openadapt-ml 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/__init__.py +0 -0
- openadapt_ml/benchmarks/__init__.py +125 -0
- openadapt_ml/benchmarks/agent.py +825 -0
- openadapt_ml/benchmarks/azure.py +761 -0
- openadapt_ml/benchmarks/base.py +366 -0
- openadapt_ml/benchmarks/cli.py +884 -0
- openadapt_ml/benchmarks/data_collection.py +432 -0
- openadapt_ml/benchmarks/runner.py +381 -0
- openadapt_ml/benchmarks/waa.py +704 -0
- openadapt_ml/cloud/__init__.py +5 -0
- openadapt_ml/cloud/azure_inference.py +441 -0
- openadapt_ml/cloud/lambda_labs.py +2445 -0
- openadapt_ml/cloud/local.py +790 -0
- openadapt_ml/config.py +56 -0
- openadapt_ml/datasets/__init__.py +0 -0
- openadapt_ml/datasets/next_action.py +507 -0
- openadapt_ml/evals/__init__.py +23 -0
- openadapt_ml/evals/grounding.py +241 -0
- openadapt_ml/evals/plot_eval_metrics.py +174 -0
- openadapt_ml/evals/trajectory_matching.py +486 -0
- openadapt_ml/grounding/__init__.py +45 -0
- openadapt_ml/grounding/base.py +236 -0
- openadapt_ml/grounding/detector.py +570 -0
- openadapt_ml/ingest/__init__.py +43 -0
- openadapt_ml/ingest/capture.py +312 -0
- openadapt_ml/ingest/loader.py +232 -0
- openadapt_ml/ingest/synthetic.py +1102 -0
- openadapt_ml/models/__init__.py +0 -0
- openadapt_ml/models/api_adapter.py +171 -0
- openadapt_ml/models/base_adapter.py +59 -0
- openadapt_ml/models/dummy_adapter.py +42 -0
- openadapt_ml/models/qwen_vl.py +426 -0
- openadapt_ml/runtime/__init__.py +0 -0
- openadapt_ml/runtime/policy.py +182 -0
- openadapt_ml/schemas/__init__.py +53 -0
- openadapt_ml/schemas/sessions.py +122 -0
- openadapt_ml/schemas/validation.py +252 -0
- openadapt_ml/scripts/__init__.py +0 -0
- openadapt_ml/scripts/compare.py +1490 -0
- openadapt_ml/scripts/demo_policy.py +62 -0
- openadapt_ml/scripts/eval_policy.py +287 -0
- openadapt_ml/scripts/make_gif.py +153 -0
- openadapt_ml/scripts/prepare_synthetic.py +43 -0
- openadapt_ml/scripts/run_qwen_login_benchmark.py +192 -0
- openadapt_ml/scripts/train.py +174 -0
- openadapt_ml/training/__init__.py +0 -0
- openadapt_ml/training/benchmark_viewer.py +1538 -0
- openadapt_ml/training/shared_ui.py +157 -0
- openadapt_ml/training/stub_provider.py +276 -0
- openadapt_ml/training/trainer.py +2446 -0
- openadapt_ml/training/viewer.py +2970 -0
- openadapt_ml-0.1.0.dist-info/METADATA +818 -0
- openadapt_ml-0.1.0.dist-info/RECORD +55 -0
- openadapt_ml-0.1.0.dist-info/WHEEL +4 -0
- openadapt_ml-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,704 @@
|
|
|
1
|
+
"""Windows Agent Arena (WAA) benchmark adapter.
|
|
2
|
+
|
|
3
|
+
This module provides integration with the Windows Agent Arena benchmark,
|
|
4
|
+
enabling evaluation of GUI agents on 154 Windows tasks across 11 domains.
|
|
5
|
+
|
|
6
|
+
WAA Repository: https://github.com/microsoft/WindowsAgentArena
|
|
7
|
+
|
|
8
|
+
Example:
|
|
9
|
+
from openadapt_ml.benchmarks import WAAAdapter, PolicyAgent, evaluate_agent_on_benchmark
|
|
10
|
+
|
|
11
|
+
adapter = WAAAdapter(waa_repo_path="/path/to/WindowsAgentArena")
|
|
12
|
+
agent = PolicyAgent(policy)
|
|
13
|
+
results = evaluate_agent_on_benchmark(agent, adapter, max_steps=15)
|
|
14
|
+
print(f"Success rate: {sum(r.success for r in results) / len(results):.1%}")
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import json
|
|
20
|
+
import logging
|
|
21
|
+
import sys
|
|
22
|
+
import time
|
|
23
|
+
from dataclasses import dataclass
|
|
24
|
+
from pathlib import Path
|
|
25
|
+
from typing import Any
|
|
26
|
+
|
|
27
|
+
from openadapt_ml.benchmarks.base import (
|
|
28
|
+
BenchmarkAction,
|
|
29
|
+
BenchmarkAdapter,
|
|
30
|
+
BenchmarkObservation,
|
|
31
|
+
BenchmarkResult,
|
|
32
|
+
BenchmarkTask,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
logger = logging.getLogger(__name__)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# WAA domain mapping (11 domains, 154 tasks)
|
|
39
|
+
WAA_DOMAINS = [
|
|
40
|
+
"browser",
|
|
41
|
+
"office",
|
|
42
|
+
"coding",
|
|
43
|
+
"media",
|
|
44
|
+
"notepad",
|
|
45
|
+
"paint",
|
|
46
|
+
"file_explorer",
|
|
47
|
+
"clock",
|
|
48
|
+
"settings",
|
|
49
|
+
"edge",
|
|
50
|
+
"vscode",
|
|
51
|
+
]
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@dataclass
|
|
55
|
+
class WAAConfig:
|
|
56
|
+
"""Configuration for WAA adapter.
|
|
57
|
+
|
|
58
|
+
Attributes:
|
|
59
|
+
waa_repo_path: Path to cloned WindowsAgentArena repository.
|
|
60
|
+
use_azure: Whether to use Azure VMs (enables parallelism).
|
|
61
|
+
observation_type: Type of observation to capture.
|
|
62
|
+
a11y_backend: Accessibility backend ("uia" or "win32").
|
|
63
|
+
screen_width: Screen width in pixels.
|
|
64
|
+
screen_height: Screen height in pixels.
|
|
65
|
+
max_steps: Default maximum steps per task.
|
|
66
|
+
action_delay: Delay between actions in seconds.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
waa_repo_path: str
|
|
70
|
+
use_azure: bool = False
|
|
71
|
+
observation_type: str = "screenshot_a11y_tree" # "screenshot", "a11y_tree", "som"
|
|
72
|
+
a11y_backend: str = "uia" # "uia" or "win32"
|
|
73
|
+
screen_width: int = 1920
|
|
74
|
+
screen_height: int = 1200
|
|
75
|
+
max_steps: int = 15
|
|
76
|
+
action_delay: float = 0.5
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class WAAAdapter(BenchmarkAdapter):
|
|
80
|
+
"""Windows Agent Arena benchmark adapter.
|
|
81
|
+
|
|
82
|
+
Integrates with the WAA benchmark to evaluate GUI agents on 154 Windows
|
|
83
|
+
desktop automation tasks spanning 11 application domains.
|
|
84
|
+
|
|
85
|
+
The adapter wraps WAA's DesktopEnv and provides:
|
|
86
|
+
- Task loading from WAA's JSON task definitions
|
|
87
|
+
- VM/environment reset to task initial state
|
|
88
|
+
- Action execution via WAA's controller
|
|
89
|
+
- Evaluation using WAA's native evaluators
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
waa_repo_path: Path to cloned WindowsAgentArena repository.
|
|
93
|
+
use_azure: Use Azure VMs for execution (enables parallelism).
|
|
94
|
+
config: Full WAAConfig (overrides other args if provided).
|
|
95
|
+
**kwargs: Additional config options passed to WAAConfig.
|
|
96
|
+
|
|
97
|
+
Raises:
|
|
98
|
+
ValueError: If waa_repo_path doesn't exist.
|
|
99
|
+
ImportError: If WAA dependencies not available.
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
def __init__(
|
|
103
|
+
self,
|
|
104
|
+
waa_repo_path: str | Path | None = None,
|
|
105
|
+
use_azure: bool = False,
|
|
106
|
+
config: WAAConfig | None = None,
|
|
107
|
+
**kwargs,
|
|
108
|
+
):
|
|
109
|
+
if config is not None:
|
|
110
|
+
self.config = config
|
|
111
|
+
else:
|
|
112
|
+
if waa_repo_path is None:
|
|
113
|
+
raise ValueError("waa_repo_path is required")
|
|
114
|
+
self.config = WAAConfig(
|
|
115
|
+
waa_repo_path=str(waa_repo_path),
|
|
116
|
+
use_azure=use_azure,
|
|
117
|
+
**kwargs,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
self.waa_repo = Path(self.config.waa_repo_path)
|
|
121
|
+
if not self.waa_repo.exists():
|
|
122
|
+
raise ValueError(f"WAA repository not found at: {self.waa_repo}")
|
|
123
|
+
|
|
124
|
+
# Paths to WAA components
|
|
125
|
+
self._client_path = self.waa_repo / "src" / "win-arena-container" / "client"
|
|
126
|
+
self._tasks_path = self._client_path / "evaluation_examples_windows"
|
|
127
|
+
|
|
128
|
+
# Lazy-loaded WAA components
|
|
129
|
+
self._desktop_env = None
|
|
130
|
+
self._task_cache: dict[str, BenchmarkTask] = {}
|
|
131
|
+
self._current_task: BenchmarkTask | None = None
|
|
132
|
+
self._waa_imported = False
|
|
133
|
+
|
|
134
|
+
def _ensure_waa_imported(self) -> None:
|
|
135
|
+
"""Import WAA modules (lazy loading)."""
|
|
136
|
+
if self._waa_imported:
|
|
137
|
+
return
|
|
138
|
+
|
|
139
|
+
# Add WAA client to path
|
|
140
|
+
client_path = str(self._client_path)
|
|
141
|
+
if client_path not in sys.path:
|
|
142
|
+
sys.path.insert(0, client_path)
|
|
143
|
+
|
|
144
|
+
try:
|
|
145
|
+
# Import WAA's DesktopEnv
|
|
146
|
+
from desktop_env import DesktopEnv
|
|
147
|
+
|
|
148
|
+
self._DesktopEnv = DesktopEnv
|
|
149
|
+
self._waa_imported = True
|
|
150
|
+
logger.info("WAA modules imported successfully")
|
|
151
|
+
except ImportError as e:
|
|
152
|
+
raise ImportError(
|
|
153
|
+
f"Failed to import WAA modules. Ensure WAA is properly installed "
|
|
154
|
+
f"and dependencies are available: {e}"
|
|
155
|
+
) from e
|
|
156
|
+
|
|
157
|
+
@property
|
|
158
|
+
def name(self) -> str:
|
|
159
|
+
"""Benchmark name."""
|
|
160
|
+
return "waa"
|
|
161
|
+
|
|
162
|
+
@property
|
|
163
|
+
def benchmark_type(self) -> str:
|
|
164
|
+
"""Benchmark type (interactive)."""
|
|
165
|
+
return "interactive"
|
|
166
|
+
|
|
167
|
+
@property
|
|
168
|
+
def supports_parallel(self) -> bool:
|
|
169
|
+
"""Whether parallel execution is supported (requires Azure)."""
|
|
170
|
+
return self.config.use_azure
|
|
171
|
+
|
|
172
|
+
def list_tasks(self, domain: str | None = None) -> list[BenchmarkTask]:
|
|
173
|
+
"""List available WAA tasks.
|
|
174
|
+
|
|
175
|
+
WAA has 154 tasks across 11 domains:
|
|
176
|
+
- browser: Edge/Chrome navigation and settings
|
|
177
|
+
- office: Word, Excel, Outlook
|
|
178
|
+
- coding: VSCode, terminal
|
|
179
|
+
- settings: Windows Settings app
|
|
180
|
+
- file_explorer: File operations
|
|
181
|
+
- notepad: Text editing
|
|
182
|
+
- paint: Drawing operations
|
|
183
|
+
- media: Video/audio playback
|
|
184
|
+
- clock: Alarms, timers
|
|
185
|
+
- edge: Browser-specific
|
|
186
|
+
- vscode: IDE-specific
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
domain: Optional domain filter.
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
List of BenchmarkTask objects.
|
|
193
|
+
"""
|
|
194
|
+
tasks = self._load_all_tasks()
|
|
195
|
+
|
|
196
|
+
if domain is not None:
|
|
197
|
+
tasks = [t for t in tasks if t.domain == domain]
|
|
198
|
+
|
|
199
|
+
return tasks
|
|
200
|
+
|
|
201
|
+
def load_task(self, task_id: str) -> BenchmarkTask:
|
|
202
|
+
"""Load a specific task by ID.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
task_id: Task identifier (e.g., "notepad_1", "browser_5").
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
BenchmarkTask object.
|
|
209
|
+
|
|
210
|
+
Raises:
|
|
211
|
+
KeyError: If task_id not found.
|
|
212
|
+
"""
|
|
213
|
+
if task_id in self._task_cache:
|
|
214
|
+
return self._task_cache[task_id]
|
|
215
|
+
|
|
216
|
+
# Try to load from disk
|
|
217
|
+
tasks = self._load_all_tasks()
|
|
218
|
+
task_map = {t.task_id: t for t in tasks}
|
|
219
|
+
|
|
220
|
+
if task_id not in task_map:
|
|
221
|
+
raise KeyError(f"Task '{task_id}' not found. Available: {list(task_map.keys())[:10]}...")
|
|
222
|
+
|
|
223
|
+
return task_map[task_id]
|
|
224
|
+
|
|
225
|
+
def reset(self, task: BenchmarkTask) -> BenchmarkObservation:
|
|
226
|
+
"""Reset environment to task's initial state.
|
|
227
|
+
|
|
228
|
+
This initializes the Windows VM/desktop to the state required for
|
|
229
|
+
the task, including opening required applications and setting up
|
|
230
|
+
any pre-conditions.
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
task: Task to initialize.
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
Initial observation (screenshot + accessibility tree).
|
|
237
|
+
"""
|
|
238
|
+
self._ensure_waa_imported()
|
|
239
|
+
self._current_task = task
|
|
240
|
+
|
|
241
|
+
# Initialize DesktopEnv if needed
|
|
242
|
+
if self._desktop_env is None:
|
|
243
|
+
self._desktop_env = self._create_desktop_env()
|
|
244
|
+
|
|
245
|
+
# Load task config and reset environment
|
|
246
|
+
task_config = self._load_waa_task_config(task)
|
|
247
|
+
obs = self._desktop_env.reset(task_config=task_config)
|
|
248
|
+
|
|
249
|
+
return self._to_benchmark_observation(obs)
|
|
250
|
+
|
|
251
|
+
def step(
|
|
252
|
+
self, action: BenchmarkAction
|
|
253
|
+
) -> tuple[BenchmarkObservation, bool, dict[str, Any]]:
|
|
254
|
+
"""Execute action and return new observation.
|
|
255
|
+
|
|
256
|
+
Args:
|
|
257
|
+
action: Action to execute.
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
Tuple of (observation, done, info).
|
|
261
|
+
"""
|
|
262
|
+
if self._desktop_env is None:
|
|
263
|
+
raise RuntimeError("Call reset() before step()")
|
|
264
|
+
|
|
265
|
+
# Convert to WAA action format
|
|
266
|
+
waa_action = self._to_waa_action(action)
|
|
267
|
+
|
|
268
|
+
# Execute action
|
|
269
|
+
obs, reward, done, info = self._desktop_env.step(waa_action)
|
|
270
|
+
|
|
271
|
+
# Optional delay between actions
|
|
272
|
+
if self.config.action_delay > 0:
|
|
273
|
+
time.sleep(self.config.action_delay)
|
|
274
|
+
|
|
275
|
+
return self._to_benchmark_observation(obs), done, info
|
|
276
|
+
|
|
277
|
+
def evaluate(self, task: BenchmarkTask) -> BenchmarkResult:
|
|
278
|
+
"""Run WAA's native evaluation on current state.
|
|
279
|
+
|
|
280
|
+
WAA evaluators check the actual OS state (files, settings, app state)
|
|
281
|
+
to determine if the task was completed successfully.
|
|
282
|
+
|
|
283
|
+
Args:
|
|
284
|
+
task: Task to evaluate.
|
|
285
|
+
|
|
286
|
+
Returns:
|
|
287
|
+
BenchmarkResult with success/score.
|
|
288
|
+
"""
|
|
289
|
+
if self._desktop_env is None:
|
|
290
|
+
raise RuntimeError("Call reset() and step() before evaluate()")
|
|
291
|
+
|
|
292
|
+
# Run WAA's evaluator
|
|
293
|
+
try:
|
|
294
|
+
result = self._desktop_env.evaluate()
|
|
295
|
+
success = result.get("success", False)
|
|
296
|
+
score = 1.0 if success else 0.0
|
|
297
|
+
reason = result.get("reason", None)
|
|
298
|
+
except Exception as e:
|
|
299
|
+
logger.error(f"Evaluation failed for task {task.task_id}: {e}")
|
|
300
|
+
success = False
|
|
301
|
+
score = 0.0
|
|
302
|
+
reason = str(e)
|
|
303
|
+
|
|
304
|
+
return BenchmarkResult(
|
|
305
|
+
task_id=task.task_id,
|
|
306
|
+
success=success,
|
|
307
|
+
score=score,
|
|
308
|
+
reason=reason,
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
def close(self) -> None:
|
|
312
|
+
"""Clean up resources."""
|
|
313
|
+
if self._desktop_env is not None:
|
|
314
|
+
try:
|
|
315
|
+
self._desktop_env.close()
|
|
316
|
+
except Exception as e:
|
|
317
|
+
logger.warning(f"Error closing DesktopEnv: {e}")
|
|
318
|
+
self._desktop_env = None
|
|
319
|
+
|
|
320
|
+
def _create_desktop_env(self):
|
|
321
|
+
"""Create WAA DesktopEnv instance."""
|
|
322
|
+
require_a11y = self.config.observation_type in [
|
|
323
|
+
"a11y_tree",
|
|
324
|
+
"screenshot_a11y_tree",
|
|
325
|
+
"som",
|
|
326
|
+
]
|
|
327
|
+
|
|
328
|
+
return self._DesktopEnv(
|
|
329
|
+
screen_size=(self.config.screen_width, self.config.screen_height),
|
|
330
|
+
require_a11y_tree=require_a11y,
|
|
331
|
+
a11y_backend=self.config.a11y_backend,
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
def _load_all_tasks(self) -> list[BenchmarkTask]:
|
|
335
|
+
"""Load all WAA tasks from the repository."""
|
|
336
|
+
if self._task_cache:
|
|
337
|
+
return list(self._task_cache.values())
|
|
338
|
+
|
|
339
|
+
tasks = []
|
|
340
|
+
|
|
341
|
+
# Load test_all.json metadata
|
|
342
|
+
meta_path = self._tasks_path / "test_all.json"
|
|
343
|
+
if meta_path.exists():
|
|
344
|
+
with open(meta_path, encoding="utf-8") as f:
|
|
345
|
+
meta = json.load(f)
|
|
346
|
+
|
|
347
|
+
for domain, task_ids in meta.items():
|
|
348
|
+
if domain in WAA_DOMAINS:
|
|
349
|
+
for task_id in task_ids:
|
|
350
|
+
task = self._load_task_from_json(domain, task_id)
|
|
351
|
+
if task:
|
|
352
|
+
tasks.append(task)
|
|
353
|
+
self._task_cache[task.task_id] = task
|
|
354
|
+
else:
|
|
355
|
+
# Fallback: scan examples directory
|
|
356
|
+
examples_dir = self._tasks_path / "examples"
|
|
357
|
+
if examples_dir.exists():
|
|
358
|
+
for domain_dir in examples_dir.iterdir():
|
|
359
|
+
if domain_dir.is_dir() and domain_dir.name in WAA_DOMAINS:
|
|
360
|
+
for task_file in domain_dir.glob("*.json"):
|
|
361
|
+
task = self._load_task_from_file(task_file, domain_dir.name)
|
|
362
|
+
if task:
|
|
363
|
+
tasks.append(task)
|
|
364
|
+
self._task_cache[task.task_id] = task
|
|
365
|
+
|
|
366
|
+
logger.info(f"Loaded {len(tasks)} WAA tasks")
|
|
367
|
+
return tasks
|
|
368
|
+
|
|
369
|
+
def _load_task_from_json(self, domain: str, task_id: str) -> BenchmarkTask | None:
|
|
370
|
+
"""Load a task from its JSON file."""
|
|
371
|
+
task_file = self._tasks_path / "examples" / domain / f"{task_id}.json"
|
|
372
|
+
if not task_file.exists():
|
|
373
|
+
logger.warning(f"Task file not found: {task_file}")
|
|
374
|
+
return None
|
|
375
|
+
|
|
376
|
+
return self._load_task_from_file(task_file, domain)
|
|
377
|
+
|
|
378
|
+
def _load_task_from_file(self, task_file: Path, domain: str) -> BenchmarkTask | None:
|
|
379
|
+
"""Load a task from a JSON file."""
|
|
380
|
+
try:
|
|
381
|
+
with open(task_file, encoding="utf-8") as f:
|
|
382
|
+
config = json.load(f)
|
|
383
|
+
|
|
384
|
+
task_id = f"{domain}_{task_file.stem}"
|
|
385
|
+
instruction = config.get("instruction", config.get("task", ""))
|
|
386
|
+
|
|
387
|
+
return BenchmarkTask(
|
|
388
|
+
task_id=task_id,
|
|
389
|
+
instruction=instruction,
|
|
390
|
+
domain=domain,
|
|
391
|
+
initial_state_ref=config.get("snapshot", None),
|
|
392
|
+
time_limit_steps=config.get("max_steps", self.config.max_steps),
|
|
393
|
+
raw_config=config,
|
|
394
|
+
evaluation_spec=config.get("evaluation", None),
|
|
395
|
+
)
|
|
396
|
+
except Exception as e:
|
|
397
|
+
logger.warning(f"Failed to load task from {task_file}: {e}")
|
|
398
|
+
return None
|
|
399
|
+
|
|
400
|
+
def _load_waa_task_config(self, task: BenchmarkTask) -> dict:
|
|
401
|
+
"""Convert BenchmarkTask to WAA's task config format."""
|
|
402
|
+
return task.raw_config
|
|
403
|
+
|
|
404
|
+
def _to_benchmark_observation(self, waa_obs: dict | Any) -> BenchmarkObservation:
|
|
405
|
+
"""Convert WAA observation to canonical format.
|
|
406
|
+
|
|
407
|
+
WAA observations may include:
|
|
408
|
+
- screenshot: PIL Image or bytes
|
|
409
|
+
- a11y_tree: UIA accessibility tree dict
|
|
410
|
+
- window_title: Active window title
|
|
411
|
+
"""
|
|
412
|
+
# Handle different WAA observation formats
|
|
413
|
+
if isinstance(waa_obs, dict):
|
|
414
|
+
screenshot = waa_obs.get("screenshot")
|
|
415
|
+
a11y_tree = waa_obs.get("a11y_tree", waa_obs.get("accessibility_tree"))
|
|
416
|
+
window_title = waa_obs.get("window_title")
|
|
417
|
+
raw_obs = waa_obs
|
|
418
|
+
else:
|
|
419
|
+
# WAA may return observation as object with attributes
|
|
420
|
+
screenshot = getattr(waa_obs, "screenshot", None)
|
|
421
|
+
a11y_tree = getattr(waa_obs, "a11y_tree", None)
|
|
422
|
+
window_title = getattr(waa_obs, "window_title", None)
|
|
423
|
+
raw_obs = {"waa_obs_type": type(waa_obs).__name__}
|
|
424
|
+
|
|
425
|
+
# Convert PIL Image to bytes if needed
|
|
426
|
+
screenshot_bytes = None
|
|
427
|
+
if screenshot is not None:
|
|
428
|
+
if hasattr(screenshot, "tobytes"):
|
|
429
|
+
# PIL Image - convert to PNG bytes
|
|
430
|
+
import io
|
|
431
|
+
buf = io.BytesIO()
|
|
432
|
+
screenshot.save(buf, format="PNG")
|
|
433
|
+
screenshot_bytes = buf.getvalue()
|
|
434
|
+
elif isinstance(screenshot, bytes):
|
|
435
|
+
screenshot_bytes = screenshot
|
|
436
|
+
|
|
437
|
+
return BenchmarkObservation(
|
|
438
|
+
screenshot=screenshot_bytes,
|
|
439
|
+
viewport=(self.config.screen_width, self.config.screen_height),
|
|
440
|
+
accessibility_tree=a11y_tree,
|
|
441
|
+
window_title=window_title,
|
|
442
|
+
raw_observation=raw_obs,
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
def _to_waa_action(self, action: BenchmarkAction) -> dict:
|
|
446
|
+
"""Convert canonical action to WAA format.
|
|
447
|
+
|
|
448
|
+
WAA action format:
|
|
449
|
+
- click: {"action_type": "click", "coordinate": [x, y]}
|
|
450
|
+
- double_click: {"action_type": "double_click", "coordinate": [x, y]}
|
|
451
|
+
- type: {"action_type": "type", "text": "..."}
|
|
452
|
+
- key: {"action_type": "key", "key": "...", "modifiers": [...]}
|
|
453
|
+
- scroll: {"action_type": "scroll", "direction": "...", "amount": ...}
|
|
454
|
+
- drag: {"action_type": "drag", "start": [x, y], "end": [x, y]}
|
|
455
|
+
"""
|
|
456
|
+
action_type = action.type
|
|
457
|
+
|
|
458
|
+
# Map canonical action types to WAA format
|
|
459
|
+
if action_type == "click":
|
|
460
|
+
x = action.x or 0
|
|
461
|
+
y = action.y or 0
|
|
462
|
+
# Convert normalized coords to pixels if needed
|
|
463
|
+
if 0 <= x <= 1 and 0 <= y <= 1:
|
|
464
|
+
x = int(x * self.config.screen_width)
|
|
465
|
+
y = int(y * self.config.screen_height)
|
|
466
|
+
return {
|
|
467
|
+
"action_type": "click",
|
|
468
|
+
"coordinate": [int(x), int(y)],
|
|
469
|
+
}
|
|
470
|
+
|
|
471
|
+
elif action_type == "double_click":
|
|
472
|
+
x = action.x or 0
|
|
473
|
+
y = action.y or 0
|
|
474
|
+
if 0 <= x <= 1 and 0 <= y <= 1:
|
|
475
|
+
x = int(x * self.config.screen_width)
|
|
476
|
+
y = int(y * self.config.screen_height)
|
|
477
|
+
return {
|
|
478
|
+
"action_type": "double_click",
|
|
479
|
+
"coordinate": [int(x), int(y)],
|
|
480
|
+
}
|
|
481
|
+
|
|
482
|
+
elif action_type == "right_click":
|
|
483
|
+
x = action.x or 0
|
|
484
|
+
y = action.y or 0
|
|
485
|
+
if 0 <= x <= 1 and 0 <= y <= 1:
|
|
486
|
+
x = int(x * self.config.screen_width)
|
|
487
|
+
y = int(y * self.config.screen_height)
|
|
488
|
+
return {
|
|
489
|
+
"action_type": "right_click",
|
|
490
|
+
"coordinate": [int(x), int(y)],
|
|
491
|
+
}
|
|
492
|
+
|
|
493
|
+
elif action_type == "type":
|
|
494
|
+
return {
|
|
495
|
+
"action_type": "type",
|
|
496
|
+
"text": action.text or "",
|
|
497
|
+
}
|
|
498
|
+
|
|
499
|
+
elif action_type == "key":
|
|
500
|
+
waa_action = {
|
|
501
|
+
"action_type": "key",
|
|
502
|
+
"key": action.key or "",
|
|
503
|
+
}
|
|
504
|
+
if action.modifiers:
|
|
505
|
+
waa_action["modifiers"] = action.modifiers
|
|
506
|
+
return waa_action
|
|
507
|
+
|
|
508
|
+
elif action_type == "scroll":
|
|
509
|
+
return {
|
|
510
|
+
"action_type": "scroll",
|
|
511
|
+
"direction": action.scroll_direction or "down",
|
|
512
|
+
"amount": action.scroll_amount or 3, # Default scroll amount
|
|
513
|
+
}
|
|
514
|
+
|
|
515
|
+
elif action_type == "drag":
|
|
516
|
+
x1 = action.x or 0
|
|
517
|
+
y1 = action.y or 0
|
|
518
|
+
x2 = action.end_x or 0
|
|
519
|
+
y2 = action.end_y or 0
|
|
520
|
+
# Convert normalized coords
|
|
521
|
+
if 0 <= x1 <= 1:
|
|
522
|
+
x1 = int(x1 * self.config.screen_width)
|
|
523
|
+
y1 = int(y1 * self.config.screen_height)
|
|
524
|
+
if 0 <= x2 <= 1:
|
|
525
|
+
x2 = int(x2 * self.config.screen_width)
|
|
526
|
+
y2 = int(y2 * self.config.screen_height)
|
|
527
|
+
return {
|
|
528
|
+
"action_type": "drag",
|
|
529
|
+
"start": [int(x1), int(y1)],
|
|
530
|
+
"end": [int(x2), int(y2)],
|
|
531
|
+
}
|
|
532
|
+
|
|
533
|
+
elif action_type == "done":
|
|
534
|
+
return {"action_type": "done"}
|
|
535
|
+
|
|
536
|
+
elif action_type == "wait":
|
|
537
|
+
return {"action_type": "wait"}
|
|
538
|
+
|
|
539
|
+
else:
|
|
540
|
+
logger.warning(f"Unknown action type: {action_type}")
|
|
541
|
+
return {"action_type": action_type, "raw": action.raw_action}
|
|
542
|
+
|
|
543
|
+
|
|
544
|
+
class WAAMockAdapter(BenchmarkAdapter):
|
|
545
|
+
"""Mock WAA adapter for testing without Windows VM.
|
|
546
|
+
|
|
547
|
+
Useful for:
|
|
548
|
+
- Testing the benchmark integration without actual WAA
|
|
549
|
+
- Development on non-Windows platforms
|
|
550
|
+
- Unit tests
|
|
551
|
+
|
|
552
|
+
Args:
|
|
553
|
+
num_tasks: Number of mock tasks to generate.
|
|
554
|
+
domains: Domains to include in mock tasks.
|
|
555
|
+
"""
|
|
556
|
+
|
|
557
|
+
def __init__(
|
|
558
|
+
self,
|
|
559
|
+
num_tasks: int = 20,
|
|
560
|
+
domains: list[str] | None = None,
|
|
561
|
+
):
|
|
562
|
+
self._num_tasks = num_tasks
|
|
563
|
+
self._domains = domains or WAA_DOMAINS[:3] # Default to first 3 domains
|
|
564
|
+
self._tasks: list[BenchmarkTask] = []
|
|
565
|
+
self._current_task: BenchmarkTask | None = None
|
|
566
|
+
self._step_count = 0
|
|
567
|
+
self._temp_dir: Path | None = None
|
|
568
|
+
self._generate_mock_tasks()
|
|
569
|
+
|
|
570
|
+
@property
|
|
571
|
+
def name(self) -> str:
|
|
572
|
+
return "waa-mock"
|
|
573
|
+
|
|
574
|
+
@property
|
|
575
|
+
def benchmark_type(self) -> str:
|
|
576
|
+
return "interactive"
|
|
577
|
+
|
|
578
|
+
def _generate_mock_tasks(self) -> None:
|
|
579
|
+
"""Generate mock tasks for testing."""
|
|
580
|
+
tasks_per_domain = self._num_tasks // len(self._domains)
|
|
581
|
+
extra = self._num_tasks % len(self._domains)
|
|
582
|
+
|
|
583
|
+
for i, domain in enumerate(self._domains):
|
|
584
|
+
count = tasks_per_domain + (1 if i < extra else 0)
|
|
585
|
+
for j in range(count):
|
|
586
|
+
task_id = f"{domain}_{j + 1}"
|
|
587
|
+
self._tasks.append(
|
|
588
|
+
BenchmarkTask(
|
|
589
|
+
task_id=task_id,
|
|
590
|
+
instruction=f"Mock task {j + 1} in {domain} domain",
|
|
591
|
+
domain=domain,
|
|
592
|
+
time_limit_steps=15,
|
|
593
|
+
raw_config={"mock": True},
|
|
594
|
+
)
|
|
595
|
+
)
|
|
596
|
+
|
|
597
|
+
def list_tasks(self, domain: str | None = None) -> list[BenchmarkTask]:
|
|
598
|
+
if domain is not None:
|
|
599
|
+
return [t for t in self._tasks if t.domain == domain]
|
|
600
|
+
return self._tasks
|
|
601
|
+
|
|
602
|
+
def load_task(self, task_id: str) -> BenchmarkTask:
|
|
603
|
+
for task in self._tasks:
|
|
604
|
+
if task.task_id == task_id:
|
|
605
|
+
return task
|
|
606
|
+
raise KeyError(f"Task '{task_id}' not found")
|
|
607
|
+
|
|
608
|
+
def reset(self, task: BenchmarkTask) -> BenchmarkObservation:
|
|
609
|
+
self._current_task = task
|
|
610
|
+
self._step_count = 0
|
|
611
|
+
return self._mock_observation()
|
|
612
|
+
|
|
613
|
+
def step(
|
|
614
|
+
self, action: BenchmarkAction
|
|
615
|
+
) -> tuple[BenchmarkObservation, bool, dict[str, Any]]:
|
|
616
|
+
self._step_count += 1
|
|
617
|
+
done = action.type == "done" or self._step_count >= 15
|
|
618
|
+
return self._mock_observation(), done, {"step": self._step_count}
|
|
619
|
+
|
|
620
|
+
def evaluate(self, task: BenchmarkTask) -> BenchmarkResult:
|
|
621
|
+
# Random success for testing
|
|
622
|
+
import random
|
|
623
|
+
success = random.random() < 0.2 # ~20% success rate like WAA SOTA
|
|
624
|
+
return BenchmarkResult(
|
|
625
|
+
task_id=task.task_id,
|
|
626
|
+
success=success,
|
|
627
|
+
score=1.0 if success else 0.0,
|
|
628
|
+
num_steps=self._step_count,
|
|
629
|
+
)
|
|
630
|
+
|
|
631
|
+
def _mock_observation(self) -> BenchmarkObservation:
|
|
632
|
+
"""Generate a mock observation with a real screenshot file."""
|
|
633
|
+
import tempfile
|
|
634
|
+
|
|
635
|
+
# Create temp directory if needed
|
|
636
|
+
if self._temp_dir is None:
|
|
637
|
+
self._temp_dir = Path(tempfile.mkdtemp(prefix="waa_mock_"))
|
|
638
|
+
|
|
639
|
+
# Generate a simple mock screenshot (gray image with text)
|
|
640
|
+
screenshot_path = self._temp_dir / f"mock_step_{self._step_count}.png"
|
|
641
|
+
self._generate_mock_screenshot(screenshot_path)
|
|
642
|
+
|
|
643
|
+
return BenchmarkObservation(
|
|
644
|
+
screenshot=screenshot_path.read_bytes(),
|
|
645
|
+
screenshot_path=str(screenshot_path),
|
|
646
|
+
viewport=(1920, 1200),
|
|
647
|
+
accessibility_tree={
|
|
648
|
+
"role": "window",
|
|
649
|
+
"name": "Mock Window",
|
|
650
|
+
"children": [
|
|
651
|
+
{"role": "button", "name": "OK", "id": "1"},
|
|
652
|
+
{"role": "textfield", "name": "Input", "id": "2"},
|
|
653
|
+
{"role": "button", "name": "Cancel", "id": "3"},
|
|
654
|
+
{"role": "button", "name": "Submit", "id": "4"},
|
|
655
|
+
],
|
|
656
|
+
},
|
|
657
|
+
window_title="Mock Window - Testing",
|
|
658
|
+
)
|
|
659
|
+
|
|
660
|
+
def _generate_mock_screenshot(self, path: Path) -> None:
|
|
661
|
+
"""Generate a simple mock screenshot image."""
|
|
662
|
+
try:
|
|
663
|
+
from PIL import Image, ImageDraw, ImageFont
|
|
664
|
+
|
|
665
|
+
# Create a simple gray image with some UI elements
|
|
666
|
+
img = Image.new("RGB", (1920, 1200), color=(240, 240, 240))
|
|
667
|
+
draw = ImageDraw.Draw(img)
|
|
668
|
+
|
|
669
|
+
# Draw a title bar
|
|
670
|
+
draw.rectangle([0, 0, 1920, 40], fill=(60, 60, 60))
|
|
671
|
+
draw.text((20, 10), "Mock Application Window", fill=(255, 255, 255))
|
|
672
|
+
|
|
673
|
+
# Draw some buttons
|
|
674
|
+
draw.rectangle([100, 100, 200, 140], fill=(0, 120, 215))
|
|
675
|
+
draw.text((120, 110), "OK", fill=(255, 255, 255))
|
|
676
|
+
|
|
677
|
+
draw.rectangle([220, 100, 320, 140], fill=(200, 200, 200))
|
|
678
|
+
draw.text((240, 110), "Cancel", fill=(0, 0, 0))
|
|
679
|
+
|
|
680
|
+
# Draw a text field
|
|
681
|
+
draw.rectangle([100, 160, 500, 200], outline=(100, 100, 100))
|
|
682
|
+
draw.text((110, 170), "Enter text here...", fill=(150, 150, 150))
|
|
683
|
+
|
|
684
|
+
# Draw task instruction
|
|
685
|
+
task_name = self._current_task.task_id if self._current_task else "Unknown"
|
|
686
|
+
draw.text((100, 250), f"Task: {task_name}", fill=(0, 0, 0))
|
|
687
|
+
draw.text((100, 280), f"Step: {self._step_count}", fill=(0, 0, 0))
|
|
688
|
+
|
|
689
|
+
img.save(path)
|
|
690
|
+
except ImportError:
|
|
691
|
+
# Fallback: create a minimal valid PNG if PIL not available
|
|
692
|
+
# This is a 1x1 gray PNG
|
|
693
|
+
minimal_png = bytes([
|
|
694
|
+
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, # PNG signature
|
|
695
|
+
0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52, # IHDR chunk
|
|
696
|
+
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01,
|
|
697
|
+
0x08, 0x02, 0x00, 0x00, 0x00, 0x90, 0x77, 0x53,
|
|
698
|
+
0xDE, 0x00, 0x00, 0x00, 0x0C, 0x49, 0x44, 0x41, # IDAT chunk
|
|
699
|
+
0x54, 0x08, 0xD7, 0x63, 0xF8, 0xCF, 0xC0, 0x00,
|
|
700
|
+
0x00, 0x00, 0x03, 0x00, 0x01, 0x00, 0x05, 0xFE,
|
|
701
|
+
0xD4, 0xEF, 0x00, 0x00, 0x00, 0x00, 0x49, 0x45, # IEND chunk
|
|
702
|
+
0x4E, 0x44, 0xAE, 0x42, 0x60, 0x82
|
|
703
|
+
])
|
|
704
|
+
path.write_bytes(minimal_png)
|