"""Startup validation for user-defined upstream MCP servers.
Ralph fails fast if any custom MCP server cannot complete the standard
``initialize`` → ``notifications/initialized`` → ``tools/list`` handshake.
Set ``RALPH_MCP_STRICT=0`` to fall back to the legacy warn-and-skip
behaviour for CI smoke runs.
"""
from __future__ import annotations
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING, Protocol
from loguru import logger
from ralph.mcp.protocol.startup import (
PreflightError,
mcp_preflight_timeout_from_env,
preflight_http_mcp_server_tools,
)
from ralph.mcp.upstream._upstream_server_report import UpstreamServerReport
from ralph.mcp.upstream._upstream_validation_error import UpstreamValidationError
from ralph.mcp.upstream.client import make_upstream_client
from ralph.mcp.upstream.models import UpstreamCallError
if TYPE_CHECKING:
from collections.abc import Callable, Iterable, Mapping
from datetime import timedelta
from ralph.mcp.upstream.config import UpstreamMcpServer
if TYPE_CHECKING:
class HttpPreflightFn(Protocol):
"""Callable protocol for running an HTTP MCP server preflight check."""
def __call__(
self, endpoint: str, required_tools: tuple[str, ...], timeout: timedelta
) -> None: ...
_STRICT_ENV_VAR = "RALPH_MCP_STRICT"
_STRICT_FALSE_VALUES = frozenset({"0", "false", "no", "off"})
[docs]
@dataclass(frozen=True)
class UpstreamValidationReport:
"""Aggregated validation results for all configured upstream MCP servers."""
servers: tuple[UpstreamServerReport, ...]
@property
def all_ok(self) -> bool:
return all(s.ok for s in self.servers)
@property
def failures(self) -> tuple[UpstreamServerReport, ...]:
return tuple(s for s in self.servers if not s.ok)
[docs]
def strict_mode_from_env(env: Mapping[str, str] | None = None) -> bool:
"""Return True when strict mode is active (the default)."""
env_map = os.environ if env is None else env
raw = env_map.get(_STRICT_ENV_VAR)
if raw is None:
return True
return raw.strip().lower() not in _STRICT_FALSE_VALUES
def _list_stdio_tools(server: UpstreamMcpServer, timeout: timedelta) -> list[str]:
"""Probe an stdio upstream by spawning the configured client.
The ``timeout`` budget bounds the subprocess via :mod:`subprocess.run` so we
never hang the orchestrator when an MCP binary forgets to flush stdout.
"""
del timeout # subprocess timeout is enforced by the underlying client
client = make_upstream_client(server)
return [tool.name for tool in client.list_tools()]
def _redact_error(server: UpstreamMcpServer, exc: BaseException) -> str:
"""Render an exception message with upstream env values stripped out."""
message = str(exc)
for value in server.env.values():
if value:
message = message.replace(value, "***")
return message
def _format_failure_report(failures: Iterable[UpstreamServerReport]) -> str:
lines: list[str] = []
for failure in failures:
keys_part = f" env_keys={list(failure.secret_keys)}" if failure.secret_keys else ""
lines.append(
f"- {failure.name} (transport={failure.transport}){keys_part}: {failure.error}"
)
return "\n".join(lines)
[docs]
def validate_upstream_mcp_servers(
servers: Iterable[UpstreamMcpServer],
*,
timeout: timedelta | None = None,
strict: bool | None = None,
preflight_http: HttpPreflightFn = preflight_http_mcp_server_tools,
list_stdio_tools: Callable[[UpstreamMcpServer, timedelta], list[str]] | None = None,
) -> UpstreamValidationReport:
"""Validate every configured upstream MCP server at startup.
Args:
servers: Iterable of normalized upstream MCP server definitions.
timeout: Optional preflight timeout. Defaults to
:func:`mcp_preflight_timeout_from_env` (30s, tunable via
``RALPH_MCP_PREFLIGHT_TIMEOUT_MS``).
strict: Override strict-mode autodetection. If unset, reads
``RALPH_MCP_STRICT`` from the environment.
preflight_http: Injection point for the HTTP preflight helper. Tests
override this to drive the validator without touching the network.
list_stdio_tools: Injection point for the stdio probe. Defaults to
:func:`_list_stdio_tools`, which spawns the configured command
through :class:`StdioUpstreamClient`.
Returns:
:class:`UpstreamValidationReport` with one entry per server. In soft
mode failures are reported with ``ok=False`` and a warning is logged
per failure. In strict mode an :class:`UpstreamValidationError` is
raised after all servers are inspected so the diagnostic listing names
every problem at once.
"""
effective_timeout = timeout or mcp_preflight_timeout_from_env()
effective_strict = strict_mode_from_env() if strict is None else strict
effective_stdio_probe = list_stdio_tools or _list_stdio_tools
reports: list[UpstreamServerReport] = []
server_list = list(servers)
for server in server_list:
secret_keys = tuple(sorted(server.env.keys()))
try:
tool_count = _probe_one_server(
server,
effective_timeout,
preflight_http=preflight_http,
list_stdio_tools=effective_stdio_probe,
)
except (PreflightError, UpstreamCallError, ValueError, OSError) as exc:
reports.append(
UpstreamServerReport(
name=server.name,
transport=server.transport,
ok=False,
tool_count=0,
error=_redact_error(server, exc),
secret_keys=secret_keys,
)
)
continue
reports.append(
UpstreamServerReport(
name=server.name,
transport=server.transport,
ok=True,
tool_count=tool_count,
error=None,
secret_keys=secret_keys,
)
)
report = UpstreamValidationReport(servers=tuple(reports))
failures = report.failures
if not failures:
if server_list:
logger.info("Validated {} custom MCP server(s); all reachable.", len(server_list))
return report
if effective_strict:
raise UpstreamValidationError(
"Custom MCP servers failed startup validation:\n" + _format_failure_report(failures)
)
for failure in failures:
logger.warning(
"Custom MCP server '{}' ({}) failed validation: {}",
failure.name,
failure.transport,
failure.error,
)
return report
def _probe_one_server(
server: UpstreamMcpServer,
timeout: timedelta,
*,
preflight_http: HttpPreflightFn,
list_stdio_tools: Callable[[UpstreamMcpServer, timedelta], list[str]],
) -> int:
if server.transport == "http":
if not server.url:
raise ValueError(f"upstream server '{server.name}' is missing 'url'")
preflight_http(server.url, (), timeout)
# The preflight helper does not return tool counts, so probe again to
# enumerate tools for diagnostic display.
client = make_upstream_client(server)
return len(client.list_tools())
if server.transport == "stdio":
if not server.command:
raise ValueError(f"upstream server '{server.name}' is missing 'command'")
return len(list_stdio_tools(server, timeout))
raise ValueError(
f"upstream server '{server.name}' has unsupported transport '{server.transport}'"
)
__all__ = [
"UpstreamServerReport",
"UpstreamValidationError",
"UpstreamValidationReport",
"strict_mode_from_env",
"validate_upstream_mcp_servers",
]