chuk-tool-processor 0.7.0__py3-none-any.whl → 0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of chuk-tool-processor might be problematic. Click here for more details.
- chuk_tool_processor/core/__init__.py +31 -0
- chuk_tool_processor/core/exceptions.py +218 -12
- chuk_tool_processor/core/processor.py +30 -1
- chuk_tool_processor/execution/wrappers/__init__.py +42 -0
- chuk_tool_processor/execution/wrappers/caching.py +7 -3
- chuk_tool_processor/execution/wrappers/circuit_breaker.py +343 -0
- chuk_tool_processor/execution/wrappers/retry.py +12 -0
- chuk_tool_processor/mcp/setup_mcp_http_streamable.py +31 -2
- chuk_tool_processor/mcp/setup_mcp_sse.py +31 -2
- chuk_tool_processor/models/__init__.py +20 -0
- chuk_tool_processor/models/tool_call.py +34 -1
- chuk_tool_processor/models/tool_spec.py +350 -0
- chuk_tool_processor/models/validated_tool.py +22 -2
- {chuk_tool_processor-0.7.0.dist-info → chuk_tool_processor-0.8.dist-info}/METADATA +197 -6
- {chuk_tool_processor-0.7.0.dist-info → chuk_tool_processor-0.8.dist-info}/RECORD +17 -15
- {chuk_tool_processor-0.7.0.dist-info → chuk_tool_processor-0.8.dist-info}/WHEEL +0 -0
- {chuk_tool_processor-0.7.0.dist-info → chuk_tool_processor-0.8.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,343 @@
|
|
|
1
|
+
# chuk_tool_processor/execution/wrappers/circuit_breaker.py
|
|
2
|
+
"""
|
|
3
|
+
Circuit breaker pattern for tool execution.
|
|
4
|
+
|
|
5
|
+
Prevents cascading failures by tracking failure rates and temporarily
|
|
6
|
+
blocking calls to failing tools. Implements a state machine:
|
|
7
|
+
|
|
8
|
+
CLOSED → OPEN → HALF_OPEN → CLOSED (or back to OPEN)
|
|
9
|
+
|
|
10
|
+
States:
|
|
11
|
+
- CLOSED: Normal operation, requests pass through
|
|
12
|
+
- OPEN: Too many failures, requests blocked immediately
|
|
13
|
+
- HALF_OPEN: Testing if service recovered, limited requests allowed
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import asyncio
|
|
19
|
+
import time
|
|
20
|
+
from datetime import UTC, datetime
|
|
21
|
+
from enum import Enum
|
|
22
|
+
from typing import Any
|
|
23
|
+
|
|
24
|
+
from chuk_tool_processor.core.exceptions import ToolCircuitOpenError
|
|
25
|
+
from chuk_tool_processor.logging import get_logger
|
|
26
|
+
from chuk_tool_processor.models.tool_call import ToolCall
|
|
27
|
+
from chuk_tool_processor.models.tool_result import ToolResult
|
|
28
|
+
|
|
29
|
+
logger = get_logger("chuk_tool_processor.execution.wrappers.circuit_breaker")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# --------------------------------------------------------------------------- #
|
|
33
|
+
# Circuit breaker state
|
|
34
|
+
# --------------------------------------------------------------------------- #
|
|
35
|
+
class CircuitState(str, Enum):
|
|
36
|
+
"""Circuit breaker states."""
|
|
37
|
+
|
|
38
|
+
CLOSED = "closed" # Normal operation
|
|
39
|
+
OPEN = "open" # Blocking requests due to failures
|
|
40
|
+
HALF_OPEN = "half_open" # Testing recovery with limited requests
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class CircuitBreakerConfig:
|
|
44
|
+
"""Configuration for circuit breaker behavior."""
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
failure_threshold: int = 5,
|
|
49
|
+
success_threshold: int = 2,
|
|
50
|
+
reset_timeout: float = 60.0,
|
|
51
|
+
half_open_max_calls: int = 1,
|
|
52
|
+
timeout_threshold: float | None = None,
|
|
53
|
+
):
|
|
54
|
+
"""
|
|
55
|
+
Initialize circuit breaker configuration.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
failure_threshold: Number of failures before opening circuit
|
|
59
|
+
success_threshold: Number of successes in HALF_OPEN to close circuit
|
|
60
|
+
reset_timeout: Seconds to wait before trying HALF_OPEN
|
|
61
|
+
half_open_max_calls: Max concurrent calls in HALF_OPEN state
|
|
62
|
+
timeout_threshold: Optional timeout (s) to consider as failure
|
|
63
|
+
"""
|
|
64
|
+
self.failure_threshold = failure_threshold
|
|
65
|
+
self.success_threshold = success_threshold
|
|
66
|
+
self.reset_timeout = reset_timeout
|
|
67
|
+
self.half_open_max_calls = half_open_max_calls
|
|
68
|
+
self.timeout_threshold = timeout_threshold
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class CircuitBreakerState:
|
|
72
|
+
"""Per-tool circuit breaker state tracking."""
|
|
73
|
+
|
|
74
|
+
def __init__(self, config: CircuitBreakerConfig):
|
|
75
|
+
self.config = config
|
|
76
|
+
self.state = CircuitState.CLOSED
|
|
77
|
+
self.failure_count = 0
|
|
78
|
+
self.success_count = 0
|
|
79
|
+
self.last_failure_time: float | None = None
|
|
80
|
+
self.opened_at: float | None = None
|
|
81
|
+
self.half_open_calls = 0
|
|
82
|
+
self._lock = asyncio.Lock()
|
|
83
|
+
|
|
84
|
+
async def record_success(self) -> None:
|
|
85
|
+
"""Record a successful call."""
|
|
86
|
+
async with self._lock:
|
|
87
|
+
if self.state == CircuitState.HALF_OPEN:
|
|
88
|
+
self.success_count += 1
|
|
89
|
+
logger.debug(f"Circuit HALF_OPEN: success {self.success_count}/{self.config.success_threshold}")
|
|
90
|
+
|
|
91
|
+
# Enough successes? Close the circuit
|
|
92
|
+
if self.success_count >= self.config.success_threshold:
|
|
93
|
+
logger.info("Circuit breaker: Transitioning to CLOSED (service recovered)")
|
|
94
|
+
self.state = CircuitState.CLOSED
|
|
95
|
+
self.failure_count = 0
|
|
96
|
+
self.success_count = 0
|
|
97
|
+
self.opened_at = None
|
|
98
|
+
self.half_open_calls = 0
|
|
99
|
+
else:
|
|
100
|
+
# In CLOSED state, just reset failure count
|
|
101
|
+
self.failure_count = 0
|
|
102
|
+
|
|
103
|
+
async def record_failure(self) -> None:
|
|
104
|
+
"""Record a failed call."""
|
|
105
|
+
async with self._lock:
|
|
106
|
+
self.failure_count += 1
|
|
107
|
+
self.last_failure_time = time.monotonic()
|
|
108
|
+
logger.debug(f"Circuit: failure {self.failure_count}/{self.config.failure_threshold}")
|
|
109
|
+
|
|
110
|
+
if self.state == CircuitState.CLOSED:
|
|
111
|
+
# Check if we should open
|
|
112
|
+
if self.failure_count >= self.config.failure_threshold:
|
|
113
|
+
logger.warning(f"Circuit breaker: OPENING after {self.failure_count} failures")
|
|
114
|
+
self.state = CircuitState.OPEN
|
|
115
|
+
self.opened_at = time.monotonic()
|
|
116
|
+
elif self.state == CircuitState.HALF_OPEN:
|
|
117
|
+
# Failed during test → back to OPEN
|
|
118
|
+
logger.warning("Circuit breaker: Back to OPEN (test failed)")
|
|
119
|
+
self.state = CircuitState.OPEN
|
|
120
|
+
self.success_count = 0
|
|
121
|
+
self.opened_at = time.monotonic()
|
|
122
|
+
self.half_open_calls = 0
|
|
123
|
+
|
|
124
|
+
async def can_execute(self) -> bool:
|
|
125
|
+
"""Check if a call should be allowed through."""
|
|
126
|
+
async with self._lock:
|
|
127
|
+
if self.state == CircuitState.CLOSED:
|
|
128
|
+
return True
|
|
129
|
+
|
|
130
|
+
if self.state == CircuitState.HALF_OPEN:
|
|
131
|
+
# Limit concurrent calls in HALF_OPEN
|
|
132
|
+
if self.half_open_calls < self.config.half_open_max_calls:
|
|
133
|
+
self.half_open_calls += 1
|
|
134
|
+
return True
|
|
135
|
+
return False
|
|
136
|
+
|
|
137
|
+
# OPEN state: check if we should try HALF_OPEN
|
|
138
|
+
if self.opened_at is not None:
|
|
139
|
+
elapsed = time.monotonic() - self.opened_at
|
|
140
|
+
if elapsed >= self.config.reset_timeout:
|
|
141
|
+
logger.info("Circuit breaker: Transitioning to HALF_OPEN (testing recovery)")
|
|
142
|
+
self.state = CircuitState.HALF_OPEN
|
|
143
|
+
self.half_open_calls = 1
|
|
144
|
+
self.success_count = 0
|
|
145
|
+
return True
|
|
146
|
+
|
|
147
|
+
return False
|
|
148
|
+
|
|
149
|
+
async def release_half_open_slot(self) -> None:
|
|
150
|
+
"""Release a HALF_OPEN slot after call completes."""
|
|
151
|
+
async with self._lock:
|
|
152
|
+
if self.state == CircuitState.HALF_OPEN:
|
|
153
|
+
self.half_open_calls = max(0, self.half_open_calls - 1)
|
|
154
|
+
|
|
155
|
+
def get_state(self) -> dict[str, Any]:
|
|
156
|
+
"""Get current state as dict."""
|
|
157
|
+
return {
|
|
158
|
+
"state": self.state.value,
|
|
159
|
+
"failure_count": self.failure_count,
|
|
160
|
+
"success_count": self.success_count,
|
|
161
|
+
"opened_at": self.opened_at,
|
|
162
|
+
"time_until_half_open": (
|
|
163
|
+
max(0, self.config.reset_timeout - (time.monotonic() - self.opened_at))
|
|
164
|
+
if self.opened_at and self.state == CircuitState.OPEN
|
|
165
|
+
else None
|
|
166
|
+
),
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
# --------------------------------------------------------------------------- #
|
|
171
|
+
# Circuit breaker executor wrapper
|
|
172
|
+
# --------------------------------------------------------------------------- #
|
|
173
|
+
class CircuitBreakerExecutor:
|
|
174
|
+
"""
|
|
175
|
+
Executor wrapper that implements circuit breaker pattern.
|
|
176
|
+
|
|
177
|
+
Tracks failures per tool and opens circuit breakers to prevent
|
|
178
|
+
cascading failures when tools are consistently failing.
|
|
179
|
+
"""
|
|
180
|
+
|
|
181
|
+
def __init__(
|
|
182
|
+
self,
|
|
183
|
+
executor: Any,
|
|
184
|
+
*,
|
|
185
|
+
default_config: CircuitBreakerConfig | None = None,
|
|
186
|
+
tool_configs: dict[str, CircuitBreakerConfig] | None = None,
|
|
187
|
+
):
|
|
188
|
+
"""
|
|
189
|
+
Initialize circuit breaker executor.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
executor: Underlying executor to wrap
|
|
193
|
+
default_config: Default circuit breaker configuration
|
|
194
|
+
tool_configs: Per-tool circuit breaker configurations
|
|
195
|
+
"""
|
|
196
|
+
self.executor = executor
|
|
197
|
+
self.default_config = default_config or CircuitBreakerConfig()
|
|
198
|
+
self.tool_configs = tool_configs or {}
|
|
199
|
+
self._states: dict[str, CircuitBreakerState] = {}
|
|
200
|
+
self._states_lock = asyncio.Lock()
|
|
201
|
+
|
|
202
|
+
async def _get_state(self, tool: str) -> CircuitBreakerState:
|
|
203
|
+
"""Get or create circuit breaker state for a tool."""
|
|
204
|
+
if tool not in self._states:
|
|
205
|
+
async with self._states_lock:
|
|
206
|
+
if tool not in self._states:
|
|
207
|
+
config = self.tool_configs.get(tool, self.default_config)
|
|
208
|
+
self._states[tool] = CircuitBreakerState(config)
|
|
209
|
+
return self._states[tool]
|
|
210
|
+
|
|
211
|
+
async def execute(
|
|
212
|
+
self,
|
|
213
|
+
calls: list[ToolCall],
|
|
214
|
+
*,
|
|
215
|
+
timeout: float | None = None,
|
|
216
|
+
use_cache: bool = True,
|
|
217
|
+
) -> list[ToolResult]:
|
|
218
|
+
"""
|
|
219
|
+
Execute tool calls with circuit breaker protection.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
calls: List of tool calls to execute
|
|
223
|
+
timeout: Optional timeout for execution
|
|
224
|
+
use_cache: Whether to use cached results
|
|
225
|
+
|
|
226
|
+
Returns:
|
|
227
|
+
List of tool results
|
|
228
|
+
"""
|
|
229
|
+
if not calls:
|
|
230
|
+
return []
|
|
231
|
+
|
|
232
|
+
results: list[ToolResult] = []
|
|
233
|
+
|
|
234
|
+
for call in calls:
|
|
235
|
+
state = await self._get_state(call.tool)
|
|
236
|
+
|
|
237
|
+
# Check if circuit allows execution
|
|
238
|
+
can_execute = await state.can_execute()
|
|
239
|
+
|
|
240
|
+
if not can_execute:
|
|
241
|
+
# Circuit is OPEN - reject immediately
|
|
242
|
+
state_info = state.get_state()
|
|
243
|
+
logger.warning(f"Circuit breaker OPEN for {call.tool} (failures: {state.failure_count})")
|
|
244
|
+
|
|
245
|
+
reset_time = state_info.get("time_until_half_open")
|
|
246
|
+
error = ToolCircuitOpenError(
|
|
247
|
+
tool_name=call.tool,
|
|
248
|
+
failure_count=state.failure_count,
|
|
249
|
+
reset_timeout=reset_time,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
now = datetime.now(UTC)
|
|
253
|
+
results.append(
|
|
254
|
+
ToolResult(
|
|
255
|
+
tool=call.tool,
|
|
256
|
+
result=None,
|
|
257
|
+
error=str(error),
|
|
258
|
+
start_time=now,
|
|
259
|
+
end_time=now,
|
|
260
|
+
machine="circuit_breaker",
|
|
261
|
+
pid=0,
|
|
262
|
+
)
|
|
263
|
+
)
|
|
264
|
+
continue
|
|
265
|
+
|
|
266
|
+
# Execute the call
|
|
267
|
+
start_time = time.monotonic()
|
|
268
|
+
try:
|
|
269
|
+
# Execute single call
|
|
270
|
+
executor_kwargs = {"timeout": timeout}
|
|
271
|
+
if hasattr(self.executor, "use_cache"):
|
|
272
|
+
executor_kwargs["use_cache"] = use_cache
|
|
273
|
+
|
|
274
|
+
result_list = await self.executor.execute([call], **executor_kwargs)
|
|
275
|
+
result = result_list[0]
|
|
276
|
+
|
|
277
|
+
# Check if successful
|
|
278
|
+
duration = time.monotonic() - start_time
|
|
279
|
+
|
|
280
|
+
# Determine success/failure
|
|
281
|
+
is_timeout = state.config.timeout_threshold is not None and duration > state.config.timeout_threshold
|
|
282
|
+
is_error = result.error is not None
|
|
283
|
+
|
|
284
|
+
if is_error or is_timeout:
|
|
285
|
+
await state.record_failure()
|
|
286
|
+
else:
|
|
287
|
+
await state.record_success()
|
|
288
|
+
|
|
289
|
+
results.append(result)
|
|
290
|
+
|
|
291
|
+
except Exception as e:
|
|
292
|
+
# Exception during execution
|
|
293
|
+
await state.record_failure()
|
|
294
|
+
|
|
295
|
+
now = datetime.now(UTC)
|
|
296
|
+
results.append(
|
|
297
|
+
ToolResult(
|
|
298
|
+
tool=call.tool,
|
|
299
|
+
result=None,
|
|
300
|
+
error=f"Circuit breaker caught exception: {str(e)}",
|
|
301
|
+
start_time=now,
|
|
302
|
+
end_time=now,
|
|
303
|
+
machine="circuit_breaker",
|
|
304
|
+
pid=0,
|
|
305
|
+
)
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
finally:
|
|
309
|
+
# Release HALF_OPEN slot if applicable
|
|
310
|
+
if state.state == CircuitState.HALF_OPEN:
|
|
311
|
+
await state.release_half_open_slot()
|
|
312
|
+
|
|
313
|
+
return results
|
|
314
|
+
|
|
315
|
+
async def get_circuit_states(self) -> dict[str, dict[str, Any]]:
|
|
316
|
+
"""
|
|
317
|
+
Get current state of all circuit breakers.
|
|
318
|
+
|
|
319
|
+
Returns:
|
|
320
|
+
Dict mapping tool name to state info
|
|
321
|
+
"""
|
|
322
|
+
states = {}
|
|
323
|
+
async with self._states_lock:
|
|
324
|
+
for tool, state in self._states.items():
|
|
325
|
+
states[tool] = state.get_state()
|
|
326
|
+
return states
|
|
327
|
+
|
|
328
|
+
async def reset_circuit(self, tool: str) -> None:
|
|
329
|
+
"""
|
|
330
|
+
Manually reset a circuit breaker.
|
|
331
|
+
|
|
332
|
+
Args:
|
|
333
|
+
tool: Tool name to reset
|
|
334
|
+
"""
|
|
335
|
+
if tool in self._states:
|
|
336
|
+
state = self._states[tool]
|
|
337
|
+
async with state._lock:
|
|
338
|
+
state.state = CircuitState.CLOSED
|
|
339
|
+
state.failure_count = 0
|
|
340
|
+
state.success_count = 0
|
|
341
|
+
state.opened_at = None
|
|
342
|
+
state.half_open_calls = 0
|
|
343
|
+
logger.info(f"Manually reset circuit breaker for {tool}")
|
|
@@ -36,6 +36,7 @@ class RetryConfig:
|
|
|
36
36
|
jitter: bool = True,
|
|
37
37
|
retry_on_exceptions: list[type[Exception]] | None = None,
|
|
38
38
|
retry_on_error_substrings: list[str] | None = None,
|
|
39
|
+
skip_retry_on_error_substrings: list[str] | None = None,
|
|
39
40
|
):
|
|
40
41
|
if max_retries < 0:
|
|
41
42
|
raise ValueError("max_retries cannot be negative")
|
|
@@ -45,6 +46,7 @@ class RetryConfig:
|
|
|
45
46
|
self.jitter = jitter
|
|
46
47
|
self.retry_on_exceptions = retry_on_exceptions or []
|
|
47
48
|
self.retry_on_error_substrings = retry_on_error_substrings or []
|
|
49
|
+
self.skip_retry_on_error_substrings = skip_retry_on_error_substrings or []
|
|
48
50
|
|
|
49
51
|
# --------------------------------------------------------------------- #
|
|
50
52
|
# Decision helpers
|
|
@@ -60,6 +62,14 @@ class RetryConfig:
|
|
|
60
62
|
if attempt >= self.max_retries:
|
|
61
63
|
return False
|
|
62
64
|
|
|
65
|
+
# Check skip list first - these errors should never be retried
|
|
66
|
+
# (e.g., OAuth errors that need to be handled at transport layer)
|
|
67
|
+
if error_str and self.skip_retry_on_error_substrings:
|
|
68
|
+
error_lower = error_str.lower()
|
|
69
|
+
if any(skip_pattern.lower() in error_lower for skip_pattern in self.skip_retry_on_error_substrings):
|
|
70
|
+
logger.debug(f"Skipping retry for error matching skip pattern: {error_str[:100]}")
|
|
71
|
+
return False
|
|
72
|
+
|
|
63
73
|
# Nothing specified → always retry until max_retries reached
|
|
64
74
|
if not self.retry_on_exceptions and not self.retry_on_error_substrings:
|
|
65
75
|
return True
|
|
@@ -246,6 +256,7 @@ def retryable(
|
|
|
246
256
|
jitter: bool = True,
|
|
247
257
|
retry_on_exceptions: list[type[Exception]] | None = None,
|
|
248
258
|
retry_on_error_substrings: list[str] | None = None,
|
|
259
|
+
skip_retry_on_error_substrings: list[str] | None = None,
|
|
249
260
|
):
|
|
250
261
|
"""
|
|
251
262
|
Class decorator that attaches a :class:`RetryConfig` to a *tool* class.
|
|
@@ -267,6 +278,7 @@ def retryable(
|
|
|
267
278
|
jitter=jitter,
|
|
268
279
|
retry_on_exceptions=retry_on_exceptions,
|
|
269
280
|
retry_on_error_substrings=retry_on_error_substrings,
|
|
281
|
+
skip_retry_on_error_substrings=skip_retry_on_error_substrings,
|
|
270
282
|
)
|
|
271
283
|
return cls
|
|
272
284
|
|
|
@@ -41,8 +41,8 @@ async def setup_mcp_http_streamable(
|
|
|
41
41
|
enable_rate_limiting: bool = False,
|
|
42
42
|
global_rate_limit: int | None = None,
|
|
43
43
|
tool_rate_limits: dict[str, tuple] | None = None,
|
|
44
|
-
enable_retries: bool =
|
|
45
|
-
max_retries: int =
|
|
44
|
+
enable_retries: bool = True, # CHANGED: Enabled with OAuth errors excluded
|
|
45
|
+
max_retries: int = 2, # Retry non-OAuth errors (OAuth handled at transport level)
|
|
46
46
|
namespace: str = "http",
|
|
47
47
|
oauth_refresh_callback: any | None = None, # NEW: OAuth token refresh callback
|
|
48
48
|
) -> tuple[ToolProcessor, StreamManager]:
|
|
@@ -102,6 +102,34 @@ async def setup_mcp_http_streamable(
|
|
|
102
102
|
registered = await register_mcp_tools(stream_manager, namespace=namespace)
|
|
103
103
|
|
|
104
104
|
# 3️⃣ build a processor instance configured to your taste
|
|
105
|
+
# IMPORTANT: Retries are enabled but OAuth errors are excluded
|
|
106
|
+
# OAuth refresh happens at transport level with automatic retry
|
|
107
|
+
|
|
108
|
+
# Import RetryConfig to configure OAuth error exclusion
|
|
109
|
+
from chuk_tool_processor.execution.wrappers.retry import RetryConfig
|
|
110
|
+
|
|
111
|
+
# Define OAuth error patterns that should NOT be retried at this level
|
|
112
|
+
# These will be handled by the transport layer's OAuth refresh mechanism
|
|
113
|
+
oauth_error_patterns = [
|
|
114
|
+
"invalid_token",
|
|
115
|
+
"oauth validation",
|
|
116
|
+
"unauthorized",
|
|
117
|
+
"expired token",
|
|
118
|
+
"token expired",
|
|
119
|
+
"authentication failed",
|
|
120
|
+
"invalid access token",
|
|
121
|
+
]
|
|
122
|
+
|
|
123
|
+
# Create retry config that skips OAuth errors
|
|
124
|
+
retry_config = (
|
|
125
|
+
RetryConfig(
|
|
126
|
+
max_retries=max_retries,
|
|
127
|
+
skip_retry_on_error_substrings=oauth_error_patterns,
|
|
128
|
+
)
|
|
129
|
+
if enable_retries
|
|
130
|
+
else None
|
|
131
|
+
)
|
|
132
|
+
|
|
105
133
|
processor = ToolProcessor(
|
|
106
134
|
default_timeout=default_timeout,
|
|
107
135
|
max_concurrency=max_concurrency,
|
|
@@ -112,6 +140,7 @@ async def setup_mcp_http_streamable(
|
|
|
112
140
|
tool_rate_limits=tool_rate_limits,
|
|
113
141
|
enable_retries=enable_retries,
|
|
114
142
|
max_retries=max_retries,
|
|
143
|
+
retry_config=retry_config, # NEW: Pass OAuth-aware retry config
|
|
115
144
|
)
|
|
116
145
|
|
|
117
146
|
logger.debug(
|
|
@@ -37,8 +37,8 @@ async def setup_mcp_sse( # noqa: C901 - long but just a config facade
|
|
|
37
37
|
enable_rate_limiting: bool = False,
|
|
38
38
|
global_rate_limit: int | None = None,
|
|
39
39
|
tool_rate_limits: dict[str, tuple] | None = None,
|
|
40
|
-
enable_retries: bool =
|
|
41
|
-
max_retries: int =
|
|
40
|
+
enable_retries: bool = True, # CHANGED: Enabled with OAuth errors excluded
|
|
41
|
+
max_retries: int = 2, # Retry non-OAuth errors (OAuth handled at transport level)
|
|
42
42
|
namespace: str = "sse",
|
|
43
43
|
oauth_refresh_callback: any | None = None, # NEW: OAuth token refresh callback
|
|
44
44
|
) -> tuple[ToolProcessor, StreamManager]:
|
|
@@ -81,6 +81,34 @@ async def setup_mcp_sse( # noqa: C901 - long but just a config facade
|
|
|
81
81
|
registered = await register_mcp_tools(stream_manager, namespace=namespace)
|
|
82
82
|
|
|
83
83
|
# 3️⃣ build a processor instance configured to your taste
|
|
84
|
+
# IMPORTANT: Retries are enabled but OAuth errors are excluded
|
|
85
|
+
# OAuth refresh happens at transport level with automatic retry
|
|
86
|
+
|
|
87
|
+
# Import RetryConfig to configure OAuth error exclusion
|
|
88
|
+
from chuk_tool_processor.execution.wrappers.retry import RetryConfig
|
|
89
|
+
|
|
90
|
+
# Define OAuth error patterns that should NOT be retried at this level
|
|
91
|
+
# These will be handled by the transport layer's OAuth refresh mechanism
|
|
92
|
+
oauth_error_patterns = [
|
|
93
|
+
"invalid_token",
|
|
94
|
+
"oauth validation",
|
|
95
|
+
"unauthorized",
|
|
96
|
+
"expired token",
|
|
97
|
+
"token expired",
|
|
98
|
+
"authentication failed",
|
|
99
|
+
"invalid access token",
|
|
100
|
+
]
|
|
101
|
+
|
|
102
|
+
# Create retry config that skips OAuth errors
|
|
103
|
+
retry_config = (
|
|
104
|
+
RetryConfig(
|
|
105
|
+
max_retries=max_retries,
|
|
106
|
+
skip_retry_on_error_substrings=oauth_error_patterns,
|
|
107
|
+
)
|
|
108
|
+
if enable_retries
|
|
109
|
+
else None
|
|
110
|
+
)
|
|
111
|
+
|
|
84
112
|
processor = ToolProcessor(
|
|
85
113
|
default_timeout=default_timeout,
|
|
86
114
|
max_concurrency=max_concurrency,
|
|
@@ -91,6 +119,7 @@ async def setup_mcp_sse( # noqa: C901 - long but just a config facade
|
|
|
91
119
|
tool_rate_limits=tool_rate_limits,
|
|
92
120
|
enable_retries=enable_retries,
|
|
93
121
|
max_retries=max_retries,
|
|
122
|
+
retry_config=retry_config, # NEW: Pass OAuth-aware retry config
|
|
94
123
|
)
|
|
95
124
|
|
|
96
125
|
logger.debug(
|
|
@@ -1 +1,21 @@
|
|
|
1
1
|
# chuk_tool_processor/models/__init__.py
|
|
2
|
+
"""Data models for the tool processor."""
|
|
3
|
+
|
|
4
|
+
from chuk_tool_processor.models.execution_strategy import ExecutionStrategy
|
|
5
|
+
from chuk_tool_processor.models.streaming_tool import StreamingTool
|
|
6
|
+
from chuk_tool_processor.models.tool_call import ToolCall
|
|
7
|
+
from chuk_tool_processor.models.tool_result import ToolResult
|
|
8
|
+
from chuk_tool_processor.models.tool_spec import ToolCapability, ToolSpec, tool_spec
|
|
9
|
+
from chuk_tool_processor.models.validated_tool import ValidatedTool, with_validation
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"ExecutionStrategy",
|
|
13
|
+
"StreamingTool",
|
|
14
|
+
"ToolCall",
|
|
15
|
+
"ToolResult",
|
|
16
|
+
"ToolSpec",
|
|
17
|
+
"ToolCapability",
|
|
18
|
+
"tool_spec",
|
|
19
|
+
"ValidatedTool",
|
|
20
|
+
"with_validation",
|
|
21
|
+
]
|
|
@@ -5,10 +5,12 @@ Model representing a tool call with arguments.
|
|
|
5
5
|
|
|
6
6
|
from __future__ import annotations
|
|
7
7
|
|
|
8
|
+
import hashlib
|
|
9
|
+
import json
|
|
8
10
|
import uuid
|
|
9
11
|
from typing import Any
|
|
10
12
|
|
|
11
|
-
from pydantic import BaseModel, ConfigDict, Field
|
|
13
|
+
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
|
12
14
|
|
|
13
15
|
|
|
14
16
|
class ToolCall(BaseModel):
|
|
@@ -20,6 +22,7 @@ class ToolCall(BaseModel):
|
|
|
20
22
|
tool: Name of the tool to call
|
|
21
23
|
namespace: Namespace the tool belongs to
|
|
22
24
|
arguments: Arguments to pass to the tool
|
|
25
|
+
idempotency_key: Optional key for deduplicating duplicate calls (auto-generated)
|
|
23
26
|
"""
|
|
24
27
|
|
|
25
28
|
model_config = ConfigDict(extra="ignore")
|
|
@@ -28,6 +31,36 @@ class ToolCall(BaseModel):
|
|
|
28
31
|
tool: str = Field(..., min_length=1, description="Name of the tool to call; must be non-empty")
|
|
29
32
|
namespace: str = Field(default="default", description="Namespace the tool belongs to")
|
|
30
33
|
arguments: dict[str, Any] = Field(default_factory=dict, description="Arguments to pass to the tool")
|
|
34
|
+
idempotency_key: str | None = Field(
|
|
35
|
+
None,
|
|
36
|
+
description="Idempotency key for deduplication. Auto-generated if not provided.",
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
@model_validator(mode="after")
|
|
40
|
+
def generate_idempotency_key(self) -> ToolCall:
|
|
41
|
+
"""Generate idempotency key if not provided."""
|
|
42
|
+
if self.idempotency_key is None:
|
|
43
|
+
self.idempotency_key = self._compute_idempotency_key()
|
|
44
|
+
return self
|
|
45
|
+
|
|
46
|
+
def _compute_idempotency_key(self) -> str:
|
|
47
|
+
"""
|
|
48
|
+
Compute a stable idempotency key from tool name, namespace, and arguments.
|
|
49
|
+
|
|
50
|
+
Uses SHA256 hash of the sorted JSON representation.
|
|
51
|
+
Returns first 16 characters of the hex digest for brevity.
|
|
52
|
+
"""
|
|
53
|
+
# Create a stable representation
|
|
54
|
+
payload = {
|
|
55
|
+
"tool": self.tool,
|
|
56
|
+
"namespace": self.namespace,
|
|
57
|
+
"arguments": self.arguments,
|
|
58
|
+
}
|
|
59
|
+
# Sort keys for stability
|
|
60
|
+
json_str = json.dumps(payload, sort_keys=True, default=str)
|
|
61
|
+
# Hash it
|
|
62
|
+
hash_obj = hashlib.sha256(json_str.encode(), usedforsecurity=False)
|
|
63
|
+
return hash_obj.hexdigest()[:16] # Use first 16 chars for brevity
|
|
31
64
|
|
|
32
65
|
async def to_dict(self) -> dict[str, Any]:
|
|
33
66
|
"""Convert to a dictionary for serialization."""
|