Source code for ralph.pipeline.fan_out

"""Fan-out parallel execution for the pipeline runner."""

from __future__ import annotations

import asyncio
import json
import sys
import uuid
from dataclasses import dataclass
from typing import TYPE_CHECKING, Protocol, cast

from loguru import logger

from ralph.agents.registry import AgentRegistry
from ralph.agents.subprocess_executor import SubprocessAgentExecutor
from ralph.display.context import make_display_context
from ralph.executor.process import run_process_async
from ralph.interrupt.asyncio_bridge import SignalBridge, install_signal_handlers
from ralph.mcp.artifacts.handoffs import sync_markdown_handoff
from ralph.mcp.artifacts.store import list_artifacts
from ralph.mcp.server.factory_impl import DynamicBindingMcpServerFactory
from ralph.mcp.session_plan import SessionMcpPlan, SessionModelOpts, build_session_mcp_plan
from ralph.pipeline import checkpoint as ckpt
from ralph.pipeline.effect_router import config_agents_for_phase as _config_agents_for_phase
from ralph.pipeline.effects import FanOutEffect
from ralph.pipeline.events import (
    PhaseFailureEvent,
    PipelineEvent,
    PostFanoutVerificationEvent,
    WorkerFailedEvent,
)
from ralph.pipeline.legacy_console_display import (
    LegacyConsoleDisplay,
    _parallel_display_cls,
)
from ralph.pipeline.parallel import coordinator
from ralph.pipeline.parallel.mode import SameWorkspaceContext
from ralph.pipeline.parallel.worker_manifest import ParallelWorkerManifest
from ralph.pipeline.parallel.worker_runtime import build_worker_runtime_paths
from ralph.pipeline.reducer import reduce as reducer_reduce
from ralph.pipeline.verification_result import VerificationResult
from ralph.pipeline.work_units import (
    WorkUnitsPlan,
    WorkUnitsValidationError,
    validate_for_same_workspace,
)
from ralph.pipeline.worker_state import WorkerStatus
from ralph.policy.loader import load_agents_policy_for_workspace_scope
from ralph.policy.validation import PolicyValidationError
from ralph.workspace import FsWorkspace

if TYPE_CHECKING:
    from collections.abc import Callable
    from pathlib import Path

    from rich.console import Console

    from ralph.agents.executor import AgentExecutor
    from ralph.config.enums import AgentTransport
    from ralph.config.models import UnifiedConfig
    from ralph.display.parallel_display import ParallelDisplay
    from ralph.executor.process import ProcessResult
    from ralph.mcp.server.factory import McpServerFactory
    from ralph.pipeline.parallel import coordinator as parallel_coordinator
    from ralph.pipeline.state import PipelineState
    from ralph.pipeline.work_units import WorkUnit
    from ralph.policy.models import PipelinePolicy, PolicyBundle
    from ralph.workspace.scope import WorkspaceScope


if TYPE_CHECKING:

    class _PipelineSubscriberLike(Protocol):
        def notify(self, state: PipelineState) -> None: ...

    class _InstallSignalHandlersFn(Protocol):
        def __call__(self, *args: object, **kwargs: object) -> None: ...

    class _ExecutorFactory(Protocol):
        def __call__(self, *args: object, **kwargs: object) -> AgentExecutor: ...

    class _McpFactory(Protocol):
        def __call__(self, *args: object, **kwargs: object) -> McpServerFactory: ...

    class _RunProcessAsyncFn(Protocol):
        async def __call__(self, *args: object, **kwargs: object) -> ProcessResult: ...

    class _ReducerReduceFn(Protocol):
        def __call__(self, *args: object, **kwargs: object) -> tuple[PipelineState, object]: ...


@dataclass(frozen=True)
class _FanOutCtx:
    effect: FanOutEffect
    state: PipelineState
    display: ParallelDisplay
    policy_bundle: PolicyBundle
    workspace_scope: WorkspaceScope
    repo_root: Path
    pipeline_subscriber: _PipelineSubscriberLike | None
    config: UnifiedConfig | None
    config_path: Path | None
    cli_overrides: dict[str, object] | None
    monitor_stop_cb: Callable[[], None] | None
    install_signal_handlers_fn: _InstallSignalHandlersFn | None = None
    executor_cls: _ExecutorFactory | None = None
    mcp_factory_cls: _McpFactory | None = None
    run_process_async_fn: _RunProcessAsyncFn | None = None
    reducer_reduce_fn: _ReducerReduceFn | None = None


def _notify_subscriber(subscriber: _PipelineSubscriberLike | None, state: PipelineState) -> None:
    if subscriber is not None:
        subscriber.notify(state)


def _save_checkpoint_or_log(state: PipelineState, *, message: str) -> None:
    try:
        ckpt.save(state)
    except Exception as exc:
        logger.exception(message, phase=state.phase, err=exc)


[docs] def write_parallel_development_summary( workspace_scope: WorkspaceScope, effect: FanOutEffect, state: PipelineState, verification: VerificationResult | None = None, ) -> None: """Write .agent/artifacts/parallel_development_summary.json after fan-out completes.""" v = verification or VerificationResult(ran=False, passed=None, exit_code=None) workers: list[dict[str, object]] = [] for unit in effect.work_units: uid = unit.unit_id ws = state.worker_states.get(uid) artifact_dir = workspace_scope.root / ".agent" / "workers" / uid / "artifacts" artifact_count = len(list_artifacts(artifact_dir)) if artifact_dir.exists() else 0 if ws is None: status = "failed" final_message: str | None = "Worker state not recorded" elif ws.status == WorkerStatus.SUCCEEDED: status = "succeeded" final_message = None elif ws.status == WorkerStatus.CANCELLED: status = "cancelled" final_message = ws.error_message elif ws.status == WorkerStatus.FAILED: err = ws.error_message or "" status = "blocked" if err.startswith("Blocked by") else "failed" final_message = ws.error_message else: status = "failed" final_message = ws.error_message workers.append( { "unit_id": uid, "status": status, "artifact_count": artifact_count, "final_message": final_message, } ) any_failed = any(w["status"] in ("failed", "cancelled", "blocked") for w in workers) all_succeeded = not any_failed and len(workers) > 0 if v.ran and not v.passed: workers.append( { "unit_id": "__verify__", "status": "failed", "artifact_count": 0, "final_message": "workspace verification failed", } ) any_failed = True all_succeeded = False summary: dict[str, object] = { "workers": workers, "any_failed": any_failed, "all_succeeded": all_succeeded, "verification": { "ran": v.ran, "passed": v.passed, "exit_code": v.exit_code, }, } agent_artifacts = workspace_scope.root / ".agent" / "artifacts" summary_path = agent_artifacts / "parallel_development_summary.json" summary_path.parent.mkdir(parents=True, exist_ok=True) summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8") logger.debug( "Wrote parallel_development_summary.json: any_failed={f} all_succeeded={s}", f=any_failed, s=all_succeeded, ) sync_markdown_handoff(workspace_scope.root, "parallel_development_summary", summary)
def _fan_out_display_and_subscriber( display: ParallelDisplay | LegacyConsoleDisplay, pipeline_subscriber: _PipelineSubscriberLike | None, dashboard_subscriber: _PipelineSubscriberLike | None, ) -> tuple[ParallelDisplay, _PipelineSubscriberLike | None]: parallel_display_cls = _parallel_display_cls() if isinstance(display, parallel_display_cls): parallel_display = display elif isinstance(display, LegacyConsoleDisplay): parallel_display = parallel_display_cls(display._ctx) else: console = cast("Console | None", getattr(display, "console", None)) parallel_display = parallel_display_cls(make_display_context(console=console)) effective_subscriber = dashboard_subscriber or pipeline_subscriber if effective_subscriber is None and hasattr(parallel_display, "subscriber"): effective_subscriber = cast( "_PipelineSubscriberLike | None", getattr(parallel_display, "subscriber", None), ) return parallel_display, effective_subscriber def _build_session_mcp_plan_for_phase( effect: FanOutEffect, policy_bundle: PolicyBundle, workspace_scope: WorkspaceScope, config: UnifiedConfig | None, ) -> tuple[SessionMcpPlan, str]: """Build session MCP plan for fan-out workers matching the serial execution contract.""" phase_def = policy_bundle.pipeline.phases.get(effect.phase) _effect_drain = cast("str | None", getattr(effect, "drain", None)) drain: str = ( cast("str", _effect_drain) or (phase_def.drain if phase_def and hasattr(phase_def, "drain") else None) or effect.phase or "development" ) agent_name: str | None = None if phase_def is not None: config_agents = _config_agents_for_phase( config, phase=effect.phase, policy_drain=drain, ) if config_agents: agent_name = config_agents[0] else: drain_binding = policy_bundle.agents.agent_drains.get(drain) if drain_binding is not None: chain_config = policy_bundle.agents.agent_chains.get(drain_binding.chain) if chain_config is not None and chain_config.agents: agent_name = chain_config.agents[0] agent_config = None if isinstance(agent_name, str) and agent_name and config is not None: registry = AgentRegistry.from_config(config) agent_config = registry.get(agent_name) _transport_raw = cast("object", getattr(agent_config, "transport", None)) transport = cast("AgentTransport | None", _transport_raw) if agent_config is not None else None _model_flag_raw = cast("object", getattr(agent_config, "model_flag", None)) model_flag = cast("str | None", _model_flag_raw) if agent_config is not None else None effective_agents_policy = ( policy_bundle.agents if policy_bundle is not None else load_agents_policy_for_workspace_scope(workspace_scope, config=config) ) try: return build_session_mcp_plan( transport=transport, drain=drain, workspace_path=workspace_scope.root, agents_policy=effective_agents_policy, model_opts=SessionModelOpts(model_flag=model_flag), ), drain except PolicyValidationError: fallback_agents_policy = load_agents_policy_for_workspace_scope( workspace_scope, config=config ) return build_session_mcp_plan( transport=transport, drain=drain, workspace_path=workspace_scope.root, agents_policy=fallback_agents_policy, model_opts=SessionModelOpts(model_flag=model_flag), ), drain def _fan_out_worker_context( *, workspace_scope: WorkspaceScope, repo_root: Path, bridge: SignalBridge, session_drain: str, worker_commands: dict[str, tuple[str, ...]], worker_manifest_paths: dict[str, Path], session_mcp_plan: SessionMcpPlan, executor_cls: _ExecutorFactory | None = None, mcp_factory_cls: _McpFactory | None = None, ) -> tuple[AgentExecutor, parallel_coordinator.WorkerContext]: _executor_cls = ( executor_cls if executor_cls is not None else cast("_ExecutorFactory", SubprocessAgentExecutor) ) _mcp_factory_cls = ( mcp_factory_cls if mcp_factory_cls is not None else cast("_McpFactory", DynamicBindingMcpServerFactory) ) executor = _executor_cls(_parallel_worker_command(), signal_bridge=bridge) workspace = FsWorkspace( workspace_scope.root, allowed_roots=workspace_scope.allowed_roots, ) worker_namespace_root = repo_root / ".agent" / "workers" worker_namespace_root.mkdir(parents=True, exist_ok=True) return executor, coordinator.WorkerContext( log=coordinator.WorkerLog( log_dir=workspace_scope.root / ".agent" / "logs", run_id=str(uuid.uuid4()), ), same_workspace=SameWorkspaceContext( repo_root=repo_root, mcp_factory=_mcp_factory_cls(workspace=workspace), executor_command=_parallel_worker_command(), worker_commands=worker_commands, signal_bridge=bridge, worker_namespace_root=worker_namespace_root, worker_manifest_paths=worker_manifest_paths, session_drain=session_drain, session_capabilities=session_mcp_plan.capabilities, session_model_identity=session_mcp_plan.model_identity, session_capability_profile=session_mcp_plan.capability_profile, ), ) def _persist_parallel_worker_manifests( *, effect: FanOutEffect, repo_root: Path, session_drain: str, config_path: Path | None = None, cli_overrides: dict[str, object] | None = None, ) -> dict[str, Path]: worker_namespace_root = repo_root / ".agent" / "workers" manifests: dict[str, Path] = {} for unit in effect.work_units: worker_namespace = worker_namespace_root / unit.unit_id worker_namespace.mkdir(parents=True, exist_ok=True) runtime_paths = build_worker_runtime_paths( workspace_root=repo_root, worker_namespace=worker_namespace, phase=effect.phase, ) manifest = ParallelWorkerManifest( unit_id=unit.unit_id, description=unit.description, allowed_directories=list(unit.allowed_directories), phase=effect.phase, drain=session_drain, config_path=str(config_path) if config_path is not None else None, cli_overrides=dict(cli_overrides or {}), worker_namespace=str(worker_namespace), worker_artifact_dir=str(worker_namespace / "artifacts"), prompt_file=str(runtime_paths.prompt_dump_path), workspace_root=str(repo_root), ) manifest_path = worker_namespace / "worker-manifest.json" manifest_path.write_text(manifest.model_dump_json(indent=2), encoding="utf-8") manifests[unit.unit_id] = manifest_path return manifests def _worker_commands_from_manifests( manifest_paths: dict[str, Path], ) -> dict[str, tuple[str, ...]]: return { unit_id: _parallel_worker_command(manifest_path) for unit_id, manifest_path in manifest_paths.items() } def _resume_fan_out_state( state: PipelineState, effect: FanOutEffect, pipeline_policy: PipelinePolicy, subscriber: _PipelineSubscriberLike | None, *, reducer_reduce_fn: _ReducerReduceFn | None = None, ) -> tuple[PipelineState, tuple[WorkUnit, ...]]: _reduce = ( reducer_reduce_fn if reducer_reduce_fn is not None else cast("_ReducerReduceFn", reducer_reduce) ) resumed_state, _ = _reduce(state, PipelineEvent.WORKERS_RESUMED, pipeline_policy) _notify_subscriber(subscriber, resumed_state) completed_ids = { uid for uid, ws in resumed_state.worker_states.items() if ws.status == WorkerStatus.SUCCEEDED } resume_units = tuple(u for u in effect.work_units if u.unit_id not in completed_ids) return resumed_state, resume_units async def _run_post_fanout_verification( workspace_scope: WorkspaceScope, *, run_process_async_fn: _RunProcessAsyncFn | None = None, ) -> str | None: """Run workspace-wide verification exactly once after all workers complete.""" logger.debug("Running post-fanout workspace-wide verification (serialized)") _run = ( run_process_async_fn if run_process_async_fn is not None else cast("_RunProcessAsyncFn", run_process_async) ) verify_result = await _run( "make", ["-C", str(workspace_scope.root / "ralph-workflow"), "verify"], ) if verify_result.returncode != 0: return ( f"Post-fanout workspace verification failed " f"(exit {verify_result.returncode}): " f"{verify_result.stderr.strip() or verify_result.stdout.strip()}" ) return None async def _run_verify_phase( ctx: _FanOutCtx, current: PipelineState, any_worker_failed: bool ) -> tuple[PipelineState, VerificationResult]: if not ctx.effect.run_post_fanout_verification: return current, VerificationResult(ran=False, passed=None, exit_code=None) if any_worker_failed: logger.debug("Post-fanout verification skipped: one or more workers failed in this wave") return current, VerificationResult(ran=False, passed=None, exit_code=None) verify_error = await _run_post_fanout_verification( ctx.workspace_scope, run_process_async_fn=ctx.run_process_async_fn ) if verify_error is not None: logger.error(verify_error) v = VerificationResult(ran=True, passed=False, exit_code=1) verify_ev = PostFanoutVerificationEvent(success=False, exit_code=1, error=verify_error) else: v = VerificationResult(ran=True, passed=True, exit_code=0) verify_ev = PostFanoutVerificationEvent(success=True, exit_code=0) _reduce = ( ctx.reducer_reduce_fn if ctx.reducer_reduce_fn is not None else cast("_ReducerReduceFn", reducer_reduce) ) current, _ = _reduce(current, verify_ev, ctx.policy_bundle.pipeline) _notify_subscriber(ctx.pipeline_subscriber, current) _save_checkpoint_or_log( current, message="Checkpoint save failed after verification in phase={phase}: {err}", ) return current, v async def _run_fan_out_async(ctx: _FanOutCtx) -> PipelineState: current = ctx.state _reduce = ( ctx.reducer_reduce_fn if ctx.reducer_reduce_fn is not None else cast("_ReducerReduceFn", reducer_reduce) ) _install = ( ctx.install_signal_handlers_fn if ctx.install_signal_handlers_fn is not None else cast("_InstallSignalHandlersFn", install_signal_handlers) ) try: loop = asyncio.get_running_loop() bridge = SignalBridge() if ctx.monitor_stop_cb is not None: bridge._connectivity_stop = ctx.monitor_stop_cb root_task = cast("asyncio.Task[object] | None", asyncio.current_task()) assert root_task is not None _install(loop, root_task, bridge) try: validate_for_same_workspace(WorkUnitsPlan(work_units=list(ctx.effect.work_units))) except WorkUnitsValidationError as exc: failure_reason = f"Parallel plan rejected (same-workspace safety check failed): {exc}" logger.error(failure_reason) failure_event = PhaseFailureEvent( phase=current.phase, reason=failure_reason, recoverable=True ) recovered, _ = _reduce( current, failure_event, ctx.policy_bundle.pipeline, recovery=None ) _notify_subscriber(ctx.pipeline_subscriber, recovered) _save_checkpoint_or_log( recovered, message="Checkpoint save failed after plan rejection in phase={phase}: {err}", ) return recovered session_mcp_plan, session_drain = _build_session_mcp_plan_for_phase( effect=ctx.effect, policy_bundle=ctx.policy_bundle, workspace_scope=ctx.workspace_scope, config=ctx.config, ) worker_manifest_paths = _persist_parallel_worker_manifests( effect=ctx.effect, repo_root=ctx.repo_root, session_drain=session_drain, config_path=ctx.config_path, cli_overrides=ctx.cli_overrides, ) worker_commands = _worker_commands_from_manifests(worker_manifest_paths) executor, worker_ctx = _fan_out_worker_context( workspace_scope=ctx.workspace_scope, repo_root=ctx.repo_root, bridge=bridge, session_drain=session_drain, worker_commands=worker_commands, worker_manifest_paths=worker_manifest_paths, session_mcp_plan=session_mcp_plan, executor_cls=ctx.executor_cls, mcp_factory_cls=ctx.mcp_factory_cls, ) current, resume_units = _resume_fan_out_state( ctx.state, ctx.effect, ctx.policy_bundle.pipeline, ctx.pipeline_subscriber, reducer_reduce_fn=_reduce, ) if not resume_units: return current fan_out_events = await coordinator.run_fan_out( effect=FanOutEffect( work_units=resume_units, max_workers=ctx.effect.max_workers, phase=ctx.effect.phase, ), executor=executor, display=ctx.display, ctx=worker_ctx, ) for ev in fan_out_events: current, _ = _reduce(current, ev, ctx.policy_bundle.pipeline) _notify_subscriber(ctx.pipeline_subscriber, current) _save_checkpoint_or_log( current, message="Checkpoint save failed after fan-out in phase={phase}: {err}", ) any_worker_failed = any(isinstance(ev, WorkerFailedEvent) for ev in fan_out_events) current, verification = await _run_verify_phase(ctx, current, any_worker_failed) write_parallel_development_summary(ctx.workspace_scope, ctx.effect, current, verification) return current except KeyboardInterrupt: raise except BaseException as exc: logger.exception( "Fan-out execution crashed in phase={phase}: {err}", phase=current.phase, err=exc, ) failure_event = PhaseFailureEvent( phase=current.phase, reason=f"Fan-out execution crashed: {type(exc).__name__}: {exc}", recoverable=True, ) recovered, _ = _reduce(current, failure_event, ctx.policy_bundle.pipeline, recovery=None) _notify_subscriber(ctx.pipeline_subscriber, recovered) _save_checkpoint_or_log( recovered, message=( "Checkpoint save failed while recording fan-out recovery in phase={phase}: {err}" ), ) return recovered
[docs] def execute_fan_out_sync( *, effect: FanOutEffect, state: PipelineState, display: ParallelDisplay | LegacyConsoleDisplay, **opts: object, ) -> PipelineState: """Execute fan-out development synchronously by wrapping asyncio.run().""" policy_bundle = cast("PolicyBundle", opts["policy_bundle"]) workspace_scope = cast("WorkspaceScope", opts["workspace_scope"]) pipeline_subscriber = cast("_PipelineSubscriberLike | None", opts.get("pipeline_subscriber")) dashboard_subscriber = cast("_PipelineSubscriberLike | None", opts.get("dashboard_subscriber")) config = cast("UnifiedConfig | None", opts.get("config")) config_path = cast("Path | None", opts.get("config_path")) cli_overrides = cast("dict[str, object] | None", opts.get("cli_overrides")) monitor_stop_cb = cast("Callable[[], None] | None", opts.get("_monitor_stop_cb")) install_fn = cast("_InstallSignalHandlersFn | None", opts.get("_install_signal_handlers")) executor_cls = cast("_ExecutorFactory | None", opts.get("_executor_cls")) mcp_factory_cls = cast("_McpFactory | None", opts.get("_mcp_factory_cls")) run_process_fn = cast("_RunProcessAsyncFn | None", opts.get("_run_process_async")) reducer_fn = cast("_ReducerReduceFn | None", opts.get("_reducer_reduce")) parallel_display, effective_subscriber = _fan_out_display_and_subscriber( display, pipeline_subscriber, dashboard_subscriber ) ctx = _FanOutCtx( effect=effect, state=state, display=parallel_display, policy_bundle=policy_bundle, workspace_scope=workspace_scope, repo_root=workspace_scope.root, pipeline_subscriber=effective_subscriber, config=config, config_path=config_path, cli_overrides=cli_overrides, monitor_stop_cb=monitor_stop_cb, install_signal_handlers_fn=install_fn, executor_cls=executor_cls, mcp_factory_cls=mcp_factory_cls, run_process_async_fn=run_process_fn, reducer_reduce_fn=reducer_fn, ) return asyncio.run(_run_fan_out_async(ctx))
def _parallel_worker_command(manifest_path: Path | None = None) -> tuple[str, ...]: command: tuple[str, ...] = (sys.executable, "-m", "ralph") if manifest_path is None: return command return (*command, "--parallel-worker-manifest", str(manifest_path))