agnt5 0.3.2a1__cp310-abi3-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of agnt5 might be problematic. Click here for more details.

agnt5/memoization.py ADDED
@@ -0,0 +1,379 @@
1
+ """Memoization support for AGNT5 agents.
2
+
3
+ This module provides auto-memoization of LLM and tool calls within agent loops,
4
+ enabling deterministic replay on crash recovery while preserving first-run
5
+ non-determinism.
6
+
7
+ Uses hybrid hashing approach:
8
+ - step_key: Sequence-based for ordering (lm.0, lm.1, tool.search.0)
9
+ - content_hash: SHA256 of inputs for validation on replay
10
+
11
+ Integrates with platform's step-level checkpoint system via CheckpointClient
12
+ for durable memoization that survives process crashes.
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import hashlib
18
+ import json
19
+ import logging
20
+ from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
21
+
22
+ if TYPE_CHECKING:
23
+ from .context import Context
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class MemoizationManager:
29
+ """
30
+ Handles read-through caching via platform checkpoint system for LLM and tool calls.
31
+
32
+ Uses hybrid approach:
33
+ - step_key: Sequence-based for ordering (lm.0, lm.1, tool.search.0)
34
+ - content_hash: SHA256 of inputs for validation on replay
35
+
36
+ The step_key ensures correct replay order, while the content_hash detects
37
+ if inputs changed between runs (e.g., developer modified code).
38
+
39
+ Integrates with CheckpointClient for durable storage via platform's
40
+ step checkpoint system (Checkpoint/GetMemoizedStep RPCs).
41
+
42
+ Example:
43
+ ```python
44
+ memo = MemoizationManager(ctx)
45
+
46
+ # For LLM calls
47
+ step_key, content_hash = memo.lm_call_key(model, messages, config)
48
+ cached = await memo.get_cached_lm_result(step_key, content_hash)
49
+ if cached:
50
+ return cached # Skip execution
51
+ # ... execute LLM call ...
52
+ await memo.cache_lm_result(step_key, content_hash, result)
53
+ ```
54
+ """
55
+
56
+ def __init__(self, ctx: "Context") -> None:
57
+ """
58
+ Initialize memoization manager.
59
+
60
+ Args:
61
+ ctx: Execution context for run_id and platform access
62
+ """
63
+ self._ctx = ctx
64
+ self._lm_sequence = 0
65
+ self._tool_sequences: Dict[str, int] = {} # Per-tool sequence counters
66
+ self._checkpoint_client = None
67
+ self._connected = False
68
+
69
+ async def _ensure_client(self) -> bool:
70
+ """
71
+ Lazily initialize and connect the checkpoint client.
72
+
73
+ Returns:
74
+ True if client is ready, False if unavailable
75
+ """
76
+ if self._connected and self._checkpoint_client is not None:
77
+ return True
78
+
79
+ if self._checkpoint_client is None:
80
+ try:
81
+ from .checkpoint import CheckpointClient
82
+ self._checkpoint_client = CheckpointClient()
83
+ except ImportError as e:
84
+ logger.debug(f"CheckpointClient not available: {e}")
85
+ return False
86
+ except Exception as e:
87
+ logger.debug(f"Failed to create CheckpointClient: {e}")
88
+ return False
89
+
90
+ if not self._connected:
91
+ try:
92
+ await self._checkpoint_client.connect()
93
+ self._connected = True
94
+ except Exception as e:
95
+ logger.debug(f"Failed to connect CheckpointClient: {e}")
96
+ return False
97
+
98
+ return True
99
+
100
+ def _content_hash(self, data: dict) -> str:
101
+ """
102
+ Generate SHA256 hash of content for validation.
103
+
104
+ Uses first 16 characters of hex digest for compact storage.
105
+
106
+ Args:
107
+ data: Dictionary to hash (will be JSON serialized)
108
+
109
+ Returns:
110
+ 16-character hex hash string
111
+ """
112
+ content = json.dumps(data, sort_keys=True, default=str)
113
+ return hashlib.sha256(content.encode()).hexdigest()[:16]
114
+
115
+ def lm_call_key(
116
+ self,
117
+ model: str,
118
+ messages: List[Any],
119
+ config: dict,
120
+ ) -> Tuple[str, str]:
121
+ """
122
+ Generate (step_key, content_hash) for LLM call.
123
+
124
+ Args:
125
+ model: Model name/identifier
126
+ messages: List of message objects or dicts
127
+ config: Generation config (temperature, max_tokens, etc.)
128
+
129
+ Returns:
130
+ Tuple of (step_key, content_hash):
131
+ - step_key: Sequence-based key for journal lookup (e.g., "lm.0")
132
+ - content_hash: Hash of inputs for replay validation
133
+ """
134
+ step_key = f"lm.{self._lm_sequence}"
135
+ self._lm_sequence += 1
136
+
137
+ # Build hashable representation of messages
138
+ messages_data = []
139
+ for m in messages:
140
+ if hasattr(m, 'role') and hasattr(m, 'content'):
141
+ # Message object
142
+ role = m.role.value if hasattr(m.role, 'value') else str(m.role)
143
+ messages_data.append({"role": role, "content": m.content})
144
+ elif isinstance(m, dict):
145
+ # Dict format
146
+ messages_data.append({
147
+ "role": m.get('role', 'user'),
148
+ "content": m.get('content', '')
149
+ })
150
+ else:
151
+ # Fallback
152
+ messages_data.append({"content": str(m)})
153
+
154
+ content_hash = self._content_hash({
155
+ "model": model,
156
+ "messages": messages_data,
157
+ "temperature": config.get("temperature"),
158
+ "max_tokens": config.get("max_tokens"),
159
+ })
160
+
161
+ return step_key, content_hash
162
+
163
+ def tool_call_key(self, tool_name: str, kwargs: dict) -> Tuple[str, str]:
164
+ """
165
+ Generate (step_key, content_hash) for tool call.
166
+
167
+ Args:
168
+ tool_name: Name of the tool being called
169
+ kwargs: Tool arguments
170
+
171
+ Returns:
172
+ Tuple of (step_key, content_hash):
173
+ - step_key: Sequence-based key for journal lookup (e.g., "tool.search.0")
174
+ - content_hash: Hash of inputs for replay validation
175
+ """
176
+ seq = self._tool_sequences.get(tool_name, 0)
177
+ step_key = f"tool.{tool_name}.{seq}"
178
+ self._tool_sequences[tool_name] = seq + 1
179
+
180
+ content_hash = self._content_hash({
181
+ "tool": tool_name,
182
+ "args": kwargs,
183
+ })
184
+
185
+ return step_key, content_hash
186
+
187
+ async def get_cached_lm_result(
188
+ self,
189
+ step_key: str,
190
+ content_hash: str,
191
+ ) -> Optional[Any]:
192
+ """
193
+ Check platform checkpoint for cached LLM result.
194
+
195
+ Args:
196
+ step_key: Sequence-based key (e.g., "lm.0")
197
+ content_hash: Hash of current inputs for validation
198
+
199
+ Returns:
200
+ Cached GenerateResponse if found and valid, None otherwise
201
+ """
202
+ if not await self._ensure_client():
203
+ return None
204
+
205
+ run_id = self._ctx.run_id
206
+
207
+ try:
208
+ # Use platform's GetMemoizedStep RPC
209
+ cached_bytes = await self._checkpoint_client.get_memoized_step(
210
+ run_id, step_key
211
+ )
212
+
213
+ if cached_bytes:
214
+ # Deserialize cached result
215
+ cached_data = json.loads(cached_bytes)
216
+
217
+ # Validate content hash matches (warn if mismatch)
218
+ stored_hash = cached_data.get("input_hash")
219
+ if stored_hash and stored_hash != content_hash:
220
+ logger.warning(
221
+ f"Content mismatch on replay for {step_key}: "
222
+ f"stored={stored_hash}, current={content_hash}. "
223
+ "Inputs may have changed between runs."
224
+ )
225
+
226
+ # Return cached output
227
+ output_data = cached_data.get("output_data")
228
+ if output_data:
229
+ logger.debug(f"Cache hit for LLM call {step_key}")
230
+ # Convert output_data back to GenerateResponse
231
+ from .lm import GenerateResponse
232
+ return GenerateResponse.from_dict(output_data)
233
+
234
+ except Exception as e:
235
+ logger.debug(f"Failed to lookup cached LLM result for {step_key}: {e}")
236
+
237
+ return None
238
+
239
+ async def cache_lm_result(
240
+ self,
241
+ step_key: str,
242
+ content_hash: str,
243
+ result: Any,
244
+ ) -> None:
245
+ """
246
+ Write LLM result to platform checkpoint for future replay.
247
+
248
+ Args:
249
+ step_key: Sequence-based key (e.g., "lm.0")
250
+ content_hash: Hash of inputs for validation on replay
251
+ result: GenerateResponse to cache
252
+ """
253
+ if not await self._ensure_client():
254
+ return
255
+
256
+ run_id = self._ctx.run_id
257
+
258
+ try:
259
+ # Convert result to dict for storage
260
+ output_data = result.to_dict() if hasattr(result, 'to_dict') else result
261
+
262
+ # Build cache payload with hash for validation
263
+ cache_payload = json.dumps({
264
+ "input_hash": content_hash,
265
+ "output_data": output_data,
266
+ }).encode()
267
+
268
+ # Use platform's Checkpoint RPC with step_completed
269
+ await self._checkpoint_client.step_completed(
270
+ run_id=run_id,
271
+ step_key=step_key,
272
+ step_name="lm_call",
273
+ step_type="llm",
274
+ output_payload=cache_payload,
275
+ )
276
+ logger.debug(f"Cached LLM result for {step_key}")
277
+
278
+ except Exception as e:
279
+ logger.warning(f"Failed to cache LLM result for {step_key}: {e}")
280
+
281
+ async def get_cached_tool_result(
282
+ self,
283
+ step_key: str,
284
+ content_hash: str,
285
+ ) -> Tuple[bool, Optional[Any]]:
286
+ """
287
+ Check platform checkpoint for cached tool result.
288
+
289
+ Args:
290
+ step_key: Sequence-based key (e.g., "tool.search.0")
291
+ content_hash: Hash of current inputs for validation
292
+
293
+ Returns:
294
+ Tuple of (found, result):
295
+ - found: True if cache entry exists
296
+ - result: Cached result if found, None otherwise
297
+ """
298
+ if not await self._ensure_client():
299
+ return False, None
300
+
301
+ run_id = self._ctx.run_id
302
+
303
+ try:
304
+ # Use platform's GetMemoizedStep RPC
305
+ cached_bytes = await self._checkpoint_client.get_memoized_step(
306
+ run_id, step_key
307
+ )
308
+
309
+ if cached_bytes:
310
+ # Deserialize cached result
311
+ cached_data = json.loads(cached_bytes)
312
+
313
+ # Validate content hash matches (warn if mismatch)
314
+ stored_hash = cached_data.get("input_hash")
315
+ if stored_hash and stored_hash != content_hash:
316
+ logger.warning(
317
+ f"Content mismatch on replay for {step_key}: "
318
+ f"stored={stored_hash}, current={content_hash}. "
319
+ "Inputs may have changed between runs."
320
+ )
321
+
322
+ # Return cached output
323
+ output_data = cached_data.get("output_data")
324
+ logger.debug(f"Cache hit for tool call {step_key}")
325
+ return True, output_data
326
+
327
+ except Exception as e:
328
+ logger.debug(f"Failed to lookup cached tool result for {step_key}: {e}")
329
+
330
+ return False, None
331
+
332
+ async def cache_tool_result(
333
+ self,
334
+ step_key: str,
335
+ content_hash: str,
336
+ result: Any,
337
+ ) -> None:
338
+ """
339
+ Write tool result to platform checkpoint for future replay.
340
+
341
+ Args:
342
+ step_key: Sequence-based key (e.g., "tool.search.0")
343
+ content_hash: Hash of inputs for validation on replay
344
+ result: Tool result to cache
345
+ """
346
+ if not await self._ensure_client():
347
+ return
348
+
349
+ run_id = self._ctx.run_id
350
+
351
+ try:
352
+ # Build cache payload with hash for validation
353
+ cache_payload = json.dumps({
354
+ "input_hash": content_hash,
355
+ "output_data": result,
356
+ }, default=str).encode()
357
+
358
+ # Use platform's Checkpoint RPC with step_completed
359
+ await self._checkpoint_client.step_completed(
360
+ run_id=run_id,
361
+ step_key=step_key,
362
+ step_name="tool_call",
363
+ step_type="tool",
364
+ output_payload=cache_payload,
365
+ )
366
+ logger.debug(f"Cached tool result for {step_key}")
367
+
368
+ except Exception as e:
369
+ logger.warning(f"Failed to cache tool result for {step_key}: {e}")
370
+
371
+ def reset(self) -> None:
372
+ """
373
+ Reset sequence counters.
374
+
375
+ Call this when starting a new execution to ensure deterministic
376
+ step_key generation.
377
+ """
378
+ self._lm_sequence = 0
379
+ self._tool_sequences.clear()