openadapt-ml 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. openadapt_ml/__init__.py +0 -0
  2. openadapt_ml/benchmarks/__init__.py +125 -0
  3. openadapt_ml/benchmarks/agent.py +825 -0
  4. openadapt_ml/benchmarks/azure.py +761 -0
  5. openadapt_ml/benchmarks/base.py +366 -0
  6. openadapt_ml/benchmarks/cli.py +884 -0
  7. openadapt_ml/benchmarks/data_collection.py +432 -0
  8. openadapt_ml/benchmarks/runner.py +381 -0
  9. openadapt_ml/benchmarks/waa.py +704 -0
  10. openadapt_ml/cloud/__init__.py +5 -0
  11. openadapt_ml/cloud/azure_inference.py +441 -0
  12. openadapt_ml/cloud/lambda_labs.py +2445 -0
  13. openadapt_ml/cloud/local.py +790 -0
  14. openadapt_ml/config.py +56 -0
  15. openadapt_ml/datasets/__init__.py +0 -0
  16. openadapt_ml/datasets/next_action.py +507 -0
  17. openadapt_ml/evals/__init__.py +23 -0
  18. openadapt_ml/evals/grounding.py +241 -0
  19. openadapt_ml/evals/plot_eval_metrics.py +174 -0
  20. openadapt_ml/evals/trajectory_matching.py +486 -0
  21. openadapt_ml/grounding/__init__.py +45 -0
  22. openadapt_ml/grounding/base.py +236 -0
  23. openadapt_ml/grounding/detector.py +570 -0
  24. openadapt_ml/ingest/__init__.py +43 -0
  25. openadapt_ml/ingest/capture.py +312 -0
  26. openadapt_ml/ingest/loader.py +232 -0
  27. openadapt_ml/ingest/synthetic.py +1102 -0
  28. openadapt_ml/models/__init__.py +0 -0
  29. openadapt_ml/models/api_adapter.py +171 -0
  30. openadapt_ml/models/base_adapter.py +59 -0
  31. openadapt_ml/models/dummy_adapter.py +42 -0
  32. openadapt_ml/models/qwen_vl.py +426 -0
  33. openadapt_ml/runtime/__init__.py +0 -0
  34. openadapt_ml/runtime/policy.py +182 -0
  35. openadapt_ml/schemas/__init__.py +53 -0
  36. openadapt_ml/schemas/sessions.py +122 -0
  37. openadapt_ml/schemas/validation.py +252 -0
  38. openadapt_ml/scripts/__init__.py +0 -0
  39. openadapt_ml/scripts/compare.py +1490 -0
  40. openadapt_ml/scripts/demo_policy.py +62 -0
  41. openadapt_ml/scripts/eval_policy.py +287 -0
  42. openadapt_ml/scripts/make_gif.py +153 -0
  43. openadapt_ml/scripts/prepare_synthetic.py +43 -0
  44. openadapt_ml/scripts/run_qwen_login_benchmark.py +192 -0
  45. openadapt_ml/scripts/train.py +174 -0
  46. openadapt_ml/training/__init__.py +0 -0
  47. openadapt_ml/training/benchmark_viewer.py +1538 -0
  48. openadapt_ml/training/shared_ui.py +157 -0
  49. openadapt_ml/training/stub_provider.py +276 -0
  50. openadapt_ml/training/trainer.py +2446 -0
  51. openadapt_ml/training/viewer.py +2970 -0
  52. openadapt_ml-0.1.0.dist-info/METADATA +818 -0
  53. openadapt_ml-0.1.0.dist-info/RECORD +55 -0
  54. openadapt_ml-0.1.0.dist-info/WHEEL +4 -0
  55. openadapt_ml-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,761 @@
1
+ """Azure deployment automation for WAA benchmark.
2
+
3
+ This module provides Azure VM orchestration for running Windows Agent Arena
4
+ at scale across multiple parallel VMs.
5
+
6
+ Requirements:
7
+ - azure-ai-ml
8
+ - azure-identity
9
+ - Azure subscription with ML workspace
10
+
11
+ Example:
12
+ from openadapt_ml.benchmarks.azure import AzureWAAOrchestrator, AzureConfig
13
+
14
+ config = AzureConfig(
15
+ subscription_id="your-subscription-id",
16
+ resource_group="agents",
17
+ workspace_name="agents_ml",
18
+ )
19
+ orchestrator = AzureWAAOrchestrator(config, waa_repo_path="/path/to/WAA")
20
+
21
+ # Run evaluation on 40 parallel VMs
22
+ results = orchestrator.run_evaluation(
23
+ agent=my_agent,
24
+ num_workers=40,
25
+ task_ids=None, # All tasks
26
+ )
27
+ """
28
+
29
+ from __future__ import annotations
30
+
31
+ import json
32
+ import logging
33
+ import os
34
+ import tempfile
35
+ import time
36
+ from concurrent.futures import ThreadPoolExecutor, as_completed
37
+ from dataclasses import dataclass, field
38
+ from pathlib import Path
39
+ from typing import Any, Callable
40
+
41
+ from openadapt_ml.benchmarks.agent import BenchmarkAgent
42
+ from openadapt_ml.benchmarks.base import BenchmarkResult, BenchmarkTask
43
+
44
+ logger = logging.getLogger(__name__)
45
+
46
+
47
+ @dataclass
48
+ class AzureConfig:
49
+ """Azure configuration for WAA deployment.
50
+
51
+ Attributes:
52
+ subscription_id: Azure subscription ID.
53
+ resource_group: Resource group containing ML workspace.
54
+ workspace_name: Azure ML workspace name.
55
+ vm_size: VM size for compute instances (must support nested virtualization).
56
+ idle_timeout_minutes: Auto-shutdown after idle (minutes).
57
+ docker_image: Docker image for agent container.
58
+ storage_account: Storage account for results (auto-detected if None).
59
+ use_managed_identity: Whether to use managed identity for auth.
60
+ managed_identity_name: Name of managed identity (if using).
61
+ """
62
+
63
+ subscription_id: str
64
+ resource_group: str
65
+ workspace_name: str
66
+ vm_size: str = "Standard_D2_v3" # 2 vCPUs (fits free trial with existing usage)
67
+ idle_timeout_minutes: int = 60
68
+ docker_image: str = "ghcr.io/microsoft/windowsagentarena:latest"
69
+ storage_account: str | None = None
70
+ use_managed_identity: bool = False
71
+ managed_identity_name: str | None = None
72
+
73
+ @classmethod
74
+ def from_env(cls) -> AzureConfig:
75
+ """Create config from environment variables / .env file.
76
+
77
+ Uses settings from openadapt_ml.config which loads from:
78
+ 1. Environment variables
79
+ 2. .env file
80
+ 3. Default values
81
+
82
+ Required settings:
83
+ AZURE_SUBSCRIPTION_ID
84
+ AZURE_ML_RESOURCE_GROUP
85
+ AZURE_ML_WORKSPACE_NAME
86
+
87
+ Optional settings:
88
+ AZURE_VM_SIZE (default: Standard_D4_v3 for free trial compatibility)
89
+ AZURE_DOCKER_IMAGE (default: ghcr.io/microsoft/windowsagentarena:latest)
90
+
91
+ Authentication (one of):
92
+ - AZURE_CLIENT_ID, AZURE_CLIENT_SECRET, AZURE_TENANT_ID (service principal)
93
+ - Azure CLI login (`az login`)
94
+ - Managed Identity (when running on Azure)
95
+
96
+ Raises:
97
+ ValueError: If required settings are not configured.
98
+ """
99
+ from openadapt_ml.config import settings
100
+
101
+ # Validate required settings
102
+ if not settings.azure_subscription_id:
103
+ raise ValueError(
104
+ "AZURE_SUBSCRIPTION_ID not set. "
105
+ "Run 'python scripts/setup_azure.py' to configure Azure credentials."
106
+ )
107
+ if not settings.azure_ml_resource_group:
108
+ raise ValueError(
109
+ "AZURE_ML_RESOURCE_GROUP not set. "
110
+ "Run 'python scripts/setup_azure.py' to configure Azure credentials."
111
+ )
112
+ if not settings.azure_ml_workspace_name:
113
+ raise ValueError(
114
+ "AZURE_ML_WORKSPACE_NAME not set. "
115
+ "Run 'python scripts/setup_azure.py' to configure Azure credentials."
116
+ )
117
+
118
+ return cls(
119
+ subscription_id=settings.azure_subscription_id,
120
+ resource_group=settings.azure_ml_resource_group,
121
+ workspace_name=settings.azure_ml_workspace_name,
122
+ vm_size=settings.azure_vm_size,
123
+ docker_image=settings.azure_docker_image,
124
+ )
125
+
126
+ @classmethod
127
+ def from_json(cls, path: str | Path) -> AzureConfig:
128
+ """Load config from JSON file."""
129
+ with open(path) as f:
130
+ data = json.load(f)
131
+ return cls(**data)
132
+
133
+ def to_json(self, path: str | Path) -> None:
134
+ """Save config to JSON file."""
135
+ with open(path, "w") as f:
136
+ json.dump(self.__dict__, f, indent=2)
137
+
138
+
139
+ @dataclass
140
+ class WorkerState:
141
+ """State of a single worker VM."""
142
+
143
+ worker_id: int
144
+ compute_name: str
145
+ status: str = "pending" # pending, running, completed, failed
146
+ assigned_tasks: list[str] = field(default_factory=list)
147
+ completed_tasks: list[str] = field(default_factory=list)
148
+ results: list[BenchmarkResult] = field(default_factory=list)
149
+ error: str | None = None
150
+ start_time: float | None = None
151
+ end_time: float | None = None
152
+
153
+
154
+ @dataclass
155
+ class EvaluationRun:
156
+ """State of an evaluation run across multiple workers."""
157
+
158
+ run_id: str
159
+ experiment_name: str
160
+ num_workers: int
161
+ total_tasks: int
162
+ workers: list[WorkerState] = field(default_factory=list)
163
+ status: str = "pending" # pending, running, completed, failed
164
+ start_time: float | None = None
165
+ end_time: float | None = None
166
+
167
+ def to_dict(self) -> dict:
168
+ """Serialize to dict for JSON storage."""
169
+ return {
170
+ "run_id": self.run_id,
171
+ "experiment_name": self.experiment_name,
172
+ "num_workers": self.num_workers,
173
+ "total_tasks": self.total_tasks,
174
+ "status": self.status,
175
+ "start_time": self.start_time,
176
+ "end_time": self.end_time,
177
+ "workers": [
178
+ {
179
+ "worker_id": w.worker_id,
180
+ "compute_name": w.compute_name,
181
+ "status": w.status,
182
+ "assigned_tasks": w.assigned_tasks,
183
+ "completed_tasks": w.completed_tasks,
184
+ "error": w.error,
185
+ }
186
+ for w in self.workers
187
+ ],
188
+ }
189
+
190
+
191
+ class AzureMLClient:
192
+ """Wrapper around Azure ML SDK for compute management.
193
+
194
+ This provides a simplified interface for creating and managing
195
+ Azure ML compute instances for WAA evaluation.
196
+ """
197
+
198
+ def __init__(self, config: AzureConfig):
199
+ self.config = config
200
+ self._client = None
201
+ self._ensure_sdk_available()
202
+
203
+ def _ensure_sdk_available(self) -> None:
204
+ """Check that Azure SDK is available."""
205
+ try:
206
+ from azure.ai.ml import MLClient
207
+ from azure.identity import (
208
+ ClientSecretCredential,
209
+ DefaultAzureCredential,
210
+ )
211
+
212
+ self._MLClient = MLClient
213
+ self._DefaultAzureCredential = DefaultAzureCredential
214
+ self._ClientSecretCredential = ClientSecretCredential
215
+ except ImportError as e:
216
+ raise ImportError(
217
+ "Azure ML SDK not installed. Install with: "
218
+ "pip install azure-ai-ml azure-identity"
219
+ ) from e
220
+
221
+ @property
222
+ def client(self):
223
+ """Lazy-load ML client.
224
+
225
+ Uses service principal credentials if configured in .env,
226
+ otherwise falls back to DefaultAzureCredential (CLI login, managed identity, etc.)
227
+ """
228
+ if self._client is None:
229
+ credential = self._get_credential()
230
+ self._client = self._MLClient(
231
+ credential=credential,
232
+ subscription_id=self.config.subscription_id,
233
+ resource_group_name=self.config.resource_group,
234
+ workspace_name=self.config.workspace_name,
235
+ )
236
+ logger.info(f"Connected to Azure ML workspace: {self.config.workspace_name}")
237
+ return self._client
238
+
239
+ def _get_credential(self):
240
+ """Get Azure credential, preferring service principal if configured."""
241
+ from openadapt_ml.config import settings
242
+
243
+ # Use service principal if credentials are configured
244
+ if all([
245
+ settings.azure_client_id,
246
+ settings.azure_client_secret,
247
+ settings.azure_tenant_id,
248
+ ]):
249
+ logger.info("Using service principal authentication")
250
+ return self._ClientSecretCredential(
251
+ tenant_id=settings.azure_tenant_id,
252
+ client_id=settings.azure_client_id,
253
+ client_secret=settings.azure_client_secret,
254
+ )
255
+
256
+ # Fall back to DefaultAzureCredential (CLI login, managed identity, etc.)
257
+ logger.info(
258
+ "Using DefaultAzureCredential (ensure you're logged in with 'az login' "
259
+ "or have service principal credentials in .env)"
260
+ )
261
+ return self._DefaultAzureCredential()
262
+
263
+ def create_compute_instance(
264
+ self,
265
+ name: str,
266
+ startup_script: str | None = None, # noqa: ARG002 - reserved for future use
267
+ ) -> str:
268
+ """Create a compute instance.
269
+
270
+ Args:
271
+ name: Compute instance name.
272
+ startup_script: Optional startup script content (not yet implemented).
273
+
274
+ Returns:
275
+ Compute instance name.
276
+ """
277
+ # TODO: Add startup_script support when implementing full WAA integration
278
+ _ = startup_script # Reserved for future use
279
+ from azure.ai.ml.entities import ComputeInstance
280
+
281
+ # Check if already exists
282
+ try:
283
+ existing = self.client.compute.get(name)
284
+ if existing:
285
+ logger.info(f"Compute instance {name} already exists")
286
+ return name
287
+ except Exception:
288
+ pass # Doesn't exist, create it
289
+
290
+ compute = ComputeInstance(
291
+ name=name,
292
+ size=self.config.vm_size,
293
+ idle_time_before_shutdown_minutes=self.config.idle_timeout_minutes,
294
+ )
295
+
296
+ # Add managed identity if configured
297
+ if self.config.use_managed_identity and self.config.managed_identity_name:
298
+ identity_id = (
299
+ f"/subscriptions/{self.config.subscription_id}"
300
+ f"/resourceGroups/{self.config.resource_group}"
301
+ f"/providers/Microsoft.ManagedIdentity"
302
+ f"/userAssignedIdentities/{self.config.managed_identity_name}"
303
+ )
304
+ compute.identity = {"type": "UserAssigned", "user_assigned_identities": [identity_id]}
305
+
306
+ print(f" Creating VM: {name}...", end="", flush=True)
307
+ self.client.compute.begin_create_or_update(compute).result()
308
+ print(" done")
309
+
310
+ return name
311
+
312
+ def delete_compute_instance(self, name: str) -> None:
313
+ """Delete a compute instance.
314
+
315
+ Args:
316
+ name: Compute instance name.
317
+ """
318
+ try:
319
+ logger.info(f"Deleting compute instance: {name}")
320
+ self.client.compute.begin_delete(name).result()
321
+ logger.info(f"Compute instance {name} deleted")
322
+ except Exception as e:
323
+ logger.warning(f"Failed to delete compute instance {name}: {e}")
324
+
325
+ def list_compute_instances(self, prefix: str | None = None) -> list[str]:
326
+ """List compute instances.
327
+
328
+ Args:
329
+ prefix: Optional name prefix filter.
330
+
331
+ Returns:
332
+ List of compute instance names.
333
+ """
334
+ computes = self.client.compute.list()
335
+ names = [c.name for c in computes if c.type == "ComputeInstance"]
336
+ if prefix:
337
+ names = [n for n in names if n.startswith(prefix)]
338
+ return names
339
+
340
+ def get_compute_status(self, name: str) -> str:
341
+ """Get compute instance status.
342
+
343
+ Args:
344
+ name: Compute instance name.
345
+
346
+ Returns:
347
+ Status string (Running, Stopped, etc.)
348
+ """
349
+ compute = self.client.compute.get(name)
350
+ return compute.state
351
+
352
+ def submit_job(
353
+ self,
354
+ compute_name: str,
355
+ command: str,
356
+ environment_variables: dict[str, str] | None = None,
357
+ display_name: str | None = None,
358
+ ) -> str:
359
+ """Submit a job to a compute instance.
360
+
361
+ Args:
362
+ compute_name: Target compute instance.
363
+ command: Command to run.
364
+ environment_variables: Environment variables.
365
+ display_name: Job display name.
366
+
367
+ Returns:
368
+ Job name/ID.
369
+ """
370
+ from azure.ai.ml import command as ml_command
371
+ from azure.ai.ml.entities import Environment
372
+
373
+ # Create environment with Docker image
374
+ env = Environment(
375
+ image=self.config.docker_image,
376
+ name="waa-agent-env",
377
+ )
378
+
379
+ job = ml_command(
380
+ command=command,
381
+ environment=env,
382
+ compute=compute_name,
383
+ display_name=display_name or f"waa-job-{compute_name}",
384
+ environment_variables=environment_variables or {},
385
+ )
386
+
387
+ submitted = self.client.jobs.create_or_update(job)
388
+ logger.info(f"Job submitted: {submitted.name}")
389
+ return submitted.name
390
+
391
+ def wait_for_job(self, job_name: str, timeout_seconds: int = 3600) -> dict:
392
+ """Wait for a job to complete.
393
+
394
+ Args:
395
+ job_name: Job name/ID.
396
+ timeout_seconds: Maximum wait time.
397
+
398
+ Returns:
399
+ Job result dict.
400
+ """
401
+ start_time = time.time()
402
+ while time.time() - start_time < timeout_seconds:
403
+ job = self.client.jobs.get(job_name)
404
+ if job.status in ["Completed", "Failed", "Canceled"]:
405
+ return {
406
+ "status": job.status,
407
+ "outputs": job.outputs if hasattr(job, "outputs") else {},
408
+ }
409
+ time.sleep(10)
410
+
411
+ raise TimeoutError(f"Job {job_name} did not complete within {timeout_seconds}s")
412
+
413
+
414
+ class AzureWAAOrchestrator:
415
+ """Orchestrates WAA evaluation across multiple Azure VMs.
416
+
417
+ This class manages the full lifecycle of a distributed WAA evaluation:
418
+ 1. Provisions Azure ML compute instances
419
+ 2. Distributes tasks across workers
420
+ 3. Monitors progress and collects results
421
+ 4. Cleans up resources
422
+
423
+ Example:
424
+ config = AzureConfig.from_env()
425
+ orchestrator = AzureWAAOrchestrator(config, waa_repo_path="/path/to/WAA")
426
+
427
+ results = orchestrator.run_evaluation(
428
+ agent=my_agent,
429
+ num_workers=40,
430
+ )
431
+ print(f"Success rate: {sum(r.success for r in results) / len(results):.1%}")
432
+ """
433
+
434
+ def __init__(
435
+ self,
436
+ config: AzureConfig,
437
+ waa_repo_path: str | Path,
438
+ experiment_name: str = "waa-eval",
439
+ ):
440
+ """Initialize orchestrator.
441
+
442
+ Args:
443
+ config: Azure configuration.
444
+ waa_repo_path: Path to WAA repository.
445
+ experiment_name: Name prefix for this evaluation.
446
+ """
447
+ self.config = config
448
+ self.waa_repo_path = Path(waa_repo_path)
449
+ self.experiment_name = experiment_name
450
+ self.ml_client = AzureMLClient(config)
451
+ self._current_run: EvaluationRun | None = None
452
+
453
+ def run_evaluation(
454
+ self,
455
+ agent: BenchmarkAgent,
456
+ num_workers: int = 10,
457
+ task_ids: list[str] | None = None,
458
+ max_steps_per_task: int = 15,
459
+ on_worker_complete: Callable[[WorkerState], None] | None = None,
460
+ cleanup_on_complete: bool = True,
461
+ ) -> list[BenchmarkResult]:
462
+ """Run evaluation across multiple Azure VMs.
463
+
464
+ Args:
465
+ agent: Agent to evaluate (must be serializable or API-based).
466
+ num_workers: Number of parallel VMs.
467
+ task_ids: Specific tasks to run (None = all 154 tasks).
468
+ max_steps_per_task: Maximum steps per task.
469
+ on_worker_complete: Callback when a worker finishes.
470
+ cleanup_on_complete: Whether to delete VMs after completion.
471
+
472
+ Returns:
473
+ List of BenchmarkResult for all tasks.
474
+ """
475
+ # Load tasks
476
+ from openadapt_ml.benchmarks.waa import WAAAdapter
477
+
478
+ adapter = WAAAdapter(waa_repo_path=self.waa_repo_path)
479
+ if task_ids:
480
+ tasks = [adapter.load_task(tid) for tid in task_ids]
481
+ else:
482
+ tasks = adapter.list_tasks()
483
+
484
+ print(f"[1/4] Loaded {len(tasks)} tasks for {num_workers} worker(s)")
485
+
486
+ # Create evaluation run
487
+ run_id = f"{self.experiment_name}-{int(time.time())}"
488
+ self._current_run = EvaluationRun(
489
+ run_id=run_id,
490
+ experiment_name=self.experiment_name,
491
+ num_workers=num_workers,
492
+ total_tasks=len(tasks),
493
+ status="running",
494
+ start_time=time.time(),
495
+ )
496
+
497
+ # Distribute tasks across workers
498
+ task_batches = self._distribute_tasks(tasks, num_workers)
499
+
500
+ # Create workers
501
+ # VM names: 3-24 chars, letters/numbers/hyphens, start with letter
502
+ # Cannot end with number after hyphen, so we add 'x' suffix
503
+ workers = []
504
+ short_id = str(int(time.time()))[-4:] # Last 4 digits of timestamp
505
+ for i, batch in enumerate(task_batches):
506
+ worker = WorkerState(
507
+ worker_id=i,
508
+ compute_name=f"waa{short_id}w{i}", # e.g., "waa6571w0" (no trailing hyphen-number)
509
+ assigned_tasks=[t.task_id for t in batch],
510
+ )
511
+ workers.append(worker)
512
+ self._current_run.workers = workers
513
+
514
+ try:
515
+ # Provision VMs in parallel
516
+ print(f"[2/4] Provisioning {num_workers} Azure VM(s)... (this takes 3-5 minutes)")
517
+ self._provision_workers(workers)
518
+ print(f" VM(s) ready")
519
+
520
+ # Submit jobs to workers
521
+ print(f"[3/4] Submitting evaluation jobs...")
522
+ self._submit_worker_jobs(workers, task_batches, agent, max_steps_per_task)
523
+ print(f" Jobs submitted")
524
+
525
+ # Wait for completion and collect results
526
+ print(f"[4/4] Waiting for workers to complete...")
527
+ results = self._wait_and_collect_results(workers, on_worker_complete)
528
+
529
+ self._current_run.status = "completed"
530
+ self._current_run.end_time = time.time()
531
+
532
+ return results
533
+
534
+ except Exception as e:
535
+ logger.error(f"Evaluation failed: {e}")
536
+ self._current_run.status = "failed"
537
+ raise
538
+
539
+ finally:
540
+ if cleanup_on_complete:
541
+ self._cleanup_workers(workers)
542
+
543
+ def _distribute_tasks(
544
+ self, tasks: list[BenchmarkTask], num_workers: int
545
+ ) -> list[list[BenchmarkTask]]:
546
+ """Distribute tasks evenly across workers."""
547
+ batches: list[list[BenchmarkTask]] = [[] for _ in range(num_workers)]
548
+ for i, task in enumerate(tasks):
549
+ batches[i % num_workers].append(task)
550
+ return batches
551
+
552
+ def _provision_workers(self, workers: list[WorkerState]) -> None:
553
+ """Provision all worker VMs in parallel."""
554
+ with ThreadPoolExecutor(max_workers=len(workers)) as executor:
555
+ futures = {
556
+ executor.submit(
557
+ self.ml_client.create_compute_instance,
558
+ worker.compute_name,
559
+ ): worker
560
+ for worker in workers
561
+ }
562
+
563
+ for future in as_completed(futures):
564
+ worker = futures[future]
565
+ try:
566
+ future.result()
567
+ worker.status = "provisioned"
568
+ logger.info(f"Worker {worker.worker_id} provisioned")
569
+ except Exception as e:
570
+ worker.status = "failed"
571
+ worker.error = str(e)
572
+ logger.error(f"Failed to provision worker {worker.worker_id}: {e}")
573
+
574
+ def _submit_worker_jobs(
575
+ self,
576
+ workers: list[WorkerState],
577
+ task_batches: list[list[BenchmarkTask]],
578
+ agent: BenchmarkAgent,
579
+ max_steps: int,
580
+ ) -> None:
581
+ """Submit evaluation jobs to workers."""
582
+ for worker, tasks in zip(workers, task_batches):
583
+ if worker.status == "failed":
584
+ continue
585
+
586
+ try:
587
+ # Serialize task IDs for this worker
588
+ task_ids = [t.task_id for t in tasks]
589
+ task_ids_json = json.dumps(task_ids)
590
+
591
+ # Build command
592
+ command = self._build_worker_command(task_ids_json, max_steps, agent)
593
+
594
+ # Submit job
595
+ self.ml_client.submit_job(
596
+ compute_name=worker.compute_name,
597
+ command=command,
598
+ environment_variables={
599
+ "WAA_TASK_IDS": task_ids_json,
600
+ "WAA_MAX_STEPS": str(max_steps),
601
+ },
602
+ display_name=f"waa-worker-{worker.worker_id}",
603
+ )
604
+ worker.status = "running"
605
+ worker.start_time = time.time()
606
+
607
+ except Exception as e:
608
+ worker.status = "failed"
609
+ worker.error = str(e)
610
+ logger.error(f"Failed to submit job for worker {worker.worker_id}: {e}")
611
+
612
+ def _build_worker_command(
613
+ self,
614
+ task_ids_json: str,
615
+ max_steps: int,
616
+ agent: BenchmarkAgent, # noqa: ARG002 - will be used for agent config serialization
617
+ ) -> str:
618
+ """Build the command to run on a worker VM.
619
+
620
+ Args:
621
+ task_ids_json: JSON string of task IDs for this worker.
622
+ max_steps: Maximum steps per task.
623
+ agent: Agent to run (TODO: serialize agent config for remote execution).
624
+ """
625
+ # TODO: Serialize agent config and pass to remote worker
626
+ # For now, workers use a default agent configuration
627
+ _ = agent # Reserved for agent serialization
628
+ return f"""
629
+ cd /workspace/WindowsAgentArena && \
630
+ python -m client.run \
631
+ --task_ids '{task_ids_json}' \
632
+ --max_steps {max_steps} \
633
+ --output_dir /outputs
634
+ """
635
+
636
+ def _wait_and_collect_results(
637
+ self,
638
+ workers: list[WorkerState],
639
+ on_worker_complete: Callable[[WorkerState], None] | None,
640
+ ) -> list[BenchmarkResult]:
641
+ """Wait for all workers and collect results."""
642
+ all_results: list[BenchmarkResult] = []
643
+
644
+ # Poll workers for completion
645
+ pending_workers = [w for w in workers if w.status == "running"]
646
+
647
+ while pending_workers:
648
+ for worker in pending_workers[:]:
649
+ try:
650
+ status = self.ml_client.get_compute_status(worker.compute_name)
651
+
652
+ # Check if job completed (simplified - real impl would check job status)
653
+ if status in ["Stopped", "Deallocated"]:
654
+ worker.status = "completed"
655
+ worker.end_time = time.time()
656
+
657
+ # Fetch results from blob storage
658
+ results = self._fetch_worker_results(worker)
659
+ worker.results = results
660
+ all_results.extend(results)
661
+
662
+ if on_worker_complete:
663
+ on_worker_complete(worker)
664
+
665
+ pending_workers.remove(worker)
666
+ logger.info(
667
+ f"Worker {worker.worker_id} completed: "
668
+ f"{len(results)} results"
669
+ )
670
+
671
+ except Exception as e:
672
+ logger.warning(f"Error checking worker {worker.worker_id}: {e}")
673
+
674
+ if pending_workers:
675
+ time.sleep(30)
676
+
677
+ return all_results
678
+
679
+ def _fetch_worker_results(self, worker: WorkerState) -> list[BenchmarkResult]:
680
+ """Fetch results from a worker's output storage."""
681
+ # In a real implementation, this would download results from blob storage
682
+ # For now, return placeholder results
683
+ results = []
684
+ for task_id in worker.assigned_tasks:
685
+ results.append(
686
+ BenchmarkResult(
687
+ task_id=task_id,
688
+ success=False, # Placeholder
689
+ score=0.0,
690
+ num_steps=0,
691
+ )
692
+ )
693
+ return results
694
+
695
+ def _cleanup_workers(self, workers: list[WorkerState]) -> None:
696
+ """Delete all worker VMs."""
697
+ logger.info("Cleaning up worker VMs...")
698
+ with ThreadPoolExecutor(max_workers=len(workers)) as executor:
699
+ futures = [
700
+ executor.submit(self.ml_client.delete_compute_instance, w.compute_name)
701
+ for w in workers
702
+ ]
703
+ for future in as_completed(futures):
704
+ try:
705
+ future.result()
706
+ except Exception as e:
707
+ logger.warning(f"Cleanup error: {e}")
708
+
709
+ def get_run_status(self) -> dict | None:
710
+ """Get current run status."""
711
+ if self._current_run is None:
712
+ return None
713
+ return self._current_run.to_dict()
714
+
715
+ def cancel_run(self) -> None:
716
+ """Cancel the current run and cleanup resources."""
717
+ if self._current_run is None:
718
+ return
719
+
720
+ logger.info("Canceling evaluation run...")
721
+ self._cleanup_workers(self._current_run.workers)
722
+ self._current_run.status = "canceled"
723
+ self._current_run.end_time = time.time()
724
+
725
+
726
+ def estimate_cost(
727
+ num_tasks: int = 154,
728
+ num_workers: int = 1,
729
+ avg_task_duration_minutes: float = 1.0,
730
+ vm_hourly_cost: float = 0.19, # Standard_D4_v3 in East US (free trial compatible)
731
+ ) -> dict:
732
+ """Estimate Azure costs for a WAA evaluation run.
733
+
734
+ Args:
735
+ num_tasks: Number of tasks to run.
736
+ num_workers: Number of parallel VMs (default: 1 for free trial).
737
+ avg_task_duration_minutes: Average time per task.
738
+ vm_hourly_cost: Hourly cost per VM (D4_v3 = $0.19/hr, D8_v3 = $0.38/hr).
739
+
740
+ Returns:
741
+ Dict with cost estimates.
742
+ """
743
+ tasks_per_worker = num_tasks / num_workers
744
+ total_minutes = tasks_per_worker * avg_task_duration_minutes
745
+ total_hours = total_minutes / 60
746
+
747
+ # Add overhead for provisioning/cleanup
748
+ overhead_hours = 0.25 # ~15 minutes
749
+
750
+ vm_hours = (total_hours + overhead_hours) * num_workers
751
+ total_cost = vm_hours * vm_hourly_cost
752
+
753
+ return {
754
+ "num_tasks": num_tasks,
755
+ "num_workers": num_workers,
756
+ "tasks_per_worker": tasks_per_worker,
757
+ "estimated_duration_minutes": total_minutes + (overhead_hours * 60),
758
+ "total_vm_hours": vm_hours,
759
+ "estimated_cost_usd": total_cost,
760
+ "cost_per_task_usd": total_cost / num_tasks,
761
+ }