hud-python 0.4.45__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hud/__init__.py +27 -7
- hud/agents/__init__.py +11 -5
- hud/agents/base.py +220 -500
- hud/agents/claude.py +200 -240
- hud/agents/gemini.py +275 -0
- hud/agents/gemini_cua.py +335 -0
- hud/agents/grounded_openai.py +98 -100
- hud/agents/misc/integration_test_agent.py +51 -20
- hud/agents/misc/response_agent.py +41 -36
- hud/agents/openai.py +291 -292
- hud/agents/{openai_chat_generic.py → openai_chat.py} +80 -34
- hud/agents/operator.py +211 -0
- hud/agents/tests/conftest.py +133 -0
- hud/agents/tests/test_base.py +300 -622
- hud/agents/tests/test_base_runtime.py +233 -0
- hud/agents/tests/test_claude.py +379 -210
- hud/agents/tests/test_client.py +9 -10
- hud/agents/tests/test_gemini.py +369 -0
- hud/agents/tests/test_grounded_openai_agent.py +65 -50
- hud/agents/tests/test_openai.py +376 -140
- hud/agents/tests/test_operator.py +362 -0
- hud/agents/tests/test_run_eval.py +179 -0
- hud/cli/__init__.py +461 -545
- hud/cli/analyze.py +43 -5
- hud/cli/build.py +664 -110
- hud/cli/debug.py +8 -5
- hud/cli/dev.py +882 -734
- hud/cli/eval.py +782 -668
- hud/cli/flows/dev.py +167 -0
- hud/cli/flows/init.py +191 -0
- hud/cli/flows/tasks.py +153 -56
- hud/cli/flows/templates.py +151 -0
- hud/cli/flows/tests/__init__.py +1 -0
- hud/cli/flows/tests/test_dev.py +126 -0
- hud/cli/init.py +60 -58
- hud/cli/push.py +29 -11
- hud/cli/rft.py +311 -0
- hud/cli/rft_status.py +145 -0
- hud/cli/tests/test_analyze.py +5 -5
- hud/cli/tests/test_analyze_metadata.py +3 -2
- hud/cli/tests/test_analyze_module.py +120 -0
- hud/cli/tests/test_build.py +108 -6
- hud/cli/tests/test_build_failure.py +41 -0
- hud/cli/tests/test_build_module.py +50 -0
- hud/cli/tests/test_cli_init.py +6 -1
- hud/cli/tests/test_cli_more_wrappers.py +30 -0
- hud/cli/tests/test_cli_root.py +140 -0
- hud/cli/tests/test_convert.py +361 -0
- hud/cli/tests/test_debug.py +12 -10
- hud/cli/tests/test_dev.py +197 -0
- hud/cli/tests/test_eval.py +251 -0
- hud/cli/tests/test_eval_bedrock.py +51 -0
- hud/cli/tests/test_init.py +124 -0
- hud/cli/tests/test_main_module.py +11 -5
- hud/cli/tests/test_mcp_server.py +12 -100
- hud/cli/tests/test_push_happy.py +74 -0
- hud/cli/tests/test_push_wrapper.py +23 -0
- hud/cli/tests/test_registry.py +1 -1
- hud/cli/tests/test_utils.py +1 -1
- hud/cli/{rl → utils}/celebrate.py +14 -12
- hud/cli/utils/config.py +18 -1
- hud/cli/utils/docker.py +130 -4
- hud/cli/utils/env_check.py +9 -9
- hud/cli/utils/git.py +136 -0
- hud/cli/utils/interactive.py +39 -5
- hud/cli/utils/metadata.py +69 -0
- hud/cli/utils/runner.py +1 -1
- hud/cli/utils/server.py +2 -2
- hud/cli/utils/source_hash.py +3 -3
- hud/cli/utils/tasks.py +4 -1
- hud/cli/utils/tests/__init__.py +0 -0
- hud/cli/utils/tests/test_config.py +58 -0
- hud/cli/utils/tests/test_docker.py +93 -0
- hud/cli/utils/tests/test_docker_hints.py +71 -0
- hud/cli/utils/tests/test_env_check.py +74 -0
- hud/cli/utils/tests/test_environment.py +42 -0
- hud/cli/utils/tests/test_git.py +142 -0
- hud/cli/utils/tests/test_interactive_module.py +60 -0
- hud/cli/utils/tests/test_local_runner.py +50 -0
- hud/cli/utils/tests/test_logging_utils.py +23 -0
- hud/cli/utils/tests/test_metadata.py +49 -0
- hud/cli/utils/tests/test_package_runner.py +35 -0
- hud/cli/utils/tests/test_registry_utils.py +49 -0
- hud/cli/utils/tests/test_remote_runner.py +25 -0
- hud/cli/utils/tests/test_runner_modules.py +52 -0
- hud/cli/utils/tests/test_source_hash.py +36 -0
- hud/cli/utils/tests/test_tasks.py +80 -0
- hud/cli/utils/version_check.py +258 -0
- hud/cli/{rl → utils}/viewer.py +2 -2
- hud/clients/README.md +12 -11
- hud/clients/__init__.py +4 -3
- hud/clients/base.py +166 -26
- hud/clients/environment.py +51 -0
- hud/clients/fastmcp.py +13 -6
- hud/clients/mcp_use.py +40 -15
- hud/clients/tests/test_analyze_scenarios.py +206 -0
- hud/clients/tests/test_protocol.py +9 -3
- hud/datasets/__init__.py +23 -20
- hud/datasets/loader.py +327 -0
- hud/datasets/runner.py +192 -105
- hud/datasets/tests/__init__.py +0 -0
- hud/datasets/tests/test_loader.py +221 -0
- hud/datasets/tests/test_utils.py +315 -0
- hud/datasets/utils.py +270 -90
- hud/environment/__init__.py +50 -0
- hud/environment/connection.py +206 -0
- hud/environment/connectors/__init__.py +33 -0
- hud/environment/connectors/base.py +68 -0
- hud/environment/connectors/local.py +177 -0
- hud/environment/connectors/mcp_config.py +109 -0
- hud/environment/connectors/openai.py +101 -0
- hud/environment/connectors/remote.py +172 -0
- hud/environment/environment.py +694 -0
- hud/environment/integrations/__init__.py +45 -0
- hud/environment/integrations/adk.py +67 -0
- hud/environment/integrations/anthropic.py +196 -0
- hud/environment/integrations/gemini.py +92 -0
- hud/environment/integrations/langchain.py +82 -0
- hud/environment/integrations/llamaindex.py +68 -0
- hud/environment/integrations/openai.py +238 -0
- hud/environment/mock.py +306 -0
- hud/environment/router.py +112 -0
- hud/environment/scenarios.py +493 -0
- hud/environment/tests/__init__.py +1 -0
- hud/environment/tests/test_connection.py +317 -0
- hud/environment/tests/test_connectors.py +218 -0
- hud/environment/tests/test_environment.py +161 -0
- hud/environment/tests/test_integrations.py +257 -0
- hud/environment/tests/test_local_connectors.py +201 -0
- hud/environment/tests/test_scenarios.py +280 -0
- hud/environment/tests/test_tools.py +208 -0
- hud/environment/types.py +23 -0
- hud/environment/utils/__init__.py +35 -0
- hud/environment/utils/formats.py +215 -0
- hud/environment/utils/schema.py +171 -0
- hud/environment/utils/tool_wrappers.py +113 -0
- hud/eval/__init__.py +67 -0
- hud/eval/context.py +674 -0
- hud/eval/display.py +299 -0
- hud/eval/instrument.py +185 -0
- hud/eval/manager.py +466 -0
- hud/eval/parallel.py +268 -0
- hud/eval/task.py +340 -0
- hud/eval/tests/__init__.py +1 -0
- hud/eval/tests/test_context.py +178 -0
- hud/eval/tests/test_eval.py +210 -0
- hud/eval/tests/test_manager.py +152 -0
- hud/eval/tests/test_parallel.py +168 -0
- hud/eval/tests/test_task.py +145 -0
- hud/eval/types.py +63 -0
- hud/eval/utils.py +183 -0
- hud/patches/__init__.py +19 -0
- hud/patches/mcp_patches.py +151 -0
- hud/patches/warnings.py +54 -0
- hud/samples/browser.py +4 -4
- hud/server/__init__.py +2 -1
- hud/server/low_level.py +2 -1
- hud/server/router.py +164 -0
- hud/server/server.py +567 -80
- hud/server/tests/test_mcp_server_integration.py +11 -11
- hud/server/tests/test_mcp_server_more.py +1 -1
- hud/server/tests/test_server_extra.py +2 -0
- hud/settings.py +45 -3
- hud/shared/exceptions.py +36 -10
- hud/shared/hints.py +26 -1
- hud/shared/requests.py +15 -3
- hud/shared/tests/test_exceptions.py +40 -31
- hud/shared/tests/test_hints.py +167 -0
- hud/telemetry/__init__.py +20 -19
- hud/telemetry/exporter.py +201 -0
- hud/telemetry/instrument.py +158 -253
- hud/telemetry/tests/test_eval_telemetry.py +356 -0
- hud/telemetry/tests/test_exporter.py +258 -0
- hud/telemetry/tests/test_instrument.py +401 -0
- hud/tools/__init__.py +16 -2
- hud/tools/apply_patch.py +639 -0
- hud/tools/base.py +54 -4
- hud/tools/bash.py +2 -2
- hud/tools/computer/__init__.py +4 -0
- hud/tools/computer/anthropic.py +2 -2
- hud/tools/computer/gemini.py +385 -0
- hud/tools/computer/hud.py +23 -6
- hud/tools/computer/openai.py +20 -21
- hud/tools/computer/qwen.py +434 -0
- hud/tools/computer/settings.py +37 -0
- hud/tools/edit.py +3 -7
- hud/tools/executors/base.py +4 -2
- hud/tools/executors/pyautogui.py +1 -1
- hud/tools/grounding/grounded_tool.py +13 -18
- hud/tools/grounding/grounder.py +10 -31
- hud/tools/grounding/tests/test_grounded_tool.py +26 -44
- hud/tools/jupyter.py +330 -0
- hud/tools/playwright.py +18 -3
- hud/tools/shell.py +308 -0
- hud/tools/tests/test_apply_patch.py +718 -0
- hud/tools/tests/test_computer.py +4 -9
- hud/tools/tests/test_computer_actions.py +24 -2
- hud/tools/tests/test_jupyter_tool.py +181 -0
- hud/tools/tests/test_shell.py +596 -0
- hud/tools/tests/test_submit.py +85 -0
- hud/tools/tests/test_types.py +193 -0
- hud/tools/types.py +21 -1
- hud/types.py +167 -57
- hud/utils/__init__.py +2 -0
- hud/utils/env.py +67 -0
- hud/utils/hud_console.py +61 -3
- hud/utils/mcp.py +15 -58
- hud/utils/strict_schema.py +162 -0
- hud/utils/tests/test_init.py +1 -2
- hud/utils/tests/test_mcp.py +1 -28
- hud/utils/tests/test_pretty_errors.py +186 -0
- hud/utils/tests/test_tool_shorthand.py +154 -0
- hud/utils/tests/test_version.py +1 -1
- hud/utils/types.py +20 -0
- hud/version.py +1 -1
- hud_python-0.5.1.dist-info/METADATA +264 -0
- hud_python-0.5.1.dist-info/RECORD +299 -0
- {hud_python-0.4.45.dist-info → hud_python-0.5.1.dist-info}/WHEEL +1 -1
- hud/agents/langchain.py +0 -261
- hud/agents/lite_llm.py +0 -72
- hud/cli/rl/__init__.py +0 -180
- hud/cli/rl/config.py +0 -101
- hud/cli/rl/display.py +0 -133
- hud/cli/rl/gpu.py +0 -63
- hud/cli/rl/gpu_utils.py +0 -321
- hud/cli/rl/local_runner.py +0 -595
- hud/cli/rl/presets.py +0 -96
- hud/cli/rl/remote_runner.py +0 -463
- hud/cli/rl/rl_api.py +0 -150
- hud/cli/rl/vllm.py +0 -177
- hud/cli/rl/wait_utils.py +0 -89
- hud/datasets/parallel.py +0 -687
- hud/misc/__init__.py +0 -1
- hud/misc/claude_plays_pokemon.py +0 -292
- hud/otel/__init__.py +0 -35
- hud/otel/collector.py +0 -142
- hud/otel/config.py +0 -181
- hud/otel/context.py +0 -570
- hud/otel/exporters.py +0 -369
- hud/otel/instrumentation.py +0 -135
- hud/otel/processors.py +0 -121
- hud/otel/tests/__init__.py +0 -1
- hud/otel/tests/test_processors.py +0 -197
- hud/rl/README.md +0 -30
- hud/rl/__init__.py +0 -1
- hud/rl/actor.py +0 -176
- hud/rl/buffer.py +0 -405
- hud/rl/chat_template.jinja +0 -101
- hud/rl/config.py +0 -192
- hud/rl/distributed.py +0 -132
- hud/rl/learner.py +0 -637
- hud/rl/tests/__init__.py +0 -1
- hud/rl/tests/test_learner.py +0 -186
- hud/rl/train.py +0 -382
- hud/rl/types.py +0 -101
- hud/rl/utils/start_vllm_server.sh +0 -30
- hud/rl/utils.py +0 -524
- hud/rl/vllm_adapter.py +0 -143
- hud/telemetry/job.py +0 -352
- hud/telemetry/replay.py +0 -74
- hud/telemetry/tests/test_replay.py +0 -40
- hud/telemetry/tests/test_trace.py +0 -63
- hud/telemetry/trace.py +0 -158
- hud/utils/agent_factories.py +0 -86
- hud/utils/async_utils.py +0 -65
- hud/utils/group_eval.py +0 -223
- hud/utils/progress.py +0 -149
- hud/utils/tasks.py +0 -127
- hud/utils/tests/test_async_utils.py +0 -173
- hud/utils/tests/test_progress.py +0 -261
- hud_python-0.4.45.dist-info/METADATA +0 -552
- hud_python-0.4.45.dist-info/RECORD +0 -228
- {hud_python-0.4.45.dist-info → hud_python-0.5.1.dist-info}/entry_points.txt +0 -0
- {hud_python-0.4.45.dist-info → hud_python-0.5.1.dist-info}/licenses/LICENSE +0 -0
hud/agents/gemini.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
"""Gemini MCP Agent implementation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from typing import Any, ClassVar, cast
|
|
7
|
+
|
|
8
|
+
import mcp.types as types
|
|
9
|
+
from google import genai
|
|
10
|
+
from google.genai import types as genai_types
|
|
11
|
+
from pydantic import ConfigDict
|
|
12
|
+
|
|
13
|
+
from hud.settings import settings
|
|
14
|
+
from hud.types import AgentResponse, BaseAgentConfig, MCPToolCall, MCPToolResult
|
|
15
|
+
from hud.utils.hud_console import HUDConsole
|
|
16
|
+
from hud.utils.types import with_signature
|
|
17
|
+
|
|
18
|
+
from .base import BaseCreateParams, MCPAgent
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class GeminiConfig(BaseAgentConfig):
|
|
24
|
+
"""Configuration for `GeminiAgent`."""
|
|
25
|
+
|
|
26
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
27
|
+
|
|
28
|
+
model_name: str = "Gemini"
|
|
29
|
+
model: str = "gemini-3-pro-preview"
|
|
30
|
+
model_client: genai.Client | None = None
|
|
31
|
+
temperature: float = 1.0
|
|
32
|
+
top_p: float = 0.95
|
|
33
|
+
top_k: int = 40
|
|
34
|
+
max_output_tokens: int = 8192
|
|
35
|
+
validate_api_key: bool = True
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class GeminiCreateParams(BaseCreateParams, GeminiConfig):
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class GeminiAgent(MCPAgent):
|
|
43
|
+
"""
|
|
44
|
+
Gemini agent that uses MCP servers for tool execution.
|
|
45
|
+
|
|
46
|
+
This agent uses Gemini's native tool calling capabilities but executes
|
|
47
|
+
tools through MCP servers instead of direct implementation.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
metadata: ClassVar[dict[str, Any] | None] = None
|
|
51
|
+
config_cls: ClassVar[type[BaseAgentConfig]] = GeminiConfig
|
|
52
|
+
|
|
53
|
+
@with_signature(GeminiCreateParams)
|
|
54
|
+
@classmethod
|
|
55
|
+
def create(cls, **kwargs: Any) -> GeminiAgent: # pyright: ignore[reportIncompatibleMethodOverride]
|
|
56
|
+
return MCPAgent.create.__func__(cls, **kwargs) # type: ignore[return-value]
|
|
57
|
+
|
|
58
|
+
def __init__(self, params: GeminiCreateParams | None = None, **kwargs: Any) -> None:
|
|
59
|
+
super().__init__(params, **kwargs)
|
|
60
|
+
self.config: GeminiConfig
|
|
61
|
+
|
|
62
|
+
model_client = self.config.model_client
|
|
63
|
+
if model_client is None:
|
|
64
|
+
api_key = settings.gemini_api_key
|
|
65
|
+
if not api_key:
|
|
66
|
+
raise ValueError("Gemini API key not found. Set GEMINI_API_KEY.")
|
|
67
|
+
model_client = genai.Client(api_key=api_key)
|
|
68
|
+
|
|
69
|
+
if self.config.validate_api_key:
|
|
70
|
+
try:
|
|
71
|
+
list(model_client.models.list(config=genai_types.ListModelsConfig(page_size=1)))
|
|
72
|
+
except Exception as e:
|
|
73
|
+
raise ValueError(f"Gemini API key is invalid: {e}") from e
|
|
74
|
+
|
|
75
|
+
self.gemini_client = model_client
|
|
76
|
+
self.temperature = self.config.temperature
|
|
77
|
+
self.top_p = self.config.top_p
|
|
78
|
+
self.top_k = self.config.top_k
|
|
79
|
+
self.max_output_tokens = self.config.max_output_tokens
|
|
80
|
+
self.hud_console = HUDConsole(logger=logger)
|
|
81
|
+
|
|
82
|
+
# Track mapping from Gemini tool names to MCP tool names
|
|
83
|
+
self._gemini_to_mcp_tool_map: dict[str, str] = {}
|
|
84
|
+
self.gemini_tools: genai_types.ToolListUnion = []
|
|
85
|
+
|
|
86
|
+
def _on_tools_ready(self) -> None:
|
|
87
|
+
"""Build Gemini-specific tool mappings after tools are discovered."""
|
|
88
|
+
self._convert_tools_for_gemini()
|
|
89
|
+
|
|
90
|
+
async def get_system_messages(self) -> list[genai_types.Content]:
|
|
91
|
+
"""No system messages for Gemini because applied in get_response"""
|
|
92
|
+
return []
|
|
93
|
+
|
|
94
|
+
async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[genai_types.Content]:
|
|
95
|
+
"""Format messages for Gemini."""
|
|
96
|
+
# Convert MCP content types to Gemini content types
|
|
97
|
+
gemini_parts: list[genai_types.Part] = []
|
|
98
|
+
|
|
99
|
+
for block in blocks:
|
|
100
|
+
if isinstance(block, types.TextContent):
|
|
101
|
+
gemini_parts.append(genai_types.Part(text=block.text))
|
|
102
|
+
elif isinstance(block, types.ImageContent):
|
|
103
|
+
# Convert MCP ImageContent to Gemini format
|
|
104
|
+
# Need to decode base64 string to bytes
|
|
105
|
+
import base64
|
|
106
|
+
|
|
107
|
+
image_bytes = base64.b64decode(block.data)
|
|
108
|
+
gemini_parts.append(
|
|
109
|
+
genai_types.Part.from_bytes(data=image_bytes, mime_type=block.mimeType)
|
|
110
|
+
)
|
|
111
|
+
else:
|
|
112
|
+
# For other types, try to handle but log a warning
|
|
113
|
+
self.hud_console.log(f"Unknown content block type: {type(block)}", level="warning")
|
|
114
|
+
|
|
115
|
+
return [genai_types.Content(role="user", parts=gemini_parts)]
|
|
116
|
+
|
|
117
|
+
async def get_response(self, messages: list[genai_types.Content]) -> AgentResponse:
|
|
118
|
+
"""Get response from Gemini including any tool calls."""
|
|
119
|
+
# Build generate content config
|
|
120
|
+
generate_config = genai_types.GenerateContentConfig(
|
|
121
|
+
temperature=self.temperature,
|
|
122
|
+
top_p=self.top_p,
|
|
123
|
+
top_k=self.top_k,
|
|
124
|
+
max_output_tokens=self.max_output_tokens,
|
|
125
|
+
tools=self.gemini_tools,
|
|
126
|
+
system_instruction=self.system_prompt,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
# Use async API to avoid blocking the event loop
|
|
130
|
+
response = await self.gemini_client.aio.models.generate_content(
|
|
131
|
+
model=self.config.model,
|
|
132
|
+
contents=cast("Any", messages),
|
|
133
|
+
config=generate_config,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# Append assistant response (including any function_call) so that
|
|
137
|
+
# subsequent FunctionResponse messages correspond to a prior FunctionCall
|
|
138
|
+
if response.candidates and len(response.candidates) > 0 and response.candidates[0].content:
|
|
139
|
+
messages.append(response.candidates[0].content)
|
|
140
|
+
|
|
141
|
+
# Process response
|
|
142
|
+
result = AgentResponse(content="", tool_calls=[], done=True)
|
|
143
|
+
collected_tool_calls: list[MCPToolCall] = []
|
|
144
|
+
|
|
145
|
+
if not response.candidates:
|
|
146
|
+
self.hud_console.warning("Response has no candidates")
|
|
147
|
+
return result
|
|
148
|
+
|
|
149
|
+
candidate = response.candidates[0]
|
|
150
|
+
|
|
151
|
+
# Extract text content and function calls
|
|
152
|
+
text_content = ""
|
|
153
|
+
thinking_content = ""
|
|
154
|
+
|
|
155
|
+
if candidate.content and candidate.content.parts:
|
|
156
|
+
for part in candidate.content.parts:
|
|
157
|
+
if part.function_call:
|
|
158
|
+
tool_call = self._extract_tool_call(part)
|
|
159
|
+
if tool_call is not None:
|
|
160
|
+
collected_tool_calls.append(tool_call)
|
|
161
|
+
elif part.thought is True and part.text:
|
|
162
|
+
if thinking_content:
|
|
163
|
+
thinking_content += "\n"
|
|
164
|
+
thinking_content += part.text
|
|
165
|
+
elif part.text:
|
|
166
|
+
text_content += part.text
|
|
167
|
+
|
|
168
|
+
# Assign collected tool calls and mark done status
|
|
169
|
+
if collected_tool_calls:
|
|
170
|
+
result.tool_calls = collected_tool_calls
|
|
171
|
+
result.done = False
|
|
172
|
+
|
|
173
|
+
result.content = text_content
|
|
174
|
+
if thinking_content:
|
|
175
|
+
result.reasoning = thinking_content
|
|
176
|
+
|
|
177
|
+
return result
|
|
178
|
+
|
|
179
|
+
def _extract_tool_call(self, part: genai_types.Part) -> MCPToolCall | None:
|
|
180
|
+
"""Extract an MCPToolCall from a function call part.
|
|
181
|
+
|
|
182
|
+
Subclasses can override to customize tool call extraction (e.g., normalizing
|
|
183
|
+
computer use calls to a different schema).
|
|
184
|
+
"""
|
|
185
|
+
if not part.function_call:
|
|
186
|
+
return None
|
|
187
|
+
|
|
188
|
+
func_name = part.function_call.name or ""
|
|
189
|
+
mcp_tool_name = self._gemini_to_mcp_tool_map.get(func_name, func_name)
|
|
190
|
+
raw_args = dict(part.function_call.args) if part.function_call.args else {}
|
|
191
|
+
|
|
192
|
+
return MCPToolCall(
|
|
193
|
+
name=mcp_tool_name,
|
|
194
|
+
arguments=raw_args,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
async def format_tool_results(
|
|
198
|
+
self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult]
|
|
199
|
+
) -> list[genai_types.Content]:
|
|
200
|
+
"""Format tool results into Gemini messages."""
|
|
201
|
+
# Process each tool result
|
|
202
|
+
function_responses = []
|
|
203
|
+
|
|
204
|
+
for tool_call, result in zip(tool_calls, tool_results, strict=True):
|
|
205
|
+
# Get the Gemini function name from metadata
|
|
206
|
+
gemini_name = getattr(tool_call, "gemini_name", tool_call.name)
|
|
207
|
+
|
|
208
|
+
# Convert MCP tool results to Gemini format
|
|
209
|
+
response_dict: dict[str, Any] = {}
|
|
210
|
+
|
|
211
|
+
if result.isError:
|
|
212
|
+
# Extract error message from content
|
|
213
|
+
error_msg = "Tool execution failed"
|
|
214
|
+
for content in result.content:
|
|
215
|
+
if isinstance(content, types.TextContent):
|
|
216
|
+
error_msg = content.text
|
|
217
|
+
break
|
|
218
|
+
response_dict["error"] = error_msg
|
|
219
|
+
else:
|
|
220
|
+
# Process success content
|
|
221
|
+
response_dict["success"] = True
|
|
222
|
+
# Add text content to response
|
|
223
|
+
for content in result.content:
|
|
224
|
+
if isinstance(content, types.TextContent):
|
|
225
|
+
response_dict["output"] = content.text
|
|
226
|
+
break
|
|
227
|
+
|
|
228
|
+
# Create function response
|
|
229
|
+
function_response = genai_types.FunctionResponse(
|
|
230
|
+
name=gemini_name,
|
|
231
|
+
response=response_dict,
|
|
232
|
+
)
|
|
233
|
+
function_responses.append(function_response)
|
|
234
|
+
|
|
235
|
+
# Return as a user message containing all function responses
|
|
236
|
+
return [
|
|
237
|
+
genai_types.Content(
|
|
238
|
+
role="user",
|
|
239
|
+
parts=[genai_types.Part(function_response=fr) for fr in function_responses],
|
|
240
|
+
)
|
|
241
|
+
]
|
|
242
|
+
|
|
243
|
+
async def create_user_message(self, text: str) -> genai_types.Content:
|
|
244
|
+
"""Create a user message in Gemini's format."""
|
|
245
|
+
return genai_types.Content(role="user", parts=[genai_types.Part(text=text)])
|
|
246
|
+
|
|
247
|
+
def _convert_tools_for_gemini(self) -> genai_types.ToolListUnion:
|
|
248
|
+
"""Convert MCP tools to Gemini tool format."""
|
|
249
|
+
self._gemini_to_mcp_tool_map = {} # Reset mapping
|
|
250
|
+
self.gemini_tools = []
|
|
251
|
+
|
|
252
|
+
for tool in self.get_available_tools():
|
|
253
|
+
gemini_tool = self._to_gemini_tool(tool)
|
|
254
|
+
if gemini_tool is None:
|
|
255
|
+
continue
|
|
256
|
+
|
|
257
|
+
self._gemini_to_mcp_tool_map[tool.name] = tool.name
|
|
258
|
+
self.gemini_tools.append(gemini_tool)
|
|
259
|
+
|
|
260
|
+
return self.gemini_tools
|
|
261
|
+
|
|
262
|
+
def _to_gemini_tool(self, tool: types.Tool) -> genai_types.Tool | None:
|
|
263
|
+
"""Convert a single MCP tool to Gemini tool format.
|
|
264
|
+
|
|
265
|
+
Subclasses can override to customize tool conversion (e.g., for computer use).
|
|
266
|
+
"""
|
|
267
|
+
# Ensure parameters have proper Schema format
|
|
268
|
+
if tool.description is None or tool.inputSchema is None:
|
|
269
|
+
raise ValueError(f"MCP tool {tool.name} requires both a description and inputSchema.")
|
|
270
|
+
function_decl = genai_types.FunctionDeclaration(
|
|
271
|
+
name=tool.name,
|
|
272
|
+
description=tool.description,
|
|
273
|
+
parameters_json_schema=tool.inputSchema,
|
|
274
|
+
)
|
|
275
|
+
return genai_types.Tool(function_declarations=[function_decl])
|
hud/agents/gemini_cua.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
"""Gemini Computer Use Agent implementation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from typing import Any, ClassVar
|
|
7
|
+
|
|
8
|
+
import mcp.types as types
|
|
9
|
+
from google.genai import types as genai_types
|
|
10
|
+
from pydantic import ConfigDict, Field
|
|
11
|
+
|
|
12
|
+
from hud.tools.computer.settings import computer_settings
|
|
13
|
+
from hud.types import AgentResponse, BaseAgentConfig, MCPToolCall, MCPToolResult
|
|
14
|
+
from hud.utils.types import with_signature
|
|
15
|
+
|
|
16
|
+
from .base import BaseCreateParams, MCPAgent
|
|
17
|
+
from .gemini import GeminiAgent, GeminiConfig
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
# Predefined Gemini computer use functions
|
|
22
|
+
PREDEFINED_COMPUTER_USE_FUNCTIONS = [
|
|
23
|
+
"open_web_browser",
|
|
24
|
+
"click_at",
|
|
25
|
+
"hover_at",
|
|
26
|
+
"type_text_at",
|
|
27
|
+
"scroll_document",
|
|
28
|
+
"scroll_at",
|
|
29
|
+
"wait_5_seconds",
|
|
30
|
+
"go_back",
|
|
31
|
+
"go_forward",
|
|
32
|
+
"search",
|
|
33
|
+
"navigate",
|
|
34
|
+
"key_combination",
|
|
35
|
+
"drag_and_drop",
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
GEMINI_CUA_INSTRUCTIONS = """
|
|
39
|
+
You are an autonomous computer-using agent. Follow these guidelines:
|
|
40
|
+
|
|
41
|
+
1. NEVER ask for confirmation. Complete all tasks autonomously.
|
|
42
|
+
2. Do NOT send messages like "I need to confirm before..." or "Do you want me to
|
|
43
|
+
continue?" - just proceed.
|
|
44
|
+
3. When the user asks you to interact with something (like clicking a chat or typing
|
|
45
|
+
a message), DO IT without asking.
|
|
46
|
+
4. Only use the formal safety check mechanism for truly dangerous operations (like
|
|
47
|
+
deleting important files).
|
|
48
|
+
5. For normal tasks like clicking buttons, typing in chat boxes, filling forms -
|
|
49
|
+
JUST DO IT.
|
|
50
|
+
6. The user has already given you permission by running this agent. No further
|
|
51
|
+
confirmation is needed.
|
|
52
|
+
7. Be decisive and action-oriented. Complete the requested task fully.
|
|
53
|
+
|
|
54
|
+
Remember: You are expected to complete tasks autonomously. The user trusts you to do
|
|
55
|
+
what they asked.
|
|
56
|
+
""".strip()
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class GeminiCUAConfig(GeminiConfig):
|
|
60
|
+
"""Configuration for `GeminiCUAAgent`."""
|
|
61
|
+
|
|
62
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
63
|
+
|
|
64
|
+
model_name: str = "GeminiCUA"
|
|
65
|
+
model: str = "gemini-2.5-computer-use-preview-10-2025"
|
|
66
|
+
excluded_predefined_functions: list[str] = Field(default_factory=list)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class GeminiCUACreateParams(BaseCreateParams, GeminiCUAConfig):
|
|
70
|
+
pass
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class GeminiCUAAgent(GeminiAgent):
|
|
74
|
+
"""
|
|
75
|
+
Gemini Computer Use Agent that extends GeminiAgent with computer use capabilities.
|
|
76
|
+
|
|
77
|
+
This agent uses Gemini's native computer use capabilities but executes
|
|
78
|
+
tools through MCP servers instead of direct implementation.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
metadata: ClassVar[dict[str, Any] | None] = {
|
|
82
|
+
"display_width": computer_settings.GEMINI_COMPUTER_WIDTH,
|
|
83
|
+
"display_height": computer_settings.GEMINI_COMPUTER_HEIGHT,
|
|
84
|
+
}
|
|
85
|
+
required_tools: ClassVar[list[str]] = ["gemini_computer"]
|
|
86
|
+
config_cls: ClassVar[type[BaseAgentConfig]] = GeminiCUAConfig
|
|
87
|
+
|
|
88
|
+
@with_signature(GeminiCUACreateParams)
|
|
89
|
+
@classmethod
|
|
90
|
+
def create(cls, **kwargs: Any) -> GeminiCUAAgent: # pyright: ignore[reportIncompatibleMethodOverride]
|
|
91
|
+
return MCPAgent.create.__func__(cls, **kwargs) # type: ignore[return-value]
|
|
92
|
+
|
|
93
|
+
def __init__(self, params: GeminiCUACreateParams | None = None, **kwargs: Any) -> None:
|
|
94
|
+
super().__init__(params, **kwargs) # type: ignore[arg-type]
|
|
95
|
+
self.config: GeminiCUAConfig # type: ignore[assignment]
|
|
96
|
+
|
|
97
|
+
self._computer_tool_name = "gemini_computer"
|
|
98
|
+
self.excluded_predefined_functions = list(self.config.excluded_predefined_functions)
|
|
99
|
+
|
|
100
|
+
# Context management: Maximum number of recent turns to keep screenshots for
|
|
101
|
+
# Configurable via GEMINI_MAX_RECENT_TURN_WITH_SCREENSHOTS environment variable
|
|
102
|
+
self.max_recent_turn_with_screenshots = (
|
|
103
|
+
computer_settings.GEMINI_MAX_RECENT_TURN_WITH_SCREENSHOTS
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# Add computer use instructions
|
|
107
|
+
if self.system_prompt:
|
|
108
|
+
self.system_prompt = f"{self.system_prompt}\n\n{GEMINI_CUA_INSTRUCTIONS}"
|
|
109
|
+
else:
|
|
110
|
+
self.system_prompt = GEMINI_CUA_INSTRUCTIONS
|
|
111
|
+
|
|
112
|
+
def _to_gemini_tool(self, tool: types.Tool) -> genai_types.Tool | None:
|
|
113
|
+
"""Convert a single MCP tool to Gemini tool format.
|
|
114
|
+
|
|
115
|
+
Handles gemini_computer tool specially by using Gemini's native ComputerUse.
|
|
116
|
+
"""
|
|
117
|
+
if tool.name == self._computer_tool_name:
|
|
118
|
+
# Use Gemini's native computer use capability
|
|
119
|
+
return genai_types.Tool(
|
|
120
|
+
computer_use=genai_types.ComputerUse(
|
|
121
|
+
environment=genai_types.Environment.ENVIRONMENT_BROWSER,
|
|
122
|
+
excluded_predefined_functions=self.excluded_predefined_functions,
|
|
123
|
+
)
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# For non-computer tools, use the parent implementation
|
|
127
|
+
return super()._to_gemini_tool(tool)
|
|
128
|
+
|
|
129
|
+
async def get_response(self, messages: list[genai_types.Content]) -> AgentResponse:
|
|
130
|
+
"""Get response from Gemini including any tool calls.
|
|
131
|
+
|
|
132
|
+
Extends parent to trim old screenshots before making API call.
|
|
133
|
+
"""
|
|
134
|
+
# Trim screenshots from older turns to manage context growth
|
|
135
|
+
self._remove_old_screenshots(messages)
|
|
136
|
+
|
|
137
|
+
return await super().get_response(messages)
|
|
138
|
+
|
|
139
|
+
async def format_tool_results(
|
|
140
|
+
self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult]
|
|
141
|
+
) -> list[genai_types.Content]:
|
|
142
|
+
"""Format tool results into Gemini messages.
|
|
143
|
+
|
|
144
|
+
Handles computer tool results specially with screenshots and URLs.
|
|
145
|
+
"""
|
|
146
|
+
# Process each tool result
|
|
147
|
+
function_responses = []
|
|
148
|
+
|
|
149
|
+
for tool_call, result in zip(tool_calls, tool_results, strict=True):
|
|
150
|
+
# Get the Gemini function name from metadata
|
|
151
|
+
gemini_name = getattr(tool_call, "gemini_name", tool_call.name)
|
|
152
|
+
|
|
153
|
+
# Check if this is a computer use tool call
|
|
154
|
+
is_computer_call = tool_call.name == self._computer_tool_name
|
|
155
|
+
|
|
156
|
+
# Convert MCP tool results to Gemini format
|
|
157
|
+
response_dict: dict[str, Any] = {}
|
|
158
|
+
url = None
|
|
159
|
+
|
|
160
|
+
if result.isError:
|
|
161
|
+
# Extract error message from content
|
|
162
|
+
error_msg = "Tool execution failed"
|
|
163
|
+
for content in result.content:
|
|
164
|
+
if isinstance(content, types.TextContent):
|
|
165
|
+
# Check if this is a URL metadata block
|
|
166
|
+
if content.text.startswith("__URL__:"):
|
|
167
|
+
url = content.text.replace("__URL__:", "")
|
|
168
|
+
else:
|
|
169
|
+
error_msg = content.text
|
|
170
|
+
break
|
|
171
|
+
response_dict["error"] = error_msg
|
|
172
|
+
# for gemini cua agent, if a nonexistend computer tool is called, it won't
|
|
173
|
+
# #technically count as a computer tool call, but we still need to return a url
|
|
174
|
+
response_dict["url"] = url if url else "about:blank"
|
|
175
|
+
else:
|
|
176
|
+
# Process success content
|
|
177
|
+
response_dict["success"] = True
|
|
178
|
+
|
|
179
|
+
# Extract URL and screenshot from content (for computer use)
|
|
180
|
+
screenshot_parts = []
|
|
181
|
+
if is_computer_call:
|
|
182
|
+
for content in result.content:
|
|
183
|
+
if isinstance(content, types.TextContent):
|
|
184
|
+
# Check if this is a URL metadata block
|
|
185
|
+
if content.text.startswith("__URL__:"):
|
|
186
|
+
url = content.text.replace("__URL__:", "")
|
|
187
|
+
elif isinstance(content, types.ImageContent):
|
|
188
|
+
# Decode base64 string to bytes for FunctionResponseBlob
|
|
189
|
+
import base64
|
|
190
|
+
|
|
191
|
+
image_bytes = base64.b64decode(content.data)
|
|
192
|
+
screenshot_parts.append(
|
|
193
|
+
genai_types.FunctionResponsePart(
|
|
194
|
+
inline_data=genai_types.FunctionResponseBlob(
|
|
195
|
+
mime_type=content.mimeType or "image/png",
|
|
196
|
+
data=image_bytes,
|
|
197
|
+
)
|
|
198
|
+
)
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
# Add URL to response dict (required by Gemini Computer Use model)
|
|
202
|
+
# URL must ALWAYS be present per Gemini API requirements
|
|
203
|
+
response_dict["url"] = url if url else "about:blank"
|
|
204
|
+
|
|
205
|
+
# For Gemini Computer Use actions, always acknowledge safety decisions
|
|
206
|
+
requires_ack = False
|
|
207
|
+
if tool_call.arguments:
|
|
208
|
+
requires_ack = bool(tool_call.arguments.get("safety_decision"))
|
|
209
|
+
if requires_ack:
|
|
210
|
+
response_dict["safety_acknowledgement"] = True
|
|
211
|
+
else:
|
|
212
|
+
# For non-computer tools, add text content to response
|
|
213
|
+
for content in result.content:
|
|
214
|
+
if isinstance(content, types.TextContent):
|
|
215
|
+
response_dict["output"] = content.text
|
|
216
|
+
break
|
|
217
|
+
|
|
218
|
+
# Create function response
|
|
219
|
+
function_response = genai_types.FunctionResponse(
|
|
220
|
+
name=gemini_name,
|
|
221
|
+
response=response_dict,
|
|
222
|
+
parts=screenshot_parts if screenshot_parts else None,
|
|
223
|
+
)
|
|
224
|
+
function_responses.append(function_response)
|
|
225
|
+
|
|
226
|
+
# Return as a user message containing all function responses
|
|
227
|
+
return [
|
|
228
|
+
genai_types.Content(
|
|
229
|
+
role="user",
|
|
230
|
+
parts=[genai_types.Part(function_response=fr) for fr in function_responses],
|
|
231
|
+
)
|
|
232
|
+
]
|
|
233
|
+
|
|
234
|
+
def _extract_tool_call(self, part: genai_types.Part) -> MCPToolCall | None:
|
|
235
|
+
"""Extract an MCPToolCall from a function call part.
|
|
236
|
+
|
|
237
|
+
Routes predefined Gemini Computer Use functions to the gemini_computer tool
|
|
238
|
+
and normalizes the arguments to MCP tool schema.
|
|
239
|
+
"""
|
|
240
|
+
if not part.function_call:
|
|
241
|
+
return None
|
|
242
|
+
|
|
243
|
+
func_name = part.function_call.name or ""
|
|
244
|
+
raw_args = dict(part.function_call.args) if part.function_call.args else {}
|
|
245
|
+
|
|
246
|
+
# Route predefined computer use functions to the computer tool
|
|
247
|
+
if func_name in PREDEFINED_COMPUTER_USE_FUNCTIONS:
|
|
248
|
+
# Normalize Gemini Computer Use calls to MCP tool schema
|
|
249
|
+
# Ensure 'action' is present and equals the Gemini function name
|
|
250
|
+
normalized_args: dict[str, Any] = {"action": func_name}
|
|
251
|
+
|
|
252
|
+
# Map common argument shapes used by Gemini Computer Use
|
|
253
|
+
# 1) Coordinate arrays → x/y
|
|
254
|
+
coord = raw_args.get("coordinate") or raw_args.get("coordinates")
|
|
255
|
+
if isinstance(coord, list | tuple) and len(coord) >= 2:
|
|
256
|
+
try:
|
|
257
|
+
normalized_args["x"] = int(coord[0])
|
|
258
|
+
normalized_args["y"] = int(coord[1])
|
|
259
|
+
except (TypeError, ValueError):
|
|
260
|
+
# Fall back to raw if casting fails
|
|
261
|
+
pass
|
|
262
|
+
|
|
263
|
+
# Destination coordinate arrays → destination_x/destination_y
|
|
264
|
+
dest = (
|
|
265
|
+
raw_args.get("destination")
|
|
266
|
+
or raw_args.get("destination_coordinate")
|
|
267
|
+
or raw_args.get("destinationCoordinate")
|
|
268
|
+
)
|
|
269
|
+
if isinstance(dest, list | tuple) and len(dest) >= 2:
|
|
270
|
+
try:
|
|
271
|
+
normalized_args["destination_x"] = int(dest[0])
|
|
272
|
+
normalized_args["destination_y"] = int(dest[1])
|
|
273
|
+
except (TypeError, ValueError):
|
|
274
|
+
pass
|
|
275
|
+
|
|
276
|
+
# Pass through supported fields if present (including direct coords)
|
|
277
|
+
for key in (
|
|
278
|
+
"text",
|
|
279
|
+
"press_enter",
|
|
280
|
+
"clear_before_typing",
|
|
281
|
+
"safety_decision",
|
|
282
|
+
"direction",
|
|
283
|
+
"magnitude",
|
|
284
|
+
"url",
|
|
285
|
+
"keys",
|
|
286
|
+
"x",
|
|
287
|
+
"y",
|
|
288
|
+
"destination_x",
|
|
289
|
+
"destination_y",
|
|
290
|
+
):
|
|
291
|
+
if key in raw_args:
|
|
292
|
+
normalized_args[key] = raw_args[key]
|
|
293
|
+
|
|
294
|
+
return MCPToolCall(
|
|
295
|
+
name=self._computer_tool_name,
|
|
296
|
+
arguments=normalized_args,
|
|
297
|
+
gemini_name=func_name, # type: ignore[arg-type]
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
# Non-computer tools: use parent implementation
|
|
301
|
+
return super()._extract_tool_call(part)
|
|
302
|
+
|
|
303
|
+
def _remove_old_screenshots(self, messages: list[genai_types.Content]) -> None:
|
|
304
|
+
"""
|
|
305
|
+
Remove screenshots from old turns to manage context length.
|
|
306
|
+
Keeps only the last N turns with screenshots (configured via
|
|
307
|
+
self.max_recent_turn_with_screenshots).
|
|
308
|
+
"""
|
|
309
|
+
turn_with_screenshots_found = 0
|
|
310
|
+
|
|
311
|
+
for content in reversed(messages):
|
|
312
|
+
if content.role == "user" and content.parts:
|
|
313
|
+
# Check if content has screenshots (function responses with images)
|
|
314
|
+
has_screenshot = False
|
|
315
|
+
for part in content.parts:
|
|
316
|
+
if (
|
|
317
|
+
part.function_response
|
|
318
|
+
and part.function_response.parts
|
|
319
|
+
and part.function_response.name in PREDEFINED_COMPUTER_USE_FUNCTIONS
|
|
320
|
+
):
|
|
321
|
+
has_screenshot = True
|
|
322
|
+
break
|
|
323
|
+
|
|
324
|
+
if has_screenshot:
|
|
325
|
+
turn_with_screenshots_found += 1
|
|
326
|
+
# Remove the screenshot image if the number of screenshots exceeds the limit
|
|
327
|
+
if turn_with_screenshots_found > self.max_recent_turn_with_screenshots:
|
|
328
|
+
for part in content.parts:
|
|
329
|
+
if (
|
|
330
|
+
part.function_response
|
|
331
|
+
and part.function_response.parts
|
|
332
|
+
and part.function_response.name in PREDEFINED_COMPUTER_USE_FUNCTIONS
|
|
333
|
+
):
|
|
334
|
+
# Clear the parts (screenshots)
|
|
335
|
+
part.function_response.parts = None
|