tinyagent-py 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1067 @@
1
+ #
2
+ #
3
+ #
4
+ #
5
+ #
6
+ # tool call and tool error with same tool id should have the same importance level, otherwise LLM would reject it.
7
+ #- tool call ==> tool error, should be MEDIUM . if there is no pair of tool call ==> tool error after that (It is the last error)
8
+ #- should be LOW, if another pair of tool call ==> tool response (response without error) happens after it.
9
+ #- if this happens at the end of conversation, the rule of HIGH importance will overrule everything, so they would be HIGh priority.
10
+ # Last message pairs should be high priority.
11
+ #
12
+ # tool_call => tool is a pair, and share the same importance level
13
+ #
14
+ #
15
+ # if 'role': 'assistant',
16
+ # 'content': '',
17
+ # 'tool_calls => function ==> name
18
+ #
19
+ # should share same level of importance for it's response with role = tool and same tool_call_id
20
+
21
+ # last 4 pair in the history should have HIGH importance
22
+ #
23
+ #
24
+ #
25
+ # memory_manager.py
26
+ import logging
27
+ from typing import Dict, List, Optional, Any, Tuple, Set
28
+ from dataclasses import dataclass, field
29
+ from enum import Enum
30
+ import json
31
+ import time
32
+ from abc import ABC, abstractmethod
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+ class MessageImportance(Enum):
37
+ """Defines the importance levels for messages."""
38
+ CRITICAL = "critical" # Must always be kept (system, final answers, etc.)
39
+ HIGH = "high" # Important context, keep unless absolutely necessary
40
+ MEDIUM = "medium" # Standard conversation, can be summarized
41
+ LOW = "low" # Tool errors, failed attempts, can be removed
42
+ TEMP = "temp" # Temporary messages, remove after success
43
+
44
+ class MessageType(Enum):
45
+ """Categorizes different types of messages."""
46
+ SYSTEM = "system"
47
+ USER_QUERY = "user_query"
48
+ ASSISTANT_RESPONSE = "assistant_response"
49
+ TOOL_CALL = "tool_call"
50
+ TOOL_RESPONSE = "tool_response"
51
+ TOOL_ERROR = "tool_error"
52
+ FINAL_ANSWER = "final_answer"
53
+ QUESTION_TO_USER = "question_to_user"
54
+
55
+ @dataclass
56
+ class MessageMetadata:
57
+ """Metadata for tracking message importance and lifecycle."""
58
+ message_type: MessageType
59
+ importance: MessageImportance
60
+ created_at: float
61
+ token_count: int = 0
62
+ is_error: bool = False
63
+ error_resolved: bool = False
64
+ part_of_task: Optional[str] = None # Task/subtask identifier
65
+ task_completed: bool = False
66
+ can_summarize: bool = True
67
+ summary: Optional[str] = None
68
+ related_messages: List[int] = field(default_factory=list) # Indices of related messages
69
+ tool_call_id: Optional[str] = None # To track tool call/response pairs
70
+
71
+ class MemoryStrategy(ABC):
72
+ """Abstract base class for memory management strategies."""
73
+
74
+ @abstractmethod
75
+ def should_keep_message(self, message: Dict[str, Any], metadata: MessageMetadata,
76
+ context: Dict[str, Any]) -> bool:
77
+ """Determine if a message should be kept in memory."""
78
+ pass
79
+
80
+ @abstractmethod
81
+ def get_priority_score(self, message: Dict[str, Any], metadata: MessageMetadata) -> float:
82
+ """Get priority score for message ranking."""
83
+ pass
84
+
85
+ class ConservativeStrategy(MemoryStrategy):
86
+ """Conservative strategy - keeps more messages, summarizes less aggressively."""
87
+
88
+ def should_keep_message(self, message: Dict[str, Any], metadata: MessageMetadata,
89
+ context: Dict[str, Any]) -> bool:
90
+ # Always keep critical messages
91
+ if metadata.importance == MessageImportance.CRITICAL:
92
+ return True
93
+
94
+ # Keep high importance messages unless we're really tight on space
95
+ if metadata.importance == MessageImportance.HIGH:
96
+ return context.get('memory_pressure', 0) < 0.8
97
+
98
+ # Keep recent messages
99
+ if time.time() - metadata.created_at < 300: # 5 minutes
100
+ return True
101
+
102
+ # Remove resolved errors and temp messages
103
+ if metadata.importance == MessageImportance.TEMP:
104
+ return False
105
+
106
+ if metadata.is_error and metadata.error_resolved:
107
+ return False
108
+
109
+ return context.get('memory_pressure', 0) < 0.6
110
+
111
+ def get_priority_score(self, message: Dict[str, Any], metadata: MessageMetadata) -> float:
112
+ base_score = {
113
+ MessageImportance.CRITICAL: 1000,
114
+ MessageImportance.HIGH: 100,
115
+ MessageImportance.MEDIUM: 50,
116
+ MessageImportance.LOW: 10,
117
+ MessageImportance.TEMP: 1
118
+ }[metadata.importance]
119
+
120
+ # Boost recent messages
121
+ age_factor = max(0.1, 1.0 - (time.time() - metadata.created_at) / 3600)
122
+
123
+ # Penalize errors
124
+ error_penalty = 0.5 if metadata.is_error else 1.0
125
+
126
+ return base_score * age_factor * error_penalty
127
+
128
+ class AggressiveStrategy(MemoryStrategy):
129
+ """Aggressive strategy - removes more messages, summarizes more aggressively."""
130
+
131
+ def should_keep_message(self, message: Dict[str, Any], metadata: MessageMetadata,
132
+ context: Dict[str, Any]) -> bool:
133
+ # Always keep critical messages
134
+ if metadata.importance == MessageImportance.CRITICAL:
135
+ return True
136
+
137
+ # Be more selective with high importance
138
+ if metadata.importance == MessageImportance.HIGH:
139
+ return context.get('memory_pressure', 0) < 0.5 and (time.time() - metadata.created_at < 600)
140
+
141
+ # Only keep very recent medium importance messages
142
+ if metadata.importance == MessageImportance.MEDIUM:
143
+ return time.time() - metadata.created_at < 180 # 3 minutes
144
+
145
+ # Remove low importance and temp messages quickly
146
+ return False
147
+
148
+ def get_priority_score(self, message: Dict[str, Any], metadata: MessageMetadata) -> float:
149
+ base_score = {
150
+ MessageImportance.CRITICAL: 1000,
151
+ MessageImportance.HIGH: 80,
152
+ MessageImportance.MEDIUM: 30,
153
+ MessageImportance.LOW: 5,
154
+ MessageImportance.TEMP: 1
155
+ }[metadata.importance]
156
+
157
+ # Strong recency bias
158
+ age_factor = max(0.05, 1.0 - (time.time() - metadata.created_at) / 1800)
159
+
160
+ # Heavy error penalty
161
+ error_penalty = 0.2 if metadata.is_error else 1.0
162
+
163
+ return base_score * age_factor * error_penalty
164
+
165
+ class BalancedStrategy(MemoryStrategy):
166
+ """Balanced strategy - moderate approach to memory management."""
167
+
168
+ def should_keep_message(self, message: Dict[str, Any], metadata: MessageMetadata,
169
+ context: Dict[str, Any]) -> bool:
170
+ # Always keep critical messages
171
+ if metadata.importance == MessageImportance.CRITICAL:
172
+ return True
173
+
174
+ # Keep high importance messages unless high memory pressure
175
+ if metadata.importance == MessageImportance.HIGH:
176
+ return context.get('memory_pressure', 0) < 0.7
177
+
178
+ # Keep recent medium importance messages
179
+ if metadata.importance == MessageImportance.MEDIUM:
180
+ return time.time() - metadata.created_at < 450 # 7.5 minutes
181
+
182
+ # Remove resolved errors and temp messages
183
+ if metadata.is_error and metadata.error_resolved:
184
+ return False
185
+
186
+ if metadata.importance == MessageImportance.TEMP:
187
+ return time.time() - metadata.created_at < 60 # 1 minute
188
+
189
+ return context.get('memory_pressure', 0) < 0.4
190
+
191
+ def get_priority_score(self, message: Dict[str, Any], metadata: MessageMetadata) -> float:
192
+ base_score = {
193
+ MessageImportance.CRITICAL: 1000,
194
+ MessageImportance.HIGH: 90,
195
+ MessageImportance.MEDIUM: 40,
196
+ MessageImportance.LOW: 8,
197
+ MessageImportance.TEMP: 2
198
+ }[metadata.importance]
199
+
200
+ # Moderate recency bias
201
+ age_factor = max(0.1, 1.0 - (time.time() - metadata.created_at) / 2400)
202
+
203
+ # Moderate error penalty
204
+ error_penalty = 0.3 if metadata.is_error else 1.0
205
+
206
+ return base_score * age_factor * error_penalty
207
+
208
+ class MemoryManager:
209
+ """
210
+ Advanced memory management system for TinyAgent.
211
+
212
+ Features:
213
+ - Message importance tracking with dynamic positioning
214
+ - Intelligent message removal and summarization
215
+ - Multiple memory management strategies
216
+ - Task-based message grouping
217
+ - Error recovery tracking
218
+ - Tool call/response pair integrity
219
+ """
220
+
221
+ _DEFAULT_NUM_RECENT_PAIRS_HIGH_IMPORTANCE = 3
222
+ _DEFAULT_NUM_INITIAL_PAIRS_CRITICAL = 3
223
+
224
+ def __init__(
225
+ self,
226
+ max_tokens: int = 8000,
227
+ target_tokens: int = 6000,
228
+ strategy: MemoryStrategy = None,
229
+ enable_summarization: bool = True,
230
+ logger: Optional[logging.Logger] = None,
231
+ num_recent_pairs_high_importance: Optional[int] = None,
232
+ num_initial_pairs_critical: Optional[int] = None
233
+ ):
234
+ self.max_tokens = max_tokens
235
+ self.target_tokens = target_tokens
236
+ self.strategy = strategy or BalancedStrategy()
237
+ self.enable_summarization = enable_summarization
238
+ self.logger = logger or logging.getLogger(__name__)
239
+
240
+ # Configure importance thresholds
241
+ self._num_recent_pairs_for_high_importance = (
242
+ num_recent_pairs_high_importance
243
+ if num_recent_pairs_high_importance is not None
244
+ else self._DEFAULT_NUM_RECENT_PAIRS_HIGH_IMPORTANCE
245
+ )
246
+
247
+ self._num_initial_pairs_critical = (
248
+ num_initial_pairs_critical
249
+ if num_initial_pairs_critical is not None
250
+ else self._DEFAULT_NUM_INITIAL_PAIRS_CRITICAL
251
+ )
252
+
253
+ # Message metadata storage
254
+ self.message_metadata: List[MessageMetadata] = []
255
+
256
+ # Task tracking
257
+ self.active_tasks: Set[str] = set()
258
+ self.completed_tasks: Set[str] = set()
259
+
260
+ # Summary storage
261
+ self.conversation_summary: Optional[str] = None
262
+ self.task_summaries: Dict[str, str] = {}
263
+
264
+ # Statistics
265
+ self.stats = {
266
+ 'messages_removed': 0,
267
+ 'messages_summarized': 0,
268
+ 'tokens_saved': 0,
269
+ 'memory_optimizations': 0
270
+ }
271
+
272
+ # Tool call tracking for proper pairing
273
+ self._tool_call_pairs: Dict[str, Tuple[int, int]] = {} # tool_call_id -> (call_index, response_index)
274
+ self._resolved_errors: Set[str] = set() # Track resolved error tool_call_ids
275
+
276
+ def _count_message_tokens(self, message: Dict[str, Any], token_counter: callable) -> int:
277
+ """
278
+ Properly count tokens in a message, including tool calls.
279
+
280
+ Args:
281
+ message: The message to count tokens for
282
+ token_counter: Function to count tokens in text
283
+
284
+ Returns:
285
+ Total token count for the message
286
+ """
287
+ total_tokens = 0
288
+
289
+ # Count content tokens
290
+ content = message.get('content', '')
291
+ if content:
292
+ total_tokens += token_counter(str(content))
293
+
294
+ # Count tool call tokens
295
+ if 'tool_calls' in message and message['tool_calls']:
296
+ for tool_call in message['tool_calls']:
297
+ # Count function name
298
+ if isinstance(tool_call, dict):
299
+ if 'function' in tool_call:
300
+ func_data = tool_call['function']
301
+ if 'name' in func_data:
302
+ total_tokens += token_counter(func_data['name'])
303
+ if 'arguments' in func_data:
304
+ total_tokens += token_counter(str(func_data['arguments']))
305
+ # Count tool call ID
306
+ if 'id' in tool_call:
307
+ total_tokens += token_counter(str(tool_call['id']))
308
+ elif hasattr(tool_call, 'function'):
309
+ # Handle object-style tool calls
310
+ if hasattr(tool_call.function, 'name'):
311
+ total_tokens += token_counter(tool_call.function.name)
312
+ if hasattr(tool_call.function, 'arguments'):
313
+ total_tokens += token_counter(str(tool_call.function.arguments))
314
+ if hasattr(tool_call, 'id'):
315
+ total_tokens += token_counter(str(tool_call.id))
316
+
317
+ # Count tool call ID for tool responses
318
+ if 'tool_call_id' in message and message['tool_call_id']:
319
+ total_tokens += token_counter(str(message['tool_call_id']))
320
+
321
+ # Count tool name for tool responses
322
+ if 'name' in message and message.get('role') == 'tool':
323
+ total_tokens += token_counter(str(message['name']))
324
+
325
+ return total_tokens
326
+
327
+ def _calculate_dynamic_importance(
328
+ self,
329
+ message: Dict[str, Any],
330
+ index: int,
331
+ total_messages: int,
332
+ message_pairs: List[Tuple[int, int]]
333
+ ) -> MessageImportance:
334
+ """
335
+ Calculate dynamic importance based on position, content, and context.
336
+
337
+ Args:
338
+ message: The message to evaluate
339
+ index: Position of the message in the conversation
340
+ total_messages: Total number of messages
341
+ message_pairs: List of message pair ranges
342
+
343
+ Returns:
344
+ MessageImportance level
345
+ """
346
+ role = message.get('role', '')
347
+ content = str(message.get('content', ''))
348
+
349
+ # System messages are always CRITICAL
350
+ if role == 'system':
351
+ return MessageImportance.CRITICAL
352
+
353
+ # Check if this is a final_answer or ask_question tool call (HIGH importance)
354
+ if role == 'assistant' and message.get('tool_calls'):
355
+ tool_calls = message.get('tool_calls', [])
356
+ if any(tc.get('function', {}).get('name') in ['final_answer', 'ask_question']
357
+ for tc in tool_calls):
358
+ return MessageImportance.HIGH
359
+
360
+ # Check if this is an error response (HIGH importance until resolved)
361
+ if self._is_tool_error_response(message):
362
+ return MessageImportance.HIGH
363
+
364
+ # Position-based importance (first N pairs are CRITICAL, last N pairs are HIGH)
365
+ if total_messages <= 10:
366
+ # For short conversations, keep everything at MEDIUM or higher
367
+ return MessageImportance.MEDIUM
368
+
369
+ # Find which pair this message belongs to
370
+ current_pair_index = None
371
+ for pair_idx, (start_idx, end_idx) in enumerate(message_pairs):
372
+ if start_idx <= index <= end_idx:
373
+ current_pair_index = pair_idx
374
+ break
375
+
376
+ if current_pair_index is not None:
377
+ # First N pairs are CRITICAL
378
+ if current_pair_index < self._num_initial_pairs_critical:
379
+ return MessageImportance.CRITICAL
380
+
381
+ # Last N pairs are HIGH
382
+ if current_pair_index >= len(message_pairs) - self._num_recent_pairs_for_high_importance:
383
+ return MessageImportance.HIGH
384
+
385
+ # Content-based importance adjustments
386
+ if role == 'user':
387
+ # User queries are generally important
388
+ return MessageImportance.MEDIUM
389
+ elif role == 'assistant':
390
+ # Assistant responses vary by content length and complexity
391
+ if len(content) > 500: # Long responses might be more important
392
+ return MessageImportance.MEDIUM
393
+ else:
394
+ return MessageImportance.LOW
395
+ elif role == 'tool':
396
+ # Tool responses are generally MEDIUM unless they're errors
397
+ return MessageImportance.MEDIUM
398
+
399
+ # Default importance
400
+ return MessageImportance.LOW
401
+
402
+ def categorize_message(self, message: Dict[str, Any], index: int, total_messages: int) -> Tuple[MessageType, MessageImportance]:
403
+ """
404
+ Categorize a message and determine its base importance.
405
+
406
+ Args:
407
+ message: The message to categorize
408
+ index: Position of the message in the conversation
409
+ total_messages: Total number of messages in the conversation
410
+
411
+ Returns:
412
+ Tuple of (MessageType, MessageImportance)
413
+ """
414
+ role = message.get('role', '')
415
+ content = message.get('content', '')
416
+
417
+ # Determine message type
418
+ if role == 'system':
419
+ msg_type = MessageType.SYSTEM
420
+ elif role == 'user':
421
+ msg_type = MessageType.USER_QUERY
422
+ elif role == 'tool':
423
+ if self._is_tool_error_response(message):
424
+ msg_type = MessageType.TOOL_ERROR
425
+ else:
426
+ msg_type = MessageType.TOOL_RESPONSE
427
+ elif role == 'assistant':
428
+ if message.get('tool_calls'):
429
+ # Check if this is a final_answer or ask_question tool call
430
+ tool_calls = message.get('tool_calls', [])
431
+ if any(tc.get('function', {}).get('name') in ['final_answer', 'ask_question']
432
+ for tc in tool_calls):
433
+ msg_type = MessageType.FINAL_ANSWER
434
+ else:
435
+ msg_type = MessageType.TOOL_CALL
436
+ else:
437
+ msg_type = MessageType.ASSISTANT_RESPONSE
438
+ else:
439
+ msg_type = MessageType.ASSISTANT_RESPONSE
440
+
441
+ # Calculate message pairs for dynamic importance
442
+ message_pairs = self._calculate_message_pairs()
443
+
444
+ # Calculate dynamic importance
445
+ importance = self._calculate_dynamic_importance(message, index, total_messages, message_pairs)
446
+
447
+ return msg_type, importance
448
+
449
+ def add_message_metadata(
450
+ self,
451
+ message: Dict[str, Any],
452
+ token_count: int,
453
+ position: int,
454
+ total_messages: int
455
+ ) -> None:
456
+ """
457
+ Add metadata for a message and update tool call pairs.
458
+
459
+ Args:
460
+ message: The message to add metadata for
461
+ token_count: Number of tokens in the message
462
+ position: Position of the message in the conversation
463
+ total_messages: Total number of messages in the conversation
464
+ """
465
+ # Categorize the message
466
+ msg_type, base_importance = self.categorize_message(message, position, total_messages)
467
+
468
+ # Extract task information
469
+ task_id = self._extract_task_id(message)
470
+ if task_id:
471
+ self.active_tasks.add(task_id)
472
+
473
+ # Check if this is an error message
474
+ is_error = self._is_tool_error_response(message)
475
+
476
+ # Extract tool call ID - handle both tool calls and tool responses
477
+ tool_call_id = None
478
+ if message.get('role') == 'tool':
479
+ # Tool response - get tool_call_id directly
480
+ tool_call_id = message.get('tool_call_id')
481
+ elif message.get('role') == 'assistant' and message.get('tool_calls'):
482
+ # Tool call - get the first tool call ID (assuming single tool call per message)
483
+ tool_calls = message.get('tool_calls', [])
484
+ if tool_calls:
485
+ tool_call_id = tool_calls[0].get('id')
486
+
487
+ # Create metadata
488
+ metadata = MessageMetadata(
489
+ message_type=msg_type,
490
+ importance=base_importance, # Will be recalculated dynamically
491
+ created_at=time.time(),
492
+ token_count=token_count,
493
+ is_error=is_error,
494
+ error_resolved=False,
495
+ part_of_task=task_id,
496
+ task_completed=task_id in self.completed_tasks if task_id else False,
497
+ tool_call_id=tool_call_id,
498
+ can_summarize=msg_type not in [MessageType.SYSTEM, MessageType.FINAL_ANSWER],
499
+ summary=None
500
+ )
501
+
502
+ # Add to metadata list
503
+ self.message_metadata.append(metadata)
504
+
505
+ # Update tool call pairs
506
+ self._update_tool_call_pairs()
507
+
508
+ # Update resolved errors
509
+ self._update_resolved_errors()
510
+
511
+ # Synchronize tool call pair importance levels
512
+ self._synchronize_tool_call_pairs()
513
+
514
+ self.logger.debug(f"Added metadata for message at position {position}: {msg_type.value}, {base_importance.value}, tool_call_id: {tool_call_id}")
515
+
516
+ def _update_tool_call_pairs(self) -> None:
517
+ """Update the tool call pairs mapping based on current messages."""
518
+ self._tool_call_pairs.clear()
519
+
520
+ # Find all tool calls and their responses
521
+ for i, metadata in enumerate(self.message_metadata):
522
+ if metadata.tool_call_id:
523
+ if metadata.message_type in [MessageType.TOOL_CALL, MessageType.FINAL_ANSWER]:
524
+ # This is a tool call, look for its response
525
+ for j in range(i + 1, len(self.message_metadata)):
526
+ response_meta = self.message_metadata[j]
527
+ if (response_meta.tool_call_id == metadata.tool_call_id and
528
+ response_meta.message_type in [MessageType.TOOL_RESPONSE, MessageType.TOOL_ERROR]):
529
+ self._tool_call_pairs[metadata.tool_call_id] = (i, j)
530
+ break
531
+
532
+ def _recalculate_all_importance_levels(self) -> None:
533
+ """Recalculate importance levels for all messages based on current context."""
534
+ if not self.message_metadata:
535
+ return
536
+
537
+ # Calculate message pairs for context
538
+ message_pairs = self._calculate_message_pairs()
539
+ total_messages = len(self.message_metadata)
540
+
541
+ # Recalculate importance for each message
542
+ for i, metadata in enumerate(self.message_metadata):
543
+ # We need the original message to recalculate importance
544
+ # For now, we'll use a simplified approach based on message type and position
545
+ new_importance = self._calculate_positional_importance(i, total_messages, message_pairs, metadata)
546
+ metadata.importance = new_importance
547
+
548
+ # After recalculating all, synchronize tool call pairs
549
+ self._synchronize_tool_call_pairs()
550
+
551
+ self.logger.debug(f"Recalculated importance levels for {total_messages} messages")
552
+
553
+ def _calculate_positional_importance(
554
+ self,
555
+ index: int,
556
+ total_messages: int,
557
+ message_pairs: List[Tuple[int, int]],
558
+ metadata: MessageMetadata
559
+ ) -> MessageImportance:
560
+ """Calculate importance based on position and message type."""
561
+
562
+ # System messages are always CRITICAL
563
+ if metadata.message_type == MessageType.SYSTEM:
564
+ return MessageImportance.CRITICAL
565
+
566
+ # Final answers are HIGH
567
+ if metadata.message_type == MessageType.FINAL_ANSWER:
568
+ return MessageImportance.HIGH
569
+
570
+ # Errors are HIGH until resolved
571
+ if metadata.is_error and not metadata.error_resolved:
572
+ return MessageImportance.HIGH
573
+
574
+ # Position-based importance
575
+ if total_messages <= 10:
576
+ return MessageImportance.MEDIUM
577
+
578
+ # Find which pair this message belongs to
579
+ current_pair_index = None
580
+ for pair_idx, (start_idx, end_idx) in enumerate(message_pairs):
581
+ if start_idx <= index <= end_idx:
582
+ current_pair_index = pair_idx
583
+ break
584
+
585
+ if current_pair_index is not None:
586
+ # First N pairs are CRITICAL
587
+ if current_pair_index < self._num_initial_pairs_critical:
588
+ return MessageImportance.CRITICAL
589
+
590
+ # Last N pairs are HIGH
591
+ if current_pair_index >= len(message_pairs) - self._num_recent_pairs_for_high_importance:
592
+ return MessageImportance.HIGH
593
+
594
+ # Default based on message type
595
+ if metadata.message_type in [MessageType.USER_QUERY, MessageType.TOOL_RESPONSE]:
596
+ return MessageImportance.MEDIUM
597
+
598
+ return MessageImportance.LOW
599
+
600
+ def _calculate_message_pairs(self) -> List[Tuple[int, int]]:
601
+ """Calculate logical message pairs for positional importance."""
602
+ pairs = []
603
+ i = 0
604
+
605
+ while i < len(self.message_metadata):
606
+ metadata = self.message_metadata[i]
607
+
608
+ # System message stands alone
609
+ if metadata.message_type == MessageType.SYSTEM:
610
+ pairs.append((i, i))
611
+ i += 1
612
+ continue
613
+
614
+ # User message followed by assistant response
615
+ if metadata.message_type == MessageType.USER_QUERY:
616
+ if i + 1 < len(self.message_metadata):
617
+ next_meta = self.message_metadata[i + 1]
618
+ if next_meta.message_type in [MessageType.ASSISTANT_RESPONSE, MessageType.TOOL_CALL]:
619
+ pairs.append((i, i + 1))
620
+ i += 2
621
+ continue
622
+
623
+ # User message without response
624
+ pairs.append((i, i))
625
+ i += 1
626
+ continue
627
+
628
+ # Tool call with response
629
+ if metadata.tool_call_id and metadata.tool_call_id in self._tool_call_pairs:
630
+ call_idx, response_idx = self._tool_call_pairs[metadata.tool_call_id]
631
+ if i == call_idx:
632
+ pairs.append((call_idx, response_idx))
633
+ i = response_idx + 1
634
+ continue
635
+
636
+ # Single message
637
+ pairs.append((i, i))
638
+ i += 1
639
+
640
+ return pairs
641
+
642
+ def _update_resolved_errors(self) -> None:
643
+ """Update the set of resolved error tool call IDs."""
644
+ self._resolved_errors.clear()
645
+
646
+ # Track tool calls that had errors but later succeeded
647
+ error_tool_calls = set()
648
+ success_tool_calls = set()
649
+
650
+ for metadata in self.message_metadata:
651
+ if metadata.tool_call_id:
652
+ if metadata.is_error:
653
+ error_tool_calls.add(metadata.tool_call_id)
654
+ elif metadata.message_type in [MessageType.TOOL_RESPONSE]:
655
+ # Check if this is a successful response (not an error)
656
+ success_tool_calls.add(metadata.tool_call_id)
657
+
658
+ # Find tool functions that had both errors and successes
659
+ for tool_call_id in self._tool_call_pairs:
660
+ call_idx, response_idx = self._tool_call_pairs[tool_call_id]
661
+
662
+ if (call_idx < len(self.message_metadata) and
663
+ response_idx < len(self.message_metadata)):
664
+
665
+ call_meta = self.message_metadata[call_idx]
666
+ response_meta = self.message_metadata[response_idx]
667
+
668
+ # Check if there's a later successful call to the same function
669
+ if response_meta.is_error:
670
+ function_name = self._extract_function_name(call_meta, call_idx)
671
+ if function_name and self._has_later_success(function_name, call_idx):
672
+ self._resolved_errors.add(tool_call_id)
673
+ response_meta.error_resolved = True
674
+
675
+ def _extract_function_name(self, metadata: MessageMetadata, message_index: int) -> Optional[str]:
676
+ """Extract function name from a tool call message."""
677
+ # This would need access to the actual message content
678
+ # For now, return a placeholder - this should be implemented based on message structure
679
+ return f"function_{message_index}" # Placeholder
680
+
681
+ def _has_later_success(self, function_name: str, error_position: int) -> bool:
682
+ """Check if there's a later successful call to the same function."""
683
+ # Look for successful calls to the same function after the error
684
+ for i in range(error_position + 1, len(self.message_metadata)):
685
+ metadata = self.message_metadata[i]
686
+ if (metadata.message_type == MessageType.TOOL_RESPONSE and
687
+ not metadata.is_error):
688
+ # Check if this is the same function (simplified check)
689
+ return True
690
+ return False
691
+
692
+ def _synchronize_tool_call_pairs(self) -> None:
693
+ """Ensure tool call pairs have synchronized importance levels."""
694
+ for tool_call_id, (call_idx, response_idx) in self._tool_call_pairs.items():
695
+ if (call_idx < len(self.message_metadata) and
696
+ response_idx < len(self.message_metadata)):
697
+
698
+ call_meta = self.message_metadata[call_idx]
699
+ response_meta = self.message_metadata[response_idx]
700
+
701
+ # Use the higher importance level for both
702
+ importance_order = [
703
+ MessageImportance.TEMP,
704
+ MessageImportance.LOW,
705
+ MessageImportance.MEDIUM,
706
+ MessageImportance.HIGH,
707
+ MessageImportance.CRITICAL
708
+ ]
709
+
710
+ call_priority = importance_order.index(call_meta.importance)
711
+ response_priority = importance_order.index(response_meta.importance)
712
+
713
+ target_importance = importance_order[max(call_priority, response_priority)]
714
+
715
+ # Update both to use the higher importance
716
+ call_meta.importance = target_importance
717
+ response_meta.importance = target_importance
718
+
719
+ self.logger.debug(f"Synchronized tool call pair {tool_call_id}: both set to {target_importance.value}")
720
+
721
+ def _extract_task_id(self, message: Dict[str, Any]) -> Optional[str]:
722
+ """Extract task identifier from message content."""
723
+ # Simple implementation - could be enhanced with more sophisticated parsing
724
+ content = str(message.get('content', ''))
725
+
726
+ # Look for task patterns
727
+ if 'task:' in content.lower():
728
+ parts = content.lower().split('task:')
729
+ if len(parts) > 1:
730
+ task_part = parts[1].split()[0] if parts[1].split() else None
731
+ return f"task_{task_part}" if task_part else None
732
+
733
+ return None
734
+
735
+ def _is_tool_error_response(self, message: Dict[str, Any]) -> bool:
736
+ """
737
+ Check if a tool response message represents an error.
738
+
739
+ Args:
740
+ message: The tool response message to check
741
+
742
+ Returns:
743
+ True if the message represents a tool error
744
+ """
745
+ if message.get('role') != 'tool':
746
+ return False
747
+
748
+ content = str(message.get('content', '')).strip().lower()
749
+
750
+ # Check if content starts with "error"
751
+ return content.startswith('error')
752
+
753
+ def calculate_memory_pressure(self, total_tokens: int) -> float:
754
+ """Calculate current memory pressure (0.0 to 1.0)."""
755
+ return min(1.0, total_tokens / self.max_tokens)
756
+
757
+ def should_optimize_memory(self, total_tokens: int) -> bool:
758
+ """Determine if memory optimization is needed."""
759
+ return total_tokens > self.target_tokens
760
+
761
+ def optimize_messages(
762
+ self,
763
+ messages: List[Dict[str, Any]],
764
+ token_counter: callable
765
+ ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
766
+ """
767
+ Optimize message list by removing/summarizing less important messages
768
+ while preserving tool call/response pairs and maintaining conversation integrity.
769
+ """
770
+ # Ensure metadata is up to date
771
+ if len(messages) > len(self.message_metadata):
772
+ for i in range(len(self.message_metadata), len(messages)):
773
+ msg = messages[i]
774
+ token_count = self._count_message_tokens(msg, token_counter)
775
+ self.add_message_metadata(msg, token_count, i, len(messages))
776
+
777
+ if len(messages) != len(self.message_metadata):
778
+ self.logger.warning("Message count mismatch with metadata")
779
+ return messages, {"error": "Message metadata mismatch"}
780
+
781
+ # Recalculate importance levels based on current conversation state
782
+ self._recalculate_all_importance_levels()
783
+
784
+ # Calculate current token usage using proper token counting
785
+ total_tokens = sum(self._count_message_tokens(msg, token_counter) for msg in messages)
786
+
787
+ if not self.should_optimize_memory(total_tokens):
788
+ return messages, {'action': 'none', 'reason': 'within_limits'}
789
+
790
+ memory_pressure = self.calculate_memory_pressure(total_tokens)
791
+ context = {'memory_pressure': memory_pressure}
792
+
793
+ self.logger.info(f"Memory optimization needed. Total tokens: {total_tokens}, pressure: {memory_pressure:.2f}")
794
+
795
+ # Find all tool call/response pairs
796
+ tool_call_pairs = self._tool_call_pairs
797
+
798
+ # Create sets of message indices that must be kept together
799
+ protected_indices = set()
800
+ pair_groups = {} # Maps group_id to set of indices
801
+
802
+ for tool_call_id, (call_idx, response_idx) in tool_call_pairs.items():
803
+ group_id = f"pair_{tool_call_id}"
804
+ pair_groups[group_id] = {call_idx, response_idx}
805
+ protected_indices.update({call_idx, response_idx})
806
+
807
+ # Always protect system message and recent critical messages
808
+ for i, meta in enumerate(self.message_metadata):
809
+ if meta.importance == MessageImportance.CRITICAL:
810
+ protected_indices.add(i)
811
+
812
+ # Build optimized message list
813
+ optimized_messages = []
814
+ optimized_metadata = []
815
+ tokens_used = 0
816
+ tokens_saved = 0
817
+ messages_removed = 0
818
+ messages_summarized = 0
819
+
820
+ # Process messages in order, respecting pairs and importance
821
+ i = 0
822
+ while i < len(messages):
823
+ msg = messages[i]
824
+ meta = self.message_metadata[i]
825
+ msg_tokens = self._count_message_tokens(msg, token_counter)
826
+
827
+ # Check if this message is part of a protected pair
828
+ current_group = None
829
+ for group_id, indices in pair_groups.items():
830
+ if i in indices:
831
+ current_group = group_id
832
+ break
833
+
834
+ if current_group:
835
+ # Process the entire group
836
+ group_indices = sorted(pair_groups[current_group])
837
+ group_messages = [messages[idx] for idx in group_indices]
838
+ group_metadata = [self.message_metadata[idx] for idx in group_indices]
839
+ group_tokens = sum(self._count_message_tokens(messages[idx], token_counter) for idx in group_indices)
840
+
841
+ # Check if we should keep this group
842
+ group_importance_values = [self.message_metadata[idx].importance.value for idx in group_indices]
843
+ group_importance = max(group_importance_values, key=lambda x:
844
+ {"critical": 4, "high": 3, "medium": 2, "low": 1, "temp": 0}.get(x, 0)
845
+ )
846
+ should_keep_group = (
847
+ group_importance == MessageImportance.CRITICAL or
848
+ self.strategy.should_keep_message(msg, meta, context)
849
+ )
850
+
851
+ if should_keep_group and tokens_used + group_tokens <= self.target_tokens:
852
+ # Keep the entire group
853
+ optimized_messages.extend(group_messages)
854
+ optimized_metadata.extend(group_metadata)
855
+ tokens_used += group_tokens
856
+ self.logger.debug(f"Kept tool call pair group: {group_indices}")
857
+ else:
858
+ # Skip the entire group
859
+ tokens_saved += group_tokens
860
+ messages_removed += len(group_indices)
861
+ self.logger.debug(f"Removed tool call pair group: {group_indices}")
862
+
863
+ # Skip to after this group
864
+ i = max(group_indices) + 1
865
+ continue
866
+
867
+ # Single message processing
868
+ if i in protected_indices:
869
+ # Always keep protected messages
870
+ optimized_messages.append(msg)
871
+ optimized_metadata.append(meta)
872
+ tokens_used += msg_tokens
873
+ elif self.strategy.should_keep_message(msg, meta, context) and tokens_used + msg_tokens <= self.target_tokens:
874
+ # Keep this message
875
+ optimized_messages.append(msg)
876
+ optimized_metadata.append(meta)
877
+ tokens_used += msg_tokens
878
+ elif self.enable_summarization and meta.can_summarize and not meta.summary:
879
+ # Try to summarize
880
+ summary = self._summarize_message(msg)
881
+ summary_tokens = token_counter(summary)
882
+
883
+ if tokens_used + summary_tokens <= self.target_tokens:
884
+ # Create summarized message
885
+ summarized_msg = msg.copy()
886
+ summarized_msg['content'] = summary
887
+ optimized_messages.append(summarized_msg)
888
+
889
+ # Update metadata
890
+ meta.summary = summary
891
+ optimized_metadata.append(meta)
892
+
893
+ tokens_used += summary_tokens
894
+ tokens_saved += msg_tokens - summary_tokens
895
+ messages_summarized += 1
896
+ else:
897
+ # Skip this message
898
+ tokens_saved += msg_tokens
899
+ messages_removed += 1
900
+ else:
901
+ # Skip this message
902
+ tokens_saved += msg_tokens
903
+ messages_removed += 1
904
+
905
+ i += 1
906
+
907
+ # Update metadata list
908
+ self.message_metadata = optimized_metadata
909
+
910
+ # Update statistics
911
+ self.stats['messages_removed'] += messages_removed
912
+ self.stats['messages_summarized'] += messages_summarized
913
+ self.stats['tokens_saved'] += tokens_saved
914
+ self.stats['memory_optimizations'] += 1
915
+
916
+ optimization_info = {
917
+ 'action': 'optimized',
918
+ 'original_tokens': total_tokens,
919
+ 'final_tokens': tokens_used,
920
+ 'tokens_saved': tokens_saved,
921
+ 'messages_removed': messages_removed,
922
+ 'messages_summarized': messages_summarized,
923
+ 'memory_pressure_before': memory_pressure,
924
+ 'memory_pressure_after': self.calculate_memory_pressure(tokens_used),
925
+ 'tool_pairs_preserved': len(tool_call_pairs)
926
+ }
927
+
928
+ self.logger.info(f"Memory optimization completed: {optimization_info}")
929
+
930
+ # Final validation: ensure tool call integrity is maintained
931
+ final_pairs = self._tool_call_pairs
932
+ if len(final_pairs) != len([pair for pair in tool_call_pairs.values() if all(idx < len(optimized_messages) for idx in pair)]):
933
+ self.logger.warning("Tool call/response integrity may be compromised")
934
+
935
+ return optimized_messages, optimization_info
936
+
937
+ def _summarize_message(self, message: Dict[str, Any]) -> str:
938
+ """Create a summary of a message."""
939
+ content = str(message.get('content', ''))
940
+ role = message.get('role', '')
941
+
942
+ # Simple summarization - could be enhanced with LLM-based summarization
943
+ if role == 'tool':
944
+ tool_name = message.get('name', 'unknown')
945
+ if len(content) > 200:
946
+ return f"[SUMMARY] Tool {tool_name} executed: {content[:100]}... [truncated]"
947
+ return content
948
+
949
+ if role == 'assistant' and len(content) > 300:
950
+ return f"[SUMMARY] Assistant response: {content[:150]}... [truncated]"
951
+
952
+ if len(content) > 200:
953
+ return f"[SUMMARY] {content[:100]}... [truncated]"
954
+
955
+ return content
956
+
957
+ def get_memory_stats(self) -> Dict[str, Any]:
958
+ """Get memory management statistics."""
959
+ return {
960
+ **self.stats,
961
+ 'active_tasks': len(self.active_tasks),
962
+ 'completed_tasks': len(self.completed_tasks),
963
+ 'total_messages': len(self.message_metadata),
964
+ 'critical_messages': sum(1 for m in self.message_metadata if m.importance == MessageImportance.CRITICAL),
965
+ 'error_messages': sum(1 for m in self.message_metadata if m.is_error),
966
+ 'resolved_errors': sum(1 for m in self.message_metadata if m.is_error and m.error_resolved)
967
+ }
968
+
969
+ def reset_stats(self) -> None:
970
+ """Reset memory management statistics."""
971
+ self.stats = {
972
+ 'messages_removed': 0,
973
+ 'messages_summarized': 0,
974
+ 'tokens_saved': 0,
975
+ 'memory_optimizations': 0
976
+ }
977
+
978
+ def clear_completed_tasks(self) -> None:
979
+ """Clear metadata for completed tasks to free up memory."""
980
+ # Remove metadata for completed, non-critical messages
981
+ kept_metadata = []
982
+ removed_count = 0
983
+
984
+ for metadata in self.message_metadata:
985
+ if (metadata.task_completed and
986
+ metadata.importance not in [MessageImportance.CRITICAL, MessageImportance.HIGH] and
987
+ time.time() - metadata.created_at > 1800): # 30 minutes old
988
+ removed_count += 1
989
+ else:
990
+ kept_metadata.append(metadata)
991
+
992
+ self.message_metadata = kept_metadata
993
+ self.logger.info(f"Cleared {removed_count} completed task metadata entries")
994
+
995
+ def to_dict(self) -> Dict[str, Any]:
996
+ """Serialize memory manager state."""
997
+ return {
998
+ 'max_tokens': self.max_tokens,
999
+ 'target_tokens': self.target_tokens,
1000
+ 'enable_summarization': self.enable_summarization,
1001
+ 'active_tasks': list(self.active_tasks),
1002
+ 'completed_tasks': list(self.completed_tasks),
1003
+ 'conversation_summary': self.conversation_summary,
1004
+ 'task_summaries': self.task_summaries,
1005
+ 'stats': self.stats,
1006
+ 'message_metadata': [
1007
+ {
1008
+ 'message_type': meta.message_type.value,
1009
+ 'importance': meta.importance.value,
1010
+ 'created_at': meta.created_at,
1011
+ 'token_count': meta.token_count,
1012
+ 'is_error': meta.is_error,
1013
+ 'error_resolved': meta.error_resolved,
1014
+ 'part_of_task': meta.part_of_task,
1015
+ 'task_completed': meta.task_completed,
1016
+ 'can_summarize': meta.can_summarize,
1017
+ 'summary': meta.summary,
1018
+ 'related_messages': meta.related_messages,
1019
+ 'tool_call_id': meta.tool_call_id
1020
+ }
1021
+ for meta in self.message_metadata
1022
+ ]
1023
+ }
1024
+
1025
+ @classmethod
1026
+ def from_dict(
1027
+ cls,
1028
+ data: Dict[str, Any],
1029
+ strategy: MemoryStrategy = None,
1030
+ logger: Optional[logging.Logger] = None
1031
+ ) -> 'MemoryManager':
1032
+ """Deserialize memory manager state."""
1033
+ manager = cls(
1034
+ max_tokens=data.get('max_tokens', 8000),
1035
+ target_tokens=data.get('target_tokens', 6000),
1036
+ strategy=strategy,
1037
+ enable_summarization=data.get('enable_summarization', True),
1038
+ logger=logger
1039
+ )
1040
+
1041
+ manager.active_tasks = set(data.get('active_tasks', []))
1042
+ manager.completed_tasks = set(data.get('completed_tasks', []))
1043
+ manager.conversation_summary = data.get('conversation_summary')
1044
+ manager.task_summaries = data.get('task_summaries', {})
1045
+ manager.stats = data.get('stats', manager.stats)
1046
+
1047
+ # Restore message metadata
1048
+ metadata_list = data.get('message_metadata', [])
1049
+ manager.message_metadata = [
1050
+ MessageMetadata(
1051
+ message_type=MessageType(meta['message_type']),
1052
+ importance=MessageImportance(meta['importance']),
1053
+ created_at=meta['created_at'],
1054
+ token_count=meta['token_count'],
1055
+ is_error=meta['is_error'],
1056
+ error_resolved=meta['error_resolved'],
1057
+ part_of_task=meta['part_of_task'],
1058
+ task_completed=meta['task_completed'],
1059
+ can_summarize=meta['can_summarize'],
1060
+ summary=meta['summary'],
1061
+ related_messages=meta['related_messages'],
1062
+ tool_call_id=meta['tool_call_id']
1063
+ )
1064
+ for meta in metadata_list
1065
+ ]
1066
+
1067
+ return manager