ai-lib-python 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_lib_python/__init__.py +43 -0
- ai_lib_python/batch/__init__.py +15 -0
- ai_lib_python/batch/collector.py +244 -0
- ai_lib_python/batch/executor.py +224 -0
- ai_lib_python/cache/__init__.py +26 -0
- ai_lib_python/cache/backends.py +380 -0
- ai_lib_python/cache/key.py +237 -0
- ai_lib_python/cache/manager.py +332 -0
- ai_lib_python/client/__init__.py +37 -0
- ai_lib_python/client/builder.py +528 -0
- ai_lib_python/client/cancel.py +368 -0
- ai_lib_python/client/core.py +433 -0
- ai_lib_python/client/response.py +134 -0
- ai_lib_python/embeddings/__init__.py +36 -0
- ai_lib_python/embeddings/client.py +339 -0
- ai_lib_python/embeddings/types.py +234 -0
- ai_lib_python/embeddings/vectors.py +246 -0
- ai_lib_python/errors/__init__.py +41 -0
- ai_lib_python/errors/base.py +316 -0
- ai_lib_python/errors/classification.py +210 -0
- ai_lib_python/guardrails/__init__.py +35 -0
- ai_lib_python/guardrails/base.py +336 -0
- ai_lib_python/guardrails/filters.py +583 -0
- ai_lib_python/guardrails/validators.py +475 -0
- ai_lib_python/pipeline/__init__.py +55 -0
- ai_lib_python/pipeline/accumulate.py +248 -0
- ai_lib_python/pipeline/base.py +240 -0
- ai_lib_python/pipeline/decode.py +281 -0
- ai_lib_python/pipeline/event_map.py +506 -0
- ai_lib_python/pipeline/fan_out.py +284 -0
- ai_lib_python/pipeline/select.py +297 -0
- ai_lib_python/plugins/__init__.py +32 -0
- ai_lib_python/plugins/base.py +294 -0
- ai_lib_python/plugins/hooks.py +296 -0
- ai_lib_python/plugins/middleware.py +285 -0
- ai_lib_python/plugins/registry.py +294 -0
- ai_lib_python/protocol/__init__.py +71 -0
- ai_lib_python/protocol/loader.py +317 -0
- ai_lib_python/protocol/manifest.py +385 -0
- ai_lib_python/protocol/validator.py +460 -0
- ai_lib_python/py.typed +1 -0
- ai_lib_python/resilience/__init__.py +102 -0
- ai_lib_python/resilience/backpressure.py +225 -0
- ai_lib_python/resilience/circuit_breaker.py +318 -0
- ai_lib_python/resilience/executor.py +343 -0
- ai_lib_python/resilience/fallback.py +341 -0
- ai_lib_python/resilience/preflight.py +413 -0
- ai_lib_python/resilience/rate_limiter.py +291 -0
- ai_lib_python/resilience/retry.py +299 -0
- ai_lib_python/resilience/signals.py +283 -0
- ai_lib_python/routing/__init__.py +118 -0
- ai_lib_python/routing/manager.py +593 -0
- ai_lib_python/routing/strategy.py +345 -0
- ai_lib_python/routing/types.py +397 -0
- ai_lib_python/structured/__init__.py +33 -0
- ai_lib_python/structured/json_mode.py +281 -0
- ai_lib_python/structured/schema.py +316 -0
- ai_lib_python/structured/validator.py +334 -0
- ai_lib_python/telemetry/__init__.py +127 -0
- ai_lib_python/telemetry/exporters/__init__.py +9 -0
- ai_lib_python/telemetry/exporters/prometheus.py +111 -0
- ai_lib_python/telemetry/feedback.py +446 -0
- ai_lib_python/telemetry/health.py +409 -0
- ai_lib_python/telemetry/logger.py +389 -0
- ai_lib_python/telemetry/metrics.py +496 -0
- ai_lib_python/telemetry/tracer.py +473 -0
- ai_lib_python/tokens/__init__.py +25 -0
- ai_lib_python/tokens/counter.py +282 -0
- ai_lib_python/tokens/estimator.py +286 -0
- ai_lib_python/transport/__init__.py +34 -0
- ai_lib_python/transport/auth.py +141 -0
- ai_lib_python/transport/http.py +364 -0
- ai_lib_python/transport/pool.py +425 -0
- ai_lib_python/types/__init__.py +41 -0
- ai_lib_python/types/events.py +343 -0
- ai_lib_python/types/message.py +332 -0
- ai_lib_python/types/tool.py +191 -0
- ai_lib_python/utils/__init__.py +21 -0
- ai_lib_python/utils/tool_call_assembler.py +317 -0
- ai_lib_python-0.5.0.dist-info/METADATA +837 -0
- ai_lib_python-0.5.0.dist-info/RECORD +84 -0
- ai_lib_python-0.5.0.dist-info/WHEEL +4 -0
- ai_lib_python-0.5.0.dist-info/licenses/LICENSE-APACHE +201 -0
- ai_lib_python-0.5.0.dist-info/licenses/LICENSE-MIT +21 -0
|
@@ -0,0 +1,343 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Resilient executor combining all resilience patterns.
|
|
3
|
+
|
|
4
|
+
Provides a unified interface for executing operations with
|
|
5
|
+
retry, rate limiting, circuit breaking, and backpressure.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from typing import TYPE_CHECKING, Any, TypeVar
|
|
12
|
+
|
|
13
|
+
from ai_lib_python.resilience.backpressure import Backpressure, BackpressureConfig
|
|
14
|
+
from ai_lib_python.resilience.circuit_breaker import (
|
|
15
|
+
CircuitBreaker,
|
|
16
|
+
CircuitBreakerConfig,
|
|
17
|
+
)
|
|
18
|
+
from ai_lib_python.resilience.rate_limiter import RateLimiter, RateLimiterConfig
|
|
19
|
+
from ai_lib_python.resilience.retry import RetryConfig, RetryPolicy, RetryResult
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from collections.abc import Awaitable, Callable
|
|
23
|
+
|
|
24
|
+
T = TypeVar("T")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class ResilientConfig:
|
|
29
|
+
"""Combined configuration for all resilience patterns.
|
|
30
|
+
|
|
31
|
+
Attributes:
|
|
32
|
+
retry: Retry configuration
|
|
33
|
+
rate_limit: Rate limiter configuration
|
|
34
|
+
circuit_breaker: Circuit breaker configuration
|
|
35
|
+
backpressure: Backpressure configuration
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
retry: RetryConfig | None = None
|
|
39
|
+
rate_limit: RateLimiterConfig | None = None
|
|
40
|
+
circuit_breaker: CircuitBreakerConfig | None = None
|
|
41
|
+
backpressure: BackpressureConfig | None = None
|
|
42
|
+
|
|
43
|
+
@classmethod
|
|
44
|
+
def default(cls) -> ResilientConfig:
|
|
45
|
+
"""Create default configuration with all patterns enabled."""
|
|
46
|
+
return cls(
|
|
47
|
+
retry=RetryConfig(),
|
|
48
|
+
rate_limit=RateLimiterConfig(),
|
|
49
|
+
circuit_breaker=CircuitBreakerConfig(),
|
|
50
|
+
backpressure=BackpressureConfig(),
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
@classmethod
|
|
54
|
+
def minimal(cls) -> ResilientConfig:
|
|
55
|
+
"""Create minimal configuration with basic retry only."""
|
|
56
|
+
return cls(retry=RetryConfig(max_retries=2))
|
|
57
|
+
|
|
58
|
+
@classmethod
|
|
59
|
+
def production(cls) -> ResilientConfig:
|
|
60
|
+
"""Create production-grade configuration."""
|
|
61
|
+
return cls(
|
|
62
|
+
retry=RetryConfig(
|
|
63
|
+
max_retries=3,
|
|
64
|
+
min_delay_ms=1000,
|
|
65
|
+
max_delay_ms=30000,
|
|
66
|
+
),
|
|
67
|
+
rate_limit=RateLimiterConfig.from_rps(10),
|
|
68
|
+
circuit_breaker=CircuitBreakerConfig(
|
|
69
|
+
failure_threshold=5,
|
|
70
|
+
cooldown_seconds=30,
|
|
71
|
+
),
|
|
72
|
+
backpressure=BackpressureConfig(max_concurrent=10),
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@dataclass
|
|
77
|
+
class ExecutionStats:
|
|
78
|
+
"""Statistics from a resilient execution.
|
|
79
|
+
|
|
80
|
+
Attributes:
|
|
81
|
+
success: Whether operation succeeded
|
|
82
|
+
retry_result: Result from retry policy
|
|
83
|
+
rate_limit_wait_ms: Time spent waiting for rate limit
|
|
84
|
+
circuit_state: Current circuit breaker state
|
|
85
|
+
inflight_at_start: In-flight count at execution start
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
success: bool
|
|
89
|
+
retry_result: RetryResult | None = None
|
|
90
|
+
rate_limit_wait_ms: float = 0.0
|
|
91
|
+
circuit_state: str = "unknown"
|
|
92
|
+
inflight_at_start: int = 0
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class ResilientExecutor:
|
|
96
|
+
"""Executor combining all resilience patterns.
|
|
97
|
+
|
|
98
|
+
Executes operations with:
|
|
99
|
+
1. Backpressure control (concurrency limiting)
|
|
100
|
+
2. Rate limiting
|
|
101
|
+
3. Circuit breaker
|
|
102
|
+
4. Retry with exponential backoff
|
|
103
|
+
|
|
104
|
+
Example:
|
|
105
|
+
>>> config = ResilientConfig.production()
|
|
106
|
+
>>> executor = ResilientExecutor(config)
|
|
107
|
+
>>> result = await executor.execute(async_operation)
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
def __init__(
|
|
111
|
+
self,
|
|
112
|
+
config: ResilientConfig | None = None,
|
|
113
|
+
name: str = "default",
|
|
114
|
+
) -> None:
|
|
115
|
+
"""Initialize resilient executor.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
config: Combined resilience configuration
|
|
119
|
+
name: Identifier for this executor
|
|
120
|
+
"""
|
|
121
|
+
self._config = config or ResilientConfig()
|
|
122
|
+
self._name = name
|
|
123
|
+
|
|
124
|
+
# Initialize components
|
|
125
|
+
self._retry = (
|
|
126
|
+
RetryPolicy(self._config.retry) if self._config.retry else None
|
|
127
|
+
)
|
|
128
|
+
self._rate_limiter = (
|
|
129
|
+
RateLimiter(self._config.rate_limit)
|
|
130
|
+
if self._config.rate_limit
|
|
131
|
+
else None
|
|
132
|
+
)
|
|
133
|
+
self._circuit_breaker = (
|
|
134
|
+
CircuitBreaker(self._config.circuit_breaker)
|
|
135
|
+
if self._config.circuit_breaker
|
|
136
|
+
else None
|
|
137
|
+
)
|
|
138
|
+
self._backpressure = (
|
|
139
|
+
Backpressure(self._config.backpressure)
|
|
140
|
+
if self._config.backpressure
|
|
141
|
+
else None
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
@property
|
|
145
|
+
def name(self) -> str:
|
|
146
|
+
"""Get executor name."""
|
|
147
|
+
return self._name
|
|
148
|
+
|
|
149
|
+
@property
|
|
150
|
+
def circuit_state(self) -> str:
|
|
151
|
+
"""Get current circuit breaker state."""
|
|
152
|
+
if self._circuit_breaker:
|
|
153
|
+
return self._circuit_breaker.state.value
|
|
154
|
+
return "disabled"
|
|
155
|
+
|
|
156
|
+
@property
|
|
157
|
+
def current_inflight(self) -> int:
|
|
158
|
+
"""Get current in-flight count."""
|
|
159
|
+
if self._backpressure:
|
|
160
|
+
return self._backpressure.current_inflight
|
|
161
|
+
return 0
|
|
162
|
+
|
|
163
|
+
async def execute(
|
|
164
|
+
self,
|
|
165
|
+
operation: Callable[[], Awaitable[T]],
|
|
166
|
+
on_retry: Callable[[int, Exception, float], None] | None = None,
|
|
167
|
+
) -> T:
|
|
168
|
+
"""Execute an operation with all resilience patterns.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
operation: Async operation to execute
|
|
172
|
+
on_retry: Optional callback on retry
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
Operation result
|
|
176
|
+
|
|
177
|
+
Raises:
|
|
178
|
+
CircuitOpenError: If circuit is open
|
|
179
|
+
Exception: Original exception if all retries fail
|
|
180
|
+
"""
|
|
181
|
+
# 1. Backpressure control
|
|
182
|
+
if self._backpressure:
|
|
183
|
+
async with self._backpressure.acquire():
|
|
184
|
+
return await self._execute_inner(operation, on_retry)
|
|
185
|
+
else:
|
|
186
|
+
return await self._execute_inner(operation, on_retry)
|
|
187
|
+
|
|
188
|
+
async def _execute_inner(
|
|
189
|
+
self,
|
|
190
|
+
operation: Callable[[], Awaitable[T]],
|
|
191
|
+
on_retry: Callable[[int, Exception, float], None] | None = None,
|
|
192
|
+
) -> T:
|
|
193
|
+
"""Execute with rate limiting, circuit breaker, and retry.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
operation: Async operation
|
|
197
|
+
on_retry: Retry callback
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
Operation result
|
|
201
|
+
"""
|
|
202
|
+
# 2. Rate limiting
|
|
203
|
+
if self._rate_limiter:
|
|
204
|
+
await self._rate_limiter.acquire()
|
|
205
|
+
|
|
206
|
+
# 3. Circuit breaker + 4. Retry
|
|
207
|
+
if self._circuit_breaker:
|
|
208
|
+
return await self._circuit_breaker.execute(
|
|
209
|
+
lambda: self._execute_with_retry(operation, on_retry)
|
|
210
|
+
)
|
|
211
|
+
else:
|
|
212
|
+
return await self._execute_with_retry(operation, on_retry)
|
|
213
|
+
|
|
214
|
+
async def _execute_with_retry(
|
|
215
|
+
self,
|
|
216
|
+
operation: Callable[[], Awaitable[T]],
|
|
217
|
+
on_retry: Callable[[int, Exception, float], None] | None = None,
|
|
218
|
+
) -> T:
|
|
219
|
+
"""Execute with retry.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
operation: Async operation
|
|
223
|
+
on_retry: Retry callback
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
Operation result
|
|
227
|
+
"""
|
|
228
|
+
if self._retry:
|
|
229
|
+
result = await self._retry.execute(operation, on_retry)
|
|
230
|
+
if result.success:
|
|
231
|
+
return result.value
|
|
232
|
+
raise result.error # type: ignore
|
|
233
|
+
else:
|
|
234
|
+
return await operation()
|
|
235
|
+
|
|
236
|
+
async def execute_with_stats(
|
|
237
|
+
self,
|
|
238
|
+
operation: Callable[[], Awaitable[T]],
|
|
239
|
+
on_retry: Callable[[int, Exception, float], None] | None = None,
|
|
240
|
+
) -> tuple[T, ExecutionStats]:
|
|
241
|
+
"""Execute and return execution statistics.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
operation: Async operation
|
|
245
|
+
on_retry: Retry callback
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
Tuple of (result, stats)
|
|
249
|
+
"""
|
|
250
|
+
stats = ExecutionStats(
|
|
251
|
+
success=False,
|
|
252
|
+
circuit_state=self.circuit_state,
|
|
253
|
+
inflight_at_start=self.current_inflight,
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
try:
|
|
257
|
+
# 1. Backpressure
|
|
258
|
+
if self._backpressure:
|
|
259
|
+
async with self._backpressure.acquire():
|
|
260
|
+
result = await self._execute_inner_with_stats(
|
|
261
|
+
operation, on_retry, stats
|
|
262
|
+
)
|
|
263
|
+
else:
|
|
264
|
+
result = await self._execute_inner_with_stats(
|
|
265
|
+
operation, on_retry, stats
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
stats.success = True
|
|
269
|
+
return result, stats
|
|
270
|
+
|
|
271
|
+
except Exception:
|
|
272
|
+
stats.success = False
|
|
273
|
+
raise
|
|
274
|
+
|
|
275
|
+
async def _execute_inner_with_stats(
|
|
276
|
+
self,
|
|
277
|
+
operation: Callable[[], Awaitable[T]],
|
|
278
|
+
on_retry: Callable[[int, Exception, float], None] | None,
|
|
279
|
+
stats: ExecutionStats,
|
|
280
|
+
) -> T:
|
|
281
|
+
"""Execute with stats collection.
|
|
282
|
+
|
|
283
|
+
Args:
|
|
284
|
+
operation: Async operation
|
|
285
|
+
on_retry: Retry callback
|
|
286
|
+
stats: Stats object to populate
|
|
287
|
+
|
|
288
|
+
Returns:
|
|
289
|
+
Operation result
|
|
290
|
+
"""
|
|
291
|
+
# Rate limiting
|
|
292
|
+
if self._rate_limiter:
|
|
293
|
+
wait_time = await self._rate_limiter.acquire()
|
|
294
|
+
stats.rate_limit_wait_ms = wait_time * 1000
|
|
295
|
+
|
|
296
|
+
# Circuit breaker + Retry
|
|
297
|
+
async def inner() -> T:
|
|
298
|
+
if self._retry:
|
|
299
|
+
result = await self._retry.execute(operation, on_retry)
|
|
300
|
+
stats.retry_result = result
|
|
301
|
+
if result.success:
|
|
302
|
+
return result.value
|
|
303
|
+
raise result.error # type: ignore
|
|
304
|
+
return await operation()
|
|
305
|
+
|
|
306
|
+
if self._circuit_breaker:
|
|
307
|
+
stats.circuit_state = self._circuit_breaker.state.value
|
|
308
|
+
return await self._circuit_breaker.execute(inner)
|
|
309
|
+
else:
|
|
310
|
+
return await inner()
|
|
311
|
+
|
|
312
|
+
def get_stats(self) -> dict[str, Any]:
|
|
313
|
+
"""Get current statistics from all components.
|
|
314
|
+
|
|
315
|
+
Returns:
|
|
316
|
+
Dict with component statistics
|
|
317
|
+
"""
|
|
318
|
+
stats: dict[str, Any] = {"name": self._name}
|
|
319
|
+
|
|
320
|
+
if self._rate_limiter:
|
|
321
|
+
stats["rate_limiter"] = {
|
|
322
|
+
"available_tokens": self._rate_limiter.available_tokens,
|
|
323
|
+
"is_limited": self._rate_limiter.is_limited,
|
|
324
|
+
}
|
|
325
|
+
|
|
326
|
+
if self._circuit_breaker:
|
|
327
|
+
cb_stats = self._circuit_breaker.get_stats()
|
|
328
|
+
stats["circuit_breaker"] = {
|
|
329
|
+
"state": self._circuit_breaker.state.value,
|
|
330
|
+
"total_requests": cb_stats.total_requests,
|
|
331
|
+
"failed_requests": cb_stats.failed_requests,
|
|
332
|
+
"rejected_requests": cb_stats.rejected_requests,
|
|
333
|
+
}
|
|
334
|
+
|
|
335
|
+
if self._backpressure:
|
|
336
|
+
stats["backpressure"] = self._backpressure.get_stats()
|
|
337
|
+
|
|
338
|
+
return stats
|
|
339
|
+
|
|
340
|
+
def reset(self) -> None:
|
|
341
|
+
"""Reset all components to initial state."""
|
|
342
|
+
if self._circuit_breaker:
|
|
343
|
+
self._circuit_breaker.reset()
|
|
@@ -0,0 +1,341 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Fallback chain for multi-model degradation.
|
|
3
|
+
|
|
4
|
+
Provides automatic failover between multiple AI providers/models.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import asyncio
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from typing import TYPE_CHECKING, Any, TypeVar
|
|
12
|
+
|
|
13
|
+
from ai_lib_python.errors import AiLibError, is_fallbackable
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from collections.abc import Awaitable, Callable
|
|
17
|
+
|
|
18
|
+
T = TypeVar("T")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class FallbackTarget:
|
|
23
|
+
"""A target in the fallback chain.
|
|
24
|
+
|
|
25
|
+
Attributes:
|
|
26
|
+
name: Identifier for this target
|
|
27
|
+
operation: Async operation factory
|
|
28
|
+
weight: Priority weight (higher = preferred)
|
|
29
|
+
enabled: Whether this target is enabled
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
name: str
|
|
33
|
+
operation: Callable[..., Awaitable[T]]
|
|
34
|
+
weight: float = 1.0
|
|
35
|
+
enabled: bool = True
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class FallbackConfig:
|
|
40
|
+
"""Configuration for fallback chain.
|
|
41
|
+
|
|
42
|
+
Attributes:
|
|
43
|
+
retry_all: Whether to retry all targets on failure
|
|
44
|
+
max_attempts_per_target: Max attempts per target before fallback
|
|
45
|
+
delay_between_targets_ms: Delay between fallback attempts
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
retry_all: bool = True
|
|
49
|
+
max_attempts_per_target: int = 1
|
|
50
|
+
delay_between_targets_ms: int = 0
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@dataclass
|
|
54
|
+
class FallbackResult:
|
|
55
|
+
"""Result of a fallback chain execution.
|
|
56
|
+
|
|
57
|
+
Attributes:
|
|
58
|
+
success: Whether operation succeeded
|
|
59
|
+
value: Result value (if success)
|
|
60
|
+
target_used: Name of target that succeeded
|
|
61
|
+
targets_tried: List of targets attempted
|
|
62
|
+
errors: Mapping of target names to errors
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
success: bool
|
|
66
|
+
value: Any = None
|
|
67
|
+
target_used: str | None = None
|
|
68
|
+
targets_tried: list[str] = field(default_factory=list)
|
|
69
|
+
errors: dict[str, Exception] = field(default_factory=dict)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class FallbackChain:
|
|
73
|
+
"""Fallback chain for automatic failover.
|
|
74
|
+
|
|
75
|
+
Executes operations through a chain of targets, falling back
|
|
76
|
+
to the next target on failure.
|
|
77
|
+
|
|
78
|
+
Example:
|
|
79
|
+
>>> chain = FallbackChain()
|
|
80
|
+
>>> chain.add_target("gpt-4", lambda: call_openai("gpt-4"))
|
|
81
|
+
>>> chain.add_target("claude", lambda: call_anthropic("claude"))
|
|
82
|
+
>>> result = await chain.execute()
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
def __init__(self, config: FallbackConfig | None = None) -> None:
|
|
86
|
+
"""Initialize fallback chain.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
config: Fallback configuration
|
|
90
|
+
"""
|
|
91
|
+
self._config = config or FallbackConfig()
|
|
92
|
+
self._targets: list[FallbackTarget] = []
|
|
93
|
+
|
|
94
|
+
def add_target(
|
|
95
|
+
self,
|
|
96
|
+
name: str,
|
|
97
|
+
operation: Callable[..., Awaitable[T]],
|
|
98
|
+
weight: float = 1.0,
|
|
99
|
+
enabled: bool = True,
|
|
100
|
+
) -> FallbackChain:
|
|
101
|
+
"""Add a target to the fallback chain.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
name: Target identifier
|
|
105
|
+
operation: Async operation factory
|
|
106
|
+
weight: Priority weight
|
|
107
|
+
enabled: Whether target is enabled
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
Self for chaining
|
|
111
|
+
"""
|
|
112
|
+
self._targets.append(
|
|
113
|
+
FallbackTarget(
|
|
114
|
+
name=name,
|
|
115
|
+
operation=operation,
|
|
116
|
+
weight=weight,
|
|
117
|
+
enabled=enabled,
|
|
118
|
+
)
|
|
119
|
+
)
|
|
120
|
+
return self
|
|
121
|
+
|
|
122
|
+
def remove_target(self, name: str) -> bool:
|
|
123
|
+
"""Remove a target from the chain.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
name: Target name to remove
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
True if removed, False if not found
|
|
130
|
+
"""
|
|
131
|
+
for i, target in enumerate(self._targets):
|
|
132
|
+
if target.name == name:
|
|
133
|
+
self._targets.pop(i)
|
|
134
|
+
return True
|
|
135
|
+
return False
|
|
136
|
+
|
|
137
|
+
def set_enabled(self, name: str, enabled: bool) -> bool:
|
|
138
|
+
"""Enable or disable a target.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
name: Target name
|
|
142
|
+
enabled: Whether to enable
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
True if target found, False otherwise
|
|
146
|
+
"""
|
|
147
|
+
for target in self._targets:
|
|
148
|
+
if target.name == name:
|
|
149
|
+
target.enabled = enabled
|
|
150
|
+
return True
|
|
151
|
+
return False
|
|
152
|
+
|
|
153
|
+
def get_targets(self) -> list[str]:
|
|
154
|
+
"""Get list of target names in priority order.
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
List of target names
|
|
158
|
+
"""
|
|
159
|
+
# Sort by weight (descending)
|
|
160
|
+
sorted_targets = sorted(
|
|
161
|
+
[t for t in self._targets if t.enabled],
|
|
162
|
+
key=lambda t: t.weight,
|
|
163
|
+
reverse=True,
|
|
164
|
+
)
|
|
165
|
+
return [t.name for t in sorted_targets]
|
|
166
|
+
|
|
167
|
+
def _should_fallback(self, error: Exception) -> bool:
|
|
168
|
+
"""Check if error should trigger fallback.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
error: The exception
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
True if should fallback
|
|
175
|
+
"""
|
|
176
|
+
# Check if error has error_class
|
|
177
|
+
if hasattr(error, "error_class"):
|
|
178
|
+
return is_fallbackable(error.error_class)
|
|
179
|
+
|
|
180
|
+
# Default: fallback on any AiLibError
|
|
181
|
+
return isinstance(error, AiLibError)
|
|
182
|
+
|
|
183
|
+
async def execute(
|
|
184
|
+
self,
|
|
185
|
+
*args: Any,
|
|
186
|
+
on_fallback: Callable[[str, str, Exception], None] | None = None,
|
|
187
|
+
**kwargs: Any,
|
|
188
|
+
) -> FallbackResult:
|
|
189
|
+
"""Execute operation through fallback chain.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
*args: Arguments to pass to operations
|
|
193
|
+
on_fallback: Callback when falling back (from, to, error)
|
|
194
|
+
**kwargs: Keyword arguments to pass to operations
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
FallbackResult with outcome
|
|
198
|
+
"""
|
|
199
|
+
# Get enabled targets sorted by weight
|
|
200
|
+
targets = sorted(
|
|
201
|
+
[t for t in self._targets if t.enabled],
|
|
202
|
+
key=lambda t: t.weight,
|
|
203
|
+
reverse=True,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
if not targets:
|
|
207
|
+
return FallbackResult(
|
|
208
|
+
success=False,
|
|
209
|
+
errors={"_chain": ValueError("No enabled targets in chain")},
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
errors: dict[str, Exception] = {}
|
|
213
|
+
targets_tried: list[str] = []
|
|
214
|
+
last_target: str | None = None
|
|
215
|
+
|
|
216
|
+
for target in targets:
|
|
217
|
+
targets_tried.append(target.name)
|
|
218
|
+
|
|
219
|
+
for attempt in range(self._config.max_attempts_per_target):
|
|
220
|
+
try:
|
|
221
|
+
result = await target.operation(*args, **kwargs)
|
|
222
|
+
return FallbackResult(
|
|
223
|
+
success=True,
|
|
224
|
+
value=result,
|
|
225
|
+
target_used=target.name,
|
|
226
|
+
targets_tried=targets_tried,
|
|
227
|
+
errors=errors,
|
|
228
|
+
)
|
|
229
|
+
except Exception as e:
|
|
230
|
+
errors[target.name] = e
|
|
231
|
+
|
|
232
|
+
# Check if should fallback
|
|
233
|
+
if not self._should_fallback(e):
|
|
234
|
+
# Non-fallbackable error, stop chain
|
|
235
|
+
return FallbackResult(
|
|
236
|
+
success=False,
|
|
237
|
+
errors=errors,
|
|
238
|
+
targets_tried=targets_tried,
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
# Only retry if more attempts available
|
|
242
|
+
if attempt < self._config.max_attempts_per_target - 1:
|
|
243
|
+
continue
|
|
244
|
+
break
|
|
245
|
+
|
|
246
|
+
# Callback before falling back
|
|
247
|
+
if on_fallback and last_target:
|
|
248
|
+
next_target = (
|
|
249
|
+
targets[targets.index(target) + 1].name
|
|
250
|
+
if targets.index(target) + 1 < len(targets)
|
|
251
|
+
else None
|
|
252
|
+
)
|
|
253
|
+
if next_target:
|
|
254
|
+
on_fallback(target.name, next_target, errors[target.name])
|
|
255
|
+
|
|
256
|
+
last_target = target.name
|
|
257
|
+
|
|
258
|
+
# Delay between targets
|
|
259
|
+
if self._config.delay_between_targets_ms > 0:
|
|
260
|
+
await asyncio.sleep(self._config.delay_between_targets_ms / 1000)
|
|
261
|
+
|
|
262
|
+
# All targets failed
|
|
263
|
+
return FallbackResult(
|
|
264
|
+
success=False,
|
|
265
|
+
errors=errors,
|
|
266
|
+
targets_tried=targets_tried,
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
class MultiFallback:
|
|
271
|
+
"""Multi-strategy fallback manager.
|
|
272
|
+
|
|
273
|
+
Manages multiple fallback chains for different scenarios.
|
|
274
|
+
|
|
275
|
+
Example:
|
|
276
|
+
>>> mf = MultiFallback()
|
|
277
|
+
>>> mf.register_chain("chat", chat_chain)
|
|
278
|
+
>>> mf.register_chain("embed", embed_chain)
|
|
279
|
+
>>> result = await mf.execute("chat", messages=[...])
|
|
280
|
+
"""
|
|
281
|
+
|
|
282
|
+
def __init__(self) -> None:
|
|
283
|
+
"""Initialize multi-fallback manager."""
|
|
284
|
+
self._chains: dict[str, FallbackChain] = {}
|
|
285
|
+
|
|
286
|
+
def register_chain(self, name: str, chain: FallbackChain) -> MultiFallback:
|
|
287
|
+
"""Register a fallback chain.
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
name: Chain identifier
|
|
291
|
+
chain: FallbackChain instance
|
|
292
|
+
|
|
293
|
+
Returns:
|
|
294
|
+
Self for chaining
|
|
295
|
+
"""
|
|
296
|
+
self._chains[name] = chain
|
|
297
|
+
return self
|
|
298
|
+
|
|
299
|
+
def get_chain(self, name: str) -> FallbackChain | None:
|
|
300
|
+
"""Get a registered chain.
|
|
301
|
+
|
|
302
|
+
Args:
|
|
303
|
+
name: Chain name
|
|
304
|
+
|
|
305
|
+
Returns:
|
|
306
|
+
Chain or None if not found
|
|
307
|
+
"""
|
|
308
|
+
return self._chains.get(name)
|
|
309
|
+
|
|
310
|
+
async def execute(
|
|
311
|
+
self,
|
|
312
|
+
chain_name: str,
|
|
313
|
+
*args: Any,
|
|
314
|
+
**kwargs: Any,
|
|
315
|
+
) -> FallbackResult:
|
|
316
|
+
"""Execute operation through a named chain.
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
chain_name: Name of chain to use
|
|
320
|
+
*args: Arguments to pass
|
|
321
|
+
**kwargs: Keyword arguments to pass
|
|
322
|
+
|
|
323
|
+
Returns:
|
|
324
|
+
FallbackResult
|
|
325
|
+
|
|
326
|
+
Raises:
|
|
327
|
+
ValueError: If chain not found
|
|
328
|
+
"""
|
|
329
|
+
chain = self._chains.get(chain_name)
|
|
330
|
+
if not chain:
|
|
331
|
+
raise ValueError(f"Unknown fallback chain: {chain_name}")
|
|
332
|
+
|
|
333
|
+
return await chain.execute(*args, **kwargs)
|
|
334
|
+
|
|
335
|
+
def list_chains(self) -> list[str]:
|
|
336
|
+
"""Get list of registered chain names.
|
|
337
|
+
|
|
338
|
+
Returns:
|
|
339
|
+
List of chain names
|
|
340
|
+
"""
|
|
341
|
+
return list(self._chains.keys())
|