"""Shared MCP transport helpers: mcp.toml loading, upstream merging, env serialization."""
from __future__ import annotations
import json
from typing import TYPE_CHECKING, cast
from loguru import logger
from ralph.config.mcp_loader import load_mcp_config
from ralph.mcp.upstream.config import (
UPSTREAM_MCP_CONFIG_ENV,
UpstreamMcpServer,
serialize_upstream_mcp_servers,
)
if TYPE_CHECKING:
from collections.abc import Callable
from pathlib import Path
from ralph.config.mcp_models import McpConfig
[docs]
def mcp_toml_as_upstreams(workspace_path: Path | None) -> tuple[UpstreamMcpServer, ...]:
"""Load .agent/mcp.toml and return the configured upstream MCP servers."""
config_path = (workspace_path / ".agent" / "mcp.toml") if workspace_path is not None else None
mcp_config = load_mcp_config(config_path=config_path)
return mcp_config_as_upstreams(mcp_config)
def _parse_json_config_file(path: Path) -> dict[str, object]:
if not path.exists():
return {}
try:
raw_payload: object = json.loads(path.read_text(encoding="utf-8"))
except (OSError, json.JSONDecodeError):
return {}
if not isinstance(raw_payload, dict):
return {}
return cast("dict[str, object]", raw_payload)
def _load_mcpservers_from_paths(
paths: tuple[Path, ...],
entry_normalizer: Callable[[str, object], tuple[str, object] | None] | None = None,
) -> dict[str, object]:
merged: dict[str, object] = {}
for path in paths:
config_obj = _parse_json_config_file(path)
if not config_obj:
continue
value = config_obj.get("mcpServers")
if not isinstance(value, dict):
continue
server_entries = cast("dict[str, object]", value)
if entry_normalizer is None:
merged.update(server_entries)
continue
for server_name, server_entry in server_entries.items():
normalized = entry_normalizer(server_name, server_entry)
if normalized is not None:
merged[normalized[0]] = normalized[1]
return merged
[docs]
def mcp_config_as_upstreams(mcp_config: McpConfig) -> tuple[UpstreamMcpServer, ...]:
"""Convert loaded MCP config into Ralph custom upstream server records."""
return tuple(
UpstreamMcpServer(
name=spec.name,
transport=spec.transport,
url=spec.url,
command=spec.command,
args=tuple(spec.args),
env=dict(spec.env),
origin="custom",
)
for spec in mcp_config.mcp_servers.values()
)
[docs]
def merge_mcp_toml_into_upstreams(
agent_native: tuple[UpstreamMcpServer, ...],
mcp_toml_servers: tuple[UpstreamMcpServer, ...],
) -> tuple[UpstreamMcpServer, ...]:
"""Merge mcp.toml servers into agent-native upstreams, preferring mcp.toml on conflict."""
merged: dict[str, UpstreamMcpServer] = {s.name: s for s in agent_native}
for server in mcp_toml_servers:
if server.name in merged:
logger.warning(
"mcp.toml server '{}' overrides agent-native upstream config",
server.name,
)
merged[server.name] = server
return tuple(merged.values())
[docs]
def set_upstream_mcp_config(
runtime_env: dict[str, str], upstreams: tuple[UpstreamMcpServer, ...]
) -> None:
"""Inject upstream MCP config into the runtime environment dict."""
if upstreams:
runtime_env[UPSTREAM_MCP_CONFIG_ENV] = serialize_upstream_mcp_servers(upstreams)
return
runtime_env.pop(UPSTREAM_MCP_CONFIG_ENV, None)
__all__ = [
"mcp_config_as_upstreams",
"mcp_toml_as_upstreams",
"merge_mcp_toml_into_upstreams",
"set_upstream_mcp_config",
]