openadapt-ml 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/baselines/__init__.py +121 -0
- openadapt_ml/baselines/adapter.py +185 -0
- openadapt_ml/baselines/cli.py +314 -0
- openadapt_ml/baselines/config.py +448 -0
- openadapt_ml/baselines/parser.py +922 -0
- openadapt_ml/baselines/prompts.py +787 -0
- openadapt_ml/benchmarks/__init__.py +13 -115
- openadapt_ml/benchmarks/agent.py +265 -421
- openadapt_ml/benchmarks/azure.py +28 -19
- openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
- openadapt_ml/benchmarks/cli.py +1722 -4847
- openadapt_ml/benchmarks/trace_export.py +631 -0
- openadapt_ml/benchmarks/viewer.py +22 -5
- openadapt_ml/benchmarks/vm_monitor.py +530 -29
- openadapt_ml/benchmarks/waa_deploy/Dockerfile +47 -53
- openadapt_ml/benchmarks/waa_deploy/api_agent.py +21 -20
- openadapt_ml/cloud/azure_inference.py +3 -5
- openadapt_ml/cloud/lambda_labs.py +722 -307
- openadapt_ml/cloud/local.py +2038 -487
- openadapt_ml/cloud/ssh_tunnel.py +68 -26
- openadapt_ml/datasets/next_action.py +40 -30
- openadapt_ml/evals/grounding.py +8 -3
- openadapt_ml/evals/plot_eval_metrics.py +15 -13
- openadapt_ml/evals/trajectory_matching.py +41 -26
- openadapt_ml/experiments/demo_prompt/format_demo.py +16 -6
- openadapt_ml/experiments/demo_prompt/run_experiment.py +26 -16
- openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
- openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
- openadapt_ml/experiments/representation_shootout/config.py +390 -0
- openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
- openadapt_ml/experiments/representation_shootout/runner.py +687 -0
- openadapt_ml/experiments/waa_demo/runner.py +29 -14
- openadapt_ml/export/parquet.py +36 -24
- openadapt_ml/grounding/detector.py +18 -14
- openadapt_ml/ingest/__init__.py +8 -6
- openadapt_ml/ingest/capture.py +25 -22
- openadapt_ml/ingest/loader.py +7 -4
- openadapt_ml/ingest/synthetic.py +189 -100
- openadapt_ml/models/api_adapter.py +14 -4
- openadapt_ml/models/base_adapter.py +10 -2
- openadapt_ml/models/providers/__init__.py +288 -0
- openadapt_ml/models/providers/anthropic.py +266 -0
- openadapt_ml/models/providers/base.py +299 -0
- openadapt_ml/models/providers/google.py +376 -0
- openadapt_ml/models/providers/openai.py +342 -0
- openadapt_ml/models/qwen_vl.py +46 -19
- openadapt_ml/perception/__init__.py +35 -0
- openadapt_ml/perception/integration.py +399 -0
- openadapt_ml/retrieval/demo_retriever.py +50 -24
- openadapt_ml/retrieval/embeddings.py +9 -8
- openadapt_ml/retrieval/retriever.py +3 -1
- openadapt_ml/runtime/__init__.py +50 -0
- openadapt_ml/runtime/policy.py +18 -5
- openadapt_ml/runtime/safety_gate.py +471 -0
- openadapt_ml/schema/__init__.py +9 -0
- openadapt_ml/schema/converters.py +74 -27
- openadapt_ml/schema/episode.py +31 -18
- openadapt_ml/scripts/capture_screenshots.py +530 -0
- openadapt_ml/scripts/compare.py +85 -54
- openadapt_ml/scripts/demo_policy.py +4 -1
- openadapt_ml/scripts/eval_policy.py +15 -9
- openadapt_ml/scripts/make_gif.py +1 -1
- openadapt_ml/scripts/prepare_synthetic.py +3 -1
- openadapt_ml/scripts/train.py +21 -9
- openadapt_ml/segmentation/README.md +920 -0
- openadapt_ml/segmentation/__init__.py +97 -0
- openadapt_ml/segmentation/adapters/__init__.py +5 -0
- openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
- openadapt_ml/segmentation/annotator.py +610 -0
- openadapt_ml/segmentation/cache.py +290 -0
- openadapt_ml/segmentation/cli.py +674 -0
- openadapt_ml/segmentation/deduplicator.py +656 -0
- openadapt_ml/segmentation/frame_describer.py +788 -0
- openadapt_ml/segmentation/pipeline.py +340 -0
- openadapt_ml/segmentation/schemas.py +622 -0
- openadapt_ml/segmentation/segment_extractor.py +634 -0
- openadapt_ml/training/azure_ops_viewer.py +1097 -0
- openadapt_ml/training/benchmark_viewer.py +52 -41
- openadapt_ml/training/shared_ui.py +7 -7
- openadapt_ml/training/stub_provider.py +57 -35
- openadapt_ml/training/trainer.py +143 -86
- openadapt_ml/training/trl_trainer.py +70 -21
- openadapt_ml/training/viewer.py +323 -108
- openadapt_ml/training/viewer_components.py +180 -0
- {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.2.dist-info}/METADATA +215 -14
- openadapt_ml-0.2.2.dist-info/RECORD +116 -0
- openadapt_ml/benchmarks/base.py +0 -366
- openadapt_ml/benchmarks/data_collection.py +0 -432
- openadapt_ml/benchmarks/live_tracker.py +0 -180
- openadapt_ml/benchmarks/runner.py +0 -418
- openadapt_ml/benchmarks/waa.py +0 -761
- openadapt_ml/benchmarks/waa_live.py +0 -619
- openadapt_ml-0.2.0.dist-info/RECORD +0 -86
- {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.2.dist-info}/WHEEL +0 -0
- {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.2.dist-info}/licenses/LICENSE +0 -0
openadapt_ml/benchmarks/agent.py
CHANGED
|
@@ -1,8 +1,15 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""ML-specific agents for benchmark evaluation.
|
|
2
2
|
|
|
3
|
-
This module provides
|
|
4
|
-
|
|
5
|
-
|
|
3
|
+
This module provides agents that wrap openadapt-ml components (VLM adapters,
|
|
4
|
+
policies, baselines) for benchmark evaluation.
|
|
5
|
+
|
|
6
|
+
For standalone agents without ML dependencies, use openadapt_evals:
|
|
7
|
+
from openadapt_evals import ApiAgent, ScriptedAgent, RandomAgent
|
|
8
|
+
|
|
9
|
+
ML-specific agents in this module:
|
|
10
|
+
- PolicyAgent: Wraps openadapt_ml.runtime.policy.AgentPolicy
|
|
11
|
+
- APIBenchmarkAgent: Uses openadapt_ml.models.api_adapter.ApiVLMAdapter
|
|
12
|
+
- UnifiedBaselineAgent: Uses openadapt_ml.baselines adapters
|
|
6
13
|
|
|
7
14
|
Example:
|
|
8
15
|
from openadapt_ml.benchmarks import PolicyAgent
|
|
@@ -12,7 +19,7 @@ Example:
|
|
|
12
19
|
agent = PolicyAgent(policy)
|
|
13
20
|
results = evaluate_agent_on_benchmark(agent, benchmark_adapter)
|
|
14
21
|
|
|
15
|
-
# API-backed agents (GPT-5.1, Claude)
|
|
22
|
+
# API-backed agents (GPT-5.1, Claude) using openadapt-ml adapters
|
|
16
23
|
from openadapt_ml.benchmarks import APIBenchmarkAgent
|
|
17
24
|
|
|
18
25
|
agent = APIBenchmarkAgent(provider="anthropic") # Uses Claude
|
|
@@ -22,13 +29,13 @@ Example:
|
|
|
22
29
|
|
|
23
30
|
from __future__ import annotations
|
|
24
31
|
|
|
25
|
-
import json
|
|
26
32
|
import re
|
|
27
|
-
from abc import ABC, abstractmethod
|
|
28
33
|
from typing import TYPE_CHECKING, Any
|
|
29
34
|
|
|
30
|
-
|
|
35
|
+
# Import base classes from openadapt-evals (canonical location)
|
|
36
|
+
from openadapt_evals import (
|
|
31
37
|
BenchmarkAction,
|
|
38
|
+
BenchmarkAgent,
|
|
32
39
|
BenchmarkObservation,
|
|
33
40
|
BenchmarkTask,
|
|
34
41
|
)
|
|
@@ -36,43 +43,7 @@ from openadapt_ml.benchmarks.base import (
|
|
|
36
43
|
if TYPE_CHECKING:
|
|
37
44
|
from openadapt_ml.models.api_adapter import ApiVLMAdapter
|
|
38
45
|
from openadapt_ml.runtime.policy import AgentPolicy
|
|
39
|
-
from openadapt_ml.schema import Action
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
class BenchmarkAgent(ABC):
|
|
43
|
-
"""Abstract interface for agents evaluated on benchmarks.
|
|
44
|
-
|
|
45
|
-
Agents must implement the `act` method to receive observations
|
|
46
|
-
and return actions. The agent can maintain internal state across
|
|
47
|
-
steps within an episode.
|
|
48
|
-
"""
|
|
49
|
-
|
|
50
|
-
@abstractmethod
|
|
51
|
-
def act(
|
|
52
|
-
self,
|
|
53
|
-
observation: BenchmarkObservation,
|
|
54
|
-
task: BenchmarkTask,
|
|
55
|
-
history: list[tuple[BenchmarkObservation, BenchmarkAction]] | None = None,
|
|
56
|
-
) -> BenchmarkAction:
|
|
57
|
-
"""Given observation and task, return next action.
|
|
58
|
-
|
|
59
|
-
Args:
|
|
60
|
-
observation: Current observation from the environment.
|
|
61
|
-
task: Task being performed.
|
|
62
|
-
history: Optional list of previous (observation, action) pairs.
|
|
63
|
-
|
|
64
|
-
Returns:
|
|
65
|
-
Action to execute.
|
|
66
|
-
"""
|
|
67
|
-
pass
|
|
68
|
-
|
|
69
|
-
def reset(self) -> None:
|
|
70
|
-
"""Reset agent state between episodes.
|
|
71
|
-
|
|
72
|
-
Called before starting a new task. Override to clear any
|
|
73
|
-
internal state.
|
|
74
|
-
"""
|
|
75
|
-
pass
|
|
46
|
+
from openadapt_ml.schema import Action
|
|
76
47
|
|
|
77
48
|
|
|
78
49
|
class PolicyAgent(BenchmarkAgent):
|
|
@@ -128,61 +99,37 @@ class PolicyAgent(BenchmarkAgent):
|
|
|
128
99
|
task: BenchmarkTask,
|
|
129
100
|
history: list[tuple[BenchmarkObservation, BenchmarkAction]] | None,
|
|
130
101
|
) -> dict:
|
|
131
|
-
"""Build SFT-style sample from benchmark observation.
|
|
132
|
-
|
|
133
|
-
Args:
|
|
134
|
-
observation: Current observation.
|
|
135
|
-
task: Current task.
|
|
136
|
-
history: Action history.
|
|
137
|
-
|
|
138
|
-
Returns:
|
|
139
|
-
Sample dict with 'images' and 'messages'.
|
|
140
|
-
"""
|
|
141
|
-
# Build user message content
|
|
102
|
+
"""Build SFT-style sample from benchmark observation."""
|
|
142
103
|
content_parts = [f"Goal: {task.instruction}"]
|
|
143
104
|
|
|
144
|
-
# Add accessibility tree if available and enabled
|
|
145
105
|
if self.use_accessibility_tree and observation.accessibility_tree:
|
|
146
106
|
tree_str = self._format_accessibility_tree(observation.accessibility_tree)
|
|
147
107
|
content_parts.append(f"UI Elements:\n{tree_str}")
|
|
148
108
|
|
|
149
|
-
# Add context
|
|
150
109
|
if observation.url:
|
|
151
110
|
content_parts.append(f"URL: {observation.url}")
|
|
152
111
|
if observation.window_title:
|
|
153
112
|
content_parts.append(f"Window: {observation.window_title}")
|
|
154
113
|
|
|
155
|
-
# Add history if enabled
|
|
156
114
|
if self.use_history and history:
|
|
157
115
|
history_str = self._format_history(history)
|
|
158
116
|
content_parts.append(f"Previous actions:\n{history_str}")
|
|
159
117
|
|
|
160
118
|
content_parts.append("What action should be taken next?")
|
|
161
119
|
|
|
162
|
-
# Build sample
|
|
163
120
|
sample = {
|
|
164
121
|
"messages": [
|
|
165
122
|
{"role": "user", "content": "\n\n".join(content_parts)},
|
|
166
123
|
],
|
|
167
124
|
}
|
|
168
125
|
|
|
169
|
-
# Add image if available
|
|
170
126
|
if observation.screenshot_path:
|
|
171
127
|
sample["images"] = [observation.screenshot_path]
|
|
172
128
|
|
|
173
129
|
return sample
|
|
174
130
|
|
|
175
131
|
def _format_accessibility_tree(self, tree: dict, indent: int = 0) -> str:
|
|
176
|
-
"""Format accessibility tree for prompt.
|
|
177
|
-
|
|
178
|
-
Args:
|
|
179
|
-
tree: Accessibility tree dict.
|
|
180
|
-
indent: Current indentation level.
|
|
181
|
-
|
|
182
|
-
Returns:
|
|
183
|
-
Formatted string representation.
|
|
184
|
-
"""
|
|
185
|
-
# Simple formatting - can be overridden for platform-specific formatting
|
|
132
|
+
"""Format accessibility tree for prompt."""
|
|
186
133
|
lines = []
|
|
187
134
|
prefix = " " * indent
|
|
188
135
|
|
|
@@ -203,29 +150,15 @@ class PolicyAgent(BenchmarkAgent):
|
|
|
203
150
|
def _format_history(
|
|
204
151
|
self, history: list[tuple[BenchmarkObservation, BenchmarkAction]]
|
|
205
152
|
) -> str:
|
|
206
|
-
"""Format action history for prompt.
|
|
207
|
-
|
|
208
|
-
Args:
|
|
209
|
-
history: List of (observation, action) pairs.
|
|
210
|
-
|
|
211
|
-
Returns:
|
|
212
|
-
Formatted string.
|
|
213
|
-
"""
|
|
153
|
+
"""Format action history for prompt."""
|
|
214
154
|
lines = []
|
|
215
|
-
for i, (obs, action) in enumerate(history[-5:], 1):
|
|
155
|
+
for i, (obs, action) in enumerate(history[-5:], 1):
|
|
216
156
|
action_str = self._action_to_string(action)
|
|
217
157
|
lines.append(f"{i}. {action_str}")
|
|
218
158
|
return "\n".join(lines)
|
|
219
159
|
|
|
220
160
|
def _action_to_string(self, action: BenchmarkAction) -> str:
|
|
221
|
-
"""Convert BenchmarkAction to string representation.
|
|
222
|
-
|
|
223
|
-
Args:
|
|
224
|
-
action: Action to convert.
|
|
225
|
-
|
|
226
|
-
Returns:
|
|
227
|
-
String representation.
|
|
228
|
-
"""
|
|
161
|
+
"""Convert BenchmarkAction to string representation."""
|
|
229
162
|
if action.type == "click":
|
|
230
163
|
if action.target_name:
|
|
231
164
|
return f"CLICK({action.target_name})"
|
|
@@ -250,29 +183,19 @@ class PolicyAgent(BenchmarkAgent):
|
|
|
250
183
|
def _to_benchmark_action(
|
|
251
184
|
self, action: Action, thought: str | None
|
|
252
185
|
) -> BenchmarkAction:
|
|
253
|
-
"""Convert openadapt-ml Action to BenchmarkAction.
|
|
254
|
-
|
|
255
|
-
Args:
|
|
256
|
-
action: Action from policy.
|
|
257
|
-
thought: Optional thought/reasoning.
|
|
258
|
-
|
|
259
|
-
Returns:
|
|
260
|
-
BenchmarkAction.
|
|
261
|
-
"""
|
|
262
|
-
# Extract normalized coordinates
|
|
186
|
+
"""Convert openadapt-ml Action to BenchmarkAction."""
|
|
263
187
|
x, y = None, None
|
|
264
188
|
if action.normalized_coordinates is not None:
|
|
265
189
|
x, y = action.normalized_coordinates
|
|
266
190
|
|
|
267
|
-
# Extract end coordinates for drag
|
|
268
191
|
end_x, end_y = None, None
|
|
269
192
|
if action.normalized_end is not None:
|
|
270
193
|
end_x, end_y = action.normalized_end
|
|
271
194
|
|
|
272
|
-
|
|
273
|
-
|
|
195
|
+
action_type = (
|
|
196
|
+
action.type.value if hasattr(action.type, "value") else action.type
|
|
197
|
+
)
|
|
274
198
|
|
|
275
|
-
# Extract element info if available
|
|
276
199
|
target_node_id = None
|
|
277
200
|
target_role = None
|
|
278
201
|
target_name = None
|
|
@@ -310,192 +233,28 @@ class PolicyAgent(BenchmarkAgent):
|
|
|
310
233
|
|
|
311
234
|
def reset(self) -> None:
|
|
312
235
|
"""Reset agent state."""
|
|
313
|
-
# PolicyAgent is stateless, nothing to reset
|
|
314
|
-
pass
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
class ScriptedAgent(BenchmarkAgent):
|
|
318
|
-
"""Agent that follows a predefined script of actions.
|
|
319
|
-
|
|
320
|
-
Useful for testing benchmark adapters or replaying trajectories.
|
|
321
|
-
|
|
322
|
-
Args:
|
|
323
|
-
actions: List of actions to execute in order.
|
|
324
|
-
"""
|
|
325
|
-
|
|
326
|
-
def __init__(self, actions: list[BenchmarkAction]):
|
|
327
|
-
self.actions = actions
|
|
328
|
-
self._step = 0
|
|
329
|
-
|
|
330
|
-
def act(
|
|
331
|
-
self,
|
|
332
|
-
observation: BenchmarkObservation,
|
|
333
|
-
task: BenchmarkTask,
|
|
334
|
-
history: list[tuple[BenchmarkObservation, BenchmarkAction]] | None = None,
|
|
335
|
-
) -> BenchmarkAction:
|
|
336
|
-
"""Return the next scripted action.
|
|
337
|
-
|
|
338
|
-
Args:
|
|
339
|
-
observation: Ignored.
|
|
340
|
-
task: Ignored.
|
|
341
|
-
history: Ignored.
|
|
342
|
-
|
|
343
|
-
Returns:
|
|
344
|
-
Next action from script, or DONE if script exhausted.
|
|
345
|
-
"""
|
|
346
|
-
if self._step < len(self.actions):
|
|
347
|
-
action = self.actions[self._step]
|
|
348
|
-
self._step += 1
|
|
349
|
-
return action
|
|
350
|
-
return BenchmarkAction(type="done")
|
|
351
|
-
|
|
352
|
-
def reset(self) -> None:
|
|
353
|
-
"""Reset step counter."""
|
|
354
|
-
self._step = 0
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
class RandomAgent(BenchmarkAgent):
|
|
358
|
-
"""Agent that takes random actions.
|
|
359
|
-
|
|
360
|
-
Useful for baseline comparisons.
|
|
361
|
-
|
|
362
|
-
Args:
|
|
363
|
-
action_types: List of action types to randomly select from.
|
|
364
|
-
seed: Random seed for reproducibility.
|
|
365
|
-
"""
|
|
366
|
-
|
|
367
|
-
def __init__(
|
|
368
|
-
self,
|
|
369
|
-
action_types: list[str] | None = None,
|
|
370
|
-
seed: int | None = None,
|
|
371
|
-
):
|
|
372
|
-
import random
|
|
373
|
-
|
|
374
|
-
self.action_types = action_types or ["click", "type", "scroll", "done"]
|
|
375
|
-
self.rng = random.Random(seed)
|
|
376
|
-
|
|
377
|
-
def act(
|
|
378
|
-
self,
|
|
379
|
-
observation: BenchmarkObservation,
|
|
380
|
-
task: BenchmarkTask,
|
|
381
|
-
history: list[tuple[BenchmarkObservation, BenchmarkAction]] | None = None,
|
|
382
|
-
) -> BenchmarkAction:
|
|
383
|
-
"""Return a random action.
|
|
384
|
-
|
|
385
|
-
Args:
|
|
386
|
-
observation: Used to get viewport bounds.
|
|
387
|
-
task: Ignored.
|
|
388
|
-
history: Used to decide when to stop.
|
|
389
|
-
|
|
390
|
-
Returns:
|
|
391
|
-
Random action.
|
|
392
|
-
"""
|
|
393
|
-
# Stop after many actions
|
|
394
|
-
if history and len(history) > 20:
|
|
395
|
-
return BenchmarkAction(type="done")
|
|
396
|
-
|
|
397
|
-
action_type = self.rng.choice(self.action_types)
|
|
398
|
-
|
|
399
|
-
if action_type == "click":
|
|
400
|
-
return BenchmarkAction(
|
|
401
|
-
type="click",
|
|
402
|
-
x=self.rng.random(),
|
|
403
|
-
y=self.rng.random(),
|
|
404
|
-
)
|
|
405
|
-
elif action_type == "type":
|
|
406
|
-
return BenchmarkAction(
|
|
407
|
-
type="type",
|
|
408
|
-
text="test",
|
|
409
|
-
)
|
|
410
|
-
elif action_type == "scroll":
|
|
411
|
-
return BenchmarkAction(
|
|
412
|
-
type="scroll",
|
|
413
|
-
scroll_direction=self.rng.choice(["up", "down"]),
|
|
414
|
-
)
|
|
415
|
-
else:
|
|
416
|
-
return BenchmarkAction(type="done")
|
|
417
|
-
|
|
418
|
-
def reset(self) -> None:
|
|
419
|
-
"""Nothing to reset."""
|
|
420
236
|
pass
|
|
421
237
|
|
|
422
238
|
|
|
423
|
-
class SmartMockAgent(BenchmarkAgent):
|
|
424
|
-
"""Agent designed to pass WAAMockAdapter evaluation.
|
|
425
|
-
|
|
426
|
-
Performs a fixed sequence of actions that satisfy the mock adapter's
|
|
427
|
-
success criteria. Use for validating the benchmark pipeline locally.
|
|
428
|
-
|
|
429
|
-
The mock adapter evaluates success based on:
|
|
430
|
-
- Clicking Submit (ID 4) - primary success path
|
|
431
|
-
- Typing something AND clicking OK (ID 1) - form submission path
|
|
432
|
-
- Calling DONE after at least 2 actions - reasonable completion
|
|
433
|
-
|
|
434
|
-
This agent clicks Submit (ID 4) which is the simplest success path.
|
|
435
|
-
"""
|
|
436
|
-
|
|
437
|
-
def __init__(self):
|
|
438
|
-
"""Initialize the agent."""
|
|
439
|
-
self._step = 0
|
|
440
|
-
# Simple action sequence: click Submit button (ID 4), then done
|
|
441
|
-
self._actions = [
|
|
442
|
-
BenchmarkAction(type="click", target_node_id="4"), # Click Submit
|
|
443
|
-
BenchmarkAction(type="done"),
|
|
444
|
-
]
|
|
445
|
-
|
|
446
|
-
def act(
|
|
447
|
-
self,
|
|
448
|
-
observation: BenchmarkObservation,
|
|
449
|
-
task: BenchmarkTask,
|
|
450
|
-
history: list[tuple[BenchmarkObservation, BenchmarkAction]] | None = None,
|
|
451
|
-
) -> BenchmarkAction:
|
|
452
|
-
"""Return the next scripted action.
|
|
453
|
-
|
|
454
|
-
Args:
|
|
455
|
-
observation: Ignored.
|
|
456
|
-
task: Ignored.
|
|
457
|
-
history: Ignored.
|
|
458
|
-
|
|
459
|
-
Returns:
|
|
460
|
-
Next action from script, or DONE if script exhausted.
|
|
461
|
-
"""
|
|
462
|
-
if self._step < len(self._actions):
|
|
463
|
-
action = self._actions[self._step]
|
|
464
|
-
self._step += 1
|
|
465
|
-
return action
|
|
466
|
-
return BenchmarkAction(type="done")
|
|
467
|
-
|
|
468
|
-
def reset(self) -> None:
|
|
469
|
-
"""Reset step counter."""
|
|
470
|
-
self._step = 0
|
|
471
|
-
|
|
472
|
-
|
|
473
239
|
class APIBenchmarkAgent(BenchmarkAgent):
|
|
474
|
-
"""Agent that uses hosted VLM APIs
|
|
240
|
+
"""Agent that uses hosted VLM APIs via openadapt-ml ApiVLMAdapter.
|
|
475
241
|
|
|
476
242
|
This agent wraps ApiVLMAdapter to provide Claude or GPT-5.1 baselines
|
|
477
243
|
for benchmark evaluation. It converts BenchmarkObservation to the
|
|
478
244
|
API format and parses VLM responses into BenchmarkActions.
|
|
479
245
|
|
|
246
|
+
Note: For standalone API evaluation without openadapt-ml, use
|
|
247
|
+
openadapt_evals.ApiAgent instead (has P0 demo persistence fix).
|
|
248
|
+
|
|
480
249
|
Args:
|
|
481
250
|
provider: API provider - "anthropic" (Claude) or "openai" (GPT-5.1).
|
|
482
251
|
api_key: Optional API key override. If not provided, uses env vars.
|
|
483
|
-
model: Optional model name override.
|
|
252
|
+
model: Optional model name override.
|
|
484
253
|
max_tokens: Maximum tokens for VLM response.
|
|
485
254
|
use_accessibility_tree: Whether to include accessibility tree in prompt.
|
|
486
255
|
use_history: Whether to include action history in prompt.
|
|
487
|
-
|
|
488
|
-
Example:
|
|
489
|
-
# Claude baseline
|
|
490
|
-
agent = APIBenchmarkAgent(provider="anthropic")
|
|
491
|
-
results = evaluate_agent_on_benchmark(agent, waa_adapter)
|
|
492
|
-
|
|
493
|
-
# GPT-5.1 baseline
|
|
494
|
-
agent = APIBenchmarkAgent(provider="openai")
|
|
495
|
-
results = evaluate_agent_on_benchmark(agent, waa_adapter)
|
|
496
256
|
"""
|
|
497
257
|
|
|
498
|
-
# System prompt for GUI automation
|
|
499
258
|
SYSTEM_PROMPT = """You are a GUI automation agent. Given a screenshot and task instruction, determine the next action to take.
|
|
500
259
|
|
|
501
260
|
Available actions:
|
|
@@ -505,7 +264,7 @@ Available actions:
|
|
|
505
264
|
- KEY(key) - Press a key (e.g., Enter, Tab, Escape)
|
|
506
265
|
- KEY(modifier+key) - Press key combination (e.g., Ctrl+c, Alt+Tab)
|
|
507
266
|
- SCROLL(direction) - Scroll up or down
|
|
508
|
-
- DRAG(x1, y1, x2, y2) - Drag from (x1,y1) to (x2,y2)
|
|
267
|
+
- DRAG(x1, y1, x2, y2) - Drag from (x1,y1) to (x2,y2)
|
|
509
268
|
- DONE() - Task is complete
|
|
510
269
|
- ANSWER("response") - For QA tasks, provide the answer
|
|
511
270
|
|
|
@@ -554,32 +313,15 @@ Then output the action on a new line starting with "ACTION:"
|
|
|
554
313
|
task: BenchmarkTask,
|
|
555
314
|
history: list[tuple[BenchmarkObservation, BenchmarkAction]] | None = None,
|
|
556
315
|
) -> BenchmarkAction:
|
|
557
|
-
"""Use VLM API to determine next action.
|
|
558
|
-
|
|
559
|
-
Args:
|
|
560
|
-
observation: Current observation with screenshot.
|
|
561
|
-
task: Task being performed.
|
|
562
|
-
history: Previous observations and actions.
|
|
563
|
-
|
|
564
|
-
Returns:
|
|
565
|
-
BenchmarkAction parsed from VLM response.
|
|
566
|
-
"""
|
|
316
|
+
"""Use VLM API to determine next action."""
|
|
567
317
|
adapter = self._get_adapter()
|
|
568
|
-
|
|
569
|
-
# Build the sample for the API
|
|
570
318
|
sample = self._build_sample(observation, task, history)
|
|
571
319
|
|
|
572
|
-
# Call the VLM API
|
|
573
320
|
try:
|
|
574
321
|
response = adapter.generate(sample, max_new_tokens=self.max_tokens)
|
|
575
322
|
except Exception as e:
|
|
576
|
-
|
|
577
|
-
return BenchmarkAction(
|
|
578
|
-
type="done",
|
|
579
|
-
raw_action={"error": str(e)},
|
|
580
|
-
)
|
|
323
|
+
return BenchmarkAction(type="done", raw_action={"error": str(e)})
|
|
581
324
|
|
|
582
|
-
# Parse the response into a BenchmarkAction
|
|
583
325
|
return self._parse_response(response, observation)
|
|
584
326
|
|
|
585
327
|
def _build_sample(
|
|
@@ -588,41 +330,26 @@ Then output the action on a new line starting with "ACTION:"
|
|
|
588
330
|
task: BenchmarkTask,
|
|
589
331
|
history: list[tuple[BenchmarkObservation, BenchmarkAction]] | None,
|
|
590
332
|
) -> dict[str, Any]:
|
|
591
|
-
"""Build API sample from benchmark observation.
|
|
592
|
-
|
|
593
|
-
Args:
|
|
594
|
-
observation: Current observation.
|
|
595
|
-
task: Current task.
|
|
596
|
-
history: Action history.
|
|
597
|
-
|
|
598
|
-
Returns:
|
|
599
|
-
Sample dict with 'images' and 'messages'.
|
|
600
|
-
"""
|
|
601
|
-
# Build user message content
|
|
333
|
+
"""Build API sample from benchmark observation."""
|
|
602
334
|
content_parts = [f"GOAL: {task.instruction}"]
|
|
603
335
|
|
|
604
|
-
# Add context
|
|
605
336
|
if observation.url:
|
|
606
337
|
content_parts.append(f"URL: {observation.url}")
|
|
607
338
|
if observation.window_title:
|
|
608
339
|
content_parts.append(f"Window: {observation.window_title}")
|
|
609
340
|
|
|
610
|
-
# Add accessibility tree if available and enabled
|
|
611
341
|
if self.use_accessibility_tree and observation.accessibility_tree:
|
|
612
342
|
tree_str = self._format_accessibility_tree(observation.accessibility_tree)
|
|
613
|
-
# Truncate if too long
|
|
614
343
|
if len(tree_str) > 4000:
|
|
615
344
|
tree_str = tree_str[:4000] + "\n... (truncated)"
|
|
616
345
|
content_parts.append(f"UI Elements:\n{tree_str}")
|
|
617
346
|
|
|
618
|
-
# Add history if enabled
|
|
619
347
|
if self.use_history and history:
|
|
620
348
|
history_str = self._format_history(history)
|
|
621
349
|
content_parts.append(f"Previous actions:\n{history_str}")
|
|
622
350
|
|
|
623
351
|
content_parts.append("\nWhat is the next action?")
|
|
624
352
|
|
|
625
|
-
# Build sample
|
|
626
353
|
sample: dict[str, Any] = {
|
|
627
354
|
"messages": [
|
|
628
355
|
{"role": "system", "content": self.SYSTEM_PROMPT},
|
|
@@ -630,22 +357,13 @@ Then output the action on a new line starting with "ACTION:"
|
|
|
630
357
|
],
|
|
631
358
|
}
|
|
632
359
|
|
|
633
|
-
# Add image if available
|
|
634
360
|
if observation.screenshot_path:
|
|
635
361
|
sample["images"] = [observation.screenshot_path]
|
|
636
362
|
|
|
637
363
|
return sample
|
|
638
364
|
|
|
639
365
|
def _format_accessibility_tree(self, tree: dict, indent: int = 0) -> str:
|
|
640
|
-
"""Format accessibility tree for prompt.
|
|
641
|
-
|
|
642
|
-
Args:
|
|
643
|
-
tree: Accessibility tree dict.
|
|
644
|
-
indent: Current indentation level.
|
|
645
|
-
|
|
646
|
-
Returns:
|
|
647
|
-
Formatted string representation.
|
|
648
|
-
"""
|
|
366
|
+
"""Format accessibility tree for prompt."""
|
|
649
367
|
lines = []
|
|
650
368
|
prefix = " " * indent
|
|
651
369
|
|
|
@@ -666,29 +384,15 @@ Then output the action on a new line starting with "ACTION:"
|
|
|
666
384
|
def _format_history(
|
|
667
385
|
self, history: list[tuple[BenchmarkObservation, BenchmarkAction]]
|
|
668
386
|
) -> str:
|
|
669
|
-
"""Format action history for prompt.
|
|
670
|
-
|
|
671
|
-
Args:
|
|
672
|
-
history: List of (observation, action) pairs.
|
|
673
|
-
|
|
674
|
-
Returns:
|
|
675
|
-
Formatted string.
|
|
676
|
-
"""
|
|
387
|
+
"""Format action history for prompt."""
|
|
677
388
|
lines = []
|
|
678
|
-
for i, (obs, action) in enumerate(history[-5:], 1):
|
|
389
|
+
for i, (obs, action) in enumerate(history[-5:], 1):
|
|
679
390
|
action_str = self._action_to_string(action)
|
|
680
391
|
lines.append(f"{i}. {action_str}")
|
|
681
392
|
return "\n".join(lines)
|
|
682
393
|
|
|
683
394
|
def _action_to_string(self, action: BenchmarkAction) -> str:
|
|
684
|
-
"""Convert BenchmarkAction to string representation.
|
|
685
|
-
|
|
686
|
-
Args:
|
|
687
|
-
action: Action to convert.
|
|
688
|
-
|
|
689
|
-
Returns:
|
|
690
|
-
String representation.
|
|
691
|
-
"""
|
|
395
|
+
"""Convert BenchmarkAction to string representation."""
|
|
692
396
|
if action.type == "click":
|
|
693
397
|
if action.target_node_id:
|
|
694
398
|
return f"CLICK([{action.target_node_id}])"
|
|
@@ -717,32 +421,14 @@ Then output the action on a new line starting with "ACTION:"
|
|
|
717
421
|
def _parse_response(
|
|
718
422
|
self, response: str, observation: BenchmarkObservation | None = None
|
|
719
423
|
) -> BenchmarkAction:
|
|
720
|
-
"""Parse VLM response into BenchmarkAction.
|
|
721
|
-
|
|
722
|
-
Handles various response formats:
|
|
723
|
-
- ACTION: CLICK(0.5, 0.3)
|
|
724
|
-
- CLICK(0.5, 0.3)
|
|
725
|
-
- I'll click at coordinates (0.5, 0.3) -> CLICK(0.5, 0.3)
|
|
726
|
-
|
|
727
|
-
Args:
|
|
728
|
-
response: Raw VLM response text.
|
|
729
|
-
observation: Current observation (used for coordinate normalization).
|
|
730
|
-
|
|
731
|
-
Returns:
|
|
732
|
-
Parsed BenchmarkAction.
|
|
733
|
-
"""
|
|
734
|
-
# Store raw response for debugging
|
|
424
|
+
"""Parse VLM response into BenchmarkAction."""
|
|
735
425
|
raw_action = {"response": response}
|
|
736
426
|
|
|
737
|
-
# Extract action line (look for ACTION: prefix or action pattern)
|
|
738
427
|
action_line = None
|
|
739
|
-
|
|
740
|
-
# Try to find ACTION: prefix
|
|
741
428
|
action_match = re.search(r"ACTION:\s*(.+)", response, re.IGNORECASE)
|
|
742
429
|
if action_match:
|
|
743
430
|
action_line = action_match.group(1).strip()
|
|
744
431
|
else:
|
|
745
|
-
# Look for action pattern anywhere in response
|
|
746
432
|
patterns = [
|
|
747
433
|
r"(CLICK\s*\([^)]+\))",
|
|
748
434
|
r"(TYPE\s*\([^)]+\))",
|
|
@@ -759,146 +445,304 @@ Then output the action on a new line starting with "ACTION:"
|
|
|
759
445
|
break
|
|
760
446
|
|
|
761
447
|
if not action_line:
|
|
762
|
-
# Could not parse action, return done
|
|
763
448
|
raw_action["parse_error"] = "No action pattern found"
|
|
764
449
|
return BenchmarkAction(type="done", raw_action=raw_action)
|
|
765
450
|
|
|
766
|
-
# Parse CLICK
|
|
451
|
+
# Parse CLICK([id])
|
|
767
452
|
click_match = re.match(
|
|
768
453
|
r"CLICK\s*\(\s*\[?(\d+)\]?\s*\)", action_line, re.IGNORECASE
|
|
769
454
|
)
|
|
770
455
|
if click_match:
|
|
771
|
-
# CLICK([id]) - element ID
|
|
772
456
|
node_id = click_match.group(1)
|
|
773
457
|
return BenchmarkAction(
|
|
774
|
-
type="click",
|
|
775
|
-
target_node_id=node_id,
|
|
776
|
-
raw_action=raw_action,
|
|
458
|
+
type="click", target_node_id=node_id, raw_action=raw_action
|
|
777
459
|
)
|
|
778
460
|
|
|
461
|
+
# Parse CLICK(x, y)
|
|
779
462
|
click_coords = re.match(
|
|
780
463
|
r"CLICK\s*\(\s*([\d.]+)\s*,\s*([\d.]+)\s*\)", action_line, re.IGNORECASE
|
|
781
464
|
)
|
|
782
465
|
if click_coords:
|
|
783
|
-
# CLICK(x, y) - coordinates
|
|
784
466
|
x = float(click_coords.group(1))
|
|
785
467
|
y = float(click_coords.group(2))
|
|
786
|
-
|
|
787
|
-
# Normalize coordinates if they appear to be pixel values
|
|
788
|
-
# If x or y > 1.0, assume pixel coordinates and normalize using viewport
|
|
789
468
|
if observation and observation.viewport and (x > 1.0 or y > 1.0):
|
|
790
469
|
width, height = observation.viewport
|
|
791
|
-
x_norm = x / width
|
|
792
|
-
y_norm = y / height
|
|
793
470
|
raw_action["original_coords"] = {"x": x, "y": y}
|
|
794
471
|
raw_action["normalized"] = True
|
|
795
|
-
x =
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
return BenchmarkAction(
|
|
799
|
-
type="click",
|
|
800
|
-
x=x,
|
|
801
|
-
y=y,
|
|
802
|
-
raw_action=raw_action,
|
|
803
|
-
)
|
|
472
|
+
x, y = x / width, y / height
|
|
473
|
+
return BenchmarkAction(type="click", x=x, y=y, raw_action=raw_action)
|
|
804
474
|
|
|
805
|
-
# Parse TYPE
|
|
475
|
+
# Parse TYPE
|
|
806
476
|
type_match = re.match(
|
|
807
477
|
r"TYPE\s*\(\s*[\"'](.+?)[\"']\s*\)", action_line, re.IGNORECASE
|
|
808
478
|
)
|
|
809
479
|
if type_match:
|
|
810
|
-
text = type_match.group(1)
|
|
811
480
|
return BenchmarkAction(
|
|
812
|
-
type="type",
|
|
813
|
-
text=text,
|
|
814
|
-
raw_action=raw_action,
|
|
481
|
+
type="type", text=type_match.group(1), raw_action=raw_action
|
|
815
482
|
)
|
|
816
483
|
|
|
817
|
-
# Parse KEY
|
|
484
|
+
# Parse KEY
|
|
818
485
|
key_match = re.match(r"KEY\s*\(\s*(.+?)\s*\)", action_line, re.IGNORECASE)
|
|
819
486
|
if key_match:
|
|
820
487
|
key_str = key_match.group(1)
|
|
821
|
-
# Handle modifier+key format
|
|
822
488
|
if "+" in key_str:
|
|
823
489
|
parts = key_str.split("+")
|
|
824
|
-
key = parts[-1]
|
|
825
|
-
modifiers = parts[:-1]
|
|
826
490
|
return BenchmarkAction(
|
|
827
491
|
type="key",
|
|
828
|
-
key=
|
|
829
|
-
modifiers=
|
|
492
|
+
key=parts[-1],
|
|
493
|
+
modifiers=parts[:-1],
|
|
830
494
|
raw_action=raw_action,
|
|
831
495
|
)
|
|
832
|
-
return BenchmarkAction(
|
|
833
|
-
type="key",
|
|
834
|
-
key=key_str,
|
|
835
|
-
raw_action=raw_action,
|
|
836
|
-
)
|
|
496
|
+
return BenchmarkAction(type="key", key=key_str, raw_action=raw_action)
|
|
837
497
|
|
|
838
|
-
# Parse SCROLL
|
|
498
|
+
# Parse SCROLL
|
|
839
499
|
scroll_match = re.match(
|
|
840
500
|
r"SCROLL\s*\(\s*(up|down)\s*\)", action_line, re.IGNORECASE
|
|
841
501
|
)
|
|
842
502
|
if scroll_match:
|
|
843
|
-
direction = scroll_match.group(1).lower()
|
|
844
503
|
return BenchmarkAction(
|
|
845
504
|
type="scroll",
|
|
846
|
-
scroll_direction=
|
|
505
|
+
scroll_direction=scroll_match.group(1).lower(),
|
|
847
506
|
raw_action=raw_action,
|
|
848
507
|
)
|
|
849
508
|
|
|
850
|
-
# Parse DRAG
|
|
509
|
+
# Parse DRAG
|
|
851
510
|
drag_match = re.match(
|
|
852
511
|
r"DRAG\s*\(\s*([\d.]+)\s*,\s*([\d.]+)\s*,\s*([\d.]+)\s*,\s*([\d.]+)\s*\)",
|
|
853
512
|
action_line,
|
|
854
513
|
re.IGNORECASE,
|
|
855
514
|
)
|
|
856
515
|
if drag_match:
|
|
857
|
-
x = float(drag_match.group(1))
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
516
|
+
x, y = float(drag_match.group(1)), float(drag_match.group(2))
|
|
517
|
+
end_x, end_y = float(drag_match.group(3)), float(drag_match.group(4))
|
|
518
|
+
if (
|
|
519
|
+
observation
|
|
520
|
+
and observation.viewport
|
|
521
|
+
and (x > 1.0 or y > 1.0 or end_x > 1.0 or end_y > 1.0)
|
|
522
|
+
):
|
|
864
523
|
width, height = observation.viewport
|
|
865
|
-
raw_action["original_coords"] = {
|
|
524
|
+
raw_action["original_coords"] = {
|
|
525
|
+
"x": x,
|
|
526
|
+
"y": y,
|
|
527
|
+
"end_x": end_x,
|
|
528
|
+
"end_y": end_y,
|
|
529
|
+
}
|
|
866
530
|
raw_action["normalized"] = True
|
|
867
|
-
x
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
531
|
+
x, y, end_x, end_y = (
|
|
532
|
+
x / width,
|
|
533
|
+
y / height,
|
|
534
|
+
end_x / width,
|
|
535
|
+
end_y / height,
|
|
536
|
+
)
|
|
872
537
|
return BenchmarkAction(
|
|
873
|
-
type="drag",
|
|
874
|
-
x=x,
|
|
875
|
-
y=y,
|
|
876
|
-
end_x=end_x,
|
|
877
|
-
end_y=end_y,
|
|
878
|
-
raw_action=raw_action,
|
|
538
|
+
type="drag", x=x, y=y, end_x=end_x, end_y=end_y, raw_action=raw_action
|
|
879
539
|
)
|
|
880
540
|
|
|
881
|
-
# Parse DONE
|
|
541
|
+
# Parse DONE
|
|
882
542
|
if re.match(r"DONE\s*\(\s*\)", action_line, re.IGNORECASE):
|
|
883
543
|
return BenchmarkAction(type="done", raw_action=raw_action)
|
|
884
544
|
|
|
885
|
-
# Parse ANSWER
|
|
545
|
+
# Parse ANSWER
|
|
886
546
|
answer_match = re.match(
|
|
887
547
|
r"ANSWER\s*\(\s*[\"'](.+?)[\"']\s*\)", action_line, re.IGNORECASE
|
|
888
548
|
)
|
|
889
549
|
if answer_match:
|
|
890
|
-
answer = answer_match.group(1)
|
|
891
550
|
return BenchmarkAction(
|
|
892
|
-
type="answer",
|
|
893
|
-
answer=answer,
|
|
894
|
-
raw_action=raw_action,
|
|
551
|
+
type="answer", answer=answer_match.group(1), raw_action=raw_action
|
|
895
552
|
)
|
|
896
553
|
|
|
897
|
-
# Unknown action format
|
|
898
554
|
raw_action["parse_error"] = f"Unknown action format: {action_line}"
|
|
899
555
|
return BenchmarkAction(type="done", raw_action=raw_action)
|
|
900
556
|
|
|
901
557
|
def reset(self) -> None:
|
|
902
558
|
"""Reset agent state."""
|
|
903
|
-
# APIBenchmarkAgent is stateless, nothing to reset
|
|
904
559
|
pass
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
class UnifiedBaselineAgent(BenchmarkAgent):
|
|
563
|
+
"""Agent that uses UnifiedBaselineAdapter for benchmark evaluation.
|
|
564
|
+
|
|
565
|
+
Provides unified interface for Claude, GPT, and Gemini baselines
|
|
566
|
+
across multiple tracks (A: coordinates, B: ReAct, C: SoM).
|
|
567
|
+
|
|
568
|
+
Args:
|
|
569
|
+
model_alias: Model alias (e.g., 'claude-opus-4.5', 'gpt-5.2').
|
|
570
|
+
track: Track type ('A', 'B', or 'C'). Defaults to 'A'.
|
|
571
|
+
api_key: Optional API key override.
|
|
572
|
+
temperature: Sampling temperature.
|
|
573
|
+
max_tokens: Maximum tokens for response.
|
|
574
|
+
demo: Optional demo text for prompts.
|
|
575
|
+
verbose: Whether to print debug output.
|
|
576
|
+
"""
|
|
577
|
+
|
|
578
|
+
def __init__(
|
|
579
|
+
self,
|
|
580
|
+
model_alias: str = "claude-opus-4.5",
|
|
581
|
+
track: str = "A",
|
|
582
|
+
api_key: str | None = None,
|
|
583
|
+
temperature: float = 0.1,
|
|
584
|
+
max_tokens: int = 1024,
|
|
585
|
+
demo: str | None = None,
|
|
586
|
+
verbose: bool = False,
|
|
587
|
+
):
|
|
588
|
+
self.model_alias = model_alias
|
|
589
|
+
self.track = track.upper()
|
|
590
|
+
self.api_key = api_key
|
|
591
|
+
self.temperature = temperature
|
|
592
|
+
self.max_tokens = max_tokens
|
|
593
|
+
self.demo = demo
|
|
594
|
+
self.verbose = verbose
|
|
595
|
+
self._adapter = None
|
|
596
|
+
|
|
597
|
+
def _get_adapter(self):
|
|
598
|
+
"""Lazily initialize the UnifiedBaselineAdapter."""
|
|
599
|
+
if self._adapter is None:
|
|
600
|
+
from openadapt_ml.baselines import TrackConfig, UnifiedBaselineAdapter
|
|
601
|
+
|
|
602
|
+
track_configs = {
|
|
603
|
+
"A": TrackConfig.track_a(),
|
|
604
|
+
"B": TrackConfig.track_b(),
|
|
605
|
+
"C": TrackConfig.track_c(),
|
|
606
|
+
}
|
|
607
|
+
track_config = track_configs.get(self.track, TrackConfig.track_a())
|
|
608
|
+
|
|
609
|
+
self._adapter = UnifiedBaselineAdapter.from_alias(
|
|
610
|
+
self.model_alias,
|
|
611
|
+
track=track_config,
|
|
612
|
+
api_key=self.api_key,
|
|
613
|
+
temperature=self.temperature,
|
|
614
|
+
max_tokens=self.max_tokens,
|
|
615
|
+
demo=self.demo,
|
|
616
|
+
verbose=self.verbose,
|
|
617
|
+
)
|
|
618
|
+
return self._adapter
|
|
619
|
+
|
|
620
|
+
def act(
|
|
621
|
+
self,
|
|
622
|
+
observation: BenchmarkObservation,
|
|
623
|
+
task: BenchmarkTask,
|
|
624
|
+
history: list[tuple[BenchmarkObservation, BenchmarkAction]] | None = None,
|
|
625
|
+
) -> BenchmarkAction:
|
|
626
|
+
"""Use UnifiedBaselineAdapter to determine next action."""
|
|
627
|
+
from PIL import Image
|
|
628
|
+
|
|
629
|
+
adapter = self._get_adapter()
|
|
630
|
+
|
|
631
|
+
screenshot = None
|
|
632
|
+
if observation.screenshot_path:
|
|
633
|
+
try:
|
|
634
|
+
screenshot = Image.open(observation.screenshot_path)
|
|
635
|
+
except Exception as e:
|
|
636
|
+
if self.verbose:
|
|
637
|
+
print(f"[UnifiedBaselineAgent] Failed to load screenshot: {e}")
|
|
638
|
+
|
|
639
|
+
a11y_tree = (
|
|
640
|
+
observation.accessibility_tree if observation.accessibility_tree else None
|
|
641
|
+
)
|
|
642
|
+
|
|
643
|
+
adapter_history = None
|
|
644
|
+
if history:
|
|
645
|
+
adapter_history = [
|
|
646
|
+
self._benchmark_action_to_dict(a) for _, a in history[-5:]
|
|
647
|
+
]
|
|
648
|
+
|
|
649
|
+
try:
|
|
650
|
+
parsed_action = adapter.predict(
|
|
651
|
+
screenshot=screenshot,
|
|
652
|
+
goal=task.instruction,
|
|
653
|
+
a11y_tree=a11y_tree,
|
|
654
|
+
history=adapter_history,
|
|
655
|
+
)
|
|
656
|
+
except Exception as e:
|
|
657
|
+
if self.verbose:
|
|
658
|
+
print(f"[UnifiedBaselineAgent] Adapter error: {e}")
|
|
659
|
+
return BenchmarkAction(type="done", raw_action={"error": str(e)})
|
|
660
|
+
|
|
661
|
+
return self._parsed_to_benchmark_action(parsed_action, observation)
|
|
662
|
+
|
|
663
|
+
def _benchmark_action_to_dict(self, action: BenchmarkAction) -> dict[str, Any]:
|
|
664
|
+
"""Convert BenchmarkAction to dict for history."""
|
|
665
|
+
result = {"type": action.type}
|
|
666
|
+
if action.x is not None:
|
|
667
|
+
result["x"] = action.x
|
|
668
|
+
if action.y is not None:
|
|
669
|
+
result["y"] = action.y
|
|
670
|
+
if action.text:
|
|
671
|
+
result["text"] = action.text
|
|
672
|
+
if action.key:
|
|
673
|
+
result["key"] = action.key
|
|
674
|
+
if action.target_node_id:
|
|
675
|
+
result["element_id"] = action.target_node_id
|
|
676
|
+
if action.scroll_direction:
|
|
677
|
+
result["direction"] = action.scroll_direction
|
|
678
|
+
return result
|
|
679
|
+
|
|
680
|
+
def _parsed_to_benchmark_action(
|
|
681
|
+
self, parsed_action, observation: BenchmarkObservation | None = None
|
|
682
|
+
) -> BenchmarkAction:
|
|
683
|
+
"""Convert ParsedAction to BenchmarkAction."""
|
|
684
|
+
raw_action = {
|
|
685
|
+
"raw_response": parsed_action.raw_response,
|
|
686
|
+
"thought": parsed_action.thought,
|
|
687
|
+
}
|
|
688
|
+
|
|
689
|
+
if not parsed_action.is_valid:
|
|
690
|
+
raw_action["parse_error"] = parsed_action.parse_error
|
|
691
|
+
return BenchmarkAction(type="done", raw_action=raw_action)
|
|
692
|
+
|
|
693
|
+
action_type = parsed_action.action_type
|
|
694
|
+
|
|
695
|
+
if action_type == "click":
|
|
696
|
+
if parsed_action.element_id is not None:
|
|
697
|
+
return BenchmarkAction(
|
|
698
|
+
type="click",
|
|
699
|
+
target_node_id=str(parsed_action.element_id),
|
|
700
|
+
raw_action=raw_action,
|
|
701
|
+
)
|
|
702
|
+
elif parsed_action.x is not None and parsed_action.y is not None:
|
|
703
|
+
x, y = parsed_action.x, parsed_action.y
|
|
704
|
+
if observation and observation.viewport and (x > 1.0 or y > 1.0):
|
|
705
|
+
width, height = observation.viewport
|
|
706
|
+
raw_action["original_coords"] = {"x": x, "y": y}
|
|
707
|
+
x, y = x / width, y / height
|
|
708
|
+
return BenchmarkAction(type="click", x=x, y=y, raw_action=raw_action)
|
|
709
|
+
|
|
710
|
+
elif action_type == "type":
|
|
711
|
+
return BenchmarkAction(
|
|
712
|
+
type="type", text=parsed_action.text, raw_action=raw_action
|
|
713
|
+
)
|
|
714
|
+
|
|
715
|
+
elif action_type == "key":
|
|
716
|
+
return BenchmarkAction(
|
|
717
|
+
type="key", key=parsed_action.key, raw_action=raw_action
|
|
718
|
+
)
|
|
719
|
+
|
|
720
|
+
elif action_type == "scroll":
|
|
721
|
+
return BenchmarkAction(
|
|
722
|
+
type="scroll",
|
|
723
|
+
scroll_direction=parsed_action.direction,
|
|
724
|
+
raw_action=raw_action,
|
|
725
|
+
)
|
|
726
|
+
|
|
727
|
+
elif action_type == "done":
|
|
728
|
+
return BenchmarkAction(type="done", raw_action=raw_action)
|
|
729
|
+
|
|
730
|
+
elif action_type == "drag":
|
|
731
|
+
return BenchmarkAction(
|
|
732
|
+
type="drag",
|
|
733
|
+
x=parsed_action.x,
|
|
734
|
+
y=parsed_action.y,
|
|
735
|
+
end_x=getattr(parsed_action, "end_x", None),
|
|
736
|
+
end_y=getattr(parsed_action, "end_y", None),
|
|
737
|
+
raw_action=raw_action,
|
|
738
|
+
)
|
|
739
|
+
|
|
740
|
+
raw_action["unknown_action"] = action_type
|
|
741
|
+
return BenchmarkAction(type="done", raw_action=raw_action)
|
|
742
|
+
|
|
743
|
+
def reset(self) -> None:
|
|
744
|
+
"""Reset agent state."""
|
|
745
|
+
pass
|
|
746
|
+
|
|
747
|
+
def __repr__(self) -> str:
|
|
748
|
+
return f"UnifiedBaselineAgent(model={self.model_alias}, track={self.track})"
|