"""RecoveryController: single owner of failure classification, budget, and fallover."""
from __future__ import annotations
from datetime import UTC, datetime
from importlib import import_module
from pathlib import Path
from typing import TYPE_CHECKING, cast
from loguru import logger
from ralph.pipeline.effects import ExitFailureEffect
from ralph.pipeline.state import FalloverRecord
from ralph.recovery.budget import AgentBudgetRegistry
from ralph.recovery.classifier import (
ClassifiedFailure,
FailureCategory,
FailureClassifier,
FailureContext,
)
from ralph.recovery.cycle_cap import CycleCap
from ralph.recovery.events import FailureEvent, FailureEventBus, FalloverEvent
from ralph.recovery.recovery_controller_options import RecoveryControllerOptions
__all__ = ["RecoveryController", "RecoveryControllerOptions", "compute_backoff_ms"]
if TYPE_CHECKING:
from collections.abc import Callable
from ralph.pipeline.effects import Effect
from ralph.pipeline.state import AgentChainState, PipelineState
from ralph.policy.models import AgentChainConfig
def _build_exit_failure_effect(*, reason: str) -> Effect:
return ExitFailureEffect(reason=reason)
def _build_fallover_record(
*,
phase: str,
from_agent: str,
to_agent: str,
timestamp_iso: str,
) -> FalloverRecord:
return FalloverRecord(
phase=phase,
from_agent=from_agent,
to_agent=to_agent,
timestamp_iso=timestamp_iso,
)
def _get_required_artifact_helpers() -> tuple[Callable[[str, str], str], Callable[[str], str]]:
# Lazy import to avoid circular dependency via ralph.phases import chain
module = import_module("ralph.phases.required_artifacts")
namespace = cast("dict[str, object]", module.__dict__)
build_retry_hint = cast("Callable[[str, str], str]", namespace["build_retry_hint"])
retry_hint_path = cast("Callable[[str], str]", namespace["retry_hint_path"])
return build_retry_hint, retry_hint_path
[docs]
def compute_backoff_ms(base_ms: int, attempt: int, max_ms: int = 30_000) -> int:
"""Compute exponential backoff delay with cap.
Args:
base_ms: Base delay in milliseconds.
attempt: Current retry attempt (0-indexed).
max_ms: Maximum delay cap in milliseconds.
Returns:
Delay in milliseconds, capped at max_ms.
"""
exponent_factor: int = 2**attempt
delay: int = base_ms * exponent_factor
return min(delay, max_ms)
[docs]
class RecoveryController:
"""Single conceptual owner of recovery logic.
Handles classification, budget debiting, chain fallover, and cycle cap.
Delegates nothing to the reducer's internal retry counter when active.
"""
def __init__(
self,
*,
options: RecoveryControllerOptions | None = None,
) -> None:
opts = options or RecoveryControllerOptions()
self._cap = CycleCap(cap=opts.cycle_cap)
self._classifier = opts.classifier or FailureClassifier()
self._bus = opts.event_bus or FailureEventBus()
self._registry = opts.budget_registry or AgentBudgetRegistry()
self._policy_bundle = opts.policy_bundle
self._backoff_attempts: dict[str, int] = opts.backoff_attempts or {}
self._technical_retry_cap = max(0, opts.technical_retry_cap)
@property
def event_bus(self) -> FailureEventBus:
return self._bus
@property
def budget_registry(self) -> AgentBudgetRegistry:
return self._registry
[docs]
def handle(
self,
state: PipelineState,
raw_failure: BaseException | str,
context: FailureContext,
) -> tuple[PipelineState, list[Effect], FailureEvent]:
"""Classify a failure and compute the recovery transition.
Args:
state: Current pipeline state.
raw_failure: The raw exception or string error message.
context: Phase/agent context and optional pre-classified failure.
Returns:
Tuple of (new_state, effects, failure_event).
"""
phase = context.phase
agent = context.agent
retry_in_session = context.retry_in_session
failure = context.classified_failure or self._classifier.classify(
raw_failure, phase=phase, agent=agent
)
chain = state.chain_for_phase(phase)
chain_capacity = 0
retry_delay_ms = 0
if chain is not None:
chain_capacity = max(0, len(chain.agents) - chain.current_index - 1)
# Compute retry delay from chain config
if agent is not None and failure.counts_against_budget:
retry_delay_ms = self._compute_retry_delay(phase, agent)
failure_evt = FailureEvent(
timestamp=datetime.now(UTC),
phase=phase,
agent=agent,
category=str(failure.category),
reason=failure.reason,
counted_against_budget=failure.counts_against_budget,
chain_capacity_remaining=chain_capacity,
recovery_cycle=state.recovery_cycle_count,
retry_delay_ms=retry_delay_ms,
)
self._bus.publish(failure_evt)
# ALWAYS set last_failure_category and last_retry_delay_ms on state first
new_state = state.copy_with(
last_failure_category=str(failure.category),
last_retry_delay_ms=retry_delay_ms,
)
if failure.category == FailureCategory.ENVIRONMENTAL:
logger.info(
"Environmental failure in phase={} (not counted against budget): {}",
phase,
failure.reason[:200],
)
new_state = new_state.copy_with(last_error=failure.reason)
new_state, effects = self._handle_technical_retry_exhaustion(
new_state,
failure,
phase,
agent,
retry_in_session=retry_in_session,
)
return new_state, effects, failure_evt
if failure.category in (
FailureCategory.ARTIFACT_VALIDATION,
FailureCategory.AMBIGUOUS,
):
category_label = (
"Artifact validation"
if failure.category == FailureCategory.ARTIFACT_VALIDATION
else "Ambiguous"
)
logger.info(
"{} failure in phase={} (retry without budget debit): {}",
category_label,
phase,
failure.reason[:200],
)
new_state = new_state.copy_with(last_error=failure.reason)
new_state, effects = self._handle_technical_retry_exhaustion(
new_state,
failure,
phase,
agent,
retry_in_session=retry_in_session,
)
return new_state, effects, failure_evt
if failure.category == FailureCategory.USER_CONFIG:
logger.error(
"User/config failure reached runtime controller in phase={} (bug): {}",
phase,
failure.reason[:200],
)
return (
self._enter_phase_failed(new_state, failure.reason, failure.category),
[],
failure_evt,
)
# AGENT category: debit budget and handle chain progression
if failure.reset_session:
logger.warning(
"Stale session detected in phase={} (session id invalid): {}",
phase,
failure.reason[:200],
)
new_state = new_state.copy_with(
last_agent_session_id=None,
session_preserve_retry_pending=False,
)
self._write_session_reset_hint(phase, failure)
if agent is not None:
self._registry = self._registry.debit(phase, agent, failure)
# Track backoff attempt
if failure.counts_against_budget:
key = f"{phase}:{agent}"
self._backoff_attempts[key] = self._backoff_attempts.get(key, 0) + 1
new_state, effects = self._handle_agent_budget_exhaustion(
new_state, failure, phase, agent, retry_in_session=retry_in_session
)
if self._cap.is_exceeded(new_state.recovery_cycle_count):
exit_reason = self._cap.exit_reason(
new_state.recovery_cycle_count,
str(failure.category),
failure.reason[:200],
)
logger.error("Recovery cycle cap exceeded: {}", exit_reason)
# Cycle exceeded: no retry delay
return (
new_state.copy_with(last_retry_delay_ms=0),
[_build_exit_failure_effect(reason=exit_reason)],
failure_evt,
)
return new_state, effects, failure_evt
def _increment_chain_retries(self, state: PipelineState, phase: str) -> PipelineState:
"""Increment chain.retries for the given phase without debiting the budget."""
chain = state.chain_for_phase(phase)
if chain is None:
return state
return state.with_phase_chain(phase, chain.with_retry_increment())
def _handle_technical_retry_exhaustion(
self,
state: PipelineState,
failure: ClassifiedFailure,
phase: str,
agent: str | None,
*,
retry_in_session: bool = False,
) -> tuple[PipelineState, list[Effect]]:
return self._handle_retry_progression(
state,
failure,
phase,
agent,
retry_in_session=retry_in_session,
max_retries=self._technical_retry_cap,
use_budget=False,
)
def _apply_chain_retry(
self,
state: PipelineState,
phase: str,
chain: AgentChainState,
*,
retry_in_session: bool,
) -> PipelineState:
"""Apply a single retry to the chain and optionally preserve the agent session."""
retried_state = state.with_phase_chain(phase, chain.with_retry_increment())
if retry_in_session and state.last_agent_session_id:
retried_state = retried_state.copy_with(session_preserve_retry_pending=True)
return retried_state
def _chain_config_for_phase(self, phase: str) -> AgentChainConfig | None:
"""Resolve the AgentChainConfig backing the given phase, or None."""
if self._policy_bundle is None:
return None
phase_def = self._policy_bundle.pipeline.phases.get(phase)
if phase_def is None:
return None
drain_config = self._policy_bundle.agents.agent_drains.get(phase_def.drain)
if drain_config is None:
return None
return self._policy_bundle.agents.agent_chains.get(drain_config.chain)
def _compute_retry_delay(
self,
phase: str,
agent: str | None,
) -> int:
"""Compute the retry delay for a given phase and agent.
Uses the chain's retry_delay_ms from policy configuration.
"""
chain_config = self._chain_config_for_phase(phase)
if chain_config is None:
return 0
# Get backoff attempt count for this phase:agent
key = f"{phase}:{agent}" if agent else phase
attempt = self._backoff_attempts.get(key, 0)
return compute_backoff_ms(chain_config.retry_delay_ms, attempt)
[docs]
def reset_backoff(self, phase: str, agent: str | None) -> None:
"""Reset backoff counter for a phase/agent after successful invocation."""
key = f"{phase}:{agent}" if agent else phase
self._backoff_attempts.pop(key, None)
def _write_session_reset_hint(
self,
phase: str,
failure: ClassifiedFailure,
) -> None:
"""Write a retry hint file describing the stale-session failure.
Args:
phase: Pipeline phase where the failure occurred.
failure: Classified failure with stale-session detail.
"""
build_retry_hint, retry_hint_path = _get_required_artifact_helpers()
detail = (
"Previous session id was invalid; restart with fresh session."
f" Original failure: {failure.raw_message}"
)
hint_content = build_retry_hint(phase, detail)
hint_file = Path(retry_hint_path(phase))
try:
hint_file.parent.mkdir(parents=True, exist_ok=True)
hint_file.write_text(hint_content, encoding="utf-8")
except OSError:
logger.warning("Failed to write session reset hint to {}", hint_file)
def _handle_agent_budget_exhaustion(
self,
state: PipelineState,
failure: ClassifiedFailure,
phase: str,
agent: str | None,
*,
retry_in_session: bool = False,
) -> tuple[PipelineState, list[Effect]]:
"""Handle agent failure with budget debit and chain progression."""
return self._handle_retry_progression(
state,
failure,
phase,
agent,
retry_in_session=retry_in_session,
max_retries=self._get_max_retries_for_chain(phase),
use_budget=True,
)
def _handle_retry_progression(
self,
state: PipelineState,
failure: ClassifiedFailure,
phase: str,
agent: str | None,
*,
retry_in_session: bool,
max_retries: int,
use_budget: bool,
) -> tuple[PipelineState, list[Effect]]:
chain = state.chain_for_phase(phase)
if chain is None:
return state, []
current_agent = agent or (
chain.agents[chain.current_index]
if chain.agents and chain.current_index < len(chain.agents)
else None
)
budget_state = (
self._registry.get(phase, current_agent)
if use_budget and current_agent is not None
else None
)
should_retry_in_chain = current_agent is not None and (
(budget_state is not None and not budget_state.exhausted)
or (budget_state is None and chain.retries < max_retries)
)
if should_retry_in_chain:
return (
self._apply_chain_retry(state, phase, chain, retry_in_session=retry_in_session),
[],
)
if chain.current_index + 1 < len(chain.agents):
next_agent = chain.agents[chain.current_index + 1]
from_agent = current_agent or f"agent[{chain.current_index}]"
fallover_record = _build_fallover_record(
phase=phase,
from_agent=from_agent,
to_agent=next_agent,
timestamp_iso=datetime.now(UTC).isoformat(),
)
fallover_evt = FalloverEvent.now(
phase=phase,
from_agent=from_agent,
to_agent=next_agent,
reason=failure.reason,
)
self._bus.publish(fallover_evt)
new_state = (
state.with_phase_chain(phase, chain.with_advance())
.copy_with(last_retry_delay_ms=0)
.with_fallover_record(fallover_record)
)
return new_state, []
new_state = state.copy_with(recovery_cycle_count=state.recovery_cycle_count + 1)
failed_state = self._enter_phase_failed(new_state, failure.reason, failure.category)
return failed_state, []
def _get_max_retries_for_chain(self, phase: str) -> int:
"""Get max_retries from policy for the chain used by this phase."""
chain_config = self._chain_config_for_phase(phase)
if chain_config is None:
return 3
return chain_config.max_retries
[docs]
def snapshot(self) -> dict[str, object]:
"""Return a runtime observability snapshot of recovery state."""
return {
"cycle_cap": self._cap.cap,
"budgets": {
f"{phase}:{agent}": {
"max_retries": budget.max_retries,
"consumed": budget.consumed,
"remaining": budget.remaining,
"exhausted": budget.exhausted,
}
for (phase, agent), budget in self._registry.items()
},
"backoff_attempts": dict(self._backoff_attempts),
"technical_retry_cap": self._technical_retry_cap,
}
def _enter_phase_failed(
self,
state: PipelineState,
reason: str,
category: object,
) -> PipelineState:
"""Enter the terminal failure phase.
Uses policy.declared.failed_route when available, raising a RuntimeError
if policy is not set (signals missing policy at a routing call site).
"""
if self._policy_bundle is None:
raise RuntimeError(
"_enter_phase_failed requires policy_bundle to be set on the controller. "
"Without policy, the runtime cannot determine the failure route. "
"Set policy_bundle when constructing RecoveryController."
)
failed_route = self._policy_bundle.pipeline.recovery.failed_route
return state.copy_with(
phase=failed_route,
previous_phase=state.phase,
last_error=reason,
recovery_epoch=state.recovery_epoch + 1,
last_failure_category=str(category),
last_retry_delay_ms=0,
)