hud-python 0.4.28__py3-none-any.whl → 0.4.30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hud-python might be problematic. Click here for more details.
- hud/__init__.py +2 -1
- hud/agents/base.py +81 -45
- hud/agents/claude.py +8 -4
- hud/agents/openai_chat_generic.py +66 -40
- hud/agents/tests/test_base.py +0 -4
- hud/agents/tests/test_openai.py +1 -1
- hud/cli/__init__.py +182 -52
- hud/cli/dev.py +8 -9
- hud/cli/eval.py +317 -119
- hud/cli/flows/__init__.py +0 -0
- hud/cli/flows/tasks.py +0 -0
- hud/cli/get.py +160 -0
- hud/cli/rl/__init__.py +567 -71
- hud/cli/rl/config.py +94 -0
- hud/cli/rl/display.py +133 -0
- hud/cli/rl/gpu.py +63 -0
- hud/cli/rl/gpu_utils.py +318 -0
- hud/cli/rl/presets.py +96 -0
- hud/cli/rl/remote_runner.py +347 -0
- hud/cli/rl/rl_api.py +150 -0
- hud/cli/rl/vllm.py +177 -0
- hud/cli/tests/test_analyze_metadata.py +0 -1
- hud/cli/utils/tasks.py +26 -0
- hud/clients/base.py +21 -23
- hud/clients/mcp_use.py +36 -44
- hud/clients/tests/test_mcp_use_retry.py +10 -10
- hud/datasets/__init__.py +4 -3
- hud/datasets/{execution/parallel.py → parallel.py} +1 -1
- hud/datasets/{execution/runner.py → runner.py} +1 -1
- hud/datasets/utils.py +1 -1
- hud/native/comparator.py +6 -6
- hud/native/tests/test_comparator.py +8 -8
- hud/native/tests/test_native_init.py +13 -11
- hud/otel/config.py +1 -1
- hud/otel/instrumentation.py +35 -0
- hud/rl/README.md +30 -0
- hud/rl/__init__.py +1 -0
- hud/rl/actor.py +174 -0
- hud/rl/buffer.py +371 -0
- hud/rl/chat_template.jinja +101 -0
- hud/rl/config.py +184 -0
- hud/rl/distributed.py +95 -0
- hud/rl/learner.py +589 -0
- hud/rl/tests/__init__.py +1 -0
- hud/rl/tests/test_learner.py +171 -0
- hud/rl/train.py +354 -0
- hud/rl/types.py +101 -0
- hud/rl/utils/start_vllm_server.sh +30 -0
- hud/rl/utils.py +524 -0
- hud/rl/vllm_adapter.py +125 -0
- hud/settings.py +6 -0
- hud/telemetry/__init__.py +2 -1
- hud/telemetry/job.py +46 -3
- hud/telemetry/tests/test_trace.py +3 -3
- hud/telemetry/trace.py +85 -13
- hud/tools/tests/test_computer.py +3 -3
- hud/tools/tests/test_computer_actions.py +1 -1
- hud/types.py +123 -2
- hud/utils/group_eval.py +223 -0
- hud/utils/hud_console.py +113 -13
- hud/utils/tasks.py +119 -0
- hud/utils/tests/test_version.py +1 -1
- hud/version.py +1 -1
- {hud_python-0.4.28.dist-info → hud_python-0.4.30.dist-info}/METADATA +20 -2
- {hud_python-0.4.28.dist-info → hud_python-0.4.30.dist-info}/RECORD +68 -48
- hud/cli/hf.py +0 -406
- hud/cli/rl/README.md +0 -243
- hud/cli/rl/init.py +0 -370
- hud/cli/rl/pod.py +0 -501
- hud/cli/rl/ssh.py +0 -322
- hud/cli/rl/train.py +0 -562
- hud/cli/rl/utils.py +0 -165
- hud/datasets/execution/__init__.py +0 -13
- hud/datasets/task.py +0 -116
- {hud_python-0.4.28.dist-info → hud_python-0.4.30.dist-info}/WHEEL +0 -0
- {hud_python-0.4.28.dist-info → hud_python-0.4.30.dist-info}/entry_points.txt +0 -0
- {hud_python-0.4.28.dist-info → hud_python-0.4.30.dist-info}/licenses/LICENSE +0 -0
hud/__init__.py
CHANGED
|
@@ -5,9 +5,10 @@ tools for building, evaluating, and training AI agents.
|
|
|
5
5
|
|
|
6
6
|
from __future__ import annotations
|
|
7
7
|
|
|
8
|
-
from .telemetry import clear_trace, create_job, get_trace, instrument, job, trace
|
|
8
|
+
from .telemetry import Trace, clear_trace, create_job, get_trace, instrument, job, trace
|
|
9
9
|
|
|
10
10
|
__all__ = [
|
|
11
|
+
"Trace",
|
|
11
12
|
"clear_trace",
|
|
12
13
|
"create_job",
|
|
13
14
|
"get_trace",
|
hud/agents/base.py
CHANGED
|
@@ -45,7 +45,7 @@ class MCPAgent(ABC):
|
|
|
45
45
|
`format_blocks`, and `format_tool_results`.
|
|
46
46
|
"""
|
|
47
47
|
|
|
48
|
-
metadata: dict[str, Any]
|
|
48
|
+
metadata: dict[str, Any] | None = None
|
|
49
49
|
required_tools: ClassVar[list[str]] = [] # Tools that must be available
|
|
50
50
|
|
|
51
51
|
def __init__(
|
|
@@ -54,7 +54,6 @@ class MCPAgent(ABC):
|
|
|
54
54
|
# Filtering
|
|
55
55
|
allowed_tools: list[str] | None = None,
|
|
56
56
|
disallowed_tools: list[str] | None = None,
|
|
57
|
-
lifecycle_tools: list[str] | None = None,
|
|
58
57
|
# Messages
|
|
59
58
|
system_prompt: str = GLOBAL_SYSTEM_PROMPT,
|
|
60
59
|
append_setup_output: bool = True,
|
|
@@ -74,8 +73,6 @@ class MCPAgent(ABC):
|
|
|
74
73
|
that provides `mcp_config`.
|
|
75
74
|
allowed_tools: Names of tools to allow (None means allow all).
|
|
76
75
|
disallowed_tools: Names of tools to always exclude.
|
|
77
|
-
lifecycle_tools: Tools reserved for lifecycle phases (e.g., setup,
|
|
78
|
-
evaluate). These are hidden from normal tool calling.
|
|
79
76
|
system_prompt: System prompt to seed the conversation.
|
|
80
77
|
append_setup_output: Whether to append setup tool output to the
|
|
81
78
|
first turn's messages.
|
|
@@ -98,10 +95,13 @@ class MCPAgent(ABC):
|
|
|
98
95
|
if verbose:
|
|
99
96
|
self.console.set_verbose(True)
|
|
100
97
|
|
|
101
|
-
#
|
|
98
|
+
# User filtering
|
|
102
99
|
self.allowed_tools = allowed_tools
|
|
103
100
|
self.disallowed_tools = disallowed_tools or []
|
|
104
|
-
|
|
101
|
+
|
|
102
|
+
# Task filtering
|
|
103
|
+
self.agent_tools = None
|
|
104
|
+
self.lifecycle_tools = []
|
|
105
105
|
|
|
106
106
|
# Messages
|
|
107
107
|
self.system_prompt = system_prompt
|
|
@@ -112,7 +112,6 @@ class MCPAgent(ABC):
|
|
|
112
112
|
self._available_tools: list[types.Tool] = []
|
|
113
113
|
self._tool_map: dict[str, types.Tool] = {} # Simplified: just name to tool
|
|
114
114
|
self.response_tool_name = None
|
|
115
|
-
self.initialization_complete = False
|
|
116
115
|
|
|
117
116
|
# Trace
|
|
118
117
|
self._auto_trace = auto_trace
|
|
@@ -131,7 +130,7 @@ class MCPAgent(ABC):
|
|
|
131
130
|
|
|
132
131
|
self.mcp_client = MCPClient(mcp_config=task.mcp_config)
|
|
133
132
|
self._auto_created_client = True
|
|
134
|
-
self.console.
|
|
133
|
+
self.console.debug("Auto-created MCPClient from task.mcp_config")
|
|
135
134
|
|
|
136
135
|
# Ensure we have a client
|
|
137
136
|
if self.mcp_client is None:
|
|
@@ -149,17 +148,29 @@ class MCPAgent(ABC):
|
|
|
149
148
|
|
|
150
149
|
# If task is provided, add lifecycle tools
|
|
151
150
|
if isinstance(task, Task):
|
|
151
|
+
if task.agent_tools:
|
|
152
|
+
self.agent_tools = task.agent_tools
|
|
152
153
|
if task.setup_tool:
|
|
153
154
|
if isinstance(task.setup_tool, list):
|
|
154
155
|
for tool in task.setup_tool:
|
|
155
|
-
self.
|
|
156
|
-
|
|
156
|
+
if not self.agent_tools or (
|
|
157
|
+
self.agent_tools and tool.name not in self.agent_tools
|
|
158
|
+
):
|
|
159
|
+
self.lifecycle_tools.append(tool.name)
|
|
160
|
+
elif not self.agent_tools or (
|
|
161
|
+
self.agent_tools and task.setup_tool.name not in self.agent_tools
|
|
162
|
+
):
|
|
157
163
|
self.lifecycle_tools.append(task.setup_tool.name)
|
|
158
164
|
if task.evaluate_tool:
|
|
159
165
|
if isinstance(task.evaluate_tool, list):
|
|
160
166
|
for tool in task.evaluate_tool:
|
|
161
|
-
self.
|
|
162
|
-
|
|
167
|
+
if not self.agent_tools or (
|
|
168
|
+
self.agent_tools and tool.name not in self.agent_tools
|
|
169
|
+
):
|
|
170
|
+
self.lifecycle_tools.append(tool.name)
|
|
171
|
+
elif not self.agent_tools or (
|
|
172
|
+
self.agent_tools and task.evaluate_tool.name not in self.agent_tools
|
|
173
|
+
):
|
|
163
174
|
self.lifecycle_tools.append(task.evaluate_tool.name)
|
|
164
175
|
if task.system_prompt:
|
|
165
176
|
self.system_prompt += "\n\n" + task.system_prompt
|
|
@@ -167,11 +178,6 @@ class MCPAgent(ABC):
|
|
|
167
178
|
# Re-apply filtering with updated lifecycle tools
|
|
168
179
|
await self._filter_tools()
|
|
169
180
|
|
|
170
|
-
num_tools = len(self._available_tools)
|
|
171
|
-
self.console.success_log(
|
|
172
|
-
f"Agent initialized with {num_tools} available tools (after filtering)"
|
|
173
|
-
)
|
|
174
|
-
|
|
175
181
|
async def run(self, prompt_or_task: str | Task | dict[str, Any], max_steps: int = 10) -> Trace:
|
|
176
182
|
"""
|
|
177
183
|
Run the agent with the given prompt or task.
|
|
@@ -188,12 +194,12 @@ class MCPAgent(ABC):
|
|
|
188
194
|
|
|
189
195
|
if isinstance(prompt_or_task, dict):
|
|
190
196
|
prompt_or_task = Task(**prompt_or_task)
|
|
197
|
+
elif not isinstance(prompt_or_task, str) and not isinstance(prompt_or_task, Task):
|
|
198
|
+
raise TypeError(f"prompt_or_task must be str or Task, got {type(prompt_or_task)}")
|
|
191
199
|
|
|
192
200
|
try:
|
|
193
201
|
# Establish the connection with the MCP server/Environment
|
|
194
|
-
|
|
195
|
-
await self.initialize(prompt_or_task)
|
|
196
|
-
self.initialization_complete = True
|
|
202
|
+
await self.initialize(prompt_or_task)
|
|
197
203
|
|
|
198
204
|
# Handle Task objects with full lifecycle
|
|
199
205
|
if isinstance(prompt_or_task, Task):
|
|
@@ -204,8 +210,6 @@ class MCPAgent(ABC):
|
|
|
204
210
|
context = text_to_blocks(prompt_or_task)
|
|
205
211
|
return await self._run_context(context, max_steps=max_steps)
|
|
206
212
|
|
|
207
|
-
else:
|
|
208
|
-
raise TypeError(f"prompt_or_task must be str or Task, got {type(prompt_or_task)}")
|
|
209
213
|
except Exception as e:
|
|
210
214
|
# Always return a Trace object for any exception
|
|
211
215
|
if self._is_connection_error(e):
|
|
@@ -240,8 +244,6 @@ class MCPAgent(ABC):
|
|
|
240
244
|
Returns:
|
|
241
245
|
Trace with reward from evaluation
|
|
242
246
|
"""
|
|
243
|
-
prompt_result = None
|
|
244
|
-
|
|
245
247
|
try:
|
|
246
248
|
# Setup phase
|
|
247
249
|
start_context: list[types.ContentBlock] = []
|
|
@@ -255,7 +257,13 @@ class MCPAgent(ABC):
|
|
|
255
257
|
self.console.progress_log(f"Setting up tool phase: {task.setup_tool}")
|
|
256
258
|
results = await self.call_tools(task.setup_tool)
|
|
257
259
|
if any(result.isError for result in results):
|
|
258
|
-
|
|
260
|
+
return Trace(
|
|
261
|
+
reward=0.0,
|
|
262
|
+
done=True,
|
|
263
|
+
content=f"Setup tool failed: {results}",
|
|
264
|
+
isError=True,
|
|
265
|
+
task=task,
|
|
266
|
+
)
|
|
259
267
|
|
|
260
268
|
if self.append_setup_output and isinstance(results[0].content, list):
|
|
261
269
|
start_context.extend(results[0].content)
|
|
@@ -268,13 +276,12 @@ class MCPAgent(ABC):
|
|
|
268
276
|
except Exception as e:
|
|
269
277
|
self.console.error_log(f"Task execution failed: {e}")
|
|
270
278
|
# Create an error result but don't return yet - we still want to evaluate
|
|
271
|
-
prompt_result = Trace(reward=0.0, done=True, content=str(e), isError=True)
|
|
279
|
+
prompt_result = Trace(reward=0.0, done=True, content=str(e), isError=True, task=task)
|
|
272
280
|
prompt_result.populate_from_context()
|
|
273
281
|
|
|
274
282
|
# Always evaluate if we have evaluate tool, regardless of errors
|
|
275
283
|
if task.evaluate_tool is not None:
|
|
276
284
|
try:
|
|
277
|
-
self.console.progress_log(f"Evaluating tool phase: {task.evaluate_tool}")
|
|
278
285
|
results = await self.call_tools(task.evaluate_tool)
|
|
279
286
|
|
|
280
287
|
if any(result.isError for result in results):
|
|
@@ -286,18 +293,24 @@ class MCPAgent(ABC):
|
|
|
286
293
|
done=True,
|
|
287
294
|
content="Task failed before evaluation",
|
|
288
295
|
isError=True,
|
|
296
|
+
task=task,
|
|
289
297
|
)
|
|
290
298
|
prompt_result.reward = 0.0 # Default to 0 on error
|
|
291
299
|
else:
|
|
292
300
|
# Extract reward and content from evaluation
|
|
293
301
|
if results:
|
|
294
302
|
reward = find_reward(results[0])
|
|
303
|
+
self.console.info_log(f"Eval: {reward:.4f} {task.evaluate_tool}")
|
|
295
304
|
eval_content = find_content(results[0])
|
|
296
305
|
|
|
297
306
|
# Update the prompt result with evaluation reward
|
|
298
307
|
if prompt_result is None:
|
|
299
308
|
prompt_result = Trace(
|
|
300
|
-
reward=reward,
|
|
309
|
+
reward=reward,
|
|
310
|
+
done=True,
|
|
311
|
+
content=eval_content or "",
|
|
312
|
+
isError=False,
|
|
313
|
+
task=task,
|
|
301
314
|
)
|
|
302
315
|
else:
|
|
303
316
|
prompt_result.reward = reward
|
|
@@ -316,14 +329,16 @@ class MCPAgent(ABC):
|
|
|
316
329
|
# Ensure we have a result even if evaluation failed
|
|
317
330
|
if prompt_result is None:
|
|
318
331
|
prompt_result = Trace(
|
|
319
|
-
reward=0.0,
|
|
332
|
+
reward=0.0,
|
|
333
|
+
done=True,
|
|
334
|
+
content=f"Evaluation failed: {e}",
|
|
335
|
+
isError=True,
|
|
336
|
+
task=task,
|
|
320
337
|
)
|
|
321
338
|
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
else Trace(reward=0.0, done=True, content="No result available", isError=True)
|
|
326
|
-
)
|
|
339
|
+
prompt_result.task = task
|
|
340
|
+
|
|
341
|
+
return prompt_result
|
|
327
342
|
|
|
328
343
|
async def _run_context(
|
|
329
344
|
self, context: list[types.ContentBlock], *, max_steps: int = 10
|
|
@@ -388,7 +403,11 @@ class MCPAgent(ABC):
|
|
|
388
403
|
|
|
389
404
|
# 2. Execute tools
|
|
390
405
|
tool_calls = response.tool_calls
|
|
406
|
+
for tool_call in tool_calls:
|
|
407
|
+
self.console.info_log(f"{tool_call}")
|
|
391
408
|
tool_results = await self.call_tools(tool_calls)
|
|
409
|
+
for tool_result in tool_results:
|
|
410
|
+
self.console.info_log(f"{tool_result}")
|
|
392
411
|
|
|
393
412
|
# 3. Format tool results and add to messages
|
|
394
413
|
tool_messages = await self.format_tool_results(tool_calls, tool_results)
|
|
@@ -422,13 +441,23 @@ class MCPAgent(ABC):
|
|
|
422
441
|
error = str(e)
|
|
423
442
|
|
|
424
443
|
# Build result
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
444
|
+
if error is not None or (
|
|
445
|
+
final_response and hasattr(final_response, "isError") and final_response.isError
|
|
446
|
+
):
|
|
447
|
+
is_error = True
|
|
448
|
+
else:
|
|
449
|
+
is_error = False
|
|
450
|
+
|
|
451
|
+
# Ensure all parameters are the correct type
|
|
452
|
+
trace_params = {
|
|
453
|
+
"reward": 0.0,
|
|
454
|
+
"done": True,
|
|
455
|
+
"messages": messages,
|
|
456
|
+
"content": final_response.content if final_response else error,
|
|
457
|
+
"isError": is_error,
|
|
458
|
+
"info": {"error": error} if error else {},
|
|
459
|
+
}
|
|
460
|
+
trace_result = Trace(**trace_params)
|
|
432
461
|
|
|
433
462
|
# Populate trace steps from current context
|
|
434
463
|
trace_result.populate_from_context()
|
|
@@ -474,16 +503,14 @@ class MCPAgent(ABC):
|
|
|
474
503
|
return results
|
|
475
504
|
|
|
476
505
|
@abstractmethod
|
|
477
|
-
async def get_system_messages(self) -> list[
|
|
506
|
+
async def get_system_messages(self) -> list[types.ContentBlock]:
|
|
478
507
|
"""
|
|
479
508
|
Get the system prompt.
|
|
480
509
|
"""
|
|
481
510
|
raise NotImplementedError
|
|
482
511
|
|
|
483
512
|
@abstractmethod
|
|
484
|
-
async def get_response(
|
|
485
|
-
self, messages: list[Any]
|
|
486
|
-
) -> AgentResponse: # maybe type messages as list[types.ContentBlock]
|
|
513
|
+
async def get_response(self, messages: list[Any]) -> AgentResponse:
|
|
487
514
|
"""
|
|
488
515
|
Get response from the model including any tool calls.
|
|
489
516
|
|
|
@@ -607,6 +634,7 @@ class MCPAgent(ABC):
|
|
|
607
634
|
|
|
608
635
|
self.console.debug(f"All tools: {[t.name for t in all_tools]}")
|
|
609
636
|
self.console.debug(f"Allowed tools: {self.allowed_tools}")
|
|
637
|
+
self.console.debug(f"Agent tools: {self.agent_tools}")
|
|
610
638
|
self.console.debug(f"Disallowed tools: {self.disallowed_tools}")
|
|
611
639
|
self.console.debug(f"Lifecycle tools: {self.lifecycle_tools}")
|
|
612
640
|
|
|
@@ -619,6 +647,9 @@ class MCPAgent(ABC):
|
|
|
619
647
|
if self.allowed_tools and tool.name not in self.allowed_tools:
|
|
620
648
|
self.console.debug(f"Skipping tool '{tool.name}' - not in allowed_tools")
|
|
621
649
|
continue
|
|
650
|
+
if self.agent_tools and tool.name not in self.agent_tools:
|
|
651
|
+
self.console.debug(f"Skipping tool '{tool.name}' - not in agent_tools")
|
|
652
|
+
continue
|
|
622
653
|
if tool.name in self.disallowed_tools:
|
|
623
654
|
self.console.debug(f"Skipping tool '{tool.name}' - in disallowed_tools")
|
|
624
655
|
continue
|
|
@@ -641,6 +672,11 @@ class MCPAgent(ABC):
|
|
|
641
672
|
f"Available tools: {list(available_tool_names)}"
|
|
642
673
|
)
|
|
643
674
|
|
|
675
|
+
available_tools = self.get_available_tools()
|
|
676
|
+
self.console.info(
|
|
677
|
+
f"Agent initialized with {len(available_tools)} tools: {', '.join([t.name for t in available_tools])}" # noqa: E501
|
|
678
|
+
)
|
|
679
|
+
|
|
644
680
|
async def _maybe_submit_response(self, response: AgentResponse, messages: list[Any]) -> None:
|
|
645
681
|
"""Submit response through lifecycle tool if available.
|
|
646
682
|
|
hud/agents/claude.py
CHANGED
|
@@ -28,6 +28,7 @@ import mcp.types as types
|
|
|
28
28
|
from hud.settings import settings
|
|
29
29
|
from hud.tools.computer.settings import computer_settings
|
|
30
30
|
from hud.types import AgentResponse, MCPToolCall, MCPToolResult
|
|
31
|
+
from hud.utils.hud_console import HUDConsole
|
|
31
32
|
|
|
32
33
|
from .base import MCPAgent
|
|
33
34
|
|
|
@@ -78,6 +79,7 @@ class ClaudeAgent(MCPAgent):
|
|
|
78
79
|
self.model = model
|
|
79
80
|
self.max_tokens = max_tokens
|
|
80
81
|
self.use_computer_beta = use_computer_beta
|
|
82
|
+
self.hud_console = HUDConsole(logger=logger)
|
|
81
83
|
|
|
82
84
|
self.model_name = self.model
|
|
83
85
|
|
|
@@ -149,7 +151,7 @@ class ClaudeAgent(MCPAgent):
|
|
|
149
151
|
)
|
|
150
152
|
else:
|
|
151
153
|
# For other types, try to cast but log a warning
|
|
152
|
-
|
|
154
|
+
self.hud_console.log(f"Unknown content block type: {type(block)}", level="warning")
|
|
153
155
|
anthropic_blocks.append(cast("BetaContentBlockParam", block))
|
|
154
156
|
|
|
155
157
|
return [
|
|
@@ -201,7 +203,7 @@ class ClaudeAgent(MCPAgent):
|
|
|
201
203
|
or "request_too_large" in str(e)
|
|
202
204
|
or e.status_code == 413
|
|
203
205
|
):
|
|
204
|
-
|
|
206
|
+
self.hud_console.warning("Prompt too long, truncating message history")
|
|
205
207
|
# Keep first message and last 20 messages
|
|
206
208
|
if len(current_messages) > 21:
|
|
207
209
|
current_messages = [current_messages[0], *current_messages[-20:]]
|
|
@@ -266,7 +268,7 @@ class ClaudeAgent(MCPAgent):
|
|
|
266
268
|
# Extract Claude-specific metadata from extra fields
|
|
267
269
|
tool_use_id = tool_call.id
|
|
268
270
|
if not tool_use_id:
|
|
269
|
-
|
|
271
|
+
self.hud_console.warning(f"No tool_use_id found for {tool_call.name}")
|
|
270
272
|
continue
|
|
271
273
|
|
|
272
274
|
# Convert MCP tool results to Claude format
|
|
@@ -335,7 +337,9 @@ class ClaudeAgent(MCPAgent):
|
|
|
335
337
|
# Map Claude's "computer" back to the actual MCP tool name
|
|
336
338
|
self._claude_to_mcp_tool_map["computer"] = selected_computer_tool.name
|
|
337
339
|
claude_tools.append(claude_tool)
|
|
338
|
-
|
|
340
|
+
self.hud_console.debug(
|
|
341
|
+
f"Using {selected_computer_tool.name} as computer tool for Claude"
|
|
342
|
+
)
|
|
339
343
|
|
|
340
344
|
# Add other non-computer tools
|
|
341
345
|
for tool in self._available_tools:
|
|
@@ -23,6 +23,7 @@ import mcp.types as types
|
|
|
23
23
|
|
|
24
24
|
from hud import instrument
|
|
25
25
|
from hud.types import AgentResponse, MCPToolCall, MCPToolResult
|
|
26
|
+
from hud.utils.hud_console import HUDConsole
|
|
26
27
|
|
|
27
28
|
from .base import MCPAgent
|
|
28
29
|
|
|
@@ -43,7 +44,6 @@ class GenericOpenAIChatAgent(MCPAgent):
|
|
|
43
44
|
*,
|
|
44
45
|
openai_client: AsyncOpenAI,
|
|
45
46
|
model_name: str = "gpt-4o-mini",
|
|
46
|
-
parallel_tool_calls: bool = False,
|
|
47
47
|
completion_kwargs: dict[str, Any] | None = None,
|
|
48
48
|
**agent_kwargs: Any,
|
|
49
49
|
) -> None:
|
|
@@ -51,17 +51,22 @@ class GenericOpenAIChatAgent(MCPAgent):
|
|
|
51
51
|
super().__init__(**agent_kwargs)
|
|
52
52
|
self.oai = openai_client
|
|
53
53
|
self.model_name = model_name
|
|
54
|
-
self.parallel_tool_calls = parallel_tool_calls
|
|
55
54
|
self.completion_kwargs: dict[str, Any] = completion_kwargs or {}
|
|
56
|
-
self.
|
|
55
|
+
self.mcp_schemas = []
|
|
56
|
+
self.hud_console = HUDConsole(logger=logger)
|
|
57
57
|
|
|
58
58
|
@staticmethod
|
|
59
59
|
def _oai_to_mcp(tool_call: Any) -> MCPToolCall: # type: ignore[valid-type]
|
|
60
60
|
"""Convert an OpenAI ``tool_call`` to :class:`MCPToolCall`."""
|
|
61
|
+
args = json.loads(tool_call.function.arguments or "{}")
|
|
62
|
+
if isinstance(args, list):
|
|
63
|
+
args = args[0]
|
|
64
|
+
if not isinstance(args, dict):
|
|
65
|
+
args = {}
|
|
61
66
|
return MCPToolCall(
|
|
62
67
|
id=tool_call.id,
|
|
63
68
|
name=tool_call.function.name,
|
|
64
|
-
arguments=
|
|
69
|
+
arguments=args,
|
|
65
70
|
)
|
|
66
71
|
|
|
67
72
|
async def get_system_messages(self) -> list[Any]:
|
|
@@ -177,45 +182,66 @@ class GenericOpenAIChatAgent(MCPAgent):
|
|
|
177
182
|
# Convert MCP tool schemas to OpenAI format
|
|
178
183
|
mcp_schemas = self.get_tool_schemas()
|
|
179
184
|
|
|
180
|
-
protected_keys = {"model", "messages", "tools"
|
|
185
|
+
protected_keys = {"model", "messages", "tools"}
|
|
181
186
|
extra = {k: v for k, v in (self.completion_kwargs or {}).items() if k not in protected_keys}
|
|
182
187
|
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
188
|
+
try:
|
|
189
|
+
response = await self.oai.chat.completions.create(
|
|
190
|
+
model=self.model_name,
|
|
191
|
+
messages=messages,
|
|
192
|
+
tools=cast("list[ChatCompletionToolParam]", mcp_schemas),
|
|
193
|
+
**extra,
|
|
194
|
+
)
|
|
195
|
+
except Exception as e:
|
|
196
|
+
error_content = f"Error getting response {e}"
|
|
197
|
+
if "Invalid JSON" in str(e):
|
|
198
|
+
error_content = "Invalid JSON, response was truncated"
|
|
199
|
+
self.hud_console.warning_log(error_content)
|
|
200
|
+
|
|
201
|
+
return AgentResponse(
|
|
202
|
+
content=error_content,
|
|
203
|
+
tool_calls=[],
|
|
204
|
+
done=True,
|
|
205
|
+
isError=True,
|
|
206
|
+
raw=None,
|
|
207
|
+
)
|
|
190
208
|
|
|
191
209
|
choice = response.choices[0]
|
|
192
210
|
msg = choice.message
|
|
193
|
-
|
|
194
211
|
assistant_msg: dict[str, Any] = {"role": "assistant"}
|
|
195
212
|
|
|
196
213
|
if msg.content:
|
|
197
214
|
assistant_msg["content"] = msg.content
|
|
198
215
|
|
|
199
216
|
if msg.tool_calls:
|
|
200
|
-
|
|
217
|
+
serialized_tool_calls = []
|
|
218
|
+
for tc in msg.tool_calls:
|
|
219
|
+
serialized_tc = {
|
|
220
|
+
"id": tc.id,
|
|
221
|
+
"type": "function",
|
|
222
|
+
"function": {"name": tc.function.name, "arguments": tc.function.arguments},
|
|
223
|
+
}
|
|
224
|
+
serialized_tool_calls.append(serialized_tc)
|
|
225
|
+
assistant_msg["tool_calls"] = serialized_tool_calls
|
|
201
226
|
|
|
202
227
|
messages.append(assistant_msg)
|
|
203
228
|
|
|
204
|
-
# Store the complete conversation history
|
|
205
|
-
self.conversation_history = messages.copy()
|
|
206
|
-
|
|
207
229
|
tool_calls = []
|
|
208
230
|
if msg.tool_calls:
|
|
209
231
|
for tc in msg.tool_calls:
|
|
210
232
|
if tc.function.name is not None: # type: ignore
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
233
|
+
# _oai_to_mcp returns a single MCPToolCall; append it
|
|
234
|
+
tool_calls.append(self._oai_to_mcp(tc)) # noqa: PERF401
|
|
235
|
+
|
|
236
|
+
# Only stop on length (token limit), never on "stop"
|
|
237
|
+
done = choice.finish_reason == "length"
|
|
238
|
+
if done:
|
|
239
|
+
self.hud_console.info_log(f"Done decision: finish_reason={choice.finish_reason}")
|
|
214
240
|
|
|
215
241
|
return AgentResponse(
|
|
216
242
|
content=msg.content or "",
|
|
217
243
|
tool_calls=tool_calls,
|
|
218
|
-
done=
|
|
244
|
+
done=done,
|
|
219
245
|
raw=response, # Include raw response for access to Choice objects
|
|
220
246
|
)
|
|
221
247
|
|
|
@@ -230,15 +256,15 @@ class GenericOpenAIChatAgent(MCPAgent):
|
|
|
230
256
|
When images are present, we return both a tool message and a user message.
|
|
231
257
|
"""
|
|
232
258
|
rendered: list[dict[str, Any]] = []
|
|
259
|
+
|
|
260
|
+
# Separate text and image content
|
|
261
|
+
image_parts = []
|
|
233
262
|
for call, res in zip(tool_calls, tool_results, strict=False):
|
|
234
263
|
# Use structuredContent.result if available, otherwise use content
|
|
235
|
-
items = res.content
|
|
236
|
-
if res.structuredContent and isinstance(res.structuredContent, dict):
|
|
237
|
-
items = res.structuredContent.get("result", res.content)
|
|
238
|
-
|
|
239
|
-
# Separate text and image content
|
|
240
264
|
text_parts = []
|
|
241
|
-
|
|
265
|
+
items = res.content
|
|
266
|
+
if not res.content and res.structuredContent:
|
|
267
|
+
items = [res.structuredContent.get("result", res.content)]
|
|
242
268
|
|
|
243
269
|
for item in items:
|
|
244
270
|
if isinstance(item, dict):
|
|
@@ -272,18 +298,18 @@ class GenericOpenAIChatAgent(MCPAgent):
|
|
|
272
298
|
}
|
|
273
299
|
)
|
|
274
300
|
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
301
|
+
# If there are images, add them as a separate user message
|
|
302
|
+
if image_parts:
|
|
303
|
+
# Add a user message with the images
|
|
304
|
+
content_with_images = [
|
|
305
|
+
{"type": "text", "text": "Tool returned the following:"},
|
|
306
|
+
image_parts[-1],
|
|
307
|
+
]
|
|
308
|
+
rendered.append(
|
|
309
|
+
{
|
|
310
|
+
"role": "user",
|
|
311
|
+
"content": content_with_images,
|
|
312
|
+
}
|
|
313
|
+
)
|
|
288
314
|
|
|
289
315
|
return rendered
|
hud/agents/tests/test_base.py
CHANGED
|
@@ -97,7 +97,6 @@ class TestBaseMCPAgent:
|
|
|
97
97
|
assert agent.disallowed_tools == []
|
|
98
98
|
assert agent.initial_screenshot is True
|
|
99
99
|
assert agent.system_prompt is not None # Default system prompt is set
|
|
100
|
-
assert agent.lifecycle_tools == []
|
|
101
100
|
|
|
102
101
|
def test_init_with_params(self):
|
|
103
102
|
"""Test initialization with custom parameters."""
|
|
@@ -108,7 +107,6 @@ class TestBaseMCPAgent:
|
|
|
108
107
|
disallowed_tools=["bad_tool"],
|
|
109
108
|
initial_screenshot=True,
|
|
110
109
|
system_prompt="Custom prompt",
|
|
111
|
-
lifecycle_tools=["custom_setup", "custom_eval"],
|
|
112
110
|
)
|
|
113
111
|
|
|
114
112
|
assert agent.mcp_client == client
|
|
@@ -116,7 +114,6 @@ class TestBaseMCPAgent:
|
|
|
116
114
|
assert agent.disallowed_tools == ["bad_tool"]
|
|
117
115
|
assert agent.initial_screenshot is True
|
|
118
116
|
assert agent.system_prompt == "Custom prompt"
|
|
119
|
-
assert agent.lifecycle_tools == ["custom_setup", "custom_eval"]
|
|
120
117
|
|
|
121
118
|
@pytest.mark.asyncio
|
|
122
119
|
async def test_init_no_client_no_task(self):
|
|
@@ -631,7 +628,6 @@ class TestMCPAgentExtended:
|
|
|
631
628
|
# Lifecycle tools are specified by name, not as objects
|
|
632
629
|
agent = MockAgentExtended(
|
|
633
630
|
mcp_client=mock_client,
|
|
634
|
-
lifecycle_tools=["screenshot"], # Use tool name
|
|
635
631
|
responses=[{"role": "assistant", "content": "Done", "tool_calls": []}],
|
|
636
632
|
)
|
|
637
633
|
|
hud/agents/tests/test_openai.py
CHANGED
|
@@ -156,7 +156,7 @@ class TestOperatorAgent:
|
|
|
156
156
|
messages = [{"prompt": "What's on the screen?", "screenshot": None}]
|
|
157
157
|
response = await agent.get_response(messages)
|
|
158
158
|
|
|
159
|
-
assert response.content == "I can see the screen content."
|
|
159
|
+
assert response.content[0].text == "I can see the screen content."
|
|
160
160
|
assert response.done is True
|
|
161
161
|
|
|
162
162
|
@pytest.mark.asyncio
|