aient 1.2.30__tar.gz → 1.2.32__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {aient-1.2.30 → aient-1.2.32}/PKG-INFO +1 -1
- {aient-1.2.30 → aient-1.2.32}/aient/architext/architext/core.py +35 -2
- {aient-1.2.30 → aient-1.2.32}/aient/architext/test/test.py +124 -0
- {aient-1.2.30 → aient-1.2.32}/aient/models/chatgpt.py +39 -23
- {aient-1.2.30 → aient-1.2.32}/aient.egg-info/PKG-INFO +1 -1
- {aient-1.2.30 → aient-1.2.32}/pyproject.toml +1 -1
- {aient-1.2.30 → aient-1.2.32}/LICENSE +0 -0
- {aient-1.2.30 → aient-1.2.32}/README.md +0 -0
- {aient-1.2.30 → aient-1.2.32}/aient/__init__.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/aient/architext/architext/__init__.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/aient/architext/test/openai_client.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/aient/architext/test/test_save_load.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/aient/core/__init__.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/aient/core/log_config.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/aient/core/models.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/aient/core/request.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/aient/core/response.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/aient/core/test/test_base_api.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/aient/core/test/test_geminimask.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/aient/core/test/test_image.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/aient/core/test/test_payload.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/aient/core/utils.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/aient/models/__init__.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/aient/models/audio.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/aient/models/base.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/aient/plugins/__init__.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/aient/plugins/arXiv.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/aient/plugins/config.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/aient/plugins/excute_command.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/aient/plugins/get_time.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/aient/plugins/image.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/aient/plugins/list_directory.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/aient/plugins/read_file.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/aient/plugins/read_image.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/aient/plugins/readonly.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/aient/plugins/registry.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/aient/plugins/run_python.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/aient/plugins/websearch.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/aient/plugins/write_file.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/aient/utils/__init__.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/aient/utils/prompt.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/aient/utils/scripts.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/aient.egg-info/SOURCES.txt +0 -0
- {aient-1.2.30 → aient-1.2.32}/aient.egg-info/dependency_links.txt +0 -0
- {aient-1.2.30 → aient-1.2.32}/aient.egg-info/requires.txt +0 -0
- {aient-1.2.30 → aient-1.2.32}/aient.egg-info/top_level.txt +0 -0
- {aient-1.2.30 → aient-1.2.32}/setup.cfg +0 -0
- {aient-1.2.30 → aient-1.2.32}/test/test_Web_crawler.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/test/test_ddg_search.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/test/test_google_search.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/test/test_ollama.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/test/test_plugin.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/test/test_url.py +0 -0
- {aient-1.2.30 → aient-1.2.32}/test/test_whisper.py +0 -0
@@ -104,9 +104,10 @@ class ContextProvider(ABC):
|
|
104
104
|
return NotImplemented
|
105
105
|
|
106
106
|
class Texts(ContextProvider):
|
107
|
-
def __init__(self, text: Optional[Union[str, Callable[[], str]]] = None, name: Optional[str] = None, visible: bool = True):
|
107
|
+
def __init__(self, text: Optional[Union[str, Callable[[], str]]] = None, name: Optional[str] = None, visible: bool = True, newline: bool = False):
|
108
108
|
if text is None and name is None:
|
109
109
|
raise ValueError("Either 'text' or 'name' must be provided.")
|
110
|
+
self.newline = newline
|
110
111
|
|
111
112
|
# Ensure that non-callable inputs are treated as strings
|
112
113
|
if not callable(text):
|
@@ -373,8 +374,11 @@ class Message(ABC):
|
|
373
374
|
for item in self._items:
|
374
375
|
block = item.get_content_block()
|
375
376
|
if block and block.content is not None:
|
377
|
+
# Check if it's a Texts provider with newline=True
|
378
|
+
# and it's not the very first item with content.
|
379
|
+
if isinstance(item, Texts) and hasattr(item, 'newline') and item.newline and final_parts:
|
380
|
+
final_parts.append("\n\n")
|
376
381
|
final_parts.append(block.content)
|
377
|
-
|
378
382
|
return "".join(final_parts)
|
379
383
|
|
380
384
|
def pop(self, name: str) -> Optional[ContextProvider]:
|
@@ -468,6 +472,35 @@ class Message(ABC):
|
|
468
472
|
# and our custom __eq__ on ContextProvider handles the comparison logic.
|
469
473
|
return item in self._items
|
470
474
|
|
475
|
+
def has(self, provider_type: type) -> bool:
|
476
|
+
"""Checks if the message contains a provider of a specific type."""
|
477
|
+
if not isinstance(provider_type, type) or not issubclass(provider_type, ContextProvider):
|
478
|
+
raise TypeError("provider_type must be a subclass of ContextProvider")
|
479
|
+
return any(isinstance(p, provider_type) for p in self._items)
|
480
|
+
|
481
|
+
def lstrip(self, provider_type: type):
|
482
|
+
"""
|
483
|
+
从消息的左侧(开头)移除所有指定类型的 provider。
|
484
|
+
移除操作会一直持续,直到遇到一个不同类型的 provider 为止。
|
485
|
+
"""
|
486
|
+
while self._items and type(self._items[0]) is provider_type:
|
487
|
+
self.pop(self._items[0].name)
|
488
|
+
|
489
|
+
def rstrip(self, provider_type: type):
|
490
|
+
"""
|
491
|
+
从消息的右侧(末尾)移除所有指定类型的 provider。
|
492
|
+
移除操作会一直持续,直到遇到一个不同类型的 provider 为止。
|
493
|
+
"""
|
494
|
+
while self._items and type(self._items[-1]) is provider_type:
|
495
|
+
self.pop(self._items[-1].name)
|
496
|
+
|
497
|
+
def strip(self, provider_type: type):
|
498
|
+
"""
|
499
|
+
从消息的两侧移除所有指定类型的 provider。
|
500
|
+
"""
|
501
|
+
self.lstrip(provider_type)
|
502
|
+
self.rstrip(provider_type)
|
503
|
+
|
471
504
|
def __bool__(self) -> bool:
|
472
505
|
return bool(self._items)
|
473
506
|
def get(self, key: str, default: Any = None) -> Any:
|
@@ -1454,6 +1454,130 @@ Files: {Files(visible=True, name="files")}
|
|
1454
1454
|
self.assertEqual(len(message_mixed.provider()), 1)
|
1455
1455
|
self.assertIsInstance(message_mixed.provider()[0], Texts)
|
1456
1456
|
|
1457
|
+
async def test_zaa_has_method_for_provider_type_check(self):
|
1458
|
+
"""测试 Message.has(type) 方法是否能正确检查 provider 类型"""
|
1459
|
+
# 1. 创建一个混合类型的消息
|
1460
|
+
message_with_text = UserMessage(Texts("hi"), Images("url"))
|
1461
|
+
|
1462
|
+
# 2. 测试存在的情况
|
1463
|
+
# This line is expected to fail with an AttributeError before implementation
|
1464
|
+
self.assertTrue(message_with_text.has(Texts))
|
1465
|
+
self.assertTrue(message_with_text.has(Images))
|
1466
|
+
|
1467
|
+
# 3. 测试不存在的情况
|
1468
|
+
self.assertFalse(message_with_text.has(Tools))
|
1469
|
+
|
1470
|
+
# 4. 测试空消息
|
1471
|
+
empty_message = UserMessage()
|
1472
|
+
self.assertFalse(empty_message.has(Texts))
|
1473
|
+
|
1474
|
+
# 5. 测试传入无效类型
|
1475
|
+
with self.assertRaises(TypeError):
|
1476
|
+
message_with_text.has(str)
|
1477
|
+
|
1478
|
+
with self.assertRaises(TypeError):
|
1479
|
+
# Also test with a class that is not a subclass of ContextProvider
|
1480
|
+
class NotAProvider: pass
|
1481
|
+
message_with_text.has(NotAProvider)
|
1482
|
+
|
1483
|
+
async def test_zab_lstrip_and_rstrip(self):
|
1484
|
+
"""测试 lstrip, rstrip, 和 strip 方法是否能正确移除两侧的特定类型的 provider"""
|
1485
|
+
# 1. 定义一个用于测试的子类
|
1486
|
+
class SpecialTexts(Texts):
|
1487
|
+
pass
|
1488
|
+
url = "_IMG"
|
1489
|
+
|
1490
|
+
# 2. 创建一个复杂的测试消息
|
1491
|
+
message = UserMessage(
|
1492
|
+
Texts("leading1"),
|
1493
|
+
Texts("leading2"),
|
1494
|
+
Images(url, name="image1"),
|
1495
|
+
Texts("middle"),
|
1496
|
+
SpecialTexts("special_middle"),
|
1497
|
+
Images(url, name="image2"),
|
1498
|
+
Texts("trailing1"),
|
1499
|
+
SpecialTexts("special_trailing"), # rstrip(Texts) should stop here
|
1500
|
+
Texts("trailing2")
|
1501
|
+
)
|
1502
|
+
|
1503
|
+
# 3. 测试 rstrip(Texts)
|
1504
|
+
r_stripped_message = UserMessage(*message.provider()) # 创建副本
|
1505
|
+
r_stripped_message.rstrip(Texts)
|
1506
|
+
# 应移除 "trailing2",但在 "special_trailing" 处停止
|
1507
|
+
self.assertEqual(len(r_stripped_message), 8)
|
1508
|
+
self.assertIs(type(r_stripped_message[-1]), SpecialTexts)
|
1509
|
+
|
1510
|
+
# 4. 测试 lstrip(Texts)
|
1511
|
+
l_stripped_message = UserMessage(*message.provider()) # 创建副本
|
1512
|
+
l_stripped_message.lstrip(Texts)
|
1513
|
+
# 应移除 "leading1" 和 "leading2",但在 "image1" 处停止
|
1514
|
+
self.assertEqual(len(l_stripped_message), 7)
|
1515
|
+
self.assertIs(type(l_stripped_message[0]), Images)
|
1516
|
+
|
1517
|
+
# 5. 测试 strip(Texts)
|
1518
|
+
stripped_message = UserMessage(*message.provider()) # 创建副本
|
1519
|
+
stripped_message.strip(Texts)
|
1520
|
+
# 应同时移除 "leading1", "leading2", 和 "trailing2"
|
1521
|
+
self.assertEqual(len(stripped_message), 6)
|
1522
|
+
self.assertIs(type(stripped_message[0]), Images)
|
1523
|
+
self.assertIs(type(stripped_message[-1]), SpecialTexts)
|
1524
|
+
|
1525
|
+
# 6. 测试在一个只包含一种类型的消息上进行剥离
|
1526
|
+
only_texts = UserMessage(Texts("a"), Texts("b"))
|
1527
|
+
only_texts.strip(Texts)
|
1528
|
+
self.assertEqual(len(only_texts), 0)
|
1529
|
+
|
1530
|
+
# 7. 测试剥离一个不包含目标类型的消息
|
1531
|
+
only_images = UserMessage(Images("url1"), Images("url2"))
|
1532
|
+
only_images.strip(Texts)
|
1533
|
+
self.assertEqual(len(only_images), 2) # 不应改变
|
1534
|
+
|
1535
|
+
# 8. 测试在一个空消息上进行剥离
|
1536
|
+
empty_message = UserMessage()
|
1537
|
+
empty_message.strip(Texts)
|
1538
|
+
self.assertEqual(len(empty_message), 0)
|
1539
|
+
|
1540
|
+
# 9. 测试剥离子类
|
1541
|
+
message_ending_with_special = UserMessage(Texts("a"), SpecialTexts("b"))
|
1542
|
+
message_ending_with_special.rstrip(SpecialTexts)
|
1543
|
+
self.assertEqual(len(message_ending_with_special), 1)
|
1544
|
+
self.assertIsInstance(message_ending_with_special[0], Texts)
|
1545
|
+
|
1546
|
+
async def test_zac_texts_join_parameter(self):
|
1547
|
+
"""测试 Texts provider 是否支持通过参数控制拼接方式"""
|
1548
|
+
# 1. 测试默认行为:直接拼接
|
1549
|
+
message_default = UserMessage(
|
1550
|
+
Texts("First line."),
|
1551
|
+
Texts("Second line.")
|
1552
|
+
)
|
1553
|
+
rendered_default = await message_default.render_latest()
|
1554
|
+
self.assertEqual(rendered_default['content'], "First line.Second line.")
|
1555
|
+
|
1556
|
+
# 2. 测试新功能:使用 \n\n 拼接
|
1557
|
+
# 假设新参数为 `newline=True`
|
1558
|
+
message_newline = UserMessage(
|
1559
|
+
Texts("First paragraph."),
|
1560
|
+
Texts("Second paragraph.", newline=True)
|
1561
|
+
)
|
1562
|
+
rendered_newline = await message_newline.render_latest()
|
1563
|
+
self.assertEqual(rendered_newline['content'], "First paragraph.\n\nSecond paragraph.")
|
1564
|
+
|
1565
|
+
# 3. 测试多个 provider 的情况
|
1566
|
+
message_multiple = UserMessage(
|
1567
|
+
Texts("First."),
|
1568
|
+
Texts("Second.", newline=True),
|
1569
|
+
Texts("Third.", newline=True)
|
1570
|
+
)
|
1571
|
+
rendered_multiple = await message_multiple.render_latest()
|
1572
|
+
self.assertEqual(rendered_multiple['content'], "First.\n\nSecond.\n\nThird.")
|
1573
|
+
|
1574
|
+
# 4. 测试只有一个 provider 的情况
|
1575
|
+
message_single = UserMessage(
|
1576
|
+
Texts("Only one.", newline=True)
|
1577
|
+
)
|
1578
|
+
rendered_single = await message_single.render_latest()
|
1579
|
+
self.assertEqual(rendered_single['content'], "Only one.")
|
1580
|
+
|
1457
1581
|
|
1458
1582
|
# ==============================================================================
|
1459
1583
|
# 6. 演示
|
@@ -17,6 +17,21 @@ from ..core.request import prepare_request_payload
|
|
17
17
|
from ..core.response import fetch_response_stream, fetch_response
|
18
18
|
from ..architext.architext import Messages, SystemMessage, UserMessage, AssistantMessage, ToolCalls, ToolResults, Texts, RoleMessage, Images, Files
|
19
19
|
|
20
|
+
class ToolResult(Texts):
|
21
|
+
def __init__(self, tool_name: str, tool_args: str, tool_response: str, name: Optional[str] = None, visible: bool = True, newline: bool = True):
|
22
|
+
super().__init__(text=tool_response, name=name or f"tool_result_{tool_name}", visible=visible, newline=newline)
|
23
|
+
self.tool_name = tool_name
|
24
|
+
self.tool_args = tool_args
|
25
|
+
|
26
|
+
async def render(self) -> Optional[str]:
|
27
|
+
tool_response = await super().render()
|
28
|
+
if tool_response is None:
|
29
|
+
tool_response = ""
|
30
|
+
if self.tool_args:
|
31
|
+
return f"[{self.tool_name}({self.tool_args}) Result]:\n\n{tool_response}"
|
32
|
+
else:
|
33
|
+
return f"[{self.tool_name} Result]:\n\n{tool_response}"
|
34
|
+
|
20
35
|
class APITimeoutError(Exception):
|
21
36
|
"""Custom exception for API timeout errors."""
|
22
37
|
pass
|
@@ -172,8 +187,8 @@ class chatgpt(BaseLLM):
|
|
172
187
|
self.conversation[convo_id].append(ToolCalls(tool_calls))
|
173
188
|
self.conversation[convo_id].append(ToolResults(tool_call_id=function_call_id, content=message))
|
174
189
|
else:
|
175
|
-
last_user_message = self.conversation[convo_id][-1]
|
176
|
-
if last_user_message != message:
|
190
|
+
last_user_message = self.conversation[convo_id][-1]
|
191
|
+
if last_user_message != UserMessage(message):
|
177
192
|
image_message_list = UserMessage()
|
178
193
|
if isinstance(function_arguments, str):
|
179
194
|
functions_list = json.loads(function_arguments)
|
@@ -564,7 +579,7 @@ class chatgpt(BaseLLM):
|
|
564
579
|
tool_calls = function_parameter
|
565
580
|
|
566
581
|
# 处理所有工具调用
|
567
|
-
all_responses =
|
582
|
+
all_responses = UserMessage()
|
568
583
|
|
569
584
|
for tool_info in tool_calls:
|
570
585
|
tool_name = tool_info['function_name']
|
@@ -584,27 +599,28 @@ class chatgpt(BaseLLM):
|
|
584
599
|
tool_response = chunk.replace("function_response:", "")
|
585
600
|
else:
|
586
601
|
yield chunk
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
602
|
+
final_tool_response = tool_response
|
603
|
+
if "<tool_error>" not in tool_response:
|
604
|
+
if tool_name == "read_file":
|
605
|
+
self.conversation[convo_id].provider("files").update(tool_info['parameter']["file_path"], tool_response)
|
606
|
+
final_tool_response = "Read file successfully! The file content has been updated in the tag <latest_file_content>."
|
607
|
+
elif tool_name == "get_knowledge_graph_tree":
|
608
|
+
self.conversation[convo_id].provider("knowledge_graph").visible = True
|
609
|
+
final_tool_response = "Get knowledge graph tree successfully! The knowledge graph tree has been updated in the tag <knowledge_graph_tree>."
|
610
|
+
elif tool_name == "write_to_file":
|
611
|
+
tool_args = None
|
612
|
+
elif tool_name == "read_image":
|
613
|
+
tool_info["base64_image"] = tool_response
|
614
|
+
final_tool_response = "Read image successfully!"
|
615
|
+
elif tool_response.startswith("data:image/") and ";base64," in tool_response:
|
616
|
+
tool_info["base64_image"] = tool_response
|
617
|
+
final_tool_response = "Read image successfully!"
|
618
|
+
all_responses.append(ToolResult(tool_name, tool_args, final_tool_response))
|
603
619
|
|
604
620
|
# 合并所有工具响应
|
605
|
-
function_response =
|
621
|
+
function_response = all_responses
|
606
622
|
if missing_required_params:
|
607
|
-
function_response
|
623
|
+
function_response.append(Texts("\n\n".join(missing_required_params)))
|
608
624
|
|
609
625
|
# 使用第一个工具的名称和参数作为历史记录
|
610
626
|
function_call_name = tool_calls[0]['function_name']
|
@@ -672,7 +688,7 @@ class chatgpt(BaseLLM):
|
|
672
688
|
# 准备会话
|
673
689
|
self.system_prompt = system_prompt or self.system_prompt
|
674
690
|
if convo_id not in self.conversation or pass_history <= 2:
|
675
|
-
self.reset(convo_id=convo_id, system_prompt=system_prompt)
|
691
|
+
self.reset(convo_id=convo_id, system_prompt=self.system_prompt)
|
676
692
|
self.add_to_conversation(prompt, role, convo_id=convo_id, function_name=function_name, total_tokens=total_tokens, function_arguments=function_arguments, pass_history=pass_history, function_call_id=function_call_id)
|
677
693
|
|
678
694
|
# 获取请求体
|
@@ -929,7 +945,7 @@ class chatgpt(BaseLLM):
|
|
929
945
|
"""
|
930
946
|
self.system_prompt = system_prompt or self.system_prompt
|
931
947
|
self.conversation[convo_id] = Messages(
|
932
|
-
SystemMessage(
|
948
|
+
SystemMessage(self.system_prompt, self.conversation[convo_id].provider("files")),
|
933
949
|
)
|
934
950
|
self.tokens_usage[convo_id] = 0
|
935
951
|
self.current_tokens[convo_id] = 0
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|