Source code for ralph.mcp.webvisit.fetcher

"""HTTP fetch layer for the visit_url tool.

Performs a single HTTP GET with SSRF-guard, size cap, and timeout enforcement.
No network IO should escape this module in production code paths.
"""

from __future__ import annotations

import ipaddress
import socket
from dataclasses import dataclass
from typing import Literal
from urllib.parse import urlparse

import httpx

FetchStatus = Literal[
    "ok",
    "timeout",
    "unreachable",
    "http_error",
    "unsupported_content",
    "too_large",
    "blocked_by_policy",
    "invalid_url",
]

_SUPPORTED_CONTENT_TYPES = frozenset(
    {
        "text/html",
        "text/plain",
        "application/xhtml+xml",
        "application/xml",
        "text/xml",
    }
)

_PRIVATE_HOSTNAMES = frozenset({"localhost"})

_HTTP_SUCCESS_MIN = 200
_HTTP_SUCCESS_MAX = 300


[docs] @dataclass(frozen=True) class FetchOutcome: """Result of a single HTTP fetch attempt.""" status: FetchStatus effective_url: str | None = None http_status: int | None = None content_type: str | None = None body: bytes | None = None error: str | None = None
_DNS_SSRF_CHECK_TIMEOUT = 5.0 # seconds; prevents unbounded blocking on slow DNS def _is_private_ip(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool: return ( ip.is_loopback or ip.is_private or ip.is_link_local or ip.is_multicast or ip.is_reserved or ip.is_unspecified ) def _is_private_address(host: str) -> bool: if host.lower() in _PRIVATE_HOSTNAMES: return True # Parse as a literal IP address first — no DNS needed and no blocking try: return _is_private_ip(ipaddress.ip_address(host)) except ValueError: pass # Hostname: resolve with a bounded timeout so slow DNS cannot block the server old_timeout = socket.getdefaulttimeout() try: socket.setdefaulttimeout(_DNS_SSRF_CHECK_TIMEOUT) results = socket.getaddrinfo(host, None, proto=socket.IPPROTO_TCP) except OSError: return False finally: socket.setdefaulttimeout(old_timeout) for _family, _type, _proto, _canonname, sockaddr in results: addr = sockaddr[0] try: if _is_private_ip(ipaddress.ip_address(addr)): return True except ValueError: continue return False def _content_type_base(content_type_header: str | None) -> str | None: if not content_type_header: return None return content_type_header.split(";")[0].strip().lower() def _check_url_policy(url: str, *, allow_private_networks: bool) -> FetchOutcome | None: """Validate URL scheme, hostname, and SSRF policy. Returns an error FetchOutcome or None.""" parsed = urlparse(url) if parsed.scheme not in {"http", "https"}: return FetchOutcome(status="invalid_url", error=f"unsupported scheme: {parsed.scheme!r}") if not parsed.hostname: return FetchOutcome(status="invalid_url", error="missing hostname") if not allow_private_networks and _is_private_address(parsed.hostname): return FetchOutcome( status="blocked_by_policy", error=( "access to private/loopback networks is disabled by default; " "set allow_private_networks=true in [web_visit] config to enable" ), ) return None def _read_streaming_body( response: httpx.Response, *, max_bytes: int, effective_url: str, http_status: int, content_type_header: str | None, ) -> FetchOutcome: chunks: list[bytes] = [] total = 0 for chunk in response.iter_bytes(): total += len(chunk) if total > max_bytes: return FetchOutcome( status="too_large", effective_url=effective_url, http_status=http_status, content_type=content_type_header, error=f"response body exceeds {max_bytes} bytes", ) chunks.append(chunk) return FetchOutcome( status="ok", effective_url=effective_url, http_status=http_status, content_type=content_type_header, body=b"".join(chunks), )
[docs] def fetch_url( url: str, *, timeout_ms: int, max_bytes: int, user_agent: str, allow_private_networks: bool, ) -> FetchOutcome: """Fetch a single URL and return a FetchOutcome. Never raises on network failures — always returns a FetchOutcome. """ policy_error = _check_url_policy(url, allow_private_networks=allow_private_networks) if policy_error is not None: return policy_error timeout = timeout_ms / 1000.0 headers = {"User-Agent": user_agent} try: with ( httpx.Client(follow_redirects=True, timeout=timeout) as client, client.stream("GET", url, headers=headers) as response, ): effective_url: str = str(response.url) http_status: int = response.status_code content_type_header: str | None = response.headers.get("content-type") content_type_base = _content_type_base(content_type_header) if http_status < _HTTP_SUCCESS_MIN or http_status >= _HTTP_SUCCESS_MAX: return FetchOutcome( status="http_error", effective_url=effective_url, http_status=http_status, content_type=content_type_header, error=f"HTTP {http_status}", ) if content_type_base not in _SUPPORTED_CONTENT_TYPES: return FetchOutcome( status="unsupported_content", effective_url=effective_url, http_status=http_status, content_type=content_type_header, error=f"unsupported content type: {content_type_header!r}", ) return _read_streaming_body( response, max_bytes=max_bytes, effective_url=effective_url, http_status=http_status, content_type_header=content_type_header, ) except httpx.TimeoutException as exc: return FetchOutcome(status="timeout", error=str(exc)) except (httpx.ConnectError, httpx.RemoteProtocolError, httpx.HTTPError) as exc: return FetchOutcome(status="unreachable", error=str(exc))
__all__ = ["FetchOutcome", "FetchStatus", "fetch_url"]