chuk-ai-session-manager 0.7.1__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chuk_ai_session_manager/__init__.py +84 -40
- chuk_ai_session_manager/api/__init__.py +1 -1
- chuk_ai_session_manager/api/simple_api.py +53 -59
- chuk_ai_session_manager/exceptions.py +31 -17
- chuk_ai_session_manager/guards/__init__.py +118 -0
- chuk_ai_session_manager/guards/bindings.py +217 -0
- chuk_ai_session_manager/guards/cache.py +163 -0
- chuk_ai_session_manager/guards/manager.py +819 -0
- chuk_ai_session_manager/guards/models.py +498 -0
- chuk_ai_session_manager/guards/ungrounded.py +159 -0
- chuk_ai_session_manager/infinite_conversation.py +86 -79
- chuk_ai_session_manager/memory/__init__.py +247 -0
- chuk_ai_session_manager/memory/artifacts_bridge.py +469 -0
- chuk_ai_session_manager/memory/context_packer.py +347 -0
- chuk_ai_session_manager/memory/fault_handler.py +507 -0
- chuk_ai_session_manager/memory/manifest.py +307 -0
- chuk_ai_session_manager/memory/models.py +1084 -0
- chuk_ai_session_manager/memory/mutation_log.py +186 -0
- chuk_ai_session_manager/memory/pack_cache.py +206 -0
- chuk_ai_session_manager/memory/page_table.py +275 -0
- chuk_ai_session_manager/memory/prefetcher.py +192 -0
- chuk_ai_session_manager/memory/tlb.py +247 -0
- chuk_ai_session_manager/memory/vm_prompts.py +238 -0
- chuk_ai_session_manager/memory/working_set.py +574 -0
- chuk_ai_session_manager/models/__init__.py +21 -9
- chuk_ai_session_manager/models/event_source.py +3 -1
- chuk_ai_session_manager/models/event_type.py +10 -1
- chuk_ai_session_manager/models/session.py +103 -68
- chuk_ai_session_manager/models/session_event.py +69 -68
- chuk_ai_session_manager/models/session_metadata.py +9 -10
- chuk_ai_session_manager/models/session_run.py +21 -22
- chuk_ai_session_manager/models/token_usage.py +76 -76
- chuk_ai_session_manager/procedural_memory/__init__.py +70 -0
- chuk_ai_session_manager/procedural_memory/formatter.py +407 -0
- chuk_ai_session_manager/procedural_memory/manager.py +523 -0
- chuk_ai_session_manager/procedural_memory/models.py +371 -0
- chuk_ai_session_manager/sample_tools.py +79 -46
- chuk_ai_session_manager/session_aware_tool_processor.py +27 -16
- chuk_ai_session_manager/session_manager.py +259 -232
- chuk_ai_session_manager/session_prompt_builder.py +163 -111
- chuk_ai_session_manager/session_storage.py +45 -52
- {chuk_ai_session_manager-0.7.1.dist-info → chuk_ai_session_manager-0.8.1.dist-info}/METADATA +80 -4
- chuk_ai_session_manager-0.8.1.dist-info/RECORD +45 -0
- {chuk_ai_session_manager-0.7.1.dist-info → chuk_ai_session_manager-0.8.1.dist-info}/WHEEL +1 -1
- chuk_ai_session_manager-0.7.1.dist-info/RECORD +0 -22
- {chuk_ai_session_manager-0.7.1.dist-info → chuk_ai_session_manager-0.8.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,523 @@
|
|
|
1
|
+
# chuk_ai_session_manager/procedural_memory/manager.py
|
|
2
|
+
"""
|
|
3
|
+
Tool Memory Manager - orchestrates procedural memory operations.
|
|
4
|
+
|
|
5
|
+
Handles:
|
|
6
|
+
- Recording tool invocations
|
|
7
|
+
- Detecting fix relationships (failure -> success)
|
|
8
|
+
- Updating aggregated patterns
|
|
9
|
+
- Searching and retrieving history
|
|
10
|
+
- Session persistence integration
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import hashlib
|
|
16
|
+
import json
|
|
17
|
+
import logging
|
|
18
|
+
from datetime import datetime, timezone
|
|
19
|
+
from typing import Any, Callable, Optional, TYPE_CHECKING
|
|
20
|
+
|
|
21
|
+
from pydantic import BaseModel, Field
|
|
22
|
+
|
|
23
|
+
from chuk_ai_session_manager.procedural_memory.models import (
|
|
24
|
+
ProceduralMemory,
|
|
25
|
+
ToolFixRelation,
|
|
26
|
+
ToolLogEntry,
|
|
27
|
+
ToolOutcome,
|
|
28
|
+
ToolPattern,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
if TYPE_CHECKING:
|
|
32
|
+
from chuk_ai_session_manager.models.session import Session
|
|
33
|
+
|
|
34
|
+
log = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class ToolMemoryManager(BaseModel):
|
|
38
|
+
"""
|
|
39
|
+
Manages procedural memory - tool traces, patterns, and fixes.
|
|
40
|
+
|
|
41
|
+
This is the main interface for:
|
|
42
|
+
- Recording tool calls and their outcomes
|
|
43
|
+
- Detecting when a success fixes a prior failure
|
|
44
|
+
- Building aggregated patterns from traces
|
|
45
|
+
- Retrieving relevant history for context injection
|
|
46
|
+
- Persisting to/from Session state
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
memory: ProceduralMemory
|
|
50
|
+
|
|
51
|
+
# Configuration
|
|
52
|
+
max_log_entries: int = Field(default=1000)
|
|
53
|
+
max_patterns_per_tool: int = Field(default=10)
|
|
54
|
+
fix_detection_window: int = Field(default=10) # Look back N calls for fixes
|
|
55
|
+
|
|
56
|
+
# Optional callbacks
|
|
57
|
+
on_fix_detected: Optional[Callable[[ToolFixRelation], None]] = Field(
|
|
58
|
+
default=None, exclude=True
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
model_config = {"arbitrary_types_allowed": True}
|
|
62
|
+
|
|
63
|
+
@classmethod
|
|
64
|
+
def create(cls, session_id: str, **kwargs: Any) -> "ToolMemoryManager":
|
|
65
|
+
"""Create a new manager for a session."""
|
|
66
|
+
memory = ProceduralMemory(session_id=session_id)
|
|
67
|
+
return cls(memory=memory, **kwargs)
|
|
68
|
+
|
|
69
|
+
@classmethod
|
|
70
|
+
async def from_session(
|
|
71
|
+
cls, session: "Session", **kwargs: Any
|
|
72
|
+
) -> "ToolMemoryManager":
|
|
73
|
+
"""
|
|
74
|
+
Load procedural memory from a session's state.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
session: The Session to load from
|
|
78
|
+
**kwargs: Additional configuration
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
ToolMemoryManager with restored state
|
|
82
|
+
"""
|
|
83
|
+
state_key = "procedural_memory"
|
|
84
|
+
stored = await session.get_state(state_key)
|
|
85
|
+
|
|
86
|
+
if stored:
|
|
87
|
+
try:
|
|
88
|
+
memory = ProceduralMemory.model_validate(stored)
|
|
89
|
+
log.info(f"Restored procedural memory for session {session.id}")
|
|
90
|
+
return cls(memory=memory, **kwargs)
|
|
91
|
+
except Exception as e:
|
|
92
|
+
log.warning(f"Failed to restore procedural memory: {e}")
|
|
93
|
+
|
|
94
|
+
# Create fresh
|
|
95
|
+
return cls.create(session_id=session.id, **kwargs)
|
|
96
|
+
|
|
97
|
+
async def save_to_session(self, session: "Session") -> None:
|
|
98
|
+
"""
|
|
99
|
+
Persist procedural memory to a session's state.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
session: The Session to save to
|
|
103
|
+
"""
|
|
104
|
+
state_key = "procedural_memory"
|
|
105
|
+
await session.set_state(state_key, self.memory.model_dump(mode="json"))
|
|
106
|
+
log.debug(f"Saved procedural memory to session {session.id}")
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def session_id(self) -> str:
|
|
110
|
+
"""Get the session ID."""
|
|
111
|
+
return self.memory.session_id
|
|
112
|
+
|
|
113
|
+
# --- Recording ---
|
|
114
|
+
|
|
115
|
+
async def record_call(
|
|
116
|
+
self,
|
|
117
|
+
tool_name: str,
|
|
118
|
+
arguments: dict[str, Any],
|
|
119
|
+
result: Any,
|
|
120
|
+
outcome: ToolOutcome,
|
|
121
|
+
context_goal: Optional[str] = None,
|
|
122
|
+
error_type: Optional[str] = None,
|
|
123
|
+
error_message: Optional[str] = None,
|
|
124
|
+
execution_time_ms: Optional[int] = None,
|
|
125
|
+
preceding_call_id: Optional[str] = None,
|
|
126
|
+
) -> ToolLogEntry:
|
|
127
|
+
"""
|
|
128
|
+
Record a tool invocation.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
tool_name: Name of the tool called
|
|
132
|
+
arguments: Arguments passed to the tool
|
|
133
|
+
result: The result (will be summarized)
|
|
134
|
+
outcome: Success/failure status
|
|
135
|
+
context_goal: What the user was trying to accomplish
|
|
136
|
+
error_type: Type of error if failure
|
|
137
|
+
error_message: Error details if failure
|
|
138
|
+
execution_time_ms: How long the call took
|
|
139
|
+
preceding_call_id: ID of the call that came before this in a chain
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
The created ToolLogEntry
|
|
143
|
+
"""
|
|
144
|
+
call_id = self.memory.allocate_call_id()
|
|
145
|
+
|
|
146
|
+
entry = ToolLogEntry(
|
|
147
|
+
id=call_id,
|
|
148
|
+
timestamp=datetime.now(timezone.utc),
|
|
149
|
+
tool_name=tool_name,
|
|
150
|
+
arguments=arguments,
|
|
151
|
+
arguments_hash=self._hash_arguments(arguments),
|
|
152
|
+
context_goal=context_goal,
|
|
153
|
+
preceding_call_id=preceding_call_id,
|
|
154
|
+
outcome=outcome,
|
|
155
|
+
result_summary=self._summarize_result(result),
|
|
156
|
+
result_type=self._classify_result(result),
|
|
157
|
+
execution_time_ms=execution_time_ms,
|
|
158
|
+
error_type=error_type,
|
|
159
|
+
error_message=error_message,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
# Add to log
|
|
163
|
+
self.memory.tool_log.append(entry)
|
|
164
|
+
|
|
165
|
+
# Enforce max size
|
|
166
|
+
if len(self.memory.tool_log) > self.max_log_entries:
|
|
167
|
+
self.memory.tool_log.pop(0)
|
|
168
|
+
|
|
169
|
+
# Update patterns
|
|
170
|
+
pattern = self.memory.get_pattern(tool_name)
|
|
171
|
+
pattern.record_call(entry)
|
|
172
|
+
|
|
173
|
+
# Handle success - check if it fixes a prior failure
|
|
174
|
+
if outcome == ToolOutcome.SUCCESS:
|
|
175
|
+
await self._check_if_fixes_prior(entry)
|
|
176
|
+
|
|
177
|
+
# Record success pattern if it was a fix
|
|
178
|
+
if entry.is_fix():
|
|
179
|
+
pattern.add_success_pattern(
|
|
180
|
+
goal_match=context_goal,
|
|
181
|
+
delta_that_fixed=entry.delta_args,
|
|
182
|
+
example_call_id=call_id,
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
# Handle failure - record error pattern
|
|
186
|
+
elif outcome in (ToolOutcome.FAILURE, ToolOutcome.TIMEOUT):
|
|
187
|
+
pattern.add_error_pattern(
|
|
188
|
+
error_type=error_type or "unknown",
|
|
189
|
+
context=context_goal,
|
|
190
|
+
example_args=arguments,
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
self.memory.updated_at = datetime.now(timezone.utc)
|
|
194
|
+
|
|
195
|
+
log.debug(f"Recorded tool call: {entry.format_compact()}")
|
|
196
|
+
return entry
|
|
197
|
+
|
|
198
|
+
async def _check_if_fixes_prior(self, success_entry: ToolLogEntry) -> None:
|
|
199
|
+
"""
|
|
200
|
+
Check if this successful call fixes a recent failure.
|
|
201
|
+
|
|
202
|
+
Looks back through recent calls for the same tool that failed,
|
|
203
|
+
and links them if found.
|
|
204
|
+
"""
|
|
205
|
+
# Look at recent calls (excluding this one)
|
|
206
|
+
window = self.memory.tool_log[-(self.fix_detection_window + 1) : -1]
|
|
207
|
+
|
|
208
|
+
for prior in reversed(window):
|
|
209
|
+
# Must be same tool
|
|
210
|
+
if prior.tool_name != success_entry.tool_name:
|
|
211
|
+
continue
|
|
212
|
+
|
|
213
|
+
# Must be a failure
|
|
214
|
+
if not prior.is_failure():
|
|
215
|
+
continue
|
|
216
|
+
|
|
217
|
+
# Must not already be fixed
|
|
218
|
+
if prior.was_fixed():
|
|
219
|
+
continue
|
|
220
|
+
|
|
221
|
+
# Compute what changed
|
|
222
|
+
delta = self._compute_arg_delta(prior.arguments, success_entry.arguments)
|
|
223
|
+
|
|
224
|
+
if delta:
|
|
225
|
+
# Link the fix
|
|
226
|
+
prior.fixed_by = success_entry.id
|
|
227
|
+
success_entry.fix_for = prior.id
|
|
228
|
+
success_entry.delta_args = delta
|
|
229
|
+
|
|
230
|
+
# Record the fix relation
|
|
231
|
+
relation = ToolFixRelation(
|
|
232
|
+
failed_call_id=prior.id,
|
|
233
|
+
success_call_id=success_entry.id,
|
|
234
|
+
delta_args=delta,
|
|
235
|
+
)
|
|
236
|
+
self.memory.fix_relations.append(relation)
|
|
237
|
+
|
|
238
|
+
# Update error pattern with fix info
|
|
239
|
+
if prior.error_type:
|
|
240
|
+
pattern = self.memory.get_pattern(success_entry.tool_name)
|
|
241
|
+
pattern.record_fix(prior.error_type, delta)
|
|
242
|
+
|
|
243
|
+
log.info(
|
|
244
|
+
f"Detected fix: {prior.id} -> {success_entry.id}, delta: {delta}"
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
# Callback
|
|
248
|
+
if self.on_fix_detected:
|
|
249
|
+
self.on_fix_detected(relation)
|
|
250
|
+
|
|
251
|
+
# Only link to most recent failure
|
|
252
|
+
break
|
|
253
|
+
|
|
254
|
+
def _compute_arg_delta(
|
|
255
|
+
self, failed_args: dict[str, Any], success_args: dict[str, Any]
|
|
256
|
+
) -> Optional[dict[str, Any]]:
|
|
257
|
+
"""Compute what changed between failed and successful call."""
|
|
258
|
+
delta: dict[str, Any] = {}
|
|
259
|
+
|
|
260
|
+
failed_keys = set(failed_args.keys())
|
|
261
|
+
success_keys = set(success_args.keys())
|
|
262
|
+
|
|
263
|
+
# Added keys
|
|
264
|
+
added = success_keys - failed_keys
|
|
265
|
+
if added:
|
|
266
|
+
delta["added"] = {k: success_args[k] for k in added}
|
|
267
|
+
|
|
268
|
+
# Removed keys
|
|
269
|
+
removed = failed_keys - success_keys
|
|
270
|
+
if removed:
|
|
271
|
+
delta["removed"] = list(removed)
|
|
272
|
+
|
|
273
|
+
# Changed values
|
|
274
|
+
changed = {}
|
|
275
|
+
for k in failed_keys & success_keys:
|
|
276
|
+
if not self._values_equal(failed_args[k], success_args[k]):
|
|
277
|
+
changed[k] = {"from": failed_args[k], "to": success_args[k]}
|
|
278
|
+
if changed:
|
|
279
|
+
delta["changed"] = changed
|
|
280
|
+
|
|
281
|
+
return delta if delta else None
|
|
282
|
+
|
|
283
|
+
def _values_equal(self, a: Any, b: Any) -> bool:
|
|
284
|
+
"""Check if two values are equal (handling nested structures)."""
|
|
285
|
+
try:
|
|
286
|
+
return json.dumps(a, sort_keys=True) == json.dumps(b, sort_keys=True)
|
|
287
|
+
except (TypeError, ValueError):
|
|
288
|
+
return a == b
|
|
289
|
+
|
|
290
|
+
def _hash_arguments(self, arguments: dict[str, Any]) -> str:
|
|
291
|
+
"""Create hash of arguments for quick comparison."""
|
|
292
|
+
try:
|
|
293
|
+
args_str = json.dumps(arguments, sort_keys=True, default=str)
|
|
294
|
+
return hashlib.sha256(args_str.encode()).hexdigest()[:12]
|
|
295
|
+
except (TypeError, ValueError):
|
|
296
|
+
return ""
|
|
297
|
+
|
|
298
|
+
def _summarize_result(self, result: Any, max_len: int = 100) -> str:
|
|
299
|
+
"""Create compact summary of a result."""
|
|
300
|
+
if result is None:
|
|
301
|
+
return "null"
|
|
302
|
+
|
|
303
|
+
if isinstance(result, str):
|
|
304
|
+
if len(result) <= max_len:
|
|
305
|
+
return result
|
|
306
|
+
return result[: max_len - 3] + "..."
|
|
307
|
+
|
|
308
|
+
if isinstance(result, bool):
|
|
309
|
+
return "true" if result else "false"
|
|
310
|
+
|
|
311
|
+
if isinstance(result, (int, float)):
|
|
312
|
+
return str(result)
|
|
313
|
+
|
|
314
|
+
if isinstance(result, list):
|
|
315
|
+
return f"list[{len(result)}]"
|
|
316
|
+
|
|
317
|
+
if isinstance(result, dict):
|
|
318
|
+
# Check for common patterns
|
|
319
|
+
if "status" in result:
|
|
320
|
+
return f"status={result['status']}"
|
|
321
|
+
if "error" in result:
|
|
322
|
+
return f"error: {result['error']}"
|
|
323
|
+
if "result" in result:
|
|
324
|
+
return self._summarize_result(result["result"], max_len)
|
|
325
|
+
return f"object[{len(result)} keys]"
|
|
326
|
+
|
|
327
|
+
return str(type(result).__name__)
|
|
328
|
+
|
|
329
|
+
def _classify_result(self, result: Any) -> str:
|
|
330
|
+
"""Classify the type of result."""
|
|
331
|
+
if result is None:
|
|
332
|
+
return "null"
|
|
333
|
+
if isinstance(result, bool):
|
|
334
|
+
return "boolean"
|
|
335
|
+
if isinstance(result, (int, float)):
|
|
336
|
+
return "number"
|
|
337
|
+
if isinstance(result, str):
|
|
338
|
+
return "string"
|
|
339
|
+
if isinstance(result, list):
|
|
340
|
+
return "list"
|
|
341
|
+
if isinstance(result, dict):
|
|
342
|
+
if "status" in result:
|
|
343
|
+
return result.get("status", "object")
|
|
344
|
+
return "object"
|
|
345
|
+
return "unknown"
|
|
346
|
+
|
|
347
|
+
# --- Retrieval ---
|
|
348
|
+
|
|
349
|
+
def get_recent_calls(
|
|
350
|
+
self,
|
|
351
|
+
tool_name: Optional[str] = None,
|
|
352
|
+
limit: int = 5,
|
|
353
|
+
outcome: Optional[ToolOutcome] = None,
|
|
354
|
+
) -> list[ToolLogEntry]:
|
|
355
|
+
"""
|
|
356
|
+
Get recent tool calls.
|
|
357
|
+
|
|
358
|
+
Args:
|
|
359
|
+
tool_name: Filter by tool name (None = all tools)
|
|
360
|
+
limit: Maximum entries to return
|
|
361
|
+
outcome: Filter by outcome (None = all outcomes)
|
|
362
|
+
|
|
363
|
+
Returns:
|
|
364
|
+
List of entries, most recent first
|
|
365
|
+
"""
|
|
366
|
+
entries = self.memory.tool_log
|
|
367
|
+
|
|
368
|
+
if tool_name:
|
|
369
|
+
entries = [e for e in entries if e.tool_name == tool_name]
|
|
370
|
+
|
|
371
|
+
if outcome:
|
|
372
|
+
entries = [e for e in entries if e.outcome == outcome]
|
|
373
|
+
|
|
374
|
+
return list(reversed(entries[-limit:]))
|
|
375
|
+
|
|
376
|
+
def get_pattern(self, tool_name: str) -> Optional[ToolPattern]:
|
|
377
|
+
"""Get aggregated pattern for a tool."""
|
|
378
|
+
return self.memory.tool_patterns.get(tool_name)
|
|
379
|
+
|
|
380
|
+
def get_all_patterns(self) -> dict[str, ToolPattern]:
|
|
381
|
+
"""Get all tool patterns."""
|
|
382
|
+
return self.memory.tool_patterns.copy()
|
|
383
|
+
|
|
384
|
+
def search_calls(
|
|
385
|
+
self,
|
|
386
|
+
tool_name: Optional[str] = None,
|
|
387
|
+
goal_contains: Optional[str] = None,
|
|
388
|
+
outcome: Optional[ToolOutcome] = None,
|
|
389
|
+
error_type: Optional[str] = None,
|
|
390
|
+
only_fixes: bool = False,
|
|
391
|
+
only_fixed: bool = False,
|
|
392
|
+
limit: int = 10,
|
|
393
|
+
) -> list[ToolLogEntry]:
|
|
394
|
+
"""
|
|
395
|
+
Search tool log with filters.
|
|
396
|
+
|
|
397
|
+
Args:
|
|
398
|
+
tool_name: Filter by tool name
|
|
399
|
+
goal_contains: Filter by goal containing text
|
|
400
|
+
outcome: Filter by outcome
|
|
401
|
+
error_type: Filter by error type
|
|
402
|
+
only_fixes: Only return calls that fixed something
|
|
403
|
+
only_fixed: Only return failures that were fixed
|
|
404
|
+
limit: Maximum results
|
|
405
|
+
|
|
406
|
+
Returns:
|
|
407
|
+
Matching entries, most recent first
|
|
408
|
+
"""
|
|
409
|
+
results = []
|
|
410
|
+
|
|
411
|
+
for entry in reversed(self.memory.tool_log):
|
|
412
|
+
# Apply filters
|
|
413
|
+
if tool_name and entry.tool_name != tool_name:
|
|
414
|
+
continue
|
|
415
|
+
|
|
416
|
+
if goal_contains:
|
|
417
|
+
if not entry.context_goal:
|
|
418
|
+
continue
|
|
419
|
+
if goal_contains.lower() not in entry.context_goal.lower():
|
|
420
|
+
continue
|
|
421
|
+
|
|
422
|
+
if outcome and entry.outcome != outcome:
|
|
423
|
+
continue
|
|
424
|
+
|
|
425
|
+
if error_type and entry.error_type != error_type:
|
|
426
|
+
continue
|
|
427
|
+
|
|
428
|
+
if only_fixes and not entry.is_fix():
|
|
429
|
+
continue
|
|
430
|
+
|
|
431
|
+
if only_fixed and not entry.was_fixed():
|
|
432
|
+
continue
|
|
433
|
+
|
|
434
|
+
results.append(entry)
|
|
435
|
+
|
|
436
|
+
if len(results) >= limit:
|
|
437
|
+
break
|
|
438
|
+
|
|
439
|
+
return results
|
|
440
|
+
|
|
441
|
+
def get_fix_for_error(
|
|
442
|
+
self, tool_name: str, error_type: str
|
|
443
|
+
) -> Optional[dict[str, Any]]:
|
|
444
|
+
"""
|
|
445
|
+
Get the typical fix for an error type.
|
|
446
|
+
|
|
447
|
+
Returns the delta_args that typically fixes this error,
|
|
448
|
+
based on observed fix relationships.
|
|
449
|
+
"""
|
|
450
|
+
pattern = self.get_pattern(tool_name)
|
|
451
|
+
if not pattern:
|
|
452
|
+
return None
|
|
453
|
+
|
|
454
|
+
for ep in pattern.error_patterns:
|
|
455
|
+
if ep.error_type == error_type and ep.fix_delta:
|
|
456
|
+
return ep.fix_delta
|
|
457
|
+
|
|
458
|
+
return None
|
|
459
|
+
|
|
460
|
+
def get_successful_args_for_goal(
|
|
461
|
+
self, tool_name: str, goal: str
|
|
462
|
+
) -> Optional[dict[str, Any]]:
|
|
463
|
+
"""
|
|
464
|
+
Get argument hints for a goal based on past successes.
|
|
465
|
+
|
|
466
|
+
Searches success patterns for matching goals and returns
|
|
467
|
+
argument hints.
|
|
468
|
+
"""
|
|
469
|
+
pattern = self.get_pattern(tool_name)
|
|
470
|
+
if not pattern:
|
|
471
|
+
return None
|
|
472
|
+
|
|
473
|
+
goal_lower = goal.lower()
|
|
474
|
+
for sp in pattern.success_patterns:
|
|
475
|
+
if sp.goal_match and goal_lower in sp.goal_match.lower():
|
|
476
|
+
return sp.arg_hints
|
|
477
|
+
|
|
478
|
+
return None
|
|
479
|
+
|
|
480
|
+
# --- Statistics ---
|
|
481
|
+
|
|
482
|
+
def get_stats(self) -> dict[str, Any]:
|
|
483
|
+
"""Get memory statistics."""
|
|
484
|
+
total_calls = len(self.memory.tool_log)
|
|
485
|
+
total_fixes = len(self.memory.fix_relations)
|
|
486
|
+
tools_tracked = len(self.memory.tool_patterns)
|
|
487
|
+
|
|
488
|
+
success_count = sum(
|
|
489
|
+
1 for e in self.memory.tool_log if e.outcome == ToolOutcome.SUCCESS
|
|
490
|
+
)
|
|
491
|
+
failure_count = sum(1 for e in self.memory.tool_log if e.is_failure())
|
|
492
|
+
|
|
493
|
+
return {
|
|
494
|
+
"session_id": self.session_id,
|
|
495
|
+
"total_calls": total_calls,
|
|
496
|
+
"success_count": success_count,
|
|
497
|
+
"failure_count": failure_count,
|
|
498
|
+
"success_rate": success_count / total_calls if total_calls > 0 else 0,
|
|
499
|
+
"total_fixes_detected": total_fixes,
|
|
500
|
+
"tools_tracked": tools_tracked,
|
|
501
|
+
"created_at": self.memory.created_at.isoformat(),
|
|
502
|
+
"updated_at": self.memory.updated_at.isoformat(),
|
|
503
|
+
}
|
|
504
|
+
|
|
505
|
+
# --- Persistence hooks ---
|
|
506
|
+
|
|
507
|
+
def to_dict(self) -> dict[str, Any]:
|
|
508
|
+
"""Serialize to dictionary for persistence."""
|
|
509
|
+
return self.memory.model_dump(mode="json")
|
|
510
|
+
|
|
511
|
+
@classmethod
|
|
512
|
+
def from_dict(cls, data: dict[str, Any], **kwargs: Any) -> "ToolMemoryManager":
|
|
513
|
+
"""Restore from dictionary."""
|
|
514
|
+
memory = ProceduralMemory.model_validate(data)
|
|
515
|
+
return cls(memory=memory, **kwargs)
|
|
516
|
+
|
|
517
|
+
def reset(self) -> None:
|
|
518
|
+
"""Clear all memory."""
|
|
519
|
+
self.memory.tool_log.clear()
|
|
520
|
+
self.memory.tool_patterns.clear()
|
|
521
|
+
self.memory.fix_relations.clear()
|
|
522
|
+
self.memory.next_call_id = 1
|
|
523
|
+
self.memory.updated_at = datetime.now(timezone.utc)
|