aient 1.1.91__py3-none-any.whl → 1.1.92__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aient/architext/architext/core.py +290 -91
- aient/architext/test/openai_client.py +146 -0
- aient/architext/test/test.py +927 -0
- aient/architext/test/test_save_load.py +93 -0
- aient/models/chatgpt.py +31 -104
- {aient-1.1.91.dist-info → aient-1.1.92.dist-info}/METADATA +1 -1
- {aient-1.1.91.dist-info → aient-1.1.92.dist-info}/RECORD +10 -8
- aient/architext/test.py +0 -226
- {aient-1.1.91.dist-info → aient-1.1.92.dist-info}/WHEEL +0 -0
- {aient-1.1.91.dist-info → aient-1.1.92.dist-info}/licenses/LICENSE +0 -0
- {aient-1.1.91.dist-info → aient-1.1.92.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,927 @@
|
|
1
|
+
import unittest
|
2
|
+
from unittest.mock import AsyncMock
|
3
|
+
|
4
|
+
import os
|
5
|
+
import sys
|
6
|
+
|
7
|
+
# Add the project root to the Python path
|
8
|
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
9
|
+
|
10
|
+
|
11
|
+
from architext import *
|
12
|
+
|
13
|
+
# ==============================================================================
|
14
|
+
# 单元测试部分
|
15
|
+
# ==============================================================================
|
16
|
+
class TestContextManagement(unittest.IsolatedAsyncioTestCase):
|
17
|
+
|
18
|
+
def setUp(self):
|
19
|
+
"""在每个测试前设置环境"""
|
20
|
+
self.system_prompt_provider = Texts("你是一个AI助手。", name="system_prompt")
|
21
|
+
self.tools_provider = Tools(tools_json=[{"name": "read_file"}])
|
22
|
+
self.files_provider = Files()
|
23
|
+
|
24
|
+
async def test_a_initial_construction_and_render(self):
|
25
|
+
"""测试优雅的初始化和首次渲染"""
|
26
|
+
messages = Messages(
|
27
|
+
SystemMessage(self.system_prompt_provider, self.tools_provider),
|
28
|
+
UserMessage(self.files_provider, Texts("这是我的初始问题。"))
|
29
|
+
)
|
30
|
+
|
31
|
+
self.assertEqual(len(messages), 2)
|
32
|
+
rendered = await messages.render_latest()
|
33
|
+
|
34
|
+
self.assertEqual(len(rendered), 2)
|
35
|
+
self.assertIn("<tools>", rendered[0]['content'])
|
36
|
+
self.assertNotIn("<files>", rendered[1]['content'])
|
37
|
+
|
38
|
+
async def test_b_provider_passthrough_and_refresh(self):
|
39
|
+
"""测试通过 mock 验证缓存和刷新逻辑"""
|
40
|
+
# 使用一个简单的 Texts provider 来测试通用缓存逻辑,避免 Files 的副作用
|
41
|
+
text_provider = Texts("initial text")
|
42
|
+
text_provider.render = AsyncMock(wraps=text_provider.render)
|
43
|
+
messages = Messages(UserMessage(text_provider))
|
44
|
+
|
45
|
+
# 1. 首次刷新
|
46
|
+
await messages.refresh()
|
47
|
+
self.assertEqual(text_provider.render.call_count, 1)
|
48
|
+
|
49
|
+
# 2. 再次刷新,内容未变,不应再次调用 render
|
50
|
+
await messages.refresh()
|
51
|
+
self.assertEqual(text_provider.render.call_count, 1)
|
52
|
+
|
53
|
+
# 3. 更新内容,这会标记 provider 为 stale
|
54
|
+
text_provider.update("updated text")
|
55
|
+
|
56
|
+
# 4. 再次刷新,现在应该会重新调用 render
|
57
|
+
await messages.refresh()
|
58
|
+
rendered = messages.render()
|
59
|
+
self.assertEqual(text_provider.render.call_count, 2)
|
60
|
+
self.assertIn("updated text", rendered[0]['content'])
|
61
|
+
|
62
|
+
async def test_c_global_pop_and_indexed_insert(self):
|
63
|
+
"""测试全局pop和通过索引insert的功能"""
|
64
|
+
messages = Messages(
|
65
|
+
SystemMessage(self.system_prompt_provider, self.tools_provider),
|
66
|
+
UserMessage(self.files_provider)
|
67
|
+
)
|
68
|
+
|
69
|
+
# 验证初始状态
|
70
|
+
initial_rendered = await messages.render_latest()
|
71
|
+
self.assertTrue(any("<tools>" in msg['content'] for msg in initial_rendered if msg['role'] == 'system'))
|
72
|
+
|
73
|
+
# 全局弹出 'tools' Provider
|
74
|
+
popped_tools_provider = messages.pop("tools")
|
75
|
+
self.assertIs(popped_tools_provider, self.tools_provider)
|
76
|
+
|
77
|
+
# 验证 pop 后的状态
|
78
|
+
rendered_after_pop = messages.render()
|
79
|
+
self.assertFalse(any("<tools>" in msg['content'] for msg in rendered_after_pop if msg['role'] == 'system'))
|
80
|
+
|
81
|
+
# 通过索引将弹出的provider插入到UserMessage的开头
|
82
|
+
messages[1].insert(0, popped_tools_provider)
|
83
|
+
|
84
|
+
# 验证 insert 后的状态
|
85
|
+
rendered_after_insert = messages.render()
|
86
|
+
user_message_content = next(msg['content'] for msg in rendered_after_insert if msg['role'] == 'user')
|
87
|
+
self.assertTrue(user_message_content.startswith("<tools>"))
|
88
|
+
|
89
|
+
async def test_d_multimodal_rendering(self):
|
90
|
+
"""测试多模态(文本+图片)渲染"""
|
91
|
+
# Create a dummy image file for the test
|
92
|
+
dummy_image_path = "test_dummy_image.png"
|
93
|
+
with open(dummy_image_path, "w") as f:
|
94
|
+
f.write("dummy content")
|
95
|
+
|
96
|
+
messages = Messages(
|
97
|
+
UserMessage(
|
98
|
+
Texts("Describe the image."),
|
99
|
+
Images(url=dummy_image_path) # Test with optional name
|
100
|
+
)
|
101
|
+
)
|
102
|
+
|
103
|
+
rendered = await messages.render_latest()
|
104
|
+
self.assertEqual(len(rendered), 1)
|
105
|
+
|
106
|
+
content = rendered[0]['content']
|
107
|
+
self.assertIsInstance(content, list)
|
108
|
+
self.assertEqual(len(content), 2)
|
109
|
+
|
110
|
+
# Check text part
|
111
|
+
self.assertEqual(content[0]['type'], 'text')
|
112
|
+
self.assertEqual(content[0]['text'], 'Describe the image.')
|
113
|
+
|
114
|
+
# Check image part
|
115
|
+
self.assertEqual(content[1]['type'], 'image_url')
|
116
|
+
self.assertIn('data:image/png;base64,', content[1]['image_url']['url'])
|
117
|
+
|
118
|
+
# Clean up the dummy file
|
119
|
+
import os
|
120
|
+
os.remove(dummy_image_path)
|
121
|
+
|
122
|
+
async def test_e_multimodal_type_switching(self):
|
123
|
+
"""测试多模态消息在pop图片后是否能正确回退到字符串渲染"""
|
124
|
+
dummy_image_path = "test_dummy_image_2.png"
|
125
|
+
with open(dummy_image_path, "w") as f:
|
126
|
+
f.write("dummy content")
|
127
|
+
|
128
|
+
messages = Messages(
|
129
|
+
UserMessage(
|
130
|
+
Texts("Look at this:"),
|
131
|
+
Images(url=dummy_image_path, name="image"), # Explicit name for popping
|
132
|
+
Texts("Any thoughts?")
|
133
|
+
)
|
134
|
+
)
|
135
|
+
|
136
|
+
# 1. Initial multimodal render
|
137
|
+
rendered_multi = await messages.render_latest()
|
138
|
+
content_multi = rendered_multi[0]['content']
|
139
|
+
self.assertIsInstance(content_multi, list)
|
140
|
+
self.assertEqual(len(content_multi), 3) # prefix, image, suffix
|
141
|
+
|
142
|
+
# 2. Pop the image
|
143
|
+
popped_image = messages.pop("image")
|
144
|
+
self.assertIsNotNone(popped_image)
|
145
|
+
|
146
|
+
# 3. Render again, should fall back to string content
|
147
|
+
rendered_str = messages.render() # No refresh needed
|
148
|
+
content_str = rendered_str[0]['content']
|
149
|
+
self.assertIsInstance(content_str, str)
|
150
|
+
self.assertEqual(content_str, "Look at this:\n\nAny thoughts?")
|
151
|
+
|
152
|
+
# Clean up
|
153
|
+
import os
|
154
|
+
os.remove(dummy_image_path)
|
155
|
+
|
156
|
+
def test_f_message_merging(self):
|
157
|
+
"""测试初始化和追加时自动合并消息的功能"""
|
158
|
+
# 1. Test merging during initialization
|
159
|
+
messages = Messages(
|
160
|
+
UserMessage(Texts("Hello,")),
|
161
|
+
UserMessage(Texts("world!")),
|
162
|
+
SystemMessage(Texts("System prompt.")),
|
163
|
+
UserMessage(Texts("How are you?"))
|
164
|
+
)
|
165
|
+
# Should be merged into: User, System, User
|
166
|
+
self.assertEqual(len(messages), 3)
|
167
|
+
self.assertEqual(len(messages[0]._items), 2) # First UserMessage has 2 items
|
168
|
+
self.assertIn("text_", messages[0]._items[1].name)
|
169
|
+
self.assertEqual(messages[1].role, "system")
|
170
|
+
self.assertEqual(messages[2].role, "user")
|
171
|
+
|
172
|
+
# 2. Test merging during append
|
173
|
+
messages.append(UserMessage(Texts("I am fine.")))
|
174
|
+
self.assertEqual(len(messages), 3) # Still 3 messages
|
175
|
+
self.assertEqual(len(messages[2]._items), 2) # Last UserMessage now has 2 items
|
176
|
+
self.assertIn("text_", messages[2]._items[1].name)
|
177
|
+
|
178
|
+
# 3. Test appending a different role
|
179
|
+
messages.append(SystemMessage(Texts("Another prompt.")))
|
180
|
+
self.assertEqual(len(messages), 4) # Should not merge
|
181
|
+
self.assertEqual(messages[3].role, "system")
|
182
|
+
|
183
|
+
async def test_g_state_inconsistency_on_direct_message_modification(self):
|
184
|
+
"""
|
185
|
+
测试当直接在 Message 对象上执行 pop 操作时,
|
186
|
+
顶层 Messages 对象的 _providers_index 是否会产生不一致。
|
187
|
+
"""
|
188
|
+
messages = Messages(
|
189
|
+
SystemMessage(self.system_prompt_provider, self.tools_provider),
|
190
|
+
UserMessage(self.files_provider)
|
191
|
+
)
|
192
|
+
|
193
|
+
# 0. 先刷新一次,确保所有 provider 的 cache 都已填充
|
194
|
+
await messages.refresh()
|
195
|
+
|
196
|
+
# 1. 初始状态:'tools' 提供者应该在索引中
|
197
|
+
self.assertIsNotNone(messages.provider("tools"), "初始状态下 'tools' 提供者应该能被找到")
|
198
|
+
self.assertIs(messages.provider("tools"), self.tools_provider)
|
199
|
+
|
200
|
+
# 2. 直接在子消息对象上执行 pop 操作
|
201
|
+
system_message = messages[0]
|
202
|
+
popped_provider = system_message.pop("tools")
|
203
|
+
|
204
|
+
# 验证是否真的从 Message 对象中弹出了
|
205
|
+
self.assertIs(popped_provider, self.tools_provider, "应该从 SystemMessage 中成功弹出 provider")
|
206
|
+
self.assertNotIn(self.tools_provider, system_message.providers(), "provider 不应再存在于 SystemMessage 的 providers 列表中")
|
207
|
+
|
208
|
+
# 3. 核心问题:检查顶层 Messages 的索引
|
209
|
+
# 在理想情况下,直接修改子消息应该同步更新顶层索引。
|
210
|
+
# 因此,我们断言 provider 现在应该是找不到的。这个测试现在应该会失败。
|
211
|
+
provider_after_pop = messages.provider("tools")
|
212
|
+
self.assertIsNone(provider_after_pop, "BUG: 直接从子消息中 pop 后,顶层索引未同步,仍然可以找到 provider")
|
213
|
+
|
214
|
+
# 4. 进一步验证:渲染结果和索引内容不一致
|
215
|
+
# 渲染结果应该不再包含 tools 内容,因为 Message 对象本身是正确的
|
216
|
+
rendered_messages = messages.render()
|
217
|
+
self.assertGreater(len(rendered_messages), 0, "渲染后的消息列表不应为空")
|
218
|
+
rendered_content = rendered_messages[0]['content']
|
219
|
+
self.assertNotIn("<tools>", rendered_content, "渲染结果中不应再包含 'tools' 的内容,证明数据源已更新")
|
220
|
+
|
221
|
+
async def test_h_pop_message_by_index(self):
|
222
|
+
"""测试通过整数索引弹出Message的功能"""
|
223
|
+
user_provider = Texts("User message 1")
|
224
|
+
messages = Messages(
|
225
|
+
SystemMessage(Texts("System message")),
|
226
|
+
UserMessage(user_provider),
|
227
|
+
AssistantMessage(Texts("Assistant response"))
|
228
|
+
)
|
229
|
+
|
230
|
+
# 初始状态断言
|
231
|
+
self.assertEqual(len(messages), 3)
|
232
|
+
self.assertIsNotNone(messages.provider(user_provider.name))
|
233
|
+
|
234
|
+
# 弹出索引为 1 的 UserMessage
|
235
|
+
popped_message = messages.pop(1)
|
236
|
+
|
237
|
+
# 验证弹出的消息是否正确
|
238
|
+
self.assertIsInstance(popped_message, UserMessage)
|
239
|
+
self.assertEqual(len(popped_message.providers()), 1)
|
240
|
+
self.assertEqual(popped_message.providers()[0].name, user_provider.name)
|
241
|
+
|
242
|
+
# 验证 Messages 对象的当前状态
|
243
|
+
self.assertEqual(len(messages), 2)
|
244
|
+
self.assertEqual(messages[0].role, "system")
|
245
|
+
self.assertEqual(messages[1].role, "assistant")
|
246
|
+
|
247
|
+
# 验证 provider 索引是否已更新
|
248
|
+
self.assertIsNone(messages.provider(user_provider.name))
|
249
|
+
|
250
|
+
# 测试弹出不存在的索引
|
251
|
+
popped_none = messages.pop(99)
|
252
|
+
self.assertIsNone(popped_none)
|
253
|
+
self.assertEqual(len(messages), 2)
|
254
|
+
|
255
|
+
async def test_i_generic_update_and_refresh(self):
|
256
|
+
"""测试新添加的 update 方法是否能正确更新内容并标记为 stale"""
|
257
|
+
# 1. Setup providers
|
258
|
+
text_provider = Texts("Hello")
|
259
|
+
tools_provider = Tools([{"name": "tool_A"}])
|
260
|
+
|
261
|
+
dummy_image_path = "test_dummy_image_3.png"
|
262
|
+
with open(dummy_image_path, "w") as f: f.write("dummy content")
|
263
|
+
image_provider = Images(url=dummy_image_path, name="logo")
|
264
|
+
|
265
|
+
messages = Messages(UserMessage(text_provider, tools_provider, image_provider))
|
266
|
+
|
267
|
+
# Mock the render methods to monitor calls
|
268
|
+
text_provider.render = AsyncMock(wraps=text_provider.render)
|
269
|
+
tools_provider.render = AsyncMock(wraps=tools_provider.render)
|
270
|
+
image_provider.render = AsyncMock(wraps=image_provider.render)
|
271
|
+
|
272
|
+
# 2. Initial render
|
273
|
+
rendered_initial = await messages.render_latest()
|
274
|
+
self.assertIn("Hello", rendered_initial[0]['content'][0]['text'])
|
275
|
+
self.assertIn("tool_A", rendered_initial[0]['content'][1]['text'])
|
276
|
+
self.assertEqual(text_provider.render.call_count, 1)
|
277
|
+
self.assertEqual(tools_provider.render.call_count, 1)
|
278
|
+
self.assertEqual(image_provider.render.call_count, 1)
|
279
|
+
|
280
|
+
# 3. Update providers
|
281
|
+
text_provider.update("Goodbye")
|
282
|
+
tools_provider.update([{"name": "tool_B"}])
|
283
|
+
|
284
|
+
new_dummy_image_path = "test_dummy_image_4.png"
|
285
|
+
with open(new_dummy_image_path, "w") as f: f.write("new dummy content")
|
286
|
+
image_provider.update(url=new_dummy_image_path)
|
287
|
+
|
288
|
+
# Calling refresh again should not re-fetch yet because we haven't called messages.refresh()
|
289
|
+
await text_provider.refresh()
|
290
|
+
self.assertEqual(text_provider.render.call_count, 2)
|
291
|
+
|
292
|
+
# 4. Re-render after update
|
293
|
+
rendered_updated = await messages.render_latest()
|
294
|
+
self.assertIn("Goodbye", rendered_updated[0]['content'][0]['text'])
|
295
|
+
self.assertIn("tool_B", rendered_updated[0]['content'][1]['text'])
|
296
|
+
|
297
|
+
# Verify that render was called again for all updated providers
|
298
|
+
self.assertEqual(text_provider.render.call_count, 2)
|
299
|
+
self.assertEqual(tools_provider.render.call_count, 2)
|
300
|
+
self.assertEqual(image_provider.render.call_count, 2)
|
301
|
+
|
302
|
+
# Clean up
|
303
|
+
os.remove(dummy_image_path)
|
304
|
+
os.remove(new_dummy_image_path)
|
305
|
+
|
306
|
+
async def test_j_pop_last_message_without_arguments(self):
|
307
|
+
"""测试不带参数调用 pop() 时,弹出最后一个 Message"""
|
308
|
+
m1 = SystemMessage(Texts("System"))
|
309
|
+
m2 = UserMessage(Texts("User"))
|
310
|
+
m3 = AssistantMessage(Texts("Assistant"))
|
311
|
+
messages = Messages(m1, m2, m3)
|
312
|
+
|
313
|
+
self.assertEqual(len(messages), 3)
|
314
|
+
|
315
|
+
# Pop the last message
|
316
|
+
popped_message = messages.pop()
|
317
|
+
|
318
|
+
self.assertIs(popped_message, m3)
|
319
|
+
self.assertEqual(len(messages), 2)
|
320
|
+
self.assertIs(messages[-1], m2)
|
321
|
+
|
322
|
+
# Pop again
|
323
|
+
popped_message_2 = messages.pop()
|
324
|
+
self.assertIs(popped_message_2, m2)
|
325
|
+
self.assertEqual(len(messages), 1)
|
326
|
+
|
327
|
+
# Pop the last one
|
328
|
+
popped_message_3 = messages.pop()
|
329
|
+
self.assertIs(popped_message_3, m1)
|
330
|
+
self.assertEqual(len(messages), 0)
|
331
|
+
|
332
|
+
# Pop from empty
|
333
|
+
popped_none = messages.pop()
|
334
|
+
self.assertIsNone(popped_none)
|
335
|
+
|
336
|
+
async def test_k_image_provider_with_base64_url(self):
|
337
|
+
"""测试 Images provider 是否能正确处理 base64 data URL"""
|
338
|
+
# A simple 1x1 transparent PNG as a base64 string
|
339
|
+
base64_image_url = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII="
|
340
|
+
|
341
|
+
messages = Messages(
|
342
|
+
UserMessage(
|
343
|
+
Texts("This is a base64 image."),
|
344
|
+
Images(url=base64_image_url, name="base64_img")
|
345
|
+
)
|
346
|
+
)
|
347
|
+
|
348
|
+
rendered = await messages.render_latest()
|
349
|
+
self.assertEqual(len(rendered), 1)
|
350
|
+
|
351
|
+
content = rendered[0]['content']
|
352
|
+
self.assertIsInstance(content, list)
|
353
|
+
self.assertEqual(len(content), 2)
|
354
|
+
|
355
|
+
# Check text part
|
356
|
+
self.assertEqual(content[0]['type'], 'text')
|
357
|
+
self.assertEqual(content[0]['text'], 'This is a base64 image.')
|
358
|
+
|
359
|
+
# Check image part
|
360
|
+
image_content = content[1]
|
361
|
+
self.assertEqual(image_content['type'], 'image_url')
|
362
|
+
self.assertEqual(image_content['image_url']['url'], base64_image_url)
|
363
|
+
|
364
|
+
# Also test the update method
|
365
|
+
provider = messages.provider("base64_img")
|
366
|
+
self.assertIsNotNone(provider)
|
367
|
+
|
368
|
+
new_base64_url = "data:image/gif;base64,R0lGODlhAQABAIAAAAAAAP///yH5BAEAAAAALAAAAAABAAEAAAIBRAA7"
|
369
|
+
provider.update(url=new_base64_url)
|
370
|
+
|
371
|
+
rendered_updated = await messages.render_latest()
|
372
|
+
self.assertEqual(rendered_updated[0]['content'][1]['image_url']['url'], new_base64_url)
|
373
|
+
|
374
|
+
def test_l_role_message_factory(self):
|
375
|
+
"""测试 RoleMessage 工厂类是否能创建正确的子类实例"""
|
376
|
+
system_msg = RoleMessage('system', Texts("System content"))
|
377
|
+
user_msg = RoleMessage('user', Texts("User content"))
|
378
|
+
assistant_msg = RoleMessage('assistant', Texts("Assistant content"))
|
379
|
+
|
380
|
+
self.assertIsInstance(system_msg, SystemMessage)
|
381
|
+
self.assertEqual(system_msg.role, 'system')
|
382
|
+
self.assertIsInstance(user_msg, UserMessage)
|
383
|
+
self.assertEqual(user_msg.role, 'user')
|
384
|
+
self.assertIsInstance(assistant_msg, AssistantMessage)
|
385
|
+
self.assertEqual(assistant_msg.role, 'assistant')
|
386
|
+
|
387
|
+
# 测试无效的 role
|
388
|
+
with self.assertRaises(ValueError):
|
389
|
+
RoleMessage('invalid_role', Texts("Content"))
|
390
|
+
|
391
|
+
async def test_m_optional_name_for_texts(self):
|
392
|
+
"""测试 Texts provider 的 name 参数是否可选,并能自动生成唯一名称"""
|
393
|
+
# 1. 不提供 name
|
394
|
+
text_provider_1 = Texts("This is a test.")
|
395
|
+
self.assertTrue(text_provider_1.name.startswith("text_"))
|
396
|
+
|
397
|
+
# 2. 提供 name
|
398
|
+
text_provider_2 = Texts("This is another test.", name="my_name")
|
399
|
+
self.assertEqual(text_provider_2.name, "my_name")
|
400
|
+
|
401
|
+
# 3. 验证相同内容的文本生成相同的 name
|
402
|
+
text_provider_3 = Texts("This is a test.")
|
403
|
+
self.assertEqual(text_provider_1.name, text_provider_3.name)
|
404
|
+
|
405
|
+
# 4. 验证不同内容的文本生成不同的 name
|
406
|
+
text_provider_4 = Texts("This is a different test.")
|
407
|
+
self.assertNotEqual(text_provider_1.name, text_provider_4.name)
|
408
|
+
|
409
|
+
# 5. 在 Messages 中使用
|
410
|
+
messages = Messages(UserMessage(text_provider_1))
|
411
|
+
provider_from_messages = messages.provider(text_provider_1.name)
|
412
|
+
self.assertIs(provider_from_messages, text_provider_1)
|
413
|
+
|
414
|
+
async def test_n_string_to_texts_conversion(self):
|
415
|
+
"""测试在Message初始化时,字符串是否能被自动转换为Texts provider"""
|
416
|
+
# 1. 初始化一个包含字符串的UserMessage
|
417
|
+
user_message = UserMessage(self.files_provider, "This is a raw string.")
|
418
|
+
|
419
|
+
# 验证 _items 列表中的第二个元素是否是 Texts 类的实例
|
420
|
+
self.assertEqual(len(user_message.providers()), 2)
|
421
|
+
self.assertIsInstance(user_message.providers()[0], Files)
|
422
|
+
self.assertIsInstance(user_message.providers()[1], Texts)
|
423
|
+
|
424
|
+
# 验证转换后的 Texts provider 内容是否正确
|
425
|
+
# 我们需要异步地获取内容
|
426
|
+
text_provider = user_message.providers()[1]
|
427
|
+
await text_provider.refresh() # 手动刷新以获取内容
|
428
|
+
content_block = text_provider.get_content_block()
|
429
|
+
self.assertIsNotNone(content_block)
|
430
|
+
self.assertEqual(content_block.content, "This is a raw string.")
|
431
|
+
|
432
|
+
# 2. 在 Messages 容器中测试
|
433
|
+
messages = Messages(
|
434
|
+
SystemMessage("System prompt here."),
|
435
|
+
user_message
|
436
|
+
)
|
437
|
+
await messages.refresh()
|
438
|
+
rendered = messages.render()
|
439
|
+
|
440
|
+
self.assertEqual(len(rendered), 2)
|
441
|
+
self.assertEqual(rendered[0]['content'], "System prompt here.")
|
442
|
+
# 在user message中,files provider没有内容,所以只有string provider的内容
|
443
|
+
self.assertEqual(rendered[1]['content'], "This is a raw string.")
|
444
|
+
|
445
|
+
# 3. 测试RoleMessage工厂类
|
446
|
+
factory_user_msg = RoleMessage('user', "Factory-created string.")
|
447
|
+
self.assertIsInstance(factory_user_msg, UserMessage)
|
448
|
+
self.assertIsInstance(factory_user_msg.providers()[0], Texts)
|
449
|
+
|
450
|
+
# 4. 测试无效类型
|
451
|
+
with self.assertRaises(TypeError):
|
452
|
+
UserMessage(123) # 传入不支持的整数类型
|
453
|
+
|
454
|
+
async def test_o_list_to_providers_conversion(self):
|
455
|
+
"""测试在Message初始化时,列表内容是否能被自动转换为相应的provider"""
|
456
|
+
# 1. 混合内容的列表
|
457
|
+
mixed_content_list = [
|
458
|
+
{'type': 'text', 'text': 'Describe the following image.'},
|
459
|
+
{'type': 'image_url', 'image_url': {'url': 'data:image/png;base64,VGhpcyBpcyBhIGR1bW15IGltYWdlIGZpbGUu'}}
|
460
|
+
]
|
461
|
+
user_message_mixed = UserMessage(mixed_content_list)
|
462
|
+
|
463
|
+
self.assertEqual(len(user_message_mixed.providers()), 2)
|
464
|
+
self.assertIsInstance(user_message_mixed.providers()[0], Texts)
|
465
|
+
self.assertIsInstance(user_message_mixed.providers()[1], Images)
|
466
|
+
|
467
|
+
# 验证内容
|
468
|
+
providers = user_message_mixed.providers()
|
469
|
+
await asyncio.gather(*[p.refresh() for p in providers]) # 刷新所有providers
|
470
|
+
self.assertEqual(providers[0].get_content_block().content, 'Describe the following image.')
|
471
|
+
self.assertEqual(providers[1].get_content_block().content, 'data:image/png;base64,VGhpcyBpcyBhIGR1bW15IGltYWdlIGZpbGUu')
|
472
|
+
|
473
|
+
# 2. 纯文本内容的列表
|
474
|
+
text_only_list = [
|
475
|
+
{'type': 'text', 'text': 'First line.'},
|
476
|
+
{'type': 'text', 'text': 'Second line.'}
|
477
|
+
]
|
478
|
+
user_message_text_only = UserMessage(text_only_list)
|
479
|
+
|
480
|
+
self.assertEqual(len(user_message_text_only.providers()), 2)
|
481
|
+
self.assertIsInstance(user_message_text_only.providers()[0], Texts)
|
482
|
+
self.assertIsInstance(user_message_text_only.providers()[1], Texts)
|
483
|
+
|
484
|
+
# 3. 在 Messages 容器中测试
|
485
|
+
messages = Messages(UserMessage(mixed_content_list))
|
486
|
+
rendered = await messages.render_latest()
|
487
|
+
|
488
|
+
self.assertEqual(len(rendered), 1)
|
489
|
+
self.assertIsInstance(rendered[0]['content'], list)
|
490
|
+
self.assertEqual(len(rendered[0]['content']), 2)
|
491
|
+
self.assertEqual(rendered[0]['content'][0]['type'], 'text')
|
492
|
+
self.assertEqual(rendered[0]['content'][1]['type'], 'image_url')
|
493
|
+
|
494
|
+
# 4. 测试无效的列表项
|
495
|
+
invalid_list = [{'type': 'invalid_type'}]
|
496
|
+
with self.assertRaises(ValueError):
|
497
|
+
UserMessage(invalid_list)
|
498
|
+
|
499
|
+
async def test_p_empty_message_boolean_context(self):
|
500
|
+
"""测试一个空的 Message 对象在布尔上下文中是否为 False"""
|
501
|
+
# 1. 创建一个不含任何 provider 的空 UserMessage
|
502
|
+
empty_message = UserMessage()
|
503
|
+
self.assertFalse(empty_message, "一个空的 UserMessage 在布尔上下文中应该为 False")
|
504
|
+
|
505
|
+
# 2. 创建一个包含 provider 的 UserMessage
|
506
|
+
non_empty_message = UserMessage("Hello")
|
507
|
+
self.assertTrue(non_empty_message, "一个非空的 UserMessage 在布尔上下文中应该为 True")
|
508
|
+
|
509
|
+
# 3. 测试一个 provider 被移除后变为空消息的情况
|
510
|
+
message_to_be_emptied = UserMessage(Texts("content", name="removable"))
|
511
|
+
self.assertTrue(message_to_be_emptied, "消息在移除前应为 True")
|
512
|
+
message_to_be_emptied.pop("removable")
|
513
|
+
self.assertFalse(message_to_be_emptied, "消息在最后一个 provider 被移除后应为 False")
|
514
|
+
|
515
|
+
async def test_q_string_addition_to_message(self):
|
516
|
+
"""测试字符串与Message对象相加的功能"""
|
517
|
+
# 1. 创建一个 UserMessage
|
518
|
+
original_message = UserMessage("hello")
|
519
|
+
|
520
|
+
# 2. 将字符串与 UserMessage 相加
|
521
|
+
new_message = "hi" + original_message
|
522
|
+
|
523
|
+
# 3. 验证新消息的类型和内容
|
524
|
+
self.assertIsInstance(new_message, UserMessage, "结果应该是一个 UserMessage 实例")
|
525
|
+
self.assertEqual(len(new_message.providers()), 2, "新消息应该包含两个 provider")
|
526
|
+
|
527
|
+
providers = new_message.providers()
|
528
|
+
self.assertIsInstance(providers[0], Texts, "第一个 provider 应该是 Texts 类型")
|
529
|
+
self.assertIsInstance(providers[1], Texts, "第二个 provider 应该是 Texts 类型")
|
530
|
+
|
531
|
+
# 刷新以获取内容
|
532
|
+
await asyncio.gather(*[p.refresh() for p in providers])
|
533
|
+
|
534
|
+
self.assertEqual(providers[0].get_content_block().content, "hi", "第一个 provider 的内容应该是 'hi'")
|
535
|
+
self.assertEqual(providers[1].get_content_block().content, "hello", "第二个 provider 的内容应该是 'hello'")
|
536
|
+
|
537
|
+
# 4. 验证原始消息没有被修改
|
538
|
+
self.assertEqual(len(original_message.providers()), 1, "原始消息不应该被修改")
|
539
|
+
|
540
|
+
# 5. 测试 UserMessage + "string"
|
541
|
+
new_message_add = original_message + "world"
|
542
|
+
self.assertIsInstance(new_message_add, UserMessage)
|
543
|
+
self.assertEqual(len(new_message_add.providers()), 2)
|
544
|
+
|
545
|
+
providers_add = new_message_add.providers()
|
546
|
+
await asyncio.gather(*[p.refresh() for p in providers_add])
|
547
|
+
self.assertEqual(providers_add[0].get_content_block().content, "hello")
|
548
|
+
self.assertEqual(providers_add[1].get_content_block().content, "world")
|
549
|
+
|
550
|
+
async def test_r_message_addition_and_flattening(self):
|
551
|
+
"""测试 Message 对象相加和嵌套初始化时的扁平化功能"""
|
552
|
+
# 1. 测试 "str" + UserMessage
|
553
|
+
combined_message = "hi" + UserMessage("hello")
|
554
|
+
self.assertIsInstance(combined_message, UserMessage)
|
555
|
+
self.assertEqual(len(combined_message.providers()), 2)
|
556
|
+
|
557
|
+
providers = combined_message.providers()
|
558
|
+
await asyncio.gather(*[p.refresh() for p in providers])
|
559
|
+
self.assertEqual(providers[0].get_content_block().content, "hi")
|
560
|
+
self.assertEqual(providers[1].get_content_block().content, "hello")
|
561
|
+
|
562
|
+
# 2. 测试 UserMessage(UserMessage(...)) 扁平化
|
563
|
+
# 按照用户的要求,UserMessage(UserMessage(...)) 应该被扁平化
|
564
|
+
nested_message = UserMessage(UserMessage("item1", "item2"))
|
565
|
+
self.assertEqual(len(nested_message.providers()), 2)
|
566
|
+
|
567
|
+
providers_nested = nested_message.providers()
|
568
|
+
self.assertIsInstance(providers_nested[0], Texts)
|
569
|
+
self.assertIsInstance(providers_nested[1], Texts)
|
570
|
+
|
571
|
+
await asyncio.gather(*[p.refresh() for p in providers_nested])
|
572
|
+
self.assertEqual(providers_nested[0].get_content_block().content, "item1")
|
573
|
+
self.assertEqual(providers_nested[1].get_content_block().content, "item2")
|
574
|
+
|
575
|
+
# 3. 结合 1 和 2,测试用户的完整场景
|
576
|
+
final_message = UserMessage("hi" + UserMessage("hello"))
|
577
|
+
self.assertIsInstance(final_message, UserMessage)
|
578
|
+
self.assertEqual(len(final_message.providers()), 2)
|
579
|
+
|
580
|
+
providers_final = final_message.providers()
|
581
|
+
await asyncio.gather(*[p.refresh() for p in providers_final])
|
582
|
+
self.assertEqual(providers_final[0].get_content_block().content, "hi")
|
583
|
+
self.assertEqual(providers_final[1].get_content_block().content, "hello")
|
584
|
+
|
585
|
+
async def test_s_len_and_pop_with_get_method(self):
|
586
|
+
"""测试 len() 功能和 pop() 返回的对象支持 .get('role')"""
|
587
|
+
messages = Messages(
|
588
|
+
SystemMessage("System prompt"),
|
589
|
+
UserMessage("User question"),
|
590
|
+
AssistantMessage("Assistant answer")
|
591
|
+
)
|
592
|
+
|
593
|
+
# 1. 测试 len()
|
594
|
+
self.assertEqual(len(messages), 3, "len(messages) 应该返回消息的数量")
|
595
|
+
|
596
|
+
# 2. 弹出中间的消息
|
597
|
+
popped_message = messages.pop(1)
|
598
|
+
self.assertIsNotNone(popped_message, "pop(1) 应该返回一个消息对象")
|
599
|
+
self.assertIsInstance(popped_message, UserMessage)
|
600
|
+
|
601
|
+
# 3. 验证弹出的消息
|
602
|
+
# 这行会失败,因为 Message 对象没有 get 方法
|
603
|
+
self.assertEqual(popped_message.get("role"), "user", "弹出的消息应该可以通过 .get('role') 获取角色")
|
604
|
+
|
605
|
+
# 4. 验证 pop 后的状态
|
606
|
+
self.assertEqual(len(messages), 2, "pop() 后消息数量应该减少")
|
607
|
+
self.assertEqual(messages[0].role, "system")
|
608
|
+
self.assertEqual(messages[1].role, "assistant")
|
609
|
+
|
610
|
+
# 5. 测试 .get() 对不存在的键返回默认值
|
611
|
+
self.assertIsNone(popped_message.get("non_existent_key"), ".get() 对不存在的键应该返回 None")
|
612
|
+
self.assertEqual(popped_message.get("non_existent_key", "default"), "default", ".get() 应支持默认值")
|
613
|
+
|
614
|
+
async def test_t_pop_and_get_tool_calls(self):
|
615
|
+
"""测试弹出 ToolCalls 消息后,可以通过 .get('tool_calls') 访问其内容"""
|
616
|
+
from dataclasses import dataclass, field
|
617
|
+
@dataclass
|
618
|
+
class MockFunction:
|
619
|
+
name: str
|
620
|
+
arguments: str
|
621
|
+
|
622
|
+
@dataclass
|
623
|
+
class MockToolCall:
|
624
|
+
id: str
|
625
|
+
type: str = "function"
|
626
|
+
function: MockFunction = field(default_factory=lambda: MockFunction("", ""))
|
627
|
+
|
628
|
+
tool_call_list = [MockToolCall(id="call_123", function=MockFunction(name="test", arguments="{}"))]
|
629
|
+
|
630
|
+
messages = Messages(
|
631
|
+
UserMessage("A regular message"),
|
632
|
+
ToolCalls(tool_calls=tool_call_list)
|
633
|
+
)
|
634
|
+
|
635
|
+
# 1. 弹出 ToolCalls 消息
|
636
|
+
popped_tool_call_message = messages.pop(1)
|
637
|
+
self.assertIsInstance(popped_tool_call_message, ToolCalls)
|
638
|
+
|
639
|
+
# 2. 验证 .get("tool_calls")
|
640
|
+
retrieved_tool_calls = popped_tool_call_message.get("tool_calls")
|
641
|
+
self.assertIsNotNone(retrieved_tool_calls)
|
642
|
+
self.assertEqual(len(retrieved_tool_calls), 1)
|
643
|
+
self.assertIs(retrieved_tool_calls, tool_call_list)
|
644
|
+
|
645
|
+
# 3. 弹出普通消息
|
646
|
+
popped_user_message = messages.pop(0)
|
647
|
+
self.assertIsInstance(popped_user_message, UserMessage)
|
648
|
+
|
649
|
+
# 4. 验证 .get("tool_calls") 在普通消息上返回 None
|
650
|
+
self.assertIsNone(popped_user_message.get("tool_calls"), "在没有 tool_calls 属性的消息上 .get() 应该返回 None")
|
651
|
+
|
652
|
+
async def test_u_message_dictionary_style_access(self):
|
653
|
+
"""测试 Message 对象是否支持字典风格的访问 (e.g., message['content'])"""
|
654
|
+
messages = Messages(
|
655
|
+
UserMessage("Hello, world!"),
|
656
|
+
AssistantMessage(
|
657
|
+
"A picture:",
|
658
|
+
Images(url="data:image/png;base64,FAKE", name="fake_image")
|
659
|
+
)
|
660
|
+
)
|
661
|
+
await messages.refresh()
|
662
|
+
|
663
|
+
# 1. 测试简单的文本消息
|
664
|
+
user_msg = messages[0]
|
665
|
+
# 这两行会因为没有 __getitem__ 而失败
|
666
|
+
self.assertEqual(user_msg['role'], 'user')
|
667
|
+
self.assertEqual(user_msg['content'], "Hello, world!")
|
668
|
+
|
669
|
+
# 2. 测试多模态消息
|
670
|
+
assistant_msg = messages[1]
|
671
|
+
self.assertEqual(assistant_msg['role'], 'assistant')
|
672
|
+
content = assistant_msg['content']
|
673
|
+
self.assertIsInstance(content, list)
|
674
|
+
self.assertEqual(len(content), 2)
|
675
|
+
self.assertEqual(content[0]['type'], 'text')
|
676
|
+
self.assertEqual(content[1]['type'], 'image_url')
|
677
|
+
|
678
|
+
# 3. 测试访问不存在的键
|
679
|
+
with self.assertRaises(KeyError):
|
680
|
+
_ = user_msg['non_existent_key']
|
681
|
+
|
682
|
+
async def test_v_files_initialization_with_list(self):
|
683
|
+
"""测试 Files provider 是否可以使用文件路径列表进行初始化"""
|
684
|
+
# 1. 创建两个虚拟文件
|
685
|
+
test_file_1 = "test_file_1.txt"
|
686
|
+
test_file_2 = "test_file_2.txt"
|
687
|
+
with open(test_file_1, "w") as f:
|
688
|
+
f.write("Content of file 1.")
|
689
|
+
with open(test_file_2, "w") as f:
|
690
|
+
f.write("Content of file 2.")
|
691
|
+
|
692
|
+
# 2. 使用路径列表初始化 Files provider
|
693
|
+
# 这行代码当前会失败,因为 __init__ 不接受参数
|
694
|
+
try:
|
695
|
+
files_provider = Files([test_file_1, test_file_2])
|
696
|
+
|
697
|
+
# 3. 将其放入 Messages 并渲染
|
698
|
+
messages = Messages(UserMessage(files_provider))
|
699
|
+
rendered = await messages.render_latest()
|
700
|
+
|
701
|
+
# 4. 验证渲染结果
|
702
|
+
self.assertEqual(len(rendered), 1)
|
703
|
+
content = rendered[0]['content']
|
704
|
+
self.assertIn("<file_path>test_file_1.txt</file_path>", content)
|
705
|
+
self.assertIn("<file_content>Content of file 1.</file_content>", content)
|
706
|
+
self.assertIn("<file_path>test_file_2.txt</file_path>", content)
|
707
|
+
self.assertIn("<file_content>Content of file 2.</file_content>", content)
|
708
|
+
|
709
|
+
finally:
|
710
|
+
# 5. 清理创建的虚拟文件
|
711
|
+
os.remove(test_file_1)
|
712
|
+
os.remove(test_file_2)
|
713
|
+
|
714
|
+
async def test_w_files_initialization_with_args(self):
|
715
|
+
"""测试 Files provider 是否可以使用多个文件路径参数进行初始化"""
|
716
|
+
# 1. 创建两个虚拟文件
|
717
|
+
test_file_3 = "test_file_3.txt"
|
718
|
+
test_file_4 = "test_file_4.txt"
|
719
|
+
with open(test_file_3, "w") as f:
|
720
|
+
f.write("Content of file 3.")
|
721
|
+
with open(test_file_4, "w") as f:
|
722
|
+
f.write("Content of file 4.")
|
723
|
+
|
724
|
+
# 2. 使用多个路径参数初始化 Files provider
|
725
|
+
# 这行代码当前会失败
|
726
|
+
try:
|
727
|
+
files_provider = Files(test_file_3, test_file_4)
|
728
|
+
|
729
|
+
# 3. 将其放入 Messages 并渲染
|
730
|
+
messages = Messages(UserMessage(files_provider))
|
731
|
+
rendered = await messages.render_latest()
|
732
|
+
|
733
|
+
# 4. 验证渲染结果
|
734
|
+
self.assertEqual(len(rendered), 1)
|
735
|
+
content = rendered[0]['content']
|
736
|
+
self.assertIn("<file_path>test_file_3.txt</file_path>", content)
|
737
|
+
self.assertIn("<file_content>Content of file 3.</file_content>", content)
|
738
|
+
self.assertIn("<file_path>test_file_4.txt</file_path>", content)
|
739
|
+
self.assertIn("<file_content>Content of file 4.</file_content>", content)
|
740
|
+
|
741
|
+
finally:
|
742
|
+
# 5. 清理创建的虚拟文件
|
743
|
+
os.remove(test_file_3)
|
744
|
+
os.remove(test_file_4)
|
745
|
+
|
746
|
+
async def test_x_files_provider_refresh_logic(self):
|
747
|
+
"""测试 Files provider 的 refresh 是否能正确同步文件系统"""
|
748
|
+
test_file = "test_file_refresh.txt"
|
749
|
+
initial_content = "Initial content for refresh."
|
750
|
+
with open(test_file, "w", encoding='utf-8') as f:
|
751
|
+
f.write(initial_content)
|
752
|
+
|
753
|
+
try:
|
754
|
+
files_provider = Files(test_file)
|
755
|
+
messages = Messages(UserMessage(files_provider))
|
756
|
+
files_provider.render = AsyncMock(wraps=files_provider.render)
|
757
|
+
|
758
|
+
# 1. Initial render
|
759
|
+
await messages.render_latest()
|
760
|
+
self.assertEqual(files_provider.render.call_count, 1)
|
761
|
+
|
762
|
+
# 2. Modify file externally
|
763
|
+
updated_content = "Updated content from external."
|
764
|
+
with open(test_file, "w", encoding='utf-8') as f:
|
765
|
+
f.write(updated_content)
|
766
|
+
|
767
|
+
# 3. render_latest() should detect change via refresh()
|
768
|
+
rendered_updated = await messages.render_latest()
|
769
|
+
self.assertEqual(files_provider.render.call_count, 2)
|
770
|
+
self.assertIn(updated_content, rendered_updated[0]['content'])
|
771
|
+
|
772
|
+
# 4. Delete the file externally
|
773
|
+
os.remove(test_file)
|
774
|
+
|
775
|
+
# 5. render_latest() should now show a file not found error
|
776
|
+
rendered_error = await messages.render_latest()
|
777
|
+
self.assertEqual(files_provider.render.call_count, 3)
|
778
|
+
self.assertIn("[Error: File not found", rendered_error[0]['content'])
|
779
|
+
|
780
|
+
finally:
|
781
|
+
if os.path.exists(test_file):
|
782
|
+
os.remove(test_file)
|
783
|
+
|
784
|
+
async def test_y_files_provider_update_logic(self):
|
785
|
+
"""测试 Files provider 的 update 方法的两种模式"""
|
786
|
+
test_file = "test_file_update.txt"
|
787
|
+
initial_content = "Initial content for update."
|
788
|
+
with open(test_file, "w", encoding='utf-8') as f:
|
789
|
+
f.write(initial_content)
|
790
|
+
|
791
|
+
try:
|
792
|
+
files_provider = Files() # Start empty
|
793
|
+
messages = Messages(UserMessage(files_provider))
|
794
|
+
files_provider.render = AsyncMock(wraps=files_provider.render)
|
795
|
+
|
796
|
+
# 1. Update with content from memory
|
797
|
+
files_provider.update(test_file, "Memory content.")
|
798
|
+
# Calling render_latest() will trigger refresh, which reads from disk and OVERWRITES memory content.
|
799
|
+
# This is the CORRECT behavior.
|
800
|
+
rendered_mem_then_refresh = await messages.render_latest()
|
801
|
+
self.assertEqual(files_provider.render.call_count, 1)
|
802
|
+
# Assert that the content is what's on disk, not what was in memory.
|
803
|
+
self.assertIn(initial_content, rendered_mem_then_refresh[0]['content'])
|
804
|
+
self.assertNotIn("Memory content.", rendered_mem_then_refresh[0]['content'])
|
805
|
+
|
806
|
+
# 2. Update from disk (no content arg)
|
807
|
+
files_provider.update(test_file)
|
808
|
+
rendered_disk = await messages.render_latest()
|
809
|
+
self.assertEqual(files_provider.render.call_count, 2)
|
810
|
+
self.assertIn(initial_content, rendered_disk[0]['content'])
|
811
|
+
|
812
|
+
# 3. Update from a non-existent file path
|
813
|
+
files_provider.update("non_existent.txt")
|
814
|
+
rendered_error = await messages.render_latest()
|
815
|
+
self.assertEqual(files_provider.render.call_count, 3)
|
816
|
+
self.assertIn("[Error: File not found", rendered_error[0]['content'])
|
817
|
+
|
818
|
+
finally:
|
819
|
+
if os.path.exists(test_file):
|
820
|
+
os.remove(test_file)
|
821
|
+
|
822
|
+
|
823
|
+
# ==============================================================================
|
824
|
+
# 6. 演示
|
825
|
+
# ==============================================================================
|
826
|
+
async def run_demo():
|
827
|
+
# --- 1. 初始化提供者 ---
|
828
|
+
system_prompt_provider = Texts("你是一个AI助手。", name="system_prompt")
|
829
|
+
tools_provider = Tools(tools_json=[{"name": "read_file"}])
|
830
|
+
files_provider = Files()
|
831
|
+
|
832
|
+
# --- 2. 演示新功能:优雅地构建 Messages ---
|
833
|
+
print("\n>>> 场景 A: 使用新的、优雅的构造函数直接初始化 Messages")
|
834
|
+
messages = Messages(
|
835
|
+
SystemMessage(system_prompt_provider, tools_provider),
|
836
|
+
UserMessage(files_provider, Texts("这是我的初始问题。")),
|
837
|
+
UserMessage(Texts("这是我的初始问题2。"))
|
838
|
+
)
|
839
|
+
|
840
|
+
print("\n--- 渲染后的初始 Messages (首次渲染,全部刷新) ---")
|
841
|
+
for msg_dict in await messages.render_latest(): print(msg_dict)
|
842
|
+
print("-" * 40)
|
843
|
+
|
844
|
+
# --- 3. 演示穿透更新 ---
|
845
|
+
print("\n>>> 场景 B: 穿透更新 File Provider,渲染时自动刷新")
|
846
|
+
files_provider_instance = messages.provider("files")
|
847
|
+
if isinstance(files_provider_instance, Files):
|
848
|
+
files_provider_instance.update("file1.py", "这是新的文件内容!")
|
849
|
+
|
850
|
+
print("\n--- 再次渲染 Messages (只有文件提供者会刷新) ---")
|
851
|
+
for msg_dict in await messages.render_latest(): print(msg_dict)
|
852
|
+
print("-" * 40)
|
853
|
+
|
854
|
+
# --- 4. 演示全局 Pop 和通过索引 Insert ---
|
855
|
+
print("\n>>> 场景 C: 全局 Pop 工具提供者,并 Insert 到 UserMessage 中")
|
856
|
+
popped_tools_provider = messages.pop("tools")
|
857
|
+
if popped_tools_provider:
|
858
|
+
messages[1].insert(0, popped_tools_provider)
|
859
|
+
print(f"\n已成功将 '{popped_tools_provider.name}' 提供者移动到用户消息。")
|
860
|
+
|
861
|
+
print("\n--- Pop 和 Insert 后渲染的 Messages (验证移动效果) ---")
|
862
|
+
for msg_dict in messages.render(): print(msg_dict)
|
863
|
+
print("-" * 40)
|
864
|
+
|
865
|
+
# --- 5. 演示多模态渲染 ---
|
866
|
+
print("\n>>> 场景 D: 演示多模态 (文本+图片) 渲染")
|
867
|
+
with open("dummy_image.png", "w") as f:
|
868
|
+
f.write("This is a dummy image file.")
|
869
|
+
|
870
|
+
multimodal_message = Messages(
|
871
|
+
UserMessage(
|
872
|
+
Texts("What do you see in this image?"),
|
873
|
+
Images(url="dummy_image.png")
|
874
|
+
)
|
875
|
+
)
|
876
|
+
print("\n--- 渲染后的多模态 Message ---")
|
877
|
+
for msg_dict in await multimodal_message.render_latest():
|
878
|
+
if isinstance(msg_dict['content'], list):
|
879
|
+
for item in msg_dict['content']:
|
880
|
+
if item['type'] == 'image_url':
|
881
|
+
item['image_url']['url'] = item['image_url']['url'][:80] + "..."
|
882
|
+
print(msg_dict)
|
883
|
+
print("-" * 40)
|
884
|
+
|
885
|
+
# --- 6. 演示 Tool-Use 流程 ---
|
886
|
+
print("\n>>> 场景 E: 模拟完整的 Tool-Use 流程")
|
887
|
+
# 模拟一个 OpenAI SDK 返回的 tool_call 对象 (使用 dataclass 或 mock object)
|
888
|
+
from dataclasses import dataclass, field
|
889
|
+
@dataclass
|
890
|
+
class MockFunction:
|
891
|
+
name: str
|
892
|
+
arguments: str
|
893
|
+
|
894
|
+
@dataclass
|
895
|
+
class MockToolCall:
|
896
|
+
id: str
|
897
|
+
type: str = "function"
|
898
|
+
function: MockFunction = field(default_factory=MockFunction)
|
899
|
+
|
900
|
+
|
901
|
+
tool_call_request = [
|
902
|
+
MockToolCall(
|
903
|
+
id="call_rddWXkDikIxllRgbPrR6XjtMVSBPv",
|
904
|
+
function=MockFunction(name="add", arguments='{"b": 10, "a": 5}')
|
905
|
+
)
|
906
|
+
]
|
907
|
+
|
908
|
+
tool_use_messages = Messages(
|
909
|
+
SystemMessage(Texts("You are a helpful assistant. You must use the provided tools to answer questions.")),
|
910
|
+
UserMessage(Texts("What is the sum of 5 and 10?")),
|
911
|
+
ToolCalls(tool_call_request),
|
912
|
+
ToolResults(tool_call_id="call_rddWXkDikIxllRgbPrR6XjtMVSBPv", content="15"),
|
913
|
+
AssistantMessage(Texts("The sum of 5 and 10 is 15."))
|
914
|
+
)
|
915
|
+
|
916
|
+
print("\n--- 渲染后的 Tool-Use Messages ---")
|
917
|
+
import json
|
918
|
+
print(json.dumps(await tool_use_messages.render_latest(), indent=2))
|
919
|
+
print("-" * 40)
|
920
|
+
|
921
|
+
if __name__ == '__main__':
|
922
|
+
# 为了在普通脚本环境中运行,添加这两行
|
923
|
+
loader = unittest.TestLoader()
|
924
|
+
suite = loader.loadTestsFromTestCase(TestContextManagement)
|
925
|
+
runner = unittest.TextTestRunner()
|
926
|
+
runner.run(suite)
|
927
|
+
asyncio.run(run_demo())
|