openadapt-ml 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/baselines/__init__.py +121 -0
- openadapt_ml/baselines/adapter.py +185 -0
- openadapt_ml/baselines/cli.py +314 -0
- openadapt_ml/baselines/config.py +448 -0
- openadapt_ml/baselines/parser.py +922 -0
- openadapt_ml/baselines/prompts.py +787 -0
- openadapt_ml/benchmarks/__init__.py +13 -107
- openadapt_ml/benchmarks/agent.py +297 -374
- openadapt_ml/benchmarks/azure.py +62 -24
- openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
- openadapt_ml/benchmarks/cli.py +1874 -751
- openadapt_ml/benchmarks/trace_export.py +631 -0
- openadapt_ml/benchmarks/viewer.py +1236 -0
- openadapt_ml/benchmarks/vm_monitor.py +1111 -0
- openadapt_ml/benchmarks/waa_deploy/Dockerfile +216 -0
- openadapt_ml/benchmarks/waa_deploy/__init__.py +10 -0
- openadapt_ml/benchmarks/waa_deploy/api_agent.py +540 -0
- openadapt_ml/benchmarks/waa_deploy/start_waa_server.bat +53 -0
- openadapt_ml/cloud/azure_inference.py +3 -5
- openadapt_ml/cloud/lambda_labs.py +722 -307
- openadapt_ml/cloud/local.py +3194 -89
- openadapt_ml/cloud/ssh_tunnel.py +595 -0
- openadapt_ml/datasets/next_action.py +125 -96
- openadapt_ml/evals/grounding.py +32 -9
- openadapt_ml/evals/plot_eval_metrics.py +15 -13
- openadapt_ml/evals/trajectory_matching.py +120 -57
- openadapt_ml/experiments/demo_prompt/__init__.py +19 -0
- openadapt_ml/experiments/demo_prompt/format_demo.py +236 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_20251231_002125.json +83 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_n30_20251231_165958.json +1100 -0
- openadapt_ml/experiments/demo_prompt/results/multistep_20251231_025051.json +182 -0
- openadapt_ml/experiments/demo_prompt/run_experiment.py +541 -0
- openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
- openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
- openadapt_ml/experiments/representation_shootout/config.py +390 -0
- openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
- openadapt_ml/experiments/representation_shootout/runner.py +687 -0
- openadapt_ml/experiments/waa_demo/__init__.py +10 -0
- openadapt_ml/experiments/waa_demo/demos.py +357 -0
- openadapt_ml/experiments/waa_demo/runner.py +732 -0
- openadapt_ml/experiments/waa_demo/tasks.py +151 -0
- openadapt_ml/export/__init__.py +9 -0
- openadapt_ml/export/__main__.py +6 -0
- openadapt_ml/export/cli.py +89 -0
- openadapt_ml/export/parquet.py +277 -0
- openadapt_ml/grounding/detector.py +18 -14
- openadapt_ml/ingest/__init__.py +11 -10
- openadapt_ml/ingest/capture.py +97 -86
- openadapt_ml/ingest/loader.py +120 -69
- openadapt_ml/ingest/synthetic.py +344 -193
- openadapt_ml/models/api_adapter.py +14 -4
- openadapt_ml/models/base_adapter.py +10 -2
- openadapt_ml/models/providers/__init__.py +288 -0
- openadapt_ml/models/providers/anthropic.py +266 -0
- openadapt_ml/models/providers/base.py +299 -0
- openadapt_ml/models/providers/google.py +376 -0
- openadapt_ml/models/providers/openai.py +342 -0
- openadapt_ml/models/qwen_vl.py +46 -19
- openadapt_ml/perception/__init__.py +35 -0
- openadapt_ml/perception/integration.py +399 -0
- openadapt_ml/retrieval/README.md +226 -0
- openadapt_ml/retrieval/USAGE.md +391 -0
- openadapt_ml/retrieval/__init__.py +91 -0
- openadapt_ml/retrieval/demo_retriever.py +843 -0
- openadapt_ml/retrieval/embeddings.py +630 -0
- openadapt_ml/retrieval/index.py +194 -0
- openadapt_ml/retrieval/retriever.py +162 -0
- openadapt_ml/runtime/__init__.py +50 -0
- openadapt_ml/runtime/policy.py +27 -14
- openadapt_ml/runtime/safety_gate.py +471 -0
- openadapt_ml/schema/__init__.py +113 -0
- openadapt_ml/schema/converters.py +588 -0
- openadapt_ml/schema/episode.py +470 -0
- openadapt_ml/scripts/capture_screenshots.py +530 -0
- openadapt_ml/scripts/compare.py +102 -61
- openadapt_ml/scripts/demo_policy.py +4 -1
- openadapt_ml/scripts/eval_policy.py +19 -14
- openadapt_ml/scripts/make_gif.py +1 -1
- openadapt_ml/scripts/prepare_synthetic.py +16 -17
- openadapt_ml/scripts/train.py +98 -75
- openadapt_ml/segmentation/README.md +920 -0
- openadapt_ml/segmentation/__init__.py +97 -0
- openadapt_ml/segmentation/adapters/__init__.py +5 -0
- openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
- openadapt_ml/segmentation/annotator.py +610 -0
- openadapt_ml/segmentation/cache.py +290 -0
- openadapt_ml/segmentation/cli.py +674 -0
- openadapt_ml/segmentation/deduplicator.py +656 -0
- openadapt_ml/segmentation/frame_describer.py +788 -0
- openadapt_ml/segmentation/pipeline.py +340 -0
- openadapt_ml/segmentation/schemas.py +622 -0
- openadapt_ml/segmentation/segment_extractor.py +634 -0
- openadapt_ml/training/azure_ops_viewer.py +1097 -0
- openadapt_ml/training/benchmark_viewer.py +3255 -19
- openadapt_ml/training/shared_ui.py +7 -7
- openadapt_ml/training/stub_provider.py +57 -35
- openadapt_ml/training/trainer.py +255 -441
- openadapt_ml/training/trl_trainer.py +403 -0
- openadapt_ml/training/viewer.py +323 -108
- openadapt_ml/training/viewer_components.py +180 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +312 -69
- openadapt_ml-0.2.1.dist-info/RECORD +116 -0
- openadapt_ml/benchmarks/base.py +0 -366
- openadapt_ml/benchmarks/data_collection.py +0 -432
- openadapt_ml/benchmarks/runner.py +0 -381
- openadapt_ml/benchmarks/waa.py +0 -704
- openadapt_ml/schemas/__init__.py +0 -53
- openadapt_ml/schemas/sessions.py +0 -122
- openadapt_ml/schemas/validation.py +0 -252
- openadapt_ml-0.1.0.dist-info/RECORD +0 -55
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
openadapt_ml/cloud/local.py
CHANGED
|
@@ -27,7 +27,6 @@ import http.server
|
|
|
27
27
|
import json
|
|
28
28
|
import os
|
|
29
29
|
import shutil
|
|
30
|
-
import signal
|
|
31
30
|
import socketserver
|
|
32
31
|
import subprocess
|
|
33
32
|
import sys
|
|
@@ -36,6 +35,8 @@ import webbrowser
|
|
|
36
35
|
from pathlib import Path
|
|
37
36
|
from typing import Any
|
|
38
37
|
|
|
38
|
+
from openadapt_ml.cloud.ssh_tunnel import get_tunnel_manager
|
|
39
|
+
|
|
39
40
|
# Training output directory
|
|
40
41
|
TRAINING_OUTPUT = Path("training_output")
|
|
41
42
|
|
|
@@ -107,7 +108,10 @@ def _is_mock_benchmark(benchmark_dir: Path) -> bool:
|
|
|
107
108
|
|
|
108
109
|
# Check for test runs (but allow waa-mock evaluations with real API models)
|
|
109
110
|
# Only filter out purely synthetic test data directories
|
|
110
|
-
if any(
|
|
111
|
+
if any(
|
|
112
|
+
term in benchmark_dir.name.lower()
|
|
113
|
+
for term in ["test_run", "test_cli", "quick_demo"]
|
|
114
|
+
):
|
|
111
115
|
return True
|
|
112
116
|
|
|
113
117
|
return False
|
|
@@ -143,6 +147,16 @@ def _regenerate_benchmark_viewer_if_available(output_dir: Path) -> bool:
|
|
|
143
147
|
# No real benchmark data - generate empty state viewer
|
|
144
148
|
try:
|
|
145
149
|
generate_empty_benchmark_viewer(benchmark_html_path)
|
|
150
|
+
|
|
151
|
+
# Still create symlink for azure_jobs.json access (even without real benchmarks)
|
|
152
|
+
if benchmark_results_dir.exists():
|
|
153
|
+
benchmark_results_link = output_dir / "benchmark_results"
|
|
154
|
+
if benchmark_results_link.is_symlink():
|
|
155
|
+
benchmark_results_link.unlink()
|
|
156
|
+
elif benchmark_results_link.exists():
|
|
157
|
+
shutil.rmtree(benchmark_results_link)
|
|
158
|
+
benchmark_results_link.symlink_to(benchmark_results_dir.absolute())
|
|
159
|
+
|
|
146
160
|
print(" Generated benchmark viewer: No real evaluation data yet")
|
|
147
161
|
return True
|
|
148
162
|
except Exception as e:
|
|
@@ -168,11 +182,20 @@ def _regenerate_benchmark_viewer_if_available(output_dir: Path) -> bool:
|
|
|
168
182
|
tasks_dst = benchmark_tasks_dir / benchmark_dir.name
|
|
169
183
|
shutil.copytree(tasks_src, tasks_dst)
|
|
170
184
|
|
|
185
|
+
# Create symlink for benchmark_results directory (for azure_jobs.json access)
|
|
186
|
+
benchmark_results_link = output_dir / "benchmark_results"
|
|
187
|
+
if benchmark_results_link.is_symlink():
|
|
188
|
+
benchmark_results_link.unlink()
|
|
189
|
+
elif benchmark_results_link.exists():
|
|
190
|
+
shutil.rmtree(benchmark_results_link)
|
|
191
|
+
benchmark_results_link.symlink_to(benchmark_results_dir.absolute())
|
|
192
|
+
|
|
171
193
|
print(f" Regenerated benchmark viewer with {len(real_benchmarks)} run(s)")
|
|
172
194
|
return True
|
|
173
195
|
except Exception as e:
|
|
174
196
|
print(f" Could not regenerate benchmark viewer: {e}")
|
|
175
197
|
import traceback
|
|
198
|
+
|
|
176
199
|
traceback.print_exc()
|
|
177
200
|
return False
|
|
178
201
|
|
|
@@ -181,6 +204,7 @@ def detect_device() -> str:
|
|
|
181
204
|
"""Detect available compute device."""
|
|
182
205
|
try:
|
|
183
206
|
import torch
|
|
207
|
+
|
|
184
208
|
if torch.cuda.is_available():
|
|
185
209
|
device_name = torch.cuda.get_device_name(0)
|
|
186
210
|
return f"cuda ({device_name})"
|
|
@@ -231,10 +255,13 @@ def get_training_status() -> dict[str, Any]:
|
|
|
231
255
|
# Find checkpoints
|
|
232
256
|
checkpoints_dir = Path("checkpoints")
|
|
233
257
|
if checkpoints_dir.exists():
|
|
234
|
-
status["checkpoints"] = sorted(
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
258
|
+
status["checkpoints"] = sorted(
|
|
259
|
+
[
|
|
260
|
+
d.name
|
|
261
|
+
for d in checkpoints_dir.iterdir()
|
|
262
|
+
if d.is_dir() and (d / "adapter_config.json").exists()
|
|
263
|
+
]
|
|
264
|
+
)
|
|
238
265
|
|
|
239
266
|
return status
|
|
240
267
|
|
|
@@ -244,9 +271,9 @@ def cmd_status(args: argparse.Namespace) -> int:
|
|
|
244
271
|
status = get_training_status()
|
|
245
272
|
current_dir = get_current_output_dir()
|
|
246
273
|
|
|
247
|
-
print(f"\n{'='*50}")
|
|
274
|
+
print(f"\n{'=' * 50}")
|
|
248
275
|
print("LOCAL TRAINING STATUS")
|
|
249
|
-
print(f"{'='*50}")
|
|
276
|
+
print(f"{'=' * 50}")
|
|
250
277
|
print(f"Device: {status['device']}")
|
|
251
278
|
print(f"Status: {'RUNNING' if status['running'] else 'IDLE'}")
|
|
252
279
|
if status.get("job_id"):
|
|
@@ -254,7 +281,7 @@ def cmd_status(args: argparse.Namespace) -> int:
|
|
|
254
281
|
print(f"Output: {current_dir}")
|
|
255
282
|
|
|
256
283
|
if status.get("epoch"):
|
|
257
|
-
print(
|
|
284
|
+
print("\nProgress:")
|
|
258
285
|
print(f" Epoch: {status['epoch']}")
|
|
259
286
|
print(f" Step: {status['step']}")
|
|
260
287
|
if status.get("loss"):
|
|
@@ -267,7 +294,9 @@ def cmd_status(args: argparse.Namespace) -> int:
|
|
|
267
294
|
for cp in status["checkpoints"][-5:]: # Show last 5
|
|
268
295
|
print(f" - {cp}")
|
|
269
296
|
|
|
270
|
-
print(
|
|
297
|
+
print(
|
|
298
|
+
f"\nDashboard: {'✓' if status['has_dashboard'] else '✗'} {current_dir}/dashboard.html"
|
|
299
|
+
)
|
|
271
300
|
print(f"Viewer: {'✓' if status['has_viewer'] else '✗'} {current_dir}/viewer.html")
|
|
272
301
|
print()
|
|
273
302
|
|
|
@@ -300,9 +329,9 @@ def cmd_train(args: argparse.Namespace) -> int:
|
|
|
300
329
|
print(f"Error: Config not found: {config_path}")
|
|
301
330
|
return 1
|
|
302
331
|
|
|
303
|
-
print(f"\n{'='*50}")
|
|
332
|
+
print(f"\n{'=' * 50}")
|
|
304
333
|
print("STARTING LOCAL TRAINING")
|
|
305
|
-
print(f"{'='*50}")
|
|
334
|
+
print(f"{'=' * 50}")
|
|
306
335
|
print(f"Capture: {capture_path}")
|
|
307
336
|
print(f"Goal: {goal}")
|
|
308
337
|
print(f"Config: {config}")
|
|
@@ -311,10 +340,15 @@ def cmd_train(args: argparse.Namespace) -> int:
|
|
|
311
340
|
|
|
312
341
|
# Build command
|
|
313
342
|
cmd = [
|
|
314
|
-
sys.executable,
|
|
315
|
-
"
|
|
316
|
-
"
|
|
317
|
-
"--
|
|
343
|
+
sys.executable,
|
|
344
|
+
"-m",
|
|
345
|
+
"openadapt_ml.scripts.train",
|
|
346
|
+
"--config",
|
|
347
|
+
str(config_path),
|
|
348
|
+
"--capture",
|
|
349
|
+
str(capture_path),
|
|
350
|
+
"--goal",
|
|
351
|
+
goal,
|
|
318
352
|
]
|
|
319
353
|
|
|
320
354
|
if args.open:
|
|
@@ -333,14 +367,16 @@ def cmd_check(args: argparse.Namespace) -> int:
|
|
|
333
367
|
"""Check training health and early stopping analysis."""
|
|
334
368
|
status = get_training_status()
|
|
335
369
|
|
|
336
|
-
print(f"\n{'='*50}")
|
|
370
|
+
print(f"\n{'=' * 50}")
|
|
337
371
|
print("TRAINING HEALTH CHECK")
|
|
338
|
-
print(f"{'='*50}")
|
|
372
|
+
print(f"{'=' * 50}")
|
|
339
373
|
|
|
340
374
|
raw_losses = status.get("losses", [])
|
|
341
375
|
if not raw_losses:
|
|
342
376
|
print("No training data found.")
|
|
343
|
-
print(
|
|
377
|
+
print(
|
|
378
|
+
"Run training first with: uv run python -m openadapt_ml.cloud.local train --capture <path>"
|
|
379
|
+
)
|
|
344
380
|
return 1
|
|
345
381
|
|
|
346
382
|
# Extract loss values (handle both dict and float formats)
|
|
@@ -361,7 +397,7 @@ def cmd_check(args: argparse.Namespace) -> int:
|
|
|
361
397
|
min_loss = min(losses)
|
|
362
398
|
max_loss = max(losses)
|
|
363
399
|
|
|
364
|
-
print(
|
|
400
|
+
print("\nLoss progression:")
|
|
365
401
|
print(f" First: {first_loss:.4f}")
|
|
366
402
|
print(f" Last: {last_loss:.4f}")
|
|
367
403
|
print(f" Min: {min_loss:.4f}")
|
|
@@ -372,9 +408,11 @@ def cmd_check(args: argparse.Namespace) -> int:
|
|
|
372
408
|
if len(losses) >= 10:
|
|
373
409
|
recent = losses[-10:]
|
|
374
410
|
recent_avg = sum(recent) / len(recent)
|
|
375
|
-
recent_std = (
|
|
411
|
+
recent_std = (
|
|
412
|
+
sum((x - recent_avg) ** 2 for x in recent) / len(recent)
|
|
413
|
+
) ** 0.5
|
|
376
414
|
|
|
377
|
-
print(
|
|
415
|
+
print("\nRecent stability (last 10 steps):")
|
|
378
416
|
print(f" Avg loss: {recent_avg:.4f}")
|
|
379
417
|
print(f" Std dev: {recent_std:.4f}")
|
|
380
418
|
|
|
@@ -393,14 +431,18 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
393
431
|
"""Start local web server for dashboard.
|
|
394
432
|
|
|
395
433
|
Automatically regenerates dashboard and viewer before serving to ensure
|
|
396
|
-
the latest code and data are reflected.
|
|
434
|
+
the latest code and data are reflected. Also ensures the 'current' symlink
|
|
435
|
+
points to the most recent training run.
|
|
397
436
|
"""
|
|
398
|
-
from openadapt_ml.training.trainer import
|
|
437
|
+
from openadapt_ml.training.trainer import (
|
|
438
|
+
regenerate_local_dashboard,
|
|
439
|
+
update_current_symlink_to_latest,
|
|
440
|
+
)
|
|
399
441
|
|
|
400
442
|
port = args.port
|
|
401
443
|
|
|
402
444
|
# Determine what to serve: benchmark directory or training output
|
|
403
|
-
if hasattr(args,
|
|
445
|
+
if hasattr(args, "benchmark") and args.benchmark:
|
|
404
446
|
serve_dir = Path(args.benchmark).expanduser().resolve()
|
|
405
447
|
if not serve_dir.exists():
|
|
406
448
|
print(f"Error: Benchmark directory not found: {serve_dir}")
|
|
@@ -410,7 +452,10 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
410
452
|
if not args.no_regenerate:
|
|
411
453
|
print("Regenerating benchmark viewer...")
|
|
412
454
|
try:
|
|
413
|
-
from openadapt_ml.training.benchmark_viewer import
|
|
455
|
+
from openadapt_ml.training.benchmark_viewer import (
|
|
456
|
+
generate_benchmark_viewer,
|
|
457
|
+
)
|
|
458
|
+
|
|
414
459
|
generate_benchmark_viewer(serve_dir)
|
|
415
460
|
except Exception as e:
|
|
416
461
|
print(f"Warning: Could not regenerate benchmark viewer: {e}")
|
|
@@ -419,6 +464,17 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
419
464
|
else:
|
|
420
465
|
serve_dir = get_current_output_dir()
|
|
421
466
|
|
|
467
|
+
# If current symlink doesn't exist or is broken, update to latest run
|
|
468
|
+
if not serve_dir.exists() or not serve_dir.is_dir():
|
|
469
|
+
print("Updating 'current' symlink to latest training run...")
|
|
470
|
+
latest = update_current_symlink_to_latest()
|
|
471
|
+
if latest:
|
|
472
|
+
serve_dir = get_current_output_dir()
|
|
473
|
+
print(f" Updated to: {latest.name}")
|
|
474
|
+
else:
|
|
475
|
+
print(f"Error: {serve_dir} not found. Run training first.")
|
|
476
|
+
return 1
|
|
477
|
+
|
|
422
478
|
if not serve_dir.exists():
|
|
423
479
|
print(f"Error: {serve_dir} not found. Run training first.")
|
|
424
480
|
return 1
|
|
@@ -427,7 +483,9 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
427
483
|
if not args.no_regenerate:
|
|
428
484
|
print("Regenerating dashboard and viewer...")
|
|
429
485
|
try:
|
|
430
|
-
|
|
486
|
+
# Use keep_polling=True so JavaScript fetches live data from training_log.json
|
|
487
|
+
# This ensures the dashboard shows current data instead of stale embedded data
|
|
488
|
+
regenerate_local_dashboard(str(serve_dir), keep_polling=True)
|
|
431
489
|
# Also regenerate viewer if comparison data exists
|
|
432
490
|
_regenerate_viewer_if_possible(serve_dir)
|
|
433
491
|
except Exception as e:
|
|
@@ -436,8 +494,23 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
436
494
|
# Also regenerate benchmark viewer from latest benchmark results
|
|
437
495
|
_regenerate_benchmark_viewer_if_available(serve_dir)
|
|
438
496
|
|
|
497
|
+
# Generate Azure ops dashboard
|
|
498
|
+
try:
|
|
499
|
+
from openadapt_ml.training.azure_ops_viewer import (
|
|
500
|
+
generate_azure_ops_dashboard,
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
generate_azure_ops_dashboard(serve_dir / "azure_ops.html")
|
|
504
|
+
print(" Generated Azure ops dashboard")
|
|
505
|
+
except Exception as e:
|
|
506
|
+
print(f" Warning: Could not generate Azure ops dashboard: {e}")
|
|
507
|
+
|
|
439
508
|
start_page = "dashboard.html"
|
|
440
509
|
|
|
510
|
+
# Override start page if specified
|
|
511
|
+
if hasattr(args, "start_page") and args.start_page:
|
|
512
|
+
start_page = args.start_page
|
|
513
|
+
|
|
441
514
|
# Serve from the specified directory
|
|
442
515
|
os.chdir(serve_dir)
|
|
443
516
|
|
|
@@ -452,33 +525,41 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
452
525
|
super().log_message(format, *log_args)
|
|
453
526
|
|
|
454
527
|
def do_POST(self):
|
|
455
|
-
if self.path ==
|
|
528
|
+
if self.path == "/api/stop":
|
|
456
529
|
# Create stop signal file
|
|
457
530
|
stop_file = serve_dir / "STOP_TRAINING"
|
|
458
531
|
stop_file.touch()
|
|
459
532
|
self.send_response(200)
|
|
460
|
-
self.send_header(
|
|
461
|
-
self.send_header(
|
|
533
|
+
self.send_header("Content-Type", "application/json")
|
|
534
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
462
535
|
self.end_headers()
|
|
463
536
|
self.wfile.write(b'{"status": "stop_signal_created"}')
|
|
464
537
|
print(f"\n⏹ Stop signal created: {stop_file}")
|
|
465
|
-
elif self.path ==
|
|
538
|
+
elif self.path == "/api/run-benchmark":
|
|
466
539
|
# Parse request body for provider
|
|
467
|
-
content_length = int(self.headers.get(
|
|
468
|
-
body =
|
|
540
|
+
content_length = int(self.headers.get("Content-Length", 0))
|
|
541
|
+
body = (
|
|
542
|
+
self.rfile.read(content_length).decode("utf-8")
|
|
543
|
+
if content_length
|
|
544
|
+
else "{}"
|
|
545
|
+
)
|
|
469
546
|
try:
|
|
470
547
|
params = json.loads(body)
|
|
471
548
|
except json.JSONDecodeError:
|
|
472
549
|
params = {}
|
|
473
550
|
|
|
474
|
-
provider = params.get(
|
|
475
|
-
tasks = params.get(
|
|
551
|
+
provider = params.get("provider", "anthropic")
|
|
552
|
+
tasks = params.get("tasks", 5)
|
|
476
553
|
|
|
477
554
|
self.send_response(200)
|
|
478
|
-
self.send_header(
|
|
479
|
-
self.send_header(
|
|
555
|
+
self.send_header("Content-Type", "application/json")
|
|
556
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
480
557
|
self.end_headers()
|
|
481
|
-
self.wfile.write(
|
|
558
|
+
self.wfile.write(
|
|
559
|
+
json.dumps(
|
|
560
|
+
{"status": "started", "provider": provider, "tasks": tasks}
|
|
561
|
+
).encode()
|
|
562
|
+
)
|
|
482
563
|
|
|
483
564
|
# Run benchmark in background thread with progress logging
|
|
484
565
|
def run_benchmark():
|
|
@@ -492,25 +573,45 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
492
573
|
# Create progress log file (in cwd which is serve_dir)
|
|
493
574
|
progress_file = Path("benchmark_progress.json")
|
|
494
575
|
|
|
495
|
-
print(
|
|
576
|
+
print(
|
|
577
|
+
f"\n🚀 Starting {provider} benchmark evaluation ({tasks} tasks)..."
|
|
578
|
+
)
|
|
496
579
|
|
|
497
580
|
# Write initial progress
|
|
498
|
-
progress_file.write_text(
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
581
|
+
progress_file.write_text(
|
|
582
|
+
json.dumps(
|
|
583
|
+
{
|
|
584
|
+
"status": "running",
|
|
585
|
+
"provider": provider,
|
|
586
|
+
"tasks_total": tasks,
|
|
587
|
+
"tasks_complete": 0,
|
|
588
|
+
"message": f"Starting {provider} evaluation...",
|
|
589
|
+
}
|
|
590
|
+
)
|
|
591
|
+
)
|
|
505
592
|
|
|
506
593
|
# Copy environment with loaded vars
|
|
507
594
|
env = os.environ.copy()
|
|
508
595
|
|
|
509
596
|
result = subprocess.run(
|
|
510
|
-
[
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
597
|
+
[
|
|
598
|
+
"uv",
|
|
599
|
+
"run",
|
|
600
|
+
"python",
|
|
601
|
+
"-m",
|
|
602
|
+
"openadapt_ml.benchmarks.cli",
|
|
603
|
+
"run-api",
|
|
604
|
+
"--provider",
|
|
605
|
+
provider,
|
|
606
|
+
"--tasks",
|
|
607
|
+
str(tasks),
|
|
608
|
+
"--model-id",
|
|
609
|
+
f"{provider}-api",
|
|
610
|
+
],
|
|
611
|
+
capture_output=True,
|
|
612
|
+
text=True,
|
|
613
|
+
cwd=str(project_root),
|
|
614
|
+
env=env,
|
|
514
615
|
)
|
|
515
616
|
|
|
516
617
|
print(f"\n📋 Benchmark output:\n{result.stdout}")
|
|
@@ -518,53 +619,2995 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
518
619
|
print(f"Stderr: {result.stderr}")
|
|
519
620
|
|
|
520
621
|
if result.returncode == 0:
|
|
521
|
-
print(
|
|
522
|
-
progress_file.write_text(
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
622
|
+
print("✅ Benchmark complete. Regenerating viewer...")
|
|
623
|
+
progress_file.write_text(
|
|
624
|
+
json.dumps(
|
|
625
|
+
{
|
|
626
|
+
"status": "complete",
|
|
627
|
+
"provider": provider,
|
|
628
|
+
"message": "Evaluation complete! Refreshing results...",
|
|
629
|
+
}
|
|
630
|
+
)
|
|
631
|
+
)
|
|
527
632
|
# Regenerate benchmark viewer
|
|
528
633
|
_regenerate_benchmark_viewer_if_available(serve_dir)
|
|
529
634
|
else:
|
|
530
635
|
print(f"❌ Benchmark failed: {result.stderr}")
|
|
531
|
-
progress_file.write_text(
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
636
|
+
progress_file.write_text(
|
|
637
|
+
json.dumps(
|
|
638
|
+
{
|
|
639
|
+
"status": "error",
|
|
640
|
+
"provider": provider,
|
|
641
|
+
"message": f"Evaluation failed: {result.stderr[:200]}",
|
|
642
|
+
}
|
|
643
|
+
)
|
|
644
|
+
)
|
|
536
645
|
|
|
537
646
|
threading.Thread(target=run_benchmark, daemon=True).start()
|
|
647
|
+
elif self.path == "/api/vms/register":
|
|
648
|
+
# Register a new VM
|
|
649
|
+
content_length = int(self.headers.get("Content-Length", 0))
|
|
650
|
+
body = (
|
|
651
|
+
self.rfile.read(content_length).decode("utf-8")
|
|
652
|
+
if content_length
|
|
653
|
+
else "{}"
|
|
654
|
+
)
|
|
655
|
+
try:
|
|
656
|
+
vm_data = json.loads(body)
|
|
657
|
+
result = self._register_vm(vm_data)
|
|
658
|
+
self.send_response(200)
|
|
659
|
+
self.send_header("Content-Type", "application/json")
|
|
660
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
661
|
+
self.end_headers()
|
|
662
|
+
self.wfile.write(json.dumps(result).encode())
|
|
663
|
+
except Exception as e:
|
|
664
|
+
self.send_response(500)
|
|
665
|
+
self.send_header("Content-Type", "application/json")
|
|
666
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
667
|
+
self.end_headers()
|
|
668
|
+
self.wfile.write(json.dumps({"error": str(e)}).encode())
|
|
669
|
+
elif self.path == "/api/benchmark/start":
|
|
670
|
+
# Start a benchmark run with configurable parameters
|
|
671
|
+
content_length = int(self.headers.get("Content-Length", 0))
|
|
672
|
+
body = (
|
|
673
|
+
self.rfile.read(content_length).decode("utf-8")
|
|
674
|
+
if content_length
|
|
675
|
+
else "{}"
|
|
676
|
+
)
|
|
677
|
+
try:
|
|
678
|
+
params = json.loads(body)
|
|
679
|
+
result = self._start_benchmark_run(params)
|
|
680
|
+
self.send_response(200)
|
|
681
|
+
self.send_header("Content-Type", "application/json")
|
|
682
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
683
|
+
self.end_headers()
|
|
684
|
+
self.wfile.write(json.dumps(result).encode())
|
|
685
|
+
except Exception as e:
|
|
686
|
+
self.send_response(500)
|
|
687
|
+
self.send_header("Content-Type", "application/json")
|
|
688
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
689
|
+
self.end_headers()
|
|
690
|
+
self.wfile.write(json.dumps({"error": str(e)}).encode())
|
|
538
691
|
else:
|
|
539
692
|
self.send_error(404, "Not found")
|
|
540
693
|
|
|
541
694
|
def do_GET(self):
|
|
542
|
-
if self.path.startswith(
|
|
695
|
+
if self.path.startswith("/api/benchmark-progress"):
|
|
543
696
|
# Return benchmark progress
|
|
544
|
-
progress_file = Path(
|
|
697
|
+
progress_file = Path(
|
|
698
|
+
"benchmark_progress.json"
|
|
699
|
+
) # Relative to serve_dir (cwd)
|
|
545
700
|
if progress_file.exists():
|
|
546
701
|
progress = progress_file.read_text()
|
|
547
702
|
else:
|
|
548
703
|
progress = json.dumps({"status": "idle"})
|
|
549
704
|
|
|
550
705
|
self.send_response(200)
|
|
551
|
-
self.send_header(
|
|
552
|
-
self.send_header(
|
|
706
|
+
self.send_header("Content-Type", "application/json")
|
|
707
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
553
708
|
self.end_headers()
|
|
554
709
|
self.wfile.write(progress.encode())
|
|
710
|
+
elif self.path.startswith("/api/benchmark-live"):
|
|
711
|
+
# Return live evaluation state
|
|
712
|
+
live_file = Path("benchmark_live.json") # Relative to serve_dir (cwd)
|
|
713
|
+
if live_file.exists():
|
|
714
|
+
live_state = live_file.read_text()
|
|
715
|
+
else:
|
|
716
|
+
live_state = json.dumps({"status": "idle"})
|
|
717
|
+
|
|
718
|
+
self.send_response(200)
|
|
719
|
+
self.send_header("Content-Type", "application/json")
|
|
720
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
721
|
+
self.end_headers()
|
|
722
|
+
self.wfile.write(live_state.encode())
|
|
723
|
+
elif self.path.startswith("/api/tasks"):
|
|
724
|
+
# Return background task status (VM, Docker, benchmarks)
|
|
725
|
+
try:
|
|
726
|
+
tasks = self._fetch_background_tasks()
|
|
727
|
+
self.send_response(200)
|
|
728
|
+
self.send_header("Content-Type", "application/json")
|
|
729
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
730
|
+
self.end_headers()
|
|
731
|
+
self.wfile.write(json.dumps(tasks).encode())
|
|
732
|
+
except Exception as e:
|
|
733
|
+
self.send_response(500)
|
|
734
|
+
self.send_header("Content-Type", "application/json")
|
|
735
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
736
|
+
self.end_headers()
|
|
737
|
+
self.wfile.write(json.dumps({"error": str(e)}).encode())
|
|
738
|
+
elif self.path.startswith("/api/azure-jobs"):
|
|
739
|
+
# Return LIVE Azure job status from Azure ML
|
|
740
|
+
# Supports ?force=true parameter for manual refresh (always fetches live)
|
|
741
|
+
try:
|
|
742
|
+
from urllib.parse import urlparse, parse_qs
|
|
743
|
+
|
|
744
|
+
query = parse_qs(urlparse(self.path).query)
|
|
745
|
+
force_refresh = query.get("force", ["false"])[0].lower() == "true"
|
|
746
|
+
|
|
747
|
+
# Always fetch live data (force just indicates manual refresh for logging)
|
|
748
|
+
if force_refresh:
|
|
749
|
+
print("Azure Jobs: Manual refresh requested")
|
|
750
|
+
|
|
751
|
+
jobs = self._fetch_live_azure_jobs()
|
|
752
|
+
self.send_response(200)
|
|
753
|
+
self.send_header("Content-Type", "application/json")
|
|
754
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
755
|
+
self.end_headers()
|
|
756
|
+
self.wfile.write(json.dumps(jobs).encode())
|
|
757
|
+
except Exception as e:
|
|
758
|
+
self.send_response(500)
|
|
759
|
+
self.send_header("Content-Type", "application/json")
|
|
760
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
761
|
+
self.end_headers()
|
|
762
|
+
self.wfile.write(json.dumps({"error": str(e)}).encode())
|
|
763
|
+
elif self.path.startswith("/api/benchmark-sse"):
|
|
764
|
+
# Server-Sent Events endpoint for real-time benchmark updates
|
|
765
|
+
try:
|
|
766
|
+
from urllib.parse import urlparse, parse_qs
|
|
767
|
+
|
|
768
|
+
query = parse_qs(urlparse(self.path).query)
|
|
769
|
+
interval = int(query.get("interval", [5])[0])
|
|
770
|
+
|
|
771
|
+
# Validate interval (min 1s, max 60s)
|
|
772
|
+
interval = max(1, min(60, interval))
|
|
773
|
+
|
|
774
|
+
self._stream_benchmark_updates(interval)
|
|
775
|
+
except Exception as e:
|
|
776
|
+
self.send_error(500, f"SSE error: {e}")
|
|
777
|
+
elif self.path.startswith("/api/vms"):
|
|
778
|
+
# Return VM registry with live status
|
|
779
|
+
try:
|
|
780
|
+
vms = self._fetch_vm_registry()
|
|
781
|
+
self.send_response(200)
|
|
782
|
+
self.send_header("Content-Type", "application/json")
|
|
783
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
784
|
+
self.end_headers()
|
|
785
|
+
self.wfile.write(json.dumps(vms).encode())
|
|
786
|
+
except Exception as e:
|
|
787
|
+
self.send_response(500)
|
|
788
|
+
self.send_header("Content-Type", "application/json")
|
|
789
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
790
|
+
self.end_headers()
|
|
791
|
+
self.wfile.write(json.dumps({"error": str(e)}).encode())
|
|
792
|
+
elif self.path.startswith("/api/azure-job-logs"):
|
|
793
|
+
# Return live logs for running Azure job
|
|
794
|
+
try:
|
|
795
|
+
# Parse job_id from query string
|
|
796
|
+
from urllib.parse import urlparse, parse_qs
|
|
797
|
+
|
|
798
|
+
query = parse_qs(urlparse(self.path).query)
|
|
799
|
+
job_id = query.get("job_id", [None])[0]
|
|
800
|
+
|
|
801
|
+
logs = self._fetch_azure_job_logs(job_id)
|
|
802
|
+
self.send_response(200)
|
|
803
|
+
self.send_header("Content-Type", "application/json")
|
|
804
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
805
|
+
self.end_headers()
|
|
806
|
+
self.wfile.write(json.dumps(logs).encode())
|
|
807
|
+
except Exception as e:
|
|
808
|
+
self.send_response(500)
|
|
809
|
+
self.send_header("Content-Type", "application/json")
|
|
810
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
811
|
+
self.end_headers()
|
|
812
|
+
self.wfile.write(json.dumps({"error": str(e)}).encode())
|
|
813
|
+
elif self.path.startswith("/api/probe-vm"):
|
|
814
|
+
# Probe the VM to check if WAA server is responding
|
|
815
|
+
try:
|
|
816
|
+
result = self._probe_vm()
|
|
817
|
+
self.send_response(200)
|
|
818
|
+
self.send_header("Content-Type", "application/json")
|
|
819
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
820
|
+
self.end_headers()
|
|
821
|
+
self.wfile.write(json.dumps(result).encode())
|
|
822
|
+
except Exception as e:
|
|
823
|
+
self.send_response(500)
|
|
824
|
+
self.send_header("Content-Type", "application/json")
|
|
825
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
826
|
+
self.end_headers()
|
|
827
|
+
self.wfile.write(
|
|
828
|
+
json.dumps({"error": str(e), "responding": False}).encode()
|
|
829
|
+
)
|
|
830
|
+
elif self.path.startswith("/api/tunnels"):
|
|
831
|
+
# Return SSH tunnel status
|
|
832
|
+
try:
|
|
833
|
+
tunnel_mgr = get_tunnel_manager()
|
|
834
|
+
status = tunnel_mgr.get_tunnel_status()
|
|
835
|
+
result = {
|
|
836
|
+
name: {
|
|
837
|
+
"active": s.active,
|
|
838
|
+
"local_port": s.local_port,
|
|
839
|
+
"remote_endpoint": s.remote_endpoint,
|
|
840
|
+
"pid": s.pid,
|
|
841
|
+
"error": s.error,
|
|
842
|
+
}
|
|
843
|
+
for name, s in status.items()
|
|
844
|
+
}
|
|
845
|
+
self.send_response(200)
|
|
846
|
+
self.send_header("Content-Type", "application/json")
|
|
847
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
848
|
+
self.end_headers()
|
|
849
|
+
self.wfile.write(json.dumps(result).encode())
|
|
850
|
+
except Exception as e:
|
|
851
|
+
self.send_response(500)
|
|
852
|
+
self.send_header("Content-Type", "application/json")
|
|
853
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
854
|
+
self.end_headers()
|
|
855
|
+
self.wfile.write(json.dumps({"error": str(e)}).encode())
|
|
856
|
+
elif self.path.startswith("/api/current-run"):
|
|
857
|
+
# Return currently running benchmark info
|
|
858
|
+
try:
|
|
859
|
+
result = self._get_current_run()
|
|
860
|
+
self.send_response(200)
|
|
861
|
+
self.send_header("Content-Type", "application/json")
|
|
862
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
863
|
+
self.end_headers()
|
|
864
|
+
self.wfile.write(json.dumps(result).encode())
|
|
865
|
+
except Exception as e:
|
|
866
|
+
self.send_response(500)
|
|
867
|
+
self.send_header("Content-Type", "application/json")
|
|
868
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
869
|
+
self.end_headers()
|
|
870
|
+
self.wfile.write(
|
|
871
|
+
json.dumps({"error": str(e), "running": False}).encode()
|
|
872
|
+
)
|
|
873
|
+
elif self.path.startswith("/api/background-tasks"):
|
|
874
|
+
# Alias for /api/tasks - background task status
|
|
875
|
+
try:
|
|
876
|
+
tasks = self._fetch_background_tasks()
|
|
877
|
+
self.send_response(200)
|
|
878
|
+
self.send_header("Content-Type", "application/json")
|
|
879
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
880
|
+
self.end_headers()
|
|
881
|
+
self.wfile.write(json.dumps(tasks).encode())
|
|
882
|
+
except Exception as e:
|
|
883
|
+
self.send_response(500)
|
|
884
|
+
self.send_header("Content-Type", "application/json")
|
|
885
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
886
|
+
self.end_headers()
|
|
887
|
+
self.wfile.write(json.dumps({"error": str(e)}).encode())
|
|
888
|
+
elif self.path.startswith("/api/benchmark/status"):
|
|
889
|
+
# Return current benchmark job status with ETA
|
|
890
|
+
try:
|
|
891
|
+
status = self._get_benchmark_status()
|
|
892
|
+
self.send_response(200)
|
|
893
|
+
self.send_header("Content-Type", "application/json")
|
|
894
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
895
|
+
self.end_headers()
|
|
896
|
+
self.wfile.write(json.dumps(status).encode())
|
|
897
|
+
except Exception as e:
|
|
898
|
+
self.send_response(500)
|
|
899
|
+
self.send_header("Content-Type", "application/json")
|
|
900
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
901
|
+
self.end_headers()
|
|
902
|
+
self.wfile.write(
|
|
903
|
+
json.dumps({"error": str(e), "status": "error"}).encode()
|
|
904
|
+
)
|
|
905
|
+
elif self.path.startswith("/api/benchmark/costs"):
|
|
906
|
+
# Return cost breakdown (Azure VM, API calls, GPU)
|
|
907
|
+
try:
|
|
908
|
+
costs = self._get_benchmark_costs()
|
|
909
|
+
self.send_response(200)
|
|
910
|
+
self.send_header("Content-Type", "application/json")
|
|
911
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
912
|
+
self.end_headers()
|
|
913
|
+
self.wfile.write(json.dumps(costs).encode())
|
|
914
|
+
except Exception as e:
|
|
915
|
+
self.send_response(500)
|
|
916
|
+
self.send_header("Content-Type", "application/json")
|
|
917
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
918
|
+
self.end_headers()
|
|
919
|
+
self.wfile.write(json.dumps({"error": str(e)}).encode())
|
|
920
|
+
elif self.path.startswith("/api/benchmark/metrics"):
|
|
921
|
+
# Return performance metrics (success rate, domain breakdown)
|
|
922
|
+
try:
|
|
923
|
+
metrics = self._get_benchmark_metrics()
|
|
924
|
+
self.send_response(200)
|
|
925
|
+
self.send_header("Content-Type", "application/json")
|
|
926
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
927
|
+
self.end_headers()
|
|
928
|
+
self.wfile.write(json.dumps(metrics).encode())
|
|
929
|
+
except Exception as e:
|
|
930
|
+
self.send_response(500)
|
|
931
|
+
self.send_header("Content-Type", "application/json")
|
|
932
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
933
|
+
self.end_headers()
|
|
934
|
+
self.wfile.write(json.dumps({"error": str(e)}).encode())
|
|
935
|
+
elif self.path.startswith("/api/benchmark/workers"):
|
|
936
|
+
# Return worker status and utilization
|
|
937
|
+
try:
|
|
938
|
+
workers = self._get_benchmark_workers()
|
|
939
|
+
self.send_response(200)
|
|
940
|
+
self.send_header("Content-Type", "application/json")
|
|
941
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
942
|
+
self.end_headers()
|
|
943
|
+
self.wfile.write(json.dumps(workers).encode())
|
|
944
|
+
except Exception as e:
|
|
945
|
+
self.send_response(500)
|
|
946
|
+
self.send_header("Content-Type", "application/json")
|
|
947
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
948
|
+
self.end_headers()
|
|
949
|
+
self.wfile.write(json.dumps({"error": str(e)}).encode())
|
|
950
|
+
elif self.path.startswith("/api/benchmark/runs"):
|
|
951
|
+
# Return list of all benchmark runs
|
|
952
|
+
try:
|
|
953
|
+
runs = self._get_benchmark_runs()
|
|
954
|
+
self.send_response(200)
|
|
955
|
+
self.send_header("Content-Type", "application/json")
|
|
956
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
957
|
+
self.end_headers()
|
|
958
|
+
self.wfile.write(json.dumps(runs).encode())
|
|
959
|
+
except Exception as e:
|
|
960
|
+
self.send_response(500)
|
|
961
|
+
self.send_header("Content-Type", "application/json")
|
|
962
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
963
|
+
self.end_headers()
|
|
964
|
+
self.wfile.write(json.dumps({"error": str(e)}).encode())
|
|
965
|
+
elif self.path.startswith("/api/benchmark/tasks/"):
|
|
966
|
+
# Return task execution details
|
|
967
|
+
# URL format: /api/benchmark/tasks/{run_name}/{task_id}
|
|
968
|
+
try:
|
|
969
|
+
parts = self.path.split("/")
|
|
970
|
+
if len(parts) >= 6:
|
|
971
|
+
run_name = parts[4]
|
|
972
|
+
task_id = parts[5]
|
|
973
|
+
execution = self._get_task_execution(run_name, task_id)
|
|
974
|
+
self.send_response(200)
|
|
975
|
+
self.send_header("Content-Type", "application/json")
|
|
976
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
977
|
+
self.end_headers()
|
|
978
|
+
self.wfile.write(json.dumps(execution).encode())
|
|
979
|
+
else:
|
|
980
|
+
self.send_error(400, "Invalid path format")
|
|
981
|
+
except Exception as e:
|
|
982
|
+
self.send_response(500)
|
|
983
|
+
self.send_header("Content-Type", "application/json")
|
|
984
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
985
|
+
self.end_headers()
|
|
986
|
+
self.wfile.write(json.dumps({"error": str(e)}).encode())
|
|
987
|
+
elif self.path.startswith("/api/benchmark/screenshots/"):
|
|
988
|
+
# Serve screenshot files
|
|
989
|
+
# URL format: /api/benchmark/screenshots/{run_name}/{task_id}/screenshots/{filename}
|
|
990
|
+
try:
|
|
991
|
+
# Remove /api/benchmark/screenshots/ prefix
|
|
992
|
+
path_parts = self.path.replace(
|
|
993
|
+
"/api/benchmark/screenshots/", ""
|
|
994
|
+
).split("/")
|
|
995
|
+
if len(path_parts) >= 4:
|
|
996
|
+
run_name = path_parts[0]
|
|
997
|
+
task_id = path_parts[1]
|
|
998
|
+
# path_parts[2] should be 'screenshots'
|
|
999
|
+
filename = path_parts[3]
|
|
1000
|
+
|
|
1001
|
+
results_dir = Path("benchmark_results")
|
|
1002
|
+
screenshot_path = (
|
|
1003
|
+
results_dir
|
|
1004
|
+
/ run_name
|
|
1005
|
+
/ "tasks"
|
|
1006
|
+
/ task_id
|
|
1007
|
+
/ "screenshots"
|
|
1008
|
+
/ filename
|
|
1009
|
+
)
|
|
1010
|
+
|
|
1011
|
+
if screenshot_path.exists():
|
|
1012
|
+
self.send_response(200)
|
|
1013
|
+
self.send_header("Content-Type", "image/png")
|
|
1014
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
1015
|
+
self.end_headers()
|
|
1016
|
+
with open(screenshot_path, "rb") as f:
|
|
1017
|
+
self.wfile.write(f.read())
|
|
1018
|
+
else:
|
|
1019
|
+
self.send_error(
|
|
1020
|
+
404, f"Screenshot not found: {screenshot_path}"
|
|
1021
|
+
)
|
|
1022
|
+
else:
|
|
1023
|
+
self.send_error(400, "Invalid screenshot path format")
|
|
1024
|
+
except Exception as e:
|
|
1025
|
+
self.send_error(500, f"Error serving screenshot: {e}")
|
|
1026
|
+
elif self.path.startswith("/api/azure-ops-sse"):
|
|
1027
|
+
# Server-Sent Events endpoint for Azure operations status
|
|
1028
|
+
try:
|
|
1029
|
+
self._stream_azure_ops_updates()
|
|
1030
|
+
except Exception as e:
|
|
1031
|
+
self.send_error(500, f"SSE error: {e}")
|
|
1032
|
+
elif self.path.startswith("/api/azure-ops-status"):
|
|
1033
|
+
# Return Azure operations status from JSON file
|
|
1034
|
+
# Session tracker provides elapsed_seconds and cost_usd for
|
|
1035
|
+
# persistence across page refreshes
|
|
1036
|
+
try:
|
|
1037
|
+
from openadapt_ml.benchmarks.azure_ops_tracker import read_status
|
|
1038
|
+
from openadapt_ml.benchmarks.session_tracker import (
|
|
1039
|
+
get_session,
|
|
1040
|
+
update_session_vm_state,
|
|
1041
|
+
)
|
|
1042
|
+
|
|
1043
|
+
# Get operation status (current task)
|
|
1044
|
+
status = read_status()
|
|
1045
|
+
|
|
1046
|
+
# Get session data (persistent across refreshes)
|
|
1047
|
+
session = get_session()
|
|
1048
|
+
|
|
1049
|
+
# Update session based on VM state if we have VM info
|
|
1050
|
+
# IMPORTANT: Only pass vm_ip if it's truthy to avoid
|
|
1051
|
+
# overwriting session's stable vm_ip with None
|
|
1052
|
+
if status.get("vm_state") and status.get("vm_state") != "unknown":
|
|
1053
|
+
status_vm_ip = status.get("vm_ip")
|
|
1054
|
+
# Build update kwargs - only include vm_ip if present
|
|
1055
|
+
update_kwargs = {
|
|
1056
|
+
"vm_state": status["vm_state"],
|
|
1057
|
+
"vm_size": status.get("vm_size"),
|
|
1058
|
+
}
|
|
1059
|
+
if status_vm_ip: # Only include if truthy
|
|
1060
|
+
update_kwargs["vm_ip"] = status_vm_ip
|
|
1061
|
+
session = update_session_vm_state(**update_kwargs)
|
|
1062
|
+
|
|
1063
|
+
# Use session's vm_ip as authoritative source
|
|
1064
|
+
# This prevents IP flickering when status file has stale/None values
|
|
1065
|
+
if session.get("vm_ip"):
|
|
1066
|
+
status["vm_ip"] = session["vm_ip"]
|
|
1067
|
+
|
|
1068
|
+
# Use session's elapsed_seconds and cost_usd for persistence
|
|
1069
|
+
# These survive page refreshes and track total VM runtime
|
|
1070
|
+
if (
|
|
1071
|
+
session.get("is_active")
|
|
1072
|
+
or session.get("accumulated_seconds", 0) > 0
|
|
1073
|
+
):
|
|
1074
|
+
status["elapsed_seconds"] = session.get("elapsed_seconds", 0.0)
|
|
1075
|
+
status["cost_usd"] = session.get("cost_usd", 0.0)
|
|
1076
|
+
status["started_at"] = session.get("started_at")
|
|
1077
|
+
# Include session metadata for debugging
|
|
1078
|
+
status["session_id"] = session.get("session_id")
|
|
1079
|
+
status["session_is_active"] = session.get("is_active", False)
|
|
1080
|
+
# Include accumulated time from previous sessions for hybrid display
|
|
1081
|
+
status["accumulated_seconds"] = session.get(
|
|
1082
|
+
"accumulated_seconds", 0.0
|
|
1083
|
+
)
|
|
1084
|
+
# Calculate current session time (total - accumulated)
|
|
1085
|
+
current_session_seconds = max(
|
|
1086
|
+
0, status["elapsed_seconds"] - status["accumulated_seconds"]
|
|
1087
|
+
)
|
|
1088
|
+
status["current_session_seconds"] = current_session_seconds
|
|
1089
|
+
status["current_session_cost_usd"] = (
|
|
1090
|
+
current_session_seconds / 3600
|
|
1091
|
+
) * session.get("hourly_rate_usd", 0.422)
|
|
1092
|
+
|
|
1093
|
+
try:
|
|
1094
|
+
tunnel_mgr = get_tunnel_manager()
|
|
1095
|
+
tunnel_status = tunnel_mgr.get_tunnel_status()
|
|
1096
|
+
status["tunnels"] = {
|
|
1097
|
+
name: {
|
|
1098
|
+
"active": s.active,
|
|
1099
|
+
"local_port": s.local_port,
|
|
1100
|
+
"remote_endpoint": s.remote_endpoint,
|
|
1101
|
+
"pid": s.pid,
|
|
1102
|
+
"error": s.error,
|
|
1103
|
+
}
|
|
1104
|
+
for name, s in tunnel_status.items()
|
|
1105
|
+
}
|
|
1106
|
+
except Exception as e:
|
|
1107
|
+
status["tunnels"] = {"error": str(e)}
|
|
1108
|
+
|
|
1109
|
+
self.send_response(200)
|
|
1110
|
+
self.send_header("Content-Type", "application/json")
|
|
1111
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
1112
|
+
self.end_headers()
|
|
1113
|
+
self.wfile.write(json.dumps(status).encode())
|
|
1114
|
+
except Exception as e:
|
|
1115
|
+
self.send_response(500)
|
|
1116
|
+
self.send_header("Content-Type", "application/json")
|
|
1117
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
1118
|
+
self.end_headers()
|
|
1119
|
+
self.wfile.write(json.dumps({"error": str(e)}).encode())
|
|
1120
|
+
elif self.path.startswith("/api/vm-diagnostics"):
|
|
1121
|
+
# Return VM diagnostics: disk usage, Docker stats, memory usage
|
|
1122
|
+
try:
|
|
1123
|
+
diagnostics = self._get_vm_diagnostics()
|
|
1124
|
+
self.send_response(200)
|
|
1125
|
+
self.send_header("Content-Type", "application/json")
|
|
1126
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
1127
|
+
self.end_headers()
|
|
1128
|
+
self.wfile.write(json.dumps(diagnostics).encode())
|
|
1129
|
+
except Exception as e:
|
|
1130
|
+
self.send_response(500)
|
|
1131
|
+
self.send_header("Content-Type", "application/json")
|
|
1132
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
1133
|
+
self.end_headers()
|
|
1134
|
+
self.wfile.write(json.dumps({"error": str(e)}).encode())
|
|
555
1135
|
else:
|
|
556
1136
|
# Default file serving
|
|
557
1137
|
super().do_GET()
|
|
558
1138
|
|
|
1139
|
+
def _fetch_live_azure_jobs(self):
|
|
1140
|
+
"""Fetch live job status from Azure ML."""
|
|
1141
|
+
import subprocess
|
|
1142
|
+
|
|
1143
|
+
result = subprocess.run(
|
|
1144
|
+
[
|
|
1145
|
+
"az",
|
|
1146
|
+
"ml",
|
|
1147
|
+
"job",
|
|
1148
|
+
"list",
|
|
1149
|
+
"--resource-group",
|
|
1150
|
+
"openadapt-agents",
|
|
1151
|
+
"--workspace-name",
|
|
1152
|
+
"openadapt-ml",
|
|
1153
|
+
"--query",
|
|
1154
|
+
"[].{name:name,display_name:display_name,status:status,creation_context:creation_context.created_at}",
|
|
1155
|
+
"-o",
|
|
1156
|
+
"json",
|
|
1157
|
+
],
|
|
1158
|
+
capture_output=True,
|
|
1159
|
+
text=True,
|
|
1160
|
+
timeout=30,
|
|
1161
|
+
)
|
|
1162
|
+
if result.returncode != 0:
|
|
1163
|
+
raise Exception(f"Azure CLI error: {result.stderr}")
|
|
1164
|
+
|
|
1165
|
+
jobs = json.loads(result.stdout)
|
|
1166
|
+
# Format for frontend
|
|
1167
|
+
experiment_id = "ad29082c-0607-4fda-8cc7-38944eb5a518"
|
|
1168
|
+
wsid = "/subscriptions/78add6c6-c92a-4a53-b751-eb644ac77e59/resourceGroups/openadapt-agents/providers/Microsoft.MachineLearningServices/workspaces/openadapt-ml"
|
|
1169
|
+
|
|
1170
|
+
formatted = []
|
|
1171
|
+
for job in jobs[:10]: # Limit to 10 most recent
|
|
1172
|
+
formatted.append(
|
|
1173
|
+
{
|
|
1174
|
+
"job_id": job.get("name", "unknown"),
|
|
1175
|
+
"display_name": job.get("display_name", ""),
|
|
1176
|
+
"status": job.get("status", "unknown").lower(),
|
|
1177
|
+
"started_at": job.get("creation_context", ""),
|
|
1178
|
+
"azure_dashboard_url": f"https://ml.azure.com/experiments/id/{experiment_id}/runs/{job.get('name', '')}?wsid={wsid}",
|
|
1179
|
+
"is_live": True, # Flag to indicate this is live data
|
|
1180
|
+
}
|
|
1181
|
+
)
|
|
1182
|
+
return formatted
|
|
1183
|
+
|
|
1184
|
+
def _fetch_azure_job_logs(self, job_id: str | None):
|
|
1185
|
+
"""Fetch logs for an Azure ML job (streaming for running jobs)."""
|
|
1186
|
+
import subprocess
|
|
1187
|
+
|
|
1188
|
+
if not job_id:
|
|
1189
|
+
# Get the most recent running job
|
|
1190
|
+
jobs = self._fetch_live_azure_jobs()
|
|
1191
|
+
running = [j for j in jobs if j["status"] == "running"]
|
|
1192
|
+
if running:
|
|
1193
|
+
job_id = running[0]["job_id"]
|
|
1194
|
+
else:
|
|
1195
|
+
return {
|
|
1196
|
+
"logs": "No running jobs found",
|
|
1197
|
+
"job_id": None,
|
|
1198
|
+
"status": "idle",
|
|
1199
|
+
}
|
|
1200
|
+
|
|
1201
|
+
# Try to stream logs for running job using az ml job stream
|
|
1202
|
+
try:
|
|
1203
|
+
result = subprocess.run(
|
|
1204
|
+
[
|
|
1205
|
+
"az",
|
|
1206
|
+
"ml",
|
|
1207
|
+
"job",
|
|
1208
|
+
"stream",
|
|
1209
|
+
"--name",
|
|
1210
|
+
job_id,
|
|
1211
|
+
"--resource-group",
|
|
1212
|
+
"openadapt-agents",
|
|
1213
|
+
"--workspace-name",
|
|
1214
|
+
"openadapt-ml",
|
|
1215
|
+
],
|
|
1216
|
+
capture_output=True,
|
|
1217
|
+
text=True,
|
|
1218
|
+
timeout=3, # Short timeout
|
|
1219
|
+
)
|
|
1220
|
+
if result.returncode == 0 and result.stdout.strip():
|
|
1221
|
+
return {
|
|
1222
|
+
"logs": result.stdout[-5000:],
|
|
1223
|
+
"job_id": job_id,
|
|
1224
|
+
"status": "streaming",
|
|
1225
|
+
}
|
|
1226
|
+
except subprocess.TimeoutExpired:
|
|
1227
|
+
pass # Fall through to job show
|
|
1228
|
+
|
|
1229
|
+
# Get job details instead
|
|
1230
|
+
result = subprocess.run(
|
|
1231
|
+
[
|
|
1232
|
+
"az",
|
|
1233
|
+
"ml",
|
|
1234
|
+
"job",
|
|
1235
|
+
"show",
|
|
1236
|
+
"--name",
|
|
1237
|
+
job_id,
|
|
1238
|
+
"--resource-group",
|
|
1239
|
+
"openadapt-agents",
|
|
1240
|
+
"--workspace-name",
|
|
1241
|
+
"openadapt-ml",
|
|
1242
|
+
"-o",
|
|
1243
|
+
"json",
|
|
1244
|
+
],
|
|
1245
|
+
capture_output=True,
|
|
1246
|
+
text=True,
|
|
1247
|
+
timeout=10,
|
|
1248
|
+
)
|
|
1249
|
+
|
|
1250
|
+
if result.returncode == 0:
|
|
1251
|
+
job_info = json.loads(result.stdout)
|
|
1252
|
+
return {
|
|
1253
|
+
"logs": f"Job {job_id} is {job_info.get('status', 'unknown')}\\n\\nCommand: {job_info.get('command', 'N/A')}",
|
|
1254
|
+
"job_id": job_id,
|
|
1255
|
+
"status": job_info.get("status", "unknown").lower(),
|
|
1256
|
+
"command": job_info.get("command", ""),
|
|
1257
|
+
}
|
|
1258
|
+
|
|
1259
|
+
return {
|
|
1260
|
+
"logs": f"Could not fetch logs: {result.stderr}",
|
|
1261
|
+
"job_id": job_id,
|
|
1262
|
+
"status": "error",
|
|
1263
|
+
}
|
|
1264
|
+
|
|
1265
|
+
def _get_vm_detailed_metadata(
|
|
1266
|
+
self, vm_ip: str, container_name: str, logs: str, phase: str
|
|
1267
|
+
) -> dict:
|
|
1268
|
+
"""Get detailed VM metadata for the VM Details panel.
|
|
1269
|
+
|
|
1270
|
+
Returns:
|
|
1271
|
+
dict with disk_usage_gb, memory_usage_mb, setup_script_phase, probe_response, qmp_connected, dependencies
|
|
1272
|
+
"""
|
|
1273
|
+
import subprocess
|
|
1274
|
+
|
|
1275
|
+
metadata = {
|
|
1276
|
+
"disk_usage_gb": None,
|
|
1277
|
+
"memory_usage_mb": None,
|
|
1278
|
+
"setup_script_phase": None,
|
|
1279
|
+
"probe_response": None,
|
|
1280
|
+
"qmp_connected": False,
|
|
1281
|
+
"dependencies": [],
|
|
1282
|
+
}
|
|
1283
|
+
|
|
1284
|
+
# 1. Get disk usage from docker stats
|
|
1285
|
+
try:
|
|
1286
|
+
disk_result = subprocess.run(
|
|
1287
|
+
[
|
|
1288
|
+
"ssh",
|
|
1289
|
+
"-o",
|
|
1290
|
+
"StrictHostKeyChecking=no",
|
|
1291
|
+
"-o",
|
|
1292
|
+
"ConnectTimeout=5",
|
|
1293
|
+
"-i",
|
|
1294
|
+
str(Path.home() / ".ssh" / "id_rsa"),
|
|
1295
|
+
f"azureuser@{vm_ip}",
|
|
1296
|
+
f"docker exec {container_name} df -h /storage 2>/dev/null | tail -1",
|
|
1297
|
+
],
|
|
1298
|
+
capture_output=True,
|
|
1299
|
+
text=True,
|
|
1300
|
+
timeout=10,
|
|
1301
|
+
)
|
|
1302
|
+
if disk_result.returncode == 0 and disk_result.stdout.strip():
|
|
1303
|
+
# Parse: "Filesystem Size Used Avail Use% Mounted on"
|
|
1304
|
+
# Example: "/dev/sda1 30G 9.2G 20G 31% /storage"
|
|
1305
|
+
parts = disk_result.stdout.split()
|
|
1306
|
+
if len(parts) >= 3:
|
|
1307
|
+
used_str = parts[2] # e.g., "9.2G"
|
|
1308
|
+
total_str = parts[1] # e.g., "30G"
|
|
1309
|
+
|
|
1310
|
+
# Convert to GB (handle M/G suffixes)
|
|
1311
|
+
def to_gb(s):
|
|
1312
|
+
if s.endswith("G"):
|
|
1313
|
+
return float(s[:-1])
|
|
1314
|
+
elif s.endswith("M"):
|
|
1315
|
+
return float(s[:-1]) / 1024
|
|
1316
|
+
elif s.endswith("K"):
|
|
1317
|
+
return float(s[:-1]) / (1024 * 1024)
|
|
1318
|
+
return 0
|
|
1319
|
+
|
|
1320
|
+
metadata["disk_usage_gb"] = (
|
|
1321
|
+
f"{to_gb(used_str):.1f} GB / {to_gb(total_str):.0f} GB used"
|
|
1322
|
+
)
|
|
1323
|
+
except Exception:
|
|
1324
|
+
pass
|
|
1325
|
+
|
|
1326
|
+
# 2. Get memory usage from docker stats
|
|
1327
|
+
try:
|
|
1328
|
+
mem_result = subprocess.run(
|
|
1329
|
+
[
|
|
1330
|
+
"ssh",
|
|
1331
|
+
"-o",
|
|
1332
|
+
"StrictHostKeyChecking=no",
|
|
1333
|
+
"-o",
|
|
1334
|
+
"ConnectTimeout=5",
|
|
1335
|
+
"-i",
|
|
1336
|
+
str(Path.home() / ".ssh" / "id_rsa"),
|
|
1337
|
+
f"azureuser@{vm_ip}",
|
|
1338
|
+
f"docker stats {container_name} --no-stream --format '{{{{.MemUsage}}}}'",
|
|
1339
|
+
],
|
|
1340
|
+
capture_output=True,
|
|
1341
|
+
text=True,
|
|
1342
|
+
timeout=10,
|
|
1343
|
+
)
|
|
1344
|
+
if mem_result.returncode == 0 and mem_result.stdout.strip():
|
|
1345
|
+
# Example: "1.5GiB / 4GiB"
|
|
1346
|
+
metadata["memory_usage_mb"] = mem_result.stdout.strip()
|
|
1347
|
+
except Exception:
|
|
1348
|
+
pass
|
|
1349
|
+
|
|
1350
|
+
# 3. Parse setup script phase from logs
|
|
1351
|
+
metadata["setup_script_phase"] = self._parse_setup_phase_from_logs(
|
|
1352
|
+
logs, phase
|
|
1353
|
+
)
|
|
1354
|
+
|
|
1355
|
+
# 4. Check /probe endpoint
|
|
1356
|
+
try:
|
|
1357
|
+
probe_result = subprocess.run(
|
|
1358
|
+
[
|
|
1359
|
+
"ssh",
|
|
1360
|
+
"-o",
|
|
1361
|
+
"StrictHostKeyChecking=no",
|
|
1362
|
+
"-o",
|
|
1363
|
+
"ConnectTimeout=5",
|
|
1364
|
+
"-i",
|
|
1365
|
+
str(Path.home() / ".ssh" / "id_rsa"),
|
|
1366
|
+
f"azureuser@{vm_ip}",
|
|
1367
|
+
"curl -s --connect-timeout 2 http://20.20.20.21:5000/probe 2>/dev/null",
|
|
1368
|
+
],
|
|
1369
|
+
capture_output=True,
|
|
1370
|
+
text=True,
|
|
1371
|
+
timeout=10,
|
|
1372
|
+
)
|
|
1373
|
+
if probe_result.returncode == 0 and probe_result.stdout.strip():
|
|
1374
|
+
metadata["probe_response"] = probe_result.stdout.strip()
|
|
1375
|
+
else:
|
|
1376
|
+
metadata["probe_response"] = "Not responding"
|
|
1377
|
+
except Exception:
|
|
1378
|
+
metadata["probe_response"] = "Connection failed"
|
|
1379
|
+
|
|
1380
|
+
# 5. Check QMP connection (port 7200)
|
|
1381
|
+
try:
|
|
1382
|
+
qmp_result = subprocess.run(
|
|
1383
|
+
[
|
|
1384
|
+
"ssh",
|
|
1385
|
+
"-o",
|
|
1386
|
+
"StrictHostKeyChecking=no",
|
|
1387
|
+
"-o",
|
|
1388
|
+
"ConnectTimeout=5",
|
|
1389
|
+
"-i",
|
|
1390
|
+
str(Path.home() / ".ssh" / "id_rsa"),
|
|
1391
|
+
f"azureuser@{vm_ip}",
|
|
1392
|
+
"nc -z -w2 localhost 7200 2>&1",
|
|
1393
|
+
],
|
|
1394
|
+
capture_output=True,
|
|
1395
|
+
text=True,
|
|
1396
|
+
timeout=10,
|
|
1397
|
+
)
|
|
1398
|
+
metadata["qmp_connected"] = qmp_result.returncode == 0
|
|
1399
|
+
except Exception:
|
|
1400
|
+
pass
|
|
1401
|
+
|
|
1402
|
+
# 6. Parse dependencies from logs
|
|
1403
|
+
metadata["dependencies"] = self._parse_dependencies_from_logs(logs, phase)
|
|
1404
|
+
|
|
1405
|
+
return metadata
|
|
1406
|
+
|
|
1407
|
+
def _parse_setup_phase_from_logs(self, logs: str, current_phase: str) -> str:
|
|
1408
|
+
"""Parse the current setup script phase from logs.
|
|
1409
|
+
|
|
1410
|
+
Looks for patterns indicating which script is running:
|
|
1411
|
+
- install.bat
|
|
1412
|
+
- setup.ps1
|
|
1413
|
+
- on-logon.ps1
|
|
1414
|
+
"""
|
|
1415
|
+
if current_phase == "ready":
|
|
1416
|
+
return "Setup complete"
|
|
1417
|
+
elif current_phase == "oobe":
|
|
1418
|
+
# Check for specific script patterns
|
|
1419
|
+
if "on-logon.ps1" in logs.lower():
|
|
1420
|
+
return "Running on-logon.ps1"
|
|
1421
|
+
elif "setup.ps1" in logs.lower():
|
|
1422
|
+
return "Running setup.ps1"
|
|
1423
|
+
elif "install.bat" in logs.lower():
|
|
1424
|
+
return "Running install.bat"
|
|
1425
|
+
else:
|
|
1426
|
+
return "Windows installation in progress"
|
|
1427
|
+
elif current_phase == "booting":
|
|
1428
|
+
return "Booting Windows"
|
|
1429
|
+
elif current_phase in [
|
|
1430
|
+
"downloading",
|
|
1431
|
+
"extracting",
|
|
1432
|
+
"configuring",
|
|
1433
|
+
"building",
|
|
1434
|
+
]:
|
|
1435
|
+
return "Preparing Windows VM"
|
|
1436
|
+
else:
|
|
1437
|
+
return "Initializing..."
|
|
1438
|
+
|
|
1439
|
+
def _parse_dependencies_from_logs(self, logs: str, phase: str) -> list[dict]:
|
|
1440
|
+
"""Parse dependency installation status from logs.
|
|
1441
|
+
|
|
1442
|
+
Returns list of dependencies with their installation status:
|
|
1443
|
+
- Python
|
|
1444
|
+
- Chrome
|
|
1445
|
+
- LibreOffice
|
|
1446
|
+
- VSCode
|
|
1447
|
+
- etc.
|
|
1448
|
+
"""
|
|
1449
|
+
dependencies = [
|
|
1450
|
+
{"name": "Python", "icon": "🐍", "status": "pending"},
|
|
1451
|
+
{"name": "Chrome", "icon": "🌐", "status": "pending"},
|
|
1452
|
+
{"name": "LibreOffice", "icon": "📝", "status": "pending"},
|
|
1453
|
+
{"name": "VSCode", "icon": "💻", "status": "pending"},
|
|
1454
|
+
{"name": "WAA Server", "icon": "🔧", "status": "pending"},
|
|
1455
|
+
]
|
|
1456
|
+
|
|
1457
|
+
if phase not in ["oobe", "ready"]:
|
|
1458
|
+
# Not yet at Windows setup phase
|
|
1459
|
+
return dependencies
|
|
1460
|
+
|
|
1461
|
+
logs_lower = logs.lower()
|
|
1462
|
+
|
|
1463
|
+
# Check for installation patterns
|
|
1464
|
+
if "python" in logs_lower and (
|
|
1465
|
+
"installing python" in logs_lower or "python.exe" in logs_lower
|
|
1466
|
+
):
|
|
1467
|
+
dependencies[0]["status"] = "installing"
|
|
1468
|
+
elif "python" in logs_lower and "installed" in logs_lower:
|
|
1469
|
+
dependencies[0]["status"] = "complete"
|
|
1470
|
+
|
|
1471
|
+
if "chrome" in logs_lower and (
|
|
1472
|
+
"downloading" in logs_lower or "installing" in logs_lower
|
|
1473
|
+
):
|
|
1474
|
+
dependencies[1]["status"] = "installing"
|
|
1475
|
+
elif "chrome" in logs_lower and "installed" in logs_lower:
|
|
1476
|
+
dependencies[1]["status"] = "complete"
|
|
1477
|
+
|
|
1478
|
+
if "libreoffice" in logs_lower and (
|
|
1479
|
+
"downloading" in logs_lower or "installing" in logs_lower
|
|
1480
|
+
):
|
|
1481
|
+
dependencies[2]["status"] = "installing"
|
|
1482
|
+
elif "libreoffice" in logs_lower and "installed" in logs_lower:
|
|
1483
|
+
dependencies[2]["status"] = "complete"
|
|
1484
|
+
|
|
1485
|
+
if "vscode" in logs_lower or "visual studio code" in logs_lower:
|
|
1486
|
+
if "installing" in logs_lower:
|
|
1487
|
+
dependencies[3]["status"] = "installing"
|
|
1488
|
+
elif "installed" in logs_lower:
|
|
1489
|
+
dependencies[3]["status"] = "complete"
|
|
1490
|
+
|
|
1491
|
+
if "waa" in logs_lower or "flask" in logs_lower:
|
|
1492
|
+
if "starting" in logs_lower or "running" in logs_lower:
|
|
1493
|
+
dependencies[4]["status"] = "installing"
|
|
1494
|
+
elif phase == "ready":
|
|
1495
|
+
dependencies[4]["status"] = "complete"
|
|
1496
|
+
|
|
1497
|
+
return dependencies
|
|
1498
|
+
|
|
1499
|
+
def _get_vm_diagnostics(self) -> dict:
|
|
1500
|
+
"""Get VM diagnostics: disk usage, Docker stats, memory usage.
|
|
1501
|
+
|
|
1502
|
+
Returns a dictionary with:
|
|
1503
|
+
- vm_online: bool - whether VM is reachable
|
|
1504
|
+
- disk_usage: list of disk partitions with usage stats
|
|
1505
|
+
- docker_stats: list of container stats (CPU, memory)
|
|
1506
|
+
- memory_usage: VM host memory stats
|
|
1507
|
+
- docker_system: Docker system disk usage
|
|
1508
|
+
- error: str if any error occurred
|
|
1509
|
+
"""
|
|
1510
|
+
import subprocess
|
|
1511
|
+
|
|
1512
|
+
from openadapt_ml.benchmarks.session_tracker import get_session
|
|
1513
|
+
|
|
1514
|
+
diagnostics = {
|
|
1515
|
+
"vm_online": False,
|
|
1516
|
+
"disk_usage": [],
|
|
1517
|
+
"docker_stats": [],
|
|
1518
|
+
"memory_usage": {},
|
|
1519
|
+
"docker_system": {},
|
|
1520
|
+
"docker_images": [],
|
|
1521
|
+
"error": None,
|
|
1522
|
+
}
|
|
1523
|
+
|
|
1524
|
+
# Get VM IP from session
|
|
1525
|
+
session = get_session()
|
|
1526
|
+
vm_ip = session.get("vm_ip")
|
|
1527
|
+
|
|
1528
|
+
if not vm_ip:
|
|
1529
|
+
diagnostics["error"] = (
|
|
1530
|
+
"VM IP not found in session. VM may not be running."
|
|
1531
|
+
)
|
|
1532
|
+
return diagnostics
|
|
1533
|
+
|
|
1534
|
+
# SSH options for Azure VM
|
|
1535
|
+
ssh_opts = [
|
|
1536
|
+
"-o",
|
|
1537
|
+
"StrictHostKeyChecking=no",
|
|
1538
|
+
"-o",
|
|
1539
|
+
"UserKnownHostsFile=/dev/null",
|
|
1540
|
+
"-o",
|
|
1541
|
+
"ConnectTimeout=10",
|
|
1542
|
+
"-o",
|
|
1543
|
+
"ServerAliveInterval=30",
|
|
1544
|
+
]
|
|
1545
|
+
|
|
1546
|
+
# Test VM connectivity
|
|
1547
|
+
try:
|
|
1548
|
+
test_result = subprocess.run(
|
|
1549
|
+
["ssh", *ssh_opts, f"azureuser@{vm_ip}", "echo 'online'"],
|
|
1550
|
+
capture_output=True,
|
|
1551
|
+
text=True,
|
|
1552
|
+
timeout=15,
|
|
1553
|
+
)
|
|
1554
|
+
if test_result.returncode != 0:
|
|
1555
|
+
diagnostics["error"] = f"Cannot connect to VM at {vm_ip}"
|
|
1556
|
+
return diagnostics
|
|
1557
|
+
diagnostics["vm_online"] = True
|
|
1558
|
+
except subprocess.TimeoutExpired:
|
|
1559
|
+
diagnostics["error"] = f"Connection to VM at {vm_ip} timed out"
|
|
1560
|
+
return diagnostics
|
|
1561
|
+
except Exception as e:
|
|
1562
|
+
diagnostics["error"] = f"SSH error: {str(e)}"
|
|
1563
|
+
return diagnostics
|
|
1564
|
+
|
|
1565
|
+
# 1. Disk usage (df -h)
|
|
1566
|
+
try:
|
|
1567
|
+
df_result = subprocess.run(
|
|
1568
|
+
[
|
|
1569
|
+
"ssh",
|
|
1570
|
+
*ssh_opts,
|
|
1571
|
+
f"azureuser@{vm_ip}",
|
|
1572
|
+
"df -h / /mnt 2>/dev/null | tail -n +2",
|
|
1573
|
+
],
|
|
1574
|
+
capture_output=True,
|
|
1575
|
+
text=True,
|
|
1576
|
+
timeout=15,
|
|
1577
|
+
)
|
|
1578
|
+
if df_result.returncode == 0 and df_result.stdout.strip():
|
|
1579
|
+
for line in df_result.stdout.strip().split("\n"):
|
|
1580
|
+
parts = line.split()
|
|
1581
|
+
if len(parts) >= 6:
|
|
1582
|
+
diagnostics["disk_usage"].append(
|
|
1583
|
+
{
|
|
1584
|
+
"filesystem": parts[0],
|
|
1585
|
+
"size": parts[1],
|
|
1586
|
+
"used": parts[2],
|
|
1587
|
+
"available": parts[3],
|
|
1588
|
+
"use_percent": parts[4],
|
|
1589
|
+
"mount_point": parts[5],
|
|
1590
|
+
}
|
|
1591
|
+
)
|
|
1592
|
+
except Exception as e:
|
|
1593
|
+
diagnostics["disk_usage"] = [{"error": str(e)}]
|
|
1594
|
+
|
|
1595
|
+
# 2. Docker container stats
|
|
1596
|
+
try:
|
|
1597
|
+
stats_result = subprocess.run(
|
|
1598
|
+
[
|
|
1599
|
+
"ssh",
|
|
1600
|
+
*ssh_opts,
|
|
1601
|
+
f"azureuser@{vm_ip}",
|
|
1602
|
+
"docker stats --no-stream --format '{{.Name}}|{{.CPUPerc}}|{{.MemUsage}}|{{.MemPerc}}|{{.NetIO}}|{{.BlockIO}}' 2>/dev/null || echo ''",
|
|
1603
|
+
],
|
|
1604
|
+
capture_output=True,
|
|
1605
|
+
text=True,
|
|
1606
|
+
timeout=30,
|
|
1607
|
+
)
|
|
1608
|
+
if stats_result.returncode == 0 and stats_result.stdout.strip():
|
|
1609
|
+
for line in stats_result.stdout.strip().split("\n"):
|
|
1610
|
+
if "|" in line:
|
|
1611
|
+
parts = line.split("|")
|
|
1612
|
+
if len(parts) >= 6:
|
|
1613
|
+
diagnostics["docker_stats"].append(
|
|
1614
|
+
{
|
|
1615
|
+
"container": parts[0],
|
|
1616
|
+
"cpu_percent": parts[1],
|
|
1617
|
+
"memory_usage": parts[2],
|
|
1618
|
+
"memory_percent": parts[3],
|
|
1619
|
+
"net_io": parts[4],
|
|
1620
|
+
"block_io": parts[5],
|
|
1621
|
+
}
|
|
1622
|
+
)
|
|
1623
|
+
except Exception as e:
|
|
1624
|
+
diagnostics["docker_stats"] = [{"error": str(e)}]
|
|
1625
|
+
|
|
1626
|
+
# 3. VM host memory usage (free -h)
|
|
1627
|
+
try:
|
|
1628
|
+
mem_result = subprocess.run(
|
|
1629
|
+
[
|
|
1630
|
+
"ssh",
|
|
1631
|
+
*ssh_opts,
|
|
1632
|
+
f"azureuser@{vm_ip}",
|
|
1633
|
+
"free -h | head -2 | tail -1",
|
|
1634
|
+
],
|
|
1635
|
+
capture_output=True,
|
|
1636
|
+
text=True,
|
|
1637
|
+
timeout=15,
|
|
1638
|
+
)
|
|
1639
|
+
if mem_result.returncode == 0 and mem_result.stdout.strip():
|
|
1640
|
+
parts = mem_result.stdout.strip().split()
|
|
1641
|
+
if len(parts) >= 7:
|
|
1642
|
+
diagnostics["memory_usage"] = {
|
|
1643
|
+
"total": parts[1],
|
|
1644
|
+
"used": parts[2],
|
|
1645
|
+
"free": parts[3],
|
|
1646
|
+
"shared": parts[4],
|
|
1647
|
+
"buff_cache": parts[5],
|
|
1648
|
+
"available": parts[6],
|
|
1649
|
+
}
|
|
1650
|
+
except Exception as e:
|
|
1651
|
+
diagnostics["memory_usage"] = {"error": str(e)}
|
|
1652
|
+
|
|
1653
|
+
# 4. Docker system disk usage
|
|
1654
|
+
try:
|
|
1655
|
+
docker_df_result = subprocess.run(
|
|
1656
|
+
[
|
|
1657
|
+
"ssh",
|
|
1658
|
+
*ssh_opts,
|
|
1659
|
+
f"azureuser@{vm_ip}",
|
|
1660
|
+
"docker system df 2>/dev/null || echo ''",
|
|
1661
|
+
],
|
|
1662
|
+
capture_output=True,
|
|
1663
|
+
text=True,
|
|
1664
|
+
timeout=15,
|
|
1665
|
+
)
|
|
1666
|
+
if docker_df_result.returncode == 0 and docker_df_result.stdout.strip():
|
|
1667
|
+
lines = docker_df_result.stdout.strip().split("\n")
|
|
1668
|
+
# Parse the table: TYPE, TOTAL, ACTIVE, SIZE, RECLAIMABLE
|
|
1669
|
+
for line in lines[1:]: # Skip header
|
|
1670
|
+
parts = line.split()
|
|
1671
|
+
if len(parts) >= 5:
|
|
1672
|
+
dtype = parts[0]
|
|
1673
|
+
diagnostics["docker_system"][dtype.lower()] = {
|
|
1674
|
+
"total": parts[1],
|
|
1675
|
+
"active": parts[2],
|
|
1676
|
+
"size": parts[3],
|
|
1677
|
+
"reclaimable": " ".join(parts[4:]),
|
|
1678
|
+
}
|
|
1679
|
+
except Exception as e:
|
|
1680
|
+
diagnostics["docker_system"] = {"error": str(e)}
|
|
1681
|
+
|
|
1682
|
+
# 5. Docker images
|
|
1683
|
+
try:
|
|
1684
|
+
images_result = subprocess.run(
|
|
1685
|
+
[
|
|
1686
|
+
"ssh",
|
|
1687
|
+
*ssh_opts,
|
|
1688
|
+
f"azureuser@{vm_ip}",
|
|
1689
|
+
"docker images --format '{{.Repository}}:{{.Tag}}|{{.Size}}|{{.CreatedSince}}' 2>/dev/null || echo ''",
|
|
1690
|
+
],
|
|
1691
|
+
capture_output=True,
|
|
1692
|
+
text=True,
|
|
1693
|
+
timeout=15,
|
|
1694
|
+
)
|
|
1695
|
+
if images_result.returncode == 0 and images_result.stdout.strip():
|
|
1696
|
+
for line in images_result.stdout.strip().split("\n"):
|
|
1697
|
+
if "|" in line:
|
|
1698
|
+
parts = line.split("|")
|
|
1699
|
+
if len(parts) >= 3:
|
|
1700
|
+
diagnostics["docker_images"].append(
|
|
1701
|
+
{
|
|
1702
|
+
"image": parts[0],
|
|
1703
|
+
"size": parts[1],
|
|
1704
|
+
"created": parts[2],
|
|
1705
|
+
}
|
|
1706
|
+
)
|
|
1707
|
+
except Exception as e:
|
|
1708
|
+
diagnostics["docker_images"] = [{"error": str(e)}]
|
|
1709
|
+
|
|
1710
|
+
return diagnostics
|
|
1711
|
+
|
|
1712
|
+
def _fetch_background_tasks(self):
|
|
1713
|
+
"""Fetch status of all background tasks: Azure VM, Docker containers, benchmarks."""
|
|
1714
|
+
import subprocess
|
|
1715
|
+
|
|
1716
|
+
tasks = []
|
|
1717
|
+
|
|
1718
|
+
# Check for VM IP from environment (set by CLI when auto-launching viewer)
|
|
1719
|
+
env_vm_ip = os.environ.get("WAA_VM_IP")
|
|
1720
|
+
env_internal_ip = os.environ.get("WAA_INTERNAL_IP", "172.30.0.2")
|
|
1721
|
+
|
|
1722
|
+
# 1. Check Azure WAA VM status
|
|
1723
|
+
vm_ip = None
|
|
1724
|
+
if env_vm_ip:
|
|
1725
|
+
# Use environment variable - VM IP was provided directly
|
|
1726
|
+
vm_ip = env_vm_ip
|
|
1727
|
+
tasks.append(
|
|
1728
|
+
{
|
|
1729
|
+
"task_id": "azure-vm-waa",
|
|
1730
|
+
"task_type": "vm_provision",
|
|
1731
|
+
"status": "completed",
|
|
1732
|
+
"phase": "ready", # Match status to prevent "Starting" + "completed" conflict
|
|
1733
|
+
"title": "Azure VM Host",
|
|
1734
|
+
"description": f"Linux host running at {vm_ip}",
|
|
1735
|
+
"progress_percent": 100.0,
|
|
1736
|
+
"elapsed_seconds": 0,
|
|
1737
|
+
"metadata": {
|
|
1738
|
+
"vm_name": "waa-eval-vm",
|
|
1739
|
+
"ip_address": vm_ip,
|
|
1740
|
+
"internal_ip": env_internal_ip,
|
|
1741
|
+
},
|
|
1742
|
+
}
|
|
1743
|
+
)
|
|
1744
|
+
else:
|
|
1745
|
+
# Query Azure CLI for VM status
|
|
1746
|
+
try:
|
|
1747
|
+
result = subprocess.run(
|
|
1748
|
+
[
|
|
1749
|
+
"az",
|
|
1750
|
+
"vm",
|
|
1751
|
+
"get-instance-view",
|
|
1752
|
+
"--name",
|
|
1753
|
+
"waa-eval-vm",
|
|
1754
|
+
"--resource-group",
|
|
1755
|
+
"openadapt-agents",
|
|
1756
|
+
"--query",
|
|
1757
|
+
"instanceView.statuses",
|
|
1758
|
+
"-o",
|
|
1759
|
+
"json",
|
|
1760
|
+
],
|
|
1761
|
+
capture_output=True,
|
|
1762
|
+
text=True,
|
|
1763
|
+
timeout=10,
|
|
1764
|
+
)
|
|
1765
|
+
if result.returncode == 0:
|
|
1766
|
+
statuses = json.loads(result.stdout)
|
|
1767
|
+
power_state = "unknown"
|
|
1768
|
+
for s in statuses:
|
|
1769
|
+
if s.get("code", "").startswith("PowerState/"):
|
|
1770
|
+
power_state = s["code"].replace("PowerState/", "")
|
|
1771
|
+
|
|
1772
|
+
# Get VM IP
|
|
1773
|
+
ip_result = subprocess.run(
|
|
1774
|
+
[
|
|
1775
|
+
"az",
|
|
1776
|
+
"vm",
|
|
1777
|
+
"list-ip-addresses",
|
|
1778
|
+
"--name",
|
|
1779
|
+
"waa-eval-vm",
|
|
1780
|
+
"--resource-group",
|
|
1781
|
+
"openadapt-agents",
|
|
1782
|
+
"--query",
|
|
1783
|
+
"[0].virtualMachine.network.publicIpAddresses[0].ipAddress",
|
|
1784
|
+
"-o",
|
|
1785
|
+
"tsv",
|
|
1786
|
+
],
|
|
1787
|
+
capture_output=True,
|
|
1788
|
+
text=True,
|
|
1789
|
+
timeout=10,
|
|
1790
|
+
)
|
|
1791
|
+
vm_ip = (
|
|
1792
|
+
ip_result.stdout.strip()
|
|
1793
|
+
if ip_result.returncode == 0
|
|
1794
|
+
else None
|
|
1795
|
+
)
|
|
1796
|
+
|
|
1797
|
+
if power_state == "running":
|
|
1798
|
+
tasks.append(
|
|
1799
|
+
{
|
|
1800
|
+
"task_id": "azure-vm-waa",
|
|
1801
|
+
"task_type": "vm_provision",
|
|
1802
|
+
"status": "completed",
|
|
1803
|
+
"phase": "ready", # Match status to prevent "Starting" + "completed" conflict
|
|
1804
|
+
"title": "Azure VM Host",
|
|
1805
|
+
"description": f"Linux host running at {vm_ip}"
|
|
1806
|
+
if vm_ip
|
|
1807
|
+
else "Linux host running",
|
|
1808
|
+
"progress_percent": 100.0,
|
|
1809
|
+
"elapsed_seconds": 0,
|
|
1810
|
+
"metadata": {
|
|
1811
|
+
"vm_name": "waa-eval-vm",
|
|
1812
|
+
"ip_address": vm_ip,
|
|
1813
|
+
# No VNC link - that's for the Windows container
|
|
1814
|
+
},
|
|
1815
|
+
}
|
|
1816
|
+
)
|
|
1817
|
+
except subprocess.TimeoutExpired:
|
|
1818
|
+
pass
|
|
1819
|
+
except Exception:
|
|
1820
|
+
pass
|
|
1821
|
+
|
|
1822
|
+
# 2. Check Docker container status on VM (if we have an IP)
|
|
1823
|
+
if vm_ip:
|
|
1824
|
+
try:
|
|
1825
|
+
docker_result = subprocess.run(
|
|
1826
|
+
[
|
|
1827
|
+
"ssh",
|
|
1828
|
+
"-o",
|
|
1829
|
+
"StrictHostKeyChecking=no",
|
|
1830
|
+
"-o",
|
|
1831
|
+
"ConnectTimeout=5",
|
|
1832
|
+
"-i",
|
|
1833
|
+
str(Path.home() / ".ssh" / "id_rsa"),
|
|
1834
|
+
f"azureuser@{vm_ip}",
|
|
1835
|
+
"docker ps --format '{{.Names}}|{{.Status}}|{{.Image}}'",
|
|
1836
|
+
],
|
|
1837
|
+
capture_output=True,
|
|
1838
|
+
text=True,
|
|
1839
|
+
timeout=15,
|
|
1840
|
+
)
|
|
1841
|
+
if docker_result.returncode == 0 and docker_result.stdout.strip():
|
|
1842
|
+
for line in docker_result.stdout.strip().split("\n"):
|
|
1843
|
+
parts = line.split("|")
|
|
1844
|
+
if len(parts) >= 3:
|
|
1845
|
+
container_name, status, image = (
|
|
1846
|
+
parts[0],
|
|
1847
|
+
parts[1],
|
|
1848
|
+
parts[2],
|
|
1849
|
+
)
|
|
1850
|
+
# Parse "Up X minutes" to determine if healthy
|
|
1851
|
+
|
|
1852
|
+
# Check for Windows VM specifically
|
|
1853
|
+
if (
|
|
1854
|
+
"windows" in image.lower()
|
|
1855
|
+
or container_name == "winarena"
|
|
1856
|
+
):
|
|
1857
|
+
# Get detailed progress from docker logs
|
|
1858
|
+
log_check = subprocess.run(
|
|
1859
|
+
[
|
|
1860
|
+
"ssh",
|
|
1861
|
+
"-o",
|
|
1862
|
+
"StrictHostKeyChecking=no",
|
|
1863
|
+
"-o",
|
|
1864
|
+
"ConnectTimeout=5",
|
|
1865
|
+
"-i",
|
|
1866
|
+
str(Path.home() / ".ssh" / "id_rsa"),
|
|
1867
|
+
f"azureuser@{vm_ip}",
|
|
1868
|
+
f"docker logs {container_name} 2>&1 | tail -30",
|
|
1869
|
+
],
|
|
1870
|
+
capture_output=True,
|
|
1871
|
+
text=True,
|
|
1872
|
+
timeout=10,
|
|
1873
|
+
)
|
|
1874
|
+
logs = (
|
|
1875
|
+
log_check.stdout
|
|
1876
|
+
if log_check.returncode == 0
|
|
1877
|
+
else ""
|
|
1878
|
+
)
|
|
1879
|
+
|
|
1880
|
+
# Parse progress from logs
|
|
1881
|
+
phase = "unknown"
|
|
1882
|
+
progress = 0.0
|
|
1883
|
+
description = "Starting..."
|
|
1884
|
+
|
|
1885
|
+
if "Windows started successfully" in logs:
|
|
1886
|
+
# Check if WAA server is ready via Docker port forwarding
|
|
1887
|
+
# See docs/waa_network_architecture.md - always use localhost
|
|
1888
|
+
server_check = subprocess.run(
|
|
1889
|
+
[
|
|
1890
|
+
"ssh",
|
|
1891
|
+
"-o",
|
|
1892
|
+
"StrictHostKeyChecking=no",
|
|
1893
|
+
"-o",
|
|
1894
|
+
"ConnectTimeout=5",
|
|
1895
|
+
"-i",
|
|
1896
|
+
str(Path.home() / ".ssh" / "id_rsa"),
|
|
1897
|
+
f"azureuser@{vm_ip}",
|
|
1898
|
+
"curl -s --connect-timeout 2 http://localhost:5000/probe 2>/dev/null",
|
|
1899
|
+
],
|
|
1900
|
+
capture_output=True,
|
|
1901
|
+
text=True,
|
|
1902
|
+
timeout=10,
|
|
1903
|
+
)
|
|
1904
|
+
waa_ready = (
|
|
1905
|
+
server_check.returncode == 0
|
|
1906
|
+
and "Service is operational"
|
|
1907
|
+
in server_check.stdout
|
|
1908
|
+
)
|
|
1909
|
+
if waa_ready:
|
|
1910
|
+
phase = "ready"
|
|
1911
|
+
progress = 100.0
|
|
1912
|
+
description = (
|
|
1913
|
+
"WAA Server ready - benchmarks can run"
|
|
1914
|
+
)
|
|
1915
|
+
else:
|
|
1916
|
+
phase = "oobe"
|
|
1917
|
+
progress = 80.0 # Phase 5/6 - VM install in progress
|
|
1918
|
+
description = "Phase 5/6: Windows installing (check VNC for %). OEM scripts will run after."
|
|
1919
|
+
elif "Booting Windows" in logs:
|
|
1920
|
+
phase = "booting"
|
|
1921
|
+
progress = 70.0 # Phase 4/6
|
|
1922
|
+
description = "Phase 4/6: Booting Windows from installer..."
|
|
1923
|
+
elif (
|
|
1924
|
+
"Building Windows" in logs
|
|
1925
|
+
or "Creating a" in logs
|
|
1926
|
+
):
|
|
1927
|
+
phase = "building"
|
|
1928
|
+
progress = 60.0 # Phase 3/6
|
|
1929
|
+
description = (
|
|
1930
|
+
"Phase 3/6: Building Windows VM disk..."
|
|
1931
|
+
)
|
|
1932
|
+
elif "Adding" in logs and "image" in logs:
|
|
1933
|
+
phase = "configuring"
|
|
1934
|
+
progress = 50.0 # Phase 2/6
|
|
1935
|
+
description = "Phase 2/6: Configuring Windows image with WAA scripts..."
|
|
1936
|
+
elif "Extracting" in logs:
|
|
1937
|
+
phase = "extracting"
|
|
1938
|
+
progress = 35.0 # Phase 1/6 (after download)
|
|
1939
|
+
description = (
|
|
1940
|
+
"Phase 1/6: Extracting Windows ISO..."
|
|
1941
|
+
)
|
|
1942
|
+
else:
|
|
1943
|
+
# Check for download progress (e.g., "1234K ........ 45% 80M 30s")
|
|
1944
|
+
import re
|
|
1945
|
+
|
|
1946
|
+
download_match = re.search(
|
|
1947
|
+
r"(\d+)%\s+[\d.]+[KMG]\s+(\d+)s", logs
|
|
1948
|
+
)
|
|
1949
|
+
if download_match:
|
|
1950
|
+
phase = "downloading"
|
|
1951
|
+
dl_pct = float(download_match.group(1))
|
|
1952
|
+
progress = (
|
|
1953
|
+
dl_pct * 0.30
|
|
1954
|
+
) # 0-30% for download phase
|
|
1955
|
+
eta = download_match.group(2)
|
|
1956
|
+
description = f"Phase 0/6: Downloading Windows 11... {download_match.group(1)}% ({eta}s left)"
|
|
1957
|
+
|
|
1958
|
+
# Improve phase detection - if Windows is booted but WAA not ready,
|
|
1959
|
+
# it might be at login screen waiting for OEM scripts or running install.bat
|
|
1960
|
+
if phase == "oobe" and "Boot0004" in logs:
|
|
1961
|
+
# Windows finished installing, at login/desktop
|
|
1962
|
+
# install.bat should auto-run from FirstLogonCommands (see Dockerfile)
|
|
1963
|
+
description = "Phase 5/6: Windows at desktop, OEM scripts running... (WAA server starting)"
|
|
1964
|
+
progress = 90.0
|
|
1965
|
+
|
|
1966
|
+
# Get detailed metadata for VM Details panel
|
|
1967
|
+
vm_metadata = self._get_vm_detailed_metadata(
|
|
1968
|
+
vm_ip, container_name, logs, phase
|
|
1969
|
+
)
|
|
1970
|
+
|
|
1971
|
+
tasks.append(
|
|
1972
|
+
{
|
|
1973
|
+
"task_id": f"docker-{container_name}",
|
|
1974
|
+
"task_type": "docker_container",
|
|
1975
|
+
"status": "completed"
|
|
1976
|
+
if phase == "ready"
|
|
1977
|
+
else "running",
|
|
1978
|
+
"title": "Windows 11 + WAA Server",
|
|
1979
|
+
"description": description,
|
|
1980
|
+
"progress_percent": progress,
|
|
1981
|
+
"elapsed_seconds": 0,
|
|
1982
|
+
"phase": phase,
|
|
1983
|
+
"metadata": {
|
|
1984
|
+
"container": container_name,
|
|
1985
|
+
"image": image,
|
|
1986
|
+
"status": status,
|
|
1987
|
+
"phase": phase,
|
|
1988
|
+
"windows_ready": phase
|
|
1989
|
+
in ["oobe", "ready"],
|
|
1990
|
+
"waa_server_ready": phase == "ready",
|
|
1991
|
+
# Use localhost - SSH tunnel handles routing to VM
|
|
1992
|
+
# See docs/waa_network_architecture.md
|
|
1993
|
+
"vnc_url": "http://localhost:8006",
|
|
1994
|
+
"windows_username": "Docker",
|
|
1995
|
+
"windows_password": "admin",
|
|
1996
|
+
"recent_logs": logs[-500:]
|
|
1997
|
+
if logs
|
|
1998
|
+
else "",
|
|
1999
|
+
# Enhanced VM details
|
|
2000
|
+
"disk_usage_gb": vm_metadata[
|
|
2001
|
+
"disk_usage_gb"
|
|
2002
|
+
],
|
|
2003
|
+
"memory_usage_mb": vm_metadata[
|
|
2004
|
+
"memory_usage_mb"
|
|
2005
|
+
],
|
|
2006
|
+
"setup_script_phase": vm_metadata[
|
|
2007
|
+
"setup_script_phase"
|
|
2008
|
+
],
|
|
2009
|
+
"probe_response": vm_metadata[
|
|
2010
|
+
"probe_response"
|
|
2011
|
+
],
|
|
2012
|
+
"qmp_connected": vm_metadata[
|
|
2013
|
+
"qmp_connected"
|
|
2014
|
+
],
|
|
2015
|
+
"dependencies": vm_metadata[
|
|
2016
|
+
"dependencies"
|
|
2017
|
+
],
|
|
2018
|
+
},
|
|
2019
|
+
}
|
|
2020
|
+
)
|
|
2021
|
+
except Exception:
|
|
2022
|
+
# SSH failed, VM might still be starting
|
|
2023
|
+
pass
|
|
2024
|
+
|
|
2025
|
+
# 3. Check local benchmark progress
|
|
2026
|
+
progress_file = Path("benchmark_progress.json")
|
|
2027
|
+
if progress_file.exists():
|
|
2028
|
+
try:
|
|
2029
|
+
progress = json.loads(progress_file.read_text())
|
|
2030
|
+
if progress.get("status") == "running":
|
|
2031
|
+
tasks.append(
|
|
2032
|
+
{
|
|
2033
|
+
"task_id": "benchmark-local",
|
|
2034
|
+
"task_type": "benchmark_run",
|
|
2035
|
+
"status": "running",
|
|
2036
|
+
"title": f"{progress.get('provider', 'API').upper()} Benchmark",
|
|
2037
|
+
"description": progress.get(
|
|
2038
|
+
"message", "Running benchmark..."
|
|
2039
|
+
),
|
|
2040
|
+
"progress_percent": (
|
|
2041
|
+
progress.get("tasks_complete", 0)
|
|
2042
|
+
/ max(progress.get("tasks_total", 1), 1)
|
|
2043
|
+
)
|
|
2044
|
+
* 100,
|
|
2045
|
+
"elapsed_seconds": 0,
|
|
2046
|
+
"metadata": progress,
|
|
2047
|
+
}
|
|
2048
|
+
)
|
|
2049
|
+
except Exception:
|
|
2050
|
+
pass
|
|
2051
|
+
|
|
2052
|
+
return tasks
|
|
2053
|
+
|
|
2054
|
+
def _fetch_vm_registry(self):
|
|
2055
|
+
"""Fetch VM registry with live status checks.
|
|
2056
|
+
|
|
2057
|
+
NOTE: We now fetch the VM IP from Azure CLI at runtime to avoid
|
|
2058
|
+
stale IP issues. The registry file is only used as a fallback.
|
|
2059
|
+
"""
|
|
2060
|
+
import subprocess
|
|
2061
|
+
from datetime import datetime
|
|
2062
|
+
|
|
2063
|
+
# Try to get VM IP from Azure CLI (always fresh)
|
|
2064
|
+
vm_ip = None
|
|
2065
|
+
resource_group = "openadapt-agents"
|
|
2066
|
+
vm_name = "azure-waa-vm"
|
|
2067
|
+
try:
|
|
2068
|
+
result = subprocess.run(
|
|
2069
|
+
[
|
|
2070
|
+
"az",
|
|
2071
|
+
"vm",
|
|
2072
|
+
"show",
|
|
2073
|
+
"-d",
|
|
2074
|
+
"-g",
|
|
2075
|
+
resource_group,
|
|
2076
|
+
"-n",
|
|
2077
|
+
vm_name,
|
|
2078
|
+
"--query",
|
|
2079
|
+
"publicIps",
|
|
2080
|
+
"-o",
|
|
2081
|
+
"tsv",
|
|
2082
|
+
],
|
|
2083
|
+
capture_output=True,
|
|
2084
|
+
text=True,
|
|
2085
|
+
timeout=10,
|
|
2086
|
+
)
|
|
2087
|
+
if result.returncode == 0 and result.stdout.strip():
|
|
2088
|
+
vm_ip = result.stdout.strip()
|
|
2089
|
+
except Exception:
|
|
2090
|
+
pass
|
|
2091
|
+
|
|
2092
|
+
# If we have a fresh IP from Azure, use it
|
|
2093
|
+
if vm_ip:
|
|
2094
|
+
vms = [
|
|
2095
|
+
{
|
|
2096
|
+
"name": vm_name,
|
|
2097
|
+
"ssh_host": vm_ip,
|
|
2098
|
+
"ssh_user": "azureuser",
|
|
2099
|
+
"vnc_port": 8006,
|
|
2100
|
+
"waa_port": 5000,
|
|
2101
|
+
"docker_container": "winarena",
|
|
2102
|
+
"internal_ip": "localhost",
|
|
2103
|
+
}
|
|
2104
|
+
]
|
|
2105
|
+
else:
|
|
2106
|
+
# Fallback to registry file
|
|
2107
|
+
project_root = Path(__file__).parent.parent.parent
|
|
2108
|
+
registry_file = project_root / "benchmark_results" / "vm_registry.json"
|
|
2109
|
+
|
|
2110
|
+
if not registry_file.exists():
|
|
2111
|
+
return []
|
|
2112
|
+
|
|
2113
|
+
try:
|
|
2114
|
+
with open(registry_file) as f:
|
|
2115
|
+
vms = json.load(f)
|
|
2116
|
+
except Exception as e:
|
|
2117
|
+
return {"error": f"Failed to read VM registry: {e}"}
|
|
2118
|
+
|
|
2119
|
+
# Check status for each VM
|
|
2120
|
+
for vm in vms:
|
|
2121
|
+
vm["status"] = "unknown"
|
|
2122
|
+
vm["last_checked"] = datetime.now().isoformat()
|
|
2123
|
+
vm["vnc_reachable"] = False
|
|
2124
|
+
vm["waa_probe_status"] = "unknown"
|
|
2125
|
+
|
|
2126
|
+
# Check VNC (HTTP HEAD request)
|
|
2127
|
+
try:
|
|
2128
|
+
vnc_url = f"http://{vm['ssh_host']}:{vm['vnc_port']}"
|
|
2129
|
+
result = subprocess.run(
|
|
2130
|
+
["curl", "-I", "-s", "--connect-timeout", "3", vnc_url],
|
|
2131
|
+
capture_output=True,
|
|
2132
|
+
text=True,
|
|
2133
|
+
timeout=5,
|
|
2134
|
+
)
|
|
2135
|
+
if result.returncode == 0 and "200" in result.stdout:
|
|
2136
|
+
vm["vnc_reachable"] = True
|
|
2137
|
+
except Exception:
|
|
2138
|
+
pass
|
|
2139
|
+
|
|
2140
|
+
# Check WAA probe via SSH
|
|
2141
|
+
# Probe WAA via localhost (Docker port forwarding handles routing)
|
|
2142
|
+
# See docs/waa_network_architecture.md for architecture details
|
|
2143
|
+
try:
|
|
2144
|
+
waa_port = vm.get("waa_port", 5000)
|
|
2145
|
+
ssh_cmd = f"curl -s --connect-timeout 2 http://localhost:{waa_port}/probe 2>/dev/null"
|
|
2146
|
+
result = subprocess.run(
|
|
2147
|
+
[
|
|
2148
|
+
"ssh",
|
|
2149
|
+
"-o",
|
|
2150
|
+
"StrictHostKeyChecking=no",
|
|
2151
|
+
"-o",
|
|
2152
|
+
"ConnectTimeout=3",
|
|
2153
|
+
"-i",
|
|
2154
|
+
str(Path.home() / ".ssh" / "id_rsa"),
|
|
2155
|
+
f"{vm['ssh_user']}@{vm['ssh_host']}",
|
|
2156
|
+
ssh_cmd,
|
|
2157
|
+
],
|
|
2158
|
+
capture_output=True,
|
|
2159
|
+
text=True,
|
|
2160
|
+
timeout=5,
|
|
2161
|
+
)
|
|
2162
|
+
probe_success = (
|
|
2163
|
+
result.returncode == 0
|
|
2164
|
+
and "Service is operational" in result.stdout
|
|
2165
|
+
)
|
|
2166
|
+
if probe_success:
|
|
2167
|
+
vm["waa_probe_status"] = "ready"
|
|
2168
|
+
vm["status"] = "online"
|
|
2169
|
+
# Auto-start SSH tunnels for VNC and WAA
|
|
2170
|
+
try:
|
|
2171
|
+
tunnel_mgr = get_tunnel_manager()
|
|
2172
|
+
tunnel_status = tunnel_mgr.ensure_tunnels_for_vm(
|
|
2173
|
+
vm_ip=vm["ssh_host"],
|
|
2174
|
+
ssh_user=vm.get("ssh_user", "azureuser"),
|
|
2175
|
+
)
|
|
2176
|
+
vm["tunnels"] = {
|
|
2177
|
+
name: {
|
|
2178
|
+
"active": s.active,
|
|
2179
|
+
"local_port": s.local_port,
|
|
2180
|
+
"error": s.error,
|
|
2181
|
+
}
|
|
2182
|
+
for name, s in tunnel_status.items()
|
|
2183
|
+
}
|
|
2184
|
+
except Exception as e:
|
|
2185
|
+
vm["tunnels"] = {"error": str(e)}
|
|
2186
|
+
else:
|
|
2187
|
+
vm["waa_probe_status"] = "not responding"
|
|
2188
|
+
vm["status"] = "offline"
|
|
2189
|
+
# Stop tunnels when VM goes offline
|
|
2190
|
+
try:
|
|
2191
|
+
tunnel_mgr = get_tunnel_manager()
|
|
2192
|
+
tunnel_mgr.stop_all_tunnels()
|
|
2193
|
+
vm["tunnels"] = {}
|
|
2194
|
+
except Exception:
|
|
2195
|
+
pass
|
|
2196
|
+
except Exception:
|
|
2197
|
+
vm["waa_probe_status"] = "ssh failed"
|
|
2198
|
+
vm["status"] = "offline"
|
|
2199
|
+
|
|
2200
|
+
return vms
|
|
2201
|
+
|
|
2202
|
+
def _probe_vm(self) -> dict:
|
|
2203
|
+
"""Probe the Azure VM to check if WAA server is responding.
|
|
2204
|
+
|
|
2205
|
+
Returns:
|
|
2206
|
+
dict with:
|
|
2207
|
+
- responding: bool - whether the WAA server is responding
|
|
2208
|
+
- vm_ip: str - the VM's IP address
|
|
2209
|
+
- container: str - the container name
|
|
2210
|
+
- probe_result: str - the raw probe response or error message
|
|
2211
|
+
- last_checked: str - ISO timestamp
|
|
2212
|
+
"""
|
|
2213
|
+
import subprocess
|
|
2214
|
+
from datetime import datetime
|
|
2215
|
+
|
|
2216
|
+
result = {
|
|
2217
|
+
"responding": False,
|
|
2218
|
+
"vm_ip": None,
|
|
2219
|
+
"container": None,
|
|
2220
|
+
"probe_result": None,
|
|
2221
|
+
"last_checked": datetime.now().isoformat(),
|
|
2222
|
+
}
|
|
2223
|
+
|
|
2224
|
+
# First get VM IP
|
|
2225
|
+
try:
|
|
2226
|
+
ip_result = subprocess.run(
|
|
2227
|
+
[
|
|
2228
|
+
"az",
|
|
2229
|
+
"vm",
|
|
2230
|
+
"list-ip-addresses",
|
|
2231
|
+
"--name",
|
|
2232
|
+
"waa-eval-vm",
|
|
2233
|
+
"--resource-group",
|
|
2234
|
+
"openadapt-agents",
|
|
2235
|
+
"--query",
|
|
2236
|
+
"[0].virtualMachine.network.publicIpAddresses[0].ipAddress",
|
|
2237
|
+
"-o",
|
|
2238
|
+
"tsv",
|
|
2239
|
+
],
|
|
2240
|
+
capture_output=True,
|
|
2241
|
+
text=True,
|
|
2242
|
+
timeout=10,
|
|
2243
|
+
)
|
|
2244
|
+
if ip_result.returncode == 0 and ip_result.stdout.strip():
|
|
2245
|
+
vm_ip = ip_result.stdout.strip()
|
|
2246
|
+
result["vm_ip"] = vm_ip
|
|
2247
|
+
|
|
2248
|
+
# Try to probe WAA server via SSH
|
|
2249
|
+
# Use the correct internal IP for the Windows VM inside Docker
|
|
2250
|
+
probe_result = subprocess.run(
|
|
2251
|
+
[
|
|
2252
|
+
"ssh",
|
|
2253
|
+
"-o",
|
|
2254
|
+
"StrictHostKeyChecking=no",
|
|
2255
|
+
"-o",
|
|
2256
|
+
"ConnectTimeout=5",
|
|
2257
|
+
"-i",
|
|
2258
|
+
str(Path.home() / ".ssh" / "id_rsa"),
|
|
2259
|
+
f"azureuser@{vm_ip}",
|
|
2260
|
+
"docker exec waa-container curl -s --connect-timeout 3 http://172.30.0.2:5000/probe 2>/dev/null || echo 'probe_failed'",
|
|
2261
|
+
],
|
|
2262
|
+
capture_output=True,
|
|
2263
|
+
text=True,
|
|
2264
|
+
timeout=15,
|
|
2265
|
+
)
|
|
2266
|
+
|
|
2267
|
+
result["container"] = "waa-container"
|
|
2268
|
+
|
|
2269
|
+
if probe_result.returncode == 0:
|
|
2270
|
+
probe_output = probe_result.stdout.strip()
|
|
2271
|
+
if probe_output and "probe_failed" not in probe_output:
|
|
2272
|
+
result["responding"] = True
|
|
2273
|
+
result["probe_result"] = probe_output
|
|
2274
|
+
else:
|
|
2275
|
+
result["probe_result"] = "WAA server not responding"
|
|
2276
|
+
else:
|
|
2277
|
+
result["probe_result"] = (
|
|
2278
|
+
f"SSH/Docker error: {probe_result.stderr[:200]}"
|
|
2279
|
+
)
|
|
2280
|
+
else:
|
|
2281
|
+
result["probe_result"] = "Could not get VM IP"
|
|
2282
|
+
except subprocess.TimeoutExpired:
|
|
2283
|
+
result["probe_result"] = "Connection timeout"
|
|
2284
|
+
except Exception as e:
|
|
2285
|
+
result["probe_result"] = f"Error: {str(e)}"
|
|
2286
|
+
|
|
2287
|
+
return result
|
|
2288
|
+
|
|
2289
|
+
def _get_current_run(self) -> dict:
|
|
2290
|
+
"""Get info about any currently running benchmark.
|
|
2291
|
+
|
|
2292
|
+
Checks:
|
|
2293
|
+
1. Local benchmark_progress.json for API benchmarks
|
|
2294
|
+
2. Azure VM for WAA benchmarks running via SSH
|
|
2295
|
+
|
|
2296
|
+
Returns:
|
|
2297
|
+
dict with:
|
|
2298
|
+
- running: bool - whether a benchmark is running
|
|
2299
|
+
- type: str - 'local' or 'azure_vm'
|
|
2300
|
+
- model: str - model being evaluated
|
|
2301
|
+
- progress: dict with tasks_completed, total_tasks
|
|
2302
|
+
- current_task: str - current task ID
|
|
2303
|
+
- started_at: str - ISO timestamp
|
|
2304
|
+
- elapsed_minutes: int
|
|
2305
|
+
"""
|
|
2306
|
+
import subprocess
|
|
2307
|
+
import re
|
|
2308
|
+
|
|
2309
|
+
result = {
|
|
2310
|
+
"running": False,
|
|
2311
|
+
"type": None,
|
|
2312
|
+
"model": None,
|
|
2313
|
+
"progress": {"tasks_completed": 0, "total_tasks": 0},
|
|
2314
|
+
"current_task": None,
|
|
2315
|
+
"started_at": None,
|
|
2316
|
+
"elapsed_minutes": 0,
|
|
2317
|
+
}
|
|
2318
|
+
|
|
2319
|
+
# Check local benchmark progress first
|
|
2320
|
+
progress_file = Path("benchmark_progress.json")
|
|
2321
|
+
if progress_file.exists():
|
|
2322
|
+
try:
|
|
2323
|
+
progress = json.loads(progress_file.read_text())
|
|
2324
|
+
if progress.get("status") == "running":
|
|
2325
|
+
result["running"] = True
|
|
2326
|
+
result["type"] = "local"
|
|
2327
|
+
result["model"] = progress.get("provider", "unknown")
|
|
2328
|
+
result["progress"]["tasks_completed"] = progress.get(
|
|
2329
|
+
"tasks_complete", 0
|
|
2330
|
+
)
|
|
2331
|
+
result["progress"]["total_tasks"] = progress.get(
|
|
2332
|
+
"tasks_total", 0
|
|
2333
|
+
)
|
|
2334
|
+
return result
|
|
2335
|
+
except Exception:
|
|
2336
|
+
pass
|
|
2337
|
+
|
|
2338
|
+
# Check Azure VM for running benchmark
|
|
2339
|
+
try:
|
|
2340
|
+
# Get VM IP
|
|
2341
|
+
ip_result = subprocess.run(
|
|
2342
|
+
[
|
|
2343
|
+
"az",
|
|
2344
|
+
"vm",
|
|
2345
|
+
"list-ip-addresses",
|
|
2346
|
+
"--name",
|
|
2347
|
+
"waa-eval-vm",
|
|
2348
|
+
"--resource-group",
|
|
2349
|
+
"openadapt-agents",
|
|
2350
|
+
"--query",
|
|
2351
|
+
"[0].virtualMachine.network.publicIpAddresses[0].ipAddress",
|
|
2352
|
+
"-o",
|
|
2353
|
+
"tsv",
|
|
2354
|
+
],
|
|
2355
|
+
capture_output=True,
|
|
2356
|
+
text=True,
|
|
2357
|
+
timeout=10,
|
|
2358
|
+
)
|
|
2359
|
+
|
|
2360
|
+
if ip_result.returncode == 0 and ip_result.stdout.strip():
|
|
2361
|
+
vm_ip = ip_result.stdout.strip()
|
|
2362
|
+
|
|
2363
|
+
# Check if benchmark process is running
|
|
2364
|
+
process_check = subprocess.run(
|
|
2365
|
+
[
|
|
2366
|
+
"ssh",
|
|
2367
|
+
"-o",
|
|
2368
|
+
"StrictHostKeyChecking=no",
|
|
2369
|
+
"-o",
|
|
2370
|
+
"ConnectTimeout=5",
|
|
2371
|
+
"-i",
|
|
2372
|
+
str(Path.home() / ".ssh" / "id_rsa"),
|
|
2373
|
+
f"azureuser@{vm_ip}",
|
|
2374
|
+
"docker exec waa-container pgrep -f 'python.*run.py' 2>/dev/null && echo 'RUNNING' || echo 'NOT_RUNNING'",
|
|
2375
|
+
],
|
|
2376
|
+
capture_output=True,
|
|
2377
|
+
text=True,
|
|
2378
|
+
timeout=10,
|
|
2379
|
+
)
|
|
2380
|
+
|
|
2381
|
+
if (
|
|
2382
|
+
process_check.returncode == 0
|
|
2383
|
+
and "RUNNING" in process_check.stdout
|
|
2384
|
+
):
|
|
2385
|
+
result["running"] = True
|
|
2386
|
+
result["type"] = "azure_vm"
|
|
2387
|
+
|
|
2388
|
+
# Get log file for more details
|
|
2389
|
+
log_check = subprocess.run(
|
|
2390
|
+
[
|
|
2391
|
+
"ssh",
|
|
2392
|
+
"-o",
|
|
2393
|
+
"StrictHostKeyChecking=no",
|
|
2394
|
+
"-o",
|
|
2395
|
+
"ConnectTimeout=5",
|
|
2396
|
+
"-i",
|
|
2397
|
+
str(Path.home() / ".ssh" / "id_rsa"),
|
|
2398
|
+
f"azureuser@{vm_ip}",
|
|
2399
|
+
"tail -100 /tmp/waa_benchmark.log 2>/dev/null || echo ''",
|
|
2400
|
+
],
|
|
2401
|
+
capture_output=True,
|
|
2402
|
+
text=True,
|
|
2403
|
+
timeout=10,
|
|
2404
|
+
)
|
|
2405
|
+
|
|
2406
|
+
if log_check.returncode == 0 and log_check.stdout.strip():
|
|
2407
|
+
logs = log_check.stdout
|
|
2408
|
+
|
|
2409
|
+
# Parse model from logs
|
|
2410
|
+
model_match = re.search(
|
|
2411
|
+
r"model[=:\s]+([^\s,]+)", logs, re.IGNORECASE
|
|
2412
|
+
)
|
|
2413
|
+
if model_match:
|
|
2414
|
+
result["model"] = model_match.group(1)
|
|
2415
|
+
|
|
2416
|
+
# Parse progress
|
|
2417
|
+
task_match = re.search(r"Task\s+(\d+)/(\d+)", logs)
|
|
2418
|
+
if task_match:
|
|
2419
|
+
result["progress"]["tasks_completed"] = int(
|
|
2420
|
+
task_match.group(1)
|
|
2421
|
+
)
|
|
2422
|
+
result["progress"]["total_tasks"] = int(
|
|
2423
|
+
task_match.group(2)
|
|
2424
|
+
)
|
|
2425
|
+
|
|
2426
|
+
# Parse current task
|
|
2427
|
+
task_id_match = re.search(
|
|
2428
|
+
r"(?:Running|Processing|task)[:\s]+([a-f0-9-]+)",
|
|
2429
|
+
logs,
|
|
2430
|
+
re.IGNORECASE,
|
|
2431
|
+
)
|
|
2432
|
+
if task_id_match:
|
|
2433
|
+
result["current_task"] = task_id_match.group(1)
|
|
2434
|
+
|
|
2435
|
+
except Exception:
|
|
2436
|
+
pass
|
|
2437
|
+
|
|
2438
|
+
return result
|
|
2439
|
+
|
|
2440
|
+
def _get_benchmark_status(self) -> dict:
|
|
2441
|
+
"""Get current benchmark job status with ETA calculation.
|
|
2442
|
+
|
|
2443
|
+
Returns:
|
|
2444
|
+
dict with job status, progress, ETA, and current task info
|
|
2445
|
+
"""
|
|
2446
|
+
import time
|
|
2447
|
+
|
|
2448
|
+
# Check for live evaluation state
|
|
2449
|
+
live_file = Path("benchmark_live.json")
|
|
2450
|
+
if live_file.exists():
|
|
2451
|
+
try:
|
|
2452
|
+
live_state = json.loads(live_file.read_text())
|
|
2453
|
+
if live_state.get("status") == "running":
|
|
2454
|
+
total_tasks = live_state.get("total_tasks", 0)
|
|
2455
|
+
completed_tasks = live_state.get("tasks_completed", 0)
|
|
2456
|
+
current_task = live_state.get("current_task", {})
|
|
2457
|
+
|
|
2458
|
+
# Calculate ETA based on completed tasks
|
|
2459
|
+
eta_seconds = None
|
|
2460
|
+
avg_task_seconds = None
|
|
2461
|
+
if completed_tasks > 0 and total_tasks > 0:
|
|
2462
|
+
# Estimate from live state timestamp or use fallback
|
|
2463
|
+
elapsed = time.time() - live_state.get(
|
|
2464
|
+
"start_time", time.time()
|
|
2465
|
+
)
|
|
2466
|
+
avg_task_seconds = (
|
|
2467
|
+
elapsed / completed_tasks
|
|
2468
|
+
if completed_tasks > 0
|
|
2469
|
+
else 30.0
|
|
2470
|
+
)
|
|
2471
|
+
remaining_tasks = total_tasks - completed_tasks
|
|
2472
|
+
eta_seconds = remaining_tasks * avg_task_seconds
|
|
2473
|
+
|
|
2474
|
+
return {
|
|
2475
|
+
"status": "running",
|
|
2476
|
+
"current_job": {
|
|
2477
|
+
"run_id": live_state.get("run_id", "unknown"),
|
|
2478
|
+
"model_id": live_state.get("model_id", "unknown"),
|
|
2479
|
+
"total_tasks": total_tasks,
|
|
2480
|
+
"completed_tasks": completed_tasks,
|
|
2481
|
+
"current_task": current_task,
|
|
2482
|
+
"eta_seconds": eta_seconds,
|
|
2483
|
+
"avg_task_seconds": avg_task_seconds,
|
|
2484
|
+
},
|
|
2485
|
+
"queue": [], # TODO: implement queue tracking
|
|
2486
|
+
}
|
|
2487
|
+
except Exception as e:
|
|
2488
|
+
return {"status": "error", "error": str(e)}
|
|
2489
|
+
|
|
2490
|
+
# Fallback to current_run check
|
|
2491
|
+
current_run = self._get_current_run()
|
|
2492
|
+
if current_run.get("running"):
|
|
2493
|
+
return {
|
|
2494
|
+
"status": "running",
|
|
2495
|
+
"current_job": {
|
|
2496
|
+
"run_id": "unknown",
|
|
2497
|
+
"model_id": current_run.get("model", "unknown"),
|
|
2498
|
+
"total_tasks": current_run["progress"]["total_tasks"],
|
|
2499
|
+
"completed_tasks": current_run["progress"]["tasks_completed"],
|
|
2500
|
+
"current_task": {"task_id": current_run.get("current_task")},
|
|
2501
|
+
},
|
|
2502
|
+
"queue": [],
|
|
2503
|
+
}
|
|
2504
|
+
|
|
2505
|
+
return {"status": "idle"}
|
|
2506
|
+
|
|
2507
|
+
def _get_benchmark_costs(self) -> dict:
|
|
2508
|
+
"""Get cost breakdown for current benchmark run.
|
|
2509
|
+
|
|
2510
|
+
Returns:
|
|
2511
|
+
dict with Azure VM, API calls, and GPU costs
|
|
2512
|
+
"""
|
|
2513
|
+
|
|
2514
|
+
# Check for cost tracking file
|
|
2515
|
+
cost_file = Path("benchmark_costs.json")
|
|
2516
|
+
if cost_file.exists():
|
|
2517
|
+
try:
|
|
2518
|
+
return json.loads(cost_file.read_text())
|
|
2519
|
+
except Exception:
|
|
2520
|
+
pass
|
|
2521
|
+
|
|
2522
|
+
# Return placeholder structure
|
|
2523
|
+
return {
|
|
2524
|
+
"azure_vm": {
|
|
2525
|
+
"instance_type": "Standard_D4ds_v5",
|
|
2526
|
+
"hourly_rate_usd": 0.192,
|
|
2527
|
+
"hours_elapsed": 0.0,
|
|
2528
|
+
"cost_usd": 0.0,
|
|
2529
|
+
},
|
|
2530
|
+
"api_calls": {
|
|
2531
|
+
"anthropic": {"cost_usd": 0.0},
|
|
2532
|
+
"openai": {"cost_usd": 0.0},
|
|
2533
|
+
},
|
|
2534
|
+
"gpu_time": {
|
|
2535
|
+
"lambda_labs": {"cost_usd": 0.0},
|
|
2536
|
+
},
|
|
2537
|
+
"total_cost_usd": 0.0,
|
|
2538
|
+
}
|
|
2539
|
+
|
|
2540
|
+
def _get_benchmark_metrics(self) -> dict:
|
|
2541
|
+
"""Get performance metrics for current/completed benchmarks.
|
|
2542
|
+
|
|
2543
|
+
Returns:
|
|
2544
|
+
dict with success rate trends, domain breakdown, episode metrics
|
|
2545
|
+
"""
|
|
2546
|
+
# Check for metrics file
|
|
2547
|
+
metrics_file = Path("benchmark_metrics.json")
|
|
2548
|
+
if metrics_file.exists():
|
|
2549
|
+
try:
|
|
2550
|
+
return json.loads(metrics_file.read_text())
|
|
2551
|
+
except Exception:
|
|
2552
|
+
pass
|
|
2553
|
+
|
|
2554
|
+
# Load completed runs from benchmark_results/
|
|
2555
|
+
benchmark_results_dir = Path("benchmark_results")
|
|
2556
|
+
if not benchmark_results_dir.exists():
|
|
2557
|
+
return {"error": "No benchmark results found"}
|
|
2558
|
+
|
|
2559
|
+
# Find most recent run
|
|
2560
|
+
runs = sorted(
|
|
2561
|
+
benchmark_results_dir.iterdir(),
|
|
2562
|
+
key=lambda p: p.stat().st_mtime,
|
|
2563
|
+
reverse=True,
|
|
2564
|
+
)
|
|
2565
|
+
if not runs:
|
|
2566
|
+
return {"error": "No benchmark runs found"}
|
|
2567
|
+
|
|
2568
|
+
recent_run = runs[0]
|
|
2569
|
+
summary_path = recent_run / "summary.json"
|
|
2570
|
+
if not summary_path.exists():
|
|
2571
|
+
return {"error": f"No summary.json in {recent_run.name}"}
|
|
2572
|
+
|
|
2573
|
+
try:
|
|
2574
|
+
summary = json.loads(summary_path.read_text())
|
|
2575
|
+
|
|
2576
|
+
# Build domain breakdown from tasks
|
|
2577
|
+
domain_breakdown = {}
|
|
2578
|
+
tasks_dir = recent_run / "tasks"
|
|
2579
|
+
if tasks_dir.exists():
|
|
2580
|
+
for task_dir in tasks_dir.iterdir():
|
|
2581
|
+
if not task_dir.is_dir():
|
|
2582
|
+
continue
|
|
2583
|
+
|
|
2584
|
+
task_json = task_dir / "task.json"
|
|
2585
|
+
execution_json = task_dir / "execution.json"
|
|
2586
|
+
if not (task_json.exists() and execution_json.exists()):
|
|
2587
|
+
continue
|
|
2588
|
+
|
|
2589
|
+
try:
|
|
2590
|
+
task_def = json.loads(task_json.read_text())
|
|
2591
|
+
execution = json.loads(execution_json.read_text())
|
|
2592
|
+
|
|
2593
|
+
domain = task_def.get("domain", "unknown")
|
|
2594
|
+
if domain not in domain_breakdown:
|
|
2595
|
+
domain_breakdown[domain] = {
|
|
2596
|
+
"total": 0,
|
|
2597
|
+
"success": 0,
|
|
2598
|
+
"rate": 0.0,
|
|
2599
|
+
"avg_steps": 0.0,
|
|
2600
|
+
"total_steps": 0,
|
|
2601
|
+
}
|
|
2602
|
+
|
|
2603
|
+
domain_breakdown[domain]["total"] += 1
|
|
2604
|
+
if execution.get("success"):
|
|
2605
|
+
domain_breakdown[domain]["success"] += 1
|
|
2606
|
+
domain_breakdown[domain]["total_steps"] += execution.get(
|
|
2607
|
+
"num_steps", 0
|
|
2608
|
+
)
|
|
2609
|
+
|
|
2610
|
+
except Exception:
|
|
2611
|
+
continue
|
|
2612
|
+
|
|
2613
|
+
# Calculate averages
|
|
2614
|
+
for domain, stats in domain_breakdown.items():
|
|
2615
|
+
if stats["total"] > 0:
|
|
2616
|
+
stats["rate"] = stats["success"] / stats["total"]
|
|
2617
|
+
stats["avg_steps"] = stats["total_steps"] / stats["total"]
|
|
2618
|
+
|
|
2619
|
+
return {
|
|
2620
|
+
"success_rate_over_time": [], # TODO: implement trend tracking
|
|
2621
|
+
"avg_steps_per_task": [], # TODO: implement trend tracking
|
|
2622
|
+
"domain_breakdown": domain_breakdown,
|
|
2623
|
+
"episode_success_metrics": {
|
|
2624
|
+
"first_action_accuracy": summary.get(
|
|
2625
|
+
"first_action_accuracy", 0.0
|
|
2626
|
+
),
|
|
2627
|
+
"episode_success_rate": summary.get("success_rate", 0.0),
|
|
2628
|
+
"avg_steps_to_success": summary.get("avg_steps", 0.0),
|
|
2629
|
+
"avg_steps_to_failure": 0.0, # TODO: calculate from failed tasks
|
|
2630
|
+
},
|
|
2631
|
+
}
|
|
2632
|
+
except Exception as e:
|
|
2633
|
+
return {"error": f"Failed to load metrics: {str(e)}"}
|
|
2634
|
+
|
|
2635
|
+
def _get_benchmark_workers(self) -> dict:
|
|
2636
|
+
"""Get worker status and utilization.
|
|
2637
|
+
|
|
2638
|
+
Returns:
|
|
2639
|
+
dict with total/active/idle workers and per-worker stats
|
|
2640
|
+
"""
|
|
2641
|
+
# Get VM registry
|
|
2642
|
+
vms = self._fetch_vm_registry()
|
|
2643
|
+
|
|
2644
|
+
active_workers = [v for v in vms if v.get("status") == "online"]
|
|
2645
|
+
idle_workers = [v for v in vms if v.get("status") != "online"]
|
|
2646
|
+
|
|
2647
|
+
workers = []
|
|
2648
|
+
for vm in vms:
|
|
2649
|
+
workers.append(
|
|
2650
|
+
{
|
|
2651
|
+
"worker_id": vm.get("name", "unknown"),
|
|
2652
|
+
"status": "running" if vm.get("status") == "online" else "idle",
|
|
2653
|
+
"current_task": vm.get("current_task"),
|
|
2654
|
+
"tasks_completed": vm.get("tasks_completed", 0),
|
|
2655
|
+
"uptime_seconds": vm.get("uptime_seconds", 0),
|
|
2656
|
+
"idle_time_seconds": vm.get("idle_time_seconds", 0),
|
|
2657
|
+
}
|
|
2658
|
+
)
|
|
2659
|
+
|
|
2660
|
+
return {
|
|
2661
|
+
"total_workers": len(vms),
|
|
2662
|
+
"active_workers": len(active_workers),
|
|
2663
|
+
"idle_workers": len(idle_workers),
|
|
2664
|
+
"workers": workers,
|
|
2665
|
+
}
|
|
2666
|
+
|
|
2667
|
+
def _get_benchmark_runs(self) -> list[dict]:
|
|
2668
|
+
"""Load all benchmark runs from benchmark_results directory.
|
|
2669
|
+
|
|
2670
|
+
Returns:
|
|
2671
|
+
List of benchmark run summaries sorted by timestamp (newest first)
|
|
2672
|
+
"""
|
|
2673
|
+
results_dir = Path("benchmark_results")
|
|
2674
|
+
if not results_dir.exists():
|
|
2675
|
+
return []
|
|
2676
|
+
|
|
2677
|
+
runs = []
|
|
2678
|
+
for run_dir in results_dir.iterdir():
|
|
2679
|
+
if run_dir.is_dir():
|
|
2680
|
+
summary_file = run_dir / "summary.json"
|
|
2681
|
+
if summary_file.exists():
|
|
2682
|
+
try:
|
|
2683
|
+
summary = json.loads(summary_file.read_text())
|
|
2684
|
+
runs.append(summary)
|
|
2685
|
+
except (json.JSONDecodeError, IOError) as e:
|
|
2686
|
+
print(f"Warning: Failed to load {summary_file}: {e}")
|
|
2687
|
+
|
|
2688
|
+
# Sort by run_name descending (newest first)
|
|
2689
|
+
runs.sort(key=lambda r: r.get("run_name", ""), reverse=True)
|
|
2690
|
+
return runs
|
|
2691
|
+
|
|
2692
|
+
def _get_task_execution(self, run_name: str, task_id: str) -> dict:
|
|
2693
|
+
"""Load task execution details from execution.json.
|
|
2694
|
+
|
|
2695
|
+
Args:
|
|
2696
|
+
run_name: Name of the benchmark run
|
|
2697
|
+
task_id: Task identifier
|
|
2698
|
+
|
|
2699
|
+
Returns:
|
|
2700
|
+
Task execution data with steps and screenshots
|
|
2701
|
+
"""
|
|
2702
|
+
results_dir = Path("benchmark_results")
|
|
2703
|
+
execution_file = (
|
|
2704
|
+
results_dir / run_name / "tasks" / task_id / "execution.json"
|
|
2705
|
+
)
|
|
2706
|
+
|
|
2707
|
+
if not execution_file.exists():
|
|
2708
|
+
raise FileNotFoundError(f"Execution file not found: {execution_file}")
|
|
2709
|
+
|
|
2710
|
+
try:
|
|
2711
|
+
return json.loads(execution_file.read_text())
|
|
2712
|
+
except (json.JSONDecodeError, IOError) as e:
|
|
2713
|
+
raise Exception(f"Failed to load execution data: {e}")
|
|
2714
|
+
|
|
2715
|
+
async def _detect_running_benchmark(
|
|
2716
|
+
self, vm_ip: str, container_name: str = "winarena"
|
|
2717
|
+
) -> dict:
|
|
2718
|
+
"""Detect if a benchmark is running on the VM and extract progress.
|
|
2719
|
+
|
|
2720
|
+
SSH into VM and check:
|
|
2721
|
+
1. Process running: docker exec {container} pgrep -f 'python.*run.py'
|
|
2722
|
+
2. Log progress: tail /tmp/waa_benchmark.log
|
|
2723
|
+
|
|
2724
|
+
Returns:
|
|
2725
|
+
dict with:
|
|
2726
|
+
- running: bool
|
|
2727
|
+
- current_task: str (task ID or description)
|
|
2728
|
+
- progress: dict with tasks_completed, total_tasks, current_step
|
|
2729
|
+
- recent_logs: str (last few log lines)
|
|
2730
|
+
"""
|
|
2731
|
+
import subprocess
|
|
2732
|
+
import re
|
|
2733
|
+
|
|
2734
|
+
result = {
|
|
2735
|
+
"running": False,
|
|
2736
|
+
"current_task": None,
|
|
2737
|
+
"progress": {
|
|
2738
|
+
"tasks_completed": 0,
|
|
2739
|
+
"total_tasks": 0,
|
|
2740
|
+
"current_step": 0,
|
|
2741
|
+
},
|
|
2742
|
+
"recent_logs": "",
|
|
2743
|
+
}
|
|
2744
|
+
|
|
2745
|
+
try:
|
|
2746
|
+
# Check if benchmark process is running
|
|
2747
|
+
process_check = subprocess.run(
|
|
2748
|
+
[
|
|
2749
|
+
"ssh",
|
|
2750
|
+
"-o",
|
|
2751
|
+
"StrictHostKeyChecking=no",
|
|
2752
|
+
"-o",
|
|
2753
|
+
"ConnectTimeout=5",
|
|
2754
|
+
"-i",
|
|
2755
|
+
str(Path.home() / ".ssh" / "id_rsa"),
|
|
2756
|
+
f"azureuser@{vm_ip}",
|
|
2757
|
+
f"docker exec {container_name} pgrep -f 'python.*run.py' 2>/dev/null || echo ''",
|
|
2758
|
+
],
|
|
2759
|
+
capture_output=True,
|
|
2760
|
+
text=True,
|
|
2761
|
+
timeout=10,
|
|
2762
|
+
)
|
|
2763
|
+
|
|
2764
|
+
if process_check.returncode == 0 and process_check.stdout.strip():
|
|
2765
|
+
result["running"] = True
|
|
2766
|
+
|
|
2767
|
+
# Get benchmark log
|
|
2768
|
+
log_check = subprocess.run(
|
|
2769
|
+
[
|
|
2770
|
+
"ssh",
|
|
2771
|
+
"-o",
|
|
2772
|
+
"StrictHostKeyChecking=no",
|
|
2773
|
+
"-o",
|
|
2774
|
+
"ConnectTimeout=5",
|
|
2775
|
+
"-i",
|
|
2776
|
+
str(Path.home() / ".ssh" / "id_rsa"),
|
|
2777
|
+
f"azureuser@{vm_ip}",
|
|
2778
|
+
"tail -100 /tmp/waa_benchmark.log 2>/dev/null || echo ''",
|
|
2779
|
+
],
|
|
2780
|
+
capture_output=True,
|
|
2781
|
+
text=True,
|
|
2782
|
+
timeout=10,
|
|
2783
|
+
)
|
|
2784
|
+
|
|
2785
|
+
if log_check.returncode == 0 and log_check.stdout.strip():
|
|
2786
|
+
logs = log_check.stdout
|
|
2787
|
+
result["recent_logs"] = logs[-500:] # Last 500 chars
|
|
2788
|
+
|
|
2789
|
+
# Parse progress from logs
|
|
2790
|
+
# Look for patterns like "Task 5/30" or "Completed: 5, Remaining: 25"
|
|
2791
|
+
task_match = re.search(r"Task\s+(\d+)/(\d+)", logs)
|
|
2792
|
+
if task_match:
|
|
2793
|
+
result["progress"]["tasks_completed"] = int(
|
|
2794
|
+
task_match.group(1)
|
|
2795
|
+
)
|
|
2796
|
+
result["progress"]["total_tasks"] = int(task_match.group(2))
|
|
2797
|
+
|
|
2798
|
+
# Extract current task ID
|
|
2799
|
+
task_id_match = re.search(
|
|
2800
|
+
r"(?:Running|Processing) task:\s*(\S+)", logs
|
|
2801
|
+
)
|
|
2802
|
+
if task_id_match:
|
|
2803
|
+
result["current_task"] = task_id_match.group(1)
|
|
2804
|
+
|
|
2805
|
+
# Extract step info
|
|
2806
|
+
step_match = re.search(r"Step\s+(\d+)", logs)
|
|
2807
|
+
if step_match:
|
|
2808
|
+
result["progress"]["current_step"] = int(
|
|
2809
|
+
step_match.group(1)
|
|
2810
|
+
)
|
|
2811
|
+
|
|
2812
|
+
except Exception:
|
|
2813
|
+
# SSH or parsing failed - leave defaults
|
|
2814
|
+
pass
|
|
2815
|
+
|
|
2816
|
+
return result
|
|
2817
|
+
|
|
2818
|
+
def _parse_task_result(self, log_lines: list[str], task_id: str) -> dict:
|
|
2819
|
+
"""Parse task success/failure from log output.
|
|
2820
|
+
|
|
2821
|
+
WAA log patterns:
|
|
2822
|
+
- Success: "Task task_001 completed successfully"
|
|
2823
|
+
- Success: "Result: PASS"
|
|
2824
|
+
- Failure: "Task task_001 failed"
|
|
2825
|
+
- Failure: "Result: FAIL"
|
|
2826
|
+
- Score: "Score: 0.85"
|
|
2827
|
+
"""
|
|
2828
|
+
import re
|
|
2829
|
+
|
|
2830
|
+
success = None
|
|
2831
|
+
score = None
|
|
2832
|
+
|
|
2833
|
+
# Search backwards from most recent
|
|
2834
|
+
for line in reversed(log_lines):
|
|
2835
|
+
# Check for explicit result
|
|
2836
|
+
if "Result: PASS" in line or "completed successfully" in line:
|
|
2837
|
+
success = True
|
|
2838
|
+
elif "Result: FAIL" in line or "failed" in line.lower():
|
|
2839
|
+
success = False
|
|
2840
|
+
|
|
2841
|
+
# Check for score
|
|
2842
|
+
score_match = re.search(r"Score:\s*([\d.]+)", line)
|
|
2843
|
+
if score_match:
|
|
2844
|
+
try:
|
|
2845
|
+
score = float(score_match.group(1))
|
|
2846
|
+
except ValueError:
|
|
2847
|
+
pass
|
|
2848
|
+
|
|
2849
|
+
# Check for task-specific completion
|
|
2850
|
+
if task_id in line:
|
|
2851
|
+
if "success" in line.lower() or "pass" in line.lower():
|
|
2852
|
+
success = True
|
|
2853
|
+
elif "fail" in line.lower() or "error" in line.lower():
|
|
2854
|
+
success = False
|
|
2855
|
+
|
|
2856
|
+
# Default to True if no explicit failure found (backwards compatible)
|
|
2857
|
+
if success is None:
|
|
2858
|
+
success = True
|
|
2859
|
+
|
|
2860
|
+
return {"success": success, "score": score}
|
|
2861
|
+
|
|
2862
|
+
def _stream_benchmark_updates(self, interval: int):
|
|
2863
|
+
"""Stream Server-Sent Events for benchmark status updates.
|
|
2864
|
+
|
|
2865
|
+
Streams events:
|
|
2866
|
+
- connected: Initial connection event
|
|
2867
|
+
- status: VM status and probe results
|
|
2868
|
+
- progress: Benchmark progress (tasks completed, current task)
|
|
2869
|
+
- task_complete: When a task finishes
|
|
2870
|
+
- heartbeat: Keep-alive signal every 30 seconds
|
|
2871
|
+
- error: Error messages
|
|
2872
|
+
|
|
2873
|
+
Uses a generator-based approach to avoid blocking the main thread
|
|
2874
|
+
and properly handles client disconnection.
|
|
2875
|
+
"""
|
|
2876
|
+
import time
|
|
2877
|
+
import select
|
|
2878
|
+
|
|
2879
|
+
HEARTBEAT_INTERVAL = 30 # seconds
|
|
2880
|
+
|
|
2881
|
+
# Set SSE headers
|
|
2882
|
+
self.send_response(200)
|
|
2883
|
+
self.send_header("Content-Type", "text/event-stream")
|
|
2884
|
+
self.send_header("Cache-Control", "no-cache")
|
|
2885
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
2886
|
+
self.send_header("Connection", "keep-alive")
|
|
2887
|
+
self.send_header("X-Accel-Buffering", "no") # Disable nginx buffering
|
|
2888
|
+
self.end_headers()
|
|
2889
|
+
|
|
2890
|
+
# Track connection state
|
|
2891
|
+
client_connected = True
|
|
2892
|
+
|
|
2893
|
+
def send_event(event_type: str, data: dict) -> bool:
|
|
2894
|
+
"""Send an SSE event. Returns False if client disconnected."""
|
|
2895
|
+
nonlocal client_connected
|
|
2896
|
+
if not client_connected:
|
|
2897
|
+
return False
|
|
2898
|
+
try:
|
|
2899
|
+
event_str = f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
|
|
2900
|
+
self.wfile.write(event_str.encode("utf-8"))
|
|
2901
|
+
self.wfile.flush()
|
|
2902
|
+
return True
|
|
2903
|
+
except (BrokenPipeError, ConnectionResetError, ConnectionAbortedError):
|
|
2904
|
+
# Client disconnected
|
|
2905
|
+
client_connected = False
|
|
2906
|
+
return False
|
|
2907
|
+
except Exception as e:
|
|
2908
|
+
# Other error - log and assume disconnected
|
|
2909
|
+
print(f"SSE send error: {e}")
|
|
2910
|
+
client_connected = False
|
|
2911
|
+
return False
|
|
2912
|
+
|
|
2913
|
+
def check_client_connected() -> bool:
|
|
2914
|
+
"""Check if client is still connected using socket select."""
|
|
2915
|
+
nonlocal client_connected
|
|
2916
|
+
if not client_connected:
|
|
2917
|
+
return False
|
|
2918
|
+
try:
|
|
2919
|
+
# Check if socket has data (would indicate client sent something or closed)
|
|
2920
|
+
# Use non-blocking check with 0 timeout
|
|
2921
|
+
rlist, _, xlist = select.select([self.rfile], [], [self.rfile], 0)
|
|
2922
|
+
if xlist:
|
|
2923
|
+
# Error condition on socket
|
|
2924
|
+
client_connected = False
|
|
2925
|
+
return False
|
|
2926
|
+
if rlist:
|
|
2927
|
+
# Client sent data - for SSE this usually means disconnect
|
|
2928
|
+
# (SSE is server-push only, client doesn't send data)
|
|
2929
|
+
data = self.rfile.read(1)
|
|
2930
|
+
if not data:
|
|
2931
|
+
client_connected = False
|
|
2932
|
+
return False
|
|
2933
|
+
return True
|
|
2934
|
+
except Exception:
|
|
2935
|
+
client_connected = False
|
|
2936
|
+
return False
|
|
2937
|
+
|
|
2938
|
+
last_task = None
|
|
2939
|
+
last_heartbeat = time.time()
|
|
2940
|
+
recent_log_lines = []
|
|
2941
|
+
|
|
2942
|
+
# Send initial connected event
|
|
2943
|
+
if not send_event(
|
|
2944
|
+
"connected",
|
|
2945
|
+
{"timestamp": time.time(), "interval": interval, "version": "1.0"},
|
|
2946
|
+
):
|
|
2947
|
+
return
|
|
2948
|
+
|
|
2949
|
+
try:
|
|
2950
|
+
iteration_count = 0
|
|
2951
|
+
max_iterations = 3600 // interval # Max 1 hour of streaming
|
|
2952
|
+
|
|
2953
|
+
while client_connected and iteration_count < max_iterations:
|
|
2954
|
+
iteration_count += 1
|
|
2955
|
+
current_time = time.time()
|
|
2956
|
+
|
|
2957
|
+
# Check client connection before doing work
|
|
2958
|
+
if not check_client_connected():
|
|
2959
|
+
break
|
|
2960
|
+
|
|
2961
|
+
# Send heartbeat every 30 seconds to prevent proxy/LB timeouts
|
|
2962
|
+
if current_time - last_heartbeat >= HEARTBEAT_INTERVAL:
|
|
2963
|
+
if not send_event("heartbeat", {"timestamp": current_time}):
|
|
2964
|
+
break
|
|
2965
|
+
last_heartbeat = current_time
|
|
2966
|
+
|
|
2967
|
+
# Fetch background tasks (includes VM status)
|
|
2968
|
+
tasks = self._fetch_background_tasks()
|
|
2969
|
+
|
|
2970
|
+
# Send VM status event
|
|
2971
|
+
vm_task = next(
|
|
2972
|
+
(t for t in tasks if t.get("task_type") == "docker_container"),
|
|
2973
|
+
None,
|
|
2974
|
+
)
|
|
2975
|
+
if vm_task:
|
|
2976
|
+
vm_data = {
|
|
2977
|
+
"type": "vm_status",
|
|
2978
|
+
"connected": vm_task.get("status")
|
|
2979
|
+
in ["running", "completed"],
|
|
2980
|
+
"phase": vm_task.get("phase", "unknown"),
|
|
2981
|
+
"waa_ready": vm_task.get("metadata", {}).get(
|
|
2982
|
+
"waa_server_ready", False
|
|
2983
|
+
),
|
|
2984
|
+
"probe": {
|
|
2985
|
+
"status": vm_task.get("metadata", {}).get(
|
|
2986
|
+
"probe_response", "unknown"
|
|
2987
|
+
),
|
|
2988
|
+
"vnc_url": vm_task.get("metadata", {}).get("vnc_url"),
|
|
2989
|
+
},
|
|
2990
|
+
}
|
|
2991
|
+
|
|
2992
|
+
if not send_event("status", vm_data):
|
|
2993
|
+
break
|
|
2994
|
+
|
|
2995
|
+
# If VM is ready, check for running benchmark
|
|
2996
|
+
if vm_data["waa_ready"]:
|
|
2997
|
+
# Get VM IP from tasks
|
|
2998
|
+
vm_ip = None
|
|
2999
|
+
azure_vm = next(
|
|
3000
|
+
(
|
|
3001
|
+
t
|
|
3002
|
+
for t in tasks
|
|
3003
|
+
if t.get("task_type") == "vm_provision"
|
|
3004
|
+
),
|
|
3005
|
+
None,
|
|
3006
|
+
)
|
|
3007
|
+
if azure_vm:
|
|
3008
|
+
vm_ip = azure_vm.get("metadata", {}).get("ip_address")
|
|
3009
|
+
|
|
3010
|
+
if vm_ip:
|
|
3011
|
+
# Detect running benchmark using sync version
|
|
3012
|
+
benchmark_status = self._detect_running_benchmark_sync(
|
|
3013
|
+
vm_ip,
|
|
3014
|
+
vm_task.get("metadata", {}).get(
|
|
3015
|
+
"container", "winarena"
|
|
3016
|
+
),
|
|
3017
|
+
)
|
|
3018
|
+
|
|
3019
|
+
if benchmark_status["running"]:
|
|
3020
|
+
# Store log lines for result parsing
|
|
3021
|
+
if benchmark_status.get("recent_logs"):
|
|
3022
|
+
recent_log_lines = benchmark_status[
|
|
3023
|
+
"recent_logs"
|
|
3024
|
+
].split("\n")
|
|
3025
|
+
|
|
3026
|
+
# Send progress event
|
|
3027
|
+
progress_data = {
|
|
3028
|
+
"tasks_completed": benchmark_status["progress"][
|
|
3029
|
+
"tasks_completed"
|
|
3030
|
+
],
|
|
3031
|
+
"total_tasks": benchmark_status["progress"][
|
|
3032
|
+
"total_tasks"
|
|
3033
|
+
],
|
|
3034
|
+
"current_task": benchmark_status[
|
|
3035
|
+
"current_task"
|
|
3036
|
+
],
|
|
3037
|
+
"current_step": benchmark_status["progress"][
|
|
3038
|
+
"current_step"
|
|
3039
|
+
],
|
|
3040
|
+
}
|
|
3041
|
+
|
|
3042
|
+
if not send_event("progress", progress_data):
|
|
3043
|
+
break
|
|
3044
|
+
|
|
3045
|
+
# Check if task completed
|
|
3046
|
+
current_task = benchmark_status["current_task"]
|
|
3047
|
+
if current_task and current_task != last_task:
|
|
3048
|
+
if last_task is not None:
|
|
3049
|
+
# Previous task completed - parse result from logs
|
|
3050
|
+
result = self._parse_task_result(
|
|
3051
|
+
recent_log_lines, last_task
|
|
3052
|
+
)
|
|
3053
|
+
complete_data = {
|
|
3054
|
+
"task_id": last_task,
|
|
3055
|
+
"success": result["success"],
|
|
3056
|
+
"score": result["score"],
|
|
3057
|
+
}
|
|
3058
|
+
if not send_event(
|
|
3059
|
+
"task_complete", complete_data
|
|
3060
|
+
):
|
|
3061
|
+
break
|
|
3062
|
+
|
|
3063
|
+
last_task = current_task
|
|
3064
|
+
|
|
3065
|
+
# Check local benchmark progress file
|
|
3066
|
+
progress_file = Path("benchmark_progress.json")
|
|
3067
|
+
if progress_file.exists():
|
|
3068
|
+
try:
|
|
3069
|
+
progress = json.loads(progress_file.read_text())
|
|
3070
|
+
if progress.get("status") == "running":
|
|
3071
|
+
progress_data = {
|
|
3072
|
+
"tasks_completed": progress.get(
|
|
3073
|
+
"tasks_complete", 0
|
|
3074
|
+
),
|
|
3075
|
+
"total_tasks": progress.get("tasks_total", 0),
|
|
3076
|
+
"current_task": progress.get("provider", "unknown"),
|
|
3077
|
+
}
|
|
3078
|
+
if not send_event("progress", progress_data):
|
|
3079
|
+
break
|
|
3080
|
+
except Exception:
|
|
3081
|
+
pass
|
|
3082
|
+
|
|
3083
|
+
# Non-blocking sleep using select with timeout
|
|
3084
|
+
# This allows checking for client disconnect during sleep
|
|
3085
|
+
try:
|
|
3086
|
+
select.select([self.rfile], [], [], interval)
|
|
3087
|
+
except Exception:
|
|
3088
|
+
break
|
|
3089
|
+
|
|
3090
|
+
except (BrokenPipeError, ConnectionResetError, ConnectionAbortedError):
|
|
3091
|
+
# Client disconnected - this is normal, don't log as error
|
|
3092
|
+
pass
|
|
3093
|
+
except Exception as e:
|
|
3094
|
+
# Send error event if still connected
|
|
3095
|
+
send_event("error", {"message": str(e)})
|
|
3096
|
+
finally:
|
|
3097
|
+
# Cleanup - connection is ending
|
|
3098
|
+
client_connected = False
|
|
3099
|
+
|
|
3100
|
+
def _stream_azure_ops_updates(self):
|
|
3101
|
+
"""Stream Server-Sent Events for Azure operations status updates.
|
|
3102
|
+
|
|
3103
|
+
Monitors azure_ops_status.json for changes and streams updates.
|
|
3104
|
+
Uses file modification time to detect changes efficiently.
|
|
3105
|
+
|
|
3106
|
+
Streams events:
|
|
3107
|
+
- connected: Initial connection event
|
|
3108
|
+
- status: Azure ops status update when file changes
|
|
3109
|
+
- heartbeat: Keep-alive signal every 30 seconds
|
|
3110
|
+
- error: Error messages
|
|
3111
|
+
"""
|
|
3112
|
+
import time
|
|
3113
|
+
import select
|
|
3114
|
+
from pathlib import Path
|
|
3115
|
+
|
|
3116
|
+
HEARTBEAT_INTERVAL = 30 # seconds
|
|
3117
|
+
CHECK_INTERVAL = 1 # Check file every second
|
|
3118
|
+
|
|
3119
|
+
# Set SSE headers
|
|
3120
|
+
self.send_response(200)
|
|
3121
|
+
self.send_header("Content-Type", "text/event-stream")
|
|
3122
|
+
self.send_header("Cache-Control", "no-cache")
|
|
3123
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
3124
|
+
self.send_header("Connection", "keep-alive")
|
|
3125
|
+
self.send_header("X-Accel-Buffering", "no") # Disable nginx buffering
|
|
3126
|
+
self.end_headers()
|
|
3127
|
+
|
|
3128
|
+
# Track connection state
|
|
3129
|
+
client_connected = True
|
|
3130
|
+
last_mtime = 0.0
|
|
3131
|
+
last_session_mtime = 0.0
|
|
3132
|
+
last_heartbeat = time.time()
|
|
3133
|
+
|
|
3134
|
+
def send_event(event_type: str, data: dict) -> bool:
|
|
3135
|
+
"""Send an SSE event. Returns False if client disconnected."""
|
|
3136
|
+
nonlocal client_connected
|
|
3137
|
+
if not client_connected:
|
|
3138
|
+
return False
|
|
3139
|
+
try:
|
|
3140
|
+
event_str = f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
|
|
3141
|
+
self.wfile.write(event_str.encode("utf-8"))
|
|
3142
|
+
self.wfile.flush()
|
|
3143
|
+
return True
|
|
3144
|
+
except (BrokenPipeError, ConnectionResetError, ConnectionAbortedError):
|
|
3145
|
+
client_connected = False
|
|
3146
|
+
return False
|
|
3147
|
+
except Exception as e:
|
|
3148
|
+
print(f"Azure ops SSE send error: {e}")
|
|
3149
|
+
client_connected = False
|
|
3150
|
+
return False
|
|
3151
|
+
|
|
3152
|
+
def check_client_connected() -> bool:
|
|
3153
|
+
"""Check if client is still connected using socket select."""
|
|
3154
|
+
nonlocal client_connected
|
|
3155
|
+
if not client_connected:
|
|
3156
|
+
return False
|
|
3157
|
+
try:
|
|
3158
|
+
rlist, _, xlist = select.select([self.rfile], [], [self.rfile], 0)
|
|
3159
|
+
if xlist:
|
|
3160
|
+
client_connected = False
|
|
3161
|
+
return False
|
|
3162
|
+
if rlist:
|
|
3163
|
+
data = self.rfile.read(1)
|
|
3164
|
+
if not data:
|
|
3165
|
+
client_connected = False
|
|
3166
|
+
return False
|
|
3167
|
+
return True
|
|
3168
|
+
except Exception:
|
|
3169
|
+
client_connected = False
|
|
3170
|
+
return False
|
|
3171
|
+
|
|
3172
|
+
# Status file path
|
|
3173
|
+
from openadapt_ml.benchmarks.azure_ops_tracker import (
|
|
3174
|
+
DEFAULT_OUTPUT_FILE,
|
|
3175
|
+
read_status,
|
|
3176
|
+
)
|
|
3177
|
+
from openadapt_ml.benchmarks.session_tracker import (
|
|
3178
|
+
get_session,
|
|
3179
|
+
update_session_vm_state,
|
|
3180
|
+
DEFAULT_SESSION_FILE,
|
|
3181
|
+
)
|
|
3182
|
+
|
|
3183
|
+
status_file = Path(DEFAULT_OUTPUT_FILE)
|
|
3184
|
+
session_file = Path(DEFAULT_SESSION_FILE)
|
|
3185
|
+
|
|
3186
|
+
def compute_server_side_values(status: dict) -> dict:
|
|
3187
|
+
"""Get elapsed_seconds and cost_usd from session tracker for persistence."""
|
|
3188
|
+
# Get session data (persistent across refreshes)
|
|
3189
|
+
session = get_session()
|
|
3190
|
+
|
|
3191
|
+
# Update session based on VM state if we have VM info
|
|
3192
|
+
# IMPORTANT: Only pass vm_ip if it's truthy to avoid
|
|
3193
|
+
# overwriting session's stable vm_ip with None
|
|
3194
|
+
if status.get("vm_state") and status.get("vm_state") != "unknown":
|
|
3195
|
+
status_vm_ip = status.get("vm_ip")
|
|
3196
|
+
# Build update kwargs - only include vm_ip if present
|
|
3197
|
+
update_kwargs = {
|
|
3198
|
+
"vm_state": status["vm_state"],
|
|
3199
|
+
"vm_size": status.get("vm_size"),
|
|
3200
|
+
}
|
|
3201
|
+
if status_vm_ip: # Only include if truthy
|
|
3202
|
+
update_kwargs["vm_ip"] = status_vm_ip
|
|
3203
|
+
session = update_session_vm_state(**update_kwargs)
|
|
3204
|
+
|
|
3205
|
+
# Use session's vm_ip as authoritative source
|
|
3206
|
+
# This prevents IP flickering when status file has stale/None values
|
|
3207
|
+
if session.get("vm_ip"):
|
|
3208
|
+
status["vm_ip"] = session["vm_ip"]
|
|
3209
|
+
|
|
3210
|
+
# Use session's elapsed_seconds and cost_usd for persistence
|
|
3211
|
+
if (
|
|
3212
|
+
session.get("is_active")
|
|
3213
|
+
or session.get("accumulated_seconds", 0) > 0
|
|
3214
|
+
):
|
|
3215
|
+
status["elapsed_seconds"] = session.get("elapsed_seconds", 0.0)
|
|
3216
|
+
status["cost_usd"] = session.get("cost_usd", 0.0)
|
|
3217
|
+
status["started_at"] = session.get("started_at")
|
|
3218
|
+
status["session_id"] = session.get("session_id")
|
|
3219
|
+
status["session_is_active"] = session.get("is_active", False)
|
|
3220
|
+
# Include accumulated time from previous sessions for hybrid display
|
|
3221
|
+
status["accumulated_seconds"] = session.get(
|
|
3222
|
+
"accumulated_seconds", 0.0
|
|
3223
|
+
)
|
|
3224
|
+
# Calculate current session time (total - accumulated)
|
|
3225
|
+
current_session_seconds = max(
|
|
3226
|
+
0, status["elapsed_seconds"] - status["accumulated_seconds"]
|
|
3227
|
+
)
|
|
3228
|
+
status["current_session_seconds"] = current_session_seconds
|
|
3229
|
+
hourly_rate = session.get("hourly_rate_usd", 0.422)
|
|
3230
|
+
status["current_session_cost_usd"] = (
|
|
3231
|
+
current_session_seconds / 3600
|
|
3232
|
+
) * hourly_rate
|
|
3233
|
+
|
|
3234
|
+
try:
|
|
3235
|
+
tunnel_mgr = get_tunnel_manager()
|
|
3236
|
+
tunnel_status = tunnel_mgr.get_tunnel_status()
|
|
3237
|
+
status["tunnels"] = {
|
|
3238
|
+
name: {
|
|
3239
|
+
"active": s.active,
|
|
3240
|
+
"local_port": s.local_port,
|
|
3241
|
+
"remote_endpoint": s.remote_endpoint,
|
|
3242
|
+
"pid": s.pid,
|
|
3243
|
+
"error": s.error,
|
|
3244
|
+
}
|
|
3245
|
+
for name, s in tunnel_status.items()
|
|
3246
|
+
}
|
|
3247
|
+
except Exception as e:
|
|
3248
|
+
status["tunnels"] = {"error": str(e)}
|
|
3249
|
+
|
|
3250
|
+
return status
|
|
3251
|
+
|
|
3252
|
+
# Send initial connected event
|
|
3253
|
+
if not send_event(
|
|
3254
|
+
"connected",
|
|
3255
|
+
{"timestamp": time.time(), "version": "1.0"},
|
|
3256
|
+
):
|
|
3257
|
+
return
|
|
3258
|
+
|
|
3259
|
+
# Send initial status immediately
|
|
3260
|
+
try:
|
|
3261
|
+
status = compute_server_side_values(read_status())
|
|
3262
|
+
if not send_event("status", status):
|
|
3263
|
+
return
|
|
3264
|
+
if status_file.exists():
|
|
3265
|
+
last_mtime = status_file.stat().st_mtime
|
|
3266
|
+
except Exception as e:
|
|
3267
|
+
send_event("error", {"message": str(e)})
|
|
3268
|
+
|
|
3269
|
+
try:
|
|
3270
|
+
iteration_count = 0
|
|
3271
|
+
max_iterations = 3600 # Max 1 hour of streaming
|
|
3272
|
+
last_status_send = 0.0
|
|
3273
|
+
STATUS_SEND_INTERVAL = 2 # Send status every 2 seconds for live updates
|
|
3274
|
+
|
|
3275
|
+
while client_connected and iteration_count < max_iterations:
|
|
3276
|
+
iteration_count += 1
|
|
3277
|
+
current_time = time.time()
|
|
3278
|
+
|
|
3279
|
+
# Check client connection
|
|
3280
|
+
if not check_client_connected():
|
|
3281
|
+
break
|
|
3282
|
+
|
|
3283
|
+
# Send heartbeat every 30 seconds
|
|
3284
|
+
if current_time - last_heartbeat >= HEARTBEAT_INTERVAL:
|
|
3285
|
+
if not send_event("heartbeat", {"timestamp": current_time}):
|
|
3286
|
+
break
|
|
3287
|
+
last_heartbeat = current_time
|
|
3288
|
+
|
|
3289
|
+
# Check if status or session file changed OR if enough time passed
|
|
3290
|
+
try:
|
|
3291
|
+
status_changed = False
|
|
3292
|
+
session_changed = False
|
|
3293
|
+
time_to_send = (
|
|
3294
|
+
current_time - last_status_send >= STATUS_SEND_INTERVAL
|
|
3295
|
+
)
|
|
3296
|
+
|
|
3297
|
+
if status_file.exists():
|
|
3298
|
+
current_mtime = status_file.stat().st_mtime
|
|
3299
|
+
if current_mtime > last_mtime:
|
|
3300
|
+
status_changed = True
|
|
3301
|
+
last_mtime = current_mtime
|
|
3302
|
+
|
|
3303
|
+
if session_file.exists():
|
|
3304
|
+
current_session_mtime = session_file.stat().st_mtime
|
|
3305
|
+
if current_session_mtime > last_session_mtime:
|
|
3306
|
+
session_changed = True
|
|
3307
|
+
last_session_mtime = current_session_mtime
|
|
3308
|
+
|
|
3309
|
+
# Send status if file changed OR periodic timer expired
|
|
3310
|
+
# This ensures live elapsed time/cost updates even without file changes
|
|
3311
|
+
if status_changed or session_changed or time_to_send:
|
|
3312
|
+
# File changed or time to send - send update with session values
|
|
3313
|
+
status = compute_server_side_values(read_status())
|
|
3314
|
+
if not send_event("status", status):
|
|
3315
|
+
break
|
|
3316
|
+
last_status_send = current_time
|
|
3317
|
+
except Exception as e:
|
|
3318
|
+
# File access error - log but continue
|
|
3319
|
+
print(f"Azure ops SSE file check error: {e}")
|
|
3320
|
+
|
|
3321
|
+
# Sleep briefly before next check
|
|
3322
|
+
try:
|
|
3323
|
+
select.select([self.rfile], [], [], CHECK_INTERVAL)
|
|
3324
|
+
except Exception:
|
|
3325
|
+
break
|
|
3326
|
+
|
|
3327
|
+
except (BrokenPipeError, ConnectionResetError, ConnectionAbortedError):
|
|
3328
|
+
# Client disconnected - normal
|
|
3329
|
+
pass
|
|
3330
|
+
except Exception as e:
|
|
3331
|
+
send_event("error", {"message": str(e)})
|
|
3332
|
+
finally:
|
|
3333
|
+
client_connected = False
|
|
3334
|
+
|
|
3335
|
+
def _detect_running_benchmark_sync(
|
|
3336
|
+
self, vm_ip: str, container_name: str = "winarena"
|
|
3337
|
+
) -> dict:
|
|
3338
|
+
"""Synchronous version of _detect_running_benchmark.
|
|
3339
|
+
|
|
3340
|
+
Avoids creating a new event loop on each call which causes issues
|
|
3341
|
+
when called from a synchronous context.
|
|
3342
|
+
"""
|
|
3343
|
+
import subprocess
|
|
3344
|
+
import re
|
|
3345
|
+
|
|
3346
|
+
result = {
|
|
3347
|
+
"running": False,
|
|
3348
|
+
"current_task": None,
|
|
3349
|
+
"progress": {
|
|
3350
|
+
"tasks_completed": 0,
|
|
3351
|
+
"total_tasks": 0,
|
|
3352
|
+
"current_step": 0,
|
|
3353
|
+
},
|
|
3354
|
+
"recent_logs": "",
|
|
3355
|
+
}
|
|
3356
|
+
|
|
3357
|
+
try:
|
|
3358
|
+
# Check if benchmark process is running
|
|
3359
|
+
process_check = subprocess.run(
|
|
3360
|
+
[
|
|
3361
|
+
"ssh",
|
|
3362
|
+
"-o",
|
|
3363
|
+
"StrictHostKeyChecking=no",
|
|
3364
|
+
"-o",
|
|
3365
|
+
"ConnectTimeout=5",
|
|
3366
|
+
"-i",
|
|
3367
|
+
str(Path.home() / ".ssh" / "id_rsa"),
|
|
3368
|
+
f"azureuser@{vm_ip}",
|
|
3369
|
+
f"docker exec {container_name} pgrep -f 'python.*run.py' 2>/dev/null || echo ''",
|
|
3370
|
+
],
|
|
3371
|
+
capture_output=True,
|
|
3372
|
+
text=True,
|
|
3373
|
+
timeout=10,
|
|
3374
|
+
)
|
|
3375
|
+
|
|
3376
|
+
if process_check.returncode == 0 and process_check.stdout.strip():
|
|
3377
|
+
result["running"] = True
|
|
3378
|
+
|
|
3379
|
+
# Get benchmark log
|
|
3380
|
+
log_check = subprocess.run(
|
|
3381
|
+
[
|
|
3382
|
+
"ssh",
|
|
3383
|
+
"-o",
|
|
3384
|
+
"StrictHostKeyChecking=no",
|
|
3385
|
+
"-o",
|
|
3386
|
+
"ConnectTimeout=5",
|
|
3387
|
+
"-i",
|
|
3388
|
+
str(Path.home() / ".ssh" / "id_rsa"),
|
|
3389
|
+
f"azureuser@{vm_ip}",
|
|
3390
|
+
"tail -100 /tmp/waa_benchmark.log 2>/dev/null || echo ''",
|
|
3391
|
+
],
|
|
3392
|
+
capture_output=True,
|
|
3393
|
+
text=True,
|
|
3394
|
+
timeout=10,
|
|
3395
|
+
)
|
|
3396
|
+
|
|
3397
|
+
if log_check.returncode == 0 and log_check.stdout.strip():
|
|
3398
|
+
logs = log_check.stdout
|
|
3399
|
+
result["recent_logs"] = logs[-500:] # Last 500 chars
|
|
3400
|
+
|
|
3401
|
+
# Parse progress from logs
|
|
3402
|
+
task_match = re.search(r"Task\s+(\d+)/(\d+)", logs)
|
|
3403
|
+
if task_match:
|
|
3404
|
+
result["progress"]["tasks_completed"] = int(
|
|
3405
|
+
task_match.group(1)
|
|
3406
|
+
)
|
|
3407
|
+
result["progress"]["total_tasks"] = int(task_match.group(2))
|
|
3408
|
+
|
|
3409
|
+
# Extract current task ID
|
|
3410
|
+
task_id_match = re.search(
|
|
3411
|
+
r"(?:Running|Processing) task:\s*(\S+)", logs
|
|
3412
|
+
)
|
|
3413
|
+
if task_id_match:
|
|
3414
|
+
result["current_task"] = task_id_match.group(1)
|
|
3415
|
+
|
|
3416
|
+
# Extract step info
|
|
3417
|
+
step_match = re.search(r"Step\s+(\d+)", logs)
|
|
3418
|
+
if step_match:
|
|
3419
|
+
result["progress"]["current_step"] = int(
|
|
3420
|
+
step_match.group(1)
|
|
3421
|
+
)
|
|
3422
|
+
|
|
3423
|
+
except Exception:
|
|
3424
|
+
# SSH or parsing failed - leave defaults
|
|
3425
|
+
pass
|
|
3426
|
+
|
|
3427
|
+
return result
|
|
3428
|
+
|
|
3429
|
+
def _register_vm(self, vm_data):
|
|
3430
|
+
"""Register a new VM in the registry."""
|
|
3431
|
+
# Path to VM registry file (relative to project root)
|
|
3432
|
+
project_root = Path(__file__).parent.parent.parent
|
|
3433
|
+
registry_file = project_root / "benchmark_results" / "vm_registry.json"
|
|
3434
|
+
|
|
3435
|
+
# Load existing registry
|
|
3436
|
+
vms = []
|
|
3437
|
+
if registry_file.exists():
|
|
3438
|
+
try:
|
|
3439
|
+
with open(registry_file) as f:
|
|
3440
|
+
vms = json.load(f)
|
|
3441
|
+
except Exception:
|
|
3442
|
+
pass
|
|
3443
|
+
|
|
3444
|
+
# Add new VM
|
|
3445
|
+
new_vm = {
|
|
3446
|
+
"name": vm_data.get("name", "unnamed-vm"),
|
|
3447
|
+
"ssh_host": vm_data.get("ssh_host", ""),
|
|
3448
|
+
"ssh_user": vm_data.get("ssh_user", "azureuser"),
|
|
3449
|
+
"vnc_port": vm_data.get("vnc_port", 8006),
|
|
3450
|
+
"waa_port": vm_data.get("waa_port", 5000),
|
|
3451
|
+
"docker_container": vm_data.get("docker_container", "win11-waa"),
|
|
3452
|
+
"internal_ip": vm_data.get("internal_ip", "20.20.20.21"),
|
|
3453
|
+
}
|
|
3454
|
+
|
|
3455
|
+
vms.append(new_vm)
|
|
3456
|
+
|
|
3457
|
+
# Save registry
|
|
3458
|
+
try:
|
|
3459
|
+
registry_file.parent.mkdir(parents=True, exist_ok=True)
|
|
3460
|
+
with open(registry_file, "w") as f:
|
|
3461
|
+
json.dump(vms, f, indent=2)
|
|
3462
|
+
return {"status": "success", "vm": new_vm}
|
|
3463
|
+
except Exception as e:
|
|
3464
|
+
return {"status": "error", "message": str(e)}
|
|
3465
|
+
|
|
3466
|
+
def _start_benchmark_run(self, params: dict) -> dict:
|
|
3467
|
+
"""Start a benchmark run with the given parameters.
|
|
3468
|
+
|
|
3469
|
+
Runs the benchmark in a background thread and returns immediately.
|
|
3470
|
+
Progress is tracked via benchmark_progress.json.
|
|
3471
|
+
|
|
3472
|
+
Expected params:
|
|
3473
|
+
{
|
|
3474
|
+
"model": "gpt-4o",
|
|
3475
|
+
"num_tasks": 5,
|
|
3476
|
+
"agent": "navi",
|
|
3477
|
+
"task_selection": "all" | "domain" | "task_ids",
|
|
3478
|
+
"domain": "general", // if task_selection == "domain"
|
|
3479
|
+
"task_ids": ["task_001", "task_015"] // if task_selection == "task_ids"
|
|
3480
|
+
}
|
|
3481
|
+
|
|
3482
|
+
Returns:
|
|
3483
|
+
dict with status and params
|
|
3484
|
+
"""
|
|
3485
|
+
from dotenv import load_dotenv
|
|
3486
|
+
|
|
3487
|
+
# Load .env file for API keys
|
|
3488
|
+
project_root = Path(__file__).parent.parent.parent
|
|
3489
|
+
load_dotenv(project_root / ".env")
|
|
3490
|
+
|
|
3491
|
+
# Build CLI command
|
|
3492
|
+
cmd = [
|
|
3493
|
+
"uv",
|
|
3494
|
+
"run",
|
|
3495
|
+
"python",
|
|
3496
|
+
"-m",
|
|
3497
|
+
"openadapt_ml.benchmarks.cli",
|
|
3498
|
+
"vm",
|
|
3499
|
+
"run-waa",
|
|
3500
|
+
"--num-tasks",
|
|
3501
|
+
str(params.get("num_tasks", 5)),
|
|
3502
|
+
"--model",
|
|
3503
|
+
params.get("model", "gpt-4o"),
|
|
3504
|
+
"--agent",
|
|
3505
|
+
params.get("agent", "navi"),
|
|
3506
|
+
"--no-open", # Don't open viewer (already open)
|
|
3507
|
+
]
|
|
3508
|
+
|
|
3509
|
+
# Add task selection args
|
|
3510
|
+
task_selection = params.get("task_selection", "all")
|
|
3511
|
+
if task_selection == "domain":
|
|
3512
|
+
domain = params.get("domain", "general")
|
|
3513
|
+
cmd.extend(["--domain", domain])
|
|
3514
|
+
elif task_selection == "task_ids":
|
|
3515
|
+
task_ids = params.get("task_ids", [])
|
|
3516
|
+
if task_ids:
|
|
3517
|
+
cmd.extend(["--task-ids", ",".join(task_ids)])
|
|
3518
|
+
|
|
3519
|
+
# Create progress log file (in cwd which is serve_dir)
|
|
3520
|
+
progress_file = Path("benchmark_progress.json")
|
|
3521
|
+
|
|
3522
|
+
# Write initial progress
|
|
3523
|
+
model = params.get("model", "gpt-4o")
|
|
3524
|
+
num_tasks = params.get("num_tasks", 5)
|
|
3525
|
+
agent = params.get("agent", "navi")
|
|
3526
|
+
|
|
3527
|
+
print(
|
|
3528
|
+
f"\n[Benchmark] Starting WAA benchmark: model={model}, tasks={num_tasks}, agent={agent}"
|
|
3529
|
+
)
|
|
3530
|
+
print(f"[Benchmark] Task selection: {task_selection}")
|
|
3531
|
+
if task_selection == "domain":
|
|
3532
|
+
print(f"[Benchmark] Domain: {params.get('domain', 'general')}")
|
|
3533
|
+
elif task_selection == "task_ids":
|
|
3534
|
+
print(f"[Benchmark] Task IDs: {params.get('task_ids', [])}")
|
|
3535
|
+
print(f"[Benchmark] Command: {' '.join(cmd)}")
|
|
3536
|
+
|
|
3537
|
+
progress_file.write_text(
|
|
3538
|
+
json.dumps(
|
|
3539
|
+
{
|
|
3540
|
+
"status": "running",
|
|
3541
|
+
"model": model,
|
|
3542
|
+
"num_tasks": num_tasks,
|
|
3543
|
+
"agent": agent,
|
|
3544
|
+
"task_selection": task_selection,
|
|
3545
|
+
"tasks_complete": 0,
|
|
3546
|
+
"message": f"Starting {model} benchmark with {num_tasks} tasks...",
|
|
3547
|
+
}
|
|
3548
|
+
)
|
|
3549
|
+
)
|
|
3550
|
+
|
|
3551
|
+
# Copy environment with loaded vars
|
|
3552
|
+
env = os.environ.copy()
|
|
3553
|
+
|
|
3554
|
+
# Run in background thread
|
|
3555
|
+
def run():
|
|
3556
|
+
result = subprocess.run(
|
|
3557
|
+
cmd, capture_output=True, text=True, cwd=str(project_root), env=env
|
|
3558
|
+
)
|
|
3559
|
+
|
|
3560
|
+
print(f"\n[Benchmark] Output:\n{result.stdout}")
|
|
3561
|
+
if result.stderr:
|
|
3562
|
+
print(f"[Benchmark] Stderr: {result.stderr}")
|
|
3563
|
+
|
|
3564
|
+
if result.returncode == 0:
|
|
3565
|
+
print("[Benchmark] Complete. Regenerating viewer...")
|
|
3566
|
+
progress_file.write_text(
|
|
3567
|
+
json.dumps(
|
|
3568
|
+
{
|
|
3569
|
+
"status": "complete",
|
|
3570
|
+
"model": model,
|
|
3571
|
+
"num_tasks": num_tasks,
|
|
3572
|
+
"message": "Benchmark complete. Refresh to see results.",
|
|
3573
|
+
}
|
|
3574
|
+
)
|
|
3575
|
+
)
|
|
3576
|
+
# Regenerate benchmark viewer
|
|
3577
|
+
_regenerate_benchmark_viewer_if_available(serve_dir)
|
|
3578
|
+
else:
|
|
3579
|
+
error_msg = (
|
|
3580
|
+
result.stderr[:200] if result.stderr else "Unknown error"
|
|
3581
|
+
)
|
|
3582
|
+
print(f"[Benchmark] Failed: {error_msg}")
|
|
3583
|
+
progress_file.write_text(
|
|
3584
|
+
json.dumps(
|
|
3585
|
+
{
|
|
3586
|
+
"status": "error",
|
|
3587
|
+
"model": model,
|
|
3588
|
+
"num_tasks": num_tasks,
|
|
3589
|
+
"message": f"Benchmark failed: {error_msg}",
|
|
3590
|
+
}
|
|
3591
|
+
)
|
|
3592
|
+
)
|
|
3593
|
+
|
|
3594
|
+
threading.Thread(target=run, daemon=True).start()
|
|
3595
|
+
|
|
3596
|
+
return {"status": "started", "params": params}
|
|
3597
|
+
|
|
559
3598
|
def do_OPTIONS(self):
|
|
560
3599
|
# Handle CORS preflight
|
|
561
3600
|
self.send_response(200)
|
|
562
|
-
self.send_header(
|
|
563
|
-
self.send_header(
|
|
564
|
-
self.send_header(
|
|
3601
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
3602
|
+
self.send_header("Access-Control-Allow-Methods", "POST, OPTIONS")
|
|
3603
|
+
self.send_header("Access-Control-Allow-Headers", "Content-Type")
|
|
565
3604
|
self.end_headers()
|
|
566
3605
|
|
|
567
|
-
|
|
3606
|
+
class ThreadedTCPServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
|
|
3607
|
+
allow_reuse_address = True
|
|
3608
|
+
daemon_threads = True # Don't block shutdown
|
|
3609
|
+
|
|
3610
|
+
with ThreadedTCPServer(("", port), StopHandler) as httpd:
|
|
568
3611
|
url = f"http://localhost:{port}/{start_page}"
|
|
569
3612
|
print(f"\nServing at: {url}")
|
|
570
3613
|
print(f"Directory: {serve_dir}")
|
|
@@ -611,7 +3654,40 @@ def cmd_viewer(args: argparse.Namespace) -> int:
|
|
|
611
3654
|
state.learning_rate = data.get("learning_rate", 0)
|
|
612
3655
|
state.losses = data.get("losses", [])
|
|
613
3656
|
state.status = data.get("status", "completed")
|
|
614
|
-
state.elapsed_time = data.get(
|
|
3657
|
+
state.elapsed_time = data.get(
|
|
3658
|
+
"elapsed_time", 0.0
|
|
3659
|
+
) # Load elapsed time for completed training
|
|
3660
|
+
state.goal = data.get("goal", "")
|
|
3661
|
+
state.config_path = data.get("config_path", "")
|
|
3662
|
+
state.capture_path = data.get("capture_path", "")
|
|
3663
|
+
|
|
3664
|
+
# Load model config from training_log.json or fall back to reading config file
|
|
3665
|
+
state.model_name = data.get("model_name", "")
|
|
3666
|
+
state.lora_r = data.get("lora_r", 0)
|
|
3667
|
+
state.lora_alpha = data.get("lora_alpha", 0)
|
|
3668
|
+
state.load_in_4bit = data.get("load_in_4bit", False)
|
|
3669
|
+
|
|
3670
|
+
# If model config not in JSON, try to read from config file
|
|
3671
|
+
if not state.model_name and state.config_path:
|
|
3672
|
+
try:
|
|
3673
|
+
import yaml
|
|
3674
|
+
|
|
3675
|
+
# Try relative to project root first, then as absolute path
|
|
3676
|
+
project_root = Path(__file__).parent.parent.parent
|
|
3677
|
+
config_file = project_root / state.config_path
|
|
3678
|
+
if not config_file.exists():
|
|
3679
|
+
config_file = Path(state.config_path)
|
|
3680
|
+
if config_file.exists():
|
|
3681
|
+
with open(config_file) as cf:
|
|
3682
|
+
cfg = yaml.safe_load(cf)
|
|
3683
|
+
if cfg and "model" in cfg:
|
|
3684
|
+
state.model_name = cfg["model"].get("name", "")
|
|
3685
|
+
state.load_in_4bit = cfg["model"].get("load_in_4bit", False)
|
|
3686
|
+
if cfg and "lora" in cfg:
|
|
3687
|
+
state.lora_r = cfg["lora"].get("r", 0)
|
|
3688
|
+
state.lora_alpha = cfg["lora"].get("lora_alpha", 0)
|
|
3689
|
+
except Exception as e:
|
|
3690
|
+
print(f" Warning: Could not read config file: {e}")
|
|
615
3691
|
|
|
616
3692
|
config = TrainingConfig(
|
|
617
3693
|
num_train_epochs=data.get("total_epochs", 5),
|
|
@@ -620,14 +3696,16 @@ def cmd_viewer(args: argparse.Namespace) -> int:
|
|
|
620
3696
|
|
|
621
3697
|
dashboard_html = generate_training_dashboard(state, config)
|
|
622
3698
|
(current_dir / "dashboard.html").write_text(dashboard_html)
|
|
623
|
-
print(
|
|
3699
|
+
print(" Regenerated: dashboard.html")
|
|
624
3700
|
|
|
625
3701
|
# Generate unified viewer using consolidated function
|
|
626
3702
|
viewer_path = generate_unified_viewer_from_output_dir(current_dir)
|
|
627
3703
|
if viewer_path:
|
|
628
3704
|
print(f"\nGenerated: {viewer_path}")
|
|
629
3705
|
else:
|
|
630
|
-
print(
|
|
3706
|
+
print(
|
|
3707
|
+
"\nNo comparison data found. Run comparison first or copy from capture directory."
|
|
3708
|
+
)
|
|
631
3709
|
|
|
632
3710
|
# Also regenerate benchmark viewer from latest benchmark results
|
|
633
3711
|
_regenerate_benchmark_viewer_if_available(current_dir)
|
|
@@ -647,9 +3725,9 @@ def cmd_benchmark_viewer(args: argparse.Namespace) -> int:
|
|
|
647
3725
|
print(f"Error: Benchmark directory not found: {benchmark_dir}")
|
|
648
3726
|
return 1
|
|
649
3727
|
|
|
650
|
-
print(f"\n{'='*50}")
|
|
3728
|
+
print(f"\n{'=' * 50}")
|
|
651
3729
|
print("GENERATING BENCHMARK VIEWER")
|
|
652
|
-
print(f"{'='*50}")
|
|
3730
|
+
print(f"{'=' * 50}")
|
|
653
3731
|
print(f"Benchmark dir: {benchmark_dir}")
|
|
654
3732
|
print()
|
|
655
3733
|
|
|
@@ -664,6 +3742,7 @@ def cmd_benchmark_viewer(args: argparse.Namespace) -> int:
|
|
|
664
3742
|
except Exception as e:
|
|
665
3743
|
print(f"Error generating benchmark viewer: {e}")
|
|
666
3744
|
import traceback
|
|
3745
|
+
|
|
667
3746
|
traceback.print_exc()
|
|
668
3747
|
return 1
|
|
669
3748
|
|
|
@@ -680,16 +3759,19 @@ def cmd_compare(args: argparse.Namespace) -> int:
|
|
|
680
3759
|
print(f"Error: Checkpoint not found: {checkpoint}")
|
|
681
3760
|
return 1
|
|
682
3761
|
|
|
683
|
-
print(f"\n{'='*50}")
|
|
3762
|
+
print(f"\n{'=' * 50}")
|
|
684
3763
|
print("RUNNING COMPARISON")
|
|
685
|
-
print(f"{'='*50}")
|
|
3764
|
+
print(f"{'=' * 50}")
|
|
686
3765
|
print(f"Capture: {capture_path}")
|
|
687
3766
|
print(f"Checkpoint: {checkpoint or 'None (capture only)'}")
|
|
688
3767
|
print()
|
|
689
3768
|
|
|
690
3769
|
cmd = [
|
|
691
|
-
sys.executable,
|
|
692
|
-
"
|
|
3770
|
+
sys.executable,
|
|
3771
|
+
"-m",
|
|
3772
|
+
"openadapt_ml.scripts.compare",
|
|
3773
|
+
"--capture",
|
|
3774
|
+
str(capture_path),
|
|
693
3775
|
]
|
|
694
3776
|
|
|
695
3777
|
if checkpoint:
|
|
@@ -728,7 +3810,7 @@ Examples:
|
|
|
728
3810
|
|
|
729
3811
|
# Run comparison
|
|
730
3812
|
uv run python -m openadapt_ml.cloud.local compare --capture ~/captures/my-workflow --checkpoint checkpoints/model
|
|
731
|
-
"""
|
|
3813
|
+
""",
|
|
732
3814
|
)
|
|
733
3815
|
|
|
734
3816
|
subparsers = parser.add_subparsers(dest="command", help="Command")
|
|
@@ -740,9 +3822,15 @@ Examples:
|
|
|
740
3822
|
# train
|
|
741
3823
|
p_train = subparsers.add_parser("train", help="Run training locally")
|
|
742
3824
|
p_train.add_argument("--capture", required=True, help="Path to capture directory")
|
|
743
|
-
p_train.add_argument(
|
|
744
|
-
|
|
745
|
-
|
|
3825
|
+
p_train.add_argument(
|
|
3826
|
+
"--goal", help="Task goal (default: derived from capture name)"
|
|
3827
|
+
)
|
|
3828
|
+
p_train.add_argument(
|
|
3829
|
+
"--config", help="Config file (default: auto-select based on device)"
|
|
3830
|
+
)
|
|
3831
|
+
p_train.add_argument(
|
|
3832
|
+
"--open", action="store_true", help="Open dashboard in browser"
|
|
3833
|
+
)
|
|
746
3834
|
p_train.set_defaults(func=cmd_train)
|
|
747
3835
|
|
|
748
3836
|
# check
|
|
@@ -753,10 +3841,21 @@ Examples:
|
|
|
753
3841
|
p_serve = subparsers.add_parser("serve", help="Start web server for dashboard")
|
|
754
3842
|
p_serve.add_argument("--port", type=int, default=8765, help="Port number")
|
|
755
3843
|
p_serve.add_argument("--open", action="store_true", help="Open in browser")
|
|
756
|
-
p_serve.add_argument(
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
p_serve.add_argument(
|
|
3844
|
+
p_serve.add_argument(
|
|
3845
|
+
"--quiet", "-q", action="store_true", help="Suppress request logging"
|
|
3846
|
+
)
|
|
3847
|
+
p_serve.add_argument(
|
|
3848
|
+
"--no-regenerate",
|
|
3849
|
+
action="store_true",
|
|
3850
|
+
help="Skip regenerating dashboard/viewer (serve existing files)",
|
|
3851
|
+
)
|
|
3852
|
+
p_serve.add_argument(
|
|
3853
|
+
"--benchmark",
|
|
3854
|
+
help="Serve benchmark results directory instead of training output",
|
|
3855
|
+
)
|
|
3856
|
+
p_serve.add_argument(
|
|
3857
|
+
"--start-page", help="Override default start page (e.g., benchmark.html)"
|
|
3858
|
+
)
|
|
760
3859
|
p_serve.set_defaults(func=cmd_serve)
|
|
761
3860
|
|
|
762
3861
|
# viewer
|
|
@@ -765,9 +3864,15 @@ Examples:
|
|
|
765
3864
|
p_viewer.set_defaults(func=cmd_viewer)
|
|
766
3865
|
|
|
767
3866
|
# benchmark_viewer
|
|
768
|
-
p_benchmark = subparsers.add_parser(
|
|
769
|
-
|
|
770
|
-
|
|
3867
|
+
p_benchmark = subparsers.add_parser(
|
|
3868
|
+
"benchmark-viewer", help="Generate benchmark viewer"
|
|
3869
|
+
)
|
|
3870
|
+
p_benchmark.add_argument(
|
|
3871
|
+
"benchmark_dir", help="Path to benchmark results directory"
|
|
3872
|
+
)
|
|
3873
|
+
p_benchmark.add_argument(
|
|
3874
|
+
"--open", action="store_true", help="Open viewer in browser"
|
|
3875
|
+
)
|
|
771
3876
|
p_benchmark.set_defaults(func=cmd_benchmark_viewer)
|
|
772
3877
|
|
|
773
3878
|
# compare
|