Source code for ralph.mcp.protocol.transport

"""MCP transport layer.

Provides transport abstractions for MCP communication. Supports stdio and
HTTP/SSE connections to MCP clients and servers.
"""

from __future__ import annotations

import contextlib
import json
import threading
from queue import Empty, Queue
from subprocess import PIPE as _SUBPROCESS_PIPE
from typing import IO, TYPE_CHECKING, Protocol, cast

from loguru import logger

from ralph.mcp.protocol._mcp_message import MCPMessage
from ralph.mcp.protocol._transport_error import TransportError
from ralph.process.manager import (
    ManagedProcess,
    ProcessTerminationError,
    SpawnOptions,
    get_process_manager,
)

__all__ = [
    "MCPMessage",
    "MCPTransport",
    "StdioTransport",
    "TransportError",
]

if TYPE_CHECKING:
    from collections.abc import AsyncIterator, Callable
if TYPE_CHECKING:

    class ProcessLike(Protocol):
        """Subset of subprocess.Popen required by StdioTransport."""

        @property
        def stdin(self) -> IO[bytes] | None: ...

        @property
        def stdout(self) -> IO[bytes] | None: ...

        @property
        def stderr(self) -> IO[bytes] | None: ...

        def terminate(self) -> None: ...
        def wait(self, timeout: float | None = None) -> int: ...
        def kill(self) -> None: ...

    class ThreadLike(Protocol):
        """Minimal threading.Thread interface required by StdioTransport."""

        def start(self) -> None: ...


[docs] class StdioTransport: """MCP transport over stdio. Communicates with an MCP server process via stdin/stdout. Each line is a JSON-RPC message. """ def __init__( self, command: list[str], cwd: str | None = None, *, process_factory: Callable[[list[str], str | None], ProcessLike] | None = None, thread_factory: Callable[[Callable[[], None], bool], ThreadLike] | None = None, ) -> None: """Initialize stdio transport. Args: command: Command and arguments to spawn as MCP server. cwd: Working directory for the subprocess. """ self._command = command self._cwd = cwd self._process_factory = process_factory or _default_process_factory self._thread_factory = thread_factory or _default_thread_factory self._process: ProcessLike | None = None self._send_queue: Queue[str | dict[str, object]] = Queue() self._recv_queue: Queue[MCPMessage] = Queue() self._closed = False self._lock = threading.Lock()
[docs] def start(self) -> None: """Start the MCP server process.""" self._process = self._process_factory(self._command, self._cwd) self._reader_thread = self._thread_factory(self._read_loop, True) self._writer_thread = self._thread_factory(self._write_loop, True) self._reader_thread.start() self._writer_thread.start() logger.info("Started MCP server: {}", " ".join(self._command))
def _read_loop(self) -> None: """Read from stdout in a loop.""" proc = self._process if proc is None or proc.stdout is None: return for raw_line in proc.stdout: if self._closed: break stripped = raw_line.strip() if not stripped: continue try: msg_dict = cast("dict[str, object]", json.loads(stripped.decode("utf-8"))) method_value = msg_dict.get("method") method = method_value if isinstance(method_value, str) else "" params_value = msg_dict.get("params") params = ( cast("dict[str, object] | None", params_value) if isinstance(params_value, dict) else None ) error_value = msg_dict.get("error") error = ( cast("dict[str, object] | None", error_value) if isinstance(error_value, dict) else None ) msg_id_value = msg_dict.get("id") msg_id = msg_id_value if isinstance(msg_id_value, str | int) else None msg = MCPMessage( method=method, params=params, msg_id=msg_id, error=error, ) self._recv_queue.put(msg) except json.JSONDecodeError as exc: logger.warning("Failed to parse MCP message: {}", exc) def _write_loop(self) -> None: """Write to stdin in a loop.""" proc = self._process if proc is None or proc.stdin is None: return stdin = proc.stdin while not self._closed: try: msg = self._send_queue.get(timeout=0.1) line = (json.dumps(msg) + "\n").encode("utf-8") stdin.write(line) stdin.flush() except Empty: continue except (OSError, ValueError): break
[docs] async def send(self, message: MCPMessage) -> None: """Send a message to the MCP server.""" if self._closed: raise TransportError("Transport is closed") msg_dict: dict[str, object] = {"method": message.method} if message.params is not None: msg_dict["params"] = message.params if message.msg_id is not None: msg_dict["id"] = message.msg_id self._send_queue.put(msg_dict)
[docs] async def recv(self) -> AsyncIterator[MCPMessage]: """Receive messages as an async iterator.""" while not self._closed: try: msg = self._recv_queue.get(timeout=0.1) yield msg except Empty: continue
[docs] async def close(self) -> None: """Close the transport.""" with self._lock: if self._closed: return self._closed = True if self._process is not None: proc = self._process self._process = None if isinstance(proc, ManagedProcess): with contextlib.suppress(ProcessTerminationError): proc.terminate(grace_period_s=5.0) for attr in ("stdin", "stdout", "stderr"): pipe: IO[bytes] | None = getattr(proc, attr, None) if pipe is not None: with contextlib.suppress(Exception): pipe.close() else: proc.terminate() proc.wait() logger.info("Closed stdio transport")
MCPTransport = StdioTransport def _default_process_factory(command: list[str], cwd: str | None) -> ManagedProcess: label = f"mcp-stdio:{command[0]}" return get_process_manager().spawn( command, SpawnOptions( cwd=cwd, stdin=_SUBPROCESS_PIPE, stdout=_SUBPROCESS_PIPE, stderr=_SUBPROCESS_PIPE, label=label, ), ) def _default_thread_factory(target: Callable[[], None], daemon: bool) -> threading.Thread: return threading.Thread(target=target, daemon=daemon)