camel-ai 0.2.71a4__py3-none-any.whl → 0.2.71a5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of camel-ai might be problematic. Click here for more details.
- camel/__init__.py +1 -1
- camel/agents/chat_agent.py +1482 -134
- camel/agents/repo_agent.py +2 -1
- camel/benchmarks/browsecomp.py +6 -6
- camel/logger.py +1 -1
- camel/messages/base.py +12 -1
- camel/models/azure_openai_model.py +96 -7
- camel/models/base_model.py +68 -10
- camel/models/deepseek_model.py +5 -0
- camel/models/gemini_model.py +5 -0
- camel/models/litellm_model.py +48 -16
- camel/models/model_manager.py +24 -6
- camel/models/openai_compatible_model.py +109 -5
- camel/models/openai_model.py +117 -8
- camel/societies/workforce/prompts.py +68 -5
- camel/societies/workforce/role_playing_worker.py +1 -0
- camel/societies/workforce/single_agent_worker.py +1 -0
- camel/societies/workforce/utils.py +67 -2
- camel/societies/workforce/workforce.py +270 -36
- camel/societies/workforce/workforce_logger.py +0 -8
- camel/tasks/task.py +2 -0
- camel/toolkits/__init__.py +2 -0
- camel/toolkits/file_write_toolkit.py +526 -121
- camel/toolkits/message_agent_toolkit.py +608 -0
- {camel_ai-0.2.71a4.dist-info → camel_ai-0.2.71a5.dist-info}/METADATA +6 -4
- {camel_ai-0.2.71a4.dist-info → camel_ai-0.2.71a5.dist-info}/RECORD +28 -27
- {camel_ai-0.2.71a4.dist-info → camel_ai-0.2.71a5.dist-info}/WHEEL +0 -0
- {camel_ai-0.2.71a4.dist-info → camel_ai-0.2.71a5.dist-info}/licenses/LICENSE +0 -0
camel/agents/chat_agent.py
CHANGED
|
@@ -16,23 +16,26 @@ from __future__ import annotations
|
|
|
16
16
|
import asyncio
|
|
17
17
|
import json
|
|
18
18
|
import logging
|
|
19
|
+
import queue
|
|
19
20
|
import textwrap
|
|
20
21
|
import threading
|
|
21
22
|
import time
|
|
22
23
|
import uuid
|
|
23
|
-
from collections import defaultdict
|
|
24
24
|
from pathlib import Path
|
|
25
25
|
from typing import (
|
|
26
26
|
TYPE_CHECKING,
|
|
27
27
|
Any,
|
|
28
|
+
AsyncGenerator,
|
|
28
29
|
Callable,
|
|
29
30
|
Dict,
|
|
31
|
+
Generator,
|
|
30
32
|
List,
|
|
31
33
|
Optional,
|
|
32
34
|
Set,
|
|
33
35
|
Tuple,
|
|
34
36
|
Type,
|
|
35
37
|
Union,
|
|
38
|
+
cast,
|
|
36
39
|
)
|
|
37
40
|
|
|
38
41
|
from openai import (
|
|
@@ -50,6 +53,7 @@ from camel.agents._utils import (
|
|
|
50
53
|
safe_model_dump,
|
|
51
54
|
)
|
|
52
55
|
from camel.agents.base import BaseAgent
|
|
56
|
+
from camel.logger import get_logger
|
|
53
57
|
from camel.memories import (
|
|
54
58
|
AgentMemory,
|
|
55
59
|
ChatHistoryMemory,
|
|
@@ -90,7 +94,7 @@ from camel.utils.tool_result import ToolResult
|
|
|
90
94
|
if TYPE_CHECKING:
|
|
91
95
|
from camel.terminators import ResponseTerminator
|
|
92
96
|
|
|
93
|
-
logger =
|
|
97
|
+
logger = get_logger(__name__)
|
|
94
98
|
|
|
95
99
|
# AgentOps decorator setting
|
|
96
100
|
try:
|
|
@@ -109,6 +113,11 @@ if os.environ.get("LANGFUSE_ENABLED", "False").lower() == "true":
|
|
|
109
113
|
from langfuse.decorators import observe
|
|
110
114
|
except ImportError:
|
|
111
115
|
from camel.utils import observe
|
|
116
|
+
elif os.environ.get("TRACEROOT_ENABLED", "False").lower() == "true":
|
|
117
|
+
try:
|
|
118
|
+
from traceroot import trace as observe # type: ignore[import]
|
|
119
|
+
except ImportError:
|
|
120
|
+
from camel.utils import observe
|
|
112
121
|
else:
|
|
113
122
|
from camel.utils import observe
|
|
114
123
|
|
|
@@ -124,6 +133,189 @@ SIMPLE_FORMAT_PROMPT = TextPrompt(
|
|
|
124
133
|
)
|
|
125
134
|
|
|
126
135
|
|
|
136
|
+
class StreamContentAccumulator:
|
|
137
|
+
r"""Manages content accumulation across streaming responses to ensure
|
|
138
|
+
all responses contain complete cumulative content."""
|
|
139
|
+
|
|
140
|
+
def __init__(self):
|
|
141
|
+
self.base_content = "" # Content before tool calls
|
|
142
|
+
self.current_content = "" # Current streaming content
|
|
143
|
+
self.tool_status_messages = [] # Accumulated tool status messages
|
|
144
|
+
|
|
145
|
+
def set_base_content(self, content: str):
|
|
146
|
+
r"""Set the base content (usually empty or pre-tool content)."""
|
|
147
|
+
self.base_content = content
|
|
148
|
+
|
|
149
|
+
def add_streaming_content(self, new_content: str):
|
|
150
|
+
r"""Add new streaming content."""
|
|
151
|
+
self.current_content += new_content
|
|
152
|
+
|
|
153
|
+
def add_tool_status(self, status_message: str):
|
|
154
|
+
r"""Add a tool status message."""
|
|
155
|
+
self.tool_status_messages.append(status_message)
|
|
156
|
+
|
|
157
|
+
def get_full_content(self) -> str:
|
|
158
|
+
r"""Get the complete accumulated content."""
|
|
159
|
+
tool_messages = "".join(self.tool_status_messages)
|
|
160
|
+
return self.base_content + tool_messages + self.current_content
|
|
161
|
+
|
|
162
|
+
def get_content_with_new_status(self, status_message: str) -> str:
|
|
163
|
+
r"""Get content with a new status message appended."""
|
|
164
|
+
tool_messages = "".join([*self.tool_status_messages, status_message])
|
|
165
|
+
return self.base_content + tool_messages + self.current_content
|
|
166
|
+
|
|
167
|
+
def reset_streaming_content(self):
|
|
168
|
+
r"""Reset only the streaming content, keep base and tool status."""
|
|
169
|
+
self.current_content = ""
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class StreamingChatAgentResponse:
|
|
173
|
+
r"""A wrapper that makes streaming responses compatible with
|
|
174
|
+
non-streaming code.
|
|
175
|
+
|
|
176
|
+
This class wraps a Generator[ChatAgentResponse, None, None] and provides
|
|
177
|
+
the same interface as ChatAgentResponse, so existing code doesn't need to
|
|
178
|
+
change.
|
|
179
|
+
"""
|
|
180
|
+
|
|
181
|
+
def __init__(self, generator: Generator[ChatAgentResponse, None, None]):
|
|
182
|
+
self._generator = generator
|
|
183
|
+
self._current_response: Optional[ChatAgentResponse] = None
|
|
184
|
+
self._responses: List[ChatAgentResponse] = []
|
|
185
|
+
self._consumed = False
|
|
186
|
+
|
|
187
|
+
def _ensure_latest_response(self):
|
|
188
|
+
r"""Ensure we have the latest response by consuming the generator."""
|
|
189
|
+
if not self._consumed:
|
|
190
|
+
try:
|
|
191
|
+
for response in self._generator:
|
|
192
|
+
self._responses.append(response)
|
|
193
|
+
self._current_response = response
|
|
194
|
+
self._consumed = True
|
|
195
|
+
except StopIteration:
|
|
196
|
+
self._consumed = True
|
|
197
|
+
|
|
198
|
+
@property
|
|
199
|
+
def msgs(self) -> List[BaseMessage]:
|
|
200
|
+
r"""Get messages from the latest response."""
|
|
201
|
+
self._ensure_latest_response()
|
|
202
|
+
if self._current_response:
|
|
203
|
+
return self._current_response.msgs
|
|
204
|
+
return []
|
|
205
|
+
|
|
206
|
+
@property
|
|
207
|
+
def terminated(self) -> bool:
|
|
208
|
+
r"""Get terminated status from the latest response."""
|
|
209
|
+
self._ensure_latest_response()
|
|
210
|
+
if self._current_response:
|
|
211
|
+
return self._current_response.terminated
|
|
212
|
+
return False
|
|
213
|
+
|
|
214
|
+
@property
|
|
215
|
+
def info(self) -> Dict[str, Any]:
|
|
216
|
+
r"""Get info from the latest response."""
|
|
217
|
+
self._ensure_latest_response()
|
|
218
|
+
if self._current_response:
|
|
219
|
+
return self._current_response.info
|
|
220
|
+
return {}
|
|
221
|
+
|
|
222
|
+
@property
|
|
223
|
+
def msg(self):
|
|
224
|
+
r"""Get the single message if there's exactly one message."""
|
|
225
|
+
self._ensure_latest_response()
|
|
226
|
+
if self._current_response:
|
|
227
|
+
return self._current_response.msg
|
|
228
|
+
return None
|
|
229
|
+
|
|
230
|
+
def __iter__(self):
|
|
231
|
+
r"""Make this object iterable."""
|
|
232
|
+
if self._consumed:
|
|
233
|
+
# If already consumed, iterate over stored responses
|
|
234
|
+
return iter(self._responses)
|
|
235
|
+
else:
|
|
236
|
+
# If not consumed, consume and yield
|
|
237
|
+
try:
|
|
238
|
+
for response in self._generator:
|
|
239
|
+
self._responses.append(response)
|
|
240
|
+
self._current_response = response
|
|
241
|
+
yield response
|
|
242
|
+
self._consumed = True
|
|
243
|
+
except StopIteration:
|
|
244
|
+
self._consumed = True
|
|
245
|
+
|
|
246
|
+
def __getattr__(self, name):
|
|
247
|
+
r"""Forward any other attribute access to the latest response."""
|
|
248
|
+
self._ensure_latest_response()
|
|
249
|
+
if self._current_response and hasattr(self._current_response, name):
|
|
250
|
+
return getattr(self._current_response, name)
|
|
251
|
+
raise AttributeError(
|
|
252
|
+
f"'StreamingChatAgentResponse' object has no attribute '{name}'"
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
class AsyncStreamingChatAgentResponse:
|
|
257
|
+
r"""A wrapper that makes async streaming responses awaitable and
|
|
258
|
+
compatible with non-streaming code.
|
|
259
|
+
|
|
260
|
+
This class wraps an AsyncGenerator[ChatAgentResponse, None] and provides
|
|
261
|
+
both awaitable and async iterable interfaces.
|
|
262
|
+
"""
|
|
263
|
+
|
|
264
|
+
def __init__(
|
|
265
|
+
self, async_generator: AsyncGenerator[ChatAgentResponse, None]
|
|
266
|
+
):
|
|
267
|
+
self._async_generator = async_generator
|
|
268
|
+
self._current_response: Optional[ChatAgentResponse] = None
|
|
269
|
+
self._responses: List[ChatAgentResponse] = []
|
|
270
|
+
self._consumed = False
|
|
271
|
+
|
|
272
|
+
async def _ensure_latest_response(self):
|
|
273
|
+
r"""Ensure the latest response by consuming the async generator."""
|
|
274
|
+
if not self._consumed:
|
|
275
|
+
try:
|
|
276
|
+
async for response in self._async_generator:
|
|
277
|
+
self._responses.append(response)
|
|
278
|
+
self._current_response = response
|
|
279
|
+
self._consumed = True
|
|
280
|
+
except StopAsyncIteration:
|
|
281
|
+
self._consumed = True
|
|
282
|
+
|
|
283
|
+
async def _get_final_response(self) -> ChatAgentResponse:
|
|
284
|
+
r"""Get the final response after consuming the entire stream."""
|
|
285
|
+
await self._ensure_latest_response()
|
|
286
|
+
if self._current_response:
|
|
287
|
+
return self._current_response
|
|
288
|
+
# Return a default response if nothing was consumed
|
|
289
|
+
return ChatAgentResponse(msgs=[], terminated=False, info={})
|
|
290
|
+
|
|
291
|
+
def __await__(self):
|
|
292
|
+
r"""Make this object awaitable - returns the final response."""
|
|
293
|
+
return self._get_final_response().__await__()
|
|
294
|
+
|
|
295
|
+
def __aiter__(self):
|
|
296
|
+
r"""Make this object async iterable."""
|
|
297
|
+
if self._consumed:
|
|
298
|
+
# If already consumed, create async iterator from stored responses
|
|
299
|
+
async def _async_iter():
|
|
300
|
+
for response in self._responses:
|
|
301
|
+
yield response
|
|
302
|
+
|
|
303
|
+
return _async_iter()
|
|
304
|
+
else:
|
|
305
|
+
# If not consumed, consume and yield
|
|
306
|
+
async def _consume_and_yield():
|
|
307
|
+
try:
|
|
308
|
+
async for response in self._async_generator:
|
|
309
|
+
self._responses.append(response)
|
|
310
|
+
self._current_response = response
|
|
311
|
+
yield response
|
|
312
|
+
self._consumed = True
|
|
313
|
+
except StopAsyncIteration:
|
|
314
|
+
self._consumed = True
|
|
315
|
+
|
|
316
|
+
return _consume_and_yield()
|
|
317
|
+
|
|
318
|
+
|
|
127
319
|
@track_agent(name="ChatAgent")
|
|
128
320
|
class ChatAgent(BaseAgent):
|
|
129
321
|
r"""Class for managing conversations of CAMEL Chat Agents.
|
|
@@ -176,6 +368,8 @@ class ChatAgent(BaseAgent):
|
|
|
176
368
|
stop_event (Optional[threading.Event], optional): Event to signal
|
|
177
369
|
termination of the agent's operation. When set, the agent will
|
|
178
370
|
terminate its execution. (default: :obj:`None`)
|
|
371
|
+
tool_execution_timeout (Optional[float], optional): Timeout
|
|
372
|
+
for individual tool execution. If None, wait indefinitely.
|
|
179
373
|
mask_tool_output (Optional[bool]): Whether to return a sanitized
|
|
180
374
|
placeholder instead of the raw tool output. (default: :obj:`False`)
|
|
181
375
|
pause_event (Optional[asyncio.Event]): Event to signal pause of the
|
|
@@ -214,6 +408,7 @@ class ChatAgent(BaseAgent):
|
|
|
214
408
|
max_iteration: Optional[int] = None,
|
|
215
409
|
agent_id: Optional[str] = None,
|
|
216
410
|
stop_event: Optional[threading.Event] = None,
|
|
411
|
+
tool_execution_timeout: Optional[float] = None,
|
|
217
412
|
mask_tool_output: bool = False,
|
|
218
413
|
pause_event: Optional[asyncio.Event] = None,
|
|
219
414
|
) -> None:
|
|
@@ -290,6 +485,7 @@ class ChatAgent(BaseAgent):
|
|
|
290
485
|
self.response_terminators = response_terminators or []
|
|
291
486
|
self.max_iteration = max_iteration
|
|
292
487
|
self.stop_event = stop_event
|
|
488
|
+
self.tool_execution_timeout = tool_execution_timeout
|
|
293
489
|
self.mask_tool_output = mask_tool_output
|
|
294
490
|
self._secure_result_store: Dict[str, Any] = {}
|
|
295
491
|
self._pending_images: List[str] = []
|
|
@@ -1106,7 +1302,7 @@ class ChatAgent(BaseAgent):
|
|
|
1106
1302
|
self,
|
|
1107
1303
|
input_message: Union[BaseMessage, str],
|
|
1108
1304
|
response_format: Optional[Type[BaseModel]] = None,
|
|
1109
|
-
) -> ChatAgentResponse:
|
|
1305
|
+
) -> Union[ChatAgentResponse, StreamingChatAgentResponse]:
|
|
1110
1306
|
r"""Executes a single step in the chat session, generating a response
|
|
1111
1307
|
to the input message.
|
|
1112
1308
|
|
|
@@ -1120,10 +1316,20 @@ class ChatAgent(BaseAgent):
|
|
|
1120
1316
|
:obj:`None`)
|
|
1121
1317
|
|
|
1122
1318
|
Returns:
|
|
1123
|
-
ChatAgentResponse:
|
|
1124
|
-
|
|
1319
|
+
Union[ChatAgentResponse, StreamingChatAgentResponse]: If stream is
|
|
1320
|
+
False, returns a ChatAgentResponse. If stream is True, returns
|
|
1321
|
+
a StreamingChatAgentResponse that behaves like
|
|
1322
|
+
ChatAgentResponse but can also be iterated for
|
|
1323
|
+
streaming updates.
|
|
1125
1324
|
"""
|
|
1126
1325
|
|
|
1326
|
+
stream = self.model_backend.model_config_dict.get("stream", False)
|
|
1327
|
+
|
|
1328
|
+
if stream:
|
|
1329
|
+
# Return wrapped generator that has ChatAgentResponse interface
|
|
1330
|
+
generator = self._stream(input_message, response_format)
|
|
1331
|
+
return StreamingChatAgentResponse(generator)
|
|
1332
|
+
|
|
1127
1333
|
# Set Langfuse session_id using agent_id for trace grouping
|
|
1128
1334
|
try:
|
|
1129
1335
|
from camel.utils.langfuse import set_current_agent_session_id
|
|
@@ -1222,9 +1428,8 @@ class ChatAgent(BaseAgent):
|
|
|
1222
1428
|
):
|
|
1223
1429
|
while not self.pause_event.is_set():
|
|
1224
1430
|
time.sleep(0.001)
|
|
1225
|
-
|
|
1226
|
-
|
|
1227
|
-
)
|
|
1431
|
+
result = self._execute_tool(tool_call_request)
|
|
1432
|
+
tool_call_records.append(result)
|
|
1228
1433
|
|
|
1229
1434
|
# If we found external tool calls, break the loop
|
|
1230
1435
|
if external_tool_call_requests:
|
|
@@ -1271,7 +1476,7 @@ class ChatAgent(BaseAgent):
|
|
|
1271
1476
|
self,
|
|
1272
1477
|
input_message: Union[BaseMessage, str],
|
|
1273
1478
|
response_format: Optional[Type[BaseModel]] = None,
|
|
1274
|
-
) -> ChatAgentResponse:
|
|
1479
|
+
) -> Union[ChatAgentResponse, AsyncStreamingChatAgentResponse]:
|
|
1275
1480
|
r"""Performs a single step in the chat session by generating a response
|
|
1276
1481
|
to the input message. This agent step can call async function calls.
|
|
1277
1482
|
|
|
@@ -1287,12 +1492,38 @@ class ChatAgent(BaseAgent):
|
|
|
1287
1492
|
used to generate a structured response by LLM. This schema
|
|
1288
1493
|
helps in defining the expected output format. (default:
|
|
1289
1494
|
:obj:`None`)
|
|
1290
|
-
|
|
1291
1495
|
Returns:
|
|
1292
|
-
ChatAgentResponse:
|
|
1293
|
-
|
|
1294
|
-
|
|
1496
|
+
Union[ChatAgentResponse, AsyncStreamingChatAgentResponse]:
|
|
1497
|
+
If stream is False, returns a ChatAgentResponse. If stream is
|
|
1498
|
+
True, returns an AsyncStreamingChatAgentResponse that can be
|
|
1499
|
+
awaited for the final result or async iterated for streaming
|
|
1500
|
+
updates.
|
|
1295
1501
|
"""
|
|
1502
|
+
|
|
1503
|
+
try:
|
|
1504
|
+
from camel.utils.langfuse import set_current_agent_session_id
|
|
1505
|
+
|
|
1506
|
+
set_current_agent_session_id(self.agent_id)
|
|
1507
|
+
except ImportError:
|
|
1508
|
+
pass # Langfuse not available
|
|
1509
|
+
|
|
1510
|
+
stream = self.model_backend.model_config_dict.get("stream", False)
|
|
1511
|
+
if stream:
|
|
1512
|
+
# Return wrapped async generator that is awaitable
|
|
1513
|
+
async_generator = self._astream(input_message, response_format)
|
|
1514
|
+
return AsyncStreamingChatAgentResponse(async_generator)
|
|
1515
|
+
else:
|
|
1516
|
+
return await self._astep_non_streaming_task(
|
|
1517
|
+
input_message, response_format
|
|
1518
|
+
)
|
|
1519
|
+
|
|
1520
|
+
async def _astep_non_streaming_task(
|
|
1521
|
+
self,
|
|
1522
|
+
input_message: Union[BaseMessage, str],
|
|
1523
|
+
response_format: Optional[Type[BaseModel]] = None,
|
|
1524
|
+
) -> ChatAgentResponse:
|
|
1525
|
+
r"""Internal async method for non-streaming astep logic."""
|
|
1526
|
+
|
|
1296
1527
|
try:
|
|
1297
1528
|
from camel.utils.langfuse import set_current_agent_session_id
|
|
1298
1529
|
|
|
@@ -1731,11 +1962,12 @@ class ChatAgent(BaseAgent):
|
|
|
1731
1962
|
f"index {self.model_backend.current_model_index}, "
|
|
1732
1963
|
f"processed these messages: {sanitized_messages}"
|
|
1733
1964
|
)
|
|
1734
|
-
|
|
1735
|
-
|
|
1736
|
-
|
|
1737
|
-
|
|
1738
|
-
|
|
1965
|
+
if not isinstance(response, ChatCompletion):
|
|
1966
|
+
raise TypeError(
|
|
1967
|
+
f"Expected response to be a `ChatCompletion` object, but "
|
|
1968
|
+
f"got {type(response).__name__} instead."
|
|
1969
|
+
)
|
|
1970
|
+
return self._handle_batch_response(response)
|
|
1739
1971
|
|
|
1740
1972
|
async def _aget_model_response(
|
|
1741
1973
|
self,
|
|
@@ -1799,11 +2031,12 @@ class ChatAgent(BaseAgent):
|
|
|
1799
2031
|
f"index {self.model_backend.current_model_index}, "
|
|
1800
2032
|
f"processed these messages: {sanitized_messages}"
|
|
1801
2033
|
)
|
|
1802
|
-
|
|
1803
|
-
|
|
1804
|
-
|
|
1805
|
-
|
|
1806
|
-
|
|
2034
|
+
if not isinstance(response, ChatCompletion):
|
|
2035
|
+
raise TypeError(
|
|
2036
|
+
f"Expected response to be a `ChatCompletion` object, but "
|
|
2037
|
+
f"got {type(response).__name__} instead."
|
|
2038
|
+
)
|
|
2039
|
+
return self._handle_batch_response(response)
|
|
1807
2040
|
|
|
1808
2041
|
def _sanitize_messages_for_logging(self, messages):
|
|
1809
2042
|
r"""Sanitize OpenAI messages for logging by replacing base64 image
|
|
@@ -2065,116 +2298,6 @@ class ChatAgent(BaseAgent):
|
|
|
2065
2298
|
response_id=response.id or "",
|
|
2066
2299
|
)
|
|
2067
2300
|
|
|
2068
|
-
def _handle_stream_response(
|
|
2069
|
-
self,
|
|
2070
|
-
response: Stream[ChatCompletionChunk],
|
|
2071
|
-
prompt_tokens: int,
|
|
2072
|
-
) -> ModelResponse:
|
|
2073
|
-
r"""Process a stream response from the model and extract the necessary
|
|
2074
|
-
information.
|
|
2075
|
-
|
|
2076
|
-
Args:
|
|
2077
|
-
response (dict): Model response.
|
|
2078
|
-
prompt_tokens (int): Number of input prompt tokens.
|
|
2079
|
-
|
|
2080
|
-
Returns:
|
|
2081
|
-
_ModelResponse: a parsed model response.
|
|
2082
|
-
"""
|
|
2083
|
-
content_dict: defaultdict = defaultdict(lambda: "")
|
|
2084
|
-
finish_reasons_dict: defaultdict = defaultdict(lambda: "")
|
|
2085
|
-
output_messages: List[BaseMessage] = []
|
|
2086
|
-
response_id: str = ""
|
|
2087
|
-
# All choices in one response share one role
|
|
2088
|
-
for chunk in response:
|
|
2089
|
-
# Some model platforms like siliconflow may return None for the
|
|
2090
|
-
# chunk.id
|
|
2091
|
-
response_id = chunk.id if chunk.id else str(uuid.uuid4())
|
|
2092
|
-
self._handle_chunk(
|
|
2093
|
-
chunk, content_dict, finish_reasons_dict, output_messages
|
|
2094
|
-
)
|
|
2095
|
-
finish_reasons = [
|
|
2096
|
-
finish_reasons_dict[i] for i in range(len(finish_reasons_dict))
|
|
2097
|
-
]
|
|
2098
|
-
usage_dict = self.get_usage_dict(output_messages, prompt_tokens)
|
|
2099
|
-
|
|
2100
|
-
# TODO: Handle tool calls
|
|
2101
|
-
return ModelResponse(
|
|
2102
|
-
response=response,
|
|
2103
|
-
tool_call_requests=None,
|
|
2104
|
-
output_messages=output_messages,
|
|
2105
|
-
finish_reasons=finish_reasons,
|
|
2106
|
-
usage_dict=usage_dict,
|
|
2107
|
-
response_id=response_id,
|
|
2108
|
-
)
|
|
2109
|
-
|
|
2110
|
-
async def _ahandle_stream_response(
|
|
2111
|
-
self,
|
|
2112
|
-
response: AsyncStream[ChatCompletionChunk],
|
|
2113
|
-
prompt_tokens: int,
|
|
2114
|
-
) -> ModelResponse:
|
|
2115
|
-
r"""Process a stream response from the model and extract the necessary
|
|
2116
|
-
information.
|
|
2117
|
-
|
|
2118
|
-
Args:
|
|
2119
|
-
response (dict): Model response.
|
|
2120
|
-
prompt_tokens (int): Number of input prompt tokens.
|
|
2121
|
-
|
|
2122
|
-
Returns:
|
|
2123
|
-
_ModelResponse: a parsed model response.
|
|
2124
|
-
"""
|
|
2125
|
-
content_dict: defaultdict = defaultdict(lambda: "")
|
|
2126
|
-
finish_reasons_dict: defaultdict = defaultdict(lambda: "")
|
|
2127
|
-
output_messages: List[BaseMessage] = []
|
|
2128
|
-
response_id: str = ""
|
|
2129
|
-
# All choices in one response share one role
|
|
2130
|
-
async for chunk in response:
|
|
2131
|
-
# Some model platforms like siliconflow may return None for the
|
|
2132
|
-
# chunk.id
|
|
2133
|
-
response_id = chunk.id if chunk.id else str(uuid.uuid4())
|
|
2134
|
-
self._handle_chunk(
|
|
2135
|
-
chunk, content_dict, finish_reasons_dict, output_messages
|
|
2136
|
-
)
|
|
2137
|
-
finish_reasons = [
|
|
2138
|
-
finish_reasons_dict[i] for i in range(len(finish_reasons_dict))
|
|
2139
|
-
]
|
|
2140
|
-
usage_dict = self.get_usage_dict(output_messages, prompt_tokens)
|
|
2141
|
-
|
|
2142
|
-
# TODO: Handle tool calls
|
|
2143
|
-
return ModelResponse(
|
|
2144
|
-
response=response,
|
|
2145
|
-
tool_call_requests=None,
|
|
2146
|
-
output_messages=output_messages,
|
|
2147
|
-
finish_reasons=finish_reasons,
|
|
2148
|
-
usage_dict=usage_dict,
|
|
2149
|
-
response_id=response_id,
|
|
2150
|
-
)
|
|
2151
|
-
|
|
2152
|
-
def _handle_chunk(
|
|
2153
|
-
self,
|
|
2154
|
-
chunk: ChatCompletionChunk,
|
|
2155
|
-
content_dict: defaultdict,
|
|
2156
|
-
finish_reasons_dict: defaultdict,
|
|
2157
|
-
output_messages: List[BaseMessage],
|
|
2158
|
-
) -> None:
|
|
2159
|
-
r"""Handle a chunk of the model response."""
|
|
2160
|
-
for choice in chunk.choices:
|
|
2161
|
-
index = choice.index
|
|
2162
|
-
delta = choice.delta
|
|
2163
|
-
if delta.content is not None:
|
|
2164
|
-
content_dict[index] += delta.content
|
|
2165
|
-
|
|
2166
|
-
if not choice.finish_reason:
|
|
2167
|
-
continue
|
|
2168
|
-
|
|
2169
|
-
finish_reasons_dict[index] = choice.finish_reason
|
|
2170
|
-
chat_message = BaseMessage(
|
|
2171
|
-
role_name=self.role_name,
|
|
2172
|
-
role_type=self.role_type,
|
|
2173
|
-
meta_dict=dict(),
|
|
2174
|
-
content=content_dict[index],
|
|
2175
|
-
)
|
|
2176
|
-
output_messages.append(chat_message)
|
|
2177
|
-
|
|
2178
2301
|
def _step_terminate(
|
|
2179
2302
|
self,
|
|
2180
2303
|
num_tokens: int,
|
|
@@ -2309,7 +2432,7 @@ class ChatAgent(BaseAgent):
|
|
|
2309
2432
|
except Exception as e:
|
|
2310
2433
|
# Capture the error message to prevent framework crash
|
|
2311
2434
|
error_msg = f"Error executing async tool '{func_name}': {e!s}"
|
|
2312
|
-
result =
|
|
2435
|
+
result = f"Tool execution failed: {error_msg}"
|
|
2313
2436
|
logging.warning(error_msg)
|
|
2314
2437
|
|
|
2315
2438
|
# Check if result is a ToolResult with images
|
|
@@ -2406,6 +2529,1230 @@ class ChatAgent(BaseAgent):
|
|
|
2406
2529
|
|
|
2407
2530
|
return tool_record
|
|
2408
2531
|
|
|
2532
|
+
def _stream(
|
|
2533
|
+
self,
|
|
2534
|
+
input_message: Union[BaseMessage, str],
|
|
2535
|
+
response_format: Optional[Type[BaseModel]] = None,
|
|
2536
|
+
) -> Generator[ChatAgentResponse, None, None]:
|
|
2537
|
+
r"""Executes a streaming step in the chat session, yielding
|
|
2538
|
+
intermediate responses as they are generated.
|
|
2539
|
+
|
|
2540
|
+
Args:
|
|
2541
|
+
input_message (Union[BaseMessage, str]): The input message for the
|
|
2542
|
+
agent.
|
|
2543
|
+
response_format (Optional[Type[BaseModel]], optional): A Pydantic
|
|
2544
|
+
model defining the expected structure of the response.
|
|
2545
|
+
|
|
2546
|
+
Yields:
|
|
2547
|
+
ChatAgentResponse: Intermediate responses containing partial
|
|
2548
|
+
content, tool calls, and other information as they become
|
|
2549
|
+
available.
|
|
2550
|
+
"""
|
|
2551
|
+
# Convert input message to BaseMessage if necessary
|
|
2552
|
+
if isinstance(input_message, str):
|
|
2553
|
+
input_message = BaseMessage.make_user_message(
|
|
2554
|
+
role_name="User", content=input_message
|
|
2555
|
+
)
|
|
2556
|
+
|
|
2557
|
+
# Add user input to memory
|
|
2558
|
+
self.update_memory(input_message, OpenAIBackendRole.USER)
|
|
2559
|
+
|
|
2560
|
+
# Get context for streaming
|
|
2561
|
+
try:
|
|
2562
|
+
openai_messages, num_tokens = self.memory.get_context()
|
|
2563
|
+
except RuntimeError as e:
|
|
2564
|
+
yield self._step_terminate(e.args[1], [], "max_tokens_exceeded")
|
|
2565
|
+
return
|
|
2566
|
+
|
|
2567
|
+
# Start streaming response
|
|
2568
|
+
yield from self._stream_response(
|
|
2569
|
+
openai_messages, num_tokens, response_format
|
|
2570
|
+
)
|
|
2571
|
+
|
|
2572
|
+
def _get_token_count(self, content: str) -> int:
|
|
2573
|
+
r"""Get token count for content with fallback."""
|
|
2574
|
+
if hasattr(self.model_backend, 'token_counter'):
|
|
2575
|
+
return len(self.model_backend.token_counter.encode(content))
|
|
2576
|
+
else:
|
|
2577
|
+
return len(content.split())
|
|
2578
|
+
|
|
2579
|
+
def _stream_response(
|
|
2580
|
+
self,
|
|
2581
|
+
openai_messages: List[OpenAIMessage],
|
|
2582
|
+
num_tokens: int,
|
|
2583
|
+
response_format: Optional[Type[BaseModel]] = None,
|
|
2584
|
+
) -> Generator[ChatAgentResponse, None, None]:
|
|
2585
|
+
r"""Internal method to handle streaming responses with tool calls."""
|
|
2586
|
+
|
|
2587
|
+
tool_call_records: List[ToolCallingRecord] = []
|
|
2588
|
+
accumulated_tool_calls: Dict[str, Any] = {}
|
|
2589
|
+
step_token_usage = self._create_token_usage_tracker()
|
|
2590
|
+
|
|
2591
|
+
# Create content accumulator for proper content management
|
|
2592
|
+
content_accumulator = StreamContentAccumulator()
|
|
2593
|
+
iteration_count = 0
|
|
2594
|
+
while True:
|
|
2595
|
+
# Check termination condition
|
|
2596
|
+
if self.stop_event and self.stop_event.is_set():
|
|
2597
|
+
yield self._step_terminate(
|
|
2598
|
+
num_tokens, tool_call_records, "termination_triggered"
|
|
2599
|
+
)
|
|
2600
|
+
return
|
|
2601
|
+
|
|
2602
|
+
# Get streaming response from model
|
|
2603
|
+
try:
|
|
2604
|
+
response = self.model_backend.run(
|
|
2605
|
+
openai_messages,
|
|
2606
|
+
response_format,
|
|
2607
|
+
self._get_full_tool_schemas() or None,
|
|
2608
|
+
)
|
|
2609
|
+
iteration_count += 1
|
|
2610
|
+
except Exception as exc:
|
|
2611
|
+
logger.error(
|
|
2612
|
+
f"Error in streaming model response: {exc}", exc_info=exc
|
|
2613
|
+
)
|
|
2614
|
+
yield self._create_error_response(str(exc), tool_call_records)
|
|
2615
|
+
return
|
|
2616
|
+
|
|
2617
|
+
# Handle streaming response
|
|
2618
|
+
if isinstance(response, Stream):
|
|
2619
|
+
(
|
|
2620
|
+
stream_completed,
|
|
2621
|
+
tool_calls_complete,
|
|
2622
|
+
) = yield from self._process_stream_chunks_with_accumulator(
|
|
2623
|
+
response,
|
|
2624
|
+
content_accumulator,
|
|
2625
|
+
accumulated_tool_calls,
|
|
2626
|
+
tool_call_records,
|
|
2627
|
+
step_token_usage,
|
|
2628
|
+
response_format,
|
|
2629
|
+
)
|
|
2630
|
+
|
|
2631
|
+
if tool_calls_complete:
|
|
2632
|
+
# Clear completed tool calls
|
|
2633
|
+
accumulated_tool_calls.clear()
|
|
2634
|
+
|
|
2635
|
+
# If we executed tools and not in
|
|
2636
|
+
# single iteration mode, continue
|
|
2637
|
+
if tool_call_records and (
|
|
2638
|
+
self.max_iteration is None
|
|
2639
|
+
or iteration_count < self.max_iteration
|
|
2640
|
+
):
|
|
2641
|
+
# Update messages with tool results for next iteration
|
|
2642
|
+
try:
|
|
2643
|
+
openai_messages, num_tokens = (
|
|
2644
|
+
self.memory.get_context()
|
|
2645
|
+
)
|
|
2646
|
+
except RuntimeError as e:
|
|
2647
|
+
yield self._step_terminate(
|
|
2648
|
+
e.args[1],
|
|
2649
|
+
tool_call_records,
|
|
2650
|
+
"max_tokens_exceeded",
|
|
2651
|
+
)
|
|
2652
|
+
return
|
|
2653
|
+
# Reset streaming content for next iteration
|
|
2654
|
+
content_accumulator.reset_streaming_content()
|
|
2655
|
+
continue
|
|
2656
|
+
else:
|
|
2657
|
+
break
|
|
2658
|
+
else:
|
|
2659
|
+
# Stream completed without tool calls
|
|
2660
|
+
accumulated_tool_calls.clear()
|
|
2661
|
+
break
|
|
2662
|
+
elif hasattr(response, '__enter__') and hasattr(
|
|
2663
|
+
response, '__exit__'
|
|
2664
|
+
):
|
|
2665
|
+
# Handle structured output stream (ChatCompletionStreamManager)
|
|
2666
|
+
with response as stream:
|
|
2667
|
+
parsed_object = None
|
|
2668
|
+
|
|
2669
|
+
for event in stream:
|
|
2670
|
+
if event.type == "content.delta":
|
|
2671
|
+
if getattr(event, "delta", None):
|
|
2672
|
+
# Use accumulator for proper content management
|
|
2673
|
+
partial_response = self._create_streaming_response_with_accumulator( # noqa: E501
|
|
2674
|
+
content_accumulator,
|
|
2675
|
+
getattr(event, "delta", ""),
|
|
2676
|
+
step_token_usage,
|
|
2677
|
+
tool_call_records=tool_call_records.copy(),
|
|
2678
|
+
)
|
|
2679
|
+
yield partial_response
|
|
2680
|
+
|
|
2681
|
+
elif event.type == "content.done":
|
|
2682
|
+
parsed_object = getattr(event, "parsed", None)
|
|
2683
|
+
break
|
|
2684
|
+
elif event.type == "error":
|
|
2685
|
+
logger.error(
|
|
2686
|
+
f"Error in structured stream: "
|
|
2687
|
+
f"{getattr(event, 'error', '')}"
|
|
2688
|
+
)
|
|
2689
|
+
yield self._create_error_response(
|
|
2690
|
+
str(getattr(event, 'error', '')),
|
|
2691
|
+
tool_call_records,
|
|
2692
|
+
)
|
|
2693
|
+
return
|
|
2694
|
+
|
|
2695
|
+
# Get final completion and record final message
|
|
2696
|
+
try:
|
|
2697
|
+
final_completion = stream.get_final_completion()
|
|
2698
|
+
final_content = (
|
|
2699
|
+
final_completion.choices[0].message.content or ""
|
|
2700
|
+
)
|
|
2701
|
+
|
|
2702
|
+
final_message = BaseMessage(
|
|
2703
|
+
role_name=self.role_name,
|
|
2704
|
+
role_type=self.role_type,
|
|
2705
|
+
meta_dict={},
|
|
2706
|
+
content=final_content,
|
|
2707
|
+
parsed=cast(
|
|
2708
|
+
"BaseModel | dict[str, Any] | None",
|
|
2709
|
+
parsed_object,
|
|
2710
|
+
), # type: ignore[arg-type]
|
|
2711
|
+
)
|
|
2712
|
+
|
|
2713
|
+
self.record_message(final_message)
|
|
2714
|
+
|
|
2715
|
+
# Create final response
|
|
2716
|
+
final_response = ChatAgentResponse(
|
|
2717
|
+
msgs=[final_message],
|
|
2718
|
+
terminated=False,
|
|
2719
|
+
info={
|
|
2720
|
+
"id": final_completion.id or "",
|
|
2721
|
+
"usage": safe_model_dump(
|
|
2722
|
+
final_completion.usage
|
|
2723
|
+
)
|
|
2724
|
+
if final_completion.usage
|
|
2725
|
+
else {},
|
|
2726
|
+
"finish_reasons": [
|
|
2727
|
+
choice.finish_reason or "stop"
|
|
2728
|
+
for choice in final_completion.choices
|
|
2729
|
+
],
|
|
2730
|
+
"num_tokens": self._get_token_count(
|
|
2731
|
+
final_content
|
|
2732
|
+
),
|
|
2733
|
+
"tool_calls": tool_call_records,
|
|
2734
|
+
"external_tool_requests": None,
|
|
2735
|
+
"streaming": False,
|
|
2736
|
+
"partial": False,
|
|
2737
|
+
},
|
|
2738
|
+
)
|
|
2739
|
+
yield final_response
|
|
2740
|
+
break
|
|
2741
|
+
|
|
2742
|
+
except Exception as e:
|
|
2743
|
+
logger.error(f"Error getting final completion: {e}")
|
|
2744
|
+
yield self._create_error_response(
|
|
2745
|
+
str(e), tool_call_records
|
|
2746
|
+
)
|
|
2747
|
+
return
|
|
2748
|
+
else:
|
|
2749
|
+
# Handle non-streaming response (fallback)
|
|
2750
|
+
model_response = self._handle_batch_response(response)
|
|
2751
|
+
yield self._convert_to_chatagent_response(
|
|
2752
|
+
model_response,
|
|
2753
|
+
tool_call_records,
|
|
2754
|
+
num_tokens,
|
|
2755
|
+
None,
|
|
2756
|
+
model_response.usage_dict.get("prompt_tokens", 0),
|
|
2757
|
+
model_response.usage_dict.get("completion_tokens", 0),
|
|
2758
|
+
model_response.usage_dict.get("total_tokens", 0),
|
|
2759
|
+
)
|
|
2760
|
+
accumulated_tool_calls.clear()
|
|
2761
|
+
break
|
|
2762
|
+
|
|
2763
|
+
def _process_stream_chunks_with_accumulator(
|
|
2764
|
+
self,
|
|
2765
|
+
stream: Stream[ChatCompletionChunk],
|
|
2766
|
+
content_accumulator: StreamContentAccumulator,
|
|
2767
|
+
accumulated_tool_calls: Dict[str, Any],
|
|
2768
|
+
tool_call_records: List[ToolCallingRecord],
|
|
2769
|
+
step_token_usage: Dict[str, int],
|
|
2770
|
+
response_format: Optional[Type[BaseModel]] = None,
|
|
2771
|
+
) -> Generator[ChatAgentResponse, None, Tuple[bool, bool]]:
|
|
2772
|
+
r"""Process streaming chunks with content accumulator."""
|
|
2773
|
+
|
|
2774
|
+
tool_calls_complete = False
|
|
2775
|
+
stream_completed = False
|
|
2776
|
+
|
|
2777
|
+
for chunk in stream:
|
|
2778
|
+
# Update token usage if available
|
|
2779
|
+
if chunk.usage:
|
|
2780
|
+
self._update_token_usage_tracker(
|
|
2781
|
+
step_token_usage, safe_model_dump(chunk.usage)
|
|
2782
|
+
)
|
|
2783
|
+
|
|
2784
|
+
# Process chunk delta
|
|
2785
|
+
if chunk.choices and len(chunk.choices) > 0:
|
|
2786
|
+
choice = chunk.choices[0]
|
|
2787
|
+
delta = choice.delta
|
|
2788
|
+
|
|
2789
|
+
# Handle content streaming
|
|
2790
|
+
if delta.content:
|
|
2791
|
+
# Use accumulator for proper content management
|
|
2792
|
+
partial_response = (
|
|
2793
|
+
self._create_streaming_response_with_accumulator(
|
|
2794
|
+
content_accumulator,
|
|
2795
|
+
delta.content,
|
|
2796
|
+
step_token_usage,
|
|
2797
|
+
getattr(chunk, 'id', ''),
|
|
2798
|
+
tool_call_records.copy(),
|
|
2799
|
+
)
|
|
2800
|
+
)
|
|
2801
|
+
yield partial_response
|
|
2802
|
+
|
|
2803
|
+
# Handle tool calls streaming
|
|
2804
|
+
if delta.tool_calls:
|
|
2805
|
+
tool_calls_complete = self._accumulate_tool_calls(
|
|
2806
|
+
delta.tool_calls, accumulated_tool_calls
|
|
2807
|
+
)
|
|
2808
|
+
|
|
2809
|
+
# Check if stream is complete
|
|
2810
|
+
if choice.finish_reason:
|
|
2811
|
+
stream_completed = True
|
|
2812
|
+
|
|
2813
|
+
# If we have complete tool calls, execute them with
|
|
2814
|
+
# sync status updates
|
|
2815
|
+
if accumulated_tool_calls:
|
|
2816
|
+
# Record assistant message with tool calls first
|
|
2817
|
+
self._record_assistant_tool_calls_message(
|
|
2818
|
+
accumulated_tool_calls,
|
|
2819
|
+
content_accumulator.get_full_content(),
|
|
2820
|
+
)
|
|
2821
|
+
|
|
2822
|
+
# Execute tools synchronously with
|
|
2823
|
+
# optimized status updates
|
|
2824
|
+
for (
|
|
2825
|
+
status_response
|
|
2826
|
+
) in self._execute_tools_sync_with_status_accumulator(
|
|
2827
|
+
accumulated_tool_calls,
|
|
2828
|
+
content_accumulator,
|
|
2829
|
+
step_token_usage,
|
|
2830
|
+
tool_call_records,
|
|
2831
|
+
):
|
|
2832
|
+
yield status_response
|
|
2833
|
+
|
|
2834
|
+
# Yield "Sending back result to model" status
|
|
2835
|
+
if tool_call_records:
|
|
2836
|
+
sending_status = self._create_tool_status_response_with_accumulator( # noqa: E501
|
|
2837
|
+
content_accumulator,
|
|
2838
|
+
"\n------\n\nSending back result to model\n\n",
|
|
2839
|
+
"tool_sending",
|
|
2840
|
+
step_token_usage,
|
|
2841
|
+
)
|
|
2842
|
+
yield sending_status
|
|
2843
|
+
|
|
2844
|
+
# Record final message only if we have content AND no tool
|
|
2845
|
+
# calls. If there are tool calls, _record_tool_calling
|
|
2846
|
+
# will handle message recording.
|
|
2847
|
+
final_content = content_accumulator.get_full_content()
|
|
2848
|
+
if final_content.strip() and not accumulated_tool_calls:
|
|
2849
|
+
final_message = BaseMessage(
|
|
2850
|
+
role_name=self.role_name,
|
|
2851
|
+
role_type=self.role_type,
|
|
2852
|
+
meta_dict={},
|
|
2853
|
+
content=final_content,
|
|
2854
|
+
)
|
|
2855
|
+
|
|
2856
|
+
if response_format:
|
|
2857
|
+
self._try_format_message(
|
|
2858
|
+
final_message, response_format
|
|
2859
|
+
)
|
|
2860
|
+
|
|
2861
|
+
self.record_message(final_message)
|
|
2862
|
+
break
|
|
2863
|
+
|
|
2864
|
+
return stream_completed, tool_calls_complete
|
|
2865
|
+
|
|
2866
|
+
def _accumulate_tool_calls(
|
|
2867
|
+
self,
|
|
2868
|
+
tool_call_deltas: List[Any],
|
|
2869
|
+
accumulated_tool_calls: Dict[str, Any],
|
|
2870
|
+
) -> bool:
|
|
2871
|
+
r"""Accumulate tool call chunks and return True when
|
|
2872
|
+
any tool call is complete.
|
|
2873
|
+
|
|
2874
|
+
Args:
|
|
2875
|
+
tool_call_deltas (List[Any]): List of tool call deltas.
|
|
2876
|
+
accumulated_tool_calls (Dict[str, Any]): Dictionary of accumulated
|
|
2877
|
+
tool calls.
|
|
2878
|
+
|
|
2879
|
+
Returns:
|
|
2880
|
+
bool: True if any tool call is complete, False otherwise.
|
|
2881
|
+
"""
|
|
2882
|
+
|
|
2883
|
+
for delta_tool_call in tool_call_deltas:
|
|
2884
|
+
index = delta_tool_call.index
|
|
2885
|
+
tool_call_id = getattr(delta_tool_call, 'id', None)
|
|
2886
|
+
|
|
2887
|
+
# Initialize tool call entry if not exists
|
|
2888
|
+
if index not in accumulated_tool_calls:
|
|
2889
|
+
accumulated_tool_calls[index] = {
|
|
2890
|
+
'id': '',
|
|
2891
|
+
'type': 'function',
|
|
2892
|
+
'function': {'name': '', 'arguments': ''},
|
|
2893
|
+
'complete': False,
|
|
2894
|
+
}
|
|
2895
|
+
|
|
2896
|
+
tool_call_entry = accumulated_tool_calls[index]
|
|
2897
|
+
|
|
2898
|
+
# Accumulate tool call data
|
|
2899
|
+
if tool_call_id:
|
|
2900
|
+
tool_call_entry['id'] = (
|
|
2901
|
+
tool_call_id # Set full ID, don't append
|
|
2902
|
+
)
|
|
2903
|
+
|
|
2904
|
+
if (
|
|
2905
|
+
hasattr(delta_tool_call, 'function')
|
|
2906
|
+
and delta_tool_call.function
|
|
2907
|
+
):
|
|
2908
|
+
if delta_tool_call.function.name:
|
|
2909
|
+
tool_call_entry['function']['name'] += (
|
|
2910
|
+
delta_tool_call.function.name
|
|
2911
|
+
) # Append incremental name
|
|
2912
|
+
if delta_tool_call.function.arguments:
|
|
2913
|
+
tool_call_entry['function']['arguments'] += (
|
|
2914
|
+
delta_tool_call.function.arguments
|
|
2915
|
+
)
|
|
2916
|
+
|
|
2917
|
+
# Check if any tool calls are complete
|
|
2918
|
+
any_complete = False
|
|
2919
|
+
for _index, tool_call_entry in accumulated_tool_calls.items():
|
|
2920
|
+
if (
|
|
2921
|
+
tool_call_entry['id']
|
|
2922
|
+
and tool_call_entry['function']['name']
|
|
2923
|
+
and tool_call_entry['function']['arguments']
|
|
2924
|
+
and tool_call_entry['function']['name'] in self._internal_tools
|
|
2925
|
+
):
|
|
2926
|
+
try:
|
|
2927
|
+
# Try to parse arguments to check completeness
|
|
2928
|
+
json.loads(tool_call_entry['function']['arguments'])
|
|
2929
|
+
tool_call_entry['complete'] = True
|
|
2930
|
+
any_complete = True
|
|
2931
|
+
except json.JSONDecodeError:
|
|
2932
|
+
# Arguments not complete yet
|
|
2933
|
+
tool_call_entry['complete'] = False
|
|
2934
|
+
|
|
2935
|
+
return any_complete
|
|
2936
|
+
|
|
2937
|
+
def _execute_tools_sync_with_status_accumulator(
|
|
2938
|
+
self,
|
|
2939
|
+
accumulated_tool_calls: Dict[str, Any],
|
|
2940
|
+
content_accumulator: StreamContentAccumulator,
|
|
2941
|
+
step_token_usage: Dict[str, int],
|
|
2942
|
+
tool_call_records: List[ToolCallingRecord],
|
|
2943
|
+
) -> Generator[ChatAgentResponse, None, None]:
|
|
2944
|
+
r"""Execute multiple tools synchronously with
|
|
2945
|
+
proper content accumulation, using threads+queue for
|
|
2946
|
+
non-blocking status streaming."""
|
|
2947
|
+
|
|
2948
|
+
def tool_worker(tool_func, args, result_queue, tool_call_data):
|
|
2949
|
+
try:
|
|
2950
|
+
tool_call_record = self._execute_tool_from_stream_data(
|
|
2951
|
+
tool_call_data
|
|
2952
|
+
)
|
|
2953
|
+
result_queue.put(tool_call_record)
|
|
2954
|
+
except Exception as e:
|
|
2955
|
+
logger.error(f"Error in threaded tool execution: {e}")
|
|
2956
|
+
result_queue.put(None)
|
|
2957
|
+
|
|
2958
|
+
tool_calls_to_execute = []
|
|
2959
|
+
for _tool_call_index, tool_call_data in accumulated_tool_calls.items():
|
|
2960
|
+
if tool_call_data.get('complete', False):
|
|
2961
|
+
tool_calls_to_execute.append(tool_call_data)
|
|
2962
|
+
|
|
2963
|
+
# Phase 2: Execute tools in threads and yield status while waiting
|
|
2964
|
+
for tool_call_data in tool_calls_to_execute:
|
|
2965
|
+
function_name = tool_call_data['function']['name']
|
|
2966
|
+
try:
|
|
2967
|
+
args = json.loads(tool_call_data['function']['arguments'])
|
|
2968
|
+
except json.JSONDecodeError:
|
|
2969
|
+
args = tool_call_data['function']['arguments']
|
|
2970
|
+
result_queue: queue.Queue[Optional[ToolCallingRecord]] = (
|
|
2971
|
+
queue.Queue()
|
|
2972
|
+
)
|
|
2973
|
+
thread = threading.Thread(
|
|
2974
|
+
target=tool_worker,
|
|
2975
|
+
args=(
|
|
2976
|
+
self._internal_tools[function_name],
|
|
2977
|
+
args,
|
|
2978
|
+
result_queue,
|
|
2979
|
+
tool_call_data,
|
|
2980
|
+
),
|
|
2981
|
+
)
|
|
2982
|
+
thread.start()
|
|
2983
|
+
|
|
2984
|
+
status_message = (
|
|
2985
|
+
f"\nCalling function: {function_name} "
|
|
2986
|
+
f"with arguments:\n{args}\n"
|
|
2987
|
+
)
|
|
2988
|
+
status_status = self._create_tool_status_response_with_accumulator(
|
|
2989
|
+
content_accumulator,
|
|
2990
|
+
status_message,
|
|
2991
|
+
"tool_calling",
|
|
2992
|
+
step_token_usage,
|
|
2993
|
+
)
|
|
2994
|
+
yield status_status
|
|
2995
|
+
# wait for tool thread to finish with optional timeout
|
|
2996
|
+
thread.join(self.tool_execution_timeout)
|
|
2997
|
+
|
|
2998
|
+
# If timeout occurred, mark as error and continue
|
|
2999
|
+
if thread.is_alive():
|
|
3000
|
+
timeout_msg = (
|
|
3001
|
+
f"\nFunction '{function_name}' timed out after "
|
|
3002
|
+
f"{self.tool_execution_timeout} seconds.\n---------\n"
|
|
3003
|
+
)
|
|
3004
|
+
timeout_status = (
|
|
3005
|
+
self._create_tool_status_response_with_accumulator(
|
|
3006
|
+
content_accumulator,
|
|
3007
|
+
timeout_msg,
|
|
3008
|
+
"tool_timeout",
|
|
3009
|
+
step_token_usage,
|
|
3010
|
+
)
|
|
3011
|
+
)
|
|
3012
|
+
yield timeout_status
|
|
3013
|
+
logger.error(timeout_msg.strip())
|
|
3014
|
+
# Detach thread (it may still finish later). Skip recording.
|
|
3015
|
+
continue
|
|
3016
|
+
|
|
3017
|
+
# Tool finished, get result
|
|
3018
|
+
tool_call_record = result_queue.get()
|
|
3019
|
+
if tool_call_record:
|
|
3020
|
+
tool_call_records.append(tool_call_record)
|
|
3021
|
+
raw_result = tool_call_record.result
|
|
3022
|
+
result_str = str(raw_result)
|
|
3023
|
+
status_message = (
|
|
3024
|
+
f"\nFunction output: {result_str}\n---------\n"
|
|
3025
|
+
)
|
|
3026
|
+
output_status = (
|
|
3027
|
+
self._create_tool_status_response_with_accumulator(
|
|
3028
|
+
content_accumulator,
|
|
3029
|
+
status_message,
|
|
3030
|
+
"tool_output",
|
|
3031
|
+
step_token_usage,
|
|
3032
|
+
[tool_call_record],
|
|
3033
|
+
)
|
|
3034
|
+
)
|
|
3035
|
+
yield output_status
|
|
3036
|
+
else:
|
|
3037
|
+
# Error already logged
|
|
3038
|
+
continue
|
|
3039
|
+
|
|
3040
|
+
def _execute_tool_from_stream_data(
|
|
3041
|
+
self, tool_call_data: Dict[str, Any]
|
|
3042
|
+
) -> Optional[ToolCallingRecord]:
|
|
3043
|
+
r"""Execute a tool from accumulated stream data."""
|
|
3044
|
+
|
|
3045
|
+
try:
|
|
3046
|
+
function_name = tool_call_data['function']['name']
|
|
3047
|
+
args = json.loads(tool_call_data['function']['arguments'])
|
|
3048
|
+
tool_call_id = tool_call_data['id']
|
|
3049
|
+
|
|
3050
|
+
if function_name in self._internal_tools:
|
|
3051
|
+
tool = self._internal_tools[function_name]
|
|
3052
|
+
try:
|
|
3053
|
+
result = tool(**args)
|
|
3054
|
+
|
|
3055
|
+
# Only record the tool response message, not the assistant
|
|
3056
|
+
# message assistant message with tool_calls was already
|
|
3057
|
+
# recorded in _record_assistant_tool_calls_message
|
|
3058
|
+
func_msg = FunctionCallingMessage(
|
|
3059
|
+
role_name=self.role_name,
|
|
3060
|
+
role_type=self.role_type,
|
|
3061
|
+
meta_dict=None,
|
|
3062
|
+
content="",
|
|
3063
|
+
func_name=function_name,
|
|
3064
|
+
result=result,
|
|
3065
|
+
tool_call_id=tool_call_id,
|
|
3066
|
+
)
|
|
3067
|
+
|
|
3068
|
+
self.update_memory(func_msg, OpenAIBackendRole.FUNCTION)
|
|
3069
|
+
|
|
3070
|
+
return ToolCallingRecord(
|
|
3071
|
+
tool_name=function_name,
|
|
3072
|
+
args=args,
|
|
3073
|
+
result=result,
|
|
3074
|
+
tool_call_id=tool_call_id,
|
|
3075
|
+
)
|
|
3076
|
+
|
|
3077
|
+
except Exception as e:
|
|
3078
|
+
error_msg = (
|
|
3079
|
+
f"Error executing tool '{function_name}': {e!s}"
|
|
3080
|
+
)
|
|
3081
|
+
result = {"error": error_msg}
|
|
3082
|
+
logging.warning(error_msg)
|
|
3083
|
+
|
|
3084
|
+
# Record error response
|
|
3085
|
+
func_msg = FunctionCallingMessage(
|
|
3086
|
+
role_name=self.role_name,
|
|
3087
|
+
role_type=self.role_type,
|
|
3088
|
+
meta_dict=None,
|
|
3089
|
+
content="",
|
|
3090
|
+
func_name=function_name,
|
|
3091
|
+
result=result,
|
|
3092
|
+
tool_call_id=tool_call_id,
|
|
3093
|
+
)
|
|
3094
|
+
|
|
3095
|
+
self.update_memory(func_msg, OpenAIBackendRole.FUNCTION)
|
|
3096
|
+
|
|
3097
|
+
return ToolCallingRecord(
|
|
3098
|
+
tool_name=function_name,
|
|
3099
|
+
args=args,
|
|
3100
|
+
result=result,
|
|
3101
|
+
tool_call_id=tool_call_id,
|
|
3102
|
+
)
|
|
3103
|
+
else:
|
|
3104
|
+
logger.warning(
|
|
3105
|
+
f"Tool '{function_name}' not found in internal tools"
|
|
3106
|
+
)
|
|
3107
|
+
return None
|
|
3108
|
+
|
|
3109
|
+
except Exception as e:
|
|
3110
|
+
logger.error(f"Error processing tool call: {e}")
|
|
3111
|
+
return None
|
|
3112
|
+
|
|
3113
|
+
async def _aexecute_tool_from_stream_data(
|
|
3114
|
+
self, tool_call_data: Dict[str, Any]
|
|
3115
|
+
) -> Optional[ToolCallingRecord]:
|
|
3116
|
+
r"""Async execute a tool from accumulated stream data."""
|
|
3117
|
+
|
|
3118
|
+
try:
|
|
3119
|
+
function_name = tool_call_data['function']['name']
|
|
3120
|
+
args = json.loads(tool_call_data['function']['arguments'])
|
|
3121
|
+
tool_call_id = tool_call_data['id']
|
|
3122
|
+
|
|
3123
|
+
if function_name in self._internal_tools:
|
|
3124
|
+
tool = self._internal_tools[function_name]
|
|
3125
|
+
try:
|
|
3126
|
+
result = await tool.async_call(**args)
|
|
3127
|
+
|
|
3128
|
+
# Only record the tool response message, not the assistant
|
|
3129
|
+
# message assistant message with tool_calls was already
|
|
3130
|
+
# recorded in _record_assistant_tool_calls_message
|
|
3131
|
+
func_msg = FunctionCallingMessage(
|
|
3132
|
+
role_name=self.role_name,
|
|
3133
|
+
role_type=self.role_type,
|
|
3134
|
+
meta_dict=None,
|
|
3135
|
+
content="",
|
|
3136
|
+
func_name=function_name,
|
|
3137
|
+
result=result,
|
|
3138
|
+
tool_call_id=tool_call_id,
|
|
3139
|
+
)
|
|
3140
|
+
|
|
3141
|
+
self.update_memory(func_msg, OpenAIBackendRole.FUNCTION)
|
|
3142
|
+
|
|
3143
|
+
return ToolCallingRecord(
|
|
3144
|
+
tool_name=function_name,
|
|
3145
|
+
args=args,
|
|
3146
|
+
result=result,
|
|
3147
|
+
tool_call_id=tool_call_id,
|
|
3148
|
+
)
|
|
3149
|
+
|
|
3150
|
+
except Exception as e:
|
|
3151
|
+
error_msg = (
|
|
3152
|
+
f"Error executing async tool '{function_name}': {e!s}"
|
|
3153
|
+
)
|
|
3154
|
+
result = {"error": error_msg}
|
|
3155
|
+
logging.warning(error_msg)
|
|
3156
|
+
|
|
3157
|
+
# Record error response
|
|
3158
|
+
func_msg = FunctionCallingMessage(
|
|
3159
|
+
role_name=self.role_name,
|
|
3160
|
+
role_type=self.role_type,
|
|
3161
|
+
meta_dict=None,
|
|
3162
|
+
content="",
|
|
3163
|
+
func_name=function_name,
|
|
3164
|
+
result=result,
|
|
3165
|
+
tool_call_id=tool_call_id,
|
|
3166
|
+
)
|
|
3167
|
+
|
|
3168
|
+
self.update_memory(func_msg, OpenAIBackendRole.FUNCTION)
|
|
3169
|
+
|
|
3170
|
+
return ToolCallingRecord(
|
|
3171
|
+
tool_name=function_name,
|
|
3172
|
+
args=args,
|
|
3173
|
+
result=result,
|
|
3174
|
+
tool_call_id=tool_call_id,
|
|
3175
|
+
)
|
|
3176
|
+
else:
|
|
3177
|
+
logger.warning(
|
|
3178
|
+
f"Tool '{function_name}' not found in internal tools"
|
|
3179
|
+
)
|
|
3180
|
+
return None
|
|
3181
|
+
|
|
3182
|
+
except Exception as e:
|
|
3183
|
+
logger.error(f"Error processing async tool call: {e}")
|
|
3184
|
+
return None
|
|
3185
|
+
|
|
3186
|
+
def _create_error_response(
|
|
3187
|
+
self, error_message: str, tool_call_records: List[ToolCallingRecord]
|
|
3188
|
+
) -> ChatAgentResponse:
|
|
3189
|
+
r"""Create an error response for streaming."""
|
|
3190
|
+
|
|
3191
|
+
error_msg = BaseMessage(
|
|
3192
|
+
role_name=self.role_name,
|
|
3193
|
+
role_type=self.role_type,
|
|
3194
|
+
meta_dict={},
|
|
3195
|
+
content=f"Error: {error_message}",
|
|
3196
|
+
)
|
|
3197
|
+
|
|
3198
|
+
return ChatAgentResponse(
|
|
3199
|
+
msgs=[error_msg],
|
|
3200
|
+
terminated=True,
|
|
3201
|
+
info={
|
|
3202
|
+
"error": error_message,
|
|
3203
|
+
"tool_calls": tool_call_records,
|
|
3204
|
+
"streaming": True,
|
|
3205
|
+
},
|
|
3206
|
+
)
|
|
3207
|
+
|
|
3208
|
+
async def _astream(
|
|
3209
|
+
self,
|
|
3210
|
+
input_message: Union[BaseMessage, str],
|
|
3211
|
+
response_format: Optional[Type[BaseModel]] = None,
|
|
3212
|
+
) -> AsyncGenerator[ChatAgentResponse, None]:
|
|
3213
|
+
r"""Asynchronous version of stream method."""
|
|
3214
|
+
|
|
3215
|
+
# Convert input message to BaseMessage if necessary
|
|
3216
|
+
if isinstance(input_message, str):
|
|
3217
|
+
input_message = BaseMessage.make_user_message(
|
|
3218
|
+
role_name="User", content=input_message
|
|
3219
|
+
)
|
|
3220
|
+
|
|
3221
|
+
# Add user input to memory
|
|
3222
|
+
self.update_memory(input_message, OpenAIBackendRole.USER)
|
|
3223
|
+
|
|
3224
|
+
# Get context for streaming
|
|
3225
|
+
try:
|
|
3226
|
+
openai_messages, num_tokens = self.memory.get_context()
|
|
3227
|
+
except RuntimeError as e:
|
|
3228
|
+
yield self._step_terminate(e.args[1], [], "max_tokens_exceeded")
|
|
3229
|
+
return
|
|
3230
|
+
|
|
3231
|
+
# Start async streaming response
|
|
3232
|
+
async for response in self._astream_response(
|
|
3233
|
+
openai_messages, num_tokens, response_format
|
|
3234
|
+
):
|
|
3235
|
+
yield response
|
|
3236
|
+
|
|
3237
|
+
async def _astream_response(
|
|
3238
|
+
self,
|
|
3239
|
+
openai_messages: List[OpenAIMessage],
|
|
3240
|
+
num_tokens: int,
|
|
3241
|
+
response_format: Optional[Type[BaseModel]] = None,
|
|
3242
|
+
) -> AsyncGenerator[ChatAgentResponse, None]:
|
|
3243
|
+
r"""Async method to handle streaming responses with tool calls."""
|
|
3244
|
+
|
|
3245
|
+
tool_call_records: List[ToolCallingRecord] = []
|
|
3246
|
+
accumulated_tool_calls: Dict[str, Any] = {}
|
|
3247
|
+
step_token_usage = self._create_token_usage_tracker()
|
|
3248
|
+
|
|
3249
|
+
# Create content accumulator for proper content management
|
|
3250
|
+
content_accumulator = StreamContentAccumulator()
|
|
3251
|
+
iteration_count = 0
|
|
3252
|
+
while True:
|
|
3253
|
+
# Check termination condition
|
|
3254
|
+
if self.stop_event and self.stop_event.is_set():
|
|
3255
|
+
yield self._step_terminate(
|
|
3256
|
+
num_tokens, tool_call_records, "termination_triggered"
|
|
3257
|
+
)
|
|
3258
|
+
return
|
|
3259
|
+
|
|
3260
|
+
# Get async streaming response from model
|
|
3261
|
+
try:
|
|
3262
|
+
response = await self.model_backend.arun(
|
|
3263
|
+
openai_messages,
|
|
3264
|
+
response_format,
|
|
3265
|
+
self._get_full_tool_schemas() or None,
|
|
3266
|
+
)
|
|
3267
|
+
iteration_count += 1
|
|
3268
|
+
except Exception as exc:
|
|
3269
|
+
logger.error(
|
|
3270
|
+
f"Error in async streaming model response: {exc}",
|
|
3271
|
+
exc_info=exc,
|
|
3272
|
+
)
|
|
3273
|
+
yield self._create_error_response(str(exc), tool_call_records)
|
|
3274
|
+
return
|
|
3275
|
+
|
|
3276
|
+
# Handle streaming response
|
|
3277
|
+
if isinstance(response, AsyncStream):
|
|
3278
|
+
stream_completed = False
|
|
3279
|
+
tool_calls_complete = False
|
|
3280
|
+
|
|
3281
|
+
# Process chunks and forward them
|
|
3282
|
+
async for (
|
|
3283
|
+
item
|
|
3284
|
+
) in self._aprocess_stream_chunks_with_accumulator(
|
|
3285
|
+
response,
|
|
3286
|
+
content_accumulator,
|
|
3287
|
+
accumulated_tool_calls,
|
|
3288
|
+
tool_call_records,
|
|
3289
|
+
step_token_usage,
|
|
3290
|
+
response_format,
|
|
3291
|
+
):
|
|
3292
|
+
if isinstance(item, tuple):
|
|
3293
|
+
# This is the final return value (stream_completed,
|
|
3294
|
+
# tool_calls_complete)
|
|
3295
|
+
stream_completed, tool_calls_complete = item
|
|
3296
|
+
break
|
|
3297
|
+
else:
|
|
3298
|
+
# This is a ChatAgentResponse to be yielded
|
|
3299
|
+
yield item
|
|
3300
|
+
|
|
3301
|
+
if tool_calls_complete:
|
|
3302
|
+
# Clear completed tool calls
|
|
3303
|
+
accumulated_tool_calls.clear()
|
|
3304
|
+
|
|
3305
|
+
# If we executed tools and not in
|
|
3306
|
+
# single iteration mode, continue
|
|
3307
|
+
if tool_call_records and (
|
|
3308
|
+
self.max_iteration is None
|
|
3309
|
+
or iteration_count < self.max_iteration
|
|
3310
|
+
):
|
|
3311
|
+
# Update messages with tool results for next iteration
|
|
3312
|
+
try:
|
|
3313
|
+
openai_messages, num_tokens = (
|
|
3314
|
+
self.memory.get_context()
|
|
3315
|
+
)
|
|
3316
|
+
except RuntimeError as e:
|
|
3317
|
+
yield self._step_terminate(
|
|
3318
|
+
e.args[1],
|
|
3319
|
+
tool_call_records,
|
|
3320
|
+
"max_tokens_exceeded",
|
|
3321
|
+
)
|
|
3322
|
+
return
|
|
3323
|
+
# Reset streaming content for next iteration
|
|
3324
|
+
content_accumulator.reset_streaming_content()
|
|
3325
|
+
continue
|
|
3326
|
+
else:
|
|
3327
|
+
break
|
|
3328
|
+
else:
|
|
3329
|
+
# Stream completed without tool calls
|
|
3330
|
+
accumulated_tool_calls.clear()
|
|
3331
|
+
break
|
|
3332
|
+
elif hasattr(response, '__aenter__') and hasattr(
|
|
3333
|
+
response, '__aexit__'
|
|
3334
|
+
):
|
|
3335
|
+
# Handle structured output stream
|
|
3336
|
+
# (AsyncChatCompletionStreamManager)
|
|
3337
|
+
async with response as stream:
|
|
3338
|
+
parsed_object = None
|
|
3339
|
+
|
|
3340
|
+
async for event in stream:
|
|
3341
|
+
if event.type == "content.delta":
|
|
3342
|
+
if getattr(event, "delta", None):
|
|
3343
|
+
# Use accumulator for proper content management
|
|
3344
|
+
partial_response = self._create_streaming_response_with_accumulator( # noqa: E501
|
|
3345
|
+
content_accumulator,
|
|
3346
|
+
getattr(event, "delta", ""),
|
|
3347
|
+
step_token_usage,
|
|
3348
|
+
tool_call_records=tool_call_records.copy(),
|
|
3349
|
+
)
|
|
3350
|
+
yield partial_response
|
|
3351
|
+
|
|
3352
|
+
elif event.type == "content.done":
|
|
3353
|
+
parsed_object = getattr(event, "parsed", None)
|
|
3354
|
+
break
|
|
3355
|
+
elif event.type == "error":
|
|
3356
|
+
logger.error(
|
|
3357
|
+
f"Error in async structured stream: "
|
|
3358
|
+
f"{getattr(event, 'error', '')}"
|
|
3359
|
+
)
|
|
3360
|
+
yield self._create_error_response(
|
|
3361
|
+
str(getattr(event, 'error', '')),
|
|
3362
|
+
tool_call_records,
|
|
3363
|
+
)
|
|
3364
|
+
return
|
|
3365
|
+
|
|
3366
|
+
# Get final completion and record final message
|
|
3367
|
+
try:
|
|
3368
|
+
final_completion = await stream.get_final_completion()
|
|
3369
|
+
final_content = (
|
|
3370
|
+
final_completion.choices[0].message.content or ""
|
|
3371
|
+
)
|
|
3372
|
+
|
|
3373
|
+
final_message = BaseMessage(
|
|
3374
|
+
role_name=self.role_name,
|
|
3375
|
+
role_type=self.role_type,
|
|
3376
|
+
meta_dict={},
|
|
3377
|
+
content=final_content,
|
|
3378
|
+
parsed=cast(
|
|
3379
|
+
"BaseModel | dict[str, Any] | None",
|
|
3380
|
+
parsed_object,
|
|
3381
|
+
), # type: ignore[arg-type]
|
|
3382
|
+
)
|
|
3383
|
+
|
|
3384
|
+
self.record_message(final_message)
|
|
3385
|
+
|
|
3386
|
+
# Create final response
|
|
3387
|
+
final_response = ChatAgentResponse(
|
|
3388
|
+
msgs=[final_message],
|
|
3389
|
+
terminated=False,
|
|
3390
|
+
info={
|
|
3391
|
+
"id": final_completion.id or "",
|
|
3392
|
+
"usage": safe_model_dump(
|
|
3393
|
+
final_completion.usage
|
|
3394
|
+
)
|
|
3395
|
+
if final_completion.usage
|
|
3396
|
+
else {},
|
|
3397
|
+
"finish_reasons": [
|
|
3398
|
+
choice.finish_reason or "stop"
|
|
3399
|
+
for choice in final_completion.choices
|
|
3400
|
+
],
|
|
3401
|
+
"num_tokens": self._get_token_count(
|
|
3402
|
+
final_content
|
|
3403
|
+
),
|
|
3404
|
+
"tool_calls": tool_call_records,
|
|
3405
|
+
"external_tool_requests": None,
|
|
3406
|
+
"streaming": False,
|
|
3407
|
+
"partial": False,
|
|
3408
|
+
},
|
|
3409
|
+
)
|
|
3410
|
+
yield final_response
|
|
3411
|
+
break
|
|
3412
|
+
|
|
3413
|
+
except Exception as e:
|
|
3414
|
+
logger.error(
|
|
3415
|
+
f"Error getting async final completion: {e}"
|
|
3416
|
+
)
|
|
3417
|
+
yield self._create_error_response(
|
|
3418
|
+
str(e), tool_call_records
|
|
3419
|
+
)
|
|
3420
|
+
return
|
|
3421
|
+
else:
|
|
3422
|
+
# Handle non-streaming response (fallback)
|
|
3423
|
+
model_response = self._handle_batch_response(response)
|
|
3424
|
+
yield self._convert_to_chatagent_response(
|
|
3425
|
+
model_response,
|
|
3426
|
+
tool_call_records,
|
|
3427
|
+
num_tokens,
|
|
3428
|
+
None,
|
|
3429
|
+
model_response.usage_dict.get("prompt_tokens", 0),
|
|
3430
|
+
model_response.usage_dict.get("completion_tokens", 0),
|
|
3431
|
+
model_response.usage_dict.get("total_tokens", 0),
|
|
3432
|
+
)
|
|
3433
|
+
accumulated_tool_calls.clear()
|
|
3434
|
+
break
|
|
3435
|
+
|
|
3436
|
+
def _record_assistant_tool_calls_message(
|
|
3437
|
+
self, accumulated_tool_calls: Dict[str, Any], content: str = ""
|
|
3438
|
+
) -> None:
|
|
3439
|
+
r"""Record the assistant message that contains tool calls.
|
|
3440
|
+
|
|
3441
|
+
This method creates and records an assistant message that includes
|
|
3442
|
+
the tool calls information, which is required by OpenAI's API format.
|
|
3443
|
+
"""
|
|
3444
|
+
# Create a BaseMessage with tool_calls information in meta_dict
|
|
3445
|
+
# This will be converted to the proper OpenAI format when needed
|
|
3446
|
+
tool_calls_list = []
|
|
3447
|
+
for tool_call_data in accumulated_tool_calls.values():
|
|
3448
|
+
if tool_call_data.get('complete', False):
|
|
3449
|
+
tool_call_dict = {
|
|
3450
|
+
"id": tool_call_data["id"],
|
|
3451
|
+
"type": "function",
|
|
3452
|
+
"function": {
|
|
3453
|
+
"name": tool_call_data["function"]["name"],
|
|
3454
|
+
"arguments": tool_call_data["function"]["arguments"],
|
|
3455
|
+
},
|
|
3456
|
+
}
|
|
3457
|
+
tool_calls_list.append(tool_call_dict)
|
|
3458
|
+
|
|
3459
|
+
# Create an assistant message with tool calls
|
|
3460
|
+
assist_msg = BaseMessage(
|
|
3461
|
+
role_name=self.role_name,
|
|
3462
|
+
role_type=self.role_type,
|
|
3463
|
+
meta_dict={"tool_calls": tool_calls_list},
|
|
3464
|
+
content=content or "",
|
|
3465
|
+
)
|
|
3466
|
+
|
|
3467
|
+
# Record this assistant message
|
|
3468
|
+
self.update_memory(assist_msg, OpenAIBackendRole.ASSISTANT)
|
|
3469
|
+
|
|
3470
|
+
async def _aprocess_stream_chunks_with_accumulator(
|
|
3471
|
+
self,
|
|
3472
|
+
stream: AsyncStream[ChatCompletionChunk],
|
|
3473
|
+
content_accumulator: StreamContentAccumulator,
|
|
3474
|
+
accumulated_tool_calls: Dict[str, Any],
|
|
3475
|
+
tool_call_records: List[ToolCallingRecord],
|
|
3476
|
+
step_token_usage: Dict[str, int],
|
|
3477
|
+
response_format: Optional[Type[BaseModel]] = None,
|
|
3478
|
+
) -> AsyncGenerator[Union[ChatAgentResponse, Tuple[bool, bool]], None]:
|
|
3479
|
+
r"""Async version of process streaming chunks with
|
|
3480
|
+
content accumulator."""
|
|
3481
|
+
|
|
3482
|
+
tool_calls_complete = False
|
|
3483
|
+
stream_completed = False
|
|
3484
|
+
|
|
3485
|
+
async for chunk in stream:
|
|
3486
|
+
# Update token usage if available
|
|
3487
|
+
if chunk.usage:
|
|
3488
|
+
self._update_token_usage_tracker(
|
|
3489
|
+
step_token_usage, safe_model_dump(chunk.usage)
|
|
3490
|
+
)
|
|
3491
|
+
|
|
3492
|
+
# Process chunk delta
|
|
3493
|
+
if chunk.choices and len(chunk.choices) > 0:
|
|
3494
|
+
choice = chunk.choices[0]
|
|
3495
|
+
delta = choice.delta
|
|
3496
|
+
|
|
3497
|
+
# Handle content streaming
|
|
3498
|
+
if delta.content:
|
|
3499
|
+
# Use accumulator for proper content management
|
|
3500
|
+
partial_response = (
|
|
3501
|
+
self._create_streaming_response_with_accumulator(
|
|
3502
|
+
content_accumulator,
|
|
3503
|
+
delta.content,
|
|
3504
|
+
step_token_usage,
|
|
3505
|
+
getattr(chunk, 'id', ''),
|
|
3506
|
+
tool_call_records.copy(),
|
|
3507
|
+
)
|
|
3508
|
+
)
|
|
3509
|
+
yield partial_response
|
|
3510
|
+
|
|
3511
|
+
# Handle tool calls streaming
|
|
3512
|
+
if delta.tool_calls:
|
|
3513
|
+
tool_calls_complete = self._accumulate_tool_calls(
|
|
3514
|
+
delta.tool_calls, accumulated_tool_calls
|
|
3515
|
+
)
|
|
3516
|
+
|
|
3517
|
+
# Check if stream is complete
|
|
3518
|
+
if choice.finish_reason:
|
|
3519
|
+
stream_completed = True
|
|
3520
|
+
|
|
3521
|
+
# If we have complete tool calls, execute them with
|
|
3522
|
+
# async status updates
|
|
3523
|
+
if accumulated_tool_calls:
|
|
3524
|
+
# Record assistant message with
|
|
3525
|
+
# tool calls first
|
|
3526
|
+
self._record_assistant_tool_calls_message(
|
|
3527
|
+
accumulated_tool_calls,
|
|
3528
|
+
content_accumulator.get_full_content(),
|
|
3529
|
+
)
|
|
3530
|
+
|
|
3531
|
+
# Execute tools asynchronously with real-time
|
|
3532
|
+
# status updates
|
|
3533
|
+
async for (
|
|
3534
|
+
status_response
|
|
3535
|
+
) in self._execute_tools_async_with_status_accumulator(
|
|
3536
|
+
accumulated_tool_calls,
|
|
3537
|
+
content_accumulator,
|
|
3538
|
+
step_token_usage,
|
|
3539
|
+
tool_call_records,
|
|
3540
|
+
):
|
|
3541
|
+
yield status_response
|
|
3542
|
+
|
|
3543
|
+
# Yield "Sending back result to model" status
|
|
3544
|
+
if tool_call_records:
|
|
3545
|
+
sending_status = self._create_tool_status_response_with_accumulator( # noqa: E501
|
|
3546
|
+
content_accumulator,
|
|
3547
|
+
"\n------\n\nSending back result to model\n\n",
|
|
3548
|
+
"tool_sending",
|
|
3549
|
+
step_token_usage,
|
|
3550
|
+
)
|
|
3551
|
+
yield sending_status
|
|
3552
|
+
|
|
3553
|
+
# Record final message only if we have content AND no tool
|
|
3554
|
+
# calls. If there are tool calls, _record_tool_calling
|
|
3555
|
+
# will handle message recording.
|
|
3556
|
+
final_content = content_accumulator.get_full_content()
|
|
3557
|
+
if final_content.strip() and not accumulated_tool_calls:
|
|
3558
|
+
final_message = BaseMessage(
|
|
3559
|
+
role_name=self.role_name,
|
|
3560
|
+
role_type=self.role_type,
|
|
3561
|
+
meta_dict={},
|
|
3562
|
+
content=final_content,
|
|
3563
|
+
)
|
|
3564
|
+
|
|
3565
|
+
if response_format:
|
|
3566
|
+
self._try_format_message(
|
|
3567
|
+
final_message, response_format
|
|
3568
|
+
)
|
|
3569
|
+
|
|
3570
|
+
self.record_message(final_message)
|
|
3571
|
+
break
|
|
3572
|
+
|
|
3573
|
+
# Yield the final status as a tuple
|
|
3574
|
+
yield (stream_completed, tool_calls_complete)
|
|
3575
|
+
|
|
3576
|
+
async def _execute_tools_async_with_status_accumulator(
|
|
3577
|
+
self,
|
|
3578
|
+
accumulated_tool_calls: Dict[str, Any],
|
|
3579
|
+
content_accumulator: StreamContentAccumulator,
|
|
3580
|
+
step_token_usage: Dict[str, int],
|
|
3581
|
+
tool_call_records: List[ToolCallingRecord],
|
|
3582
|
+
) -> AsyncGenerator[ChatAgentResponse, None]:
|
|
3583
|
+
r"""Execute multiple tools asynchronously with
|
|
3584
|
+
proper content accumulation."""
|
|
3585
|
+
import asyncio
|
|
3586
|
+
|
|
3587
|
+
# Phase 1: Start all tools and yield "Calling function"
|
|
3588
|
+
# statuses immediately
|
|
3589
|
+
tool_tasks = []
|
|
3590
|
+
for _tool_call_index, tool_call_data in accumulated_tool_calls.items():
|
|
3591
|
+
if tool_call_data.get('complete', False):
|
|
3592
|
+
function_name = tool_call_data['function']['name']
|
|
3593
|
+
try:
|
|
3594
|
+
args = json.loads(tool_call_data['function']['arguments'])
|
|
3595
|
+
except json.JSONDecodeError:
|
|
3596
|
+
args = tool_call_data['function']['arguments']
|
|
3597
|
+
|
|
3598
|
+
status_message = (
|
|
3599
|
+
f"\nCalling function: {function_name} "
|
|
3600
|
+
f"with arguments:\n{args}\n"
|
|
3601
|
+
)
|
|
3602
|
+
|
|
3603
|
+
# Immediately yield "Calling function" status
|
|
3604
|
+
calling_status = (
|
|
3605
|
+
self._create_tool_status_response_with_accumulator(
|
|
3606
|
+
content_accumulator,
|
|
3607
|
+
status_message,
|
|
3608
|
+
"tool_calling",
|
|
3609
|
+
step_token_usage,
|
|
3610
|
+
)
|
|
3611
|
+
)
|
|
3612
|
+
yield calling_status
|
|
3613
|
+
|
|
3614
|
+
# Start tool execution asynchronously (non-blocking)
|
|
3615
|
+
if self.tool_execution_timeout is not None:
|
|
3616
|
+
task = asyncio.create_task(
|
|
3617
|
+
asyncio.wait_for(
|
|
3618
|
+
self._aexecute_tool_from_stream_data(
|
|
3619
|
+
tool_call_data
|
|
3620
|
+
),
|
|
3621
|
+
timeout=self.tool_execution_timeout,
|
|
3622
|
+
)
|
|
3623
|
+
)
|
|
3624
|
+
else:
|
|
3625
|
+
task = asyncio.create_task(
|
|
3626
|
+
self._aexecute_tool_from_stream_data(tool_call_data)
|
|
3627
|
+
)
|
|
3628
|
+
tool_tasks.append((task, tool_call_data))
|
|
3629
|
+
|
|
3630
|
+
# Phase 2: Wait for tools to complete and yield results as they finish
|
|
3631
|
+
if tool_tasks:
|
|
3632
|
+
# Use asyncio.as_completed for true async processing
|
|
3633
|
+
for completed_task in asyncio.as_completed(
|
|
3634
|
+
[task for task, _ in tool_tasks]
|
|
3635
|
+
):
|
|
3636
|
+
try:
|
|
3637
|
+
tool_call_record = await completed_task
|
|
3638
|
+
if tool_call_record:
|
|
3639
|
+
# Add to the shared tool_call_records list
|
|
3640
|
+
tool_call_records.append(tool_call_record)
|
|
3641
|
+
|
|
3642
|
+
# Create output status message
|
|
3643
|
+
raw_result = tool_call_record.result
|
|
3644
|
+
result_str = str(raw_result)
|
|
3645
|
+
status_message = (
|
|
3646
|
+
f"\nFunction output: {result_str}\n---------\n"
|
|
3647
|
+
)
|
|
3648
|
+
|
|
3649
|
+
# Yield "Function output" status as soon as this
|
|
3650
|
+
# tool completes
|
|
3651
|
+
output_status = (
|
|
3652
|
+
self._create_tool_status_response_with_accumulator(
|
|
3653
|
+
content_accumulator,
|
|
3654
|
+
status_message,
|
|
3655
|
+
"tool_output",
|
|
3656
|
+
step_token_usage,
|
|
3657
|
+
[tool_call_record],
|
|
3658
|
+
)
|
|
3659
|
+
)
|
|
3660
|
+
yield output_status
|
|
3661
|
+
|
|
3662
|
+
except Exception as e:
|
|
3663
|
+
if isinstance(e, asyncio.TimeoutError):
|
|
3664
|
+
timeout_msg = (
|
|
3665
|
+
f"\nFunction timed out after "
|
|
3666
|
+
f"{self.tool_execution_timeout} seconds.\n"
|
|
3667
|
+
f"---------\n"
|
|
3668
|
+
)
|
|
3669
|
+
timeout_status = (
|
|
3670
|
+
self._create_tool_status_response_with_accumulator(
|
|
3671
|
+
content_accumulator,
|
|
3672
|
+
timeout_msg,
|
|
3673
|
+
"tool_timeout",
|
|
3674
|
+
step_token_usage,
|
|
3675
|
+
)
|
|
3676
|
+
)
|
|
3677
|
+
yield timeout_status
|
|
3678
|
+
logger.error("Async tool execution timeout")
|
|
3679
|
+
else:
|
|
3680
|
+
logger.error(f"Error in async tool execution: {e}")
|
|
3681
|
+
continue
|
|
3682
|
+
|
|
3683
|
+
def _create_tool_status_response_with_accumulator(
|
|
3684
|
+
self,
|
|
3685
|
+
accumulator: StreamContentAccumulator,
|
|
3686
|
+
status_message: str,
|
|
3687
|
+
status_type: str,
|
|
3688
|
+
step_token_usage: Dict[str, int],
|
|
3689
|
+
tool_calls: Optional[List[ToolCallingRecord]] = None,
|
|
3690
|
+
) -> ChatAgentResponse:
|
|
3691
|
+
r"""Create a tool status response using content accumulator."""
|
|
3692
|
+
|
|
3693
|
+
# Add this status message to accumulator and get full content
|
|
3694
|
+
accumulator.add_tool_status(status_message)
|
|
3695
|
+
full_content = accumulator.get_full_content()
|
|
3696
|
+
|
|
3697
|
+
message = BaseMessage(
|
|
3698
|
+
role_name=self.role_name,
|
|
3699
|
+
role_type=self.role_type,
|
|
3700
|
+
meta_dict={},
|
|
3701
|
+
content=full_content,
|
|
3702
|
+
)
|
|
3703
|
+
|
|
3704
|
+
return ChatAgentResponse(
|
|
3705
|
+
msgs=[message],
|
|
3706
|
+
terminated=False,
|
|
3707
|
+
info={
|
|
3708
|
+
"id": "",
|
|
3709
|
+
"usage": step_token_usage.copy(),
|
|
3710
|
+
"finish_reasons": [status_type],
|
|
3711
|
+
"num_tokens": self._get_token_count(full_content),
|
|
3712
|
+
"tool_calls": tool_calls or [],
|
|
3713
|
+
"external_tool_requests": None,
|
|
3714
|
+
"streaming": True,
|
|
3715
|
+
"tool_status": status_type,
|
|
3716
|
+
"partial": True,
|
|
3717
|
+
},
|
|
3718
|
+
)
|
|
3719
|
+
|
|
3720
|
+
def _create_streaming_response_with_accumulator(
|
|
3721
|
+
self,
|
|
3722
|
+
accumulator: StreamContentAccumulator,
|
|
3723
|
+
new_content: str,
|
|
3724
|
+
step_token_usage: Dict[str, int],
|
|
3725
|
+
response_id: str = "",
|
|
3726
|
+
tool_call_records: Optional[List[ToolCallingRecord]] = None,
|
|
3727
|
+
) -> ChatAgentResponse:
|
|
3728
|
+
r"""Create a streaming response using content accumulator."""
|
|
3729
|
+
|
|
3730
|
+
# Add new content to accumulator and get full content
|
|
3731
|
+
accumulator.add_streaming_content(new_content)
|
|
3732
|
+
full_content = accumulator.get_full_content()
|
|
3733
|
+
|
|
3734
|
+
message = BaseMessage(
|
|
3735
|
+
role_name=self.role_name,
|
|
3736
|
+
role_type=self.role_type,
|
|
3737
|
+
meta_dict={},
|
|
3738
|
+
content=full_content,
|
|
3739
|
+
)
|
|
3740
|
+
|
|
3741
|
+
return ChatAgentResponse(
|
|
3742
|
+
msgs=[message],
|
|
3743
|
+
terminated=False,
|
|
3744
|
+
info={
|
|
3745
|
+
"id": response_id,
|
|
3746
|
+
"usage": step_token_usage.copy(),
|
|
3747
|
+
"finish_reasons": ["streaming"],
|
|
3748
|
+
"num_tokens": self._get_token_count(full_content),
|
|
3749
|
+
"tool_calls": tool_call_records or [],
|
|
3750
|
+
"external_tool_requests": None,
|
|
3751
|
+
"streaming": True,
|
|
3752
|
+
"partial": True,
|
|
3753
|
+
},
|
|
3754
|
+
)
|
|
3755
|
+
|
|
2409
3756
|
def get_usage_dict(
|
|
2410
3757
|
self, output_messages: List[BaseMessage], prompt_tokens: int
|
|
2411
3758
|
) -> Dict[str, int]:
|
|
@@ -2478,6 +3825,7 @@ class ChatAgent(BaseAgent):
|
|
|
2478
3825
|
),
|
|
2479
3826
|
max_iteration=self.max_iteration,
|
|
2480
3827
|
stop_event=self.stop_event,
|
|
3828
|
+
tool_execution_timeout=self.tool_execution_timeout,
|
|
2481
3829
|
pause_event=self.pause_event,
|
|
2482
3830
|
)
|
|
2483
3831
|
|