rossum-agent 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rossum_agent/__init__.py +9 -0
- rossum_agent/agent/__init__.py +32 -0
- rossum_agent/agent/core.py +932 -0
- rossum_agent/agent/memory.py +176 -0
- rossum_agent/agent/models.py +160 -0
- rossum_agent/agent/request_classifier.py +152 -0
- rossum_agent/agent/skills.py +132 -0
- rossum_agent/agent/types.py +5 -0
- rossum_agent/agent_logging.py +56 -0
- rossum_agent/api/__init__.py +1 -0
- rossum_agent/api/cli.py +51 -0
- rossum_agent/api/dependencies.py +190 -0
- rossum_agent/api/main.py +180 -0
- rossum_agent/api/models/__init__.py +1 -0
- rossum_agent/api/models/schemas.py +301 -0
- rossum_agent/api/routes/__init__.py +1 -0
- rossum_agent/api/routes/chats.py +95 -0
- rossum_agent/api/routes/files.py +113 -0
- rossum_agent/api/routes/health.py +44 -0
- rossum_agent/api/routes/messages.py +218 -0
- rossum_agent/api/services/__init__.py +1 -0
- rossum_agent/api/services/agent_service.py +451 -0
- rossum_agent/api/services/chat_service.py +197 -0
- rossum_agent/api/services/file_service.py +65 -0
- rossum_agent/assets/Primary_light_logo.png +0 -0
- rossum_agent/bedrock_client.py +64 -0
- rossum_agent/prompts/__init__.py +27 -0
- rossum_agent/prompts/base_prompt.py +80 -0
- rossum_agent/prompts/system_prompt.py +24 -0
- rossum_agent/py.typed +0 -0
- rossum_agent/redis_storage.py +482 -0
- rossum_agent/rossum_mcp_integration.py +123 -0
- rossum_agent/skills/hook-debugging.md +31 -0
- rossum_agent/skills/organization-setup.md +60 -0
- rossum_agent/skills/rossum-deployment.md +102 -0
- rossum_agent/skills/schema-patching.md +61 -0
- rossum_agent/skills/schema-pruning.md +23 -0
- rossum_agent/skills/ui-settings.md +45 -0
- rossum_agent/streamlit_app/__init__.py +1 -0
- rossum_agent/streamlit_app/app.py +646 -0
- rossum_agent/streamlit_app/beep_sound.py +36 -0
- rossum_agent/streamlit_app/cli.py +17 -0
- rossum_agent/streamlit_app/render_modules.py +123 -0
- rossum_agent/streamlit_app/response_formatting.py +305 -0
- rossum_agent/tools/__init__.py +214 -0
- rossum_agent/tools/core.py +173 -0
- rossum_agent/tools/deploy.py +404 -0
- rossum_agent/tools/dynamic_tools.py +365 -0
- rossum_agent/tools/file_tools.py +62 -0
- rossum_agent/tools/formula.py +187 -0
- rossum_agent/tools/skills.py +31 -0
- rossum_agent/tools/spawn_mcp.py +227 -0
- rossum_agent/tools/subagents/__init__.py +31 -0
- rossum_agent/tools/subagents/base.py +303 -0
- rossum_agent/tools/subagents/hook_debug.py +591 -0
- rossum_agent/tools/subagents/knowledge_base.py +305 -0
- rossum_agent/tools/subagents/mcp_helpers.py +47 -0
- rossum_agent/tools/subagents/schema_patching.py +471 -0
- rossum_agent/url_context.py +167 -0
- rossum_agent/user_detection.py +100 -0
- rossum_agent/utils.py +128 -0
- rossum_agent-1.0.0rc0.dist-info/METADATA +311 -0
- rossum_agent-1.0.0rc0.dist-info/RECORD +67 -0
- rossum_agent-1.0.0rc0.dist-info/WHEEL +5 -0
- rossum_agent-1.0.0rc0.dist-info/entry_points.txt +3 -0
- rossum_agent-1.0.0rc0.dist-info/licenses/LICENSE +21 -0
- rossum_agent-1.0.0rc0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
"""Tools for spawning, calling, and closing MCP connections to different Rossum environments.
|
|
2
|
+
|
|
3
|
+
This module provides tools to manage secondary MCP connections to different Rossum
|
|
4
|
+
environments at runtime, enabling cross-environment operations like deployments.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import asyncio
|
|
10
|
+
import json
|
|
11
|
+
import logging
|
|
12
|
+
import threading
|
|
13
|
+
from concurrent.futures import TimeoutError as FuturesTimeoutError
|
|
14
|
+
from dataclasses import dataclass
|
|
15
|
+
|
|
16
|
+
from anthropic import beta_tool
|
|
17
|
+
from fastmcp import Client
|
|
18
|
+
|
|
19
|
+
from rossum_agent.rossum_mcp_integration import MCPConnection, create_mcp_transport
|
|
20
|
+
from rossum_agent.tools.core import get_mcp_event_loop
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class SpawnedConnection:
|
|
27
|
+
"""Record for a spawned MCP connection to a different environment."""
|
|
28
|
+
|
|
29
|
+
connection: MCPConnection
|
|
30
|
+
client: Client
|
|
31
|
+
api_base_url: str
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
# Secondary MCP connections spawned at runtime for different environments
|
|
35
|
+
_spawned_connections: dict[str, SpawnedConnection] = {}
|
|
36
|
+
_spawned_connections_lock = threading.Lock()
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def get_spawned_connections() -> dict[str, SpawnedConnection]:
|
|
40
|
+
"""Get the spawned connections dict (for internal use only)."""
|
|
41
|
+
return _spawned_connections
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def get_spawned_connections_lock() -> threading.Lock:
|
|
45
|
+
"""Get the spawned connections lock (for internal use only)."""
|
|
46
|
+
return _spawned_connections_lock
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def clear_spawned_connections() -> None:
|
|
50
|
+
"""Clear all spawned connections. Called when MCP connection is reset."""
|
|
51
|
+
with _spawned_connections_lock:
|
|
52
|
+
_spawned_connections.clear()
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
async def _spawn_connection_async(
|
|
56
|
+
connection_id: str, api_token: str, api_base_url: str, mcp_mode: str = "read-write"
|
|
57
|
+
) -> SpawnedConnection:
|
|
58
|
+
"""Spawn a new MCP connection asynchronously.
|
|
59
|
+
|
|
60
|
+
Raises:
|
|
61
|
+
ValueError: If connection_id already exists.
|
|
62
|
+
"""
|
|
63
|
+
with _spawned_connections_lock:
|
|
64
|
+
if connection_id in _spawned_connections:
|
|
65
|
+
raise ValueError(f"Connection '{connection_id}' already exists")
|
|
66
|
+
|
|
67
|
+
transport = create_mcp_transport(api_token, api_base_url, mcp_mode) # type: ignore[arg-type]
|
|
68
|
+
client = Client(transport)
|
|
69
|
+
|
|
70
|
+
await client.__aenter__()
|
|
71
|
+
connection = MCPConnection(client=client)
|
|
72
|
+
|
|
73
|
+
record = SpawnedConnection(connection=connection, client=client, api_base_url=api_base_url)
|
|
74
|
+
|
|
75
|
+
with _spawned_connections_lock:
|
|
76
|
+
_spawned_connections[connection_id] = record
|
|
77
|
+
|
|
78
|
+
return record
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
async def _close_spawned_connection_async(connection_id: str) -> None:
|
|
82
|
+
"""Close a spawned MCP connection."""
|
|
83
|
+
with _spawned_connections_lock:
|
|
84
|
+
record = _spawned_connections.pop(connection_id, None)
|
|
85
|
+
|
|
86
|
+
if record is not None:
|
|
87
|
+
await record.client.__aexit__(None, None, None)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def cleanup_all_spawned_connections() -> None:
|
|
91
|
+
"""Cleanup all spawned connections. Call this when the agent session ends.
|
|
92
|
+
|
|
93
|
+
Should only be called once at session teardown, after no more tool calls are expected.
|
|
94
|
+
"""
|
|
95
|
+
if (mcp_event_loop := get_mcp_event_loop()) is None:
|
|
96
|
+
return
|
|
97
|
+
|
|
98
|
+
with _spawned_connections_lock:
|
|
99
|
+
conn_ids = list(_spawned_connections.keys())
|
|
100
|
+
|
|
101
|
+
for conn_id in conn_ids:
|
|
102
|
+
try:
|
|
103
|
+
future = asyncio.run_coroutine_threadsafe(_close_spawned_connection_async(conn_id), mcp_event_loop)
|
|
104
|
+
future.result(timeout=10)
|
|
105
|
+
except FuturesTimeoutError:
|
|
106
|
+
future.cancel()
|
|
107
|
+
logger.warning(f"Timeout cleaning up connection {conn_id}")
|
|
108
|
+
except Exception as e:
|
|
109
|
+
logger.warning(f"Failed to cleanup connection {conn_id}: {e}")
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@beta_tool
|
|
113
|
+
def spawn_mcp_connection(connection_id: str, api_token: str, api_base_url: str, mcp_mode: str = "read-write") -> str:
|
|
114
|
+
"""Spawn a new MCP connection to a different Rossum environment.
|
|
115
|
+
|
|
116
|
+
Use this when you need to make changes to a different Rossum environment than the one the agent was initialized with.
|
|
117
|
+
For example, when deploying changes from a source environment to a target environment.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
connection_id: A unique identifier for this connection (e.g., 'target', 'sandbox')
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
Success message with available tools, or error message if failed.
|
|
124
|
+
"""
|
|
125
|
+
if (mcp_event_loop := get_mcp_event_loop()) is None:
|
|
126
|
+
return "Error: MCP event loop not set. Agent not properly initialized."
|
|
127
|
+
|
|
128
|
+
if not connection_id or not connection_id.strip():
|
|
129
|
+
return "Error: connection_id must be non-empty."
|
|
130
|
+
|
|
131
|
+
if not api_base_url or not api_base_url.startswith("https://"):
|
|
132
|
+
return "Error: api_base_url must start with https://"
|
|
133
|
+
|
|
134
|
+
try:
|
|
135
|
+
future = asyncio.run_coroutine_threadsafe(
|
|
136
|
+
_spawn_connection_async(connection_id, api_token, api_base_url, mcp_mode),
|
|
137
|
+
mcp_event_loop,
|
|
138
|
+
)
|
|
139
|
+
record = future.result(timeout=30)
|
|
140
|
+
|
|
141
|
+
tools_future = asyncio.run_coroutine_threadsafe(record.connection.get_tools(), mcp_event_loop)
|
|
142
|
+
tools = tools_future.result(timeout=30)
|
|
143
|
+
tool_names = [t.name for t in tools]
|
|
144
|
+
|
|
145
|
+
return f"Successfully spawned MCP connection '{connection_id}' to {api_base_url}. Available tools: {', '.join(tool_names[:10])}{'...' if len(tool_names) > 10 else ''}"
|
|
146
|
+
except ValueError as e:
|
|
147
|
+
return f"Error: {e}"
|
|
148
|
+
except FuturesTimeoutError:
|
|
149
|
+
future.cancel()
|
|
150
|
+
return "Error: Timed out while spawning MCP connection."
|
|
151
|
+
except RuntimeError as e:
|
|
152
|
+
logger.exception("Error scheduling MCP call")
|
|
153
|
+
return f"Error: Failed to schedule MCP call: {e}"
|
|
154
|
+
except Exception as e:
|
|
155
|
+
logger.error(f"Failed to spawn connection: {e}")
|
|
156
|
+
return f"Error spawning connection: {e}"
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
@beta_tool
|
|
160
|
+
def call_on_connection(connection_id: str, tool_name: str, arguments: str | dict) -> str:
|
|
161
|
+
"""Call a tool on a spawned MCP connection.
|
|
162
|
+
|
|
163
|
+
Use this to execute MCP tools on a connection that was previously spawned with spawn_mcp_connection.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
connection_id: The identifier of the spawned connection.
|
|
167
|
+
tool_name: The name of the MCP tool to call.
|
|
168
|
+
arguments: Arguments to pass to the tool (JSON string or dict).
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
The result of the tool call as a JSON string, or error message.
|
|
172
|
+
"""
|
|
173
|
+
if (mcp_event_loop := get_mcp_event_loop()) is None:
|
|
174
|
+
return "Error: MCP event loop not set."
|
|
175
|
+
|
|
176
|
+
with _spawned_connections_lock:
|
|
177
|
+
if connection_id not in _spawned_connections:
|
|
178
|
+
available = list(_spawned_connections.keys())
|
|
179
|
+
return f"Error: Connection '{connection_id}' not found. Available: {available}"
|
|
180
|
+
record = _spawned_connections[connection_id]
|
|
181
|
+
|
|
182
|
+
logger.debug(f"call_on_connection: Using connection '{connection_id}' - API URL: {record.api_base_url}")
|
|
183
|
+
|
|
184
|
+
if isinstance(arguments, dict):
|
|
185
|
+
args = arguments
|
|
186
|
+
elif arguments:
|
|
187
|
+
try:
|
|
188
|
+
args = json.loads(arguments)
|
|
189
|
+
except json.JSONDecodeError as e:
|
|
190
|
+
return f"Error parsing arguments JSON: {e}"
|
|
191
|
+
else:
|
|
192
|
+
args = {}
|
|
193
|
+
|
|
194
|
+
try:
|
|
195
|
+
future = asyncio.run_coroutine_threadsafe(record.connection.call_tool(tool_name, args), mcp_event_loop)
|
|
196
|
+
result = future.result(timeout=60)
|
|
197
|
+
|
|
198
|
+
if isinstance(result, (dict, list)):
|
|
199
|
+
return f"[{tool_name}] {json.dumps(result, indent=2, default=str)}"
|
|
200
|
+
return f"[{tool_name}] {result}" if result is not None else f"[{tool_name}] Tool executed successfully"
|
|
201
|
+
except FuturesTimeoutError:
|
|
202
|
+
future.cancel()
|
|
203
|
+
return f"Error: Timed out calling {tool_name} after 60 seconds."
|
|
204
|
+
except Exception as e:
|
|
205
|
+
logger.error(f"Error calling tool on connection: {e}")
|
|
206
|
+
return f"Error calling {tool_name}: {e}"
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
@beta_tool
|
|
210
|
+
def close_connection(connection_id: str) -> str:
|
|
211
|
+
"""Close a spawned MCP connection."""
|
|
212
|
+
if (mcp_event_loop := get_mcp_event_loop()) is None:
|
|
213
|
+
return "Error: MCP event loop not set."
|
|
214
|
+
|
|
215
|
+
with _spawned_connections_lock:
|
|
216
|
+
if connection_id not in _spawned_connections:
|
|
217
|
+
return f"Connection '{connection_id}' not found."
|
|
218
|
+
|
|
219
|
+
try:
|
|
220
|
+
future = asyncio.run_coroutine_threadsafe(_close_spawned_connection_async(connection_id), mcp_event_loop)
|
|
221
|
+
future.result(timeout=10)
|
|
222
|
+
return f"Successfully closed connection '{connection_id}'."
|
|
223
|
+
except FuturesTimeoutError:
|
|
224
|
+
future.cancel()
|
|
225
|
+
return f"Error: Timed out closing connection '{connection_id}'."
|
|
226
|
+
except Exception as e:
|
|
227
|
+
return f"Error closing connection: {e}"
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""Sub-agents for the Rossum Agent.
|
|
2
|
+
|
|
3
|
+
Opus-powered sub-agents for complex iterative tasks:
|
|
4
|
+
- Hook debugging with sandboxed execution
|
|
5
|
+
- Knowledge base search with AI analysis
|
|
6
|
+
- Schema patching with programmatic bulk updates
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from rossum_agent.bedrock_client import OPUS_MODEL_ID
|
|
12
|
+
from rossum_agent.tools.subagents.base import SubAgent, SubAgentConfig, SubAgentResult
|
|
13
|
+
from rossum_agent.tools.subagents.hook_debug import HookDebugSubAgent, debug_hook, evaluate_python_hook
|
|
14
|
+
from rossum_agent.tools.subagents.knowledge_base import WebSearchError, search_knowledge_base
|
|
15
|
+
from rossum_agent.tools.subagents.mcp_helpers import call_mcp_tool
|
|
16
|
+
from rossum_agent.tools.subagents.schema_patching import SchemaPatchingSubAgent, patch_schema_with_subagent
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"OPUS_MODEL_ID",
|
|
20
|
+
"HookDebugSubAgent",
|
|
21
|
+
"SchemaPatchingSubAgent",
|
|
22
|
+
"SubAgent",
|
|
23
|
+
"SubAgentConfig",
|
|
24
|
+
"SubAgentResult",
|
|
25
|
+
"WebSearchError",
|
|
26
|
+
"call_mcp_tool",
|
|
27
|
+
"debug_hook",
|
|
28
|
+
"evaluate_python_hook",
|
|
29
|
+
"patch_schema_with_subagent",
|
|
30
|
+
"search_knowledge_base",
|
|
31
|
+
]
|
|
@@ -0,0 +1,303 @@
|
|
|
1
|
+
"""Shared base module for sub-agents.
|
|
2
|
+
|
|
3
|
+
Provides common infrastructure for sub-agents that use iterative LLM calls with tool use:
|
|
4
|
+
- Unified iteration loop with token tracking
|
|
5
|
+
- Context saving for debugging
|
|
6
|
+
- Consistent logging patterns
|
|
7
|
+
- Progress and token usage reporting
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import json
|
|
13
|
+
import logging
|
|
14
|
+
import time
|
|
15
|
+
from abc import ABC, abstractmethod
|
|
16
|
+
from dataclasses import dataclass
|
|
17
|
+
from typing import TYPE_CHECKING
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from typing import Any
|
|
21
|
+
|
|
22
|
+
from rossum_agent.bedrock_client import create_bedrock_client, get_model_id
|
|
23
|
+
from rossum_agent.tools.core import (
|
|
24
|
+
SubAgentProgress,
|
|
25
|
+
SubAgentTokenUsage,
|
|
26
|
+
get_output_dir,
|
|
27
|
+
report_progress,
|
|
28
|
+
report_token_usage,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class SubAgentConfig:
|
|
36
|
+
"""Configuration for a sub-agent iteration loop."""
|
|
37
|
+
|
|
38
|
+
tool_name: str
|
|
39
|
+
system_prompt: str
|
|
40
|
+
tools: list[dict[str, Any]]
|
|
41
|
+
max_iterations: int = 15
|
|
42
|
+
max_tokens: int = 16384
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
@dataclass
|
|
46
|
+
class SubAgentResult:
|
|
47
|
+
"""Result from a sub-agent execution."""
|
|
48
|
+
|
|
49
|
+
analysis: str
|
|
50
|
+
input_tokens: int
|
|
51
|
+
output_tokens: int
|
|
52
|
+
iterations_used: int
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def save_iteration_context(
|
|
56
|
+
tool_name: str,
|
|
57
|
+
iteration: int,
|
|
58
|
+
max_iterations: int,
|
|
59
|
+
messages: list[dict[str, Any]],
|
|
60
|
+
system_prompt: str,
|
|
61
|
+
tools: list[dict[str, Any]],
|
|
62
|
+
max_tokens: int,
|
|
63
|
+
) -> None:
|
|
64
|
+
"""Save agent input context to file for debugging.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
tool_name: Name of the sub-agent tool (e.g., "debug_hook", "patch_schema").
|
|
68
|
+
iteration: Current iteration number (1-indexed).
|
|
69
|
+
max_iterations: Maximum number of iterations.
|
|
70
|
+
messages: Current conversation messages.
|
|
71
|
+
system_prompt: System prompt used.
|
|
72
|
+
tools: Tool definitions.
|
|
73
|
+
max_tokens: Max tokens setting.
|
|
74
|
+
"""
|
|
75
|
+
try:
|
|
76
|
+
output_dir = get_output_dir()
|
|
77
|
+
context_file = output_dir / f"{tool_name}_context_iter_{iteration}.json"
|
|
78
|
+
context_data = {
|
|
79
|
+
"iteration": iteration,
|
|
80
|
+
"max_iterations": max_iterations,
|
|
81
|
+
"model": get_model_id(),
|
|
82
|
+
"max_tokens": max_tokens,
|
|
83
|
+
"system_prompt": system_prompt,
|
|
84
|
+
"messages": messages,
|
|
85
|
+
"tools": tools,
|
|
86
|
+
}
|
|
87
|
+
context_file.write_text(json.dumps(context_data, indent=2, default=str))
|
|
88
|
+
logger.info(f"{tool_name} sub-agent: saved context to {context_file}")
|
|
89
|
+
except Exception as e:
|
|
90
|
+
logger.warning(f"Failed to save {tool_name} context: {e}")
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class SubAgent(ABC):
|
|
94
|
+
"""Base class for sub-agents with iterative tool use.
|
|
95
|
+
|
|
96
|
+
Provides a unified iteration loop with:
|
|
97
|
+
- Token tracking and reporting
|
|
98
|
+
- Progress reporting
|
|
99
|
+
- Context saving for debugging
|
|
100
|
+
- Consistent logging
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
def __init__(self, config: SubAgentConfig) -> None:
|
|
104
|
+
"""Initialize the sub-agent.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
config: Configuration for the sub-agent.
|
|
108
|
+
"""
|
|
109
|
+
self.config = config
|
|
110
|
+
self._client = None
|
|
111
|
+
|
|
112
|
+
@property
|
|
113
|
+
def client(self):
|
|
114
|
+
"""Lazily create the Bedrock client."""
|
|
115
|
+
if self._client is None:
|
|
116
|
+
client_start = time.perf_counter()
|
|
117
|
+
self._client = create_bedrock_client()
|
|
118
|
+
elapsed_ms = (time.perf_counter() - client_start) * 1000
|
|
119
|
+
logger.info(f"{self.config.tool_name}: Bedrock client created in {elapsed_ms:.1f}ms")
|
|
120
|
+
return self._client
|
|
121
|
+
|
|
122
|
+
@abstractmethod
|
|
123
|
+
def execute_tool(self, tool_name: str, tool_input: dict[str, Any]) -> str:
|
|
124
|
+
"""Execute a tool call from the LLM.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
tool_name: Name of the tool to execute.
|
|
128
|
+
tool_input: Input arguments for the tool.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
Tool result as a string.
|
|
132
|
+
"""
|
|
133
|
+
|
|
134
|
+
@abstractmethod
|
|
135
|
+
def process_response_block(self, block: Any, iteration: int, max_iterations: int) -> dict[str, Any] | None:
|
|
136
|
+
"""Process a response block for special handling (e.g., web search).
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
block: Response content block.
|
|
140
|
+
iteration: Current iteration number (1-indexed).
|
|
141
|
+
max_iterations: Maximum iterations.
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
Tool result dict if the block was processed, None otherwise.
|
|
145
|
+
"""
|
|
146
|
+
|
|
147
|
+
def run(self, initial_message: str) -> SubAgentResult:
|
|
148
|
+
"""Run the sub-agent iteration loop."""
|
|
149
|
+
messages: list[dict[str, Any]] = [{"role": "user", "content": initial_message}]
|
|
150
|
+
total_input_tokens = 0
|
|
151
|
+
total_output_tokens = 0
|
|
152
|
+
current_iteration = 0
|
|
153
|
+
|
|
154
|
+
response = None
|
|
155
|
+
try:
|
|
156
|
+
for iteration in range(self.config.max_iterations):
|
|
157
|
+
current_iteration = iteration + 1
|
|
158
|
+
iter_start = time.perf_counter()
|
|
159
|
+
|
|
160
|
+
logger.info(
|
|
161
|
+
f"{self.config.tool_name} sub-agent: iteration {current_iteration}/{self.config.max_iterations}"
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
report_progress(
|
|
165
|
+
SubAgentProgress(
|
|
166
|
+
tool_name=self.config.tool_name,
|
|
167
|
+
iteration=current_iteration,
|
|
168
|
+
max_iterations=self.config.max_iterations,
|
|
169
|
+
status="thinking",
|
|
170
|
+
)
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
save_iteration_context(
|
|
174
|
+
tool_name=self.config.tool_name,
|
|
175
|
+
iteration=current_iteration,
|
|
176
|
+
max_iterations=self.config.max_iterations,
|
|
177
|
+
messages=messages,
|
|
178
|
+
system_prompt=self.config.system_prompt,
|
|
179
|
+
tools=self.config.tools,
|
|
180
|
+
max_tokens=self.config.max_tokens,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
llm_start = time.perf_counter()
|
|
184
|
+
response = self.client.messages.create(
|
|
185
|
+
model=get_model_id(),
|
|
186
|
+
max_tokens=self.config.max_tokens,
|
|
187
|
+
system=self.config.system_prompt,
|
|
188
|
+
messages=messages,
|
|
189
|
+
tools=self.config.tools,
|
|
190
|
+
)
|
|
191
|
+
llm_elapsed_ms = (time.perf_counter() - llm_start) * 1000
|
|
192
|
+
|
|
193
|
+
input_tokens = response.usage.input_tokens
|
|
194
|
+
output_tokens = response.usage.output_tokens
|
|
195
|
+
total_input_tokens += input_tokens
|
|
196
|
+
total_output_tokens += output_tokens
|
|
197
|
+
|
|
198
|
+
logger.info(
|
|
199
|
+
f"{self.config.tool_name} [iter {current_iteration}]: "
|
|
200
|
+
f"LLM {llm_elapsed_ms:.1f}ms, tokens in={input_tokens} out={output_tokens}"
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
report_token_usage(
|
|
204
|
+
SubAgentTokenUsage(
|
|
205
|
+
tool_name=self.config.tool_name,
|
|
206
|
+
input_tokens=input_tokens,
|
|
207
|
+
output_tokens=output_tokens,
|
|
208
|
+
iteration=current_iteration,
|
|
209
|
+
)
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
has_tool_use = any(hasattr(block, "type") and block.type == "tool_use" for block in response.content)
|
|
213
|
+
|
|
214
|
+
if response.stop_reason == "end_of_turn" or not has_tool_use:
|
|
215
|
+
iter_elapsed_ms = (time.perf_counter() - iter_start) * 1000
|
|
216
|
+
logger.info(
|
|
217
|
+
f"{self.config.tool_name}: completed after {current_iteration} iterations "
|
|
218
|
+
f"in {iter_elapsed_ms:.1f}ms (stop_reason={response.stop_reason}, has_tool_use={has_tool_use})"
|
|
219
|
+
)
|
|
220
|
+
report_progress(
|
|
221
|
+
SubAgentProgress(
|
|
222
|
+
tool_name=self.config.tool_name,
|
|
223
|
+
iteration=current_iteration,
|
|
224
|
+
max_iterations=self.config.max_iterations,
|
|
225
|
+
status="completed",
|
|
226
|
+
)
|
|
227
|
+
)
|
|
228
|
+
text_parts = [block.text for block in response.content if hasattr(block, "text")]
|
|
229
|
+
return SubAgentResult(
|
|
230
|
+
analysis="\n".join(text_parts) if text_parts else "No analysis provided",
|
|
231
|
+
input_tokens=total_input_tokens,
|
|
232
|
+
output_tokens=total_output_tokens,
|
|
233
|
+
iterations_used=current_iteration,
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
messages.append({"role": "assistant", "content": response.content})
|
|
237
|
+
|
|
238
|
+
tool_results: list[dict[str, Any]] = []
|
|
239
|
+
iteration_tool_calls: list[str] = []
|
|
240
|
+
|
|
241
|
+
for block in response.content:
|
|
242
|
+
special_result = self.process_response_block(block, current_iteration, self.config.max_iterations)
|
|
243
|
+
if special_result:
|
|
244
|
+
tool_results.append(special_result)
|
|
245
|
+
|
|
246
|
+
for block in response.content:
|
|
247
|
+
if hasattr(block, "type") and block.type == "tool_use":
|
|
248
|
+
tool_name = block.name
|
|
249
|
+
tool_input = block.input
|
|
250
|
+
iteration_tool_calls.append(tool_name)
|
|
251
|
+
|
|
252
|
+
logger.info(f"{self.config.tool_name} [iter {current_iteration}]: calling tool '{tool_name}'")
|
|
253
|
+
|
|
254
|
+
report_progress(
|
|
255
|
+
SubAgentProgress(
|
|
256
|
+
tool_name=self.config.tool_name,
|
|
257
|
+
iteration=current_iteration,
|
|
258
|
+
max_iterations=self.config.max_iterations,
|
|
259
|
+
current_tool=tool_name,
|
|
260
|
+
tool_calls=iteration_tool_calls.copy(),
|
|
261
|
+
status="running_tool",
|
|
262
|
+
)
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
try:
|
|
266
|
+
tool_start = time.perf_counter()
|
|
267
|
+
result = self.execute_tool(tool_name, tool_input)
|
|
268
|
+
tool_elapsed_ms = (time.perf_counter() - tool_start) * 1000
|
|
269
|
+
logger.info(
|
|
270
|
+
f"{self.config.tool_name}: tool '{tool_name}' executed in {tool_elapsed_ms:.1f}ms"
|
|
271
|
+
)
|
|
272
|
+
tool_results.append({"type": "tool_result", "tool_use_id": block.id, "content": result})
|
|
273
|
+
except Exception as e:
|
|
274
|
+
logger.warning(f"Tool {tool_name} failed: {e}")
|
|
275
|
+
tool_results.append(
|
|
276
|
+
{
|
|
277
|
+
"type": "tool_result",
|
|
278
|
+
"tool_use_id": block.id,
|
|
279
|
+
"content": f"Error: {e}",
|
|
280
|
+
"is_error": True,
|
|
281
|
+
}
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
if tool_results:
|
|
285
|
+
messages.append({"role": "user", "content": tool_results})
|
|
286
|
+
|
|
287
|
+
logger.warning(f"{self.config.tool_name}: max iterations ({self.config.max_iterations}) reached")
|
|
288
|
+
text_parts = [block.text for block in response.content if hasattr(block, "text")] if response else []
|
|
289
|
+
return SubAgentResult(
|
|
290
|
+
analysis="\n".join(text_parts) if text_parts else "Max iterations reached",
|
|
291
|
+
input_tokens=total_input_tokens,
|
|
292
|
+
output_tokens=total_output_tokens,
|
|
293
|
+
iterations_used=self.config.max_iterations,
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
except Exception as e:
|
|
297
|
+
logger.exception(f"Error in {self.config.tool_name} sub-agent")
|
|
298
|
+
return SubAgentResult(
|
|
299
|
+
analysis=f"Error calling Opus sub-agent: {e}",
|
|
300
|
+
input_tokens=total_input_tokens,
|
|
301
|
+
output_tokens=total_output_tokens,
|
|
302
|
+
iterations_used=current_iteration,
|
|
303
|
+
)
|