Source code for ralph.pipeline.state

"""Immutable pipeline state model.

This module defines PipelineState - the single source of truth for pipeline
execution progress. It serves dual purposes:
1. Runtime State: Tracks current phase, iteration counters, agent chain state
2. Checkpoint Payload: Serializes to JSON for resume functionality

PipelineState is IMMUTABLE from the reducer's perspective. State transitions
occur exclusively through the reduce function.

POLICY-DRIVEN STATE TRACKING
==============================
Loop counters (loop_iterations / loop_caps) and phase chains (phase_chains)
are keyed by policy-declared names, not hardcoded field names. This enables
custom workflows with arbitrary phase and counter names to work without
modifying source code.

Budget counters (budget_caps / outer_progress) track the cap and completed
cycles for each policy-declared budget counter. Remaining budget is always
derived: remaining = max(0, cap - progress).

Legacy checkpoint fields (budget fields only) are migrated to the generic
dicts at deserialise time via the _migrate_legacy_state_fields model_validator.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Final, cast

from pydantic import Field, field_validator, model_validator

from ralph.config.enums import PipelinePhase
from ralph.pipeline.work_units import WorkUnit
from ralph.pipeline.worker_state import WorkerState

from .state_models import (
    AgentChainState,
    CommitState,
    FalloverRecord,
    RebaseState,
    RunMetrics,
    _FrozenPipelineStateModel,
)

if TYPE_CHECKING:
    from ralph.policy.models import DrainName, PipelinePolicy

_UNSET_PHASE: Final[str] = "__unset__"
_DEFAULT_RECOVERY_CYCLE_CAP: Final[int] = 200


def _migrate_counter_field(
    d: dict[str, object],
    target: dict[str, object],
    legacy_field: str,
    counter_name: str,
) -> None:
    if counter_name not in target and legacy_field in d:
        target[counter_name] = d[legacy_field]


def _resolved_recovery_cycle_cap(raw_cap: object) -> int:
    if raw_cap is None:
        return _DEFAULT_RECOVERY_CYCLE_CAP
    if isinstance(raw_cap, int):
        cap = raw_cap
    elif isinstance(raw_cap, str):
        try:
            cap = int(raw_cap)
        except ValueError:
            return _DEFAULT_RECOVERY_CYCLE_CAP
    else:
        return _DEFAULT_RECOVERY_CYCLE_CAP
    return cap if cap >= 1 else _DEFAULT_RECOVERY_CYCLE_CAP


def _normalize_fallover_history_for_cap(
    history: object,
    recovery_cycle_cap: object,
) -> tuple[FalloverRecord, ...]:
    if history is None:
        return ()
    if not isinstance(history, list | tuple):
        raise TypeError(
            f"Expected list or tuple for fallover_history, got {type(history).__name__!r}"
        )
    records = tuple(
        FalloverRecord.model_validate(item)
        if isinstance(item, dict)
        else cast("FalloverRecord", item)
        for item in history
    )
    cap = _resolved_recovery_cycle_cap(recovery_cycle_cap)
    if len(records) <= cap:
        return records
    return records[-cap:]


[docs] class PipelineState(_FrozenPipelineStateModel): """Immutable snapshot of pipeline execution state. This is the checkpoint payload - the single source of truth for pipeline progress. Serialize it to JSON to save state; deserialize to resume interrupted runs. GENERIC TRACKING FIELDS (policy-keyed): phase_chains: Per-phase agent chain state keyed by canonical phase name. loop_iterations: Loop iteration counters keyed by iteration_state_field name. loop_caps: Loop iteration caps keyed by iteration_state_field name. budget_caps: Max budget keyed by budget counter name (seeded from policy). outer_progress: Completed cycle counts keyed by budget counter name. Remaining budget is derived on-demand: max(0, cap - progress). """ phase: PipelinePhase = _UNSET_PHASE previous_phase: PipelinePhase | None = None # Review outcome tracking (replaces direct review_issues_found writes) review_outcome: str | None = None # Generic per-phase chain state (keyed by canonical phase name from policy) phase_chains: dict[str, AgentChainState] = Field(default_factory=dict) # Generic loop iteration tracking (keyed by iteration_state_field from loop_policy) loop_iterations: dict[str, int] = Field(default_factory=dict) loop_caps: dict[str, int] = Field(default_factory=dict) # Generic budget counter tracking (keyed by budget counter name from budget_counters) # Remaining budget is derived: max(0, budget_caps[k] - outer_progress[k]) budget_caps: dict[str, int] = Field(default_factory=dict) outer_progress: dict[str, int] = Field(default_factory=dict) rebase: RebaseState = Field(default_factory=RebaseState) commit: CommitState = Field(default_factory=CommitState) metrics: RunMetrics = Field(default_factory=RunMetrics) checkpoint_saved_count: int = 0 recovery_epoch: int = 0 interrupted_by_user: bool = False git_auth_configured: bool = False pr_created: bool = False pr_url: str | None = None push_count: int = 0 last_error: str | None = None last_reviewed_sha: str | None = None # Policy-derived fields (set at startup and after phase transitions) policy_entry_phase: PipelinePhase = _UNSET_PHASE policy_format_version: int | None = None current_drain: str | None = None work_units: tuple[WorkUnit, ...] = Field(default_factory=tuple) worker_states: dict[str, WorkerState] = Field(default_factory=dict) # Recovery observability fields — all have defaults so legacy checkpoints load cleanly recovery_cycle_count: int = 0 fallover_history: tuple[FalloverRecord, ...] = Field(default_factory=tuple) last_failure_category: str | None = None last_connectivity_state: str = "unknown" recovery_cycle_cap: int = Field(default=200, ge=1) last_retry_delay_ms: int = 0 last_agent_session_id: str | None = None session_preserve_retry_pending: bool = False @model_validator(mode="after") def _validate_phase_set(self) -> PipelineState: if self.phase == _UNSET_PHASE: raise ValueError( "PipelineState requires phase to be set from PipelinePolicy.entry_phase " "before construction; use PipelineState.from_policy(policy) " "or pass phase= explicitly." ) return self
[docs] @classmethod def from_policy(cls, policy: PipelinePolicy, **overrides: object) -> PipelineState: """Construct initial pipeline state from a loaded PipelinePolicy. The entry phase is derived from policy.entry_phase so no workflow entry semantics are embedded in this class. """ payload: dict[str, object] = { "phase": policy.entry_phase, "policy_entry_phase": policy.entry_phase, "policy_format_version": 2 if policy.entry_block is not None else 1, **overrides, } return cls.model_validate(payload)
@model_validator(mode="before") @classmethod def _migrate_legacy_state_fields(cls, data: object) -> object: """Migrate legacy checkpoint fields into generic dicts. Handles old checkpoints that stored: - Typed chain fields → migrated into phase_chains - Legacy budget fields → migrated into budget_caps / outer_progress Legacy budget_remaining values are converted to outer_progress by computing progress = cap - remaining, so the derived remaining stays the same. """ if not isinstance(data, dict): return data d = cast("dict[str, object]", dict(data)) _raw_bc = d.get("budget_caps") budget_caps_data: dict[str, object] = dict( cast("dict[str, object]", _raw_bc) if _raw_bc is not None else {} ) _raw_op = d.get("outer_progress") outer_progress_data: dict[str, object] = dict( cast("dict[str, object]", _raw_op) if _raw_op is not None else {} ) _migrate_counter_field(d, budget_caps_data, "total_iterations", "iteration") _migrate_counter_field(d, budget_caps_data, "total_reviewer_passes", "reviewer_pass") _migrate_counter_field(d, outer_progress_data, "iteration", "iteration") _migrate_counter_field(d, outer_progress_data, "reviewer_pass", "reviewer_pass") # Migrate very-old scalar remaining fields to outer_progress via progress = cap - remaining for _br_field, _counter in ( ("development_budget_remaining", "iteration"), ("review_budget_remaining", "reviewer_pass"), ): _scalar_br = d.get(_br_field) if ( _scalar_br is not None and _counter not in outer_progress_data and _counter in budget_caps_data ): _cap = int(cast("int", budget_caps_data[_counter])) _rem = int(cast("int", _scalar_br)) outer_progress_data[_counter] = max(0, _cap - _rem) # Migrate legacy budget_remaining dict into outer_progress via progress = cap - remaining _raw_br = d.get("budget_remaining") if isinstance(_raw_br, dict): legacy_br = cast("dict[str, object]", _raw_br) for counter, remaining_val in legacy_br.items(): if counter not in outer_progress_data and counter in budget_caps_data: cap = int(cast("int", budget_caps_data[counter])) remaining = int(cast("int", remaining_val)) outer_progress_data[counter] = max(0, cap - remaining) d["budget_caps"] = budget_caps_data d["outer_progress"] = outer_progress_data if "fallover_history" in d: d["fallover_history"] = _normalize_fallover_history_for_cap( d.get("fallover_history"), d.get("recovery_cycle_cap"), ) # Drop legacy budget_remaining — no longer a field d.pop("budget_remaining", None) return d @field_validator("work_units", mode="before") @classmethod def _coerce_work_units(cls, v: object) -> tuple[WorkUnit, ...]: if v is None: return () if isinstance(v, list): return tuple(v) if isinstance(v, tuple): return v raise TypeError(f"Expected list or tuple for work_units, got {type(v).__name__!r}") @field_validator("worker_states", mode="before") @classmethod def _coerce_worker_states(cls, v: object) -> dict[str, WorkerState]: if v is None: return {} if isinstance(v, dict): return v raise TypeError(f"Expected dict for worker_states, got {type(v).__name__!r}") @field_validator("fallover_history", mode="before") @classmethod def _coerce_fallover_history(cls, v: object) -> tuple[FalloverRecord, ...]: if v is None: return () if isinstance(v, list): return tuple( FalloverRecord.model_validate(item) if isinstance(item, dict) else item for item in v ) if isinstance(v, tuple): return v raise TypeError(f"Expected list or tuple for fallover_history, got {type(v).__name__!r}") @field_validator("phase_chains", mode="before") @classmethod def _coerce_phase_chains(cls, v: object) -> dict[str, AgentChainState]: if v is None: return {} if isinstance(v, dict): return { str(key): AgentChainState.model_validate(value) if isinstance(value, dict) else cast("AgentChainState", value) for key, value in v.items() } raise TypeError(f"Expected dict for phase_chains, got {type(v).__name__!r}") @field_validator("loop_iterations", mode="before") @classmethod def _coerce_loop_iterations(cls, v: object) -> dict[str, int]: if v is None: return {} if isinstance(v, dict): return {str(k): int(val) for k, val in v.items()} raise TypeError(f"Expected dict for loop_iterations, got {type(v).__name__!r}") @field_validator("loop_caps", mode="before") @classmethod def _coerce_loop_caps(cls, v: object) -> dict[str, int]: if v is None: return {} if isinstance(v, dict): return {str(k): int(val) for k, val in v.items()} raise TypeError(f"Expected dict for loop_caps, got {type(v).__name__!r}") @field_validator("budget_caps", mode="before") @classmethod def _coerce_budget_caps(cls, v: object) -> dict[str, int]: if v is None: return {} if isinstance(v, dict): return {str(k): int(val) for k, val in v.items()} raise TypeError(f"Expected dict for budget_caps, got {type(v).__name__!r}") @field_validator("outer_progress", mode="before") @classmethod def _coerce_outer_progress(cls, v: object) -> dict[str, int]: if v is None: return {} if isinstance(v, dict): return {str(k): int(val) for k, val in v.items()} raise TypeError(f"Expected dict for outer_progress, got {type(v).__name__!r}")
[docs] def is_complete(self, policy: PipelinePolicy) -> bool: """Check if pipeline has reached a terminal success state. Args: policy: PipelinePolicy. Compares current phase against policy.terminal_phase to determine completion. Raises: RuntimeError: When policy is None (routing requires loaded policy). """ return self.phase == policy.terminal_phase
[docs] def current_agent(self) -> str | None: """Get the current agent for the active phase.""" chain = self.chain_for_phase(self.phase) if chain is None: return None if not chain.agents or chain.current_index >= len(chain.agents): return None return chain.agents[chain.current_index]
[docs] def remaining_retries(self) -> int: """Calculate remaining retries for current agent.""" chain = self.chain_for_phase(self.phase) if chain is None: return 0 return max(0, 3 - chain.retries)
[docs] def advance_agent(self) -> PipelineState: """Advance to the next agent in the fallback chain.""" chain = self.chain_for_phase(self.phase) if chain is None: return self new_chain = AgentChainState( agents=chain.agents, current_index=min(chain.current_index + 1, len(chain.agents) - 1), retries=0, ) return self.with_phase_chain(self.phase, new_chain)
[docs] def chain_for_phase(self, phase: PipelinePhase | str) -> AgentChainState | None: """Get the tracked agent chain state for a phase, if any.""" return self.phase_chains.get(str(phase))
[docs] def with_phase_chain( self, phase: PipelinePhase | str, chain: AgentChainState, ) -> PipelineState: """Return a copy with the chain state for the given phase updated.""" phase_key = str(phase) new_chains = {**self.phase_chains, phase_key: chain} return self.copy_with(phase_chains=new_chains)
[docs] def with_drain(self, drain: DrainName | None) -> PipelineState: """Return a copy with the current_drain set.""" return self.copy_with(current_drain=drain)
[docs] def get_loop_iteration(self, field_name: str) -> int: """Get the loop iteration counter for a policy-declared iteration field. Args: field_name: The iteration_state_field value from PhaseLoopPolicy. Returns: Current iteration count (0 when not yet set). """ return self.loop_iterations.get(field_name, 0)
[docs] def with_loop_iteration(self, field_name: str, value: int) -> PipelineState: """Return a copy with the specified loop iteration field set to value. Args: field_name: The iteration_state_field value from PhaseLoopPolicy. value: New iteration count. Returns: New PipelineState with the iteration counter updated. """ return self.copy_with( loop_iterations={**self.loop_iterations, field_name: value}, )
[docs] def get_budget_remaining(self, counter_name: str) -> int: """Get the remaining budget for a policy-declared budget counter. Args: counter_name: The budget counter name from PhaseCommitPolicy.increments_counter. Returns: Remaining budget count, derived as max(0, cap - completed). """ return max( 0, self.budget_caps.get(counter_name, 0) - self.outer_progress.get(counter_name, 0), )
[docs] def get_budget_cap(self, counter_name: str) -> int: """Get the budget cap for a policy-declared budget counter. Args: counter_name: The budget counter name. Returns: Budget cap (maximum allowed), or 0 if not set. """ return self.budget_caps.get(counter_name, 0)
[docs] def get_outer_progress(self, counter_name: str) -> int: """Get the completed cycle count for a policy-declared budget counter.""" return self.outer_progress.get(counter_name, 0)
[docs] def with_outer_progress(self, counter_name: str, value: int) -> PipelineState: """Return a copy with the specified outer progress counter set to value.""" return self.copy_with(outer_progress={**self.outer_progress, counter_name: value})
[docs] def with_budget_cap(self, counter_name: str, value: int) -> PipelineState: """Return a copy with the specified budget cap set to value.""" return self.copy_with( budget_caps={**self.budget_caps, counter_name: value}, )
[docs] def with_fallover_record(self, record: FalloverRecord) -> PipelineState: """Return a copy with one additional fallover record, trimmed to cycle cap.""" return self.copy_with(fallover_history=(*self.fallover_history, record))
[docs] def copy_with(self, **updates: object) -> PipelineState: """Return a copy with updates applied in a typed-safe manner.""" if self.work_units and "work_units" in updates and updates["work_units"] != self.work_units: updates = {k: v for k, v in updates.items() if k != "work_units"} if "fallover_history" in updates or "recovery_cycle_cap" in updates: updates = { **updates, "fallover_history": _normalize_fallover_history_for_cap( updates.get("fallover_history", self.fallover_history), updates.get("recovery_cycle_cap", self.recovery_cycle_cap), ), } return self.model_copy(update=updates)
# Resolve forward references from TYPE_CHECKING imports at runtime PipelineState.model_rebuild( _types_namespace={ "PipelinePhase": PipelinePhase, "WorkUnit": WorkUnit, "WorkerState": WorkerState, } ) __all__ = [ "AgentChainState", "CommitState", "FalloverRecord", "PipelineState", "RebaseState", "RunMetrics", ]