aloop 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of aloop might be problematic. Click here for more details.

Files changed (62) hide show
  1. agent/__init__.py +0 -0
  2. agent/agent.py +182 -0
  3. agent/base.py +406 -0
  4. agent/context.py +126 -0
  5. agent/todo.py +149 -0
  6. agent/tool_executor.py +54 -0
  7. agent/verification.py +135 -0
  8. aloop-0.1.0.dist-info/METADATA +246 -0
  9. aloop-0.1.0.dist-info/RECORD +62 -0
  10. aloop-0.1.0.dist-info/WHEEL +5 -0
  11. aloop-0.1.0.dist-info/entry_points.txt +2 -0
  12. aloop-0.1.0.dist-info/licenses/LICENSE +21 -0
  13. aloop-0.1.0.dist-info/top_level.txt +9 -0
  14. cli.py +19 -0
  15. config.py +146 -0
  16. interactive.py +865 -0
  17. llm/__init__.py +51 -0
  18. llm/base.py +26 -0
  19. llm/compat.py +226 -0
  20. llm/content_utils.py +309 -0
  21. llm/litellm_adapter.py +450 -0
  22. llm/message_types.py +245 -0
  23. llm/model_manager.py +265 -0
  24. llm/retry.py +95 -0
  25. main.py +246 -0
  26. memory/__init__.py +20 -0
  27. memory/compressor.py +554 -0
  28. memory/manager.py +538 -0
  29. memory/serialization.py +82 -0
  30. memory/short_term.py +88 -0
  31. memory/token_tracker.py +203 -0
  32. memory/types.py +51 -0
  33. tools/__init__.py +6 -0
  34. tools/advanced_file_ops.py +557 -0
  35. tools/base.py +51 -0
  36. tools/calculator.py +50 -0
  37. tools/code_navigator.py +975 -0
  38. tools/explore.py +254 -0
  39. tools/file_ops.py +150 -0
  40. tools/git_tools.py +791 -0
  41. tools/notify.py +69 -0
  42. tools/parallel_execute.py +420 -0
  43. tools/session_manager.py +205 -0
  44. tools/shell.py +147 -0
  45. tools/shell_background.py +470 -0
  46. tools/smart_edit.py +491 -0
  47. tools/todo.py +130 -0
  48. tools/web_fetch.py +673 -0
  49. tools/web_search.py +61 -0
  50. utils/__init__.py +15 -0
  51. utils/logger.py +105 -0
  52. utils/model_pricing.py +49 -0
  53. utils/runtime.py +75 -0
  54. utils/terminal_ui.py +422 -0
  55. utils/tui/__init__.py +39 -0
  56. utils/tui/command_registry.py +49 -0
  57. utils/tui/components.py +306 -0
  58. utils/tui/input_handler.py +393 -0
  59. utils/tui/model_ui.py +204 -0
  60. utils/tui/progress.py +292 -0
  61. utils/tui/status_bar.py +178 -0
  62. utils/tui/theme.py +165 -0
memory/manager.py ADDED
@@ -0,0 +1,538 @@
1
+ """Core memory manager that orchestrates all memory operations."""
2
+
3
+ import logging
4
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
5
+
6
+ from config import Config
7
+ from llm.content_utils import content_has_tool_calls
8
+ from llm.message_types import LLMMessage
9
+ from utils import terminal_ui
10
+ from utils.tui.progress import AsyncSpinner
11
+
12
+ from .compressor import WorkingMemoryCompressor
13
+ from .short_term import ShortTermMemory
14
+ from .token_tracker import TokenTracker
15
+ from .types import CompressedMemory, CompressionStrategy
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ if TYPE_CHECKING:
20
+ from llm import LiteLLMAdapter
21
+
22
+
23
+ class MemoryManager:
24
+ """Central memory management system with built-in persistence.
25
+
26
+ The persistence store is fully owned by MemoryManager and should not
27
+ be created or passed in from outside.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ llm: "LiteLLMAdapter",
33
+ session_id: Optional[str] = None,
34
+ ):
35
+ """Initialize memory manager.
36
+
37
+ Args:
38
+ llm: LLM instance for compression
39
+ session_id: Optional session ID (if resuming session)
40
+ """
41
+ self.llm = llm
42
+
43
+ # Store is fully owned by MemoryManager
44
+ from .store import YamlFileMemoryStore
45
+
46
+ self._store = YamlFileMemoryStore()
47
+
48
+ # Lazy session creation: only create when first message is added
49
+ # If session_id is provided (resuming), use it immediately
50
+ if session_id is not None:
51
+ self.session_id = session_id
52
+ self._session_created = True
53
+ else:
54
+ self.session_id = None
55
+ self._session_created = False
56
+
57
+ # Initialize components using Config directly
58
+ self.short_term = ShortTermMemory(max_size=Config.MEMORY_SHORT_TERM_SIZE)
59
+ self.compressor = WorkingMemoryCompressor(llm)
60
+ self.token_tracker = TokenTracker()
61
+
62
+ # Storage for system messages
63
+ self.system_messages: List[LLMMessage] = []
64
+
65
+ # State tracking
66
+ self.current_tokens = 0
67
+ self.was_compressed_last_iteration = False
68
+ self.last_compression_savings = 0
69
+ self.compression_count = 0
70
+
71
+ # Optional callback to get current todo context for compression
72
+ self._todo_context_provider: Optional[Callable[[], Optional[str]]] = None
73
+
74
+ @classmethod
75
+ async def from_session(
76
+ cls,
77
+ session_id: str,
78
+ llm: "LiteLLMAdapter",
79
+ ) -> "MemoryManager":
80
+ """Load a MemoryManager from a saved session.
81
+
82
+ Args:
83
+ session_id: Session ID to load
84
+ llm: LLM instance for compression
85
+
86
+ Returns:
87
+ MemoryManager instance with loaded state
88
+ """
89
+ manager = cls(llm=llm, session_id=session_id)
90
+
91
+ # Load session data
92
+ session_data = await manager._store.load_session(session_id)
93
+ if not session_data:
94
+ raise ValueError(f"Session {session_id} not found")
95
+
96
+ # Restore state
97
+ manager.system_messages = session_data["system_messages"]
98
+
99
+ # Add messages to short-term memory (including any summary messages)
100
+ for msg in session_data["messages"]:
101
+ manager.short_term.add_message(msg)
102
+
103
+ # Recalculate tokens
104
+ manager.current_tokens = manager._recalculate_current_tokens()
105
+
106
+ logger.info(
107
+ f"Loaded session {session_id}: "
108
+ f"{len(session_data['messages'])} messages, "
109
+ f"{manager.current_tokens} tokens"
110
+ )
111
+
112
+ return manager
113
+
114
+ @staticmethod
115
+ async def list_sessions(limit: int = 50) -> List[Dict[str, Any]]:
116
+ """List saved sessions.
117
+
118
+ Args:
119
+ limit: Maximum number of sessions to return
120
+
121
+ Returns:
122
+ List of session summaries
123
+ """
124
+ from .store import YamlFileMemoryStore
125
+
126
+ store = YamlFileMemoryStore()
127
+ return await store.list_sessions(limit=limit)
128
+
129
+ @staticmethod
130
+ async def find_latest_session() -> Optional[str]:
131
+ """Find the most recently updated session ID.
132
+
133
+ Returns:
134
+ Session ID or None if no sessions exist
135
+ """
136
+ from .store import YamlFileMemoryStore
137
+
138
+ store = YamlFileMemoryStore()
139
+ return await store.find_latest_session()
140
+
141
+ @staticmethod
142
+ async def find_session_by_prefix(prefix: str) -> Optional[str]:
143
+ """Find a session by ID prefix.
144
+
145
+ Args:
146
+ prefix: Prefix of session UUID
147
+
148
+ Returns:
149
+ Full session ID or None
150
+ """
151
+ from .store import YamlFileMemoryStore
152
+
153
+ store = YamlFileMemoryStore()
154
+ return await store.find_session_by_prefix(prefix)
155
+
156
+ async def _ensure_session(self) -> None:
157
+ """Lazily create session when first needed.
158
+
159
+ This avoids creating empty sessions when MemoryManager is instantiated
160
+ but no messages are ever added (e.g., user exits before running any task).
161
+
162
+ Raises:
163
+ RuntimeError: If session creation fails
164
+ """
165
+ if not self._session_created:
166
+ try:
167
+ self.session_id = await self._store.create_session()
168
+ self._session_created = True
169
+ logger.info(f"Created new session: {self.session_id}")
170
+ except Exception as e:
171
+ logger.error(f"Failed to create session: {e}")
172
+ raise RuntimeError(f"Failed to create memory session: {e}") from e
173
+
174
+ async def add_message(self, message: LLMMessage, actual_tokens: Dict[str, int] = None) -> None:
175
+ """Add a message to memory and trigger compression if needed.
176
+
177
+ Args:
178
+ message: Message to add
179
+ actual_tokens: Optional dict with actual token counts from LLM response
180
+ Format: {"input": int, "output": int}
181
+ """
182
+ # Ensure session exists before adding messages
183
+ await self._ensure_session()
184
+
185
+ # Track system messages separately
186
+ if message.role == "system":
187
+ self.system_messages.append(message)
188
+ return
189
+
190
+ # Count tokens (use actual if provided, otherwise estimate)
191
+ if actual_tokens:
192
+ # Use actual token counts from LLM response
193
+ # Note: input_tokens includes full context sent to API, not just new content
194
+ input_tokens = actual_tokens.get("input", 0)
195
+ output_tokens = actual_tokens.get("output", 0)
196
+
197
+ self.token_tracker.add_input_tokens(input_tokens)
198
+ self.token_tracker.add_output_tokens(output_tokens)
199
+
200
+ # Log API usage separately
201
+ logger.debug(
202
+ f"API usage: input={input_tokens}, output={output_tokens}, "
203
+ f"total={input_tokens + output_tokens}"
204
+ )
205
+ else:
206
+ # Estimate token count for non-API messages (tool results, etc.)
207
+ provider = self.llm.provider_name.lower()
208
+ model = self.llm.model
209
+ tokens = self.token_tracker.count_message_tokens(message, provider, model)
210
+
211
+ # Update token tracker
212
+ if message.role == "assistant":
213
+ self.token_tracker.add_output_tokens(tokens)
214
+ else:
215
+ self.token_tracker.add_input_tokens(tokens)
216
+
217
+ # Add to short-term memory
218
+ self.short_term.add_message(message)
219
+
220
+ # Recalculate current tokens based on actual stored content
221
+ # This gives accurate count for compression decisions
222
+ self.current_tokens = self._recalculate_current_tokens()
223
+
224
+ # Log memory state (stored content size, not API usage)
225
+ logger.debug(
226
+ f"Memory state: {self.current_tokens} stored tokens, "
227
+ f"{self.short_term.count()}/{Config.MEMORY_SHORT_TERM_SIZE} messages, "
228
+ f"full={self.short_term.is_full()}"
229
+ )
230
+
231
+ # Check if compression is needed
232
+ self.was_compressed_last_iteration = False
233
+ should_compress, reason = self._should_compress()
234
+ if should_compress:
235
+ logger.info(f"🗜️ Triggering compression: {reason}")
236
+ await self.compress()
237
+ else:
238
+ # Log compression check details
239
+ logger.debug(
240
+ f"Compression check: stored={self.current_tokens}, "
241
+ f"threshold={Config.MEMORY_COMPRESSION_THRESHOLD}, "
242
+ f"short_term_full={self.short_term.is_full()}"
243
+ )
244
+
245
+ def get_context_for_llm(self) -> List[LLMMessage]:
246
+ """Get optimized context for LLM call.
247
+
248
+ Returns:
249
+ List of messages: system messages + short-term messages (which includes summaries)
250
+ """
251
+ context = []
252
+
253
+ # 1. Add system messages (always included)
254
+ context.extend(self.system_messages)
255
+
256
+ # 2. Add short-term memory (includes summary messages and recent messages)
257
+ context.extend(self.short_term.get_messages())
258
+
259
+ return context
260
+
261
+ def set_todo_context_provider(self, provider: Callable[[], Optional[str]]) -> None:
262
+ """Set a callback to provide current todo context for compression.
263
+
264
+ The provider should return a formatted string of current todo items,
265
+ or None if no todos exist. This context will be injected into
266
+ compression summaries to preserve task state.
267
+
268
+ Args:
269
+ provider: Callable that returns current todo context string or None
270
+ """
271
+ self._todo_context_provider = provider
272
+
273
+ async def compress(self, strategy: str = None) -> Optional[CompressedMemory]:
274
+ """Compress current short-term memory.
275
+
276
+ After compression, the compressed messages (including any summary as user message)
277
+ are put back into short_term as regular messages.
278
+
279
+ Args:
280
+ strategy: Compression strategy (None = auto-select)
281
+
282
+ Returns:
283
+ CompressedMemory object if compression was performed
284
+ """
285
+ messages = self.short_term.get_messages()
286
+ message_count = len(messages)
287
+
288
+ if not messages:
289
+ logger.warning("No messages to compress")
290
+ return None
291
+
292
+ # Auto-select strategy if not specified
293
+ if strategy is None:
294
+ strategy = self._select_strategy(messages)
295
+
296
+ logger.info(f"🗜️ Compressing {message_count} messages using {strategy} strategy")
297
+
298
+ try:
299
+ # Get todo context if provider is set
300
+ todo_context = None
301
+ if self._todo_context_provider:
302
+ todo_context = self._todo_context_provider()
303
+
304
+ # Perform compression with TUI spinner
305
+ async with AsyncSpinner(terminal_ui.console, "Compressing memory..."):
306
+ compressed = await self.compressor.compress(
307
+ messages,
308
+ strategy=strategy,
309
+ target_tokens=self._calculate_target_tokens(),
310
+ todo_context=todo_context,
311
+ )
312
+
313
+ # Track compression results
314
+ self.compression_count += 1
315
+ self.was_compressed_last_iteration = True
316
+ self.last_compression_savings = compressed.token_savings
317
+
318
+ # Update token tracker
319
+ self.token_tracker.add_compression_savings(compressed.token_savings)
320
+ self.token_tracker.add_compression_cost(compressed.compressed_tokens)
321
+
322
+ # Remove compressed messages from short-term memory
323
+ self.short_term.remove_first(message_count)
324
+
325
+ # Get any remaining messages (added after compression started)
326
+ remaining_messages = self.short_term.get_messages()
327
+ self.short_term.clear()
328
+
329
+ # Add compressed messages (summary + preserved, already combined by compressor)
330
+ for msg in compressed.messages:
331
+ self.short_term.add_message(msg)
332
+
333
+ # Add any remaining messages
334
+ for msg in remaining_messages:
335
+ self.short_term.add_message(msg)
336
+
337
+ # Update current token count
338
+ old_tokens = self.current_tokens
339
+ self.current_tokens = self._recalculate_current_tokens()
340
+
341
+ # Log compression results
342
+ logger.info(
343
+ f"✅ Compression complete: {compressed.original_tokens} → {compressed.compressed_tokens} tokens "
344
+ f"({compressed.savings_percentage:.1f}% saved, ratio: {compressed.compression_ratio:.2f}), "
345
+ f"context: {old_tokens} → {self.current_tokens} tokens, "
346
+ f"short_term now has {self.short_term.count()} messages"
347
+ )
348
+
349
+ return compressed
350
+
351
+ except Exception as e:
352
+ logger.error(f"Compression failed: {e}")
353
+ return None
354
+
355
+ def _should_compress(self) -> tuple[bool, Optional[str]]:
356
+ """Check if compression should be triggered.
357
+
358
+ Returns:
359
+ Tuple of (should_compress, reason)
360
+ """
361
+ if not Config.MEMORY_ENABLED:
362
+ return False, "compression_disabled"
363
+
364
+ # Hard limit: must compress
365
+ if self.current_tokens > Config.MEMORY_COMPRESSION_THRESHOLD:
366
+ return (
367
+ True,
368
+ f"hard_limit ({self.current_tokens} > {Config.MEMORY_COMPRESSION_THRESHOLD})",
369
+ )
370
+
371
+ # CRITICAL: Compress when short-term memory is full to prevent eviction
372
+ # If we don't compress, the next message will cause deque to evict the oldest message,
373
+ # which may break tool_use/tool_result pairs
374
+ if self.short_term.is_full():
375
+ return (
376
+ True,
377
+ f"short_term_full ({self.short_term.count()}/{Config.MEMORY_SHORT_TERM_SIZE} messages, "
378
+ f"current tokens: {self.current_tokens})",
379
+ )
380
+
381
+ return False, None
382
+
383
+ def _select_strategy(self, messages: List[LLMMessage]) -> str:
384
+ """Auto-select compression strategy based on message characteristics.
385
+
386
+ Args:
387
+ messages: Messages to analyze
388
+
389
+ Returns:
390
+ Strategy name
391
+ """
392
+ # Check for tool calls
393
+ has_tool_calls = any(self._message_has_tool_calls(msg) for msg in messages)
394
+
395
+ # Select strategy
396
+ if has_tool_calls:
397
+ # Preserve tool calls
398
+ return CompressionStrategy.SELECTIVE
399
+ elif len(messages) < 5:
400
+ # Too few messages, just delete
401
+ return CompressionStrategy.DELETION
402
+ else:
403
+ # Default: sliding window
404
+ return CompressionStrategy.SLIDING_WINDOW
405
+
406
+ def _message_has_tool_calls(self, message: LLMMessage) -> bool:
407
+ """Check if message contains tool calls.
408
+
409
+ Handles both new format (tool_calls field) and legacy format (content blocks).
410
+
411
+ Args:
412
+ message: Message to check
413
+
414
+ Returns:
415
+ True if contains tool calls
416
+ """
417
+ # New format: check tool_calls field
418
+ if hasattr(message, "tool_calls") and message.tool_calls:
419
+ return True
420
+
421
+ # New format: tool role message
422
+ if message.role == "tool":
423
+ return True
424
+
425
+ # Legacy/centralized check on content
426
+ return content_has_tool_calls(message.content)
427
+
428
+ def _calculate_target_tokens(self) -> int:
429
+ """Calculate target token count for compression.
430
+
431
+ Returns:
432
+ Target token count
433
+ """
434
+ original_tokens = self.current_tokens
435
+ target = int(original_tokens * Config.MEMORY_COMPRESSION_RATIO)
436
+ return max(target, 500) # Minimum 500 tokens for summary
437
+
438
+ def _recalculate_current_tokens(self) -> int:
439
+ """Recalculate current token count from scratch.
440
+
441
+ Returns:
442
+ Current token count
443
+ """
444
+ provider = self.llm.provider_name.lower()
445
+ model = self.llm.model
446
+
447
+ total = 0
448
+
449
+ # Count system messages
450
+ for msg in self.system_messages:
451
+ total += self.token_tracker.count_message_tokens(msg, provider, model)
452
+
453
+ # Count short-term messages (includes summary messages)
454
+ for msg in self.short_term.get_messages():
455
+ total += self.token_tracker.count_message_tokens(msg, provider, model)
456
+
457
+ return total
458
+
459
+ def get_stats(self) -> Dict[str, Any]:
460
+ """Get memory statistics.
461
+
462
+ Returns:
463
+ Dict with statistics
464
+ """
465
+ return {
466
+ "current_tokens": self.current_tokens,
467
+ "total_input_tokens": self.token_tracker.total_input_tokens,
468
+ "total_output_tokens": self.token_tracker.total_output_tokens,
469
+ "compression_count": self.compression_count,
470
+ "total_savings": self.token_tracker.compression_savings,
471
+ "compression_cost": self.token_tracker.compression_cost,
472
+ "net_savings": self.token_tracker.compression_savings
473
+ - self.token_tracker.compression_cost,
474
+ "short_term_count": self.short_term.count(),
475
+ "total_cost": self.token_tracker.get_total_cost(self.llm.model),
476
+ }
477
+
478
+ async def save_memory(self):
479
+ """Save current memory state to store.
480
+
481
+ This saves the complete memory state including:
482
+ - System messages
483
+ - Short-term messages (which includes summary messages after compression)
484
+
485
+ Call this method after completing a task or at key checkpoints.
486
+ """
487
+ # Skip if no session was created (no messages were ever added)
488
+ if not self._store or not self._session_created or not self.session_id:
489
+ logger.debug("Skipping save_memory: no session created")
490
+ return
491
+
492
+ messages = self.short_term.get_messages()
493
+
494
+ # Skip saving if there are no messages (empty conversation)
495
+ if not messages and not self.system_messages:
496
+ logger.debug(f"Skipping save_memory: no messages to save for session {self.session_id}")
497
+ return
498
+
499
+ await self._store.save_memory(
500
+ session_id=self.session_id,
501
+ system_messages=self.system_messages,
502
+ messages=messages,
503
+ )
504
+ logger.info(f"Saved memory state for session {self.session_id}")
505
+
506
+ def reset(self):
507
+ """Reset memory manager state."""
508
+ self.short_term.clear()
509
+ self.system_messages.clear()
510
+ self.token_tracker.reset()
511
+ self.current_tokens = 0
512
+ self.was_compressed_last_iteration = False
513
+ self.last_compression_savings = 0
514
+ self.compression_count = 0
515
+
516
+ def rollback_incomplete_exchange(self) -> None:
517
+ """Rollback the last incomplete assistant response with tool_calls.
518
+
519
+ This is used when a task is interrupted before tool execution completes.
520
+ It removes the assistant message if it contains tool_calls but no results.
521
+ The user message is preserved so the agent can see the original question.
522
+
523
+ This prevents API errors about missing tool responses on the next turn.
524
+ """
525
+ messages = self.short_term.get_messages()
526
+ if not messages:
527
+ return
528
+
529
+ # Check if last message is an assistant message with tool_calls
530
+ last_msg = messages[-1]
531
+ if last_msg.role == "assistant" and self._message_has_tool_calls(last_msg):
532
+ # Remove only the assistant message with tool_calls
533
+ # Keep the user message so the agent can still see the question
534
+ self.short_term.remove_last(1)
535
+ logger.debug("Removed incomplete assistant message with tool_calls")
536
+
537
+ # Recalculate token count
538
+ self.current_tokens = self._recalculate_current_tokens()
@@ -0,0 +1,82 @@
1
+ """Shared serialization logic for memory persistence.
2
+
3
+ Provides serialize/deserialize functions for LLMMessage objects,
4
+ used by all persistence backends.
5
+ """
6
+
7
+ import json
8
+ from typing import Any, Dict
9
+
10
+ from llm.message_types import LLMMessage
11
+
12
+
13
+ def serialize_content(content: Any) -> Any:
14
+ """Serialize message content, handling complex objects.
15
+
16
+ Args:
17
+ content: Message content (can be string, list, dict, or None)
18
+
19
+ Returns:
20
+ JSON-serializable content
21
+ """
22
+ if content is None:
23
+ return None
24
+ elif isinstance(content, str):
25
+ return content
26
+ elif isinstance(content, (list, dict)):
27
+ try:
28
+ json.dumps(content)
29
+ return content
30
+ except (TypeError, ValueError):
31
+ return str(content)
32
+ else:
33
+ return str(content)
34
+
35
+
36
+ def serialize_message(message: LLMMessage) -> Dict[str, Any]:
37
+ """Serialize an LLMMessage to a JSON/YAML-serializable dict.
38
+
39
+ Args:
40
+ message: LLMMessage to serialize
41
+
42
+ Returns:
43
+ Serializable dict
44
+ """
45
+ result: Dict[str, Any] = {
46
+ "role": message.role,
47
+ "content": serialize_content(message.content),
48
+ }
49
+
50
+ # For assistant messages, always include tool_calls (even if None) for completeness
51
+ if message.role == "assistant":
52
+ result["tool_calls"] = (
53
+ message.tool_calls if (hasattr(message, "tool_calls") and message.tool_calls) else None
54
+ )
55
+ elif hasattr(message, "tool_calls") and message.tool_calls:
56
+ result["tool_calls"] = message.tool_calls
57
+
58
+ if hasattr(message, "tool_call_id") and message.tool_call_id:
59
+ result["tool_call_id"] = message.tool_call_id
60
+
61
+ if hasattr(message, "name") and message.name:
62
+ result["name"] = message.name
63
+
64
+ return result
65
+
66
+
67
+ def deserialize_message(data: Dict[str, Any]) -> LLMMessage:
68
+ """Deserialize a dict to an LLMMessage.
69
+
70
+ Args:
71
+ data: Dict with message data
72
+
73
+ Returns:
74
+ LLMMessage instance
75
+ """
76
+ return LLMMessage(
77
+ role=data["role"],
78
+ content=data.get("content"),
79
+ tool_calls=data.get("tool_calls"),
80
+ tool_call_id=data.get("tool_call_id"),
81
+ name=data.get("name"),
82
+ )
memory/short_term.py ADDED
@@ -0,0 +1,88 @@
1
+ """Short-term memory management with fixed-size window."""
2
+
3
+ from collections import deque
4
+ from typing import List
5
+
6
+ from llm.base import LLMMessage
7
+
8
+
9
+ class ShortTermMemory:
10
+ """Manages recent messages in a fixed-size sliding window."""
11
+
12
+ def __init__(self, max_size: int = 20):
13
+ """Initialize short-term memory.
14
+
15
+ Args:
16
+ max_size: Maximum number of messages to keep
17
+ """
18
+ self.max_size = max_size
19
+ self.messages = deque(maxlen=max_size)
20
+
21
+ def add_message(self, message: LLMMessage) -> None:
22
+ """Add a message to short-term memory.
23
+
24
+ Automatically evicts oldest message if at capacity.
25
+
26
+ Args:
27
+ message: LLMMessage to add
28
+ """
29
+ self.messages.append(message)
30
+
31
+ def get_messages(self) -> List[LLMMessage]:
32
+ """Get all messages in short-term memory.
33
+
34
+ Returns:
35
+ List of messages, oldest to newest
36
+ """
37
+ return list(self.messages)
38
+
39
+ def clear(self) -> List[LLMMessage]:
40
+ """Clear all messages and return them.
41
+
42
+ Returns:
43
+ List of all messages that were cleared
44
+ """
45
+ messages = list(self.messages)
46
+ self.messages.clear()
47
+ return messages
48
+
49
+ def remove_first(self, count: int) -> List[LLMMessage]:
50
+ """Remove the first N messages (oldest) from memory.
51
+
52
+ This is useful after compression to remove only the compressed messages
53
+ while preserving any new messages that arrived during compression.
54
+
55
+ Args:
56
+ count: Number of messages to remove from the front
57
+
58
+ Returns:
59
+ List of removed messages
60
+ """
61
+ return [self.messages.popleft() for _ in range(min(count, len(self.messages)))]
62
+
63
+ def is_full(self) -> bool:
64
+ """Check if short-term memory is at capacity.
65
+
66
+ Returns:
67
+ True if at max capacity
68
+ """
69
+ return len(self.messages) >= self.max_size
70
+
71
+ def count(self) -> int:
72
+ """Get current message count.
73
+
74
+ Returns:
75
+ Number of messages in short-term memory
76
+ """
77
+ return len(self.messages)
78
+
79
+ def remove_last(self, count: int = 1) -> None:
80
+ """Remove the last N messages (newest) from memory.
81
+
82
+ This is useful for rolling back incomplete exchanges (e.g., after interruption).
83
+
84
+ Args:
85
+ count: Number of messages to remove from the end (default: 1)
86
+ """
87
+ for _ in range(min(count, len(self.messages))):
88
+ self.messages.pop()