aient 1.2.3__py3-none-any.whl → 1.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aient/architext/architext/core.py +7 -7
- aient/architext/test/test.py +29 -29
- aient/models/chatgpt.py +6 -0
- {aient-1.2.3.dist-info → aient-1.2.5.dist-info}/METADATA +1 -1
- {aient-1.2.3.dist-info → aient-1.2.5.dist-info}/RECORD +8 -8
- {aient-1.2.3.dist-info → aient-1.2.5.dist-info}/WHEEL +0 -0
- {aient-1.2.3.dist-info → aient-1.2.5.dist-info}/licenses/LICENSE +0 -0
- {aient-1.2.3.dist-info → aient-1.2.5.dist-info}/top_level.txt +0 -0
@@ -299,7 +299,7 @@ class Message(ABC):
|
|
299
299
|
processed_items.append(Texts(text=item))
|
300
300
|
|
301
301
|
elif isinstance(item, Message):
|
302
|
-
processed_items.extend(item.
|
302
|
+
processed_items.extend(item.provider())
|
303
303
|
elif isinstance(item, ContextProvider):
|
304
304
|
processed_items.append(item)
|
305
305
|
elif isinstance(item, list):
|
@@ -360,14 +360,14 @@ class Message(ABC):
|
|
360
360
|
if self._parent_messages:
|
361
361
|
self._parent_messages._notify_provider_added(item, self)
|
362
362
|
|
363
|
-
def
|
363
|
+
def provider(self) -> List[ContextProvider]: return self._items
|
364
364
|
|
365
365
|
def __add__(self, other):
|
366
366
|
if isinstance(other, str):
|
367
367
|
new_items = self._items + [Texts(text=other)]
|
368
368
|
return type(self)(*new_items)
|
369
369
|
if isinstance(other, Message):
|
370
|
-
new_items = self._items + other.
|
370
|
+
new_items = self._items + other.provider()
|
371
371
|
return type(self)(*new_items)
|
372
372
|
return NotImplemented
|
373
373
|
|
@@ -376,7 +376,7 @@ class Message(ABC):
|
|
376
376
|
new_items = [Texts(text=other)] + self._items
|
377
377
|
return type(self)(*new_items)
|
378
378
|
if isinstance(other, Message):
|
379
|
-
new_items = other.
|
379
|
+
new_items = other.provider() + self._items
|
380
380
|
return type(self)(*new_items)
|
381
381
|
return NotImplemented
|
382
382
|
|
@@ -563,7 +563,7 @@ class Messages:
|
|
563
563
|
return None
|
564
564
|
popped_message = self._messages.pop(key)
|
565
565
|
popped_message._parent_messages = None
|
566
|
-
for provider in popped_message.
|
566
|
+
for provider in popped_message.provider():
|
567
567
|
self._notify_provider_removed(provider)
|
568
568
|
return popped_message
|
569
569
|
except IndexError:
|
@@ -589,12 +589,12 @@ class Messages:
|
|
589
589
|
def append(self, message: Message):
|
590
590
|
if self._messages and self._messages[-1].role == message.role:
|
591
591
|
last_message = self._messages[-1]
|
592
|
-
for provider in message.
|
592
|
+
for provider in message.provider():
|
593
593
|
last_message.append(provider)
|
594
594
|
else:
|
595
595
|
message._parent_messages = self
|
596
596
|
self._messages.append(message)
|
597
|
-
for p in message.
|
597
|
+
for p in message.provider():
|
598
598
|
self._notify_provider_added(p, message)
|
599
599
|
|
600
600
|
def save(self, file_path: str):
|
aient/architext/test/test.py
CHANGED
@@ -204,7 +204,7 @@ class TestContextManagement(unittest.IsolatedAsyncioTestCase):
|
|
204
204
|
|
205
205
|
# 验证是否真的从 Message 对象中弹出了
|
206
206
|
self.assertIs(popped_provider, self.tools_provider, "应该从 SystemMessage 中成功弹出 provider")
|
207
|
-
self.assertNotIn(self.tools_provider, system_message.
|
207
|
+
self.assertNotIn(self.tools_provider, system_message.provider(), "provider 不应再存在于 SystemMessage 的 provider 列表中")
|
208
208
|
|
209
209
|
# 3. 核心问题:检查顶层 Messages 的索引
|
210
210
|
# 在理想情况下,直接修改子消息应该同步更新顶层索引。
|
@@ -237,8 +237,8 @@ class TestContextManagement(unittest.IsolatedAsyncioTestCase):
|
|
237
237
|
|
238
238
|
# 验证弹出的消息是否正确
|
239
239
|
self.assertIsInstance(popped_message, UserMessage)
|
240
|
-
self.assertEqual(len(popped_message.
|
241
|
-
self.assertEqual(popped_message.
|
240
|
+
self.assertEqual(len(popped_message.provider()), 1)
|
241
|
+
self.assertEqual(popped_message.provider()[0].name, user_provider.name)
|
242
242
|
|
243
243
|
# 验证 Messages 对象的当前状态
|
244
244
|
self.assertEqual(len(messages), 2)
|
@@ -418,13 +418,13 @@ class TestContextManagement(unittest.IsolatedAsyncioTestCase):
|
|
418
418
|
user_message = UserMessage(self.files_provider, "This is a raw string.")
|
419
419
|
|
420
420
|
# 验证 _items 列表中的第二个元素是否是 Texts 类的实例
|
421
|
-
self.assertEqual(len(user_message.
|
422
|
-
self.assertIsInstance(user_message.
|
423
|
-
self.assertIsInstance(user_message.
|
421
|
+
self.assertEqual(len(user_message.provider()), 2)
|
422
|
+
self.assertIsInstance(user_message.provider()[0], Files)
|
423
|
+
self.assertIsInstance(user_message.provider()[1], Texts)
|
424
424
|
|
425
425
|
# 验证转换后的 Texts provider 内容是否正确
|
426
426
|
# 我们需要异步地获取内容
|
427
|
-
text_provider = user_message.
|
427
|
+
text_provider = user_message.provider()[1]
|
428
428
|
await text_provider.refresh() # 手动刷新以获取内容
|
429
429
|
content_block = text_provider.get_content_block()
|
430
430
|
self.assertIsNotNone(content_block)
|
@@ -446,7 +446,7 @@ class TestContextManagement(unittest.IsolatedAsyncioTestCase):
|
|
446
446
|
# 3. 测试RoleMessage工厂类
|
447
447
|
factory_user_msg = RoleMessage('user', "Factory-created string.")
|
448
448
|
self.assertIsInstance(factory_user_msg, UserMessage)
|
449
|
-
self.assertIsInstance(factory_user_msg.
|
449
|
+
self.assertIsInstance(factory_user_msg.provider()[0], Texts)
|
450
450
|
|
451
451
|
# 4. 测试无效类型
|
452
452
|
with self.assertRaises(TypeError):
|
@@ -461,12 +461,12 @@ class TestContextManagement(unittest.IsolatedAsyncioTestCase):
|
|
461
461
|
]
|
462
462
|
user_message_mixed = UserMessage(mixed_content_list)
|
463
463
|
|
464
|
-
self.assertEqual(len(user_message_mixed.
|
465
|
-
self.assertIsInstance(user_message_mixed.
|
466
|
-
self.assertIsInstance(user_message_mixed.
|
464
|
+
self.assertEqual(len(user_message_mixed.provider()), 2)
|
465
|
+
self.assertIsInstance(user_message_mixed.provider()[0], Texts)
|
466
|
+
self.assertIsInstance(user_message_mixed.provider()[1], Images)
|
467
467
|
|
468
468
|
# 验证内容
|
469
|
-
providers = user_message_mixed.
|
469
|
+
providers = user_message_mixed.provider()
|
470
470
|
await asyncio.gather(*[p.refresh() for p in providers]) # 刷新所有providers
|
471
471
|
self.assertEqual(providers[0].get_content_block().content, 'Describe the following image.')
|
472
472
|
self.assertEqual(providers[1].get_content_block().content, 'data:image/png;base64,VGhpcyBpcyBhIGR1bW15IGltYWdlIGZpbGUu')
|
@@ -478,9 +478,9 @@ class TestContextManagement(unittest.IsolatedAsyncioTestCase):
|
|
478
478
|
]
|
479
479
|
user_message_text_only = UserMessage(text_only_list)
|
480
480
|
|
481
|
-
self.assertEqual(len(user_message_text_only.
|
482
|
-
self.assertIsInstance(user_message_text_only.
|
483
|
-
self.assertIsInstance(user_message_text_only.
|
481
|
+
self.assertEqual(len(user_message_text_only.provider()), 2)
|
482
|
+
self.assertIsInstance(user_message_text_only.provider()[0], Texts)
|
483
|
+
self.assertIsInstance(user_message_text_only.provider()[1], Texts)
|
484
484
|
|
485
485
|
# 3. 在 Messages 容器中测试
|
486
486
|
messages = Messages(UserMessage(mixed_content_list))
|
@@ -523,9 +523,9 @@ class TestContextManagement(unittest.IsolatedAsyncioTestCase):
|
|
523
523
|
|
524
524
|
# 3. 验证新消息的类型和内容
|
525
525
|
self.assertIsInstance(new_message, UserMessage, "结果应该是一个 UserMessage 实例")
|
526
|
-
self.assertEqual(len(new_message.
|
526
|
+
self.assertEqual(len(new_message.provider()), 2, "新消息应该包含两个 provider")
|
527
527
|
|
528
|
-
providers = new_message.
|
528
|
+
providers = new_message.provider()
|
529
529
|
self.assertIsInstance(providers[0], Texts, "第一个 provider 应该是 Texts 类型")
|
530
530
|
self.assertIsInstance(providers[1], Texts, "第二个 provider 应该是 Texts 类型")
|
531
531
|
|
@@ -536,14 +536,14 @@ class TestContextManagement(unittest.IsolatedAsyncioTestCase):
|
|
536
536
|
self.assertEqual(providers[1].get_content_block().content, "hello", "第二个 provider 的内容应该是 'hello'")
|
537
537
|
|
538
538
|
# 4. 验证原始消息没有被修改
|
539
|
-
self.assertEqual(len(original_message.
|
539
|
+
self.assertEqual(len(original_message.provider()), 1, "原始消息不应该被修改")
|
540
540
|
|
541
541
|
# 5. 测试 UserMessage + "string"
|
542
542
|
new_message_add = original_message + "world"
|
543
543
|
self.assertIsInstance(new_message_add, UserMessage)
|
544
|
-
self.assertEqual(len(new_message_add.
|
544
|
+
self.assertEqual(len(new_message_add.provider()), 2)
|
545
545
|
|
546
|
-
providers_add = new_message_add.
|
546
|
+
providers_add = new_message_add.provider()
|
547
547
|
await asyncio.gather(*[p.refresh() for p in providers_add])
|
548
548
|
self.assertEqual(providers_add[0].get_content_block().content, "hello")
|
549
549
|
self.assertEqual(providers_add[1].get_content_block().content, "world")
|
@@ -553,9 +553,9 @@ class TestContextManagement(unittest.IsolatedAsyncioTestCase):
|
|
553
553
|
# 1. 测试 "str" + UserMessage
|
554
554
|
combined_message = "hi" + UserMessage("hello")
|
555
555
|
self.assertIsInstance(combined_message, UserMessage)
|
556
|
-
self.assertEqual(len(combined_message.
|
556
|
+
self.assertEqual(len(combined_message.provider()), 2)
|
557
557
|
|
558
|
-
providers = combined_message.
|
558
|
+
providers = combined_message.provider()
|
559
559
|
await asyncio.gather(*[p.refresh() for p in providers])
|
560
560
|
self.assertEqual(providers[0].get_content_block().content, "hi")
|
561
561
|
self.assertEqual(providers[1].get_content_block().content, "hello")
|
@@ -563,9 +563,9 @@ class TestContextManagement(unittest.IsolatedAsyncioTestCase):
|
|
563
563
|
# 2. 测试 UserMessage(UserMessage(...)) 扁平化
|
564
564
|
# 按照用户的要求,UserMessage(UserMessage(...)) 应该被扁平化
|
565
565
|
nested_message = UserMessage(UserMessage("item1", "item2"))
|
566
|
-
self.assertEqual(len(nested_message.
|
566
|
+
self.assertEqual(len(nested_message.provider()), 2)
|
567
567
|
|
568
|
-
providers_nested = nested_message.
|
568
|
+
providers_nested = nested_message.provider()
|
569
569
|
self.assertIsInstance(providers_nested[0], Texts)
|
570
570
|
self.assertIsInstance(providers_nested[1], Texts)
|
571
571
|
|
@@ -576,9 +576,9 @@ class TestContextManagement(unittest.IsolatedAsyncioTestCase):
|
|
576
576
|
# 3. 结合 1 和 2,测试用户的完整场景
|
577
577
|
final_message = UserMessage("hi" + UserMessage("hello"))
|
578
578
|
self.assertIsInstance(final_message, UserMessage)
|
579
|
-
self.assertEqual(len(final_message.
|
579
|
+
self.assertEqual(len(final_message.provider()), 2)
|
580
580
|
|
581
|
-
providers_final = final_message.
|
581
|
+
providers_final = final_message.provider()
|
582
582
|
await asyncio.gather(*[p.refresh() for p in providers_final])
|
583
583
|
self.assertEqual(providers_final[0].get_content_block().content, "hi")
|
584
584
|
self.assertEqual(providers_final[1].get_content_block().content, "hello")
|
@@ -1108,7 +1108,7 @@ Current time: {Texts(lambda: datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}
|
|
1108
1108
|
user_message = UserMessage("你好, Architext!")
|
1109
1109
|
# 对于简单的 Texts, refresh 不是必须的, 但这是个好习惯
|
1110
1110
|
# Message 类本身没有 refresh, 调用其 providers 的 refresh
|
1111
|
-
for p in user_message.
|
1111
|
+
for p in user_message.provider():
|
1112
1112
|
await p.refresh()
|
1113
1113
|
|
1114
1114
|
# 2. 直接访问 .content 属性
|
@@ -1120,7 +1120,7 @@ Current time: {Texts(lambda: datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}
|
|
1120
1120
|
"这是一张图片:",
|
1121
1121
|
Images(url="data:image/png;base64,FAKE_IMG_DATA")
|
1122
1122
|
)
|
1123
|
-
for p in multimodal_message.
|
1123
|
+
for p in multimodal_message.provider():
|
1124
1124
|
await p.refresh()
|
1125
1125
|
|
1126
1126
|
# 4. 访问多模态消息的 .content 属性,期望返回一个列表
|
@@ -1132,7 +1132,7 @@ Current time: {Texts(lambda: datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}
|
|
1132
1132
|
|
1133
1133
|
# 5. 测试通过 RoleMessage 工厂创建的消息
|
1134
1134
|
role_message = RoleMessage('user', "通过工厂创建的内容")
|
1135
|
-
for p in role_message.
|
1135
|
+
for p in role_message.provider():
|
1136
1136
|
await p.refresh()
|
1137
1137
|
self.assertEqual(role_message.content, "通过工厂创建的内容")
|
1138
1138
|
|
aient/models/chatgpt.py
CHANGED
@@ -256,6 +256,12 @@ class chatgpt(BaseLLM):
|
|
256
256
|
"image": True
|
257
257
|
}
|
258
258
|
|
259
|
+
done_message = self.conversation[convo_id].provider("done")
|
260
|
+
if self.check_done and done_message:
|
261
|
+
done_message.visible = False
|
262
|
+
if self.conversation[convo_id][-1][-1].name == "done":
|
263
|
+
done_message[-1].visible = True
|
264
|
+
|
259
265
|
# 构造请求数据
|
260
266
|
request_data = {
|
261
267
|
"model": model or self.engine,
|
@@ -1,8 +1,8 @@
|
|
1
1
|
aient/__init__.py,sha256=SRfF7oDVlOOAi6nGKiJIUK6B_arqYLO9iSMp-2IZZps,21
|
2
2
|
aient/architext/architext/__init__.py,sha256=79Ih1151rfcqZdr7F8HSZSTs_iT2SKd1xCkehMsXeXs,19
|
3
|
-
aient/architext/architext/core.py,sha256=
|
3
|
+
aient/architext/architext/core.py,sha256=Wn4a8npNERbVW9BrRR55-ePjNjIjkKZjKEAp7Lg-LbQ,25629
|
4
4
|
aient/architext/test/openai_client.py,sha256=Dqtbmubv6vwF8uBqcayG0kbsiO65of7sgU2-DRBi-UM,4590
|
5
|
-
aient/architext/test/test.py,sha256=
|
5
|
+
aient/architext/test/test.py,sha256=w5ANQyoD_HQFol4EZKUKOa4MyDMBWP1YOUzKUxr2PaA,55286
|
6
6
|
aient/architext/test/test_save_load.py,sha256=o8DqH6gDYZkFkQy-a7blqLtJTRj5e4a-Lil48pJ0V3g,3260
|
7
7
|
aient/core/__init__.py,sha256=NxjebTlku35S4Dzr16rdSqSTWUvvwEeACe8KvHJnjPg,34
|
8
8
|
aient/core/log_config.py,sha256=kz2_yJv1p-o3lUQOwA3qh-LSc3wMHv13iCQclw44W9c,274
|
@@ -17,7 +17,7 @@ aient/core/test/test_payload.py,sha256=8jBiJY1uidm1jzL-EiK0s6UGmW9XkdsuuKFGrwFhF
|
|
17
17
|
aient/models/__init__.py,sha256=ZTiZgbfBPTjIPSKURE7t6hlFBVLRS9lluGbmqc1WjxQ,43
|
18
18
|
aient/models/audio.py,sha256=kRd-8-WXzv4vwvsTGwnstK-WR8--vr9CdfCZzu8y9LA,1934
|
19
19
|
aient/models/base.py,sha256=-nnihYnx-vHZMqeVO9ljjt3k4FcD3n-iMk4tT-10nRQ,7232
|
20
|
-
aient/models/chatgpt.py,sha256=
|
20
|
+
aient/models/chatgpt.py,sha256=XM63-dJql-Ti6mWK2nEFUACjdoz8uN-mqh4B3NH8kLw,42189
|
21
21
|
aient/plugins/__init__.py,sha256=p3KO6Aa3Lupos4i2SjzLQw1hzQTigOAfEHngsldrsyk,986
|
22
22
|
aient/plugins/arXiv.py,sha256=yHjb6PS3GUWazpOYRMKMzghKJlxnZ5TX8z9F6UtUVow,1461
|
23
23
|
aient/plugins/config.py,sha256=TGgZ5SnNKZ8MmdznrZ-TEq7s2ulhAAwTSKH89bci3dA,7079
|
@@ -35,8 +35,8 @@ aient/plugins/write_file.py,sha256=Jt8fOEwqhYiSWpCbwfAr1xoi_BmFnx3076GMhuL06uI,3
|
|
35
35
|
aient/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
36
36
|
aient/utils/prompt.py,sha256=UcSzKkFE4-h_1b6NofI6xgk3GoleqALRKY8VBaXLjmI,11311
|
37
37
|
aient/utils/scripts.py,sha256=VqtK4RFEx7KxkmcqG3lFDS1DxoNlFFGErEjopVcc8IE,40974
|
38
|
-
aient-1.2.
|
39
|
-
aient-1.2.
|
40
|
-
aient-1.2.
|
41
|
-
aient-1.2.
|
42
|
-
aient-1.2.
|
38
|
+
aient-1.2.5.dist-info/licenses/LICENSE,sha256=XNdbcWldt0yaNXXWB_Bakoqnxb3OVhUft4MgMA_71ds,1051
|
39
|
+
aient-1.2.5.dist-info/METADATA,sha256=l3t245vxtz0K3cf1bzWZ2E4kN33Qyb3z5aMGFG2Hv_U,4841
|
40
|
+
aient-1.2.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
41
|
+
aient-1.2.5.dist-info/top_level.txt,sha256=3oXzrP5sAVvyyqabpeq8A2_vfMtY554r4bVE-OHBrZk,6
|
42
|
+
aient-1.2.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|