massgen 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of massgen might be problematic. Click here for more details.
- massgen/__init__.py +1 -1
- massgen/backend/base_with_custom_tool_and_mcp.py +453 -23
- massgen/backend/capabilities.py +39 -0
- massgen/backend/chat_completions.py +111 -197
- massgen/backend/claude.py +210 -181
- massgen/backend/gemini.py +1015 -1559
- massgen/backend/grok.py +3 -2
- massgen/backend/response.py +160 -220
- massgen/chat_agent.py +340 -20
- massgen/cli.py +399 -25
- massgen/config_builder.py +20 -54
- massgen/config_validator.py +931 -0
- massgen/configs/README.md +95 -10
- massgen/configs/memory/gpt5mini_gemini_baseline_research_to_implementation.yaml +94 -0
- massgen/configs/memory/gpt5mini_gemini_context_window_management.yaml +187 -0
- massgen/configs/memory/gpt5mini_gemini_research_to_implementation.yaml +127 -0
- massgen/configs/memory/gpt5mini_high_reasoning_gemini.yaml +107 -0
- massgen/configs/memory/single_agent_compression_test.yaml +64 -0
- massgen/configs/tools/custom_tools/claude_code_custom_tool_with_mcp_example.yaml +1 -0
- massgen/configs/tools/custom_tools/claude_custom_tool_example_no_path.yaml +1 -1
- massgen/configs/tools/custom_tools/claude_custom_tool_with_mcp_example.yaml +1 -0
- massgen/configs/tools/custom_tools/computer_use_browser_example.yaml +1 -1
- massgen/configs/tools/custom_tools/computer_use_docker_example.yaml +1 -1
- massgen/configs/tools/custom_tools/gemini_custom_tool_with_mcp_example.yaml +1 -0
- massgen/configs/tools/custom_tools/gpt5_nano_custom_tool_with_mcp_example.yaml +1 -0
- massgen/configs/tools/custom_tools/gpt_oss_custom_tool_with_mcp_example.yaml +1 -0
- massgen/configs/tools/custom_tools/grok3_mini_custom_tool_with_mcp_example.yaml +1 -0
- massgen/configs/tools/custom_tools/interop/ag2_and_langgraph_lesson_planner.yaml +65 -0
- massgen/configs/tools/custom_tools/interop/ag2_and_openai_assistant_lesson_planner.yaml +65 -0
- massgen/configs/tools/custom_tools/interop/ag2_lesson_planner_example.yaml +48 -0
- massgen/configs/tools/custom_tools/interop/agentscope_lesson_planner_example.yaml +48 -0
- massgen/configs/tools/custom_tools/interop/langgraph_lesson_planner_example.yaml +49 -0
- massgen/configs/tools/custom_tools/interop/openai_assistant_lesson_planner_example.yaml +50 -0
- massgen/configs/tools/custom_tools/interop/smolagent_lesson_planner_example.yaml +49 -0
- massgen/configs/tools/custom_tools/qwen_api_custom_tool_with_mcp_example.yaml +1 -0
- massgen/configs/tools/custom_tools/two_models_with_tools_example.yaml +44 -0
- massgen/formatter/_gemini_formatter.py +61 -15
- massgen/memory/README.md +277 -0
- massgen/memory/__init__.py +26 -0
- massgen/memory/_base.py +193 -0
- massgen/memory/_compression.py +237 -0
- massgen/memory/_context_monitor.py +211 -0
- massgen/memory/_conversation.py +255 -0
- massgen/memory/_fact_extraction_prompts.py +333 -0
- massgen/memory/_mem0_adapters.py +257 -0
- massgen/memory/_persistent.py +687 -0
- massgen/memory/docker-compose.qdrant.yml +36 -0
- massgen/memory/docs/DESIGN.md +388 -0
- massgen/memory/docs/QUICKSTART.md +409 -0
- massgen/memory/docs/SUMMARY.md +319 -0
- massgen/memory/docs/agent_use_memory.md +408 -0
- massgen/memory/docs/orchestrator_use_memory.md +586 -0
- massgen/memory/examples.py +237 -0
- massgen/orchestrator.py +207 -7
- massgen/tests/memory/test_agent_compression.py +174 -0
- massgen/tests/memory/test_context_window_management.py +286 -0
- massgen/tests/memory/test_force_compression.py +154 -0
- massgen/tests/memory/test_simple_compression.py +147 -0
- massgen/tests/test_ag2_lesson_planner.py +223 -0
- massgen/tests/test_agent_memory.py +534 -0
- massgen/tests/test_config_validator.py +1156 -0
- massgen/tests/test_conversation_memory.py +382 -0
- massgen/tests/test_langgraph_lesson_planner.py +223 -0
- massgen/tests/test_orchestrator_memory.py +620 -0
- massgen/tests/test_persistent_memory.py +435 -0
- massgen/token_manager/token_manager.py +6 -0
- massgen/tool/__init__.py +2 -9
- massgen/tool/_decorators.py +52 -0
- massgen/tool/_extraframework_agents/ag2_lesson_planner_tool.py +251 -0
- massgen/tool/_extraframework_agents/agentscope_lesson_planner_tool.py +303 -0
- massgen/tool/_extraframework_agents/langgraph_lesson_planner_tool.py +275 -0
- massgen/tool/_extraframework_agents/openai_assistant_lesson_planner_tool.py +247 -0
- massgen/tool/_extraframework_agents/smolagent_lesson_planner_tool.py +180 -0
- massgen/tool/_manager.py +102 -16
- massgen/tool/_registered_tool.py +3 -0
- massgen/tool/_result.py +3 -0
- {massgen-0.1.4.dist-info → massgen-0.1.6.dist-info}/METADATA +138 -77
- {massgen-0.1.4.dist-info → massgen-0.1.6.dist-info}/RECORD +82 -37
- massgen/backend/gemini_mcp_manager.py +0 -545
- massgen/backend/gemini_trackers.py +0 -344
- {massgen-0.1.4.dist-info → massgen-0.1.6.dist-info}/WHEEL +0 -0
- {massgen-0.1.4.dist-info → massgen-0.1.6.dist-info}/entry_points.txt +0 -0
- {massgen-0.1.4.dist-info → massgen-0.1.6.dist-info}/licenses/LICENSE +0 -0
- {massgen-0.1.4.dist-info → massgen-0.1.6.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,382 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
Tests for ConversationMemory implementation.
|
|
5
|
+
|
|
6
|
+
This module tests the in-memory conversation storage functionality,
|
|
7
|
+
including adding, retrieving, deleting, and managing messages.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import pytest
|
|
11
|
+
|
|
12
|
+
from massgen.memory import ConversationMemory
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@pytest.mark.asyncio
|
|
16
|
+
async def test_conversation_memory_initialization():
|
|
17
|
+
"""Test that ConversationMemory initializes correctly."""
|
|
18
|
+
memory = ConversationMemory()
|
|
19
|
+
|
|
20
|
+
assert await memory.size() == 0
|
|
21
|
+
assert memory.messages == []
|
|
22
|
+
print("✅ ConversationMemory initialization works")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@pytest.mark.asyncio
|
|
26
|
+
async def test_add_single_message():
|
|
27
|
+
"""Test adding a single message to memory."""
|
|
28
|
+
memory = ConversationMemory()
|
|
29
|
+
|
|
30
|
+
message = {"role": "user", "content": "Hello, world!"}
|
|
31
|
+
await memory.add(message)
|
|
32
|
+
|
|
33
|
+
assert await memory.size() == 1
|
|
34
|
+
messages = await memory.get_messages()
|
|
35
|
+
assert len(messages) == 1
|
|
36
|
+
assert messages[0]["role"] == "user"
|
|
37
|
+
assert messages[0]["content"] == "Hello, world!"
|
|
38
|
+
assert "id" in messages[0] # Auto-generated ID
|
|
39
|
+
print("✅ Adding single message works")
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@pytest.mark.asyncio
|
|
43
|
+
async def test_add_multiple_messages():
|
|
44
|
+
"""Test adding multiple messages at once."""
|
|
45
|
+
memory = ConversationMemory()
|
|
46
|
+
|
|
47
|
+
messages = [
|
|
48
|
+
{"role": "user", "content": "Hello"},
|
|
49
|
+
{"role": "assistant", "content": "Hi there!"},
|
|
50
|
+
{"role": "user", "content": "How are you?"},
|
|
51
|
+
]
|
|
52
|
+
await memory.add(messages)
|
|
53
|
+
|
|
54
|
+
assert await memory.size() == 3
|
|
55
|
+
retrieved = await memory.get_messages()
|
|
56
|
+
assert len(retrieved) == 3
|
|
57
|
+
assert retrieved[0]["content"] == "Hello"
|
|
58
|
+
assert retrieved[1]["content"] == "Hi there!"
|
|
59
|
+
assert retrieved[2]["content"] == "How are you?"
|
|
60
|
+
print("✅ Adding multiple messages works")
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@pytest.mark.asyncio
|
|
64
|
+
async def test_duplicate_prevention():
|
|
65
|
+
"""Test that duplicate messages are prevented by default."""
|
|
66
|
+
memory = ConversationMemory()
|
|
67
|
+
|
|
68
|
+
message = {"id": "msg_123", "role": "user", "content": "Hello"}
|
|
69
|
+
|
|
70
|
+
# Add same message twice
|
|
71
|
+
await memory.add(message)
|
|
72
|
+
await memory.add(message)
|
|
73
|
+
|
|
74
|
+
# Should only have one message
|
|
75
|
+
assert await memory.size() == 1
|
|
76
|
+
print("✅ Duplicate prevention works")
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@pytest.mark.asyncio
|
|
80
|
+
async def test_allow_duplicates():
|
|
81
|
+
"""Test allowing duplicate messages when explicitly enabled."""
|
|
82
|
+
memory = ConversationMemory()
|
|
83
|
+
|
|
84
|
+
message = {"id": "msg_123", "role": "user", "content": "Hello"}
|
|
85
|
+
|
|
86
|
+
# Add same message twice with allow_duplicates=True
|
|
87
|
+
await memory.add(message, allow_duplicates=True)
|
|
88
|
+
await memory.add(message, allow_duplicates=True)
|
|
89
|
+
|
|
90
|
+
# Should have two messages
|
|
91
|
+
assert await memory.size() == 2
|
|
92
|
+
print("✅ Allowing duplicates works")
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
@pytest.mark.asyncio
|
|
96
|
+
async def test_delete_by_index():
|
|
97
|
+
"""Test deleting messages by index."""
|
|
98
|
+
memory = ConversationMemory()
|
|
99
|
+
|
|
100
|
+
messages = [
|
|
101
|
+
{"role": "user", "content": "Message 1"},
|
|
102
|
+
{"role": "assistant", "content": "Message 2"},
|
|
103
|
+
{"role": "user", "content": "Message 3"},
|
|
104
|
+
]
|
|
105
|
+
await memory.add(messages)
|
|
106
|
+
|
|
107
|
+
# Delete middle message
|
|
108
|
+
await memory.delete(1)
|
|
109
|
+
|
|
110
|
+
assert await memory.size() == 2
|
|
111
|
+
retrieved = await memory.get_messages()
|
|
112
|
+
assert retrieved[0]["content"] == "Message 1"
|
|
113
|
+
assert retrieved[1]["content"] == "Message 3"
|
|
114
|
+
print("✅ Deleting by index works")
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
@pytest.mark.asyncio
|
|
118
|
+
async def test_delete_multiple_indices():
|
|
119
|
+
"""Test deleting multiple messages at once."""
|
|
120
|
+
memory = ConversationMemory()
|
|
121
|
+
|
|
122
|
+
messages = [{"role": "user", "content": f"Message {i}"} for i in range(5)]
|
|
123
|
+
await memory.add(messages)
|
|
124
|
+
|
|
125
|
+
# Delete indices 1 and 3
|
|
126
|
+
await memory.delete([1, 3])
|
|
127
|
+
|
|
128
|
+
assert await memory.size() == 3
|
|
129
|
+
retrieved = await memory.get_messages()
|
|
130
|
+
assert retrieved[0]["content"] == "Message 0"
|
|
131
|
+
assert retrieved[1]["content"] == "Message 2"
|
|
132
|
+
assert retrieved[2]["content"] == "Message 4"
|
|
133
|
+
print("✅ Deleting multiple indices works")
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
@pytest.mark.asyncio
|
|
137
|
+
async def test_delete_invalid_index():
|
|
138
|
+
"""Test that deleting invalid index raises error."""
|
|
139
|
+
memory = ConversationMemory()
|
|
140
|
+
|
|
141
|
+
await memory.add({"role": "user", "content": "Hello"})
|
|
142
|
+
|
|
143
|
+
# Try to delete out of range index
|
|
144
|
+
with pytest.raises(IndexError):
|
|
145
|
+
await memory.delete(10)
|
|
146
|
+
|
|
147
|
+
print("✅ Invalid index deletion raises error correctly")
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
@pytest.mark.asyncio
|
|
151
|
+
async def test_get_last_message():
|
|
152
|
+
"""Test getting the last message."""
|
|
153
|
+
memory = ConversationMemory()
|
|
154
|
+
|
|
155
|
+
# Empty memory
|
|
156
|
+
assert await memory.get_last_message() is None
|
|
157
|
+
|
|
158
|
+
# Add messages
|
|
159
|
+
messages = [
|
|
160
|
+
{"role": "user", "content": "First"},
|
|
161
|
+
{"role": "assistant", "content": "Second"},
|
|
162
|
+
{"role": "user", "content": "Third"},
|
|
163
|
+
]
|
|
164
|
+
await memory.add(messages)
|
|
165
|
+
|
|
166
|
+
last = await memory.get_last_message()
|
|
167
|
+
assert last is not None
|
|
168
|
+
assert last["content"] == "Third"
|
|
169
|
+
print("✅ Getting last message works")
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
@pytest.mark.asyncio
|
|
173
|
+
async def test_get_messages_by_role():
|
|
174
|
+
"""Test filtering messages by role."""
|
|
175
|
+
memory = ConversationMemory()
|
|
176
|
+
|
|
177
|
+
messages = [
|
|
178
|
+
{"role": "user", "content": "User 1"},
|
|
179
|
+
{"role": "assistant", "content": "Assistant 1"},
|
|
180
|
+
{"role": "user", "content": "User 2"},
|
|
181
|
+
{"role": "assistant", "content": "Assistant 2"},
|
|
182
|
+
{"role": "system", "content": "System 1"},
|
|
183
|
+
]
|
|
184
|
+
await memory.add(messages)
|
|
185
|
+
|
|
186
|
+
user_messages = await memory.get_messages_by_role("user")
|
|
187
|
+
assert len(user_messages) == 2
|
|
188
|
+
assert all(msg["role"] == "user" for msg in user_messages)
|
|
189
|
+
|
|
190
|
+
assistant_messages = await memory.get_messages_by_role("assistant")
|
|
191
|
+
assert len(assistant_messages) == 2
|
|
192
|
+
assert all(msg["role"] == "assistant" for msg in assistant_messages)
|
|
193
|
+
|
|
194
|
+
system_messages = await memory.get_messages_by_role("system")
|
|
195
|
+
assert len(system_messages) == 1
|
|
196
|
+
|
|
197
|
+
print("✅ Filtering by role works")
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
@pytest.mark.asyncio
|
|
201
|
+
async def test_get_messages_with_limit():
|
|
202
|
+
"""Test getting messages with limit."""
|
|
203
|
+
memory = ConversationMemory()
|
|
204
|
+
|
|
205
|
+
messages = [{"role": "user", "content": f"Message {i}"} for i in range(10)]
|
|
206
|
+
await memory.add(messages)
|
|
207
|
+
|
|
208
|
+
# Get last 3 messages
|
|
209
|
+
recent = await memory.get_messages(limit=3)
|
|
210
|
+
assert len(recent) == 3
|
|
211
|
+
assert recent[0]["content"] == "Message 7"
|
|
212
|
+
assert recent[1]["content"] == "Message 8"
|
|
213
|
+
assert recent[2]["content"] == "Message 9"
|
|
214
|
+
print("✅ Getting messages with limit works")
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
@pytest.mark.asyncio
|
|
218
|
+
async def test_truncate_to_size():
|
|
219
|
+
"""Test truncating memory to a maximum size."""
|
|
220
|
+
memory = ConversationMemory()
|
|
221
|
+
|
|
222
|
+
messages = [{"role": "user", "content": f"Message {i}"} for i in range(10)]
|
|
223
|
+
await memory.add(messages)
|
|
224
|
+
|
|
225
|
+
# Truncate to last 5 messages
|
|
226
|
+
await memory.truncate_to_size(5)
|
|
227
|
+
|
|
228
|
+
assert await memory.size() == 5
|
|
229
|
+
retrieved = await memory.get_messages()
|
|
230
|
+
assert retrieved[0]["content"] == "Message 5"
|
|
231
|
+
assert retrieved[4]["content"] == "Message 9"
|
|
232
|
+
print("✅ Truncating to size works")
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
@pytest.mark.asyncio
|
|
236
|
+
async def test_clear_memory():
|
|
237
|
+
"""Test clearing all messages from memory."""
|
|
238
|
+
memory = ConversationMemory()
|
|
239
|
+
|
|
240
|
+
messages = [{"role": "user", "content": f"Message {i}"} for i in range(5)]
|
|
241
|
+
await memory.add(messages)
|
|
242
|
+
|
|
243
|
+
assert await memory.size() == 5
|
|
244
|
+
|
|
245
|
+
await memory.clear()
|
|
246
|
+
|
|
247
|
+
assert await memory.size() == 0
|
|
248
|
+
assert await memory.get_messages() == []
|
|
249
|
+
print("✅ Clearing memory works")
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
@pytest.mark.asyncio
|
|
253
|
+
async def test_state_dict_serialization():
|
|
254
|
+
"""Test state serialization and deserialization."""
|
|
255
|
+
memory1 = ConversationMemory()
|
|
256
|
+
|
|
257
|
+
messages = [
|
|
258
|
+
{"role": "user", "content": "Hello"},
|
|
259
|
+
{"role": "assistant", "content": "Hi!"},
|
|
260
|
+
]
|
|
261
|
+
await memory1.add(messages)
|
|
262
|
+
|
|
263
|
+
# Export state
|
|
264
|
+
state = memory1.state_dict()
|
|
265
|
+
assert "messages" in state
|
|
266
|
+
assert len(state["messages"]) == 2
|
|
267
|
+
|
|
268
|
+
# Load into new memory
|
|
269
|
+
memory2 = ConversationMemory()
|
|
270
|
+
memory2.load_state_dict(state)
|
|
271
|
+
|
|
272
|
+
assert await memory2.size() == 2
|
|
273
|
+
retrieved = await memory2.get_messages()
|
|
274
|
+
assert retrieved[0]["content"] == "Hello"
|
|
275
|
+
assert retrieved[1]["content"] == "Hi!"
|
|
276
|
+
print("✅ State serialization works")
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
@pytest.mark.asyncio
|
|
280
|
+
async def test_state_dict_strict_mode():
|
|
281
|
+
"""Test state loading with strict mode validation."""
|
|
282
|
+
memory = ConversationMemory()
|
|
283
|
+
|
|
284
|
+
# Invalid state dict (missing 'messages' key)
|
|
285
|
+
invalid_state = {"wrong_key": []}
|
|
286
|
+
|
|
287
|
+
# Should raise error in strict mode
|
|
288
|
+
with pytest.raises(ValueError):
|
|
289
|
+
memory.load_state_dict(invalid_state, strict=True)
|
|
290
|
+
|
|
291
|
+
# Should not raise error in non-strict mode
|
|
292
|
+
memory.load_state_dict(invalid_state, strict=False)
|
|
293
|
+
assert await memory.size() == 0
|
|
294
|
+
|
|
295
|
+
print("✅ State dict strict mode works")
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
@pytest.mark.asyncio
|
|
299
|
+
async def test_add_none_message():
|
|
300
|
+
"""Test that adding None message is handled gracefully."""
|
|
301
|
+
memory = ConversationMemory()
|
|
302
|
+
|
|
303
|
+
await memory.add(None)
|
|
304
|
+
assert await memory.size() == 0
|
|
305
|
+
print("✅ Adding None message handled gracefully")
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
@pytest.mark.asyncio
|
|
309
|
+
async def test_add_invalid_message_type():
|
|
310
|
+
"""Test that adding invalid message type raises error."""
|
|
311
|
+
memory = ConversationMemory()
|
|
312
|
+
|
|
313
|
+
# Try to add a string instead of dict
|
|
314
|
+
with pytest.raises(TypeError):
|
|
315
|
+
await memory.add("invalid message")
|
|
316
|
+
|
|
317
|
+
# Try to add list of non-dicts
|
|
318
|
+
with pytest.raises(TypeError):
|
|
319
|
+
await memory.add(["invalid", "messages"])
|
|
320
|
+
|
|
321
|
+
print("✅ Invalid message type raises error correctly")
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
@pytest.mark.asyncio
|
|
325
|
+
async def test_retrieve_not_implemented():
|
|
326
|
+
"""Test that retrieve method raises NotImplementedError."""
|
|
327
|
+
memory = ConversationMemory()
|
|
328
|
+
|
|
329
|
+
with pytest.raises(NotImplementedError):
|
|
330
|
+
await memory.retrieve("some query")
|
|
331
|
+
|
|
332
|
+
print("✅ Retrieve method correctly not implemented")
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
@pytest.mark.asyncio
|
|
336
|
+
async def test_message_isolation():
|
|
337
|
+
"""Test that returned messages are copies, not references."""
|
|
338
|
+
memory = ConversationMemory()
|
|
339
|
+
|
|
340
|
+
original = {"role": "user", "content": "Original"}
|
|
341
|
+
await memory.add(original)
|
|
342
|
+
|
|
343
|
+
# Get messages and modify
|
|
344
|
+
retrieved = await memory.get_messages()
|
|
345
|
+
retrieved[0]["content"] = "Modified"
|
|
346
|
+
|
|
347
|
+
# Original in memory should be unchanged
|
|
348
|
+
messages = await memory.get_messages()
|
|
349
|
+
assert messages[0]["content"] == "Original"
|
|
350
|
+
print("✅ Message isolation works correctly")
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
if __name__ == "__main__":
|
|
354
|
+
import asyncio
|
|
355
|
+
|
|
356
|
+
async def run_all_tests():
|
|
357
|
+
"""Run all tests manually."""
|
|
358
|
+
print("\n=== Running ConversationMemory Tests ===\n")
|
|
359
|
+
|
|
360
|
+
await test_conversation_memory_initialization()
|
|
361
|
+
await test_add_single_message()
|
|
362
|
+
await test_add_multiple_messages()
|
|
363
|
+
await test_duplicate_prevention()
|
|
364
|
+
await test_allow_duplicates()
|
|
365
|
+
await test_delete_by_index()
|
|
366
|
+
await test_delete_multiple_indices()
|
|
367
|
+
await test_delete_invalid_index()
|
|
368
|
+
await test_get_last_message()
|
|
369
|
+
await test_get_messages_by_role()
|
|
370
|
+
await test_get_messages_with_limit()
|
|
371
|
+
await test_truncate_to_size()
|
|
372
|
+
await test_clear_memory()
|
|
373
|
+
await test_state_dict_serialization()
|
|
374
|
+
await test_state_dict_strict_mode()
|
|
375
|
+
await test_add_none_message()
|
|
376
|
+
await test_add_invalid_message_type()
|
|
377
|
+
await test_retrieve_not_implemented()
|
|
378
|
+
await test_message_isolation()
|
|
379
|
+
|
|
380
|
+
print("\n=== All ConversationMemory Tests Passed! ===\n")
|
|
381
|
+
|
|
382
|
+
asyncio.run(run_all_tests())
|
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
"""
|
|
3
|
+
Test LangGraph Lesson Planner Tool
|
|
4
|
+
Tests the interoperability feature where LangGraph state graph is wrapped as a MassGen custom tool.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import asyncio
|
|
8
|
+
import os
|
|
9
|
+
import sys
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
|
|
12
|
+
import pytest
|
|
13
|
+
|
|
14
|
+
# Add parent directory to path for imports
|
|
15
|
+
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
|
16
|
+
|
|
17
|
+
from massgen.tool._extraframework_agents.langgraph_lesson_planner_tool import ( # noqa: E402
|
|
18
|
+
langgraph_lesson_planner,
|
|
19
|
+
)
|
|
20
|
+
from massgen.tool._result import ExecutionResult # noqa: E402
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class TestLangGraphLessonPlannerTool:
|
|
24
|
+
"""Test LangGraph Lesson Planner Tool functionality."""
|
|
25
|
+
|
|
26
|
+
@pytest.mark.asyncio
|
|
27
|
+
async def test_basic_lesson_plan_creation(self):
|
|
28
|
+
"""Test basic lesson plan creation with a simple topic."""
|
|
29
|
+
# Skip if no API key
|
|
30
|
+
api_key = os.getenv("OPENAI_API_KEY")
|
|
31
|
+
if not api_key:
|
|
32
|
+
pytest.skip("OPENAI_API_KEY not set")
|
|
33
|
+
|
|
34
|
+
# Test with a simple topic
|
|
35
|
+
result = await langgraph_lesson_planner(topic="photosynthesis", api_key=api_key)
|
|
36
|
+
|
|
37
|
+
# Verify result structure
|
|
38
|
+
assert isinstance(result, ExecutionResult)
|
|
39
|
+
assert len(result.output_blocks) > 0
|
|
40
|
+
# Check that the result doesn't contain an error
|
|
41
|
+
assert not result.output_blocks[0].data.startswith("Error:")
|
|
42
|
+
|
|
43
|
+
# Verify lesson plan contains expected elements
|
|
44
|
+
lesson_plan = result.output_blocks[0].data
|
|
45
|
+
assert "photosynthesis" in lesson_plan.lower()
|
|
46
|
+
|
|
47
|
+
@pytest.mark.asyncio
|
|
48
|
+
async def test_lesson_plan_with_env_api_key(self):
|
|
49
|
+
"""Test lesson plan creation using environment variable for API key."""
|
|
50
|
+
# Skip if no API key
|
|
51
|
+
if not os.getenv("OPENAI_API_KEY"):
|
|
52
|
+
pytest.skip("OPENAI_API_KEY not set")
|
|
53
|
+
|
|
54
|
+
# Test without passing api_key parameter (should use env var)
|
|
55
|
+
result = await langgraph_lesson_planner(topic="fractions")
|
|
56
|
+
|
|
57
|
+
assert isinstance(result, ExecutionResult)
|
|
58
|
+
assert len(result.output_blocks) > 0
|
|
59
|
+
# Check that the result doesn't contain an error
|
|
60
|
+
assert not result.output_blocks[0].data.startswith("Error:")
|
|
61
|
+
|
|
62
|
+
@pytest.mark.asyncio
|
|
63
|
+
async def test_missing_api_key_error(self):
|
|
64
|
+
"""Test error handling when API key is missing."""
|
|
65
|
+
# Temporarily save and remove env var
|
|
66
|
+
original_key = os.environ.get("OPENAI_API_KEY")
|
|
67
|
+
if "OPENAI_API_KEY" in os.environ:
|
|
68
|
+
del os.environ["OPENAI_API_KEY"]
|
|
69
|
+
|
|
70
|
+
try:
|
|
71
|
+
result = await langgraph_lesson_planner(topic="test topic")
|
|
72
|
+
|
|
73
|
+
# Should return error result
|
|
74
|
+
assert isinstance(result, ExecutionResult)
|
|
75
|
+
assert result.output_blocks[0].data.startswith("Error:")
|
|
76
|
+
assert "OPENAI_API_KEY not found" in result.output_blocks[0].data
|
|
77
|
+
finally:
|
|
78
|
+
# Restore env var
|
|
79
|
+
if original_key:
|
|
80
|
+
os.environ["OPENAI_API_KEY"] = original_key
|
|
81
|
+
|
|
82
|
+
@pytest.mark.asyncio
|
|
83
|
+
async def test_different_topics(self):
|
|
84
|
+
"""Test lesson plan creation with different topics."""
|
|
85
|
+
# Skip if no API key
|
|
86
|
+
api_key = os.getenv("OPENAI_API_KEY")
|
|
87
|
+
if not api_key:
|
|
88
|
+
pytest.skip("OPENAI_API_KEY not set")
|
|
89
|
+
|
|
90
|
+
topics = ["addition", "animals", "water cycle"]
|
|
91
|
+
|
|
92
|
+
for topic in topics:
|
|
93
|
+
result = await langgraph_lesson_planner(topic=topic, api_key=api_key)
|
|
94
|
+
|
|
95
|
+
assert isinstance(result, ExecutionResult)
|
|
96
|
+
assert len(result.output_blocks) > 0
|
|
97
|
+
# Check that the result doesn't contain an error
|
|
98
|
+
assert not result.output_blocks[0].data.startswith("Error:")
|
|
99
|
+
assert topic.lower() in result.output_blocks[0].data.lower()
|
|
100
|
+
|
|
101
|
+
@pytest.mark.asyncio
|
|
102
|
+
async def test_concurrent_lesson_plan_creation(self):
|
|
103
|
+
"""Test creating multiple lesson plans concurrently."""
|
|
104
|
+
# Skip if no API key
|
|
105
|
+
api_key = os.getenv("OPENAI_API_KEY")
|
|
106
|
+
if not api_key:
|
|
107
|
+
pytest.skip("OPENAI_API_KEY not set")
|
|
108
|
+
|
|
109
|
+
topics = ["math", "science", "reading"]
|
|
110
|
+
|
|
111
|
+
# Create tasks for concurrent execution
|
|
112
|
+
tasks = [langgraph_lesson_planner(topic=topic, api_key=api_key) for topic in topics]
|
|
113
|
+
|
|
114
|
+
# Execute concurrently
|
|
115
|
+
results = await asyncio.gather(*tasks)
|
|
116
|
+
|
|
117
|
+
# Verify all results
|
|
118
|
+
assert len(results) == len(topics)
|
|
119
|
+
for i, result in enumerate(results):
|
|
120
|
+
assert isinstance(result, ExecutionResult)
|
|
121
|
+
assert len(result.output_blocks) > 0
|
|
122
|
+
# Check that the result doesn't contain an error
|
|
123
|
+
assert not result.output_blocks[0].data.startswith("Error:")
|
|
124
|
+
assert topics[i].lower() in result.output_blocks[0].data.lower()
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class TestLangGraphToolIntegration:
|
|
128
|
+
"""Test LangGraph tool integration with MassGen tool system."""
|
|
129
|
+
|
|
130
|
+
def test_tool_function_signature(self):
|
|
131
|
+
"""Test that the tool has the correct async signature."""
|
|
132
|
+
import inspect
|
|
133
|
+
|
|
134
|
+
assert inspect.iscoroutinefunction(langgraph_lesson_planner)
|
|
135
|
+
|
|
136
|
+
# Get function signature
|
|
137
|
+
sig = inspect.signature(langgraph_lesson_planner)
|
|
138
|
+
params = sig.parameters
|
|
139
|
+
|
|
140
|
+
# Verify parameters
|
|
141
|
+
assert "topic" in params
|
|
142
|
+
assert "api_key" in params
|
|
143
|
+
|
|
144
|
+
# Verify return annotation
|
|
145
|
+
assert sig.return_annotation == ExecutionResult
|
|
146
|
+
|
|
147
|
+
@pytest.mark.asyncio
|
|
148
|
+
async def test_execution_result_structure(self):
|
|
149
|
+
"""Test that the returned ExecutionResult has the correct structure."""
|
|
150
|
+
# Skip if no API key
|
|
151
|
+
api_key = os.getenv("OPENAI_API_KEY")
|
|
152
|
+
if not api_key:
|
|
153
|
+
pytest.skip("OPENAI_API_KEY not set")
|
|
154
|
+
|
|
155
|
+
result = await langgraph_lesson_planner(topic="test", api_key=api_key)
|
|
156
|
+
|
|
157
|
+
# Verify ExecutionResult structure
|
|
158
|
+
assert hasattr(result, "output_blocks")
|
|
159
|
+
assert isinstance(result.output_blocks, list)
|
|
160
|
+
assert len(result.output_blocks) > 0
|
|
161
|
+
# Check that the result doesn't contain an error
|
|
162
|
+
assert not result.output_blocks[0].data.startswith("Error:")
|
|
163
|
+
|
|
164
|
+
# Verify TextContent structure
|
|
165
|
+
from massgen.tool._result import TextContent
|
|
166
|
+
|
|
167
|
+
assert isinstance(result.output_blocks[0], TextContent)
|
|
168
|
+
assert hasattr(result.output_blocks[0], "data")
|
|
169
|
+
assert isinstance(result.output_blocks[0].data, str)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class TestLangGraphToolWithBackend:
|
|
173
|
+
"""Test LangGraph tool with ResponseBackend."""
|
|
174
|
+
|
|
175
|
+
@pytest.mark.asyncio
|
|
176
|
+
async def test_backend_registration(self):
|
|
177
|
+
"""Test registering LangGraph tool with ResponseBackend."""
|
|
178
|
+
from massgen.backend.response import ResponseBackend
|
|
179
|
+
|
|
180
|
+
api_key = os.getenv("OPENAI_API_KEY", "test-key")
|
|
181
|
+
|
|
182
|
+
# Import the tool
|
|
183
|
+
from massgen.tool._extraframework_agents.langgraph_lesson_planner_tool import (
|
|
184
|
+
langgraph_lesson_planner,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
# Register with backend
|
|
188
|
+
backend = ResponseBackend(
|
|
189
|
+
api_key=api_key,
|
|
190
|
+
custom_tools=[
|
|
191
|
+
{
|
|
192
|
+
"func": langgraph_lesson_planner,
|
|
193
|
+
"description": "Create a comprehensive lesson plan using LangGraph state graph",
|
|
194
|
+
},
|
|
195
|
+
],
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
# Verify tool is registered
|
|
199
|
+
assert "langgraph_lesson_planner" in backend._custom_tool_names
|
|
200
|
+
|
|
201
|
+
# Verify schema generation
|
|
202
|
+
schemas = backend._get_custom_tools_schemas()
|
|
203
|
+
assert len(schemas) >= 1
|
|
204
|
+
|
|
205
|
+
# Find our tool's schema
|
|
206
|
+
langgraph_schema = None
|
|
207
|
+
for schema in schemas:
|
|
208
|
+
if schema["function"]["name"] == "langgraph_lesson_planner":
|
|
209
|
+
langgraph_schema = schema
|
|
210
|
+
break
|
|
211
|
+
|
|
212
|
+
assert langgraph_schema is not None
|
|
213
|
+
assert langgraph_schema["type"] == "function"
|
|
214
|
+
assert "parameters" in langgraph_schema["function"]
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
# ============================================================================
|
|
218
|
+
# Run tests
|
|
219
|
+
# ============================================================================
|
|
220
|
+
|
|
221
|
+
if __name__ == "__main__":
|
|
222
|
+
# Run pytest
|
|
223
|
+
pytest.main([__file__, "-v"])
|