Source code for secantus.cursors

from __future__ import annotations

import secrets
import threading
import time
import uuid as _uuid
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any

DEFAULT_IDLE_TTL_SECONDS = 600.0  # matches MongoDB's 10-minute cursor TTL.
TAILABLE_IDLE_TTL_SECONDS = 1800.0  # 30 min — change-stream clients legitimately idle.

# Hard cap on simultaneous live cursors. Without this, a malicious or
# buggy client can open cursors with `no_cursor_timeout=True` (or a low
# request rate that stays under the prune cadence) and accumulate
# unbounded `remaining` doc lists, OOMing the server.
DEFAULT_MAX_CURSORS = 10_000


class CursorLimitExceeded(Exception):
    """Raised when register() would push the registry over its cap."""

    def __init__(self, limit: int) -> None:
        super().__init__(
            f"cursor registry full ({limit} cursors live); kill an existing "
            f"cursor or increase max_cursors before opening a new one"
        )
        self.limit = limit


[docs] class CursorNotFound(Exception): def __init__(self, cursor_id: int) -> None: super().__init__(f"cursor id {cursor_id} not found") self.cursor_id = cursor_id
@dataclass class _Entry: cursor_id: int namespace: str remaining: list[dict[str, Any]] last_access: float tailable: bool = False await_data: bool = False no_cursor_timeout: bool = False producer: Callable[[], list[dict[str, Any]]] | None = None position_seq: int = 0 collection_uuid: _uuid.UUID | None = None invalidated: bool = False final_event_pending: bool = False # Latest change-stream resume token observed by the producer. # Returned to the client as ``cursor.postBatchResumeToken`` on # every getMore so consumers on quiet collections can advance # their resume position even when no events are visible. None # for non-change-stream cursors. last_token: dict[str, str] | None = None
[docs] class CursorRegistry: def __init__( self, idle_ttl_seconds: float = DEFAULT_IDLE_TTL_SECONDS, time_func: Callable[[], float] | None = None, tailable_idle_ttl_seconds: float = TAILABLE_IDLE_TTL_SECONDS, max_cursors: int = DEFAULT_MAX_CURSORS, ) -> None: self._cursors: dict[int, _Entry] = {} self._lock = threading.Lock() self.idle_ttl_seconds = idle_ttl_seconds self.tailable_idle_ttl_seconds = tailable_idle_ttl_seconds self.max_cursors = max_cursors self._time = time_func or time.monotonic # Last successful prune timestamp; -inf means never. Used to amortise # the O(N) prune walk across operations rather than running it on # every register / get / next_batch / kill / __len__ call. self._last_prune: float = float("-inf") def _prune_interval(self) -> float: # Cap the prune cadence at 60s and at one-tenth of the (smaller) TTL, # so a cursor can't outlive its TTL by more than ~10%. ttl = self.idle_ttl_seconds if ttl <= 0: return float("inf") return min(60.0, ttl / 10.0) def _prune_locked(self) -> None: now = self._time() if now - self._last_prune < self._prune_interval(): return expired: list[int] = [] for cid, e in self._cursors.items(): if e.no_cursor_timeout: continue ttl = self.tailable_idle_ttl_seconds if e.tailable else self.idle_ttl_seconds if ttl <= 0: continue if e.last_access < now - ttl: expired.append(cid) for cid in expired: del self._cursors[cid] self._last_prune = now def register(self, namespace: str, remaining: list[dict[str, Any]]) -> int: with self._lock: self._prune_locked() if len(self._cursors) >= self.max_cursors: raise CursorLimitExceeded(self.max_cursors) # Cursor IDs were `itertools.count(1)` — sequential, predictable, # and trivially enumerable by another connection issuing # `getMore` against guessed IDs. Mint with `secrets.randbits` # like the tailable path so cross-session cursor hijacking # requires brute-forcing a 63-bit space. for _attempt in range(8): candidate = secrets.randbits(63) | 1 # avoid 0 if candidate not in self._cursors: cursor_id = candidate break else: raise RuntimeError("could not mint unique cursor id") self._cursors[cursor_id] = _Entry(cursor_id, namespace, list(remaining), self._time()) return cursor_id def register_tailable( self, namespace: str, producer: Callable[[], list[dict[str, Any]]], *, await_data: bool = True, no_cursor_timeout: bool = False, position_seq: int = 0, collection_uuid: _uuid.UUID | None = None, initial_remaining: list[dict[str, Any]] | None = None, ) -> int: """Register a tailable cursor backed by a producer closure. ``initial_remaining`` lets the caller pre-seed the cursor's ``remaining`` queue with docs already matched at ``find`` time that didn't fit in ``firstBatch``. Subsequent ``getMore``s drain that queue first, then fall through to ``producer()`` for newly-inserted docs. Cursor IDs are int64-random (above 2**32) to match what real ``mongod`` issues for change streams; some drivers compare these as signed 64-bit and break on small sequential ints. """ with self._lock: self._prune_locked() if len(self._cursors) >= self.max_cursors: raise CursorLimitExceeded(self.max_cursors) for _attempt in range(8): candidate = secrets.randbits(63) | (1 << 32) # ensure > 2**32 if candidate not in self._cursors: cursor_id = candidate break else: raise RuntimeError("could not mint unique tailable cursor id") self._cursors[cursor_id] = _Entry( cursor_id=cursor_id, namespace=namespace, remaining=list(initial_remaining) if initial_remaining else [], last_access=self._time(), tailable=True, await_data=await_data, no_cursor_timeout=no_cursor_timeout, producer=producer, position_seq=position_seq, collection_uuid=collection_uuid, ) return cursor_id def get(self, cursor_id: int) -> _Entry: with self._lock: self._prune_locked() entry = self._cursors.get(cursor_id) if entry is None: raise CursorNotFound(cursor_id) entry.last_access = self._time() return entry def next_batch(self, cursor_id: int, batch_size: int) -> tuple[list[dict[str, Any]], bool]: with self._lock: self._prune_locked() entry = self._cursors.get(cursor_id) if entry is None: raise CursorNotFound(cursor_id) if batch_size <= 0: batch_size = len(entry.remaining) batch = entry.remaining[:batch_size] entry.remaining = entry.remaining[batch_size:] exhausted = not entry.remaining if entry.tailable: # Tailable cursors persist across empty batches. entry.last_access = self._time() return batch, False if exhausted: del self._cursors[cursor_id] else: entry.last_access = self._time() return batch, exhausted def kill(self, cursor_ids: list[int]) -> tuple[list[int], list[int]]: killed: list[int] = [] not_found: list[int] = [] with self._lock: self._prune_locked() for cid in cursor_ids: if self._cursors.pop(cid, None) is not None: killed.append(cid) else: not_found.append(cid) return killed, not_found def __len__(self) -> int: with self._lock: self._prune_locked() return len(self._cursors) def snapshot(self) -> list[dict[str, Any]]: """Return a list of plain dicts describing the live cursors. Used by the ``currentOp`` command and the admin UI's cursors page. Each entry has ``cursor_id``, ``namespace``, ``remaining``, ``last_access``, ``tailable``, and ``await_data``. We don't return ``_Entry`` instances so callers can't accidentally mutate registry state. """ with self._lock: self._prune_locked() return [ { "cursor_id": e.cursor_id, "namespace": e.namespace, "remaining": len(e.remaining), "last_access": e.last_access, "tailable": e.tailable, "await_data": e.await_data, } for e in sorted(self._cursors.values(), key=lambda x: x.cursor_id) ]
__all__ = [ "CursorLimitExceeded", "CursorNotFound", "CursorRegistry", "_Entry", "DEFAULT_IDLE_TTL_SECONDS", "DEFAULT_MAX_CURSORS", "TAILABLE_IDLE_TTL_SECONDS", ]