chuk-tool-processor 0.6.29__py3-none-any.whl → 0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chuk-tool-processor might be problematic. Click here for more details.

@@ -0,0 +1,343 @@
1
+ # chuk_tool_processor/execution/wrappers/circuit_breaker.py
2
+ """
3
+ Circuit breaker pattern for tool execution.
4
+
5
+ Prevents cascading failures by tracking failure rates and temporarily
6
+ blocking calls to failing tools. Implements a state machine:
7
+
8
+ CLOSED → OPEN → HALF_OPEN → CLOSED (or back to OPEN)
9
+
10
+ States:
11
+ - CLOSED: Normal operation, requests pass through
12
+ - OPEN: Too many failures, requests blocked immediately
13
+ - HALF_OPEN: Testing if service recovered, limited requests allowed
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import asyncio
19
+ import time
20
+ from datetime import UTC, datetime
21
+ from enum import Enum
22
+ from typing import Any
23
+
24
+ from chuk_tool_processor.core.exceptions import ToolCircuitOpenError
25
+ from chuk_tool_processor.logging import get_logger
26
+ from chuk_tool_processor.models.tool_call import ToolCall
27
+ from chuk_tool_processor.models.tool_result import ToolResult
28
+
29
+ logger = get_logger("chuk_tool_processor.execution.wrappers.circuit_breaker")
30
+
31
+
32
+ # --------------------------------------------------------------------------- #
33
+ # Circuit breaker state
34
+ # --------------------------------------------------------------------------- #
35
+ class CircuitState(str, Enum):
36
+ """Circuit breaker states."""
37
+
38
+ CLOSED = "closed" # Normal operation
39
+ OPEN = "open" # Blocking requests due to failures
40
+ HALF_OPEN = "half_open" # Testing recovery with limited requests
41
+
42
+
43
+ class CircuitBreakerConfig:
44
+ """Configuration for circuit breaker behavior."""
45
+
46
+ def __init__(
47
+ self,
48
+ failure_threshold: int = 5,
49
+ success_threshold: int = 2,
50
+ reset_timeout: float = 60.0,
51
+ half_open_max_calls: int = 1,
52
+ timeout_threshold: float | None = None,
53
+ ):
54
+ """
55
+ Initialize circuit breaker configuration.
56
+
57
+ Args:
58
+ failure_threshold: Number of failures before opening circuit
59
+ success_threshold: Number of successes in HALF_OPEN to close circuit
60
+ reset_timeout: Seconds to wait before trying HALF_OPEN
61
+ half_open_max_calls: Max concurrent calls in HALF_OPEN state
62
+ timeout_threshold: Optional timeout (s) to consider as failure
63
+ """
64
+ self.failure_threshold = failure_threshold
65
+ self.success_threshold = success_threshold
66
+ self.reset_timeout = reset_timeout
67
+ self.half_open_max_calls = half_open_max_calls
68
+ self.timeout_threshold = timeout_threshold
69
+
70
+
71
+ class CircuitBreakerState:
72
+ """Per-tool circuit breaker state tracking."""
73
+
74
+ def __init__(self, config: CircuitBreakerConfig):
75
+ self.config = config
76
+ self.state = CircuitState.CLOSED
77
+ self.failure_count = 0
78
+ self.success_count = 0
79
+ self.last_failure_time: float | None = None
80
+ self.opened_at: float | None = None
81
+ self.half_open_calls = 0
82
+ self._lock = asyncio.Lock()
83
+
84
+ async def record_success(self) -> None:
85
+ """Record a successful call."""
86
+ async with self._lock:
87
+ if self.state == CircuitState.HALF_OPEN:
88
+ self.success_count += 1
89
+ logger.debug(f"Circuit HALF_OPEN: success {self.success_count}/{self.config.success_threshold}")
90
+
91
+ # Enough successes? Close the circuit
92
+ if self.success_count >= self.config.success_threshold:
93
+ logger.info("Circuit breaker: Transitioning to CLOSED (service recovered)")
94
+ self.state = CircuitState.CLOSED
95
+ self.failure_count = 0
96
+ self.success_count = 0
97
+ self.opened_at = None
98
+ self.half_open_calls = 0
99
+ else:
100
+ # In CLOSED state, just reset failure count
101
+ self.failure_count = 0
102
+
103
+ async def record_failure(self) -> None:
104
+ """Record a failed call."""
105
+ async with self._lock:
106
+ self.failure_count += 1
107
+ self.last_failure_time = time.monotonic()
108
+ logger.debug(f"Circuit: failure {self.failure_count}/{self.config.failure_threshold}")
109
+
110
+ if self.state == CircuitState.CLOSED:
111
+ # Check if we should open
112
+ if self.failure_count >= self.config.failure_threshold:
113
+ logger.warning(f"Circuit breaker: OPENING after {self.failure_count} failures")
114
+ self.state = CircuitState.OPEN
115
+ self.opened_at = time.monotonic()
116
+ elif self.state == CircuitState.HALF_OPEN:
117
+ # Failed during test → back to OPEN
118
+ logger.warning("Circuit breaker: Back to OPEN (test failed)")
119
+ self.state = CircuitState.OPEN
120
+ self.success_count = 0
121
+ self.opened_at = time.monotonic()
122
+ self.half_open_calls = 0
123
+
124
+ async def can_execute(self) -> bool:
125
+ """Check if a call should be allowed through."""
126
+ async with self._lock:
127
+ if self.state == CircuitState.CLOSED:
128
+ return True
129
+
130
+ if self.state == CircuitState.HALF_OPEN:
131
+ # Limit concurrent calls in HALF_OPEN
132
+ if self.half_open_calls < self.config.half_open_max_calls:
133
+ self.half_open_calls += 1
134
+ return True
135
+ return False
136
+
137
+ # OPEN state: check if we should try HALF_OPEN
138
+ if self.opened_at is not None:
139
+ elapsed = time.monotonic() - self.opened_at
140
+ if elapsed >= self.config.reset_timeout:
141
+ logger.info("Circuit breaker: Transitioning to HALF_OPEN (testing recovery)")
142
+ self.state = CircuitState.HALF_OPEN
143
+ self.half_open_calls = 1
144
+ self.success_count = 0
145
+ return True
146
+
147
+ return False
148
+
149
+ async def release_half_open_slot(self) -> None:
150
+ """Release a HALF_OPEN slot after call completes."""
151
+ async with self._lock:
152
+ if self.state == CircuitState.HALF_OPEN:
153
+ self.half_open_calls = max(0, self.half_open_calls - 1)
154
+
155
+ def get_state(self) -> dict[str, Any]:
156
+ """Get current state as dict."""
157
+ return {
158
+ "state": self.state.value,
159
+ "failure_count": self.failure_count,
160
+ "success_count": self.success_count,
161
+ "opened_at": self.opened_at,
162
+ "time_until_half_open": (
163
+ max(0, self.config.reset_timeout - (time.monotonic() - self.opened_at))
164
+ if self.opened_at and self.state == CircuitState.OPEN
165
+ else None
166
+ ),
167
+ }
168
+
169
+
170
+ # --------------------------------------------------------------------------- #
171
+ # Circuit breaker executor wrapper
172
+ # --------------------------------------------------------------------------- #
173
+ class CircuitBreakerExecutor:
174
+ """
175
+ Executor wrapper that implements circuit breaker pattern.
176
+
177
+ Tracks failures per tool and opens circuit breakers to prevent
178
+ cascading failures when tools are consistently failing.
179
+ """
180
+
181
+ def __init__(
182
+ self,
183
+ executor: Any,
184
+ *,
185
+ default_config: CircuitBreakerConfig | None = None,
186
+ tool_configs: dict[str, CircuitBreakerConfig] | None = None,
187
+ ):
188
+ """
189
+ Initialize circuit breaker executor.
190
+
191
+ Args:
192
+ executor: Underlying executor to wrap
193
+ default_config: Default circuit breaker configuration
194
+ tool_configs: Per-tool circuit breaker configurations
195
+ """
196
+ self.executor = executor
197
+ self.default_config = default_config or CircuitBreakerConfig()
198
+ self.tool_configs = tool_configs or {}
199
+ self._states: dict[str, CircuitBreakerState] = {}
200
+ self._states_lock = asyncio.Lock()
201
+
202
+ async def _get_state(self, tool: str) -> CircuitBreakerState:
203
+ """Get or create circuit breaker state for a tool."""
204
+ if tool not in self._states:
205
+ async with self._states_lock:
206
+ if tool not in self._states:
207
+ config = self.tool_configs.get(tool, self.default_config)
208
+ self._states[tool] = CircuitBreakerState(config)
209
+ return self._states[tool]
210
+
211
+ async def execute(
212
+ self,
213
+ calls: list[ToolCall],
214
+ *,
215
+ timeout: float | None = None,
216
+ use_cache: bool = True,
217
+ ) -> list[ToolResult]:
218
+ """
219
+ Execute tool calls with circuit breaker protection.
220
+
221
+ Args:
222
+ calls: List of tool calls to execute
223
+ timeout: Optional timeout for execution
224
+ use_cache: Whether to use cached results
225
+
226
+ Returns:
227
+ List of tool results
228
+ """
229
+ if not calls:
230
+ return []
231
+
232
+ results: list[ToolResult] = []
233
+
234
+ for call in calls:
235
+ state = await self._get_state(call.tool)
236
+
237
+ # Check if circuit allows execution
238
+ can_execute = await state.can_execute()
239
+
240
+ if not can_execute:
241
+ # Circuit is OPEN - reject immediately
242
+ state_info = state.get_state()
243
+ logger.warning(f"Circuit breaker OPEN for {call.tool} (failures: {state.failure_count})")
244
+
245
+ reset_time = state_info.get("time_until_half_open")
246
+ error = ToolCircuitOpenError(
247
+ tool_name=call.tool,
248
+ failure_count=state.failure_count,
249
+ reset_timeout=reset_time,
250
+ )
251
+
252
+ now = datetime.now(UTC)
253
+ results.append(
254
+ ToolResult(
255
+ tool=call.tool,
256
+ result=None,
257
+ error=str(error),
258
+ start_time=now,
259
+ end_time=now,
260
+ machine="circuit_breaker",
261
+ pid=0,
262
+ )
263
+ )
264
+ continue
265
+
266
+ # Execute the call
267
+ start_time = time.monotonic()
268
+ try:
269
+ # Execute single call
270
+ executor_kwargs = {"timeout": timeout}
271
+ if hasattr(self.executor, "use_cache"):
272
+ executor_kwargs["use_cache"] = use_cache
273
+
274
+ result_list = await self.executor.execute([call], **executor_kwargs)
275
+ result = result_list[0]
276
+
277
+ # Check if successful
278
+ duration = time.monotonic() - start_time
279
+
280
+ # Determine success/failure
281
+ is_timeout = state.config.timeout_threshold is not None and duration > state.config.timeout_threshold
282
+ is_error = result.error is not None
283
+
284
+ if is_error or is_timeout:
285
+ await state.record_failure()
286
+ else:
287
+ await state.record_success()
288
+
289
+ results.append(result)
290
+
291
+ except Exception as e:
292
+ # Exception during execution
293
+ await state.record_failure()
294
+
295
+ now = datetime.now(UTC)
296
+ results.append(
297
+ ToolResult(
298
+ tool=call.tool,
299
+ result=None,
300
+ error=f"Circuit breaker caught exception: {str(e)}",
301
+ start_time=now,
302
+ end_time=now,
303
+ machine="circuit_breaker",
304
+ pid=0,
305
+ )
306
+ )
307
+
308
+ finally:
309
+ # Release HALF_OPEN slot if applicable
310
+ if state.state == CircuitState.HALF_OPEN:
311
+ await state.release_half_open_slot()
312
+
313
+ return results
314
+
315
+ async def get_circuit_states(self) -> dict[str, dict[str, Any]]:
316
+ """
317
+ Get current state of all circuit breakers.
318
+
319
+ Returns:
320
+ Dict mapping tool name to state info
321
+ """
322
+ states = {}
323
+ async with self._states_lock:
324
+ for tool, state in self._states.items():
325
+ states[tool] = state.get_state()
326
+ return states
327
+
328
+ async def reset_circuit(self, tool: str) -> None:
329
+ """
330
+ Manually reset a circuit breaker.
331
+
332
+ Args:
333
+ tool: Tool name to reset
334
+ """
335
+ if tool in self._states:
336
+ state = self._states[tool]
337
+ async with state._lock:
338
+ state.state = CircuitState.CLOSED
339
+ state.failure_count = 0
340
+ state.success_count = 0
341
+ state.opened_at = None
342
+ state.half_open_calls = 0
343
+ logger.info(f"Manually reset circuit breaker for {tool}")
@@ -36,6 +36,7 @@ class RetryConfig:
36
36
  jitter: bool = True,
37
37
  retry_on_exceptions: list[type[Exception]] | None = None,
38
38
  retry_on_error_substrings: list[str] | None = None,
39
+ skip_retry_on_error_substrings: list[str] | None = None,
39
40
  ):
40
41
  if max_retries < 0:
41
42
  raise ValueError("max_retries cannot be negative")
@@ -45,6 +46,7 @@ class RetryConfig:
45
46
  self.jitter = jitter
46
47
  self.retry_on_exceptions = retry_on_exceptions or []
47
48
  self.retry_on_error_substrings = retry_on_error_substrings or []
49
+ self.skip_retry_on_error_substrings = skip_retry_on_error_substrings or []
48
50
 
49
51
  # --------------------------------------------------------------------- #
50
52
  # Decision helpers
@@ -60,6 +62,14 @@ class RetryConfig:
60
62
  if attempt >= self.max_retries:
61
63
  return False
62
64
 
65
+ # Check skip list first - these errors should never be retried
66
+ # (e.g., OAuth errors that need to be handled at transport layer)
67
+ if error_str and self.skip_retry_on_error_substrings:
68
+ error_lower = error_str.lower()
69
+ if any(skip_pattern.lower() in error_lower for skip_pattern in self.skip_retry_on_error_substrings):
70
+ logger.debug(f"Skipping retry for error matching skip pattern: {error_str[:100]}")
71
+ return False
72
+
63
73
  # Nothing specified → always retry until max_retries reached
64
74
  if not self.retry_on_exceptions and not self.retry_on_error_substrings:
65
75
  return True
@@ -246,6 +256,7 @@ def retryable(
246
256
  jitter: bool = True,
247
257
  retry_on_exceptions: list[type[Exception]] | None = None,
248
258
  retry_on_error_substrings: list[str] | None = None,
259
+ skip_retry_on_error_substrings: list[str] | None = None,
249
260
  ):
250
261
  """
251
262
  Class decorator that attaches a :class:`RetryConfig` to a *tool* class.
@@ -267,6 +278,7 @@ def retryable(
267
278
  jitter=jitter,
268
279
  retry_on_exceptions=retry_on_exceptions,
269
280
  retry_on_error_substrings=retry_on_error_substrings,
281
+ skip_retry_on_error_substrings=skip_retry_on_error_substrings,
270
282
  )
271
283
  return cls
272
284
 
@@ -41,8 +41,8 @@ async def setup_mcp_http_streamable(
41
41
  enable_rate_limiting: bool = False,
42
42
  global_rate_limit: int | None = None,
43
43
  tool_rate_limits: dict[str, tuple] | None = None,
44
- enable_retries: bool = True,
45
- max_retries: int = 3,
44
+ enable_retries: bool = True, # CHANGED: Enabled with OAuth errors excluded
45
+ max_retries: int = 2, # Retry non-OAuth errors (OAuth handled at transport level)
46
46
  namespace: str = "http",
47
47
  oauth_refresh_callback: any | None = None, # NEW: OAuth token refresh callback
48
48
  ) -> tuple[ToolProcessor, StreamManager]:
@@ -102,6 +102,34 @@ async def setup_mcp_http_streamable(
102
102
  registered = await register_mcp_tools(stream_manager, namespace=namespace)
103
103
 
104
104
  # 3️⃣ build a processor instance configured to your taste
105
+ # IMPORTANT: Retries are enabled but OAuth errors are excluded
106
+ # OAuth refresh happens at transport level with automatic retry
107
+
108
+ # Import RetryConfig to configure OAuth error exclusion
109
+ from chuk_tool_processor.execution.wrappers.retry import RetryConfig
110
+
111
+ # Define OAuth error patterns that should NOT be retried at this level
112
+ # These will be handled by the transport layer's OAuth refresh mechanism
113
+ oauth_error_patterns = [
114
+ "invalid_token",
115
+ "oauth validation",
116
+ "unauthorized",
117
+ "expired token",
118
+ "token expired",
119
+ "authentication failed",
120
+ "invalid access token",
121
+ ]
122
+
123
+ # Create retry config that skips OAuth errors
124
+ retry_config = (
125
+ RetryConfig(
126
+ max_retries=max_retries,
127
+ skip_retry_on_error_substrings=oauth_error_patterns,
128
+ )
129
+ if enable_retries
130
+ else None
131
+ )
132
+
105
133
  processor = ToolProcessor(
106
134
  default_timeout=default_timeout,
107
135
  max_concurrency=max_concurrency,
@@ -112,6 +140,7 @@ async def setup_mcp_http_streamable(
112
140
  tool_rate_limits=tool_rate_limits,
113
141
  enable_retries=enable_retries,
114
142
  max_retries=max_retries,
143
+ retry_config=retry_config, # NEW: Pass OAuth-aware retry config
115
144
  )
116
145
 
117
146
  logger.debug(
@@ -37,8 +37,8 @@ async def setup_mcp_sse( # noqa: C901 - long but just a config facade
37
37
  enable_rate_limiting: bool = False,
38
38
  global_rate_limit: int | None = None,
39
39
  tool_rate_limits: dict[str, tuple] | None = None,
40
- enable_retries: bool = True,
41
- max_retries: int = 3,
40
+ enable_retries: bool = True, # CHANGED: Enabled with OAuth errors excluded
41
+ max_retries: int = 2, # Retry non-OAuth errors (OAuth handled at transport level)
42
42
  namespace: str = "sse",
43
43
  oauth_refresh_callback: any | None = None, # NEW: OAuth token refresh callback
44
44
  ) -> tuple[ToolProcessor, StreamManager]:
@@ -81,6 +81,34 @@ async def setup_mcp_sse( # noqa: C901 - long but just a config facade
81
81
  registered = await register_mcp_tools(stream_manager, namespace=namespace)
82
82
 
83
83
  # 3️⃣ build a processor instance configured to your taste
84
+ # IMPORTANT: Retries are enabled but OAuth errors are excluded
85
+ # OAuth refresh happens at transport level with automatic retry
86
+
87
+ # Import RetryConfig to configure OAuth error exclusion
88
+ from chuk_tool_processor.execution.wrappers.retry import RetryConfig
89
+
90
+ # Define OAuth error patterns that should NOT be retried at this level
91
+ # These will be handled by the transport layer's OAuth refresh mechanism
92
+ oauth_error_patterns = [
93
+ "invalid_token",
94
+ "oauth validation",
95
+ "unauthorized",
96
+ "expired token",
97
+ "token expired",
98
+ "authentication failed",
99
+ "invalid access token",
100
+ ]
101
+
102
+ # Create retry config that skips OAuth errors
103
+ retry_config = (
104
+ RetryConfig(
105
+ max_retries=max_retries,
106
+ skip_retry_on_error_substrings=oauth_error_patterns,
107
+ )
108
+ if enable_retries
109
+ else None
110
+ )
111
+
84
112
  processor = ToolProcessor(
85
113
  default_timeout=default_timeout,
86
114
  max_concurrency=max_concurrency,
@@ -91,6 +119,7 @@ async def setup_mcp_sse( # noqa: C901 - long but just a config facade
91
119
  tool_rate_limits=tool_rate_limits,
92
120
  enable_retries=enable_retries,
93
121
  max_retries=max_retries,
122
+ retry_config=retry_config, # NEW: Pass OAuth-aware retry config
94
123
  )
95
124
 
96
125
  logger.debug(
@@ -1 +1,21 @@
1
1
  # chuk_tool_processor/models/__init__.py
2
+ """Data models for the tool processor."""
3
+
4
+ from chuk_tool_processor.models.execution_strategy import ExecutionStrategy
5
+ from chuk_tool_processor.models.streaming_tool import StreamingTool
6
+ from chuk_tool_processor.models.tool_call import ToolCall
7
+ from chuk_tool_processor.models.tool_result import ToolResult
8
+ from chuk_tool_processor.models.tool_spec import ToolCapability, ToolSpec, tool_spec
9
+ from chuk_tool_processor.models.validated_tool import ValidatedTool, with_validation
10
+
11
+ __all__ = [
12
+ "ExecutionStrategy",
13
+ "StreamingTool",
14
+ "ToolCall",
15
+ "ToolResult",
16
+ "ToolSpec",
17
+ "ToolCapability",
18
+ "tool_spec",
19
+ "ValidatedTool",
20
+ "with_validation",
21
+ ]
@@ -5,10 +5,12 @@ Model representing a tool call with arguments.
5
5
 
6
6
  from __future__ import annotations
7
7
 
8
+ import hashlib
9
+ import json
8
10
  import uuid
9
11
  from typing import Any
10
12
 
11
- from pydantic import BaseModel, ConfigDict, Field
13
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
12
14
 
13
15
 
14
16
  class ToolCall(BaseModel):
@@ -20,6 +22,7 @@ class ToolCall(BaseModel):
20
22
  tool: Name of the tool to call
21
23
  namespace: Namespace the tool belongs to
22
24
  arguments: Arguments to pass to the tool
25
+ idempotency_key: Optional key for deduplicating duplicate calls (auto-generated)
23
26
  """
24
27
 
25
28
  model_config = ConfigDict(extra="ignore")
@@ -28,6 +31,36 @@ class ToolCall(BaseModel):
28
31
  tool: str = Field(..., min_length=1, description="Name of the tool to call; must be non-empty")
29
32
  namespace: str = Field(default="default", description="Namespace the tool belongs to")
30
33
  arguments: dict[str, Any] = Field(default_factory=dict, description="Arguments to pass to the tool")
34
+ idempotency_key: str | None = Field(
35
+ None,
36
+ description="Idempotency key for deduplication. Auto-generated if not provided.",
37
+ )
38
+
39
+ @model_validator(mode="after")
40
+ def generate_idempotency_key(self) -> ToolCall:
41
+ """Generate idempotency key if not provided."""
42
+ if self.idempotency_key is None:
43
+ self.idempotency_key = self._compute_idempotency_key()
44
+ return self
45
+
46
+ def _compute_idempotency_key(self) -> str:
47
+ """
48
+ Compute a stable idempotency key from tool name, namespace, and arguments.
49
+
50
+ Uses SHA256 hash of the sorted JSON representation.
51
+ Returns first 16 characters of the hex digest for brevity.
52
+ """
53
+ # Create a stable representation
54
+ payload = {
55
+ "tool": self.tool,
56
+ "namespace": self.namespace,
57
+ "arguments": self.arguments,
58
+ }
59
+ # Sort keys for stability
60
+ json_str = json.dumps(payload, sort_keys=True, default=str)
61
+ # Hash it
62
+ hash_obj = hashlib.sha256(json_str.encode(), usedforsecurity=False)
63
+ return hash_obj.hexdigest()[:16] # Use first 16 chars for brevity
31
64
 
32
65
  async def to_dict(self) -> dict[str, Any]:
33
66
  """Convert to a dictionary for serialization."""