Source code for secantus.server

from __future__ import annotations

import contextlib
import itertools
import logging
import socket
import ssl
import threading
import time
from types import TracebackType
from typing import TYPE_CHECKING, Any

import bson

if TYPE_CHECKING:
    from typing import Self

from secantus.auth import ConnectionAuth, subject_dn_from_peercert
from secantus.commands import CommandContext, dispatch
from secantus.connreg import ConnectionRegistry
from secantus.cursors import CursorRegistry
from secantus.failpoints import CloseConnectionRequested, FailPointRegistry
from secantus.logbuf import LogBuffer
from secantus.metrics import Metrics
from secantus.sessions import SessionRegistry
from secantus.storage import Storage
from secantus.transactions import Transaction, TransactionRegistry
from secantus.wire import (
    OP_MSG_FLAG_EXHAUST_ALLOWED,
    OP_MSG_FLAG_MORE_TO_COME,
    ConnectionClosed,
    MalformedBodyError,
    OpMsg,
    OpQuery,
    WireProtocolError,
    build_op_msg_reply,
    build_op_reply,
    read_message,
)

logger = logging.getLogger(__name__)


def _merge_op_msg_body(op: OpMsg) -> dict[str, Any]:
    body = dict(op.body)
    for identifier, docs in op.document_sequences:
        body[identifier] = docs
    return body


def _db_from_namespace(ns: str) -> str:
    db, _, _ = ns.partition(".")
    return db or "admin"


# Default per-connection idle timeout. A client that connects and
# sends nothing (or stalls mid-message) gets disconnected after this
# many seconds. Without it, a hostile peer can pin a thread forever
# by holding the socket open and not writing.
DEFAULT_CLIENT_IDLE_TIMEOUT_S = 300.0

# Default cap on simultaneous client connections. Each connection runs
# its own dedicated thread; without a cap a TCP-flood DoS exhausts the
# Python thread pool.
DEFAULT_MAX_CONNECTIONS = 1000


[docs] class SecantusDBServer: def __init__( self, host: str = "127.0.0.1", port: int = 0, storage_path: str = "./secantus-data", *, replica_set_name: str | None = "secantus", require_auth: bool = False, ttl_sweep_seconds: float = 60.0, noop_heartbeat_seconds: float = 0.0, client_idle_timeout_s: float = DEFAULT_CLIENT_IDLE_TIMEOUT_S, max_connections: int = DEFAULT_MAX_CONNECTIONS, oplog_retention_seconds: float = 3600.0, oplog_max_entries: int = 100_000, cache_size: str = "1G", session_max: int = 1000, sync_on_commit: bool = False, transaction_lifetime_seconds: float = 60.0, tls_cert_file: str | None = None, tls_key_file: str | None = None, tls_ca_file: str | None = None, tls_require_client_cert: bool = False, engine: str | None = None, ) -> None: # Engine selection is process-wide (see ``secantus.engine``). ``None`` # leaves the current selection (SECANTUS_ENGINE env / default Python) # untouched; "python" / "rust" / "auto" set it for the whole process. if engine is not None: from secantus import engine as _engine _engine.set_engine(engine) self.host = host self.port = port self.replica_set_name = replica_set_name self.require_auth = require_auth self.client_idle_timeout_s = client_idle_timeout_s self.max_connections = max_connections self._socket: socket.socket | None = None self._thread: threading.Thread | None = None self._stop_event = threading.Event() # TLS: when both cert + key files are provided, accept()ed sockets # are wrapped with an SSLContext before being handed off to the # connection thread. Clients then negotiate TLS first, and only # then start sending mongo wire frames over the encrypted channel. # Without TLS the daemon stays plaintext as it always has. if (tls_cert_file is None) != (tls_key_file is None): raise ValueError("tls_cert_file and tls_key_file must both be set or both be None") # mTLS knobs (ca_file / require_client_cert) only make sense when # server-side TLS is on — without it, there's no TLS handshake at # which to verify a client cert. Reject the misconfiguration loudly # rather than silently ignoring the mTLS settings. if tls_cert_file is None and (tls_ca_file is not None or tls_require_client_cert): raise ValueError( "tls_ca_file / tls_require_client_cert require tls_cert_file " "and tls_key_file (mTLS is a layer on top of server-side TLS)" ) if tls_require_client_cert and tls_ca_file is None: raise ValueError( "tls_require_client_cert=True requires tls_ca_file so the " "presented client cert can be verified" ) self._ssl_context: ssl.SSLContext | None if tls_cert_file is not None and tls_key_file is not None: # PROTOCOL_TLS_SERVER picks the highest TLS version both ends # support and refuses SSLv2/3 — Python's recommended preset # for new server code. load_cert_chain raises FileNotFoundError # or ssl.SSLError on bad inputs; we let those propagate so a # misconfigured server fails loudly at startup rather than # silently falling back to plaintext. ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) ctx.load_cert_chain(certfile=tls_cert_file, keyfile=tls_key_file) # mTLS: if a CA bundle is configured, ask clients for a cert # during the handshake. ``CERT_REQUIRED`` rejects clients # without one; ``CERT_OPTIONAL`` accepts both — useful for # staged rollouts where some clients are not yet # cert-enabled. if tls_ca_file is not None: ctx.load_verify_locations(cafile=tls_ca_file) ctx.verify_mode = ( ssl.CERT_REQUIRED if tls_require_client_cert else ssl.CERT_OPTIONAL ) self._ssl_context = ctx else: self._ssl_context = None self.tls_cert_file = tls_cert_file self.tls_key_file = tls_key_file self.tls_ca_file = tls_ca_file self.tls_require_client_cert = tls_require_client_cert # Active connection counter — incremented in _handle_client, # decremented on its way out. Read by _serve_forever to enforce # max_connections. self._active_conns = 0 self._active_conns_lock = threading.Lock() # Oplog writes are only useful when change-stream clients can # connect, which requires replica-set advertisement in `hello`. # Skip them in pure-standalone mode to drop a per-write BSON # encode + oplog-table cursor write per modified document. self.storage = Storage( storage_path, enable_oplog=replica_set_name is not None, ttl_sweep_seconds=ttl_sweep_seconds, noop_heartbeat_seconds=noop_heartbeat_seconds, oplog_retention_seconds=oplog_retention_seconds, oplog_max_entries=oplog_max_entries, cache_size=cache_size, session_max=session_max, sync_on_commit=sync_on_commit, ) # Replica-set initiation: seed the bootstrap oplog noop so # ``local.oplog.rs`` is never empty (mongod parity). No-op when the # oplog is disabled or already populated. self.storage.ensure_oplog_bootstrap() self.cursors = CursorRegistry() # Per-server counters surfaced through `serverStatus`. Started # eagerly so `start_monotonic` reflects construction time, not # the first command — uptime then matches what users expect. self.metrics = Metrics() # Per-connection visibility (currentOp) and an in-memory log # buffer (getLog). Both initialised eagerly so the registry / # buffer survive the lifetime of the server, not just the first # connection. self.connections = ConnectionRegistry() self.logs = LogBuffer() # Logical-session registry: drivers send an lsid on every # command and bracket session lifetime via ``startSession`` / # ``endSessions`` / ``refreshSessions``. Tracked here so the # session-management commands actually have state to operate # on instead of returning canned ``{ok: 1}``. self.sessions = SessionRegistry() # Per-server registry of active ``configureFailPoint`` entries. # Driver test suites lean on this to inject deterministic # errors at the wire (failCommand → errorCode / writeConcernError). # See ``secantus.failpoints`` for the supported subset. self.failpoints = FailPointRegistry() # Multi-document transaction state machine. The WT work is # bound here so the registry itself stays storage-agnostic; # ``txn.handle`` is None when the transaction never executed a # statement (the WT session is created lazily at the first one). self.transactions = TransactionRegistry( commit_func=self._commit_txn_handle, rollback_func=self._rollback_txn_handle, lifetime_seconds=transaction_lifetime_seconds, ) def _commit_txn_handle(self, txn: Transaction) -> None: if txn.handle is not None: self.storage.commit_user_transaction( txn.handle, lsid_doc=txn.lsid_doc, txn_number=txn.txn_number ) def _rollback_txn_handle(self, txn: Transaction) -> None: if txn.handle is not None: self.storage.abort_user_transaction(txn.handle) @property def address(self) -> tuple[str, int]: if self._socket is None: raise RuntimeError("server is not started") host, port, *_ = self._socket.getsockname() return host, port @property def uri(self) -> str: host, port = self.address return f"mongodb://{host}:{port}/" def start(self) -> None: if self._socket is not None: raise RuntimeError("server is already started") sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind((self.host, self.port)) sock.listen() self._socket = sock self.host, self.port = sock.getsockname()[:2] self._stop_event.clear() self._thread = threading.Thread( target=self._serve_forever, name="secantus-accept", daemon=True ) self._thread.start() logger.info("secantus listening on %s:%d", self.host, self.port) def stop(self) -> None: self._stop_event.set() if self._socket is not None: with contextlib.suppress(OSError): self._socket.close() self._socket = None if self._thread is not None: self._thread.join(timeout=2.0) self._thread = None # Drain in-flight per-connection handler threads BEFORE touching # WiredTiger. Each connection runs on its own daemon thread; if one is # mid-WT-operation when ``storage.close()`` frees the WT connection, # that's a use-after-free — a native crash (the intermittent xdist # worker death). Close every connection socket so blocked reads return # and the loops exit, wake any tailable getMore blocked on the oplog # condition variable, then wait for the active-connection count to # reach zero so no handler is still using storage. self.connections.close_all() self.storage.signal_shutdown() self._await_connections_drained(timeout=5.0) # Roll back open transactions while storage is still usable; # ``Storage.close()``'s session sweep would roll them back too, # but this releases their WT sessions in an orderly way first. self.transactions.abort_all() self.storage.close() def _await_connections_drained(self, timeout: float) -> None: """Block until every per-connection handler thread has exited (the active-connection counter reaches zero), or ``timeout`` elapses. Run at stop, after the connection sockets are closed and tailable waiters woken, so storage isn't torn down under an in-flight handler. ``close_all`` is re-run on every poll, not just once before this loop. The accept thread bumps ``_active_conns`` and spawns the handler *before* the handler registers its socket via ``connections.open``; a connection accepted in the instant before ``stop`` can therefore register its socket *after* the initial ``close_all`` snapshot, leaving its blocking ``recv`` un-woken and the drain stuck until the idle timeout. Re-snapshotting each poll closes such late arrivals within 5ms (every call is idempotent — already-closed sockets are no-ops).""" deadline = time.monotonic() + timeout while time.monotonic() < deadline: with self._active_conns_lock: if self._active_conns == 0: return self.connections.close_all() time.sleep(0.005) with self._active_conns_lock: remaining = self._active_conns if remaining: logger.warning( "server stop: %d connection thread(s) still active after %.1fs; " "closing storage anyway", remaining, timeout, ) def wait(self) -> None: self._stop_event.wait() def _serve_forever(self) -> None: assert self._socket is not None while not self._stop_event.is_set(): try: conn, addr = self._socket.accept() except OSError: return # Refuse new connections beyond the cap. A flood-DoS attacker # would otherwise spawn a thread per accepted socket. We close # the socket immediately rather than queuing — clients should # see "connection refused" / EOF, not block. with self._active_conns_lock: if self._active_conns >= self.max_connections: overflowed = True else: self._active_conns += 1 overflowed = False if overflowed: logger.warning( "rejecting connection from %s: %d active >= %d cap", addr, self._active_conns, self.max_connections, ) with contextlib.suppress(OSError): conn.close() continue # TLS handshake (when configured) happens here, on the accept # thread, BEFORE _handle_client takes over. wrap_socket blocks # on the handshake; a malformed / plaintext client surfaces # as ssl.SSLError, which we log + close cleanly so a single # bad client can't take the daemon down or stall the loop. if self._ssl_context is not None: try: conn = self._ssl_context.wrap_socket(conn, server_side=True) except (ssl.SSLError, OSError) as exc: logger.warning("TLS handshake from %s failed: %s", addr, exc) with contextlib.suppress(OSError): conn.close() with self._active_conns_lock: self._active_conns -= 1 continue # Set an idle timeout on the socket so a hostile or stuck # peer can't pin a thread forever. Applies to all socket # reads in _handle_client. with contextlib.suppress(OSError): conn.settimeout(self.client_idle_timeout_s) handler = threading.Thread( target=self._handle_client, args=(conn, addr), daemon=True, ) handler.start() def _stream_exhaust_getmore( self, conn: socket.socket, response_to: int, reply_ids: itertools.count[int], body: dict[str, Any], ctx: CommandContext, first_doc: dict[str, Any], ) -> bool: """Stream the rest of an exhaust cursor over one socket. Called when a getMore arrives with the OP_MSG ``exhaustAllowed`` flag set. ``first_doc`` is the reply the getMore handler already produced. We send that batch (and every subsequent one) with the ``moreToCome`` flag set, pulling further batches with synthetic getMores, until the cursor drains — then a final reply with ``id: 0`` and ``moreToCome`` clear closes the stream. mongod keeps the cursor alive until a getMore returns an empty batch, so even a cursor that drains exactly on a non-empty batch gets a trailing empty reply (this is what makes pymongo's command-monitoring see ``find, getMore, getMore, getMore`` for three docs at batchSize 1). Returns True if the whole stream was written (connection survives), False if a socket write failed (caller should drop the connection). """ target_id = bson.Int64(int(body["getMore"])) coll = body.get("collection", "") ns = first_doc["cursor"].get("ns", f"{ctx.db_name}.{coll}") db = body.get("$db", ctx.db_name) def send(doc: dict[str, Any], *, more: bool) -> bool: flags = OP_MSG_FLAG_MORE_TO_COME if more else 0 try: conn.sendall( build_op_msg_reply( response_to=response_to, request_id=next(reply_ids), body=doc, flags=flags, ) ) except OSError: return False return True doc = first_doc while True: cursor = doc.get("cursor") if not isinstance(cursor, dict): # An error reply (ok: 0) mid-stream — deliver it without # moreToCome and end the stream. return send(doc, more=False) batch = cursor.get("nextBatch", cursor.get("firstBatch", [])) drained = int(cursor.get("id", 0)) == 0 if drained: if batch and not send( { "cursor": {"nextBatch": batch, "id": target_id, "ns": ns}, "ok": 1.0, }, more=True, ): return False return send( { "cursor": {"nextBatch": [], "id": bson.Int64(0), "ns": ns}, "ok": 1.0, }, more=False, ) if not batch: # A live cursor that yielded nothing this round — a # tailable / awaitData getMore whose wait expired. Don't # keep streaming (that would spin forever); deliver this # empty batch without moreToCome and let the client fall # back to ordinary getMores. Normal cursors never reach # here (an empty batch always drains the cursor to id 0). return send(doc, more=False) if not send(doc, more=True): return False getmore: dict[str, Any] = {"getMore": target_id, "collection": coll} if "batchSize" in body: getmore["batchSize"] = body["batchSize"] if "maxTimeMS" in body: getmore["maxTimeMS"] = body["maxTimeMS"] getmore["$db"] = db try: doc = dispatch(getmore, ctx) except Exception: # We've already sent a `moreToCome` reply this round, so the # client is waiting for the rest of the stream. If the next # getMore blows up unexpectedly, terminate the stream with a # final `moreToCome`-clear reply rather than letting the # exception drop the connection mid-stream (which the client # surfaces as "Server ended moreToCome unexpectedly"). logger.exception("error streaming exhaust getMore on cursor %d", int(target_id)) return send( { "cursor": {"nextBatch": [], "id": bson.Int64(0), "ns": ns}, "ok": 1.0, }, more=False, ) def _handle_client(self, conn: socket.socket, addr: tuple[str, int]) -> None: # Register the socket alongside the conn_id so killOp can # shut it down from another thread. connection_id = self.connections.open((addr[0], addr[1]), sock=conn) reply_ids = itertools.count(1) connection_auth = ConnectionAuth() # Capture the verified client cert's subject DN once per # connection — MONGODB-X509 needs it to look up the user. # ``getpeercert()`` returns ``{}`` when the peer didn't present # a cert (CERT_OPTIONAL mode with a plain client), or the # parsed cert dict otherwise. ``getpeercert`` raises on # non-SSL sockets; plain TCP connections bypass this branch. peer_cert_dn: str | None = None if isinstance(conn, ssl.SSLSocket): with contextlib.suppress(ssl.SSLError, OSError, ValueError): cert = conn.getpeercert() peer_cert_dn = subject_dn_from_peercert(cert) self.metrics.connection_opened() logger.debug("client %d connected from %s", connection_id, addr) self.logs.append( "I", "NETWORK", "connection accepted", {"conn_id": connection_id, "from": addr} ) try: with conn: while not self._stop_event.is_set(): try: message = read_message(conn) except ConnectionClosed: return except ConnectionResetError: # Abrupt hang-up (RST instead of FIN) — the Go # driver's tools (mongodump, mongostat, ...) close # pooled connections this way routinely. A normal # disconnect, not an error worth a traceback. logger.debug("client %d reset connection", connection_id) return except TimeoutError: # Idle timeout fired — drop the connection so the # thread can be reaped instead of pinned forever. logger.debug("client %d idle timeout, closing", connection_id) return except OSError as exc: # The socket was closed / aborted under us — usually # ``stop()`` closing this connection to drain it (on # Windows that surfaces as ConnectionAbortedError / # BrokenPipeError, WinError 10053/10058), or the peer # hanging up abruptly. A clean disconnect, not a fault: # exit the loop so the handler thread is reaped. logger.debug("client %d connection closed: %s", connection_id, exc) return except MalformedBodyError as exc: # Body bytes failed BSON validation. Build a # targeted error reply so the client gets a # ``BadValue`` and can keep using the connection # for subsequent commands. Mongod returns the # same shape — InvalidBSON in a command body is # not a protocol-level fault. logger.warning("malformed BSON from %s: %s", addr, exc) try: reply = build_op_msg_reply( response_to=exc.header.request_id, request_id=next(reply_ids), body={ "ok": 0.0, "errmsg": str(exc), "code": 2, "codeName": "BadValue", }, ) conn.sendall(reply) except OSError: # Client may have hung up before we could # respond; fall through to the close path. return continue except WireProtocolError as exc: logger.warning("wire protocol error from %s: %s", addr, exc) return op = message.op try: server_addr = self.address if self._socket is not None else None except (RuntimeError, OSError): # OSError: stop() closed the listen socket between # the None check and getsockname() — shutdown race, # treat the same as "not started". server_addr = None if isinstance(op, OpMsg): body = _merge_op_msg_body(op) ctx = CommandContext( connection_id=connection_id, storage=self.storage, cursors=self.cursors, db_name=body.get("$db", "admin"), server_address=server_addr, replica_set_name=self.replica_set_name, connection_auth=connection_auth, require_auth=self.require_auth, metrics=self.metrics, connections=self.connections, logs=self.logs, sessions=self.sessions, failpoints=self.failpoints, transactions=self.transactions, peer_cert_dn=peer_cert_dn, ) try: response_doc = dispatch(body, ctx) except CloseConnectionRequested: # ``failCommand`` failpoint with # ``closeConnection: true``: drop the TCP # connection without replying. The driver # sees a closed socket and surfaces it as # a client-side network error. logger.debug( "closeConnection failpoint fired on conn %d", connection_id, ) return # `moreToCome` (bit 1) is the wire signal for # fire-and-forget requests — `writeConcern: {w: 0}` # unacknowledged writes use it. The spec requires the # server NOT to send a reply; if we do, the driver # mis-pairs it with the next response and aborts the # connection with a responseTo/requestId mismatch. if op.flags & OP_MSG_FLAG_MORE_TO_COME: continue # OP_MSG exhaust: when the client sets the # `exhaustAllowed` flag on a getMore, the server # streams every remaining batch back over the same # socket using the `moreToCome` flag, instead of # waiting for a getMore per batch. mongod only # streams on getMore (the find/aggregate reply that # opens the cursor is sent normally), so we mirror # that — find replies fall through to the single # reply below. if ( op.flags & OP_MSG_FLAG_EXHAUST_ALLOWED and "getMore" in body and isinstance(response_doc.get("cursor"), dict) ): if self._stream_exhaust_getmore( conn, message.header.request_id, reply_ids, body, ctx, response_doc, ): continue return reply = build_op_msg_reply( response_to=message.header.request_id, request_id=next(reply_ids), body=response_doc, ) elif isinstance(op, OpQuery): ctx = CommandContext( connection_id=connection_id, storage=self.storage, cursors=self.cursors, db_name=_db_from_namespace(op.full_collection_name), server_address=server_addr, replica_set_name=self.replica_set_name, connection_auth=connection_auth, require_auth=self.require_auth, metrics=self.metrics, connections=self.connections, logs=self.logs, sessions=self.sessions, failpoints=self.failpoints, transactions=self.transactions, peer_cert_dn=peer_cert_dn, ) try: response_doc = dispatch(op.query, ctx) except CloseConnectionRequested: logger.debug( "closeConnection failpoint fired on conn %d (OP_QUERY)", connection_id, ) return reply = build_op_reply( response_to=message.header.request_id, request_id=next(reply_ids), documents=[response_doc], ) else: logger.warning("unhandled op type %r from %s", type(op), addr) return try: conn.sendall(reply) except OSError: return except Exception: logger.exception("unhandled error on connection %d", connection_id) finally: self.metrics.connection_closed() self.connections.close(connection_id) with self._active_conns_lock: self._active_conns -= 1 # Release the WT session cached on this connection thread so # the engine's session pool (default 1024) isn't leaked when # client churn opens many short-lived connections. Without # this, an aggressive driver pool (mongo-rust-driver's spec # runners are the canonical case) saturates the pool after # ~1k connections and every subsequent ``hello`` errors with # ``WT_ERROR: out of sessions`` mid-handshake. with contextlib.suppress(Exception): self.storage._reset_thread_session() logger.debug("client %d disconnected", connection_id) def __enter__(self) -> Self: self.start() return self def __exit__( self, exc_type: type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None, ) -> None: self.stop()