Source code for ralph.prompts.template_registry

"""Simple registry for prompt templates."""

from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING

from ralph.prompts.template_not_found_error import TemplateNotFoundError

__all__ = [
    "TemplateNotFoundError",
    "TemplateRegistry",
    "default_template_dirs",
    "load_partial_templates",
    "packaged_template_root",
]

if TYPE_CHECKING:
    from collections.abc import Iterable


[docs] class TemplateRegistry: """Registry that holds prompt templates by name.""" def __init__(self, *, template_dirs: tuple[Path, ...] = ()) -> None: self._templates: dict[str, str] = {} self._template_dirs = template_dirs
[docs] def register_template(self, name: str, content: str) -> None: """Register or replace a prompt template.""" self._templates[name] = content
[docs] def get_template(self, name: str) -> str: """Return the template associated with ``name`` or raise if missing.""" try: return self._templates[name] except KeyError as exc: discovered = self._discover_template(name) if discovered is not None: return discovered raise TemplateNotFoundError(name) from exc
def _discover_template(self, name: str) -> str | None: candidates = _template_candidates(name) for directory in self._template_dirs: for candidate in candidates: path = directory / candidate if path.exists() and path.is_file(): return path.read_text(encoding="utf-8") return None
[docs] def load_partial_templates(template_dirs: Iterable[Path]) -> dict[str, str]: """Load all Jinja/j2/txt templates from the given directories into a dict.""" partials: dict[str, str] = {} for directory in template_dirs: if not directory.exists() or not directory.is_dir(): continue for path in directory.rglob("*.jinja"): key = _relative_template_key(directory, path) partials[key] = path.read_text(encoding="utf-8") for path in directory.rglob("*.j2"): key = _relative_template_key(directory, path) partials[key] = path.read_text(encoding="utf-8") for path in directory.rglob("*.txt"): key = _relative_template_key(directory, path) partials[key] = path.read_text(encoding="utf-8") return partials
[docs] def packaged_template_root() -> Path: """Return the path to the bundled prompt templates directory.""" return Path(__file__).resolve().parent / "templates"
def _template_candidates(name: str) -> tuple[str, ...]: path = Path(name) if path.suffix: return (name,) return (f"{name}.jinja", f"{name}.j2", f"{name}.txt") def _relative_template_key(root: Path, path: Path) -> str: relative = path.relative_to(root) without_suffix = relative.with_suffix("") return without_suffix.as_posix()
[docs] def default_template_dirs(workspace_root: Path) -> tuple[Path, ...]: """Convention-over-configuration prompt template directories.""" return ( workspace_root / ".agent" / "prompts" / "shared", workspace_root / ".agent" / "prompts", workspace_root / ".agent" / "prompts" / "partials", packaged_template_root(), packaged_template_root() / "shared", )