agnt5 0.3.2a1__cp310-abi3-manylinux_2_34_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of agnt5 might be problematic. Click here for more details.
- agnt5/__init__.py +119 -0
- agnt5/_compat.py +16 -0
- agnt5/_core.abi3.so +0 -0
- agnt5/_retry_utils.py +196 -0
- agnt5/_schema_utils.py +312 -0
- agnt5/_sentry.py +515 -0
- agnt5/_telemetry.py +279 -0
- agnt5/agent/__init__.py +48 -0
- agnt5/agent/context.py +581 -0
- agnt5/agent/core.py +1782 -0
- agnt5/agent/decorator.py +112 -0
- agnt5/agent/handoff.py +105 -0
- agnt5/agent/registry.py +68 -0
- agnt5/agent/result.py +39 -0
- agnt5/checkpoint.py +246 -0
- agnt5/client.py +1556 -0
- agnt5/context.py +288 -0
- agnt5/emit.py +197 -0
- agnt5/entity.py +1230 -0
- agnt5/events.py +567 -0
- agnt5/exceptions.py +110 -0
- agnt5/function.py +330 -0
- agnt5/journal.py +212 -0
- agnt5/lm.py +1266 -0
- agnt5/memoization.py +379 -0
- agnt5/memory.py +521 -0
- agnt5/tool.py +721 -0
- agnt5/tracing.py +300 -0
- agnt5/types.py +111 -0
- agnt5/version.py +19 -0
- agnt5/worker.py +2094 -0
- agnt5/workflow.py +1632 -0
- agnt5-0.3.2a1.dist-info/METADATA +26 -0
- agnt5-0.3.2a1.dist-info/RECORD +35 -0
- agnt5-0.3.2a1.dist-info/WHEEL +4 -0
agnt5/memoization.py
ADDED
|
@@ -0,0 +1,379 @@
|
|
|
1
|
+
"""Memoization support for AGNT5 agents.
|
|
2
|
+
|
|
3
|
+
This module provides auto-memoization of LLM and tool calls within agent loops,
|
|
4
|
+
enabling deterministic replay on crash recovery while preserving first-run
|
|
5
|
+
non-determinism.
|
|
6
|
+
|
|
7
|
+
Uses hybrid hashing approach:
|
|
8
|
+
- step_key: Sequence-based for ordering (lm.0, lm.1, tool.search.0)
|
|
9
|
+
- content_hash: SHA256 of inputs for validation on replay
|
|
10
|
+
|
|
11
|
+
Integrates with platform's step-level checkpoint system via CheckpointClient
|
|
12
|
+
for durable memoization that survives process crashes.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import hashlib
|
|
18
|
+
import json
|
|
19
|
+
import logging
|
|
20
|
+
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from .context import Context
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class MemoizationManager:
|
|
29
|
+
"""
|
|
30
|
+
Handles read-through caching via platform checkpoint system for LLM and tool calls.
|
|
31
|
+
|
|
32
|
+
Uses hybrid approach:
|
|
33
|
+
- step_key: Sequence-based for ordering (lm.0, lm.1, tool.search.0)
|
|
34
|
+
- content_hash: SHA256 of inputs for validation on replay
|
|
35
|
+
|
|
36
|
+
The step_key ensures correct replay order, while the content_hash detects
|
|
37
|
+
if inputs changed between runs (e.g., developer modified code).
|
|
38
|
+
|
|
39
|
+
Integrates with CheckpointClient for durable storage via platform's
|
|
40
|
+
step checkpoint system (Checkpoint/GetMemoizedStep RPCs).
|
|
41
|
+
|
|
42
|
+
Example:
|
|
43
|
+
```python
|
|
44
|
+
memo = MemoizationManager(ctx)
|
|
45
|
+
|
|
46
|
+
# For LLM calls
|
|
47
|
+
step_key, content_hash = memo.lm_call_key(model, messages, config)
|
|
48
|
+
cached = await memo.get_cached_lm_result(step_key, content_hash)
|
|
49
|
+
if cached:
|
|
50
|
+
return cached # Skip execution
|
|
51
|
+
# ... execute LLM call ...
|
|
52
|
+
await memo.cache_lm_result(step_key, content_hash, result)
|
|
53
|
+
```
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def __init__(self, ctx: "Context") -> None:
|
|
57
|
+
"""
|
|
58
|
+
Initialize memoization manager.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
ctx: Execution context for run_id and platform access
|
|
62
|
+
"""
|
|
63
|
+
self._ctx = ctx
|
|
64
|
+
self._lm_sequence = 0
|
|
65
|
+
self._tool_sequences: Dict[str, int] = {} # Per-tool sequence counters
|
|
66
|
+
self._checkpoint_client = None
|
|
67
|
+
self._connected = False
|
|
68
|
+
|
|
69
|
+
async def _ensure_client(self) -> bool:
|
|
70
|
+
"""
|
|
71
|
+
Lazily initialize and connect the checkpoint client.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
True if client is ready, False if unavailable
|
|
75
|
+
"""
|
|
76
|
+
if self._connected and self._checkpoint_client is not None:
|
|
77
|
+
return True
|
|
78
|
+
|
|
79
|
+
if self._checkpoint_client is None:
|
|
80
|
+
try:
|
|
81
|
+
from .checkpoint import CheckpointClient
|
|
82
|
+
self._checkpoint_client = CheckpointClient()
|
|
83
|
+
except ImportError as e:
|
|
84
|
+
logger.debug(f"CheckpointClient not available: {e}")
|
|
85
|
+
return False
|
|
86
|
+
except Exception as e:
|
|
87
|
+
logger.debug(f"Failed to create CheckpointClient: {e}")
|
|
88
|
+
return False
|
|
89
|
+
|
|
90
|
+
if not self._connected:
|
|
91
|
+
try:
|
|
92
|
+
await self._checkpoint_client.connect()
|
|
93
|
+
self._connected = True
|
|
94
|
+
except Exception as e:
|
|
95
|
+
logger.debug(f"Failed to connect CheckpointClient: {e}")
|
|
96
|
+
return False
|
|
97
|
+
|
|
98
|
+
return True
|
|
99
|
+
|
|
100
|
+
def _content_hash(self, data: dict) -> str:
|
|
101
|
+
"""
|
|
102
|
+
Generate SHA256 hash of content for validation.
|
|
103
|
+
|
|
104
|
+
Uses first 16 characters of hex digest for compact storage.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
data: Dictionary to hash (will be JSON serialized)
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
16-character hex hash string
|
|
111
|
+
"""
|
|
112
|
+
content = json.dumps(data, sort_keys=True, default=str)
|
|
113
|
+
return hashlib.sha256(content.encode()).hexdigest()[:16]
|
|
114
|
+
|
|
115
|
+
def lm_call_key(
|
|
116
|
+
self,
|
|
117
|
+
model: str,
|
|
118
|
+
messages: List[Any],
|
|
119
|
+
config: dict,
|
|
120
|
+
) -> Tuple[str, str]:
|
|
121
|
+
"""
|
|
122
|
+
Generate (step_key, content_hash) for LLM call.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
model: Model name/identifier
|
|
126
|
+
messages: List of message objects or dicts
|
|
127
|
+
config: Generation config (temperature, max_tokens, etc.)
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
Tuple of (step_key, content_hash):
|
|
131
|
+
- step_key: Sequence-based key for journal lookup (e.g., "lm.0")
|
|
132
|
+
- content_hash: Hash of inputs for replay validation
|
|
133
|
+
"""
|
|
134
|
+
step_key = f"lm.{self._lm_sequence}"
|
|
135
|
+
self._lm_sequence += 1
|
|
136
|
+
|
|
137
|
+
# Build hashable representation of messages
|
|
138
|
+
messages_data = []
|
|
139
|
+
for m in messages:
|
|
140
|
+
if hasattr(m, 'role') and hasattr(m, 'content'):
|
|
141
|
+
# Message object
|
|
142
|
+
role = m.role.value if hasattr(m.role, 'value') else str(m.role)
|
|
143
|
+
messages_data.append({"role": role, "content": m.content})
|
|
144
|
+
elif isinstance(m, dict):
|
|
145
|
+
# Dict format
|
|
146
|
+
messages_data.append({
|
|
147
|
+
"role": m.get('role', 'user'),
|
|
148
|
+
"content": m.get('content', '')
|
|
149
|
+
})
|
|
150
|
+
else:
|
|
151
|
+
# Fallback
|
|
152
|
+
messages_data.append({"content": str(m)})
|
|
153
|
+
|
|
154
|
+
content_hash = self._content_hash({
|
|
155
|
+
"model": model,
|
|
156
|
+
"messages": messages_data,
|
|
157
|
+
"temperature": config.get("temperature"),
|
|
158
|
+
"max_tokens": config.get("max_tokens"),
|
|
159
|
+
})
|
|
160
|
+
|
|
161
|
+
return step_key, content_hash
|
|
162
|
+
|
|
163
|
+
def tool_call_key(self, tool_name: str, kwargs: dict) -> Tuple[str, str]:
|
|
164
|
+
"""
|
|
165
|
+
Generate (step_key, content_hash) for tool call.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
tool_name: Name of the tool being called
|
|
169
|
+
kwargs: Tool arguments
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
Tuple of (step_key, content_hash):
|
|
173
|
+
- step_key: Sequence-based key for journal lookup (e.g., "tool.search.0")
|
|
174
|
+
- content_hash: Hash of inputs for replay validation
|
|
175
|
+
"""
|
|
176
|
+
seq = self._tool_sequences.get(tool_name, 0)
|
|
177
|
+
step_key = f"tool.{tool_name}.{seq}"
|
|
178
|
+
self._tool_sequences[tool_name] = seq + 1
|
|
179
|
+
|
|
180
|
+
content_hash = self._content_hash({
|
|
181
|
+
"tool": tool_name,
|
|
182
|
+
"args": kwargs,
|
|
183
|
+
})
|
|
184
|
+
|
|
185
|
+
return step_key, content_hash
|
|
186
|
+
|
|
187
|
+
async def get_cached_lm_result(
|
|
188
|
+
self,
|
|
189
|
+
step_key: str,
|
|
190
|
+
content_hash: str,
|
|
191
|
+
) -> Optional[Any]:
|
|
192
|
+
"""
|
|
193
|
+
Check platform checkpoint for cached LLM result.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
step_key: Sequence-based key (e.g., "lm.0")
|
|
197
|
+
content_hash: Hash of current inputs for validation
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
Cached GenerateResponse if found and valid, None otherwise
|
|
201
|
+
"""
|
|
202
|
+
if not await self._ensure_client():
|
|
203
|
+
return None
|
|
204
|
+
|
|
205
|
+
run_id = self._ctx.run_id
|
|
206
|
+
|
|
207
|
+
try:
|
|
208
|
+
# Use platform's GetMemoizedStep RPC
|
|
209
|
+
cached_bytes = await self._checkpoint_client.get_memoized_step(
|
|
210
|
+
run_id, step_key
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
if cached_bytes:
|
|
214
|
+
# Deserialize cached result
|
|
215
|
+
cached_data = json.loads(cached_bytes)
|
|
216
|
+
|
|
217
|
+
# Validate content hash matches (warn if mismatch)
|
|
218
|
+
stored_hash = cached_data.get("input_hash")
|
|
219
|
+
if stored_hash and stored_hash != content_hash:
|
|
220
|
+
logger.warning(
|
|
221
|
+
f"Content mismatch on replay for {step_key}: "
|
|
222
|
+
f"stored={stored_hash}, current={content_hash}. "
|
|
223
|
+
"Inputs may have changed between runs."
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
# Return cached output
|
|
227
|
+
output_data = cached_data.get("output_data")
|
|
228
|
+
if output_data:
|
|
229
|
+
logger.debug(f"Cache hit for LLM call {step_key}")
|
|
230
|
+
# Convert output_data back to GenerateResponse
|
|
231
|
+
from .lm import GenerateResponse
|
|
232
|
+
return GenerateResponse.from_dict(output_data)
|
|
233
|
+
|
|
234
|
+
except Exception as e:
|
|
235
|
+
logger.debug(f"Failed to lookup cached LLM result for {step_key}: {e}")
|
|
236
|
+
|
|
237
|
+
return None
|
|
238
|
+
|
|
239
|
+
async def cache_lm_result(
|
|
240
|
+
self,
|
|
241
|
+
step_key: str,
|
|
242
|
+
content_hash: str,
|
|
243
|
+
result: Any,
|
|
244
|
+
) -> None:
|
|
245
|
+
"""
|
|
246
|
+
Write LLM result to platform checkpoint for future replay.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
step_key: Sequence-based key (e.g., "lm.0")
|
|
250
|
+
content_hash: Hash of inputs for validation on replay
|
|
251
|
+
result: GenerateResponse to cache
|
|
252
|
+
"""
|
|
253
|
+
if not await self._ensure_client():
|
|
254
|
+
return
|
|
255
|
+
|
|
256
|
+
run_id = self._ctx.run_id
|
|
257
|
+
|
|
258
|
+
try:
|
|
259
|
+
# Convert result to dict for storage
|
|
260
|
+
output_data = result.to_dict() if hasattr(result, 'to_dict') else result
|
|
261
|
+
|
|
262
|
+
# Build cache payload with hash for validation
|
|
263
|
+
cache_payload = json.dumps({
|
|
264
|
+
"input_hash": content_hash,
|
|
265
|
+
"output_data": output_data,
|
|
266
|
+
}).encode()
|
|
267
|
+
|
|
268
|
+
# Use platform's Checkpoint RPC with step_completed
|
|
269
|
+
await self._checkpoint_client.step_completed(
|
|
270
|
+
run_id=run_id,
|
|
271
|
+
step_key=step_key,
|
|
272
|
+
step_name="lm_call",
|
|
273
|
+
step_type="llm",
|
|
274
|
+
output_payload=cache_payload,
|
|
275
|
+
)
|
|
276
|
+
logger.debug(f"Cached LLM result for {step_key}")
|
|
277
|
+
|
|
278
|
+
except Exception as e:
|
|
279
|
+
logger.warning(f"Failed to cache LLM result for {step_key}: {e}")
|
|
280
|
+
|
|
281
|
+
async def get_cached_tool_result(
|
|
282
|
+
self,
|
|
283
|
+
step_key: str,
|
|
284
|
+
content_hash: str,
|
|
285
|
+
) -> Tuple[bool, Optional[Any]]:
|
|
286
|
+
"""
|
|
287
|
+
Check platform checkpoint for cached tool result.
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
step_key: Sequence-based key (e.g., "tool.search.0")
|
|
291
|
+
content_hash: Hash of current inputs for validation
|
|
292
|
+
|
|
293
|
+
Returns:
|
|
294
|
+
Tuple of (found, result):
|
|
295
|
+
- found: True if cache entry exists
|
|
296
|
+
- result: Cached result if found, None otherwise
|
|
297
|
+
"""
|
|
298
|
+
if not await self._ensure_client():
|
|
299
|
+
return False, None
|
|
300
|
+
|
|
301
|
+
run_id = self._ctx.run_id
|
|
302
|
+
|
|
303
|
+
try:
|
|
304
|
+
# Use platform's GetMemoizedStep RPC
|
|
305
|
+
cached_bytes = await self._checkpoint_client.get_memoized_step(
|
|
306
|
+
run_id, step_key
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
if cached_bytes:
|
|
310
|
+
# Deserialize cached result
|
|
311
|
+
cached_data = json.loads(cached_bytes)
|
|
312
|
+
|
|
313
|
+
# Validate content hash matches (warn if mismatch)
|
|
314
|
+
stored_hash = cached_data.get("input_hash")
|
|
315
|
+
if stored_hash and stored_hash != content_hash:
|
|
316
|
+
logger.warning(
|
|
317
|
+
f"Content mismatch on replay for {step_key}: "
|
|
318
|
+
f"stored={stored_hash}, current={content_hash}. "
|
|
319
|
+
"Inputs may have changed between runs."
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
# Return cached output
|
|
323
|
+
output_data = cached_data.get("output_data")
|
|
324
|
+
logger.debug(f"Cache hit for tool call {step_key}")
|
|
325
|
+
return True, output_data
|
|
326
|
+
|
|
327
|
+
except Exception as e:
|
|
328
|
+
logger.debug(f"Failed to lookup cached tool result for {step_key}: {e}")
|
|
329
|
+
|
|
330
|
+
return False, None
|
|
331
|
+
|
|
332
|
+
async def cache_tool_result(
|
|
333
|
+
self,
|
|
334
|
+
step_key: str,
|
|
335
|
+
content_hash: str,
|
|
336
|
+
result: Any,
|
|
337
|
+
) -> None:
|
|
338
|
+
"""
|
|
339
|
+
Write tool result to platform checkpoint for future replay.
|
|
340
|
+
|
|
341
|
+
Args:
|
|
342
|
+
step_key: Sequence-based key (e.g., "tool.search.0")
|
|
343
|
+
content_hash: Hash of inputs for validation on replay
|
|
344
|
+
result: Tool result to cache
|
|
345
|
+
"""
|
|
346
|
+
if not await self._ensure_client():
|
|
347
|
+
return
|
|
348
|
+
|
|
349
|
+
run_id = self._ctx.run_id
|
|
350
|
+
|
|
351
|
+
try:
|
|
352
|
+
# Build cache payload with hash for validation
|
|
353
|
+
cache_payload = json.dumps({
|
|
354
|
+
"input_hash": content_hash,
|
|
355
|
+
"output_data": result,
|
|
356
|
+
}, default=str).encode()
|
|
357
|
+
|
|
358
|
+
# Use platform's Checkpoint RPC with step_completed
|
|
359
|
+
await self._checkpoint_client.step_completed(
|
|
360
|
+
run_id=run_id,
|
|
361
|
+
step_key=step_key,
|
|
362
|
+
step_name="tool_call",
|
|
363
|
+
step_type="tool",
|
|
364
|
+
output_payload=cache_payload,
|
|
365
|
+
)
|
|
366
|
+
logger.debug(f"Cached tool result for {step_key}")
|
|
367
|
+
|
|
368
|
+
except Exception as e:
|
|
369
|
+
logger.warning(f"Failed to cache tool result for {step_key}: {e}")
|
|
370
|
+
|
|
371
|
+
def reset(self) -> None:
|
|
372
|
+
"""
|
|
373
|
+
Reset sequence counters.
|
|
374
|
+
|
|
375
|
+
Call this when starting a new execution to ensure deterministic
|
|
376
|
+
step_key generation.
|
|
377
|
+
"""
|
|
378
|
+
self._lm_sequence = 0
|
|
379
|
+
self._tool_sequences.clear()
|