beswarm 0.2.81__py3-none-any.whl → 0.2.83__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- beswarm/agents/planact.py +24 -71
- beswarm/aient/aient/architext/architext/__init__.py +1 -0
- beswarm/aient/aient/architext/architext/core.py +694 -0
- beswarm/aient/aient/architext/test/openai_client.py +146 -0
- beswarm/aient/aient/architext/test/test.py +1410 -0
- beswarm/aient/aient/architext/test/test_save_load.py +93 -0
- beswarm/aient/aient/models/chatgpt.py +39 -111
- beswarm/prompt.py +44 -17
- {beswarm-0.2.81.dist-info → beswarm-0.2.83.dist-info}/METADATA +1 -1
- {beswarm-0.2.81.dist-info → beswarm-0.2.83.dist-info}/RECORD +12 -7
- {beswarm-0.2.81.dist-info → beswarm-0.2.83.dist-info}/WHEEL +0 -0
- {beswarm-0.2.81.dist-info → beswarm-0.2.83.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,694 @@
|
|
1
|
+
import pickle
|
2
|
+
import base64
|
3
|
+
import asyncio
|
4
|
+
import logging
|
5
|
+
import hashlib
|
6
|
+
import mimetypes
|
7
|
+
import uuid
|
8
|
+
import threading
|
9
|
+
from dataclasses import dataclass
|
10
|
+
from abc import ABC, abstractmethod
|
11
|
+
from typing import List, Dict, Any, Optional, Union, Callable
|
12
|
+
|
13
|
+
# A wrapper to manage multiple providers with the same name
|
14
|
+
class ProviderGroup:
|
15
|
+
"""A container for multiple providers that share the same name, allowing for bulk operations."""
|
16
|
+
def __init__(self, providers: List['ContextProvider']):
|
17
|
+
self._providers = providers
|
18
|
+
def __getitem__(self, key: int) -> 'ContextProvider':
|
19
|
+
"""Allows accessing providers by index, e.g., group[-1]."""
|
20
|
+
return self._providers[key]
|
21
|
+
def __iter__(self):
|
22
|
+
"""Allows iterating over the providers."""
|
23
|
+
return iter(self._providers)
|
24
|
+
def __len__(self) -> int:
|
25
|
+
"""Returns the number of providers in the group."""
|
26
|
+
return len(self._providers)
|
27
|
+
@property
|
28
|
+
def visible(self) -> List[bool]:
|
29
|
+
"""Gets the visibility of all providers in the group."""
|
30
|
+
return [p.visible for p in self._providers]
|
31
|
+
@visible.setter
|
32
|
+
def visible(self, value: bool):
|
33
|
+
"""Sets the visibility for all providers in the group."""
|
34
|
+
for p in self._providers:
|
35
|
+
p.visible = value
|
36
|
+
|
37
|
+
# Global, thread-safe registry for providers created within f-strings
|
38
|
+
_fstring_provider_registry = {}
|
39
|
+
_registry_lock = threading.Lock()
|
40
|
+
|
41
|
+
def _register_provider(provider: 'ContextProvider') -> str:
|
42
|
+
"""Registers a provider and returns a unique placeholder."""
|
43
|
+
with _registry_lock:
|
44
|
+
provider_id = f"__provider_placeholder_{uuid.uuid4().hex}__"
|
45
|
+
_fstring_provider_registry[provider_id] = provider
|
46
|
+
return provider_id
|
47
|
+
|
48
|
+
def _retrieve_provider(placeholder: str) -> Optional['ContextProvider']:
|
49
|
+
"""Retrieves a provider from the registry."""
|
50
|
+
with _registry_lock:
|
51
|
+
return _fstring_provider_registry.pop(placeholder, None)
|
52
|
+
|
53
|
+
# 1. 核心数据结构: ContentBlock
|
54
|
+
@dataclass
|
55
|
+
class ContentBlock:
|
56
|
+
name: str
|
57
|
+
content: str
|
58
|
+
|
59
|
+
# 2. 上下文提供者 (带缓存)
|
60
|
+
class ContextProvider(ABC):
|
61
|
+
def __init__(self, name: str):
|
62
|
+
self.name = name
|
63
|
+
self._cached_content: Optional[str] = None
|
64
|
+
self._is_stale: bool = True
|
65
|
+
self._visible: bool = True
|
66
|
+
|
67
|
+
def __str__(self):
|
68
|
+
# This allows the object to be captured when used inside an f-string.
|
69
|
+
return _register_provider(self)
|
70
|
+
|
71
|
+
def mark_stale(self): self._is_stale = True
|
72
|
+
|
73
|
+
@property
|
74
|
+
def visible(self) -> bool:
|
75
|
+
"""Gets the visibility of the provider."""
|
76
|
+
return self._visible
|
77
|
+
|
78
|
+
@visible.setter
|
79
|
+
def visible(self, value: bool):
|
80
|
+
"""Sets the visibility of the provider."""
|
81
|
+
if self._visible != value:
|
82
|
+
self._visible = value
|
83
|
+
# Content needs to be re-evaluated, but the source data hasn't changed,
|
84
|
+
# so just marking it stale is enough for the renderer to reconsider it.
|
85
|
+
self.mark_stale()
|
86
|
+
async def refresh(self):
|
87
|
+
if self._is_stale:
|
88
|
+
self._cached_content = await self.render()
|
89
|
+
self._is_stale = False
|
90
|
+
@abstractmethod
|
91
|
+
async def render(self) -> Optional[str]: raise NotImplementedError
|
92
|
+
@abstractmethod
|
93
|
+
def update(self, *args, **kwargs): raise NotImplementedError
|
94
|
+
def get_content_block(self) -> Optional[ContentBlock]:
|
95
|
+
if self.visible and self._cached_content is not None:
|
96
|
+
return ContentBlock(self.name, self._cached_content)
|
97
|
+
return None
|
98
|
+
|
99
|
+
class Texts(ContextProvider):
|
100
|
+
def __init__(self, text: Optional[Union[str, Callable[[], str]]] = None, name: Optional[str] = None):
|
101
|
+
if text is None and name is None:
|
102
|
+
raise ValueError("Either 'text' or 'name' must be provided.")
|
103
|
+
|
104
|
+
# Ensure that non-callable inputs are treated as strings
|
105
|
+
if not callable(text):
|
106
|
+
self._text = str(text) if text is not None else None
|
107
|
+
else:
|
108
|
+
self._text = text
|
109
|
+
|
110
|
+
self._is_dynamic = callable(self._text)
|
111
|
+
|
112
|
+
if name is None:
|
113
|
+
if self._is_dynamic:
|
114
|
+
import uuid
|
115
|
+
_name = f"dynamic_text_{uuid.uuid4().hex[:8]}"
|
116
|
+
else:
|
117
|
+
# Handle the case where text is None during initialization
|
118
|
+
h = hashlib.sha1(self._text.encode() if self._text else b'').hexdigest()
|
119
|
+
_name = f"text_{h[:8]}"
|
120
|
+
else:
|
121
|
+
_name = name
|
122
|
+
super().__init__(_name)
|
123
|
+
|
124
|
+
async def refresh(self):
|
125
|
+
if self._is_dynamic:
|
126
|
+
self._is_stale = True
|
127
|
+
await super().refresh()
|
128
|
+
|
129
|
+
def update(self, text: Union[str, Callable[[], str]]):
|
130
|
+
self._text = text
|
131
|
+
self._is_dynamic = callable(self._text)
|
132
|
+
self.mark_stale()
|
133
|
+
|
134
|
+
@property
|
135
|
+
def content(self) -> Optional[str]:
|
136
|
+
"""
|
137
|
+
Synchronously retrieves the raw text content as a property.
|
138
|
+
If the content is dynamic (a callable), it executes the callable.
|
139
|
+
"""
|
140
|
+
if self._is_dynamic:
|
141
|
+
# Ensure dynamic content returns a string, even if empty
|
142
|
+
result = self._text()
|
143
|
+
return result if result is not None else ""
|
144
|
+
# Ensure static content returns a string, even if empty
|
145
|
+
return self._text if self._text is not None else ""
|
146
|
+
|
147
|
+
async def render(self) -> Optional[str]:
|
148
|
+
return self.content
|
149
|
+
|
150
|
+
def __getstate__(self):
|
151
|
+
"""Custom state for pickling."""
|
152
|
+
state = self.__dict__.copy()
|
153
|
+
if self._is_dynamic:
|
154
|
+
# For dynamic content, we snapshot its current value for serialization.
|
155
|
+
# The lambda function itself cannot be pickled.
|
156
|
+
try:
|
157
|
+
# Evaluate the lambda and store it as a static string
|
158
|
+
state['_text'] = self.content
|
159
|
+
# Mark it as no longer dynamic in the pickled state
|
160
|
+
state['_is_dynamic'] = False
|
161
|
+
except Exception as e:
|
162
|
+
# If the lambda fails for some reason, store an error message.
|
163
|
+
logging.error(f"Error evaluating dynamic text '{self.name}' during pickling: {e}")
|
164
|
+
state['_text'] = f"[Error: Could not evaluate dynamic content during save: {e}]"
|
165
|
+
state['_is_dynamic'] = False
|
166
|
+
return state
|
167
|
+
|
168
|
+
def __setstate__(self, state):
|
169
|
+
"""Custom state for unpickling."""
|
170
|
+
# Just restore the dictionary. The transformation is one-way.
|
171
|
+
self.__dict__.update(state)
|
172
|
+
|
173
|
+
def __eq__(self, other):
|
174
|
+
if not isinstance(other, Texts):
|
175
|
+
return NotImplemented
|
176
|
+
# If either object is dynamic, they are only equal if they are the exact same object.
|
177
|
+
if self._is_dynamic or (hasattr(other, '_is_dynamic') and other._is_dynamic):
|
178
|
+
return self is other
|
179
|
+
# For static content, compare the actual content.
|
180
|
+
return self.content == other.content
|
181
|
+
|
182
|
+
class Tools(ContextProvider):
|
183
|
+
def __init__(self, tools_json: Optional[List[Dict]] = None, name: str = "tools"):
|
184
|
+
super().__init__(name)
|
185
|
+
self._tools_json = tools_json or []
|
186
|
+
def update(self, tools_json: List[Dict]):
|
187
|
+
self._tools_json = tools_json
|
188
|
+
self.mark_stale()
|
189
|
+
async def render(self) -> Optional[str]:
|
190
|
+
if not self._tools_json:
|
191
|
+
return None
|
192
|
+
return f"<tools>{str(self._tools_json)}</tools>"
|
193
|
+
|
194
|
+
def __eq__(self, other):
|
195
|
+
if not isinstance(other, Tools):
|
196
|
+
return NotImplemented
|
197
|
+
return self._tools_json == other._tools_json
|
198
|
+
|
199
|
+
class Files(ContextProvider):
|
200
|
+
def __init__(self, *paths: Union[str, List[str]], name: str = "files"):
|
201
|
+
super().__init__(name)
|
202
|
+
self._files: Dict[str, str] = {}
|
203
|
+
|
204
|
+
file_paths: List[str] = []
|
205
|
+
if paths:
|
206
|
+
# Handle the case where the first argument is a list of paths, e.g., Files(['a', 'b'])
|
207
|
+
if len(paths) == 1 and isinstance(paths[0], list):
|
208
|
+
file_paths.extend(paths[0])
|
209
|
+
# Handle the case where arguments are individual string paths, e.g., Files('a', 'b')
|
210
|
+
else:
|
211
|
+
file_paths.extend(paths)
|
212
|
+
|
213
|
+
if file_paths:
|
214
|
+
for path in file_paths:
|
215
|
+
try:
|
216
|
+
with open(path, 'r', encoding='utf-8') as f:
|
217
|
+
self._files[path] = f.read()
|
218
|
+
except FileNotFoundError:
|
219
|
+
logging.warning(f"File not found during initialization: {path}. Skipping.")
|
220
|
+
except Exception as e:
|
221
|
+
logging.error(f"Error reading file {path} during initialization: {e}")
|
222
|
+
|
223
|
+
async def refresh(self):
|
224
|
+
"""
|
225
|
+
Overrides the default refresh behavior. It synchronizes the content of
|
226
|
+
all tracked files with the file system. If a file is not found, its
|
227
|
+
content is updated to reflect the error.
|
228
|
+
"""
|
229
|
+
is_changed = False
|
230
|
+
for path in list(self._files.keys()):
|
231
|
+
try:
|
232
|
+
with open(path, 'r', encoding='utf-8') as f:
|
233
|
+
new_content = f.read()
|
234
|
+
if self._files.get(path) != new_content:
|
235
|
+
self._files[path] = new_content
|
236
|
+
is_changed = True
|
237
|
+
except FileNotFoundError:
|
238
|
+
error_msg = f"[Error: File not found at path '{path}']"
|
239
|
+
if self._files.get(path) != error_msg:
|
240
|
+
self._files[path] = error_msg
|
241
|
+
is_changed = True
|
242
|
+
except Exception as e:
|
243
|
+
error_msg = f"[Error: Could not read file at path '{path}': {e}]"
|
244
|
+
if self._files.get(path) != error_msg:
|
245
|
+
self._files[path] = error_msg
|
246
|
+
is_changed = True
|
247
|
+
|
248
|
+
if is_changed:
|
249
|
+
self.mark_stale()
|
250
|
+
|
251
|
+
await super().refresh()
|
252
|
+
|
253
|
+
def update(self, path: str, content: Optional[str] = None):
|
254
|
+
"""
|
255
|
+
Updates a single file. If content is provided, it updates the file in
|
256
|
+
memory. If content is None, it reads the file from disk.
|
257
|
+
"""
|
258
|
+
if content is not None:
|
259
|
+
self._files[path] = content
|
260
|
+
else:
|
261
|
+
try:
|
262
|
+
with open(path, 'r', encoding='utf-8') as f:
|
263
|
+
self._files[path] = f.read()
|
264
|
+
except FileNotFoundError:
|
265
|
+
logging.error(f"File not found for update: {path}.")
|
266
|
+
self._files[path] = f"[Error: File not found at path '{path}']"
|
267
|
+
except Exception as e:
|
268
|
+
logging.error(f"Error reading file for update {path}: {e}.")
|
269
|
+
self._files[path] = f"[Error: Could not read file at path '{path}': {e}]"
|
270
|
+
self.mark_stale()
|
271
|
+
async def render(self) -> str:
|
272
|
+
if not self._files: return None
|
273
|
+
return "<latest_file_content>" + "\n".join([f"<file><file_path>{p}</file_path><file_content>{c}</file_content></file>" for p, c in self._files.items()]) + "\n</latest_file_content>"
|
274
|
+
|
275
|
+
def __eq__(self, other):
|
276
|
+
if not isinstance(other, Files):
|
277
|
+
return NotImplemented
|
278
|
+
return self._files == other._files
|
279
|
+
|
280
|
+
class Images(ContextProvider):
|
281
|
+
def __init__(self, url: str, name: Optional[str] = None):
|
282
|
+
super().__init__(name or url)
|
283
|
+
self.url = url
|
284
|
+
def update(self, url: str):
|
285
|
+
self.url = url
|
286
|
+
self.mark_stale()
|
287
|
+
async def render(self) -> Optional[str]:
|
288
|
+
if self.url.startswith("data:"):
|
289
|
+
return self.url
|
290
|
+
try:
|
291
|
+
with open(self.url, "rb") as image_file:
|
292
|
+
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
|
293
|
+
mime_type, _ = mimetypes.guess_type(self.url)
|
294
|
+
if not mime_type: mime_type = "application/octet-stream" # Fallback
|
295
|
+
return f"data:{mime_type};base64,{encoded_string}"
|
296
|
+
except FileNotFoundError:
|
297
|
+
logging.warning(f"Image file not found: {self.url}. Skipping.")
|
298
|
+
return None # Or handle error appropriately
|
299
|
+
|
300
|
+
def __eq__(self, other):
|
301
|
+
if not isinstance(other, Images):
|
302
|
+
return NotImplemented
|
303
|
+
return self.url == other.url
|
304
|
+
|
305
|
+
# 3. 消息类 (已合并 MessageContent)
|
306
|
+
class Message(ABC):
|
307
|
+
def __init__(self, role: str, *initial_items: Union[ContextProvider, str, list]):
|
308
|
+
self.role = role
|
309
|
+
processed_items = []
|
310
|
+
for item in initial_items:
|
311
|
+
if isinstance(item, str):
|
312
|
+
# Check if the string contains placeholders from f-string rendering
|
313
|
+
import re
|
314
|
+
placeholder_pattern = re.compile(r'(__provider_placeholder_[a-f0-9]{32}__)')
|
315
|
+
parts = placeholder_pattern.split(item)
|
316
|
+
|
317
|
+
if len(parts) > 1: # Placeholders were found
|
318
|
+
for part in parts:
|
319
|
+
if not part: continue
|
320
|
+
if placeholder_pattern.match(part):
|
321
|
+
provider = _retrieve_provider(part)
|
322
|
+
if provider:
|
323
|
+
processed_items.append(provider)
|
324
|
+
else:
|
325
|
+
processed_items.append(Texts(text=part))
|
326
|
+
else: # No placeholders, just a regular string
|
327
|
+
processed_items.append(Texts(text=item))
|
328
|
+
|
329
|
+
elif isinstance(item, Message):
|
330
|
+
processed_items.extend(item.provider())
|
331
|
+
elif isinstance(item, ContextProvider):
|
332
|
+
processed_items.append(item)
|
333
|
+
elif isinstance(item, list):
|
334
|
+
for sub_item in item:
|
335
|
+
if not isinstance(sub_item, dict) or 'type' not in sub_item:
|
336
|
+
raise ValueError("List items must be dicts with a 'type' key.")
|
337
|
+
|
338
|
+
item_type = sub_item['type']
|
339
|
+
if item_type == 'text':
|
340
|
+
processed_items.append(Texts(text=sub_item.get('text', '')))
|
341
|
+
elif item_type == 'image_url':
|
342
|
+
image_url = sub_item.get('image_url', {}).get('url')
|
343
|
+
if image_url:
|
344
|
+
processed_items.append(Images(url=image_url))
|
345
|
+
else:
|
346
|
+
raise ValueError(f"Unsupported item type in list: {item_type}")
|
347
|
+
else:
|
348
|
+
raise TypeError(f"Unsupported item type: {type(item)}. Must be str, ContextProvider, or list.")
|
349
|
+
self._items: List[ContextProvider] = processed_items
|
350
|
+
self._parent_messages: Optional['Messages'] = None
|
351
|
+
|
352
|
+
@property
|
353
|
+
def content(self) -> Optional[Union[str, List[Dict[str, Any]]]]:
|
354
|
+
"""
|
355
|
+
Renders the message content.
|
356
|
+
For simple text messages, returns a string.
|
357
|
+
For multimodal messages, returns a list of content blocks.
|
358
|
+
"""
|
359
|
+
rendered_dict = self.to_dict()
|
360
|
+
return rendered_dict.get('content') if rendered_dict else None
|
361
|
+
|
362
|
+
def _render_content(self) -> str:
|
363
|
+
final_parts = []
|
364
|
+
for item in self._items:
|
365
|
+
block = item.get_content_block()
|
366
|
+
if block and block.content is not None:
|
367
|
+
final_parts.append(block.content)
|
368
|
+
|
369
|
+
return "".join(final_parts)
|
370
|
+
|
371
|
+
def pop(self, name: str) -> Optional[ContextProvider]:
|
372
|
+
popped_item = None
|
373
|
+
for i, item in enumerate(self._items):
|
374
|
+
if hasattr(item, 'name') and item.name == name:
|
375
|
+
popped_item = self._items.pop(i)
|
376
|
+
break
|
377
|
+
if popped_item and self._parent_messages:
|
378
|
+
self._parent_messages._notify_provider_removed(popped_item)
|
379
|
+
return popped_item
|
380
|
+
|
381
|
+
def insert(self, index: int, item: ContextProvider):
|
382
|
+
self._items.insert(index, item)
|
383
|
+
if self._parent_messages:
|
384
|
+
self._parent_messages._notify_provider_added(item, self)
|
385
|
+
|
386
|
+
def append(self, item: ContextProvider):
|
387
|
+
self._items.append(item)
|
388
|
+
if self._parent_messages:
|
389
|
+
self._parent_messages._notify_provider_added(item, self)
|
390
|
+
|
391
|
+
def provider(self, name: Optional[str] = None) -> Optional[Union[ContextProvider, ProviderGroup, List[ContextProvider]]]:
|
392
|
+
if name is None:
|
393
|
+
return self._items
|
394
|
+
|
395
|
+
named_providers = [p for p in self._items if hasattr(p, 'name') and p.name == name]
|
396
|
+
|
397
|
+
if not named_providers:
|
398
|
+
return None
|
399
|
+
if len(named_providers) == 1:
|
400
|
+
return named_providers[0]
|
401
|
+
return ProviderGroup(named_providers)
|
402
|
+
|
403
|
+
def __add__(self, other):
|
404
|
+
if isinstance(other, str):
|
405
|
+
new_items = self._items + [Texts(text=other)]
|
406
|
+
return type(self)(*new_items)
|
407
|
+
if isinstance(other, Message):
|
408
|
+
new_items = self._items + other.provider()
|
409
|
+
return type(self)(*new_items)
|
410
|
+
return NotImplemented
|
411
|
+
|
412
|
+
def __radd__(self, other):
|
413
|
+
if isinstance(other, str):
|
414
|
+
new_items = [Texts(text=other)] + self._items
|
415
|
+
return type(self)(*new_items)
|
416
|
+
if isinstance(other, Message):
|
417
|
+
new_items = other.provider() + self._items
|
418
|
+
return type(self)(*new_items)
|
419
|
+
return NotImplemented
|
420
|
+
|
421
|
+
def __getitem__(self, key: Union[str, int]) -> Any:
|
422
|
+
"""
|
423
|
+
使得 Message 对象支持字典风格的访问 (e.g., message['content'])
|
424
|
+
和列表风格的索引访问 (e.g., message[-1])。
|
425
|
+
"""
|
426
|
+
if isinstance(key, str):
|
427
|
+
if key == 'role':
|
428
|
+
return self.role
|
429
|
+
elif key == 'content':
|
430
|
+
# 直接调用 to_dict 并提取 'content',确保逻辑一致
|
431
|
+
rendered_dict = self.to_dict()
|
432
|
+
return rendered_dict.get('content') if rendered_dict else None
|
433
|
+
# 对于 tool_calls 等特殊属性,也通过 to_dict 获取
|
434
|
+
elif hasattr(self, key):
|
435
|
+
rendered_dict = self.to_dict()
|
436
|
+
if rendered_dict and key in rendered_dict:
|
437
|
+
return rendered_dict[key]
|
438
|
+
|
439
|
+
# 如果在对象本身或其 to_dict() 中都找不到,则引发 KeyError
|
440
|
+
if hasattr(self, key):
|
441
|
+
return getattr(self, key)
|
442
|
+
raise KeyError(f"'{key}'")
|
443
|
+
elif isinstance(key, int):
|
444
|
+
return self._items[key]
|
445
|
+
else:
|
446
|
+
raise TypeError(f"Message indices must be integers or strings, not {type(key).__name__}")
|
447
|
+
|
448
|
+
def __len__(self) -> int:
|
449
|
+
"""返回消息中 provider 的数量。"""
|
450
|
+
return len(self._items)
|
451
|
+
|
452
|
+
def __repr__(self): return f"Message(role='{self.role}', items={[i.name for i in self._items]})"
|
453
|
+
def __bool__(self) -> bool:
|
454
|
+
return bool(self._items)
|
455
|
+
def get(self, key: str, default: Any = None) -> Any:
|
456
|
+
"""提供类似字典的 .get() 方法来访问属性。"""
|
457
|
+
return getattr(self, key, default)
|
458
|
+
def to_dict(self) -> Optional[Dict[str, Any]]:
|
459
|
+
is_multimodal = any(isinstance(p, Images) for p in self._items)
|
460
|
+
|
461
|
+
if not is_multimodal:
|
462
|
+
rendered_content = self._render_content()
|
463
|
+
if not rendered_content: return None
|
464
|
+
return {"role": self.role, "content": rendered_content}
|
465
|
+
else:
|
466
|
+
content_list = []
|
467
|
+
for item in self._items:
|
468
|
+
block = item.get_content_block()
|
469
|
+
if not block or not block.content: continue
|
470
|
+
if isinstance(item, Images):
|
471
|
+
content_list.append({"type": "image_url", "image_url": {"url": block.content}})
|
472
|
+
else:
|
473
|
+
content_list.append({"type": "text", "text": block.content})
|
474
|
+
if not content_list: return None
|
475
|
+
return {"role": self.role, "content": content_list}
|
476
|
+
|
477
|
+
class SystemMessage(Message):
|
478
|
+
def __init__(self, *items): super().__init__("system", *items)
|
479
|
+
class UserMessage(Message):
|
480
|
+
def __init__(self, *items): super().__init__("user", *items)
|
481
|
+
class AssistantMessage(Message):
|
482
|
+
def __init__(self, *items): super().__init__("assistant", *items)
|
483
|
+
|
484
|
+
class RoleMessage:
|
485
|
+
"""A factory class that creates a specific message type based on the role."""
|
486
|
+
def __new__(cls, role: str, *items):
|
487
|
+
if role == 'system':
|
488
|
+
return SystemMessage(*items)
|
489
|
+
elif role == 'user':
|
490
|
+
return UserMessage(*items)
|
491
|
+
elif role == 'assistant':
|
492
|
+
return AssistantMessage(*items)
|
493
|
+
else:
|
494
|
+
raise ValueError(f"Invalid role: {role}. Must be 'system', 'user', or 'assistant'.")
|
495
|
+
|
496
|
+
class ToolCalls(Message):
|
497
|
+
"""Represents an assistant message that requests tool calls."""
|
498
|
+
def __init__(self, tool_calls: List[Any]):
|
499
|
+
super().__init__("assistant")
|
500
|
+
self.tool_calls = tool_calls
|
501
|
+
|
502
|
+
def to_dict(self) -> Dict[str, Any]:
|
503
|
+
# Duck-typing serialization for OpenAI's tool_call objects
|
504
|
+
serialized_calls = []
|
505
|
+
for tc in self.tool_calls:
|
506
|
+
try:
|
507
|
+
# Attempt to serialize based on openai-python > 1.0 tool_call structure
|
508
|
+
func = tc.function
|
509
|
+
serialized_calls.append({
|
510
|
+
"id": tc.id,
|
511
|
+
"type": tc.type,
|
512
|
+
"function": { "name": func.name, "arguments": func.arguments }
|
513
|
+
})
|
514
|
+
except AttributeError:
|
515
|
+
if isinstance(tc, dict):
|
516
|
+
serialized_calls.append(tc) # Assume it's already a serializable dict
|
517
|
+
else:
|
518
|
+
raise TypeError(f"Unsupported tool_call type: {type(tc)}. It should be an OpenAI tool_call object or a dict.")
|
519
|
+
|
520
|
+
return {
|
521
|
+
"role": self.role,
|
522
|
+
"tool_calls": serialized_calls,
|
523
|
+
"content": None
|
524
|
+
}
|
525
|
+
|
526
|
+
class ToolResults(Message):
|
527
|
+
"""Represents a tool message with the result of a single tool call."""
|
528
|
+
def __init__(self, tool_call_id: str, content: str):
|
529
|
+
# We pass a Texts provider to the parent so it can be rendered,
|
530
|
+
# but the primary way to access content for ToolResults is via its dict representation.
|
531
|
+
super().__init__("tool", Texts(text=content))
|
532
|
+
self.tool_call_id = tool_call_id
|
533
|
+
self._content = content
|
534
|
+
|
535
|
+
def to_dict(self) -> Dict[str, Any]:
|
536
|
+
return {
|
537
|
+
"role": self.role,
|
538
|
+
"tool_call_id": self.tool_call_id,
|
539
|
+
"content": self._content
|
540
|
+
}
|
541
|
+
|
542
|
+
# 4. 顶层容器: Messages
|
543
|
+
class Messages:
|
544
|
+
def __init__(self, *initial_messages: Message):
|
545
|
+
from typing import Tuple
|
546
|
+
self._messages: List[Message] = []
|
547
|
+
self._providers_index: Dict[str, List[Tuple[ContextProvider, Message]]] = {}
|
548
|
+
if initial_messages:
|
549
|
+
for msg in initial_messages:
|
550
|
+
self.append(msg)
|
551
|
+
|
552
|
+
def _notify_provider_added(self, provider: ContextProvider, message: Message):
|
553
|
+
if provider.name not in self._providers_index:
|
554
|
+
self._providers_index[provider.name] = []
|
555
|
+
self._providers_index[provider.name].append((provider, message))
|
556
|
+
|
557
|
+
def _notify_provider_removed(self, provider: ContextProvider):
|
558
|
+
if provider.name in self._providers_index:
|
559
|
+
# Create a new list excluding the provider to be removed.
|
560
|
+
# Comparing by object identity (`is`) is crucial here.
|
561
|
+
providers_list = self._providers_index[provider.name]
|
562
|
+
new_list = [(p, m) for p, m in providers_list if p is not provider]
|
563
|
+
|
564
|
+
if not new_list:
|
565
|
+
# If the list becomes empty, remove the key from the dictionary.
|
566
|
+
del self._providers_index[provider.name]
|
567
|
+
else:
|
568
|
+
# Otherwise, update the dictionary with the new list.
|
569
|
+
self._providers_index[provider.name] = new_list
|
570
|
+
|
571
|
+
def provider(self, name: str) -> Optional[Union[ContextProvider, ProviderGroup]]:
|
572
|
+
indexed_list = self._providers_index.get(name)
|
573
|
+
if not indexed_list:
|
574
|
+
return None
|
575
|
+
|
576
|
+
providers = [p for p, m in indexed_list]
|
577
|
+
if len(providers) == 1:
|
578
|
+
return providers[0]
|
579
|
+
else:
|
580
|
+
return ProviderGroup(providers)
|
581
|
+
|
582
|
+
def pop(self, key: Optional[Union[str, int]] = None) -> Union[Optional[ContextProvider], Optional[Message]]:
|
583
|
+
# If no key is provided, pop the last message.
|
584
|
+
if key is None:
|
585
|
+
key = len(self._messages) - 1
|
586
|
+
|
587
|
+
if isinstance(key, str):
|
588
|
+
indexed_list = self._providers_index.get(key)
|
589
|
+
if not indexed_list:
|
590
|
+
return None
|
591
|
+
# Pop the first one found, which is consistent with how pop usually works
|
592
|
+
_provider, parent_message = indexed_list[0]
|
593
|
+
# The actual removal from _providers_index happens in _notify_provider_removed
|
594
|
+
# which is called by message.pop()
|
595
|
+
return parent_message.pop(key)
|
596
|
+
elif isinstance(key, int):
|
597
|
+
try:
|
598
|
+
if key < 0: # Handle negative indices like -1
|
599
|
+
key += len(self._messages)
|
600
|
+
if not (0 <= key < len(self._messages)):
|
601
|
+
return None
|
602
|
+
popped_message = self._messages.pop(key)
|
603
|
+
popped_message._parent_messages = None
|
604
|
+
for provider in popped_message.provider():
|
605
|
+
self._notify_provider_removed(provider)
|
606
|
+
return popped_message
|
607
|
+
except IndexError:
|
608
|
+
return None
|
609
|
+
|
610
|
+
return None
|
611
|
+
|
612
|
+
async def refresh(self):
|
613
|
+
tasks = []
|
614
|
+
for provider_list in self._providers_index.values():
|
615
|
+
for provider, _ in provider_list:
|
616
|
+
tasks.append(provider.refresh())
|
617
|
+
await asyncio.gather(*tasks)
|
618
|
+
|
619
|
+
def render(self) -> List[Dict[str, Any]]:
|
620
|
+
results = [msg.to_dict() for msg in self._messages]
|
621
|
+
return [res for res in results if res]
|
622
|
+
|
623
|
+
async def render_latest(self) -> List[Dict[str, Any]]:
|
624
|
+
await self.refresh()
|
625
|
+
return self.render()
|
626
|
+
|
627
|
+
def append(self, message: Message):
|
628
|
+
if self._messages and self._messages[-1].role == message.role:
|
629
|
+
last_message = self._messages[-1]
|
630
|
+
for provider in message.provider():
|
631
|
+
last_message.append(provider)
|
632
|
+
else:
|
633
|
+
message._parent_messages = self
|
634
|
+
self._messages.append(message)
|
635
|
+
for p in message.provider():
|
636
|
+
self._notify_provider_added(p, message)
|
637
|
+
|
638
|
+
def save(self, file_path: str):
|
639
|
+
"""
|
640
|
+
Saves the entire Messages object to a file using pickle.
|
641
|
+
Warning: Deserializing data with pickle from an untrusted source is insecure.
|
642
|
+
"""
|
643
|
+
with open(file_path, 'wb') as f:
|
644
|
+
pickle.dump(self, f)
|
645
|
+
|
646
|
+
@classmethod
|
647
|
+
def load(cls, file_path: str) -> Optional['Messages']:
|
648
|
+
"""
|
649
|
+
Loads a Messages object from a file using pickle.
|
650
|
+
Returns the loaded object, or None if the file is not found or an error occurs.
|
651
|
+
Warning: Only load files from a trusted source.
|
652
|
+
"""
|
653
|
+
try:
|
654
|
+
with open(file_path, 'rb') as f:
|
655
|
+
return pickle.load(f)
|
656
|
+
except FileNotFoundError:
|
657
|
+
# logging.warning(f"File not found at {file_path}, returning empty Messages.")
|
658
|
+
return cls()
|
659
|
+
except (pickle.UnpicklingError, EOFError) as e:
|
660
|
+
logging.error(f"Could not deserialize file {file_path}: {e}")
|
661
|
+
return cls()
|
662
|
+
|
663
|
+
def __getitem__(self, index: Union[int, slice]) -> Union[Message, 'Messages']:
|
664
|
+
if isinstance(index, slice):
|
665
|
+
return Messages(*self._messages[index])
|
666
|
+
return self._messages[index]
|
667
|
+
|
668
|
+
def __setitem__(self, index: slice, value: 'Messages'):
|
669
|
+
if not isinstance(index, slice) or not isinstance(value, Messages):
|
670
|
+
raise TypeError("Unsupported operand type(s) for slice assignment")
|
671
|
+
|
672
|
+
# Basic slice assignment logic.
|
673
|
+
# A more robust implementation would handle step and negative indices.
|
674
|
+
start, stop, step = index.indices(len(self._messages))
|
675
|
+
|
676
|
+
if step != 1:
|
677
|
+
raise ValueError("Slice assignment with step is not supported.")
|
678
|
+
|
679
|
+
# Remove old providers from the index
|
680
|
+
for i in range(start, stop):
|
681
|
+
for provider in self._messages[i].provider():
|
682
|
+
self._notify_provider_removed(provider)
|
683
|
+
|
684
|
+
# Replace the slice in the list
|
685
|
+
self._messages[start:stop] = value._messages
|
686
|
+
|
687
|
+
# Add new providers to the index and set parent
|
688
|
+
for msg in value:
|
689
|
+
msg._parent_messages = self
|
690
|
+
for provider in msg.provider():
|
691
|
+
self._notify_provider_added(provider, msg)
|
692
|
+
|
693
|
+
def __len__(self) -> int: return len(self._messages)
|
694
|
+
def __iter__(self): return iter(self._messages)
|