openadapt-ml 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/baselines/__init__.py +121 -0
- openadapt_ml/baselines/adapter.py +185 -0
- openadapt_ml/baselines/cli.py +314 -0
- openadapt_ml/baselines/config.py +448 -0
- openadapt_ml/baselines/parser.py +922 -0
- openadapt_ml/baselines/prompts.py +787 -0
- openadapt_ml/benchmarks/__init__.py +13 -107
- openadapt_ml/benchmarks/agent.py +297 -374
- openadapt_ml/benchmarks/azure.py +62 -24
- openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
- openadapt_ml/benchmarks/cli.py +1874 -751
- openadapt_ml/benchmarks/trace_export.py +631 -0
- openadapt_ml/benchmarks/viewer.py +1236 -0
- openadapt_ml/benchmarks/vm_monitor.py +1111 -0
- openadapt_ml/benchmarks/waa_deploy/Dockerfile +216 -0
- openadapt_ml/benchmarks/waa_deploy/__init__.py +10 -0
- openadapt_ml/benchmarks/waa_deploy/api_agent.py +540 -0
- openadapt_ml/benchmarks/waa_deploy/start_waa_server.bat +53 -0
- openadapt_ml/cloud/azure_inference.py +3 -5
- openadapt_ml/cloud/lambda_labs.py +722 -307
- openadapt_ml/cloud/local.py +3194 -89
- openadapt_ml/cloud/ssh_tunnel.py +595 -0
- openadapt_ml/datasets/next_action.py +125 -96
- openadapt_ml/evals/grounding.py +32 -9
- openadapt_ml/evals/plot_eval_metrics.py +15 -13
- openadapt_ml/evals/trajectory_matching.py +120 -57
- openadapt_ml/experiments/demo_prompt/__init__.py +19 -0
- openadapt_ml/experiments/demo_prompt/format_demo.py +236 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_20251231_002125.json +83 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_n30_20251231_165958.json +1100 -0
- openadapt_ml/experiments/demo_prompt/results/multistep_20251231_025051.json +182 -0
- openadapt_ml/experiments/demo_prompt/run_experiment.py +541 -0
- openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
- openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
- openadapt_ml/experiments/representation_shootout/config.py +390 -0
- openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
- openadapt_ml/experiments/representation_shootout/runner.py +687 -0
- openadapt_ml/experiments/waa_demo/__init__.py +10 -0
- openadapt_ml/experiments/waa_demo/demos.py +357 -0
- openadapt_ml/experiments/waa_demo/runner.py +732 -0
- openadapt_ml/experiments/waa_demo/tasks.py +151 -0
- openadapt_ml/export/__init__.py +9 -0
- openadapt_ml/export/__main__.py +6 -0
- openadapt_ml/export/cli.py +89 -0
- openadapt_ml/export/parquet.py +277 -0
- openadapt_ml/grounding/detector.py +18 -14
- openadapt_ml/ingest/__init__.py +11 -10
- openadapt_ml/ingest/capture.py +97 -86
- openadapt_ml/ingest/loader.py +120 -69
- openadapt_ml/ingest/synthetic.py +344 -193
- openadapt_ml/models/api_adapter.py +14 -4
- openadapt_ml/models/base_adapter.py +10 -2
- openadapt_ml/models/providers/__init__.py +288 -0
- openadapt_ml/models/providers/anthropic.py +266 -0
- openadapt_ml/models/providers/base.py +299 -0
- openadapt_ml/models/providers/google.py +376 -0
- openadapt_ml/models/providers/openai.py +342 -0
- openadapt_ml/models/qwen_vl.py +46 -19
- openadapt_ml/perception/__init__.py +35 -0
- openadapt_ml/perception/integration.py +399 -0
- openadapt_ml/retrieval/README.md +226 -0
- openadapt_ml/retrieval/USAGE.md +391 -0
- openadapt_ml/retrieval/__init__.py +91 -0
- openadapt_ml/retrieval/demo_retriever.py +843 -0
- openadapt_ml/retrieval/embeddings.py +630 -0
- openadapt_ml/retrieval/index.py +194 -0
- openadapt_ml/retrieval/retriever.py +162 -0
- openadapt_ml/runtime/__init__.py +50 -0
- openadapt_ml/runtime/policy.py +27 -14
- openadapt_ml/runtime/safety_gate.py +471 -0
- openadapt_ml/schema/__init__.py +113 -0
- openadapt_ml/schema/converters.py +588 -0
- openadapt_ml/schema/episode.py +470 -0
- openadapt_ml/scripts/capture_screenshots.py +530 -0
- openadapt_ml/scripts/compare.py +102 -61
- openadapt_ml/scripts/demo_policy.py +4 -1
- openadapt_ml/scripts/eval_policy.py +19 -14
- openadapt_ml/scripts/make_gif.py +1 -1
- openadapt_ml/scripts/prepare_synthetic.py +16 -17
- openadapt_ml/scripts/train.py +98 -75
- openadapt_ml/segmentation/README.md +920 -0
- openadapt_ml/segmentation/__init__.py +97 -0
- openadapt_ml/segmentation/adapters/__init__.py +5 -0
- openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
- openadapt_ml/segmentation/annotator.py +610 -0
- openadapt_ml/segmentation/cache.py +290 -0
- openadapt_ml/segmentation/cli.py +674 -0
- openadapt_ml/segmentation/deduplicator.py +656 -0
- openadapt_ml/segmentation/frame_describer.py +788 -0
- openadapt_ml/segmentation/pipeline.py +340 -0
- openadapt_ml/segmentation/schemas.py +622 -0
- openadapt_ml/segmentation/segment_extractor.py +634 -0
- openadapt_ml/training/azure_ops_viewer.py +1097 -0
- openadapt_ml/training/benchmark_viewer.py +3255 -19
- openadapt_ml/training/shared_ui.py +7 -7
- openadapt_ml/training/stub_provider.py +57 -35
- openadapt_ml/training/trainer.py +255 -441
- openadapt_ml/training/trl_trainer.py +403 -0
- openadapt_ml/training/viewer.py +323 -108
- openadapt_ml/training/viewer_components.py +180 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +312 -69
- openadapt_ml-0.2.1.dist-info/RECORD +116 -0
- openadapt_ml/benchmarks/base.py +0 -366
- openadapt_ml/benchmarks/data_collection.py +0 -432
- openadapt_ml/benchmarks/runner.py +0 -381
- openadapt_ml/benchmarks/waa.py +0 -704
- openadapt_ml/schemas/__init__.py +0 -53
- openadapt_ml/schemas/sessions.py +0 -122
- openadapt_ml/schemas/validation.py +0 -252
- openadapt_ml-0.1.0.dist-info/RECORD +0 -55
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,732 @@
|
|
|
1
|
+
"""Runner for WAA demo-conditioned experiment.
|
|
2
|
+
|
|
3
|
+
Usage:
|
|
4
|
+
# List all tasks and demo status
|
|
5
|
+
python -m openadapt_ml.experiments.waa_demo.runner list
|
|
6
|
+
|
|
7
|
+
# Show a specific demo
|
|
8
|
+
python -m openadapt_ml.experiments.waa_demo.runner show 8
|
|
9
|
+
|
|
10
|
+
# Run experiment (requires WAA environment)
|
|
11
|
+
python -m openadapt_ml.experiments.waa_demo.runner run --condition demo
|
|
12
|
+
|
|
13
|
+
# Run with mock adapter (no Windows required)
|
|
14
|
+
python -m openadapt_ml.experiments.waa_demo.runner run --condition demo --mock
|
|
15
|
+
|
|
16
|
+
Integration with benchmarks runner:
|
|
17
|
+
# Via benchmarks CLI
|
|
18
|
+
python -m openadapt_ml.benchmarks.cli waa-demo --condition demo --tasks 5
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from __future__ import annotations
|
|
22
|
+
|
|
23
|
+
import argparse
|
|
24
|
+
import logging
|
|
25
|
+
import sys
|
|
26
|
+
from typing import TYPE_CHECKING, Any
|
|
27
|
+
|
|
28
|
+
from openadapt_ml.experiments.waa_demo.demos import (
|
|
29
|
+
format_demo_for_prompt,
|
|
30
|
+
get_complete_demos,
|
|
31
|
+
get_demo,
|
|
32
|
+
get_placeholder_demos,
|
|
33
|
+
)
|
|
34
|
+
from openadapt_ml.experiments.waa_demo.tasks import (
|
|
35
|
+
TASKS,
|
|
36
|
+
get_recorded_tasks,
|
|
37
|
+
get_task,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
if TYPE_CHECKING:
|
|
41
|
+
from openadapt_evals import (
|
|
42
|
+
BenchmarkAction,
|
|
43
|
+
BenchmarkObservation,
|
|
44
|
+
BenchmarkTask,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
logger = logging.getLogger(__name__)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def cmd_list(args: argparse.Namespace) -> int:
|
|
51
|
+
"""List all tasks with their demo status."""
|
|
52
|
+
print("WAA Demo Experiment - Task List")
|
|
53
|
+
print("=" * 80)
|
|
54
|
+
print()
|
|
55
|
+
|
|
56
|
+
complete = get_complete_demos()
|
|
57
|
+
placeholder = get_placeholder_demos()
|
|
58
|
+
|
|
59
|
+
print(f"Tasks: {len(TASKS)} total")
|
|
60
|
+
print(f" Manual demos written: {len(complete)}")
|
|
61
|
+
print(f" Recorded demos needed: {len(placeholder)}")
|
|
62
|
+
print()
|
|
63
|
+
print("-" * 80)
|
|
64
|
+
print(f"{'#':<3} {'Domain':<18} {'Difficulty':<8} {'Demo':<10} {'Instruction'}")
|
|
65
|
+
print("-" * 80)
|
|
66
|
+
|
|
67
|
+
for num, task in TASKS.items():
|
|
68
|
+
demo_status = "Ready" if num in complete else "NEEDS REC"
|
|
69
|
+
print(
|
|
70
|
+
f"{num:<3} {task.domain.value:<18} {task.difficulty.value:<8} "
|
|
71
|
+
f"{demo_status:<10} {task.instruction[:45]}..."
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
print()
|
|
75
|
+
print("Tasks needing recorded demos on Windows:")
|
|
76
|
+
for task in get_recorded_tasks():
|
|
77
|
+
print(
|
|
78
|
+
f" - #{list(TASKS.keys())[list(TASKS.values()).index(task)]}: {task.instruction}"
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
return 0
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def cmd_show(args: argparse.Namespace) -> int:
|
|
85
|
+
"""Show a specific demo."""
|
|
86
|
+
task_num = args.task
|
|
87
|
+
task = get_task(task_num)
|
|
88
|
+
demo = get_demo(task_num)
|
|
89
|
+
|
|
90
|
+
if not task:
|
|
91
|
+
print(f"Error: Task {task_num} not found (valid: 1-10)")
|
|
92
|
+
return 1
|
|
93
|
+
|
|
94
|
+
print(f"Task #{task_num}: {task.instruction}")
|
|
95
|
+
print(f"Domain: {task.domain.value}")
|
|
96
|
+
print(f"Difficulty: {task.difficulty.value}")
|
|
97
|
+
print(f"Demo method: {task.demo_method}")
|
|
98
|
+
print()
|
|
99
|
+
print("=" * 80)
|
|
100
|
+
print("DEMO:")
|
|
101
|
+
print("=" * 80)
|
|
102
|
+
print(demo or "No demo available")
|
|
103
|
+
|
|
104
|
+
return 0
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def cmd_prompt(args: argparse.Namespace) -> int:
|
|
108
|
+
"""Generate a prompt for a task with optional demo."""
|
|
109
|
+
task_num = args.task
|
|
110
|
+
task = get_task(task_num)
|
|
111
|
+
demo = get_demo(task_num) if args.with_demo else None
|
|
112
|
+
|
|
113
|
+
if not task:
|
|
114
|
+
print(f"Error: Task {task_num} not found")
|
|
115
|
+
return 1
|
|
116
|
+
|
|
117
|
+
print("=" * 80)
|
|
118
|
+
print("GENERATED PROMPT")
|
|
119
|
+
print("=" * 80)
|
|
120
|
+
print()
|
|
121
|
+
|
|
122
|
+
if demo and "[PLACEHOLDER" not in demo:
|
|
123
|
+
prompt = format_demo_for_prompt(demo, task.instruction)
|
|
124
|
+
print(prompt)
|
|
125
|
+
else:
|
|
126
|
+
print(f"Task: {task.instruction}")
|
|
127
|
+
print()
|
|
128
|
+
print(
|
|
129
|
+
"Analyze the screenshot and provide the next action to complete this task."
|
|
130
|
+
)
|
|
131
|
+
if demo and "[PLACEHOLDER" in demo:
|
|
132
|
+
print()
|
|
133
|
+
print("[Note: Demo not available - this would be zero-shot]")
|
|
134
|
+
|
|
135
|
+
return 0
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class DemoConditionedAgent:
|
|
139
|
+
"""Agent that uses demo-conditioned prompting for WAA tasks.
|
|
140
|
+
|
|
141
|
+
This agent extends the APIBenchmarkAgent approach but injects relevant
|
|
142
|
+
demos into the prompt based on the current task. It supports:
|
|
143
|
+
- Zero-shot mode: Standard VLM prompting without demos
|
|
144
|
+
- Demo-conditioned mode: Includes task-specific demonstration in prompt
|
|
145
|
+
|
|
146
|
+
The demo-conditioned approach was validated to improve first-action accuracy
|
|
147
|
+
from 33% (zero-shot) to 100% (with demo) in initial experiments.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
provider: API provider ("anthropic" or "openai")
|
|
151
|
+
condition: "zero-shot" or "demo"
|
|
152
|
+
api_key: Optional API key override
|
|
153
|
+
model: Optional model name override
|
|
154
|
+
max_tokens: Maximum tokens for response
|
|
155
|
+
use_accessibility_tree: Include accessibility tree in prompt
|
|
156
|
+
use_history: Include action history in prompt
|
|
157
|
+
|
|
158
|
+
Example:
|
|
159
|
+
agent = DemoConditionedAgent(provider="anthropic", condition="demo")
|
|
160
|
+
results = evaluate_agent_on_benchmark(agent, waa_adapter)
|
|
161
|
+
"""
|
|
162
|
+
|
|
163
|
+
# System prompt for demo-conditioned GUI automation
|
|
164
|
+
SYSTEM_PROMPT = """You are a GUI automation agent. Given a screenshot and task instruction, determine the next action to take.
|
|
165
|
+
|
|
166
|
+
Available actions:
|
|
167
|
+
- CLICK(x, y) - Click at coordinates (normalized 0.0-1.0 or pixels)
|
|
168
|
+
- CLICK([id]) - Click element with given ID from accessibility tree
|
|
169
|
+
- TYPE("text") - Type the given text
|
|
170
|
+
- KEY(key) - Press a key (e.g., Enter, Tab, Escape)
|
|
171
|
+
- KEY(modifier+key) - Press key combination (e.g., Ctrl+c, Alt+Tab)
|
|
172
|
+
- SCROLL(direction) - Scroll up or down
|
|
173
|
+
- DONE() - Task is complete
|
|
174
|
+
|
|
175
|
+
If a demonstration is provided, use it as a reference for understanding the UI navigation pattern.
|
|
176
|
+
Focus on the current state of the screen and select the appropriate next action.
|
|
177
|
+
|
|
178
|
+
Respond with exactly ONE action in the format shown above.
|
|
179
|
+
Think step by step, then output the action on a new line starting with "ACTION:"
|
|
180
|
+
"""
|
|
181
|
+
|
|
182
|
+
def __init__(
|
|
183
|
+
self,
|
|
184
|
+
provider: str = "anthropic",
|
|
185
|
+
condition: str = "demo",
|
|
186
|
+
api_key: str | None = None,
|
|
187
|
+
model: str | None = None,
|
|
188
|
+
max_tokens: int = 512,
|
|
189
|
+
use_accessibility_tree: bool = True,
|
|
190
|
+
use_history: bool = True,
|
|
191
|
+
):
|
|
192
|
+
self.provider = provider
|
|
193
|
+
self.condition = condition
|
|
194
|
+
self.api_key = api_key
|
|
195
|
+
self.model = model
|
|
196
|
+
self.max_tokens = max_tokens
|
|
197
|
+
self.use_accessibility_tree = use_accessibility_tree
|
|
198
|
+
self.use_history = use_history
|
|
199
|
+
self._adapter = None
|
|
200
|
+
self._task_demo_map: dict[str, str] = {}
|
|
201
|
+
self._build_task_demo_map()
|
|
202
|
+
|
|
203
|
+
def _build_task_demo_map(self) -> None:
|
|
204
|
+
"""Build mapping from WAA task IDs to demo text."""
|
|
205
|
+
for task_num, task in TASKS.items():
|
|
206
|
+
demo = get_demo(task_num)
|
|
207
|
+
if demo and "[PLACEHOLDER" not in demo:
|
|
208
|
+
# Map both task number and full task_id
|
|
209
|
+
self._task_demo_map[task_num] = demo
|
|
210
|
+
self._task_demo_map[task.task_id] = demo
|
|
211
|
+
|
|
212
|
+
def _get_adapter(self):
|
|
213
|
+
"""Lazily initialize the API adapter."""
|
|
214
|
+
if self._adapter is None:
|
|
215
|
+
from openadapt_ml.models.api_adapter import ApiVLMAdapter
|
|
216
|
+
|
|
217
|
+
self._adapter = ApiVLMAdapter(
|
|
218
|
+
provider=self.provider,
|
|
219
|
+
api_key=self.api_key,
|
|
220
|
+
)
|
|
221
|
+
return self._adapter
|
|
222
|
+
|
|
223
|
+
def _get_demo_for_task(self, task: "BenchmarkTask") -> str | None:
|
|
224
|
+
"""Get the demo for a task if available.
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
task: The benchmark task
|
|
228
|
+
|
|
229
|
+
Returns:
|
|
230
|
+
Demo text or None if not available
|
|
231
|
+
"""
|
|
232
|
+
if self.condition == "zero-shot":
|
|
233
|
+
return None
|
|
234
|
+
|
|
235
|
+
# Try to find demo by task_id
|
|
236
|
+
task_id = task.task_id
|
|
237
|
+
|
|
238
|
+
# Check direct mapping
|
|
239
|
+
if task_id in self._task_demo_map:
|
|
240
|
+
return self._task_demo_map[task_id]
|
|
241
|
+
|
|
242
|
+
# Try to extract task number from task_id patterns
|
|
243
|
+
for task_num, wa_task in TASKS.items():
|
|
244
|
+
if wa_task.task_id in task_id or task_id in wa_task.task_id:
|
|
245
|
+
return self._task_demo_map.get(task_num)
|
|
246
|
+
|
|
247
|
+
# Check if instruction matches
|
|
248
|
+
for task_num, wa_task in TASKS.items():
|
|
249
|
+
if wa_task.instruction.lower() in task.instruction.lower():
|
|
250
|
+
return self._task_demo_map.get(task_num)
|
|
251
|
+
|
|
252
|
+
return None
|
|
253
|
+
|
|
254
|
+
def act(
|
|
255
|
+
self,
|
|
256
|
+
observation: "BenchmarkObservation",
|
|
257
|
+
task: "BenchmarkTask",
|
|
258
|
+
history: list[tuple["BenchmarkObservation", "BenchmarkAction"]] | None = None,
|
|
259
|
+
) -> "BenchmarkAction":
|
|
260
|
+
"""Use VLM API with optional demo to determine next action.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
observation: Current observation with screenshot
|
|
264
|
+
task: Task being performed
|
|
265
|
+
history: Previous observations and actions
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
BenchmarkAction parsed from VLM response
|
|
269
|
+
"""
|
|
270
|
+
from openadapt_evals import BenchmarkAction
|
|
271
|
+
|
|
272
|
+
adapter = self._get_adapter()
|
|
273
|
+
|
|
274
|
+
# Build the sample for the API
|
|
275
|
+
sample = self._build_sample(observation, task, history)
|
|
276
|
+
|
|
277
|
+
# Call the VLM API
|
|
278
|
+
try:
|
|
279
|
+
response = adapter.generate(sample, max_new_tokens=self.max_tokens)
|
|
280
|
+
except Exception as e:
|
|
281
|
+
logger.error(f"API error: {e}")
|
|
282
|
+
return BenchmarkAction(
|
|
283
|
+
type="done",
|
|
284
|
+
raw_action={"error": str(e)},
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
# Parse the response into a BenchmarkAction
|
|
288
|
+
return self._parse_response(response, observation)
|
|
289
|
+
|
|
290
|
+
def _build_sample(
|
|
291
|
+
self,
|
|
292
|
+
observation: "BenchmarkObservation",
|
|
293
|
+
task: "BenchmarkTask",
|
|
294
|
+
history: list[tuple["BenchmarkObservation", "BenchmarkAction"]] | None,
|
|
295
|
+
) -> dict[str, Any]:
|
|
296
|
+
"""Build API sample with optional demo.
|
|
297
|
+
|
|
298
|
+
Args:
|
|
299
|
+
observation: Current observation
|
|
300
|
+
task: Current task
|
|
301
|
+
history: Action history
|
|
302
|
+
|
|
303
|
+
Returns:
|
|
304
|
+
Sample dict with 'images' and 'messages'
|
|
305
|
+
"""
|
|
306
|
+
content_parts = []
|
|
307
|
+
|
|
308
|
+
# Add demo if available and in demo condition
|
|
309
|
+
demo = self._get_demo_for_task(task)
|
|
310
|
+
if demo:
|
|
311
|
+
formatted_demo = format_demo_for_prompt(demo, task.instruction)
|
|
312
|
+
content_parts.append(formatted_demo)
|
|
313
|
+
else:
|
|
314
|
+
content_parts.append(f"GOAL: {task.instruction}")
|
|
315
|
+
|
|
316
|
+
# Add context
|
|
317
|
+
if observation.url:
|
|
318
|
+
content_parts.append(f"URL: {observation.url}")
|
|
319
|
+
if observation.window_title:
|
|
320
|
+
content_parts.append(f"Window: {observation.window_title}")
|
|
321
|
+
|
|
322
|
+
# Add accessibility tree if available and enabled
|
|
323
|
+
if self.use_accessibility_tree and observation.accessibility_tree:
|
|
324
|
+
tree_str = self._format_accessibility_tree(observation.accessibility_tree)
|
|
325
|
+
if len(tree_str) > 4000:
|
|
326
|
+
tree_str = tree_str[:4000] + "\n... (truncated)"
|
|
327
|
+
content_parts.append(f"UI Elements:\n{tree_str}")
|
|
328
|
+
|
|
329
|
+
# Add history if enabled
|
|
330
|
+
if self.use_history and history:
|
|
331
|
+
history_str = self._format_history(history)
|
|
332
|
+
content_parts.append(f"Previous actions:\n{history_str}")
|
|
333
|
+
|
|
334
|
+
content_parts.append(
|
|
335
|
+
"\nAnalyze the current screenshot and provide the next action."
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
sample: dict[str, Any] = {
|
|
339
|
+
"messages": [
|
|
340
|
+
{"role": "system", "content": self.SYSTEM_PROMPT},
|
|
341
|
+
{"role": "user", "content": "\n\n".join(content_parts)},
|
|
342
|
+
],
|
|
343
|
+
}
|
|
344
|
+
|
|
345
|
+
if observation.screenshot_path:
|
|
346
|
+
sample["images"] = [observation.screenshot_path]
|
|
347
|
+
|
|
348
|
+
return sample
|
|
349
|
+
|
|
350
|
+
def _format_accessibility_tree(self, tree: dict, indent: int = 0) -> str:
|
|
351
|
+
"""Format accessibility tree for prompt."""
|
|
352
|
+
lines = []
|
|
353
|
+
prefix = " " * indent
|
|
354
|
+
|
|
355
|
+
role = tree.get("role", "unknown")
|
|
356
|
+
name = tree.get("name", "")
|
|
357
|
+
node_id = tree.get("id", tree.get("node_id", ""))
|
|
358
|
+
|
|
359
|
+
line = f"{prefix}[{node_id}] {role}"
|
|
360
|
+
if name:
|
|
361
|
+
line += f": {name}"
|
|
362
|
+
lines.append(line)
|
|
363
|
+
|
|
364
|
+
for child in tree.get("children", []):
|
|
365
|
+
lines.append(self._format_accessibility_tree(child, indent + 1))
|
|
366
|
+
|
|
367
|
+
return "\n".join(lines)
|
|
368
|
+
|
|
369
|
+
def _format_history(
|
|
370
|
+
self, history: list[tuple["BenchmarkObservation", "BenchmarkAction"]]
|
|
371
|
+
) -> str:
|
|
372
|
+
"""Format action history for prompt."""
|
|
373
|
+
lines = []
|
|
374
|
+
for i, (obs, action) in enumerate(history[-5:], 1):
|
|
375
|
+
action_str = self._action_to_string(action)
|
|
376
|
+
lines.append(f"{i}. {action_str}")
|
|
377
|
+
return "\n".join(lines)
|
|
378
|
+
|
|
379
|
+
def _action_to_string(self, action: "BenchmarkAction") -> str:
|
|
380
|
+
"""Convert BenchmarkAction to string."""
|
|
381
|
+
if action.type == "click":
|
|
382
|
+
if action.target_node_id:
|
|
383
|
+
return f"CLICK([{action.target_node_id}])"
|
|
384
|
+
if action.target_name:
|
|
385
|
+
return f"CLICK({action.target_name})"
|
|
386
|
+
if action.x is not None and action.y is not None:
|
|
387
|
+
return f"CLICK({action.x:.3f}, {action.y:.3f})"
|
|
388
|
+
return "CLICK()"
|
|
389
|
+
elif action.type == "type":
|
|
390
|
+
return f"TYPE({action.text!r})"
|
|
391
|
+
elif action.type == "key":
|
|
392
|
+
mods = "+".join(action.modifiers or [])
|
|
393
|
+
key = action.key or ""
|
|
394
|
+
if mods:
|
|
395
|
+
return f"KEY({mods}+{key})"
|
|
396
|
+
return f"KEY({key})"
|
|
397
|
+
elif action.type == "scroll":
|
|
398
|
+
return f"SCROLL({action.scroll_direction})"
|
|
399
|
+
elif action.type == "done":
|
|
400
|
+
return "DONE()"
|
|
401
|
+
else:
|
|
402
|
+
return f"{action.type.upper()}()"
|
|
403
|
+
|
|
404
|
+
def _parse_response(
|
|
405
|
+
self, response: str, observation: "BenchmarkObservation" | None = None
|
|
406
|
+
) -> "BenchmarkAction":
|
|
407
|
+
"""Parse VLM response into BenchmarkAction.
|
|
408
|
+
|
|
409
|
+
Uses the same parsing logic as APIBenchmarkAgent.
|
|
410
|
+
"""
|
|
411
|
+
import re
|
|
412
|
+
from openadapt_evals import BenchmarkAction
|
|
413
|
+
|
|
414
|
+
raw_action = {"response": response}
|
|
415
|
+
|
|
416
|
+
# Extract action line
|
|
417
|
+
action_line = None
|
|
418
|
+
action_match = re.search(r"ACTION:\s*(.+)", response, re.IGNORECASE)
|
|
419
|
+
if action_match:
|
|
420
|
+
action_line = action_match.group(1).strip()
|
|
421
|
+
else:
|
|
422
|
+
patterns = [
|
|
423
|
+
r"(CLICK\s*\([^)]+\))",
|
|
424
|
+
r"(TYPE\s*\([^)]+\))",
|
|
425
|
+
r"(KEY\s*\([^)]+\))",
|
|
426
|
+
r"(SCROLL\s*\([^)]+\))",
|
|
427
|
+
r"(DONE\s*\(\s*\))",
|
|
428
|
+
]
|
|
429
|
+
for pattern in patterns:
|
|
430
|
+
match = re.search(pattern, response, re.IGNORECASE)
|
|
431
|
+
if match:
|
|
432
|
+
action_line = match.group(1).strip()
|
|
433
|
+
break
|
|
434
|
+
|
|
435
|
+
if not action_line:
|
|
436
|
+
raw_action["parse_error"] = "No action pattern found"
|
|
437
|
+
return BenchmarkAction(type="done", raw_action=raw_action)
|
|
438
|
+
|
|
439
|
+
# Parse CLICK with element ID
|
|
440
|
+
click_id_match = re.match(
|
|
441
|
+
r"CLICK\s*\(\s*\[?(\d+)\]?\s*\)", action_line, re.IGNORECASE
|
|
442
|
+
)
|
|
443
|
+
if click_id_match:
|
|
444
|
+
return BenchmarkAction(
|
|
445
|
+
type="click",
|
|
446
|
+
target_node_id=click_id_match.group(1),
|
|
447
|
+
raw_action=raw_action,
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
# Parse CLICK with coordinates
|
|
451
|
+
click_coords = re.match(
|
|
452
|
+
r"CLICK\s*\(\s*([\d.]+)\s*,\s*([\d.]+)\s*\)", action_line, re.IGNORECASE
|
|
453
|
+
)
|
|
454
|
+
if click_coords:
|
|
455
|
+
x = float(click_coords.group(1))
|
|
456
|
+
y = float(click_coords.group(2))
|
|
457
|
+
if observation and observation.viewport and (x > 1.0 or y > 1.0):
|
|
458
|
+
width, height = observation.viewport
|
|
459
|
+
x = x / width
|
|
460
|
+
y = y / height
|
|
461
|
+
return BenchmarkAction(type="click", x=x, y=y, raw_action=raw_action)
|
|
462
|
+
|
|
463
|
+
# Parse TYPE
|
|
464
|
+
type_match = re.match(
|
|
465
|
+
r"TYPE\s*\(\s*[\"'](.+?)[\"']\s*\)", action_line, re.IGNORECASE
|
|
466
|
+
)
|
|
467
|
+
if type_match:
|
|
468
|
+
return BenchmarkAction(
|
|
469
|
+
type="type", text=type_match.group(1), raw_action=raw_action
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
# Parse KEY
|
|
473
|
+
key_match = re.match(r"KEY\s*\(\s*(.+?)\s*\)", action_line, re.IGNORECASE)
|
|
474
|
+
if key_match:
|
|
475
|
+
key_str = key_match.group(1)
|
|
476
|
+
if "+" in key_str:
|
|
477
|
+
parts = key_str.split("+")
|
|
478
|
+
return BenchmarkAction(
|
|
479
|
+
type="key",
|
|
480
|
+
key=parts[-1],
|
|
481
|
+
modifiers=parts[:-1],
|
|
482
|
+
raw_action=raw_action,
|
|
483
|
+
)
|
|
484
|
+
return BenchmarkAction(type="key", key=key_str, raw_action=raw_action)
|
|
485
|
+
|
|
486
|
+
# Parse SCROLL
|
|
487
|
+
scroll_match = re.match(
|
|
488
|
+
r"SCROLL\s*\(\s*(up|down)\s*\)", action_line, re.IGNORECASE
|
|
489
|
+
)
|
|
490
|
+
if scroll_match:
|
|
491
|
+
return BenchmarkAction(
|
|
492
|
+
type="scroll",
|
|
493
|
+
scroll_direction=scroll_match.group(1).lower(),
|
|
494
|
+
raw_action=raw_action,
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
# Parse DONE
|
|
498
|
+
if re.match(r"DONE\s*\(\s*\)", action_line, re.IGNORECASE):
|
|
499
|
+
return BenchmarkAction(type="done", raw_action=raw_action)
|
|
500
|
+
|
|
501
|
+
raw_action["parse_error"] = f"Unknown action format: {action_line}"
|
|
502
|
+
return BenchmarkAction(type="done", raw_action=raw_action)
|
|
503
|
+
|
|
504
|
+
def reset(self) -> None:
|
|
505
|
+
"""Reset agent state between episodes."""
|
|
506
|
+
pass
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
def cmd_run(args: argparse.Namespace) -> int:
|
|
510
|
+
"""Run the WAA demo-conditioned experiment.
|
|
511
|
+
|
|
512
|
+
This integrates with the benchmarks infrastructure to run either
|
|
513
|
+
zero-shot or demo-conditioned evaluation on WAA tasks.
|
|
514
|
+
"""
|
|
515
|
+
from openadapt_evals import (
|
|
516
|
+
EvaluationConfig,
|
|
517
|
+
WAAMockAdapter,
|
|
518
|
+
compute_metrics,
|
|
519
|
+
evaluate_agent_on_benchmark,
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
print("WAA Demo-Conditioned Experiment Runner")
|
|
523
|
+
print("=" * 80)
|
|
524
|
+
print()
|
|
525
|
+
print(f"Condition: {args.condition}")
|
|
526
|
+
print(f"Provider: {args.provider}")
|
|
527
|
+
print(f"Tasks: {args.tasks or 'all with demos'}")
|
|
528
|
+
print()
|
|
529
|
+
|
|
530
|
+
# Determine which tasks to run
|
|
531
|
+
task_ids = None
|
|
532
|
+
if args.tasks:
|
|
533
|
+
task_nums = [t.strip() for t in args.tasks.split(",")]
|
|
534
|
+
# Map task numbers to WAA task IDs
|
|
535
|
+
task_ids = []
|
|
536
|
+
for num in task_nums:
|
|
537
|
+
task = get_task(num)
|
|
538
|
+
if task:
|
|
539
|
+
task_ids.append(task.task_id)
|
|
540
|
+
else:
|
|
541
|
+
print(f"Warning: Task {num} not found")
|
|
542
|
+
else:
|
|
543
|
+
# Default to all tasks with complete demos
|
|
544
|
+
complete_demos = get_complete_demos()
|
|
545
|
+
task_ids = []
|
|
546
|
+
for num in complete_demos.keys():
|
|
547
|
+
task = get_task(num)
|
|
548
|
+
if task:
|
|
549
|
+
task_ids.append(task.task_id)
|
|
550
|
+
print(f"Running {len(task_ids)} tasks with complete demos")
|
|
551
|
+
|
|
552
|
+
# Check for mock mode or real WAA
|
|
553
|
+
use_mock = getattr(args, "mock", False)
|
|
554
|
+
|
|
555
|
+
if use_mock:
|
|
556
|
+
print("Using mock adapter (no Windows required)")
|
|
557
|
+
adapter = WAAMockAdapter(num_tasks=len(task_ids) if task_ids else 10)
|
|
558
|
+
# Override task_ids since mock adapter has different IDs
|
|
559
|
+
task_ids = None
|
|
560
|
+
elif args.waa_url:
|
|
561
|
+
print(f"WAA URL: {args.waa_url}")
|
|
562
|
+
print("Note: Real WAA integration requires a running Windows VM")
|
|
563
|
+
print()
|
|
564
|
+
print("To set up WAA:")
|
|
565
|
+
print(" uv run python -m openadapt_ml.benchmarks.cli vm setup-waa")
|
|
566
|
+
print(" uv run python -m openadapt_ml.benchmarks.cli vm prepare-windows")
|
|
567
|
+
print()
|
|
568
|
+
# For now, fall back to mock since we can't connect to real WAA without VM
|
|
569
|
+
print("Falling back to mock adapter for demonstration...")
|
|
570
|
+
adapter = WAAMockAdapter(num_tasks=len(task_ids) if task_ids else 10)
|
|
571
|
+
task_ids = None
|
|
572
|
+
else:
|
|
573
|
+
print("No WAA URL provided, using mock adapter")
|
|
574
|
+
adapter = WAAMockAdapter(num_tasks=len(task_ids) if task_ids else 10)
|
|
575
|
+
task_ids = None
|
|
576
|
+
|
|
577
|
+
# Create the demo-conditioned agent
|
|
578
|
+
agent = DemoConditionedAgent(
|
|
579
|
+
provider=args.provider,
|
|
580
|
+
condition=args.condition,
|
|
581
|
+
max_tokens=512,
|
|
582
|
+
use_accessibility_tree=True,
|
|
583
|
+
use_history=True,
|
|
584
|
+
)
|
|
585
|
+
|
|
586
|
+
# Configure evaluation
|
|
587
|
+
config = EvaluationConfig(
|
|
588
|
+
max_steps=args.max_steps,
|
|
589
|
+
parallel=1,
|
|
590
|
+
save_trajectories=True,
|
|
591
|
+
save_execution_traces=True,
|
|
592
|
+
model_id=f"{args.provider}-{args.condition}",
|
|
593
|
+
output_dir=args.output or "benchmark_results",
|
|
594
|
+
run_name=args.run_name,
|
|
595
|
+
verbose=True,
|
|
596
|
+
)
|
|
597
|
+
|
|
598
|
+
print()
|
|
599
|
+
print("Starting evaluation...")
|
|
600
|
+
print("(Each step calls the VLM API - this may take a while)")
|
|
601
|
+
print()
|
|
602
|
+
|
|
603
|
+
try:
|
|
604
|
+
results = evaluate_agent_on_benchmark(
|
|
605
|
+
agent=agent,
|
|
606
|
+
adapter=adapter,
|
|
607
|
+
task_ids=task_ids,
|
|
608
|
+
config=config,
|
|
609
|
+
)
|
|
610
|
+
except Exception as e:
|
|
611
|
+
print(f"Error during evaluation: {e}")
|
|
612
|
+
if "API key" in str(e) or "api_key" in str(e).lower():
|
|
613
|
+
key_name = (
|
|
614
|
+
"ANTHROPIC_API_KEY"
|
|
615
|
+
if args.provider == "anthropic"
|
|
616
|
+
else "OPENAI_API_KEY"
|
|
617
|
+
)
|
|
618
|
+
print(f"\nMake sure {key_name} is set in your environment or .env file.")
|
|
619
|
+
return 1
|
|
620
|
+
|
|
621
|
+
# Print results
|
|
622
|
+
metrics = compute_metrics(results)
|
|
623
|
+
print()
|
|
624
|
+
print("=" * 80)
|
|
625
|
+
print("RESULTS")
|
|
626
|
+
print("=" * 80)
|
|
627
|
+
print(f"Condition: {args.condition}")
|
|
628
|
+
print(f"Tasks: {metrics['num_tasks']}")
|
|
629
|
+
print(f"Success rate: {metrics['success_rate']:.1%}")
|
|
630
|
+
print(f"Successes: {metrics['success_count']}")
|
|
631
|
+
print(f"Failures: {metrics['fail_count']}")
|
|
632
|
+
print(f"Avg steps: {metrics['avg_steps']:.1f}")
|
|
633
|
+
print()
|
|
634
|
+
|
|
635
|
+
# Show per-task results
|
|
636
|
+
print("Per-task results:")
|
|
637
|
+
for result in results:
|
|
638
|
+
status = "PASS" if result.success else "FAIL"
|
|
639
|
+
print(f" {result.task_id}: {status} ({result.num_steps} steps)")
|
|
640
|
+
|
|
641
|
+
return 0
|
|
642
|
+
|
|
643
|
+
|
|
644
|
+
def main() -> int:
|
|
645
|
+
parser = argparse.ArgumentParser(
|
|
646
|
+
description="WAA Demo-Conditioned Experiment Runner",
|
|
647
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
648
|
+
epilog="""
|
|
649
|
+
Examples:
|
|
650
|
+
# List tasks and their demo status
|
|
651
|
+
python -m openadapt_ml.experiments.waa_demo.runner list
|
|
652
|
+
|
|
653
|
+
# Show a specific demo
|
|
654
|
+
python -m openadapt_ml.experiments.waa_demo.runner show 8
|
|
655
|
+
|
|
656
|
+
# Run with demo conditioning (mock adapter, no Windows needed)
|
|
657
|
+
python -m openadapt_ml.experiments.waa_demo.runner run --condition demo --mock
|
|
658
|
+
|
|
659
|
+
# Run zero-shot for comparison
|
|
660
|
+
python -m openadapt_ml.experiments.waa_demo.runner run --condition zero-shot --mock
|
|
661
|
+
|
|
662
|
+
# Run with real WAA (requires running Windows VM)
|
|
663
|
+
python -m openadapt_ml.experiments.waa_demo.runner run --condition demo --waa-url http://<vm-ip>:5000
|
|
664
|
+
""",
|
|
665
|
+
)
|
|
666
|
+
subparsers = parser.add_subparsers(dest="command", required=True)
|
|
667
|
+
|
|
668
|
+
# list command
|
|
669
|
+
list_parser = subparsers.add_parser("list", help="List all tasks")
|
|
670
|
+
list_parser.set_defaults(func=cmd_list)
|
|
671
|
+
|
|
672
|
+
# show command
|
|
673
|
+
show_parser = subparsers.add_parser("show", help="Show a specific demo")
|
|
674
|
+
show_parser.add_argument("task", help="Task number (1-10)")
|
|
675
|
+
show_parser.set_defaults(func=cmd_show)
|
|
676
|
+
|
|
677
|
+
# prompt command
|
|
678
|
+
prompt_parser = subparsers.add_parser("prompt", help="Generate prompt for a task")
|
|
679
|
+
prompt_parser.add_argument("task", help="Task number (1-10)")
|
|
680
|
+
prompt_parser.add_argument("--with-demo", action="store_true", help="Include demo")
|
|
681
|
+
prompt_parser.set_defaults(func=cmd_prompt)
|
|
682
|
+
|
|
683
|
+
# run command
|
|
684
|
+
run_parser = subparsers.add_parser("run", help="Run experiment")
|
|
685
|
+
run_parser.add_argument(
|
|
686
|
+
"--condition",
|
|
687
|
+
choices=["zero-shot", "demo"],
|
|
688
|
+
default="demo",
|
|
689
|
+
help="Experiment condition (default: demo)",
|
|
690
|
+
)
|
|
691
|
+
run_parser.add_argument(
|
|
692
|
+
"--provider",
|
|
693
|
+
choices=["anthropic", "openai"],
|
|
694
|
+
default="anthropic",
|
|
695
|
+
help="VLM API provider (default: anthropic)",
|
|
696
|
+
)
|
|
697
|
+
run_parser.add_argument(
|
|
698
|
+
"--tasks",
|
|
699
|
+
help="Comma-separated task numbers (default: all with demos)",
|
|
700
|
+
)
|
|
701
|
+
run_parser.add_argument(
|
|
702
|
+
"--max-steps",
|
|
703
|
+
type=int,
|
|
704
|
+
default=15,
|
|
705
|
+
help="Maximum steps per task (default: 15)",
|
|
706
|
+
)
|
|
707
|
+
run_parser.add_argument(
|
|
708
|
+
"--mock",
|
|
709
|
+
action="store_true",
|
|
710
|
+
help="Use mock adapter (no Windows required)",
|
|
711
|
+
)
|
|
712
|
+
run_parser.add_argument(
|
|
713
|
+
"--waa-url",
|
|
714
|
+
help="WAA server URL (e.g., http://vm-ip:5000)",
|
|
715
|
+
)
|
|
716
|
+
run_parser.add_argument(
|
|
717
|
+
"--output",
|
|
718
|
+
default="benchmark_results",
|
|
719
|
+
help="Output directory (default: benchmark_results)",
|
|
720
|
+
)
|
|
721
|
+
run_parser.add_argument(
|
|
722
|
+
"--run-name",
|
|
723
|
+
help="Run name (default: auto-generated)",
|
|
724
|
+
)
|
|
725
|
+
run_parser.set_defaults(func=cmd_run)
|
|
726
|
+
|
|
727
|
+
args = parser.parse_args()
|
|
728
|
+
return args.func(args)
|
|
729
|
+
|
|
730
|
+
|
|
731
|
+
if __name__ == "__main__":
|
|
732
|
+
sys.exit(main())
|