agentflow-client 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. agentflow/__init__.py +14 -0
  2. agentflow/_compat.py +26 -0
  3. agentflow/async_client.py +564 -0
  4. agentflow/circuit_breaker.py +65 -0
  5. agentflow/cli.py +505 -0
  6. agentflow/client.py +562 -0
  7. agentflow/exceptions.py +27 -0
  8. agentflow/models.py +250 -0
  9. agentflow/py.typed +0 -0
  10. agentflow/retry.py +41 -0
  11. agentflow/templates/basic/.env.example.tmpl +2 -0
  12. agentflow/templates/basic/README.md.tmpl +15 -0
  13. agentflow/templates/basic/main.py.tmpl +25 -0
  14. agentflow/templates/basic/requirements.txt.tmpl +1 -0
  15. agentflow/templates/crewai/.env.example.tmpl +2 -0
  16. agentflow/templates/crewai/README.md.tmpl +12 -0
  17. agentflow/templates/crewai/main.py.tmpl +28 -0
  18. agentflow/templates/crewai/requirements.txt.tmpl +4 -0
  19. agentflow/templates/langchain/.env.example.tmpl +3 -0
  20. agentflow/templates/langchain/README.md.tmpl +13 -0
  21. agentflow/templates/langchain/main.py.tmpl +28 -0
  22. agentflow/templates/langchain/requirements.txt.tmpl +4 -0
  23. agentflow/templates/vercel-ai/.env.example.tmpl +3 -0
  24. agentflow/templates/vercel-ai/README.md.tmpl +12 -0
  25. agentflow/templates/vercel-ai/app/api/chat/route.ts.tmpl +37 -0
  26. agentflow/templates/vercel-ai/app/layout.tsx.tmpl +9 -0
  27. agentflow/templates/vercel-ai/app/page.tsx.tmpl +46 -0
  28. agentflow/templates/vercel-ai/next-env.d.ts.tmpl +2 -0
  29. agentflow/templates/vercel-ai/next.config.mjs.tmpl +3 -0
  30. agentflow/templates/vercel-ai/package.json.tmpl +24 -0
  31. agentflow/templates/vercel-ai/tsconfig.json.tmpl +19 -0
  32. agentflow_client-1.1.0.dist-info/METADATA +72 -0
  33. agentflow_client-1.1.0.dist-info/RECORD +35 -0
  34. agentflow_client-1.1.0.dist-info/WHEEL +4 -0
  35. agentflow_client-1.1.0.dist-info/entry_points.txt +2 -0
agentflow/__init__.py ADDED
@@ -0,0 +1,14 @@
1
+ from agentflow.async_client import AsyncAgentFlowClient
2
+ from agentflow.circuit_breaker import CircuitOpenError
3
+ from agentflow.client import AgentFlowClient
4
+ from agentflow.exceptions import PermissionDeniedError
5
+
6
+ __version__ = "1.1.0"
7
+
8
+ __all__ = [
9
+ "AgentFlowClient",
10
+ "AsyncAgentFlowClient",
11
+ "PermissionDeniedError",
12
+ "CircuitOpenError",
13
+ "__version__",
14
+ ]
agentflow/_compat.py ADDED
@@ -0,0 +1,26 @@
1
+ import warnings
2
+ from collections.abc import Callable
3
+ from functools import wraps
4
+ from typing import Any
5
+
6
+
7
+ def deprecated(
8
+ replacement: str,
9
+ removed_in: str,
10
+ ) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
11
+ def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
12
+ @wraps(func)
13
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
14
+ warnings.warn(
15
+ (
16
+ f"{func.__name__} is deprecated and will be removed in {removed_in}. "
17
+ f"Use {replacement} instead."
18
+ ),
19
+ DeprecationWarning,
20
+ stacklevel=2,
21
+ )
22
+ return func(*args, **kwargs)
23
+
24
+ return wrapper
25
+
26
+ return decorator
@@ -0,0 +1,564 @@
1
+ import asyncio
2
+ from datetime import UTC, datetime
3
+ from typing import Any, AsyncIterator, TypeVar, cast
4
+ from uuid import uuid4
5
+
6
+ import httpx
7
+ from pydantic import BaseModel
8
+
9
+ from agentflow.circuit_breaker import CircuitBreaker
10
+ from agentflow.exceptions import (
11
+ AgentFlowError,
12
+ AuthError,
13
+ DataFreshnessError,
14
+ EntityNotFoundError,
15
+ PermissionDeniedError,
16
+ RateLimitError,
17
+ )
18
+ from agentflow.models import (
19
+ CatalogResponse,
20
+ Changelog,
21
+ ContractDiff,
22
+ ContractSummary,
23
+ ContractValidation,
24
+ EntityEnvelope,
25
+ EntityContract,
26
+ HealthStatus,
27
+ Lineage,
28
+ MetricResult,
29
+ OrderEntity,
30
+ ProductEntity,
31
+ QueryExplanation,
32
+ QueryResult,
33
+ SearchResults,
34
+ SessionEntity,
35
+ UserEntity,
36
+ )
37
+ from agentflow.retry import RETRYABLE_STATUS, RetryPolicy, is_retryable_method
38
+
39
+ EntityModelT = TypeVar("EntityModelT", bound=BaseModel)
40
+
41
+
42
+ def _normalize_as_of(as_of: datetime | str | None) -> str | None:
43
+ if as_of is None:
44
+ return None
45
+ if isinstance(as_of, datetime):
46
+ value = as_of
47
+ else:
48
+ try:
49
+ value = datetime.fromisoformat(as_of.replace("Z", "+00:00"))
50
+ except ValueError:
51
+ return as_of
52
+ if value.tzinfo is None:
53
+ value = value.replace(tzinfo=UTC)
54
+ return value.astimezone(UTC).replace(microsecond=0).isoformat().replace("+00:00", "Z")
55
+
56
+
57
+ class _LegacyResilienceInitCompat(type):
58
+ def __call__(cls, *args: Any, **kwargs: Any) -> Any:
59
+ retry_policy = kwargs.pop("retry_policy", None)
60
+ circuit_breaker = kwargs.pop("circuit_breaker", None)
61
+ client = super().__call__(*args, **kwargs)
62
+ if retry_policy is not None or circuit_breaker is not None:
63
+ client.configure_resilience(
64
+ retry_policy=retry_policy,
65
+ circuit_breaker=circuit_breaker,
66
+ )
67
+ return client
68
+
69
+
70
+ class AsyncAgentFlowClient(metaclass=_LegacyResilienceInitCompat):
71
+ def __init__(
72
+ self,
73
+ base_url: str,
74
+ api_key: str,
75
+ timeout: float = 10.0,
76
+ contract_version: str | None = None,
77
+ api_version: str | None = None,
78
+ ):
79
+ headers = {"X-API-Key": api_key}
80
+ if api_version is not None:
81
+ headers["X-AgentFlow-Version"] = api_version
82
+ self._http = httpx.AsyncClient(
83
+ base_url=base_url.rstrip("/"),
84
+ timeout=timeout,
85
+ headers=headers,
86
+ )
87
+ self._contract_versions = self._parse_contract_versions(contract_version)
88
+ self._contract_cache: dict[tuple[str, str], dict[str, Any]] = {}
89
+ self._last_server_version: str | None = None
90
+ self._last_latest_version: str | None = None
91
+ self._last_deprecated: str | None = None
92
+ self._last_deprecation_warning: str | None = None
93
+ self.retry_policy = RetryPolicy()
94
+ self.circuit_breaker = CircuitBreaker()
95
+
96
+ @property
97
+ def last_server_version(self) -> str | None:
98
+ return self._last_server_version
99
+
100
+ @property
101
+ def last_deprecation_warning(self) -> str | None:
102
+ return self._last_deprecation_warning
103
+
104
+ def configure_resilience(
105
+ self,
106
+ retry_policy: RetryPolicy | None = None,
107
+ circuit_breaker: CircuitBreaker | None = None,
108
+ ) -> "AsyncAgentFlowClient":
109
+ if retry_policy is not None:
110
+ self.retry_policy = retry_policy
111
+ if circuit_breaker is not None:
112
+ self.circuit_breaker = circuit_breaker
113
+ return self
114
+
115
+ async def _request(
116
+ self,
117
+ method: str,
118
+ path: str,
119
+ *,
120
+ params: dict[str, Any] | None = None,
121
+ json: dict[str, Any] | None = None,
122
+ headers: dict[str, str] | None = None,
123
+ ) -> dict[str, Any]:
124
+ attempt = 0
125
+ can_retry = is_retryable_method(method, headers=headers)
126
+ self.circuit_breaker.before_call()
127
+ while True:
128
+ try:
129
+ response = await self._http.request(
130
+ method,
131
+ path,
132
+ params=params,
133
+ json=json,
134
+ headers=headers,
135
+ )
136
+ except httpx.TransportError as exc:
137
+ if can_retry and attempt < self.retry_policy.max_attempts - 1:
138
+ delay = self.retry_policy.compute_delay(attempt)
139
+ attempt += 1
140
+ await asyncio.sleep(delay)
141
+ continue
142
+ self.circuit_breaker.record_failure()
143
+ raise AgentFlowError(f"Request failed: {exc}") from exc
144
+ retry_after: float | None = None
145
+ retry_after_header = response.headers.get("Retry-After")
146
+ if retry_after_header is not None:
147
+ try:
148
+ retry_after = float(retry_after_header)
149
+ except ValueError:
150
+ retry_after = None
151
+ if (
152
+ can_retry
153
+ and response.status_code in RETRYABLE_STATUS
154
+ and attempt < self.retry_policy.max_attempts - 1
155
+ ):
156
+ delay = self.retry_policy.compute_delay(attempt, retry_after)
157
+ attempt += 1
158
+ await asyncio.sleep(delay)
159
+ continue
160
+ break
161
+
162
+ self._record_version_headers(response.headers)
163
+
164
+ if response.status_code >= 500:
165
+ self.circuit_breaker.record_failure()
166
+ else:
167
+ self.circuit_breaker.record_success()
168
+
169
+ payload = cast(dict[str, Any], response.json())
170
+
171
+ if response.status_code == 401:
172
+ detail = payload.get("detail", "Unauthorized")
173
+ raise AuthError(detail)
174
+
175
+ if response.status_code == 403:
176
+ detail = payload.get("detail", "Forbidden")
177
+ raise PermissionDeniedError(detail)
178
+
179
+ if response.status_code == 429:
180
+ detail = payload.get("detail", "Rate limit exceeded")
181
+ retry_after = int(response.headers.get("Retry-After", "0"))
182
+ raise RateLimitError(detail, retry_after=retry_after)
183
+
184
+ if response.status_code == 404:
185
+ detail = payload.get("detail", "Resource not found")
186
+ parts = path.strip("/").split("/")
187
+ if len(parts) >= 4 and parts[1] == "entity":
188
+ raise EntityNotFoundError(parts[2], parts[3], detail)
189
+ raise AgentFlowError(detail)
190
+
191
+ if response.status_code >= 400:
192
+ detail = payload.get("detail", response.text)
193
+ raise AgentFlowError(detail)
194
+
195
+ return payload
196
+
197
+ def _record_version_headers(self, headers: httpx.Headers) -> None:
198
+ self._last_server_version = headers.get("X-AgentFlow-Version")
199
+ self._last_latest_version = headers.get("X-AgentFlow-Latest-Version")
200
+ self._last_deprecated = headers.get("X-AgentFlow-Deprecated")
201
+ self._last_deprecation_warning = headers.get(
202
+ "X-AgentFlow-Deprecation-Warning"
203
+ )
204
+
205
+ async def _get_entity(
206
+ self,
207
+ entity_type: str,
208
+ entity_id: str,
209
+ model: type[EntityModelT],
210
+ as_of: datetime | str | None = None,
211
+ ) -> EntityModelT:
212
+ envelope = await self.get_entity(entity_type, entity_id, as_of=as_of)
213
+ return cast(EntityModelT, model.model_validate(envelope.data))
214
+
215
+ def _parse_contract_versions(
216
+ self,
217
+ contract_version: str | None,
218
+ ) -> dict[str, str]:
219
+ if contract_version is None:
220
+ return {}
221
+ entity, separator, version = contract_version.partition(":")
222
+ if not separator or not entity or not version:
223
+ raise ValueError(
224
+ "contract_version must use '<entity>:<version>' format."
225
+ )
226
+ return {entity: version[1:] if version.startswith("v") else version}
227
+
228
+ async def _apply_contract_version(
229
+ self,
230
+ entity_type: str,
231
+ payload: dict[str, Any],
232
+ ) -> dict[str, Any]:
233
+ version = self._contract_versions.get(entity_type)
234
+ if version is None:
235
+ return payload
236
+ contract = await self._get_contract(entity_type, version)
237
+ fields = contract.get("fields", [])
238
+ required_fields = [
239
+ field["name"]
240
+ for field in fields
241
+ if field.get("required")
242
+ ]
243
+ missing_fields = [
244
+ field_name
245
+ for field_name in required_fields
246
+ if field_name not in payload
247
+ ]
248
+ if missing_fields:
249
+ raise AgentFlowError(
250
+ "Contract validation failed. Missing required fields: "
251
+ + ", ".join(missing_fields)
252
+ )
253
+ allowed_fields = {field["name"] for field in fields}
254
+ return {
255
+ name: value
256
+ for name, value in payload.items()
257
+ if name in allowed_fields
258
+ }
259
+
260
+ async def _get_contract(self, entity_type: str, version: str) -> dict[str, Any]:
261
+ cache_key = (entity_type, version)
262
+ cached = self._contract_cache.get(cache_key)
263
+ if cached is not None:
264
+ return cached
265
+ contract = await self._request("GET", f"/v1/contracts/{entity_type}/{version}")
266
+ self._contract_cache[cache_key] = contract
267
+ return contract
268
+
269
+ async def get_entity(
270
+ self,
271
+ entity_type: str,
272
+ entity_id: str,
273
+ *,
274
+ as_of: datetime | str | None = None,
275
+ ) -> EntityEnvelope:
276
+ params: dict[str, Any] | None = None
277
+ normalized_as_of = _normalize_as_of(as_of)
278
+ if normalized_as_of is not None:
279
+ params = {"as_of": normalized_as_of}
280
+ payload = await self._request(
281
+ "GET",
282
+ f"/v1/entity/{entity_type}/{entity_id}",
283
+ params=params,
284
+ )
285
+ envelope = EntityEnvelope.model_validate(payload)
286
+ data = await self._apply_contract_version(entity_type, envelope.data)
287
+ return envelope.model_copy(update={"data": data})
288
+
289
+ async def get_order(
290
+ self,
291
+ order_id: str,
292
+ *,
293
+ as_of: datetime | str | None = None,
294
+ ) -> OrderEntity:
295
+ return await self._get_entity("order", order_id, OrderEntity, as_of=as_of)
296
+
297
+ async def get_user(
298
+ self,
299
+ user_id: str,
300
+ *,
301
+ as_of: datetime | str | None = None,
302
+ ) -> UserEntity:
303
+ return await self._get_entity("user", user_id, UserEntity, as_of=as_of)
304
+
305
+ async def get_product(
306
+ self,
307
+ product_id: str,
308
+ *,
309
+ as_of: datetime | str | None = None,
310
+ ) -> ProductEntity:
311
+ return await self._get_entity("product", product_id, ProductEntity, as_of=as_of)
312
+
313
+ async def get_session(
314
+ self,
315
+ session_id: str,
316
+ *,
317
+ as_of: datetime | str | None = None,
318
+ ) -> SessionEntity:
319
+ return await self._get_entity("session", session_id, SessionEntity, as_of=as_of)
320
+
321
+ async def get_metric(
322
+ self,
323
+ name: str,
324
+ window: str = "1h",
325
+ *,
326
+ as_of: datetime | str | None = None,
327
+ ) -> MetricResult:
328
+ params: dict[str, Any] = {"window": window}
329
+ normalized_as_of = _normalize_as_of(as_of)
330
+ if normalized_as_of is not None:
331
+ params["as_of"] = normalized_as_of
332
+ payload = await self._request("GET", f"/v1/metrics/{name}", params=params)
333
+ return MetricResult.model_validate(payload)
334
+
335
+ def _normalize_query_payload(self, payload: dict[str, Any]) -> dict[str, Any]:
336
+ metadata = dict(payload.get("metadata", {}))
337
+ for key in ("total_count", "next_cursor", "has_more", "page_size"):
338
+ if key in payload:
339
+ metadata[key] = payload[key]
340
+ return {
341
+ "answer": payload.get("answer", payload.get("rows", [])),
342
+ "sql": payload.get("sql"),
343
+ "metadata": metadata,
344
+ }
345
+
346
+ async def _query_page(
347
+ self,
348
+ question: str,
349
+ *,
350
+ limit: int | None = None,
351
+ cursor: str | None = None,
352
+ idempotency_key: str | None = None,
353
+ ) -> dict[str, Any]:
354
+ payload: dict[str, Any] = {"question": question}
355
+ if limit is not None:
356
+ payload["limit"] = limit
357
+ if cursor is not None:
358
+ payload["cursor"] = cursor
359
+ headers = (
360
+ {"Idempotency-Key": idempotency_key}
361
+ if idempotency_key is not None
362
+ else None
363
+ )
364
+ return await self._request("POST", "/v1/query", json=payload, headers=headers)
365
+
366
+ async def query(
367
+ self,
368
+ question: str,
369
+ limit: int | None = None,
370
+ cursor: str | None = None,
371
+ idempotency_key: str | None = None,
372
+ ) -> QueryResult:
373
+ payload = await self._query_page(
374
+ question,
375
+ limit=limit,
376
+ cursor=cursor,
377
+ idempotency_key=idempotency_key,
378
+ )
379
+ return QueryResult.model_validate(self._normalize_query_payload(payload))
380
+
381
+ async def explain_query(
382
+ self,
383
+ question: str,
384
+ contract_version: str | None = None,
385
+ ) -> QueryExplanation:
386
+ payload: dict[str, Any] = {"question": question}
387
+ if contract_version is not None:
388
+ payload["contract_version"] = contract_version
389
+ response = await self._request("POST", "/v1/query/explain", json=payload)
390
+ return QueryExplanation.model_validate(response)
391
+
392
+ async def search(
393
+ self,
394
+ query: str,
395
+ *,
396
+ limit: int = 10,
397
+ entity_types: list[str] | None = None,
398
+ ) -> SearchResults:
399
+ params: dict[str, Any] = {"q": query, "limit": limit}
400
+ if entity_types is not None:
401
+ params["entity_types"] = entity_types
402
+ payload = await self._request("GET", "/v1/search", params=params)
403
+ return SearchResults.model_validate(payload)
404
+
405
+ async def list_contracts(self) -> list[ContractSummary]:
406
+ payload = await self._request("GET", "/v1/contracts")
407
+ return [
408
+ ContractSummary.model_validate(contract)
409
+ for contract in payload.get("contracts", [])
410
+ ]
411
+
412
+ async def get_contract(
413
+ self,
414
+ entity: str,
415
+ version: str | None = None,
416
+ ) -> EntityContract:
417
+ path = f"/v1/contracts/{entity}"
418
+ if version is not None:
419
+ path = f"{path}/{version}"
420
+ payload = await self._request("GET", path)
421
+ return EntityContract.model_validate(payload)
422
+
423
+ async def diff_contracts(
424
+ self,
425
+ entity: str,
426
+ from_version: str,
427
+ to_version: str,
428
+ ) -> ContractDiff:
429
+ payload = await self._request(
430
+ "GET",
431
+ f"/v1/contracts/{entity}/diff/{from_version}/{to_version}",
432
+ )
433
+ return ContractDiff.model_validate(payload)
434
+
435
+ async def validate_contract(
436
+ self,
437
+ entity: str,
438
+ payload: dict[str, Any],
439
+ *,
440
+ idempotency_key: str | None = None,
441
+ ) -> ContractValidation:
442
+ headers = (
443
+ {"Idempotency-Key": idempotency_key}
444
+ if idempotency_key is not None
445
+ else None
446
+ )
447
+ response = await self._request(
448
+ "POST",
449
+ f"/v1/contracts/{entity}/validate",
450
+ json=payload,
451
+ headers=headers,
452
+ )
453
+ return ContractValidation.model_validate(response)
454
+
455
+ async def get_lineage(self, entity_type: str, entity_id: str) -> Lineage:
456
+ payload = await self._request("GET", f"/v1/lineage/{entity_type}/{entity_id}")
457
+ return Lineage.model_validate(payload)
458
+
459
+ async def get_changelog(self) -> Changelog:
460
+ payload = await self._request("GET", "/v1/changelog")
461
+ return Changelog.model_validate(payload)
462
+
463
+ async def paginate(
464
+ self,
465
+ question: str,
466
+ page_size: int = 100,
467
+ ) -> AsyncIterator[list[dict[str, Any]]]:
468
+ cursor: str | None = None
469
+ while True:
470
+ payload = await self._query_page(question, limit=page_size, cursor=cursor)
471
+ rows = cast(list[dict[str, Any]], payload.get("rows", payload.get("answer", [])))
472
+ yield rows
473
+ if not payload.get("has_more"):
474
+ break
475
+ cursor = cast(str | None, payload.get("next_cursor"))
476
+ if cursor is None:
477
+ break
478
+
479
+ async def health(self) -> HealthStatus:
480
+ payload = await self._request("GET", "/v1/health")
481
+ return HealthStatus.model_validate(payload)
482
+
483
+ async def is_fresh(self, max_age_seconds: int = 60) -> bool:
484
+ health = await self.health()
485
+ if health.status != "healthy":
486
+ raise DataFreshnessError(
487
+ f"Pipeline is {health.status}; freshness check cannot be trusted"
488
+ )
489
+ if health.freshness_seconds is None:
490
+ raise DataFreshnessError("Pipeline freshness metric is unavailable")
491
+ return health.freshness_seconds < max_age_seconds
492
+
493
+ async def catalog(self) -> CatalogResponse:
494
+ payload = await self._request("GET", "/v1/catalog")
495
+ return CatalogResponse.model_validate(payload)
496
+
497
+ async def batch(
498
+ self,
499
+ requests: list[dict[str, Any]],
500
+ *,
501
+ idempotency_key: str | None = None,
502
+ ) -> dict[str, Any]:
503
+ headers = (
504
+ {"Idempotency-Key": idempotency_key}
505
+ if idempotency_key is not None
506
+ else None
507
+ )
508
+ return await self._request(
509
+ "POST",
510
+ "/v1/batch",
511
+ json={"requests": requests},
512
+ headers=headers,
513
+ )
514
+
515
+ def batch_entity(
516
+ self,
517
+ entity_type: str,
518
+ entity_id: str,
519
+ request_id: str | None = None,
520
+ ) -> dict[str, Any]:
521
+ return {
522
+ "id": request_id or f"entity-{uuid4().hex[:8]}",
523
+ "type": "entity",
524
+ "params": {
525
+ "entity_type": entity_type,
526
+ "entity_id": entity_id,
527
+ },
528
+ }
529
+
530
+ def batch_metric(
531
+ self,
532
+ name: str,
533
+ window: str = "1h",
534
+ request_id: str | None = None,
535
+ ) -> dict[str, Any]:
536
+ return {
537
+ "id": request_id or f"metric-{uuid4().hex[:8]}",
538
+ "type": "metric",
539
+ "params": {
540
+ "name": name,
541
+ "window": window,
542
+ },
543
+ }
544
+
545
+ def batch_query(
546
+ self,
547
+ question: str,
548
+ context: dict[str, Any] | None = None,
549
+ request_id: str | None = None,
550
+ ) -> dict[str, Any]:
551
+ params: dict[str, Any] = {"question": question}
552
+ if context is not None:
553
+ params["context"] = context
554
+ return {
555
+ "id": request_id or f"query-{uuid4().hex[:8]}",
556
+ "type": "query",
557
+ "params": params,
558
+ }
559
+
560
+ async def __aenter__(self) -> "AsyncAgentFlowClient":
561
+ return self
562
+
563
+ async def __aexit__(self, *args: object) -> None:
564
+ await self._http.aclose()
@@ -0,0 +1,65 @@
1
+ import time
2
+ from dataclasses import dataclass, field
3
+ from enum import Enum
4
+ from threading import Lock
5
+
6
+ from agentflow.exceptions import AgentFlowError
7
+
8
+
9
+ class CircuitState(Enum):
10
+ CLOSED = "closed"
11
+ OPEN = "open"
12
+ HALF_OPEN = "half_open"
13
+
14
+
15
+ class CircuitOpenError(AgentFlowError):
16
+ pass
17
+
18
+
19
+ @dataclass
20
+ class CircuitBreaker:
21
+ failure_threshold: int = 5
22
+ reset_timeout_s: float = 30.0
23
+ half_open_max_calls: int = 1
24
+
25
+ _state: CircuitState = field(default=CircuitState.CLOSED, init=False)
26
+ _failure_count: int = field(default=0, init=False)
27
+ _opened_at: float = field(default=0.0, init=False)
28
+ _half_open_calls: int = field(default=0, init=False)
29
+ _lock: Lock = field(default_factory=Lock, init=False)
30
+
31
+ def before_call(self) -> None:
32
+ with self._lock:
33
+ if self._state == CircuitState.OPEN:
34
+ if time.monotonic() - self._opened_at >= self.reset_timeout_s:
35
+ self._state = CircuitState.HALF_OPEN
36
+ self._half_open_calls = 0
37
+ else:
38
+ raise CircuitOpenError("circuit is open")
39
+ if self._state == CircuitState.HALF_OPEN:
40
+ if self._half_open_calls >= self.half_open_max_calls:
41
+ raise CircuitOpenError("circuit is half-open, probe in flight")
42
+ self._half_open_calls += 1
43
+
44
+ def record_success(self) -> None:
45
+ with self._lock:
46
+ self._state = CircuitState.CLOSED
47
+ self._failure_count = 0
48
+ self._half_open_calls = 0
49
+
50
+ def record_failure(self) -> None:
51
+ with self._lock:
52
+ if self._state == CircuitState.HALF_OPEN:
53
+ self._state = CircuitState.OPEN
54
+ self._opened_at = time.monotonic()
55
+ self._half_open_calls = 0
56
+ return
57
+ self._failure_count += 1
58
+ if self._failure_count >= self.failure_threshold:
59
+ self._state = CircuitState.OPEN
60
+ self._opened_at = time.monotonic()
61
+ self._half_open_calls = 0
62
+
63
+ @property
64
+ def state(self) -> CircuitState:
65
+ return self._state