camel-ai 0.2.67__py3-none-any.whl → 0.2.80a2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- camel/__init__.py +1 -1
- camel/agents/_types.py +6 -2
- camel/agents/_utils.py +38 -0
- camel/agents/chat_agent.py +4014 -410
- camel/agents/mcp_agent.py +30 -27
- camel/agents/repo_agent.py +2 -1
- camel/benchmarks/browsecomp.py +6 -6
- camel/configs/__init__.py +15 -0
- camel/configs/aihubmix_config.py +88 -0
- camel/configs/amd_config.py +70 -0
- camel/configs/cometapi_config.py +104 -0
- camel/configs/minimax_config.py +93 -0
- camel/configs/nebius_config.py +103 -0
- camel/configs/vllm_config.py +2 -0
- camel/data_collectors/alpaca_collector.py +15 -6
- camel/datagen/self_improving_cot.py +1 -1
- camel/datasets/base_generator.py +39 -10
- camel/environments/__init__.py +12 -0
- camel/environments/rlcards_env.py +860 -0
- camel/environments/single_step.py +28 -3
- camel/environments/tic_tac_toe.py +1 -1
- camel/interpreters/__init__.py +2 -0
- camel/interpreters/docker/Dockerfile +4 -16
- camel/interpreters/docker_interpreter.py +3 -2
- camel/interpreters/e2b_interpreter.py +34 -1
- camel/interpreters/internal_python_interpreter.py +51 -2
- camel/interpreters/microsandbox_interpreter.py +395 -0
- camel/loaders/__init__.py +11 -2
- camel/loaders/base_loader.py +85 -0
- camel/loaders/chunkr_reader.py +9 -0
- camel/loaders/firecrawl_reader.py +4 -4
- camel/logger.py +1 -1
- camel/memories/agent_memories.py +84 -1
- camel/memories/base.py +34 -0
- camel/memories/blocks/chat_history_block.py +122 -4
- camel/memories/blocks/vectordb_block.py +8 -1
- camel/memories/context_creators/score_based.py +29 -237
- camel/memories/records.py +88 -8
- camel/messages/base.py +166 -40
- camel/messages/func_message.py +32 -5
- camel/models/__init__.py +10 -0
- camel/models/aihubmix_model.py +83 -0
- camel/models/aiml_model.py +1 -16
- camel/models/amd_model.py +101 -0
- camel/models/anthropic_model.py +117 -18
- camel/models/aws_bedrock_model.py +2 -33
- camel/models/azure_openai_model.py +205 -91
- camel/models/base_audio_model.py +3 -1
- camel/models/base_model.py +189 -24
- camel/models/cohere_model.py +5 -17
- camel/models/cometapi_model.py +83 -0
- camel/models/crynux_model.py +1 -16
- camel/models/deepseek_model.py +6 -16
- camel/models/fish_audio_model.py +6 -0
- camel/models/gemini_model.py +71 -20
- camel/models/groq_model.py +1 -17
- camel/models/internlm_model.py +1 -16
- camel/models/litellm_model.py +49 -32
- camel/models/lmstudio_model.py +1 -17
- camel/models/minimax_model.py +83 -0
- camel/models/mistral_model.py +1 -16
- camel/models/model_factory.py +27 -1
- camel/models/model_manager.py +24 -6
- camel/models/modelscope_model.py +1 -16
- camel/models/moonshot_model.py +185 -19
- camel/models/nebius_model.py +83 -0
- camel/models/nemotron_model.py +0 -5
- camel/models/netmind_model.py +1 -16
- camel/models/novita_model.py +1 -16
- camel/models/nvidia_model.py +1 -16
- camel/models/ollama_model.py +4 -19
- camel/models/openai_compatible_model.py +171 -46
- camel/models/openai_model.py +205 -77
- camel/models/openrouter_model.py +1 -17
- camel/models/ppio_model.py +1 -16
- camel/models/qianfan_model.py +1 -16
- camel/models/qwen_model.py +1 -16
- camel/models/reka_model.py +1 -16
- camel/models/samba_model.py +34 -47
- camel/models/sglang_model.py +64 -31
- camel/models/siliconflow_model.py +1 -16
- camel/models/stub_model.py +0 -4
- camel/models/togetherai_model.py +1 -16
- camel/models/vllm_model.py +1 -16
- camel/models/volcano_model.py +0 -17
- camel/models/watsonx_model.py +1 -16
- camel/models/yi_model.py +1 -16
- camel/models/zhipuai_model.py +60 -16
- camel/parsers/__init__.py +18 -0
- camel/parsers/mcp_tool_call_parser.py +176 -0
- camel/retrievers/auto_retriever.py +1 -0
- camel/runtimes/configs.py +11 -11
- camel/runtimes/daytona_runtime.py +15 -16
- camel/runtimes/docker_runtime.py +6 -6
- camel/runtimes/remote_http_runtime.py +5 -5
- camel/services/agent_openapi_server.py +380 -0
- camel/societies/__init__.py +2 -0
- camel/societies/role_playing.py +26 -28
- camel/societies/workforce/__init__.py +2 -0
- camel/societies/workforce/events.py +122 -0
- camel/societies/workforce/prompts.py +249 -38
- camel/societies/workforce/role_playing_worker.py +82 -20
- camel/societies/workforce/single_agent_worker.py +634 -34
- camel/societies/workforce/structured_output_handler.py +512 -0
- camel/societies/workforce/task_channel.py +169 -23
- camel/societies/workforce/utils.py +176 -9
- camel/societies/workforce/worker.py +77 -23
- camel/societies/workforce/workflow_memory_manager.py +772 -0
- camel/societies/workforce/workforce.py +3168 -478
- camel/societies/workforce/workforce_callback.py +74 -0
- camel/societies/workforce/workforce_logger.py +203 -175
- camel/societies/workforce/workforce_metrics.py +33 -0
- camel/storages/__init__.py +4 -0
- camel/storages/key_value_storages/json.py +15 -2
- camel/storages/key_value_storages/mem0_cloud.py +48 -47
- camel/storages/object_storages/google_cloud.py +1 -1
- camel/storages/vectordb_storages/__init__.py +6 -0
- camel/storages/vectordb_storages/chroma.py +731 -0
- camel/storages/vectordb_storages/oceanbase.py +13 -13
- camel/storages/vectordb_storages/pgvector.py +349 -0
- camel/storages/vectordb_storages/qdrant.py +3 -3
- camel/storages/vectordb_storages/surreal.py +365 -0
- camel/storages/vectordb_storages/tidb.py +8 -6
- camel/tasks/task.py +244 -27
- camel/toolkits/__init__.py +46 -8
- camel/toolkits/aci_toolkit.py +64 -19
- camel/toolkits/arxiv_toolkit.py +6 -6
- camel/toolkits/base.py +63 -5
- camel/toolkits/code_execution.py +28 -1
- camel/toolkits/context_summarizer_toolkit.py +684 -0
- camel/toolkits/craw4ai_toolkit.py +93 -0
- camel/toolkits/dappier_toolkit.py +10 -6
- camel/toolkits/dingtalk.py +1135 -0
- camel/toolkits/edgeone_pages_mcp_toolkit.py +49 -0
- camel/toolkits/excel_toolkit.py +901 -67
- camel/toolkits/file_toolkit.py +1402 -0
- camel/toolkits/function_tool.py +30 -6
- camel/toolkits/github_toolkit.py +107 -20
- camel/toolkits/gmail_toolkit.py +1839 -0
- camel/toolkits/google_calendar_toolkit.py +38 -4
- camel/toolkits/google_drive_mcp_toolkit.py +54 -0
- camel/toolkits/human_toolkit.py +34 -10
- camel/toolkits/hybrid_browser_toolkit/__init__.py +18 -0
- camel/toolkits/hybrid_browser_toolkit/config_loader.py +185 -0
- camel/toolkits/hybrid_browser_toolkit/hybrid_browser_toolkit.py +246 -0
- camel/toolkits/hybrid_browser_toolkit/hybrid_browser_toolkit_ts.py +1973 -0
- camel/toolkits/hybrid_browser_toolkit/installer.py +203 -0
- camel/toolkits/hybrid_browser_toolkit/ts/package-lock.json +3749 -0
- camel/toolkits/hybrid_browser_toolkit/ts/package.json +32 -0
- camel/toolkits/hybrid_browser_toolkit/ts/src/browser-scripts.js +125 -0
- camel/toolkits/hybrid_browser_toolkit/ts/src/browser-session.ts +1815 -0
- camel/toolkits/hybrid_browser_toolkit/ts/src/config-loader.ts +233 -0
- camel/toolkits/hybrid_browser_toolkit/ts/src/hybrid-browser-toolkit.ts +590 -0
- camel/toolkits/hybrid_browser_toolkit/ts/src/index.ts +7 -0
- camel/toolkits/hybrid_browser_toolkit/ts/src/parent-child-filter.ts +226 -0
- camel/toolkits/hybrid_browser_toolkit/ts/src/snapshot-parser.ts +219 -0
- camel/toolkits/hybrid_browser_toolkit/ts/src/som-screenshot-injected.ts +543 -0
- camel/toolkits/hybrid_browser_toolkit/ts/src/types.ts +130 -0
- camel/toolkits/hybrid_browser_toolkit/ts/tsconfig.json +26 -0
- camel/toolkits/hybrid_browser_toolkit/ts/websocket-server.js +319 -0
- camel/toolkits/hybrid_browser_toolkit/ws_wrapper.py +1032 -0
- camel/toolkits/hybrid_browser_toolkit_py/__init__.py +17 -0
- camel/toolkits/hybrid_browser_toolkit_py/actions.py +575 -0
- camel/toolkits/hybrid_browser_toolkit_py/agent.py +311 -0
- camel/toolkits/hybrid_browser_toolkit_py/browser_session.py +787 -0
- camel/toolkits/hybrid_browser_toolkit_py/config_loader.py +490 -0
- camel/toolkits/hybrid_browser_toolkit_py/hybrid_browser_toolkit.py +2390 -0
- camel/toolkits/hybrid_browser_toolkit_py/snapshot.py +233 -0
- camel/toolkits/hybrid_browser_toolkit_py/stealth_script.js +0 -0
- camel/toolkits/hybrid_browser_toolkit_py/unified_analyzer.js +1043 -0
- camel/toolkits/image_generation_toolkit.py +390 -0
- camel/toolkits/jina_reranker_toolkit.py +3 -4
- camel/toolkits/klavis_toolkit.py +5 -1
- camel/toolkits/markitdown_toolkit.py +104 -0
- camel/toolkits/math_toolkit.py +64 -10
- camel/toolkits/mcp_toolkit.py +370 -45
- camel/toolkits/memory_toolkit.py +5 -1
- camel/toolkits/message_agent_toolkit.py +608 -0
- camel/toolkits/message_integration.py +724 -0
- camel/toolkits/minimax_mcp_toolkit.py +195 -0
- camel/toolkits/note_taking_toolkit.py +277 -0
- camel/toolkits/notion_mcp_toolkit.py +224 -0
- camel/toolkits/openbb_toolkit.py +5 -1
- camel/toolkits/origene_mcp_toolkit.py +56 -0
- camel/toolkits/playwright_mcp_toolkit.py +12 -31
- camel/toolkits/pptx_toolkit.py +25 -12
- camel/toolkits/resend_toolkit.py +168 -0
- camel/toolkits/screenshot_toolkit.py +213 -0
- camel/toolkits/search_toolkit.py +437 -142
- camel/toolkits/slack_toolkit.py +104 -50
- camel/toolkits/sympy_toolkit.py +1 -1
- camel/toolkits/task_planning_toolkit.py +3 -3
- camel/toolkits/terminal_toolkit/__init__.py +18 -0
- camel/toolkits/terminal_toolkit/terminal_toolkit.py +957 -0
- camel/toolkits/terminal_toolkit/utils.py +532 -0
- camel/toolkits/thinking_toolkit.py +1 -1
- camel/toolkits/vertex_ai_veo_toolkit.py +590 -0
- camel/toolkits/video_analysis_toolkit.py +106 -26
- camel/toolkits/video_download_toolkit.py +17 -14
- camel/toolkits/web_deploy_toolkit.py +1219 -0
- camel/toolkits/wechat_official_toolkit.py +483 -0
- camel/toolkits/zapier_toolkit.py +5 -1
- camel/types/__init__.py +2 -2
- camel/types/agents/tool_calling_record.py +4 -1
- camel/types/enums.py +316 -40
- camel/types/openai_types.py +2 -2
- camel/types/unified_model_type.py +31 -4
- camel/utils/commons.py +36 -5
- camel/utils/constants.py +3 -0
- camel/utils/context_utils.py +1003 -0
- camel/utils/mcp.py +138 -4
- camel/utils/mcp_client.py +45 -1
- camel/utils/message_summarizer.py +148 -0
- camel/utils/token_counting.py +43 -20
- camel/utils/tool_result.py +44 -0
- {camel_ai-0.2.67.dist-info → camel_ai-0.2.80a2.dist-info}/METADATA +296 -85
- {camel_ai-0.2.67.dist-info → camel_ai-0.2.80a2.dist-info}/RECORD +219 -146
- camel/loaders/pandas_reader.py +0 -368
- camel/toolkits/dalle_toolkit.py +0 -175
- camel/toolkits/file_write_toolkit.py +0 -444
- camel/toolkits/openai_agent_toolkit.py +0 -135
- camel/toolkits/terminal_toolkit.py +0 -1037
- {camel_ai-0.2.67.dist-info → camel_ai-0.2.80a2.dist-info}/WHEEL +0 -0
- {camel_ai-0.2.67.dist-info → camel_ai-0.2.80a2.dist-info}/licenses/LICENSE +0 -0
camel/agents/chat_agent.py
CHANGED
|
@@ -13,34 +13,51 @@
|
|
|
13
13
|
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
|
14
14
|
from __future__ import annotations
|
|
15
15
|
|
|
16
|
+
import asyncio
|
|
17
|
+
import atexit
|
|
18
|
+
import base64
|
|
19
|
+
import concurrent.futures
|
|
20
|
+
import hashlib
|
|
21
|
+
import inspect
|
|
16
22
|
import json
|
|
17
|
-
import
|
|
23
|
+
import os
|
|
24
|
+
import random
|
|
25
|
+
import re
|
|
26
|
+
import tempfile
|
|
18
27
|
import textwrap
|
|
19
28
|
import threading
|
|
29
|
+
import time
|
|
20
30
|
import uuid
|
|
21
|
-
|
|
31
|
+
import warnings
|
|
32
|
+
from dataclasses import dataclass
|
|
33
|
+
from datetime import datetime
|
|
22
34
|
from pathlib import Path
|
|
23
35
|
from typing import (
|
|
24
36
|
TYPE_CHECKING,
|
|
25
37
|
Any,
|
|
38
|
+
AsyncGenerator,
|
|
26
39
|
Callable,
|
|
27
40
|
Dict,
|
|
41
|
+
Generator,
|
|
28
42
|
List,
|
|
29
43
|
Optional,
|
|
30
44
|
Set,
|
|
31
45
|
Tuple,
|
|
32
46
|
Type,
|
|
33
47
|
Union,
|
|
48
|
+
cast,
|
|
34
49
|
)
|
|
35
50
|
|
|
36
51
|
from openai import (
|
|
37
52
|
AsyncStream,
|
|
53
|
+
RateLimitError,
|
|
38
54
|
Stream,
|
|
39
55
|
)
|
|
40
56
|
from pydantic import BaseModel, ValidationError
|
|
41
57
|
|
|
42
58
|
from camel.agents._types import ModelResponse, ToolCallRequest
|
|
43
59
|
from camel.agents._utils import (
|
|
60
|
+
build_default_summary_prompt,
|
|
44
61
|
convert_to_function_tool,
|
|
45
62
|
convert_to_schema,
|
|
46
63
|
get_info_dict,
|
|
@@ -48,13 +65,19 @@ from camel.agents._utils import (
|
|
|
48
65
|
safe_model_dump,
|
|
49
66
|
)
|
|
50
67
|
from camel.agents.base import BaseAgent
|
|
68
|
+
from camel.logger import get_logger
|
|
51
69
|
from camel.memories import (
|
|
52
70
|
AgentMemory,
|
|
53
71
|
ChatHistoryMemory,
|
|
72
|
+
ContextRecord,
|
|
54
73
|
MemoryRecord,
|
|
55
74
|
ScoreBasedContextCreator,
|
|
56
75
|
)
|
|
57
|
-
from camel.messages import
|
|
76
|
+
from camel.messages import (
|
|
77
|
+
BaseMessage,
|
|
78
|
+
FunctionCallingMessage,
|
|
79
|
+
OpenAIMessage,
|
|
80
|
+
)
|
|
58
81
|
from camel.models import (
|
|
59
82
|
BaseModelBackend,
|
|
60
83
|
ModelFactory,
|
|
@@ -64,7 +87,7 @@ from camel.models import (
|
|
|
64
87
|
from camel.prompts import TextPrompt
|
|
65
88
|
from camel.responses import ChatAgentResponse
|
|
66
89
|
from camel.storages import JsonStorage
|
|
67
|
-
from camel.toolkits import FunctionTool
|
|
90
|
+
from camel.toolkits import FunctionTool, RegisteredAgentToolkit
|
|
68
91
|
from camel.types import (
|
|
69
92
|
ChatCompletion,
|
|
70
93
|
ChatCompletionChunk,
|
|
@@ -75,20 +98,46 @@ from camel.types import (
|
|
|
75
98
|
)
|
|
76
99
|
from camel.types.agents import ToolCallingRecord
|
|
77
100
|
from camel.utils import (
|
|
101
|
+
Constants,
|
|
78
102
|
get_model_encoding,
|
|
79
103
|
model_from_json_schema,
|
|
80
104
|
)
|
|
81
105
|
from camel.utils.commons import dependencies_required
|
|
106
|
+
from camel.utils.context_utils import ContextUtility
|
|
107
|
+
|
|
108
|
+
TOKEN_LIMIT_ERROR_MARKERS = (
|
|
109
|
+
"context_length_exceeded",
|
|
110
|
+
"prompt is too long",
|
|
111
|
+
"exceeded your current quota",
|
|
112
|
+
"tokens must be reduced",
|
|
113
|
+
"context length",
|
|
114
|
+
"token count",
|
|
115
|
+
"context limit",
|
|
116
|
+
)
|
|
82
117
|
|
|
83
118
|
if TYPE_CHECKING:
|
|
84
119
|
from camel.terminators import ResponseTerminator
|
|
85
120
|
|
|
86
|
-
logger =
|
|
121
|
+
logger = get_logger(__name__)
|
|
122
|
+
|
|
123
|
+
# Cleanup temp files on exit
|
|
124
|
+
_temp_files: Set[str] = set()
|
|
125
|
+
_temp_files_lock = threading.Lock()
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _cleanup_temp_files():
|
|
129
|
+
with _temp_files_lock:
|
|
130
|
+
for path in _temp_files:
|
|
131
|
+
try:
|
|
132
|
+
os.unlink(path)
|
|
133
|
+
except Exception:
|
|
134
|
+
pass
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
atexit.register(_cleanup_temp_files)
|
|
87
138
|
|
|
88
139
|
# AgentOps decorator setting
|
|
89
140
|
try:
|
|
90
|
-
import os
|
|
91
|
-
|
|
92
141
|
if os.getenv("AGENTOPS_API_KEY") is not None:
|
|
93
142
|
from agentops import track_agent
|
|
94
143
|
else:
|
|
@@ -102,6 +151,11 @@ if os.environ.get("LANGFUSE_ENABLED", "False").lower() == "true":
|
|
|
102
151
|
from langfuse.decorators import observe
|
|
103
152
|
except ImportError:
|
|
104
153
|
from camel.utils import observe
|
|
154
|
+
elif os.environ.get("TRACEROOT_ENABLED", "False").lower() == "true":
|
|
155
|
+
try:
|
|
156
|
+
from traceroot import trace as observe # type: ignore[import]
|
|
157
|
+
except ImportError:
|
|
158
|
+
from camel.utils import observe
|
|
105
159
|
else:
|
|
106
160
|
from camel.utils import observe
|
|
107
161
|
|
|
@@ -117,6 +171,189 @@ SIMPLE_FORMAT_PROMPT = TextPrompt(
|
|
|
117
171
|
)
|
|
118
172
|
|
|
119
173
|
|
|
174
|
+
@dataclass
|
|
175
|
+
class _ToolOutputHistoryEntry:
|
|
176
|
+
tool_name: str
|
|
177
|
+
tool_call_id: str
|
|
178
|
+
result_text: str
|
|
179
|
+
record_uuids: List[str]
|
|
180
|
+
record_timestamps: List[float]
|
|
181
|
+
cached: bool = False
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
class StreamContentAccumulator:
|
|
185
|
+
r"""Manages content accumulation across streaming responses to ensure
|
|
186
|
+
all responses contain complete cumulative content."""
|
|
187
|
+
|
|
188
|
+
def __init__(self):
|
|
189
|
+
self.base_content = "" # Content before tool calls
|
|
190
|
+
self.current_content = [] # Accumulated streaming fragments
|
|
191
|
+
self.tool_status_messages = [] # Accumulated tool status messages
|
|
192
|
+
|
|
193
|
+
def set_base_content(self, content: str):
|
|
194
|
+
r"""Set the base content (usually empty or pre-tool content)."""
|
|
195
|
+
self.base_content = content
|
|
196
|
+
|
|
197
|
+
def add_streaming_content(self, new_content: str):
|
|
198
|
+
r"""Add new streaming content."""
|
|
199
|
+
self.current_content.append(new_content)
|
|
200
|
+
|
|
201
|
+
def add_tool_status(self, status_message: str):
|
|
202
|
+
r"""Add a tool status message."""
|
|
203
|
+
self.tool_status_messages.append(status_message)
|
|
204
|
+
|
|
205
|
+
def get_full_content(self) -> str:
|
|
206
|
+
r"""Get the complete accumulated content."""
|
|
207
|
+
tool_messages = "".join(self.tool_status_messages)
|
|
208
|
+
current = "".join(self.current_content)
|
|
209
|
+
return self.base_content + tool_messages + current
|
|
210
|
+
|
|
211
|
+
def get_content_with_new_status(self, status_message: str) -> str:
|
|
212
|
+
r"""Get content with a new status message appended."""
|
|
213
|
+
tool_messages = "".join([*self.tool_status_messages, status_message])
|
|
214
|
+
current = "".join(self.current_content)
|
|
215
|
+
return self.base_content + tool_messages + current
|
|
216
|
+
|
|
217
|
+
def reset_streaming_content(self):
|
|
218
|
+
r"""Reset only the streaming content, keep base and tool status."""
|
|
219
|
+
self.current_content = []
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
class StreamingChatAgentResponse:
|
|
223
|
+
r"""A wrapper that makes streaming responses compatible with
|
|
224
|
+
non-streaming code.
|
|
225
|
+
|
|
226
|
+
This class wraps a Generator[ChatAgentResponse, None, None] and provides
|
|
227
|
+
the same interface as ChatAgentResponse, so existing code doesn't need to
|
|
228
|
+
change.
|
|
229
|
+
"""
|
|
230
|
+
|
|
231
|
+
def __init__(self, generator: Generator[ChatAgentResponse, None, None]):
|
|
232
|
+
self._generator = generator
|
|
233
|
+
self._current_response: Optional[ChatAgentResponse] = None
|
|
234
|
+
self._responses: List[ChatAgentResponse] = []
|
|
235
|
+
self._consumed = False
|
|
236
|
+
|
|
237
|
+
def _ensure_latest_response(self):
|
|
238
|
+
r"""Ensure we have the latest response by consuming the generator."""
|
|
239
|
+
if not self._consumed:
|
|
240
|
+
for response in self._generator:
|
|
241
|
+
self._responses.append(response)
|
|
242
|
+
self._current_response = response
|
|
243
|
+
self._consumed = True
|
|
244
|
+
|
|
245
|
+
@property
|
|
246
|
+
def msgs(self) -> List[BaseMessage]:
|
|
247
|
+
r"""Get messages from the latest response."""
|
|
248
|
+
self._ensure_latest_response()
|
|
249
|
+
if self._current_response:
|
|
250
|
+
return self._current_response.msgs
|
|
251
|
+
return []
|
|
252
|
+
|
|
253
|
+
@property
|
|
254
|
+
def terminated(self) -> bool:
|
|
255
|
+
r"""Get terminated status from the latest response."""
|
|
256
|
+
self._ensure_latest_response()
|
|
257
|
+
if self._current_response:
|
|
258
|
+
return self._current_response.terminated
|
|
259
|
+
return False
|
|
260
|
+
|
|
261
|
+
@property
|
|
262
|
+
def info(self) -> Dict[str, Any]:
|
|
263
|
+
r"""Get info from the latest response."""
|
|
264
|
+
self._ensure_latest_response()
|
|
265
|
+
if self._current_response:
|
|
266
|
+
return self._current_response.info
|
|
267
|
+
return {}
|
|
268
|
+
|
|
269
|
+
@property
|
|
270
|
+
def msg(self):
|
|
271
|
+
r"""Get the single message if there's exactly one message."""
|
|
272
|
+
self._ensure_latest_response()
|
|
273
|
+
if self._current_response:
|
|
274
|
+
return self._current_response.msg
|
|
275
|
+
return None
|
|
276
|
+
|
|
277
|
+
def __iter__(self):
|
|
278
|
+
r"""Make this object iterable."""
|
|
279
|
+
if self._consumed:
|
|
280
|
+
# If already consumed, iterate over stored responses
|
|
281
|
+
yield from self._responses
|
|
282
|
+
else:
|
|
283
|
+
# If not consumed, consume and yield
|
|
284
|
+
for response in self._generator:
|
|
285
|
+
self._responses.append(response)
|
|
286
|
+
self._current_response = response
|
|
287
|
+
yield response
|
|
288
|
+
self._consumed = True
|
|
289
|
+
|
|
290
|
+
def __getattr__(self, name):
|
|
291
|
+
r"""Forward any other attribute access to the latest response."""
|
|
292
|
+
self._ensure_latest_response()
|
|
293
|
+
if self._current_response and hasattr(self._current_response, name):
|
|
294
|
+
return getattr(self._current_response, name)
|
|
295
|
+
raise AttributeError(
|
|
296
|
+
f"'StreamingChatAgentResponse' object has no attribute '{name}'"
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
class AsyncStreamingChatAgentResponse:
|
|
301
|
+
r"""A wrapper that makes async streaming responses awaitable and
|
|
302
|
+
compatible with non-streaming code.
|
|
303
|
+
|
|
304
|
+
This class wraps an AsyncGenerator[ChatAgentResponse, None] and provides
|
|
305
|
+
both awaitable and async iterable interfaces.
|
|
306
|
+
"""
|
|
307
|
+
|
|
308
|
+
def __init__(
|
|
309
|
+
self, async_generator: AsyncGenerator[ChatAgentResponse, None]
|
|
310
|
+
):
|
|
311
|
+
self._async_generator = async_generator
|
|
312
|
+
self._current_response: Optional[ChatAgentResponse] = None
|
|
313
|
+
self._responses: List[ChatAgentResponse] = []
|
|
314
|
+
self._consumed = False
|
|
315
|
+
|
|
316
|
+
async def _ensure_latest_response(self):
|
|
317
|
+
r"""Ensure the latest response by consuming the async generator."""
|
|
318
|
+
if not self._consumed:
|
|
319
|
+
async for response in self._async_generator:
|
|
320
|
+
self._responses.append(response)
|
|
321
|
+
self._current_response = response
|
|
322
|
+
self._consumed = True
|
|
323
|
+
|
|
324
|
+
async def _get_final_response(self) -> ChatAgentResponse:
|
|
325
|
+
r"""Get the final response after consuming the entire stream."""
|
|
326
|
+
await self._ensure_latest_response()
|
|
327
|
+
if self._current_response:
|
|
328
|
+
return self._current_response
|
|
329
|
+
# Return a default response if nothing was consumed
|
|
330
|
+
return ChatAgentResponse(msgs=[], terminated=False, info={})
|
|
331
|
+
|
|
332
|
+
def __await__(self):
|
|
333
|
+
r"""Make this object awaitable - returns the final response."""
|
|
334
|
+
return self._get_final_response().__await__()
|
|
335
|
+
|
|
336
|
+
def __aiter__(self):
|
|
337
|
+
r"""Make this object async iterable."""
|
|
338
|
+
if self._consumed:
|
|
339
|
+
# If already consumed, create async iterator from stored responses
|
|
340
|
+
async def _async_iter():
|
|
341
|
+
for response in self._responses:
|
|
342
|
+
yield response
|
|
343
|
+
|
|
344
|
+
return _async_iter()
|
|
345
|
+
else:
|
|
346
|
+
# If not consumed, consume and yield
|
|
347
|
+
async def _consume_and_yield():
|
|
348
|
+
async for response in self._async_generator:
|
|
349
|
+
self._responses.append(response)
|
|
350
|
+
self._current_response = response
|
|
351
|
+
yield response
|
|
352
|
+
self._consumed = True
|
|
353
|
+
|
|
354
|
+
return _consume_and_yield()
|
|
355
|
+
|
|
356
|
+
|
|
120
357
|
@track_agent(name="ChatAgent")
|
|
121
358
|
class ChatAgent(BaseAgent):
|
|
122
359
|
r"""Class for managing conversations of CAMEL Chat Agents.
|
|
@@ -140,15 +377,22 @@ class ChatAgent(BaseAgent):
|
|
|
140
377
|
message_window_size (int, optional): The maximum number of previous
|
|
141
378
|
messages to include in the context window. If `None`, no windowing
|
|
142
379
|
is performed. (default: :obj:`None`)
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
380
|
+
summarize_threshold (int, optional): The percentage of the context
|
|
381
|
+
window that triggers summarization. If `None`, will trigger
|
|
382
|
+
summarization when the context window is full.
|
|
146
383
|
(default: :obj:`None`)
|
|
147
384
|
output_language (str, optional): The language to be output by the
|
|
148
385
|
agent. (default: :obj:`None`)
|
|
149
386
|
tools (Optional[List[Union[FunctionTool, Callable]]], optional): List
|
|
150
387
|
of available :obj:`FunctionTool` or :obj:`Callable`. (default:
|
|
151
388
|
:obj:`None`)
|
|
389
|
+
toolkits_to_register_agent (Optional[List[RegisteredAgentToolkit]],
|
|
390
|
+
optional): List of toolkit instances that inherit from
|
|
391
|
+
:obj:`RegisteredAgentToolkit`. The agent will register itself with
|
|
392
|
+
these toolkits, allowing them to access the agent instance. Note:
|
|
393
|
+
This does NOT add the toolkit's tools to the agent. To use tools
|
|
394
|
+
from these toolkits, pass them explicitly via the `tools`
|
|
395
|
+
parameter. (default: :obj:`None`)
|
|
152
396
|
external_tools (Optional[List[Union[FunctionTool, Callable,
|
|
153
397
|
Dict[str, Any]]]], optional): List of external tools
|
|
154
398
|
(:obj:`FunctionTool` or :obj:`Callable` or :obj:`Dict[str, Any]`)
|
|
@@ -169,6 +413,39 @@ class ChatAgent(BaseAgent):
|
|
|
169
413
|
stop_event (Optional[threading.Event], optional): Event to signal
|
|
170
414
|
termination of the agent's operation. When set, the agent will
|
|
171
415
|
terminate its execution. (default: :obj:`None`)
|
|
416
|
+
tool_execution_timeout (Optional[float], optional): Timeout
|
|
417
|
+
for individual tool execution. If None, wait indefinitely.
|
|
418
|
+
mask_tool_output (Optional[bool]): Whether to return a sanitized
|
|
419
|
+
placeholder instead of the raw tool output. (default: :obj:`False`)
|
|
420
|
+
pause_event (Optional[Union[threading.Event, asyncio.Event]]): Event to
|
|
421
|
+
signal pause of the agent's operation. When clear, the agent will
|
|
422
|
+
pause its execution. Use threading.Event for sync operations or
|
|
423
|
+
asyncio.Event for async operations. (default: :obj:`None`)
|
|
424
|
+
prune_tool_calls_from_memory (bool): Whether to clean tool
|
|
425
|
+
call messages from memory after response generation to save token
|
|
426
|
+
usage. When enabled, removes FUNCTION/TOOL role messages and
|
|
427
|
+
ASSISTANT messages with tool_calls after each step.
|
|
428
|
+
(default: :obj:`False`)
|
|
429
|
+
enable_snapshot_clean (bool, optional): Whether to clean snapshot
|
|
430
|
+
markers and references from historical tool outputs in memory.
|
|
431
|
+
This removes verbose DOM markers (like [ref=...]) from older tool
|
|
432
|
+
results while keeping the latest output intact for immediate use.
|
|
433
|
+
(default: :obj:`False`)
|
|
434
|
+
retry_attempts (int, optional): Maximum number of retry attempts for
|
|
435
|
+
rate limit errors. (default: :obj:`3`)
|
|
436
|
+
retry_delay (float, optional): Initial delay in seconds between
|
|
437
|
+
retries. Uses exponential backoff. (default: :obj:`1.0`)
|
|
438
|
+
step_timeout (Optional[float], optional): Timeout in seconds for the
|
|
439
|
+
entire step operation. If None, no timeout is applied.
|
|
440
|
+
(default: :obj:`None`)
|
|
441
|
+
stream_accumulate (bool, optional): When True, partial streaming
|
|
442
|
+
updates return accumulated content (current behavior). When False,
|
|
443
|
+
partial updates return only the incremental delta. (default:
|
|
444
|
+
:obj:`True`)
|
|
445
|
+
summary_window_ratio (float, optional): Maximum fraction of the total
|
|
446
|
+
context window that can be occupied by summary information. Used
|
|
447
|
+
to limit how much of the model's context is reserved for
|
|
448
|
+
summarization results. (default: :obj:`0.6`)
|
|
172
449
|
"""
|
|
173
450
|
|
|
174
451
|
def __init__(
|
|
@@ -191,9 +468,13 @@ class ChatAgent(BaseAgent):
|
|
|
191
468
|
] = None,
|
|
192
469
|
memory: Optional[AgentMemory] = None,
|
|
193
470
|
message_window_size: Optional[int] = None,
|
|
471
|
+
summarize_threshold: Optional[int] = 50,
|
|
194
472
|
token_limit: Optional[int] = None,
|
|
195
473
|
output_language: Optional[str] = None,
|
|
196
474
|
tools: Optional[List[Union[FunctionTool, Callable]]] = None,
|
|
475
|
+
toolkits_to_register_agent: Optional[
|
|
476
|
+
List[RegisteredAgentToolkit]
|
|
477
|
+
] = None,
|
|
197
478
|
external_tools: Optional[
|
|
198
479
|
List[Union[FunctionTool, Callable, Dict[str, Any]]]
|
|
199
480
|
] = None,
|
|
@@ -202,6 +483,16 @@ class ChatAgent(BaseAgent):
|
|
|
202
483
|
max_iteration: Optional[int] = None,
|
|
203
484
|
agent_id: Optional[str] = None,
|
|
204
485
|
stop_event: Optional[threading.Event] = None,
|
|
486
|
+
tool_execution_timeout: Optional[float] = Constants.TIMEOUT_THRESHOLD,
|
|
487
|
+
mask_tool_output: bool = False,
|
|
488
|
+
pause_event: Optional[Union[threading.Event, asyncio.Event]] = None,
|
|
489
|
+
prune_tool_calls_from_memory: bool = False,
|
|
490
|
+
enable_snapshot_clean: bool = False,
|
|
491
|
+
retry_attempts: int = 3,
|
|
492
|
+
retry_delay: float = 1.0,
|
|
493
|
+
step_timeout: Optional[float] = Constants.TIMEOUT_THRESHOLD,
|
|
494
|
+
stream_accumulate: bool = True,
|
|
495
|
+
summary_window_ratio: float = 0.6,
|
|
205
496
|
) -> None:
|
|
206
497
|
if isinstance(model, ModelManager):
|
|
207
498
|
self.model_backend = model
|
|
@@ -217,13 +508,16 @@ class ChatAgent(BaseAgent):
|
|
|
217
508
|
# Assign unique ID
|
|
218
509
|
self.agent_id = agent_id if agent_id else str(uuid.uuid4())
|
|
219
510
|
|
|
511
|
+
self._enable_snapshot_clean = enable_snapshot_clean
|
|
512
|
+
self._tool_output_history: List[_ToolOutputHistoryEntry] = []
|
|
513
|
+
|
|
220
514
|
# Set up memory
|
|
221
515
|
context_creator = ScoreBasedContextCreator(
|
|
222
516
|
self.model_backend.token_counter,
|
|
223
|
-
|
|
517
|
+
self.model_backend.token_limit,
|
|
224
518
|
)
|
|
225
519
|
|
|
226
|
-
self.
|
|
520
|
+
self._memory: AgentMemory = memory or ChatHistoryMemory(
|
|
227
521
|
context_creator,
|
|
228
522
|
window_size=message_window_size,
|
|
229
523
|
agent_id=self.agent_id,
|
|
@@ -231,13 +525,11 @@ class ChatAgent(BaseAgent):
|
|
|
231
525
|
|
|
232
526
|
# So we don't have to pass agent_id when we define memory
|
|
233
527
|
if memory is not None:
|
|
234
|
-
|
|
528
|
+
self._memory.agent_id = self.agent_id
|
|
235
529
|
|
|
236
530
|
# Set up system message and initialize messages
|
|
237
531
|
self._original_system_message = (
|
|
238
|
-
BaseMessage.
|
|
239
|
-
role_name="Assistant", content=system_message
|
|
240
|
-
)
|
|
532
|
+
BaseMessage.make_system_message(system_message)
|
|
241
533
|
if isinstance(system_message, str)
|
|
242
534
|
else system_message
|
|
243
535
|
)
|
|
@@ -247,6 +539,21 @@ class ChatAgent(BaseAgent):
|
|
|
247
539
|
)
|
|
248
540
|
self.init_messages()
|
|
249
541
|
|
|
542
|
+
# Set up summarize threshold with validation
|
|
543
|
+
if summarize_threshold is not None:
|
|
544
|
+
if not (0 < summarize_threshold <= 100):
|
|
545
|
+
raise ValueError(
|
|
546
|
+
f"summarize_threshold must be between 0 and 100, "
|
|
547
|
+
f"got {summarize_threshold}"
|
|
548
|
+
)
|
|
549
|
+
logger.info(
|
|
550
|
+
f"Automatic context compression is enabled. Will trigger "
|
|
551
|
+
f"summarization when context window exceeds "
|
|
552
|
+
f"{summarize_threshold}% of the total token limit."
|
|
553
|
+
)
|
|
554
|
+
self.summarize_threshold = summarize_threshold
|
|
555
|
+
self._reset_summary_state()
|
|
556
|
+
|
|
250
557
|
# Set up role name and role type
|
|
251
558
|
self.role_name: str = (
|
|
252
559
|
getattr(self.system_message, "role_name", None) or "assistant"
|
|
@@ -264,6 +571,12 @@ class ChatAgent(BaseAgent):
|
|
|
264
571
|
]
|
|
265
572
|
}
|
|
266
573
|
|
|
574
|
+
# Register agent with toolkits that have RegisteredAgentToolkit mixin
|
|
575
|
+
if toolkits_to_register_agent:
|
|
576
|
+
for toolkit in toolkits_to_register_agent:
|
|
577
|
+
if isinstance(toolkit, RegisteredAgentToolkit):
|
|
578
|
+
toolkit.register_agent(self)
|
|
579
|
+
|
|
267
580
|
self._external_tool_schemas = {
|
|
268
581
|
tool_schema["function"]["name"]: tool_schema
|
|
269
582
|
for tool_schema in [
|
|
@@ -276,11 +589,28 @@ class ChatAgent(BaseAgent):
|
|
|
276
589
|
self.response_terminators = response_terminators or []
|
|
277
590
|
self.max_iteration = max_iteration
|
|
278
591
|
self.stop_event = stop_event
|
|
592
|
+
self.tool_execution_timeout = tool_execution_timeout
|
|
593
|
+
self.mask_tool_output = mask_tool_output
|
|
594
|
+
self._secure_result_store: Dict[str, Any] = {}
|
|
595
|
+
self._secure_result_store_lock = threading.Lock()
|
|
596
|
+
self.pause_event = pause_event
|
|
597
|
+
self.prune_tool_calls_from_memory = prune_tool_calls_from_memory
|
|
598
|
+
self.retry_attempts = max(1, retry_attempts)
|
|
599
|
+
self.retry_delay = max(0.0, retry_delay)
|
|
600
|
+
self.step_timeout = step_timeout
|
|
601
|
+
self._context_utility: Optional[ContextUtility] = None
|
|
602
|
+
self._context_summary_agent: Optional["ChatAgent"] = None
|
|
603
|
+
self.stream_accumulate = stream_accumulate
|
|
604
|
+
self._last_tool_call_record: Optional[ToolCallingRecord] = None
|
|
605
|
+
self._last_tool_call_signature: Optional[str] = None
|
|
606
|
+
self._last_token_limit_tool_signature: Optional[str] = None
|
|
607
|
+
self.summary_window_ratio = summary_window_ratio
|
|
279
608
|
|
|
280
609
|
def reset(self):
|
|
281
610
|
r"""Resets the :obj:`ChatAgent` to its initial state."""
|
|
282
611
|
self.terminated = False
|
|
283
612
|
self.init_messages()
|
|
613
|
+
self._reset_summary_state()
|
|
284
614
|
for terminator in self.response_terminators:
|
|
285
615
|
terminator.reset()
|
|
286
616
|
|
|
@@ -402,7 +732,10 @@ class ChatAgent(BaseAgent):
|
|
|
402
732
|
# List of tuples (platform, type)
|
|
403
733
|
resolved_models_list = []
|
|
404
734
|
for model_spec in model_list:
|
|
405
|
-
platform, type_ =
|
|
735
|
+
platform, type_ = ( # type: ignore[index]
|
|
736
|
+
model_spec[0],
|
|
737
|
+
model_spec[1],
|
|
738
|
+
)
|
|
406
739
|
resolved_models_list.append(
|
|
407
740
|
ModelFactory.create(
|
|
408
741
|
model_platform=platform, model_type=type_
|
|
@@ -442,6 +775,39 @@ class ChatAgent(BaseAgent):
|
|
|
442
775
|
)
|
|
443
776
|
self.init_messages()
|
|
444
777
|
|
|
778
|
+
@property
|
|
779
|
+
def memory(self) -> AgentMemory:
|
|
780
|
+
r"""Returns the agent memory."""
|
|
781
|
+
return self._memory
|
|
782
|
+
|
|
783
|
+
@memory.setter
|
|
784
|
+
def memory(self, value: AgentMemory) -> None:
|
|
785
|
+
r"""Set the agent memory.
|
|
786
|
+
|
|
787
|
+
When setting a new memory, the system message is automatically
|
|
788
|
+
re-added to ensure it's not lost.
|
|
789
|
+
|
|
790
|
+
Args:
|
|
791
|
+
value (AgentMemory): The new agent memory to use.
|
|
792
|
+
"""
|
|
793
|
+
self._memory = value
|
|
794
|
+
# Ensure the new memory has the system message
|
|
795
|
+
self.init_messages()
|
|
796
|
+
|
|
797
|
+
def set_context_utility(
|
|
798
|
+
self, context_utility: Optional[ContextUtility]
|
|
799
|
+
) -> None:
|
|
800
|
+
r"""Set the context utility for the agent.
|
|
801
|
+
|
|
802
|
+
This allows external components (like SingleAgentWorker) to provide
|
|
803
|
+
a shared context utility instance for workflow management.
|
|
804
|
+
|
|
805
|
+
Args:
|
|
806
|
+
context_utility (ContextUtility, optional): The context utility
|
|
807
|
+
to use. If None, the agent will create its own when needed.
|
|
808
|
+
"""
|
|
809
|
+
self._context_utility = context_utility
|
|
810
|
+
|
|
445
811
|
def _get_full_tool_schemas(self) -> List[Dict[str, Any]]:
|
|
446
812
|
r"""Returns a list of tool schemas of all tools, including internal
|
|
447
813
|
and external tools.
|
|
@@ -451,6 +817,329 @@ class ChatAgent(BaseAgent):
|
|
|
451
817
|
for func_tool in self._internal_tools.values()
|
|
452
818
|
]
|
|
453
819
|
|
|
820
|
+
@staticmethod
|
|
821
|
+
def _is_token_limit_error(error: Exception) -> bool:
|
|
822
|
+
r"""Return True when the exception message indicates a token limit."""
|
|
823
|
+
error_message = str(error).lower()
|
|
824
|
+
return any(
|
|
825
|
+
marker in error_message for marker in TOKEN_LIMIT_ERROR_MARKERS
|
|
826
|
+
)
|
|
827
|
+
|
|
828
|
+
@staticmethod
|
|
829
|
+
def _is_tool_related_record(record: MemoryRecord) -> bool:
|
|
830
|
+
r"""Determine whether the given memory record
|
|
831
|
+
belongs to a tool call."""
|
|
832
|
+
if record.role_at_backend in {
|
|
833
|
+
OpenAIBackendRole.TOOL,
|
|
834
|
+
OpenAIBackendRole.FUNCTION,
|
|
835
|
+
}:
|
|
836
|
+
return True
|
|
837
|
+
|
|
838
|
+
if (
|
|
839
|
+
record.role_at_backend == OpenAIBackendRole.ASSISTANT
|
|
840
|
+
and isinstance(record.message, FunctionCallingMessage)
|
|
841
|
+
):
|
|
842
|
+
return True
|
|
843
|
+
|
|
844
|
+
return False
|
|
845
|
+
|
|
846
|
+
def _find_indices_to_remove_for_last_tool_pair(
|
|
847
|
+
self, recent_records: List[ContextRecord]
|
|
848
|
+
) -> List[int]:
|
|
849
|
+
"""Find indices of records that should be removed to clean up the most
|
|
850
|
+
recent incomplete tool interaction pair.
|
|
851
|
+
|
|
852
|
+
This method identifies tool call/result pairs by tool_call_id and
|
|
853
|
+
returns the exact indices to remove, allowing non-contiguous deletions.
|
|
854
|
+
|
|
855
|
+
Logic:
|
|
856
|
+
- If the last record is a tool result (TOOL/FUNCTION) with a
|
|
857
|
+
tool_call_id, find the matching assistant call anywhere in history
|
|
858
|
+
and return both indices.
|
|
859
|
+
- If the last record is an assistant tool call without a result yet,
|
|
860
|
+
return just that index.
|
|
861
|
+
- For normal messages (non tool-related): remove just the last one.
|
|
862
|
+
- Fallback: If no tool_call_id is available, use heuristic (last 2 if
|
|
863
|
+
tool-related, otherwise last 1).
|
|
864
|
+
|
|
865
|
+
Returns:
|
|
866
|
+
List[int]: Indices to remove (may be non-contiguous).
|
|
867
|
+
"""
|
|
868
|
+
if not recent_records:
|
|
869
|
+
return []
|
|
870
|
+
|
|
871
|
+
last_idx = len(recent_records) - 1
|
|
872
|
+
last_record = recent_records[last_idx].memory_record
|
|
873
|
+
|
|
874
|
+
# Case A: Last is an ASSISTANT tool call with no result yet
|
|
875
|
+
if (
|
|
876
|
+
last_record.role_at_backend == OpenAIBackendRole.ASSISTANT
|
|
877
|
+
and isinstance(last_record.message, FunctionCallingMessage)
|
|
878
|
+
and last_record.message.result is None
|
|
879
|
+
):
|
|
880
|
+
return [last_idx]
|
|
881
|
+
|
|
882
|
+
# Case B: Last is TOOL/FUNCTION result, try id-based pairing
|
|
883
|
+
if last_record.role_at_backend in {
|
|
884
|
+
OpenAIBackendRole.TOOL,
|
|
885
|
+
OpenAIBackendRole.FUNCTION,
|
|
886
|
+
}:
|
|
887
|
+
tool_id = None
|
|
888
|
+
if isinstance(last_record.message, FunctionCallingMessage):
|
|
889
|
+
tool_id = last_record.message.tool_call_id
|
|
890
|
+
|
|
891
|
+
if tool_id:
|
|
892
|
+
for idx in range(len(recent_records) - 2, -1, -1):
|
|
893
|
+
rec = recent_records[idx].memory_record
|
|
894
|
+
if rec.role_at_backend != OpenAIBackendRole.ASSISTANT:
|
|
895
|
+
continue
|
|
896
|
+
|
|
897
|
+
# Check if this assistant message contains the tool_call_id
|
|
898
|
+
matched = False
|
|
899
|
+
|
|
900
|
+
# Case 1: FunctionCallingMessage (single tool call)
|
|
901
|
+
if isinstance(rec.message, FunctionCallingMessage):
|
|
902
|
+
if rec.message.tool_call_id == tool_id:
|
|
903
|
+
matched = True
|
|
904
|
+
|
|
905
|
+
# Case 2: BaseMessage with multiple tool_calls in meta_dict
|
|
906
|
+
elif (
|
|
907
|
+
hasattr(rec.message, "meta_dict")
|
|
908
|
+
and rec.message.meta_dict
|
|
909
|
+
):
|
|
910
|
+
tool_calls_list = rec.message.meta_dict.get(
|
|
911
|
+
"tool_calls", []
|
|
912
|
+
)
|
|
913
|
+
if isinstance(tool_calls_list, list):
|
|
914
|
+
for tc in tool_calls_list:
|
|
915
|
+
if (
|
|
916
|
+
isinstance(tc, dict)
|
|
917
|
+
and tc.get("id") == tool_id
|
|
918
|
+
):
|
|
919
|
+
matched = True
|
|
920
|
+
break
|
|
921
|
+
|
|
922
|
+
if matched:
|
|
923
|
+
# Return both assistant call and tool result indices
|
|
924
|
+
return [idx, last_idx]
|
|
925
|
+
|
|
926
|
+
# Fallback: no tool_call_id, use heuristic
|
|
927
|
+
if self._is_tool_related_record(last_record):
|
|
928
|
+
# Remove last 2 (assume they are paired)
|
|
929
|
+
return [last_idx - 1, last_idx] if last_idx > 0 else [last_idx]
|
|
930
|
+
else:
|
|
931
|
+
return [last_idx]
|
|
932
|
+
|
|
933
|
+
# Default: non tool-related tail => remove last one
|
|
934
|
+
return [last_idx]
|
|
935
|
+
|
|
936
|
+
@staticmethod
|
|
937
|
+
def _serialize_tool_args(args: Dict[str, Any]) -> str:
|
|
938
|
+
try:
|
|
939
|
+
return json.dumps(args, ensure_ascii=False, sort_keys=True)
|
|
940
|
+
except TypeError:
|
|
941
|
+
return str(args)
|
|
942
|
+
|
|
943
|
+
@classmethod
|
|
944
|
+
def _build_tool_signature(
|
|
945
|
+
cls, func_name: str, args: Dict[str, Any]
|
|
946
|
+
) -> str:
|
|
947
|
+
args_repr = cls._serialize_tool_args(args)
|
|
948
|
+
return f"{func_name}:{args_repr}"
|
|
949
|
+
|
|
950
|
+
def _describe_tool_call(
|
|
951
|
+
self, record: Optional[ToolCallingRecord]
|
|
952
|
+
) -> Optional[str]:
|
|
953
|
+
if record is None:
|
|
954
|
+
return None
|
|
955
|
+
args_repr = self._serialize_tool_args(record.args)
|
|
956
|
+
return f"Tool `{record.tool_name}` invoked with arguments {args_repr}."
|
|
957
|
+
|
|
958
|
+
def _update_last_tool_call_state(
|
|
959
|
+
self, record: Optional[ToolCallingRecord]
|
|
960
|
+
) -> None:
|
|
961
|
+
"""Track the most recent tool call and its identifying signature."""
|
|
962
|
+
self._last_tool_call_record = record
|
|
963
|
+
if record is None:
|
|
964
|
+
self._last_tool_call_signature = None
|
|
965
|
+
return
|
|
966
|
+
|
|
967
|
+
args = (
|
|
968
|
+
record.args
|
|
969
|
+
if isinstance(record.args, dict)
|
|
970
|
+
else {"_raw": record.args}
|
|
971
|
+
)
|
|
972
|
+
try:
|
|
973
|
+
signature = self._build_tool_signature(record.tool_name, args)
|
|
974
|
+
except Exception: # pragma: no cover - defensive guard
|
|
975
|
+
signature = None
|
|
976
|
+
self._last_tool_call_signature = signature
|
|
977
|
+
|
|
978
|
+
def _format_tool_limit_notice(self) -> Optional[str]:
|
|
979
|
+
record = self._last_tool_call_record
|
|
980
|
+
description = self._describe_tool_call(record)
|
|
981
|
+
if description is None:
|
|
982
|
+
return None
|
|
983
|
+
notice_lines = [
|
|
984
|
+
"[Tool Call Causing Token Limit]",
|
|
985
|
+
description,
|
|
986
|
+
]
|
|
987
|
+
|
|
988
|
+
if record is not None:
|
|
989
|
+
result = record.result
|
|
990
|
+
if isinstance(result, bytes):
|
|
991
|
+
result_repr = result.decode(errors="replace")
|
|
992
|
+
elif isinstance(result, str):
|
|
993
|
+
result_repr = result
|
|
994
|
+
else:
|
|
995
|
+
try:
|
|
996
|
+
result_repr = json.dumps(
|
|
997
|
+
result, ensure_ascii=False, sort_keys=True
|
|
998
|
+
)
|
|
999
|
+
except (TypeError, ValueError):
|
|
1000
|
+
result_repr = str(result)
|
|
1001
|
+
|
|
1002
|
+
result_length = len(result_repr)
|
|
1003
|
+
notice_lines.append(f"Tool result length: {result_length}")
|
|
1004
|
+
if self.model_backend.token_limit != 999999999:
|
|
1005
|
+
notice_lines.append(
|
|
1006
|
+
f"Token limit: {self.model_backend.token_limit}"
|
|
1007
|
+
)
|
|
1008
|
+
|
|
1009
|
+
return "\n".join(notice_lines)
|
|
1010
|
+
|
|
1011
|
+
@staticmethod
|
|
1012
|
+
def _append_user_messages_section(
|
|
1013
|
+
summary_content: str, user_messages: List[str]
|
|
1014
|
+
) -> str:
|
|
1015
|
+
section_title = "- **All User Messages**:"
|
|
1016
|
+
sanitized_messages: List[str] = []
|
|
1017
|
+
for msg in user_messages:
|
|
1018
|
+
if not isinstance(msg, str):
|
|
1019
|
+
msg = str(msg)
|
|
1020
|
+
cleaned = " ".join(msg.strip().splitlines())
|
|
1021
|
+
if cleaned:
|
|
1022
|
+
sanitized_messages.append(cleaned)
|
|
1023
|
+
|
|
1024
|
+
bullet_block = (
|
|
1025
|
+
"\n".join(f"- {m}" for m in sanitized_messages)
|
|
1026
|
+
if sanitized_messages
|
|
1027
|
+
else "- None noted"
|
|
1028
|
+
)
|
|
1029
|
+
user_section = f"{section_title}\n{bullet_block}"
|
|
1030
|
+
|
|
1031
|
+
summary_clean = summary_content.rstrip()
|
|
1032
|
+
separator = "\n\n" if summary_clean else ""
|
|
1033
|
+
return f"{summary_clean}{separator}{user_section}"
|
|
1034
|
+
|
|
1035
|
+
def _reset_summary_state(self) -> None:
|
|
1036
|
+
self._summary_token_count = 0 # Total tokens in summary messages
|
|
1037
|
+
|
|
1038
|
+
def _calculate_next_summary_threshold(self) -> int:
|
|
1039
|
+
r"""Calculate the next token threshold that should trigger
|
|
1040
|
+
summarization.
|
|
1041
|
+
|
|
1042
|
+
The threshold calculation follows a progressive strategy:
|
|
1043
|
+
- First time: token_limit * (summarize_threshold / 100)
|
|
1044
|
+
- Subsequent times: (limit - summary_token) / 2 + summary_token
|
|
1045
|
+
|
|
1046
|
+
This ensures that as summaries accumulate, the threshold adapts
|
|
1047
|
+
to maintain a reasonable balance between context and summaries.
|
|
1048
|
+
|
|
1049
|
+
Returns:
|
|
1050
|
+
int: The token count threshold for next summarization.
|
|
1051
|
+
"""
|
|
1052
|
+
token_limit = self.model_backend.token_limit
|
|
1053
|
+
summary_token_count = self._summary_token_count
|
|
1054
|
+
|
|
1055
|
+
# First summarization: use the percentage threshold
|
|
1056
|
+
if summary_token_count == 0:
|
|
1057
|
+
threshold = int(token_limit * self.summarize_threshold / 100)
|
|
1058
|
+
else:
|
|
1059
|
+
# Subsequent summarizations: adaptive threshold
|
|
1060
|
+
threshold = int(
|
|
1061
|
+
(token_limit - summary_token_count)
|
|
1062
|
+
* self.summarize_threshold
|
|
1063
|
+
/ 100
|
|
1064
|
+
+ summary_token_count
|
|
1065
|
+
)
|
|
1066
|
+
|
|
1067
|
+
return threshold
|
|
1068
|
+
|
|
1069
|
+
def _update_memory_with_summary(
|
|
1070
|
+
self, summary: str, include_summaries: bool = False
|
|
1071
|
+
) -> None:
|
|
1072
|
+
r"""Update memory with summary result.
|
|
1073
|
+
|
|
1074
|
+
This method handles memory clearing and restoration of summaries based
|
|
1075
|
+
on whether it's a progressive or full compression.
|
|
1076
|
+
"""
|
|
1077
|
+
|
|
1078
|
+
summary_content: str = summary
|
|
1079
|
+
|
|
1080
|
+
existing_summaries = []
|
|
1081
|
+
if not include_summaries:
|
|
1082
|
+
messages, _ = self.memory.get_context()
|
|
1083
|
+
for msg in messages:
|
|
1084
|
+
content = msg.get('content', '')
|
|
1085
|
+
if isinstance(content, str) and content.startswith(
|
|
1086
|
+
'[CONTEXT_SUMMARY]'
|
|
1087
|
+
):
|
|
1088
|
+
existing_summaries.append(msg)
|
|
1089
|
+
|
|
1090
|
+
# Clear memory
|
|
1091
|
+
self.clear_memory()
|
|
1092
|
+
|
|
1093
|
+
# Restore old summaries (for progressive compression)
|
|
1094
|
+
for old_summary in existing_summaries:
|
|
1095
|
+
content = old_summary.get('content', '')
|
|
1096
|
+
if not isinstance(content, str):
|
|
1097
|
+
content = str(content)
|
|
1098
|
+
summary_msg = BaseMessage.make_assistant_message(
|
|
1099
|
+
role_name="assistant", content=content
|
|
1100
|
+
)
|
|
1101
|
+
self.update_memory(summary_msg, OpenAIBackendRole.ASSISTANT)
|
|
1102
|
+
|
|
1103
|
+
# Add new summary
|
|
1104
|
+
new_summary_msg = BaseMessage.make_assistant_message(
|
|
1105
|
+
role_name="assistant", content=summary_content
|
|
1106
|
+
)
|
|
1107
|
+
self.update_memory(new_summary_msg, OpenAIBackendRole.ASSISTANT)
|
|
1108
|
+
input_message = BaseMessage.make_assistant_message(
|
|
1109
|
+
role_name="assistant",
|
|
1110
|
+
content=(
|
|
1111
|
+
"Please continue the conversation from "
|
|
1112
|
+
"where we left it off without asking the user any further "
|
|
1113
|
+
"questions. Continue with the last task that you were "
|
|
1114
|
+
"asked to work on."
|
|
1115
|
+
),
|
|
1116
|
+
)
|
|
1117
|
+
self.update_memory(input_message, OpenAIBackendRole.ASSISTANT)
|
|
1118
|
+
# Update token count
|
|
1119
|
+
try:
|
|
1120
|
+
summary_tokens = (
|
|
1121
|
+
self.model_backend.token_counter.count_tokens_from_messages(
|
|
1122
|
+
[{"role": "assistant", "content": summary_content}]
|
|
1123
|
+
)
|
|
1124
|
+
)
|
|
1125
|
+
|
|
1126
|
+
if include_summaries: # Full compression - reset count
|
|
1127
|
+
self._summary_token_count = summary_tokens
|
|
1128
|
+
logger.info(
|
|
1129
|
+
f"Full compression: Summary with {summary_tokens} tokens. "
|
|
1130
|
+
f"Total summary tokens reset to: {summary_tokens}"
|
|
1131
|
+
)
|
|
1132
|
+
else: # Progressive compression - accumulate
|
|
1133
|
+
self._summary_token_count += summary_tokens
|
|
1134
|
+
logger.info(
|
|
1135
|
+
f"Progressive compression: New summary "
|
|
1136
|
+
f"with {summary_tokens} tokens. "
|
|
1137
|
+
f"Total summary tokens: "
|
|
1138
|
+
f"{self._summary_token_count}"
|
|
1139
|
+
)
|
|
1140
|
+
except Exception as e:
|
|
1141
|
+
logger.warning(f"Failed to count summary tokens: {e}")
|
|
1142
|
+
|
|
454
1143
|
def _get_external_tool_names(self) -> Set[str]:
|
|
455
1144
|
r"""Returns a set of external tool names."""
|
|
456
1145
|
return set(self._external_tool_schemas.keys())
|
|
@@ -465,6 +1154,282 @@ class ChatAgent(BaseAgent):
|
|
|
465
1154
|
for tool in tools:
|
|
466
1155
|
self.add_tool(tool)
|
|
467
1156
|
|
|
1157
|
+
def _serialize_tool_result(self, result: Any) -> str:
|
|
1158
|
+
if isinstance(result, str):
|
|
1159
|
+
return result
|
|
1160
|
+
try:
|
|
1161
|
+
return json.dumps(result, ensure_ascii=False)
|
|
1162
|
+
except (TypeError, ValueError):
|
|
1163
|
+
return str(result)
|
|
1164
|
+
|
|
1165
|
+
def _clean_snapshot_line(self, line: str) -> str:
|
|
1166
|
+
r"""Clean a single snapshot line by removing prefixes and references.
|
|
1167
|
+
|
|
1168
|
+
This method handles snapshot lines in the format:
|
|
1169
|
+
- [prefix] "quoted text" [attributes] [ref=...]: description
|
|
1170
|
+
|
|
1171
|
+
It preserves:
|
|
1172
|
+
- Quoted text content (including brackets inside quotes)
|
|
1173
|
+
- Description text after the colon
|
|
1174
|
+
|
|
1175
|
+
It removes:
|
|
1176
|
+
- Line prefixes (e.g., "- button", "- tooltip", "generic:")
|
|
1177
|
+
- Attribute markers (e.g., [disabled], [ref=e47])
|
|
1178
|
+
- Lines with only element types
|
|
1179
|
+
- All indentation
|
|
1180
|
+
|
|
1181
|
+
Args:
|
|
1182
|
+
line: The original line content.
|
|
1183
|
+
|
|
1184
|
+
Returns:
|
|
1185
|
+
The cleaned line content, or empty string if line should be
|
|
1186
|
+
removed.
|
|
1187
|
+
"""
|
|
1188
|
+
original = line.strip()
|
|
1189
|
+
if not original:
|
|
1190
|
+
return ''
|
|
1191
|
+
|
|
1192
|
+
# Check if line is just an element type marker
|
|
1193
|
+
# (e.g., "- generic:", "button:")
|
|
1194
|
+
if re.match(r'^(?:-\s+)?\w+\s*:?\s*$', original):
|
|
1195
|
+
return ''
|
|
1196
|
+
|
|
1197
|
+
# Remove element type prefix
|
|
1198
|
+
line = re.sub(r'^(?:-\s+)?\w+[\s:]+', '', original)
|
|
1199
|
+
|
|
1200
|
+
# Remove bracket markers while preserving quoted text
|
|
1201
|
+
quoted_parts = []
|
|
1202
|
+
|
|
1203
|
+
def save_quoted(match):
|
|
1204
|
+
quoted_parts.append(match.group(0))
|
|
1205
|
+
return f'__QUOTED_{len(quoted_parts)-1}__'
|
|
1206
|
+
|
|
1207
|
+
line = re.sub(r'"[^"]*"', save_quoted, line)
|
|
1208
|
+
line = re.sub(r'\s*\[[^\]]+\]\s*', ' ', line)
|
|
1209
|
+
|
|
1210
|
+
for i, quoted in enumerate(quoted_parts):
|
|
1211
|
+
line = line.replace(f'__QUOTED_{i}__', quoted)
|
|
1212
|
+
|
|
1213
|
+
# Clean up formatting
|
|
1214
|
+
line = re.sub(r'\s+', ' ', line).strip()
|
|
1215
|
+
line = re.sub(r'\s*:\s*', ': ', line)
|
|
1216
|
+
line = line.lstrip(': ').strip()
|
|
1217
|
+
|
|
1218
|
+
return '' if not line else line
|
|
1219
|
+
|
|
1220
|
+
def _clean_snapshot_content(self, content: str) -> str:
|
|
1221
|
+
r"""Clean snapshot content by removing prefixes, references, and
|
|
1222
|
+
deduplicating lines.
|
|
1223
|
+
|
|
1224
|
+
This method identifies snapshot lines (containing element keywords or
|
|
1225
|
+
references) and cleans them while preserving non-snapshot content.
|
|
1226
|
+
It also handles JSON-formatted tool outputs with snapshot fields.
|
|
1227
|
+
|
|
1228
|
+
Args:
|
|
1229
|
+
content: The original snapshot content.
|
|
1230
|
+
|
|
1231
|
+
Returns:
|
|
1232
|
+
The cleaned content with deduplicated lines.
|
|
1233
|
+
"""
|
|
1234
|
+
try:
|
|
1235
|
+
import json
|
|
1236
|
+
|
|
1237
|
+
data = json.loads(content)
|
|
1238
|
+
modified = False
|
|
1239
|
+
|
|
1240
|
+
def clean_json_value(obj):
|
|
1241
|
+
nonlocal modified
|
|
1242
|
+
if isinstance(obj, dict):
|
|
1243
|
+
result = {}
|
|
1244
|
+
for key, value in obj.items():
|
|
1245
|
+
if key == 'snapshot' and isinstance(value, str):
|
|
1246
|
+
try:
|
|
1247
|
+
decoded_value = value.encode().decode(
|
|
1248
|
+
'unicode_escape'
|
|
1249
|
+
)
|
|
1250
|
+
except (UnicodeDecodeError, AttributeError):
|
|
1251
|
+
decoded_value = value
|
|
1252
|
+
|
|
1253
|
+
needs_cleaning = (
|
|
1254
|
+
'- ' in decoded_value
|
|
1255
|
+
or '[ref=' in decoded_value
|
|
1256
|
+
or any(
|
|
1257
|
+
elem + ':' in decoded_value
|
|
1258
|
+
for elem in [
|
|
1259
|
+
'generic',
|
|
1260
|
+
'img',
|
|
1261
|
+
'banner',
|
|
1262
|
+
'list',
|
|
1263
|
+
'listitem',
|
|
1264
|
+
'search',
|
|
1265
|
+
'navigation',
|
|
1266
|
+
]
|
|
1267
|
+
)
|
|
1268
|
+
)
|
|
1269
|
+
|
|
1270
|
+
if needs_cleaning:
|
|
1271
|
+
cleaned_snapshot = self._clean_text_snapshot(
|
|
1272
|
+
decoded_value
|
|
1273
|
+
)
|
|
1274
|
+
result[key] = cleaned_snapshot
|
|
1275
|
+
modified = True
|
|
1276
|
+
else:
|
|
1277
|
+
result[key] = value
|
|
1278
|
+
else:
|
|
1279
|
+
result[key] = clean_json_value(value)
|
|
1280
|
+
return result
|
|
1281
|
+
elif isinstance(obj, list):
|
|
1282
|
+
return [clean_json_value(item) for item in obj]
|
|
1283
|
+
else:
|
|
1284
|
+
return obj
|
|
1285
|
+
|
|
1286
|
+
cleaned_data = clean_json_value(data)
|
|
1287
|
+
|
|
1288
|
+
if modified:
|
|
1289
|
+
return json.dumps(cleaned_data, ensure_ascii=False, indent=4)
|
|
1290
|
+
else:
|
|
1291
|
+
return content
|
|
1292
|
+
|
|
1293
|
+
except (json.JSONDecodeError, TypeError):
|
|
1294
|
+
return self._clean_text_snapshot(content)
|
|
1295
|
+
|
|
1296
|
+
def _clean_text_snapshot(self, content: str) -> str:
|
|
1297
|
+
r"""Clean plain text snapshot content.
|
|
1298
|
+
|
|
1299
|
+
This method:
|
|
1300
|
+
- Removes all indentation
|
|
1301
|
+
- Deletes empty lines
|
|
1302
|
+
- Deduplicates all lines
|
|
1303
|
+
- Cleans snapshot-specific markers
|
|
1304
|
+
|
|
1305
|
+
Args:
|
|
1306
|
+
content: The original snapshot text.
|
|
1307
|
+
|
|
1308
|
+
Returns:
|
|
1309
|
+
The cleaned content with deduplicated lines, no indentation,
|
|
1310
|
+
and no empty lines.
|
|
1311
|
+
"""
|
|
1312
|
+
lines = content.split('\n')
|
|
1313
|
+
cleaned_lines = []
|
|
1314
|
+
seen = set()
|
|
1315
|
+
|
|
1316
|
+
for line in lines:
|
|
1317
|
+
stripped_line = line.strip()
|
|
1318
|
+
|
|
1319
|
+
if not stripped_line:
|
|
1320
|
+
continue
|
|
1321
|
+
|
|
1322
|
+
# Skip metadata lines (like "- /url:", "- /ref:")
|
|
1323
|
+
if re.match(r'^-?\s*/\w+\s*:', stripped_line):
|
|
1324
|
+
continue
|
|
1325
|
+
|
|
1326
|
+
is_snapshot_line = '[ref=' in stripped_line or re.match(
|
|
1327
|
+
r'^(?:-\s+)?\w+(?:[\s:]|$)', stripped_line
|
|
1328
|
+
)
|
|
1329
|
+
|
|
1330
|
+
if is_snapshot_line:
|
|
1331
|
+
cleaned = self._clean_snapshot_line(stripped_line)
|
|
1332
|
+
if cleaned and cleaned not in seen:
|
|
1333
|
+
cleaned_lines.append(cleaned)
|
|
1334
|
+
seen.add(cleaned)
|
|
1335
|
+
else:
|
|
1336
|
+
if stripped_line not in seen:
|
|
1337
|
+
cleaned_lines.append(stripped_line)
|
|
1338
|
+
seen.add(stripped_line)
|
|
1339
|
+
|
|
1340
|
+
return '\n'.join(cleaned_lines)
|
|
1341
|
+
|
|
1342
|
+
def _register_tool_output_for_cache(
|
|
1343
|
+
self,
|
|
1344
|
+
func_name: str,
|
|
1345
|
+
tool_call_id: str,
|
|
1346
|
+
result_text: str,
|
|
1347
|
+
records: List[MemoryRecord],
|
|
1348
|
+
) -> None:
|
|
1349
|
+
if not records:
|
|
1350
|
+
return
|
|
1351
|
+
|
|
1352
|
+
entry = _ToolOutputHistoryEntry(
|
|
1353
|
+
tool_name=func_name,
|
|
1354
|
+
tool_call_id=tool_call_id,
|
|
1355
|
+
result_text=result_text,
|
|
1356
|
+
record_uuids=[str(record.uuid) for record in records],
|
|
1357
|
+
record_timestamps=[record.timestamp for record in records],
|
|
1358
|
+
)
|
|
1359
|
+
self._tool_output_history.append(entry)
|
|
1360
|
+
self._process_tool_output_cache()
|
|
1361
|
+
|
|
1362
|
+
def _process_tool_output_cache(self) -> None:
|
|
1363
|
+
if not self._enable_snapshot_clean or not self._tool_output_history:
|
|
1364
|
+
return
|
|
1365
|
+
|
|
1366
|
+
# Only clean older results; keep the latest expanded for immediate use.
|
|
1367
|
+
for entry in self._tool_output_history[:-1]:
|
|
1368
|
+
if entry.cached:
|
|
1369
|
+
continue
|
|
1370
|
+
self._clean_snapshot_in_memory(entry)
|
|
1371
|
+
|
|
1372
|
+
def _clean_snapshot_in_memory(
|
|
1373
|
+
self, entry: _ToolOutputHistoryEntry
|
|
1374
|
+
) -> None:
|
|
1375
|
+
if not entry.record_uuids:
|
|
1376
|
+
return
|
|
1377
|
+
|
|
1378
|
+
# Clean snapshot markers and references from historical tool output
|
|
1379
|
+
result_text = entry.result_text
|
|
1380
|
+
if '- ' in result_text and '[ref=' in result_text:
|
|
1381
|
+
cleaned_result = self._clean_snapshot_content(result_text)
|
|
1382
|
+
|
|
1383
|
+
# Update the message in memory storage
|
|
1384
|
+
timestamp = (
|
|
1385
|
+
entry.record_timestamps[0]
|
|
1386
|
+
if entry.record_timestamps
|
|
1387
|
+
else time.time_ns() / 1_000_000_000
|
|
1388
|
+
)
|
|
1389
|
+
cleaned_message = FunctionCallingMessage(
|
|
1390
|
+
role_name=self.role_name,
|
|
1391
|
+
role_type=self.role_type,
|
|
1392
|
+
meta_dict={},
|
|
1393
|
+
content="",
|
|
1394
|
+
func_name=entry.tool_name,
|
|
1395
|
+
result=cleaned_result,
|
|
1396
|
+
tool_call_id=entry.tool_call_id,
|
|
1397
|
+
)
|
|
1398
|
+
|
|
1399
|
+
chat_history_block = getattr(
|
|
1400
|
+
self.memory, "_chat_history_block", None
|
|
1401
|
+
)
|
|
1402
|
+
storage = getattr(chat_history_block, "storage", None)
|
|
1403
|
+
if storage is None:
|
|
1404
|
+
return
|
|
1405
|
+
|
|
1406
|
+
existing_records = storage.load()
|
|
1407
|
+
updated_records = [
|
|
1408
|
+
record
|
|
1409
|
+
for record in existing_records
|
|
1410
|
+
if record["uuid"] not in entry.record_uuids
|
|
1411
|
+
]
|
|
1412
|
+
new_record = MemoryRecord(
|
|
1413
|
+
message=cleaned_message,
|
|
1414
|
+
role_at_backend=OpenAIBackendRole.FUNCTION,
|
|
1415
|
+
timestamp=timestamp,
|
|
1416
|
+
agent_id=self.agent_id,
|
|
1417
|
+
)
|
|
1418
|
+
updated_records.append(new_record.to_dict())
|
|
1419
|
+
updated_records.sort(key=lambda record: record["timestamp"])
|
|
1420
|
+
storage.clear()
|
|
1421
|
+
storage.save(updated_records)
|
|
1422
|
+
|
|
1423
|
+
logger.info(
|
|
1424
|
+
"Cleaned snapshot in memory for tool output '%s' (%s)",
|
|
1425
|
+
entry.tool_name,
|
|
1426
|
+
entry.tool_call_id,
|
|
1427
|
+
)
|
|
1428
|
+
|
|
1429
|
+
entry.cached = True
|
|
1430
|
+
entry.record_uuids = [str(new_record.uuid)]
|
|
1431
|
+
entry.record_timestamps = [timestamp]
|
|
1432
|
+
|
|
468
1433
|
def add_external_tool(
|
|
469
1434
|
self, tool: Union[FunctionTool, Callable, Dict[str, Any]]
|
|
470
1435
|
) -> None:
|
|
@@ -509,7 +1474,8 @@ class ChatAgent(BaseAgent):
|
|
|
509
1474
|
message: BaseMessage,
|
|
510
1475
|
role: OpenAIBackendRole,
|
|
511
1476
|
timestamp: Optional[float] = None,
|
|
512
|
-
|
|
1477
|
+
return_records: bool = False,
|
|
1478
|
+
) -> Optional[List[MemoryRecord]]:
|
|
513
1479
|
r"""Updates the agent memory with a new message.
|
|
514
1480
|
|
|
515
1481
|
Args:
|
|
@@ -517,21 +1483,29 @@ class ChatAgent(BaseAgent):
|
|
|
517
1483
|
messages.
|
|
518
1484
|
role (OpenAIBackendRole): The backend role type.
|
|
519
1485
|
timestamp (Optional[float], optional): Custom timestamp for the
|
|
520
|
-
memory record. If None
|
|
1486
|
+
memory record. If `None`, the current time will be used.
|
|
521
1487
|
(default: :obj:`None`)
|
|
1488
|
+
return_records (bool, optional): When ``True`` the method returns
|
|
1489
|
+
the list of MemoryRecord objects written to memory.
|
|
1490
|
+
(default: :obj:`False`)
|
|
1491
|
+
|
|
1492
|
+
Returns:
|
|
1493
|
+
Optional[List[MemoryRecord]]: The records that were written when
|
|
1494
|
+
``return_records`` is ``True``; otherwise ``None``.
|
|
522
1495
|
"""
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
if timestamp is not None
|
|
531
|
-
else time.time_ns() / 1_000_000_000, # Nanosecond precision
|
|
532
|
-
agent_id=self.agent_id,
|
|
533
|
-
)
|
|
1496
|
+
record = MemoryRecord(
|
|
1497
|
+
message=message,
|
|
1498
|
+
role_at_backend=role,
|
|
1499
|
+
timestamp=timestamp
|
|
1500
|
+
if timestamp is not None
|
|
1501
|
+
else time.time_ns() / 1_000_000_000, # Nanosecond precision
|
|
1502
|
+
agent_id=self.agent_id,
|
|
534
1503
|
)
|
|
1504
|
+
self.memory.write_record(record)
|
|
1505
|
+
|
|
1506
|
+
if return_records:
|
|
1507
|
+
return [record]
|
|
1508
|
+
return None
|
|
535
1509
|
|
|
536
1510
|
def load_memory(self, memory: AgentMemory) -> None:
|
|
537
1511
|
r"""Load the provided memory into the agent.
|
|
@@ -610,51 +1584,699 @@ class ChatAgent(BaseAgent):
|
|
|
610
1584
|
json_store.save(to_save)
|
|
611
1585
|
logger.info(f"Memory saved to {path}")
|
|
612
1586
|
|
|
613
|
-
def
|
|
614
|
-
r"""Clear the agent's memory and reset to initial state.
|
|
615
|
-
|
|
616
|
-
Returns:
|
|
617
|
-
None
|
|
618
|
-
"""
|
|
619
|
-
self.memory.clear()
|
|
620
|
-
if self.system_message is not None:
|
|
621
|
-
self.update_memory(self.system_message, OpenAIBackendRole.SYSTEM)
|
|
622
|
-
|
|
623
|
-
def _generate_system_message_for_output_language(
|
|
1587
|
+
def summarize(
|
|
624
1588
|
self,
|
|
625
|
-
|
|
626
|
-
|
|
1589
|
+
filename: Optional[str] = None,
|
|
1590
|
+
summary_prompt: Optional[str] = None,
|
|
1591
|
+
response_format: Optional[Type[BaseModel]] = None,
|
|
1592
|
+
working_directory: Optional[Union[str, Path]] = None,
|
|
1593
|
+
include_summaries: bool = False,
|
|
1594
|
+
add_user_messages: bool = True,
|
|
1595
|
+
) -> Dict[str, Any]:
|
|
1596
|
+
r"""Summarize the agent's current conversation context and persist it
|
|
1597
|
+
to a markdown file.
|
|
627
1598
|
|
|
628
|
-
|
|
629
|
-
|
|
1599
|
+
.. deprecated:: 0.2.80
|
|
1600
|
+
Use :meth:`asummarize` for async/await support and better
|
|
1601
|
+
performance in parallel summarization workflows.
|
|
630
1602
|
|
|
1603
|
+
Args:
|
|
1604
|
+
filename (Optional[str]): The base filename (without extension) to
|
|
1605
|
+
use for the markdown file. Defaults to a timestamped name when
|
|
1606
|
+
not provided.
|
|
1607
|
+
summary_prompt (Optional[str]): Custom prompt for the summarizer.
|
|
1608
|
+
When omitted, a default prompt highlighting key decisions,
|
|
1609
|
+
action items, and open questions is used.
|
|
1610
|
+
response_format (Optional[Type[BaseModel]]): A Pydantic model
|
|
1611
|
+
defining the expected structure of the response. If provided,
|
|
1612
|
+
the summary will be generated as structured output and included
|
|
1613
|
+
in the result.
|
|
1614
|
+
include_summaries (bool): Whether to include previously generated
|
|
1615
|
+
summaries in the content to be summarized. If False (default),
|
|
1616
|
+
only non-summary messages will be summarized. If True, all
|
|
1617
|
+
messages including previous summaries will be summarized
|
|
1618
|
+
(full compression). (default: :obj:`False`)
|
|
1619
|
+
working_directory (Optional[str|Path]): Optional directory to save
|
|
1620
|
+
the markdown summary file. If provided, overrides the default
|
|
1621
|
+
directory used by ContextUtility.
|
|
1622
|
+
add_user_messages (bool): Whether add user messages to summary.
|
|
1623
|
+
(default: :obj:`True`)
|
|
631
1624
|
Returns:
|
|
632
|
-
|
|
1625
|
+
Dict[str, Any]: A dictionary containing the summary text, file
|
|
1626
|
+
path, status message, and optionally structured_summary if
|
|
1627
|
+
response_format was provided.
|
|
1628
|
+
|
|
1629
|
+
See Also:
|
|
1630
|
+
:meth:`asummarize`: Async version for non-blocking LLM calls.
|
|
633
1631
|
"""
|
|
634
|
-
if not self._output_language:
|
|
635
|
-
return self._original_system_message
|
|
636
1632
|
|
|
637
|
-
|
|
638
|
-
"
|
|
639
|
-
|
|
1633
|
+
warnings.warn(
|
|
1634
|
+
"summarize() is synchronous. Consider using asummarize() "
|
|
1635
|
+
"for async/await support and better performance.",
|
|
1636
|
+
DeprecationWarning,
|
|
1637
|
+
stacklevel=2,
|
|
640
1638
|
)
|
|
641
1639
|
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
1640
|
+
result: Dict[str, Any] = {
|
|
1641
|
+
"summary": "",
|
|
1642
|
+
"file_path": None,
|
|
1643
|
+
"status": "",
|
|
1644
|
+
}
|
|
1645
|
+
|
|
1646
|
+
try:
|
|
1647
|
+
# Use external context if set, otherwise create local one
|
|
1648
|
+
if self._context_utility is None:
|
|
1649
|
+
if working_directory is not None:
|
|
1650
|
+
self._context_utility = ContextUtility(
|
|
1651
|
+
working_directory=str(working_directory)
|
|
1652
|
+
)
|
|
1653
|
+
else:
|
|
1654
|
+
self._context_utility = ContextUtility()
|
|
1655
|
+
context_util = self._context_utility
|
|
1656
|
+
|
|
1657
|
+
# Get conversation directly from agent's memory
|
|
1658
|
+
messages, _ = self.memory.get_context()
|
|
1659
|
+
|
|
1660
|
+
if not messages:
|
|
1661
|
+
status_message = (
|
|
1662
|
+
"No conversation context available to summarize."
|
|
1663
|
+
)
|
|
1664
|
+
result["status"] = status_message
|
|
1665
|
+
return result
|
|
1666
|
+
|
|
1667
|
+
# Convert messages to conversation text
|
|
1668
|
+
conversation_lines = []
|
|
1669
|
+
user_messages: List[str] = []
|
|
1670
|
+
for message in messages:
|
|
1671
|
+
role = message.get('role', 'unknown')
|
|
1672
|
+
content = message.get('content', '')
|
|
1673
|
+
|
|
1674
|
+
# Skip summary messages if include_summaries is False
|
|
1675
|
+
if not include_summaries and isinstance(content, str):
|
|
1676
|
+
# Check if this is a summary message by looking for marker
|
|
1677
|
+
if content.startswith('[CONTEXT_SUMMARY]'):
|
|
1678
|
+
continue
|
|
1679
|
+
|
|
1680
|
+
# Handle tool call messages (assistant calling tools)
|
|
1681
|
+
tool_calls = message.get('tool_calls')
|
|
1682
|
+
if tool_calls and isinstance(tool_calls, (list, tuple)):
|
|
1683
|
+
for tool_call in tool_calls:
|
|
1684
|
+
# Handle both dict and object formats
|
|
1685
|
+
if isinstance(tool_call, dict):
|
|
1686
|
+
func_name = tool_call.get('function', {}).get(
|
|
1687
|
+
'name', 'unknown_tool'
|
|
1688
|
+
)
|
|
1689
|
+
func_args_str = tool_call.get('function', {}).get(
|
|
1690
|
+
'arguments', '{}'
|
|
1691
|
+
)
|
|
1692
|
+
else:
|
|
1693
|
+
# Handle object format (Pydantic or similar)
|
|
1694
|
+
func_name = getattr(
|
|
1695
|
+
getattr(tool_call, 'function', None),
|
|
1696
|
+
'name',
|
|
1697
|
+
'unknown_tool',
|
|
1698
|
+
)
|
|
1699
|
+
func_args_str = getattr(
|
|
1700
|
+
getattr(tool_call, 'function', None),
|
|
1701
|
+
'arguments',
|
|
1702
|
+
'{}',
|
|
1703
|
+
)
|
|
1704
|
+
|
|
1705
|
+
# Parse and format arguments for readability
|
|
1706
|
+
try:
|
|
1707
|
+
import json
|
|
1708
|
+
|
|
1709
|
+
args_dict = json.loads(func_args_str)
|
|
1710
|
+
args_formatted = ', '.join(
|
|
1711
|
+
f"{k}={v}" for k, v in args_dict.items()
|
|
1712
|
+
)
|
|
1713
|
+
except (json.JSONDecodeError, ValueError, TypeError):
|
|
1714
|
+
args_formatted = func_args_str
|
|
1715
|
+
|
|
1716
|
+
conversation_lines.append(
|
|
1717
|
+
f"[TOOL CALL] {func_name}({args_formatted})"
|
|
1718
|
+
)
|
|
1719
|
+
|
|
1720
|
+
# Handle tool response messages
|
|
1721
|
+
elif role == 'tool':
|
|
1722
|
+
tool_name = message.get('name', 'unknown_tool')
|
|
1723
|
+
if not content:
|
|
1724
|
+
content = str(message.get('content', ''))
|
|
1725
|
+
conversation_lines.append(
|
|
1726
|
+
f"[TOOL RESULT] {tool_name} → {content}"
|
|
1727
|
+
)
|
|
1728
|
+
|
|
1729
|
+
# Handle regular content messages (user/assistant/system)
|
|
1730
|
+
elif content:
|
|
1731
|
+
content = str(content)
|
|
1732
|
+
if role == 'user':
|
|
1733
|
+
user_messages.append(content)
|
|
1734
|
+
conversation_lines.append(f"{role}: {content}")
|
|
1735
|
+
|
|
1736
|
+
conversation_text = "\n".join(conversation_lines).strip()
|
|
1737
|
+
|
|
1738
|
+
if not conversation_text:
|
|
1739
|
+
status_message = (
|
|
1740
|
+
"Conversation context is empty; skipping summary."
|
|
1741
|
+
)
|
|
1742
|
+
result["status"] = status_message
|
|
1743
|
+
return result
|
|
1744
|
+
|
|
1745
|
+
if self._context_summary_agent is None:
|
|
1746
|
+
self._context_summary_agent = ChatAgent(
|
|
1747
|
+
system_message=(
|
|
1748
|
+
"You are a helpful assistant that summarizes "
|
|
1749
|
+
"conversations"
|
|
1750
|
+
),
|
|
1751
|
+
model=self.model_backend,
|
|
1752
|
+
agent_id=f"{self.agent_id}_context_summarizer",
|
|
1753
|
+
summarize_threshold=None,
|
|
1754
|
+
)
|
|
1755
|
+
else:
|
|
1756
|
+
self._context_summary_agent.reset()
|
|
1757
|
+
|
|
1758
|
+
if summary_prompt:
|
|
1759
|
+
prompt_text = (
|
|
1760
|
+
f"{summary_prompt.rstrip()}\n\n"
|
|
1761
|
+
f"AGENT CONVERSATION TO BE SUMMARIZED:\n"
|
|
1762
|
+
f"{conversation_text}"
|
|
1763
|
+
)
|
|
1764
|
+
else:
|
|
1765
|
+
prompt_text = build_default_summary_prompt(conversation_text)
|
|
1766
|
+
|
|
1767
|
+
try:
|
|
1768
|
+
# Use structured output if response_format is provided
|
|
1769
|
+
if response_format:
|
|
1770
|
+
response = self._context_summary_agent.step(
|
|
1771
|
+
prompt_text, response_format=response_format
|
|
1772
|
+
)
|
|
1773
|
+
else:
|
|
1774
|
+
response = self._context_summary_agent.step(prompt_text)
|
|
1775
|
+
except Exception as step_exc:
|
|
1776
|
+
error_message = (
|
|
1777
|
+
f"Failed to generate summary using model: {step_exc}"
|
|
1778
|
+
)
|
|
1779
|
+
logger.error(error_message)
|
|
1780
|
+
result["status"] = error_message
|
|
1781
|
+
return result
|
|
1782
|
+
|
|
1783
|
+
if not response.msgs:
|
|
1784
|
+
status_message = (
|
|
1785
|
+
"Failed to generate summary from model response."
|
|
1786
|
+
)
|
|
1787
|
+
result["status"] = status_message
|
|
1788
|
+
return result
|
|
1789
|
+
|
|
1790
|
+
summary_content = response.msgs[-1].content.strip()
|
|
1791
|
+
if not summary_content:
|
|
1792
|
+
status_message = "Generated summary is empty."
|
|
1793
|
+
result["status"] = status_message
|
|
1794
|
+
return result
|
|
1795
|
+
|
|
1796
|
+
# handle structured output if response_format was provided
|
|
1797
|
+
structured_output = None
|
|
1798
|
+
if response_format and response.msgs[-1].parsed:
|
|
1799
|
+
structured_output = response.msgs[-1].parsed
|
|
1800
|
+
|
|
1801
|
+
# determine filename: use provided filename, or extract from
|
|
1802
|
+
# structured output, or generate timestamp
|
|
1803
|
+
if filename:
|
|
1804
|
+
base_filename = filename
|
|
1805
|
+
elif structured_output and hasattr(
|
|
1806
|
+
structured_output, 'task_title'
|
|
1807
|
+
):
|
|
1808
|
+
# use task_title from structured output for filename
|
|
1809
|
+
task_title = structured_output.task_title
|
|
1810
|
+
clean_title = ContextUtility.sanitize_workflow_filename(
|
|
1811
|
+
task_title
|
|
1812
|
+
)
|
|
1813
|
+
base_filename = (
|
|
1814
|
+
f"{clean_title}_workflow" if clean_title else "workflow"
|
|
1815
|
+
)
|
|
1816
|
+
else:
|
|
1817
|
+
base_filename = f"context_summary_{datetime.now().strftime('%Y%m%d_%H%M%S')}" # noqa: E501
|
|
1818
|
+
|
|
1819
|
+
base_filename = Path(base_filename).with_suffix("").name
|
|
1820
|
+
|
|
1821
|
+
metadata = context_util.get_session_metadata()
|
|
1822
|
+
metadata.update(
|
|
1823
|
+
{
|
|
1824
|
+
"agent_id": self.agent_id,
|
|
1825
|
+
"message_count": len(messages),
|
|
1826
|
+
}
|
|
1827
|
+
)
|
|
1828
|
+
|
|
1829
|
+
# convert structured output to custom markdown if present
|
|
1830
|
+
if structured_output:
|
|
1831
|
+
# convert structured output to custom markdown
|
|
1832
|
+
summary_content = context_util.structured_output_to_markdown(
|
|
1833
|
+
structured_data=structured_output, metadata=metadata
|
|
1834
|
+
)
|
|
1835
|
+
if add_user_messages:
|
|
1836
|
+
summary_content = self._append_user_messages_section(
|
|
1837
|
+
summary_content, user_messages
|
|
1838
|
+
)
|
|
1839
|
+
|
|
1840
|
+
# Save the markdown (either custom structured or default)
|
|
1841
|
+
save_status = context_util.save_markdown_file(
|
|
1842
|
+
base_filename,
|
|
1843
|
+
summary_content,
|
|
1844
|
+
title="Conversation Summary"
|
|
1845
|
+
if not structured_output
|
|
1846
|
+
else None,
|
|
1847
|
+
metadata=metadata if not structured_output else None,
|
|
1848
|
+
)
|
|
1849
|
+
|
|
1850
|
+
file_path = (
|
|
1851
|
+
context_util.get_working_directory() / f"{base_filename}.md"
|
|
1852
|
+
)
|
|
1853
|
+
summary_content = (
|
|
1854
|
+
f"[CONTEXT_SUMMARY] The following is a summary of our "
|
|
1855
|
+
f"conversation from a previous session: {summary_content}"
|
|
1856
|
+
)
|
|
1857
|
+
# Prepare result dictionary
|
|
1858
|
+
result_dict = {
|
|
1859
|
+
"summary": summary_content,
|
|
1860
|
+
"file_path": str(file_path),
|
|
1861
|
+
"status": save_status,
|
|
1862
|
+
"structured_summary": structured_output,
|
|
1863
|
+
}
|
|
1864
|
+
|
|
1865
|
+
result.update(result_dict)
|
|
1866
|
+
logger.info("Conversation summary saved to %s", file_path)
|
|
1867
|
+
return result
|
|
1868
|
+
|
|
1869
|
+
except Exception as exc:
|
|
1870
|
+
error_message = f"Failed to summarize conversation context: {exc}"
|
|
1871
|
+
logger.error(error_message)
|
|
1872
|
+
result["status"] = error_message
|
|
1873
|
+
return result
|
|
1874
|
+
|
|
1875
|
+
async def asummarize(
|
|
1876
|
+
self,
|
|
1877
|
+
filename: Optional[str] = None,
|
|
1878
|
+
summary_prompt: Optional[str] = None,
|
|
1879
|
+
response_format: Optional[Type[BaseModel]] = None,
|
|
1880
|
+
working_directory: Optional[Union[str, Path]] = None,
|
|
1881
|
+
include_summaries: bool = False,
|
|
1882
|
+
add_user_messages: bool = True,
|
|
1883
|
+
) -> Dict[str, Any]:
|
|
1884
|
+
r"""Asynchronously summarize the agent's current conversation context
|
|
1885
|
+
and persist it to a markdown file.
|
|
1886
|
+
|
|
1887
|
+
This is the async version of summarize() that uses astep() for
|
|
1888
|
+
non-blocking LLM calls, enabling parallel summarization of multiple
|
|
1889
|
+
agents.
|
|
1890
|
+
|
|
1891
|
+
Args:
|
|
1892
|
+
filename (Optional[str]): The base filename (without extension) to
|
|
1893
|
+
use for the markdown file. Defaults to a timestamped name when
|
|
1894
|
+
not provided.
|
|
1895
|
+
summary_prompt (Optional[str]): Custom prompt for the summarizer.
|
|
1896
|
+
When omitted, a default prompt highlighting key decisions,
|
|
1897
|
+
action items, and open questions is used.
|
|
1898
|
+
response_format (Optional[Type[BaseModel]]): A Pydantic model
|
|
1899
|
+
defining the expected structure of the response. If provided,
|
|
1900
|
+
the summary will be generated as structured output and included
|
|
1901
|
+
in the result.
|
|
1902
|
+
working_directory (Optional[str|Path]): Optional directory to save
|
|
1903
|
+
the markdown summary file. If provided, overrides the default
|
|
1904
|
+
directory used by ContextUtility.
|
|
1905
|
+
include_summaries (bool): Whether to include previously generated
|
|
1906
|
+
summaries in the content to be summarized. If False (default),
|
|
1907
|
+
only non-summary messages will be summarized. If True, all
|
|
1908
|
+
messages including previous summaries will be summarized
|
|
1909
|
+
(full compression). (default: :obj:`False`)
|
|
1910
|
+
add_user_messages (bool): Whether add user messages to summary.
|
|
1911
|
+
(default: :obj:`True`)
|
|
1912
|
+
Returns:
|
|
1913
|
+
Dict[str, Any]: A dictionary containing the summary text, file
|
|
1914
|
+
path, status message, and optionally structured_summary if
|
|
1915
|
+
response_format was provided.
|
|
1916
|
+
"""
|
|
1917
|
+
|
|
1918
|
+
result: Dict[str, Any] = {
|
|
1919
|
+
"summary": "",
|
|
1920
|
+
"file_path": None,
|
|
1921
|
+
"status": "",
|
|
1922
|
+
}
|
|
1923
|
+
|
|
1924
|
+
try:
|
|
1925
|
+
# Use external context if set, otherwise create local one
|
|
1926
|
+
if self._context_utility is None:
|
|
1927
|
+
if working_directory is not None:
|
|
1928
|
+
self._context_utility = ContextUtility(
|
|
1929
|
+
working_directory=str(working_directory)
|
|
1930
|
+
)
|
|
1931
|
+
else:
|
|
1932
|
+
self._context_utility = ContextUtility()
|
|
1933
|
+
context_util = self._context_utility
|
|
1934
|
+
|
|
1935
|
+
# Get conversation directly from agent's memory
|
|
1936
|
+
messages, _ = self.memory.get_context()
|
|
1937
|
+
|
|
1938
|
+
if not messages:
|
|
1939
|
+
status_message = (
|
|
1940
|
+
"No conversation context available to summarize."
|
|
1941
|
+
)
|
|
1942
|
+
result["status"] = status_message
|
|
1943
|
+
return result
|
|
1944
|
+
|
|
1945
|
+
# Convert messages to conversation text
|
|
1946
|
+
conversation_lines = []
|
|
1947
|
+
user_messages: List[str] = []
|
|
1948
|
+
for message in messages:
|
|
1949
|
+
role = message.get('role', 'unknown')
|
|
1950
|
+
content = message.get('content', '')
|
|
1951
|
+
|
|
1952
|
+
# Skip summary messages if include_summaries is False
|
|
1953
|
+
if not include_summaries and isinstance(content, str):
|
|
1954
|
+
# Check if this is a summary message by looking for marker
|
|
1955
|
+
if content.startswith('[CONTEXT_SUMMARY]'):
|
|
1956
|
+
continue
|
|
1957
|
+
|
|
1958
|
+
# Handle tool call messages (assistant calling tools)
|
|
1959
|
+
tool_calls = message.get('tool_calls')
|
|
1960
|
+
if tool_calls and isinstance(tool_calls, (list, tuple)):
|
|
1961
|
+
for tool_call in tool_calls:
|
|
1962
|
+
# Handle both dict and object formats
|
|
1963
|
+
if isinstance(tool_call, dict):
|
|
1964
|
+
func_name = tool_call.get('function', {}).get(
|
|
1965
|
+
'name', 'unknown_tool'
|
|
1966
|
+
)
|
|
1967
|
+
func_args_str = tool_call.get('function', {}).get(
|
|
1968
|
+
'arguments', '{}'
|
|
1969
|
+
)
|
|
1970
|
+
else:
|
|
1971
|
+
# Handle object format (Pydantic or similar)
|
|
1972
|
+
func_name = getattr(
|
|
1973
|
+
getattr(tool_call, 'function', None),
|
|
1974
|
+
'name',
|
|
1975
|
+
'unknown_tool',
|
|
1976
|
+
)
|
|
1977
|
+
func_args_str = getattr(
|
|
1978
|
+
getattr(tool_call, 'function', None),
|
|
1979
|
+
'arguments',
|
|
1980
|
+
'{}',
|
|
1981
|
+
)
|
|
1982
|
+
|
|
1983
|
+
# Parse and format arguments for readability
|
|
1984
|
+
try:
|
|
1985
|
+
import json
|
|
1986
|
+
|
|
1987
|
+
args_dict = json.loads(func_args_str)
|
|
1988
|
+
args_formatted = ', '.join(
|
|
1989
|
+
f"{k}={v}" for k, v in args_dict.items()
|
|
1990
|
+
)
|
|
1991
|
+
except (json.JSONDecodeError, ValueError, TypeError):
|
|
1992
|
+
args_formatted = func_args_str
|
|
1993
|
+
|
|
1994
|
+
conversation_lines.append(
|
|
1995
|
+
f"[TOOL CALL] {func_name}({args_formatted})"
|
|
1996
|
+
)
|
|
1997
|
+
|
|
1998
|
+
# Handle tool response messages
|
|
1999
|
+
elif role == 'tool':
|
|
2000
|
+
tool_name = message.get('name', 'unknown_tool')
|
|
2001
|
+
if not content:
|
|
2002
|
+
content = str(message.get('content', ''))
|
|
2003
|
+
conversation_lines.append(
|
|
2004
|
+
f"[TOOL RESULT] {tool_name} → {content}"
|
|
2005
|
+
)
|
|
2006
|
+
|
|
2007
|
+
# Handle regular content messages (user/assistant/system)
|
|
2008
|
+
elif content:
|
|
2009
|
+
content = str(content)
|
|
2010
|
+
if role == 'user':
|
|
2011
|
+
user_messages.append(content)
|
|
2012
|
+
conversation_lines.append(f"{role}: {content}")
|
|
2013
|
+
|
|
2014
|
+
conversation_text = "\n".join(conversation_lines).strip()
|
|
2015
|
+
|
|
2016
|
+
if not conversation_text:
|
|
2017
|
+
status_message = (
|
|
2018
|
+
"Conversation context is empty; skipping summary."
|
|
2019
|
+
)
|
|
2020
|
+
result["status"] = status_message
|
|
2021
|
+
return result
|
|
2022
|
+
|
|
2023
|
+
if self._context_summary_agent is None:
|
|
2024
|
+
self._context_summary_agent = ChatAgent(
|
|
2025
|
+
system_message=(
|
|
2026
|
+
"You are a helpful assistant that summarizes "
|
|
2027
|
+
"conversations"
|
|
2028
|
+
),
|
|
2029
|
+
model=self.model_backend,
|
|
2030
|
+
agent_id=f"{self.agent_id}_context_summarizer",
|
|
2031
|
+
summarize_threshold=None,
|
|
2032
|
+
)
|
|
2033
|
+
else:
|
|
2034
|
+
self._context_summary_agent.reset()
|
|
2035
|
+
|
|
2036
|
+
if summary_prompt:
|
|
2037
|
+
prompt_text = (
|
|
2038
|
+
f"{summary_prompt.rstrip()}\n\n"
|
|
2039
|
+
f"AGENT CONVERSATION TO BE SUMMARIZED:\n"
|
|
2040
|
+
f"{conversation_text}"
|
|
2041
|
+
)
|
|
2042
|
+
else:
|
|
2043
|
+
prompt_text = build_default_summary_prompt(conversation_text)
|
|
2044
|
+
|
|
2045
|
+
try:
|
|
2046
|
+
# Use structured output if response_format is provided
|
|
2047
|
+
if response_format:
|
|
2048
|
+
response = await self._context_summary_agent.astep(
|
|
2049
|
+
prompt_text, response_format=response_format
|
|
2050
|
+
)
|
|
2051
|
+
else:
|
|
2052
|
+
response = await self._context_summary_agent.astep(
|
|
2053
|
+
prompt_text
|
|
2054
|
+
)
|
|
2055
|
+
|
|
2056
|
+
# Handle streaming response
|
|
2057
|
+
if isinstance(response, AsyncStreamingChatAgentResponse):
|
|
2058
|
+
# Collect final response
|
|
2059
|
+
final_response = await response
|
|
2060
|
+
response = final_response
|
|
2061
|
+
|
|
2062
|
+
except Exception as step_exc:
|
|
2063
|
+
error_message = (
|
|
2064
|
+
f"Failed to generate summary using model: {step_exc}"
|
|
2065
|
+
)
|
|
2066
|
+
logger.error(error_message)
|
|
2067
|
+
result["status"] = error_message
|
|
2068
|
+
return result
|
|
2069
|
+
|
|
2070
|
+
if not response.msgs:
|
|
2071
|
+
status_message = (
|
|
2072
|
+
"Failed to generate summary from model response."
|
|
2073
|
+
)
|
|
2074
|
+
result["status"] = status_message
|
|
2075
|
+
return result
|
|
2076
|
+
|
|
2077
|
+
summary_content = response.msgs[-1].content.strip()
|
|
2078
|
+
if not summary_content:
|
|
2079
|
+
status_message = "Generated summary is empty."
|
|
2080
|
+
result["status"] = status_message
|
|
2081
|
+
return result
|
|
2082
|
+
|
|
2083
|
+
# handle structured output if response_format was provided
|
|
2084
|
+
structured_output = None
|
|
2085
|
+
if response_format and response.msgs[-1].parsed:
|
|
2086
|
+
structured_output = response.msgs[-1].parsed
|
|
2087
|
+
|
|
2088
|
+
# determine filename: use provided filename, or extract from
|
|
2089
|
+
# structured output, or generate timestamp
|
|
2090
|
+
if filename:
|
|
2091
|
+
base_filename = filename
|
|
2092
|
+
elif structured_output and hasattr(
|
|
2093
|
+
structured_output, 'task_title'
|
|
2094
|
+
):
|
|
2095
|
+
# use task_title from structured output for filename
|
|
2096
|
+
task_title = structured_output.task_title
|
|
2097
|
+
clean_title = ContextUtility.sanitize_workflow_filename(
|
|
2098
|
+
task_title
|
|
2099
|
+
)
|
|
2100
|
+
base_filename = (
|
|
2101
|
+
f"{clean_title}_workflow" if clean_title else "workflow"
|
|
2102
|
+
)
|
|
2103
|
+
else:
|
|
2104
|
+
base_filename = f"context_summary_{datetime.now().strftime('%Y%m%d_%H%M%S')}" # noqa: E501
|
|
2105
|
+
|
|
2106
|
+
base_filename = Path(base_filename).with_suffix("").name
|
|
2107
|
+
|
|
2108
|
+
metadata = context_util.get_session_metadata()
|
|
2109
|
+
metadata.update(
|
|
2110
|
+
{
|
|
2111
|
+
"agent_id": self.agent_id,
|
|
2112
|
+
"message_count": len(messages),
|
|
2113
|
+
}
|
|
2114
|
+
)
|
|
2115
|
+
|
|
2116
|
+
# convert structured output to custom markdown if present
|
|
2117
|
+
if structured_output:
|
|
2118
|
+
# convert structured output to custom markdown
|
|
2119
|
+
summary_content = context_util.structured_output_to_markdown(
|
|
2120
|
+
structured_data=structured_output, metadata=metadata
|
|
2121
|
+
)
|
|
2122
|
+
if add_user_messages:
|
|
2123
|
+
summary_content = self._append_user_messages_section(
|
|
2124
|
+
summary_content, user_messages
|
|
2125
|
+
)
|
|
2126
|
+
|
|
2127
|
+
# Save the markdown (either custom structured or default)
|
|
2128
|
+
save_status = context_util.save_markdown_file(
|
|
2129
|
+
base_filename,
|
|
2130
|
+
summary_content,
|
|
2131
|
+
title="Conversation Summary"
|
|
2132
|
+
if not structured_output
|
|
2133
|
+
else None,
|
|
2134
|
+
metadata=metadata if not structured_output else None,
|
|
2135
|
+
)
|
|
2136
|
+
|
|
2137
|
+
file_path = (
|
|
2138
|
+
context_util.get_working_directory() / f"{base_filename}.md"
|
|
2139
|
+
)
|
|
2140
|
+
|
|
2141
|
+
summary_content = (
|
|
2142
|
+
f"[CONTEXT_SUMMARY] The following is a summary of our "
|
|
2143
|
+
f"conversation from a previous session: {summary_content}"
|
|
2144
|
+
)
|
|
2145
|
+
|
|
2146
|
+
# Prepare result dictionary
|
|
2147
|
+
result_dict = {
|
|
2148
|
+
"summary": summary_content,
|
|
2149
|
+
"file_path": str(file_path),
|
|
2150
|
+
"status": save_status,
|
|
2151
|
+
"structured_summary": structured_output,
|
|
2152
|
+
}
|
|
2153
|
+
|
|
2154
|
+
result.update(result_dict)
|
|
2155
|
+
logger.info("Conversation summary saved to %s", file_path)
|
|
2156
|
+
return result
|
|
2157
|
+
|
|
2158
|
+
except Exception as exc:
|
|
2159
|
+
error_message = f"Failed to summarize conversation context: {exc}"
|
|
2160
|
+
logger.error(error_message)
|
|
2161
|
+
result["status"] = error_message
|
|
2162
|
+
return result
|
|
2163
|
+
|
|
2164
|
+
def clear_memory(self) -> None:
|
|
2165
|
+
r"""Clear the agent's memory and reset to initial state.
|
|
2166
|
+
|
|
2167
|
+
Returns:
|
|
2168
|
+
None
|
|
2169
|
+
"""
|
|
2170
|
+
self.memory.clear()
|
|
2171
|
+
|
|
2172
|
+
if self.system_message is not None:
|
|
2173
|
+
self.memory.write_record(
|
|
2174
|
+
MemoryRecord(
|
|
2175
|
+
message=self.system_message,
|
|
2176
|
+
role_at_backend=OpenAIBackendRole.SYSTEM,
|
|
2177
|
+
timestamp=time.time_ns() / 1_000_000_000,
|
|
2178
|
+
agent_id=self.agent_id,
|
|
2179
|
+
)
|
|
649
2180
|
)
|
|
650
2181
|
|
|
2182
|
+
def _generate_system_message_for_output_language(
|
|
2183
|
+
self,
|
|
2184
|
+
) -> Optional[BaseMessage]:
|
|
2185
|
+
r"""Generate a new system message with the output language prompt.
|
|
2186
|
+
|
|
2187
|
+
The output language determines the language in which the output text
|
|
2188
|
+
should be generated.
|
|
2189
|
+
|
|
2190
|
+
Returns:
|
|
2191
|
+
BaseMessage: The new system message.
|
|
2192
|
+
"""
|
|
2193
|
+
if not self._output_language:
|
|
2194
|
+
return self._original_system_message
|
|
2195
|
+
|
|
2196
|
+
language_prompt = (
|
|
2197
|
+
"\nRegardless of the input language, "
|
|
2198
|
+
f"you must output text in {self._output_language}."
|
|
2199
|
+
)
|
|
2200
|
+
|
|
2201
|
+
if self._original_system_message is not None:
|
|
2202
|
+
content = self._original_system_message.content + language_prompt
|
|
2203
|
+
return self._original_system_message.create_new_instance(content)
|
|
2204
|
+
else:
|
|
2205
|
+
return BaseMessage.make_system_message(language_prompt)
|
|
2206
|
+
|
|
651
2207
|
def init_messages(self) -> None:
|
|
652
2208
|
r"""Initializes the stored messages list with the current system
|
|
653
2209
|
message.
|
|
654
2210
|
"""
|
|
655
|
-
self.
|
|
656
|
-
|
|
657
|
-
|
|
2211
|
+
self._reset_summary_state()
|
|
2212
|
+
self.clear_memory()
|
|
2213
|
+
|
|
2214
|
+
def update_system_message(
|
|
2215
|
+
self,
|
|
2216
|
+
system_message: Union[BaseMessage, str],
|
|
2217
|
+
reset_memory: bool = True,
|
|
2218
|
+
) -> None:
|
|
2219
|
+
r"""Update the system message.
|
|
2220
|
+
It will reset conversation with new system message.
|
|
2221
|
+
|
|
2222
|
+
Args:
|
|
2223
|
+
system_message (Union[BaseMessage, str]): The new system message.
|
|
2224
|
+
Can be either a BaseMessage object or a string.
|
|
2225
|
+
If a string is provided, it will be converted
|
|
2226
|
+
into a BaseMessage object.
|
|
2227
|
+
reset_memory (bool):
|
|
2228
|
+
Whether to reinitialize conversation messages after updating
|
|
2229
|
+
the system message. Defaults to True.
|
|
2230
|
+
"""
|
|
2231
|
+
if system_message is None:
|
|
2232
|
+
raise ValueError("system_message is required and cannot be None. ")
|
|
2233
|
+
self._original_system_message = (
|
|
2234
|
+
BaseMessage.make_system_message(system_message)
|
|
2235
|
+
if isinstance(system_message, str)
|
|
2236
|
+
else system_message
|
|
2237
|
+
)
|
|
2238
|
+
self._system_message = (
|
|
2239
|
+
self._generate_system_message_for_output_language()
|
|
2240
|
+
)
|
|
2241
|
+
if reset_memory:
|
|
2242
|
+
self.init_messages()
|
|
2243
|
+
|
|
2244
|
+
def append_to_system_message(
|
|
2245
|
+
self, content: str, reset_memory: bool = True
|
|
2246
|
+
) -> None:
|
|
2247
|
+
"""Append additional context to existing system message.
|
|
2248
|
+
|
|
2249
|
+
Args:
|
|
2250
|
+
content (str): The additional system message.
|
|
2251
|
+
reset_memory (bool):
|
|
2252
|
+
Whether to reinitialize conversation messages after appending
|
|
2253
|
+
additional context. Defaults to True.
|
|
2254
|
+
"""
|
|
2255
|
+
original_content = (
|
|
2256
|
+
self._original_system_message.content
|
|
2257
|
+
if self._original_system_message
|
|
2258
|
+
else ""
|
|
2259
|
+
)
|
|
2260
|
+
new_system_message = original_content + '\n' + content
|
|
2261
|
+
self._original_system_message = BaseMessage.make_system_message(
|
|
2262
|
+
new_system_message
|
|
2263
|
+
)
|
|
2264
|
+
self._system_message = (
|
|
2265
|
+
self._generate_system_message_for_output_language()
|
|
2266
|
+
)
|
|
2267
|
+
if reset_memory:
|
|
2268
|
+
self.init_messages()
|
|
2269
|
+
|
|
2270
|
+
def reset_to_original_system_message(self) -> None:
|
|
2271
|
+
r"""Reset system message to original, removing any appended context.
|
|
2272
|
+
|
|
2273
|
+
This method reverts the agent's system message back to its original
|
|
2274
|
+
state, removing any workflow context or other modifications that may
|
|
2275
|
+
have been appended. Useful for resetting agent state in multi-turn
|
|
2276
|
+
scenarios.
|
|
2277
|
+
"""
|
|
2278
|
+
self._system_message = self._original_system_message
|
|
2279
|
+
self.init_messages()
|
|
658
2280
|
|
|
659
2281
|
def record_message(self, message: BaseMessage) -> None:
|
|
660
2282
|
r"""Records the externally provided message into the agent memory as if
|
|
@@ -687,6 +2309,210 @@ class ChatAgent(BaseAgent):
|
|
|
687
2309
|
except ValidationError:
|
|
688
2310
|
return False
|
|
689
2311
|
|
|
2312
|
+
def _check_tools_strict_compatibility(self) -> bool:
|
|
2313
|
+
r"""Check if all tools are compatible with OpenAI strict mode.
|
|
2314
|
+
|
|
2315
|
+
Returns:
|
|
2316
|
+
bool: True if all tools are strict mode compatible,
|
|
2317
|
+
False otherwise.
|
|
2318
|
+
"""
|
|
2319
|
+
tool_schemas = self._get_full_tool_schemas()
|
|
2320
|
+
for schema in tool_schemas:
|
|
2321
|
+
if not schema.get("function", {}).get("strict", True):
|
|
2322
|
+
return False
|
|
2323
|
+
return True
|
|
2324
|
+
|
|
2325
|
+
def _convert_response_format_to_prompt(
|
|
2326
|
+
self, response_format: Type[BaseModel]
|
|
2327
|
+
) -> str:
|
|
2328
|
+
r"""Convert a Pydantic response format to a prompt instruction.
|
|
2329
|
+
|
|
2330
|
+
Args:
|
|
2331
|
+
response_format (Type[BaseModel]): The Pydantic model class.
|
|
2332
|
+
|
|
2333
|
+
Returns:
|
|
2334
|
+
str: A prompt instruction requesting the specific format.
|
|
2335
|
+
"""
|
|
2336
|
+
try:
|
|
2337
|
+
# Get the JSON schema from the Pydantic model
|
|
2338
|
+
schema = response_format.model_json_schema()
|
|
2339
|
+
|
|
2340
|
+
# Create a prompt based on the schema
|
|
2341
|
+
format_instruction = (
|
|
2342
|
+
"\n\nPlease respond in the following JSON format:\n{\n"
|
|
2343
|
+
)
|
|
2344
|
+
|
|
2345
|
+
properties = schema.get("properties", {})
|
|
2346
|
+
for field_name, field_info in properties.items():
|
|
2347
|
+
field_type = field_info.get("type", "string")
|
|
2348
|
+
description = field_info.get("description", "")
|
|
2349
|
+
|
|
2350
|
+
if field_type == "array":
|
|
2351
|
+
format_instruction += (
|
|
2352
|
+
f' "{field_name}": ["array of values"]'
|
|
2353
|
+
)
|
|
2354
|
+
elif field_type == "object":
|
|
2355
|
+
format_instruction += f' "{field_name}": {{"object"}}'
|
|
2356
|
+
elif field_type == "boolean":
|
|
2357
|
+
format_instruction += f' "{field_name}": true'
|
|
2358
|
+
elif field_type == "number":
|
|
2359
|
+
format_instruction += f' "{field_name}": 0'
|
|
2360
|
+
else:
|
|
2361
|
+
format_instruction += f' "{field_name}": "string value"'
|
|
2362
|
+
|
|
2363
|
+
if description:
|
|
2364
|
+
format_instruction += f' // {description}'
|
|
2365
|
+
|
|
2366
|
+
# Add comma if not the last item
|
|
2367
|
+
if field_name != list(properties.keys())[-1]:
|
|
2368
|
+
format_instruction += ","
|
|
2369
|
+
format_instruction += "\n"
|
|
2370
|
+
|
|
2371
|
+
format_instruction += "}"
|
|
2372
|
+
return format_instruction
|
|
2373
|
+
|
|
2374
|
+
except Exception as e:
|
|
2375
|
+
logger.warning(
|
|
2376
|
+
f"Failed to convert response_format to prompt: {e}. "
|
|
2377
|
+
f"Using generic format instruction."
|
|
2378
|
+
)
|
|
2379
|
+
return (
|
|
2380
|
+
"\n\nPlease respond in a structured JSON format "
|
|
2381
|
+
"that matches the requested schema."
|
|
2382
|
+
)
|
|
2383
|
+
|
|
2384
|
+
def _handle_response_format_with_non_strict_tools(
|
|
2385
|
+
self,
|
|
2386
|
+
input_message: Union[BaseMessage, str],
|
|
2387
|
+
response_format: Optional[Type[BaseModel]] = None,
|
|
2388
|
+
) -> Tuple[Union[BaseMessage, str], Optional[Type[BaseModel]], bool]:
|
|
2389
|
+
r"""Handle response format when tools are not strict mode compatible.
|
|
2390
|
+
|
|
2391
|
+
Args:
|
|
2392
|
+
input_message: The original input message.
|
|
2393
|
+
response_format: The requested response format.
|
|
2394
|
+
|
|
2395
|
+
Returns:
|
|
2396
|
+
Tuple: (modified_message, modified_response_format,
|
|
2397
|
+
used_prompt_formatting)
|
|
2398
|
+
"""
|
|
2399
|
+
if response_format is None:
|
|
2400
|
+
return input_message, response_format, False
|
|
2401
|
+
|
|
2402
|
+
# Check if tools are strict mode compatible
|
|
2403
|
+
if self._check_tools_strict_compatibility():
|
|
2404
|
+
return input_message, response_format, False
|
|
2405
|
+
|
|
2406
|
+
# Tools are not strict compatible, convert to prompt
|
|
2407
|
+
logger.info(
|
|
2408
|
+
"Non-strict tools detected. Converting response_format to "
|
|
2409
|
+
"prompt-based formatting."
|
|
2410
|
+
)
|
|
2411
|
+
|
|
2412
|
+
format_prompt = self._convert_response_format_to_prompt(
|
|
2413
|
+
response_format
|
|
2414
|
+
)
|
|
2415
|
+
|
|
2416
|
+
# Modify the message to include format instruction
|
|
2417
|
+
modified_message: Union[BaseMessage, str]
|
|
2418
|
+
if isinstance(input_message, str):
|
|
2419
|
+
modified_message = input_message + format_prompt
|
|
2420
|
+
else:
|
|
2421
|
+
modified_message = input_message.create_new_instance(
|
|
2422
|
+
input_message.content + format_prompt
|
|
2423
|
+
)
|
|
2424
|
+
|
|
2425
|
+
# Return None for response_format to avoid strict mode conflicts
|
|
2426
|
+
# and True to indicate we used prompt formatting
|
|
2427
|
+
return modified_message, None, True
|
|
2428
|
+
|
|
2429
|
+
def _is_called_from_registered_toolkit(self) -> bool:
|
|
2430
|
+
r"""Check if current step/astep call originates from a
|
|
2431
|
+
RegisteredAgentToolkit.
|
|
2432
|
+
|
|
2433
|
+
This method uses stack inspection to detect if the current call
|
|
2434
|
+
is originating from a toolkit that inherits from
|
|
2435
|
+
RegisteredAgentToolkit. When detected, tools should be disabled to
|
|
2436
|
+
prevent recursive calls.
|
|
2437
|
+
|
|
2438
|
+
Returns:
|
|
2439
|
+
bool: True if called from a RegisteredAgentToolkit, False otherwise
|
|
2440
|
+
"""
|
|
2441
|
+
from camel.toolkits.base import RegisteredAgentToolkit
|
|
2442
|
+
|
|
2443
|
+
try:
|
|
2444
|
+
for frame_info in inspect.stack():
|
|
2445
|
+
frame_locals = frame_info.frame.f_locals
|
|
2446
|
+
if 'self' in frame_locals:
|
|
2447
|
+
caller_self = frame_locals['self']
|
|
2448
|
+
if isinstance(caller_self, RegisteredAgentToolkit):
|
|
2449
|
+
return True
|
|
2450
|
+
|
|
2451
|
+
except Exception:
|
|
2452
|
+
return False
|
|
2453
|
+
|
|
2454
|
+
return False
|
|
2455
|
+
|
|
2456
|
+
def _apply_prompt_based_parsing(
|
|
2457
|
+
self,
|
|
2458
|
+
response: ModelResponse,
|
|
2459
|
+
original_response_format: Type[BaseModel],
|
|
2460
|
+
) -> None:
|
|
2461
|
+
r"""Apply manual parsing when using prompt-based formatting.
|
|
2462
|
+
|
|
2463
|
+
Args:
|
|
2464
|
+
response: The model response to parse.
|
|
2465
|
+
original_response_format: The original response format class.
|
|
2466
|
+
"""
|
|
2467
|
+
for message in response.output_messages:
|
|
2468
|
+
if message.content:
|
|
2469
|
+
try:
|
|
2470
|
+
# Try to extract JSON from the response content
|
|
2471
|
+
import json
|
|
2472
|
+
|
|
2473
|
+
from pydantic import ValidationError
|
|
2474
|
+
|
|
2475
|
+
# Try to find JSON in the content
|
|
2476
|
+
content = message.content.strip()
|
|
2477
|
+
|
|
2478
|
+
# Try direct parsing first
|
|
2479
|
+
try:
|
|
2480
|
+
parsed_json = json.loads(content)
|
|
2481
|
+
message.parsed = (
|
|
2482
|
+
original_response_format.model_validate(
|
|
2483
|
+
parsed_json
|
|
2484
|
+
)
|
|
2485
|
+
)
|
|
2486
|
+
continue
|
|
2487
|
+
except (json.JSONDecodeError, ValidationError):
|
|
2488
|
+
pass
|
|
2489
|
+
|
|
2490
|
+
# Try to extract JSON from text
|
|
2491
|
+
json_pattern = r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}'
|
|
2492
|
+
json_matches = re.findall(json_pattern, content, re.DOTALL)
|
|
2493
|
+
|
|
2494
|
+
for json_str in json_matches:
|
|
2495
|
+
try:
|
|
2496
|
+
parsed_json = json.loads(json_str)
|
|
2497
|
+
message.parsed = (
|
|
2498
|
+
original_response_format.model_validate(
|
|
2499
|
+
parsed_json
|
|
2500
|
+
)
|
|
2501
|
+
)
|
|
2502
|
+
# Update content to just the JSON for consistency
|
|
2503
|
+
message.content = json.dumps(parsed_json)
|
|
2504
|
+
break
|
|
2505
|
+
except (json.JSONDecodeError, ValidationError):
|
|
2506
|
+
continue
|
|
2507
|
+
|
|
2508
|
+
if not message.parsed:
|
|
2509
|
+
logger.warning(
|
|
2510
|
+
f"Failed to parse JSON from response: {content}"
|
|
2511
|
+
)
|
|
2512
|
+
|
|
2513
|
+
except Exception as e:
|
|
2514
|
+
logger.warning(f"Error during prompt-based parsing: {e}")
|
|
2515
|
+
|
|
690
2516
|
def _format_response_if_needed(
|
|
691
2517
|
self,
|
|
692
2518
|
response: ModelResponse,
|
|
@@ -709,7 +2535,11 @@ class ChatAgent(BaseAgent):
|
|
|
709
2535
|
openai_message: OpenAIMessage = {"role": "user", "content": prompt}
|
|
710
2536
|
# Explicitly set the tools to empty list to avoid calling tools
|
|
711
2537
|
response = self._get_model_response(
|
|
712
|
-
[openai_message],
|
|
2538
|
+
openai_messages=[openai_message],
|
|
2539
|
+
num_tokens=0,
|
|
2540
|
+
response_format=response_format,
|
|
2541
|
+
tool_schemas=[],
|
|
2542
|
+
prev_num_openai_messages=0,
|
|
713
2543
|
)
|
|
714
2544
|
message.content = response.output_messages[0].content
|
|
715
2545
|
if not self._try_format_message(message, response_format):
|
|
@@ -737,7 +2567,11 @@ class ChatAgent(BaseAgent):
|
|
|
737
2567
|
prompt = SIMPLE_FORMAT_PROMPT.format(content=message.content)
|
|
738
2568
|
openai_message: OpenAIMessage = {"role": "user", "content": prompt}
|
|
739
2569
|
response = await self._aget_model_response(
|
|
740
|
-
[openai_message],
|
|
2570
|
+
openai_messages=[openai_message],
|
|
2571
|
+
num_tokens=0,
|
|
2572
|
+
response_format=response_format,
|
|
2573
|
+
tool_schemas=[],
|
|
2574
|
+
prev_num_openai_messages=0,
|
|
741
2575
|
)
|
|
742
2576
|
message.content = response.output_messages[0].content
|
|
743
2577
|
self._try_format_message(message, response_format)
|
|
@@ -747,7 +2581,7 @@ class ChatAgent(BaseAgent):
|
|
|
747
2581
|
self,
|
|
748
2582
|
input_message: Union[BaseMessage, str],
|
|
749
2583
|
response_format: Optional[Type[BaseModel]] = None,
|
|
750
|
-
) -> ChatAgentResponse:
|
|
2584
|
+
) -> Union[ChatAgentResponse, StreamingChatAgentResponse]:
|
|
751
2585
|
r"""Executes a single step in the chat session, generating a response
|
|
752
2586
|
to the input message.
|
|
753
2587
|
|
|
@@ -761,10 +2595,47 @@ class ChatAgent(BaseAgent):
|
|
|
761
2595
|
:obj:`None`)
|
|
762
2596
|
|
|
763
2597
|
Returns:
|
|
764
|
-
ChatAgentResponse:
|
|
765
|
-
|
|
2598
|
+
Union[ChatAgentResponse, StreamingChatAgentResponse]: If stream is
|
|
2599
|
+
False, returns a ChatAgentResponse. If stream is True, returns
|
|
2600
|
+
a StreamingChatAgentResponse that behaves like
|
|
2601
|
+
ChatAgentResponse but can also be iterated for
|
|
2602
|
+
streaming updates.
|
|
2603
|
+
|
|
2604
|
+
Raises:
|
|
2605
|
+
TimeoutError: If the step operation exceeds the configured timeout.
|
|
766
2606
|
"""
|
|
767
2607
|
|
|
2608
|
+
stream = self.model_backend.model_config_dict.get("stream", False)
|
|
2609
|
+
|
|
2610
|
+
if stream:
|
|
2611
|
+
# Return wrapped generator that has ChatAgentResponse interface
|
|
2612
|
+
generator = self._stream(input_message, response_format)
|
|
2613
|
+
return StreamingChatAgentResponse(generator)
|
|
2614
|
+
|
|
2615
|
+
# Execute with timeout if configured
|
|
2616
|
+
if self.step_timeout is not None:
|
|
2617
|
+
with concurrent.futures.ThreadPoolExecutor(
|
|
2618
|
+
max_workers=1
|
|
2619
|
+
) as executor:
|
|
2620
|
+
future = executor.submit(
|
|
2621
|
+
self._step_impl, input_message, response_format
|
|
2622
|
+
)
|
|
2623
|
+
try:
|
|
2624
|
+
return future.result(timeout=self.step_timeout)
|
|
2625
|
+
except concurrent.futures.TimeoutError:
|
|
2626
|
+
future.cancel()
|
|
2627
|
+
raise TimeoutError(
|
|
2628
|
+
f"Step timed out after {self.step_timeout}s"
|
|
2629
|
+
)
|
|
2630
|
+
else:
|
|
2631
|
+
return self._step_impl(input_message, response_format)
|
|
2632
|
+
|
|
2633
|
+
def _step_impl(
|
|
2634
|
+
self,
|
|
2635
|
+
input_message: Union[BaseMessage, str],
|
|
2636
|
+
response_format: Optional[Type[BaseModel]] = None,
|
|
2637
|
+
) -> ChatAgentResponse:
|
|
2638
|
+
r"""Implementation of non-streaming step logic."""
|
|
768
2639
|
# Set Langfuse session_id using agent_id for trace grouping
|
|
769
2640
|
try:
|
|
770
2641
|
from camel.utils.langfuse import set_current_agent_session_id
|
|
@@ -773,6 +2644,18 @@ class ChatAgent(BaseAgent):
|
|
|
773
2644
|
except ImportError:
|
|
774
2645
|
pass # Langfuse not available
|
|
775
2646
|
|
|
2647
|
+
# Check if this call is from a RegisteredAgentToolkit to prevent tool
|
|
2648
|
+
# use
|
|
2649
|
+
disable_tools = self._is_called_from_registered_toolkit()
|
|
2650
|
+
|
|
2651
|
+
# Handle response format compatibility with non-strict tools
|
|
2652
|
+
original_response_format = response_format
|
|
2653
|
+
input_message, response_format, used_prompt_formatting = (
|
|
2654
|
+
self._handle_response_format_with_non_strict_tools(
|
|
2655
|
+
input_message, response_format
|
|
2656
|
+
)
|
|
2657
|
+
)
|
|
2658
|
+
|
|
776
2659
|
# Convert input message to BaseMessage if necessary
|
|
777
2660
|
if isinstance(input_message, str):
|
|
778
2661
|
input_message = BaseMessage.make_user_message(
|
|
@@ -791,23 +2674,138 @@ class ChatAgent(BaseAgent):
|
|
|
791
2674
|
|
|
792
2675
|
# Initialize token usage tracker
|
|
793
2676
|
step_token_usage = self._create_token_usage_tracker()
|
|
794
|
-
iteration_count = 0
|
|
2677
|
+
iteration_count: int = 0
|
|
2678
|
+
prev_num_openai_messages: int = 0
|
|
795
2679
|
|
|
796
2680
|
while True:
|
|
2681
|
+
if self.pause_event is not None and not self.pause_event.is_set():
|
|
2682
|
+
# Use efficient blocking wait for threading.Event
|
|
2683
|
+
if isinstance(self.pause_event, threading.Event):
|
|
2684
|
+
self.pause_event.wait()
|
|
2685
|
+
else:
|
|
2686
|
+
# Fallback for asyncio.Event in sync context
|
|
2687
|
+
while not self.pause_event.is_set():
|
|
2688
|
+
time.sleep(0.001)
|
|
2689
|
+
|
|
797
2690
|
try:
|
|
798
2691
|
openai_messages, num_tokens = self.memory.get_context()
|
|
2692
|
+
if self.summarize_threshold is not None:
|
|
2693
|
+
threshold = self._calculate_next_summary_threshold()
|
|
2694
|
+
summary_token_count = self._summary_token_count
|
|
2695
|
+
token_limit = self.model_backend.token_limit
|
|
2696
|
+
|
|
2697
|
+
if num_tokens <= token_limit:
|
|
2698
|
+
if (
|
|
2699
|
+
summary_token_count
|
|
2700
|
+
> token_limit * self.summary_window_ratio
|
|
2701
|
+
):
|
|
2702
|
+
logger.info(
|
|
2703
|
+
f"Summary tokens ({summary_token_count}) "
|
|
2704
|
+
f"exceed limit, full compression."
|
|
2705
|
+
)
|
|
2706
|
+
# Summarize everything (including summaries)
|
|
2707
|
+
summary = self.summarize(include_summaries=True)
|
|
2708
|
+
self._update_memory_with_summary(
|
|
2709
|
+
summary.get("summary", ""),
|
|
2710
|
+
include_summaries=True,
|
|
2711
|
+
)
|
|
2712
|
+
elif num_tokens > threshold:
|
|
2713
|
+
logger.info(
|
|
2714
|
+
f"Token count ({num_tokens}) exceed threshold "
|
|
2715
|
+
f"({threshold}). Triggering summarization."
|
|
2716
|
+
)
|
|
2717
|
+
# Only summarize non-summary content
|
|
2718
|
+
summary = self.summarize(include_summaries=False)
|
|
2719
|
+
self._update_memory_with_summary(
|
|
2720
|
+
summary.get("summary", ""),
|
|
2721
|
+
include_summaries=False,
|
|
2722
|
+
)
|
|
799
2723
|
accumulated_context_tokens += num_tokens
|
|
800
2724
|
except RuntimeError as e:
|
|
801
2725
|
return self._step_terminate(
|
|
802
2726
|
e.args[1], tool_call_records, "max_tokens_exceeded"
|
|
803
2727
|
)
|
|
804
|
-
# Get response from model backend
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
2728
|
+
# Get response from model backend with token limit error handling
|
|
2729
|
+
try:
|
|
2730
|
+
response = self._get_model_response(
|
|
2731
|
+
openai_messages,
|
|
2732
|
+
num_tokens=num_tokens,
|
|
2733
|
+
current_iteration=iteration_count,
|
|
2734
|
+
response_format=response_format,
|
|
2735
|
+
tool_schemas=[]
|
|
2736
|
+
if disable_tools
|
|
2737
|
+
else self._get_full_tool_schemas(),
|
|
2738
|
+
prev_num_openai_messages=prev_num_openai_messages,
|
|
2739
|
+
)
|
|
2740
|
+
except Exception as exc:
|
|
2741
|
+
logger.exception("Model error: %s", exc)
|
|
2742
|
+
|
|
2743
|
+
if self._is_token_limit_error(exc):
|
|
2744
|
+
tool_signature = self._last_tool_call_signature
|
|
2745
|
+
if (
|
|
2746
|
+
tool_signature is not None
|
|
2747
|
+
and tool_signature
|
|
2748
|
+
== self._last_token_limit_tool_signature
|
|
2749
|
+
):
|
|
2750
|
+
description = self._describe_tool_call(
|
|
2751
|
+
self._last_tool_call_record
|
|
2752
|
+
)
|
|
2753
|
+
repeated_msg = (
|
|
2754
|
+
"Context exceeded again by the same tool call."
|
|
2755
|
+
)
|
|
2756
|
+
if description:
|
|
2757
|
+
repeated_msg += f" {description}"
|
|
2758
|
+
raise RuntimeError(repeated_msg) from exc
|
|
2759
|
+
|
|
2760
|
+
user_message_count = sum(
|
|
2761
|
+
1
|
|
2762
|
+
for msg in openai_messages
|
|
2763
|
+
if getattr(msg, "role", None) == "user"
|
|
2764
|
+
)
|
|
2765
|
+
if (
|
|
2766
|
+
user_message_count == 1
|
|
2767
|
+
and getattr(openai_messages[-1], "role", None)
|
|
2768
|
+
== "user"
|
|
2769
|
+
):
|
|
2770
|
+
raise RuntimeError(
|
|
2771
|
+
"The provided user input alone exceeds the "
|
|
2772
|
+
"context window. Please shorten the input."
|
|
2773
|
+
) from exc
|
|
2774
|
+
|
|
2775
|
+
logger.warning(
|
|
2776
|
+
"Token limit exceeded error detected. "
|
|
2777
|
+
"Summarizing context."
|
|
2778
|
+
)
|
|
2779
|
+
|
|
2780
|
+
recent_records: List[ContextRecord]
|
|
2781
|
+
try:
|
|
2782
|
+
recent_records = self.memory.retrieve()
|
|
2783
|
+
except Exception: # pragma: no cover - defensive guard
|
|
2784
|
+
recent_records = []
|
|
2785
|
+
|
|
2786
|
+
indices_to_remove = (
|
|
2787
|
+
self._find_indices_to_remove_for_last_tool_pair(
|
|
2788
|
+
recent_records
|
|
2789
|
+
)
|
|
2790
|
+
)
|
|
2791
|
+
self.memory.remove_records_by_indices(indices_to_remove)
|
|
2792
|
+
|
|
2793
|
+
summary = self.summarize(include_summaries=False)
|
|
2794
|
+
tool_notice = self._format_tool_limit_notice()
|
|
2795
|
+
summary_messages = summary.get("summary", "")
|
|
2796
|
+
|
|
2797
|
+
if tool_notice:
|
|
2798
|
+
summary_messages += "\n\n" + tool_notice
|
|
2799
|
+
|
|
2800
|
+
self._update_memory_with_summary(
|
|
2801
|
+
summary_messages, include_summaries=False
|
|
2802
|
+
)
|
|
2803
|
+
self._last_token_limit_tool_signature = tool_signature
|
|
2804
|
+
return self._step_impl(input_message, response_format)
|
|
2805
|
+
|
|
2806
|
+
raise
|
|
2807
|
+
|
|
2808
|
+
prev_num_openai_messages = len(openai_messages)
|
|
811
2809
|
iteration_count += 1
|
|
812
2810
|
|
|
813
2811
|
# Accumulate API token usage
|
|
@@ -818,6 +2816,9 @@ class ChatAgent(BaseAgent):
|
|
|
818
2816
|
# Terminate Agent if stop_event is set
|
|
819
2817
|
if self.stop_event and self.stop_event.is_set():
|
|
820
2818
|
# Use the _step_terminate to terminate the agent with reason
|
|
2819
|
+
logger.info(
|
|
2820
|
+
f"Termination triggered at iteration {iteration_count}"
|
|
2821
|
+
)
|
|
821
2822
|
return self._step_terminate(
|
|
822
2823
|
accumulated_context_tokens,
|
|
823
2824
|
tool_call_records,
|
|
@@ -835,9 +2836,17 @@ class ChatAgent(BaseAgent):
|
|
|
835
2836
|
external_tool_call_requests = []
|
|
836
2837
|
external_tool_call_requests.append(tool_call_request)
|
|
837
2838
|
else:
|
|
838
|
-
|
|
839
|
-
self.
|
|
840
|
-
|
|
2839
|
+
if (
|
|
2840
|
+
self.pause_event is not None
|
|
2841
|
+
and not self.pause_event.is_set()
|
|
2842
|
+
):
|
|
2843
|
+
if isinstance(self.pause_event, threading.Event):
|
|
2844
|
+
self.pause_event.wait()
|
|
2845
|
+
else:
|
|
2846
|
+
while not self.pause_event.is_set():
|
|
2847
|
+
time.sleep(0.001)
|
|
2848
|
+
result = self._execute_tool(tool_call_request)
|
|
2849
|
+
tool_call_records.append(result)
|
|
841
2850
|
|
|
842
2851
|
# If we found external tool calls, break the loop
|
|
843
2852
|
if external_tool_call_requests:
|
|
@@ -847,6 +2856,7 @@ class ChatAgent(BaseAgent):
|
|
|
847
2856
|
self.max_iteration is not None
|
|
848
2857
|
and iteration_count >= self.max_iteration
|
|
849
2858
|
):
|
|
2859
|
+
logger.info(f"Max iteration reached: {iteration_count}")
|
|
850
2860
|
break
|
|
851
2861
|
|
|
852
2862
|
# If we're still here, continue the loop
|
|
@@ -855,8 +2865,19 @@ class ChatAgent(BaseAgent):
|
|
|
855
2865
|
break
|
|
856
2866
|
|
|
857
2867
|
self._format_response_if_needed(response, response_format)
|
|
2868
|
+
|
|
2869
|
+
# Apply manual parsing if we used prompt-based formatting
|
|
2870
|
+
if used_prompt_formatting and original_response_format:
|
|
2871
|
+
self._apply_prompt_based_parsing(
|
|
2872
|
+
response, original_response_format
|
|
2873
|
+
)
|
|
2874
|
+
|
|
858
2875
|
self._record_final_output(response.output_messages)
|
|
859
2876
|
|
|
2877
|
+
# Clean tool call messages from memory after response generation
|
|
2878
|
+
if self.prune_tool_calls_from_memory and tool_call_records:
|
|
2879
|
+
self.memory.clean_tool_calls()
|
|
2880
|
+
|
|
860
2881
|
return self._convert_to_chatagent_response(
|
|
861
2882
|
response,
|
|
862
2883
|
tool_call_records,
|
|
@@ -877,7 +2898,7 @@ class ChatAgent(BaseAgent):
|
|
|
877
2898
|
self,
|
|
878
2899
|
input_message: Union[BaseMessage, str],
|
|
879
2900
|
response_format: Optional[Type[BaseModel]] = None,
|
|
880
|
-
) -> ChatAgentResponse:
|
|
2901
|
+
) -> Union[ChatAgentResponse, AsyncStreamingChatAgentResponse]:
|
|
881
2902
|
r"""Performs a single step in the chat session by generating a response
|
|
882
2903
|
to the input message. This agent step can call async function calls.
|
|
883
2904
|
|
|
@@ -893,12 +2914,55 @@ class ChatAgent(BaseAgent):
|
|
|
893
2914
|
used to generate a structured response by LLM. This schema
|
|
894
2915
|
helps in defining the expected output format. (default:
|
|
895
2916
|
:obj:`None`)
|
|
896
|
-
|
|
897
2917
|
Returns:
|
|
898
|
-
ChatAgentResponse:
|
|
899
|
-
|
|
900
|
-
|
|
2918
|
+
Union[ChatAgentResponse, AsyncStreamingChatAgentResponse]:
|
|
2919
|
+
If stream is False, returns a ChatAgentResponse. If stream is
|
|
2920
|
+
True, returns an AsyncStreamingChatAgentResponse that can be
|
|
2921
|
+
awaited for the final result or async iterated for streaming
|
|
2922
|
+
updates.
|
|
2923
|
+
|
|
2924
|
+
Raises:
|
|
2925
|
+
asyncio.TimeoutError: If the step operation exceeds the configured
|
|
2926
|
+
timeout.
|
|
901
2927
|
"""
|
|
2928
|
+
|
|
2929
|
+
try:
|
|
2930
|
+
from camel.utils.langfuse import set_current_agent_session_id
|
|
2931
|
+
|
|
2932
|
+
set_current_agent_session_id(self.agent_id)
|
|
2933
|
+
except ImportError:
|
|
2934
|
+
pass # Langfuse not available
|
|
2935
|
+
|
|
2936
|
+
stream = self.model_backend.model_config_dict.get("stream", False)
|
|
2937
|
+
if stream:
|
|
2938
|
+
# Return wrapped async generator that is awaitable
|
|
2939
|
+
async_generator = self._astream(input_message, response_format)
|
|
2940
|
+
return AsyncStreamingChatAgentResponse(async_generator)
|
|
2941
|
+
else:
|
|
2942
|
+
if self.step_timeout is not None:
|
|
2943
|
+
try:
|
|
2944
|
+
return await asyncio.wait_for(
|
|
2945
|
+
self._astep_non_streaming_task(
|
|
2946
|
+
input_message, response_format
|
|
2947
|
+
),
|
|
2948
|
+
timeout=self.step_timeout,
|
|
2949
|
+
)
|
|
2950
|
+
except asyncio.TimeoutError:
|
|
2951
|
+
raise asyncio.TimeoutError(
|
|
2952
|
+
f"Async step timed out after {self.step_timeout}s"
|
|
2953
|
+
)
|
|
2954
|
+
else:
|
|
2955
|
+
return await self._astep_non_streaming_task(
|
|
2956
|
+
input_message, response_format
|
|
2957
|
+
)
|
|
2958
|
+
|
|
2959
|
+
async def _astep_non_streaming_task(
|
|
2960
|
+
self,
|
|
2961
|
+
input_message: Union[BaseMessage, str],
|
|
2962
|
+
response_format: Optional[Type[BaseModel]] = None,
|
|
2963
|
+
) -> ChatAgentResponse:
|
|
2964
|
+
r"""Internal async method for non-streaming astep logic."""
|
|
2965
|
+
|
|
902
2966
|
try:
|
|
903
2967
|
from camel.utils.langfuse import set_current_agent_session_id
|
|
904
2968
|
|
|
@@ -906,6 +2970,18 @@ class ChatAgent(BaseAgent):
|
|
|
906
2970
|
except ImportError:
|
|
907
2971
|
pass # Langfuse not available
|
|
908
2972
|
|
|
2973
|
+
# Check if this call is from a RegisteredAgentToolkit to prevent tool
|
|
2974
|
+
# use
|
|
2975
|
+
disable_tools = self._is_called_from_registered_toolkit()
|
|
2976
|
+
|
|
2977
|
+
# Handle response format compatibility with non-strict tools
|
|
2978
|
+
original_response_format = response_format
|
|
2979
|
+
input_message, response_format, used_prompt_formatting = (
|
|
2980
|
+
self._handle_response_format_with_non_strict_tools(
|
|
2981
|
+
input_message, response_format
|
|
2982
|
+
)
|
|
2983
|
+
)
|
|
2984
|
+
|
|
909
2985
|
if isinstance(input_message, str):
|
|
910
2986
|
input_message = BaseMessage.make_user_message(
|
|
911
2987
|
role_name="User", content=input_message
|
|
@@ -921,27 +2997,155 @@ class ChatAgent(BaseAgent):
|
|
|
921
2997
|
|
|
922
2998
|
# Initialize token usage tracker
|
|
923
2999
|
step_token_usage = self._create_token_usage_tracker()
|
|
924
|
-
iteration_count = 0
|
|
3000
|
+
iteration_count: int = 0
|
|
3001
|
+
prev_num_openai_messages: int = 0
|
|
3002
|
+
|
|
925
3003
|
while True:
|
|
3004
|
+
if self.pause_event is not None and not self.pause_event.is_set():
|
|
3005
|
+
if isinstance(self.pause_event, asyncio.Event):
|
|
3006
|
+
await self.pause_event.wait()
|
|
3007
|
+
elif isinstance(self.pause_event, threading.Event):
|
|
3008
|
+
# For threading.Event in async context, run in executor
|
|
3009
|
+
loop = asyncio.get_event_loop()
|
|
3010
|
+
await loop.run_in_executor(None, self.pause_event.wait)
|
|
926
3011
|
try:
|
|
927
3012
|
openai_messages, num_tokens = self.memory.get_context()
|
|
3013
|
+
if self.summarize_threshold is not None:
|
|
3014
|
+
threshold = self._calculate_next_summary_threshold()
|
|
3015
|
+
summary_token_count = self._summary_token_count
|
|
3016
|
+
token_limit = self.model_backend.token_limit
|
|
3017
|
+
|
|
3018
|
+
if num_tokens <= token_limit:
|
|
3019
|
+
if (
|
|
3020
|
+
summary_token_count
|
|
3021
|
+
> token_limit * self.summary_window_ratio
|
|
3022
|
+
):
|
|
3023
|
+
logger.info(
|
|
3024
|
+
f"Summary tokens ({summary_token_count}) "
|
|
3025
|
+
f"exceed limit, full compression."
|
|
3026
|
+
)
|
|
3027
|
+
# Summarize everything (including summaries)
|
|
3028
|
+
summary = await self.asummarize(
|
|
3029
|
+
include_summaries=True
|
|
3030
|
+
)
|
|
3031
|
+
self._update_memory_with_summary(
|
|
3032
|
+
summary.get("summary", ""),
|
|
3033
|
+
include_summaries=True,
|
|
3034
|
+
)
|
|
3035
|
+
elif num_tokens > threshold:
|
|
3036
|
+
logger.info(
|
|
3037
|
+
f"Token count ({num_tokens}) exceed threshold "
|
|
3038
|
+
"({threshold}). Triggering summarization."
|
|
3039
|
+
)
|
|
3040
|
+
# Only summarize non-summary content
|
|
3041
|
+
summary = await self.asummarize(
|
|
3042
|
+
include_summaries=False
|
|
3043
|
+
)
|
|
3044
|
+
self._update_memory_with_summary(
|
|
3045
|
+
summary.get("summary", ""),
|
|
3046
|
+
include_summaries=False,
|
|
3047
|
+
)
|
|
928
3048
|
accumulated_context_tokens += num_tokens
|
|
929
3049
|
except RuntimeError as e:
|
|
930
3050
|
return self._step_terminate(
|
|
931
3051
|
e.args[1], tool_call_records, "max_tokens_exceeded"
|
|
932
3052
|
)
|
|
3053
|
+
# Get response from model backend with token limit error handling
|
|
3054
|
+
try:
|
|
3055
|
+
response = await self._aget_model_response(
|
|
3056
|
+
openai_messages,
|
|
3057
|
+
num_tokens=num_tokens,
|
|
3058
|
+
current_iteration=iteration_count,
|
|
3059
|
+
response_format=response_format,
|
|
3060
|
+
tool_schemas=[]
|
|
3061
|
+
if disable_tools
|
|
3062
|
+
else self._get_full_tool_schemas(),
|
|
3063
|
+
prev_num_openai_messages=prev_num_openai_messages,
|
|
3064
|
+
)
|
|
3065
|
+
except Exception as exc:
|
|
3066
|
+
logger.exception("Model error: %s", exc)
|
|
933
3067
|
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
3068
|
+
if self._is_token_limit_error(exc):
|
|
3069
|
+
tool_signature = self._last_tool_call_signature
|
|
3070
|
+
if (
|
|
3071
|
+
tool_signature is not None
|
|
3072
|
+
and tool_signature
|
|
3073
|
+
== self._last_token_limit_tool_signature
|
|
3074
|
+
):
|
|
3075
|
+
description = self._describe_tool_call(
|
|
3076
|
+
self._last_tool_call_record
|
|
3077
|
+
)
|
|
3078
|
+
repeated_msg = (
|
|
3079
|
+
"Context exceeded again by the same tool call."
|
|
3080
|
+
)
|
|
3081
|
+
if description:
|
|
3082
|
+
repeated_msg += f" {description}"
|
|
3083
|
+
raise RuntimeError(repeated_msg) from exc
|
|
3084
|
+
|
|
3085
|
+
user_message_count = sum(
|
|
3086
|
+
1
|
|
3087
|
+
for msg in openai_messages
|
|
3088
|
+
if getattr(msg, "role", None) == "user"
|
|
3089
|
+
)
|
|
3090
|
+
if (
|
|
3091
|
+
user_message_count == 1
|
|
3092
|
+
and getattr(openai_messages[-1], "role", None)
|
|
3093
|
+
== "user"
|
|
3094
|
+
):
|
|
3095
|
+
raise RuntimeError(
|
|
3096
|
+
"The provided user input alone exceeds the"
|
|
3097
|
+
"context window. Please shorten the input."
|
|
3098
|
+
) from exc
|
|
3099
|
+
|
|
3100
|
+
logger.warning(
|
|
3101
|
+
"Token limit exceeded error detected. "
|
|
3102
|
+
"Summarizing context."
|
|
3103
|
+
)
|
|
3104
|
+
|
|
3105
|
+
recent_records: List[ContextRecord]
|
|
3106
|
+
try:
|
|
3107
|
+
recent_records = self.memory.retrieve()
|
|
3108
|
+
except Exception: # pragma: no cover - defensive guard
|
|
3109
|
+
recent_records = []
|
|
3110
|
+
|
|
3111
|
+
indices_to_remove = (
|
|
3112
|
+
self._find_indices_to_remove_for_last_tool_pair(
|
|
3113
|
+
recent_records
|
|
3114
|
+
)
|
|
3115
|
+
)
|
|
3116
|
+
self.memory.remove_records_by_indices(indices_to_remove)
|
|
3117
|
+
|
|
3118
|
+
summary = await self.asummarize()
|
|
3119
|
+
|
|
3120
|
+
tool_notice = self._format_tool_limit_notice()
|
|
3121
|
+
summary_messages = summary.get("summary", "")
|
|
3122
|
+
|
|
3123
|
+
if tool_notice:
|
|
3124
|
+
summary_messages += "\n\n" + tool_notice
|
|
3125
|
+
self._update_memory_with_summary(
|
|
3126
|
+
summary_messages, include_summaries=False
|
|
3127
|
+
)
|
|
3128
|
+
self._last_token_limit_tool_signature = tool_signature
|
|
3129
|
+
return await self._astep_non_streaming_task(
|
|
3130
|
+
input_message, response_format
|
|
3131
|
+
)
|
|
3132
|
+
|
|
3133
|
+
raise
|
|
3134
|
+
|
|
3135
|
+
prev_num_openai_messages = len(openai_messages)
|
|
940
3136
|
iteration_count += 1
|
|
941
3137
|
|
|
3138
|
+
# Accumulate API token usage
|
|
3139
|
+
self._update_token_usage_tracker(
|
|
3140
|
+
step_token_usage, response.usage_dict
|
|
3141
|
+
)
|
|
3142
|
+
|
|
942
3143
|
# Terminate Agent if stop_event is set
|
|
943
3144
|
if self.stop_event and self.stop_event.is_set():
|
|
944
3145
|
# Use the _step_terminate to terminate the agent with reason
|
|
3146
|
+
logger.info(
|
|
3147
|
+
f"Termination triggered at iteration {iteration_count}"
|
|
3148
|
+
)
|
|
945
3149
|
return self._step_terminate(
|
|
946
3150
|
accumulated_context_tokens,
|
|
947
3151
|
tool_call_records,
|
|
@@ -959,6 +3163,17 @@ class ChatAgent(BaseAgent):
|
|
|
959
3163
|
external_tool_call_requests = []
|
|
960
3164
|
external_tool_call_requests.append(tool_call_request)
|
|
961
3165
|
else:
|
|
3166
|
+
if (
|
|
3167
|
+
self.pause_event is not None
|
|
3168
|
+
and not self.pause_event.is_set()
|
|
3169
|
+
):
|
|
3170
|
+
if isinstance(self.pause_event, asyncio.Event):
|
|
3171
|
+
await self.pause_event.wait()
|
|
3172
|
+
elif isinstance(self.pause_event, threading.Event):
|
|
3173
|
+
loop = asyncio.get_event_loop()
|
|
3174
|
+
await loop.run_in_executor(
|
|
3175
|
+
None, self.pause_event.wait
|
|
3176
|
+
)
|
|
962
3177
|
tool_call_record = await self._aexecute_tool(
|
|
963
3178
|
tool_call_request
|
|
964
3179
|
)
|
|
@@ -980,13 +3195,20 @@ class ChatAgent(BaseAgent):
|
|
|
980
3195
|
break
|
|
981
3196
|
|
|
982
3197
|
await self._aformat_response_if_needed(response, response_format)
|
|
3198
|
+
|
|
3199
|
+
# Apply manual parsing if we used prompt-based formatting
|
|
3200
|
+
if used_prompt_formatting and original_response_format:
|
|
3201
|
+
self._apply_prompt_based_parsing(
|
|
3202
|
+
response, original_response_format
|
|
3203
|
+
)
|
|
3204
|
+
|
|
983
3205
|
self._record_final_output(response.output_messages)
|
|
984
3206
|
|
|
985
|
-
#
|
|
986
|
-
|
|
3207
|
+
# Clean tool call messages from memory after response generation
|
|
3208
|
+
if self.prune_tool_calls_from_memory and tool_call_records:
|
|
3209
|
+
self.memory.clean_tool_calls()
|
|
987
3210
|
|
|
988
|
-
|
|
989
|
-
self._update_token_usage_tracker(step_token_usage, response.usage_dict)
|
|
3211
|
+
self._last_token_limit_user_signature = None
|
|
990
3212
|
|
|
991
3213
|
return self._convert_to_chatagent_response(
|
|
992
3214
|
response,
|
|
@@ -1063,121 +3285,156 @@ class ChatAgent(BaseAgent):
|
|
|
1063
3285
|
"selected message manually using `record_message()`."
|
|
1064
3286
|
)
|
|
1065
3287
|
|
|
3288
|
+
@observe()
|
|
1066
3289
|
def _get_model_response(
|
|
1067
3290
|
self,
|
|
1068
3291
|
openai_messages: List[OpenAIMessage],
|
|
1069
3292
|
num_tokens: int,
|
|
3293
|
+
current_iteration: int = 0,
|
|
1070
3294
|
response_format: Optional[Type[BaseModel]] = None,
|
|
1071
3295
|
tool_schemas: Optional[List[Dict[str, Any]]] = None,
|
|
3296
|
+
prev_num_openai_messages: int = 0,
|
|
1072
3297
|
) -> ModelResponse:
|
|
1073
3298
|
r"""Internal function for agent step model response."""
|
|
3299
|
+
last_error = None
|
|
1074
3300
|
|
|
1075
|
-
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
|
-
|
|
1084
|
-
|
|
1085
|
-
|
|
1086
|
-
|
|
1087
|
-
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
|
|
1091
|
-
|
|
1092
|
-
|
|
1093
|
-
|
|
1094
|
-
|
|
3301
|
+
for attempt in range(self.retry_attempts):
|
|
3302
|
+
try:
|
|
3303
|
+
response = self.model_backend.run(
|
|
3304
|
+
openai_messages, response_format, tool_schemas or None
|
|
3305
|
+
)
|
|
3306
|
+
if response:
|
|
3307
|
+
break
|
|
3308
|
+
except RateLimitError as e:
|
|
3309
|
+
if self._is_token_limit_error(e):
|
|
3310
|
+
raise
|
|
3311
|
+
last_error = e
|
|
3312
|
+
if attempt < self.retry_attempts - 1:
|
|
3313
|
+
delay = min(self.retry_delay * (2**attempt), 60.0)
|
|
3314
|
+
delay = random.uniform(0, delay) # Add jitter
|
|
3315
|
+
logger.warning(
|
|
3316
|
+
f"Rate limit hit (attempt {attempt + 1}"
|
|
3317
|
+
f"/{self.retry_attempts}). Retrying in {delay:.1f}s"
|
|
3318
|
+
)
|
|
3319
|
+
time.sleep(delay)
|
|
3320
|
+
else:
|
|
3321
|
+
logger.error(
|
|
3322
|
+
f"Rate limit exhausted after "
|
|
3323
|
+
f"{self.retry_attempts} attempts"
|
|
3324
|
+
)
|
|
3325
|
+
except Exception:
|
|
3326
|
+
logger.error(
|
|
3327
|
+
f"Model error: {self.model_backend.model_type}",
|
|
3328
|
+
)
|
|
3329
|
+
raise
|
|
3330
|
+
else:
|
|
3331
|
+
# Loop completed without success
|
|
1095
3332
|
raise ModelProcessingError(
|
|
1096
|
-
f"Unable to process messages:
|
|
1097
|
-
f"
|
|
3333
|
+
f"Unable to process messages: "
|
|
3334
|
+
f"{str(last_error) if last_error else 'Unknown error'}"
|
|
1098
3335
|
)
|
|
1099
3336
|
|
|
1100
|
-
|
|
1101
|
-
|
|
3337
|
+
# Log success
|
|
3338
|
+
sanitized = self._sanitize_messages_for_logging(
|
|
3339
|
+
openai_messages, prev_num_openai_messages
|
|
1102
3340
|
)
|
|
1103
3341
|
logger.info(
|
|
1104
|
-
f"Model {self.model_backend.model_type}
|
|
1105
|
-
f"
|
|
1106
|
-
f"processed these messages: {sanitized_messages}"
|
|
3342
|
+
f"Model {self.model_backend.model_type} "
|
|
3343
|
+
f"[{current_iteration}]: {sanitized}"
|
|
1107
3344
|
)
|
|
1108
3345
|
|
|
1109
|
-
if isinstance(response, ChatCompletion):
|
|
1110
|
-
|
|
1111
|
-
|
|
1112
|
-
|
|
3346
|
+
if not isinstance(response, ChatCompletion):
|
|
3347
|
+
raise TypeError(
|
|
3348
|
+
f"Expected ChatCompletion, got {type(response).__name__}"
|
|
3349
|
+
)
|
|
3350
|
+
|
|
3351
|
+
return self._handle_batch_response(response)
|
|
1113
3352
|
|
|
3353
|
+
@observe()
|
|
1114
3354
|
async def _aget_model_response(
|
|
1115
3355
|
self,
|
|
1116
3356
|
openai_messages: List[OpenAIMessage],
|
|
1117
3357
|
num_tokens: int,
|
|
3358
|
+
current_iteration: int = 0,
|
|
1118
3359
|
response_format: Optional[Type[BaseModel]] = None,
|
|
1119
3360
|
tool_schemas: Optional[List[Dict[str, Any]]] = None,
|
|
3361
|
+
prev_num_openai_messages: int = 0,
|
|
1120
3362
|
) -> ModelResponse:
|
|
1121
|
-
r"""Internal function for agent step model response."""
|
|
1122
|
-
|
|
1123
|
-
response = None
|
|
1124
|
-
try:
|
|
1125
|
-
response = await self.model_backend.arun(
|
|
1126
|
-
openai_messages, response_format, tool_schemas or None
|
|
1127
|
-
)
|
|
1128
|
-
except Exception as exc:
|
|
1129
|
-
logger.error(
|
|
1130
|
-
f"An error occurred while running model "
|
|
1131
|
-
f"{self.model_backend.model_type}, "
|
|
1132
|
-
f"index: {self.model_backend.current_model_index}",
|
|
1133
|
-
exc_info=exc,
|
|
1134
|
-
)
|
|
1135
|
-
error_info = str(exc)
|
|
3363
|
+
r"""Internal function for agent async step model response."""
|
|
3364
|
+
last_error = None
|
|
1136
3365
|
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
|
|
1141
|
-
|
|
1142
|
-
|
|
3366
|
+
for attempt in range(self.retry_attempts):
|
|
3367
|
+
try:
|
|
3368
|
+
response = await self.model_backend.arun(
|
|
3369
|
+
openai_messages, response_format, tool_schemas or None
|
|
3370
|
+
)
|
|
3371
|
+
if response:
|
|
3372
|
+
break
|
|
3373
|
+
except RateLimitError as e:
|
|
3374
|
+
if self._is_token_limit_error(e):
|
|
3375
|
+
raise
|
|
3376
|
+
last_error = e
|
|
3377
|
+
if attempt < self.retry_attempts - 1:
|
|
3378
|
+
delay = min(self.retry_delay * (2**attempt), 60.0)
|
|
3379
|
+
delay = random.uniform(0, delay) # Add jitter
|
|
3380
|
+
logger.warning(
|
|
3381
|
+
f"Rate limit hit (attempt {attempt + 1}"
|
|
3382
|
+
f"/{self.retry_attempts}). "
|
|
3383
|
+
f"Retrying in {delay:.1f}s"
|
|
3384
|
+
)
|
|
3385
|
+
await asyncio.sleep(delay)
|
|
3386
|
+
else:
|
|
3387
|
+
logger.error(
|
|
3388
|
+
f"Rate limit exhausted after "
|
|
3389
|
+
f"{self.retry_attempts} attempts"
|
|
3390
|
+
)
|
|
3391
|
+
except Exception:
|
|
3392
|
+
logger.error(
|
|
3393
|
+
f"Model error: {self.model_backend.model_type}",
|
|
3394
|
+
exc_info=True,
|
|
3395
|
+
)
|
|
3396
|
+
raise
|
|
3397
|
+
else:
|
|
3398
|
+
# Loop completed without success
|
|
1143
3399
|
raise ModelProcessingError(
|
|
1144
|
-
f"Unable to process messages:
|
|
1145
|
-
f"
|
|
3400
|
+
f"Unable to process messages: "
|
|
3401
|
+
f"{str(last_error) if last_error else 'Unknown error'}"
|
|
1146
3402
|
)
|
|
1147
3403
|
|
|
1148
|
-
|
|
1149
|
-
|
|
3404
|
+
# Log success
|
|
3405
|
+
sanitized = self._sanitize_messages_for_logging(
|
|
3406
|
+
openai_messages, prev_num_openai_messages
|
|
1150
3407
|
)
|
|
1151
3408
|
logger.info(
|
|
1152
|
-
f"Model {self.model_backend.model_type}
|
|
1153
|
-
f"
|
|
1154
|
-
f"processed these messages: {sanitized_messages}"
|
|
3409
|
+
f"Model {self.model_backend.model_type} "
|
|
3410
|
+
f"[{current_iteration}]: {sanitized}"
|
|
1155
3411
|
)
|
|
1156
3412
|
|
|
1157
|
-
if isinstance(response, ChatCompletion):
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
|
|
3413
|
+
if not isinstance(response, ChatCompletion):
|
|
3414
|
+
raise TypeError(
|
|
3415
|
+
f"Expected ChatCompletion, got {type(response).__name__}"
|
|
3416
|
+
)
|
|
3417
|
+
|
|
3418
|
+
return self._handle_batch_response(response)
|
|
1161
3419
|
|
|
1162
|
-
def _sanitize_messages_for_logging(
|
|
3420
|
+
def _sanitize_messages_for_logging(
|
|
3421
|
+
self, messages, prev_num_openai_messages: int
|
|
3422
|
+
):
|
|
1163
3423
|
r"""Sanitize OpenAI messages for logging by replacing base64 image
|
|
1164
3424
|
data with a simple message and a link to view the image.
|
|
1165
3425
|
|
|
1166
3426
|
Args:
|
|
1167
3427
|
messages (List[OpenAIMessage]): The OpenAI messages to sanitize.
|
|
3428
|
+
prev_num_openai_messages (int): The number of openai messages
|
|
3429
|
+
logged in the previous iteration.
|
|
1168
3430
|
|
|
1169
3431
|
Returns:
|
|
1170
3432
|
List[OpenAIMessage]: The sanitized OpenAI messages.
|
|
1171
3433
|
"""
|
|
1172
|
-
import hashlib
|
|
1173
|
-
import os
|
|
1174
|
-
import re
|
|
1175
|
-
import tempfile
|
|
1176
|
-
|
|
1177
3434
|
# Create a copy of messages for logging to avoid modifying the
|
|
1178
3435
|
# original messages
|
|
1179
3436
|
sanitized_messages = []
|
|
1180
|
-
for msg in messages:
|
|
3437
|
+
for msg in messages[prev_num_openai_messages:]:
|
|
1181
3438
|
if isinstance(msg, dict):
|
|
1182
3439
|
sanitized_msg = msg.copy()
|
|
1183
3440
|
# Check if content is a list (multimodal content with images)
|
|
@@ -1214,7 +3471,14 @@ class ChatAgent(BaseAgent):
|
|
|
1214
3471
|
|
|
1215
3472
|
# Save image to temp directory for viewing
|
|
1216
3473
|
try:
|
|
1217
|
-
|
|
3474
|
+
# Sanitize img_format to prevent path
|
|
3475
|
+
# traversal
|
|
3476
|
+
safe_format = re.sub(
|
|
3477
|
+
r'[^a-zA-Z0-9]', '', img_format
|
|
3478
|
+
)[:10]
|
|
3479
|
+
img_filename = (
|
|
3480
|
+
f"image_{img_hash}.{safe_format}"
|
|
3481
|
+
)
|
|
1218
3482
|
|
|
1219
3483
|
temp_dir = tempfile.gettempdir()
|
|
1220
3484
|
img_path = os.path.join(
|
|
@@ -1229,6 +3493,9 @@ class ChatAgent(BaseAgent):
|
|
|
1229
3493
|
base64_data
|
|
1230
3494
|
)
|
|
1231
3495
|
)
|
|
3496
|
+
# Register for cleanup
|
|
3497
|
+
with _temp_files_lock:
|
|
3498
|
+
_temp_files.add(img_path)
|
|
1232
3499
|
|
|
1233
3500
|
# Create a file:// URL that can be
|
|
1234
3501
|
# opened
|
|
@@ -1402,9 +3669,9 @@ class ChatAgent(BaseAgent):
|
|
|
1402
3669
|
if tool_calls := response.choices[0].message.tool_calls:
|
|
1403
3670
|
tool_call_requests = []
|
|
1404
3671
|
for tool_call in tool_calls:
|
|
1405
|
-
tool_name = tool_call.function.name
|
|
3672
|
+
tool_name = tool_call.function.name # type: ignore[union-attr]
|
|
1406
3673
|
tool_call_id = tool_call.id
|
|
1407
|
-
args = json.loads(tool_call.function.arguments)
|
|
3674
|
+
args = json.loads(tool_call.function.arguments) # type: ignore[union-attr]
|
|
1408
3675
|
tool_call_request = ToolCallRequest(
|
|
1409
3676
|
tool_name=tool_name, args=args, tool_call_id=tool_call_id
|
|
1410
3677
|
)
|
|
@@ -1419,285 +3686,1491 @@ class ChatAgent(BaseAgent):
|
|
|
1419
3686
|
response_id=response.id or "",
|
|
1420
3687
|
)
|
|
1421
3688
|
|
|
1422
|
-
def
|
|
1423
|
-
self,
|
|
1424
|
-
|
|
1425
|
-
|
|
1426
|
-
|
|
1427
|
-
|
|
1428
|
-
|
|
3689
|
+
def _step_terminate(
|
|
3690
|
+
self,
|
|
3691
|
+
num_tokens: int,
|
|
3692
|
+
tool_calls: List[ToolCallingRecord],
|
|
3693
|
+
termination_reason: str,
|
|
3694
|
+
) -> ChatAgentResponse:
|
|
3695
|
+
r"""Create a response when the agent execution is terminated.
|
|
3696
|
+
|
|
3697
|
+
This method is called when the agent needs to terminate its execution
|
|
3698
|
+
due to various reasons such as token limit exceeded, or other
|
|
3699
|
+
termination conditions. It creates a response with empty messages but
|
|
3700
|
+
includes termination information in the info dictionary.
|
|
3701
|
+
|
|
3702
|
+
Args:
|
|
3703
|
+
num_tokens (int): Number of tokens in the messages.
|
|
3704
|
+
tool_calls (List[ToolCallingRecord]): List of information
|
|
3705
|
+
objects of functions called in the current step.
|
|
3706
|
+
termination_reason (str): String describing the reason for
|
|
3707
|
+
termination.
|
|
3708
|
+
|
|
3709
|
+
Returns:
|
|
3710
|
+
ChatAgentResponse: A response object with empty message list,
|
|
3711
|
+
terminated flag set to True, and an info dictionary containing
|
|
3712
|
+
termination details, token counts, and tool call information.
|
|
3713
|
+
"""
|
|
3714
|
+
self.terminated = True
|
|
3715
|
+
|
|
3716
|
+
info = get_info_dict(
|
|
3717
|
+
None,
|
|
3718
|
+
None,
|
|
3719
|
+
[termination_reason],
|
|
3720
|
+
num_tokens,
|
|
3721
|
+
tool_calls,
|
|
3722
|
+
)
|
|
3723
|
+
|
|
3724
|
+
return ChatAgentResponse(
|
|
3725
|
+
msgs=[],
|
|
3726
|
+
terminated=self.terminated,
|
|
3727
|
+
info=info,
|
|
3728
|
+
)
|
|
3729
|
+
|
|
3730
|
+
@observe()
|
|
3731
|
+
def _execute_tool(
|
|
3732
|
+
self,
|
|
3733
|
+
tool_call_request: ToolCallRequest,
|
|
3734
|
+
) -> ToolCallingRecord:
|
|
3735
|
+
r"""Execute the tool with arguments following the model's response.
|
|
3736
|
+
|
|
3737
|
+
Args:
|
|
3738
|
+
tool_call_request (_ToolCallRequest): The tool call request.
|
|
3739
|
+
|
|
3740
|
+
Returns:
|
|
3741
|
+
FunctionCallingRecord: A struct for logging information about this
|
|
3742
|
+
function call.
|
|
3743
|
+
"""
|
|
3744
|
+
func_name = tool_call_request.tool_name
|
|
3745
|
+
args = tool_call_request.args
|
|
3746
|
+
tool_call_id = tool_call_request.tool_call_id
|
|
3747
|
+
tool = self._internal_tools[func_name]
|
|
3748
|
+
try:
|
|
3749
|
+
raw_result = tool(**args)
|
|
3750
|
+
if self.mask_tool_output:
|
|
3751
|
+
with self._secure_result_store_lock:
|
|
3752
|
+
self._secure_result_store[tool_call_id] = raw_result
|
|
3753
|
+
result = (
|
|
3754
|
+
"[The tool has been executed successfully, but the output"
|
|
3755
|
+
" from the tool is masked. You can move forward]"
|
|
3756
|
+
)
|
|
3757
|
+
mask_flag = True
|
|
3758
|
+
else:
|
|
3759
|
+
result = raw_result
|
|
3760
|
+
mask_flag = False
|
|
3761
|
+
except Exception as e:
|
|
3762
|
+
# Capture the error message to prevent framework crash
|
|
3763
|
+
error_msg = f"Error executing tool '{func_name}': {e!s}"
|
|
3764
|
+
result = f"Tool execution failed: {error_msg}"
|
|
3765
|
+
mask_flag = False
|
|
3766
|
+
logger.warning(f"{error_msg} with result: {result}")
|
|
3767
|
+
|
|
3768
|
+
return self._record_tool_calling(
|
|
3769
|
+
func_name, args, result, tool_call_id, mask_output=mask_flag
|
|
3770
|
+
)
|
|
3771
|
+
|
|
3772
|
+
async def _aexecute_tool(
|
|
3773
|
+
self,
|
|
3774
|
+
tool_call_request: ToolCallRequest,
|
|
3775
|
+
) -> ToolCallingRecord:
|
|
3776
|
+
func_name = tool_call_request.tool_name
|
|
3777
|
+
args = tool_call_request.args
|
|
3778
|
+
tool_call_id = tool_call_request.tool_call_id
|
|
3779
|
+
tool = self._internal_tools[func_name]
|
|
3780
|
+
import asyncio
|
|
3781
|
+
|
|
3782
|
+
try:
|
|
3783
|
+
# Try different invocation paths in order of preference
|
|
3784
|
+
if hasattr(tool, 'func') and hasattr(tool.func, 'async_call'):
|
|
3785
|
+
# Case: FunctionTool wrapping an MCP tool
|
|
3786
|
+
result = await tool.func.async_call(**args)
|
|
3787
|
+
|
|
3788
|
+
elif hasattr(tool, 'async_call') and callable(tool.async_call):
|
|
3789
|
+
# Case: tool itself has async_call
|
|
3790
|
+
result = await tool.async_call(**args)
|
|
3791
|
+
|
|
3792
|
+
elif hasattr(tool, 'func') and asyncio.iscoroutinefunction(
|
|
3793
|
+
tool.func
|
|
3794
|
+
):
|
|
3795
|
+
# Case: tool wraps a direct async function
|
|
3796
|
+
result = await tool.func(**args)
|
|
3797
|
+
|
|
3798
|
+
elif asyncio.iscoroutinefunction(tool):
|
|
3799
|
+
# Case: tool is itself a coroutine function
|
|
3800
|
+
result = await tool(**args)
|
|
3801
|
+
|
|
3802
|
+
else:
|
|
3803
|
+
# Fallback: synchronous call
|
|
3804
|
+
result = tool(**args)
|
|
3805
|
+
|
|
3806
|
+
except Exception as e:
|
|
3807
|
+
# Capture the error message to prevent framework crash
|
|
3808
|
+
error_msg = f"Error executing async tool '{func_name}': {e!s}"
|
|
3809
|
+
result = f"Tool execution failed: {error_msg}"
|
|
3810
|
+
logger.warning(error_msg)
|
|
3811
|
+
return self._record_tool_calling(func_name, args, result, tool_call_id)
|
|
3812
|
+
|
|
3813
|
+
def _record_tool_calling(
|
|
3814
|
+
self,
|
|
3815
|
+
func_name: str,
|
|
3816
|
+
args: Dict[str, Any],
|
|
3817
|
+
result: Any,
|
|
3818
|
+
tool_call_id: str,
|
|
3819
|
+
mask_output: bool = False,
|
|
3820
|
+
):
|
|
3821
|
+
r"""Record the tool calling information in the memory, and return the
|
|
3822
|
+
tool calling record.
|
|
3823
|
+
|
|
3824
|
+
Args:
|
|
3825
|
+
func_name (str): The name of the tool function called.
|
|
3826
|
+
args (Dict[str, Any]): The arguments passed to the tool.
|
|
3827
|
+
result (Any): The result returned by the tool execution.
|
|
3828
|
+
tool_call_id (str): A unique identifier for the tool call.
|
|
3829
|
+
mask_output (bool, optional): Whether to return a sanitized
|
|
3830
|
+
placeholder instead of the raw tool output.
|
|
3831
|
+
(default: :obj:`False`)
|
|
3832
|
+
|
|
3833
|
+
Returns:
|
|
3834
|
+
ToolCallingRecord: A struct containing information about
|
|
3835
|
+
this tool call.
|
|
3836
|
+
"""
|
|
3837
|
+
assist_msg = FunctionCallingMessage(
|
|
3838
|
+
role_name=self.role_name,
|
|
3839
|
+
role_type=self.role_type,
|
|
3840
|
+
meta_dict=None,
|
|
3841
|
+
content="",
|
|
3842
|
+
func_name=func_name,
|
|
3843
|
+
args=args,
|
|
3844
|
+
tool_call_id=tool_call_id,
|
|
3845
|
+
)
|
|
3846
|
+
func_msg = FunctionCallingMessage(
|
|
3847
|
+
role_name=self.role_name,
|
|
3848
|
+
role_type=self.role_type,
|
|
3849
|
+
meta_dict=None,
|
|
3850
|
+
content="",
|
|
3851
|
+
func_name=func_name,
|
|
3852
|
+
result=result,
|
|
3853
|
+
tool_call_id=tool_call_id,
|
|
3854
|
+
mask_output=mask_output,
|
|
3855
|
+
)
|
|
3856
|
+
|
|
3857
|
+
# Use precise timestamps to ensure correct ordering
|
|
3858
|
+
# This ensures the assistant message (tool call) always appears before
|
|
3859
|
+
# the function message (tool result) in the conversation context
|
|
3860
|
+
# Use time.time_ns() for nanosecond precision to avoid collisions
|
|
3861
|
+
current_time_ns = time.time_ns()
|
|
3862
|
+
base_timestamp = current_time_ns / 1_000_000_000 # Convert to seconds
|
|
3863
|
+
|
|
3864
|
+
self.update_memory(
|
|
3865
|
+
assist_msg,
|
|
3866
|
+
OpenAIBackendRole.ASSISTANT,
|
|
3867
|
+
timestamp=base_timestamp,
|
|
3868
|
+
return_records=self._enable_snapshot_clean,
|
|
3869
|
+
)
|
|
3870
|
+
|
|
3871
|
+
# Add minimal increment to ensure function message comes after
|
|
3872
|
+
func_records = self.update_memory(
|
|
3873
|
+
func_msg,
|
|
3874
|
+
OpenAIBackendRole.FUNCTION,
|
|
3875
|
+
timestamp=base_timestamp + 1e-6,
|
|
3876
|
+
return_records=self._enable_snapshot_clean,
|
|
3877
|
+
)
|
|
3878
|
+
|
|
3879
|
+
# Register tool output for snapshot cleaning if enabled
|
|
3880
|
+
if self._enable_snapshot_clean and not mask_output and func_records:
|
|
3881
|
+
serialized_result = self._serialize_tool_result(result)
|
|
3882
|
+
self._register_tool_output_for_cache(
|
|
3883
|
+
func_name,
|
|
3884
|
+
tool_call_id,
|
|
3885
|
+
serialized_result,
|
|
3886
|
+
cast(List[MemoryRecord], func_records),
|
|
3887
|
+
)
|
|
3888
|
+
|
|
3889
|
+
# Record information about this tool call
|
|
3890
|
+
tool_record = ToolCallingRecord(
|
|
3891
|
+
tool_name=func_name,
|
|
3892
|
+
args=args,
|
|
3893
|
+
result=result,
|
|
3894
|
+
tool_call_id=tool_call_id,
|
|
3895
|
+
)
|
|
3896
|
+
|
|
3897
|
+
self._update_last_tool_call_state(tool_record)
|
|
3898
|
+
return tool_record
|
|
3899
|
+
|
|
3900
|
+
def _stream(
|
|
3901
|
+
self,
|
|
3902
|
+
input_message: Union[BaseMessage, str],
|
|
3903
|
+
response_format: Optional[Type[BaseModel]] = None,
|
|
3904
|
+
) -> Generator[ChatAgentResponse, None, None]:
|
|
3905
|
+
r"""Executes a streaming step in the chat session, yielding
|
|
3906
|
+
intermediate responses as they are generated.
|
|
3907
|
+
|
|
3908
|
+
Args:
|
|
3909
|
+
input_message (Union[BaseMessage, str]): The input message for the
|
|
3910
|
+
agent.
|
|
3911
|
+
response_format (Optional[Type[BaseModel]], optional): A Pydantic
|
|
3912
|
+
model defining the expected structure of the response.
|
|
3913
|
+
|
|
3914
|
+
Yields:
|
|
3915
|
+
ChatAgentResponse: Intermediate responses containing partial
|
|
3916
|
+
content, tool calls, and other information as they become
|
|
3917
|
+
available.
|
|
3918
|
+
"""
|
|
3919
|
+
# Convert input message to BaseMessage if necessary
|
|
3920
|
+
if isinstance(input_message, str):
|
|
3921
|
+
input_message = BaseMessage.make_user_message(
|
|
3922
|
+
role_name="User", content=input_message
|
|
3923
|
+
)
|
|
3924
|
+
|
|
3925
|
+
# Add user input to memory
|
|
3926
|
+
self.update_memory(input_message, OpenAIBackendRole.USER)
|
|
3927
|
+
|
|
3928
|
+
# Get context for streaming
|
|
3929
|
+
try:
|
|
3930
|
+
openai_messages, num_tokens = self.memory.get_context()
|
|
3931
|
+
except RuntimeError as e:
|
|
3932
|
+
yield self._step_terminate(e.args[1], [], "max_tokens_exceeded")
|
|
3933
|
+
return
|
|
3934
|
+
|
|
3935
|
+
# Start streaming response
|
|
3936
|
+
yield from self._stream_response(
|
|
3937
|
+
openai_messages, num_tokens, response_format
|
|
3938
|
+
)
|
|
3939
|
+
|
|
3940
|
+
def _get_token_count(self, content: str) -> int:
|
|
3941
|
+
r"""Get token count for content with fallback."""
|
|
3942
|
+
if hasattr(self.model_backend, 'token_counter'):
|
|
3943
|
+
return len(self.model_backend.token_counter.encode(content))
|
|
3944
|
+
else:
|
|
3945
|
+
return len(content.split())
|
|
3946
|
+
|
|
3947
|
+
def _stream_response(
|
|
3948
|
+
self,
|
|
3949
|
+
openai_messages: List[OpenAIMessage],
|
|
3950
|
+
num_tokens: int,
|
|
3951
|
+
response_format: Optional[Type[BaseModel]] = None,
|
|
3952
|
+
) -> Generator[ChatAgentResponse, None, None]:
|
|
3953
|
+
r"""Internal method to handle streaming responses with tool calls."""
|
|
3954
|
+
|
|
3955
|
+
tool_call_records: List[ToolCallingRecord] = []
|
|
3956
|
+
accumulated_tool_calls: Dict[str, Any] = {}
|
|
3957
|
+
step_token_usage = self._create_token_usage_tracker()
|
|
3958
|
+
|
|
3959
|
+
# Create content accumulator for proper content management
|
|
3960
|
+
content_accumulator = StreamContentAccumulator()
|
|
3961
|
+
iteration_count = 0
|
|
3962
|
+
while True:
|
|
3963
|
+
# Check termination condition
|
|
3964
|
+
if self.stop_event and self.stop_event.is_set():
|
|
3965
|
+
logger.info(
|
|
3966
|
+
f"Termination triggered at iteration {iteration_count}"
|
|
3967
|
+
)
|
|
3968
|
+
yield self._step_terminate(
|
|
3969
|
+
num_tokens, tool_call_records, "termination_triggered"
|
|
3970
|
+
)
|
|
3971
|
+
return
|
|
3972
|
+
|
|
3973
|
+
# Get streaming response from model
|
|
3974
|
+
try:
|
|
3975
|
+
response = self.model_backend.run(
|
|
3976
|
+
openai_messages,
|
|
3977
|
+
response_format,
|
|
3978
|
+
self._get_full_tool_schemas() or None,
|
|
3979
|
+
)
|
|
3980
|
+
iteration_count += 1
|
|
3981
|
+
except Exception as exc:
|
|
3982
|
+
logger.error(
|
|
3983
|
+
f"Error in streaming model response: {exc}", exc_info=exc
|
|
3984
|
+
)
|
|
3985
|
+
yield self._create_error_response(str(exc), tool_call_records)
|
|
3986
|
+
return
|
|
3987
|
+
|
|
3988
|
+
# Handle streaming response
|
|
3989
|
+
if isinstance(response, Stream):
|
|
3990
|
+
(
|
|
3991
|
+
stream_completed,
|
|
3992
|
+
tool_calls_complete,
|
|
3993
|
+
) = yield from self._process_stream_chunks_with_accumulator(
|
|
3994
|
+
response,
|
|
3995
|
+
content_accumulator,
|
|
3996
|
+
accumulated_tool_calls,
|
|
3997
|
+
tool_call_records,
|
|
3998
|
+
step_token_usage,
|
|
3999
|
+
response_format,
|
|
4000
|
+
)
|
|
4001
|
+
|
|
4002
|
+
if tool_calls_complete:
|
|
4003
|
+
# Clear completed tool calls
|
|
4004
|
+
accumulated_tool_calls.clear()
|
|
4005
|
+
|
|
4006
|
+
# If we executed tools and not in
|
|
4007
|
+
# single iteration mode, continue
|
|
4008
|
+
if tool_call_records and (
|
|
4009
|
+
self.max_iteration is None
|
|
4010
|
+
or iteration_count < self.max_iteration
|
|
4011
|
+
):
|
|
4012
|
+
# Update messages with tool results for next iteration
|
|
4013
|
+
try:
|
|
4014
|
+
openai_messages, num_tokens = (
|
|
4015
|
+
self.memory.get_context()
|
|
4016
|
+
)
|
|
4017
|
+
except RuntimeError as e:
|
|
4018
|
+
yield self._step_terminate(
|
|
4019
|
+
e.args[1],
|
|
4020
|
+
tool_call_records,
|
|
4021
|
+
"max_tokens_exceeded",
|
|
4022
|
+
)
|
|
4023
|
+
return
|
|
4024
|
+
# Reset streaming content for next iteration
|
|
4025
|
+
content_accumulator.reset_streaming_content()
|
|
4026
|
+
continue
|
|
4027
|
+
else:
|
|
4028
|
+
break
|
|
4029
|
+
else:
|
|
4030
|
+
# Stream completed without tool calls
|
|
4031
|
+
accumulated_tool_calls.clear()
|
|
4032
|
+
break
|
|
4033
|
+
elif hasattr(response, '__enter__') and hasattr(
|
|
4034
|
+
response, '__exit__'
|
|
4035
|
+
):
|
|
4036
|
+
# Handle structured output stream (ChatCompletionStreamManager)
|
|
4037
|
+
with response as stream:
|
|
4038
|
+
parsed_object = None
|
|
4039
|
+
|
|
4040
|
+
for event in stream:
|
|
4041
|
+
if event.type == "content.delta":
|
|
4042
|
+
if getattr(event, "delta", None):
|
|
4043
|
+
# Use accumulator for proper content management
|
|
4044
|
+
partial_response = self._create_streaming_response_with_accumulator( # noqa: E501
|
|
4045
|
+
content_accumulator,
|
|
4046
|
+
getattr(event, "delta", ""),
|
|
4047
|
+
step_token_usage,
|
|
4048
|
+
tool_call_records=tool_call_records.copy(),
|
|
4049
|
+
)
|
|
4050
|
+
yield partial_response
|
|
4051
|
+
|
|
4052
|
+
elif event.type == "content.done":
|
|
4053
|
+
parsed_object = getattr(event, "parsed", None)
|
|
4054
|
+
break
|
|
4055
|
+
elif event.type == "error":
|
|
4056
|
+
logger.error(
|
|
4057
|
+
f"Error in structured stream: "
|
|
4058
|
+
f"{getattr(event, 'error', '')}"
|
|
4059
|
+
)
|
|
4060
|
+
yield self._create_error_response(
|
|
4061
|
+
str(getattr(event, 'error', '')),
|
|
4062
|
+
tool_call_records,
|
|
4063
|
+
)
|
|
4064
|
+
return
|
|
4065
|
+
|
|
4066
|
+
# Get final completion and record final message
|
|
4067
|
+
try:
|
|
4068
|
+
final_completion = stream.get_final_completion()
|
|
4069
|
+
final_content = (
|
|
4070
|
+
final_completion.choices[0].message.content or ""
|
|
4071
|
+
)
|
|
4072
|
+
|
|
4073
|
+
final_message = BaseMessage(
|
|
4074
|
+
role_name=self.role_name,
|
|
4075
|
+
role_type=self.role_type,
|
|
4076
|
+
meta_dict={},
|
|
4077
|
+
content=final_content,
|
|
4078
|
+
parsed=cast(
|
|
4079
|
+
"BaseModel | dict[str, Any] | None",
|
|
4080
|
+
parsed_object,
|
|
4081
|
+
), # type: ignore[arg-type]
|
|
4082
|
+
)
|
|
4083
|
+
|
|
4084
|
+
self.record_message(final_message)
|
|
4085
|
+
|
|
4086
|
+
# Create final response
|
|
4087
|
+
final_response = ChatAgentResponse(
|
|
4088
|
+
msgs=[final_message],
|
|
4089
|
+
terminated=False,
|
|
4090
|
+
info={
|
|
4091
|
+
"id": final_completion.id or "",
|
|
4092
|
+
"usage": safe_model_dump(
|
|
4093
|
+
final_completion.usage
|
|
4094
|
+
)
|
|
4095
|
+
if final_completion.usage
|
|
4096
|
+
else {},
|
|
4097
|
+
"finish_reasons": [
|
|
4098
|
+
choice.finish_reason or "stop"
|
|
4099
|
+
for choice in final_completion.choices
|
|
4100
|
+
],
|
|
4101
|
+
"num_tokens": self._get_token_count(
|
|
4102
|
+
final_content
|
|
4103
|
+
),
|
|
4104
|
+
"tool_calls": tool_call_records,
|
|
4105
|
+
"external_tool_requests": None,
|
|
4106
|
+
"streaming": False,
|
|
4107
|
+
"partial": False,
|
|
4108
|
+
},
|
|
4109
|
+
)
|
|
4110
|
+
yield final_response
|
|
4111
|
+
break
|
|
4112
|
+
|
|
4113
|
+
except Exception as e:
|
|
4114
|
+
logger.error(f"Error getting final completion: {e}")
|
|
4115
|
+
yield self._create_error_response(
|
|
4116
|
+
str(e), tool_call_records
|
|
4117
|
+
)
|
|
4118
|
+
return
|
|
4119
|
+
else:
|
|
4120
|
+
# Handle non-streaming response (fallback)
|
|
4121
|
+
model_response = self._handle_batch_response(response)
|
|
4122
|
+
yield self._convert_to_chatagent_response(
|
|
4123
|
+
model_response,
|
|
4124
|
+
tool_call_records,
|
|
4125
|
+
num_tokens,
|
|
4126
|
+
None,
|
|
4127
|
+
model_response.usage_dict.get("prompt_tokens", 0),
|
|
4128
|
+
model_response.usage_dict.get("completion_tokens", 0),
|
|
4129
|
+
model_response.usage_dict.get("total_tokens", 0),
|
|
4130
|
+
)
|
|
4131
|
+
accumulated_tool_calls.clear()
|
|
4132
|
+
break
|
|
4133
|
+
|
|
4134
|
+
def _process_stream_chunks_with_accumulator(
|
|
4135
|
+
self,
|
|
4136
|
+
stream: Stream[ChatCompletionChunk],
|
|
4137
|
+
content_accumulator: StreamContentAccumulator,
|
|
4138
|
+
accumulated_tool_calls: Dict[str, Any],
|
|
4139
|
+
tool_call_records: List[ToolCallingRecord],
|
|
4140
|
+
step_token_usage: Dict[str, int],
|
|
4141
|
+
response_format: Optional[Type[BaseModel]] = None,
|
|
4142
|
+
) -> Generator[ChatAgentResponse, None, Tuple[bool, bool]]:
|
|
4143
|
+
r"""Process streaming chunks with content accumulator."""
|
|
4144
|
+
|
|
4145
|
+
tool_calls_complete = False
|
|
4146
|
+
stream_completed = False
|
|
4147
|
+
|
|
4148
|
+
for chunk in stream:
|
|
4149
|
+
# Process chunk delta
|
|
4150
|
+
if chunk.choices and len(chunk.choices) > 0:
|
|
4151
|
+
choice = chunk.choices[0]
|
|
4152
|
+
delta = choice.delta
|
|
4153
|
+
|
|
4154
|
+
# Handle content streaming
|
|
4155
|
+
if delta.content:
|
|
4156
|
+
# Use accumulator for proper content management
|
|
4157
|
+
partial_response = (
|
|
4158
|
+
self._create_streaming_response_with_accumulator(
|
|
4159
|
+
content_accumulator,
|
|
4160
|
+
delta.content,
|
|
4161
|
+
step_token_usage,
|
|
4162
|
+
getattr(chunk, 'id', ''),
|
|
4163
|
+
tool_call_records.copy(),
|
|
4164
|
+
)
|
|
4165
|
+
)
|
|
4166
|
+
yield partial_response
|
|
4167
|
+
|
|
4168
|
+
# Handle tool calls streaming
|
|
4169
|
+
if delta.tool_calls:
|
|
4170
|
+
tool_calls_complete = self._accumulate_tool_calls(
|
|
4171
|
+
delta.tool_calls, accumulated_tool_calls
|
|
4172
|
+
)
|
|
4173
|
+
|
|
4174
|
+
# Check if stream is complete
|
|
4175
|
+
if choice.finish_reason:
|
|
4176
|
+
stream_completed = True
|
|
4177
|
+
|
|
4178
|
+
# If we have complete tool calls, execute them with
|
|
4179
|
+
# sync status updates
|
|
4180
|
+
if accumulated_tool_calls:
|
|
4181
|
+
# Execute tools synchronously with
|
|
4182
|
+
# optimized status updates
|
|
4183
|
+
for (
|
|
4184
|
+
status_response
|
|
4185
|
+
) in self._execute_tools_sync_with_status_accumulator(
|
|
4186
|
+
accumulated_tool_calls,
|
|
4187
|
+
tool_call_records,
|
|
4188
|
+
):
|
|
4189
|
+
yield status_response
|
|
4190
|
+
|
|
4191
|
+
# Log sending status instead of adding to content
|
|
4192
|
+
if tool_call_records:
|
|
4193
|
+
logger.info("Sending back result to model")
|
|
4194
|
+
|
|
4195
|
+
# Record final message only if we have content AND no tool
|
|
4196
|
+
# calls. If there are tool calls, _record_tool_calling
|
|
4197
|
+
# will handle message recording.
|
|
4198
|
+
final_content = content_accumulator.get_full_content()
|
|
4199
|
+
if final_content.strip() and not accumulated_tool_calls:
|
|
4200
|
+
final_message = BaseMessage(
|
|
4201
|
+
role_name=self.role_name,
|
|
4202
|
+
role_type=self.role_type,
|
|
4203
|
+
meta_dict={},
|
|
4204
|
+
content=final_content,
|
|
4205
|
+
)
|
|
4206
|
+
|
|
4207
|
+
if response_format:
|
|
4208
|
+
self._try_format_message(
|
|
4209
|
+
final_message, response_format
|
|
4210
|
+
)
|
|
4211
|
+
|
|
4212
|
+
self.record_message(final_message)
|
|
4213
|
+
elif chunk.usage and not chunk.choices:
|
|
4214
|
+
# Handle final chunk with usage but empty choices
|
|
4215
|
+
# This happens when stream_options={"include_usage": True}
|
|
4216
|
+
# Update the final usage from this chunk
|
|
4217
|
+
self._update_token_usage_tracker(
|
|
4218
|
+
step_token_usage, safe_model_dump(chunk.usage)
|
|
4219
|
+
)
|
|
4220
|
+
|
|
4221
|
+
# Create final response with final usage
|
|
4222
|
+
final_content = content_accumulator.get_full_content()
|
|
4223
|
+
if final_content.strip():
|
|
4224
|
+
final_message = BaseMessage(
|
|
4225
|
+
role_name=self.role_name,
|
|
4226
|
+
role_type=self.role_type,
|
|
4227
|
+
meta_dict={},
|
|
4228
|
+
content=final_content,
|
|
4229
|
+
)
|
|
4230
|
+
|
|
4231
|
+
if response_format:
|
|
4232
|
+
self._try_format_message(
|
|
4233
|
+
final_message, response_format
|
|
4234
|
+
)
|
|
4235
|
+
|
|
4236
|
+
# Create final response with final usage (not partial)
|
|
4237
|
+
final_response = ChatAgentResponse(
|
|
4238
|
+
msgs=[final_message],
|
|
4239
|
+
terminated=False,
|
|
4240
|
+
info={
|
|
4241
|
+
"id": getattr(chunk, 'id', ''),
|
|
4242
|
+
"usage": step_token_usage.copy(),
|
|
4243
|
+
"finish_reasons": ["stop"],
|
|
4244
|
+
"num_tokens": self._get_token_count(final_content),
|
|
4245
|
+
"tool_calls": tool_call_records or [],
|
|
4246
|
+
"external_tool_requests": None,
|
|
4247
|
+
"streaming": False,
|
|
4248
|
+
"partial": False,
|
|
4249
|
+
},
|
|
4250
|
+
)
|
|
4251
|
+
yield final_response
|
|
4252
|
+
break
|
|
4253
|
+
elif stream_completed:
|
|
4254
|
+
# If we've already seen finish_reason but no usage chunk, exit
|
|
4255
|
+
break
|
|
4256
|
+
|
|
4257
|
+
return stream_completed, tool_calls_complete
|
|
4258
|
+
|
|
4259
|
+
def _accumulate_tool_calls(
|
|
4260
|
+
self,
|
|
4261
|
+
tool_call_deltas: List[Any],
|
|
4262
|
+
accumulated_tool_calls: Dict[str, Any],
|
|
4263
|
+
) -> bool:
|
|
4264
|
+
r"""Accumulate tool call chunks and return True when
|
|
4265
|
+
any tool call is complete.
|
|
4266
|
+
|
|
4267
|
+
Args:
|
|
4268
|
+
tool_call_deltas (List[Any]): List of tool call deltas.
|
|
4269
|
+
accumulated_tool_calls (Dict[str, Any]): Dictionary of accumulated
|
|
4270
|
+
tool calls.
|
|
4271
|
+
|
|
4272
|
+
Returns:
|
|
4273
|
+
bool: True if any tool call is complete, False otherwise.
|
|
4274
|
+
"""
|
|
4275
|
+
|
|
4276
|
+
for delta_tool_call in tool_call_deltas:
|
|
4277
|
+
index = delta_tool_call.index
|
|
4278
|
+
tool_call_id = getattr(delta_tool_call, 'id', None)
|
|
4279
|
+
|
|
4280
|
+
# Initialize tool call entry if not exists
|
|
4281
|
+
if index not in accumulated_tool_calls:
|
|
4282
|
+
accumulated_tool_calls[index] = {
|
|
4283
|
+
'id': '',
|
|
4284
|
+
'type': 'function',
|
|
4285
|
+
'function': {'name': '', 'arguments': ''},
|
|
4286
|
+
'complete': False,
|
|
4287
|
+
}
|
|
4288
|
+
|
|
4289
|
+
tool_call_entry = accumulated_tool_calls[index]
|
|
4290
|
+
|
|
4291
|
+
# Accumulate tool call data
|
|
4292
|
+
if tool_call_id:
|
|
4293
|
+
tool_call_entry['id'] = (
|
|
4294
|
+
tool_call_id # Set full ID, don't append
|
|
4295
|
+
)
|
|
4296
|
+
|
|
4297
|
+
if (
|
|
4298
|
+
hasattr(delta_tool_call, 'function')
|
|
4299
|
+
and delta_tool_call.function
|
|
4300
|
+
):
|
|
4301
|
+
if delta_tool_call.function.name:
|
|
4302
|
+
tool_call_entry['function']['name'] += (
|
|
4303
|
+
delta_tool_call.function.name
|
|
4304
|
+
) # Append incremental name
|
|
4305
|
+
if delta_tool_call.function.arguments:
|
|
4306
|
+
tool_call_entry['function']['arguments'] += (
|
|
4307
|
+
delta_tool_call.function.arguments
|
|
4308
|
+
)
|
|
4309
|
+
|
|
4310
|
+
# Check if any tool calls are complete
|
|
4311
|
+
any_complete = False
|
|
4312
|
+
for _index, tool_call_entry in accumulated_tool_calls.items():
|
|
4313
|
+
if (
|
|
4314
|
+
tool_call_entry['id']
|
|
4315
|
+
and tool_call_entry['function']['name']
|
|
4316
|
+
and tool_call_entry['function']['arguments']
|
|
4317
|
+
and tool_call_entry['function']['name'] in self._internal_tools
|
|
4318
|
+
):
|
|
4319
|
+
try:
|
|
4320
|
+
# Try to parse arguments to check completeness
|
|
4321
|
+
json.loads(tool_call_entry['function']['arguments'])
|
|
4322
|
+
tool_call_entry['complete'] = True
|
|
4323
|
+
any_complete = True
|
|
4324
|
+
except json.JSONDecodeError:
|
|
4325
|
+
# Arguments not complete yet
|
|
4326
|
+
tool_call_entry['complete'] = False
|
|
4327
|
+
|
|
4328
|
+
return any_complete
|
|
4329
|
+
|
|
4330
|
+
def _execute_tools_sync_with_status_accumulator(
|
|
4331
|
+
self,
|
|
4332
|
+
accumulated_tool_calls: Dict[str, Any],
|
|
4333
|
+
tool_call_records: List[ToolCallingRecord],
|
|
4334
|
+
) -> Generator[ChatAgentResponse, None, None]:
|
|
4335
|
+
r"""Execute multiple tools synchronously with proper content
|
|
4336
|
+
accumulation, using ThreadPoolExecutor for better timeout handling."""
|
|
4337
|
+
|
|
4338
|
+
tool_calls_to_execute = []
|
|
4339
|
+
for _tool_call_index, tool_call_data in accumulated_tool_calls.items():
|
|
4340
|
+
if tool_call_data.get('complete', False):
|
|
4341
|
+
tool_calls_to_execute.append(tool_call_data)
|
|
4342
|
+
|
|
4343
|
+
if not tool_calls_to_execute:
|
|
4344
|
+
# No tools to execute, return immediately
|
|
4345
|
+
return
|
|
4346
|
+
yield # Make this a generator
|
|
4347
|
+
|
|
4348
|
+
# Execute tools using ThreadPoolExecutor for proper timeout handling
|
|
4349
|
+
# Use max_workers=len() for parallel execution, with min of 1
|
|
4350
|
+
with concurrent.futures.ThreadPoolExecutor(
|
|
4351
|
+
max_workers=max(1, len(tool_calls_to_execute))
|
|
4352
|
+
) as executor:
|
|
4353
|
+
# Submit all tools first (parallel execution)
|
|
4354
|
+
futures_map = {}
|
|
4355
|
+
for tool_call_data in tool_calls_to_execute:
|
|
4356
|
+
function_name = tool_call_data['function']['name']
|
|
4357
|
+
try:
|
|
4358
|
+
args = json.loads(tool_call_data['function']['arguments'])
|
|
4359
|
+
except json.JSONDecodeError:
|
|
4360
|
+
args = tool_call_data['function']['arguments']
|
|
4361
|
+
|
|
4362
|
+
# Log debug info
|
|
4363
|
+
logger.info(
|
|
4364
|
+
f"Calling function: {function_name} with arguments: {args}"
|
|
4365
|
+
)
|
|
4366
|
+
|
|
4367
|
+
# Submit tool execution (non-blocking)
|
|
4368
|
+
future = executor.submit(
|
|
4369
|
+
self._execute_tool_from_stream_data, tool_call_data
|
|
4370
|
+
)
|
|
4371
|
+
futures_map[future] = (function_name, tool_call_data)
|
|
4372
|
+
|
|
4373
|
+
# Wait for all futures to complete (or timeout)
|
|
4374
|
+
for future in concurrent.futures.as_completed(
|
|
4375
|
+
futures_map.keys(),
|
|
4376
|
+
timeout=self.tool_execution_timeout
|
|
4377
|
+
if self.tool_execution_timeout
|
|
4378
|
+
else None,
|
|
4379
|
+
):
|
|
4380
|
+
function_name, tool_call_data = futures_map[future]
|
|
4381
|
+
|
|
4382
|
+
try:
|
|
4383
|
+
tool_call_record = future.result()
|
|
4384
|
+
if tool_call_record:
|
|
4385
|
+
tool_call_records.append(tool_call_record)
|
|
4386
|
+
logger.info(
|
|
4387
|
+
f"Function output: {tool_call_record.result}"
|
|
4388
|
+
)
|
|
4389
|
+
except concurrent.futures.TimeoutError:
|
|
4390
|
+
logger.warning(
|
|
4391
|
+
f"Function '{function_name}' timed out after "
|
|
4392
|
+
f"{self.tool_execution_timeout} seconds"
|
|
4393
|
+
)
|
|
4394
|
+
future.cancel()
|
|
4395
|
+
except Exception as e:
|
|
4396
|
+
logger.error(
|
|
4397
|
+
f"Error executing tool '{function_name}': {e}"
|
|
4398
|
+
)
|
|
4399
|
+
|
|
4400
|
+
# Ensure this function remains a generator (required by type signature)
|
|
4401
|
+
return
|
|
4402
|
+
yield # This line is never reached but makes this a generator function
|
|
4403
|
+
|
|
4404
|
+
def _execute_tool_from_stream_data(
|
|
4405
|
+
self, tool_call_data: Dict[str, Any]
|
|
4406
|
+
) -> Optional[ToolCallingRecord]:
|
|
4407
|
+
r"""Execute a tool from accumulated stream data."""
|
|
4408
|
+
|
|
4409
|
+
try:
|
|
4410
|
+
function_name = tool_call_data['function']['name']
|
|
4411
|
+
args = json.loads(tool_call_data['function']['arguments'])
|
|
4412
|
+
tool_call_id = tool_call_data['id']
|
|
4413
|
+
|
|
4414
|
+
if function_name in self._internal_tools:
|
|
4415
|
+
tool = self._internal_tools[function_name]
|
|
4416
|
+
try:
|
|
4417
|
+
result = tool(**args)
|
|
4418
|
+
# First, create and record the assistant message with tool
|
|
4419
|
+
# call
|
|
4420
|
+
assist_msg = FunctionCallingMessage(
|
|
4421
|
+
role_name=self.role_name,
|
|
4422
|
+
role_type=self.role_type,
|
|
4423
|
+
meta_dict=None,
|
|
4424
|
+
content="",
|
|
4425
|
+
func_name=function_name,
|
|
4426
|
+
args=args,
|
|
4427
|
+
tool_call_id=tool_call_id,
|
|
4428
|
+
)
|
|
4429
|
+
|
|
4430
|
+
# Then create the tool response message
|
|
4431
|
+
func_msg = FunctionCallingMessage(
|
|
4432
|
+
role_name=self.role_name,
|
|
4433
|
+
role_type=self.role_type,
|
|
4434
|
+
meta_dict=None,
|
|
4435
|
+
content="",
|
|
4436
|
+
func_name=function_name,
|
|
4437
|
+
result=result,
|
|
4438
|
+
tool_call_id=tool_call_id,
|
|
4439
|
+
)
|
|
4440
|
+
|
|
4441
|
+
# Record both messages with precise timestamps to ensure
|
|
4442
|
+
# correct ordering
|
|
4443
|
+
current_time_ns = time.time_ns()
|
|
4444
|
+
base_timestamp = (
|
|
4445
|
+
current_time_ns / 1_000_000_000
|
|
4446
|
+
) # Convert to seconds
|
|
4447
|
+
|
|
4448
|
+
self.update_memory(
|
|
4449
|
+
assist_msg,
|
|
4450
|
+
OpenAIBackendRole.ASSISTANT,
|
|
4451
|
+
timestamp=base_timestamp,
|
|
4452
|
+
)
|
|
4453
|
+
self.update_memory(
|
|
4454
|
+
func_msg,
|
|
4455
|
+
OpenAIBackendRole.FUNCTION,
|
|
4456
|
+
timestamp=base_timestamp + 1e-6,
|
|
4457
|
+
)
|
|
4458
|
+
|
|
4459
|
+
tool_record = ToolCallingRecord(
|
|
4460
|
+
tool_name=function_name,
|
|
4461
|
+
args=args,
|
|
4462
|
+
result=result,
|
|
4463
|
+
tool_call_id=tool_call_id,
|
|
4464
|
+
)
|
|
4465
|
+
self._update_last_tool_call_state(tool_record)
|
|
4466
|
+
return tool_record
|
|
4467
|
+
|
|
4468
|
+
except Exception as e:
|
|
4469
|
+
error_msg = (
|
|
4470
|
+
f"Error executing tool '{function_name}': {e!s}"
|
|
4471
|
+
)
|
|
4472
|
+
result = {"error": error_msg}
|
|
4473
|
+
logger.warning(error_msg)
|
|
4474
|
+
|
|
4475
|
+
# Record error response
|
|
4476
|
+
func_msg = FunctionCallingMessage(
|
|
4477
|
+
role_name=self.role_name,
|
|
4478
|
+
role_type=self.role_type,
|
|
4479
|
+
meta_dict=None,
|
|
4480
|
+
content="",
|
|
4481
|
+
func_name=function_name,
|
|
4482
|
+
result=result,
|
|
4483
|
+
tool_call_id=tool_call_id,
|
|
4484
|
+
)
|
|
4485
|
+
|
|
4486
|
+
self.update_memory(func_msg, OpenAIBackendRole.FUNCTION)
|
|
4487
|
+
|
|
4488
|
+
tool_record = ToolCallingRecord(
|
|
4489
|
+
tool_name=function_name,
|
|
4490
|
+
args=args,
|
|
4491
|
+
result=result,
|
|
4492
|
+
tool_call_id=tool_call_id,
|
|
4493
|
+
)
|
|
4494
|
+
self._update_last_tool_call_state(tool_record)
|
|
4495
|
+
return tool_record
|
|
4496
|
+
else:
|
|
4497
|
+
logger.warning(
|
|
4498
|
+
f"Tool '{function_name}' not found in internal tools"
|
|
4499
|
+
)
|
|
4500
|
+
return None
|
|
4501
|
+
|
|
4502
|
+
except Exception as e:
|
|
4503
|
+
logger.error(f"Error processing tool call: {e}")
|
|
4504
|
+
return None
|
|
4505
|
+
|
|
4506
|
+
async def _aexecute_tool_from_stream_data(
|
|
4507
|
+
self, tool_call_data: Dict[str, Any]
|
|
4508
|
+
) -> Optional[ToolCallingRecord]:
|
|
4509
|
+
r"""Async execute a tool from accumulated stream data."""
|
|
4510
|
+
|
|
4511
|
+
try:
|
|
4512
|
+
function_name = tool_call_data['function']['name']
|
|
4513
|
+
args = json.loads(tool_call_data['function']['arguments'])
|
|
4514
|
+
tool_call_id = tool_call_data['id']
|
|
4515
|
+
|
|
4516
|
+
if function_name in self._internal_tools:
|
|
4517
|
+
# Create the tool call message
|
|
4518
|
+
assist_msg = FunctionCallingMessage(
|
|
4519
|
+
role_name=self.role_name,
|
|
4520
|
+
role_type=self.role_type,
|
|
4521
|
+
meta_dict=None,
|
|
4522
|
+
content="",
|
|
4523
|
+
func_name=function_name,
|
|
4524
|
+
args=args,
|
|
4525
|
+
tool_call_id=tool_call_id,
|
|
4526
|
+
)
|
|
4527
|
+
assist_ts = time.time_ns() / 1_000_000_000
|
|
4528
|
+
self.update_memory(
|
|
4529
|
+
assist_msg,
|
|
4530
|
+
OpenAIBackendRole.ASSISTANT,
|
|
4531
|
+
timestamp=assist_ts,
|
|
4532
|
+
)
|
|
4533
|
+
|
|
4534
|
+
tool = self._internal_tools[function_name]
|
|
4535
|
+
try:
|
|
4536
|
+
# Try different invocation paths in order of preference
|
|
4537
|
+
if hasattr(tool, 'func') and hasattr(
|
|
4538
|
+
tool.func, 'async_call'
|
|
4539
|
+
):
|
|
4540
|
+
# Case: FunctionTool wrapping an MCP tool
|
|
4541
|
+
result = await tool.func.async_call(**args)
|
|
4542
|
+
|
|
4543
|
+
elif hasattr(tool, 'async_call') and callable(
|
|
4544
|
+
tool.async_call
|
|
4545
|
+
):
|
|
4546
|
+
# Case: tool itself has async_call
|
|
4547
|
+
result = await tool.async_call(**args)
|
|
4548
|
+
|
|
4549
|
+
elif hasattr(tool, 'func') and asyncio.iscoroutinefunction(
|
|
4550
|
+
tool.func
|
|
4551
|
+
):
|
|
4552
|
+
# Case: tool wraps a direct async function
|
|
4553
|
+
result = await tool.func(**args)
|
|
4554
|
+
|
|
4555
|
+
elif asyncio.iscoroutinefunction(tool):
|
|
4556
|
+
# Case: tool is itself a coroutine function
|
|
4557
|
+
result = await tool(**args)
|
|
4558
|
+
|
|
4559
|
+
else:
|
|
4560
|
+
# Fallback: synchronous call
|
|
4561
|
+
result = tool(**args)
|
|
4562
|
+
|
|
4563
|
+
# Create the tool response message
|
|
4564
|
+
func_msg = FunctionCallingMessage(
|
|
4565
|
+
role_name=self.role_name,
|
|
4566
|
+
role_type=self.role_type,
|
|
4567
|
+
meta_dict=None,
|
|
4568
|
+
content="",
|
|
4569
|
+
func_name=function_name,
|
|
4570
|
+
result=result,
|
|
4571
|
+
tool_call_id=tool_call_id,
|
|
4572
|
+
)
|
|
4573
|
+
func_ts = time.time_ns() / 1_000_000_000
|
|
4574
|
+
self.update_memory(
|
|
4575
|
+
func_msg,
|
|
4576
|
+
OpenAIBackendRole.FUNCTION,
|
|
4577
|
+
timestamp=func_ts,
|
|
4578
|
+
)
|
|
4579
|
+
|
|
4580
|
+
tool_record = ToolCallingRecord(
|
|
4581
|
+
tool_name=function_name,
|
|
4582
|
+
args=args,
|
|
4583
|
+
result=result,
|
|
4584
|
+
tool_call_id=tool_call_id,
|
|
4585
|
+
)
|
|
4586
|
+
self._update_last_tool_call_state(tool_record)
|
|
4587
|
+
return tool_record
|
|
4588
|
+
|
|
4589
|
+
except Exception as e:
|
|
4590
|
+
error_msg = (
|
|
4591
|
+
f"Error executing async tool '{function_name}': {e!s}"
|
|
4592
|
+
)
|
|
4593
|
+
result = {"error": error_msg}
|
|
4594
|
+
logger.warning(error_msg)
|
|
4595
|
+
|
|
4596
|
+
# Record error response
|
|
4597
|
+
func_msg = FunctionCallingMessage(
|
|
4598
|
+
role_name=self.role_name,
|
|
4599
|
+
role_type=self.role_type,
|
|
4600
|
+
meta_dict=None,
|
|
4601
|
+
content="",
|
|
4602
|
+
func_name=function_name,
|
|
4603
|
+
result=result,
|
|
4604
|
+
tool_call_id=tool_call_id,
|
|
4605
|
+
)
|
|
4606
|
+
func_ts = time.time_ns() / 1_000_000_000
|
|
4607
|
+
self.update_memory(
|
|
4608
|
+
func_msg,
|
|
4609
|
+
OpenAIBackendRole.FUNCTION,
|
|
4610
|
+
timestamp=func_ts,
|
|
4611
|
+
)
|
|
4612
|
+
|
|
4613
|
+
tool_record = ToolCallingRecord(
|
|
4614
|
+
tool_name=function_name,
|
|
4615
|
+
args=args,
|
|
4616
|
+
result=result,
|
|
4617
|
+
tool_call_id=tool_call_id,
|
|
4618
|
+
)
|
|
4619
|
+
self._update_last_tool_call_state(tool_record)
|
|
4620
|
+
return tool_record
|
|
4621
|
+
else:
|
|
4622
|
+
logger.warning(
|
|
4623
|
+
f"Tool '{function_name}' not found in internal tools"
|
|
4624
|
+
)
|
|
4625
|
+
return None
|
|
4626
|
+
|
|
4627
|
+
except Exception as e:
|
|
4628
|
+
logger.error(f"Error processing async tool call: {e}")
|
|
4629
|
+
return None
|
|
4630
|
+
|
|
4631
|
+
def _create_error_response(
|
|
4632
|
+
self, error_message: str, tool_call_records: List[ToolCallingRecord]
|
|
4633
|
+
) -> ChatAgentResponse:
|
|
4634
|
+
r"""Create an error response for streaming."""
|
|
4635
|
+
|
|
4636
|
+
error_msg = BaseMessage(
|
|
4637
|
+
role_name=self.role_name,
|
|
4638
|
+
role_type=self.role_type,
|
|
4639
|
+
meta_dict={},
|
|
4640
|
+
content=f"Error: {error_message}",
|
|
4641
|
+
)
|
|
4642
|
+
|
|
4643
|
+
return ChatAgentResponse(
|
|
4644
|
+
msgs=[error_msg],
|
|
4645
|
+
terminated=True,
|
|
4646
|
+
info={
|
|
4647
|
+
"error": error_message,
|
|
4648
|
+
"tool_calls": tool_call_records,
|
|
4649
|
+
"streaming": True,
|
|
4650
|
+
},
|
|
4651
|
+
)
|
|
4652
|
+
|
|
4653
|
+
async def _astream(
|
|
4654
|
+
self,
|
|
4655
|
+
input_message: Union[BaseMessage, str],
|
|
4656
|
+
response_format: Optional[Type[BaseModel]] = None,
|
|
4657
|
+
) -> AsyncGenerator[ChatAgentResponse, None]:
|
|
4658
|
+
r"""Asynchronous version of stream method."""
|
|
4659
|
+
|
|
4660
|
+
# Convert input message to BaseMessage if necessary
|
|
4661
|
+
if isinstance(input_message, str):
|
|
4662
|
+
input_message = BaseMessage.make_user_message(
|
|
4663
|
+
role_name="User", content=input_message
|
|
4664
|
+
)
|
|
4665
|
+
|
|
4666
|
+
# Add user input to memory
|
|
4667
|
+
self.update_memory(input_message, OpenAIBackendRole.USER)
|
|
4668
|
+
|
|
4669
|
+
# Get context for streaming
|
|
4670
|
+
try:
|
|
4671
|
+
openai_messages, num_tokens = self.memory.get_context()
|
|
4672
|
+
except RuntimeError as e:
|
|
4673
|
+
yield self._step_terminate(e.args[1], [], "max_tokens_exceeded")
|
|
4674
|
+
return
|
|
4675
|
+
|
|
4676
|
+
# Start async streaming response
|
|
4677
|
+
last_response = None
|
|
4678
|
+
async for response in self._astream_response(
|
|
4679
|
+
openai_messages, num_tokens, response_format
|
|
4680
|
+
):
|
|
4681
|
+
last_response = response
|
|
4682
|
+
yield response
|
|
4683
|
+
|
|
4684
|
+
# Clean tool call messages from memory after response generation
|
|
4685
|
+
if self.prune_tool_calls_from_memory and last_response:
|
|
4686
|
+
# Extract tool_calls from the last response info
|
|
4687
|
+
tool_calls = last_response.info.get("tool_calls", [])
|
|
4688
|
+
if tool_calls:
|
|
4689
|
+
self.memory.clean_tool_calls()
|
|
4690
|
+
|
|
4691
|
+
async def _astream_response(
|
|
4692
|
+
self,
|
|
4693
|
+
openai_messages: List[OpenAIMessage],
|
|
4694
|
+
num_tokens: int,
|
|
4695
|
+
response_format: Optional[Type[BaseModel]] = None,
|
|
4696
|
+
) -> AsyncGenerator[ChatAgentResponse, None]:
|
|
4697
|
+
r"""Async method to handle streaming responses with tool calls."""
|
|
4698
|
+
|
|
4699
|
+
tool_call_records: List[ToolCallingRecord] = []
|
|
4700
|
+
accumulated_tool_calls: Dict[str, Any] = {}
|
|
4701
|
+
step_token_usage = self._create_token_usage_tracker()
|
|
4702
|
+
|
|
4703
|
+
# Create content accumulator for proper content management
|
|
4704
|
+
content_accumulator = StreamContentAccumulator()
|
|
4705
|
+
iteration_count = 0
|
|
4706
|
+
while True:
|
|
4707
|
+
# Check termination condition
|
|
4708
|
+
if self.stop_event and self.stop_event.is_set():
|
|
4709
|
+
logger.info(
|
|
4710
|
+
f"Termination triggered at iteration {iteration_count}"
|
|
4711
|
+
)
|
|
4712
|
+
yield self._step_terminate(
|
|
4713
|
+
num_tokens, tool_call_records, "termination_triggered"
|
|
4714
|
+
)
|
|
4715
|
+
return
|
|
4716
|
+
|
|
4717
|
+
# Get async streaming response from model
|
|
4718
|
+
try:
|
|
4719
|
+
response = await self.model_backend.arun(
|
|
4720
|
+
openai_messages,
|
|
4721
|
+
response_format,
|
|
4722
|
+
self._get_full_tool_schemas() or None,
|
|
4723
|
+
)
|
|
4724
|
+
iteration_count += 1
|
|
4725
|
+
except Exception as exc:
|
|
4726
|
+
logger.error(
|
|
4727
|
+
f"Error in async streaming model response: {exc}",
|
|
4728
|
+
exc_info=exc,
|
|
4729
|
+
)
|
|
4730
|
+
yield self._create_error_response(str(exc), tool_call_records)
|
|
4731
|
+
return
|
|
4732
|
+
|
|
4733
|
+
# Handle streaming response
|
|
4734
|
+
if isinstance(response, AsyncStream):
|
|
4735
|
+
stream_completed = False
|
|
4736
|
+
tool_calls_complete = False
|
|
4737
|
+
|
|
4738
|
+
# Process chunks and forward them
|
|
4739
|
+
async for (
|
|
4740
|
+
item
|
|
4741
|
+
) in self._aprocess_stream_chunks_with_accumulator(
|
|
4742
|
+
response,
|
|
4743
|
+
content_accumulator,
|
|
4744
|
+
accumulated_tool_calls,
|
|
4745
|
+
tool_call_records,
|
|
4746
|
+
step_token_usage,
|
|
4747
|
+
response_format,
|
|
4748
|
+
):
|
|
4749
|
+
if isinstance(item, tuple):
|
|
4750
|
+
# This is the final return value (stream_completed,
|
|
4751
|
+
# tool_calls_complete)
|
|
4752
|
+
stream_completed, tool_calls_complete = item
|
|
4753
|
+
break
|
|
4754
|
+
else:
|
|
4755
|
+
# This is a ChatAgentResponse to be yielded
|
|
4756
|
+
yield item
|
|
4757
|
+
|
|
4758
|
+
if tool_calls_complete:
|
|
4759
|
+
# Clear completed tool calls
|
|
4760
|
+
accumulated_tool_calls.clear()
|
|
4761
|
+
|
|
4762
|
+
# If we executed tools and not in
|
|
4763
|
+
# single iteration mode, continue
|
|
4764
|
+
if tool_call_records and (
|
|
4765
|
+
self.max_iteration is None
|
|
4766
|
+
or iteration_count < self.max_iteration
|
|
4767
|
+
):
|
|
4768
|
+
# Update messages with tool results for next iteration
|
|
4769
|
+
try:
|
|
4770
|
+
openai_messages, num_tokens = (
|
|
4771
|
+
self.memory.get_context()
|
|
4772
|
+
)
|
|
4773
|
+
except RuntimeError as e:
|
|
4774
|
+
yield self._step_terminate(
|
|
4775
|
+
e.args[1],
|
|
4776
|
+
tool_call_records,
|
|
4777
|
+
"max_tokens_exceeded",
|
|
4778
|
+
)
|
|
4779
|
+
return
|
|
4780
|
+
# Reset streaming content for next iteration
|
|
4781
|
+
content_accumulator.reset_streaming_content()
|
|
4782
|
+
continue
|
|
4783
|
+
else:
|
|
4784
|
+
break
|
|
4785
|
+
else:
|
|
4786
|
+
# Stream completed without tool calls
|
|
4787
|
+
accumulated_tool_calls.clear()
|
|
4788
|
+
break
|
|
4789
|
+
elif hasattr(response, '__aenter__') and hasattr(
|
|
4790
|
+
response, '__aexit__'
|
|
4791
|
+
):
|
|
4792
|
+
# Handle structured output stream
|
|
4793
|
+
# (AsyncChatCompletionStreamManager)
|
|
4794
|
+
async with response as stream:
|
|
4795
|
+
parsed_object = None
|
|
4796
|
+
|
|
4797
|
+
async for event in stream:
|
|
4798
|
+
if event.type == "content.delta":
|
|
4799
|
+
if getattr(event, "delta", None):
|
|
4800
|
+
# Use accumulator for proper content management
|
|
4801
|
+
partial_response = self._create_streaming_response_with_accumulator( # noqa: E501
|
|
4802
|
+
content_accumulator,
|
|
4803
|
+
getattr(event, "delta", ""),
|
|
4804
|
+
step_token_usage,
|
|
4805
|
+
tool_call_records=tool_call_records.copy(),
|
|
4806
|
+
)
|
|
4807
|
+
yield partial_response
|
|
4808
|
+
|
|
4809
|
+
elif event.type == "content.done":
|
|
4810
|
+
parsed_object = getattr(event, "parsed", None)
|
|
4811
|
+
break
|
|
4812
|
+
elif event.type == "error":
|
|
4813
|
+
logger.error(
|
|
4814
|
+
f"Error in async structured stream: "
|
|
4815
|
+
f"{getattr(event, 'error', '')}"
|
|
4816
|
+
)
|
|
4817
|
+
yield self._create_error_response(
|
|
4818
|
+
str(getattr(event, 'error', '')),
|
|
4819
|
+
tool_call_records,
|
|
4820
|
+
)
|
|
4821
|
+
return
|
|
1429
4822
|
|
|
1430
|
-
|
|
1431
|
-
|
|
1432
|
-
|
|
4823
|
+
# Get final completion and record final message
|
|
4824
|
+
try:
|
|
4825
|
+
final_completion = await stream.get_final_completion()
|
|
4826
|
+
final_content = (
|
|
4827
|
+
final_completion.choices[0].message.content or ""
|
|
4828
|
+
)
|
|
1433
4829
|
|
|
1434
|
-
|
|
1435
|
-
|
|
1436
|
-
|
|
1437
|
-
|
|
1438
|
-
|
|
1439
|
-
|
|
1440
|
-
|
|
1441
|
-
|
|
1442
|
-
|
|
1443
|
-
|
|
1444
|
-
# chunk.id
|
|
1445
|
-
response_id = chunk.id if chunk.id else str(uuid.uuid4())
|
|
1446
|
-
self._handle_chunk(
|
|
1447
|
-
chunk, content_dict, finish_reasons_dict, output_messages
|
|
1448
|
-
)
|
|
1449
|
-
finish_reasons = [
|
|
1450
|
-
finish_reasons_dict[i] for i in range(len(finish_reasons_dict))
|
|
1451
|
-
]
|
|
1452
|
-
usage_dict = self.get_usage_dict(output_messages, prompt_tokens)
|
|
4830
|
+
final_message = BaseMessage(
|
|
4831
|
+
role_name=self.role_name,
|
|
4832
|
+
role_type=self.role_type,
|
|
4833
|
+
meta_dict={},
|
|
4834
|
+
content=final_content,
|
|
4835
|
+
parsed=cast(
|
|
4836
|
+
"BaseModel | dict[str, Any] | None",
|
|
4837
|
+
parsed_object,
|
|
4838
|
+
), # type: ignore[arg-type]
|
|
4839
|
+
)
|
|
1453
4840
|
|
|
1454
|
-
|
|
1455
|
-
return ModelResponse(
|
|
1456
|
-
response=response,
|
|
1457
|
-
tool_call_requests=None,
|
|
1458
|
-
output_messages=output_messages,
|
|
1459
|
-
finish_reasons=finish_reasons,
|
|
1460
|
-
usage_dict=usage_dict,
|
|
1461
|
-
response_id=response_id,
|
|
1462
|
-
)
|
|
4841
|
+
self.record_message(final_message)
|
|
1463
4842
|
|
|
1464
|
-
|
|
1465
|
-
|
|
1466
|
-
|
|
1467
|
-
|
|
1468
|
-
|
|
1469
|
-
|
|
1470
|
-
|
|
4843
|
+
# Create final response
|
|
4844
|
+
final_response = ChatAgentResponse(
|
|
4845
|
+
msgs=[final_message],
|
|
4846
|
+
terminated=False,
|
|
4847
|
+
info={
|
|
4848
|
+
"id": final_completion.id or "",
|
|
4849
|
+
"usage": safe_model_dump(
|
|
4850
|
+
final_completion.usage
|
|
4851
|
+
)
|
|
4852
|
+
if final_completion.usage
|
|
4853
|
+
else {},
|
|
4854
|
+
"finish_reasons": [
|
|
4855
|
+
choice.finish_reason or "stop"
|
|
4856
|
+
for choice in final_completion.choices
|
|
4857
|
+
],
|
|
4858
|
+
"num_tokens": self._get_token_count(
|
|
4859
|
+
final_content
|
|
4860
|
+
),
|
|
4861
|
+
"tool_calls": tool_call_records,
|
|
4862
|
+
"external_tool_requests": None,
|
|
4863
|
+
"streaming": False,
|
|
4864
|
+
"partial": False,
|
|
4865
|
+
},
|
|
4866
|
+
)
|
|
4867
|
+
yield final_response
|
|
4868
|
+
break
|
|
1471
4869
|
|
|
1472
|
-
|
|
1473
|
-
|
|
1474
|
-
|
|
4870
|
+
except Exception as e:
|
|
4871
|
+
logger.error(
|
|
4872
|
+
f"Error getting async final completion: {e}"
|
|
4873
|
+
)
|
|
4874
|
+
yield self._create_error_response(
|
|
4875
|
+
str(e), tool_call_records
|
|
4876
|
+
)
|
|
4877
|
+
return
|
|
4878
|
+
else:
|
|
4879
|
+
# Handle non-streaming response (fallback)
|
|
4880
|
+
model_response = self._handle_batch_response(response)
|
|
4881
|
+
yield self._convert_to_chatagent_response(
|
|
4882
|
+
model_response,
|
|
4883
|
+
tool_call_records,
|
|
4884
|
+
num_tokens,
|
|
4885
|
+
None,
|
|
4886
|
+
model_response.usage_dict.get("prompt_tokens", 0),
|
|
4887
|
+
model_response.usage_dict.get("completion_tokens", 0),
|
|
4888
|
+
model_response.usage_dict.get("total_tokens", 0),
|
|
4889
|
+
)
|
|
4890
|
+
accumulated_tool_calls.clear()
|
|
4891
|
+
break
|
|
1475
4892
|
|
|
1476
|
-
|
|
1477
|
-
|
|
4893
|
+
def _record_assistant_tool_calls_message(
|
|
4894
|
+
self, accumulated_tool_calls: Dict[str, Any], content: str = ""
|
|
4895
|
+
) -> None:
|
|
4896
|
+
r"""Record the assistant message that contains tool calls.
|
|
4897
|
+
|
|
4898
|
+
This method creates and records an assistant message that includes
|
|
4899
|
+
the tool calls information, which is required by OpenAI's API format.
|
|
1478
4900
|
"""
|
|
1479
|
-
|
|
1480
|
-
|
|
1481
|
-
|
|
1482
|
-
|
|
1483
|
-
|
|
1484
|
-
|
|
1485
|
-
|
|
1486
|
-
|
|
1487
|
-
|
|
1488
|
-
|
|
1489
|
-
|
|
1490
|
-
|
|
1491
|
-
|
|
1492
|
-
|
|
1493
|
-
]
|
|
1494
|
-
usage_dict = self.get_usage_dict(output_messages, prompt_tokens)
|
|
4901
|
+
# Create a BaseMessage with tool_calls information in meta_dict
|
|
4902
|
+
# This will be converted to the proper OpenAI format when needed
|
|
4903
|
+
tool_calls_list = []
|
|
4904
|
+
for tool_call_data in accumulated_tool_calls.values():
|
|
4905
|
+
if tool_call_data.get('complete', False):
|
|
4906
|
+
tool_call_dict = {
|
|
4907
|
+
"id": tool_call_data["id"],
|
|
4908
|
+
"type": "function",
|
|
4909
|
+
"function": {
|
|
4910
|
+
"name": tool_call_data["function"]["name"],
|
|
4911
|
+
"arguments": tool_call_data["function"]["arguments"],
|
|
4912
|
+
},
|
|
4913
|
+
}
|
|
4914
|
+
tool_calls_list.append(tool_call_dict)
|
|
1495
4915
|
|
|
1496
|
-
#
|
|
1497
|
-
|
|
1498
|
-
|
|
1499
|
-
|
|
1500
|
-
|
|
1501
|
-
|
|
1502
|
-
usage_dict=usage_dict,
|
|
1503
|
-
response_id=response_id,
|
|
4916
|
+
# Create an assistant message with tool calls
|
|
4917
|
+
assist_msg = BaseMessage(
|
|
4918
|
+
role_name=self.role_name,
|
|
4919
|
+
role_type=self.role_type,
|
|
4920
|
+
meta_dict={"tool_calls": tool_calls_list},
|
|
4921
|
+
content=content or "",
|
|
1504
4922
|
)
|
|
1505
4923
|
|
|
1506
|
-
|
|
1507
|
-
self,
|
|
1508
|
-
chunk: ChatCompletionChunk,
|
|
1509
|
-
content_dict: defaultdict,
|
|
1510
|
-
finish_reasons_dict: defaultdict,
|
|
1511
|
-
output_messages: List[BaseMessage],
|
|
1512
|
-
) -> None:
|
|
1513
|
-
r"""Handle a chunk of the model response."""
|
|
1514
|
-
for choice in chunk.choices:
|
|
1515
|
-
index = choice.index
|
|
1516
|
-
delta = choice.delta
|
|
1517
|
-
if delta.content is not None:
|
|
1518
|
-
content_dict[index] += delta.content
|
|
1519
|
-
|
|
1520
|
-
if not choice.finish_reason:
|
|
1521
|
-
continue
|
|
1522
|
-
|
|
1523
|
-
finish_reasons_dict[index] = choice.finish_reason
|
|
1524
|
-
chat_message = BaseMessage(
|
|
1525
|
-
role_name=self.role_name,
|
|
1526
|
-
role_type=self.role_type,
|
|
1527
|
-
meta_dict=dict(),
|
|
1528
|
-
content=content_dict[index],
|
|
1529
|
-
)
|
|
1530
|
-
output_messages.append(chat_message)
|
|
4924
|
+
# Record this assistant message
|
|
4925
|
+
self.update_memory(assist_msg, OpenAIBackendRole.ASSISTANT)
|
|
1531
4926
|
|
|
1532
|
-
def
|
|
4927
|
+
async def _aprocess_stream_chunks_with_accumulator(
|
|
1533
4928
|
self,
|
|
1534
|
-
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
|
|
1538
|
-
|
|
4929
|
+
stream: AsyncStream[ChatCompletionChunk],
|
|
4930
|
+
content_accumulator: StreamContentAccumulator,
|
|
4931
|
+
accumulated_tool_calls: Dict[str, Any],
|
|
4932
|
+
tool_call_records: List[ToolCallingRecord],
|
|
4933
|
+
step_token_usage: Dict[str, int],
|
|
4934
|
+
response_format: Optional[Type[BaseModel]] = None,
|
|
4935
|
+
) -> AsyncGenerator[Union[ChatAgentResponse, Tuple[bool, bool]], None]:
|
|
4936
|
+
r"""Async version of process streaming chunks with
|
|
4937
|
+
content accumulator.
|
|
4938
|
+
"""
|
|
1539
4939
|
|
|
1540
|
-
|
|
1541
|
-
|
|
1542
|
-
|
|
1543
|
-
|
|
4940
|
+
tool_calls_complete = False
|
|
4941
|
+
stream_completed = False
|
|
4942
|
+
|
|
4943
|
+
async for chunk in stream:
|
|
4944
|
+
# Process chunk delta
|
|
4945
|
+
if chunk.choices and len(chunk.choices) > 0:
|
|
4946
|
+
choice = chunk.choices[0]
|
|
4947
|
+
delta = choice.delta
|
|
4948
|
+
|
|
4949
|
+
# Handle content streaming
|
|
4950
|
+
if delta.content:
|
|
4951
|
+
# Use accumulator for proper content management
|
|
4952
|
+
partial_response = (
|
|
4953
|
+
self._create_streaming_response_with_accumulator(
|
|
4954
|
+
content_accumulator,
|
|
4955
|
+
delta.content,
|
|
4956
|
+
step_token_usage,
|
|
4957
|
+
getattr(chunk, 'id', ''),
|
|
4958
|
+
tool_call_records.copy(),
|
|
4959
|
+
)
|
|
4960
|
+
)
|
|
4961
|
+
yield partial_response
|
|
1544
4962
|
|
|
1545
|
-
|
|
1546
|
-
|
|
1547
|
-
|
|
1548
|
-
|
|
1549
|
-
|
|
1550
|
-
termination.
|
|
4963
|
+
# Handle tool calls streaming
|
|
4964
|
+
if delta.tool_calls:
|
|
4965
|
+
tool_calls_complete = self._accumulate_tool_calls(
|
|
4966
|
+
delta.tool_calls, accumulated_tool_calls
|
|
4967
|
+
)
|
|
1551
4968
|
|
|
1552
|
-
|
|
1553
|
-
|
|
1554
|
-
|
|
1555
|
-
|
|
1556
|
-
|
|
1557
|
-
|
|
4969
|
+
# Check if stream is complete
|
|
4970
|
+
if choice.finish_reason:
|
|
4971
|
+
stream_completed = True
|
|
4972
|
+
|
|
4973
|
+
# If we have complete tool calls, execute them with
|
|
4974
|
+
# async status updates
|
|
4975
|
+
if accumulated_tool_calls:
|
|
4976
|
+
# Execute tools asynchronously with real-time
|
|
4977
|
+
# status updates
|
|
4978
|
+
async for (
|
|
4979
|
+
status_response
|
|
4980
|
+
) in self._execute_tools_async_with_status_accumulator(
|
|
4981
|
+
accumulated_tool_calls,
|
|
4982
|
+
content_accumulator,
|
|
4983
|
+
step_token_usage,
|
|
4984
|
+
tool_call_records,
|
|
4985
|
+
):
|
|
4986
|
+
yield status_response
|
|
4987
|
+
|
|
4988
|
+
# Log sending status instead of adding to content
|
|
4989
|
+
if tool_call_records:
|
|
4990
|
+
logger.info("Sending back result to model")
|
|
4991
|
+
|
|
4992
|
+
# Record final message only if we have content AND no tool
|
|
4993
|
+
# calls. If there are tool calls, _record_tool_calling
|
|
4994
|
+
# will handle message recording.
|
|
4995
|
+
final_content = content_accumulator.get_full_content()
|
|
4996
|
+
if final_content.strip() and not accumulated_tool_calls:
|
|
4997
|
+
final_message = BaseMessage(
|
|
4998
|
+
role_name=self.role_name,
|
|
4999
|
+
role_type=self.role_type,
|
|
5000
|
+
meta_dict={},
|
|
5001
|
+
content=final_content,
|
|
5002
|
+
)
|
|
1558
5003
|
|
|
1559
|
-
|
|
1560
|
-
|
|
1561
|
-
|
|
1562
|
-
|
|
1563
|
-
num_tokens,
|
|
1564
|
-
tool_calls,
|
|
1565
|
-
)
|
|
5004
|
+
if response_format:
|
|
5005
|
+
self._try_format_message(
|
|
5006
|
+
final_message, response_format
|
|
5007
|
+
)
|
|
1566
5008
|
|
|
1567
|
-
|
|
1568
|
-
|
|
1569
|
-
|
|
1570
|
-
|
|
1571
|
-
|
|
5009
|
+
self.record_message(final_message)
|
|
5010
|
+
elif chunk.usage and not chunk.choices:
|
|
5011
|
+
# Handle final chunk with usage but empty choices
|
|
5012
|
+
# This happens when stream_options={"include_usage": True}
|
|
5013
|
+
# Update the final usage from this chunk
|
|
5014
|
+
self._update_token_usage_tracker(
|
|
5015
|
+
step_token_usage, safe_model_dump(chunk.usage)
|
|
5016
|
+
)
|
|
1572
5017
|
|
|
1573
|
-
|
|
1574
|
-
|
|
1575
|
-
|
|
1576
|
-
|
|
1577
|
-
|
|
5018
|
+
# Create final response with final usage
|
|
5019
|
+
final_content = content_accumulator.get_full_content()
|
|
5020
|
+
if final_content.strip():
|
|
5021
|
+
final_message = BaseMessage(
|
|
5022
|
+
role_name=self.role_name,
|
|
5023
|
+
role_type=self.role_type,
|
|
5024
|
+
meta_dict={},
|
|
5025
|
+
content=final_content,
|
|
5026
|
+
)
|
|
1578
5027
|
|
|
1579
|
-
|
|
1580
|
-
|
|
5028
|
+
if response_format:
|
|
5029
|
+
self._try_format_message(
|
|
5030
|
+
final_message, response_format
|
|
5031
|
+
)
|
|
1581
5032
|
|
|
1582
|
-
|
|
1583
|
-
|
|
1584
|
-
|
|
1585
|
-
|
|
1586
|
-
|
|
1587
|
-
|
|
1588
|
-
|
|
1589
|
-
|
|
1590
|
-
|
|
1591
|
-
|
|
1592
|
-
|
|
1593
|
-
|
|
1594
|
-
|
|
1595
|
-
|
|
1596
|
-
|
|
5033
|
+
# Create final response with final usage (not partial)
|
|
5034
|
+
final_response = ChatAgentResponse(
|
|
5035
|
+
msgs=[final_message],
|
|
5036
|
+
terminated=False,
|
|
5037
|
+
info={
|
|
5038
|
+
"id": getattr(chunk, 'id', ''),
|
|
5039
|
+
"usage": step_token_usage.copy(),
|
|
5040
|
+
"finish_reasons": ["stop"],
|
|
5041
|
+
"num_tokens": self._get_token_count(final_content),
|
|
5042
|
+
"tool_calls": tool_call_records or [],
|
|
5043
|
+
"external_tool_requests": None,
|
|
5044
|
+
"streaming": False,
|
|
5045
|
+
"partial": False,
|
|
5046
|
+
},
|
|
5047
|
+
)
|
|
5048
|
+
yield final_response
|
|
5049
|
+
break
|
|
5050
|
+
elif stream_completed:
|
|
5051
|
+
# If we've already seen finish_reason but no usage chunk, exit
|
|
5052
|
+
break
|
|
1597
5053
|
|
|
1598
|
-
|
|
5054
|
+
# Yield the final status as a tuple
|
|
5055
|
+
yield (stream_completed, tool_calls_complete)
|
|
1599
5056
|
|
|
1600
|
-
async def
|
|
5057
|
+
async def _execute_tools_async_with_status_accumulator(
|
|
1601
5058
|
self,
|
|
1602
|
-
|
|
1603
|
-
|
|
1604
|
-
|
|
1605
|
-
|
|
1606
|
-
|
|
1607
|
-
|
|
5059
|
+
accumulated_tool_calls: Dict[str, Any],
|
|
5060
|
+
content_accumulator: StreamContentAccumulator,
|
|
5061
|
+
step_token_usage: Dict[str, int],
|
|
5062
|
+
tool_call_records: List[ToolCallingRecord],
|
|
5063
|
+
) -> AsyncGenerator[ChatAgentResponse, None]:
|
|
5064
|
+
r"""Execute multiple tools asynchronously with
|
|
5065
|
+
proper content accumulation."""
|
|
1608
5066
|
import asyncio
|
|
1609
5067
|
|
|
1610
|
-
|
|
1611
|
-
|
|
1612
|
-
|
|
1613
|
-
|
|
1614
|
-
|
|
5068
|
+
# Phase 1: Start all tools and yield "Calling function"
|
|
5069
|
+
# statuses immediately
|
|
5070
|
+
tool_tasks = []
|
|
5071
|
+
for _tool_call_index, tool_call_data in accumulated_tool_calls.items():
|
|
5072
|
+
if tool_call_data.get('complete', False):
|
|
5073
|
+
function_name = tool_call_data['function']['name']
|
|
5074
|
+
try:
|
|
5075
|
+
args = json.loads(tool_call_data['function']['arguments'])
|
|
5076
|
+
except json.JSONDecodeError:
|
|
5077
|
+
args = tool_call_data['function']['arguments']
|
|
5078
|
+
|
|
5079
|
+
# Log debug info instead of adding to content
|
|
5080
|
+
logger.info(
|
|
5081
|
+
f"Calling function: {function_name} with arguments: {args}"
|
|
5082
|
+
)
|
|
1615
5083
|
|
|
1616
|
-
|
|
1617
|
-
|
|
1618
|
-
|
|
5084
|
+
# Start tool execution asynchronously (non-blocking)
|
|
5085
|
+
if self.tool_execution_timeout is not None:
|
|
5086
|
+
task = asyncio.create_task(
|
|
5087
|
+
asyncio.wait_for(
|
|
5088
|
+
self._aexecute_tool_from_stream_data(
|
|
5089
|
+
tool_call_data
|
|
5090
|
+
),
|
|
5091
|
+
timeout=self.tool_execution_timeout,
|
|
5092
|
+
)
|
|
5093
|
+
)
|
|
5094
|
+
else:
|
|
5095
|
+
task = asyncio.create_task(
|
|
5096
|
+
self._aexecute_tool_from_stream_data(tool_call_data)
|
|
5097
|
+
)
|
|
5098
|
+
tool_tasks.append((task, tool_call_data))
|
|
1619
5099
|
|
|
1620
|
-
|
|
1621
|
-
|
|
5100
|
+
# Phase 2: Wait for tools to complete and yield results as they finish
|
|
5101
|
+
if tool_tasks:
|
|
5102
|
+
# Use asyncio.as_completed for true async processing
|
|
5103
|
+
for completed_task in asyncio.as_completed(
|
|
5104
|
+
[task for task, _ in tool_tasks]
|
|
1622
5105
|
):
|
|
1623
|
-
|
|
1624
|
-
|
|
5106
|
+
try:
|
|
5107
|
+
tool_call_record = await completed_task
|
|
5108
|
+
if tool_call_record:
|
|
5109
|
+
# Add to the shared tool_call_records list
|
|
5110
|
+
tool_call_records.append(tool_call_record)
|
|
1625
5111
|
|
|
1626
|
-
|
|
1627
|
-
|
|
1628
|
-
|
|
5112
|
+
# Create output status message
|
|
5113
|
+
raw_result = tool_call_record.result
|
|
5114
|
+
result_str = str(raw_result)
|
|
1629
5115
|
|
|
1630
|
-
|
|
1631
|
-
|
|
1632
|
-
result = tool(**args)
|
|
5116
|
+
# Log debug info instead of adding to content
|
|
5117
|
+
logger.info(f"Function output: {result_str}")
|
|
1633
5118
|
|
|
1634
|
-
|
|
1635
|
-
|
|
1636
|
-
|
|
1637
|
-
|
|
1638
|
-
|
|
5119
|
+
except Exception as e:
|
|
5120
|
+
if isinstance(e, asyncio.TimeoutError):
|
|
5121
|
+
# Log timeout info instead of adding to content
|
|
5122
|
+
logger.warning(
|
|
5123
|
+
f"Function timed out after "
|
|
5124
|
+
f"{self.tool_execution_timeout} seconds"
|
|
5125
|
+
)
|
|
5126
|
+
else:
|
|
5127
|
+
logger.error(f"Error in async tool execution: {e}")
|
|
5128
|
+
continue
|
|
1639
5129
|
|
|
1640
|
-
|
|
5130
|
+
# Ensure this function remains an async generator
|
|
5131
|
+
return
|
|
5132
|
+
# This line is never reached but makes this an async generator function
|
|
5133
|
+
yield
|
|
1641
5134
|
|
|
1642
|
-
def
|
|
5135
|
+
def _create_streaming_response_with_accumulator(
|
|
1643
5136
|
self,
|
|
1644
|
-
|
|
1645
|
-
|
|
1646
|
-
|
|
1647
|
-
|
|
1648
|
-
|
|
1649
|
-
|
|
1650
|
-
|
|
1651
|
-
"""
|
|
1652
|
-
assist_msg = FunctionCallingMessage(
|
|
1653
|
-
role_name=self.role_name,
|
|
1654
|
-
role_type=self.role_type,
|
|
1655
|
-
meta_dict=None,
|
|
1656
|
-
content="",
|
|
1657
|
-
func_name=func_name,
|
|
1658
|
-
args=args,
|
|
1659
|
-
tool_call_id=tool_call_id,
|
|
1660
|
-
)
|
|
1661
|
-
func_msg = FunctionCallingMessage(
|
|
1662
|
-
role_name=self.role_name,
|
|
1663
|
-
role_type=self.role_type,
|
|
1664
|
-
meta_dict=None,
|
|
1665
|
-
content="",
|
|
1666
|
-
func_name=func_name,
|
|
1667
|
-
result=result,
|
|
1668
|
-
tool_call_id=tool_call_id,
|
|
1669
|
-
)
|
|
1670
|
-
|
|
1671
|
-
# Use precise timestamps to ensure correct ordering
|
|
1672
|
-
# This ensures the assistant message (tool call) always appears before
|
|
1673
|
-
# the function message (tool result) in the conversation context
|
|
1674
|
-
# Use time.time_ns() for nanosecond precision to avoid collisions
|
|
1675
|
-
import time
|
|
1676
|
-
|
|
1677
|
-
current_time_ns = time.time_ns()
|
|
1678
|
-
base_timestamp = current_time_ns / 1_000_000_000 # Convert to seconds
|
|
5137
|
+
accumulator: StreamContentAccumulator,
|
|
5138
|
+
new_content: str,
|
|
5139
|
+
step_token_usage: Dict[str, int],
|
|
5140
|
+
response_id: str = "",
|
|
5141
|
+
tool_call_records: Optional[List[ToolCallingRecord]] = None,
|
|
5142
|
+
) -> ChatAgentResponse:
|
|
5143
|
+
r"""Create a streaming response using content accumulator."""
|
|
1679
5144
|
|
|
1680
|
-
|
|
1681
|
-
|
|
1682
|
-
|
|
5145
|
+
# Add new content; only build full content when needed
|
|
5146
|
+
accumulator.add_streaming_content(new_content)
|
|
5147
|
+
if self.stream_accumulate:
|
|
5148
|
+
message_content = accumulator.get_full_content()
|
|
5149
|
+
else:
|
|
5150
|
+
message_content = new_content
|
|
1683
5151
|
|
|
1684
|
-
|
|
1685
|
-
|
|
1686
|
-
|
|
1687
|
-
|
|
1688
|
-
|
|
5152
|
+
message = BaseMessage(
|
|
5153
|
+
role_name=self.role_name,
|
|
5154
|
+
role_type=self.role_type,
|
|
5155
|
+
meta_dict={},
|
|
5156
|
+
content=message_content,
|
|
1689
5157
|
)
|
|
1690
5158
|
|
|
1691
|
-
|
|
1692
|
-
|
|
1693
|
-
|
|
1694
|
-
|
|
1695
|
-
|
|
1696
|
-
|
|
5159
|
+
return ChatAgentResponse(
|
|
5160
|
+
msgs=[message],
|
|
5161
|
+
terminated=False,
|
|
5162
|
+
info={
|
|
5163
|
+
"id": response_id,
|
|
5164
|
+
"usage": step_token_usage.copy(),
|
|
5165
|
+
"finish_reasons": ["streaming"],
|
|
5166
|
+
"num_tokens": self._get_token_count(message_content),
|
|
5167
|
+
"tool_calls": tool_call_records or [],
|
|
5168
|
+
"external_tool_requests": None,
|
|
5169
|
+
"streaming": True,
|
|
5170
|
+
"partial": True,
|
|
5171
|
+
},
|
|
1697
5172
|
)
|
|
1698
5173
|
|
|
1699
|
-
return tool_record
|
|
1700
|
-
|
|
1701
5174
|
def get_usage_dict(
|
|
1702
5175
|
self, output_messages: List[BaseMessage], prompt_tokens: int
|
|
1703
5176
|
) -> Dict[str, int]:
|
|
@@ -1746,10 +5219,15 @@ class ChatAgent(BaseAgent):
|
|
|
1746
5219
|
configuration.
|
|
1747
5220
|
"""
|
|
1748
5221
|
# Create a new instance with the same configuration
|
|
1749
|
-
# If with_memory is True, set system_message to None
|
|
1750
|
-
#
|
|
5222
|
+
# If with_memory is True, set system_message to None (it will be
|
|
5223
|
+
# copied from memory below, including any workflow context)
|
|
5224
|
+
# If with_memory is False, use the current system message
|
|
5225
|
+
# (which may include appended workflow context)
|
|
1751
5226
|
# To avoid duplicated system memory.
|
|
1752
|
-
system_message = None if with_memory else self.
|
|
5227
|
+
system_message = None if with_memory else self._system_message
|
|
5228
|
+
|
|
5229
|
+
# Clone tools and collect toolkits that need registration
|
|
5230
|
+
cloned_tools, toolkits_to_register = self._clone_tools()
|
|
1753
5231
|
|
|
1754
5232
|
new_agent = ChatAgent(
|
|
1755
5233
|
system_message=system_message,
|
|
@@ -1760,14 +5238,21 @@ class ChatAgent(BaseAgent):
|
|
|
1760
5238
|
self.memory.get_context_creator(), "token_limit", None
|
|
1761
5239
|
),
|
|
1762
5240
|
output_language=self._output_language,
|
|
1763
|
-
tools=
|
|
5241
|
+
tools=cast(List[Union[FunctionTool, Callable]], cloned_tools),
|
|
5242
|
+
toolkits_to_register_agent=toolkits_to_register,
|
|
1764
5243
|
external_tools=[
|
|
1765
5244
|
schema for schema in self._external_tool_schemas.values()
|
|
1766
5245
|
],
|
|
1767
5246
|
response_terminators=self.response_terminators,
|
|
1768
|
-
scheduling_strategy=
|
|
5247
|
+
scheduling_strategy=(
|
|
5248
|
+
self.model_backend.scheduling_strategy.__name__
|
|
5249
|
+
),
|
|
1769
5250
|
max_iteration=self.max_iteration,
|
|
1770
5251
|
stop_event=self.stop_event,
|
|
5252
|
+
tool_execution_timeout=self.tool_execution_timeout,
|
|
5253
|
+
pause_event=self.pause_event,
|
|
5254
|
+
prune_tool_calls_from_memory=self.prune_tool_calls_from_memory,
|
|
5255
|
+
stream_accumulate=self.stream_accumulate,
|
|
1771
5256
|
)
|
|
1772
5257
|
|
|
1773
5258
|
# Copy memory if requested
|
|
@@ -1780,6 +5265,125 @@ class ChatAgent(BaseAgent):
|
|
|
1780
5265
|
|
|
1781
5266
|
return new_agent
|
|
1782
5267
|
|
|
5268
|
+
def _clone_tools(
|
|
5269
|
+
self,
|
|
5270
|
+
) -> Tuple[List[FunctionTool], List[RegisteredAgentToolkit]]:
|
|
5271
|
+
r"""Clone tools and return toolkits that need agent registration.
|
|
5272
|
+
|
|
5273
|
+
This method handles stateful toolkits by cloning them if they have
|
|
5274
|
+
a clone_for_new_session method, and collecting RegisteredAgentToolkit
|
|
5275
|
+
instances for later registration.
|
|
5276
|
+
|
|
5277
|
+
Returns:
|
|
5278
|
+
Tuple containing:
|
|
5279
|
+
- List of cloned tools/functions
|
|
5280
|
+
- List of RegisteredAgentToolkit instances need registration
|
|
5281
|
+
"""
|
|
5282
|
+
cloned_tools = []
|
|
5283
|
+
toolkits_to_register = []
|
|
5284
|
+
cloned_toolkits = {}
|
|
5285
|
+
# Cache for cloned toolkits by original toolkit id
|
|
5286
|
+
|
|
5287
|
+
for tool in self._internal_tools.values():
|
|
5288
|
+
# Check if this tool is a method bound to a toolkit instance
|
|
5289
|
+
if hasattr(tool.func, '__self__'):
|
|
5290
|
+
toolkit_instance = tool.func.__self__
|
|
5291
|
+
toolkit_id = id(toolkit_instance)
|
|
5292
|
+
|
|
5293
|
+
if toolkit_id not in cloned_toolkits:
|
|
5294
|
+
# Check if the toolkit has a clone method
|
|
5295
|
+
if hasattr(toolkit_instance, 'clone_for_new_session'):
|
|
5296
|
+
try:
|
|
5297
|
+
import uuid
|
|
5298
|
+
|
|
5299
|
+
new_session_id = str(uuid.uuid4())[:8]
|
|
5300
|
+
new_toolkit = (
|
|
5301
|
+
toolkit_instance.clone_for_new_session(
|
|
5302
|
+
new_session_id
|
|
5303
|
+
)
|
|
5304
|
+
)
|
|
5305
|
+
|
|
5306
|
+
# If this is a RegisteredAgentToolkit,
|
|
5307
|
+
# add it to registration list
|
|
5308
|
+
if isinstance(new_toolkit, RegisteredAgentToolkit):
|
|
5309
|
+
toolkits_to_register.append(new_toolkit)
|
|
5310
|
+
|
|
5311
|
+
cloned_toolkits[toolkit_id] = new_toolkit
|
|
5312
|
+
except Exception as e:
|
|
5313
|
+
logger.warning(
|
|
5314
|
+
f"Failed to clone toolkit {toolkit_instance.__class__.__name__}: {e}" # noqa:E501
|
|
5315
|
+
)
|
|
5316
|
+
# Use original toolkit if cloning fails
|
|
5317
|
+
cloned_toolkits[toolkit_id] = toolkit_instance
|
|
5318
|
+
else:
|
|
5319
|
+
# Toolkit doesn't support cloning, use original
|
|
5320
|
+
cloned_toolkits[toolkit_id] = toolkit_instance
|
|
5321
|
+
|
|
5322
|
+
# Get the method from the cloned (or original) toolkit
|
|
5323
|
+
toolkit = cloned_toolkits[toolkit_id]
|
|
5324
|
+
method_name = tool.func.__name__
|
|
5325
|
+
|
|
5326
|
+
# Check if toolkit was actually cloned or just reused
|
|
5327
|
+
toolkit_was_cloned = toolkit is not toolkit_instance
|
|
5328
|
+
|
|
5329
|
+
if hasattr(toolkit, method_name):
|
|
5330
|
+
new_method = getattr(toolkit, method_name)
|
|
5331
|
+
|
|
5332
|
+
# If toolkit wasn't cloned (stateless), preserve the
|
|
5333
|
+
# original function to maintain any enhancements/wrappers
|
|
5334
|
+
if not toolkit_was_cloned:
|
|
5335
|
+
# Toolkit is stateless, safe to reuse original function
|
|
5336
|
+
cloned_tools.append(
|
|
5337
|
+
FunctionTool(
|
|
5338
|
+
func=tool.func,
|
|
5339
|
+
openai_tool_schema=tool.get_openai_tool_schema(),
|
|
5340
|
+
)
|
|
5341
|
+
)
|
|
5342
|
+
continue
|
|
5343
|
+
|
|
5344
|
+
# Toolkit was cloned, use the new method
|
|
5345
|
+
# Wrap cloned method into a new FunctionTool,
|
|
5346
|
+
# preserving schema
|
|
5347
|
+
try:
|
|
5348
|
+
new_tool = FunctionTool(
|
|
5349
|
+
func=new_method,
|
|
5350
|
+
openai_tool_schema=tool.get_openai_tool_schema(),
|
|
5351
|
+
)
|
|
5352
|
+
cloned_tools.append(new_tool)
|
|
5353
|
+
except Exception as e:
|
|
5354
|
+
# If wrapping fails, fallback to wrapping the original
|
|
5355
|
+
# function with its schema to maintain consistency
|
|
5356
|
+
logger.warning(
|
|
5357
|
+
f"Failed to wrap cloned toolkit "
|
|
5358
|
+
f"method '{method_name}' "
|
|
5359
|
+
f"with FunctionTool: {e}. Using original "
|
|
5360
|
+
f"function with preserved schema instead."
|
|
5361
|
+
)
|
|
5362
|
+
cloned_tools.append(
|
|
5363
|
+
FunctionTool(
|
|
5364
|
+
func=tool.func,
|
|
5365
|
+
openai_tool_schema=tool.get_openai_tool_schema(),
|
|
5366
|
+
)
|
|
5367
|
+
)
|
|
5368
|
+
else:
|
|
5369
|
+
# Fallback to original function wrapped in FunctionTool
|
|
5370
|
+
cloned_tools.append(
|
|
5371
|
+
FunctionTool(
|
|
5372
|
+
func=tool.func,
|
|
5373
|
+
openai_tool_schema=tool.get_openai_tool_schema(),
|
|
5374
|
+
)
|
|
5375
|
+
)
|
|
5376
|
+
else:
|
|
5377
|
+
# Not a toolkit method, preserve FunctionTool schema directly
|
|
5378
|
+
cloned_tools.append(
|
|
5379
|
+
FunctionTool(
|
|
5380
|
+
func=tool.func,
|
|
5381
|
+
openai_tool_schema=tool.get_openai_tool_schema(),
|
|
5382
|
+
)
|
|
5383
|
+
)
|
|
5384
|
+
|
|
5385
|
+
return cloned_tools, toolkits_to_register
|
|
5386
|
+
|
|
1783
5387
|
def __repr__(self) -> str:
|
|
1784
5388
|
r"""Returns a string representation of the :obj:`ChatAgent`.
|
|
1785
5389
|
|