openadapt-ml 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/__init__.py +0 -0
- openadapt_ml/benchmarks/__init__.py +125 -0
- openadapt_ml/benchmarks/agent.py +825 -0
- openadapt_ml/benchmarks/azure.py +761 -0
- openadapt_ml/benchmarks/base.py +366 -0
- openadapt_ml/benchmarks/cli.py +884 -0
- openadapt_ml/benchmarks/data_collection.py +432 -0
- openadapt_ml/benchmarks/runner.py +381 -0
- openadapt_ml/benchmarks/waa.py +704 -0
- openadapt_ml/cloud/__init__.py +5 -0
- openadapt_ml/cloud/azure_inference.py +441 -0
- openadapt_ml/cloud/lambda_labs.py +2445 -0
- openadapt_ml/cloud/local.py +790 -0
- openadapt_ml/config.py +56 -0
- openadapt_ml/datasets/__init__.py +0 -0
- openadapt_ml/datasets/next_action.py +507 -0
- openadapt_ml/evals/__init__.py +23 -0
- openadapt_ml/evals/grounding.py +241 -0
- openadapt_ml/evals/plot_eval_metrics.py +174 -0
- openadapt_ml/evals/trajectory_matching.py +486 -0
- openadapt_ml/grounding/__init__.py +45 -0
- openadapt_ml/grounding/base.py +236 -0
- openadapt_ml/grounding/detector.py +570 -0
- openadapt_ml/ingest/__init__.py +43 -0
- openadapt_ml/ingest/capture.py +312 -0
- openadapt_ml/ingest/loader.py +232 -0
- openadapt_ml/ingest/synthetic.py +1102 -0
- openadapt_ml/models/__init__.py +0 -0
- openadapt_ml/models/api_adapter.py +171 -0
- openadapt_ml/models/base_adapter.py +59 -0
- openadapt_ml/models/dummy_adapter.py +42 -0
- openadapt_ml/models/qwen_vl.py +426 -0
- openadapt_ml/runtime/__init__.py +0 -0
- openadapt_ml/runtime/policy.py +182 -0
- openadapt_ml/schemas/__init__.py +53 -0
- openadapt_ml/schemas/sessions.py +122 -0
- openadapt_ml/schemas/validation.py +252 -0
- openadapt_ml/scripts/__init__.py +0 -0
- openadapt_ml/scripts/compare.py +1490 -0
- openadapt_ml/scripts/demo_policy.py +62 -0
- openadapt_ml/scripts/eval_policy.py +287 -0
- openadapt_ml/scripts/make_gif.py +153 -0
- openadapt_ml/scripts/prepare_synthetic.py +43 -0
- openadapt_ml/scripts/run_qwen_login_benchmark.py +192 -0
- openadapt_ml/scripts/train.py +174 -0
- openadapt_ml/training/__init__.py +0 -0
- openadapt_ml/training/benchmark_viewer.py +1538 -0
- openadapt_ml/training/shared_ui.py +157 -0
- openadapt_ml/training/stub_provider.py +276 -0
- openadapt_ml/training/trainer.py +2446 -0
- openadapt_ml/training/viewer.py +2970 -0
- openadapt_ml-0.1.0.dist-info/METADATA +818 -0
- openadapt_ml-0.1.0.dist-info/RECORD +55 -0
- openadapt_ml-0.1.0.dist-info/WHEEL +4 -0
- openadapt_ml-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,825 @@
|
|
|
1
|
+
"""Agent interface for benchmark evaluation.
|
|
2
|
+
|
|
3
|
+
This module provides the BenchmarkAgent interface that agents must implement
|
|
4
|
+
to be evaluated on benchmarks, plus adapters to wrap existing openadapt-ml
|
|
5
|
+
components.
|
|
6
|
+
|
|
7
|
+
Example:
|
|
8
|
+
from openadapt_ml.benchmarks import PolicyAgent
|
|
9
|
+
from openadapt_ml.runtime.policy import AgentPolicy
|
|
10
|
+
|
|
11
|
+
policy = AgentPolicy(adapter)
|
|
12
|
+
agent = PolicyAgent(policy)
|
|
13
|
+
results = evaluate_agent_on_benchmark(agent, benchmark_adapter)
|
|
14
|
+
|
|
15
|
+
# API-backed agents (GPT-5.1, Claude)
|
|
16
|
+
from openadapt_ml.benchmarks import APIBenchmarkAgent
|
|
17
|
+
|
|
18
|
+
agent = APIBenchmarkAgent(provider="anthropic") # Uses Claude
|
|
19
|
+
agent = APIBenchmarkAgent(provider="openai") # Uses GPT-5.1
|
|
20
|
+
results = evaluate_agent_on_benchmark(agent, benchmark_adapter)
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from __future__ import annotations
|
|
24
|
+
|
|
25
|
+
import json
|
|
26
|
+
import re
|
|
27
|
+
from abc import ABC, abstractmethod
|
|
28
|
+
from typing import TYPE_CHECKING, Any
|
|
29
|
+
|
|
30
|
+
from openadapt_ml.benchmarks.base import (
|
|
31
|
+
BenchmarkAction,
|
|
32
|
+
BenchmarkObservation,
|
|
33
|
+
BenchmarkTask,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
if TYPE_CHECKING:
|
|
37
|
+
from openadapt_ml.models.api_adapter import ApiVLMAdapter
|
|
38
|
+
from openadapt_ml.runtime.policy import AgentPolicy
|
|
39
|
+
from openadapt_ml.schemas.sessions import Action
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class BenchmarkAgent(ABC):
|
|
43
|
+
"""Abstract interface for agents evaluated on benchmarks.
|
|
44
|
+
|
|
45
|
+
Agents must implement the `act` method to receive observations
|
|
46
|
+
and return actions. The agent can maintain internal state across
|
|
47
|
+
steps within an episode.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
@abstractmethod
|
|
51
|
+
def act(
|
|
52
|
+
self,
|
|
53
|
+
observation: BenchmarkObservation,
|
|
54
|
+
task: BenchmarkTask,
|
|
55
|
+
history: list[tuple[BenchmarkObservation, BenchmarkAction]] | None = None,
|
|
56
|
+
) -> BenchmarkAction:
|
|
57
|
+
"""Given observation and task, return next action.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
observation: Current observation from the environment.
|
|
61
|
+
task: Task being performed.
|
|
62
|
+
history: Optional list of previous (observation, action) pairs.
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
Action to execute.
|
|
66
|
+
"""
|
|
67
|
+
pass
|
|
68
|
+
|
|
69
|
+
def reset(self) -> None:
|
|
70
|
+
"""Reset agent state between episodes.
|
|
71
|
+
|
|
72
|
+
Called before starting a new task. Override to clear any
|
|
73
|
+
internal state.
|
|
74
|
+
"""
|
|
75
|
+
pass
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class PolicyAgent(BenchmarkAgent):
|
|
79
|
+
"""Wraps openadapt-ml AgentPolicy for benchmark evaluation.
|
|
80
|
+
|
|
81
|
+
Converts between BenchmarkObservation/BenchmarkAction and the
|
|
82
|
+
SFT sample format expected by AgentPolicy.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
policy: AgentPolicy instance to wrap.
|
|
86
|
+
use_accessibility_tree: Whether to include accessibility tree in prompt.
|
|
87
|
+
use_history: Whether to include action history in prompt.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
def __init__(
|
|
91
|
+
self,
|
|
92
|
+
policy: AgentPolicy,
|
|
93
|
+
use_accessibility_tree: bool = True,
|
|
94
|
+
use_history: bool = True,
|
|
95
|
+
):
|
|
96
|
+
self.policy = policy
|
|
97
|
+
self.use_accessibility_tree = use_accessibility_tree
|
|
98
|
+
self.use_history = use_history
|
|
99
|
+
|
|
100
|
+
def act(
|
|
101
|
+
self,
|
|
102
|
+
observation: BenchmarkObservation,
|
|
103
|
+
task: BenchmarkTask,
|
|
104
|
+
history: list[tuple[BenchmarkObservation, BenchmarkAction]] | None = None,
|
|
105
|
+
) -> BenchmarkAction:
|
|
106
|
+
"""Convert observation to SFT sample and get action from policy.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
observation: Benchmark observation.
|
|
110
|
+
task: Benchmark task.
|
|
111
|
+
history: Previous observations and actions.
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
BenchmarkAction from policy.
|
|
115
|
+
"""
|
|
116
|
+
# Build SFT-style sample
|
|
117
|
+
sample = self._build_sample(observation, task, history)
|
|
118
|
+
|
|
119
|
+
# Get action from policy
|
|
120
|
+
action, thought = self.policy.predict(sample)
|
|
121
|
+
|
|
122
|
+
# Convert to BenchmarkAction
|
|
123
|
+
return self._to_benchmark_action(action, thought)
|
|
124
|
+
|
|
125
|
+
def _build_sample(
|
|
126
|
+
self,
|
|
127
|
+
observation: BenchmarkObservation,
|
|
128
|
+
task: BenchmarkTask,
|
|
129
|
+
history: list[tuple[BenchmarkObservation, BenchmarkAction]] | None,
|
|
130
|
+
) -> dict:
|
|
131
|
+
"""Build SFT-style sample from benchmark observation.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
observation: Current observation.
|
|
135
|
+
task: Current task.
|
|
136
|
+
history: Action history.
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
Sample dict with 'images' and 'messages'.
|
|
140
|
+
"""
|
|
141
|
+
# Build user message content
|
|
142
|
+
content_parts = [f"Goal: {task.instruction}"]
|
|
143
|
+
|
|
144
|
+
# Add accessibility tree if available and enabled
|
|
145
|
+
if self.use_accessibility_tree and observation.accessibility_tree:
|
|
146
|
+
tree_str = self._format_accessibility_tree(observation.accessibility_tree)
|
|
147
|
+
content_parts.append(f"UI Elements:\n{tree_str}")
|
|
148
|
+
|
|
149
|
+
# Add context
|
|
150
|
+
if observation.url:
|
|
151
|
+
content_parts.append(f"URL: {observation.url}")
|
|
152
|
+
if observation.window_title:
|
|
153
|
+
content_parts.append(f"Window: {observation.window_title}")
|
|
154
|
+
|
|
155
|
+
# Add history if enabled
|
|
156
|
+
if self.use_history and history:
|
|
157
|
+
history_str = self._format_history(history)
|
|
158
|
+
content_parts.append(f"Previous actions:\n{history_str}")
|
|
159
|
+
|
|
160
|
+
content_parts.append("What action should be taken next?")
|
|
161
|
+
|
|
162
|
+
# Build sample
|
|
163
|
+
sample = {
|
|
164
|
+
"messages": [
|
|
165
|
+
{"role": "user", "content": "\n\n".join(content_parts)},
|
|
166
|
+
],
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
# Add image if available
|
|
170
|
+
if observation.screenshot_path:
|
|
171
|
+
sample["images"] = [observation.screenshot_path]
|
|
172
|
+
|
|
173
|
+
return sample
|
|
174
|
+
|
|
175
|
+
def _format_accessibility_tree(self, tree: dict, indent: int = 0) -> str:
|
|
176
|
+
"""Format accessibility tree for prompt.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
tree: Accessibility tree dict.
|
|
180
|
+
indent: Current indentation level.
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
Formatted string representation.
|
|
184
|
+
"""
|
|
185
|
+
# Simple formatting - can be overridden for platform-specific formatting
|
|
186
|
+
lines = []
|
|
187
|
+
prefix = " " * indent
|
|
188
|
+
|
|
189
|
+
role = tree.get("role", "unknown")
|
|
190
|
+
name = tree.get("name", "")
|
|
191
|
+
node_id = tree.get("id", tree.get("node_id", ""))
|
|
192
|
+
|
|
193
|
+
line = f"{prefix}[{node_id}] {role}"
|
|
194
|
+
if name:
|
|
195
|
+
line += f": {name}"
|
|
196
|
+
lines.append(line)
|
|
197
|
+
|
|
198
|
+
for child in tree.get("children", []):
|
|
199
|
+
lines.append(self._format_accessibility_tree(child, indent + 1))
|
|
200
|
+
|
|
201
|
+
return "\n".join(lines)
|
|
202
|
+
|
|
203
|
+
def _format_history(
|
|
204
|
+
self, history: list[tuple[BenchmarkObservation, BenchmarkAction]]
|
|
205
|
+
) -> str:
|
|
206
|
+
"""Format action history for prompt.
|
|
207
|
+
|
|
208
|
+
Args:
|
|
209
|
+
history: List of (observation, action) pairs.
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
Formatted string.
|
|
213
|
+
"""
|
|
214
|
+
lines = []
|
|
215
|
+
for i, (obs, action) in enumerate(history[-5:], 1): # Last 5 actions
|
|
216
|
+
action_str = self._action_to_string(action)
|
|
217
|
+
lines.append(f"{i}. {action_str}")
|
|
218
|
+
return "\n".join(lines)
|
|
219
|
+
|
|
220
|
+
def _action_to_string(self, action: BenchmarkAction) -> str:
|
|
221
|
+
"""Convert BenchmarkAction to string representation.
|
|
222
|
+
|
|
223
|
+
Args:
|
|
224
|
+
action: Action to convert.
|
|
225
|
+
|
|
226
|
+
Returns:
|
|
227
|
+
String representation.
|
|
228
|
+
"""
|
|
229
|
+
if action.type == "click":
|
|
230
|
+
if action.target_name:
|
|
231
|
+
return f"CLICK({action.target_name})"
|
|
232
|
+
return f"CLICK(x={action.x:.3f}, y={action.y:.3f})"
|
|
233
|
+
elif action.type == "type":
|
|
234
|
+
return f"TYPE({action.text!r})"
|
|
235
|
+
elif action.type == "key":
|
|
236
|
+
mods = "+".join(action.modifiers or [])
|
|
237
|
+
key = action.key
|
|
238
|
+
if mods:
|
|
239
|
+
return f"KEY({mods}+{key})"
|
|
240
|
+
return f"KEY({key})"
|
|
241
|
+
elif action.type == "scroll":
|
|
242
|
+
return f"SCROLL({action.scroll_direction})"
|
|
243
|
+
elif action.type == "done":
|
|
244
|
+
return "DONE()"
|
|
245
|
+
elif action.type == "answer":
|
|
246
|
+
return f"ANSWER({action.answer!r})"
|
|
247
|
+
else:
|
|
248
|
+
return f"{action.type.upper()}()"
|
|
249
|
+
|
|
250
|
+
def _to_benchmark_action(
|
|
251
|
+
self, action: Action, thought: str | None
|
|
252
|
+
) -> BenchmarkAction:
|
|
253
|
+
"""Convert openadapt-ml Action to BenchmarkAction.
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
action: Action from policy.
|
|
257
|
+
thought: Optional thought/reasoning.
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
BenchmarkAction.
|
|
261
|
+
"""
|
|
262
|
+
return BenchmarkAction(
|
|
263
|
+
type=action.type,
|
|
264
|
+
x=action.x,
|
|
265
|
+
y=action.y,
|
|
266
|
+
text=action.text,
|
|
267
|
+
target_bbox=action.bbox,
|
|
268
|
+
# Map additional fields if present
|
|
269
|
+
target_node_id=getattr(action, "target_node_id", None),
|
|
270
|
+
target_role=getattr(action, "target_role", None),
|
|
271
|
+
target_name=getattr(action, "target_name", None),
|
|
272
|
+
key=getattr(action, "key", None),
|
|
273
|
+
modifiers=getattr(action, "modifiers", None),
|
|
274
|
+
scroll_direction=getattr(action, "scroll_direction", None),
|
|
275
|
+
scroll_amount=getattr(action, "scroll_amount", None),
|
|
276
|
+
end_x=getattr(action, "end_x", None),
|
|
277
|
+
end_y=getattr(action, "end_y", None),
|
|
278
|
+
answer=getattr(action, "answer", None),
|
|
279
|
+
raw_action={"thought": thought} if thought else None,
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
def reset(self) -> None:
|
|
283
|
+
"""Reset agent state."""
|
|
284
|
+
# PolicyAgent is stateless, nothing to reset
|
|
285
|
+
pass
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
class ScriptedAgent(BenchmarkAgent):
|
|
289
|
+
"""Agent that follows a predefined script of actions.
|
|
290
|
+
|
|
291
|
+
Useful for testing benchmark adapters or replaying trajectories.
|
|
292
|
+
|
|
293
|
+
Args:
|
|
294
|
+
actions: List of actions to execute in order.
|
|
295
|
+
"""
|
|
296
|
+
|
|
297
|
+
def __init__(self, actions: list[BenchmarkAction]):
|
|
298
|
+
self.actions = actions
|
|
299
|
+
self._step = 0
|
|
300
|
+
|
|
301
|
+
def act(
|
|
302
|
+
self,
|
|
303
|
+
observation: BenchmarkObservation,
|
|
304
|
+
task: BenchmarkTask,
|
|
305
|
+
history: list[tuple[BenchmarkObservation, BenchmarkAction]] | None = None,
|
|
306
|
+
) -> BenchmarkAction:
|
|
307
|
+
"""Return the next scripted action.
|
|
308
|
+
|
|
309
|
+
Args:
|
|
310
|
+
observation: Ignored.
|
|
311
|
+
task: Ignored.
|
|
312
|
+
history: Ignored.
|
|
313
|
+
|
|
314
|
+
Returns:
|
|
315
|
+
Next action from script, or DONE if script exhausted.
|
|
316
|
+
"""
|
|
317
|
+
if self._step < len(self.actions):
|
|
318
|
+
action = self.actions[self._step]
|
|
319
|
+
self._step += 1
|
|
320
|
+
return action
|
|
321
|
+
return BenchmarkAction(type="done")
|
|
322
|
+
|
|
323
|
+
def reset(self) -> None:
|
|
324
|
+
"""Reset step counter."""
|
|
325
|
+
self._step = 0
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
class RandomAgent(BenchmarkAgent):
|
|
329
|
+
"""Agent that takes random actions.
|
|
330
|
+
|
|
331
|
+
Useful for baseline comparisons.
|
|
332
|
+
|
|
333
|
+
Args:
|
|
334
|
+
action_types: List of action types to randomly select from.
|
|
335
|
+
seed: Random seed for reproducibility.
|
|
336
|
+
"""
|
|
337
|
+
|
|
338
|
+
def __init__(
|
|
339
|
+
self,
|
|
340
|
+
action_types: list[str] | None = None,
|
|
341
|
+
seed: int | None = None,
|
|
342
|
+
):
|
|
343
|
+
import random
|
|
344
|
+
|
|
345
|
+
self.action_types = action_types or ["click", "type", "scroll", "done"]
|
|
346
|
+
self.rng = random.Random(seed)
|
|
347
|
+
|
|
348
|
+
def act(
|
|
349
|
+
self,
|
|
350
|
+
observation: BenchmarkObservation,
|
|
351
|
+
task: BenchmarkTask,
|
|
352
|
+
history: list[tuple[BenchmarkObservation, BenchmarkAction]] | None = None,
|
|
353
|
+
) -> BenchmarkAction:
|
|
354
|
+
"""Return a random action.
|
|
355
|
+
|
|
356
|
+
Args:
|
|
357
|
+
observation: Used to get viewport bounds.
|
|
358
|
+
task: Ignored.
|
|
359
|
+
history: Used to decide when to stop.
|
|
360
|
+
|
|
361
|
+
Returns:
|
|
362
|
+
Random action.
|
|
363
|
+
"""
|
|
364
|
+
# Stop after many actions
|
|
365
|
+
if history and len(history) > 20:
|
|
366
|
+
return BenchmarkAction(type="done")
|
|
367
|
+
|
|
368
|
+
action_type = self.rng.choice(self.action_types)
|
|
369
|
+
|
|
370
|
+
if action_type == "click":
|
|
371
|
+
return BenchmarkAction(
|
|
372
|
+
type="click",
|
|
373
|
+
x=self.rng.random(),
|
|
374
|
+
y=self.rng.random(),
|
|
375
|
+
)
|
|
376
|
+
elif action_type == "type":
|
|
377
|
+
return BenchmarkAction(
|
|
378
|
+
type="type",
|
|
379
|
+
text="test",
|
|
380
|
+
)
|
|
381
|
+
elif action_type == "scroll":
|
|
382
|
+
return BenchmarkAction(
|
|
383
|
+
type="scroll",
|
|
384
|
+
scroll_direction=self.rng.choice(["up", "down"]),
|
|
385
|
+
)
|
|
386
|
+
else:
|
|
387
|
+
return BenchmarkAction(type="done")
|
|
388
|
+
|
|
389
|
+
def reset(self) -> None:
|
|
390
|
+
"""Nothing to reset."""
|
|
391
|
+
pass
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
class APIBenchmarkAgent(BenchmarkAgent):
|
|
395
|
+
"""Agent that uses hosted VLM APIs (Claude, GPT-5.1) for benchmark evaluation.
|
|
396
|
+
|
|
397
|
+
This agent wraps ApiVLMAdapter to provide Claude or GPT-5.1 baselines
|
|
398
|
+
for benchmark evaluation. It converts BenchmarkObservation to the
|
|
399
|
+
API format and parses VLM responses into BenchmarkActions.
|
|
400
|
+
|
|
401
|
+
Args:
|
|
402
|
+
provider: API provider - "anthropic" (Claude) or "openai" (GPT-5.1).
|
|
403
|
+
api_key: Optional API key override. If not provided, uses env vars.
|
|
404
|
+
model: Optional model name override. Defaults to provider's best VLM.
|
|
405
|
+
max_tokens: Maximum tokens for VLM response.
|
|
406
|
+
use_accessibility_tree: Whether to include accessibility tree in prompt.
|
|
407
|
+
use_history: Whether to include action history in prompt.
|
|
408
|
+
|
|
409
|
+
Example:
|
|
410
|
+
# Claude baseline
|
|
411
|
+
agent = APIBenchmarkAgent(provider="anthropic")
|
|
412
|
+
results = evaluate_agent_on_benchmark(agent, waa_adapter)
|
|
413
|
+
|
|
414
|
+
# GPT-5.1 baseline
|
|
415
|
+
agent = APIBenchmarkAgent(provider="openai")
|
|
416
|
+
results = evaluate_agent_on_benchmark(agent, waa_adapter)
|
|
417
|
+
"""
|
|
418
|
+
|
|
419
|
+
# System prompt for GUI automation
|
|
420
|
+
SYSTEM_PROMPT = """You are a GUI automation agent. Given a screenshot and task instruction, determine the next action to take.
|
|
421
|
+
|
|
422
|
+
Available actions:
|
|
423
|
+
- CLICK(x, y) - Click at coordinates (can be pixel values or normalized 0.0-1.0)
|
|
424
|
+
- CLICK([id]) - Click element with given ID from accessibility tree
|
|
425
|
+
- TYPE("text") - Type the given text
|
|
426
|
+
- KEY(key) - Press a key (e.g., Enter, Tab, Escape)
|
|
427
|
+
- KEY(modifier+key) - Press key combination (e.g., Ctrl+c, Alt+Tab)
|
|
428
|
+
- SCROLL(direction) - Scroll up or down
|
|
429
|
+
- DRAG(x1, y1, x2, y2) - Drag from (x1,y1) to (x2,y2) (pixel or normalized)
|
|
430
|
+
- DONE() - Task is complete
|
|
431
|
+
- ANSWER("response") - For QA tasks, provide the answer
|
|
432
|
+
|
|
433
|
+
Respond with exactly ONE action in the format shown above.
|
|
434
|
+
If the task appears complete, use DONE().
|
|
435
|
+
|
|
436
|
+
Think step by step:
|
|
437
|
+
1. What is the current state of the UI?
|
|
438
|
+
2. What is the goal?
|
|
439
|
+
3. What is the next logical action?
|
|
440
|
+
|
|
441
|
+
Then output the action on a new line starting with "ACTION:"
|
|
442
|
+
"""
|
|
443
|
+
|
|
444
|
+
def __init__(
|
|
445
|
+
self,
|
|
446
|
+
provider: str = "anthropic",
|
|
447
|
+
api_key: str | None = None,
|
|
448
|
+
model: str | None = None,
|
|
449
|
+
max_tokens: int = 512,
|
|
450
|
+
use_accessibility_tree: bool = True,
|
|
451
|
+
use_history: bool = True,
|
|
452
|
+
):
|
|
453
|
+
self.provider = provider
|
|
454
|
+
self.api_key = api_key
|
|
455
|
+
self.model = model
|
|
456
|
+
self.max_tokens = max_tokens
|
|
457
|
+
self.use_accessibility_tree = use_accessibility_tree
|
|
458
|
+
self.use_history = use_history
|
|
459
|
+
self._adapter: ApiVLMAdapter | None = None
|
|
460
|
+
|
|
461
|
+
def _get_adapter(self) -> "ApiVLMAdapter":
|
|
462
|
+
"""Lazily initialize the API adapter."""
|
|
463
|
+
if self._adapter is None:
|
|
464
|
+
from openadapt_ml.models.api_adapter import ApiVLMAdapter
|
|
465
|
+
|
|
466
|
+
self._adapter = ApiVLMAdapter(
|
|
467
|
+
provider=self.provider,
|
|
468
|
+
api_key=self.api_key,
|
|
469
|
+
)
|
|
470
|
+
return self._adapter
|
|
471
|
+
|
|
472
|
+
def act(
|
|
473
|
+
self,
|
|
474
|
+
observation: BenchmarkObservation,
|
|
475
|
+
task: BenchmarkTask,
|
|
476
|
+
history: list[tuple[BenchmarkObservation, BenchmarkAction]] | None = None,
|
|
477
|
+
) -> BenchmarkAction:
|
|
478
|
+
"""Use VLM API to determine next action.
|
|
479
|
+
|
|
480
|
+
Args:
|
|
481
|
+
observation: Current observation with screenshot.
|
|
482
|
+
task: Task being performed.
|
|
483
|
+
history: Previous observations and actions.
|
|
484
|
+
|
|
485
|
+
Returns:
|
|
486
|
+
BenchmarkAction parsed from VLM response.
|
|
487
|
+
"""
|
|
488
|
+
adapter = self._get_adapter()
|
|
489
|
+
|
|
490
|
+
# Build the sample for the API
|
|
491
|
+
sample = self._build_sample(observation, task, history)
|
|
492
|
+
|
|
493
|
+
# Call the VLM API
|
|
494
|
+
try:
|
|
495
|
+
response = adapter.generate(sample, max_new_tokens=self.max_tokens)
|
|
496
|
+
except Exception as e:
|
|
497
|
+
# On API error, return done to avoid infinite loops
|
|
498
|
+
return BenchmarkAction(
|
|
499
|
+
type="done",
|
|
500
|
+
raw_action={"error": str(e)},
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
# Parse the response into a BenchmarkAction
|
|
504
|
+
return self._parse_response(response, observation)
|
|
505
|
+
|
|
506
|
+
def _build_sample(
|
|
507
|
+
self,
|
|
508
|
+
observation: BenchmarkObservation,
|
|
509
|
+
task: BenchmarkTask,
|
|
510
|
+
history: list[tuple[BenchmarkObservation, BenchmarkAction]] | None,
|
|
511
|
+
) -> dict[str, Any]:
|
|
512
|
+
"""Build API sample from benchmark observation.
|
|
513
|
+
|
|
514
|
+
Args:
|
|
515
|
+
observation: Current observation.
|
|
516
|
+
task: Current task.
|
|
517
|
+
history: Action history.
|
|
518
|
+
|
|
519
|
+
Returns:
|
|
520
|
+
Sample dict with 'images' and 'messages'.
|
|
521
|
+
"""
|
|
522
|
+
# Build user message content
|
|
523
|
+
content_parts = [f"GOAL: {task.instruction}"]
|
|
524
|
+
|
|
525
|
+
# Add context
|
|
526
|
+
if observation.url:
|
|
527
|
+
content_parts.append(f"URL: {observation.url}")
|
|
528
|
+
if observation.window_title:
|
|
529
|
+
content_parts.append(f"Window: {observation.window_title}")
|
|
530
|
+
|
|
531
|
+
# Add accessibility tree if available and enabled
|
|
532
|
+
if self.use_accessibility_tree and observation.accessibility_tree:
|
|
533
|
+
tree_str = self._format_accessibility_tree(observation.accessibility_tree)
|
|
534
|
+
# Truncate if too long
|
|
535
|
+
if len(tree_str) > 4000:
|
|
536
|
+
tree_str = tree_str[:4000] + "\n... (truncated)"
|
|
537
|
+
content_parts.append(f"UI Elements:\n{tree_str}")
|
|
538
|
+
|
|
539
|
+
# Add history if enabled
|
|
540
|
+
if self.use_history and history:
|
|
541
|
+
history_str = self._format_history(history)
|
|
542
|
+
content_parts.append(f"Previous actions:\n{history_str}")
|
|
543
|
+
|
|
544
|
+
content_parts.append("\nWhat is the next action?")
|
|
545
|
+
|
|
546
|
+
# Build sample
|
|
547
|
+
sample: dict[str, Any] = {
|
|
548
|
+
"messages": [
|
|
549
|
+
{"role": "system", "content": self.SYSTEM_PROMPT},
|
|
550
|
+
{"role": "user", "content": "\n\n".join(content_parts)},
|
|
551
|
+
],
|
|
552
|
+
}
|
|
553
|
+
|
|
554
|
+
# Add image if available
|
|
555
|
+
if observation.screenshot_path:
|
|
556
|
+
sample["images"] = [observation.screenshot_path]
|
|
557
|
+
|
|
558
|
+
return sample
|
|
559
|
+
|
|
560
|
+
def _format_accessibility_tree(self, tree: dict, indent: int = 0) -> str:
|
|
561
|
+
"""Format accessibility tree for prompt.
|
|
562
|
+
|
|
563
|
+
Args:
|
|
564
|
+
tree: Accessibility tree dict.
|
|
565
|
+
indent: Current indentation level.
|
|
566
|
+
|
|
567
|
+
Returns:
|
|
568
|
+
Formatted string representation.
|
|
569
|
+
"""
|
|
570
|
+
lines = []
|
|
571
|
+
prefix = " " * indent
|
|
572
|
+
|
|
573
|
+
role = tree.get("role", "unknown")
|
|
574
|
+
name = tree.get("name", "")
|
|
575
|
+
node_id = tree.get("id", tree.get("node_id", ""))
|
|
576
|
+
|
|
577
|
+
line = f"{prefix}[{node_id}] {role}"
|
|
578
|
+
if name:
|
|
579
|
+
line += f": {name}"
|
|
580
|
+
lines.append(line)
|
|
581
|
+
|
|
582
|
+
for child in tree.get("children", []):
|
|
583
|
+
lines.append(self._format_accessibility_tree(child, indent + 1))
|
|
584
|
+
|
|
585
|
+
return "\n".join(lines)
|
|
586
|
+
|
|
587
|
+
def _format_history(
|
|
588
|
+
self, history: list[tuple[BenchmarkObservation, BenchmarkAction]]
|
|
589
|
+
) -> str:
|
|
590
|
+
"""Format action history for prompt.
|
|
591
|
+
|
|
592
|
+
Args:
|
|
593
|
+
history: List of (observation, action) pairs.
|
|
594
|
+
|
|
595
|
+
Returns:
|
|
596
|
+
Formatted string.
|
|
597
|
+
"""
|
|
598
|
+
lines = []
|
|
599
|
+
for i, (obs, action) in enumerate(history[-5:], 1): # Last 5 actions
|
|
600
|
+
action_str = self._action_to_string(action)
|
|
601
|
+
lines.append(f"{i}. {action_str}")
|
|
602
|
+
return "\n".join(lines)
|
|
603
|
+
|
|
604
|
+
def _action_to_string(self, action: BenchmarkAction) -> str:
|
|
605
|
+
"""Convert BenchmarkAction to string representation.
|
|
606
|
+
|
|
607
|
+
Args:
|
|
608
|
+
action: Action to convert.
|
|
609
|
+
|
|
610
|
+
Returns:
|
|
611
|
+
String representation.
|
|
612
|
+
"""
|
|
613
|
+
if action.type == "click":
|
|
614
|
+
if action.target_node_id:
|
|
615
|
+
return f"CLICK([{action.target_node_id}])"
|
|
616
|
+
if action.target_name:
|
|
617
|
+
return f"CLICK({action.target_name})"
|
|
618
|
+
return f"CLICK({action.x:.3f}, {action.y:.3f})"
|
|
619
|
+
elif action.type == "type":
|
|
620
|
+
return f"TYPE({action.text!r})"
|
|
621
|
+
elif action.type == "key":
|
|
622
|
+
mods = "+".join(action.modifiers or [])
|
|
623
|
+
key = action.key
|
|
624
|
+
if mods:
|
|
625
|
+
return f"KEY({mods}+{key})"
|
|
626
|
+
return f"KEY({key})"
|
|
627
|
+
elif action.type == "scroll":
|
|
628
|
+
return f"SCROLL({action.scroll_direction})"
|
|
629
|
+
elif action.type == "drag":
|
|
630
|
+
return f"DRAG({action.x:.3f}, {action.y:.3f}, {action.end_x:.3f}, {action.end_y:.3f})"
|
|
631
|
+
elif action.type == "done":
|
|
632
|
+
return "DONE()"
|
|
633
|
+
elif action.type == "answer":
|
|
634
|
+
return f"ANSWER({action.answer!r})"
|
|
635
|
+
else:
|
|
636
|
+
return f"{action.type.upper()}()"
|
|
637
|
+
|
|
638
|
+
def _parse_response(
|
|
639
|
+
self, response: str, observation: BenchmarkObservation | None = None
|
|
640
|
+
) -> BenchmarkAction:
|
|
641
|
+
"""Parse VLM response into BenchmarkAction.
|
|
642
|
+
|
|
643
|
+
Handles various response formats:
|
|
644
|
+
- ACTION: CLICK(0.5, 0.3)
|
|
645
|
+
- CLICK(0.5, 0.3)
|
|
646
|
+
- I'll click at coordinates (0.5, 0.3) -> CLICK(0.5, 0.3)
|
|
647
|
+
|
|
648
|
+
Args:
|
|
649
|
+
response: Raw VLM response text.
|
|
650
|
+
observation: Current observation (used for coordinate normalization).
|
|
651
|
+
|
|
652
|
+
Returns:
|
|
653
|
+
Parsed BenchmarkAction.
|
|
654
|
+
"""
|
|
655
|
+
# Store raw response for debugging
|
|
656
|
+
raw_action = {"response": response}
|
|
657
|
+
|
|
658
|
+
# Extract action line (look for ACTION: prefix or action pattern)
|
|
659
|
+
action_line = None
|
|
660
|
+
|
|
661
|
+
# Try to find ACTION: prefix
|
|
662
|
+
action_match = re.search(r"ACTION:\s*(.+)", response, re.IGNORECASE)
|
|
663
|
+
if action_match:
|
|
664
|
+
action_line = action_match.group(1).strip()
|
|
665
|
+
else:
|
|
666
|
+
# Look for action pattern anywhere in response
|
|
667
|
+
patterns = [
|
|
668
|
+
r"(CLICK\s*\([^)]+\))",
|
|
669
|
+
r"(TYPE\s*\([^)]+\))",
|
|
670
|
+
r"(KEY\s*\([^)]+\))",
|
|
671
|
+
r"(SCROLL\s*\([^)]+\))",
|
|
672
|
+
r"(DRAG\s*\([^)]+\))",
|
|
673
|
+
r"(DONE\s*\(\s*\))",
|
|
674
|
+
r"(ANSWER\s*\([^)]+\))",
|
|
675
|
+
]
|
|
676
|
+
for pattern in patterns:
|
|
677
|
+
match = re.search(pattern, response, re.IGNORECASE)
|
|
678
|
+
if match:
|
|
679
|
+
action_line = match.group(1).strip()
|
|
680
|
+
break
|
|
681
|
+
|
|
682
|
+
if not action_line:
|
|
683
|
+
# Could not parse action, return done
|
|
684
|
+
raw_action["parse_error"] = "No action pattern found"
|
|
685
|
+
return BenchmarkAction(type="done", raw_action=raw_action)
|
|
686
|
+
|
|
687
|
+
# Parse CLICK action
|
|
688
|
+
click_match = re.match(
|
|
689
|
+
r"CLICK\s*\(\s*\[?(\d+)\]?\s*\)", action_line, re.IGNORECASE
|
|
690
|
+
)
|
|
691
|
+
if click_match:
|
|
692
|
+
# CLICK([id]) - element ID
|
|
693
|
+
node_id = click_match.group(1)
|
|
694
|
+
return BenchmarkAction(
|
|
695
|
+
type="click",
|
|
696
|
+
target_node_id=node_id,
|
|
697
|
+
raw_action=raw_action,
|
|
698
|
+
)
|
|
699
|
+
|
|
700
|
+
click_coords = re.match(
|
|
701
|
+
r"CLICK\s*\(\s*([\d.]+)\s*,\s*([\d.]+)\s*\)", action_line, re.IGNORECASE
|
|
702
|
+
)
|
|
703
|
+
if click_coords:
|
|
704
|
+
# CLICK(x, y) - coordinates
|
|
705
|
+
x = float(click_coords.group(1))
|
|
706
|
+
y = float(click_coords.group(2))
|
|
707
|
+
|
|
708
|
+
# Normalize coordinates if they appear to be pixel values
|
|
709
|
+
# If x or y > 1.0, assume pixel coordinates and normalize using viewport
|
|
710
|
+
if observation and observation.viewport and (x > 1.0 or y > 1.0):
|
|
711
|
+
width, height = observation.viewport
|
|
712
|
+
x_norm = x / width
|
|
713
|
+
y_norm = y / height
|
|
714
|
+
raw_action["original_coords"] = {"x": x, "y": y}
|
|
715
|
+
raw_action["normalized"] = True
|
|
716
|
+
x = x_norm
|
|
717
|
+
y = y_norm
|
|
718
|
+
|
|
719
|
+
return BenchmarkAction(
|
|
720
|
+
type="click",
|
|
721
|
+
x=x,
|
|
722
|
+
y=y,
|
|
723
|
+
raw_action=raw_action,
|
|
724
|
+
)
|
|
725
|
+
|
|
726
|
+
# Parse TYPE action
|
|
727
|
+
type_match = re.match(
|
|
728
|
+
r"TYPE\s*\(\s*[\"'](.+?)[\"']\s*\)", action_line, re.IGNORECASE
|
|
729
|
+
)
|
|
730
|
+
if type_match:
|
|
731
|
+
text = type_match.group(1)
|
|
732
|
+
return BenchmarkAction(
|
|
733
|
+
type="type",
|
|
734
|
+
text=text,
|
|
735
|
+
raw_action=raw_action,
|
|
736
|
+
)
|
|
737
|
+
|
|
738
|
+
# Parse KEY action
|
|
739
|
+
key_match = re.match(r"KEY\s*\(\s*(.+?)\s*\)", action_line, re.IGNORECASE)
|
|
740
|
+
if key_match:
|
|
741
|
+
key_str = key_match.group(1)
|
|
742
|
+
# Handle modifier+key format
|
|
743
|
+
if "+" in key_str:
|
|
744
|
+
parts = key_str.split("+")
|
|
745
|
+
key = parts[-1]
|
|
746
|
+
modifiers = parts[:-1]
|
|
747
|
+
return BenchmarkAction(
|
|
748
|
+
type="key",
|
|
749
|
+
key=key,
|
|
750
|
+
modifiers=modifiers,
|
|
751
|
+
raw_action=raw_action,
|
|
752
|
+
)
|
|
753
|
+
return BenchmarkAction(
|
|
754
|
+
type="key",
|
|
755
|
+
key=key_str,
|
|
756
|
+
raw_action=raw_action,
|
|
757
|
+
)
|
|
758
|
+
|
|
759
|
+
# Parse SCROLL action
|
|
760
|
+
scroll_match = re.match(
|
|
761
|
+
r"SCROLL\s*\(\s*(up|down)\s*\)", action_line, re.IGNORECASE
|
|
762
|
+
)
|
|
763
|
+
if scroll_match:
|
|
764
|
+
direction = scroll_match.group(1).lower()
|
|
765
|
+
return BenchmarkAction(
|
|
766
|
+
type="scroll",
|
|
767
|
+
scroll_direction=direction,
|
|
768
|
+
raw_action=raw_action,
|
|
769
|
+
)
|
|
770
|
+
|
|
771
|
+
# Parse DRAG action
|
|
772
|
+
drag_match = re.match(
|
|
773
|
+
r"DRAG\s*\(\s*([\d.]+)\s*,\s*([\d.]+)\s*,\s*([\d.]+)\s*,\s*([\d.]+)\s*\)",
|
|
774
|
+
action_line,
|
|
775
|
+
re.IGNORECASE,
|
|
776
|
+
)
|
|
777
|
+
if drag_match:
|
|
778
|
+
x = float(drag_match.group(1))
|
|
779
|
+
y = float(drag_match.group(2))
|
|
780
|
+
end_x = float(drag_match.group(3))
|
|
781
|
+
end_y = float(drag_match.group(4))
|
|
782
|
+
|
|
783
|
+
# Normalize coordinates if they appear to be pixel values
|
|
784
|
+
if observation and observation.viewport and (x > 1.0 or y > 1.0 or end_x > 1.0 or end_y > 1.0):
|
|
785
|
+
width, height = observation.viewport
|
|
786
|
+
raw_action["original_coords"] = {"x": x, "y": y, "end_x": end_x, "end_y": end_y}
|
|
787
|
+
raw_action["normalized"] = True
|
|
788
|
+
x = x / width
|
|
789
|
+
y = y / height
|
|
790
|
+
end_x = end_x / width
|
|
791
|
+
end_y = end_y / height
|
|
792
|
+
|
|
793
|
+
return BenchmarkAction(
|
|
794
|
+
type="drag",
|
|
795
|
+
x=x,
|
|
796
|
+
y=y,
|
|
797
|
+
end_x=end_x,
|
|
798
|
+
end_y=end_y,
|
|
799
|
+
raw_action=raw_action,
|
|
800
|
+
)
|
|
801
|
+
|
|
802
|
+
# Parse DONE action
|
|
803
|
+
if re.match(r"DONE\s*\(\s*\)", action_line, re.IGNORECASE):
|
|
804
|
+
return BenchmarkAction(type="done", raw_action=raw_action)
|
|
805
|
+
|
|
806
|
+
# Parse ANSWER action
|
|
807
|
+
answer_match = re.match(
|
|
808
|
+
r"ANSWER\s*\(\s*[\"'](.+?)[\"']\s*\)", action_line, re.IGNORECASE
|
|
809
|
+
)
|
|
810
|
+
if answer_match:
|
|
811
|
+
answer = answer_match.group(1)
|
|
812
|
+
return BenchmarkAction(
|
|
813
|
+
type="answer",
|
|
814
|
+
answer=answer,
|
|
815
|
+
raw_action=raw_action,
|
|
816
|
+
)
|
|
817
|
+
|
|
818
|
+
# Unknown action format
|
|
819
|
+
raw_action["parse_error"] = f"Unknown action format: {action_line}"
|
|
820
|
+
return BenchmarkAction(type="done", raw_action=raw_action)
|
|
821
|
+
|
|
822
|
+
def reset(self) -> None:
|
|
823
|
+
"""Reset agent state."""
|
|
824
|
+
# APIBenchmarkAgent is stateless, nothing to reset
|
|
825
|
+
pass
|