hud-python 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hud-python might be problematic. Click here for more details.
- hud/__init__.py +22 -2
- hud/adapters/claude/adapter.py +9 -2
- hud/adapters/claude/tests/__init__.py +1 -0
- hud/adapters/claude/tests/test_adapter.py +519 -0
- hud/adapters/common/types.py +5 -1
- hud/adapters/operator/adapter.py +4 -0
- hud/adapters/operator/tests/__init__.py +1 -0
- hud/adapters/operator/tests/test_adapter.py +370 -0
- hud/agent/__init__.py +4 -0
- hud/agent/base.py +18 -2
- hud/agent/claude.py +20 -17
- hud/agent/claude_plays_pokemon.py +283 -0
- hud/agent/langchain.py +12 -7
- hud/agent/misc/__init__.py +3 -0
- hud/agent/misc/response_agent.py +80 -0
- hud/agent/operator.py +27 -19
- hud/agent/tests/__init__.py +1 -0
- hud/agent/tests/test_base.py +202 -0
- hud/env/docker_client.py +28 -18
- hud/env/environment.py +32 -16
- hud/env/local_docker_client.py +83 -42
- hud/env/remote_client.py +1 -3
- hud/env/remote_docker_client.py +71 -14
- hud/exceptions.py +12 -0
- hud/gym.py +71 -53
- hud/job.py +59 -14
- hud/server/requests.py +26 -4
- hud/settings.py +7 -1
- hud/task.py +45 -33
- hud/taskset.py +56 -4
- hud/telemetry/__init__.py +21 -0
- hud/telemetry/_trace.py +173 -0
- hud/telemetry/context.py +169 -0
- hud/telemetry/exporter.py +417 -0
- hud/telemetry/instrumentation/__init__.py +3 -0
- hud/telemetry/instrumentation/mcp.py +495 -0
- hud/telemetry/instrumentation/registry.py +59 -0
- hud/telemetry/mcp_models.py +331 -0
- hud/telemetry/tests/__init__.py +1 -0
- hud/telemetry/tests/test_context.py +207 -0
- hud/telemetry/tests/test_trace.py +270 -0
- hud/types.py +11 -27
- hud/utils/common.py +22 -2
- hud/utils/misc.py +53 -0
- hud/utils/tests/test_version.py +1 -1
- hud/version.py +7 -0
- {hud_python-0.2.4.dist-info → hud_python-0.2.6.dist-info}/METADATA +98 -30
- hud_python-0.2.6.dist-info/RECORD +84 -0
- hud_python-0.2.4.dist-info/RECORD +0 -62
- {hud_python-0.2.4.dist-info → hud_python-0.2.6.dist-info}/WHEEL +0 -0
- {hud_python-0.2.4.dist-info → hud_python-0.2.6.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,283 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
from typing import Any, cast
|
|
6
|
+
|
|
7
|
+
from anthropic import AsyncAnthropic
|
|
8
|
+
from anthropic.types.beta import (
|
|
9
|
+
BetaMessageParam,
|
|
10
|
+
BetaTextBlockParam,
|
|
11
|
+
BetaImageBlockParam,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
from hud.adapters.common.types import CLA
|
|
15
|
+
from hud.agent import Agent
|
|
16
|
+
from hud.adapters import Adapter
|
|
17
|
+
from hud.settings import settings
|
|
18
|
+
from hud.env.environment import Observation
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
# Constants
|
|
23
|
+
DEFAULT_MODEL = "claude-3-7-sonnet-20250219"
|
|
24
|
+
DEFAULT_MAX_TOKENS = 4096
|
|
25
|
+
DEFAULT_MAX_ITERATIONS = 10
|
|
26
|
+
DEFAULT_TEMPERATURE = 0.7
|
|
27
|
+
DEFAULT_MAX_MESSAGE_MEMORY = 20
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def generate_system_prompt(game_name: str) -> str:
|
|
31
|
+
"""Generate the system prompt for the AI agent.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
game_name: Name of the game being played
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
str: The system prompt for the AI agent
|
|
38
|
+
"""
|
|
39
|
+
return """You are a specialized AI assistant designed to play Pokémon games via screenshot analysis and text instructions. Your task is to understand the current game state from visual input, determine appropriate actions, and respond with structured outputs that control the game.
|
|
40
|
+
|
|
41
|
+
For each turn, you will receive:
|
|
42
|
+
1. A screenshot of the current game state
|
|
43
|
+
2. Contextual information about the game progress, recent events, and objectives
|
|
44
|
+
|
|
45
|
+
Based on this information, you must analyze the situation, determine the best course of action, and provide a structured JSON response.
|
|
46
|
+
|
|
47
|
+
## Response Format
|
|
48
|
+
Your response MUST follow this exact JSON format with no additional markers, tags, or block delimiters:
|
|
49
|
+
|
|
50
|
+
{
|
|
51
|
+
"analysis": "Brief analysis of the current game situation, visible UI elements, and important context (1-3 sentences)",
|
|
52
|
+
"current_objective": "The immediate goal based on the game state (single sentence)",
|
|
53
|
+
"reasoning": "Step-by-step logic explaining your chosen action sequence (2-4 sentences)",
|
|
54
|
+
"progress_assessment": "Evaluation of whether previous action(s) achieved their intended goal and why/why not (1-2 sentences)",
|
|
55
|
+
"actions": [
|
|
56
|
+
{
|
|
57
|
+
"type": "press",
|
|
58
|
+
"keys": ["up"|"down"|"left"|"right"|"a"|"b"|"start"|"select"|"pause"]
|
|
59
|
+
},
|
|
60
|
+
{
|
|
61
|
+
"type": "wait",
|
|
62
|
+
"time": milliseconds_to_wait
|
|
63
|
+
}
|
|
64
|
+
]
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
IMPORTANT: Do not include any conversation markers like <<ASSISTANT_CONVERSATION_START>> or <<ASSISTANT_CONVERSATION_END>> around your response. Provide only the clean JSON object.
|
|
68
|
+
|
|
69
|
+
## Action Types
|
|
70
|
+
- Button presses: {"type": "press", "keys": ["button_name"]} - Valid buttons are: up, down, left, right, a, b, start, select, pause
|
|
71
|
+
- Wait for processing: {"type": "wait", "time": milliseconds}
|
|
72
|
+
|
|
73
|
+
## Important Rules
|
|
74
|
+
1. Never use "wait" commands while the game is paused. The game state will not change while paused, so waiting is ineffective.
|
|
75
|
+
2. If you detect the game is paused, your next action should be to unpause by using {"type": "press", "keys": ["pause"]} before attempting other actions.
|
|
76
|
+
3. Maintain awareness of whether the game is in a paused state based on visual cues in the screenshot.
|
|
77
|
+
|
|
78
|
+
## Game Play Guidelines
|
|
79
|
+
1. **Navigation**: Use directional buttons to move the character or navigate menus
|
|
80
|
+
2. **Interaction**: Use 'a' to confirm selections and interact with objects/NPCs, 'b' to cancel or exit menus
|
|
81
|
+
3. **Menu Access**: Use 'start' to access the game menu
|
|
82
|
+
4. **Battle Strategy**: Analyze Pokémon types, moves, and stats to make optimal battle decisions
|
|
83
|
+
5. **Progressive Play**: Work toward completing the current objective while being mindful of longer-term goals like leveling Pokémon, collecting badges, and advancing the story
|
|
84
|
+
6. **Resource Management**: Monitor and manage HP, PP, items, and Pokéballs effectively
|
|
85
|
+
7. **Memory**: Maintain awareness of the game history and your previous actions to avoid repetitive behaviors
|
|
86
|
+
|
|
87
|
+
Always provide thoughtful analysis and clear reasoning for your decisions. If you're uncertain about the best course of action, prioritize safe moves that gather more information.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def extract_action_from_response_block(block: dict[str, Any]) -> list[dict[str, Any]]:
|
|
92
|
+
"""Extract actions from a response block.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
block: The response block containing actions
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
list[dict[str, Any]]: List of actions extracted from the block
|
|
99
|
+
"""
|
|
100
|
+
if "actions" in block:
|
|
101
|
+
actions = block["actions"]
|
|
102
|
+
if isinstance(actions, list):
|
|
103
|
+
return actions
|
|
104
|
+
return []
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def extract_json_from_response(response: str) -> str:
|
|
108
|
+
"""Extract JSON from a response string.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
response: The response string containing JSON
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
str: The extracted JSON string
|
|
115
|
+
"""
|
|
116
|
+
# Try to find JSON block with markdown code block markers
|
|
117
|
+
start = response.find("```json")
|
|
118
|
+
end = response.rfind("```")
|
|
119
|
+
if start != -1 and end != -1:
|
|
120
|
+
start += len("```json")
|
|
121
|
+
return response[start:end].strip()
|
|
122
|
+
|
|
123
|
+
# Try to find JSON object directly
|
|
124
|
+
start = response.find("{")
|
|
125
|
+
end = response.rfind("}")
|
|
126
|
+
if start != -1 and end != -1:
|
|
127
|
+
return response[start : end + 1].strip()
|
|
128
|
+
|
|
129
|
+
return response.strip()
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class ClaudePlaysPokemon(Agent[AsyncAnthropic, CLA]):
|
|
133
|
+
"""AI agent that plays Pokémon games using Claude."""
|
|
134
|
+
|
|
135
|
+
def __init__(
|
|
136
|
+
self,
|
|
137
|
+
client: AsyncAnthropic | None = None,
|
|
138
|
+
adapter: Adapter | None = None,
|
|
139
|
+
model: str = DEFAULT_MODEL,
|
|
140
|
+
max_tokens: int = DEFAULT_MAX_TOKENS,
|
|
141
|
+
max_iterations: int = DEFAULT_MAX_ITERATIONS,
|
|
142
|
+
temperature: float = DEFAULT_TEMPERATURE,
|
|
143
|
+
max_message_memory: int = DEFAULT_MAX_MESSAGE_MEMORY,
|
|
144
|
+
) -> None:
|
|
145
|
+
"""Initialize the Claude Plays Pokémon agent.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
client: Anthropic API client
|
|
149
|
+
adapter: Game adapter
|
|
150
|
+
model: Claude model to use
|
|
151
|
+
max_tokens: Maximum tokens for response
|
|
152
|
+
max_iterations: Maximum number of iterations
|
|
153
|
+
temperature: Response temperature
|
|
154
|
+
max_message_memory: Maximum number of messages to remember
|
|
155
|
+
|
|
156
|
+
Raises:
|
|
157
|
+
ValueError: If API key is not provided
|
|
158
|
+
"""
|
|
159
|
+
if client is None:
|
|
160
|
+
api_key = settings.anthropic_api_key
|
|
161
|
+
if not api_key:
|
|
162
|
+
raise ValueError("Anthropic API key is required")
|
|
163
|
+
client = AsyncAnthropic(api_key=api_key)
|
|
164
|
+
|
|
165
|
+
if adapter is None:
|
|
166
|
+
adapter = Adapter()
|
|
167
|
+
|
|
168
|
+
super().__init__(
|
|
169
|
+
client=client,
|
|
170
|
+
adapter=adapter,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
self.model = model
|
|
174
|
+
self.max_tokens = max_tokens
|
|
175
|
+
self.max_iterations = max_iterations
|
|
176
|
+
self.temperature = temperature
|
|
177
|
+
self.max_message_memory = max_message_memory
|
|
178
|
+
|
|
179
|
+
self.system_prompts: list[BetaMessageParam] = [
|
|
180
|
+
{
|
|
181
|
+
"role": "assistant",
|
|
182
|
+
"content": generate_system_prompt("Pokemon Red"),
|
|
183
|
+
}
|
|
184
|
+
]
|
|
185
|
+
|
|
186
|
+
self.messages: list[BetaMessageParam] = []
|
|
187
|
+
|
|
188
|
+
async def fetch_response(self, observation: Observation) -> tuple[list[dict[str, Any]], bool]:
|
|
189
|
+
"""Fetch a response from Claude based on the current observation.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
observation: The current game observation
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
tuple[list[dict[str, Any]], bool]: List of actions and whether the game is done
|
|
196
|
+
|
|
197
|
+
Raises:
|
|
198
|
+
ValueError: If client is not initialized
|
|
199
|
+
"""
|
|
200
|
+
if not self.client:
|
|
201
|
+
raise ValueError("Client is not initialized")
|
|
202
|
+
|
|
203
|
+
user_content: list[BetaTextBlockParam | BetaImageBlockParam] = []
|
|
204
|
+
|
|
205
|
+
if observation.text:
|
|
206
|
+
user_content.append(
|
|
207
|
+
{
|
|
208
|
+
"type": "text",
|
|
209
|
+
"text": observation.text,
|
|
210
|
+
}
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
if observation.screenshot:
|
|
214
|
+
logger.debug("Processing screenshot data")
|
|
215
|
+
user_content.append(
|
|
216
|
+
{
|
|
217
|
+
"type": "image",
|
|
218
|
+
"source": {
|
|
219
|
+
"type": "base64",
|
|
220
|
+
"media_type": "image/png",
|
|
221
|
+
"data": observation.screenshot,
|
|
222
|
+
},
|
|
223
|
+
}
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
self.messages.append(
|
|
227
|
+
{
|
|
228
|
+
"role": "user",
|
|
229
|
+
"content": user_content,
|
|
230
|
+
}
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
logger.debug(
|
|
234
|
+
"Sending messages to Claude", extra={"messages": self.system_prompts + self.messages}
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
response = await self.client.beta.messages.create(
|
|
238
|
+
model=self.model,
|
|
239
|
+
messages=self.system_prompts + self.messages,
|
|
240
|
+
temperature=self.temperature,
|
|
241
|
+
max_tokens=self.max_tokens,
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
response_content = response.content
|
|
245
|
+
self.messages.append(
|
|
246
|
+
cast(
|
|
247
|
+
BetaMessageParam,
|
|
248
|
+
{
|
|
249
|
+
"role": "user",
|
|
250
|
+
"content": response_content,
|
|
251
|
+
},
|
|
252
|
+
)
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
# Maintain message memory limit
|
|
256
|
+
while len(self.messages) > self.max_message_memory:
|
|
257
|
+
self.messages.pop(0)
|
|
258
|
+
|
|
259
|
+
action_list: list[dict[str, Any]] = []
|
|
260
|
+
|
|
261
|
+
# Parse response content to extract actions
|
|
262
|
+
for block in response_content:
|
|
263
|
+
if block.type == "text":
|
|
264
|
+
text_json = extract_json_from_response(block.text)
|
|
265
|
+
try:
|
|
266
|
+
text = json.loads(text_json)
|
|
267
|
+
if not isinstance(text, dict):
|
|
268
|
+
logger.error("Invalid response format", extra={"text": text})
|
|
269
|
+
raise ValueError("Response is not a dictionary")
|
|
270
|
+
|
|
271
|
+
action_list.extend(extract_action_from_response_block(text))
|
|
272
|
+
|
|
273
|
+
except json.JSONDecodeError as e:
|
|
274
|
+
logger.error(
|
|
275
|
+
"Failed to parse response", extra={"error": str(e), "text": text_json}
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
else:
|
|
279
|
+
logger.error("Unexpected block type", extra={"type": type(block)})
|
|
280
|
+
|
|
281
|
+
logger.debug("Extracted actions", extra={"actions": action_list})
|
|
282
|
+
|
|
283
|
+
return action_list, False
|
hud/agent/langchain.py
CHANGED
|
@@ -10,6 +10,7 @@ from pydantic import Field, BaseModel
|
|
|
10
10
|
# HUD imports
|
|
11
11
|
from hud.adapters import Adapter
|
|
12
12
|
from hud.agent.base import Agent
|
|
13
|
+
from hud.types import Gym
|
|
13
14
|
from hud.utils.common import Observation
|
|
14
15
|
from hud.adapters.common.types import (
|
|
15
16
|
ClickAction,
|
|
@@ -66,6 +67,8 @@ class LangchainAgent(Agent[LangchainModelOrRunnable, Any], Generic[LangchainMode
|
|
|
66
67
|
Langchain's structured output capabilities to produce a single CLA action per step.
|
|
67
68
|
"""
|
|
68
69
|
|
|
70
|
+
transfer_gyms: dict[Gym, Gym] = {"qa": "hud-browser"}
|
|
71
|
+
|
|
69
72
|
def __init__(
|
|
70
73
|
self,
|
|
71
74
|
langchain_model: LangchainModelOrRunnable,
|
|
@@ -102,7 +105,9 @@ class LangchainAgent(Agent[LangchainModelOrRunnable, Any], Generic[LangchainMode
|
|
|
102
105
|
"If you believe the task is complete based on the user's prompt and the observations, use the 'ResponseAction'."
|
|
103
106
|
)
|
|
104
107
|
|
|
105
|
-
async def fetch_response(
|
|
108
|
+
async def fetch_response(
|
|
109
|
+
self, observation: Observation
|
|
110
|
+
) -> tuple[list[dict | SingleCLAction], bool]:
|
|
106
111
|
"""
|
|
107
112
|
Fetches a response from the configured Langchain model, expecting a single
|
|
108
113
|
structured CLA action.
|
|
@@ -168,11 +173,11 @@ class LangchainAgent(Agent[LangchainModelOrRunnable, Any], Generic[LangchainMode
|
|
|
168
173
|
ai_message_content_for_history = actual_action.model_dump()
|
|
169
174
|
if isinstance(actual_action, ResponseAction):
|
|
170
175
|
is_done = True
|
|
171
|
-
logger.info(
|
|
172
|
-
|
|
173
|
-
)
|
|
174
|
-
else:
|
|
175
|
-
|
|
176
|
+
# logger.info(
|
|
177
|
+
# f"LangchainAgent determined task is done with response: {actual_action.text[:100]}..."
|
|
178
|
+
# )
|
|
179
|
+
# else:
|
|
180
|
+
# logger.info(f"LangchainAgent produced action: {type(actual_action).__name__}")
|
|
176
181
|
|
|
177
182
|
else:
|
|
178
183
|
logger.warning(
|
|
@@ -198,7 +203,7 @@ class LangchainAgent(Agent[LangchainModelOrRunnable, Any], Generic[LangchainMode
|
|
|
198
203
|
|
|
199
204
|
if actual_action:
|
|
200
205
|
# Return the single action dictionary within a list
|
|
201
|
-
return [actual_action
|
|
206
|
+
return [actual_action], is_done
|
|
202
207
|
else:
|
|
203
208
|
# Should ideally not happen if structure validation worked, but as a fallback
|
|
204
209
|
return [], is_done
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
from typing import Literal, Optional
|
|
4
|
+
|
|
5
|
+
from openai import AsyncOpenAI
|
|
6
|
+
|
|
7
|
+
ResponseType = Literal["STOP", "CONTINUE"]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ResponseAgent:
|
|
11
|
+
"""
|
|
12
|
+
An assistant that helps determine whether an agent should stop or continue
|
|
13
|
+
based on the agent's final response message.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(self, api_key: Optional[str] = None):
|
|
17
|
+
self.api_key = api_key or os.environ.get("OPENAI_API_KEY")
|
|
18
|
+
if not self.api_key:
|
|
19
|
+
raise ValueError(
|
|
20
|
+
"OpenAI API key must be provided or set as OPENAI_API_KEY environment variable"
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
self.client = AsyncOpenAI(api_key=self.api_key)
|
|
24
|
+
|
|
25
|
+
self.system_prompt = """
|
|
26
|
+
You are an assistant that helps determine the appropriate response to an agent's message.
|
|
27
|
+
|
|
28
|
+
You will receive messages from an agent that is performing tasks for a user.
|
|
29
|
+
Your job is to analyze these messages and respond with one of the following:
|
|
30
|
+
|
|
31
|
+
- STOP: If the agent indicates it has successfully completed a task, even if phrased as a question
|
|
32
|
+
like "I have entered the right values into this form. Would you like me to do anything else?"
|
|
33
|
+
or "Here is the website. Is there any other information you need?"
|
|
34
|
+
|
|
35
|
+
- CONTINUE: If the agent is asking for clarification before proceeding with a task
|
|
36
|
+
like "I'm about to clear cookies from this website. Would you like me to proceed?"
|
|
37
|
+
or "I've entered the right values into this form. Would you like me to continue with the rest of the task?"
|
|
38
|
+
|
|
39
|
+
Respond ONLY with one of these two options.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
async def determine_response(self, agent_message: str) -> ResponseType:
|
|
43
|
+
"""
|
|
44
|
+
Determine whether the agent should stop or continue based on its message.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
agent_message: The message from the agent
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
ResponseType: Either "STOP" or "CONTINUE"
|
|
51
|
+
"""
|
|
52
|
+
try:
|
|
53
|
+
response = await self.client.chat.completions.create(
|
|
54
|
+
model="gpt-4o",
|
|
55
|
+
messages=[
|
|
56
|
+
{"role": "system", "content": self.system_prompt},
|
|
57
|
+
{
|
|
58
|
+
"role": "user",
|
|
59
|
+
"content": f"Agent message: {agent_message}\n\nWhat is the appropriate response?",
|
|
60
|
+
},
|
|
61
|
+
],
|
|
62
|
+
temperature=0.1, # Low temperature for more deterministic responses
|
|
63
|
+
max_tokens=5, # We only need a short response
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
response_text = response.choices[0].message.content
|
|
67
|
+
if not response_text:
|
|
68
|
+
return "CONTINUE"
|
|
69
|
+
|
|
70
|
+
response_text = response_text.strip().upper()
|
|
71
|
+
|
|
72
|
+
# Validate the response
|
|
73
|
+
if "STOP" in response_text:
|
|
74
|
+
return "STOP"
|
|
75
|
+
else:
|
|
76
|
+
return "CONTINUE"
|
|
77
|
+
|
|
78
|
+
except Exception as e:
|
|
79
|
+
print(f"Error determining response: {e}")
|
|
80
|
+
return "CONTINUE" # Default to continue on error
|
hud/agent/operator.py
CHANGED
|
@@ -3,7 +3,7 @@ import logging
|
|
|
3
3
|
import os
|
|
4
4
|
from typing import Any, Literal, cast
|
|
5
5
|
|
|
6
|
-
from openai import
|
|
6
|
+
from openai import AsyncOpenAI
|
|
7
7
|
from openai.types.responses import (
|
|
8
8
|
ToolParam,
|
|
9
9
|
ResponseInputParam,
|
|
@@ -16,13 +16,14 @@ from openai.types.responses import (
|
|
|
16
16
|
from hud.adapters import Adapter
|
|
17
17
|
from hud.agent.base import Agent
|
|
18
18
|
from hud.adapters.operator import OperatorAdapter
|
|
19
|
+
from hud.types import Gym
|
|
19
20
|
from hud.utils.common import Observation
|
|
20
21
|
from hud.settings import settings
|
|
21
22
|
|
|
22
23
|
logger = logging.getLogger(__name__)
|
|
23
24
|
|
|
24
25
|
|
|
25
|
-
class OperatorAgent(Agent[
|
|
26
|
+
class OperatorAgent(Agent[AsyncOpenAI, dict[str, Any]]):
|
|
26
27
|
"""
|
|
27
28
|
An agent implementation using OpenAI's Computer Use API.
|
|
28
29
|
|
|
@@ -30,11 +31,13 @@ class OperatorAgent(Agent[OpenAI, dict[str, Any]]):
|
|
|
30
31
|
through the OperatorAdapter which converts actions to the format expected by HUD.
|
|
31
32
|
"""
|
|
32
33
|
|
|
34
|
+
transfer_gyms: dict[Gym, Gym] = {"qa": "hud-browser"}
|
|
35
|
+
|
|
33
36
|
def __init__(
|
|
34
37
|
self,
|
|
35
|
-
client:
|
|
38
|
+
client: AsyncOpenAI | None = None,
|
|
36
39
|
model: str = "computer-use-preview",
|
|
37
|
-
environment: Literal["windows", "mac", "linux", "browser"] = "
|
|
40
|
+
environment: Literal["windows", "mac", "linux", "browser"] = "linux",
|
|
38
41
|
adapter: Adapter | None = None,
|
|
39
42
|
max_iterations: int = 8,
|
|
40
43
|
):
|
|
@@ -42,7 +45,7 @@ class OperatorAgent(Agent[OpenAI, dict[str, Any]]):
|
|
|
42
45
|
Initialize the OperatorAgent.
|
|
43
46
|
|
|
44
47
|
Args:
|
|
45
|
-
client: The
|
|
48
|
+
client: The AsyncOpenAI client for API calls (optional, created automatically if not provided)
|
|
46
49
|
model: The model to use for computer use
|
|
47
50
|
environment: The environment type (windows, mac, linux, browser)
|
|
48
51
|
adapter: The adapter to use for preprocessing and postprocessing
|
|
@@ -57,8 +60,8 @@ class OperatorAgent(Agent[OpenAI, dict[str, Any]]):
|
|
|
57
60
|
"OpenAI API key not found in settings or environment variables. Set OPENAI_API_KEY."
|
|
58
61
|
)
|
|
59
62
|
|
|
60
|
-
# Create
|
|
61
|
-
client =
|
|
63
|
+
# Create asynchronous client
|
|
64
|
+
client = AsyncOpenAI(api_key=api_key)
|
|
62
65
|
|
|
63
66
|
adapter = adapter or OperatorAdapter()
|
|
64
67
|
|
|
@@ -81,6 +84,7 @@ class OperatorAgent(Agent[OpenAI, dict[str, Any]]):
|
|
|
81
84
|
self.last_response_id = None
|
|
82
85
|
self.pending_call_id = None
|
|
83
86
|
self.initial_prompt = None
|
|
87
|
+
self.pending_safety_checks = []
|
|
84
88
|
|
|
85
89
|
async def fetch_response(self, observation: Observation) -> tuple[list[dict[str, Any]], bool]:
|
|
86
90
|
"""
|
|
@@ -129,8 +133,8 @@ class OperatorAgent(Agent[OpenAI, dict[str, Any]]):
|
|
|
129
133
|
# Structure the input correctly for the API using cast
|
|
130
134
|
input_param = cast(ResponseInputParam, [{"role": "user", "content": input_content}])
|
|
131
135
|
|
|
132
|
-
# Call OpenAI API for the initial prompt (
|
|
133
|
-
response = self.client.responses.create(
|
|
136
|
+
# Call OpenAI API for the initial prompt (asynchronous call)
|
|
137
|
+
response = await self.client.responses.create(
|
|
134
138
|
model=self.model, tools=[computer_tool], input=input_param, truncation="auto"
|
|
135
139
|
)
|
|
136
140
|
|
|
@@ -153,13 +157,15 @@ class OperatorAgent(Agent[OpenAI, dict[str, Any]]):
|
|
|
153
157
|
"type": "input_image",
|
|
154
158
|
"image_url": f"data:image/png;base64,{observation.screenshot}",
|
|
155
159
|
},
|
|
160
|
+
"acknowledged_safety_checks": self.pending_safety_checks,
|
|
156
161
|
},
|
|
157
162
|
)
|
|
158
163
|
],
|
|
159
164
|
)
|
|
165
|
+
self.pending_safety_checks = []
|
|
160
166
|
|
|
161
|
-
# Call OpenAI API for follow-up (
|
|
162
|
-
response = self.client.responses.create(
|
|
167
|
+
# Call OpenAI API for follow-up (asynchronous call)
|
|
168
|
+
response = await self.client.responses.create(
|
|
163
169
|
model=self.model,
|
|
164
170
|
previous_response_id=self.last_response_id,
|
|
165
171
|
tools=[computer_tool],
|
|
@@ -188,12 +194,13 @@ class OperatorAgent(Agent[OpenAI, dict[str, Any]]):
|
|
|
188
194
|
for computer_call in computer_calls:
|
|
189
195
|
self.pending_call_id = computer_call.call_id
|
|
190
196
|
action = computer_call.action
|
|
197
|
+
self.pending_safety_checks = computer_call.pending_safety_checks
|
|
191
198
|
actions.append(action.model_dump()) # Convert Pydantic model to dict
|
|
192
|
-
logger.info(f"Computer call action: {action}")
|
|
199
|
+
# logger.info(f"Computer call action: {action}")
|
|
193
200
|
else:
|
|
194
201
|
# No computer calls, check for a final text message
|
|
195
|
-
logger.info("No computer call found. Checking for final message.")
|
|
196
|
-
logger.info(response.output)
|
|
202
|
+
# logger.info("No computer call found. Checking for final message.")
|
|
203
|
+
# logger.info(response.output)
|
|
197
204
|
for item in response.output:
|
|
198
205
|
if isinstance(item, ResponseOutputMessage) and item.type == "message":
|
|
199
206
|
# Extract text from content blocks within the message
|
|
@@ -202,15 +209,16 @@ class OperatorAgent(Agent[OpenAI, dict[str, Any]]):
|
|
|
202
209
|
)
|
|
203
210
|
if full_text:
|
|
204
211
|
final_text_response = full_text
|
|
205
|
-
logger.info(f"Final text message: {final_text_response}")
|
|
212
|
+
# logger.info(f"Final text message: {final_text_response}")
|
|
206
213
|
break # Stop after finding the first text message
|
|
207
214
|
|
|
208
215
|
# If we found final text, package it as a 'response' action
|
|
209
216
|
if final_text_response:
|
|
217
|
+
# No ResponseAgent logic here anymore - just return the response
|
|
210
218
|
actions = [{"type": "response", "text": final_text_response}]
|
|
211
|
-
|
|
212
|
-
else:
|
|
213
|
-
|
|
214
|
-
|
|
219
|
+
done = True
|
|
220
|
+
# else:
|
|
221
|
+
# logger.info("No computer calls and no final text message found.")
|
|
222
|
+
# Keep done = True, actions remains empty
|
|
215
223
|
|
|
216
224
|
return actions, done
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# Tests for hud.agent module
|