"""ParallelCoordinator and supporting helpers for fan-out execution."""
from __future__ import annotations
import asyncio
from dataclasses import replace
from typing import TYPE_CHECKING, cast
from loguru import logger
from ralph import logging as ralph_logging
from ralph.agents import subprocess_executor
from ralph.mcp.artifacts.store import list_artifacts
from ralph.mcp.protocol.env import (
AGENT_LABEL_SCOPE_ENV,
MCP_ENDPOINT_ENV,
RALPH_PARALLEL_WORKER_MANIFEST_ENV,
WORKER_ARTIFACT_DIR_ENV,
WORKER_ID_ENV,
WORKER_NAMESPACE_ENV,
)
from ralph.mcp.server import factory_impl
from ralph.pipeline.events import (
Event,
PipelineEvent,
WorkerCompletedEvent,
WorkerFailedEvent,
WorkerStartedEvent,
)
from ralph.pipeline.parallel import worker_session
from ralph.pipeline.parallel.scheduler import schedule_next_wave
from ralph.pipeline.parallel.worker_context import WorkerContext
from ralph.pipeline.parallel.worker_failure_error import WorkerFailureError as _WorkerFailureError
from ralph.pipeline.parallel.worker_log import WorkerLog
from ralph.pipeline.work_units import (
WorkUnitsPlan,
WorkUnitsValidationError,
validate_for_same_workspace,
)
from ralph.pipeline.worker_state import WorkerStatus
from ralph.process.manager import ProcessTerminationError, get_process_manager
from ralph.workspace import fs
from ralph.workspace.scope import WorkspaceScope
if TYPE_CHECKING:
from pathlib import Path
from ralph.agents.executor import AgentExecutor, WorkerResult
from ralph.display.activity_router import ActivityRouter
from ralph.display.parallel_display import ParallelDisplay
from ralph.pipeline.effects import FanOutEffect
from ralph.pipeline.parallel.mode import SameWorkspaceContext
from ralph.pipeline.parallel.worker_session import WorkerSessionBundle
from ralph.pipeline.work_units import WorkUnit
[docs]
class ParallelCoordinator:
"""Orchestrates parallel work-unit execution with DAG dependency ordering."""
def __init__(self, *, activity_router: ActivityRouter | None = None) -> None:
self.activity_router = activity_router
[docs]
async def run_fan_out(
self,
effect: FanOutEffect,
executor: AgentExecutor,
display: ParallelDisplay,
ctx: WorkerContext | None = None,
) -> list[Event]:
"""Execute parallel work units while respecting DAG dependencies and worker caps."""
effective_router = self.activity_router
if effective_router is None and hasattr(display, "activity_router"):
effective_router = display.activity_router
worker_ctx = (
WorkerContext(activity_router=effective_router)
if ctx is None
else replace(ctx, activity_router=effective_router)
)
same_workspace = worker_ctx.same_workspace if worker_ctx is not None else None
ns_root = (
str(same_workspace.worker_namespace_root)
if same_workspace is not None and same_workspace.worker_namespace_root is not None
else "unknown"
)
logger.info(
"fan-out start mode=same_workspace units={n} namespace_root={ns}",
n=len(effect.work_units),
ns=ns_root,
)
if effect.work_units:
try:
validate_for_same_workspace(WorkUnitsPlan(work_units=list(effect.work_units)))
except WorkUnitsValidationError as exc:
logger.error("coordinator preflight rejected plan: {}", exc)
return [
WorkerFailedEvent(
unit_id="__preflight__",
exit_code=2,
error=f"parallel preflight rejected plan: {exc}",
)
]
events: list[Event] = [PipelineEvent.FAN_OUT_STARTED]
if not effect.work_units:
return [*events, PipelineEvent.ALL_WORKERS_COMPLETE]
pending = {unit.unit_id for unit in effect.work_units}
completed: set[str] = set()
running: dict[str, WorkUnit] = {}
completion_queue: asyncio.Queue[WorkerResult] = asyncio.Queue()
try:
async with asyncio.TaskGroup() as task_group:
while pending or running:
ready = schedule_next_wave(
completed,
effect.work_units,
set(running),
effect.max_workers,
)
for unit in ready:
pending.discard(unit.unit_id)
running[unit.unit_id] = unit
events.append(WorkerStartedEvent(unit_id=unit.unit_id))
task_group.create_task(
_run_worker(
unit,
executor,
display,
completion_queue,
worker_ctx,
),
name=unit.unit_id,
)
if running:
result = await completion_queue.get()
running.pop(result.unit_id, None)
completed.add(result.unit_id)
events.append(
WorkerCompletedEvent(
unit_id=result.unit_id,
exit_code=result.exit_code,
)
)
continue
if pending:
break
except* Exception as group:
failures, unexpected = _flatten_worker_failures(group.exceptions)
_append_terminal_failure_events(
events=events,
work_units=effect.work_units,
pending=pending,
running=running,
failures=failures,
)
if unexpected:
raise ExceptionGroup("Unexpected fan-out coordinator failure", unexpected) from None
else:
events.append(PipelineEvent.ALL_WORKERS_COMPLETE)
return events
def _flatten_worker_failures(
exceptions: tuple[BaseException, ...],
) -> tuple[list[_WorkerFailureError], list[Exception]]:
failures: list[_WorkerFailureError] = []
unexpected: list[Exception] = []
stack = list(exceptions)
while stack:
current = stack.pop()
if isinstance(current, BaseExceptionGroup):
stack.extend(current.exceptions)
continue
if isinstance(current, _WorkerFailureError):
failures.append(current)
continue
if isinstance(current, Exception):
unexpected.append(current)
return failures, unexpected
def _prepare_executor(
unit: WorkUnit,
executor: AgentExecutor,
same_workspace: SameWorkspaceContext | None,
activity_router: ActivityRouter | None = None,
) -> tuple[AgentExecutor, WorkerSessionBundle | None, Path | None]:
if same_workspace is None:
if activity_router is not None and isinstance(
executor, subprocess_executor.SubprocessAgentExecutor
):
executor.activity_router = activity_router
return executor, None, None
ns_root = same_workspace.worker_namespace_root or (
same_workspace.repo_root / ".agent" / "workers"
)
worker_namespace = ns_root / unit.unit_id
for subdir in ("artifacts", "tmp", "logs", "handoffs"):
(worker_namespace / subdir).mkdir(parents=True, exist_ok=True)
worker_scope = WorkspaceScope.for_same_workspace_worker(
repo_root=same_workspace.repo_root,
allowed_directories=tuple(unit.allowed_directories),
worker_namespace=worker_namespace,
)
_session_cfg = worker_session.WorkerSessionConfig(
worker_artifact_dir=worker_namespace / "artifacts",
worker_namespace=worker_namespace,
session_drain=same_workspace.session_drain,
session_capabilities=same_workspace.session_capabilities,
session_model_identity=same_workspace.session_model_identity,
session_capability_profile=same_workspace.session_capability_profile,
)
if same_workspace.executor_command is None:
bundle = worker_session.build_worker_session(
unit, same_workspace.mcp_factory, worker_scope, _session_cfg
)
return executor, bundle, worker_namespace
worker_workspace = fs.FsWorkspace(
same_workspace.repo_root,
allowed_roots=worker_scope.allowed_roots,
)
worker_mcp_factory = factory_impl.DynamicBindingMcpServerFactory(workspace=worker_workspace)
bundle = worker_session.build_worker_session(
unit, worker_mcp_factory, worker_scope, _session_cfg
)
worker_artifact_dir = worker_namespace / "artifacts"
agent_label_scope = bundle.session.session_id
command = same_workspace.worker_commands.get(unit.unit_id, same_workspace.executor_command)
manifest_path = same_workspace.worker_manifest_paths.get(unit.unit_id)
return (
cast(
"AgentExecutor",
subprocess_executor.SubprocessAgentExecutor(
command,
signal_bridge=same_workspace.signal_bridge,
cwd=same_workspace.repo_root,
extra_env={
str(MCP_ENDPOINT_ENV): bundle.mcp_handle.endpoint,
str(WORKER_ID_ENV): unit.unit_id,
str(WORKER_NAMESPACE_ENV): str(worker_namespace),
str(WORKER_ARTIFACT_DIR_ENV): str(worker_artifact_dir),
str(RALPH_PARALLEL_WORKER_MANIFEST_ENV): (
str(manifest_path) if manifest_path is not None else ""
),
str(AGENT_LABEL_SCOPE_ENV): agent_label_scope,
},
activity_router=activity_router,
raw_overflow_root=worker_namespace / "logs",
),
),
bundle,
worker_namespace,
)
def _blocked_dependency_error(unit: WorkUnit, failed_unit_ids: set[str]) -> str | None:
blocked_by = sorted(dep for dep in unit.dependencies if dep in failed_unit_ids)
if not blocked_by:
return None
return f"Blocked by failed dependencies: {', '.join(blocked_by)}"
def _blocked_pending_failures(
work_units: tuple[WorkUnit, ...],
pending_unit_ids: set[str],
failed_unit_ids: set[str],
) -> list[WorkerFailedEvent]:
pending_units = {unit.unit_id: unit for unit in work_units if unit.unit_id in pending_unit_ids}
blocked_events: list[WorkerFailedEvent] = []
expanded_failures = set(failed_unit_ids)
while True:
progress_made = False
for unit_id, unit in list(pending_units.items()):
blocked_error = _blocked_dependency_error(unit, expanded_failures)
if blocked_error is None:
continue
blocked_events.append(
WorkerFailedEvent(unit_id=unit_id, exit_code=1, error=blocked_error)
)
expanded_failures.add(unit_id)
del pending_units[unit_id]
progress_made = True
if not progress_made:
return blocked_events
def _append_terminal_failure_events(
*,
events: list[Event],
work_units: tuple[WorkUnit, ...],
pending: set[str],
running: dict[str, WorkUnit],
failures: list[_WorkerFailureError],
) -> None:
seen_failures = {event.unit_id for event in events if isinstance(event, WorkerFailedEvent)}
failed_unit_ids = {failure.unit_id for failure in failures}
for failure in failures:
if failure.unit_id in seen_failures:
continue
running.pop(failure.unit_id, None)
events.append(
WorkerFailedEvent(
unit_id=failure.unit_id,
exit_code=failure.exit_code,
error=failure.error,
)
)
seen_failures.add(failure.unit_id)
for unit_id in list(running):
if unit_id in seen_failures:
continue
events.append(
WorkerFailedEvent(
unit_id=unit_id,
exit_code=1,
error="Cancelled because another worker failed",
)
)
seen_failures.add(unit_id)
blocked_events = _blocked_pending_failures(
work_units,
pending,
failed_unit_ids | seen_failures,
)
for blocked_event in blocked_events:
if blocked_event.unit_id in seen_failures:
continue
events.append(blocked_event)
seen_failures.add(blocked_event.unit_id)
async def _run_worker(
unit: WorkUnit,
executor: AgentExecutor,
display: ParallelDisplay,
completion_queue: asyncio.Queue[WorkerResult],
ctx: WorkerContext | None = None,
) -> None:
log = ctx.log if ctx is not None else None
same_workspace = ctx.same_workspace if ctx is not None else None
activity_router = ctx.activity_router if ctx is not None else None
with logger.contextualize(unit_id=unit.unit_id):
sink_handle = (
ralph_logging.bind_worker_sink(
unit_id=unit.unit_id, log_dir=log.log_dir, run_id=log.run_id
)
if log is not None
else None
)
bundle = None
worker_succeeded = False
active_executor = executor
def on_output(line: str) -> None:
display.emit(unit.unit_id, line)
def on_status(status: WorkerStatus) -> None:
display.set_status(unit.unit_id, status)
try:
active_executor, bundle, worker_namespace = _prepare_executor(
unit,
executor,
same_workspace,
activity_router,
)
try:
result = await active_executor.run(unit, on_output=on_output, on_status=on_status)
except asyncio.CancelledError:
display.set_status(unit.unit_id, WorkerStatus.CANCELLED)
raise
except BaseException as exc:
if isinstance(exc, _WorkerFailureError):
raise
if isinstance(exc, Exception):
display.set_status(unit.unit_id, WorkerStatus.FAILED)
raise _WorkerFailureError(unit.unit_id, 1, str(exc)) from exc
raise
if bundle is not None and worker_namespace is not None:
artifact_dir = worker_namespace / "artifacts"
if not list_artifacts(artifact_dir):
display.set_status(unit.unit_id, WorkerStatus.FAILED)
raise _WorkerFailureError(
unit_id=unit.unit_id,
exit_code=result.exit_code,
error=(
f"Worker {unit.unit_id!r} produced no worker-local artifact "
f"evidence under {artifact_dir} "
f"(exit_code={result.exit_code})"
),
)
display.set_status(unit.unit_id, WorkerStatus.SUCCEEDED)
await completion_queue.put(result)
worker_succeeded = True
finally:
if bundle is not None:
bundle.mcp_handle.shutdown()
label_env: dict[str, str] | None = None
if bundle is not None and same_workspace is not None:
label_env = {str(AGENT_LABEL_SCOPE_ENV): bundle.session.session_id}
try:
get_process_manager().shutdown_all_for_label(
subprocess_executor.agent_process_label_prefix(unit.unit_id, label_env),
grace_period_s=2.0,
)
except ProcessTerminationError as exc:
logger.error(
"Failed to terminate agent processes for worker {}: {}", unit.unit_id, exc
)
if sink_handle is not None:
ralph_logging.remove_worker_sink(sink_handle)
del worker_succeeded
[docs]
async def run_fan_out(
effect: FanOutEffect,
executor: AgentExecutor,
display: ParallelDisplay,
ctx: WorkerContext | None = None,
activity_router: ActivityRouter | None = None,
) -> list[Event]:
"""Execute a fan-out effect using a fresh ParallelCoordinator instance."""
coord = ParallelCoordinator(activity_router=activity_router)
return await coord.run_fan_out(effect, executor, display, ctx)
prepare_executor = _prepare_executor
__all__ = [
"ParallelCoordinator",
"WorkerContext",
"WorkerLog",
"prepare_executor",
"run_fan_out",
]