Source code for secantus.server

from __future__ import annotations

import contextlib
import itertools
import logging
import socket
import threading
from types import TracebackType
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
    from typing import Self

from secantus.auth import ConnectionAuth
from secantus.commands import CommandContext, dispatch
from secantus.connreg import ConnectionRegistry
from secantus.cursors import CursorRegistry
from secantus.logbuf import LogBuffer
from secantus.metrics import Metrics
from secantus.sessions import SessionRegistry
from secantus.storage import Storage
from secantus.wire import (
    OP_MSG_FLAG_MORE_TO_COME,
    ConnectionClosed,
    MalformedBodyError,
    OpMsg,
    OpQuery,
    WireProtocolError,
    build_op_msg_reply,
    build_op_reply,
    read_message,
)

logger = logging.getLogger(__name__)


def _merge_op_msg_body(op: OpMsg) -> dict[str, Any]:
    body = dict(op.body)
    for identifier, docs in op.document_sequences:
        body[identifier] = docs
    return body


def _db_from_namespace(ns: str) -> str:
    db, _, _ = ns.partition(".")
    return db or "admin"


# Default per-connection idle timeout. A client that connects and
# sends nothing (or stalls mid-message) gets disconnected after this
# many seconds. Without it, a hostile peer can pin a thread forever
# by holding the socket open and not writing.
DEFAULT_CLIENT_IDLE_TIMEOUT_S = 300.0

# Default cap on simultaneous client connections. Each connection runs
# its own dedicated thread; without a cap a TCP-flood DoS exhausts the
# Python thread pool.
DEFAULT_MAX_CONNECTIONS = 1000


[docs] class SecantusDBServer: def __init__( self, host: str = "127.0.0.1", port: int = 0, storage_path: str = "./secantus-data", *, replica_set_name: str | None = "secantus", require_auth: bool = False, ttl_sweep_seconds: float = 60.0, client_idle_timeout_s: float = DEFAULT_CLIENT_IDLE_TIMEOUT_S, max_connections: int = DEFAULT_MAX_CONNECTIONS, ) -> None: self.host = host self.port = port self.replica_set_name = replica_set_name self.require_auth = require_auth self.client_idle_timeout_s = client_idle_timeout_s self.max_connections = max_connections self._socket: socket.socket | None = None self._thread: threading.Thread | None = None self._stop_event = threading.Event() # Active connection counter — incremented in _handle_client, # decremented on its way out. Read by _serve_forever to enforce # max_connections. self._active_conns = 0 self._active_conns_lock = threading.Lock() # Oplog writes are only useful when change-stream clients can # connect, which requires replica-set advertisement in `hello`. # Skip them in pure-standalone mode to drop a per-write BSON # encode + oplog-table cursor write per modified document. self.storage = Storage( storage_path, enable_oplog=replica_set_name is not None, ttl_sweep_seconds=ttl_sweep_seconds, ) self.cursors = CursorRegistry() # Per-server counters surfaced through `serverStatus`. Started # eagerly so `start_monotonic` reflects construction time, not # the first command — uptime then matches what users expect. self.metrics = Metrics() # Per-connection visibility (currentOp) and an in-memory log # buffer (getLog). Both initialised eagerly so the registry / # buffer survive the lifetime of the server, not just the first # connection. self.connections = ConnectionRegistry() self.logs = LogBuffer() # Logical-session registry: drivers send an lsid on every # command and bracket session lifetime via ``startSession`` / # ``endSessions`` / ``refreshSessions``. Tracked here so the # session-management commands actually have state to operate # on instead of returning canned ``{ok: 1}``. self.sessions = SessionRegistry() @property def address(self) -> tuple[str, int]: if self._socket is None: raise RuntimeError("server is not started") host, port, *_ = self._socket.getsockname() return host, port @property def uri(self) -> str: host, port = self.address return f"mongodb://{host}:{port}/" def start(self) -> None: if self._socket is not None: raise RuntimeError("server is already started") sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind((self.host, self.port)) sock.listen() self._socket = sock self.host, self.port = sock.getsockname()[:2] self._stop_event.clear() self._thread = threading.Thread( target=self._serve_forever, name="secantus-accept", daemon=True ) self._thread.start() logger.info("secantus listening on %s:%d", self.host, self.port) def stop(self) -> None: self._stop_event.set() if self._socket is not None: with contextlib.suppress(OSError): self._socket.close() self._socket = None if self._thread is not None: self._thread.join(timeout=2.0) self._thread = None self.storage.close() def wait(self) -> None: self._stop_event.wait() def _serve_forever(self) -> None: assert self._socket is not None while not self._stop_event.is_set(): try: conn, addr = self._socket.accept() except OSError: return # Refuse new connections beyond the cap. A flood-DoS attacker # would otherwise spawn a thread per accepted socket. We close # the socket immediately rather than queuing — clients should # see "connection refused" / EOF, not block. with self._active_conns_lock: if self._active_conns >= self.max_connections: overflowed = True else: self._active_conns += 1 overflowed = False if overflowed: logger.warning( "rejecting connection from %s: %d active >= %d cap", addr, self._active_conns, self.max_connections, ) with contextlib.suppress(OSError): conn.close() continue # Set an idle timeout on the socket so a hostile or stuck # peer can't pin a thread forever. Applies to all socket # reads in _handle_client. with contextlib.suppress(OSError): conn.settimeout(self.client_idle_timeout_s) handler = threading.Thread( target=self._handle_client, args=(conn, addr), daemon=True, ) handler.start() def _handle_client(self, conn: socket.socket, addr: tuple[str, int]) -> None: connection_id = self.connections.open((addr[0], addr[1])) reply_ids = itertools.count(1) connection_auth = ConnectionAuth() self.metrics.connection_opened() logger.debug("client %d connected from %s", connection_id, addr) self.logs.append( "I", "NETWORK", "connection accepted", {"conn_id": connection_id, "from": addr} ) try: with conn: while not self._stop_event.is_set(): try: message = read_message(conn) except ConnectionClosed: return except TimeoutError: # Idle timeout fired — drop the connection so the # thread can be reaped instead of pinned forever. logger.debug("client %d idle timeout, closing", connection_id) return except MalformedBodyError as exc: # Body bytes failed BSON validation. Build a # targeted error reply so the client gets a # ``BadValue`` and can keep using the connection # for subsequent commands. Mongod returns the # same shape — InvalidBSON in a command body is # not a protocol-level fault. logger.warning("malformed BSON from %s: %s", addr, exc) try: reply = build_op_msg_reply( response_to=exc.header.request_id, request_id=next(reply_ids), body={ "ok": 0.0, "errmsg": str(exc), "code": 2, "codeName": "BadValue", }, ) conn.sendall(reply) except OSError: # Client may have hung up before we could # respond; fall through to the close path. return continue except WireProtocolError as exc: logger.warning("wire protocol error from %s: %s", addr, exc) return op = message.op try: server_addr = self.address if self._socket is not None else None except RuntimeError: server_addr = None if isinstance(op, OpMsg): body = _merge_op_msg_body(op) ctx = CommandContext( connection_id=connection_id, storage=self.storage, cursors=self.cursors, db_name=body.get("$db", "admin"), server_address=server_addr, replica_set_name=self.replica_set_name, connection_auth=connection_auth, require_auth=self.require_auth, metrics=self.metrics, connections=self.connections, logs=self.logs, sessions=self.sessions, ) response_doc = dispatch(body, ctx) # `moreToCome` (bit 1) is the wire signal for # fire-and-forget requests — `writeConcern: {w: 0}` # unacknowledged writes use it. The spec requires the # server NOT to send a reply; if we do, the driver # mis-pairs it with the next response and aborts the # connection with a responseTo/requestId mismatch. if op.flags & OP_MSG_FLAG_MORE_TO_COME: continue reply = build_op_msg_reply( response_to=message.header.request_id, request_id=next(reply_ids), body=response_doc, ) elif isinstance(op, OpQuery): ctx = CommandContext( connection_id=connection_id, storage=self.storage, cursors=self.cursors, db_name=_db_from_namespace(op.full_collection_name), server_address=server_addr, replica_set_name=self.replica_set_name, connection_auth=connection_auth, require_auth=self.require_auth, metrics=self.metrics, connections=self.connections, logs=self.logs, sessions=self.sessions, ) response_doc = dispatch(op.query, ctx) reply = build_op_reply( response_to=message.header.request_id, request_id=next(reply_ids), documents=[response_doc], ) else: logger.warning("unhandled op type %r from %s", type(op), addr) return try: conn.sendall(reply) except OSError: return except Exception: logger.exception("unhandled error on connection %d", connection_id) finally: self.metrics.connection_closed() self.connections.close(connection_id) with self._active_conns_lock: self._active_conns -= 1 logger.debug("client %d disconnected", connection_id) def __enter__(self) -> Self: self.start() return self def __exit__( self, exc_type: type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None, ) -> None: self.stop()