aient 1.1.91__py3-none-any.whl → 1.1.92__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,93 @@
1
+ import asyncio
2
+ import os
3
+ import sys
4
+
5
+ # Add the project root to the Python path
6
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
7
+
8
+ from architext.core import (
9
+ Messages,
10
+ SystemMessage,
11
+ UserMessage,
12
+ AssistantMessage,
13
+ ToolCalls,
14
+ ToolResults,
15
+ Texts,
16
+ Tools,
17
+ Files,
18
+ )
19
+
20
+ async def main():
21
+ """
22
+ Tests the save and load functionality of the Messages class using pickle.
23
+ """
24
+ print("--- Test Save/Load (pickle) ---")
25
+
26
+ # 1. Create an initial Messages object
27
+ messages = Messages(
28
+ SystemMessage(Texts("system_prompt", "You are a helpful assistant.")),
29
+ UserMessage(Texts("user_input", "What is the weather in Shanghai?")),
30
+ AssistantMessage(Texts("thought", "I should call a tool for this.")),
31
+ ToolCalls(tool_calls=[{
32
+ 'id': 'call_1234',
33
+ 'type': 'function',
34
+ 'function': {'name': 'get_weather', 'arguments': '{"location": "Shanghai"}'}
35
+ }]),
36
+ ToolResults(tool_call_id="call_1234", content='{"temperature": "25°C"}')
37
+ )
38
+
39
+ # Add a message with Files provider
40
+ files_provider = Files()
41
+ files_provider.update("test.txt", "This is a test file.")
42
+ messages.append(UserMessage(files_provider))
43
+
44
+ # Render the original messages
45
+ original_render = await messages.render_latest()
46
+ print("Original Messages Render:")
47
+ print(original_render)
48
+
49
+ # 2. Save the messages to a file
50
+ file_path = "test_messages.pkl"
51
+ messages.save(file_path)
52
+ print(f"\nMessages saved to {file_path}")
53
+
54
+ assert os.path.exists(file_path), "Save file was not created."
55
+
56
+ # 3. Load the messages from the file
57
+ loaded_messages = Messages.load(file_path)
58
+ print("\nMessages loaded from file.")
59
+
60
+ assert loaded_messages is not None, "Loaded messages should not be None."
61
+
62
+ # Render the loaded messages
63
+ loaded_render = await loaded_messages.render_latest()
64
+ print("\nLoaded Messages Render:")
65
+ print(loaded_render)
66
+
67
+ # 4. Compare the original and loaded content
68
+ assert original_render == loaded_render, "Rendered content of original and loaded messages do not match."
69
+ print("\n✅ Assertion passed: Original and loaded message renders are identical.")
70
+
71
+ # 5. Check if the loaded object retains its class structure and methods
72
+ print(f"\nType of loaded object: {type(loaded_messages)}")
73
+ assert isinstance(loaded_messages, Messages), "Loaded object is not a Messages instance."
74
+
75
+ # Test pop functionality on the loaded object
76
+ popped_item = loaded_messages.pop(0)
77
+ assert isinstance(popped_item, SystemMessage), "Popped item is not a SystemMessage."
78
+ print(f"Popped first message: {popped_item}")
79
+
80
+ popped_render = await loaded_messages.render_latest()
81
+ print("\nRender after popping first message from loaded object:")
82
+ print(popped_render)
83
+ assert len(popped_render) == len(original_render) - 1, "Popping a message did not reduce the message count."
84
+ print("✅ Assertion passed: Pop functionality works on the loaded object.")
85
+
86
+ # 6. Clean up the test file
87
+ os.remove(file_path)
88
+ print(f"\nCleaned up {file_path}.")
89
+
90
+ print("\n--- Test Completed Successfully ---")
91
+
92
+ if __name__ == "__main__":
93
+ asyncio.run(main())
aient/models/chatgpt.py CHANGED
@@ -14,6 +14,7 @@ from ..plugins import PLUGINS, get_tools_result_async, function_call_list, updat
14
14
  from ..utils.scripts import safe_get, async_generator_to_sync, parse_function_xml, parse_continuous_json, convert_functions_to_xml, remove_xml_tags_and_content
15
15
  from ..core.request import prepare_request_payload
16
16
  from ..core.response import fetch_response_stream, fetch_response
17
+ from ..architext.architext import Messages, SystemMessage, UserMessage, AssistantMessage, ToolCalls, ToolResults, Texts, RoleMessage, Images, Files
17
18
 
18
19
  class APITimeoutError(Exception):
19
20
  """Custom exception for API timeout errors."""
@@ -88,7 +89,6 @@ class chatgpt(BaseLLM):
88
89
  print_log: bool = False,
89
90
  tools: Optional[Union[list, str, Callable]] = [],
90
91
  function_call_max_loop: int = 3,
91
- cut_history_by_function_name: str = "",
92
92
  cache_messages: list = None,
93
93
  logger: logging.Logger = None,
94
94
  check_done: bool = False,
@@ -109,8 +109,6 @@ class chatgpt(BaseLLM):
109
109
  self.conversation["default"] = cache_messages
110
110
  self.function_calls_counter = {}
111
111
  self.function_call_max_loop = function_call_max_loop
112
- self.cut_history_by_function_name = cut_history_by_function_name
113
- self.latest_file_content = {}
114
112
  self.check_done = check_done
115
113
 
116
114
  if logger:
@@ -164,95 +162,48 @@ class chatgpt(BaseLLM):
164
162
  if convo_id not in self.conversation:
165
163
  self.reset(convo_id=convo_id)
166
164
  if function_name == "" and message:
167
- self.conversation[convo_id].append({"role": role, "content": message})
165
+ self.conversation[convo_id].append(RoleMessage(role, message))
168
166
  elif function_name != "" and message:
169
- # 删除从 cut_history_by_function_name 以后的所有历史记录
170
- if function_name == self.cut_history_by_function_name:
171
- matching_message = next(filter(lambda x: safe_get(x, "tool_calls", 0, "function", "name", default="") == 'get_next_pdf', self.conversation[convo_id]), None)
172
- if matching_message is not None:
173
- self.conversation[convo_id] = self.conversation[convo_id][:self.conversation[convo_id].index(matching_message)]
174
-
175
167
  if not (all(value == False for value in self.plugins.values()) or self.use_plugins == False):
176
- self.conversation[convo_id].append({
177
- "role": "assistant",
178
- "tool_calls": [
179
- {
180
- "id": function_call_id,
181
- "type": "function",
182
- "function": {
183
- "name": function_name,
184
- "arguments": function_arguments,
185
- },
186
- }
187
- ],
188
- })
189
- self.conversation[convo_id].append({"role": role, "tool_call_id": function_call_id, "content": message})
168
+ tool_calls = [
169
+ {
170
+ "id": function_call_id,
171
+ "type": "function",
172
+ "function": {
173
+ "name": function_name,
174
+ "arguments": function_arguments,
175
+ },
176
+ }
177
+ ]
178
+ self.conversation[convo_id].append(ToolCalls(tool_calls))
179
+ self.conversation[convo_id].append(ToolResults(tool_call_id=function_call_id, content=message))
190
180
  else:
191
181
  last_user_message = self.conversation[convo_id][-1]["content"]
192
182
  if last_user_message != message:
193
- image_message_list = []
183
+ image_message_list = UserMessage()
194
184
  if isinstance(function_arguments, str):
195
185
  functions_list = json.loads(function_arguments)
196
186
  else:
197
187
  functions_list = function_arguments
198
188
  for tool_info in functions_list:
199
189
  if tool_info.get("base64_image"):
200
- image_message_list.append({"type": "text", "text": safe_get(tool_info, "parameter", "image_path", default="") + " image:"})
201
- image_message_list.append({
202
- "type": "image_url",
203
- "image_url": {
204
- "url": tool_info["base64_image"],
205
- }
206
- })
207
- self.conversation[convo_id].append({"role": "assistant", "content": convert_functions_to_xml(function_arguments)})
190
+ image_message_list.extend([
191
+ safe_get(tool_info, "parameter", "image_path", default="") + " image:",
192
+ Images(tool_info["base64_image"]),
193
+ ])
194
+ self.conversation[convo_id].append(AssistantMessage(convert_functions_to_xml(function_arguments)))
208
195
  if image_message_list:
209
- self.conversation[convo_id].append({"role": "user", "content": [{"type": "text", "text": message}] + image_message_list})
196
+ self.conversation[convo_id].append(UserMessage(message + image_message_list))
210
197
  else:
211
- self.conversation[convo_id].append({"role": "user", "content": message})
198
+ self.conversation[convo_id].append(UserMessage(message))
212
199
  else:
213
- self.conversation[convo_id].append({"role": "assistant", "content": "我已经执行过这个工具了,接下来我需要做什么?"})
214
-
200
+ self.conversation[convo_id].append(AssistantMessage("我已经执行过这个工具了,接下来我需要做什么?"))
215
201
  else:
216
202
  self.logger.error(f"error: add_to_conversation message is None or empty, role: {role}, function_name: {function_name}, message: {message}")
217
203
 
218
- conversation_len = len(self.conversation[convo_id]) - 1
219
- message_index = 0
220
204
  # if self.print_log:
221
205
  # replaced_text = json.loads(re.sub(r';base64,([A-Za-z0-9+/=]+)', ';base64,***', json.dumps(self.conversation[convo_id])))
222
206
  # self.logger.info(json.dumps(replaced_text, indent=4, ensure_ascii=False))
223
- while message_index < conversation_len:
224
- if self.conversation[convo_id][message_index]["role"] == self.conversation[convo_id][message_index + 1]["role"]:
225
- if self.conversation[convo_id][message_index].get("content") and self.conversation[convo_id][message_index + 1].get("content") \
226
- and self.conversation[convo_id][message_index].get("content") != self.conversation[convo_id][message_index + 1].get("content"):
227
- if type(self.conversation[convo_id][message_index + 1]["content"]) == str \
228
- and type(self.conversation[convo_id][message_index]["content"]) == list:
229
- self.conversation[convo_id][message_index + 1]["content"] = [{"type": "text", "text": self.conversation[convo_id][message_index + 1]["content"]}]
230
- if type(self.conversation[convo_id][message_index]["content"]) == str \
231
- and type(self.conversation[convo_id][message_index + 1]["content"]) == list:
232
- self.conversation[convo_id][message_index]["content"] = [{"type": "text", "text": self.conversation[convo_id][message_index]["content"]}]
233
- if type(self.conversation[convo_id][message_index]["content"]) == dict \
234
- and type(self.conversation[convo_id][message_index + 1]["content"]) == str:
235
- self.conversation[convo_id][message_index]["content"] = [self.conversation[convo_id][message_index]["content"]]
236
- self.conversation[convo_id][message_index + 1]["content"] = [{"type": "text", "text": self.conversation[convo_id][message_index + 1]["content"]}]
237
- if type(self.conversation[convo_id][message_index]["content"]) == dict \
238
- and type(self.conversation[convo_id][message_index + 1]["content"]) == list:
239
- self.conversation[convo_id][message_index]["content"] = [self.conversation[convo_id][message_index]["content"]]
240
- if type(self.conversation[convo_id][message_index]["content"]) == dict \
241
- and type(self.conversation[convo_id][message_index + 1]["content"]) == dict:
242
- self.conversation[convo_id][message_index]["content"] = [self.conversation[convo_id][message_index]["content"]]
243
- self.conversation[convo_id][message_index + 1]["content"] = [self.conversation[convo_id][message_index + 1]["content"]]
244
- if type(self.conversation[convo_id][message_index]["content"]) == list \
245
- and type(self.conversation[convo_id][message_index + 1]["content"]) == dict:
246
- self.conversation[convo_id][message_index + 1]["content"] = [self.conversation[convo_id][message_index + 1]["content"]]
247
- if type(self.conversation[convo_id][message_index]["content"]) == str \
248
- and type(self.conversation[convo_id][message_index + 1]["content"]) == str \
249
- and self.conversation[convo_id][message_index].get("content").endswith(self.conversation[convo_id][message_index + 1].get("content")):
250
- self.conversation[convo_id][message_index + 1]["content"] = ""
251
- self.conversation[convo_id][message_index]["content"] += self.conversation[convo_id][message_index + 1]["content"]
252
- self.conversation[convo_id].pop(message_index + 1)
253
- conversation_len = conversation_len - 1
254
- else:
255
- message_index = message_index + 1
256
207
 
257
208
  history_len = len(self.conversation[convo_id])
258
209
 
@@ -290,27 +241,6 @@ class chatgpt(BaseLLM):
290
241
  else:
291
242
  break
292
243
 
293
- def get_latest_file_content(self) -> str:
294
- """
295
- 获取最新文件内容
296
- """
297
- result = ""
298
- if self.latest_file_content:
299
- for file_path, content in self.latest_file_content.items():
300
- result += (
301
- "<file>"
302
- f"<file_path>{file_path}</file_path>"
303
- f"<file_content>{content}</file_content>"
304
- "</file>\n\n"
305
- )
306
- if result:
307
- result = (
308
- "<latest_file_content>"
309
- f"{result}"
310
- "</latest_file_content>"
311
- )
312
- return result
313
-
314
244
  async def get_post_body(
315
245
  self,
316
246
  prompt: str,
@@ -321,8 +251,6 @@ class chatgpt(BaseLLM):
321
251
  stream: bool = True,
322
252
  **kwargs,
323
253
  ):
324
- self.conversation[convo_id][0] = {"role": "system","content": self.system_prompt + "\n\n" + self.get_latest_file_content()}
325
-
326
254
  # 构造 provider 信息
327
255
  provider = {
328
256
  "provider": "openai",
@@ -336,10 +264,10 @@ class chatgpt(BaseLLM):
336
264
  # 构造请求数据
337
265
  request_data = {
338
266
  "model": model or self.engine,
339
- "messages": copy.deepcopy(self.conversation[convo_id]) if pass_history else [
340
- {"role": "system","content": self.system_prompt + "\n\n" + self.get_latest_file_content()},
341
- {"role": role, "content": prompt}
342
- ],
267
+ "messages": self.conversation[convo_id].render_latest() if pass_history else Messages(
268
+ SystemMessage(self.system_prompt, self.conversation[convo_id].provider("files")),
269
+ UserMessage(prompt)
270
+ ),
343
271
  "stream": stream,
344
272
  "temperature": kwargs.get("temperature", self.temperature)
345
273
  }
@@ -655,7 +583,7 @@ class chatgpt(BaseLLM):
655
583
  else:
656
584
  yield chunk
657
585
  if tool_name == "read_file" and "<tool_error>" not in tool_response:
658
- self.latest_file_content[tool_info['parameter']["file_path"]] = tool_response
586
+ self.conversation[convo_id].provider("files").update(tool_info['parameter']["file_path"], tool_response)
659
587
  all_responses.append(f"[{tool_name}({tool_args}) Result]:\n\nRead file successfully! The file content has been updated in the tag <latest_file_content>.")
660
588
  elif tool_name == "write_to_file" and "<tool_error>" not in tool_response:
661
589
  all_responses.append(f"[{tool_name} Result]:\n\n{tool_response}")
@@ -998,9 +926,8 @@ class chatgpt(BaseLLM):
998
926
  Reset the conversation
999
927
  """
1000
928
  self.system_prompt = system_prompt or self.system_prompt
1001
- self.latest_file_content = {}
1002
- self.conversation[convo_id] = [
1003
- {"role": "system", "content": self.system_prompt},
1004
- ]
929
+ self.conversation[convo_id] = Messages(
930
+ SystemMessage(Texts("system_prompt", self.system_prompt), self.conversation[convo_id].provider("files")),
931
+ )
1005
932
  self.tokens_usage[convo_id] = 0
1006
933
  self.current_tokens[convo_id] = 0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: aient
3
- Version: 1.1.91
3
+ Version: 1.1.92
4
4
  Summary: Aient: The Awakening of Agent.
5
5
  Requires-Python: >=3.11
6
6
  Description-Content-Type: text/markdown
@@ -1,7 +1,9 @@
1
1
  aient/__init__.py,sha256=SRfF7oDVlOOAi6nGKiJIUK6B_arqYLO9iSMp-2IZZps,21
2
- aient/architext/test.py,sha256=n2hUwHOPhOdpHRLCPuqeR-r4hGK43SF1WKaLdivdmC8,10021
3
2
  aient/architext/architext/__init__.py,sha256=79Ih1151rfcqZdr7F8HSZSTs_iT2SKd1xCkehMsXeXs,19
4
- aient/architext/architext/core.py,sha256=vww5nxNyoGXmWwmNEfLMGfv7k4GhCDZSiRqKzyqYv_8,10070
3
+ aient/architext/architext/core.py,sha256=i8UGEwoAS7Cwat3wB-jU8EC23gSwAXjFiwEc8QknE-k,17875
4
+ aient/architext/test/openai_client.py,sha256=Dqtbmubv6vwF8uBqcayG0kbsiO65of7sgU2-DRBi-UM,4590
5
+ aient/architext/test/test.py,sha256=CeFyloj-RJRk8hr6QDLjnRETG-KylD2bo3MVq5GVKPA,40597
6
+ aient/architext/test/test_save_load.py,sha256=o8DqH6gDYZkFkQy-a7blqLtJTRj5e4a-Lil48pJ0V3g,3260
5
7
  aient/core/__init__.py,sha256=NxjebTlku35S4Dzr16rdSqSTWUvvwEeACe8KvHJnjPg,34
6
8
  aient/core/log_config.py,sha256=kz2_yJv1p-o3lUQOwA3qh-LSc3wMHv13iCQclw44W9c,274
7
9
  aient/core/models.py,sha256=KMlCRLjtq1wQHZTJGqnbWhPS2cHq6eLdnk7peKDrzR8,7490
@@ -15,7 +17,7 @@ aient/core/test/test_payload.py,sha256=8jBiJY1uidm1jzL-EiK0s6UGmW9XkdsuuKFGrwFhF
15
17
  aient/models/__init__.py,sha256=ZTiZgbfBPTjIPSKURE7t6hlFBVLRS9lluGbmqc1WjxQ,43
16
18
  aient/models/audio.py,sha256=kRd-8-WXzv4vwvsTGwnstK-WR8--vr9CdfCZzu8y9LA,1934
17
19
  aient/models/base.py,sha256=-nnihYnx-vHZMqeVO9ljjt3k4FcD3n-iMk4tT-10nRQ,7232
18
- aient/models/chatgpt.py,sha256=2RaObZmliqJlGveOSWbwgpscjPWk7R1RmxwbEAH0xXo,47315
20
+ aient/models/chatgpt.py,sha256=EqvIcB6R_JGmvD5kEfR1ZRQYLfyGpmy4PvfRx8hMIKE,42019
19
21
  aient/plugins/__init__.py,sha256=p3KO6Aa3Lupos4i2SjzLQw1hzQTigOAfEHngsldrsyk,986
20
22
  aient/plugins/arXiv.py,sha256=yHjb6PS3GUWazpOYRMKMzghKJlxnZ5TX8z9F6UtUVow,1461
21
23
  aient/plugins/config.py,sha256=TGgZ5SnNKZ8MmdznrZ-TEq7s2ulhAAwTSKH89bci3dA,7079
@@ -33,8 +35,8 @@ aient/plugins/write_file.py,sha256=Jt8fOEwqhYiSWpCbwfAr1xoi_BmFnx3076GMhuL06uI,3
33
35
  aient/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
34
36
  aient/utils/prompt.py,sha256=UcSzKkFE4-h_1b6NofI6xgk3GoleqALRKY8VBaXLjmI,11311
35
37
  aient/utils/scripts.py,sha256=VqtK4RFEx7KxkmcqG3lFDS1DxoNlFFGErEjopVcc8IE,40974
36
- aient-1.1.91.dist-info/licenses/LICENSE,sha256=XNdbcWldt0yaNXXWB_Bakoqnxb3OVhUft4MgMA_71ds,1051
37
- aient-1.1.91.dist-info/METADATA,sha256=SByLoF4YtLoR1QxQ1qmVim3FbMHQIgD-b1OcxliO4bw,4842
38
- aient-1.1.91.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
39
- aient-1.1.91.dist-info/top_level.txt,sha256=3oXzrP5sAVvyyqabpeq8A2_vfMtY554r4bVE-OHBrZk,6
40
- aient-1.1.91.dist-info/RECORD,,
38
+ aient-1.1.92.dist-info/licenses/LICENSE,sha256=XNdbcWldt0yaNXXWB_Bakoqnxb3OVhUft4MgMA_71ds,1051
39
+ aient-1.1.92.dist-info/METADATA,sha256=EYclEjlBbW-N9rjlfZPA-1HlyQST_DmK3qqBNabzC3g,4842
40
+ aient-1.1.92.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
41
+ aient-1.1.92.dist-info/top_level.txt,sha256=3oXzrP5sAVvyyqabpeq8A2_vfMtY554r4bVE-OHBrZk,6
42
+ aient-1.1.92.dist-info/RECORD,,
aient/architext/test.py DELETED
@@ -1,226 +0,0 @@
1
- import unittest
2
- from unittest.mock import AsyncMock
3
- from architext import *
4
-
5
- # ==============================================================================
6
- # 单元测试部分
7
- # ==============================================================================
8
- class TestContextManagement(unittest.IsolatedAsyncioTestCase):
9
-
10
- def setUp(self):
11
- """在每个测试前设置环境"""
12
- self.system_prompt_provider = Texts("system_prompt", "你是一个AI助手。")
13
- self.tools_provider = Tools(tools_json=[{"name": "read_file"}])
14
- self.files_provider = Files()
15
-
16
- async def test_a_initial_construction_and_render(self):
17
- """测试优雅的初始化和首次渲染"""
18
- messages = Messages(
19
- SystemMessage(self.system_prompt_provider, self.tools_provider),
20
- UserMessage(self.files_provider, Texts("user_input", "这是我的初始问题。"))
21
- )
22
-
23
- self.assertEqual(len(messages), 2)
24
- rendered = await messages.render_latest()
25
-
26
- self.assertEqual(len(rendered), 2)
27
- self.assertIn("<tools>", rendered[0]['content'])
28
- self.assertNotIn("<files>", rendered[1]['content'])
29
-
30
- async def test_b_provider_passthrough_and_refresh(self):
31
- """测试通过 mock 验证缓存和刷新逻辑"""
32
- # 我们真正关心的是 _fetch_content 是否被不必要地调用
33
- # 所以我们 mock 底层的它,而不是 refresh 方法
34
- original_fetch_content = self.files_provider._fetch_content
35
- self.files_provider._fetch_content = AsyncMock(side_effect=original_fetch_content)
36
-
37
- messages = Messages(UserMessage(self.files_provider))
38
-
39
- # 1. 首次刷新
40
- self.files_provider.update("path1", "content1")
41
- await messages.refresh()
42
- # _fetch_content 应该被调用了 1 次
43
- self.assertEqual(self.files_provider._fetch_content.call_count, 1)
44
-
45
- # 2. 再次刷新,内容未变,不应再次调用 _fetch_content
46
- await messages.refresh()
47
- # 调用次数应该仍然是 1,证明缓存生效
48
- self.assertEqual(self.files_provider._fetch_content.call_count, 1)
49
-
50
- # 3. 更新文件内容,这会标记 provider 为 stale
51
- self.files_provider.update("path2", "content2")
52
-
53
- # 4. 再次刷新,现在应该会重新调用 _fetch_content
54
- await messages.refresh()
55
- rendered = messages.render()
56
- # 调用次数应该变为 2
57
- self.assertEqual(self.files_provider._fetch_content.call_count, 2)
58
- # 并且渲染结果包含了新内容
59
- self.assertIn("content2", rendered[0]['content'])
60
-
61
- async def test_c_global_pop_and_indexed_insert(self):
62
- """测试全局pop和通过索引insert的功能"""
63
- messages = Messages(
64
- SystemMessage(self.system_prompt_provider, self.tools_provider),
65
- UserMessage(self.files_provider)
66
- )
67
-
68
- # 验证初始状态
69
- initial_rendered = await messages.render_latest()
70
- self.assertTrue(any("<tools>" in msg['content'] for msg in initial_rendered if msg['role'] == 'system'))
71
-
72
- # 全局弹出 'tools' Provider
73
- popped_tools_provider = messages.pop("tools")
74
- self.assertIs(popped_tools_provider, self.tools_provider)
75
-
76
- # 验证 pop 后的状态
77
- rendered_after_pop = messages.render()
78
- self.assertFalse(any("<tools>" in msg['content'] for msg in rendered_after_pop if msg['role'] == 'system'))
79
-
80
- # 通过索引将弹出的provider插入到UserMessage的开头
81
- messages[1].insert(0, popped_tools_provider)
82
-
83
- # 验证 insert 后的状态
84
- rendered_after_insert = messages.render()
85
- user_message_content = next(msg['content'] for msg in rendered_after_insert if msg['role'] == 'user')
86
- self.assertTrue(user_message_content.startswith("<tools>"))
87
-
88
- async def test_d_multimodal_rendering(self):
89
- """测试多模态(文本+图片)渲染"""
90
- # Create a dummy image file for the test
91
- dummy_image_path = "test_dummy_image.png"
92
- with open(dummy_image_path, "w") as f:
93
- f.write("dummy content")
94
-
95
- messages = Messages(
96
- UserMessage(
97
- Texts("prompt", "Describe the image."),
98
- Images(dummy_image_path) # Test with optional name
99
- )
100
- )
101
-
102
- rendered = await messages.render_latest()
103
- self.assertEqual(len(rendered), 1)
104
-
105
- content = rendered[0]['content']
106
- self.assertIsInstance(content, list)
107
- self.assertEqual(len(content), 2)
108
-
109
- # Check text part
110
- self.assertEqual(content[0]['type'], 'text')
111
- self.assertEqual(content[0]['text'], 'Describe the image.')
112
-
113
- # Check image part
114
- self.assertEqual(content[1]['type'], 'image_url')
115
- self.assertIn('data:image/png;base64,', content[1]['image_url']['url'])
116
-
117
- # Clean up the dummy file
118
- import os
119
- os.remove(dummy_image_path)
120
-
121
- async def test_e_multimodal_type_switching(self):
122
- """测试多模态消息在pop图片后是否能正确回退到字符串渲染"""
123
- dummy_image_path = "test_dummy_image_2.png"
124
- with open(dummy_image_path, "w") as f:
125
- f.write("dummy content")
126
-
127
- messages = Messages(
128
- UserMessage(
129
- Texts("prefix", "Look at this:"),
130
- Images(dummy_image_path, name="image"), # Explicit name for popping
131
- Texts("suffix", "Any thoughts?")
132
- )
133
- )
134
-
135
- # 1. Initial multimodal render
136
- rendered_multi = await messages.render_latest()
137
- content_multi = rendered_multi[0]['content']
138
- self.assertIsInstance(content_multi, list)
139
- self.assertEqual(len(content_multi), 3) # prefix, image, suffix
140
-
141
- # 2. Pop the image
142
- popped_image = messages.pop("image")
143
- self.assertIsNotNone(popped_image)
144
-
145
- # 3. Render again, should fall back to string content
146
- rendered_str = messages.render() # No refresh needed
147
- content_str = rendered_str[0]['content']
148
- self.assertIsInstance(content_str, str)
149
- self.assertEqual(content_str, "Look at this:\n\nAny thoughts?")
150
-
151
- # Clean up
152
- import os
153
- os.remove(dummy_image_path)
154
-
155
- def test_f_message_merging(self):
156
- """测试初始化和追加时自动合并消息的功能"""
157
- # 1. Test merging during initialization
158
- messages = Messages(
159
- UserMessage(Texts("part1", "Hello,")),
160
- UserMessage(Texts("part2", "world!")),
161
- SystemMessage(Texts("system", "System prompt.")),
162
- UserMessage(Texts("part3", "How are you?"))
163
- )
164
- # Should be merged into: User, System, User
165
- self.assertEqual(len(messages), 3)
166
- self.assertEqual(len(messages[0]._items), 2) # First UserMessage has 2 items
167
- self.assertEqual(messages[0]._items[1].name, "part2")
168
- self.assertEqual(messages[1].role, "system")
169
- self.assertEqual(messages[2].role, "user")
170
-
171
- # 2. Test merging during append
172
- messages.append(UserMessage(Texts("part4", "I am fine.")))
173
- self.assertEqual(len(messages), 3) # Still 3 messages
174
- self.assertEqual(len(messages[2]._items), 2) # Last UserMessage now has 2 items
175
- self.assertEqual(messages[2]._items[1].name, "part4")
176
-
177
- # 3. Test appending a different role
178
- messages.append(SystemMessage(Texts("system2", "Another prompt.")))
179
- self.assertEqual(len(messages), 4) # Should not merge
180
- self.assertEqual(messages[3].role, "system")
181
-
182
- async def test_g_state_inconsistency_on_direct_message_modification(self):
183
- """
184
- 测试当直接在 Message 对象上执行 pop 操作时,
185
- 顶层 Messages 对象的 _providers_index 是否会产生不一致。
186
- """
187
- messages = Messages(
188
- SystemMessage(self.system_prompt_provider, self.tools_provider),
189
- UserMessage(self.files_provider)
190
- )
191
-
192
- # 0. 先刷新一次,确保所有 provider 的 cache 都已填充
193
- await messages.refresh()
194
-
195
- # 1. 初始状态:'tools' 提供者应该在索引中
196
- self.assertIsNotNone(messages.provider("tools"), "初始状态下 'tools' 提供者应该能被找到")
197
- self.assertIs(messages.provider("tools"), self.tools_provider)
198
-
199
- # 2. 直接在子消息对象上执行 pop 操作
200
- system_message = messages[0]
201
- popped_provider = system_message.pop("tools")
202
-
203
- # 验证是否真的从 Message 对象中弹出了
204
- self.assertIs(popped_provider, self.tools_provider, "应该从 SystemMessage 中成功弹出 provider")
205
- self.assertNotIn(self.tools_provider, system_message.providers(), "provider 不应再存在于 SystemMessage 的 providers 列表中")
206
-
207
- # 3. 核心问题:检查顶层 Messages 的索引
208
- # 在理想情况下,直接修改子消息应该同步更新顶层索引。
209
- # 因此,我们断言 provider 现在应该是找不到的。这个测试现在应该会失败。
210
- provider_after_pop = messages.provider("tools")
211
- self.assertIsNone(provider_after_pop, "BUG: 直接从子消息中 pop 后,顶层索引未同步,仍然可以找到 provider")
212
-
213
- # 4. 进一步验证:渲染结果和索引内容不一致
214
- # 渲染结果应该不再包含 tools 内容,因为 Message 对象本身是正确的
215
- rendered_messages = messages.render()
216
- self.assertGreater(len(rendered_messages), 0, "渲染后的消息列表不应为空")
217
- rendered_content = rendered_messages[0]['content']
218
- self.assertNotIn("<tools>", rendered_content, "渲染结果中不应再包含 'tools' 的内容,证明数据源已更新")
219
-
220
-
221
- if __name__ == '__main__':
222
- # 为了在普通脚本环境中运行,添加这两行
223
- suite = unittest.TestSuite()
224
- suite.addTest(unittest.makeSuite(TestContextManagement))
225
- runner = unittest.TextTestRunner()
226
- runner.run(suite)
File without changes