massgen 0.1.4__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of massgen might be problematic. Click here for more details.

Files changed (46) hide show
  1. massgen/__init__.py +1 -1
  2. massgen/chat_agent.py +340 -20
  3. massgen/cli.py +326 -19
  4. massgen/configs/README.md +52 -10
  5. massgen/configs/memory/gpt5mini_gemini_baseline_research_to_implementation.yaml +94 -0
  6. massgen/configs/memory/gpt5mini_gemini_context_window_management.yaml +187 -0
  7. massgen/configs/memory/gpt5mini_gemini_research_to_implementation.yaml +127 -0
  8. massgen/configs/memory/gpt5mini_high_reasoning_gemini.yaml +107 -0
  9. massgen/configs/memory/single_agent_compression_test.yaml +64 -0
  10. massgen/configs/tools/custom_tools/multimodal_tools/playwright_with_img_understanding.yaml +98 -0
  11. massgen/configs/tools/custom_tools/multimodal_tools/understand_video_example.yaml +54 -0
  12. massgen/memory/README.md +277 -0
  13. massgen/memory/__init__.py +26 -0
  14. massgen/memory/_base.py +193 -0
  15. massgen/memory/_compression.py +237 -0
  16. massgen/memory/_context_monitor.py +211 -0
  17. massgen/memory/_conversation.py +255 -0
  18. massgen/memory/_fact_extraction_prompts.py +333 -0
  19. massgen/memory/_mem0_adapters.py +257 -0
  20. massgen/memory/_persistent.py +687 -0
  21. massgen/memory/docker-compose.qdrant.yml +36 -0
  22. massgen/memory/docs/DESIGN.md +388 -0
  23. massgen/memory/docs/QUICKSTART.md +409 -0
  24. massgen/memory/docs/SUMMARY.md +319 -0
  25. massgen/memory/docs/agent_use_memory.md +408 -0
  26. massgen/memory/docs/orchestrator_use_memory.md +586 -0
  27. massgen/memory/examples.py +237 -0
  28. massgen/orchestrator.py +207 -7
  29. massgen/tests/memory/test_agent_compression.py +174 -0
  30. massgen/tests/memory/test_context_window_management.py +286 -0
  31. massgen/tests/memory/test_force_compression.py +154 -0
  32. massgen/tests/memory/test_simple_compression.py +147 -0
  33. massgen/tests/test_agent_memory.py +534 -0
  34. massgen/tests/test_conversation_memory.py +382 -0
  35. massgen/tests/test_orchestrator_memory.py +620 -0
  36. massgen/tests/test_persistent_memory.py +435 -0
  37. massgen/token_manager/token_manager.py +6 -0
  38. massgen/tools/__init__.py +8 -0
  39. massgen/tools/_planning_mcp_server.py +520 -0
  40. massgen/tools/planning_dataclasses.py +434 -0
  41. {massgen-0.1.4.dist-info → massgen-0.1.5.dist-info}/METADATA +109 -76
  42. {massgen-0.1.4.dist-info → massgen-0.1.5.dist-info}/RECORD +46 -12
  43. {massgen-0.1.4.dist-info → massgen-0.1.5.dist-info}/WHEEL +0 -0
  44. {massgen-0.1.4.dist-info → massgen-0.1.5.dist-info}/entry_points.txt +0 -0
  45. {massgen-0.1.4.dist-info → massgen-0.1.5.dist-info}/licenses/LICENSE +0 -0
  46. {massgen-0.1.4.dist-info → massgen-0.1.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,237 @@
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Context Window Compression
4
+
5
+ Automatically compresses conversation history when context window fills up.
6
+ Since messages are already recorded to persistent memory after each turn,
7
+ compression simply removes old messages from active context while keeping
8
+ them accessible via semantic search.
9
+ """
10
+
11
+ from typing import Any, Callable, Dict, List, Optional
12
+
13
+ from ..logger_config import logger
14
+ from ..token_manager.token_manager import TokenCostCalculator
15
+ from ._conversation import ConversationMemory
16
+ from ._persistent import PersistentMemoryBase
17
+
18
+
19
+ class CompressionStats:
20
+ """Statistics about a compression operation."""
21
+
22
+ def __init__(
23
+ self,
24
+ messages_removed: int = 0,
25
+ tokens_removed: int = 0,
26
+ messages_kept: int = 0,
27
+ tokens_kept: int = 0,
28
+ ):
29
+ self.messages_removed = messages_removed
30
+ self.tokens_removed = tokens_removed
31
+ self.messages_kept = messages_kept
32
+ self.tokens_kept = tokens_kept
33
+
34
+
35
+ class ContextCompressor:
36
+ """
37
+ Compresses conversation history when context window fills up.
38
+
39
+ Strategy:
40
+ - Messages are already recorded to persistent_memory after each turn
41
+ - Compression removes old messages from conversation_memory
42
+ - Recent messages stay in active context
43
+ - Old messages remain accessible via semantic retrieval
44
+
45
+ Features:
46
+ - Token-aware compression (not just message count)
47
+ - Preserves system messages
48
+ - Keeps most recent messages
49
+ - Detailed compression logging
50
+
51
+ Example:
52
+ >>> compressor = ContextCompressor(
53
+ ... token_calculator=TokenCostCalculator(),
54
+ ... conversation_memory=conversation_memory,
55
+ ... persistent_memory=persistent_memory
56
+ ... )
57
+ >>>
58
+ >>> stats = await compressor.compress_if_needed(
59
+ ... messages=messages,
60
+ ... current_tokens=96000,
61
+ ... target_tokens=51200
62
+ ... )
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ token_calculator: TokenCostCalculator,
68
+ conversation_memory: ConversationMemory,
69
+ persistent_memory: Optional[PersistentMemoryBase] = None,
70
+ on_compress: Optional[Callable[[CompressionStats], None]] = None,
71
+ ):
72
+ """
73
+ Initialize context compressor.
74
+
75
+ Args:
76
+ token_calculator: Calculator for token estimation
77
+ conversation_memory: Conversation memory to compress
78
+ persistent_memory: Optional persistent memory (for logging purposes)
79
+ on_compress: Optional callback called after compression
80
+ """
81
+ self.token_calculator = token_calculator
82
+ self.conversation_memory = conversation_memory
83
+ self.persistent_memory = persistent_memory
84
+ self.on_compress = on_compress
85
+
86
+ # Stats tracking
87
+ self.total_compressions = 0
88
+ self.total_messages_removed = 0
89
+ self.total_tokens_removed = 0
90
+
91
+ async def compress_if_needed(
92
+ self,
93
+ messages: List[Dict[str, Any]],
94
+ current_tokens: int,
95
+ target_tokens: int,
96
+ should_compress: bool = None,
97
+ ) -> Optional[CompressionStats]:
98
+ """
99
+ Compress messages if needed.
100
+
101
+ Args:
102
+ messages: Current conversation messages
103
+ current_tokens: Current token count
104
+ target_tokens: Target token count after compression
105
+ should_compress: Optional explicit compression flag
106
+ If None, compresses only if current_tokens > target_tokens
107
+
108
+ Returns:
109
+ CompressionStats if compression occurred, None otherwise
110
+ """
111
+ # Determine if we need to compress
112
+ if should_compress is None:
113
+ should_compress = current_tokens > target_tokens
114
+
115
+ if not should_compress:
116
+ return None
117
+
118
+ # Select messages to keep
119
+ messages_to_keep = self._select_messages_to_keep(
120
+ messages=messages,
121
+ target_tokens=target_tokens,
122
+ )
123
+
124
+ if len(messages_to_keep) >= len(messages):
125
+ # No compression needed (already under target)
126
+ logger.debug("All messages fit within target, skipping compression")
127
+ return None
128
+
129
+ # Calculate stats
130
+ messages_removed = len(messages) - len(messages_to_keep)
131
+ messages_to_remove = [msg for msg in messages if msg not in messages_to_keep]
132
+ tokens_removed = self.token_calculator.estimate_tokens(messages_to_remove)
133
+ tokens_kept = self.token_calculator.estimate_tokens(messages_to_keep)
134
+
135
+ # Update conversation memory
136
+ try:
137
+ await self.conversation_memory.clear()
138
+ await self.conversation_memory.add(messages_to_keep)
139
+ except Exception as e:
140
+ logger.error(f"Failed to update conversation memory during compression: {e}")
141
+ return None
142
+
143
+ # Log compression result
144
+ if self.persistent_memory:
145
+ logger.info(
146
+ f"📦 Context compressed: Removed {messages_removed} old messages "
147
+ f"({tokens_removed:,} tokens) from active context.\n"
148
+ f" Kept {len(messages_to_keep)} recent messages ({tokens_kept:,} tokens).\n"
149
+ f" Old messages remain accessible via semantic search.",
150
+ )
151
+ else:
152
+ logger.warning(
153
+ f"⚠️ Context compressed: Removed {messages_removed} old messages "
154
+ f"({tokens_removed:,} tokens) from active context.\n"
155
+ f" Kept {len(messages_to_keep)} recent messages ({tokens_kept:,} tokens).\n"
156
+ f" No persistent memory - old messages NOT retrievable.",
157
+ )
158
+
159
+ # Update stats
160
+ self.total_compressions += 1
161
+ self.total_messages_removed += messages_removed
162
+ self.total_tokens_removed += tokens_removed
163
+
164
+ # Create stats object
165
+ stats = CompressionStats(
166
+ messages_removed=messages_removed,
167
+ tokens_removed=tokens_removed,
168
+ messages_kept=len(messages_to_keep),
169
+ tokens_kept=tokens_kept,
170
+ )
171
+
172
+ # Trigger callback if provided
173
+ if self.on_compress:
174
+ self.on_compress(stats)
175
+
176
+ return stats
177
+
178
+ def _select_messages_to_keep(
179
+ self,
180
+ messages: List[Dict[str, Any]],
181
+ target_tokens: int,
182
+ ) -> List[Dict[str, Any]]:
183
+ """
184
+ Select which messages to keep in active context.
185
+
186
+ Strategy:
187
+ 1. Always keep system messages at the start
188
+ 2. Keep most recent messages that fit in target_tokens
189
+ 3. Remove everything in between
190
+
191
+ Args:
192
+ messages: All messages in conversation
193
+ target_tokens: Target token budget for kept messages
194
+
195
+ Returns:
196
+ List of messages to keep in conversation_memory
197
+ """
198
+ if not messages:
199
+ return []
200
+
201
+ # Separate system messages from others
202
+ system_messages = []
203
+ non_system_messages = []
204
+
205
+ for msg in messages:
206
+ if msg.get("role") == "system":
207
+ system_messages.append(msg)
208
+ else:
209
+ non_system_messages.append(msg)
210
+
211
+ # Start with system messages in kept list
212
+ messages_to_keep = system_messages.copy()
213
+ tokens_so_far = self.token_calculator.estimate_tokens(system_messages)
214
+
215
+ # Work backwards from most recent, adding messages until we hit target
216
+ recent_messages_to_keep = []
217
+ for msg in reversed(non_system_messages):
218
+ msg_tokens = self.token_calculator.estimate_tokens([msg])
219
+ if tokens_so_far + msg_tokens <= target_tokens:
220
+ tokens_so_far += msg_tokens
221
+ recent_messages_to_keep.insert(0, msg) # Maintain order
222
+ else:
223
+ # Hit token limit, stop here
224
+ break
225
+
226
+ # Combine: system messages + recent messages
227
+ messages_to_keep.extend(recent_messages_to_keep)
228
+
229
+ return messages_to_keep
230
+
231
+ def get_stats(self) -> Dict[str, Any]:
232
+ """Get compression statistics."""
233
+ return {
234
+ "total_compressions": self.total_compressions,
235
+ "total_messages_removed": self.total_messages_removed,
236
+ "total_tokens_removed": self.total_tokens_removed,
237
+ }
@@ -0,0 +1,211 @@
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Context Window Monitoring Utility
4
+
5
+ Provides logging and tracking for context window usage during agent execution.
6
+ Helps debug memory and token management by showing real-time context usage.
7
+ """
8
+
9
+ from typing import Any, Dict, List, Optional
10
+
11
+ from ..logger_config import logger
12
+ from ..token_manager.token_manager import TokenCostCalculator
13
+
14
+
15
+ class ContextWindowMonitor:
16
+ """Monitor and log context window usage during agent execution."""
17
+
18
+ def __init__(
19
+ self,
20
+ model_name: str,
21
+ provider: str = "openai",
22
+ trigger_threshold: float = 0.75,
23
+ target_ratio: float = 0.40,
24
+ enabled: bool = True,
25
+ ):
26
+ """
27
+ Initialize context window monitor.
28
+
29
+ Args:
30
+ model_name: Name of the model (e.g., "gpt-4o-mini")
31
+ provider: Provider name (e.g., "openai", "anthropic")
32
+ trigger_threshold: Percentage (0-1) at which to warn about context usage
33
+ target_ratio: Target percentage after compression
34
+ enabled: Whether to enable logging
35
+ """
36
+ self.model_name = model_name
37
+ self.provider = provider
38
+ self.trigger_threshold = trigger_threshold
39
+ self.target_ratio = target_ratio
40
+ self.enabled = enabled
41
+
42
+ # Get model pricing info to determine context window size
43
+ self.calculator = TokenCostCalculator()
44
+ self.pricing = self.calculator.get_model_pricing(provider, model_name)
45
+
46
+ if self.pricing and self.pricing.context_window:
47
+ self.context_window = self.pricing.context_window
48
+ else:
49
+ # Default fallbacks
50
+ self.context_window = 128000 # Common default
51
+ logger.warning(
52
+ f"Could not determine context window for {provider}/{model_name}, " f"using default {self.context_window}",
53
+ )
54
+
55
+ # Tracking
56
+ self.total_input_tokens = 0
57
+ self.total_output_tokens = 0
58
+ self.turn_count = 0
59
+
60
+ def log_context_usage(
61
+ self,
62
+ messages: List[Dict[str, Any]],
63
+ turn_number: Optional[int] = None,
64
+ ) -> Dict[str, Any]:
65
+ """
66
+ Log current context window usage.
67
+
68
+ Args:
69
+ messages: Current conversation messages
70
+ turn_number: Optional turn number to display
71
+
72
+ Returns:
73
+ Dict with usage stats: {
74
+ "current_tokens": int,
75
+ "max_tokens": int,
76
+ "usage_percent": float,
77
+ "should_compress": bool,
78
+ "target_tokens": int
79
+ }
80
+ """
81
+ if not self.enabled:
82
+ return {}
83
+
84
+ # Estimate tokens in current context
85
+ current_tokens = self.calculator.estimate_tokens(messages)
86
+ usage_percent = current_tokens / self.context_window
87
+ should_compress = usage_percent >= self.trigger_threshold
88
+ target_tokens = int(self.context_window * self.target_ratio)
89
+
90
+ # Build log message
91
+ turn_str = f" (Turn {turn_number})" if turn_number is not None else ""
92
+ status_emoji = "📊"
93
+
94
+ if usage_percent >= self.trigger_threshold:
95
+ status_emoji = "⚠️"
96
+ logger.warning(
97
+ f"{status_emoji} Context Window{turn_str}: " f"{current_tokens:,} / {self.context_window:,} tokens " f"({usage_percent*100:.1f}%) - Approaching limit!",
98
+ )
99
+ logger.warning(
100
+ f" Compression threshold reached. Target after compression: " f"{target_tokens:,} tokens ({self.target_ratio*100:.0f}%)",
101
+ )
102
+ elif usage_percent >= 0.50:
103
+ logger.info(
104
+ f"{status_emoji} Context Window{turn_str}: " f"{current_tokens:,} / {self.context_window:,} tokens " f"({usage_percent*100:.1f}%)",
105
+ )
106
+ else:
107
+ logger.info(
108
+ f"{status_emoji} Context Window{turn_str}: " f"{current_tokens:,} / {self.context_window:,} tokens " f"({usage_percent*100:.1f}%)",
109
+ )
110
+
111
+ return {
112
+ "current_tokens": current_tokens,
113
+ "max_tokens": self.context_window,
114
+ "usage_percent": usage_percent,
115
+ "should_compress": should_compress,
116
+ "target_tokens": target_tokens,
117
+ "trigger_threshold": self.trigger_threshold,
118
+ "target_ratio": self.target_ratio,
119
+ }
120
+
121
+ def log_turn_summary(
122
+ self,
123
+ input_tokens: int,
124
+ output_tokens: int,
125
+ turn_number: Optional[int] = None,
126
+ ):
127
+ """
128
+ Log summary for a single turn.
129
+
130
+ Args:
131
+ input_tokens: Input tokens for this turn
132
+ output_tokens: Output tokens for this turn
133
+ turn_number: Optional turn number
134
+ """
135
+ if not self.enabled:
136
+ return
137
+
138
+ self.total_input_tokens += input_tokens
139
+ self.total_output_tokens += output_tokens
140
+ self.turn_count += 1
141
+
142
+ turn_str = f" {turn_number}" if turn_number is not None else f" {self.turn_count}"
143
+
144
+ logger.info(
145
+ f"Turn{turn_str} tokens: {input_tokens:,} input + {output_tokens:,} output = " f"{input_tokens + output_tokens:,} total",
146
+ )
147
+
148
+ def log_session_summary(self):
149
+ """Log overall session summary."""
150
+ if not self.enabled or self.turn_count == 0:
151
+ return
152
+
153
+ total_tokens = self.total_input_tokens + self.total_output_tokens
154
+ avg_per_turn = total_tokens / self.turn_count if self.turn_count > 0 else 0
155
+
156
+ logger.info("=" * 70)
157
+ logger.info("📊 Session Summary:")
158
+ logger.info(f" Total turns: {self.turn_count}")
159
+ logger.info(f" Total input tokens: {self.total_input_tokens:,}")
160
+ logger.info(f" Total output tokens: {self.total_output_tokens:,}")
161
+ logger.info(f" Total tokens: {total_tokens:,}")
162
+ logger.info(f" Average per turn: {avg_per_turn:,.0f} tokens")
163
+ logger.info(f" Context window: {self.context_window:,} tokens")
164
+ logger.info(f" Peak usage: {(total_tokens/self.context_window)*100:.1f}%")
165
+ logger.info("=" * 70)
166
+
167
+ def get_stats(self) -> Dict[str, Any]:
168
+ """Get current monitoring stats."""
169
+ total_tokens = self.total_input_tokens + self.total_output_tokens
170
+
171
+ return {
172
+ "turn_count": self.turn_count,
173
+ "total_input_tokens": self.total_input_tokens,
174
+ "total_output_tokens": self.total_output_tokens,
175
+ "total_tokens": total_tokens,
176
+ "context_window": self.context_window,
177
+ "avg_tokens_per_turn": total_tokens / self.turn_count if self.turn_count > 0 else 0,
178
+ "peak_usage_percent": total_tokens / self.context_window if self.context_window > 0 else 0,
179
+ }
180
+
181
+
182
+ def create_monitor_from_config(
183
+ config: Dict[str, Any],
184
+ model_name: str,
185
+ provider: str = "openai",
186
+ ) -> ContextWindowMonitor:
187
+ """
188
+ Create a context window monitor from YAML config.
189
+
190
+ Args:
191
+ config: Config dict (should have 'memory' section)
192
+ model_name: Model name
193
+ provider: Provider name
194
+
195
+ Returns:
196
+ ContextWindowMonitor instance
197
+ """
198
+ memory_config = config.get("memory", {})
199
+ compression_config = memory_config.get("compression", {})
200
+
201
+ trigger_threshold = compression_config.get("trigger_threshold", 0.75)
202
+ target_ratio = compression_config.get("target_ratio", 0.40)
203
+ enabled = memory_config.get("enabled", True)
204
+
205
+ return ContextWindowMonitor(
206
+ model_name=model_name,
207
+ provider=provider,
208
+ trigger_threshold=trigger_threshold,
209
+ target_ratio=target_ratio,
210
+ enabled=enabled,
211
+ )
@@ -0,0 +1,255 @@
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Conversation memory implementation for MassGen.
4
+
5
+ This module provides in-memory storage for conversation messages, optimized
6
+ for quick access during active chat sessions.
7
+ """
8
+
9
+ import uuid
10
+ from typing import Any, Dict, Iterable, List, Optional, Union
11
+
12
+ from ._base import MemoryBase
13
+
14
+
15
+ class ConversationMemory(MemoryBase):
16
+ """
17
+ In-memory storage for conversation messages.
18
+
19
+ This memory type is designed for short-term storage of ongoing conversations.
20
+ It keeps messages in a simple list structure for fast access and iteration.
21
+
22
+ Features:
23
+ - Fast in-memory access
24
+ - Duplicate detection based on message IDs
25
+ - Index-based deletion
26
+ - State serialization for session persistence
27
+
28
+ Example:
29
+ >>> memory = ConversationMemory()
30
+ >>> await memory.add({"role": "user", "content": "Hello"})
31
+ >>> messages = await memory.get_messages()
32
+ >>> print(len(messages)) # 1
33
+ """
34
+
35
+ def __init__(self) -> None:
36
+ """Initialize an empty conversation memory."""
37
+ super().__init__()
38
+ self.messages: List[Dict[str, Any]] = []
39
+
40
+ def state_dict(self) -> Dict[str, Any]:
41
+ """
42
+ Serialize memory state to a dictionary.
43
+
44
+ Returns:
45
+ Dictionary with 'messages' key containing all stored messages
46
+ """
47
+ return {
48
+ "messages": [msg.copy() for msg in self.messages],
49
+ }
50
+
51
+ def load_state_dict(
52
+ self,
53
+ state_dict: Dict[str, Any],
54
+ strict: bool = True,
55
+ ) -> None:
56
+ """
57
+ Load memory state from a serialized dictionary.
58
+
59
+ Args:
60
+ state_dict: Dictionary containing 'messages' key
61
+ strict: If True, validates the state dictionary structure
62
+
63
+ Raises:
64
+ ValueError: If strict=True and state_dict is invalid
65
+ """
66
+ if strict and "messages" not in state_dict:
67
+ raise ValueError(
68
+ "State dictionary must contain 'messages' key when strict=True",
69
+ )
70
+
71
+ self.messages = []
72
+ for msg_data in state_dict.get("messages", []):
73
+ # Ensure each message is a proper dictionary
74
+ if isinstance(msg_data, dict):
75
+ self.messages.append(msg_data.copy())
76
+
77
+ async def size(self) -> int:
78
+ """
79
+ Get the number of messages in memory.
80
+
81
+ Returns:
82
+ Count of stored messages
83
+ """
84
+ return len(self.messages)
85
+
86
+ async def retrieve(self, *args: Any, **kwargs: Any) -> None:
87
+ """
88
+ Retrieve is not supported for basic conversation memory.
89
+
90
+ Use get_messages() to access all messages directly.
91
+
92
+ Raises:
93
+ NotImplementedError: Always, as basic retrieval is not supported
94
+ """
95
+ raise NotImplementedError(
96
+ f"The retrieve method is not implemented in {self.__class__.__name__}. " "Use get_messages() to access conversation history directly.",
97
+ )
98
+
99
+ async def delete(self, index: Union[Iterable, int]) -> None:
100
+ """
101
+ Delete message(s) by index position.
102
+
103
+ Args:
104
+ index: Single index or iterable of indices to delete
105
+
106
+ Raises:
107
+ IndexError: If any index is out of range
108
+
109
+ Example:
110
+ >>> await memory.delete(0) # Delete first message
111
+ >>> await memory.delete([1, 3, 5]) # Delete multiple messages
112
+ """
113
+ if isinstance(index, int):
114
+ index = [index]
115
+
116
+ # Validate all indices first
117
+ invalid_indices = [i for i in index if i < 0 or i >= len(self.messages)]
118
+
119
+ if invalid_indices:
120
+ raise IndexError(
121
+ f"The following indices do not exist: {invalid_indices}. " f"Valid range is 0-{len(self.messages) - 1}",
122
+ )
123
+
124
+ # Create new list excluding deleted indices
125
+ self.messages = [msg for idx, msg in enumerate(self.messages) if idx not in index]
126
+
127
+ async def add(
128
+ self,
129
+ messages: Union[List[Dict[str, Any]], Dict[str, Any], None],
130
+ allow_duplicates: bool = False,
131
+ ) -> None:
132
+ """
133
+ Add one or more messages to the conversation memory.
134
+
135
+ Args:
136
+ messages: Single message dict or list of message dicts to add.
137
+ Each message should have at minimum a 'role' and 'content'.
138
+ allow_duplicates: If False, skip messages with duplicate IDs
139
+
140
+ Raises:
141
+ TypeError: If messages are not in the expected format
142
+
143
+ Example:
144
+ >>> # Add single message
145
+ >>> await memory.add({"role": "user", "content": "Hello"})
146
+ >>>
147
+ >>> # Add multiple messages
148
+ >>> await memory.add([
149
+ ... {"role": "user", "content": "Hi"},
150
+ ... {"role": "assistant", "content": "Hello!"}
151
+ ... ])
152
+ """
153
+ if messages is None:
154
+ return
155
+
156
+ # Normalize to list
157
+ if isinstance(messages, dict):
158
+ messages = [messages]
159
+
160
+ if not isinstance(messages, list):
161
+ raise TypeError(
162
+ f"Messages should be a list of dicts or a single dict, " f"but got {type(messages)}",
163
+ )
164
+
165
+ # Validate each message
166
+ for msg in messages:
167
+ if not isinstance(msg, dict):
168
+ raise TypeError(
169
+ f"Each message should be a dictionary, but got {type(msg)}",
170
+ )
171
+
172
+ # Add message IDs if not present (for duplicate detection)
173
+ processed_messages = []
174
+ for msg in messages:
175
+ msg_copy = msg.copy()
176
+ if "id" not in msg_copy:
177
+ msg_copy["id"] = f"msg_{uuid.uuid4().hex[:12]}"
178
+ processed_messages.append(msg_copy)
179
+
180
+ # Filter duplicates if needed
181
+ if not allow_duplicates:
182
+ existing_ids = {msg.get("id") for msg in self.messages if "id" in msg}
183
+ processed_messages = [msg for msg in processed_messages if msg.get("id") not in existing_ids]
184
+
185
+ self.messages.extend(processed_messages)
186
+
187
+ async def get_messages(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
188
+ """
189
+ Get all messages in the conversation.
190
+
191
+ Args:
192
+ limit: Optional limit on number of most recent messages to return
193
+
194
+ Returns:
195
+ List of message dictionaries (copies, not references)
196
+
197
+ Example:
198
+ >>> # Get all messages
199
+ >>> all_msgs = await memory.get_messages()
200
+ >>>
201
+ >>> # Get last 10 messages
202
+ >>> recent = await memory.get_messages(limit=10)
203
+ """
204
+ if limit is not None and limit > 0:
205
+ return [msg.copy() for msg in self.messages[-limit:]]
206
+ return [msg.copy() for msg in self.messages]
207
+
208
+ async def clear(self) -> None:
209
+ """
210
+ Remove all messages from memory.
211
+
212
+ Example:
213
+ >>> await memory.clear()
214
+ >>> assert await memory.size() == 0
215
+ """
216
+ self.messages = []
217
+
218
+ async def get_last_message(self) -> Optional[Dict[str, Any]]:
219
+ """
220
+ Get the most recent message.
221
+
222
+ Returns:
223
+ Last message dictionary, or None if memory is empty
224
+ """
225
+ if not self.messages:
226
+ return None
227
+ return self.messages[-1].copy()
228
+
229
+ async def get_messages_by_role(self, role: str) -> List[Dict[str, Any]]:
230
+ """
231
+ Filter messages by role.
232
+
233
+ Args:
234
+ role: Role to filter by (e.g., 'user', 'assistant', 'system')
235
+
236
+ Returns:
237
+ List of messages with matching role
238
+ """
239
+ return [msg.copy() for msg in self.messages if msg.get("role") == role]
240
+
241
+ async def truncate_to_size(self, max_messages: int) -> None:
242
+ """
243
+ Keep only the most recent messages up to max_messages.
244
+
245
+ This is useful for managing memory usage in long conversations.
246
+
247
+ Args:
248
+ max_messages: Maximum number of messages to keep
249
+
250
+ Example:
251
+ >>> # Keep only last 100 messages
252
+ >>> await memory.truncate_to_size(100)
253
+ """
254
+ if max_messages < len(self.messages):
255
+ self.messages = self.messages[-max_messages:]