chuk-tool-processor 0.1.6__py3-none-any.whl → 0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of chuk-tool-processor might be problematic. Click here for more details.

Files changed (46) hide show
  1. chuk_tool_processor/core/processor.py +345 -132
  2. chuk_tool_processor/execution/strategies/inprocess_strategy.py +522 -71
  3. chuk_tool_processor/execution/strategies/subprocess_strategy.py +559 -64
  4. chuk_tool_processor/execution/tool_executor.py +282 -24
  5. chuk_tool_processor/execution/wrappers/caching.py +465 -123
  6. chuk_tool_processor/execution/wrappers/rate_limiting.py +199 -86
  7. chuk_tool_processor/execution/wrappers/retry.py +133 -23
  8. chuk_tool_processor/logging/__init__.py +83 -10
  9. chuk_tool_processor/logging/context.py +218 -22
  10. chuk_tool_processor/logging/formatter.py +56 -13
  11. chuk_tool_processor/logging/helpers.py +91 -16
  12. chuk_tool_processor/logging/metrics.py +75 -6
  13. chuk_tool_processor/mcp/mcp_tool.py +80 -35
  14. chuk_tool_processor/mcp/register_mcp_tools.py +74 -56
  15. chuk_tool_processor/mcp/setup_mcp_sse.py +41 -36
  16. chuk_tool_processor/mcp/setup_mcp_stdio.py +39 -37
  17. chuk_tool_processor/mcp/transport/sse_transport.py +351 -105
  18. chuk_tool_processor/models/execution_strategy.py +52 -3
  19. chuk_tool_processor/models/streaming_tool.py +110 -0
  20. chuk_tool_processor/models/tool_call.py +56 -4
  21. chuk_tool_processor/models/tool_result.py +115 -9
  22. chuk_tool_processor/models/validated_tool.py +15 -13
  23. chuk_tool_processor/plugins/discovery.py +115 -70
  24. chuk_tool_processor/plugins/parsers/base.py +13 -5
  25. chuk_tool_processor/plugins/parsers/{function_call_tool_plugin.py → function_call_tool.py} +39 -20
  26. chuk_tool_processor/plugins/parsers/json_tool.py +50 -0
  27. chuk_tool_processor/plugins/parsers/openai_tool.py +88 -0
  28. chuk_tool_processor/plugins/parsers/xml_tool.py +74 -20
  29. chuk_tool_processor/registry/__init__.py +46 -7
  30. chuk_tool_processor/registry/auto_register.py +92 -28
  31. chuk_tool_processor/registry/decorators.py +134 -11
  32. chuk_tool_processor/registry/interface.py +48 -14
  33. chuk_tool_processor/registry/metadata.py +52 -6
  34. chuk_tool_processor/registry/provider.py +75 -36
  35. chuk_tool_processor/registry/providers/__init__.py +49 -10
  36. chuk_tool_processor/registry/providers/memory.py +59 -48
  37. chuk_tool_processor/registry/tool_export.py +208 -39
  38. chuk_tool_processor/utils/validation.py +18 -13
  39. chuk_tool_processor-0.2.dist-info/METADATA +401 -0
  40. chuk_tool_processor-0.2.dist-info/RECORD +58 -0
  41. {chuk_tool_processor-0.1.6.dist-info → chuk_tool_processor-0.2.dist-info}/WHEEL +1 -1
  42. chuk_tool_processor/plugins/parsers/json_tool_plugin.py +0 -38
  43. chuk_tool_processor/plugins/parsers/openai_tool_plugin.py +0 -76
  44. chuk_tool_processor-0.1.6.dist-info/METADATA +0 -462
  45. chuk_tool_processor-0.1.6.dist-info/RECORD +0 -57
  46. {chuk_tool_processor-0.1.6.dist-info → chuk_tool_processor-0.2.dist-info}/top_level.txt +0 -0
@@ -1,234 +1,576 @@
1
1
  # chuk_tool_processor/execution/wrappers/caching.py
2
+ """
3
+ Async-native caching wrapper for tool execution.
4
+
5
+ This module provides:
6
+
7
+ * **CacheInterface** – abstract async cache contract for custom implementations
8
+ * **InMemoryCache** – simple, thread-safe in-memory cache with TTL support
9
+ * **CachingToolExecutor** – executor wrapper that transparently caches results
10
+
11
+ Results retrieved from cache are marked with `cached=True` and `machine="cache"`
12
+ for easy detection.
13
+ """
14
+ from __future__ import annotations
15
+
2
16
  import asyncio
3
17
  import hashlib
4
18
  import json
5
- import time
19
+ import logging
6
20
  from abc import ABC, abstractmethod
7
- from datetime import datetime, timedelta
8
- from functools import wraps
9
- from typing import Any, Dict, Optional, Tuple, List, Callable
10
- from pydantic import BaseModel
21
+ from datetime import datetime, timedelta, timezone
22
+ from typing import Any, Dict, List, Optional, Tuple, Set, Union
23
+
24
+ from pydantic import BaseModel, Field
11
25
 
12
- # imports
13
26
  from chuk_tool_processor.models.tool_call import ToolCall
14
27
  from chuk_tool_processor.models.tool_result import ToolResult
28
+ from chuk_tool_processor.logging import get_logger
15
29
 
30
+ logger = get_logger("chuk_tool_processor.execution.wrappers.caching")
16
31
 
32
+ # --------------------------------------------------------------------------- #
33
+ # Cache primitives
34
+ # --------------------------------------------------------------------------- #
17
35
  class CacheEntry(BaseModel):
18
36
  """
19
- Entry in the tool result cache.
37
+ Model representing a cached tool result.
38
+
39
+ Attributes:
40
+ tool: Name of the tool
41
+ arguments_hash: Hash of the tool arguments
42
+ result: The cached result value
43
+ created_at: When the entry was created
44
+ expires_at: When the entry expires (None = no expiration)
20
45
  """
21
- tool: str
22
- arguments_hash: str
23
- result: Any
24
- created_at: datetime
25
- expires_at: Optional[datetime] = None
46
+ tool: str = Field(..., description="Tool name")
47
+ arguments_hash: str = Field(..., description="MD5 hash of arguments")
48
+ result: Any = Field(..., description="Cached result value")
49
+ created_at: datetime = Field(..., description="Creation timestamp")
50
+ expires_at: Optional[datetime] = Field(None, description="Expiration timestamp")
26
51
 
27
52
 
28
53
  class CacheInterface(ABC):
29
54
  """
30
- Abstract interface for cache implementations.
55
+ Abstract interface for tool result caches.
56
+
57
+ All cache implementations must be async-native and thread-safe.
31
58
  """
59
+
32
60
  @abstractmethod
33
61
  async def get(self, tool: str, arguments_hash: str) -> Optional[Any]:
34
62
  """
35
- Get a cached result for a tool with given arguments hash.
63
+ Get a cached result by tool name and arguments hash.
64
+
65
+ Args:
66
+ tool: Tool name
67
+ arguments_hash: Hash of the arguments
68
+
69
+ Returns:
70
+ Cached result value or None if not found
36
71
  """
37
72
  pass
38
-
73
+
39
74
  @abstractmethod
40
75
  async def set(
41
- self,
42
- tool: str,
43
- arguments_hash: str,
44
- result: Any,
45
- ttl: Optional[int] = None
76
+ self,
77
+ tool: str,
78
+ arguments_hash: str,
79
+ result: Any,
80
+ *,
81
+ ttl: Optional[int] = None,
46
82
  ) -> None:
47
83
  """
48
- Set a cached result for a tool with given arguments hash.
84
+ Set a cache entry.
85
+
86
+ Args:
87
+ tool: Tool name
88
+ arguments_hash: Hash of the arguments
89
+ result: Result value to cache
90
+ ttl: Time-to-live in seconds (overrides default)
49
91
  """
50
92
  pass
51
-
93
+
52
94
  @abstractmethod
53
95
  async def invalidate(self, tool: str, arguments_hash: Optional[str] = None) -> None:
54
96
  """
55
- Invalidate cached results for a tool, optionally for specific arguments.
97
+ Invalidate cache entries.
98
+
99
+ Args:
100
+ tool: Tool name
101
+ arguments_hash: Optional arguments hash. If None, all entries for the tool are invalidated.
56
102
  """
57
103
  pass
104
+
105
+ async def clear(self) -> None:
106
+ """
107
+ Clear all cache entries.
108
+
109
+ Default implementation raises NotImplementedError.
110
+ Override in subclasses to provide an efficient implementation.
111
+ """
112
+ raise NotImplementedError("Cache clear not implemented")
113
+
114
+ async def get_stats(self) -> Dict[str, Any]:
115
+ """
116
+ Get cache statistics.
117
+
118
+ Returns:
119
+ Dict with cache statistics (implementation-specific)
120
+ """
121
+ return {"implemented": False}
58
122
 
59
123
 
60
124
  class InMemoryCache(CacheInterface):
61
125
  """
62
- In-memory implementation of the cache interface.
126
+ In-memory cache implementation with async thread-safety.
127
+
128
+ This cache uses a two-level dictionary structure with asyncio locks
129
+ to ensure thread safety. Entries can have optional TTL values.
63
130
  """
64
- def __init__(self, default_ttl: Optional[int] = 300):
131
+
132
+ def __init__(self, default_ttl: Optional[int] = 300) -> None:
133
+ """
134
+ Initialize the in-memory cache.
135
+
136
+ Args:
137
+ default_ttl: Default time-to-live in seconds (None = no expiration)
138
+ """
65
139
  self._cache: Dict[str, Dict[str, CacheEntry]] = {}
66
140
  self._default_ttl = default_ttl
67
141
  self._lock = asyncio.Lock()
68
-
142
+ self._stats: Dict[str, int] = {
143
+ "hits": 0,
144
+ "misses": 0,
145
+ "sets": 0,
146
+ "invalidations": 0,
147
+ "expirations": 0,
148
+ }
149
+
150
+ logger.debug(f"Initialized InMemoryCache with default_ttl={default_ttl}s")
151
+
152
+ # ---------------------- Helper methods ------------------------ #
153
+ def _is_expired(self, entry: CacheEntry) -> bool:
154
+ """Check if an entry is expired."""
155
+ return entry.expires_at is not None and entry.expires_at < datetime.now()
156
+
157
+ async def _prune_expired(self) -> int:
158
+ """
159
+ Remove all expired entries.
160
+
161
+ Returns:
162
+ Number of entries removed
163
+ """
164
+ now = datetime.now()
165
+ removed = 0
166
+
167
+ async with self._lock:
168
+ for tool in list(self._cache.keys()):
169
+ tool_cache = self._cache[tool]
170
+ for arg_hash in list(tool_cache.keys()):
171
+ entry = tool_cache[arg_hash]
172
+ if entry.expires_at and entry.expires_at < now:
173
+ del tool_cache[arg_hash]
174
+ removed += 1
175
+ self._stats["expirations"] += 1
176
+
177
+ # Remove empty tool caches
178
+ if not tool_cache:
179
+ del self._cache[tool]
180
+
181
+ return removed
182
+
183
+ # ---------------------- CacheInterface implementation ------------------------ #
69
184
  async def get(self, tool: str, arguments_hash: str) -> Optional[Any]:
185
+ """
186
+ Get a cached result, checking expiration.
187
+
188
+ Args:
189
+ tool: Tool name
190
+ arguments_hash: Hash of the arguments
191
+
192
+ Returns:
193
+ Cached result value or None if not found or expired
194
+ """
70
195
  async with self._lock:
71
- tool_cache = self._cache.get(tool)
72
- if not tool_cache:
73
- return None
74
- entry = tool_cache.get(arguments_hash)
196
+ entry = self._cache.get(tool, {}).get(arguments_hash)
197
+
75
198
  if not entry:
199
+ self._stats["misses"] += 1
76
200
  return None
77
- now = datetime.now()
78
- if entry.expires_at and entry.expires_at < now:
79
- del tool_cache[arguments_hash]
201
+
202
+ if self._is_expired(entry):
203
+ # Prune expired entry
204
+ del self._cache[tool][arguments_hash]
205
+ if not self._cache[tool]:
206
+ del self._cache[tool]
207
+
208
+ self._stats["expirations"] += 1
209
+ self._stats["misses"] += 1
80
210
  return None
211
+
212
+ self._stats["hits"] += 1
81
213
  return entry.result
82
-
214
+
83
215
  async def set(
84
- self,
85
- tool: str,
86
- arguments_hash: str,
87
- result: Any,
88
- ttl: Optional[int] = None
216
+ self,
217
+ tool: str,
218
+ arguments_hash: str,
219
+ result: Any,
220
+ *,
221
+ ttl: Optional[int] = None,
89
222
  ) -> None:
223
+ """
224
+ Set a cache entry with optional custom TTL.
225
+
226
+ Args:
227
+ tool: Tool name
228
+ arguments_hash: Hash of the arguments
229
+ result: Result value to cache
230
+ ttl: Time-to-live in seconds (overrides default)
231
+ """
90
232
  async with self._lock:
91
- if tool not in self._cache:
92
- self._cache[tool] = {}
93
233
  now = datetime.now()
94
- expires_at = None
95
- actual_ttl = ttl if ttl is not None else self._default_ttl
96
- if actual_ttl is not None:
97
- expires_at = now + timedelta(seconds=actual_ttl)
234
+
235
+ # Calculate expiration
236
+ use_ttl = ttl if ttl is not None else self._default_ttl
237
+ expires_at = now + timedelta(seconds=use_ttl) if use_ttl is not None else None
238
+
239
+ # Create entry
98
240
  entry = CacheEntry(
99
241
  tool=tool,
100
242
  arguments_hash=arguments_hash,
101
243
  result=result,
102
244
  created_at=now,
103
- expires_at=expires_at
245
+ expires_at=expires_at,
104
246
  )
105
- self._cache[tool][arguments_hash] = entry
106
-
247
+
248
+ # Store in cache
249
+ self._cache.setdefault(tool, {})[arguments_hash] = entry
250
+ self._stats["sets"] += 1
251
+
252
+ logger.debug(
253
+ f"Cached result for {tool} (TTL: "
254
+ f"{use_ttl if use_ttl is not None else 'none'}s)"
255
+ )
256
+
107
257
  async def invalidate(self, tool: str, arguments_hash: Optional[str] = None) -> None:
258
+ """
259
+ Invalidate cache entries for a tool.
260
+
261
+ Args:
262
+ tool: Tool name
263
+ arguments_hash: Optional arguments hash. If None, all entries for the tool are invalidated.
264
+ """
108
265
  async with self._lock:
109
266
  if tool not in self._cache:
110
267
  return
111
- if arguments_hash is not None:
268
+
269
+ if arguments_hash:
270
+ # Invalidate specific entry
112
271
  self._cache[tool].pop(arguments_hash, None)
272
+ if not self._cache[tool]:
273
+ del self._cache[tool]
274
+ self._stats["invalidations"] += 1
275
+ logger.debug(f"Invalidated specific cache entry for {tool}")
113
276
  else:
277
+ # Invalidate all entries for tool
278
+ count = len(self._cache[tool])
114
279
  del self._cache[tool]
280
+ self._stats["invalidations"] += count
281
+ logger.debug(f"Invalidated all cache entries for {tool} ({count} entries)")
282
+
283
+ async def clear(self) -> None:
284
+ """Clear all cache entries."""
285
+ async with self._lock:
286
+ count = sum(len(entries) for entries in self._cache.values())
287
+ self._cache.clear()
288
+ self._stats["invalidations"] += count
289
+ logger.debug(f"Cleared entire cache ({count} entries)")
290
+
291
+ async def get_stats(self) -> Dict[str, Any]:
292
+ """
293
+ Get cache statistics.
294
+
295
+ Returns:
296
+ Dict with hits, misses, sets, invalidations, and entry counts
297
+ """
298
+ async with self._lock:
299
+ stats = dict(self._stats)
300
+ stats["implemented"] = True
301
+ stats["entry_count"] = sum(len(entries) for entries in self._cache.values())
302
+ stats["tool_count"] = len(self._cache)
303
+
304
+ # Calculate hit rate
305
+ total_gets = stats["hits"] + stats["misses"]
306
+ stats["hit_rate"] = stats["hits"] / total_gets if total_gets > 0 else 0.0
307
+
308
+ return stats
115
309
 
116
-
310
+ # --------------------------------------------------------------------------- #
311
+ # Executor wrapper
312
+ # --------------------------------------------------------------------------- #
117
313
  class CachingToolExecutor:
118
314
  """
119
- Wrapper for a tool executor that caches results.
315
+ Executor wrapper that transparently caches successful tool results.
316
+
317
+ This wrapper intercepts tool calls, checks if results are available in cache,
318
+ and only executes uncached calls. Successful results are automatically stored
319
+ in the cache for future use.
120
320
  """
321
+
121
322
  def __init__(
122
323
  self,
123
324
  executor: Any,
124
325
  cache: CacheInterface,
326
+ *,
125
327
  default_ttl: Optional[int] = None,
126
328
  tool_ttls: Optional[Dict[str, int]] = None,
127
- cacheable_tools: Optional[List[str]] = None
128
- ):
329
+ cacheable_tools: Optional[List[str]] = None,
330
+ ) -> None:
331
+ """
332
+ Initialize the caching executor.
333
+
334
+ Args:
335
+ executor: The underlying executor to wrap
336
+ cache: Cache implementation to use
337
+ default_ttl: Default time-to-live in seconds
338
+ tool_ttls: Dict mapping tool names to custom TTL values
339
+ cacheable_tools: List of tool names that should be cached. If None, all tools are cacheable.
340
+ """
129
341
  self.executor = executor
130
342
  self.cache = cache
131
343
  self.default_ttl = default_ttl
132
344
  self.tool_ttls = tool_ttls or {}
133
- self.cacheable_tools = cacheable_tools
134
-
135
- def _get_arguments_hash(self, arguments: Dict[str, Any]) -> str:
136
- serialized = json.dumps(arguments, sort_keys=True)
137
- return hashlib.md5(serialized.encode()).hexdigest()
138
-
345
+ self.cacheable_tools = set(cacheable_tools) if cacheable_tools else None
346
+
347
+ logger.debug(
348
+ f"Initialized CachingToolExecutor with {len(self.tool_ttls)} custom TTLs, "
349
+ f"default TTL={default_ttl}s"
350
+ )
351
+
352
+ # ---------------------------- helpers ----------------------------- #
353
+ @staticmethod
354
+ def _hash_arguments(arguments: Dict[str, Any]) -> str:
355
+ """
356
+ Generate a stable hash for tool arguments.
357
+
358
+ Args:
359
+ arguments: Tool arguments dict
360
+
361
+ Returns:
362
+ MD5 hash of the sorted JSON representation
363
+ """
364
+ try:
365
+ blob = json.dumps(arguments, sort_keys=True, default=str)
366
+ return hashlib.md5(blob.encode()).hexdigest()
367
+ except Exception as e:
368
+ logger.warning(f"Error hashing arguments: {e}")
369
+ # Fallback to a string representation
370
+ return hashlib.md5(str(arguments).encode()).hexdigest()
371
+
139
372
  def _is_cacheable(self, tool: str) -> bool:
140
- if self.cacheable_tools is None:
141
- return True
142
- return tool in self.cacheable_tools
143
-
144
- def _get_ttl(self, tool: str) -> Optional[int]:
373
+ """
374
+ Check if a tool is cacheable.
375
+
376
+ Args:
377
+ tool: Tool name
378
+
379
+ Returns:
380
+ True if the tool should be cached, False otherwise
381
+ """
382
+ return self.cacheable_tools is None or tool in self.cacheable_tools
383
+
384
+ def _ttl_for(self, tool: str) -> Optional[int]:
385
+ """
386
+ Get the TTL for a specific tool.
387
+
388
+ Args:
389
+ tool: Tool name
390
+
391
+ Returns:
392
+ Tool-specific TTL or default TTL
393
+ """
145
394
  return self.tool_ttls.get(tool, self.default_ttl)
146
-
395
+
396
+ # ------------------------------ API ------------------------------- #
147
397
  async def execute(
148
398
  self,
149
399
  calls: List[ToolCall],
400
+ *,
150
401
  timeout: Optional[float] = None,
151
- use_cache: bool = True
402
+ use_cache: bool = True,
152
403
  ) -> List[ToolResult]:
153
- results: List[ToolResult] = []
154
- uncached_calls: List[Tuple[int, ToolCall]] = []
404
+ """
405
+ Execute tool calls with caching.
155
406
 
407
+ Args:
408
+ calls: List of tool calls to execute
409
+ timeout: Optional timeout for execution
410
+ use_cache: Whether to use cached results
411
+
412
+ Returns:
413
+ List of tool results in the same order as calls
414
+ """
415
+ # Handle empty calls
416
+ if not calls:
417
+ return []
418
+
419
+ # ------------------------------------------------------------------
420
+ # 1. Split calls into cached / uncached buckets
421
+ # ------------------------------------------------------------------
422
+ cached_hits: List[Tuple[int, ToolResult]] = []
423
+ uncached: List[Tuple[int, ToolCall]] = []
424
+
156
425
  if use_cache:
157
- for i, call in enumerate(calls):
426
+ for idx, call in enumerate(calls):
158
427
  if not self._is_cacheable(call.tool):
159
- uncached_calls.append((i, call))
428
+ logger.debug(f"Tool {call.tool} is not cacheable, executing directly")
429
+ uncached.append((idx, call))
160
430
  continue
161
- arguments_hash = self._get_arguments_hash(call.arguments)
162
- cached_result = await self.cache.get(call.tool, arguments_hash)
163
- if cached_result is not None:
164
- now = datetime.now()
165
- results.append(ToolResult(
166
- tool=call.tool,
167
- result=cached_result,
168
- error=None,
169
- start_time=now,
170
- end_time=now,
171
- machine="cache",
172
- pid=0,
173
- cached=True
174
- ))
431
+
432
+ h = self._hash_arguments(call.arguments)
433
+ cached_val = await self.cache.get(call.tool, h)
434
+
435
+ if cached_val is None:
436
+ # Cache miss
437
+ logger.debug(f"Cache miss for {call.tool}")
438
+ uncached.append((idx, call))
175
439
  else:
176
- uncached_calls.append((i, call))
440
+ # Cache hit
441
+ logger.debug(f"Cache hit for {call.tool}")
442
+ now = datetime.now(timezone.utc)
443
+ cached_hits.append(
444
+ (
445
+ idx,
446
+ ToolResult(
447
+ tool=call.tool,
448
+ result=cached_val,
449
+ error=None,
450
+ start_time=now,
451
+ end_time=now,
452
+ machine="cache",
453
+ pid=0,
454
+ cached=True,
455
+ ),
456
+ )
457
+ )
177
458
  else:
178
- uncached_calls = [(i, call) for i, call in enumerate(calls)]
459
+ # Skip cache entirely
460
+ logger.debug("Cache disabled for this execution")
461
+ uncached = list(enumerate(calls))
179
462
 
180
- # Early return if all served from cache
181
- if use_cache and not uncached_calls:
182
- return results
463
+ # Early-exit if every call was served from cache
464
+ if not uncached:
465
+ logger.debug(f"All {len(cached_hits)} calls served from cache")
466
+ return [res for _, res in sorted(cached_hits, key=lambda t: t[0])]
183
467
 
184
- if uncached_calls:
185
- uncached_results = await self.executor.execute(
186
- [call for _, call in uncached_calls],
187
- timeout=timeout
188
- )
468
+ # ------------------------------------------------------------------
469
+ # 2. Execute remaining calls via wrapped executor
470
+ # ------------------------------------------------------------------
471
+ logger.debug(f"Executing {len(uncached)} uncached calls")
472
+ # Pass use_cache=False to avoid potential double-caching if executor also has caching
473
+ executor_kwargs = {"timeout": timeout}
474
+ if hasattr(self.executor, "use_cache"):
475
+ executor_kwargs["use_cache"] = False
189
476
 
190
- if use_cache:
191
- for idx, result in enumerate(uncached_results):
192
- _, call = uncached_calls[idx]
193
- if result.error is None and self._is_cacheable(call.tool):
194
- arguments_hash = self._get_arguments_hash(call.arguments)
195
- ttl = self._get_ttl(call.tool)
196
- await self.cache.set(
197
- call.tool,
198
- arguments_hash,
199
- result.result,
200
- ttl=ttl
201
- )
202
- result.cached = False
203
-
204
- final_results: List[ToolResult] = [None] * len(calls)
205
- uncached_indices = {idx for idx, _ in uncached_calls}
206
- uncached_iter = iter(uncached_results)
207
- cache_iter = iter(results)
208
- for i in range(len(calls)):
209
- if i in uncached_indices:
210
- final_results[i] = next(uncached_iter)
477
+ uncached_results = await self.executor.execute(
478
+ [call for _, call in uncached], **executor_kwargs
479
+ )
480
+
481
+ # ------------------------------------------------------------------
482
+ # 3. Insert fresh results into cache
483
+ # ------------------------------------------------------------------
484
+ if use_cache:
485
+ cache_tasks = []
486
+ for (idx, call), result in zip(uncached, uncached_results):
487
+ if result.error is None and self._is_cacheable(call.tool):
488
+ ttl = self._ttl_for(call.tool)
489
+ logger.debug(f"Caching result for {call.tool} with TTL={ttl}s")
490
+
491
+ # Create task but don't await yet (for concurrent caching)
492
+ task = self.cache.set(
493
+ call.tool,
494
+ self._hash_arguments(call.arguments),
495
+ result.result,
496
+ ttl=ttl,
497
+ )
498
+ cache_tasks.append(task)
499
+
500
+ # Flag as non-cached so callers can tell
501
+ if hasattr(result, "cached"):
502
+ result.cached = False
211
503
  else:
212
- final_results[i] = next(cache_iter)
213
- return final_results
504
+ # For older ToolResult objects that might not have cached attribute
505
+ setattr(result, "cached", False)
506
+
507
+ # Wait for all cache operations to complete
508
+ if cache_tasks:
509
+ await asyncio.gather(*cache_tasks)
510
+
511
+ # ------------------------------------------------------------------
512
+ # 4. Merge cached-hits + fresh results in original order
513
+ # ------------------------------------------------------------------
514
+ merged: List[Optional[ToolResult]] = [None] * len(calls)
515
+ for idx, hit in cached_hits:
516
+ merged[idx] = hit
517
+ for (idx, _), fresh in zip(uncached, uncached_results):
518
+ merged[idx] = fresh
214
519
 
520
+ # If calls was empty, merged remains []
521
+ return [result for result in merged if result is not None]
215
522
 
523
+
524
+ # --------------------------------------------------------------------------- #
525
+ # Convenience decorators
526
+ # --------------------------------------------------------------------------- #
216
527
  def cacheable(ttl: Optional[int] = None):
528
+ """
529
+ Decorator to mark a tool class as cacheable.
530
+
531
+ Example:
532
+ @cacheable(ttl=600) # Cache for 10 minutes
533
+ class WeatherTool:
534
+ async def execute(self, location: str) -> Dict[str, Any]:
535
+ # Implementation
536
+
537
+ Args:
538
+ ttl: Optional custom time-to-live in seconds
539
+
540
+ Returns:
541
+ Decorated class with caching metadata
542
+ """
217
543
  def decorator(cls):
218
- cls._cacheable = True
544
+ cls._cacheable = True # Runtime flag picked up by higher-level code
219
545
  if ttl is not None:
220
546
  cls._cache_ttl = ttl
221
547
  return cls
548
+
222
549
  return decorator
223
550
 
224
551
 
225
552
  def invalidate_cache(tool: str, arguments: Optional[Dict[str, Any]] = None):
553
+ """
554
+ Create an async function that invalidates specific cache entries.
555
+
556
+ Example:
557
+ invalidator = invalidate_cache("weather", {"location": "London"})
558
+ await invalidator(cache) # Call with a cache instance
559
+
560
+ Args:
561
+ tool: Tool name
562
+ arguments: Optional arguments dict. If None, all entries for the tool are invalidated.
563
+
564
+ Returns:
565
+ Async function that takes a cache instance and invalidates entries
566
+ """
226
567
  async def _invalidate(cache: CacheInterface):
227
568
  if arguments is not None:
228
- arguments_hash = hashlib.md5(
229
- json.dumps(arguments, sort_keys=True).encode()
230
- ).hexdigest()
231
- await cache.invalidate(tool, arguments_hash)
569
+ h = hashlib.md5(json.dumps(arguments, sort_keys=True, default=str).encode()).hexdigest()
570
+ await cache.invalidate(tool, h)
571
+ logger.debug(f"Invalidated cache entry for {tool} with specific arguments")
232
572
  else:
233
573
  await cache.invalidate(tool)
234
- return _invalidate
574
+ logger.debug(f"Invalidated all cache entries for {tool}")
575
+
576
+ return _invalidate