"""WiredTiger-backed document store.
WiredTiger is the default storage engine for MongoDB. We use the same
engine here so that on-disk semantics line up with what test code would
see against a real ``mongod``.
Indexes use a sidecar entries table (``table:secantus_index_entries``)
with a single trailing ``u`` column packing
``escape(sortkey) + b"\\x00\\x00" + id_key``. The sortkey comes from
``secantus.sortkey`` (typed, byte-sortable BSON encoding), so the WT
B-tree gives us ordered access for free. ``find_matching`` routes a wide
range of filter shapes through the index — equality, ``$eq``, ``$in``,
``$gt``/``$gte``/``$lt``/``$lte`` on a single field, plus compound
indexes when filter fields cover a leading prefix (with optional range
on the next field). Sort-by-indexed-field walks the B-tree in order.
"""
from __future__ import annotations
import contextlib
import datetime as _dt
import os
import shutil
import tempfile
import threading
import time as _time
import uuid as _uuid
from collections.abc import Callable, Iterable, Mapping
from decimal import Decimal, InvalidOperation
from typing import Any
import bson
import wiredtiger as wt
from bson import Decimal128
from bson.timestamp import Timestamp
from secantus.diff import compute_update_description
from secantus.geo import GeoError, parse_doc_geometry, parse_query_geometry, validate_coordinates
from secantus.geo_index import (
encode_cell,
planar_2d_covering,
planar_2d_index_for_point,
s2_doc_covering,
s2_query_covering,
)
from secantus.paths import get_path, has_path
from secantus.projection import apply_projection
from secantus.query import matches
from secantus.sortkey import COMPOUND_SEP, encode_value, encode_value_directed
from secantus.update import apply_update, find_positional_matches
_GEO_2DSPHERE = "2dsphere"
_GEO_2D = "2d"
_GEO_TYPES = frozenset({_GEO_2DSPHERE, _GEO_2D})
def _geo_type_of(key_spec: Mapping[str, Any]) -> tuple[str, str] | None:
"""Return ``(field, geo_type)`` if ``key_spec`` declares a geo index.
A geo index has exactly one field whose value is the string
``"2dsphere"`` or ``"2d"`` (rather than ``1`` / ``-1``). Compound
geo indexes (geo field + scalar trailing fields) are out of scope
in Phase 2; we treat any spec containing a geo field as geo-only
and ignore the trailing fields. The picker still works because
`$geoWithin` etc. are answered by the cell scan + verifier.
"""
for field, value in key_spec.items():
if isinstance(value, str) and value in _GEO_TYPES:
return field, value
return None
def _doc_geo_cells(
doc: Mapping[str, Any],
field: str,
geo_type: str,
options: Mapping[str, Any],
*,
index_name: str = "",
) -> list[bytes]:
"""Encoded cell bytes for the doc's geo field.
Returns an empty list when the indexed field is missing or null
(sparse-by-default semantics, matching mongod's 2dsphere/2d).
Raises :class:`GeoExtractError` when the value is *present* but
can't be indexed — unparseable shape, wrong type for a 2d index,
or coordinates outside the valid range. The caller propagates this
to the wire as a write error (code 16572 "Can't extract geo keys").
"""
value = get_path(dict(doc), field)
if value is None:
# Field missing or explicitly null — sparse semantics, no entry.
return []
geom = parse_doc_geometry(value)
if geom is None:
raise GeoExtractError(
index_name,
field,
doc.get("_id"),
f"value {value!r} is not a recognised geometry",
)
try:
validate_coordinates(geom, geo_type=geo_type, options=options)
except GeoError as exc:
raise GeoExtractError(index_name, field, doc.get("_id"), str(exc)) from exc
if geo_type == _GEO_2DSPHERE:
return [encode_cell(c) for c in s2_doc_covering(geom)]
# 2d: single point only.
from shapely.geometry import Point as _Point
if not isinstance(geom, _Point):
raise GeoExtractError(
index_name,
field,
doc.get("_id"),
"2d index requires a point; got a non-point geometry",
)
return [encode_cell(planar_2d_index_for_point(geom.x, geom.y, options))]
_COLL_TABLE = "table:secantus_collections"
_DOC_TABLE = "table:secantus_documents"
_IDX_TABLE = "table:secantus_indexes"
_IDX_ENTRIES_TABLE = "table:secantus_index_entries"
_OPLOG_TABLE = "table:secantus_oplog"
_PREIMAGE_TABLE = "table:secantus_preimages"
_OPLOG_META_TABLE = "table:secantus_oplog_meta"
_USERS_TABLE = "table:secantus_users"
_ROLES_TABLE = "table:secantus_roles"
_PROFILE_TABLE = "table:secantus_profile_settings"
_OPLOG_PRUNE_INTERVAL = 1000 # call prune_oplog every N emits
_ENTRY_SEP = b"\x00\x00"
def _escape_kb(kb: bytes) -> bytes:
"""Order-preserving escape so ``\\x00\\x00`` is unambiguous as a separator."""
return kb.replace(b"\x00", b"\x00\xff")
def _pack_entry(kb: bytes, id_key: bytes) -> bytes:
"""Pack a sortable index-entry payload into a single ``u`` column.
WiredTiger length-prefixes ``u`` columns when they're not last in the
key, which breaks lexicographic comparison. Packing both fields into
one trailing ``u`` column lets the B-tree do the sort for us.
"""
return _escape_kb(kb) + _ENTRY_SEP + id_key
def _unpack_entry(packed: bytes) -> tuple[bytes, bytes]:
"""Return ``(escaped_kb, id_key)`` from a packed entry."""
sep = packed.find(_ENTRY_SEP)
return packed[:sep], packed[sep + 2 :]
class DuplicateKeyError(Exception):
def __init__(self, doc_id: Any) -> None:
super().__init__(f"duplicate _id: {doc_id!r}")
self.doc_id = doc_id
def _id_key(doc_id: Any) -> bytes:
"""Byte-sortable canonical bytes for an ``_id`` value.
Uses the same byte-sortable encoding the secondary-index entries
table relies on. Two consequences worth knowing:
* Cross-numeric collision: ``1 == 1.0 == Decimal128("1")`` produce
identical bytes (so they hit the same doc / clash on uniqueness),
because ``encode_value`` normalises numerics through ``Decimal``.
* Natural iteration: walking the doc table in WT-key order yields
docs in BSON cross-type sort order, which matches what real
MongoDB calls "natural order" for non-capped collections.
"""
return encode_value(doc_id)
def _doc_makes_multikey(doc: Mapping[str, Any], key_spec: Mapping[str, Any]) -> bool:
"""True if any field in ``key_spec`` resolves to a list value in ``doc``.
Such a value is encoded as a single composite array sortkey, so a
later scalar-equality query against this index would silently miss
the doc — the index must fall back to a full scan.
"""
return any(isinstance(get_path(dict(doc), field), list) for field in key_spec)
def _index_key(
doc: Mapping[str, Any], key_spec: Mapping[str, Any], *, sparse: bool
) -> bytes | None:
"""Direction-aware byte-sortable encoding for an index ``key_spec``.
Each field is encoded with ``encode_value_directed`` so ``-1``
(descending) fields get bitwise-inverted bytes, making a forward
B-tree walk yield values in descending order. Compound keys are
joined with ``\\x00\\x00`` between components.
For docs whose indexed field is array-valued, this returns the
whole-array sortkey only — the single canonical "doc-shape" key
used by uniqueness probes. The full set of multikey entries
(per-element + whole-array) is produced by
:func:`_index_key_variants`.
"""
if sparse:
for field in key_spec:
if not has_path(dict(doc), field):
return None
fields = list(key_spec)
if len(fields) == 1:
d = int(key_spec[fields[0]])
return encode_value_directed(get_path(dict(doc), fields[0]), d)
parts = [encode_value_directed(get_path(dict(doc), f), int(key_spec[f])) for f in fields]
return COMPOUND_SEP.join(parts)
def _index_key_variants(
doc: Mapping[str, Any], key_spec: Mapping[str, Any], *, sparse: bool
) -> list[bytes]:
"""All byte-keys this doc contributes to an index under ``key_spec``.
For scalar-valued fields, returns one key — same as ``_index_key``.
For array-valued fields, returns one key per array element *and*
the whole-array key, mirroring real ``mongod``'s multikey index
layout. This makes:
* ``{tags: "python"}`` against ``{tags: ["python", "go"]}`` light
up via the per-element entry for ``"python"``.
* ``{tags: ["python", "go"]}`` (whole-array equality) light up via
the whole-array entry — without this, the equality lookup would
false-negative.
* Range / ``$in`` queries on array fields hit at least all true
matches (the post-index ``matches()`` filter discards
false-positives).
For compound indexes whose multiple fields are array-valued, the
cartesian product is taken across each field's candidate values.
Real mongod restricts compound indexes to one multikey field per
doc; we don't enforce that — we just emit the cross-product, which
is correct (over-includes; the post-filter discards) but pays a
cardinality blow-up the user is then on the hook for.
Returns an empty list when ``sparse`` and any field is missing.
Per-element values are deduplicated against their encoded bytes,
so ``[1, 1, 2]`` writes two element entries (``1`` and ``2``) plus
the whole-array entry, not three.
"""
fields = list(key_spec)
if sparse:
for field in fields:
if not has_path(dict(doc), field):
return []
# Per-field candidate values: scalars contribute [val]; arrays
# contribute [unique_elements..., whole_array].
per_field: list[list[Any]] = []
for field in fields:
v = get_path(dict(doc), field)
if isinstance(v, list):
seen: set[bytes] = set()
uniq: list[Any] = []
d = int(key_spec[field])
for elem in v:
eb = encode_value_directed(elem, d)
if eb in seen:
continue
seen.add(eb)
uniq.append(elem)
# Whole-array sortkey may collide with an element when the
# array is a single scalar repeated; the dedup below at the
# entry level (set of bytes) catches that.
per_field.append([*uniq, v])
else:
per_field.append([v])
if len(fields) == 1:
d = int(key_spec[fields[0]])
keys: list[bytes] = []
seen_kb: set[bytes] = set()
for val in per_field[0]:
kb = encode_value_directed(val, d)
if kb in seen_kb:
continue
seen_kb.add(kb)
keys.append(kb)
return keys
# Compound: cartesian product across per-field candidate lists.
from itertools import product
keys = []
seen_kb = set()
for combo in product(*per_field):
parts = [
encode_value_directed(combo[i], int(key_spec[fields[i]])) for i in range(len(fields))
]
kb = COMPOUND_SEP.join(parts)
if kb in seen_kb:
continue
seen_kb.add(kb)
keys.append(kb)
return keys
def _to_decimal(value: Any) -> Decimal:
if isinstance(value, Decimal128):
return value.to_decimal()
if isinstance(value, float):
return Decimal(repr(value))
return Decimal(value)
def _bson_type_rank(value: Any) -> int:
"""Rank for MongoDB's cross-type sort order. Lower rank sorts first."""
import datetime as _dt
from bson import Binary, MaxKey, MinKey, ObjectId, Regex, Timestamp
if isinstance(value, MinKey):
return 1
if value is None:
return 2
if isinstance(value, bool):
return 9
if isinstance(value, (int, float, Decimal128)):
return 3
if isinstance(value, str):
return 4
if isinstance(value, Mapping):
return 5
if isinstance(value, list):
return 6
if isinstance(value, (bytes, Binary)):
return 7
if isinstance(value, ObjectId):
return 8
if isinstance(value, _dt.datetime):
return 10
if isinstance(value, Timestamp):
return 11
if isinstance(value, Regex):
return 12
if isinstance(value, MaxKey):
return 13
return 5
class _SortKey:
__slots__ = ("val", "_reverse")
def __init__(self, val: Any, reverse: bool = False) -> None:
self.val = val
self._reverse = reverse
def __lt__(self, other: _SortKey) -> bool:
# Swap operands when this key is descending — the same comparison
# logic then yields the correct order for desc fields, and the
# equal-keys case still returns False on both sides (stable sort
# preserves doc order). Both sides of the comparison must agree on
# direction (they're in the same column), which our caller
# guarantees.
if self._reverse:
a, b = other.val, self.val
else:
a, b = self.val, other.val
return _bson_lt(a, b)
def __eq__(self, other: object) -> bool:
return isinstance(other, _SortKey) and self.val == other.val
def _bson_lt(a: Any, b: Any) -> bool:
"""BSON sort-order ``<`` for two values.
Handles the four cases ``__lt__`` used to inline: cross-type rank,
Decimal128 widening, native ``<``, and the embedded-document /
array recursion — mongo-node-driver's
``Aggregation ... pipeline using array`` test sorts grouped docs
by an embedded ``_id`` field and the previous inline ``a < b``
raised ``TypeError`` on Python's dicts.
"""
ra = _bson_type_rank(a)
rb = _bson_type_rank(b)
if ra != rb:
return ra < rb
if a is None or b is None:
return False
if isinstance(a, Decimal128) or isinstance(b, Decimal128):
try:
ad = _to_decimal(a)
bd = _to_decimal(b)
return bool(ad < bd)
except (InvalidOperation, ValueError):
pass
# Embedded documents: compare field-by-field in insertion order,
# first differing pair wins. Real BSON sort recurses; Python's dict
# ``<`` raises ``TypeError`` so without this branch sort would be
# a no-op on grouped ``_id`` keys.
if isinstance(a, Mapping) and isinstance(b, Mapping):
a_items = list(a.items())
b_items = list(b.items())
for (ak, av), (bk, bv) in zip(a_items, b_items, strict=False):
if ak != bk:
return ak < bk
if _bson_lt(av, bv):
return True
if _bson_lt(bv, av):
return False
return len(a_items) < len(b_items)
# Arrays: lexicographic, element-by-element. Same TypeError trap
# as the dict case for arrays-of-mixed-types.
if isinstance(a, list) and isinstance(b, list):
for av, bv in zip(a, b, strict=False):
if _bson_lt(av, bv):
return True
if _bson_lt(bv, av):
return False
return len(a) < len(b)
try:
return bool(a < b)
except TypeError:
return type(a).__name__ < type(b).__name__
def sort_docs(
docs: list[dict[str, Any]], sort_spec: Mapping[str, Any] | None
) -> list[dict[str, Any]]:
if not sort_spec:
return docs
fields = [(f, int(d) == -1) for f, d in sort_spec.items()]
# Single sort over a precomputed tuple key rather than N stable passes:
# one pass through Timsort, get_path called once per field per doc.
return sorted(
docs,
key=lambda d: tuple(_SortKey(get_path(d, f), reverse=rev) for f, rev in fields),
)
_ID_INDEX_NAME = "_id_"
class IndexConflict(Exception):
def __init__(
self,
index_name: str,
doc_id: Any,
*,
key_pattern: dict[str, Any] | None = None,
key_value: dict[str, Any] | None = None,
) -> None:
super().__init__(f"E11000 duplicate key error in index {index_name}: _id={doc_id!r}")
self.index_name = index_name
self.doc_id = doc_id
# Real mongod returns ``keyPattern`` (the index spec) and
# ``keyValue`` (the conflicting field values) in the dup-key
# error response. Drivers expose them as ``errorResponse``
# fields; mongo-java-driver's ``findOneAndUpdate-errorResponse``
# asserts both. Optional because legacy raise-sites
# (``_id`` collision before index machinery, recovery paths)
# don't have the index spec handy.
self.key_pattern = key_pattern
self.key_value = key_value
class DocumentValidationError(Exception):
"""A write produced a doc that didn't satisfy the collection's
``validator``. Caught at the command layer and surfaced as the
mongod-shaped writeError (code 121, ``DocumentValidationFailure``)
with the ``errInfo.failingDocumentId`` field drivers' errorResponse
tests assert on."""
def __init__(self, doc_id: Any) -> None:
super().__init__("Document failed validation")
self.doc_id = doc_id
class CreateIndexUnsupported(Exception):
"""``create_index`` was given an index type SecantusDB doesn't support
(currently ``text`` / ``hashed``). Caught at the command layer and
surfaced as a typed wire error rather than letting the cell-encoder
later trip over an opaque internal exception."""
class IndexOptionsConflict(Exception):
"""``create_index`` was called with a name that already exists in the
collection but with conflicting options (different ``unique`` /
``sparse`` / ``hidden`` / ``expireAfterSeconds`` /
``partialFilterExpression``). Real mongod rejects with
``IndexOptionsConflict`` (code 85); drivers (mongo-ruby-driver's
``Collection#create_indexes`` specs) assert on the rejection."""
class GeoExtractError(Exception):
"""Doc's geo field can't be indexed — bad shape or out-of-bounds coords.
Raised from the geo-index write path when an insert / update / index
creation hits a doc the geo extractor can't make sense of (bad
GeoJSON, non-numeric coordinates, longitude / latitude outside the
valid range, etc.). Caught at the command-layer write boundary and
surfaced as a wire-level write error with mongod's documented code
16572 ("Can't extract geo keys").
"""
def __init__(self, index_name: str, field: str, doc_id: Any, reason: str) -> None:
super().__init__(
f"Can't extract geo keys for index {index_name!r} on field {field!r}: {reason}"
)
self.index_name = index_name
self.field = field
self.doc_id = doc_id
self.reason = reason
class BadHint(Exception):
"""The ``hint`` passed to ``find_matching`` doesn't name an existing index."""
[docs]
class Storage:
def __init__(
self,
path: str = ":memory:",
*,
oplog_retention_seconds: float = 3600.0,
oplog_max_entries: int = 100_000,
time_func: Callable[[], float] | None = None,
enable_oplog: bool = True,
ttl_sweep_seconds: float = 60.0,
noop_heartbeat_seconds: float = 0.0,
) -> None:
# When False, _emit_oplog short-circuits and writes nothing —
# used in standalone (non-replica-set) mode to skip the per-write
# BSON encode + WT cursor write cost of oplog entries that no
# change-stream client will ever read. The oplog WT tables are
# still created so toggling at runtime stays safe.
self.enable_oplog = enable_oplog
self._lock = threading.RLock()
self._closed = False
self._tempdir: str | None = None
# session_max default is ~120; each client connection thread
# caches its own session in `threading.local()`, and cross-
# thread oplog readers open additional short-lived sessions on
# demand. With a few dozen concurrent client connections plus
# active change-stream tailers, the default ceiling is hit
# mid-handshake and surfaces as `out of sessions` /
# WT_ERROR. mongod itself runs with session_max=33000 — 1000
# is a generous floor for a single-node test surrogate while
# still well under the WT hard limit.
# cache_size default is 100 MB. With ``in_memory=true`` every
# write also lives in cache, so a workload that inserts a
# handful of 16 MB documents (mongod's per-doc max) blows the
# cap as ``WT_CACHE_FULL: operation would overflow cache``.
# 1 GB gives generous headroom for tests + reasonable
# in-process workloads while staying well under the limits
# ``mongod`` itself runs with on a normal box.
# Tracked so ``checkpoint()`` calls are skipped in in-memory
# mode (WT's in_memory backend rejects them with a noisy
# ``__wt_inmem_unsupported_op`` log line on every call).
self._in_memory = path == ":memory:"
if path == ":memory:":
self._tempdir = tempfile.mkdtemp(prefix="secantus_wt_")
home = self._tempdir
# in_memory=true disables the journal entirely (no files);
# ephemeral by definition, so durability isn't a concern.
config = "create,in_memory=true,session_max=1000,cache_size=1G"
else:
os.makedirs(path, exist_ok=True)
home = path
# ``log=(enabled=true)`` turns on WT's redo journal: every
# transaction commit writes a log record before it returns,
# and recovery replays the log on reopen. Without this,
# WT's only durability mechanism is checkpoints (default
# cadence: every 60s, or on clean ``WT_CONNECTION->close``).
# On SIGKILL between checkpoints, every uncommitted write
# is lost — which is exactly the failure mode observed by
# ``bench/chaos.py`` (3-min chaos run, 17 SIGKILLs:
# 432,881 acked / 1 persisted).
#
# ``transaction_sync=(enabled=false,method=fsync)`` is
# WT's default once logging is on: log records are written
# to the OS but not fsynced per commit. Result: SIGKILL
# of the process is durable (the OS still flushes its
# page cache), only true power-loss between commits and
# the next OS flush window can lose data. That matches
# mongod's default ``writeConcern: {w:1, j:false}``.
# Bumping to ``method=fsync,enabled=true`` would honour
# ``j:true`` durability but we don't yet plumb that
# write-concern down — backlog item.
#
# ``file_max=10MB`` bounds journal segment size; smaller
# files churn the log more, larger files delay reclamation.
# 10 MB matches mongod's WT default.
config = (
"create,session_max=1000,cache_size=1G,"
"log=(enabled=true,file_max=10MB),"
"transaction_sync=(enabled=false,method=fsync)"
)
self._conn = wt.wiredtiger_open(home, config)
self._tls = threading.local()
self._all_sessions: list[Any] = []
boot = self._conn.open_session()
try:
boot.create(_COLL_TABLE, "key_format=SS,value_format=u")
boot.create(_DOC_TABLE, "key_format=SSu,value_format=u")
boot.create(_IDX_TABLE, "key_format=SSS,value_format=u")
boot.create(_IDX_ENTRIES_TABLE, "key_format=SSSu,value_format=u")
boot.create(_OPLOG_TABLE, "key_format=q,value_format=u")
boot.create(_PREIMAGE_TABLE, "key_format=q,value_format=u")
boot.create(_OPLOG_META_TABLE, "key_format=S,value_format=u")
boot.create(_USERS_TABLE, "key_format=SS,value_format=u")
boot.create(_ROLES_TABLE, "key_format=SS,value_format=u")
boot.create(_PROFILE_TABLE, "key_format=S,value_format=u")
finally:
boot.close()
# Oplog state — durable across restart via _OPLOG_META_TABLE.
self.oplog_retention_seconds = float(oplog_retention_seconds)
self.oplog_max_entries = int(oplog_max_entries)
self._time = time_func or _time.time
self._oplog_cv = threading.Condition(threading.Lock())
self._oplog_emit_count = 0
# Tiny fine-grained lock for seq + timestamp minting. Held in
# microseconds while reserving the next seq range and bumping
# the cluster-time counter. Carved out of ``_lock`` (Phase 2.1
# of the WT concurrency plan) so concurrent writers can mint
# without contending on the global storage lock.
self._oplog_seq_lock = threading.Lock()
# Per-collection RLocks for the CRUD path (Phase 2.4 of the WT
# concurrency plan). Writes to *different* collections can now
# run in parallel; writes to the *same* collection still
# serialise (preserves unique-index correctness + the pre-check
# racing windows that would otherwise need an architectural
# refactor of the index-entries schema). DDL operations also
# acquire the per-coll lock(s) they affect so they cannot
# reshape schema mid-CRUD-write.
self._coll_locks: dict[tuple[str, str], threading.RLock] = {}
self._coll_locks_mutex = threading.Lock()
with self._lock:
self._next_seq, self._last_ts_secs, self._last_ts_ord = self._load_oplog_meta()
# TTL sweeper. Real mongod runs ``ttlMonitor`` every 60s by
# default; we mirror that. ``ttl_sweep_seconds <= 0`` disables
# the thread entirely (tests that drive expiry deterministically
# via ``prune_ttl(now=...)`` use that escape hatch). The
# sweeper walks every (db, coll) and calls ``prune_ttl`` on
# each — collections with no TTL index short-circuit cheaply
# at the index-scan step, so the steady-state cost is small.
self._ttl_sweep_seconds = float(ttl_sweep_seconds)
self._ttl_stop = threading.Event()
self._ttl_thread: threading.Thread | None = None
if self._ttl_sweep_seconds > 0:
self._ttl_thread = threading.Thread(
target=self._ttl_sweep_loop, name="secantus-ttl-sweeper", daemon=True
)
self._ttl_thread.start()
# Periodic noop heartbeat. Real mongod writes ``{op: "n"}``
# entries to the oplog every ~10s (configurable via
# ``periodicNoopIntervalSecs``) so cluster time advances and
# change-stream resume tokens minted from the oplog don't fall
# outside the retention window during quiet stretches. Default
# disabled (0) — embedded test users typically don't need it
# and the extra writes would noise up tight oplog assertions.
# Set ``noop_heartbeat_seconds=10`` (mongod default) for
# production-ish behaviour. ``enable_oplog=False`` short-
# circuits anyway, so the heartbeat is a no-op in that mode.
self._noop_heartbeat_seconds = float(noop_heartbeat_seconds)
self._noop_stop = threading.Event()
self._noop_thread: threading.Thread | None = None
if self._noop_heartbeat_seconds > 0 and self.enable_oplog:
self._noop_thread = threading.Thread(
target=self._noop_heartbeat_loop, name="secantus-noop-heartbeat", daemon=True
)
self._noop_thread.start()
def _load_oplog_meta(self) -> tuple[int, int, int]:
c = self._cursor(_OPLOG_META_TABLE)
c.set_key("state")
if c.search() == 0:
blob = bytes(c.get_value())
if blob:
state = bson.decode(blob)
return (
int(state.get("next_seq", 1)),
int(state.get("last_ts_secs", 0)),
int(state.get("last_ts_ord", 0)),
)
# Fallback: scan oplog table for max key + reconstruct from entry.
c2 = self._cursor(_OPLOG_TABLE)
# Walk to last row.
last_seq = 0
last_secs = 0
last_ord = 0
rc = c2.next()
while rc == 0:
seq = int(c2.get_key())
if seq > last_seq:
last_seq = seq
blob = bytes(c2.get_value())
if blob:
entry = bson.decode(blob)
ts = entry.get("ts")
if isinstance(ts, Timestamp):
last_secs, last_ord = ts.time, ts.inc
rc = c2.next()
return last_seq + 1, last_secs, last_ord
def _persist_oplog_meta(self) -> None:
c = self._cursor(_OPLOG_META_TABLE)
c["state"] = bson.encode(
{
"next_seq": self._next_seq,
"last_ts_secs": self._last_ts_secs,
"last_ts_ord": self._last_ts_ord,
}
)
def _mint_ts(self) -> Timestamp:
"""Return a strictly-monotonic ``Timestamp(secs, ord)``.
Caller must hold ``self._oplog_seq_lock``. Within a single
wall-clock second ``ord`` increments; on a new second it resets
to 1. Recovered state on startup ensures the first mint after
restart is strictly greater than any previously-emitted
timestamp.
"""
now = int(self._time())
if now > self._last_ts_secs:
self._last_ts_secs = now
self._last_ts_ord = 1
else:
self._last_ts_ord += 1
return Timestamp(self._last_ts_secs, self._last_ts_ord)
def _coll_lock(self, db: str, coll: str) -> threading.RLock:
"""Return the per-collection RLock for ``(db, coll)``, creating it
on first reference. Phase 2.4 of the WT concurrency plan.
CRUD on a given collection serialises through this lock; CRUD on
*other* collections proceeds in parallel. DDL on this collection
also acquires this lock so schema changes cannot interleave with
in-flight writes.
"""
key = (db, coll)
# Fast path: lock already exists — read without any mutation,
# safe under GIL.
existing = self._coll_locks.get(key)
if existing is not None:
return existing
# Create-or-fetch under the small registry mutex. RLocks are
# never removed (collections come and go but the lock identity
# for a given (db, coll) stays stable across drop+recreate to
# avoid races with in-flight writers).
with self._coll_locks_mutex:
existing = self._coll_locks.get(key)
if existing is not None:
return existing
lock = threading.RLock()
self._coll_locks[key] = lock
return lock
def _mint_oplog_seq_and_ts(self, n: int) -> tuple[int, list[Timestamp]]:
"""Atomically reserve ``n`` consecutive oplog seq numbers and mint
``n`` strictly-monotonic timestamps. Returns ``(start_seq,
[ts_0, ..., ts_{n-1}])``.
Held only under ``_oplog_seq_lock`` (microseconds of work) — the
actual oplog cursor writes happen in the caller's WT session
without blocking other writers on this lock.
"""
with self._oplog_seq_lock:
start = self._next_seq
self._next_seq += n
timestamps = [self._mint_ts() for _ in range(n)]
return start, timestamps
def _collection_uuid(self, db: str, coll: str) -> _uuid.UUID:
"""Return the collection's UUID, minting and persisting on first call.
Fast path (UUID already present): no Python lock — straight WT
cursor read on the calling thread's session. This was a major
per-insert bottleneck before Phase 2.4: every write re-acquired
``self._lock`` here, defeating the per-collection lock split.
Slow path (mint a new UUID): take ``_coll_lock`` for the
namespace to serialise the persist; double-check inside the
lock so two racing callers can't mint different UUIDs for the
same collection.
"""
opts = self._coll_options(db, coll) or {}
existing = opts.get("uuid")
if isinstance(existing, _uuid.UUID):
return existing
if isinstance(existing, bson.Binary) and len(existing) == 16:
return _uuid.UUID(bytes=bytes(existing))
if isinstance(existing, bytes) and len(existing) == 16:
return _uuid.UUID(bytes=existing)
# Mint path — take the per-coll lock; re-read after acquiring
# so a racer that won the mint race is observed.
with self._coll_lock(db, coll):
opts = self._coll_options(db, coll) or {}
existing = opts.get("uuid")
if isinstance(existing, _uuid.UUID):
return existing
if isinstance(existing, bson.Binary) and len(existing) == 16:
return _uuid.UUID(bytes=bytes(existing))
if isinstance(existing, bytes) and len(existing) == 16:
return _uuid.UUID(bytes=existing)
new_uuid = _uuid.uuid4()
opts["uuid"] = new_uuid
self._write_coll_options(db, coll, opts)
return new_uuid
def collection_uuid(self, db: str, coll: str) -> _uuid.UUID:
"""Public alias for ``_collection_uuid``."""
return self._collection_uuid(db, coll)
def current_cluster_time(self) -> Timestamp:
"""Return a strictly-monotonic ``Timestamp`` advancing the cluster clock."""
with self._oplog_seq_lock:
ts = self._mint_ts()
# Meta persist uses the calling thread's WT session/cursor —
# safe to do outside the seq lock since it doesn't depend on
# the in-memory counters being held stable past the mint.
self._persist_oplog_meta()
return ts
def _write_coll_options(self, db: str, coll: str, opts: Mapping[str, Any]) -> None:
c = self._cursor(_COLL_TABLE)
# bson can't directly encode a uuid.UUID without a codec, so store as Binary subtype 4.
encoded: dict[str, Any] = {}
for k, v in opts.items():
if isinstance(v, _uuid.UUID):
encoded[k] = bson.Binary(v.bytes, subtype=4)
else:
encoded[k] = v
c[db, coll] = bson.encode(encoded) if encoded else b""
def set_collection_options(self, db: str, coll: str, **opts: Any) -> None:
"""Merge ``opts`` into the collection's options blob (creates if absent)."""
with self._lock:
self._ensure_collection(db, coll)
current = self._coll_options(db, coll) or {}
current.update(opts)
self._write_coll_options(db, coll, current)
def get_collection_options(self, db: str, coll: str) -> dict[str, Any]:
"""Return the collection's options blob, or ``{}`` if absent."""
if self.enable_oplog and db == "local" and coll == "oplog.rs":
# Synthetic ``local.oplog.rs``: report the capped-collection
# shape mongod uses so $collStats / listCollections options
# match. ``size`` is a notional byte cap derived from the
# entry cap × a conservative per-entry estimate; we don't
# track real byte usage, only entry count.
return {
"capped": True,
"size": self.oplog_max_entries * 16 * 1024,
"max": self.oplog_max_entries,
}
self._refresh_read_snapshot()
with self._lock:
opts = self._coll_options(db, coll) or {}
# Decode UUID Binary back into uuid.UUID for callers.
decoded: dict[str, Any] = {}
for k, v in opts.items():
if k == "uuid" and isinstance(v, bson.Binary) and len(v) == 16:
decoded[k] = _uuid.UUID(bytes=bytes(v))
else:
decoded[k] = v
return decoded
def _is_oplog_rs(self, db: str, coll: str) -> bool:
"""``(local, oplog.rs)`` is the synthetic oplog view."""
return self.enable_oplog and db == "local" and coll == "oplog.rs"
def _scan_oplog_entries(self) -> list[dict[str, Any]]:
"""Walk every persisted oplog entry and return the decoded docs.
Uses a private short-lived session so the read view always
reflects rows committed by writer threads on other connections
(same pattern as ``read_oplog``).
"""
rows: list[dict[str, Any]] = []
with self._lock:
session = self._conn.open_session()
try:
c = session.open_cursor(_OPLOG_TABLE, None)
try:
rc = c.next()
while rc == 0:
blob = bytes(c.get_value())
if blob:
rows.append(bson.decode(blob))
rc = c.next()
finally:
with contextlib.suppress(Exception):
c.close()
finally:
with contextlib.suppress(Exception):
session.close()
return rows
def _find_oplog_rs(
self,
filter: dict[str, Any] | None,
*,
skip: int,
limit: int,
sort: Mapping[str, Any] | None,
projection: Mapping[str, Any] | None,
let: dict[str, Any] | None,
collation: Any,
) -> list[dict[str, Any]]:
"""Read path for the synthetic ``local.oplog.rs`` view.
Entries are walked in seq order (== ts order). Filter / sort /
skip / limit / projection are all honoured against the decoded
entry docs via the existing pure-Python helpers.
"""
from secantus.collation import parse as _parse_collation
collation_obj = _parse_collation(collation)
rows = self._scan_oplog_entries()
if filter:
rows = [r for r in rows if matches(r, filter, vars=let, collation=collation_obj)]
if sort:
rows = sort_docs(rows, sort)
if skip:
rows = rows[skip:]
if limit > 0:
rows = rows[:limit]
if projection:
rows = [apply_projection(r, projection) for r in rows]
return rows
def _count_oplog_rs(
self,
filter: dict[str, Any] | None,
*,
let: dict[str, Any] | None,
collation: Any,
) -> int:
from secantus.collation import parse as _parse_collation
collation_obj = _parse_collation(collation)
if not filter:
return len(self._scan_oplog_entries())
return sum(
1
for r in self._scan_oplog_entries()
if matches(r, filter, vars=let, collation=collation_obj)
)
def _emit_oplog(
self,
entries: list[dict[str, Any]],
pre_images: list[bytes | None] | None = None,
) -> int:
"""Append ``entries`` to the oplog table under ``self._lock``.
``pre_images`` is parallel to ``entries``; non-None elements are
stored under the matching seq in ``_PREIMAGE_TABLE``. Returns the
highest seq emitted (0 if ``entries`` is empty). Notifies waiters
on ``self._oplog_cv`` once writes have committed.
If ``self.enable_oplog`` is False, returns 0 immediately — the
caller's prebuilt ``entries`` list is discarded. The change-stream
condvar is still notified so any tailable getMore wakes up and
observes the (empty) state.
"""
if not self.enable_oplog:
with self._oplog_cv:
self._oplog_cv.notify_all()
return 0
if not entries:
return 0
if pre_images is None:
pre_images = [None] * len(entries)
assert len(pre_images) == len(entries)
# Reserve seq + ts range up-front under the tiny seq lock.
# The actual cursor writes below run on this thread's WT
# session without holding any cross-thread Python lock.
n = len(entries)
start_seq, ts_range = self._mint_oplog_seq_and_ts(n)
op_cur = self._cursor(_OPLOG_TABLE)
pre_cur = None
last_seq = 0
for i, (entry, pre) in enumerate(zip(entries, pre_images, strict=True)):
seq = start_seq + i
entry_with_ts = dict(entry)
if "ts" not in entry_with_ts:
entry_with_ts["ts"] = ts_range[i]
if "wall" not in entry_with_ts:
entry_with_ts["wall"] = _dt.datetime.now(_dt.timezone.utc)
op_cur[seq] = bson.encode(entry_with_ts)
if pre is not None:
if pre_cur is None:
pre_cur = self._cursor(_PREIMAGE_TABLE)
pre_cur[seq] = pre
last_seq = seq
# ``_persist_oplog_meta`` was called here on every emit, but
# under concurrent writers it WT-rollbacks half the time —
# every writer hits the same single ``"state"`` meta row.
# The meta row is purely a recovery optimisation; if it's
# stale, ``_load_oplog_meta``'s fallback scans the oplog
# table for the actual max seq. So we now persist only on
# close + on prune_oplog, both of which are rare. The seq
# mint itself is durable because the actual oplog rows are
# written on every emit.
self._oplog_emit_count += len(entries)
if self._oplog_emit_count >= _OPLOG_PRUNE_INTERVAL:
self._oplog_emit_count = 0
self._prune_oplog_locked(now=self._time())
with self._oplog_cv:
self._oplog_cv.notify_all()
return last_seq
def read_oplog(
self,
*,
start_seq: int,
limit: int,
ns_filter: Callable[[str], bool] | None = None,
) -> list[tuple[int, dict[str, Any]]]:
"""Forward-scan the oplog from ``start_seq`` (inclusive).
Uses a private short-lived session so the read view always reflects
rows committed by other sessions. The cached per-thread session's
snapshot is sticky — under WiredTiger's MVCC, reusing it across
getMore polls would never observe oplog rows produced by a writer
running on a different connection thread.
"""
out: list[tuple[int, dict[str, Any]]] = []
with self._lock:
session = self._conn.open_session()
try:
c = session.open_cursor(_OPLOG_TABLE, None)
try:
c.set_key(int(start_seq))
rc = c.search_near()
if rc == wt.WT_NOTFOUND:
return out
if rc < 0 and c.next() != 0:
return out
while True:
seq = int(c.get_key())
blob = bytes(c.get_value())
if blob:
entry = bson.decode(blob)
if ns_filter is None or ns_filter(str(entry.get("ns", ""))):
out.append((seq, entry))
if len(out) >= limit:
break
if c.next() != 0:
break
finally:
with contextlib.suppress(Exception):
c.close()
finally:
with contextlib.suppress(Exception):
session.close()
return out
def read_preimage(self, seq: int) -> dict[str, Any] | None:
"""Return the pre-image doc for ``seq`` if one was stored, else ``None``.
Uses a private session for cross-thread visibility (see ``read_oplog``).
"""
with self._lock:
session = self._conn.open_session()
try:
c = session.open_cursor(_PREIMAGE_TABLE, None)
try:
c.set_key(int(seq))
if c.search() != 0:
return None
blob = bytes(c.get_value())
if not blob:
return None
return bson.decode(blob)
finally:
with contextlib.suppress(Exception):
c.close()
finally:
with contextlib.suppress(Exception):
session.close()
def oplog_tail_seq(self) -> int:
"""Highest seq currently present (or last emitted). 0 if empty."""
with self._lock:
return self._next_seq - 1
def oplog_tail_seq_nolock(self) -> int:
"""Highest seq read without acquiring ``self._lock``.
Safe for use **only** as the wake predicate for a tailable
``getMore`` waiting on ``self._oplog_cv``: lock order in the
write path is ``_lock`` -> ``_oplog_cv``, so a waiter that
already holds ``_oplog_cv`` (which is what ``cv.wait_for``
does) MUST NOT then take ``_lock`` -- that's an ABBA deadlock
with any concurrent writer. Reading ``_next_seq`` directly is
safe because (a) ``int`` reads are atomic under the GIL and
(b) the cv is also notified on every commit, so any momentary
stale read self-corrects on the next iteration of the
``wait_for`` predicate.
"""
return self._next_seq - 1
def oplog_floor_seq(self) -> int:
"""Smallest seq currently present after pruning. 0 if empty.
Uses a private session for cross-thread visibility.
"""
with self._lock:
session = self._conn.open_session()
try:
c = session.open_cursor(_OPLOG_TABLE, None)
try:
rc = c.next()
if rc != 0:
return 0
return int(c.get_key())
finally:
with contextlib.suppress(Exception):
c.close()
finally:
with contextlib.suppress(Exception):
session.close()
def find_seq_for_ts(self, ts: Timestamp) -> int:
"""Smallest seq whose entry ``ts >= target``. Tail+1 if none qualify.
Uses a private session for cross-thread visibility.
"""
with self._lock:
session = self._conn.open_session()
try:
c = session.open_cursor(_OPLOG_TABLE, None)
try:
rc = c.next()
while rc == 0:
seq = int(c.get_key())
blob = bytes(c.get_value())
if blob:
entry = bson.decode(blob)
entry_ts = entry.get("ts")
if isinstance(entry_ts, Timestamp) and (
entry_ts.time > ts.time
or (entry_ts.time == ts.time and entry_ts.inc >= ts.inc)
):
return seq
rc = c.next()
return self._next_seq
finally:
with contextlib.suppress(Exception):
c.close()
finally:
with contextlib.suppress(Exception):
session.close()
def prune_oplog(self, *, now: float | None = None) -> int:
"""Drop oplog rows older than retention or above the entry cap."""
with self._lock:
return self._prune_oplog_locked(now=now)
def _ns(self, db: str, coll: str) -> str:
return f"{db}.{coll}"
def _pre_post_images_enabled(self, db: str, coll: str) -> bool:
opts = self._coll_options(db, coll) or {}
cfg = opts.get("changeStreamPreAndPostImages")
return isinstance(cfg, Mapping) and bool(cfg.get("enabled"))
def _prune_oplog_locked(self, *, now: float | None = None) -> int:
when = now if now is not None else self._time()
cutoff_secs = int(when - self.oplog_retention_seconds)
# Two-phase: collect doomed seqs, then delete (avoid mutating during scan).
doomed: list[int] = []
all_seqs: list[int] = []
c = self._cursor(_OPLOG_TABLE)
rc = c.next()
while rc == 0:
seq = int(c.get_key())
blob = bytes(c.get_value())
all_seqs.append(seq)
if blob:
entry = bson.decode(blob)
ts = entry.get("ts")
if isinstance(ts, Timestamp) and ts.time < cutoff_secs:
doomed.append(seq)
rc = c.next()
# Trim to entry cap by extending doom set to oldest entries.
kept_count = len(all_seqs) - len(doomed)
if kept_count > self.oplog_max_entries:
extra = kept_count - self.oplog_max_entries
doomed_set = set(doomed)
for seq in all_seqs:
if extra <= 0:
break
if seq not in doomed_set:
doomed.append(seq)
doomed_set.add(seq)
extra -= 1
if not doomed:
return 0
op_del = self._cursor(_OPLOG_TABLE)
pre_del = self._cursor(_PREIMAGE_TABLE)
for seq in doomed:
op_del.set_key(seq)
with contextlib.suppress(wt.WiredTigerError):
op_del.remove()
op_del.reset()
pre_del.set_key(seq)
with contextlib.suppress(wt.WiredTigerError):
pre_del.remove()
pre_del.reset()
return len(doomed)
# --- Users (auth) ---
def add_user(
self,
db: str,
username: str,
record: Mapping[str, Any],
*,
replace: bool = False,
) -> bool:
"""Persist a user record. Returns True if added; False if it already
existed and ``replace=False``.
``record`` is a BSON-encodable dict of arbitrary shape (the
commands layer owns the structure). Stored verbatim.
"""
with self._lock:
c = self._cursor(_USERS_TABLE)
c.set_key(db, username)
if c.search() == 0 and not replace:
return False
c.reset()
c[db, username] = bson.encode(dict(record))
return True
def get_user(self, db: str, username: str) -> dict[str, Any] | None:
with self._lock:
c = self._cursor(_USERS_TABLE)
c.set_key(db, username)
if c.search() != 0:
return None
blob = bytes(c.get_value())
return bson.decode(blob) if blob else None
def drop_user(self, db: str, username: str) -> bool:
with self._lock:
c = self._cursor(_USERS_TABLE)
c.set_key(db, username)
if c.search() != 0:
return False
c.remove()
return True
def list_users(
self,
db: str | None = None,
*,
skip: int = 0,
limit: int = 100,
) -> list[dict[str, Any]]:
"""Paginated user listing. ``db=None`` lists across all databases."""
if limit <= 0 or limit > 1000:
limit = 1000
out: list[dict[str, Any]] = []
with self._lock:
c = self._cursor(_USERS_TABLE)
rc = c.next()
seen = 0
while rc == 0:
k = c.get_key()
row_db = k[0]
if db is None or row_db == db:
if seen >= skip:
blob = bytes(c.get_value())
if blob:
out.append(bson.decode(blob))
if len(out) >= limit:
break
seen += 1
rc = c.next()
return out
# ------------------------------------------------------------------
# Per-database profiling settings.
#
# Real mongod tracks (level, slowms, sampleRate) per database in
# memory + persists to the database's metadata. We persist in a
# dedicated WT table keyed by db name. The dispatch path reads
# these settings on every command — keep ``get_profile`` fast.
# ------------------------------------------------------------------
def get_profile(self, db: str) -> dict[str, Any]:
"""Return the active profile settings for ``db``, defaults if unset.
Defaults match mongod: level 0 (off), slowms 100, sampleRate 1.0.
"""
with self._lock:
c = self._cursor(_PROFILE_TABLE)
c.set_key(db)
if c.search() != 0:
return {"level": 0, "slowms": 100, "sampleRate": 1.0}
blob = bytes(c.get_value())
if not blob:
return {"level": 0, "slowms": 100, "sampleRate": 1.0}
doc = bson.decode(blob)
# ``or default`` is wrong here — slowms=0 / sampleRate=0.0 are
# legitimate values that must round-trip, not be replaced
# with defaults. Use direct ``.get`` with the default and
# coerce only when a value is actually present.
level_v = doc.get("level", 0)
slowms_v = doc.get("slowms", 100)
rate_v = doc.get("sampleRate", 1.0)
return {
"level": int(level_v) if level_v is not None else 0,
"slowms": int(slowms_v) if slowms_v is not None else 100,
"sampleRate": float(rate_v) if rate_v is not None else 1.0,
}
def set_profile(
self,
db: str,
*,
level: int,
slowms: int = 100,
sample_rate: float = 1.0,
) -> None:
"""Persist profile settings for ``db``."""
if level not in (0, 1, 2):
raise ValueError("level must be 0, 1, or 2")
if slowms < 0:
raise ValueError("slowms must be non-negative")
if not (0.0 <= sample_rate <= 1.0):
raise ValueError("sampleRate must be in [0, 1]")
doc = {"level": int(level), "slowms": int(slowms), "sampleRate": float(sample_rate)}
with self._lock:
c = self._cursor(_PROFILE_TABLE)
c[db] = bson.encode(doc)
def ensure_profile_collection(self, db: str, *, size_bytes: int = 10 * 1024 * 1024) -> None:
"""Ensure ``<db>.system.profile`` exists as a 10 MB-default capped collection."""
if self.collection_exists(db, "system.profile"):
return
self.create_collection(db, "system.profile")
self.set_collection_options(db, "system.profile", capped=True, size=int(size_bytes))
# ------------------------------------------------------------------
# Custom roles. Storage layer is a thin BSON-blob CRUD; the commands
# layer owns the role-record shape (privileges + inherited roles)
# and ``secantus.rbac`` owns the privilege-check logic that walks
# the inheritance graph.
# ------------------------------------------------------------------
def add_role(
self,
db: str,
name: str,
record: Mapping[str, Any],
*,
replace: bool = False,
) -> bool:
"""Persist a custom role record. Returns True if added; False if
it already existed and ``replace=False``."""
with self._lock:
c = self._cursor(_ROLES_TABLE)
c.set_key(db, name)
if c.search() == 0 and not replace:
return False
c.reset()
c[db, name] = bson.encode(dict(record))
return True
def get_role(self, db: str, name: str) -> dict[str, Any] | None:
# Use a private short-lived session so cross-thread visibility
# is guaranteed: connection-thread A may have written a role
# while we're on connection-thread B, and B's cached session
# carries a sticky snapshot that won't observe A's commit.
# Same pattern as ``read_oplog``. The cost (one open_session +
# close per call) is negligible vs the correctness win.
with self._lock:
session = self._conn.open_session()
try:
c = session.open_cursor(_ROLES_TABLE, None, None)
try:
c.set_key(db, name)
if c.search() != 0:
return None
blob = bytes(c.get_value())
return bson.decode(blob) if blob else None
finally:
with contextlib.suppress(Exception):
c.close()
finally:
with contextlib.suppress(Exception):
session.close()
def drop_role(self, db: str, name: str) -> bool:
with self._lock:
c = self._cursor(_ROLES_TABLE)
c.set_key(db, name)
if c.search() != 0:
return False
c.remove()
return True
def list_roles(
self,
db: str | None = None,
*,
skip: int = 0,
limit: int = 100,
) -> list[dict[str, Any]]:
"""Paginated custom-role listing. ``db=None`` spans every db."""
if limit <= 0 or limit > 1000:
limit = 1000
out: list[dict[str, Any]] = []
with self._lock:
c = self._cursor(_ROLES_TABLE)
rc = c.next()
seen = 0
while rc == 0:
k = c.get_key()
row_db = k[0]
if db is None or row_db == db:
if seen >= skip:
blob = bytes(c.get_value())
if blob:
out.append(bson.decode(blob))
if len(out) >= limit:
break
seen += 1
rc = c.next()
return out
def close(self) -> None:
# Stop background threads before tearing down WT — both the
# TTL sweeper and the noop heartbeat acquire ``self._lock``,
# so racing them against close would deadlock or
# use-after-close.
self._ttl_stop.set()
if self._ttl_thread is not None and self._ttl_thread.is_alive():
self._ttl_thread.join(timeout=2.0)
self._ttl_thread = None
self._noop_stop.set()
if self._noop_thread is not None and self._noop_thread.is_alive():
self._noop_thread.join(timeout=2.0)
self._noop_thread = None
with self._lock:
if self._closed:
return
self._closed = True
# Persist the oplog meta one last time. We dropped the
# per-emit persist in Phase 2.4 (it caused WT-rollback
# storms under concurrent writers), so this is the
# canonical place to write the in-memory ``_next_seq``
# and timestamp counters down to disk before shutdown.
with contextlib.suppress(Exception):
self._persist_oplog_meta()
# Force a checkpoint before tearing the connection down.
# ``WT_CONNECTION->close`` does this implicitly, but only
# when logging is off (or hits the connection's
# close-time flush window). Driving it explicitly here
# gives a durable on-disk image of the dataset at the
# moment of shutdown regardless of journal state — the
# behaviour callers reasonably expect from ``close()``.
# Skip for in-memory backends: WT's in_memory engine
# rejects checkpoint() with a noisy stderr log
# (``__wt_inmem_unsupported_op``) on every call.
if not self._in_memory:
with contextlib.suppress(Exception):
self._session().checkpoint()
for s in self._all_sessions:
with contextlib.suppress(Exception):
s.close()
self._all_sessions.clear()
with contextlib.suppress(Exception):
self._conn.close()
if self._tempdir is not None:
# Don't follow symlinks during cleanup. A local attacker
# racing the mkdtemp could replace `_tempdir` with a
# symlink to elsewhere on the filesystem before close()
# fires — `shutil.rmtree(symlink, ignore_errors=True)`
# would then delete the symlink target. The mkdtemp
# already creates with mode 0700 (owner-only), but the
# parent /tmp is world-writable, so this is the
# belt-and-braces guard. Failures during cleanup are
# logged but not raised — close() must remain idempotent.
try:
if not os.path.islink(self._tempdir):
shutil.rmtree(self._tempdir)
except OSError:
# Best-effort: log via warnings rather than crash close().
import warnings as _warn
_warn.warn(
f"failed to remove WiredTiger tempdir {self._tempdir!r}",
ResourceWarning,
stacklevel=2,
)
self._tempdir = None
def prune_ttl_all_collections(self, *, now: _dt.datetime | None = None) -> int:
"""Run :meth:`prune_ttl` against every collection, returning the
total docs pruned. Used by the background sweeper and exposed
publicly so callers (admin tooling, tests) can drive a
deterministic global pass.
Callers using the cached per-thread session must call
:meth:`_reset_thread_session` first — WiredTiger snapshots
are sticky per-session, so reads otherwise miss rows
committed by other threads. The sweeper does this on every
iteration; one-shot user calls happen on the writer's thread
and see their own writes.
"""
with self._lock:
c = self._cursor(_COLL_TABLE)
namespaces: list[tuple[str, str]] = []
rc = c.next()
while rc == 0:
k = c.get_key()
namespaces.append((k[0], k[1]))
rc = c.next()
total = 0
for db, coll in namespaces:
with contextlib.suppress(Exception):
# Storage close races: drop_collection between snapshot
# and prune fails inside prune_ttl with a missing-coll
# error. The sweeper should never crash the daemon.
total += self.prune_ttl(db, coll, now=now)
return total
def _ttl_sweep_loop(self) -> None:
"""Background sweeper: every ``ttl_sweep_seconds`` walk all
collections and prune expired docs. Stops when ``_ttl_stop``
is set or the storage is closed.
Drops the per-thread WT session before each iteration so the
next cursor call opens a fresh session. WiredTiger sessions
carry a sticky read snapshot — without the reset, reads on
this thread would never observe rows committed by other
writers, and TTL sweeps would always return 0 even when
expired docs existed. Same pattern as ``read_oplog``.
"""
import logging
log = logging.getLogger("secantus.storage.ttl")
while not self._ttl_stop.wait(self._ttl_sweep_seconds):
if self._closed:
return
self._reset_thread_session()
try:
self.prune_ttl_all_collections()
except Exception:
# Sweeper failures must not propagate — they'd kill
# the daemon thread and silently disable expiry.
log.exception("ttl sweep failed")
def emit_noop_heartbeat(self) -> int:
"""Append one ``{op: "n"}`` heartbeat to the oplog and return its seq.
The entry shape mirrors mongod's periodic noop: ``op = "n"``,
an empty namespace, current cluster time, and a small
``o = {msg: "periodic noop"}`` payload. Change-stream consumers
skip ``op: "n"`` rows in projection but still advance their
``position_seq`` and ``last_token`` past them, so the resume
token of a quiet collection stays current.
Public so callers (admin tooling, tests that drive heartbeats
deterministically) can fire one explicitly.
"""
with self._lock:
return self._emit_oplog(
[
{
"op": "n",
"ns": "",
"o": {"msg": "periodic noop"},
}
]
)
def _noop_heartbeat_loop(self) -> None:
"""Background heartbeat: emit one ``{op: "n"}`` oplog entry every
``noop_heartbeat_seconds``. Stops when ``_noop_stop`` is set
or the storage is closed. Failures are logged and swallowed —
a transient WT error must not kill the daemon thread.
"""
import logging
log = logging.getLogger("secantus.storage.noop")
while not self._noop_stop.wait(self._noop_heartbeat_seconds):
if self._closed:
return
try:
self.emit_noop_heartbeat()
except Exception:
log.exception("noop heartbeat failed")
def _reset_thread_session(self) -> None:
"""Close the calling thread's cached WT session + cursors so
the next ``_session()`` call opens fresh ones. Needed when a
thread reads in a loop and must observe writes from other
threads (snapshot is otherwise sticky)."""
s = getattr(self._tls, "session", None)
if s is None:
return
cursors = getattr(self._tls, "cursors", {}) or {}
for c in cursors.values():
with contextlib.suppress(Exception):
c.close()
with contextlib.suppress(Exception):
s.close()
with self._lock, contextlib.suppress(ValueError):
self._all_sessions.remove(s)
self._tls.session = None
self._tls.cursors = {}
def checkpoint(self) -> None:
"""Force a WiredTiger checkpoint to flush dirty pages to disk.
Backs the ``fsync`` command and the admin UI's maintenance
slice. Lock-protected so concurrent commands wait their turn.
On in-memory backends the call is a no-op (WT's in_memory
engine has no disk to flush and rejects with a noisy stderr
log).
"""
with self._lock:
if self._closed or self._in_memory:
return
self._session().checkpoint()
@contextlib.contextmanager
def _batch_transaction(self) -> Any:
"""Group multiple cursor writes into one WT transaction = one log record.
WT auto-commits every individual ``cursor.insert()`` /
``cursor.update()`` etc., which means N writes produce N log
records and N commit overheads. With this wrapper, the same N
writes share a single commit (and therefore a single log
record): on a typical bulk insert that's a 2-5x throughput
win for ``--batch-size > 1`` on the wire side, with the same
durability guarantee (all-or-nothing on commit).
Caller must already hold ``self._lock``. Reads within the
transaction observe the in-progress writes — fine for our
unique-conflict probes which need to see uncommitted siblings
in the same batch.
On exception the transaction is rolled back. Callers that
accumulate per-doc errors (e.g. ``ordered=False`` insert)
should NOT raise out of the block — they handle the per-doc
errors locally and let the surviving writes commit.
"""
session = self._session()
# Cached cursors must be reset before begin_transaction so they
# don't carry a stale snapshot from before the transaction
# boundary. WT documents this requirement explicitly.
for c in getattr(self._tls, "cursors", {}).values():
with contextlib.suppress(Exception):
c.reset()
session.begin_transaction()
try:
yield session
except Exception:
with contextlib.suppress(Exception):
session.rollback_transaction()
raise
else:
session.commit_transaction()
def _session(self) -> Any:
s = getattr(self._tls, "session", None)
if s is None:
s = self._conn.open_session()
self._tls.session = s
self._tls.cursors = {}
with self._lock:
self._all_sessions.append(s)
return s
def _refresh_read_snapshot(self) -> None:
"""Force the per-thread WT session to acquire a fresh read snapshot.
WiredTiger's default snapshot isolation pins a session's read
view at first cursor access; subsequent reads on the same
session see exactly that point-in-time view until the session
commits / rolls back a transaction. That's correct for a single
in-flight operation, but our daemon reuses one session per
connection thread across the full lifetime of a TCP connection.
Without an explicit snapshot refresh, a long-lived client
connection (Java's ``ClusterFixture`` is the canonical case)
does an insert, idles while another connection commits a
write, then reads — and sees the stale pre-other-write view.
``session.reset_snapshot()`` releases the held snapshot so the
next cursor read picks up the latest committed state. Called at
the top of every public read entry point (``find_matching``,
``count_matching``, ``list_*``, ``explain_plan``) so cross-
connection visibility matches real ``mongod``.
"""
s = getattr(self._tls, "session", None)
if s is None:
return
with contextlib.suppress(Exception):
# ``reset_snapshot()`` errors if the session is in an
# explicit transaction. Reads never run inside one
# (``_batch_transaction`` is write-only), so the exception
# path is defensive — log via ``suppress`` and move on.
s.reset_snapshot()
def _cursor(self, table: str, *, overwrite: bool = True) -> Any:
self._session()
cursors: dict[tuple[str, bool], Any] = self._tls.cursors
key = (table, overwrite)
c = cursors.get(key)
if c is None:
cfg = None if overwrite else "overwrite=false"
c = self._tls.session.open_cursor(table, None, cfg)
cursors[key] = c
else:
c.reset()
return c
def _coll_options(self, db: str, coll: str) -> dict[str, Any] | None:
c = self._cursor(_COLL_TABLE)
c.set_key(db, coll)
rc = c.search()
if rc != 0:
return None
blob = bytes(c.get_value())
return bson.decode(blob) if blob else {}
def _ensure_collection(self, db: str, coll: str) -> None:
c = self._cursor(_COLL_TABLE)
c.set_key(db, coll)
if c.search() == 0:
return
c.reset()
c[db, coll] = b""
def collection_exists(self, db: str, coll: str) -> bool:
with self._lock:
return self._coll_options(db, coll) is not None
def create_collection(self, db: str, coll: str) -> bool:
with self._lock:
c = self._cursor(_COLL_TABLE)
c.set_key(db, coll)
if c.search() == 0:
return False
c.reset()
c[db, coll] = b""
self._collection_uuid(db, coll) # mint and persist
ui = self._collection_uuid(db, coll)
self._emit_oplog(
[
{
"op": "c",
"ns": f"{db}.$cmd",
"ui": bson.Binary(ui.bytes, subtype=4),
"o": {
"create": coll,
"idIndex": {"v": 2, "key": {"_id": 1}, "name": "_id_"},
},
}
]
)
return True
def _scan_docs(self, db: str, coll: str) -> Iterable[tuple[bytes, bytes]]:
c = self._cursor(_DOC_TABLE)
c.set_key(db, coll, b"")
rc = c.search_near()
if rc == wt.WT_NOTFOUND:
return
if rc < 0 and c.next() != 0:
return
while True:
k = c.get_key()
if k[0] != db or k[1] != coll:
return
yield bytes(k[2]), bytes(c.get_value())
if c.next() != 0:
return
def _all_docs(self, db: str, coll: str) -> list[dict[str, Any]]:
# Two-stage to keep ``bson.decode`` out of ``self._lock`` —
# otherwise an N-doc scan blocks every other thread for the
# whole decode loop. Lock owns the WT cursor walk; decode
# happens after release.
with self._lock:
blobs = [blob for _id_k, blob in self._scan_docs(db, coll)]
return [bson.decode(blob) for blob in blobs]
def _all_docs_with_id_key(self, db: str, coll: str) -> list[tuple[dict[str, Any], bytes]]:
with self._lock:
raw = [(id_k, blob) for id_k, blob in self._scan_docs(db, coll)]
return [(bson.decode(blob), id_k) for id_k, blob in raw]
def scan_docs_after_id_key(
self, db: str, coll: str, after: bytes | None
) -> list[tuple[bytes, dict[str, Any]]]:
"""Scan the document table in natural (id_key) order, returning
only rows whose ``id_key`` is strictly greater than ``after``.
``after`` of ``None`` returns the entire collection.
Used by the tailable-cursor producer to emit only the docs
inserted since the last poll. Returns ``[(id_key, doc), ...]``
— callers update their ``after`` checkpoint to the last
returned ``id_key`` for the next poll.
"""
# Two-stage: collect raw bytes under the lock, decode after.
with self._lock:
raw: list[tuple[bytes, bytes]] = [
(id_k, blob)
for id_k, blob in self._scan_docs(db, coll)
if after is None or id_k > after
]
return [(id_k, bson.decode(blob)) for id_k, blob in raw]
def collection_is_capped(self, db: str, coll: str) -> bool:
"""Public predicate: does the collection have ``capped: true`` set?"""
with self._lock:
opts = self._coll_options(db, coll) or {}
return bool(opts.get("capped"))
def insert(
self, db: str, coll: str, docs: Iterable[dict[str, Any]], *, ordered: bool = True
) -> tuple[int, list[dict[str, Any]]]:
inserted = 0
errors: list[dict[str, Any]] = []
oplog_entries: list[dict[str, Any]] = []
fresh_id_keys: set[bytes] = set()
oplog_on = self.enable_oplog
with self._coll_lock(db, coll), self._batch_transaction():
# Per-collection lock (Phase 2.4): writes to other
# collections proceed in parallel; same-collection writes
# still serialise to keep the unique-index pre-check
# race-free. _batch_transaction wraps the per-doc cursor
# inserts (doc table + index entries + oplog) in one
# explicit WT transaction so they share a single commit /
# log record.
self._ensure_collection(db, coll)
ns = self._ns(db, coll) if oplog_on else ""
ui = self._collection_uuid(db, coll) if oplog_on else None
indexes = self._all_indexes(db, coll)
partials = self._partial_filters(db, coll)
multikey_names = self._multikey_index_names(db, coll)
for index, doc in enumerate(docs):
if "_id" not in doc:
doc["_id"] = bson.ObjectId()
key = _id_key(doc["_id"])
conflict = self._unique_conflict(
db, coll, doc, indexes, exclude_id_key=None, partials=partials
)
if conflict is not None:
cname, kpat, kval = conflict
errors.append(
{
"index": index,
"code": 11000,
"errmsg": (
f"E11000 duplicate key error in index {cname}: _id={doc['_id']!r}"
),
"keyPattern": kpat,
"keyValue": kval,
}
)
if ordered:
break
continue
# Pre-flight every geo index: a bad geometry should reject
# the doc *before* it lands in the doc table, so we don't
# leave a half-indexed write behind. Validation is cheap;
# _write_index_entries below recomputes the same cells.
try:
self._validate_geo_indexes(db, coll, doc, indexes, partials)
except GeoExtractError as exc:
errors.append({"index": index, "code": 16572, "errmsg": str(exc)})
if ordered:
break
continue
blob = bson.encode(doc)
doc_cur = self._cursor(_DOC_TABLE, overwrite=False)
doc_cur.set_key(db, coll, key)
doc_cur.set_value(blob)
try:
doc_cur.insert()
except wt.WiredTigerError:
errors.append(
{
"index": index,
"code": 11000,
"errmsg": f"E11000 duplicate key error: _id {doc['_id']!r}",
}
)
if ordered:
break
continue
self._write_index_entries(db, coll, doc, indexes, partials)
multikey_names = self._maybe_mark_multikey(db, coll, doc, indexes, multikey_names)
inserted += 1
if oplog_on:
oplog_entries.append(
{
"op": "i",
"ns": ns,
"ui": bson.Binary(ui.bytes, subtype=4),
"o": dict(doc),
"o2": {"_id": doc["_id"]},
}
)
fresh_id_keys.add(key)
cap_entries, cap_pre_images = self._enforce_capped_bounds_locked(
db, coll, fresh_id_keys, indexes, partials, oplog_on, ns, ui
)
if oplog_entries or cap_entries:
pre_images = [None] * len(oplog_entries) + cap_pre_images
self._emit_oplog(oplog_entries + cap_entries, pre_images)
return inserted, errors
def _enforce_capped_bounds_locked(
self,
db: str,
coll: str,
fresh_id_keys: set[bytes],
indexes: list[tuple[str, dict[str, Any], bool, bool]],
partials: dict[str, dict[str, Any]],
oplog_on: bool,
ns: str,
ui: _uuid.UUID | None,
) -> tuple[list[dict[str, Any]], list[bytes | None]]:
"""Evict oldest non-fresh docs from a capped collection until within bounds.
"Oldest" is the natural-order walk over the doc table, which matches
insertion order when ``_id`` is monotonic (e.g. the default
ObjectId). For non-monotonic ``_id`` values the eviction order
reflects ``_id`` byte order, not literal insertion order — capped
users with custom ``_id`` should not rely on FIFO semantics.
"""
raw = self._coll_options(db, coll) or {}
if not raw.get("capped"):
return [], []
size_limit = raw.get("size")
max_limit = raw.get("max")
if size_limit is None and max_limit is None:
return [], []
scanned = list(self._scan_docs(db, coll))
total = sum(len(blob) for _id_k, blob in scanned)
count = len(scanned)
oplog_entries: list[dict[str, Any]] = []
pre_images: list[bytes | None] = []
preimages_on = oplog_on and self._pre_post_images_enabled(db, coll)
for id_k, blob in scanned:
over_size = size_limit is not None and total > size_limit
over_max = max_limit is not None and count > max_limit
if not over_size and not over_max:
break
if id_k in fresh_id_keys:
# Don't evict docs we just inserted in this batch — they
# always sort to the tail with monotonic _ids, so reaching
# one means everything left is fresh too.
break
doc = bson.decode(blob)
self._delete_index_entries(db, coll, doc, indexes, partials)
doc_cur = self._cursor(_DOC_TABLE)
doc_cur.set_key(db, coll, id_k)
doc_cur.remove()
total -= len(blob)
count -= 1
if oplog_on:
entry: dict[str, Any] = {
"op": "d",
"ns": ns,
"o": {"_id": doc["_id"]},
"o2": {"_id": doc["_id"]},
}
if ui is not None:
entry["ui"] = bson.Binary(ui.bytes, subtype=4)
oplog_entries.append(entry)
pre_images.append(bson.encode(doc) if preimages_on else None)
return oplog_entries, pre_images
def find_matching(
self,
db: str,
coll: str,
filter: dict[str, Any] | None = None,
*,
skip: int = 0,
limit: int = 0,
sort: Mapping[str, Any] | None = None,
projection: Mapping[str, Any] | None = None,
hint: str | Mapping[str, Any] | None = None,
let: dict[str, Any] | None = None,
collation: Any = None,
) -> list[dict[str, Any]]:
if self._is_oplog_rs(db, coll):
return self._find_oplog_rs(
filter,
skip=skip,
limit=limit,
sort=sort,
projection=projection,
let=let,
collation=collation,
)
from secantus.collation import parse as _parse_collation
collation_obj = _parse_collation(collation)
self._refresh_read_snapshot()
filter = filter or {}
in_sort_order = False
# Two-stage decode discipline: the lock is held only for the
# WT cursor walk and any index routing; the COLLSCAN fallback
# collects raw blobs while the lock is held and defers
# ``bson.decode`` (and the ``matches()`` predicate, sorting,
# projection) to *after* the lock releases. Concurrent readers
# then decode in parallel even while a writer holds the lock
# for inserts. Index-path candidates still come back already
# decoded — that's a deeper refactor (Phase 2 territory).
candidates: list[dict[str, Any]] | None = None
raw_blobs: list[bytes] | None = None
with self._lock:
sort_field, sort_dir = self._single_sort_spec(sort)
if hint is not None:
resolved = self._resolve_hint(db, coll, hint)
candidates, in_sort_order = self._candidates_from_hint(
db, coll, resolved, sort_field, sort_dir
)
elif collation_obj is not None:
# Index entries are byte-encoded under the default
# codepoint ordering — they don't carry the collation's
# case- or accent-folding. A case-insensitive equality
# like ``{x: "PING"}`` would miss an indexed doc with
# ``x: "ping"`` because the entries are stored under
# the literal bytes of the inserted value. Force a
# COLLSCAN so ``matches()`` does the comparison with
# the collation in hand. ``mongod`` only honours an
# index when its definition's collation matches the
# query's; we don't support per-index collation yet,
# so the safe path is always-COLLSCAN-when-collation.
raw_blobs = [b for _, b in self._scan_docs(db, coll)]
else:
candidates = self._try_index_lookup(db, coll, filter)
if candidates is not None and sort_field is not None:
if (
len(filter) == 1
and not next(iter(filter)).startswith("$")
and next(iter(filter)) == sort_field
):
in_sort_order = True
idx = self._find_leading_field_index(db, coll, sort_field, filter)
idx_dir = idx[1] if idx else 1
if sort_dir != idx_dir:
candidates = list(reversed(candidates))
elif candidates is None and not filter and sort_field is not None:
idx = self._find_leading_field_index(db, coll, sort_field, filter)
if idx is not None:
idx_name, idx_dir, _is_compound = idx
# If the index direction matches the sort direction,
# walk forward; if it's opposite, walk backward.
reverse = sort_dir != idx_dir
candidates = self._walk_index_in_order(db, coll, idx_name, reverse=reverse)
in_sort_order = True
# Multi-field sort acceleration: when sort has 2+ fields and
# filter is empty, try to find a compound index whose key
# spec exactly matches (or fully inverts) the sort. Walking
# that index in the right direction yields the requested
# order without a Python-side post-sort.
if candidates is None and not filter and sort_field is None and sort:
multi_spec = self._multi_sort_spec(sort)
if multi_spec is not None and len(multi_spec) > 1:
match = self._compound_index_for_sort(db, coll, multi_spec)
if match is not None:
idx_name, reverse = match
candidates = self._walk_index_in_order(
db, coll, idx_name, reverse=reverse
)
in_sort_order = True
if candidates is None:
raw_blobs = [b for _, b in self._scan_docs(db, coll)]
if candidates is None:
assert raw_blobs is not None
candidates = [bson.decode(b) for b in raw_blobs]
out = [d for d in candidates if matches(d, filter, vars=let, collation=collation_obj)]
if sort and not in_sort_order:
out = sort_docs(out, sort)
if skip:
out = out[skip:]
if limit > 0:
out = out[:limit]
if projection:
out = [apply_projection(d, projection) for d in out]
return out
def _resolve_hint(self, db: str, coll: str, hint: str | Mapping[str, Any]) -> str:
"""Resolve ``hint`` to an index name (or ``$natural``).
``hint`` may be an index name string, a key-spec dict matching an
existing index, ``"$natural"``, or ``{"$natural": +/-1}``. Anything
else raises ``BadHint`` so the command layer can return a Mongo
``BadValue`` error.
"""
if isinstance(hint, str):
if hint == "$natural":
return "$natural"
if hint == _ID_INDEX_NAME:
return _ID_INDEX_NAME
for name, _key_spec, _sparse, _unique in self._all_indexes(db, coll):
if name == hint:
return name
raise BadHint(f"hint {hint!r} does not correspond to an existing index")
if isinstance(hint, Mapping):
if list(hint) == ["$natural"]:
return "$natural"
if list(hint) == ["_id"] and int(hint["_id"]) == 1:
return _ID_INDEX_NAME
for name, key_spec, _sparse, _unique in self._all_indexes(db, coll):
if dict(key_spec) == dict(hint):
return name
raise BadHint(f"hint {dict(hint)!r} does not correspond to an existing index")
raise BadHint(f"invalid hint type: {type(hint).__name__}")
def _candidates_from_hint(
self,
db: str,
coll: str,
resolved: str,
sort_field: str | None,
sort_dir: int,
) -> tuple[list[dict[str, Any]], bool]:
"""Walk the index named by ``resolved`` (or full collection for $natural).
Returns ``(candidates, in_sort_order)`` where ``in_sort_order`` is
True when the hint's leading field matches the sort field — in
which case ``find_matching`` skips the post-sort step.
"""
if resolved == "$natural":
return [bson.decode(b) for _, b in self._scan_docs(db, coll)], False
if resolved == _ID_INDEX_NAME:
# The doc table is keyed by id_key; iterating it gives entries
# sorted by encoded _id, which matches the _id_ index walk.
docs = [bson.decode(b) for _, b in self._scan_docs(db, coll)]
in_order = sort_field == "_id"
if in_order and sort_dir == -1:
docs = list(reversed(docs))
return docs, in_order
# Find the index's leading field and its direction
leading: str | None = None
leading_dir = 1
for name, key_spec, _sparse, _unique in self._all_indexes(db, coll):
if name == resolved:
first = next(iter(key_spec))
leading = first
leading_dir = int(key_spec[first])
break
candidates = self._walk_index_in_order(db, coll, resolved, reverse=False)
in_order = sort_field is not None and sort_field == leading
if in_order and sort_dir != leading_dir:
candidates = list(reversed(candidates))
return candidates, in_order
@staticmethod
def _single_sort_spec(sort: Mapping[str, Any] | None) -> tuple[str | None, int]:
"""Return ``(field, direction)`` if ``sort`` is single-field +/-1, else ``(None, 0)``."""
if not sort or len(sort) != 1:
return None, 0
f, d = next(iter(sort.items()))
if f.startswith("$"):
return None, 0
try:
di = int(d)
except (TypeError, ValueError):
return None, 0
if di not in (-1, 1):
return None, 0
return f, di
@staticmethod
def _multi_sort_spec(
sort: Mapping[str, Any] | None,
) -> list[tuple[str, int]] | None:
"""Return a list of ``(field, direction)`` pairs for a multi-field
sort spec, or ``None`` if any entry is operator-prefixed or has a
non-``±1`` direction.
Used for compound-index sort acceleration: an index whose key
spec exactly matches (or fully inverts) the returned list lets
``find_matching`` walk WT in the requested order and skip the
Python-side post-sort entirely.
"""
if not sort:
return None
out: list[tuple[str, int]] = []
for field, direction in sort.items():
if field.startswith("$"):
return None
try:
d = int(direction)
except (TypeError, ValueError):
return None
if d not in (-1, 1):
return None
out.append((field, d))
return out
def _compound_index_for_sort(
self, db: str, coll: str, sort_fields: list[tuple[str, int]]
) -> tuple[str, bool] | None:
"""Find a compound index that satisfies ``sort_fields`` end-to-end.
Returns ``(index_name, reverse_walk)`` where ``reverse_walk`` is
True when the matching index is the *fully-inverted* permutation
of the sort (walking backward yields the requested order).
Multikey indexes are excluded — array values in the index could
produce row order that doesn't match the BSON cross-type sort
the user expects from a sort spec, so we'd fall back to Python
sort anyway.
Strict match only: the index key spec must have the same fields
in the same order with directions either matching the sort spec
or being the full inverse. Partial-prefix matches (sort uses 3
fields, index has 2) aren't accelerated; the savings on the
leading prefix are usually less than the cost of the trailing
Python sort over the materialised set.
"""
multikey = self._multikey_index_names(db, coll)
target = list(sort_fields)
inverted = [(f, -d) for f, d in target]
for name, key_spec, _sparse, _unique in self._all_indexes(db, coll):
if name in multikey:
continue
try:
idx_pairs = [(f, int(d)) for f, d in key_spec.items()]
except (TypeError, ValueError):
continue
if any(d not in (-1, 1) for _, d in idx_pairs):
continue
if idx_pairs == target:
return name, False
if idx_pairs == inverted:
return name, True
return None
def _single_field_index_for(self, db: str, coll: str, field: str) -> tuple[str, int] | None:
"""Return ``(index_name, direction)`` for a single-field index on
``field``, or ``None`` if no such index exists. Direction is the
index's stored sort direction (`+1` for ASC, `-1` for DESC)."""
for name, key_spec, _sparse, _unique in self._all_indexes(db, coll):
if list(key_spec.keys()) == [field]:
d = int(key_spec[field])
if d in (1, -1):
return name, d
return None
def _walk_index_in_order(
self, db: str, coll: str, name: str, *, reverse: bool = False
) -> list[dict[str, Any]]:
c = self._cursor(_IDX_ENTRIES_TABLE)
c.set_key(db, coll, name, b"")
rc = c.search_near()
if rc == wt.WT_NOTFOUND:
return []
if rc < 0 and c.next() != 0:
return []
id_keys: list[bytes] = []
while True:
k = c.get_key()
if (k[0], k[1], k[2]) != (db, coll, name):
break
packed = bytes(k[3])
_esc, row_id = _unpack_entry(packed)
id_keys.append(row_id)
if c.next() != 0:
break
if reverse:
id_keys.reverse()
return self._docs_by_id_keys(db, coll, id_keys)
[docs]
def explain_plan(
self,
db: str,
coll: str,
filter: dict[str, Any] | None = None,
*,
sort: Mapping[str, Any] | None = None,
hint: str | Mapping[str, Any] | None = None,
) -> dict[str, Any]:
"""Plan summary for what ``find_matching`` would do with these args.
No execution; mirrors the same routing decisions. Returns
``{"kind": "COLLSCAN"}`` or ``{"kind": "IXSCAN", "index_name",
"key_pattern", "direction"}``. ``direction`` is ``"forward"``
unless a sort spec inverts it relative to the chosen index.
"""
filter = filter or {}
with self._lock:
sort_field, sort_dir = self._single_sort_spec(sort)
if hint is not None:
try:
resolved = self._resolve_hint(db, coll, hint)
except BadHint:
return {"kind": "COLLSCAN"}
if resolved == "$natural":
return {"kind": "COLLSCAN"}
if resolved == _ID_INDEX_NAME:
direction = "forward"
if sort_field == "_id" and sort_dir == -1:
direction = "backward"
return {
"kind": "IXSCAN",
"index_name": _ID_INDEX_NAME,
"key_pattern": {"_id": 1},
"direction": direction,
}
key_spec = self._key_spec_for(db, coll, resolved)
if key_spec is None:
return {"kind": "COLLSCAN"}
return self._make_ixscan_plan(resolved, key_spec, sort_field, sort_dir)
picked = self._pick_index_for_filter(db, coll, filter)
if picked is not None:
name, key_spec = picked
return self._make_ixscan_plan(name, key_spec, sort_field, sort_dir)
if not filter and sort_field is not None:
idx = self._find_leading_field_index(db, coll, sort_field, filter)
if idx is not None:
name, _idx_dir, _is_compound = idx
key_spec = self._key_spec_for(db, coll, name)
if key_spec is not None:
return self._make_ixscan_plan(name, key_spec, sort_field, sort_dir)
# Multi-field sort acceleration mirrored in the planner: same
# rules as find_matching (compound key spec exactly matches
# or fully inverts the sort, filter empty).
if not filter and sort_field is None and sort:
multi_spec = self._multi_sort_spec(sort)
if multi_spec is not None and len(multi_spec) > 1:
match = self._compound_index_for_sort(db, coll, multi_spec)
if match is not None:
name, reverse = match
key_spec = self._key_spec_for(db, coll, name)
if key_spec is not None:
return {
"kind": "IXSCAN",
"index_name": name,
"key_pattern": key_spec,
"direction": "backward" if reverse else "forward",
}
return {"kind": "COLLSCAN"}
def _key_spec_for(self, db: str, coll: str, name: str) -> dict[str, Any] | None:
for n, key_spec, _sparse, _unique in self._all_indexes(db, coll):
if n == name:
return dict(key_spec)
return None
def _pick_geo_index_for_filter(
self, db: str, coll: str, filter: dict[str, Any]
) -> tuple[str, dict[str, Any]] | None:
"""Mirror :meth:`_try_geo_index_id_keys`'s index selection (no exec).
Returns ``(name, key_spec)`` if the filter has a geo operator on
a geo-indexed field; ``None`` otherwise. The picker is exact —
``_try_geo_index_id_keys`` may still bail (e.g. ``$near`` with no
max distance), but ``explain`` reports IXSCAN whenever an index
*could* serve the query, matching mongod's planner explain.
"""
for field, value in filter.items():
if not isinstance(value, dict):
continue
if not any(op in value for op in self._GEO_OPS):
continue
for name, key_spec, _opts in self._iter_indexes(db, coll):
geo = _geo_type_of(key_spec)
if geo is not None and geo[0] == field:
return name, dict(key_spec)
return None
def _pick_index_for_filter(
self, db: str, coll: str, filter: dict[str, Any]
) -> tuple[str, dict[str, Any]] | None:
"""Mirror ``_try_index_lookup``'s index-selection (no execution)."""
if not filter:
return None
if any(f.startswith("$") for f in filter):
return None
# Mirror `_try_index_id_keys`: geo dispatch first.
geo_pick = self._pick_geo_index_for_filter(db, coll, filter)
if geo_pick is not None:
return geo_pick
if all(not isinstance(v, dict) for v in filter.values()):
picked = self._pick_compound_eq_index(db, coll, filter)
if picked is not None:
return picked
if len(filter) >= 2:
picked = self._pick_compound_range_index(db, coll, filter)
if picked is not None:
return picked
if len(filter) != 1:
return None
field, value = next(iter(filter.items()))
idx_match = self._find_leading_field_index(db, coll, field, filter)
if idx_match is None:
return None
if isinstance(value, dict):
if not value or not all(k.startswith("$") for k in value):
return None
if not all(op in self._RANGE_OPS for op in value):
return None
name, _direction, _is_compound = idx_match
key_spec = self._key_spec_for(db, coll, name)
if key_spec is None:
return None
return name, key_spec
@staticmethod
def _make_ixscan_plan(
name: str,
key_spec: Mapping[str, Any],
sort_field: str | None,
sort_dir: int,
) -> dict[str, Any]:
direction = "forward"
if sort_field is not None and sort_field in key_spec:
idx_dir = int(key_spec[sort_field])
if sort_dir != 0 and sort_dir != idx_dir:
direction = "backward"
return {
"kind": "IXSCAN",
"index_name": name,
"key_pattern": dict(key_spec),
"direction": direction,
}
def count_matching(
self,
db: str,
coll: str,
filter: dict[str, Any] | None = None,
*,
let: dict[str, Any] | None = None,
collation: Any = None,
) -> int:
if self._is_oplog_rs(db, coll):
return self._count_oplog_rs(filter, let=let, collation=collation)
from secantus.collation import parse as _parse_collation
collation_obj = _parse_collation(collation)
self._refresh_read_snapshot()
if not filter:
with self._lock:
return sum(1 for _ in self._scan_docs(db, coll))
return sum(
1
for doc in self._all_docs(db, coll)
if matches(doc, filter, vars=let, collation=collation_obj)
)
[docs]
def collection_data_size(self, db: str, coll: str) -> int:
"""Sum of bson-encoded doc bytes for ``coll``.
Used by ``collStats`` / ``dbStats`` for ``size`` / ``dataSize``.
Best-effort estimate — doesn't include WT block overhead.
"""
with self._lock:
return sum(len(blob) for _id_k, blob in self._scan_docs(db, coll))
[docs]
def index_sizes(self, db: str, coll: str) -> dict[str, int]:
"""Map of index name → sum of packed entry-key bytes.
``_id_`` is reported separately as ``len(id_key)`` summed across
the doc table, so callers can include it alongside secondary
indexes for an accurate ``totalIndexSize``.
"""
with self._lock:
sizes: dict[str, int] = {}
id_size = sum(len(id_k) for id_k, _blob in self._scan_docs(db, coll))
if id_size:
sizes[_ID_INDEX_NAME] = id_size
entry_rows = self._collect_prefix(_IDX_ENTRIES_TABLE, (db, coll))
for k, _v in entry_rows:
name = k[2]
packed = bytes(k[3])
sizes[name] = sizes.get(name, 0) + len(packed)
return sizes
def update_matching(
self,
db: str,
coll: str,
filter: dict[str, Any],
update: dict[str, Any],
*,
multi: bool = False,
upsert: bool = False,
array_filters: list[dict[str, Any]] | None = None,
let: dict[str, Any] | None = None,
collation: Any = None,
validator: dict[str, Any] | None = None,
) -> dict[str, Any]:
from secantus.collation import parse as _parse_collation
collation_obj = _parse_collation(collation)
# Release any sticky session snapshot before the write's
# ``begin_transaction`` acquires a new one. Otherwise the
# transaction inherits a stale view and the candidate scan
# misses rows committed by other connections (the cross-
# connection visibility fix applied to reads — see
# ``_refresh_read_snapshot``).
self._refresh_read_snapshot()
matched = 0
modified = 0
upserted_id: Any = None
oplog_entries: list[dict[str, Any]] = []
pre_images: list[bytes | None] = []
oplog_on = self.enable_oplog
with self._coll_lock(db, coll), self._batch_transaction():
# Per-collection lock + one WT transaction per call. Every
# doc-table write + index-entry delete/insert + oplog write
# that lands in this method shares a single commit. Phase
# 2.4: was self._lock; now per-coll so different
# collections update in parallel.
self._ensure_collection(db, coll)
ns = self._ns(db, coll)
ui = self._collection_uuid(db, coll) if oplog_on else None
preimages_on = oplog_on and self._pre_post_images_enabled(db, coll)
indexes = self._all_indexes(db, coll)
partials = self._partial_filters(db, coll)
multikey_names = self._multikey_index_names(db, coll)
# Index-routed when the filter is covered (only matching id_keys
# come back from the index walk); full scan otherwise. Either
# way the doc cursor isn't held across writes — bytes are
# eagerly buffered. Only matching docs pay ``bson.decode``.
# With a collation in effect, fall back to a doc-table scan:
# the index entries don't carry the collation's folding, so
# an indexed equality probe would miss case-insensitive
# matches. Always materialise the list — the update loop
# rewrites the doc table via the cached cursor, which
# invalidates a still-walking ``_scan_docs`` cursor on the
# same session.
if collation_obj is not None:
candidates = list(self._scan_docs(db, coll))
else:
candidates = self._candidates_iter(db, coll, filter)
for id_k, blob in candidates:
doc = bson.decode(blob)
if not matches(doc, filter, vars=let, collation=collation_obj):
continue
matched += 1
pos = find_positional_matches(doc, filter)
new = apply_update(
doc,
update,
array_filters=array_filters,
positional_matches=pos,
let=let,
)
if new != doc:
# Document-validator check: collection-level
# ``validator`` (set via ``create`` / ``collMod``)
# rejects updates whose result fails the predicate.
# Caller passes ``None`` to skip
# (``bypassDocumentValidation: true``).
if validator is not None and not matches(new, validator):
raise DocumentValidationError(new.get("_id"))
new_id_key = _id_key(new["_id"])
conflict = self._unique_conflict(
db, coll, new, indexes, exclude_id_key=id_k, partials=partials
)
if conflict is not None:
cname, kpat, kval = conflict
raise IndexConflict(cname, new["_id"], key_pattern=kpat, key_value=kval)
# Geo validation must reject the update before any
# write happens, otherwise we'd be left with a
# half-deleted set of index entries.
self._validate_geo_indexes(db, coll, new, indexes, partials)
modified += 1
self._delete_index_entries(db, coll, doc, indexes, partials)
doc_cur = self._cursor(_DOC_TABLE)
doc_cur[db, coll, new_id_key] = bson.encode(new)
self._write_index_entries(db, coll, new, indexes, partials)
multikey_names = self._maybe_mark_multikey(
db, coll, new, indexes, multikey_names
)
if oplog_on:
is_replacement = not any(
isinstance(k, str) and k.startswith("$") for k in update
)
if is_replacement:
o_field: dict[str, Any] = dict(new)
else:
o_field = {"$v": 2, "diff": compute_update_description(doc, new)}
oplog_entries.append(
{
"op": "u",
"ns": ns,
"ui": bson.Binary(ui.bytes, subtype=4),
"o": o_field,
"o2": {"_id": doc["_id"]},
}
)
pre_images.append(bson.encode(doc) if preimages_on else None)
if not multi:
break
if matched == 0 and upsert:
seed: dict[str, Any] = {}
for k, v in filter.items():
if not k.startswith("$") and not isinstance(v, dict):
seed[k] = v
new = apply_update(seed, update, is_upsert=True, array_filters=array_filters)
if "_id" not in new:
new["_id"] = bson.ObjectId()
if validator is not None and not matches(new, validator):
raise DocumentValidationError(new.get("_id"))
upserted_id = new["_id"]
conflict = self._unique_conflict(
db, coll, new, indexes, exclude_id_key=None, partials=partials
)
if conflict is not None:
cname, kpat, kval = conflict
raise IndexConflict(cname, new["_id"], key_pattern=kpat, key_value=kval)
self._validate_geo_indexes(db, coll, new, indexes, partials)
doc_cur = self._cursor(_DOC_TABLE)
doc_cur[db, coll, _id_key(upserted_id)] = bson.encode(new)
self._write_index_entries(db, coll, new, indexes, partials)
self._maybe_mark_multikey(db, coll, new, indexes, multikey_names)
if oplog_on:
oplog_entries.append(
{
"op": "i",
"ns": ns,
"ui": bson.Binary(ui.bytes, subtype=4),
"o": dict(new),
"o2": {"_id": upserted_id},
}
)
pre_images.append(None)
cap_ns = ns if oplog_on else ""
cap_entries, cap_pre = self._enforce_capped_bounds_locked(
db, coll, set(), indexes, partials, oplog_on, cap_ns, ui
)
if cap_entries:
oplog_entries.extend(cap_entries)
pre_images.extend(cap_pre)
if oplog_entries:
self._emit_oplog(oplog_entries, pre_images)
return {"matched": matched, "modified": modified, "upserted_id": upserted_id}
def delete_matching(
self,
db: str,
coll: str,
filter: dict[str, Any],
*,
limit: int = 0,
let: dict[str, Any] | None = None,
collation: Any = None,
) -> int:
from secantus.collation import parse as _parse_collation
collation_obj = _parse_collation(collation)
# See ``update_matching`` — release the sticky snapshot so the
# candidate scan sees writes committed by other connections.
self._refresh_read_snapshot()
deleted = 0
oplog_entries: list[dict[str, Any]] = []
pre_images: list[bytes | None] = []
oplog_on = self.enable_oplog
with self._coll_lock(db, coll), self._batch_transaction():
# Per-collection lock (Phase 2.4) + one WT transaction.
# Groups the per-doc removes + index-entry deletes + oplog
# writes into one commit. Other collections delete in
# parallel.
ns = self._ns(db, coll) if oplog_on else ""
preimages_on = oplog_on and self._pre_post_images_enabled(db, coll)
ui = (
self._collection_uuid(db, coll)
if oplog_on and self._coll_options(db, coll) is not None
else None
)
indexes = self._all_indexes(db, coll)
partials = self._partial_filters(db, coll)
# Index-routed candidates when the filter is covered; full scan
# otherwise. See update_matching for the full-scan rationale.
# Collation forces a full scan — index entries don't carry the
# collation's folding. Always materialise into a list so the
# delete loop's writes don't invalidate the iteration cursor
# mid-scan (deletes via ``_cursor(_DOC_TABLE)`` share the
# cached cursor with ``_scan_docs``).
if collation_obj is not None:
candidates = list(self._scan_docs(db, coll))
else:
candidates = self._candidates_iter(db, coll, filter)
for id_k, blob in candidates:
doc = bson.decode(blob)
if not matches(doc, filter, vars=let, collation=collation_obj):
continue
self._delete_index_entries(db, coll, doc, indexes, partials)
doc_cur = self._cursor(_DOC_TABLE)
doc_cur.set_key(db, coll, id_k)
doc_cur.remove()
deleted += 1
if oplog_on:
entry: dict[str, Any] = {
"op": "d",
"ns": ns,
"o": {"_id": doc["_id"]},
"o2": {"_id": doc["_id"]},
}
if ui is not None:
entry["ui"] = bson.Binary(ui.bytes, subtype=4)
oplog_entries.append(entry)
pre_images.append(bson.encode(doc) if preimages_on else None)
if limit > 0 and deleted >= limit:
break
if oplog_entries:
self._emit_oplog(oplog_entries, pre_images)
return deleted
[docs]
def prune_ttl(
self,
db: str,
coll: str,
*,
now: _dt.datetime | None = None,
) -> int:
"""Delete docs whose indexed Date field is older than now - TTL.
For every index on ``coll`` with an ``expireAfterSeconds`` option,
walks the collection and deletes docs whose indexed field resolves
to a ``datetime`` older than ``now - expireAfterSeconds``. Docs
without the field, with non-date values, or with values inside the
TTL window are left in place. Real MongoDB runs this on a 60s
background sweeper; SecantusDB invokes it explicitly so tests can
drive expiry with an injected ``now``. Returns the number of docs
pruned.
"""
ttl_indexes: list[tuple[str, str, float]] = []
for name, key_spec, opts in self._iter_indexes(db, coll):
ttl = opts.get("expireAfterSeconds")
if not isinstance(ttl, (int, float)) or ttl < 0:
continue
field = next(iter(key_spec), None)
if not isinstance(field, str):
continue
ttl_indexes.append((name, field, float(ttl)))
if not ttl_indexes:
return 0
when = now if now is not None else _dt.datetime.now(_dt.timezone.utc)
if when.tzinfo is None:
when = when.replace(tzinfo=_dt.timezone.utc)
pruned = 0
oplog_entries: list[dict[str, Any]] = []
pre_images: list[bytes | None] = []
with self._lock:
ns = self._ns(db, coll)
preimages_on = self._pre_post_images_enabled(db, coll)
ui = (
self._collection_uuid(db, coll)
if self._coll_options(db, coll) is not None
else None
)
indexes = self._all_indexes(db, coll)
partials = self._partial_filters(db, coll)
candidates = list(self._scan_docs(db, coll))
for id_k, blob in candidates:
doc = bson.decode(blob)
expired = False
for _name, field, ttl_seconds in ttl_indexes:
value = get_path(doc, field)
if not isinstance(value, _dt.datetime):
continue
value_aware = value if value.tzinfo else value.replace(tzinfo=_dt.timezone.utc)
if (when - value_aware).total_seconds() > ttl_seconds:
expired = True
break
if not expired:
continue
self._delete_index_entries(db, coll, doc, indexes, partials)
doc_cur = self._cursor(_DOC_TABLE)
doc_cur.set_key(db, coll, id_k)
doc_cur.remove()
pruned += 1
entry: dict[str, Any] = {
"op": "d",
"ns": ns,
"o": {"_id": doc["_id"]},
"o2": {"_id": doc["_id"]},
}
if ui is not None:
entry["ui"] = bson.Binary(ui.bytes, subtype=4)
oplog_entries.append(entry)
pre_images.append(bson.encode(doc) if preimages_on else None)
if oplog_entries:
self._emit_oplog(oplog_entries, pre_images)
return pruned
@staticmethod
def _table_kf(table: str) -> str:
return {
_COLL_TABLE: "SS",
_DOC_TABLE: "SSu",
_IDX_TABLE: "SSS",
_IDX_ENTRIES_TABLE: "SSSu",
}[table]
@staticmethod
def _smallest_for_kf(kf: str) -> tuple[Any, ...]:
return tuple(b"" if c == "u" else "" for c in kf)
def _collect_prefix(
self, table: str, prefix: tuple[Any, ...]
) -> list[tuple[tuple[Any, ...], Any]]:
c = self._cursor(table)
kf = self._table_kf(table)
seed = prefix + self._smallest_for_kf(kf)[len(prefix) :]
c.set_key(*seed)
rc = c.search_near()
if rc == wt.WT_NOTFOUND:
return []
if rc < 0 and c.next() != 0:
return []
out: list[tuple[tuple[Any, ...], Any]] = []
while True:
k = tuple(c.get_key())
if k[: len(prefix)] != prefix:
break
v = c.get_value()
out.append((k, bytes(v) if isinstance(v, (bytes, bytearray)) else v))
if c.next() != 0:
break
return out
def _delete_keys(self, table: str, keys: list[tuple[Any, ...]]) -> None:
if not keys:
return
c = self._cursor(table)
for k in keys:
c.set_key(*k)
c.remove()
c.reset()
def drop_collection(self, db: str, coll: str) -> bool:
with self._lock:
existed = self._coll_options(db, coll) is not None
ui = self._collection_uuid(db, coll) if existed else None
for tbl in (_DOC_TABLE, _IDX_TABLE, _IDX_ENTRIES_TABLE):
rows = self._collect_prefix(tbl, (db, coll))
self._delete_keys(tbl, [k for k, _ in rows])
c = self._cursor(_COLL_TABLE)
c.set_key(db, coll)
if c.search() == 0:
c.remove()
if existed and ui is not None:
self._emit_oplog(
[
{
"op": "c",
"ns": f"{db}.$cmd",
"ui": bson.Binary(ui.bytes, subtype=4),
"o": {"drop": coll},
}
]
)
return existed
def drop_database(self, db: str) -> None:
with self._lock:
colls_with_ui: list[tuple[str, _uuid.UUID]] = []
for c_name in self.list_collections(db):
ui = self._collection_uuid(db, c_name)
colls_with_ui.append((c_name, ui))
for tbl in (_DOC_TABLE, _IDX_TABLE, _IDX_ENTRIES_TABLE, _COLL_TABLE):
rows = self._collect_prefix(tbl, (db,))
self._delete_keys(tbl, [k for k, _ in rows])
entries: list[dict[str, Any]] = []
for c_name, ui in colls_with_ui:
entries.append(
{
"op": "c",
"ns": f"{db}.$cmd",
"ui": bson.Binary(ui.bytes, subtype=4),
"o": {"drop": c_name},
}
)
entries.append({"op": "c", "ns": f"{db}.$cmd", "o": {"dropDatabase": 1}})
self._emit_oplog(entries)
def rename_collection(
self,
src_db: str,
src_coll: str,
dst_db: str,
dst_coll: str,
*,
drop_target: bool = False,
) -> tuple[bool, str | None]:
with self._lock:
if self._coll_options(src_db, src_coll) is None:
return False, f"source namespace does not exist: {src_db}.{src_coll}"
if (src_db, src_coll) == (dst_db, dst_coll):
return True, None
ui = self._collection_uuid(src_db, src_coll)
dst_existed = self._coll_options(dst_db, dst_coll) is not None
dst_ui = self._collection_uuid(dst_db, dst_coll) if dst_existed else None
if dst_existed:
if not drop_target:
return False, f"target namespace exists: {dst_db}.{dst_coll}"
for tbl in (_DOC_TABLE, _IDX_TABLE, _IDX_ENTRIES_TABLE):
rows = self._collect_prefix(tbl, (dst_db, dst_coll))
self._delete_keys(tbl, [k for k, _ in rows])
c = self._cursor(_COLL_TABLE)
c.set_key(dst_db, dst_coll)
if c.search() == 0:
c.remove()
for tbl in (_DOC_TABLE, _IDX_TABLE, _IDX_ENTRIES_TABLE):
rows = self._collect_prefix(tbl, (src_db, src_coll))
self._delete_keys(tbl, [k for k, _ in rows])
c = self._cursor(tbl)
for k, v in rows:
new_k = (dst_db, dst_coll) + k[2:]
c.set_key(*new_k)
c.set_value(v)
c.insert()
c.reset()
ensure = self._cursor(_COLL_TABLE)
ensure.set_key(dst_db, dst_coll)
if ensure.search() != 0:
ensure.reset()
ensure[dst_db, dst_coll] = b""
ensure.reset()
ensure.set_key(src_db, src_coll)
if ensure.search() == 0:
ensure.remove()
entries: list[dict[str, Any]] = []
if dst_existed and dst_ui is not None:
entries.append(
{
"op": "c",
"ns": f"{dst_db}.$cmd",
"ui": bson.Binary(dst_ui.bytes, subtype=4),
"o": {"drop": dst_coll},
}
)
entries.append(
{
"op": "c",
"ns": f"{src_db}.$cmd",
"ui": bson.Binary(ui.bytes, subtype=4),
"o": {
"renameCollection": f"{src_db}.{src_coll}",
"to": f"{dst_db}.{dst_coll}",
},
}
)
self._emit_oplog(entries)
return True, None
def list_collections(self, db: str) -> list[str]:
self._refresh_read_snapshot()
with self._lock:
c = self._cursor(_COLL_TABLE)
c.set_key(db, "")
rc = c.search_near()
if rc != wt.WT_NOTFOUND and not (rc < 0 and c.next() != 0):
out: list[str] = []
while True:
k = c.get_key()
if k[0] != db:
break
out.append(k[1])
if c.next() != 0:
break
else:
out = []
# Synthesise ``local.oplog.rs`` for the ``local`` db whenever the
# oplog is enabled. The collection isn't materialised in
# ``_COLL_TABLE`` — it's a view over the oplog WT table — but
# ``listCollections`` needs to surface it so pymongo clients can
# discover it before querying.
if self.enable_oplog and db == "local" and "oplog.rs" not in out:
out.append("oplog.rs")
return sorted(out)
def list_databases(self) -> list[str]:
self._refresh_read_snapshot()
with self._lock:
c = self._cursor(_COLL_TABLE)
seen: set[str] = set()
rc = c.next()
while rc == 0:
k = c.get_key()
seen.add(k[0])
rc = c.next()
# mongod always exposes the ``local`` database; mirror that
# when the oplog is enabled so listDatabases includes it even
# before any user-created collection lands in ``local``.
if self.enable_oplog:
seen.add("local")
return sorted(seen)
def create_index(
self,
db: str,
coll: str,
name: str,
key_spec: Mapping[str, Any],
options: Mapping[str, Any] | None = None,
) -> bool:
if name == _ID_INDEX_NAME:
return False
# Text / hashed indexes are documented out-of-scope (CLAUDE.md
# "Out of scope regardless: text / hashed / wildcard indexes").
# Surface the rejection as a typed exception (caught in
# ``commands._create_indexes``) instead of letting the geo
# picker / encoder later fall over with an opaque internal
# error. Mongo-node-driver's ``Find should correctly sort using
# text search`` test expects a clean error here.
for _field, _spec_val in key_spec.items():
if _spec_val in ("text", "hashed"):
raise CreateIndexUnsupported(f"{_spec_val} indexes are not supported by SecantusDB")
options = dict(options or {})
with self._lock:
self._ensure_collection(db, coll)
c = self._cursor(_IDX_TABLE)
c.set_key(db, coll, name)
if c.search() == 0:
# Index exists. Mongo rejects re-creation with conflicting
# options (different ``unique`` / ``sparse`` / ``hidden``
# / ``expireAfterSeconds``). Silently succeeding hides
# a bug surface that mongo-ruby-driver's ``Collection#
# create_indexes when index creation fails`` test pins.
existing_raw = bytes(c.get_value())
existing = bson.decode(existing_raw) if existing_raw else {}
existing_opts = dict(existing.get("options") or {})
_CONFLICTING_OPTS = (
"unique",
"sparse",
"hidden",
"expireAfterSeconds",
"partialFilterExpression",
)
for opt in _CONFLICTING_OPTS:
if (opt in options or opt in existing_opts) and options.get(
opt
) != existing_opts.get(opt):
raise IndexOptionsConflict(
f"Index with name '{name}' already exists with different options"
)
return False
sparse = bool(options.get("sparse"))
unique = bool(options.get("unique"))
partial_filter = options.get("partialFilterExpression")
if not isinstance(partial_filter, Mapping) or not partial_filter:
partial_filter = None
key_spec_dict = dict(key_spec)
geo = _geo_type_of(key_spec_dict)
# Geo indexes use the same entries table but write **multiple**
# entries per doc (one per S2 cell or 2d bucket). They're inherently
# multikey-style; uniqueness is meaningless for geo and is rejected
# by mongod, so we mirror.
if geo is not None:
if unique:
raise IndexConflict(name, None)
geo_field, geo_type = geo
# Geo indexes are always multikey from the picker's perspective
# — each doc may produce many cell entries. Mark it so the
# regular pickers skip the index for non-geo queries.
options["multikey"] = True
entries: list[tuple[bytes, bytes]] = []
for id_k, blob in self._scan_docs(db, coll):
d = bson.decode(blob)
if partial_filter is not None and not matches(d, partial_filter):
continue
for cell_bytes in _doc_geo_cells(
d, geo_field, geo_type, options, index_name=name
):
entries.append((cell_bytes, id_k))
payload = bson.encode({"key": dict(key_spec), "options": options})
c.reset()
c[db, coll, name] = payload
entry_cur = self._cursor(_IDX_ENTRIES_TABLE)
for kb, id_k in entries:
entry_cur.reset()
entry_cur[db, coll, name, _pack_entry(kb, id_k)] = b""
else:
# Single doc-table walk: decode each blob once and fold all
# three checks (uniqueness, multikey detection, entry build)
# into one pass. Uniqueness is probed against the canonical
# whole-doc key (``_index_key``); index entries are written
# for every key variant (``_index_key_variants``) so per-
# element multikey lookups land at IXSCAN.
seen: dict[bytes, Any] | None = {} if unique else None
multikey = False
entries = []
for id_k, blob in self._scan_docs(db, coll):
d = bson.decode(blob)
if partial_filter is not None and not matches(d, partial_filter):
continue
if not multikey and _doc_makes_multikey(d, key_spec_dict):
multikey = True
if seen is not None:
canonical = _index_key(d, key_spec_dict, sparse=sparse)
if canonical is not None:
if canonical in seen:
raise IndexConflict(name, d.get("_id"))
seen[canonical] = d.get("_id")
for kb in _index_key_variants(d, key_spec_dict, sparse=sparse):
entries.append((kb, id_k))
if multikey:
options["multikey"] = True
payload = bson.encode({"key": dict(key_spec), "options": options})
c.reset()
c[db, coll, name] = payload
entry_cur = self._cursor(_IDX_ENTRIES_TABLE)
for kb, id_k in entries:
entry_cur.reset()
entry_cur[db, coll, name, _pack_entry(kb, id_k)] = b""
ui = self._collection_uuid(db, coll)
self._emit_oplog(
[
{
"op": "c",
"ns": f"{db}.$cmd",
"ui": bson.Binary(ui.bytes, subtype=4),
"o": {
"createIndexes": coll,
"indexes": [{"v": 2, "key": dict(key_spec), "name": name, **options}],
},
}
]
)
return True
def list_indexes(self, db: str, coll: str) -> list[dict[str, Any]]:
self._refresh_read_snapshot()
with self._lock:
if self._coll_options(db, coll) is None:
return []
out: list[dict[str, Any]] = [{"v": 2, "key": {"_id": 1}, "name": _ID_INDEX_NAME}]
for name, key_spec, opts in self._iter_indexes(db, coll):
entry: dict[str, Any] = {"v": 2, "key": key_spec, "name": name}
for k, v in opts.items():
entry[k] = v
out.append(entry)
out.sort(key=lambda e: e["name"])
return out
def _iter_indexes(
self, db: str, coll: str
) -> Iterable[tuple[str, dict[str, Any], dict[str, Any]]]:
c = self._cursor(_IDX_TABLE)
c.set_key(db, coll, "")
rc = c.search_near()
if rc == wt.WT_NOTFOUND:
return
if rc < 0 and c.next() != 0:
return
while True:
k = c.get_key()
if k[0] != db or k[1] != coll:
return
payload = bson.decode(bytes(c.get_value()))
yield k[2], payload.get("key", {}), payload.get("options", {})
if c.next() != 0:
return
def drop_index(self, db: str, coll: str, name: str) -> bool:
if name == _ID_INDEX_NAME:
return False
with self._lock:
c = self._cursor(_IDX_TABLE)
c.set_key(db, coll, name)
if c.search() != 0:
return False
c.remove()
entry_rows = self._collect_prefix(_IDX_ENTRIES_TABLE, (db, coll, name))
self._delete_keys(_IDX_ENTRIES_TABLE, [k for k, _ in entry_rows])
ui = self._collection_uuid(db, coll)
self._emit_oplog(
[
{
"op": "c",
"ns": f"{db}.$cmd",
"ui": bson.Binary(ui.bytes, subtype=4),
"o": {"dropIndexes": coll, "index": name},
}
]
)
return True
def drop_all_indexes(self, db: str, coll: str) -> int:
with self._lock:
rows = self._collect_prefix(_IDX_TABLE, (db, coll))
names = [k[2] for k, _ in rows]
self._delete_keys(_IDX_TABLE, [k for k, _ in rows])
entry_rows = self._collect_prefix(_IDX_ENTRIES_TABLE, (db, coll))
self._delete_keys(_IDX_ENTRIES_TABLE, [k for k, _ in entry_rows])
if names:
ui = self._collection_uuid(db, coll)
self._emit_oplog(
[
{
"op": "c",
"ns": f"{db}.$cmd",
"ui": bson.Binary(ui.bytes, subtype=4),
"o": {"dropIndexes": coll, "index": n},
}
for n in names
]
)
return len(rows)
def _all_indexes(self, db: str, coll: str) -> list[tuple[str, dict[str, Any], bool, bool]]:
"""Every non-_id_ index: (name, key_spec, sparse, unique)."""
out: list[tuple[str, dict[str, Any], bool, bool]] = []
for name, key_spec, opts in list(self._iter_indexes(db, coll)):
out.append((name, key_spec, bool(opts.get("sparse")), bool(opts.get("unique"))))
return out
def _partial_filters(self, db: str, coll: str) -> dict[str, dict[str, Any]]:
"""Map of index name → ``partialFilterExpression`` for indexes that have one.
Indexes without a partial filter are absent from the dict.
"""
out: dict[str, dict[str, Any]] = {}
for name, _key_spec, opts in self._iter_indexes(db, coll):
pf = opts.get("partialFilterExpression")
if isinstance(pf, Mapping) and pf:
out[name] = dict(pf)
return out
@staticmethod
def _query_implies_partial(query: Mapping[str, Any], partial: Mapping[str, Any]) -> bool:
"""True if ``query`` is at least as restrictive as ``partial`` — every
key/value pair in ``partial`` appears with the same bare value in
``query``. Conservative: anything more sophisticated (operator-form
clauses, $and, etc.) is treated as not implying the partial filter.
"""
for key, value in partial.items():
if key not in query:
return False
if query[key] != value:
return False
return True
def _multikey_index_names(self, db: str, coll: str) -> set[str]:
"""Names of indexes flagged ``multikey`` (must fall back to scan).
Without true multi-key indexing, an index where any doc has a
list-valued field can't serve scalar-element matches — so the
pickers skip these names and ``find_matching`` falls back to a
full scan.
"""
return {
name for name, _key_spec, opts in self._iter_indexes(db, coll) if opts.get("multikey")
}
def _maybe_mark_multikey(
self,
db: str,
coll: str,
doc: Mapping[str, Any],
indexes: list[tuple[str, dict[str, Any], bool, bool]],
already_multikey: set[str],
) -> set[str]:
"""For each non-multikey index, flag it if ``doc`` has an array
value on any indexed field. Returns the (possibly grown) set of
multikey index names so the caller can avoid re-checking.
"""
c = self._cursor(_IDX_TABLE)
for name, key_spec, _sparse, _unique in indexes:
if name in already_multikey:
continue
if not _doc_makes_multikey(doc, key_spec):
continue
c.reset()
c.set_key(db, coll, name)
if c.search() != 0:
continue
payload = bson.decode(bytes(c.get_value()))
opts = dict(payload.get("options") or {})
if opts.get("multikey"):
already_multikey.add(name)
continue
opts["multikey"] = True
payload["options"] = opts
c.reset()
c[db, coll, name] = bson.encode(payload)
already_multikey.add(name)
return already_multikey
def _write_index_entries(
self,
db: str,
coll: str,
doc: dict[str, Any],
indexes: list[tuple[str, dict[str, Any], bool, bool]],
partials: dict[str, dict[str, Any]] | None = None,
) -> None:
if not indexes:
return
c = self._cursor(_IDX_ENTRIES_TABLE)
id_k = _id_key(doc["_id"])
if partials is None:
partials = self._partial_filters(db, coll)
index_options = self._index_options_map(db, coll)
for name, key_spec, sparse, _unique in indexes:
pf = partials.get(name)
if pf is not None and not matches(doc, pf):
continue
geo = _geo_type_of(key_spec)
if geo is not None:
geo_field, geo_type = geo
opts = index_options.get(name, {})
for cell_bytes in _doc_geo_cells(doc, geo_field, geo_type, opts, index_name=name):
c.reset()
c[db, coll, name, _pack_entry(cell_bytes, id_k)] = b""
continue
for kb in _index_key_variants(doc, key_spec, sparse=sparse):
c.reset()
c[db, coll, name, _pack_entry(kb, id_k)] = b""
def _delete_index_entries(
self,
db: str,
coll: str,
doc: dict[str, Any],
indexes: list[tuple[str, dict[str, Any], bool, bool]],
partials: dict[str, dict[str, Any]] | None = None,
) -> None:
if not indexes:
return
c = self._cursor(_IDX_ENTRIES_TABLE)
id_k = _id_key(doc["_id"])
if partials is None:
partials = self._partial_filters(db, coll)
index_options = self._index_options_map(db, coll)
for name, key_spec, sparse, _unique in indexes:
pf = partials.get(name)
if pf is not None and not matches(doc, pf):
continue
geo = _geo_type_of(key_spec)
if geo is not None:
geo_field, geo_type = geo
opts = index_options.get(name, {})
# On the delete path, swallow GeoExtractError. A doc that
# was inserted before geo validation became strict might
# have bad geometry; we still need to allow it to be
# deleted. The index may end up with stale entries we
# can't match, but the next compact / drop_index cleans
# those up. Insert/update remain strict.
try:
cells = _doc_geo_cells(doc, geo_field, geo_type, opts, index_name=name)
except GeoExtractError:
continue
for cell_bytes in cells:
c.reset()
c.set_key(db, coll, name, _pack_entry(cell_bytes, id_k))
if c.search() == 0:
c.remove()
continue
for kb in _index_key_variants(doc, key_spec, sparse=sparse):
c.reset()
c.set_key(db, coll, name, _pack_entry(kb, id_k))
if c.search() == 0:
c.remove()
def _validate_geo_indexes(
self,
db: str,
coll: str,
doc: dict[str, Any],
indexes: list[tuple[str, dict[str, Any], bool, bool]],
partials: dict[str, dict[str, Any]] | None = None,
) -> None:
"""Pre-flight every geo index for ``doc``: raise on bad geometry.
Used by the insert / update paths to reject docs *before* writing
them, so a single bad geo coordinate doesn't leave a half-indexed
document behind. The work duplicates ``_write_index_entries``'s
cell computation but is cheap (one Shapely parse + bounds check
per indexed field).
"""
if not indexes:
return
if partials is None:
partials = self._partial_filters(db, coll)
options_map = self._index_options_map(db, coll)
for name, key_spec, _sparse, _unique in indexes:
geo = _geo_type_of(key_spec)
if geo is None:
continue
pf = partials.get(name)
if pf is not None and not matches(doc, pf):
continue
geo_field, geo_type = geo
opts = options_map.get(name, {})
# Compute & discard — `_doc_geo_cells` raises GeoExtractError
# on bad shape or out-of-bounds coords; that's the signal we
# want to bubble up.
_doc_geo_cells(doc, geo_field, geo_type, opts, index_name=name)
def _index_options_map(self, db: str, coll: str) -> dict[str, dict[str, Any]]:
"""Map of index name → its full options blob.
Used by the geo write/delete paths: 2d indexes carry per-index
``bits`` / ``min`` / ``max`` settings that affect the cell encoder,
so we need the option blob to compute the right bucket. Cached
per call (the caller handles per-doc loops).
"""
return {name: dict(opts) for name, _key_spec, opts in self._iter_indexes(db, coll)}
def _unique_conflict(
self,
db: str,
coll: str,
candidate_doc: dict[str, Any],
indexes: list[tuple[str, dict[str, Any], bool, bool]],
*,
exclude_id_key: bytes | None,
partials: dict[str, dict[str, Any]] | None = None,
) -> tuple[str, dict[str, Any], dict[str, Any]] | None:
# Returns ``(index_name, key_pattern, key_value)`` so callers
# can build a mongod-shaped dup-key error response with the
# ``keyPattern`` + ``keyValue`` fields drivers' errorResponse
# tests assert on. ``None`` when no conflict.
if not indexes:
return None
c = self._cursor(_IDX_ENTRIES_TABLE)
if partials is None:
partials = self._partial_filters(db, coll)
for name, key_spec, sparse, unique in indexes:
if not unique:
continue
pf = partials.get(name)
if pf is not None and not matches(candidate_doc, pf):
continue
kb = _index_key(candidate_doc, key_spec, sparse=sparse)
if kb is None:
continue
esc_kb = _escape_kb(kb)
seed = esc_kb + _ENTRY_SEP
c.reset()
c.set_key(db, coll, name, seed)
rc = c.search_near()
if rc == wt.WT_NOTFOUND:
continue
if rc < 0 and c.next() != 0:
continue
while True:
k = c.get_key()
if (k[0], k[1], k[2]) != (db, coll, name):
break
packed = bytes(k[3])
row_esc, row_id = _unpack_entry(packed)
if row_esc != esc_kb:
break
if exclude_id_key is None or row_id != exclude_id_key:
key_value = {
field: get_path(candidate_doc, field, default=None) for field in key_spec
}
return name, dict(key_spec), key_value
if c.next() != 0:
break
return None
def _scan_index_for_id_keys(
self, db: str, coll: str, name: str, kb: bytes, *, prefix: bool = False
) -> list[bytes]:
"""Walk the index entries for ``name`` matching ``kb``.
With ``prefix=False`` (default), only rows whose ``escaped_kb`` is
exactly equal to ``escape(kb)`` are returned — equality lookup.
With ``prefix=True``, any row whose ``escaped_kb`` starts with
``escape(kb)`` is returned — compound-prefix lookup.
"""
c = self._cursor(_IDX_ENTRIES_TABLE)
esc_kb = _escape_kb(kb)
seed = esc_kb if prefix else esc_kb + _ENTRY_SEP
c.set_key(db, coll, name, seed)
rc = c.search_near()
if rc == wt.WT_NOTFOUND:
return []
if rc < 0 and c.next() != 0:
return []
out: list[bytes] = []
while True:
k = c.get_key()
if (k[0], k[1], k[2]) != (db, coll, name):
break
packed = bytes(k[3])
row_esc, row_id = _unpack_entry(packed)
if prefix:
if not row_esc.startswith(esc_kb):
break
elif row_esc != esc_kb:
break
out.append(row_id)
if c.next() != 0:
break
return out
def _docs_by_id_keys(self, db: str, coll: str, id_keys: list[bytes]) -> list[dict[str, Any]]:
if not id_keys:
return []
c = self._cursor(_DOC_TABLE)
# Two-stage: WT cursor walk first (raw bytes), then ``bson.decode``
# outside that loop. The cursor work is what needs lock scope;
# decode is pure CPU and benefits from running unsynchronised.
# Multikey indexes write per-element entries, so the same doc's
# id_key can appear more than once for queries that match
# multiple elements. Dedupe while preserving order.
raw: list[bytes] = []
for id_k in dict.fromkeys(id_keys):
c.reset()
c.set_key(db, coll, id_k)
if c.search() == 0:
raw.append(bytes(c.get_value()))
return [bson.decode(b) for b in raw]
_RANGE_OPS: tuple[str, ...] = ("$eq", "$gt", "$gte", "$lt", "$lte", "$in")
_GEO_OPS: tuple[str, ...] = ("$geoWithin", "$geoIntersects", "$near", "$nearSphere")
def _try_geo_index_id_keys(
self, db: str, coll: str, filter: dict[str, Any]
) -> list[bytes] | None:
"""If ``filter`` contains a geo operator on a geo-indexed field,
scan that index's covering cells and return candidate id_keys.
Returns ``None`` when no geo operator is present or no matching
geo index exists — caller falls through to regular pickers and
eventually a full scan. Returns a list (possibly empty) when a
geo index covers the query — caller short-circuits the regular
pickers.
The cell scan over-collects (cell-covering is a superset of the
true intersection); the caller's ``matches()`` step verifies via
:func:`secantus.geo.geo_within` / ``geo_intersects`` and removes
false positives.
"""
# Find a single field with a geo operator on it.
geo_field: str | None = None
geo_op: str | None = None
geo_arg: Any = None
for field, value in filter.items():
if not isinstance(value, dict):
continue
for op in self._GEO_OPS:
if op in value:
geo_field = field
geo_op = op
geo_arg = value[op]
break
if geo_field is not None:
break
if geo_field is None:
return None
# Locate a geo index on that field.
chosen_name: str | None = None
chosen_type: str | None = None
chosen_opts: dict[str, Any] = {}
for name, key_spec, opts in self._iter_indexes(db, coll):
geo = _geo_type_of(key_spec)
if geo is None:
continue
if geo[0] == geo_field:
chosen_name = name
chosen_type = geo[1]
chosen_opts = dict(opts)
break
if chosen_name is None or chosen_type is None:
return None
# Build the query geometry from the operator arg.
cells = self._geo_query_cells(geo_op, geo_arg, chosen_type, chosen_opts)
if cells is None:
# Couldn't compute a covering — defer to full scan.
return None
return self._collect_geo_candidates(db, coll, chosen_name, cells)
def _geo_query_cells(
self, op: str, arg: Any, geo_type: str, options: Mapping[str, Any]
) -> list[tuple[bytes, bytes]] | None:
"""Byte ranges covering the query geometry, one per covering cell.
Both 2dsphere and 2d return ``list[tuple[bytes, bytes]]`` — for
2dsphere each entry is the (range_min, range_max) byte pair of
an S2 covering cell expanded to its leaf descendants; for 2d
it's the single (lo, hi) bbox range from `planar_2d_covering`.
Callers use :meth:`_scan_geo_range` for both.
"""
from secantus.geo import GeoError
try:
if op in ("$geoWithin", "$geoIntersects"):
if not isinstance(arg, Mapping):
return None
geom, _ = parse_query_geometry(arg)
elif op in ("$near", "$nearSphere"):
# `$near` without a max distance: caller falls through to
# full scan (signal None). With a max, expand into a cap
# (2dsphere) or planar disk (2d).
center, max_d, _min_d, _spherical = self._near_query_geom(arg)
if max_d is None:
return None
from shapely.geometry import Point as _Point
from secantus.geo import _SphericalCircle
if geo_type == _GEO_2DSPHERE:
from secantus.geo import EARTH_RADIUS_METERS
radius_rad = max_d / EARTH_RADIUS_METERS
geom = _SphericalCircle(center[0], center[1], radius_rad)
else: # 2d planar — circular disk
geom = _Point(*center).buffer(max_d, quad_segs=16)
else:
return None
except GeoError:
return None
if geo_type == _GEO_2DSPHERE:
# Each cell becomes a degenerate (cell, cell) range so the
# storage scanner does an exact point-lookup. Treating
# 2dsphere uniformly as a list-of-ranges keeps the storage
# path single-shaped.
return [(encode_cell(c), encode_cell(c)) for c in s2_query_covering(geom)]
# 2d: shape must be planar; convert to a single (lo, hi) range.
from shapely.geometry.base import BaseGeometry as _BG
if not isinstance(geom, _BG):
return None
lo, hi = planar_2d_covering(geom, options)
return [(encode_cell(lo), encode_cell(hi))]
def _near_query_geom(
self, arg: Any
) -> tuple[tuple[float, float], float | None, float | None, bool]:
"""Reuse :mod:`secantus.query`'s ``$near`` parser for the picker.
Routing it through `_parse_near_spec` keeps the spec semantics in
one place — the operator handler and the picker agree on what
a ``$near`` arg means.
"""
from secantus.query import _parse_near_spec # type: ignore[attr-defined]
return _parse_near_spec(arg, default_spherical=False)
def _collect_geo_candidates(
self,
db: str,
coll: str,
index_name: str,
cells: list[tuple[bytes, bytes]],
) -> list[bytes]:
"""Walk index entries in each (lo, hi) range; return deduplicated id_keys.
A doc with N covering cells produces N index entries; we collect
just one ``_id`` per doc. The post-fetch verifier (in
``find_matching``'s ``matches()`` step) discards docs whose
actual geometry doesn't match the query.
"""
c = self._cursor(_IDX_ENTRIES_TABLE)
seen: set[bytes] = set()
out: list[bytes] = []
for lo_bytes, hi_bytes in cells:
self._scan_geo_range(c, db, coll, index_name, lo_bytes, hi_bytes, seen, out)
return out
def _scan_geo_range(
self,
c: Any,
db: str,
coll: str,
name: str,
lo_bytes: bytes,
hi_bytes: bytes,
seen: set[bytes],
out: list[bytes],
) -> None:
"""Walk every index entry whose escaped cell-id is in [lo_bytes, hi_bytes].
Lex byte order over `_escape_kb`-escaped fixed-width cell IDs is
the same as numeric cell-id order, so a forward WT cursor walk
between the two escaped boundary keys visits every entry inside
the range exactly once. Cell IDs are packed as fixed 8-byte
big-endian, so escaping never changes their relative order.
"""
lo_prefix = _escape_kb(lo_bytes)
hi_prefix = _escape_kb(hi_bytes)
c.reset()
c.set_key(db, coll, name, lo_prefix)
rc = c.search_near()
if rc == wt.WT_NOTFOUND:
return
if rc < 0 and c.next() != 0:
return
while True:
k = c.get_key()
if k[0] != db or k[1] != coll or k[2] != name:
return
packed = bytes(k[3])
sep_pos = packed.find(_ENTRY_SEP)
if sep_pos < 0:
if c.next() != 0:
return
continue
kb_part = packed[:sep_pos]
if kb_part > hi_prefix:
return
id_key = packed[sep_pos + len(_ENTRY_SEP) :]
if id_key not in seen:
seen.add(id_key)
out.append(id_key)
if c.next() != 0:
return
def _try_index_lookup(
self, db: str, coll: str, filter: dict[str, Any]
) -> list[dict[str, Any]] | None:
id_keys = self._try_index_id_keys(db, coll, filter)
if id_keys is None:
return None
return self._docs_by_id_keys(db, coll, id_keys)
def _try_index_id_keys(self, db: str, coll: str, filter: dict[str, Any]) -> list[bytes] | None:
"""Same dispatch as ``_try_index_lookup`` but returns id_keys instead
of materialised docs. Used by the write paths (update / delete) so
only matching docs pay ``bson.decode``.
"""
if not filter:
return None
if any(f.startswith("$") for f in filter):
return None
# Geo dispatch first — a $geoWithin / $geoIntersects / $near clause
# on a field with a 2dsphere or 2d index uses the cell-covering
# path. The picker returns None if no geo index covers the query,
# and we fall through to the regular pickers below.
geo_ids = self._try_geo_index_id_keys(db, coll, filter)
if geo_ids is not None:
return geo_ids
# Bare-equality filters of any size can use a compound index whose
# leading fields cover the filter set.
if all(not isinstance(v, dict) for v in filter.values()):
result = self._try_compound_eq_id_keys(db, coll, filter)
if result is not None:
return result
# Compound prefix + trailing operator field (eq fields then range/in).
if len(filter) >= 2:
result = self._try_compound_range_id_keys(db, coll, filter)
if result is not None:
return result
if len(filter) != 1:
return None
field, value = next(iter(filter.items()))
idx_match = self._find_leading_field_index(db, coll, field, filter)
if idx_match is None:
return None
return self._lookup_id_keys_via_leading_field(db, coll, idx_match, value)
def _candidates_iter(
self, db: str, coll: str, filter: dict[str, Any] | None
) -> list[tuple[bytes, bytes]]:
"""Return (id_key, blob) pairs that the write paths should consider.
If an index covers the filter, only the indexed candidates are
fetched; otherwise the full doc table is scanned. Either way,
BSON decode is left to the caller so non-matching docs don't pay
for it. Caller still applies ``matches()`` to the decoded doc —
index lookups can produce false-positive candidates for partial
scans (multikey, prefix overlap, etc).
"""
if filter:
id_keys = self._try_index_id_keys(db, coll, filter)
if id_keys is not None:
c = self._cursor(_DOC_TABLE)
out: list[tuple[bytes, bytes]] = []
# Same dedup contract as ``_docs_by_id_keys``: multikey
# indexes can yield duplicate id_keys for one doc.
for id_k in dict.fromkeys(id_keys):
c.reset()
c.set_key(db, coll, id_k)
if c.search() == 0:
out.append((id_k, bytes(c.get_value())))
return out
return list(self._scan_docs(db, coll))
def _find_leading_field_index(
self,
db: str,
coll: str,
field: str,
query: Mapping[str, Any] | None = None,
) -> tuple[str, int, bool] | None:
"""Best index whose leading field is ``field``.
Returns ``(name, direction, is_compound)``. Single-field indexes
win over compound (tighter scan, no separator math). All fields
must be ASC or DESC. Partial indexes are skipped unless ``query``
implies their ``partialFilterExpression``.
Multikey indexes are not skipped — ``_index_key_variants`` writes
per-element entries, so equality / range / ``$in`` lookups on
the leading field hit at least all true matches. The geo
``2dsphere`` / ``2d`` indexes have non-numeric direction values
and are excluded by the ASC/DESC check below.
"""
partials = self._partial_filters(db, coll)
query = query or {}
compound_fallback: tuple[str, int, bool] | None = None
for name, key_spec, _sparse, _unique in self._all_indexes(db, coll):
pf = partials.get(name)
if pf is not None and not self._query_implies_partial(query, pf):
continue
idx_fields = list(key_spec)
if not idx_fields or idx_fields[0] != field:
continue
# Geo / hashed / text indexes carry string direction values
# ("2dsphere", "2d", "hashed", "text"); the bare equality
# picker can't drive them. Real numeric direction values are
# 1 / -1.
if any(key_spec[f] not in (1, -1) for f in idx_fields):
continue
d = int(key_spec[field])
if len(idx_fields) == 1:
return name, d, False
if compound_fallback is None:
compound_fallback = (name, d, True)
return compound_fallback
def _lookup_id_keys_via_leading_field(
self,
db: str,
coll: str,
idx_match: tuple[str, int, bool],
value: Any,
) -> list[bytes] | None:
name, direction, is_compound = idx_match
if not isinstance(value, dict):
return self._eq_id_keys_via_leading(db, coll, name, direction, is_compound, value)
if not value or not all(k.startswith("$") for k in value):
return None
if not all(op in self._RANGE_OPS for op in value):
return None
if "$in" in value:
if len(value) != 1 or not isinstance(value["$in"], list):
return None
seen: set[bytes] = set()
id_keys: list[bytes] = []
for v in value["$in"]:
if isinstance(v, dict):
return None
for id_k in self._eq_id_keys_via_leading(db, coll, name, direction, is_compound, v):
if id_k not in seen:
seen.add(id_k)
id_keys.append(id_k)
return id_keys
lower: bytes | None = None
lower_inclusive = True
upper: bytes | None = None
upper_inclusive = True
for op, bound in value.items():
if isinstance(bound, dict):
return None
if op == "$eq":
return self._eq_id_keys_via_leading(db, coll, name, direction, is_compound, bound)
kb = encode_value_directed(bound, direction)
# Operator semantics flip when stored bytes are inverted: in a
# DESC index, "x > 5" means we want stored bytes < enc_desc(5).
effective_op = op
if direction == -1:
effective_op = {"$gt": "$lt", "$gte": "$lte", "$lt": "$gt", "$lte": "$gte"}[op]
if effective_op == "$gt":
lower, lower_inclusive = kb, False
elif effective_op == "$gte":
lower, lower_inclusive = kb, True
elif effective_op == "$lt":
upper, upper_inclusive = kb, False
elif effective_op == "$lte":
upper, upper_inclusive = kb, True
if is_compound:
return self._range_scan_index_leading(
db, coll, name, lower, lower_inclusive, upper, upper_inclusive
)
return self._range_scan_index(
db, coll, name, lower, lower_inclusive, upper, upper_inclusive
)
def _eq_id_keys_via_leading(
self,
db: str,
coll: str,
name: str,
direction: int,
is_compound: bool,
value: Any,
) -> list[bytes]:
kb = encode_value_directed(value, direction)
if is_compound:
return self._scan_index_for_id_keys(db, coll, name, kb + COMPOUND_SEP, prefix=True)
return self._scan_index_for_id_keys(db, coll, name, kb)
def _pick_compound_eq_index(
self, db: str, coll: str, filter: dict[str, Any]
) -> tuple[str, dict[str, Any]] | None:
"""Find the index that ``_try_compound_eq_id_keys`` would walk for ``filter``.
Returns ``(name, key_spec)`` of the chosen index, or ``None`` if no
index covers the filter as a leading prefix. Pure picker — does
not scan. Multikey indexes are eligible (per-element entries
cover equality lookups); the ASC/DESC direction check excludes
geo indexes.
"""
filter_fields = set(filter)
partials = self._partial_filters(db, coll)
best: tuple[str, dict[str, Any]] | None = None
for name, key_spec, _sparse, _unique in self._all_indexes(db, coll):
pf = partials.get(name)
if pf is not None:
if not self._query_implies_partial(filter, pf):
continue
# Partial-filter clauses are guaranteed by the index itself,
# so they don't have to appear in the index key.
eff_fields = filter_fields - set(pf)
else:
eff_fields = filter_fields
idx_fields = list(key_spec.keys())
# Geo / hashed / text indexes (string direction values) can't
# serve a bare-equality compound lookup.
if any(key_spec[f] not in (1, -1) for f in idx_fields):
continue
if len(idx_fields) < len(eff_fields):
continue
if set(idx_fields[: len(eff_fields)]) != eff_fields:
continue
if best is None or (len(list(best[1])) > len(idx_fields)):
best = (name, dict(key_spec))
if len(idx_fields) == len(eff_fields):
break
return best
def _try_compound_eq_id_keys(
self, db: str, coll: str, filter: dict[str, Any]
) -> list[bytes] | None:
"""Bare-equality filter against a compound (or single-field) index prefix.
Picks an index whose leading fields (set-wise) match the filter's
fields, and runs an equality (full-cover) or prefix
(strict-leading-prefix) scan against it. Per-field index direction
is honoured by encoding each value with ``encode_value_directed``.
"""
picked = self._pick_compound_eq_index(db, coll, filter)
if picked is None:
return None
name, key_spec = picked
idx_fields = list(key_spec)
# Build kb from the filter fields that are in the index (partial-filter
# clauses live outside the key and are guaranteed by index population).
prefix_fields = [f for f in idx_fields if f in filter]
parts = [encode_value_directed(filter[f], int(key_spec[f])) for f in prefix_fields]
kb = COMPOUND_SEP.join(parts) if len(parts) > 1 else parts[0]
if len(prefix_fields) == len(idx_fields):
return self._scan_index_for_id_keys(db, coll, name, kb)
kb = kb + COMPOUND_SEP
return self._scan_index_for_id_keys(db, coll, name, kb, prefix=True)
def _partition_compound_range_filter(
self, filter: dict[str, Any]
) -> tuple[dict[str, Any], str, dict[str, Any]] | None:
"""Split a filter into ``(eq_fields, operator_field, operator_ops)``.
Returns ``None`` if the filter doesn't fit the compound-range
shape (any number of bare-equality fields plus exactly one
operator-form field whose ops are all in ``_RANGE_OPS``).
"""
eq_fields: dict[str, Any] = {}
operator_field: str | None = None
operator_ops: dict[str, Any] | None = None
for f, v in filter.items():
if isinstance(v, dict):
if not v or not all(k.startswith("$") for k in v):
return None
if not all(op in self._RANGE_OPS for op in v):
return None
if operator_field is not None:
return None
operator_field = f
operator_ops = v
else:
eq_fields[f] = v
if operator_field is None or not eq_fields:
return None
if operator_field in eq_fields:
return None
return eq_fields, operator_field, operator_ops or {}
def _pick_compound_range_index(
self, db: str, coll: str, filter: dict[str, Any]
) -> tuple[str, dict[str, Any]] | None:
"""Find the index that ``_try_compound_range_id_keys`` would walk."""
parts = self._partition_compound_range_filter(filter)
if parts is None:
return None
eq_fields, operator_field, _operator_ops = parts
eq_set = set(eq_fields)
target_eq_count = len(eq_set)
partials = self._partial_filters(db, coll)
best: tuple[str, dict[str, Any]] | None = None
for name, key_spec, _sparse, _unique in self._all_indexes(db, coll):
pf = partials.get(name)
if pf is not None and not self._query_implies_partial(filter, pf):
continue
idx_fields = list(key_spec.keys())
# Geo / hashed / text indexes (string direction values) can't
# serve a compound prefix + trailing-operator lookup.
if any(key_spec[f] not in (1, -1) for f in idx_fields):
continue
if len(idx_fields) <= target_eq_count:
continue
if set(idx_fields[:target_eq_count]) != eq_set:
continue
if idx_fields[target_eq_count] != operator_field:
continue
if best is None or len(list(best[1])) > len(idx_fields):
best = (name, dict(key_spec))
if len(idx_fields) == target_eq_count + 1:
break
return best
def _try_compound_range_id_keys(
self, db: str, coll: str, filter: dict[str, Any]
) -> list[bytes] | None:
"""Compound-prefix lookup with a trailing operator field.
Filters of the form ``{a: 5, b: 10, c: {$gt: 20}}`` (any number of
leading bare-equality fields followed by exactly one operator-form
field) walk the compound index by pinning the prefix from the
equalities and applying the operator's bounds to the next field.
"""
parts = self._partition_compound_range_filter(filter)
if parts is None:
return None
eq_fields, operator_field, operator_ops = parts
picked = self._pick_compound_range_index(db, coll, filter)
if picked is None:
return None
name, key_spec = picked
idx_fields = list(key_spec)
target_eq_count = len(eq_fields)
eq_field_names = idx_fields[:target_eq_count]
op_dir = int(key_spec[operator_field])
eq_parts = [encode_value_directed(eq_fields[f], int(key_spec[f])) for f in eq_field_names]
prefix_kb = COMPOUND_SEP.join(eq_parts) if len(eq_parts) > 1 else eq_parts[0]
prefix_with_sep = prefix_kb + COMPOUND_SEP
if "$in" in operator_ops:
if len(operator_ops) != 1 or not isinstance(operator_ops["$in"], list):
return None
seen: set[bytes] = set()
id_keys: list[bytes] = []
for v in operator_ops["$in"]:
if isinstance(v, dict):
return None
kb = prefix_with_sep + encode_value_directed(v, op_dir)
use_prefix = len(idx_fields) > target_eq_count + 1
inner_kb = kb + COMPOUND_SEP if use_prefix else kb
for id_k in self._scan_index_for_id_keys(
db, coll, name, inner_kb, prefix=use_prefix
):
if id_k not in seen:
seen.add(id_k)
id_keys.append(id_k)
return id_keys
if "$eq" in operator_ops:
if len(operator_ops) != 1:
return None
kb = prefix_with_sep + encode_value_directed(operator_ops["$eq"], op_dir)
use_prefix = len(idx_fields) > target_eq_count + 1
inner_kb = kb + COMPOUND_SEP if use_prefix else kb
return self._scan_index_for_id_keys(db, coll, name, inner_kb, prefix=use_prefix)
lower: bytes | None = None
lower_inclusive = True
upper: bytes | None = None
upper_inclusive = True
for op, bound in operator_ops.items():
if isinstance(bound, dict):
return None
full = prefix_with_sep + encode_value_directed(bound, op_dir)
effective_op = op
if op_dir == -1:
effective_op = {"$gt": "$lt", "$gte": "$lte", "$lt": "$gt", "$lte": "$gte"}[op]
if effective_op == "$gt":
lower, lower_inclusive = full, False
elif effective_op == "$gte":
lower, lower_inclusive = full, True
elif effective_op == "$lt":
upper, upper_inclusive = full, False
elif effective_op == "$lte":
upper, upper_inclusive = full, True
else:
return None
return self._range_scan_index(
db,
coll,
name,
lower,
lower_inclusive,
upper,
upper_inclusive,
prefix=prefix_with_sep,
)
def _range_scan_index(
self,
db: str,
coll: str,
name: str,
lower: bytes | None,
lower_inclusive: bool,
upper: bytes | None,
upper_inclusive: bool,
*,
prefix: bytes | None = None,
) -> list[bytes]:
"""Range-scan the index entries for ``name``.
Optional ``prefix`` constrains the scan to entries whose escaped
kb starts with ``escape(prefix)`` — used by compound-index
prefix+range queries where leading equalities pin part of the kb.
"""
c = self._cursor(_IDX_ENTRIES_TABLE)
esc_prefix = _escape_kb(prefix) if prefix is not None else None
esc_lower = _escape_kb(lower) if lower is not None else None
esc_upper = _escape_kb(upper) if upper is not None else None
if esc_lower is not None:
seed = esc_lower + _ENTRY_SEP
elif esc_prefix is not None:
seed = esc_prefix
else:
seed = b""
c.set_key(db, coll, name, seed)
rc = c.search_near()
if rc == wt.WT_NOTFOUND:
return []
if rc < 0 and c.next() != 0:
return []
out: list[bytes] = []
while True:
k = c.get_key()
if (k[0], k[1], k[2]) != (db, coll, name):
break
packed = bytes(k[3])
row_esc, row_id = _unpack_entry(packed)
if esc_prefix is not None and not row_esc.startswith(esc_prefix):
break
if esc_lower is not None and not lower_inclusive and row_esc == esc_lower:
if c.next() != 0:
break
continue
if esc_upper is not None:
if upper_inclusive:
if row_esc > esc_upper:
break
elif row_esc >= esc_upper:
break
out.append(row_id)
if c.next() != 0:
break
return out
def _range_scan_index_leading(
self,
db: str,
coll: str,
name: str,
lower: bytes | None,
lower_inclusive: bool,
upper: bytes | None,
upper_inclusive: bool,
) -> list[bytes]:
"""Range-scan a compound index using only its leading field.
Each row's escaped kb is
``escape(enc(leading)) + escape(COMPOUND_SEP) + escape(enc(trailing...))``.
Boundary detection uses ``startswith(esc_X + esc_compound_sep)`` to
identify rows whose leading field equals ``X`` — the terminator
bytes of an escaped numeric encoding can overlap with the start of
the escaped compound separator, so a literal find/split on the
separator is unreliable.
"""
esc_compound_sep = _escape_kb(COMPOUND_SEP)
c = self._cursor(_IDX_ENTRIES_TABLE)
esc_lower = _escape_kb(lower) if lower is not None else None
esc_upper = _escape_kb(upper) if upper is not None else None
seed = esc_lower if esc_lower is not None else b""
c.set_key(db, coll, name, seed)
rc = c.search_near()
if rc == wt.WT_NOTFOUND:
return []
if rc < 0 and c.next() != 0:
return []
lower_eq_prefix = esc_lower + esc_compound_sep if esc_lower is not None else None
upper_eq_prefix = esc_upper + esc_compound_sep if esc_upper is not None else None
out: list[bytes] = []
while True:
k = c.get_key()
if (k[0], k[1], k[2]) != (db, coll, name):
break
packed = bytes(k[3])
row_esc, row_id = _unpack_entry(packed)
if (
lower_eq_prefix is not None
and not lower_inclusive
and row_esc.startswith(lower_eq_prefix)
):
if c.next() != 0:
break
continue
if esc_upper is not None:
if upper_inclusive:
if row_esc > esc_upper and not row_esc.startswith(upper_eq_prefix):
break
elif row_esc >= esc_upper:
break
out.append(row_id)
if c.next() != 0:
break
return out