aient 1.1.99__py3-none-any.whl → 1.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aient/architext/architext/core.py +58 -10
- aient/architext/test/test.py +48 -0
- aient/models/chatgpt.py +2 -7
- {aient-1.1.99.dist-info → aient-1.2.1.dist-info}/METADATA +1 -1
- {aient-1.1.99.dist-info → aient-1.2.1.dist-info}/RECORD +8 -8
- {aient-1.1.99.dist-info → aient-1.2.1.dist-info}/WHEEL +0 -0
- {aient-1.1.99.dist-info → aient-1.2.1.dist-info}/licenses/LICENSE +0 -0
- {aient-1.1.99.dist-info → aient-1.2.1.dist-info}/top_level.txt +0 -0
@@ -10,6 +10,30 @@ from dataclasses import dataclass
|
|
10
10
|
from abc import ABC, abstractmethod
|
11
11
|
from typing import List, Dict, Any, Optional, Union, Callable
|
12
12
|
|
13
|
+
# A wrapper to manage multiple providers with the same name
|
14
|
+
class ProviderGroup:
|
15
|
+
"""A container for multiple providers that share the same name, allowing for bulk operations."""
|
16
|
+
def __init__(self, providers: List['ContextProvider']):
|
17
|
+
self._providers = providers
|
18
|
+
def __getitem__(self, key: int) -> 'ContextProvider':
|
19
|
+
"""Allows accessing providers by index, e.g., group[-1]."""
|
20
|
+
return self._providers[key]
|
21
|
+
def __iter__(self):
|
22
|
+
"""Allows iterating over the providers."""
|
23
|
+
return iter(self._providers)
|
24
|
+
def __len__(self) -> int:
|
25
|
+
"""Returns the number of providers in the group."""
|
26
|
+
return len(self._providers)
|
27
|
+
@property
|
28
|
+
def visible(self) -> List[bool]:
|
29
|
+
"""Gets the visibility of all providers in the group."""
|
30
|
+
return [p.visible for p in self._providers]
|
31
|
+
@visible.setter
|
32
|
+
def visible(self, value: bool):
|
33
|
+
"""Sets the visibility for all providers in the group."""
|
34
|
+
for p in self._providers:
|
35
|
+
p.visible = value
|
36
|
+
|
13
37
|
# Global, thread-safe registry for providers created within f-strings
|
14
38
|
_fstring_provider_registry = {}
|
15
39
|
_registry_lock = threading.Lock()
|
@@ -436,22 +460,40 @@ class Messages:
|
|
436
460
|
def __init__(self, *initial_messages: Message):
|
437
461
|
from typing import Tuple
|
438
462
|
self._messages: List[Message] = []
|
439
|
-
self._providers_index: Dict[str, Tuple[ContextProvider, Message]] = {}
|
463
|
+
self._providers_index: Dict[str, List[Tuple[ContextProvider, Message]]] = {}
|
440
464
|
if initial_messages:
|
441
465
|
for msg in initial_messages:
|
442
466
|
self.append(msg)
|
443
467
|
|
444
468
|
def _notify_provider_added(self, provider: ContextProvider, message: Message):
|
445
469
|
if provider.name not in self._providers_index:
|
446
|
-
self._providers_index[provider.name] =
|
470
|
+
self._providers_index[provider.name] = []
|
471
|
+
self._providers_index[provider.name].append((provider, message))
|
447
472
|
|
448
473
|
def _notify_provider_removed(self, provider: ContextProvider):
|
449
474
|
if provider.name in self._providers_index:
|
450
|
-
|
475
|
+
# Create a new list excluding the provider to be removed.
|
476
|
+
# Comparing by object identity (`is`) is crucial here.
|
477
|
+
providers_list = self._providers_index[provider.name]
|
478
|
+
new_list = [(p, m) for p, m in providers_list if p is not provider]
|
479
|
+
|
480
|
+
if not new_list:
|
481
|
+
# If the list becomes empty, remove the key from the dictionary.
|
482
|
+
del self._providers_index[provider.name]
|
483
|
+
else:
|
484
|
+
# Otherwise, update the dictionary with the new list.
|
485
|
+
self._providers_index[provider.name] = new_list
|
486
|
+
|
487
|
+
def provider(self, name: str) -> Optional[Union[ContextProvider, ProviderGroup]]:
|
488
|
+
indexed_list = self._providers_index.get(name)
|
489
|
+
if not indexed_list:
|
490
|
+
return None
|
451
491
|
|
452
|
-
|
453
|
-
|
454
|
-
|
492
|
+
providers = [p for p, m in indexed_list]
|
493
|
+
if len(providers) == 1:
|
494
|
+
return providers[0]
|
495
|
+
else:
|
496
|
+
return ProviderGroup(providers)
|
455
497
|
|
456
498
|
def pop(self, key: Optional[Union[str, int]] = None) -> Union[Optional[ContextProvider], Optional[Message]]:
|
457
499
|
# If no key is provided, pop the last message.
|
@@ -459,10 +501,13 @@ class Messages:
|
|
459
501
|
key = len(self._messages) - 1
|
460
502
|
|
461
503
|
if isinstance(key, str):
|
462
|
-
|
463
|
-
if not
|
504
|
+
indexed_list = self._providers_index.get(key)
|
505
|
+
if not indexed_list:
|
464
506
|
return None
|
465
|
-
|
507
|
+
# Pop the first one found, which is consistent with how pop usually works
|
508
|
+
_provider, parent_message = indexed_list[0]
|
509
|
+
# The actual removal from _providers_index happens in _notify_provider_removed
|
510
|
+
# which is called by message.pop()
|
466
511
|
return parent_message.pop(key)
|
467
512
|
elif isinstance(key, int):
|
468
513
|
try:
|
@@ -481,7 +526,10 @@ class Messages:
|
|
481
526
|
return None
|
482
527
|
|
483
528
|
async def refresh(self):
|
484
|
-
tasks = [
|
529
|
+
tasks = []
|
530
|
+
for provider_list in self._providers_index.values():
|
531
|
+
for provider, _ in provider_list:
|
532
|
+
tasks.append(provider.refresh())
|
485
533
|
await asyncio.gather(*tasks)
|
486
534
|
|
487
535
|
def render(self) -> List[Dict[str, Any]]:
|
aient/architext/test/test.py
CHANGED
@@ -1053,6 +1053,54 @@ Current time: {Texts(lambda: datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}
|
|
1053
1053
|
self.assertEqual(len(rendered_visible_again), 1)
|
1054
1054
|
self.assertEqual(rendered_visible_again[0]['content'], "Hello, World!")
|
1055
1055
|
|
1056
|
+
async def test_z8_bulk_provider_visibility_control(self):
|
1057
|
+
"""测试通过名称批量控制和豁免provider的可见性"""
|
1058
|
+
# 1. 创建多个同名 provider
|
1059
|
+
messages = Messages(
|
1060
|
+
UserMessage(
|
1061
|
+
Texts("First explanation.", name="explanation"),
|
1062
|
+
Texts("Second explanation.", name="explanation"),
|
1063
|
+
Texts("Some other text."),
|
1064
|
+
Texts("Third explanation.", name="explanation")
|
1065
|
+
)
|
1066
|
+
)
|
1067
|
+
|
1068
|
+
# 2. 初始渲染,所有 "explanation" 都应该可见
|
1069
|
+
rendered_initial = await messages.render_latest()
|
1070
|
+
self.assertIn("First explanation.", rendered_initial[0]['content'])
|
1071
|
+
self.assertIn("Second explanation.", rendered_initial[0]['content'])
|
1072
|
+
self.assertIn("Third explanation.", rendered_initial[0]['content'])
|
1073
|
+
|
1074
|
+
# 3. 获取所有名为 "explanation" 的 provider
|
1075
|
+
explanation_providers = messages.provider("explanation")
|
1076
|
+
self.assertIsInstance(explanation_providers, ProviderGroup)
|
1077
|
+
self.assertEqual(len(explanation_providers), 3)
|
1078
|
+
|
1079
|
+
# 4. 将所有 "explanation" provider 设置为不可见
|
1080
|
+
# 这是需要实现的新语法
|
1081
|
+
explanation_providers.visible = False
|
1082
|
+
for p in explanation_providers:
|
1083
|
+
self.assertFalse(p.visible)
|
1084
|
+
|
1085
|
+
# 5. 渲染,所有 "explanation" 的内容都应该消失
|
1086
|
+
rendered_hidden = await messages.render_latest()
|
1087
|
+
self.assertNotIn("First explanation.", rendered_hidden[0]['content'])
|
1088
|
+
self.assertNotIn("Second explanation.", rendered_hidden[0]['content'])
|
1089
|
+
self.assertNotIn("Third explanation.", rendered_hidden[0]['content'])
|
1090
|
+
self.assertIn("Some other text.", rendered_hidden[0]['content'])
|
1091
|
+
|
1092
|
+
# 6. 将最后一个 "explanation" provider 设置回可见
|
1093
|
+
# 这是需要实现的另一个新语法
|
1094
|
+
explanation_providers[-1].visible = True
|
1095
|
+
self.assertTrue(explanation_providers[-1].visible)
|
1096
|
+
self.assertFalse(explanation_providers[0].visible)
|
1097
|
+
|
1098
|
+
# 7. 最终渲染,只应看到最后一个 "explanation"
|
1099
|
+
rendered_final = await messages.render_latest()
|
1100
|
+
self.assertNotIn("First explanation.", rendered_final[0]['content'])
|
1101
|
+
self.assertNotIn("Second explanation.", rendered_final[0]['content'])
|
1102
|
+
self.assertIn("Third explanation.", rendered_final[0]['content'])
|
1103
|
+
self.assertIn("Some other text.", rendered_final[0]['content'])
|
1056
1104
|
|
1057
1105
|
# ==============================================================================
|
1058
1106
|
# 6. 演示
|
aient/models/chatgpt.py
CHANGED
@@ -97,13 +97,8 @@ class chatgpt(BaseLLM):
|
|
97
97
|
Initialize Chatbot with API key (from https://platform.openai.com/account/api-keys)
|
98
98
|
"""
|
99
99
|
super().__init__(api_key, engine, api_url, system_prompt, proxy, timeout, max_tokens, temperature, top_p, presence_penalty, frequency_penalty, reply_count, truncate_limit, use_plugins=use_plugins, print_log=print_log)
|
100
|
-
self.conversation: dict[str,
|
101
|
-
"default":
|
102
|
-
{
|
103
|
-
"role": "system",
|
104
|
-
"content": self.system_prompt,
|
105
|
-
},
|
106
|
-
],
|
100
|
+
self.conversation: dict[str, Messages] = {
|
101
|
+
"default": Messages(SystemMessage(self.system_prompt)),
|
107
102
|
}
|
108
103
|
if cache_messages:
|
109
104
|
self.conversation["default"] = cache_messages
|
@@ -1,8 +1,8 @@
|
|
1
1
|
aient/__init__.py,sha256=SRfF7oDVlOOAi6nGKiJIUK6B_arqYLO9iSMp-2IZZps,21
|
2
2
|
aient/architext/architext/__init__.py,sha256=79Ih1151rfcqZdr7F8HSZSTs_iT2SKd1xCkehMsXeXs,19
|
3
|
-
aient/architext/architext/core.py,sha256=
|
3
|
+
aient/architext/architext/core.py,sha256=zdgUvCBr_BMZ1s51ZN_28CjN8RDXY6DzpnodJaIG80Y,23735
|
4
4
|
aient/architext/test/openai_client.py,sha256=Dqtbmubv6vwF8uBqcayG0kbsiO65of7sgU2-DRBi-UM,4590
|
5
|
-
aient/architext/test/test.py,sha256=
|
5
|
+
aient/architext/test/test.py,sha256=EsgtMdIUSz4eCB7T0zs3wqYg_E1gP7h11WZyVevlqjw,52951
|
6
6
|
aient/architext/test/test_save_load.py,sha256=o8DqH6gDYZkFkQy-a7blqLtJTRj5e4a-Lil48pJ0V3g,3260
|
7
7
|
aient/core/__init__.py,sha256=NxjebTlku35S4Dzr16rdSqSTWUvvwEeACe8KvHJnjPg,34
|
8
8
|
aient/core/log_config.py,sha256=kz2_yJv1p-o3lUQOwA3qh-LSc3wMHv13iCQclw44W9c,274
|
@@ -17,7 +17,7 @@ aient/core/test/test_payload.py,sha256=8jBiJY1uidm1jzL-EiK0s6UGmW9XkdsuuKFGrwFhF
|
|
17
17
|
aient/models/__init__.py,sha256=ZTiZgbfBPTjIPSKURE7t6hlFBVLRS9lluGbmqc1WjxQ,43
|
18
18
|
aient/models/audio.py,sha256=kRd-8-WXzv4vwvsTGwnstK-WR8--vr9CdfCZzu8y9LA,1934
|
19
19
|
aient/models/base.py,sha256=-nnihYnx-vHZMqeVO9ljjt3k4FcD3n-iMk4tT-10nRQ,7232
|
20
|
-
aient/models/chatgpt.py,sha256=
|
20
|
+
aient/models/chatgpt.py,sha256=_fYGUnSciKSWYhPuS0lECtU6-C42GGGhFiqbVxejaJg,41919
|
21
21
|
aient/plugins/__init__.py,sha256=p3KO6Aa3Lupos4i2SjzLQw1hzQTigOAfEHngsldrsyk,986
|
22
22
|
aient/plugins/arXiv.py,sha256=yHjb6PS3GUWazpOYRMKMzghKJlxnZ5TX8z9F6UtUVow,1461
|
23
23
|
aient/plugins/config.py,sha256=TGgZ5SnNKZ8MmdznrZ-TEq7s2ulhAAwTSKH89bci3dA,7079
|
@@ -35,8 +35,8 @@ aient/plugins/write_file.py,sha256=Jt8fOEwqhYiSWpCbwfAr1xoi_BmFnx3076GMhuL06uI,3
|
|
35
35
|
aient/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
36
36
|
aient/utils/prompt.py,sha256=UcSzKkFE4-h_1b6NofI6xgk3GoleqALRKY8VBaXLjmI,11311
|
37
37
|
aient/utils/scripts.py,sha256=VqtK4RFEx7KxkmcqG3lFDS1DxoNlFFGErEjopVcc8IE,40974
|
38
|
-
aient-1.1.
|
39
|
-
aient-1.1.
|
40
|
-
aient-1.1.
|
41
|
-
aient-1.1.
|
42
|
-
aient-1.1.
|
38
|
+
aient-1.2.1.dist-info/licenses/LICENSE,sha256=XNdbcWldt0yaNXXWB_Bakoqnxb3OVhUft4MgMA_71ds,1051
|
39
|
+
aient-1.2.1.dist-info/METADATA,sha256=XlgqCsm1fEQcbglTceD25kCVKaFmO5rvcON71zdg5cs,4841
|
40
|
+
aient-1.2.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
41
|
+
aient-1.2.1.dist-info/top_level.txt,sha256=3oXzrP5sAVvyyqabpeq8A2_vfMtY554r4bVE-OHBrZk,6
|
42
|
+
aient-1.2.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|