Source code for ralph.recovery.connectivity

"""Proactive connectivity detection with auto-resume."""

from __future__ import annotations

import asyncio
import contextlib
from collections.abc import Awaitable, Callable
from datetime import UTC, datetime

from ralph.recovery.connectivity_event import ConnectivityEvent
from ralph.recovery.connectivity_state import ConnectivityState

__all__ = ["ConnectivityEvent", "ConnectivityMonitor", "ConnectivityState"]

# Default probe targets: DNS resolvers over TCP
_DEFAULT_PROBE_TARGETS: list[tuple[str, int]] = [
    ("1.1.1.1", 53),
    ("8.8.8.8", 53),
]
_DEFAULT_PROBE_INTERVAL_S: float = 10.0
_DEFAULT_PROBE_TIMEOUT_S: float = 2.0


async def _default_probe(host: str, port: int, timeout_s: float) -> bool:
    """Attempt TCP connection; return True if reachable."""
    try:
        pair: tuple[asyncio.StreamReader, asyncio.StreamWriter] = await asyncio.wait_for(
            asyncio.open_connection(host, port),
            timeout=timeout_s,
        )
        writer = pair[1]
        writer.close()
        with contextlib.suppress(Exception):
            await writer.wait_closed()
        return True
    except Exception:
        return False


ProbeCallable = Callable[[str, int, float], Awaitable[bool]]
ListenerCallable = Callable[[ConnectivityEvent], None]


[docs] class ConnectivityMonitor: """Proactively detect connectivity loss and surface state transitions. All timing and network I/O is injectable so tests run deterministically without real sockets. """ def __init__( self, *, probe_targets: list[tuple[str, int]] = _DEFAULT_PROBE_TARGETS, probe_interval_s: float = _DEFAULT_PROBE_INTERVAL_S, probe_timeout_s: float = _DEFAULT_PROBE_TIMEOUT_S, probe: ProbeCallable | None = None, ) -> None: self._targets = probe_targets self._interval = probe_interval_s self._timeout = probe_timeout_s self._probe = probe or _default_probe self._current_state: ConnectivityState = ConnectivityState.UNKNOWN self._listeners: list[ListenerCallable] = [] self._task: asyncio.Task[None] | None = None self._online_event: asyncio.Event = asyncio.Event() self._stopped: bool = False @property def current_state(self) -> ConnectivityState: return self._current_state
[docs] def add_listener(self, cb: ListenerCallable) -> Callable[[], None]: """Register a listener for connectivity events. Returns an unsubscribe callable.""" self._listeners.append(cb) def _unsubscribe() -> None: self._listeners.remove(cb) return _unsubscribe
[docs] async def start(self) -> None: """Start the background connectivity probe loop.""" self._stopped = False self._task = asyncio.create_task(self._probe_loop())
[docs] async def stop(self) -> None: """Stop the background probe loop and unblock any waiters.""" self._stopped = True self._online_event.set() # unblock any wait_online() callers if self._task is not None: self._task.cancel() with contextlib.suppress(asyncio.CancelledError, Exception): await self._task self._task = None
[docs] async def wait_online(self) -> None: """Suspend until connectivity is restored (or monitor is stopped).""" await self._online_event.wait()
async def _probe_once(self) -> bool: """Probe all targets; return True if any are reachable.""" results = await asyncio.gather( *[self._probe(host, port, self._timeout) for host, port in self._targets], return_exceptions=True, ) return any(r is True for r in results) async def _probe_loop(self) -> None: while not self._stopped: reachable = await self._probe_once() new_state = ConnectivityState.ONLINE if reachable else ConnectivityState.OFFLINE if new_state != self._current_state: self._current_state = new_state if new_state == ConnectivityState.ONLINE: self._online_event.set() else: self._online_event.clear() evt = ConnectivityEvent( state=new_state, since=datetime.now(UTC), reason="probe" if reachable else "all probes failed", ) for listener in list(self._listeners): with contextlib.suppress(Exception): listener(evt) await asyncio.sleep(self._interval)