aient 1.1.90__py3-none-any.whl → 1.1.92__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,165 +1,436 @@
1
+ import pickle
2
+ import base64
1
3
  import asyncio
4
+ import logging
5
+ import hashlib
6
+ import mimetypes
7
+ from dataclasses import dataclass
2
8
  from abc import ABC, abstractmethod
3
9
  from typing import List, Dict, Any, Optional, Union
4
10
 
5
11
  # 1. 核心数据结构: ContentBlock
12
+ @dataclass
6
13
  class ContentBlock:
7
- def __init__(self, name: str, content: str, provider: Optional['ContextProvider'] = None):
8
- self.name = name; self.content = content; self.provider = provider
9
- def __repr__(self): return f"Block(name='{self.name}')"
14
+ name: str
15
+ content: str
10
16
 
11
17
  # 2. 上下文提供者 (带缓存)
12
18
  class ContextProvider(ABC):
13
19
  def __init__(self, name: str):
14
20
  self.name = name; self._cached_content: Optional[str] = None; self._is_stale: bool = True
15
21
  def mark_stale(self): self._is_stale = True
16
- async def _refresh(self):
22
+ async def refresh(self):
17
23
  if self._is_stale:
18
- # 注意:我们将在这个方法上使用 mock,所以实际的 print 不再重要
19
- # print(f"信息: 正在为上下文提供者 '{self.name}' 刷新内容...")
20
- self._cached_content = await self._fetch_content()
24
+ self._cached_content = await self.render()
21
25
  self._is_stale = False
22
- # else:
23
- # print(f"调试: 上下文提供者 '{self.name}' 正在使用缓存内容。")
24
26
  @abstractmethod
25
- async def _fetch_content(self) -> Optional[str]: raise NotImplementedError
26
- async def render(self) -> Optional[ContentBlock]:
27
- await self._refresh()
28
- if self._cached_content is not None: return ContentBlock(self.name, self._cached_content, self)
27
+ async def render(self) -> Optional[str]: raise NotImplementedError
28
+ @abstractmethod
29
+ def update(self, *args, **kwargs): raise NotImplementedError
30
+ def get_content_block(self) -> Optional[ContentBlock]:
31
+ if self._cached_content is not None: return ContentBlock(self.name, self._cached_content)
29
32
  return None
30
33
 
31
34
  class Texts(ContextProvider):
32
- def __init__(self, name: str, text: str): super().__init__(name); self._text = text
33
- async def _fetch_content(self) -> str: return self._text
35
+ def __init__(self, text: str, name: Optional[str] = None):
36
+ self._text = text
37
+ if name is None:
38
+ h = hashlib.sha1(self._text.encode()).hexdigest()
39
+ _name = f"text_{h[:8]}"
40
+ else:
41
+ _name = name
42
+ super().__init__(_name)
43
+
44
+ def update(self, text: str):
45
+ self._text = text
46
+ self.mark_stale()
47
+
48
+ async def render(self) -> str: return self._text
34
49
 
35
50
  class Tools(ContextProvider):
36
51
  def __init__(self, tools_json: List[Dict]): super().__init__("tools"); self._tools_json = tools_json
37
- async def _fetch_content(self) -> str: return f"<tools>{str(self._tools_json)}</tools>"
52
+ def update(self, tools_json: List[Dict]):
53
+ self._tools_json = tools_json
54
+ self.mark_stale()
55
+ async def render(self) -> str: return f"<tools>{str(self._tools_json)}</tools>"
38
56
 
39
57
  class Files(ContextProvider):
40
- def __init__(self): super().__init__("files"); self._files: Dict[str, str] = {}
41
- def update(self, path: str, content: str): self._files[path] = content; self.mark_stale()
42
- async def _fetch_content(self) -> str:
43
- if not self._files: return None
44
- return "<files>\n" + "\n".join([f"<file path='{p}'>{c[:50]}...</file>" for p, c in self._files.items()]) + "\n</files>"
58
+ def __init__(self, *paths: Union[str, List[str]]):
59
+ super().__init__("files")
60
+ self._files: Dict[str, str] = {}
61
+
62
+ file_paths: List[str] = []
63
+ if paths:
64
+ # Handle the case where the first argument is a list of paths, e.g., Files(['a', 'b'])
65
+ if len(paths) == 1 and isinstance(paths[0], list):
66
+ file_paths.extend(paths[0])
67
+ # Handle the case where arguments are individual string paths, e.g., Files('a', 'b')
68
+ else:
69
+ file_paths.extend(paths)
70
+
71
+ if file_paths:
72
+ for path in file_paths:
73
+ try:
74
+ with open(path, 'r', encoding='utf-8') as f:
75
+ self._files[path] = f.read()
76
+ except FileNotFoundError:
77
+ logging.warning(f"File not found during initialization: {path}. Skipping.")
78
+ except Exception as e:
79
+ logging.error(f"Error reading file {path} during initialization: {e}")
45
80
 
46
- Item = Union[ContentBlock, ContextProvider]
81
+ async def refresh(self):
82
+ """
83
+ Overrides the default refresh behavior. It synchronizes the content of
84
+ all tracked files with the file system. If a file is not found, its
85
+ content is updated to reflect the error.
86
+ """
87
+ is_changed = False
88
+ for path in list(self._files.keys()):
89
+ try:
90
+ with open(path, 'r', encoding='utf-8') as f:
91
+ new_content = f.read()
92
+ if self._files.get(path) != new_content:
93
+ self._files[path] = new_content
94
+ is_changed = True
95
+ except FileNotFoundError:
96
+ error_msg = f"[Error: File not found at path '{path}']"
97
+ if self._files.get(path) != error_msg:
98
+ self._files[path] = error_msg
99
+ is_changed = True
100
+ except Exception as e:
101
+ error_msg = f"[Error: Could not read file at path '{path}': {e}]"
102
+ if self._files.get(path) != error_msg:
103
+ self._files[path] = error_msg
104
+ is_changed = True
47
105
 
48
- # 3. 消息内容类与消息类
49
- class MessageContent:
50
- def __init__(self, items: List[Item]): self._items: List[Item] = items
106
+ if is_changed:
107
+ self.mark_stale()
108
+
109
+ await super().refresh()
110
+
111
+ def update(self, path: str, content: Optional[str] = None):
112
+ """
113
+ Updates a single file. If content is provided, it updates the file in
114
+ memory. If content is None, it reads the file from disk.
115
+ """
116
+ if content is not None:
117
+ self._files[path] = content
118
+ else:
119
+ try:
120
+ with open(path, 'r', encoding='utf-8') as f:
121
+ self._files[path] = f.read()
122
+ except FileNotFoundError:
123
+ logging.error(f"File not found for update: {path}.")
124
+ self._files[path] = f"[Error: File not found at path '{path}']"
125
+ except Exception as e:
126
+ logging.error(f"Error reading file for update {path}: {e}.")
127
+ self._files[path] = f"[Error: Could not read file at path '{path}': {e}]"
128
+ self.mark_stale()
51
129
  async def render(self) -> str:
52
- tasks = []
53
- for item in self._items:
54
- if isinstance(item, ContextProvider):
55
- tasks.append(item.render())
56
- elif isinstance(item, ContentBlock):
57
- # 为了统一处理,将 ContentBlock 也包装成一个已完成的 Future
58
- future = asyncio.Future()
59
- future.set_result(item)
60
- tasks.append(future)
61
-
62
- blocks = await asyncio.gather(*tasks)
130
+ if not self._files: return None
131
+ return "<latest_file_content>" + "\n".join([f"<file><file_path>{p}</file_path><file_content>{c}</file_content></file>" for p, c in self._files.items()]) + "\n</latest_file_content>"
132
+
133
+ class Images(ContextProvider):
134
+ def __init__(self, url: str, name: Optional[str] = None):
135
+ super().__init__(name or url)
136
+ self.url = url
137
+ def update(self, url: str):
138
+ self.url = url
139
+ self.mark_stale()
140
+ async def render(self) -> Optional[str]:
141
+ if self.url.startswith("data:"):
142
+ return self.url
143
+ try:
144
+ with open(self.url, "rb") as image_file:
145
+ encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
146
+ mime_type, _ = mimetypes.guess_type(self.url)
147
+ if not mime_type: mime_type = "application/octet-stream" # Fallback
148
+ return f"data:{mime_type};base64,{encoded_string}"
149
+ except FileNotFoundError:
150
+ logging.warning(f"Image file not found: {self.url}. Skipping.")
151
+ return None # Or handle error appropriately
152
+
153
+ # 3. 消息类 (已合并 MessageContent)
154
+ class Message(ABC):
155
+ def __init__(self, role: str, *initial_items: Union[ContextProvider, str, list]):
156
+ self.role = role
157
+ processed_items = []
158
+ for item in initial_items:
159
+ if isinstance(item, str):
160
+ processed_items.append(Texts(text=item))
161
+ elif isinstance(item, Message):
162
+ processed_items.extend(item.providers())
163
+ elif isinstance(item, ContextProvider):
164
+ processed_items.append(item)
165
+ elif isinstance(item, list):
166
+ for sub_item in item:
167
+ if not isinstance(sub_item, dict) or 'type' not in sub_item:
168
+ raise ValueError("List items must be dicts with a 'type' key.")
169
+
170
+ item_type = sub_item['type']
171
+ if item_type == 'text':
172
+ processed_items.append(Texts(text=sub_item.get('text', '')))
173
+ elif item_type == 'image_url':
174
+ image_url = sub_item.get('image_url', {}).get('url')
175
+ if image_url:
176
+ processed_items.append(Images(url=image_url))
177
+ else:
178
+ raise ValueError(f"Unsupported item type in list: {item_type}")
179
+ else:
180
+ raise TypeError(f"Unsupported item type: {type(item)}. Must be str, ContextProvider, or list.")
181
+ self._items: List[ContextProvider] = processed_items
182
+ self._parent_messages: Optional['Messages'] = None
183
+
184
+ def _render_content(self) -> str:
185
+ blocks = [item.get_content_block() for item in self._items]
63
186
  return "\n\n".join(b.content for b in blocks if b and b.content)
64
- def pop(self, name: str) -> Optional[Item]:
187
+
188
+ def pop(self, name: str) -> Optional[ContextProvider]:
189
+ popped_item = None
65
190
  for i, item in enumerate(self._items):
66
- if hasattr(item, 'name') and item.name == name: return self._items.pop(i)
67
- return None
68
- def insert(self, index: int, item: Item): self._items.insert(index, item)
69
- def append(self, item: Item): self._items.append(item)
70
- def providers(self) -> List[ContextProvider]: return [item for item in self._items if isinstance(item, ContextProvider)]
71
- def __repr__(self): return f"Content(items={[item.name for item in self._items if hasattr(item, 'name')]})"
191
+ if hasattr(item, 'name') and item.name == name:
192
+ popped_item = self._items.pop(i)
193
+ break
194
+ if popped_item and self._parent_messages:
195
+ self._parent_messages._notify_provider_removed(popped_item)
196
+ return popped_item
72
197
 
73
- class Message(ABC):
74
- def __init__(self, role: str, *initial_items: Item): self.role = role; self.content = MessageContent(list(initial_items))
75
- async def to_dict(self) -> Optional[Dict[str, Any]]:
76
- rendered_content = await self.content.render()
77
- if not rendered_content: return None
78
- return {"role": self.role, "content": rendered_content}
198
+ def insert(self, index: int, item: ContextProvider):
199
+ self._items.insert(index, item)
200
+ if self._parent_messages:
201
+ self._parent_messages._notify_provider_added(item, self)
202
+
203
+ def append(self, item: ContextProvider):
204
+ self._items.append(item)
205
+ if self._parent_messages:
206
+ self._parent_messages._notify_provider_added(item, self)
207
+
208
+ def providers(self) -> List[ContextProvider]: return self._items
209
+
210
+ def __add__(self, other):
211
+ if isinstance(other, str):
212
+ new_items = self._items + [Texts(text=other)]
213
+ return type(self)(*new_items)
214
+ if isinstance(other, Message):
215
+ new_items = self._items + other.providers()
216
+ return type(self)(*new_items)
217
+ return NotImplemented
218
+
219
+ def __radd__(self, other):
220
+ if isinstance(other, str):
221
+ new_items = [Texts(text=other)] + self._items
222
+ return type(self)(*new_items)
223
+ if isinstance(other, Message):
224
+ new_items = other.providers() + self._items
225
+ return type(self)(*new_items)
226
+ return NotImplemented
227
+
228
+ def __getitem__(self, key: str) -> Any:
229
+ """
230
+ 使得 Message 对象支持字典风格的访问 (e.g., message['content'])。
231
+ """
232
+ if key == 'role':
233
+ return self.role
234
+ elif key == 'content':
235
+ # 直接调用 to_dict 并提取 'content',确保逻辑一致
236
+ rendered_dict = self.to_dict()
237
+ return rendered_dict.get('content') if rendered_dict else None
238
+ # 对于 tool_calls 等特殊属性,也通过 to_dict 获取
239
+ elif hasattr(self, key):
240
+ rendered_dict = self.to_dict()
241
+ if rendered_dict and key in rendered_dict:
242
+ return rendered_dict[key]
243
+
244
+ # 如果在对象本身或其 to_dict() 中都找不到,则引发 KeyError
245
+ if hasattr(self, key):
246
+ return getattr(self, key)
247
+ raise KeyError(f"'{key}'")
248
+
249
+ def __repr__(self): return f"Message(role='{self.role}', items={[i.name for i in self._items]})"
250
+ def __bool__(self) -> bool:
251
+ return bool(self._items)
252
+ def get(self, key: str, default: Any = None) -> Any:
253
+ """提供类似字典的 .get() 方法来访问属性。"""
254
+ return getattr(self, key, default)
255
+ def to_dict(self) -> Optional[Dict[str, Any]]:
256
+ is_multimodal = any(isinstance(p, Images) for p in self._items)
257
+
258
+ if not is_multimodal:
259
+ rendered_content = self._render_content()
260
+ if not rendered_content: return None
261
+ return {"role": self.role, "content": rendered_content}
262
+ else:
263
+ content_list = []
264
+ for item in self._items:
265
+ block = item.get_content_block()
266
+ if not block or not block.content: continue
267
+ if isinstance(item, Images):
268
+ content_list.append({"type": "image_url", "image_url": {"url": block.content}})
269
+ else:
270
+ content_list.append({"type": "text", "text": block.content})
271
+ if not content_list: return None
272
+ return {"role": self.role, "content": content_list}
79
273
 
80
274
  class SystemMessage(Message):
81
275
  def __init__(self, *items): super().__init__("system", *items)
82
276
  class UserMessage(Message):
83
277
  def __init__(self, *items): super().__init__("user", *items)
278
+ class AssistantMessage(Message):
279
+ def __init__(self, *items): super().__init__("assistant", *items)
280
+
281
+ class RoleMessage:
282
+ """A factory class that creates a specific message type based on the role."""
283
+ def __new__(cls, role: str, *items):
284
+ if role == 'system':
285
+ return SystemMessage(*items)
286
+ elif role == 'user':
287
+ return UserMessage(*items)
288
+ elif role == 'assistant':
289
+ return AssistantMessage(*items)
290
+ else:
291
+ raise ValueError(f"Invalid role: {role}. Must be 'system', 'user', or 'assistant'.")
292
+
293
+ class ToolCalls(Message):
294
+ """Represents an assistant message that requests tool calls."""
295
+ def __init__(self, tool_calls: List[Any]):
296
+ super().__init__("assistant")
297
+ self.tool_calls = tool_calls
298
+
299
+ def to_dict(self) -> Dict[str, Any]:
300
+ # Duck-typing serialization for OpenAI's tool_call objects
301
+ serialized_calls = []
302
+ for tc in self.tool_calls:
303
+ try:
304
+ # Attempt to serialize based on openai-python > 1.0 tool_call structure
305
+ func = tc.function
306
+ serialized_calls.append({
307
+ "id": tc.id,
308
+ "type": tc.type,
309
+ "function": { "name": func.name, "arguments": func.arguments }
310
+ })
311
+ except AttributeError:
312
+ if isinstance(tc, dict):
313
+ serialized_calls.append(tc) # Assume it's already a serializable dict
314
+ else:
315
+ raise TypeError(f"Unsupported tool_call type: {type(tc)}. It should be an OpenAI tool_call object or a dict.")
316
+
317
+ return {
318
+ "role": self.role,
319
+ "tool_calls": serialized_calls,
320
+ "content": None
321
+ }
84
322
 
85
- # 4. 上下文构建器 (内部使用)
86
- class ContextBuilder:
87
- def __init__(self, providers: List[ContextProvider]): self.providers = {p.name: p for p in providers}
88
- def get_provider(self, name: str) -> Optional[ContextProvider]: return self.providers.get(name)
323
+ class ToolResults(Message):
324
+ """Represents a tool message with the result of a single tool call."""
325
+ def __init__(self, tool_call_id: str, content: str):
326
+ super().__init__("tool")
327
+ self.tool_call_id = tool_call_id
328
+ self.content = content
89
329
 
90
- # 5. 顶层容器: Messages
330
+ def to_dict(self) -> Dict[str, Any]:
331
+ return {
332
+ "role": self.role,
333
+ "tool_call_id": self.tool_call_id,
334
+ "content": self.content
335
+ }
336
+
337
+ # 4. 顶层容器: Messages
91
338
  class Messages:
92
339
  def __init__(self, *initial_messages: Message):
93
- self._messages: List[Message] = list(initial_messages)
94
- all_providers = []
95
- for msg in self._messages: all_providers.extend(msg.content.providers())
96
- self._context_builder = ContextBuilder(all_providers)
97
- def provider(self, name: str) -> Optional[ContextProvider]: return self._context_builder.get_provider(name)
98
- def pop(self, name: str) -> Optional[Item]:
99
- for message in self._messages:
100
- popped_item = message.content.pop(name)
101
- if popped_item: return popped_item
340
+ from typing import Tuple
341
+ self._messages: List[Message] = []
342
+ self._providers_index: Dict[str, Tuple[ContextProvider, Message]] = {}
343
+ if initial_messages:
344
+ for msg in initial_messages:
345
+ self.append(msg)
346
+
347
+ def _notify_provider_added(self, provider: ContextProvider, message: Message):
348
+ if provider.name not in self._providers_index:
349
+ self._providers_index[provider.name] = (provider, message)
350
+
351
+ def _notify_provider_removed(self, provider: ContextProvider):
352
+ if provider.name in self._providers_index:
353
+ del self._providers_index[provider.name]
354
+
355
+ def provider(self, name: str) -> Optional[ContextProvider]:
356
+ indexed = self._providers_index.get(name)
357
+ return indexed[0] if indexed else None
358
+
359
+ def pop(self, key: Optional[Union[str, int]] = None) -> Union[Optional[ContextProvider], Optional[Message]]:
360
+ # If no key is provided, pop the last message.
361
+ if key is None:
362
+ key = len(self._messages) - 1
363
+
364
+ if isinstance(key, str):
365
+ indexed = self._providers_index.get(key)
366
+ if not indexed:
367
+ return None
368
+ _provider, parent_message = indexed
369
+ return parent_message.pop(key)
370
+ elif isinstance(key, int):
371
+ try:
372
+ if key < 0: # Handle negative indices like -1
373
+ key += len(self._messages)
374
+ if not (0 <= key < len(self._messages)):
375
+ return None
376
+ popped_message = self._messages.pop(key)
377
+ popped_message._parent_messages = None
378
+ for provider in popped_message.providers():
379
+ self._notify_provider_removed(provider)
380
+ return popped_message
381
+ except IndexError:
382
+ return None
383
+
102
384
  return None
103
- async def render(self) -> List[Dict[str, Any]]:
104
- tasks = [msg.to_dict() for msg in self._messages]
105
- results = await asyncio.gather(*tasks)
385
+
386
+ async def refresh(self):
387
+ tasks = [provider.refresh() for provider, _ in self._providers_index.values()]
388
+ await asyncio.gather(*tasks)
389
+
390
+ def render(self) -> List[Dict[str, Any]]:
391
+ results = [msg.to_dict() for msg in self._messages]
106
392
  return [res for res in results if res]
393
+
394
+ async def render_latest(self) -> List[Dict[str, Any]]:
395
+ await self.refresh()
396
+ return self.render()
397
+
107
398
  def append(self, message: Message):
108
- self._messages.append(message)
109
- for p in message.content.providers():
110
- if p.name not in self._context_builder.providers: self._context_builder.providers[p.name] = p
399
+ if self._messages and self._messages[-1].role == message.role:
400
+ last_message = self._messages[-1]
401
+ for provider in message.providers():
402
+ last_message.append(provider)
403
+ else:
404
+ message._parent_messages = self
405
+ self._messages.append(message)
406
+ for p in message.providers():
407
+ self._notify_provider_added(p, message)
408
+
409
+ def save(self, file_path: str):
410
+ """
411
+ Saves the entire Messages object to a file using pickle.
412
+ Warning: Deserializing data with pickle from an untrusted source is insecure.
413
+ """
414
+ with open(file_path, 'wb') as f:
415
+ pickle.dump(self, f)
416
+
417
+ @classmethod
418
+ def load(cls, file_path: str) -> Optional['Messages']:
419
+ """
420
+ Loads a Messages object from a file using pickle.
421
+ Returns the loaded object, or None if the file is not found or an error occurs.
422
+ Warning: Only load files from a trusted source.
423
+ """
424
+ try:
425
+ with open(file_path, 'rb') as f:
426
+ return pickle.load(f)
427
+ except FileNotFoundError:
428
+ logging.warning(f"File not found at {file_path}, returning empty Messages.")
429
+ return cls()
430
+ except (pickle.UnpicklingError, EOFError) as e:
431
+ logging.error(f"Could not deserialize file {file_path}: {e}")
432
+ return cls()
433
+
111
434
  def __getitem__(self, index: int) -> Message: return self._messages[index]
112
435
  def __len__(self) -> int: return len(self._messages)
113
436
  def __iter__(self): return iter(self._messages)
114
-
115
-
116
- # ==============================================================================
117
- # 6. 演示
118
- # ==============================================================================
119
- async def run_demo():
120
- # --- 1. 初始化提供者 ---
121
- system_prompt_provider = Texts("system_prompt", "你是一个AI助手。")
122
- tools_provider = Tools(tools_json=[{"name": "read_file"}])
123
- files_provider = Files()
124
-
125
- # --- 2. 演示新功能:优雅地构建 Messages ---
126
- print("\n>>> 场景 A: 使用新的、优雅的构造函数直接初始化 Messages")
127
- messages = Messages(
128
- SystemMessage(system_prompt_provider, tools_provider),
129
- UserMessage(files_provider, Texts("user_input", "这是我的初始问题。"))
130
- )
131
-
132
- print("\n--- 渲染后的初始 Messages (首次渲染,全部刷新) ---")
133
- for msg_dict in await messages.render(): print(msg_dict)
134
- print("-" * 40)
135
-
136
- # --- 3. 演示穿透更新 ---
137
- print("\n>>> 场景 B: 穿透更新 File Provider,渲染时自动刷新")
138
-
139
- # 直接通过 messages 对象穿透访问并更新 files provider
140
- files_provider_instance = messages.provider("files")
141
- if isinstance(files_provider_instance, Files):
142
- files_provider_instance.update("file1.py", "这是新的文件内容!")
143
-
144
- print("\n--- 再次渲染 Messages (只有文件提供者会刷新) ---")
145
- for msg_dict in await messages.render(): print(msg_dict)
146
- print("-" * 40)
147
-
148
- # --- 4. 演示全局 Pop 和通过索引 Insert ---
149
- print("\n>>> 场景 C: 全局 Pop 工具提供者,并 Insert 到 UserMessage 中")
150
-
151
- # a. 全局弹出 'tools' Provider
152
- popped_tools_provider = messages.pop("tools")
153
-
154
- # b. 将弹出的 Provider 插入到第一个 UserMessage (索引为1) 的开头
155
- if popped_tools_provider:
156
- # 通过索引精确定位
157
- messages[1].content.insert(0, popped_tools_provider)
158
- print(f"\n已成功将 '{popped_tools_provider.name}' 提供者移动到用户消息。")
159
-
160
- print("\n--- Pop 和 Insert 后渲染的 Messages (验证移动效果) ---")
161
- for msg_dict in await messages.render(): print(msg_dict)
162
- print("-" * 40)
163
-
164
- if __name__ == "__main__":
165
- asyncio.run(run_demo())