aient 1.2.4__py3-none-any.whl → 1.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -299,7 +299,7 @@ class Message(ABC):
299
299
  processed_items.append(Texts(text=item))
300
300
 
301
301
  elif isinstance(item, Message):
302
- processed_items.extend(item.providers())
302
+ processed_items.extend(item.provider())
303
303
  elif isinstance(item, ContextProvider):
304
304
  processed_items.append(item)
305
305
  elif isinstance(item, list):
@@ -360,14 +360,24 @@ class Message(ABC):
360
360
  if self._parent_messages:
361
361
  self._parent_messages._notify_provider_added(item, self)
362
362
 
363
- def providers(self) -> List[ContextProvider]: return self._items
363
+ def provider(self, name: Optional[str] = None) -> Optional[Union[ContextProvider, ProviderGroup, List[ContextProvider]]]:
364
+ if name is None:
365
+ return self._items
366
+
367
+ named_providers = [p for p in self._items if hasattr(p, 'name') and p.name == name]
368
+
369
+ if not named_providers:
370
+ return None
371
+ if len(named_providers) == 1:
372
+ return named_providers[0]
373
+ return ProviderGroup(named_providers)
364
374
 
365
375
  def __add__(self, other):
366
376
  if isinstance(other, str):
367
377
  new_items = self._items + [Texts(text=other)]
368
378
  return type(self)(*new_items)
369
379
  if isinstance(other, Message):
370
- new_items = self._items + other.providers()
380
+ new_items = self._items + other.provider()
371
381
  return type(self)(*new_items)
372
382
  return NotImplemented
373
383
 
@@ -376,7 +386,7 @@ class Message(ABC):
376
386
  new_items = [Texts(text=other)] + self._items
377
387
  return type(self)(*new_items)
378
388
  if isinstance(other, Message):
379
- new_items = other.providers() + self._items
389
+ new_items = other.provider() + self._items
380
390
  return type(self)(*new_items)
381
391
  return NotImplemented
382
392
 
@@ -563,7 +573,7 @@ class Messages:
563
573
  return None
564
574
  popped_message = self._messages.pop(key)
565
575
  popped_message._parent_messages = None
566
- for provider in popped_message.providers():
576
+ for provider in popped_message.provider():
567
577
  self._notify_provider_removed(provider)
568
578
  return popped_message
569
579
  except IndexError:
@@ -589,12 +599,12 @@ class Messages:
589
599
  def append(self, message: Message):
590
600
  if self._messages and self._messages[-1].role == message.role:
591
601
  last_message = self._messages[-1]
592
- for provider in message.providers():
602
+ for provider in message.provider():
593
603
  last_message.append(provider)
594
604
  else:
595
605
  message._parent_messages = self
596
606
  self._messages.append(message)
597
- for p in message.providers():
607
+ for p in message.provider():
598
608
  self._notify_provider_added(p, message)
599
609
 
600
610
  def save(self, file_path: str):
@@ -204,7 +204,7 @@ class TestContextManagement(unittest.IsolatedAsyncioTestCase):
204
204
 
205
205
  # 验证是否真的从 Message 对象中弹出了
206
206
  self.assertIs(popped_provider, self.tools_provider, "应该从 SystemMessage 中成功弹出 provider")
207
- self.assertNotIn(self.tools_provider, system_message.providers(), "provider 不应再存在于 SystemMessage 的 providers 列表中")
207
+ self.assertNotIn(self.tools_provider, system_message.provider(), "provider 不应再存在于 SystemMessage 的 provider 列表中")
208
208
 
209
209
  # 3. 核心问题:检查顶层 Messages 的索引
210
210
  # 在理想情况下,直接修改子消息应该同步更新顶层索引。
@@ -237,8 +237,8 @@ class TestContextManagement(unittest.IsolatedAsyncioTestCase):
237
237
 
238
238
  # 验证弹出的消息是否正确
239
239
  self.assertIsInstance(popped_message, UserMessage)
240
- self.assertEqual(len(popped_message.providers()), 1)
241
- self.assertEqual(popped_message.providers()[0].name, user_provider.name)
240
+ self.assertEqual(len(popped_message.provider()), 1)
241
+ self.assertEqual(popped_message.provider()[0].name, user_provider.name)
242
242
 
243
243
  # 验证 Messages 对象的当前状态
244
244
  self.assertEqual(len(messages), 2)
@@ -418,13 +418,13 @@ class TestContextManagement(unittest.IsolatedAsyncioTestCase):
418
418
  user_message = UserMessage(self.files_provider, "This is a raw string.")
419
419
 
420
420
  # 验证 _items 列表中的第二个元素是否是 Texts 类的实例
421
- self.assertEqual(len(user_message.providers()), 2)
422
- self.assertIsInstance(user_message.providers()[0], Files)
423
- self.assertIsInstance(user_message.providers()[1], Texts)
421
+ self.assertEqual(len(user_message.provider()), 2)
422
+ self.assertIsInstance(user_message.provider()[0], Files)
423
+ self.assertIsInstance(user_message.provider()[1], Texts)
424
424
 
425
425
  # 验证转换后的 Texts provider 内容是否正确
426
426
  # 我们需要异步地获取内容
427
- text_provider = user_message.providers()[1]
427
+ text_provider = user_message.provider()[1]
428
428
  await text_provider.refresh() # 手动刷新以获取内容
429
429
  content_block = text_provider.get_content_block()
430
430
  self.assertIsNotNone(content_block)
@@ -446,7 +446,7 @@ class TestContextManagement(unittest.IsolatedAsyncioTestCase):
446
446
  # 3. 测试RoleMessage工厂类
447
447
  factory_user_msg = RoleMessage('user', "Factory-created string.")
448
448
  self.assertIsInstance(factory_user_msg, UserMessage)
449
- self.assertIsInstance(factory_user_msg.providers()[0], Texts)
449
+ self.assertIsInstance(factory_user_msg.provider()[0], Texts)
450
450
 
451
451
  # 4. 测试无效类型
452
452
  with self.assertRaises(TypeError):
@@ -461,12 +461,12 @@ class TestContextManagement(unittest.IsolatedAsyncioTestCase):
461
461
  ]
462
462
  user_message_mixed = UserMessage(mixed_content_list)
463
463
 
464
- self.assertEqual(len(user_message_mixed.providers()), 2)
465
- self.assertIsInstance(user_message_mixed.providers()[0], Texts)
466
- self.assertIsInstance(user_message_mixed.providers()[1], Images)
464
+ self.assertEqual(len(user_message_mixed.provider()), 2)
465
+ self.assertIsInstance(user_message_mixed.provider()[0], Texts)
466
+ self.assertIsInstance(user_message_mixed.provider()[1], Images)
467
467
 
468
468
  # 验证内容
469
- providers = user_message_mixed.providers()
469
+ providers = user_message_mixed.provider()
470
470
  await asyncio.gather(*[p.refresh() for p in providers]) # 刷新所有providers
471
471
  self.assertEqual(providers[0].get_content_block().content, 'Describe the following image.')
472
472
  self.assertEqual(providers[1].get_content_block().content, 'data:image/png;base64,VGhpcyBpcyBhIGR1bW15IGltYWdlIGZpbGUu')
@@ -478,9 +478,9 @@ class TestContextManagement(unittest.IsolatedAsyncioTestCase):
478
478
  ]
479
479
  user_message_text_only = UserMessage(text_only_list)
480
480
 
481
- self.assertEqual(len(user_message_text_only.providers()), 2)
482
- self.assertIsInstance(user_message_text_only.providers()[0], Texts)
483
- self.assertIsInstance(user_message_text_only.providers()[1], Texts)
481
+ self.assertEqual(len(user_message_text_only.provider()), 2)
482
+ self.assertIsInstance(user_message_text_only.provider()[0], Texts)
483
+ self.assertIsInstance(user_message_text_only.provider()[1], Texts)
484
484
 
485
485
  # 3. 在 Messages 容器中测试
486
486
  messages = Messages(UserMessage(mixed_content_list))
@@ -523,9 +523,9 @@ class TestContextManagement(unittest.IsolatedAsyncioTestCase):
523
523
 
524
524
  # 3. 验证新消息的类型和内容
525
525
  self.assertIsInstance(new_message, UserMessage, "结果应该是一个 UserMessage 实例")
526
- self.assertEqual(len(new_message.providers()), 2, "新消息应该包含两个 provider")
526
+ self.assertEqual(len(new_message.provider()), 2, "新消息应该包含两个 provider")
527
527
 
528
- providers = new_message.providers()
528
+ providers = new_message.provider()
529
529
  self.assertIsInstance(providers[0], Texts, "第一个 provider 应该是 Texts 类型")
530
530
  self.assertIsInstance(providers[1], Texts, "第二个 provider 应该是 Texts 类型")
531
531
 
@@ -536,14 +536,14 @@ class TestContextManagement(unittest.IsolatedAsyncioTestCase):
536
536
  self.assertEqual(providers[1].get_content_block().content, "hello", "第二个 provider 的内容应该是 'hello'")
537
537
 
538
538
  # 4. 验证原始消息没有被修改
539
- self.assertEqual(len(original_message.providers()), 1, "原始消息不应该被修改")
539
+ self.assertEqual(len(original_message.provider()), 1, "原始消息不应该被修改")
540
540
 
541
541
  # 5. 测试 UserMessage + "string"
542
542
  new_message_add = original_message + "world"
543
543
  self.assertIsInstance(new_message_add, UserMessage)
544
- self.assertEqual(len(new_message_add.providers()), 2)
544
+ self.assertEqual(len(new_message_add.provider()), 2)
545
545
 
546
- providers_add = new_message_add.providers()
546
+ providers_add = new_message_add.provider()
547
547
  await asyncio.gather(*[p.refresh() for p in providers_add])
548
548
  self.assertEqual(providers_add[0].get_content_block().content, "hello")
549
549
  self.assertEqual(providers_add[1].get_content_block().content, "world")
@@ -553,9 +553,9 @@ class TestContextManagement(unittest.IsolatedAsyncioTestCase):
553
553
  # 1. 测试 "str" + UserMessage
554
554
  combined_message = "hi" + UserMessage("hello")
555
555
  self.assertIsInstance(combined_message, UserMessage)
556
- self.assertEqual(len(combined_message.providers()), 2)
556
+ self.assertEqual(len(combined_message.provider()), 2)
557
557
 
558
- providers = combined_message.providers()
558
+ providers = combined_message.provider()
559
559
  await asyncio.gather(*[p.refresh() for p in providers])
560
560
  self.assertEqual(providers[0].get_content_block().content, "hi")
561
561
  self.assertEqual(providers[1].get_content_block().content, "hello")
@@ -563,9 +563,9 @@ class TestContextManagement(unittest.IsolatedAsyncioTestCase):
563
563
  # 2. 测试 UserMessage(UserMessage(...)) 扁平化
564
564
  # 按照用户的要求,UserMessage(UserMessage(...)) 应该被扁平化
565
565
  nested_message = UserMessage(UserMessage("item1", "item2"))
566
- self.assertEqual(len(nested_message.providers()), 2)
566
+ self.assertEqual(len(nested_message.provider()), 2)
567
567
 
568
- providers_nested = nested_message.providers()
568
+ providers_nested = nested_message.provider()
569
569
  self.assertIsInstance(providers_nested[0], Texts)
570
570
  self.assertIsInstance(providers_nested[1], Texts)
571
571
 
@@ -576,9 +576,9 @@ class TestContextManagement(unittest.IsolatedAsyncioTestCase):
576
576
  # 3. 结合 1 和 2,测试用户的完整场景
577
577
  final_message = UserMessage("hi" + UserMessage("hello"))
578
578
  self.assertIsInstance(final_message, UserMessage)
579
- self.assertEqual(len(final_message.providers()), 2)
579
+ self.assertEqual(len(final_message.provider()), 2)
580
580
 
581
- providers_final = final_message.providers()
581
+ providers_final = final_message.provider()
582
582
  await asyncio.gather(*[p.refresh() for p in providers_final])
583
583
  self.assertEqual(providers_final[0].get_content_block().content, "hi")
584
584
  self.assertEqual(providers_final[1].get_content_block().content, "hello")
@@ -1108,7 +1108,7 @@ Current time: {Texts(lambda: datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}
1108
1108
  user_message = UserMessage("你好, Architext!")
1109
1109
  # 对于简单的 Texts, refresh 不是必须的, 但这是个好习惯
1110
1110
  # Message 类本身没有 refresh, 调用其 providers 的 refresh
1111
- for p in user_message.providers():
1111
+ for p in user_message.provider():
1112
1112
  await p.refresh()
1113
1113
 
1114
1114
  # 2. 直接访问 .content 属性
@@ -1120,7 +1120,7 @@ Current time: {Texts(lambda: datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}
1120
1120
  "这是一张图片:",
1121
1121
  Images(url="data:image/png;base64,FAKE_IMG_DATA")
1122
1122
  )
1123
- for p in multimodal_message.providers():
1123
+ for p in multimodal_message.provider():
1124
1124
  await p.refresh()
1125
1125
 
1126
1126
  # 4. 访问多模态消息的 .content 属性,期望返回一个列表
@@ -1132,7 +1132,7 @@ Current time: {Texts(lambda: datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}
1132
1132
 
1133
1133
  # 5. 测试通过 RoleMessage 工厂创建的消息
1134
1134
  role_message = RoleMessage('user', "通过工厂创建的内容")
1135
- for p in role_message.providers():
1135
+ for p in role_message.provider():
1136
1136
  await p.refresh()
1137
1137
  self.assertEqual(role_message.content, "通过工厂创建的内容")
1138
1138
 
@@ -1158,6 +1158,30 @@ Current time: {Texts(lambda: datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}
1158
1158
  with self.assertRaises(IndexError):
1159
1159
  _ = mess[2]
1160
1160
 
1161
+ async def test_zb_message_provider_by_name(self):
1162
+ """测试是否可以通过名称从 Message 对象中获取 provider"""
1163
+ # 1. 创建一个包含命名 provider 的 Message
1164
+ message = UserMessage(
1165
+ Texts("Some instruction", name="instruction"),
1166
+ Tools([{"name": "a_tool"}], name="tools"),
1167
+ Texts("Another instruction", name="instruction")
1168
+ )
1169
+
1170
+ # 2. 测试获取单个 provider
1171
+ tools_provider = message.provider("tools")
1172
+ self.assertIsInstance(tools_provider, Tools)
1173
+ self.assertEqual(tools_provider.name, "tools")
1174
+
1175
+ # 3. 测试获取多个同名 provider
1176
+ instruction_providers = message.provider("instruction")
1177
+ self.assertIsInstance(instruction_providers, ProviderGroup)
1178
+ self.assertEqual(len(instruction_providers), 2)
1179
+ self.assertTrue(all(isinstance(p, Texts) for p in instruction_providers))
1180
+
1181
+ # 4. 测试获取不存在的 provider
1182
+ non_existent_provider = message.provider("non_existent")
1183
+ self.assertIsNone(non_existent_provider)
1184
+
1161
1185
  # ==============================================================================
1162
1186
  # 6. 演示
1163
1187
  # ==============================================================================
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: aient
3
- Version: 1.2.4
3
+ Version: 1.2.6
4
4
  Summary: Aient: The Awakening of Agent.
5
5
  Requires-Python: >=3.11
6
6
  Description-Content-Type: text/markdown
@@ -1,8 +1,8 @@
1
1
  aient/__init__.py,sha256=SRfF7oDVlOOAi6nGKiJIUK6B_arqYLO9iSMp-2IZZps,21
2
2
  aient/architext/architext/__init__.py,sha256=79Ih1151rfcqZdr7F8HSZSTs_iT2SKd1xCkehMsXeXs,19
3
- aient/architext/architext/core.py,sha256=qSRiIfx9Cl_OZIDfGMFr7AKbwRWTBrJNt_jqJeOr5Lw,25636
3
+ aient/architext/architext/core.py,sha256=A1ZeZSJwcdl-svYA12uD2qSlU-wfNTVWvpSg8AF8_Gk,26015
4
4
  aient/architext/test/openai_client.py,sha256=Dqtbmubv6vwF8uBqcayG0kbsiO65of7sgU2-DRBi-UM,4590
5
- aient/architext/test/test.py,sha256=efWn_qYdgPZO-Wz68o33QoWhWiJjU-cAzPbIVL0sjhw,55316
5
+ aient/architext/test/test.py,sha256=trVHo2we0W8RN-0QNvP3sJ3yUpe08-34Ae_ZVNZJHdE,56378
6
6
  aient/architext/test/test_save_load.py,sha256=o8DqH6gDYZkFkQy-a7blqLtJTRj5e4a-Lil48pJ0V3g,3260
7
7
  aient/core/__init__.py,sha256=NxjebTlku35S4Dzr16rdSqSTWUvvwEeACe8KvHJnjPg,34
8
8
  aient/core/log_config.py,sha256=kz2_yJv1p-o3lUQOwA3qh-LSc3wMHv13iCQclw44W9c,274
@@ -35,8 +35,8 @@ aient/plugins/write_file.py,sha256=Jt8fOEwqhYiSWpCbwfAr1xoi_BmFnx3076GMhuL06uI,3
35
35
  aient/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
36
36
  aient/utils/prompt.py,sha256=UcSzKkFE4-h_1b6NofI6xgk3GoleqALRKY8VBaXLjmI,11311
37
37
  aient/utils/scripts.py,sha256=VqtK4RFEx7KxkmcqG3lFDS1DxoNlFFGErEjopVcc8IE,40974
38
- aient-1.2.4.dist-info/licenses/LICENSE,sha256=XNdbcWldt0yaNXXWB_Bakoqnxb3OVhUft4MgMA_71ds,1051
39
- aient-1.2.4.dist-info/METADATA,sha256=ibFnvDPyG3FMCoVOkm4yjuI_nUqrtqPtMKLITvC5rKM,4841
40
- aient-1.2.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
41
- aient-1.2.4.dist-info/top_level.txt,sha256=3oXzrP5sAVvyyqabpeq8A2_vfMtY554r4bVE-OHBrZk,6
42
- aient-1.2.4.dist-info/RECORD,,
38
+ aient-1.2.6.dist-info/licenses/LICENSE,sha256=XNdbcWldt0yaNXXWB_Bakoqnxb3OVhUft4MgMA_71ds,1051
39
+ aient-1.2.6.dist-info/METADATA,sha256=aKQPqRvoFl1kd5GXFxOziGdeRhNAZDL8ksFspWdljoo,4841
40
+ aient-1.2.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
41
+ aient-1.2.6.dist-info/top_level.txt,sha256=3oXzrP5sAVvyyqabpeq8A2_vfMtY554r4bVE-OHBrZk,6
42
+ aient-1.2.6.dist-info/RECORD,,
File without changes