"""TOML policy loader with fallback to bundled defaults.
Loads agents.toml, pipeline.toml, and artifacts.toml from the user's .agent/
config directory, falling back to the packaged defaults when files are absent.
All loading goes through Pydantic validation so any malformed config surfaces
as a PolicyValidationError with field-level detail.
User-global policy overrides prefer branded filenames
(`ralph-workflow-pipeline.toml`, `ralph-workflow-artifacts.toml`) while
still accepting the legacy unprefixed names for backward compatibility.
"""
from __future__ import annotations
import tomllib
from collections.abc import Mapping, Sequence
from os import getenv
from pathlib import Path
from typing import TYPE_CHECKING, cast
from loguru import logger
from pydantic import TypeAdapter, ValidationError
import ralph.policy
from ralph.phases import register_role_handlers
from ralph.policy.models import (
AgentChainConfig,
AgentDrainConfig,
AgentsPolicy,
ArtifactsPolicy,
IndividualPolicyBlock,
PipelinePolicy,
PolicyBlock,
PolicyBundle,
)
from ralph.policy.validation import (
PolicyValidationError,
validate_drain_contracts,
validate_policy_completeness,
)
if TYPE_CHECKING:
from ralph.config.models import UnifiedConfig
from ralph.workspace.scope import WorkspaceScope
__all__ = [
"load_policy",
"load_policy_or_die",
]
def _load_toml(path: Path) -> dict[str, object]:
"""Load a TOML file, returning empty dict if absent.
Args:
path: Path to the TOML file.
Returns:
Parsed TOML content or empty dict if file doesn't exist.
Raises:
PolicyValidationError: If TOML parsing fails.
"""
if not path.exists():
return {}
try:
with path.open("rb") as fh:
data: dict[str, object] = tomllib.load(fh)
return data
except Exception as exc:
raise PolicyValidationError(
f"Failed to parse TOML at {path}: {exc}",
source=str(path.name),
) from exc
ValidationErrorDetail = Mapping[str, object]
ValidationErrorDetails = Sequence[ValidationErrorDetail]
_GLOBAL_POLICY_FILENAME_MAP = {
"pipeline.toml": "ralph-workflow-pipeline.toml",
"artifacts.toml": "ralph-workflow-artifacts.toml",
}
PIPELINE_POLICY_FIELDS = frozenset(
{
"blocks",
"entry_block",
"phases",
"entry_phase",
"terminal_phase",
"loop_counters",
"budget_counters",
"post_commit_routes",
"lifecycle_phases",
"default_phase_retry_policy",
"recovery",
}
)
_BLOCKS_ADAPTER = TypeAdapter(dict[str, PolicyBlock])
def _phase_name_for_block(
block_name: str,
blocks: Mapping[str, PolicyBlock],
*,
path: tuple[str, ...] = (),
) -> str:
block = blocks.get(block_name)
if block is None:
raise PolicyValidationError(
f"pipeline.toml block '{block_name}' is referenced but not defined",
source="pipeline",
)
if isinstance(block, IndividualPolicyBlock):
return block.phase_name
if block_name in path:
cycle = " -> ".join((*path, block_name))
raise PolicyValidationError(
f"pipeline.toml block graph contains a cycle: {cycle}",
source="pipeline",
)
if not block.child_blocks:
raise PolicyValidationError(
"pipeline.toml group block "
f"'{block_name}' must declare at least one child_blocks entry",
source="pipeline",
)
return _phase_name_for_block(block.child_blocks[0], blocks, path=(*path, block_name))
def _compile_block_pipeline_data(data: dict[str, object]) -> dict[str, object]:
raw_blocks = data.get("blocks")
if raw_blocks is None:
return data
if not isinstance(raw_blocks, Mapping):
raise PolicyValidationError(
"pipeline.toml blocks must be a mapping of block names to block definitions",
source="pipeline",
)
blocks = _BLOCKS_ADAPTER.validate_python(raw_blocks)
phases: dict[str, object] = {}
lifecycle_phases: dict[str, object] = {}
for block_name, block in blocks.items():
if isinstance(block, IndividualPolicyBlock):
if block.phase_name in phases:
raise PolicyValidationError(
"pipeline.toml phase_name "
f"'{block.phase_name}' is declared by more than one individual block",
source="pipeline",
)
phases[block.phase_name] = block.phase
continue
for child_name in block.child_blocks:
if child_name not in blocks:
raise PolicyValidationError(
"pipeline.toml group block "
f"'{block_name}' references unknown child_blocks entry '{child_name}'",
source="pipeline",
)
if block.completion_block not in block.child_blocks:
raise PolicyValidationError(
"pipeline.toml group block "
f"'{block_name}' completion_block '{block.completion_block}' "
"must also appear in child_blocks",
source="pipeline",
)
completion_phase = _phase_name_for_block(block.completion_block, blocks)
lifecycle_phases[completion_phase] = {
"lifecycle_name": block_name,
"completion_block": block.completion_block,
"increments_counter": block.increments_counter,
"loop_resets": list(block.loop_resets),
"before_complete": list(block.before_complete),
"after_complete": list(block.after_complete),
}
entry_block = data.get("entry_block")
if not isinstance(entry_block, str) or not entry_block:
raise PolicyValidationError(
"pipeline.toml block-authored policies must declare entry_block",
source="pipeline",
)
entry_phase = _phase_name_for_block(entry_block, blocks)
normalized = dict(data)
normalized["blocks"] = blocks
normalized["phases"] = phases
normalized["entry_phase"] = entry_phase
normalized["lifecycle_phases"] = lifecycle_phases
return normalized
def _normalize_pipeline_data(data: dict[str, object]) -> dict[str, object]:
"""Normalize pipeline data and reject obsolete authoring formats."""
nested_pipeline = data.get("pipeline")
if isinstance(nested_pipeline, Mapping) and not PIPELINE_POLICY_FIELDS.intersection(data):
raise PolicyValidationError(
"pipeline.toml uses the obsolete [pipeline] wrapper format. "
"This redesign is a hard break; regenerate the policy with the "
"new block-based schema.",
source="pipeline",
)
return _compile_block_pipeline_data(data)
def format_validation_error_messages(exc: ValidationError) -> list[str]:
"""Format all pydantic ValidationError errors into human-readable strings."""
details = cast("ValidationErrorDetails", exc.errors())
return [format_validation_error_detail(detail) for detail in details]
def format_validation_error_detail(detail: ValidationErrorDetail) -> str:
"""Format a single pydantic validation error detail as 'location: message'."""
loc = detail.get("loc")
msg = detail.get("msg")
return f" {format_validation_location(loc)}: {format_validation_message(msg)}"
def format_validation_location(raw_loc: object | None) -> str:
"""Format a pydantic error location tuple to a dotted path string."""
if raw_loc is None:
return "<root>"
if isinstance(raw_loc, list | tuple):
if not raw_loc:
return "<root>"
return ".".join(str(component) for component in raw_loc)
return str(raw_loc)
def format_validation_message(raw_msg: object | None) -> str:
"""Return the validation error message string, substituting a placeholder if absent."""
if isinstance(raw_msg, str):
return raw_msg
if raw_msg is None:
return "<missing message>"
return str(raw_msg)
def _validate_agents(data: dict[str, object]) -> AgentsPolicy:
"""Validate and return AgentsPolicy.
Args:
data: Raw TOML dictionary.
Returns:
Validated AgentsPolicy instance.
Raises:
PolicyValidationError: On validation failure.
"""
try:
return AgentsPolicy.model_validate(data)
except ValidationError as exc:
msgs = format_validation_error_messages(exc)
raise PolicyValidationError(
"agents.toml validation failed:\n" + "\n".join(msgs),
source="agents",
) from exc
def _validate_pipeline(data: dict[str, object]) -> PipelinePolicy:
"""Validate and return PipelinePolicy.
Args:
data: Raw TOML dictionary.
Returns:
Validated PipelinePolicy instance.
Raises:
PolicyValidationError: On validation failure.
"""
normalized = _normalize_pipeline_data(data)
try:
return PipelinePolicy.model_validate(normalized)
except ValidationError as exc:
msgs = format_validation_error_messages(exc)
raise PolicyValidationError(
"pipeline.toml validation failed:\n" + "\n".join(msgs),
source="pipeline",
) from exc
def _validate_artifacts(data: dict[str, object]) -> ArtifactsPolicy:
"""Validate and return ArtifactsPolicy.
Args:
data: Raw TOML dictionary.
Returns:
Validated ArtifactsPolicy instance.
Raises:
PolicyValidationError: On validation failure.
"""
try:
return ArtifactsPolicy.model_validate(data)
except ValidationError as exc:
msgs = format_validation_error_messages(exc)
raise PolicyValidationError(
"artifacts.toml validation failed:\n" + "\n".join(msgs),
source="artifacts",
) from exc
def _merge_mapping_defaults(
defaults: Mapping[str, object], overrides: Mapping[str, object]
) -> dict[str, object]:
"""Recursively merge a project-local policy mapping onto bundled defaults.
This preserves backward compatibility for older generated policy files that
omit newly added fields or artifact contracts. Explicit project-local values
still win over the bundled defaults.
"""
merged: dict[str, object] = dict(defaults)
for key, override_value in overrides.items():
default_value = merged.get(key)
if isinstance(default_value, Mapping) and isinstance(override_value, Mapping):
merged[key] = _merge_mapping_defaults(default_value, override_value)
continue
merged[key] = override_value
return merged
def _merge_pipeline_defaults(
defaults: Mapping[str, object], overrides: Mapping[str, object]
) -> dict[str, object]:
"""Merge pipeline defaults while treating omitted override phases as removals.
Older generated pipeline files are typically full phase graphs. When they
intentionally omit a default-only phase, recursively unioning the ``phases``
mapping revives that phase and can make the resulting graph unreachable.
For pipeline phase overlays, preserve deep field inheritance *within* each
declared override phase, but do not inherit default phases that the override
omitted entirely.
"""
merged = _merge_mapping_defaults(defaults, overrides)
override_phases = overrides.get("phases")
default_phases = defaults.get("phases")
if not isinstance(override_phases, Mapping) or not isinstance(default_phases, Mapping):
return merged
merged_phases: dict[str, object] = {}
for phase_name, override_phase in override_phases.items():
default_phase = default_phases.get(phase_name)
if isinstance(default_phase, Mapping) and isinstance(override_phase, Mapping):
merged_phases[phase_name] = _merge_mapping_defaults(default_phase, override_phase)
else:
merged_phases[phase_name] = override_phase
merged["phases"] = merged_phases
return merged
def _is_phase_authored_pipeline_data(data: Mapping[str, object]) -> bool:
"""Return whether a pipeline TOML uses the legacy phase-authored schema."""
if not data:
return False
has_phases = isinstance(data.get("phases"), Mapping)
has_blocks = isinstance(data.get("blocks"), Mapping)
has_entry_block = isinstance(data.get("entry_block"), str)
return has_phases and not has_blocks and not has_entry_block
def _reject_phase_authored_pipeline_override(*, scope: str, path: Path) -> None:
raise PolicyValidationError(
f"{scope} pipeline.toml at '{path}' uses the obsolete phase-authored schema and must "
f"not be merged with the bundled block-authored default pipeline. Remove the outdated "
f"file so Ralph Workflow falls back to the current default pipeline template.",
source="pipeline",
)
def _resolve_pipeline_data(
*,
default_policy_dir: Path,
pipeline_path: Path,
global_pipeline_path: Path | None,
) -> dict[str, object]:
default_pipeline_data = _load_toml(default_policy_dir / "pipeline.toml")
local_pipeline_data = _load_toml(pipeline_path)
if global_pipeline_path is None:
return local_pipeline_data or default_pipeline_data
global_pipeline_data = _load_toml(global_pipeline_path)
if _is_phase_authored_pipeline_data(global_pipeline_data):
_reject_phase_authored_pipeline_override(scope="User-global", path=global_pipeline_path)
if _is_phase_authored_pipeline_data(local_pipeline_data):
_reject_phase_authored_pipeline_override(scope="Workspace-local", path=pipeline_path)
pipeline_data = _merge_pipeline_defaults(default_pipeline_data, global_pipeline_data)
if local_pipeline_data:
pipeline_data = _merge_pipeline_defaults(pipeline_data, local_pipeline_data)
return pipeline_data
def _config_defines_agent_policy(config: object) -> bool:
chains: object = getattr(config, "agent_chains", None)
drains: object = getattr(config, "agent_drains", None)
return (
isinstance(chains, Mapping)
and isinstance(drains, Mapping)
and bool(chains)
and bool(drains)
)
def _coerce_agent_chain_config(
value: object,
*,
retry_budget: int,
retry_delay_ms: int,
) -> AgentChainConfig:
if isinstance(value, AgentChainConfig):
return value
return AgentChainConfig(
agents=list(cast("Sequence[str]", value)),
max_retries=retry_budget,
retry_delay_ms=retry_delay_ms,
)
def _coerce_agent_drain_config(
drain: str,
value: object,
*,
builtin_drain_classes: Mapping[str, str],
) -> AgentDrainConfig:
if isinstance(value, AgentDrainConfig):
return AgentDrainConfig(
chain=value.chain,
drain_class=value.drain_class or builtin_drain_classes.get(drain),
capability_class=value.capability_class,
)
return AgentDrainConfig(
chain=cast("str", value),
drain_class=builtin_drain_classes.get(drain),
)
def build_agents_policy_from_config(config: UnifiedConfig) -> AgentsPolicy:
"""Synthesize the active agents policy from the main Ralph config.
User-facing chain order and drain routing live in ``ralph-workflow.toml``.
This helper converts the flat ``UnifiedConfig`` representation into the richer
``AgentsPolicy`` model used by the runtime.
Canonical built-in drains are upgraded to explicit ``drain_class`` declarations
here so downstream runtime code can resolve classes from policy alone without
relying on hidden enum fallbacks.
"""
general: object = getattr(config, "general", None)
retry_budget_value: object = getattr(general, "max_retries", 3)
retry_delay_ms_value: object = getattr(general, "retry_delay_ms", 1000)
retry_budget = retry_budget_value if isinstance(retry_budget_value, int) else 3
retry_delay_ms = retry_delay_ms_value if isinstance(retry_delay_ms_value, int) else 1000
raw_agent_chains_obj: object = getattr(config, "agent_chains", {})
raw_agent_chains = (
cast("Mapping[str, object]", raw_agent_chains_obj)
if isinstance(raw_agent_chains_obj, Mapping)
else {}
)
chain_configs = {
name: _coerce_agent_chain_config(
chain_value,
retry_budget=retry_budget,
retry_delay_ms=retry_delay_ms,
)
for name, chain_value in raw_agent_chains.items()
}
builtin_drain_classes: dict[str, str] = {
"planning": "planning",
"development": "development",
"development_analysis": "analysis",
"planning_analysis": "analysis",
"review_analysis": "analysis",
"analysis": "analysis",
"review": "review",
"fix": "fix",
"development_commit": "commit",
"review_commit": "commit",
"commit": "commit",
}
raw_agent_drains_obj: object = getattr(config, "agent_drains", {})
raw_agent_drains = (
cast("Mapping[str, object]", raw_agent_drains_obj)
if isinstance(raw_agent_drains_obj, Mapping)
else {}
)
drain_configs = {
drain: _coerce_agent_drain_config(
drain,
drain_value,
builtin_drain_classes=builtin_drain_classes,
)
for drain, drain_value in raw_agent_drains.items()
}
return AgentsPolicy(
agent_chains=chain_configs,
agent_drains=drain_configs,
)
_DEFAULT_AGENTS_POLICY_CACHE: list[AgentsPolicy] = []
def _cached_default_agents_policy() -> AgentsPolicy:
if not _DEFAULT_AGENTS_POLICY_CACHE:
_DEFAULT_AGENTS_POLICY_CACHE.append(
_validate_agents(_load_toml(default_dir() / "agents.toml"))
)
return _DEFAULT_AGENTS_POLICY_CACHE[0]
def _load_agents_policy_from_path(
agents_path: Path,
config: UnifiedConfig | None = None,
) -> AgentsPolicy:
agents_policy = (
build_agents_policy_from_config(config)
if config is not None and _config_defines_agent_policy(config)
else None
)
if agents_policy is not None:
return agents_policy
if not agents_path.exists():
return _cached_default_agents_policy()
agents_data = _load_toml(agents_path)
if not agents_data:
return _cached_default_agents_policy()
return _validate_agents(agents_data)
def load_agents_policy(config_dir: Path, config: UnifiedConfig | None = None) -> AgentsPolicy:
"""Load only the agents policy, using config synthesis when available.
This is for call sites that need drain/chain declarations without requiring a
full pipeline/artifact bundle.
"""
return _load_agents_policy_from_path(config_dir / "agents.toml", config=config)
def load_agents_policy_for_workspace_scope(
workspace_scope: WorkspaceScope,
config: UnifiedConfig | None = None,
) -> AgentsPolicy:
"""Load agents policy for a workspace with worktree-aware inheritance."""
return _load_agents_policy_from_path(
workspace_scope.resolve_agent_file("agents.toml"),
config=config,
)
def _load_policy_from_paths(
*,
agents_path: Path,
pipeline_path: Path,
artifacts_path: Path,
config: UnifiedConfig | None = None,
global_policy_paths: tuple[Path | None, Path | None] | None = None,
) -> PolicyBundle:
"""Load a policy bundle from explicit file paths."""
global_pipeline_path, global_artifacts_path = (
global_policy_paths if global_policy_paths is not None else (None, None)
)
default_policy_dir = default_dir()
pipeline_data = _resolve_pipeline_data(
default_policy_dir=default_policy_dir,
pipeline_path=pipeline_path,
global_pipeline_path=global_pipeline_path,
)
default_artifacts_data = _load_toml(default_policy_dir / "artifacts.toml")
local_artifacts_data = _load_toml(artifacts_path)
if global_artifacts_path is None:
if local_artifacts_data:
artifacts_data = _merge_mapping_defaults(default_artifacts_data, local_artifacts_data)
else:
artifacts_data = default_artifacts_data
else:
global_artifacts_data = _load_toml(global_artifacts_path)
artifacts_data = _merge_mapping_defaults(default_artifacts_data, global_artifacts_data)
if local_artifacts_data:
artifacts_data = _merge_mapping_defaults(artifacts_data, local_artifacts_data)
agents_policy = _load_agents_policy_from_path(agents_path, config=config)
pipeline_policy = _validate_pipeline(pipeline_data)
artifacts_policy = _validate_artifacts(artifacts_data)
try:
bundle = PolicyBundle(
agents=agents_policy,
pipeline=pipeline_policy,
artifacts=artifacts_policy,
)
except ValidationError as exc:
msgs = format_validation_error_messages(exc)
raise PolicyValidationError(
"Cross-policy validation failed (drain bindings / analysis contracts):\n"
+ "\n".join(msgs),
source=None,
) from exc
try:
validate_drain_contracts(bundle)
except PolicyValidationError as exc:
raise PolicyValidationError(
exc.message,
source="agents",
) from exc
try:
validate_policy_completeness(bundle)
except PolicyValidationError as exc:
raise PolicyValidationError(
exc.message,
source=exc.source or "completeness",
) from exc
register_role_handlers(pipeline_policy)
return bundle
[docs]
def load_policy(config_dir: Path, config: UnifiedConfig | None = None) -> PolicyBundle:
"""Load all three policy TOML files and return a validated PolicyBundle.
Files are loaded from ``config_dir`` (the .agent/ directory). Any absent
file is silently replaced with the bundled default.
"""
return _load_policy_from_paths(
agents_path=config_dir / "agents.toml",
pipeline_path=config_dir / "pipeline.toml",
artifacts_path=config_dir / "artifacts.toml",
config=config,
)
def load_policy_for_workspace_scope(
workspace_scope: WorkspaceScope,
config: UnifiedConfig | None = None,
) -> PolicyBundle:
"""Load policy for a workspace with worktree-aware per-file inheritance."""
return _load_policy_from_paths(
agents_path=workspace_scope.resolve_agent_file("agents.toml"),
pipeline_path=workspace_scope.resolve_agent_file("pipeline.toml"),
artifacts_path=workspace_scope.resolve_agent_file("artifacts.toml"),
config=config,
global_policy_paths=(
_global_policy_path("pipeline.toml"),
_global_policy_path("artifacts.toml"),
),
)
def default_dir() -> Path:
"""Return the path to the bundled default policy files."""
return Path(ralph.policy.__file__).parent / "defaults"
def _global_policy_path(filename: str) -> Path:
"""Return the canonical user-global path for a runtime policy TOML."""
xdg_config_home = getenv("XDG_CONFIG_HOME")
base_dir = Path(xdg_config_home) if xdg_config_home else Path.home() / ".config"
preferred_name = _GLOBAL_POLICY_FILENAME_MAP.get(filename, filename)
return base_dir / preferred_name
[docs]
def load_policy_or_die(config_dir: Path, config: UnifiedConfig | None = None) -> PolicyBundle:
"""Load policy, exiting with a user-friendly message on failure.
Args:
config_dir: Path to the .agent/ configuration directory.
Returns:
Validated PolicyBundle.
"""
try:
if config is None:
return load_policy(config_dir)
return load_policy(config_dir, config=config)
except PolicyValidationError as exc:
logger.error("Policy validation failed: {}", exc.message)
if exc.source:
logger.error(" Source: {}", exc.source)
raise SystemExit(1) from exc