omkit 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,18 @@
1
+ """omkit.platform — re-exports platform primitives.
2
+
3
+ Settings + sync notification + lazy model lifecycle helpers. Additive
4
+ grouping; flat imports still work.
5
+ """
6
+
7
+ from omkit.config import BaseServiceSettings
8
+ from omkit.model_lifecycle import ModelLifecycle, ModelRegistry
9
+ from omkit.settings import SettingsManager
10
+ from omkit.sync_notifier import SyncNotifier
11
+
12
+ __all__ = [
13
+ "BaseServiceSettings",
14
+ "ModelLifecycle",
15
+ "ModelRegistry",
16
+ "SettingsManager",
17
+ "SyncNotifier",
18
+ ]
@@ -0,0 +1,11 @@
1
+ """packages/omur-sdk/omkit/providers/__init__.py — Package init for providers.
2
+
3
+ exports: none
4
+ rules: none
5
+ agent: ollama/qwen3-coder:latest | ollama | 2026-05-01 | codedna-cli | initial CodeDNA annotation pass
6
+ message:
7
+ """
8
+ from .base import ProviderBase, ProviderDocument, ProviderMetric
9
+ from .registry import ProviderRegistry
10
+
11
+ __all__ = ["ProviderBase", "ProviderDocument", "ProviderMetric", "ProviderRegistry"]
@@ -0,0 +1,76 @@
1
+ """packages/omur-sdk/omkit/providers/base.py — ProviderBase ABC and shared data contracts for all Omur providers.
2
+
3
+ exports: class ProviderDocument | class ProviderMetric | class ProviderBase
4
+ rules: All provider classes must inherit from ProviderBase and implement the run() method, with ProviderMetric values must be coercible to float.
5
+ agent: ollama/qwen3-coder:latest | ollama | 2026-05-01 | codedna-cli | initial CodeDNA annotation pass
6
+ message:
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from abc import ABC, abstractmethod
12
+ from typing import Any
13
+
14
+ from pydantic import BaseModel, Field, field_validator
15
+
16
+
17
+ class ProviderDocument(BaseModel):
18
+ """Document emitted by an indexer provider."""
19
+ source: str
20
+ source_id: str
21
+ title: str
22
+ content: str
23
+ doc_type: str | None = None
24
+ doc_date: str | None = None
25
+ meta: dict[str, Any] = Field(default_factory=dict)
26
+
27
+
28
+ class ProviderMetric(BaseModel):
29
+ """Time-series metric emitted by a collector or sensor provider."""
30
+ source: str
31
+ metric: str
32
+ value: float
33
+ unit: str
34
+ ts: int # nanoseconds UTC epoch
35
+ tenant_id: str
36
+ meta: dict[str, Any] = Field(default_factory=dict)
37
+
38
+ @field_validator("value", mode="before")
39
+ @classmethod
40
+ def coerce_value(cls, v: Any) -> float:
41
+ """
42
+ Rules: Input value must be convertible to float, otherwise float() will raise ValueError. Future developers should ensure proper error handling for non-numeric inputs.
43
+ """
44
+ return float(v)
45
+
46
+
47
+ class ProviderBase(ABC):
48
+ """
49
+ Base class for all Omur data providers.
50
+
51
+ Subclasses must declare class-level `kind` and `name` and implement `run()`.
52
+ `run()` is called once per active tenant instance and must handle
53
+ asyncio.CancelledError for clean shutdown.
54
+ """
55
+
56
+ kind: str # 'collector' | 'indexer' | 'sensor'
57
+ name: str # e.g. 'fitbit', 'gdrive', 'weather_station'
58
+
59
+ def __init_subclass__(cls, **kwargs: Any) -> None:
60
+ super().__init_subclass__(**kwargs)
61
+ if not getattr(cls, "__abstractmethods__", None):
62
+ for attr in ("kind", "name"):
63
+ if not isinstance(cls.__dict__.get(attr), str):
64
+ raise TypeError(f"{cls.__name__} must define class attribute '{attr}' as a str")
65
+
66
+ def __init__(self, tenant_id: str, config: dict[str, Any]) -> None:
67
+ self.tenant_id = tenant_id
68
+ self.config = config
69
+
70
+ @abstractmethod
71
+ async def run(self) -> None:
72
+ """Main loop. Must handle asyncio.CancelledError for clean shutdown.
73
+
74
+ Rules: Must handle asyncio.CancelledError for clean shutdown. Future developers MUST know that this function is the main execution loop and any cleanup logic should be implemented in response to cancellation.
75
+ """
76
+ ...
@@ -0,0 +1,263 @@
1
+ """packages/omur-sdk/omkit/providers/registry.py — loads providers from DB, manages asyncio tasks, hot-reloads via Valkey.
2
+
3
+ exports: DEFAULT_PROVIDERS_POLL_INTERVAL | class ProviderRegistry
4
+ rules: The module must maintain thread safety across all asyncio task management and ensure consistent state synchronization between PostgreSQL and Valkey for tenant-provider configurations.
5
+ agent: ollama/qwen3-coder:latest | ollama | 2026-05-01 | codedna-cli | initial CodeDNA annotation pass
6
+ message:
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import asyncio
12
+ import os
13
+ from typing import Any
14
+
15
+ import structlog
16
+
17
+ from .base import ProviderBase
18
+
19
+ log = structlog.get_logger()
20
+
21
+
22
+ def _backend_from_env() -> str:
23
+ v = os.getenv("PROVIDERS_BACKEND", "postgres")
24
+ if v not in {"postgres", "redis"}:
25
+ # Invalid values fall back to postgres to avoid blocking startup on a typo.
26
+ log.warning("registry.invalid_backend_env", value=v)
27
+ return "postgres"
28
+ return v
29
+
30
+
31
+ DEFAULT_PROVIDERS_POLL_INTERVAL = 10.0
32
+
33
+
34
+ class ProviderRegistry:
35
+ """
36
+ Manages one asyncio Task per active (tenant_id, provider_name) pair.
37
+
38
+ Lifecycle:
39
+ - start(): load all enabled providers from DB, start tasks, subscribe to Valkey
40
+ - stop(): cancel all tasks cleanly
41
+ - _reload_tenant(tenant_id): cancel stale tasks, re-query DB, start new tasks
42
+
43
+ Valkey channel: omur:providers:updated:{tenant_id}
44
+ On reconnect after Valkey disconnect: re-read full providers table to reconcile
45
+ any events missed during the outage (reconnect-with-exponential-backoff).
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ kind: str,
51
+ provider_classes: dict[str, type[ProviderBase]],
52
+ postgres_dsn: str,
53
+ valkey_url: str,
54
+ *,
55
+ poll_interval: float = DEFAULT_PROVIDERS_POLL_INTERVAL,
56
+ backend: str | None = None,
57
+ ) -> None:
58
+ self.kind = kind
59
+ self.provider_classes = provider_classes
60
+ self._postgres_dsn = postgres_dsn
61
+ self._valkey_url = valkey_url
62
+ self._tasks: dict[str, asyncio.Task] = {} # key: "{tenant_id}:{name}"
63
+ self._valkey_task: asyncio.Task | None = None
64
+ self._poll_task: asyncio.Task | None = None
65
+ self._table_missing: bool = False
66
+ self._poll_interval = poll_interval
67
+ self._backend = backend or _backend_from_env()
68
+ self._stop = asyncio.Event()
69
+
70
+ # ── Public API ────────────────────────────────────────────────
71
+
72
+ async def start(self) -> None:
73
+ """
74
+ Rules: Registry must handle database unavailability gracefully by falling back to empty provider list and logging warning. The backend type (redis vs poll) determines which subscription mechanism is used, with redis using _subscribe_valkey() and poll using _poll_loop().
75
+ """
76
+ try:
77
+ rows = await self._fetch_providers()
78
+ except Exception as exc:
79
+ log.warning("registry.db_unavailable", kind=self.kind, error=str(exc))
80
+ rows = []
81
+ for row in rows:
82
+ self._start_task(row["tenant_id"], row["name"], row["config"])
83
+ if self._backend == "redis":
84
+ self._valkey_task = asyncio.create_task(self._subscribe_valkey())
85
+ else:
86
+ self._poll_task = asyncio.create_task(self._poll_loop())
87
+ log.info(
88
+ "registry.started", kind=self.kind, tasks=len(self._tasks),
89
+ backend=self._backend,
90
+ )
91
+
92
+ async def stop(self) -> None:
93
+ """
94
+ Rules: Stop method must properly cancel all active tasks and ensure clean shutdown of both redis subscription and polling tasks, with proper exception handling during task cancellation.
95
+ """
96
+ self._stop.set()
97
+ if self._valkey_task:
98
+ self._valkey_task.cancel()
99
+ await asyncio.gather(self._valkey_task, return_exceptions=True)
100
+ if self._poll_task:
101
+ self._poll_task.cancel()
102
+ await asyncio.gather(self._poll_task, return_exceptions=True)
103
+ await self._cancel_tasks(list(self._tasks.keys()))
104
+ log.info("registry.stopped", kind=self.kind)
105
+
106
+ async def _poll_loop(self) -> None:
107
+ """Reconcile running-vs-desired tasks every poll_interval by re-reading
108
+ the providers table. Replaces the valkey pub/sub path when backend is
109
+ set to postgres (default)."""
110
+ while not self._stop.is_set():
111
+ try:
112
+ await asyncio.wait_for(self._stop.wait(), timeout=self._poll_interval)
113
+ except asyncio.TimeoutError:
114
+ pass
115
+ if self._stop.is_set():
116
+ return
117
+ try:
118
+ await self._reconcile_all()
119
+ except Exception as exc:
120
+ log.warning("registry.poll_failed", kind=self.kind, error=str(exc))
121
+
122
+ async def _reload_tenant(self, tenant_id: str) -> None:
123
+ if not tenant_id or len(tenant_id) < 36:
124
+ log.warning("registry.invalid_tenant_id", tenant_id=tenant_id, hint="skipping reload")
125
+ return
126
+ # Cancel all tasks for this tenant
127
+ tenant_keys = [k for k in self._tasks if k.startswith(f"{tenant_id}:")]
128
+ await self._cancel_tasks(tenant_keys)
129
+
130
+ # Re-query and restart
131
+ rows = await self._fetch_providers(tenant_id=tenant_id)
132
+ for row in rows:
133
+ self._start_task(row["tenant_id"], row["name"], row["config"])
134
+ log.info("registry.tenant_reloaded", tenant_id=tenant_id, new_tasks=len(rows))
135
+
136
+ # ── Internal ─────────────────────────────────────────────────
137
+
138
+ def _start_task(self, tenant_id: str, name: str, config: dict[str, Any]) -> None:
139
+ cls = self.provider_classes.get(name)
140
+ if cls is None:
141
+ log.warning("registry.unknown_provider", name=name, tenant_id=tenant_id)
142
+ return
143
+ key = f"{tenant_id}:{name}"
144
+ assert key not in self._tasks, f"Task {key!r} already running — cancel before starting"
145
+ instance = cls(tenant_id=tenant_id, config=config)
146
+ task = asyncio.create_task(instance.run(), name=key)
147
+ task.add_done_callback(lambda t, k=key: self._on_task_done(k, t))
148
+ self._tasks[key] = task
149
+ log.info("registry.task_started", key=key)
150
+
151
+ def _on_task_done(self, key: str, task: asyncio.Task) -> None:
152
+ """Remove crashed tasks from _tasks so reconcile can restart them."""
153
+ if task.cancelled():
154
+ return
155
+ exc = task.exception()
156
+ if exc is not None:
157
+ log.error("registry.task_crashed", key=key, error=str(exc))
158
+ self._tasks.pop(key, None)
159
+
160
+ async def _cancel_tasks(self, keys: list[str]) -> None:
161
+ tasks = []
162
+ for key in keys:
163
+ task = self._tasks.pop(key, None)
164
+ if task and not task.done():
165
+ task.cancel()
166
+ tasks.append(task)
167
+ if tasks:
168
+ await asyncio.gather(*tasks, return_exceptions=True)
169
+
170
+ async def _fetch_providers(self, tenant_id: str | None = None) -> list[dict[str, Any]]:
171
+ """Query DB for enabled providers of this registry's kind.
172
+
173
+ Returns [] and logs a warning if the providers table doesn't exist yet
174
+ (pre-migration state). The warning is rate-limited to avoid log spam.
175
+ """
176
+ import asyncpg
177
+ # Connect directly to postgres (bypass PgBouncer) to avoid RLS issues
178
+ # with SET ROLE omur_app when app.tenant_id isn't set.
179
+ dsn = self._postgres_dsn.replace("postgresql+asyncpg://", "postgresql://")
180
+ dsn = dsn.replace("pgbouncer:6432", "postgres:5432")
181
+ conn = await asyncpg.connect(dsn)
182
+ try:
183
+ if tenant_id:
184
+ rows = await conn.fetch(
185
+ "SELECT tenant_id::text, name, config FROM providers "
186
+ "WHERE kind = $1 AND enabled = TRUE AND tenant_id = $2::uuid",
187
+ self.kind, tenant_id,
188
+ )
189
+ else:
190
+ rows = await conn.fetch(
191
+ "SELECT tenant_id::text, name, config FROM providers "
192
+ "WHERE kind = $1 AND enabled = TRUE",
193
+ self.kind,
194
+ )
195
+ if self._table_missing:
196
+ log.info("registry.table_available", kind=self.kind)
197
+ self._table_missing = False
198
+ return [{"tenant_id": r["tenant_id"], "name": r["name"], "config": r["config"]} for r in rows]
199
+ # asyncpg auto-deserializes JSONB to dict — no json.loads() needed
200
+ except asyncpg.UndefinedTableError:
201
+ if not self._table_missing:
202
+ log.warning("registry.table_missing", kind=self.kind,
203
+ hint="run the migrate container to create the providers table")
204
+ self._table_missing = True
205
+ return []
206
+ finally:
207
+ await conn.close()
208
+
209
+ async def _subscribe_valkey(self) -> None:
210
+ """
211
+ Subscribe to omur:providers:updated:* and reload tenants on events.
212
+ Implements reconnect-with-exponential-backoff.
213
+ On reconnect, performs a full provider reconciliation to catch missed events.
214
+ """
215
+ import redis.asyncio as redis
216
+
217
+ backoff = 1.0
218
+ while True:
219
+ client = None
220
+ pubsub = None
221
+ try:
222
+ client = redis.from_url(self._valkey_url)
223
+ pubsub = client.pubsub()
224
+ await pubsub.psubscribe("omur:providers:updated:*")
225
+ log.info("registry.valkey_subscribed", kind=self.kind)
226
+ backoff = 1.0
227
+
228
+ # On (re)connect: full reconciliation to catch missed events
229
+ await self._reconcile_all()
230
+
231
+ async for message in pubsub.listen():
232
+ if message["type"] != "pmessage":
233
+ continue
234
+ channel: str = message["channel"].decode()
235
+ tenant_id = channel.split(":")[-1]
236
+ await self._reload_tenant(tenant_id)
237
+
238
+ except asyncio.CancelledError:
239
+ return
240
+ except Exception as exc:
241
+ log.warning("registry.valkey_disconnected", error=str(exc), retry_in=backoff)
242
+ await asyncio.sleep(backoff)
243
+ backoff = min(backoff * 2, 60.0)
244
+ finally:
245
+ if pubsub is not None:
246
+ await pubsub.aclose()
247
+ if client is not None:
248
+ await client.aclose()
249
+
250
+ async def _reconcile_all(self) -> None:
251
+ """Re-read all enabled providers from DB and sync running tasks."""
252
+ rows = await self._fetch_providers()
253
+ desired = {f"{r['tenant_id']}:{r['name']}": r for r in rows}
254
+ current = set(self._tasks.keys())
255
+ desired_keys = set(desired.keys())
256
+
257
+ # Cancel tasks no longer in DB
258
+ await self._cancel_tasks(list(current - desired_keys))
259
+
260
+ # Start new tasks not yet running
261
+ for key in desired_keys - current:
262
+ row = desired[key]
263
+ self._start_task(row["tenant_id"], row["name"], row["config"])
omkit/py.typed ADDED
File without changes
omkit/quota.py ADDED
@@ -0,0 +1,186 @@
1
+ """packages/omur-sdk/omkit/quota.py — Per-tenant quota helpers (plan 1.7).
2
+
3
+ Mirrors ``packages/omur-go-sdk/quota/quota.go`` so marrow (Python) and
4
+ spine (Go) enforce identical defaults. Absence of a ``tenant_quotas`` row
5
+ means "use defaults below". Limits are integers; bytes are BIGINT.
6
+
7
+ exports: DEFAULT_DOCS | DEFAULT_STORAGE_BYTES | DEFAULT_QUERIES_PER_MONTH | class Resource | class Limits | class Usage | class Decision | load(session) | get_usage(session) | check_upload(lim, usage, incoming_bytes) | check_query(lim, usage)
8
+ rules: The module must maintain backward compatibility with existing quota enforcement logic and cannot modify the public API of `load`, `get_usage`, `check_upload`, or `check_query` functions.
9
+ agent: ollama/qwen3-coder:latest | ollama | 2026-05-01 | codedna-cli | initial CodeDNA annotation pass
10
+ message:
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import enum
16
+ from dataclasses import dataclass
17
+ from datetime import datetime, timezone
18
+
19
+ from sqlalchemy import text
20
+ from sqlalchemy.ext.asyncio import AsyncSession
21
+
22
+ DEFAULT_DOCS = 100
23
+ DEFAULT_STORAGE_BYTES = 500 * 1024 * 1024 # 500 MiB
24
+ DEFAULT_QUERIES_PER_MONTH = 1000
25
+
26
+
27
+ class Resource(str, enum.Enum):
28
+ DOCS = "docs"
29
+ STORAGE_BYTES = "storage_bytes"
30
+ QUERIES_PER_MONTH = "queries_per_month"
31
+
32
+
33
+ @dataclass(frozen=True)
34
+ class Limits:
35
+ docs: int
36
+ storage_bytes: int
37
+ queries_per_month: int
38
+
39
+
40
+ @dataclass(frozen=True)
41
+ class Usage:
42
+ docs: int
43
+ storage_bytes: int
44
+ queries_this_month: int
45
+
46
+
47
+ @dataclass(frozen=True)
48
+ class Decision:
49
+ allowed: bool
50
+ resource: Resource | None = None
51
+ limit: int = 0
52
+ used: int = 0
53
+ retry_after: int = 0 # seconds; 0 means "no retry will help"
54
+
55
+
56
+ async def load(session: AsyncSession) -> Limits:
57
+ """Read the caller's effective limits under the session's RLS role.
58
+
59
+ The caller is expected to have already invoked ``tenant.set_rls(session)``.
60
+ Falls back to defaults when no row exists for the tenant.
61
+
62
+ Rules: The function assumes that `tenant.set_rls(session)` has already been called, and explicitly filters by `app.tenant_id` to prevent cross-tenant data leakage since `tenant_quotas` has no RLS policy.
63
+ """
64
+ # ``tenant_quotas`` intentionally has no RLS policy (operators need
65
+ # cross-tenant visibility for capacity planning), so we filter by
66
+ # the caller's app.tenant_id GUC explicitly. Without this WHERE the
67
+ # helper would return an arbitrary row belonging to another tenant.
68
+ row = (
69
+ await session.execute(
70
+ text(
71
+ "SELECT docs_limit, storage_bytes_limit, queries_per_month_limit "
72
+ "FROM tenant_quotas "
73
+ "WHERE tenant_id = current_setting('app.tenant_id', true)::uuid "
74
+ "LIMIT 1"
75
+ )
76
+ )
77
+ ).first()
78
+ if row is None:
79
+ return Limits(
80
+ docs=DEFAULT_DOCS,
81
+ storage_bytes=DEFAULT_STORAGE_BYTES,
82
+ queries_per_month=DEFAULT_QUERIES_PER_MONTH,
83
+ )
84
+ return Limits(
85
+ docs=int(row[0]),
86
+ storage_bytes=int(row[1]),
87
+ queries_per_month=int(row[2]),
88
+ )
89
+
90
+
91
+ async def get_usage(session: AsyncSession) -> Usage:
92
+ """Read docs / storage_bytes / queries-this-month for the request tenant.
93
+
94
+ Caller must have already set RLS on the session.
95
+
96
+ Services that don't hold SELECT on ``usage_log`` (e.g. marrow, which
97
+ only enforces upload-side quotas) get ``queries_this_month=0`` from
98
+ this helper — the upload-path checks never read that field, and the
99
+ spine middleware that does enforce query quota will roll back and
100
+ surface the real permission error instead.
101
+
102
+ Rules: The function requires the caller to have already set RLS on the session, and services without SELECT permission on `usage_log` will get `queries_this_month=0`, which may mask real permission errors in query enforcement.
103
+ """
104
+ doc_row = (
105
+ await session.execute(
106
+ text(
107
+ "SELECT COUNT(*)::int, COALESCE(SUM(size_bytes), 0)::bigint "
108
+ "FROM document_files"
109
+ )
110
+ )
111
+ ).one()
112
+ queries = 0
113
+ nested = await session.begin_nested()
114
+ try:
115
+ queries = (
116
+ await session.execute(
117
+ text(
118
+ "SELECT COUNT(*)::int FROM usage_log "
119
+ "WHERE created_at >= date_trunc('month', now())"
120
+ )
121
+ )
122
+ ).scalar() or 0
123
+ await nested.commit()
124
+ except Exception:
125
+ # Permission denied / table missing: upload-side callers don't
126
+ # need queries_this_month. Rolling back the savepoint keeps the
127
+ # outer transaction usable so callers can still read `docs`.
128
+ await nested.rollback()
129
+ return Usage(
130
+ docs=int(doc_row[0]),
131
+ storage_bytes=int(doc_row[1]),
132
+ queries_this_month=int(queries),
133
+ )
134
+
135
+
136
+ def check_upload(lim: Limits, usage: Usage, incoming_bytes: int) -> Decision:
137
+ """
138
+ Rules: The function does not validate that `incoming_bytes` is non-negative, which could lead to incorrect quota calculations if negative values are passed.
139
+ """
140
+ if usage.docs + 1 > lim.docs:
141
+ return Decision(
142
+ allowed=False,
143
+ resource=Resource.DOCS,
144
+ limit=lim.docs,
145
+ used=usage.docs,
146
+ retry_after=0,
147
+ )
148
+ if usage.storage_bytes + incoming_bytes > lim.storage_bytes:
149
+ return Decision(
150
+ allowed=False,
151
+ resource=Resource.STORAGE_BYTES,
152
+ limit=lim.storage_bytes,
153
+ used=usage.storage_bytes,
154
+ retry_after=0,
155
+ )
156
+ return Decision(allowed=True)
157
+
158
+
159
+ def check_query(lim: Limits, usage: Usage) -> Decision:
160
+ """
161
+ Rules: The function relies on `_seconds_until_next_month()` to calculate retry_after, which may not be accurate if the system clock is skewed or if the function is called outside of normal monthly boundaries.
162
+ """
163
+ if usage.queries_this_month + 1 > lim.queries_per_month:
164
+ return Decision(
165
+ allowed=False,
166
+ resource=Resource.QUERIES_PER_MONTH,
167
+ limit=lim.queries_per_month,
168
+ used=usage.queries_this_month,
169
+ retry_after=_seconds_until_next_month(),
170
+ )
171
+ return Decision(allowed=True)
172
+
173
+
174
+ def _cap_at_32_days(s: int) -> int:
175
+ if s < 0:
176
+ return 60
177
+ if s > 32 * 24 * 3600:
178
+ return 32 * 24 * 3600
179
+ return s
180
+
181
+
182
+ def _seconds_until_next_month() -> int:
183
+ now = datetime.now(timezone.utc)
184
+ y, m = (now.year + 1, 1) if now.month == 12 else (now.year, now.month + 1)
185
+ nxt = datetime(y, m, 1, 0, 0, 0, tzinfo=timezone.utc)
186
+ return _cap_at_32_days(int((nxt - now).total_seconds()))
omkit/resilience.py ADDED
@@ -0,0 +1,122 @@
1
+ """packages/omur-sdk/omkit/resilience.py — HTTP resilience primitives: circuit breaker + retry with exponential backoff.
2
+
3
+ exports: T | class CircuitOpen | class CircuitBreaker | resilient(breaker)
4
+ rules: The circuit breaker must maintain thread safety across all state transitions and failure tracking operations. The breaker's state must be consistent between concurrent calls and failures, with proper synchronization to prevent race conditions during state changes. All external dependencies like httpx exceptions must be handled with specific type checking to ensure transient error detection works correctly.
5
+ agent: ollama/qwen3-coder:latest | ollama | 2026-05-01 | codedna-cli | initial CodeDNA annotation pass
6
+ message:
7
+ """
8
+ from __future__ import annotations
9
+
10
+ import functools
11
+ import time
12
+ from enum import Enum
13
+ from typing import Awaitable, Callable, TypeVar
14
+
15
+ import httpx
16
+ from tenacity import retry, retry_if_exception, stop_after_attempt, wait_exponential
17
+
18
+ T = TypeVar("T")
19
+
20
+
21
+ class _State(Enum):
22
+ CLOSED = "closed"
23
+ OPEN = "open"
24
+ HALF_OPEN = "half_open"
25
+
26
+
27
+ class CircuitOpen(Exception):
28
+ """Raised when a call is rejected because the circuit is open."""
29
+
30
+
31
+ class CircuitBreaker:
32
+ """Async circuit breaker for single-process services.
33
+
34
+ States: CLOSED (normal) → OPEN (fail-fast) → HALF_OPEN (probe) → CLOSED
35
+ Thread-safe within a single asyncio event loop (no cross-thread use).
36
+ """
37
+
38
+ def __init__(self, fail_max: int = 5, reset_timeout: int = 60, name: str = "") -> None:
39
+ self.fail_max = fail_max
40
+ self.reset_timeout = reset_timeout
41
+ self.name = name
42
+ self._failures = 0
43
+ self._state = _State.CLOSED
44
+ self._opened_at: float = 0.0
45
+ self._open_observed: bool = False # True after first "open" state is returned
46
+
47
+ @property
48
+ def state(self) -> str:
49
+ """
50
+ Rules: When the circuit breaker is in OPEN state, it transitions to HALF_OPEN only after the reset_timeout has elapsed since it was opened. The _open_observed flag is used to track whether the breaker has been observed in OPEN state during the current reset period.
51
+ """
52
+ if self._state == _State.OPEN:
53
+ if self._open_observed and time.monotonic() - self._opened_at >= self.reset_timeout:
54
+ self._state = _State.HALF_OPEN
55
+ else:
56
+ self._open_observed = True
57
+ return self._state.value
58
+
59
+ def record_success(self) -> None:
60
+ """
61
+ Rules: Calling this function resets the failure count and transitions the circuit breaker to CLOSED state, regardless of the current state. The _open_observed flag is reset to False to ensure proper behavior in subsequent state transitions.
62
+ """
63
+ self._failures = 0
64
+ self._state = _State.CLOSED
65
+ self._open_observed = False
66
+
67
+ def record_failure(self) -> None:
68
+ """
69
+ Rules: Each call increments the failure count. When the failure count reaches the configured fail_max threshold, the circuit breaker transitions to OPEN state and records the current timestamp. The _open_observed flag is reset to False to ensure proper behavior in subsequent state transitions.
70
+ """
71
+ self._failures += 1
72
+ if self._failures >= self.fail_max:
73
+ self._state = _State.OPEN
74
+ self._opened_at = time.monotonic()
75
+ self._open_observed = False
76
+
77
+ async def call(self, coro: Awaitable[T]) -> T:
78
+ """
79
+ Rules: The circuit breaker must be in the 'closed' state to allow calls; otherwise, it raises a CircuitOpen exception. The state transition logic depends on reset_timeout and fail_max thresholds.
80
+ """
81
+ if self.state == "open":
82
+ raise CircuitOpen(
83
+ f"Circuit '{self.name}' is open — retry after {self.reset_timeout}s"
84
+ )
85
+ try:
86
+ result = await coro
87
+ self.record_success()
88
+ return result
89
+ except Exception:
90
+ self.record_failure()
91
+ raise
92
+
93
+
94
+ def _is_transient(exc: BaseException) -> bool:
95
+ """True for errors that warrant a retry."""
96
+ if isinstance(exc, (httpx.ConnectError, httpx.TimeoutException)):
97
+ return True
98
+ if isinstance(exc, httpx.HTTPStatusError) and exc.response is not None:
99
+ return exc.response.status_code in (502, 503, 504)
100
+ return False
101
+
102
+
103
+ def resilient(breaker: CircuitBreaker) -> Callable:
104
+ """Decorator: retry 3× with exponential backoff, guarded by a circuit breaker.
105
+
106
+ Rules: The decorator applies a retry mechanism with exponential backoff and uses the provided circuit breaker to guard the function call. It assumes the circuit breaker is properly initialized with valid fail_max and reset_timeout values.
107
+ """
108
+
109
+ def decorator(fn: Callable[..., Awaitable[T]]) -> Callable[..., Awaitable[T]]:
110
+ @retry(
111
+ stop=stop_after_attempt(3),
112
+ wait=wait_exponential(multiplier=1, min=1, max=8),
113
+ retry=retry_if_exception(_is_transient),
114
+ reraise=True,
115
+ )
116
+ @functools.wraps(fn)
117
+ async def wrapped(*args, **kwargs):
118
+ return await breaker.call(fn(*args, **kwargs))
119
+
120
+ return wrapped
121
+
122
+ return decorator