tinyagent-py 0.0.6__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinyagent/hooks/rich_code_ui_callback.py +434 -0
- tinyagent/memory_manager.py +1067 -0
- tinyagent/storage/__init__.py +2 -4
- tinyagent/storage/sqlite_storage.py +4 -1
- tinyagent/tiny_agent.py +3 -2
- {tinyagent_py-0.0.6.dist-info → tinyagent_py-0.0.7.dist-info}/METADATA +2 -2
- {tinyagent_py-0.0.6.dist-info → tinyagent_py-0.0.7.dist-info}/RECORD +10 -10
- {tinyagent_py-0.0.6.dist-info → tinyagent_py-0.0.7.dist-info}/WHEEL +1 -1
- tinyagent/hooks/agno_storage_hook.py +0 -128
- tinyagent/storage/agno_storage.py +0 -114
- {tinyagent_py-0.0.6.dist-info → tinyagent_py-0.0.7.dist-info}/licenses/LICENSE +0 -0
- {tinyagent_py-0.0.6.dist-info → tinyagent_py-0.0.7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1067 @@
|
|
1
|
+
#
|
2
|
+
#
|
3
|
+
#
|
4
|
+
#
|
5
|
+
#
|
6
|
+
# tool call and tool error with same tool id should have the same importance level, otherwise LLM would reject it.
|
7
|
+
#- tool call ==> tool error, should be MEDIUM . if there is no pair of tool call ==> tool error after that (It is the last error)
|
8
|
+
#- should be LOW, if another pair of tool call ==> tool response (response without error) happens after it.
|
9
|
+
#- if this happens at the end of conversation, the rule of HIGH importance will overrule everything, so they would be HIGh priority.
|
10
|
+
# Last message pairs should be high priority.
|
11
|
+
#
|
12
|
+
# tool_call => tool is a pair, and share the same importance level
|
13
|
+
#
|
14
|
+
#
|
15
|
+
# if 'role': 'assistant',
|
16
|
+
# 'content': '',
|
17
|
+
# 'tool_calls => function ==> name
|
18
|
+
#
|
19
|
+
# should share same level of importance for it's response with role = tool and same tool_call_id
|
20
|
+
|
21
|
+
# last 4 pair in the history should have HIGH importance
|
22
|
+
#
|
23
|
+
#
|
24
|
+
#
|
25
|
+
# memory_manager.py
|
26
|
+
import logging
|
27
|
+
from typing import Dict, List, Optional, Any, Tuple, Set
|
28
|
+
from dataclasses import dataclass, field
|
29
|
+
from enum import Enum
|
30
|
+
import json
|
31
|
+
import time
|
32
|
+
from abc import ABC, abstractmethod
|
33
|
+
|
34
|
+
logger = logging.getLogger(__name__)
|
35
|
+
|
36
|
+
class MessageImportance(Enum):
|
37
|
+
"""Defines the importance levels for messages."""
|
38
|
+
CRITICAL = "critical" # Must always be kept (system, final answers, etc.)
|
39
|
+
HIGH = "high" # Important context, keep unless absolutely necessary
|
40
|
+
MEDIUM = "medium" # Standard conversation, can be summarized
|
41
|
+
LOW = "low" # Tool errors, failed attempts, can be removed
|
42
|
+
TEMP = "temp" # Temporary messages, remove after success
|
43
|
+
|
44
|
+
class MessageType(Enum):
|
45
|
+
"""Categorizes different types of messages."""
|
46
|
+
SYSTEM = "system"
|
47
|
+
USER_QUERY = "user_query"
|
48
|
+
ASSISTANT_RESPONSE = "assistant_response"
|
49
|
+
TOOL_CALL = "tool_call"
|
50
|
+
TOOL_RESPONSE = "tool_response"
|
51
|
+
TOOL_ERROR = "tool_error"
|
52
|
+
FINAL_ANSWER = "final_answer"
|
53
|
+
QUESTION_TO_USER = "question_to_user"
|
54
|
+
|
55
|
+
@dataclass
|
56
|
+
class MessageMetadata:
|
57
|
+
"""Metadata for tracking message importance and lifecycle."""
|
58
|
+
message_type: MessageType
|
59
|
+
importance: MessageImportance
|
60
|
+
created_at: float
|
61
|
+
token_count: int = 0
|
62
|
+
is_error: bool = False
|
63
|
+
error_resolved: bool = False
|
64
|
+
part_of_task: Optional[str] = None # Task/subtask identifier
|
65
|
+
task_completed: bool = False
|
66
|
+
can_summarize: bool = True
|
67
|
+
summary: Optional[str] = None
|
68
|
+
related_messages: List[int] = field(default_factory=list) # Indices of related messages
|
69
|
+
tool_call_id: Optional[str] = None # To track tool call/response pairs
|
70
|
+
|
71
|
+
class MemoryStrategy(ABC):
|
72
|
+
"""Abstract base class for memory management strategies."""
|
73
|
+
|
74
|
+
@abstractmethod
|
75
|
+
def should_keep_message(self, message: Dict[str, Any], metadata: MessageMetadata,
|
76
|
+
context: Dict[str, Any]) -> bool:
|
77
|
+
"""Determine if a message should be kept in memory."""
|
78
|
+
pass
|
79
|
+
|
80
|
+
@abstractmethod
|
81
|
+
def get_priority_score(self, message: Dict[str, Any], metadata: MessageMetadata) -> float:
|
82
|
+
"""Get priority score for message ranking."""
|
83
|
+
pass
|
84
|
+
|
85
|
+
class ConservativeStrategy(MemoryStrategy):
|
86
|
+
"""Conservative strategy - keeps more messages, summarizes less aggressively."""
|
87
|
+
|
88
|
+
def should_keep_message(self, message: Dict[str, Any], metadata: MessageMetadata,
|
89
|
+
context: Dict[str, Any]) -> bool:
|
90
|
+
# Always keep critical messages
|
91
|
+
if metadata.importance == MessageImportance.CRITICAL:
|
92
|
+
return True
|
93
|
+
|
94
|
+
# Keep high importance messages unless we're really tight on space
|
95
|
+
if metadata.importance == MessageImportance.HIGH:
|
96
|
+
return context.get('memory_pressure', 0) < 0.8
|
97
|
+
|
98
|
+
# Keep recent messages
|
99
|
+
if time.time() - metadata.created_at < 300: # 5 minutes
|
100
|
+
return True
|
101
|
+
|
102
|
+
# Remove resolved errors and temp messages
|
103
|
+
if metadata.importance == MessageImportance.TEMP:
|
104
|
+
return False
|
105
|
+
|
106
|
+
if metadata.is_error and metadata.error_resolved:
|
107
|
+
return False
|
108
|
+
|
109
|
+
return context.get('memory_pressure', 0) < 0.6
|
110
|
+
|
111
|
+
def get_priority_score(self, message: Dict[str, Any], metadata: MessageMetadata) -> float:
|
112
|
+
base_score = {
|
113
|
+
MessageImportance.CRITICAL: 1000,
|
114
|
+
MessageImportance.HIGH: 100,
|
115
|
+
MessageImportance.MEDIUM: 50,
|
116
|
+
MessageImportance.LOW: 10,
|
117
|
+
MessageImportance.TEMP: 1
|
118
|
+
}[metadata.importance]
|
119
|
+
|
120
|
+
# Boost recent messages
|
121
|
+
age_factor = max(0.1, 1.0 - (time.time() - metadata.created_at) / 3600)
|
122
|
+
|
123
|
+
# Penalize errors
|
124
|
+
error_penalty = 0.5 if metadata.is_error else 1.0
|
125
|
+
|
126
|
+
return base_score * age_factor * error_penalty
|
127
|
+
|
128
|
+
class AggressiveStrategy(MemoryStrategy):
|
129
|
+
"""Aggressive strategy - removes more messages, summarizes more aggressively."""
|
130
|
+
|
131
|
+
def should_keep_message(self, message: Dict[str, Any], metadata: MessageMetadata,
|
132
|
+
context: Dict[str, Any]) -> bool:
|
133
|
+
# Always keep critical messages
|
134
|
+
if metadata.importance == MessageImportance.CRITICAL:
|
135
|
+
return True
|
136
|
+
|
137
|
+
# Be more selective with high importance
|
138
|
+
if metadata.importance == MessageImportance.HIGH:
|
139
|
+
return context.get('memory_pressure', 0) < 0.5 and (time.time() - metadata.created_at < 600)
|
140
|
+
|
141
|
+
# Only keep very recent medium importance messages
|
142
|
+
if metadata.importance == MessageImportance.MEDIUM:
|
143
|
+
return time.time() - metadata.created_at < 180 # 3 minutes
|
144
|
+
|
145
|
+
# Remove low importance and temp messages quickly
|
146
|
+
return False
|
147
|
+
|
148
|
+
def get_priority_score(self, message: Dict[str, Any], metadata: MessageMetadata) -> float:
|
149
|
+
base_score = {
|
150
|
+
MessageImportance.CRITICAL: 1000,
|
151
|
+
MessageImportance.HIGH: 80,
|
152
|
+
MessageImportance.MEDIUM: 30,
|
153
|
+
MessageImportance.LOW: 5,
|
154
|
+
MessageImportance.TEMP: 1
|
155
|
+
}[metadata.importance]
|
156
|
+
|
157
|
+
# Strong recency bias
|
158
|
+
age_factor = max(0.05, 1.0 - (time.time() - metadata.created_at) / 1800)
|
159
|
+
|
160
|
+
# Heavy error penalty
|
161
|
+
error_penalty = 0.2 if metadata.is_error else 1.0
|
162
|
+
|
163
|
+
return base_score * age_factor * error_penalty
|
164
|
+
|
165
|
+
class BalancedStrategy(MemoryStrategy):
|
166
|
+
"""Balanced strategy - moderate approach to memory management."""
|
167
|
+
|
168
|
+
def should_keep_message(self, message: Dict[str, Any], metadata: MessageMetadata,
|
169
|
+
context: Dict[str, Any]) -> bool:
|
170
|
+
# Always keep critical messages
|
171
|
+
if metadata.importance == MessageImportance.CRITICAL:
|
172
|
+
return True
|
173
|
+
|
174
|
+
# Keep high importance messages unless high memory pressure
|
175
|
+
if metadata.importance == MessageImportance.HIGH:
|
176
|
+
return context.get('memory_pressure', 0) < 0.7
|
177
|
+
|
178
|
+
# Keep recent medium importance messages
|
179
|
+
if metadata.importance == MessageImportance.MEDIUM:
|
180
|
+
return time.time() - metadata.created_at < 450 # 7.5 minutes
|
181
|
+
|
182
|
+
# Remove resolved errors and temp messages
|
183
|
+
if metadata.is_error and metadata.error_resolved:
|
184
|
+
return False
|
185
|
+
|
186
|
+
if metadata.importance == MessageImportance.TEMP:
|
187
|
+
return time.time() - metadata.created_at < 60 # 1 minute
|
188
|
+
|
189
|
+
return context.get('memory_pressure', 0) < 0.4
|
190
|
+
|
191
|
+
def get_priority_score(self, message: Dict[str, Any], metadata: MessageMetadata) -> float:
|
192
|
+
base_score = {
|
193
|
+
MessageImportance.CRITICAL: 1000,
|
194
|
+
MessageImportance.HIGH: 90,
|
195
|
+
MessageImportance.MEDIUM: 40,
|
196
|
+
MessageImportance.LOW: 8,
|
197
|
+
MessageImportance.TEMP: 2
|
198
|
+
}[metadata.importance]
|
199
|
+
|
200
|
+
# Moderate recency bias
|
201
|
+
age_factor = max(0.1, 1.0 - (time.time() - metadata.created_at) / 2400)
|
202
|
+
|
203
|
+
# Moderate error penalty
|
204
|
+
error_penalty = 0.3 if metadata.is_error else 1.0
|
205
|
+
|
206
|
+
return base_score * age_factor * error_penalty
|
207
|
+
|
208
|
+
class MemoryManager:
|
209
|
+
"""
|
210
|
+
Advanced memory management system for TinyAgent.
|
211
|
+
|
212
|
+
Features:
|
213
|
+
- Message importance tracking with dynamic positioning
|
214
|
+
- Intelligent message removal and summarization
|
215
|
+
- Multiple memory management strategies
|
216
|
+
- Task-based message grouping
|
217
|
+
- Error recovery tracking
|
218
|
+
- Tool call/response pair integrity
|
219
|
+
"""
|
220
|
+
|
221
|
+
_DEFAULT_NUM_RECENT_PAIRS_HIGH_IMPORTANCE = 3
|
222
|
+
_DEFAULT_NUM_INITIAL_PAIRS_CRITICAL = 3
|
223
|
+
|
224
|
+
def __init__(
|
225
|
+
self,
|
226
|
+
max_tokens: int = 8000,
|
227
|
+
target_tokens: int = 6000,
|
228
|
+
strategy: MemoryStrategy = None,
|
229
|
+
enable_summarization: bool = True,
|
230
|
+
logger: Optional[logging.Logger] = None,
|
231
|
+
num_recent_pairs_high_importance: Optional[int] = None,
|
232
|
+
num_initial_pairs_critical: Optional[int] = None
|
233
|
+
):
|
234
|
+
self.max_tokens = max_tokens
|
235
|
+
self.target_tokens = target_tokens
|
236
|
+
self.strategy = strategy or BalancedStrategy()
|
237
|
+
self.enable_summarization = enable_summarization
|
238
|
+
self.logger = logger or logging.getLogger(__name__)
|
239
|
+
|
240
|
+
# Configure importance thresholds
|
241
|
+
self._num_recent_pairs_for_high_importance = (
|
242
|
+
num_recent_pairs_high_importance
|
243
|
+
if num_recent_pairs_high_importance is not None
|
244
|
+
else self._DEFAULT_NUM_RECENT_PAIRS_HIGH_IMPORTANCE
|
245
|
+
)
|
246
|
+
|
247
|
+
self._num_initial_pairs_critical = (
|
248
|
+
num_initial_pairs_critical
|
249
|
+
if num_initial_pairs_critical is not None
|
250
|
+
else self._DEFAULT_NUM_INITIAL_PAIRS_CRITICAL
|
251
|
+
)
|
252
|
+
|
253
|
+
# Message metadata storage
|
254
|
+
self.message_metadata: List[MessageMetadata] = []
|
255
|
+
|
256
|
+
# Task tracking
|
257
|
+
self.active_tasks: Set[str] = set()
|
258
|
+
self.completed_tasks: Set[str] = set()
|
259
|
+
|
260
|
+
# Summary storage
|
261
|
+
self.conversation_summary: Optional[str] = None
|
262
|
+
self.task_summaries: Dict[str, str] = {}
|
263
|
+
|
264
|
+
# Statistics
|
265
|
+
self.stats = {
|
266
|
+
'messages_removed': 0,
|
267
|
+
'messages_summarized': 0,
|
268
|
+
'tokens_saved': 0,
|
269
|
+
'memory_optimizations': 0
|
270
|
+
}
|
271
|
+
|
272
|
+
# Tool call tracking for proper pairing
|
273
|
+
self._tool_call_pairs: Dict[str, Tuple[int, int]] = {} # tool_call_id -> (call_index, response_index)
|
274
|
+
self._resolved_errors: Set[str] = set() # Track resolved error tool_call_ids
|
275
|
+
|
276
|
+
def _count_message_tokens(self, message: Dict[str, Any], token_counter: callable) -> int:
|
277
|
+
"""
|
278
|
+
Properly count tokens in a message, including tool calls.
|
279
|
+
|
280
|
+
Args:
|
281
|
+
message: The message to count tokens for
|
282
|
+
token_counter: Function to count tokens in text
|
283
|
+
|
284
|
+
Returns:
|
285
|
+
Total token count for the message
|
286
|
+
"""
|
287
|
+
total_tokens = 0
|
288
|
+
|
289
|
+
# Count content tokens
|
290
|
+
content = message.get('content', '')
|
291
|
+
if content:
|
292
|
+
total_tokens += token_counter(str(content))
|
293
|
+
|
294
|
+
# Count tool call tokens
|
295
|
+
if 'tool_calls' in message and message['tool_calls']:
|
296
|
+
for tool_call in message['tool_calls']:
|
297
|
+
# Count function name
|
298
|
+
if isinstance(tool_call, dict):
|
299
|
+
if 'function' in tool_call:
|
300
|
+
func_data = tool_call['function']
|
301
|
+
if 'name' in func_data:
|
302
|
+
total_tokens += token_counter(func_data['name'])
|
303
|
+
if 'arguments' in func_data:
|
304
|
+
total_tokens += token_counter(str(func_data['arguments']))
|
305
|
+
# Count tool call ID
|
306
|
+
if 'id' in tool_call:
|
307
|
+
total_tokens += token_counter(str(tool_call['id']))
|
308
|
+
elif hasattr(tool_call, 'function'):
|
309
|
+
# Handle object-style tool calls
|
310
|
+
if hasattr(tool_call.function, 'name'):
|
311
|
+
total_tokens += token_counter(tool_call.function.name)
|
312
|
+
if hasattr(tool_call.function, 'arguments'):
|
313
|
+
total_tokens += token_counter(str(tool_call.function.arguments))
|
314
|
+
if hasattr(tool_call, 'id'):
|
315
|
+
total_tokens += token_counter(str(tool_call.id))
|
316
|
+
|
317
|
+
# Count tool call ID for tool responses
|
318
|
+
if 'tool_call_id' in message and message['tool_call_id']:
|
319
|
+
total_tokens += token_counter(str(message['tool_call_id']))
|
320
|
+
|
321
|
+
# Count tool name for tool responses
|
322
|
+
if 'name' in message and message.get('role') == 'tool':
|
323
|
+
total_tokens += token_counter(str(message['name']))
|
324
|
+
|
325
|
+
return total_tokens
|
326
|
+
|
327
|
+
def _calculate_dynamic_importance(
|
328
|
+
self,
|
329
|
+
message: Dict[str, Any],
|
330
|
+
index: int,
|
331
|
+
total_messages: int,
|
332
|
+
message_pairs: List[Tuple[int, int]]
|
333
|
+
) -> MessageImportance:
|
334
|
+
"""
|
335
|
+
Calculate dynamic importance based on position, content, and context.
|
336
|
+
|
337
|
+
Args:
|
338
|
+
message: The message to evaluate
|
339
|
+
index: Position of the message in the conversation
|
340
|
+
total_messages: Total number of messages
|
341
|
+
message_pairs: List of message pair ranges
|
342
|
+
|
343
|
+
Returns:
|
344
|
+
MessageImportance level
|
345
|
+
"""
|
346
|
+
role = message.get('role', '')
|
347
|
+
content = str(message.get('content', ''))
|
348
|
+
|
349
|
+
# System messages are always CRITICAL
|
350
|
+
if role == 'system':
|
351
|
+
return MessageImportance.CRITICAL
|
352
|
+
|
353
|
+
# Check if this is a final_answer or ask_question tool call (HIGH importance)
|
354
|
+
if role == 'assistant' and message.get('tool_calls'):
|
355
|
+
tool_calls = message.get('tool_calls', [])
|
356
|
+
if any(tc.get('function', {}).get('name') in ['final_answer', 'ask_question']
|
357
|
+
for tc in tool_calls):
|
358
|
+
return MessageImportance.HIGH
|
359
|
+
|
360
|
+
# Check if this is an error response (HIGH importance until resolved)
|
361
|
+
if self._is_tool_error_response(message):
|
362
|
+
return MessageImportance.HIGH
|
363
|
+
|
364
|
+
# Position-based importance (first N pairs are CRITICAL, last N pairs are HIGH)
|
365
|
+
if total_messages <= 10:
|
366
|
+
# For short conversations, keep everything at MEDIUM or higher
|
367
|
+
return MessageImportance.MEDIUM
|
368
|
+
|
369
|
+
# Find which pair this message belongs to
|
370
|
+
current_pair_index = None
|
371
|
+
for pair_idx, (start_idx, end_idx) in enumerate(message_pairs):
|
372
|
+
if start_idx <= index <= end_idx:
|
373
|
+
current_pair_index = pair_idx
|
374
|
+
break
|
375
|
+
|
376
|
+
if current_pair_index is not None:
|
377
|
+
# First N pairs are CRITICAL
|
378
|
+
if current_pair_index < self._num_initial_pairs_critical:
|
379
|
+
return MessageImportance.CRITICAL
|
380
|
+
|
381
|
+
# Last N pairs are HIGH
|
382
|
+
if current_pair_index >= len(message_pairs) - self._num_recent_pairs_for_high_importance:
|
383
|
+
return MessageImportance.HIGH
|
384
|
+
|
385
|
+
# Content-based importance adjustments
|
386
|
+
if role == 'user':
|
387
|
+
# User queries are generally important
|
388
|
+
return MessageImportance.MEDIUM
|
389
|
+
elif role == 'assistant':
|
390
|
+
# Assistant responses vary by content length and complexity
|
391
|
+
if len(content) > 500: # Long responses might be more important
|
392
|
+
return MessageImportance.MEDIUM
|
393
|
+
else:
|
394
|
+
return MessageImportance.LOW
|
395
|
+
elif role == 'tool':
|
396
|
+
# Tool responses are generally MEDIUM unless they're errors
|
397
|
+
return MessageImportance.MEDIUM
|
398
|
+
|
399
|
+
# Default importance
|
400
|
+
return MessageImportance.LOW
|
401
|
+
|
402
|
+
def categorize_message(self, message: Dict[str, Any], index: int, total_messages: int) -> Tuple[MessageType, MessageImportance]:
|
403
|
+
"""
|
404
|
+
Categorize a message and determine its base importance.
|
405
|
+
|
406
|
+
Args:
|
407
|
+
message: The message to categorize
|
408
|
+
index: Position of the message in the conversation
|
409
|
+
total_messages: Total number of messages in the conversation
|
410
|
+
|
411
|
+
Returns:
|
412
|
+
Tuple of (MessageType, MessageImportance)
|
413
|
+
"""
|
414
|
+
role = message.get('role', '')
|
415
|
+
content = message.get('content', '')
|
416
|
+
|
417
|
+
# Determine message type
|
418
|
+
if role == 'system':
|
419
|
+
msg_type = MessageType.SYSTEM
|
420
|
+
elif role == 'user':
|
421
|
+
msg_type = MessageType.USER_QUERY
|
422
|
+
elif role == 'tool':
|
423
|
+
if self._is_tool_error_response(message):
|
424
|
+
msg_type = MessageType.TOOL_ERROR
|
425
|
+
else:
|
426
|
+
msg_type = MessageType.TOOL_RESPONSE
|
427
|
+
elif role == 'assistant':
|
428
|
+
if message.get('tool_calls'):
|
429
|
+
# Check if this is a final_answer or ask_question tool call
|
430
|
+
tool_calls = message.get('tool_calls', [])
|
431
|
+
if any(tc.get('function', {}).get('name') in ['final_answer', 'ask_question']
|
432
|
+
for tc in tool_calls):
|
433
|
+
msg_type = MessageType.FINAL_ANSWER
|
434
|
+
else:
|
435
|
+
msg_type = MessageType.TOOL_CALL
|
436
|
+
else:
|
437
|
+
msg_type = MessageType.ASSISTANT_RESPONSE
|
438
|
+
else:
|
439
|
+
msg_type = MessageType.ASSISTANT_RESPONSE
|
440
|
+
|
441
|
+
# Calculate message pairs for dynamic importance
|
442
|
+
message_pairs = self._calculate_message_pairs()
|
443
|
+
|
444
|
+
# Calculate dynamic importance
|
445
|
+
importance = self._calculate_dynamic_importance(message, index, total_messages, message_pairs)
|
446
|
+
|
447
|
+
return msg_type, importance
|
448
|
+
|
449
|
+
def add_message_metadata(
|
450
|
+
self,
|
451
|
+
message: Dict[str, Any],
|
452
|
+
token_count: int,
|
453
|
+
position: int,
|
454
|
+
total_messages: int
|
455
|
+
) -> None:
|
456
|
+
"""
|
457
|
+
Add metadata for a message and update tool call pairs.
|
458
|
+
|
459
|
+
Args:
|
460
|
+
message: The message to add metadata for
|
461
|
+
token_count: Number of tokens in the message
|
462
|
+
position: Position of the message in the conversation
|
463
|
+
total_messages: Total number of messages in the conversation
|
464
|
+
"""
|
465
|
+
# Categorize the message
|
466
|
+
msg_type, base_importance = self.categorize_message(message, position, total_messages)
|
467
|
+
|
468
|
+
# Extract task information
|
469
|
+
task_id = self._extract_task_id(message)
|
470
|
+
if task_id:
|
471
|
+
self.active_tasks.add(task_id)
|
472
|
+
|
473
|
+
# Check if this is an error message
|
474
|
+
is_error = self._is_tool_error_response(message)
|
475
|
+
|
476
|
+
# Extract tool call ID - handle both tool calls and tool responses
|
477
|
+
tool_call_id = None
|
478
|
+
if message.get('role') == 'tool':
|
479
|
+
# Tool response - get tool_call_id directly
|
480
|
+
tool_call_id = message.get('tool_call_id')
|
481
|
+
elif message.get('role') == 'assistant' and message.get('tool_calls'):
|
482
|
+
# Tool call - get the first tool call ID (assuming single tool call per message)
|
483
|
+
tool_calls = message.get('tool_calls', [])
|
484
|
+
if tool_calls:
|
485
|
+
tool_call_id = tool_calls[0].get('id')
|
486
|
+
|
487
|
+
# Create metadata
|
488
|
+
metadata = MessageMetadata(
|
489
|
+
message_type=msg_type,
|
490
|
+
importance=base_importance, # Will be recalculated dynamically
|
491
|
+
created_at=time.time(),
|
492
|
+
token_count=token_count,
|
493
|
+
is_error=is_error,
|
494
|
+
error_resolved=False,
|
495
|
+
part_of_task=task_id,
|
496
|
+
task_completed=task_id in self.completed_tasks if task_id else False,
|
497
|
+
tool_call_id=tool_call_id,
|
498
|
+
can_summarize=msg_type not in [MessageType.SYSTEM, MessageType.FINAL_ANSWER],
|
499
|
+
summary=None
|
500
|
+
)
|
501
|
+
|
502
|
+
# Add to metadata list
|
503
|
+
self.message_metadata.append(metadata)
|
504
|
+
|
505
|
+
# Update tool call pairs
|
506
|
+
self._update_tool_call_pairs()
|
507
|
+
|
508
|
+
# Update resolved errors
|
509
|
+
self._update_resolved_errors()
|
510
|
+
|
511
|
+
# Synchronize tool call pair importance levels
|
512
|
+
self._synchronize_tool_call_pairs()
|
513
|
+
|
514
|
+
self.logger.debug(f"Added metadata for message at position {position}: {msg_type.value}, {base_importance.value}, tool_call_id: {tool_call_id}")
|
515
|
+
|
516
|
+
def _update_tool_call_pairs(self) -> None:
|
517
|
+
"""Update the tool call pairs mapping based on current messages."""
|
518
|
+
self._tool_call_pairs.clear()
|
519
|
+
|
520
|
+
# Find all tool calls and their responses
|
521
|
+
for i, metadata in enumerate(self.message_metadata):
|
522
|
+
if metadata.tool_call_id:
|
523
|
+
if metadata.message_type in [MessageType.TOOL_CALL, MessageType.FINAL_ANSWER]:
|
524
|
+
# This is a tool call, look for its response
|
525
|
+
for j in range(i + 1, len(self.message_metadata)):
|
526
|
+
response_meta = self.message_metadata[j]
|
527
|
+
if (response_meta.tool_call_id == metadata.tool_call_id and
|
528
|
+
response_meta.message_type in [MessageType.TOOL_RESPONSE, MessageType.TOOL_ERROR]):
|
529
|
+
self._tool_call_pairs[metadata.tool_call_id] = (i, j)
|
530
|
+
break
|
531
|
+
|
532
|
+
def _recalculate_all_importance_levels(self) -> None:
|
533
|
+
"""Recalculate importance levels for all messages based on current context."""
|
534
|
+
if not self.message_metadata:
|
535
|
+
return
|
536
|
+
|
537
|
+
# Calculate message pairs for context
|
538
|
+
message_pairs = self._calculate_message_pairs()
|
539
|
+
total_messages = len(self.message_metadata)
|
540
|
+
|
541
|
+
# Recalculate importance for each message
|
542
|
+
for i, metadata in enumerate(self.message_metadata):
|
543
|
+
# We need the original message to recalculate importance
|
544
|
+
# For now, we'll use a simplified approach based on message type and position
|
545
|
+
new_importance = self._calculate_positional_importance(i, total_messages, message_pairs, metadata)
|
546
|
+
metadata.importance = new_importance
|
547
|
+
|
548
|
+
# After recalculating all, synchronize tool call pairs
|
549
|
+
self._synchronize_tool_call_pairs()
|
550
|
+
|
551
|
+
self.logger.debug(f"Recalculated importance levels for {total_messages} messages")
|
552
|
+
|
553
|
+
def _calculate_positional_importance(
|
554
|
+
self,
|
555
|
+
index: int,
|
556
|
+
total_messages: int,
|
557
|
+
message_pairs: List[Tuple[int, int]],
|
558
|
+
metadata: MessageMetadata
|
559
|
+
) -> MessageImportance:
|
560
|
+
"""Calculate importance based on position and message type."""
|
561
|
+
|
562
|
+
# System messages are always CRITICAL
|
563
|
+
if metadata.message_type == MessageType.SYSTEM:
|
564
|
+
return MessageImportance.CRITICAL
|
565
|
+
|
566
|
+
# Final answers are HIGH
|
567
|
+
if metadata.message_type == MessageType.FINAL_ANSWER:
|
568
|
+
return MessageImportance.HIGH
|
569
|
+
|
570
|
+
# Errors are HIGH until resolved
|
571
|
+
if metadata.is_error and not metadata.error_resolved:
|
572
|
+
return MessageImportance.HIGH
|
573
|
+
|
574
|
+
# Position-based importance
|
575
|
+
if total_messages <= 10:
|
576
|
+
return MessageImportance.MEDIUM
|
577
|
+
|
578
|
+
# Find which pair this message belongs to
|
579
|
+
current_pair_index = None
|
580
|
+
for pair_idx, (start_idx, end_idx) in enumerate(message_pairs):
|
581
|
+
if start_idx <= index <= end_idx:
|
582
|
+
current_pair_index = pair_idx
|
583
|
+
break
|
584
|
+
|
585
|
+
if current_pair_index is not None:
|
586
|
+
# First N pairs are CRITICAL
|
587
|
+
if current_pair_index < self._num_initial_pairs_critical:
|
588
|
+
return MessageImportance.CRITICAL
|
589
|
+
|
590
|
+
# Last N pairs are HIGH
|
591
|
+
if current_pair_index >= len(message_pairs) - self._num_recent_pairs_for_high_importance:
|
592
|
+
return MessageImportance.HIGH
|
593
|
+
|
594
|
+
# Default based on message type
|
595
|
+
if metadata.message_type in [MessageType.USER_QUERY, MessageType.TOOL_RESPONSE]:
|
596
|
+
return MessageImportance.MEDIUM
|
597
|
+
|
598
|
+
return MessageImportance.LOW
|
599
|
+
|
600
|
+
def _calculate_message_pairs(self) -> List[Tuple[int, int]]:
|
601
|
+
"""Calculate logical message pairs for positional importance."""
|
602
|
+
pairs = []
|
603
|
+
i = 0
|
604
|
+
|
605
|
+
while i < len(self.message_metadata):
|
606
|
+
metadata = self.message_metadata[i]
|
607
|
+
|
608
|
+
# System message stands alone
|
609
|
+
if metadata.message_type == MessageType.SYSTEM:
|
610
|
+
pairs.append((i, i))
|
611
|
+
i += 1
|
612
|
+
continue
|
613
|
+
|
614
|
+
# User message followed by assistant response
|
615
|
+
if metadata.message_type == MessageType.USER_QUERY:
|
616
|
+
if i + 1 < len(self.message_metadata):
|
617
|
+
next_meta = self.message_metadata[i + 1]
|
618
|
+
if next_meta.message_type in [MessageType.ASSISTANT_RESPONSE, MessageType.TOOL_CALL]:
|
619
|
+
pairs.append((i, i + 1))
|
620
|
+
i += 2
|
621
|
+
continue
|
622
|
+
|
623
|
+
# User message without response
|
624
|
+
pairs.append((i, i))
|
625
|
+
i += 1
|
626
|
+
continue
|
627
|
+
|
628
|
+
# Tool call with response
|
629
|
+
if metadata.tool_call_id and metadata.tool_call_id in self._tool_call_pairs:
|
630
|
+
call_idx, response_idx = self._tool_call_pairs[metadata.tool_call_id]
|
631
|
+
if i == call_idx:
|
632
|
+
pairs.append((call_idx, response_idx))
|
633
|
+
i = response_idx + 1
|
634
|
+
continue
|
635
|
+
|
636
|
+
# Single message
|
637
|
+
pairs.append((i, i))
|
638
|
+
i += 1
|
639
|
+
|
640
|
+
return pairs
|
641
|
+
|
642
|
+
def _update_resolved_errors(self) -> None:
|
643
|
+
"""Update the set of resolved error tool call IDs."""
|
644
|
+
self._resolved_errors.clear()
|
645
|
+
|
646
|
+
# Track tool calls that had errors but later succeeded
|
647
|
+
error_tool_calls = set()
|
648
|
+
success_tool_calls = set()
|
649
|
+
|
650
|
+
for metadata in self.message_metadata:
|
651
|
+
if metadata.tool_call_id:
|
652
|
+
if metadata.is_error:
|
653
|
+
error_tool_calls.add(metadata.tool_call_id)
|
654
|
+
elif metadata.message_type in [MessageType.TOOL_RESPONSE]:
|
655
|
+
# Check if this is a successful response (not an error)
|
656
|
+
success_tool_calls.add(metadata.tool_call_id)
|
657
|
+
|
658
|
+
# Find tool functions that had both errors and successes
|
659
|
+
for tool_call_id in self._tool_call_pairs:
|
660
|
+
call_idx, response_idx = self._tool_call_pairs[tool_call_id]
|
661
|
+
|
662
|
+
if (call_idx < len(self.message_metadata) and
|
663
|
+
response_idx < len(self.message_metadata)):
|
664
|
+
|
665
|
+
call_meta = self.message_metadata[call_idx]
|
666
|
+
response_meta = self.message_metadata[response_idx]
|
667
|
+
|
668
|
+
# Check if there's a later successful call to the same function
|
669
|
+
if response_meta.is_error:
|
670
|
+
function_name = self._extract_function_name(call_meta, call_idx)
|
671
|
+
if function_name and self._has_later_success(function_name, call_idx):
|
672
|
+
self._resolved_errors.add(tool_call_id)
|
673
|
+
response_meta.error_resolved = True
|
674
|
+
|
675
|
+
def _extract_function_name(self, metadata: MessageMetadata, message_index: int) -> Optional[str]:
|
676
|
+
"""Extract function name from a tool call message."""
|
677
|
+
# This would need access to the actual message content
|
678
|
+
# For now, return a placeholder - this should be implemented based on message structure
|
679
|
+
return f"function_{message_index}" # Placeholder
|
680
|
+
|
681
|
+
def _has_later_success(self, function_name: str, error_position: int) -> bool:
|
682
|
+
"""Check if there's a later successful call to the same function."""
|
683
|
+
# Look for successful calls to the same function after the error
|
684
|
+
for i in range(error_position + 1, len(self.message_metadata)):
|
685
|
+
metadata = self.message_metadata[i]
|
686
|
+
if (metadata.message_type == MessageType.TOOL_RESPONSE and
|
687
|
+
not metadata.is_error):
|
688
|
+
# Check if this is the same function (simplified check)
|
689
|
+
return True
|
690
|
+
return False
|
691
|
+
|
692
|
+
def _synchronize_tool_call_pairs(self) -> None:
|
693
|
+
"""Ensure tool call pairs have synchronized importance levels."""
|
694
|
+
for tool_call_id, (call_idx, response_idx) in self._tool_call_pairs.items():
|
695
|
+
if (call_idx < len(self.message_metadata) and
|
696
|
+
response_idx < len(self.message_metadata)):
|
697
|
+
|
698
|
+
call_meta = self.message_metadata[call_idx]
|
699
|
+
response_meta = self.message_metadata[response_idx]
|
700
|
+
|
701
|
+
# Use the higher importance level for both
|
702
|
+
importance_order = [
|
703
|
+
MessageImportance.TEMP,
|
704
|
+
MessageImportance.LOW,
|
705
|
+
MessageImportance.MEDIUM,
|
706
|
+
MessageImportance.HIGH,
|
707
|
+
MessageImportance.CRITICAL
|
708
|
+
]
|
709
|
+
|
710
|
+
call_priority = importance_order.index(call_meta.importance)
|
711
|
+
response_priority = importance_order.index(response_meta.importance)
|
712
|
+
|
713
|
+
target_importance = importance_order[max(call_priority, response_priority)]
|
714
|
+
|
715
|
+
# Update both to use the higher importance
|
716
|
+
call_meta.importance = target_importance
|
717
|
+
response_meta.importance = target_importance
|
718
|
+
|
719
|
+
self.logger.debug(f"Synchronized tool call pair {tool_call_id}: both set to {target_importance.value}")
|
720
|
+
|
721
|
+
def _extract_task_id(self, message: Dict[str, Any]) -> Optional[str]:
|
722
|
+
"""Extract task identifier from message content."""
|
723
|
+
# Simple implementation - could be enhanced with more sophisticated parsing
|
724
|
+
content = str(message.get('content', ''))
|
725
|
+
|
726
|
+
# Look for task patterns
|
727
|
+
if 'task:' in content.lower():
|
728
|
+
parts = content.lower().split('task:')
|
729
|
+
if len(parts) > 1:
|
730
|
+
task_part = parts[1].split()[0] if parts[1].split() else None
|
731
|
+
return f"task_{task_part}" if task_part else None
|
732
|
+
|
733
|
+
return None
|
734
|
+
|
735
|
+
def _is_tool_error_response(self, message: Dict[str, Any]) -> bool:
|
736
|
+
"""
|
737
|
+
Check if a tool response message represents an error.
|
738
|
+
|
739
|
+
Args:
|
740
|
+
message: The tool response message to check
|
741
|
+
|
742
|
+
Returns:
|
743
|
+
True if the message represents a tool error
|
744
|
+
"""
|
745
|
+
if message.get('role') != 'tool':
|
746
|
+
return False
|
747
|
+
|
748
|
+
content = str(message.get('content', '')).strip().lower()
|
749
|
+
|
750
|
+
# Check if content starts with "error"
|
751
|
+
return content.startswith('error')
|
752
|
+
|
753
|
+
def calculate_memory_pressure(self, total_tokens: int) -> float:
|
754
|
+
"""Calculate current memory pressure (0.0 to 1.0)."""
|
755
|
+
return min(1.0, total_tokens / self.max_tokens)
|
756
|
+
|
757
|
+
def should_optimize_memory(self, total_tokens: int) -> bool:
|
758
|
+
"""Determine if memory optimization is needed."""
|
759
|
+
return total_tokens > self.target_tokens
|
760
|
+
|
761
|
+
def optimize_messages(
|
762
|
+
self,
|
763
|
+
messages: List[Dict[str, Any]],
|
764
|
+
token_counter: callable
|
765
|
+
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
766
|
+
"""
|
767
|
+
Optimize message list by removing/summarizing less important messages
|
768
|
+
while preserving tool call/response pairs and maintaining conversation integrity.
|
769
|
+
"""
|
770
|
+
# Ensure metadata is up to date
|
771
|
+
if len(messages) > len(self.message_metadata):
|
772
|
+
for i in range(len(self.message_metadata), len(messages)):
|
773
|
+
msg = messages[i]
|
774
|
+
token_count = self._count_message_tokens(msg, token_counter)
|
775
|
+
self.add_message_metadata(msg, token_count, i, len(messages))
|
776
|
+
|
777
|
+
if len(messages) != len(self.message_metadata):
|
778
|
+
self.logger.warning("Message count mismatch with metadata")
|
779
|
+
return messages, {"error": "Message metadata mismatch"}
|
780
|
+
|
781
|
+
# Recalculate importance levels based on current conversation state
|
782
|
+
self._recalculate_all_importance_levels()
|
783
|
+
|
784
|
+
# Calculate current token usage using proper token counting
|
785
|
+
total_tokens = sum(self._count_message_tokens(msg, token_counter) for msg in messages)
|
786
|
+
|
787
|
+
if not self.should_optimize_memory(total_tokens):
|
788
|
+
return messages, {'action': 'none', 'reason': 'within_limits'}
|
789
|
+
|
790
|
+
memory_pressure = self.calculate_memory_pressure(total_tokens)
|
791
|
+
context = {'memory_pressure': memory_pressure}
|
792
|
+
|
793
|
+
self.logger.info(f"Memory optimization needed. Total tokens: {total_tokens}, pressure: {memory_pressure:.2f}")
|
794
|
+
|
795
|
+
# Find all tool call/response pairs
|
796
|
+
tool_call_pairs = self._tool_call_pairs
|
797
|
+
|
798
|
+
# Create sets of message indices that must be kept together
|
799
|
+
protected_indices = set()
|
800
|
+
pair_groups = {} # Maps group_id to set of indices
|
801
|
+
|
802
|
+
for tool_call_id, (call_idx, response_idx) in tool_call_pairs.items():
|
803
|
+
group_id = f"pair_{tool_call_id}"
|
804
|
+
pair_groups[group_id] = {call_idx, response_idx}
|
805
|
+
protected_indices.update({call_idx, response_idx})
|
806
|
+
|
807
|
+
# Always protect system message and recent critical messages
|
808
|
+
for i, meta in enumerate(self.message_metadata):
|
809
|
+
if meta.importance == MessageImportance.CRITICAL:
|
810
|
+
protected_indices.add(i)
|
811
|
+
|
812
|
+
# Build optimized message list
|
813
|
+
optimized_messages = []
|
814
|
+
optimized_metadata = []
|
815
|
+
tokens_used = 0
|
816
|
+
tokens_saved = 0
|
817
|
+
messages_removed = 0
|
818
|
+
messages_summarized = 0
|
819
|
+
|
820
|
+
# Process messages in order, respecting pairs and importance
|
821
|
+
i = 0
|
822
|
+
while i < len(messages):
|
823
|
+
msg = messages[i]
|
824
|
+
meta = self.message_metadata[i]
|
825
|
+
msg_tokens = self._count_message_tokens(msg, token_counter)
|
826
|
+
|
827
|
+
# Check if this message is part of a protected pair
|
828
|
+
current_group = None
|
829
|
+
for group_id, indices in pair_groups.items():
|
830
|
+
if i in indices:
|
831
|
+
current_group = group_id
|
832
|
+
break
|
833
|
+
|
834
|
+
if current_group:
|
835
|
+
# Process the entire group
|
836
|
+
group_indices = sorted(pair_groups[current_group])
|
837
|
+
group_messages = [messages[idx] for idx in group_indices]
|
838
|
+
group_metadata = [self.message_metadata[idx] for idx in group_indices]
|
839
|
+
group_tokens = sum(self._count_message_tokens(messages[idx], token_counter) for idx in group_indices)
|
840
|
+
|
841
|
+
# Check if we should keep this group
|
842
|
+
group_importance_values = [self.message_metadata[idx].importance.value for idx in group_indices]
|
843
|
+
group_importance = max(group_importance_values, key=lambda x:
|
844
|
+
{"critical": 4, "high": 3, "medium": 2, "low": 1, "temp": 0}.get(x, 0)
|
845
|
+
)
|
846
|
+
should_keep_group = (
|
847
|
+
group_importance == MessageImportance.CRITICAL or
|
848
|
+
self.strategy.should_keep_message(msg, meta, context)
|
849
|
+
)
|
850
|
+
|
851
|
+
if should_keep_group and tokens_used + group_tokens <= self.target_tokens:
|
852
|
+
# Keep the entire group
|
853
|
+
optimized_messages.extend(group_messages)
|
854
|
+
optimized_metadata.extend(group_metadata)
|
855
|
+
tokens_used += group_tokens
|
856
|
+
self.logger.debug(f"Kept tool call pair group: {group_indices}")
|
857
|
+
else:
|
858
|
+
# Skip the entire group
|
859
|
+
tokens_saved += group_tokens
|
860
|
+
messages_removed += len(group_indices)
|
861
|
+
self.logger.debug(f"Removed tool call pair group: {group_indices}")
|
862
|
+
|
863
|
+
# Skip to after this group
|
864
|
+
i = max(group_indices) + 1
|
865
|
+
continue
|
866
|
+
|
867
|
+
# Single message processing
|
868
|
+
if i in protected_indices:
|
869
|
+
# Always keep protected messages
|
870
|
+
optimized_messages.append(msg)
|
871
|
+
optimized_metadata.append(meta)
|
872
|
+
tokens_used += msg_tokens
|
873
|
+
elif self.strategy.should_keep_message(msg, meta, context) and tokens_used + msg_tokens <= self.target_tokens:
|
874
|
+
# Keep this message
|
875
|
+
optimized_messages.append(msg)
|
876
|
+
optimized_metadata.append(meta)
|
877
|
+
tokens_used += msg_tokens
|
878
|
+
elif self.enable_summarization and meta.can_summarize and not meta.summary:
|
879
|
+
# Try to summarize
|
880
|
+
summary = self._summarize_message(msg)
|
881
|
+
summary_tokens = token_counter(summary)
|
882
|
+
|
883
|
+
if tokens_used + summary_tokens <= self.target_tokens:
|
884
|
+
# Create summarized message
|
885
|
+
summarized_msg = msg.copy()
|
886
|
+
summarized_msg['content'] = summary
|
887
|
+
optimized_messages.append(summarized_msg)
|
888
|
+
|
889
|
+
# Update metadata
|
890
|
+
meta.summary = summary
|
891
|
+
optimized_metadata.append(meta)
|
892
|
+
|
893
|
+
tokens_used += summary_tokens
|
894
|
+
tokens_saved += msg_tokens - summary_tokens
|
895
|
+
messages_summarized += 1
|
896
|
+
else:
|
897
|
+
# Skip this message
|
898
|
+
tokens_saved += msg_tokens
|
899
|
+
messages_removed += 1
|
900
|
+
else:
|
901
|
+
# Skip this message
|
902
|
+
tokens_saved += msg_tokens
|
903
|
+
messages_removed += 1
|
904
|
+
|
905
|
+
i += 1
|
906
|
+
|
907
|
+
# Update metadata list
|
908
|
+
self.message_metadata = optimized_metadata
|
909
|
+
|
910
|
+
# Update statistics
|
911
|
+
self.stats['messages_removed'] += messages_removed
|
912
|
+
self.stats['messages_summarized'] += messages_summarized
|
913
|
+
self.stats['tokens_saved'] += tokens_saved
|
914
|
+
self.stats['memory_optimizations'] += 1
|
915
|
+
|
916
|
+
optimization_info = {
|
917
|
+
'action': 'optimized',
|
918
|
+
'original_tokens': total_tokens,
|
919
|
+
'final_tokens': tokens_used,
|
920
|
+
'tokens_saved': tokens_saved,
|
921
|
+
'messages_removed': messages_removed,
|
922
|
+
'messages_summarized': messages_summarized,
|
923
|
+
'memory_pressure_before': memory_pressure,
|
924
|
+
'memory_pressure_after': self.calculate_memory_pressure(tokens_used),
|
925
|
+
'tool_pairs_preserved': len(tool_call_pairs)
|
926
|
+
}
|
927
|
+
|
928
|
+
self.logger.info(f"Memory optimization completed: {optimization_info}")
|
929
|
+
|
930
|
+
# Final validation: ensure tool call integrity is maintained
|
931
|
+
final_pairs = self._tool_call_pairs
|
932
|
+
if len(final_pairs) != len([pair for pair in tool_call_pairs.values() if all(idx < len(optimized_messages) for idx in pair)]):
|
933
|
+
self.logger.warning("Tool call/response integrity may be compromised")
|
934
|
+
|
935
|
+
return optimized_messages, optimization_info
|
936
|
+
|
937
|
+
def _summarize_message(self, message: Dict[str, Any]) -> str:
|
938
|
+
"""Create a summary of a message."""
|
939
|
+
content = str(message.get('content', ''))
|
940
|
+
role = message.get('role', '')
|
941
|
+
|
942
|
+
# Simple summarization - could be enhanced with LLM-based summarization
|
943
|
+
if role == 'tool':
|
944
|
+
tool_name = message.get('name', 'unknown')
|
945
|
+
if len(content) > 200:
|
946
|
+
return f"[SUMMARY] Tool {tool_name} executed: {content[:100]}... [truncated]"
|
947
|
+
return content
|
948
|
+
|
949
|
+
if role == 'assistant' and len(content) > 300:
|
950
|
+
return f"[SUMMARY] Assistant response: {content[:150]}... [truncated]"
|
951
|
+
|
952
|
+
if len(content) > 200:
|
953
|
+
return f"[SUMMARY] {content[:100]}... [truncated]"
|
954
|
+
|
955
|
+
return content
|
956
|
+
|
957
|
+
def get_memory_stats(self) -> Dict[str, Any]:
|
958
|
+
"""Get memory management statistics."""
|
959
|
+
return {
|
960
|
+
**self.stats,
|
961
|
+
'active_tasks': len(self.active_tasks),
|
962
|
+
'completed_tasks': len(self.completed_tasks),
|
963
|
+
'total_messages': len(self.message_metadata),
|
964
|
+
'critical_messages': sum(1 for m in self.message_metadata if m.importance == MessageImportance.CRITICAL),
|
965
|
+
'error_messages': sum(1 for m in self.message_metadata if m.is_error),
|
966
|
+
'resolved_errors': sum(1 for m in self.message_metadata if m.is_error and m.error_resolved)
|
967
|
+
}
|
968
|
+
|
969
|
+
def reset_stats(self) -> None:
|
970
|
+
"""Reset memory management statistics."""
|
971
|
+
self.stats = {
|
972
|
+
'messages_removed': 0,
|
973
|
+
'messages_summarized': 0,
|
974
|
+
'tokens_saved': 0,
|
975
|
+
'memory_optimizations': 0
|
976
|
+
}
|
977
|
+
|
978
|
+
def clear_completed_tasks(self) -> None:
|
979
|
+
"""Clear metadata for completed tasks to free up memory."""
|
980
|
+
# Remove metadata for completed, non-critical messages
|
981
|
+
kept_metadata = []
|
982
|
+
removed_count = 0
|
983
|
+
|
984
|
+
for metadata in self.message_metadata:
|
985
|
+
if (metadata.task_completed and
|
986
|
+
metadata.importance not in [MessageImportance.CRITICAL, MessageImportance.HIGH] and
|
987
|
+
time.time() - metadata.created_at > 1800): # 30 minutes old
|
988
|
+
removed_count += 1
|
989
|
+
else:
|
990
|
+
kept_metadata.append(metadata)
|
991
|
+
|
992
|
+
self.message_metadata = kept_metadata
|
993
|
+
self.logger.info(f"Cleared {removed_count} completed task metadata entries")
|
994
|
+
|
995
|
+
def to_dict(self) -> Dict[str, Any]:
|
996
|
+
"""Serialize memory manager state."""
|
997
|
+
return {
|
998
|
+
'max_tokens': self.max_tokens,
|
999
|
+
'target_tokens': self.target_tokens,
|
1000
|
+
'enable_summarization': self.enable_summarization,
|
1001
|
+
'active_tasks': list(self.active_tasks),
|
1002
|
+
'completed_tasks': list(self.completed_tasks),
|
1003
|
+
'conversation_summary': self.conversation_summary,
|
1004
|
+
'task_summaries': self.task_summaries,
|
1005
|
+
'stats': self.stats,
|
1006
|
+
'message_metadata': [
|
1007
|
+
{
|
1008
|
+
'message_type': meta.message_type.value,
|
1009
|
+
'importance': meta.importance.value,
|
1010
|
+
'created_at': meta.created_at,
|
1011
|
+
'token_count': meta.token_count,
|
1012
|
+
'is_error': meta.is_error,
|
1013
|
+
'error_resolved': meta.error_resolved,
|
1014
|
+
'part_of_task': meta.part_of_task,
|
1015
|
+
'task_completed': meta.task_completed,
|
1016
|
+
'can_summarize': meta.can_summarize,
|
1017
|
+
'summary': meta.summary,
|
1018
|
+
'related_messages': meta.related_messages,
|
1019
|
+
'tool_call_id': meta.tool_call_id
|
1020
|
+
}
|
1021
|
+
for meta in self.message_metadata
|
1022
|
+
]
|
1023
|
+
}
|
1024
|
+
|
1025
|
+
@classmethod
|
1026
|
+
def from_dict(
|
1027
|
+
cls,
|
1028
|
+
data: Dict[str, Any],
|
1029
|
+
strategy: MemoryStrategy = None,
|
1030
|
+
logger: Optional[logging.Logger] = None
|
1031
|
+
) -> 'MemoryManager':
|
1032
|
+
"""Deserialize memory manager state."""
|
1033
|
+
manager = cls(
|
1034
|
+
max_tokens=data.get('max_tokens', 8000),
|
1035
|
+
target_tokens=data.get('target_tokens', 6000),
|
1036
|
+
strategy=strategy,
|
1037
|
+
enable_summarization=data.get('enable_summarization', True),
|
1038
|
+
logger=logger
|
1039
|
+
)
|
1040
|
+
|
1041
|
+
manager.active_tasks = set(data.get('active_tasks', []))
|
1042
|
+
manager.completed_tasks = set(data.get('completed_tasks', []))
|
1043
|
+
manager.conversation_summary = data.get('conversation_summary')
|
1044
|
+
manager.task_summaries = data.get('task_summaries', {})
|
1045
|
+
manager.stats = data.get('stats', manager.stats)
|
1046
|
+
|
1047
|
+
# Restore message metadata
|
1048
|
+
metadata_list = data.get('message_metadata', [])
|
1049
|
+
manager.message_metadata = [
|
1050
|
+
MessageMetadata(
|
1051
|
+
message_type=MessageType(meta['message_type']),
|
1052
|
+
importance=MessageImportance(meta['importance']),
|
1053
|
+
created_at=meta['created_at'],
|
1054
|
+
token_count=meta['token_count'],
|
1055
|
+
is_error=meta['is_error'],
|
1056
|
+
error_resolved=meta['error_resolved'],
|
1057
|
+
part_of_task=meta['part_of_task'],
|
1058
|
+
task_completed=meta['task_completed'],
|
1059
|
+
can_summarize=meta['can_summarize'],
|
1060
|
+
summary=meta['summary'],
|
1061
|
+
related_messages=meta['related_messages'],
|
1062
|
+
tool_call_id=meta['tool_call_id']
|
1063
|
+
)
|
1064
|
+
for meta in metadata_list
|
1065
|
+
]
|
1066
|
+
|
1067
|
+
return manager
|