rossum-agent 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rossum_agent/__init__.py +9 -0
- rossum_agent/agent/__init__.py +32 -0
- rossum_agent/agent/core.py +932 -0
- rossum_agent/agent/memory.py +176 -0
- rossum_agent/agent/models.py +160 -0
- rossum_agent/agent/request_classifier.py +152 -0
- rossum_agent/agent/skills.py +132 -0
- rossum_agent/agent/types.py +5 -0
- rossum_agent/agent_logging.py +56 -0
- rossum_agent/api/__init__.py +1 -0
- rossum_agent/api/cli.py +51 -0
- rossum_agent/api/dependencies.py +190 -0
- rossum_agent/api/main.py +180 -0
- rossum_agent/api/models/__init__.py +1 -0
- rossum_agent/api/models/schemas.py +301 -0
- rossum_agent/api/routes/__init__.py +1 -0
- rossum_agent/api/routes/chats.py +95 -0
- rossum_agent/api/routes/files.py +113 -0
- rossum_agent/api/routes/health.py +44 -0
- rossum_agent/api/routes/messages.py +218 -0
- rossum_agent/api/services/__init__.py +1 -0
- rossum_agent/api/services/agent_service.py +451 -0
- rossum_agent/api/services/chat_service.py +197 -0
- rossum_agent/api/services/file_service.py +65 -0
- rossum_agent/assets/Primary_light_logo.png +0 -0
- rossum_agent/bedrock_client.py +64 -0
- rossum_agent/prompts/__init__.py +27 -0
- rossum_agent/prompts/base_prompt.py +80 -0
- rossum_agent/prompts/system_prompt.py +24 -0
- rossum_agent/py.typed +0 -0
- rossum_agent/redis_storage.py +482 -0
- rossum_agent/rossum_mcp_integration.py +123 -0
- rossum_agent/skills/hook-debugging.md +31 -0
- rossum_agent/skills/organization-setup.md +60 -0
- rossum_agent/skills/rossum-deployment.md +102 -0
- rossum_agent/skills/schema-patching.md +61 -0
- rossum_agent/skills/schema-pruning.md +23 -0
- rossum_agent/skills/ui-settings.md +45 -0
- rossum_agent/streamlit_app/__init__.py +1 -0
- rossum_agent/streamlit_app/app.py +646 -0
- rossum_agent/streamlit_app/beep_sound.py +36 -0
- rossum_agent/streamlit_app/cli.py +17 -0
- rossum_agent/streamlit_app/render_modules.py +123 -0
- rossum_agent/streamlit_app/response_formatting.py +305 -0
- rossum_agent/tools/__init__.py +214 -0
- rossum_agent/tools/core.py +173 -0
- rossum_agent/tools/deploy.py +404 -0
- rossum_agent/tools/dynamic_tools.py +365 -0
- rossum_agent/tools/file_tools.py +62 -0
- rossum_agent/tools/formula.py +187 -0
- rossum_agent/tools/skills.py +31 -0
- rossum_agent/tools/spawn_mcp.py +227 -0
- rossum_agent/tools/subagents/__init__.py +31 -0
- rossum_agent/tools/subagents/base.py +303 -0
- rossum_agent/tools/subagents/hook_debug.py +591 -0
- rossum_agent/tools/subagents/knowledge_base.py +305 -0
- rossum_agent/tools/subagents/mcp_helpers.py +47 -0
- rossum_agent/tools/subagents/schema_patching.py +471 -0
- rossum_agent/url_context.py +167 -0
- rossum_agent/user_detection.py +100 -0
- rossum_agent/utils.py +128 -0
- rossum_agent-1.0.0rc0.dist-info/METADATA +311 -0
- rossum_agent-1.0.0rc0.dist-info/RECORD +67 -0
- rossum_agent-1.0.0rc0.dist-info/WHEEL +5 -0
- rossum_agent-1.0.0rc0.dist-info/entry_points.txt +3 -0
- rossum_agent-1.0.0rc0.dist-info/licenses/LICENSE +21 -0
- rossum_agent-1.0.0rc0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,932 @@
|
|
|
1
|
+
"""Core agent module implementing the RossumAgent class with Anthropic tool use API.
|
|
2
|
+
|
|
3
|
+
This module provides the main agent loop for interacting with the Rossum platform
|
|
4
|
+
using Claude models via AWS Bedrock and MCP tools.
|
|
5
|
+
|
|
6
|
+
Streaming Architecture & AgentStep Yield Points
|
|
7
|
+
================================================
|
|
8
|
+
|
|
9
|
+
The agent streams responses via `_stream_model_response` which yields `AgentStep` objects
|
|
10
|
+
at multiple points to provide real-time updates to the client. The yield flow is:
|
|
11
|
+
|
|
12
|
+
_stream_model_response
|
|
13
|
+
│
|
|
14
|
+
├── #5 forwards from _process_stream_events ──┬── #1 Timeout flush (buffer stale after 1.5s)
|
|
15
|
+
│ ├── #2 Stream end flush (final text)
|
|
16
|
+
│ ├── #3 Thinking tokens (chain-of-thought)
|
|
17
|
+
│ └── #4 Text deltas (after initial buffer)
|
|
18
|
+
│
|
|
19
|
+
├── #6 Final answer (no tools, response complete)
|
|
20
|
+
│
|
|
21
|
+
└── #7 forwards from _execute_tools_with_progress
|
|
22
|
+
├── Tool starting (which tool is about to run)
|
|
23
|
+
└── Sub-agent progress (from nested agent tools like debug_hook)
|
|
24
|
+
|
|
25
|
+
Key concepts:
|
|
26
|
+
- Initial text buffering (INITIAL_TEXT_BUFFER_DELAY=1.5s) allows determining step type
|
|
27
|
+
(INTERMEDIATE vs FINAL_ANSWER) before streaming to client
|
|
28
|
+
- After initial flush, text tokens stream immediately
|
|
29
|
+
- Tool execution yields progress updates for UI responsiveness
|
|
30
|
+
- In a single step, a thinking block is always followed by an intermediate block
|
|
31
|
+
(tool calls or text response)
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
from __future__ import annotations
|
|
35
|
+
|
|
36
|
+
import asyncio
|
|
37
|
+
import dataclasses
|
|
38
|
+
import json
|
|
39
|
+
import logging
|
|
40
|
+
import queue
|
|
41
|
+
import random
|
|
42
|
+
import time
|
|
43
|
+
from contextvars import copy_context
|
|
44
|
+
from functools import partial
|
|
45
|
+
from typing import TYPE_CHECKING
|
|
46
|
+
|
|
47
|
+
from anthropic import APIError, APITimeoutError, RateLimitError
|
|
48
|
+
from anthropic._types import Omit
|
|
49
|
+
from anthropic.types import (
|
|
50
|
+
ContentBlockStopEvent,
|
|
51
|
+
InputJSONDelta,
|
|
52
|
+
Message,
|
|
53
|
+
MessageParam,
|
|
54
|
+
MessageStreamEvent,
|
|
55
|
+
RawContentBlockDeltaEvent,
|
|
56
|
+
RawContentBlockStartEvent,
|
|
57
|
+
TextBlockParam,
|
|
58
|
+
TextDelta,
|
|
59
|
+
ThinkingBlock,
|
|
60
|
+
ThinkingConfigEnabledParam,
|
|
61
|
+
ThinkingDelta,
|
|
62
|
+
ToolParam,
|
|
63
|
+
ToolUseBlock,
|
|
64
|
+
)
|
|
65
|
+
from pydantic import BaseModel
|
|
66
|
+
|
|
67
|
+
from rossum_agent.agent.memory import AgentMemory, MemoryStep
|
|
68
|
+
from rossum_agent.agent.models import (
|
|
69
|
+
AgentConfig,
|
|
70
|
+
AgentStep,
|
|
71
|
+
StepType,
|
|
72
|
+
StreamDelta,
|
|
73
|
+
ThinkingBlockData,
|
|
74
|
+
ToolCall,
|
|
75
|
+
ToolResult,
|
|
76
|
+
truncate_content,
|
|
77
|
+
)
|
|
78
|
+
from rossum_agent.agent.request_classifier import RequestScope, classify_request, generate_rejection_response
|
|
79
|
+
from rossum_agent.api.models.schemas import TokenUsageBreakdown
|
|
80
|
+
from rossum_agent.bedrock_client import create_bedrock_client, get_model_id
|
|
81
|
+
from rossum_agent.rossum_mcp_integration import MCPConnection, mcp_tools_to_anthropic_format
|
|
82
|
+
from rossum_agent.tools import (
|
|
83
|
+
DEPLOY_TOOLS,
|
|
84
|
+
DISCOVERY_TOOL_NAME,
|
|
85
|
+
SubAgentProgress,
|
|
86
|
+
SubAgentTokenUsage,
|
|
87
|
+
execute_internal_tool,
|
|
88
|
+
execute_tool,
|
|
89
|
+
get_deploy_tool_names,
|
|
90
|
+
get_deploy_tools,
|
|
91
|
+
get_dynamic_tools,
|
|
92
|
+
get_internal_tool_names,
|
|
93
|
+
get_internal_tools,
|
|
94
|
+
preload_categories_for_request,
|
|
95
|
+
reset_dynamic_tools,
|
|
96
|
+
set_mcp_connection,
|
|
97
|
+
set_progress_callback,
|
|
98
|
+
set_token_callback,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
if TYPE_CHECKING:
|
|
102
|
+
from collections.abc import AsyncIterator, Iterator
|
|
103
|
+
from typing import Literal
|
|
104
|
+
|
|
105
|
+
from anthropic import AnthropicBedrock
|
|
106
|
+
|
|
107
|
+
from rossum_agent.agent.types import UserContent
|
|
108
|
+
|
|
109
|
+
logger = logging.getLogger(__name__)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
RATE_LIMIT_MAX_RETRIES = 5
|
|
113
|
+
RATE_LIMIT_BASE_DELAY = 2.0
|
|
114
|
+
RATE_LIMIT_MAX_DELAY = 60.0
|
|
115
|
+
|
|
116
|
+
# Buffer text tokens for this duration before first flush to allow time to determine
|
|
117
|
+
# whether this is an intermediate step (with tool calls) or final answer text.
|
|
118
|
+
# This delay helps correctly classify the step type before streaming to the client.
|
|
119
|
+
INITIAL_TEXT_BUFFER_DELAY = 1.5
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def _parse_json_encoded_strings(arguments: dict) -> dict:
|
|
123
|
+
"""Recursively parse JSON-encoded strings in tool arguments.
|
|
124
|
+
|
|
125
|
+
LLMs sometimes generate JSON-encoded strings for list/dict arguments instead of
|
|
126
|
+
actual lists/dicts. This function detects and parses such strings.
|
|
127
|
+
|
|
128
|
+
For example, converts:
|
|
129
|
+
{"fields_to_keep": "[\"a\", \"b\"]"}
|
|
130
|
+
To:
|
|
131
|
+
{"fields_to_keep": ["a", "b"]}
|
|
132
|
+
"""
|
|
133
|
+
# Parameters that should remain as JSON strings (not parsed to lists/dicts)
|
|
134
|
+
keep_as_string = {"changes"}
|
|
135
|
+
|
|
136
|
+
result = {}
|
|
137
|
+
for key, value in arguments.items():
|
|
138
|
+
if key in keep_as_string:
|
|
139
|
+
result[key] = value
|
|
140
|
+
elif isinstance(value, str) and value.startswith(("[", "{")):
|
|
141
|
+
try:
|
|
142
|
+
parsed = json.loads(value)
|
|
143
|
+
if isinstance(parsed, (list, dict)):
|
|
144
|
+
result[key] = parsed
|
|
145
|
+
else:
|
|
146
|
+
result[key] = value
|
|
147
|
+
except json.JSONDecodeError:
|
|
148
|
+
result[key] = value
|
|
149
|
+
elif isinstance(value, dict):
|
|
150
|
+
result[key] = _parse_json_encoded_strings(value)
|
|
151
|
+
else:
|
|
152
|
+
result[key] = value
|
|
153
|
+
return result
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
@dataclasses.dataclass
|
|
157
|
+
class _StreamState:
|
|
158
|
+
"""Mutable state for streaming model response.
|
|
159
|
+
|
|
160
|
+
Attributes:
|
|
161
|
+
first_text_token_time: Timestamp of when the first text token was received.
|
|
162
|
+
Used to implement initial buffering delay (see INITIAL_TEXT_BUFFER_DELAY).
|
|
163
|
+
initial_buffer_flushed: Whether the initial buffer has been flushed after
|
|
164
|
+
the delay period. Once True, text tokens are streamed immediately.
|
|
165
|
+
"""
|
|
166
|
+
|
|
167
|
+
thinking_text: str = ""
|
|
168
|
+
response_text: str = ""
|
|
169
|
+
final_message: Message | None = None
|
|
170
|
+
text_buffer: list[str] = dataclasses.field(default_factory=list)
|
|
171
|
+
tool_calls: list[ToolCall] = dataclasses.field(default_factory=list)
|
|
172
|
+
pending_tools: dict[int, dict[str, str]] = dataclasses.field(default_factory=dict)
|
|
173
|
+
first_text_token_time: float | None = None
|
|
174
|
+
initial_buffer_flushed: bool = False
|
|
175
|
+
|
|
176
|
+
def _should_flush_initial_buffer(self) -> bool:
|
|
177
|
+
"""Check if the initial buffer delay has elapsed and buffer should be flushed."""
|
|
178
|
+
if self.initial_buffer_flushed:
|
|
179
|
+
return True
|
|
180
|
+
if self.first_text_token_time is None:
|
|
181
|
+
return False
|
|
182
|
+
return (time.monotonic() - self.first_text_token_time) >= INITIAL_TEXT_BUFFER_DELAY
|
|
183
|
+
|
|
184
|
+
def get_step_type(self) -> StepType:
|
|
185
|
+
"""Get the step type based on whether tool calls are pending."""
|
|
186
|
+
return StepType.INTERMEDIATE if self.pending_tools or self.tool_calls else StepType.FINAL_ANSWER
|
|
187
|
+
|
|
188
|
+
def flush_buffer(self, step_num: int, step_type: StepType) -> AgentStep | None:
|
|
189
|
+
"""Flush text buffer and return AgentStep if buffer had content."""
|
|
190
|
+
if not self.text_buffer:
|
|
191
|
+
return None
|
|
192
|
+
buffered_text = "".join(self.text_buffer)
|
|
193
|
+
self.text_buffer.clear()
|
|
194
|
+
self.response_text += buffered_text
|
|
195
|
+
return AgentStep(
|
|
196
|
+
step_number=step_num,
|
|
197
|
+
thinking=self.thinking_text or None,
|
|
198
|
+
is_streaming=True,
|
|
199
|
+
text_delta=buffered_text,
|
|
200
|
+
accumulated_text=self.response_text,
|
|
201
|
+
step_type=step_type,
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
@property
|
|
205
|
+
def contains_thinking(self) -> bool:
|
|
206
|
+
return bool(self.thinking_text)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
class RossumAgent:
|
|
210
|
+
"""Claude-powered agent for Rossum document processing.
|
|
211
|
+
|
|
212
|
+
This agent uses Anthropic's tool use API to interact with the Rossum platform
|
|
213
|
+
via MCP tools. It maintains conversation state across multiple turns and
|
|
214
|
+
supports streaming responses.
|
|
215
|
+
|
|
216
|
+
Memory is stored as structured MemoryStep objects and rebuilt into messages
|
|
217
|
+
each call.
|
|
218
|
+
"""
|
|
219
|
+
|
|
220
|
+
def __init__(
|
|
221
|
+
self,
|
|
222
|
+
client: AnthropicBedrock,
|
|
223
|
+
mcp_connection: MCPConnection,
|
|
224
|
+
system_prompt: str,
|
|
225
|
+
config: AgentConfig | None = None,
|
|
226
|
+
additional_tools: list[ToolParam] | None = None,
|
|
227
|
+
) -> None:
|
|
228
|
+
self.client = client
|
|
229
|
+
self.mcp_connection = mcp_connection
|
|
230
|
+
self.system_prompt = system_prompt
|
|
231
|
+
self.config = config or AgentConfig()
|
|
232
|
+
self.additional_tools = additional_tools or []
|
|
233
|
+
|
|
234
|
+
self.memory = AgentMemory()
|
|
235
|
+
self._tools_cache: list[ToolParam] | None = None
|
|
236
|
+
self._total_input_tokens: int = 0
|
|
237
|
+
self._total_output_tokens: int = 0
|
|
238
|
+
# Token breakdown tracking
|
|
239
|
+
self._main_agent_input_tokens: int = 0
|
|
240
|
+
self._main_agent_output_tokens: int = 0
|
|
241
|
+
self._sub_agent_input_tokens: int = 0
|
|
242
|
+
self._sub_agent_output_tokens: int = 0
|
|
243
|
+
self._sub_agent_usage: dict[str, tuple[int, int]] = {} # tool_name -> (input, output)
|
|
244
|
+
|
|
245
|
+
@property
|
|
246
|
+
def messages(self) -> list[MessageParam]:
|
|
247
|
+
"""Get the current conversation messages (rebuilt from memory)."""
|
|
248
|
+
return self.memory.write_to_messages()
|
|
249
|
+
|
|
250
|
+
def reset(self) -> None:
|
|
251
|
+
"""Reset the agent's conversation state."""
|
|
252
|
+
self.memory.reset()
|
|
253
|
+
self._total_input_tokens = 0
|
|
254
|
+
self._total_output_tokens = 0
|
|
255
|
+
self._main_agent_input_tokens = 0
|
|
256
|
+
self._main_agent_output_tokens = 0
|
|
257
|
+
self._sub_agent_input_tokens = 0
|
|
258
|
+
self._sub_agent_output_tokens = 0
|
|
259
|
+
self._sub_agent_usage = {}
|
|
260
|
+
reset_dynamic_tools()
|
|
261
|
+
|
|
262
|
+
def _accumulate_sub_agent_tokens(self, usage: SubAgentTokenUsage) -> None:
|
|
263
|
+
"""Accumulate token usage from a sub-agent call."""
|
|
264
|
+
self._total_input_tokens += usage.input_tokens
|
|
265
|
+
self._total_output_tokens += usage.output_tokens
|
|
266
|
+
self._sub_agent_input_tokens += usage.input_tokens
|
|
267
|
+
self._sub_agent_output_tokens += usage.output_tokens
|
|
268
|
+
# Track per sub-agent
|
|
269
|
+
prev_in, prev_out = self._sub_agent_usage.get(usage.tool_name, (0, 0))
|
|
270
|
+
self._sub_agent_usage[usage.tool_name] = (prev_in + usage.input_tokens, prev_out + usage.output_tokens)
|
|
271
|
+
logger.info(
|
|
272
|
+
f"Sub-agent '{usage.tool_name}' token usage (iter {usage.iteration}): "
|
|
273
|
+
f"in={usage.input_tokens}, out={usage.output_tokens}, "
|
|
274
|
+
f"cumulative total: in={self._total_input_tokens}, out={self._total_output_tokens}"
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
def get_token_usage_breakdown(self) -> TokenUsageBreakdown:
|
|
278
|
+
"""Get token usage breakdown by agent vs sub-agents."""
|
|
279
|
+
return TokenUsageBreakdown.from_raw_counts(
|
|
280
|
+
total_input=self._total_input_tokens,
|
|
281
|
+
total_output=self._total_output_tokens,
|
|
282
|
+
main_input=self._main_agent_input_tokens,
|
|
283
|
+
main_output=self._main_agent_output_tokens,
|
|
284
|
+
sub_input=self._sub_agent_input_tokens,
|
|
285
|
+
sub_output=self._sub_agent_output_tokens,
|
|
286
|
+
sub_by_tool=self._sub_agent_usage,
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
def log_token_usage_summary(self) -> None:
|
|
290
|
+
"""Log a human-readable token usage summary."""
|
|
291
|
+
breakdown = self.get_token_usage_breakdown()
|
|
292
|
+
logger.info("\n".join(breakdown.format_summary_lines()))
|
|
293
|
+
|
|
294
|
+
def add_user_message(self, content: UserContent) -> None:
|
|
295
|
+
"""Add a user message to the conversation history."""
|
|
296
|
+
self.memory.add_task(content)
|
|
297
|
+
|
|
298
|
+
def add_assistant_message(self, content: str) -> None:
|
|
299
|
+
"""Add an assistant message to the conversation history.
|
|
300
|
+
|
|
301
|
+
This creates a MemoryStep with text set, which ensures the
|
|
302
|
+
message is properly serialized when rebuilding conversation history.
|
|
303
|
+
For proper conversation flow with tool use, use the run() method instead.
|
|
304
|
+
"""
|
|
305
|
+
step = MemoryStep(step_number=0, text=content)
|
|
306
|
+
self.memory.add_step(step)
|
|
307
|
+
|
|
308
|
+
async def _get_tools(self) -> list[ToolParam]:
|
|
309
|
+
"""Get all available tools in Anthropic format.
|
|
310
|
+
|
|
311
|
+
Initially loads only the discovery tool from MCP (list_tool_categories)
|
|
312
|
+
plus internal tools and deploy tools. Additional MCP tools are loaded dynamically
|
|
313
|
+
via load_tool_category and added to the tool list through get_dynamic_tools().
|
|
314
|
+
"""
|
|
315
|
+
if self._tools_cache is None:
|
|
316
|
+
# Load only discovery tool from MCP initially to reduce context usage
|
|
317
|
+
mcp_tools = await self.mcp_connection.get_tools()
|
|
318
|
+
discovery_tools = [t for t in mcp_tools if t.name == DISCOVERY_TOOL_NAME]
|
|
319
|
+
self._tools_cache = (
|
|
320
|
+
mcp_tools_to_anthropic_format(discovery_tools)
|
|
321
|
+
+ get_internal_tools()
|
|
322
|
+
+ get_deploy_tools()
|
|
323
|
+
+ self.additional_tools
|
|
324
|
+
)
|
|
325
|
+
# Include dynamically loaded tools
|
|
326
|
+
return self._tools_cache + get_dynamic_tools()
|
|
327
|
+
|
|
328
|
+
def _serialize_tool_result(self, result: object) -> str:
|
|
329
|
+
"""Serialize a tool result to a string for storage in context.
|
|
330
|
+
|
|
331
|
+
Handles pydantic models, dataclasses, dicts, lists, and other objects properly.
|
|
332
|
+
"""
|
|
333
|
+
if result is None:
|
|
334
|
+
return "Tool executed successfully (no output)"
|
|
335
|
+
|
|
336
|
+
# Handle dataclasses (check before pydantic since pydantic models aren't dataclasses)
|
|
337
|
+
if dataclasses.is_dataclass(result) and not isinstance(result, type):
|
|
338
|
+
return json.dumps(dataclasses.asdict(result), indent=2, default=str)
|
|
339
|
+
|
|
340
|
+
# Handle lists of dataclasses
|
|
341
|
+
if isinstance(result, list) and result and dataclasses.is_dataclass(result[0]):
|
|
342
|
+
return json.dumps(
|
|
343
|
+
[
|
|
344
|
+
dataclasses.asdict(item)
|
|
345
|
+
for item in result
|
|
346
|
+
if dataclasses.is_dataclass(item) and not isinstance(item, type)
|
|
347
|
+
],
|
|
348
|
+
indent=2,
|
|
349
|
+
default=str,
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
# Handle pydantic models
|
|
353
|
+
# Use mode='json' to ensure nested models are properly serialized to JSON-compatible dicts
|
|
354
|
+
if isinstance(result, BaseModel):
|
|
355
|
+
return json.dumps(result.model_dump(mode="json"), indent=2, default=str)
|
|
356
|
+
|
|
357
|
+
# Handle lists of pydantic models
|
|
358
|
+
if isinstance(result, list) and result and isinstance(result[0], BaseModel):
|
|
359
|
+
return json.dumps(
|
|
360
|
+
[item.model_dump(mode="json") for item in result if isinstance(item, BaseModel)],
|
|
361
|
+
indent=2,
|
|
362
|
+
default=str,
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
# Handle dicts and regular lists
|
|
366
|
+
if isinstance(result, dict | list):
|
|
367
|
+
return json.dumps(result, indent=2, default=str)
|
|
368
|
+
|
|
369
|
+
# Fallback to string representation
|
|
370
|
+
return str(result)
|
|
371
|
+
|
|
372
|
+
def _sync_stream_events(
|
|
373
|
+
self, model_id: str, messages: list[MessageParam], tools: list[ToolParam]
|
|
374
|
+
) -> Iterator[tuple[MessageStreamEvent | None, Message | None]]:
|
|
375
|
+
"""Synchronous generator that yields stream events and final message.
|
|
376
|
+
|
|
377
|
+
This runs in a thread pool to avoid blocking the event loop.
|
|
378
|
+
|
|
379
|
+
Yields:
|
|
380
|
+
Tuples of (event, None) for each stream event, then (None, final_message) at the end.
|
|
381
|
+
"""
|
|
382
|
+
thinking_config: ThinkingConfigEnabledParam = {
|
|
383
|
+
"type": "enabled",
|
|
384
|
+
"budget_tokens": self.config.thinking_budget_tokens,
|
|
385
|
+
}
|
|
386
|
+
with self.client.messages.stream(
|
|
387
|
+
model=model_id,
|
|
388
|
+
max_tokens=self.config.max_output_tokens,
|
|
389
|
+
system=self.system_prompt,
|
|
390
|
+
messages=messages,
|
|
391
|
+
tools=tools if tools else Omit(),
|
|
392
|
+
thinking=thinking_config,
|
|
393
|
+
temperature=self.config.temperature,
|
|
394
|
+
) as stream:
|
|
395
|
+
for event in stream:
|
|
396
|
+
yield (event, None)
|
|
397
|
+
yield (None, stream.get_final_message())
|
|
398
|
+
|
|
399
|
+
def _process_stream_event(
|
|
400
|
+
self,
|
|
401
|
+
event: MessageStreamEvent,
|
|
402
|
+
pending_tools: dict[int, dict[str, str]],
|
|
403
|
+
tool_calls: list[ToolCall],
|
|
404
|
+
) -> StreamDelta | None:
|
|
405
|
+
"""Process a single stream event.
|
|
406
|
+
|
|
407
|
+
Returns:
|
|
408
|
+
StreamDelta with kind="thinking" or "text", or None if no delta.
|
|
409
|
+
"""
|
|
410
|
+
if isinstance(event, RawContentBlockStartEvent):
|
|
411
|
+
if isinstance(event.content_block, ToolUseBlock):
|
|
412
|
+
pending_tools[event.index] = {
|
|
413
|
+
"name": event.content_block.name,
|
|
414
|
+
"id": event.content_block.id,
|
|
415
|
+
"json": "",
|
|
416
|
+
}
|
|
417
|
+
|
|
418
|
+
elif isinstance(event, RawContentBlockDeltaEvent):
|
|
419
|
+
if isinstance(event.delta, ThinkingDelta):
|
|
420
|
+
return StreamDelta(kind="thinking", content=event.delta.thinking)
|
|
421
|
+
if isinstance(event.delta, TextDelta):
|
|
422
|
+
return StreamDelta(kind="text", content=event.delta.text)
|
|
423
|
+
if isinstance(event.delta, InputJSONDelta) and event.index in pending_tools:
|
|
424
|
+
pending_tools[event.index]["json"] += event.delta.partial_json
|
|
425
|
+
|
|
426
|
+
elif isinstance(event, ContentBlockStopEvent) and event.index in pending_tools:
|
|
427
|
+
tool_info = pending_tools.pop(event.index)
|
|
428
|
+
try:
|
|
429
|
+
arguments = json.loads(tool_info["json"]) if tool_info["json"] else {}
|
|
430
|
+
arguments = _parse_json_encoded_strings(arguments)
|
|
431
|
+
except json.JSONDecodeError as e:
|
|
432
|
+
logger.warning("Failed to decode tool arguments for %s: %s", tool_info["name"], e)
|
|
433
|
+
arguments = {}
|
|
434
|
+
tool_calls.append(ToolCall(id=tool_info["id"], name=tool_info["name"], arguments=arguments))
|
|
435
|
+
|
|
436
|
+
return None
|
|
437
|
+
|
|
438
|
+
def _extract_thinking_blocks(self, message: Message) -> list[ThinkingBlockData]:
|
|
439
|
+
"""Extract thinking blocks from a message for preserving in conversation history."""
|
|
440
|
+
return [
|
|
441
|
+
ThinkingBlockData(thinking=block.thinking, signature=block.signature)
|
|
442
|
+
for block in message.content
|
|
443
|
+
if isinstance(block, ThinkingBlock)
|
|
444
|
+
]
|
|
445
|
+
|
|
446
|
+
def _handle_text_delta(
|
|
447
|
+
self, step_num: int, content: str, delta_kind: Literal["thinking", "text"], state: _StreamState
|
|
448
|
+
) -> AgentStep | None:
|
|
449
|
+
"""Handle a text delta, buffering or flushing as appropriate."""
|
|
450
|
+
if state.first_text_token_time is None:
|
|
451
|
+
state.first_text_token_time = time.monotonic()
|
|
452
|
+
else:
|
|
453
|
+
if time.monotonic() - state.first_text_token_time > INITIAL_TEXT_BUFFER_DELAY:
|
|
454
|
+
state.initial_buffer_flushed = True
|
|
455
|
+
|
|
456
|
+
state.text_buffer.append(content)
|
|
457
|
+
|
|
458
|
+
if state.initial_buffer_flushed:
|
|
459
|
+
step_type = (
|
|
460
|
+
StepType.INTERMEDIATE if state.contains_thinking and delta_kind == "text" else state.get_step_type()
|
|
461
|
+
)
|
|
462
|
+
return state.flush_buffer(step_num, step_type)
|
|
463
|
+
if state.pending_tools or state.tool_calls:
|
|
464
|
+
state.initial_buffer_flushed = True
|
|
465
|
+
return state.flush_buffer(step_num, StepType.INTERMEDIATE)
|
|
466
|
+
return None
|
|
467
|
+
|
|
468
|
+
async def _process_stream_events(
|
|
469
|
+
self,
|
|
470
|
+
step_num: int,
|
|
471
|
+
event_queue: queue.Queue[tuple[MessageStreamEvent | None, Message | None] | None],
|
|
472
|
+
state: _StreamState,
|
|
473
|
+
) -> AsyncIterator[AgentStep]:
|
|
474
|
+
"""Process stream events and yield AgentSteps.
|
|
475
|
+
|
|
476
|
+
Text tokens are buffered for INITIAL_TEXT_BUFFER_DELAY seconds after the first
|
|
477
|
+
text token is received. This allows time to determine whether the response will
|
|
478
|
+
include tool calls (intermediate step) or is a final answer, enabling correct
|
|
479
|
+
step type classification before streaming to the client.
|
|
480
|
+
|
|
481
|
+
After the initial buffer is flushed, subsequent text tokens are streamed immediately.
|
|
482
|
+
"""
|
|
483
|
+
while True:
|
|
484
|
+
try:
|
|
485
|
+
item = await asyncio.to_thread(event_queue.get, timeout=INITIAL_TEXT_BUFFER_DELAY)
|
|
486
|
+
except queue.Empty:
|
|
487
|
+
# Yield #1: Timeout-based flush of initial text buffer (ensures responsiveness during model pauses)
|
|
488
|
+
if (
|
|
489
|
+
state.text_buffer
|
|
490
|
+
and state._should_flush_initial_buffer()
|
|
491
|
+
and (step := state.flush_buffer(step_num, state.get_step_type()))
|
|
492
|
+
):
|
|
493
|
+
state.initial_buffer_flushed = True
|
|
494
|
+
yield step
|
|
495
|
+
continue
|
|
496
|
+
|
|
497
|
+
if item is None:
|
|
498
|
+
# Yield #2: Stream ended - flush any remaining buffered text
|
|
499
|
+
if step := state.flush_buffer(step_num, state.get_step_type()):
|
|
500
|
+
yield step
|
|
501
|
+
break
|
|
502
|
+
|
|
503
|
+
event, final_msg = item
|
|
504
|
+
if final_msg is not None:
|
|
505
|
+
state.final_message = final_msg
|
|
506
|
+
continue
|
|
507
|
+
|
|
508
|
+
if event is None:
|
|
509
|
+
continue
|
|
510
|
+
|
|
511
|
+
delta = self._process_stream_event(event, state.pending_tools, state.tool_calls)
|
|
512
|
+
if not delta:
|
|
513
|
+
continue
|
|
514
|
+
|
|
515
|
+
if delta.kind == "thinking":
|
|
516
|
+
state.thinking_text += delta.content
|
|
517
|
+
# Yield #3: Streaming thinking tokens (extended thinking / chain-of-thought)
|
|
518
|
+
yield AgentStep(
|
|
519
|
+
step_number=step_num,
|
|
520
|
+
thinking=state.thinking_text,
|
|
521
|
+
is_streaming=True,
|
|
522
|
+
step_type=StepType.THINKING,
|
|
523
|
+
)
|
|
524
|
+
if state.first_text_token_time is None:
|
|
525
|
+
state.first_text_token_time = time.monotonic()
|
|
526
|
+
continue
|
|
527
|
+
|
|
528
|
+
# Yield #4: Text delta - immediate flush after initial buffer period or when tool calls detected
|
|
529
|
+
if step := self._handle_text_delta(step_num, delta.content, delta.kind, state):
|
|
530
|
+
yield step
|
|
531
|
+
|
|
532
|
+
async def _stream_model_response(self, step_num: int) -> AsyncIterator[AgentStep]:
|
|
533
|
+
"""Stream model response, yielding partial steps as thinking streams in.
|
|
534
|
+
|
|
535
|
+
Extended thinking separates the model's internal reasoning (thinking blocks)
|
|
536
|
+
from its final response (text blocks). This allows distinguishing between
|
|
537
|
+
the chain-of-thought process and the actual answer.
|
|
538
|
+
|
|
539
|
+
Yields:
|
|
540
|
+
AgentStep objects - partial steps while streaming, then final step with tool results.
|
|
541
|
+
"""
|
|
542
|
+
messages = self.memory.write_to_messages()
|
|
543
|
+
tools = await self._get_tools()
|
|
544
|
+
model_id = get_model_id()
|
|
545
|
+
state = _StreamState()
|
|
546
|
+
|
|
547
|
+
event_queue: queue.Queue[tuple[MessageStreamEvent | None, Message | None] | None] = queue.Queue()
|
|
548
|
+
|
|
549
|
+
def producer() -> None:
|
|
550
|
+
for item in self._sync_stream_events(model_id, messages, tools):
|
|
551
|
+
event_queue.put(item)
|
|
552
|
+
event_queue.put(None)
|
|
553
|
+
|
|
554
|
+
ctx = copy_context()
|
|
555
|
+
producer_task = asyncio.get_event_loop().run_in_executor(None, partial(ctx.run, producer))
|
|
556
|
+
|
|
557
|
+
# Yield #5: Forward all streaming steps from _process_stream_events (yields #1-4)
|
|
558
|
+
async for step in self._process_stream_events(step_num, event_queue, state):
|
|
559
|
+
yield step
|
|
560
|
+
|
|
561
|
+
await producer_task
|
|
562
|
+
|
|
563
|
+
if state.final_message is None:
|
|
564
|
+
raise RuntimeError("Stream ended without final message")
|
|
565
|
+
|
|
566
|
+
thinking_blocks = self._extract_thinking_blocks(state.final_message)
|
|
567
|
+
input_tokens = state.final_message.usage.input_tokens
|
|
568
|
+
output_tokens = state.final_message.usage.output_tokens
|
|
569
|
+
self._total_input_tokens += input_tokens
|
|
570
|
+
self._total_output_tokens += output_tokens
|
|
571
|
+
self._main_agent_input_tokens += input_tokens
|
|
572
|
+
self._main_agent_output_tokens += output_tokens
|
|
573
|
+
logger.info(
|
|
574
|
+
f"Step {step_num}: input_tokens={input_tokens}, output_tokens={output_tokens}, "
|
|
575
|
+
f"total_input={self._total_input_tokens}, total_output={self._total_output_tokens}"
|
|
576
|
+
)
|
|
577
|
+
|
|
578
|
+
step = AgentStep(
|
|
579
|
+
step_number=step_num,
|
|
580
|
+
thinking=state.thinking_text if state.thinking_text else None,
|
|
581
|
+
tool_calls=state.tool_calls,
|
|
582
|
+
is_streaming=False,
|
|
583
|
+
input_tokens=input_tokens,
|
|
584
|
+
output_tokens=output_tokens,
|
|
585
|
+
step_type=StepType.FINAL_ANSWER if not state.tool_calls else StepType.INTERMEDIATE,
|
|
586
|
+
)
|
|
587
|
+
|
|
588
|
+
if not state.tool_calls:
|
|
589
|
+
step.final_answer = state.response_text or None
|
|
590
|
+
step.is_final = True
|
|
591
|
+
memory_step = MemoryStep(
|
|
592
|
+
step_number=step_num,
|
|
593
|
+
text=state.response_text if state.response_text else None,
|
|
594
|
+
input_tokens=input_tokens,
|
|
595
|
+
output_tokens=output_tokens,
|
|
596
|
+
)
|
|
597
|
+
self.memory.add_step(memory_step)
|
|
598
|
+
# Yield #6: Final answer step (no tool calls, response complete)
|
|
599
|
+
yield step
|
|
600
|
+
return
|
|
601
|
+
|
|
602
|
+
# Yield #7: Forward tool execution progress steps from _execute_tools_with_progress
|
|
603
|
+
async for step_or_result in self._execute_tools_with_progress(
|
|
604
|
+
step_num, state.response_text, state.tool_calls, step, input_tokens, output_tokens, thinking_blocks
|
|
605
|
+
):
|
|
606
|
+
yield step_or_result
|
|
607
|
+
|
|
608
|
+
async def _execute_tools_with_progress(
|
|
609
|
+
self,
|
|
610
|
+
step_num: int,
|
|
611
|
+
thinking_text: str,
|
|
612
|
+
tool_calls: list[ToolCall],
|
|
613
|
+
step: AgentStep,
|
|
614
|
+
input_tokens: int,
|
|
615
|
+
output_tokens: int,
|
|
616
|
+
thinking_blocks: list[ThinkingBlockData] | None = None,
|
|
617
|
+
) -> AsyncIterator[AgentStep]:
|
|
618
|
+
"""Execute tools in parallel and yield progress updates."""
|
|
619
|
+
memory_step = MemoryStep(
|
|
620
|
+
step_number=step_num,
|
|
621
|
+
text=thinking_text or None,
|
|
622
|
+
tool_calls=tool_calls,
|
|
623
|
+
thinking_blocks=thinking_blocks or [],
|
|
624
|
+
input_tokens=input_tokens,
|
|
625
|
+
output_tokens=output_tokens,
|
|
626
|
+
)
|
|
627
|
+
|
|
628
|
+
total_tools = len(tool_calls)
|
|
629
|
+
|
|
630
|
+
yield AgentStep(
|
|
631
|
+
step_number=step_num,
|
|
632
|
+
thinking=thinking_text or None,
|
|
633
|
+
tool_calls=tool_calls,
|
|
634
|
+
is_streaming=True,
|
|
635
|
+
current_tool=None,
|
|
636
|
+
tool_progress=(0, total_tools),
|
|
637
|
+
step_type=StepType.INTERMEDIATE,
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
progress_queue: asyncio.Queue[AgentStep] = asyncio.Queue()
|
|
641
|
+
results_by_id: dict[str, ToolResult] = {}
|
|
642
|
+
|
|
643
|
+
async def execute_single_tool(tool_call: ToolCall, idx: int) -> None:
|
|
644
|
+
tool_progress = (idx, total_tools)
|
|
645
|
+
async for progress_or_result in self._execute_tool_with_progress(
|
|
646
|
+
tool_call, step_num, tool_calls, tool_progress
|
|
647
|
+
):
|
|
648
|
+
if isinstance(progress_or_result, AgentStep):
|
|
649
|
+
await progress_queue.put(progress_or_result)
|
|
650
|
+
elif isinstance(progress_or_result, ToolResult):
|
|
651
|
+
results_by_id[tool_call.id] = progress_or_result
|
|
652
|
+
|
|
653
|
+
tasks = [
|
|
654
|
+
asyncio.create_task(execute_single_tool(tool_call, idx)) for idx, tool_call in enumerate(tool_calls, 1)
|
|
655
|
+
]
|
|
656
|
+
|
|
657
|
+
pending = set(tasks)
|
|
658
|
+
while pending:
|
|
659
|
+
_done, pending = await asyncio.wait(pending, timeout=0.05, return_when=asyncio.FIRST_COMPLETED)
|
|
660
|
+
|
|
661
|
+
while not progress_queue.empty():
|
|
662
|
+
yield progress_queue.get_nowait()
|
|
663
|
+
|
|
664
|
+
while not progress_queue.empty():
|
|
665
|
+
yield progress_queue.get_nowait()
|
|
666
|
+
|
|
667
|
+
tool_results = [results_by_id[tc.id] for tc in tool_calls if tc.id in results_by_id]
|
|
668
|
+
|
|
669
|
+
step.tool_results = tool_results
|
|
670
|
+
memory_step.tool_results = tool_results
|
|
671
|
+
|
|
672
|
+
self.memory.add_step(memory_step)
|
|
673
|
+
|
|
674
|
+
yield step
|
|
675
|
+
|
|
676
|
+
def _drain_token_queue(self, token_queue: queue.Queue[SubAgentTokenUsage]) -> None:
|
|
677
|
+
"""Drain all pending token usage from the queue."""
|
|
678
|
+
while True:
|
|
679
|
+
try:
|
|
680
|
+
usage = token_queue.get_nowait()
|
|
681
|
+
self._accumulate_sub_agent_tokens(usage)
|
|
682
|
+
except queue.Empty:
|
|
683
|
+
break
|
|
684
|
+
|
|
685
|
+
async def _execute_tool_with_progress(
|
|
686
|
+
self, tool_call: ToolCall, step_num: int, tool_calls: list[ToolCall], tool_progress: tuple[int, int]
|
|
687
|
+
) -> AsyncIterator[AgentStep | ToolResult]:
|
|
688
|
+
"""Execute a tool and yield progress updates for sub-agents.
|
|
689
|
+
|
|
690
|
+
For tools with sub-agents (like debug_hook), this yields AgentStep updates
|
|
691
|
+
with sub_agent_progress. Always yields the final ToolResult.
|
|
692
|
+
"""
|
|
693
|
+
progress_queue: queue.Queue[SubAgentProgress] = queue.Queue()
|
|
694
|
+
token_queue: queue.Queue[SubAgentTokenUsage] = queue.Queue()
|
|
695
|
+
|
|
696
|
+
def progress_callback(progress: SubAgentProgress) -> None:
|
|
697
|
+
progress_queue.put(progress)
|
|
698
|
+
|
|
699
|
+
def token_callback(usage: SubAgentTokenUsage) -> None:
|
|
700
|
+
token_queue.put(usage)
|
|
701
|
+
|
|
702
|
+
try:
|
|
703
|
+
if tool_call.name in get_internal_tool_names():
|
|
704
|
+
set_progress_callback(progress_callback)
|
|
705
|
+
set_token_callback(token_callback)
|
|
706
|
+
|
|
707
|
+
loop = asyncio.get_event_loop()
|
|
708
|
+
ctx = copy_context()
|
|
709
|
+
future = loop.run_in_executor(
|
|
710
|
+
None, partial(ctx.run, execute_internal_tool, tool_call.name, tool_call.arguments)
|
|
711
|
+
)
|
|
712
|
+
|
|
713
|
+
while not future.done():
|
|
714
|
+
try:
|
|
715
|
+
progress = progress_queue.get_nowait()
|
|
716
|
+
yield AgentStep(
|
|
717
|
+
step_number=step_num,
|
|
718
|
+
tool_calls=tool_calls,
|
|
719
|
+
is_streaming=True,
|
|
720
|
+
current_tool=tool_call.name,
|
|
721
|
+
tool_progress=tool_progress,
|
|
722
|
+
sub_agent_progress=progress,
|
|
723
|
+
step_type=StepType.INTERMEDIATE,
|
|
724
|
+
)
|
|
725
|
+
except queue.Empty:
|
|
726
|
+
pass
|
|
727
|
+
|
|
728
|
+
self._drain_token_queue(token_queue)
|
|
729
|
+
await asyncio.sleep(0.1)
|
|
730
|
+
|
|
731
|
+
self._drain_token_queue(token_queue)
|
|
732
|
+
|
|
733
|
+
result = future.result()
|
|
734
|
+
content = str(result)
|
|
735
|
+
set_progress_callback(None)
|
|
736
|
+
set_token_callback(None)
|
|
737
|
+
elif tool_call.name in get_deploy_tool_names():
|
|
738
|
+
loop = asyncio.get_event_loop()
|
|
739
|
+
ctx = copy_context()
|
|
740
|
+
future = loop.run_in_executor(
|
|
741
|
+
None, partial(ctx.run, execute_tool, tool_call.name, tool_call.arguments, DEPLOY_TOOLS)
|
|
742
|
+
)
|
|
743
|
+
result = await future
|
|
744
|
+
content = str(result)
|
|
745
|
+
else:
|
|
746
|
+
result = await self.mcp_connection.call_tool(tool_call.name, tool_call.arguments)
|
|
747
|
+
content = self._serialize_tool_result(result)
|
|
748
|
+
|
|
749
|
+
content = truncate_content(content)
|
|
750
|
+
yield ToolResult(tool_call_id=tool_call.id, name=tool_call.name, content=content)
|
|
751
|
+
|
|
752
|
+
except Exception as e:
|
|
753
|
+
set_progress_callback(None)
|
|
754
|
+
set_token_callback(None)
|
|
755
|
+
error_msg = f"Tool {tool_call.name} failed: {e}"
|
|
756
|
+
logger.warning(f"Tool {tool_call.name} failed: {e}", exc_info=True)
|
|
757
|
+
yield ToolResult(tool_call_id=tool_call.id, name=tool_call.name, content=error_msg, is_error=True)
|
|
758
|
+
|
|
759
|
+
def _extract_text_from_prompt(self, prompt: UserContent) -> str:
|
|
760
|
+
"""Extract text content from a user prompt for classification."""
|
|
761
|
+
if isinstance(prompt, str):
|
|
762
|
+
return prompt
|
|
763
|
+
text_parts: list[str] = []
|
|
764
|
+
for block in prompt:
|
|
765
|
+
if block.get("type") == "text":
|
|
766
|
+
text = block.get("text")
|
|
767
|
+
if isinstance(text, str):
|
|
768
|
+
text_parts.append(text)
|
|
769
|
+
return " ".join(text_parts)
|
|
770
|
+
|
|
771
|
+
def _check_request_scope(self, prompt: UserContent) -> AgentStep | None:
|
|
772
|
+
"""Check if request is in scope, return rejection step if out of scope."""
|
|
773
|
+
text = self._extract_text_from_prompt(prompt)
|
|
774
|
+
result = classify_request(self.client, text)
|
|
775
|
+
self._total_input_tokens += result.input_tokens
|
|
776
|
+
self._total_output_tokens += result.output_tokens
|
|
777
|
+
self._main_agent_input_tokens += result.input_tokens
|
|
778
|
+
self._main_agent_output_tokens += result.output_tokens
|
|
779
|
+
if result.scope == RequestScope.OUT_OF_SCOPE:
|
|
780
|
+
rejection = generate_rejection_response(self.client, text)
|
|
781
|
+
total_input = result.input_tokens + rejection.input_tokens
|
|
782
|
+
total_output = result.output_tokens + rejection.output_tokens
|
|
783
|
+
self._total_input_tokens += rejection.input_tokens
|
|
784
|
+
self._total_output_tokens += rejection.output_tokens
|
|
785
|
+
self._main_agent_input_tokens += rejection.input_tokens
|
|
786
|
+
self._main_agent_output_tokens += rejection.output_tokens
|
|
787
|
+
return AgentStep(
|
|
788
|
+
step_number=1,
|
|
789
|
+
final_answer=rejection.response,
|
|
790
|
+
is_final=True,
|
|
791
|
+
input_tokens=total_input,
|
|
792
|
+
output_tokens=total_output,
|
|
793
|
+
step_type=StepType.FINAL_ANSWER,
|
|
794
|
+
)
|
|
795
|
+
return None
|
|
796
|
+
|
|
797
|
+
def _inject_preload_info(self, prompt: UserContent, preload_result: str) -> UserContent:
|
|
798
|
+
"""Inject preload result info into the user prompt."""
|
|
799
|
+
suffix = (
|
|
800
|
+
f"\n\n[System: {preload_result}. Use these tools directly without calling list_tool_categories first.]"
|
|
801
|
+
)
|
|
802
|
+
if isinstance(prompt, str):
|
|
803
|
+
return prompt + suffix
|
|
804
|
+
system_block: TextBlockParam = {"type": "text", "text": suffix}
|
|
805
|
+
return [*prompt, system_block]
|
|
806
|
+
|
|
807
|
+
def _calculate_rate_limit_delay(self, retries: int) -> float:
|
|
808
|
+
"""Calculate exponential backoff delay with jitter for rate limiting."""
|
|
809
|
+
delay = min(RATE_LIMIT_BASE_DELAY * (2 ** (retries - 1)), RATE_LIMIT_MAX_DELAY)
|
|
810
|
+
jitter = random.uniform(0, delay * 0.1)
|
|
811
|
+
return delay + jitter
|
|
812
|
+
|
|
813
|
+
async def run(self, prompt: UserContent) -> AsyncIterator[AgentStep]:
|
|
814
|
+
"""Run the agent with the given prompt, yielding steps.
|
|
815
|
+
|
|
816
|
+
This method implements the main agent loop, calling the model,
|
|
817
|
+
executing tools, and continuing until the model produces a final
|
|
818
|
+
answer or the maximum number of steps is reached.
|
|
819
|
+
|
|
820
|
+
Rate limiting is handled with exponential backoff and jitter.
|
|
821
|
+
"""
|
|
822
|
+
if rejection := self._check_request_scope(prompt):
|
|
823
|
+
yield rejection
|
|
824
|
+
return
|
|
825
|
+
|
|
826
|
+
loop = asyncio.get_event_loop()
|
|
827
|
+
set_mcp_connection(self.mcp_connection, loop)
|
|
828
|
+
|
|
829
|
+
# Pre-load tool categories based on keywords in the user's request
|
|
830
|
+
# Run in thread pool to avoid blocking the event loop (preload uses sync MCP calls)
|
|
831
|
+
request_text = self._extract_text_from_prompt(prompt)
|
|
832
|
+
ctx = copy_context()
|
|
833
|
+
preload_result = await loop.run_in_executor(
|
|
834
|
+
None, partial(ctx.run, preload_categories_for_request, request_text)
|
|
835
|
+
)
|
|
836
|
+
|
|
837
|
+
# Inject pre-load info into the task so agent knows what tools are available
|
|
838
|
+
if preload_result:
|
|
839
|
+
prompt = self._inject_preload_info(prompt, preload_result)
|
|
840
|
+
|
|
841
|
+
self.memory.add_task(prompt)
|
|
842
|
+
|
|
843
|
+
for step_num in range(1, self.config.max_steps + 1):
|
|
844
|
+
rate_limit_retries = 0
|
|
845
|
+
|
|
846
|
+
# Throttle requests to avoid rate limiting (skip delay on first step)
|
|
847
|
+
if step_num > 1:
|
|
848
|
+
await asyncio.sleep(self.config.request_delay)
|
|
849
|
+
|
|
850
|
+
while True:
|
|
851
|
+
try:
|
|
852
|
+
final_step: AgentStep | None = None
|
|
853
|
+
async for step in self._stream_model_response(step_num):
|
|
854
|
+
yield step
|
|
855
|
+
if not step.is_streaming:
|
|
856
|
+
final_step = step
|
|
857
|
+
|
|
858
|
+
if final_step and final_step.is_final:
|
|
859
|
+
return
|
|
860
|
+
|
|
861
|
+
break
|
|
862
|
+
|
|
863
|
+
except RateLimitError as e:
|
|
864
|
+
rate_limit_retries += 1
|
|
865
|
+
if rate_limit_retries > RATE_LIMIT_MAX_RETRIES:
|
|
866
|
+
logger.error(f"Rate limit retries exhausted at step {step_num}: {e}")
|
|
867
|
+
yield AgentStep(
|
|
868
|
+
step_number=step_num,
|
|
869
|
+
error=f"Rate limit exceeded after {RATE_LIMIT_MAX_RETRIES} retries. Please try again later.",
|
|
870
|
+
is_final=True,
|
|
871
|
+
step_type=StepType.FINAL_ANSWER,
|
|
872
|
+
)
|
|
873
|
+
return
|
|
874
|
+
|
|
875
|
+
wait_time = self._calculate_rate_limit_delay(rate_limit_retries)
|
|
876
|
+
logger.warning(
|
|
877
|
+
f"Rate limit hit at step {step_num} (attempt {rate_limit_retries}/{RATE_LIMIT_MAX_RETRIES}), "
|
|
878
|
+
f"retrying in {wait_time:.1f}s: {e}"
|
|
879
|
+
)
|
|
880
|
+
yield AgentStep(
|
|
881
|
+
step_number=step_num,
|
|
882
|
+
thinking=f"⏳ Rate limited, waiting {wait_time:.1f}s before retry ({rate_limit_retries}/{RATE_LIMIT_MAX_RETRIES})...",
|
|
883
|
+
is_streaming=True,
|
|
884
|
+
step_type=StepType.INTERMEDIATE,
|
|
885
|
+
)
|
|
886
|
+
await asyncio.sleep(wait_time)
|
|
887
|
+
|
|
888
|
+
except APIError as e:
|
|
889
|
+
is_timeout = isinstance(e, APITimeoutError)
|
|
890
|
+
log_fn = logger.warning if is_timeout else logger.error
|
|
891
|
+
log_fn(f"API {'timeout' if is_timeout else 'error'} at step {step_num}: {e}")
|
|
892
|
+
error_msg = (
|
|
893
|
+
f"Request timed out. Please try again. Details: {e}"
|
|
894
|
+
if is_timeout
|
|
895
|
+
else f"API error occurred: {e}"
|
|
896
|
+
)
|
|
897
|
+
yield AgentStep(
|
|
898
|
+
step_number=step_num,
|
|
899
|
+
error=error_msg,
|
|
900
|
+
is_final=True,
|
|
901
|
+
step_type=StepType.FINAL_ANSWER,
|
|
902
|
+
)
|
|
903
|
+
return
|
|
904
|
+
|
|
905
|
+
else:
|
|
906
|
+
yield AgentStep(
|
|
907
|
+
step_number=self.config.max_steps,
|
|
908
|
+
error=f"Maximum steps ({self.config.max_steps}) reached without final answer.",
|
|
909
|
+
is_final=True,
|
|
910
|
+
step_type=StepType.FINAL_ANSWER,
|
|
911
|
+
)
|
|
912
|
+
|
|
913
|
+
|
|
914
|
+
async def create_agent(
|
|
915
|
+
mcp_connection: MCPConnection,
|
|
916
|
+
system_prompt: str,
|
|
917
|
+
config: AgentConfig | None = None,
|
|
918
|
+
additional_tools: list[ToolParam] | None = None,
|
|
919
|
+
) -> RossumAgent:
|
|
920
|
+
"""Create and configure a RossumAgent instance.
|
|
921
|
+
|
|
922
|
+
This is a convenience factory function that creates the Bedrock client
|
|
923
|
+
and initializes the agent with the provided configuration.
|
|
924
|
+
"""
|
|
925
|
+
client = create_bedrock_client()
|
|
926
|
+
return RossumAgent(
|
|
927
|
+
client=client,
|
|
928
|
+
mcp_connection=mcp_connection,
|
|
929
|
+
system_prompt=system_prompt,
|
|
930
|
+
config=config,
|
|
931
|
+
additional_tools=additional_tools,
|
|
932
|
+
)
|