Source code for secantus.storage

"""WiredTiger-backed document store.

WiredTiger is the default storage engine for MongoDB. We use the same
engine here so that on-disk semantics line up with what test code would
see against a real ``mongod``.

Indexes use a sidecar entries table (``table:secantus_index_entries``)
with a single trailing ``u`` column packing
``escape(sortkey) + b"\\x00\\x00" + id_key``. The sortkey comes from
``secantus.sortkey`` (typed, byte-sortable BSON encoding), so the WT
B-tree gives us ordered access for free. ``find_matching`` routes a wide
range of filter shapes through the index — equality, ``$eq``, ``$in``,
``$gt``/``$gte``/``$lt``/``$lte`` on a single field, plus compound
indexes when filter fields cover a leading prefix (with optional range
on the next field). Sort-by-indexed-field walks the B-tree in order.
"""

from __future__ import annotations

import contextlib
import datetime as _dt
import os
import shutil
import tempfile
import threading
import time as _time
import uuid as _uuid
from collections.abc import Callable, Iterable, Mapping
from decimal import Decimal, InvalidOperation
from typing import Any

import bson
import wiredtiger as wt
from bson import Decimal128
from bson.timestamp import Timestamp

from secantus.diff import compute_update_description
from secantus.geo import GeoError, parse_doc_geometry, parse_query_geometry, validate_coordinates
from secantus.geo_index import (
    encode_cell,
    planar_2d_covering,
    planar_2d_index_for_point,
    s2_doc_covering,
    s2_query_covering,
)
from secantus.paths import get_path, has_path
from secantus.projection import apply_projection
from secantus.query import matches
from secantus.sortkey import COMPOUND_SEP, encode_value, encode_value_directed
from secantus.update import apply_update, find_positional_matches

_GEO_2DSPHERE = "2dsphere"
_GEO_2D = "2d"
_GEO_TYPES = frozenset({_GEO_2DSPHERE, _GEO_2D})


def _geo_type_of(key_spec: Mapping[str, Any]) -> tuple[str, str] | None:
    """Return ``(field, geo_type)`` if ``key_spec`` declares a geo index.

    A geo index has exactly one field whose value is the string
    ``"2dsphere"`` or ``"2d"`` (rather than ``1`` / ``-1``). Compound
    geo indexes (geo field + scalar trailing fields) are out of scope
    in Phase 2; we treat any spec containing a geo field as geo-only
    and ignore the trailing fields. The picker still works because
    `$geoWithin` etc. are answered by the cell scan + verifier.
    """
    for field, value in key_spec.items():
        if isinstance(value, str) and value in _GEO_TYPES:
            return field, value
    return None


def _doc_geo_cells(
    doc: Mapping[str, Any],
    field: str,
    geo_type: str,
    options: Mapping[str, Any],
    *,
    index_name: str = "",
) -> list[bytes]:
    """Encoded cell bytes for the doc's geo field.

    Returns an empty list when the indexed field is missing or null
    (sparse-by-default semantics, matching mongod's 2dsphere/2d).

    Raises :class:`GeoExtractError` when the value is *present* but
    can't be indexed — unparseable shape, wrong type for a 2d index,
    or coordinates outside the valid range. The caller propagates this
    to the wire as a write error (code 16572 "Can't extract geo keys").
    """
    value = get_path(dict(doc), field)
    if value is None:
        # Field missing or explicitly null — sparse semantics, no entry.
        return []
    geom = parse_doc_geometry(value)
    if geom is None:
        raise GeoExtractError(
            index_name,
            field,
            doc.get("_id"),
            f"value {value!r} is not a recognised geometry",
        )
    try:
        validate_coordinates(geom, geo_type=geo_type, options=options)
    except GeoError as exc:
        raise GeoExtractError(index_name, field, doc.get("_id"), str(exc)) from exc
    if geo_type == _GEO_2DSPHERE:
        return [encode_cell(c) for c in s2_doc_covering(geom)]
    # 2d: single point only.
    from shapely.geometry import Point as _Point

    if not isinstance(geom, _Point):
        raise GeoExtractError(
            index_name,
            field,
            doc.get("_id"),
            "2d index requires a point; got a non-point geometry",
        )
    return [encode_cell(planar_2d_index_for_point(geom.x, geom.y, options))]


_COLL_TABLE = "table:secantus_collections"
_DOC_TABLE = "table:secantus_documents"
_IDX_TABLE = "table:secantus_indexes"
_IDX_ENTRIES_TABLE = "table:secantus_index_entries"
_OPLOG_TABLE = "table:secantus_oplog"
_PREIMAGE_TABLE = "table:secantus_preimages"
_OPLOG_META_TABLE = "table:secantus_oplog_meta"
_USERS_TABLE = "table:secantus_users"
_ROLES_TABLE = "table:secantus_roles"
_PROFILE_TABLE = "table:secantus_profile_settings"

_OPLOG_PRUNE_INTERVAL = 1000  # call prune_oplog every N emits

_ENTRY_SEP = b"\x00\x00"


def _escape_kb(kb: bytes) -> bytes:
    """Order-preserving escape so ``\\x00\\x00`` is unambiguous as a separator."""
    return kb.replace(b"\x00", b"\x00\xff")


def _pack_entry(kb: bytes, id_key: bytes) -> bytes:
    """Pack a sortable index-entry payload into a single ``u`` column.

    WiredTiger length-prefixes ``u`` columns when they're not last in the
    key, which breaks lexicographic comparison. Packing both fields into
    one trailing ``u`` column lets the B-tree do the sort for us.
    """
    return _escape_kb(kb) + _ENTRY_SEP + id_key


def _unpack_entry(packed: bytes) -> tuple[bytes, bytes]:
    """Return ``(escaped_kb, id_key)`` from a packed entry."""
    sep = packed.find(_ENTRY_SEP)
    return packed[:sep], packed[sep + 2 :]


def extract_backup_archive(
    archive_path: str,
    target_dir: str,
    *,
    allow_existing: bool = False,
) -> dict[str, int | str]:
    """Extract a SecantusDB backup archive into ``target_dir``.

    Side-channel restore: the archive is unpacked into a fresh
    directory that the caller then points a new ``SecantusDBServer`` at
    (``SecantusDBServer(storage_path=<target_dir>)``). The function
    does **not** touch any running server's storage — that mode of
    "hot restore over a live WT connection" can't be done safely
    without restructuring how connection threads cache WT sessions,
    and isn't what real mongod's restore tooling supports either.

    Returns ``{"targetDir": <abs>, "fileCount": <int>, "archive": <abs>}``
    on success. Raises ``RuntimeError`` if:

    * the archive doesn't exist,
    * the archive doesn't contain a ``WiredTiger`` metadata file
      (so it's not a SecantusDB / WT backup at all),
    * ``target_dir`` already exists, is non-empty, and ``allow_existing``
      is False (default).

    The WT metadata check runs **before** extraction so a malformed
    archive can't pollute ``target_dir``.
    """
    import tarfile

    abs_archive = os.path.abspath(archive_path)
    abs_target = os.path.abspath(target_dir)
    if not os.path.isfile(abs_archive):
        raise RuntimeError(f"extract_backup_archive: archive not found: {abs_archive}")
    if os.path.exists(abs_target):
        if not os.path.isdir(abs_target):
            raise RuntimeError(
                f"extract_backup_archive: target exists and is not a directory: {abs_target}"
            )
        if os.listdir(abs_target) and not allow_existing:
            raise RuntimeError(
                "extract_backup_archive: target directory is not empty "
                f"(pass allow_existing=True to overlay): {abs_target}"
            )
    else:
        os.makedirs(abs_target)

    with tarfile.open(abs_archive, "r:*") as tar:
        names = tar.getnames()
        if "WiredTiger" not in names:
            raise RuntimeError(
                f"extract_backup_archive: archive {abs_archive!r} is not "
                "a SecantusDB backup (no WiredTiger metadata file inside)"
            )
        tar.extractall(abs_target, filter="data")

    return {
        "targetDir": abs_target,
        "fileCount": len(names),
        "archive": abs_archive,
    }


class DuplicateKeyError(Exception):
    def __init__(self, doc_id: Any) -> None:
        super().__init__(f"duplicate _id: {doc_id!r}")
        self.doc_id = doc_id


def _id_key(doc_id: Any) -> bytes:
    """Byte-sortable canonical bytes for an ``_id`` value.

    Uses the same byte-sortable encoding the secondary-index entries
    table relies on. Two consequences worth knowing:

    * Cross-numeric collision: ``1 == 1.0 == Decimal128("1")`` produce
      identical bytes (so they hit the same doc / clash on uniqueness),
      because ``encode_value`` normalises numerics through ``Decimal``.
    * Natural iteration: walking the doc table in WT-key order yields
      docs in BSON cross-type sort order, which matches what real
      MongoDB calls "natural order" for non-capped collections.
    """
    return encode_value(doc_id)


def _doc_makes_multikey(doc: Mapping[str, Any], key_spec: Mapping[str, Any]) -> bool:
    """True if any field in ``key_spec`` resolves to a list value in ``doc``.

    Such a value is encoded as a single composite array sortkey, so a
    later scalar-equality query against this index would silently miss
    the doc — the index must fall back to a full scan.
    """
    return any(isinstance(get_path(dict(doc), field), list) for field in key_spec)


def _index_key(
    doc: Mapping[str, Any], key_spec: Mapping[str, Any], *, sparse: bool
) -> bytes | None:
    """Direction-aware byte-sortable encoding for an index ``key_spec``.

    Each field is encoded with ``encode_value_directed`` so ``-1``
    (descending) fields get bitwise-inverted bytes, making a forward
    B-tree walk yield values in descending order. Compound keys are
    joined with ``\\x00\\x00`` between components.

    For docs whose indexed field is array-valued, this returns the
    whole-array sortkey only — the single canonical "doc-shape" key
    used by uniqueness probes. The full set of multikey entries
    (per-element + whole-array) is produced by
    :func:`_index_key_variants`.
    """
    if sparse:
        for field in key_spec:
            if not has_path(dict(doc), field):
                return None
    fields = list(key_spec)
    if len(fields) == 1:
        d = int(key_spec[fields[0]])
        return encode_value_directed(get_path(dict(doc), fields[0]), d)
    parts = [encode_value_directed(get_path(dict(doc), f), int(key_spec[f])) for f in fields]
    return COMPOUND_SEP.join(parts)


def _index_key_variants(
    doc: Mapping[str, Any], key_spec: Mapping[str, Any], *, sparse: bool
) -> list[bytes]:
    """All byte-keys this doc contributes to an index under ``key_spec``.

    For scalar-valued fields, returns one key — same as ``_index_key``.
    For array-valued fields, returns one key per array element *and*
    the whole-array key, mirroring real ``mongod``'s multikey index
    layout. This makes:

    * ``{tags: "python"}`` against ``{tags: ["python", "go"]}`` light
      up via the per-element entry for ``"python"``.
    * ``{tags: ["python", "go"]}`` (whole-array equality) light up via
      the whole-array entry — without this, the equality lookup would
      false-negative.
    * Range / ``$in`` queries on array fields hit at least all true
      matches (the post-index ``matches()`` filter discards
      false-positives).

    For compound indexes whose multiple fields are array-valued, the
    cartesian product is taken across each field's candidate values.
    Real mongod restricts compound indexes to one multikey field per
    doc; we don't enforce that — we just emit the cross-product, which
    is correct (over-includes; the post-filter discards) but pays a
    cardinality blow-up the user is then on the hook for.

    Returns an empty list when ``sparse`` and any field is missing.
    Per-element values are deduplicated against their encoded bytes,
    so ``[1, 1, 2]`` writes two element entries (``1`` and ``2``) plus
    the whole-array entry, not three.
    """
    fields = list(key_spec)
    if sparse:
        for field in fields:
            if not has_path(dict(doc), field):
                return []

    # Per-field candidate values: scalars contribute [val]; arrays
    # contribute [unique_elements..., whole_array].
    per_field: list[list[Any]] = []
    for field in fields:
        v = get_path(dict(doc), field)
        if isinstance(v, list):
            seen: set[bytes] = set()
            uniq: list[Any] = []
            d = int(key_spec[field])
            for elem in v:
                eb = encode_value_directed(elem, d)
                if eb in seen:
                    continue
                seen.add(eb)
                uniq.append(elem)
            # Whole-array sortkey may collide with an element when the
            # array is a single scalar repeated; the dedup below at the
            # entry level (set of bytes) catches that.
            per_field.append([*uniq, v])
        else:
            per_field.append([v])

    if len(fields) == 1:
        d = int(key_spec[fields[0]])
        keys: list[bytes] = []
        seen_kb: set[bytes] = set()
        for val in per_field[0]:
            kb = encode_value_directed(val, d)
            if kb in seen_kb:
                continue
            seen_kb.add(kb)
            keys.append(kb)
        return keys

    # Compound: cartesian product across per-field candidate lists.
    from itertools import product

    keys = []
    seen_kb = set()
    for combo in product(*per_field):
        parts = [
            encode_value_directed(combo[i], int(key_spec[fields[i]])) for i in range(len(fields))
        ]
        kb = COMPOUND_SEP.join(parts)
        if kb in seen_kb:
            continue
        seen_kb.add(kb)
        keys.append(kb)
    return keys


def _to_decimal(value: Any) -> Decimal:
    if isinstance(value, Decimal128):
        return value.to_decimal()
    if isinstance(value, float):
        return Decimal(repr(value))
    return Decimal(value)


def _bson_type_rank(value: Any) -> int:
    """Rank for MongoDB's cross-type sort order. Lower rank sorts first."""
    import datetime as _dt

    from bson import Binary, MaxKey, MinKey, ObjectId, Regex, Timestamp

    if isinstance(value, MinKey):
        return 1
    if value is None:
        return 2
    if isinstance(value, bool):
        return 9
    if isinstance(value, (int, float, Decimal128)):
        return 3
    if isinstance(value, str):
        return 4
    if isinstance(value, Mapping):
        return 5
    if isinstance(value, list):
        return 6
    if isinstance(value, (bytes, Binary)):
        return 7
    if isinstance(value, ObjectId):
        return 8
    if isinstance(value, _dt.datetime):
        return 10
    if isinstance(value, Timestamp):
        return 11
    if isinstance(value, Regex):
        return 12
    if isinstance(value, MaxKey):
        return 13
    return 5


class _SortKey:
    __slots__ = ("val", "_reverse")

    def __init__(self, val: Any, reverse: bool = False) -> None:
        self.val = val
        self._reverse = reverse

    def __lt__(self, other: _SortKey) -> bool:
        # Swap operands when this key is descending — the same comparison
        # logic then yields the correct order for desc fields, and the
        # equal-keys case still returns False on both sides (stable sort
        # preserves doc order). Both sides of the comparison must agree on
        # direction (they're in the same column), which our caller
        # guarantees.
        if self._reverse:
            a, b = other.val, self.val
        else:
            a, b = self.val, other.val
        return _bson_lt(a, b)

    def __eq__(self, other: object) -> bool:
        return isinstance(other, _SortKey) and self.val == other.val


def _bson_lt(a: Any, b: Any) -> bool:
    """BSON sort-order ``<`` for two values.

    Handles the four cases ``__lt__`` used to inline: cross-type rank,
    Decimal128 widening, native ``<``, and the embedded-document /
    array recursion — mongo-node-driver's
    ``Aggregation ... pipeline using array`` test sorts grouped docs
    by an embedded ``_id`` field and the previous inline ``a < b``
    raised ``TypeError`` on Python's dicts.
    """
    ra = _bson_type_rank(a)
    rb = _bson_type_rank(b)
    if ra != rb:
        return ra < rb
    if a is None or b is None:
        return False
    if isinstance(a, Decimal128) or isinstance(b, Decimal128):
        try:
            ad = _to_decimal(a)
            bd = _to_decimal(b)
            return bool(ad < bd)
        except (InvalidOperation, ValueError):
            pass
    # Embedded documents: compare field-by-field in insertion order,
    # first differing pair wins. Real BSON sort recurses; Python's dict
    # ``<`` raises ``TypeError`` so without this branch sort would be
    # a no-op on grouped ``_id`` keys.
    if isinstance(a, Mapping) and isinstance(b, Mapping):
        a_items = list(a.items())
        b_items = list(b.items())
        for (ak, av), (bk, bv) in zip(a_items, b_items, strict=False):
            if ak != bk:
                return ak < bk
            if _bson_lt(av, bv):
                return True
            if _bson_lt(bv, av):
                return False
        return len(a_items) < len(b_items)
    # Arrays: lexicographic, element-by-element. Same TypeError trap
    # as the dict case for arrays-of-mixed-types.
    if isinstance(a, list) and isinstance(b, list):
        for av, bv in zip(a, b, strict=False):
            if _bson_lt(av, bv):
                return True
            if _bson_lt(bv, av):
                return False
        return len(a) < len(b)
    try:
        return bool(a < b)
    except TypeError:
        return type(a).__name__ < type(b).__name__


def sort_docs(
    docs: list[dict[str, Any]], sort_spec: Mapping[str, Any] | None
) -> list[dict[str, Any]]:
    if not sort_spec:
        return docs
    fields = [(f, int(d) == -1) for f, d in sort_spec.items()]
    # Single sort over a precomputed tuple key rather than N stable passes:
    # one pass through Timsort, get_path called once per field per doc.
    return sorted(
        docs,
        key=lambda d: tuple(_SortKey(get_path(d, f), reverse=rev) for f, rev in fields),
    )


_ID_INDEX_NAME = "_id_"


class IndexConflict(Exception):
    def __init__(
        self,
        index_name: str,
        doc_id: Any,
        *,
        key_pattern: dict[str, Any] | None = None,
        key_value: dict[str, Any] | None = None,
    ) -> None:
        super().__init__(f"E11000 duplicate key error in index {index_name}: _id={doc_id!r}")
        self.index_name = index_name
        self.doc_id = doc_id
        # Real mongod returns ``keyPattern`` (the index spec) and
        # ``keyValue`` (the conflicting field values) in the dup-key
        # error response. Drivers expose them as ``errorResponse``
        # fields; mongo-java-driver's ``findOneAndUpdate-errorResponse``
        # asserts both. Optional because legacy raise-sites
        # (``_id`` collision before index machinery, recovery paths)
        # don't have the index spec handy.
        self.key_pattern = key_pattern
        self.key_value = key_value


class DocumentValidationError(Exception):
    """A write produced a doc that didn't satisfy the collection's
    ``validator``. Caught at the command layer and surfaced as the
    mongod-shaped writeError (code 121, ``DocumentValidationFailure``)
    with the ``errInfo.failingDocumentId`` field drivers' errorResponse
    tests assert on."""

    def __init__(self, doc_id: Any) -> None:
        super().__init__("Document failed validation")
        self.doc_id = doc_id


class CreateIndexUnsupported(Exception):
    """``create_index`` was given an index type SecantusDB doesn't support
    (currently ``text`` / ``hashed``). Caught at the command layer and
    surfaced as a typed wire error rather than letting the cell-encoder
    later trip over an opaque internal exception."""


class IndexOptionsConflict(Exception):
    """``create_index`` was called with a name that already exists in the
    collection but with conflicting options (different ``unique`` /
    ``sparse`` / ``hidden`` / ``expireAfterSeconds`` /
    ``partialFilterExpression``). Real mongod rejects with
    ``IndexOptionsConflict`` (code 85); drivers (mongo-ruby-driver's
    ``Collection#create_indexes`` specs) assert on the rejection."""


class GeoExtractError(Exception):
    """Doc's geo field can't be indexed — bad shape or out-of-bounds coords.

    Raised from the geo-index write path when an insert / update / index
    creation hits a doc the geo extractor can't make sense of (bad
    GeoJSON, non-numeric coordinates, longitude / latitude outside the
    valid range, etc.). Caught at the command-layer write boundary and
    surfaced as a wire-level write error with mongod's documented code
    16572 ("Can't extract geo keys").
    """

    def __init__(self, index_name: str, field: str, doc_id: Any, reason: str) -> None:
        super().__init__(
            f"Can't extract geo keys for index {index_name!r} on field {field!r}: {reason}"
        )
        self.index_name = index_name
        self.field = field
        self.doc_id = doc_id
        self.reason = reason


class BadHint(Exception):
    """The ``hint`` passed to ``find_matching`` doesn't name an existing index."""


[docs] class Storage: def __init__( self, path: str = ":memory:", *, oplog_retention_seconds: float = 3600.0, oplog_max_entries: int = 100_000, time_func: Callable[[], float] | None = None, enable_oplog: bool = True, ttl_sweep_seconds: float = 60.0, noop_heartbeat_seconds: float = 0.0, cache_size: str = "1G", session_max: int = 1000, sync_on_commit: bool = False, ) -> None: # When False, _emit_oplog short-circuits and writes nothing — # used in standalone (non-replica-set) mode to skip the per-write # BSON encode + WT cursor write cost of oplog entries that no # change-stream client will ever read. The oplog WT tables are # still created so toggling at runtime stays safe. self.enable_oplog = enable_oplog self._lock = threading.RLock() self._closed = False self._tempdir: str | None = None # session_max default is ~120; each client connection thread # caches its own session in `threading.local()`, and cross- # thread oplog readers open additional short-lived sessions on # demand. With a few dozen concurrent client connections plus # active change-stream tailers, the default ceiling is hit # mid-handshake and surfaces as `out of sessions` / # WT_ERROR. mongod itself runs with session_max=33000 — 1000 # is a generous floor for a single-node test surrogate while # still well under the WT hard limit. # cache_size default is 100 MB. With ``in_memory=true`` every # write also lives in cache, so a workload that inserts a # handful of 16 MB documents (mongod's per-doc max) blows the # cap as ``WT_CACHE_FULL: operation would overflow cache``. # 1 GB gives generous headroom for tests + reasonable # in-process workloads while staying well under the limits # ``mongod`` itself runs with on a normal box. # Tracked so ``checkpoint()`` calls are skipped in in-memory # mode (WT's in_memory backend rejects them with a noisy # ``__wt_inmem_unsupported_op`` log line on every call). self._in_memory = path == ":memory:" # Stashed for reuse in restore-archive / explain output. self.cache_size = cache_size self.session_max = session_max self.sync_on_commit = sync_on_commit if path == ":memory:": self._tempdir = tempfile.mkdtemp(prefix="secantus_wt_") home = self._tempdir # in_memory=true disables the journal entirely (no files); # ephemeral by definition, so durability isn't a concern. config = f"create,in_memory=true,session_max={session_max},cache_size={cache_size}" else: os.makedirs(path, exist_ok=True) home = path # ``log=(enabled=true)`` turns on WT's redo journal: every # transaction commit writes a log record before it returns, # and recovery replays the log on reopen. Without this, # WT's only durability mechanism is checkpoints (default # cadence: every 60s, or on clean ``WT_CONNECTION->close``). # On SIGKILL between checkpoints, every uncommitted write # is lost — which is exactly the failure mode observed by # ``bench/chaos.py`` (3-min chaos run, 17 SIGKILLs: # 432,881 acked / 1 persisted). # # ``transaction_sync`` is the per-commit durability knob. # Default ``enabled=false,method=fsync`` matches mongod's # default ``writeConcern: {w:1, j:false}`` — log records # land in the OS page cache, the OS flushes them on its # own schedule, SIGKILL is durable, true power-loss # between commits can lose data. # # ``sync_on_commit=True`` (config-file knob) bumps to # ``enabled=true,method=fsync``: every commit fsyncs the # log before returning, so the wire-protocol equivalent of # ``writeConcern: {j: true}`` is effectively enforced for # the whole connection. Throughput cost on small-doc # inserts is significant (1-2 orders of magnitude), # which is why it's opt-in. # # ``file_max=10MB`` bounds journal segment size; smaller # files churn the log more, larger files delay reclamation. # 10 MB matches mongod's WT default. sync_part = ( "transaction_sync=(enabled=true,method=fsync)" if sync_on_commit else "transaction_sync=(enabled=false,method=fsync)" ) config = ( f"create,session_max={session_max},cache_size={cache_size}," f"log=(enabled=true,file_max=10MB)," f"{sync_part}" ) # The on-disk WT home is stashed so ``create_archive`` can tar # it after a checkpoint without re-deriving the path. self.home_path = home self._conn = wt.wiredtiger_open(home, config) self._tls = threading.local() self._all_sessions: list[Any] = [] boot = self._conn.open_session() try: boot.create(_COLL_TABLE, "key_format=SS,value_format=u") boot.create(_DOC_TABLE, "key_format=SSu,value_format=u") boot.create(_IDX_TABLE, "key_format=SSS,value_format=u") boot.create(_IDX_ENTRIES_TABLE, "key_format=SSSu,value_format=u") boot.create(_OPLOG_TABLE, "key_format=q,value_format=u") boot.create(_PREIMAGE_TABLE, "key_format=q,value_format=u") boot.create(_OPLOG_META_TABLE, "key_format=S,value_format=u") boot.create(_USERS_TABLE, "key_format=SS,value_format=u") boot.create(_ROLES_TABLE, "key_format=SS,value_format=u") boot.create(_PROFILE_TABLE, "key_format=S,value_format=u") finally: boot.close() # Oplog state — durable across restart via _OPLOG_META_TABLE. self.oplog_retention_seconds = float(oplog_retention_seconds) self.oplog_max_entries = int(oplog_max_entries) self._time = time_func or _time.time self._oplog_cv = threading.Condition(threading.Lock()) self._oplog_emit_count = 0 # Tiny fine-grained lock for seq + timestamp minting. Held in # microseconds while reserving the next seq range and bumping # the cluster-time counter. Carved out of ``_lock`` (Phase 2.1 # of the WT concurrency plan) so concurrent writers can mint # without contending on the global storage lock. self._oplog_seq_lock = threading.Lock() # Per-collection RLocks for the CRUD path (Phase 2.4 of the WT # concurrency plan). Writes to *different* collections can now # run in parallel; writes to the *same* collection still # serialise (preserves unique-index correctness + the pre-check # racing windows that would otherwise need an architectural # refactor of the index-entries schema). DDL operations also # acquire the per-coll lock(s) they affect so they cannot # reshape schema mid-CRUD-write. self._coll_locks: dict[tuple[str, str], threading.RLock] = {} self._coll_locks_mutex = threading.Lock() with self._lock: self._next_seq, self._last_ts_secs, self._last_ts_ord = self._load_oplog_meta() # TTL sweeper. Real mongod runs ``ttlMonitor`` every 60s by # default; we mirror that. ``ttl_sweep_seconds <= 0`` disables # the thread entirely (tests that drive expiry deterministically # via ``prune_ttl(now=...)`` use that escape hatch). The # sweeper walks every (db, coll) and calls ``prune_ttl`` on # each — collections with no TTL index short-circuit cheaply # at the index-scan step, so the steady-state cost is small. self._ttl_sweep_seconds = float(ttl_sweep_seconds) self._ttl_stop = threading.Event() self._ttl_thread: threading.Thread | None = None if self._ttl_sweep_seconds > 0: self._ttl_thread = threading.Thread( target=self._ttl_sweep_loop, name="secantus-ttl-sweeper", daemon=True ) self._ttl_thread.start() # Periodic noop heartbeat. Real mongod writes ``{op: "n"}`` # entries to the oplog every ~10s (configurable via # ``periodicNoopIntervalSecs``) so cluster time advances and # change-stream resume tokens minted from the oplog don't fall # outside the retention window during quiet stretches. Default # disabled (0) — embedded test users typically don't need it # and the extra writes would noise up tight oplog assertions. # Set ``noop_heartbeat_seconds=10`` (mongod default) for # production-ish behaviour. ``enable_oplog=False`` short- # circuits anyway, so the heartbeat is a no-op in that mode. self._noop_heartbeat_seconds = float(noop_heartbeat_seconds) self._noop_stop = threading.Event() self._noop_thread: threading.Thread | None = None if self._noop_heartbeat_seconds > 0 and self.enable_oplog: self._noop_thread = threading.Thread( target=self._noop_heartbeat_loop, name="secantus-noop-heartbeat", daemon=True ) self._noop_thread.start() def _load_oplog_meta(self) -> tuple[int, int, int]: c = self._cursor(_OPLOG_META_TABLE) c.set_key("state") if c.search() == 0: blob = bytes(c.get_value()) if blob: state = bson.decode(blob) return ( int(state.get("next_seq", 1)), int(state.get("last_ts_secs", 0)), int(state.get("last_ts_ord", 0)), ) # Fallback: scan oplog table for max key + reconstruct from entry. c2 = self._cursor(_OPLOG_TABLE) # Walk to last row. last_seq = 0 last_secs = 0 last_ord = 0 rc = c2.next() while rc == 0: seq = int(c2.get_key()) if seq > last_seq: last_seq = seq blob = bytes(c2.get_value()) if blob: entry = bson.decode(blob) ts = entry.get("ts") if isinstance(ts, Timestamp): last_secs, last_ord = ts.time, ts.inc rc = c2.next() return last_seq + 1, last_secs, last_ord def _persist_oplog_meta(self) -> None: c = self._cursor(_OPLOG_META_TABLE) c["state"] = bson.encode( { "next_seq": self._next_seq, "last_ts_secs": self._last_ts_secs, "last_ts_ord": self._last_ts_ord, } ) def _mint_ts(self) -> Timestamp: """Return a strictly-monotonic ``Timestamp(secs, ord)``. Caller must hold ``self._oplog_seq_lock``. Within a single wall-clock second ``ord`` increments; on a new second it resets to 1. Recovered state on startup ensures the first mint after restart is strictly greater than any previously-emitted timestamp. """ now = int(self._time()) if now > self._last_ts_secs: self._last_ts_secs = now self._last_ts_ord = 1 else: self._last_ts_ord += 1 return Timestamp(self._last_ts_secs, self._last_ts_ord) def _coll_lock(self, db: str, coll: str) -> threading.RLock: """Return the per-collection RLock for ``(db, coll)``, creating it on first reference. Phase 2.4 of the WT concurrency plan. CRUD on a given collection serialises through this lock; CRUD on *other* collections proceeds in parallel. DDL on this collection also acquires this lock so schema changes cannot interleave with in-flight writes. """ key = (db, coll) # Fast path: lock already exists — read without any mutation, # safe under GIL. existing = self._coll_locks.get(key) if existing is not None: return existing # Create-or-fetch under the small registry mutex. RLocks are # never removed (collections come and go but the lock identity # for a given (db, coll) stays stable across drop+recreate to # avoid races with in-flight writers). with self._coll_locks_mutex: existing = self._coll_locks.get(key) if existing is not None: return existing lock = threading.RLock() self._coll_locks[key] = lock return lock def _mint_oplog_seq_and_ts(self, n: int) -> tuple[int, list[Timestamp]]: """Atomically reserve ``n`` consecutive oplog seq numbers and mint ``n`` strictly-monotonic timestamps. Returns ``(start_seq, [ts_0, ..., ts_{n-1}])``. Held only under ``_oplog_seq_lock`` (microseconds of work) — the actual oplog cursor writes happen in the caller's WT session without blocking other writers on this lock. """ with self._oplog_seq_lock: start = self._next_seq self._next_seq += n timestamps = [self._mint_ts() for _ in range(n)] return start, timestamps def _collection_uuid(self, db: str, coll: str) -> _uuid.UUID: """Return the collection's UUID, minting and persisting on first call. Fast path (UUID already present): no Python lock — straight WT cursor read on the calling thread's session. This was a major per-insert bottleneck before Phase 2.4: every write re-acquired ``self._lock`` here, defeating the per-collection lock split. Slow path (mint a new UUID): take ``_coll_lock`` for the namespace to serialise the persist; double-check inside the lock so two racing callers can't mint different UUIDs for the same collection. """ opts = self._coll_options(db, coll) or {} existing = opts.get("uuid") if isinstance(existing, _uuid.UUID): return existing if isinstance(existing, bson.Binary) and len(existing) == 16: return _uuid.UUID(bytes=bytes(existing)) if isinstance(existing, bytes) and len(existing) == 16: return _uuid.UUID(bytes=existing) # Mint path — take the per-coll lock; re-read after acquiring # so a racer that won the mint race is observed. with self._coll_lock(db, coll): opts = self._coll_options(db, coll) or {} existing = opts.get("uuid") if isinstance(existing, _uuid.UUID): return existing if isinstance(existing, bson.Binary) and len(existing) == 16: return _uuid.UUID(bytes=bytes(existing)) if isinstance(existing, bytes) and len(existing) == 16: return _uuid.UUID(bytes=existing) new_uuid = _uuid.uuid4() opts["uuid"] = new_uuid self._write_coll_options(db, coll, opts) return new_uuid def collection_uuid(self, db: str, coll: str) -> _uuid.UUID: """Public alias for ``_collection_uuid``.""" return self._collection_uuid(db, coll) def current_cluster_time(self) -> Timestamp: """Return a strictly-monotonic ``Timestamp`` advancing the cluster clock.""" with self._oplog_seq_lock: ts = self._mint_ts() # Meta persist uses the calling thread's WT session/cursor — # safe to do outside the seq lock since it doesn't depend on # the in-memory counters being held stable past the mint. self._persist_oplog_meta() return ts def _write_coll_options(self, db: str, coll: str, opts: Mapping[str, Any]) -> None: c = self._cursor(_COLL_TABLE) # bson can't directly encode a uuid.UUID without a codec, so store as Binary subtype 4. encoded: dict[str, Any] = {} for k, v in opts.items(): if isinstance(v, _uuid.UUID): encoded[k] = bson.Binary(v.bytes, subtype=4) else: encoded[k] = v c[db, coll] = bson.encode(encoded) if encoded else b"" def set_collection_options(self, db: str, coll: str, **opts: Any) -> None: """Merge ``opts`` into the collection's options blob (creates if absent).""" with self._lock: self._ensure_collection(db, coll) current = self._coll_options(db, coll) or {} current.update(opts) self._write_coll_options(db, coll, current) def get_collection_options(self, db: str, coll: str) -> dict[str, Any]: """Return the collection's options blob, or ``{}`` if absent.""" if self.enable_oplog and db == "local" and coll == "oplog.rs": # Synthetic ``local.oplog.rs``: report the capped-collection # shape mongod uses so $collStats / listCollections options # match. ``size`` is a notional byte cap derived from the # entry cap × a conservative per-entry estimate; we don't # track real byte usage, only entry count. return { "capped": True, "size": self.oplog_max_entries * 16 * 1024, "max": self.oplog_max_entries, } self._refresh_read_snapshot() with self._lock: opts = self._coll_options(db, coll) or {} # Decode UUID Binary back into uuid.UUID for callers. decoded: dict[str, Any] = {} for k, v in opts.items(): if k == "uuid" and isinstance(v, bson.Binary) and len(v) == 16: decoded[k] = _uuid.UUID(bytes=bytes(v)) else: decoded[k] = v return decoded def _is_oplog_rs(self, db: str, coll: str) -> bool: """``(local, oplog.rs)`` is the synthetic oplog view.""" return self.enable_oplog and db == "local" and coll == "oplog.rs" def _scan_oplog_entries(self) -> list[dict[str, Any]]: """Walk every persisted oplog entry and return the decoded docs. Uses a private short-lived session so the read view always reflects rows committed by writer threads on other connections (same pattern as ``read_oplog``). """ rows: list[dict[str, Any]] = [] with self._lock: session = self._conn.open_session() try: c = session.open_cursor(_OPLOG_TABLE, None) try: rc = c.next() while rc == 0: blob = bytes(c.get_value()) if blob: rows.append(bson.decode(blob)) rc = c.next() finally: with contextlib.suppress(Exception): c.close() finally: with contextlib.suppress(Exception): session.close() return rows def _find_oplog_rs( self, filter: dict[str, Any] | None, *, skip: int, limit: int, sort: Mapping[str, Any] | None, projection: Mapping[str, Any] | None, let: dict[str, Any] | None, collation: Any, ) -> list[dict[str, Any]]: """Read path for the synthetic ``local.oplog.rs`` view. Entries are walked in seq order (== ts order). Filter / sort / skip / limit / projection are all honoured against the decoded entry docs via the existing pure-Python helpers. """ from secantus.collation import parse as _parse_collation collation_obj = _parse_collation(collation) rows = self._scan_oplog_entries() if filter: rows = [r for r in rows if matches(r, filter, vars=let, collation=collation_obj)] if sort: rows = sort_docs(rows, sort) if skip: rows = rows[skip:] if limit > 0: rows = rows[:limit] if projection: rows = [apply_projection(r, projection) for r in rows] return rows def _count_oplog_rs( self, filter: dict[str, Any] | None, *, let: dict[str, Any] | None, collation: Any, ) -> int: from secantus.collation import parse as _parse_collation collation_obj = _parse_collation(collation) if not filter: return len(self._scan_oplog_entries()) return sum( 1 for r in self._scan_oplog_entries() if matches(r, filter, vars=let, collation=collation_obj) ) def _emit_oplog( self, entries: list[dict[str, Any]], pre_images: list[bytes | None] | None = None, ) -> int: """Append ``entries`` to the oplog table under ``self._lock``. ``pre_images`` is parallel to ``entries``; non-None elements are stored under the matching seq in ``_PREIMAGE_TABLE``. Returns the highest seq emitted (0 if ``entries`` is empty). Notifies waiters on ``self._oplog_cv`` once writes have committed. If ``self.enable_oplog`` is False, returns 0 immediately — the caller's prebuilt ``entries`` list is discarded. The change-stream condvar is still notified so any tailable getMore wakes up and observes the (empty) state. """ if not self.enable_oplog: with self._oplog_cv: self._oplog_cv.notify_all() return 0 if not entries: return 0 if pre_images is None: pre_images = [None] * len(entries) assert len(pre_images) == len(entries) # Reserve seq + ts range up-front under the tiny seq lock. # The actual cursor writes below run on this thread's WT # session without holding any cross-thread Python lock. n = len(entries) start_seq, ts_range = self._mint_oplog_seq_and_ts(n) op_cur = self._cursor(_OPLOG_TABLE) pre_cur = None last_seq = 0 for i, (entry, pre) in enumerate(zip(entries, pre_images, strict=True)): seq = start_seq + i entry_with_ts = dict(entry) if "ts" not in entry_with_ts: entry_with_ts["ts"] = ts_range[i] if "wall" not in entry_with_ts: entry_with_ts["wall"] = _dt.datetime.now(_dt.timezone.utc) op_cur[seq] = bson.encode(entry_with_ts) if pre is not None: if pre_cur is None: pre_cur = self._cursor(_PREIMAGE_TABLE) pre_cur[seq] = pre last_seq = seq # ``_persist_oplog_meta`` was called here on every emit, but # under concurrent writers it WT-rollbacks half the time — # every writer hits the same single ``"state"`` meta row. # The meta row is purely a recovery optimisation; if it's # stale, ``_load_oplog_meta``'s fallback scans the oplog # table for the actual max seq. So we now persist only on # close + on prune_oplog, both of which are rare. The seq # mint itself is durable because the actual oplog rows are # written on every emit. self._oplog_emit_count += len(entries) if self._oplog_emit_count >= _OPLOG_PRUNE_INTERVAL: self._oplog_emit_count = 0 self._prune_oplog_locked(now=self._time()) with self._oplog_cv: self._oplog_cv.notify_all() return last_seq def read_oplog( self, *, start_seq: int, limit: int, ns_filter: Callable[[str], bool] | None = None, ) -> list[tuple[int, dict[str, Any]]]: """Forward-scan the oplog from ``start_seq`` (inclusive). Uses a private short-lived session so the read view always reflects rows committed by other sessions. The cached per-thread session's snapshot is sticky — under WiredTiger's MVCC, reusing it across getMore polls would never observe oplog rows produced by a writer running on a different connection thread. """ out: list[tuple[int, dict[str, Any]]] = [] with self._lock: session = self._conn.open_session() try: c = session.open_cursor(_OPLOG_TABLE, None) try: c.set_key(int(start_seq)) rc = c.search_near() if rc == wt.WT_NOTFOUND: return out if rc < 0 and c.next() != 0: return out while True: seq = int(c.get_key()) blob = bytes(c.get_value()) if blob: entry = bson.decode(blob) if ns_filter is None or ns_filter(str(entry.get("ns", ""))): out.append((seq, entry)) if len(out) >= limit: break if c.next() != 0: break finally: with contextlib.suppress(Exception): c.close() finally: with contextlib.suppress(Exception): session.close() return out def read_preimage(self, seq: int) -> dict[str, Any] | None: """Return the pre-image doc for ``seq`` if one was stored, else ``None``. Uses a private session for cross-thread visibility (see ``read_oplog``). """ with self._lock: session = self._conn.open_session() try: c = session.open_cursor(_PREIMAGE_TABLE, None) try: c.set_key(int(seq)) if c.search() != 0: return None blob = bytes(c.get_value()) if not blob: return None return bson.decode(blob) finally: with contextlib.suppress(Exception): c.close() finally: with contextlib.suppress(Exception): session.close() def oplog_tail_seq(self) -> int: """Highest seq currently present (or last emitted). 0 if empty.""" with self._lock: return self._next_seq - 1 def oplog_tail_seq_nolock(self) -> int: """Highest seq read without acquiring ``self._lock``. Safe for use **only** as the wake predicate for a tailable ``getMore`` waiting on ``self._oplog_cv``: lock order in the write path is ``_lock`` -> ``_oplog_cv``, so a waiter that already holds ``_oplog_cv`` (which is what ``cv.wait_for`` does) MUST NOT then take ``_lock`` -- that's an ABBA deadlock with any concurrent writer. Reading ``_next_seq`` directly is safe because (a) ``int`` reads are atomic under the GIL and (b) the cv is also notified on every commit, so any momentary stale read self-corrects on the next iteration of the ``wait_for`` predicate. """ return self._next_seq - 1 def oplog_floor_seq(self) -> int: """Smallest seq currently present after pruning. 0 if empty. Uses a private session for cross-thread visibility. """ with self._lock: session = self._conn.open_session() try: c = session.open_cursor(_OPLOG_TABLE, None) try: rc = c.next() if rc != 0: return 0 return int(c.get_key()) finally: with contextlib.suppress(Exception): c.close() finally: with contextlib.suppress(Exception): session.close() def find_seq_for_ts(self, ts: Timestamp) -> int: """Smallest seq whose entry ``ts >= target``. Tail+1 if none qualify. Uses a private session for cross-thread visibility. """ with self._lock: session = self._conn.open_session() try: c = session.open_cursor(_OPLOG_TABLE, None) try: rc = c.next() while rc == 0: seq = int(c.get_key()) blob = bytes(c.get_value()) if blob: entry = bson.decode(blob) entry_ts = entry.get("ts") if isinstance(entry_ts, Timestamp) and ( entry_ts.time > ts.time or (entry_ts.time == ts.time and entry_ts.inc >= ts.inc) ): return seq rc = c.next() return self._next_seq finally: with contextlib.suppress(Exception): c.close() finally: with contextlib.suppress(Exception): session.close() def prune_oplog(self, *, now: float | None = None) -> int: """Drop oplog rows older than retention or above the entry cap.""" with self._lock: return self._prune_oplog_locked(now=now) def _ns(self, db: str, coll: str) -> str: return f"{db}.{coll}" def _pre_post_images_enabled(self, db: str, coll: str) -> bool: opts = self._coll_options(db, coll) or {} cfg = opts.get("changeStreamPreAndPostImages") return isinstance(cfg, Mapping) and bool(cfg.get("enabled")) def _prune_oplog_locked(self, *, now: float | None = None) -> int: when = now if now is not None else self._time() cutoff_secs = int(when - self.oplog_retention_seconds) # Two-phase: collect doomed seqs, then delete (avoid mutating during scan). doomed: list[int] = [] all_seqs: list[int] = [] c = self._cursor(_OPLOG_TABLE) rc = c.next() while rc == 0: seq = int(c.get_key()) blob = bytes(c.get_value()) all_seqs.append(seq) if blob: entry = bson.decode(blob) ts = entry.get("ts") if isinstance(ts, Timestamp) and ts.time < cutoff_secs: doomed.append(seq) rc = c.next() # Trim to entry cap by extending doom set to oldest entries. kept_count = len(all_seqs) - len(doomed) if kept_count > self.oplog_max_entries: extra = kept_count - self.oplog_max_entries doomed_set = set(doomed) for seq in all_seqs: if extra <= 0: break if seq not in doomed_set: doomed.append(seq) doomed_set.add(seq) extra -= 1 if not doomed: return 0 op_del = self._cursor(_OPLOG_TABLE) pre_del = self._cursor(_PREIMAGE_TABLE) for seq in doomed: op_del.set_key(seq) with contextlib.suppress(wt.WiredTigerError): op_del.remove() op_del.reset() pre_del.set_key(seq) with contextlib.suppress(wt.WiredTigerError): pre_del.remove() pre_del.reset() return len(doomed) # --- Users (auth) --- def add_user( self, db: str, username: str, record: Mapping[str, Any], *, replace: bool = False, ) -> bool: """Persist a user record. Returns True if added; False if it already existed and ``replace=False``. ``record`` is a BSON-encodable dict of arbitrary shape (the commands layer owns the structure). Stored verbatim. """ with self._lock: c = self._cursor(_USERS_TABLE) c.set_key(db, username) if c.search() == 0 and not replace: return False c.reset() c[db, username] = bson.encode(dict(record)) return True def get_user(self, db: str, username: str) -> dict[str, Any] | None: with self._lock: c = self._cursor(_USERS_TABLE) c.set_key(db, username) if c.search() != 0: return None blob = bytes(c.get_value()) return bson.decode(blob) if blob else None def drop_user(self, db: str, username: str) -> bool: with self._lock: c = self._cursor(_USERS_TABLE) c.set_key(db, username) if c.search() != 0: return False c.remove() return True def list_users( self, db: str | None = None, *, skip: int = 0, limit: int = 100, ) -> list[dict[str, Any]]: """Paginated user listing. ``db=None`` lists across all databases.""" if limit <= 0 or limit > 1000: limit = 1000 out: list[dict[str, Any]] = [] with self._lock: c = self._cursor(_USERS_TABLE) rc = c.next() seen = 0 while rc == 0: k = c.get_key() row_db = k[0] if db is None or row_db == db: if seen >= skip: blob = bytes(c.get_value()) if blob: out.append(bson.decode(blob)) if len(out) >= limit: break seen += 1 rc = c.next() return out # ------------------------------------------------------------------ # Per-database profiling settings. # # Real mongod tracks (level, slowms, sampleRate) per database in # memory + persists to the database's metadata. We persist in a # dedicated WT table keyed by db name. The dispatch path reads # these settings on every command — keep ``get_profile`` fast. # ------------------------------------------------------------------ def get_profile(self, db: str) -> dict[str, Any]: """Return the active profile settings for ``db``, defaults if unset. Defaults match mongod: level 0 (off), slowms 100, sampleRate 1.0. """ with self._lock: c = self._cursor(_PROFILE_TABLE) c.set_key(db) if c.search() != 0: return {"level": 0, "slowms": 100, "sampleRate": 1.0} blob = bytes(c.get_value()) if not blob: return {"level": 0, "slowms": 100, "sampleRate": 1.0} doc = bson.decode(blob) # ``or default`` is wrong here — slowms=0 / sampleRate=0.0 are # legitimate values that must round-trip, not be replaced # with defaults. Use direct ``.get`` with the default and # coerce only when a value is actually present. level_v = doc.get("level", 0) slowms_v = doc.get("slowms", 100) rate_v = doc.get("sampleRate", 1.0) return { "level": int(level_v) if level_v is not None else 0, "slowms": int(slowms_v) if slowms_v is not None else 100, "sampleRate": float(rate_v) if rate_v is not None else 1.0, } def set_profile( self, db: str, *, level: int, slowms: int = 100, sample_rate: float = 1.0, ) -> None: """Persist profile settings for ``db``.""" if level not in (0, 1, 2): raise ValueError("level must be 0, 1, or 2") if slowms < 0: raise ValueError("slowms must be non-negative") if not (0.0 <= sample_rate <= 1.0): raise ValueError("sampleRate must be in [0, 1]") doc = {"level": int(level), "slowms": int(slowms), "sampleRate": float(sample_rate)} with self._lock: c = self._cursor(_PROFILE_TABLE) c[db] = bson.encode(doc) def ensure_profile_collection(self, db: str, *, size_bytes: int = 10 * 1024 * 1024) -> None: """Ensure ``<db>.system.profile`` exists as a 10 MB-default capped collection.""" if self.collection_exists(db, "system.profile"): return self.create_collection(db, "system.profile") self.set_collection_options(db, "system.profile", capped=True, size=int(size_bytes)) # ------------------------------------------------------------------ # Custom roles. Storage layer is a thin BSON-blob CRUD; the commands # layer owns the role-record shape (privileges + inherited roles) # and ``secantus.rbac`` owns the privilege-check logic that walks # the inheritance graph. # ------------------------------------------------------------------ def add_role( self, db: str, name: str, record: Mapping[str, Any], *, replace: bool = False, ) -> bool: """Persist a custom role record. Returns True if added; False if it already existed and ``replace=False``.""" with self._lock: c = self._cursor(_ROLES_TABLE) c.set_key(db, name) if c.search() == 0 and not replace: return False c.reset() c[db, name] = bson.encode(dict(record)) return True def get_role(self, db: str, name: str) -> dict[str, Any] | None: # Use a private short-lived session so cross-thread visibility # is guaranteed: connection-thread A may have written a role # while we're on connection-thread B, and B's cached session # carries a sticky snapshot that won't observe A's commit. # Same pattern as ``read_oplog``. The cost (one open_session + # close per call) is negligible vs the correctness win. with self._lock: session = self._conn.open_session() try: c = session.open_cursor(_ROLES_TABLE, None, None) try: c.set_key(db, name) if c.search() != 0: return None blob = bytes(c.get_value()) return bson.decode(blob) if blob else None finally: with contextlib.suppress(Exception): c.close() finally: with contextlib.suppress(Exception): session.close() def drop_role(self, db: str, name: str) -> bool: with self._lock: c = self._cursor(_ROLES_TABLE) c.set_key(db, name) if c.search() != 0: return False c.remove() return True def list_roles( self, db: str | None = None, *, skip: int = 0, limit: int = 100, ) -> list[dict[str, Any]]: """Paginated custom-role listing. ``db=None`` spans every db.""" if limit <= 0 or limit > 1000: limit = 1000 out: list[dict[str, Any]] = [] with self._lock: c = self._cursor(_ROLES_TABLE) rc = c.next() seen = 0 while rc == 0: k = c.get_key() row_db = k[0] if db is None or row_db == db: if seen >= skip: blob = bytes(c.get_value()) if blob: out.append(bson.decode(blob)) if len(out) >= limit: break seen += 1 rc = c.next() return out def close(self) -> None: # Stop background threads before tearing down WT — both the # TTL sweeper and the noop heartbeat acquire ``self._lock``, # so racing them against close would deadlock or # use-after-close. self._ttl_stop.set() if self._ttl_thread is not None and self._ttl_thread.is_alive(): self._ttl_thread.join(timeout=2.0) self._ttl_thread = None self._noop_stop.set() if self._noop_thread is not None and self._noop_thread.is_alive(): self._noop_thread.join(timeout=2.0) self._noop_thread = None with self._lock: if self._closed: return self._closed = True # Persist the oplog meta one last time. We dropped the # per-emit persist in Phase 2.4 (it caused WT-rollback # storms under concurrent writers), so this is the # canonical place to write the in-memory ``_next_seq`` # and timestamp counters down to disk before shutdown. with contextlib.suppress(Exception): self._persist_oplog_meta() # Force a checkpoint before tearing the connection down. # ``WT_CONNECTION->close`` does this implicitly, but only # when logging is off (or hits the connection's # close-time flush window). Driving it explicitly here # gives a durable on-disk image of the dataset at the # moment of shutdown regardless of journal state — the # behaviour callers reasonably expect from ``close()``. # Skip for in-memory backends: WT's in_memory engine # rejects checkpoint() with a noisy stderr log # (``__wt_inmem_unsupported_op``) on every call. if not self._in_memory: with contextlib.suppress(Exception): self._session().checkpoint() for s in self._all_sessions: with contextlib.suppress(Exception): s.close() self._all_sessions.clear() with contextlib.suppress(Exception): self._conn.close() if self._tempdir is not None: # Don't follow symlinks during cleanup. A local attacker # racing the mkdtemp could replace `_tempdir` with a # symlink to elsewhere on the filesystem before close() # fires — `shutil.rmtree(symlink, ignore_errors=True)` # would then delete the symlink target. The mkdtemp # already creates with mode 0700 (owner-only), but the # parent /tmp is world-writable, so this is the # belt-and-braces guard. Failures during cleanup are # logged but not raised — close() must remain idempotent. try: if not os.path.islink(self._tempdir): shutil.rmtree(self._tempdir) except OSError: # Best-effort: log via warnings rather than crash close(). import warnings as _warn _warn.warn( f"failed to remove WiredTiger tempdir {self._tempdir!r}", ResourceWarning, stacklevel=2, ) self._tempdir = None def prune_ttl_all_collections(self, *, now: _dt.datetime | None = None) -> int: """Run :meth:`prune_ttl` against every collection, returning the total docs pruned. Used by the background sweeper and exposed publicly so callers (admin tooling, tests) can drive a deterministic global pass. Callers using the cached per-thread session must call :meth:`_reset_thread_session` first — WiredTiger snapshots are sticky per-session, so reads otherwise miss rows committed by other threads. The sweeper does this on every iteration; one-shot user calls happen on the writer's thread and see their own writes. """ with self._lock: c = self._cursor(_COLL_TABLE) namespaces: list[tuple[str, str]] = [] rc = c.next() while rc == 0: k = c.get_key() namespaces.append((k[0], k[1])) rc = c.next() total = 0 for db, coll in namespaces: with contextlib.suppress(Exception): # Storage close races: drop_collection between snapshot # and prune fails inside prune_ttl with a missing-coll # error. The sweeper should never crash the daemon. total += self.prune_ttl(db, coll, now=now) return total def _ttl_sweep_loop(self) -> None: """Background sweeper: every ``ttl_sweep_seconds`` walk all collections and prune expired docs. Stops when ``_ttl_stop`` is set or the storage is closed. Drops the per-thread WT session before each iteration so the next cursor call opens a fresh session. WiredTiger sessions carry a sticky read snapshot — without the reset, reads on this thread would never observe rows committed by other writers, and TTL sweeps would always return 0 even when expired docs existed. Same pattern as ``read_oplog``. """ import logging log = logging.getLogger("secantus.storage.ttl") while not self._ttl_stop.wait(self._ttl_sweep_seconds): if self._closed: return self._reset_thread_session() try: self.prune_ttl_all_collections() except Exception: # Sweeper failures must not propagate — they'd kill # the daemon thread and silently disable expiry. log.exception("ttl sweep failed") def emit_noop_heartbeat(self) -> int: """Append one ``{op: "n"}`` heartbeat to the oplog and return its seq. The entry shape mirrors mongod's periodic noop: ``op = "n"``, an empty namespace, current cluster time, and a small ``o = {msg: "periodic noop"}`` payload. Change-stream consumers skip ``op: "n"`` rows in projection but still advance their ``position_seq`` and ``last_token`` past them, so the resume token of a quiet collection stays current. Public so callers (admin tooling, tests that drive heartbeats deterministically) can fire one explicitly. """ with self._lock: return self._emit_oplog( [ { "op": "n", "ns": "", "o": {"msg": "periodic noop"}, } ] ) def _noop_heartbeat_loop(self) -> None: """Background heartbeat: emit one ``{op: "n"}`` oplog entry every ``noop_heartbeat_seconds``. Stops when ``_noop_stop`` is set or the storage is closed. Failures are logged and swallowed — a transient WT error must not kill the daemon thread. """ import logging log = logging.getLogger("secantus.storage.noop") while not self._noop_stop.wait(self._noop_heartbeat_seconds): if self._closed: return try: self.emit_noop_heartbeat() except Exception: log.exception("noop heartbeat failed") def _reset_thread_session(self) -> None: """Close the calling thread's cached WT session + cursors so the next ``_session()`` call opens fresh ones. Needed when a thread reads in a loop and must observe writes from other threads (snapshot is otherwise sticky).""" s = getattr(self._tls, "session", None) if s is None: return cursors = getattr(self._tls, "cursors", {}) or {} for c in cursors.values(): with contextlib.suppress(Exception): c.close() with contextlib.suppress(Exception): s.close() with self._lock, contextlib.suppress(ValueError): self._all_sessions.remove(s) self._tls.session = None self._tls.cursors = {} def checkpoint(self) -> None: """Force a WiredTiger checkpoint to flush dirty pages to disk. Backs the ``fsync`` command and the admin UI's maintenance slice. Lock-protected so concurrent commands wait their turn. On in-memory backends the call is a no-op (WT's in_memory engine has no disk to flush and rejects with a noisy stderr log). """ with self._lock: if self._closed or self._in_memory: return self._session().checkpoint() def create_archive(self, output_path: str) -> dict[str, int | str]: """Force a checkpoint, then tar the consistent file set into ``output_path``. Returns ``{"path": <abs>, "sizeBytes": <int>}`` on success. Raises ``RuntimeError`` for in-memory backends — there's no on-disk state to archive. Uses WT's dedicated ``backup:`` cursor to enumerate the files that constitute a consistent snapshot. WT promises during the cursor's lifetime that those files won't change and that they are read-shareable — the latter matters on Windows, where WT otherwise holds exclusive file locks that block ``tarfile``'s reads. Walking the directory directly worked on Unix (open files are shareable by default) but ``PermissionError``'d on Windows. Output is a single ``.tar.gz`` (gzip-compressed) so the archive round-trips cleanly through git/mail/scp; the typical workload compresses well because WT pages aren't snappy/zstd at rest. """ import tarfile if self._in_memory: raise RuntimeError( "create_archive: cannot archive an in-memory backend " "(WT in_memory engine has no on-disk state)" ) # Resolve to absolute so the returned ``path`` is unambiguous # for the caller even if their cwd has shifted. abs_out = os.path.abspath(output_path) os.makedirs(os.path.dirname(abs_out) or ".", exist_ok=True) with self._lock: if self._closed: raise RuntimeError("create_archive: storage is closed") self._session().checkpoint() # A private session for the backup cursor so its lifecycle # doesn't interfere with the per-thread cached session # that handles regular work. backup_session = self._conn.open_session() try: cursor = backup_session.open_cursor("backup:", None, None) try: # Tar inline while the cursor is open: WT creates # the ``WiredTiger.backup`` metadata file as part # of the cursor's open state and removes it on # close, so collecting filenames first then tarring # would race the cleanup. Iterate-and-add keeps # every file readable for the duration of the tar. with tarfile.open(abs_out, "w:gz") as tar: while cursor.next() == 0: rel = cursor.get_key() full = os.path.join(self.home_path, rel) tar.add(full, arcname=rel) finally: cursor.close() finally: backup_session.close() return {"path": abs_out, "sizeBytes": os.path.getsize(abs_out)} @contextlib.contextmanager def _batch_transaction(self, *, sync: bool = False) -> Any: """Group multiple cursor writes into one WT transaction = one log record. WT auto-commits every individual ``cursor.insert()`` / ``cursor.update()`` etc., which means N writes produce N log records and N commit overheads. With this wrapper, the same N writes share a single commit (and therefore a single log record): on a typical bulk insert that's a 2-5x throughput win for ``--batch-size > 1`` on the wire side, with the same durability guarantee (all-or-nothing on commit). ``sync=True`` overrides the connection-level ``transaction_sync`` setting and forces this individual commit to fsync the log to disk before returning — the per-transaction equivalent of the server-wide ``sync_on_commit`` knob. Used to honour ``writeConcern: {j: true}`` on a single write even when the daemon is otherwise running with ``sync_on_commit=false``. ``sync=False`` (default) inherits the connection's ``transaction_sync`` config. Caller must already hold ``self._lock``. Reads within the transaction observe the in-progress writes — fine for our unique-conflict probes which need to see uncommitted siblings in the same batch. On exception the transaction is rolled back. Callers that accumulate per-doc errors (e.g. ``ordered=False`` insert) should NOT raise out of the block — they handle the per-doc errors locally and let the surviving writes commit. """ session = self._session() # Cached cursors must be reset before begin_transaction so they # don't carry a stale snapshot from before the transaction # boundary. WT documents this requirement explicitly. for c in getattr(self._tls, "cursors", {}).values(): with contextlib.suppress(Exception): c.reset() session.begin_transaction() try: yield session except Exception: with contextlib.suppress(Exception): session.rollback_transaction() raise else: session.commit_transaction("sync=on" if sync else None) def _session(self) -> Any: s = getattr(self._tls, "session", None) if s is None: s = self._conn.open_session() self._tls.session = s self._tls.cursors = {} with self._lock: self._all_sessions.append(s) return s def _refresh_read_snapshot(self) -> None: """Force the per-thread WT session to acquire a fresh read snapshot. WiredTiger's default snapshot isolation pins a session's read view at first cursor access; subsequent reads on the same session see exactly that point-in-time view until the session commits / rolls back a transaction. That's correct for a single in-flight operation, but our daemon reuses one session per connection thread across the full lifetime of a TCP connection. Without an explicit snapshot refresh, a long-lived client connection (Java's ``ClusterFixture`` is the canonical case) does an insert, idles while another connection commits a write, then reads — and sees the stale pre-other-write view. ``session.reset_snapshot()`` releases the held snapshot so the next cursor read picks up the latest committed state. Called at the top of every public read entry point (``find_matching``, ``count_matching``, ``list_*``, ``explain_plan``) so cross- connection visibility matches real ``mongod``. """ s = getattr(self._tls, "session", None) if s is None: return with contextlib.suppress(Exception): # ``reset_snapshot()`` errors if the session is in an # explicit transaction. Reads never run inside one # (``_batch_transaction`` is write-only), so the exception # path is defensive — log via ``suppress`` and move on. s.reset_snapshot() def _cursor(self, table: str, *, overwrite: bool = True) -> Any: self._session() cursors: dict[tuple[str, bool], Any] = self._tls.cursors key = (table, overwrite) c = cursors.get(key) if c is None: cfg = None if overwrite else "overwrite=false" c = self._tls.session.open_cursor(table, None, cfg) cursors[key] = c else: c.reset() return c def _coll_options(self, db: str, coll: str) -> dict[str, Any] | None: c = self._cursor(_COLL_TABLE) c.set_key(db, coll) rc = c.search() if rc != 0: return None blob = bytes(c.get_value()) return bson.decode(blob) if blob else {} def _ensure_collection(self, db: str, coll: str) -> None: c = self._cursor(_COLL_TABLE) c.set_key(db, coll) if c.search() == 0: return c.reset() c[db, coll] = b"" def collection_exists(self, db: str, coll: str) -> bool: with self._lock: return self._coll_options(db, coll) is not None def create_collection(self, db: str, coll: str) -> bool: with self._lock: c = self._cursor(_COLL_TABLE) c.set_key(db, coll) if c.search() == 0: return False c.reset() c[db, coll] = b"" self._collection_uuid(db, coll) # mint and persist ui = self._collection_uuid(db, coll) self._emit_oplog( [ { "op": "c", "ns": f"{db}.$cmd", "ui": bson.Binary(ui.bytes, subtype=4), "o": { "create": coll, "idIndex": {"v": 2, "key": {"_id": 1}, "name": "_id_"}, }, } ] ) return True def _scan_docs(self, db: str, coll: str) -> Iterable[tuple[bytes, bytes]]: c = self._cursor(_DOC_TABLE) c.set_key(db, coll, b"") rc = c.search_near() if rc == wt.WT_NOTFOUND: return if rc < 0 and c.next() != 0: return while True: k = c.get_key() if k[0] != db or k[1] != coll: return yield bytes(k[2]), bytes(c.get_value()) if c.next() != 0: return def _all_docs(self, db: str, coll: str) -> list[dict[str, Any]]: # Two-stage to keep ``bson.decode`` out of ``self._lock`` — # otherwise an N-doc scan blocks every other thread for the # whole decode loop. Lock owns the WT cursor walk; decode # happens after release. with self._lock: blobs = [blob for _id_k, blob in self._scan_docs(db, coll)] return [bson.decode(blob) for blob in blobs] def _all_docs_with_id_key(self, db: str, coll: str) -> list[tuple[dict[str, Any], bytes]]: with self._lock: raw = [(id_k, blob) for id_k, blob in self._scan_docs(db, coll)] return [(bson.decode(blob), id_k) for id_k, blob in raw] def scan_docs_after_id_key( self, db: str, coll: str, after: bytes | None ) -> list[tuple[bytes, dict[str, Any]]]: """Scan the document table in natural (id_key) order, returning only rows whose ``id_key`` is strictly greater than ``after``. ``after`` of ``None`` returns the entire collection. Used by the tailable-cursor producer to emit only the docs inserted since the last poll. Returns ``[(id_key, doc), ...]`` — callers update their ``after`` checkpoint to the last returned ``id_key`` for the next poll. """ # Two-stage: collect raw bytes under the lock, decode after. with self._lock: raw: list[tuple[bytes, bytes]] = [ (id_k, blob) for id_k, blob in self._scan_docs(db, coll) if after is None or id_k > after ] return [(id_k, bson.decode(blob)) for id_k, blob in raw] def collection_is_capped(self, db: str, coll: str) -> bool: """Public predicate: does the collection have ``capped: true`` set?""" with self._lock: opts = self._coll_options(db, coll) or {} return bool(opts.get("capped")) def insert( self, db: str, coll: str, docs: Iterable[dict[str, Any]], *, ordered: bool = True, journal: bool = False, ) -> tuple[int, list[dict[str, Any]]]: inserted = 0 errors: list[dict[str, Any]] = [] oplog_entries: list[dict[str, Any]] = [] fresh_id_keys: set[bytes] = set() oplog_on = self.enable_oplog with self._coll_lock(db, coll), self._batch_transaction(sync=journal): # Per-collection lock (Phase 2.4): writes to other # collections proceed in parallel; same-collection writes # still serialise to keep the unique-index pre-check # race-free. _batch_transaction wraps the per-doc cursor # inserts (doc table + index entries + oplog) in one # explicit WT transaction so they share a single commit / # log record. self._ensure_collection(db, coll) ns = self._ns(db, coll) if oplog_on else "" ui = self._collection_uuid(db, coll) if oplog_on else None indexes = self._all_indexes(db, coll) partials = self._partial_filters(db, coll) multikey_names = self._multikey_index_names(db, coll) for index, doc in enumerate(docs): if "_id" not in doc: doc["_id"] = bson.ObjectId() key = _id_key(doc["_id"]) conflict = self._unique_conflict( db, coll, doc, indexes, exclude_id_key=None, partials=partials ) if conflict is not None: cname, kpat, kval = conflict errors.append( { "index": index, "code": 11000, "errmsg": ( f"E11000 duplicate key error in index {cname}: _id={doc['_id']!r}" ), "keyPattern": kpat, "keyValue": kval, } ) if ordered: break continue # Pre-flight every geo index: a bad geometry should reject # the doc *before* it lands in the doc table, so we don't # leave a half-indexed write behind. Validation is cheap; # _write_index_entries below recomputes the same cells. try: self._validate_geo_indexes(db, coll, doc, indexes, partials) except GeoExtractError as exc: errors.append({"index": index, "code": 16572, "errmsg": str(exc)}) if ordered: break continue blob = bson.encode(doc) doc_cur = self._cursor(_DOC_TABLE, overwrite=False) doc_cur.set_key(db, coll, key) doc_cur.set_value(blob) try: doc_cur.insert() except wt.WiredTigerError: errors.append( { "index": index, "code": 11000, "errmsg": f"E11000 duplicate key error: _id {doc['_id']!r}", } ) if ordered: break continue self._write_index_entries(db, coll, doc, indexes, partials) multikey_names = self._maybe_mark_multikey(db, coll, doc, indexes, multikey_names) inserted += 1 if oplog_on: oplog_entries.append( { "op": "i", "ns": ns, "ui": bson.Binary(ui.bytes, subtype=4), "o": dict(doc), "o2": {"_id": doc["_id"]}, } ) fresh_id_keys.add(key) cap_entries, cap_pre_images = self._enforce_capped_bounds_locked( db, coll, fresh_id_keys, indexes, partials, oplog_on, ns, ui ) if oplog_entries or cap_entries: pre_images = [None] * len(oplog_entries) + cap_pre_images self._emit_oplog(oplog_entries + cap_entries, pre_images) return inserted, errors def _enforce_capped_bounds_locked( self, db: str, coll: str, fresh_id_keys: set[bytes], indexes: list[tuple[str, dict[str, Any], bool, bool]], partials: dict[str, dict[str, Any]], oplog_on: bool, ns: str, ui: _uuid.UUID | None, ) -> tuple[list[dict[str, Any]], list[bytes | None]]: """Evict oldest non-fresh docs from a capped collection until within bounds. "Oldest" is the natural-order walk over the doc table, which matches insertion order when ``_id`` is monotonic (e.g. the default ObjectId). For non-monotonic ``_id`` values the eviction order reflects ``_id`` byte order, not literal insertion order — capped users with custom ``_id`` should not rely on FIFO semantics. """ raw = self._coll_options(db, coll) or {} if not raw.get("capped"): return [], [] size_limit = raw.get("size") max_limit = raw.get("max") if size_limit is None and max_limit is None: return [], [] scanned = list(self._scan_docs(db, coll)) total = sum(len(blob) for _id_k, blob in scanned) count = len(scanned) oplog_entries: list[dict[str, Any]] = [] pre_images: list[bytes | None] = [] preimages_on = oplog_on and self._pre_post_images_enabled(db, coll) for id_k, blob in scanned: over_size = size_limit is not None and total > size_limit over_max = max_limit is not None and count > max_limit if not over_size and not over_max: break if id_k in fresh_id_keys: # Don't evict docs we just inserted in this batch — they # always sort to the tail with monotonic _ids, so reaching # one means everything left is fresh too. break doc = bson.decode(blob) self._delete_index_entries(db, coll, doc, indexes, partials) doc_cur = self._cursor(_DOC_TABLE) doc_cur.set_key(db, coll, id_k) doc_cur.remove() total -= len(blob) count -= 1 if oplog_on: entry: dict[str, Any] = { "op": "d", "ns": ns, "o": {"_id": doc["_id"]}, "o2": {"_id": doc["_id"]}, } if ui is not None: entry["ui"] = bson.Binary(ui.bytes, subtype=4) oplog_entries.append(entry) pre_images.append(bson.encode(doc) if preimages_on else None) return oplog_entries, pre_images def find_matching( self, db: str, coll: str, filter: dict[str, Any] | None = None, *, skip: int = 0, limit: int = 0, sort: Mapping[str, Any] | None = None, projection: Mapping[str, Any] | None = None, hint: str | Mapping[str, Any] | None = None, let: dict[str, Any] | None = None, collation: Any = None, ) -> list[dict[str, Any]]: if self._is_oplog_rs(db, coll): return self._find_oplog_rs( filter, skip=skip, limit=limit, sort=sort, projection=projection, let=let, collation=collation, ) from secantus.collation import parse as _parse_collation collation_obj = _parse_collation(collation) self._refresh_read_snapshot() filter = filter or {} in_sort_order = False # Two-stage decode discipline: the lock is held only for the # WT cursor walk and any index routing; the COLLSCAN fallback # collects raw blobs while the lock is held and defers # ``bson.decode`` (and the ``matches()`` predicate, sorting, # projection) to *after* the lock releases. Concurrent readers # then decode in parallel even while a writer holds the lock # for inserts. Index-path candidates still come back already # decoded — that's a deeper refactor (Phase 2 territory). candidates: list[dict[str, Any]] | None = None raw_blobs: list[bytes] | None = None with self._lock: sort_field, sort_dir = self._single_sort_spec(sort) if hint is not None: resolved = self._resolve_hint(db, coll, hint) candidates, in_sort_order = self._candidates_from_hint( db, coll, resolved, sort_field, sort_dir ) elif collation_obj is not None: # Index entries are byte-encoded under the default # codepoint ordering — they don't carry the collation's # case- or accent-folding. A case-insensitive equality # like ``{x: "PING"}`` would miss an indexed doc with # ``x: "ping"`` because the entries are stored under # the literal bytes of the inserted value. Force a # COLLSCAN so ``matches()`` does the comparison with # the collation in hand. ``mongod`` only honours an # index when its definition's collation matches the # query's; we don't support per-index collation yet, # so the safe path is always-COLLSCAN-when-collation. raw_blobs = [b for _, b in self._scan_docs(db, coll)] else: candidates = self._try_index_lookup(db, coll, filter) if candidates is not None and sort_field is not None: if ( len(filter) == 1 and not next(iter(filter)).startswith("$") and next(iter(filter)) == sort_field ): in_sort_order = True idx = self._find_leading_field_index(db, coll, sort_field, filter) idx_dir = idx[1] if idx else 1 if sort_dir != idx_dir: candidates = list(reversed(candidates)) elif candidates is None and not filter and sort_field is not None: idx = self._find_leading_field_index(db, coll, sort_field, filter) if idx is not None: idx_name, idx_dir, _is_compound = idx # If the index direction matches the sort direction, # walk forward; if it's opposite, walk backward. reverse = sort_dir != idx_dir candidates = self._walk_index_in_order(db, coll, idx_name, reverse=reverse) in_sort_order = True # Multi-field sort acceleration: when sort has 2+ fields and # filter is empty, try to find a compound index whose key # spec exactly matches (or fully inverts) the sort. Walking # that index in the right direction yields the requested # order without a Python-side post-sort. if candidates is None and not filter and sort_field is None and sort: multi_spec = self._multi_sort_spec(sort) if multi_spec is not None and len(multi_spec) > 1: match = self._compound_index_for_sort(db, coll, multi_spec) if match is not None: idx_name, reverse = match candidates = self._walk_index_in_order( db, coll, idx_name, reverse=reverse ) in_sort_order = True if candidates is None: raw_blobs = [b for _, b in self._scan_docs(db, coll)] if candidates is None: assert raw_blobs is not None candidates = [bson.decode(b) for b in raw_blobs] out = [d for d in candidates if matches(d, filter, vars=let, collation=collation_obj)] if sort and not in_sort_order: out = sort_docs(out, sort) if skip: out = out[skip:] if limit > 0: out = out[:limit] if projection: out = [apply_projection(d, projection) for d in out] return out def _resolve_hint(self, db: str, coll: str, hint: str | Mapping[str, Any]) -> str: """Resolve ``hint`` to an index name (or ``$natural``). ``hint`` may be an index name string, a key-spec dict matching an existing index, ``"$natural"``, or ``{"$natural": +/-1}``. Anything else raises ``BadHint`` so the command layer can return a Mongo ``BadValue`` error. """ if isinstance(hint, str): if hint == "$natural": return "$natural" if hint == _ID_INDEX_NAME: return _ID_INDEX_NAME for name, _key_spec, _sparse, _unique in self._all_indexes(db, coll): if name == hint: return name raise BadHint(f"hint {hint!r} does not correspond to an existing index") if isinstance(hint, Mapping): if list(hint) == ["$natural"]: return "$natural" if list(hint) == ["_id"] and int(hint["_id"]) == 1: return _ID_INDEX_NAME for name, key_spec, _sparse, _unique in self._all_indexes(db, coll): if dict(key_spec) == dict(hint): return name raise BadHint(f"hint {dict(hint)!r} does not correspond to an existing index") raise BadHint(f"invalid hint type: {type(hint).__name__}") def _candidates_from_hint( self, db: str, coll: str, resolved: str, sort_field: str | None, sort_dir: int, ) -> tuple[list[dict[str, Any]], bool]: """Walk the index named by ``resolved`` (or full collection for $natural). Returns ``(candidates, in_sort_order)`` where ``in_sort_order`` is True when the hint's leading field matches the sort field — in which case ``find_matching`` skips the post-sort step. """ if resolved == "$natural": return [bson.decode(b) for _, b in self._scan_docs(db, coll)], False if resolved == _ID_INDEX_NAME: # The doc table is keyed by id_key; iterating it gives entries # sorted by encoded _id, which matches the _id_ index walk. docs = [bson.decode(b) for _, b in self._scan_docs(db, coll)] in_order = sort_field == "_id" if in_order and sort_dir == -1: docs = list(reversed(docs)) return docs, in_order # Find the index's leading field and its direction leading: str | None = None leading_dir = 1 for name, key_spec, _sparse, _unique in self._all_indexes(db, coll): if name == resolved: first = next(iter(key_spec)) leading = first leading_dir = int(key_spec[first]) break candidates = self._walk_index_in_order(db, coll, resolved, reverse=False) in_order = sort_field is not None and sort_field == leading if in_order and sort_dir != leading_dir: candidates = list(reversed(candidates)) return candidates, in_order @staticmethod def _single_sort_spec(sort: Mapping[str, Any] | None) -> tuple[str | None, int]: """Return ``(field, direction)`` if ``sort`` is single-field +/-1, else ``(None, 0)``.""" if not sort or len(sort) != 1: return None, 0 f, d = next(iter(sort.items())) if f.startswith("$"): return None, 0 try: di = int(d) except (TypeError, ValueError): return None, 0 if di not in (-1, 1): return None, 0 return f, di @staticmethod def _multi_sort_spec( sort: Mapping[str, Any] | None, ) -> list[tuple[str, int]] | None: """Return a list of ``(field, direction)`` pairs for a multi-field sort spec, or ``None`` if any entry is operator-prefixed or has a non-``±1`` direction. Used for compound-index sort acceleration: an index whose key spec exactly matches (or fully inverts) the returned list lets ``find_matching`` walk WT in the requested order and skip the Python-side post-sort entirely. """ if not sort: return None out: list[tuple[str, int]] = [] for field, direction in sort.items(): if field.startswith("$"): return None try: d = int(direction) except (TypeError, ValueError): return None if d not in (-1, 1): return None out.append((field, d)) return out def _compound_index_for_sort( self, db: str, coll: str, sort_fields: list[tuple[str, int]] ) -> tuple[str, bool] | None: """Find a compound index that satisfies ``sort_fields`` end-to-end. Returns ``(index_name, reverse_walk)`` where ``reverse_walk`` is True when the matching index is the *fully-inverted* permutation of the sort (walking backward yields the requested order). Multikey indexes are excluded — array values in the index could produce row order that doesn't match the BSON cross-type sort the user expects from a sort spec, so we'd fall back to Python sort anyway. Strict match only: the index key spec must have the same fields in the same order with directions either matching the sort spec or being the full inverse. Partial-prefix matches (sort uses 3 fields, index has 2) aren't accelerated; the savings on the leading prefix are usually less than the cost of the trailing Python sort over the materialised set. """ multikey = self._multikey_index_names(db, coll) target = list(sort_fields) inverted = [(f, -d) for f, d in target] for name, key_spec, _sparse, _unique in self._all_indexes(db, coll): if name in multikey: continue try: idx_pairs = [(f, int(d)) for f, d in key_spec.items()] except (TypeError, ValueError): continue if any(d not in (-1, 1) for _, d in idx_pairs): continue if idx_pairs == target: return name, False if idx_pairs == inverted: return name, True return None def _single_field_index_for(self, db: str, coll: str, field: str) -> tuple[str, int] | None: """Return ``(index_name, direction)`` for a single-field index on ``field``, or ``None`` if no such index exists. Direction is the index's stored sort direction (`+1` for ASC, `-1` for DESC).""" for name, key_spec, _sparse, _unique in self._all_indexes(db, coll): if list(key_spec.keys()) == [field]: d = int(key_spec[field]) if d in (1, -1): return name, d return None def _walk_index_in_order( self, db: str, coll: str, name: str, *, reverse: bool = False ) -> list[dict[str, Any]]: c = self._cursor(_IDX_ENTRIES_TABLE) c.set_key(db, coll, name, b"") rc = c.search_near() if rc == wt.WT_NOTFOUND: return [] if rc < 0 and c.next() != 0: return [] id_keys: list[bytes] = [] while True: k = c.get_key() if (k[0], k[1], k[2]) != (db, coll, name): break packed = bytes(k[3]) _esc, row_id = _unpack_entry(packed) id_keys.append(row_id) if c.next() != 0: break if reverse: id_keys.reverse() return self._docs_by_id_keys(db, coll, id_keys)
[docs] def explain_plan( self, db: str, coll: str, filter: dict[str, Any] | None = None, *, sort: Mapping[str, Any] | None = None, hint: str | Mapping[str, Any] | None = None, ) -> dict[str, Any]: """Plan summary for what ``find_matching`` would do with these args. No execution; mirrors the same routing decisions. Returns ``{"kind": "COLLSCAN"}`` or ``{"kind": "IXSCAN", "index_name", "key_pattern", "direction"}``. ``direction`` is ``"forward"`` unless a sort spec inverts it relative to the chosen index. """ filter = filter or {} with self._lock: sort_field, sort_dir = self._single_sort_spec(sort) if hint is not None: try: resolved = self._resolve_hint(db, coll, hint) except BadHint: return {"kind": "COLLSCAN"} if resolved == "$natural": return {"kind": "COLLSCAN"} if resolved == _ID_INDEX_NAME: direction = "forward" if sort_field == "_id" and sort_dir == -1: direction = "backward" return { "kind": "IXSCAN", "index_name": _ID_INDEX_NAME, "key_pattern": {"_id": 1}, "direction": direction, } key_spec = self._key_spec_for(db, coll, resolved) if key_spec is None: return {"kind": "COLLSCAN"} return self._make_ixscan_plan(resolved, key_spec, sort_field, sort_dir) picked = self._pick_index_for_filter(db, coll, filter) if picked is not None: name, key_spec = picked return self._make_ixscan_plan(name, key_spec, sort_field, sort_dir) if not filter and sort_field is not None: idx = self._find_leading_field_index(db, coll, sort_field, filter) if idx is not None: name, _idx_dir, _is_compound = idx key_spec = self._key_spec_for(db, coll, name) if key_spec is not None: return self._make_ixscan_plan(name, key_spec, sort_field, sort_dir) # Multi-field sort acceleration mirrored in the planner: same # rules as find_matching (compound key spec exactly matches # or fully inverts the sort, filter empty). if not filter and sort_field is None and sort: multi_spec = self._multi_sort_spec(sort) if multi_spec is not None and len(multi_spec) > 1: match = self._compound_index_for_sort(db, coll, multi_spec) if match is not None: name, reverse = match key_spec = self._key_spec_for(db, coll, name) if key_spec is not None: return { "kind": "IXSCAN", "index_name": name, "key_pattern": key_spec, "direction": "backward" if reverse else "forward", } return {"kind": "COLLSCAN"}
def _key_spec_for(self, db: str, coll: str, name: str) -> dict[str, Any] | None: for n, key_spec, _sparse, _unique in self._all_indexes(db, coll): if n == name: return dict(key_spec) return None def _pick_geo_index_for_filter( self, db: str, coll: str, filter: dict[str, Any] ) -> tuple[str, dict[str, Any]] | None: """Mirror :meth:`_try_geo_index_id_keys`'s index selection (no exec). Returns ``(name, key_spec)`` if the filter has a geo operator on a geo-indexed field; ``None`` otherwise. The picker is exact — ``_try_geo_index_id_keys`` may still bail (e.g. ``$near`` with no max distance), but ``explain`` reports IXSCAN whenever an index *could* serve the query, matching mongod's planner explain. """ for field, value in filter.items(): if not isinstance(value, dict): continue if not any(op in value for op in self._GEO_OPS): continue for name, key_spec, _opts in self._iter_indexes(db, coll): geo = _geo_type_of(key_spec) if geo is not None and geo[0] == field: return name, dict(key_spec) return None def _pick_index_for_filter( self, db: str, coll: str, filter: dict[str, Any] ) -> tuple[str, dict[str, Any]] | None: """Mirror ``_try_index_lookup``'s index-selection (no execution).""" if not filter: return None if any(f.startswith("$") for f in filter): return None # Mirror `_try_index_id_keys`: geo dispatch first. geo_pick = self._pick_geo_index_for_filter(db, coll, filter) if geo_pick is not None: return geo_pick if all(not isinstance(v, dict) for v in filter.values()): picked = self._pick_compound_eq_index(db, coll, filter) if picked is not None: return picked if len(filter) >= 2: picked = self._pick_compound_range_index(db, coll, filter) if picked is not None: return picked if len(filter) != 1: return None field, value = next(iter(filter.items())) idx_match = self._find_leading_field_index(db, coll, field, filter) if idx_match is None: return None if isinstance(value, dict): if not value or not all(k.startswith("$") for k in value): return None if not all(op in self._RANGE_OPS for op in value): return None name, _direction, _is_compound = idx_match key_spec = self._key_spec_for(db, coll, name) if key_spec is None: return None return name, key_spec @staticmethod def _make_ixscan_plan( name: str, key_spec: Mapping[str, Any], sort_field: str | None, sort_dir: int, ) -> dict[str, Any]: direction = "forward" if sort_field is not None and sort_field in key_spec: idx_dir = int(key_spec[sort_field]) if sort_dir != 0 and sort_dir != idx_dir: direction = "backward" return { "kind": "IXSCAN", "index_name": name, "key_pattern": dict(key_spec), "direction": direction, } def count_matching( self, db: str, coll: str, filter: dict[str, Any] | None = None, *, let: dict[str, Any] | None = None, collation: Any = None, ) -> int: if self._is_oplog_rs(db, coll): return self._count_oplog_rs(filter, let=let, collation=collation) from secantus.collation import parse as _parse_collation collation_obj = _parse_collation(collation) self._refresh_read_snapshot() if not filter: with self._lock: return sum(1 for _ in self._scan_docs(db, coll)) return sum( 1 for doc in self._all_docs(db, coll) if matches(doc, filter, vars=let, collation=collation_obj) )
[docs] def collection_data_size(self, db: str, coll: str) -> int: """Sum of bson-encoded doc bytes for ``coll``. Used by ``collStats`` / ``dbStats`` for ``size`` / ``dataSize``. Best-effort estimate — doesn't include WT block overhead. """ with self._lock: return sum(len(blob) for _id_k, blob in self._scan_docs(db, coll))
[docs] def index_sizes(self, db: str, coll: str) -> dict[str, int]: """Map of index name → sum of packed entry-key bytes. ``_id_`` is reported separately as ``len(id_key)`` summed across the doc table, so callers can include it alongside secondary indexes for an accurate ``totalIndexSize``. """ with self._lock: sizes: dict[str, int] = {} id_size = sum(len(id_k) for id_k, _blob in self._scan_docs(db, coll)) if id_size: sizes[_ID_INDEX_NAME] = id_size entry_rows = self._collect_prefix(_IDX_ENTRIES_TABLE, (db, coll)) for k, _v in entry_rows: name = k[2] packed = bytes(k[3]) sizes[name] = sizes.get(name, 0) + len(packed) return sizes
def update_matching( self, db: str, coll: str, filter: dict[str, Any], update: dict[str, Any], *, multi: bool = False, upsert: bool = False, array_filters: list[dict[str, Any]] | None = None, let: dict[str, Any] | None = None, collation: Any = None, validator: dict[str, Any] | None = None, journal: bool = False, ) -> dict[str, Any]: from secantus.collation import parse as _parse_collation collation_obj = _parse_collation(collation) # Release any sticky session snapshot before the write's # ``begin_transaction`` acquires a new one. Otherwise the # transaction inherits a stale view and the candidate scan # misses rows committed by other connections (the cross- # connection visibility fix applied to reads — see # ``_refresh_read_snapshot``). self._refresh_read_snapshot() matched = 0 modified = 0 upserted_id: Any = None oplog_entries: list[dict[str, Any]] = [] pre_images: list[bytes | None] = [] oplog_on = self.enable_oplog with self._coll_lock(db, coll), self._batch_transaction(sync=journal): # Per-collection lock + one WT transaction per call. Every # doc-table write + index-entry delete/insert + oplog write # that lands in this method shares a single commit. Phase # 2.4: was self._lock; now per-coll so different # collections update in parallel. self._ensure_collection(db, coll) ns = self._ns(db, coll) ui = self._collection_uuid(db, coll) if oplog_on else None preimages_on = oplog_on and self._pre_post_images_enabled(db, coll) indexes = self._all_indexes(db, coll) partials = self._partial_filters(db, coll) multikey_names = self._multikey_index_names(db, coll) # Index-routed when the filter is covered (only matching id_keys # come back from the index walk); full scan otherwise. Either # way the doc cursor isn't held across writes — bytes are # eagerly buffered. Only matching docs pay ``bson.decode``. # With a collation in effect, fall back to a doc-table scan: # the index entries don't carry the collation's folding, so # an indexed equality probe would miss case-insensitive # matches. Always materialise the list — the update loop # rewrites the doc table via the cached cursor, which # invalidates a still-walking ``_scan_docs`` cursor on the # same session. if collation_obj is not None: candidates = list(self._scan_docs(db, coll)) else: candidates = self._candidates_iter(db, coll, filter) for id_k, blob in candidates: doc = bson.decode(blob) if not matches(doc, filter, vars=let, collation=collation_obj): continue matched += 1 pos = find_positional_matches(doc, filter) new = apply_update( doc, update, array_filters=array_filters, positional_matches=pos, let=let, ) if new != doc: # Document-validator check: collection-level # ``validator`` (set via ``create`` / ``collMod``) # rejects updates whose result fails the predicate. # Caller passes ``None`` to skip # (``bypassDocumentValidation: true``). if validator is not None and not matches(new, validator): raise DocumentValidationError(new.get("_id")) new_id_key = _id_key(new["_id"]) conflict = self._unique_conflict( db, coll, new, indexes, exclude_id_key=id_k, partials=partials ) if conflict is not None: cname, kpat, kval = conflict raise IndexConflict(cname, new["_id"], key_pattern=kpat, key_value=kval) # Geo validation must reject the update before any # write happens, otherwise we'd be left with a # half-deleted set of index entries. self._validate_geo_indexes(db, coll, new, indexes, partials) modified += 1 self._delete_index_entries(db, coll, doc, indexes, partials) doc_cur = self._cursor(_DOC_TABLE) doc_cur[db, coll, new_id_key] = bson.encode(new) self._write_index_entries(db, coll, new, indexes, partials) multikey_names = self._maybe_mark_multikey( db, coll, new, indexes, multikey_names ) if oplog_on: is_replacement = not any( isinstance(k, str) and k.startswith("$") for k in update ) if is_replacement: o_field: dict[str, Any] = dict(new) else: o_field = {"$v": 2, "diff": compute_update_description(doc, new)} oplog_entries.append( { "op": "u", "ns": ns, "ui": bson.Binary(ui.bytes, subtype=4), "o": o_field, "o2": {"_id": doc["_id"]}, } ) pre_images.append(bson.encode(doc) if preimages_on else None) if not multi: break if matched == 0 and upsert: seed: dict[str, Any] = {} for k, v in filter.items(): if not k.startswith("$") and not isinstance(v, dict): seed[k] = v new = apply_update(seed, update, is_upsert=True, array_filters=array_filters) if "_id" not in new: new["_id"] = bson.ObjectId() if validator is not None and not matches(new, validator): raise DocumentValidationError(new.get("_id")) upserted_id = new["_id"] conflict = self._unique_conflict( db, coll, new, indexes, exclude_id_key=None, partials=partials ) if conflict is not None: cname, kpat, kval = conflict raise IndexConflict(cname, new["_id"], key_pattern=kpat, key_value=kval) self._validate_geo_indexes(db, coll, new, indexes, partials) doc_cur = self._cursor(_DOC_TABLE) doc_cur[db, coll, _id_key(upserted_id)] = bson.encode(new) self._write_index_entries(db, coll, new, indexes, partials) self._maybe_mark_multikey(db, coll, new, indexes, multikey_names) if oplog_on: oplog_entries.append( { "op": "i", "ns": ns, "ui": bson.Binary(ui.bytes, subtype=4), "o": dict(new), "o2": {"_id": upserted_id}, } ) pre_images.append(None) cap_ns = ns if oplog_on else "" cap_entries, cap_pre = self._enforce_capped_bounds_locked( db, coll, set(), indexes, partials, oplog_on, cap_ns, ui ) if cap_entries: oplog_entries.extend(cap_entries) pre_images.extend(cap_pre) if oplog_entries: self._emit_oplog(oplog_entries, pre_images) return {"matched": matched, "modified": modified, "upserted_id": upserted_id} def delete_matching( self, db: str, coll: str, filter: dict[str, Any], *, limit: int = 0, let: dict[str, Any] | None = None, collation: Any = None, journal: bool = False, ) -> int: from secantus.collation import parse as _parse_collation collation_obj = _parse_collation(collation) # See ``update_matching`` — release the sticky snapshot so the # candidate scan sees writes committed by other connections. self._refresh_read_snapshot() deleted = 0 oplog_entries: list[dict[str, Any]] = [] pre_images: list[bytes | None] = [] oplog_on = self.enable_oplog with self._coll_lock(db, coll), self._batch_transaction(sync=journal): # Per-collection lock (Phase 2.4) + one WT transaction. # Groups the per-doc removes + index-entry deletes + oplog # writes into one commit. Other collections delete in # parallel. ns = self._ns(db, coll) if oplog_on else "" preimages_on = oplog_on and self._pre_post_images_enabled(db, coll) ui = ( self._collection_uuid(db, coll) if oplog_on and self._coll_options(db, coll) is not None else None ) indexes = self._all_indexes(db, coll) partials = self._partial_filters(db, coll) # Index-routed candidates when the filter is covered; full scan # otherwise. See update_matching for the full-scan rationale. # Collation forces a full scan — index entries don't carry the # collation's folding. Always materialise into a list so the # delete loop's writes don't invalidate the iteration cursor # mid-scan (deletes via ``_cursor(_DOC_TABLE)`` share the # cached cursor with ``_scan_docs``). if collation_obj is not None: candidates = list(self._scan_docs(db, coll)) else: candidates = self._candidates_iter(db, coll, filter) for id_k, blob in candidates: doc = bson.decode(blob) if not matches(doc, filter, vars=let, collation=collation_obj): continue self._delete_index_entries(db, coll, doc, indexes, partials) doc_cur = self._cursor(_DOC_TABLE) doc_cur.set_key(db, coll, id_k) doc_cur.remove() deleted += 1 if oplog_on: entry: dict[str, Any] = { "op": "d", "ns": ns, "o": {"_id": doc["_id"]}, "o2": {"_id": doc["_id"]}, } if ui is not None: entry["ui"] = bson.Binary(ui.bytes, subtype=4) oplog_entries.append(entry) pre_images.append(bson.encode(doc) if preimages_on else None) if limit > 0 and deleted >= limit: break if oplog_entries: self._emit_oplog(oplog_entries, pre_images) return deleted
[docs] def prune_ttl( self, db: str, coll: str, *, now: _dt.datetime | None = None, ) -> int: """Delete docs whose indexed Date field is older than now - TTL. For every index on ``coll`` with an ``expireAfterSeconds`` option, walks the collection and deletes docs whose indexed field resolves to a ``datetime`` older than ``now - expireAfterSeconds``. Docs without the field, with non-date values, or with values inside the TTL window are left in place. Real MongoDB runs this on a 60s background sweeper; SecantusDB invokes it explicitly so tests can drive expiry with an injected ``now``. Returns the number of docs pruned. """ ttl_indexes: list[tuple[str, str, float]] = [] for name, key_spec, opts in self._iter_indexes(db, coll): ttl = opts.get("expireAfterSeconds") if not isinstance(ttl, (int, float)) or ttl < 0: continue field = next(iter(key_spec), None) if not isinstance(field, str): continue ttl_indexes.append((name, field, float(ttl))) if not ttl_indexes: return 0 when = now if now is not None else _dt.datetime.now(_dt.timezone.utc) if when.tzinfo is None: when = when.replace(tzinfo=_dt.timezone.utc) pruned = 0 oplog_entries: list[dict[str, Any]] = [] pre_images: list[bytes | None] = [] with self._lock: ns = self._ns(db, coll) preimages_on = self._pre_post_images_enabled(db, coll) ui = ( self._collection_uuid(db, coll) if self._coll_options(db, coll) is not None else None ) indexes = self._all_indexes(db, coll) partials = self._partial_filters(db, coll) candidates = list(self._scan_docs(db, coll)) for id_k, blob in candidates: doc = bson.decode(blob) expired = False for _name, field, ttl_seconds in ttl_indexes: value = get_path(doc, field) if not isinstance(value, _dt.datetime): continue value_aware = value if value.tzinfo else value.replace(tzinfo=_dt.timezone.utc) if (when - value_aware).total_seconds() > ttl_seconds: expired = True break if not expired: continue self._delete_index_entries(db, coll, doc, indexes, partials) doc_cur = self._cursor(_DOC_TABLE) doc_cur.set_key(db, coll, id_k) doc_cur.remove() pruned += 1 entry: dict[str, Any] = { "op": "d", "ns": ns, "o": {"_id": doc["_id"]}, "o2": {"_id": doc["_id"]}, } if ui is not None: entry["ui"] = bson.Binary(ui.bytes, subtype=4) oplog_entries.append(entry) pre_images.append(bson.encode(doc) if preimages_on else None) if oplog_entries: self._emit_oplog(oplog_entries, pre_images) return pruned
@staticmethod def _table_kf(table: str) -> str: return { _COLL_TABLE: "SS", _DOC_TABLE: "SSu", _IDX_TABLE: "SSS", _IDX_ENTRIES_TABLE: "SSSu", }[table] @staticmethod def _smallest_for_kf(kf: str) -> tuple[Any, ...]: return tuple(b"" if c == "u" else "" for c in kf) def _collect_prefix( self, table: str, prefix: tuple[Any, ...] ) -> list[tuple[tuple[Any, ...], Any]]: c = self._cursor(table) kf = self._table_kf(table) seed = prefix + self._smallest_for_kf(kf)[len(prefix) :] c.set_key(*seed) rc = c.search_near() if rc == wt.WT_NOTFOUND: return [] if rc < 0 and c.next() != 0: return [] out: list[tuple[tuple[Any, ...], Any]] = [] while True: k = tuple(c.get_key()) if k[: len(prefix)] != prefix: break v = c.get_value() out.append((k, bytes(v) if isinstance(v, (bytes, bytearray)) else v)) if c.next() != 0: break return out def _delete_keys(self, table: str, keys: list[tuple[Any, ...]]) -> None: if not keys: return c = self._cursor(table) for k in keys: c.set_key(*k) c.remove() c.reset() def drop_collection(self, db: str, coll: str) -> bool: with self._lock: existed = self._coll_options(db, coll) is not None ui = self._collection_uuid(db, coll) if existed else None for tbl in (_DOC_TABLE, _IDX_TABLE, _IDX_ENTRIES_TABLE): rows = self._collect_prefix(tbl, (db, coll)) self._delete_keys(tbl, [k for k, _ in rows]) c = self._cursor(_COLL_TABLE) c.set_key(db, coll) if c.search() == 0: c.remove() if existed and ui is not None: self._emit_oplog( [ { "op": "c", "ns": f"{db}.$cmd", "ui": bson.Binary(ui.bytes, subtype=4), "o": {"drop": coll}, } ] ) return existed def drop_database(self, db: str) -> None: with self._lock: colls_with_ui: list[tuple[str, _uuid.UUID]] = [] for c_name in self.list_collections(db): ui = self._collection_uuid(db, c_name) colls_with_ui.append((c_name, ui)) for tbl in (_DOC_TABLE, _IDX_TABLE, _IDX_ENTRIES_TABLE, _COLL_TABLE): rows = self._collect_prefix(tbl, (db,)) self._delete_keys(tbl, [k for k, _ in rows]) entries: list[dict[str, Any]] = [] for c_name, ui in colls_with_ui: entries.append( { "op": "c", "ns": f"{db}.$cmd", "ui": bson.Binary(ui.bytes, subtype=4), "o": {"drop": c_name}, } ) entries.append({"op": "c", "ns": f"{db}.$cmd", "o": {"dropDatabase": 1}}) self._emit_oplog(entries) def rename_collection( self, src_db: str, src_coll: str, dst_db: str, dst_coll: str, *, drop_target: bool = False, ) -> tuple[bool, str | None]: with self._lock: if self._coll_options(src_db, src_coll) is None: return False, f"source namespace does not exist: {src_db}.{src_coll}" if (src_db, src_coll) == (dst_db, dst_coll): return True, None ui = self._collection_uuid(src_db, src_coll) dst_existed = self._coll_options(dst_db, dst_coll) is not None dst_ui = self._collection_uuid(dst_db, dst_coll) if dst_existed else None if dst_existed: if not drop_target: return False, f"target namespace exists: {dst_db}.{dst_coll}" for tbl in (_DOC_TABLE, _IDX_TABLE, _IDX_ENTRIES_TABLE): rows = self._collect_prefix(tbl, (dst_db, dst_coll)) self._delete_keys(tbl, [k for k, _ in rows]) c = self._cursor(_COLL_TABLE) c.set_key(dst_db, dst_coll) if c.search() == 0: c.remove() for tbl in (_DOC_TABLE, _IDX_TABLE, _IDX_ENTRIES_TABLE): rows = self._collect_prefix(tbl, (src_db, src_coll)) self._delete_keys(tbl, [k for k, _ in rows]) c = self._cursor(tbl) for k, v in rows: new_k = (dst_db, dst_coll) + k[2:] c.set_key(*new_k) c.set_value(v) c.insert() c.reset() ensure = self._cursor(_COLL_TABLE) ensure.set_key(dst_db, dst_coll) if ensure.search() != 0: ensure.reset() ensure[dst_db, dst_coll] = b"" ensure.reset() ensure.set_key(src_db, src_coll) if ensure.search() == 0: ensure.remove() entries: list[dict[str, Any]] = [] if dst_existed and dst_ui is not None: entries.append( { "op": "c", "ns": f"{dst_db}.$cmd", "ui": bson.Binary(dst_ui.bytes, subtype=4), "o": {"drop": dst_coll}, } ) entries.append( { "op": "c", "ns": f"{src_db}.$cmd", "ui": bson.Binary(ui.bytes, subtype=4), "o": { "renameCollection": f"{src_db}.{src_coll}", "to": f"{dst_db}.{dst_coll}", }, } ) self._emit_oplog(entries) return True, None def list_collections(self, db: str) -> list[str]: self._refresh_read_snapshot() with self._lock: c = self._cursor(_COLL_TABLE) c.set_key(db, "") rc = c.search_near() if rc != wt.WT_NOTFOUND and not (rc < 0 and c.next() != 0): out: list[str] = [] while True: k = c.get_key() if k[0] != db: break out.append(k[1]) if c.next() != 0: break else: out = [] # Synthesise ``local.oplog.rs`` for the ``local`` db whenever the # oplog is enabled. The collection isn't materialised in # ``_COLL_TABLE`` — it's a view over the oplog WT table — but # ``listCollections`` needs to surface it so pymongo clients can # discover it before querying. if self.enable_oplog and db == "local" and "oplog.rs" not in out: out.append("oplog.rs") return sorted(out) def list_databases(self) -> list[str]: self._refresh_read_snapshot() with self._lock: c = self._cursor(_COLL_TABLE) seen: set[str] = set() rc = c.next() while rc == 0: k = c.get_key() seen.add(k[0]) rc = c.next() # mongod always exposes the ``local`` database; mirror that # when the oplog is enabled so listDatabases includes it even # before any user-created collection lands in ``local``. if self.enable_oplog: seen.add("local") return sorted(seen) def create_index( self, db: str, coll: str, name: str, key_spec: Mapping[str, Any], options: Mapping[str, Any] | None = None, ) -> bool: if name == _ID_INDEX_NAME: return False # Text / hashed indexes are documented out-of-scope (CLAUDE.md # "Out of scope regardless: text / hashed / wildcard indexes"). # Surface the rejection as a typed exception (caught in # ``commands._create_indexes``) instead of letting the geo # picker / encoder later fall over with an opaque internal # error. Mongo-node-driver's ``Find should correctly sort using # text search`` test expects a clean error here. for _field, _spec_val in key_spec.items(): if _spec_val in ("text", "hashed"): raise CreateIndexUnsupported(f"{_spec_val} indexes are not supported by SecantusDB") options = dict(options or {}) with self._lock: self._ensure_collection(db, coll) c = self._cursor(_IDX_TABLE) c.set_key(db, coll, name) if c.search() == 0: # Index exists. Mongo rejects re-creation with conflicting # options (different ``unique`` / ``sparse`` / ``hidden`` # / ``expireAfterSeconds``). Silently succeeding hides # a bug surface that mongo-ruby-driver's ``Collection# # create_indexes when index creation fails`` test pins. existing_raw = bytes(c.get_value()) existing = bson.decode(existing_raw) if existing_raw else {} existing_opts = dict(existing.get("options") or {}) _CONFLICTING_OPTS = ( "unique", "sparse", "hidden", "expireAfterSeconds", "partialFilterExpression", ) for opt in _CONFLICTING_OPTS: if (opt in options or opt in existing_opts) and options.get( opt ) != existing_opts.get(opt): raise IndexOptionsConflict( f"Index with name '{name}' already exists with different options" ) return False sparse = bool(options.get("sparse")) unique = bool(options.get("unique")) partial_filter = options.get("partialFilterExpression") if not isinstance(partial_filter, Mapping) or not partial_filter: partial_filter = None key_spec_dict = dict(key_spec) geo = _geo_type_of(key_spec_dict) # Geo indexes use the same entries table but write **multiple** # entries per doc (one per S2 cell or 2d bucket). They're inherently # multikey-style; uniqueness is meaningless for geo and is rejected # by mongod, so we mirror. if geo is not None: if unique: raise IndexConflict(name, None) geo_field, geo_type = geo # Geo indexes are always multikey from the picker's perspective # — each doc may produce many cell entries. Mark it so the # regular pickers skip the index for non-geo queries. options["multikey"] = True entries: list[tuple[bytes, bytes]] = [] for id_k, blob in self._scan_docs(db, coll): d = bson.decode(blob) if partial_filter is not None and not matches(d, partial_filter): continue for cell_bytes in _doc_geo_cells( d, geo_field, geo_type, options, index_name=name ): entries.append((cell_bytes, id_k)) payload = bson.encode({"key": dict(key_spec), "options": options}) c.reset() c[db, coll, name] = payload entry_cur = self._cursor(_IDX_ENTRIES_TABLE) for kb, id_k in entries: entry_cur.reset() entry_cur[db, coll, name, _pack_entry(kb, id_k)] = b"" else: # Single doc-table walk: decode each blob once and fold all # three checks (uniqueness, multikey detection, entry build) # into one pass. Uniqueness is probed against the canonical # whole-doc key (``_index_key``); index entries are written # for every key variant (``_index_key_variants``) so per- # element multikey lookups land at IXSCAN. seen: dict[bytes, Any] | None = {} if unique else None multikey = False entries = [] for id_k, blob in self._scan_docs(db, coll): d = bson.decode(blob) if partial_filter is not None and not matches(d, partial_filter): continue if not multikey and _doc_makes_multikey(d, key_spec_dict): multikey = True if seen is not None: canonical = _index_key(d, key_spec_dict, sparse=sparse) if canonical is not None: if canonical in seen: raise IndexConflict(name, d.get("_id")) seen[canonical] = d.get("_id") for kb in _index_key_variants(d, key_spec_dict, sparse=sparse): entries.append((kb, id_k)) if multikey: options["multikey"] = True payload = bson.encode({"key": dict(key_spec), "options": options}) c.reset() c[db, coll, name] = payload entry_cur = self._cursor(_IDX_ENTRIES_TABLE) for kb, id_k in entries: entry_cur.reset() entry_cur[db, coll, name, _pack_entry(kb, id_k)] = b"" ui = self._collection_uuid(db, coll) self._emit_oplog( [ { "op": "c", "ns": f"{db}.$cmd", "ui": bson.Binary(ui.bytes, subtype=4), "o": { "createIndexes": coll, "indexes": [{"v": 2, "key": dict(key_spec), "name": name, **options}], }, } ] ) return True def list_indexes(self, db: str, coll: str) -> list[dict[str, Any]]: self._refresh_read_snapshot() with self._lock: if self._coll_options(db, coll) is None: return [] out: list[dict[str, Any]] = [{"v": 2, "key": {"_id": 1}, "name": _ID_INDEX_NAME}] for name, key_spec, opts in self._iter_indexes(db, coll): entry: dict[str, Any] = {"v": 2, "key": key_spec, "name": name} for k, v in opts.items(): entry[k] = v out.append(entry) out.sort(key=lambda e: e["name"]) return out def _iter_indexes( self, db: str, coll: str ) -> Iterable[tuple[str, dict[str, Any], dict[str, Any]]]: c = self._cursor(_IDX_TABLE) c.set_key(db, coll, "") rc = c.search_near() if rc == wt.WT_NOTFOUND: return if rc < 0 and c.next() != 0: return while True: k = c.get_key() if k[0] != db or k[1] != coll: return payload = bson.decode(bytes(c.get_value())) yield k[2], payload.get("key", {}), payload.get("options", {}) if c.next() != 0: return def drop_index(self, db: str, coll: str, name: str) -> bool: if name == _ID_INDEX_NAME: return False with self._lock: c = self._cursor(_IDX_TABLE) c.set_key(db, coll, name) if c.search() != 0: return False c.remove() entry_rows = self._collect_prefix(_IDX_ENTRIES_TABLE, (db, coll, name)) self._delete_keys(_IDX_ENTRIES_TABLE, [k for k, _ in entry_rows]) ui = self._collection_uuid(db, coll) self._emit_oplog( [ { "op": "c", "ns": f"{db}.$cmd", "ui": bson.Binary(ui.bytes, subtype=4), "o": {"dropIndexes": coll, "index": name}, } ] ) return True def drop_all_indexes(self, db: str, coll: str) -> int: with self._lock: rows = self._collect_prefix(_IDX_TABLE, (db, coll)) names = [k[2] for k, _ in rows] self._delete_keys(_IDX_TABLE, [k for k, _ in rows]) entry_rows = self._collect_prefix(_IDX_ENTRIES_TABLE, (db, coll)) self._delete_keys(_IDX_ENTRIES_TABLE, [k for k, _ in entry_rows]) if names: ui = self._collection_uuid(db, coll) self._emit_oplog( [ { "op": "c", "ns": f"{db}.$cmd", "ui": bson.Binary(ui.bytes, subtype=4), "o": {"dropIndexes": coll, "index": n}, } for n in names ] ) return len(rows) def _all_indexes(self, db: str, coll: str) -> list[tuple[str, dict[str, Any], bool, bool]]: """Every non-_id_ index: (name, key_spec, sparse, unique).""" out: list[tuple[str, dict[str, Any], bool, bool]] = [] for name, key_spec, opts in list(self._iter_indexes(db, coll)): out.append((name, key_spec, bool(opts.get("sparse")), bool(opts.get("unique")))) return out def _partial_filters(self, db: str, coll: str) -> dict[str, dict[str, Any]]: """Map of index name → ``partialFilterExpression`` for indexes that have one. Indexes without a partial filter are absent from the dict. """ out: dict[str, dict[str, Any]] = {} for name, _key_spec, opts in self._iter_indexes(db, coll): pf = opts.get("partialFilterExpression") if isinstance(pf, Mapping) and pf: out[name] = dict(pf) return out @staticmethod def _query_implies_partial(query: Mapping[str, Any], partial: Mapping[str, Any]) -> bool: """True if ``query`` is at least as restrictive as ``partial`` — every key/value pair in ``partial`` appears with the same bare value in ``query``. Conservative: anything more sophisticated (operator-form clauses, $and, etc.) is treated as not implying the partial filter. """ for key, value in partial.items(): if key not in query: return False if query[key] != value: return False return True def _multikey_index_names(self, db: str, coll: str) -> set[str]: """Names of indexes flagged ``multikey`` (must fall back to scan). Without true multi-key indexing, an index where any doc has a list-valued field can't serve scalar-element matches — so the pickers skip these names and ``find_matching`` falls back to a full scan. """ return { name for name, _key_spec, opts in self._iter_indexes(db, coll) if opts.get("multikey") } def _maybe_mark_multikey( self, db: str, coll: str, doc: Mapping[str, Any], indexes: list[tuple[str, dict[str, Any], bool, bool]], already_multikey: set[str], ) -> set[str]: """For each non-multikey index, flag it if ``doc`` has an array value on any indexed field. Returns the (possibly grown) set of multikey index names so the caller can avoid re-checking. """ c = self._cursor(_IDX_TABLE) for name, key_spec, _sparse, _unique in indexes: if name in already_multikey: continue if not _doc_makes_multikey(doc, key_spec): continue c.reset() c.set_key(db, coll, name) if c.search() != 0: continue payload = bson.decode(bytes(c.get_value())) opts = dict(payload.get("options") or {}) if opts.get("multikey"): already_multikey.add(name) continue opts["multikey"] = True payload["options"] = opts c.reset() c[db, coll, name] = bson.encode(payload) already_multikey.add(name) return already_multikey def _write_index_entries( self, db: str, coll: str, doc: dict[str, Any], indexes: list[tuple[str, dict[str, Any], bool, bool]], partials: dict[str, dict[str, Any]] | None = None, ) -> None: if not indexes: return c = self._cursor(_IDX_ENTRIES_TABLE) id_k = _id_key(doc["_id"]) if partials is None: partials = self._partial_filters(db, coll) index_options = self._index_options_map(db, coll) for name, key_spec, sparse, _unique in indexes: pf = partials.get(name) if pf is not None and not matches(doc, pf): continue geo = _geo_type_of(key_spec) if geo is not None: geo_field, geo_type = geo opts = index_options.get(name, {}) for cell_bytes in _doc_geo_cells(doc, geo_field, geo_type, opts, index_name=name): c.reset() c[db, coll, name, _pack_entry(cell_bytes, id_k)] = b"" continue for kb in _index_key_variants(doc, key_spec, sparse=sparse): c.reset() c[db, coll, name, _pack_entry(kb, id_k)] = b"" def _delete_index_entries( self, db: str, coll: str, doc: dict[str, Any], indexes: list[tuple[str, dict[str, Any], bool, bool]], partials: dict[str, dict[str, Any]] | None = None, ) -> None: if not indexes: return c = self._cursor(_IDX_ENTRIES_TABLE) id_k = _id_key(doc["_id"]) if partials is None: partials = self._partial_filters(db, coll) index_options = self._index_options_map(db, coll) for name, key_spec, sparse, _unique in indexes: pf = partials.get(name) if pf is not None and not matches(doc, pf): continue geo = _geo_type_of(key_spec) if geo is not None: geo_field, geo_type = geo opts = index_options.get(name, {}) # On the delete path, swallow GeoExtractError. A doc that # was inserted before geo validation became strict might # have bad geometry; we still need to allow it to be # deleted. The index may end up with stale entries we # can't match, but the next compact / drop_index cleans # those up. Insert/update remain strict. try: cells = _doc_geo_cells(doc, geo_field, geo_type, opts, index_name=name) except GeoExtractError: continue for cell_bytes in cells: c.reset() c.set_key(db, coll, name, _pack_entry(cell_bytes, id_k)) if c.search() == 0: c.remove() continue for kb in _index_key_variants(doc, key_spec, sparse=sparse): c.reset() c.set_key(db, coll, name, _pack_entry(kb, id_k)) if c.search() == 0: c.remove() def _validate_geo_indexes( self, db: str, coll: str, doc: dict[str, Any], indexes: list[tuple[str, dict[str, Any], bool, bool]], partials: dict[str, dict[str, Any]] | None = None, ) -> None: """Pre-flight every geo index for ``doc``: raise on bad geometry. Used by the insert / update paths to reject docs *before* writing them, so a single bad geo coordinate doesn't leave a half-indexed document behind. The work duplicates ``_write_index_entries``'s cell computation but is cheap (one Shapely parse + bounds check per indexed field). """ if not indexes: return if partials is None: partials = self._partial_filters(db, coll) options_map = self._index_options_map(db, coll) for name, key_spec, _sparse, _unique in indexes: geo = _geo_type_of(key_spec) if geo is None: continue pf = partials.get(name) if pf is not None and not matches(doc, pf): continue geo_field, geo_type = geo opts = options_map.get(name, {}) # Compute & discard — `_doc_geo_cells` raises GeoExtractError # on bad shape or out-of-bounds coords; that's the signal we # want to bubble up. _doc_geo_cells(doc, geo_field, geo_type, opts, index_name=name) def _index_options_map(self, db: str, coll: str) -> dict[str, dict[str, Any]]: """Map of index name → its full options blob. Used by the geo write/delete paths: 2d indexes carry per-index ``bits`` / ``min`` / ``max`` settings that affect the cell encoder, so we need the option blob to compute the right bucket. Cached per call (the caller handles per-doc loops). """ return {name: dict(opts) for name, _key_spec, opts in self._iter_indexes(db, coll)} def _unique_conflict( self, db: str, coll: str, candidate_doc: dict[str, Any], indexes: list[tuple[str, dict[str, Any], bool, bool]], *, exclude_id_key: bytes | None, partials: dict[str, dict[str, Any]] | None = None, ) -> tuple[str, dict[str, Any], dict[str, Any]] | None: # Returns ``(index_name, key_pattern, key_value)`` so callers # can build a mongod-shaped dup-key error response with the # ``keyPattern`` + ``keyValue`` fields drivers' errorResponse # tests assert on. ``None`` when no conflict. if not indexes: return None c = self._cursor(_IDX_ENTRIES_TABLE) if partials is None: partials = self._partial_filters(db, coll) for name, key_spec, sparse, unique in indexes: if not unique: continue pf = partials.get(name) if pf is not None and not matches(candidate_doc, pf): continue kb = _index_key(candidate_doc, key_spec, sparse=sparse) if kb is None: continue esc_kb = _escape_kb(kb) seed = esc_kb + _ENTRY_SEP c.reset() c.set_key(db, coll, name, seed) rc = c.search_near() if rc == wt.WT_NOTFOUND: continue if rc < 0 and c.next() != 0: continue while True: k = c.get_key() if (k[0], k[1], k[2]) != (db, coll, name): break packed = bytes(k[3]) row_esc, row_id = _unpack_entry(packed) if row_esc != esc_kb: break if exclude_id_key is None or row_id != exclude_id_key: key_value = { field: get_path(candidate_doc, field, default=None) for field in key_spec } return name, dict(key_spec), key_value if c.next() != 0: break return None def _scan_index_for_id_keys( self, db: str, coll: str, name: str, kb: bytes, *, prefix: bool = False ) -> list[bytes]: """Walk the index entries for ``name`` matching ``kb``. With ``prefix=False`` (default), only rows whose ``escaped_kb`` is exactly equal to ``escape(kb)`` are returned — equality lookup. With ``prefix=True``, any row whose ``escaped_kb`` starts with ``escape(kb)`` is returned — compound-prefix lookup. """ c = self._cursor(_IDX_ENTRIES_TABLE) esc_kb = _escape_kb(kb) seed = esc_kb if prefix else esc_kb + _ENTRY_SEP c.set_key(db, coll, name, seed) rc = c.search_near() if rc == wt.WT_NOTFOUND: return [] if rc < 0 and c.next() != 0: return [] out: list[bytes] = [] while True: k = c.get_key() if (k[0], k[1], k[2]) != (db, coll, name): break packed = bytes(k[3]) row_esc, row_id = _unpack_entry(packed) if prefix: if not row_esc.startswith(esc_kb): break elif row_esc != esc_kb: break out.append(row_id) if c.next() != 0: break return out def _docs_by_id_keys(self, db: str, coll: str, id_keys: list[bytes]) -> list[dict[str, Any]]: if not id_keys: return [] c = self._cursor(_DOC_TABLE) # Two-stage: WT cursor walk first (raw bytes), then ``bson.decode`` # outside that loop. The cursor work is what needs lock scope; # decode is pure CPU and benefits from running unsynchronised. # Multikey indexes write per-element entries, so the same doc's # id_key can appear more than once for queries that match # multiple elements. Dedupe while preserving order. raw: list[bytes] = [] for id_k in dict.fromkeys(id_keys): c.reset() c.set_key(db, coll, id_k) if c.search() == 0: raw.append(bytes(c.get_value())) return [bson.decode(b) for b in raw] _RANGE_OPS: tuple[str, ...] = ("$eq", "$gt", "$gte", "$lt", "$lte", "$in") _GEO_OPS: tuple[str, ...] = ("$geoWithin", "$geoIntersects", "$near", "$nearSphere") def _try_geo_index_id_keys( self, db: str, coll: str, filter: dict[str, Any] ) -> list[bytes] | None: """If ``filter`` contains a geo operator on a geo-indexed field, scan that index's covering cells and return candidate id_keys. Returns ``None`` when no geo operator is present or no matching geo index exists — caller falls through to regular pickers and eventually a full scan. Returns a list (possibly empty) when a geo index covers the query — caller short-circuits the regular pickers. The cell scan over-collects (cell-covering is a superset of the true intersection); the caller's ``matches()`` step verifies via :func:`secantus.geo.geo_within` / ``geo_intersects`` and removes false positives. """ # Find a single field with a geo operator on it. geo_field: str | None = None geo_op: str | None = None geo_arg: Any = None for field, value in filter.items(): if not isinstance(value, dict): continue for op in self._GEO_OPS: if op in value: geo_field = field geo_op = op geo_arg = value[op] break if geo_field is not None: break if geo_field is None: return None # Locate a geo index on that field. chosen_name: str | None = None chosen_type: str | None = None chosen_opts: dict[str, Any] = {} for name, key_spec, opts in self._iter_indexes(db, coll): geo = _geo_type_of(key_spec) if geo is None: continue if geo[0] == geo_field: chosen_name = name chosen_type = geo[1] chosen_opts = dict(opts) break if chosen_name is None or chosen_type is None: return None # Build the query geometry from the operator arg. cells = self._geo_query_cells(geo_op, geo_arg, chosen_type, chosen_opts) if cells is None: # Couldn't compute a covering — defer to full scan. return None return self._collect_geo_candidates(db, coll, chosen_name, cells) def _geo_query_cells( self, op: str, arg: Any, geo_type: str, options: Mapping[str, Any] ) -> list[tuple[bytes, bytes]] | None: """Byte ranges covering the query geometry, one per covering cell. Both 2dsphere and 2d return ``list[tuple[bytes, bytes]]`` — for 2dsphere each entry is the (range_min, range_max) byte pair of an S2 covering cell expanded to its leaf descendants; for 2d it's the single (lo, hi) bbox range from `planar_2d_covering`. Callers use :meth:`_scan_geo_range` for both. """ from secantus.geo import GeoError try: if op in ("$geoWithin", "$geoIntersects"): if not isinstance(arg, Mapping): return None geom, _ = parse_query_geometry(arg) elif op in ("$near", "$nearSphere"): # `$near` without a max distance: caller falls through to # full scan (signal None). With a max, expand into a cap # (2dsphere) or planar disk (2d). center, max_d, _min_d, _spherical = self._near_query_geom(arg) if max_d is None: return None from shapely.geometry import Point as _Point from secantus.geo import _SphericalCircle if geo_type == _GEO_2DSPHERE: from secantus.geo import EARTH_RADIUS_METERS radius_rad = max_d / EARTH_RADIUS_METERS geom = _SphericalCircle(center[0], center[1], radius_rad) else: # 2d planar — circular disk geom = _Point(*center).buffer(max_d, quad_segs=16) else: return None except GeoError: return None if geo_type == _GEO_2DSPHERE: # Each cell becomes a degenerate (cell, cell) range so the # storage scanner does an exact point-lookup. Treating # 2dsphere uniformly as a list-of-ranges keeps the storage # path single-shaped. return [(encode_cell(c), encode_cell(c)) for c in s2_query_covering(geom)] # 2d: shape must be planar; convert to a single (lo, hi) range. from shapely.geometry.base import BaseGeometry as _BG if not isinstance(geom, _BG): return None lo, hi = planar_2d_covering(geom, options) return [(encode_cell(lo), encode_cell(hi))] def _near_query_geom( self, arg: Any ) -> tuple[tuple[float, float], float | None, float | None, bool]: """Reuse :mod:`secantus.query`'s ``$near`` parser for the picker. Routing it through `_parse_near_spec` keeps the spec semantics in one place — the operator handler and the picker agree on what a ``$near`` arg means. """ from secantus.query import _parse_near_spec # type: ignore[attr-defined] return _parse_near_spec(arg, default_spherical=False) def _collect_geo_candidates( self, db: str, coll: str, index_name: str, cells: list[tuple[bytes, bytes]], ) -> list[bytes]: """Walk index entries in each (lo, hi) range; return deduplicated id_keys. A doc with N covering cells produces N index entries; we collect just one ``_id`` per doc. The post-fetch verifier (in ``find_matching``'s ``matches()`` step) discards docs whose actual geometry doesn't match the query. """ c = self._cursor(_IDX_ENTRIES_TABLE) seen: set[bytes] = set() out: list[bytes] = [] for lo_bytes, hi_bytes in cells: self._scan_geo_range(c, db, coll, index_name, lo_bytes, hi_bytes, seen, out) return out def _scan_geo_range( self, c: Any, db: str, coll: str, name: str, lo_bytes: bytes, hi_bytes: bytes, seen: set[bytes], out: list[bytes], ) -> None: """Walk every index entry whose escaped cell-id is in [lo_bytes, hi_bytes]. Lex byte order over `_escape_kb`-escaped fixed-width cell IDs is the same as numeric cell-id order, so a forward WT cursor walk between the two escaped boundary keys visits every entry inside the range exactly once. Cell IDs are packed as fixed 8-byte big-endian, so escaping never changes their relative order. """ lo_prefix = _escape_kb(lo_bytes) hi_prefix = _escape_kb(hi_bytes) c.reset() c.set_key(db, coll, name, lo_prefix) rc = c.search_near() if rc == wt.WT_NOTFOUND: return if rc < 0 and c.next() != 0: return while True: k = c.get_key() if k[0] != db or k[1] != coll or k[2] != name: return packed = bytes(k[3]) sep_pos = packed.find(_ENTRY_SEP) if sep_pos < 0: if c.next() != 0: return continue kb_part = packed[:sep_pos] if kb_part > hi_prefix: return id_key = packed[sep_pos + len(_ENTRY_SEP) :] if id_key not in seen: seen.add(id_key) out.append(id_key) if c.next() != 0: return def _try_index_lookup( self, db: str, coll: str, filter: dict[str, Any] ) -> list[dict[str, Any]] | None: id_keys = self._try_index_id_keys(db, coll, filter) if id_keys is None: return None return self._docs_by_id_keys(db, coll, id_keys) def _try_index_id_keys(self, db: str, coll: str, filter: dict[str, Any]) -> list[bytes] | None: """Same dispatch as ``_try_index_lookup`` but returns id_keys instead of materialised docs. Used by the write paths (update / delete) so only matching docs pay ``bson.decode``. """ if not filter: return None if any(f.startswith("$") for f in filter): return None # Geo dispatch first — a $geoWithin / $geoIntersects / $near clause # on a field with a 2dsphere or 2d index uses the cell-covering # path. The picker returns None if no geo index covers the query, # and we fall through to the regular pickers below. geo_ids = self._try_geo_index_id_keys(db, coll, filter) if geo_ids is not None: return geo_ids # Bare-equality filters of any size can use a compound index whose # leading fields cover the filter set. if all(not isinstance(v, dict) for v in filter.values()): result = self._try_compound_eq_id_keys(db, coll, filter) if result is not None: return result # Compound prefix + trailing operator field (eq fields then range/in). if len(filter) >= 2: result = self._try_compound_range_id_keys(db, coll, filter) if result is not None: return result if len(filter) != 1: return None field, value = next(iter(filter.items())) idx_match = self._find_leading_field_index(db, coll, field, filter) if idx_match is None: return None return self._lookup_id_keys_via_leading_field(db, coll, idx_match, value) def _candidates_iter( self, db: str, coll: str, filter: dict[str, Any] | None ) -> list[tuple[bytes, bytes]]: """Return (id_key, blob) pairs that the write paths should consider. If an index covers the filter, only the indexed candidates are fetched; otherwise the full doc table is scanned. Either way, BSON decode is left to the caller so non-matching docs don't pay for it. Caller still applies ``matches()`` to the decoded doc — index lookups can produce false-positive candidates for partial scans (multikey, prefix overlap, etc). """ if filter: id_keys = self._try_index_id_keys(db, coll, filter) if id_keys is not None: c = self._cursor(_DOC_TABLE) out: list[tuple[bytes, bytes]] = [] # Same dedup contract as ``_docs_by_id_keys``: multikey # indexes can yield duplicate id_keys for one doc. for id_k in dict.fromkeys(id_keys): c.reset() c.set_key(db, coll, id_k) if c.search() == 0: out.append((id_k, bytes(c.get_value()))) return out return list(self._scan_docs(db, coll)) def _find_leading_field_index( self, db: str, coll: str, field: str, query: Mapping[str, Any] | None = None, ) -> tuple[str, int, bool] | None: """Best index whose leading field is ``field``. Returns ``(name, direction, is_compound)``. Single-field indexes win over compound (tighter scan, no separator math). All fields must be ASC or DESC. Partial indexes are skipped unless ``query`` implies their ``partialFilterExpression``. Multikey indexes are not skipped — ``_index_key_variants`` writes per-element entries, so equality / range / ``$in`` lookups on the leading field hit at least all true matches. The geo ``2dsphere`` / ``2d`` indexes have non-numeric direction values and are excluded by the ASC/DESC check below. """ partials = self._partial_filters(db, coll) query = query or {} compound_fallback: tuple[str, int, bool] | None = None for name, key_spec, _sparse, _unique in self._all_indexes(db, coll): pf = partials.get(name) if pf is not None and not self._query_implies_partial(query, pf): continue idx_fields = list(key_spec) if not idx_fields or idx_fields[0] != field: continue # Geo / hashed / text indexes carry string direction values # ("2dsphere", "2d", "hashed", "text"); the bare equality # picker can't drive them. Real numeric direction values are # 1 / -1. if any(key_spec[f] not in (1, -1) for f in idx_fields): continue d = int(key_spec[field]) if len(idx_fields) == 1: return name, d, False if compound_fallback is None: compound_fallback = (name, d, True) return compound_fallback def _lookup_id_keys_via_leading_field( self, db: str, coll: str, idx_match: tuple[str, int, bool], value: Any, ) -> list[bytes] | None: name, direction, is_compound = idx_match if not isinstance(value, dict): return self._eq_id_keys_via_leading(db, coll, name, direction, is_compound, value) if not value or not all(k.startswith("$") for k in value): return None if not all(op in self._RANGE_OPS for op in value): return None if "$in" in value: if len(value) != 1 or not isinstance(value["$in"], list): return None seen: set[bytes] = set() id_keys: list[bytes] = [] for v in value["$in"]: if isinstance(v, dict): return None for id_k in self._eq_id_keys_via_leading(db, coll, name, direction, is_compound, v): if id_k not in seen: seen.add(id_k) id_keys.append(id_k) return id_keys lower: bytes | None = None lower_inclusive = True upper: bytes | None = None upper_inclusive = True for op, bound in value.items(): if isinstance(bound, dict): return None if op == "$eq": return self._eq_id_keys_via_leading(db, coll, name, direction, is_compound, bound) kb = encode_value_directed(bound, direction) # Operator semantics flip when stored bytes are inverted: in a # DESC index, "x > 5" means we want stored bytes < enc_desc(5). effective_op = op if direction == -1: effective_op = {"$gt": "$lt", "$gte": "$lte", "$lt": "$gt", "$lte": "$gte"}[op] if effective_op == "$gt": lower, lower_inclusive = kb, False elif effective_op == "$gte": lower, lower_inclusive = kb, True elif effective_op == "$lt": upper, upper_inclusive = kb, False elif effective_op == "$lte": upper, upper_inclusive = kb, True if is_compound: return self._range_scan_index_leading( db, coll, name, lower, lower_inclusive, upper, upper_inclusive ) return self._range_scan_index( db, coll, name, lower, lower_inclusive, upper, upper_inclusive ) def _eq_id_keys_via_leading( self, db: str, coll: str, name: str, direction: int, is_compound: bool, value: Any, ) -> list[bytes]: kb = encode_value_directed(value, direction) if is_compound: return self._scan_index_for_id_keys(db, coll, name, kb + COMPOUND_SEP, prefix=True) return self._scan_index_for_id_keys(db, coll, name, kb) def _pick_compound_eq_index( self, db: str, coll: str, filter: dict[str, Any] ) -> tuple[str, dict[str, Any]] | None: """Find the index that ``_try_compound_eq_id_keys`` would walk for ``filter``. Returns ``(name, key_spec)`` of the chosen index, or ``None`` if no index covers the filter as a leading prefix. Pure picker — does not scan. Multikey indexes are eligible (per-element entries cover equality lookups); the ASC/DESC direction check excludes geo indexes. """ filter_fields = set(filter) partials = self._partial_filters(db, coll) best: tuple[str, dict[str, Any]] | None = None for name, key_spec, _sparse, _unique in self._all_indexes(db, coll): pf = partials.get(name) if pf is not None: if not self._query_implies_partial(filter, pf): continue # Partial-filter clauses are guaranteed by the index itself, # so they don't have to appear in the index key. eff_fields = filter_fields - set(pf) else: eff_fields = filter_fields idx_fields = list(key_spec.keys()) # Geo / hashed / text indexes (string direction values) can't # serve a bare-equality compound lookup. if any(key_spec[f] not in (1, -1) for f in idx_fields): continue if len(idx_fields) < len(eff_fields): continue if set(idx_fields[: len(eff_fields)]) != eff_fields: continue if best is None or (len(list(best[1])) > len(idx_fields)): best = (name, dict(key_spec)) if len(idx_fields) == len(eff_fields): break return best def _try_compound_eq_id_keys( self, db: str, coll: str, filter: dict[str, Any] ) -> list[bytes] | None: """Bare-equality filter against a compound (or single-field) index prefix. Picks an index whose leading fields (set-wise) match the filter's fields, and runs an equality (full-cover) or prefix (strict-leading-prefix) scan against it. Per-field index direction is honoured by encoding each value with ``encode_value_directed``. """ picked = self._pick_compound_eq_index(db, coll, filter) if picked is None: return None name, key_spec = picked idx_fields = list(key_spec) # Build kb from the filter fields that are in the index (partial-filter # clauses live outside the key and are guaranteed by index population). prefix_fields = [f for f in idx_fields if f in filter] parts = [encode_value_directed(filter[f], int(key_spec[f])) for f in prefix_fields] kb = COMPOUND_SEP.join(parts) if len(parts) > 1 else parts[0] if len(prefix_fields) == len(idx_fields): return self._scan_index_for_id_keys(db, coll, name, kb) kb = kb + COMPOUND_SEP return self._scan_index_for_id_keys(db, coll, name, kb, prefix=True) def _partition_compound_range_filter( self, filter: dict[str, Any] ) -> tuple[dict[str, Any], str, dict[str, Any]] | None: """Split a filter into ``(eq_fields, operator_field, operator_ops)``. Returns ``None`` if the filter doesn't fit the compound-range shape (any number of bare-equality fields plus exactly one operator-form field whose ops are all in ``_RANGE_OPS``). """ eq_fields: dict[str, Any] = {} operator_field: str | None = None operator_ops: dict[str, Any] | None = None for f, v in filter.items(): if isinstance(v, dict): if not v or not all(k.startswith("$") for k in v): return None if not all(op in self._RANGE_OPS for op in v): return None if operator_field is not None: return None operator_field = f operator_ops = v else: eq_fields[f] = v if operator_field is None or not eq_fields: return None if operator_field in eq_fields: return None return eq_fields, operator_field, operator_ops or {} def _pick_compound_range_index( self, db: str, coll: str, filter: dict[str, Any] ) -> tuple[str, dict[str, Any]] | None: """Find the index that ``_try_compound_range_id_keys`` would walk.""" parts = self._partition_compound_range_filter(filter) if parts is None: return None eq_fields, operator_field, _operator_ops = parts eq_set = set(eq_fields) target_eq_count = len(eq_set) partials = self._partial_filters(db, coll) best: tuple[str, dict[str, Any]] | None = None for name, key_spec, _sparse, _unique in self._all_indexes(db, coll): pf = partials.get(name) if pf is not None and not self._query_implies_partial(filter, pf): continue idx_fields = list(key_spec.keys()) # Geo / hashed / text indexes (string direction values) can't # serve a compound prefix + trailing-operator lookup. if any(key_spec[f] not in (1, -1) for f in idx_fields): continue if len(idx_fields) <= target_eq_count: continue if set(idx_fields[:target_eq_count]) != eq_set: continue if idx_fields[target_eq_count] != operator_field: continue if best is None or len(list(best[1])) > len(idx_fields): best = (name, dict(key_spec)) if len(idx_fields) == target_eq_count + 1: break return best def _try_compound_range_id_keys( self, db: str, coll: str, filter: dict[str, Any] ) -> list[bytes] | None: """Compound-prefix lookup with a trailing operator field. Filters of the form ``{a: 5, b: 10, c: {$gt: 20}}`` (any number of leading bare-equality fields followed by exactly one operator-form field) walk the compound index by pinning the prefix from the equalities and applying the operator's bounds to the next field. """ parts = self._partition_compound_range_filter(filter) if parts is None: return None eq_fields, operator_field, operator_ops = parts picked = self._pick_compound_range_index(db, coll, filter) if picked is None: return None name, key_spec = picked idx_fields = list(key_spec) target_eq_count = len(eq_fields) eq_field_names = idx_fields[:target_eq_count] op_dir = int(key_spec[operator_field]) eq_parts = [encode_value_directed(eq_fields[f], int(key_spec[f])) for f in eq_field_names] prefix_kb = COMPOUND_SEP.join(eq_parts) if len(eq_parts) > 1 else eq_parts[0] prefix_with_sep = prefix_kb + COMPOUND_SEP if "$in" in operator_ops: if len(operator_ops) != 1 or not isinstance(operator_ops["$in"], list): return None seen: set[bytes] = set() id_keys: list[bytes] = [] for v in operator_ops["$in"]: if isinstance(v, dict): return None kb = prefix_with_sep + encode_value_directed(v, op_dir) use_prefix = len(idx_fields) > target_eq_count + 1 inner_kb = kb + COMPOUND_SEP if use_prefix else kb for id_k in self._scan_index_for_id_keys( db, coll, name, inner_kb, prefix=use_prefix ): if id_k not in seen: seen.add(id_k) id_keys.append(id_k) return id_keys if "$eq" in operator_ops: if len(operator_ops) != 1: return None kb = prefix_with_sep + encode_value_directed(operator_ops["$eq"], op_dir) use_prefix = len(idx_fields) > target_eq_count + 1 inner_kb = kb + COMPOUND_SEP if use_prefix else kb return self._scan_index_for_id_keys(db, coll, name, inner_kb, prefix=use_prefix) lower: bytes | None = None lower_inclusive = True upper: bytes | None = None upper_inclusive = True for op, bound in operator_ops.items(): if isinstance(bound, dict): return None full = prefix_with_sep + encode_value_directed(bound, op_dir) effective_op = op if op_dir == -1: effective_op = {"$gt": "$lt", "$gte": "$lte", "$lt": "$gt", "$lte": "$gte"}[op] if effective_op == "$gt": lower, lower_inclusive = full, False elif effective_op == "$gte": lower, lower_inclusive = full, True elif effective_op == "$lt": upper, upper_inclusive = full, False elif effective_op == "$lte": upper, upper_inclusive = full, True else: return None return self._range_scan_index( db, coll, name, lower, lower_inclusive, upper, upper_inclusive, prefix=prefix_with_sep, ) def _range_scan_index( self, db: str, coll: str, name: str, lower: bytes | None, lower_inclusive: bool, upper: bytes | None, upper_inclusive: bool, *, prefix: bytes | None = None, ) -> list[bytes]: """Range-scan the index entries for ``name``. Optional ``prefix`` constrains the scan to entries whose escaped kb starts with ``escape(prefix)`` — used by compound-index prefix+range queries where leading equalities pin part of the kb. """ c = self._cursor(_IDX_ENTRIES_TABLE) esc_prefix = _escape_kb(prefix) if prefix is not None else None esc_lower = _escape_kb(lower) if lower is not None else None esc_upper = _escape_kb(upper) if upper is not None else None if esc_lower is not None: seed = esc_lower + _ENTRY_SEP elif esc_prefix is not None: seed = esc_prefix else: seed = b"" c.set_key(db, coll, name, seed) rc = c.search_near() if rc == wt.WT_NOTFOUND: return [] if rc < 0 and c.next() != 0: return [] out: list[bytes] = [] while True: k = c.get_key() if (k[0], k[1], k[2]) != (db, coll, name): break packed = bytes(k[3]) row_esc, row_id = _unpack_entry(packed) if esc_prefix is not None and not row_esc.startswith(esc_prefix): break if esc_lower is not None and not lower_inclusive and row_esc == esc_lower: if c.next() != 0: break continue if esc_upper is not None: if upper_inclusive: if row_esc > esc_upper: break elif row_esc >= esc_upper: break out.append(row_id) if c.next() != 0: break return out def _range_scan_index_leading( self, db: str, coll: str, name: str, lower: bytes | None, lower_inclusive: bool, upper: bytes | None, upper_inclusive: bool, ) -> list[bytes]: """Range-scan a compound index using only its leading field. Each row's escaped kb is ``escape(enc(leading)) + escape(COMPOUND_SEP) + escape(enc(trailing...))``. Boundary detection uses ``startswith(esc_X + esc_compound_sep)`` to identify rows whose leading field equals ``X`` — the terminator bytes of an escaped numeric encoding can overlap with the start of the escaped compound separator, so a literal find/split on the separator is unreliable. """ esc_compound_sep = _escape_kb(COMPOUND_SEP) c = self._cursor(_IDX_ENTRIES_TABLE) esc_lower = _escape_kb(lower) if lower is not None else None esc_upper = _escape_kb(upper) if upper is not None else None seed = esc_lower if esc_lower is not None else b"" c.set_key(db, coll, name, seed) rc = c.search_near() if rc == wt.WT_NOTFOUND: return [] if rc < 0 and c.next() != 0: return [] lower_eq_prefix = esc_lower + esc_compound_sep if esc_lower is not None else None upper_eq_prefix = esc_upper + esc_compound_sep if esc_upper is not None else None out: list[bytes] = [] while True: k = c.get_key() if (k[0], k[1], k[2]) != (db, coll, name): break packed = bytes(k[3]) row_esc, row_id = _unpack_entry(packed) if ( lower_eq_prefix is not None and not lower_inclusive and row_esc.startswith(lower_eq_prefix) ): if c.next() != 0: break continue if esc_upper is not None: if upper_inclusive: if row_esc > esc_upper and not row_esc.startswith(upper_eq_prefix): break elif row_esc >= esc_upper: break out.append(row_id) if c.next() != 0: break return out