agent-tool-resilience 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,32 @@
1
+ """
2
+ agent-tool-resilience: Production-grade resilience for AI agent tool calls.
3
+
4
+ Provides smart retries, fallbacks, circuit breakers, and result validation
5
+ to prevent silent failures in AI agent workflows.
6
+ """
7
+
8
+ from .retry import RetryPolicy, RetryError
9
+ from .circuit_breaker import CircuitBreaker, CircuitBreakerOpen
10
+ from .fallback import FallbackChain, FallbackError
11
+ from .validator import ResultValidator, ValidationError
12
+ from .tracer import ToolExecutionTracer, ExecutionEvent
13
+ from .rate_limit import RateLimitHandler, RateLimitExceeded
14
+ from .resilient_tool import ResilientTool, resilient_tool
15
+
16
+ __version__ = "0.1.0"
17
+ __all__ = [
18
+ "ResilientTool",
19
+ "resilient_tool",
20
+ "RetryPolicy",
21
+ "RetryError",
22
+ "CircuitBreaker",
23
+ "CircuitBreakerOpen",
24
+ "FallbackChain",
25
+ "FallbackError",
26
+ "ResultValidator",
27
+ "ValidationError",
28
+ "ToolExecutionTracer",
29
+ "ExecutionEvent",
30
+ "RateLimitHandler",
31
+ "RateLimitExceeded",
32
+ ]
@@ -0,0 +1,246 @@
1
+ """
2
+ Circuit breaker pattern to prevent cascade failures.
3
+ """
4
+
5
+ import asyncio
6
+ import threading
7
+ import time
8
+ from dataclasses import dataclass, field
9
+ from enum import Enum
10
+ from typing import Any, Callable, Optional
11
+
12
+
13
+ class CircuitState(Enum):
14
+ """Circuit breaker states."""
15
+ CLOSED = "closed" # Normal operation, requests pass through
16
+ OPEN = "open" # Circuit is open, requests fail immediately
17
+ HALF_OPEN = "half_open" # Testing if service has recovered
18
+
19
+
20
+ class CircuitBreakerOpen(Exception):
21
+ """Raised when the circuit breaker is open."""
22
+
23
+ def __init__(self, message: str, reset_time: Optional[float] = None):
24
+ super().__init__(message)
25
+ self.reset_time = reset_time
26
+
27
+
28
+ @dataclass
29
+ class CircuitBreaker:
30
+ """
31
+ Circuit breaker to prevent cascade failures.
32
+
33
+ The circuit breaker has three states:
34
+ - CLOSED: Normal operation, requests pass through
35
+ - OPEN: Too many failures, requests fail immediately
36
+ - HALF_OPEN: Testing recovery, limited requests allowed
37
+
38
+ Attributes:
39
+ failure_threshold: Number of failures before opening circuit
40
+ success_threshold: Number of successes in half-open to close circuit
41
+ reset_timeout: Seconds to wait before transitioning from open to half-open
42
+ half_open_max_calls: Maximum concurrent calls in half-open state
43
+ exclude_exceptions: Exception types that don't count as failures
44
+ on_state_change: Callback for state transitions
45
+ """
46
+ failure_threshold: int = 5
47
+ success_threshold: int = 2
48
+ reset_timeout: float = 60.0
49
+ half_open_max_calls: int = 3
50
+ exclude_exceptions: tuple[type, ...] = field(default_factory=tuple)
51
+ on_state_change: Optional[Callable[[CircuitState, CircuitState], None]] = None
52
+
53
+ # Internal state (not part of config)
54
+ _state: CircuitState = field(default=CircuitState.CLOSED, init=False)
55
+ _failure_count: int = field(default=0, init=False)
56
+ _success_count: int = field(default=0, init=False)
57
+ _last_failure_time: Optional[float] = field(default=None, init=False)
58
+ _half_open_calls: int = field(default=0, init=False)
59
+ _lock: threading.Lock = field(default_factory=threading.Lock, init=False)
60
+
61
+ @property
62
+ def state(self) -> CircuitState:
63
+ """Get current circuit state."""
64
+ with self._lock:
65
+ self._maybe_transition_to_half_open()
66
+ return self._state
67
+
68
+ @property
69
+ def failure_count(self) -> int:
70
+ """Get current failure count."""
71
+ with self._lock:
72
+ return self._failure_count
73
+
74
+ @property
75
+ def is_closed(self) -> bool:
76
+ """Check if circuit is closed (normal operation)."""
77
+ return self.state == CircuitState.CLOSED
78
+
79
+ @property
80
+ def is_open(self) -> bool:
81
+ """Check if circuit is open (blocking requests)."""
82
+ return self.state == CircuitState.OPEN
83
+
84
+ @property
85
+ def is_half_open(self) -> bool:
86
+ """Check if circuit is half-open (testing recovery)."""
87
+ return self.state == CircuitState.HALF_OPEN
88
+
89
+ def _maybe_transition_to_half_open(self) -> None:
90
+ """Check if we should transition from open to half-open."""
91
+ if self._state == CircuitState.OPEN and self._last_failure_time:
92
+ elapsed = time.time() - self._last_failure_time
93
+ if elapsed >= self.reset_timeout:
94
+ self._transition_to(CircuitState.HALF_OPEN)
95
+
96
+ def _transition_to(self, new_state: CircuitState) -> None:
97
+ """Transition to a new state."""
98
+ old_state = self._state
99
+ self._state = new_state
100
+
101
+ if new_state == CircuitState.CLOSED:
102
+ self._failure_count = 0
103
+ self._success_count = 0
104
+ self._half_open_calls = 0
105
+ elif new_state == CircuitState.HALF_OPEN:
106
+ self._success_count = 0
107
+ self._half_open_calls = 0
108
+ elif new_state == CircuitState.OPEN:
109
+ self._last_failure_time = time.time()
110
+
111
+ if self.on_state_change and old_state != new_state:
112
+ self.on_state_change(old_state, new_state)
113
+
114
+ def _record_success(self) -> None:
115
+ """Record a successful call."""
116
+ with self._lock:
117
+ if self._state == CircuitState.HALF_OPEN:
118
+ self._success_count += 1
119
+ if self._success_count >= self.success_threshold:
120
+ self._transition_to(CircuitState.CLOSED)
121
+ elif self._state == CircuitState.CLOSED:
122
+ # Reset failure count on success
123
+ self._failure_count = 0
124
+
125
+ def _record_failure(self, exception: Exception) -> None:
126
+ """Record a failed call."""
127
+ # Check if this exception type is excluded
128
+ if isinstance(exception, self.exclude_exceptions):
129
+ return
130
+
131
+ with self._lock:
132
+ self._failure_count += 1
133
+
134
+ if self._state == CircuitState.HALF_OPEN:
135
+ # Any failure in half-open immediately opens the circuit
136
+ self._transition_to(CircuitState.OPEN)
137
+ elif self._state == CircuitState.CLOSED:
138
+ if self._failure_count >= self.failure_threshold:
139
+ self._transition_to(CircuitState.OPEN)
140
+
141
+ def _allow_request(self) -> bool:
142
+ """Check if a request should be allowed."""
143
+ with self._lock:
144
+ self._maybe_transition_to_half_open()
145
+
146
+ if self._state == CircuitState.CLOSED:
147
+ return True
148
+ elif self._state == CircuitState.OPEN:
149
+ return False
150
+ elif self._state == CircuitState.HALF_OPEN:
151
+ if self._half_open_calls < self.half_open_max_calls:
152
+ self._half_open_calls += 1
153
+ return True
154
+ return False
155
+
156
+ return False
157
+
158
+ def execute(
159
+ self,
160
+ func: Callable[..., Any],
161
+ *args: Any,
162
+ **kwargs: Any
163
+ ) -> Any:
164
+ """
165
+ Execute a function with circuit breaker protection.
166
+
167
+ Args:
168
+ func: Function to execute
169
+ *args: Positional arguments
170
+ **kwargs: Keyword arguments
171
+
172
+ Returns:
173
+ Function's return value
174
+
175
+ Raises:
176
+ CircuitBreakerOpen: If circuit is open
177
+ """
178
+ if not self._allow_request():
179
+ reset_time = None
180
+ if self._last_failure_time:
181
+ reset_time = self._last_failure_time + self.reset_timeout
182
+ raise CircuitBreakerOpen(
183
+ f"Circuit breaker is {self._state.value}",
184
+ reset_time=reset_time
185
+ )
186
+
187
+ try:
188
+ result = func(*args, **kwargs)
189
+ self._record_success()
190
+ return result
191
+ except Exception as e:
192
+ self._record_failure(e)
193
+ raise
194
+
195
+ async def execute_async(
196
+ self,
197
+ func: Callable[..., Any],
198
+ *args: Any,
199
+ **kwargs: Any
200
+ ) -> Any:
201
+ """
202
+ Execute an async function with circuit breaker protection.
203
+
204
+ Args:
205
+ func: Async function to execute
206
+ *args: Positional arguments
207
+ **kwargs: Keyword arguments
208
+
209
+ Returns:
210
+ Function's return value
211
+
212
+ Raises:
213
+ CircuitBreakerOpen: If circuit is open
214
+ """
215
+ if not self._allow_request():
216
+ reset_time = None
217
+ if self._last_failure_time:
218
+ reset_time = self._last_failure_time + self.reset_timeout
219
+ raise CircuitBreakerOpen(
220
+ f"Circuit breaker is {self._state.value}",
221
+ reset_time=reset_time
222
+ )
223
+
224
+ try:
225
+ result = await func(*args, **kwargs)
226
+ self._record_success()
227
+ return result
228
+ except Exception as e:
229
+ self._record_failure(e)
230
+ raise
231
+
232
+ def reset(self) -> None:
233
+ """Manually reset the circuit breaker to closed state."""
234
+ with self._lock:
235
+ self._transition_to(CircuitState.CLOSED)
236
+
237
+ def get_stats(self) -> dict:
238
+ """Get circuit breaker statistics."""
239
+ with self._lock:
240
+ return {
241
+ "state": self._state.value,
242
+ "failure_count": self._failure_count,
243
+ "success_count": self._success_count,
244
+ "last_failure_time": self._last_failure_time,
245
+ "half_open_calls": self._half_open_calls,
246
+ }
@@ -0,0 +1,188 @@
1
+ """
2
+ Fallback strategies for graceful degradation.
3
+ """
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Any, Callable, Optional, Sequence
7
+
8
+
9
+ class FallbackError(Exception):
10
+ """Raised when all fallbacks have failed."""
11
+
12
+ def __init__(self, message: str, errors: list[tuple[str, Exception]]):
13
+ super().__init__(message)
14
+ self.errors = errors
15
+
16
+
17
+ @dataclass
18
+ class FallbackChain:
19
+ """
20
+ Chain of fallback functions to try in sequence.
21
+
22
+ Attributes:
23
+ fallbacks: Sequence of fallback functions to try
24
+ names: Optional names for each fallback (for error reporting)
25
+ on_fallback: Optional callback when a fallback is used
26
+ """
27
+ fallbacks: Sequence[Callable[..., Any]]
28
+ names: Optional[Sequence[str]] = None
29
+ on_fallback: Optional[Callable[[int, str, Exception], None]] = None
30
+
31
+ def __post_init__(self):
32
+ if self.names is None:
33
+ self.names = [f"fallback_{i}" for i in range(len(self.fallbacks))]
34
+ elif len(self.names) != len(self.fallbacks):
35
+ raise ValueError("Number of names must match number of fallbacks")
36
+
37
+ def execute(
38
+ self,
39
+ *args: Any,
40
+ primary_exception: Optional[Exception] = None,
41
+ **kwargs: Any
42
+ ) -> Any:
43
+ """
44
+ Execute fallbacks in sequence until one succeeds.
45
+
46
+ Args:
47
+ *args: Positional arguments passed to fallbacks
48
+ primary_exception: The exception from the primary function
49
+ **kwargs: Keyword arguments passed to fallbacks
50
+
51
+ Returns:
52
+ Result from the first successful fallback
53
+
54
+ Raises:
55
+ FallbackError: If all fallbacks fail
56
+ """
57
+ errors: list[tuple[str, Exception]] = []
58
+
59
+ if primary_exception:
60
+ errors.append(("primary", primary_exception))
61
+
62
+ for i, (fallback, name) in enumerate(zip(self.fallbacks, self.names)):
63
+ try:
64
+ result = fallback(*args, **kwargs)
65
+
66
+ if self.on_fallback and i > 0:
67
+ last_error = errors[-1][1] if errors else None
68
+ self.on_fallback(i, name, last_error)
69
+
70
+ return result
71
+ except Exception as e:
72
+ errors.append((name, e))
73
+
74
+ raise FallbackError(
75
+ f"All {len(self.fallbacks)} fallbacks failed",
76
+ errors=errors
77
+ )
78
+
79
+ async def execute_async(
80
+ self,
81
+ *args: Any,
82
+ primary_exception: Optional[Exception] = None,
83
+ **kwargs: Any
84
+ ) -> Any:
85
+ """
86
+ Execute async fallbacks in sequence until one succeeds.
87
+
88
+ Args:
89
+ *args: Positional arguments passed to fallbacks
90
+ primary_exception: The exception from the primary function
91
+ **kwargs: Keyword arguments passed to fallbacks
92
+
93
+ Returns:
94
+ Result from the first successful fallback
95
+
96
+ Raises:
97
+ FallbackError: If all fallbacks fail
98
+ """
99
+ errors: list[tuple[str, Exception]] = []
100
+
101
+ if primary_exception:
102
+ errors.append(("primary", primary_exception))
103
+
104
+ for i, (fallback, name) in enumerate(zip(self.fallbacks, self.names)):
105
+ try:
106
+ result = await fallback(*args, **kwargs)
107
+
108
+ if self.on_fallback and i > 0:
109
+ last_error = errors[-1][1] if errors else None
110
+ self.on_fallback(i, name, last_error)
111
+
112
+ return result
113
+ except Exception as e:
114
+ errors.append((name, e))
115
+
116
+ raise FallbackError(
117
+ f"All {len(self.fallbacks)} fallbacks failed",
118
+ errors=errors
119
+ )
120
+
121
+
122
+ @dataclass
123
+ class CachedFallback:
124
+ """
125
+ Fallback that returns cached results when primary fails.
126
+
127
+ Attributes:
128
+ cache: Dictionary mapping args to cached results
129
+ default: Default value if no cache entry exists
130
+ ttl_seconds: Time-to-live for cache entries (None = infinite)
131
+ """
132
+ cache: dict[Any, tuple[Any, float]] = field(default_factory=dict)
133
+ default: Any = None
134
+ ttl_seconds: Optional[float] = None
135
+
136
+ def get(self, key: Any) -> Any:
137
+ """Get a cached value."""
138
+ import time
139
+
140
+ if key not in self.cache:
141
+ return self.default
142
+
143
+ value, timestamp = self.cache[key]
144
+
145
+ if self.ttl_seconds is not None:
146
+ if time.time() - timestamp > self.ttl_seconds:
147
+ del self.cache[key]
148
+ return self.default
149
+
150
+ return value
151
+
152
+ def set(self, key: Any, value: Any) -> None:
153
+ """Set a cached value."""
154
+ import time
155
+ self.cache[key] = (value, time.time())
156
+
157
+ def as_fallback(self) -> Callable[[Any], Any]:
158
+ """Return a fallback function that uses this cache."""
159
+ def fallback(*args, **kwargs):
160
+ # Use args as cache key
161
+ key = (args, tuple(sorted(kwargs.items())))
162
+ return self.get(key)
163
+ return fallback
164
+
165
+
166
+ def static_fallback(value: Any) -> Callable[..., Any]:
167
+ """Create a fallback that always returns a static value."""
168
+ def fallback(*args, **kwargs):
169
+ return value
170
+ return fallback
171
+
172
+
173
+ def error_fallback(
174
+ error_type: str = "service_unavailable",
175
+ include_args: bool = False
176
+ ) -> Callable[..., dict]:
177
+ """Create a fallback that returns an error response dict."""
178
+ def fallback(*args, **kwargs):
179
+ response = {
180
+ "error": True,
181
+ "error_type": error_type,
182
+ "message": "Service temporarily unavailable",
183
+ }
184
+ if include_args:
185
+ response["args"] = args
186
+ response["kwargs"] = kwargs
187
+ return response
188
+ return fallback