openadapt-ml 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. openadapt_ml/baselines/__init__.py +121 -0
  2. openadapt_ml/baselines/adapter.py +185 -0
  3. openadapt_ml/baselines/cli.py +314 -0
  4. openadapt_ml/baselines/config.py +448 -0
  5. openadapt_ml/baselines/parser.py +922 -0
  6. openadapt_ml/baselines/prompts.py +787 -0
  7. openadapt_ml/benchmarks/__init__.py +13 -107
  8. openadapt_ml/benchmarks/agent.py +297 -374
  9. openadapt_ml/benchmarks/azure.py +62 -24
  10. openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
  11. openadapt_ml/benchmarks/cli.py +1874 -751
  12. openadapt_ml/benchmarks/trace_export.py +631 -0
  13. openadapt_ml/benchmarks/viewer.py +1236 -0
  14. openadapt_ml/benchmarks/vm_monitor.py +1111 -0
  15. openadapt_ml/benchmarks/waa_deploy/Dockerfile +216 -0
  16. openadapt_ml/benchmarks/waa_deploy/__init__.py +10 -0
  17. openadapt_ml/benchmarks/waa_deploy/api_agent.py +540 -0
  18. openadapt_ml/benchmarks/waa_deploy/start_waa_server.bat +53 -0
  19. openadapt_ml/cloud/azure_inference.py +3 -5
  20. openadapt_ml/cloud/lambda_labs.py +722 -307
  21. openadapt_ml/cloud/local.py +3194 -89
  22. openadapt_ml/cloud/ssh_tunnel.py +595 -0
  23. openadapt_ml/datasets/next_action.py +125 -96
  24. openadapt_ml/evals/grounding.py +32 -9
  25. openadapt_ml/evals/plot_eval_metrics.py +15 -13
  26. openadapt_ml/evals/trajectory_matching.py +120 -57
  27. openadapt_ml/experiments/demo_prompt/__init__.py +19 -0
  28. openadapt_ml/experiments/demo_prompt/format_demo.py +236 -0
  29. openadapt_ml/experiments/demo_prompt/results/experiment_20251231_002125.json +83 -0
  30. openadapt_ml/experiments/demo_prompt/results/experiment_n30_20251231_165958.json +1100 -0
  31. openadapt_ml/experiments/demo_prompt/results/multistep_20251231_025051.json +182 -0
  32. openadapt_ml/experiments/demo_prompt/run_experiment.py +541 -0
  33. openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
  34. openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
  35. openadapt_ml/experiments/representation_shootout/config.py +390 -0
  36. openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
  37. openadapt_ml/experiments/representation_shootout/runner.py +687 -0
  38. openadapt_ml/experiments/waa_demo/__init__.py +10 -0
  39. openadapt_ml/experiments/waa_demo/demos.py +357 -0
  40. openadapt_ml/experiments/waa_demo/runner.py +732 -0
  41. openadapt_ml/experiments/waa_demo/tasks.py +151 -0
  42. openadapt_ml/export/__init__.py +9 -0
  43. openadapt_ml/export/__main__.py +6 -0
  44. openadapt_ml/export/cli.py +89 -0
  45. openadapt_ml/export/parquet.py +277 -0
  46. openadapt_ml/grounding/detector.py +18 -14
  47. openadapt_ml/ingest/__init__.py +11 -10
  48. openadapt_ml/ingest/capture.py +97 -86
  49. openadapt_ml/ingest/loader.py +120 -69
  50. openadapt_ml/ingest/synthetic.py +344 -193
  51. openadapt_ml/models/api_adapter.py +14 -4
  52. openadapt_ml/models/base_adapter.py +10 -2
  53. openadapt_ml/models/providers/__init__.py +288 -0
  54. openadapt_ml/models/providers/anthropic.py +266 -0
  55. openadapt_ml/models/providers/base.py +299 -0
  56. openadapt_ml/models/providers/google.py +376 -0
  57. openadapt_ml/models/providers/openai.py +342 -0
  58. openadapt_ml/models/qwen_vl.py +46 -19
  59. openadapt_ml/perception/__init__.py +35 -0
  60. openadapt_ml/perception/integration.py +399 -0
  61. openadapt_ml/retrieval/README.md +226 -0
  62. openadapt_ml/retrieval/USAGE.md +391 -0
  63. openadapt_ml/retrieval/__init__.py +91 -0
  64. openadapt_ml/retrieval/demo_retriever.py +843 -0
  65. openadapt_ml/retrieval/embeddings.py +630 -0
  66. openadapt_ml/retrieval/index.py +194 -0
  67. openadapt_ml/retrieval/retriever.py +162 -0
  68. openadapt_ml/runtime/__init__.py +50 -0
  69. openadapt_ml/runtime/policy.py +27 -14
  70. openadapt_ml/runtime/safety_gate.py +471 -0
  71. openadapt_ml/schema/__init__.py +113 -0
  72. openadapt_ml/schema/converters.py +588 -0
  73. openadapt_ml/schema/episode.py +470 -0
  74. openadapt_ml/scripts/capture_screenshots.py +530 -0
  75. openadapt_ml/scripts/compare.py +102 -61
  76. openadapt_ml/scripts/demo_policy.py +4 -1
  77. openadapt_ml/scripts/eval_policy.py +19 -14
  78. openadapt_ml/scripts/make_gif.py +1 -1
  79. openadapt_ml/scripts/prepare_synthetic.py +16 -17
  80. openadapt_ml/scripts/train.py +98 -75
  81. openadapt_ml/segmentation/README.md +920 -0
  82. openadapt_ml/segmentation/__init__.py +97 -0
  83. openadapt_ml/segmentation/adapters/__init__.py +5 -0
  84. openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
  85. openadapt_ml/segmentation/annotator.py +610 -0
  86. openadapt_ml/segmentation/cache.py +290 -0
  87. openadapt_ml/segmentation/cli.py +674 -0
  88. openadapt_ml/segmentation/deduplicator.py +656 -0
  89. openadapt_ml/segmentation/frame_describer.py +788 -0
  90. openadapt_ml/segmentation/pipeline.py +340 -0
  91. openadapt_ml/segmentation/schemas.py +622 -0
  92. openadapt_ml/segmentation/segment_extractor.py +634 -0
  93. openadapt_ml/training/azure_ops_viewer.py +1097 -0
  94. openadapt_ml/training/benchmark_viewer.py +3255 -19
  95. openadapt_ml/training/shared_ui.py +7 -7
  96. openadapt_ml/training/stub_provider.py +57 -35
  97. openadapt_ml/training/trainer.py +255 -441
  98. openadapt_ml/training/trl_trainer.py +403 -0
  99. openadapt_ml/training/viewer.py +323 -108
  100. openadapt_ml/training/viewer_components.py +180 -0
  101. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +312 -69
  102. openadapt_ml-0.2.1.dist-info/RECORD +116 -0
  103. openadapt_ml/benchmarks/base.py +0 -366
  104. openadapt_ml/benchmarks/data_collection.py +0 -432
  105. openadapt_ml/benchmarks/runner.py +0 -381
  106. openadapt_ml/benchmarks/waa.py +0 -704
  107. openadapt_ml/schemas/__init__.py +0 -53
  108. openadapt_ml/schemas/sessions.py +0 -122
  109. openadapt_ml/schemas/validation.py +0 -252
  110. openadapt_ml-0.1.0.dist-info/RECORD +0 -55
  111. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
  112. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1111 @@
1
+ """VM monitoring utilities for WAA benchmark evaluation.
2
+
3
+ This module provides reusable classes for monitoring Windows VMs running WAA.
4
+ Can be used by the viewer, CLI, or as a standalone tool.
5
+
6
+ Enhanced with Azure ML job tracking, cost estimation, and activity detection.
7
+
8
+ Usage:
9
+ # Monitor a single VM
10
+ from openadapt_ml.benchmarks.vm_monitor import VMMonitor, VMConfig
11
+
12
+ config = VMConfig(
13
+ name="azure-waa-vm",
14
+ ssh_host="172.171.112.41",
15
+ ssh_user="azureuser",
16
+ docker_container="winarena",
17
+ internal_ip="20.20.20.21",
18
+ )
19
+
20
+ monitor = VMMonitor(config)
21
+ status = monitor.check_status()
22
+ print(f"VNC: {status.vnc_reachable}, WAA: {status.waa_ready}")
23
+
24
+ # Or run continuous monitoring
25
+ monitor.run_monitor(callback=lambda s: print(s))
26
+
27
+ # Fetch Azure ML jobs
28
+ jobs = fetch_azure_ml_jobs(days=7)
29
+ print(f"Found {len(jobs)} jobs in last 7 days")
30
+
31
+ # Calculate VM costs
32
+ costs = calculate_vm_costs(vm_size="Standard_D4ds_v5", hours=2.5)
33
+ print(f"Estimated cost: ${costs['total_cost_usd']:.2f}")
34
+ """
35
+
36
+ from __future__ import annotations
37
+
38
+ import json
39
+ import subprocess
40
+ import time
41
+ from dataclasses import dataclass, field, asdict
42
+ from datetime import datetime, timedelta
43
+ from pathlib import Path
44
+ from typing import Callable
45
+ import urllib.request
46
+ import urllib.error
47
+ import socket
48
+ import logging
49
+
50
+ logger = logging.getLogger(__name__)
51
+
52
+
53
+ @dataclass
54
+ class VMConfig:
55
+ """Configuration for a WAA VM."""
56
+
57
+ name: str
58
+ ssh_host: str
59
+ ssh_user: str = "azureuser"
60
+ vnc_port: int = 8006
61
+ waa_port: int = 5000
62
+ qmp_port: int = 7200
63
+ docker_container: str = "winarena"
64
+ internal_ip: str = "20.20.20.21"
65
+
66
+ def to_dict(self) -> dict:
67
+ """Convert to dictionary for JSON serialization."""
68
+ return asdict(self)
69
+
70
+ @classmethod
71
+ def from_dict(cls, data: dict) -> VMConfig:
72
+ """Create from dictionary."""
73
+ return cls(**data)
74
+
75
+
76
+ @dataclass
77
+ class VMStatus:
78
+ """Status of a WAA VM at a point in time."""
79
+
80
+ config: VMConfig
81
+ timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
82
+ ssh_reachable: bool = False
83
+ vnc_reachable: bool = False
84
+ waa_ready: bool = False
85
+ waa_probe_response: str | None = None
86
+ container_running: bool = False
87
+ container_logs: str | None = None
88
+ disk_usage_gb: float | None = None
89
+ error: str | None = None
90
+
91
+ def to_dict(self) -> dict:
92
+ """Convert to dictionary for JSON serialization."""
93
+ return {
94
+ "config": self.config.to_dict(),
95
+ "timestamp": self.timestamp,
96
+ "ssh_reachable": self.ssh_reachable,
97
+ "vnc_reachable": self.vnc_reachable,
98
+ "waa_ready": self.waa_ready,
99
+ "waa_probe_response": self.waa_probe_response,
100
+ "container_running": self.container_running,
101
+ "container_logs": self.container_logs,
102
+ "disk_usage_gb": self.disk_usage_gb,
103
+ "error": self.error,
104
+ }
105
+
106
+
107
+ class VMMonitor:
108
+ """Monitor a single WAA VM."""
109
+
110
+ def __init__(self, config: VMConfig, timeout: int = 5):
111
+ """Initialize monitor.
112
+
113
+ Args:
114
+ config: VM configuration.
115
+ timeout: Timeout in seconds for network operations.
116
+ """
117
+ self.config = config
118
+ self.timeout = timeout
119
+
120
+ def check_vnc(self) -> bool:
121
+ """Check if VNC port is reachable via SSH tunnel (localhost)."""
122
+ try:
123
+ # VNC is only accessible via SSH tunnel at localhost, not the public IP
124
+ url = f"http://localhost:{self.config.vnc_port}/"
125
+ req = urllib.request.Request(url, method="HEAD")
126
+ with urllib.request.urlopen(req, timeout=self.timeout):
127
+ return True
128
+ except (urllib.error.URLError, socket.timeout, Exception):
129
+ return False
130
+
131
+ def check_ssh(self) -> bool:
132
+ """Check if SSH is reachable."""
133
+ try:
134
+ result = subprocess.run(
135
+ [
136
+ "ssh",
137
+ "-o",
138
+ "StrictHostKeyChecking=no",
139
+ "-o",
140
+ f"ConnectTimeout={self.timeout}",
141
+ "-o",
142
+ "BatchMode=yes",
143
+ f"{self.config.ssh_user}@{self.config.ssh_host}",
144
+ "echo ok",
145
+ ],
146
+ capture_output=True,
147
+ text=True,
148
+ timeout=self.timeout + 5,
149
+ )
150
+ return result.returncode == 0 and "ok" in result.stdout
151
+ except (subprocess.TimeoutExpired, Exception):
152
+ return False
153
+
154
+ def check_waa_probe(self) -> tuple[bool, str | None]:
155
+ """Check if WAA /probe endpoint responds.
156
+
157
+ Returns:
158
+ Tuple of (ready, response_text).
159
+ """
160
+ try:
161
+ cmd = f"curl -s --connect-timeout {self.timeout} http://{self.config.internal_ip}:{self.config.waa_port}/probe"
162
+ result = subprocess.run(
163
+ [
164
+ "ssh",
165
+ "-o",
166
+ "StrictHostKeyChecking=no",
167
+ "-o",
168
+ f"ConnectTimeout={self.timeout}",
169
+ "-o",
170
+ "BatchMode=yes",
171
+ f"{self.config.ssh_user}@{self.config.ssh_host}",
172
+ cmd,
173
+ ],
174
+ capture_output=True,
175
+ text=True,
176
+ timeout=self.timeout + 10,
177
+ )
178
+ response = result.stdout.strip()
179
+ if response and "error" not in response.lower():
180
+ return True, response
181
+ return False, response or None
182
+ except (subprocess.TimeoutExpired, Exception) as e:
183
+ return False, str(e)
184
+
185
+ def get_container_status(self) -> tuple[bool, str | None]:
186
+ """Check container status and get recent logs.
187
+
188
+ Returns:
189
+ Tuple of (running, last_log_lines).
190
+ """
191
+ try:
192
+ cmd = f"docker ps -q -f name={self.config.docker_container}"
193
+ result = subprocess.run(
194
+ [
195
+ "ssh",
196
+ "-o",
197
+ "StrictHostKeyChecking=no",
198
+ "-o",
199
+ f"ConnectTimeout={self.timeout}",
200
+ "-o",
201
+ "BatchMode=yes",
202
+ f"{self.config.ssh_user}@{self.config.ssh_host}",
203
+ cmd,
204
+ ],
205
+ capture_output=True,
206
+ text=True,
207
+ timeout=self.timeout + 5,
208
+ )
209
+ running = bool(result.stdout.strip())
210
+
211
+ if running:
212
+ # Get last few log lines
213
+ log_cmd = f"docker logs {self.config.docker_container} 2>&1 | tail -5"
214
+ log_result = subprocess.run(
215
+ [
216
+ "ssh",
217
+ "-o",
218
+ "StrictHostKeyChecking=no",
219
+ "-o",
220
+ f"ConnectTimeout={self.timeout}",
221
+ "-o",
222
+ "BatchMode=yes",
223
+ f"{self.config.ssh_user}@{self.config.ssh_host}",
224
+ log_cmd,
225
+ ],
226
+ capture_output=True,
227
+ text=True,
228
+ timeout=self.timeout + 10,
229
+ )
230
+ return True, log_result.stdout.strip()
231
+ return False, None
232
+ except (subprocess.TimeoutExpired, Exception) as e:
233
+ return False, str(e)
234
+
235
+ def get_disk_usage(self) -> float | None:
236
+ """Get disk usage of data.img in GB."""
237
+ try:
238
+ # Try common paths
239
+ paths = [
240
+ "/home/azureuser/waa-storage/data.img",
241
+ "/home/ubuntu/waa-storage/data.img",
242
+ "/storage/data.img",
243
+ ]
244
+ for path in paths:
245
+ cmd = f"du -b {path} 2>/dev/null | cut -f1"
246
+ result = subprocess.run(
247
+ [
248
+ "ssh",
249
+ "-o",
250
+ "StrictHostKeyChecking=no",
251
+ "-o",
252
+ f"ConnectTimeout={self.timeout}",
253
+ "-o",
254
+ "BatchMode=yes",
255
+ f"{self.config.ssh_user}@{self.config.ssh_host}",
256
+ cmd,
257
+ ],
258
+ capture_output=True,
259
+ text=True,
260
+ timeout=self.timeout + 5,
261
+ )
262
+ if result.returncode == 0 and result.stdout.strip():
263
+ try:
264
+ bytes_size = int(result.stdout.strip())
265
+ return round(bytes_size / (1024**3), 2)
266
+ except ValueError:
267
+ continue
268
+ return None
269
+ except (subprocess.TimeoutExpired, Exception):
270
+ return None
271
+
272
+ def check_status(self) -> VMStatus:
273
+ """Perform full status check on the VM.
274
+
275
+ Returns:
276
+ VMStatus with all checks performed.
277
+ """
278
+ status = VMStatus(config=self.config)
279
+
280
+ try:
281
+ # Check VNC first (fastest, no SSH needed)
282
+ status.vnc_reachable = self.check_vnc()
283
+
284
+ # Check SSH
285
+ status.ssh_reachable = self.check_ssh()
286
+
287
+ if status.ssh_reachable:
288
+ # Check container
289
+ status.container_running, status.container_logs = (
290
+ self.get_container_status()
291
+ )
292
+
293
+ # Check WAA probe
294
+ status.waa_ready, status.waa_probe_response = self.check_waa_probe()
295
+
296
+ # Get disk usage
297
+ status.disk_usage_gb = self.get_disk_usage()
298
+ except Exception as e:
299
+ status.error = str(e)
300
+
301
+ return status
302
+
303
+ def run_monitor(
304
+ self,
305
+ callback: Callable[[VMStatus], None] | None = None,
306
+ interval: int = 30,
307
+ stop_on_ready: bool = True,
308
+ output_file: str | Path | None = None,
309
+ ) -> VMStatus:
310
+ """Run continuous monitoring until WAA is ready.
311
+
312
+ Args:
313
+ callback: Optional callback function called with each status update.
314
+ interval: Seconds between checks.
315
+ stop_on_ready: Stop monitoring when WAA is ready.
316
+ output_file: Optional file to write status updates (JSON lines).
317
+
318
+ Returns:
319
+ Final VMStatus (typically when WAA is ready).
320
+ """
321
+ output_path = Path(output_file) if output_file else None
322
+ if output_path:
323
+ output_path.parent.mkdir(parents=True, exist_ok=True)
324
+
325
+ while True:
326
+ status = self.check_status()
327
+
328
+ # Call callback if provided
329
+ if callback:
330
+ callback(status)
331
+
332
+ # Write to file if provided
333
+ if output_path:
334
+ with open(output_path, "a") as f:
335
+ f.write(json.dumps(status.to_dict()) + "\n")
336
+
337
+ # Check if we should stop
338
+ if stop_on_ready and status.waa_ready:
339
+ return status
340
+
341
+ time.sleep(interval)
342
+
343
+
344
+ @dataclass
345
+ class PoolWorker:
346
+ """A single worker in a VM pool."""
347
+
348
+ name: str
349
+ ip: str
350
+ status: str = "creating" # creating, ready, running, completed, failed, deleted
351
+ docker_container: str = "winarena"
352
+ waa_ready: bool = False
353
+ assigned_tasks: list[str] = field(default_factory=list)
354
+ completed_tasks: list[str] = field(default_factory=list)
355
+ current_task: str | None = None
356
+ error: str | None = None
357
+ created_at: str = field(default_factory=lambda: datetime.now().isoformat())
358
+ updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
359
+
360
+
361
+ @dataclass
362
+ class VMPool:
363
+ """A pool of worker VMs for parallel WAA evaluation."""
364
+
365
+ pool_id: str
366
+ created_at: str
367
+ resource_group: str
368
+ location: str
369
+ vm_size: str
370
+ workers: list[PoolWorker]
371
+ total_tasks: int = 0
372
+ completed_tasks: int = 0
373
+ failed_tasks: int = 0
374
+
375
+
376
+ class VMPoolRegistry:
377
+ """Manage VM pools for parallel WAA evaluation."""
378
+
379
+ REGISTRY_FILE = "benchmark_results/vm_pool_registry.json"
380
+
381
+ def __init__(self, registry_file: str | Path | None = None):
382
+ """Initialize pool registry.
383
+
384
+ Args:
385
+ registry_file: Path to JSON registry file.
386
+ """
387
+ self.registry_file = Path(registry_file or self.REGISTRY_FILE)
388
+ self._pool: VMPool | None = None
389
+ self.load()
390
+
391
+ def load(self) -> None:
392
+ """Load pool from registry file."""
393
+ if self.registry_file.exists():
394
+ try:
395
+ with open(self.registry_file) as f:
396
+ data = json.load(f)
397
+ workers = [PoolWorker(**w) for w in data.get("workers", [])]
398
+ self._pool = VMPool(
399
+ pool_id=data["pool_id"],
400
+ created_at=data["created_at"],
401
+ resource_group=data["resource_group"],
402
+ location=data["location"],
403
+ vm_size=data["vm_size"],
404
+ workers=workers,
405
+ total_tasks=data.get("total_tasks", 0),
406
+ completed_tasks=data.get("completed_tasks", 0),
407
+ failed_tasks=data.get("failed_tasks", 0),
408
+ )
409
+ except (json.JSONDecodeError, KeyError) as e:
410
+ print(f"Warning: Could not load pool registry: {e}")
411
+ self._pool = None
412
+
413
+ def save(self) -> None:
414
+ """Save pool to registry file."""
415
+ if self._pool is None:
416
+ return
417
+ self.registry_file.parent.mkdir(parents=True, exist_ok=True)
418
+ with open(self.registry_file, "w") as f:
419
+ json.dump(asdict(self._pool), f, indent=2)
420
+
421
+ def create_pool(
422
+ self,
423
+ workers: list[tuple[str, str]], # [(name, ip), ...]
424
+ resource_group: str,
425
+ location: str,
426
+ vm_size: str = "Standard_D4ds_v5",
427
+ ) -> VMPool:
428
+ """Create a new pool from created VMs.
429
+
430
+ Args:
431
+ workers: List of (name, ip) tuples.
432
+ resource_group: Azure resource group.
433
+ location: Azure region.
434
+ vm_size: VM size used.
435
+
436
+ Returns:
437
+ Created VMPool.
438
+ """
439
+ pool_id = datetime.now().strftime("%Y%m%d_%H%M%S")
440
+ self._pool = VMPool(
441
+ pool_id=pool_id,
442
+ created_at=datetime.now().isoformat(),
443
+ resource_group=resource_group,
444
+ location=location,
445
+ vm_size=vm_size,
446
+ workers=[
447
+ PoolWorker(name=name, ip=ip, status="ready") for name, ip in workers
448
+ ],
449
+ )
450
+ self.save()
451
+ return self._pool
452
+
453
+ def get_pool(self) -> VMPool | None:
454
+ """Get current pool."""
455
+ return self._pool
456
+
457
+ def update_worker(self, name: str, **kwargs) -> None:
458
+ """Update a worker's status.
459
+
460
+ Args:
461
+ name: Worker name.
462
+ **kwargs: Fields to update.
463
+ """
464
+ if self._pool is None:
465
+ return
466
+ for worker in self._pool.workers:
467
+ if worker.name == name:
468
+ for key, value in kwargs.items():
469
+ if hasattr(worker, key):
470
+ setattr(worker, key, value)
471
+ worker.updated_at = datetime.now().isoformat()
472
+ break
473
+ self.save()
474
+
475
+ def update_pool_progress(self, completed: int = 0, failed: int = 0) -> None:
476
+ """Update pool-level progress.
477
+
478
+ Args:
479
+ completed: Increment completed count by this amount.
480
+ failed: Increment failed count by this amount.
481
+ """
482
+ if self._pool is None:
483
+ return
484
+ self._pool.completed_tasks += completed
485
+ self._pool.failed_tasks += failed
486
+ self.save()
487
+
488
+ def delete_pool(self) -> bool:
489
+ """Delete the pool registry (VMs must be deleted separately).
490
+
491
+ Returns:
492
+ True if pool was deleted.
493
+ """
494
+ if self.registry_file.exists():
495
+ self.registry_file.unlink()
496
+ self._pool = None
497
+ return True
498
+ return False
499
+
500
+
501
+ class VMRegistry:
502
+ """Manage a registry of VMs and their status."""
503
+
504
+ def __init__(
505
+ self, registry_file: str | Path = "benchmark_results/vm_registry.json"
506
+ ):
507
+ """Initialize registry.
508
+
509
+ Args:
510
+ registry_file: Path to JSON registry file.
511
+ """
512
+ self.registry_file = Path(registry_file)
513
+ self._vms: list[VMConfig] = []
514
+ self.load()
515
+
516
+ def load(self) -> None:
517
+ """Load VMs from registry file."""
518
+ if self.registry_file.exists():
519
+ with open(self.registry_file) as f:
520
+ data = json.load(f)
521
+ self._vms = [VMConfig.from_dict(vm) for vm in data]
522
+
523
+ def save(self) -> None:
524
+ """Save VMs to registry file."""
525
+ self.registry_file.parent.mkdir(parents=True, exist_ok=True)
526
+ with open(self.registry_file, "w") as f:
527
+ json.dump([vm.to_dict() for vm in self._vms], f, indent=2)
528
+
529
+ def add(self, config: VMConfig) -> None:
530
+ """Add a VM to the registry."""
531
+ # Remove existing VM with same name
532
+ self._vms = [vm for vm in self._vms if vm.name != config.name]
533
+ self._vms.append(config)
534
+ self.save()
535
+
536
+ def remove(self, name: str) -> bool:
537
+ """Remove a VM from the registry.
538
+
539
+ Returns:
540
+ True if VM was found and removed.
541
+ """
542
+ original_len = len(self._vms)
543
+ self._vms = [vm for vm in self._vms if vm.name != name]
544
+ if len(self._vms) < original_len:
545
+ self.save()
546
+ return True
547
+ return False
548
+
549
+ def get(self, name: str) -> VMConfig | None:
550
+ """Get a VM by name."""
551
+ for vm in self._vms:
552
+ if vm.name == name:
553
+ return vm
554
+ return None
555
+
556
+ def list(self) -> list[VMConfig]:
557
+ """List all VMs."""
558
+ return list(self._vms)
559
+
560
+ def check_all(self, timeout: int = 5) -> list[VMStatus]:
561
+ """Check status of all VMs.
562
+
563
+ Args:
564
+ timeout: Timeout per VM check.
565
+
566
+ Returns:
567
+ List of VMStatus for each registered VM.
568
+ """
569
+ statuses = []
570
+ for config in self._vms:
571
+ monitor = VMMonitor(config, timeout=timeout)
572
+ statuses.append(monitor.check_status())
573
+ return statuses
574
+
575
+
576
+ def main():
577
+ """CLI entry point for VM monitoring."""
578
+ import argparse
579
+
580
+ parser = argparse.ArgumentParser(description="Monitor WAA VMs")
581
+ parser.add_argument("--host", help="SSH host")
582
+ parser.add_argument("--user", default="azureuser", help="SSH user")
583
+ parser.add_argument("--container", default="winarena", help="Docker container name")
584
+ parser.add_argument(
585
+ "--interval", type=int, default=30, help="Check interval in seconds"
586
+ )
587
+ parser.add_argument("--output", help="Output file for status updates (JSON lines)")
588
+ parser.add_argument("--list", action="store_true", help="List all registered VMs")
589
+ parser.add_argument(
590
+ "--check-all", action="store_true", help="Check all registered VMs"
591
+ )
592
+
593
+ args = parser.parse_args()
594
+
595
+ if args.list:
596
+ registry = VMRegistry()
597
+ for vm in registry.list():
598
+ print(
599
+ f" {vm.name}: {vm.ssh_user}@{vm.ssh_host} (container: {vm.docker_container})"
600
+ )
601
+ return
602
+
603
+ if args.check_all:
604
+ registry = VMRegistry()
605
+ for status in registry.check_all():
606
+ print(f"\n{status.config.name}:")
607
+ print(f" SSH: {'✓' if status.ssh_reachable else '✗'}")
608
+ print(f" VNC: {'✓' if status.vnc_reachable else '✗'}")
609
+ print(f" WAA: {'✓ READY' if status.waa_ready else '✗ Not ready'}")
610
+ if status.disk_usage_gb:
611
+ print(f" Disk: {status.disk_usage_gb} GB")
612
+ return
613
+
614
+ if not args.host:
615
+ parser.error("--host is required for monitoring")
616
+
617
+ config = VMConfig(
618
+ name="cli-vm",
619
+ ssh_host=args.host,
620
+ ssh_user=args.user,
621
+ docker_container=args.container,
622
+ )
623
+
624
+ monitor = VMMonitor(config)
625
+
626
+ def print_status(status: VMStatus):
627
+ ts = datetime.now().strftime("%H:%M:%S")
628
+ waa_str = "READY!" if status.waa_ready else "not ready"
629
+ disk_str = f"{status.disk_usage_gb}GB" if status.disk_usage_gb else "?"
630
+ print(
631
+ f"[{ts}] SSH: {'✓' if status.ssh_reachable else '✗'} | "
632
+ f"VNC: {'✓' if status.vnc_reachable else '✗'} | "
633
+ f"WAA: {waa_str} | Disk: {disk_str}"
634
+ )
635
+ if status.container_logs:
636
+ # Show last log line
637
+ last_line = status.container_logs.split("\n")[-1][:80]
638
+ print(f" Log: {last_line}")
639
+
640
+ print(f"Monitoring {args.host}... (Ctrl+C to stop)")
641
+ try:
642
+ final_status = monitor.run_monitor(
643
+ callback=print_status,
644
+ interval=args.interval,
645
+ output_file=args.output,
646
+ )
647
+ print(f"\n✓ WAA is ready! Probe response: {final_status.waa_probe_response}")
648
+ except KeyboardInterrupt:
649
+ print("\nMonitoring stopped.")
650
+
651
+
652
+ # ============================================================================
653
+ # Azure ML Job Tracking
654
+ # ============================================================================
655
+
656
+
657
+ @dataclass
658
+ class AzureMLJob:
659
+ """Represents an Azure ML job."""
660
+
661
+ job_id: str
662
+ display_name: str
663
+ status: str # running, completed, failed, canceled
664
+ created_at: str
665
+ compute_target: str | None = None
666
+ duration_minutes: float | None = None
667
+ cost_usd: float | None = None
668
+ azure_dashboard_url: str | None = None
669
+
670
+
671
+ def fetch_azure_ml_jobs(
672
+ resource_group: str = "openadapt-agents",
673
+ workspace_name: str = "openadapt-ml",
674
+ days: int = 7,
675
+ max_results: int = 20,
676
+ ) -> list[AzureMLJob]:
677
+ """Fetch recent Azure ML jobs.
678
+
679
+ Args:
680
+ resource_group: Azure resource group name.
681
+ workspace_name: Azure ML workspace name.
682
+ days: Number of days to look back.
683
+ max_results: Maximum number of jobs to return.
684
+
685
+ Returns:
686
+ List of AzureMLJob objects, sorted by creation time (newest first).
687
+ """
688
+ try:
689
+ result = subprocess.run(
690
+ [
691
+ "az",
692
+ "ml",
693
+ "job",
694
+ "list",
695
+ "--resource-group",
696
+ resource_group,
697
+ "--workspace-name",
698
+ workspace_name,
699
+ "--query",
700
+ "[].{name:name,display_name:display_name,status:status,created_at:creation_context.created_at,compute:compute}",
701
+ "-o",
702
+ "json",
703
+ ],
704
+ capture_output=True,
705
+ text=True,
706
+ timeout=30,
707
+ )
708
+
709
+ if result.returncode != 0:
710
+ logger.error(f"Azure CLI error: {result.stderr}")
711
+ return []
712
+
713
+ jobs_raw = json.loads(result.stdout)
714
+
715
+ # Filter by date
716
+ cutoff_date = datetime.now() - timedelta(days=days)
717
+ jobs = []
718
+
719
+ for job in jobs_raw[:max_results]:
720
+ created_at = job.get("created_at", "")
721
+ try:
722
+ # Parse ISO format: 2026-01-17T10:30:00Z
723
+ job_date = datetime.fromisoformat(
724
+ created_at.replace("Z", "+00:00")
725
+ if created_at
726
+ else datetime.now().isoformat()
727
+ )
728
+ if job_date < cutoff_date.replace(tzinfo=job_date.tzinfo):
729
+ continue
730
+ except (ValueError, AttributeError):
731
+ # If date parsing fails, include the job
732
+ pass
733
+
734
+ # Calculate duration for completed jobs
735
+ duration_minutes = None
736
+ status = job.get("status", "unknown").lower()
737
+
738
+ # Build Azure dashboard URL
739
+ subscription_id = get_azure_subscription_id()
740
+ wsid = f"/subscriptions/{subscription_id}/resourceGroups/{resource_group}/providers/Microsoft.MachineLearningServices/workspaces/{workspace_name}"
741
+ dashboard_url = (
742
+ f"https://ml.azure.com/runs/{job.get('name', '')}?wsid={wsid}"
743
+ )
744
+
745
+ jobs.append(
746
+ AzureMLJob(
747
+ job_id=job.get("name", "unknown"),
748
+ display_name=job.get("display_name", ""),
749
+ status=status,
750
+ created_at=created_at,
751
+ compute_target=job.get("compute", None),
752
+ duration_minutes=duration_minutes,
753
+ azure_dashboard_url=dashboard_url,
754
+ )
755
+ )
756
+
757
+ return jobs
758
+
759
+ except Exception as e:
760
+ logger.error(f"Error fetching Azure ML jobs: {e}")
761
+ return []
762
+
763
+
764
+ def get_azure_subscription_id() -> str:
765
+ """Get the current Azure subscription ID."""
766
+ try:
767
+ result = subprocess.run(
768
+ ["az", "account", "show", "--query", "id", "-o", "tsv"],
769
+ capture_output=True,
770
+ text=True,
771
+ timeout=10,
772
+ )
773
+ if result.returncode == 0:
774
+ return result.stdout.strip()
775
+ except Exception:
776
+ pass
777
+ return "unknown"
778
+
779
+
780
+ # ============================================================================
781
+ # Cost Tracking
782
+ # ============================================================================
783
+
784
+
785
+ @dataclass
786
+ class VMCostEstimate:
787
+ """Estimated costs for VM usage."""
788
+
789
+ vm_size: str
790
+ hourly_rate_usd: float
791
+ hours_elapsed: float
792
+ cost_usd: float
793
+ cost_per_hour_usd: float
794
+ cost_per_day_usd: float
795
+ cost_per_week_usd: float
796
+
797
+
798
+ # Azure VM pricing (US East, as of Jan 2025)
799
+ VM_PRICING = {
800
+ "Standard_D2_v3": 0.096,
801
+ "Standard_D4_v3": 0.192,
802
+ "Standard_D8_v3": 0.384,
803
+ "Standard_D4s_v3": 0.192,
804
+ "Standard_D8s_v3": 0.384,
805
+ "Standard_D4ds_v5": 0.192,
806
+ "Standard_D8ds_v5": 0.384,
807
+ "Standard_D16ds_v5": 0.768,
808
+ "Standard_D32ds_v5": 1.536,
809
+ }
810
+
811
+
812
+ def calculate_vm_costs(
813
+ vm_size: str, hours: float, hourly_rate_override: float | None = None
814
+ ) -> VMCostEstimate:
815
+ """Calculate VM cost estimates.
816
+
817
+ Args:
818
+ vm_size: Azure VM size (e.g., "Standard_D4ds_v5").
819
+ hours: Number of hours the VM has been running.
820
+ hourly_rate_override: Override default hourly rate (for custom pricing).
821
+
822
+ Returns:
823
+ VMCostEstimate with cost breakdown.
824
+ """
825
+ hourly_rate = hourly_rate_override or VM_PRICING.get(vm_size, 0.20)
826
+ cost_usd = hourly_rate * hours
827
+
828
+ return VMCostEstimate(
829
+ vm_size=vm_size,
830
+ hourly_rate_usd=hourly_rate,
831
+ hours_elapsed=hours,
832
+ cost_usd=cost_usd,
833
+ cost_per_hour_usd=hourly_rate,
834
+ cost_per_day_usd=hourly_rate * 24,
835
+ cost_per_week_usd=hourly_rate * 24 * 7,
836
+ )
837
+
838
+
839
+ def get_vm_uptime_hours(
840
+ resource_group: str, vm_name: str, check_actual_state: bool = True
841
+ ) -> float:
842
+ """Get VM uptime in hours.
843
+
844
+ Args:
845
+ resource_group: Azure resource group.
846
+ vm_name: VM name.
847
+ check_actual_state: If True, check if VM is actually running.
848
+
849
+ Returns:
850
+ Hours since VM started, or 0 if VM is not running.
851
+ """
852
+ try:
853
+ # Get VM creation time or last start time
854
+ result = subprocess.run(
855
+ [
856
+ "az",
857
+ "vm",
858
+ "show",
859
+ "-d",
860
+ "-g",
861
+ resource_group,
862
+ "-n",
863
+ vm_name,
864
+ "--query",
865
+ "{powerState:powerState}",
866
+ "-o",
867
+ "json",
868
+ ],
869
+ capture_output=True,
870
+ text=True,
871
+ timeout=10,
872
+ )
873
+
874
+ if result.returncode != 0:
875
+ return 0.0
876
+
877
+ info = json.loads(result.stdout)
878
+ power_state = info.get("powerState", "")
879
+
880
+ # Check if VM is running
881
+ if check_actual_state and "running" not in power_state.lower():
882
+ return 0.0
883
+
884
+ # Try to get activity logs for last start time
885
+ result = subprocess.run(
886
+ [
887
+ "az",
888
+ "monitor",
889
+ "activity-log",
890
+ "list",
891
+ "--resource-group",
892
+ resource_group,
893
+ "--resource-id",
894
+ f"/subscriptions/{get_azure_subscription_id()}/resourceGroups/{resource_group}/providers/Microsoft.Compute/virtualMachines/{vm_name}",
895
+ "--query",
896
+ "[?operationName.localizedValue=='Start Virtual Machine' || operationName.localizedValue=='Create or Update Virtual Machine'].eventTimestamp | [0]",
897
+ "-o",
898
+ "tsv",
899
+ ],
900
+ capture_output=True,
901
+ text=True,
902
+ timeout=15,
903
+ )
904
+
905
+ if result.returncode == 0 and result.stdout.strip():
906
+ start_time_str = result.stdout.strip()
907
+ start_time = datetime.fromisoformat(start_time_str.replace("Z", "+00:00"))
908
+ elapsed = datetime.now(start_time.tzinfo) - start_time
909
+ return elapsed.total_seconds() / 3600
910
+
911
+ # Fallback: assume started 1 hour ago if we can't determine
912
+ return 1.0
913
+
914
+ except Exception as e:
915
+ logger.debug(f"Error getting VM uptime: {e}")
916
+ return 0.0
917
+
918
+
919
+ # ============================================================================
920
+ # VM Activity Detection
921
+ # ============================================================================
922
+
923
+
924
+ @dataclass
925
+ class VMActivity:
926
+ """Current VM activity information."""
927
+
928
+ is_active: bool
929
+ activity_type: str # idle, benchmark_running, training, setup, unknown
930
+ description: str
931
+ benchmark_progress: dict | None = None # If benchmark is running
932
+ last_action_time: str | None = None
933
+
934
+
935
+ def detect_vm_activity(
936
+ ip: str,
937
+ ssh_user: str = "azureuser",
938
+ docker_container: str = "winarena",
939
+ internal_ip: str = "localhost", # WAA server bound to localhost via Docker port forward
940
+ ) -> VMActivity:
941
+ """Detect what the VM is currently doing.
942
+
943
+ Args:
944
+ ip: VM IP address.
945
+ ssh_user: SSH username.
946
+ docker_container: Docker container name.
947
+ internal_ip: Internal IP for WAA server.
948
+
949
+ Returns:
950
+ VMActivity with current activity information.
951
+ """
952
+ try:
953
+ # Check if container is running
954
+ result = subprocess.run(
955
+ [
956
+ "ssh",
957
+ "-o",
958
+ "StrictHostKeyChecking=no",
959
+ "-o",
960
+ "ConnectTimeout=5",
961
+ f"{ssh_user}@{ip}",
962
+ f"docker ps -q -f name={docker_container}",
963
+ ],
964
+ capture_output=True,
965
+ text=True,
966
+ timeout=10,
967
+ )
968
+
969
+ if result.returncode != 0 or not result.stdout.strip():
970
+ return VMActivity(
971
+ is_active=False,
972
+ activity_type="idle",
973
+ description="Container not running",
974
+ )
975
+
976
+ # Check WAA probe for benchmark status
977
+ result = subprocess.run(
978
+ [
979
+ "ssh",
980
+ "-o",
981
+ "StrictHostKeyChecking=no",
982
+ "-o",
983
+ "ConnectTimeout=5",
984
+ f"{ssh_user}@{ip}",
985
+ f"curl -s --connect-timeout 3 http://{internal_ip}:5000/probe",
986
+ ],
987
+ capture_output=True,
988
+ text=True,
989
+ timeout=10,
990
+ )
991
+
992
+ if result.returncode == 0 and result.stdout.strip():
993
+ probe_response = result.stdout.strip()
994
+ try:
995
+ probe_data = json.loads(probe_response)
996
+ # WAA is ready and responsive - check if benchmark is actually running
997
+ # by looking for python processes (Navi agent or our client)
998
+ python_check = subprocess.run(
999
+ [
1000
+ "ssh",
1001
+ "-o",
1002
+ "StrictHostKeyChecking=no",
1003
+ "-o",
1004
+ "ConnectTimeout=5",
1005
+ f"{ssh_user}@{ip}",
1006
+ f"docker exec {docker_container} pgrep -f 'python.*run' 2>/dev/null | head -1",
1007
+ ],
1008
+ capture_output=True,
1009
+ text=True,
1010
+ timeout=10,
1011
+ )
1012
+ is_running = bool(python_check.stdout.strip())
1013
+
1014
+ return VMActivity(
1015
+ is_active=is_running,
1016
+ activity_type="benchmark_running" if is_running else "idle",
1017
+ description="WAA benchmark running"
1018
+ if is_running
1019
+ else "WAA ready - idle",
1020
+ benchmark_progress=probe_data,
1021
+ )
1022
+ except json.JSONDecodeError:
1023
+ # Got response but not JSON - maybe setup phase
1024
+ return VMActivity(
1025
+ is_active=True,
1026
+ activity_type="setup",
1027
+ description="WAA starting up",
1028
+ )
1029
+
1030
+ # Container running but WAA not ready
1031
+ return VMActivity(
1032
+ is_active=True,
1033
+ activity_type="setup",
1034
+ description="Windows VM booting or WAA initializing",
1035
+ )
1036
+
1037
+ except Exception as e:
1038
+ logger.debug(f"Error detecting VM activity: {e}")
1039
+ return VMActivity(
1040
+ is_active=False,
1041
+ activity_type="unknown",
1042
+ description=f"Error checking activity: {str(e)[:100]}",
1043
+ )
1044
+
1045
+
1046
+ # ============================================================================
1047
+ # Evaluation History
1048
+ # ============================================================================
1049
+
1050
+
1051
+ @dataclass
1052
+ class EvaluationRun:
1053
+ """Historical evaluation run."""
1054
+
1055
+ run_id: str
1056
+ started_at: str
1057
+ completed_at: str | None
1058
+ num_tasks: int
1059
+ success_rate: float | None
1060
+ agent_type: str
1061
+ status: str # running, completed, failed
1062
+
1063
+
1064
+ def get_evaluation_history(
1065
+ results_dir: Path | str = "benchmark_results", max_runs: int = 10
1066
+ ) -> list[EvaluationRun]:
1067
+ """Get history of evaluation runs from results directory.
1068
+
1069
+ Args:
1070
+ results_dir: Path to benchmark results directory.
1071
+ max_runs: Maximum number of runs to return.
1072
+
1073
+ Returns:
1074
+ List of EvaluationRun objects, sorted by start time (newest first).
1075
+ """
1076
+ results_path = Path(results_dir)
1077
+ if not results_path.exists():
1078
+ return []
1079
+
1080
+ runs = []
1081
+
1082
+ # Look for run directories or result files
1083
+ for item in sorted(results_path.iterdir(), reverse=True):
1084
+ if item.is_dir():
1085
+ # Check for summary.json or similar
1086
+ summary_file = item / "summary.json"
1087
+ if summary_file.exists():
1088
+ try:
1089
+ summary = json.loads(summary_file.read_text())
1090
+ runs.append(
1091
+ EvaluationRun(
1092
+ run_id=item.name,
1093
+ started_at=summary.get("started_at", "unknown"),
1094
+ completed_at=summary.get("completed_at", None),
1095
+ num_tasks=summary.get("num_tasks", 0),
1096
+ success_rate=summary.get("success_rate", None),
1097
+ agent_type=summary.get("agent_type", "unknown"),
1098
+ status=summary.get("status", "completed"),
1099
+ )
1100
+ )
1101
+ except (json.JSONDecodeError, KeyError):
1102
+ continue
1103
+
1104
+ if len(runs) >= max_runs:
1105
+ break
1106
+
1107
+ return runs
1108
+
1109
+
1110
+ if __name__ == "__main__":
1111
+ main()