chuk-ai-session-manager 0.7.1__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. chuk_ai_session_manager/__init__.py +84 -40
  2. chuk_ai_session_manager/api/__init__.py +1 -1
  3. chuk_ai_session_manager/api/simple_api.py +53 -59
  4. chuk_ai_session_manager/exceptions.py +31 -17
  5. chuk_ai_session_manager/guards/__init__.py +118 -0
  6. chuk_ai_session_manager/guards/bindings.py +217 -0
  7. chuk_ai_session_manager/guards/cache.py +163 -0
  8. chuk_ai_session_manager/guards/manager.py +819 -0
  9. chuk_ai_session_manager/guards/models.py +498 -0
  10. chuk_ai_session_manager/guards/ungrounded.py +159 -0
  11. chuk_ai_session_manager/infinite_conversation.py +86 -79
  12. chuk_ai_session_manager/memory/__init__.py +247 -0
  13. chuk_ai_session_manager/memory/artifacts_bridge.py +469 -0
  14. chuk_ai_session_manager/memory/context_packer.py +347 -0
  15. chuk_ai_session_manager/memory/fault_handler.py +507 -0
  16. chuk_ai_session_manager/memory/manifest.py +307 -0
  17. chuk_ai_session_manager/memory/models.py +1084 -0
  18. chuk_ai_session_manager/memory/mutation_log.py +186 -0
  19. chuk_ai_session_manager/memory/pack_cache.py +206 -0
  20. chuk_ai_session_manager/memory/page_table.py +275 -0
  21. chuk_ai_session_manager/memory/prefetcher.py +192 -0
  22. chuk_ai_session_manager/memory/tlb.py +247 -0
  23. chuk_ai_session_manager/memory/vm_prompts.py +238 -0
  24. chuk_ai_session_manager/memory/working_set.py +574 -0
  25. chuk_ai_session_manager/models/__init__.py +21 -9
  26. chuk_ai_session_manager/models/event_source.py +3 -1
  27. chuk_ai_session_manager/models/event_type.py +10 -1
  28. chuk_ai_session_manager/models/session.py +103 -68
  29. chuk_ai_session_manager/models/session_event.py +69 -68
  30. chuk_ai_session_manager/models/session_metadata.py +9 -10
  31. chuk_ai_session_manager/models/session_run.py +21 -22
  32. chuk_ai_session_manager/models/token_usage.py +76 -76
  33. chuk_ai_session_manager/procedural_memory/__init__.py +70 -0
  34. chuk_ai_session_manager/procedural_memory/formatter.py +407 -0
  35. chuk_ai_session_manager/procedural_memory/manager.py +523 -0
  36. chuk_ai_session_manager/procedural_memory/models.py +371 -0
  37. chuk_ai_session_manager/sample_tools.py +79 -46
  38. chuk_ai_session_manager/session_aware_tool_processor.py +27 -16
  39. chuk_ai_session_manager/session_manager.py +259 -232
  40. chuk_ai_session_manager/session_prompt_builder.py +163 -111
  41. chuk_ai_session_manager/session_storage.py +45 -52
  42. {chuk_ai_session_manager-0.7.1.dist-info → chuk_ai_session_manager-0.8.1.dist-info}/METADATA +80 -4
  43. chuk_ai_session_manager-0.8.1.dist-info/RECORD +45 -0
  44. {chuk_ai_session_manager-0.7.1.dist-info → chuk_ai_session_manager-0.8.1.dist-info}/WHEEL +1 -1
  45. chuk_ai_session_manager-0.7.1.dist-info/RECORD +0 -22
  46. {chuk_ai_session_manager-0.7.1.dist-info → chuk_ai_session_manager-0.8.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,523 @@
1
+ # chuk_ai_session_manager/procedural_memory/manager.py
2
+ """
3
+ Tool Memory Manager - orchestrates procedural memory operations.
4
+
5
+ Handles:
6
+ - Recording tool invocations
7
+ - Detecting fix relationships (failure -> success)
8
+ - Updating aggregated patterns
9
+ - Searching and retrieving history
10
+ - Session persistence integration
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import hashlib
16
+ import json
17
+ import logging
18
+ from datetime import datetime, timezone
19
+ from typing import Any, Callable, Optional, TYPE_CHECKING
20
+
21
+ from pydantic import BaseModel, Field
22
+
23
+ from chuk_ai_session_manager.procedural_memory.models import (
24
+ ProceduralMemory,
25
+ ToolFixRelation,
26
+ ToolLogEntry,
27
+ ToolOutcome,
28
+ ToolPattern,
29
+ )
30
+
31
+ if TYPE_CHECKING:
32
+ from chuk_ai_session_manager.models.session import Session
33
+
34
+ log = logging.getLogger(__name__)
35
+
36
+
37
+ class ToolMemoryManager(BaseModel):
38
+ """
39
+ Manages procedural memory - tool traces, patterns, and fixes.
40
+
41
+ This is the main interface for:
42
+ - Recording tool calls and their outcomes
43
+ - Detecting when a success fixes a prior failure
44
+ - Building aggregated patterns from traces
45
+ - Retrieving relevant history for context injection
46
+ - Persisting to/from Session state
47
+ """
48
+
49
+ memory: ProceduralMemory
50
+
51
+ # Configuration
52
+ max_log_entries: int = Field(default=1000)
53
+ max_patterns_per_tool: int = Field(default=10)
54
+ fix_detection_window: int = Field(default=10) # Look back N calls for fixes
55
+
56
+ # Optional callbacks
57
+ on_fix_detected: Optional[Callable[[ToolFixRelation], None]] = Field(
58
+ default=None, exclude=True
59
+ )
60
+
61
+ model_config = {"arbitrary_types_allowed": True}
62
+
63
+ @classmethod
64
+ def create(cls, session_id: str, **kwargs: Any) -> "ToolMemoryManager":
65
+ """Create a new manager for a session."""
66
+ memory = ProceduralMemory(session_id=session_id)
67
+ return cls(memory=memory, **kwargs)
68
+
69
+ @classmethod
70
+ async def from_session(
71
+ cls, session: "Session", **kwargs: Any
72
+ ) -> "ToolMemoryManager":
73
+ """
74
+ Load procedural memory from a session's state.
75
+
76
+ Args:
77
+ session: The Session to load from
78
+ **kwargs: Additional configuration
79
+
80
+ Returns:
81
+ ToolMemoryManager with restored state
82
+ """
83
+ state_key = "procedural_memory"
84
+ stored = await session.get_state(state_key)
85
+
86
+ if stored:
87
+ try:
88
+ memory = ProceduralMemory.model_validate(stored)
89
+ log.info(f"Restored procedural memory for session {session.id}")
90
+ return cls(memory=memory, **kwargs)
91
+ except Exception as e:
92
+ log.warning(f"Failed to restore procedural memory: {e}")
93
+
94
+ # Create fresh
95
+ return cls.create(session_id=session.id, **kwargs)
96
+
97
+ async def save_to_session(self, session: "Session") -> None:
98
+ """
99
+ Persist procedural memory to a session's state.
100
+
101
+ Args:
102
+ session: The Session to save to
103
+ """
104
+ state_key = "procedural_memory"
105
+ await session.set_state(state_key, self.memory.model_dump(mode="json"))
106
+ log.debug(f"Saved procedural memory to session {session.id}")
107
+
108
+ @property
109
+ def session_id(self) -> str:
110
+ """Get the session ID."""
111
+ return self.memory.session_id
112
+
113
+ # --- Recording ---
114
+
115
+ async def record_call(
116
+ self,
117
+ tool_name: str,
118
+ arguments: dict[str, Any],
119
+ result: Any,
120
+ outcome: ToolOutcome,
121
+ context_goal: Optional[str] = None,
122
+ error_type: Optional[str] = None,
123
+ error_message: Optional[str] = None,
124
+ execution_time_ms: Optional[int] = None,
125
+ preceding_call_id: Optional[str] = None,
126
+ ) -> ToolLogEntry:
127
+ """
128
+ Record a tool invocation.
129
+
130
+ Args:
131
+ tool_name: Name of the tool called
132
+ arguments: Arguments passed to the tool
133
+ result: The result (will be summarized)
134
+ outcome: Success/failure status
135
+ context_goal: What the user was trying to accomplish
136
+ error_type: Type of error if failure
137
+ error_message: Error details if failure
138
+ execution_time_ms: How long the call took
139
+ preceding_call_id: ID of the call that came before this in a chain
140
+
141
+ Returns:
142
+ The created ToolLogEntry
143
+ """
144
+ call_id = self.memory.allocate_call_id()
145
+
146
+ entry = ToolLogEntry(
147
+ id=call_id,
148
+ timestamp=datetime.now(timezone.utc),
149
+ tool_name=tool_name,
150
+ arguments=arguments,
151
+ arguments_hash=self._hash_arguments(arguments),
152
+ context_goal=context_goal,
153
+ preceding_call_id=preceding_call_id,
154
+ outcome=outcome,
155
+ result_summary=self._summarize_result(result),
156
+ result_type=self._classify_result(result),
157
+ execution_time_ms=execution_time_ms,
158
+ error_type=error_type,
159
+ error_message=error_message,
160
+ )
161
+
162
+ # Add to log
163
+ self.memory.tool_log.append(entry)
164
+
165
+ # Enforce max size
166
+ if len(self.memory.tool_log) > self.max_log_entries:
167
+ self.memory.tool_log.pop(0)
168
+
169
+ # Update patterns
170
+ pattern = self.memory.get_pattern(tool_name)
171
+ pattern.record_call(entry)
172
+
173
+ # Handle success - check if it fixes a prior failure
174
+ if outcome == ToolOutcome.SUCCESS:
175
+ await self._check_if_fixes_prior(entry)
176
+
177
+ # Record success pattern if it was a fix
178
+ if entry.is_fix():
179
+ pattern.add_success_pattern(
180
+ goal_match=context_goal,
181
+ delta_that_fixed=entry.delta_args,
182
+ example_call_id=call_id,
183
+ )
184
+
185
+ # Handle failure - record error pattern
186
+ elif outcome in (ToolOutcome.FAILURE, ToolOutcome.TIMEOUT):
187
+ pattern.add_error_pattern(
188
+ error_type=error_type or "unknown",
189
+ context=context_goal,
190
+ example_args=arguments,
191
+ )
192
+
193
+ self.memory.updated_at = datetime.now(timezone.utc)
194
+
195
+ log.debug(f"Recorded tool call: {entry.format_compact()}")
196
+ return entry
197
+
198
+ async def _check_if_fixes_prior(self, success_entry: ToolLogEntry) -> None:
199
+ """
200
+ Check if this successful call fixes a recent failure.
201
+
202
+ Looks back through recent calls for the same tool that failed,
203
+ and links them if found.
204
+ """
205
+ # Look at recent calls (excluding this one)
206
+ window = self.memory.tool_log[-(self.fix_detection_window + 1) : -1]
207
+
208
+ for prior in reversed(window):
209
+ # Must be same tool
210
+ if prior.tool_name != success_entry.tool_name:
211
+ continue
212
+
213
+ # Must be a failure
214
+ if not prior.is_failure():
215
+ continue
216
+
217
+ # Must not already be fixed
218
+ if prior.was_fixed():
219
+ continue
220
+
221
+ # Compute what changed
222
+ delta = self._compute_arg_delta(prior.arguments, success_entry.arguments)
223
+
224
+ if delta:
225
+ # Link the fix
226
+ prior.fixed_by = success_entry.id
227
+ success_entry.fix_for = prior.id
228
+ success_entry.delta_args = delta
229
+
230
+ # Record the fix relation
231
+ relation = ToolFixRelation(
232
+ failed_call_id=prior.id,
233
+ success_call_id=success_entry.id,
234
+ delta_args=delta,
235
+ )
236
+ self.memory.fix_relations.append(relation)
237
+
238
+ # Update error pattern with fix info
239
+ if prior.error_type:
240
+ pattern = self.memory.get_pattern(success_entry.tool_name)
241
+ pattern.record_fix(prior.error_type, delta)
242
+
243
+ log.info(
244
+ f"Detected fix: {prior.id} -> {success_entry.id}, delta: {delta}"
245
+ )
246
+
247
+ # Callback
248
+ if self.on_fix_detected:
249
+ self.on_fix_detected(relation)
250
+
251
+ # Only link to most recent failure
252
+ break
253
+
254
+ def _compute_arg_delta(
255
+ self, failed_args: dict[str, Any], success_args: dict[str, Any]
256
+ ) -> Optional[dict[str, Any]]:
257
+ """Compute what changed between failed and successful call."""
258
+ delta: dict[str, Any] = {}
259
+
260
+ failed_keys = set(failed_args.keys())
261
+ success_keys = set(success_args.keys())
262
+
263
+ # Added keys
264
+ added = success_keys - failed_keys
265
+ if added:
266
+ delta["added"] = {k: success_args[k] for k in added}
267
+
268
+ # Removed keys
269
+ removed = failed_keys - success_keys
270
+ if removed:
271
+ delta["removed"] = list(removed)
272
+
273
+ # Changed values
274
+ changed = {}
275
+ for k in failed_keys & success_keys:
276
+ if not self._values_equal(failed_args[k], success_args[k]):
277
+ changed[k] = {"from": failed_args[k], "to": success_args[k]}
278
+ if changed:
279
+ delta["changed"] = changed
280
+
281
+ return delta if delta else None
282
+
283
+ def _values_equal(self, a: Any, b: Any) -> bool:
284
+ """Check if two values are equal (handling nested structures)."""
285
+ try:
286
+ return json.dumps(a, sort_keys=True) == json.dumps(b, sort_keys=True)
287
+ except (TypeError, ValueError):
288
+ return a == b
289
+
290
+ def _hash_arguments(self, arguments: dict[str, Any]) -> str:
291
+ """Create hash of arguments for quick comparison."""
292
+ try:
293
+ args_str = json.dumps(arguments, sort_keys=True, default=str)
294
+ return hashlib.sha256(args_str.encode()).hexdigest()[:12]
295
+ except (TypeError, ValueError):
296
+ return ""
297
+
298
+ def _summarize_result(self, result: Any, max_len: int = 100) -> str:
299
+ """Create compact summary of a result."""
300
+ if result is None:
301
+ return "null"
302
+
303
+ if isinstance(result, str):
304
+ if len(result) <= max_len:
305
+ return result
306
+ return result[: max_len - 3] + "..."
307
+
308
+ if isinstance(result, bool):
309
+ return "true" if result else "false"
310
+
311
+ if isinstance(result, (int, float)):
312
+ return str(result)
313
+
314
+ if isinstance(result, list):
315
+ return f"list[{len(result)}]"
316
+
317
+ if isinstance(result, dict):
318
+ # Check for common patterns
319
+ if "status" in result:
320
+ return f"status={result['status']}"
321
+ if "error" in result:
322
+ return f"error: {result['error']}"
323
+ if "result" in result:
324
+ return self._summarize_result(result["result"], max_len)
325
+ return f"object[{len(result)} keys]"
326
+
327
+ return str(type(result).__name__)
328
+
329
+ def _classify_result(self, result: Any) -> str:
330
+ """Classify the type of result."""
331
+ if result is None:
332
+ return "null"
333
+ if isinstance(result, bool):
334
+ return "boolean"
335
+ if isinstance(result, (int, float)):
336
+ return "number"
337
+ if isinstance(result, str):
338
+ return "string"
339
+ if isinstance(result, list):
340
+ return "list"
341
+ if isinstance(result, dict):
342
+ if "status" in result:
343
+ return result.get("status", "object")
344
+ return "object"
345
+ return "unknown"
346
+
347
+ # --- Retrieval ---
348
+
349
+ def get_recent_calls(
350
+ self,
351
+ tool_name: Optional[str] = None,
352
+ limit: int = 5,
353
+ outcome: Optional[ToolOutcome] = None,
354
+ ) -> list[ToolLogEntry]:
355
+ """
356
+ Get recent tool calls.
357
+
358
+ Args:
359
+ tool_name: Filter by tool name (None = all tools)
360
+ limit: Maximum entries to return
361
+ outcome: Filter by outcome (None = all outcomes)
362
+
363
+ Returns:
364
+ List of entries, most recent first
365
+ """
366
+ entries = self.memory.tool_log
367
+
368
+ if tool_name:
369
+ entries = [e for e in entries if e.tool_name == tool_name]
370
+
371
+ if outcome:
372
+ entries = [e for e in entries if e.outcome == outcome]
373
+
374
+ return list(reversed(entries[-limit:]))
375
+
376
+ def get_pattern(self, tool_name: str) -> Optional[ToolPattern]:
377
+ """Get aggregated pattern for a tool."""
378
+ return self.memory.tool_patterns.get(tool_name)
379
+
380
+ def get_all_patterns(self) -> dict[str, ToolPattern]:
381
+ """Get all tool patterns."""
382
+ return self.memory.tool_patterns.copy()
383
+
384
+ def search_calls(
385
+ self,
386
+ tool_name: Optional[str] = None,
387
+ goal_contains: Optional[str] = None,
388
+ outcome: Optional[ToolOutcome] = None,
389
+ error_type: Optional[str] = None,
390
+ only_fixes: bool = False,
391
+ only_fixed: bool = False,
392
+ limit: int = 10,
393
+ ) -> list[ToolLogEntry]:
394
+ """
395
+ Search tool log with filters.
396
+
397
+ Args:
398
+ tool_name: Filter by tool name
399
+ goal_contains: Filter by goal containing text
400
+ outcome: Filter by outcome
401
+ error_type: Filter by error type
402
+ only_fixes: Only return calls that fixed something
403
+ only_fixed: Only return failures that were fixed
404
+ limit: Maximum results
405
+
406
+ Returns:
407
+ Matching entries, most recent first
408
+ """
409
+ results = []
410
+
411
+ for entry in reversed(self.memory.tool_log):
412
+ # Apply filters
413
+ if tool_name and entry.tool_name != tool_name:
414
+ continue
415
+
416
+ if goal_contains:
417
+ if not entry.context_goal:
418
+ continue
419
+ if goal_contains.lower() not in entry.context_goal.lower():
420
+ continue
421
+
422
+ if outcome and entry.outcome != outcome:
423
+ continue
424
+
425
+ if error_type and entry.error_type != error_type:
426
+ continue
427
+
428
+ if only_fixes and not entry.is_fix():
429
+ continue
430
+
431
+ if only_fixed and not entry.was_fixed():
432
+ continue
433
+
434
+ results.append(entry)
435
+
436
+ if len(results) >= limit:
437
+ break
438
+
439
+ return results
440
+
441
+ def get_fix_for_error(
442
+ self, tool_name: str, error_type: str
443
+ ) -> Optional[dict[str, Any]]:
444
+ """
445
+ Get the typical fix for an error type.
446
+
447
+ Returns the delta_args that typically fixes this error,
448
+ based on observed fix relationships.
449
+ """
450
+ pattern = self.get_pattern(tool_name)
451
+ if not pattern:
452
+ return None
453
+
454
+ for ep in pattern.error_patterns:
455
+ if ep.error_type == error_type and ep.fix_delta:
456
+ return ep.fix_delta
457
+
458
+ return None
459
+
460
+ def get_successful_args_for_goal(
461
+ self, tool_name: str, goal: str
462
+ ) -> Optional[dict[str, Any]]:
463
+ """
464
+ Get argument hints for a goal based on past successes.
465
+
466
+ Searches success patterns for matching goals and returns
467
+ argument hints.
468
+ """
469
+ pattern = self.get_pattern(tool_name)
470
+ if not pattern:
471
+ return None
472
+
473
+ goal_lower = goal.lower()
474
+ for sp in pattern.success_patterns:
475
+ if sp.goal_match and goal_lower in sp.goal_match.lower():
476
+ return sp.arg_hints
477
+
478
+ return None
479
+
480
+ # --- Statistics ---
481
+
482
+ def get_stats(self) -> dict[str, Any]:
483
+ """Get memory statistics."""
484
+ total_calls = len(self.memory.tool_log)
485
+ total_fixes = len(self.memory.fix_relations)
486
+ tools_tracked = len(self.memory.tool_patterns)
487
+
488
+ success_count = sum(
489
+ 1 for e in self.memory.tool_log if e.outcome == ToolOutcome.SUCCESS
490
+ )
491
+ failure_count = sum(1 for e in self.memory.tool_log if e.is_failure())
492
+
493
+ return {
494
+ "session_id": self.session_id,
495
+ "total_calls": total_calls,
496
+ "success_count": success_count,
497
+ "failure_count": failure_count,
498
+ "success_rate": success_count / total_calls if total_calls > 0 else 0,
499
+ "total_fixes_detected": total_fixes,
500
+ "tools_tracked": tools_tracked,
501
+ "created_at": self.memory.created_at.isoformat(),
502
+ "updated_at": self.memory.updated_at.isoformat(),
503
+ }
504
+
505
+ # --- Persistence hooks ---
506
+
507
+ def to_dict(self) -> dict[str, Any]:
508
+ """Serialize to dictionary for persistence."""
509
+ return self.memory.model_dump(mode="json")
510
+
511
+ @classmethod
512
+ def from_dict(cls, data: dict[str, Any], **kwargs: Any) -> "ToolMemoryManager":
513
+ """Restore from dictionary."""
514
+ memory = ProceduralMemory.model_validate(data)
515
+ return cls(memory=memory, **kwargs)
516
+
517
+ def reset(self) -> None:
518
+ """Clear all memory."""
519
+ self.memory.tool_log.clear()
520
+ self.memory.tool_patterns.clear()
521
+ self.memory.fix_relations.clear()
522
+ self.memory.next_call_id = 1
523
+ self.memory.updated_at = datetime.now(timezone.utc)