aloop 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of aloop might be problematic. Click here for more details.
- agent/__init__.py +0 -0
- agent/agent.py +182 -0
- agent/base.py +406 -0
- agent/context.py +126 -0
- agent/todo.py +149 -0
- agent/tool_executor.py +54 -0
- agent/verification.py +135 -0
- aloop-0.1.0.dist-info/METADATA +246 -0
- aloop-0.1.0.dist-info/RECORD +62 -0
- aloop-0.1.0.dist-info/WHEEL +5 -0
- aloop-0.1.0.dist-info/entry_points.txt +2 -0
- aloop-0.1.0.dist-info/licenses/LICENSE +21 -0
- aloop-0.1.0.dist-info/top_level.txt +9 -0
- cli.py +19 -0
- config.py +146 -0
- interactive.py +865 -0
- llm/__init__.py +51 -0
- llm/base.py +26 -0
- llm/compat.py +226 -0
- llm/content_utils.py +309 -0
- llm/litellm_adapter.py +450 -0
- llm/message_types.py +245 -0
- llm/model_manager.py +265 -0
- llm/retry.py +95 -0
- main.py +246 -0
- memory/__init__.py +20 -0
- memory/compressor.py +554 -0
- memory/manager.py +538 -0
- memory/serialization.py +82 -0
- memory/short_term.py +88 -0
- memory/token_tracker.py +203 -0
- memory/types.py +51 -0
- tools/__init__.py +6 -0
- tools/advanced_file_ops.py +557 -0
- tools/base.py +51 -0
- tools/calculator.py +50 -0
- tools/code_navigator.py +975 -0
- tools/explore.py +254 -0
- tools/file_ops.py +150 -0
- tools/git_tools.py +791 -0
- tools/notify.py +69 -0
- tools/parallel_execute.py +420 -0
- tools/session_manager.py +205 -0
- tools/shell.py +147 -0
- tools/shell_background.py +470 -0
- tools/smart_edit.py +491 -0
- tools/todo.py +130 -0
- tools/web_fetch.py +673 -0
- tools/web_search.py +61 -0
- utils/__init__.py +15 -0
- utils/logger.py +105 -0
- utils/model_pricing.py +49 -0
- utils/runtime.py +75 -0
- utils/terminal_ui.py +422 -0
- utils/tui/__init__.py +39 -0
- utils/tui/command_registry.py +49 -0
- utils/tui/components.py +306 -0
- utils/tui/input_handler.py +393 -0
- utils/tui/model_ui.py +204 -0
- utils/tui/progress.py +292 -0
- utils/tui/status_bar.py +178 -0
- utils/tui/theme.py +165 -0
memory/manager.py
ADDED
|
@@ -0,0 +1,538 @@
|
|
|
1
|
+
"""Core memory manager that orchestrates all memory operations."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
|
5
|
+
|
|
6
|
+
from config import Config
|
|
7
|
+
from llm.content_utils import content_has_tool_calls
|
|
8
|
+
from llm.message_types import LLMMessage
|
|
9
|
+
from utils import terminal_ui
|
|
10
|
+
from utils.tui.progress import AsyncSpinner
|
|
11
|
+
|
|
12
|
+
from .compressor import WorkingMemoryCompressor
|
|
13
|
+
from .short_term import ShortTermMemory
|
|
14
|
+
from .token_tracker import TokenTracker
|
|
15
|
+
from .types import CompressedMemory, CompressionStrategy
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from llm import LiteLLMAdapter
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class MemoryManager:
|
|
24
|
+
"""Central memory management system with built-in persistence.
|
|
25
|
+
|
|
26
|
+
The persistence store is fully owned by MemoryManager and should not
|
|
27
|
+
be created or passed in from outside.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
llm: "LiteLLMAdapter",
|
|
33
|
+
session_id: Optional[str] = None,
|
|
34
|
+
):
|
|
35
|
+
"""Initialize memory manager.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
llm: LLM instance for compression
|
|
39
|
+
session_id: Optional session ID (if resuming session)
|
|
40
|
+
"""
|
|
41
|
+
self.llm = llm
|
|
42
|
+
|
|
43
|
+
# Store is fully owned by MemoryManager
|
|
44
|
+
from .store import YamlFileMemoryStore
|
|
45
|
+
|
|
46
|
+
self._store = YamlFileMemoryStore()
|
|
47
|
+
|
|
48
|
+
# Lazy session creation: only create when first message is added
|
|
49
|
+
# If session_id is provided (resuming), use it immediately
|
|
50
|
+
if session_id is not None:
|
|
51
|
+
self.session_id = session_id
|
|
52
|
+
self._session_created = True
|
|
53
|
+
else:
|
|
54
|
+
self.session_id = None
|
|
55
|
+
self._session_created = False
|
|
56
|
+
|
|
57
|
+
# Initialize components using Config directly
|
|
58
|
+
self.short_term = ShortTermMemory(max_size=Config.MEMORY_SHORT_TERM_SIZE)
|
|
59
|
+
self.compressor = WorkingMemoryCompressor(llm)
|
|
60
|
+
self.token_tracker = TokenTracker()
|
|
61
|
+
|
|
62
|
+
# Storage for system messages
|
|
63
|
+
self.system_messages: List[LLMMessage] = []
|
|
64
|
+
|
|
65
|
+
# State tracking
|
|
66
|
+
self.current_tokens = 0
|
|
67
|
+
self.was_compressed_last_iteration = False
|
|
68
|
+
self.last_compression_savings = 0
|
|
69
|
+
self.compression_count = 0
|
|
70
|
+
|
|
71
|
+
# Optional callback to get current todo context for compression
|
|
72
|
+
self._todo_context_provider: Optional[Callable[[], Optional[str]]] = None
|
|
73
|
+
|
|
74
|
+
@classmethod
|
|
75
|
+
async def from_session(
|
|
76
|
+
cls,
|
|
77
|
+
session_id: str,
|
|
78
|
+
llm: "LiteLLMAdapter",
|
|
79
|
+
) -> "MemoryManager":
|
|
80
|
+
"""Load a MemoryManager from a saved session.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
session_id: Session ID to load
|
|
84
|
+
llm: LLM instance for compression
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
MemoryManager instance with loaded state
|
|
88
|
+
"""
|
|
89
|
+
manager = cls(llm=llm, session_id=session_id)
|
|
90
|
+
|
|
91
|
+
# Load session data
|
|
92
|
+
session_data = await manager._store.load_session(session_id)
|
|
93
|
+
if not session_data:
|
|
94
|
+
raise ValueError(f"Session {session_id} not found")
|
|
95
|
+
|
|
96
|
+
# Restore state
|
|
97
|
+
manager.system_messages = session_data["system_messages"]
|
|
98
|
+
|
|
99
|
+
# Add messages to short-term memory (including any summary messages)
|
|
100
|
+
for msg in session_data["messages"]:
|
|
101
|
+
manager.short_term.add_message(msg)
|
|
102
|
+
|
|
103
|
+
# Recalculate tokens
|
|
104
|
+
manager.current_tokens = manager._recalculate_current_tokens()
|
|
105
|
+
|
|
106
|
+
logger.info(
|
|
107
|
+
f"Loaded session {session_id}: "
|
|
108
|
+
f"{len(session_data['messages'])} messages, "
|
|
109
|
+
f"{manager.current_tokens} tokens"
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
return manager
|
|
113
|
+
|
|
114
|
+
@staticmethod
|
|
115
|
+
async def list_sessions(limit: int = 50) -> List[Dict[str, Any]]:
|
|
116
|
+
"""List saved sessions.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
limit: Maximum number of sessions to return
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
List of session summaries
|
|
123
|
+
"""
|
|
124
|
+
from .store import YamlFileMemoryStore
|
|
125
|
+
|
|
126
|
+
store = YamlFileMemoryStore()
|
|
127
|
+
return await store.list_sessions(limit=limit)
|
|
128
|
+
|
|
129
|
+
@staticmethod
|
|
130
|
+
async def find_latest_session() -> Optional[str]:
|
|
131
|
+
"""Find the most recently updated session ID.
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
Session ID or None if no sessions exist
|
|
135
|
+
"""
|
|
136
|
+
from .store import YamlFileMemoryStore
|
|
137
|
+
|
|
138
|
+
store = YamlFileMemoryStore()
|
|
139
|
+
return await store.find_latest_session()
|
|
140
|
+
|
|
141
|
+
@staticmethod
|
|
142
|
+
async def find_session_by_prefix(prefix: str) -> Optional[str]:
|
|
143
|
+
"""Find a session by ID prefix.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
prefix: Prefix of session UUID
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
Full session ID or None
|
|
150
|
+
"""
|
|
151
|
+
from .store import YamlFileMemoryStore
|
|
152
|
+
|
|
153
|
+
store = YamlFileMemoryStore()
|
|
154
|
+
return await store.find_session_by_prefix(prefix)
|
|
155
|
+
|
|
156
|
+
async def _ensure_session(self) -> None:
|
|
157
|
+
"""Lazily create session when first needed.
|
|
158
|
+
|
|
159
|
+
This avoids creating empty sessions when MemoryManager is instantiated
|
|
160
|
+
but no messages are ever added (e.g., user exits before running any task).
|
|
161
|
+
|
|
162
|
+
Raises:
|
|
163
|
+
RuntimeError: If session creation fails
|
|
164
|
+
"""
|
|
165
|
+
if not self._session_created:
|
|
166
|
+
try:
|
|
167
|
+
self.session_id = await self._store.create_session()
|
|
168
|
+
self._session_created = True
|
|
169
|
+
logger.info(f"Created new session: {self.session_id}")
|
|
170
|
+
except Exception as e:
|
|
171
|
+
logger.error(f"Failed to create session: {e}")
|
|
172
|
+
raise RuntimeError(f"Failed to create memory session: {e}") from e
|
|
173
|
+
|
|
174
|
+
async def add_message(self, message: LLMMessage, actual_tokens: Dict[str, int] = None) -> None:
|
|
175
|
+
"""Add a message to memory and trigger compression if needed.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
message: Message to add
|
|
179
|
+
actual_tokens: Optional dict with actual token counts from LLM response
|
|
180
|
+
Format: {"input": int, "output": int}
|
|
181
|
+
"""
|
|
182
|
+
# Ensure session exists before adding messages
|
|
183
|
+
await self._ensure_session()
|
|
184
|
+
|
|
185
|
+
# Track system messages separately
|
|
186
|
+
if message.role == "system":
|
|
187
|
+
self.system_messages.append(message)
|
|
188
|
+
return
|
|
189
|
+
|
|
190
|
+
# Count tokens (use actual if provided, otherwise estimate)
|
|
191
|
+
if actual_tokens:
|
|
192
|
+
# Use actual token counts from LLM response
|
|
193
|
+
# Note: input_tokens includes full context sent to API, not just new content
|
|
194
|
+
input_tokens = actual_tokens.get("input", 0)
|
|
195
|
+
output_tokens = actual_tokens.get("output", 0)
|
|
196
|
+
|
|
197
|
+
self.token_tracker.add_input_tokens(input_tokens)
|
|
198
|
+
self.token_tracker.add_output_tokens(output_tokens)
|
|
199
|
+
|
|
200
|
+
# Log API usage separately
|
|
201
|
+
logger.debug(
|
|
202
|
+
f"API usage: input={input_tokens}, output={output_tokens}, "
|
|
203
|
+
f"total={input_tokens + output_tokens}"
|
|
204
|
+
)
|
|
205
|
+
else:
|
|
206
|
+
# Estimate token count for non-API messages (tool results, etc.)
|
|
207
|
+
provider = self.llm.provider_name.lower()
|
|
208
|
+
model = self.llm.model
|
|
209
|
+
tokens = self.token_tracker.count_message_tokens(message, provider, model)
|
|
210
|
+
|
|
211
|
+
# Update token tracker
|
|
212
|
+
if message.role == "assistant":
|
|
213
|
+
self.token_tracker.add_output_tokens(tokens)
|
|
214
|
+
else:
|
|
215
|
+
self.token_tracker.add_input_tokens(tokens)
|
|
216
|
+
|
|
217
|
+
# Add to short-term memory
|
|
218
|
+
self.short_term.add_message(message)
|
|
219
|
+
|
|
220
|
+
# Recalculate current tokens based on actual stored content
|
|
221
|
+
# This gives accurate count for compression decisions
|
|
222
|
+
self.current_tokens = self._recalculate_current_tokens()
|
|
223
|
+
|
|
224
|
+
# Log memory state (stored content size, not API usage)
|
|
225
|
+
logger.debug(
|
|
226
|
+
f"Memory state: {self.current_tokens} stored tokens, "
|
|
227
|
+
f"{self.short_term.count()}/{Config.MEMORY_SHORT_TERM_SIZE} messages, "
|
|
228
|
+
f"full={self.short_term.is_full()}"
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
# Check if compression is needed
|
|
232
|
+
self.was_compressed_last_iteration = False
|
|
233
|
+
should_compress, reason = self._should_compress()
|
|
234
|
+
if should_compress:
|
|
235
|
+
logger.info(f"🗜️ Triggering compression: {reason}")
|
|
236
|
+
await self.compress()
|
|
237
|
+
else:
|
|
238
|
+
# Log compression check details
|
|
239
|
+
logger.debug(
|
|
240
|
+
f"Compression check: stored={self.current_tokens}, "
|
|
241
|
+
f"threshold={Config.MEMORY_COMPRESSION_THRESHOLD}, "
|
|
242
|
+
f"short_term_full={self.short_term.is_full()}"
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
def get_context_for_llm(self) -> List[LLMMessage]:
|
|
246
|
+
"""Get optimized context for LLM call.
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
List of messages: system messages + short-term messages (which includes summaries)
|
|
250
|
+
"""
|
|
251
|
+
context = []
|
|
252
|
+
|
|
253
|
+
# 1. Add system messages (always included)
|
|
254
|
+
context.extend(self.system_messages)
|
|
255
|
+
|
|
256
|
+
# 2. Add short-term memory (includes summary messages and recent messages)
|
|
257
|
+
context.extend(self.short_term.get_messages())
|
|
258
|
+
|
|
259
|
+
return context
|
|
260
|
+
|
|
261
|
+
def set_todo_context_provider(self, provider: Callable[[], Optional[str]]) -> None:
|
|
262
|
+
"""Set a callback to provide current todo context for compression.
|
|
263
|
+
|
|
264
|
+
The provider should return a formatted string of current todo items,
|
|
265
|
+
or None if no todos exist. This context will be injected into
|
|
266
|
+
compression summaries to preserve task state.
|
|
267
|
+
|
|
268
|
+
Args:
|
|
269
|
+
provider: Callable that returns current todo context string or None
|
|
270
|
+
"""
|
|
271
|
+
self._todo_context_provider = provider
|
|
272
|
+
|
|
273
|
+
async def compress(self, strategy: str = None) -> Optional[CompressedMemory]:
|
|
274
|
+
"""Compress current short-term memory.
|
|
275
|
+
|
|
276
|
+
After compression, the compressed messages (including any summary as user message)
|
|
277
|
+
are put back into short_term as regular messages.
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
strategy: Compression strategy (None = auto-select)
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
CompressedMemory object if compression was performed
|
|
284
|
+
"""
|
|
285
|
+
messages = self.short_term.get_messages()
|
|
286
|
+
message_count = len(messages)
|
|
287
|
+
|
|
288
|
+
if not messages:
|
|
289
|
+
logger.warning("No messages to compress")
|
|
290
|
+
return None
|
|
291
|
+
|
|
292
|
+
# Auto-select strategy if not specified
|
|
293
|
+
if strategy is None:
|
|
294
|
+
strategy = self._select_strategy(messages)
|
|
295
|
+
|
|
296
|
+
logger.info(f"🗜️ Compressing {message_count} messages using {strategy} strategy")
|
|
297
|
+
|
|
298
|
+
try:
|
|
299
|
+
# Get todo context if provider is set
|
|
300
|
+
todo_context = None
|
|
301
|
+
if self._todo_context_provider:
|
|
302
|
+
todo_context = self._todo_context_provider()
|
|
303
|
+
|
|
304
|
+
# Perform compression with TUI spinner
|
|
305
|
+
async with AsyncSpinner(terminal_ui.console, "Compressing memory..."):
|
|
306
|
+
compressed = await self.compressor.compress(
|
|
307
|
+
messages,
|
|
308
|
+
strategy=strategy,
|
|
309
|
+
target_tokens=self._calculate_target_tokens(),
|
|
310
|
+
todo_context=todo_context,
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
# Track compression results
|
|
314
|
+
self.compression_count += 1
|
|
315
|
+
self.was_compressed_last_iteration = True
|
|
316
|
+
self.last_compression_savings = compressed.token_savings
|
|
317
|
+
|
|
318
|
+
# Update token tracker
|
|
319
|
+
self.token_tracker.add_compression_savings(compressed.token_savings)
|
|
320
|
+
self.token_tracker.add_compression_cost(compressed.compressed_tokens)
|
|
321
|
+
|
|
322
|
+
# Remove compressed messages from short-term memory
|
|
323
|
+
self.short_term.remove_first(message_count)
|
|
324
|
+
|
|
325
|
+
# Get any remaining messages (added after compression started)
|
|
326
|
+
remaining_messages = self.short_term.get_messages()
|
|
327
|
+
self.short_term.clear()
|
|
328
|
+
|
|
329
|
+
# Add compressed messages (summary + preserved, already combined by compressor)
|
|
330
|
+
for msg in compressed.messages:
|
|
331
|
+
self.short_term.add_message(msg)
|
|
332
|
+
|
|
333
|
+
# Add any remaining messages
|
|
334
|
+
for msg in remaining_messages:
|
|
335
|
+
self.short_term.add_message(msg)
|
|
336
|
+
|
|
337
|
+
# Update current token count
|
|
338
|
+
old_tokens = self.current_tokens
|
|
339
|
+
self.current_tokens = self._recalculate_current_tokens()
|
|
340
|
+
|
|
341
|
+
# Log compression results
|
|
342
|
+
logger.info(
|
|
343
|
+
f"✅ Compression complete: {compressed.original_tokens} → {compressed.compressed_tokens} tokens "
|
|
344
|
+
f"({compressed.savings_percentage:.1f}% saved, ratio: {compressed.compression_ratio:.2f}), "
|
|
345
|
+
f"context: {old_tokens} → {self.current_tokens} tokens, "
|
|
346
|
+
f"short_term now has {self.short_term.count()} messages"
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
return compressed
|
|
350
|
+
|
|
351
|
+
except Exception as e:
|
|
352
|
+
logger.error(f"Compression failed: {e}")
|
|
353
|
+
return None
|
|
354
|
+
|
|
355
|
+
def _should_compress(self) -> tuple[bool, Optional[str]]:
|
|
356
|
+
"""Check if compression should be triggered.
|
|
357
|
+
|
|
358
|
+
Returns:
|
|
359
|
+
Tuple of (should_compress, reason)
|
|
360
|
+
"""
|
|
361
|
+
if not Config.MEMORY_ENABLED:
|
|
362
|
+
return False, "compression_disabled"
|
|
363
|
+
|
|
364
|
+
# Hard limit: must compress
|
|
365
|
+
if self.current_tokens > Config.MEMORY_COMPRESSION_THRESHOLD:
|
|
366
|
+
return (
|
|
367
|
+
True,
|
|
368
|
+
f"hard_limit ({self.current_tokens} > {Config.MEMORY_COMPRESSION_THRESHOLD})",
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
# CRITICAL: Compress when short-term memory is full to prevent eviction
|
|
372
|
+
# If we don't compress, the next message will cause deque to evict the oldest message,
|
|
373
|
+
# which may break tool_use/tool_result pairs
|
|
374
|
+
if self.short_term.is_full():
|
|
375
|
+
return (
|
|
376
|
+
True,
|
|
377
|
+
f"short_term_full ({self.short_term.count()}/{Config.MEMORY_SHORT_TERM_SIZE} messages, "
|
|
378
|
+
f"current tokens: {self.current_tokens})",
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
return False, None
|
|
382
|
+
|
|
383
|
+
def _select_strategy(self, messages: List[LLMMessage]) -> str:
|
|
384
|
+
"""Auto-select compression strategy based on message characteristics.
|
|
385
|
+
|
|
386
|
+
Args:
|
|
387
|
+
messages: Messages to analyze
|
|
388
|
+
|
|
389
|
+
Returns:
|
|
390
|
+
Strategy name
|
|
391
|
+
"""
|
|
392
|
+
# Check for tool calls
|
|
393
|
+
has_tool_calls = any(self._message_has_tool_calls(msg) for msg in messages)
|
|
394
|
+
|
|
395
|
+
# Select strategy
|
|
396
|
+
if has_tool_calls:
|
|
397
|
+
# Preserve tool calls
|
|
398
|
+
return CompressionStrategy.SELECTIVE
|
|
399
|
+
elif len(messages) < 5:
|
|
400
|
+
# Too few messages, just delete
|
|
401
|
+
return CompressionStrategy.DELETION
|
|
402
|
+
else:
|
|
403
|
+
# Default: sliding window
|
|
404
|
+
return CompressionStrategy.SLIDING_WINDOW
|
|
405
|
+
|
|
406
|
+
def _message_has_tool_calls(self, message: LLMMessage) -> bool:
|
|
407
|
+
"""Check if message contains tool calls.
|
|
408
|
+
|
|
409
|
+
Handles both new format (tool_calls field) and legacy format (content blocks).
|
|
410
|
+
|
|
411
|
+
Args:
|
|
412
|
+
message: Message to check
|
|
413
|
+
|
|
414
|
+
Returns:
|
|
415
|
+
True if contains tool calls
|
|
416
|
+
"""
|
|
417
|
+
# New format: check tool_calls field
|
|
418
|
+
if hasattr(message, "tool_calls") and message.tool_calls:
|
|
419
|
+
return True
|
|
420
|
+
|
|
421
|
+
# New format: tool role message
|
|
422
|
+
if message.role == "tool":
|
|
423
|
+
return True
|
|
424
|
+
|
|
425
|
+
# Legacy/centralized check on content
|
|
426
|
+
return content_has_tool_calls(message.content)
|
|
427
|
+
|
|
428
|
+
def _calculate_target_tokens(self) -> int:
|
|
429
|
+
"""Calculate target token count for compression.
|
|
430
|
+
|
|
431
|
+
Returns:
|
|
432
|
+
Target token count
|
|
433
|
+
"""
|
|
434
|
+
original_tokens = self.current_tokens
|
|
435
|
+
target = int(original_tokens * Config.MEMORY_COMPRESSION_RATIO)
|
|
436
|
+
return max(target, 500) # Minimum 500 tokens for summary
|
|
437
|
+
|
|
438
|
+
def _recalculate_current_tokens(self) -> int:
|
|
439
|
+
"""Recalculate current token count from scratch.
|
|
440
|
+
|
|
441
|
+
Returns:
|
|
442
|
+
Current token count
|
|
443
|
+
"""
|
|
444
|
+
provider = self.llm.provider_name.lower()
|
|
445
|
+
model = self.llm.model
|
|
446
|
+
|
|
447
|
+
total = 0
|
|
448
|
+
|
|
449
|
+
# Count system messages
|
|
450
|
+
for msg in self.system_messages:
|
|
451
|
+
total += self.token_tracker.count_message_tokens(msg, provider, model)
|
|
452
|
+
|
|
453
|
+
# Count short-term messages (includes summary messages)
|
|
454
|
+
for msg in self.short_term.get_messages():
|
|
455
|
+
total += self.token_tracker.count_message_tokens(msg, provider, model)
|
|
456
|
+
|
|
457
|
+
return total
|
|
458
|
+
|
|
459
|
+
def get_stats(self) -> Dict[str, Any]:
|
|
460
|
+
"""Get memory statistics.
|
|
461
|
+
|
|
462
|
+
Returns:
|
|
463
|
+
Dict with statistics
|
|
464
|
+
"""
|
|
465
|
+
return {
|
|
466
|
+
"current_tokens": self.current_tokens,
|
|
467
|
+
"total_input_tokens": self.token_tracker.total_input_tokens,
|
|
468
|
+
"total_output_tokens": self.token_tracker.total_output_tokens,
|
|
469
|
+
"compression_count": self.compression_count,
|
|
470
|
+
"total_savings": self.token_tracker.compression_savings,
|
|
471
|
+
"compression_cost": self.token_tracker.compression_cost,
|
|
472
|
+
"net_savings": self.token_tracker.compression_savings
|
|
473
|
+
- self.token_tracker.compression_cost,
|
|
474
|
+
"short_term_count": self.short_term.count(),
|
|
475
|
+
"total_cost": self.token_tracker.get_total_cost(self.llm.model),
|
|
476
|
+
}
|
|
477
|
+
|
|
478
|
+
async def save_memory(self):
|
|
479
|
+
"""Save current memory state to store.
|
|
480
|
+
|
|
481
|
+
This saves the complete memory state including:
|
|
482
|
+
- System messages
|
|
483
|
+
- Short-term messages (which includes summary messages after compression)
|
|
484
|
+
|
|
485
|
+
Call this method after completing a task or at key checkpoints.
|
|
486
|
+
"""
|
|
487
|
+
# Skip if no session was created (no messages were ever added)
|
|
488
|
+
if not self._store or not self._session_created or not self.session_id:
|
|
489
|
+
logger.debug("Skipping save_memory: no session created")
|
|
490
|
+
return
|
|
491
|
+
|
|
492
|
+
messages = self.short_term.get_messages()
|
|
493
|
+
|
|
494
|
+
# Skip saving if there are no messages (empty conversation)
|
|
495
|
+
if not messages and not self.system_messages:
|
|
496
|
+
logger.debug(f"Skipping save_memory: no messages to save for session {self.session_id}")
|
|
497
|
+
return
|
|
498
|
+
|
|
499
|
+
await self._store.save_memory(
|
|
500
|
+
session_id=self.session_id,
|
|
501
|
+
system_messages=self.system_messages,
|
|
502
|
+
messages=messages,
|
|
503
|
+
)
|
|
504
|
+
logger.info(f"Saved memory state for session {self.session_id}")
|
|
505
|
+
|
|
506
|
+
def reset(self):
|
|
507
|
+
"""Reset memory manager state."""
|
|
508
|
+
self.short_term.clear()
|
|
509
|
+
self.system_messages.clear()
|
|
510
|
+
self.token_tracker.reset()
|
|
511
|
+
self.current_tokens = 0
|
|
512
|
+
self.was_compressed_last_iteration = False
|
|
513
|
+
self.last_compression_savings = 0
|
|
514
|
+
self.compression_count = 0
|
|
515
|
+
|
|
516
|
+
def rollback_incomplete_exchange(self) -> None:
|
|
517
|
+
"""Rollback the last incomplete assistant response with tool_calls.
|
|
518
|
+
|
|
519
|
+
This is used when a task is interrupted before tool execution completes.
|
|
520
|
+
It removes the assistant message if it contains tool_calls but no results.
|
|
521
|
+
The user message is preserved so the agent can see the original question.
|
|
522
|
+
|
|
523
|
+
This prevents API errors about missing tool responses on the next turn.
|
|
524
|
+
"""
|
|
525
|
+
messages = self.short_term.get_messages()
|
|
526
|
+
if not messages:
|
|
527
|
+
return
|
|
528
|
+
|
|
529
|
+
# Check if last message is an assistant message with tool_calls
|
|
530
|
+
last_msg = messages[-1]
|
|
531
|
+
if last_msg.role == "assistant" and self._message_has_tool_calls(last_msg):
|
|
532
|
+
# Remove only the assistant message with tool_calls
|
|
533
|
+
# Keep the user message so the agent can still see the question
|
|
534
|
+
self.short_term.remove_last(1)
|
|
535
|
+
logger.debug("Removed incomplete assistant message with tool_calls")
|
|
536
|
+
|
|
537
|
+
# Recalculate token count
|
|
538
|
+
self.current_tokens = self._recalculate_current_tokens()
|
memory/serialization.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
"""Shared serialization logic for memory persistence.
|
|
2
|
+
|
|
3
|
+
Provides serialize/deserialize functions for LLMMessage objects,
|
|
4
|
+
used by all persistence backends.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
from typing import Any, Dict
|
|
9
|
+
|
|
10
|
+
from llm.message_types import LLMMessage
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def serialize_content(content: Any) -> Any:
|
|
14
|
+
"""Serialize message content, handling complex objects.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
content: Message content (can be string, list, dict, or None)
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
JSON-serializable content
|
|
21
|
+
"""
|
|
22
|
+
if content is None:
|
|
23
|
+
return None
|
|
24
|
+
elif isinstance(content, str):
|
|
25
|
+
return content
|
|
26
|
+
elif isinstance(content, (list, dict)):
|
|
27
|
+
try:
|
|
28
|
+
json.dumps(content)
|
|
29
|
+
return content
|
|
30
|
+
except (TypeError, ValueError):
|
|
31
|
+
return str(content)
|
|
32
|
+
else:
|
|
33
|
+
return str(content)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def serialize_message(message: LLMMessage) -> Dict[str, Any]:
|
|
37
|
+
"""Serialize an LLMMessage to a JSON/YAML-serializable dict.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
message: LLMMessage to serialize
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
Serializable dict
|
|
44
|
+
"""
|
|
45
|
+
result: Dict[str, Any] = {
|
|
46
|
+
"role": message.role,
|
|
47
|
+
"content": serialize_content(message.content),
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
# For assistant messages, always include tool_calls (even if None) for completeness
|
|
51
|
+
if message.role == "assistant":
|
|
52
|
+
result["tool_calls"] = (
|
|
53
|
+
message.tool_calls if (hasattr(message, "tool_calls") and message.tool_calls) else None
|
|
54
|
+
)
|
|
55
|
+
elif hasattr(message, "tool_calls") and message.tool_calls:
|
|
56
|
+
result["tool_calls"] = message.tool_calls
|
|
57
|
+
|
|
58
|
+
if hasattr(message, "tool_call_id") and message.tool_call_id:
|
|
59
|
+
result["tool_call_id"] = message.tool_call_id
|
|
60
|
+
|
|
61
|
+
if hasattr(message, "name") and message.name:
|
|
62
|
+
result["name"] = message.name
|
|
63
|
+
|
|
64
|
+
return result
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def deserialize_message(data: Dict[str, Any]) -> LLMMessage:
|
|
68
|
+
"""Deserialize a dict to an LLMMessage.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
data: Dict with message data
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
LLMMessage instance
|
|
75
|
+
"""
|
|
76
|
+
return LLMMessage(
|
|
77
|
+
role=data["role"],
|
|
78
|
+
content=data.get("content"),
|
|
79
|
+
tool_calls=data.get("tool_calls"),
|
|
80
|
+
tool_call_id=data.get("tool_call_id"),
|
|
81
|
+
name=data.get("name"),
|
|
82
|
+
)
|
memory/short_term.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
"""Short-term memory management with fixed-size window."""
|
|
2
|
+
|
|
3
|
+
from collections import deque
|
|
4
|
+
from typing import List
|
|
5
|
+
|
|
6
|
+
from llm.base import LLMMessage
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ShortTermMemory:
|
|
10
|
+
"""Manages recent messages in a fixed-size sliding window."""
|
|
11
|
+
|
|
12
|
+
def __init__(self, max_size: int = 20):
|
|
13
|
+
"""Initialize short-term memory.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
max_size: Maximum number of messages to keep
|
|
17
|
+
"""
|
|
18
|
+
self.max_size = max_size
|
|
19
|
+
self.messages = deque(maxlen=max_size)
|
|
20
|
+
|
|
21
|
+
def add_message(self, message: LLMMessage) -> None:
|
|
22
|
+
"""Add a message to short-term memory.
|
|
23
|
+
|
|
24
|
+
Automatically evicts oldest message if at capacity.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
message: LLMMessage to add
|
|
28
|
+
"""
|
|
29
|
+
self.messages.append(message)
|
|
30
|
+
|
|
31
|
+
def get_messages(self) -> List[LLMMessage]:
|
|
32
|
+
"""Get all messages in short-term memory.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
List of messages, oldest to newest
|
|
36
|
+
"""
|
|
37
|
+
return list(self.messages)
|
|
38
|
+
|
|
39
|
+
def clear(self) -> List[LLMMessage]:
|
|
40
|
+
"""Clear all messages and return them.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
List of all messages that were cleared
|
|
44
|
+
"""
|
|
45
|
+
messages = list(self.messages)
|
|
46
|
+
self.messages.clear()
|
|
47
|
+
return messages
|
|
48
|
+
|
|
49
|
+
def remove_first(self, count: int) -> List[LLMMessage]:
|
|
50
|
+
"""Remove the first N messages (oldest) from memory.
|
|
51
|
+
|
|
52
|
+
This is useful after compression to remove only the compressed messages
|
|
53
|
+
while preserving any new messages that arrived during compression.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
count: Number of messages to remove from the front
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
List of removed messages
|
|
60
|
+
"""
|
|
61
|
+
return [self.messages.popleft() for _ in range(min(count, len(self.messages)))]
|
|
62
|
+
|
|
63
|
+
def is_full(self) -> bool:
|
|
64
|
+
"""Check if short-term memory is at capacity.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
True if at max capacity
|
|
68
|
+
"""
|
|
69
|
+
return len(self.messages) >= self.max_size
|
|
70
|
+
|
|
71
|
+
def count(self) -> int:
|
|
72
|
+
"""Get current message count.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
Number of messages in short-term memory
|
|
76
|
+
"""
|
|
77
|
+
return len(self.messages)
|
|
78
|
+
|
|
79
|
+
def remove_last(self, count: int = 1) -> None:
|
|
80
|
+
"""Remove the last N messages (newest) from memory.
|
|
81
|
+
|
|
82
|
+
This is useful for rolling back incomplete exchanges (e.g., after interruption).
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
count: Number of messages to remove from the end (default: 1)
|
|
86
|
+
"""
|
|
87
|
+
for _ in range(min(count, len(self.messages))):
|
|
88
|
+
self.messages.pop()
|