from __future__ import annotations
import contextlib
import itertools
import logging
import socket
import threading
from types import TracebackType
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from typing import Self
from secantus.auth import ConnectionAuth
from secantus.commands import CommandContext, dispatch
from secantus.cursors import CursorRegistry
from secantus.storage import Storage
from secantus.wire import (
OP_MSG_FLAG_MORE_TO_COME,
ConnectionClosed,
OpMsg,
OpQuery,
WireProtocolError,
build_op_msg_reply,
build_op_reply,
read_message,
)
logger = logging.getLogger(__name__)
def _merge_op_msg_body(op: OpMsg) -> dict[str, Any]:
body = dict(op.body)
for identifier, docs in op.document_sequences:
body[identifier] = docs
return body
def _db_from_namespace(ns: str) -> str:
db, _, _ = ns.partition(".")
return db or "admin"
[docs]
class SecantusDBServer:
def __init__(
self,
host: str = "127.0.0.1",
port: int = 0,
storage_path: str = "./secantus-data",
*,
replica_set_name: str | None = "secantus",
require_auth: bool = False,
) -> None:
self.host = host
self.port = port
self.replica_set_name = replica_set_name
self.require_auth = require_auth
self._socket: socket.socket | None = None
self._thread: threading.Thread | None = None
self._stop_event = threading.Event()
self._connection_ids = itertools.count(1)
# Oplog writes are only useful when change-stream clients can
# connect, which requires replica-set advertisement in `hello`.
# Skip them in pure-standalone mode to drop a per-write BSON
# encode + oplog-table cursor write per modified document.
self.storage = Storage(storage_path, enable_oplog=replica_set_name is not None)
self.cursors = CursorRegistry()
@property
def address(self) -> tuple[str, int]:
if self._socket is None:
raise RuntimeError("server is not started")
host, port, *_ = self._socket.getsockname()
return host, port
@property
def uri(self) -> str:
host, port = self.address
return f"mongodb://{host}:{port}/"
def start(self) -> None:
if self._socket is not None:
raise RuntimeError("server is already started")
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind((self.host, self.port))
sock.listen()
self._socket = sock
self.host, self.port = sock.getsockname()[:2]
self._stop_event.clear()
self._thread = threading.Thread(
target=self._serve_forever, name="secantus-accept", daemon=True
)
self._thread.start()
logger.info("secantus listening on %s:%d", self.host, self.port)
def stop(self) -> None:
self._stop_event.set()
if self._socket is not None:
with contextlib.suppress(OSError):
self._socket.close()
self._socket = None
if self._thread is not None:
self._thread.join(timeout=2.0)
self._thread = None
self.storage.close()
def wait(self) -> None:
self._stop_event.wait()
def _serve_forever(self) -> None:
assert self._socket is not None
while not self._stop_event.is_set():
try:
conn, addr = self._socket.accept()
except OSError:
return
handler = threading.Thread(
target=self._handle_client,
args=(conn, addr),
daemon=True,
)
handler.start()
def _handle_client(self, conn: socket.socket, addr: tuple[str, int]) -> None:
connection_id = next(self._connection_ids)
reply_ids = itertools.count(1)
connection_auth = ConnectionAuth()
logger.debug("client %d connected from %s", connection_id, addr)
try:
with conn:
while not self._stop_event.is_set():
try:
message = read_message(conn)
except ConnectionClosed:
return
except WireProtocolError as exc:
logger.warning("wire protocol error from %s: %s", addr, exc)
return
op = message.op
try:
server_addr = self.address if self._socket is not None else None
except RuntimeError:
server_addr = None
if isinstance(op, OpMsg):
body = _merge_op_msg_body(op)
ctx = CommandContext(
connection_id=connection_id,
storage=self.storage,
cursors=self.cursors,
db_name=body.get("$db", "admin"),
server_address=server_addr,
replica_set_name=self.replica_set_name,
connection_auth=connection_auth,
require_auth=self.require_auth,
)
response_doc = dispatch(body, ctx)
# `moreToCome` (bit 1) is the wire signal for
# fire-and-forget requests — `writeConcern: {w: 0}`
# unacknowledged writes use it. The spec requires the
# server NOT to send a reply; if we do, the driver
# mis-pairs it with the next response and aborts the
# connection with a responseTo/requestId mismatch.
if op.flags & OP_MSG_FLAG_MORE_TO_COME:
continue
reply = build_op_msg_reply(
response_to=message.header.request_id,
request_id=next(reply_ids),
body=response_doc,
)
elif isinstance(op, OpQuery):
ctx = CommandContext(
connection_id=connection_id,
storage=self.storage,
cursors=self.cursors,
db_name=_db_from_namespace(op.full_collection_name),
server_address=server_addr,
replica_set_name=self.replica_set_name,
connection_auth=connection_auth,
require_auth=self.require_auth,
)
response_doc = dispatch(op.query, ctx)
reply = build_op_reply(
response_to=message.header.request_id,
request_id=next(reply_ids),
documents=[response_doc],
)
else:
logger.warning("unhandled op type %r from %s", type(op), addr)
return
try:
conn.sendall(reply)
except OSError:
return
except Exception:
logger.exception("unhandled error on connection %d", connection_id)
finally:
logger.debug("client %d disconnected", connection_id)
def __enter__(self) -> Self:
self.start()
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
tb: TracebackType | None,
) -> None:
self.stop()