Source code for ralph.git.rebase.rebase_checkpoint

"""Rebase checkpoint persistence and locking utilities."""

from __future__ import annotations

import json
import os
import shutil
import tempfile
from dataclasses import dataclass, field
from datetime import UTC, datetime
from pathlib import Path
from typing import TYPE_CHECKING

from ralph.git.rebase._rebase_lock import RebaseLock
from ralph.git.rebase._rebase_phase import RebasePhase

if TYPE_CHECKING:
    from collections.abc import Mapping

__all__ = [
    "RebaseCheckpoint",
    "acquire_rebase_lock",
    "clear_rebase_checkpoint",
    "load_rebase_checkpoint",
    "rebase_checkpoint_exists",
    "release_rebase_lock",
    "restore_from_backup",
    "save_rebase_checkpoint",
]

AGENT_DIR = Path(".agent")
CHECKPOINT_FILE = "rebase_checkpoint.json"
BACKUP_SUFFIX = ".bak"
LOCK_FILE = "rebase.lock"
LOCK_TIMEOUT_SECONDS = 1_800


def _current_timestamp() -> str:
    return datetime.now(UTC).isoformat()


def _ensure_agent_dir() -> None:
    AGENT_DIR.mkdir(parents=True, exist_ok=True)


def _checkpoint_path() -> Path:
    return AGENT_DIR / CHECKPOINT_FILE


def _backup_path() -> Path:
    return AGENT_DIR / f"{CHECKPOINT_FILE}{BACKUP_SUFFIX}"


def _lock_path() -> Path:
    return AGENT_DIR / LOCK_FILE


def _json_object(value: object) -> dict[str, object]:
    if not isinstance(value, dict):
        raise ValueError("Checkpoint payload must be a JSON object")

    payload: dict[str, object] = {}
    for key, item in value.items():
        if not isinstance(key, str):
            raise ValueError("Checkpoint payload keys must be strings")
        payload[key] = item
    return payload


def _load_checkpoint_payload(path: Path) -> dict[str, object]:
    raw_payload: object = json.loads(path.read_text())
    return _json_object(raw_payload)


def _string_list(data: Mapping[str, object], key: str) -> list[str]:
    value = data.get(key, [])
    if not isinstance(value, list):
        return []
    return [str(item) for item in value]


def _int_value(data: Mapping[str, object], key: str, default: int = 0) -> int:
    value = data.get(key, default)
    match value:
        case int():
            return value
        case str() | bytes() | bytearray():
            try:
                return int(value)
            except ValueError:
                return default
        case _:
            return default


[docs] @dataclass class RebaseCheckpoint: """Persisted state for a rebase operation, written to ``.agent/rebase_checkpoint.json``.""" phase: RebasePhase = field(default_factory=lambda: RebasePhase.NotStarted) upstream_branch: str = "" conflicted_files: list[str] = field(default_factory=list) resolved_files: list[str] = field(default_factory=list) error_count: int = 0 last_error: str | None = None timestamp: str = field(default_factory=_current_timestamp) phase_error_count: int = 0 @classmethod def new(cls, upstream_branch: str) -> RebaseCheckpoint: return cls(upstream_branch=upstream_branch) def set_phase(self, phase: RebasePhase) -> None: if self.phase != phase: self.phase_error_count = 0 self.phase = phase self.timestamp = _current_timestamp() def add_conflicted_file(self, file: str) -> None: if file not in self.conflicted_files: self.conflicted_files.append(file) self.timestamp = _current_timestamp() def add_resolved_file(self, file: str) -> None: if file not in self.resolved_files: self.resolved_files.append(file) self.timestamp = _current_timestamp() def record_error(self, error: str) -> None: self.error_count += 1 self.phase_error_count += 1 self.last_error = error self.timestamp = _current_timestamp() def all_conflicts_resolved(self) -> bool: return all(file in self.resolved_files for file in self.conflicted_files) def unresolved_conflict_count(self) -> int: return sum(1 for file in self.conflicted_files if file not in self.resolved_files) def to_dict(self) -> dict[str, object]: return { "phase": self.phase.value, "upstream_branch": self.upstream_branch, "conflicted_files": list(self.conflicted_files), "resolved_files": list(self.resolved_files), "error_count": self.error_count, "last_error": self.last_error, "timestamp": self.timestamp, "phase_error_count": self.phase_error_count, } @classmethod def from_dict(cls, data: Mapping[str, object]) -> RebaseCheckpoint: phase_value = data.get("phase") phase = RebasePhase.NotStarted if isinstance(phase_value, str): try: phase = RebasePhase(phase_value) except ValueError: phase = RebasePhase.NotStarted last_error_value = data.get("last_error") last_error = None if last_error_value is None else str(last_error_value) return cls( phase=phase, upstream_branch=str(data.get("upstream_branch", "")), conflicted_files=_string_list(data, "conflicted_files"), resolved_files=_string_list(data, "resolved_files"), error_count=_int_value(data, "error_count"), last_error=last_error, timestamp=str(data.get("timestamp", _current_timestamp())), phase_error_count=_int_value(data, "phase_error_count"), )
[docs] def save_rebase_checkpoint(checkpoint: RebaseCheckpoint) -> None: """Atomically persist ``checkpoint`` to the agent rebase checkpoint file.""" _ensure_agent_dir() path = _checkpoint_path() checkpoint_existed = path.exists() _backup_checkpoint() fd, temp_name = tempfile.mkstemp( prefix=f"{path.name}.", suffix=".tmp", dir=path.parent, ) os.close(fd) temp_path = Path(temp_name) try: temp_path.write_text(json.dumps(checkpoint.to_dict(), indent=2), encoding="utf-8") temp_path.replace(path) finally: if temp_path.exists(): temp_path.unlink() if not checkpoint_existed: _backup_checkpoint()
def _backup_checkpoint() -> None: path = _checkpoint_path() if not path.exists(): return backup = _backup_path() fd, temp_name = tempfile.mkstemp( prefix=f"{backup.name}.", suffix=".tmp", dir=backup.parent, ) os.close(fd) temp_path = Path(temp_name) try: shutil.copy2(path, temp_path) temp_path.replace(backup) finally: if temp_path.exists(): temp_path.unlink()
[docs] def load_rebase_checkpoint() -> RebaseCheckpoint | None: """Load and validate the rebase checkpoint, falling back to backup on error.""" path = _checkpoint_path() if not path.exists(): return None try: payload = _load_checkpoint_payload(path) checkpoint = RebaseCheckpoint.from_dict(payload) validate_checkpoint(checkpoint) return checkpoint except (OSError, ValueError, json.JSONDecodeError): restored = restore_from_backup() if restored is not None: return restored raise
[docs] def clear_rebase_checkpoint() -> None: """Delete the rebase checkpoint file if it exists.""" path = _checkpoint_path() if path.exists(): path.unlink()
[docs] def rebase_checkpoint_exists() -> bool: """Return True if a rebase checkpoint file exists on disk.""" return _checkpoint_path().exists()
def validate_checkpoint(checkpoint: RebaseCheckpoint) -> None: """Raise ``ValueError`` if ``checkpoint`` contains invalid or inconsistent data.""" if checkpoint.phase != RebasePhase.NotStarted and not checkpoint.upstream_branch: raise ValueError("Checkpoint must contain upstream branch once the rebase starts") try: datetime.fromisoformat(checkpoint.timestamp) except ValueError as exc: raise ValueError("Checkpoint has invalid timestamp") from exc for resolved in checkpoint.resolved_files: if resolved not in checkpoint.conflicted_files: raise ValueError("Resolved file missing from conflict list")
[docs] def restore_from_backup() -> RebaseCheckpoint | None: """Attempt to restore a valid checkpoint from the backup file.""" backup = _backup_path() if not backup.exists(): return None payload = _load_checkpoint_payload(backup) checkpoint = RebaseCheckpoint.from_dict(payload) validate_checkpoint(checkpoint) shutil.copy2(backup, _checkpoint_path()) return checkpoint
[docs] def acquire_rebase_lock() -> None: """Acquire the rebase lock file, raising ``OSError`` if another process holds it.""" _ensure_agent_dir() path = _lock_path() if path.exists(): if _is_lock_stale(): path.unlink() else: raise OSError("Rebase lock already held") path.write_text(_lock_content())
[docs] def release_rebase_lock() -> None: """Release the rebase lock file if it exists.""" path = _lock_path() if path.exists(): path.unlink()
def _lock_content() -> str: pid = os.getpid() timestamp = _current_timestamp() return f"pid={pid}\ntimestamp={timestamp}\n" def _is_lock_stale() -> bool: path = _lock_path() try: content = path.read_text() except OSError: return True for line in content.splitlines(): if line.startswith("timestamp="): timestamp = line.split("=", 1)[1] try: cutoff = datetime.fromisoformat(timestamp) except ValueError: return True elapsed = datetime.now(UTC) - cutoff return elapsed.total_seconds() > LOCK_TIMEOUT_SECONDS return True RebaseLock._acquire_fn = acquire_rebase_lock RebaseLock._release_fn = release_rebase_lock