hud-python 0.4.28__py3-none-any.whl → 0.4.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hud-python might be problematic. Click here for more details.
- hud/__init__.py +2 -1
- hud/agents/base.py +73 -45
- hud/agents/claude.py +8 -4
- hud/agents/openai_chat_generic.py +65 -40
- hud/agents/tests/test_base.py +0 -4
- hud/agents/tests/test_openai.py +1 -1
- hud/cli/__init__.py +182 -52
- hud/cli/dev.py +8 -9
- hud/cli/eval.py +317 -119
- hud/cli/flows/__init__.py +0 -0
- hud/cli/flows/tasks.py +0 -0
- hud/cli/get.py +160 -0
- hud/cli/rl/__init__.py +563 -71
- hud/cli/rl/config.py +94 -0
- hud/cli/rl/display.py +133 -0
- hud/cli/rl/gpu.py +63 -0
- hud/cli/rl/gpu_utils.py +318 -0
- hud/cli/rl/presets.py +96 -0
- hud/cli/rl/remote_runner.py +348 -0
- hud/cli/rl/rl_api.py +150 -0
- hud/cli/rl/vllm.py +177 -0
- hud/cli/tests/test_analyze_metadata.py +0 -1
- hud/cli/utils/tasks.py +26 -0
- hud/clients/base.py +21 -23
- hud/clients/mcp_use.py +36 -44
- hud/clients/tests/test_mcp_use_retry.py +10 -10
- hud/datasets/__init__.py +4 -3
- hud/datasets/{execution/parallel.py → parallel.py} +1 -1
- hud/datasets/{execution/runner.py → runner.py} +1 -1
- hud/datasets/utils.py +1 -1
- hud/native/tests/test_native_init.py +1 -1
- hud/otel/config.py +1 -1
- hud/otel/instrumentation.py +35 -0
- hud/rl/README.md +31 -0
- hud/rl/__init__.py +1 -0
- hud/rl/actor.py +174 -0
- hud/rl/buffer.py +371 -0
- hud/rl/chat_template.jinja +101 -0
- hud/rl/config.py +184 -0
- hud/rl/distributed.py +95 -0
- hud/rl/learner.py +586 -0
- hud/rl/tests/__init__.py +1 -0
- hud/rl/tests/test_learner.py +171 -0
- hud/rl/train.py +354 -0
- hud/rl/types.py +101 -0
- hud/rl/utils/start_vllm_server.sh +30 -0
- hud/rl/utils.py +524 -0
- hud/rl/vllm_adapter.py +125 -0
- hud/settings.py +6 -0
- hud/telemetry/__init__.py +2 -1
- hud/telemetry/job.py +46 -3
- hud/telemetry/tests/test_trace.py +3 -3
- hud/telemetry/trace.py +85 -13
- hud/tools/tests/test_computer.py +3 -3
- hud/tools/tests/test_computer_actions.py +1 -1
- hud/types.py +123 -2
- hud/utils/group_eval.py +223 -0
- hud/utils/hud_console.py +113 -13
- hud/utils/tasks.py +119 -0
- hud/utils/tests/test_version.py +1 -1
- hud/version.py +1 -1
- {hud_python-0.4.28.dist-info → hud_python-0.4.29.dist-info}/METADATA +20 -2
- {hud_python-0.4.28.dist-info → hud_python-0.4.29.dist-info}/RECORD +66 -46
- hud/cli/hf.py +0 -406
- hud/cli/rl/README.md +0 -243
- hud/cli/rl/init.py +0 -370
- hud/cli/rl/pod.py +0 -501
- hud/cli/rl/ssh.py +0 -322
- hud/cli/rl/train.py +0 -562
- hud/cli/rl/utils.py +0 -165
- hud/datasets/execution/__init__.py +0 -13
- hud/datasets/task.py +0 -116
- {hud_python-0.4.28.dist-info → hud_python-0.4.29.dist-info}/WHEEL +0 -0
- {hud_python-0.4.28.dist-info → hud_python-0.4.29.dist-info}/entry_points.txt +0 -0
- {hud_python-0.4.28.dist-info → hud_python-0.4.29.dist-info}/licenses/LICENSE +0 -0
hud/__init__.py
CHANGED
|
@@ -5,9 +5,10 @@ tools for building, evaluating, and training AI agents.
|
|
|
5
5
|
|
|
6
6
|
from __future__ import annotations
|
|
7
7
|
|
|
8
|
-
from .telemetry import clear_trace, create_job, get_trace, instrument, job, trace
|
|
8
|
+
from .telemetry import Trace, clear_trace, create_job, get_trace, instrument, job, trace
|
|
9
9
|
|
|
10
10
|
__all__ = [
|
|
11
|
+
"Trace",
|
|
11
12
|
"clear_trace",
|
|
12
13
|
"create_job",
|
|
13
14
|
"get_trace",
|
hud/agents/base.py
CHANGED
|
@@ -45,7 +45,7 @@ class MCPAgent(ABC):
|
|
|
45
45
|
`format_blocks`, and `format_tool_results`.
|
|
46
46
|
"""
|
|
47
47
|
|
|
48
|
-
metadata: dict[str, Any]
|
|
48
|
+
metadata: dict[str, Any] | None = None
|
|
49
49
|
required_tools: ClassVar[list[str]] = [] # Tools that must be available
|
|
50
50
|
|
|
51
51
|
def __init__(
|
|
@@ -54,7 +54,6 @@ class MCPAgent(ABC):
|
|
|
54
54
|
# Filtering
|
|
55
55
|
allowed_tools: list[str] | None = None,
|
|
56
56
|
disallowed_tools: list[str] | None = None,
|
|
57
|
-
lifecycle_tools: list[str] | None = None,
|
|
58
57
|
# Messages
|
|
59
58
|
system_prompt: str = GLOBAL_SYSTEM_PROMPT,
|
|
60
59
|
append_setup_output: bool = True,
|
|
@@ -74,8 +73,6 @@ class MCPAgent(ABC):
|
|
|
74
73
|
that provides `mcp_config`.
|
|
75
74
|
allowed_tools: Names of tools to allow (None means allow all).
|
|
76
75
|
disallowed_tools: Names of tools to always exclude.
|
|
77
|
-
lifecycle_tools: Tools reserved for lifecycle phases (e.g., setup,
|
|
78
|
-
evaluate). These are hidden from normal tool calling.
|
|
79
76
|
system_prompt: System prompt to seed the conversation.
|
|
80
77
|
append_setup_output: Whether to append setup tool output to the
|
|
81
78
|
first turn's messages.
|
|
@@ -98,10 +95,13 @@ class MCPAgent(ABC):
|
|
|
98
95
|
if verbose:
|
|
99
96
|
self.console.set_verbose(True)
|
|
100
97
|
|
|
101
|
-
#
|
|
98
|
+
# User filtering
|
|
102
99
|
self.allowed_tools = allowed_tools
|
|
103
100
|
self.disallowed_tools = disallowed_tools or []
|
|
104
|
-
|
|
101
|
+
|
|
102
|
+
# Task filtering
|
|
103
|
+
self.agent_tools = None
|
|
104
|
+
self.lifecycle_tools = []
|
|
105
105
|
|
|
106
106
|
# Messages
|
|
107
107
|
self.system_prompt = system_prompt
|
|
@@ -112,7 +112,6 @@ class MCPAgent(ABC):
|
|
|
112
112
|
self._available_tools: list[types.Tool] = []
|
|
113
113
|
self._tool_map: dict[str, types.Tool] = {} # Simplified: just name to tool
|
|
114
114
|
self.response_tool_name = None
|
|
115
|
-
self.initialization_complete = False
|
|
116
115
|
|
|
117
116
|
# Trace
|
|
118
117
|
self._auto_trace = auto_trace
|
|
@@ -131,7 +130,7 @@ class MCPAgent(ABC):
|
|
|
131
130
|
|
|
132
131
|
self.mcp_client = MCPClient(mcp_config=task.mcp_config)
|
|
133
132
|
self._auto_created_client = True
|
|
134
|
-
self.console.
|
|
133
|
+
self.console.debug("Auto-created MCPClient from task.mcp_config")
|
|
135
134
|
|
|
136
135
|
# Ensure we have a client
|
|
137
136
|
if self.mcp_client is None:
|
|
@@ -149,17 +148,21 @@ class MCPAgent(ABC):
|
|
|
149
148
|
|
|
150
149
|
# If task is provided, add lifecycle tools
|
|
151
150
|
if isinstance(task, Task):
|
|
151
|
+
if task.agent_tools:
|
|
152
|
+
self.agent_tools = task.agent_tools
|
|
152
153
|
if task.setup_tool:
|
|
153
154
|
if isinstance(task.setup_tool, list):
|
|
154
155
|
for tool in task.setup_tool:
|
|
155
|
-
self.
|
|
156
|
-
|
|
156
|
+
if self.agent_tools and tool.name not in self.agent_tools:
|
|
157
|
+
self.lifecycle_tools.append(tool.name)
|
|
158
|
+
elif self.agent_tools and task.setup_tool.name not in self.agent_tools:
|
|
157
159
|
self.lifecycle_tools.append(task.setup_tool.name)
|
|
158
160
|
if task.evaluate_tool:
|
|
159
161
|
if isinstance(task.evaluate_tool, list):
|
|
160
162
|
for tool in task.evaluate_tool:
|
|
161
|
-
self.
|
|
162
|
-
|
|
163
|
+
if self.agent_tools and tool.name not in self.agent_tools:
|
|
164
|
+
self.lifecycle_tools.append(tool.name)
|
|
165
|
+
elif self.agent_tools and task.evaluate_tool.name not in self.agent_tools:
|
|
163
166
|
self.lifecycle_tools.append(task.evaluate_tool.name)
|
|
164
167
|
if task.system_prompt:
|
|
165
168
|
self.system_prompt += "\n\n" + task.system_prompt
|
|
@@ -167,11 +170,6 @@ class MCPAgent(ABC):
|
|
|
167
170
|
# Re-apply filtering with updated lifecycle tools
|
|
168
171
|
await self._filter_tools()
|
|
169
172
|
|
|
170
|
-
num_tools = len(self._available_tools)
|
|
171
|
-
self.console.success_log(
|
|
172
|
-
f"Agent initialized with {num_tools} available tools (after filtering)"
|
|
173
|
-
)
|
|
174
|
-
|
|
175
173
|
async def run(self, prompt_or_task: str | Task | dict[str, Any], max_steps: int = 10) -> Trace:
|
|
176
174
|
"""
|
|
177
175
|
Run the agent with the given prompt or task.
|
|
@@ -188,12 +186,12 @@ class MCPAgent(ABC):
|
|
|
188
186
|
|
|
189
187
|
if isinstance(prompt_or_task, dict):
|
|
190
188
|
prompt_or_task = Task(**prompt_or_task)
|
|
189
|
+
elif not isinstance(prompt_or_task, str) and not isinstance(prompt_or_task, Task):
|
|
190
|
+
raise TypeError(f"prompt_or_task must be str or Task, got {type(prompt_or_task)}")
|
|
191
191
|
|
|
192
192
|
try:
|
|
193
193
|
# Establish the connection with the MCP server/Environment
|
|
194
|
-
|
|
195
|
-
await self.initialize(prompt_or_task)
|
|
196
|
-
self.initialization_complete = True
|
|
194
|
+
await self.initialize(prompt_or_task)
|
|
197
195
|
|
|
198
196
|
# Handle Task objects with full lifecycle
|
|
199
197
|
if isinstance(prompt_or_task, Task):
|
|
@@ -204,8 +202,6 @@ class MCPAgent(ABC):
|
|
|
204
202
|
context = text_to_blocks(prompt_or_task)
|
|
205
203
|
return await self._run_context(context, max_steps=max_steps)
|
|
206
204
|
|
|
207
|
-
else:
|
|
208
|
-
raise TypeError(f"prompt_or_task must be str or Task, got {type(prompt_or_task)}")
|
|
209
205
|
except Exception as e:
|
|
210
206
|
# Always return a Trace object for any exception
|
|
211
207
|
if self._is_connection_error(e):
|
|
@@ -240,8 +236,6 @@ class MCPAgent(ABC):
|
|
|
240
236
|
Returns:
|
|
241
237
|
Trace with reward from evaluation
|
|
242
238
|
"""
|
|
243
|
-
prompt_result = None
|
|
244
|
-
|
|
245
239
|
try:
|
|
246
240
|
# Setup phase
|
|
247
241
|
start_context: list[types.ContentBlock] = []
|
|
@@ -255,7 +249,13 @@ class MCPAgent(ABC):
|
|
|
255
249
|
self.console.progress_log(f"Setting up tool phase: {task.setup_tool}")
|
|
256
250
|
results = await self.call_tools(task.setup_tool)
|
|
257
251
|
if any(result.isError for result in results):
|
|
258
|
-
|
|
252
|
+
return Trace(
|
|
253
|
+
reward=0.0,
|
|
254
|
+
done=True,
|
|
255
|
+
content=f"Setup tool failed: {results}",
|
|
256
|
+
isError=True,
|
|
257
|
+
task=task,
|
|
258
|
+
)
|
|
259
259
|
|
|
260
260
|
if self.append_setup_output and isinstance(results[0].content, list):
|
|
261
261
|
start_context.extend(results[0].content)
|
|
@@ -268,13 +268,12 @@ class MCPAgent(ABC):
|
|
|
268
268
|
except Exception as e:
|
|
269
269
|
self.console.error_log(f"Task execution failed: {e}")
|
|
270
270
|
# Create an error result but don't return yet - we still want to evaluate
|
|
271
|
-
prompt_result = Trace(reward=0.0, done=True, content=str(e), isError=True)
|
|
271
|
+
prompt_result = Trace(reward=0.0, done=True, content=str(e), isError=True, task=task)
|
|
272
272
|
prompt_result.populate_from_context()
|
|
273
273
|
|
|
274
274
|
# Always evaluate if we have evaluate tool, regardless of errors
|
|
275
275
|
if task.evaluate_tool is not None:
|
|
276
276
|
try:
|
|
277
|
-
self.console.progress_log(f"Evaluating tool phase: {task.evaluate_tool}")
|
|
278
277
|
results = await self.call_tools(task.evaluate_tool)
|
|
279
278
|
|
|
280
279
|
if any(result.isError for result in results):
|
|
@@ -286,18 +285,24 @@ class MCPAgent(ABC):
|
|
|
286
285
|
done=True,
|
|
287
286
|
content="Task failed before evaluation",
|
|
288
287
|
isError=True,
|
|
288
|
+
task=task,
|
|
289
289
|
)
|
|
290
290
|
prompt_result.reward = 0.0 # Default to 0 on error
|
|
291
291
|
else:
|
|
292
292
|
# Extract reward and content from evaluation
|
|
293
293
|
if results:
|
|
294
294
|
reward = find_reward(results[0])
|
|
295
|
+
self.console.info_log(f"Eval: {reward:.4f} {task.evaluate_tool}")
|
|
295
296
|
eval_content = find_content(results[0])
|
|
296
297
|
|
|
297
298
|
# Update the prompt result with evaluation reward
|
|
298
299
|
if prompt_result is None:
|
|
299
300
|
prompt_result = Trace(
|
|
300
|
-
reward=reward,
|
|
301
|
+
reward=reward,
|
|
302
|
+
done=True,
|
|
303
|
+
content=eval_content or "",
|
|
304
|
+
isError=False,
|
|
305
|
+
task=task,
|
|
301
306
|
)
|
|
302
307
|
else:
|
|
303
308
|
prompt_result.reward = reward
|
|
@@ -316,14 +321,16 @@ class MCPAgent(ABC):
|
|
|
316
321
|
# Ensure we have a result even if evaluation failed
|
|
317
322
|
if prompt_result is None:
|
|
318
323
|
prompt_result = Trace(
|
|
319
|
-
reward=0.0,
|
|
324
|
+
reward=0.0,
|
|
325
|
+
done=True,
|
|
326
|
+
content=f"Evaluation failed: {e}",
|
|
327
|
+
isError=True,
|
|
328
|
+
task=task,
|
|
320
329
|
)
|
|
321
330
|
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
else Trace(reward=0.0, done=True, content="No result available", isError=True)
|
|
326
|
-
)
|
|
331
|
+
prompt_result.task = task
|
|
332
|
+
|
|
333
|
+
return prompt_result
|
|
327
334
|
|
|
328
335
|
async def _run_context(
|
|
329
336
|
self, context: list[types.ContentBlock], *, max_steps: int = 10
|
|
@@ -388,7 +395,11 @@ class MCPAgent(ABC):
|
|
|
388
395
|
|
|
389
396
|
# 2. Execute tools
|
|
390
397
|
tool_calls = response.tool_calls
|
|
398
|
+
for tool_call in tool_calls:
|
|
399
|
+
self.console.info_log(f"{tool_call}")
|
|
391
400
|
tool_results = await self.call_tools(tool_calls)
|
|
401
|
+
for tool_result in tool_results:
|
|
402
|
+
self.console.info_log(f"{tool_result}")
|
|
392
403
|
|
|
393
404
|
# 3. Format tool results and add to messages
|
|
394
405
|
tool_messages = await self.format_tool_results(tool_calls, tool_results)
|
|
@@ -422,13 +433,23 @@ class MCPAgent(ABC):
|
|
|
422
433
|
error = str(e)
|
|
423
434
|
|
|
424
435
|
# Build result
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
436
|
+
if error is not None or (
|
|
437
|
+
final_response and hasattr(final_response, "isError") and final_response.isError
|
|
438
|
+
):
|
|
439
|
+
is_error = True
|
|
440
|
+
else:
|
|
441
|
+
is_error = False
|
|
442
|
+
|
|
443
|
+
# Ensure all parameters are the correct type
|
|
444
|
+
trace_params = {
|
|
445
|
+
"reward": 0.0,
|
|
446
|
+
"done": True,
|
|
447
|
+
"messages": messages,
|
|
448
|
+
"content": final_response.content if final_response else error,
|
|
449
|
+
"isError": is_error,
|
|
450
|
+
"info": {"error": error} if error else {},
|
|
451
|
+
}
|
|
452
|
+
trace_result = Trace(**trace_params)
|
|
432
453
|
|
|
433
454
|
# Populate trace steps from current context
|
|
434
455
|
trace_result.populate_from_context()
|
|
@@ -474,16 +495,14 @@ class MCPAgent(ABC):
|
|
|
474
495
|
return results
|
|
475
496
|
|
|
476
497
|
@abstractmethod
|
|
477
|
-
async def get_system_messages(self) -> list[
|
|
498
|
+
async def get_system_messages(self) -> list[types.ContentBlock]:
|
|
478
499
|
"""
|
|
479
500
|
Get the system prompt.
|
|
480
501
|
"""
|
|
481
502
|
raise NotImplementedError
|
|
482
503
|
|
|
483
504
|
@abstractmethod
|
|
484
|
-
async def get_response(
|
|
485
|
-
self, messages: list[Any]
|
|
486
|
-
) -> AgentResponse: # maybe type messages as list[types.ContentBlock]
|
|
505
|
+
async def get_response(self, messages: list[Any]) -> AgentResponse:
|
|
487
506
|
"""
|
|
488
507
|
Get response from the model including any tool calls.
|
|
489
508
|
|
|
@@ -607,6 +626,7 @@ class MCPAgent(ABC):
|
|
|
607
626
|
|
|
608
627
|
self.console.debug(f"All tools: {[t.name for t in all_tools]}")
|
|
609
628
|
self.console.debug(f"Allowed tools: {self.allowed_tools}")
|
|
629
|
+
self.console.debug(f"Agent tools: {self.agent_tools}")
|
|
610
630
|
self.console.debug(f"Disallowed tools: {self.disallowed_tools}")
|
|
611
631
|
self.console.debug(f"Lifecycle tools: {self.lifecycle_tools}")
|
|
612
632
|
|
|
@@ -619,6 +639,9 @@ class MCPAgent(ABC):
|
|
|
619
639
|
if self.allowed_tools and tool.name not in self.allowed_tools:
|
|
620
640
|
self.console.debug(f"Skipping tool '{tool.name}' - not in allowed_tools")
|
|
621
641
|
continue
|
|
642
|
+
if self.agent_tools and tool.name not in self.agent_tools:
|
|
643
|
+
self.console.debug(f"Skipping tool '{tool.name}' - not in agent_tools")
|
|
644
|
+
continue
|
|
622
645
|
if tool.name in self.disallowed_tools:
|
|
623
646
|
self.console.debug(f"Skipping tool '{tool.name}' - in disallowed_tools")
|
|
624
647
|
continue
|
|
@@ -641,6 +664,11 @@ class MCPAgent(ABC):
|
|
|
641
664
|
f"Available tools: {list(available_tool_names)}"
|
|
642
665
|
)
|
|
643
666
|
|
|
667
|
+
available_tools = self.get_available_tools()
|
|
668
|
+
self.console.info(
|
|
669
|
+
f"Agent initialized with {len(available_tools)} tools: {', '.join([t.name for t in available_tools])}" # noqa: E501
|
|
670
|
+
)
|
|
671
|
+
|
|
644
672
|
async def _maybe_submit_response(self, response: AgentResponse, messages: list[Any]) -> None:
|
|
645
673
|
"""Submit response through lifecycle tool if available.
|
|
646
674
|
|
hud/agents/claude.py
CHANGED
|
@@ -28,6 +28,7 @@ import mcp.types as types
|
|
|
28
28
|
from hud.settings import settings
|
|
29
29
|
from hud.tools.computer.settings import computer_settings
|
|
30
30
|
from hud.types import AgentResponse, MCPToolCall, MCPToolResult
|
|
31
|
+
from hud.utils.hud_console import HUDConsole
|
|
31
32
|
|
|
32
33
|
from .base import MCPAgent
|
|
33
34
|
|
|
@@ -78,6 +79,7 @@ class ClaudeAgent(MCPAgent):
|
|
|
78
79
|
self.model = model
|
|
79
80
|
self.max_tokens = max_tokens
|
|
80
81
|
self.use_computer_beta = use_computer_beta
|
|
82
|
+
self.hud_console = HUDConsole(logger=logger)
|
|
81
83
|
|
|
82
84
|
self.model_name = self.model
|
|
83
85
|
|
|
@@ -149,7 +151,7 @@ class ClaudeAgent(MCPAgent):
|
|
|
149
151
|
)
|
|
150
152
|
else:
|
|
151
153
|
# For other types, try to cast but log a warning
|
|
152
|
-
|
|
154
|
+
self.hud_console.log(f"Unknown content block type: {type(block)}", level="warning")
|
|
153
155
|
anthropic_blocks.append(cast("BetaContentBlockParam", block))
|
|
154
156
|
|
|
155
157
|
return [
|
|
@@ -201,7 +203,7 @@ class ClaudeAgent(MCPAgent):
|
|
|
201
203
|
or "request_too_large" in str(e)
|
|
202
204
|
or e.status_code == 413
|
|
203
205
|
):
|
|
204
|
-
|
|
206
|
+
self.hud_console.warning("Prompt too long, truncating message history")
|
|
205
207
|
# Keep first message and last 20 messages
|
|
206
208
|
if len(current_messages) > 21:
|
|
207
209
|
current_messages = [current_messages[0], *current_messages[-20:]]
|
|
@@ -266,7 +268,7 @@ class ClaudeAgent(MCPAgent):
|
|
|
266
268
|
# Extract Claude-specific metadata from extra fields
|
|
267
269
|
tool_use_id = tool_call.id
|
|
268
270
|
if not tool_use_id:
|
|
269
|
-
|
|
271
|
+
self.hud_console.warning(f"No tool_use_id found for {tool_call.name}")
|
|
270
272
|
continue
|
|
271
273
|
|
|
272
274
|
# Convert MCP tool results to Claude format
|
|
@@ -335,7 +337,9 @@ class ClaudeAgent(MCPAgent):
|
|
|
335
337
|
# Map Claude's "computer" back to the actual MCP tool name
|
|
336
338
|
self._claude_to_mcp_tool_map["computer"] = selected_computer_tool.name
|
|
337
339
|
claude_tools.append(claude_tool)
|
|
338
|
-
|
|
340
|
+
self.hud_console.debug(
|
|
341
|
+
f"Using {selected_computer_tool.name} as computer tool for Claude"
|
|
342
|
+
)
|
|
339
343
|
|
|
340
344
|
# Add other non-computer tools
|
|
341
345
|
for tool in self._available_tools:
|
|
@@ -23,6 +23,7 @@ import mcp.types as types
|
|
|
23
23
|
|
|
24
24
|
from hud import instrument
|
|
25
25
|
from hud.types import AgentResponse, MCPToolCall, MCPToolResult
|
|
26
|
+
from hud.utils.hud_console import HUDConsole
|
|
26
27
|
|
|
27
28
|
from .base import MCPAgent
|
|
28
29
|
|
|
@@ -43,7 +44,6 @@ class GenericOpenAIChatAgent(MCPAgent):
|
|
|
43
44
|
*,
|
|
44
45
|
openai_client: AsyncOpenAI,
|
|
45
46
|
model_name: str = "gpt-4o-mini",
|
|
46
|
-
parallel_tool_calls: bool = False,
|
|
47
47
|
completion_kwargs: dict[str, Any] | None = None,
|
|
48
48
|
**agent_kwargs: Any,
|
|
49
49
|
) -> None:
|
|
@@ -51,17 +51,22 @@ class GenericOpenAIChatAgent(MCPAgent):
|
|
|
51
51
|
super().__init__(**agent_kwargs)
|
|
52
52
|
self.oai = openai_client
|
|
53
53
|
self.model_name = model_name
|
|
54
|
-
self.parallel_tool_calls = parallel_tool_calls
|
|
55
54
|
self.completion_kwargs: dict[str, Any] = completion_kwargs or {}
|
|
56
|
-
self.
|
|
55
|
+
self.mcp_schemas = []
|
|
56
|
+
self.hud_console = HUDConsole(logger=logger)
|
|
57
57
|
|
|
58
58
|
@staticmethod
|
|
59
59
|
def _oai_to_mcp(tool_call: Any) -> MCPToolCall: # type: ignore[valid-type]
|
|
60
60
|
"""Convert an OpenAI ``tool_call`` to :class:`MCPToolCall`."""
|
|
61
|
+
args = json.loads(tool_call.function.arguments or "{}")
|
|
62
|
+
if isinstance(args, list):
|
|
63
|
+
args = args[0]
|
|
64
|
+
if not isinstance(args, dict):
|
|
65
|
+
args = {}
|
|
61
66
|
return MCPToolCall(
|
|
62
67
|
id=tool_call.id,
|
|
63
68
|
name=tool_call.function.name,
|
|
64
|
-
arguments=
|
|
69
|
+
arguments=args,
|
|
65
70
|
)
|
|
66
71
|
|
|
67
72
|
async def get_system_messages(self) -> list[Any]:
|
|
@@ -177,45 +182,65 @@ class GenericOpenAIChatAgent(MCPAgent):
|
|
|
177
182
|
# Convert MCP tool schemas to OpenAI format
|
|
178
183
|
mcp_schemas = self.get_tool_schemas()
|
|
179
184
|
|
|
180
|
-
protected_keys = {"model", "messages", "tools"
|
|
185
|
+
protected_keys = {"model", "messages", "tools"}
|
|
181
186
|
extra = {k: v for k, v in (self.completion_kwargs or {}).items() if k not in protected_keys}
|
|
182
187
|
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
188
|
+
try:
|
|
189
|
+
response = await self.oai.chat.completions.create(
|
|
190
|
+
model=self.model_name,
|
|
191
|
+
messages=messages,
|
|
192
|
+
tools=cast("list[ChatCompletionToolParam]", mcp_schemas),
|
|
193
|
+
**extra,
|
|
194
|
+
)
|
|
195
|
+
except Exception as e:
|
|
196
|
+
error_content = f"Error getting response {e}"
|
|
197
|
+
if "Invalid JSON" in str(e):
|
|
198
|
+
error_content = "Invalid JSON, response was truncated"
|
|
199
|
+
self.hud_console.warning_log(error_content)
|
|
200
|
+
|
|
201
|
+
return AgentResponse(
|
|
202
|
+
content=error_content,
|
|
203
|
+
tool_calls=[],
|
|
204
|
+
done=True,
|
|
205
|
+
isError=True,
|
|
206
|
+
raw=None,
|
|
207
|
+
)
|
|
190
208
|
|
|
191
209
|
choice = response.choices[0]
|
|
192
210
|
msg = choice.message
|
|
193
|
-
|
|
194
211
|
assistant_msg: dict[str, Any] = {"role": "assistant"}
|
|
195
212
|
|
|
196
213
|
if msg.content:
|
|
197
214
|
assistant_msg["content"] = msg.content
|
|
198
215
|
|
|
199
216
|
if msg.tool_calls:
|
|
200
|
-
|
|
217
|
+
serialized_tool_calls = []
|
|
218
|
+
for tc in msg.tool_calls:
|
|
219
|
+
serialized_tc = {
|
|
220
|
+
"id": tc.id,
|
|
221
|
+
"type": "function",
|
|
222
|
+
"function": {"name": tc.function.name, "arguments": tc.function.arguments},
|
|
223
|
+
}
|
|
224
|
+
serialized_tool_calls.append(serialized_tc)
|
|
225
|
+
assistant_msg["tool_calls"] = serialized_tool_calls
|
|
201
226
|
|
|
202
227
|
messages.append(assistant_msg)
|
|
203
228
|
|
|
204
|
-
# Store the complete conversation history
|
|
205
|
-
self.conversation_history = messages.copy()
|
|
206
|
-
|
|
207
229
|
tool_calls = []
|
|
208
230
|
if msg.tool_calls:
|
|
209
231
|
for tc in msg.tool_calls:
|
|
210
232
|
if tc.function.name is not None: # type: ignore
|
|
211
|
-
tool_calls.
|
|
212
|
-
|
|
213
|
-
|
|
233
|
+
tool_calls.extend(self._oai_to_mcp(tc))
|
|
234
|
+
|
|
235
|
+
# Only stop on length (token limit), never on "stop"
|
|
236
|
+
done = choice.finish_reason == "length"
|
|
237
|
+
if done:
|
|
238
|
+
self.hud_console.info_log(f"Done decision: finish_reason={choice.finish_reason}")
|
|
214
239
|
|
|
215
240
|
return AgentResponse(
|
|
216
241
|
content=msg.content or "",
|
|
217
242
|
tool_calls=tool_calls,
|
|
218
|
-
done=
|
|
243
|
+
done=done,
|
|
219
244
|
raw=response, # Include raw response for access to Choice objects
|
|
220
245
|
)
|
|
221
246
|
|
|
@@ -230,15 +255,15 @@ class GenericOpenAIChatAgent(MCPAgent):
|
|
|
230
255
|
When images are present, we return both a tool message and a user message.
|
|
231
256
|
"""
|
|
232
257
|
rendered: list[dict[str, Any]] = []
|
|
258
|
+
|
|
259
|
+
# Separate text and image content
|
|
260
|
+
image_parts = []
|
|
233
261
|
for call, res in zip(tool_calls, tool_results, strict=False):
|
|
234
262
|
# Use structuredContent.result if available, otherwise use content
|
|
235
|
-
items = res.content
|
|
236
|
-
if res.structuredContent and isinstance(res.structuredContent, dict):
|
|
237
|
-
items = res.structuredContent.get("result", res.content)
|
|
238
|
-
|
|
239
|
-
# Separate text and image content
|
|
240
263
|
text_parts = []
|
|
241
|
-
|
|
264
|
+
items = res.content
|
|
265
|
+
if not res.content and res.structuredContent:
|
|
266
|
+
items = [res.structuredContent.get("result", res.content)]
|
|
242
267
|
|
|
243
268
|
for item in items:
|
|
244
269
|
if isinstance(item, dict):
|
|
@@ -272,18 +297,18 @@ class GenericOpenAIChatAgent(MCPAgent):
|
|
|
272
297
|
}
|
|
273
298
|
)
|
|
274
299
|
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
300
|
+
# If there are images, add them as a separate user message
|
|
301
|
+
if image_parts:
|
|
302
|
+
# Add a user message with the images
|
|
303
|
+
content_with_images = [
|
|
304
|
+
{"type": "text", "text": "Tool returned the following:"},
|
|
305
|
+
image_parts[-1],
|
|
306
|
+
]
|
|
307
|
+
rendered.append(
|
|
308
|
+
{
|
|
309
|
+
"role": "user",
|
|
310
|
+
"content": content_with_images,
|
|
311
|
+
}
|
|
312
|
+
)
|
|
288
313
|
|
|
289
314
|
return rendered
|
hud/agents/tests/test_base.py
CHANGED
|
@@ -97,7 +97,6 @@ class TestBaseMCPAgent:
|
|
|
97
97
|
assert agent.disallowed_tools == []
|
|
98
98
|
assert agent.initial_screenshot is True
|
|
99
99
|
assert agent.system_prompt is not None # Default system prompt is set
|
|
100
|
-
assert agent.lifecycle_tools == []
|
|
101
100
|
|
|
102
101
|
def test_init_with_params(self):
|
|
103
102
|
"""Test initialization with custom parameters."""
|
|
@@ -108,7 +107,6 @@ class TestBaseMCPAgent:
|
|
|
108
107
|
disallowed_tools=["bad_tool"],
|
|
109
108
|
initial_screenshot=True,
|
|
110
109
|
system_prompt="Custom prompt",
|
|
111
|
-
lifecycle_tools=["custom_setup", "custom_eval"],
|
|
112
110
|
)
|
|
113
111
|
|
|
114
112
|
assert agent.mcp_client == client
|
|
@@ -116,7 +114,6 @@ class TestBaseMCPAgent:
|
|
|
116
114
|
assert agent.disallowed_tools == ["bad_tool"]
|
|
117
115
|
assert agent.initial_screenshot is True
|
|
118
116
|
assert agent.system_prompt == "Custom prompt"
|
|
119
|
-
assert agent.lifecycle_tools == ["custom_setup", "custom_eval"]
|
|
120
117
|
|
|
121
118
|
@pytest.mark.asyncio
|
|
122
119
|
async def test_init_no_client_no_task(self):
|
|
@@ -631,7 +628,6 @@ class TestMCPAgentExtended:
|
|
|
631
628
|
# Lifecycle tools are specified by name, not as objects
|
|
632
629
|
agent = MockAgentExtended(
|
|
633
630
|
mcp_client=mock_client,
|
|
634
|
-
lifecycle_tools=["screenshot"], # Use tool name
|
|
635
631
|
responses=[{"role": "assistant", "content": "Done", "tool_calls": []}],
|
|
636
632
|
)
|
|
637
633
|
|
hud/agents/tests/test_openai.py
CHANGED
|
@@ -156,7 +156,7 @@ class TestOperatorAgent:
|
|
|
156
156
|
messages = [{"prompt": "What's on the screen?", "screenshot": None}]
|
|
157
157
|
response = await agent.get_response(messages)
|
|
158
158
|
|
|
159
|
-
assert response.content == "I can see the screen content."
|
|
159
|
+
assert response.content[0].text == "I can see the screen content."
|
|
160
160
|
assert response.done is True
|
|
161
161
|
|
|
162
162
|
@pytest.mark.asyncio
|