"""WiredTiger-backed document store.
WiredTiger is the default storage engine for MongoDB. We use the same
engine here so that on-disk semantics line up with what test code would
see against a real ``mongod``.
Indexes use a sidecar entries table (``table:secantus_index_entries``)
with a single trailing ``u`` column packing
``escape(sortkey) + b"\\x00\\x00" + id_key``. The sortkey comes from
``secantus.sortkey`` (typed, byte-sortable BSON encoding), so the WT
B-tree gives us ordered access for free. ``find_matching`` routes a wide
range of filter shapes through the index — equality, ``$eq``, ``$in``,
``$gt``/``$gte``/``$lt``/``$lte`` on a single field, plus compound
indexes when filter fields cover a leading prefix (with optional range
on the next field). Sort-by-indexed-field walks the B-tree in order.
"""
from __future__ import annotations
import contextlib
import datetime as _dt
import functools
import os
import re
import shutil
import tempfile
import threading
import time as _time
import uuid as _uuid
from collections.abc import Callable, Iterable, Mapping
from typing import Any
import bson
import wiredtiger as wt
from bson.int64 import Int64
from bson.timestamp import Timestamp
from secantus.diff import compute_update_description
from secantus.geo import GeoError, parse_doc_geometry, parse_query_geometry, validate_coordinates
from secantus.geo_index import (
encode_cell,
planar_2d_covering_ranges,
planar_2d_index_for_point,
s2_doc_covering,
s2_query_covering,
)
from secantus.paths import get_path, has_path
from secantus.projection import apply_projection
from secantus.query import matches
from secantus.sortkey import COMPOUND_SEP, encode_value, encode_value_directed
from secantus.update import apply_update, find_positional_matches
_GEO_2DSPHERE = "2dsphere"
_GEO_2D = "2d"
_GEO_TYPES = frozenset({_GEO_2DSPHERE, _GEO_2D})
def _geo_type_of(key_spec: Mapping[str, Any]) -> tuple[str, str] | None:
"""Return ``(field, geo_type)`` if ``key_spec`` declares a geo index.
A geo index has exactly one field whose value is the string
``"2dsphere"`` or ``"2d"`` (rather than ``1`` / ``-1``). Compound
geo indexes (geo field + scalar trailing fields) are out of scope
in Phase 2; we treat any spec containing a geo field as geo-only
and ignore the trailing fields. The picker still works because
`$geoWithin` etc. are answered by the cell scan + verifier.
"""
for field, value in key_spec.items():
if isinstance(value, str) and value in _GEO_TYPES:
return field, value
return None
def _doc_geo_cells(
doc: Mapping[str, Any],
field: str,
geo_type: str,
options: Mapping[str, Any],
*,
index_name: str = "",
) -> list[bytes]:
"""Encoded cell bytes for the doc's geo field.
Returns an empty list when the indexed field is missing or null
(sparse-by-default semantics, matching mongod's 2dsphere/2d).
Raises :class:`GeoExtractError` when the value is *present* but
can't be indexed — unparseable shape, wrong type for a 2d index,
or coordinates outside the valid range. The caller propagates this
to the wire as a write error (code 16572 "Can't extract geo keys").
"""
value = get_path(dict(doc), field)
if value is None:
# Field missing or explicitly null — sparse semantics, no entry.
return []
geom = parse_doc_geometry(value)
if geom is None:
raise GeoExtractError(
index_name,
field,
doc.get("_id"),
f"value {value!r} is not a recognised geometry",
)
try:
validate_coordinates(geom, geo_type=geo_type, options=options)
except GeoError as exc:
raise GeoExtractError(index_name, field, doc.get("_id"), str(exc)) from exc
if geo_type == _GEO_2DSPHERE:
return [encode_cell(c) for c in s2_doc_covering(geom)]
# 2d: single point only.
from shapely.geometry import Point as _Point
if not isinstance(geom, _Point):
raise GeoExtractError(
index_name,
field,
doc.get("_id"),
"2d index requires a point; got a non-point geometry",
)
return [encode_cell(planar_2d_index_for_point(geom.x, geom.y, options))]
_COLL_TABLE = "table:secantus_collections"
_DOC_TABLE = "table:secantus_documents"
_IDX_TABLE = "table:secantus_indexes"
_IDX_ENTRIES_TABLE = "table:secantus_index_entries"
_OPLOG_TABLE = "table:secantus_oplog"
_PREIMAGE_TABLE = "table:secantus_preimages"
_OPLOG_META_TABLE = "table:secantus_oplog_meta"
_USERS_TABLE = "table:secantus_users"
_ROLES_TABLE = "table:secantus_roles"
_PROFILE_TABLE = "table:secantus_profile_settings"
_OPLOG_PRUNE_INTERVAL = 1000 # call prune_oplog every N emits
_ENTRY_SEP = b"\x00\x00"
def _escape_kb(kb: bytes) -> bytes:
"""Order-preserving escape so ``\\x00\\x00`` is unambiguous as a separator."""
return kb.replace(b"\x00", b"\x00\xff")
def _pack_entry(kb: bytes, id_key: bytes) -> bytes:
"""Pack a sortable index-entry payload into a single ``u`` column.
WiredTiger length-prefixes ``u`` columns when they're not last in the
key, which breaks lexicographic comparison. Packing both fields into
one trailing ``u`` column lets the B-tree do the sort for us.
"""
return _escape_kb(kb) + _ENTRY_SEP + id_key
def _unpack_entry(packed: bytes) -> tuple[bytes, bytes]:
"""Return ``(escaped_kb, id_key)`` from a packed entry."""
sep = packed.find(_ENTRY_SEP)
return packed[:sep], packed[sep + 2 :]
def extract_backup_archive(
archive_path: str,
target_dir: str,
*,
allow_existing: bool = False,
) -> dict[str, int | str]:
"""Extract a SecantusDB backup archive into ``target_dir``.
Side-channel restore: the archive is unpacked into a fresh
directory that the caller then points a new ``SecantusDBServer`` at
(``SecantusDBServer(storage_path=<target_dir>)``). The function
does **not** touch any running server's storage — that mode of
"hot restore over a live WT connection" can't be done safely
without restructuring how connection threads cache WT sessions,
and isn't what real mongod's restore tooling supports either.
Returns ``{"targetDir": <abs>, "fileCount": <int>, "archive": <abs>}``
on success. Raises ``RuntimeError`` if:
* the archive doesn't exist,
* the archive doesn't contain a ``WiredTiger`` metadata file
(so it's not a SecantusDB / WT backup at all),
* ``target_dir`` already exists, is non-empty, and ``allow_existing``
is False (default).
The WT metadata check runs **before** extraction so a malformed
archive can't pollute ``target_dir``.
"""
import tarfile
abs_archive = os.path.abspath(archive_path)
abs_target = os.path.abspath(target_dir)
if not os.path.isfile(abs_archive):
raise RuntimeError(f"extract_backup_archive: archive not found: {abs_archive}")
if os.path.exists(abs_target):
if not os.path.isdir(abs_target):
raise RuntimeError(
f"extract_backup_archive: target exists and is not a directory: {abs_target}"
)
if os.listdir(abs_target) and not allow_existing:
raise RuntimeError(
"extract_backup_archive: target directory is not empty "
f"(pass allow_existing=True to overlay): {abs_target}"
)
else:
os.makedirs(abs_target)
with tarfile.open(abs_archive, "r:*") as tar:
names = tar.getnames()
if "WiredTiger" not in names:
raise RuntimeError(
f"extract_backup_archive: archive {abs_archive!r} is not "
"a SecantusDB backup (no WiredTiger metadata file inside)"
)
tar.extractall(abs_target, filter="data")
return {
"targetDir": abs_target,
"fileCount": len(names),
"archive": abs_archive,
}
class DuplicateKeyError(Exception):
def __init__(self, doc_id: Any) -> None:
super().__init__(f"duplicate _id: {doc_id!r}")
self.doc_id = doc_id
def _is_operator_expr(v: Any) -> bool:
"""True when ``v`` is a query OPERATOR expression (a non-empty dict
whose keys all start with ``$``, e.g. ``{$gt: 5}``) — as opposed to
a literal subdocument equality value (``{f: 1, f2: 2}``). Used by the
upsert seed extraction to tell the two apart."""
return isinstance(v, dict) and len(v) > 0 and all(k.startswith("$") for k in v)
def _id_key(doc_id: Any) -> bytes:
"""Byte-sortable canonical bytes for an ``_id`` value.
Uses the same byte-sortable encoding the secondary-index entries
table relies on. Two consequences worth knowing:
* Cross-numeric collision: ``1 == 1.0 == Decimal128("1")`` produce
identical bytes (so they hit the same doc / clash on uniqueness),
because ``encode_value`` normalises numerics through ``Decimal``.
* Natural iteration: walking the doc table in WT-key order yields
docs in BSON cross-type sort order, which matches what real
MongoDB calls "natural order" for non-capped collections.
"""
return encode_value(doc_id)
def _is_regex_value(v: Any) -> bool:
return isinstance(v, (re.Pattern, bson.Regex))
def _id_point_lookup_keys(spec: Any) -> list[bytes] | None:
"""id_keys for an ``{_id: <spec>}`` equality predicate, or ``None``.
The documents table is keyed by ``(db, coll, encode_value(_id))``, so
an ``_id`` equality is a direct primary-key point lookup rather than a
COLLSCAN. This returns the WT key bytes to fetch for:
* a scalar bare equality (``{_id: 5}``),
* ``{_id: {$eq: scalar}}``,
* ``{_id: {$in: [scalars]}}``.
Returns ``None`` (caller falls back to its normal routing / COLLSCAN)
for range operators, regex, subdocument or operator-valued equalities,
or anything else that isn't a pure point lookup. ``$in`` keys come back
deduplicated and in ascending byte (== ``_id``) order so the caller's
sort-acceleration can treat the result as already sorted on ``_id``.
An empty ``$in`` yields ``[]`` — a valid no-match point lookup.
"""
if isinstance(spec, Mapping):
keys = list(spec.keys())
if not keys or not all(isinstance(k, str) and k.startswith("$") for k in keys):
# Literal subdocument _id — leave to the normal path.
return None
if keys == ["$eq"]:
v = spec["$eq"]
if isinstance(v, Mapping) or _is_regex_value(v):
return None
return [_id_key(v)]
if keys == ["$in"]:
vals = spec["$in"]
if not isinstance(vals, (list, tuple)):
return None
if any(isinstance(v, Mapping) or _is_regex_value(v) for v in vals):
return None
return sorted({_id_key(v) for v in vals})
return None
if _is_regex_value(spec):
return None
return [_id_key(spec)]
def _parse_index_collation(spec: Any) -> Any:
"""Parse an index's stored ``collation`` option into a Collation.
Returns ``None`` for falsy / non-dict input, or for collations
that don't support index encoding (``numericOrdering``) — the
picker treats those as "index isn't usable for collation
lookups," falling back to COLLSCAN, while the write path writes
raw-codepoint entries unchanged.
Local import avoids the ``storage → collation → sortkey →
storage`` cycle that a top-level import would create.
"""
if not isinstance(spec, dict) or not spec:
return None
from secantus.collation import parse as _parse_coll
coll = _parse_coll(spec)
if coll is None or not coll.supports_index_encoding:
return None
return coll
def _doc_makes_multikey(doc: Mapping[str, Any], key_spec: Mapping[str, Any]) -> bool:
"""True if any field in ``key_spec`` resolves to a list value in ``doc``.
Such a value is encoded as a single composite array sortkey, so a
later scalar-equality query against this index would silently miss
the doc — the index must fall back to a full scan.
"""
return any(isinstance(get_path(dict(doc), field), list) for field in key_spec)
def _index_key(
doc: Mapping[str, Any],
key_spec: Mapping[str, Any],
*,
sparse: bool,
collation: Any = None,
) -> bytes | None:
"""Direction-aware byte-sortable encoding for an index ``key_spec``.
Each field is encoded with ``encode_value_directed`` so ``-1``
(descending) fields get bitwise-inverted bytes, making a forward
B-tree walk yield values in descending order. Compound keys are
joined with ``\\x00\\x00`` between components.
``collation`` propagates to every string field — when set, string
values are normalised (accent-stripped / case-folded per the
collation strength) before encoding so the entries table sorts
by the collation's rules rather than raw codepoint. Must match
the index's stored ``collation`` option; the writers handle
that.
For docs whose indexed field is array-valued, this returns the
whole-array sortkey only — the single canonical "doc-shape" key
used by uniqueness probes. The full set of multikey entries
(per-element + whole-array) is produced by
:func:`_index_key_variants`.
"""
if sparse:
for field in key_spec:
if not has_path(dict(doc), field):
return None
fields = list(key_spec)
if len(fields) == 1:
d = int(key_spec[fields[0]])
return encode_value_directed(get_path(dict(doc), fields[0]), d, collation=collation)
parts = [
encode_value_directed(get_path(dict(doc), f), int(key_spec[f]), collation=collation)
for f in fields
]
return COMPOUND_SEP.join(parts)
def _index_key_variants(
doc: Mapping[str, Any],
key_spec: Mapping[str, Any],
*,
sparse: bool,
collation: Any = None,
) -> list[bytes]:
"""All byte-keys this doc contributes to an index under ``key_spec``.
For scalar-valued fields, returns one key — same as ``_index_key``.
For array-valued fields, returns one key per array element *and*
the whole-array key, mirroring real ``mongod``'s multikey index
layout. This makes:
* ``{tags: "python"}`` against ``{tags: ["python", "go"]}`` light
up via the per-element entry for ``"python"``.
* ``{tags: ["python", "go"]}`` (whole-array equality) light up via
the whole-array entry — without this, the equality lookup would
false-negative.
* Range / ``$in`` queries on array fields hit at least all true
matches (the post-index ``matches()`` filter discards
false-positives).
For compound indexes whose multiple fields are array-valued, the
cartesian product is taken across each field's candidate values.
Real mongod restricts compound indexes to one multikey field per
doc; we don't enforce that — we just emit the cross-product, which
is correct (over-includes; the post-filter discards) but pays a
cardinality blow-up the user is then on the hook for.
Returns an empty list when ``sparse`` and any field is missing.
Per-element values are deduplicated against their encoded bytes,
so ``[1, 1, 2]`` writes two element entries (``1`` and ``2``) plus
the whole-array entry, not three.
"""
fields = list(key_spec)
if sparse:
for field in fields:
if not has_path(dict(doc), field):
return []
# Per-field candidate values: scalars contribute [val]; arrays
# contribute [unique_elements..., whole_array].
per_field: list[list[Any]] = []
for field in fields:
v = get_path(dict(doc), field)
if isinstance(v, list):
seen: set[bytes] = set()
uniq: list[Any] = []
d = int(key_spec[field])
for elem in v:
eb = encode_value_directed(elem, d, collation=collation)
if eb in seen:
continue
seen.add(eb)
uniq.append(elem)
# Whole-array sortkey may collide with an element when the
# array is a single scalar repeated; the dedup below at the
# entry level (set of bytes) catches that.
per_field.append([*uniq, v])
else:
per_field.append([v])
if len(fields) == 1:
d = int(key_spec[fields[0]])
keys: list[bytes] = []
seen_kb: set[bytes] = set()
for val in per_field[0]:
kb = encode_value_directed(val, d, collation=collation)
if kb in seen_kb:
continue
seen_kb.add(kb)
keys.append(kb)
return keys
# Compound: cartesian product across per-field candidate lists.
from itertools import product
keys = []
seen_kb = set()
for combo in product(*per_field):
parts = [
encode_value_directed(combo[i], int(key_spec[fields[i]]), collation=collation)
for i in range(len(fields))
]
kb = COMPOUND_SEP.join(parts)
if kb in seen_kb:
continue
seen_kb.add(kb)
keys.append(kb)
return keys
# The pure BSON sort comparator lives in ``secantus.ordering`` (no I/O, so it's
# importable without the WiredTiger extension). Re-exported here for the many
# existing ``from secantus.storage import sort_docs / _SortKey / _bson_lt`` call
# sites and ``find_matching``'s internal ``sort_docs`` calls below.
from secantus.ordering import ( # noqa: E402, F401 (re-exported for back-compat)
_bson_lt,
_bson_type_rank,
_SortKey,
_to_decimal,
sort_docs,
)
_ID_INDEX_NAME = "_id_"
class IndexConflict(Exception):
def __init__(
self,
index_name: str,
doc_id: Any,
*,
key_pattern: dict[str, Any] | None = None,
key_value: dict[str, Any] | None = None,
) -> None:
super().__init__(f"E11000 duplicate key error in index {index_name}: _id={doc_id!r}")
self.index_name = index_name
self.doc_id = doc_id
# Real mongod returns ``keyPattern`` (the index spec) and
# ``keyValue`` (the conflicting field values) in the dup-key
# error response. Drivers expose them as ``errorResponse``
# fields; mongo-java-driver's ``findOneAndUpdate-errorResponse``
# asserts both. Optional because legacy raise-sites
# (``_id`` collision before index machinery, recovery paths)
# don't have the index spec handy.
self.key_pattern = key_pattern
self.key_value = key_value
class WriteConflictError(Exception):
"""A WiredTiger WT_ROLLBACK: two transactions touched the same item.
Inside a user (multi-document) transaction this surfaces to the
client as mongod's statement-time ``WriteConflict`` (code 112) with
the ``TransientTransactionError`` label, and the transaction is
aborted server-side. Outside a transaction the storage layer
retries the write briefly (a user transaction holds its uncommitted
writes until commit/abort) before giving up with the same error.
"""
def _is_wt_rollback(exc: BaseException) -> bool:
"""True when a ``WiredTigerError`` is the WT_ROLLBACK conflict signal
(as opposed to e.g. WT_DUPLICATE_KEY). The SWIG binding raises a
typed ``WiredTigerRollbackError`` subclass; the message match is a
fallback for raise-sites that re-wrap into the base class."""
if isinstance(exc, wt.WiredTigerRollbackError):
return True
msg = str(exc)
return "WT_ROLLBACK" in msg or "conflict between concurrent operations" in msg
# Non-transactional writers that hit a user transaction's uncommitted
# write retry briefly instead of blocking: mongod blocks such writers
# until the transaction commits or aborts, which we approximate with a
# backoff loop bounded by this deadline (the transaction lifetime cap
# is 60s, but a multi-second stall already covers the overwhelmingly
# common test patterns; see tasks/backlog.md for the divergence note).
# mongod's per-document BSON cap (16 MiB). Duplicated from wire.py on
# purpose: storage must not import the wire layer, and both values pin
# the same protocol constant.
MAX_BSON_OBJECT_SIZE = 16 * 1024 * 1024
class DocumentTooLargeError(Exception):
"""A write produced a document over ``MAX_BSON_OBJECT_SIZE``.
Carries mongod's per-path error code: 10334 (BSONObjectTooLarge)
for inserts and update-grown documents, 17420 for upserts. The
message is mongod's verbatim wording — drivers' tests assert it.
"""
def __init__(self, code: int, errmsg: str) -> None:
super().__init__(errmsg)
self.code = code
_WRITE_CONFLICT_RETRY_DEADLINE_S = 5.0
_WRITE_CONFLICT_RETRY_DELAY_S = 0.005
_WRITE_CONFLICT_RETRY_DELAY_MAX_S = 0.02
def _retry_write_conflicts(fn: Callable[..., Any]) -> Callable[..., Any]:
"""Retry a whole public write method on WT_ROLLBACK.
Safe because the failed attempt's ``_batch_transaction`` already
rolled everything back and the per-collection lock is released on
the way out — the retry re-runs from scratch. Inside a user
transaction the conflict is NOT retried: it surfaces immediately so
the command layer can abort the transaction with mongod's
statement-time ``WriteConflict``.
"""
@functools.wraps(fn)
def wrapper(self: Storage, *args: Any, **kwargs: Any) -> Any:
deadline: float | None = None
delay = _WRITE_CONFLICT_RETRY_DELAY_S
while True:
try:
return fn(self, *args, **kwargs)
except (WriteConflictError, wt.WiredTigerError) as exc:
if not isinstance(exc, WriteConflictError) and not _is_wt_rollback(exc):
raise
if getattr(self._tls, "user_txn", None) is not None:
raise
now = _time.monotonic()
if deadline is None:
deadline = now + _WRITE_CONFLICT_RETRY_DEADLINE_S
if now >= deadline:
if isinstance(exc, WriteConflictError):
raise
raise WriteConflictError(str(exc)) from exc
_time.sleep(delay)
delay = min(delay * 2, _WRITE_CONFLICT_RETRY_DELAY_MAX_S)
return wrapper
class UserTransactionHandle:
"""Storage-side state of one multi-document transaction.
Knows nothing about ``lsid`` / ``txnNumber`` — that's the
``secantus.transactions`` registry's layer. Carries the dedicated
WT session, its cursor cache (same ``(table, overwrite)`` keying as
the per-thread cache), and the buffered oplog entries + pre-images
that ``commit_user_transaction`` flushes.
"""
__slots__ = ("session", "cursors", "began", "closed", "oplog_entries", "pre_images")
def __init__(self, session: Any) -> None:
self.session = session
self.cursors: dict[tuple[str, bool], Any] = {}
self.began = False
self.closed = False
self.oplog_entries: list[dict[str, Any]] = []
self.pre_images: list[bytes | None] = []
class DocumentValidationError(Exception):
"""A write produced a doc that didn't satisfy the collection's
``validator``. Caught at the command layer and surfaced as the
mongod-shaped writeError (code 121, ``DocumentValidationFailure``)
with the ``errInfo.failingDocumentId`` field drivers' errorResponse
tests assert on."""
def __init__(self, doc_id: Any) -> None:
super().__init__("Document failed validation")
self.doc_id = doc_id
class CreateIndexUnsupported(Exception):
"""``create_index`` was given an index type SecantusDB doesn't support
(currently ``text`` / ``hashed``). Caught at the command layer and
surfaced as a typed wire error rather than letting the cell-encoder
later trip over an opaque internal exception."""
class IndexOptionsConflict(Exception):
"""``create_index`` was called with a name that already exists in the
collection but with conflicting options (different ``unique`` /
``sparse`` / ``hidden`` / ``expireAfterSeconds`` /
``partialFilterExpression``). Real mongod rejects with
``IndexOptionsConflict`` (code 85); drivers (mongo-ruby-driver's
``Collection#create_indexes`` specs) assert on the rejection."""
class GeoExtractError(Exception):
"""Doc's geo field can't be indexed — bad shape or out-of-bounds coords.
Raised from the geo-index write path when an insert / update / index
creation hits a doc the geo extractor can't make sense of (bad
GeoJSON, non-numeric coordinates, longitude / latitude outside the
valid range, etc.). Caught at the command-layer write boundary and
surfaced as a wire-level write error with mongod's documented code
16572 ("Can't extract geo keys").
"""
def __init__(self, index_name: str, field: str, doc_id: Any, reason: str) -> None:
super().__init__(
f"Can't extract geo keys for index {index_name!r} on field {field!r}: {reason}"
)
self.index_name = index_name
self.field = field
self.doc_id = doc_id
self.reason = reason
class BadHint(Exception):
"""The ``hint`` passed to ``find_matching`` doesn't name an existing index."""
class MinMaxKeyError(Exception):
"""Cursor ``min`` / ``max`` bounds don't match the hinted index key
pattern (mongod surfaces this as 51174)."""
def _op_implies_bound(qop: str, qv: Any, pop: str, pv: Any) -> bool:
"""Does a single query constraint ``(qop, qv)`` guarantee the partial
bound ``(pop, pv)``? Comparison uses ``encode_value`` so it follows
MongoDB's cross-type BSON sort order. Returns ``False`` for any
operator pairing it can't prove (soundness over completeness)."""
try:
a, b = encode_value(qv), encode_value(pv)
except Exception:
return False
le, lt, ge, gt, eq = a <= b, a < b, a >= b, a > b, a == b
if pop in ("$lte", "$lt"):
# query upper-bounds the field; need its max <= / < pv.
if qop == "$eq":
return le if pop == "$lte" else lt
if qop == "$lte":
return le if pop == "$lte" else lt
if qop == "$lt":
return le # a < qv <= pv => a < pv => a <= pv (and a < pv for $lt)
return False
if pop in ("$gte", "$gt"):
if qop == "$eq":
return ge if pop == "$gte" else gt
if qop == "$gte":
return ge if pop == "$gte" else gt
if qop == "$gt":
return ge
return False
if pop == "$eq":
return qop == "$eq" and eq
return False
def _clause_implies_bounds(qval: Any, pbound: Mapping[str, Any]) -> bool:
"""True if the query clause ``qval`` (a bare value or an operator
dict) guarantees every constraint in the partial operator dict
``pbound`` (e.g. ``{$lte: 1.5}``)."""
if isinstance(qval, Mapping) and qval and all(k.startswith("$") for k in qval):
q_constraints = list(qval.items())
else:
q_constraints = [("$eq", qval)]
for pop, pv in pbound.items():
if pop not in ("$eq", "$lt", "$lte", "$gt", "$gte"):
return False # partial filter uses an operator we can't reason about
if not any(_op_implies_bound(qop, qv, pop, pv) for qop, qv in q_constraints):
return False
return True
[docs]
class Storage:
def __init__(
self,
path: str = ":memory:",
*,
oplog_retention_seconds: float = 3600.0,
oplog_max_entries: int = 100_000,
time_func: Callable[[], float] | None = None,
enable_oplog: bool = True,
ttl_sweep_seconds: float = 60.0,
noop_heartbeat_seconds: float = 0.0,
cache_size: str = "1G",
session_max: int = 1000,
sync_on_commit: bool = False,
) -> None:
# When False, _emit_oplog short-circuits and writes nothing —
# used in standalone (non-replica-set) mode to skip the per-write
# BSON encode + WT cursor write cost of oplog entries that no
# change-stream client will ever read. The oplog WT tables are
# still created so toggling at runtime stays safe.
self.enable_oplog = enable_oplog
self._lock = threading.RLock()
self._closed = False
# Per-insert discriminator counter for timeseries doc keys (see
# ``_timeseries_doc_suffix``). Only disambiguates inserts that land
# in the same nanosecond; wall-clock restart-safety comes from the
# ``time_ns`` prefix.
self._ts_suffix_counter = 0
self._tempdir: str | None = None
# session_max default is ~120; each client connection thread
# caches its own session in `threading.local()`, and cross-
# thread oplog readers open additional short-lived sessions on
# demand. With a few dozen concurrent client connections plus
# active change-stream tailers, the default ceiling is hit
# mid-handshake and surfaces as `out of sessions` /
# WT_ERROR. mongod itself runs with session_max=33000 — 1000
# is a generous floor for a single-node test surrogate while
# still well under the WT hard limit.
# cache_size default is 100 MB. With ``in_memory=true`` every
# write also lives in cache, so a workload that inserts a
# handful of 16 MB documents (mongod's per-doc max) blows the
# cap as ``WT_CACHE_FULL: operation would overflow cache``.
# 1 GB gives generous headroom for tests + reasonable
# in-process workloads while staying well under the limits
# ``mongod`` itself runs with on a normal box.
# Tracked so ``checkpoint()`` calls are skipped in in-memory
# mode (WT's in_memory backend rejects them with a noisy
# ``__wt_inmem_unsupported_op`` log line on every call).
self._in_memory = path == ":memory:"
# Stashed for reuse in restore-archive / explain output.
self.cache_size = cache_size
self.session_max = session_max
self.sync_on_commit = sync_on_commit
if path == ":memory:":
self._tempdir = tempfile.mkdtemp(prefix="secantus_wt_")
home = self._tempdir
# in_memory=true disables the journal entirely (no files);
# ephemeral by definition, so durability isn't a concern.
config = f"create,in_memory=true,session_max={session_max},cache_size={cache_size}"
else:
os.makedirs(path, exist_ok=True)
home = path
# ``log=(enabled=true)`` turns on WT's redo journal: every
# transaction commit writes a log record before it returns,
# and recovery replays the log on reopen. Without this,
# WT's only durability mechanism is checkpoints (default
# cadence: every 60s, or on clean ``WT_CONNECTION->close``).
# On SIGKILL between checkpoints, every uncommitted write
# is lost — which is exactly the failure mode observed by
# ``bench/chaos.py`` (3-min chaos run, 17 SIGKILLs:
# 432,881 acked / 1 persisted).
#
# ``transaction_sync`` is the per-commit durability knob.
# Default ``enabled=false,method=fsync`` matches mongod's
# default ``writeConcern: {w:1, j:false}`` — log records
# land in the OS page cache, the OS flushes them on its
# own schedule, SIGKILL is durable, true power-loss
# between commits can lose data.
#
# ``sync_on_commit=True`` (config-file knob) bumps to
# ``enabled=true,method=fsync``: every commit fsyncs the
# log before returning, so the wire-protocol equivalent of
# ``writeConcern: {j: true}`` is effectively enforced for
# the whole connection. Throughput cost on small-doc
# inserts is significant (1-2 orders of magnitude),
# which is why it's opt-in.
#
# ``file_max=10MB`` bounds journal segment size; smaller
# files churn the log more, larger files delay reclamation.
# 10 MB matches mongod's WT default.
sync_part = (
"transaction_sync=(enabled=true,method=fsync)"
if sync_on_commit
else "transaction_sync=(enabled=false,method=fsync)"
)
config = (
f"create,session_max={session_max},cache_size={cache_size},"
f"log=(enabled=true,file_max=10MB),"
f"{sync_part}"
)
# The on-disk WT home is stashed so ``create_archive`` can tar
# it after a checkpoint without re-deriving the path.
self.home_path = home
self._conn = wt.wiredtiger_open(home, config)
self._tls = threading.local()
self._all_sessions: list[Any] = []
boot = self._conn.open_session()
try:
boot.create(_COLL_TABLE, "key_format=SS,value_format=u")
boot.create(_DOC_TABLE, "key_format=SSu,value_format=u")
boot.create(_IDX_TABLE, "key_format=SSS,value_format=u")
boot.create(_IDX_ENTRIES_TABLE, "key_format=SSSu,value_format=u")
boot.create(_OPLOG_TABLE, "key_format=q,value_format=u")
boot.create(_PREIMAGE_TABLE, "key_format=q,value_format=u")
boot.create(_OPLOG_META_TABLE, "key_format=S,value_format=u")
boot.create(_USERS_TABLE, "key_format=SS,value_format=u")
boot.create(_ROLES_TABLE, "key_format=SS,value_format=u")
boot.create(_PROFILE_TABLE, "key_format=S,value_format=u")
finally:
boot.close()
# Oplog state — durable across restart via _OPLOG_META_TABLE.
self.oplog_retention_seconds = float(oplog_retention_seconds)
self.oplog_max_entries = int(oplog_max_entries)
self._time = time_func or _time.time
self._oplog_cv = threading.Condition(threading.Lock())
# Set by ``signal_shutdown()`` at server stop so tailable getMore
# waiters stop blocking and their connection threads drain *before*
# ``close()`` tears down the WT connection — a thread mid-WT-op when
# the connection closes is a use-after-free / native crash.
self._shutting_down = False
self._oplog_emit_count = 0
# Tiny fine-grained lock for seq + timestamp minting. Held in
# microseconds while reserving the next seq range and bumping
# the cluster-time counter. Carved out of ``_lock`` (Phase 2.1
# of the WT concurrency plan) so concurrent writers can mint
# without contending on the global storage lock.
self._oplog_seq_lock = threading.Lock()
# Per-collection RLocks for the CRUD path (Phase 2.4 of the WT
# concurrency plan). Writes to *different* collections can now
# run in parallel; writes to the *same* collection still
# serialise (preserves unique-index correctness + the pre-check
# racing windows that would otherwise need an architectural
# refactor of the index-entries schema). DDL operations also
# acquire the per-coll lock(s) they affect so they cannot
# reshape schema mid-CRUD-write.
self._coll_locks: dict[tuple[str, str], threading.RLock] = {}
self._coll_locks_mutex = threading.Lock()
with self._lock:
self._next_seq, self._last_ts_secs, self._last_ts_ord = self._load_oplog_meta()
# TTL sweeper. Real mongod runs ``ttlMonitor`` every 60s by
# default; we mirror that. ``ttl_sweep_seconds <= 0`` disables
# the thread entirely (tests that drive expiry deterministically
# via ``prune_ttl(now=...)`` use that escape hatch). The
# sweeper walks every (db, coll) and calls ``prune_ttl`` on
# each — collections with no TTL index short-circuit cheaply
# at the index-scan step, so the steady-state cost is small.
self._ttl_sweep_seconds = float(ttl_sweep_seconds)
self._ttl_stop = threading.Event()
self._ttl_thread: threading.Thread | None = None
if self._ttl_sweep_seconds > 0:
self._ttl_thread = threading.Thread(
target=self._ttl_sweep_loop, name="secantus-ttl-sweeper", daemon=True
)
self._ttl_thread.start()
# Periodic noop heartbeat. Real mongod writes ``{op: "n"}``
# entries to the oplog every ~10s (configurable via
# ``periodicNoopIntervalSecs``) so cluster time advances and
# change-stream resume tokens minted from the oplog don't fall
# outside the retention window during quiet stretches. Default
# disabled (0) — embedded test users typically don't need it
# and the extra writes would noise up tight oplog assertions.
# Set ``noop_heartbeat_seconds=10`` (mongod default) for
# production-ish behaviour. ``enable_oplog=False`` short-
# circuits anyway, so the heartbeat is a no-op in that mode.
self._noop_heartbeat_seconds = float(noop_heartbeat_seconds)
self._noop_stop = threading.Event()
self._noop_thread: threading.Thread | None = None
if self._noop_heartbeat_seconds > 0 and self.enable_oplog:
self._noop_thread = threading.Thread(
target=self._noop_heartbeat_loop, name="secantus-noop-heartbeat", daemon=True
)
self._noop_thread.start()
def _load_oplog_meta(self) -> tuple[int, int, int]:
c = self._cursor(_OPLOG_META_TABLE)
c.set_key("state")
if c.search() == 0:
blob = bytes(c.get_value())
if blob:
state = bson.decode(blob)
return (
int(state.get("next_seq", 1)),
int(state.get("last_ts_secs", 0)),
int(state.get("last_ts_ord", 0)),
)
# Fallback: scan oplog table for max key + reconstruct from entry.
c2 = self._cursor(_OPLOG_TABLE)
# Walk to last row.
last_seq = 0
last_secs = 0
last_ord = 0
rc = c2.next()
while rc == 0:
seq = int(c2.get_key())
if seq > last_seq:
last_seq = seq
blob = bytes(c2.get_value())
if blob:
entry = bson.decode(blob)
ts = entry.get("ts")
if isinstance(ts, Timestamp):
last_secs, last_ord = ts.time, ts.inc
rc = c2.next()
return last_seq + 1, last_secs, last_ord
def _persist_oplog_meta(self) -> None:
c = self._cursor(_OPLOG_META_TABLE)
c["state"] = bson.encode(
{
"next_seq": self._next_seq,
"last_ts_secs": self._last_ts_secs,
"last_ts_ord": self._last_ts_ord,
}
)
def _mint_ts(self) -> Timestamp:
"""Return a strictly-monotonic ``Timestamp(secs, ord)``.
Caller must hold ``self._oplog_seq_lock``. Within a single
wall-clock second ``ord`` increments; on a new second it resets
to 1. Recovered state on startup ensures the first mint after
restart is strictly greater than any previously-emitted
timestamp.
"""
now = int(self._time())
if now > self._last_ts_secs:
self._last_ts_secs = now
self._last_ts_ord = 1
else:
self._last_ts_ord += 1
return Timestamp(self._last_ts_secs, self._last_ts_ord)
def _coll_lock(self, db: str, coll: str) -> threading.RLock:
"""Return the per-collection RLock for ``(db, coll)``, creating it
on first reference. Phase 2.4 of the WT concurrency plan.
CRUD on a given collection serialises through this lock; CRUD on
*other* collections proceeds in parallel. DDL on this collection
also acquires this lock so schema changes cannot interleave with
in-flight writes.
"""
key = (db, coll)
# Fast path: lock already exists — read without any mutation,
# safe under GIL.
existing = self._coll_locks.get(key)
if existing is not None:
return existing
# Create-or-fetch under the small registry mutex. RLocks are
# never removed (collections come and go but the lock identity
# for a given (db, coll) stays stable across drop+recreate to
# avoid races with in-flight writers).
with self._coll_locks_mutex:
existing = self._coll_locks.get(key)
if existing is not None:
return existing
lock = threading.RLock()
self._coll_locks[key] = lock
return lock
def _mint_oplog_seq_and_ts(self, n: int) -> tuple[int, list[Timestamp]]:
"""Atomically reserve ``n`` consecutive oplog seq numbers and mint
``n`` strictly-monotonic timestamps. Returns ``(start_seq,
[ts_0, ..., ts_{n-1}])``.
Held only under ``_oplog_seq_lock`` (microseconds of work) — the
actual oplog cursor writes happen in the caller's WT session
without blocking other writers on this lock.
"""
with self._oplog_seq_lock:
start = self._next_seq
self._next_seq += n
timestamps = [self._mint_ts() for _ in range(n)]
return start, timestamps
def _collection_uuid(self, db: str, coll: str) -> _uuid.UUID:
"""Return the collection's UUID, minting and persisting on first call.
Fast path (UUID already present): no Python lock — straight WT
cursor read on the calling thread's session. This was a major
per-insert bottleneck before Phase 2.4: every write re-acquired
``self._lock`` here, defeating the per-collection lock split.
Slow path (mint a new UUID): take ``_coll_lock`` for the
namespace to serialise the persist; double-check inside the
lock so two racing callers can't mint different UUIDs for the
same collection.
"""
opts = self._coll_options(db, coll) or {}
existing = opts.get("uuid")
if isinstance(existing, _uuid.UUID):
return existing
if isinstance(existing, bson.Binary) and len(existing) == 16:
return _uuid.UUID(bytes=bytes(existing))
if isinstance(existing, bytes) and len(existing) == 16:
return _uuid.UUID(bytes=existing)
# Mint path — take the per-coll lock; re-read after acquiring
# so a racer that won the mint race is observed.
with self._coll_lock(db, coll):
opts = self._coll_options(db, coll) or {}
existing = opts.get("uuid")
if isinstance(existing, _uuid.UUID):
return existing
if isinstance(existing, bson.Binary) and len(existing) == 16:
return _uuid.UUID(bytes=bytes(existing))
if isinstance(existing, bytes) and len(existing) == 16:
return _uuid.UUID(bytes=existing)
new_uuid = _uuid.uuid4()
opts["uuid"] = new_uuid
self._write_coll_options(db, coll, opts)
return new_uuid
def collection_uuid(self, db: str, coll: str) -> _uuid.UUID:
"""Public alias for ``_collection_uuid``."""
return self._collection_uuid(db, coll)
def current_cluster_time(self) -> Timestamp:
"""Return a strictly-monotonic ``Timestamp`` advancing the cluster clock."""
with self._oplog_seq_lock:
ts = self._mint_ts()
# Meta persist uses the calling thread's WT session/cursor —
# safe to do outside the seq lock since it doesn't depend on
# the in-memory counters being held stable past the mint.
self._persist_oplog_meta()
return ts
def peek_cluster_time(self) -> Timestamp:
"""The last minted cluster time WITHOUT advancing the clock.
Reply gossip (``$clusterTime`` / ``operationTime`` attached to
every command reply) observes cluster time; only writes and the
explicit ``current_cluster_time`` advance it — matching mongod,
where reads gossip the node's known cluster time. A virgin
store mints once so the gossiped value is never
``Timestamp(0, 0)``.
"""
with self._oplog_seq_lock:
if self._last_ts_secs:
return Timestamp(self._last_ts_secs, self._last_ts_ord)
return self.current_cluster_time()
def _write_coll_options(self, db: str, coll: str, opts: Mapping[str, Any]) -> None:
c = self._cursor(_COLL_TABLE)
# bson can't directly encode a uuid.UUID without a codec, so store as Binary subtype 4.
encoded: dict[str, Any] = {}
for k, v in opts.items():
if isinstance(v, _uuid.UUID):
encoded[k] = bson.Binary(v.bytes, subtype=4)
else:
encoded[k] = v
c[db, coll] = bson.encode(encoded) if encoded else b""
def set_collection_options(self, db: str, coll: str, **opts: Any) -> None:
"""Merge ``opts`` into the collection's options blob (creates if absent)."""
with self._lock:
self._ensure_collection(db, coll)
current = self._coll_options(db, coll) or {}
current.update(opts)
self._write_coll_options(db, coll, current)
def get_collection_options(self, db: str, coll: str) -> dict[str, Any]:
"""Return the collection's options blob, or ``{}`` if absent."""
if self.enable_oplog and db == "local" and coll == "oplog.rs":
# Synthetic ``local.oplog.rs``: report the capped-collection
# shape mongod uses so $collStats / listCollections options
# match. ``size`` is a notional byte cap derived from the
# entry cap × a conservative per-entry estimate; we don't
# track real byte usage, only entry count.
return {
"capped": True,
"size": self.oplog_max_entries * 16 * 1024,
"max": self.oplog_max_entries,
}
self._refresh_read_snapshot()
with self._lock:
opts = self._coll_options(db, coll) or {}
# Decode UUID Binary back into uuid.UUID for callers.
decoded: dict[str, Any] = {}
for k, v in opts.items():
if k == "uuid" and isinstance(v, bson.Binary) and len(v) == 16:
decoded[k] = _uuid.UUID(bytes=bytes(v))
else:
decoded[k] = v
return decoded
def _is_oplog_rs(self, db: str, coll: str) -> bool:
"""``(local, oplog.rs)`` is the synthetic oplog view."""
return self.enable_oplog and db == "local" and coll == "oplog.rs"
def _scan_oplog_entries(self) -> list[dict[str, Any]]:
"""Walk every persisted oplog entry and return the decoded docs.
Uses a private short-lived session so the read view always
reflects rows committed by writer threads on other connections
(same pattern as ``read_oplog``).
"""
rows: list[dict[str, Any]] = []
with self._lock:
session = self._conn.open_session()
try:
c = session.open_cursor(_OPLOG_TABLE, None)
try:
rc = c.next()
while rc == 0:
blob = bytes(c.get_value())
if blob:
rows.append(bson.decode(blob))
rc = c.next()
finally:
with contextlib.suppress(Exception):
c.close()
finally:
with contextlib.suppress(Exception):
session.close()
return rows
def _find_oplog_rs(
self,
filter: dict[str, Any] | None,
*,
skip: int,
limit: int,
sort: Mapping[str, Any] | None,
projection: Mapping[str, Any] | None,
let: dict[str, Any] | None,
collation: Any,
) -> list[dict[str, Any]]:
"""Read path for the synthetic ``local.oplog.rs`` view.
Entries are walked in seq order (== ts order). Filter / sort /
skip / limit / projection are all honoured against the decoded
entry docs via the existing pure-Python helpers.
"""
from secantus.collation import parse as _parse_collation
collation_obj = _parse_collation(collation)
rows = self._scan_oplog_entries()
if filter:
rows = [r for r in rows if matches(r, filter, vars=let, collation=collation_obj)]
if sort:
# ``$natural`` is the oplog's only meaningful order: entries are
# already scanned in natural (seq == insertion == ts) order, so
# ``$natural: 1`` is the identity and ``$natural: -1`` reverses.
# It's a pseudo-field, not a document field, so it must not go
# through the generic field-sort (which would see it as missing).
natural = sort.get("$natural") if isinstance(sort, Mapping) else None
if natural is not None:
if int(natural) < 0:
rows = list(reversed(rows))
else:
rows = sort_docs(rows, sort)
if skip:
rows = rows[skip:]
if limit > 0:
rows = rows[:limit]
if projection:
rows = [apply_projection(r, projection) for r in rows]
return rows
def _is_system_users(self, db: str, coll: str) -> bool:
"""``admin.system.users`` is the synthetic view onto the user
store. Mongod surfaces user records there regardless of which
database the user was created against — the per-user ``db``
field of each record names the authentication database. Other
databases' ``system.users`` namespace exists but is empty (also
matches mongod)."""
return db == "admin" and coll == "system.users"
def _scan_user_records(self) -> list[dict[str, Any]]:
"""Walk every persisted user record across all databases and
return the decoded docs. Uses a private short-lived session for
the same cross-thread visibility reason as
:meth:`_scan_oplog_entries`."""
rows: list[dict[str, Any]] = []
with self._lock:
session = self._conn.open_session()
try:
c = session.open_cursor(_USERS_TABLE, None)
try:
rc = c.next()
while rc == 0:
blob = bytes(c.get_value())
if blob:
rows.append(bson.decode(blob))
rc = c.next()
finally:
with contextlib.suppress(Exception):
c.close()
finally:
with contextlib.suppress(Exception):
session.close()
return rows
def _find_system_users(
self,
filter: dict[str, Any] | None,
*,
skip: int,
limit: int,
sort: Mapping[str, Any] | None,
projection: Mapping[str, Any] | None,
let: dict[str, Any] | None,
collation: Any,
) -> list[dict[str, Any]]:
"""Read path for ``admin.system.users``. The user records
themselves already carry the mongod-shaped fields (``_id`` =
``<db>.<user>``, ``user``, ``db``, ``credentials``, ``roles``,
``mechanisms``), so the view is the row set unchanged plus the
usual filter / sort / skip / limit / projection pipeline."""
from secantus.collation import parse as _parse_collation
collation_obj = _parse_collation(collation)
rows = self._scan_user_records()
if filter:
rows = [r for r in rows if matches(r, filter, vars=let, collation=collation_obj)]
if sort:
rows = sort_docs(rows, sort)
if skip:
rows = rows[skip:]
if limit > 0:
rows = rows[:limit]
if projection:
rows = [apply_projection(r, projection) for r in rows]
return rows
def _count_system_users(
self,
filter: dict[str, Any] | None,
*,
let: dict[str, Any] | None,
collation: Any,
) -> int:
from secantus.collation import parse as _parse_collation
collation_obj = _parse_collation(collation)
rows = self._scan_user_records()
if not filter:
return len(rows)
return sum(1 for r in rows if matches(r, filter, vars=let, collation=collation_obj))
def _is_system_version(self, db: str, coll: str) -> bool:
"""``admin.system.version`` is the synthetic view that surfaces
the user-management auth-schema doc. Mongod stores other
cluster-state docs here too (e.g. the version-2-to-3 schema
upgrade snapshot from MongoDB 2.6 → 3.0), but in modern
deployments the only doc that tooling cares about is
``{_id: "authSchema", currentVersion: 5}`` — the version SCRAM
introduced. Surfacing just that doc is what driver tools
actually check on startup before issuing user-management
commands."""
return db == "admin" and coll == "system.version"
def _system_version_docs(self) -> list[dict[str, Any]]:
"""The fixed contents of ``admin.system.version``.
Mongod's ``authSchema`` currentVersion is ``5`` as of MongoDB
4.0 — the SCRAM-SHA-256 baseline. We advertise the same number
so tools that gate user-management on the schema version
proceed (we implement SCRAM-SHA-256 natively, so 5 is honest).
"""
return [{"_id": "authSchema", "currentVersion": 5}]
def _find_system_version(
self,
filter: dict[str, Any] | None,
*,
skip: int,
limit: int,
sort: Mapping[str, Any] | None,
projection: Mapping[str, Any] | None,
let: dict[str, Any] | None,
collation: Any,
) -> list[dict[str, Any]]:
"""Read path for ``admin.system.version`` — synthetic fixed-doc view."""
from secantus.collation import parse as _parse_collation
collation_obj = _parse_collation(collation)
rows = self._system_version_docs()
if filter:
rows = [r for r in rows if matches(r, filter, vars=let, collation=collation_obj)]
if sort:
rows = sort_docs(rows, sort)
if skip:
rows = rows[skip:]
if limit > 0:
rows = rows[:limit]
if projection:
rows = [apply_projection(r, projection) for r in rows]
return rows
def _count_system_version(
self,
filter: dict[str, Any] | None,
*,
let: dict[str, Any] | None,
collation: Any,
) -> int:
from secantus.collation import parse as _parse_collation
collation_obj = _parse_collation(collation)
rows = self._system_version_docs()
if not filter:
return len(rows)
return sum(1 for r in rows if matches(r, filter, vars=let, collation=collation_obj))
def _count_oplog_rs(
self,
filter: dict[str, Any] | None,
*,
let: dict[str, Any] | None,
collation: Any,
) -> int:
from secantus.collation import parse as _parse_collation
collation_obj = _parse_collation(collation)
if not filter:
return len(self._scan_oplog_entries())
return sum(
1
for r in self._scan_oplog_entries()
if matches(r, filter, vars=let, collation=collation_obj)
)
def _emit_oplog(
self,
entries: list[dict[str, Any]],
pre_images: list[bytes | None] | None = None,
) -> int:
"""Append ``entries`` to the oplog table under ``self._lock``.
``pre_images`` is parallel to ``entries``; non-None elements are
stored under the matching seq in ``_PREIMAGE_TABLE``. Returns the
highest seq emitted (0 if ``entries`` is empty). Notifies waiters
on ``self._oplog_cv`` once writes have committed.
If ``self.enable_oplog`` is False, returns 0 immediately — the
caller's prebuilt ``entries`` list is discarded. The change-stream
condvar is still notified so any tailable getMore wakes up and
observes the (empty) state.
When a user (multi-document) transaction is installed on this
thread, entries are **buffered** on the transaction handle and
nothing is written or notified: seqs must be minted at commit
time, because a statement-time seq could become visible *behind*
a concurrent change-stream reader's position and the event would
be silently skipped. ``commit_user_transaction`` flushes the
buffer through this same method (with the buffering hook
disarmed) inside the transaction's WT session.
"""
handle = getattr(self._tls, "user_txn", None)
if handle is not None:
if self.enable_oplog and entries:
if pre_images is None:
pre_images = [None] * len(entries)
handle.oplog_entries.extend(entries)
handle.pre_images.extend(pre_images)
return 0
if not self.enable_oplog:
with self._oplog_cv:
self._oplog_cv.notify_all()
return 0
if not entries:
return 0
if pre_images is None:
pre_images = [None] * len(entries)
assert len(pre_images) == len(entries)
# Reserve seq + ts range up-front under the tiny seq lock.
# The actual cursor writes below run on this thread's WT
# session without holding any cross-thread Python lock.
n = len(entries)
start_seq, ts_range = self._mint_oplog_seq_and_ts(n)
op_cur = self._cursor(_OPLOG_TABLE)
pre_cur = None
last_seq = 0
for i, (entry, pre) in enumerate(zip(entries, pre_images, strict=True)):
seq = start_seq + i
entry_with_ts = dict(entry)
if "ts" not in entry_with_ts:
entry_with_ts["ts"] = ts_range[i]
if "wall" not in entry_with_ts:
entry_with_ts["wall"] = _dt.datetime.now(_dt.timezone.utc)
op_cur[seq] = bson.encode(entry_with_ts)
if pre is not None:
if pre_cur is None:
pre_cur = self._cursor(_PREIMAGE_TABLE)
pre_cur[seq] = pre
last_seq = seq
# ``_persist_oplog_meta`` was called here on every emit, but
# under concurrent writers it WT-rollbacks half the time —
# every writer hits the same single ``"state"`` meta row.
# The meta row is purely a recovery optimisation; if it's
# stale, ``_load_oplog_meta``'s fallback scans the oplog
# table for the actual max seq. So we now persist only on
# close + on prune_oplog, both of which are rare. The seq
# mint itself is durable because the actual oplog rows are
# written on every emit.
self._oplog_emit_count += len(entries)
if self._oplog_emit_count >= _OPLOG_PRUNE_INTERVAL:
self._oplog_emit_count = 0
self._prune_oplog_locked(now=self._time())
with self._oplog_cv:
self._oplog_cv.notify_all()
return last_seq
def read_oplog(
self,
*,
start_seq: int,
limit: int,
ns_filter: Callable[[str], bool] | None = None,
) -> list[tuple[int, dict[str, Any]]]:
"""Forward-scan the oplog from ``start_seq`` (inclusive).
Uses a private short-lived session so the read view always reflects
rows committed by other sessions. The cached per-thread session's
snapshot is sticky — under WiredTiger's MVCC, reusing it across
getMore polls would never observe oplog rows produced by a writer
running on a different connection thread.
"""
out: list[tuple[int, dict[str, Any]]] = []
with self._lock:
session = self._conn.open_session()
try:
c = session.open_cursor(_OPLOG_TABLE, None)
try:
c.set_key(int(start_seq))
rc = c.search_near()
if rc == wt.WT_NOTFOUND:
return out
if rc < 0 and c.next() != 0:
return out
while True:
seq = int(c.get_key())
blob = bytes(c.get_value())
if blob:
entry = bson.decode(blob)
if ns_filter is None or ns_filter(str(entry.get("ns", ""))):
out.append((seq, entry))
if len(out) >= limit:
break
if c.next() != 0:
break
finally:
with contextlib.suppress(Exception):
c.close()
finally:
with contextlib.suppress(Exception):
session.close()
return out
def read_preimage(self, seq: int) -> dict[str, Any] | None:
"""Return the pre-image doc for ``seq`` if one was stored, else ``None``.
Uses a private session for cross-thread visibility (see ``read_oplog``).
"""
with self._lock:
session = self._conn.open_session()
try:
c = session.open_cursor(_PREIMAGE_TABLE, None)
try:
c.set_key(int(seq))
if c.search() != 0:
return None
blob = bytes(c.get_value())
if not blob:
return None
return bson.decode(blob)
finally:
with contextlib.suppress(Exception):
c.close()
finally:
with contextlib.suppress(Exception):
session.close()
def oplog_tail_seq(self) -> int:
"""Highest seq currently present (or last emitted). 0 if empty."""
with self._lock:
return self._next_seq - 1
def oplog_tail_seq_nolock(self) -> int:
"""Highest seq read without acquiring ``self._lock``.
Safe for use **only** as the wake predicate for a tailable
``getMore`` waiting on ``self._oplog_cv``: lock order in the
write path is ``_lock`` -> ``_oplog_cv``, so a waiter that
already holds ``_oplog_cv`` (which is what ``cv.wait_for``
does) MUST NOT then take ``_lock`` -- that's an ABBA deadlock
with any concurrent writer. Reading ``_next_seq`` directly is
safe because (a) ``int`` reads are atomic under the GIL and
(b) the cv is also notified on every commit, so any momentary
stale read self-corrects on the next iteration of the
``wait_for`` predicate.
"""
return self._next_seq - 1
def oplog_floor_seq(self) -> int:
"""Smallest seq currently present after pruning. 0 if empty.
Uses a private session for cross-thread visibility.
"""
with self._lock:
session = self._conn.open_session()
try:
c = session.open_cursor(_OPLOG_TABLE, None)
try:
rc = c.next()
if rc != 0:
return 0
return int(c.get_key())
finally:
with contextlib.suppress(Exception):
c.close()
finally:
with contextlib.suppress(Exception):
session.close()
def find_seq_for_ts(self, ts: Timestamp) -> int:
"""Smallest seq whose entry ``ts >= target``. Tail+1 if none qualify.
Uses a private session for cross-thread visibility.
"""
with self._lock:
session = self._conn.open_session()
try:
c = session.open_cursor(_OPLOG_TABLE, None)
try:
rc = c.next()
while rc == 0:
seq = int(c.get_key())
blob = bytes(c.get_value())
if blob:
entry = bson.decode(blob)
entry_ts = entry.get("ts")
if isinstance(entry_ts, Timestamp) and (
entry_ts.time > ts.time
or (entry_ts.time == ts.time and entry_ts.inc >= ts.inc)
):
return seq
rc = c.next()
return self._next_seq
finally:
with contextlib.suppress(Exception):
c.close()
finally:
with contextlib.suppress(Exception):
session.close()
def prune_oplog(self, *, now: float | None = None) -> int:
"""Drop oplog rows older than retention or above the entry cap."""
with self._lock:
return self._prune_oplog_locked(now=now)
def _ns(self, db: str, coll: str) -> str:
return f"{db}.{coll}"
def _pre_post_images_enabled(self, db: str, coll: str) -> bool:
opts = self._coll_options(db, coll) or {}
cfg = opts.get("changeStreamPreAndPostImages")
return isinstance(cfg, Mapping) and bool(cfg.get("enabled"))
def _prune_oplog_locked(self, *, now: float | None = None) -> int:
when = now if now is not None else self._time()
cutoff_secs = int(when - self.oplog_retention_seconds)
# Two-phase: collect doomed seqs, then delete (avoid mutating during scan).
doomed: list[int] = []
all_seqs: list[int] = []
c = self._cursor(_OPLOG_TABLE)
rc = c.next()
while rc == 0:
seq = int(c.get_key())
blob = bytes(c.get_value())
all_seqs.append(seq)
if blob:
entry = bson.decode(blob)
ts = entry.get("ts")
if isinstance(ts, Timestamp) and ts.time < cutoff_secs:
doomed.append(seq)
rc = c.next()
# Trim to entry cap by extending doom set to oldest entries.
kept_count = len(all_seqs) - len(doomed)
if kept_count > self.oplog_max_entries:
extra = kept_count - self.oplog_max_entries
doomed_set = set(doomed)
for seq in all_seqs:
if extra <= 0:
break
if seq not in doomed_set:
doomed.append(seq)
doomed_set.add(seq)
extra -= 1
if not doomed:
return 0
op_del = self._cursor(_OPLOG_TABLE)
pre_del = self._cursor(_PREIMAGE_TABLE)
for seq in doomed:
op_del.set_key(seq)
with contextlib.suppress(wt.WiredTigerError):
op_del.remove()
op_del.reset()
pre_del.set_key(seq)
with contextlib.suppress(wt.WiredTigerError):
pre_del.remove()
pre_del.reset()
return len(doomed)
# --- Users (auth) ---
def add_user(
self,
db: str,
username: str,
record: Mapping[str, Any],
*,
replace: bool = False,
) -> bool:
"""Persist a user record. Returns True if added; False if it already
existed and ``replace=False``.
``record`` is a BSON-encodable dict of arbitrary shape (the
commands layer owns the structure). Stored verbatim.
"""
with self._lock:
c = self._cursor(_USERS_TABLE)
c.set_key(db, username)
if c.search() == 0 and not replace:
return False
c.reset()
c[db, username] = bson.encode(dict(record))
return True
def get_user(self, db: str, username: str) -> dict[str, Any] | None:
with self._lock:
c = self._cursor(_USERS_TABLE)
c.set_key(db, username)
if c.search() != 0:
return None
blob = bytes(c.get_value())
return bson.decode(blob) if blob else None
def drop_user(self, db: str, username: str) -> bool:
with self._lock:
c = self._cursor(_USERS_TABLE)
c.set_key(db, username)
if c.search() != 0:
return False
c.remove()
return True
def list_users(
self,
db: str | None = None,
*,
skip: int = 0,
limit: int = 100,
) -> list[dict[str, Any]]:
"""Paginated user listing. ``db=None`` lists across all databases."""
if limit <= 0 or limit > 1000:
limit = 1000
out: list[dict[str, Any]] = []
with self._lock:
c = self._cursor(_USERS_TABLE)
rc = c.next()
seen = 0
while rc == 0:
k = c.get_key()
row_db = k[0]
if db is None or row_db == db:
if seen >= skip:
blob = bytes(c.get_value())
if blob:
out.append(bson.decode(blob))
if len(out) >= limit:
break
seen += 1
rc = c.next()
return out
# ------------------------------------------------------------------
# Per-database profiling settings.
#
# Real mongod tracks (level, slowms, sampleRate) per database in
# memory + persists to the database's metadata. We persist in a
# dedicated WT table keyed by db name. The dispatch path reads
# these settings on every command — keep ``get_profile`` fast.
# ------------------------------------------------------------------
def get_profile(self, db: str) -> dict[str, Any]:
"""Return the active profile settings for ``db``, defaults if unset.
Defaults match mongod: level 0 (off), slowms 100, sampleRate 1.0.
"""
with self._lock:
c = self._cursor(_PROFILE_TABLE)
c.set_key(db)
if c.search() != 0:
return {"level": 0, "slowms": 100, "sampleRate": 1.0}
blob = bytes(c.get_value())
if not blob:
return {"level": 0, "slowms": 100, "sampleRate": 1.0}
doc = bson.decode(blob)
# ``or default`` is wrong here — slowms=0 / sampleRate=0.0 are
# legitimate values that must round-trip, not be replaced
# with defaults. Use direct ``.get`` with the default and
# coerce only when a value is actually present.
level_v = doc.get("level", 0)
slowms_v = doc.get("slowms", 100)
rate_v = doc.get("sampleRate", 1.0)
return {
"level": int(level_v) if level_v is not None else 0,
"slowms": int(slowms_v) if slowms_v is not None else 100,
"sampleRate": float(rate_v) if rate_v is not None else 1.0,
}
def set_profile(
self,
db: str,
*,
level: int,
slowms: int = 100,
sample_rate: float = 1.0,
) -> None:
"""Persist profile settings for ``db``."""
if level not in (0, 1, 2):
raise ValueError("level must be 0, 1, or 2")
if slowms < 0:
raise ValueError("slowms must be non-negative")
if not (0.0 <= sample_rate <= 1.0):
raise ValueError("sampleRate must be in [0, 1]")
doc = {"level": int(level), "slowms": int(slowms), "sampleRate": float(sample_rate)}
with self._lock:
c = self._cursor(_PROFILE_TABLE)
c[db] = bson.encode(doc)
def ensure_profile_collection(self, db: str, *, size_bytes: int = 10 * 1024 * 1024) -> None:
"""Ensure ``<db>.system.profile`` exists as a 10 MB-default capped collection."""
if self.collection_exists(db, "system.profile"):
return
self.create_collection(db, "system.profile")
self.set_collection_options(db, "system.profile", capped=True, size=int(size_bytes))
# ------------------------------------------------------------------
# Custom roles. Storage layer is a thin BSON-blob CRUD; the commands
# layer owns the role-record shape (privileges + inherited roles)
# and ``secantus.rbac`` owns the privilege-check logic that walks
# the inheritance graph.
# ------------------------------------------------------------------
def add_role(
self,
db: str,
name: str,
record: Mapping[str, Any],
*,
replace: bool = False,
) -> bool:
"""Persist a custom role record. Returns True if added; False if
it already existed and ``replace=False``."""
with self._lock:
c = self._cursor(_ROLES_TABLE)
c.set_key(db, name)
if c.search() == 0 and not replace:
return False
c.reset()
c[db, name] = bson.encode(dict(record))
return True
def get_role(self, db: str, name: str) -> dict[str, Any] | None:
# Use a private short-lived session so cross-thread visibility
# is guaranteed: connection-thread A may have written a role
# while we're on connection-thread B, and B's cached session
# carries a sticky snapshot that won't observe A's commit.
# Same pattern as ``read_oplog``. The cost (one open_session +
# close per call) is negligible vs the correctness win.
with self._lock:
session = self._conn.open_session()
try:
c = session.open_cursor(_ROLES_TABLE, None, None)
try:
c.set_key(db, name)
if c.search() != 0:
return None
blob = bytes(c.get_value())
return bson.decode(blob) if blob else None
finally:
with contextlib.suppress(Exception):
c.close()
finally:
with contextlib.suppress(Exception):
session.close()
def drop_role(self, db: str, name: str) -> bool:
with self._lock:
c = self._cursor(_ROLES_TABLE)
c.set_key(db, name)
if c.search() != 0:
return False
c.remove()
return True
def list_roles(
self,
db: str | None = None,
*,
skip: int = 0,
limit: int = 100,
) -> list[dict[str, Any]]:
"""Paginated custom-role listing. ``db=None`` spans every db."""
if limit <= 0 or limit > 1000:
limit = 1000
out: list[dict[str, Any]] = []
with self._lock:
c = self._cursor(_ROLES_TABLE)
rc = c.next()
seen = 0
while rc == 0:
k = c.get_key()
row_db = k[0]
if db is None or row_db == db:
if seen >= skip:
blob = bytes(c.get_value())
if blob:
out.append(bson.decode(blob))
if len(out) >= limit:
break
seen += 1
rc = c.next()
return out
def signal_shutdown(self) -> None:
"""Tell tailable getMore waiters the server is stopping so they wake
and return immediately, letting their connection threads drain before
:meth:`close` tears down WiredTiger. One-way: only set at stop."""
self._shutting_down = True
with self._oplog_cv:
self._oplog_cv.notify_all()
def close(self) -> None:
# Stop background threads before tearing down WT — both the
# TTL sweeper and the noop heartbeat acquire ``self._lock``,
# so racing them against close would deadlock or
# use-after-close.
self._ttl_stop.set()
if self._ttl_thread is not None and self._ttl_thread.is_alive():
self._ttl_thread.join(timeout=2.0)
self._ttl_thread = None
self._noop_stop.set()
if self._noop_thread is not None and self._noop_thread.is_alive():
self._noop_thread.join(timeout=2.0)
self._noop_thread = None
with self._lock:
if self._closed:
return
self._closed = True
# Persist the oplog meta one last time. We dropped the
# per-emit persist in Phase 2.4 (it caused WT-rollback
# storms under concurrent writers), so this is the
# canonical place to write the in-memory ``_next_seq``
# and timestamp counters down to disk before shutdown.
with contextlib.suppress(Exception):
self._persist_oplog_meta()
# Force a checkpoint before tearing the connection down.
# ``WT_CONNECTION->close`` does this implicitly, but only
# when logging is off (or hits the connection's
# close-time flush window). Driving it explicitly here
# gives a durable on-disk image of the dataset at the
# moment of shutdown regardless of journal state — the
# behaviour callers reasonably expect from ``close()``.
# Skip for in-memory backends: WT's in_memory engine
# rejects checkpoint() with a noisy stderr log
# (``__wt_inmem_unsupported_op``) on every call.
if not self._in_memory:
with contextlib.suppress(Exception):
self._session().checkpoint()
for s in self._all_sessions:
with contextlib.suppress(Exception):
s.close()
self._all_sessions.clear()
with contextlib.suppress(Exception):
self._conn.close()
if self._tempdir is not None:
# Don't follow symlinks during cleanup. A local attacker
# racing the mkdtemp could replace `_tempdir` with a
# symlink to elsewhere on the filesystem before close()
# fires — `shutil.rmtree(symlink, ignore_errors=True)`
# would then delete the symlink target. The mkdtemp
# already creates with mode 0700 (owner-only), but the
# parent /tmp is world-writable, so this is the
# belt-and-braces guard. Failures during cleanup are
# logged but not raised — close() must remain idempotent.
try:
if not os.path.islink(self._tempdir):
shutil.rmtree(self._tempdir)
except OSError:
# Best-effort: log via warnings rather than crash close().
import warnings as _warn
_warn.warn(
f"failed to remove WiredTiger tempdir {self._tempdir!r}",
ResourceWarning,
stacklevel=2,
)
self._tempdir = None
def prune_ttl_all_collections(self, *, now: _dt.datetime | None = None) -> int:
"""Run :meth:`prune_ttl` against every collection, returning the
total docs pruned. Used by the background sweeper and exposed
publicly so callers (admin tooling, tests) can drive a
deterministic global pass.
Callers using the cached per-thread session must call
:meth:`_reset_thread_session` first — WiredTiger snapshots
are sticky per-session, so reads otherwise miss rows
committed by other threads. The sweeper does this on every
iteration; one-shot user calls happen on the writer's thread
and see their own writes.
"""
with self._lock:
c = self._cursor(_COLL_TABLE)
namespaces: list[tuple[str, str]] = []
rc = c.next()
while rc == 0:
k = c.get_key()
namespaces.append((k[0], k[1]))
rc = c.next()
total = 0
for db, coll in namespaces:
with contextlib.suppress(Exception):
# Storage close races: drop_collection between snapshot
# and prune fails inside prune_ttl with a missing-coll
# error. The sweeper should never crash the daemon.
total += self.prune_ttl(db, coll, now=now)
return total
def _ttl_sweep_loop(self) -> None:
"""Background sweeper: every ``ttl_sweep_seconds`` walk all
collections and prune expired docs. Stops when ``_ttl_stop``
is set or the storage is closed.
Drops the per-thread WT session before each iteration so the
next cursor call opens a fresh session. WiredTiger sessions
carry a sticky read snapshot — without the reset, reads on
this thread would never observe rows committed by other
writers, and TTL sweeps would always return 0 even when
expired docs existed. Same pattern as ``read_oplog``.
"""
import logging
log = logging.getLogger("secantus.storage.ttl")
while not self._ttl_stop.wait(self._ttl_sweep_seconds):
if self._closed:
return
self._reset_thread_session()
try:
self.prune_ttl_all_collections()
except Exception:
# Sweeper failures must not propagate — they'd kill
# the daemon thread and silently disable expiry.
log.exception("ttl sweep failed")
def ensure_oplog_bootstrap(self) -> None:
"""Seed a bootstrap noop on a *fresh* oplog so ``local.oplog.rs`` is
never empty — mirroring mongod, whose first oplog entry is the replica
set's "initiating set" noop. Without it a brand-new server's oplog has
zero rows and a client tailing ``local.oplog.rs`` (pymongo's
``test_cursor.test_to_list_tailable``) finds nothing to read.
Called by :class:`SecantusDBServer` at startup (replica-set initiation
is a server/replication concern, not a storage-engine one — bare
``Storage`` instances in unit tests keep a clean empty oplog). A noop
(``op: "n"``) is skipped by change-stream projection, so it never
surfaces as a change event. Idempotent: fires only when the oplog is
enabled and truly fresh (``_next_seq == 1``); reopening a populated
oplog is a no-op.
"""
with self._lock:
if self.enable_oplog and self._next_seq == 1:
self._emit_oplog([{"op": "n", "ns": "", "o": {"msg": "initiating set"}}])
def emit_noop_heartbeat(self) -> int:
"""Append one ``{op: "n"}`` heartbeat to the oplog and return its seq.
The entry shape mirrors mongod's periodic noop: ``op = "n"``,
an empty namespace, current cluster time, and a small
``o = {msg: "periodic noop"}`` payload. Change-stream consumers
skip ``op: "n"`` rows in projection but still advance their
``position_seq`` and ``last_token`` past them, so the resume
token of a quiet collection stays current.
Public so callers (admin tooling, tests that drive heartbeats
deterministically) can fire one explicitly.
"""
with self._lock:
return self._emit_oplog(
[
{
"op": "n",
"ns": "",
"o": {"msg": "periodic noop"},
}
]
)
def _noop_heartbeat_loop(self) -> None:
"""Background heartbeat: emit one ``{op: "n"}`` oplog entry every
``noop_heartbeat_seconds``. Stops when ``_noop_stop`` is set
or the storage is closed. Failures are logged and swallowed —
a transient WT error must not kill the daemon thread.
"""
import logging
log = logging.getLogger("secantus.storage.noop")
while not self._noop_stop.wait(self._noop_heartbeat_seconds):
if self._closed:
return
try:
self.emit_noop_heartbeat()
except Exception:
log.exception("noop heartbeat failed")
def _reset_thread_session(self) -> None:
"""Close the calling thread's cached WT session + cursors so
the next ``_session()`` call opens fresh ones. Needed when a
thread reads in a loop and must observe writes from other
threads (snapshot is otherwise sticky)."""
s = getattr(self._tls, "session", None)
if s is None:
return
cursors = getattr(self._tls, "cursors", {}) or {}
for c in cursors.values():
with contextlib.suppress(Exception):
c.close()
with contextlib.suppress(Exception):
s.close()
with self._lock, contextlib.suppress(ValueError):
self._all_sessions.remove(s)
self._tls.session = None
self._tls.cursors = {}
def checkpoint(self) -> None:
"""Force a WiredTiger checkpoint to flush dirty pages to disk.
Backs the ``fsync`` command and the admin UI's maintenance
slice. Lock-protected so concurrent commands wait their turn.
On in-memory backends the call is a no-op (WT's in_memory
engine has no disk to flush and rejects with a noisy stderr
log).
"""
with self._lock:
if self._closed or self._in_memory:
return
self._session().checkpoint()
def create_archive(self, output_path: str) -> dict[str, int | str]:
"""Force a checkpoint, then tar the consistent file set into ``output_path``.
Returns ``{"path": <abs>, "sizeBytes": <int>}`` on success.
Raises ``RuntimeError`` for in-memory backends — there's no
on-disk state to archive.
Uses WT's dedicated ``backup:`` cursor to enumerate the files
that constitute a consistent snapshot. WT promises during the
cursor's lifetime that those files won't change and that they
are read-shareable — the latter matters on Windows, where WT
otherwise holds exclusive file locks that block ``tarfile``'s
reads. Walking the directory directly worked on Unix (open
files are shareable by default) but ``PermissionError``'d on
Windows.
Output is a single ``.tar.gz`` (gzip-compressed) so the archive
round-trips cleanly through git/mail/scp; the typical workload
compresses well because WT pages aren't snappy/zstd at rest.
"""
import tarfile
if self._in_memory:
raise RuntimeError(
"create_archive: cannot archive an in-memory backend "
"(WT in_memory engine has no on-disk state)"
)
# Resolve to absolute so the returned ``path`` is unambiguous
# for the caller even if their cwd has shifted.
abs_out = os.path.abspath(output_path)
os.makedirs(os.path.dirname(abs_out) or ".", exist_ok=True)
with self._lock:
if self._closed:
raise RuntimeError("create_archive: storage is closed")
self._session().checkpoint()
# A private session for the backup cursor so its lifecycle
# doesn't interfere with the per-thread cached session
# that handles regular work.
backup_session = self._conn.open_session()
try:
cursor = backup_session.open_cursor("backup:", None, None)
try:
# Tar inline while the cursor is open: WT creates
# the ``WiredTiger.backup`` metadata file as part
# of the cursor's open state and removes it on
# close, so collecting filenames first then tarring
# would race the cleanup. Iterate-and-add keeps
# every file readable for the duration of the tar.
with tarfile.open(abs_out, "w:gz") as tar:
while cursor.next() == 0:
rel = cursor.get_key()
full = os.path.join(self.home_path, rel)
tar.add(full, arcname=rel)
finally:
cursor.close()
finally:
backup_session.close()
return {"path": abs_out, "sizeBytes": os.path.getsize(abs_out)}
@contextlib.contextmanager
def _batch_transaction(self, *, sync: bool = False) -> Any:
"""Group multiple cursor writes into one WT transaction = one log record.
WT auto-commits every individual ``cursor.insert()`` /
``cursor.update()`` etc., which means N writes produce N log
records and N commit overheads. With this wrapper, the same N
writes share a single commit (and therefore a single log
record): on a typical bulk insert that's a 2-5x throughput
win for ``--batch-size > 1`` on the wire side, with the same
durability guarantee (all-or-nothing on commit).
``sync=True`` overrides the connection-level
``transaction_sync`` setting and forces this individual
commit to fsync the log to disk before returning — the
per-transaction equivalent of the server-wide
``sync_on_commit`` knob. Used to honour
``writeConcern: {j: true}`` on a single write even when the
daemon is otherwise running with ``sync_on_commit=false``.
``sync=False`` (default) inherits the connection's
``transaction_sync`` config.
Caller must already hold ``self._lock``. Reads within the
transaction observe the in-progress writes — fine for our
unique-conflict probes which need to see uncommitted siblings
in the same batch.
On exception the transaction is rolled back. Callers that
accumulate per-doc errors (e.g. ``ordered=False`` insert)
should NOT raise out of the block — they handle the per-doc
errors locally and let the surviving writes commit.
Inside a user (multi-document) transaction this is a no-op
passthrough: WT doesn't nest transactions, and the statement's
writes must stay uncommitted in the user transaction until its
``commitTransaction``.
"""
if getattr(self._tls, "user_txn", None) is not None:
yield self._session()
return
session = self._session()
# Cached cursors must be reset before begin_transaction so they
# don't carry a stale snapshot from before the transaction
# boundary. WT documents this requirement explicitly.
for c in getattr(self._tls, "cursors", {}).values():
with contextlib.suppress(Exception):
c.reset()
session.begin_transaction()
try:
yield session
except Exception:
with contextlib.suppress(Exception):
session.rollback_transaction()
raise
else:
session.commit_transaction("sync=on" if sync else None)
def _session(self) -> Any:
s = getattr(self._tls, "session", None)
if s is None:
s = self._conn.open_session()
self._tls.session = s
self._tls.cursors = {}
with self._lock:
self._all_sessions.append(s)
return s
def _refresh_read_snapshot(self) -> None:
"""Force the per-thread WT session to acquire a fresh read snapshot.
WiredTiger's default snapshot isolation pins a session's read
view at first cursor access; subsequent reads on the same
session see exactly that point-in-time view until the session
commits / rolls back a transaction. That's correct for a single
in-flight operation, but our daemon reuses one session per
connection thread across the full lifetime of a TCP connection.
Without an explicit snapshot refresh, a long-lived client
connection (Java's ``ClusterFixture`` is the canonical case)
does an insert, idles while another connection commits a
write, then reads — and sees the stale pre-other-write view.
``session.reset_snapshot()`` releases the held snapshot so the
next cursor read picks up the latest committed state. Called at
the top of every public read entry point (``find_matching``,
``count_matching``, ``list_*``, ``explain_plan``) so cross-
connection visibility matches real ``mongod``.
"""
if getattr(self._tls, "user_txn", None) is not None:
# A user transaction's whole point is the pinned snapshot:
# reads inside it must keep seeing the transaction's view.
return
s = getattr(self._tls, "session", None)
if s is None:
return
with contextlib.suppress(Exception):
# ``reset_snapshot()`` errors if the session is in an
# explicit transaction. Reads never run inside one
# (``_batch_transaction`` is write-only), so the exception
# path is defensive — log via ``suppress`` and move on.
s.reset_snapshot()
# -- user (multi-document) transactions --------------------------------
#
# A user transaction owns a dedicated WT session, NOT the connection
# thread's ``threading.local`` one: pymongo can legally send a
# transaction's statements and its retryable commit on different
# pooled connections (= different server threads). Statements run
# with the transaction's session/cursors swapped into ``_tls`` so
# every existing storage path (unique probes, index writes,
# ``_ensure_collection``, ``find_matching``) transparently executes
# inside the WT transaction — read-your-own-writes and the pinned
# snapshot fall out for free. The command layer serializes access
# per transaction; these primitives assume no two threads install
# the same handle concurrently.
def begin_user_transaction(self) -> UserTransactionHandle:
"""Open a dedicated WT session for a multi-document transaction.
The WT ``begin_transaction`` itself happens lazily on the first
``use_user_transaction`` entry so the snapshot pins at the
transaction's first statement (mongod semantics).
"""
with self._lock:
if self._closed:
raise RuntimeError("storage is closed")
session = self._conn.open_session()
# Registered so ``close()``'s sweep rolls back leftovers.
self._all_sessions.append(session)
return UserTransactionHandle(session)
@contextlib.contextmanager
def use_user_transaction(self, handle: UserTransactionHandle) -> Any:
"""Run the body with ``handle``'s session installed as this
thread's storage session, arming the oplog buffering hook."""
if not handle.began:
handle.session.begin_transaction()
handle.began = True
with self._install_txn_session(handle):
self._tls.user_txn = handle
try:
yield
finally:
self._tls.user_txn = None
def commit_user_transaction(
self,
handle: UserTransactionHandle,
*,
lsid_doc: Mapping[str, Any] | None = None,
txn_number: int | None = None,
) -> int:
"""Flush the buffered oplog + commit the WT transaction.
All buffered entries get one shared commit ``Timestamp``
(mongod stamps every op in a transaction with the commit time)
plus ``lsid`` / ``txnNumber`` for change-stream events. The
oplog/preimage rows are written through the transaction's own
session *before* ``commit_transaction``, so data and oplog
become visible atomically. Returns the last oplog seq emitted.
On failure the transaction is rolled back and the exception
propagates — a failed WT commit cannot be retried into success.
"""
last_seq = 0
try:
if handle.began:
entries = handle.oplog_entries
pre_images = handle.pre_images
if entries and self.enable_oplog:
# Mint the shared commit timestamp before installing
# the txn session: ``current_cluster_time`` persists
# oplog meta through the calling thread's session and
# that write must not ride inside the transaction.
ts = self.current_cluster_time()
wall = _dt.datetime.now(_dt.timezone.utc)
for entry in entries:
entry.setdefault("ts", ts)
entry.setdefault("wall", wall)
if lsid_doc is not None:
entry["lsid"] = dict(lsid_doc)
if txn_number is not None:
entry["txnNumber"] = Int64(txn_number)
with self._install_txn_session(handle):
# ``_tls.user_txn`` is deliberately NOT set here, so
# ``_emit_oplog`` takes its real write path on the
# transaction's session instead of re-buffering.
if entries:
last_seq = self._emit_oplog(entries, pre_images)
handle.session.commit_transaction()
except Exception:
self.abort_user_transaction(handle)
raise
self._close_user_txn_session(handle)
# ``_emit_oplog`` notified before the WT commit (same order as
# the non-transactional write path); one more notify after the
# commit guarantees tailable getMore waiters re-poll against
# the now-visible rows.
with self._oplog_cv:
self._oplog_cv.notify_all()
return last_seq
def abort_user_transaction(self, handle: UserTransactionHandle) -> None:
"""Roll back and release the transaction's WT session. Idempotent."""
if handle.closed:
return
if handle.began:
with contextlib.suppress(Exception):
handle.session.rollback_transaction()
self._close_user_txn_session(handle)
@contextlib.contextmanager
def _install_txn_session(self, handle: UserTransactionHandle) -> Any:
tls = self._tls
prev_session = getattr(tls, "session", None)
prev_cursors = getattr(tls, "cursors", {})
tls.session = handle.session
tls.cursors = handle.cursors
try:
yield
finally:
tls.session = prev_session
tls.cursors = prev_cursors
def _close_user_txn_session(self, handle: UserTransactionHandle) -> None:
if handle.closed:
return
handle.closed = True
for c in handle.cursors.values():
with contextlib.suppress(Exception):
c.close()
handle.cursors.clear()
with contextlib.suppress(Exception):
handle.session.close()
with self._lock, contextlib.suppress(ValueError):
self._all_sessions.remove(handle.session)
def _cursor(self, table: str, *, overwrite: bool = True) -> Any:
self._session()
cursors: dict[tuple[str, bool], Any] = self._tls.cursors
key = (table, overwrite)
c = cursors.get(key)
if c is None:
cfg = None if overwrite else "overwrite=false"
c = self._tls.session.open_cursor(table, None, cfg)
cursors[key] = c
else:
c.reset()
return c
def _coll_options(self, db: str, coll: str) -> dict[str, Any] | None:
c = self._cursor(_COLL_TABLE)
c.set_key(db, coll)
rc = c.search()
if rc != 0:
return None
blob = bytes(c.get_value())
return bson.decode(blob) if blob else {}
def _is_timeseries(self, db: str, coll: str) -> bool:
opts = self._coll_options(db, coll)
return bool(opts) and "timeseries" in opts
def _timeseries_doc_suffix(self) -> bytes:
"""Doc-table key discriminator for timeseries collections.
Timeseries collections don't enforce ``_id`` uniqueness (mongod
buckets measurements by time; ``_id`` is not a key), but our doc
table is keyed by ``encode_value(_id)`` — equal ``_id``s would
structurally collide. Suffixing the key keeps duplicates adjacent
(the sortkey encoding is prefix-free, so grouping by ``_id`` is
preserved) in insertion order. ``time_ns`` keeps suffixes unique
across store reopens; the counter disambiguates same-nanosecond
inserts. Reads decode and filter by content, so the suffix is
invisible above storage — but the ``_id`` point-lookup fast path
must not be used (it reconstructs the UNsuffixed key).
"""
self._ts_suffix_counter = (self._ts_suffix_counter + 1) % 0x10000
return _time.time_ns().to_bytes(8, "big") + self._ts_suffix_counter.to_bytes(2, "big")
def _ensure_collection(self, db: str, coll: str) -> None:
c = self._cursor(_COLL_TABLE)
c.set_key(db, coll)
if c.search() == 0:
return
c.reset()
c[db, coll] = b""
def collection_exists(self, db: str, coll: str) -> bool:
with self._lock:
return self._coll_options(db, coll) is not None
def create_collection(self, db: str, coll: str) -> bool:
with self._lock:
c = self._cursor(_COLL_TABLE)
c.set_key(db, coll)
if c.search() == 0:
return False
c.reset()
c[db, coll] = b""
self._collection_uuid(db, coll) # mint and persist
ui = self._collection_uuid(db, coll)
self._emit_oplog(
[
{
"op": "c",
"ns": f"{db}.$cmd",
"ui": bson.Binary(ui.bytes, subtype=4),
"o": {
"create": coll,
"idIndex": {"v": 2, "key": {"_id": 1}, "name": "_id_"},
},
}
]
)
return True
def _scan_docs(self, db: str, coll: str) -> Iterable[tuple[bytes, bytes]]:
c = self._cursor(_DOC_TABLE)
c.set_key(db, coll, b"")
rc = c.search_near()
if rc == wt.WT_NOTFOUND:
return
if rc < 0 and c.next() != 0:
return
while True:
k = c.get_key()
if k[0] != db or k[1] != coll:
return
yield bytes(k[2]), bytes(c.get_value())
if c.next() != 0:
return
def _all_docs(self, db: str, coll: str) -> list[dict[str, Any]]:
# Two-stage to keep ``bson.decode`` out of ``self._lock`` —
# otherwise an N-doc scan blocks every other thread for the
# whole decode loop. Lock owns the WT cursor walk; decode
# happens after release.
with self._lock:
blobs = [blob for _id_k, blob in self._scan_docs(db, coll)]
return [bson.decode(blob) for blob in blobs]
def _all_docs_with_id_key(self, db: str, coll: str) -> list[tuple[dict[str, Any], bytes]]:
with self._lock:
raw = [(id_k, blob) for id_k, blob in self._scan_docs(db, coll)]
return [(bson.decode(blob), id_k) for id_k, blob in raw]
def scan_docs_after_id_key(
self, db: str, coll: str, after: bytes | None
) -> list[tuple[bytes, dict[str, Any]]]:
"""Scan the document table in natural (id_key) order, returning
only rows whose ``id_key`` is strictly greater than ``after``.
``after`` of ``None`` returns the entire collection.
Used by the tailable-cursor producer to emit only the docs
inserted since the last poll. Returns ``[(id_key, doc), ...]``
— callers update their ``after`` checkpoint to the last
returned ``id_key`` for the next poll.
"""
# Two-stage: collect raw bytes under the lock, decode after.
with self._lock:
raw: list[tuple[bytes, bytes]] = [
(id_k, blob)
for id_k, blob in self._scan_docs(db, coll)
if after is None or id_k > after
]
return [(id_k, bson.decode(blob)) for id_k, blob in raw]
def collection_min_id_key(self, db: str, coll: str) -> bytes | None:
"""Smallest ``id_key`` in the collection — the oldest doc in natural
(insertion, for monotonic ``_id``) order — or ``None`` if empty.
Used to detect capped-collection rollover for tailable cursors: if a
cursor's last-returned ``id_key`` is below this, the document it was
anchored on has been evicted, and mongod kills the cursor with
``CappedPositionLost``. ``_scan_docs`` yields in ``id_key`` order, so
the first row is the minimum — we stop after it.
"""
with self._lock:
for id_k, _blob in self._scan_docs(db, coll):
return bytes(id_k)
return None
def collection_is_capped(self, db: str, coll: str) -> bool:
"""Public predicate: does the collection have ``capped: true`` set?
The synthetic ``local.oplog.rs`` view is always capped (mongod models
the oplog as a capped collection) even though it isn't materialised in
the collections table — so tailable cursors over it are accepted.
"""
with self._lock:
if self._is_oplog_rs(db, coll):
return True
opts = self._coll_options(db, coll) or {}
return bool(opts.get("capped"))
@_retry_write_conflicts
def insert(
self,
db: str,
coll: str,
docs: Iterable[dict[str, Any]],
*,
ordered: bool = True,
journal: bool = False,
) -> tuple[int, list[dict[str, Any]]]:
# Materialized so the conflict-retry wrapper can safely re-run
# the whole method (a generator would arrive exhausted).
docs = list(docs)
inserted = 0
errors: list[dict[str, Any]] = []
oplog_entries: list[dict[str, Any]] = []
fresh_id_keys: set[bytes] = set()
oplog_on = self.enable_oplog
with self._coll_lock(db, coll), self._batch_transaction(sync=journal):
# Per-collection lock (Phase 2.4): writes to other
# collections proceed in parallel; same-collection writes
# still serialise to keep the unique-index pre-check
# race-free. _batch_transaction wraps the per-doc cursor
# inserts (doc table + index entries + oplog) in one
# explicit WT transaction so they share a single commit /
# log record.
self._ensure_collection(db, coll)
ns = self._ns(db, coll) if oplog_on else ""
ui = self._collection_uuid(db, coll) if oplog_on else None
indexes = self._all_indexes(db, coll)
partials = self._partial_filters(db, coll)
multikey_names = self._multikey_index_names(db, coll)
timeseries = self._is_timeseries(db, coll)
for index, doc in enumerate(docs):
if "_id" not in doc:
doc["_id"] = bson.ObjectId()
key = _id_key(doc["_id"])
if timeseries:
# Duplicate _ids are legal in timeseries collections —
# see _timeseries_doc_suffix.
key += self._timeseries_doc_suffix()
conflict = self._unique_conflict(
db, coll, doc, indexes, exclude_id_key=None, partials=partials
)
if conflict is not None:
cname, kpat, kval = conflict
errors.append(
{
"index": index,
"code": 11000,
"errmsg": (
f"E11000 duplicate key error in index {cname}: _id={doc['_id']!r}"
),
"keyPattern": kpat,
"keyValue": kval,
}
)
if ordered:
break
continue
# Pre-flight every geo index: a bad geometry should reject
# the doc *before* it lands in the doc table, so we don't
# leave a half-indexed write behind. Validation is cheap;
# _write_index_entries below recomputes the same cells.
try:
self._validate_geo_indexes(db, coll, doc, indexes, partials)
except GeoExtractError as exc:
errors.append({"index": index, "code": 16572, "errmsg": str(exc)})
if ordered:
break
continue
blob = bson.encode(doc)
if len(blob) > MAX_BSON_OBJECT_SIZE:
# mongod rejects per-document at insert time with
# BSONObjectTooLarge (10334) and this exact wording.
errors.append(
{
"index": index,
"code": 10334,
"errmsg": (
f"object to insert too large. size in bytes: "
f"{len(blob)}, max size: {MAX_BSON_OBJECT_SIZE}"
),
}
)
if ordered:
break
continue
doc_cur = self._cursor(_DOC_TABLE, overwrite=False)
doc_cur.set_key(db, coll, key)
doc_cur.set_value(blob)
try:
doc_cur.insert()
except wt.WiredTigerError as exc:
if _is_wt_rollback(exc):
# Concurrency conflict, not a duplicate key —
# surface for transaction/retry handling instead
# of lying with an E11000.
raise WriteConflictError(str(exc)) from exc
errors.append(
{
"index": index,
"code": 11000,
"errmsg": f"E11000 duplicate key error: _id {doc['_id']!r}",
}
)
if ordered:
break
continue
self._write_index_entries(db, coll, doc, indexes, partials, id_key_override=key)
multikey_names = self._maybe_mark_multikey(db, coll, doc, indexes, multikey_names)
inserted += 1
if oplog_on:
oplog_entries.append(
{
"op": "i",
"ns": ns,
"ui": bson.Binary(ui.bytes, subtype=4),
"o": dict(doc),
"o2": {"_id": doc["_id"]},
}
)
fresh_id_keys.add(key)
cap_entries, cap_pre_images = self._enforce_capped_bounds_locked(
db, coll, fresh_id_keys, indexes, partials, oplog_on, ns, ui
)
if oplog_entries or cap_entries:
pre_images = [None] * len(oplog_entries) + cap_pre_images
self._emit_oplog(oplog_entries + cap_entries, pre_images)
return inserted, errors
def _enforce_capped_bounds_locked(
self,
db: str,
coll: str,
fresh_id_keys: set[bytes],
indexes: list[tuple[str, dict[str, Any], bool, bool]],
partials: dict[str, dict[str, Any]],
oplog_on: bool,
ns: str,
ui: _uuid.UUID | None,
) -> tuple[list[dict[str, Any]], list[bytes | None]]:
"""Evict oldest non-fresh docs from a capped collection until within bounds.
"Oldest" is the natural-order walk over the doc table, which matches
insertion order when ``_id`` is monotonic (e.g. the default
ObjectId). For non-monotonic ``_id`` values the eviction order
reflects ``_id`` byte order, not literal insertion order — capped
users with custom ``_id`` should not rely on FIFO semantics.
"""
raw = self._coll_options(db, coll) or {}
if not raw.get("capped"):
return [], []
size_limit = raw.get("size")
max_limit = raw.get("max")
if size_limit is None and max_limit is None:
return [], []
scanned = list(self._scan_docs(db, coll))
total = sum(len(blob) for _id_k, blob in scanned)
count = len(scanned)
oplog_entries: list[dict[str, Any]] = []
pre_images: list[bytes | None] = []
preimages_on = oplog_on and self._pre_post_images_enabled(db, coll)
for id_k, blob in scanned:
over_size = size_limit is not None and total > size_limit
over_max = max_limit is not None and count > max_limit
if not over_size and not over_max:
break
if id_k in fresh_id_keys:
# Don't evict docs we just inserted in this batch — they
# always sort to the tail with monotonic _ids, so reaching
# one means everything left is fresh too.
break
doc = bson.decode(blob)
self._delete_index_entries(db, coll, doc, indexes, partials)
doc_cur = self._cursor(_DOC_TABLE)
doc_cur.set_key(db, coll, id_k)
doc_cur.remove()
total -= len(blob)
count -= 1
if oplog_on:
entry: dict[str, Any] = {
"op": "d",
"ns": ns,
"o": {"_id": doc["_id"]},
"o2": {"_id": doc["_id"]},
}
if ui is not None:
entry["ui"] = bson.Binary(ui.bytes, subtype=4)
oplog_entries.append(entry)
pre_images.append(bson.encode(doc) if preimages_on else None)
return oplog_entries, pre_images
def find_matching(
self,
db: str,
coll: str,
filter: dict[str, Any] | None = None,
*,
skip: int = 0,
limit: int = 0,
sort: Mapping[str, Any] | None = None,
projection: Mapping[str, Any] | None = None,
hint: str | Mapping[str, Any] | None = None,
let: dict[str, Any] | None = None,
collation: Any = None,
min_bound: Mapping[str, Any] | None = None,
max_bound: Mapping[str, Any] | None = None,
) -> list[dict[str, Any]]:
if self._is_oplog_rs(db, coll):
return self._find_oplog_rs(
filter,
skip=skip,
limit=limit,
sort=sort,
projection=projection,
let=let,
collation=collation,
)
if self._is_system_users(db, coll):
return self._find_system_users(
filter,
skip=skip,
limit=limit,
sort=sort,
projection=projection,
let=let,
collation=collation,
)
if self._is_system_version(db, coll):
return self._find_system_version(
filter,
skip=skip,
limit=limit,
sort=sort,
projection=projection,
let=let,
collation=collation,
)
from secantus.collation import parse as _parse_collation
collation_obj = _parse_collation(collation)
self._refresh_read_snapshot()
filter = filter or {}
in_sort_order = False
# Two-stage decode discipline: the lock is held only for the
# WT cursor walk and any index routing; the COLLSCAN fallback
# collects raw blobs while the lock is held and defers
# ``bson.decode`` (and the ``matches()`` predicate, sorting,
# projection) to *after* the lock releases. Concurrent readers
# then decode in parallel even while a writer holds the lock
# for inserts. Index-path candidates still come back already
# decoded — that's a deeper refactor (Phase 2 territory).
candidates: list[dict[str, Any]] | None = None
raw_blobs: list[bytes] | None = None
with self._lock:
sort_field, sort_dir = self._single_sort_spec(sort)
if hint is not None:
resolved = self._resolve_hint(db, coll, hint)
candidates, in_sort_order = self._candidates_from_hint(
db, coll, resolved, sort_field, sort_dir
)
else:
# Per-index collation: ``_try_index_lookup`` gates indexes
# by exact match against ``collation_obj`` (None counts as
# "no collation"), so the same code path covers both the
# plain and the collation-bearing cases. Same applies to
# the sort-acceleration pickers below — they all thread
# ``collation_obj`` through so a sort on a collation-
# indexed string field walks the index when the query's
# collation matches and falls back to a Python sort
# otherwise.
candidates = self._try_index_lookup(db, coll, filter, collation=collation_obj)
if candidates is not None and sort_field is not None:
if (
len(filter) == 1
and not next(iter(filter)).startswith("$")
and next(iter(filter)) == sort_field
):
in_sort_order = True
idx = self._find_leading_field_index(
db, coll, sort_field, filter, collation=collation_obj
)
idx_dir = idx[1] if idx else 1
if sort_dir != idx_dir:
candidates = list(reversed(candidates))
elif candidates is None and not filter and sort_field is not None:
idx = self._find_leading_field_index(
db, coll, sort_field, filter, collation=collation_obj
)
if idx is not None:
idx_name, idx_dir, _is_compound = idx
# If the index direction matches the sort direction,
# walk forward; if it's opposite, walk backward.
reverse = sort_dir != idx_dir
candidates = self._walk_index_in_order(db, coll, idx_name, reverse=reverse)
in_sort_order = True
# Multi-field sort acceleration: when sort has 2+ fields and
# filter is empty, try to find a compound index whose key
# spec exactly matches (or fully inverts) the sort. Walking
# that index in the right direction yields the requested
# order without a Python-side post-sort.
if candidates is None and not filter and sort_field is None and sort:
multi_spec = self._multi_sort_spec(sort)
if multi_spec is not None and len(multi_spec) > 1:
match = self._compound_index_for_sort(
db, coll, multi_spec, collation=collation_obj
)
if match is not None:
idx_name, reverse = match
candidates = self._walk_index_in_order(
db, coll, idx_name, reverse=reverse
)
in_sort_order = True
if candidates is None:
raw_blobs = [b for _, b in self._scan_docs(db, coll)]
if candidates is None:
assert raw_blobs is not None
candidates = [bson.decode(b) for b in raw_blobs]
out = [d for d in candidates if matches(d, filter, vars=let, collation=collation_obj)]
if min_bound is not None or max_bound is not None:
out = self._apply_minmax_bounds(
db, coll, out, hint, min_bound, max_bound, collation_obj
)
if sort and not in_sort_order:
out = sort_docs(out, sort)
if skip:
out = out[skip:]
if limit > 0:
out = out[:limit]
if projection:
out = [apply_projection(d, projection) for d in out]
return out
def _apply_minmax_bounds(
self,
db: str,
coll: str,
docs: list[dict[str, Any]],
hint: str | Mapping[str, Any] | None,
min_bound: Mapping[str, Any] | None,
max_bound: Mapping[str, Any] | None,
collation: Any,
) -> list[dict[str, Any]]:
"""Filter ``docs`` by cursor ``min`` / ``max`` index bounds.
``max`` is an exclusive upper bound, ``min`` an inclusive lower
bound, evaluated on the hinted index's key (mongod semantics).
The bound documents must name a leading prefix of the hinted
index's key fields, in the same order — otherwise mongod raises
51174, which we mirror via ``MinMaxKeyError``. Bounds and docs
are encoded with the same ``_index_key`` direction-aware byte
encoder, so a byte comparison reflects the index's natural order
(cross-type, per-field direction).
"""
if hint is None:
raise MinMaxKeyError("min/max requires a hint")
resolved = self._resolve_hint(db, coll, hint)
key_spec: dict[str, Any] | None = None
if resolved == _ID_INDEX_NAME:
key_spec = {"_id": 1}
else:
for name, ks, _sparse, _unique in self._all_indexes(db, coll):
if name == resolved:
key_spec = dict(ks)
break
if key_spec is None:
raise MinMaxKeyError("min/max hint does not correspond to an index")
index_fields = list(key_spec)
def _bound_spec(bound: Mapping[str, Any]) -> dict[str, Any]:
bound_fields = list(bound)
if bound_fields != index_fields[: len(bound_fields)]:
raise MinMaxKeyError(
"The field order of the min/max query option does not "
"match the order of the hinted index's key pattern"
)
return {f: key_spec[f] for f in bound_fields}
min_key = (
_index_key(dict(min_bound), _bound_spec(min_bound), sparse=False, collation=collation)
if min_bound is not None
else None
)
max_key = (
_index_key(dict(max_bound), _bound_spec(max_bound), sparse=False, collation=collation)
if max_bound is not None
else None
)
def _in_bounds(doc: dict[str, Any]) -> bool:
if min_key is not None:
dk = _index_key(doc, _bound_spec(min_bound), sparse=False, collation=collation)
if dk is None or dk < min_key: # min is inclusive
return False
if max_key is not None:
dk = _index_key(doc, _bound_spec(max_bound), sparse=False, collation=collation)
if dk is None or dk >= max_key: # max is exclusive
return False
return True
return [d for d in docs if _in_bounds(d)]
def _resolve_hint(self, db: str, coll: str, hint: str | Mapping[str, Any]) -> str:
"""Resolve ``hint`` to an index name (or ``$natural``).
``hint`` may be an index name string, a key-spec dict matching an
existing index, ``"$natural"``, or ``{"$natural": +/-1}``. Anything
else raises ``BadHint`` so the command layer can return a Mongo
``BadValue`` error.
"""
if isinstance(hint, str):
if hint == "$natural":
return "$natural"
if hint == _ID_INDEX_NAME:
return _ID_INDEX_NAME
for name, _key_spec, _sparse, _unique in self._all_indexes(db, coll):
if name == hint:
return name
raise BadHint(f"hint {hint!r} does not correspond to an existing index")
if isinstance(hint, Mapping):
if list(hint) == ["$natural"]:
return "$natural"
if list(hint) == ["_id"] and int(hint["_id"]) == 1:
return _ID_INDEX_NAME
for name, key_spec, _sparse, _unique in self._all_indexes(db, coll):
if dict(key_spec) == dict(hint):
return name
raise BadHint(f"hint {dict(hint)!r} does not correspond to an existing index")
raise BadHint(f"invalid hint type: {type(hint).__name__}")
def _candidates_from_hint(
self,
db: str,
coll: str,
resolved: str,
sort_field: str | None,
sort_dir: int,
) -> tuple[list[dict[str, Any]], bool]:
"""Walk the index named by ``resolved`` (or full collection for $natural).
Returns ``(candidates, in_sort_order)`` where ``in_sort_order`` is
True when the hint's leading field matches the sort field — in
which case ``find_matching`` skips the post-sort step.
"""
if resolved == "$natural":
return [bson.decode(b) for _, b in self._scan_docs(db, coll)], False
if resolved == _ID_INDEX_NAME:
# The doc table is keyed by id_key; iterating it gives entries
# sorted by encoded _id, which matches the _id_ index walk.
docs = [bson.decode(b) for _, b in self._scan_docs(db, coll)]
in_order = sort_field == "_id"
if in_order and sort_dir == -1:
docs = list(reversed(docs))
return docs, in_order
# Find the index's leading field and its direction
leading: str | None = None
leading_dir = 1
for name, key_spec, _sparse, _unique in self._all_indexes(db, coll):
if name == resolved:
first = next(iter(key_spec))
leading = first
leading_dir = int(key_spec[first])
break
candidates = self._walk_index_in_order(db, coll, resolved, reverse=False)
in_order = sort_field is not None and sort_field == leading
if in_order and sort_dir != leading_dir:
candidates = list(reversed(candidates))
return candidates, in_order
@staticmethod
def _single_sort_spec(sort: Mapping[str, Any] | None) -> tuple[str | None, int]:
"""Return ``(field, direction)`` if ``sort`` is single-field +/-1, else ``(None, 0)``."""
if not sort or len(sort) != 1:
return None, 0
f, d = next(iter(sort.items()))
if f.startswith("$"):
return None, 0
try:
di = int(d)
except (TypeError, ValueError):
return None, 0
if di not in (-1, 1):
return None, 0
return f, di
@staticmethod
def _multi_sort_spec(
sort: Mapping[str, Any] | None,
) -> list[tuple[str, int]] | None:
"""Return a list of ``(field, direction)`` pairs for a multi-field
sort spec, or ``None`` if any entry is operator-prefixed or has a
non-``±1`` direction.
Used for compound-index sort acceleration: an index whose key
spec exactly matches (or fully inverts) the returned list lets
``find_matching`` walk WT in the requested order and skip the
Python-side post-sort entirely.
"""
if not sort:
return None
out: list[tuple[str, int]] = []
for field, direction in sort.items():
if field.startswith("$"):
return None
try:
d = int(direction)
except (TypeError, ValueError):
return None
if d not in (-1, 1):
return None
out.append((field, d))
return out
def _compound_index_for_sort(
self,
db: str,
coll: str,
sort_fields: list[tuple[str, int]],
*,
collation: Any = None,
) -> tuple[str, bool] | None:
"""Find a compound index that satisfies ``sort_fields`` end-to-end.
Returns ``(index_name, reverse_walk)`` where ``reverse_walk`` is
True when the matching index is the *fully-inverted* permutation
of the sort (walking backward yields the requested order).
Multikey indexes are excluded — array values in the index could
produce row order that doesn't match the BSON cross-type sort
the user expects from a sort spec, so we'd fall back to Python
sort anyway.
Strict match only: the index key spec must have the same fields
in the same order with directions either matching the sort spec
or being the full inverse. Partial-prefix matches (sort uses 3
fields, index has 2) aren't accelerated; the savings on the
leading prefix are usually less than the cost of the trailing
Python sort over the materialised set.
``collation``: same exact-match gate as the filter pickers — an
index is only considered if its stored collation parses to the
same :class:`Collation` as the query's (or both None). A
no-collation sort against a collation-having index would walk
the index in collation order rather than codepoint order, which
is wrong for the user; the reverse is also wrong. So mismatched
indexes are skipped and the caller falls back to a Python sort.
"""
multikey = self._multikey_index_names(db, coll)
index_options = self._index_options_map(db, coll)
target = list(sort_fields)
inverted = [(f, -d) for f, d in target]
for name, key_spec, _sparse, _unique in self._all_indexes(db, coll):
if name in multikey:
continue
try:
idx_pairs = [(f, int(d)) for f, d in key_spec.items()]
except (TypeError, ValueError):
continue
if any(d not in (-1, 1) for _, d in idx_pairs):
continue
idx_coll = _parse_index_collation(index_options.get(name, {}).get("collation"))
if idx_coll != collation:
continue
if idx_pairs == target:
return name, False
if idx_pairs == inverted:
return name, True
return None
def _single_field_index_for(self, db: str, coll: str, field: str) -> tuple[str, int] | None:
"""Return ``(index_name, direction)`` for a single-field index on
``field``, or ``None`` if no such index exists. Direction is the
index's stored sort direction (`+1` for ASC, `-1` for DESC)."""
for name, key_spec, _sparse, _unique in self._all_indexes(db, coll):
if list(key_spec.keys()) == [field]:
d = int(key_spec[field])
if d in (1, -1):
return name, d
return None
def _walk_index_in_order(
self, db: str, coll: str, name: str, *, reverse: bool = False
) -> list[dict[str, Any]]:
c = self._cursor(_IDX_ENTRIES_TABLE)
c.set_key(db, coll, name, b"")
rc = c.search_near()
if rc == wt.WT_NOTFOUND:
return []
if rc < 0 and c.next() != 0:
return []
id_keys: list[bytes] = []
while True:
k = c.get_key()
if (k[0], k[1], k[2]) != (db, coll, name):
break
packed = bytes(k[3])
_esc, row_id = _unpack_entry(packed)
id_keys.append(row_id)
if c.next() != 0:
break
if reverse:
id_keys.reverse()
return self._docs_by_id_keys(db, coll, id_keys)
[docs]
def explain_plan(
self,
db: str,
coll: str,
filter: dict[str, Any] | None = None,
*,
sort: Mapping[str, Any] | None = None,
hint: str | Mapping[str, Any] | None = None,
collation: Any = None,
) -> dict[str, Any]:
"""Plan summary for what ``find_matching`` would do with these args.
No execution; mirrors the same routing decisions. Returns
``{"kind": "COLLSCAN"}`` or ``{"kind": "IXSCAN", "index_name",
"key_pattern", "direction"}``. ``direction`` is ``"forward"``
unless a sort spec inverts it relative to the chosen index.
``collation``: mirrors the runtime gate — when set, only
indexes whose stored ``collation`` matches the query's are
considered for string-bearing predicates. Mismatched indexes
produce COLLSCAN, same as ``find_matching`` would.
"""
from secantus.collation import parse as _parse_collation
collation_obj = _parse_collation(collation)
filter = filter or {}
with self._lock:
sort_field, sort_dir = self._single_sort_spec(sort)
if hint is not None:
try:
resolved = self._resolve_hint(db, coll, hint)
except BadHint:
return {"kind": "COLLSCAN"}
if resolved == "$natural":
return {"kind": "COLLSCAN"}
if resolved == _ID_INDEX_NAME:
direction = "forward"
if sort_field == "_id" and sort_dir == -1:
direction = "backward"
return {
"kind": "IXSCAN",
"index_name": _ID_INDEX_NAME,
"key_pattern": {"_id": 1},
"direction": direction,
}
key_spec = self._key_spec_for(db, coll, resolved)
if key_spec is None:
return {"kind": "COLLSCAN"}
return self._make_ixscan_plan(resolved, key_spec, sort_field, sort_dir)
picked = self._pick_index_for_filter(db, coll, filter, collation=collation_obj)
if picked is not None:
name, key_spec = picked
return self._make_ixscan_plan(name, key_spec, sort_field, sort_dir)
if not filter and sort_field is not None:
idx = self._find_leading_field_index(
db, coll, sort_field, filter, collation=collation_obj
)
if idx is not None:
name, _idx_dir, _is_compound = idx
key_spec = self._key_spec_for(db, coll, name)
if key_spec is not None:
return self._make_ixscan_plan(name, key_spec, sort_field, sort_dir)
# Multi-field sort acceleration mirrored in the planner: same
# rules as find_matching (compound key spec exactly matches
# or fully inverts the sort, filter empty).
if not filter and sort_field is None and sort:
multi_spec = self._multi_sort_spec(sort)
if multi_spec is not None and len(multi_spec) > 1:
match = self._compound_index_for_sort(
db, coll, multi_spec, collation=collation_obj
)
if match is not None:
name, reverse = match
key_spec = self._key_spec_for(db, coll, name)
if key_spec is not None:
return {
"kind": "IXSCAN",
"index_name": name,
"key_pattern": key_spec,
"direction": "backward" if reverse else "forward",
}
return {"kind": "COLLSCAN"}
def _key_spec_for(self, db: str, coll: str, name: str) -> dict[str, Any] | None:
for n, key_spec, _sparse, _unique in self._all_indexes(db, coll):
if n == name:
return dict(key_spec)
return None
def _pick_geo_index_for_filter(
self, db: str, coll: str, filter: dict[str, Any]
) -> tuple[str, dict[str, Any]] | None:
"""Mirror :meth:`_try_geo_index_id_keys`'s index selection (no exec).
Returns ``(name, key_spec)`` if the filter has a geo operator on
a geo-indexed field; ``None`` otherwise. The picker is exact —
``_try_geo_index_id_keys`` may still bail (e.g. ``$near`` with no
max distance), but ``explain`` reports IXSCAN whenever an index
*could* serve the query, matching mongod's planner explain.
"""
for field, value in filter.items():
if not isinstance(value, dict):
continue
if not any(op in value for op in self._GEO_OPS):
continue
for name, key_spec, _opts in self._iter_indexes(db, coll):
geo = _geo_type_of(key_spec)
if geo is not None and geo[0] == field:
return name, dict(key_spec)
return None
def _pick_index_for_filter(
self,
db: str,
coll: str,
filter: dict[str, Any],
*,
collation: Any = None,
) -> tuple[str, dict[str, Any]] | None:
"""Mirror ``_try_index_lookup``'s index-selection (no execution).
``collation`` propagates from the query; when set, only
indexes with a matching stored ``collation`` option are
considered. Single-field, compound bare-eq, and compound
prefix + trailing-operator pickers all collation-match;
``numericOrdering`` queries fall through to COLLSCAN.
"""
if not filter:
return None
if any(f.startswith("$") for f in filter):
return None
# Mirror the _id point-lookup fast path: report it as an IXSCAN on
# the virtual _id_ index (key pattern {_id: 1}), matching mongod.
# Timeseries collections fall through to COLLSCAN (suffixed keys).
if (
len(filter) == 1
and "_id" in filter
and _id_point_lookup_keys(filter["_id"]) is not None
and not self._is_timeseries(db, coll)
):
return _ID_INDEX_NAME, {"_id": 1}
# Mirror `_try_index_id_keys`: geo dispatch first.
geo_pick = self._pick_geo_index_for_filter(db, coll, filter)
if geo_pick is not None:
return geo_pick
if all(not isinstance(v, dict) for v in filter.values()):
picked = self._pick_compound_eq_index(db, coll, filter, collation=collation)
if picked is not None:
return picked
if len(filter) >= 2:
picked = self._pick_compound_range_index(db, coll, filter, collation=collation)
if picked is not None:
return picked
if len(filter) == 1:
field, value = next(iter(filter.items()))
# Mirror the lookup: {field: {$exists: true}} → sparse index IXSCAN.
if isinstance(value, dict) and len(value) == 1 and value.get("$exists"):
name = self._sparse_index_for_exists(db, coll, field)
if name is None:
return None
key_spec = self._key_spec_for(db, coll, name)
return (name, key_spec) if key_spec is not None else None
idx_match = self._find_leading_field_index(db, coll, field, filter, collation=collation)
if idx_match is None:
return None
if isinstance(value, dict):
if not value or not all(k.startswith("$") for k in value):
return None
if not all(op in self._RANGE_OPS for op in value):
return None
name, _direction, _is_compound = idx_match
key_spec = self._key_spec_for(db, coll, name)
if key_spec is None:
return None
return name, key_spec
# Multi-field filter: mirror the lookup's single-field + partial-absorbed
# residual path so explain reports IXSCAN (with isPartial) where the
# query would actually use the index.
match = self._single_field_partial_residual_match(db, coll, filter, collation=collation)
if match is None:
return None
name = match[2][0]
key_spec = self._key_spec_for(db, coll, name)
if key_spec is None:
return None
return name, key_spec
@staticmethod
def _make_ixscan_plan(
name: str,
key_spec: Mapping[str, Any],
sort_field: str | None,
sort_dir: int,
) -> dict[str, Any]:
direction = "forward"
if sort_field is not None and sort_field in key_spec:
idx_dir = int(key_spec[sort_field])
if sort_dir != 0 and sort_dir != idx_dir:
direction = "backward"
return {
"kind": "IXSCAN",
"index_name": name,
"key_pattern": dict(key_spec),
"direction": direction,
}
def count_matching(
self,
db: str,
coll: str,
filter: dict[str, Any] | None = None,
*,
let: dict[str, Any] | None = None,
collation: Any = None,
) -> int:
if self._is_oplog_rs(db, coll):
return self._count_oplog_rs(filter, let=let, collation=collation)
if self._is_system_users(db, coll):
return self._count_system_users(filter, let=let, collation=collation)
if self._is_system_version(db, coll):
return self._count_system_version(filter, let=let, collation=collation)
from secantus.collation import parse as _parse_collation
collation_obj = _parse_collation(collation)
self._refresh_read_snapshot()
if not filter:
with self._lock:
return sum(1 for _ in self._scan_docs(db, coll))
return sum(
1
for doc in self._all_docs(db, coll)
if matches(doc, filter, vars=let, collation=collation_obj)
)
[docs]
def collection_data_size(self, db: str, coll: str) -> int:
"""Sum of bson-encoded doc bytes for ``coll``.
Used by ``collStats`` / ``dbStats`` for ``size`` / ``dataSize``.
Best-effort estimate — doesn't include WT block overhead.
"""
with self._lock:
return sum(len(blob) for _id_k, blob in self._scan_docs(db, coll))
[docs]
def index_sizes(self, db: str, coll: str) -> dict[str, int]:
"""Map of index name → sum of packed entry-key bytes.
``_id_`` is reported separately as ``len(id_key)`` summed across
the doc table, so callers can include it alongside secondary
indexes for an accurate ``totalIndexSize``.
"""
with self._lock:
# Mutating scanners read the current rows before deleting/rewriting
# them; a snapshot pinned by an earlier positioned cursor on
# this connection thread would hide rows committed by other
# threads and turn the scan into a silent partial no-op
# (the gauge's drop-then-reinsert E11000 cluster).
self._refresh_read_snapshot()
sizes: dict[str, int] = {}
id_size = sum(len(id_k) for id_k, _blob in self._scan_docs(db, coll))
if id_size:
sizes[_ID_INDEX_NAME] = id_size
entry_rows = self._collect_prefix(_IDX_ENTRIES_TABLE, (db, coll))
for k, _v in entry_rows:
name = k[2]
packed = bytes(k[3])
sizes[name] = sizes.get(name, 0) + len(packed)
return sizes
@_retry_write_conflicts
def update_matching(
self,
db: str,
coll: str,
filter: dict[str, Any],
update: dict[str, Any],
*,
multi: bool = False,
upsert: bool = False,
array_filters: list[dict[str, Any]] | None = None,
let: dict[str, Any] | None = None,
collation: Any = None,
validator: dict[str, Any] | None = None,
journal: bool = False,
) -> dict[str, Any]:
from secantus.collation import parse as _parse_collation
collation_obj = _parse_collation(collation)
# Release any sticky session snapshot before the write's
# ``begin_transaction`` acquires a new one. Otherwise the
# transaction inherits a stale view and the candidate scan
# misses rows committed by other connections (the cross-
# connection visibility fix applied to reads — see
# ``_refresh_read_snapshot``).
self._refresh_read_snapshot()
matched = 0
modified = 0
upserted_id: Any = None
did_upsert = False
oplog_entries: list[dict[str, Any]] = []
pre_images: list[bytes | None] = []
oplog_on = self.enable_oplog
with self._coll_lock(db, coll), self._batch_transaction(sync=journal):
# Per-collection lock + one WT transaction per call. Every
# doc-table write + index-entry delete/insert + oplog write
# that lands in this method shares a single commit. Phase
# 2.4: was self._lock; now per-coll so different
# collections update in parallel.
self._ensure_collection(db, coll)
ns = self._ns(db, coll)
ui = self._collection_uuid(db, coll) if oplog_on else None
preimages_on = oplog_on and self._pre_post_images_enabled(db, coll)
indexes = self._all_indexes(db, coll)
partials = self._partial_filters(db, coll)
multikey_names = self._multikey_index_names(db, coll)
# Index-routed when the filter is covered (only matching id_keys
# come back from the index walk); full scan otherwise. Either
# way the doc cursor isn't held across writes — bytes are
# eagerly buffered. Only matching docs pay ``bson.decode``.
# With a collation in effect, fall back to a doc-table scan:
# the index entries don't carry the collation's folding, so
# an indexed equality probe would miss case-insensitive
# matches. Always materialise the list — the update loop
# rewrites the doc table via the cached cursor, which
# invalidates a still-walking ``_scan_docs`` cursor on the
# same session.
if collation_obj is not None:
candidates = list(self._scan_docs(db, coll))
else:
candidates = self._candidates_iter(db, coll, filter)
for id_k, blob in candidates:
doc = bson.decode(blob)
if not matches(doc, filter, vars=let, collation=collation_obj):
continue
matched += 1
pos = find_positional_matches(doc, filter)
new = apply_update(
doc,
update,
array_filters=array_filters,
positional_matches=pos,
let=let,
)
if new != doc:
# Document-validator check: collection-level
# ``validator`` (set via ``create`` / ``collMod``)
# rejects updates whose result fails the predicate.
# Caller passes ``None`` to skip
# (``bypassDocumentValidation: true``).
if validator is not None and not matches(new, validator):
raise DocumentValidationError(new.get("_id"))
# _id is immutable, so the row's actual key is the right
# write target. For ordinary collections that equals
# _id_key(new["_id"]); for timeseries the row key carries
# a uniqueness suffix that a recompute would drop —
# writing at the recomputed key would strand the old row.
new_id_key = id_k
conflict = self._unique_conflict(
db, coll, new, indexes, exclude_id_key=id_k, partials=partials
)
if conflict is not None:
cname, kpat, kval = conflict
raise IndexConflict(cname, new["_id"], key_pattern=kpat, key_value=kval)
# Geo validation must reject the update before any
# write happens, otherwise we'd be left with a
# half-deleted set of index entries.
self._validate_geo_indexes(db, coll, new, indexes, partials)
new_blob = bson.encode(new)
if len(new_blob) > MAX_BSON_OBJECT_SIZE:
raise DocumentTooLargeError(
10334,
"Plan executor error during update :: caused by :: "
f"Resulting document after update is larger than "
f"{MAX_BSON_OBJECT_SIZE}",
)
modified += 1
self._delete_index_entries(
db, coll, doc, indexes, partials, id_key_override=id_k
)
doc_cur = self._cursor(_DOC_TABLE)
doc_cur[db, coll, new_id_key] = new_blob
self._write_index_entries(
db, coll, new, indexes, partials, id_key_override=id_k
)
multikey_names = self._maybe_mark_multikey(
db, coll, new, indexes, multikey_names
)
# Pipeline-form updates (a list of stages) are
# diff-style in the oplog — mongod emits op "u" with
# an update description (the unified "array
# truncation" spec asserts operationType "update",
# not "replace").
is_replacement = not isinstance(update, list) and not any(
isinstance(k, str) and k.startswith("$") for k in update
)
if oplog_on:
if is_replacement:
o_field: dict[str, Any] = dict(new)
else:
o_field = {"$v": 2, "diff": compute_update_description(doc, new)}
oplog_entries.append(
{
"op": "u",
"ns": ns,
"ui": bson.Binary(ui.bytes, subtype=4),
"o": o_field,
"o2": {"_id": doc["_id"]},
}
)
pre_images.append(bson.encode(doc) if preimages_on else None)
if not multi:
break
if matched == 0 and upsert:
seed: dict[str, Any] = {}
for k, v in filter.items():
# Seed bare-equality predicates into the upserted doc.
# A dict value is only skipped when it's an OPERATOR
# expression ({$gt: 5}); a literal subdocument value
# ({f: ..., f2: ...}, e.g. a compound ``_id``) is a
# real equality and must be seeded — Python's
# ``isinstance(v, dict)`` alone wrongly drops it,
# generating a fresh ObjectId instead.
if k.startswith("$") or _is_operator_expr(v):
continue
seed[k] = v
new = apply_update(seed, update, is_upsert=True, array_filters=array_filters)
if "_id" not in new:
new["_id"] = bson.ObjectId()
if validator is not None and not matches(new, validator):
raise DocumentValidationError(new.get("_id"))
upserted_id = new["_id"]
did_upsert = True
conflict = self._unique_conflict(
db, coll, new, indexes, exclude_id_key=None, partials=partials
)
if conflict is not None:
cname, kpat, kval = conflict
raise IndexConflict(cname, new["_id"], key_pattern=kpat, key_value=kval)
self._validate_geo_indexes(db, coll, new, indexes, partials)
upsert_blob = bson.encode(new)
if len(upsert_blob) > MAX_BSON_OBJECT_SIZE:
raise DocumentTooLargeError(
17420,
"Plan executor error during update :: caused by :: "
f"Document to upsert is larger than {MAX_BSON_OBJECT_SIZE}",
)
doc_cur = self._cursor(_DOC_TABLE)
doc_cur[db, coll, _id_key(upserted_id)] = upsert_blob
self._write_index_entries(db, coll, new, indexes, partials)
self._maybe_mark_multikey(db, coll, new, indexes, multikey_names)
if oplog_on:
oplog_entries.append(
{
"op": "i",
"ns": ns,
"ui": bson.Binary(ui.bytes, subtype=4),
"o": dict(new),
"o2": {"_id": upserted_id},
}
)
pre_images.append(None)
cap_ns = ns if oplog_on else ""
cap_entries, cap_pre = self._enforce_capped_bounds_locked(
db, coll, set(), indexes, partials, oplog_on, cap_ns, ui
)
if cap_entries:
oplog_entries.extend(cap_entries)
pre_images.extend(cap_pre)
if oplog_entries:
self._emit_oplog(oplog_entries, pre_images)
return {
"matched": matched,
"modified": modified,
"upserted_id": upserted_id,
"did_upsert": did_upsert,
}
@_retry_write_conflicts
def delete_matching(
self,
db: str,
coll: str,
filter: dict[str, Any],
*,
limit: int = 0,
let: dict[str, Any] | None = None,
collation: Any = None,
journal: bool = False,
) -> int:
from secantus.collation import parse as _parse_collation
collation_obj = _parse_collation(collation)
# See ``update_matching`` — release the sticky snapshot so the
# candidate scan sees writes committed by other connections.
self._refresh_read_snapshot()
deleted = 0
oplog_entries: list[dict[str, Any]] = []
pre_images: list[bytes | None] = []
oplog_on = self.enable_oplog
with self._coll_lock(db, coll), self._batch_transaction(sync=journal):
# Per-collection lock (Phase 2.4) + one WT transaction.
# Groups the per-doc removes + index-entry deletes + oplog
# writes into one commit. Other collections delete in
# parallel.
ns = self._ns(db, coll) if oplog_on else ""
preimages_on = oplog_on and self._pre_post_images_enabled(db, coll)
ui = (
self._collection_uuid(db, coll)
if oplog_on and self._coll_options(db, coll) is not None
else None
)
indexes = self._all_indexes(db, coll)
partials = self._partial_filters(db, coll)
# Index-routed candidates when the filter is covered; full scan
# otherwise. See update_matching for the full-scan rationale.
# Collation forces a full scan — index entries don't carry the
# collation's folding. Always materialise into a list so the
# delete loop's writes don't invalidate the iteration cursor
# mid-scan (deletes via ``_cursor(_DOC_TABLE)`` share the
# cached cursor with ``_scan_docs``).
if collation_obj is not None:
candidates = list(self._scan_docs(db, coll))
else:
candidates = self._candidates_iter(db, coll, filter)
for id_k, blob in candidates:
doc = bson.decode(blob)
if not matches(doc, filter, vars=let, collation=collation_obj):
continue
self._delete_index_entries(db, coll, doc, indexes, partials, id_key_override=id_k)
doc_cur = self._cursor(_DOC_TABLE)
doc_cur.set_key(db, coll, id_k)
doc_cur.remove()
deleted += 1
if oplog_on:
entry: dict[str, Any] = {
"op": "d",
"ns": ns,
"o": {"_id": doc["_id"]},
"o2": {"_id": doc["_id"]},
}
if ui is not None:
entry["ui"] = bson.Binary(ui.bytes, subtype=4)
oplog_entries.append(entry)
pre_images.append(bson.encode(doc) if preimages_on else None)
if limit > 0 and deleted >= limit:
break
if oplog_entries:
self._emit_oplog(oplog_entries, pre_images)
return deleted
[docs]
def prune_ttl(
self,
db: str,
coll: str,
*,
now: _dt.datetime | None = None,
) -> int:
"""Delete docs whose indexed Date field is older than now - TTL.
For every index on ``coll`` with an ``expireAfterSeconds`` option,
walks the collection and deletes docs whose indexed field resolves
to a ``datetime`` older than ``now - expireAfterSeconds``. Docs
without the field, with non-date values, or with values inside the
TTL window are left in place. Real MongoDB runs this on a 60s
background sweeper; SecantusDB invokes it explicitly so tests can
drive expiry with an injected ``now``. Returns the number of docs
pruned.
"""
ttl_indexes: list[tuple[str, str, float]] = []
for name, key_spec, opts in self._iter_indexes(db, coll):
ttl = opts.get("expireAfterSeconds")
if not isinstance(ttl, (int, float)) or ttl < 0:
continue
field = next(iter(key_spec), None)
if not isinstance(field, str):
continue
ttl_indexes.append((name, field, float(ttl)))
if not ttl_indexes:
return 0
when = now if now is not None else _dt.datetime.now(_dt.timezone.utc)
if when.tzinfo is None:
when = when.replace(tzinfo=_dt.timezone.utc)
pruned = 0
oplog_entries: list[dict[str, Any]] = []
pre_images: list[bytes | None] = []
with self._lock:
ns = self._ns(db, coll)
preimages_on = self._pre_post_images_enabled(db, coll)
ui = (
self._collection_uuid(db, coll)
if self._coll_options(db, coll) is not None
else None
)
indexes = self._all_indexes(db, coll)
partials = self._partial_filters(db, coll)
candidates = list(self._scan_docs(db, coll))
for id_k, blob in candidates:
doc = bson.decode(blob)
expired = False
for _name, field, ttl_seconds in ttl_indexes:
value = get_path(doc, field)
if not isinstance(value, _dt.datetime):
continue
value_aware = value if value.tzinfo else value.replace(tzinfo=_dt.timezone.utc)
if (when - value_aware).total_seconds() > ttl_seconds:
expired = True
break
if not expired:
continue
self._delete_index_entries(db, coll, doc, indexes, partials, id_key_override=id_k)
doc_cur = self._cursor(_DOC_TABLE)
doc_cur.set_key(db, coll, id_k)
doc_cur.remove()
pruned += 1
entry: dict[str, Any] = {
"op": "d",
"ns": ns,
"o": {"_id": doc["_id"]},
"o2": {"_id": doc["_id"]},
}
if ui is not None:
entry["ui"] = bson.Binary(ui.bytes, subtype=4)
oplog_entries.append(entry)
pre_images.append(bson.encode(doc) if preimages_on else None)
if oplog_entries:
self._emit_oplog(oplog_entries, pre_images)
return pruned
@staticmethod
def _table_kf(table: str) -> str:
return {
_COLL_TABLE: "SS",
_DOC_TABLE: "SSu",
_IDX_TABLE: "SSS",
_IDX_ENTRIES_TABLE: "SSSu",
}[table]
@staticmethod
def _smallest_for_kf(kf: str) -> tuple[Any, ...]:
return tuple(b"" if c == "u" else "" for c in kf)
def _collect_prefix(
self, table: str, prefix: tuple[Any, ...]
) -> list[tuple[tuple[Any, ...], Any]]:
c = self._cursor(table)
kf = self._table_kf(table)
seed = prefix + self._smallest_for_kf(kf)[len(prefix) :]
c.set_key(*seed)
rc = c.search_near()
if rc == wt.WT_NOTFOUND:
return []
if rc < 0 and c.next() != 0:
return []
out: list[tuple[tuple[Any, ...], Any]] = []
while True:
k = tuple(c.get_key())
if k[: len(prefix)] != prefix:
break
v = c.get_value()
out.append((k, bytes(v) if isinstance(v, (bytes, bytearray)) else v))
if c.next() != 0:
break
return out
def _delete_keys(self, table: str, keys: list[tuple[Any, ...]]) -> None:
if not keys:
return
c = self._cursor(table)
for k in keys:
c.set_key(*k)
c.remove()
c.reset()
def drop_collection(self, db: str, coll: str) -> bool:
with self._lock:
# Mutating scanners read the current rows before deleting/rewriting
# them; a snapshot pinned by an earlier positioned cursor on
# this connection thread would hide rows committed by other
# threads and turn the scan into a silent partial no-op
# (the gauge's drop-then-reinsert E11000 cluster).
self._refresh_read_snapshot()
existed = self._coll_options(db, coll) is not None
ui = self._collection_uuid(db, coll) if existed else None
for tbl in (_DOC_TABLE, _IDX_TABLE, _IDX_ENTRIES_TABLE):
rows = self._collect_prefix(tbl, (db, coll))
self._delete_keys(tbl, [k for k, _ in rows])
c = self._cursor(_COLL_TABLE)
c.set_key(db, coll)
if c.search() == 0:
c.remove()
if existed and ui is not None:
self._emit_oplog(
[
{
"op": "c",
"ns": f"{db}.$cmd",
"ui": bson.Binary(ui.bytes, subtype=4),
"o": {"drop": coll},
}
]
)
return existed
def drop_database(self, db: str) -> None:
with self._lock:
# Mutating scanners read the current rows before deleting/rewriting
# them; a snapshot pinned by an earlier positioned cursor on
# this connection thread would hide rows committed by other
# threads and turn the scan into a silent partial no-op
# (the gauge's drop-then-reinsert E11000 cluster).
self._refresh_read_snapshot()
colls_with_ui: list[tuple[str, _uuid.UUID]] = []
for c_name in self.list_collections(db):
ui = self._collection_uuid(db, c_name)
colls_with_ui.append((c_name, ui))
for tbl in (_DOC_TABLE, _IDX_TABLE, _IDX_ENTRIES_TABLE, _COLL_TABLE):
rows = self._collect_prefix(tbl, (db,))
self._delete_keys(tbl, [k for k, _ in rows])
entries: list[dict[str, Any]] = []
for c_name, ui in colls_with_ui:
entries.append(
{
"op": "c",
"ns": f"{db}.$cmd",
"ui": bson.Binary(ui.bytes, subtype=4),
"o": {"drop": c_name},
}
)
entries.append({"op": "c", "ns": f"{db}.$cmd", "o": {"dropDatabase": 1}})
self._emit_oplog(entries)
def rename_collection(
self,
src_db: str,
src_coll: str,
dst_db: str,
dst_coll: str,
*,
drop_target: bool = False,
) -> tuple[bool, str | None]:
with self._lock:
# Mutating scanners read the current rows before deleting/rewriting
# them; a snapshot pinned by an earlier positioned cursor on
# this connection thread would hide rows committed by other
# threads and turn the scan into a silent partial no-op
# (the gauge's drop-then-reinsert E11000 cluster).
self._refresh_read_snapshot()
if self._coll_options(src_db, src_coll) is None:
return False, f"source namespace does not exist: {src_db}.{src_coll}"
if (src_db, src_coll) == (dst_db, dst_coll):
return True, None
ui = self._collection_uuid(src_db, src_coll)
dst_existed = self._coll_options(dst_db, dst_coll) is not None
dst_ui = self._collection_uuid(dst_db, dst_coll) if dst_existed else None
if dst_existed:
if not drop_target:
return False, f"target namespace exists: {dst_db}.{dst_coll}"
for tbl in (_DOC_TABLE, _IDX_TABLE, _IDX_ENTRIES_TABLE):
rows = self._collect_prefix(tbl, (dst_db, dst_coll))
self._delete_keys(tbl, [k for k, _ in rows])
c = self._cursor(_COLL_TABLE)
c.set_key(dst_db, dst_coll)
if c.search() == 0:
c.remove()
for tbl in (_DOC_TABLE, _IDX_TABLE, _IDX_ENTRIES_TABLE):
rows = self._collect_prefix(tbl, (src_db, src_coll))
self._delete_keys(tbl, [k for k, _ in rows])
c = self._cursor(tbl)
for k, v in rows:
new_k = (dst_db, dst_coll) + k[2:]
c.set_key(*new_k)
c.set_value(v)
c.insert()
c.reset()
ensure = self._cursor(_COLL_TABLE)
ensure.set_key(dst_db, dst_coll)
if ensure.search() != 0:
ensure.reset()
ensure[dst_db, dst_coll] = b""
ensure.reset()
ensure.set_key(src_db, src_coll)
if ensure.search() == 0:
ensure.remove()
entries: list[dict[str, Any]] = []
if dst_existed and dst_ui is not None:
entries.append(
{
"op": "c",
"ns": f"{dst_db}.$cmd",
"ui": bson.Binary(dst_ui.bytes, subtype=4),
"o": {"drop": dst_coll},
}
)
rename_o: dict[str, Any] = {
"renameCollection": f"{src_db}.{src_coll}",
"to": f"{dst_db}.{dst_coll}",
}
if dst_existed and dst_ui is not None:
# mongod records the dropped target's UUID under ``dropTarget``
# in the rename oplog entry; the change-stream ``rename`` event
# surfaces it under ``operationDescription.dropTarget`` when
# ``showExpandedEvents`` is on.
rename_o["dropTarget"] = bson.Binary(dst_ui.bytes, subtype=4)
entries.append(
{
"op": "c",
"ns": f"{src_db}.$cmd",
"ui": bson.Binary(ui.bytes, subtype=4),
"o": rename_o,
}
)
self._emit_oplog(entries)
return True, None
def record_collmod(self, db: str, coll: str, description: dict[str, Any]) -> None:
"""Emit a ``collMod`` command oplog entry so change streams watching
``db`` / ``db.coll`` (with ``showExpandedEvents``) can surface a
``modify`` event. ``description`` carries the changed options (empty
for a no-op ``collMod``); it becomes the event's
``operationDescription``. The collection's option mutation has already
been applied by the caller via :meth:`set_collection_options`.
"""
with self._lock:
if self._coll_options(db, coll) is None:
return
ui = self._collection_uuid(db, coll)
self._emit_oplog(
[
{
"op": "c",
"ns": f"{db}.$cmd",
"ui": bson.Binary(ui.bytes, subtype=4),
"o": {"collMod": coll, **description},
}
]
)
def list_collections(self, db: str) -> list[str]:
self._refresh_read_snapshot()
with self._lock:
c = self._cursor(_COLL_TABLE)
c.set_key(db, "")
rc = c.search_near()
if rc != wt.WT_NOTFOUND and not (rc < 0 and c.next() != 0):
out: list[str] = []
while True:
k = c.get_key()
if k[0] != db:
break
out.append(k[1])
if c.next() != 0:
break
else:
out = []
# Synthesise ``local.oplog.rs`` for the ``local`` db whenever the
# oplog is enabled. The collection isn't materialised in
# ``_COLL_TABLE`` — it's a view over the oplog WT table — but
# ``listCollections`` needs to surface it so pymongo clients can
# discover it before querying.
if self.enable_oplog and db == "local" and "oplog.rs" not in out:
out.append("oplog.rs")
return sorted(out)
def list_databases(self) -> list[str]:
self._refresh_read_snapshot()
with self._lock:
c = self._cursor(_COLL_TABLE)
seen: set[str] = set()
rc = c.next()
while rc == 0:
k = c.get_key()
seen.add(k[0])
rc = c.next()
# mongod always exposes the ``local`` database; mirror that
# when the oplog is enabled so listDatabases includes it even
# before any user-created collection lands in ``local``.
if self.enable_oplog:
seen.add("local")
return sorted(seen)
def create_index(
self,
db: str,
coll: str,
name: str,
key_spec: Mapping[str, Any],
options: Mapping[str, Any] | None = None,
) -> bool:
if name == _ID_INDEX_NAME:
return False
# Text / hashed indexes are documented out-of-scope (CLAUDE.md
# "Out of scope regardless: text / hashed / wildcard indexes").
# Surface the rejection as a typed exception (caught in
# ``commands._create_indexes``) instead of letting the geo
# picker / encoder later fall over with an opaque internal
# error. Mongo-node-driver's ``Find should correctly sort using
# text search`` test expects a clean error here.
for _field, _spec_val in key_spec.items():
if _spec_val in ("text", "hashed"):
raise CreateIndexUnsupported(f"{_spec_val} indexes are not supported by SecantusDB")
options = dict(options or {})
with self._lock:
self._ensure_collection(db, coll)
c = self._cursor(_IDX_TABLE)
c.set_key(db, coll, name)
if c.search() == 0:
# Index exists. Mongo rejects re-creation with conflicting
# options (different ``unique`` / ``sparse`` / ``hidden``
# / ``expireAfterSeconds``). Silently succeeding hides
# a bug surface that mongo-ruby-driver's ``Collection#
# create_indexes when index creation fails`` test pins.
existing_raw = bytes(c.get_value())
existing = bson.decode(existing_raw) if existing_raw else {}
existing_opts = dict(existing.get("options") or {})
_CONFLICTING_OPTS = (
"unique",
"sparse",
"hidden",
"expireAfterSeconds",
"partialFilterExpression",
)
for opt in _CONFLICTING_OPTS:
if (opt in options or opt in existing_opts) and options.get(
opt
) != existing_opts.get(opt):
raise IndexOptionsConflict(
f"Index with name '{name}' already exists with different options"
)
return False
sparse = bool(options.get("sparse"))
unique = bool(options.get("unique"))
partial_filter = options.get("partialFilterExpression")
if not isinstance(partial_filter, Mapping) or not partial_filter:
partial_filter = None
key_spec_dict = dict(key_spec)
geo = _geo_type_of(key_spec_dict)
# Geo indexes use the same entries table but write **multiple**
# entries per doc (one per S2 cell or 2d bucket). They're inherently
# multikey-style; uniqueness is meaningless for geo and is rejected
# by mongod, so we mirror.
if geo is not None:
if unique:
raise IndexConflict(name, None)
geo_field, geo_type = geo
# Geo indexes are always multikey from the picker's perspective
# — each doc may produce many cell entries. Mark it so the
# regular pickers skip the index for non-geo queries.
options["multikey"] = True
entries: list[tuple[bytes, bytes]] = []
for id_k, blob in self._scan_docs(db, coll):
d = bson.decode(blob)
if partial_filter is not None and not matches(d, partial_filter):
continue
for cell_bytes in _doc_geo_cells(
d, geo_field, geo_type, options, index_name=name
):
entries.append((cell_bytes, id_k))
payload = bson.encode({"key": dict(key_spec), "options": options})
c.reset()
c[db, coll, name] = payload
entry_cur = self._cursor(_IDX_ENTRIES_TABLE)
for kb, id_k in entries:
entry_cur.reset()
entry_cur[db, coll, name, _pack_entry(kb, id_k)] = b""
else:
# Single doc-table walk: decode each blob once and fold all
# three checks (uniqueness, multikey detection, entry build)
# into one pass. Uniqueness is probed against the canonical
# whole-doc key (``_index_key``); index entries are written
# for every key variant (``_index_key_variants``) so per-
# element multikey lookups land at IXSCAN.
seen: dict[bytes, Any] | None = {} if unique else None
multikey = False
entries = []
coll_opt = _parse_index_collation(options.get("collation"))
for id_k, blob in self._scan_docs(db, coll):
d = bson.decode(blob)
if partial_filter is not None and not matches(d, partial_filter):
continue
if not multikey and _doc_makes_multikey(d, key_spec_dict):
multikey = True
if seen is not None:
canonical = _index_key(d, key_spec_dict, sparse=sparse, collation=coll_opt)
if canonical is not None:
if canonical in seen:
raise IndexConflict(name, d.get("_id"))
seen[canonical] = d.get("_id")
for kb in _index_key_variants(
d, key_spec_dict, sparse=sparse, collation=coll_opt
):
entries.append((kb, id_k))
if multikey:
options["multikey"] = True
payload = bson.encode({"key": dict(key_spec), "options": options})
c.reset()
c[db, coll, name] = payload
entry_cur = self._cursor(_IDX_ENTRIES_TABLE)
for kb, id_k in entries:
entry_cur.reset()
entry_cur[db, coll, name, _pack_entry(kb, id_k)] = b""
ui = self._collection_uuid(db, coll)
self._emit_oplog(
[
{
"op": "c",
"ns": f"{db}.$cmd",
"ui": bson.Binary(ui.bytes, subtype=4),
"o": {
"createIndexes": coll,
"indexes": [{"v": 2, "key": dict(key_spec), "name": name, **options}],
},
}
]
)
return True
def list_indexes(self, db: str, coll: str) -> list[dict[str, Any]]:
self._refresh_read_snapshot()
with self._lock:
if self._coll_options(db, coll) is None:
return []
out: list[dict[str, Any]] = [{"v": 2, "key": {"_id": 1}, "name": _ID_INDEX_NAME}]
for name, key_spec, opts in self._iter_indexes(db, coll):
entry: dict[str, Any] = {"v": 2, "key": key_spec, "name": name}
for k, v in opts.items():
entry[k] = v
out.append(entry)
out.sort(key=lambda e: e["name"])
return out
def _iter_indexes(
self, db: str, coll: str
) -> Iterable[tuple[str, dict[str, Any], dict[str, Any]]]:
c = self._cursor(_IDX_TABLE)
c.set_key(db, coll, "")
rc = c.search_near()
if rc == wt.WT_NOTFOUND:
return
if rc < 0 and c.next() != 0:
return
while True:
k = c.get_key()
if k[0] != db or k[1] != coll:
return
payload = bson.decode(bytes(c.get_value()))
yield k[2], payload.get("key", {}), payload.get("options", {})
if c.next() != 0:
return
def drop_index(self, db: str, coll: str, name: str) -> bool:
if name == _ID_INDEX_NAME:
return False
with self._lock:
# Mutating scanners read the current rows before deleting/rewriting
# them; a snapshot pinned by an earlier positioned cursor on
# this connection thread would hide rows committed by other
# threads and turn the scan into a silent partial no-op
# (the gauge's drop-then-reinsert E11000 cluster).
self._refresh_read_snapshot()
c = self._cursor(_IDX_TABLE)
c.set_key(db, coll, name)
if c.search() != 0:
return False
c.remove()
entry_rows = self._collect_prefix(_IDX_ENTRIES_TABLE, (db, coll, name))
self._delete_keys(_IDX_ENTRIES_TABLE, [k for k, _ in entry_rows])
ui = self._collection_uuid(db, coll)
self._emit_oplog(
[
{
"op": "c",
"ns": f"{db}.$cmd",
"ui": bson.Binary(ui.bytes, subtype=4),
"o": {"dropIndexes": coll, "index": name},
}
]
)
return True
def drop_all_indexes(self, db: str, coll: str) -> int:
with self._lock:
# Mutating scanners read the current rows before deleting/rewriting
# them; a snapshot pinned by an earlier positioned cursor on
# this connection thread would hide rows committed by other
# threads and turn the scan into a silent partial no-op
# (the gauge's drop-then-reinsert E11000 cluster).
self._refresh_read_snapshot()
rows = self._collect_prefix(_IDX_TABLE, (db, coll))
names = [k[2] for k, _ in rows]
self._delete_keys(_IDX_TABLE, [k for k, _ in rows])
entry_rows = self._collect_prefix(_IDX_ENTRIES_TABLE, (db, coll))
self._delete_keys(_IDX_ENTRIES_TABLE, [k for k, _ in entry_rows])
if names:
ui = self._collection_uuid(db, coll)
self._emit_oplog(
[
{
"op": "c",
"ns": f"{db}.$cmd",
"ui": bson.Binary(ui.bytes, subtype=4),
"o": {"dropIndexes": coll, "index": n},
}
for n in names
]
)
return len(rows)
def _all_indexes(self, db: str, coll: str) -> list[tuple[str, dict[str, Any], bool, bool]]:
"""Every non-_id_ index: (name, key_spec, sparse, unique)."""
out: list[tuple[str, dict[str, Any], bool, bool]] = []
for name, key_spec, opts in list(self._iter_indexes(db, coll)):
out.append((name, key_spec, bool(opts.get("sparse")), bool(opts.get("unique"))))
return out
def _partial_filters(self, db: str, coll: str) -> dict[str, dict[str, Any]]:
"""Map of index name → ``partialFilterExpression`` for indexes that have one.
Indexes without a partial filter are absent from the dict.
"""
out: dict[str, dict[str, Any]] = {}
for name, _key_spec, opts in self._iter_indexes(db, coll):
pf = opts.get("partialFilterExpression")
if isinstance(pf, Mapping) and pf:
out[name] = dict(pf)
return out
@staticmethod
def _query_implies_partial(query: Mapping[str, Any], partial: Mapping[str, Any]) -> bool:
"""True if every document matching ``query`` is guaranteed to be in
a partial index whose filter is ``partial`` — i.e. ``query`` is at
least as restrictive as ``partial`` on every partial-filter field.
SOUNDNESS is the rule: using a partial index for a query that could
match documents the index doesn't contain returns wrong results, so
this errs to ``False`` (skip the index, full scan — correct but
slower) for anything it can't prove implied. Supports bare-equality
partial values and the ``$eq``/``$lt``/``$lte``/``$gt``/``$gte``
range operators on both sides (``{a: {$lte: 1.5}}`` is implied by a
query equality ``a: 1`` or ``a: {$lt: 1}``).
"""
for key, pval in partial.items():
if key not in query:
return False
qval = query[key]
p_is_ops = isinstance(pval, Mapping) and pval and all(k.startswith("$") for k in pval)
q_is_ops = isinstance(qval, Mapping) and qval and all(k.startswith("$") for k in qval)
if p_is_ops:
if not _clause_implies_bounds(qval, pval):
return False
elif q_is_ops:
# bare-value partial, operator-form query: only an exact
# ``$eq`` of the same value implies it.
if qval.get("$eq") != pval:
return False
elif qval != pval:
return False
return True
def _multikey_index_names(self, db: str, coll: str) -> set[str]:
"""Names of indexes flagged ``multikey`` (must fall back to scan).
Without true multi-key indexing, an index where any doc has a
list-valued field can't serve scalar-element matches — so the
pickers skip these names and ``find_matching`` falls back to a
full scan.
"""
return {
name for name, _key_spec, opts in self._iter_indexes(db, coll) if opts.get("multikey")
}
def _maybe_mark_multikey(
self,
db: str,
coll: str,
doc: Mapping[str, Any],
indexes: list[tuple[str, dict[str, Any], bool, bool]],
already_multikey: set[str],
) -> set[str]:
"""For each non-multikey index, flag it if ``doc`` has an array
value on any indexed field. Returns the (possibly grown) set of
multikey index names so the caller can avoid re-checking.
"""
c = self._cursor(_IDX_TABLE)
for name, key_spec, _sparse, _unique in indexes:
if name in already_multikey:
continue
if not _doc_makes_multikey(doc, key_spec):
continue
c.reset()
c.set_key(db, coll, name)
if c.search() != 0:
continue
payload = bson.decode(bytes(c.get_value()))
opts = dict(payload.get("options") or {})
if opts.get("multikey"):
already_multikey.add(name)
continue
opts["multikey"] = True
payload["options"] = opts
c.reset()
c[db, coll, name] = bson.encode(payload)
already_multikey.add(name)
return already_multikey
def _write_index_entries(
self,
db: str,
coll: str,
doc: dict[str, Any],
indexes: list[tuple[str, dict[str, Any], bool, bool]],
partials: dict[str, dict[str, Any]] | None = None,
*,
id_key_override: bytes | None = None,
) -> None:
if not indexes:
return
c = self._cursor(_IDX_ENTRIES_TABLE)
# Timeseries doc-table keys carry a uniqueness suffix; entries must
# point at the row's ACTUAL key or index lookups would miss it.
id_k = id_key_override if id_key_override is not None else _id_key(doc["_id"])
if partials is None:
partials = self._partial_filters(db, coll)
index_options = self._index_options_map(db, coll)
for name, key_spec, sparse, _unique in indexes:
pf = partials.get(name)
if pf is not None and not matches(doc, pf):
continue
geo = _geo_type_of(key_spec)
if geo is not None:
geo_field, geo_type = geo
opts = index_options.get(name, {})
for cell_bytes in _doc_geo_cells(doc, geo_field, geo_type, opts, index_name=name):
c.reset()
c[db, coll, name, _pack_entry(cell_bytes, id_k)] = b""
continue
coll_opt = _parse_index_collation(index_options.get(name, {}).get("collation"))
for kb in _index_key_variants(doc, key_spec, sparse=sparse, collation=coll_opt):
c.reset()
c[db, coll, name, _pack_entry(kb, id_k)] = b""
def _delete_index_entries(
self,
db: str,
coll: str,
doc: dict[str, Any],
indexes: list[tuple[str, dict[str, Any], bool, bool]],
partials: dict[str, dict[str, Any]] | None = None,
*,
id_key_override: bytes | None = None,
) -> None:
if not indexes:
return
c = self._cursor(_IDX_ENTRIES_TABLE)
# Timeseries doc-table keys carry a uniqueness suffix; entries must
# point at the row's ACTUAL key or index lookups would miss it.
id_k = id_key_override if id_key_override is not None else _id_key(doc["_id"])
if partials is None:
partials = self._partial_filters(db, coll)
index_options = self._index_options_map(db, coll)
for name, key_spec, sparse, _unique in indexes:
pf = partials.get(name)
if pf is not None and not matches(doc, pf):
continue
geo = _geo_type_of(key_spec)
if geo is not None:
geo_field, geo_type = geo
opts = index_options.get(name, {})
# On the delete path, swallow GeoExtractError. A doc that
# was inserted before geo validation became strict might
# have bad geometry; we still need to allow it to be
# deleted. The index may end up with stale entries we
# can't match, but the next compact / drop_index cleans
# those up. Insert/update remain strict.
try:
cells = _doc_geo_cells(doc, geo_field, geo_type, opts, index_name=name)
except GeoExtractError:
continue
for cell_bytes in cells:
c.reset()
c.set_key(db, coll, name, _pack_entry(cell_bytes, id_k))
if c.search() == 0:
c.remove()
continue
coll_opt = _parse_index_collation(index_options.get(name, {}).get("collation"))
for kb in _index_key_variants(doc, key_spec, sparse=sparse, collation=coll_opt):
c.reset()
c.set_key(db, coll, name, _pack_entry(kb, id_k))
if c.search() == 0:
c.remove()
def _validate_geo_indexes(
self,
db: str,
coll: str,
doc: dict[str, Any],
indexes: list[tuple[str, dict[str, Any], bool, bool]],
partials: dict[str, dict[str, Any]] | None = None,
) -> None:
"""Pre-flight every geo index for ``doc``: raise on bad geometry.
Used by the insert / update paths to reject docs *before* writing
them, so a single bad geo coordinate doesn't leave a half-indexed
document behind. The work duplicates ``_write_index_entries``'s
cell computation but is cheap (one Shapely parse + bounds check
per indexed field).
"""
if not indexes:
return
if partials is None:
partials = self._partial_filters(db, coll)
options_map = self._index_options_map(db, coll)
for name, key_spec, _sparse, _unique in indexes:
geo = _geo_type_of(key_spec)
if geo is None:
continue
pf = partials.get(name)
if pf is not None and not matches(doc, pf):
continue
geo_field, geo_type = geo
opts = options_map.get(name, {})
# Compute & discard — `_doc_geo_cells` raises GeoExtractError
# on bad shape or out-of-bounds coords; that's the signal we
# want to bubble up.
_doc_geo_cells(doc, geo_field, geo_type, opts, index_name=name)
def _index_options_map(self, db: str, coll: str) -> dict[str, dict[str, Any]]:
"""Map of index name → its full options blob.
Used by the geo write/delete paths: 2d indexes carry per-index
``bits`` / ``min`` / ``max`` settings that affect the cell encoder,
so we need the option blob to compute the right bucket. Cached
per call (the caller handles per-doc loops).
"""
return {name: dict(opts) for name, _key_spec, opts in self._iter_indexes(db, coll)}
def _unique_conflict(
self,
db: str,
coll: str,
candidate_doc: dict[str, Any],
indexes: list[tuple[str, dict[str, Any], bool, bool]],
*,
exclude_id_key: bytes | None,
partials: dict[str, dict[str, Any]] | None = None,
) -> tuple[str, dict[str, Any], dict[str, Any]] | None:
# Returns ``(index_name, key_pattern, key_value)`` so callers
# can build a mongod-shaped dup-key error response with the
# ``keyPattern`` + ``keyValue`` fields drivers' errorResponse
# tests assert on. ``None`` when no conflict.
if not indexes:
return None
c = self._cursor(_IDX_ENTRIES_TABLE)
if partials is None:
partials = self._partial_filters(db, coll)
index_options = self._index_options_map(db, coll)
for name, key_spec, sparse, unique in indexes:
if not unique:
continue
pf = partials.get(name)
if pf is not None and not matches(candidate_doc, pf):
continue
coll_opt = _parse_index_collation(index_options.get(name, {}).get("collation"))
kb = _index_key(candidate_doc, key_spec, sparse=sparse, collation=coll_opt)
if kb is None:
continue
esc_kb = _escape_kb(kb)
seed = esc_kb + _ENTRY_SEP
c.reset()
c.set_key(db, coll, name, seed)
rc = c.search_near()
if rc == wt.WT_NOTFOUND:
continue
if rc < 0 and c.next() != 0:
continue
while True:
k = c.get_key()
if (k[0], k[1], k[2]) != (db, coll, name):
break
packed = bytes(k[3])
row_esc, row_id = _unpack_entry(packed)
if row_esc != esc_kb:
break
if exclude_id_key is None or row_id != exclude_id_key:
key_value = {
field: get_path(candidate_doc, field, default=None) for field in key_spec
}
return name, dict(key_spec), key_value
if c.next() != 0:
break
return None
def _scan_index_for_id_keys(
self, db: str, coll: str, name: str, kb: bytes, *, prefix: bool = False
) -> list[bytes]:
"""Walk the index entries for ``name`` matching ``kb``.
With ``prefix=False`` (default), only rows whose ``escaped_kb`` is
exactly equal to ``escape(kb)`` are returned — equality lookup.
With ``prefix=True``, any row whose ``escaped_kb`` starts with
``escape(kb)`` is returned — compound-prefix lookup.
"""
c = self._cursor(_IDX_ENTRIES_TABLE)
esc_kb = _escape_kb(kb)
seed = esc_kb if prefix else esc_kb + _ENTRY_SEP
c.set_key(db, coll, name, seed)
rc = c.search_near()
if rc == wt.WT_NOTFOUND:
return []
if rc < 0 and c.next() != 0:
return []
out: list[bytes] = []
while True:
k = c.get_key()
if (k[0], k[1], k[2]) != (db, coll, name):
break
packed = bytes(k[3])
row_esc, row_id = _unpack_entry(packed)
if prefix:
if not row_esc.startswith(esc_kb):
break
elif row_esc != esc_kb:
break
out.append(row_id)
if c.next() != 0:
break
return out
def _all_id_keys_for_index(self, db: str, coll: str, name: str) -> list[bytes]:
"""Every id_key with an entry in index ``name`` — a full index scan.
Serves ``{field: {$exists: true}}`` via a sparse index: a sparse
index's entries table holds an entry for exactly the docs where the
indexed field is present (missing-field docs are omitted; present-
but-null keeps an entry), so the complete set of entries *is* the
``$exists: true`` match set. id_keys can repeat for multikey arrays
(one entry per element); the caller's ``_docs_by_id_keys`` dedups.
"""
c = self._cursor(_IDX_ENTRIES_TABLE)
c.set_key(db, coll, name, b"")
rc = c.search_near()
if rc == wt.WT_NOTFOUND:
return []
if rc < 0 and c.next() != 0:
return []
out: list[bytes] = []
while True:
k = c.get_key()
if (k[0], k[1], k[2]) != (db, coll, name):
break
_row_esc, row_id = _unpack_entry(bytes(k[3]))
out.append(row_id)
if c.next() != 0:
break
return out
def _docs_by_id_keys(self, db: str, coll: str, id_keys: list[bytes]) -> list[dict[str, Any]]:
if not id_keys:
return []
c = self._cursor(_DOC_TABLE)
# Two-stage: WT cursor walk first (raw bytes), then ``bson.decode``
# outside that loop. The cursor work is what needs lock scope;
# decode is pure CPU and benefits from running unsynchronised.
# Multikey indexes write per-element entries, so the same doc's
# id_key can appear more than once for queries that match
# multiple elements. Dedupe while preserving order.
raw: list[bytes] = []
for id_k in dict.fromkeys(id_keys):
c.reset()
c.set_key(db, coll, id_k)
if c.search() == 0:
raw.append(bytes(c.get_value()))
return [bson.decode(b) for b in raw]
_RANGE_OPS: tuple[str, ...] = ("$eq", "$gt", "$gte", "$lt", "$lte", "$in")
_GEO_OPS: tuple[str, ...] = ("$geoWithin", "$geoIntersects", "$near", "$nearSphere")
def _try_geo_index_id_keys(
self, db: str, coll: str, filter: dict[str, Any]
) -> list[bytes] | None:
"""If ``filter`` contains a geo operator on a geo-indexed field,
scan that index's covering cells and return candidate id_keys.
Returns ``None`` when no geo operator is present or no matching
geo index exists — caller falls through to regular pickers and
eventually a full scan. Returns a list (possibly empty) when a
geo index covers the query — caller short-circuits the regular
pickers.
The cell scan over-collects (cell-covering is a superset of the
true intersection); the caller's ``matches()`` step verifies via
:func:`secantus.geo.geo_within` / ``geo_intersects`` and removes
false positives.
"""
# Find a single field with a geo operator on it.
geo_field: str | None = None
geo_op: str | None = None
geo_arg: Any = None
geo_siblings: Mapping[str, Any] | None = None
for field, value in filter.items():
if not isinstance(value, dict):
continue
for op in self._GEO_OPS:
if op in value:
geo_field = field
geo_op = op
geo_arg = value[op]
# Capture the whole condition so `$near` /
# `$nearSphere` legacy 2d form (sibling
# `$maxDistance` / `$minDistance`) can scope the
# scan; without this the picker can't know the
# distance bound and falls back to full-scan.
geo_siblings = value
break
if geo_field is not None:
break
if geo_field is None:
return None
# Locate a geo index on that field.
chosen_name: str | None = None
chosen_type: str | None = None
chosen_opts: dict[str, Any] = {}
for name, key_spec, opts in self._iter_indexes(db, coll):
geo = _geo_type_of(key_spec)
if geo is None:
continue
if geo[0] == geo_field:
chosen_name = name
chosen_type = geo[1]
chosen_opts = dict(opts)
break
if chosen_name is None or chosen_type is None:
return None
# Build the query geometry from the operator arg.
cells = self._geo_query_cells(
geo_op, geo_arg, chosen_type, chosen_opts, siblings=geo_siblings
)
if cells is None:
# Couldn't compute a covering — defer to full scan.
return None
return self._collect_geo_candidates(db, coll, chosen_name, cells)
def _geo_query_cells(
self,
op: str,
arg: Any,
geo_type: str,
options: Mapping[str, Any],
*,
siblings: Mapping[str, Any] | None = None,
) -> list[tuple[bytes, bytes]] | None:
"""Byte ranges covering the query geometry, one per covering cell.
Both 2dsphere and 2d return ``list[tuple[bytes, bytes]]`` — for
2dsphere each entry is the (range_min, range_max) byte pair of
an S2 covering cell expanded to its leaf descendants; for 2d
it's the single (lo, hi) bbox range from `planar_2d_covering`.
Callers use :meth:`_scan_geo_range` for both.
"""
from secantus.geo import GeoError
try:
if op in ("$geoWithin", "$geoIntersects"):
if not isinstance(arg, Mapping):
return None
geom, _ = parse_query_geometry(arg)
elif op in ("$near", "$nearSphere"):
# `$near` without a max distance: caller falls through to
# full scan (signal None). With a max, expand into a cap
# (2dsphere) or planar disk (2d).
(
center,
max_d,
_min_d,
spherical,
legacy_form,
) = self._near_query_geom(
arg,
default_spherical=(op == "$nearSphere"),
siblings=siblings,
)
if max_d is None:
return None
# Unit normalisation: legacy+spherical gives max in
# radians-on-unit-sphere; legacy+planar gives max in
# input units; GeoJSON gives max in meters. Index
# picker for 2dsphere wants radians (so / EARTH_R);
# picker for 2d wants planar (so leave alone for
# legacy+planar; convert rad→degrees for
# legacy+spherical via *180/π).
import math as _math
from shapely.geometry import Point as _Point
from secantus.geo import EARTH_RADIUS_METERS, _SphericalCircle
if geo_type == _GEO_2DSPHERE:
# legacy+spherical → max already radians; otherwise
# → meters → divide by Earth radius for radians.
radius_rad = (
max_d if (legacy_form and spherical) else max_d / EARTH_RADIUS_METERS
)
geom = _SphericalCircle(center[0], center[1], radius_rad)
else: # 2d planar — circular disk
# legacy+spherical → radians-on-unit-sphere → degrees
# in planar input space (the conventional geographic
# mapping that matches mongod's behaviour against a
# 2d index). Otherwise the bound is already in input
# units.
planar_radius = (
max_d * 180.0 / _math.pi if (legacy_form and spherical) else max_d
)
geom = _Point(*center).buffer(planar_radius, quad_segs=16)
else:
return None
except GeoError:
return None
if geo_type == _GEO_2DSPHERE:
# Each cell becomes a degenerate (cell, cell) range so the
# storage scanner does an exact point-lookup. Treating
# 2dsphere uniformly as a list-of-ranges keeps the storage
# path single-shaped.
return [(encode_cell(c), encode_cell(c)) for c in s2_query_covering(geom)]
# 2d: shape must be planar; convert to a list of tight Z-order
# ranges via quadtree decomposition. For small bboxes this is
# one range, same as the single-range path; for wider bboxes
# the decomposition tightens the scan vs the old single coarse
# range.
from shapely.geometry.base import BaseGeometry as _BG
if not isinstance(geom, _BG):
return None
return [
(encode_cell(lo), encode_cell(hi))
for lo, hi in planar_2d_covering_ranges(geom, options)
]
def _near_query_geom(
self,
arg: Any,
*,
default_spherical: bool = False,
siblings: Mapping[str, Any] | None = None,
) -> tuple[tuple[float, float], float | None, float | None, bool, bool]:
"""Reuse :mod:`secantus.query`'s ``$near`` parser for the picker.
Returns ``(center, max_d, min_d, spherical, legacy_form)`` —
legacy_form lets the picker pick the right unit conversion
(radians-on-unit-sphere vs meters vs input units) when building
the index-side geometry.
``default_spherical`` must match the operator: ``$near`` →
False, ``$nearSphere`` → True. Without this, a legacy-form
``$nearSphere`` would be misread as planar and the picker
would build the wrong geometry.
Routing through `_parse_near_spec` keeps the spec semantics in
one place — the operator handler and the picker agree on what
a ``$near`` arg means. ``siblings`` carries the parent
condition dict so the legacy 2d shape ``{geo: {$near: [x, y],
$maxDistance: r}}`` works.
"""
from secantus.query import _parse_near_spec # type: ignore[attr-defined]
return _parse_near_spec(arg, default_spherical=default_spherical, siblings=siblings)
def _collect_geo_candidates(
self,
db: str,
coll: str,
index_name: str,
cells: list[tuple[bytes, bytes]],
) -> list[bytes]:
"""Walk index entries in each (lo, hi) range; return deduplicated id_keys.
A doc with N covering cells produces N index entries; we collect
just one ``_id`` per doc. The post-fetch verifier (in
``find_matching``'s ``matches()`` step) discards docs whose
actual geometry doesn't match the query.
"""
c = self._cursor(_IDX_ENTRIES_TABLE)
seen: set[bytes] = set()
out: list[bytes] = []
for lo_bytes, hi_bytes in cells:
self._scan_geo_range(c, db, coll, index_name, lo_bytes, hi_bytes, seen, out)
return out
def _scan_geo_range(
self,
c: Any,
db: str,
coll: str,
name: str,
lo_bytes: bytes,
hi_bytes: bytes,
seen: set[bytes],
out: list[bytes],
) -> None:
"""Walk every index entry whose escaped cell-id is in [lo_bytes, hi_bytes].
Lex byte order over `_escape_kb`-escaped fixed-width cell IDs is
the same as numeric cell-id order, so a forward WT cursor walk
between the two escaped boundary keys visits every entry inside
the range exactly once. Cell IDs are packed as fixed 8-byte
big-endian, so escaping never changes their relative order.
"""
lo_prefix = _escape_kb(lo_bytes)
hi_prefix = _escape_kb(hi_bytes)
c.reset()
c.set_key(db, coll, name, lo_prefix)
rc = c.search_near()
if rc == wt.WT_NOTFOUND:
return
if rc < 0 and c.next() != 0:
return
while True:
k = c.get_key()
if k[0] != db or k[1] != coll or k[2] != name:
return
packed = bytes(k[3])
sep_pos = packed.find(_ENTRY_SEP)
if sep_pos < 0:
if c.next() != 0:
return
continue
kb_part = packed[:sep_pos]
if kb_part > hi_prefix:
return
id_key = packed[sep_pos + len(_ENTRY_SEP) :]
if id_key not in seen:
seen.add(id_key)
out.append(id_key)
if c.next() != 0:
return
def _try_index_lookup(
self,
db: str,
coll: str,
filter: dict[str, Any],
*,
collation: Any = None,
) -> list[dict[str, Any]] | None:
id_keys = self._try_index_id_keys(db, coll, filter, collation=collation)
if id_keys is None:
return None
return self._docs_by_id_keys(db, coll, id_keys)
def _single_field_partial_residual_match(
self,
db: str,
coll: str,
filter: dict[str, Any],
*,
collation: Any = None,
) -> tuple[str, Any, tuple[str, int, bool]] | None:
"""For a *multi-field* filter, find a single-field index whose leading
field serves one clause while every **other** filter field is absorbed
by the index's (implied) partial filter.
e.g. ``find({x: {$gt: 1}, a: 1})`` against an index on ``x`` partial on
``{a: {$lte: 1.5}}``: ``x``'s range rides the index, the ``a: 1`` clause
is partial-implied (so the index's very existence guarantees it) and is
rechecked by the exact ``matches()`` pass in ``find_matching``. Returns
``(field, value, idx_match)`` or ``None``.
Conservative by design: only *partial* indexes get this treatment, and
only when the residual fields are exactly partial-filter fields — a
non-partial residual keeps the query on COLLSCAN, mirroring the
bare-equality path's ``eff_fields - set(pf)`` philosophy. Shared by the
lookup (``_try_index_id_keys``) and explain (``_pick_index_for_filter``)
dispatchers so they never diverge.
"""
partials = self._partial_filters(db, coll)
for field, value in filter.items():
if isinstance(value, dict) and (
not value or not all(op in self._RANGE_OPS for op in value)
):
continue
idx_match = self._find_leading_field_index(db, coll, field, filter, collation=collation)
if idx_match is None:
continue
name = idx_match[0]
pf = partials.get(name)
if pf is None:
continue
if not (set(filter) - {field}).issubset(set(pf)):
continue
return field, value, idx_match
return None
def _try_index_id_keys(
self,
db: str,
coll: str,
filter: dict[str, Any],
*,
collation: Any = None,
) -> list[bytes] | None:
"""Same dispatch as ``_try_index_lookup`` but returns id_keys instead
of materialised docs. Used by the write paths (update / delete) so
only matching docs pay ``bson.decode``.
``collation`` propagates from the query. When set, only indexes
whose stored ``collation`` option matches are considered;
non-matching indexes are skipped so the caller falls back to
COLLSCAN (the safe semantics). Single-field equality / range /
``$in`` and compound bare-eq / compound prefix + trailing
operator all thread collation through. ``numericOrdering``
collations never match any index (parse to None at the gate)
and fall through to COLLSCAN.
"""
if not filter:
return None
if any(f.startswith("$") for f in filter):
return None
# Fast path: equality on _id alone is a direct primary-key point
# lookup on the documents table (keyed by encode_value(_id)), not a
# COLLSCAN — the _id_ index is virtual (no entries table), so the
# generic pickers below never match it. Timeseries collections are
# excluded: their doc keys carry a uniqueness suffix (duplicate
# _ids are legal there), so the reconstructed unsuffixed key would
# never match a row.
if len(filter) == 1 and "_id" in filter and not self._is_timeseries(db, coll):
id_keys = _id_point_lookup_keys(filter["_id"])
if id_keys is not None:
return id_keys
# Geo dispatch first — a $geoWithin / $geoIntersects / $near clause
# on a field with a 2dsphere or 2d index uses the cell-covering
# path. The picker returns None if no geo index covers the query,
# and we fall through to the regular pickers below.
geo_ids = self._try_geo_index_id_keys(db, coll, filter)
if geo_ids is not None:
return geo_ids
# Bare-equality filters of any size can use a compound index whose
# leading fields cover the filter set.
if all(not isinstance(v, dict) for v in filter.values()):
result = self._try_compound_eq_id_keys(db, coll, filter, collation=collation)
if result is not None:
return result
# Compound prefix + trailing operator field (eq fields then range/in).
if len(filter) >= 2:
result = self._try_compound_range_id_keys(db, coll, filter, collation=collation)
if result is not None:
return result
if len(filter) == 1:
field, value = next(iter(filter.items()))
# {field: {$exists: true}} rides a sparse single-field index on
# ``field`` — every sparse entry is a doc where the field is
# present, exactly the $exists:true match set. No value bound:
# the whole index scans.
if isinstance(value, dict) and len(value) == 1 and value.get("$exists"):
name = self._sparse_index_for_exists(db, coll, field)
if name is None:
return None
return self._all_id_keys_for_index(db, coll, name)
idx_match = self._find_leading_field_index(db, coll, field, filter, collation=collation)
if idx_match is None:
return None
return self._lookup_id_keys_via_leading_field(
db, coll, idx_match, value, collation=collation
)
# Multi-field filter: a single-field index can still serve it when every
# other filter field is absorbed by the index's (implied) partial filter.
match = self._single_field_partial_residual_match(db, coll, filter, collation=collation)
if match is None:
return None
_field, value, idx_match = match
return self._lookup_id_keys_via_leading_field(
db, coll, idx_match, value, collation=collation
)
def _candidates_iter(
self, db: str, coll: str, filter: dict[str, Any] | None
) -> list[tuple[bytes, bytes]]:
"""Return (id_key, blob) pairs that the write paths should consider.
If an index covers the filter, only the indexed candidates are
fetched; otherwise the full doc table is scanned. Either way,
BSON decode is left to the caller so non-matching docs don't pay
for it. Caller still applies ``matches()`` to the decoded doc —
index lookups can produce false-positive candidates for partial
scans (multikey, prefix overlap, etc).
"""
if filter:
id_keys = self._try_index_id_keys(db, coll, filter)
if id_keys is not None:
c = self._cursor(_DOC_TABLE)
out: list[tuple[bytes, bytes]] = []
# Same dedup contract as ``_docs_by_id_keys``: multikey
# indexes can yield duplicate id_keys for one doc.
for id_k in dict.fromkeys(id_keys):
c.reset()
c.set_key(db, coll, id_k)
if c.search() == 0:
out.append((id_k, bytes(c.get_value())))
return out
return list(self._scan_docs(db, coll))
def _find_leading_field_index(
self,
db: str,
coll: str,
field: str,
query: Mapping[str, Any] | None = None,
*,
collation: Any = None,
) -> tuple[str, int, bool] | None:
"""Best index whose leading field is ``field``.
Returns ``(name, direction, is_compound)``. Single-field indexes
win over compound (tighter scan, no separator math). All fields
must be ASC or DESC. Partial indexes are skipped unless ``query``
implies their ``partialFilterExpression``.
Multikey indexes are not skipped — ``_index_key_variants`` writes
per-element entries, so equality / range / ``$in`` lookups on
the leading field hit at least all true matches. The geo
``2dsphere`` / ``2d`` indexes have non-numeric direction values
and are excluded by the ASC/DESC check below.
``collation``: when set, an index is only considered if its
stored ``collation`` option produces the same :class:`Collation`
as the query's (or both are None). Mismatched indexes are
skipped — the caller falls back to COLLSCAN, which uses
``matches()`` with the query's collation. Matches mongod's
per-index collation semantics.
"""
partials = self._partial_filters(db, coll)
index_options = self._index_options_map(db, coll)
query = query or {}
compound_fallback: tuple[str, int, bool] | None = None
for name, key_spec, _sparse, _unique in self._all_indexes(db, coll):
pf = partials.get(name)
if pf is not None and not self._query_implies_partial(query, pf):
continue
idx_fields = list(key_spec)
if not idx_fields or idx_fields[0] != field:
continue
# Geo / hashed / text indexes carry string direction values
# ("2dsphere", "2d", "hashed", "text"); the bare equality
# picker can't drive them. Real numeric direction values are
# 1 / -1.
if any(key_spec[f] not in (1, -1) for f in idx_fields):
continue
# Collation gate: the index's stored collation must equal
# the query's effective collation (both None counts as a
# match). Indexes with a collation that doesn't support
# byte encoding (numericOrdering) parse to None here, so
# they're treated as "no collation" — correct for queries
# that also don't carry collation, wrong for queries that
# do. Conservative: gate by None-vs-None or exact match.
idx_coll = _parse_index_collation(index_options.get(name, {}).get("collation"))
if idx_coll != collation:
continue
d = int(key_spec[field])
if len(idx_fields) == 1:
return name, d, False
if compound_fallback is None:
compound_fallback = (name, d, True)
return compound_fallback
def _sparse_index_for_exists(self, db: str, coll: str, field: str) -> str | None:
"""Name of a sparse single-field index on ``field`` that can serve
``{field: {$exists: true}}`` at IXSCAN, or ``None``.
Only a **sparse** index qualifies: it omits docs missing the field,
so a full scan of its entries yields exactly the ``$exists: true``
matches. A non-sparse index has an entry per doc (missing fields
included), so it can't distinguish presence. Restricted to
single-field indexes — a compound sparse index in mongod drops a
doc only when *every* indexed field is missing, so its entries
don't line up with ``{leadingField: {$exists: true}}``. Collation-
independent: presence doesn't depend on string normalisation, so an
index of any collation serves the query (the post-scan ``matches()``
is the final arbiter regardless).
"""
for name, key_spec, sparse, _unique in self._all_indexes(db, coll):
if not sparse:
continue
idx_fields = list(key_spec)
if len(idx_fields) != 1 or idx_fields[0] != field:
continue
if key_spec[field] not in (1, -1):
continue
return name
return None
def _lookup_id_keys_via_leading_field(
self,
db: str,
coll: str,
idx_match: tuple[str, int, bool],
value: Any,
*,
collation: Any = None,
) -> list[bytes] | None:
name, direction, is_compound = idx_match
if not isinstance(value, dict):
return self._eq_id_keys_via_leading(
db, coll, name, direction, is_compound, value, collation=collation
)
if not value or not all(k.startswith("$") for k in value):
return None
if not all(op in self._RANGE_OPS for op in value):
return None
if "$in" in value:
if len(value) != 1 or not isinstance(value["$in"], list):
return None
seen: set[bytes] = set()
id_keys: list[bytes] = []
for v in value["$in"]:
if isinstance(v, dict):
return None
for id_k in self._eq_id_keys_via_leading(
db, coll, name, direction, is_compound, v, collation=collation
):
if id_k not in seen:
seen.add(id_k)
id_keys.append(id_k)
return id_keys
lower: bytes | None = None
lower_inclusive = True
upper: bytes | None = None
upper_inclusive = True
for op, bound in value.items():
if isinstance(bound, dict):
return None
if op == "$eq":
return self._eq_id_keys_via_leading(
db, coll, name, direction, is_compound, bound, collation=collation
)
kb = encode_value_directed(bound, direction, collation=collation)
# Operator semantics flip when stored bytes are inverted: in a
# DESC index, "x > 5" means we want stored bytes < enc_desc(5).
effective_op = op
if direction == -1:
effective_op = {"$gt": "$lt", "$gte": "$lte", "$lt": "$gt", "$lte": "$gte"}[op]
if effective_op == "$gt":
lower, lower_inclusive = kb, False
elif effective_op == "$gte":
lower, lower_inclusive = kb, True
elif effective_op == "$lt":
upper, upper_inclusive = kb, False
elif effective_op == "$lte":
upper, upper_inclusive = kb, True
if is_compound:
return self._range_scan_index_leading(
db, coll, name, lower, lower_inclusive, upper, upper_inclusive
)
return self._range_scan_index(
db, coll, name, lower, lower_inclusive, upper, upper_inclusive
)
def _eq_id_keys_via_leading(
self,
db: str,
coll: str,
name: str,
direction: int,
is_compound: bool,
value: Any,
*,
collation: Any = None,
) -> list[bytes]:
kb = encode_value_directed(value, direction, collation=collation)
if is_compound:
return self._scan_index_for_id_keys(db, coll, name, kb + COMPOUND_SEP, prefix=True)
return self._scan_index_for_id_keys(db, coll, name, kb)
def _pick_compound_eq_index(
self, db: str, coll: str, filter: dict[str, Any], *, collation: Any = None
) -> tuple[str, dict[str, Any]] | None:
"""Find the index that ``_try_compound_eq_id_keys`` would walk for ``filter``.
Returns ``(name, key_spec)`` of the chosen index, or ``None`` if no
index covers the filter as a leading prefix. Pure picker — does
not scan. Multikey indexes are eligible (per-element entries
cover equality lookups); the ASC/DESC direction check excludes
geo indexes.
``collation``: an index is only considered if its stored
``collation`` parses to the same :class:`Collation` as the
query's (or both None). Same exact-match gate as
``_find_leading_field_index``. Indexes whose stored collation
is ``numericOrdering`` parse to None here, so they look like
no-collation indexes — correct for no-collation queries,
wrong for numericOrdering queries; the latter fall through to
COLLSCAN regardless.
"""
filter_fields = set(filter)
partials = self._partial_filters(db, coll)
index_options = self._index_options_map(db, coll)
best: tuple[str, dict[str, Any]] | None = None
for name, key_spec, _sparse, _unique in self._all_indexes(db, coll):
pf = partials.get(name)
if pf is not None:
if not self._query_implies_partial(filter, pf):
continue
# Partial-filter clauses are guaranteed by the index itself,
# so they don't have to appear in the index key.
eff_fields = filter_fields - set(pf)
else:
eff_fields = filter_fields
idx_fields = list(key_spec.keys())
# Geo / hashed / text indexes (string direction values) can't
# serve a bare-equality compound lookup.
if any(key_spec[f] not in (1, -1) for f in idx_fields):
continue
idx_coll = _parse_index_collation(index_options.get(name, {}).get("collation"))
if idx_coll != collation:
continue
if len(idx_fields) < len(eff_fields):
continue
if set(idx_fields[: len(eff_fields)]) != eff_fields:
continue
if best is None or (len(list(best[1])) > len(idx_fields)):
best = (name, dict(key_spec))
if len(idx_fields) == len(eff_fields):
break
return best
def _try_compound_eq_id_keys(
self, db: str, coll: str, filter: dict[str, Any], *, collation: Any = None
) -> list[bytes] | None:
"""Bare-equality filter against a compound (or single-field) index prefix.
Picks an index whose leading fields (set-wise) match the filter's
fields, and runs an equality (full-cover) or prefix
(strict-leading-prefix) scan against it. Per-field index direction
is honoured by encoding each value with ``encode_value_directed``.
``collation`` propagates from the query: only collation-matching
indexes are picked, and the lookup bytes are built under the same
collation so they hit the same row as the index-write side.
"""
picked = self._pick_compound_eq_index(db, coll, filter, collation=collation)
if picked is None:
return None
name, key_spec = picked
idx_fields = list(key_spec)
# Build kb from the filter fields that are in the index (partial-filter
# clauses live outside the key and are guaranteed by index population).
prefix_fields = [f for f in idx_fields if f in filter]
parts = [
encode_value_directed(filter[f], int(key_spec[f]), collation=collation)
for f in prefix_fields
]
kb = COMPOUND_SEP.join(parts) if len(parts) > 1 else parts[0]
if len(prefix_fields) == len(idx_fields):
return self._scan_index_for_id_keys(db, coll, name, kb)
kb = kb + COMPOUND_SEP
return self._scan_index_for_id_keys(db, coll, name, kb, prefix=True)
def _partition_compound_range_filter(
self, filter: dict[str, Any]
) -> tuple[dict[str, Any], str, dict[str, Any]] | None:
"""Split a filter into ``(eq_fields, operator_field, operator_ops)``.
Returns ``None`` if the filter doesn't fit the compound-range
shape (any number of bare-equality fields plus exactly one
operator-form field whose ops are all in ``_RANGE_OPS``).
"""
eq_fields: dict[str, Any] = {}
operator_field: str | None = None
operator_ops: dict[str, Any] | None = None
for f, v in filter.items():
if isinstance(v, dict):
if not v or not all(k.startswith("$") for k in v):
return None
if not all(op in self._RANGE_OPS for op in v):
return None
if operator_field is not None:
return None
operator_field = f
operator_ops = v
else:
eq_fields[f] = v
if operator_field is None or not eq_fields:
return None
if operator_field in eq_fields:
return None
return eq_fields, operator_field, operator_ops or {}
def _pick_compound_range_index(
self, db: str, coll: str, filter: dict[str, Any], *, collation: Any = None
) -> tuple[str, dict[str, Any]] | None:
"""Find the index that ``_try_compound_range_id_keys`` would walk.
``collation``: an index is only considered if its stored
collation parses to the same :class:`Collation` as the query's
(or both None). Same exact-match gate as
``_pick_compound_eq_index`` and ``_find_leading_field_index``.
"""
parts = self._partition_compound_range_filter(filter)
if parts is None:
return None
eq_fields, operator_field, _operator_ops = parts
eq_set = set(eq_fields)
target_eq_count = len(eq_set)
partials = self._partial_filters(db, coll)
index_options = self._index_options_map(db, coll)
best: tuple[str, dict[str, Any]] | None = None
for name, key_spec, _sparse, _unique in self._all_indexes(db, coll):
pf = partials.get(name)
if pf is not None and not self._query_implies_partial(filter, pf):
continue
idx_fields = list(key_spec.keys())
# Geo / hashed / text indexes (string direction values) can't
# serve a compound prefix + trailing-operator lookup.
if any(key_spec[f] not in (1, -1) for f in idx_fields):
continue
idx_coll = _parse_index_collation(index_options.get(name, {}).get("collation"))
if idx_coll != collation:
continue
if len(idx_fields) <= target_eq_count:
continue
if set(idx_fields[:target_eq_count]) != eq_set:
continue
if idx_fields[target_eq_count] != operator_field:
continue
if best is None or len(list(best[1])) > len(idx_fields):
best = (name, dict(key_spec))
if len(idx_fields) == target_eq_count + 1:
break
return best
def _try_compound_range_id_keys(
self, db: str, coll: str, filter: dict[str, Any], *, collation: Any = None
) -> list[bytes] | None:
"""Compound-prefix lookup with a trailing operator field.
Filters of the form ``{a: 5, b: 10, c: {$gt: 20}}`` (any number of
leading bare-equality fields followed by exactly one operator-form
field) walk the compound index by pinning the prefix from the
equalities and applying the operator's bounds to the next field.
``collation`` propagates from the query: only collation-matching
indexes are picked, and every encoded value (prefix equalities
and trailing-operator bound) is built under the same collation.
"""
parts = self._partition_compound_range_filter(filter)
if parts is None:
return None
eq_fields, operator_field, operator_ops = parts
picked = self._pick_compound_range_index(db, coll, filter, collation=collation)
if picked is None:
return None
name, key_spec = picked
idx_fields = list(key_spec)
target_eq_count = len(eq_fields)
eq_field_names = idx_fields[:target_eq_count]
op_dir = int(key_spec[operator_field])
eq_parts = [
encode_value_directed(eq_fields[f], int(key_spec[f]), collation=collation)
for f in eq_field_names
]
prefix_kb = COMPOUND_SEP.join(eq_parts) if len(eq_parts) > 1 else eq_parts[0]
prefix_with_sep = prefix_kb + COMPOUND_SEP
if "$in" in operator_ops:
if len(operator_ops) != 1 or not isinstance(operator_ops["$in"], list):
return None
seen: set[bytes] = set()
id_keys: list[bytes] = []
for v in operator_ops["$in"]:
if isinstance(v, dict):
return None
kb = prefix_with_sep + encode_value_directed(v, op_dir, collation=collation)
use_prefix = len(idx_fields) > target_eq_count + 1
inner_kb = kb + COMPOUND_SEP if use_prefix else kb
for id_k in self._scan_index_for_id_keys(
db, coll, name, inner_kb, prefix=use_prefix
):
if id_k not in seen:
seen.add(id_k)
id_keys.append(id_k)
return id_keys
if "$eq" in operator_ops:
if len(operator_ops) != 1:
return None
kb = prefix_with_sep + encode_value_directed(
operator_ops["$eq"], op_dir, collation=collation
)
use_prefix = len(idx_fields) > target_eq_count + 1
inner_kb = kb + COMPOUND_SEP if use_prefix else kb
return self._scan_index_for_id_keys(db, coll, name, inner_kb, prefix=use_prefix)
lower: bytes | None = None
lower_inclusive = True
upper: bytes | None = None
upper_inclusive = True
for op, bound in operator_ops.items():
if isinstance(bound, dict):
return None
full = prefix_with_sep + encode_value_directed(bound, op_dir, collation=collation)
effective_op = op
if op_dir == -1:
effective_op = {"$gt": "$lt", "$gte": "$lte", "$lt": "$gt", "$lte": "$gte"}[op]
if effective_op == "$gt":
lower, lower_inclusive = full, False
elif effective_op == "$gte":
lower, lower_inclusive = full, True
elif effective_op == "$lt":
upper, upper_inclusive = full, False
elif effective_op == "$lte":
upper, upper_inclusive = full, True
else:
return None
return self._range_scan_index(
db,
coll,
name,
lower,
lower_inclusive,
upper,
upper_inclusive,
prefix=prefix_with_sep,
)
def _range_scan_index(
self,
db: str,
coll: str,
name: str,
lower: bytes | None,
lower_inclusive: bool,
upper: bytes | None,
upper_inclusive: bool,
*,
prefix: bytes | None = None,
) -> list[bytes]:
"""Range-scan the index entries for ``name``.
Optional ``prefix`` constrains the scan to entries whose escaped
kb starts with ``escape(prefix)`` — used by compound-index
prefix+range queries where leading equalities pin part of the kb.
"""
c = self._cursor(_IDX_ENTRIES_TABLE)
esc_prefix = _escape_kb(prefix) if prefix is not None else None
esc_lower = _escape_kb(lower) if lower is not None else None
esc_upper = _escape_kb(upper) if upper is not None else None
if esc_lower is not None:
seed = esc_lower + _ENTRY_SEP
elif esc_prefix is not None:
seed = esc_prefix
else:
seed = b""
c.set_key(db, coll, name, seed)
rc = c.search_near()
if rc == wt.WT_NOTFOUND:
return []
if rc < 0 and c.next() != 0:
return []
out: list[bytes] = []
while True:
k = c.get_key()
if (k[0], k[1], k[2]) != (db, coll, name):
break
packed = bytes(k[3])
row_esc, row_id = _unpack_entry(packed)
if esc_prefix is not None and not row_esc.startswith(esc_prefix):
break
if esc_lower is not None and not lower_inclusive and row_esc == esc_lower:
if c.next() != 0:
break
continue
if esc_upper is not None:
if upper_inclusive:
if row_esc > esc_upper:
break
elif row_esc >= esc_upper:
break
out.append(row_id)
if c.next() != 0:
break
return out
def _range_scan_index_leading(
self,
db: str,
coll: str,
name: str,
lower: bytes | None,
lower_inclusive: bool,
upper: bytes | None,
upper_inclusive: bool,
) -> list[bytes]:
"""Range-scan a compound index using only its leading field.
Each row's escaped kb is
``escape(enc(leading)) + escape(COMPOUND_SEP) + escape(enc(trailing...))``.
Boundary detection uses ``startswith(esc_X + esc_compound_sep)`` to
identify rows whose leading field equals ``X`` — the terminator
bytes of an escaped numeric encoding can overlap with the start of
the escaped compound separator, so a literal find/split on the
separator is unreliable.
"""
esc_compound_sep = _escape_kb(COMPOUND_SEP)
c = self._cursor(_IDX_ENTRIES_TABLE)
esc_lower = _escape_kb(lower) if lower is not None else None
esc_upper = _escape_kb(upper) if upper is not None else None
seed = esc_lower if esc_lower is not None else b""
c.set_key(db, coll, name, seed)
rc = c.search_near()
if rc == wt.WT_NOTFOUND:
return []
if rc < 0 and c.next() != 0:
return []
lower_eq_prefix = esc_lower + esc_compound_sep if esc_lower is not None else None
upper_eq_prefix = esc_upper + esc_compound_sep if esc_upper is not None else None
out: list[bytes] = []
while True:
k = c.get_key()
if (k[0], k[1], k[2]) != (db, coll, name):
break
packed = bytes(k[3])
row_esc, row_id = _unpack_entry(packed)
if (
lower_eq_prefix is not None
and not lower_inclusive
and row_esc.startswith(lower_eq_prefix)
):
if c.next() != 0:
break
continue
if esc_upper is not None:
if upper_inclusive:
if row_esc > esc_upper and not row_esc.startswith(upper_eq_prefix):
break
elif row_esc >= esc_upper:
break
out.append(row_id)
if c.next() != 0:
break
return out