aient 1.1.91__py3-none-any.whl → 1.1.93__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,980 @@
1
+ import unittest
2
+ from unittest.mock import AsyncMock
3
+
4
+ import os
5
+ import sys
6
+
7
+ # Add the project root to the Python path
8
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
9
+
10
+
11
+ from architext import *
12
+
13
+ # ==============================================================================
14
+ # 单元测试部分
15
+ # ==============================================================================
16
+ class TestContextManagement(unittest.IsolatedAsyncioTestCase):
17
+
18
+ def setUp(self):
19
+ """在每个测试前设置环境"""
20
+ self.system_prompt_provider = Texts("你是一个AI助手。", name="system_prompt")
21
+ self.tools_provider = Tools(tools_json=[{"name": "read_file"}])
22
+ self.files_provider = Files()
23
+
24
+ async def test_a_initial_construction_and_render(self):
25
+ """测试优雅的初始化和首次渲染"""
26
+ messages = Messages(
27
+ SystemMessage(self.system_prompt_provider, self.tools_provider),
28
+ UserMessage(self.files_provider, Texts("这是我的初始问题。"))
29
+ )
30
+
31
+ self.assertEqual(len(messages), 2)
32
+ rendered = await messages.render_latest()
33
+
34
+ self.assertEqual(len(rendered), 2)
35
+ self.assertIn("<tools>", rendered[0]['content'])
36
+ self.assertNotIn("<files>", rendered[1]['content'])
37
+
38
+ async def test_b_provider_passthrough_and_refresh(self):
39
+ """测试通过 mock 验证缓存和刷新逻辑"""
40
+ # 使用一个简单的 Texts provider 来测试通用缓存逻辑,避免 Files 的副作用
41
+ text_provider = Texts("initial text")
42
+ text_provider.render = AsyncMock(wraps=text_provider.render)
43
+ messages = Messages(UserMessage(text_provider))
44
+
45
+ # 1. 首次刷新
46
+ await messages.refresh()
47
+ self.assertEqual(text_provider.render.call_count, 1)
48
+
49
+ # 2. 再次刷新,内容未变,不应再次调用 render
50
+ await messages.refresh()
51
+ self.assertEqual(text_provider.render.call_count, 1)
52
+
53
+ # 3. 更新内容,这会标记 provider 为 stale
54
+ text_provider.update("updated text")
55
+
56
+ # 4. 再次刷新,现在应该会重新调用 render
57
+ await messages.refresh()
58
+ rendered = messages.render()
59
+ self.assertEqual(text_provider.render.call_count, 2)
60
+ self.assertIn("updated text", rendered[0]['content'])
61
+
62
+ async def test_c_global_pop_and_indexed_insert(self):
63
+ """测试全局pop和通过索引insert的功能"""
64
+ messages = Messages(
65
+ SystemMessage(self.system_prompt_provider, self.tools_provider),
66
+ UserMessage(self.files_provider)
67
+ )
68
+
69
+ # 验证初始状态
70
+ initial_rendered = await messages.render_latest()
71
+ self.assertTrue(any("<tools>" in msg['content'] for msg in initial_rendered if msg['role'] == 'system'))
72
+
73
+ # 全局弹出 'tools' Provider
74
+ popped_tools_provider = messages.pop("tools")
75
+ self.assertIs(popped_tools_provider, self.tools_provider)
76
+
77
+ # 验证 pop 后的状态
78
+ rendered_after_pop = messages.render()
79
+ self.assertFalse(any("<tools>" in msg['content'] for msg in rendered_after_pop if msg['role'] == 'system'))
80
+
81
+ # 通过索引将弹出的provider插入到UserMessage的开头
82
+ messages[1].insert(0, popped_tools_provider)
83
+
84
+ # 验证 insert 后的状态
85
+ rendered_after_insert = messages.render()
86
+ user_message_content = next(msg['content'] for msg in rendered_after_insert if msg['role'] == 'user')
87
+ self.assertTrue(user_message_content.startswith("<tools>"))
88
+
89
+ async def test_d_multimodal_rendering(self):
90
+ """测试多模态(文本+图片)渲染"""
91
+ # Create a dummy image file for the test
92
+ dummy_image_path = "test_dummy_image.png"
93
+ with open(dummy_image_path, "w") as f:
94
+ f.write("dummy content")
95
+
96
+ messages = Messages(
97
+ UserMessage(
98
+ Texts("Describe the image."),
99
+ Images(url=dummy_image_path) # Test with optional name
100
+ )
101
+ )
102
+
103
+ rendered = await messages.render_latest()
104
+ self.assertEqual(len(rendered), 1)
105
+
106
+ content = rendered[0]['content']
107
+ self.assertIsInstance(content, list)
108
+ self.assertEqual(len(content), 2)
109
+
110
+ # Check text part
111
+ self.assertEqual(content[0]['type'], 'text')
112
+ self.assertEqual(content[0]['text'], 'Describe the image.')
113
+
114
+ # Check image part
115
+ self.assertEqual(content[1]['type'], 'image_url')
116
+ self.assertIn('data:image/png;base64,', content[1]['image_url']['url'])
117
+
118
+ # Clean up the dummy file
119
+ import os
120
+ os.remove(dummy_image_path)
121
+
122
+ async def test_e_multimodal_type_switching(self):
123
+ """测试多模态消息在pop图片后是否能正确回退到字符串渲染"""
124
+ dummy_image_path = "test_dummy_image_2.png"
125
+ with open(dummy_image_path, "w") as f:
126
+ f.write("dummy content")
127
+
128
+ messages = Messages(
129
+ UserMessage(
130
+ Texts("Look at this:"),
131
+ Images(url=dummy_image_path, name="image"), # Explicit name for popping
132
+ Texts("Any thoughts?")
133
+ )
134
+ )
135
+
136
+ # 1. Initial multimodal render
137
+ rendered_multi = await messages.render_latest()
138
+ content_multi = rendered_multi[0]['content']
139
+ self.assertIsInstance(content_multi, list)
140
+ self.assertEqual(len(content_multi), 3) # prefix, image, suffix
141
+
142
+ # 2. Pop the image
143
+ popped_image = messages.pop("image")
144
+ self.assertIsNotNone(popped_image)
145
+
146
+ # 3. Render again, should fall back to string content
147
+ rendered_str = messages.render() # No refresh needed
148
+ content_str = rendered_str[0]['content']
149
+ self.assertIsInstance(content_str, str)
150
+ self.assertEqual(content_str, "Look at this:\n\nAny thoughts?")
151
+
152
+ # Clean up
153
+ import os
154
+ os.remove(dummy_image_path)
155
+
156
+ def test_f_message_merging(self):
157
+ """测试初始化和追加时自动合并消息的功能"""
158
+ # 1. Test merging during initialization
159
+ messages = Messages(
160
+ UserMessage(Texts("Hello,")),
161
+ UserMessage(Texts("world!")),
162
+ SystemMessage(Texts("System prompt.")),
163
+ UserMessage(Texts("How are you?"))
164
+ )
165
+ # Should be merged into: User, System, User
166
+ self.assertEqual(len(messages), 3)
167
+ self.assertEqual(len(messages[0]._items), 2) # First UserMessage has 2 items
168
+ self.assertIn("text_", messages[0]._items[1].name)
169
+ self.assertEqual(messages[1].role, "system")
170
+ self.assertEqual(messages[2].role, "user")
171
+
172
+ # 2. Test merging during append
173
+ messages.append(UserMessage(Texts("I am fine.")))
174
+ self.assertEqual(len(messages), 3) # Still 3 messages
175
+ self.assertEqual(len(messages[2]._items), 2) # Last UserMessage now has 2 items
176
+ self.assertIn("text_", messages[2]._items[1].name)
177
+
178
+ # 3. Test appending a different role
179
+ messages.append(SystemMessage(Texts("Another prompt.")))
180
+ self.assertEqual(len(messages), 4) # Should not merge
181
+ self.assertEqual(messages[3].role, "system")
182
+
183
+ async def test_g_state_inconsistency_on_direct_message_modification(self):
184
+ """
185
+ 测试当直接在 Message 对象上执行 pop 操作时,
186
+ 顶层 Messages 对象的 _providers_index 是否会产生不一致。
187
+ """
188
+ messages = Messages(
189
+ SystemMessage(self.system_prompt_provider, self.tools_provider),
190
+ UserMessage(self.files_provider)
191
+ )
192
+
193
+ # 0. 先刷新一次,确保所有 provider 的 cache 都已填充
194
+ await messages.refresh()
195
+
196
+ # 1. 初始状态:'tools' 提供者应该在索引中
197
+ self.assertIsNotNone(messages.provider("tools"), "初始状态下 'tools' 提供者应该能被找到")
198
+ self.assertIs(messages.provider("tools"), self.tools_provider)
199
+
200
+ # 2. 直接在子消息对象上执行 pop 操作
201
+ system_message = messages[0]
202
+ popped_provider = system_message.pop("tools")
203
+
204
+ # 验证是否真的从 Message 对象中弹出了
205
+ self.assertIs(popped_provider, self.tools_provider, "应该从 SystemMessage 中成功弹出 provider")
206
+ self.assertNotIn(self.tools_provider, system_message.providers(), "provider 不应再存在于 SystemMessage 的 providers 列表中")
207
+
208
+ # 3. 核心问题:检查顶层 Messages 的索引
209
+ # 在理想情况下,直接修改子消息应该同步更新顶层索引。
210
+ # 因此,我们断言 provider 现在应该是找不到的。这个测试现在应该会失败。
211
+ provider_after_pop = messages.provider("tools")
212
+ self.assertIsNone(provider_after_pop, "BUG: 直接从子消息中 pop 后,顶层索引未同步,仍然可以找到 provider")
213
+
214
+ # 4. 进一步验证:渲染结果和索引内容不一致
215
+ # 渲染结果应该不再包含 tools 内容,因为 Message 对象本身是正确的
216
+ rendered_messages = messages.render()
217
+ self.assertGreater(len(rendered_messages), 0, "渲染后的消息列表不应为空")
218
+ rendered_content = rendered_messages[0]['content']
219
+ self.assertNotIn("<tools>", rendered_content, "渲染结果中不应再包含 'tools' 的内容,证明数据源已更新")
220
+
221
+ async def test_h_pop_message_by_index(self):
222
+ """测试通过整数索引弹出Message的功能"""
223
+ user_provider = Texts("User message 1")
224
+ messages = Messages(
225
+ SystemMessage(Texts("System message")),
226
+ UserMessage(user_provider),
227
+ AssistantMessage(Texts("Assistant response"))
228
+ )
229
+
230
+ # 初始状态断言
231
+ self.assertEqual(len(messages), 3)
232
+ self.assertIsNotNone(messages.provider(user_provider.name))
233
+
234
+ # 弹出索引为 1 的 UserMessage
235
+ popped_message = messages.pop(1)
236
+
237
+ # 验证弹出的消息是否正确
238
+ self.assertIsInstance(popped_message, UserMessage)
239
+ self.assertEqual(len(popped_message.providers()), 1)
240
+ self.assertEqual(popped_message.providers()[0].name, user_provider.name)
241
+
242
+ # 验证 Messages 对象的当前状态
243
+ self.assertEqual(len(messages), 2)
244
+ self.assertEqual(messages[0].role, "system")
245
+ self.assertEqual(messages[1].role, "assistant")
246
+
247
+ # 验证 provider 索引是否已更新
248
+ self.assertIsNone(messages.provider(user_provider.name))
249
+
250
+ # 测试弹出不存在的索引
251
+ popped_none = messages.pop(99)
252
+ self.assertIsNone(popped_none)
253
+ self.assertEqual(len(messages), 2)
254
+
255
+ async def test_i_generic_update_and_refresh(self):
256
+ """测试新添加的 update 方法是否能正确更新内容并标记为 stale"""
257
+ # 1. Setup providers
258
+ text_provider = Texts("Hello")
259
+ tools_provider = Tools([{"name": "tool_A"}])
260
+
261
+ dummy_image_path = "test_dummy_image_3.png"
262
+ with open(dummy_image_path, "w") as f: f.write("dummy content")
263
+ image_provider = Images(url=dummy_image_path, name="logo")
264
+
265
+ messages = Messages(UserMessage(text_provider, tools_provider, image_provider))
266
+
267
+ # Mock the render methods to monitor calls
268
+ text_provider.render = AsyncMock(wraps=text_provider.render)
269
+ tools_provider.render = AsyncMock(wraps=tools_provider.render)
270
+ image_provider.render = AsyncMock(wraps=image_provider.render)
271
+
272
+ # 2. Initial render
273
+ rendered_initial = await messages.render_latest()
274
+ self.assertIn("Hello", rendered_initial[0]['content'][0]['text'])
275
+ self.assertIn("tool_A", rendered_initial[0]['content'][1]['text'])
276
+ self.assertEqual(text_provider.render.call_count, 1)
277
+ self.assertEqual(tools_provider.render.call_count, 1)
278
+ self.assertEqual(image_provider.render.call_count, 1)
279
+
280
+ # 3. Update providers
281
+ text_provider.update("Goodbye")
282
+ tools_provider.update([{"name": "tool_B"}])
283
+
284
+ new_dummy_image_path = "test_dummy_image_4.png"
285
+ with open(new_dummy_image_path, "w") as f: f.write("new dummy content")
286
+ image_provider.update(url=new_dummy_image_path)
287
+
288
+ # Calling refresh again should not re-fetch yet because we haven't called messages.refresh()
289
+ await text_provider.refresh()
290
+ self.assertEqual(text_provider.render.call_count, 2)
291
+
292
+ # 4. Re-render after update
293
+ rendered_updated = await messages.render_latest()
294
+ self.assertIn("Goodbye", rendered_updated[0]['content'][0]['text'])
295
+ self.assertIn("tool_B", rendered_updated[0]['content'][1]['text'])
296
+
297
+ # Verify that render was called again for all updated providers
298
+ self.assertEqual(text_provider.render.call_count, 2)
299
+ self.assertEqual(tools_provider.render.call_count, 2)
300
+ self.assertEqual(image_provider.render.call_count, 2)
301
+
302
+ # Clean up
303
+ os.remove(dummy_image_path)
304
+ os.remove(new_dummy_image_path)
305
+
306
+ async def test_j_pop_last_message_without_arguments(self):
307
+ """测试不带参数调用 pop() 时,弹出最后一个 Message"""
308
+ m1 = SystemMessage(Texts("System"))
309
+ m2 = UserMessage(Texts("User"))
310
+ m3 = AssistantMessage(Texts("Assistant"))
311
+ messages = Messages(m1, m2, m3)
312
+
313
+ self.assertEqual(len(messages), 3)
314
+
315
+ # Pop the last message
316
+ popped_message = messages.pop()
317
+
318
+ self.assertIs(popped_message, m3)
319
+ self.assertEqual(len(messages), 2)
320
+ self.assertIs(messages[-1], m2)
321
+
322
+ # Pop again
323
+ popped_message_2 = messages.pop()
324
+ self.assertIs(popped_message_2, m2)
325
+ self.assertEqual(len(messages), 1)
326
+
327
+ # Pop the last one
328
+ popped_message_3 = messages.pop()
329
+ self.assertIs(popped_message_3, m1)
330
+ self.assertEqual(len(messages), 0)
331
+
332
+ # Pop from empty
333
+ popped_none = messages.pop()
334
+ self.assertIsNone(popped_none)
335
+
336
+ async def test_k_image_provider_with_base64_url(self):
337
+ """测试 Images provider 是否能正确处理 base64 data URL"""
338
+ # A simple 1x1 transparent PNG as a base64 string
339
+ base64_image_url = ""
340
+
341
+ messages = Messages(
342
+ UserMessage(
343
+ Texts("This is a base64 image."),
344
+ Images(url=base64_image_url, name="base64_img")
345
+ )
346
+ )
347
+
348
+ rendered = await messages.render_latest()
349
+ self.assertEqual(len(rendered), 1)
350
+
351
+ content = rendered[0]['content']
352
+ self.assertIsInstance(content, list)
353
+ self.assertEqual(len(content), 2)
354
+
355
+ # Check text part
356
+ self.assertEqual(content[0]['type'], 'text')
357
+ self.assertEqual(content[0]['text'], 'This is a base64 image.')
358
+
359
+ # Check image part
360
+ image_content = content[1]
361
+ self.assertEqual(image_content['type'], 'image_url')
362
+ self.assertEqual(image_content['image_url']['url'], base64_image_url)
363
+
364
+ # Also test the update method
365
+ provider = messages.provider("base64_img")
366
+ self.assertIsNotNone(provider)
367
+
368
+ new_base64_url = ""
369
+ provider.update(url=new_base64_url)
370
+
371
+ rendered_updated = await messages.render_latest()
372
+ self.assertEqual(rendered_updated[0]['content'][1]['image_url']['url'], new_base64_url)
373
+
374
+ def test_l_role_message_factory(self):
375
+ """测试 RoleMessage 工厂类是否能创建正确的子类实例"""
376
+ system_msg = RoleMessage('system', Texts("System content"))
377
+ user_msg = RoleMessage('user', Texts("User content"))
378
+ assistant_msg = RoleMessage('assistant', Texts("Assistant content"))
379
+
380
+ self.assertIsInstance(system_msg, SystemMessage)
381
+ self.assertEqual(system_msg.role, 'system')
382
+ self.assertIsInstance(user_msg, UserMessage)
383
+ self.assertEqual(user_msg.role, 'user')
384
+ self.assertIsInstance(assistant_msg, AssistantMessage)
385
+ self.assertEqual(assistant_msg.role, 'assistant')
386
+
387
+ # 测试无效的 role
388
+ with self.assertRaises(ValueError):
389
+ RoleMessage('invalid_role', Texts("Content"))
390
+
391
+ async def test_m_optional_name_for_texts(self):
392
+ """测试 Texts provider 的 name 参数是否可选,并能自动生成唯一名称"""
393
+ # 1. 不提供 name
394
+ text_provider_1 = Texts("This is a test.")
395
+ self.assertTrue(text_provider_1.name.startswith("text_"))
396
+
397
+ # 2. 提供 name
398
+ text_provider_2 = Texts("This is another test.", name="my_name")
399
+ self.assertEqual(text_provider_2.name, "my_name")
400
+
401
+ # 3. 验证相同内容的文本生成相同的 name
402
+ text_provider_3 = Texts("This is a test.")
403
+ self.assertEqual(text_provider_1.name, text_provider_3.name)
404
+
405
+ # 4. 验证不同内容的文本生成不同的 name
406
+ text_provider_4 = Texts("This is a different test.")
407
+ self.assertNotEqual(text_provider_1.name, text_provider_4.name)
408
+
409
+ # 5. 在 Messages 中使用
410
+ messages = Messages(UserMessage(text_provider_1))
411
+ provider_from_messages = messages.provider(text_provider_1.name)
412
+ self.assertIs(provider_from_messages, text_provider_1)
413
+
414
+ async def test_n_string_to_texts_conversion(self):
415
+ """测试在Message初始化时,字符串是否能被自动转换为Texts provider"""
416
+ # 1. 初始化一个包含字符串的UserMessage
417
+ user_message = UserMessage(self.files_provider, "This is a raw string.")
418
+
419
+ # 验证 _items 列表中的第二个元素是否是 Texts 类的实例
420
+ self.assertEqual(len(user_message.providers()), 2)
421
+ self.assertIsInstance(user_message.providers()[0], Files)
422
+ self.assertIsInstance(user_message.providers()[1], Texts)
423
+
424
+ # 验证转换后的 Texts provider 内容是否正确
425
+ # 我们需要异步地获取内容
426
+ text_provider = user_message.providers()[1]
427
+ await text_provider.refresh() # 手动刷新以获取内容
428
+ content_block = text_provider.get_content_block()
429
+ self.assertIsNotNone(content_block)
430
+ self.assertEqual(content_block.content, "This is a raw string.")
431
+
432
+ # 2. 在 Messages 容器中测试
433
+ messages = Messages(
434
+ SystemMessage("System prompt here."),
435
+ user_message
436
+ )
437
+ await messages.refresh()
438
+ rendered = messages.render()
439
+
440
+ self.assertEqual(len(rendered), 2)
441
+ self.assertEqual(rendered[0]['content'], "System prompt here.")
442
+ # 在user message中,files provider没有内容,所以只有string provider的内容
443
+ self.assertEqual(rendered[1]['content'], "This is a raw string.")
444
+
445
+ # 3. 测试RoleMessage工厂类
446
+ factory_user_msg = RoleMessage('user', "Factory-created string.")
447
+ self.assertIsInstance(factory_user_msg, UserMessage)
448
+ self.assertIsInstance(factory_user_msg.providers()[0], Texts)
449
+
450
+ # 4. 测试无效类型
451
+ with self.assertRaises(TypeError):
452
+ UserMessage(123) # 传入不支持的整数类型
453
+
454
+ async def test_o_list_to_providers_conversion(self):
455
+ """测试在Message初始化时,列表内容是否能被自动转换为相应的provider"""
456
+ # 1. 混合内容的列表
457
+ mixed_content_list = [
458
+ {'type': 'text', 'text': 'Describe the following image.'},
459
+ {'type': 'image_url', 'image_url': {'url': ''}}
460
+ ]
461
+ user_message_mixed = UserMessage(mixed_content_list)
462
+
463
+ self.assertEqual(len(user_message_mixed.providers()), 2)
464
+ self.assertIsInstance(user_message_mixed.providers()[0], Texts)
465
+ self.assertIsInstance(user_message_mixed.providers()[1], Images)
466
+
467
+ # 验证内容
468
+ providers = user_message_mixed.providers()
469
+ await asyncio.gather(*[p.refresh() for p in providers]) # 刷新所有providers
470
+ self.assertEqual(providers[0].get_content_block().content, 'Describe the following image.')
471
+ self.assertEqual(providers[1].get_content_block().content, '')
472
+
473
+ # 2. 纯文本内容的列表
474
+ text_only_list = [
475
+ {'type': 'text', 'text': 'First line.'},
476
+ {'type': 'text', 'text': 'Second line.'}
477
+ ]
478
+ user_message_text_only = UserMessage(text_only_list)
479
+
480
+ self.assertEqual(len(user_message_text_only.providers()), 2)
481
+ self.assertIsInstance(user_message_text_only.providers()[0], Texts)
482
+ self.assertIsInstance(user_message_text_only.providers()[1], Texts)
483
+
484
+ # 3. 在 Messages 容器中测试
485
+ messages = Messages(UserMessage(mixed_content_list))
486
+ rendered = await messages.render_latest()
487
+
488
+ self.assertEqual(len(rendered), 1)
489
+ self.assertIsInstance(rendered[0]['content'], list)
490
+ self.assertEqual(len(rendered[0]['content']), 2)
491
+ self.assertEqual(rendered[0]['content'][0]['type'], 'text')
492
+ self.assertEqual(rendered[0]['content'][1]['type'], 'image_url')
493
+
494
+ # 4. 测试无效的列表项
495
+ invalid_list = [{'type': 'invalid_type'}]
496
+ with self.assertRaises(ValueError):
497
+ UserMessage(invalid_list)
498
+
499
+ async def test_p_empty_message_boolean_context(self):
500
+ """测试一个空的 Message 对象在布尔上下文中是否为 False"""
501
+ # 1. 创建一个不含任何 provider 的空 UserMessage
502
+ empty_message = UserMessage()
503
+ self.assertFalse(empty_message, "一个空的 UserMessage 在布尔上下文中应该为 False")
504
+
505
+ # 2. 创建一个包含 provider 的 UserMessage
506
+ non_empty_message = UserMessage("Hello")
507
+ self.assertTrue(non_empty_message, "一个非空的 UserMessage 在布尔上下文中应该为 True")
508
+
509
+ # 3. 测试一个 provider 被移除后变为空消息的情况
510
+ message_to_be_emptied = UserMessage(Texts("content", name="removable"))
511
+ self.assertTrue(message_to_be_emptied, "消息在移除前应为 True")
512
+ message_to_be_emptied.pop("removable")
513
+ self.assertFalse(message_to_be_emptied, "消息在最后一个 provider 被移除后应为 False")
514
+
515
+ async def test_q_string_addition_to_message(self):
516
+ """测试字符串与Message对象相加的功能"""
517
+ # 1. 创建一个 UserMessage
518
+ original_message = UserMessage("hello")
519
+
520
+ # 2. 将字符串与 UserMessage 相加
521
+ new_message = "hi" + original_message
522
+
523
+ # 3. 验证新消息的类型和内容
524
+ self.assertIsInstance(new_message, UserMessage, "结果应该是一个 UserMessage 实例")
525
+ self.assertEqual(len(new_message.providers()), 2, "新消息应该包含两个 provider")
526
+
527
+ providers = new_message.providers()
528
+ self.assertIsInstance(providers[0], Texts, "第一个 provider 应该是 Texts 类型")
529
+ self.assertIsInstance(providers[1], Texts, "第二个 provider 应该是 Texts 类型")
530
+
531
+ # 刷新以获取内容
532
+ await asyncio.gather(*[p.refresh() for p in providers])
533
+
534
+ self.assertEqual(providers[0].get_content_block().content, "hi", "第一个 provider 的内容应该是 'hi'")
535
+ self.assertEqual(providers[1].get_content_block().content, "hello", "第二个 provider 的内容应该是 'hello'")
536
+
537
+ # 4. 验证原始消息没有被修改
538
+ self.assertEqual(len(original_message.providers()), 1, "原始消息不应该被修改")
539
+
540
+ # 5. 测试 UserMessage + "string"
541
+ new_message_add = original_message + "world"
542
+ self.assertIsInstance(new_message_add, UserMessage)
543
+ self.assertEqual(len(new_message_add.providers()), 2)
544
+
545
+ providers_add = new_message_add.providers()
546
+ await asyncio.gather(*[p.refresh() for p in providers_add])
547
+ self.assertEqual(providers_add[0].get_content_block().content, "hello")
548
+ self.assertEqual(providers_add[1].get_content_block().content, "world")
549
+
550
+ async def test_r_message_addition_and_flattening(self):
551
+ """测试 Message 对象相加和嵌套初始化时的扁平化功能"""
552
+ # 1. 测试 "str" + UserMessage
553
+ combined_message = "hi" + UserMessage("hello")
554
+ self.assertIsInstance(combined_message, UserMessage)
555
+ self.assertEqual(len(combined_message.providers()), 2)
556
+
557
+ providers = combined_message.providers()
558
+ await asyncio.gather(*[p.refresh() for p in providers])
559
+ self.assertEqual(providers[0].get_content_block().content, "hi")
560
+ self.assertEqual(providers[1].get_content_block().content, "hello")
561
+
562
+ # 2. 测试 UserMessage(UserMessage(...)) 扁平化
563
+ # 按照用户的要求,UserMessage(UserMessage(...)) 应该被扁平化
564
+ nested_message = UserMessage(UserMessage("item1", "item2"))
565
+ self.assertEqual(len(nested_message.providers()), 2)
566
+
567
+ providers_nested = nested_message.providers()
568
+ self.assertIsInstance(providers_nested[0], Texts)
569
+ self.assertIsInstance(providers_nested[1], Texts)
570
+
571
+ await asyncio.gather(*[p.refresh() for p in providers_nested])
572
+ self.assertEqual(providers_nested[0].get_content_block().content, "item1")
573
+ self.assertEqual(providers_nested[1].get_content_block().content, "item2")
574
+
575
+ # 3. 结合 1 和 2,测试用户的完整场景
576
+ final_message = UserMessage("hi" + UserMessage("hello"))
577
+ self.assertIsInstance(final_message, UserMessage)
578
+ self.assertEqual(len(final_message.providers()), 2)
579
+
580
+ providers_final = final_message.providers()
581
+ await asyncio.gather(*[p.refresh() for p in providers_final])
582
+ self.assertEqual(providers_final[0].get_content_block().content, "hi")
583
+ self.assertEqual(providers_final[1].get_content_block().content, "hello")
584
+
585
+ async def test_s_len_and_pop_with_get_method(self):
586
+ """测试 len() 功能和 pop() 返回的对象支持 .get('role')"""
587
+ messages = Messages(
588
+ SystemMessage("System prompt"),
589
+ UserMessage("User question"),
590
+ AssistantMessage("Assistant answer")
591
+ )
592
+
593
+ # 1. 测试 len()
594
+ self.assertEqual(len(messages), 3, "len(messages) 应该返回消息的数量")
595
+
596
+ # 2. 弹出中间的消息
597
+ popped_message = messages.pop(1)
598
+ self.assertIsNotNone(popped_message, "pop(1) 应该返回一个消息对象")
599
+ self.assertIsInstance(popped_message, UserMessage)
600
+
601
+ # 3. 验证弹出的消息
602
+ # 这行会失败,因为 Message 对象没有 get 方法
603
+ self.assertEqual(popped_message.get("role"), "user", "弹出的消息应该可以通过 .get('role') 获取角色")
604
+
605
+ # 4. 验证 pop 后的状态
606
+ self.assertEqual(len(messages), 2, "pop() 后消息数量应该减少")
607
+ self.assertEqual(messages[0].role, "system")
608
+ self.assertEqual(messages[1].role, "assistant")
609
+
610
+ # 5. 测试 .get() 对不存在的键返回默认值
611
+ self.assertIsNone(popped_message.get("non_existent_key"), ".get() 对不存在的键应该返回 None")
612
+ self.assertEqual(popped_message.get("non_existent_key", "default"), "default", ".get() 应支持默认值")
613
+
614
+ async def test_t_pop_and_get_tool_calls(self):
615
+ """测试弹出 ToolCalls 消息后,可以通过 .get('tool_calls') 访问其内容"""
616
+ from dataclasses import dataclass, field
617
+ @dataclass
618
+ class MockFunction:
619
+ name: str
620
+ arguments: str
621
+
622
+ @dataclass
623
+ class MockToolCall:
624
+ id: str
625
+ type: str = "function"
626
+ function: MockFunction = field(default_factory=lambda: MockFunction("", ""))
627
+
628
+ tool_call_list = [MockToolCall(id="call_123", function=MockFunction(name="test", arguments="{}"))]
629
+
630
+ messages = Messages(
631
+ UserMessage("A regular message"),
632
+ ToolCalls(tool_calls=tool_call_list)
633
+ )
634
+
635
+ # 1. 弹出 ToolCalls 消息
636
+ popped_tool_call_message = messages.pop(1)
637
+ self.assertIsInstance(popped_tool_call_message, ToolCalls)
638
+
639
+ # 2. 验证 .get("tool_calls")
640
+ retrieved_tool_calls = popped_tool_call_message.get("tool_calls")
641
+ self.assertIsNotNone(retrieved_tool_calls)
642
+ self.assertEqual(len(retrieved_tool_calls), 1)
643
+ self.assertIs(retrieved_tool_calls, tool_call_list)
644
+
645
+ # 3. 弹出普通消息
646
+ popped_user_message = messages.pop(0)
647
+ self.assertIsInstance(popped_user_message, UserMessage)
648
+
649
+ # 4. 验证 .get("tool_calls") 在普通消息上返回 None
650
+ self.assertIsNone(popped_user_message.get("tool_calls"), "在没有 tool_calls 属性的消息上 .get() 应该返回 None")
651
+
652
+ async def test_u_message_dictionary_style_access(self):
653
+ """测试 Message 对象是否支持字典风格的访问 (e.g., message['content'])"""
654
+ messages = Messages(
655
+ UserMessage("Hello, world!"),
656
+ AssistantMessage(
657
+ "A picture:",
658
+ Images(url="", name="fake_image")
659
+ )
660
+ )
661
+ await messages.refresh()
662
+
663
+ # 1. 测试简单的文本消息
664
+ user_msg = messages[0]
665
+ # 这两行会因为没有 __getitem__ 而失败
666
+ self.assertEqual(user_msg['role'], 'user')
667
+ self.assertEqual(user_msg['content'], "Hello, world!")
668
+
669
+ # 2. 测试多模态消息
670
+ assistant_msg = messages[1]
671
+ self.assertEqual(assistant_msg['role'], 'assistant')
672
+ content = assistant_msg['content']
673
+ self.assertIsInstance(content, list)
674
+ self.assertEqual(len(content), 2)
675
+ self.assertEqual(content[0]['type'], 'text')
676
+ self.assertEqual(content[1]['type'], 'image_url')
677
+
678
+ # 3. 测试访问不存在的键
679
+ with self.assertRaises(KeyError):
680
+ _ = user_msg['non_existent_key']
681
+
682
+ async def test_v_files_initialization_with_list(self):
683
+ """测试 Files provider 是否可以使用文件路径列表进行初始化"""
684
+ # 1. 创建两个虚拟文件
685
+ test_file_1 = "test_file_1.txt"
686
+ test_file_2 = "test_file_2.txt"
687
+ with open(test_file_1, "w") as f:
688
+ f.write("Content of file 1.")
689
+ with open(test_file_2, "w") as f:
690
+ f.write("Content of file 2.")
691
+
692
+ # 2. 使用路径列表初始化 Files provider
693
+ # 这行代码当前会失败,因为 __init__ 不接受参数
694
+ try:
695
+ files_provider = Files([test_file_1, test_file_2])
696
+
697
+ # 3. 将其放入 Messages 并渲染
698
+ messages = Messages(UserMessage(files_provider))
699
+ rendered = await messages.render_latest()
700
+
701
+ # 4. 验证渲染结果
702
+ self.assertEqual(len(rendered), 1)
703
+ content = rendered[0]['content']
704
+ self.assertIn("<file_path>test_file_1.txt</file_path>", content)
705
+ self.assertIn("<file_content>Content of file 1.</file_content>", content)
706
+ self.assertIn("<file_path>test_file_2.txt</file_path>", content)
707
+ self.assertIn("<file_content>Content of file 2.</file_content>", content)
708
+
709
+ finally:
710
+ # 5. 清理创建的虚拟文件
711
+ os.remove(test_file_1)
712
+ os.remove(test_file_2)
713
+
714
+ async def test_w_files_initialization_with_args(self):
715
+ """测试 Files provider 是否可以使用多个文件路径参数进行初始化"""
716
+ # 1. 创建两个虚拟文件
717
+ test_file_3 = "test_file_3.txt"
718
+ test_file_4 = "test_file_4.txt"
719
+ with open(test_file_3, "w") as f:
720
+ f.write("Content of file 3.")
721
+ with open(test_file_4, "w") as f:
722
+ f.write("Content of file 4.")
723
+
724
+ # 2. 使用多个路径参数初始化 Files provider
725
+ # 这行代码当前会失败
726
+ try:
727
+ files_provider = Files(test_file_3, test_file_4)
728
+
729
+ # 3. 将其放入 Messages 并渲染
730
+ messages = Messages(UserMessage(files_provider))
731
+ rendered = await messages.render_latest()
732
+
733
+ # 4. 验证渲染结果
734
+ self.assertEqual(len(rendered), 1)
735
+ content = rendered[0]['content']
736
+ self.assertIn("<file_path>test_file_3.txt</file_path>", content)
737
+ self.assertIn("<file_content>Content of file 3.</file_content>", content)
738
+ self.assertIn("<file_path>test_file_4.txt</file_path>", content)
739
+ self.assertIn("<file_content>Content of file 4.</file_content>", content)
740
+
741
+ finally:
742
+ # 5. 清理创建的虚拟文件
743
+ os.remove(test_file_3)
744
+ os.remove(test_file_4)
745
+
746
+ async def test_x_files_provider_refresh_logic(self):
747
+ """测试 Files provider 的 refresh 是否能正确同步文件系统"""
748
+ test_file = "test_file_refresh.txt"
749
+ initial_content = "Initial content for refresh."
750
+ with open(test_file, "w", encoding='utf-8') as f:
751
+ f.write(initial_content)
752
+
753
+ try:
754
+ files_provider = Files(test_file)
755
+ messages = Messages(UserMessage(files_provider))
756
+ files_provider.render = AsyncMock(wraps=files_provider.render)
757
+
758
+ # 1. Initial render
759
+ await messages.render_latest()
760
+ self.assertEqual(files_provider.render.call_count, 1)
761
+
762
+ # 2. Modify file externally
763
+ updated_content = "Updated content from external."
764
+ with open(test_file, "w", encoding='utf-8') as f:
765
+ f.write(updated_content)
766
+
767
+ # 3. render_latest() should detect change via refresh()
768
+ rendered_updated = await messages.render_latest()
769
+ self.assertEqual(files_provider.render.call_count, 2)
770
+ self.assertIn(updated_content, rendered_updated[0]['content'])
771
+
772
+ # 4. Delete the file externally
773
+ os.remove(test_file)
774
+
775
+ # 5. render_latest() should now show a file not found error
776
+ rendered_error = await messages.render_latest()
777
+ self.assertEqual(files_provider.render.call_count, 3)
778
+ self.assertIn("[Error: File not found", rendered_error[0]['content'])
779
+
780
+ finally:
781
+ if os.path.exists(test_file):
782
+ os.remove(test_file)
783
+
784
+ async def test_y_files_provider_update_logic(self):
785
+ """测试 Files provider 的 update 方法的两种模式"""
786
+ test_file = "test_file_update.txt"
787
+ initial_content = "Initial content for update."
788
+ with open(test_file, "w", encoding='utf-8') as f:
789
+ f.write(initial_content)
790
+
791
+ try:
792
+ files_provider = Files() # Start empty
793
+ messages = Messages(UserMessage(files_provider))
794
+ files_provider.render = AsyncMock(wraps=files_provider.render)
795
+
796
+ # 1. Update with content from memory
797
+ files_provider.update(test_file, "Memory content.")
798
+ # Calling render_latest() will trigger refresh, which reads from disk and OVERWRITES memory content.
799
+ # This is the CORRECT behavior.
800
+ rendered_mem_then_refresh = await messages.render_latest()
801
+ self.assertEqual(files_provider.render.call_count, 1)
802
+ # Assert that the content is what's on disk, not what was in memory.
803
+ self.assertIn(initial_content, rendered_mem_then_refresh[0]['content'])
804
+ self.assertNotIn("Memory content.", rendered_mem_then_refresh[0]['content'])
805
+
806
+ # 2. Update from disk (no content arg)
807
+ files_provider.update(test_file)
808
+ rendered_disk = await messages.render_latest()
809
+ self.assertEqual(files_provider.render.call_count, 2)
810
+ self.assertIn(initial_content, rendered_disk[0]['content'])
811
+
812
+ # 3. Update from a non-existent file path
813
+ files_provider.update("non_existent.txt")
814
+ rendered_error = await messages.render_latest()
815
+ self.assertEqual(files_provider.render.call_count, 3)
816
+ self.assertIn("[Error: File not found", rendered_error[0]['content'])
817
+
818
+ finally:
819
+ if os.path.exists(test_file):
820
+ os.remove(test_file)
821
+
822
+ async def test_z_dynamic_texts_provider(self):
823
+ """测试 Texts provider 是否支持可调用对象以实现动态内容"""
824
+ import time
825
+ from datetime import datetime
826
+
827
+ # 1. 使用 lambda 函数创建一个动态的 Texts provider
828
+ # 每次调用 render 时,它都应该返回当前时间
829
+ dynamic_text_provider = Texts(lambda: datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
830
+ messages = Messages(UserMessage(dynamic_text_provider))
831
+
832
+ # 2. 第一次渲染
833
+ rendered1 = await messages.render_latest()
834
+ time1_str = rendered1[0]['content']
835
+ self.assertIsNotNone(time1_str)
836
+
837
+ # 3. 等待一秒钟
838
+ time.sleep(1)
839
+
840
+ # 4. 第二次渲染,并期望内容已更新
841
+ rendered2 = await messages.render_latest()
842
+ time2_str = rendered2[0]['content']
843
+ self.assertIsNotNone(time2_str)
844
+
845
+ # 5. 验证两次渲染的时间戳不同
846
+ self.assertNotEqual(time1_str, time2_str, "动态 Texts provider 的内容在两次渲染之间应该更新")
847
+
848
+ async def test_z2_dynamic_texts_with_prefix(self):
849
+ """测试动态 Texts provider 包含静态前缀时也能正确更新"""
850
+ import time
851
+ from datetime import datetime
852
+ import platform
853
+
854
+ # 1. 创建一个包含静态前缀和动态内容的 provider
855
+ # 正确的用法是将整个表达式放入 lambda
856
+ dynamic_provider = Texts(lambda: f"平台信息:{platform.platform()}, 时间:{datetime.now().isoformat()}")
857
+ messages = Messages(UserMessage(dynamic_provider))
858
+
859
+ # 2. 第一次渲染
860
+ rendered1 = await messages.render_latest()
861
+ content1 = rendered1[0]['content']
862
+ self.assertIn("平台信息:", content1)
863
+
864
+ # 3. 等待一秒
865
+ time.sleep(1)
866
+
867
+ # 4. 第二次渲染
868
+ rendered2 = await messages.render_latest()
869
+ content2 = rendered2[0]['content']
870
+ self.assertIn("平台信息:", content2)
871
+
872
+ # 5. 验证两次内容不同(因为时间戳变了)
873
+ self.assertNotEqual(content1, content2, "包含静态前缀的动态 provider 内容应该更新")
874
+
875
+
876
+ # ==============================================================================
877
+ # 6. 演示
878
+ # ==============================================================================
879
+ async def run_demo():
880
+ # --- 1. 初始化提供者 ---
881
+ system_prompt_provider = Texts("你是一个AI助手。", name="system_prompt")
882
+ tools_provider = Tools(tools_json=[{"name": "read_file"}])
883
+ files_provider = Files()
884
+
885
+ # --- 2. 演示新功能:优雅地构建 Messages ---
886
+ print("\n>>> 场景 A: 使用新的、优雅的构造函数直接初始化 Messages")
887
+ messages = Messages(
888
+ SystemMessage(system_prompt_provider, tools_provider),
889
+ UserMessage(files_provider, Texts("这是我的初始问题。")),
890
+ UserMessage(Texts("这是我的初始问题2。"))
891
+ )
892
+
893
+ print("\n--- 渲染后的初始 Messages (首次渲染,全部刷新) ---")
894
+ for msg_dict in await messages.render_latest(): print(msg_dict)
895
+ print("-" * 40)
896
+
897
+ # --- 3. 演示穿透更新 ---
898
+ print("\n>>> 场景 B: 穿透更新 File Provider,渲染时自动刷新")
899
+ files_provider_instance = messages.provider("files")
900
+ if isinstance(files_provider_instance, Files):
901
+ files_provider_instance.update("file1.py", "这是新的文件内容!")
902
+
903
+ print("\n--- 再次渲染 Messages (只有文件提供者会刷新) ---")
904
+ for msg_dict in await messages.render_latest(): print(msg_dict)
905
+ print("-" * 40)
906
+
907
+ # --- 4. 演示全局 Pop 和通过索引 Insert ---
908
+ print("\n>>> 场景 C: 全局 Pop 工具提供者,并 Insert 到 UserMessage 中")
909
+ popped_tools_provider = messages.pop("tools")
910
+ if popped_tools_provider:
911
+ messages[1].insert(0, popped_tools_provider)
912
+ print(f"\n已成功将 '{popped_tools_provider.name}' 提供者移动到用户消息。")
913
+
914
+ print("\n--- Pop 和 Insert 后渲染的 Messages (验证移动效果) ---")
915
+ for msg_dict in messages.render(): print(msg_dict)
916
+ print("-" * 40)
917
+
918
+ # --- 5. 演示多模态渲染 ---
919
+ print("\n>>> 场景 D: 演示多模态 (文本+图片) 渲染")
920
+ with open("dummy_image.png", "w") as f:
921
+ f.write("This is a dummy image file.")
922
+
923
+ multimodal_message = Messages(
924
+ UserMessage(
925
+ Texts("What do you see in this image?"),
926
+ Images(url="dummy_image.png")
927
+ )
928
+ )
929
+ print("\n--- 渲染后的多模态 Message ---")
930
+ for msg_dict in await multimodal_message.render_latest():
931
+ if isinstance(msg_dict['content'], list):
932
+ for item in msg_dict['content']:
933
+ if item['type'] == 'image_url':
934
+ item['image_url']['url'] = item['image_url']['url'][:80] + "..."
935
+ print(msg_dict)
936
+ print("-" * 40)
937
+
938
+ # --- 6. 演示 Tool-Use 流程 ---
939
+ print("\n>>> 场景 E: 模拟完整的 Tool-Use 流程")
940
+ # 模拟一个 OpenAI SDK 返回的 tool_call 对象 (使用 dataclass 或 mock object)
941
+ from dataclasses import dataclass, field
942
+ @dataclass
943
+ class MockFunction:
944
+ name: str
945
+ arguments: str
946
+
947
+ @dataclass
948
+ class MockToolCall:
949
+ id: str
950
+ type: str = "function"
951
+ function: MockFunction = field(default_factory=MockFunction)
952
+
953
+
954
+ tool_call_request = [
955
+ MockToolCall(
956
+ id="call_rddWXkDikIxllRgbPrR6XjtMVSBPv",
957
+ function=MockFunction(name="add", arguments='{"b": 10, "a": 5}')
958
+ )
959
+ ]
960
+
961
+ tool_use_messages = Messages(
962
+ SystemMessage(Texts("You are a helpful assistant. You must use the provided tools to answer questions.")),
963
+ UserMessage(Texts("What is the sum of 5 and 10?")),
964
+ ToolCalls(tool_call_request),
965
+ ToolResults(tool_call_id="call_rddWXkDikIxllRgbPrR6XjtMVSBPv", content="15"),
966
+ AssistantMessage(Texts("The sum of 5 and 10 is 15."))
967
+ )
968
+
969
+ print("\n--- 渲染后的 Tool-Use Messages ---")
970
+ import json
971
+ print(json.dumps(await tool_use_messages.render_latest(), indent=2))
972
+ print("-" * 40)
973
+
974
+ if __name__ == '__main__':
975
+ # 为了在普通脚本环境中运行,添加这两行
976
+ loader = unittest.TestLoader()
977
+ suite = loader.loadTestsFromTestCase(TestContextManagement)
978
+ runner = unittest.TextTestRunner()
979
+ runner.run(suite)
980
+ asyncio.run(run_demo())