openadapt-ml 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/baselines/__init__.py +121 -0
- openadapt_ml/baselines/adapter.py +185 -0
- openadapt_ml/baselines/cli.py +314 -0
- openadapt_ml/baselines/config.py +448 -0
- openadapt_ml/baselines/parser.py +922 -0
- openadapt_ml/baselines/prompts.py +787 -0
- openadapt_ml/benchmarks/__init__.py +13 -115
- openadapt_ml/benchmarks/agent.py +265 -421
- openadapt_ml/benchmarks/azure.py +28 -19
- openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
- openadapt_ml/benchmarks/cli.py +1722 -4847
- openadapt_ml/benchmarks/trace_export.py +631 -0
- openadapt_ml/benchmarks/viewer.py +22 -5
- openadapt_ml/benchmarks/vm_monitor.py +530 -29
- openadapt_ml/benchmarks/waa_deploy/Dockerfile +47 -53
- openadapt_ml/benchmarks/waa_deploy/api_agent.py +21 -20
- openadapt_ml/cloud/azure_inference.py +3 -5
- openadapt_ml/cloud/lambda_labs.py +722 -307
- openadapt_ml/cloud/local.py +2038 -487
- openadapt_ml/cloud/ssh_tunnel.py +68 -26
- openadapt_ml/datasets/next_action.py +40 -30
- openadapt_ml/evals/grounding.py +8 -3
- openadapt_ml/evals/plot_eval_metrics.py +15 -13
- openadapt_ml/evals/trajectory_matching.py +41 -26
- openadapt_ml/experiments/demo_prompt/format_demo.py +16 -6
- openadapt_ml/experiments/demo_prompt/run_experiment.py +26 -16
- openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
- openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
- openadapt_ml/experiments/representation_shootout/config.py +390 -0
- openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
- openadapt_ml/experiments/representation_shootout/runner.py +687 -0
- openadapt_ml/experiments/waa_demo/runner.py +29 -14
- openadapt_ml/export/parquet.py +36 -24
- openadapt_ml/grounding/detector.py +18 -14
- openadapt_ml/ingest/__init__.py +8 -6
- openadapt_ml/ingest/capture.py +25 -22
- openadapt_ml/ingest/loader.py +7 -4
- openadapt_ml/ingest/synthetic.py +189 -100
- openadapt_ml/models/api_adapter.py +14 -4
- openadapt_ml/models/base_adapter.py +10 -2
- openadapt_ml/models/providers/__init__.py +288 -0
- openadapt_ml/models/providers/anthropic.py +266 -0
- openadapt_ml/models/providers/base.py +299 -0
- openadapt_ml/models/providers/google.py +376 -0
- openadapt_ml/models/providers/openai.py +342 -0
- openadapt_ml/models/qwen_vl.py +46 -19
- openadapt_ml/perception/__init__.py +35 -0
- openadapt_ml/perception/integration.py +399 -0
- openadapt_ml/retrieval/demo_retriever.py +50 -24
- openadapt_ml/retrieval/embeddings.py +9 -8
- openadapt_ml/retrieval/retriever.py +3 -1
- openadapt_ml/runtime/__init__.py +50 -0
- openadapt_ml/runtime/policy.py +18 -5
- openadapt_ml/runtime/safety_gate.py +471 -0
- openadapt_ml/schema/__init__.py +9 -0
- openadapt_ml/schema/converters.py +74 -27
- openadapt_ml/schema/episode.py +31 -18
- openadapt_ml/scripts/capture_screenshots.py +530 -0
- openadapt_ml/scripts/compare.py +85 -54
- openadapt_ml/scripts/demo_policy.py +4 -1
- openadapt_ml/scripts/eval_policy.py +15 -9
- openadapt_ml/scripts/make_gif.py +1 -1
- openadapt_ml/scripts/prepare_synthetic.py +3 -1
- openadapt_ml/scripts/train.py +21 -9
- openadapt_ml/segmentation/README.md +920 -0
- openadapt_ml/segmentation/__init__.py +97 -0
- openadapt_ml/segmentation/adapters/__init__.py +5 -0
- openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
- openadapt_ml/segmentation/annotator.py +610 -0
- openadapt_ml/segmentation/cache.py +290 -0
- openadapt_ml/segmentation/cli.py +674 -0
- openadapt_ml/segmentation/deduplicator.py +656 -0
- openadapt_ml/segmentation/frame_describer.py +788 -0
- openadapt_ml/segmentation/pipeline.py +340 -0
- openadapt_ml/segmentation/schemas.py +622 -0
- openadapt_ml/segmentation/segment_extractor.py +634 -0
- openadapt_ml/training/azure_ops_viewer.py +1097 -0
- openadapt_ml/training/benchmark_viewer.py +52 -41
- openadapt_ml/training/shared_ui.py +7 -7
- openadapt_ml/training/stub_provider.py +57 -35
- openadapt_ml/training/trainer.py +143 -86
- openadapt_ml/training/trl_trainer.py +70 -21
- openadapt_ml/training/viewer.py +323 -108
- openadapt_ml/training/viewer_components.py +180 -0
- {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +215 -14
- openadapt_ml-0.2.1.dist-info/RECORD +116 -0
- openadapt_ml/benchmarks/base.py +0 -366
- openadapt_ml/benchmarks/data_collection.py +0 -432
- openadapt_ml/benchmarks/live_tracker.py +0 -180
- openadapt_ml/benchmarks/runner.py +0 -418
- openadapt_ml/benchmarks/waa.py +0 -761
- openadapt_ml/benchmarks/waa_live.py +0 -619
- openadapt_ml-0.2.0.dist-info/RECORD +0 -86
- {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
- {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
openadapt_ml/cloud/local.py
CHANGED
|
@@ -27,7 +27,6 @@ import http.server
|
|
|
27
27
|
import json
|
|
28
28
|
import os
|
|
29
29
|
import shutil
|
|
30
|
-
import signal
|
|
31
30
|
import socketserver
|
|
32
31
|
import subprocess
|
|
33
32
|
import sys
|
|
@@ -109,7 +108,10 @@ def _is_mock_benchmark(benchmark_dir: Path) -> bool:
|
|
|
109
108
|
|
|
110
109
|
# Check for test runs (but allow waa-mock evaluations with real API models)
|
|
111
110
|
# Only filter out purely synthetic test data directories
|
|
112
|
-
if any(
|
|
111
|
+
if any(
|
|
112
|
+
term in benchmark_dir.name.lower()
|
|
113
|
+
for term in ["test_run", "test_cli", "quick_demo"]
|
|
114
|
+
):
|
|
113
115
|
return True
|
|
114
116
|
|
|
115
117
|
return False
|
|
@@ -193,6 +195,7 @@ def _regenerate_benchmark_viewer_if_available(output_dir: Path) -> bool:
|
|
|
193
195
|
except Exception as e:
|
|
194
196
|
print(f" Could not regenerate benchmark viewer: {e}")
|
|
195
197
|
import traceback
|
|
198
|
+
|
|
196
199
|
traceback.print_exc()
|
|
197
200
|
return False
|
|
198
201
|
|
|
@@ -201,6 +204,7 @@ def detect_device() -> str:
|
|
|
201
204
|
"""Detect available compute device."""
|
|
202
205
|
try:
|
|
203
206
|
import torch
|
|
207
|
+
|
|
204
208
|
if torch.cuda.is_available():
|
|
205
209
|
device_name = torch.cuda.get_device_name(0)
|
|
206
210
|
return f"cuda ({device_name})"
|
|
@@ -251,10 +255,13 @@ def get_training_status() -> dict[str, Any]:
|
|
|
251
255
|
# Find checkpoints
|
|
252
256
|
checkpoints_dir = Path("checkpoints")
|
|
253
257
|
if checkpoints_dir.exists():
|
|
254
|
-
status["checkpoints"] = sorted(
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
+
status["checkpoints"] = sorted(
|
|
259
|
+
[
|
|
260
|
+
d.name
|
|
261
|
+
for d in checkpoints_dir.iterdir()
|
|
262
|
+
if d.is_dir() and (d / "adapter_config.json").exists()
|
|
263
|
+
]
|
|
264
|
+
)
|
|
258
265
|
|
|
259
266
|
return status
|
|
260
267
|
|
|
@@ -264,9 +271,9 @@ def cmd_status(args: argparse.Namespace) -> int:
|
|
|
264
271
|
status = get_training_status()
|
|
265
272
|
current_dir = get_current_output_dir()
|
|
266
273
|
|
|
267
|
-
print(f"\n{'='*50}")
|
|
274
|
+
print(f"\n{'=' * 50}")
|
|
268
275
|
print("LOCAL TRAINING STATUS")
|
|
269
|
-
print(f"{'='*50}")
|
|
276
|
+
print(f"{'=' * 50}")
|
|
270
277
|
print(f"Device: {status['device']}")
|
|
271
278
|
print(f"Status: {'RUNNING' if status['running'] else 'IDLE'}")
|
|
272
279
|
if status.get("job_id"):
|
|
@@ -274,7 +281,7 @@ def cmd_status(args: argparse.Namespace) -> int:
|
|
|
274
281
|
print(f"Output: {current_dir}")
|
|
275
282
|
|
|
276
283
|
if status.get("epoch"):
|
|
277
|
-
print(
|
|
284
|
+
print("\nProgress:")
|
|
278
285
|
print(f" Epoch: {status['epoch']}")
|
|
279
286
|
print(f" Step: {status['step']}")
|
|
280
287
|
if status.get("loss"):
|
|
@@ -287,7 +294,9 @@ def cmd_status(args: argparse.Namespace) -> int:
|
|
|
287
294
|
for cp in status["checkpoints"][-5:]: # Show last 5
|
|
288
295
|
print(f" - {cp}")
|
|
289
296
|
|
|
290
|
-
print(
|
|
297
|
+
print(
|
|
298
|
+
f"\nDashboard: {'✓' if status['has_dashboard'] else '✗'} {current_dir}/dashboard.html"
|
|
299
|
+
)
|
|
291
300
|
print(f"Viewer: {'✓' if status['has_viewer'] else '✗'} {current_dir}/viewer.html")
|
|
292
301
|
print()
|
|
293
302
|
|
|
@@ -320,9 +329,9 @@ def cmd_train(args: argparse.Namespace) -> int:
|
|
|
320
329
|
print(f"Error: Config not found: {config_path}")
|
|
321
330
|
return 1
|
|
322
331
|
|
|
323
|
-
print(f"\n{'='*50}")
|
|
332
|
+
print(f"\n{'=' * 50}")
|
|
324
333
|
print("STARTING LOCAL TRAINING")
|
|
325
|
-
print(f"{'='*50}")
|
|
334
|
+
print(f"{'=' * 50}")
|
|
326
335
|
print(f"Capture: {capture_path}")
|
|
327
336
|
print(f"Goal: {goal}")
|
|
328
337
|
print(f"Config: {config}")
|
|
@@ -331,10 +340,15 @@ def cmd_train(args: argparse.Namespace) -> int:
|
|
|
331
340
|
|
|
332
341
|
# Build command
|
|
333
342
|
cmd = [
|
|
334
|
-
sys.executable,
|
|
335
|
-
"
|
|
336
|
-
"
|
|
337
|
-
"--
|
|
343
|
+
sys.executable,
|
|
344
|
+
"-m",
|
|
345
|
+
"openadapt_ml.scripts.train",
|
|
346
|
+
"--config",
|
|
347
|
+
str(config_path),
|
|
348
|
+
"--capture",
|
|
349
|
+
str(capture_path),
|
|
350
|
+
"--goal",
|
|
351
|
+
goal,
|
|
338
352
|
]
|
|
339
353
|
|
|
340
354
|
if args.open:
|
|
@@ -353,14 +367,16 @@ def cmd_check(args: argparse.Namespace) -> int:
|
|
|
353
367
|
"""Check training health and early stopping analysis."""
|
|
354
368
|
status = get_training_status()
|
|
355
369
|
|
|
356
|
-
print(f"\n{'='*50}")
|
|
370
|
+
print(f"\n{'=' * 50}")
|
|
357
371
|
print("TRAINING HEALTH CHECK")
|
|
358
|
-
print(f"{'='*50}")
|
|
372
|
+
print(f"{'=' * 50}")
|
|
359
373
|
|
|
360
374
|
raw_losses = status.get("losses", [])
|
|
361
375
|
if not raw_losses:
|
|
362
376
|
print("No training data found.")
|
|
363
|
-
print(
|
|
377
|
+
print(
|
|
378
|
+
"Run training first with: uv run python -m openadapt_ml.cloud.local train --capture <path>"
|
|
379
|
+
)
|
|
364
380
|
return 1
|
|
365
381
|
|
|
366
382
|
# Extract loss values (handle both dict and float formats)
|
|
@@ -381,7 +397,7 @@ def cmd_check(args: argparse.Namespace) -> int:
|
|
|
381
397
|
min_loss = min(losses)
|
|
382
398
|
max_loss = max(losses)
|
|
383
399
|
|
|
384
|
-
print(
|
|
400
|
+
print("\nLoss progression:")
|
|
385
401
|
print(f" First: {first_loss:.4f}")
|
|
386
402
|
print(f" Last: {last_loss:.4f}")
|
|
387
403
|
print(f" Min: {min_loss:.4f}")
|
|
@@ -392,9 +408,11 @@ def cmd_check(args: argparse.Namespace) -> int:
|
|
|
392
408
|
if len(losses) >= 10:
|
|
393
409
|
recent = losses[-10:]
|
|
394
410
|
recent_avg = sum(recent) / len(recent)
|
|
395
|
-
recent_std = (
|
|
411
|
+
recent_std = (
|
|
412
|
+
sum((x - recent_avg) ** 2 for x in recent) / len(recent)
|
|
413
|
+
) ** 0.5
|
|
396
414
|
|
|
397
|
-
print(
|
|
415
|
+
print("\nRecent stability (last 10 steps):")
|
|
398
416
|
print(f" Avg loss: {recent_avg:.4f}")
|
|
399
417
|
print(f" Std dev: {recent_std:.4f}")
|
|
400
418
|
|
|
@@ -413,14 +431,18 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
413
431
|
"""Start local web server for dashboard.
|
|
414
432
|
|
|
415
433
|
Automatically regenerates dashboard and viewer before serving to ensure
|
|
416
|
-
the latest code and data are reflected.
|
|
434
|
+
the latest code and data are reflected. Also ensures the 'current' symlink
|
|
435
|
+
points to the most recent training run.
|
|
417
436
|
"""
|
|
418
|
-
from openadapt_ml.training.trainer import
|
|
437
|
+
from openadapt_ml.training.trainer import (
|
|
438
|
+
regenerate_local_dashboard,
|
|
439
|
+
update_current_symlink_to_latest,
|
|
440
|
+
)
|
|
419
441
|
|
|
420
442
|
port = args.port
|
|
421
443
|
|
|
422
444
|
# Determine what to serve: benchmark directory or training output
|
|
423
|
-
if hasattr(args,
|
|
445
|
+
if hasattr(args, "benchmark") and args.benchmark:
|
|
424
446
|
serve_dir = Path(args.benchmark).expanduser().resolve()
|
|
425
447
|
if not serve_dir.exists():
|
|
426
448
|
print(f"Error: Benchmark directory not found: {serve_dir}")
|
|
@@ -430,7 +452,10 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
430
452
|
if not args.no_regenerate:
|
|
431
453
|
print("Regenerating benchmark viewer...")
|
|
432
454
|
try:
|
|
433
|
-
from openadapt_ml.training.benchmark_viewer import
|
|
455
|
+
from openadapt_ml.training.benchmark_viewer import (
|
|
456
|
+
generate_benchmark_viewer,
|
|
457
|
+
)
|
|
458
|
+
|
|
434
459
|
generate_benchmark_viewer(serve_dir)
|
|
435
460
|
except Exception as e:
|
|
436
461
|
print(f"Warning: Could not regenerate benchmark viewer: {e}")
|
|
@@ -439,6 +464,17 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
439
464
|
else:
|
|
440
465
|
serve_dir = get_current_output_dir()
|
|
441
466
|
|
|
467
|
+
# If current symlink doesn't exist or is broken, update to latest run
|
|
468
|
+
if not serve_dir.exists() or not serve_dir.is_dir():
|
|
469
|
+
print("Updating 'current' symlink to latest training run...")
|
|
470
|
+
latest = update_current_symlink_to_latest()
|
|
471
|
+
if latest:
|
|
472
|
+
serve_dir = get_current_output_dir()
|
|
473
|
+
print(f" Updated to: {latest.name}")
|
|
474
|
+
else:
|
|
475
|
+
print(f"Error: {serve_dir} not found. Run training first.")
|
|
476
|
+
return 1
|
|
477
|
+
|
|
442
478
|
if not serve_dir.exists():
|
|
443
479
|
print(f"Error: {serve_dir} not found. Run training first.")
|
|
444
480
|
return 1
|
|
@@ -447,7 +483,9 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
447
483
|
if not args.no_regenerate:
|
|
448
484
|
print("Regenerating dashboard and viewer...")
|
|
449
485
|
try:
|
|
450
|
-
|
|
486
|
+
# Use keep_polling=True so JavaScript fetches live data from training_log.json
|
|
487
|
+
# This ensures the dashboard shows current data instead of stale embedded data
|
|
488
|
+
regenerate_local_dashboard(str(serve_dir), keep_polling=True)
|
|
451
489
|
# Also regenerate viewer if comparison data exists
|
|
452
490
|
_regenerate_viewer_if_possible(serve_dir)
|
|
453
491
|
except Exception as e:
|
|
@@ -456,10 +494,21 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
456
494
|
# Also regenerate benchmark viewer from latest benchmark results
|
|
457
495
|
_regenerate_benchmark_viewer_if_available(serve_dir)
|
|
458
496
|
|
|
497
|
+
# Generate Azure ops dashboard
|
|
498
|
+
try:
|
|
499
|
+
from openadapt_ml.training.azure_ops_viewer import (
|
|
500
|
+
generate_azure_ops_dashboard,
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
generate_azure_ops_dashboard(serve_dir / "azure_ops.html")
|
|
504
|
+
print(" Generated Azure ops dashboard")
|
|
505
|
+
except Exception as e:
|
|
506
|
+
print(f" Warning: Could not generate Azure ops dashboard: {e}")
|
|
507
|
+
|
|
459
508
|
start_page = "dashboard.html"
|
|
460
509
|
|
|
461
510
|
# Override start page if specified
|
|
462
|
-
if hasattr(args,
|
|
511
|
+
if hasattr(args, "start_page") and args.start_page:
|
|
463
512
|
start_page = args.start_page
|
|
464
513
|
|
|
465
514
|
# Serve from the specified directory
|
|
@@ -476,33 +525,41 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
476
525
|
super().log_message(format, *log_args)
|
|
477
526
|
|
|
478
527
|
def do_POST(self):
|
|
479
|
-
if self.path ==
|
|
528
|
+
if self.path == "/api/stop":
|
|
480
529
|
# Create stop signal file
|
|
481
530
|
stop_file = serve_dir / "STOP_TRAINING"
|
|
482
531
|
stop_file.touch()
|
|
483
532
|
self.send_response(200)
|
|
484
|
-
self.send_header(
|
|
485
|
-
self.send_header(
|
|
533
|
+
self.send_header("Content-Type", "application/json")
|
|
534
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
486
535
|
self.end_headers()
|
|
487
536
|
self.wfile.write(b'{"status": "stop_signal_created"}')
|
|
488
537
|
print(f"\n⏹ Stop signal created: {stop_file}")
|
|
489
|
-
elif self.path ==
|
|
538
|
+
elif self.path == "/api/run-benchmark":
|
|
490
539
|
# Parse request body for provider
|
|
491
|
-
content_length = int(self.headers.get(
|
|
492
|
-
body =
|
|
540
|
+
content_length = int(self.headers.get("Content-Length", 0))
|
|
541
|
+
body = (
|
|
542
|
+
self.rfile.read(content_length).decode("utf-8")
|
|
543
|
+
if content_length
|
|
544
|
+
else "{}"
|
|
545
|
+
)
|
|
493
546
|
try:
|
|
494
547
|
params = json.loads(body)
|
|
495
548
|
except json.JSONDecodeError:
|
|
496
549
|
params = {}
|
|
497
550
|
|
|
498
|
-
provider = params.get(
|
|
499
|
-
tasks = params.get(
|
|
551
|
+
provider = params.get("provider", "anthropic")
|
|
552
|
+
tasks = params.get("tasks", 5)
|
|
500
553
|
|
|
501
554
|
self.send_response(200)
|
|
502
|
-
self.send_header(
|
|
503
|
-
self.send_header(
|
|
555
|
+
self.send_header("Content-Type", "application/json")
|
|
556
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
504
557
|
self.end_headers()
|
|
505
|
-
self.wfile.write(
|
|
558
|
+
self.wfile.write(
|
|
559
|
+
json.dumps(
|
|
560
|
+
{"status": "started", "provider": provider, "tasks": tasks}
|
|
561
|
+
).encode()
|
|
562
|
+
)
|
|
506
563
|
|
|
507
564
|
# Run benchmark in background thread with progress logging
|
|
508
565
|
def run_benchmark():
|
|
@@ -516,25 +573,45 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
516
573
|
# Create progress log file (in cwd which is serve_dir)
|
|
517
574
|
progress_file = Path("benchmark_progress.json")
|
|
518
575
|
|
|
519
|
-
print(
|
|
576
|
+
print(
|
|
577
|
+
f"\n🚀 Starting {provider} benchmark evaluation ({tasks} tasks)..."
|
|
578
|
+
)
|
|
520
579
|
|
|
521
580
|
# Write initial progress
|
|
522
|
-
progress_file.write_text(
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
581
|
+
progress_file.write_text(
|
|
582
|
+
json.dumps(
|
|
583
|
+
{
|
|
584
|
+
"status": "running",
|
|
585
|
+
"provider": provider,
|
|
586
|
+
"tasks_total": tasks,
|
|
587
|
+
"tasks_complete": 0,
|
|
588
|
+
"message": f"Starting {provider} evaluation...",
|
|
589
|
+
}
|
|
590
|
+
)
|
|
591
|
+
)
|
|
529
592
|
|
|
530
593
|
# Copy environment with loaded vars
|
|
531
594
|
env = os.environ.copy()
|
|
532
595
|
|
|
533
596
|
result = subprocess.run(
|
|
534
|
-
[
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
597
|
+
[
|
|
598
|
+
"uv",
|
|
599
|
+
"run",
|
|
600
|
+
"python",
|
|
601
|
+
"-m",
|
|
602
|
+
"openadapt_ml.benchmarks.cli",
|
|
603
|
+
"run-api",
|
|
604
|
+
"--provider",
|
|
605
|
+
provider,
|
|
606
|
+
"--tasks",
|
|
607
|
+
str(tasks),
|
|
608
|
+
"--model-id",
|
|
609
|
+
f"{provider}-api",
|
|
610
|
+
],
|
|
611
|
+
capture_output=True,
|
|
612
|
+
text=True,
|
|
613
|
+
cwd=str(project_root),
|
|
614
|
+
env=env,
|
|
538
615
|
)
|
|
539
616
|
|
|
540
617
|
print(f"\n📋 Benchmark output:\n{result.stdout}")
|
|
@@ -542,77 +619,95 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
542
619
|
print(f"Stderr: {result.stderr}")
|
|
543
620
|
|
|
544
621
|
if result.returncode == 0:
|
|
545
|
-
print(
|
|
546
|
-
progress_file.write_text(
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
622
|
+
print("✅ Benchmark complete. Regenerating viewer...")
|
|
623
|
+
progress_file.write_text(
|
|
624
|
+
json.dumps(
|
|
625
|
+
{
|
|
626
|
+
"status": "complete",
|
|
627
|
+
"provider": provider,
|
|
628
|
+
"message": "Evaluation complete! Refreshing results...",
|
|
629
|
+
}
|
|
630
|
+
)
|
|
631
|
+
)
|
|
551
632
|
# Regenerate benchmark viewer
|
|
552
633
|
_regenerate_benchmark_viewer_if_available(serve_dir)
|
|
553
634
|
else:
|
|
554
635
|
print(f"❌ Benchmark failed: {result.stderr}")
|
|
555
|
-
progress_file.write_text(
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
636
|
+
progress_file.write_text(
|
|
637
|
+
json.dumps(
|
|
638
|
+
{
|
|
639
|
+
"status": "error",
|
|
640
|
+
"provider": provider,
|
|
641
|
+
"message": f"Evaluation failed: {result.stderr[:200]}",
|
|
642
|
+
}
|
|
643
|
+
)
|
|
644
|
+
)
|
|
560
645
|
|
|
561
646
|
threading.Thread(target=run_benchmark, daemon=True).start()
|
|
562
|
-
elif self.path ==
|
|
647
|
+
elif self.path == "/api/vms/register":
|
|
563
648
|
# Register a new VM
|
|
564
|
-
content_length = int(self.headers.get(
|
|
565
|
-
body =
|
|
649
|
+
content_length = int(self.headers.get("Content-Length", 0))
|
|
650
|
+
body = (
|
|
651
|
+
self.rfile.read(content_length).decode("utf-8")
|
|
652
|
+
if content_length
|
|
653
|
+
else "{}"
|
|
654
|
+
)
|
|
566
655
|
try:
|
|
567
656
|
vm_data = json.loads(body)
|
|
568
657
|
result = self._register_vm(vm_data)
|
|
569
658
|
self.send_response(200)
|
|
570
|
-
self.send_header(
|
|
571
|
-
self.send_header(
|
|
659
|
+
self.send_header("Content-Type", "application/json")
|
|
660
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
572
661
|
self.end_headers()
|
|
573
662
|
self.wfile.write(json.dumps(result).encode())
|
|
574
663
|
except Exception as e:
|
|
575
664
|
self.send_response(500)
|
|
576
|
-
self.send_header(
|
|
577
|
-
self.send_header(
|
|
665
|
+
self.send_header("Content-Type", "application/json")
|
|
666
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
578
667
|
self.end_headers()
|
|
579
668
|
self.wfile.write(json.dumps({"error": str(e)}).encode())
|
|
580
|
-
elif self.path ==
|
|
669
|
+
elif self.path == "/api/benchmark/start":
|
|
581
670
|
# Start a benchmark run with configurable parameters
|
|
582
|
-
content_length = int(self.headers.get(
|
|
583
|
-
body =
|
|
671
|
+
content_length = int(self.headers.get("Content-Length", 0))
|
|
672
|
+
body = (
|
|
673
|
+
self.rfile.read(content_length).decode("utf-8")
|
|
674
|
+
if content_length
|
|
675
|
+
else "{}"
|
|
676
|
+
)
|
|
584
677
|
try:
|
|
585
678
|
params = json.loads(body)
|
|
586
679
|
result = self._start_benchmark_run(params)
|
|
587
680
|
self.send_response(200)
|
|
588
|
-
self.send_header(
|
|
589
|
-
self.send_header(
|
|
681
|
+
self.send_header("Content-Type", "application/json")
|
|
682
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
590
683
|
self.end_headers()
|
|
591
684
|
self.wfile.write(json.dumps(result).encode())
|
|
592
685
|
except Exception as e:
|
|
593
686
|
self.send_response(500)
|
|
594
|
-
self.send_header(
|
|
595
|
-
self.send_header(
|
|
687
|
+
self.send_header("Content-Type", "application/json")
|
|
688
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
596
689
|
self.end_headers()
|
|
597
690
|
self.wfile.write(json.dumps({"error": str(e)}).encode())
|
|
598
691
|
else:
|
|
599
692
|
self.send_error(404, "Not found")
|
|
600
693
|
|
|
601
694
|
def do_GET(self):
|
|
602
|
-
if self.path.startswith(
|
|
695
|
+
if self.path.startswith("/api/benchmark-progress"):
|
|
603
696
|
# Return benchmark progress
|
|
604
|
-
progress_file = Path(
|
|
697
|
+
progress_file = Path(
|
|
698
|
+
"benchmark_progress.json"
|
|
699
|
+
) # Relative to serve_dir (cwd)
|
|
605
700
|
if progress_file.exists():
|
|
606
701
|
progress = progress_file.read_text()
|
|
607
702
|
else:
|
|
608
703
|
progress = json.dumps({"status": "idle"})
|
|
609
704
|
|
|
610
705
|
self.send_response(200)
|
|
611
|
-
self.send_header(
|
|
612
|
-
self.send_header(
|
|
706
|
+
self.send_header("Content-Type", "application/json")
|
|
707
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
613
708
|
self.end_headers()
|
|
614
709
|
self.wfile.write(progress.encode())
|
|
615
|
-
elif self.path.startswith(
|
|
710
|
+
elif self.path.startswith("/api/benchmark-live"):
|
|
616
711
|
# Return live evaluation state
|
|
617
712
|
live_file = Path("benchmark_live.json") # Relative to serve_dir (cwd)
|
|
618
713
|
if live_file.exists():
|
|
@@ -621,32 +716,33 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
621
716
|
live_state = json.dumps({"status": "idle"})
|
|
622
717
|
|
|
623
718
|
self.send_response(200)
|
|
624
|
-
self.send_header(
|
|
625
|
-
self.send_header(
|
|
719
|
+
self.send_header("Content-Type", "application/json")
|
|
720
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
626
721
|
self.end_headers()
|
|
627
722
|
self.wfile.write(live_state.encode())
|
|
628
|
-
elif self.path.startswith(
|
|
723
|
+
elif self.path.startswith("/api/tasks"):
|
|
629
724
|
# Return background task status (VM, Docker, benchmarks)
|
|
630
725
|
try:
|
|
631
726
|
tasks = self._fetch_background_tasks()
|
|
632
727
|
self.send_response(200)
|
|
633
|
-
self.send_header(
|
|
634
|
-
self.send_header(
|
|
728
|
+
self.send_header("Content-Type", "application/json")
|
|
729
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
635
730
|
self.end_headers()
|
|
636
731
|
self.wfile.write(json.dumps(tasks).encode())
|
|
637
732
|
except Exception as e:
|
|
638
733
|
self.send_response(500)
|
|
639
|
-
self.send_header(
|
|
640
|
-
self.send_header(
|
|
734
|
+
self.send_header("Content-Type", "application/json")
|
|
735
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
641
736
|
self.end_headers()
|
|
642
737
|
self.wfile.write(json.dumps({"error": str(e)}).encode())
|
|
643
|
-
elif self.path.startswith(
|
|
738
|
+
elif self.path.startswith("/api/azure-jobs"):
|
|
644
739
|
# Return LIVE Azure job status from Azure ML
|
|
645
740
|
# Supports ?force=true parameter for manual refresh (always fetches live)
|
|
646
741
|
try:
|
|
647
742
|
from urllib.parse import urlparse, parse_qs
|
|
743
|
+
|
|
648
744
|
query = parse_qs(urlparse(self.path).query)
|
|
649
|
-
force_refresh = query.get(
|
|
745
|
+
force_refresh = query.get("force", ["false"])[0].lower() == "true"
|
|
650
746
|
|
|
651
747
|
# Always fetch live data (force just indicates manual refresh for logging)
|
|
652
748
|
if force_refresh:
|
|
@@ -654,22 +750,23 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
654
750
|
|
|
655
751
|
jobs = self._fetch_live_azure_jobs()
|
|
656
752
|
self.send_response(200)
|
|
657
|
-
self.send_header(
|
|
658
|
-
self.send_header(
|
|
753
|
+
self.send_header("Content-Type", "application/json")
|
|
754
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
659
755
|
self.end_headers()
|
|
660
756
|
self.wfile.write(json.dumps(jobs).encode())
|
|
661
757
|
except Exception as e:
|
|
662
758
|
self.send_response(500)
|
|
663
|
-
self.send_header(
|
|
664
|
-
self.send_header(
|
|
759
|
+
self.send_header("Content-Type", "application/json")
|
|
760
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
665
761
|
self.end_headers()
|
|
666
762
|
self.wfile.write(json.dumps({"error": str(e)}).encode())
|
|
667
|
-
elif self.path.startswith(
|
|
763
|
+
elif self.path.startswith("/api/benchmark-sse"):
|
|
668
764
|
# Server-Sent Events endpoint for real-time benchmark updates
|
|
669
765
|
try:
|
|
670
766
|
from urllib.parse import urlparse, parse_qs
|
|
767
|
+
|
|
671
768
|
query = parse_qs(urlparse(self.path).query)
|
|
672
|
-
interval = int(query.get(
|
|
769
|
+
interval = int(query.get("interval", [5])[0])
|
|
673
770
|
|
|
674
771
|
# Validate interval (min 1s, max 60s)
|
|
675
772
|
interval = max(1, min(60, interval))
|
|
@@ -677,57 +774,60 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
677
774
|
self._stream_benchmark_updates(interval)
|
|
678
775
|
except Exception as e:
|
|
679
776
|
self.send_error(500, f"SSE error: {e}")
|
|
680
|
-
elif self.path.startswith(
|
|
777
|
+
elif self.path.startswith("/api/vms"):
|
|
681
778
|
# Return VM registry with live status
|
|
682
779
|
try:
|
|
683
780
|
vms = self._fetch_vm_registry()
|
|
684
781
|
self.send_response(200)
|
|
685
|
-
self.send_header(
|
|
686
|
-
self.send_header(
|
|
782
|
+
self.send_header("Content-Type", "application/json")
|
|
783
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
687
784
|
self.end_headers()
|
|
688
785
|
self.wfile.write(json.dumps(vms).encode())
|
|
689
786
|
except Exception as e:
|
|
690
787
|
self.send_response(500)
|
|
691
|
-
self.send_header(
|
|
692
|
-
self.send_header(
|
|
788
|
+
self.send_header("Content-Type", "application/json")
|
|
789
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
693
790
|
self.end_headers()
|
|
694
791
|
self.wfile.write(json.dumps({"error": str(e)}).encode())
|
|
695
|
-
elif self.path.startswith(
|
|
792
|
+
elif self.path.startswith("/api/azure-job-logs"):
|
|
696
793
|
# Return live logs for running Azure job
|
|
697
794
|
try:
|
|
698
795
|
# Parse job_id from query string
|
|
699
796
|
from urllib.parse import urlparse, parse_qs
|
|
797
|
+
|
|
700
798
|
query = parse_qs(urlparse(self.path).query)
|
|
701
|
-
job_id = query.get(
|
|
799
|
+
job_id = query.get("job_id", [None])[0]
|
|
702
800
|
|
|
703
801
|
logs = self._fetch_azure_job_logs(job_id)
|
|
704
802
|
self.send_response(200)
|
|
705
|
-
self.send_header(
|
|
706
|
-
self.send_header(
|
|
803
|
+
self.send_header("Content-Type", "application/json")
|
|
804
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
707
805
|
self.end_headers()
|
|
708
806
|
self.wfile.write(json.dumps(logs).encode())
|
|
709
807
|
except Exception as e:
|
|
710
808
|
self.send_response(500)
|
|
711
|
-
self.send_header(
|
|
712
|
-
self.send_header(
|
|
809
|
+
self.send_header("Content-Type", "application/json")
|
|
810
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
713
811
|
self.end_headers()
|
|
714
812
|
self.wfile.write(json.dumps({"error": str(e)}).encode())
|
|
715
|
-
elif self.path.startswith(
|
|
813
|
+
elif self.path.startswith("/api/probe-vm"):
|
|
716
814
|
# Probe the VM to check if WAA server is responding
|
|
717
815
|
try:
|
|
718
816
|
result = self._probe_vm()
|
|
719
817
|
self.send_response(200)
|
|
720
|
-
self.send_header(
|
|
721
|
-
self.send_header(
|
|
818
|
+
self.send_header("Content-Type", "application/json")
|
|
819
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
722
820
|
self.end_headers()
|
|
723
821
|
self.wfile.write(json.dumps(result).encode())
|
|
724
822
|
except Exception as e:
|
|
725
823
|
self.send_response(500)
|
|
726
|
-
self.send_header(
|
|
727
|
-
self.send_header(
|
|
824
|
+
self.send_header("Content-Type", "application/json")
|
|
825
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
728
826
|
self.end_headers()
|
|
729
|
-
self.wfile.write(
|
|
730
|
-
|
|
827
|
+
self.wfile.write(
|
|
828
|
+
json.dumps({"error": str(e), "responding": False}).encode()
|
|
829
|
+
)
|
|
830
|
+
elif self.path.startswith("/api/tunnels"):
|
|
731
831
|
# Return SSH tunnel status
|
|
732
832
|
try:
|
|
733
833
|
tunnel_mgr = get_tunnel_manager()
|
|
@@ -743,44 +843,293 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
743
843
|
for name, s in status.items()
|
|
744
844
|
}
|
|
745
845
|
self.send_response(200)
|
|
746
|
-
self.send_header(
|
|
747
|
-
self.send_header(
|
|
846
|
+
self.send_header("Content-Type", "application/json")
|
|
847
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
748
848
|
self.end_headers()
|
|
749
849
|
self.wfile.write(json.dumps(result).encode())
|
|
750
850
|
except Exception as e:
|
|
751
851
|
self.send_response(500)
|
|
752
|
-
self.send_header(
|
|
753
|
-
self.send_header(
|
|
852
|
+
self.send_header("Content-Type", "application/json")
|
|
853
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
754
854
|
self.end_headers()
|
|
755
855
|
self.wfile.write(json.dumps({"error": str(e)}).encode())
|
|
756
|
-
elif self.path.startswith(
|
|
856
|
+
elif self.path.startswith("/api/current-run"):
|
|
757
857
|
# Return currently running benchmark info
|
|
758
858
|
try:
|
|
759
859
|
result = self._get_current_run()
|
|
760
860
|
self.send_response(200)
|
|
761
|
-
self.send_header(
|
|
762
|
-
self.send_header(
|
|
861
|
+
self.send_header("Content-Type", "application/json")
|
|
862
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
763
863
|
self.end_headers()
|
|
764
864
|
self.wfile.write(json.dumps(result).encode())
|
|
765
865
|
except Exception as e:
|
|
766
866
|
self.send_response(500)
|
|
767
|
-
self.send_header(
|
|
768
|
-
self.send_header(
|
|
867
|
+
self.send_header("Content-Type", "application/json")
|
|
868
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
769
869
|
self.end_headers()
|
|
770
|
-
self.wfile.write(
|
|
771
|
-
|
|
870
|
+
self.wfile.write(
|
|
871
|
+
json.dumps({"error": str(e), "running": False}).encode()
|
|
872
|
+
)
|
|
873
|
+
elif self.path.startswith("/api/background-tasks"):
|
|
772
874
|
# Alias for /api/tasks - background task status
|
|
773
875
|
try:
|
|
774
876
|
tasks = self._fetch_background_tasks()
|
|
775
877
|
self.send_response(200)
|
|
776
|
-
self.send_header(
|
|
777
|
-
self.send_header(
|
|
878
|
+
self.send_header("Content-Type", "application/json")
|
|
879
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
778
880
|
self.end_headers()
|
|
779
881
|
self.wfile.write(json.dumps(tasks).encode())
|
|
780
882
|
except Exception as e:
|
|
781
883
|
self.send_response(500)
|
|
782
|
-
self.send_header(
|
|
783
|
-
self.send_header(
|
|
884
|
+
self.send_header("Content-Type", "application/json")
|
|
885
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
886
|
+
self.end_headers()
|
|
887
|
+
self.wfile.write(json.dumps({"error": str(e)}).encode())
|
|
888
|
+
elif self.path.startswith("/api/benchmark/status"):
|
|
889
|
+
# Return current benchmark job status with ETA
|
|
890
|
+
try:
|
|
891
|
+
status = self._get_benchmark_status()
|
|
892
|
+
self.send_response(200)
|
|
893
|
+
self.send_header("Content-Type", "application/json")
|
|
894
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
895
|
+
self.end_headers()
|
|
896
|
+
self.wfile.write(json.dumps(status).encode())
|
|
897
|
+
except Exception as e:
|
|
898
|
+
self.send_response(500)
|
|
899
|
+
self.send_header("Content-Type", "application/json")
|
|
900
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
901
|
+
self.end_headers()
|
|
902
|
+
self.wfile.write(
|
|
903
|
+
json.dumps({"error": str(e), "status": "error"}).encode()
|
|
904
|
+
)
|
|
905
|
+
elif self.path.startswith("/api/benchmark/costs"):
|
|
906
|
+
# Return cost breakdown (Azure VM, API calls, GPU)
|
|
907
|
+
try:
|
|
908
|
+
costs = self._get_benchmark_costs()
|
|
909
|
+
self.send_response(200)
|
|
910
|
+
self.send_header("Content-Type", "application/json")
|
|
911
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
912
|
+
self.end_headers()
|
|
913
|
+
self.wfile.write(json.dumps(costs).encode())
|
|
914
|
+
except Exception as e:
|
|
915
|
+
self.send_response(500)
|
|
916
|
+
self.send_header("Content-Type", "application/json")
|
|
917
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
918
|
+
self.end_headers()
|
|
919
|
+
self.wfile.write(json.dumps({"error": str(e)}).encode())
|
|
920
|
+
elif self.path.startswith("/api/benchmark/metrics"):
|
|
921
|
+
# Return performance metrics (success rate, domain breakdown)
|
|
922
|
+
try:
|
|
923
|
+
metrics = self._get_benchmark_metrics()
|
|
924
|
+
self.send_response(200)
|
|
925
|
+
self.send_header("Content-Type", "application/json")
|
|
926
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
927
|
+
self.end_headers()
|
|
928
|
+
self.wfile.write(json.dumps(metrics).encode())
|
|
929
|
+
except Exception as e:
|
|
930
|
+
self.send_response(500)
|
|
931
|
+
self.send_header("Content-Type", "application/json")
|
|
932
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
933
|
+
self.end_headers()
|
|
934
|
+
self.wfile.write(json.dumps({"error": str(e)}).encode())
|
|
935
|
+
elif self.path.startswith("/api/benchmark/workers"):
|
|
936
|
+
# Return worker status and utilization
|
|
937
|
+
try:
|
|
938
|
+
workers = self._get_benchmark_workers()
|
|
939
|
+
self.send_response(200)
|
|
940
|
+
self.send_header("Content-Type", "application/json")
|
|
941
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
942
|
+
self.end_headers()
|
|
943
|
+
self.wfile.write(json.dumps(workers).encode())
|
|
944
|
+
except Exception as e:
|
|
945
|
+
self.send_response(500)
|
|
946
|
+
self.send_header("Content-Type", "application/json")
|
|
947
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
948
|
+
self.end_headers()
|
|
949
|
+
self.wfile.write(json.dumps({"error": str(e)}).encode())
|
|
950
|
+
elif self.path.startswith("/api/benchmark/runs"):
|
|
951
|
+
# Return list of all benchmark runs
|
|
952
|
+
try:
|
|
953
|
+
runs = self._get_benchmark_runs()
|
|
954
|
+
self.send_response(200)
|
|
955
|
+
self.send_header("Content-Type", "application/json")
|
|
956
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
957
|
+
self.end_headers()
|
|
958
|
+
self.wfile.write(json.dumps(runs).encode())
|
|
959
|
+
except Exception as e:
|
|
960
|
+
self.send_response(500)
|
|
961
|
+
self.send_header("Content-Type", "application/json")
|
|
962
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
963
|
+
self.end_headers()
|
|
964
|
+
self.wfile.write(json.dumps({"error": str(e)}).encode())
|
|
965
|
+
elif self.path.startswith("/api/benchmark/tasks/"):
|
|
966
|
+
# Return task execution details
|
|
967
|
+
# URL format: /api/benchmark/tasks/{run_name}/{task_id}
|
|
968
|
+
try:
|
|
969
|
+
parts = self.path.split("/")
|
|
970
|
+
if len(parts) >= 6:
|
|
971
|
+
run_name = parts[4]
|
|
972
|
+
task_id = parts[5]
|
|
973
|
+
execution = self._get_task_execution(run_name, task_id)
|
|
974
|
+
self.send_response(200)
|
|
975
|
+
self.send_header("Content-Type", "application/json")
|
|
976
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
977
|
+
self.end_headers()
|
|
978
|
+
self.wfile.write(json.dumps(execution).encode())
|
|
979
|
+
else:
|
|
980
|
+
self.send_error(400, "Invalid path format")
|
|
981
|
+
except Exception as e:
|
|
982
|
+
self.send_response(500)
|
|
983
|
+
self.send_header("Content-Type", "application/json")
|
|
984
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
985
|
+
self.end_headers()
|
|
986
|
+
self.wfile.write(json.dumps({"error": str(e)}).encode())
|
|
987
|
+
elif self.path.startswith("/api/benchmark/screenshots/"):
|
|
988
|
+
# Serve screenshot files
|
|
989
|
+
# URL format: /api/benchmark/screenshots/{run_name}/{task_id}/screenshots/{filename}
|
|
990
|
+
try:
|
|
991
|
+
# Remove /api/benchmark/screenshots/ prefix
|
|
992
|
+
path_parts = self.path.replace(
|
|
993
|
+
"/api/benchmark/screenshots/", ""
|
|
994
|
+
).split("/")
|
|
995
|
+
if len(path_parts) >= 4:
|
|
996
|
+
run_name = path_parts[0]
|
|
997
|
+
task_id = path_parts[1]
|
|
998
|
+
# path_parts[2] should be 'screenshots'
|
|
999
|
+
filename = path_parts[3]
|
|
1000
|
+
|
|
1001
|
+
results_dir = Path("benchmark_results")
|
|
1002
|
+
screenshot_path = (
|
|
1003
|
+
results_dir
|
|
1004
|
+
/ run_name
|
|
1005
|
+
/ "tasks"
|
|
1006
|
+
/ task_id
|
|
1007
|
+
/ "screenshots"
|
|
1008
|
+
/ filename
|
|
1009
|
+
)
|
|
1010
|
+
|
|
1011
|
+
if screenshot_path.exists():
|
|
1012
|
+
self.send_response(200)
|
|
1013
|
+
self.send_header("Content-Type", "image/png")
|
|
1014
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
1015
|
+
self.end_headers()
|
|
1016
|
+
with open(screenshot_path, "rb") as f:
|
|
1017
|
+
self.wfile.write(f.read())
|
|
1018
|
+
else:
|
|
1019
|
+
self.send_error(
|
|
1020
|
+
404, f"Screenshot not found: {screenshot_path}"
|
|
1021
|
+
)
|
|
1022
|
+
else:
|
|
1023
|
+
self.send_error(400, "Invalid screenshot path format")
|
|
1024
|
+
except Exception as e:
|
|
1025
|
+
self.send_error(500, f"Error serving screenshot: {e}")
|
|
1026
|
+
elif self.path.startswith("/api/azure-ops-sse"):
|
|
1027
|
+
# Server-Sent Events endpoint for Azure operations status
|
|
1028
|
+
try:
|
|
1029
|
+
self._stream_azure_ops_updates()
|
|
1030
|
+
except Exception as e:
|
|
1031
|
+
self.send_error(500, f"SSE error: {e}")
|
|
1032
|
+
elif self.path.startswith("/api/azure-ops-status"):
|
|
1033
|
+
# Return Azure operations status from JSON file
|
|
1034
|
+
# Session tracker provides elapsed_seconds and cost_usd for
|
|
1035
|
+
# persistence across page refreshes
|
|
1036
|
+
try:
|
|
1037
|
+
from openadapt_ml.benchmarks.azure_ops_tracker import read_status
|
|
1038
|
+
from openadapt_ml.benchmarks.session_tracker import (
|
|
1039
|
+
get_session,
|
|
1040
|
+
update_session_vm_state,
|
|
1041
|
+
)
|
|
1042
|
+
|
|
1043
|
+
# Get operation status (current task)
|
|
1044
|
+
status = read_status()
|
|
1045
|
+
|
|
1046
|
+
# Get session data (persistent across refreshes)
|
|
1047
|
+
session = get_session()
|
|
1048
|
+
|
|
1049
|
+
# Update session based on VM state if we have VM info
|
|
1050
|
+
# IMPORTANT: Only pass vm_ip if it's truthy to avoid
|
|
1051
|
+
# overwriting session's stable vm_ip with None
|
|
1052
|
+
if status.get("vm_state") and status.get("vm_state") != "unknown":
|
|
1053
|
+
status_vm_ip = status.get("vm_ip")
|
|
1054
|
+
# Build update kwargs - only include vm_ip if present
|
|
1055
|
+
update_kwargs = {
|
|
1056
|
+
"vm_state": status["vm_state"],
|
|
1057
|
+
"vm_size": status.get("vm_size"),
|
|
1058
|
+
}
|
|
1059
|
+
if status_vm_ip: # Only include if truthy
|
|
1060
|
+
update_kwargs["vm_ip"] = status_vm_ip
|
|
1061
|
+
session = update_session_vm_state(**update_kwargs)
|
|
1062
|
+
|
|
1063
|
+
# Use session's vm_ip as authoritative source
|
|
1064
|
+
# This prevents IP flickering when status file has stale/None values
|
|
1065
|
+
if session.get("vm_ip"):
|
|
1066
|
+
status["vm_ip"] = session["vm_ip"]
|
|
1067
|
+
|
|
1068
|
+
# Use session's elapsed_seconds and cost_usd for persistence
|
|
1069
|
+
# These survive page refreshes and track total VM runtime
|
|
1070
|
+
if (
|
|
1071
|
+
session.get("is_active")
|
|
1072
|
+
or session.get("accumulated_seconds", 0) > 0
|
|
1073
|
+
):
|
|
1074
|
+
status["elapsed_seconds"] = session.get("elapsed_seconds", 0.0)
|
|
1075
|
+
status["cost_usd"] = session.get("cost_usd", 0.0)
|
|
1076
|
+
status["started_at"] = session.get("started_at")
|
|
1077
|
+
# Include session metadata for debugging
|
|
1078
|
+
status["session_id"] = session.get("session_id")
|
|
1079
|
+
status["session_is_active"] = session.get("is_active", False)
|
|
1080
|
+
# Include accumulated time from previous sessions for hybrid display
|
|
1081
|
+
status["accumulated_seconds"] = session.get(
|
|
1082
|
+
"accumulated_seconds", 0.0
|
|
1083
|
+
)
|
|
1084
|
+
# Calculate current session time (total - accumulated)
|
|
1085
|
+
current_session_seconds = max(
|
|
1086
|
+
0, status["elapsed_seconds"] - status["accumulated_seconds"]
|
|
1087
|
+
)
|
|
1088
|
+
status["current_session_seconds"] = current_session_seconds
|
|
1089
|
+
status["current_session_cost_usd"] = (
|
|
1090
|
+
current_session_seconds / 3600
|
|
1091
|
+
) * session.get("hourly_rate_usd", 0.422)
|
|
1092
|
+
|
|
1093
|
+
try:
|
|
1094
|
+
tunnel_mgr = get_tunnel_manager()
|
|
1095
|
+
tunnel_status = tunnel_mgr.get_tunnel_status()
|
|
1096
|
+
status["tunnels"] = {
|
|
1097
|
+
name: {
|
|
1098
|
+
"active": s.active,
|
|
1099
|
+
"local_port": s.local_port,
|
|
1100
|
+
"remote_endpoint": s.remote_endpoint,
|
|
1101
|
+
"pid": s.pid,
|
|
1102
|
+
"error": s.error,
|
|
1103
|
+
}
|
|
1104
|
+
for name, s in tunnel_status.items()
|
|
1105
|
+
}
|
|
1106
|
+
except Exception as e:
|
|
1107
|
+
status["tunnels"] = {"error": str(e)}
|
|
1108
|
+
|
|
1109
|
+
self.send_response(200)
|
|
1110
|
+
self.send_header("Content-Type", "application/json")
|
|
1111
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
1112
|
+
self.end_headers()
|
|
1113
|
+
self.wfile.write(json.dumps(status).encode())
|
|
1114
|
+
except Exception as e:
|
|
1115
|
+
self.send_response(500)
|
|
1116
|
+
self.send_header("Content-Type", "application/json")
|
|
1117
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
1118
|
+
self.end_headers()
|
|
1119
|
+
self.wfile.write(json.dumps({"error": str(e)}).encode())
|
|
1120
|
+
elif self.path.startswith("/api/vm-diagnostics"):
|
|
1121
|
+
# Return VM diagnostics: disk usage, Docker stats, memory usage
|
|
1122
|
+
try:
|
|
1123
|
+
diagnostics = self._get_vm_diagnostics()
|
|
1124
|
+
self.send_response(200)
|
|
1125
|
+
self.send_header("Content-Type", "application/json")
|
|
1126
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
1127
|
+
self.end_headers()
|
|
1128
|
+
self.wfile.write(json.dumps(diagnostics).encode())
|
|
1129
|
+
except Exception as e:
|
|
1130
|
+
self.send_response(500)
|
|
1131
|
+
self.send_header("Content-Type", "application/json")
|
|
1132
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
784
1133
|
self.end_headers()
|
|
785
1134
|
self.wfile.write(json.dumps({"error": str(e)}).encode())
|
|
786
1135
|
else:
|
|
@@ -790,13 +1139,25 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
790
1139
|
def _fetch_live_azure_jobs(self):
|
|
791
1140
|
"""Fetch live job status from Azure ML."""
|
|
792
1141
|
import subprocess
|
|
1142
|
+
|
|
793
1143
|
result = subprocess.run(
|
|
794
|
-
[
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
1144
|
+
[
|
|
1145
|
+
"az",
|
|
1146
|
+
"ml",
|
|
1147
|
+
"job",
|
|
1148
|
+
"list",
|
|
1149
|
+
"--resource-group",
|
|
1150
|
+
"openadapt-agents",
|
|
1151
|
+
"--workspace-name",
|
|
1152
|
+
"openadapt-ml",
|
|
1153
|
+
"--query",
|
|
1154
|
+
"[].{name:name,display_name:display_name,status:status,creation_context:creation_context.created_at}",
|
|
1155
|
+
"-o",
|
|
1156
|
+
"json",
|
|
1157
|
+
],
|
|
1158
|
+
capture_output=True,
|
|
1159
|
+
text=True,
|
|
1160
|
+
timeout=30,
|
|
800
1161
|
)
|
|
801
1162
|
if result.returncode != 0:
|
|
802
1163
|
raise Exception(f"Azure CLI error: {result.stderr}")
|
|
@@ -808,14 +1169,16 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
808
1169
|
|
|
809
1170
|
formatted = []
|
|
810
1171
|
for job in jobs[:10]: # Limit to 10 most recent
|
|
811
|
-
formatted.append(
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
1172
|
+
formatted.append(
|
|
1173
|
+
{
|
|
1174
|
+
"job_id": job.get("name", "unknown"),
|
|
1175
|
+
"display_name": job.get("display_name", ""),
|
|
1176
|
+
"status": job.get("status", "unknown").lower(),
|
|
1177
|
+
"started_at": job.get("creation_context", ""),
|
|
1178
|
+
"azure_dashboard_url": f"https://ml.azure.com/experiments/id/{experiment_id}/runs/{job.get('name', '')}?wsid={wsid}",
|
|
1179
|
+
"is_live": True, # Flag to indicate this is live data
|
|
1180
|
+
}
|
|
1181
|
+
)
|
|
819
1182
|
return formatted
|
|
820
1183
|
|
|
821
1184
|
def _fetch_azure_job_logs(self, job_id: str | None):
|
|
@@ -825,34 +1188,63 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
825
1188
|
if not job_id:
|
|
826
1189
|
# Get the most recent running job
|
|
827
1190
|
jobs = self._fetch_live_azure_jobs()
|
|
828
|
-
running = [j for j in jobs if j[
|
|
1191
|
+
running = [j for j in jobs if j["status"] == "running"]
|
|
829
1192
|
if running:
|
|
830
|
-
job_id = running[0][
|
|
1193
|
+
job_id = running[0]["job_id"]
|
|
831
1194
|
else:
|
|
832
|
-
return {
|
|
1195
|
+
return {
|
|
1196
|
+
"logs": "No running jobs found",
|
|
1197
|
+
"job_id": None,
|
|
1198
|
+
"status": "idle",
|
|
1199
|
+
}
|
|
833
1200
|
|
|
834
1201
|
# Try to stream logs for running job using az ml job stream
|
|
835
1202
|
try:
|
|
836
1203
|
result = subprocess.run(
|
|
837
|
-
[
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
1204
|
+
[
|
|
1205
|
+
"az",
|
|
1206
|
+
"ml",
|
|
1207
|
+
"job",
|
|
1208
|
+
"stream",
|
|
1209
|
+
"--name",
|
|
1210
|
+
job_id,
|
|
1211
|
+
"--resource-group",
|
|
1212
|
+
"openadapt-agents",
|
|
1213
|
+
"--workspace-name",
|
|
1214
|
+
"openadapt-ml",
|
|
1215
|
+
],
|
|
1216
|
+
capture_output=True,
|
|
1217
|
+
text=True,
|
|
1218
|
+
timeout=3, # Short timeout
|
|
842
1219
|
)
|
|
843
1220
|
if result.returncode == 0 and result.stdout.strip():
|
|
844
|
-
return {
|
|
1221
|
+
return {
|
|
1222
|
+
"logs": result.stdout[-5000:],
|
|
1223
|
+
"job_id": job_id,
|
|
1224
|
+
"status": "streaming",
|
|
1225
|
+
}
|
|
845
1226
|
except subprocess.TimeoutExpired:
|
|
846
1227
|
pass # Fall through to job show
|
|
847
1228
|
|
|
848
1229
|
# Get job details instead
|
|
849
1230
|
result = subprocess.run(
|
|
850
|
-
[
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
1231
|
+
[
|
|
1232
|
+
"az",
|
|
1233
|
+
"ml",
|
|
1234
|
+
"job",
|
|
1235
|
+
"show",
|
|
1236
|
+
"--name",
|
|
1237
|
+
job_id,
|
|
1238
|
+
"--resource-group",
|
|
1239
|
+
"openadapt-agents",
|
|
1240
|
+
"--workspace-name",
|
|
1241
|
+
"openadapt-ml",
|
|
1242
|
+
"-o",
|
|
1243
|
+
"json",
|
|
1244
|
+
],
|
|
1245
|
+
capture_output=True,
|
|
1246
|
+
text=True,
|
|
1247
|
+
timeout=10,
|
|
856
1248
|
)
|
|
857
1249
|
|
|
858
1250
|
if result.returncode == 0:
|
|
@@ -860,20 +1252,25 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
860
1252
|
return {
|
|
861
1253
|
"logs": f"Job {job_id} is {job_info.get('status', 'unknown')}\\n\\nCommand: {job_info.get('command', 'N/A')}",
|
|
862
1254
|
"job_id": job_id,
|
|
863
|
-
"status": job_info.get(
|
|
864
|
-
"command": job_info.get(
|
|
1255
|
+
"status": job_info.get("status", "unknown").lower(),
|
|
1256
|
+
"command": job_info.get("command", ""),
|
|
865
1257
|
}
|
|
866
1258
|
|
|
867
|
-
return {
|
|
1259
|
+
return {
|
|
1260
|
+
"logs": f"Could not fetch logs: {result.stderr}",
|
|
1261
|
+
"job_id": job_id,
|
|
1262
|
+
"status": "error",
|
|
1263
|
+
}
|
|
868
1264
|
|
|
869
|
-
def _get_vm_detailed_metadata(
|
|
1265
|
+
def _get_vm_detailed_metadata(
|
|
1266
|
+
self, vm_ip: str, container_name: str, logs: str, phase: str
|
|
1267
|
+
) -> dict:
|
|
870
1268
|
"""Get detailed VM metadata for the VM Details panel.
|
|
871
1269
|
|
|
872
1270
|
Returns:
|
|
873
1271
|
dict with disk_usage_gb, memory_usage_mb, setup_script_phase, probe_response, qmp_connected, dependencies
|
|
874
1272
|
"""
|
|
875
1273
|
import subprocess
|
|
876
|
-
import re
|
|
877
1274
|
|
|
878
1275
|
metadata = {
|
|
879
1276
|
"disk_usage_gb": None,
|
|
@@ -881,17 +1278,26 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
881
1278
|
"setup_script_phase": None,
|
|
882
1279
|
"probe_response": None,
|
|
883
1280
|
"qmp_connected": False,
|
|
884
|
-
"dependencies": []
|
|
1281
|
+
"dependencies": [],
|
|
885
1282
|
}
|
|
886
1283
|
|
|
887
1284
|
# 1. Get disk usage from docker stats
|
|
888
1285
|
try:
|
|
889
1286
|
disk_result = subprocess.run(
|
|
890
|
-
[
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
1287
|
+
[
|
|
1288
|
+
"ssh",
|
|
1289
|
+
"-o",
|
|
1290
|
+
"StrictHostKeyChecking=no",
|
|
1291
|
+
"-o",
|
|
1292
|
+
"ConnectTimeout=5",
|
|
1293
|
+
"-i",
|
|
1294
|
+
str(Path.home() / ".ssh" / "id_rsa"),
|
|
1295
|
+
f"azureuser@{vm_ip}",
|
|
1296
|
+
f"docker exec {container_name} df -h /storage 2>/dev/null | tail -1",
|
|
1297
|
+
],
|
|
1298
|
+
capture_output=True,
|
|
1299
|
+
text=True,
|
|
1300
|
+
timeout=10,
|
|
895
1301
|
)
|
|
896
1302
|
if disk_result.returncode == 0 and disk_result.stdout.strip():
|
|
897
1303
|
# Parse: "Filesystem Size Used Avail Use% Mounted on"
|
|
@@ -900,27 +1306,40 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
900
1306
|
if len(parts) >= 3:
|
|
901
1307
|
used_str = parts[2] # e.g., "9.2G"
|
|
902
1308
|
total_str = parts[1] # e.g., "30G"
|
|
1309
|
+
|
|
903
1310
|
# Convert to GB (handle M/G suffixes)
|
|
904
1311
|
def to_gb(s):
|
|
905
|
-
if s.endswith(
|
|
1312
|
+
if s.endswith("G"):
|
|
906
1313
|
return float(s[:-1])
|
|
907
|
-
elif s.endswith(
|
|
1314
|
+
elif s.endswith("M"):
|
|
908
1315
|
return float(s[:-1]) / 1024
|
|
909
|
-
elif s.endswith(
|
|
1316
|
+
elif s.endswith("K"):
|
|
910
1317
|
return float(s[:-1]) / (1024 * 1024)
|
|
911
1318
|
return 0
|
|
912
|
-
|
|
1319
|
+
|
|
1320
|
+
metadata["disk_usage_gb"] = (
|
|
1321
|
+
f"{to_gb(used_str):.1f} GB / {to_gb(total_str):.0f} GB used"
|
|
1322
|
+
)
|
|
913
1323
|
except Exception:
|
|
914
1324
|
pass
|
|
915
1325
|
|
|
916
1326
|
# 2. Get memory usage from docker stats
|
|
917
1327
|
try:
|
|
918
1328
|
mem_result = subprocess.run(
|
|
919
|
-
[
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
1329
|
+
[
|
|
1330
|
+
"ssh",
|
|
1331
|
+
"-o",
|
|
1332
|
+
"StrictHostKeyChecking=no",
|
|
1333
|
+
"-o",
|
|
1334
|
+
"ConnectTimeout=5",
|
|
1335
|
+
"-i",
|
|
1336
|
+
str(Path.home() / ".ssh" / "id_rsa"),
|
|
1337
|
+
f"azureuser@{vm_ip}",
|
|
1338
|
+
f"docker stats {container_name} --no-stream --format '{{{{.MemUsage}}}}'",
|
|
1339
|
+
],
|
|
1340
|
+
capture_output=True,
|
|
1341
|
+
text=True,
|
|
1342
|
+
timeout=10,
|
|
924
1343
|
)
|
|
925
1344
|
if mem_result.returncode == 0 and mem_result.stdout.strip():
|
|
926
1345
|
# Example: "1.5GiB / 4GiB"
|
|
@@ -929,16 +1348,27 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
929
1348
|
pass
|
|
930
1349
|
|
|
931
1350
|
# 3. Parse setup script phase from logs
|
|
932
|
-
metadata["setup_script_phase"] = self._parse_setup_phase_from_logs(
|
|
1351
|
+
metadata["setup_script_phase"] = self._parse_setup_phase_from_logs(
|
|
1352
|
+
logs, phase
|
|
1353
|
+
)
|
|
933
1354
|
|
|
934
1355
|
# 4. Check /probe endpoint
|
|
935
1356
|
try:
|
|
936
1357
|
probe_result = subprocess.run(
|
|
937
|
-
[
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
1358
|
+
[
|
|
1359
|
+
"ssh",
|
|
1360
|
+
"-o",
|
|
1361
|
+
"StrictHostKeyChecking=no",
|
|
1362
|
+
"-o",
|
|
1363
|
+
"ConnectTimeout=5",
|
|
1364
|
+
"-i",
|
|
1365
|
+
str(Path.home() / ".ssh" / "id_rsa"),
|
|
1366
|
+
f"azureuser@{vm_ip}",
|
|
1367
|
+
"curl -s --connect-timeout 2 http://20.20.20.21:5000/probe 2>/dev/null",
|
|
1368
|
+
],
|
|
1369
|
+
capture_output=True,
|
|
1370
|
+
text=True,
|
|
1371
|
+
timeout=10,
|
|
942
1372
|
)
|
|
943
1373
|
if probe_result.returncode == 0 and probe_result.stdout.strip():
|
|
944
1374
|
metadata["probe_response"] = probe_result.stdout.strip()
|
|
@@ -950,11 +1380,20 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
950
1380
|
# 5. Check QMP connection (port 7200)
|
|
951
1381
|
try:
|
|
952
1382
|
qmp_result = subprocess.run(
|
|
953
|
-
[
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
|
|
1383
|
+
[
|
|
1384
|
+
"ssh",
|
|
1385
|
+
"-o",
|
|
1386
|
+
"StrictHostKeyChecking=no",
|
|
1387
|
+
"-o",
|
|
1388
|
+
"ConnectTimeout=5",
|
|
1389
|
+
"-i",
|
|
1390
|
+
str(Path.home() / ".ssh" / "id_rsa"),
|
|
1391
|
+
f"azureuser@{vm_ip}",
|
|
1392
|
+
"nc -z -w2 localhost 7200 2>&1",
|
|
1393
|
+
],
|
|
1394
|
+
capture_output=True,
|
|
1395
|
+
text=True,
|
|
1396
|
+
timeout=10,
|
|
958
1397
|
)
|
|
959
1398
|
metadata["qmp_connected"] = qmp_result.returncode == 0
|
|
960
1399
|
except Exception:
|
|
@@ -987,7 +1426,12 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
987
1426
|
return "Windows installation in progress"
|
|
988
1427
|
elif current_phase == "booting":
|
|
989
1428
|
return "Booting Windows"
|
|
990
|
-
elif current_phase in [
|
|
1429
|
+
elif current_phase in [
|
|
1430
|
+
"downloading",
|
|
1431
|
+
"extracting",
|
|
1432
|
+
"configuring",
|
|
1433
|
+
"building",
|
|
1434
|
+
]:
|
|
991
1435
|
return "Preparing Windows VM"
|
|
992
1436
|
else:
|
|
993
1437
|
return "Initializing..."
|
|
@@ -1017,17 +1461,23 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1017
1461
|
logs_lower = logs.lower()
|
|
1018
1462
|
|
|
1019
1463
|
# Check for installation patterns
|
|
1020
|
-
if "python" in logs_lower and (
|
|
1464
|
+
if "python" in logs_lower and (
|
|
1465
|
+
"installing python" in logs_lower or "python.exe" in logs_lower
|
|
1466
|
+
):
|
|
1021
1467
|
dependencies[0]["status"] = "installing"
|
|
1022
1468
|
elif "python" in logs_lower and "installed" in logs_lower:
|
|
1023
1469
|
dependencies[0]["status"] = "complete"
|
|
1024
1470
|
|
|
1025
|
-
if "chrome" in logs_lower and (
|
|
1471
|
+
if "chrome" in logs_lower and (
|
|
1472
|
+
"downloading" in logs_lower or "installing" in logs_lower
|
|
1473
|
+
):
|
|
1026
1474
|
dependencies[1]["status"] = "installing"
|
|
1027
1475
|
elif "chrome" in logs_lower and "installed" in logs_lower:
|
|
1028
1476
|
dependencies[1]["status"] = "complete"
|
|
1029
1477
|
|
|
1030
|
-
if "libreoffice" in logs_lower and (
|
|
1478
|
+
if "libreoffice" in logs_lower and (
|
|
1479
|
+
"downloading" in logs_lower or "installing" in logs_lower
|
|
1480
|
+
):
|
|
1031
1481
|
dependencies[2]["status"] = "installing"
|
|
1032
1482
|
elif "libreoffice" in logs_lower and "installed" in logs_lower:
|
|
1033
1483
|
dependencies[2]["status"] = "complete"
|
|
@@ -1046,11 +1496,222 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1046
1496
|
|
|
1047
1497
|
return dependencies
|
|
1048
1498
|
|
|
1499
|
+
def _get_vm_diagnostics(self) -> dict:
|
|
1500
|
+
"""Get VM diagnostics: disk usage, Docker stats, memory usage.
|
|
1501
|
+
|
|
1502
|
+
Returns a dictionary with:
|
|
1503
|
+
- vm_online: bool - whether VM is reachable
|
|
1504
|
+
- disk_usage: list of disk partitions with usage stats
|
|
1505
|
+
- docker_stats: list of container stats (CPU, memory)
|
|
1506
|
+
- memory_usage: VM host memory stats
|
|
1507
|
+
- docker_system: Docker system disk usage
|
|
1508
|
+
- error: str if any error occurred
|
|
1509
|
+
"""
|
|
1510
|
+
import subprocess
|
|
1511
|
+
|
|
1512
|
+
from openadapt_ml.benchmarks.session_tracker import get_session
|
|
1513
|
+
|
|
1514
|
+
diagnostics = {
|
|
1515
|
+
"vm_online": False,
|
|
1516
|
+
"disk_usage": [],
|
|
1517
|
+
"docker_stats": [],
|
|
1518
|
+
"memory_usage": {},
|
|
1519
|
+
"docker_system": {},
|
|
1520
|
+
"docker_images": [],
|
|
1521
|
+
"error": None,
|
|
1522
|
+
}
|
|
1523
|
+
|
|
1524
|
+
# Get VM IP from session
|
|
1525
|
+
session = get_session()
|
|
1526
|
+
vm_ip = session.get("vm_ip")
|
|
1527
|
+
|
|
1528
|
+
if not vm_ip:
|
|
1529
|
+
diagnostics["error"] = (
|
|
1530
|
+
"VM IP not found in session. VM may not be running."
|
|
1531
|
+
)
|
|
1532
|
+
return diagnostics
|
|
1533
|
+
|
|
1534
|
+
# SSH options for Azure VM
|
|
1535
|
+
ssh_opts = [
|
|
1536
|
+
"-o",
|
|
1537
|
+
"StrictHostKeyChecking=no",
|
|
1538
|
+
"-o",
|
|
1539
|
+
"UserKnownHostsFile=/dev/null",
|
|
1540
|
+
"-o",
|
|
1541
|
+
"ConnectTimeout=10",
|
|
1542
|
+
"-o",
|
|
1543
|
+
"ServerAliveInterval=30",
|
|
1544
|
+
]
|
|
1545
|
+
|
|
1546
|
+
# Test VM connectivity
|
|
1547
|
+
try:
|
|
1548
|
+
test_result = subprocess.run(
|
|
1549
|
+
["ssh", *ssh_opts, f"azureuser@{vm_ip}", "echo 'online'"],
|
|
1550
|
+
capture_output=True,
|
|
1551
|
+
text=True,
|
|
1552
|
+
timeout=15,
|
|
1553
|
+
)
|
|
1554
|
+
if test_result.returncode != 0:
|
|
1555
|
+
diagnostics["error"] = f"Cannot connect to VM at {vm_ip}"
|
|
1556
|
+
return diagnostics
|
|
1557
|
+
diagnostics["vm_online"] = True
|
|
1558
|
+
except subprocess.TimeoutExpired:
|
|
1559
|
+
diagnostics["error"] = f"Connection to VM at {vm_ip} timed out"
|
|
1560
|
+
return diagnostics
|
|
1561
|
+
except Exception as e:
|
|
1562
|
+
diagnostics["error"] = f"SSH error: {str(e)}"
|
|
1563
|
+
return diagnostics
|
|
1564
|
+
|
|
1565
|
+
# 1. Disk usage (df -h)
|
|
1566
|
+
try:
|
|
1567
|
+
df_result = subprocess.run(
|
|
1568
|
+
[
|
|
1569
|
+
"ssh",
|
|
1570
|
+
*ssh_opts,
|
|
1571
|
+
f"azureuser@{vm_ip}",
|
|
1572
|
+
"df -h / /mnt 2>/dev/null | tail -n +2",
|
|
1573
|
+
],
|
|
1574
|
+
capture_output=True,
|
|
1575
|
+
text=True,
|
|
1576
|
+
timeout=15,
|
|
1577
|
+
)
|
|
1578
|
+
if df_result.returncode == 0 and df_result.stdout.strip():
|
|
1579
|
+
for line in df_result.stdout.strip().split("\n"):
|
|
1580
|
+
parts = line.split()
|
|
1581
|
+
if len(parts) >= 6:
|
|
1582
|
+
diagnostics["disk_usage"].append(
|
|
1583
|
+
{
|
|
1584
|
+
"filesystem": parts[0],
|
|
1585
|
+
"size": parts[1],
|
|
1586
|
+
"used": parts[2],
|
|
1587
|
+
"available": parts[3],
|
|
1588
|
+
"use_percent": parts[4],
|
|
1589
|
+
"mount_point": parts[5],
|
|
1590
|
+
}
|
|
1591
|
+
)
|
|
1592
|
+
except Exception as e:
|
|
1593
|
+
diagnostics["disk_usage"] = [{"error": str(e)}]
|
|
1594
|
+
|
|
1595
|
+
# 2. Docker container stats
|
|
1596
|
+
try:
|
|
1597
|
+
stats_result = subprocess.run(
|
|
1598
|
+
[
|
|
1599
|
+
"ssh",
|
|
1600
|
+
*ssh_opts,
|
|
1601
|
+
f"azureuser@{vm_ip}",
|
|
1602
|
+
"docker stats --no-stream --format '{{.Name}}|{{.CPUPerc}}|{{.MemUsage}}|{{.MemPerc}}|{{.NetIO}}|{{.BlockIO}}' 2>/dev/null || echo ''",
|
|
1603
|
+
],
|
|
1604
|
+
capture_output=True,
|
|
1605
|
+
text=True,
|
|
1606
|
+
timeout=30,
|
|
1607
|
+
)
|
|
1608
|
+
if stats_result.returncode == 0 and stats_result.stdout.strip():
|
|
1609
|
+
for line in stats_result.stdout.strip().split("\n"):
|
|
1610
|
+
if "|" in line:
|
|
1611
|
+
parts = line.split("|")
|
|
1612
|
+
if len(parts) >= 6:
|
|
1613
|
+
diagnostics["docker_stats"].append(
|
|
1614
|
+
{
|
|
1615
|
+
"container": parts[0],
|
|
1616
|
+
"cpu_percent": parts[1],
|
|
1617
|
+
"memory_usage": parts[2],
|
|
1618
|
+
"memory_percent": parts[3],
|
|
1619
|
+
"net_io": parts[4],
|
|
1620
|
+
"block_io": parts[5],
|
|
1621
|
+
}
|
|
1622
|
+
)
|
|
1623
|
+
except Exception as e:
|
|
1624
|
+
diagnostics["docker_stats"] = [{"error": str(e)}]
|
|
1625
|
+
|
|
1626
|
+
# 3. VM host memory usage (free -h)
|
|
1627
|
+
try:
|
|
1628
|
+
mem_result = subprocess.run(
|
|
1629
|
+
[
|
|
1630
|
+
"ssh",
|
|
1631
|
+
*ssh_opts,
|
|
1632
|
+
f"azureuser@{vm_ip}",
|
|
1633
|
+
"free -h | head -2 | tail -1",
|
|
1634
|
+
],
|
|
1635
|
+
capture_output=True,
|
|
1636
|
+
text=True,
|
|
1637
|
+
timeout=15,
|
|
1638
|
+
)
|
|
1639
|
+
if mem_result.returncode == 0 and mem_result.stdout.strip():
|
|
1640
|
+
parts = mem_result.stdout.strip().split()
|
|
1641
|
+
if len(parts) >= 7:
|
|
1642
|
+
diagnostics["memory_usage"] = {
|
|
1643
|
+
"total": parts[1],
|
|
1644
|
+
"used": parts[2],
|
|
1645
|
+
"free": parts[3],
|
|
1646
|
+
"shared": parts[4],
|
|
1647
|
+
"buff_cache": parts[5],
|
|
1648
|
+
"available": parts[6],
|
|
1649
|
+
}
|
|
1650
|
+
except Exception as e:
|
|
1651
|
+
diagnostics["memory_usage"] = {"error": str(e)}
|
|
1652
|
+
|
|
1653
|
+
# 4. Docker system disk usage
|
|
1654
|
+
try:
|
|
1655
|
+
docker_df_result = subprocess.run(
|
|
1656
|
+
[
|
|
1657
|
+
"ssh",
|
|
1658
|
+
*ssh_opts,
|
|
1659
|
+
f"azureuser@{vm_ip}",
|
|
1660
|
+
"docker system df 2>/dev/null || echo ''",
|
|
1661
|
+
],
|
|
1662
|
+
capture_output=True,
|
|
1663
|
+
text=True,
|
|
1664
|
+
timeout=15,
|
|
1665
|
+
)
|
|
1666
|
+
if docker_df_result.returncode == 0 and docker_df_result.stdout.strip():
|
|
1667
|
+
lines = docker_df_result.stdout.strip().split("\n")
|
|
1668
|
+
# Parse the table: TYPE, TOTAL, ACTIVE, SIZE, RECLAIMABLE
|
|
1669
|
+
for line in lines[1:]: # Skip header
|
|
1670
|
+
parts = line.split()
|
|
1671
|
+
if len(parts) >= 5:
|
|
1672
|
+
dtype = parts[0]
|
|
1673
|
+
diagnostics["docker_system"][dtype.lower()] = {
|
|
1674
|
+
"total": parts[1],
|
|
1675
|
+
"active": parts[2],
|
|
1676
|
+
"size": parts[3],
|
|
1677
|
+
"reclaimable": " ".join(parts[4:]),
|
|
1678
|
+
}
|
|
1679
|
+
except Exception as e:
|
|
1680
|
+
diagnostics["docker_system"] = {"error": str(e)}
|
|
1681
|
+
|
|
1682
|
+
# 5. Docker images
|
|
1683
|
+
try:
|
|
1684
|
+
images_result = subprocess.run(
|
|
1685
|
+
[
|
|
1686
|
+
"ssh",
|
|
1687
|
+
*ssh_opts,
|
|
1688
|
+
f"azureuser@{vm_ip}",
|
|
1689
|
+
"docker images --format '{{.Repository}}:{{.Tag}}|{{.Size}}|{{.CreatedSince}}' 2>/dev/null || echo ''",
|
|
1690
|
+
],
|
|
1691
|
+
capture_output=True,
|
|
1692
|
+
text=True,
|
|
1693
|
+
timeout=15,
|
|
1694
|
+
)
|
|
1695
|
+
if images_result.returncode == 0 and images_result.stdout.strip():
|
|
1696
|
+
for line in images_result.stdout.strip().split("\n"):
|
|
1697
|
+
if "|" in line:
|
|
1698
|
+
parts = line.split("|")
|
|
1699
|
+
if len(parts) >= 3:
|
|
1700
|
+
diagnostics["docker_images"].append(
|
|
1701
|
+
{
|
|
1702
|
+
"image": parts[0],
|
|
1703
|
+
"size": parts[1],
|
|
1704
|
+
"created": parts[2],
|
|
1705
|
+
}
|
|
1706
|
+
)
|
|
1707
|
+
except Exception as e:
|
|
1708
|
+
diagnostics["docker_images"] = [{"error": str(e)}]
|
|
1709
|
+
|
|
1710
|
+
return diagnostics
|
|
1711
|
+
|
|
1049
1712
|
def _fetch_background_tasks(self):
|
|
1050
1713
|
"""Fetch status of all background tasks: Azure VM, Docker containers, benchmarks."""
|
|
1051
1714
|
import subprocess
|
|
1052
|
-
from datetime import datetime
|
|
1053
|
-
import time
|
|
1054
1715
|
|
|
1055
1716
|
tasks = []
|
|
1056
1717
|
|
|
@@ -1063,31 +1724,43 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1063
1724
|
if env_vm_ip:
|
|
1064
1725
|
# Use environment variable - VM IP was provided directly
|
|
1065
1726
|
vm_ip = env_vm_ip
|
|
1066
|
-
tasks.append(
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
|
|
1076
|
-
"
|
|
1077
|
-
|
|
1078
|
-
|
|
1727
|
+
tasks.append(
|
|
1728
|
+
{
|
|
1729
|
+
"task_id": "azure-vm-waa",
|
|
1730
|
+
"task_type": "vm_provision",
|
|
1731
|
+
"status": "completed",
|
|
1732
|
+
"phase": "ready", # Match status to prevent "Starting" + "completed" conflict
|
|
1733
|
+
"title": "Azure VM Host",
|
|
1734
|
+
"description": f"Linux host running at {vm_ip}",
|
|
1735
|
+
"progress_percent": 100.0,
|
|
1736
|
+
"elapsed_seconds": 0,
|
|
1737
|
+
"metadata": {
|
|
1738
|
+
"vm_name": "waa-eval-vm",
|
|
1739
|
+
"ip_address": vm_ip,
|
|
1740
|
+
"internal_ip": env_internal_ip,
|
|
1741
|
+
},
|
|
1079
1742
|
}
|
|
1080
|
-
|
|
1743
|
+
)
|
|
1081
1744
|
else:
|
|
1082
1745
|
# Query Azure CLI for VM status
|
|
1083
1746
|
try:
|
|
1084
1747
|
result = subprocess.run(
|
|
1085
|
-
[
|
|
1086
|
-
|
|
1087
|
-
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
|
|
1748
|
+
[
|
|
1749
|
+
"az",
|
|
1750
|
+
"vm",
|
|
1751
|
+
"get-instance-view",
|
|
1752
|
+
"--name",
|
|
1753
|
+
"waa-eval-vm",
|
|
1754
|
+
"--resource-group",
|
|
1755
|
+
"openadapt-agents",
|
|
1756
|
+
"--query",
|
|
1757
|
+
"instanceView.statuses",
|
|
1758
|
+
"-o",
|
|
1759
|
+
"json",
|
|
1760
|
+
],
|
|
1761
|
+
capture_output=True,
|
|
1762
|
+
text=True,
|
|
1763
|
+
timeout=10,
|
|
1091
1764
|
)
|
|
1092
1765
|
if result.returncode == 0:
|
|
1093
1766
|
statuses = json.loads(result.stdout)
|
|
@@ -1098,31 +1771,49 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1098
1771
|
|
|
1099
1772
|
# Get VM IP
|
|
1100
1773
|
ip_result = subprocess.run(
|
|
1101
|
-
[
|
|
1102
|
-
|
|
1103
|
-
|
|
1104
|
-
|
|
1105
|
-
|
|
1106
|
-
|
|
1774
|
+
[
|
|
1775
|
+
"az",
|
|
1776
|
+
"vm",
|
|
1777
|
+
"list-ip-addresses",
|
|
1778
|
+
"--name",
|
|
1779
|
+
"waa-eval-vm",
|
|
1780
|
+
"--resource-group",
|
|
1781
|
+
"openadapt-agents",
|
|
1782
|
+
"--query",
|
|
1783
|
+
"[0].virtualMachine.network.publicIpAddresses[0].ipAddress",
|
|
1784
|
+
"-o",
|
|
1785
|
+
"tsv",
|
|
1786
|
+
],
|
|
1787
|
+
capture_output=True,
|
|
1788
|
+
text=True,
|
|
1789
|
+
timeout=10,
|
|
1790
|
+
)
|
|
1791
|
+
vm_ip = (
|
|
1792
|
+
ip_result.stdout.strip()
|
|
1793
|
+
if ip_result.returncode == 0
|
|
1794
|
+
else None
|
|
1107
1795
|
)
|
|
1108
|
-
vm_ip = ip_result.stdout.strip() if ip_result.returncode == 0 else None
|
|
1109
1796
|
|
|
1110
1797
|
if power_state == "running":
|
|
1111
|
-
tasks.append(
|
|
1112
|
-
|
|
1113
|
-
|
|
1114
|
-
|
|
1115
|
-
|
|
1116
|
-
|
|
1117
|
-
|
|
1118
|
-
|
|
1119
|
-
|
|
1120
|
-
|
|
1121
|
-
"
|
|
1122
|
-
"
|
|
1123
|
-
|
|
1798
|
+
tasks.append(
|
|
1799
|
+
{
|
|
1800
|
+
"task_id": "azure-vm-waa",
|
|
1801
|
+
"task_type": "vm_provision",
|
|
1802
|
+
"status": "completed",
|
|
1803
|
+
"phase": "ready", # Match status to prevent "Starting" + "completed" conflict
|
|
1804
|
+
"title": "Azure VM Host",
|
|
1805
|
+
"description": f"Linux host running at {vm_ip}"
|
|
1806
|
+
if vm_ip
|
|
1807
|
+
else "Linux host running",
|
|
1808
|
+
"progress_percent": 100.0,
|
|
1809
|
+
"elapsed_seconds": 0,
|
|
1810
|
+
"metadata": {
|
|
1811
|
+
"vm_name": "waa-eval-vm",
|
|
1812
|
+
"ip_address": vm_ip,
|
|
1813
|
+
# No VNC link - that's for the Windows container
|
|
1814
|
+
},
|
|
1124
1815
|
}
|
|
1125
|
-
|
|
1816
|
+
)
|
|
1126
1817
|
except subprocess.TimeoutExpired:
|
|
1127
1818
|
pass
|
|
1128
1819
|
except Exception:
|
|
@@ -1132,31 +1823,59 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1132
1823
|
if vm_ip:
|
|
1133
1824
|
try:
|
|
1134
1825
|
docker_result = subprocess.run(
|
|
1135
|
-
[
|
|
1136
|
-
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
|
|
1826
|
+
[
|
|
1827
|
+
"ssh",
|
|
1828
|
+
"-o",
|
|
1829
|
+
"StrictHostKeyChecking=no",
|
|
1830
|
+
"-o",
|
|
1831
|
+
"ConnectTimeout=5",
|
|
1832
|
+
"-i",
|
|
1833
|
+
str(Path.home() / ".ssh" / "id_rsa"),
|
|
1834
|
+
f"azureuser@{vm_ip}",
|
|
1835
|
+
"docker ps --format '{{.Names}}|{{.Status}}|{{.Image}}'",
|
|
1836
|
+
],
|
|
1837
|
+
capture_output=True,
|
|
1838
|
+
text=True,
|
|
1839
|
+
timeout=15,
|
|
1140
1840
|
)
|
|
1141
1841
|
if docker_result.returncode == 0 and docker_result.stdout.strip():
|
|
1142
|
-
for line in docker_result.stdout.strip().split(
|
|
1143
|
-
parts = line.split(
|
|
1842
|
+
for line in docker_result.stdout.strip().split("\n"):
|
|
1843
|
+
parts = line.split("|")
|
|
1144
1844
|
if len(parts) >= 3:
|
|
1145
|
-
container_name, status, image =
|
|
1845
|
+
container_name, status, image = (
|
|
1846
|
+
parts[0],
|
|
1847
|
+
parts[1],
|
|
1848
|
+
parts[2],
|
|
1849
|
+
)
|
|
1146
1850
|
# Parse "Up X minutes" to determine if healthy
|
|
1147
|
-
is_healthy = "Up" in status
|
|
1148
1851
|
|
|
1149
1852
|
# Check for Windows VM specifically
|
|
1150
|
-
if
|
|
1853
|
+
if (
|
|
1854
|
+
"windows" in image.lower()
|
|
1855
|
+
or container_name == "winarena"
|
|
1856
|
+
):
|
|
1151
1857
|
# Get detailed progress from docker logs
|
|
1152
1858
|
log_check = subprocess.run(
|
|
1153
|
-
[
|
|
1154
|
-
|
|
1155
|
-
|
|
1156
|
-
|
|
1157
|
-
|
|
1859
|
+
[
|
|
1860
|
+
"ssh",
|
|
1861
|
+
"-o",
|
|
1862
|
+
"StrictHostKeyChecking=no",
|
|
1863
|
+
"-o",
|
|
1864
|
+
"ConnectTimeout=5",
|
|
1865
|
+
"-i",
|
|
1866
|
+
str(Path.home() / ".ssh" / "id_rsa"),
|
|
1867
|
+
f"azureuser@{vm_ip}",
|
|
1868
|
+
f"docker logs {container_name} 2>&1 | tail -30",
|
|
1869
|
+
],
|
|
1870
|
+
capture_output=True,
|
|
1871
|
+
text=True,
|
|
1872
|
+
timeout=10,
|
|
1873
|
+
)
|
|
1874
|
+
logs = (
|
|
1875
|
+
log_check.stdout
|
|
1876
|
+
if log_check.returncode == 0
|
|
1877
|
+
else ""
|
|
1158
1878
|
)
|
|
1159
|
-
logs = log_check.stdout if log_check.returncode == 0 else ""
|
|
1160
1879
|
|
|
1161
1880
|
# Parse progress from logs
|
|
1162
1881
|
phase = "unknown"
|
|
@@ -1167,17 +1886,32 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1167
1886
|
# Check if WAA server is ready via Docker port forwarding
|
|
1168
1887
|
# See docs/waa_network_architecture.md - always use localhost
|
|
1169
1888
|
server_check = subprocess.run(
|
|
1170
|
-
[
|
|
1171
|
-
|
|
1172
|
-
|
|
1173
|
-
|
|
1174
|
-
|
|
1889
|
+
[
|
|
1890
|
+
"ssh",
|
|
1891
|
+
"-o",
|
|
1892
|
+
"StrictHostKeyChecking=no",
|
|
1893
|
+
"-o",
|
|
1894
|
+
"ConnectTimeout=5",
|
|
1895
|
+
"-i",
|
|
1896
|
+
str(Path.home() / ".ssh" / "id_rsa"),
|
|
1897
|
+
f"azureuser@{vm_ip}",
|
|
1898
|
+
"curl -s --connect-timeout 2 http://localhost:5000/probe 2>/dev/null",
|
|
1899
|
+
],
|
|
1900
|
+
capture_output=True,
|
|
1901
|
+
text=True,
|
|
1902
|
+
timeout=10,
|
|
1903
|
+
)
|
|
1904
|
+
waa_ready = (
|
|
1905
|
+
server_check.returncode == 0
|
|
1906
|
+
and "Service is operational"
|
|
1907
|
+
in server_check.stdout
|
|
1175
1908
|
)
|
|
1176
|
-
waa_ready = server_check.returncode == 0 and "Service is operational" in server_check.stdout
|
|
1177
1909
|
if waa_ready:
|
|
1178
1910
|
phase = "ready"
|
|
1179
1911
|
progress = 100.0
|
|
1180
|
-
description =
|
|
1912
|
+
description = (
|
|
1913
|
+
"WAA Server ready - benchmarks can run"
|
|
1914
|
+
)
|
|
1181
1915
|
else:
|
|
1182
1916
|
phase = "oobe"
|
|
1183
1917
|
progress = 80.0 # Phase 5/6 - VM install in progress
|
|
@@ -1186,10 +1920,15 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1186
1920
|
phase = "booting"
|
|
1187
1921
|
progress = 70.0 # Phase 4/6
|
|
1188
1922
|
description = "Phase 4/6: Booting Windows from installer..."
|
|
1189
|
-
elif
|
|
1923
|
+
elif (
|
|
1924
|
+
"Building Windows" in logs
|
|
1925
|
+
or "Creating a" in logs
|
|
1926
|
+
):
|
|
1190
1927
|
phase = "building"
|
|
1191
1928
|
progress = 60.0 # Phase 3/6
|
|
1192
|
-
description =
|
|
1929
|
+
description = (
|
|
1930
|
+
"Phase 3/6: Building Windows VM disk..."
|
|
1931
|
+
)
|
|
1193
1932
|
elif "Adding" in logs and "image" in logs:
|
|
1194
1933
|
phase = "configuring"
|
|
1195
1934
|
progress = 50.0 # Phase 2/6
|
|
@@ -1197,15 +1936,22 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1197
1936
|
elif "Extracting" in logs:
|
|
1198
1937
|
phase = "extracting"
|
|
1199
1938
|
progress = 35.0 # Phase 1/6 (after download)
|
|
1200
|
-
description =
|
|
1939
|
+
description = (
|
|
1940
|
+
"Phase 1/6: Extracting Windows ISO..."
|
|
1941
|
+
)
|
|
1201
1942
|
else:
|
|
1202
1943
|
# Check for download progress (e.g., "1234K ........ 45% 80M 30s")
|
|
1203
1944
|
import re
|
|
1204
|
-
|
|
1945
|
+
|
|
1946
|
+
download_match = re.search(
|
|
1947
|
+
r"(\d+)%\s+[\d.]+[KMG]\s+(\d+)s", logs
|
|
1948
|
+
)
|
|
1205
1949
|
if download_match:
|
|
1206
1950
|
phase = "downloading"
|
|
1207
1951
|
dl_pct = float(download_match.group(1))
|
|
1208
|
-
progress =
|
|
1952
|
+
progress = (
|
|
1953
|
+
dl_pct * 0.30
|
|
1954
|
+
) # 0-30% for download phase
|
|
1209
1955
|
eta = download_match.group(2)
|
|
1210
1956
|
description = f"Phase 0/6: Downloading Windows 11... {download_match.group(1)}% ({eta}s left)"
|
|
1211
1957
|
|
|
@@ -1218,40 +1964,61 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1218
1964
|
progress = 90.0
|
|
1219
1965
|
|
|
1220
1966
|
# Get detailed metadata for VM Details panel
|
|
1221
|
-
vm_metadata = self._get_vm_detailed_metadata(
|
|
1222
|
-
|
|
1223
|
-
|
|
1224
|
-
|
|
1225
|
-
|
|
1226
|
-
|
|
1227
|
-
|
|
1228
|
-
|
|
1229
|
-
|
|
1230
|
-
|
|
1231
|
-
|
|
1232
|
-
|
|
1233
|
-
"
|
|
1234
|
-
"
|
|
1235
|
-
"
|
|
1967
|
+
vm_metadata = self._get_vm_detailed_metadata(
|
|
1968
|
+
vm_ip, container_name, logs, phase
|
|
1969
|
+
)
|
|
1970
|
+
|
|
1971
|
+
tasks.append(
|
|
1972
|
+
{
|
|
1973
|
+
"task_id": f"docker-{container_name}",
|
|
1974
|
+
"task_type": "docker_container",
|
|
1975
|
+
"status": "completed"
|
|
1976
|
+
if phase == "ready"
|
|
1977
|
+
else "running",
|
|
1978
|
+
"title": "Windows 11 + WAA Server",
|
|
1979
|
+
"description": description,
|
|
1980
|
+
"progress_percent": progress,
|
|
1981
|
+
"elapsed_seconds": 0,
|
|
1236
1982
|
"phase": phase,
|
|
1237
|
-
"
|
|
1238
|
-
|
|
1239
|
-
|
|
1240
|
-
|
|
1241
|
-
|
|
1242
|
-
|
|
1243
|
-
|
|
1244
|
-
|
|
1245
|
-
|
|
1246
|
-
|
|
1247
|
-
|
|
1248
|
-
|
|
1249
|
-
|
|
1250
|
-
|
|
1251
|
-
|
|
1983
|
+
"metadata": {
|
|
1984
|
+
"container": container_name,
|
|
1985
|
+
"image": image,
|
|
1986
|
+
"status": status,
|
|
1987
|
+
"phase": phase,
|
|
1988
|
+
"windows_ready": phase
|
|
1989
|
+
in ["oobe", "ready"],
|
|
1990
|
+
"waa_server_ready": phase == "ready",
|
|
1991
|
+
# Use localhost - SSH tunnel handles routing to VM
|
|
1992
|
+
# See docs/waa_network_architecture.md
|
|
1993
|
+
"vnc_url": "http://localhost:8006",
|
|
1994
|
+
"windows_username": "Docker",
|
|
1995
|
+
"windows_password": "admin",
|
|
1996
|
+
"recent_logs": logs[-500:]
|
|
1997
|
+
if logs
|
|
1998
|
+
else "",
|
|
1999
|
+
# Enhanced VM details
|
|
2000
|
+
"disk_usage_gb": vm_metadata[
|
|
2001
|
+
"disk_usage_gb"
|
|
2002
|
+
],
|
|
2003
|
+
"memory_usage_mb": vm_metadata[
|
|
2004
|
+
"memory_usage_mb"
|
|
2005
|
+
],
|
|
2006
|
+
"setup_script_phase": vm_metadata[
|
|
2007
|
+
"setup_script_phase"
|
|
2008
|
+
],
|
|
2009
|
+
"probe_response": vm_metadata[
|
|
2010
|
+
"probe_response"
|
|
2011
|
+
],
|
|
2012
|
+
"qmp_connected": vm_metadata[
|
|
2013
|
+
"qmp_connected"
|
|
2014
|
+
],
|
|
2015
|
+
"dependencies": vm_metadata[
|
|
2016
|
+
"dependencies"
|
|
2017
|
+
],
|
|
2018
|
+
},
|
|
1252
2019
|
}
|
|
1253
|
-
|
|
1254
|
-
except Exception
|
|
2020
|
+
)
|
|
2021
|
+
except Exception:
|
|
1255
2022
|
# SSH failed, VM might still be starting
|
|
1256
2023
|
pass
|
|
1257
2024
|
|
|
@@ -1261,38 +2028,93 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1261
2028
|
try:
|
|
1262
2029
|
progress = json.loads(progress_file.read_text())
|
|
1263
2030
|
if progress.get("status") == "running":
|
|
1264
|
-
tasks.append(
|
|
1265
|
-
|
|
1266
|
-
|
|
1267
|
-
|
|
1268
|
-
|
|
1269
|
-
|
|
1270
|
-
|
|
1271
|
-
|
|
1272
|
-
|
|
1273
|
-
|
|
2031
|
+
tasks.append(
|
|
2032
|
+
{
|
|
2033
|
+
"task_id": "benchmark-local",
|
|
2034
|
+
"task_type": "benchmark_run",
|
|
2035
|
+
"status": "running",
|
|
2036
|
+
"title": f"{progress.get('provider', 'API').upper()} Benchmark",
|
|
2037
|
+
"description": progress.get(
|
|
2038
|
+
"message", "Running benchmark..."
|
|
2039
|
+
),
|
|
2040
|
+
"progress_percent": (
|
|
2041
|
+
progress.get("tasks_complete", 0)
|
|
2042
|
+
/ max(progress.get("tasks_total", 1), 1)
|
|
2043
|
+
)
|
|
2044
|
+
* 100,
|
|
2045
|
+
"elapsed_seconds": 0,
|
|
2046
|
+
"metadata": progress,
|
|
2047
|
+
}
|
|
2048
|
+
)
|
|
1274
2049
|
except Exception:
|
|
1275
2050
|
pass
|
|
1276
2051
|
|
|
1277
2052
|
return tasks
|
|
1278
2053
|
|
|
1279
2054
|
def _fetch_vm_registry(self):
|
|
1280
|
-
"""Fetch VM registry with live status checks.
|
|
2055
|
+
"""Fetch VM registry with live status checks.
|
|
2056
|
+
|
|
2057
|
+
NOTE: We now fetch the VM IP from Azure CLI at runtime to avoid
|
|
2058
|
+
stale IP issues. The registry file is only used as a fallback.
|
|
2059
|
+
"""
|
|
1281
2060
|
import subprocess
|
|
1282
2061
|
from datetime import datetime
|
|
1283
2062
|
|
|
1284
|
-
#
|
|
1285
|
-
|
|
1286
|
-
|
|
2063
|
+
# Try to get VM IP from Azure CLI (always fresh)
|
|
2064
|
+
vm_ip = None
|
|
2065
|
+
resource_group = "openadapt-agents"
|
|
2066
|
+
vm_name = "azure-waa-vm"
|
|
2067
|
+
try:
|
|
2068
|
+
result = subprocess.run(
|
|
2069
|
+
[
|
|
2070
|
+
"az",
|
|
2071
|
+
"vm",
|
|
2072
|
+
"show",
|
|
2073
|
+
"-d",
|
|
2074
|
+
"-g",
|
|
2075
|
+
resource_group,
|
|
2076
|
+
"-n",
|
|
2077
|
+
vm_name,
|
|
2078
|
+
"--query",
|
|
2079
|
+
"publicIps",
|
|
2080
|
+
"-o",
|
|
2081
|
+
"tsv",
|
|
2082
|
+
],
|
|
2083
|
+
capture_output=True,
|
|
2084
|
+
text=True,
|
|
2085
|
+
timeout=10,
|
|
2086
|
+
)
|
|
2087
|
+
if result.returncode == 0 and result.stdout.strip():
|
|
2088
|
+
vm_ip = result.stdout.strip()
|
|
2089
|
+
except Exception:
|
|
2090
|
+
pass
|
|
1287
2091
|
|
|
1288
|
-
|
|
1289
|
-
|
|
2092
|
+
# If we have a fresh IP from Azure, use it
|
|
2093
|
+
if vm_ip:
|
|
2094
|
+
vms = [
|
|
2095
|
+
{
|
|
2096
|
+
"name": vm_name,
|
|
2097
|
+
"ssh_host": vm_ip,
|
|
2098
|
+
"ssh_user": "azureuser",
|
|
2099
|
+
"vnc_port": 8006,
|
|
2100
|
+
"waa_port": 5000,
|
|
2101
|
+
"docker_container": "winarena",
|
|
2102
|
+
"internal_ip": "localhost",
|
|
2103
|
+
}
|
|
2104
|
+
]
|
|
2105
|
+
else:
|
|
2106
|
+
# Fallback to registry file
|
|
2107
|
+
project_root = Path(__file__).parent.parent.parent
|
|
2108
|
+
registry_file = project_root / "benchmark_results" / "vm_registry.json"
|
|
1290
2109
|
|
|
1291
|
-
|
|
1292
|
-
|
|
1293
|
-
|
|
1294
|
-
|
|
1295
|
-
|
|
2110
|
+
if not registry_file.exists():
|
|
2111
|
+
return []
|
|
2112
|
+
|
|
2113
|
+
try:
|
|
2114
|
+
with open(registry_file) as f:
|
|
2115
|
+
vms = json.load(f)
|
|
2116
|
+
except Exception as e:
|
|
2117
|
+
return {"error": f"Failed to read VM registry: {e}"}
|
|
1296
2118
|
|
|
1297
2119
|
# Check status for each VM
|
|
1298
2120
|
for vm in vms:
|
|
@@ -1306,7 +2128,9 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1306
2128
|
vnc_url = f"http://{vm['ssh_host']}:{vm['vnc_port']}"
|
|
1307
2129
|
result = subprocess.run(
|
|
1308
2130
|
["curl", "-I", "-s", "--connect-timeout", "3", vnc_url],
|
|
1309
|
-
capture_output=True,
|
|
2131
|
+
capture_output=True,
|
|
2132
|
+
text=True,
|
|
2133
|
+
timeout=5,
|
|
1310
2134
|
)
|
|
1311
2135
|
if result.returncode == 0 and "200" in result.stdout:
|
|
1312
2136
|
vm["vnc_reachable"] = True
|
|
@@ -1320,13 +2144,25 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1320
2144
|
waa_port = vm.get("waa_port", 5000)
|
|
1321
2145
|
ssh_cmd = f"curl -s --connect-timeout 2 http://localhost:{waa_port}/probe 2>/dev/null"
|
|
1322
2146
|
result = subprocess.run(
|
|
1323
|
-
[
|
|
1324
|
-
|
|
1325
|
-
|
|
1326
|
-
|
|
1327
|
-
|
|
2147
|
+
[
|
|
2148
|
+
"ssh",
|
|
2149
|
+
"-o",
|
|
2150
|
+
"StrictHostKeyChecking=no",
|
|
2151
|
+
"-o",
|
|
2152
|
+
"ConnectTimeout=3",
|
|
2153
|
+
"-i",
|
|
2154
|
+
str(Path.home() / ".ssh" / "id_rsa"),
|
|
2155
|
+
f"{vm['ssh_user']}@{vm['ssh_host']}",
|
|
2156
|
+
ssh_cmd,
|
|
2157
|
+
],
|
|
2158
|
+
capture_output=True,
|
|
2159
|
+
text=True,
|
|
2160
|
+
timeout=5,
|
|
2161
|
+
)
|
|
2162
|
+
probe_success = (
|
|
2163
|
+
result.returncode == 0
|
|
2164
|
+
and "Service is operational" in result.stdout
|
|
1328
2165
|
)
|
|
1329
|
-
probe_success = result.returncode == 0 and "Service is operational" in result.stdout
|
|
1330
2166
|
if probe_success:
|
|
1331
2167
|
vm["waa_probe_status"] = "ready"
|
|
1332
2168
|
vm["status"] = "online"
|
|
@@ -1338,7 +2174,11 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1338
2174
|
ssh_user=vm.get("ssh_user", "azureuser"),
|
|
1339
2175
|
)
|
|
1340
2176
|
vm["tunnels"] = {
|
|
1341
|
-
name: {
|
|
2177
|
+
name: {
|
|
2178
|
+
"active": s.active,
|
|
2179
|
+
"local_port": s.local_port,
|
|
2180
|
+
"error": s.error,
|
|
2181
|
+
}
|
|
1342
2182
|
for name, s in tunnel_status.items()
|
|
1343
2183
|
}
|
|
1344
2184
|
except Exception as e:
|
|
@@ -1384,12 +2224,22 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1384
2224
|
# First get VM IP
|
|
1385
2225
|
try:
|
|
1386
2226
|
ip_result = subprocess.run(
|
|
1387
|
-
[
|
|
1388
|
-
|
|
1389
|
-
|
|
1390
|
-
|
|
1391
|
-
|
|
1392
|
-
|
|
2227
|
+
[
|
|
2228
|
+
"az",
|
|
2229
|
+
"vm",
|
|
2230
|
+
"list-ip-addresses",
|
|
2231
|
+
"--name",
|
|
2232
|
+
"waa-eval-vm",
|
|
2233
|
+
"--resource-group",
|
|
2234
|
+
"openadapt-agents",
|
|
2235
|
+
"--query",
|
|
2236
|
+
"[0].virtualMachine.network.publicIpAddresses[0].ipAddress",
|
|
2237
|
+
"-o",
|
|
2238
|
+
"tsv",
|
|
2239
|
+
],
|
|
2240
|
+
capture_output=True,
|
|
2241
|
+
text=True,
|
|
2242
|
+
timeout=10,
|
|
1393
2243
|
)
|
|
1394
2244
|
if ip_result.returncode == 0 and ip_result.stdout.strip():
|
|
1395
2245
|
vm_ip = ip_result.stdout.strip()
|
|
@@ -1398,11 +2248,20 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1398
2248
|
# Try to probe WAA server via SSH
|
|
1399
2249
|
# Use the correct internal IP for the Windows VM inside Docker
|
|
1400
2250
|
probe_result = subprocess.run(
|
|
1401
|
-
[
|
|
1402
|
-
|
|
1403
|
-
|
|
1404
|
-
|
|
1405
|
-
|
|
2251
|
+
[
|
|
2252
|
+
"ssh",
|
|
2253
|
+
"-o",
|
|
2254
|
+
"StrictHostKeyChecking=no",
|
|
2255
|
+
"-o",
|
|
2256
|
+
"ConnectTimeout=5",
|
|
2257
|
+
"-i",
|
|
2258
|
+
str(Path.home() / ".ssh" / "id_rsa"),
|
|
2259
|
+
f"azureuser@{vm_ip}",
|
|
2260
|
+
"docker exec waa-container curl -s --connect-timeout 3 http://172.30.0.2:5000/probe 2>/dev/null || echo 'probe_failed'",
|
|
2261
|
+
],
|
|
2262
|
+
capture_output=True,
|
|
2263
|
+
text=True,
|
|
2264
|
+
timeout=15,
|
|
1406
2265
|
)
|
|
1407
2266
|
|
|
1408
2267
|
result["container"] = "waa-container"
|
|
@@ -1415,7 +2274,9 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1415
2274
|
else:
|
|
1416
2275
|
result["probe_result"] = "WAA server not responding"
|
|
1417
2276
|
else:
|
|
1418
|
-
result["probe_result"] =
|
|
2277
|
+
result["probe_result"] = (
|
|
2278
|
+
f"SSH/Docker error: {probe_result.stderr[:200]}"
|
|
2279
|
+
)
|
|
1419
2280
|
else:
|
|
1420
2281
|
result["probe_result"] = "Could not get VM IP"
|
|
1421
2282
|
except subprocess.TimeoutExpired:
|
|
@@ -1443,7 +2304,6 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1443
2304
|
- elapsed_minutes: int
|
|
1444
2305
|
"""
|
|
1445
2306
|
import subprocess
|
|
1446
|
-
from datetime import datetime
|
|
1447
2307
|
import re
|
|
1448
2308
|
|
|
1449
2309
|
result = {
|
|
@@ -1465,8 +2325,12 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1465
2325
|
result["running"] = True
|
|
1466
2326
|
result["type"] = "local"
|
|
1467
2327
|
result["model"] = progress.get("provider", "unknown")
|
|
1468
|
-
result["progress"]["tasks_completed"] = progress.get(
|
|
1469
|
-
|
|
2328
|
+
result["progress"]["tasks_completed"] = progress.get(
|
|
2329
|
+
"tasks_complete", 0
|
|
2330
|
+
)
|
|
2331
|
+
result["progress"]["total_tasks"] = progress.get(
|
|
2332
|
+
"tasks_total", 0
|
|
2333
|
+
)
|
|
1470
2334
|
return result
|
|
1471
2335
|
except Exception:
|
|
1472
2336
|
pass
|
|
@@ -1475,12 +2339,22 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1475
2339
|
try:
|
|
1476
2340
|
# Get VM IP
|
|
1477
2341
|
ip_result = subprocess.run(
|
|
1478
|
-
[
|
|
1479
|
-
|
|
1480
|
-
|
|
1481
|
-
|
|
1482
|
-
|
|
1483
|
-
|
|
2342
|
+
[
|
|
2343
|
+
"az",
|
|
2344
|
+
"vm",
|
|
2345
|
+
"list-ip-addresses",
|
|
2346
|
+
"--name",
|
|
2347
|
+
"waa-eval-vm",
|
|
2348
|
+
"--resource-group",
|
|
2349
|
+
"openadapt-agents",
|
|
2350
|
+
"--query",
|
|
2351
|
+
"[0].virtualMachine.network.publicIpAddresses[0].ipAddress",
|
|
2352
|
+
"-o",
|
|
2353
|
+
"tsv",
|
|
2354
|
+
],
|
|
2355
|
+
capture_output=True,
|
|
2356
|
+
text=True,
|
|
2357
|
+
timeout=10,
|
|
1484
2358
|
)
|
|
1485
2359
|
|
|
1486
2360
|
if ip_result.returncode == 0 and ip_result.stdout.strip():
|
|
@@ -1488,42 +2362,73 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1488
2362
|
|
|
1489
2363
|
# Check if benchmark process is running
|
|
1490
2364
|
process_check = subprocess.run(
|
|
1491
|
-
[
|
|
1492
|
-
|
|
1493
|
-
|
|
1494
|
-
|
|
1495
|
-
|
|
2365
|
+
[
|
|
2366
|
+
"ssh",
|
|
2367
|
+
"-o",
|
|
2368
|
+
"StrictHostKeyChecking=no",
|
|
2369
|
+
"-o",
|
|
2370
|
+
"ConnectTimeout=5",
|
|
2371
|
+
"-i",
|
|
2372
|
+
str(Path.home() / ".ssh" / "id_rsa"),
|
|
2373
|
+
f"azureuser@{vm_ip}",
|
|
2374
|
+
"docker exec waa-container pgrep -f 'python.*run.py' 2>/dev/null && echo 'RUNNING' || echo 'NOT_RUNNING'",
|
|
2375
|
+
],
|
|
2376
|
+
capture_output=True,
|
|
2377
|
+
text=True,
|
|
2378
|
+
timeout=10,
|
|
1496
2379
|
)
|
|
1497
2380
|
|
|
1498
|
-
if
|
|
2381
|
+
if (
|
|
2382
|
+
process_check.returncode == 0
|
|
2383
|
+
and "RUNNING" in process_check.stdout
|
|
2384
|
+
):
|
|
1499
2385
|
result["running"] = True
|
|
1500
2386
|
result["type"] = "azure_vm"
|
|
1501
2387
|
|
|
1502
2388
|
# Get log file for more details
|
|
1503
2389
|
log_check = subprocess.run(
|
|
1504
|
-
[
|
|
1505
|
-
|
|
1506
|
-
|
|
1507
|
-
|
|
1508
|
-
|
|
2390
|
+
[
|
|
2391
|
+
"ssh",
|
|
2392
|
+
"-o",
|
|
2393
|
+
"StrictHostKeyChecking=no",
|
|
2394
|
+
"-o",
|
|
2395
|
+
"ConnectTimeout=5",
|
|
2396
|
+
"-i",
|
|
2397
|
+
str(Path.home() / ".ssh" / "id_rsa"),
|
|
2398
|
+
f"azureuser@{vm_ip}",
|
|
2399
|
+
"tail -100 /tmp/waa_benchmark.log 2>/dev/null || echo ''",
|
|
2400
|
+
],
|
|
2401
|
+
capture_output=True,
|
|
2402
|
+
text=True,
|
|
2403
|
+
timeout=10,
|
|
1509
2404
|
)
|
|
1510
2405
|
|
|
1511
2406
|
if log_check.returncode == 0 and log_check.stdout.strip():
|
|
1512
2407
|
logs = log_check.stdout
|
|
1513
2408
|
|
|
1514
2409
|
# Parse model from logs
|
|
1515
|
-
model_match = re.search(
|
|
2410
|
+
model_match = re.search(
|
|
2411
|
+
r"model[=:\s]+([^\s,]+)", logs, re.IGNORECASE
|
|
2412
|
+
)
|
|
1516
2413
|
if model_match:
|
|
1517
2414
|
result["model"] = model_match.group(1)
|
|
1518
2415
|
|
|
1519
2416
|
# Parse progress
|
|
1520
|
-
task_match = re.search(r
|
|
2417
|
+
task_match = re.search(r"Task\s+(\d+)/(\d+)", logs)
|
|
1521
2418
|
if task_match:
|
|
1522
|
-
result["progress"]["tasks_completed"] = int(
|
|
1523
|
-
|
|
2419
|
+
result["progress"]["tasks_completed"] = int(
|
|
2420
|
+
task_match.group(1)
|
|
2421
|
+
)
|
|
2422
|
+
result["progress"]["total_tasks"] = int(
|
|
2423
|
+
task_match.group(2)
|
|
2424
|
+
)
|
|
1524
2425
|
|
|
1525
2426
|
# Parse current task
|
|
1526
|
-
task_id_match = re.search(
|
|
2427
|
+
task_id_match = re.search(
|
|
2428
|
+
r"(?:Running|Processing|task)[:\s]+([a-f0-9-]+)",
|
|
2429
|
+
logs,
|
|
2430
|
+
re.IGNORECASE,
|
|
2431
|
+
)
|
|
1527
2432
|
if task_id_match:
|
|
1528
2433
|
result["current_task"] = task_id_match.group(1)
|
|
1529
2434
|
|
|
@@ -1532,7 +2437,284 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1532
2437
|
|
|
1533
2438
|
return result
|
|
1534
2439
|
|
|
1535
|
-
|
|
2440
|
+
def _get_benchmark_status(self) -> dict:
|
|
2441
|
+
"""Get current benchmark job status with ETA calculation.
|
|
2442
|
+
|
|
2443
|
+
Returns:
|
|
2444
|
+
dict with job status, progress, ETA, and current task info
|
|
2445
|
+
"""
|
|
2446
|
+
import time
|
|
2447
|
+
|
|
2448
|
+
# Check for live evaluation state
|
|
2449
|
+
live_file = Path("benchmark_live.json")
|
|
2450
|
+
if live_file.exists():
|
|
2451
|
+
try:
|
|
2452
|
+
live_state = json.loads(live_file.read_text())
|
|
2453
|
+
if live_state.get("status") == "running":
|
|
2454
|
+
total_tasks = live_state.get("total_tasks", 0)
|
|
2455
|
+
completed_tasks = live_state.get("tasks_completed", 0)
|
|
2456
|
+
current_task = live_state.get("current_task", {})
|
|
2457
|
+
|
|
2458
|
+
# Calculate ETA based on completed tasks
|
|
2459
|
+
eta_seconds = None
|
|
2460
|
+
avg_task_seconds = None
|
|
2461
|
+
if completed_tasks > 0 and total_tasks > 0:
|
|
2462
|
+
# Estimate from live state timestamp or use fallback
|
|
2463
|
+
elapsed = time.time() - live_state.get(
|
|
2464
|
+
"start_time", time.time()
|
|
2465
|
+
)
|
|
2466
|
+
avg_task_seconds = (
|
|
2467
|
+
elapsed / completed_tasks
|
|
2468
|
+
if completed_tasks > 0
|
|
2469
|
+
else 30.0
|
|
2470
|
+
)
|
|
2471
|
+
remaining_tasks = total_tasks - completed_tasks
|
|
2472
|
+
eta_seconds = remaining_tasks * avg_task_seconds
|
|
2473
|
+
|
|
2474
|
+
return {
|
|
2475
|
+
"status": "running",
|
|
2476
|
+
"current_job": {
|
|
2477
|
+
"run_id": live_state.get("run_id", "unknown"),
|
|
2478
|
+
"model_id": live_state.get("model_id", "unknown"),
|
|
2479
|
+
"total_tasks": total_tasks,
|
|
2480
|
+
"completed_tasks": completed_tasks,
|
|
2481
|
+
"current_task": current_task,
|
|
2482
|
+
"eta_seconds": eta_seconds,
|
|
2483
|
+
"avg_task_seconds": avg_task_seconds,
|
|
2484
|
+
},
|
|
2485
|
+
"queue": [], # TODO: implement queue tracking
|
|
2486
|
+
}
|
|
2487
|
+
except Exception as e:
|
|
2488
|
+
return {"status": "error", "error": str(e)}
|
|
2489
|
+
|
|
2490
|
+
# Fallback to current_run check
|
|
2491
|
+
current_run = self._get_current_run()
|
|
2492
|
+
if current_run.get("running"):
|
|
2493
|
+
return {
|
|
2494
|
+
"status": "running",
|
|
2495
|
+
"current_job": {
|
|
2496
|
+
"run_id": "unknown",
|
|
2497
|
+
"model_id": current_run.get("model", "unknown"),
|
|
2498
|
+
"total_tasks": current_run["progress"]["total_tasks"],
|
|
2499
|
+
"completed_tasks": current_run["progress"]["tasks_completed"],
|
|
2500
|
+
"current_task": {"task_id": current_run.get("current_task")},
|
|
2501
|
+
},
|
|
2502
|
+
"queue": [],
|
|
2503
|
+
}
|
|
2504
|
+
|
|
2505
|
+
return {"status": "idle"}
|
|
2506
|
+
|
|
2507
|
+
def _get_benchmark_costs(self) -> dict:
|
|
2508
|
+
"""Get cost breakdown for current benchmark run.
|
|
2509
|
+
|
|
2510
|
+
Returns:
|
|
2511
|
+
dict with Azure VM, API calls, and GPU costs
|
|
2512
|
+
"""
|
|
2513
|
+
|
|
2514
|
+
# Check for cost tracking file
|
|
2515
|
+
cost_file = Path("benchmark_costs.json")
|
|
2516
|
+
if cost_file.exists():
|
|
2517
|
+
try:
|
|
2518
|
+
return json.loads(cost_file.read_text())
|
|
2519
|
+
except Exception:
|
|
2520
|
+
pass
|
|
2521
|
+
|
|
2522
|
+
# Return placeholder structure
|
|
2523
|
+
return {
|
|
2524
|
+
"azure_vm": {
|
|
2525
|
+
"instance_type": "Standard_D4ds_v5",
|
|
2526
|
+
"hourly_rate_usd": 0.192,
|
|
2527
|
+
"hours_elapsed": 0.0,
|
|
2528
|
+
"cost_usd": 0.0,
|
|
2529
|
+
},
|
|
2530
|
+
"api_calls": {
|
|
2531
|
+
"anthropic": {"cost_usd": 0.0},
|
|
2532
|
+
"openai": {"cost_usd": 0.0},
|
|
2533
|
+
},
|
|
2534
|
+
"gpu_time": {
|
|
2535
|
+
"lambda_labs": {"cost_usd": 0.0},
|
|
2536
|
+
},
|
|
2537
|
+
"total_cost_usd": 0.0,
|
|
2538
|
+
}
|
|
2539
|
+
|
|
2540
|
+
def _get_benchmark_metrics(self) -> dict:
|
|
2541
|
+
"""Get performance metrics for current/completed benchmarks.
|
|
2542
|
+
|
|
2543
|
+
Returns:
|
|
2544
|
+
dict with success rate trends, domain breakdown, episode metrics
|
|
2545
|
+
"""
|
|
2546
|
+
# Check for metrics file
|
|
2547
|
+
metrics_file = Path("benchmark_metrics.json")
|
|
2548
|
+
if metrics_file.exists():
|
|
2549
|
+
try:
|
|
2550
|
+
return json.loads(metrics_file.read_text())
|
|
2551
|
+
except Exception:
|
|
2552
|
+
pass
|
|
2553
|
+
|
|
2554
|
+
# Load completed runs from benchmark_results/
|
|
2555
|
+
benchmark_results_dir = Path("benchmark_results")
|
|
2556
|
+
if not benchmark_results_dir.exists():
|
|
2557
|
+
return {"error": "No benchmark results found"}
|
|
2558
|
+
|
|
2559
|
+
# Find most recent run
|
|
2560
|
+
runs = sorted(
|
|
2561
|
+
benchmark_results_dir.iterdir(),
|
|
2562
|
+
key=lambda p: p.stat().st_mtime,
|
|
2563
|
+
reverse=True,
|
|
2564
|
+
)
|
|
2565
|
+
if not runs:
|
|
2566
|
+
return {"error": "No benchmark runs found"}
|
|
2567
|
+
|
|
2568
|
+
recent_run = runs[0]
|
|
2569
|
+
summary_path = recent_run / "summary.json"
|
|
2570
|
+
if not summary_path.exists():
|
|
2571
|
+
return {"error": f"No summary.json in {recent_run.name}"}
|
|
2572
|
+
|
|
2573
|
+
try:
|
|
2574
|
+
summary = json.loads(summary_path.read_text())
|
|
2575
|
+
|
|
2576
|
+
# Build domain breakdown from tasks
|
|
2577
|
+
domain_breakdown = {}
|
|
2578
|
+
tasks_dir = recent_run / "tasks"
|
|
2579
|
+
if tasks_dir.exists():
|
|
2580
|
+
for task_dir in tasks_dir.iterdir():
|
|
2581
|
+
if not task_dir.is_dir():
|
|
2582
|
+
continue
|
|
2583
|
+
|
|
2584
|
+
task_json = task_dir / "task.json"
|
|
2585
|
+
execution_json = task_dir / "execution.json"
|
|
2586
|
+
if not (task_json.exists() and execution_json.exists()):
|
|
2587
|
+
continue
|
|
2588
|
+
|
|
2589
|
+
try:
|
|
2590
|
+
task_def = json.loads(task_json.read_text())
|
|
2591
|
+
execution = json.loads(execution_json.read_text())
|
|
2592
|
+
|
|
2593
|
+
domain = task_def.get("domain", "unknown")
|
|
2594
|
+
if domain not in domain_breakdown:
|
|
2595
|
+
domain_breakdown[domain] = {
|
|
2596
|
+
"total": 0,
|
|
2597
|
+
"success": 0,
|
|
2598
|
+
"rate": 0.0,
|
|
2599
|
+
"avg_steps": 0.0,
|
|
2600
|
+
"total_steps": 0,
|
|
2601
|
+
}
|
|
2602
|
+
|
|
2603
|
+
domain_breakdown[domain]["total"] += 1
|
|
2604
|
+
if execution.get("success"):
|
|
2605
|
+
domain_breakdown[domain]["success"] += 1
|
|
2606
|
+
domain_breakdown[domain]["total_steps"] += execution.get(
|
|
2607
|
+
"num_steps", 0
|
|
2608
|
+
)
|
|
2609
|
+
|
|
2610
|
+
except Exception:
|
|
2611
|
+
continue
|
|
2612
|
+
|
|
2613
|
+
# Calculate averages
|
|
2614
|
+
for domain, stats in domain_breakdown.items():
|
|
2615
|
+
if stats["total"] > 0:
|
|
2616
|
+
stats["rate"] = stats["success"] / stats["total"]
|
|
2617
|
+
stats["avg_steps"] = stats["total_steps"] / stats["total"]
|
|
2618
|
+
|
|
2619
|
+
return {
|
|
2620
|
+
"success_rate_over_time": [], # TODO: implement trend tracking
|
|
2621
|
+
"avg_steps_per_task": [], # TODO: implement trend tracking
|
|
2622
|
+
"domain_breakdown": domain_breakdown,
|
|
2623
|
+
"episode_success_metrics": {
|
|
2624
|
+
"first_action_accuracy": summary.get(
|
|
2625
|
+
"first_action_accuracy", 0.0
|
|
2626
|
+
),
|
|
2627
|
+
"episode_success_rate": summary.get("success_rate", 0.0),
|
|
2628
|
+
"avg_steps_to_success": summary.get("avg_steps", 0.0),
|
|
2629
|
+
"avg_steps_to_failure": 0.0, # TODO: calculate from failed tasks
|
|
2630
|
+
},
|
|
2631
|
+
}
|
|
2632
|
+
except Exception as e:
|
|
2633
|
+
return {"error": f"Failed to load metrics: {str(e)}"}
|
|
2634
|
+
|
|
2635
|
+
def _get_benchmark_workers(self) -> dict:
|
|
2636
|
+
"""Get worker status and utilization.
|
|
2637
|
+
|
|
2638
|
+
Returns:
|
|
2639
|
+
dict with total/active/idle workers and per-worker stats
|
|
2640
|
+
"""
|
|
2641
|
+
# Get VM registry
|
|
2642
|
+
vms = self._fetch_vm_registry()
|
|
2643
|
+
|
|
2644
|
+
active_workers = [v for v in vms if v.get("status") == "online"]
|
|
2645
|
+
idle_workers = [v for v in vms if v.get("status") != "online"]
|
|
2646
|
+
|
|
2647
|
+
workers = []
|
|
2648
|
+
for vm in vms:
|
|
2649
|
+
workers.append(
|
|
2650
|
+
{
|
|
2651
|
+
"worker_id": vm.get("name", "unknown"),
|
|
2652
|
+
"status": "running" if vm.get("status") == "online" else "idle",
|
|
2653
|
+
"current_task": vm.get("current_task"),
|
|
2654
|
+
"tasks_completed": vm.get("tasks_completed", 0),
|
|
2655
|
+
"uptime_seconds": vm.get("uptime_seconds", 0),
|
|
2656
|
+
"idle_time_seconds": vm.get("idle_time_seconds", 0),
|
|
2657
|
+
}
|
|
2658
|
+
)
|
|
2659
|
+
|
|
2660
|
+
return {
|
|
2661
|
+
"total_workers": len(vms),
|
|
2662
|
+
"active_workers": len(active_workers),
|
|
2663
|
+
"idle_workers": len(idle_workers),
|
|
2664
|
+
"workers": workers,
|
|
2665
|
+
}
|
|
2666
|
+
|
|
2667
|
+
def _get_benchmark_runs(self) -> list[dict]:
|
|
2668
|
+
"""Load all benchmark runs from benchmark_results directory.
|
|
2669
|
+
|
|
2670
|
+
Returns:
|
|
2671
|
+
List of benchmark run summaries sorted by timestamp (newest first)
|
|
2672
|
+
"""
|
|
2673
|
+
results_dir = Path("benchmark_results")
|
|
2674
|
+
if not results_dir.exists():
|
|
2675
|
+
return []
|
|
2676
|
+
|
|
2677
|
+
runs = []
|
|
2678
|
+
for run_dir in results_dir.iterdir():
|
|
2679
|
+
if run_dir.is_dir():
|
|
2680
|
+
summary_file = run_dir / "summary.json"
|
|
2681
|
+
if summary_file.exists():
|
|
2682
|
+
try:
|
|
2683
|
+
summary = json.loads(summary_file.read_text())
|
|
2684
|
+
runs.append(summary)
|
|
2685
|
+
except (json.JSONDecodeError, IOError) as e:
|
|
2686
|
+
print(f"Warning: Failed to load {summary_file}: {e}")
|
|
2687
|
+
|
|
2688
|
+
# Sort by run_name descending (newest first)
|
|
2689
|
+
runs.sort(key=lambda r: r.get("run_name", ""), reverse=True)
|
|
2690
|
+
return runs
|
|
2691
|
+
|
|
2692
|
+
def _get_task_execution(self, run_name: str, task_id: str) -> dict:
|
|
2693
|
+
"""Load task execution details from execution.json.
|
|
2694
|
+
|
|
2695
|
+
Args:
|
|
2696
|
+
run_name: Name of the benchmark run
|
|
2697
|
+
task_id: Task identifier
|
|
2698
|
+
|
|
2699
|
+
Returns:
|
|
2700
|
+
Task execution data with steps and screenshots
|
|
2701
|
+
"""
|
|
2702
|
+
results_dir = Path("benchmark_results")
|
|
2703
|
+
execution_file = (
|
|
2704
|
+
results_dir / run_name / "tasks" / task_id / "execution.json"
|
|
2705
|
+
)
|
|
2706
|
+
|
|
2707
|
+
if not execution_file.exists():
|
|
2708
|
+
raise FileNotFoundError(f"Execution file not found: {execution_file}")
|
|
2709
|
+
|
|
2710
|
+
try:
|
|
2711
|
+
return json.loads(execution_file.read_text())
|
|
2712
|
+
except (json.JSONDecodeError, IOError) as e:
|
|
2713
|
+
raise Exception(f"Failed to load execution data: {e}")
|
|
2714
|
+
|
|
2715
|
+
async def _detect_running_benchmark(
|
|
2716
|
+
self, vm_ip: str, container_name: str = "winarena"
|
|
2717
|
+
) -> dict:
|
|
1536
2718
|
"""Detect if a benchmark is running on the VM and extract progress.
|
|
1537
2719
|
|
|
1538
2720
|
SSH into VM and check:
|
|
@@ -1563,11 +2745,20 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1563
2745
|
try:
|
|
1564
2746
|
# Check if benchmark process is running
|
|
1565
2747
|
process_check = subprocess.run(
|
|
1566
|
-
[
|
|
1567
|
-
|
|
1568
|
-
|
|
1569
|
-
|
|
1570
|
-
|
|
2748
|
+
[
|
|
2749
|
+
"ssh",
|
|
2750
|
+
"-o",
|
|
2751
|
+
"StrictHostKeyChecking=no",
|
|
2752
|
+
"-o",
|
|
2753
|
+
"ConnectTimeout=5",
|
|
2754
|
+
"-i",
|
|
2755
|
+
str(Path.home() / ".ssh" / "id_rsa"),
|
|
2756
|
+
f"azureuser@{vm_ip}",
|
|
2757
|
+
f"docker exec {container_name} pgrep -f 'python.*run.py' 2>/dev/null || echo ''",
|
|
2758
|
+
],
|
|
2759
|
+
capture_output=True,
|
|
2760
|
+
text=True,
|
|
2761
|
+
timeout=10,
|
|
1571
2762
|
)
|
|
1572
2763
|
|
|
1573
2764
|
if process_check.returncode == 0 and process_check.stdout.strip():
|
|
@@ -1575,11 +2766,20 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1575
2766
|
|
|
1576
2767
|
# Get benchmark log
|
|
1577
2768
|
log_check = subprocess.run(
|
|
1578
|
-
[
|
|
1579
|
-
|
|
1580
|
-
|
|
1581
|
-
|
|
1582
|
-
|
|
2769
|
+
[
|
|
2770
|
+
"ssh",
|
|
2771
|
+
"-o",
|
|
2772
|
+
"StrictHostKeyChecking=no",
|
|
2773
|
+
"-o",
|
|
2774
|
+
"ConnectTimeout=5",
|
|
2775
|
+
"-i",
|
|
2776
|
+
str(Path.home() / ".ssh" / "id_rsa"),
|
|
2777
|
+
f"azureuser@{vm_ip}",
|
|
2778
|
+
"tail -100 /tmp/waa_benchmark.log 2>/dev/null || echo ''",
|
|
2779
|
+
],
|
|
2780
|
+
capture_output=True,
|
|
2781
|
+
text=True,
|
|
2782
|
+
timeout=10,
|
|
1583
2783
|
)
|
|
1584
2784
|
|
|
1585
2785
|
if log_check.returncode == 0 and log_check.stdout.strip():
|
|
@@ -1588,22 +2788,28 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1588
2788
|
|
|
1589
2789
|
# Parse progress from logs
|
|
1590
2790
|
# Look for patterns like "Task 5/30" or "Completed: 5, Remaining: 25"
|
|
1591
|
-
task_match = re.search(r
|
|
2791
|
+
task_match = re.search(r"Task\s+(\d+)/(\d+)", logs)
|
|
1592
2792
|
if task_match:
|
|
1593
|
-
result["progress"]["tasks_completed"] = int(
|
|
2793
|
+
result["progress"]["tasks_completed"] = int(
|
|
2794
|
+
task_match.group(1)
|
|
2795
|
+
)
|
|
1594
2796
|
result["progress"]["total_tasks"] = int(task_match.group(2))
|
|
1595
2797
|
|
|
1596
2798
|
# Extract current task ID
|
|
1597
|
-
task_id_match = re.search(
|
|
2799
|
+
task_id_match = re.search(
|
|
2800
|
+
r"(?:Running|Processing) task:\s*(\S+)", logs
|
|
2801
|
+
)
|
|
1598
2802
|
if task_id_match:
|
|
1599
2803
|
result["current_task"] = task_id_match.group(1)
|
|
1600
2804
|
|
|
1601
2805
|
# Extract step info
|
|
1602
|
-
step_match = re.search(r
|
|
2806
|
+
step_match = re.search(r"Step\s+(\d+)", logs)
|
|
1603
2807
|
if step_match:
|
|
1604
|
-
result["progress"]["current_step"] = int(
|
|
2808
|
+
result["progress"]["current_step"] = int(
|
|
2809
|
+
step_match.group(1)
|
|
2810
|
+
)
|
|
1605
2811
|
|
|
1606
|
-
except Exception
|
|
2812
|
+
except Exception:
|
|
1607
2813
|
# SSH or parsing failed - leave defaults
|
|
1608
2814
|
pass
|
|
1609
2815
|
|
|
@@ -1627,13 +2833,13 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1627
2833
|
# Search backwards from most recent
|
|
1628
2834
|
for line in reversed(log_lines):
|
|
1629
2835
|
# Check for explicit result
|
|
1630
|
-
if
|
|
2836
|
+
if "Result: PASS" in line or "completed successfully" in line:
|
|
1631
2837
|
success = True
|
|
1632
|
-
elif
|
|
2838
|
+
elif "Result: FAIL" in line or "failed" in line.lower():
|
|
1633
2839
|
success = False
|
|
1634
2840
|
|
|
1635
2841
|
# Check for score
|
|
1636
|
-
score_match = re.search(r
|
|
2842
|
+
score_match = re.search(r"Score:\s*([\d.]+)", line)
|
|
1637
2843
|
if score_match:
|
|
1638
2844
|
try:
|
|
1639
2845
|
score = float(score_match.group(1))
|
|
@@ -1642,9 +2848,9 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1642
2848
|
|
|
1643
2849
|
# Check for task-specific completion
|
|
1644
2850
|
if task_id in line:
|
|
1645
|
-
if
|
|
2851
|
+
if "success" in line.lower() or "pass" in line.lower():
|
|
1646
2852
|
success = True
|
|
1647
|
-
elif
|
|
2853
|
+
elif "fail" in line.lower() or "error" in line.lower():
|
|
1648
2854
|
success = False
|
|
1649
2855
|
|
|
1650
2856
|
# Default to True if no explicit failure found (backwards compatible)
|
|
@@ -1674,11 +2880,11 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1674
2880
|
|
|
1675
2881
|
# Set SSE headers
|
|
1676
2882
|
self.send_response(200)
|
|
1677
|
-
self.send_header(
|
|
1678
|
-
self.send_header(
|
|
1679
|
-
self.send_header(
|
|
1680
|
-
self.send_header(
|
|
1681
|
-
self.send_header(
|
|
2883
|
+
self.send_header("Content-Type", "text/event-stream")
|
|
2884
|
+
self.send_header("Cache-Control", "no-cache")
|
|
2885
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
2886
|
+
self.send_header("Connection", "keep-alive")
|
|
2887
|
+
self.send_header("X-Accel-Buffering", "no") # Disable nginx buffering
|
|
1682
2888
|
self.end_headers()
|
|
1683
2889
|
|
|
1684
2890
|
# Track connection state
|
|
@@ -1691,7 +2897,7 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1691
2897
|
return False
|
|
1692
2898
|
try:
|
|
1693
2899
|
event_str = f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
|
|
1694
|
-
self.wfile.write(event_str.encode(
|
|
2900
|
+
self.wfile.write(event_str.encode("utf-8"))
|
|
1695
2901
|
self.wfile.flush()
|
|
1696
2902
|
return True
|
|
1697
2903
|
except (BrokenPipeError, ConnectionResetError, ConnectionAbortedError):
|
|
@@ -1734,11 +2940,10 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1734
2940
|
recent_log_lines = []
|
|
1735
2941
|
|
|
1736
2942
|
# Send initial connected event
|
|
1737
|
-
if not send_event(
|
|
1738
|
-
"
|
|
1739
|
-
"interval": interval,
|
|
1740
|
-
|
|
1741
|
-
}):
|
|
2943
|
+
if not send_event(
|
|
2944
|
+
"connected",
|
|
2945
|
+
{"timestamp": time.time(), "interval": interval, "version": "1.0"},
|
|
2946
|
+
):
|
|
1742
2947
|
return
|
|
1743
2948
|
|
|
1744
2949
|
try:
|
|
@@ -1763,17 +2968,25 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1763
2968
|
tasks = self._fetch_background_tasks()
|
|
1764
2969
|
|
|
1765
2970
|
# Send VM status event
|
|
1766
|
-
vm_task = next(
|
|
2971
|
+
vm_task = next(
|
|
2972
|
+
(t for t in tasks if t.get("task_type") == "docker_container"),
|
|
2973
|
+
None,
|
|
2974
|
+
)
|
|
1767
2975
|
if vm_task:
|
|
1768
2976
|
vm_data = {
|
|
1769
2977
|
"type": "vm_status",
|
|
1770
|
-
"connected": vm_task.get("status")
|
|
2978
|
+
"connected": vm_task.get("status")
|
|
2979
|
+
in ["running", "completed"],
|
|
1771
2980
|
"phase": vm_task.get("phase", "unknown"),
|
|
1772
|
-
"waa_ready": vm_task.get("metadata", {}).get(
|
|
2981
|
+
"waa_ready": vm_task.get("metadata", {}).get(
|
|
2982
|
+
"waa_server_ready", False
|
|
2983
|
+
),
|
|
1773
2984
|
"probe": {
|
|
1774
|
-
"status": vm_task.get("metadata", {}).get(
|
|
2985
|
+
"status": vm_task.get("metadata", {}).get(
|
|
2986
|
+
"probe_response", "unknown"
|
|
2987
|
+
),
|
|
1775
2988
|
"vnc_url": vm_task.get("metadata", {}).get("vnc_url"),
|
|
1776
|
-
}
|
|
2989
|
+
},
|
|
1777
2990
|
}
|
|
1778
2991
|
|
|
1779
2992
|
if not send_event("status", vm_data):
|
|
@@ -1783,27 +2996,47 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1783
2996
|
if vm_data["waa_ready"]:
|
|
1784
2997
|
# Get VM IP from tasks
|
|
1785
2998
|
vm_ip = None
|
|
1786
|
-
azure_vm = next(
|
|
2999
|
+
azure_vm = next(
|
|
3000
|
+
(
|
|
3001
|
+
t
|
|
3002
|
+
for t in tasks
|
|
3003
|
+
if t.get("task_type") == "vm_provision"
|
|
3004
|
+
),
|
|
3005
|
+
None,
|
|
3006
|
+
)
|
|
1787
3007
|
if azure_vm:
|
|
1788
3008
|
vm_ip = azure_vm.get("metadata", {}).get("ip_address")
|
|
1789
3009
|
|
|
1790
3010
|
if vm_ip:
|
|
1791
3011
|
# Detect running benchmark using sync version
|
|
1792
3012
|
benchmark_status = self._detect_running_benchmark_sync(
|
|
1793
|
-
vm_ip,
|
|
3013
|
+
vm_ip,
|
|
3014
|
+
vm_task.get("metadata", {}).get(
|
|
3015
|
+
"container", "winarena"
|
|
3016
|
+
),
|
|
1794
3017
|
)
|
|
1795
3018
|
|
|
1796
3019
|
if benchmark_status["running"]:
|
|
1797
3020
|
# Store log lines for result parsing
|
|
1798
3021
|
if benchmark_status.get("recent_logs"):
|
|
1799
|
-
recent_log_lines = benchmark_status[
|
|
3022
|
+
recent_log_lines = benchmark_status[
|
|
3023
|
+
"recent_logs"
|
|
3024
|
+
].split("\n")
|
|
1800
3025
|
|
|
1801
3026
|
# Send progress event
|
|
1802
3027
|
progress_data = {
|
|
1803
|
-
"tasks_completed": benchmark_status["progress"][
|
|
1804
|
-
|
|
1805
|
-
|
|
1806
|
-
"
|
|
3028
|
+
"tasks_completed": benchmark_status["progress"][
|
|
3029
|
+
"tasks_completed"
|
|
3030
|
+
],
|
|
3031
|
+
"total_tasks": benchmark_status["progress"][
|
|
3032
|
+
"total_tasks"
|
|
3033
|
+
],
|
|
3034
|
+
"current_task": benchmark_status[
|
|
3035
|
+
"current_task"
|
|
3036
|
+
],
|
|
3037
|
+
"current_step": benchmark_status["progress"][
|
|
3038
|
+
"current_step"
|
|
3039
|
+
],
|
|
1807
3040
|
}
|
|
1808
3041
|
|
|
1809
3042
|
if not send_event("progress", progress_data):
|
|
@@ -1814,13 +3047,17 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1814
3047
|
if current_task and current_task != last_task:
|
|
1815
3048
|
if last_task is not None:
|
|
1816
3049
|
# Previous task completed - parse result from logs
|
|
1817
|
-
result = self._parse_task_result(
|
|
3050
|
+
result = self._parse_task_result(
|
|
3051
|
+
recent_log_lines, last_task
|
|
3052
|
+
)
|
|
1818
3053
|
complete_data = {
|
|
1819
3054
|
"task_id": last_task,
|
|
1820
3055
|
"success": result["success"],
|
|
1821
3056
|
"score": result["score"],
|
|
1822
3057
|
}
|
|
1823
|
-
if not send_event(
|
|
3058
|
+
if not send_event(
|
|
3059
|
+
"task_complete", complete_data
|
|
3060
|
+
):
|
|
1824
3061
|
break
|
|
1825
3062
|
|
|
1826
3063
|
last_task = current_task
|
|
@@ -1832,7 +3069,9 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1832
3069
|
progress = json.loads(progress_file.read_text())
|
|
1833
3070
|
if progress.get("status") == "running":
|
|
1834
3071
|
progress_data = {
|
|
1835
|
-
"tasks_completed": progress.get(
|
|
3072
|
+
"tasks_completed": progress.get(
|
|
3073
|
+
"tasks_complete", 0
|
|
3074
|
+
),
|
|
1836
3075
|
"total_tasks": progress.get("tasks_total", 0),
|
|
1837
3076
|
"current_task": progress.get("provider", "unknown"),
|
|
1838
3077
|
}
|
|
@@ -1858,7 +3097,244 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1858
3097
|
# Cleanup - connection is ending
|
|
1859
3098
|
client_connected = False
|
|
1860
3099
|
|
|
1861
|
-
def
|
|
3100
|
+
def _stream_azure_ops_updates(self):
|
|
3101
|
+
"""Stream Server-Sent Events for Azure operations status updates.
|
|
3102
|
+
|
|
3103
|
+
Monitors azure_ops_status.json for changes and streams updates.
|
|
3104
|
+
Uses file modification time to detect changes efficiently.
|
|
3105
|
+
|
|
3106
|
+
Streams events:
|
|
3107
|
+
- connected: Initial connection event
|
|
3108
|
+
- status: Azure ops status update when file changes
|
|
3109
|
+
- heartbeat: Keep-alive signal every 30 seconds
|
|
3110
|
+
- error: Error messages
|
|
3111
|
+
"""
|
|
3112
|
+
import time
|
|
3113
|
+
import select
|
|
3114
|
+
from pathlib import Path
|
|
3115
|
+
|
|
3116
|
+
HEARTBEAT_INTERVAL = 30 # seconds
|
|
3117
|
+
CHECK_INTERVAL = 1 # Check file every second
|
|
3118
|
+
|
|
3119
|
+
# Set SSE headers
|
|
3120
|
+
self.send_response(200)
|
|
3121
|
+
self.send_header("Content-Type", "text/event-stream")
|
|
3122
|
+
self.send_header("Cache-Control", "no-cache")
|
|
3123
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
3124
|
+
self.send_header("Connection", "keep-alive")
|
|
3125
|
+
self.send_header("X-Accel-Buffering", "no") # Disable nginx buffering
|
|
3126
|
+
self.end_headers()
|
|
3127
|
+
|
|
3128
|
+
# Track connection state
|
|
3129
|
+
client_connected = True
|
|
3130
|
+
last_mtime = 0.0
|
|
3131
|
+
last_session_mtime = 0.0
|
|
3132
|
+
last_heartbeat = time.time()
|
|
3133
|
+
|
|
3134
|
+
def send_event(event_type: str, data: dict) -> bool:
|
|
3135
|
+
"""Send an SSE event. Returns False if client disconnected."""
|
|
3136
|
+
nonlocal client_connected
|
|
3137
|
+
if not client_connected:
|
|
3138
|
+
return False
|
|
3139
|
+
try:
|
|
3140
|
+
event_str = f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
|
|
3141
|
+
self.wfile.write(event_str.encode("utf-8"))
|
|
3142
|
+
self.wfile.flush()
|
|
3143
|
+
return True
|
|
3144
|
+
except (BrokenPipeError, ConnectionResetError, ConnectionAbortedError):
|
|
3145
|
+
client_connected = False
|
|
3146
|
+
return False
|
|
3147
|
+
except Exception as e:
|
|
3148
|
+
print(f"Azure ops SSE send error: {e}")
|
|
3149
|
+
client_connected = False
|
|
3150
|
+
return False
|
|
3151
|
+
|
|
3152
|
+
def check_client_connected() -> bool:
|
|
3153
|
+
"""Check if client is still connected using socket select."""
|
|
3154
|
+
nonlocal client_connected
|
|
3155
|
+
if not client_connected:
|
|
3156
|
+
return False
|
|
3157
|
+
try:
|
|
3158
|
+
rlist, _, xlist = select.select([self.rfile], [], [self.rfile], 0)
|
|
3159
|
+
if xlist:
|
|
3160
|
+
client_connected = False
|
|
3161
|
+
return False
|
|
3162
|
+
if rlist:
|
|
3163
|
+
data = self.rfile.read(1)
|
|
3164
|
+
if not data:
|
|
3165
|
+
client_connected = False
|
|
3166
|
+
return False
|
|
3167
|
+
return True
|
|
3168
|
+
except Exception:
|
|
3169
|
+
client_connected = False
|
|
3170
|
+
return False
|
|
3171
|
+
|
|
3172
|
+
# Status file path
|
|
3173
|
+
from openadapt_ml.benchmarks.azure_ops_tracker import (
|
|
3174
|
+
DEFAULT_OUTPUT_FILE,
|
|
3175
|
+
read_status,
|
|
3176
|
+
)
|
|
3177
|
+
from openadapt_ml.benchmarks.session_tracker import (
|
|
3178
|
+
get_session,
|
|
3179
|
+
update_session_vm_state,
|
|
3180
|
+
DEFAULT_SESSION_FILE,
|
|
3181
|
+
)
|
|
3182
|
+
|
|
3183
|
+
status_file = Path(DEFAULT_OUTPUT_FILE)
|
|
3184
|
+
session_file = Path(DEFAULT_SESSION_FILE)
|
|
3185
|
+
|
|
3186
|
+
def compute_server_side_values(status: dict) -> dict:
|
|
3187
|
+
"""Get elapsed_seconds and cost_usd from session tracker for persistence."""
|
|
3188
|
+
# Get session data (persistent across refreshes)
|
|
3189
|
+
session = get_session()
|
|
3190
|
+
|
|
3191
|
+
# Update session based on VM state if we have VM info
|
|
3192
|
+
# IMPORTANT: Only pass vm_ip if it's truthy to avoid
|
|
3193
|
+
# overwriting session's stable vm_ip with None
|
|
3194
|
+
if status.get("vm_state") and status.get("vm_state") != "unknown":
|
|
3195
|
+
status_vm_ip = status.get("vm_ip")
|
|
3196
|
+
# Build update kwargs - only include vm_ip if present
|
|
3197
|
+
update_kwargs = {
|
|
3198
|
+
"vm_state": status["vm_state"],
|
|
3199
|
+
"vm_size": status.get("vm_size"),
|
|
3200
|
+
}
|
|
3201
|
+
if status_vm_ip: # Only include if truthy
|
|
3202
|
+
update_kwargs["vm_ip"] = status_vm_ip
|
|
3203
|
+
session = update_session_vm_state(**update_kwargs)
|
|
3204
|
+
|
|
3205
|
+
# Use session's vm_ip as authoritative source
|
|
3206
|
+
# This prevents IP flickering when status file has stale/None values
|
|
3207
|
+
if session.get("vm_ip"):
|
|
3208
|
+
status["vm_ip"] = session["vm_ip"]
|
|
3209
|
+
|
|
3210
|
+
# Use session's elapsed_seconds and cost_usd for persistence
|
|
3211
|
+
if (
|
|
3212
|
+
session.get("is_active")
|
|
3213
|
+
or session.get("accumulated_seconds", 0) > 0
|
|
3214
|
+
):
|
|
3215
|
+
status["elapsed_seconds"] = session.get("elapsed_seconds", 0.0)
|
|
3216
|
+
status["cost_usd"] = session.get("cost_usd", 0.0)
|
|
3217
|
+
status["started_at"] = session.get("started_at")
|
|
3218
|
+
status["session_id"] = session.get("session_id")
|
|
3219
|
+
status["session_is_active"] = session.get("is_active", False)
|
|
3220
|
+
# Include accumulated time from previous sessions for hybrid display
|
|
3221
|
+
status["accumulated_seconds"] = session.get(
|
|
3222
|
+
"accumulated_seconds", 0.0
|
|
3223
|
+
)
|
|
3224
|
+
# Calculate current session time (total - accumulated)
|
|
3225
|
+
current_session_seconds = max(
|
|
3226
|
+
0, status["elapsed_seconds"] - status["accumulated_seconds"]
|
|
3227
|
+
)
|
|
3228
|
+
status["current_session_seconds"] = current_session_seconds
|
|
3229
|
+
hourly_rate = session.get("hourly_rate_usd", 0.422)
|
|
3230
|
+
status["current_session_cost_usd"] = (
|
|
3231
|
+
current_session_seconds / 3600
|
|
3232
|
+
) * hourly_rate
|
|
3233
|
+
|
|
3234
|
+
try:
|
|
3235
|
+
tunnel_mgr = get_tunnel_manager()
|
|
3236
|
+
tunnel_status = tunnel_mgr.get_tunnel_status()
|
|
3237
|
+
status["tunnels"] = {
|
|
3238
|
+
name: {
|
|
3239
|
+
"active": s.active,
|
|
3240
|
+
"local_port": s.local_port,
|
|
3241
|
+
"remote_endpoint": s.remote_endpoint,
|
|
3242
|
+
"pid": s.pid,
|
|
3243
|
+
"error": s.error,
|
|
3244
|
+
}
|
|
3245
|
+
for name, s in tunnel_status.items()
|
|
3246
|
+
}
|
|
3247
|
+
except Exception as e:
|
|
3248
|
+
status["tunnels"] = {"error": str(e)}
|
|
3249
|
+
|
|
3250
|
+
return status
|
|
3251
|
+
|
|
3252
|
+
# Send initial connected event
|
|
3253
|
+
if not send_event(
|
|
3254
|
+
"connected",
|
|
3255
|
+
{"timestamp": time.time(), "version": "1.0"},
|
|
3256
|
+
):
|
|
3257
|
+
return
|
|
3258
|
+
|
|
3259
|
+
# Send initial status immediately
|
|
3260
|
+
try:
|
|
3261
|
+
status = compute_server_side_values(read_status())
|
|
3262
|
+
if not send_event("status", status):
|
|
3263
|
+
return
|
|
3264
|
+
if status_file.exists():
|
|
3265
|
+
last_mtime = status_file.stat().st_mtime
|
|
3266
|
+
except Exception as e:
|
|
3267
|
+
send_event("error", {"message": str(e)})
|
|
3268
|
+
|
|
3269
|
+
try:
|
|
3270
|
+
iteration_count = 0
|
|
3271
|
+
max_iterations = 3600 # Max 1 hour of streaming
|
|
3272
|
+
last_status_send = 0.0
|
|
3273
|
+
STATUS_SEND_INTERVAL = 2 # Send status every 2 seconds for live updates
|
|
3274
|
+
|
|
3275
|
+
while client_connected and iteration_count < max_iterations:
|
|
3276
|
+
iteration_count += 1
|
|
3277
|
+
current_time = time.time()
|
|
3278
|
+
|
|
3279
|
+
# Check client connection
|
|
3280
|
+
if not check_client_connected():
|
|
3281
|
+
break
|
|
3282
|
+
|
|
3283
|
+
# Send heartbeat every 30 seconds
|
|
3284
|
+
if current_time - last_heartbeat >= HEARTBEAT_INTERVAL:
|
|
3285
|
+
if not send_event("heartbeat", {"timestamp": current_time}):
|
|
3286
|
+
break
|
|
3287
|
+
last_heartbeat = current_time
|
|
3288
|
+
|
|
3289
|
+
# Check if status or session file changed OR if enough time passed
|
|
3290
|
+
try:
|
|
3291
|
+
status_changed = False
|
|
3292
|
+
session_changed = False
|
|
3293
|
+
time_to_send = (
|
|
3294
|
+
current_time - last_status_send >= STATUS_SEND_INTERVAL
|
|
3295
|
+
)
|
|
3296
|
+
|
|
3297
|
+
if status_file.exists():
|
|
3298
|
+
current_mtime = status_file.stat().st_mtime
|
|
3299
|
+
if current_mtime > last_mtime:
|
|
3300
|
+
status_changed = True
|
|
3301
|
+
last_mtime = current_mtime
|
|
3302
|
+
|
|
3303
|
+
if session_file.exists():
|
|
3304
|
+
current_session_mtime = session_file.stat().st_mtime
|
|
3305
|
+
if current_session_mtime > last_session_mtime:
|
|
3306
|
+
session_changed = True
|
|
3307
|
+
last_session_mtime = current_session_mtime
|
|
3308
|
+
|
|
3309
|
+
# Send status if file changed OR periodic timer expired
|
|
3310
|
+
# This ensures live elapsed time/cost updates even without file changes
|
|
3311
|
+
if status_changed or session_changed or time_to_send:
|
|
3312
|
+
# File changed or time to send - send update with session values
|
|
3313
|
+
status = compute_server_side_values(read_status())
|
|
3314
|
+
if not send_event("status", status):
|
|
3315
|
+
break
|
|
3316
|
+
last_status_send = current_time
|
|
3317
|
+
except Exception as e:
|
|
3318
|
+
# File access error - log but continue
|
|
3319
|
+
print(f"Azure ops SSE file check error: {e}")
|
|
3320
|
+
|
|
3321
|
+
# Sleep briefly before next check
|
|
3322
|
+
try:
|
|
3323
|
+
select.select([self.rfile], [], [], CHECK_INTERVAL)
|
|
3324
|
+
except Exception:
|
|
3325
|
+
break
|
|
3326
|
+
|
|
3327
|
+
except (BrokenPipeError, ConnectionResetError, ConnectionAbortedError):
|
|
3328
|
+
# Client disconnected - normal
|
|
3329
|
+
pass
|
|
3330
|
+
except Exception as e:
|
|
3331
|
+
send_event("error", {"message": str(e)})
|
|
3332
|
+
finally:
|
|
3333
|
+
client_connected = False
|
|
3334
|
+
|
|
3335
|
+
def _detect_running_benchmark_sync(
|
|
3336
|
+
self, vm_ip: str, container_name: str = "winarena"
|
|
3337
|
+
) -> dict:
|
|
1862
3338
|
"""Synchronous version of _detect_running_benchmark.
|
|
1863
3339
|
|
|
1864
3340
|
Avoids creating a new event loop on each call which causes issues
|
|
@@ -1881,11 +3357,20 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1881
3357
|
try:
|
|
1882
3358
|
# Check if benchmark process is running
|
|
1883
3359
|
process_check = subprocess.run(
|
|
1884
|
-
[
|
|
1885
|
-
|
|
1886
|
-
|
|
1887
|
-
|
|
1888
|
-
|
|
3360
|
+
[
|
|
3361
|
+
"ssh",
|
|
3362
|
+
"-o",
|
|
3363
|
+
"StrictHostKeyChecking=no",
|
|
3364
|
+
"-o",
|
|
3365
|
+
"ConnectTimeout=5",
|
|
3366
|
+
"-i",
|
|
3367
|
+
str(Path.home() / ".ssh" / "id_rsa"),
|
|
3368
|
+
f"azureuser@{vm_ip}",
|
|
3369
|
+
f"docker exec {container_name} pgrep -f 'python.*run.py' 2>/dev/null || echo ''",
|
|
3370
|
+
],
|
|
3371
|
+
capture_output=True,
|
|
3372
|
+
text=True,
|
|
3373
|
+
timeout=10,
|
|
1889
3374
|
)
|
|
1890
3375
|
|
|
1891
3376
|
if process_check.returncode == 0 and process_check.stdout.strip():
|
|
@@ -1893,11 +3378,20 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1893
3378
|
|
|
1894
3379
|
# Get benchmark log
|
|
1895
3380
|
log_check = subprocess.run(
|
|
1896
|
-
[
|
|
1897
|
-
|
|
1898
|
-
|
|
1899
|
-
|
|
1900
|
-
|
|
3381
|
+
[
|
|
3382
|
+
"ssh",
|
|
3383
|
+
"-o",
|
|
3384
|
+
"StrictHostKeyChecking=no",
|
|
3385
|
+
"-o",
|
|
3386
|
+
"ConnectTimeout=5",
|
|
3387
|
+
"-i",
|
|
3388
|
+
str(Path.home() / ".ssh" / "id_rsa"),
|
|
3389
|
+
f"azureuser@{vm_ip}",
|
|
3390
|
+
"tail -100 /tmp/waa_benchmark.log 2>/dev/null || echo ''",
|
|
3391
|
+
],
|
|
3392
|
+
capture_output=True,
|
|
3393
|
+
text=True,
|
|
3394
|
+
timeout=10,
|
|
1901
3395
|
)
|
|
1902
3396
|
|
|
1903
3397
|
if log_check.returncode == 0 and log_check.stdout.strip():
|
|
@@ -1905,20 +3399,26 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1905
3399
|
result["recent_logs"] = logs[-500:] # Last 500 chars
|
|
1906
3400
|
|
|
1907
3401
|
# Parse progress from logs
|
|
1908
|
-
task_match = re.search(r
|
|
3402
|
+
task_match = re.search(r"Task\s+(\d+)/(\d+)", logs)
|
|
1909
3403
|
if task_match:
|
|
1910
|
-
result["progress"]["tasks_completed"] = int(
|
|
3404
|
+
result["progress"]["tasks_completed"] = int(
|
|
3405
|
+
task_match.group(1)
|
|
3406
|
+
)
|
|
1911
3407
|
result["progress"]["total_tasks"] = int(task_match.group(2))
|
|
1912
3408
|
|
|
1913
3409
|
# Extract current task ID
|
|
1914
|
-
task_id_match = re.search(
|
|
3410
|
+
task_id_match = re.search(
|
|
3411
|
+
r"(?:Running|Processing) task:\s*(\S+)", logs
|
|
3412
|
+
)
|
|
1915
3413
|
if task_id_match:
|
|
1916
3414
|
result["current_task"] = task_id_match.group(1)
|
|
1917
3415
|
|
|
1918
3416
|
# Extract step info
|
|
1919
|
-
step_match = re.search(r
|
|
3417
|
+
step_match = re.search(r"Step\s+(\d+)", logs)
|
|
1920
3418
|
if step_match:
|
|
1921
|
-
result["progress"]["current_step"] = int(
|
|
3419
|
+
result["progress"]["current_step"] = int(
|
|
3420
|
+
step_match.group(1)
|
|
3421
|
+
)
|
|
1922
3422
|
|
|
1923
3423
|
except Exception:
|
|
1924
3424
|
# SSH or parsing failed - leave defaults
|
|
@@ -1949,7 +3449,7 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1949
3449
|
"vnc_port": vm_data.get("vnc_port", 8006),
|
|
1950
3450
|
"waa_port": vm_data.get("waa_port", 5000),
|
|
1951
3451
|
"docker_container": vm_data.get("docker_container", "win11-waa"),
|
|
1952
|
-
"internal_ip": vm_data.get("internal_ip", "20.20.20.21")
|
|
3452
|
+
"internal_ip": vm_data.get("internal_ip", "20.20.20.21"),
|
|
1953
3453
|
}
|
|
1954
3454
|
|
|
1955
3455
|
vms.append(new_vm)
|
|
@@ -1957,7 +3457,7 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1957
3457
|
# Save registry
|
|
1958
3458
|
try:
|
|
1959
3459
|
registry_file.parent.mkdir(parents=True, exist_ok=True)
|
|
1960
|
-
with open(registry_file,
|
|
3460
|
+
with open(registry_file, "w") as f:
|
|
1961
3461
|
json.dump(vms, f, indent=2)
|
|
1962
3462
|
return {"status": "success", "vm": new_vm}
|
|
1963
3463
|
except Exception as e:
|
|
@@ -1990,12 +3490,20 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
1990
3490
|
|
|
1991
3491
|
# Build CLI command
|
|
1992
3492
|
cmd = [
|
|
1993
|
-
"uv",
|
|
1994
|
-
"
|
|
1995
|
-
"
|
|
1996
|
-
"
|
|
1997
|
-
"
|
|
1998
|
-
"
|
|
3493
|
+
"uv",
|
|
3494
|
+
"run",
|
|
3495
|
+
"python",
|
|
3496
|
+
"-m",
|
|
3497
|
+
"openadapt_ml.benchmarks.cli",
|
|
3498
|
+
"vm",
|
|
3499
|
+
"run-waa",
|
|
3500
|
+
"--num-tasks",
|
|
3501
|
+
str(params.get("num_tasks", 5)),
|
|
3502
|
+
"--model",
|
|
3503
|
+
params.get("model", "gpt-4o"),
|
|
3504
|
+
"--agent",
|
|
3505
|
+
params.get("agent", "navi"),
|
|
3506
|
+
"--no-open", # Don't open viewer (already open)
|
|
1999
3507
|
]
|
|
2000
3508
|
|
|
2001
3509
|
# Add task selection args
|
|
@@ -2016,7 +3524,9 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
2016
3524
|
num_tasks = params.get("num_tasks", 5)
|
|
2017
3525
|
agent = params.get("agent", "navi")
|
|
2018
3526
|
|
|
2019
|
-
print(
|
|
3527
|
+
print(
|
|
3528
|
+
f"\n[Benchmark] Starting WAA benchmark: model={model}, tasks={num_tasks}, agent={agent}"
|
|
3529
|
+
)
|
|
2020
3530
|
print(f"[Benchmark] Task selection: {task_selection}")
|
|
2021
3531
|
if task_selection == "domain":
|
|
2022
3532
|
print(f"[Benchmark] Domain: {params.get('domain', 'general')}")
|
|
@@ -2024,15 +3534,19 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
2024
3534
|
print(f"[Benchmark] Task IDs: {params.get('task_ids', [])}")
|
|
2025
3535
|
print(f"[Benchmark] Command: {' '.join(cmd)}")
|
|
2026
3536
|
|
|
2027
|
-
progress_file.write_text(
|
|
2028
|
-
|
|
2029
|
-
|
|
2030
|
-
|
|
2031
|
-
|
|
2032
|
-
|
|
2033
|
-
|
|
2034
|
-
|
|
2035
|
-
|
|
3537
|
+
progress_file.write_text(
|
|
3538
|
+
json.dumps(
|
|
3539
|
+
{
|
|
3540
|
+
"status": "running",
|
|
3541
|
+
"model": model,
|
|
3542
|
+
"num_tasks": num_tasks,
|
|
3543
|
+
"agent": agent,
|
|
3544
|
+
"task_selection": task_selection,
|
|
3545
|
+
"tasks_complete": 0,
|
|
3546
|
+
"message": f"Starting {model} benchmark with {num_tasks} tasks...",
|
|
3547
|
+
}
|
|
3548
|
+
)
|
|
3549
|
+
)
|
|
2036
3550
|
|
|
2037
3551
|
# Copy environment with loaded vars
|
|
2038
3552
|
env = os.environ.copy()
|
|
@@ -2040,11 +3554,7 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
2040
3554
|
# Run in background thread
|
|
2041
3555
|
def run():
|
|
2042
3556
|
result = subprocess.run(
|
|
2043
|
-
cmd,
|
|
2044
|
-
capture_output=True,
|
|
2045
|
-
text=True,
|
|
2046
|
-
cwd=str(project_root),
|
|
2047
|
-
env=env
|
|
3557
|
+
cmd, capture_output=True, text=True, cwd=str(project_root), env=env
|
|
2048
3558
|
)
|
|
2049
3559
|
|
|
2050
3560
|
print(f"\n[Benchmark] Output:\n{result.stdout}")
|
|
@@ -2052,24 +3562,34 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
2052
3562
|
print(f"[Benchmark] Stderr: {result.stderr}")
|
|
2053
3563
|
|
|
2054
3564
|
if result.returncode == 0:
|
|
2055
|
-
print(
|
|
2056
|
-
progress_file.write_text(
|
|
2057
|
-
|
|
2058
|
-
|
|
2059
|
-
|
|
2060
|
-
|
|
2061
|
-
|
|
3565
|
+
print("[Benchmark] Complete. Regenerating viewer...")
|
|
3566
|
+
progress_file.write_text(
|
|
3567
|
+
json.dumps(
|
|
3568
|
+
{
|
|
3569
|
+
"status": "complete",
|
|
3570
|
+
"model": model,
|
|
3571
|
+
"num_tasks": num_tasks,
|
|
3572
|
+
"message": "Benchmark complete. Refresh to see results.",
|
|
3573
|
+
}
|
|
3574
|
+
)
|
|
3575
|
+
)
|
|
2062
3576
|
# Regenerate benchmark viewer
|
|
2063
3577
|
_regenerate_benchmark_viewer_if_available(serve_dir)
|
|
2064
3578
|
else:
|
|
2065
|
-
error_msg =
|
|
3579
|
+
error_msg = (
|
|
3580
|
+
result.stderr[:200] if result.stderr else "Unknown error"
|
|
3581
|
+
)
|
|
2066
3582
|
print(f"[Benchmark] Failed: {error_msg}")
|
|
2067
|
-
progress_file.write_text(
|
|
2068
|
-
|
|
2069
|
-
|
|
2070
|
-
|
|
2071
|
-
|
|
2072
|
-
|
|
3583
|
+
progress_file.write_text(
|
|
3584
|
+
json.dumps(
|
|
3585
|
+
{
|
|
3586
|
+
"status": "error",
|
|
3587
|
+
"model": model,
|
|
3588
|
+
"num_tasks": num_tasks,
|
|
3589
|
+
"message": f"Benchmark failed: {error_msg}",
|
|
3590
|
+
}
|
|
3591
|
+
)
|
|
3592
|
+
)
|
|
2073
3593
|
|
|
2074
3594
|
threading.Thread(target=run, daemon=True).start()
|
|
2075
3595
|
|
|
@@ -2078,9 +3598,9 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
2078
3598
|
def do_OPTIONS(self):
|
|
2079
3599
|
# Handle CORS preflight
|
|
2080
3600
|
self.send_response(200)
|
|
2081
|
-
self.send_header(
|
|
2082
|
-
self.send_header(
|
|
2083
|
-
self.send_header(
|
|
3601
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
3602
|
+
self.send_header("Access-Control-Allow-Methods", "POST, OPTIONS")
|
|
3603
|
+
self.send_header("Access-Control-Allow-Headers", "Content-Type")
|
|
2084
3604
|
self.end_headers()
|
|
2085
3605
|
|
|
2086
3606
|
class ThreadedTCPServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
|
|
@@ -2134,7 +3654,9 @@ def cmd_viewer(args: argparse.Namespace) -> int:
|
|
|
2134
3654
|
state.learning_rate = data.get("learning_rate", 0)
|
|
2135
3655
|
state.losses = data.get("losses", [])
|
|
2136
3656
|
state.status = data.get("status", "completed")
|
|
2137
|
-
state.elapsed_time = data.get(
|
|
3657
|
+
state.elapsed_time = data.get(
|
|
3658
|
+
"elapsed_time", 0.0
|
|
3659
|
+
) # Load elapsed time for completed training
|
|
2138
3660
|
state.goal = data.get("goal", "")
|
|
2139
3661
|
state.config_path = data.get("config_path", "")
|
|
2140
3662
|
state.capture_path = data.get("capture_path", "")
|
|
@@ -2149,6 +3671,7 @@ def cmd_viewer(args: argparse.Namespace) -> int:
|
|
|
2149
3671
|
if not state.model_name and state.config_path:
|
|
2150
3672
|
try:
|
|
2151
3673
|
import yaml
|
|
3674
|
+
|
|
2152
3675
|
# Try relative to project root first, then as absolute path
|
|
2153
3676
|
project_root = Path(__file__).parent.parent.parent
|
|
2154
3677
|
config_file = project_root / state.config_path
|
|
@@ -2173,14 +3696,16 @@ def cmd_viewer(args: argparse.Namespace) -> int:
|
|
|
2173
3696
|
|
|
2174
3697
|
dashboard_html = generate_training_dashboard(state, config)
|
|
2175
3698
|
(current_dir / "dashboard.html").write_text(dashboard_html)
|
|
2176
|
-
print(
|
|
3699
|
+
print(" Regenerated: dashboard.html")
|
|
2177
3700
|
|
|
2178
3701
|
# Generate unified viewer using consolidated function
|
|
2179
3702
|
viewer_path = generate_unified_viewer_from_output_dir(current_dir)
|
|
2180
3703
|
if viewer_path:
|
|
2181
3704
|
print(f"\nGenerated: {viewer_path}")
|
|
2182
3705
|
else:
|
|
2183
|
-
print(
|
|
3706
|
+
print(
|
|
3707
|
+
"\nNo comparison data found. Run comparison first or copy from capture directory."
|
|
3708
|
+
)
|
|
2184
3709
|
|
|
2185
3710
|
# Also regenerate benchmark viewer from latest benchmark results
|
|
2186
3711
|
_regenerate_benchmark_viewer_if_available(current_dir)
|
|
@@ -2200,9 +3725,9 @@ def cmd_benchmark_viewer(args: argparse.Namespace) -> int:
|
|
|
2200
3725
|
print(f"Error: Benchmark directory not found: {benchmark_dir}")
|
|
2201
3726
|
return 1
|
|
2202
3727
|
|
|
2203
|
-
print(f"\n{'='*50}")
|
|
3728
|
+
print(f"\n{'=' * 50}")
|
|
2204
3729
|
print("GENERATING BENCHMARK VIEWER")
|
|
2205
|
-
print(f"{'='*50}")
|
|
3730
|
+
print(f"{'=' * 50}")
|
|
2206
3731
|
print(f"Benchmark dir: {benchmark_dir}")
|
|
2207
3732
|
print()
|
|
2208
3733
|
|
|
@@ -2217,6 +3742,7 @@ def cmd_benchmark_viewer(args: argparse.Namespace) -> int:
|
|
|
2217
3742
|
except Exception as e:
|
|
2218
3743
|
print(f"Error generating benchmark viewer: {e}")
|
|
2219
3744
|
import traceback
|
|
3745
|
+
|
|
2220
3746
|
traceback.print_exc()
|
|
2221
3747
|
return 1
|
|
2222
3748
|
|
|
@@ -2233,16 +3759,19 @@ def cmd_compare(args: argparse.Namespace) -> int:
|
|
|
2233
3759
|
print(f"Error: Checkpoint not found: {checkpoint}")
|
|
2234
3760
|
return 1
|
|
2235
3761
|
|
|
2236
|
-
print(f"\n{'='*50}")
|
|
3762
|
+
print(f"\n{'=' * 50}")
|
|
2237
3763
|
print("RUNNING COMPARISON")
|
|
2238
|
-
print(f"{'='*50}")
|
|
3764
|
+
print(f"{'=' * 50}")
|
|
2239
3765
|
print(f"Capture: {capture_path}")
|
|
2240
3766
|
print(f"Checkpoint: {checkpoint or 'None (capture only)'}")
|
|
2241
3767
|
print()
|
|
2242
3768
|
|
|
2243
3769
|
cmd = [
|
|
2244
|
-
sys.executable,
|
|
2245
|
-
"
|
|
3770
|
+
sys.executable,
|
|
3771
|
+
"-m",
|
|
3772
|
+
"openadapt_ml.scripts.compare",
|
|
3773
|
+
"--capture",
|
|
3774
|
+
str(capture_path),
|
|
2246
3775
|
]
|
|
2247
3776
|
|
|
2248
3777
|
if checkpoint:
|
|
@@ -2281,7 +3810,7 @@ Examples:
|
|
|
2281
3810
|
|
|
2282
3811
|
# Run comparison
|
|
2283
3812
|
uv run python -m openadapt_ml.cloud.local compare --capture ~/captures/my-workflow --checkpoint checkpoints/model
|
|
2284
|
-
"""
|
|
3813
|
+
""",
|
|
2285
3814
|
)
|
|
2286
3815
|
|
|
2287
3816
|
subparsers = parser.add_subparsers(dest="command", help="Command")
|
|
@@ -2293,9 +3822,15 @@ Examples:
|
|
|
2293
3822
|
# train
|
|
2294
3823
|
p_train = subparsers.add_parser("train", help="Run training locally")
|
|
2295
3824
|
p_train.add_argument("--capture", required=True, help="Path to capture directory")
|
|
2296
|
-
p_train.add_argument(
|
|
2297
|
-
|
|
2298
|
-
|
|
3825
|
+
p_train.add_argument(
|
|
3826
|
+
"--goal", help="Task goal (default: derived from capture name)"
|
|
3827
|
+
)
|
|
3828
|
+
p_train.add_argument(
|
|
3829
|
+
"--config", help="Config file (default: auto-select based on device)"
|
|
3830
|
+
)
|
|
3831
|
+
p_train.add_argument(
|
|
3832
|
+
"--open", action="store_true", help="Open dashboard in browser"
|
|
3833
|
+
)
|
|
2299
3834
|
p_train.set_defaults(func=cmd_train)
|
|
2300
3835
|
|
|
2301
3836
|
# check
|
|
@@ -2306,11 +3841,21 @@ Examples:
|
|
|
2306
3841
|
p_serve = subparsers.add_parser("serve", help="Start web server for dashboard")
|
|
2307
3842
|
p_serve.add_argument("--port", type=int, default=8765, help="Port number")
|
|
2308
3843
|
p_serve.add_argument("--open", action="store_true", help="Open in browser")
|
|
2309
|
-
p_serve.add_argument(
|
|
2310
|
-
|
|
2311
|
-
|
|
2312
|
-
p_serve.add_argument(
|
|
2313
|
-
|
|
3844
|
+
p_serve.add_argument(
|
|
3845
|
+
"--quiet", "-q", action="store_true", help="Suppress request logging"
|
|
3846
|
+
)
|
|
3847
|
+
p_serve.add_argument(
|
|
3848
|
+
"--no-regenerate",
|
|
3849
|
+
action="store_true",
|
|
3850
|
+
help="Skip regenerating dashboard/viewer (serve existing files)",
|
|
3851
|
+
)
|
|
3852
|
+
p_serve.add_argument(
|
|
3853
|
+
"--benchmark",
|
|
3854
|
+
help="Serve benchmark results directory instead of training output",
|
|
3855
|
+
)
|
|
3856
|
+
p_serve.add_argument(
|
|
3857
|
+
"--start-page", help="Override default start page (e.g., benchmark.html)"
|
|
3858
|
+
)
|
|
2314
3859
|
p_serve.set_defaults(func=cmd_serve)
|
|
2315
3860
|
|
|
2316
3861
|
# viewer
|
|
@@ -2319,9 +3864,15 @@ Examples:
|
|
|
2319
3864
|
p_viewer.set_defaults(func=cmd_viewer)
|
|
2320
3865
|
|
|
2321
3866
|
# benchmark_viewer
|
|
2322
|
-
p_benchmark = subparsers.add_parser(
|
|
2323
|
-
|
|
2324
|
-
|
|
3867
|
+
p_benchmark = subparsers.add_parser(
|
|
3868
|
+
"benchmark-viewer", help="Generate benchmark viewer"
|
|
3869
|
+
)
|
|
3870
|
+
p_benchmark.add_argument(
|
|
3871
|
+
"benchmark_dir", help="Path to benchmark results directory"
|
|
3872
|
+
)
|
|
3873
|
+
p_benchmark.add_argument(
|
|
3874
|
+
"--open", action="store_true", help="Open viewer in browser"
|
|
3875
|
+
)
|
|
2325
3876
|
p_benchmark.set_defaults(func=cmd_benchmark_viewer)
|
|
2326
3877
|
|
|
2327
3878
|
# compare
|