camel-ai 0.2.71a4__py3-none-any.whl → 0.2.71a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of camel-ai might be problematic. Click here for more details.

@@ -16,23 +16,26 @@ from __future__ import annotations
16
16
  import asyncio
17
17
  import json
18
18
  import logging
19
+ import queue
19
20
  import textwrap
20
21
  import threading
21
22
  import time
22
23
  import uuid
23
- from collections import defaultdict
24
24
  from pathlib import Path
25
25
  from typing import (
26
26
  TYPE_CHECKING,
27
27
  Any,
28
+ AsyncGenerator,
28
29
  Callable,
29
30
  Dict,
31
+ Generator,
30
32
  List,
31
33
  Optional,
32
34
  Set,
33
35
  Tuple,
34
36
  Type,
35
37
  Union,
38
+ cast,
36
39
  )
37
40
 
38
41
  from openai import (
@@ -50,6 +53,7 @@ from camel.agents._utils import (
50
53
  safe_model_dump,
51
54
  )
52
55
  from camel.agents.base import BaseAgent
56
+ from camel.logger import get_logger
53
57
  from camel.memories import (
54
58
  AgentMemory,
55
59
  ChatHistoryMemory,
@@ -90,7 +94,7 @@ from camel.utils.tool_result import ToolResult
90
94
  if TYPE_CHECKING:
91
95
  from camel.terminators import ResponseTerminator
92
96
 
93
- logger = logging.getLogger(__name__)
97
+ logger = get_logger(__name__)
94
98
 
95
99
  # AgentOps decorator setting
96
100
  try:
@@ -109,6 +113,11 @@ if os.environ.get("LANGFUSE_ENABLED", "False").lower() == "true":
109
113
  from langfuse.decorators import observe
110
114
  except ImportError:
111
115
  from camel.utils import observe
116
+ elif os.environ.get("TRACEROOT_ENABLED", "False").lower() == "true":
117
+ try:
118
+ from traceroot import trace as observe # type: ignore[import]
119
+ except ImportError:
120
+ from camel.utils import observe
112
121
  else:
113
122
  from camel.utils import observe
114
123
 
@@ -124,6 +133,189 @@ SIMPLE_FORMAT_PROMPT = TextPrompt(
124
133
  )
125
134
 
126
135
 
136
+ class StreamContentAccumulator:
137
+ r"""Manages content accumulation across streaming responses to ensure
138
+ all responses contain complete cumulative content."""
139
+
140
+ def __init__(self):
141
+ self.base_content = "" # Content before tool calls
142
+ self.current_content = "" # Current streaming content
143
+ self.tool_status_messages = [] # Accumulated tool status messages
144
+
145
+ def set_base_content(self, content: str):
146
+ r"""Set the base content (usually empty or pre-tool content)."""
147
+ self.base_content = content
148
+
149
+ def add_streaming_content(self, new_content: str):
150
+ r"""Add new streaming content."""
151
+ self.current_content += new_content
152
+
153
+ def add_tool_status(self, status_message: str):
154
+ r"""Add a tool status message."""
155
+ self.tool_status_messages.append(status_message)
156
+
157
+ def get_full_content(self) -> str:
158
+ r"""Get the complete accumulated content."""
159
+ tool_messages = "".join(self.tool_status_messages)
160
+ return self.base_content + tool_messages + self.current_content
161
+
162
+ def get_content_with_new_status(self, status_message: str) -> str:
163
+ r"""Get content with a new status message appended."""
164
+ tool_messages = "".join([*self.tool_status_messages, status_message])
165
+ return self.base_content + tool_messages + self.current_content
166
+
167
+ def reset_streaming_content(self):
168
+ r"""Reset only the streaming content, keep base and tool status."""
169
+ self.current_content = ""
170
+
171
+
172
+ class StreamingChatAgentResponse:
173
+ r"""A wrapper that makes streaming responses compatible with
174
+ non-streaming code.
175
+
176
+ This class wraps a Generator[ChatAgentResponse, None, None] and provides
177
+ the same interface as ChatAgentResponse, so existing code doesn't need to
178
+ change.
179
+ """
180
+
181
+ def __init__(self, generator: Generator[ChatAgentResponse, None, None]):
182
+ self._generator = generator
183
+ self._current_response: Optional[ChatAgentResponse] = None
184
+ self._responses: List[ChatAgentResponse] = []
185
+ self._consumed = False
186
+
187
+ def _ensure_latest_response(self):
188
+ r"""Ensure we have the latest response by consuming the generator."""
189
+ if not self._consumed:
190
+ try:
191
+ for response in self._generator:
192
+ self._responses.append(response)
193
+ self._current_response = response
194
+ self._consumed = True
195
+ except StopIteration:
196
+ self._consumed = True
197
+
198
+ @property
199
+ def msgs(self) -> List[BaseMessage]:
200
+ r"""Get messages from the latest response."""
201
+ self._ensure_latest_response()
202
+ if self._current_response:
203
+ return self._current_response.msgs
204
+ return []
205
+
206
+ @property
207
+ def terminated(self) -> bool:
208
+ r"""Get terminated status from the latest response."""
209
+ self._ensure_latest_response()
210
+ if self._current_response:
211
+ return self._current_response.terminated
212
+ return False
213
+
214
+ @property
215
+ def info(self) -> Dict[str, Any]:
216
+ r"""Get info from the latest response."""
217
+ self._ensure_latest_response()
218
+ if self._current_response:
219
+ return self._current_response.info
220
+ return {}
221
+
222
+ @property
223
+ def msg(self):
224
+ r"""Get the single message if there's exactly one message."""
225
+ self._ensure_latest_response()
226
+ if self._current_response:
227
+ return self._current_response.msg
228
+ return None
229
+
230
+ def __iter__(self):
231
+ r"""Make this object iterable."""
232
+ if self._consumed:
233
+ # If already consumed, iterate over stored responses
234
+ return iter(self._responses)
235
+ else:
236
+ # If not consumed, consume and yield
237
+ try:
238
+ for response in self._generator:
239
+ self._responses.append(response)
240
+ self._current_response = response
241
+ yield response
242
+ self._consumed = True
243
+ except StopIteration:
244
+ self._consumed = True
245
+
246
+ def __getattr__(self, name):
247
+ r"""Forward any other attribute access to the latest response."""
248
+ self._ensure_latest_response()
249
+ if self._current_response and hasattr(self._current_response, name):
250
+ return getattr(self._current_response, name)
251
+ raise AttributeError(
252
+ f"'StreamingChatAgentResponse' object has no attribute '{name}'"
253
+ )
254
+
255
+
256
+ class AsyncStreamingChatAgentResponse:
257
+ r"""A wrapper that makes async streaming responses awaitable and
258
+ compatible with non-streaming code.
259
+
260
+ This class wraps an AsyncGenerator[ChatAgentResponse, None] and provides
261
+ both awaitable and async iterable interfaces.
262
+ """
263
+
264
+ def __init__(
265
+ self, async_generator: AsyncGenerator[ChatAgentResponse, None]
266
+ ):
267
+ self._async_generator = async_generator
268
+ self._current_response: Optional[ChatAgentResponse] = None
269
+ self._responses: List[ChatAgentResponse] = []
270
+ self._consumed = False
271
+
272
+ async def _ensure_latest_response(self):
273
+ r"""Ensure the latest response by consuming the async generator."""
274
+ if not self._consumed:
275
+ try:
276
+ async for response in self._async_generator:
277
+ self._responses.append(response)
278
+ self._current_response = response
279
+ self._consumed = True
280
+ except StopAsyncIteration:
281
+ self._consumed = True
282
+
283
+ async def _get_final_response(self) -> ChatAgentResponse:
284
+ r"""Get the final response after consuming the entire stream."""
285
+ await self._ensure_latest_response()
286
+ if self._current_response:
287
+ return self._current_response
288
+ # Return a default response if nothing was consumed
289
+ return ChatAgentResponse(msgs=[], terminated=False, info={})
290
+
291
+ def __await__(self):
292
+ r"""Make this object awaitable - returns the final response."""
293
+ return self._get_final_response().__await__()
294
+
295
+ def __aiter__(self):
296
+ r"""Make this object async iterable."""
297
+ if self._consumed:
298
+ # If already consumed, create async iterator from stored responses
299
+ async def _async_iter():
300
+ for response in self._responses:
301
+ yield response
302
+
303
+ return _async_iter()
304
+ else:
305
+ # If not consumed, consume and yield
306
+ async def _consume_and_yield():
307
+ try:
308
+ async for response in self._async_generator:
309
+ self._responses.append(response)
310
+ self._current_response = response
311
+ yield response
312
+ self._consumed = True
313
+ except StopAsyncIteration:
314
+ self._consumed = True
315
+
316
+ return _consume_and_yield()
317
+
318
+
127
319
  @track_agent(name="ChatAgent")
128
320
  class ChatAgent(BaseAgent):
129
321
  r"""Class for managing conversations of CAMEL Chat Agents.
@@ -176,6 +368,8 @@ class ChatAgent(BaseAgent):
176
368
  stop_event (Optional[threading.Event], optional): Event to signal
177
369
  termination of the agent's operation. When set, the agent will
178
370
  terminate its execution. (default: :obj:`None`)
371
+ tool_execution_timeout (Optional[float], optional): Timeout
372
+ for individual tool execution. If None, wait indefinitely.
179
373
  mask_tool_output (Optional[bool]): Whether to return a sanitized
180
374
  placeholder instead of the raw tool output. (default: :obj:`False`)
181
375
  pause_event (Optional[asyncio.Event]): Event to signal pause of the
@@ -214,6 +408,7 @@ class ChatAgent(BaseAgent):
214
408
  max_iteration: Optional[int] = None,
215
409
  agent_id: Optional[str] = None,
216
410
  stop_event: Optional[threading.Event] = None,
411
+ tool_execution_timeout: Optional[float] = None,
217
412
  mask_tool_output: bool = False,
218
413
  pause_event: Optional[asyncio.Event] = None,
219
414
  ) -> None:
@@ -290,6 +485,7 @@ class ChatAgent(BaseAgent):
290
485
  self.response_terminators = response_terminators or []
291
486
  self.max_iteration = max_iteration
292
487
  self.stop_event = stop_event
488
+ self.tool_execution_timeout = tool_execution_timeout
293
489
  self.mask_tool_output = mask_tool_output
294
490
  self._secure_result_store: Dict[str, Any] = {}
295
491
  self._pending_images: List[str] = []
@@ -1106,7 +1302,7 @@ class ChatAgent(BaseAgent):
1106
1302
  self,
1107
1303
  input_message: Union[BaseMessage, str],
1108
1304
  response_format: Optional[Type[BaseModel]] = None,
1109
- ) -> ChatAgentResponse:
1305
+ ) -> Union[ChatAgentResponse, StreamingChatAgentResponse]:
1110
1306
  r"""Executes a single step in the chat session, generating a response
1111
1307
  to the input message.
1112
1308
 
@@ -1120,10 +1316,20 @@ class ChatAgent(BaseAgent):
1120
1316
  :obj:`None`)
1121
1317
 
1122
1318
  Returns:
1123
- ChatAgentResponse: Contains output messages, a termination status
1124
- flag, and session information.
1319
+ Union[ChatAgentResponse, StreamingChatAgentResponse]: If stream is
1320
+ False, returns a ChatAgentResponse. If stream is True, returns
1321
+ a StreamingChatAgentResponse that behaves like
1322
+ ChatAgentResponse but can also be iterated for
1323
+ streaming updates.
1125
1324
  """
1126
1325
 
1326
+ stream = self.model_backend.model_config_dict.get("stream", False)
1327
+
1328
+ if stream:
1329
+ # Return wrapped generator that has ChatAgentResponse interface
1330
+ generator = self._stream(input_message, response_format)
1331
+ return StreamingChatAgentResponse(generator)
1332
+
1127
1333
  # Set Langfuse session_id using agent_id for trace grouping
1128
1334
  try:
1129
1335
  from camel.utils.langfuse import set_current_agent_session_id
@@ -1222,9 +1428,8 @@ class ChatAgent(BaseAgent):
1222
1428
  ):
1223
1429
  while not self.pause_event.is_set():
1224
1430
  time.sleep(0.001)
1225
- tool_call_records.append(
1226
- self._execute_tool(tool_call_request)
1227
- )
1431
+ result = self._execute_tool(tool_call_request)
1432
+ tool_call_records.append(result)
1228
1433
 
1229
1434
  # If we found external tool calls, break the loop
1230
1435
  if external_tool_call_requests:
@@ -1271,7 +1476,7 @@ class ChatAgent(BaseAgent):
1271
1476
  self,
1272
1477
  input_message: Union[BaseMessage, str],
1273
1478
  response_format: Optional[Type[BaseModel]] = None,
1274
- ) -> ChatAgentResponse:
1479
+ ) -> Union[ChatAgentResponse, AsyncStreamingChatAgentResponse]:
1275
1480
  r"""Performs a single step in the chat session by generating a response
1276
1481
  to the input message. This agent step can call async function calls.
1277
1482
 
@@ -1287,12 +1492,38 @@ class ChatAgent(BaseAgent):
1287
1492
  used to generate a structured response by LLM. This schema
1288
1493
  helps in defining the expected output format. (default:
1289
1494
  :obj:`None`)
1290
-
1291
1495
  Returns:
1292
- ChatAgentResponse: A struct containing the output messages,
1293
- a boolean indicating whether the chat session has terminated,
1294
- and information about the chat session.
1496
+ Union[ChatAgentResponse, AsyncStreamingChatAgentResponse]:
1497
+ If stream is False, returns a ChatAgentResponse. If stream is
1498
+ True, returns an AsyncStreamingChatAgentResponse that can be
1499
+ awaited for the final result or async iterated for streaming
1500
+ updates.
1295
1501
  """
1502
+
1503
+ try:
1504
+ from camel.utils.langfuse import set_current_agent_session_id
1505
+
1506
+ set_current_agent_session_id(self.agent_id)
1507
+ except ImportError:
1508
+ pass # Langfuse not available
1509
+
1510
+ stream = self.model_backend.model_config_dict.get("stream", False)
1511
+ if stream:
1512
+ # Return wrapped async generator that is awaitable
1513
+ async_generator = self._astream(input_message, response_format)
1514
+ return AsyncStreamingChatAgentResponse(async_generator)
1515
+ else:
1516
+ return await self._astep_non_streaming_task(
1517
+ input_message, response_format
1518
+ )
1519
+
1520
+ async def _astep_non_streaming_task(
1521
+ self,
1522
+ input_message: Union[BaseMessage, str],
1523
+ response_format: Optional[Type[BaseModel]] = None,
1524
+ ) -> ChatAgentResponse:
1525
+ r"""Internal async method for non-streaming astep logic."""
1526
+
1296
1527
  try:
1297
1528
  from camel.utils.langfuse import set_current_agent_session_id
1298
1529
 
@@ -1731,11 +1962,12 @@ class ChatAgent(BaseAgent):
1731
1962
  f"index {self.model_backend.current_model_index}, "
1732
1963
  f"processed these messages: {sanitized_messages}"
1733
1964
  )
1734
-
1735
- if isinstance(response, ChatCompletion):
1736
- return self._handle_batch_response(response)
1737
- else:
1738
- return self._handle_stream_response(response, num_tokens)
1965
+ if not isinstance(response, ChatCompletion):
1966
+ raise TypeError(
1967
+ f"Expected response to be a `ChatCompletion` object, but "
1968
+ f"got {type(response).__name__} instead."
1969
+ )
1970
+ return self._handle_batch_response(response)
1739
1971
 
1740
1972
  async def _aget_model_response(
1741
1973
  self,
@@ -1799,11 +2031,12 @@ class ChatAgent(BaseAgent):
1799
2031
  f"index {self.model_backend.current_model_index}, "
1800
2032
  f"processed these messages: {sanitized_messages}"
1801
2033
  )
1802
-
1803
- if isinstance(response, ChatCompletion):
1804
- return self._handle_batch_response(response)
1805
- else:
1806
- return await self._ahandle_stream_response(response, num_tokens)
2034
+ if not isinstance(response, ChatCompletion):
2035
+ raise TypeError(
2036
+ f"Expected response to be a `ChatCompletion` object, but "
2037
+ f"got {type(response).__name__} instead."
2038
+ )
2039
+ return self._handle_batch_response(response)
1807
2040
 
1808
2041
  def _sanitize_messages_for_logging(self, messages):
1809
2042
  r"""Sanitize OpenAI messages for logging by replacing base64 image
@@ -2065,116 +2298,6 @@ class ChatAgent(BaseAgent):
2065
2298
  response_id=response.id or "",
2066
2299
  )
2067
2300
 
2068
- def _handle_stream_response(
2069
- self,
2070
- response: Stream[ChatCompletionChunk],
2071
- prompt_tokens: int,
2072
- ) -> ModelResponse:
2073
- r"""Process a stream response from the model and extract the necessary
2074
- information.
2075
-
2076
- Args:
2077
- response (dict): Model response.
2078
- prompt_tokens (int): Number of input prompt tokens.
2079
-
2080
- Returns:
2081
- _ModelResponse: a parsed model response.
2082
- """
2083
- content_dict: defaultdict = defaultdict(lambda: "")
2084
- finish_reasons_dict: defaultdict = defaultdict(lambda: "")
2085
- output_messages: List[BaseMessage] = []
2086
- response_id: str = ""
2087
- # All choices in one response share one role
2088
- for chunk in response:
2089
- # Some model platforms like siliconflow may return None for the
2090
- # chunk.id
2091
- response_id = chunk.id if chunk.id else str(uuid.uuid4())
2092
- self._handle_chunk(
2093
- chunk, content_dict, finish_reasons_dict, output_messages
2094
- )
2095
- finish_reasons = [
2096
- finish_reasons_dict[i] for i in range(len(finish_reasons_dict))
2097
- ]
2098
- usage_dict = self.get_usage_dict(output_messages, prompt_tokens)
2099
-
2100
- # TODO: Handle tool calls
2101
- return ModelResponse(
2102
- response=response,
2103
- tool_call_requests=None,
2104
- output_messages=output_messages,
2105
- finish_reasons=finish_reasons,
2106
- usage_dict=usage_dict,
2107
- response_id=response_id,
2108
- )
2109
-
2110
- async def _ahandle_stream_response(
2111
- self,
2112
- response: AsyncStream[ChatCompletionChunk],
2113
- prompt_tokens: int,
2114
- ) -> ModelResponse:
2115
- r"""Process a stream response from the model and extract the necessary
2116
- information.
2117
-
2118
- Args:
2119
- response (dict): Model response.
2120
- prompt_tokens (int): Number of input prompt tokens.
2121
-
2122
- Returns:
2123
- _ModelResponse: a parsed model response.
2124
- """
2125
- content_dict: defaultdict = defaultdict(lambda: "")
2126
- finish_reasons_dict: defaultdict = defaultdict(lambda: "")
2127
- output_messages: List[BaseMessage] = []
2128
- response_id: str = ""
2129
- # All choices in one response share one role
2130
- async for chunk in response:
2131
- # Some model platforms like siliconflow may return None for the
2132
- # chunk.id
2133
- response_id = chunk.id if chunk.id else str(uuid.uuid4())
2134
- self._handle_chunk(
2135
- chunk, content_dict, finish_reasons_dict, output_messages
2136
- )
2137
- finish_reasons = [
2138
- finish_reasons_dict[i] for i in range(len(finish_reasons_dict))
2139
- ]
2140
- usage_dict = self.get_usage_dict(output_messages, prompt_tokens)
2141
-
2142
- # TODO: Handle tool calls
2143
- return ModelResponse(
2144
- response=response,
2145
- tool_call_requests=None,
2146
- output_messages=output_messages,
2147
- finish_reasons=finish_reasons,
2148
- usage_dict=usage_dict,
2149
- response_id=response_id,
2150
- )
2151
-
2152
- def _handle_chunk(
2153
- self,
2154
- chunk: ChatCompletionChunk,
2155
- content_dict: defaultdict,
2156
- finish_reasons_dict: defaultdict,
2157
- output_messages: List[BaseMessage],
2158
- ) -> None:
2159
- r"""Handle a chunk of the model response."""
2160
- for choice in chunk.choices:
2161
- index = choice.index
2162
- delta = choice.delta
2163
- if delta.content is not None:
2164
- content_dict[index] += delta.content
2165
-
2166
- if not choice.finish_reason:
2167
- continue
2168
-
2169
- finish_reasons_dict[index] = choice.finish_reason
2170
- chat_message = BaseMessage(
2171
- role_name=self.role_name,
2172
- role_type=self.role_type,
2173
- meta_dict=dict(),
2174
- content=content_dict[index],
2175
- )
2176
- output_messages.append(chat_message)
2177
-
2178
2301
  def _step_terminate(
2179
2302
  self,
2180
2303
  num_tokens: int,
@@ -2309,7 +2432,7 @@ class ChatAgent(BaseAgent):
2309
2432
  except Exception as e:
2310
2433
  # Capture the error message to prevent framework crash
2311
2434
  error_msg = f"Error executing async tool '{func_name}': {e!s}"
2312
- result = {"error": error_msg}
2435
+ result = f"Tool execution failed: {error_msg}"
2313
2436
  logging.warning(error_msg)
2314
2437
 
2315
2438
  # Check if result is a ToolResult with images
@@ -2406,6 +2529,1230 @@ class ChatAgent(BaseAgent):
2406
2529
 
2407
2530
  return tool_record
2408
2531
 
2532
+ def _stream(
2533
+ self,
2534
+ input_message: Union[BaseMessage, str],
2535
+ response_format: Optional[Type[BaseModel]] = None,
2536
+ ) -> Generator[ChatAgentResponse, None, None]:
2537
+ r"""Executes a streaming step in the chat session, yielding
2538
+ intermediate responses as they are generated.
2539
+
2540
+ Args:
2541
+ input_message (Union[BaseMessage, str]): The input message for the
2542
+ agent.
2543
+ response_format (Optional[Type[BaseModel]], optional): A Pydantic
2544
+ model defining the expected structure of the response.
2545
+
2546
+ Yields:
2547
+ ChatAgentResponse: Intermediate responses containing partial
2548
+ content, tool calls, and other information as they become
2549
+ available.
2550
+ """
2551
+ # Convert input message to BaseMessage if necessary
2552
+ if isinstance(input_message, str):
2553
+ input_message = BaseMessage.make_user_message(
2554
+ role_name="User", content=input_message
2555
+ )
2556
+
2557
+ # Add user input to memory
2558
+ self.update_memory(input_message, OpenAIBackendRole.USER)
2559
+
2560
+ # Get context for streaming
2561
+ try:
2562
+ openai_messages, num_tokens = self.memory.get_context()
2563
+ except RuntimeError as e:
2564
+ yield self._step_terminate(e.args[1], [], "max_tokens_exceeded")
2565
+ return
2566
+
2567
+ # Start streaming response
2568
+ yield from self._stream_response(
2569
+ openai_messages, num_tokens, response_format
2570
+ )
2571
+
2572
+ def _get_token_count(self, content: str) -> int:
2573
+ r"""Get token count for content with fallback."""
2574
+ if hasattr(self.model_backend, 'token_counter'):
2575
+ return len(self.model_backend.token_counter.encode(content))
2576
+ else:
2577
+ return len(content.split())
2578
+
2579
+ def _stream_response(
2580
+ self,
2581
+ openai_messages: List[OpenAIMessage],
2582
+ num_tokens: int,
2583
+ response_format: Optional[Type[BaseModel]] = None,
2584
+ ) -> Generator[ChatAgentResponse, None, None]:
2585
+ r"""Internal method to handle streaming responses with tool calls."""
2586
+
2587
+ tool_call_records: List[ToolCallingRecord] = []
2588
+ accumulated_tool_calls: Dict[str, Any] = {}
2589
+ step_token_usage = self._create_token_usage_tracker()
2590
+
2591
+ # Create content accumulator for proper content management
2592
+ content_accumulator = StreamContentAccumulator()
2593
+ iteration_count = 0
2594
+ while True:
2595
+ # Check termination condition
2596
+ if self.stop_event and self.stop_event.is_set():
2597
+ yield self._step_terminate(
2598
+ num_tokens, tool_call_records, "termination_triggered"
2599
+ )
2600
+ return
2601
+
2602
+ # Get streaming response from model
2603
+ try:
2604
+ response = self.model_backend.run(
2605
+ openai_messages,
2606
+ response_format,
2607
+ self._get_full_tool_schemas() or None,
2608
+ )
2609
+ iteration_count += 1
2610
+ except Exception as exc:
2611
+ logger.error(
2612
+ f"Error in streaming model response: {exc}", exc_info=exc
2613
+ )
2614
+ yield self._create_error_response(str(exc), tool_call_records)
2615
+ return
2616
+
2617
+ # Handle streaming response
2618
+ if isinstance(response, Stream):
2619
+ (
2620
+ stream_completed,
2621
+ tool_calls_complete,
2622
+ ) = yield from self._process_stream_chunks_with_accumulator(
2623
+ response,
2624
+ content_accumulator,
2625
+ accumulated_tool_calls,
2626
+ tool_call_records,
2627
+ step_token_usage,
2628
+ response_format,
2629
+ )
2630
+
2631
+ if tool_calls_complete:
2632
+ # Clear completed tool calls
2633
+ accumulated_tool_calls.clear()
2634
+
2635
+ # If we executed tools and not in
2636
+ # single iteration mode, continue
2637
+ if tool_call_records and (
2638
+ self.max_iteration is None
2639
+ or iteration_count < self.max_iteration
2640
+ ):
2641
+ # Update messages with tool results for next iteration
2642
+ try:
2643
+ openai_messages, num_tokens = (
2644
+ self.memory.get_context()
2645
+ )
2646
+ except RuntimeError as e:
2647
+ yield self._step_terminate(
2648
+ e.args[1],
2649
+ tool_call_records,
2650
+ "max_tokens_exceeded",
2651
+ )
2652
+ return
2653
+ # Reset streaming content for next iteration
2654
+ content_accumulator.reset_streaming_content()
2655
+ continue
2656
+ else:
2657
+ break
2658
+ else:
2659
+ # Stream completed without tool calls
2660
+ accumulated_tool_calls.clear()
2661
+ break
2662
+ elif hasattr(response, '__enter__') and hasattr(
2663
+ response, '__exit__'
2664
+ ):
2665
+ # Handle structured output stream (ChatCompletionStreamManager)
2666
+ with response as stream:
2667
+ parsed_object = None
2668
+
2669
+ for event in stream:
2670
+ if event.type == "content.delta":
2671
+ if getattr(event, "delta", None):
2672
+ # Use accumulator for proper content management
2673
+ partial_response = self._create_streaming_response_with_accumulator( # noqa: E501
2674
+ content_accumulator,
2675
+ getattr(event, "delta", ""),
2676
+ step_token_usage,
2677
+ tool_call_records=tool_call_records.copy(),
2678
+ )
2679
+ yield partial_response
2680
+
2681
+ elif event.type == "content.done":
2682
+ parsed_object = getattr(event, "parsed", None)
2683
+ break
2684
+ elif event.type == "error":
2685
+ logger.error(
2686
+ f"Error in structured stream: "
2687
+ f"{getattr(event, 'error', '')}"
2688
+ )
2689
+ yield self._create_error_response(
2690
+ str(getattr(event, 'error', '')),
2691
+ tool_call_records,
2692
+ )
2693
+ return
2694
+
2695
+ # Get final completion and record final message
2696
+ try:
2697
+ final_completion = stream.get_final_completion()
2698
+ final_content = (
2699
+ final_completion.choices[0].message.content or ""
2700
+ )
2701
+
2702
+ final_message = BaseMessage(
2703
+ role_name=self.role_name,
2704
+ role_type=self.role_type,
2705
+ meta_dict={},
2706
+ content=final_content,
2707
+ parsed=cast(
2708
+ "BaseModel | dict[str, Any] | None",
2709
+ parsed_object,
2710
+ ), # type: ignore[arg-type]
2711
+ )
2712
+
2713
+ self.record_message(final_message)
2714
+
2715
+ # Create final response
2716
+ final_response = ChatAgentResponse(
2717
+ msgs=[final_message],
2718
+ terminated=False,
2719
+ info={
2720
+ "id": final_completion.id or "",
2721
+ "usage": safe_model_dump(
2722
+ final_completion.usage
2723
+ )
2724
+ if final_completion.usage
2725
+ else {},
2726
+ "finish_reasons": [
2727
+ choice.finish_reason or "stop"
2728
+ for choice in final_completion.choices
2729
+ ],
2730
+ "num_tokens": self._get_token_count(
2731
+ final_content
2732
+ ),
2733
+ "tool_calls": tool_call_records,
2734
+ "external_tool_requests": None,
2735
+ "streaming": False,
2736
+ "partial": False,
2737
+ },
2738
+ )
2739
+ yield final_response
2740
+ break
2741
+
2742
+ except Exception as e:
2743
+ logger.error(f"Error getting final completion: {e}")
2744
+ yield self._create_error_response(
2745
+ str(e), tool_call_records
2746
+ )
2747
+ return
2748
+ else:
2749
+ # Handle non-streaming response (fallback)
2750
+ model_response = self._handle_batch_response(response)
2751
+ yield self._convert_to_chatagent_response(
2752
+ model_response,
2753
+ tool_call_records,
2754
+ num_tokens,
2755
+ None,
2756
+ model_response.usage_dict.get("prompt_tokens", 0),
2757
+ model_response.usage_dict.get("completion_tokens", 0),
2758
+ model_response.usage_dict.get("total_tokens", 0),
2759
+ )
2760
+ accumulated_tool_calls.clear()
2761
+ break
2762
+
2763
+ def _process_stream_chunks_with_accumulator(
2764
+ self,
2765
+ stream: Stream[ChatCompletionChunk],
2766
+ content_accumulator: StreamContentAccumulator,
2767
+ accumulated_tool_calls: Dict[str, Any],
2768
+ tool_call_records: List[ToolCallingRecord],
2769
+ step_token_usage: Dict[str, int],
2770
+ response_format: Optional[Type[BaseModel]] = None,
2771
+ ) -> Generator[ChatAgentResponse, None, Tuple[bool, bool]]:
2772
+ r"""Process streaming chunks with content accumulator."""
2773
+
2774
+ tool_calls_complete = False
2775
+ stream_completed = False
2776
+
2777
+ for chunk in stream:
2778
+ # Update token usage if available
2779
+ if chunk.usage:
2780
+ self._update_token_usage_tracker(
2781
+ step_token_usage, safe_model_dump(chunk.usage)
2782
+ )
2783
+
2784
+ # Process chunk delta
2785
+ if chunk.choices and len(chunk.choices) > 0:
2786
+ choice = chunk.choices[0]
2787
+ delta = choice.delta
2788
+
2789
+ # Handle content streaming
2790
+ if delta.content:
2791
+ # Use accumulator for proper content management
2792
+ partial_response = (
2793
+ self._create_streaming_response_with_accumulator(
2794
+ content_accumulator,
2795
+ delta.content,
2796
+ step_token_usage,
2797
+ getattr(chunk, 'id', ''),
2798
+ tool_call_records.copy(),
2799
+ )
2800
+ )
2801
+ yield partial_response
2802
+
2803
+ # Handle tool calls streaming
2804
+ if delta.tool_calls:
2805
+ tool_calls_complete = self._accumulate_tool_calls(
2806
+ delta.tool_calls, accumulated_tool_calls
2807
+ )
2808
+
2809
+ # Check if stream is complete
2810
+ if choice.finish_reason:
2811
+ stream_completed = True
2812
+
2813
+ # If we have complete tool calls, execute them with
2814
+ # sync status updates
2815
+ if accumulated_tool_calls:
2816
+ # Record assistant message with tool calls first
2817
+ self._record_assistant_tool_calls_message(
2818
+ accumulated_tool_calls,
2819
+ content_accumulator.get_full_content(),
2820
+ )
2821
+
2822
+ # Execute tools synchronously with
2823
+ # optimized status updates
2824
+ for (
2825
+ status_response
2826
+ ) in self._execute_tools_sync_with_status_accumulator(
2827
+ accumulated_tool_calls,
2828
+ content_accumulator,
2829
+ step_token_usage,
2830
+ tool_call_records,
2831
+ ):
2832
+ yield status_response
2833
+
2834
+ # Yield "Sending back result to model" status
2835
+ if tool_call_records:
2836
+ sending_status = self._create_tool_status_response_with_accumulator( # noqa: E501
2837
+ content_accumulator,
2838
+ "\n------\n\nSending back result to model\n\n",
2839
+ "tool_sending",
2840
+ step_token_usage,
2841
+ )
2842
+ yield sending_status
2843
+
2844
+ # Record final message only if we have content AND no tool
2845
+ # calls. If there are tool calls, _record_tool_calling
2846
+ # will handle message recording.
2847
+ final_content = content_accumulator.get_full_content()
2848
+ if final_content.strip() and not accumulated_tool_calls:
2849
+ final_message = BaseMessage(
2850
+ role_name=self.role_name,
2851
+ role_type=self.role_type,
2852
+ meta_dict={},
2853
+ content=final_content,
2854
+ )
2855
+
2856
+ if response_format:
2857
+ self._try_format_message(
2858
+ final_message, response_format
2859
+ )
2860
+
2861
+ self.record_message(final_message)
2862
+ break
2863
+
2864
+ return stream_completed, tool_calls_complete
2865
+
2866
+ def _accumulate_tool_calls(
2867
+ self,
2868
+ tool_call_deltas: List[Any],
2869
+ accumulated_tool_calls: Dict[str, Any],
2870
+ ) -> bool:
2871
+ r"""Accumulate tool call chunks and return True when
2872
+ any tool call is complete.
2873
+
2874
+ Args:
2875
+ tool_call_deltas (List[Any]): List of tool call deltas.
2876
+ accumulated_tool_calls (Dict[str, Any]): Dictionary of accumulated
2877
+ tool calls.
2878
+
2879
+ Returns:
2880
+ bool: True if any tool call is complete, False otherwise.
2881
+ """
2882
+
2883
+ for delta_tool_call in tool_call_deltas:
2884
+ index = delta_tool_call.index
2885
+ tool_call_id = getattr(delta_tool_call, 'id', None)
2886
+
2887
+ # Initialize tool call entry if not exists
2888
+ if index not in accumulated_tool_calls:
2889
+ accumulated_tool_calls[index] = {
2890
+ 'id': '',
2891
+ 'type': 'function',
2892
+ 'function': {'name': '', 'arguments': ''},
2893
+ 'complete': False,
2894
+ }
2895
+
2896
+ tool_call_entry = accumulated_tool_calls[index]
2897
+
2898
+ # Accumulate tool call data
2899
+ if tool_call_id:
2900
+ tool_call_entry['id'] = (
2901
+ tool_call_id # Set full ID, don't append
2902
+ )
2903
+
2904
+ if (
2905
+ hasattr(delta_tool_call, 'function')
2906
+ and delta_tool_call.function
2907
+ ):
2908
+ if delta_tool_call.function.name:
2909
+ tool_call_entry['function']['name'] += (
2910
+ delta_tool_call.function.name
2911
+ ) # Append incremental name
2912
+ if delta_tool_call.function.arguments:
2913
+ tool_call_entry['function']['arguments'] += (
2914
+ delta_tool_call.function.arguments
2915
+ )
2916
+
2917
+ # Check if any tool calls are complete
2918
+ any_complete = False
2919
+ for _index, tool_call_entry in accumulated_tool_calls.items():
2920
+ if (
2921
+ tool_call_entry['id']
2922
+ and tool_call_entry['function']['name']
2923
+ and tool_call_entry['function']['arguments']
2924
+ and tool_call_entry['function']['name'] in self._internal_tools
2925
+ ):
2926
+ try:
2927
+ # Try to parse arguments to check completeness
2928
+ json.loads(tool_call_entry['function']['arguments'])
2929
+ tool_call_entry['complete'] = True
2930
+ any_complete = True
2931
+ except json.JSONDecodeError:
2932
+ # Arguments not complete yet
2933
+ tool_call_entry['complete'] = False
2934
+
2935
+ return any_complete
2936
+
2937
+ def _execute_tools_sync_with_status_accumulator(
2938
+ self,
2939
+ accumulated_tool_calls: Dict[str, Any],
2940
+ content_accumulator: StreamContentAccumulator,
2941
+ step_token_usage: Dict[str, int],
2942
+ tool_call_records: List[ToolCallingRecord],
2943
+ ) -> Generator[ChatAgentResponse, None, None]:
2944
+ r"""Execute multiple tools synchronously with
2945
+ proper content accumulation, using threads+queue for
2946
+ non-blocking status streaming."""
2947
+
2948
+ def tool_worker(tool_func, args, result_queue, tool_call_data):
2949
+ try:
2950
+ tool_call_record = self._execute_tool_from_stream_data(
2951
+ tool_call_data
2952
+ )
2953
+ result_queue.put(tool_call_record)
2954
+ except Exception as e:
2955
+ logger.error(f"Error in threaded tool execution: {e}")
2956
+ result_queue.put(None)
2957
+
2958
+ tool_calls_to_execute = []
2959
+ for _tool_call_index, tool_call_data in accumulated_tool_calls.items():
2960
+ if tool_call_data.get('complete', False):
2961
+ tool_calls_to_execute.append(tool_call_data)
2962
+
2963
+ # Phase 2: Execute tools in threads and yield status while waiting
2964
+ for tool_call_data in tool_calls_to_execute:
2965
+ function_name = tool_call_data['function']['name']
2966
+ try:
2967
+ args = json.loads(tool_call_data['function']['arguments'])
2968
+ except json.JSONDecodeError:
2969
+ args = tool_call_data['function']['arguments']
2970
+ result_queue: queue.Queue[Optional[ToolCallingRecord]] = (
2971
+ queue.Queue()
2972
+ )
2973
+ thread = threading.Thread(
2974
+ target=tool_worker,
2975
+ args=(
2976
+ self._internal_tools[function_name],
2977
+ args,
2978
+ result_queue,
2979
+ tool_call_data,
2980
+ ),
2981
+ )
2982
+ thread.start()
2983
+
2984
+ status_message = (
2985
+ f"\nCalling function: {function_name} "
2986
+ f"with arguments:\n{args}\n"
2987
+ )
2988
+ status_status = self._create_tool_status_response_with_accumulator(
2989
+ content_accumulator,
2990
+ status_message,
2991
+ "tool_calling",
2992
+ step_token_usage,
2993
+ )
2994
+ yield status_status
2995
+ # wait for tool thread to finish with optional timeout
2996
+ thread.join(self.tool_execution_timeout)
2997
+
2998
+ # If timeout occurred, mark as error and continue
2999
+ if thread.is_alive():
3000
+ timeout_msg = (
3001
+ f"\nFunction '{function_name}' timed out after "
3002
+ f"{self.tool_execution_timeout} seconds.\n---------\n"
3003
+ )
3004
+ timeout_status = (
3005
+ self._create_tool_status_response_with_accumulator(
3006
+ content_accumulator,
3007
+ timeout_msg,
3008
+ "tool_timeout",
3009
+ step_token_usage,
3010
+ )
3011
+ )
3012
+ yield timeout_status
3013
+ logger.error(timeout_msg.strip())
3014
+ # Detach thread (it may still finish later). Skip recording.
3015
+ continue
3016
+
3017
+ # Tool finished, get result
3018
+ tool_call_record = result_queue.get()
3019
+ if tool_call_record:
3020
+ tool_call_records.append(tool_call_record)
3021
+ raw_result = tool_call_record.result
3022
+ result_str = str(raw_result)
3023
+ status_message = (
3024
+ f"\nFunction output: {result_str}\n---------\n"
3025
+ )
3026
+ output_status = (
3027
+ self._create_tool_status_response_with_accumulator(
3028
+ content_accumulator,
3029
+ status_message,
3030
+ "tool_output",
3031
+ step_token_usage,
3032
+ [tool_call_record],
3033
+ )
3034
+ )
3035
+ yield output_status
3036
+ else:
3037
+ # Error already logged
3038
+ continue
3039
+
3040
+ def _execute_tool_from_stream_data(
3041
+ self, tool_call_data: Dict[str, Any]
3042
+ ) -> Optional[ToolCallingRecord]:
3043
+ r"""Execute a tool from accumulated stream data."""
3044
+
3045
+ try:
3046
+ function_name = tool_call_data['function']['name']
3047
+ args = json.loads(tool_call_data['function']['arguments'])
3048
+ tool_call_id = tool_call_data['id']
3049
+
3050
+ if function_name in self._internal_tools:
3051
+ tool = self._internal_tools[function_name]
3052
+ try:
3053
+ result = tool(**args)
3054
+
3055
+ # Only record the tool response message, not the assistant
3056
+ # message assistant message with tool_calls was already
3057
+ # recorded in _record_assistant_tool_calls_message
3058
+ func_msg = FunctionCallingMessage(
3059
+ role_name=self.role_name,
3060
+ role_type=self.role_type,
3061
+ meta_dict=None,
3062
+ content="",
3063
+ func_name=function_name,
3064
+ result=result,
3065
+ tool_call_id=tool_call_id,
3066
+ )
3067
+
3068
+ self.update_memory(func_msg, OpenAIBackendRole.FUNCTION)
3069
+
3070
+ return ToolCallingRecord(
3071
+ tool_name=function_name,
3072
+ args=args,
3073
+ result=result,
3074
+ tool_call_id=tool_call_id,
3075
+ )
3076
+
3077
+ except Exception as e:
3078
+ error_msg = (
3079
+ f"Error executing tool '{function_name}': {e!s}"
3080
+ )
3081
+ result = {"error": error_msg}
3082
+ logging.warning(error_msg)
3083
+
3084
+ # Record error response
3085
+ func_msg = FunctionCallingMessage(
3086
+ role_name=self.role_name,
3087
+ role_type=self.role_type,
3088
+ meta_dict=None,
3089
+ content="",
3090
+ func_name=function_name,
3091
+ result=result,
3092
+ tool_call_id=tool_call_id,
3093
+ )
3094
+
3095
+ self.update_memory(func_msg, OpenAIBackendRole.FUNCTION)
3096
+
3097
+ return ToolCallingRecord(
3098
+ tool_name=function_name,
3099
+ args=args,
3100
+ result=result,
3101
+ tool_call_id=tool_call_id,
3102
+ )
3103
+ else:
3104
+ logger.warning(
3105
+ f"Tool '{function_name}' not found in internal tools"
3106
+ )
3107
+ return None
3108
+
3109
+ except Exception as e:
3110
+ logger.error(f"Error processing tool call: {e}")
3111
+ return None
3112
+
3113
+ async def _aexecute_tool_from_stream_data(
3114
+ self, tool_call_data: Dict[str, Any]
3115
+ ) -> Optional[ToolCallingRecord]:
3116
+ r"""Async execute a tool from accumulated stream data."""
3117
+
3118
+ try:
3119
+ function_name = tool_call_data['function']['name']
3120
+ args = json.loads(tool_call_data['function']['arguments'])
3121
+ tool_call_id = tool_call_data['id']
3122
+
3123
+ if function_name in self._internal_tools:
3124
+ tool = self._internal_tools[function_name]
3125
+ try:
3126
+ result = await tool.async_call(**args)
3127
+
3128
+ # Only record the tool response message, not the assistant
3129
+ # message assistant message with tool_calls was already
3130
+ # recorded in _record_assistant_tool_calls_message
3131
+ func_msg = FunctionCallingMessage(
3132
+ role_name=self.role_name,
3133
+ role_type=self.role_type,
3134
+ meta_dict=None,
3135
+ content="",
3136
+ func_name=function_name,
3137
+ result=result,
3138
+ tool_call_id=tool_call_id,
3139
+ )
3140
+
3141
+ self.update_memory(func_msg, OpenAIBackendRole.FUNCTION)
3142
+
3143
+ return ToolCallingRecord(
3144
+ tool_name=function_name,
3145
+ args=args,
3146
+ result=result,
3147
+ tool_call_id=tool_call_id,
3148
+ )
3149
+
3150
+ except Exception as e:
3151
+ error_msg = (
3152
+ f"Error executing async tool '{function_name}': {e!s}"
3153
+ )
3154
+ result = {"error": error_msg}
3155
+ logging.warning(error_msg)
3156
+
3157
+ # Record error response
3158
+ func_msg = FunctionCallingMessage(
3159
+ role_name=self.role_name,
3160
+ role_type=self.role_type,
3161
+ meta_dict=None,
3162
+ content="",
3163
+ func_name=function_name,
3164
+ result=result,
3165
+ tool_call_id=tool_call_id,
3166
+ )
3167
+
3168
+ self.update_memory(func_msg, OpenAIBackendRole.FUNCTION)
3169
+
3170
+ return ToolCallingRecord(
3171
+ tool_name=function_name,
3172
+ args=args,
3173
+ result=result,
3174
+ tool_call_id=tool_call_id,
3175
+ )
3176
+ else:
3177
+ logger.warning(
3178
+ f"Tool '{function_name}' not found in internal tools"
3179
+ )
3180
+ return None
3181
+
3182
+ except Exception as e:
3183
+ logger.error(f"Error processing async tool call: {e}")
3184
+ return None
3185
+
3186
+ def _create_error_response(
3187
+ self, error_message: str, tool_call_records: List[ToolCallingRecord]
3188
+ ) -> ChatAgentResponse:
3189
+ r"""Create an error response for streaming."""
3190
+
3191
+ error_msg = BaseMessage(
3192
+ role_name=self.role_name,
3193
+ role_type=self.role_type,
3194
+ meta_dict={},
3195
+ content=f"Error: {error_message}",
3196
+ )
3197
+
3198
+ return ChatAgentResponse(
3199
+ msgs=[error_msg],
3200
+ terminated=True,
3201
+ info={
3202
+ "error": error_message,
3203
+ "tool_calls": tool_call_records,
3204
+ "streaming": True,
3205
+ },
3206
+ )
3207
+
3208
+ async def _astream(
3209
+ self,
3210
+ input_message: Union[BaseMessage, str],
3211
+ response_format: Optional[Type[BaseModel]] = None,
3212
+ ) -> AsyncGenerator[ChatAgentResponse, None]:
3213
+ r"""Asynchronous version of stream method."""
3214
+
3215
+ # Convert input message to BaseMessage if necessary
3216
+ if isinstance(input_message, str):
3217
+ input_message = BaseMessage.make_user_message(
3218
+ role_name="User", content=input_message
3219
+ )
3220
+
3221
+ # Add user input to memory
3222
+ self.update_memory(input_message, OpenAIBackendRole.USER)
3223
+
3224
+ # Get context for streaming
3225
+ try:
3226
+ openai_messages, num_tokens = self.memory.get_context()
3227
+ except RuntimeError as e:
3228
+ yield self._step_terminate(e.args[1], [], "max_tokens_exceeded")
3229
+ return
3230
+
3231
+ # Start async streaming response
3232
+ async for response in self._astream_response(
3233
+ openai_messages, num_tokens, response_format
3234
+ ):
3235
+ yield response
3236
+
3237
+ async def _astream_response(
3238
+ self,
3239
+ openai_messages: List[OpenAIMessage],
3240
+ num_tokens: int,
3241
+ response_format: Optional[Type[BaseModel]] = None,
3242
+ ) -> AsyncGenerator[ChatAgentResponse, None]:
3243
+ r"""Async method to handle streaming responses with tool calls."""
3244
+
3245
+ tool_call_records: List[ToolCallingRecord] = []
3246
+ accumulated_tool_calls: Dict[str, Any] = {}
3247
+ step_token_usage = self._create_token_usage_tracker()
3248
+
3249
+ # Create content accumulator for proper content management
3250
+ content_accumulator = StreamContentAccumulator()
3251
+ iteration_count = 0
3252
+ while True:
3253
+ # Check termination condition
3254
+ if self.stop_event and self.stop_event.is_set():
3255
+ yield self._step_terminate(
3256
+ num_tokens, tool_call_records, "termination_triggered"
3257
+ )
3258
+ return
3259
+
3260
+ # Get async streaming response from model
3261
+ try:
3262
+ response = await self.model_backend.arun(
3263
+ openai_messages,
3264
+ response_format,
3265
+ self._get_full_tool_schemas() or None,
3266
+ )
3267
+ iteration_count += 1
3268
+ except Exception as exc:
3269
+ logger.error(
3270
+ f"Error in async streaming model response: {exc}",
3271
+ exc_info=exc,
3272
+ )
3273
+ yield self._create_error_response(str(exc), tool_call_records)
3274
+ return
3275
+
3276
+ # Handle streaming response
3277
+ if isinstance(response, AsyncStream):
3278
+ stream_completed = False
3279
+ tool_calls_complete = False
3280
+
3281
+ # Process chunks and forward them
3282
+ async for (
3283
+ item
3284
+ ) in self._aprocess_stream_chunks_with_accumulator(
3285
+ response,
3286
+ content_accumulator,
3287
+ accumulated_tool_calls,
3288
+ tool_call_records,
3289
+ step_token_usage,
3290
+ response_format,
3291
+ ):
3292
+ if isinstance(item, tuple):
3293
+ # This is the final return value (stream_completed,
3294
+ # tool_calls_complete)
3295
+ stream_completed, tool_calls_complete = item
3296
+ break
3297
+ else:
3298
+ # This is a ChatAgentResponse to be yielded
3299
+ yield item
3300
+
3301
+ if tool_calls_complete:
3302
+ # Clear completed tool calls
3303
+ accumulated_tool_calls.clear()
3304
+
3305
+ # If we executed tools and not in
3306
+ # single iteration mode, continue
3307
+ if tool_call_records and (
3308
+ self.max_iteration is None
3309
+ or iteration_count < self.max_iteration
3310
+ ):
3311
+ # Update messages with tool results for next iteration
3312
+ try:
3313
+ openai_messages, num_tokens = (
3314
+ self.memory.get_context()
3315
+ )
3316
+ except RuntimeError as e:
3317
+ yield self._step_terminate(
3318
+ e.args[1],
3319
+ tool_call_records,
3320
+ "max_tokens_exceeded",
3321
+ )
3322
+ return
3323
+ # Reset streaming content for next iteration
3324
+ content_accumulator.reset_streaming_content()
3325
+ continue
3326
+ else:
3327
+ break
3328
+ else:
3329
+ # Stream completed without tool calls
3330
+ accumulated_tool_calls.clear()
3331
+ break
3332
+ elif hasattr(response, '__aenter__') and hasattr(
3333
+ response, '__aexit__'
3334
+ ):
3335
+ # Handle structured output stream
3336
+ # (AsyncChatCompletionStreamManager)
3337
+ async with response as stream:
3338
+ parsed_object = None
3339
+
3340
+ async for event in stream:
3341
+ if event.type == "content.delta":
3342
+ if getattr(event, "delta", None):
3343
+ # Use accumulator for proper content management
3344
+ partial_response = self._create_streaming_response_with_accumulator( # noqa: E501
3345
+ content_accumulator,
3346
+ getattr(event, "delta", ""),
3347
+ step_token_usage,
3348
+ tool_call_records=tool_call_records.copy(),
3349
+ )
3350
+ yield partial_response
3351
+
3352
+ elif event.type == "content.done":
3353
+ parsed_object = getattr(event, "parsed", None)
3354
+ break
3355
+ elif event.type == "error":
3356
+ logger.error(
3357
+ f"Error in async structured stream: "
3358
+ f"{getattr(event, 'error', '')}"
3359
+ )
3360
+ yield self._create_error_response(
3361
+ str(getattr(event, 'error', '')),
3362
+ tool_call_records,
3363
+ )
3364
+ return
3365
+
3366
+ # Get final completion and record final message
3367
+ try:
3368
+ final_completion = await stream.get_final_completion()
3369
+ final_content = (
3370
+ final_completion.choices[0].message.content or ""
3371
+ )
3372
+
3373
+ final_message = BaseMessage(
3374
+ role_name=self.role_name,
3375
+ role_type=self.role_type,
3376
+ meta_dict={},
3377
+ content=final_content,
3378
+ parsed=cast(
3379
+ "BaseModel | dict[str, Any] | None",
3380
+ parsed_object,
3381
+ ), # type: ignore[arg-type]
3382
+ )
3383
+
3384
+ self.record_message(final_message)
3385
+
3386
+ # Create final response
3387
+ final_response = ChatAgentResponse(
3388
+ msgs=[final_message],
3389
+ terminated=False,
3390
+ info={
3391
+ "id": final_completion.id or "",
3392
+ "usage": safe_model_dump(
3393
+ final_completion.usage
3394
+ )
3395
+ if final_completion.usage
3396
+ else {},
3397
+ "finish_reasons": [
3398
+ choice.finish_reason or "stop"
3399
+ for choice in final_completion.choices
3400
+ ],
3401
+ "num_tokens": self._get_token_count(
3402
+ final_content
3403
+ ),
3404
+ "tool_calls": tool_call_records,
3405
+ "external_tool_requests": None,
3406
+ "streaming": False,
3407
+ "partial": False,
3408
+ },
3409
+ )
3410
+ yield final_response
3411
+ break
3412
+
3413
+ except Exception as e:
3414
+ logger.error(
3415
+ f"Error getting async final completion: {e}"
3416
+ )
3417
+ yield self._create_error_response(
3418
+ str(e), tool_call_records
3419
+ )
3420
+ return
3421
+ else:
3422
+ # Handle non-streaming response (fallback)
3423
+ model_response = self._handle_batch_response(response)
3424
+ yield self._convert_to_chatagent_response(
3425
+ model_response,
3426
+ tool_call_records,
3427
+ num_tokens,
3428
+ None,
3429
+ model_response.usage_dict.get("prompt_tokens", 0),
3430
+ model_response.usage_dict.get("completion_tokens", 0),
3431
+ model_response.usage_dict.get("total_tokens", 0),
3432
+ )
3433
+ accumulated_tool_calls.clear()
3434
+ break
3435
+
3436
+ def _record_assistant_tool_calls_message(
3437
+ self, accumulated_tool_calls: Dict[str, Any], content: str = ""
3438
+ ) -> None:
3439
+ r"""Record the assistant message that contains tool calls.
3440
+
3441
+ This method creates and records an assistant message that includes
3442
+ the tool calls information, which is required by OpenAI's API format.
3443
+ """
3444
+ # Create a BaseMessage with tool_calls information in meta_dict
3445
+ # This will be converted to the proper OpenAI format when needed
3446
+ tool_calls_list = []
3447
+ for tool_call_data in accumulated_tool_calls.values():
3448
+ if tool_call_data.get('complete', False):
3449
+ tool_call_dict = {
3450
+ "id": tool_call_data["id"],
3451
+ "type": "function",
3452
+ "function": {
3453
+ "name": tool_call_data["function"]["name"],
3454
+ "arguments": tool_call_data["function"]["arguments"],
3455
+ },
3456
+ }
3457
+ tool_calls_list.append(tool_call_dict)
3458
+
3459
+ # Create an assistant message with tool calls
3460
+ assist_msg = BaseMessage(
3461
+ role_name=self.role_name,
3462
+ role_type=self.role_type,
3463
+ meta_dict={"tool_calls": tool_calls_list},
3464
+ content=content or "",
3465
+ )
3466
+
3467
+ # Record this assistant message
3468
+ self.update_memory(assist_msg, OpenAIBackendRole.ASSISTANT)
3469
+
3470
+ async def _aprocess_stream_chunks_with_accumulator(
3471
+ self,
3472
+ stream: AsyncStream[ChatCompletionChunk],
3473
+ content_accumulator: StreamContentAccumulator,
3474
+ accumulated_tool_calls: Dict[str, Any],
3475
+ tool_call_records: List[ToolCallingRecord],
3476
+ step_token_usage: Dict[str, int],
3477
+ response_format: Optional[Type[BaseModel]] = None,
3478
+ ) -> AsyncGenerator[Union[ChatAgentResponse, Tuple[bool, bool]], None]:
3479
+ r"""Async version of process streaming chunks with
3480
+ content accumulator."""
3481
+
3482
+ tool_calls_complete = False
3483
+ stream_completed = False
3484
+
3485
+ async for chunk in stream:
3486
+ # Update token usage if available
3487
+ if chunk.usage:
3488
+ self._update_token_usage_tracker(
3489
+ step_token_usage, safe_model_dump(chunk.usage)
3490
+ )
3491
+
3492
+ # Process chunk delta
3493
+ if chunk.choices and len(chunk.choices) > 0:
3494
+ choice = chunk.choices[0]
3495
+ delta = choice.delta
3496
+
3497
+ # Handle content streaming
3498
+ if delta.content:
3499
+ # Use accumulator for proper content management
3500
+ partial_response = (
3501
+ self._create_streaming_response_with_accumulator(
3502
+ content_accumulator,
3503
+ delta.content,
3504
+ step_token_usage,
3505
+ getattr(chunk, 'id', ''),
3506
+ tool_call_records.copy(),
3507
+ )
3508
+ )
3509
+ yield partial_response
3510
+
3511
+ # Handle tool calls streaming
3512
+ if delta.tool_calls:
3513
+ tool_calls_complete = self._accumulate_tool_calls(
3514
+ delta.tool_calls, accumulated_tool_calls
3515
+ )
3516
+
3517
+ # Check if stream is complete
3518
+ if choice.finish_reason:
3519
+ stream_completed = True
3520
+
3521
+ # If we have complete tool calls, execute them with
3522
+ # async status updates
3523
+ if accumulated_tool_calls:
3524
+ # Record assistant message with
3525
+ # tool calls first
3526
+ self._record_assistant_tool_calls_message(
3527
+ accumulated_tool_calls,
3528
+ content_accumulator.get_full_content(),
3529
+ )
3530
+
3531
+ # Execute tools asynchronously with real-time
3532
+ # status updates
3533
+ async for (
3534
+ status_response
3535
+ ) in self._execute_tools_async_with_status_accumulator(
3536
+ accumulated_tool_calls,
3537
+ content_accumulator,
3538
+ step_token_usage,
3539
+ tool_call_records,
3540
+ ):
3541
+ yield status_response
3542
+
3543
+ # Yield "Sending back result to model" status
3544
+ if tool_call_records:
3545
+ sending_status = self._create_tool_status_response_with_accumulator( # noqa: E501
3546
+ content_accumulator,
3547
+ "\n------\n\nSending back result to model\n\n",
3548
+ "tool_sending",
3549
+ step_token_usage,
3550
+ )
3551
+ yield sending_status
3552
+
3553
+ # Record final message only if we have content AND no tool
3554
+ # calls. If there are tool calls, _record_tool_calling
3555
+ # will handle message recording.
3556
+ final_content = content_accumulator.get_full_content()
3557
+ if final_content.strip() and not accumulated_tool_calls:
3558
+ final_message = BaseMessage(
3559
+ role_name=self.role_name,
3560
+ role_type=self.role_type,
3561
+ meta_dict={},
3562
+ content=final_content,
3563
+ )
3564
+
3565
+ if response_format:
3566
+ self._try_format_message(
3567
+ final_message, response_format
3568
+ )
3569
+
3570
+ self.record_message(final_message)
3571
+ break
3572
+
3573
+ # Yield the final status as a tuple
3574
+ yield (stream_completed, tool_calls_complete)
3575
+
3576
+ async def _execute_tools_async_with_status_accumulator(
3577
+ self,
3578
+ accumulated_tool_calls: Dict[str, Any],
3579
+ content_accumulator: StreamContentAccumulator,
3580
+ step_token_usage: Dict[str, int],
3581
+ tool_call_records: List[ToolCallingRecord],
3582
+ ) -> AsyncGenerator[ChatAgentResponse, None]:
3583
+ r"""Execute multiple tools asynchronously with
3584
+ proper content accumulation."""
3585
+ import asyncio
3586
+
3587
+ # Phase 1: Start all tools and yield "Calling function"
3588
+ # statuses immediately
3589
+ tool_tasks = []
3590
+ for _tool_call_index, tool_call_data in accumulated_tool_calls.items():
3591
+ if tool_call_data.get('complete', False):
3592
+ function_name = tool_call_data['function']['name']
3593
+ try:
3594
+ args = json.loads(tool_call_data['function']['arguments'])
3595
+ except json.JSONDecodeError:
3596
+ args = tool_call_data['function']['arguments']
3597
+
3598
+ status_message = (
3599
+ f"\nCalling function: {function_name} "
3600
+ f"with arguments:\n{args}\n"
3601
+ )
3602
+
3603
+ # Immediately yield "Calling function" status
3604
+ calling_status = (
3605
+ self._create_tool_status_response_with_accumulator(
3606
+ content_accumulator,
3607
+ status_message,
3608
+ "tool_calling",
3609
+ step_token_usage,
3610
+ )
3611
+ )
3612
+ yield calling_status
3613
+
3614
+ # Start tool execution asynchronously (non-blocking)
3615
+ if self.tool_execution_timeout is not None:
3616
+ task = asyncio.create_task(
3617
+ asyncio.wait_for(
3618
+ self._aexecute_tool_from_stream_data(
3619
+ tool_call_data
3620
+ ),
3621
+ timeout=self.tool_execution_timeout,
3622
+ )
3623
+ )
3624
+ else:
3625
+ task = asyncio.create_task(
3626
+ self._aexecute_tool_from_stream_data(tool_call_data)
3627
+ )
3628
+ tool_tasks.append((task, tool_call_data))
3629
+
3630
+ # Phase 2: Wait for tools to complete and yield results as they finish
3631
+ if tool_tasks:
3632
+ # Use asyncio.as_completed for true async processing
3633
+ for completed_task in asyncio.as_completed(
3634
+ [task for task, _ in tool_tasks]
3635
+ ):
3636
+ try:
3637
+ tool_call_record = await completed_task
3638
+ if tool_call_record:
3639
+ # Add to the shared tool_call_records list
3640
+ tool_call_records.append(tool_call_record)
3641
+
3642
+ # Create output status message
3643
+ raw_result = tool_call_record.result
3644
+ result_str = str(raw_result)
3645
+ status_message = (
3646
+ f"\nFunction output: {result_str}\n---------\n"
3647
+ )
3648
+
3649
+ # Yield "Function output" status as soon as this
3650
+ # tool completes
3651
+ output_status = (
3652
+ self._create_tool_status_response_with_accumulator(
3653
+ content_accumulator,
3654
+ status_message,
3655
+ "tool_output",
3656
+ step_token_usage,
3657
+ [tool_call_record],
3658
+ )
3659
+ )
3660
+ yield output_status
3661
+
3662
+ except Exception as e:
3663
+ if isinstance(e, asyncio.TimeoutError):
3664
+ timeout_msg = (
3665
+ f"\nFunction timed out after "
3666
+ f"{self.tool_execution_timeout} seconds.\n"
3667
+ f"---------\n"
3668
+ )
3669
+ timeout_status = (
3670
+ self._create_tool_status_response_with_accumulator(
3671
+ content_accumulator,
3672
+ timeout_msg,
3673
+ "tool_timeout",
3674
+ step_token_usage,
3675
+ )
3676
+ )
3677
+ yield timeout_status
3678
+ logger.error("Async tool execution timeout")
3679
+ else:
3680
+ logger.error(f"Error in async tool execution: {e}")
3681
+ continue
3682
+
3683
+ def _create_tool_status_response_with_accumulator(
3684
+ self,
3685
+ accumulator: StreamContentAccumulator,
3686
+ status_message: str,
3687
+ status_type: str,
3688
+ step_token_usage: Dict[str, int],
3689
+ tool_calls: Optional[List[ToolCallingRecord]] = None,
3690
+ ) -> ChatAgentResponse:
3691
+ r"""Create a tool status response using content accumulator."""
3692
+
3693
+ # Add this status message to accumulator and get full content
3694
+ accumulator.add_tool_status(status_message)
3695
+ full_content = accumulator.get_full_content()
3696
+
3697
+ message = BaseMessage(
3698
+ role_name=self.role_name,
3699
+ role_type=self.role_type,
3700
+ meta_dict={},
3701
+ content=full_content,
3702
+ )
3703
+
3704
+ return ChatAgentResponse(
3705
+ msgs=[message],
3706
+ terminated=False,
3707
+ info={
3708
+ "id": "",
3709
+ "usage": step_token_usage.copy(),
3710
+ "finish_reasons": [status_type],
3711
+ "num_tokens": self._get_token_count(full_content),
3712
+ "tool_calls": tool_calls or [],
3713
+ "external_tool_requests": None,
3714
+ "streaming": True,
3715
+ "tool_status": status_type,
3716
+ "partial": True,
3717
+ },
3718
+ )
3719
+
3720
+ def _create_streaming_response_with_accumulator(
3721
+ self,
3722
+ accumulator: StreamContentAccumulator,
3723
+ new_content: str,
3724
+ step_token_usage: Dict[str, int],
3725
+ response_id: str = "",
3726
+ tool_call_records: Optional[List[ToolCallingRecord]] = None,
3727
+ ) -> ChatAgentResponse:
3728
+ r"""Create a streaming response using content accumulator."""
3729
+
3730
+ # Add new content to accumulator and get full content
3731
+ accumulator.add_streaming_content(new_content)
3732
+ full_content = accumulator.get_full_content()
3733
+
3734
+ message = BaseMessage(
3735
+ role_name=self.role_name,
3736
+ role_type=self.role_type,
3737
+ meta_dict={},
3738
+ content=full_content,
3739
+ )
3740
+
3741
+ return ChatAgentResponse(
3742
+ msgs=[message],
3743
+ terminated=False,
3744
+ info={
3745
+ "id": response_id,
3746
+ "usage": step_token_usage.copy(),
3747
+ "finish_reasons": ["streaming"],
3748
+ "num_tokens": self._get_token_count(full_content),
3749
+ "tool_calls": tool_call_records or [],
3750
+ "external_tool_requests": None,
3751
+ "streaming": True,
3752
+ "partial": True,
3753
+ },
3754
+ )
3755
+
2409
3756
  def get_usage_dict(
2410
3757
  self, output_messages: List[BaseMessage], prompt_tokens: int
2411
3758
  ) -> Dict[str, int]:
@@ -2478,6 +3825,7 @@ class ChatAgent(BaseAgent):
2478
3825
  ),
2479
3826
  max_iteration=self.max_iteration,
2480
3827
  stop_event=self.stop_event,
3828
+ tool_execution_timeout=self.tool_execution_timeout,
2481
3829
  pause_event=self.pause_event,
2482
3830
  )
2483
3831