aient 1.1.91__py3-none-any.whl → 1.1.92__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aient/architext/architext/core.py +290 -91
- aient/architext/test/openai_client.py +146 -0
- aient/architext/test/test.py +927 -0
- aient/architext/test/test_save_load.py +93 -0
- aient/models/chatgpt.py +31 -104
- {aient-1.1.91.dist-info → aient-1.1.92.dist-info}/METADATA +1 -1
- {aient-1.1.91.dist-info → aient-1.1.92.dist-info}/RECORD +10 -8
- aient/architext/test.py +0 -226
- {aient-1.1.91.dist-info → aient-1.1.92.dist-info}/WHEEL +0 -0
- {aient-1.1.91.dist-info → aient-1.1.92.dist-info}/licenses/LICENSE +0 -0
- {aient-1.1.91.dist-info → aient-1.1.92.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,12 @@
|
|
1
|
+
import pickle
|
1
2
|
import base64
|
2
3
|
import asyncio
|
3
4
|
import logging
|
5
|
+
import hashlib
|
4
6
|
import mimetypes
|
5
7
|
from dataclasses import dataclass
|
6
8
|
from abc import ABC, abstractmethod
|
7
|
-
from typing import List, Dict, Any, Optional
|
9
|
+
from typing import List, Dict, Any, Optional, Union
|
8
10
|
|
9
11
|
# 1. 核心数据结构: ContentBlock
|
10
12
|
@dataclass
|
@@ -19,49 +21,164 @@ class ContextProvider(ABC):
|
|
19
21
|
def mark_stale(self): self._is_stale = True
|
20
22
|
async def refresh(self):
|
21
23
|
if self._is_stale:
|
22
|
-
self._cached_content = await self.
|
24
|
+
self._cached_content = await self.render()
|
23
25
|
self._is_stale = False
|
24
26
|
@abstractmethod
|
25
|
-
async def
|
27
|
+
async def render(self) -> Optional[str]: raise NotImplementedError
|
28
|
+
@abstractmethod
|
29
|
+
def update(self, *args, **kwargs): raise NotImplementedError
|
26
30
|
def get_content_block(self) -> Optional[ContentBlock]:
|
27
31
|
if self._cached_content is not None: return ContentBlock(self.name, self._cached_content)
|
28
32
|
return None
|
29
33
|
|
30
34
|
class Texts(ContextProvider):
|
31
|
-
def __init__(self,
|
32
|
-
|
35
|
+
def __init__(self, text: str, name: Optional[str] = None):
|
36
|
+
self._text = text
|
37
|
+
if name is None:
|
38
|
+
h = hashlib.sha1(self._text.encode()).hexdigest()
|
39
|
+
_name = f"text_{h[:8]}"
|
40
|
+
else:
|
41
|
+
_name = name
|
42
|
+
super().__init__(_name)
|
43
|
+
|
44
|
+
def update(self, text: str):
|
45
|
+
self._text = text
|
46
|
+
self.mark_stale()
|
47
|
+
|
48
|
+
async def render(self) -> str: return self._text
|
33
49
|
|
34
50
|
class Tools(ContextProvider):
|
35
51
|
def __init__(self, tools_json: List[Dict]): super().__init__("tools"); self._tools_json = tools_json
|
36
|
-
|
52
|
+
def update(self, tools_json: List[Dict]):
|
53
|
+
self._tools_json = tools_json
|
54
|
+
self.mark_stale()
|
55
|
+
async def render(self) -> str: return f"<tools>{str(self._tools_json)}</tools>"
|
37
56
|
|
38
57
|
class Files(ContextProvider):
|
39
|
-
def __init__(self
|
40
|
-
|
41
|
-
|
58
|
+
def __init__(self, *paths: Union[str, List[str]]):
|
59
|
+
super().__init__("files")
|
60
|
+
self._files: Dict[str, str] = {}
|
61
|
+
|
62
|
+
file_paths: List[str] = []
|
63
|
+
if paths:
|
64
|
+
# Handle the case where the first argument is a list of paths, e.g., Files(['a', 'b'])
|
65
|
+
if len(paths) == 1 and isinstance(paths[0], list):
|
66
|
+
file_paths.extend(paths[0])
|
67
|
+
# Handle the case where arguments are individual string paths, e.g., Files('a', 'b')
|
68
|
+
else:
|
69
|
+
file_paths.extend(paths)
|
70
|
+
|
71
|
+
if file_paths:
|
72
|
+
for path in file_paths:
|
73
|
+
try:
|
74
|
+
with open(path, 'r', encoding='utf-8') as f:
|
75
|
+
self._files[path] = f.read()
|
76
|
+
except FileNotFoundError:
|
77
|
+
logging.warning(f"File not found during initialization: {path}. Skipping.")
|
78
|
+
except Exception as e:
|
79
|
+
logging.error(f"Error reading file {path} during initialization: {e}")
|
80
|
+
|
81
|
+
async def refresh(self):
|
82
|
+
"""
|
83
|
+
Overrides the default refresh behavior. It synchronizes the content of
|
84
|
+
all tracked files with the file system. If a file is not found, its
|
85
|
+
content is updated to reflect the error.
|
86
|
+
"""
|
87
|
+
is_changed = False
|
88
|
+
for path in list(self._files.keys()):
|
89
|
+
try:
|
90
|
+
with open(path, 'r', encoding='utf-8') as f:
|
91
|
+
new_content = f.read()
|
92
|
+
if self._files.get(path) != new_content:
|
93
|
+
self._files[path] = new_content
|
94
|
+
is_changed = True
|
95
|
+
except FileNotFoundError:
|
96
|
+
error_msg = f"[Error: File not found at path '{path}']"
|
97
|
+
if self._files.get(path) != error_msg:
|
98
|
+
self._files[path] = error_msg
|
99
|
+
is_changed = True
|
100
|
+
except Exception as e:
|
101
|
+
error_msg = f"[Error: Could not read file at path '{path}': {e}]"
|
102
|
+
if self._files.get(path) != error_msg:
|
103
|
+
self._files[path] = error_msg
|
104
|
+
is_changed = True
|
105
|
+
|
106
|
+
if is_changed:
|
107
|
+
self.mark_stale()
|
108
|
+
|
109
|
+
await super().refresh()
|
110
|
+
|
111
|
+
def update(self, path: str, content: Optional[str] = None):
|
112
|
+
"""
|
113
|
+
Updates a single file. If content is provided, it updates the file in
|
114
|
+
memory. If content is None, it reads the file from disk.
|
115
|
+
"""
|
116
|
+
if content is not None:
|
117
|
+
self._files[path] = content
|
118
|
+
else:
|
119
|
+
try:
|
120
|
+
with open(path, 'r', encoding='utf-8') as f:
|
121
|
+
self._files[path] = f.read()
|
122
|
+
except FileNotFoundError:
|
123
|
+
logging.error(f"File not found for update: {path}.")
|
124
|
+
self._files[path] = f"[Error: File not found at path '{path}']"
|
125
|
+
except Exception as e:
|
126
|
+
logging.error(f"Error reading file for update {path}: {e}.")
|
127
|
+
self._files[path] = f"[Error: Could not read file at path '{path}': {e}]"
|
128
|
+
self.mark_stale()
|
129
|
+
async def render(self) -> str:
|
42
130
|
if not self._files: return None
|
43
|
-
return "<
|
131
|
+
return "<latest_file_content>" + "\n".join([f"<file><file_path>{p}</file_path><file_content>{c}</file_content></file>" for p, c in self._files.items()]) + "\n</latest_file_content>"
|
44
132
|
|
45
133
|
class Images(ContextProvider):
|
46
|
-
def __init__(self,
|
47
|
-
super().__init__(name or
|
48
|
-
self.
|
49
|
-
|
134
|
+
def __init__(self, url: str, name: Optional[str] = None):
|
135
|
+
super().__init__(name or url)
|
136
|
+
self.url = url
|
137
|
+
def update(self, url: str):
|
138
|
+
self.url = url
|
139
|
+
self.mark_stale()
|
140
|
+
async def render(self) -> Optional[str]:
|
141
|
+
if self.url.startswith("data:"):
|
142
|
+
return self.url
|
50
143
|
try:
|
51
|
-
with open(self.
|
144
|
+
with open(self.url, "rb") as image_file:
|
52
145
|
encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
|
53
|
-
mime_type, _ = mimetypes.guess_type(self.
|
146
|
+
mime_type, _ = mimetypes.guess_type(self.url)
|
54
147
|
if not mime_type: mime_type = "application/octet-stream" # Fallback
|
55
148
|
return f"data:{mime_type};base64,{encoded_string}"
|
56
149
|
except FileNotFoundError:
|
57
|
-
logging.warning(f"Image file not found: {self.
|
150
|
+
logging.warning(f"Image file not found: {self.url}. Skipping.")
|
58
151
|
return None # Or handle error appropriately
|
59
152
|
|
60
153
|
# 3. 消息类 (已合并 MessageContent)
|
61
154
|
class Message(ABC):
|
62
|
-
def __init__(self, role: str, *initial_items: ContextProvider):
|
155
|
+
def __init__(self, role: str, *initial_items: Union[ContextProvider, str, list]):
|
63
156
|
self.role = role
|
64
|
-
|
157
|
+
processed_items = []
|
158
|
+
for item in initial_items:
|
159
|
+
if isinstance(item, str):
|
160
|
+
processed_items.append(Texts(text=item))
|
161
|
+
elif isinstance(item, Message):
|
162
|
+
processed_items.extend(item.providers())
|
163
|
+
elif isinstance(item, ContextProvider):
|
164
|
+
processed_items.append(item)
|
165
|
+
elif isinstance(item, list):
|
166
|
+
for sub_item in item:
|
167
|
+
if not isinstance(sub_item, dict) or 'type' not in sub_item:
|
168
|
+
raise ValueError("List items must be dicts with a 'type' key.")
|
169
|
+
|
170
|
+
item_type = sub_item['type']
|
171
|
+
if item_type == 'text':
|
172
|
+
processed_items.append(Texts(text=sub_item.get('text', '')))
|
173
|
+
elif item_type == 'image_url':
|
174
|
+
image_url = sub_item.get('image_url', {}).get('url')
|
175
|
+
if image_url:
|
176
|
+
processed_items.append(Images(url=image_url))
|
177
|
+
else:
|
178
|
+
raise ValueError(f"Unsupported item type in list: {item_type}")
|
179
|
+
else:
|
180
|
+
raise TypeError(f"Unsupported item type: {type(item)}. Must be str, ContextProvider, or list.")
|
181
|
+
self._items: List[ContextProvider] = processed_items
|
65
182
|
self._parent_messages: Optional['Messages'] = None
|
66
183
|
|
67
184
|
def _render_content(self) -> str:
|
@@ -89,7 +206,52 @@ class Message(ABC):
|
|
89
206
|
self._parent_messages._notify_provider_added(item, self)
|
90
207
|
|
91
208
|
def providers(self) -> List[ContextProvider]: return self._items
|
209
|
+
|
210
|
+
def __add__(self, other):
|
211
|
+
if isinstance(other, str):
|
212
|
+
new_items = self._items + [Texts(text=other)]
|
213
|
+
return type(self)(*new_items)
|
214
|
+
if isinstance(other, Message):
|
215
|
+
new_items = self._items + other.providers()
|
216
|
+
return type(self)(*new_items)
|
217
|
+
return NotImplemented
|
218
|
+
|
219
|
+
def __radd__(self, other):
|
220
|
+
if isinstance(other, str):
|
221
|
+
new_items = [Texts(text=other)] + self._items
|
222
|
+
return type(self)(*new_items)
|
223
|
+
if isinstance(other, Message):
|
224
|
+
new_items = other.providers() + self._items
|
225
|
+
return type(self)(*new_items)
|
226
|
+
return NotImplemented
|
227
|
+
|
228
|
+
def __getitem__(self, key: str) -> Any:
|
229
|
+
"""
|
230
|
+
使得 Message 对象支持字典风格的访问 (e.g., message['content'])。
|
231
|
+
"""
|
232
|
+
if key == 'role':
|
233
|
+
return self.role
|
234
|
+
elif key == 'content':
|
235
|
+
# 直接调用 to_dict 并提取 'content',确保逻辑一致
|
236
|
+
rendered_dict = self.to_dict()
|
237
|
+
return rendered_dict.get('content') if rendered_dict else None
|
238
|
+
# 对于 tool_calls 等特殊属性,也通过 to_dict 获取
|
239
|
+
elif hasattr(self, key):
|
240
|
+
rendered_dict = self.to_dict()
|
241
|
+
if rendered_dict and key in rendered_dict:
|
242
|
+
return rendered_dict[key]
|
243
|
+
|
244
|
+
# 如果在对象本身或其 to_dict() 中都找不到,则引发 KeyError
|
245
|
+
if hasattr(self, key):
|
246
|
+
return getattr(self, key)
|
247
|
+
raise KeyError(f"'{key}'")
|
248
|
+
|
92
249
|
def __repr__(self): return f"Message(role='{self.role}', items={[i.name for i in self._items]})"
|
250
|
+
def __bool__(self) -> bool:
|
251
|
+
return bool(self._items)
|
252
|
+
def get(self, key: str, default: Any = None) -> Any:
|
253
|
+
"""提供类似字典的 .get() 方法来访问属性。"""
|
254
|
+
return getattr(self, key, default)
|
93
255
|
def to_dict(self) -> Optional[Dict[str, Any]]:
|
94
256
|
is_multimodal = any(isinstance(p, Images) for p in self._items)
|
95
257
|
|
@@ -113,6 +275,64 @@ class SystemMessage(Message):
|
|
113
275
|
def __init__(self, *items): super().__init__("system", *items)
|
114
276
|
class UserMessage(Message):
|
115
277
|
def __init__(self, *items): super().__init__("user", *items)
|
278
|
+
class AssistantMessage(Message):
|
279
|
+
def __init__(self, *items): super().__init__("assistant", *items)
|
280
|
+
|
281
|
+
class RoleMessage:
|
282
|
+
"""A factory class that creates a specific message type based on the role."""
|
283
|
+
def __new__(cls, role: str, *items):
|
284
|
+
if role == 'system':
|
285
|
+
return SystemMessage(*items)
|
286
|
+
elif role == 'user':
|
287
|
+
return UserMessage(*items)
|
288
|
+
elif role == 'assistant':
|
289
|
+
return AssistantMessage(*items)
|
290
|
+
else:
|
291
|
+
raise ValueError(f"Invalid role: {role}. Must be 'system', 'user', or 'assistant'.")
|
292
|
+
|
293
|
+
class ToolCalls(Message):
|
294
|
+
"""Represents an assistant message that requests tool calls."""
|
295
|
+
def __init__(self, tool_calls: List[Any]):
|
296
|
+
super().__init__("assistant")
|
297
|
+
self.tool_calls = tool_calls
|
298
|
+
|
299
|
+
def to_dict(self) -> Dict[str, Any]:
|
300
|
+
# Duck-typing serialization for OpenAI's tool_call objects
|
301
|
+
serialized_calls = []
|
302
|
+
for tc in self.tool_calls:
|
303
|
+
try:
|
304
|
+
# Attempt to serialize based on openai-python > 1.0 tool_call structure
|
305
|
+
func = tc.function
|
306
|
+
serialized_calls.append({
|
307
|
+
"id": tc.id,
|
308
|
+
"type": tc.type,
|
309
|
+
"function": { "name": func.name, "arguments": func.arguments }
|
310
|
+
})
|
311
|
+
except AttributeError:
|
312
|
+
if isinstance(tc, dict):
|
313
|
+
serialized_calls.append(tc) # Assume it's already a serializable dict
|
314
|
+
else:
|
315
|
+
raise TypeError(f"Unsupported tool_call type: {type(tc)}. It should be an OpenAI tool_call object or a dict.")
|
316
|
+
|
317
|
+
return {
|
318
|
+
"role": self.role,
|
319
|
+
"tool_calls": serialized_calls,
|
320
|
+
"content": None
|
321
|
+
}
|
322
|
+
|
323
|
+
class ToolResults(Message):
|
324
|
+
"""Represents a tool message with the result of a single tool call."""
|
325
|
+
def __init__(self, tool_call_id: str, content: str):
|
326
|
+
super().__init__("tool")
|
327
|
+
self.tool_call_id = tool_call_id
|
328
|
+
self.content = content
|
329
|
+
|
330
|
+
def to_dict(self) -> Dict[str, Any]:
|
331
|
+
return {
|
332
|
+
"role": self.role,
|
333
|
+
"tool_call_id": self.tool_call_id,
|
334
|
+
"content": self.content
|
335
|
+
}
|
116
336
|
|
117
337
|
# 4. 顶层容器: Messages
|
118
338
|
class Messages:
|
@@ -136,12 +356,32 @@ class Messages:
|
|
136
356
|
indexed = self._providers_index.get(name)
|
137
357
|
return indexed[0] if indexed else None
|
138
358
|
|
139
|
-
def pop(self,
|
140
|
-
|
141
|
-
if
|
142
|
-
|
143
|
-
|
144
|
-
|
359
|
+
def pop(self, key: Optional[Union[str, int]] = None) -> Union[Optional[ContextProvider], Optional[Message]]:
|
360
|
+
# If no key is provided, pop the last message.
|
361
|
+
if key is None:
|
362
|
+
key = len(self._messages) - 1
|
363
|
+
|
364
|
+
if isinstance(key, str):
|
365
|
+
indexed = self._providers_index.get(key)
|
366
|
+
if not indexed:
|
367
|
+
return None
|
368
|
+
_provider, parent_message = indexed
|
369
|
+
return parent_message.pop(key)
|
370
|
+
elif isinstance(key, int):
|
371
|
+
try:
|
372
|
+
if key < 0: # Handle negative indices like -1
|
373
|
+
key += len(self._messages)
|
374
|
+
if not (0 <= key < len(self._messages)):
|
375
|
+
return None
|
376
|
+
popped_message = self._messages.pop(key)
|
377
|
+
popped_message._parent_messages = None
|
378
|
+
for provider in popped_message.providers():
|
379
|
+
self._notify_provider_removed(provider)
|
380
|
+
return popped_message
|
381
|
+
except IndexError:
|
382
|
+
return None
|
383
|
+
|
384
|
+
return None
|
145
385
|
|
146
386
|
async def refresh(self):
|
147
387
|
tasks = [provider.refresh() for provider, _ in self._providers_index.values()]
|
@@ -166,72 +406,31 @@ class Messages:
|
|
166
406
|
for p in message.providers():
|
167
407
|
self._notify_provider_added(p, message)
|
168
408
|
|
409
|
+
def save(self, file_path: str):
|
410
|
+
"""
|
411
|
+
Saves the entire Messages object to a file using pickle.
|
412
|
+
Warning: Deserializing data with pickle from an untrusted source is insecure.
|
413
|
+
"""
|
414
|
+
with open(file_path, 'wb') as f:
|
415
|
+
pickle.dump(self, f)
|
416
|
+
|
417
|
+
@classmethod
|
418
|
+
def load(cls, file_path: str) -> Optional['Messages']:
|
419
|
+
"""
|
420
|
+
Loads a Messages object from a file using pickle.
|
421
|
+
Returns the loaded object, or None if the file is not found or an error occurs.
|
422
|
+
Warning: Only load files from a trusted source.
|
423
|
+
"""
|
424
|
+
try:
|
425
|
+
with open(file_path, 'rb') as f:
|
426
|
+
return pickle.load(f)
|
427
|
+
except FileNotFoundError:
|
428
|
+
logging.warning(f"File not found at {file_path}, returning empty Messages.")
|
429
|
+
return cls()
|
430
|
+
except (pickle.UnpicklingError, EOFError) as e:
|
431
|
+
logging.error(f"Could not deserialize file {file_path}: {e}")
|
432
|
+
return cls()
|
433
|
+
|
169
434
|
def __getitem__(self, index: int) -> Message: return self._messages[index]
|
170
435
|
def __len__(self) -> int: return len(self._messages)
|
171
436
|
def __iter__(self): return iter(self._messages)
|
172
|
-
|
173
|
-
# ==============================================================================
|
174
|
-
# 6. 演示
|
175
|
-
# ==============================================================================
|
176
|
-
async def run_demo():
|
177
|
-
# --- 1. 初始化提供者 ---
|
178
|
-
system_prompt_provider = Texts("system_prompt", "你是一个AI助手。")
|
179
|
-
tools_provider = Tools(tools_json=[{"name": "read_file"}])
|
180
|
-
files_provider = Files()
|
181
|
-
|
182
|
-
# --- 2. 演示新功能:优雅地构建 Messages ---
|
183
|
-
print("\n>>> 场景 A: 使用新的、优雅的构造函数直接初始化 Messages")
|
184
|
-
messages = Messages(
|
185
|
-
SystemMessage(system_prompt_provider, tools_provider),
|
186
|
-
UserMessage(files_provider, Texts("user_input", "这是我的初始问题。")),
|
187
|
-
UserMessage(Texts("user_input2", "这是我的初始问题2。"))
|
188
|
-
)
|
189
|
-
|
190
|
-
print("\n--- 渲染后的初始 Messages (首次渲染,全部刷新) ---")
|
191
|
-
for msg_dict in await messages.render_latest(): print(msg_dict)
|
192
|
-
print("-" * 40)
|
193
|
-
|
194
|
-
# --- 3. 演示穿透更新 ---
|
195
|
-
print("\n>>> 场景 B: 穿透更新 File Provider,渲染时自动刷新")
|
196
|
-
files_provider_instance = messages.provider("files")
|
197
|
-
if isinstance(files_provider_instance, Files):
|
198
|
-
files_provider_instance.update("file1.py", "这是新的文件内容!")
|
199
|
-
|
200
|
-
print("\n--- 再次渲染 Messages (只有文件提供者会刷新) ---")
|
201
|
-
for msg_dict in await messages.render_latest(): print(msg_dict)
|
202
|
-
print("-" * 40)
|
203
|
-
|
204
|
-
# --- 4. 演示全局 Pop 和通过索引 Insert ---
|
205
|
-
print("\n>>> 场景 C: 全局 Pop 工具提供者,并 Insert 到 UserMessage 中")
|
206
|
-
popped_tools_provider = messages.pop("tools")
|
207
|
-
if popped_tools_provider:
|
208
|
-
messages[1].insert(0, popped_tools_provider)
|
209
|
-
print(f"\n已成功将 '{popped_tools_provider.name}' 提供者移动到用户消息。")
|
210
|
-
|
211
|
-
print("\n--- Pop 和 Insert 后渲染的 Messages (验证移动效果) ---")
|
212
|
-
for msg_dict in messages.render(): print(msg_dict)
|
213
|
-
print("-" * 40)
|
214
|
-
|
215
|
-
# --- 5. 演示多模态渲染 ---
|
216
|
-
print("\n>>> 场景 D: 演示多模态 (文本+图片) 渲染")
|
217
|
-
with open("dummy_image.png", "w") as f:
|
218
|
-
f.write("This is a dummy image file.")
|
219
|
-
|
220
|
-
multimodal_message = Messages(
|
221
|
-
UserMessage(
|
222
|
-
Texts("prompt", "What do you see in this image?"),
|
223
|
-
Images("dummy_image.png")
|
224
|
-
)
|
225
|
-
)
|
226
|
-
print("\n--- 渲染后的多模态 Message ---")
|
227
|
-
for msg_dict in await multimodal_message.render_latest():
|
228
|
-
if isinstance(msg_dict['content'], list):
|
229
|
-
for item in msg_dict['content']:
|
230
|
-
if item['type'] == 'image_url':
|
231
|
-
item['image_url']['url'] = item['image_url']['url'][:80] + "..."
|
232
|
-
print(msg_dict)
|
233
|
-
print("-" * 40)
|
234
|
-
|
235
|
-
|
236
|
-
if __name__ == "__main__":
|
237
|
-
asyncio.run(run_demo())
|
@@ -0,0 +1,146 @@
|
|
1
|
+
import os
|
2
|
+
import json
|
3
|
+
import asyncio
|
4
|
+
from openai import AsyncOpenAI
|
5
|
+
|
6
|
+
# 从我们设计的 architext 库中导入消息类
|
7
|
+
from architext.core import (
|
8
|
+
Messages,
|
9
|
+
SystemMessage,
|
10
|
+
UserMessage,
|
11
|
+
AssistantMessage,
|
12
|
+
ToolCalls,
|
13
|
+
ToolResults,
|
14
|
+
Texts,
|
15
|
+
)
|
16
|
+
|
17
|
+
def _add_tool(a: int, b: int) -> int:
|
18
|
+
"""(工具函数) 计算两个整数的和。"""
|
19
|
+
print(f"Executing tool: add(a={a}, b={b})")
|
20
|
+
return a + b
|
21
|
+
|
22
|
+
async def main():
|
23
|
+
"""
|
24
|
+
一个简化的、函数式的流程,用于处理单个包含工具调用的用户查询。
|
25
|
+
"""
|
26
|
+
print("Starting simplified Tool Use demonstration...")
|
27
|
+
|
28
|
+
# --- 1. 初始化 ---
|
29
|
+
# 确保环境变量已设置
|
30
|
+
if not os.getenv("API_KEY"):
|
31
|
+
print("\nERROR: API_KEY environment variable not set.")
|
32
|
+
return
|
33
|
+
|
34
|
+
client = AsyncOpenAI(base_url=os.getenv("BASE_URL"), api_key=os.getenv("API_KEY"))
|
35
|
+
model = os.getenv("MODEL", "gpt-4o-mini")
|
36
|
+
|
37
|
+
# 定义工具
|
38
|
+
tool_executors = { "add": _add_tool }
|
39
|
+
tools_definition = [{
|
40
|
+
"type": "function", "function": {
|
41
|
+
"name": "add", "description": "Calculate the sum of two integers.",
|
42
|
+
"parameters": {
|
43
|
+
"type": "object",
|
44
|
+
"properties": {
|
45
|
+
"a": {"type": "integer", "description": "The first integer."},
|
46
|
+
"b": {"type": "integer", "description": "The second integer."},
|
47
|
+
}, "required": ["a", "b"],
|
48
|
+
},
|
49
|
+
},
|
50
|
+
}]
|
51
|
+
|
52
|
+
# --- 2. 处理查询 ---
|
53
|
+
# 初始消息
|
54
|
+
messages = Messages(
|
55
|
+
SystemMessage(Texts("system_prompt", "You are a helpful assistant. You must use the provided tools to answer questions.")),
|
56
|
+
UserMessage(Texts("user_question", "What is the sum of 5 and 10?"))
|
57
|
+
)
|
58
|
+
|
59
|
+
# 第一次 API 调用
|
60
|
+
print("\n--- [Step 1] Calling OpenAI with tools...")
|
61
|
+
response = await client.chat.completions.create(
|
62
|
+
model=model,
|
63
|
+
messages=await messages.render_latest(),
|
64
|
+
tools=tools_definition,
|
65
|
+
tool_choice="auto",
|
66
|
+
)
|
67
|
+
response_message = response.choices[0].message
|
68
|
+
|
69
|
+
# 检查是否需要工具调用
|
70
|
+
if not response_message.tool_calls:
|
71
|
+
final_content = response_message.content or ""
|
72
|
+
messages.append(AssistantMessage(Texts("assistant_response", final_content)))
|
73
|
+
else:
|
74
|
+
# 执行工具调用
|
75
|
+
print("--- [Step 2] Assistant requested tool calls. Executing them...")
|
76
|
+
messages.append(ToolCalls(response_message.tool_calls))
|
77
|
+
|
78
|
+
for tool_call in response_message.tool_calls:
|
79
|
+
if tool_call.function is None: continue
|
80
|
+
|
81
|
+
executor = tool_executors.get(tool_call.function.name)
|
82
|
+
if not executor: continue
|
83
|
+
|
84
|
+
try:
|
85
|
+
args = json.loads(tool_call.function.arguments)
|
86
|
+
result = executor(**args)
|
87
|
+
messages.append(ToolResults(tool_call_id=tool_call.id, content=str(result)))
|
88
|
+
print(f" - Executed '{tool_call.function.name}'. Result: {result}")
|
89
|
+
except (json.JSONDecodeError, TypeError) as e:
|
90
|
+
print(f" - Error processing tool call '{tool_call.function.name}': {e}")
|
91
|
+
|
92
|
+
# 第二次 API 调用
|
93
|
+
print("--- [Step 3] Calling OpenAI with tool results for final answer...")
|
94
|
+
final_response = await client.chat.completions.create(
|
95
|
+
model=model,
|
96
|
+
messages=await messages.render_latest(),
|
97
|
+
)
|
98
|
+
final_content = final_response.choices[0].message.content or ""
|
99
|
+
messages.append(AssistantMessage(Texts("final_response", final_content)))
|
100
|
+
|
101
|
+
# --- 3. 显示结果 ---
|
102
|
+
print("\n--- Final request body sent to OpenAI: ---")
|
103
|
+
print(json.dumps(await messages.render_latest(), indent=2, ensure_ascii=False))
|
104
|
+
|
105
|
+
print("\n--- Final Assistant Answer ---")
|
106
|
+
print(final_content)
|
107
|
+
print("\nDemonstration finished.")
|
108
|
+
|
109
|
+
if __name__ == "__main__":
|
110
|
+
asyncio.run(main())
|
111
|
+
|
112
|
+
"""
|
113
|
+
[
|
114
|
+
{
|
115
|
+
"role": "system",
|
116
|
+
"content": "You are a helpful assistant. You must use the provided tools to answer questions."
|
117
|
+
},
|
118
|
+
{
|
119
|
+
"role": "user",
|
120
|
+
"content": "What is the sum of 5 and 10?"
|
121
|
+
},
|
122
|
+
{
|
123
|
+
"role": "assistant",
|
124
|
+
"tool_calls": [
|
125
|
+
{
|
126
|
+
"id": "call_rddWXkDikIxllRgbPrR6XjtMVSBPv",
|
127
|
+
"type": "function",
|
128
|
+
"function": {
|
129
|
+
"name": "add",
|
130
|
+
"arguments": "{\"b\": 10, \"a\": 5}"
|
131
|
+
}
|
132
|
+
}
|
133
|
+
],
|
134
|
+
"content": null
|
135
|
+
},
|
136
|
+
{
|
137
|
+
"role": "tool",
|
138
|
+
"tool_call_id": "call_rddWXkDikIxllRgbPrR6XjtMVSBPv",
|
139
|
+
"content": "15"
|
140
|
+
},
|
141
|
+
{
|
142
|
+
"role": "assistant",
|
143
|
+
"content": "The sum of 5 and 10 is 15."
|
144
|
+
}
|
145
|
+
]
|
146
|
+
"""
|