"""WiredTiger-backed document store.
WiredTiger is the default storage engine for MongoDB. We use the same
engine here so that on-disk semantics line up with what test code would
see against a real ``mongod``.
Indexes use a sidecar entries table (``table:secantus_index_entries``)
with a single trailing ``u`` column packing
``escape(sortkey) + b"\\x00\\x00" + id_key``. The sortkey comes from
``secantus.sortkey`` (typed, byte-sortable BSON encoding), so the WT
B-tree gives us ordered access for free. ``find_matching`` routes a wide
range of filter shapes through the index — equality, ``$eq``, ``$in``,
``$gt``/``$gte``/``$lt``/``$lte`` on a single field, plus compound
indexes when filter fields cover a leading prefix (with optional range
on the next field). Sort-by-indexed-field walks the B-tree in order.
"""
from __future__ import annotations
import contextlib
import datetime as _dt
import os
import shutil
import tempfile
import threading
from collections.abc import Iterable, Mapping
from decimal import Decimal, InvalidOperation
from typing import Any
import bson
import wiredtiger as wt
from bson import Decimal128
from secantus.paths import get_path, has_path
from secantus.projection import apply_projection
from secantus.query import matches
from secantus.sortkey import COMPOUND_SEP, encode_value, encode_value_directed
from secantus.update import apply_update, find_positional_matches
_COLL_TABLE = "table:secantus_collections"
_DOC_TABLE = "table:secantus_documents"
_IDX_TABLE = "table:secantus_indexes"
_IDX_ENTRIES_TABLE = "table:secantus_index_entries"
_ENTRY_SEP = b"\x00\x00"
def _escape_kb(kb: bytes) -> bytes:
"""Order-preserving escape so ``\\x00\\x00`` is unambiguous as a separator."""
return kb.replace(b"\x00", b"\x00\xff")
def _pack_entry(kb: bytes, id_key: bytes) -> bytes:
"""Pack a sortable index-entry payload into a single ``u`` column.
WiredTiger length-prefixes ``u`` columns when they're not last in the
key, which breaks lexicographic comparison. Packing both fields into
one trailing ``u`` column lets the B-tree do the sort for us.
"""
return _escape_kb(kb) + _ENTRY_SEP + id_key
def _unpack_entry(packed: bytes) -> tuple[bytes, bytes]:
"""Return ``(escaped_kb, id_key)`` from a packed entry."""
sep = packed.find(_ENTRY_SEP)
return packed[:sep], packed[sep + 2 :]
class DuplicateKeyError(Exception):
def __init__(self, doc_id: Any) -> None:
super().__init__(f"duplicate _id: {doc_id!r}")
self.doc_id = doc_id
def _id_key(doc_id: Any) -> bytes:
"""Byte-sortable canonical bytes for an ``_id`` value.
Uses the same byte-sortable encoding the secondary-index entries
table relies on. Two consequences worth knowing:
* Cross-numeric collision: ``1 == 1.0 == Decimal128("1")`` produce
identical bytes (so they hit the same doc / clash on uniqueness),
because ``encode_value`` normalises numerics through ``Decimal``.
* Natural iteration: walking the doc table in WT-key order yields
docs in BSON cross-type sort order, which matches what real
MongoDB calls "natural order" for non-capped collections.
"""
return encode_value(doc_id)
def _doc_makes_multikey(doc: Mapping[str, Any], key_spec: Mapping[str, Any]) -> bool:
"""True if any field in ``key_spec`` resolves to a list value in ``doc``.
Such a value is encoded as a single composite array sortkey, so a
later scalar-equality query against this index would silently miss
the doc — the index must fall back to a full scan.
"""
return any(isinstance(get_path(dict(doc), field), list) for field in key_spec)
def _index_key(
doc: Mapping[str, Any], key_spec: Mapping[str, Any], *, sparse: bool
) -> bytes | None:
"""Direction-aware byte-sortable encoding for an index ``key_spec``.
Each field is encoded with ``encode_value_directed`` so ``-1``
(descending) fields get bitwise-inverted bytes, making a forward
B-tree walk yield values in descending order. Compound keys are
joined with ``\\x00\\x00`` between components.
"""
if sparse:
for field in key_spec:
if not has_path(dict(doc), field):
return None
fields = list(key_spec)
if len(fields) == 1:
d = int(key_spec[fields[0]])
return encode_value_directed(get_path(dict(doc), fields[0]), d)
parts = [encode_value_directed(get_path(dict(doc), f), int(key_spec[f])) for f in fields]
return COMPOUND_SEP.join(parts)
def _to_decimal(value: Any) -> Decimal:
if isinstance(value, Decimal128):
return value.to_decimal()
if isinstance(value, float):
return Decimal(repr(value))
return Decimal(value)
def _bson_type_rank(value: Any) -> int:
"""Rank for MongoDB's cross-type sort order. Lower rank sorts first."""
import datetime as _dt
from bson import Binary, MaxKey, MinKey, ObjectId, Regex, Timestamp
if isinstance(value, MinKey):
return 1
if value is None:
return 2
if isinstance(value, bool):
return 9
if isinstance(value, (int, float, Decimal128)):
return 3
if isinstance(value, str):
return 4
if isinstance(value, Mapping):
return 5
if isinstance(value, list):
return 6
if isinstance(value, (bytes, Binary)):
return 7
if isinstance(value, ObjectId):
return 8
if isinstance(value, _dt.datetime):
return 10
if isinstance(value, Timestamp):
return 11
if isinstance(value, Regex):
return 12
if isinstance(value, MaxKey):
return 13
return 5
class _SortKey:
__slots__ = ("val",)
def __init__(self, val: Any) -> None:
self.val = val
def __lt__(self, other: _SortKey) -> bool:
a, b = self.val, other.val
ra = _bson_type_rank(a)
rb = _bson_type_rank(b)
if ra != rb:
return ra < rb
if a is None or b is None:
return False
if isinstance(a, Decimal128) or isinstance(b, Decimal128):
try:
ad = _to_decimal(a)
bd = _to_decimal(b)
return bool(ad < bd)
except (InvalidOperation, ValueError):
pass
try:
return bool(a < b)
except TypeError:
return type(a).__name__ < type(b).__name__
def __eq__(self, other: object) -> bool:
return isinstance(other, _SortKey) and self.val == other.val
def sort_docs(
docs: list[dict[str, Any]], sort_spec: Mapping[str, Any] | None
) -> list[dict[str, Any]]:
if not sort_spec:
return docs
result = list(docs)
for field, direction in reversed(list(sort_spec.items())):
result.sort(
key=lambda d, f=field: _SortKey(get_path(d, f)),
reverse=(int(direction) == -1),
)
return result
_ID_INDEX_NAME = "_id_"
class IndexConflict(Exception):
def __init__(self, index_name: str, doc_id: Any) -> None:
super().__init__(f"E11000 duplicate key error in index {index_name}: _id={doc_id!r}")
self.index_name = index_name
self.doc_id = doc_id
class BadHint(Exception):
"""The ``hint`` passed to ``find_matching`` doesn't name an existing index."""
[docs]
class Storage:
def __init__(self, path: str = ":memory:") -> None:
self._lock = threading.RLock()
self._closed = False
self._tempdir: str | None = None
if path == ":memory:":
self._tempdir = tempfile.mkdtemp(prefix="secantus_wt_")
home = self._tempdir
config = "create,in_memory=true"
else:
os.makedirs(path, exist_ok=True)
home = path
config = "create"
self._conn = wt.wiredtiger_open(home, config)
self._tls = threading.local()
self._all_sessions: list[Any] = []
boot = self._conn.open_session()
try:
boot.create(_COLL_TABLE, "key_format=SS,value_format=u")
boot.create(_DOC_TABLE, "key_format=SSu,value_format=u")
boot.create(_IDX_TABLE, "key_format=SSS,value_format=u")
boot.create(_IDX_ENTRIES_TABLE, "key_format=SSSu,value_format=u")
finally:
boot.close()
def close(self) -> None:
with self._lock:
if self._closed:
return
self._closed = True
for s in self._all_sessions:
with contextlib.suppress(Exception):
s.close()
self._all_sessions.clear()
with contextlib.suppress(Exception):
self._conn.close()
if self._tempdir is not None:
shutil.rmtree(self._tempdir, ignore_errors=True)
self._tempdir = None
def _session(self) -> Any:
s = getattr(self._tls, "session", None)
if s is None:
s = self._conn.open_session()
self._tls.session = s
self._tls.cursors = {}
with self._lock:
self._all_sessions.append(s)
return s
def _cursor(self, table: str, *, overwrite: bool = True) -> Any:
self._session()
cursors: dict[tuple[str, bool], Any] = self._tls.cursors
key = (table, overwrite)
c = cursors.get(key)
if c is None:
cfg = None if overwrite else "overwrite=false"
c = self._tls.session.open_cursor(table, None, cfg)
cursors[key] = c
else:
c.reset()
return c
def _coll_options(self, db: str, coll: str) -> dict[str, Any] | None:
c = self._cursor(_COLL_TABLE)
c.set_key(db, coll)
rc = c.search()
if rc != 0:
return None
blob = bytes(c.get_value())
return bson.decode(blob) if blob else {}
def _ensure_collection(self, db: str, coll: str) -> None:
c = self._cursor(_COLL_TABLE)
c.set_key(db, coll)
if c.search() == 0:
return
c.reset()
c[db, coll] = b""
def collection_exists(self, db: str, coll: str) -> bool:
with self._lock:
return self._coll_options(db, coll) is not None
def create_collection(self, db: str, coll: str) -> bool:
with self._lock:
c = self._cursor(_COLL_TABLE)
c.set_key(db, coll)
if c.search() == 0:
return False
c.reset()
c[db, coll] = b""
return True
def _scan_docs(self, db: str, coll: str) -> Iterable[tuple[bytes, bytes]]:
c = self._cursor(_DOC_TABLE)
c.set_key(db, coll, b"")
rc = c.search_near()
if rc == wt.WT_NOTFOUND:
return
if rc < 0 and c.next() != 0:
return
while True:
k = c.get_key()
if k[0] != db or k[1] != coll:
return
yield bytes(k[2]), bytes(c.get_value())
if c.next() != 0:
return
def _all_docs(self, db: str, coll: str) -> list[dict[str, Any]]:
with self._lock:
return [bson.decode(blob) for _id_k, blob in self._scan_docs(db, coll)]
def _all_docs_with_id_key(self, db: str, coll: str) -> list[tuple[dict[str, Any], bytes]]:
with self._lock:
return [(bson.decode(blob), id_k) for id_k, blob in self._scan_docs(db, coll)]
def insert(
self, db: str, coll: str, docs: Iterable[dict[str, Any]], *, ordered: bool = True
) -> tuple[int, list[dict[str, Any]]]:
inserted = 0
errors: list[dict[str, Any]] = []
with self._lock:
self._ensure_collection(db, coll)
indexes = self._all_indexes(db, coll)
multikey_names = self._multikey_index_names(db, coll)
for index, doc in enumerate(docs):
if "_id" not in doc:
doc["_id"] = bson.ObjectId()
key = _id_key(doc["_id"])
conflict = self._unique_conflict(db, coll, doc, indexes, exclude_id_key=None)
if conflict is not None:
errors.append(
{
"index": index,
"code": 11000,
"errmsg": (
f"E11000 duplicate key error in index {conflict}: "
f"_id={doc['_id']!r}"
),
}
)
if ordered:
break
continue
blob = bson.encode(doc)
doc_cur = self._cursor(_DOC_TABLE, overwrite=False)
doc_cur.set_key(db, coll, key)
doc_cur.set_value(blob)
try:
doc_cur.insert()
except wt.WiredTigerError:
errors.append(
{
"index": index,
"code": 11000,
"errmsg": f"E11000 duplicate key error: _id {doc['_id']!r}",
}
)
if ordered:
break
continue
self._write_index_entries(db, coll, doc, indexes)
multikey_names = self._maybe_mark_multikey(db, coll, doc, indexes, multikey_names)
inserted += 1
return inserted, errors
def find_matching(
self,
db: str,
coll: str,
filter: dict[str, Any] | None = None,
*,
skip: int = 0,
limit: int = 0,
sort: Mapping[str, Any] | None = None,
projection: Mapping[str, Any] | None = None,
hint: str | Mapping[str, Any] | None = None,
) -> list[dict[str, Any]]:
filter = filter or {}
in_sort_order = False
with self._lock:
sort_field, sort_dir = self._single_sort_spec(sort)
if hint is not None:
resolved = self._resolve_hint(db, coll, hint)
candidates, in_sort_order = self._candidates_from_hint(
db, coll, resolved, sort_field, sort_dir
)
else:
candidates = self._try_index_lookup(db, coll, filter)
if candidates is not None and sort_field is not None:
if (
len(filter) == 1
and not next(iter(filter)).startswith("$")
and next(iter(filter)) == sort_field
):
in_sort_order = True
idx = self._find_leading_field_index(db, coll, sort_field, filter)
idx_dir = idx[1] if idx else 1
if sort_dir != idx_dir:
candidates = list(reversed(candidates))
elif candidates is None and not filter and sort_field is not None:
idx = self._find_leading_field_index(db, coll, sort_field, filter)
if idx is not None:
idx_name, idx_dir, _is_compound = idx
# If the index direction matches the sort direction,
# walk forward; if it's opposite, walk backward.
reverse = sort_dir != idx_dir
candidates = self._walk_index_in_order(db, coll, idx_name, reverse=reverse)
in_sort_order = True
if candidates is None:
candidates = [bson.decode(b) for _, b in self._scan_docs(db, coll)]
out = [d for d in candidates if matches(d, filter)]
if sort and not in_sort_order:
out = sort_docs(out, sort)
if skip:
out = out[skip:]
if limit > 0:
out = out[:limit]
if projection:
out = [apply_projection(d, projection) for d in out]
return out
def _resolve_hint(self, db: str, coll: str, hint: str | Mapping[str, Any]) -> str:
"""Resolve ``hint`` to an index name (or ``$natural``).
``hint`` may be an index name string, a key-spec dict matching an
existing index, ``"$natural"``, or ``{"$natural": +/-1}``. Anything
else raises ``BadHint`` so the command layer can return a Mongo
``BadValue`` error.
"""
if isinstance(hint, str):
if hint == "$natural":
return "$natural"
if hint == _ID_INDEX_NAME:
return _ID_INDEX_NAME
for name, _key_spec, _sparse, _unique in self._all_indexes(db, coll):
if name == hint:
return name
raise BadHint(f"hint {hint!r} does not correspond to an existing index")
if isinstance(hint, Mapping):
if list(hint) == ["$natural"]:
return "$natural"
if list(hint) == ["_id"] and int(hint["_id"]) == 1:
return _ID_INDEX_NAME
for name, key_spec, _sparse, _unique in self._all_indexes(db, coll):
if dict(key_spec) == dict(hint):
return name
raise BadHint(f"hint {dict(hint)!r} does not correspond to an existing index")
raise BadHint(f"invalid hint type: {type(hint).__name__}")
def _candidates_from_hint(
self,
db: str,
coll: str,
resolved: str,
sort_field: str | None,
sort_dir: int,
) -> tuple[list[dict[str, Any]], bool]:
"""Walk the index named by ``resolved`` (or full collection for $natural).
Returns ``(candidates, in_sort_order)`` where ``in_sort_order`` is
True when the hint's leading field matches the sort field — in
which case ``find_matching`` skips the post-sort step.
"""
if resolved == "$natural":
return [bson.decode(b) for _, b in self._scan_docs(db, coll)], False
if resolved == _ID_INDEX_NAME:
# The doc table is keyed by id_key; iterating it gives entries
# sorted by encoded _id, which matches the _id_ index walk.
docs = [bson.decode(b) for _, b in self._scan_docs(db, coll)]
in_order = sort_field == "_id"
if in_order and sort_dir == -1:
docs = list(reversed(docs))
return docs, in_order
# Find the index's leading field and its direction
leading: str | None = None
leading_dir = 1
for name, key_spec, _sparse, _unique in self._all_indexes(db, coll):
if name == resolved:
first = next(iter(key_spec))
leading = first
leading_dir = int(key_spec[first])
break
candidates = self._walk_index_in_order(db, coll, resolved, reverse=False)
in_order = sort_field is not None and sort_field == leading
if in_order and sort_dir != leading_dir:
candidates = list(reversed(candidates))
return candidates, in_order
@staticmethod
def _single_sort_spec(sort: Mapping[str, Any] | None) -> tuple[str | None, int]:
"""Return ``(field, direction)`` if ``sort`` is single-field +/-1, else ``(None, 0)``."""
if not sort or len(sort) != 1:
return None, 0
f, d = next(iter(sort.items()))
if f.startswith("$"):
return None, 0
try:
di = int(d)
except (TypeError, ValueError):
return None, 0
if di not in (-1, 1):
return None, 0
return f, di
def _single_field_index_for(self, db: str, coll: str, field: str) -> tuple[str, int] | None:
"""Return ``(index_name, direction)`` for a single-field index on
``field``, or ``None`` if no such index exists. Direction is the
index's stored sort direction (`+1` for ASC, `-1` for DESC)."""
for name, key_spec, _sparse, _unique in self._all_indexes(db, coll):
if list(key_spec.keys()) == [field]:
d = int(key_spec[field])
if d in (1, -1):
return name, d
return None
def _walk_index_in_order(
self, db: str, coll: str, name: str, *, reverse: bool = False
) -> list[dict[str, Any]]:
c = self._cursor(_IDX_ENTRIES_TABLE)
c.set_key(db, coll, name, b"")
rc = c.search_near()
if rc == wt.WT_NOTFOUND:
return []
if rc < 0 and c.next() != 0:
return []
id_keys: list[bytes] = []
while True:
k = c.get_key()
if (k[0], k[1], k[2]) != (db, coll, name):
break
packed = bytes(k[3])
_esc, row_id = _unpack_entry(packed)
id_keys.append(row_id)
if c.next() != 0:
break
if reverse:
id_keys.reverse()
return self._docs_by_id_keys(db, coll, id_keys)
[docs]
def explain_plan(
self,
db: str,
coll: str,
filter: dict[str, Any] | None = None,
*,
sort: Mapping[str, Any] | None = None,
hint: str | Mapping[str, Any] | None = None,
) -> dict[str, Any]:
"""Plan summary for what ``find_matching`` would do with these args.
No execution; mirrors the same routing decisions. Returns
``{"kind": "COLLSCAN"}`` or ``{"kind": "IXSCAN", "index_name",
"key_pattern", "direction"}``. ``direction`` is ``"forward"``
unless a sort spec inverts it relative to the chosen index.
"""
filter = filter or {}
with self._lock:
sort_field, sort_dir = self._single_sort_spec(sort)
if hint is not None:
try:
resolved = self._resolve_hint(db, coll, hint)
except BadHint:
return {"kind": "COLLSCAN"}
if resolved == "$natural":
return {"kind": "COLLSCAN"}
if resolved == _ID_INDEX_NAME:
direction = "forward"
if sort_field == "_id" and sort_dir == -1:
direction = "backward"
return {
"kind": "IXSCAN",
"index_name": _ID_INDEX_NAME,
"key_pattern": {"_id": 1},
"direction": direction,
}
key_spec = self._key_spec_for(db, coll, resolved)
if key_spec is None:
return {"kind": "COLLSCAN"}
return self._make_ixscan_plan(resolved, key_spec, sort_field, sort_dir)
picked = self._pick_index_for_filter(db, coll, filter)
if picked is not None:
name, key_spec = picked
return self._make_ixscan_plan(name, key_spec, sort_field, sort_dir)
if not filter and sort_field is not None:
idx = self._find_leading_field_index(db, coll, sort_field, filter)
if idx is not None:
name, _idx_dir, _is_compound = idx
key_spec = self._key_spec_for(db, coll, name)
if key_spec is not None:
return self._make_ixscan_plan(name, key_spec, sort_field, sort_dir)
return {"kind": "COLLSCAN"}
def _key_spec_for(self, db: str, coll: str, name: str) -> dict[str, Any] | None:
for n, key_spec, _sparse, _unique in self._all_indexes(db, coll):
if n == name:
return dict(key_spec)
return None
def _pick_index_for_filter(
self, db: str, coll: str, filter: dict[str, Any]
) -> tuple[str, dict[str, Any]] | None:
"""Mirror ``_try_index_lookup``'s index-selection (no execution)."""
if not filter:
return None
if any(f.startswith("$") for f in filter):
return None
if all(not isinstance(v, dict) for v in filter.values()):
picked = self._pick_compound_eq_index(db, coll, filter)
if picked is not None:
return picked
if len(filter) >= 2:
picked = self._pick_compound_range_index(db, coll, filter)
if picked is not None:
return picked
if len(filter) != 1:
return None
field, value = next(iter(filter.items()))
idx_match = self._find_leading_field_index(db, coll, field, filter)
if idx_match is None:
return None
if isinstance(value, dict):
if not value or not all(k.startswith("$") for k in value):
return None
if not all(op in self._RANGE_OPS for op in value):
return None
name, _direction, _is_compound = idx_match
key_spec = self._key_spec_for(db, coll, name)
if key_spec is None:
return None
return name, key_spec
@staticmethod
def _make_ixscan_plan(
name: str,
key_spec: Mapping[str, Any],
sort_field: str | None,
sort_dir: int,
) -> dict[str, Any]:
direction = "forward"
if sort_field is not None and sort_field in key_spec:
idx_dir = int(key_spec[sort_field])
if sort_dir != 0 and sort_dir != idx_dir:
direction = "backward"
return {
"kind": "IXSCAN",
"index_name": name,
"key_pattern": dict(key_spec),
"direction": direction,
}
def count_matching(self, db: str, coll: str, filter: dict[str, Any] | None = None) -> int:
if not filter:
with self._lock:
return sum(1 for _ in self._scan_docs(db, coll))
return sum(1 for doc in self._all_docs(db, coll) if matches(doc, filter))
[docs]
def collection_data_size(self, db: str, coll: str) -> int:
"""Sum of bson-encoded doc bytes for ``coll``.
Used by ``collStats`` / ``dbStats`` for ``size`` / ``dataSize``.
Best-effort estimate — doesn't include WT block overhead.
"""
with self._lock:
return sum(len(blob) for _id_k, blob in self._scan_docs(db, coll))
[docs]
def index_sizes(self, db: str, coll: str) -> dict[str, int]:
"""Map of index name → sum of packed entry-key bytes.
``_id_`` is reported separately as ``len(id_key)`` summed across
the doc table, so callers can include it alongside secondary
indexes for an accurate ``totalIndexSize``.
"""
with self._lock:
sizes: dict[str, int] = {}
id_size = sum(len(id_k) for id_k, _blob in self._scan_docs(db, coll))
if id_size:
sizes[_ID_INDEX_NAME] = id_size
entry_rows = self._collect_prefix(_IDX_ENTRIES_TABLE, (db, coll))
for k, _v in entry_rows:
name = k[2]
packed = bytes(k[3])
sizes[name] = sizes.get(name, 0) + len(packed)
return sizes
def update_matching(
self,
db: str,
coll: str,
filter: dict[str, Any],
update: dict[str, Any],
*,
multi: bool = False,
upsert: bool = False,
array_filters: list[dict[str, Any]] | None = None,
) -> dict[str, Any]:
matched = 0
modified = 0
upserted_id: Any = None
with self._lock:
self._ensure_collection(db, coll)
indexes = self._all_indexes(db, coll)
multikey_names = self._multikey_index_names(db, coll)
for doc in self._all_docs(db, coll):
if not matches(doc, filter):
continue
matched += 1
pos = find_positional_matches(doc, filter)
new = apply_update(doc, update, array_filters=array_filters, positional_matches=pos)
if new != doc:
new_id_key = _id_key(new["_id"])
conflict = self._unique_conflict(
db, coll, new, indexes, exclude_id_key=_id_key(doc["_id"])
)
if conflict is not None:
raise IndexConflict(conflict, new["_id"])
modified += 1
self._delete_index_entries(db, coll, doc, indexes)
doc_cur = self._cursor(_DOC_TABLE)
doc_cur[db, coll, new_id_key] = bson.encode(new)
self._write_index_entries(db, coll, new, indexes)
multikey_names = self._maybe_mark_multikey(
db, coll, new, indexes, multikey_names
)
if not multi:
break
if matched == 0 and upsert:
seed: dict[str, Any] = {}
for k, v in filter.items():
if not k.startswith("$") and not isinstance(v, dict):
seed[k] = v
new = apply_update(seed, update, is_upsert=True, array_filters=array_filters)
if "_id" not in new:
new["_id"] = bson.ObjectId()
upserted_id = new["_id"]
conflict = self._unique_conflict(db, coll, new, indexes, exclude_id_key=None)
if conflict is not None:
raise IndexConflict(conflict, new["_id"])
doc_cur = self._cursor(_DOC_TABLE)
doc_cur[db, coll, _id_key(upserted_id)] = bson.encode(new)
self._write_index_entries(db, coll, new, indexes)
self._maybe_mark_multikey(db, coll, new, indexes, multikey_names)
return {"matched": matched, "modified": modified, "upserted_id": upserted_id}
def delete_matching(self, db: str, coll: str, filter: dict[str, Any], *, limit: int = 0) -> int:
deleted = 0
with self._lock:
indexes = self._all_indexes(db, coll)
for doc in self._all_docs(db, coll):
if not matches(doc, filter):
continue
self._delete_index_entries(db, coll, doc, indexes)
doc_cur = self._cursor(_DOC_TABLE)
doc_cur.set_key(db, coll, _id_key(doc["_id"]))
doc_cur.remove()
deleted += 1
if limit > 0 and deleted >= limit:
break
return deleted
[docs]
def prune_ttl(
self,
db: str,
coll: str,
*,
now: _dt.datetime | None = None,
) -> int:
"""Delete docs whose indexed Date field is older than now - TTL.
For every index on ``coll`` with an ``expireAfterSeconds`` option,
walks the collection and deletes docs whose indexed field resolves
to a ``datetime`` older than ``now - expireAfterSeconds``. Docs
without the field, with non-date values, or with values inside the
TTL window are left in place. Real MongoDB runs this on a 60s
background sweeper; SecantusDB invokes it explicitly so tests can
drive expiry with an injected ``now``. Returns the number of docs
pruned.
"""
ttl_indexes: list[tuple[str, str, float]] = []
for name, key_spec, opts in self._iter_indexes(db, coll):
ttl = opts.get("expireAfterSeconds")
if not isinstance(ttl, (int, float)) or ttl < 0:
continue
field = next(iter(key_spec), None)
if not isinstance(field, str):
continue
ttl_indexes.append((name, field, float(ttl)))
if not ttl_indexes:
return 0
when = now if now is not None else _dt.datetime.now(_dt.UTC)
if when.tzinfo is None:
when = when.replace(tzinfo=_dt.UTC)
pruned = 0
with self._lock:
indexes = self._all_indexes(db, coll)
for doc in list(self._all_docs(db, coll)):
expired = False
for _name, field, ttl_seconds in ttl_indexes:
value = get_path(doc, field)
if not isinstance(value, _dt.datetime):
continue
value_aware = value if value.tzinfo else value.replace(tzinfo=_dt.UTC)
if (when - value_aware).total_seconds() > ttl_seconds:
expired = True
break
if not expired:
continue
self._delete_index_entries(db, coll, doc, indexes)
doc_cur = self._cursor(_DOC_TABLE)
doc_cur.set_key(db, coll, _id_key(doc["_id"]))
doc_cur.remove()
pruned += 1
return pruned
@staticmethod
def _table_kf(table: str) -> str:
return {
_COLL_TABLE: "SS",
_DOC_TABLE: "SSu",
_IDX_TABLE: "SSS",
_IDX_ENTRIES_TABLE: "SSSu",
}[table]
@staticmethod
def _smallest_for_kf(kf: str) -> tuple[Any, ...]:
return tuple(b"" if c == "u" else "" for c in kf)
def _collect_prefix(
self, table: str, prefix: tuple[Any, ...]
) -> list[tuple[tuple[Any, ...], Any]]:
c = self._cursor(table)
kf = self._table_kf(table)
seed = prefix + self._smallest_for_kf(kf)[len(prefix) :]
c.set_key(*seed)
rc = c.search_near()
if rc == wt.WT_NOTFOUND:
return []
if rc < 0 and c.next() != 0:
return []
out: list[tuple[tuple[Any, ...], Any]] = []
while True:
k = tuple(c.get_key())
if k[: len(prefix)] != prefix:
break
v = c.get_value()
out.append((k, bytes(v) if isinstance(v, (bytes, bytearray)) else v))
if c.next() != 0:
break
return out
def _delete_keys(self, table: str, keys: list[tuple[Any, ...]]) -> None:
if not keys:
return
c = self._cursor(table)
for k in keys:
c.set_key(*k)
c.remove()
c.reset()
def drop_collection(self, db: str, coll: str) -> bool:
with self._lock:
existed = self._coll_options(db, coll) is not None
for tbl in (_DOC_TABLE, _IDX_TABLE, _IDX_ENTRIES_TABLE):
rows = self._collect_prefix(tbl, (db, coll))
self._delete_keys(tbl, [k for k, _ in rows])
c = self._cursor(_COLL_TABLE)
c.set_key(db, coll)
if c.search() == 0:
c.remove()
return existed
def drop_database(self, db: str) -> None:
with self._lock:
for tbl in (_DOC_TABLE, _IDX_TABLE, _IDX_ENTRIES_TABLE, _COLL_TABLE):
rows = self._collect_prefix(tbl, (db,))
self._delete_keys(tbl, [k for k, _ in rows])
def rename_collection(
self,
src_db: str,
src_coll: str,
dst_db: str,
dst_coll: str,
*,
drop_target: bool = False,
) -> tuple[bool, str | None]:
with self._lock:
if self._coll_options(src_db, src_coll) is None:
return False, f"source namespace does not exist: {src_db}.{src_coll}"
if (src_db, src_coll) == (dst_db, dst_coll):
return True, None
if self._coll_options(dst_db, dst_coll) is not None:
if not drop_target:
return False, f"target namespace exists: {dst_db}.{dst_coll}"
for tbl in (_DOC_TABLE, _IDX_TABLE, _IDX_ENTRIES_TABLE):
rows = self._collect_prefix(tbl, (dst_db, dst_coll))
self._delete_keys(tbl, [k for k, _ in rows])
c = self._cursor(_COLL_TABLE)
c.set_key(dst_db, dst_coll)
if c.search() == 0:
c.remove()
for tbl in (_DOC_TABLE, _IDX_TABLE, _IDX_ENTRIES_TABLE):
rows = self._collect_prefix(tbl, (src_db, src_coll))
self._delete_keys(tbl, [k for k, _ in rows])
c = self._cursor(tbl)
for k, v in rows:
new_k = (dst_db, dst_coll) + k[2:]
c.set_key(*new_k)
c.set_value(v)
c.insert()
c.reset()
ensure = self._cursor(_COLL_TABLE)
ensure.set_key(dst_db, dst_coll)
if ensure.search() != 0:
ensure.reset()
ensure[dst_db, dst_coll] = b""
ensure.reset()
ensure.set_key(src_db, src_coll)
if ensure.search() == 0:
ensure.remove()
return True, None
def list_collections(self, db: str) -> list[str]:
with self._lock:
c = self._cursor(_COLL_TABLE)
c.set_key(db, "")
rc = c.search_near()
if rc == wt.WT_NOTFOUND:
return []
if rc < 0 and c.next() != 0:
return []
out: list[str] = []
while True:
k = c.get_key()
if k[0] != db:
break
out.append(k[1])
if c.next() != 0:
break
return sorted(out)
def list_databases(self) -> list[str]:
with self._lock:
c = self._cursor(_COLL_TABLE)
seen: set[str] = set()
rc = c.next()
while rc == 0:
k = c.get_key()
seen.add(k[0])
rc = c.next()
return sorted(seen)
def create_index(
self,
db: str,
coll: str,
name: str,
key_spec: Mapping[str, Any],
options: Mapping[str, Any] | None = None,
) -> bool:
if name == _ID_INDEX_NAME:
return False
options = dict(options or {})
with self._lock:
self._ensure_collection(db, coll)
c = self._cursor(_IDX_TABLE)
c.set_key(db, coll, name)
if c.search() == 0:
return False
sparse = bool(options.get("sparse"))
unique = bool(options.get("unique"))
partial_filter = options.get("partialFilterExpression")
if not isinstance(partial_filter, Mapping) or not partial_filter:
partial_filter = None
if unique:
seen: dict[bytes, Any] = {}
for d in self._all_docs(db, coll):
if partial_filter is not None and not matches(d, partial_filter):
continue
key = _index_key(d, key_spec, sparse=sparse)
if key is None:
continue
if key in seen:
raise IndexConflict(name, d.get("_id"))
seen[key] = d.get("_id")
multikey = False
for d in self._all_docs(db, coll):
if partial_filter is not None and not matches(d, partial_filter):
continue
if _doc_makes_multikey(d, key_spec):
multikey = True
break
if multikey:
options["multikey"] = True
payload = bson.encode({"key": dict(key_spec), "options": options})
c.reset()
c[db, coll, name] = payload
entry_cur = self._cursor(_IDX_ENTRIES_TABLE)
for d in self._all_docs(db, coll):
if partial_filter is not None and not matches(d, partial_filter):
continue
kb = _index_key(d, dict(key_spec), sparse=sparse)
if kb is None:
continue
entry_cur.reset()
entry_cur[db, coll, name, _pack_entry(kb, _id_key(d["_id"]))] = b""
return True
def list_indexes(self, db: str, coll: str) -> list[dict[str, Any]]:
with self._lock:
if self._coll_options(db, coll) is None:
return []
out: list[dict[str, Any]] = [{"v": 2, "key": {"_id": 1}, "name": _ID_INDEX_NAME}]
for name, key_spec, opts in self._iter_indexes(db, coll):
entry: dict[str, Any] = {"v": 2, "key": key_spec, "name": name}
for k, v in opts.items():
entry[k] = v
out.append(entry)
out.sort(key=lambda e: e["name"])
return out
def _iter_indexes(
self, db: str, coll: str
) -> Iterable[tuple[str, dict[str, Any], dict[str, Any]]]:
c = self._cursor(_IDX_TABLE)
c.set_key(db, coll, "")
rc = c.search_near()
if rc == wt.WT_NOTFOUND:
return
if rc < 0 and c.next() != 0:
return
while True:
k = c.get_key()
if k[0] != db or k[1] != coll:
return
payload = bson.decode(bytes(c.get_value()))
yield k[2], payload.get("key", {}), payload.get("options", {})
if c.next() != 0:
return
def drop_index(self, db: str, coll: str, name: str) -> bool:
if name == _ID_INDEX_NAME:
return False
with self._lock:
c = self._cursor(_IDX_TABLE)
c.set_key(db, coll, name)
if c.search() != 0:
return False
c.remove()
entry_rows = self._collect_prefix(_IDX_ENTRIES_TABLE, (db, coll, name))
self._delete_keys(_IDX_ENTRIES_TABLE, [k for k, _ in entry_rows])
return True
def drop_all_indexes(self, db: str, coll: str) -> int:
with self._lock:
rows = self._collect_prefix(_IDX_TABLE, (db, coll))
self._delete_keys(_IDX_TABLE, [k for k, _ in rows])
entry_rows = self._collect_prefix(_IDX_ENTRIES_TABLE, (db, coll))
self._delete_keys(_IDX_ENTRIES_TABLE, [k for k, _ in entry_rows])
return len(rows)
def _all_indexes(self, db: str, coll: str) -> list[tuple[str, dict[str, Any], bool, bool]]:
"""Every non-_id_ index: (name, key_spec, sparse, unique)."""
out: list[tuple[str, dict[str, Any], bool, bool]] = []
for name, key_spec, opts in list(self._iter_indexes(db, coll)):
out.append((name, key_spec, bool(opts.get("sparse")), bool(opts.get("unique"))))
return out
def _partial_filters(self, db: str, coll: str) -> dict[str, dict[str, Any]]:
"""Map of index name → ``partialFilterExpression`` for indexes that have one.
Indexes without a partial filter are absent from the dict.
"""
out: dict[str, dict[str, Any]] = {}
for name, _key_spec, opts in self._iter_indexes(db, coll):
pf = opts.get("partialFilterExpression")
if isinstance(pf, Mapping) and pf:
out[name] = dict(pf)
return out
@staticmethod
def _query_implies_partial(query: Mapping[str, Any], partial: Mapping[str, Any]) -> bool:
"""True if ``query`` is at least as restrictive as ``partial`` — every
key/value pair in ``partial`` appears with the same bare value in
``query``. Conservative: anything more sophisticated (operator-form
clauses, $and, etc.) is treated as not implying the partial filter.
"""
for key, value in partial.items():
if key not in query:
return False
if query[key] != value:
return False
return True
def _multikey_index_names(self, db: str, coll: str) -> set[str]:
"""Names of indexes flagged ``multikey`` (must fall back to scan).
Without true multi-key indexing, an index where any doc has a
list-valued field can't serve scalar-element matches — so the
pickers skip these names and ``find_matching`` falls back to a
full scan.
"""
return {
name for name, _key_spec, opts in self._iter_indexes(db, coll) if opts.get("multikey")
}
def _maybe_mark_multikey(
self,
db: str,
coll: str,
doc: Mapping[str, Any],
indexes: list[tuple[str, dict[str, Any], bool, bool]],
already_multikey: set[str],
) -> set[str]:
"""For each non-multikey index, flag it if ``doc`` has an array
value on any indexed field. Returns the (possibly grown) set of
multikey index names so the caller can avoid re-checking.
"""
c = self._cursor(_IDX_TABLE)
for name, key_spec, _sparse, _unique in indexes:
if name in already_multikey:
continue
if not _doc_makes_multikey(doc, key_spec):
continue
c.reset()
c.set_key(db, coll, name)
if c.search() != 0:
continue
payload = bson.decode(bytes(c.get_value()))
opts = dict(payload.get("options") or {})
if opts.get("multikey"):
already_multikey.add(name)
continue
opts["multikey"] = True
payload["options"] = opts
c.reset()
c[db, coll, name] = bson.encode(payload)
already_multikey.add(name)
return already_multikey
def _write_index_entries(
self,
db: str,
coll: str,
doc: dict[str, Any],
indexes: list[tuple[str, dict[str, Any], bool, bool]],
) -> None:
if not indexes:
return
c = self._cursor(_IDX_ENTRIES_TABLE)
id_k = _id_key(doc["_id"])
partials = self._partial_filters(db, coll)
for name, key_spec, sparse, _unique in indexes:
pf = partials.get(name)
if pf is not None and not matches(doc, pf):
continue
kb = _index_key(doc, key_spec, sparse=sparse)
if kb is None:
continue
c.reset()
c[db, coll, name, _pack_entry(kb, id_k)] = b""
def _delete_index_entries(
self,
db: str,
coll: str,
doc: dict[str, Any],
indexes: list[tuple[str, dict[str, Any], bool, bool]],
) -> None:
if not indexes:
return
c = self._cursor(_IDX_ENTRIES_TABLE)
id_k = _id_key(doc["_id"])
partials = self._partial_filters(db, coll)
for name, key_spec, sparse, _unique in indexes:
pf = partials.get(name)
if pf is not None and not matches(doc, pf):
continue
kb = _index_key(doc, key_spec, sparse=sparse)
if kb is None:
continue
c.reset()
c.set_key(db, coll, name, _pack_entry(kb, id_k))
if c.search() == 0:
c.remove()
def _unique_conflict(
self,
db: str,
coll: str,
candidate_doc: dict[str, Any],
indexes: list[tuple[str, dict[str, Any], bool, bool]],
*,
exclude_id_key: bytes | None,
) -> str | None:
if not indexes:
return None
c = self._cursor(_IDX_ENTRIES_TABLE)
partials = self._partial_filters(db, coll)
for name, key_spec, sparse, unique in indexes:
if not unique:
continue
pf = partials.get(name)
if pf is not None and not matches(candidate_doc, pf):
continue
kb = _index_key(candidate_doc, key_spec, sparse=sparse)
if kb is None:
continue
esc_kb = _escape_kb(kb)
seed = esc_kb + _ENTRY_SEP
c.reset()
c.set_key(db, coll, name, seed)
rc = c.search_near()
if rc == wt.WT_NOTFOUND:
continue
if rc < 0 and c.next() != 0:
continue
while True:
k = c.get_key()
if (k[0], k[1], k[2]) != (db, coll, name):
break
packed = bytes(k[3])
row_esc, row_id = _unpack_entry(packed)
if row_esc != esc_kb:
break
if exclude_id_key is None or row_id != exclude_id_key:
return name
if c.next() != 0:
break
return None
def _scan_index_for_id_keys(
self, db: str, coll: str, name: str, kb: bytes, *, prefix: bool = False
) -> list[bytes]:
"""Walk the index entries for ``name`` matching ``kb``.
With ``prefix=False`` (default), only rows whose ``escaped_kb`` is
exactly equal to ``escape(kb)`` are returned — equality lookup.
With ``prefix=True``, any row whose ``escaped_kb`` starts with
``escape(kb)`` is returned — compound-prefix lookup.
"""
c = self._cursor(_IDX_ENTRIES_TABLE)
esc_kb = _escape_kb(kb)
seed = esc_kb if prefix else esc_kb + _ENTRY_SEP
c.set_key(db, coll, name, seed)
rc = c.search_near()
if rc == wt.WT_NOTFOUND:
return []
if rc < 0 and c.next() != 0:
return []
out: list[bytes] = []
while True:
k = c.get_key()
if (k[0], k[1], k[2]) != (db, coll, name):
break
packed = bytes(k[3])
row_esc, row_id = _unpack_entry(packed)
if prefix:
if not row_esc.startswith(esc_kb):
break
elif row_esc != esc_kb:
break
out.append(row_id)
if c.next() != 0:
break
return out
def _docs_by_id_keys(self, db: str, coll: str, id_keys: list[bytes]) -> list[dict[str, Any]]:
if not id_keys:
return []
c = self._cursor(_DOC_TABLE)
out: list[dict[str, Any]] = []
for id_k in id_keys:
c.reset()
c.set_key(db, coll, id_k)
if c.search() == 0:
out.append(bson.decode(bytes(c.get_value())))
return out
_RANGE_OPS: tuple[str, ...] = ("$eq", "$gt", "$gte", "$lt", "$lte", "$in")
def _try_index_lookup(
self, db: str, coll: str, filter: dict[str, Any]
) -> list[dict[str, Any]] | None:
if not filter:
return None
if any(f.startswith("$") for f in filter):
return None
# Bare-equality filters of any size can use a compound index whose
# leading fields cover the filter set.
if all(not isinstance(v, dict) for v in filter.values()):
result = self._try_compound_eq_lookup(db, coll, filter)
if result is not None:
return result
# Compound prefix + trailing operator field (eq fields then range/in).
if len(filter) >= 2:
result = self._try_compound_range_lookup(db, coll, filter)
if result is not None:
return result
if len(filter) != 1:
return None
field, value = next(iter(filter.items()))
idx_match = self._find_leading_field_index(db, coll, field, filter)
if idx_match is None:
return None
return self._lookup_via_leading_field(db, coll, idx_match, value)
def _find_leading_field_index(
self,
db: str,
coll: str,
field: str,
query: Mapping[str, Any] | None = None,
) -> tuple[str, int, bool] | None:
"""Best index whose leading field is ``field``.
Returns ``(name, direction, is_compound)``. Single-field indexes
win over compound (tighter scan, no separator math). All fields
must be ASC or DESC. Partial indexes are skipped unless ``query``
implies their ``partialFilterExpression``.
"""
multikey = self._multikey_index_names(db, coll)
partials = self._partial_filters(db, coll)
query = query or {}
compound_fallback: tuple[str, int, bool] | None = None
for name, key_spec, _sparse, _unique in self._all_indexes(db, coll):
if name in multikey:
continue
pf = partials.get(name)
if pf is not None and not self._query_implies_partial(query, pf):
continue
idx_fields = list(key_spec)
if not idx_fields or idx_fields[0] != field:
continue
if any(int(key_spec[f]) not in (1, -1) for f in idx_fields):
continue
d = int(key_spec[field])
if len(idx_fields) == 1:
return name, d, False
if compound_fallback is None:
compound_fallback = (name, d, True)
return compound_fallback
def _lookup_via_leading_field(
self,
db: str,
coll: str,
idx_match: tuple[str, int, bool],
value: Any,
) -> list[dict[str, Any]] | None:
name, direction, is_compound = idx_match
if not isinstance(value, dict):
return self._docs_by_id_keys(
db,
coll,
self._eq_id_keys_via_leading(db, coll, name, direction, is_compound, value),
)
if not value or not all(k.startswith("$") for k in value):
return None
if not all(op in self._RANGE_OPS for op in value):
return None
if "$in" in value:
if len(value) != 1 or not isinstance(value["$in"], list):
return None
seen: set[bytes] = set()
id_keys: list[bytes] = []
for v in value["$in"]:
if isinstance(v, dict):
return None
for id_k in self._eq_id_keys_via_leading(db, coll, name, direction, is_compound, v):
if id_k not in seen:
seen.add(id_k)
id_keys.append(id_k)
return self._docs_by_id_keys(db, coll, id_keys)
lower: bytes | None = None
lower_inclusive = True
upper: bytes | None = None
upper_inclusive = True
for op, bound in value.items():
if isinstance(bound, dict):
return None
if op == "$eq":
return self._docs_by_id_keys(
db,
coll,
self._eq_id_keys_via_leading(db, coll, name, direction, is_compound, bound),
)
kb = encode_value_directed(bound, direction)
# Operator semantics flip when stored bytes are inverted: in a
# DESC index, "x > 5" means we want stored bytes < enc_desc(5).
effective_op = op
if direction == -1:
effective_op = {"$gt": "$lt", "$gte": "$lte", "$lt": "$gt", "$lte": "$gte"}[op]
if effective_op == "$gt":
lower, lower_inclusive = kb, False
elif effective_op == "$gte":
lower, lower_inclusive = kb, True
elif effective_op == "$lt":
upper, upper_inclusive = kb, False
elif effective_op == "$lte":
upper, upper_inclusive = kb, True
if is_compound:
id_keys = self._range_scan_index_leading(
db, coll, name, lower, lower_inclusive, upper, upper_inclusive
)
else:
id_keys = self._range_scan_index(
db, coll, name, lower, lower_inclusive, upper, upper_inclusive
)
return self._docs_by_id_keys(db, coll, id_keys)
def _eq_id_keys_via_leading(
self,
db: str,
coll: str,
name: str,
direction: int,
is_compound: bool,
value: Any,
) -> list[bytes]:
kb = encode_value_directed(value, direction)
if is_compound:
return self._scan_index_for_id_keys(db, coll, name, kb + COMPOUND_SEP, prefix=True)
return self._scan_index_for_id_keys(db, coll, name, kb)
def _pick_compound_eq_index(
self, db: str, coll: str, filter: dict[str, Any]
) -> tuple[str, dict[str, Any]] | None:
"""Find the index that ``_try_compound_eq_lookup`` would walk for ``filter``.
Returns ``(name, key_spec)`` of the chosen index, or ``None`` if no
index covers the filter as a leading prefix. Pure picker — does
not scan.
"""
filter_fields = set(filter)
multikey = self._multikey_index_names(db, coll)
partials = self._partial_filters(db, coll)
best: tuple[str, dict[str, Any]] | None = None
for name, key_spec, _sparse, _unique in self._all_indexes(db, coll):
if name in multikey:
continue
pf = partials.get(name)
if pf is not None:
if not self._query_implies_partial(filter, pf):
continue
# Partial-filter clauses are guaranteed by the index itself,
# so they don't have to appear in the index key.
eff_fields = filter_fields - set(pf)
else:
eff_fields = filter_fields
idx_fields = list(key_spec.keys())
if any(int(key_spec[f]) not in (1, -1) for f in idx_fields):
continue
if len(idx_fields) < len(eff_fields):
continue
if set(idx_fields[: len(eff_fields)]) != eff_fields:
continue
if best is None or (len(list(best[1])) > len(idx_fields)):
best = (name, dict(key_spec))
if len(idx_fields) == len(eff_fields):
break
return best
def _try_compound_eq_lookup(
self, db: str, coll: str, filter: dict[str, Any]
) -> list[dict[str, Any]] | None:
"""Bare-equality filter against a compound (or single-field) index prefix.
Picks an index whose leading fields (set-wise) match the filter's
fields, and runs an equality (full-cover) or prefix
(strict-leading-prefix) scan against it. Per-field index direction
is honoured by encoding each value with ``encode_value_directed``.
"""
picked = self._pick_compound_eq_index(db, coll, filter)
if picked is None:
return None
name, key_spec = picked
idx_fields = list(key_spec)
# Build kb from the filter fields that are in the index (partial-filter
# clauses live outside the key and are guaranteed by index population).
prefix_fields = [f for f in idx_fields if f in filter]
parts = [encode_value_directed(filter[f], int(key_spec[f])) for f in prefix_fields]
kb = COMPOUND_SEP.join(parts) if len(parts) > 1 else parts[0]
if len(prefix_fields) == len(idx_fields):
id_keys = self._scan_index_for_id_keys(db, coll, name, kb)
else:
kb = kb + COMPOUND_SEP
id_keys = self._scan_index_for_id_keys(db, coll, name, kb, prefix=True)
return self._docs_by_id_keys(db, coll, id_keys)
def _partition_compound_range_filter(
self, filter: dict[str, Any]
) -> tuple[dict[str, Any], str, dict[str, Any]] | None:
"""Split a filter into ``(eq_fields, operator_field, operator_ops)``.
Returns ``None`` if the filter doesn't fit the compound-range
shape (any number of bare-equality fields plus exactly one
operator-form field whose ops are all in ``_RANGE_OPS``).
"""
eq_fields: dict[str, Any] = {}
operator_field: str | None = None
operator_ops: dict[str, Any] | None = None
for f, v in filter.items():
if isinstance(v, dict):
if not v or not all(k.startswith("$") for k in v):
return None
if not all(op in self._RANGE_OPS for op in v):
return None
if operator_field is not None:
return None
operator_field = f
operator_ops = v
else:
eq_fields[f] = v
if operator_field is None or not eq_fields:
return None
if operator_field in eq_fields:
return None
return eq_fields, operator_field, operator_ops or {}
def _pick_compound_range_index(
self, db: str, coll: str, filter: dict[str, Any]
) -> tuple[str, dict[str, Any]] | None:
"""Find the index that ``_try_compound_range_lookup`` would walk."""
parts = self._partition_compound_range_filter(filter)
if parts is None:
return None
eq_fields, operator_field, _operator_ops = parts
eq_set = set(eq_fields)
target_eq_count = len(eq_set)
multikey = self._multikey_index_names(db, coll)
partials = self._partial_filters(db, coll)
best: tuple[str, dict[str, Any]] | None = None
for name, key_spec, _sparse, _unique in self._all_indexes(db, coll):
if name in multikey:
continue
pf = partials.get(name)
if pf is not None and not self._query_implies_partial(filter, pf):
continue
idx_fields = list(key_spec.keys())
if any(int(key_spec[f]) not in (1, -1) for f in idx_fields):
continue
if len(idx_fields) <= target_eq_count:
continue
if set(idx_fields[:target_eq_count]) != eq_set:
continue
if idx_fields[target_eq_count] != operator_field:
continue
if best is None or len(list(best[1])) > len(idx_fields):
best = (name, dict(key_spec))
if len(idx_fields) == target_eq_count + 1:
break
return best
def _try_compound_range_lookup(
self, db: str, coll: str, filter: dict[str, Any]
) -> list[dict[str, Any]] | None:
"""Compound-prefix lookup with a trailing operator field.
Filters of the form ``{a: 5, b: 10, c: {$gt: 20}}`` (any number of
leading bare-equality fields followed by exactly one operator-form
field) walk the compound index by pinning the prefix from the
equalities and applying the operator's bounds to the next field.
"""
parts = self._partition_compound_range_filter(filter)
if parts is None:
return None
eq_fields, operator_field, operator_ops = parts
picked = self._pick_compound_range_index(db, coll, filter)
if picked is None:
return None
name, key_spec = picked
idx_fields = list(key_spec)
target_eq_count = len(eq_fields)
eq_field_names = idx_fields[:target_eq_count]
op_dir = int(key_spec[operator_field])
eq_parts = [encode_value_directed(eq_fields[f], int(key_spec[f])) for f in eq_field_names]
prefix_kb = COMPOUND_SEP.join(eq_parts) if len(eq_parts) > 1 else eq_parts[0]
prefix_with_sep = prefix_kb + COMPOUND_SEP
if "$in" in operator_ops:
if len(operator_ops) != 1 or not isinstance(operator_ops["$in"], list):
return None
seen: set[bytes] = set()
id_keys: list[bytes] = []
for v in operator_ops["$in"]:
if isinstance(v, dict):
return None
kb = prefix_with_sep + encode_value_directed(v, op_dir)
use_prefix = len(idx_fields) > target_eq_count + 1
inner_kb = kb + COMPOUND_SEP if use_prefix else kb
for id_k in self._scan_index_for_id_keys(
db, coll, name, inner_kb, prefix=use_prefix
):
if id_k not in seen:
seen.add(id_k)
id_keys.append(id_k)
return self._docs_by_id_keys(db, coll, id_keys)
if "$eq" in operator_ops:
if len(operator_ops) != 1:
return None
kb = prefix_with_sep + encode_value_directed(operator_ops["$eq"], op_dir)
use_prefix = len(idx_fields) > target_eq_count + 1
inner_kb = kb + COMPOUND_SEP if use_prefix else kb
id_keys = self._scan_index_for_id_keys(db, coll, name, inner_kb, prefix=use_prefix)
return self._docs_by_id_keys(db, coll, id_keys)
lower: bytes | None = None
lower_inclusive = True
upper: bytes | None = None
upper_inclusive = True
for op, bound in operator_ops.items():
if isinstance(bound, dict):
return None
full = prefix_with_sep + encode_value_directed(bound, op_dir)
effective_op = op
if op_dir == -1:
effective_op = {"$gt": "$lt", "$gte": "$lte", "$lt": "$gt", "$lte": "$gte"}[op]
if effective_op == "$gt":
lower, lower_inclusive = full, False
elif effective_op == "$gte":
lower, lower_inclusive = full, True
elif effective_op == "$lt":
upper, upper_inclusive = full, False
elif effective_op == "$lte":
upper, upper_inclusive = full, True
else:
return None
id_keys = self._range_scan_index(
db,
coll,
name,
lower,
lower_inclusive,
upper,
upper_inclusive,
prefix=prefix_with_sep,
)
return self._docs_by_id_keys(db, coll, id_keys)
def _range_scan_index(
self,
db: str,
coll: str,
name: str,
lower: bytes | None,
lower_inclusive: bool,
upper: bytes | None,
upper_inclusive: bool,
*,
prefix: bytes | None = None,
) -> list[bytes]:
"""Range-scan the index entries for ``name``.
Optional ``prefix`` constrains the scan to entries whose escaped
kb starts with ``escape(prefix)`` — used by compound-index
prefix+range queries where leading equalities pin part of the kb.
"""
c = self._cursor(_IDX_ENTRIES_TABLE)
esc_prefix = _escape_kb(prefix) if prefix is not None else None
esc_lower = _escape_kb(lower) if lower is not None else None
esc_upper = _escape_kb(upper) if upper is not None else None
if esc_lower is not None:
seed = esc_lower + _ENTRY_SEP
elif esc_prefix is not None:
seed = esc_prefix
else:
seed = b""
c.set_key(db, coll, name, seed)
rc = c.search_near()
if rc == wt.WT_NOTFOUND:
return []
if rc < 0 and c.next() != 0:
return []
out: list[bytes] = []
while True:
k = c.get_key()
if (k[0], k[1], k[2]) != (db, coll, name):
break
packed = bytes(k[3])
row_esc, row_id = _unpack_entry(packed)
if esc_prefix is not None and not row_esc.startswith(esc_prefix):
break
if esc_lower is not None and not lower_inclusive and row_esc == esc_lower:
if c.next() != 0:
break
continue
if esc_upper is not None:
if upper_inclusive:
if row_esc > esc_upper:
break
elif row_esc >= esc_upper:
break
out.append(row_id)
if c.next() != 0:
break
return out
def _range_scan_index_leading(
self,
db: str,
coll: str,
name: str,
lower: bytes | None,
lower_inclusive: bool,
upper: bytes | None,
upper_inclusive: bool,
) -> list[bytes]:
"""Range-scan a compound index using only its leading field.
Each row's escaped kb is
``escape(enc(leading)) + escape(COMPOUND_SEP) + escape(enc(trailing...))``.
Boundary detection uses ``startswith(esc_X + esc_compound_sep)`` to
identify rows whose leading field equals ``X`` — the terminator
bytes of an escaped numeric encoding can overlap with the start of
the escaped compound separator, so a literal find/split on the
separator is unreliable.
"""
esc_compound_sep = _escape_kb(COMPOUND_SEP)
c = self._cursor(_IDX_ENTRIES_TABLE)
esc_lower = _escape_kb(lower) if lower is not None else None
esc_upper = _escape_kb(upper) if upper is not None else None
seed = esc_lower if esc_lower is not None else b""
c.set_key(db, coll, name, seed)
rc = c.search_near()
if rc == wt.WT_NOTFOUND:
return []
if rc < 0 and c.next() != 0:
return []
lower_eq_prefix = esc_lower + esc_compound_sep if esc_lower is not None else None
upper_eq_prefix = esc_upper + esc_compound_sep if esc_upper is not None else None
out: list[bytes] = []
while True:
k = c.get_key()
if (k[0], k[1], k[2]) != (db, coll, name):
break
packed = bytes(k[3])
row_esc, row_id = _unpack_entry(packed)
if (
lower_eq_prefix is not None
and not lower_inclusive
and row_esc.startswith(lower_eq_prefix)
):
if c.next() != 0:
break
continue
if esc_upper is not None:
if upper_inclusive:
if row_esc > esc_upper and not row_esc.startswith(upper_eq_prefix):
break
elif row_esc >= esc_upper:
break
out.append(row_id)
if c.next() != 0:
break
return out