"""StdioUpstreamClient — upstream MCP client that communicates over stdio."""
from __future__ import annotations
import json
import os
import subprocess
from collections.abc import Callable, Mapping
from typing import TYPE_CHECKING, cast
from ralph.mcp.upstream.models import UpstreamCallError, UpstreamTool
from ralph.process.manager import SpawnOptions, get_process_manager
if TYPE_CHECKING:
from ralph.mcp.upstream.config import UpstreamMcpServer
JsonObject = dict[str, object]
JsonRpcCaller = Callable[[str, JsonObject], JsonObject]
[docs]
class StdioUpstreamClient:
"""Upstream MCP client that communicates over stdio with a subprocess."""
def __init__(
self,
server: UpstreamMcpServer,
*,
caller: JsonRpcCaller | None = None,
) -> None:
self._server = server
self._caller: JsonRpcCaller = caller if caller is not None else _make_stdio_caller(server)
def list_tools(self) -> list[UpstreamTool]:
try:
result = self._caller("tools/list", {})
except UpstreamCallError:
raise
except Exception as exc:
raise UpstreamCallError(
f"upstream server '{self._server.name}' tools/list failed: {exc}"
) from exc
return _parse_tools(result)
def call_tool(self, name: str, arguments: JsonObject) -> object:
try:
result = self._caller("tools/call", {"name": name, "arguments": arguments})
except UpstreamCallError:
raise
except Exception as exc:
raise UpstreamCallError(
f"upstream server '{self._server.name}' tool '{name}' failed: {exc}"
) from exc
return result
def _parse_tools(result: JsonObject) -> list[UpstreamTool]:
raw_tools = result.get("tools")
if not isinstance(raw_tools, list):
return []
tools: list[UpstreamTool] = []
for item in raw_tools:
if not isinstance(item, Mapping):
continue
item_map = cast("Mapping[str, object]", item)
name = item_map.get("name")
if not isinstance(name, str) or not name:
continue
description_raw = item_map.get("description")
description = str(description_raw) if description_raw is not None else ""
schema_raw = item_map.get("inputSchema") or item_map.get("input_schema")
if isinstance(schema_raw, Mapping):
input_schema: dict[str, object] = dict(cast("Mapping[str, object]", schema_raw))
else:
input_schema = {}
tools.append(UpstreamTool(name=name, description=description, input_schema=input_schema))
return tools
def _make_stdio_caller(server: UpstreamMcpServer) -> JsonRpcCaller:
def _call(method: str, params: JsonObject) -> JsonObject:
if not server.command:
raise UpstreamCallError(f"upstream server '{server.name}' has no command configured")
command = [server.command, *server.args]
initialize_payload: JsonObject = {
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {"name": "ralph-upstream", "version": "0"},
},
}
initialized_payload: JsonObject = {
"jsonrpc": "2.0",
"method": "notifications/initialized",
"params": {},
}
method_payload: JsonObject = {
"jsonrpc": "2.0",
"id": 2,
"method": method,
"params": params,
}
payload_lines = [
json.dumps(initialize_payload, separators=(",", ":")),
json.dumps(initialized_payload, separators=(",", ":")),
json.dumps(method_payload, separators=(",", ":")),
]
payload = "\n".join(payload_lines) + "\n"
env: dict[str, str] = {**os.environ, **server.env}
handle = get_process_manager().spawn(
command,
SpawnOptions(
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env,
label=f"upstream:{server.name}",
),
)
try:
stdout_bytes, _stderr = handle.communicate(input=payload.encode(), timeout=30)
except subprocess.TimeoutExpired:
handle.terminate(grace_period_s=0)
raise UpstreamCallError(f"upstream server '{server.name}' timed out") from None
if (handle.returncode or 0) != 0:
raise UpstreamCallError(
f"upstream server '{server.name}' process exited {handle.returncode}"
)
stdout_str = stdout_bytes.decode() if stdout_bytes else ""
stdout_lines = [line for line in stdout_str.splitlines() if line.strip()]
if not stdout_lines:
raise UpstreamCallError(f"upstream server '{server.name}' returned no JSON-RPC output")
raw: object = json.loads(stdout_lines[-1])
return _json_rpc_result(raw, f"'{server.name}'")
return _call
def _json_rpc_result(raw: object, context: str) -> JsonObject:
if not isinstance(raw, Mapping):
raise UpstreamCallError(f"unexpected response type from {context}")
raw_map = cast("Mapping[str, object]", raw)
err = raw_map.get("error")
if err is not None:
raise UpstreamCallError(f"JSON-RPC error from {context}: {err}")
result = raw_map.get("result")
if isinstance(result, Mapping):
return dict(cast("Mapping[str, object]", result))
return {}
__all__ = ["StdioUpstreamClient"]