openadapt-ml 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. openadapt_ml/__init__.py +0 -0
  2. openadapt_ml/benchmarks/__init__.py +125 -0
  3. openadapt_ml/benchmarks/agent.py +825 -0
  4. openadapt_ml/benchmarks/azure.py +761 -0
  5. openadapt_ml/benchmarks/base.py +366 -0
  6. openadapt_ml/benchmarks/cli.py +884 -0
  7. openadapt_ml/benchmarks/data_collection.py +432 -0
  8. openadapt_ml/benchmarks/runner.py +381 -0
  9. openadapt_ml/benchmarks/waa.py +704 -0
  10. openadapt_ml/cloud/__init__.py +5 -0
  11. openadapt_ml/cloud/azure_inference.py +441 -0
  12. openadapt_ml/cloud/lambda_labs.py +2445 -0
  13. openadapt_ml/cloud/local.py +790 -0
  14. openadapt_ml/config.py +56 -0
  15. openadapt_ml/datasets/__init__.py +0 -0
  16. openadapt_ml/datasets/next_action.py +507 -0
  17. openadapt_ml/evals/__init__.py +23 -0
  18. openadapt_ml/evals/grounding.py +241 -0
  19. openadapt_ml/evals/plot_eval_metrics.py +174 -0
  20. openadapt_ml/evals/trajectory_matching.py +486 -0
  21. openadapt_ml/grounding/__init__.py +45 -0
  22. openadapt_ml/grounding/base.py +236 -0
  23. openadapt_ml/grounding/detector.py +570 -0
  24. openadapt_ml/ingest/__init__.py +43 -0
  25. openadapt_ml/ingest/capture.py +312 -0
  26. openadapt_ml/ingest/loader.py +232 -0
  27. openadapt_ml/ingest/synthetic.py +1102 -0
  28. openadapt_ml/models/__init__.py +0 -0
  29. openadapt_ml/models/api_adapter.py +171 -0
  30. openadapt_ml/models/base_adapter.py +59 -0
  31. openadapt_ml/models/dummy_adapter.py +42 -0
  32. openadapt_ml/models/qwen_vl.py +426 -0
  33. openadapt_ml/runtime/__init__.py +0 -0
  34. openadapt_ml/runtime/policy.py +182 -0
  35. openadapt_ml/schemas/__init__.py +53 -0
  36. openadapt_ml/schemas/sessions.py +122 -0
  37. openadapt_ml/schemas/validation.py +252 -0
  38. openadapt_ml/scripts/__init__.py +0 -0
  39. openadapt_ml/scripts/compare.py +1490 -0
  40. openadapt_ml/scripts/demo_policy.py +62 -0
  41. openadapt_ml/scripts/eval_policy.py +287 -0
  42. openadapt_ml/scripts/make_gif.py +153 -0
  43. openadapt_ml/scripts/prepare_synthetic.py +43 -0
  44. openadapt_ml/scripts/run_qwen_login_benchmark.py +192 -0
  45. openadapt_ml/scripts/train.py +174 -0
  46. openadapt_ml/training/__init__.py +0 -0
  47. openadapt_ml/training/benchmark_viewer.py +1538 -0
  48. openadapt_ml/training/shared_ui.py +157 -0
  49. openadapt_ml/training/stub_provider.py +276 -0
  50. openadapt_ml/training/trainer.py +2446 -0
  51. openadapt_ml/training/viewer.py +2970 -0
  52. openadapt_ml-0.1.0.dist-info/METADATA +818 -0
  53. openadapt_ml-0.1.0.dist-info/RECORD +55 -0
  54. openadapt_ml-0.1.0.dist-info/WHEEL +4 -0
  55. openadapt_ml-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,704 @@
1
+ """Windows Agent Arena (WAA) benchmark adapter.
2
+
3
+ This module provides integration with the Windows Agent Arena benchmark,
4
+ enabling evaluation of GUI agents on 154 Windows tasks across 11 domains.
5
+
6
+ WAA Repository: https://github.com/microsoft/WindowsAgentArena
7
+
8
+ Example:
9
+ from openadapt_ml.benchmarks import WAAAdapter, PolicyAgent, evaluate_agent_on_benchmark
10
+
11
+ adapter = WAAAdapter(waa_repo_path="/path/to/WindowsAgentArena")
12
+ agent = PolicyAgent(policy)
13
+ results = evaluate_agent_on_benchmark(agent, adapter, max_steps=15)
14
+ print(f"Success rate: {sum(r.success for r in results) / len(results):.1%}")
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import json
20
+ import logging
21
+ import sys
22
+ import time
23
+ from dataclasses import dataclass
24
+ from pathlib import Path
25
+ from typing import Any
26
+
27
+ from openadapt_ml.benchmarks.base import (
28
+ BenchmarkAction,
29
+ BenchmarkAdapter,
30
+ BenchmarkObservation,
31
+ BenchmarkResult,
32
+ BenchmarkTask,
33
+ )
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ # WAA domain mapping (11 domains, 154 tasks)
39
+ WAA_DOMAINS = [
40
+ "browser",
41
+ "office",
42
+ "coding",
43
+ "media",
44
+ "notepad",
45
+ "paint",
46
+ "file_explorer",
47
+ "clock",
48
+ "settings",
49
+ "edge",
50
+ "vscode",
51
+ ]
52
+
53
+
54
+ @dataclass
55
+ class WAAConfig:
56
+ """Configuration for WAA adapter.
57
+
58
+ Attributes:
59
+ waa_repo_path: Path to cloned WindowsAgentArena repository.
60
+ use_azure: Whether to use Azure VMs (enables parallelism).
61
+ observation_type: Type of observation to capture.
62
+ a11y_backend: Accessibility backend ("uia" or "win32").
63
+ screen_width: Screen width in pixels.
64
+ screen_height: Screen height in pixels.
65
+ max_steps: Default maximum steps per task.
66
+ action_delay: Delay between actions in seconds.
67
+ """
68
+
69
+ waa_repo_path: str
70
+ use_azure: bool = False
71
+ observation_type: str = "screenshot_a11y_tree" # "screenshot", "a11y_tree", "som"
72
+ a11y_backend: str = "uia" # "uia" or "win32"
73
+ screen_width: int = 1920
74
+ screen_height: int = 1200
75
+ max_steps: int = 15
76
+ action_delay: float = 0.5
77
+
78
+
79
+ class WAAAdapter(BenchmarkAdapter):
80
+ """Windows Agent Arena benchmark adapter.
81
+
82
+ Integrates with the WAA benchmark to evaluate GUI agents on 154 Windows
83
+ desktop automation tasks spanning 11 application domains.
84
+
85
+ The adapter wraps WAA's DesktopEnv and provides:
86
+ - Task loading from WAA's JSON task definitions
87
+ - VM/environment reset to task initial state
88
+ - Action execution via WAA's controller
89
+ - Evaluation using WAA's native evaluators
90
+
91
+ Args:
92
+ waa_repo_path: Path to cloned WindowsAgentArena repository.
93
+ use_azure: Use Azure VMs for execution (enables parallelism).
94
+ config: Full WAAConfig (overrides other args if provided).
95
+ **kwargs: Additional config options passed to WAAConfig.
96
+
97
+ Raises:
98
+ ValueError: If waa_repo_path doesn't exist.
99
+ ImportError: If WAA dependencies not available.
100
+ """
101
+
102
+ def __init__(
103
+ self,
104
+ waa_repo_path: str | Path | None = None,
105
+ use_azure: bool = False,
106
+ config: WAAConfig | None = None,
107
+ **kwargs,
108
+ ):
109
+ if config is not None:
110
+ self.config = config
111
+ else:
112
+ if waa_repo_path is None:
113
+ raise ValueError("waa_repo_path is required")
114
+ self.config = WAAConfig(
115
+ waa_repo_path=str(waa_repo_path),
116
+ use_azure=use_azure,
117
+ **kwargs,
118
+ )
119
+
120
+ self.waa_repo = Path(self.config.waa_repo_path)
121
+ if not self.waa_repo.exists():
122
+ raise ValueError(f"WAA repository not found at: {self.waa_repo}")
123
+
124
+ # Paths to WAA components
125
+ self._client_path = self.waa_repo / "src" / "win-arena-container" / "client"
126
+ self._tasks_path = self._client_path / "evaluation_examples_windows"
127
+
128
+ # Lazy-loaded WAA components
129
+ self._desktop_env = None
130
+ self._task_cache: dict[str, BenchmarkTask] = {}
131
+ self._current_task: BenchmarkTask | None = None
132
+ self._waa_imported = False
133
+
134
+ def _ensure_waa_imported(self) -> None:
135
+ """Import WAA modules (lazy loading)."""
136
+ if self._waa_imported:
137
+ return
138
+
139
+ # Add WAA client to path
140
+ client_path = str(self._client_path)
141
+ if client_path not in sys.path:
142
+ sys.path.insert(0, client_path)
143
+
144
+ try:
145
+ # Import WAA's DesktopEnv
146
+ from desktop_env import DesktopEnv
147
+
148
+ self._DesktopEnv = DesktopEnv
149
+ self._waa_imported = True
150
+ logger.info("WAA modules imported successfully")
151
+ except ImportError as e:
152
+ raise ImportError(
153
+ f"Failed to import WAA modules. Ensure WAA is properly installed "
154
+ f"and dependencies are available: {e}"
155
+ ) from e
156
+
157
+ @property
158
+ def name(self) -> str:
159
+ """Benchmark name."""
160
+ return "waa"
161
+
162
+ @property
163
+ def benchmark_type(self) -> str:
164
+ """Benchmark type (interactive)."""
165
+ return "interactive"
166
+
167
+ @property
168
+ def supports_parallel(self) -> bool:
169
+ """Whether parallel execution is supported (requires Azure)."""
170
+ return self.config.use_azure
171
+
172
+ def list_tasks(self, domain: str | None = None) -> list[BenchmarkTask]:
173
+ """List available WAA tasks.
174
+
175
+ WAA has 154 tasks across 11 domains:
176
+ - browser: Edge/Chrome navigation and settings
177
+ - office: Word, Excel, Outlook
178
+ - coding: VSCode, terminal
179
+ - settings: Windows Settings app
180
+ - file_explorer: File operations
181
+ - notepad: Text editing
182
+ - paint: Drawing operations
183
+ - media: Video/audio playback
184
+ - clock: Alarms, timers
185
+ - edge: Browser-specific
186
+ - vscode: IDE-specific
187
+
188
+ Args:
189
+ domain: Optional domain filter.
190
+
191
+ Returns:
192
+ List of BenchmarkTask objects.
193
+ """
194
+ tasks = self._load_all_tasks()
195
+
196
+ if domain is not None:
197
+ tasks = [t for t in tasks if t.domain == domain]
198
+
199
+ return tasks
200
+
201
+ def load_task(self, task_id: str) -> BenchmarkTask:
202
+ """Load a specific task by ID.
203
+
204
+ Args:
205
+ task_id: Task identifier (e.g., "notepad_1", "browser_5").
206
+
207
+ Returns:
208
+ BenchmarkTask object.
209
+
210
+ Raises:
211
+ KeyError: If task_id not found.
212
+ """
213
+ if task_id in self._task_cache:
214
+ return self._task_cache[task_id]
215
+
216
+ # Try to load from disk
217
+ tasks = self._load_all_tasks()
218
+ task_map = {t.task_id: t for t in tasks}
219
+
220
+ if task_id not in task_map:
221
+ raise KeyError(f"Task '{task_id}' not found. Available: {list(task_map.keys())[:10]}...")
222
+
223
+ return task_map[task_id]
224
+
225
+ def reset(self, task: BenchmarkTask) -> BenchmarkObservation:
226
+ """Reset environment to task's initial state.
227
+
228
+ This initializes the Windows VM/desktop to the state required for
229
+ the task, including opening required applications and setting up
230
+ any pre-conditions.
231
+
232
+ Args:
233
+ task: Task to initialize.
234
+
235
+ Returns:
236
+ Initial observation (screenshot + accessibility tree).
237
+ """
238
+ self._ensure_waa_imported()
239
+ self._current_task = task
240
+
241
+ # Initialize DesktopEnv if needed
242
+ if self._desktop_env is None:
243
+ self._desktop_env = self._create_desktop_env()
244
+
245
+ # Load task config and reset environment
246
+ task_config = self._load_waa_task_config(task)
247
+ obs = self._desktop_env.reset(task_config=task_config)
248
+
249
+ return self._to_benchmark_observation(obs)
250
+
251
+ def step(
252
+ self, action: BenchmarkAction
253
+ ) -> tuple[BenchmarkObservation, bool, dict[str, Any]]:
254
+ """Execute action and return new observation.
255
+
256
+ Args:
257
+ action: Action to execute.
258
+
259
+ Returns:
260
+ Tuple of (observation, done, info).
261
+ """
262
+ if self._desktop_env is None:
263
+ raise RuntimeError("Call reset() before step()")
264
+
265
+ # Convert to WAA action format
266
+ waa_action = self._to_waa_action(action)
267
+
268
+ # Execute action
269
+ obs, reward, done, info = self._desktop_env.step(waa_action)
270
+
271
+ # Optional delay between actions
272
+ if self.config.action_delay > 0:
273
+ time.sleep(self.config.action_delay)
274
+
275
+ return self._to_benchmark_observation(obs), done, info
276
+
277
+ def evaluate(self, task: BenchmarkTask) -> BenchmarkResult:
278
+ """Run WAA's native evaluation on current state.
279
+
280
+ WAA evaluators check the actual OS state (files, settings, app state)
281
+ to determine if the task was completed successfully.
282
+
283
+ Args:
284
+ task: Task to evaluate.
285
+
286
+ Returns:
287
+ BenchmarkResult with success/score.
288
+ """
289
+ if self._desktop_env is None:
290
+ raise RuntimeError("Call reset() and step() before evaluate()")
291
+
292
+ # Run WAA's evaluator
293
+ try:
294
+ result = self._desktop_env.evaluate()
295
+ success = result.get("success", False)
296
+ score = 1.0 if success else 0.0
297
+ reason = result.get("reason", None)
298
+ except Exception as e:
299
+ logger.error(f"Evaluation failed for task {task.task_id}: {e}")
300
+ success = False
301
+ score = 0.0
302
+ reason = str(e)
303
+
304
+ return BenchmarkResult(
305
+ task_id=task.task_id,
306
+ success=success,
307
+ score=score,
308
+ reason=reason,
309
+ )
310
+
311
+ def close(self) -> None:
312
+ """Clean up resources."""
313
+ if self._desktop_env is not None:
314
+ try:
315
+ self._desktop_env.close()
316
+ except Exception as e:
317
+ logger.warning(f"Error closing DesktopEnv: {e}")
318
+ self._desktop_env = None
319
+
320
+ def _create_desktop_env(self):
321
+ """Create WAA DesktopEnv instance."""
322
+ require_a11y = self.config.observation_type in [
323
+ "a11y_tree",
324
+ "screenshot_a11y_tree",
325
+ "som",
326
+ ]
327
+
328
+ return self._DesktopEnv(
329
+ screen_size=(self.config.screen_width, self.config.screen_height),
330
+ require_a11y_tree=require_a11y,
331
+ a11y_backend=self.config.a11y_backend,
332
+ )
333
+
334
+ def _load_all_tasks(self) -> list[BenchmarkTask]:
335
+ """Load all WAA tasks from the repository."""
336
+ if self._task_cache:
337
+ return list(self._task_cache.values())
338
+
339
+ tasks = []
340
+
341
+ # Load test_all.json metadata
342
+ meta_path = self._tasks_path / "test_all.json"
343
+ if meta_path.exists():
344
+ with open(meta_path, encoding="utf-8") as f:
345
+ meta = json.load(f)
346
+
347
+ for domain, task_ids in meta.items():
348
+ if domain in WAA_DOMAINS:
349
+ for task_id in task_ids:
350
+ task = self._load_task_from_json(domain, task_id)
351
+ if task:
352
+ tasks.append(task)
353
+ self._task_cache[task.task_id] = task
354
+ else:
355
+ # Fallback: scan examples directory
356
+ examples_dir = self._tasks_path / "examples"
357
+ if examples_dir.exists():
358
+ for domain_dir in examples_dir.iterdir():
359
+ if domain_dir.is_dir() and domain_dir.name in WAA_DOMAINS:
360
+ for task_file in domain_dir.glob("*.json"):
361
+ task = self._load_task_from_file(task_file, domain_dir.name)
362
+ if task:
363
+ tasks.append(task)
364
+ self._task_cache[task.task_id] = task
365
+
366
+ logger.info(f"Loaded {len(tasks)} WAA tasks")
367
+ return tasks
368
+
369
+ def _load_task_from_json(self, domain: str, task_id: str) -> BenchmarkTask | None:
370
+ """Load a task from its JSON file."""
371
+ task_file = self._tasks_path / "examples" / domain / f"{task_id}.json"
372
+ if not task_file.exists():
373
+ logger.warning(f"Task file not found: {task_file}")
374
+ return None
375
+
376
+ return self._load_task_from_file(task_file, domain)
377
+
378
+ def _load_task_from_file(self, task_file: Path, domain: str) -> BenchmarkTask | None:
379
+ """Load a task from a JSON file."""
380
+ try:
381
+ with open(task_file, encoding="utf-8") as f:
382
+ config = json.load(f)
383
+
384
+ task_id = f"{domain}_{task_file.stem}"
385
+ instruction = config.get("instruction", config.get("task", ""))
386
+
387
+ return BenchmarkTask(
388
+ task_id=task_id,
389
+ instruction=instruction,
390
+ domain=domain,
391
+ initial_state_ref=config.get("snapshot", None),
392
+ time_limit_steps=config.get("max_steps", self.config.max_steps),
393
+ raw_config=config,
394
+ evaluation_spec=config.get("evaluation", None),
395
+ )
396
+ except Exception as e:
397
+ logger.warning(f"Failed to load task from {task_file}: {e}")
398
+ return None
399
+
400
+ def _load_waa_task_config(self, task: BenchmarkTask) -> dict:
401
+ """Convert BenchmarkTask to WAA's task config format."""
402
+ return task.raw_config
403
+
404
+ def _to_benchmark_observation(self, waa_obs: dict | Any) -> BenchmarkObservation:
405
+ """Convert WAA observation to canonical format.
406
+
407
+ WAA observations may include:
408
+ - screenshot: PIL Image or bytes
409
+ - a11y_tree: UIA accessibility tree dict
410
+ - window_title: Active window title
411
+ """
412
+ # Handle different WAA observation formats
413
+ if isinstance(waa_obs, dict):
414
+ screenshot = waa_obs.get("screenshot")
415
+ a11y_tree = waa_obs.get("a11y_tree", waa_obs.get("accessibility_tree"))
416
+ window_title = waa_obs.get("window_title")
417
+ raw_obs = waa_obs
418
+ else:
419
+ # WAA may return observation as object with attributes
420
+ screenshot = getattr(waa_obs, "screenshot", None)
421
+ a11y_tree = getattr(waa_obs, "a11y_tree", None)
422
+ window_title = getattr(waa_obs, "window_title", None)
423
+ raw_obs = {"waa_obs_type": type(waa_obs).__name__}
424
+
425
+ # Convert PIL Image to bytes if needed
426
+ screenshot_bytes = None
427
+ if screenshot is not None:
428
+ if hasattr(screenshot, "tobytes"):
429
+ # PIL Image - convert to PNG bytes
430
+ import io
431
+ buf = io.BytesIO()
432
+ screenshot.save(buf, format="PNG")
433
+ screenshot_bytes = buf.getvalue()
434
+ elif isinstance(screenshot, bytes):
435
+ screenshot_bytes = screenshot
436
+
437
+ return BenchmarkObservation(
438
+ screenshot=screenshot_bytes,
439
+ viewport=(self.config.screen_width, self.config.screen_height),
440
+ accessibility_tree=a11y_tree,
441
+ window_title=window_title,
442
+ raw_observation=raw_obs,
443
+ )
444
+
445
+ def _to_waa_action(self, action: BenchmarkAction) -> dict:
446
+ """Convert canonical action to WAA format.
447
+
448
+ WAA action format:
449
+ - click: {"action_type": "click", "coordinate": [x, y]}
450
+ - double_click: {"action_type": "double_click", "coordinate": [x, y]}
451
+ - type: {"action_type": "type", "text": "..."}
452
+ - key: {"action_type": "key", "key": "...", "modifiers": [...]}
453
+ - scroll: {"action_type": "scroll", "direction": "...", "amount": ...}
454
+ - drag: {"action_type": "drag", "start": [x, y], "end": [x, y]}
455
+ """
456
+ action_type = action.type
457
+
458
+ # Map canonical action types to WAA format
459
+ if action_type == "click":
460
+ x = action.x or 0
461
+ y = action.y or 0
462
+ # Convert normalized coords to pixels if needed
463
+ if 0 <= x <= 1 and 0 <= y <= 1:
464
+ x = int(x * self.config.screen_width)
465
+ y = int(y * self.config.screen_height)
466
+ return {
467
+ "action_type": "click",
468
+ "coordinate": [int(x), int(y)],
469
+ }
470
+
471
+ elif action_type == "double_click":
472
+ x = action.x or 0
473
+ y = action.y or 0
474
+ if 0 <= x <= 1 and 0 <= y <= 1:
475
+ x = int(x * self.config.screen_width)
476
+ y = int(y * self.config.screen_height)
477
+ return {
478
+ "action_type": "double_click",
479
+ "coordinate": [int(x), int(y)],
480
+ }
481
+
482
+ elif action_type == "right_click":
483
+ x = action.x or 0
484
+ y = action.y or 0
485
+ if 0 <= x <= 1 and 0 <= y <= 1:
486
+ x = int(x * self.config.screen_width)
487
+ y = int(y * self.config.screen_height)
488
+ return {
489
+ "action_type": "right_click",
490
+ "coordinate": [int(x), int(y)],
491
+ }
492
+
493
+ elif action_type == "type":
494
+ return {
495
+ "action_type": "type",
496
+ "text": action.text or "",
497
+ }
498
+
499
+ elif action_type == "key":
500
+ waa_action = {
501
+ "action_type": "key",
502
+ "key": action.key or "",
503
+ }
504
+ if action.modifiers:
505
+ waa_action["modifiers"] = action.modifiers
506
+ return waa_action
507
+
508
+ elif action_type == "scroll":
509
+ return {
510
+ "action_type": "scroll",
511
+ "direction": action.scroll_direction or "down",
512
+ "amount": action.scroll_amount or 3, # Default scroll amount
513
+ }
514
+
515
+ elif action_type == "drag":
516
+ x1 = action.x or 0
517
+ y1 = action.y or 0
518
+ x2 = action.end_x or 0
519
+ y2 = action.end_y or 0
520
+ # Convert normalized coords
521
+ if 0 <= x1 <= 1:
522
+ x1 = int(x1 * self.config.screen_width)
523
+ y1 = int(y1 * self.config.screen_height)
524
+ if 0 <= x2 <= 1:
525
+ x2 = int(x2 * self.config.screen_width)
526
+ y2 = int(y2 * self.config.screen_height)
527
+ return {
528
+ "action_type": "drag",
529
+ "start": [int(x1), int(y1)],
530
+ "end": [int(x2), int(y2)],
531
+ }
532
+
533
+ elif action_type == "done":
534
+ return {"action_type": "done"}
535
+
536
+ elif action_type == "wait":
537
+ return {"action_type": "wait"}
538
+
539
+ else:
540
+ logger.warning(f"Unknown action type: {action_type}")
541
+ return {"action_type": action_type, "raw": action.raw_action}
542
+
543
+
544
+ class WAAMockAdapter(BenchmarkAdapter):
545
+ """Mock WAA adapter for testing without Windows VM.
546
+
547
+ Useful for:
548
+ - Testing the benchmark integration without actual WAA
549
+ - Development on non-Windows platforms
550
+ - Unit tests
551
+
552
+ Args:
553
+ num_tasks: Number of mock tasks to generate.
554
+ domains: Domains to include in mock tasks.
555
+ """
556
+
557
+ def __init__(
558
+ self,
559
+ num_tasks: int = 20,
560
+ domains: list[str] | None = None,
561
+ ):
562
+ self._num_tasks = num_tasks
563
+ self._domains = domains or WAA_DOMAINS[:3] # Default to first 3 domains
564
+ self._tasks: list[BenchmarkTask] = []
565
+ self._current_task: BenchmarkTask | None = None
566
+ self._step_count = 0
567
+ self._temp_dir: Path | None = None
568
+ self._generate_mock_tasks()
569
+
570
+ @property
571
+ def name(self) -> str:
572
+ return "waa-mock"
573
+
574
+ @property
575
+ def benchmark_type(self) -> str:
576
+ return "interactive"
577
+
578
+ def _generate_mock_tasks(self) -> None:
579
+ """Generate mock tasks for testing."""
580
+ tasks_per_domain = self._num_tasks // len(self._domains)
581
+ extra = self._num_tasks % len(self._domains)
582
+
583
+ for i, domain in enumerate(self._domains):
584
+ count = tasks_per_domain + (1 if i < extra else 0)
585
+ for j in range(count):
586
+ task_id = f"{domain}_{j + 1}"
587
+ self._tasks.append(
588
+ BenchmarkTask(
589
+ task_id=task_id,
590
+ instruction=f"Mock task {j + 1} in {domain} domain",
591
+ domain=domain,
592
+ time_limit_steps=15,
593
+ raw_config={"mock": True},
594
+ )
595
+ )
596
+
597
+ def list_tasks(self, domain: str | None = None) -> list[BenchmarkTask]:
598
+ if domain is not None:
599
+ return [t for t in self._tasks if t.domain == domain]
600
+ return self._tasks
601
+
602
+ def load_task(self, task_id: str) -> BenchmarkTask:
603
+ for task in self._tasks:
604
+ if task.task_id == task_id:
605
+ return task
606
+ raise KeyError(f"Task '{task_id}' not found")
607
+
608
+ def reset(self, task: BenchmarkTask) -> BenchmarkObservation:
609
+ self._current_task = task
610
+ self._step_count = 0
611
+ return self._mock_observation()
612
+
613
+ def step(
614
+ self, action: BenchmarkAction
615
+ ) -> tuple[BenchmarkObservation, bool, dict[str, Any]]:
616
+ self._step_count += 1
617
+ done = action.type == "done" or self._step_count >= 15
618
+ return self._mock_observation(), done, {"step": self._step_count}
619
+
620
+ def evaluate(self, task: BenchmarkTask) -> BenchmarkResult:
621
+ # Random success for testing
622
+ import random
623
+ success = random.random() < 0.2 # ~20% success rate like WAA SOTA
624
+ return BenchmarkResult(
625
+ task_id=task.task_id,
626
+ success=success,
627
+ score=1.0 if success else 0.0,
628
+ num_steps=self._step_count,
629
+ )
630
+
631
+ def _mock_observation(self) -> BenchmarkObservation:
632
+ """Generate a mock observation with a real screenshot file."""
633
+ import tempfile
634
+
635
+ # Create temp directory if needed
636
+ if self._temp_dir is None:
637
+ self._temp_dir = Path(tempfile.mkdtemp(prefix="waa_mock_"))
638
+
639
+ # Generate a simple mock screenshot (gray image with text)
640
+ screenshot_path = self._temp_dir / f"mock_step_{self._step_count}.png"
641
+ self._generate_mock_screenshot(screenshot_path)
642
+
643
+ return BenchmarkObservation(
644
+ screenshot=screenshot_path.read_bytes(),
645
+ screenshot_path=str(screenshot_path),
646
+ viewport=(1920, 1200),
647
+ accessibility_tree={
648
+ "role": "window",
649
+ "name": "Mock Window",
650
+ "children": [
651
+ {"role": "button", "name": "OK", "id": "1"},
652
+ {"role": "textfield", "name": "Input", "id": "2"},
653
+ {"role": "button", "name": "Cancel", "id": "3"},
654
+ {"role": "button", "name": "Submit", "id": "4"},
655
+ ],
656
+ },
657
+ window_title="Mock Window - Testing",
658
+ )
659
+
660
+ def _generate_mock_screenshot(self, path: Path) -> None:
661
+ """Generate a simple mock screenshot image."""
662
+ try:
663
+ from PIL import Image, ImageDraw, ImageFont
664
+
665
+ # Create a simple gray image with some UI elements
666
+ img = Image.new("RGB", (1920, 1200), color=(240, 240, 240))
667
+ draw = ImageDraw.Draw(img)
668
+
669
+ # Draw a title bar
670
+ draw.rectangle([0, 0, 1920, 40], fill=(60, 60, 60))
671
+ draw.text((20, 10), "Mock Application Window", fill=(255, 255, 255))
672
+
673
+ # Draw some buttons
674
+ draw.rectangle([100, 100, 200, 140], fill=(0, 120, 215))
675
+ draw.text((120, 110), "OK", fill=(255, 255, 255))
676
+
677
+ draw.rectangle([220, 100, 320, 140], fill=(200, 200, 200))
678
+ draw.text((240, 110), "Cancel", fill=(0, 0, 0))
679
+
680
+ # Draw a text field
681
+ draw.rectangle([100, 160, 500, 200], outline=(100, 100, 100))
682
+ draw.text((110, 170), "Enter text here...", fill=(150, 150, 150))
683
+
684
+ # Draw task instruction
685
+ task_name = self._current_task.task_id if self._current_task else "Unknown"
686
+ draw.text((100, 250), f"Task: {task_name}", fill=(0, 0, 0))
687
+ draw.text((100, 280), f"Step: {self._step_count}", fill=(0, 0, 0))
688
+
689
+ img.save(path)
690
+ except ImportError:
691
+ # Fallback: create a minimal valid PNG if PIL not available
692
+ # This is a 1x1 gray PNG
693
+ minimal_png = bytes([
694
+ 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, # PNG signature
695
+ 0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, 0x44, 0x52, # IHDR chunk
696
+ 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01,
697
+ 0x08, 0x02, 0x00, 0x00, 0x00, 0x90, 0x77, 0x53,
698
+ 0xDE, 0x00, 0x00, 0x00, 0x0C, 0x49, 0x44, 0x41, # IDAT chunk
699
+ 0x54, 0x08, 0xD7, 0x63, 0xF8, 0xCF, 0xC0, 0x00,
700
+ 0x00, 0x00, 0x03, 0x00, 0x01, 0x00, 0x05, 0xFE,
701
+ 0xD4, 0xEF, 0x00, 0x00, 0x00, 0x00, 0x49, 0x45, # IEND chunk
702
+ 0x4E, 0x44, 0xAE, 0x42, 0x60, 0x82
703
+ ])
704
+ path.write_bytes(minimal_png)