hud-python 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hud-python might be problematic. Click here for more details.

Files changed (58) hide show
  1. hud/__init__.py +4 -3
  2. hud/adapters/claude/adapter.py +5 -14
  3. hud/adapters/common/adapter.py +3 -3
  4. hud/adapters/common/tests/__init__.py +0 -0
  5. hud/adapters/common/tests/test_adapter.py +277 -0
  6. hud/adapters/common/types.py +3 -3
  7. hud/adapters/operator/adapter.py +16 -23
  8. hud/agent/__init__.py +8 -1
  9. hud/agent/base.py +28 -28
  10. hud/agent/claude.py +69 -60
  11. hud/agent/langchain.py +32 -26
  12. hud/agent/operator.py +75 -67
  13. hud/env/__init__.py +5 -5
  14. hud/env/client.py +2 -2
  15. hud/env/docker_client.py +37 -39
  16. hud/env/environment.py +91 -66
  17. hud/env/local_docker_client.py +5 -7
  18. hud/env/remote_client.py +39 -32
  19. hud/env/remote_docker_client.py +13 -3
  20. hud/evaluators/__init__.py +2 -3
  21. hud/evaluators/base.py +4 -3
  22. hud/evaluators/inspect.py +3 -8
  23. hud/evaluators/judge.py +34 -58
  24. hud/evaluators/match.py +42 -49
  25. hud/evaluators/remote.py +13 -26
  26. hud/evaluators/tests/__init__.py +0 -0
  27. hud/evaluators/tests/test_inspect.py +12 -0
  28. hud/evaluators/tests/test_judge.py +231 -0
  29. hud/evaluators/tests/test_match.py +115 -0
  30. hud/evaluators/tests/test_remote.py +98 -0
  31. hud/exceptions.py +167 -0
  32. hud/gym.py +9 -7
  33. hud/job.py +179 -109
  34. hud/server/__init__.py +2 -2
  35. hud/server/requests.py +148 -186
  36. hud/server/tests/__init__.py +0 -0
  37. hud/server/tests/test_requests.py +275 -0
  38. hud/settings.py +3 -2
  39. hud/task.py +9 -19
  40. hud/taskset.py +44 -11
  41. hud/trajectory.py +6 -9
  42. hud/types.py +12 -9
  43. hud/utils/__init__.py +2 -2
  44. hud/utils/common.py +36 -15
  45. hud/utils/config.py +45 -30
  46. hud/utils/progress.py +34 -21
  47. hud/utils/telemetry.py +10 -11
  48. hud/utils/tests/__init__.py +0 -0
  49. hud/utils/tests/test_common.py +52 -0
  50. hud/utils/tests/test_config.py +129 -0
  51. hud/utils/tests/test_progress.py +225 -0
  52. hud/utils/tests/test_telemetry.py +37 -0
  53. hud/utils/tests/test_version.py +8 -0
  54. {hud_python-0.2.2.dist-info → hud_python-0.2.4.dist-info}/METADATA +9 -6
  55. hud_python-0.2.4.dist-info/RECORD +62 -0
  56. hud_python-0.2.2.dist-info/RECORD +0 -46
  57. {hud_python-0.2.2.dist-info → hud_python-0.2.4.dist-info}/WHEEL +0 -0
  58. {hud_python-0.2.2.dist-info → hud_python-0.2.4.dist-info}/licenses/LICENSE +0 -0
hud/__init__.py CHANGED
@@ -5,19 +5,20 @@ HUD Gym SDK - A Python SDK for interacting with HUD environments.
5
5
  from __future__ import annotations
6
6
 
7
7
  from . import agent, env, gym, settings, task, taskset, types, utils
8
- from .job import create_job, job, load_job, run_job
8
+ from .job import create_job, load_job, run_job
9
+ from .job import job as register_job
9
10
  from .taskset import load_taskset
10
11
 
11
- __version__ = "0.2.2"
12
+ __version__ = "0.2.4"
12
13
 
13
14
  __all__ = [
14
15
  "agent",
15
16
  "create_job",
16
17
  "env",
17
18
  "gym",
18
- "job",
19
19
  "load_job",
20
20
  "load_taskset",
21
+ "register_job",
21
22
  "run_job",
22
23
  "settings",
23
24
  "task",
@@ -39,7 +39,7 @@ class ClaudeAdapter(Adapter):
39
39
  def _map_key(self, key: str) -> CLAKey:
40
40
  """Map a key to its standardized form."""
41
41
  return self.KEY_MAP.get(key.lower(), key.lower()) # type: ignore
42
-
42
+
43
43
  def convert(self, data: Any) -> CLA:
44
44
  try:
45
45
  action_type = data.get("action")
@@ -47,9 +47,7 @@ class ClaudeAdapter(Adapter):
47
47
  if action_type == "key":
48
48
  assert "text" in data
49
49
  if "+" in data["text"]:
50
- keys: list[CLAKey] = [
51
- self._map_key(k) for k in (data["text"].split("+"))
52
- ]
50
+ keys: list[CLAKey] = [self._map_key(k) for k in (data["text"].split("+"))]
53
51
  assert len(keys) > 0
54
52
  return PressAction(keys=keys)
55
53
  return PressAction(keys=[self._map_key(data["text"])])
@@ -83,19 +81,12 @@ class ClaudeAdapter(Adapter):
83
81
  assert len(coord) == 2
84
82
  if (
85
83
  len(self.memory) == 0
86
- or (
87
- self.memory[-1] is not MoveAction
88
- and self.memory[-1] is not ClickAction
89
- )
84
+ or (self.memory[-1] is not MoveAction and self.memory[-1] is not ClickAction)
90
85
  or self.memory[-1].point is None
91
86
  ):
92
- raise ValueError(
93
- "Left click drag must be preceded by a move or click action"
94
- )
87
+ raise ValueError("Left click drag must be preceded by a move or click action")
95
88
  else:
96
- return DragAction(
97
- path=[self.memory[-1].point, Point(x=coord[0], y=coord[1])]
98
- )
89
+ return DragAction(path=[self.memory[-1].point, Point(x=coord[0], y=coord[1])])
99
90
 
100
91
  elif action_type == "right_click":
101
92
  assert "coordinate" in data
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import TYPE_CHECKING, Any
3
+ from typing import TYPE_CHECKING, Any, TypeAlias
4
4
 
5
5
  import numpy as np
6
6
  from PIL import Image
@@ -11,7 +11,7 @@ from .types import CLA
11
11
  if TYPE_CHECKING:
12
12
  from typing_extensions import TypeIs
13
13
 
14
- ImageType = np.ndarray[Any, Any] | Image.Image | str | None
14
+ ImageType: TypeAlias = np.ndarray[Any, Any] | Image.Image | str | None
15
15
 
16
16
 
17
17
  def _is_numpy_array(observation: Any) -> TypeIs[np.ndarray]:
@@ -164,5 +164,5 @@ class Adapter:
164
164
  def adapt_list(self, actions: list[Any]) -> list[CLA]:
165
165
  if not isinstance(actions, list):
166
166
  raise ValueError("Please provide a list of actions")
167
-
167
+
168
168
  return [self.adapt(action) for action in actions]
File without changes
@@ -0,0 +1,277 @@
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ import io
5
+ from unittest.mock import MagicMock, patch
6
+
7
+ import numpy as np
8
+ import pytest
9
+ from PIL import Image
10
+
11
+ from hud.adapters.common import Adapter
12
+ from hud.adapters.common.types import ClickAction, Point, TypeAction
13
+
14
+
15
+ @pytest.fixture
16
+ def adapter():
17
+ """Fixture providing a clean adapter instance."""
18
+ return Adapter()
19
+
20
+
21
+ @pytest.fixture
22
+ def test_image():
23
+ """Fixture providing test image in various formats."""
24
+ img = Image.new("RGB", (100, 80), color="red")
25
+ img_bytes = io.BytesIO()
26
+ img.save(img_bytes, format="PNG")
27
+ img_base64 = base64.b64encode(img_bytes.getvalue()).decode("utf-8")
28
+ img_array = np.array(img)
29
+
30
+ return {
31
+ "pil": img,
32
+ "bytes": img_bytes.getvalue(),
33
+ "base64": img_base64,
34
+ "array": img_array,
35
+ }
36
+
37
+
38
+ def test_init(adapter):
39
+ """Test adapter initialization."""
40
+ assert adapter.agent_width == 1920
41
+ assert adapter.agent_height == 1080
42
+ assert adapter.env_width == 1920
43
+ assert adapter.env_height == 1080
44
+ assert adapter.memory == []
45
+
46
+
47
+ def test_preprocess(adapter):
48
+ """Test preprocess method (default implementation)."""
49
+ action = {"type": "click", "point": {"x": 100, "y": 100}}
50
+ result = adapter.preprocess(action)
51
+ assert result == action # Default implementation returns unchanged
52
+
53
+
54
+ def test_convert_valid(adapter):
55
+ """Test convert method with valid action."""
56
+ action = ClickAction(point=Point(x=100, y=100))
57
+ result = adapter.convert(action)
58
+ # Fix: Instead of checking against CLA, check it's the same type as the input
59
+ assert isinstance(result, ClickAction)
60
+ assert result == action
61
+
62
+
63
+ def test_convert_invalid(adapter):
64
+ """Test convert method with invalid action."""
65
+ with pytest.raises(ValueError):
66
+ adapter.convert(None) # type: ignore
67
+
68
+
69
+ def test_json_valid(adapter):
70
+ """Test json method with valid action."""
71
+ action = ClickAction(point=Point(x=100, y=100))
72
+ result = adapter.json(action)
73
+ assert isinstance(result, dict)
74
+ assert result["type"] == "click"
75
+ assert result["point"]["x"] == 100
76
+ assert result["point"]["y"] == 100
77
+
78
+
79
+ def test_json_invalid(adapter):
80
+ """Test json method with invalid action."""
81
+ with pytest.raises(ValueError):
82
+ adapter.json(None) # type: ignore
83
+
84
+
85
+ def test_rescale_pil_image(adapter, test_image):
86
+ """Test rescaling PIL Image."""
87
+ result = adapter.rescale(test_image["pil"])
88
+
89
+ # Verify result is base64 string
90
+ assert isinstance(result, str)
91
+
92
+ # Verify environment dimensions were updated
93
+ assert adapter.env_width == 100
94
+ assert adapter.env_height == 80
95
+
96
+ # Decode and verify image dimensions
97
+ img_bytes = base64.b64decode(result)
98
+ img = Image.open(io.BytesIO(img_bytes))
99
+ assert img.size == (adapter.agent_width, adapter.agent_height)
100
+
101
+
102
+ def test_rescale_numpy_array(adapter, test_image):
103
+ """Test rescaling numpy array."""
104
+ result = adapter.rescale(test_image["array"])
105
+
106
+ # Verify result is base64 string
107
+ assert isinstance(result, str)
108
+
109
+ # Verify environment dimensions were updated
110
+ assert adapter.env_width == 100
111
+ assert adapter.env_height == 80
112
+
113
+
114
+ def test_rescale_base64(adapter, test_image):
115
+ """Test rescaling base64 string."""
116
+ result = adapter.rescale(test_image["base64"])
117
+
118
+ # Verify result is base64 string
119
+ assert isinstance(result, str)
120
+
121
+ # Verify environment dimensions were updated
122
+ assert adapter.env_width == 100
123
+ assert adapter.env_height == 80
124
+
125
+
126
+ def test_rescale_base64_with_header(adapter, test_image):
127
+ """Test rescaling base64 string with header."""
128
+ base64_with_header = f"data:image/png;base64,{test_image['base64']}"
129
+ result = adapter.rescale(base64_with_header)
130
+
131
+ # Verify result is base64 string
132
+ assert isinstance(result, str)
133
+
134
+ # Verify environment dimensions were updated
135
+ assert adapter.env_width == 100
136
+ assert adapter.env_height == 80
137
+
138
+
139
+ def test_rescale_invalid_type(adapter):
140
+ """Test rescaling with invalid type."""
141
+ with pytest.raises(ValueError):
142
+ adapter.rescale(123) # type: ignore
143
+
144
+
145
+ def test_rescale_none(adapter):
146
+ """Test rescaling with None."""
147
+ result = adapter.rescale(None)
148
+ assert result is None
149
+
150
+
151
+ def test_postprocess_action_click(adapter):
152
+ """Test postprocess_action with click action."""
153
+ # Set different agent and env dimensions
154
+ adapter.agent_width = 1000
155
+ adapter.agent_height = 800
156
+ adapter.env_width = 2000
157
+ adapter.env_height = 1600
158
+
159
+ action = {"type": "click", "point": {"x": 500, "y": 400}}
160
+ result = adapter.postprocess_action(action)
161
+
162
+ # Coordinates should be doubled
163
+ assert result["point"]["x"] == 1000
164
+ assert result["point"]["y"] == 800
165
+
166
+
167
+ def test_postprocess_action_drag(adapter):
168
+ """Test postprocess_action with drag action."""
169
+ # Set different agent and env dimensions
170
+ adapter.agent_width = 1000
171
+ adapter.agent_height = 800
172
+ adapter.env_width = 2000
173
+ adapter.env_height = 1600
174
+
175
+ action = {"type": "drag", "path": [{"x": 100, "y": 200}, {"x": 300, "y": 400}]}
176
+ result = adapter.postprocess_action(action)
177
+
178
+ # Coordinates should be doubled
179
+ assert result["path"][0]["x"] == 200
180
+ assert result["path"][0]["y"] == 400
181
+ assert result["path"][1]["x"] == 600
182
+ assert result["path"][1]["y"] == 800
183
+
184
+
185
+ def test_postprocess_action_scroll(adapter):
186
+ """Test postprocess_action with scroll action."""
187
+ # Set different agent and env dimensions
188
+ adapter.agent_width = 1000
189
+ adapter.agent_height = 800
190
+ adapter.env_width = 2000
191
+ adapter.env_height = 1600
192
+
193
+ action = {"type": "scroll", "point": {"x": 500, "y": 400}, "scroll": {"x": 0, "y": 10}}
194
+ result = adapter.postprocess_action(action)
195
+
196
+ # Point coordinates should be doubled
197
+ assert result["point"]["x"] == 1000
198
+ assert result["point"]["y"] == 800
199
+ # Scroll amount should be scaled
200
+ assert result["scroll"]["x"] == 0
201
+ assert result["scroll"]["y"] == 20
202
+
203
+
204
+ def test_postprocess_action_empty(adapter):
205
+ """Test postprocess_action with empty action."""
206
+ result = adapter.postprocess_action({})
207
+ assert result == {}
208
+
209
+
210
+ def test_adapt(adapter):
211
+ """Test adapt method."""
212
+ # Mock the needed methods
213
+ with (
214
+ patch.object(adapter, "preprocess", return_value={"preprocessed": True}),
215
+ patch.object(adapter, "convert", return_value=TypeAction(text="test")),
216
+ patch.object(adapter, "json", return_value={"type": "type", "text": "test"}),
217
+ patch.object(adapter, "postprocess_action", return_value={"type": "type", "text": "test"}),
218
+ patch("hud.adapters.common.adapter.TypeAdapter") as mock_adapter,
219
+ ):
220
+ mock_validator = MagicMock()
221
+ mock_adapter.return_value = mock_validator
222
+ mock_validator.validate_python.return_value = TypeAction(text="test")
223
+
224
+ adapter.adapt({"raw": "action"})
225
+
226
+ # Verify the method chain was called correctly
227
+ adapter.preprocess.assert_called_once_with({"raw": "action"})
228
+ adapter.convert.assert_called_once_with({"preprocessed": True})
229
+ adapter.json.assert_called_once_with(TypeAction(text="test"))
230
+ adapter.postprocess_action.assert_called_once_with({"type": "type", "text": "test"})
231
+
232
+ # Verify the memory was updated
233
+ assert len(adapter.memory) == 1
234
+ assert adapter.memory[0] == TypeAction(text="test")
235
+
236
+
237
+ def test_adapt_list(adapter):
238
+ """Test adapt_list method."""
239
+ # Fix: Use side_effect to return different values for each call to adapt
240
+ click_action = ClickAction(point=Point(x=100, y=100))
241
+ type_action = TypeAction(text="test")
242
+
243
+ mock_adapt = MagicMock(side_effect=[click_action, type_action])
244
+ with patch.object(adapter, "adapt", mock_adapt):
245
+ actions = [{"type": "click"}, {"type": "type"}]
246
+ result = adapter.adapt_list(actions)
247
+
248
+ assert adapter.adapt.call_count == 2
249
+ assert len(result) == 2
250
+ assert result[0] == click_action
251
+ assert result[1] == type_action
252
+
253
+
254
+ def test_adapt_list_invalid(adapter):
255
+ """Test adapt_list with invalid input."""
256
+ with pytest.raises(ValueError):
257
+ adapter.adapt_list("not a list") # type: ignore
258
+
259
+
260
+ def test_integration(adapter):
261
+ """Integration test for the full adapter pipeline."""
262
+ adapter.agent_width = 1000
263
+ adapter.agent_height = 800
264
+ adapter.env_width = 2000
265
+ adapter.env_height = 1600
266
+
267
+ # Create a click action
268
+ action = ClickAction(point=Point(x=500, y=400))
269
+
270
+ result = adapter.adapt(action)
271
+
272
+ assert isinstance(result, ClickAction)
273
+ assert result.point is not None
274
+ assert result.point.x == 1000 # Scaled from 500 to 1000
275
+ assert result.point.y == 800 # Scaled from 400 to 800
276
+
277
+ assert len(adapter.memory) == 1
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Annotated, Literal
3
+ from typing import Annotated, Literal, TypeAlias
4
4
 
5
5
  from pydantic import BaseModel, Field
6
6
 
@@ -82,7 +82,7 @@ class DragAction(CLAAction):
82
82
  # RESPONSE ACTION from agent
83
83
  class ResponseAction(CLAAction):
84
84
  type: Literal["response"] = "response"
85
- text: str # The final textual response from the agent
85
+ text: str # The final textual response from the agent
86
86
 
87
87
 
88
88
  # SCREENSHOT ACTION
@@ -118,7 +118,7 @@ CLA = Annotated[
118
118
  ]
119
119
 
120
120
 
121
- CLAKey = Literal[
121
+ CLAKey: TypeAlias = Literal[
122
122
  # Control keys
123
123
  "backspace",
124
124
  "tab",
@@ -26,72 +26,65 @@ class OperatorAdapter(Adapter):
26
26
  "arrowleft": "left",
27
27
  "arrowright": "right",
28
28
  }
29
-
29
+
30
30
  def __init__(self) -> None:
31
31
  super().__init__()
32
32
  # OpenAI Computer Use default dimensions
33
33
  self.agent_width = 1024
34
34
  self.agent_height = 768
35
-
35
+
36
36
  def _map_key(self, key: str) -> CLAKey:
37
37
  """Map a key to its standardized form."""
38
38
  return self.KEY_MAP.get(key.lower(), key.lower()) # type: ignore
39
-
39
+
40
40
  def convert(self, data: Any) -> CLA:
41
41
  """Convert a Computer Use action to a HUD action"""
42
42
  try:
43
43
  action_type = data.get("type")
44
-
44
+
45
45
  if action_type == "click":
46
46
  x, y = data.get("x", 0), data.get("y", 0)
47
47
  button = data.get("button", "left")
48
48
  return ClickAction(point=Point(x=x, y=y), button=button)
49
-
49
+
50
50
  elif action_type == "double_click":
51
51
  x, y = data.get("x", 0), data.get("y", 0)
52
- return ClickAction(
53
- point=Point(x=x, y=y),
54
- button="left",
55
- pattern=[100]
56
- )
57
-
52
+ return ClickAction(point=Point(x=x, y=y), button="left", pattern=[100])
53
+
58
54
  elif action_type == "scroll":
59
55
  x, y = data.get("x", 0), data.get("y", 0)
60
56
  scroll_x = data.get("scroll_x", 0)
61
57
  scroll_y = data.get("scroll_y", 0)
62
- return ScrollAction(
63
- point=Point(x=x, y=y),
64
- scroll=Point(x=scroll_x, y=scroll_y)
65
- )
66
-
58
+ return ScrollAction(point=Point(x=x, y=y), scroll=Point(x=scroll_x, y=scroll_y))
59
+
67
60
  elif action_type == "type":
68
61
  text = data.get("text", "")
69
62
  return TypeAction(text=text, enter_after=False)
70
-
63
+
71
64
  elif action_type == "wait":
72
65
  ms = data.get("ms", 1000)
73
66
  return WaitAction(time=ms)
74
-
67
+
75
68
  elif action_type == "move":
76
69
  x, y = data.get("x", 0), data.get("y", 0)
77
70
  return MoveAction(point=Point(x=x, y=y))
78
-
71
+
79
72
  elif action_type == "keypress":
80
73
  keys = data.get("keys", [])
81
74
  return PressAction(keys=[self._map_key(k) for k in keys])
82
-
75
+
83
76
  elif action_type == "drag":
84
77
  path = data.get("path", [])
85
78
  points = [Point(x=p.get("x", 0), y=p.get("y", 0)) for p in path]
86
79
  return DragAction(path=points)
87
-
80
+
88
81
  elif action_type == "screenshot":
89
82
  return ScreenshotFetch()
90
-
83
+
91
84
  elif action_type == "response":
92
85
  return ResponseAction(text=data.get("text", ""))
93
86
  else:
94
87
  raise ValueError(f"Unsupported action type: {action_type}")
95
-
88
+
96
89
  except Exception as e:
97
90
  raise ValueError(f"Invalid action: {data}. Error: {e!s}") from e
hud/agent/__init__.py CHANGED
@@ -5,4 +5,11 @@ from .langchain import LangchainAgent
5
5
 
6
6
  from hud.adapters import OperatorAdapter, ClaudeAdapter
7
7
 
8
- __all__ = ["Agent", "ClaudeAgent", "OperatorAgent", "OperatorAdapter", "ClaudeAdapter", "LangchainAgent"]
8
+ __all__ = [
9
+ "Agent",
10
+ "ClaudeAgent",
11
+ "OperatorAgent",
12
+ "OperatorAdapter",
13
+ "ClaudeAdapter",
14
+ "LangchainAgent",
15
+ ]
hud/agent/base.py CHANGED
@@ -2,108 +2,108 @@ from abc import ABC, abstractmethod
2
2
  from typing import Sequence, TypeVar, Generic
3
3
 
4
4
  from hud.adapters import Adapter, CLA
5
- from hud.env.environment import Observation
5
+ from hud.utils.common import Observation
6
6
 
7
7
  # Generic type for different client types (Anthropic, OpenAI, etc.)
8
- ClientT = TypeVar('ClientT')
9
- ActionT = TypeVar('ActionT')
8
+ ClientT = TypeVar("ClientT")
9
+ ActionT = TypeVar("ActionT")
10
+
10
11
 
11
12
  class Agent(Generic[ClientT, ActionT], ABC):
12
13
  """
13
14
  Base class for all agents.
14
-
15
+
15
16
  Implements a three-stage prediction process:
16
17
  1. preprocess - Prepare observation data (e.g., rescale screenshot)
17
18
  2. fetch_response - Make API calls to get model response
18
19
  3. postprocess - Convert model actions to HUD format
19
-
20
+
20
21
  Subclasses only need to implement the fetch_response method.
21
22
  """
22
-
23
+
23
24
  def __init__(self, client: ClientT | None = None, adapter: Adapter | None = None):
24
25
  """
25
26
  Initialize the agent.
26
-
27
+
27
28
  Args:
28
29
  client: The client to use for API calls
29
30
  adapter: The adapter to use for preprocessing and postprocessing
30
31
  """
31
32
  self.client = client
32
33
  self.adapter = adapter
33
-
34
+
34
35
  def preprocess(self, observation: Observation) -> Observation:
35
36
  """
36
37
  Preprocess the observation before sending to the model.
37
-
38
+
38
39
  Args:
39
40
  observation: The raw observation from the environment
40
-
41
+
41
42
  Returns:
42
43
  Observation: The processed observation ready for the model
43
44
  """
44
45
  if not self.adapter or not observation.screenshot:
45
46
  return observation
46
-
47
+
47
48
  # Create a new observation with the rescaled screenshot
48
49
  processed_obs = Observation(
49
- text=observation.text,
50
- screenshot=self.adapter.rescale(observation.screenshot)
50
+ text=observation.text, screenshot=self.adapter.rescale(observation.screenshot)
51
51
  )
52
52
  return processed_obs
53
-
53
+
54
54
  @abstractmethod
55
55
  async def fetch_response(self, observation: Observation) -> tuple[list[ActionT], bool]:
56
56
  """
57
57
  Fetch a response from the model based on the observation.
58
-
58
+
59
59
  Args:
60
60
  observation: The preprocessed observation
61
-
61
+
62
62
  Returns:
63
63
  tuple[list[ActionT], bool]: A tuple containing the list of raw actions and a
64
64
  boolean indicating if the agent believes it has
65
65
  completed the task
66
66
  """
67
67
  pass
68
-
68
+
69
69
  def postprocess(self, actions: list[ActionT]) -> list[CLA]:
70
70
  """
71
71
  Convert model actions to HUD actions.
72
-
72
+
73
73
  Args:
74
74
  actions: The raw actions from the model
75
-
75
+
76
76
  Returns:
77
77
  Sequence[CLA]: The actions converted to HUD format
78
78
  """
79
79
  if not self.adapter:
80
80
  raise ValueError("Cannot postprocess actions without an adapter")
81
-
81
+
82
82
  return self.adapter.adapt_list(actions)
83
-
83
+
84
84
  async def predict(self, observation: Observation) -> tuple[list[CLA] | list[ActionT], bool]:
85
85
  """
86
86
  Predict the next action based on the observation.
87
-
87
+
88
88
  Implements the full three-stage prediction process.
89
-
89
+
90
90
  Args:
91
91
  observation: The observation from the environment
92
-
92
+
93
93
  Returns:
94
94
  tuple[list[CLA] | list[ActionT], bool]: A tuple containing the list of actions and a boolean
95
95
  indicating if the agent believes it has completed the task
96
96
  """
97
97
  # Stage 1: Preprocess the observation
98
98
  processed_obs = self.preprocess(observation)
99
-
99
+
100
100
  # Stage 2: Fetch response from the model
101
101
  actions, done = await self.fetch_response(processed_obs)
102
-
102
+
103
103
  # Stage 3: Postprocess the actions if we have an adapter
104
104
  if self.adapter and actions:
105
105
  hud_actions = self.postprocess(actions)
106
106
  return hud_actions, done
107
-
107
+
108
108
  # If no adapter, return actions as is
109
- return actions, done
109
+ return actions, done