massgen 0.1.4__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of massgen might be problematic. Click here for more details.
- massgen/__init__.py +1 -1
- massgen/chat_agent.py +340 -20
- massgen/cli.py +326 -19
- massgen/configs/README.md +52 -10
- massgen/configs/memory/gpt5mini_gemini_baseline_research_to_implementation.yaml +94 -0
- massgen/configs/memory/gpt5mini_gemini_context_window_management.yaml +187 -0
- massgen/configs/memory/gpt5mini_gemini_research_to_implementation.yaml +127 -0
- massgen/configs/memory/gpt5mini_high_reasoning_gemini.yaml +107 -0
- massgen/configs/memory/single_agent_compression_test.yaml +64 -0
- massgen/configs/tools/custom_tools/multimodal_tools/playwright_with_img_understanding.yaml +98 -0
- massgen/configs/tools/custom_tools/multimodal_tools/understand_video_example.yaml +54 -0
- massgen/memory/README.md +277 -0
- massgen/memory/__init__.py +26 -0
- massgen/memory/_base.py +193 -0
- massgen/memory/_compression.py +237 -0
- massgen/memory/_context_monitor.py +211 -0
- massgen/memory/_conversation.py +255 -0
- massgen/memory/_fact_extraction_prompts.py +333 -0
- massgen/memory/_mem0_adapters.py +257 -0
- massgen/memory/_persistent.py +687 -0
- massgen/memory/docker-compose.qdrant.yml +36 -0
- massgen/memory/docs/DESIGN.md +388 -0
- massgen/memory/docs/QUICKSTART.md +409 -0
- massgen/memory/docs/SUMMARY.md +319 -0
- massgen/memory/docs/agent_use_memory.md +408 -0
- massgen/memory/docs/orchestrator_use_memory.md +586 -0
- massgen/memory/examples.py +237 -0
- massgen/orchestrator.py +207 -7
- massgen/tests/memory/test_agent_compression.py +174 -0
- massgen/tests/memory/test_context_window_management.py +286 -0
- massgen/tests/memory/test_force_compression.py +154 -0
- massgen/tests/memory/test_simple_compression.py +147 -0
- massgen/tests/test_agent_memory.py +534 -0
- massgen/tests/test_conversation_memory.py +382 -0
- massgen/tests/test_orchestrator_memory.py +620 -0
- massgen/tests/test_persistent_memory.py +435 -0
- massgen/token_manager/token_manager.py +6 -0
- massgen/tools/__init__.py +8 -0
- massgen/tools/_planning_mcp_server.py +520 -0
- massgen/tools/planning_dataclasses.py +434 -0
- {massgen-0.1.4.dist-info → massgen-0.1.5.dist-info}/METADATA +109 -76
- {massgen-0.1.4.dist-info → massgen-0.1.5.dist-info}/RECORD +46 -12
- {massgen-0.1.4.dist-info → massgen-0.1.5.dist-info}/WHEEL +0 -0
- {massgen-0.1.4.dist-info → massgen-0.1.5.dist-info}/entry_points.txt +0 -0
- {massgen-0.1.4.dist-info → massgen-0.1.5.dist-info}/licenses/LICENSE +0 -0
- {massgen-0.1.4.dist-info → massgen-0.1.5.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
"""
|
|
3
|
+
Context Window Compression
|
|
4
|
+
|
|
5
|
+
Automatically compresses conversation history when context window fills up.
|
|
6
|
+
Since messages are already recorded to persistent memory after each turn,
|
|
7
|
+
compression simply removes old messages from active context while keeping
|
|
8
|
+
them accessible via semantic search.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from typing import Any, Callable, Dict, List, Optional
|
|
12
|
+
|
|
13
|
+
from ..logger_config import logger
|
|
14
|
+
from ..token_manager.token_manager import TokenCostCalculator
|
|
15
|
+
from ._conversation import ConversationMemory
|
|
16
|
+
from ._persistent import PersistentMemoryBase
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class CompressionStats:
|
|
20
|
+
"""Statistics about a compression operation."""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
messages_removed: int = 0,
|
|
25
|
+
tokens_removed: int = 0,
|
|
26
|
+
messages_kept: int = 0,
|
|
27
|
+
tokens_kept: int = 0,
|
|
28
|
+
):
|
|
29
|
+
self.messages_removed = messages_removed
|
|
30
|
+
self.tokens_removed = tokens_removed
|
|
31
|
+
self.messages_kept = messages_kept
|
|
32
|
+
self.tokens_kept = tokens_kept
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class ContextCompressor:
|
|
36
|
+
"""
|
|
37
|
+
Compresses conversation history when context window fills up.
|
|
38
|
+
|
|
39
|
+
Strategy:
|
|
40
|
+
- Messages are already recorded to persistent_memory after each turn
|
|
41
|
+
- Compression removes old messages from conversation_memory
|
|
42
|
+
- Recent messages stay in active context
|
|
43
|
+
- Old messages remain accessible via semantic retrieval
|
|
44
|
+
|
|
45
|
+
Features:
|
|
46
|
+
- Token-aware compression (not just message count)
|
|
47
|
+
- Preserves system messages
|
|
48
|
+
- Keeps most recent messages
|
|
49
|
+
- Detailed compression logging
|
|
50
|
+
|
|
51
|
+
Example:
|
|
52
|
+
>>> compressor = ContextCompressor(
|
|
53
|
+
... token_calculator=TokenCostCalculator(),
|
|
54
|
+
... conversation_memory=conversation_memory,
|
|
55
|
+
... persistent_memory=persistent_memory
|
|
56
|
+
... )
|
|
57
|
+
>>>
|
|
58
|
+
>>> stats = await compressor.compress_if_needed(
|
|
59
|
+
... messages=messages,
|
|
60
|
+
... current_tokens=96000,
|
|
61
|
+
... target_tokens=51200
|
|
62
|
+
... )
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
token_calculator: TokenCostCalculator,
|
|
68
|
+
conversation_memory: ConversationMemory,
|
|
69
|
+
persistent_memory: Optional[PersistentMemoryBase] = None,
|
|
70
|
+
on_compress: Optional[Callable[[CompressionStats], None]] = None,
|
|
71
|
+
):
|
|
72
|
+
"""
|
|
73
|
+
Initialize context compressor.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
token_calculator: Calculator for token estimation
|
|
77
|
+
conversation_memory: Conversation memory to compress
|
|
78
|
+
persistent_memory: Optional persistent memory (for logging purposes)
|
|
79
|
+
on_compress: Optional callback called after compression
|
|
80
|
+
"""
|
|
81
|
+
self.token_calculator = token_calculator
|
|
82
|
+
self.conversation_memory = conversation_memory
|
|
83
|
+
self.persistent_memory = persistent_memory
|
|
84
|
+
self.on_compress = on_compress
|
|
85
|
+
|
|
86
|
+
# Stats tracking
|
|
87
|
+
self.total_compressions = 0
|
|
88
|
+
self.total_messages_removed = 0
|
|
89
|
+
self.total_tokens_removed = 0
|
|
90
|
+
|
|
91
|
+
async def compress_if_needed(
|
|
92
|
+
self,
|
|
93
|
+
messages: List[Dict[str, Any]],
|
|
94
|
+
current_tokens: int,
|
|
95
|
+
target_tokens: int,
|
|
96
|
+
should_compress: bool = None,
|
|
97
|
+
) -> Optional[CompressionStats]:
|
|
98
|
+
"""
|
|
99
|
+
Compress messages if needed.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
messages: Current conversation messages
|
|
103
|
+
current_tokens: Current token count
|
|
104
|
+
target_tokens: Target token count after compression
|
|
105
|
+
should_compress: Optional explicit compression flag
|
|
106
|
+
If None, compresses only if current_tokens > target_tokens
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
CompressionStats if compression occurred, None otherwise
|
|
110
|
+
"""
|
|
111
|
+
# Determine if we need to compress
|
|
112
|
+
if should_compress is None:
|
|
113
|
+
should_compress = current_tokens > target_tokens
|
|
114
|
+
|
|
115
|
+
if not should_compress:
|
|
116
|
+
return None
|
|
117
|
+
|
|
118
|
+
# Select messages to keep
|
|
119
|
+
messages_to_keep = self._select_messages_to_keep(
|
|
120
|
+
messages=messages,
|
|
121
|
+
target_tokens=target_tokens,
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
if len(messages_to_keep) >= len(messages):
|
|
125
|
+
# No compression needed (already under target)
|
|
126
|
+
logger.debug("All messages fit within target, skipping compression")
|
|
127
|
+
return None
|
|
128
|
+
|
|
129
|
+
# Calculate stats
|
|
130
|
+
messages_removed = len(messages) - len(messages_to_keep)
|
|
131
|
+
messages_to_remove = [msg for msg in messages if msg not in messages_to_keep]
|
|
132
|
+
tokens_removed = self.token_calculator.estimate_tokens(messages_to_remove)
|
|
133
|
+
tokens_kept = self.token_calculator.estimate_tokens(messages_to_keep)
|
|
134
|
+
|
|
135
|
+
# Update conversation memory
|
|
136
|
+
try:
|
|
137
|
+
await self.conversation_memory.clear()
|
|
138
|
+
await self.conversation_memory.add(messages_to_keep)
|
|
139
|
+
except Exception as e:
|
|
140
|
+
logger.error(f"Failed to update conversation memory during compression: {e}")
|
|
141
|
+
return None
|
|
142
|
+
|
|
143
|
+
# Log compression result
|
|
144
|
+
if self.persistent_memory:
|
|
145
|
+
logger.info(
|
|
146
|
+
f"📦 Context compressed: Removed {messages_removed} old messages "
|
|
147
|
+
f"({tokens_removed:,} tokens) from active context.\n"
|
|
148
|
+
f" Kept {len(messages_to_keep)} recent messages ({tokens_kept:,} tokens).\n"
|
|
149
|
+
f" Old messages remain accessible via semantic search.",
|
|
150
|
+
)
|
|
151
|
+
else:
|
|
152
|
+
logger.warning(
|
|
153
|
+
f"⚠️ Context compressed: Removed {messages_removed} old messages "
|
|
154
|
+
f"({tokens_removed:,} tokens) from active context.\n"
|
|
155
|
+
f" Kept {len(messages_to_keep)} recent messages ({tokens_kept:,} tokens).\n"
|
|
156
|
+
f" No persistent memory - old messages NOT retrievable.",
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
# Update stats
|
|
160
|
+
self.total_compressions += 1
|
|
161
|
+
self.total_messages_removed += messages_removed
|
|
162
|
+
self.total_tokens_removed += tokens_removed
|
|
163
|
+
|
|
164
|
+
# Create stats object
|
|
165
|
+
stats = CompressionStats(
|
|
166
|
+
messages_removed=messages_removed,
|
|
167
|
+
tokens_removed=tokens_removed,
|
|
168
|
+
messages_kept=len(messages_to_keep),
|
|
169
|
+
tokens_kept=tokens_kept,
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
# Trigger callback if provided
|
|
173
|
+
if self.on_compress:
|
|
174
|
+
self.on_compress(stats)
|
|
175
|
+
|
|
176
|
+
return stats
|
|
177
|
+
|
|
178
|
+
def _select_messages_to_keep(
|
|
179
|
+
self,
|
|
180
|
+
messages: List[Dict[str, Any]],
|
|
181
|
+
target_tokens: int,
|
|
182
|
+
) -> List[Dict[str, Any]]:
|
|
183
|
+
"""
|
|
184
|
+
Select which messages to keep in active context.
|
|
185
|
+
|
|
186
|
+
Strategy:
|
|
187
|
+
1. Always keep system messages at the start
|
|
188
|
+
2. Keep most recent messages that fit in target_tokens
|
|
189
|
+
3. Remove everything in between
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
messages: All messages in conversation
|
|
193
|
+
target_tokens: Target token budget for kept messages
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
List of messages to keep in conversation_memory
|
|
197
|
+
"""
|
|
198
|
+
if not messages:
|
|
199
|
+
return []
|
|
200
|
+
|
|
201
|
+
# Separate system messages from others
|
|
202
|
+
system_messages = []
|
|
203
|
+
non_system_messages = []
|
|
204
|
+
|
|
205
|
+
for msg in messages:
|
|
206
|
+
if msg.get("role") == "system":
|
|
207
|
+
system_messages.append(msg)
|
|
208
|
+
else:
|
|
209
|
+
non_system_messages.append(msg)
|
|
210
|
+
|
|
211
|
+
# Start with system messages in kept list
|
|
212
|
+
messages_to_keep = system_messages.copy()
|
|
213
|
+
tokens_so_far = self.token_calculator.estimate_tokens(system_messages)
|
|
214
|
+
|
|
215
|
+
# Work backwards from most recent, adding messages until we hit target
|
|
216
|
+
recent_messages_to_keep = []
|
|
217
|
+
for msg in reversed(non_system_messages):
|
|
218
|
+
msg_tokens = self.token_calculator.estimate_tokens([msg])
|
|
219
|
+
if tokens_so_far + msg_tokens <= target_tokens:
|
|
220
|
+
tokens_so_far += msg_tokens
|
|
221
|
+
recent_messages_to_keep.insert(0, msg) # Maintain order
|
|
222
|
+
else:
|
|
223
|
+
# Hit token limit, stop here
|
|
224
|
+
break
|
|
225
|
+
|
|
226
|
+
# Combine: system messages + recent messages
|
|
227
|
+
messages_to_keep.extend(recent_messages_to_keep)
|
|
228
|
+
|
|
229
|
+
return messages_to_keep
|
|
230
|
+
|
|
231
|
+
def get_stats(self) -> Dict[str, Any]:
|
|
232
|
+
"""Get compression statistics."""
|
|
233
|
+
return {
|
|
234
|
+
"total_compressions": self.total_compressions,
|
|
235
|
+
"total_messages_removed": self.total_messages_removed,
|
|
236
|
+
"total_tokens_removed": self.total_tokens_removed,
|
|
237
|
+
}
|
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
"""
|
|
3
|
+
Context Window Monitoring Utility
|
|
4
|
+
|
|
5
|
+
Provides logging and tracking for context window usage during agent execution.
|
|
6
|
+
Helps debug memory and token management by showing real-time context usage.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import Any, Dict, List, Optional
|
|
10
|
+
|
|
11
|
+
from ..logger_config import logger
|
|
12
|
+
from ..token_manager.token_manager import TokenCostCalculator
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ContextWindowMonitor:
|
|
16
|
+
"""Monitor and log context window usage during agent execution."""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
model_name: str,
|
|
21
|
+
provider: str = "openai",
|
|
22
|
+
trigger_threshold: float = 0.75,
|
|
23
|
+
target_ratio: float = 0.40,
|
|
24
|
+
enabled: bool = True,
|
|
25
|
+
):
|
|
26
|
+
"""
|
|
27
|
+
Initialize context window monitor.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
model_name: Name of the model (e.g., "gpt-4o-mini")
|
|
31
|
+
provider: Provider name (e.g., "openai", "anthropic")
|
|
32
|
+
trigger_threshold: Percentage (0-1) at which to warn about context usage
|
|
33
|
+
target_ratio: Target percentage after compression
|
|
34
|
+
enabled: Whether to enable logging
|
|
35
|
+
"""
|
|
36
|
+
self.model_name = model_name
|
|
37
|
+
self.provider = provider
|
|
38
|
+
self.trigger_threshold = trigger_threshold
|
|
39
|
+
self.target_ratio = target_ratio
|
|
40
|
+
self.enabled = enabled
|
|
41
|
+
|
|
42
|
+
# Get model pricing info to determine context window size
|
|
43
|
+
self.calculator = TokenCostCalculator()
|
|
44
|
+
self.pricing = self.calculator.get_model_pricing(provider, model_name)
|
|
45
|
+
|
|
46
|
+
if self.pricing and self.pricing.context_window:
|
|
47
|
+
self.context_window = self.pricing.context_window
|
|
48
|
+
else:
|
|
49
|
+
# Default fallbacks
|
|
50
|
+
self.context_window = 128000 # Common default
|
|
51
|
+
logger.warning(
|
|
52
|
+
f"Could not determine context window for {provider}/{model_name}, " f"using default {self.context_window}",
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
# Tracking
|
|
56
|
+
self.total_input_tokens = 0
|
|
57
|
+
self.total_output_tokens = 0
|
|
58
|
+
self.turn_count = 0
|
|
59
|
+
|
|
60
|
+
def log_context_usage(
|
|
61
|
+
self,
|
|
62
|
+
messages: List[Dict[str, Any]],
|
|
63
|
+
turn_number: Optional[int] = None,
|
|
64
|
+
) -> Dict[str, Any]:
|
|
65
|
+
"""
|
|
66
|
+
Log current context window usage.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
messages: Current conversation messages
|
|
70
|
+
turn_number: Optional turn number to display
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Dict with usage stats: {
|
|
74
|
+
"current_tokens": int,
|
|
75
|
+
"max_tokens": int,
|
|
76
|
+
"usage_percent": float,
|
|
77
|
+
"should_compress": bool,
|
|
78
|
+
"target_tokens": int
|
|
79
|
+
}
|
|
80
|
+
"""
|
|
81
|
+
if not self.enabled:
|
|
82
|
+
return {}
|
|
83
|
+
|
|
84
|
+
# Estimate tokens in current context
|
|
85
|
+
current_tokens = self.calculator.estimate_tokens(messages)
|
|
86
|
+
usage_percent = current_tokens / self.context_window
|
|
87
|
+
should_compress = usage_percent >= self.trigger_threshold
|
|
88
|
+
target_tokens = int(self.context_window * self.target_ratio)
|
|
89
|
+
|
|
90
|
+
# Build log message
|
|
91
|
+
turn_str = f" (Turn {turn_number})" if turn_number is not None else ""
|
|
92
|
+
status_emoji = "📊"
|
|
93
|
+
|
|
94
|
+
if usage_percent >= self.trigger_threshold:
|
|
95
|
+
status_emoji = "⚠️"
|
|
96
|
+
logger.warning(
|
|
97
|
+
f"{status_emoji} Context Window{turn_str}: " f"{current_tokens:,} / {self.context_window:,} tokens " f"({usage_percent*100:.1f}%) - Approaching limit!",
|
|
98
|
+
)
|
|
99
|
+
logger.warning(
|
|
100
|
+
f" Compression threshold reached. Target after compression: " f"{target_tokens:,} tokens ({self.target_ratio*100:.0f}%)",
|
|
101
|
+
)
|
|
102
|
+
elif usage_percent >= 0.50:
|
|
103
|
+
logger.info(
|
|
104
|
+
f"{status_emoji} Context Window{turn_str}: " f"{current_tokens:,} / {self.context_window:,} tokens " f"({usage_percent*100:.1f}%)",
|
|
105
|
+
)
|
|
106
|
+
else:
|
|
107
|
+
logger.info(
|
|
108
|
+
f"{status_emoji} Context Window{turn_str}: " f"{current_tokens:,} / {self.context_window:,} tokens " f"({usage_percent*100:.1f}%)",
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
return {
|
|
112
|
+
"current_tokens": current_tokens,
|
|
113
|
+
"max_tokens": self.context_window,
|
|
114
|
+
"usage_percent": usage_percent,
|
|
115
|
+
"should_compress": should_compress,
|
|
116
|
+
"target_tokens": target_tokens,
|
|
117
|
+
"trigger_threshold": self.trigger_threshold,
|
|
118
|
+
"target_ratio": self.target_ratio,
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
def log_turn_summary(
|
|
122
|
+
self,
|
|
123
|
+
input_tokens: int,
|
|
124
|
+
output_tokens: int,
|
|
125
|
+
turn_number: Optional[int] = None,
|
|
126
|
+
):
|
|
127
|
+
"""
|
|
128
|
+
Log summary for a single turn.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
input_tokens: Input tokens for this turn
|
|
132
|
+
output_tokens: Output tokens for this turn
|
|
133
|
+
turn_number: Optional turn number
|
|
134
|
+
"""
|
|
135
|
+
if not self.enabled:
|
|
136
|
+
return
|
|
137
|
+
|
|
138
|
+
self.total_input_tokens += input_tokens
|
|
139
|
+
self.total_output_tokens += output_tokens
|
|
140
|
+
self.turn_count += 1
|
|
141
|
+
|
|
142
|
+
turn_str = f" {turn_number}" if turn_number is not None else f" {self.turn_count}"
|
|
143
|
+
|
|
144
|
+
logger.info(
|
|
145
|
+
f"Turn{turn_str} tokens: {input_tokens:,} input + {output_tokens:,} output = " f"{input_tokens + output_tokens:,} total",
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
def log_session_summary(self):
|
|
149
|
+
"""Log overall session summary."""
|
|
150
|
+
if not self.enabled or self.turn_count == 0:
|
|
151
|
+
return
|
|
152
|
+
|
|
153
|
+
total_tokens = self.total_input_tokens + self.total_output_tokens
|
|
154
|
+
avg_per_turn = total_tokens / self.turn_count if self.turn_count > 0 else 0
|
|
155
|
+
|
|
156
|
+
logger.info("=" * 70)
|
|
157
|
+
logger.info("📊 Session Summary:")
|
|
158
|
+
logger.info(f" Total turns: {self.turn_count}")
|
|
159
|
+
logger.info(f" Total input tokens: {self.total_input_tokens:,}")
|
|
160
|
+
logger.info(f" Total output tokens: {self.total_output_tokens:,}")
|
|
161
|
+
logger.info(f" Total tokens: {total_tokens:,}")
|
|
162
|
+
logger.info(f" Average per turn: {avg_per_turn:,.0f} tokens")
|
|
163
|
+
logger.info(f" Context window: {self.context_window:,} tokens")
|
|
164
|
+
logger.info(f" Peak usage: {(total_tokens/self.context_window)*100:.1f}%")
|
|
165
|
+
logger.info("=" * 70)
|
|
166
|
+
|
|
167
|
+
def get_stats(self) -> Dict[str, Any]:
|
|
168
|
+
"""Get current monitoring stats."""
|
|
169
|
+
total_tokens = self.total_input_tokens + self.total_output_tokens
|
|
170
|
+
|
|
171
|
+
return {
|
|
172
|
+
"turn_count": self.turn_count,
|
|
173
|
+
"total_input_tokens": self.total_input_tokens,
|
|
174
|
+
"total_output_tokens": self.total_output_tokens,
|
|
175
|
+
"total_tokens": total_tokens,
|
|
176
|
+
"context_window": self.context_window,
|
|
177
|
+
"avg_tokens_per_turn": total_tokens / self.turn_count if self.turn_count > 0 else 0,
|
|
178
|
+
"peak_usage_percent": total_tokens / self.context_window if self.context_window > 0 else 0,
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def create_monitor_from_config(
|
|
183
|
+
config: Dict[str, Any],
|
|
184
|
+
model_name: str,
|
|
185
|
+
provider: str = "openai",
|
|
186
|
+
) -> ContextWindowMonitor:
|
|
187
|
+
"""
|
|
188
|
+
Create a context window monitor from YAML config.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
config: Config dict (should have 'memory' section)
|
|
192
|
+
model_name: Model name
|
|
193
|
+
provider: Provider name
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
ContextWindowMonitor instance
|
|
197
|
+
"""
|
|
198
|
+
memory_config = config.get("memory", {})
|
|
199
|
+
compression_config = memory_config.get("compression", {})
|
|
200
|
+
|
|
201
|
+
trigger_threshold = compression_config.get("trigger_threshold", 0.75)
|
|
202
|
+
target_ratio = compression_config.get("target_ratio", 0.40)
|
|
203
|
+
enabled = memory_config.get("enabled", True)
|
|
204
|
+
|
|
205
|
+
return ContextWindowMonitor(
|
|
206
|
+
model_name=model_name,
|
|
207
|
+
provider=provider,
|
|
208
|
+
trigger_threshold=trigger_threshold,
|
|
209
|
+
target_ratio=target_ratio,
|
|
210
|
+
enabled=enabled,
|
|
211
|
+
)
|
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
"""
|
|
3
|
+
Conversation memory implementation for MassGen.
|
|
4
|
+
|
|
5
|
+
This module provides in-memory storage for conversation messages, optimized
|
|
6
|
+
for quick access during active chat sessions.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import uuid
|
|
10
|
+
from typing import Any, Dict, Iterable, List, Optional, Union
|
|
11
|
+
|
|
12
|
+
from ._base import MemoryBase
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ConversationMemory(MemoryBase):
|
|
16
|
+
"""
|
|
17
|
+
In-memory storage for conversation messages.
|
|
18
|
+
|
|
19
|
+
This memory type is designed for short-term storage of ongoing conversations.
|
|
20
|
+
It keeps messages in a simple list structure for fast access and iteration.
|
|
21
|
+
|
|
22
|
+
Features:
|
|
23
|
+
- Fast in-memory access
|
|
24
|
+
- Duplicate detection based on message IDs
|
|
25
|
+
- Index-based deletion
|
|
26
|
+
- State serialization for session persistence
|
|
27
|
+
|
|
28
|
+
Example:
|
|
29
|
+
>>> memory = ConversationMemory()
|
|
30
|
+
>>> await memory.add({"role": "user", "content": "Hello"})
|
|
31
|
+
>>> messages = await memory.get_messages()
|
|
32
|
+
>>> print(len(messages)) # 1
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(self) -> None:
|
|
36
|
+
"""Initialize an empty conversation memory."""
|
|
37
|
+
super().__init__()
|
|
38
|
+
self.messages: List[Dict[str, Any]] = []
|
|
39
|
+
|
|
40
|
+
def state_dict(self) -> Dict[str, Any]:
|
|
41
|
+
"""
|
|
42
|
+
Serialize memory state to a dictionary.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
Dictionary with 'messages' key containing all stored messages
|
|
46
|
+
"""
|
|
47
|
+
return {
|
|
48
|
+
"messages": [msg.copy() for msg in self.messages],
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
def load_state_dict(
|
|
52
|
+
self,
|
|
53
|
+
state_dict: Dict[str, Any],
|
|
54
|
+
strict: bool = True,
|
|
55
|
+
) -> None:
|
|
56
|
+
"""
|
|
57
|
+
Load memory state from a serialized dictionary.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
state_dict: Dictionary containing 'messages' key
|
|
61
|
+
strict: If True, validates the state dictionary structure
|
|
62
|
+
|
|
63
|
+
Raises:
|
|
64
|
+
ValueError: If strict=True and state_dict is invalid
|
|
65
|
+
"""
|
|
66
|
+
if strict and "messages" not in state_dict:
|
|
67
|
+
raise ValueError(
|
|
68
|
+
"State dictionary must contain 'messages' key when strict=True",
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
self.messages = []
|
|
72
|
+
for msg_data in state_dict.get("messages", []):
|
|
73
|
+
# Ensure each message is a proper dictionary
|
|
74
|
+
if isinstance(msg_data, dict):
|
|
75
|
+
self.messages.append(msg_data.copy())
|
|
76
|
+
|
|
77
|
+
async def size(self) -> int:
|
|
78
|
+
"""
|
|
79
|
+
Get the number of messages in memory.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
Count of stored messages
|
|
83
|
+
"""
|
|
84
|
+
return len(self.messages)
|
|
85
|
+
|
|
86
|
+
async def retrieve(self, *args: Any, **kwargs: Any) -> None:
|
|
87
|
+
"""
|
|
88
|
+
Retrieve is not supported for basic conversation memory.
|
|
89
|
+
|
|
90
|
+
Use get_messages() to access all messages directly.
|
|
91
|
+
|
|
92
|
+
Raises:
|
|
93
|
+
NotImplementedError: Always, as basic retrieval is not supported
|
|
94
|
+
"""
|
|
95
|
+
raise NotImplementedError(
|
|
96
|
+
f"The retrieve method is not implemented in {self.__class__.__name__}. " "Use get_messages() to access conversation history directly.",
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
async def delete(self, index: Union[Iterable, int]) -> None:
|
|
100
|
+
"""
|
|
101
|
+
Delete message(s) by index position.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
index: Single index or iterable of indices to delete
|
|
105
|
+
|
|
106
|
+
Raises:
|
|
107
|
+
IndexError: If any index is out of range
|
|
108
|
+
|
|
109
|
+
Example:
|
|
110
|
+
>>> await memory.delete(0) # Delete first message
|
|
111
|
+
>>> await memory.delete([1, 3, 5]) # Delete multiple messages
|
|
112
|
+
"""
|
|
113
|
+
if isinstance(index, int):
|
|
114
|
+
index = [index]
|
|
115
|
+
|
|
116
|
+
# Validate all indices first
|
|
117
|
+
invalid_indices = [i for i in index if i < 0 or i >= len(self.messages)]
|
|
118
|
+
|
|
119
|
+
if invalid_indices:
|
|
120
|
+
raise IndexError(
|
|
121
|
+
f"The following indices do not exist: {invalid_indices}. " f"Valid range is 0-{len(self.messages) - 1}",
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
# Create new list excluding deleted indices
|
|
125
|
+
self.messages = [msg for idx, msg in enumerate(self.messages) if idx not in index]
|
|
126
|
+
|
|
127
|
+
async def add(
|
|
128
|
+
self,
|
|
129
|
+
messages: Union[List[Dict[str, Any]], Dict[str, Any], None],
|
|
130
|
+
allow_duplicates: bool = False,
|
|
131
|
+
) -> None:
|
|
132
|
+
"""
|
|
133
|
+
Add one or more messages to the conversation memory.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
messages: Single message dict or list of message dicts to add.
|
|
137
|
+
Each message should have at minimum a 'role' and 'content'.
|
|
138
|
+
allow_duplicates: If False, skip messages with duplicate IDs
|
|
139
|
+
|
|
140
|
+
Raises:
|
|
141
|
+
TypeError: If messages are not in the expected format
|
|
142
|
+
|
|
143
|
+
Example:
|
|
144
|
+
>>> # Add single message
|
|
145
|
+
>>> await memory.add({"role": "user", "content": "Hello"})
|
|
146
|
+
>>>
|
|
147
|
+
>>> # Add multiple messages
|
|
148
|
+
>>> await memory.add([
|
|
149
|
+
... {"role": "user", "content": "Hi"},
|
|
150
|
+
... {"role": "assistant", "content": "Hello!"}
|
|
151
|
+
... ])
|
|
152
|
+
"""
|
|
153
|
+
if messages is None:
|
|
154
|
+
return
|
|
155
|
+
|
|
156
|
+
# Normalize to list
|
|
157
|
+
if isinstance(messages, dict):
|
|
158
|
+
messages = [messages]
|
|
159
|
+
|
|
160
|
+
if not isinstance(messages, list):
|
|
161
|
+
raise TypeError(
|
|
162
|
+
f"Messages should be a list of dicts or a single dict, " f"but got {type(messages)}",
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
# Validate each message
|
|
166
|
+
for msg in messages:
|
|
167
|
+
if not isinstance(msg, dict):
|
|
168
|
+
raise TypeError(
|
|
169
|
+
f"Each message should be a dictionary, but got {type(msg)}",
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
# Add message IDs if not present (for duplicate detection)
|
|
173
|
+
processed_messages = []
|
|
174
|
+
for msg in messages:
|
|
175
|
+
msg_copy = msg.copy()
|
|
176
|
+
if "id" not in msg_copy:
|
|
177
|
+
msg_copy["id"] = f"msg_{uuid.uuid4().hex[:12]}"
|
|
178
|
+
processed_messages.append(msg_copy)
|
|
179
|
+
|
|
180
|
+
# Filter duplicates if needed
|
|
181
|
+
if not allow_duplicates:
|
|
182
|
+
existing_ids = {msg.get("id") for msg in self.messages if "id" in msg}
|
|
183
|
+
processed_messages = [msg for msg in processed_messages if msg.get("id") not in existing_ids]
|
|
184
|
+
|
|
185
|
+
self.messages.extend(processed_messages)
|
|
186
|
+
|
|
187
|
+
async def get_messages(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
|
|
188
|
+
"""
|
|
189
|
+
Get all messages in the conversation.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
limit: Optional limit on number of most recent messages to return
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
List of message dictionaries (copies, not references)
|
|
196
|
+
|
|
197
|
+
Example:
|
|
198
|
+
>>> # Get all messages
|
|
199
|
+
>>> all_msgs = await memory.get_messages()
|
|
200
|
+
>>>
|
|
201
|
+
>>> # Get last 10 messages
|
|
202
|
+
>>> recent = await memory.get_messages(limit=10)
|
|
203
|
+
"""
|
|
204
|
+
if limit is not None and limit > 0:
|
|
205
|
+
return [msg.copy() for msg in self.messages[-limit:]]
|
|
206
|
+
return [msg.copy() for msg in self.messages]
|
|
207
|
+
|
|
208
|
+
async def clear(self) -> None:
|
|
209
|
+
"""
|
|
210
|
+
Remove all messages from memory.
|
|
211
|
+
|
|
212
|
+
Example:
|
|
213
|
+
>>> await memory.clear()
|
|
214
|
+
>>> assert await memory.size() == 0
|
|
215
|
+
"""
|
|
216
|
+
self.messages = []
|
|
217
|
+
|
|
218
|
+
async def get_last_message(self) -> Optional[Dict[str, Any]]:
|
|
219
|
+
"""
|
|
220
|
+
Get the most recent message.
|
|
221
|
+
|
|
222
|
+
Returns:
|
|
223
|
+
Last message dictionary, or None if memory is empty
|
|
224
|
+
"""
|
|
225
|
+
if not self.messages:
|
|
226
|
+
return None
|
|
227
|
+
return self.messages[-1].copy()
|
|
228
|
+
|
|
229
|
+
async def get_messages_by_role(self, role: str) -> List[Dict[str, Any]]:
|
|
230
|
+
"""
|
|
231
|
+
Filter messages by role.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
role: Role to filter by (e.g., 'user', 'assistant', 'system')
|
|
235
|
+
|
|
236
|
+
Returns:
|
|
237
|
+
List of messages with matching role
|
|
238
|
+
"""
|
|
239
|
+
return [msg.copy() for msg in self.messages if msg.get("role") == role]
|
|
240
|
+
|
|
241
|
+
async def truncate_to_size(self, max_messages: int) -> None:
|
|
242
|
+
"""
|
|
243
|
+
Keep only the most recent messages up to max_messages.
|
|
244
|
+
|
|
245
|
+
This is useful for managing memory usage in long conversations.
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
max_messages: Maximum number of messages to keep
|
|
249
|
+
|
|
250
|
+
Example:
|
|
251
|
+
>>> # Keep only last 100 messages
|
|
252
|
+
>>> await memory.truncate_to_size(100)
|
|
253
|
+
"""
|
|
254
|
+
if max_messages < len(self.messages):
|
|
255
|
+
self.messages = self.messages[-max_messages:]
|