openadapt-ml 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/__init__.py +0 -0
- openadapt_ml/benchmarks/__init__.py +125 -0
- openadapt_ml/benchmarks/agent.py +825 -0
- openadapt_ml/benchmarks/azure.py +761 -0
- openadapt_ml/benchmarks/base.py +366 -0
- openadapt_ml/benchmarks/cli.py +884 -0
- openadapt_ml/benchmarks/data_collection.py +432 -0
- openadapt_ml/benchmarks/runner.py +381 -0
- openadapt_ml/benchmarks/waa.py +704 -0
- openadapt_ml/cloud/__init__.py +5 -0
- openadapt_ml/cloud/azure_inference.py +441 -0
- openadapt_ml/cloud/lambda_labs.py +2445 -0
- openadapt_ml/cloud/local.py +790 -0
- openadapt_ml/config.py +56 -0
- openadapt_ml/datasets/__init__.py +0 -0
- openadapt_ml/datasets/next_action.py +507 -0
- openadapt_ml/evals/__init__.py +23 -0
- openadapt_ml/evals/grounding.py +241 -0
- openadapt_ml/evals/plot_eval_metrics.py +174 -0
- openadapt_ml/evals/trajectory_matching.py +486 -0
- openadapt_ml/grounding/__init__.py +45 -0
- openadapt_ml/grounding/base.py +236 -0
- openadapt_ml/grounding/detector.py +570 -0
- openadapt_ml/ingest/__init__.py +43 -0
- openadapt_ml/ingest/capture.py +312 -0
- openadapt_ml/ingest/loader.py +232 -0
- openadapt_ml/ingest/synthetic.py +1102 -0
- openadapt_ml/models/__init__.py +0 -0
- openadapt_ml/models/api_adapter.py +171 -0
- openadapt_ml/models/base_adapter.py +59 -0
- openadapt_ml/models/dummy_adapter.py +42 -0
- openadapt_ml/models/qwen_vl.py +426 -0
- openadapt_ml/runtime/__init__.py +0 -0
- openadapt_ml/runtime/policy.py +182 -0
- openadapt_ml/schemas/__init__.py +53 -0
- openadapt_ml/schemas/sessions.py +122 -0
- openadapt_ml/schemas/validation.py +252 -0
- openadapt_ml/scripts/__init__.py +0 -0
- openadapt_ml/scripts/compare.py +1490 -0
- openadapt_ml/scripts/demo_policy.py +62 -0
- openadapt_ml/scripts/eval_policy.py +287 -0
- openadapt_ml/scripts/make_gif.py +153 -0
- openadapt_ml/scripts/prepare_synthetic.py +43 -0
- openadapt_ml/scripts/run_qwen_login_benchmark.py +192 -0
- openadapt_ml/scripts/train.py +174 -0
- openadapt_ml/training/__init__.py +0 -0
- openadapt_ml/training/benchmark_viewer.py +1538 -0
- openadapt_ml/training/shared_ui.py +157 -0
- openadapt_ml/training/stub_provider.py +276 -0
- openadapt_ml/training/trainer.py +2446 -0
- openadapt_ml/training/viewer.py +2970 -0
- openadapt_ml-0.1.0.dist-info/METADATA +818 -0
- openadapt_ml-0.1.0.dist-info/RECORD +55 -0
- openadapt_ml-0.1.0.dist-info/WHEEL +4 -0
- openadapt_ml-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,761 @@
|
|
|
1
|
+
"""Azure deployment automation for WAA benchmark.
|
|
2
|
+
|
|
3
|
+
This module provides Azure VM orchestration for running Windows Agent Arena
|
|
4
|
+
at scale across multiple parallel VMs.
|
|
5
|
+
|
|
6
|
+
Requirements:
|
|
7
|
+
- azure-ai-ml
|
|
8
|
+
- azure-identity
|
|
9
|
+
- Azure subscription with ML workspace
|
|
10
|
+
|
|
11
|
+
Example:
|
|
12
|
+
from openadapt_ml.benchmarks.azure import AzureWAAOrchestrator, AzureConfig
|
|
13
|
+
|
|
14
|
+
config = AzureConfig(
|
|
15
|
+
subscription_id="your-subscription-id",
|
|
16
|
+
resource_group="agents",
|
|
17
|
+
workspace_name="agents_ml",
|
|
18
|
+
)
|
|
19
|
+
orchestrator = AzureWAAOrchestrator(config, waa_repo_path="/path/to/WAA")
|
|
20
|
+
|
|
21
|
+
# Run evaluation on 40 parallel VMs
|
|
22
|
+
results = orchestrator.run_evaluation(
|
|
23
|
+
agent=my_agent,
|
|
24
|
+
num_workers=40,
|
|
25
|
+
task_ids=None, # All tasks
|
|
26
|
+
)
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
from __future__ import annotations
|
|
30
|
+
|
|
31
|
+
import json
|
|
32
|
+
import logging
|
|
33
|
+
import os
|
|
34
|
+
import tempfile
|
|
35
|
+
import time
|
|
36
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
37
|
+
from dataclasses import dataclass, field
|
|
38
|
+
from pathlib import Path
|
|
39
|
+
from typing import Any, Callable
|
|
40
|
+
|
|
41
|
+
from openadapt_ml.benchmarks.agent import BenchmarkAgent
|
|
42
|
+
from openadapt_ml.benchmarks.base import BenchmarkResult, BenchmarkTask
|
|
43
|
+
|
|
44
|
+
logger = logging.getLogger(__name__)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass
|
|
48
|
+
class AzureConfig:
|
|
49
|
+
"""Azure configuration for WAA deployment.
|
|
50
|
+
|
|
51
|
+
Attributes:
|
|
52
|
+
subscription_id: Azure subscription ID.
|
|
53
|
+
resource_group: Resource group containing ML workspace.
|
|
54
|
+
workspace_name: Azure ML workspace name.
|
|
55
|
+
vm_size: VM size for compute instances (must support nested virtualization).
|
|
56
|
+
idle_timeout_minutes: Auto-shutdown after idle (minutes).
|
|
57
|
+
docker_image: Docker image for agent container.
|
|
58
|
+
storage_account: Storage account for results (auto-detected if None).
|
|
59
|
+
use_managed_identity: Whether to use managed identity for auth.
|
|
60
|
+
managed_identity_name: Name of managed identity (if using).
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
subscription_id: str
|
|
64
|
+
resource_group: str
|
|
65
|
+
workspace_name: str
|
|
66
|
+
vm_size: str = "Standard_D2_v3" # 2 vCPUs (fits free trial with existing usage)
|
|
67
|
+
idle_timeout_minutes: int = 60
|
|
68
|
+
docker_image: str = "ghcr.io/microsoft/windowsagentarena:latest"
|
|
69
|
+
storage_account: str | None = None
|
|
70
|
+
use_managed_identity: bool = False
|
|
71
|
+
managed_identity_name: str | None = None
|
|
72
|
+
|
|
73
|
+
@classmethod
|
|
74
|
+
def from_env(cls) -> AzureConfig:
|
|
75
|
+
"""Create config from environment variables / .env file.
|
|
76
|
+
|
|
77
|
+
Uses settings from openadapt_ml.config which loads from:
|
|
78
|
+
1. Environment variables
|
|
79
|
+
2. .env file
|
|
80
|
+
3. Default values
|
|
81
|
+
|
|
82
|
+
Required settings:
|
|
83
|
+
AZURE_SUBSCRIPTION_ID
|
|
84
|
+
AZURE_ML_RESOURCE_GROUP
|
|
85
|
+
AZURE_ML_WORKSPACE_NAME
|
|
86
|
+
|
|
87
|
+
Optional settings:
|
|
88
|
+
AZURE_VM_SIZE (default: Standard_D4_v3 for free trial compatibility)
|
|
89
|
+
AZURE_DOCKER_IMAGE (default: ghcr.io/microsoft/windowsagentarena:latest)
|
|
90
|
+
|
|
91
|
+
Authentication (one of):
|
|
92
|
+
- AZURE_CLIENT_ID, AZURE_CLIENT_SECRET, AZURE_TENANT_ID (service principal)
|
|
93
|
+
- Azure CLI login (`az login`)
|
|
94
|
+
- Managed Identity (when running on Azure)
|
|
95
|
+
|
|
96
|
+
Raises:
|
|
97
|
+
ValueError: If required settings are not configured.
|
|
98
|
+
"""
|
|
99
|
+
from openadapt_ml.config import settings
|
|
100
|
+
|
|
101
|
+
# Validate required settings
|
|
102
|
+
if not settings.azure_subscription_id:
|
|
103
|
+
raise ValueError(
|
|
104
|
+
"AZURE_SUBSCRIPTION_ID not set. "
|
|
105
|
+
"Run 'python scripts/setup_azure.py' to configure Azure credentials."
|
|
106
|
+
)
|
|
107
|
+
if not settings.azure_ml_resource_group:
|
|
108
|
+
raise ValueError(
|
|
109
|
+
"AZURE_ML_RESOURCE_GROUP not set. "
|
|
110
|
+
"Run 'python scripts/setup_azure.py' to configure Azure credentials."
|
|
111
|
+
)
|
|
112
|
+
if not settings.azure_ml_workspace_name:
|
|
113
|
+
raise ValueError(
|
|
114
|
+
"AZURE_ML_WORKSPACE_NAME not set. "
|
|
115
|
+
"Run 'python scripts/setup_azure.py' to configure Azure credentials."
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
return cls(
|
|
119
|
+
subscription_id=settings.azure_subscription_id,
|
|
120
|
+
resource_group=settings.azure_ml_resource_group,
|
|
121
|
+
workspace_name=settings.azure_ml_workspace_name,
|
|
122
|
+
vm_size=settings.azure_vm_size,
|
|
123
|
+
docker_image=settings.azure_docker_image,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
@classmethod
|
|
127
|
+
def from_json(cls, path: str | Path) -> AzureConfig:
|
|
128
|
+
"""Load config from JSON file."""
|
|
129
|
+
with open(path) as f:
|
|
130
|
+
data = json.load(f)
|
|
131
|
+
return cls(**data)
|
|
132
|
+
|
|
133
|
+
def to_json(self, path: str | Path) -> None:
|
|
134
|
+
"""Save config to JSON file."""
|
|
135
|
+
with open(path, "w") as f:
|
|
136
|
+
json.dump(self.__dict__, f, indent=2)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
@dataclass
|
|
140
|
+
class WorkerState:
|
|
141
|
+
"""State of a single worker VM."""
|
|
142
|
+
|
|
143
|
+
worker_id: int
|
|
144
|
+
compute_name: str
|
|
145
|
+
status: str = "pending" # pending, running, completed, failed
|
|
146
|
+
assigned_tasks: list[str] = field(default_factory=list)
|
|
147
|
+
completed_tasks: list[str] = field(default_factory=list)
|
|
148
|
+
results: list[BenchmarkResult] = field(default_factory=list)
|
|
149
|
+
error: str | None = None
|
|
150
|
+
start_time: float | None = None
|
|
151
|
+
end_time: float | None = None
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
@dataclass
|
|
155
|
+
class EvaluationRun:
|
|
156
|
+
"""State of an evaluation run across multiple workers."""
|
|
157
|
+
|
|
158
|
+
run_id: str
|
|
159
|
+
experiment_name: str
|
|
160
|
+
num_workers: int
|
|
161
|
+
total_tasks: int
|
|
162
|
+
workers: list[WorkerState] = field(default_factory=list)
|
|
163
|
+
status: str = "pending" # pending, running, completed, failed
|
|
164
|
+
start_time: float | None = None
|
|
165
|
+
end_time: float | None = None
|
|
166
|
+
|
|
167
|
+
def to_dict(self) -> dict:
|
|
168
|
+
"""Serialize to dict for JSON storage."""
|
|
169
|
+
return {
|
|
170
|
+
"run_id": self.run_id,
|
|
171
|
+
"experiment_name": self.experiment_name,
|
|
172
|
+
"num_workers": self.num_workers,
|
|
173
|
+
"total_tasks": self.total_tasks,
|
|
174
|
+
"status": self.status,
|
|
175
|
+
"start_time": self.start_time,
|
|
176
|
+
"end_time": self.end_time,
|
|
177
|
+
"workers": [
|
|
178
|
+
{
|
|
179
|
+
"worker_id": w.worker_id,
|
|
180
|
+
"compute_name": w.compute_name,
|
|
181
|
+
"status": w.status,
|
|
182
|
+
"assigned_tasks": w.assigned_tasks,
|
|
183
|
+
"completed_tasks": w.completed_tasks,
|
|
184
|
+
"error": w.error,
|
|
185
|
+
}
|
|
186
|
+
for w in self.workers
|
|
187
|
+
],
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class AzureMLClient:
|
|
192
|
+
"""Wrapper around Azure ML SDK for compute management.
|
|
193
|
+
|
|
194
|
+
This provides a simplified interface for creating and managing
|
|
195
|
+
Azure ML compute instances for WAA evaluation.
|
|
196
|
+
"""
|
|
197
|
+
|
|
198
|
+
def __init__(self, config: AzureConfig):
|
|
199
|
+
self.config = config
|
|
200
|
+
self._client = None
|
|
201
|
+
self._ensure_sdk_available()
|
|
202
|
+
|
|
203
|
+
def _ensure_sdk_available(self) -> None:
|
|
204
|
+
"""Check that Azure SDK is available."""
|
|
205
|
+
try:
|
|
206
|
+
from azure.ai.ml import MLClient
|
|
207
|
+
from azure.identity import (
|
|
208
|
+
ClientSecretCredential,
|
|
209
|
+
DefaultAzureCredential,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
self._MLClient = MLClient
|
|
213
|
+
self._DefaultAzureCredential = DefaultAzureCredential
|
|
214
|
+
self._ClientSecretCredential = ClientSecretCredential
|
|
215
|
+
except ImportError as e:
|
|
216
|
+
raise ImportError(
|
|
217
|
+
"Azure ML SDK not installed. Install with: "
|
|
218
|
+
"pip install azure-ai-ml azure-identity"
|
|
219
|
+
) from e
|
|
220
|
+
|
|
221
|
+
@property
|
|
222
|
+
def client(self):
|
|
223
|
+
"""Lazy-load ML client.
|
|
224
|
+
|
|
225
|
+
Uses service principal credentials if configured in .env,
|
|
226
|
+
otherwise falls back to DefaultAzureCredential (CLI login, managed identity, etc.)
|
|
227
|
+
"""
|
|
228
|
+
if self._client is None:
|
|
229
|
+
credential = self._get_credential()
|
|
230
|
+
self._client = self._MLClient(
|
|
231
|
+
credential=credential,
|
|
232
|
+
subscription_id=self.config.subscription_id,
|
|
233
|
+
resource_group_name=self.config.resource_group,
|
|
234
|
+
workspace_name=self.config.workspace_name,
|
|
235
|
+
)
|
|
236
|
+
logger.info(f"Connected to Azure ML workspace: {self.config.workspace_name}")
|
|
237
|
+
return self._client
|
|
238
|
+
|
|
239
|
+
def _get_credential(self):
|
|
240
|
+
"""Get Azure credential, preferring service principal if configured."""
|
|
241
|
+
from openadapt_ml.config import settings
|
|
242
|
+
|
|
243
|
+
# Use service principal if credentials are configured
|
|
244
|
+
if all([
|
|
245
|
+
settings.azure_client_id,
|
|
246
|
+
settings.azure_client_secret,
|
|
247
|
+
settings.azure_tenant_id,
|
|
248
|
+
]):
|
|
249
|
+
logger.info("Using service principal authentication")
|
|
250
|
+
return self._ClientSecretCredential(
|
|
251
|
+
tenant_id=settings.azure_tenant_id,
|
|
252
|
+
client_id=settings.azure_client_id,
|
|
253
|
+
client_secret=settings.azure_client_secret,
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
# Fall back to DefaultAzureCredential (CLI login, managed identity, etc.)
|
|
257
|
+
logger.info(
|
|
258
|
+
"Using DefaultAzureCredential (ensure you're logged in with 'az login' "
|
|
259
|
+
"or have service principal credentials in .env)"
|
|
260
|
+
)
|
|
261
|
+
return self._DefaultAzureCredential()
|
|
262
|
+
|
|
263
|
+
def create_compute_instance(
|
|
264
|
+
self,
|
|
265
|
+
name: str,
|
|
266
|
+
startup_script: str | None = None, # noqa: ARG002 - reserved for future use
|
|
267
|
+
) -> str:
|
|
268
|
+
"""Create a compute instance.
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
name: Compute instance name.
|
|
272
|
+
startup_script: Optional startup script content (not yet implemented).
|
|
273
|
+
|
|
274
|
+
Returns:
|
|
275
|
+
Compute instance name.
|
|
276
|
+
"""
|
|
277
|
+
# TODO: Add startup_script support when implementing full WAA integration
|
|
278
|
+
_ = startup_script # Reserved for future use
|
|
279
|
+
from azure.ai.ml.entities import ComputeInstance
|
|
280
|
+
|
|
281
|
+
# Check if already exists
|
|
282
|
+
try:
|
|
283
|
+
existing = self.client.compute.get(name)
|
|
284
|
+
if existing:
|
|
285
|
+
logger.info(f"Compute instance {name} already exists")
|
|
286
|
+
return name
|
|
287
|
+
except Exception:
|
|
288
|
+
pass # Doesn't exist, create it
|
|
289
|
+
|
|
290
|
+
compute = ComputeInstance(
|
|
291
|
+
name=name,
|
|
292
|
+
size=self.config.vm_size,
|
|
293
|
+
idle_time_before_shutdown_minutes=self.config.idle_timeout_minutes,
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
# Add managed identity if configured
|
|
297
|
+
if self.config.use_managed_identity and self.config.managed_identity_name:
|
|
298
|
+
identity_id = (
|
|
299
|
+
f"/subscriptions/{self.config.subscription_id}"
|
|
300
|
+
f"/resourceGroups/{self.config.resource_group}"
|
|
301
|
+
f"/providers/Microsoft.ManagedIdentity"
|
|
302
|
+
f"/userAssignedIdentities/{self.config.managed_identity_name}"
|
|
303
|
+
)
|
|
304
|
+
compute.identity = {"type": "UserAssigned", "user_assigned_identities": [identity_id]}
|
|
305
|
+
|
|
306
|
+
print(f" Creating VM: {name}...", end="", flush=True)
|
|
307
|
+
self.client.compute.begin_create_or_update(compute).result()
|
|
308
|
+
print(" done")
|
|
309
|
+
|
|
310
|
+
return name
|
|
311
|
+
|
|
312
|
+
def delete_compute_instance(self, name: str) -> None:
|
|
313
|
+
"""Delete a compute instance.
|
|
314
|
+
|
|
315
|
+
Args:
|
|
316
|
+
name: Compute instance name.
|
|
317
|
+
"""
|
|
318
|
+
try:
|
|
319
|
+
logger.info(f"Deleting compute instance: {name}")
|
|
320
|
+
self.client.compute.begin_delete(name).result()
|
|
321
|
+
logger.info(f"Compute instance {name} deleted")
|
|
322
|
+
except Exception as e:
|
|
323
|
+
logger.warning(f"Failed to delete compute instance {name}: {e}")
|
|
324
|
+
|
|
325
|
+
def list_compute_instances(self, prefix: str | None = None) -> list[str]:
|
|
326
|
+
"""List compute instances.
|
|
327
|
+
|
|
328
|
+
Args:
|
|
329
|
+
prefix: Optional name prefix filter.
|
|
330
|
+
|
|
331
|
+
Returns:
|
|
332
|
+
List of compute instance names.
|
|
333
|
+
"""
|
|
334
|
+
computes = self.client.compute.list()
|
|
335
|
+
names = [c.name for c in computes if c.type == "ComputeInstance"]
|
|
336
|
+
if prefix:
|
|
337
|
+
names = [n for n in names if n.startswith(prefix)]
|
|
338
|
+
return names
|
|
339
|
+
|
|
340
|
+
def get_compute_status(self, name: str) -> str:
|
|
341
|
+
"""Get compute instance status.
|
|
342
|
+
|
|
343
|
+
Args:
|
|
344
|
+
name: Compute instance name.
|
|
345
|
+
|
|
346
|
+
Returns:
|
|
347
|
+
Status string (Running, Stopped, etc.)
|
|
348
|
+
"""
|
|
349
|
+
compute = self.client.compute.get(name)
|
|
350
|
+
return compute.state
|
|
351
|
+
|
|
352
|
+
def submit_job(
|
|
353
|
+
self,
|
|
354
|
+
compute_name: str,
|
|
355
|
+
command: str,
|
|
356
|
+
environment_variables: dict[str, str] | None = None,
|
|
357
|
+
display_name: str | None = None,
|
|
358
|
+
) -> str:
|
|
359
|
+
"""Submit a job to a compute instance.
|
|
360
|
+
|
|
361
|
+
Args:
|
|
362
|
+
compute_name: Target compute instance.
|
|
363
|
+
command: Command to run.
|
|
364
|
+
environment_variables: Environment variables.
|
|
365
|
+
display_name: Job display name.
|
|
366
|
+
|
|
367
|
+
Returns:
|
|
368
|
+
Job name/ID.
|
|
369
|
+
"""
|
|
370
|
+
from azure.ai.ml import command as ml_command
|
|
371
|
+
from azure.ai.ml.entities import Environment
|
|
372
|
+
|
|
373
|
+
# Create environment with Docker image
|
|
374
|
+
env = Environment(
|
|
375
|
+
image=self.config.docker_image,
|
|
376
|
+
name="waa-agent-env",
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
job = ml_command(
|
|
380
|
+
command=command,
|
|
381
|
+
environment=env,
|
|
382
|
+
compute=compute_name,
|
|
383
|
+
display_name=display_name or f"waa-job-{compute_name}",
|
|
384
|
+
environment_variables=environment_variables or {},
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
submitted = self.client.jobs.create_or_update(job)
|
|
388
|
+
logger.info(f"Job submitted: {submitted.name}")
|
|
389
|
+
return submitted.name
|
|
390
|
+
|
|
391
|
+
def wait_for_job(self, job_name: str, timeout_seconds: int = 3600) -> dict:
|
|
392
|
+
"""Wait for a job to complete.
|
|
393
|
+
|
|
394
|
+
Args:
|
|
395
|
+
job_name: Job name/ID.
|
|
396
|
+
timeout_seconds: Maximum wait time.
|
|
397
|
+
|
|
398
|
+
Returns:
|
|
399
|
+
Job result dict.
|
|
400
|
+
"""
|
|
401
|
+
start_time = time.time()
|
|
402
|
+
while time.time() - start_time < timeout_seconds:
|
|
403
|
+
job = self.client.jobs.get(job_name)
|
|
404
|
+
if job.status in ["Completed", "Failed", "Canceled"]:
|
|
405
|
+
return {
|
|
406
|
+
"status": job.status,
|
|
407
|
+
"outputs": job.outputs if hasattr(job, "outputs") else {},
|
|
408
|
+
}
|
|
409
|
+
time.sleep(10)
|
|
410
|
+
|
|
411
|
+
raise TimeoutError(f"Job {job_name} did not complete within {timeout_seconds}s")
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
class AzureWAAOrchestrator:
|
|
415
|
+
"""Orchestrates WAA evaluation across multiple Azure VMs.
|
|
416
|
+
|
|
417
|
+
This class manages the full lifecycle of a distributed WAA evaluation:
|
|
418
|
+
1. Provisions Azure ML compute instances
|
|
419
|
+
2. Distributes tasks across workers
|
|
420
|
+
3. Monitors progress and collects results
|
|
421
|
+
4. Cleans up resources
|
|
422
|
+
|
|
423
|
+
Example:
|
|
424
|
+
config = AzureConfig.from_env()
|
|
425
|
+
orchestrator = AzureWAAOrchestrator(config, waa_repo_path="/path/to/WAA")
|
|
426
|
+
|
|
427
|
+
results = orchestrator.run_evaluation(
|
|
428
|
+
agent=my_agent,
|
|
429
|
+
num_workers=40,
|
|
430
|
+
)
|
|
431
|
+
print(f"Success rate: {sum(r.success for r in results) / len(results):.1%}")
|
|
432
|
+
"""
|
|
433
|
+
|
|
434
|
+
def __init__(
|
|
435
|
+
self,
|
|
436
|
+
config: AzureConfig,
|
|
437
|
+
waa_repo_path: str | Path,
|
|
438
|
+
experiment_name: str = "waa-eval",
|
|
439
|
+
):
|
|
440
|
+
"""Initialize orchestrator.
|
|
441
|
+
|
|
442
|
+
Args:
|
|
443
|
+
config: Azure configuration.
|
|
444
|
+
waa_repo_path: Path to WAA repository.
|
|
445
|
+
experiment_name: Name prefix for this evaluation.
|
|
446
|
+
"""
|
|
447
|
+
self.config = config
|
|
448
|
+
self.waa_repo_path = Path(waa_repo_path)
|
|
449
|
+
self.experiment_name = experiment_name
|
|
450
|
+
self.ml_client = AzureMLClient(config)
|
|
451
|
+
self._current_run: EvaluationRun | None = None
|
|
452
|
+
|
|
453
|
+
def run_evaluation(
|
|
454
|
+
self,
|
|
455
|
+
agent: BenchmarkAgent,
|
|
456
|
+
num_workers: int = 10,
|
|
457
|
+
task_ids: list[str] | None = None,
|
|
458
|
+
max_steps_per_task: int = 15,
|
|
459
|
+
on_worker_complete: Callable[[WorkerState], None] | None = None,
|
|
460
|
+
cleanup_on_complete: bool = True,
|
|
461
|
+
) -> list[BenchmarkResult]:
|
|
462
|
+
"""Run evaluation across multiple Azure VMs.
|
|
463
|
+
|
|
464
|
+
Args:
|
|
465
|
+
agent: Agent to evaluate (must be serializable or API-based).
|
|
466
|
+
num_workers: Number of parallel VMs.
|
|
467
|
+
task_ids: Specific tasks to run (None = all 154 tasks).
|
|
468
|
+
max_steps_per_task: Maximum steps per task.
|
|
469
|
+
on_worker_complete: Callback when a worker finishes.
|
|
470
|
+
cleanup_on_complete: Whether to delete VMs after completion.
|
|
471
|
+
|
|
472
|
+
Returns:
|
|
473
|
+
List of BenchmarkResult for all tasks.
|
|
474
|
+
"""
|
|
475
|
+
# Load tasks
|
|
476
|
+
from openadapt_ml.benchmarks.waa import WAAAdapter
|
|
477
|
+
|
|
478
|
+
adapter = WAAAdapter(waa_repo_path=self.waa_repo_path)
|
|
479
|
+
if task_ids:
|
|
480
|
+
tasks = [adapter.load_task(tid) for tid in task_ids]
|
|
481
|
+
else:
|
|
482
|
+
tasks = adapter.list_tasks()
|
|
483
|
+
|
|
484
|
+
print(f"[1/4] Loaded {len(tasks)} tasks for {num_workers} worker(s)")
|
|
485
|
+
|
|
486
|
+
# Create evaluation run
|
|
487
|
+
run_id = f"{self.experiment_name}-{int(time.time())}"
|
|
488
|
+
self._current_run = EvaluationRun(
|
|
489
|
+
run_id=run_id,
|
|
490
|
+
experiment_name=self.experiment_name,
|
|
491
|
+
num_workers=num_workers,
|
|
492
|
+
total_tasks=len(tasks),
|
|
493
|
+
status="running",
|
|
494
|
+
start_time=time.time(),
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
# Distribute tasks across workers
|
|
498
|
+
task_batches = self._distribute_tasks(tasks, num_workers)
|
|
499
|
+
|
|
500
|
+
# Create workers
|
|
501
|
+
# VM names: 3-24 chars, letters/numbers/hyphens, start with letter
|
|
502
|
+
# Cannot end with number after hyphen, so we add 'x' suffix
|
|
503
|
+
workers = []
|
|
504
|
+
short_id = str(int(time.time()))[-4:] # Last 4 digits of timestamp
|
|
505
|
+
for i, batch in enumerate(task_batches):
|
|
506
|
+
worker = WorkerState(
|
|
507
|
+
worker_id=i,
|
|
508
|
+
compute_name=f"waa{short_id}w{i}", # e.g., "waa6571w0" (no trailing hyphen-number)
|
|
509
|
+
assigned_tasks=[t.task_id for t in batch],
|
|
510
|
+
)
|
|
511
|
+
workers.append(worker)
|
|
512
|
+
self._current_run.workers = workers
|
|
513
|
+
|
|
514
|
+
try:
|
|
515
|
+
# Provision VMs in parallel
|
|
516
|
+
print(f"[2/4] Provisioning {num_workers} Azure VM(s)... (this takes 3-5 minutes)")
|
|
517
|
+
self._provision_workers(workers)
|
|
518
|
+
print(f" VM(s) ready")
|
|
519
|
+
|
|
520
|
+
# Submit jobs to workers
|
|
521
|
+
print(f"[3/4] Submitting evaluation jobs...")
|
|
522
|
+
self._submit_worker_jobs(workers, task_batches, agent, max_steps_per_task)
|
|
523
|
+
print(f" Jobs submitted")
|
|
524
|
+
|
|
525
|
+
# Wait for completion and collect results
|
|
526
|
+
print(f"[4/4] Waiting for workers to complete...")
|
|
527
|
+
results = self._wait_and_collect_results(workers, on_worker_complete)
|
|
528
|
+
|
|
529
|
+
self._current_run.status = "completed"
|
|
530
|
+
self._current_run.end_time = time.time()
|
|
531
|
+
|
|
532
|
+
return results
|
|
533
|
+
|
|
534
|
+
except Exception as e:
|
|
535
|
+
logger.error(f"Evaluation failed: {e}")
|
|
536
|
+
self._current_run.status = "failed"
|
|
537
|
+
raise
|
|
538
|
+
|
|
539
|
+
finally:
|
|
540
|
+
if cleanup_on_complete:
|
|
541
|
+
self._cleanup_workers(workers)
|
|
542
|
+
|
|
543
|
+
def _distribute_tasks(
|
|
544
|
+
self, tasks: list[BenchmarkTask], num_workers: int
|
|
545
|
+
) -> list[list[BenchmarkTask]]:
|
|
546
|
+
"""Distribute tasks evenly across workers."""
|
|
547
|
+
batches: list[list[BenchmarkTask]] = [[] for _ in range(num_workers)]
|
|
548
|
+
for i, task in enumerate(tasks):
|
|
549
|
+
batches[i % num_workers].append(task)
|
|
550
|
+
return batches
|
|
551
|
+
|
|
552
|
+
def _provision_workers(self, workers: list[WorkerState]) -> None:
|
|
553
|
+
"""Provision all worker VMs in parallel."""
|
|
554
|
+
with ThreadPoolExecutor(max_workers=len(workers)) as executor:
|
|
555
|
+
futures = {
|
|
556
|
+
executor.submit(
|
|
557
|
+
self.ml_client.create_compute_instance,
|
|
558
|
+
worker.compute_name,
|
|
559
|
+
): worker
|
|
560
|
+
for worker in workers
|
|
561
|
+
}
|
|
562
|
+
|
|
563
|
+
for future in as_completed(futures):
|
|
564
|
+
worker = futures[future]
|
|
565
|
+
try:
|
|
566
|
+
future.result()
|
|
567
|
+
worker.status = "provisioned"
|
|
568
|
+
logger.info(f"Worker {worker.worker_id} provisioned")
|
|
569
|
+
except Exception as e:
|
|
570
|
+
worker.status = "failed"
|
|
571
|
+
worker.error = str(e)
|
|
572
|
+
logger.error(f"Failed to provision worker {worker.worker_id}: {e}")
|
|
573
|
+
|
|
574
|
+
def _submit_worker_jobs(
|
|
575
|
+
self,
|
|
576
|
+
workers: list[WorkerState],
|
|
577
|
+
task_batches: list[list[BenchmarkTask]],
|
|
578
|
+
agent: BenchmarkAgent,
|
|
579
|
+
max_steps: int,
|
|
580
|
+
) -> None:
|
|
581
|
+
"""Submit evaluation jobs to workers."""
|
|
582
|
+
for worker, tasks in zip(workers, task_batches):
|
|
583
|
+
if worker.status == "failed":
|
|
584
|
+
continue
|
|
585
|
+
|
|
586
|
+
try:
|
|
587
|
+
# Serialize task IDs for this worker
|
|
588
|
+
task_ids = [t.task_id for t in tasks]
|
|
589
|
+
task_ids_json = json.dumps(task_ids)
|
|
590
|
+
|
|
591
|
+
# Build command
|
|
592
|
+
command = self._build_worker_command(task_ids_json, max_steps, agent)
|
|
593
|
+
|
|
594
|
+
# Submit job
|
|
595
|
+
self.ml_client.submit_job(
|
|
596
|
+
compute_name=worker.compute_name,
|
|
597
|
+
command=command,
|
|
598
|
+
environment_variables={
|
|
599
|
+
"WAA_TASK_IDS": task_ids_json,
|
|
600
|
+
"WAA_MAX_STEPS": str(max_steps),
|
|
601
|
+
},
|
|
602
|
+
display_name=f"waa-worker-{worker.worker_id}",
|
|
603
|
+
)
|
|
604
|
+
worker.status = "running"
|
|
605
|
+
worker.start_time = time.time()
|
|
606
|
+
|
|
607
|
+
except Exception as e:
|
|
608
|
+
worker.status = "failed"
|
|
609
|
+
worker.error = str(e)
|
|
610
|
+
logger.error(f"Failed to submit job for worker {worker.worker_id}: {e}")
|
|
611
|
+
|
|
612
|
+
def _build_worker_command(
|
|
613
|
+
self,
|
|
614
|
+
task_ids_json: str,
|
|
615
|
+
max_steps: int,
|
|
616
|
+
agent: BenchmarkAgent, # noqa: ARG002 - will be used for agent config serialization
|
|
617
|
+
) -> str:
|
|
618
|
+
"""Build the command to run on a worker VM.
|
|
619
|
+
|
|
620
|
+
Args:
|
|
621
|
+
task_ids_json: JSON string of task IDs for this worker.
|
|
622
|
+
max_steps: Maximum steps per task.
|
|
623
|
+
agent: Agent to run (TODO: serialize agent config for remote execution).
|
|
624
|
+
"""
|
|
625
|
+
# TODO: Serialize agent config and pass to remote worker
|
|
626
|
+
# For now, workers use a default agent configuration
|
|
627
|
+
_ = agent # Reserved for agent serialization
|
|
628
|
+
return f"""
|
|
629
|
+
cd /workspace/WindowsAgentArena && \
|
|
630
|
+
python -m client.run \
|
|
631
|
+
--task_ids '{task_ids_json}' \
|
|
632
|
+
--max_steps {max_steps} \
|
|
633
|
+
--output_dir /outputs
|
|
634
|
+
"""
|
|
635
|
+
|
|
636
|
+
def _wait_and_collect_results(
|
|
637
|
+
self,
|
|
638
|
+
workers: list[WorkerState],
|
|
639
|
+
on_worker_complete: Callable[[WorkerState], None] | None,
|
|
640
|
+
) -> list[BenchmarkResult]:
|
|
641
|
+
"""Wait for all workers and collect results."""
|
|
642
|
+
all_results: list[BenchmarkResult] = []
|
|
643
|
+
|
|
644
|
+
# Poll workers for completion
|
|
645
|
+
pending_workers = [w for w in workers if w.status == "running"]
|
|
646
|
+
|
|
647
|
+
while pending_workers:
|
|
648
|
+
for worker in pending_workers[:]:
|
|
649
|
+
try:
|
|
650
|
+
status = self.ml_client.get_compute_status(worker.compute_name)
|
|
651
|
+
|
|
652
|
+
# Check if job completed (simplified - real impl would check job status)
|
|
653
|
+
if status in ["Stopped", "Deallocated"]:
|
|
654
|
+
worker.status = "completed"
|
|
655
|
+
worker.end_time = time.time()
|
|
656
|
+
|
|
657
|
+
# Fetch results from blob storage
|
|
658
|
+
results = self._fetch_worker_results(worker)
|
|
659
|
+
worker.results = results
|
|
660
|
+
all_results.extend(results)
|
|
661
|
+
|
|
662
|
+
if on_worker_complete:
|
|
663
|
+
on_worker_complete(worker)
|
|
664
|
+
|
|
665
|
+
pending_workers.remove(worker)
|
|
666
|
+
logger.info(
|
|
667
|
+
f"Worker {worker.worker_id} completed: "
|
|
668
|
+
f"{len(results)} results"
|
|
669
|
+
)
|
|
670
|
+
|
|
671
|
+
except Exception as e:
|
|
672
|
+
logger.warning(f"Error checking worker {worker.worker_id}: {e}")
|
|
673
|
+
|
|
674
|
+
if pending_workers:
|
|
675
|
+
time.sleep(30)
|
|
676
|
+
|
|
677
|
+
return all_results
|
|
678
|
+
|
|
679
|
+
def _fetch_worker_results(self, worker: WorkerState) -> list[BenchmarkResult]:
|
|
680
|
+
"""Fetch results from a worker's output storage."""
|
|
681
|
+
# In a real implementation, this would download results from blob storage
|
|
682
|
+
# For now, return placeholder results
|
|
683
|
+
results = []
|
|
684
|
+
for task_id in worker.assigned_tasks:
|
|
685
|
+
results.append(
|
|
686
|
+
BenchmarkResult(
|
|
687
|
+
task_id=task_id,
|
|
688
|
+
success=False, # Placeholder
|
|
689
|
+
score=0.0,
|
|
690
|
+
num_steps=0,
|
|
691
|
+
)
|
|
692
|
+
)
|
|
693
|
+
return results
|
|
694
|
+
|
|
695
|
+
def _cleanup_workers(self, workers: list[WorkerState]) -> None:
|
|
696
|
+
"""Delete all worker VMs."""
|
|
697
|
+
logger.info("Cleaning up worker VMs...")
|
|
698
|
+
with ThreadPoolExecutor(max_workers=len(workers)) as executor:
|
|
699
|
+
futures = [
|
|
700
|
+
executor.submit(self.ml_client.delete_compute_instance, w.compute_name)
|
|
701
|
+
for w in workers
|
|
702
|
+
]
|
|
703
|
+
for future in as_completed(futures):
|
|
704
|
+
try:
|
|
705
|
+
future.result()
|
|
706
|
+
except Exception as e:
|
|
707
|
+
logger.warning(f"Cleanup error: {e}")
|
|
708
|
+
|
|
709
|
+
def get_run_status(self) -> dict | None:
|
|
710
|
+
"""Get current run status."""
|
|
711
|
+
if self._current_run is None:
|
|
712
|
+
return None
|
|
713
|
+
return self._current_run.to_dict()
|
|
714
|
+
|
|
715
|
+
def cancel_run(self) -> None:
|
|
716
|
+
"""Cancel the current run and cleanup resources."""
|
|
717
|
+
if self._current_run is None:
|
|
718
|
+
return
|
|
719
|
+
|
|
720
|
+
logger.info("Canceling evaluation run...")
|
|
721
|
+
self._cleanup_workers(self._current_run.workers)
|
|
722
|
+
self._current_run.status = "canceled"
|
|
723
|
+
self._current_run.end_time = time.time()
|
|
724
|
+
|
|
725
|
+
|
|
726
|
+
def estimate_cost(
|
|
727
|
+
num_tasks: int = 154,
|
|
728
|
+
num_workers: int = 1,
|
|
729
|
+
avg_task_duration_minutes: float = 1.0,
|
|
730
|
+
vm_hourly_cost: float = 0.19, # Standard_D4_v3 in East US (free trial compatible)
|
|
731
|
+
) -> dict:
|
|
732
|
+
"""Estimate Azure costs for a WAA evaluation run.
|
|
733
|
+
|
|
734
|
+
Args:
|
|
735
|
+
num_tasks: Number of tasks to run.
|
|
736
|
+
num_workers: Number of parallel VMs (default: 1 for free trial).
|
|
737
|
+
avg_task_duration_minutes: Average time per task.
|
|
738
|
+
vm_hourly_cost: Hourly cost per VM (D4_v3 = $0.19/hr, D8_v3 = $0.38/hr).
|
|
739
|
+
|
|
740
|
+
Returns:
|
|
741
|
+
Dict with cost estimates.
|
|
742
|
+
"""
|
|
743
|
+
tasks_per_worker = num_tasks / num_workers
|
|
744
|
+
total_minutes = tasks_per_worker * avg_task_duration_minutes
|
|
745
|
+
total_hours = total_minutes / 60
|
|
746
|
+
|
|
747
|
+
# Add overhead for provisioning/cleanup
|
|
748
|
+
overhead_hours = 0.25 # ~15 minutes
|
|
749
|
+
|
|
750
|
+
vm_hours = (total_hours + overhead_hours) * num_workers
|
|
751
|
+
total_cost = vm_hours * vm_hourly_cost
|
|
752
|
+
|
|
753
|
+
return {
|
|
754
|
+
"num_tasks": num_tasks,
|
|
755
|
+
"num_workers": num_workers,
|
|
756
|
+
"tasks_per_worker": tasks_per_worker,
|
|
757
|
+
"estimated_duration_minutes": total_minutes + (overhead_hours * 60),
|
|
758
|
+
"total_vm_hours": vm_hours,
|
|
759
|
+
"estimated_cost_usd": total_cost,
|
|
760
|
+
"cost_per_task_usd": total_cost / num_tasks,
|
|
761
|
+
}
|