omkit 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,116 @@
1
+ """packages/omur-sdk/omkit/jobqueue/envelope.py — Cross-SDK envelope contract for job-queue payloads.
2
+
3
+ Every task enqueued via streaq (Python) or Asynq (Go) is wrapped in this
4
+ envelope. Workers unwrap on receive, validate, and run the handler under the
5
+ tenant's RLS scope.
6
+
7
+ exports: ENVELOPE_VERSION | class InvalidEnvelopeError | class Envelope | wrap(tenant_id, payload) | unwrap(data)
8
+ rules: The Envelope class must maintain strict tenant isolation and never allow cross-tenant data leakage. All envelope validation must be immutable and deterministic to ensure consistent task processing across distributed workers. The wrap/unwrap functions must handle all serialization edge cases including nested data structures and preserve original payload integrity during transformation.
9
+ agent: ollama/qwen3-coder:latest | ollama | 2026-05-01 | codedna-cli | initial CodeDNA annotation pass
10
+ message:
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import json
16
+ import uuid
17
+ from typing import Any
18
+
19
+ from pydantic import BaseModel, Field, ValidationError, field_validator
20
+
21
+ ENVELOPE_VERSION = 1
22
+
23
+
24
+ class InvalidEnvelopeError(ValueError):
25
+ """Envelope failed validation. Workers should dead-letter (no retry)."""
26
+
27
+
28
+ class Envelope(BaseModel):
29
+ """Tenant-scoped task envelope.
30
+
31
+ `payload` is opaque — handlers parse it into their own pydantic model.
32
+ Cross-SDK contract: matches Go's packages/omur-go-sdk/jobqueue/Envelope
33
+ field-for-field. Empty payloads and missing version keys are rejected by
34
+ both sides — wrap()/Wrap() produce envelopes that round-trip cleanly
35
+ between Python and Go workers.
36
+ """
37
+
38
+ model_config = {"frozen": True, "extra": "forbid"}
39
+
40
+ # `version` is required — no default. Go's Unwrap rejects envelopes
41
+ # with version==0 (missing field zero-value), so the Python side must
42
+ # reject the missing-key case symmetrically.
43
+ version: int
44
+ tenant_id: str
45
+ payload: dict[str, Any]
46
+
47
+ @field_validator("tenant_id")
48
+ @classmethod
49
+ def _validate_tenant(cls, v: str) -> str:
50
+ try:
51
+ uuid.UUID(v)
52
+ except (ValueError, AttributeError, TypeError) as exc:
53
+ raise ValueError(f"tenant_id not a valid uuid: {v!r}") from exc
54
+ return v
55
+
56
+ @field_validator("version")
57
+ @classmethod
58
+ def _validate_version(cls, v: int) -> int:
59
+ if v < 1:
60
+ raise ValueError(f"envelope version must be >= 1, got {v}")
61
+ if v > ENVELOPE_VERSION:
62
+ raise ValueError(
63
+ f"unsupported envelope version {v} (max {ENVELOPE_VERSION})"
64
+ )
65
+ return v
66
+
67
+ @field_validator("payload")
68
+ @classmethod
69
+ def _validate_payload(cls, v: dict[str, Any]) -> dict[str, Any]:
70
+ if not v:
71
+ raise ValueError(
72
+ "payload must not be empty — Go workers dead-letter empty payloads"
73
+ )
74
+ return v
75
+
76
+
77
+ def wrap(tenant_id: str, payload: dict[str, Any]) -> bytes:
78
+ """Build an envelope and serialize to JSON bytes for streaq enqueue.
79
+
80
+ Raises InvalidEnvelopeError if tenant_id is not a UUID.
81
+
82
+ Rules: tenant_id must be a valid UUID string, otherwise InvalidEnvelopeError is raised
83
+ """
84
+ try:
85
+ env = Envelope(
86
+ version=ENVELOPE_VERSION,
87
+ tenant_id=tenant_id,
88
+ payload=payload,
89
+ )
90
+ except ValidationError as exc:
91
+ raise InvalidEnvelopeError(str(exc)) from exc
92
+ return env.model_dump_json().encode("utf-8")
93
+
94
+
95
+ def unwrap(data: bytes | str | dict[str, Any]) -> Envelope:
96
+ """Parse and validate inbound envelope.
97
+
98
+ Accepts raw JSON bytes/str or a pre-parsed dict (streaq sometimes hands
99
+ handlers the decoded payload directly). Raises InvalidEnvelopeError on
100
+ any validation failure — callers must dead-letter, not retry.
101
+
102
+ Rules: Input must be valid JSON bytes/str or a pre-parsed dict. If pre-parsed dict is provided, it must already be validated and contain the expected envelope structure. The function raises InvalidEnvelopeError for any validation failure, which should be handled by dead-lettering rather than retrying.
103
+ """
104
+ if isinstance(data, (bytes, str)):
105
+ try:
106
+ obj = json.loads(data)
107
+ except json.JSONDecodeError as exc:
108
+ raise InvalidEnvelopeError(f"envelope not valid json: {exc}") from exc
109
+ else:
110
+ obj = data
111
+ if not isinstance(obj, dict):
112
+ raise InvalidEnvelopeError(f"envelope must be a json object, got {type(obj).__name__}")
113
+ try:
114
+ return Envelope.model_validate(obj)
115
+ except ValidationError as exc:
116
+ raise InvalidEnvelopeError(str(exc)) from exc
@@ -0,0 +1,267 @@
1
+ """packages/omur-sdk/omkit/jobqueue/streaq.py — streaq integration for Omur Python services.
2
+
3
+ Wraps the streaq Worker with the SDK's tenant + envelope contract so all
4
+ Python services have the same ergonomics as the Go-side `omkit.jobqueue`
5
+ helpers (which front Asynq).
6
+
7
+ Public surface:
8
+
9
+ make_worker(redis_url, queue_name, ...) -> streaq.Worker
10
+ tenant_middleware -> streaq middleware factory
11
+ enqueue(task, tenant_id, payload, ...) -> shorthand for envelope-wrapped enqueue
12
+ mount_streaq_ui(app, worker, prefix=...) -> mount the FastAPI UI router
13
+ StreaqPromCollector(worker) -> prometheus.Collector for worker.counters
14
+
15
+ Conventions:
16
+
17
+ - Workers serialize tasks as JSON. Required because (a) cross-language
18
+ round-trip with Go workers requires JSON and (b) the streaq UI renders
19
+ JSON arguments inline. streaq's default binary serializer is replaced.
20
+ Callers MUST pass JSON-safe payloads.
21
+ - Every task is tenant-scoped. The first positional argument of every
22
+ registered task is the envelope dict; `tenant_middleware` unwraps it,
23
+ binds `tenant.current()`, and passes the inner payload to the handler.
24
+ - Defaults match the SDK contract documented in
25
+ `docs/superpowers/specs/2026-04-29-job-queue-design.md`:
26
+ concurrency=4, max_tries=3, task_timeout=300s, ttl=48h.
27
+
28
+ exports: DEFAULT_CONCURRENCY | DEFAULT_MAX_TRIES | DEFAULT_TIMEOUT_SECONDS | DEFAULT_TTL | make_worker(redis_url, queue_name) | tenant_middleware(next_handler) | enqueue(task, tenant_id, payload) | mount_streaq_ui(app, worker) | _STREAQ_COUNTER_KEYS | class StreaqPromCollector
29
+ rules: The module requires all Redis-based job queue operations to be thread-safe and idempotent, as it's designed for high-concurrency worker environments where tasks may be retried or processed by multiple workers simultaneously.
30
+ agent: ollama/qwen3-coder:latest | ollama | 2026-05-01 | codedna-cli | initial CodeDNA annotation pass
31
+ message:
32
+ """
33
+
34
+ from __future__ import annotations
35
+
36
+ import json
37
+ import logging
38
+ from datetime import timedelta
39
+ from typing import Any, Awaitable, Callable
40
+
41
+ from omkit import tenant
42
+ from omkit.jobqueue.envelope import (
43
+ Envelope,
44
+ InvalidEnvelopeError,
45
+ unwrap,
46
+ wrap,
47
+ )
48
+
49
+ log = logging.getLogger(__name__)
50
+
51
+ DEFAULT_CONCURRENCY = 4
52
+ DEFAULT_MAX_TRIES = 3
53
+ DEFAULT_TIMEOUT_SECONDS = 300
54
+ DEFAULT_TTL = timedelta(hours=48)
55
+
56
+
57
+ def _json_serializer(obj: Any) -> bytes:
58
+ return json.dumps(obj, default=str, separators=(",", ":")).encode("utf-8")
59
+
60
+
61
+ def _json_deserializer(data: bytes) -> Any:
62
+ return json.loads(data)
63
+
64
+
65
+ def make_worker(
66
+ redis_url: str,
67
+ queue_name: str,
68
+ *,
69
+ concurrency: int = DEFAULT_CONCURRENCY,
70
+ max_tries: int = DEFAULT_MAX_TRIES,
71
+ task_timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
72
+ ttl: timedelta = DEFAULT_TTL,
73
+ handle_signals: bool = False,
74
+ **worker_kwargs: Any,
75
+ ):
76
+ """Construct a streaq Worker pre-configured for Omur services.
77
+
78
+ `redis_url` accepts the same forms streaq does (`redis://valkey:6379/0`
79
+ or `rediss://…`). For Valkey with password, encode it in the URL:
80
+ `redis://:PASSWORD@valkey:6379/0`.
81
+
82
+ `handle_signals=False` because services run streaq alongside an HTTP
83
+ server and own their own SIGTERM handler — letting streaq install one
84
+ deadlocks shutdown. The lifespan-context-manager pattern stops the
85
+ worker cleanly on app shutdown instead.
86
+
87
+ `max_tries` and `task_timeout_seconds` set the worker-level defaults;
88
+ individual `@worker.task(...)` decorators may override.
89
+
90
+ Extra keyword arguments are forwarded to `streaq.Worker(...)` for
91
+ advanced cases (sentinel/cluster, custom serializers, etc.).
92
+
93
+ Rules: The `redis_url` must be a valid Redis/Valkey URL, including password-encoded URLs if required. The `handle_signals` parameter should be set to `True` only if the worker is intended to handle OS signals for graceful shutdown.
94
+ """
95
+ import streaq
96
+
97
+ return streaq.Worker(
98
+ redis_url=redis_url,
99
+ queue_name=queue_name,
100
+ concurrency=concurrency,
101
+ handle_signals=handle_signals,
102
+ serializer=_json_serializer,
103
+ deserializer=_json_deserializer,
104
+ **worker_kwargs,
105
+ )
106
+
107
+
108
+ # ─────────────────────────────────────────────────────────────────────
109
+ # Tenant + envelope middleware
110
+ # ─────────────────────────────────────────────────────────────────────
111
+
112
+
113
+ def tenant_middleware(next_handler: Callable[..., Awaitable[Any]]) -> Callable[..., Awaitable[Any]]:
114
+ """streaq middleware. Unwraps the envelope from the first positional
115
+ arg, binds `tenant.current()`, then calls `next_handler(payload, …)`.
116
+
117
+ Register on the Worker:
118
+
119
+ worker.middleware(tenant_middleware)
120
+
121
+ And then write tasks as:
122
+
123
+ @worker.task(timeout=600)
124
+ async def parse(payload: dict) -> None:
125
+ doc_id = payload["doc_id"]
126
+ assert tenant.current() is not None
127
+
128
+ `InvalidEnvelopeError` raised here propagates as a regular exception —
129
+ streaq counts it as a failure and respects `max_tries`. Callers MUST
130
+ ensure all enqueues go through `enqueue()` below so envelopes are
131
+ well-formed; a ValidationError at the worker boundary indicates a
132
+ bug, not a transient fault.
133
+
134
+ Rules: The `next_handler` function must accept the unwrapped payload as its first positional argument and must be used in conjunction with a worker that has `tenant_middleware` registered to ensure tenant context is correctly set.
135
+ """
136
+
137
+ async def wrapper(*args: Any, **kwargs: Any) -> Any:
138
+ if not args:
139
+ raise InvalidEnvelopeError("task called with no positional args")
140
+ try:
141
+ env: Envelope = unwrap(args[0])
142
+ except InvalidEnvelopeError:
143
+ log.error("streaq.invalid_envelope")
144
+ raise
145
+ with tenant.bind(env.tenant_id):
146
+ return await next_handler(env.payload, *args[1:], **kwargs)
147
+
148
+ return wrapper
149
+
150
+
151
+ # ─────────────────────────────────────────────────────────────────────
152
+ # Enqueue helper
153
+ # ─────────────────────────────────────────────────────────────────────
154
+
155
+
156
+ def enqueue(task: Any, tenant_id: str, payload: dict[str, Any], **opts: Any) -> Awaitable[Any]:
157
+ """Enqueue a streaq task with an envelope-wrapped payload.
158
+
159
+ `task` is the result of `@worker.task(...)`. Returns the awaitable
160
+ streaq returns from `task.enqueue(...)`.
161
+
162
+ Caller is responsible for `await`-ing.
163
+
164
+ Rules: The `task` must be a valid streaq task created using `@worker.task(...)`; otherwise, `task.enqueue(...)` will fail. The `tenant_id` must be a valid identifier for the tenant context.
165
+ """
166
+ envelope_bytes = wrap(tenant_id, payload)
167
+ envelope_dict = json.loads(envelope_bytes)
168
+ return task.enqueue(envelope_dict, **opts)
169
+
170
+
171
+ # ─────────────────────────────────────────────────────────────────────
172
+ # FastAPI UI mount
173
+ # ─────────────────────────────────────────────────────────────────────
174
+
175
+
176
+ def mount_streaq_ui(app: Any, worker: Any, *, prefix: str = "/queue/ui") -> None:
177
+ """Mount streaq's built-in admin UI at `prefix`.
178
+
179
+ streaq's UI router uses a FastAPI dependency `get_worker` that raises
180
+ 412 by default. Override it to return our worker so the UI can read
181
+ queue state, results, and counters. The route is otherwise open —
182
+ Caddy's oauth2-proxy forward-auth gates access (Zitadel SSO).
183
+
184
+ Rules: The `app` must be a FastAPI application instance, and the `worker` must be a properly initialized streaq worker. The `prefix` should not conflict with existing routes in the application.
185
+ """
186
+ try:
187
+ from streaq.ui.deps import get_worker
188
+ from streaq.ui.tasks import router as tasks_router
189
+ except ImportError as exc:
190
+ raise RuntimeError(
191
+ "streaq UI requires `streaq[web]` extra (fastapi/jinja2/uvicorn)"
192
+ ) from exc
193
+
194
+ app.dependency_overrides[get_worker] = lambda: worker
195
+ app.include_router(tasks_router, prefix=prefix)
196
+
197
+
198
+ # ─────────────────────────────────────────────────────────────────────
199
+ # Prometheus bridge
200
+ # ─────────────────────────────────────────────────────────────────────
201
+
202
+
203
+ _STREAQ_COUNTER_KEYS = (
204
+ "aborted",
205
+ "completed",
206
+ "failed",
207
+ "relinquished",
208
+ "retried",
209
+ "running",
210
+ )
211
+
212
+
213
+ class StreaqPromCollector:
214
+ """Prometheus collector that lazily reads `worker.counters` on every
215
+ scrape and exports gauges:
216
+
217
+ streaq_worker_aborted{queue}
218
+ streaq_worker_completed{queue}
219
+ streaq_worker_failed{queue}
220
+ streaq_worker_relinquished{queue}
221
+ streaq_worker_retried{queue}
222
+ streaq_worker_running{queue}
223
+
224
+ Register once per process:
225
+
226
+ from prometheus_client import REGISTRY
227
+ REGISTRY.register(StreaqPromCollector(worker))
228
+
229
+ All metrics are gauges (counters reset on worker restart, which is
230
+ fine — alert dashboards already de-dupe on `service` instance).
231
+ """
232
+
233
+ def __init__(self, worker: Any) -> None:
234
+ self._worker = worker
235
+
236
+ def describe(self) -> Any:
237
+ return iter([])
238
+
239
+ def collect(self) -> Any:
240
+ """
241
+ Rules: The `_worker` object must have a `queue_name` attribute and a `counters` dictionary with keys matching `_STREAQ_COUNTER_KEYS`; otherwise, the Prometheus metrics will not be correctly populated.
242
+ """
243
+ from prometheus_client.core import GaugeMetricFamily
244
+
245
+ queue = getattr(self._worker, "queue_name", "default")
246
+ counters = getattr(self._worker, "counters", {}) or {}
247
+ for key in _STREAQ_COUNTER_KEYS:
248
+ g = GaugeMetricFamily(
249
+ f"streaq_worker_{key}",
250
+ f"streaq worker {key} count (since process start)",
251
+ labels=["queue"],
252
+ )
253
+ g.add_metric([queue], float(counters.get(key, 0)))
254
+ yield g
255
+
256
+
257
+ __all__ = [
258
+ "DEFAULT_CONCURRENCY",
259
+ "DEFAULT_MAX_TRIES",
260
+ "DEFAULT_TIMEOUT_SECONDS",
261
+ "DEFAULT_TTL",
262
+ "StreaqPromCollector",
263
+ "enqueue",
264
+ "make_worker",
265
+ "mount_streaq_ui",
266
+ "tenant_middleware",
267
+ ]
omkit/logging.py ADDED
@@ -0,0 +1,77 @@
1
+ """packages/omur-sdk/omkit/logging.py — Shared structlog configuration for all Omur services.
2
+
3
+ Default output is JSON, suitable for production log aggregation. Set
4
+ ``LOG_FORMAT=console`` to switch to the human-readable renderer during dev.
5
+
6
+ Usage:
7
+ from omkit.logging import configure_logging
8
+ configure_logging("spine") # Call once at startup, before get_logger()
9
+
10
+ exports: configure_logging(service_name)
11
+ rules: The logging module must maintain backward compatibility with existing log format configurations and service name resolution patterns across all SDK versions. The module cannot introduce breaking changes to its public API or alter the default logging behavior without explicit versioned migration paths. All logging configurations must remain thread-safe and support concurrent service initialization without race conditions.
12
+ agent: ollama/qwen3-coder:latest | ollama | 2026-05-01 | codedna-cli | initial CodeDNA annotation pass
13
+ message:
14
+ """
15
+ from __future__ import annotations
16
+
17
+ import os
18
+
19
+ import structlog
20
+
21
+
22
+ def configure_logging(service_name: str) -> None:
23
+ """Configure structlog with ISO timestamps, log level, contextvars, and
24
+ a renderer selected by the ``LOG_FORMAT`` environment variable.
25
+
26
+ Every log record emitted after this call carries a ``service`` field set
27
+ to ``service_name`` (unless the call site overrides it explicitly).
28
+
29
+ ``LOG_FORMAT`` values:
30
+ * ``json`` (default) — JSONRenderer for production / log aggregation.
31
+ * ``console`` — ConsoleRenderer for local development.
32
+
33
+ Rules: LOG_FORMAT environment variable must be either 'json' or 'console' (case insensitive), with 'json' as default. Future developers must ensure these specific values are handled or risk runtime errors.
34
+ """
35
+ fmt = os.environ.get("LOG_FORMAT", "json").lower()
36
+ if fmt == "console":
37
+ renderer = structlog.dev.ConsoleRenderer()
38
+ else:
39
+ renderer = structlog.processors.JSONRenderer()
40
+
41
+ def _add_service(_logger, _method, event_dict):
42
+ event_dict.setdefault("service", service_name)
43
+ return event_dict
44
+
45
+ def _add_correlation(_logger, _method, event_dict):
46
+ # Pull tenant + request_id off the SDK-managed contextvars so every
47
+ # log record is auto-tagged with cross-service correlation fields.
48
+ # Lazy import keeps this module decoupled from tenant.
49
+ try:
50
+ from omkit.tenant import current_or_none, request_id
51
+ except Exception:
52
+ return event_dict
53
+ try:
54
+ tid = current_or_none()
55
+ rid = request_id()
56
+ except Exception:
57
+ return event_dict
58
+ if tid:
59
+ event_dict.setdefault("tenant_id", tid)
60
+ if rid:
61
+ event_dict.setdefault("request_id", rid)
62
+ return event_dict
63
+
64
+ structlog.configure(
65
+ processors=[
66
+ structlog.contextvars.merge_contextvars,
67
+ _add_service,
68
+ _add_correlation,
69
+ structlog.stdlib.add_log_level,
70
+ structlog.processors.TimeStamper(fmt="iso"),
71
+ renderer,
72
+ ],
73
+ wrapper_class=structlog.make_filtering_bound_logger(0),
74
+ context_class=dict,
75
+ logger_factory=structlog.PrintLoggerFactory(),
76
+ cache_logger_on_first_use=True,
77
+ )
omkit/metrics.py ADDED
@@ -0,0 +1,41 @@
1
+ """Shared Prometheus metrics wiring for FastAPI services.
2
+
3
+ Usage:
4
+ from omkit.metrics import mount_metrics
5
+ mount_metrics(app, "my-service") # exposes /metrics, instruments all routes
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import TYPE_CHECKING
11
+
12
+ if TYPE_CHECKING:
13
+ from fastapi import FastAPI
14
+
15
+
16
+ def mount_metrics(app: "FastAPI", service_name: str) -> None:
17
+ """Wire prometheus-fastapi-instrumentator with default labels and a /metrics endpoint.
18
+
19
+ Idempotent: calling twice on the same app is a no-op.
20
+
21
+ Rules: The function requires the 'prometheus-fastapi-instrumentator' package to be installed, and the app parameter must be a FastAPI instance that supports the '_omkit_metrics_mounted' attribute for idempotency checks.
22
+ """
23
+ try:
24
+ from prometheus_fastapi_instrumentator import Instrumentator
25
+ except ImportError as e:
26
+ raise ImportError(
27
+ "prometheus-fastapi-instrumentator is required. "
28
+ "Install with: pip install omkit[metrics]"
29
+ ) from e
30
+
31
+ if getattr(app, "_omkit_metrics_mounted", False):
32
+ return
33
+
34
+ Instrumentator(
35
+ should_group_status_codes=True,
36
+ should_ignore_untemplated=True,
37
+ should_respect_env_var=False,
38
+ excluded_handlers=["/metrics", "/health", "/ready"],
39
+ ).instrument(app, metric_namespace="omur", metric_subsystem=service_name).expose(app)
40
+
41
+ app._omkit_metrics_mounted = True
@@ -0,0 +1,192 @@
1
+ """On-demand model loading with TTL-based idle unloading.
2
+
3
+ ModelLifecycle is an abstract base for lazy model loading. ModelRegistry
4
+ manages a set of lifecycles and reaps idle ones. Thread-safe via per-
5
+ lifecycle async lock.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import abc
11
+ import asyncio
12
+ import gc
13
+ import time
14
+ from typing import Any
15
+
16
+ import structlog
17
+ from prometheus_client import Histogram, Counter, Gauge
18
+
19
+ log = structlog.get_logger()
20
+
21
+ MODEL_LOAD_DURATION = Histogram(
22
+ "model_load_duration_seconds",
23
+ "Time to load a model into memory",
24
+ ["model"],
25
+ buckets=[1, 2, 5, 10, 15, 20, 30, 60],
26
+ )
27
+ MODEL_LOAD_ERRORS = Counter(
28
+ "model_load_errors_total",
29
+ "Number of model load failures",
30
+ ["model"],
31
+ )
32
+ MODEL_UNLOAD_TOTAL = Counter(
33
+ "model_unload_total",
34
+ "Number of model unloads",
35
+ ["model", "reason"],
36
+ )
37
+ MODEL_LOADED = Gauge(
38
+ "model_loaded",
39
+ "Whether a model is currently loaded (1=yes, 0=no)",
40
+ ["model"],
41
+ )
42
+
43
+
44
+ class ModelLifecycle(abc.ABC):
45
+ """Abstract base for on-demand model loading with idle tracking."""
46
+
47
+ def __init__(self, name: str) -> None:
48
+ self.name = name
49
+ self._model: Any = None
50
+ self._last_used: float = 0
51
+ self._lock = asyncio.Lock()
52
+
53
+ @abc.abstractmethod
54
+ def _do_load(self) -> Any:
55
+ """Load model into memory. Runs in thread executor. Return model object."""
56
+
57
+ @abc.abstractmethod
58
+ def _do_unload(self) -> None:
59
+ """Release model resources. Runs in thread executor."""
60
+
61
+ @property
62
+ def is_loaded(self) -> bool:
63
+ """
64
+ Rules: none
65
+ """
66
+ return self._model is not None
67
+
68
+ @property
69
+ def model(self) -> Any:
70
+ return self._model
71
+
72
+ @property
73
+ def last_used(self) -> float:
74
+ return self._last_used
75
+
76
+ def touch(self) -> None:
77
+ """
78
+ Rules: none
79
+ """
80
+ self._last_used = time.monotonic()
81
+
82
+ async def ensure_loaded(self) -> None:
83
+ """
84
+ Rules: Model loading is async and uses a lock; concurrent calls may result in redundant loading if not properly synchronized.
85
+ """
86
+ async with self._lock:
87
+ if self._model is not None:
88
+ self._last_used = time.monotonic()
89
+ return
90
+ log.info("model.loading", model=self.name)
91
+ t0 = time.monotonic()
92
+ loop = asyncio.get_running_loop()
93
+ try:
94
+ self._model = await loop.run_in_executor(None, self._do_load)
95
+ except Exception:
96
+ MODEL_LOAD_ERRORS.labels(model=self.name).inc()
97
+ raise
98
+ duration = time.monotonic() - t0
99
+ self._last_used = time.monotonic()
100
+ MODEL_LOAD_DURATION.labels(model=self.name).observe(duration)
101
+ MODEL_LOADED.labels(model=self.name).set(1)
102
+ log.info("model.loaded", model=self.name, duration_s=round(duration, 2))
103
+
104
+ async def unload(self) -> None:
105
+ """
106
+ Rules: Model unloading is async and uses a lock; calling unload on an already unloaded model is safe but does nothing.
107
+ """
108
+ async with self._lock:
109
+ if self._model is None:
110
+ return
111
+ log.info("model.unloading", model=self.name)
112
+ MODEL_LOADED.labels(model=self.name).set(0)
113
+ loop = asyncio.get_running_loop()
114
+ await loop.run_in_executor(None, self._do_unload)
115
+ self._model = None
116
+ self._last_used = 0
117
+ gc.collect()
118
+ log.info("model.unloaded", model=self.name)
119
+
120
+
121
+ class ModelRegistry:
122
+ """Manages a set of ModelLifecycle instances with a shared reaper task."""
123
+
124
+ def __init__(self) -> None:
125
+ self._models: dict[str, ModelLifecycle] = {}
126
+ self._ttl: int = 300
127
+ self._reaper_task: asyncio.Task | None = None
128
+
129
+ def register(self, name: str, lifecycle: ModelLifecycle) -> None:
130
+ """
131
+ Rules: none
132
+ """
133
+ self._models[name] = lifecycle
134
+
135
+ def status(self) -> dict[str, bool]:
136
+ """
137
+ Rules: none
138
+ """
139
+ return {name: lc.is_loaded for name, lc in self._models.items()}
140
+
141
+ def set_ttl(self, ttl_seconds: int) -> None:
142
+ """
143
+ Rules: none
144
+ """
145
+ self._ttl = ttl_seconds
146
+ log.info("registry.ttl_updated", ttl=ttl_seconds)
147
+
148
+ def start_reaper(self, ttl_seconds: int, sweep_interval: float = 30) -> None:
149
+ """
150
+ Rules: Starting a new reaper task cancels any existing one; ensure the registry is not used concurrently during this operation.
151
+ """
152
+ self._ttl = ttl_seconds
153
+ if self._reaper_task and not self._reaper_task.done():
154
+ self._reaper_task.cancel()
155
+ self._reaper_task = asyncio.create_task(
156
+ self._reap_loop(sweep_interval), name="model-reaper"
157
+ )
158
+
159
+ def stop_reaper(self) -> None:
160
+ """
161
+ Rules: none
162
+ """
163
+ if self._reaper_task and not self._reaper_task.done():
164
+ self._reaper_task.cancel()
165
+ self._reaper_task = None
166
+
167
+ async def unload_all(self) -> None:
168
+ """
169
+ Rules: Unloading all models stops the reaper task and may cause a delay due to garbage collection and async I/O.
170
+ """
171
+ self.stop_reaper()
172
+ for name, lc in self._models.items():
173
+ if lc.is_loaded:
174
+ MODEL_UNLOAD_TOTAL.labels(model=name, reason="shutdown").inc()
175
+ await lc.unload()
176
+
177
+ async def _reap_loop(self, interval: float) -> None:
178
+ try:
179
+ while True:
180
+ await asyncio.sleep(interval)
181
+ if self._ttl <= 0:
182
+ continue
183
+ now = time.monotonic()
184
+ for name, lc in list(self._models.items()):
185
+ try:
186
+ if lc.is_loaded and (now - lc.last_used) >= self._ttl:
187
+ MODEL_UNLOAD_TOTAL.labels(model=name, reason="idle").inc()
188
+ await lc.unload()
189
+ except Exception:
190
+ log.error("reaper.unload_failed", model=name, exc_info=True)
191
+ except asyncio.CancelledError:
192
+ pass