openadapt-ml 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/benchmarks/__init__.py +8 -0
- openadapt_ml/benchmarks/agent.py +90 -11
- openadapt_ml/benchmarks/azure.py +35 -6
- openadapt_ml/benchmarks/cli.py +4449 -201
- openadapt_ml/benchmarks/live_tracker.py +180 -0
- openadapt_ml/benchmarks/runner.py +41 -4
- openadapt_ml/benchmarks/viewer.py +1219 -0
- openadapt_ml/benchmarks/vm_monitor.py +610 -0
- openadapt_ml/benchmarks/waa.py +61 -4
- openadapt_ml/benchmarks/waa_deploy/Dockerfile +222 -0
- openadapt_ml/benchmarks/waa_deploy/__init__.py +10 -0
- openadapt_ml/benchmarks/waa_deploy/api_agent.py +539 -0
- openadapt_ml/benchmarks/waa_deploy/start_waa_server.bat +53 -0
- openadapt_ml/benchmarks/waa_live.py +619 -0
- openadapt_ml/cloud/local.py +1555 -1
- openadapt_ml/cloud/ssh_tunnel.py +553 -0
- openadapt_ml/datasets/next_action.py +87 -68
- openadapt_ml/evals/grounding.py +26 -8
- openadapt_ml/evals/trajectory_matching.py +84 -36
- openadapt_ml/experiments/demo_prompt/__init__.py +19 -0
- openadapt_ml/experiments/demo_prompt/format_demo.py +226 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_20251231_002125.json +83 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_n30_20251231_165958.json +1100 -0
- openadapt_ml/experiments/demo_prompt/results/multistep_20251231_025051.json +182 -0
- openadapt_ml/experiments/demo_prompt/run_experiment.py +531 -0
- openadapt_ml/experiments/waa_demo/__init__.py +10 -0
- openadapt_ml/experiments/waa_demo/demos.py +357 -0
- openadapt_ml/experiments/waa_demo/runner.py +717 -0
- openadapt_ml/experiments/waa_demo/tasks.py +151 -0
- openadapt_ml/export/__init__.py +9 -0
- openadapt_ml/export/__main__.py +6 -0
- openadapt_ml/export/cli.py +89 -0
- openadapt_ml/export/parquet.py +265 -0
- openadapt_ml/ingest/__init__.py +3 -4
- openadapt_ml/ingest/capture.py +89 -81
- openadapt_ml/ingest/loader.py +116 -68
- openadapt_ml/ingest/synthetic.py +221 -159
- openadapt_ml/retrieval/README.md +226 -0
- openadapt_ml/retrieval/USAGE.md +391 -0
- openadapt_ml/retrieval/__init__.py +91 -0
- openadapt_ml/retrieval/demo_retriever.py +817 -0
- openadapt_ml/retrieval/embeddings.py +629 -0
- openadapt_ml/retrieval/index.py +194 -0
- openadapt_ml/retrieval/retriever.py +160 -0
- openadapt_ml/runtime/policy.py +10 -10
- openadapt_ml/schema/__init__.py +104 -0
- openadapt_ml/schema/converters.py +541 -0
- openadapt_ml/schema/episode.py +457 -0
- openadapt_ml/scripts/compare.py +26 -16
- openadapt_ml/scripts/eval_policy.py +4 -5
- openadapt_ml/scripts/prepare_synthetic.py +14 -17
- openadapt_ml/scripts/train.py +81 -70
- openadapt_ml/training/benchmark_viewer.py +3225 -0
- openadapt_ml/training/trainer.py +120 -363
- openadapt_ml/training/trl_trainer.py +354 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/METADATA +102 -60
- openadapt_ml-0.2.0.dist-info/RECORD +86 -0
- openadapt_ml/schemas/__init__.py +0 -53
- openadapt_ml/schemas/sessions.py +0 -122
- openadapt_ml/schemas/validation.py +0 -252
- openadapt_ml-0.1.0.dist-info/RECORD +0 -55
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/WHEEL +0 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,610 @@
|
|
|
1
|
+
"""VM monitoring utilities for WAA benchmark evaluation.
|
|
2
|
+
|
|
3
|
+
This module provides reusable classes for monitoring Windows VMs running WAA.
|
|
4
|
+
Can be used by the viewer, CLI, or as a standalone tool.
|
|
5
|
+
|
|
6
|
+
Usage:
|
|
7
|
+
# Monitor a single VM
|
|
8
|
+
from openadapt_ml.benchmarks.vm_monitor import VMMonitor, VMConfig
|
|
9
|
+
|
|
10
|
+
config = VMConfig(
|
|
11
|
+
name="azure-waa-vm",
|
|
12
|
+
ssh_host="172.171.112.41",
|
|
13
|
+
ssh_user="azureuser",
|
|
14
|
+
docker_container="winarena",
|
|
15
|
+
internal_ip="20.20.20.21",
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
monitor = VMMonitor(config)
|
|
19
|
+
status = monitor.check_status()
|
|
20
|
+
print(f"VNC: {status.vnc_reachable}, WAA: {status.waa_ready}")
|
|
21
|
+
|
|
22
|
+
# Or run continuous monitoring
|
|
23
|
+
monitor.run_monitor(callback=lambda s: print(s))
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
from __future__ import annotations
|
|
27
|
+
|
|
28
|
+
import json
|
|
29
|
+
import subprocess
|
|
30
|
+
import time
|
|
31
|
+
from dataclasses import dataclass, field, asdict
|
|
32
|
+
from datetime import datetime
|
|
33
|
+
from pathlib import Path
|
|
34
|
+
from typing import Callable
|
|
35
|
+
import urllib.request
|
|
36
|
+
import urllib.error
|
|
37
|
+
import socket
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class VMConfig:
|
|
42
|
+
"""Configuration for a WAA VM."""
|
|
43
|
+
|
|
44
|
+
name: str
|
|
45
|
+
ssh_host: str
|
|
46
|
+
ssh_user: str = "azureuser"
|
|
47
|
+
vnc_port: int = 8006
|
|
48
|
+
waa_port: int = 5000
|
|
49
|
+
qmp_port: int = 7200
|
|
50
|
+
docker_container: str = "winarena"
|
|
51
|
+
internal_ip: str = "20.20.20.21"
|
|
52
|
+
|
|
53
|
+
def to_dict(self) -> dict:
|
|
54
|
+
"""Convert to dictionary for JSON serialization."""
|
|
55
|
+
return asdict(self)
|
|
56
|
+
|
|
57
|
+
@classmethod
|
|
58
|
+
def from_dict(cls, data: dict) -> VMConfig:
|
|
59
|
+
"""Create from dictionary."""
|
|
60
|
+
return cls(**data)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@dataclass
|
|
64
|
+
class VMStatus:
|
|
65
|
+
"""Status of a WAA VM at a point in time."""
|
|
66
|
+
|
|
67
|
+
config: VMConfig
|
|
68
|
+
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
|
|
69
|
+
ssh_reachable: bool = False
|
|
70
|
+
vnc_reachable: bool = False
|
|
71
|
+
waa_ready: bool = False
|
|
72
|
+
waa_probe_response: str | None = None
|
|
73
|
+
container_running: bool = False
|
|
74
|
+
container_logs: str | None = None
|
|
75
|
+
disk_usage_gb: float | None = None
|
|
76
|
+
error: str | None = None
|
|
77
|
+
|
|
78
|
+
def to_dict(self) -> dict:
|
|
79
|
+
"""Convert to dictionary for JSON serialization."""
|
|
80
|
+
return {
|
|
81
|
+
"config": self.config.to_dict(),
|
|
82
|
+
"timestamp": self.timestamp,
|
|
83
|
+
"ssh_reachable": self.ssh_reachable,
|
|
84
|
+
"vnc_reachable": self.vnc_reachable,
|
|
85
|
+
"waa_ready": self.waa_ready,
|
|
86
|
+
"waa_probe_response": self.waa_probe_response,
|
|
87
|
+
"container_running": self.container_running,
|
|
88
|
+
"container_logs": self.container_logs,
|
|
89
|
+
"disk_usage_gb": self.disk_usage_gb,
|
|
90
|
+
"error": self.error,
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class VMMonitor:
|
|
95
|
+
"""Monitor a single WAA VM."""
|
|
96
|
+
|
|
97
|
+
def __init__(self, config: VMConfig, timeout: int = 5):
|
|
98
|
+
"""Initialize monitor.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
config: VM configuration.
|
|
102
|
+
timeout: Timeout in seconds for network operations.
|
|
103
|
+
"""
|
|
104
|
+
self.config = config
|
|
105
|
+
self.timeout = timeout
|
|
106
|
+
|
|
107
|
+
def check_vnc(self) -> bool:
|
|
108
|
+
"""Check if VNC port is reachable."""
|
|
109
|
+
try:
|
|
110
|
+
url = f"http://{self.config.ssh_host}:{self.config.vnc_port}/"
|
|
111
|
+
req = urllib.request.Request(url, method="HEAD")
|
|
112
|
+
with urllib.request.urlopen(req, timeout=self.timeout):
|
|
113
|
+
return True
|
|
114
|
+
except (urllib.error.URLError, socket.timeout, Exception):
|
|
115
|
+
return False
|
|
116
|
+
|
|
117
|
+
def check_ssh(self) -> bool:
|
|
118
|
+
"""Check if SSH is reachable."""
|
|
119
|
+
try:
|
|
120
|
+
result = subprocess.run(
|
|
121
|
+
[
|
|
122
|
+
"ssh",
|
|
123
|
+
"-o", "StrictHostKeyChecking=no",
|
|
124
|
+
"-o", f"ConnectTimeout={self.timeout}",
|
|
125
|
+
"-o", "BatchMode=yes",
|
|
126
|
+
f"{self.config.ssh_user}@{self.config.ssh_host}",
|
|
127
|
+
"echo ok",
|
|
128
|
+
],
|
|
129
|
+
capture_output=True,
|
|
130
|
+
text=True,
|
|
131
|
+
timeout=self.timeout + 5,
|
|
132
|
+
)
|
|
133
|
+
return result.returncode == 0 and "ok" in result.stdout
|
|
134
|
+
except (subprocess.TimeoutExpired, Exception):
|
|
135
|
+
return False
|
|
136
|
+
|
|
137
|
+
def check_waa_probe(self) -> tuple[bool, str | None]:
|
|
138
|
+
"""Check if WAA /probe endpoint responds.
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
Tuple of (ready, response_text).
|
|
142
|
+
"""
|
|
143
|
+
try:
|
|
144
|
+
cmd = f"curl -s --connect-timeout {self.timeout} http://{self.config.internal_ip}:{self.config.waa_port}/probe"
|
|
145
|
+
result = subprocess.run(
|
|
146
|
+
[
|
|
147
|
+
"ssh",
|
|
148
|
+
"-o", "StrictHostKeyChecking=no",
|
|
149
|
+
"-o", f"ConnectTimeout={self.timeout}",
|
|
150
|
+
"-o", "BatchMode=yes",
|
|
151
|
+
f"{self.config.ssh_user}@{self.config.ssh_host}",
|
|
152
|
+
cmd,
|
|
153
|
+
],
|
|
154
|
+
capture_output=True,
|
|
155
|
+
text=True,
|
|
156
|
+
timeout=self.timeout + 10,
|
|
157
|
+
)
|
|
158
|
+
response = result.stdout.strip()
|
|
159
|
+
if response and "error" not in response.lower():
|
|
160
|
+
return True, response
|
|
161
|
+
return False, response or None
|
|
162
|
+
except (subprocess.TimeoutExpired, Exception) as e:
|
|
163
|
+
return False, str(e)
|
|
164
|
+
|
|
165
|
+
def get_container_status(self) -> tuple[bool, str | None]:
|
|
166
|
+
"""Check container status and get recent logs.
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
Tuple of (running, last_log_lines).
|
|
170
|
+
"""
|
|
171
|
+
try:
|
|
172
|
+
cmd = f"docker ps -q -f name={self.config.docker_container}"
|
|
173
|
+
result = subprocess.run(
|
|
174
|
+
[
|
|
175
|
+
"ssh",
|
|
176
|
+
"-o", "StrictHostKeyChecking=no",
|
|
177
|
+
"-o", f"ConnectTimeout={self.timeout}",
|
|
178
|
+
"-o", "BatchMode=yes",
|
|
179
|
+
f"{self.config.ssh_user}@{self.config.ssh_host}",
|
|
180
|
+
cmd,
|
|
181
|
+
],
|
|
182
|
+
capture_output=True,
|
|
183
|
+
text=True,
|
|
184
|
+
timeout=self.timeout + 5,
|
|
185
|
+
)
|
|
186
|
+
running = bool(result.stdout.strip())
|
|
187
|
+
|
|
188
|
+
if running:
|
|
189
|
+
# Get last few log lines
|
|
190
|
+
log_cmd = f"docker logs {self.config.docker_container} 2>&1 | tail -5"
|
|
191
|
+
log_result = subprocess.run(
|
|
192
|
+
[
|
|
193
|
+
"ssh",
|
|
194
|
+
"-o", "StrictHostKeyChecking=no",
|
|
195
|
+
"-o", f"ConnectTimeout={self.timeout}",
|
|
196
|
+
"-o", "BatchMode=yes",
|
|
197
|
+
f"{self.config.ssh_user}@{self.config.ssh_host}",
|
|
198
|
+
log_cmd,
|
|
199
|
+
],
|
|
200
|
+
capture_output=True,
|
|
201
|
+
text=True,
|
|
202
|
+
timeout=self.timeout + 10,
|
|
203
|
+
)
|
|
204
|
+
return True, log_result.stdout.strip()
|
|
205
|
+
return False, None
|
|
206
|
+
except (subprocess.TimeoutExpired, Exception) as e:
|
|
207
|
+
return False, str(e)
|
|
208
|
+
|
|
209
|
+
def get_disk_usage(self) -> float | None:
|
|
210
|
+
"""Get disk usage of data.img in GB."""
|
|
211
|
+
try:
|
|
212
|
+
# Try common paths
|
|
213
|
+
paths = [
|
|
214
|
+
"/home/azureuser/waa-storage/data.img",
|
|
215
|
+
"/home/ubuntu/waa-storage/data.img",
|
|
216
|
+
"/storage/data.img",
|
|
217
|
+
]
|
|
218
|
+
for path in paths:
|
|
219
|
+
cmd = f"du -b {path} 2>/dev/null | cut -f1"
|
|
220
|
+
result = subprocess.run(
|
|
221
|
+
[
|
|
222
|
+
"ssh",
|
|
223
|
+
"-o", "StrictHostKeyChecking=no",
|
|
224
|
+
"-o", f"ConnectTimeout={self.timeout}",
|
|
225
|
+
"-o", "BatchMode=yes",
|
|
226
|
+
f"{self.config.ssh_user}@{self.config.ssh_host}",
|
|
227
|
+
cmd,
|
|
228
|
+
],
|
|
229
|
+
capture_output=True,
|
|
230
|
+
text=True,
|
|
231
|
+
timeout=self.timeout + 5,
|
|
232
|
+
)
|
|
233
|
+
if result.returncode == 0 and result.stdout.strip():
|
|
234
|
+
try:
|
|
235
|
+
bytes_size = int(result.stdout.strip())
|
|
236
|
+
return round(bytes_size / (1024 ** 3), 2)
|
|
237
|
+
except ValueError:
|
|
238
|
+
continue
|
|
239
|
+
return None
|
|
240
|
+
except (subprocess.TimeoutExpired, Exception):
|
|
241
|
+
return None
|
|
242
|
+
|
|
243
|
+
def check_status(self) -> VMStatus:
|
|
244
|
+
"""Perform full status check on the VM.
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
VMStatus with all checks performed.
|
|
248
|
+
"""
|
|
249
|
+
status = VMStatus(config=self.config)
|
|
250
|
+
|
|
251
|
+
try:
|
|
252
|
+
# Check VNC first (fastest, no SSH needed)
|
|
253
|
+
status.vnc_reachable = self.check_vnc()
|
|
254
|
+
|
|
255
|
+
# Check SSH
|
|
256
|
+
status.ssh_reachable = self.check_ssh()
|
|
257
|
+
|
|
258
|
+
if status.ssh_reachable:
|
|
259
|
+
# Check container
|
|
260
|
+
status.container_running, status.container_logs = self.get_container_status()
|
|
261
|
+
|
|
262
|
+
# Check WAA probe
|
|
263
|
+
status.waa_ready, status.waa_probe_response = self.check_waa_probe()
|
|
264
|
+
|
|
265
|
+
# Get disk usage
|
|
266
|
+
status.disk_usage_gb = self.get_disk_usage()
|
|
267
|
+
except Exception as e:
|
|
268
|
+
status.error = str(e)
|
|
269
|
+
|
|
270
|
+
return status
|
|
271
|
+
|
|
272
|
+
def run_monitor(
|
|
273
|
+
self,
|
|
274
|
+
callback: Callable[[VMStatus], None] | None = None,
|
|
275
|
+
interval: int = 30,
|
|
276
|
+
stop_on_ready: bool = True,
|
|
277
|
+
output_file: str | Path | None = None,
|
|
278
|
+
) -> VMStatus:
|
|
279
|
+
"""Run continuous monitoring until WAA is ready.
|
|
280
|
+
|
|
281
|
+
Args:
|
|
282
|
+
callback: Optional callback function called with each status update.
|
|
283
|
+
interval: Seconds between checks.
|
|
284
|
+
stop_on_ready: Stop monitoring when WAA is ready.
|
|
285
|
+
output_file: Optional file to write status updates (JSON lines).
|
|
286
|
+
|
|
287
|
+
Returns:
|
|
288
|
+
Final VMStatus (typically when WAA is ready).
|
|
289
|
+
"""
|
|
290
|
+
output_path = Path(output_file) if output_file else None
|
|
291
|
+
if output_path:
|
|
292
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
293
|
+
|
|
294
|
+
while True:
|
|
295
|
+
status = self.check_status()
|
|
296
|
+
|
|
297
|
+
# Call callback if provided
|
|
298
|
+
if callback:
|
|
299
|
+
callback(status)
|
|
300
|
+
|
|
301
|
+
# Write to file if provided
|
|
302
|
+
if output_path:
|
|
303
|
+
with open(output_path, "a") as f:
|
|
304
|
+
f.write(json.dumps(status.to_dict()) + "\n")
|
|
305
|
+
|
|
306
|
+
# Check if we should stop
|
|
307
|
+
if stop_on_ready and status.waa_ready:
|
|
308
|
+
return status
|
|
309
|
+
|
|
310
|
+
time.sleep(interval)
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
@dataclass
|
|
314
|
+
class PoolWorker:
|
|
315
|
+
"""A single worker in a VM pool."""
|
|
316
|
+
|
|
317
|
+
name: str
|
|
318
|
+
ip: str
|
|
319
|
+
status: str = "creating" # creating, ready, running, completed, failed, deleted
|
|
320
|
+
docker_container: str = "winarena"
|
|
321
|
+
waa_ready: bool = False
|
|
322
|
+
assigned_tasks: list[str] = field(default_factory=list)
|
|
323
|
+
completed_tasks: list[str] = field(default_factory=list)
|
|
324
|
+
current_task: str | None = None
|
|
325
|
+
error: str | None = None
|
|
326
|
+
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
|
327
|
+
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
@dataclass
|
|
331
|
+
class VMPool:
|
|
332
|
+
"""A pool of worker VMs for parallel WAA evaluation."""
|
|
333
|
+
|
|
334
|
+
pool_id: str
|
|
335
|
+
created_at: str
|
|
336
|
+
resource_group: str
|
|
337
|
+
location: str
|
|
338
|
+
vm_size: str
|
|
339
|
+
workers: list[PoolWorker]
|
|
340
|
+
total_tasks: int = 0
|
|
341
|
+
completed_tasks: int = 0
|
|
342
|
+
failed_tasks: int = 0
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
class VMPoolRegistry:
|
|
346
|
+
"""Manage VM pools for parallel WAA evaluation."""
|
|
347
|
+
|
|
348
|
+
REGISTRY_FILE = "benchmark_results/vm_pool_registry.json"
|
|
349
|
+
|
|
350
|
+
def __init__(self, registry_file: str | Path | None = None):
|
|
351
|
+
"""Initialize pool registry.
|
|
352
|
+
|
|
353
|
+
Args:
|
|
354
|
+
registry_file: Path to JSON registry file.
|
|
355
|
+
"""
|
|
356
|
+
self.registry_file = Path(registry_file or self.REGISTRY_FILE)
|
|
357
|
+
self._pool: VMPool | None = None
|
|
358
|
+
self.load()
|
|
359
|
+
|
|
360
|
+
def load(self) -> None:
|
|
361
|
+
"""Load pool from registry file."""
|
|
362
|
+
if self.registry_file.exists():
|
|
363
|
+
try:
|
|
364
|
+
with open(self.registry_file) as f:
|
|
365
|
+
data = json.load(f)
|
|
366
|
+
workers = [PoolWorker(**w) for w in data.get("workers", [])]
|
|
367
|
+
self._pool = VMPool(
|
|
368
|
+
pool_id=data["pool_id"],
|
|
369
|
+
created_at=data["created_at"],
|
|
370
|
+
resource_group=data["resource_group"],
|
|
371
|
+
location=data["location"],
|
|
372
|
+
vm_size=data["vm_size"],
|
|
373
|
+
workers=workers,
|
|
374
|
+
total_tasks=data.get("total_tasks", 0),
|
|
375
|
+
completed_tasks=data.get("completed_tasks", 0),
|
|
376
|
+
failed_tasks=data.get("failed_tasks", 0),
|
|
377
|
+
)
|
|
378
|
+
except (json.JSONDecodeError, KeyError) as e:
|
|
379
|
+
print(f"Warning: Could not load pool registry: {e}")
|
|
380
|
+
self._pool = None
|
|
381
|
+
|
|
382
|
+
def save(self) -> None:
|
|
383
|
+
"""Save pool to registry file."""
|
|
384
|
+
if self._pool is None:
|
|
385
|
+
return
|
|
386
|
+
self.registry_file.parent.mkdir(parents=True, exist_ok=True)
|
|
387
|
+
with open(self.registry_file, "w") as f:
|
|
388
|
+
json.dump(asdict(self._pool), f, indent=2)
|
|
389
|
+
|
|
390
|
+
def create_pool(
|
|
391
|
+
self,
|
|
392
|
+
workers: list[tuple[str, str]], # [(name, ip), ...]
|
|
393
|
+
resource_group: str,
|
|
394
|
+
location: str,
|
|
395
|
+
vm_size: str = "Standard_D4ds_v5",
|
|
396
|
+
) -> VMPool:
|
|
397
|
+
"""Create a new pool from created VMs.
|
|
398
|
+
|
|
399
|
+
Args:
|
|
400
|
+
workers: List of (name, ip) tuples.
|
|
401
|
+
resource_group: Azure resource group.
|
|
402
|
+
location: Azure region.
|
|
403
|
+
vm_size: VM size used.
|
|
404
|
+
|
|
405
|
+
Returns:
|
|
406
|
+
Created VMPool.
|
|
407
|
+
"""
|
|
408
|
+
pool_id = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
409
|
+
self._pool = VMPool(
|
|
410
|
+
pool_id=pool_id,
|
|
411
|
+
created_at=datetime.now().isoformat(),
|
|
412
|
+
resource_group=resource_group,
|
|
413
|
+
location=location,
|
|
414
|
+
vm_size=vm_size,
|
|
415
|
+
workers=[PoolWorker(name=name, ip=ip, status="ready") for name, ip in workers],
|
|
416
|
+
)
|
|
417
|
+
self.save()
|
|
418
|
+
return self._pool
|
|
419
|
+
|
|
420
|
+
def get_pool(self) -> VMPool | None:
|
|
421
|
+
"""Get current pool."""
|
|
422
|
+
return self._pool
|
|
423
|
+
|
|
424
|
+
def update_worker(self, name: str, **kwargs) -> None:
|
|
425
|
+
"""Update a worker's status.
|
|
426
|
+
|
|
427
|
+
Args:
|
|
428
|
+
name: Worker name.
|
|
429
|
+
**kwargs: Fields to update.
|
|
430
|
+
"""
|
|
431
|
+
if self._pool is None:
|
|
432
|
+
return
|
|
433
|
+
for worker in self._pool.workers:
|
|
434
|
+
if worker.name == name:
|
|
435
|
+
for key, value in kwargs.items():
|
|
436
|
+
if hasattr(worker, key):
|
|
437
|
+
setattr(worker, key, value)
|
|
438
|
+
worker.updated_at = datetime.now().isoformat()
|
|
439
|
+
break
|
|
440
|
+
self.save()
|
|
441
|
+
|
|
442
|
+
def update_pool_progress(self, completed: int = 0, failed: int = 0) -> None:
|
|
443
|
+
"""Update pool-level progress.
|
|
444
|
+
|
|
445
|
+
Args:
|
|
446
|
+
completed: Increment completed count by this amount.
|
|
447
|
+
failed: Increment failed count by this amount.
|
|
448
|
+
"""
|
|
449
|
+
if self._pool is None:
|
|
450
|
+
return
|
|
451
|
+
self._pool.completed_tasks += completed
|
|
452
|
+
self._pool.failed_tasks += failed
|
|
453
|
+
self.save()
|
|
454
|
+
|
|
455
|
+
def delete_pool(self) -> bool:
|
|
456
|
+
"""Delete the pool registry (VMs must be deleted separately).
|
|
457
|
+
|
|
458
|
+
Returns:
|
|
459
|
+
True if pool was deleted.
|
|
460
|
+
"""
|
|
461
|
+
if self.registry_file.exists():
|
|
462
|
+
self.registry_file.unlink()
|
|
463
|
+
self._pool = None
|
|
464
|
+
return True
|
|
465
|
+
return False
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
class VMRegistry:
|
|
469
|
+
"""Manage a registry of VMs and their status."""
|
|
470
|
+
|
|
471
|
+
def __init__(self, registry_file: str | Path = "benchmark_results/vm_registry.json"):
|
|
472
|
+
"""Initialize registry.
|
|
473
|
+
|
|
474
|
+
Args:
|
|
475
|
+
registry_file: Path to JSON registry file.
|
|
476
|
+
"""
|
|
477
|
+
self.registry_file = Path(registry_file)
|
|
478
|
+
self._vms: list[VMConfig] = []
|
|
479
|
+
self.load()
|
|
480
|
+
|
|
481
|
+
def load(self) -> None:
|
|
482
|
+
"""Load VMs from registry file."""
|
|
483
|
+
if self.registry_file.exists():
|
|
484
|
+
with open(self.registry_file) as f:
|
|
485
|
+
data = json.load(f)
|
|
486
|
+
self._vms = [VMConfig.from_dict(vm) for vm in data]
|
|
487
|
+
|
|
488
|
+
def save(self) -> None:
|
|
489
|
+
"""Save VMs to registry file."""
|
|
490
|
+
self.registry_file.parent.mkdir(parents=True, exist_ok=True)
|
|
491
|
+
with open(self.registry_file, "w") as f:
|
|
492
|
+
json.dump([vm.to_dict() for vm in self._vms], f, indent=2)
|
|
493
|
+
|
|
494
|
+
def add(self, config: VMConfig) -> None:
|
|
495
|
+
"""Add a VM to the registry."""
|
|
496
|
+
# Remove existing VM with same name
|
|
497
|
+
self._vms = [vm for vm in self._vms if vm.name != config.name]
|
|
498
|
+
self._vms.append(config)
|
|
499
|
+
self.save()
|
|
500
|
+
|
|
501
|
+
def remove(self, name: str) -> bool:
|
|
502
|
+
"""Remove a VM from the registry.
|
|
503
|
+
|
|
504
|
+
Returns:
|
|
505
|
+
True if VM was found and removed.
|
|
506
|
+
"""
|
|
507
|
+
original_len = len(self._vms)
|
|
508
|
+
self._vms = [vm for vm in self._vms if vm.name != name]
|
|
509
|
+
if len(self._vms) < original_len:
|
|
510
|
+
self.save()
|
|
511
|
+
return True
|
|
512
|
+
return False
|
|
513
|
+
|
|
514
|
+
def get(self, name: str) -> VMConfig | None:
|
|
515
|
+
"""Get a VM by name."""
|
|
516
|
+
for vm in self._vms:
|
|
517
|
+
if vm.name == name:
|
|
518
|
+
return vm
|
|
519
|
+
return None
|
|
520
|
+
|
|
521
|
+
def list(self) -> list[VMConfig]:
|
|
522
|
+
"""List all VMs."""
|
|
523
|
+
return list(self._vms)
|
|
524
|
+
|
|
525
|
+
def check_all(self, timeout: int = 5) -> list[VMStatus]:
|
|
526
|
+
"""Check status of all VMs.
|
|
527
|
+
|
|
528
|
+
Args:
|
|
529
|
+
timeout: Timeout per VM check.
|
|
530
|
+
|
|
531
|
+
Returns:
|
|
532
|
+
List of VMStatus for each registered VM.
|
|
533
|
+
"""
|
|
534
|
+
statuses = []
|
|
535
|
+
for config in self._vms:
|
|
536
|
+
monitor = VMMonitor(config, timeout=timeout)
|
|
537
|
+
statuses.append(monitor.check_status())
|
|
538
|
+
return statuses
|
|
539
|
+
|
|
540
|
+
|
|
541
|
+
def main():
|
|
542
|
+
"""CLI entry point for VM monitoring."""
|
|
543
|
+
import argparse
|
|
544
|
+
|
|
545
|
+
parser = argparse.ArgumentParser(description="Monitor WAA VMs")
|
|
546
|
+
parser.add_argument("--host", help="SSH host")
|
|
547
|
+
parser.add_argument("--user", default="azureuser", help="SSH user")
|
|
548
|
+
parser.add_argument("--container", default="winarena", help="Docker container name")
|
|
549
|
+
parser.add_argument("--interval", type=int, default=30, help="Check interval in seconds")
|
|
550
|
+
parser.add_argument("--output", help="Output file for status updates (JSON lines)")
|
|
551
|
+
parser.add_argument("--list", action="store_true", help="List all registered VMs")
|
|
552
|
+
parser.add_argument("--check-all", action="store_true", help="Check all registered VMs")
|
|
553
|
+
|
|
554
|
+
args = parser.parse_args()
|
|
555
|
+
|
|
556
|
+
if args.list:
|
|
557
|
+
registry = VMRegistry()
|
|
558
|
+
for vm in registry.list():
|
|
559
|
+
print(f" {vm.name}: {vm.ssh_user}@{vm.ssh_host} (container: {vm.docker_container})")
|
|
560
|
+
return
|
|
561
|
+
|
|
562
|
+
if args.check_all:
|
|
563
|
+
registry = VMRegistry()
|
|
564
|
+
for status in registry.check_all():
|
|
565
|
+
print(f"\n{status.config.name}:")
|
|
566
|
+
print(f" SSH: {'✓' if status.ssh_reachable else '✗'}")
|
|
567
|
+
print(f" VNC: {'✓' if status.vnc_reachable else '✗'}")
|
|
568
|
+
print(f" WAA: {'✓ READY' if status.waa_ready else '✗ Not ready'}")
|
|
569
|
+
if status.disk_usage_gb:
|
|
570
|
+
print(f" Disk: {status.disk_usage_gb} GB")
|
|
571
|
+
return
|
|
572
|
+
|
|
573
|
+
if not args.host:
|
|
574
|
+
parser.error("--host is required for monitoring")
|
|
575
|
+
|
|
576
|
+
config = VMConfig(
|
|
577
|
+
name="cli-vm",
|
|
578
|
+
ssh_host=args.host,
|
|
579
|
+
ssh_user=args.user,
|
|
580
|
+
docker_container=args.container,
|
|
581
|
+
)
|
|
582
|
+
|
|
583
|
+
monitor = VMMonitor(config)
|
|
584
|
+
|
|
585
|
+
def print_status(status: VMStatus):
|
|
586
|
+
ts = datetime.now().strftime("%H:%M:%S")
|
|
587
|
+
waa_str = "READY!" if status.waa_ready else "not ready"
|
|
588
|
+
disk_str = f"{status.disk_usage_gb}GB" if status.disk_usage_gb else "?"
|
|
589
|
+
print(f"[{ts}] SSH: {'✓' if status.ssh_reachable else '✗'} | "
|
|
590
|
+
f"VNC: {'✓' if status.vnc_reachable else '✗'} | "
|
|
591
|
+
f"WAA: {waa_str} | Disk: {disk_str}")
|
|
592
|
+
if status.container_logs:
|
|
593
|
+
# Show last log line
|
|
594
|
+
last_line = status.container_logs.split('\n')[-1][:80]
|
|
595
|
+
print(f" Log: {last_line}")
|
|
596
|
+
|
|
597
|
+
print(f"Monitoring {args.host}... (Ctrl+C to stop)")
|
|
598
|
+
try:
|
|
599
|
+
final_status = monitor.run_monitor(
|
|
600
|
+
callback=print_status,
|
|
601
|
+
interval=args.interval,
|
|
602
|
+
output_file=args.output,
|
|
603
|
+
)
|
|
604
|
+
print(f"\n✓ WAA is ready! Probe response: {final_status.waa_probe_response}")
|
|
605
|
+
except KeyboardInterrupt:
|
|
606
|
+
print("\nMonitoring stopped.")
|
|
607
|
+
|
|
608
|
+
|
|
609
|
+
if __name__ == "__main__":
|
|
610
|
+
main()
|
openadapt_ml/benchmarks/waa.py
CHANGED
|
@@ -565,6 +565,8 @@ class WAAMockAdapter(BenchmarkAdapter):
|
|
|
565
565
|
self._current_task: BenchmarkTask | None = None
|
|
566
566
|
self._step_count = 0
|
|
567
567
|
self._temp_dir: Path | None = None
|
|
568
|
+
self._actions: list[BenchmarkAction] = [] # Track actions for evaluation
|
|
569
|
+
self._text_entered: str | None = None # Track typed text
|
|
568
570
|
self._generate_mock_tasks()
|
|
569
571
|
|
|
570
572
|
@property
|
|
@@ -608,24 +610,79 @@ class WAAMockAdapter(BenchmarkAdapter):
|
|
|
608
610
|
def reset(self, task: BenchmarkTask) -> BenchmarkObservation:
|
|
609
611
|
self._current_task = task
|
|
610
612
|
self._step_count = 0
|
|
613
|
+
self._actions = [] # Clear action history
|
|
614
|
+
self._text_entered = None
|
|
611
615
|
return self._mock_observation()
|
|
612
616
|
|
|
613
617
|
def step(
|
|
614
618
|
self, action: BenchmarkAction
|
|
615
619
|
) -> tuple[BenchmarkObservation, bool, dict[str, Any]]:
|
|
616
620
|
self._step_count += 1
|
|
621
|
+
self._actions.append(action) # Track action for evaluation
|
|
622
|
+
|
|
623
|
+
# Track typed text
|
|
624
|
+
if action.type == "type" and action.text:
|
|
625
|
+
self._text_entered = action.text
|
|
626
|
+
|
|
617
627
|
done = action.type == "done" or self._step_count >= 15
|
|
618
628
|
return self._mock_observation(), done, {"step": self._step_count}
|
|
619
629
|
|
|
620
630
|
def evaluate(self, task: BenchmarkTask) -> BenchmarkResult:
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
631
|
+
"""Evaluate task based on actions taken.
|
|
632
|
+
|
|
633
|
+
Success criteria for mock tasks:
|
|
634
|
+
- Agent clicked the Submit button (ID 4) OR
|
|
635
|
+
- Agent typed text AND clicked OK (ID 1) OR
|
|
636
|
+
- Agent completed with DONE action after meaningful interaction
|
|
637
|
+
|
|
638
|
+
This provides deterministic evaluation based on actual agent behavior,
|
|
639
|
+
not random chance. The mock UI has:
|
|
640
|
+
- ID 1: OK button
|
|
641
|
+
- ID 2: Text input field
|
|
642
|
+
- ID 3: Cancel button
|
|
643
|
+
- ID 4: Submit button
|
|
644
|
+
"""
|
|
645
|
+
# Check what actions were taken
|
|
646
|
+
clicked_ids = set()
|
|
647
|
+
typed_text = False
|
|
648
|
+
called_done = False
|
|
649
|
+
|
|
650
|
+
for action in self._actions:
|
|
651
|
+
if action.type == "click":
|
|
652
|
+
# Extract target node ID from action
|
|
653
|
+
target_id = getattr(action, "target_node_id", None)
|
|
654
|
+
if target_id:
|
|
655
|
+
clicked_ids.add(str(target_id))
|
|
656
|
+
elif action.type == "type" and action.text:
|
|
657
|
+
typed_text = True
|
|
658
|
+
elif action.type == "done":
|
|
659
|
+
called_done = True
|
|
660
|
+
|
|
661
|
+
# Success criteria:
|
|
662
|
+
# 1. Clicked Submit (ID 4) - primary success path
|
|
663
|
+
# 2. Typed something AND clicked OK (ID 1) - form submission path
|
|
664
|
+
# 3. Called DONE after at least 2 actions - reasonable completion
|
|
665
|
+
clicked_submit = "4" in clicked_ids
|
|
666
|
+
clicked_ok = "1" in clicked_ids
|
|
667
|
+
form_submitted = typed_text and clicked_ok
|
|
668
|
+
reasonable_completion = called_done and len(self._actions) >= 2
|
|
669
|
+
|
|
670
|
+
success = clicked_submit or form_submitted or reasonable_completion
|
|
671
|
+
|
|
672
|
+
# Calculate partial credit score
|
|
673
|
+
score = 0.0
|
|
674
|
+
if success:
|
|
675
|
+
score = 1.0
|
|
676
|
+
elif typed_text or clicked_ids:
|
|
677
|
+
# Partial credit for taking meaningful actions
|
|
678
|
+
score = 0.3 + (0.1 * min(len(clicked_ids), 3)) + (0.2 if typed_text else 0.0)
|
|
679
|
+
|
|
624
680
|
return BenchmarkResult(
|
|
625
681
|
task_id=task.task_id,
|
|
626
682
|
success=success,
|
|
627
|
-
score=
|
|
683
|
+
score=score,
|
|
628
684
|
num_steps=self._step_count,
|
|
685
|
+
reason=f"clicked={list(clicked_ids)}, typed={typed_text}, done={called_done}",
|
|
629
686
|
)
|
|
630
687
|
|
|
631
688
|
def _mock_observation(self) -> BenchmarkObservation:
|