openadapt-ml 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. openadapt_ml/baselines/__init__.py +121 -0
  2. openadapt_ml/baselines/adapter.py +185 -0
  3. openadapt_ml/baselines/cli.py +314 -0
  4. openadapt_ml/baselines/config.py +448 -0
  5. openadapt_ml/baselines/parser.py +922 -0
  6. openadapt_ml/baselines/prompts.py +787 -0
  7. openadapt_ml/benchmarks/__init__.py +13 -115
  8. openadapt_ml/benchmarks/agent.py +265 -421
  9. openadapt_ml/benchmarks/azure.py +28 -19
  10. openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
  11. openadapt_ml/benchmarks/cli.py +1722 -4847
  12. openadapt_ml/benchmarks/trace_export.py +631 -0
  13. openadapt_ml/benchmarks/viewer.py +22 -5
  14. openadapt_ml/benchmarks/vm_monitor.py +530 -29
  15. openadapt_ml/benchmarks/waa_deploy/Dockerfile +47 -53
  16. openadapt_ml/benchmarks/waa_deploy/api_agent.py +21 -20
  17. openadapt_ml/cloud/azure_inference.py +3 -5
  18. openadapt_ml/cloud/lambda_labs.py +722 -307
  19. openadapt_ml/cloud/local.py +2038 -487
  20. openadapt_ml/cloud/ssh_tunnel.py +68 -26
  21. openadapt_ml/datasets/next_action.py +40 -30
  22. openadapt_ml/evals/grounding.py +8 -3
  23. openadapt_ml/evals/plot_eval_metrics.py +15 -13
  24. openadapt_ml/evals/trajectory_matching.py +41 -26
  25. openadapt_ml/experiments/demo_prompt/format_demo.py +16 -6
  26. openadapt_ml/experiments/demo_prompt/run_experiment.py +26 -16
  27. openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
  28. openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
  29. openadapt_ml/experiments/representation_shootout/config.py +390 -0
  30. openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
  31. openadapt_ml/experiments/representation_shootout/runner.py +687 -0
  32. openadapt_ml/experiments/waa_demo/runner.py +29 -14
  33. openadapt_ml/export/parquet.py +36 -24
  34. openadapt_ml/grounding/detector.py +18 -14
  35. openadapt_ml/ingest/__init__.py +8 -6
  36. openadapt_ml/ingest/capture.py +25 -22
  37. openadapt_ml/ingest/loader.py +7 -4
  38. openadapt_ml/ingest/synthetic.py +189 -100
  39. openadapt_ml/models/api_adapter.py +14 -4
  40. openadapt_ml/models/base_adapter.py +10 -2
  41. openadapt_ml/models/providers/__init__.py +288 -0
  42. openadapt_ml/models/providers/anthropic.py +266 -0
  43. openadapt_ml/models/providers/base.py +299 -0
  44. openadapt_ml/models/providers/google.py +376 -0
  45. openadapt_ml/models/providers/openai.py +342 -0
  46. openadapt_ml/models/qwen_vl.py +46 -19
  47. openadapt_ml/perception/__init__.py +35 -0
  48. openadapt_ml/perception/integration.py +399 -0
  49. openadapt_ml/retrieval/demo_retriever.py +50 -24
  50. openadapt_ml/retrieval/embeddings.py +9 -8
  51. openadapt_ml/retrieval/retriever.py +3 -1
  52. openadapt_ml/runtime/__init__.py +50 -0
  53. openadapt_ml/runtime/policy.py +18 -5
  54. openadapt_ml/runtime/safety_gate.py +471 -0
  55. openadapt_ml/schema/__init__.py +9 -0
  56. openadapt_ml/schema/converters.py +74 -27
  57. openadapt_ml/schema/episode.py +31 -18
  58. openadapt_ml/scripts/capture_screenshots.py +530 -0
  59. openadapt_ml/scripts/compare.py +85 -54
  60. openadapt_ml/scripts/demo_policy.py +4 -1
  61. openadapt_ml/scripts/eval_policy.py +15 -9
  62. openadapt_ml/scripts/make_gif.py +1 -1
  63. openadapt_ml/scripts/prepare_synthetic.py +3 -1
  64. openadapt_ml/scripts/train.py +21 -9
  65. openadapt_ml/segmentation/README.md +920 -0
  66. openadapt_ml/segmentation/__init__.py +97 -0
  67. openadapt_ml/segmentation/adapters/__init__.py +5 -0
  68. openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
  69. openadapt_ml/segmentation/annotator.py +610 -0
  70. openadapt_ml/segmentation/cache.py +290 -0
  71. openadapt_ml/segmentation/cli.py +674 -0
  72. openadapt_ml/segmentation/deduplicator.py +656 -0
  73. openadapt_ml/segmentation/frame_describer.py +788 -0
  74. openadapt_ml/segmentation/pipeline.py +340 -0
  75. openadapt_ml/segmentation/schemas.py +622 -0
  76. openadapt_ml/segmentation/segment_extractor.py +634 -0
  77. openadapt_ml/training/azure_ops_viewer.py +1097 -0
  78. openadapt_ml/training/benchmark_viewer.py +52 -41
  79. openadapt_ml/training/shared_ui.py +7 -7
  80. openadapt_ml/training/stub_provider.py +57 -35
  81. openadapt_ml/training/trainer.py +143 -86
  82. openadapt_ml/training/trl_trainer.py +70 -21
  83. openadapt_ml/training/viewer.py +323 -108
  84. openadapt_ml/training/viewer_components.py +180 -0
  85. {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +215 -14
  86. openadapt_ml-0.2.1.dist-info/RECORD +116 -0
  87. openadapt_ml/benchmarks/base.py +0 -366
  88. openadapt_ml/benchmarks/data_collection.py +0 -432
  89. openadapt_ml/benchmarks/live_tracker.py +0 -180
  90. openadapt_ml/benchmarks/runner.py +0 -418
  91. openadapt_ml/benchmarks/waa.py +0 -761
  92. openadapt_ml/benchmarks/waa_live.py +0 -619
  93. openadapt_ml-0.2.0.dist-info/RECORD +0 -86
  94. {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
  95. {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -27,7 +27,6 @@ import http.server
27
27
  import json
28
28
  import os
29
29
  import shutil
30
- import signal
31
30
  import socketserver
32
31
  import subprocess
33
32
  import sys
@@ -109,7 +108,10 @@ def _is_mock_benchmark(benchmark_dir: Path) -> bool:
109
108
 
110
109
  # Check for test runs (but allow waa-mock evaluations with real API models)
111
110
  # Only filter out purely synthetic test data directories
112
- if any(term in benchmark_dir.name.lower() for term in ["test_run", "test_cli", "quick_demo"]):
111
+ if any(
112
+ term in benchmark_dir.name.lower()
113
+ for term in ["test_run", "test_cli", "quick_demo"]
114
+ ):
113
115
  return True
114
116
 
115
117
  return False
@@ -193,6 +195,7 @@ def _regenerate_benchmark_viewer_if_available(output_dir: Path) -> bool:
193
195
  except Exception as e:
194
196
  print(f" Could not regenerate benchmark viewer: {e}")
195
197
  import traceback
198
+
196
199
  traceback.print_exc()
197
200
  return False
198
201
 
@@ -201,6 +204,7 @@ def detect_device() -> str:
201
204
  """Detect available compute device."""
202
205
  try:
203
206
  import torch
207
+
204
208
  if torch.cuda.is_available():
205
209
  device_name = torch.cuda.get_device_name(0)
206
210
  return f"cuda ({device_name})"
@@ -251,10 +255,13 @@ def get_training_status() -> dict[str, Any]:
251
255
  # Find checkpoints
252
256
  checkpoints_dir = Path("checkpoints")
253
257
  if checkpoints_dir.exists():
254
- status["checkpoints"] = sorted([
255
- d.name for d in checkpoints_dir.iterdir()
256
- if d.is_dir() and (d / "adapter_config.json").exists()
257
- ])
258
+ status["checkpoints"] = sorted(
259
+ [
260
+ d.name
261
+ for d in checkpoints_dir.iterdir()
262
+ if d.is_dir() and (d / "adapter_config.json").exists()
263
+ ]
264
+ )
258
265
 
259
266
  return status
260
267
 
@@ -264,9 +271,9 @@ def cmd_status(args: argparse.Namespace) -> int:
264
271
  status = get_training_status()
265
272
  current_dir = get_current_output_dir()
266
273
 
267
- print(f"\n{'='*50}")
274
+ print(f"\n{'=' * 50}")
268
275
  print("LOCAL TRAINING STATUS")
269
- print(f"{'='*50}")
276
+ print(f"{'=' * 50}")
270
277
  print(f"Device: {status['device']}")
271
278
  print(f"Status: {'RUNNING' if status['running'] else 'IDLE'}")
272
279
  if status.get("job_id"):
@@ -274,7 +281,7 @@ def cmd_status(args: argparse.Namespace) -> int:
274
281
  print(f"Output: {current_dir}")
275
282
 
276
283
  if status.get("epoch"):
277
- print(f"\nProgress:")
284
+ print("\nProgress:")
278
285
  print(f" Epoch: {status['epoch']}")
279
286
  print(f" Step: {status['step']}")
280
287
  if status.get("loss"):
@@ -287,7 +294,9 @@ def cmd_status(args: argparse.Namespace) -> int:
287
294
  for cp in status["checkpoints"][-5:]: # Show last 5
288
295
  print(f" - {cp}")
289
296
 
290
- print(f"\nDashboard: {'✓' if status['has_dashboard'] else '✗'} {current_dir}/dashboard.html")
297
+ print(
298
+ f"\nDashboard: {'✓' if status['has_dashboard'] else '✗'} {current_dir}/dashboard.html"
299
+ )
291
300
  print(f"Viewer: {'✓' if status['has_viewer'] else '✗'} {current_dir}/viewer.html")
292
301
  print()
293
302
 
@@ -320,9 +329,9 @@ def cmd_train(args: argparse.Namespace) -> int:
320
329
  print(f"Error: Config not found: {config_path}")
321
330
  return 1
322
331
 
323
- print(f"\n{'='*50}")
332
+ print(f"\n{'=' * 50}")
324
333
  print("STARTING LOCAL TRAINING")
325
- print(f"{'='*50}")
334
+ print(f"{'=' * 50}")
326
335
  print(f"Capture: {capture_path}")
327
336
  print(f"Goal: {goal}")
328
337
  print(f"Config: {config}")
@@ -331,10 +340,15 @@ def cmd_train(args: argparse.Namespace) -> int:
331
340
 
332
341
  # Build command
333
342
  cmd = [
334
- sys.executable, "-m", "openadapt_ml.scripts.train",
335
- "--config", str(config_path),
336
- "--capture", str(capture_path),
337
- "--goal", goal,
343
+ sys.executable,
344
+ "-m",
345
+ "openadapt_ml.scripts.train",
346
+ "--config",
347
+ str(config_path),
348
+ "--capture",
349
+ str(capture_path),
350
+ "--goal",
351
+ goal,
338
352
  ]
339
353
 
340
354
  if args.open:
@@ -353,14 +367,16 @@ def cmd_check(args: argparse.Namespace) -> int:
353
367
  """Check training health and early stopping analysis."""
354
368
  status = get_training_status()
355
369
 
356
- print(f"\n{'='*50}")
370
+ print(f"\n{'=' * 50}")
357
371
  print("TRAINING HEALTH CHECK")
358
- print(f"{'='*50}")
372
+ print(f"{'=' * 50}")
359
373
 
360
374
  raw_losses = status.get("losses", [])
361
375
  if not raw_losses:
362
376
  print("No training data found.")
363
- print("Run training first with: uv run python -m openadapt_ml.cloud.local train --capture <path>")
377
+ print(
378
+ "Run training first with: uv run python -m openadapt_ml.cloud.local train --capture <path>"
379
+ )
364
380
  return 1
365
381
 
366
382
  # Extract loss values (handle both dict and float formats)
@@ -381,7 +397,7 @@ def cmd_check(args: argparse.Namespace) -> int:
381
397
  min_loss = min(losses)
382
398
  max_loss = max(losses)
383
399
 
384
- print(f"\nLoss progression:")
400
+ print("\nLoss progression:")
385
401
  print(f" First: {first_loss:.4f}")
386
402
  print(f" Last: {last_loss:.4f}")
387
403
  print(f" Min: {min_loss:.4f}")
@@ -392,9 +408,11 @@ def cmd_check(args: argparse.Namespace) -> int:
392
408
  if len(losses) >= 10:
393
409
  recent = losses[-10:]
394
410
  recent_avg = sum(recent) / len(recent)
395
- recent_std = (sum((x - recent_avg) ** 2 for x in recent) / len(recent)) ** 0.5
411
+ recent_std = (
412
+ sum((x - recent_avg) ** 2 for x in recent) / len(recent)
413
+ ) ** 0.5
396
414
 
397
- print(f"\nRecent stability (last 10 steps):")
415
+ print("\nRecent stability (last 10 steps):")
398
416
  print(f" Avg loss: {recent_avg:.4f}")
399
417
  print(f" Std dev: {recent_std:.4f}")
400
418
 
@@ -413,14 +431,18 @@ def cmd_serve(args: argparse.Namespace) -> int:
413
431
  """Start local web server for dashboard.
414
432
 
415
433
  Automatically regenerates dashboard and viewer before serving to ensure
416
- the latest code and data are reflected.
434
+ the latest code and data are reflected. Also ensures the 'current' symlink
435
+ points to the most recent training run.
417
436
  """
418
- from openadapt_ml.training.trainer import regenerate_local_dashboard
437
+ from openadapt_ml.training.trainer import (
438
+ regenerate_local_dashboard,
439
+ update_current_symlink_to_latest,
440
+ )
419
441
 
420
442
  port = args.port
421
443
 
422
444
  # Determine what to serve: benchmark directory or training output
423
- if hasattr(args, 'benchmark') and args.benchmark:
445
+ if hasattr(args, "benchmark") and args.benchmark:
424
446
  serve_dir = Path(args.benchmark).expanduser().resolve()
425
447
  if not serve_dir.exists():
426
448
  print(f"Error: Benchmark directory not found: {serve_dir}")
@@ -430,7 +452,10 @@ def cmd_serve(args: argparse.Namespace) -> int:
430
452
  if not args.no_regenerate:
431
453
  print("Regenerating benchmark viewer...")
432
454
  try:
433
- from openadapt_ml.training.benchmark_viewer import generate_benchmark_viewer
455
+ from openadapt_ml.training.benchmark_viewer import (
456
+ generate_benchmark_viewer,
457
+ )
458
+
434
459
  generate_benchmark_viewer(serve_dir)
435
460
  except Exception as e:
436
461
  print(f"Warning: Could not regenerate benchmark viewer: {e}")
@@ -439,6 +464,17 @@ def cmd_serve(args: argparse.Namespace) -> int:
439
464
  else:
440
465
  serve_dir = get_current_output_dir()
441
466
 
467
+ # If current symlink doesn't exist or is broken, update to latest run
468
+ if not serve_dir.exists() or not serve_dir.is_dir():
469
+ print("Updating 'current' symlink to latest training run...")
470
+ latest = update_current_symlink_to_latest()
471
+ if latest:
472
+ serve_dir = get_current_output_dir()
473
+ print(f" Updated to: {latest.name}")
474
+ else:
475
+ print(f"Error: {serve_dir} not found. Run training first.")
476
+ return 1
477
+
442
478
  if not serve_dir.exists():
443
479
  print(f"Error: {serve_dir} not found. Run training first.")
444
480
  return 1
@@ -447,7 +483,9 @@ def cmd_serve(args: argparse.Namespace) -> int:
447
483
  if not args.no_regenerate:
448
484
  print("Regenerating dashboard and viewer...")
449
485
  try:
450
- regenerate_local_dashboard(str(serve_dir))
486
+ # Use keep_polling=True so JavaScript fetches live data from training_log.json
487
+ # This ensures the dashboard shows current data instead of stale embedded data
488
+ regenerate_local_dashboard(str(serve_dir), keep_polling=True)
451
489
  # Also regenerate viewer if comparison data exists
452
490
  _regenerate_viewer_if_possible(serve_dir)
453
491
  except Exception as e:
@@ -456,10 +494,21 @@ def cmd_serve(args: argparse.Namespace) -> int:
456
494
  # Also regenerate benchmark viewer from latest benchmark results
457
495
  _regenerate_benchmark_viewer_if_available(serve_dir)
458
496
 
497
+ # Generate Azure ops dashboard
498
+ try:
499
+ from openadapt_ml.training.azure_ops_viewer import (
500
+ generate_azure_ops_dashboard,
501
+ )
502
+
503
+ generate_azure_ops_dashboard(serve_dir / "azure_ops.html")
504
+ print(" Generated Azure ops dashboard")
505
+ except Exception as e:
506
+ print(f" Warning: Could not generate Azure ops dashboard: {e}")
507
+
459
508
  start_page = "dashboard.html"
460
509
 
461
510
  # Override start page if specified
462
- if hasattr(args, 'start_page') and args.start_page:
511
+ if hasattr(args, "start_page") and args.start_page:
463
512
  start_page = args.start_page
464
513
 
465
514
  # Serve from the specified directory
@@ -476,33 +525,41 @@ def cmd_serve(args: argparse.Namespace) -> int:
476
525
  super().log_message(format, *log_args)
477
526
 
478
527
  def do_POST(self):
479
- if self.path == '/api/stop':
528
+ if self.path == "/api/stop":
480
529
  # Create stop signal file
481
530
  stop_file = serve_dir / "STOP_TRAINING"
482
531
  stop_file.touch()
483
532
  self.send_response(200)
484
- self.send_header('Content-Type', 'application/json')
485
- self.send_header('Access-Control-Allow-Origin', '*')
533
+ self.send_header("Content-Type", "application/json")
534
+ self.send_header("Access-Control-Allow-Origin", "*")
486
535
  self.end_headers()
487
536
  self.wfile.write(b'{"status": "stop_signal_created"}')
488
537
  print(f"\n⏹ Stop signal created: {stop_file}")
489
- elif self.path == '/api/run-benchmark':
538
+ elif self.path == "/api/run-benchmark":
490
539
  # Parse request body for provider
491
- content_length = int(self.headers.get('Content-Length', 0))
492
- body = self.rfile.read(content_length).decode('utf-8') if content_length else '{}'
540
+ content_length = int(self.headers.get("Content-Length", 0))
541
+ body = (
542
+ self.rfile.read(content_length).decode("utf-8")
543
+ if content_length
544
+ else "{}"
545
+ )
493
546
  try:
494
547
  params = json.loads(body)
495
548
  except json.JSONDecodeError:
496
549
  params = {}
497
550
 
498
- provider = params.get('provider', 'anthropic')
499
- tasks = params.get('tasks', 5)
551
+ provider = params.get("provider", "anthropic")
552
+ tasks = params.get("tasks", 5)
500
553
 
501
554
  self.send_response(200)
502
- self.send_header('Content-Type', 'application/json')
503
- self.send_header('Access-Control-Allow-Origin', '*')
555
+ self.send_header("Content-Type", "application/json")
556
+ self.send_header("Access-Control-Allow-Origin", "*")
504
557
  self.end_headers()
505
- self.wfile.write(json.dumps({"status": "started", "provider": provider, "tasks": tasks}).encode())
558
+ self.wfile.write(
559
+ json.dumps(
560
+ {"status": "started", "provider": provider, "tasks": tasks}
561
+ ).encode()
562
+ )
506
563
 
507
564
  # Run benchmark in background thread with progress logging
508
565
  def run_benchmark():
@@ -516,25 +573,45 @@ def cmd_serve(args: argparse.Namespace) -> int:
516
573
  # Create progress log file (in cwd which is serve_dir)
517
574
  progress_file = Path("benchmark_progress.json")
518
575
 
519
- print(f"\n🚀 Starting {provider} benchmark evaluation ({tasks} tasks)...")
576
+ print(
577
+ f"\n🚀 Starting {provider} benchmark evaluation ({tasks} tasks)..."
578
+ )
520
579
 
521
580
  # Write initial progress
522
- progress_file.write_text(json.dumps({
523
- "status": "running",
524
- "provider": provider,
525
- "tasks_total": tasks,
526
- "tasks_complete": 0,
527
- "message": f"Starting {provider} evaluation..."
528
- }))
581
+ progress_file.write_text(
582
+ json.dumps(
583
+ {
584
+ "status": "running",
585
+ "provider": provider,
586
+ "tasks_total": tasks,
587
+ "tasks_complete": 0,
588
+ "message": f"Starting {provider} evaluation...",
589
+ }
590
+ )
591
+ )
529
592
 
530
593
  # Copy environment with loaded vars
531
594
  env = os.environ.copy()
532
595
 
533
596
  result = subprocess.run(
534
- ["uv", "run", "python", "-m", "openadapt_ml.benchmarks.cli", "run-api",
535
- "--provider", provider, "--tasks", str(tasks),
536
- "--model-id", f"{provider}-api"],
537
- capture_output=True, text=True, cwd=str(project_root), env=env
597
+ [
598
+ "uv",
599
+ "run",
600
+ "python",
601
+ "-m",
602
+ "openadapt_ml.benchmarks.cli",
603
+ "run-api",
604
+ "--provider",
605
+ provider,
606
+ "--tasks",
607
+ str(tasks),
608
+ "--model-id",
609
+ f"{provider}-api",
610
+ ],
611
+ capture_output=True,
612
+ text=True,
613
+ cwd=str(project_root),
614
+ env=env,
538
615
  )
539
616
 
540
617
  print(f"\n📋 Benchmark output:\n{result.stdout}")
@@ -542,77 +619,95 @@ def cmd_serve(args: argparse.Namespace) -> int:
542
619
  print(f"Stderr: {result.stderr}")
543
620
 
544
621
  if result.returncode == 0:
545
- print(f"✅ Benchmark complete. Regenerating viewer...")
546
- progress_file.write_text(json.dumps({
547
- "status": "complete",
548
- "provider": provider,
549
- "message": "Evaluation complete! Refreshing results..."
550
- }))
622
+ print("✅ Benchmark complete. Regenerating viewer...")
623
+ progress_file.write_text(
624
+ json.dumps(
625
+ {
626
+ "status": "complete",
627
+ "provider": provider,
628
+ "message": "Evaluation complete! Refreshing results...",
629
+ }
630
+ )
631
+ )
551
632
  # Regenerate benchmark viewer
552
633
  _regenerate_benchmark_viewer_if_available(serve_dir)
553
634
  else:
554
635
  print(f"❌ Benchmark failed: {result.stderr}")
555
- progress_file.write_text(json.dumps({
556
- "status": "error",
557
- "provider": provider,
558
- "message": f"Evaluation failed: {result.stderr[:200]}"
559
- }))
636
+ progress_file.write_text(
637
+ json.dumps(
638
+ {
639
+ "status": "error",
640
+ "provider": provider,
641
+ "message": f"Evaluation failed: {result.stderr[:200]}",
642
+ }
643
+ )
644
+ )
560
645
 
561
646
  threading.Thread(target=run_benchmark, daemon=True).start()
562
- elif self.path == '/api/vms/register':
647
+ elif self.path == "/api/vms/register":
563
648
  # Register a new VM
564
- content_length = int(self.headers.get('Content-Length', 0))
565
- body = self.rfile.read(content_length).decode('utf-8') if content_length else '{}'
649
+ content_length = int(self.headers.get("Content-Length", 0))
650
+ body = (
651
+ self.rfile.read(content_length).decode("utf-8")
652
+ if content_length
653
+ else "{}"
654
+ )
566
655
  try:
567
656
  vm_data = json.loads(body)
568
657
  result = self._register_vm(vm_data)
569
658
  self.send_response(200)
570
- self.send_header('Content-Type', 'application/json')
571
- self.send_header('Access-Control-Allow-Origin', '*')
659
+ self.send_header("Content-Type", "application/json")
660
+ self.send_header("Access-Control-Allow-Origin", "*")
572
661
  self.end_headers()
573
662
  self.wfile.write(json.dumps(result).encode())
574
663
  except Exception as e:
575
664
  self.send_response(500)
576
- self.send_header('Content-Type', 'application/json')
577
- self.send_header('Access-Control-Allow-Origin', '*')
665
+ self.send_header("Content-Type", "application/json")
666
+ self.send_header("Access-Control-Allow-Origin", "*")
578
667
  self.end_headers()
579
668
  self.wfile.write(json.dumps({"error": str(e)}).encode())
580
- elif self.path == '/api/benchmark/start':
669
+ elif self.path == "/api/benchmark/start":
581
670
  # Start a benchmark run with configurable parameters
582
- content_length = int(self.headers.get('Content-Length', 0))
583
- body = self.rfile.read(content_length).decode('utf-8') if content_length else '{}'
671
+ content_length = int(self.headers.get("Content-Length", 0))
672
+ body = (
673
+ self.rfile.read(content_length).decode("utf-8")
674
+ if content_length
675
+ else "{}"
676
+ )
584
677
  try:
585
678
  params = json.loads(body)
586
679
  result = self._start_benchmark_run(params)
587
680
  self.send_response(200)
588
- self.send_header('Content-Type', 'application/json')
589
- self.send_header('Access-Control-Allow-Origin', '*')
681
+ self.send_header("Content-Type", "application/json")
682
+ self.send_header("Access-Control-Allow-Origin", "*")
590
683
  self.end_headers()
591
684
  self.wfile.write(json.dumps(result).encode())
592
685
  except Exception as e:
593
686
  self.send_response(500)
594
- self.send_header('Content-Type', 'application/json')
595
- self.send_header('Access-Control-Allow-Origin', '*')
687
+ self.send_header("Content-Type", "application/json")
688
+ self.send_header("Access-Control-Allow-Origin", "*")
596
689
  self.end_headers()
597
690
  self.wfile.write(json.dumps({"error": str(e)}).encode())
598
691
  else:
599
692
  self.send_error(404, "Not found")
600
693
 
601
694
  def do_GET(self):
602
- if self.path.startswith('/api/benchmark-progress'):
695
+ if self.path.startswith("/api/benchmark-progress"):
603
696
  # Return benchmark progress
604
- progress_file = Path("benchmark_progress.json") # Relative to serve_dir (cwd)
697
+ progress_file = Path(
698
+ "benchmark_progress.json"
699
+ ) # Relative to serve_dir (cwd)
605
700
  if progress_file.exists():
606
701
  progress = progress_file.read_text()
607
702
  else:
608
703
  progress = json.dumps({"status": "idle"})
609
704
 
610
705
  self.send_response(200)
611
- self.send_header('Content-Type', 'application/json')
612
- self.send_header('Access-Control-Allow-Origin', '*')
706
+ self.send_header("Content-Type", "application/json")
707
+ self.send_header("Access-Control-Allow-Origin", "*")
613
708
  self.end_headers()
614
709
  self.wfile.write(progress.encode())
615
- elif self.path.startswith('/api/benchmark-live'):
710
+ elif self.path.startswith("/api/benchmark-live"):
616
711
  # Return live evaluation state
617
712
  live_file = Path("benchmark_live.json") # Relative to serve_dir (cwd)
618
713
  if live_file.exists():
@@ -621,32 +716,33 @@ def cmd_serve(args: argparse.Namespace) -> int:
621
716
  live_state = json.dumps({"status": "idle"})
622
717
 
623
718
  self.send_response(200)
624
- self.send_header('Content-Type', 'application/json')
625
- self.send_header('Access-Control-Allow-Origin', '*')
719
+ self.send_header("Content-Type", "application/json")
720
+ self.send_header("Access-Control-Allow-Origin", "*")
626
721
  self.end_headers()
627
722
  self.wfile.write(live_state.encode())
628
- elif self.path.startswith('/api/tasks'):
723
+ elif self.path.startswith("/api/tasks"):
629
724
  # Return background task status (VM, Docker, benchmarks)
630
725
  try:
631
726
  tasks = self._fetch_background_tasks()
632
727
  self.send_response(200)
633
- self.send_header('Content-Type', 'application/json')
634
- self.send_header('Access-Control-Allow-Origin', '*')
728
+ self.send_header("Content-Type", "application/json")
729
+ self.send_header("Access-Control-Allow-Origin", "*")
635
730
  self.end_headers()
636
731
  self.wfile.write(json.dumps(tasks).encode())
637
732
  except Exception as e:
638
733
  self.send_response(500)
639
- self.send_header('Content-Type', 'application/json')
640
- self.send_header('Access-Control-Allow-Origin', '*')
734
+ self.send_header("Content-Type", "application/json")
735
+ self.send_header("Access-Control-Allow-Origin", "*")
641
736
  self.end_headers()
642
737
  self.wfile.write(json.dumps({"error": str(e)}).encode())
643
- elif self.path.startswith('/api/azure-jobs'):
738
+ elif self.path.startswith("/api/azure-jobs"):
644
739
  # Return LIVE Azure job status from Azure ML
645
740
  # Supports ?force=true parameter for manual refresh (always fetches live)
646
741
  try:
647
742
  from urllib.parse import urlparse, parse_qs
743
+
648
744
  query = parse_qs(urlparse(self.path).query)
649
- force_refresh = query.get('force', ['false'])[0].lower() == 'true'
745
+ force_refresh = query.get("force", ["false"])[0].lower() == "true"
650
746
 
651
747
  # Always fetch live data (force just indicates manual refresh for logging)
652
748
  if force_refresh:
@@ -654,22 +750,23 @@ def cmd_serve(args: argparse.Namespace) -> int:
654
750
 
655
751
  jobs = self._fetch_live_azure_jobs()
656
752
  self.send_response(200)
657
- self.send_header('Content-Type', 'application/json')
658
- self.send_header('Access-Control-Allow-Origin', '*')
753
+ self.send_header("Content-Type", "application/json")
754
+ self.send_header("Access-Control-Allow-Origin", "*")
659
755
  self.end_headers()
660
756
  self.wfile.write(json.dumps(jobs).encode())
661
757
  except Exception as e:
662
758
  self.send_response(500)
663
- self.send_header('Content-Type', 'application/json')
664
- self.send_header('Access-Control-Allow-Origin', '*')
759
+ self.send_header("Content-Type", "application/json")
760
+ self.send_header("Access-Control-Allow-Origin", "*")
665
761
  self.end_headers()
666
762
  self.wfile.write(json.dumps({"error": str(e)}).encode())
667
- elif self.path.startswith('/api/benchmark-sse'):
763
+ elif self.path.startswith("/api/benchmark-sse"):
668
764
  # Server-Sent Events endpoint for real-time benchmark updates
669
765
  try:
670
766
  from urllib.parse import urlparse, parse_qs
767
+
671
768
  query = parse_qs(urlparse(self.path).query)
672
- interval = int(query.get('interval', [5])[0])
769
+ interval = int(query.get("interval", [5])[0])
673
770
 
674
771
  # Validate interval (min 1s, max 60s)
675
772
  interval = max(1, min(60, interval))
@@ -677,57 +774,60 @@ def cmd_serve(args: argparse.Namespace) -> int:
677
774
  self._stream_benchmark_updates(interval)
678
775
  except Exception as e:
679
776
  self.send_error(500, f"SSE error: {e}")
680
- elif self.path.startswith('/api/vms'):
777
+ elif self.path.startswith("/api/vms"):
681
778
  # Return VM registry with live status
682
779
  try:
683
780
  vms = self._fetch_vm_registry()
684
781
  self.send_response(200)
685
- self.send_header('Content-Type', 'application/json')
686
- self.send_header('Access-Control-Allow-Origin', '*')
782
+ self.send_header("Content-Type", "application/json")
783
+ self.send_header("Access-Control-Allow-Origin", "*")
687
784
  self.end_headers()
688
785
  self.wfile.write(json.dumps(vms).encode())
689
786
  except Exception as e:
690
787
  self.send_response(500)
691
- self.send_header('Content-Type', 'application/json')
692
- self.send_header('Access-Control-Allow-Origin', '*')
788
+ self.send_header("Content-Type", "application/json")
789
+ self.send_header("Access-Control-Allow-Origin", "*")
693
790
  self.end_headers()
694
791
  self.wfile.write(json.dumps({"error": str(e)}).encode())
695
- elif self.path.startswith('/api/azure-job-logs'):
792
+ elif self.path.startswith("/api/azure-job-logs"):
696
793
  # Return live logs for running Azure job
697
794
  try:
698
795
  # Parse job_id from query string
699
796
  from urllib.parse import urlparse, parse_qs
797
+
700
798
  query = parse_qs(urlparse(self.path).query)
701
- job_id = query.get('job_id', [None])[0]
799
+ job_id = query.get("job_id", [None])[0]
702
800
 
703
801
  logs = self._fetch_azure_job_logs(job_id)
704
802
  self.send_response(200)
705
- self.send_header('Content-Type', 'application/json')
706
- self.send_header('Access-Control-Allow-Origin', '*')
803
+ self.send_header("Content-Type", "application/json")
804
+ self.send_header("Access-Control-Allow-Origin", "*")
707
805
  self.end_headers()
708
806
  self.wfile.write(json.dumps(logs).encode())
709
807
  except Exception as e:
710
808
  self.send_response(500)
711
- self.send_header('Content-Type', 'application/json')
712
- self.send_header('Access-Control-Allow-Origin', '*')
809
+ self.send_header("Content-Type", "application/json")
810
+ self.send_header("Access-Control-Allow-Origin", "*")
713
811
  self.end_headers()
714
812
  self.wfile.write(json.dumps({"error": str(e)}).encode())
715
- elif self.path.startswith('/api/probe-vm'):
813
+ elif self.path.startswith("/api/probe-vm"):
716
814
  # Probe the VM to check if WAA server is responding
717
815
  try:
718
816
  result = self._probe_vm()
719
817
  self.send_response(200)
720
- self.send_header('Content-Type', 'application/json')
721
- self.send_header('Access-Control-Allow-Origin', '*')
818
+ self.send_header("Content-Type", "application/json")
819
+ self.send_header("Access-Control-Allow-Origin", "*")
722
820
  self.end_headers()
723
821
  self.wfile.write(json.dumps(result).encode())
724
822
  except Exception as e:
725
823
  self.send_response(500)
726
- self.send_header('Content-Type', 'application/json')
727
- self.send_header('Access-Control-Allow-Origin', '*')
824
+ self.send_header("Content-Type", "application/json")
825
+ self.send_header("Access-Control-Allow-Origin", "*")
728
826
  self.end_headers()
729
- self.wfile.write(json.dumps({"error": str(e), "responding": False}).encode())
730
- elif self.path.startswith('/api/tunnels'):
827
+ self.wfile.write(
828
+ json.dumps({"error": str(e), "responding": False}).encode()
829
+ )
830
+ elif self.path.startswith("/api/tunnels"):
731
831
  # Return SSH tunnel status
732
832
  try:
733
833
  tunnel_mgr = get_tunnel_manager()
@@ -743,44 +843,293 @@ def cmd_serve(args: argparse.Namespace) -> int:
743
843
  for name, s in status.items()
744
844
  }
745
845
  self.send_response(200)
746
- self.send_header('Content-Type', 'application/json')
747
- self.send_header('Access-Control-Allow-Origin', '*')
846
+ self.send_header("Content-Type", "application/json")
847
+ self.send_header("Access-Control-Allow-Origin", "*")
748
848
  self.end_headers()
749
849
  self.wfile.write(json.dumps(result).encode())
750
850
  except Exception as e:
751
851
  self.send_response(500)
752
- self.send_header('Content-Type', 'application/json')
753
- self.send_header('Access-Control-Allow-Origin', '*')
852
+ self.send_header("Content-Type", "application/json")
853
+ self.send_header("Access-Control-Allow-Origin", "*")
754
854
  self.end_headers()
755
855
  self.wfile.write(json.dumps({"error": str(e)}).encode())
756
- elif self.path.startswith('/api/current-run'):
856
+ elif self.path.startswith("/api/current-run"):
757
857
  # Return currently running benchmark info
758
858
  try:
759
859
  result = self._get_current_run()
760
860
  self.send_response(200)
761
- self.send_header('Content-Type', 'application/json')
762
- self.send_header('Access-Control-Allow-Origin', '*')
861
+ self.send_header("Content-Type", "application/json")
862
+ self.send_header("Access-Control-Allow-Origin", "*")
763
863
  self.end_headers()
764
864
  self.wfile.write(json.dumps(result).encode())
765
865
  except Exception as e:
766
866
  self.send_response(500)
767
- self.send_header('Content-Type', 'application/json')
768
- self.send_header('Access-Control-Allow-Origin', '*')
867
+ self.send_header("Content-Type", "application/json")
868
+ self.send_header("Access-Control-Allow-Origin", "*")
769
869
  self.end_headers()
770
- self.wfile.write(json.dumps({"error": str(e), "running": False}).encode())
771
- elif self.path.startswith('/api/background-tasks'):
870
+ self.wfile.write(
871
+ json.dumps({"error": str(e), "running": False}).encode()
872
+ )
873
+ elif self.path.startswith("/api/background-tasks"):
772
874
  # Alias for /api/tasks - background task status
773
875
  try:
774
876
  tasks = self._fetch_background_tasks()
775
877
  self.send_response(200)
776
- self.send_header('Content-Type', 'application/json')
777
- self.send_header('Access-Control-Allow-Origin', '*')
878
+ self.send_header("Content-Type", "application/json")
879
+ self.send_header("Access-Control-Allow-Origin", "*")
778
880
  self.end_headers()
779
881
  self.wfile.write(json.dumps(tasks).encode())
780
882
  except Exception as e:
781
883
  self.send_response(500)
782
- self.send_header('Content-Type', 'application/json')
783
- self.send_header('Access-Control-Allow-Origin', '*')
884
+ self.send_header("Content-Type", "application/json")
885
+ self.send_header("Access-Control-Allow-Origin", "*")
886
+ self.end_headers()
887
+ self.wfile.write(json.dumps({"error": str(e)}).encode())
888
+ elif self.path.startswith("/api/benchmark/status"):
889
+ # Return current benchmark job status with ETA
890
+ try:
891
+ status = self._get_benchmark_status()
892
+ self.send_response(200)
893
+ self.send_header("Content-Type", "application/json")
894
+ self.send_header("Access-Control-Allow-Origin", "*")
895
+ self.end_headers()
896
+ self.wfile.write(json.dumps(status).encode())
897
+ except Exception as e:
898
+ self.send_response(500)
899
+ self.send_header("Content-Type", "application/json")
900
+ self.send_header("Access-Control-Allow-Origin", "*")
901
+ self.end_headers()
902
+ self.wfile.write(
903
+ json.dumps({"error": str(e), "status": "error"}).encode()
904
+ )
905
+ elif self.path.startswith("/api/benchmark/costs"):
906
+ # Return cost breakdown (Azure VM, API calls, GPU)
907
+ try:
908
+ costs = self._get_benchmark_costs()
909
+ self.send_response(200)
910
+ self.send_header("Content-Type", "application/json")
911
+ self.send_header("Access-Control-Allow-Origin", "*")
912
+ self.end_headers()
913
+ self.wfile.write(json.dumps(costs).encode())
914
+ except Exception as e:
915
+ self.send_response(500)
916
+ self.send_header("Content-Type", "application/json")
917
+ self.send_header("Access-Control-Allow-Origin", "*")
918
+ self.end_headers()
919
+ self.wfile.write(json.dumps({"error": str(e)}).encode())
920
+ elif self.path.startswith("/api/benchmark/metrics"):
921
+ # Return performance metrics (success rate, domain breakdown)
922
+ try:
923
+ metrics = self._get_benchmark_metrics()
924
+ self.send_response(200)
925
+ self.send_header("Content-Type", "application/json")
926
+ self.send_header("Access-Control-Allow-Origin", "*")
927
+ self.end_headers()
928
+ self.wfile.write(json.dumps(metrics).encode())
929
+ except Exception as e:
930
+ self.send_response(500)
931
+ self.send_header("Content-Type", "application/json")
932
+ self.send_header("Access-Control-Allow-Origin", "*")
933
+ self.end_headers()
934
+ self.wfile.write(json.dumps({"error": str(e)}).encode())
935
+ elif self.path.startswith("/api/benchmark/workers"):
936
+ # Return worker status and utilization
937
+ try:
938
+ workers = self._get_benchmark_workers()
939
+ self.send_response(200)
940
+ self.send_header("Content-Type", "application/json")
941
+ self.send_header("Access-Control-Allow-Origin", "*")
942
+ self.end_headers()
943
+ self.wfile.write(json.dumps(workers).encode())
944
+ except Exception as e:
945
+ self.send_response(500)
946
+ self.send_header("Content-Type", "application/json")
947
+ self.send_header("Access-Control-Allow-Origin", "*")
948
+ self.end_headers()
949
+ self.wfile.write(json.dumps({"error": str(e)}).encode())
950
+ elif self.path.startswith("/api/benchmark/runs"):
951
+ # Return list of all benchmark runs
952
+ try:
953
+ runs = self._get_benchmark_runs()
954
+ self.send_response(200)
955
+ self.send_header("Content-Type", "application/json")
956
+ self.send_header("Access-Control-Allow-Origin", "*")
957
+ self.end_headers()
958
+ self.wfile.write(json.dumps(runs).encode())
959
+ except Exception as e:
960
+ self.send_response(500)
961
+ self.send_header("Content-Type", "application/json")
962
+ self.send_header("Access-Control-Allow-Origin", "*")
963
+ self.end_headers()
964
+ self.wfile.write(json.dumps({"error": str(e)}).encode())
965
+ elif self.path.startswith("/api/benchmark/tasks/"):
966
+ # Return task execution details
967
+ # URL format: /api/benchmark/tasks/{run_name}/{task_id}
968
+ try:
969
+ parts = self.path.split("/")
970
+ if len(parts) >= 6:
971
+ run_name = parts[4]
972
+ task_id = parts[5]
973
+ execution = self._get_task_execution(run_name, task_id)
974
+ self.send_response(200)
975
+ self.send_header("Content-Type", "application/json")
976
+ self.send_header("Access-Control-Allow-Origin", "*")
977
+ self.end_headers()
978
+ self.wfile.write(json.dumps(execution).encode())
979
+ else:
980
+ self.send_error(400, "Invalid path format")
981
+ except Exception as e:
982
+ self.send_response(500)
983
+ self.send_header("Content-Type", "application/json")
984
+ self.send_header("Access-Control-Allow-Origin", "*")
985
+ self.end_headers()
986
+ self.wfile.write(json.dumps({"error": str(e)}).encode())
987
+ elif self.path.startswith("/api/benchmark/screenshots/"):
988
+ # Serve screenshot files
989
+ # URL format: /api/benchmark/screenshots/{run_name}/{task_id}/screenshots/{filename}
990
+ try:
991
+ # Remove /api/benchmark/screenshots/ prefix
992
+ path_parts = self.path.replace(
993
+ "/api/benchmark/screenshots/", ""
994
+ ).split("/")
995
+ if len(path_parts) >= 4:
996
+ run_name = path_parts[0]
997
+ task_id = path_parts[1]
998
+ # path_parts[2] should be 'screenshots'
999
+ filename = path_parts[3]
1000
+
1001
+ results_dir = Path("benchmark_results")
1002
+ screenshot_path = (
1003
+ results_dir
1004
+ / run_name
1005
+ / "tasks"
1006
+ / task_id
1007
+ / "screenshots"
1008
+ / filename
1009
+ )
1010
+
1011
+ if screenshot_path.exists():
1012
+ self.send_response(200)
1013
+ self.send_header("Content-Type", "image/png")
1014
+ self.send_header("Access-Control-Allow-Origin", "*")
1015
+ self.end_headers()
1016
+ with open(screenshot_path, "rb") as f:
1017
+ self.wfile.write(f.read())
1018
+ else:
1019
+ self.send_error(
1020
+ 404, f"Screenshot not found: {screenshot_path}"
1021
+ )
1022
+ else:
1023
+ self.send_error(400, "Invalid screenshot path format")
1024
+ except Exception as e:
1025
+ self.send_error(500, f"Error serving screenshot: {e}")
1026
+ elif self.path.startswith("/api/azure-ops-sse"):
1027
+ # Server-Sent Events endpoint for Azure operations status
1028
+ try:
1029
+ self._stream_azure_ops_updates()
1030
+ except Exception as e:
1031
+ self.send_error(500, f"SSE error: {e}")
1032
+ elif self.path.startswith("/api/azure-ops-status"):
1033
+ # Return Azure operations status from JSON file
1034
+ # Session tracker provides elapsed_seconds and cost_usd for
1035
+ # persistence across page refreshes
1036
+ try:
1037
+ from openadapt_ml.benchmarks.azure_ops_tracker import read_status
1038
+ from openadapt_ml.benchmarks.session_tracker import (
1039
+ get_session,
1040
+ update_session_vm_state,
1041
+ )
1042
+
1043
+ # Get operation status (current task)
1044
+ status = read_status()
1045
+
1046
+ # Get session data (persistent across refreshes)
1047
+ session = get_session()
1048
+
1049
+ # Update session based on VM state if we have VM info
1050
+ # IMPORTANT: Only pass vm_ip if it's truthy to avoid
1051
+ # overwriting session's stable vm_ip with None
1052
+ if status.get("vm_state") and status.get("vm_state") != "unknown":
1053
+ status_vm_ip = status.get("vm_ip")
1054
+ # Build update kwargs - only include vm_ip if present
1055
+ update_kwargs = {
1056
+ "vm_state": status["vm_state"],
1057
+ "vm_size": status.get("vm_size"),
1058
+ }
1059
+ if status_vm_ip: # Only include if truthy
1060
+ update_kwargs["vm_ip"] = status_vm_ip
1061
+ session = update_session_vm_state(**update_kwargs)
1062
+
1063
+ # Use session's vm_ip as authoritative source
1064
+ # This prevents IP flickering when status file has stale/None values
1065
+ if session.get("vm_ip"):
1066
+ status["vm_ip"] = session["vm_ip"]
1067
+
1068
+ # Use session's elapsed_seconds and cost_usd for persistence
1069
+ # These survive page refreshes and track total VM runtime
1070
+ if (
1071
+ session.get("is_active")
1072
+ or session.get("accumulated_seconds", 0) > 0
1073
+ ):
1074
+ status["elapsed_seconds"] = session.get("elapsed_seconds", 0.0)
1075
+ status["cost_usd"] = session.get("cost_usd", 0.0)
1076
+ status["started_at"] = session.get("started_at")
1077
+ # Include session metadata for debugging
1078
+ status["session_id"] = session.get("session_id")
1079
+ status["session_is_active"] = session.get("is_active", False)
1080
+ # Include accumulated time from previous sessions for hybrid display
1081
+ status["accumulated_seconds"] = session.get(
1082
+ "accumulated_seconds", 0.0
1083
+ )
1084
+ # Calculate current session time (total - accumulated)
1085
+ current_session_seconds = max(
1086
+ 0, status["elapsed_seconds"] - status["accumulated_seconds"]
1087
+ )
1088
+ status["current_session_seconds"] = current_session_seconds
1089
+ status["current_session_cost_usd"] = (
1090
+ current_session_seconds / 3600
1091
+ ) * session.get("hourly_rate_usd", 0.422)
1092
+
1093
+ try:
1094
+ tunnel_mgr = get_tunnel_manager()
1095
+ tunnel_status = tunnel_mgr.get_tunnel_status()
1096
+ status["tunnels"] = {
1097
+ name: {
1098
+ "active": s.active,
1099
+ "local_port": s.local_port,
1100
+ "remote_endpoint": s.remote_endpoint,
1101
+ "pid": s.pid,
1102
+ "error": s.error,
1103
+ }
1104
+ for name, s in tunnel_status.items()
1105
+ }
1106
+ except Exception as e:
1107
+ status["tunnels"] = {"error": str(e)}
1108
+
1109
+ self.send_response(200)
1110
+ self.send_header("Content-Type", "application/json")
1111
+ self.send_header("Access-Control-Allow-Origin", "*")
1112
+ self.end_headers()
1113
+ self.wfile.write(json.dumps(status).encode())
1114
+ except Exception as e:
1115
+ self.send_response(500)
1116
+ self.send_header("Content-Type", "application/json")
1117
+ self.send_header("Access-Control-Allow-Origin", "*")
1118
+ self.end_headers()
1119
+ self.wfile.write(json.dumps({"error": str(e)}).encode())
1120
+ elif self.path.startswith("/api/vm-diagnostics"):
1121
+ # Return VM diagnostics: disk usage, Docker stats, memory usage
1122
+ try:
1123
+ diagnostics = self._get_vm_diagnostics()
1124
+ self.send_response(200)
1125
+ self.send_header("Content-Type", "application/json")
1126
+ self.send_header("Access-Control-Allow-Origin", "*")
1127
+ self.end_headers()
1128
+ self.wfile.write(json.dumps(diagnostics).encode())
1129
+ except Exception as e:
1130
+ self.send_response(500)
1131
+ self.send_header("Content-Type", "application/json")
1132
+ self.send_header("Access-Control-Allow-Origin", "*")
784
1133
  self.end_headers()
785
1134
  self.wfile.write(json.dumps({"error": str(e)}).encode())
786
1135
  else:
@@ -790,13 +1139,25 @@ def cmd_serve(args: argparse.Namespace) -> int:
790
1139
  def _fetch_live_azure_jobs(self):
791
1140
  """Fetch live job status from Azure ML."""
792
1141
  import subprocess
1142
+
793
1143
  result = subprocess.run(
794
- ["az", "ml", "job", "list",
795
- "--resource-group", "openadapt-agents",
796
- "--workspace-name", "openadapt-ml",
797
- "--query", "[].{name:name,display_name:display_name,status:status,creation_context:creation_context.created_at}",
798
- "-o", "json"],
799
- capture_output=True, text=True, timeout=30
1144
+ [
1145
+ "az",
1146
+ "ml",
1147
+ "job",
1148
+ "list",
1149
+ "--resource-group",
1150
+ "openadapt-agents",
1151
+ "--workspace-name",
1152
+ "openadapt-ml",
1153
+ "--query",
1154
+ "[].{name:name,display_name:display_name,status:status,creation_context:creation_context.created_at}",
1155
+ "-o",
1156
+ "json",
1157
+ ],
1158
+ capture_output=True,
1159
+ text=True,
1160
+ timeout=30,
800
1161
  )
801
1162
  if result.returncode != 0:
802
1163
  raise Exception(f"Azure CLI error: {result.stderr}")
@@ -808,14 +1169,16 @@ def cmd_serve(args: argparse.Namespace) -> int:
808
1169
 
809
1170
  formatted = []
810
1171
  for job in jobs[:10]: # Limit to 10 most recent
811
- formatted.append({
812
- "job_id": job.get("name", "unknown"),
813
- "display_name": job.get("display_name", ""),
814
- "status": job.get("status", "unknown").lower(),
815
- "started_at": job.get("creation_context", ""),
816
- "azure_dashboard_url": f"https://ml.azure.com/experiments/id/{experiment_id}/runs/{job.get('name', '')}?wsid={wsid}",
817
- "is_live": True # Flag to indicate this is live data
818
- })
1172
+ formatted.append(
1173
+ {
1174
+ "job_id": job.get("name", "unknown"),
1175
+ "display_name": job.get("display_name", ""),
1176
+ "status": job.get("status", "unknown").lower(),
1177
+ "started_at": job.get("creation_context", ""),
1178
+ "azure_dashboard_url": f"https://ml.azure.com/experiments/id/{experiment_id}/runs/{job.get('name', '')}?wsid={wsid}",
1179
+ "is_live": True, # Flag to indicate this is live data
1180
+ }
1181
+ )
819
1182
  return formatted
820
1183
 
821
1184
  def _fetch_azure_job_logs(self, job_id: str | None):
@@ -825,34 +1188,63 @@ def cmd_serve(args: argparse.Namespace) -> int:
825
1188
  if not job_id:
826
1189
  # Get the most recent running job
827
1190
  jobs = self._fetch_live_azure_jobs()
828
- running = [j for j in jobs if j['status'] == 'running']
1191
+ running = [j for j in jobs if j["status"] == "running"]
829
1192
  if running:
830
- job_id = running[0]['job_id']
1193
+ job_id = running[0]["job_id"]
831
1194
  else:
832
- return {"logs": "No running jobs found", "job_id": None, "status": "idle"}
1195
+ return {
1196
+ "logs": "No running jobs found",
1197
+ "job_id": None,
1198
+ "status": "idle",
1199
+ }
833
1200
 
834
1201
  # Try to stream logs for running job using az ml job stream
835
1202
  try:
836
1203
  result = subprocess.run(
837
- ["az", "ml", "job", "stream",
838
- "--name", job_id,
839
- "--resource-group", "openadapt-agents",
840
- "--workspace-name", "openadapt-ml"],
841
- capture_output=True, text=True, timeout=3 # Short timeout
1204
+ [
1205
+ "az",
1206
+ "ml",
1207
+ "job",
1208
+ "stream",
1209
+ "--name",
1210
+ job_id,
1211
+ "--resource-group",
1212
+ "openadapt-agents",
1213
+ "--workspace-name",
1214
+ "openadapt-ml",
1215
+ ],
1216
+ capture_output=True,
1217
+ text=True,
1218
+ timeout=3, # Short timeout
842
1219
  )
843
1220
  if result.returncode == 0 and result.stdout.strip():
844
- return {"logs": result.stdout[-5000:], "job_id": job_id, "status": "streaming"}
1221
+ return {
1222
+ "logs": result.stdout[-5000:],
1223
+ "job_id": job_id,
1224
+ "status": "streaming",
1225
+ }
845
1226
  except subprocess.TimeoutExpired:
846
1227
  pass # Fall through to job show
847
1228
 
848
1229
  # Get job details instead
849
1230
  result = subprocess.run(
850
- ["az", "ml", "job", "show",
851
- "--name", job_id,
852
- "--resource-group", "openadapt-agents",
853
- "--workspace-name", "openadapt-ml",
854
- "-o", "json"],
855
- capture_output=True, text=True, timeout=10
1231
+ [
1232
+ "az",
1233
+ "ml",
1234
+ "job",
1235
+ "show",
1236
+ "--name",
1237
+ job_id,
1238
+ "--resource-group",
1239
+ "openadapt-agents",
1240
+ "--workspace-name",
1241
+ "openadapt-ml",
1242
+ "-o",
1243
+ "json",
1244
+ ],
1245
+ capture_output=True,
1246
+ text=True,
1247
+ timeout=10,
856
1248
  )
857
1249
 
858
1250
  if result.returncode == 0:
@@ -860,20 +1252,25 @@ def cmd_serve(args: argparse.Namespace) -> int:
860
1252
  return {
861
1253
  "logs": f"Job {job_id} is {job_info.get('status', 'unknown')}\\n\\nCommand: {job_info.get('command', 'N/A')}",
862
1254
  "job_id": job_id,
863
- "status": job_info.get('status', 'unknown').lower(),
864
- "command": job_info.get('command', '')
1255
+ "status": job_info.get("status", "unknown").lower(),
1256
+ "command": job_info.get("command", ""),
865
1257
  }
866
1258
 
867
- return {"logs": f"Could not fetch logs: {result.stderr}", "job_id": job_id, "status": "error"}
1259
+ return {
1260
+ "logs": f"Could not fetch logs: {result.stderr}",
1261
+ "job_id": job_id,
1262
+ "status": "error",
1263
+ }
868
1264
 
869
- def _get_vm_detailed_metadata(self, vm_ip: str, container_name: str, logs: str, phase: str) -> dict:
1265
+ def _get_vm_detailed_metadata(
1266
+ self, vm_ip: str, container_name: str, logs: str, phase: str
1267
+ ) -> dict:
870
1268
  """Get detailed VM metadata for the VM Details panel.
871
1269
 
872
1270
  Returns:
873
1271
  dict with disk_usage_gb, memory_usage_mb, setup_script_phase, probe_response, qmp_connected, dependencies
874
1272
  """
875
1273
  import subprocess
876
- import re
877
1274
 
878
1275
  metadata = {
879
1276
  "disk_usage_gb": None,
@@ -881,17 +1278,26 @@ def cmd_serve(args: argparse.Namespace) -> int:
881
1278
  "setup_script_phase": None,
882
1279
  "probe_response": None,
883
1280
  "qmp_connected": False,
884
- "dependencies": []
1281
+ "dependencies": [],
885
1282
  }
886
1283
 
887
1284
  # 1. Get disk usage from docker stats
888
1285
  try:
889
1286
  disk_result = subprocess.run(
890
- ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
891
- "-i", str(Path.home() / ".ssh" / "id_rsa"),
892
- f"azureuser@{vm_ip}",
893
- f"docker exec {container_name} df -h /storage 2>/dev/null | tail -1"],
894
- capture_output=True, text=True, timeout=10
1287
+ [
1288
+ "ssh",
1289
+ "-o",
1290
+ "StrictHostKeyChecking=no",
1291
+ "-o",
1292
+ "ConnectTimeout=5",
1293
+ "-i",
1294
+ str(Path.home() / ".ssh" / "id_rsa"),
1295
+ f"azureuser@{vm_ip}",
1296
+ f"docker exec {container_name} df -h /storage 2>/dev/null | tail -1",
1297
+ ],
1298
+ capture_output=True,
1299
+ text=True,
1300
+ timeout=10,
895
1301
  )
896
1302
  if disk_result.returncode == 0 and disk_result.stdout.strip():
897
1303
  # Parse: "Filesystem Size Used Avail Use% Mounted on"
@@ -900,27 +1306,40 @@ def cmd_serve(args: argparse.Namespace) -> int:
900
1306
  if len(parts) >= 3:
901
1307
  used_str = parts[2] # e.g., "9.2G"
902
1308
  total_str = parts[1] # e.g., "30G"
1309
+
903
1310
  # Convert to GB (handle M/G suffixes)
904
1311
  def to_gb(s):
905
- if s.endswith('G'):
1312
+ if s.endswith("G"):
906
1313
  return float(s[:-1])
907
- elif s.endswith('M'):
1314
+ elif s.endswith("M"):
908
1315
  return float(s[:-1]) / 1024
909
- elif s.endswith('K'):
1316
+ elif s.endswith("K"):
910
1317
  return float(s[:-1]) / (1024 * 1024)
911
1318
  return 0
912
- metadata["disk_usage_gb"] = f"{to_gb(used_str):.1f} GB / {to_gb(total_str):.0f} GB used"
1319
+
1320
+ metadata["disk_usage_gb"] = (
1321
+ f"{to_gb(used_str):.1f} GB / {to_gb(total_str):.0f} GB used"
1322
+ )
913
1323
  except Exception:
914
1324
  pass
915
1325
 
916
1326
  # 2. Get memory usage from docker stats
917
1327
  try:
918
1328
  mem_result = subprocess.run(
919
- ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
920
- "-i", str(Path.home() / ".ssh" / "id_rsa"),
921
- f"azureuser@{vm_ip}",
922
- f"docker stats {container_name} --no-stream --format '{{{{.MemUsage}}}}'"],
923
- capture_output=True, text=True, timeout=10
1329
+ [
1330
+ "ssh",
1331
+ "-o",
1332
+ "StrictHostKeyChecking=no",
1333
+ "-o",
1334
+ "ConnectTimeout=5",
1335
+ "-i",
1336
+ str(Path.home() / ".ssh" / "id_rsa"),
1337
+ f"azureuser@{vm_ip}",
1338
+ f"docker stats {container_name} --no-stream --format '{{{{.MemUsage}}}}'",
1339
+ ],
1340
+ capture_output=True,
1341
+ text=True,
1342
+ timeout=10,
924
1343
  )
925
1344
  if mem_result.returncode == 0 and mem_result.stdout.strip():
926
1345
  # Example: "1.5GiB / 4GiB"
@@ -929,16 +1348,27 @@ def cmd_serve(args: argparse.Namespace) -> int:
929
1348
  pass
930
1349
 
931
1350
  # 3. Parse setup script phase from logs
932
- metadata["setup_script_phase"] = self._parse_setup_phase_from_logs(logs, phase)
1351
+ metadata["setup_script_phase"] = self._parse_setup_phase_from_logs(
1352
+ logs, phase
1353
+ )
933
1354
 
934
1355
  # 4. Check /probe endpoint
935
1356
  try:
936
1357
  probe_result = subprocess.run(
937
- ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
938
- "-i", str(Path.home() / ".ssh" / "id_rsa"),
939
- f"azureuser@{vm_ip}",
940
- "curl -s --connect-timeout 2 http://20.20.20.21:5000/probe 2>/dev/null"],
941
- capture_output=True, text=True, timeout=10
1358
+ [
1359
+ "ssh",
1360
+ "-o",
1361
+ "StrictHostKeyChecking=no",
1362
+ "-o",
1363
+ "ConnectTimeout=5",
1364
+ "-i",
1365
+ str(Path.home() / ".ssh" / "id_rsa"),
1366
+ f"azureuser@{vm_ip}",
1367
+ "curl -s --connect-timeout 2 http://20.20.20.21:5000/probe 2>/dev/null",
1368
+ ],
1369
+ capture_output=True,
1370
+ text=True,
1371
+ timeout=10,
942
1372
  )
943
1373
  if probe_result.returncode == 0 and probe_result.stdout.strip():
944
1374
  metadata["probe_response"] = probe_result.stdout.strip()
@@ -950,11 +1380,20 @@ def cmd_serve(args: argparse.Namespace) -> int:
950
1380
  # 5. Check QMP connection (port 7200)
951
1381
  try:
952
1382
  qmp_result = subprocess.run(
953
- ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
954
- "-i", str(Path.home() / ".ssh" / "id_rsa"),
955
- f"azureuser@{vm_ip}",
956
- "nc -z -w2 localhost 7200 2>&1"],
957
- capture_output=True, text=True, timeout=10
1383
+ [
1384
+ "ssh",
1385
+ "-o",
1386
+ "StrictHostKeyChecking=no",
1387
+ "-o",
1388
+ "ConnectTimeout=5",
1389
+ "-i",
1390
+ str(Path.home() / ".ssh" / "id_rsa"),
1391
+ f"azureuser@{vm_ip}",
1392
+ "nc -z -w2 localhost 7200 2>&1",
1393
+ ],
1394
+ capture_output=True,
1395
+ text=True,
1396
+ timeout=10,
958
1397
  )
959
1398
  metadata["qmp_connected"] = qmp_result.returncode == 0
960
1399
  except Exception:
@@ -987,7 +1426,12 @@ def cmd_serve(args: argparse.Namespace) -> int:
987
1426
  return "Windows installation in progress"
988
1427
  elif current_phase == "booting":
989
1428
  return "Booting Windows"
990
- elif current_phase in ["downloading", "extracting", "configuring", "building"]:
1429
+ elif current_phase in [
1430
+ "downloading",
1431
+ "extracting",
1432
+ "configuring",
1433
+ "building",
1434
+ ]:
991
1435
  return "Preparing Windows VM"
992
1436
  else:
993
1437
  return "Initializing..."
@@ -1017,17 +1461,23 @@ def cmd_serve(args: argparse.Namespace) -> int:
1017
1461
  logs_lower = logs.lower()
1018
1462
 
1019
1463
  # Check for installation patterns
1020
- if "python" in logs_lower and ("installing python" in logs_lower or "python.exe" in logs_lower):
1464
+ if "python" in logs_lower and (
1465
+ "installing python" in logs_lower or "python.exe" in logs_lower
1466
+ ):
1021
1467
  dependencies[0]["status"] = "installing"
1022
1468
  elif "python" in logs_lower and "installed" in logs_lower:
1023
1469
  dependencies[0]["status"] = "complete"
1024
1470
 
1025
- if "chrome" in logs_lower and ("downloading" in logs_lower or "installing" in logs_lower):
1471
+ if "chrome" in logs_lower and (
1472
+ "downloading" in logs_lower or "installing" in logs_lower
1473
+ ):
1026
1474
  dependencies[1]["status"] = "installing"
1027
1475
  elif "chrome" in logs_lower and "installed" in logs_lower:
1028
1476
  dependencies[1]["status"] = "complete"
1029
1477
 
1030
- if "libreoffice" in logs_lower and ("downloading" in logs_lower or "installing" in logs_lower):
1478
+ if "libreoffice" in logs_lower and (
1479
+ "downloading" in logs_lower or "installing" in logs_lower
1480
+ ):
1031
1481
  dependencies[2]["status"] = "installing"
1032
1482
  elif "libreoffice" in logs_lower and "installed" in logs_lower:
1033
1483
  dependencies[2]["status"] = "complete"
@@ -1046,11 +1496,222 @@ def cmd_serve(args: argparse.Namespace) -> int:
1046
1496
 
1047
1497
  return dependencies
1048
1498
 
1499
+ def _get_vm_diagnostics(self) -> dict:
1500
+ """Get VM diagnostics: disk usage, Docker stats, memory usage.
1501
+
1502
+ Returns a dictionary with:
1503
+ - vm_online: bool - whether VM is reachable
1504
+ - disk_usage: list of disk partitions with usage stats
1505
+ - docker_stats: list of container stats (CPU, memory)
1506
+ - memory_usage: VM host memory stats
1507
+ - docker_system: Docker system disk usage
1508
+ - error: str if any error occurred
1509
+ """
1510
+ import subprocess
1511
+
1512
+ from openadapt_ml.benchmarks.session_tracker import get_session
1513
+
1514
+ diagnostics = {
1515
+ "vm_online": False,
1516
+ "disk_usage": [],
1517
+ "docker_stats": [],
1518
+ "memory_usage": {},
1519
+ "docker_system": {},
1520
+ "docker_images": [],
1521
+ "error": None,
1522
+ }
1523
+
1524
+ # Get VM IP from session
1525
+ session = get_session()
1526
+ vm_ip = session.get("vm_ip")
1527
+
1528
+ if not vm_ip:
1529
+ diagnostics["error"] = (
1530
+ "VM IP not found in session. VM may not be running."
1531
+ )
1532
+ return diagnostics
1533
+
1534
+ # SSH options for Azure VM
1535
+ ssh_opts = [
1536
+ "-o",
1537
+ "StrictHostKeyChecking=no",
1538
+ "-o",
1539
+ "UserKnownHostsFile=/dev/null",
1540
+ "-o",
1541
+ "ConnectTimeout=10",
1542
+ "-o",
1543
+ "ServerAliveInterval=30",
1544
+ ]
1545
+
1546
+ # Test VM connectivity
1547
+ try:
1548
+ test_result = subprocess.run(
1549
+ ["ssh", *ssh_opts, f"azureuser@{vm_ip}", "echo 'online'"],
1550
+ capture_output=True,
1551
+ text=True,
1552
+ timeout=15,
1553
+ )
1554
+ if test_result.returncode != 0:
1555
+ diagnostics["error"] = f"Cannot connect to VM at {vm_ip}"
1556
+ return diagnostics
1557
+ diagnostics["vm_online"] = True
1558
+ except subprocess.TimeoutExpired:
1559
+ diagnostics["error"] = f"Connection to VM at {vm_ip} timed out"
1560
+ return diagnostics
1561
+ except Exception as e:
1562
+ diagnostics["error"] = f"SSH error: {str(e)}"
1563
+ return diagnostics
1564
+
1565
+ # 1. Disk usage (df -h)
1566
+ try:
1567
+ df_result = subprocess.run(
1568
+ [
1569
+ "ssh",
1570
+ *ssh_opts,
1571
+ f"azureuser@{vm_ip}",
1572
+ "df -h / /mnt 2>/dev/null | tail -n +2",
1573
+ ],
1574
+ capture_output=True,
1575
+ text=True,
1576
+ timeout=15,
1577
+ )
1578
+ if df_result.returncode == 0 and df_result.stdout.strip():
1579
+ for line in df_result.stdout.strip().split("\n"):
1580
+ parts = line.split()
1581
+ if len(parts) >= 6:
1582
+ diagnostics["disk_usage"].append(
1583
+ {
1584
+ "filesystem": parts[0],
1585
+ "size": parts[1],
1586
+ "used": parts[2],
1587
+ "available": parts[3],
1588
+ "use_percent": parts[4],
1589
+ "mount_point": parts[5],
1590
+ }
1591
+ )
1592
+ except Exception as e:
1593
+ diagnostics["disk_usage"] = [{"error": str(e)}]
1594
+
1595
+ # 2. Docker container stats
1596
+ try:
1597
+ stats_result = subprocess.run(
1598
+ [
1599
+ "ssh",
1600
+ *ssh_opts,
1601
+ f"azureuser@{vm_ip}",
1602
+ "docker stats --no-stream --format '{{.Name}}|{{.CPUPerc}}|{{.MemUsage}}|{{.MemPerc}}|{{.NetIO}}|{{.BlockIO}}' 2>/dev/null || echo ''",
1603
+ ],
1604
+ capture_output=True,
1605
+ text=True,
1606
+ timeout=30,
1607
+ )
1608
+ if stats_result.returncode == 0 and stats_result.stdout.strip():
1609
+ for line in stats_result.stdout.strip().split("\n"):
1610
+ if "|" in line:
1611
+ parts = line.split("|")
1612
+ if len(parts) >= 6:
1613
+ diagnostics["docker_stats"].append(
1614
+ {
1615
+ "container": parts[0],
1616
+ "cpu_percent": parts[1],
1617
+ "memory_usage": parts[2],
1618
+ "memory_percent": parts[3],
1619
+ "net_io": parts[4],
1620
+ "block_io": parts[5],
1621
+ }
1622
+ )
1623
+ except Exception as e:
1624
+ diagnostics["docker_stats"] = [{"error": str(e)}]
1625
+
1626
+ # 3. VM host memory usage (free -h)
1627
+ try:
1628
+ mem_result = subprocess.run(
1629
+ [
1630
+ "ssh",
1631
+ *ssh_opts,
1632
+ f"azureuser@{vm_ip}",
1633
+ "free -h | head -2 | tail -1",
1634
+ ],
1635
+ capture_output=True,
1636
+ text=True,
1637
+ timeout=15,
1638
+ )
1639
+ if mem_result.returncode == 0 and mem_result.stdout.strip():
1640
+ parts = mem_result.stdout.strip().split()
1641
+ if len(parts) >= 7:
1642
+ diagnostics["memory_usage"] = {
1643
+ "total": parts[1],
1644
+ "used": parts[2],
1645
+ "free": parts[3],
1646
+ "shared": parts[4],
1647
+ "buff_cache": parts[5],
1648
+ "available": parts[6],
1649
+ }
1650
+ except Exception as e:
1651
+ diagnostics["memory_usage"] = {"error": str(e)}
1652
+
1653
+ # 4. Docker system disk usage
1654
+ try:
1655
+ docker_df_result = subprocess.run(
1656
+ [
1657
+ "ssh",
1658
+ *ssh_opts,
1659
+ f"azureuser@{vm_ip}",
1660
+ "docker system df 2>/dev/null || echo ''",
1661
+ ],
1662
+ capture_output=True,
1663
+ text=True,
1664
+ timeout=15,
1665
+ )
1666
+ if docker_df_result.returncode == 0 and docker_df_result.stdout.strip():
1667
+ lines = docker_df_result.stdout.strip().split("\n")
1668
+ # Parse the table: TYPE, TOTAL, ACTIVE, SIZE, RECLAIMABLE
1669
+ for line in lines[1:]: # Skip header
1670
+ parts = line.split()
1671
+ if len(parts) >= 5:
1672
+ dtype = parts[0]
1673
+ diagnostics["docker_system"][dtype.lower()] = {
1674
+ "total": parts[1],
1675
+ "active": parts[2],
1676
+ "size": parts[3],
1677
+ "reclaimable": " ".join(parts[4:]),
1678
+ }
1679
+ except Exception as e:
1680
+ diagnostics["docker_system"] = {"error": str(e)}
1681
+
1682
+ # 5. Docker images
1683
+ try:
1684
+ images_result = subprocess.run(
1685
+ [
1686
+ "ssh",
1687
+ *ssh_opts,
1688
+ f"azureuser@{vm_ip}",
1689
+ "docker images --format '{{.Repository}}:{{.Tag}}|{{.Size}}|{{.CreatedSince}}' 2>/dev/null || echo ''",
1690
+ ],
1691
+ capture_output=True,
1692
+ text=True,
1693
+ timeout=15,
1694
+ )
1695
+ if images_result.returncode == 0 and images_result.stdout.strip():
1696
+ for line in images_result.stdout.strip().split("\n"):
1697
+ if "|" in line:
1698
+ parts = line.split("|")
1699
+ if len(parts) >= 3:
1700
+ diagnostics["docker_images"].append(
1701
+ {
1702
+ "image": parts[0],
1703
+ "size": parts[1],
1704
+ "created": parts[2],
1705
+ }
1706
+ )
1707
+ except Exception as e:
1708
+ diagnostics["docker_images"] = [{"error": str(e)}]
1709
+
1710
+ return diagnostics
1711
+
1049
1712
  def _fetch_background_tasks(self):
1050
1713
  """Fetch status of all background tasks: Azure VM, Docker containers, benchmarks."""
1051
1714
  import subprocess
1052
- from datetime import datetime
1053
- import time
1054
1715
 
1055
1716
  tasks = []
1056
1717
 
@@ -1063,31 +1724,43 @@ def cmd_serve(args: argparse.Namespace) -> int:
1063
1724
  if env_vm_ip:
1064
1725
  # Use environment variable - VM IP was provided directly
1065
1726
  vm_ip = env_vm_ip
1066
- tasks.append({
1067
- "task_id": "azure-vm-waa",
1068
- "task_type": "vm_provision",
1069
- "status": "completed",
1070
- "phase": "ready", # Match status to prevent "Starting" + "completed" conflict
1071
- "title": "Azure VM Host",
1072
- "description": f"Linux host running at {vm_ip}",
1073
- "progress_percent": 100.0,
1074
- "elapsed_seconds": 0,
1075
- "metadata": {
1076
- "vm_name": "waa-eval-vm",
1077
- "ip_address": vm_ip,
1078
- "internal_ip": env_internal_ip
1727
+ tasks.append(
1728
+ {
1729
+ "task_id": "azure-vm-waa",
1730
+ "task_type": "vm_provision",
1731
+ "status": "completed",
1732
+ "phase": "ready", # Match status to prevent "Starting" + "completed" conflict
1733
+ "title": "Azure VM Host",
1734
+ "description": f"Linux host running at {vm_ip}",
1735
+ "progress_percent": 100.0,
1736
+ "elapsed_seconds": 0,
1737
+ "metadata": {
1738
+ "vm_name": "waa-eval-vm",
1739
+ "ip_address": vm_ip,
1740
+ "internal_ip": env_internal_ip,
1741
+ },
1079
1742
  }
1080
- })
1743
+ )
1081
1744
  else:
1082
1745
  # Query Azure CLI for VM status
1083
1746
  try:
1084
1747
  result = subprocess.run(
1085
- ["az", "vm", "get-instance-view",
1086
- "--name", "waa-eval-vm",
1087
- "--resource-group", "openadapt-agents",
1088
- "--query", "instanceView.statuses",
1089
- "-o", "json"],
1090
- capture_output=True, text=True, timeout=10
1748
+ [
1749
+ "az",
1750
+ "vm",
1751
+ "get-instance-view",
1752
+ "--name",
1753
+ "waa-eval-vm",
1754
+ "--resource-group",
1755
+ "openadapt-agents",
1756
+ "--query",
1757
+ "instanceView.statuses",
1758
+ "-o",
1759
+ "json",
1760
+ ],
1761
+ capture_output=True,
1762
+ text=True,
1763
+ timeout=10,
1091
1764
  )
1092
1765
  if result.returncode == 0:
1093
1766
  statuses = json.loads(result.stdout)
@@ -1098,31 +1771,49 @@ def cmd_serve(args: argparse.Namespace) -> int:
1098
1771
 
1099
1772
  # Get VM IP
1100
1773
  ip_result = subprocess.run(
1101
- ["az", "vm", "list-ip-addresses",
1102
- "--name", "waa-eval-vm",
1103
- "--resource-group", "openadapt-agents",
1104
- "--query", "[0].virtualMachine.network.publicIpAddresses[0].ipAddress",
1105
- "-o", "tsv"],
1106
- capture_output=True, text=True, timeout=10
1774
+ [
1775
+ "az",
1776
+ "vm",
1777
+ "list-ip-addresses",
1778
+ "--name",
1779
+ "waa-eval-vm",
1780
+ "--resource-group",
1781
+ "openadapt-agents",
1782
+ "--query",
1783
+ "[0].virtualMachine.network.publicIpAddresses[0].ipAddress",
1784
+ "-o",
1785
+ "tsv",
1786
+ ],
1787
+ capture_output=True,
1788
+ text=True,
1789
+ timeout=10,
1790
+ )
1791
+ vm_ip = (
1792
+ ip_result.stdout.strip()
1793
+ if ip_result.returncode == 0
1794
+ else None
1107
1795
  )
1108
- vm_ip = ip_result.stdout.strip() if ip_result.returncode == 0 else None
1109
1796
 
1110
1797
  if power_state == "running":
1111
- tasks.append({
1112
- "task_id": "azure-vm-waa",
1113
- "task_type": "vm_provision",
1114
- "status": "completed",
1115
- "phase": "ready", # Match status to prevent "Starting" + "completed" conflict
1116
- "title": "Azure VM Host",
1117
- "description": f"Linux host running at {vm_ip}" if vm_ip else "Linux host running",
1118
- "progress_percent": 100.0,
1119
- "elapsed_seconds": 0,
1120
- "metadata": {
1121
- "vm_name": "waa-eval-vm",
1122
- "ip_address": vm_ip
1123
- # No VNC link - that's for the Windows container
1798
+ tasks.append(
1799
+ {
1800
+ "task_id": "azure-vm-waa",
1801
+ "task_type": "vm_provision",
1802
+ "status": "completed",
1803
+ "phase": "ready", # Match status to prevent "Starting" + "completed" conflict
1804
+ "title": "Azure VM Host",
1805
+ "description": f"Linux host running at {vm_ip}"
1806
+ if vm_ip
1807
+ else "Linux host running",
1808
+ "progress_percent": 100.0,
1809
+ "elapsed_seconds": 0,
1810
+ "metadata": {
1811
+ "vm_name": "waa-eval-vm",
1812
+ "ip_address": vm_ip,
1813
+ # No VNC link - that's for the Windows container
1814
+ },
1124
1815
  }
1125
- })
1816
+ )
1126
1817
  except subprocess.TimeoutExpired:
1127
1818
  pass
1128
1819
  except Exception:
@@ -1132,31 +1823,59 @@ def cmd_serve(args: argparse.Namespace) -> int:
1132
1823
  if vm_ip:
1133
1824
  try:
1134
1825
  docker_result = subprocess.run(
1135
- ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
1136
- "-i", str(Path.home() / ".ssh" / "id_rsa"),
1137
- f"azureuser@{vm_ip}",
1138
- "docker ps --format '{{.Names}}|{{.Status}}|{{.Image}}'"],
1139
- capture_output=True, text=True, timeout=15
1826
+ [
1827
+ "ssh",
1828
+ "-o",
1829
+ "StrictHostKeyChecking=no",
1830
+ "-o",
1831
+ "ConnectTimeout=5",
1832
+ "-i",
1833
+ str(Path.home() / ".ssh" / "id_rsa"),
1834
+ f"azureuser@{vm_ip}",
1835
+ "docker ps --format '{{.Names}}|{{.Status}}|{{.Image}}'",
1836
+ ],
1837
+ capture_output=True,
1838
+ text=True,
1839
+ timeout=15,
1140
1840
  )
1141
1841
  if docker_result.returncode == 0 and docker_result.stdout.strip():
1142
- for line in docker_result.stdout.strip().split('\n'):
1143
- parts = line.split('|')
1842
+ for line in docker_result.stdout.strip().split("\n"):
1843
+ parts = line.split("|")
1144
1844
  if len(parts) >= 3:
1145
- container_name, status, image = parts[0], parts[1], parts[2]
1845
+ container_name, status, image = (
1846
+ parts[0],
1847
+ parts[1],
1848
+ parts[2],
1849
+ )
1146
1850
  # Parse "Up X minutes" to determine if healthy
1147
- is_healthy = "Up" in status
1148
1851
 
1149
1852
  # Check for Windows VM specifically
1150
- if "windows" in image.lower() or container_name == "winarena":
1853
+ if (
1854
+ "windows" in image.lower()
1855
+ or container_name == "winarena"
1856
+ ):
1151
1857
  # Get detailed progress from docker logs
1152
1858
  log_check = subprocess.run(
1153
- ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
1154
- "-i", str(Path.home() / ".ssh" / "id_rsa"),
1155
- f"azureuser@{vm_ip}",
1156
- f"docker logs {container_name} 2>&1 | tail -30"],
1157
- capture_output=True, text=True, timeout=10
1859
+ [
1860
+ "ssh",
1861
+ "-o",
1862
+ "StrictHostKeyChecking=no",
1863
+ "-o",
1864
+ "ConnectTimeout=5",
1865
+ "-i",
1866
+ str(Path.home() / ".ssh" / "id_rsa"),
1867
+ f"azureuser@{vm_ip}",
1868
+ f"docker logs {container_name} 2>&1 | tail -30",
1869
+ ],
1870
+ capture_output=True,
1871
+ text=True,
1872
+ timeout=10,
1873
+ )
1874
+ logs = (
1875
+ log_check.stdout
1876
+ if log_check.returncode == 0
1877
+ else ""
1158
1878
  )
1159
- logs = log_check.stdout if log_check.returncode == 0 else ""
1160
1879
 
1161
1880
  # Parse progress from logs
1162
1881
  phase = "unknown"
@@ -1167,17 +1886,32 @@ def cmd_serve(args: argparse.Namespace) -> int:
1167
1886
  # Check if WAA server is ready via Docker port forwarding
1168
1887
  # See docs/waa_network_architecture.md - always use localhost
1169
1888
  server_check = subprocess.run(
1170
- ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
1171
- "-i", str(Path.home() / ".ssh" / "id_rsa"),
1172
- f"azureuser@{vm_ip}",
1173
- "curl -s --connect-timeout 2 http://localhost:5000/probe 2>/dev/null"],
1174
- capture_output=True, text=True, timeout=10
1889
+ [
1890
+ "ssh",
1891
+ "-o",
1892
+ "StrictHostKeyChecking=no",
1893
+ "-o",
1894
+ "ConnectTimeout=5",
1895
+ "-i",
1896
+ str(Path.home() / ".ssh" / "id_rsa"),
1897
+ f"azureuser@{vm_ip}",
1898
+ "curl -s --connect-timeout 2 http://localhost:5000/probe 2>/dev/null",
1899
+ ],
1900
+ capture_output=True,
1901
+ text=True,
1902
+ timeout=10,
1903
+ )
1904
+ waa_ready = (
1905
+ server_check.returncode == 0
1906
+ and "Service is operational"
1907
+ in server_check.stdout
1175
1908
  )
1176
- waa_ready = server_check.returncode == 0 and "Service is operational" in server_check.stdout
1177
1909
  if waa_ready:
1178
1910
  phase = "ready"
1179
1911
  progress = 100.0
1180
- description = "WAA Server ready - benchmarks can run"
1912
+ description = (
1913
+ "WAA Server ready - benchmarks can run"
1914
+ )
1181
1915
  else:
1182
1916
  phase = "oobe"
1183
1917
  progress = 80.0 # Phase 5/6 - VM install in progress
@@ -1186,10 +1920,15 @@ def cmd_serve(args: argparse.Namespace) -> int:
1186
1920
  phase = "booting"
1187
1921
  progress = 70.0 # Phase 4/6
1188
1922
  description = "Phase 4/6: Booting Windows from installer..."
1189
- elif "Building Windows" in logs or "Creating a" in logs:
1923
+ elif (
1924
+ "Building Windows" in logs
1925
+ or "Creating a" in logs
1926
+ ):
1190
1927
  phase = "building"
1191
1928
  progress = 60.0 # Phase 3/6
1192
- description = "Phase 3/6: Building Windows VM disk..."
1929
+ description = (
1930
+ "Phase 3/6: Building Windows VM disk..."
1931
+ )
1193
1932
  elif "Adding" in logs and "image" in logs:
1194
1933
  phase = "configuring"
1195
1934
  progress = 50.0 # Phase 2/6
@@ -1197,15 +1936,22 @@ def cmd_serve(args: argparse.Namespace) -> int:
1197
1936
  elif "Extracting" in logs:
1198
1937
  phase = "extracting"
1199
1938
  progress = 35.0 # Phase 1/6 (after download)
1200
- description = "Phase 1/6: Extracting Windows ISO..."
1939
+ description = (
1940
+ "Phase 1/6: Extracting Windows ISO..."
1941
+ )
1201
1942
  else:
1202
1943
  # Check for download progress (e.g., "1234K ........ 45% 80M 30s")
1203
1944
  import re
1204
- download_match = re.search(r'(\d+)%\s+[\d.]+[KMG]\s+(\d+)s', logs)
1945
+
1946
+ download_match = re.search(
1947
+ r"(\d+)%\s+[\d.]+[KMG]\s+(\d+)s", logs
1948
+ )
1205
1949
  if download_match:
1206
1950
  phase = "downloading"
1207
1951
  dl_pct = float(download_match.group(1))
1208
- progress = dl_pct * 0.30 # 0-30% for download phase
1952
+ progress = (
1953
+ dl_pct * 0.30
1954
+ ) # 0-30% for download phase
1209
1955
  eta = download_match.group(2)
1210
1956
  description = f"Phase 0/6: Downloading Windows 11... {download_match.group(1)}% ({eta}s left)"
1211
1957
 
@@ -1218,40 +1964,61 @@ def cmd_serve(args: argparse.Namespace) -> int:
1218
1964
  progress = 90.0
1219
1965
 
1220
1966
  # Get detailed metadata for VM Details panel
1221
- vm_metadata = self._get_vm_detailed_metadata(vm_ip, container_name, logs, phase)
1222
-
1223
- tasks.append({
1224
- "task_id": f"docker-{container_name}",
1225
- "task_type": "docker_container",
1226
- "status": "completed" if phase == "ready" else "running",
1227
- "title": "Windows 11 + WAA Server",
1228
- "description": description,
1229
- "progress_percent": progress,
1230
- "elapsed_seconds": 0,
1231
- "phase": phase,
1232
- "metadata": {
1233
- "container": container_name,
1234
- "image": image,
1235
- "status": status,
1967
+ vm_metadata = self._get_vm_detailed_metadata(
1968
+ vm_ip, container_name, logs, phase
1969
+ )
1970
+
1971
+ tasks.append(
1972
+ {
1973
+ "task_id": f"docker-{container_name}",
1974
+ "task_type": "docker_container",
1975
+ "status": "completed"
1976
+ if phase == "ready"
1977
+ else "running",
1978
+ "title": "Windows 11 + WAA Server",
1979
+ "description": description,
1980
+ "progress_percent": progress,
1981
+ "elapsed_seconds": 0,
1236
1982
  "phase": phase,
1237
- "windows_ready": phase in ["oobe", "ready"],
1238
- "waa_server_ready": phase == "ready",
1239
- # Use localhost - SSH tunnel handles routing to VM
1240
- # See docs/waa_network_architecture.md
1241
- "vnc_url": "http://localhost:8006",
1242
- "windows_username": "Docker",
1243
- "windows_password": "admin",
1244
- "recent_logs": logs[-500:] if logs else "",
1245
- # Enhanced VM details
1246
- "disk_usage_gb": vm_metadata["disk_usage_gb"],
1247
- "memory_usage_mb": vm_metadata["memory_usage_mb"],
1248
- "setup_script_phase": vm_metadata["setup_script_phase"],
1249
- "probe_response": vm_metadata["probe_response"],
1250
- "qmp_connected": vm_metadata["qmp_connected"],
1251
- "dependencies": vm_metadata["dependencies"],
1983
+ "metadata": {
1984
+ "container": container_name,
1985
+ "image": image,
1986
+ "status": status,
1987
+ "phase": phase,
1988
+ "windows_ready": phase
1989
+ in ["oobe", "ready"],
1990
+ "waa_server_ready": phase == "ready",
1991
+ # Use localhost - SSH tunnel handles routing to VM
1992
+ # See docs/waa_network_architecture.md
1993
+ "vnc_url": "http://localhost:8006",
1994
+ "windows_username": "Docker",
1995
+ "windows_password": "admin",
1996
+ "recent_logs": logs[-500:]
1997
+ if logs
1998
+ else "",
1999
+ # Enhanced VM details
2000
+ "disk_usage_gb": vm_metadata[
2001
+ "disk_usage_gb"
2002
+ ],
2003
+ "memory_usage_mb": vm_metadata[
2004
+ "memory_usage_mb"
2005
+ ],
2006
+ "setup_script_phase": vm_metadata[
2007
+ "setup_script_phase"
2008
+ ],
2009
+ "probe_response": vm_metadata[
2010
+ "probe_response"
2011
+ ],
2012
+ "qmp_connected": vm_metadata[
2013
+ "qmp_connected"
2014
+ ],
2015
+ "dependencies": vm_metadata[
2016
+ "dependencies"
2017
+ ],
2018
+ },
1252
2019
  }
1253
- })
1254
- except Exception as e:
2020
+ )
2021
+ except Exception:
1255
2022
  # SSH failed, VM might still be starting
1256
2023
  pass
1257
2024
 
@@ -1261,38 +2028,93 @@ def cmd_serve(args: argparse.Namespace) -> int:
1261
2028
  try:
1262
2029
  progress = json.loads(progress_file.read_text())
1263
2030
  if progress.get("status") == "running":
1264
- tasks.append({
1265
- "task_id": "benchmark-local",
1266
- "task_type": "benchmark_run",
1267
- "status": "running",
1268
- "title": f"{progress.get('provider', 'API').upper()} Benchmark",
1269
- "description": progress.get("message", "Running benchmark..."),
1270
- "progress_percent": (progress.get("tasks_complete", 0) / max(progress.get("tasks_total", 1), 1)) * 100,
1271
- "elapsed_seconds": 0,
1272
- "metadata": progress
1273
- })
2031
+ tasks.append(
2032
+ {
2033
+ "task_id": "benchmark-local",
2034
+ "task_type": "benchmark_run",
2035
+ "status": "running",
2036
+ "title": f"{progress.get('provider', 'API').upper()} Benchmark",
2037
+ "description": progress.get(
2038
+ "message", "Running benchmark..."
2039
+ ),
2040
+ "progress_percent": (
2041
+ progress.get("tasks_complete", 0)
2042
+ / max(progress.get("tasks_total", 1), 1)
2043
+ )
2044
+ * 100,
2045
+ "elapsed_seconds": 0,
2046
+ "metadata": progress,
2047
+ }
2048
+ )
1274
2049
  except Exception:
1275
2050
  pass
1276
2051
 
1277
2052
  return tasks
1278
2053
 
1279
2054
  def _fetch_vm_registry(self):
1280
- """Fetch VM registry with live status checks."""
2055
+ """Fetch VM registry with live status checks.
2056
+
2057
+ NOTE: We now fetch the VM IP from Azure CLI at runtime to avoid
2058
+ stale IP issues. The registry file is only used as a fallback.
2059
+ """
1281
2060
  import subprocess
1282
2061
  from datetime import datetime
1283
2062
 
1284
- # Path to VM registry file (relative to project root)
1285
- project_root = Path(__file__).parent.parent.parent
1286
- registry_file = project_root / "benchmark_results" / "vm_registry.json"
2063
+ # Try to get VM IP from Azure CLI (always fresh)
2064
+ vm_ip = None
2065
+ resource_group = "openadapt-agents"
2066
+ vm_name = "azure-waa-vm"
2067
+ try:
2068
+ result = subprocess.run(
2069
+ [
2070
+ "az",
2071
+ "vm",
2072
+ "show",
2073
+ "-d",
2074
+ "-g",
2075
+ resource_group,
2076
+ "-n",
2077
+ vm_name,
2078
+ "--query",
2079
+ "publicIps",
2080
+ "-o",
2081
+ "tsv",
2082
+ ],
2083
+ capture_output=True,
2084
+ text=True,
2085
+ timeout=10,
2086
+ )
2087
+ if result.returncode == 0 and result.stdout.strip():
2088
+ vm_ip = result.stdout.strip()
2089
+ except Exception:
2090
+ pass
1287
2091
 
1288
- if not registry_file.exists():
1289
- return []
2092
+ # If we have a fresh IP from Azure, use it
2093
+ if vm_ip:
2094
+ vms = [
2095
+ {
2096
+ "name": vm_name,
2097
+ "ssh_host": vm_ip,
2098
+ "ssh_user": "azureuser",
2099
+ "vnc_port": 8006,
2100
+ "waa_port": 5000,
2101
+ "docker_container": "winarena",
2102
+ "internal_ip": "localhost",
2103
+ }
2104
+ ]
2105
+ else:
2106
+ # Fallback to registry file
2107
+ project_root = Path(__file__).parent.parent.parent
2108
+ registry_file = project_root / "benchmark_results" / "vm_registry.json"
1290
2109
 
1291
- try:
1292
- with open(registry_file) as f:
1293
- vms = json.load(f)
1294
- except Exception as e:
1295
- return {"error": f"Failed to read VM registry: {e}"}
2110
+ if not registry_file.exists():
2111
+ return []
2112
+
2113
+ try:
2114
+ with open(registry_file) as f:
2115
+ vms = json.load(f)
2116
+ except Exception as e:
2117
+ return {"error": f"Failed to read VM registry: {e}"}
1296
2118
 
1297
2119
  # Check status for each VM
1298
2120
  for vm in vms:
@@ -1306,7 +2128,9 @@ def cmd_serve(args: argparse.Namespace) -> int:
1306
2128
  vnc_url = f"http://{vm['ssh_host']}:{vm['vnc_port']}"
1307
2129
  result = subprocess.run(
1308
2130
  ["curl", "-I", "-s", "--connect-timeout", "3", vnc_url],
1309
- capture_output=True, text=True, timeout=5
2131
+ capture_output=True,
2132
+ text=True,
2133
+ timeout=5,
1310
2134
  )
1311
2135
  if result.returncode == 0 and "200" in result.stdout:
1312
2136
  vm["vnc_reachable"] = True
@@ -1320,13 +2144,25 @@ def cmd_serve(args: argparse.Namespace) -> int:
1320
2144
  waa_port = vm.get("waa_port", 5000)
1321
2145
  ssh_cmd = f"curl -s --connect-timeout 2 http://localhost:{waa_port}/probe 2>/dev/null"
1322
2146
  result = subprocess.run(
1323
- ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=3",
1324
- "-i", str(Path.home() / ".ssh" / "id_rsa"),
1325
- f"{vm['ssh_user']}@{vm['ssh_host']}",
1326
- ssh_cmd],
1327
- capture_output=True, text=True, timeout=5
2147
+ [
2148
+ "ssh",
2149
+ "-o",
2150
+ "StrictHostKeyChecking=no",
2151
+ "-o",
2152
+ "ConnectTimeout=3",
2153
+ "-i",
2154
+ str(Path.home() / ".ssh" / "id_rsa"),
2155
+ f"{vm['ssh_user']}@{vm['ssh_host']}",
2156
+ ssh_cmd,
2157
+ ],
2158
+ capture_output=True,
2159
+ text=True,
2160
+ timeout=5,
2161
+ )
2162
+ probe_success = (
2163
+ result.returncode == 0
2164
+ and "Service is operational" in result.stdout
1328
2165
  )
1329
- probe_success = result.returncode == 0 and "Service is operational" in result.stdout
1330
2166
  if probe_success:
1331
2167
  vm["waa_probe_status"] = "ready"
1332
2168
  vm["status"] = "online"
@@ -1338,7 +2174,11 @@ def cmd_serve(args: argparse.Namespace) -> int:
1338
2174
  ssh_user=vm.get("ssh_user", "azureuser"),
1339
2175
  )
1340
2176
  vm["tunnels"] = {
1341
- name: {"active": s.active, "local_port": s.local_port, "error": s.error}
2177
+ name: {
2178
+ "active": s.active,
2179
+ "local_port": s.local_port,
2180
+ "error": s.error,
2181
+ }
1342
2182
  for name, s in tunnel_status.items()
1343
2183
  }
1344
2184
  except Exception as e:
@@ -1384,12 +2224,22 @@ def cmd_serve(args: argparse.Namespace) -> int:
1384
2224
  # First get VM IP
1385
2225
  try:
1386
2226
  ip_result = subprocess.run(
1387
- ["az", "vm", "list-ip-addresses",
1388
- "--name", "waa-eval-vm",
1389
- "--resource-group", "openadapt-agents",
1390
- "--query", "[0].virtualMachine.network.publicIpAddresses[0].ipAddress",
1391
- "-o", "tsv"],
1392
- capture_output=True, text=True, timeout=10
2227
+ [
2228
+ "az",
2229
+ "vm",
2230
+ "list-ip-addresses",
2231
+ "--name",
2232
+ "waa-eval-vm",
2233
+ "--resource-group",
2234
+ "openadapt-agents",
2235
+ "--query",
2236
+ "[0].virtualMachine.network.publicIpAddresses[0].ipAddress",
2237
+ "-o",
2238
+ "tsv",
2239
+ ],
2240
+ capture_output=True,
2241
+ text=True,
2242
+ timeout=10,
1393
2243
  )
1394
2244
  if ip_result.returncode == 0 and ip_result.stdout.strip():
1395
2245
  vm_ip = ip_result.stdout.strip()
@@ -1398,11 +2248,20 @@ def cmd_serve(args: argparse.Namespace) -> int:
1398
2248
  # Try to probe WAA server via SSH
1399
2249
  # Use the correct internal IP for the Windows VM inside Docker
1400
2250
  probe_result = subprocess.run(
1401
- ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
1402
- "-i", str(Path.home() / ".ssh" / "id_rsa"),
1403
- f"azureuser@{vm_ip}",
1404
- "docker exec waa-container curl -s --connect-timeout 3 http://172.30.0.2:5000/probe 2>/dev/null || echo 'probe_failed'"],
1405
- capture_output=True, text=True, timeout=15
2251
+ [
2252
+ "ssh",
2253
+ "-o",
2254
+ "StrictHostKeyChecking=no",
2255
+ "-o",
2256
+ "ConnectTimeout=5",
2257
+ "-i",
2258
+ str(Path.home() / ".ssh" / "id_rsa"),
2259
+ f"azureuser@{vm_ip}",
2260
+ "docker exec waa-container curl -s --connect-timeout 3 http://172.30.0.2:5000/probe 2>/dev/null || echo 'probe_failed'",
2261
+ ],
2262
+ capture_output=True,
2263
+ text=True,
2264
+ timeout=15,
1406
2265
  )
1407
2266
 
1408
2267
  result["container"] = "waa-container"
@@ -1415,7 +2274,9 @@ def cmd_serve(args: argparse.Namespace) -> int:
1415
2274
  else:
1416
2275
  result["probe_result"] = "WAA server not responding"
1417
2276
  else:
1418
- result["probe_result"] = f"SSH/Docker error: {probe_result.stderr[:200]}"
2277
+ result["probe_result"] = (
2278
+ f"SSH/Docker error: {probe_result.stderr[:200]}"
2279
+ )
1419
2280
  else:
1420
2281
  result["probe_result"] = "Could not get VM IP"
1421
2282
  except subprocess.TimeoutExpired:
@@ -1443,7 +2304,6 @@ def cmd_serve(args: argparse.Namespace) -> int:
1443
2304
  - elapsed_minutes: int
1444
2305
  """
1445
2306
  import subprocess
1446
- from datetime import datetime
1447
2307
  import re
1448
2308
 
1449
2309
  result = {
@@ -1465,8 +2325,12 @@ def cmd_serve(args: argparse.Namespace) -> int:
1465
2325
  result["running"] = True
1466
2326
  result["type"] = "local"
1467
2327
  result["model"] = progress.get("provider", "unknown")
1468
- result["progress"]["tasks_completed"] = progress.get("tasks_complete", 0)
1469
- result["progress"]["total_tasks"] = progress.get("tasks_total", 0)
2328
+ result["progress"]["tasks_completed"] = progress.get(
2329
+ "tasks_complete", 0
2330
+ )
2331
+ result["progress"]["total_tasks"] = progress.get(
2332
+ "tasks_total", 0
2333
+ )
1470
2334
  return result
1471
2335
  except Exception:
1472
2336
  pass
@@ -1475,12 +2339,22 @@ def cmd_serve(args: argparse.Namespace) -> int:
1475
2339
  try:
1476
2340
  # Get VM IP
1477
2341
  ip_result = subprocess.run(
1478
- ["az", "vm", "list-ip-addresses",
1479
- "--name", "waa-eval-vm",
1480
- "--resource-group", "openadapt-agents",
1481
- "--query", "[0].virtualMachine.network.publicIpAddresses[0].ipAddress",
1482
- "-o", "tsv"],
1483
- capture_output=True, text=True, timeout=10
2342
+ [
2343
+ "az",
2344
+ "vm",
2345
+ "list-ip-addresses",
2346
+ "--name",
2347
+ "waa-eval-vm",
2348
+ "--resource-group",
2349
+ "openadapt-agents",
2350
+ "--query",
2351
+ "[0].virtualMachine.network.publicIpAddresses[0].ipAddress",
2352
+ "-o",
2353
+ "tsv",
2354
+ ],
2355
+ capture_output=True,
2356
+ text=True,
2357
+ timeout=10,
1484
2358
  )
1485
2359
 
1486
2360
  if ip_result.returncode == 0 and ip_result.stdout.strip():
@@ -1488,42 +2362,73 @@ def cmd_serve(args: argparse.Namespace) -> int:
1488
2362
 
1489
2363
  # Check if benchmark process is running
1490
2364
  process_check = subprocess.run(
1491
- ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
1492
- "-i", str(Path.home() / ".ssh" / "id_rsa"),
1493
- f"azureuser@{vm_ip}",
1494
- "docker exec waa-container pgrep -f 'python.*run.py' 2>/dev/null && echo 'RUNNING' || echo 'NOT_RUNNING'"],
1495
- capture_output=True, text=True, timeout=10
2365
+ [
2366
+ "ssh",
2367
+ "-o",
2368
+ "StrictHostKeyChecking=no",
2369
+ "-o",
2370
+ "ConnectTimeout=5",
2371
+ "-i",
2372
+ str(Path.home() / ".ssh" / "id_rsa"),
2373
+ f"azureuser@{vm_ip}",
2374
+ "docker exec waa-container pgrep -f 'python.*run.py' 2>/dev/null && echo 'RUNNING' || echo 'NOT_RUNNING'",
2375
+ ],
2376
+ capture_output=True,
2377
+ text=True,
2378
+ timeout=10,
1496
2379
  )
1497
2380
 
1498
- if process_check.returncode == 0 and "RUNNING" in process_check.stdout:
2381
+ if (
2382
+ process_check.returncode == 0
2383
+ and "RUNNING" in process_check.stdout
2384
+ ):
1499
2385
  result["running"] = True
1500
2386
  result["type"] = "azure_vm"
1501
2387
 
1502
2388
  # Get log file for more details
1503
2389
  log_check = subprocess.run(
1504
- ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
1505
- "-i", str(Path.home() / ".ssh" / "id_rsa"),
1506
- f"azureuser@{vm_ip}",
1507
- "tail -100 /tmp/waa_benchmark.log 2>/dev/null || echo ''"],
1508
- capture_output=True, text=True, timeout=10
2390
+ [
2391
+ "ssh",
2392
+ "-o",
2393
+ "StrictHostKeyChecking=no",
2394
+ "-o",
2395
+ "ConnectTimeout=5",
2396
+ "-i",
2397
+ str(Path.home() / ".ssh" / "id_rsa"),
2398
+ f"azureuser@{vm_ip}",
2399
+ "tail -100 /tmp/waa_benchmark.log 2>/dev/null || echo ''",
2400
+ ],
2401
+ capture_output=True,
2402
+ text=True,
2403
+ timeout=10,
1509
2404
  )
1510
2405
 
1511
2406
  if log_check.returncode == 0 and log_check.stdout.strip():
1512
2407
  logs = log_check.stdout
1513
2408
 
1514
2409
  # Parse model from logs
1515
- model_match = re.search(r'model[=:\s]+([^\s,]+)', logs, re.IGNORECASE)
2410
+ model_match = re.search(
2411
+ r"model[=:\s]+([^\s,]+)", logs, re.IGNORECASE
2412
+ )
1516
2413
  if model_match:
1517
2414
  result["model"] = model_match.group(1)
1518
2415
 
1519
2416
  # Parse progress
1520
- task_match = re.search(r'Task\s+(\d+)/(\d+)', logs)
2417
+ task_match = re.search(r"Task\s+(\d+)/(\d+)", logs)
1521
2418
  if task_match:
1522
- result["progress"]["tasks_completed"] = int(task_match.group(1))
1523
- result["progress"]["total_tasks"] = int(task_match.group(2))
2419
+ result["progress"]["tasks_completed"] = int(
2420
+ task_match.group(1)
2421
+ )
2422
+ result["progress"]["total_tasks"] = int(
2423
+ task_match.group(2)
2424
+ )
1524
2425
 
1525
2426
  # Parse current task
1526
- task_id_match = re.search(r'(?:Running|Processing|task)[:\s]+([a-f0-9-]+)', logs, re.IGNORECASE)
2427
+ task_id_match = re.search(
2428
+ r"(?:Running|Processing|task)[:\s]+([a-f0-9-]+)",
2429
+ logs,
2430
+ re.IGNORECASE,
2431
+ )
1527
2432
  if task_id_match:
1528
2433
  result["current_task"] = task_id_match.group(1)
1529
2434
 
@@ -1532,7 +2437,284 @@ def cmd_serve(args: argparse.Namespace) -> int:
1532
2437
 
1533
2438
  return result
1534
2439
 
1535
- async def _detect_running_benchmark(self, vm_ip: str, container_name: str = "winarena") -> dict:
2440
+ def _get_benchmark_status(self) -> dict:
2441
+ """Get current benchmark job status with ETA calculation.
2442
+
2443
+ Returns:
2444
+ dict with job status, progress, ETA, and current task info
2445
+ """
2446
+ import time
2447
+
2448
+ # Check for live evaluation state
2449
+ live_file = Path("benchmark_live.json")
2450
+ if live_file.exists():
2451
+ try:
2452
+ live_state = json.loads(live_file.read_text())
2453
+ if live_state.get("status") == "running":
2454
+ total_tasks = live_state.get("total_tasks", 0)
2455
+ completed_tasks = live_state.get("tasks_completed", 0)
2456
+ current_task = live_state.get("current_task", {})
2457
+
2458
+ # Calculate ETA based on completed tasks
2459
+ eta_seconds = None
2460
+ avg_task_seconds = None
2461
+ if completed_tasks > 0 and total_tasks > 0:
2462
+ # Estimate from live state timestamp or use fallback
2463
+ elapsed = time.time() - live_state.get(
2464
+ "start_time", time.time()
2465
+ )
2466
+ avg_task_seconds = (
2467
+ elapsed / completed_tasks
2468
+ if completed_tasks > 0
2469
+ else 30.0
2470
+ )
2471
+ remaining_tasks = total_tasks - completed_tasks
2472
+ eta_seconds = remaining_tasks * avg_task_seconds
2473
+
2474
+ return {
2475
+ "status": "running",
2476
+ "current_job": {
2477
+ "run_id": live_state.get("run_id", "unknown"),
2478
+ "model_id": live_state.get("model_id", "unknown"),
2479
+ "total_tasks": total_tasks,
2480
+ "completed_tasks": completed_tasks,
2481
+ "current_task": current_task,
2482
+ "eta_seconds": eta_seconds,
2483
+ "avg_task_seconds": avg_task_seconds,
2484
+ },
2485
+ "queue": [], # TODO: implement queue tracking
2486
+ }
2487
+ except Exception as e:
2488
+ return {"status": "error", "error": str(e)}
2489
+
2490
+ # Fallback to current_run check
2491
+ current_run = self._get_current_run()
2492
+ if current_run.get("running"):
2493
+ return {
2494
+ "status": "running",
2495
+ "current_job": {
2496
+ "run_id": "unknown",
2497
+ "model_id": current_run.get("model", "unknown"),
2498
+ "total_tasks": current_run["progress"]["total_tasks"],
2499
+ "completed_tasks": current_run["progress"]["tasks_completed"],
2500
+ "current_task": {"task_id": current_run.get("current_task")},
2501
+ },
2502
+ "queue": [],
2503
+ }
2504
+
2505
+ return {"status": "idle"}
2506
+
2507
+ def _get_benchmark_costs(self) -> dict:
2508
+ """Get cost breakdown for current benchmark run.
2509
+
2510
+ Returns:
2511
+ dict with Azure VM, API calls, and GPU costs
2512
+ """
2513
+
2514
+ # Check for cost tracking file
2515
+ cost_file = Path("benchmark_costs.json")
2516
+ if cost_file.exists():
2517
+ try:
2518
+ return json.loads(cost_file.read_text())
2519
+ except Exception:
2520
+ pass
2521
+
2522
+ # Return placeholder structure
2523
+ return {
2524
+ "azure_vm": {
2525
+ "instance_type": "Standard_D4ds_v5",
2526
+ "hourly_rate_usd": 0.192,
2527
+ "hours_elapsed": 0.0,
2528
+ "cost_usd": 0.0,
2529
+ },
2530
+ "api_calls": {
2531
+ "anthropic": {"cost_usd": 0.0},
2532
+ "openai": {"cost_usd": 0.0},
2533
+ },
2534
+ "gpu_time": {
2535
+ "lambda_labs": {"cost_usd": 0.0},
2536
+ },
2537
+ "total_cost_usd": 0.0,
2538
+ }
2539
+
2540
+ def _get_benchmark_metrics(self) -> dict:
2541
+ """Get performance metrics for current/completed benchmarks.
2542
+
2543
+ Returns:
2544
+ dict with success rate trends, domain breakdown, episode metrics
2545
+ """
2546
+ # Check for metrics file
2547
+ metrics_file = Path("benchmark_metrics.json")
2548
+ if metrics_file.exists():
2549
+ try:
2550
+ return json.loads(metrics_file.read_text())
2551
+ except Exception:
2552
+ pass
2553
+
2554
+ # Load completed runs from benchmark_results/
2555
+ benchmark_results_dir = Path("benchmark_results")
2556
+ if not benchmark_results_dir.exists():
2557
+ return {"error": "No benchmark results found"}
2558
+
2559
+ # Find most recent run
2560
+ runs = sorted(
2561
+ benchmark_results_dir.iterdir(),
2562
+ key=lambda p: p.stat().st_mtime,
2563
+ reverse=True,
2564
+ )
2565
+ if not runs:
2566
+ return {"error": "No benchmark runs found"}
2567
+
2568
+ recent_run = runs[0]
2569
+ summary_path = recent_run / "summary.json"
2570
+ if not summary_path.exists():
2571
+ return {"error": f"No summary.json in {recent_run.name}"}
2572
+
2573
+ try:
2574
+ summary = json.loads(summary_path.read_text())
2575
+
2576
+ # Build domain breakdown from tasks
2577
+ domain_breakdown = {}
2578
+ tasks_dir = recent_run / "tasks"
2579
+ if tasks_dir.exists():
2580
+ for task_dir in tasks_dir.iterdir():
2581
+ if not task_dir.is_dir():
2582
+ continue
2583
+
2584
+ task_json = task_dir / "task.json"
2585
+ execution_json = task_dir / "execution.json"
2586
+ if not (task_json.exists() and execution_json.exists()):
2587
+ continue
2588
+
2589
+ try:
2590
+ task_def = json.loads(task_json.read_text())
2591
+ execution = json.loads(execution_json.read_text())
2592
+
2593
+ domain = task_def.get("domain", "unknown")
2594
+ if domain not in domain_breakdown:
2595
+ domain_breakdown[domain] = {
2596
+ "total": 0,
2597
+ "success": 0,
2598
+ "rate": 0.0,
2599
+ "avg_steps": 0.0,
2600
+ "total_steps": 0,
2601
+ }
2602
+
2603
+ domain_breakdown[domain]["total"] += 1
2604
+ if execution.get("success"):
2605
+ domain_breakdown[domain]["success"] += 1
2606
+ domain_breakdown[domain]["total_steps"] += execution.get(
2607
+ "num_steps", 0
2608
+ )
2609
+
2610
+ except Exception:
2611
+ continue
2612
+
2613
+ # Calculate averages
2614
+ for domain, stats in domain_breakdown.items():
2615
+ if stats["total"] > 0:
2616
+ stats["rate"] = stats["success"] / stats["total"]
2617
+ stats["avg_steps"] = stats["total_steps"] / stats["total"]
2618
+
2619
+ return {
2620
+ "success_rate_over_time": [], # TODO: implement trend tracking
2621
+ "avg_steps_per_task": [], # TODO: implement trend tracking
2622
+ "domain_breakdown": domain_breakdown,
2623
+ "episode_success_metrics": {
2624
+ "first_action_accuracy": summary.get(
2625
+ "first_action_accuracy", 0.0
2626
+ ),
2627
+ "episode_success_rate": summary.get("success_rate", 0.0),
2628
+ "avg_steps_to_success": summary.get("avg_steps", 0.0),
2629
+ "avg_steps_to_failure": 0.0, # TODO: calculate from failed tasks
2630
+ },
2631
+ }
2632
+ except Exception as e:
2633
+ return {"error": f"Failed to load metrics: {str(e)}"}
2634
+
2635
+ def _get_benchmark_workers(self) -> dict:
2636
+ """Get worker status and utilization.
2637
+
2638
+ Returns:
2639
+ dict with total/active/idle workers and per-worker stats
2640
+ """
2641
+ # Get VM registry
2642
+ vms = self._fetch_vm_registry()
2643
+
2644
+ active_workers = [v for v in vms if v.get("status") == "online"]
2645
+ idle_workers = [v for v in vms if v.get("status") != "online"]
2646
+
2647
+ workers = []
2648
+ for vm in vms:
2649
+ workers.append(
2650
+ {
2651
+ "worker_id": vm.get("name", "unknown"),
2652
+ "status": "running" if vm.get("status") == "online" else "idle",
2653
+ "current_task": vm.get("current_task"),
2654
+ "tasks_completed": vm.get("tasks_completed", 0),
2655
+ "uptime_seconds": vm.get("uptime_seconds", 0),
2656
+ "idle_time_seconds": vm.get("idle_time_seconds", 0),
2657
+ }
2658
+ )
2659
+
2660
+ return {
2661
+ "total_workers": len(vms),
2662
+ "active_workers": len(active_workers),
2663
+ "idle_workers": len(idle_workers),
2664
+ "workers": workers,
2665
+ }
2666
+
2667
+ def _get_benchmark_runs(self) -> list[dict]:
2668
+ """Load all benchmark runs from benchmark_results directory.
2669
+
2670
+ Returns:
2671
+ List of benchmark run summaries sorted by timestamp (newest first)
2672
+ """
2673
+ results_dir = Path("benchmark_results")
2674
+ if not results_dir.exists():
2675
+ return []
2676
+
2677
+ runs = []
2678
+ for run_dir in results_dir.iterdir():
2679
+ if run_dir.is_dir():
2680
+ summary_file = run_dir / "summary.json"
2681
+ if summary_file.exists():
2682
+ try:
2683
+ summary = json.loads(summary_file.read_text())
2684
+ runs.append(summary)
2685
+ except (json.JSONDecodeError, IOError) as e:
2686
+ print(f"Warning: Failed to load {summary_file}: {e}")
2687
+
2688
+ # Sort by run_name descending (newest first)
2689
+ runs.sort(key=lambda r: r.get("run_name", ""), reverse=True)
2690
+ return runs
2691
+
2692
+ def _get_task_execution(self, run_name: str, task_id: str) -> dict:
2693
+ """Load task execution details from execution.json.
2694
+
2695
+ Args:
2696
+ run_name: Name of the benchmark run
2697
+ task_id: Task identifier
2698
+
2699
+ Returns:
2700
+ Task execution data with steps and screenshots
2701
+ """
2702
+ results_dir = Path("benchmark_results")
2703
+ execution_file = (
2704
+ results_dir / run_name / "tasks" / task_id / "execution.json"
2705
+ )
2706
+
2707
+ if not execution_file.exists():
2708
+ raise FileNotFoundError(f"Execution file not found: {execution_file}")
2709
+
2710
+ try:
2711
+ return json.loads(execution_file.read_text())
2712
+ except (json.JSONDecodeError, IOError) as e:
2713
+ raise Exception(f"Failed to load execution data: {e}")
2714
+
2715
+ async def _detect_running_benchmark(
2716
+ self, vm_ip: str, container_name: str = "winarena"
2717
+ ) -> dict:
1536
2718
  """Detect if a benchmark is running on the VM and extract progress.
1537
2719
 
1538
2720
  SSH into VM and check:
@@ -1563,11 +2745,20 @@ def cmd_serve(args: argparse.Namespace) -> int:
1563
2745
  try:
1564
2746
  # Check if benchmark process is running
1565
2747
  process_check = subprocess.run(
1566
- ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
1567
- "-i", str(Path.home() / ".ssh" / "id_rsa"),
1568
- f"azureuser@{vm_ip}",
1569
- f"docker exec {container_name} pgrep -f 'python.*run.py' 2>/dev/null || echo ''"],
1570
- capture_output=True, text=True, timeout=10
2748
+ [
2749
+ "ssh",
2750
+ "-o",
2751
+ "StrictHostKeyChecking=no",
2752
+ "-o",
2753
+ "ConnectTimeout=5",
2754
+ "-i",
2755
+ str(Path.home() / ".ssh" / "id_rsa"),
2756
+ f"azureuser@{vm_ip}",
2757
+ f"docker exec {container_name} pgrep -f 'python.*run.py' 2>/dev/null || echo ''",
2758
+ ],
2759
+ capture_output=True,
2760
+ text=True,
2761
+ timeout=10,
1571
2762
  )
1572
2763
 
1573
2764
  if process_check.returncode == 0 and process_check.stdout.strip():
@@ -1575,11 +2766,20 @@ def cmd_serve(args: argparse.Namespace) -> int:
1575
2766
 
1576
2767
  # Get benchmark log
1577
2768
  log_check = subprocess.run(
1578
- ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
1579
- "-i", str(Path.home() / ".ssh" / "id_rsa"),
1580
- f"azureuser@{vm_ip}",
1581
- "tail -100 /tmp/waa_benchmark.log 2>/dev/null || echo ''"],
1582
- capture_output=True, text=True, timeout=10
2769
+ [
2770
+ "ssh",
2771
+ "-o",
2772
+ "StrictHostKeyChecking=no",
2773
+ "-o",
2774
+ "ConnectTimeout=5",
2775
+ "-i",
2776
+ str(Path.home() / ".ssh" / "id_rsa"),
2777
+ f"azureuser@{vm_ip}",
2778
+ "tail -100 /tmp/waa_benchmark.log 2>/dev/null || echo ''",
2779
+ ],
2780
+ capture_output=True,
2781
+ text=True,
2782
+ timeout=10,
1583
2783
  )
1584
2784
 
1585
2785
  if log_check.returncode == 0 and log_check.stdout.strip():
@@ -1588,22 +2788,28 @@ def cmd_serve(args: argparse.Namespace) -> int:
1588
2788
 
1589
2789
  # Parse progress from logs
1590
2790
  # Look for patterns like "Task 5/30" or "Completed: 5, Remaining: 25"
1591
- task_match = re.search(r'Task\s+(\d+)/(\d+)', logs)
2791
+ task_match = re.search(r"Task\s+(\d+)/(\d+)", logs)
1592
2792
  if task_match:
1593
- result["progress"]["tasks_completed"] = int(task_match.group(1))
2793
+ result["progress"]["tasks_completed"] = int(
2794
+ task_match.group(1)
2795
+ )
1594
2796
  result["progress"]["total_tasks"] = int(task_match.group(2))
1595
2797
 
1596
2798
  # Extract current task ID
1597
- task_id_match = re.search(r'(?:Running|Processing) task:\s*(\S+)', logs)
2799
+ task_id_match = re.search(
2800
+ r"(?:Running|Processing) task:\s*(\S+)", logs
2801
+ )
1598
2802
  if task_id_match:
1599
2803
  result["current_task"] = task_id_match.group(1)
1600
2804
 
1601
2805
  # Extract step info
1602
- step_match = re.search(r'Step\s+(\d+)', logs)
2806
+ step_match = re.search(r"Step\s+(\d+)", logs)
1603
2807
  if step_match:
1604
- result["progress"]["current_step"] = int(step_match.group(1))
2808
+ result["progress"]["current_step"] = int(
2809
+ step_match.group(1)
2810
+ )
1605
2811
 
1606
- except Exception as e:
2812
+ except Exception:
1607
2813
  # SSH or parsing failed - leave defaults
1608
2814
  pass
1609
2815
 
@@ -1627,13 +2833,13 @@ def cmd_serve(args: argparse.Namespace) -> int:
1627
2833
  # Search backwards from most recent
1628
2834
  for line in reversed(log_lines):
1629
2835
  # Check for explicit result
1630
- if 'Result: PASS' in line or 'completed successfully' in line:
2836
+ if "Result: PASS" in line or "completed successfully" in line:
1631
2837
  success = True
1632
- elif 'Result: FAIL' in line or 'failed' in line.lower():
2838
+ elif "Result: FAIL" in line or "failed" in line.lower():
1633
2839
  success = False
1634
2840
 
1635
2841
  # Check for score
1636
- score_match = re.search(r'Score:\s*([\d.]+)', line)
2842
+ score_match = re.search(r"Score:\s*([\d.]+)", line)
1637
2843
  if score_match:
1638
2844
  try:
1639
2845
  score = float(score_match.group(1))
@@ -1642,9 +2848,9 @@ def cmd_serve(args: argparse.Namespace) -> int:
1642
2848
 
1643
2849
  # Check for task-specific completion
1644
2850
  if task_id in line:
1645
- if 'success' in line.lower() or 'pass' in line.lower():
2851
+ if "success" in line.lower() or "pass" in line.lower():
1646
2852
  success = True
1647
- elif 'fail' in line.lower() or 'error' in line.lower():
2853
+ elif "fail" in line.lower() or "error" in line.lower():
1648
2854
  success = False
1649
2855
 
1650
2856
  # Default to True if no explicit failure found (backwards compatible)
@@ -1674,11 +2880,11 @@ def cmd_serve(args: argparse.Namespace) -> int:
1674
2880
 
1675
2881
  # Set SSE headers
1676
2882
  self.send_response(200)
1677
- self.send_header('Content-Type', 'text/event-stream')
1678
- self.send_header('Cache-Control', 'no-cache')
1679
- self.send_header('Access-Control-Allow-Origin', '*')
1680
- self.send_header('Connection', 'keep-alive')
1681
- self.send_header('X-Accel-Buffering', 'no') # Disable nginx buffering
2883
+ self.send_header("Content-Type", "text/event-stream")
2884
+ self.send_header("Cache-Control", "no-cache")
2885
+ self.send_header("Access-Control-Allow-Origin", "*")
2886
+ self.send_header("Connection", "keep-alive")
2887
+ self.send_header("X-Accel-Buffering", "no") # Disable nginx buffering
1682
2888
  self.end_headers()
1683
2889
 
1684
2890
  # Track connection state
@@ -1691,7 +2897,7 @@ def cmd_serve(args: argparse.Namespace) -> int:
1691
2897
  return False
1692
2898
  try:
1693
2899
  event_str = f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
1694
- self.wfile.write(event_str.encode('utf-8'))
2900
+ self.wfile.write(event_str.encode("utf-8"))
1695
2901
  self.wfile.flush()
1696
2902
  return True
1697
2903
  except (BrokenPipeError, ConnectionResetError, ConnectionAbortedError):
@@ -1734,11 +2940,10 @@ def cmd_serve(args: argparse.Namespace) -> int:
1734
2940
  recent_log_lines = []
1735
2941
 
1736
2942
  # Send initial connected event
1737
- if not send_event("connected", {
1738
- "timestamp": time.time(),
1739
- "interval": interval,
1740
- "version": "1.0"
1741
- }):
2943
+ if not send_event(
2944
+ "connected",
2945
+ {"timestamp": time.time(), "interval": interval, "version": "1.0"},
2946
+ ):
1742
2947
  return
1743
2948
 
1744
2949
  try:
@@ -1763,17 +2968,25 @@ def cmd_serve(args: argparse.Namespace) -> int:
1763
2968
  tasks = self._fetch_background_tasks()
1764
2969
 
1765
2970
  # Send VM status event
1766
- vm_task = next((t for t in tasks if t.get("task_type") == "docker_container"), None)
2971
+ vm_task = next(
2972
+ (t for t in tasks if t.get("task_type") == "docker_container"),
2973
+ None,
2974
+ )
1767
2975
  if vm_task:
1768
2976
  vm_data = {
1769
2977
  "type": "vm_status",
1770
- "connected": vm_task.get("status") in ["running", "completed"],
2978
+ "connected": vm_task.get("status")
2979
+ in ["running", "completed"],
1771
2980
  "phase": vm_task.get("phase", "unknown"),
1772
- "waa_ready": vm_task.get("metadata", {}).get("waa_server_ready", False),
2981
+ "waa_ready": vm_task.get("metadata", {}).get(
2982
+ "waa_server_ready", False
2983
+ ),
1773
2984
  "probe": {
1774
- "status": vm_task.get("metadata", {}).get("probe_response", "unknown"),
2985
+ "status": vm_task.get("metadata", {}).get(
2986
+ "probe_response", "unknown"
2987
+ ),
1775
2988
  "vnc_url": vm_task.get("metadata", {}).get("vnc_url"),
1776
- }
2989
+ },
1777
2990
  }
1778
2991
 
1779
2992
  if not send_event("status", vm_data):
@@ -1783,27 +2996,47 @@ def cmd_serve(args: argparse.Namespace) -> int:
1783
2996
  if vm_data["waa_ready"]:
1784
2997
  # Get VM IP from tasks
1785
2998
  vm_ip = None
1786
- azure_vm = next((t for t in tasks if t.get("task_type") == "vm_provision"), None)
2999
+ azure_vm = next(
3000
+ (
3001
+ t
3002
+ for t in tasks
3003
+ if t.get("task_type") == "vm_provision"
3004
+ ),
3005
+ None,
3006
+ )
1787
3007
  if azure_vm:
1788
3008
  vm_ip = azure_vm.get("metadata", {}).get("ip_address")
1789
3009
 
1790
3010
  if vm_ip:
1791
3011
  # Detect running benchmark using sync version
1792
3012
  benchmark_status = self._detect_running_benchmark_sync(
1793
- vm_ip, vm_task.get("metadata", {}).get("container", "winarena")
3013
+ vm_ip,
3014
+ vm_task.get("metadata", {}).get(
3015
+ "container", "winarena"
3016
+ ),
1794
3017
  )
1795
3018
 
1796
3019
  if benchmark_status["running"]:
1797
3020
  # Store log lines for result parsing
1798
3021
  if benchmark_status.get("recent_logs"):
1799
- recent_log_lines = benchmark_status["recent_logs"].split('\n')
3022
+ recent_log_lines = benchmark_status[
3023
+ "recent_logs"
3024
+ ].split("\n")
1800
3025
 
1801
3026
  # Send progress event
1802
3027
  progress_data = {
1803
- "tasks_completed": benchmark_status["progress"]["tasks_completed"],
1804
- "total_tasks": benchmark_status["progress"]["total_tasks"],
1805
- "current_task": benchmark_status["current_task"],
1806
- "current_step": benchmark_status["progress"]["current_step"],
3028
+ "tasks_completed": benchmark_status["progress"][
3029
+ "tasks_completed"
3030
+ ],
3031
+ "total_tasks": benchmark_status["progress"][
3032
+ "total_tasks"
3033
+ ],
3034
+ "current_task": benchmark_status[
3035
+ "current_task"
3036
+ ],
3037
+ "current_step": benchmark_status["progress"][
3038
+ "current_step"
3039
+ ],
1807
3040
  }
1808
3041
 
1809
3042
  if not send_event("progress", progress_data):
@@ -1814,13 +3047,17 @@ def cmd_serve(args: argparse.Namespace) -> int:
1814
3047
  if current_task and current_task != last_task:
1815
3048
  if last_task is not None:
1816
3049
  # Previous task completed - parse result from logs
1817
- result = self._parse_task_result(recent_log_lines, last_task)
3050
+ result = self._parse_task_result(
3051
+ recent_log_lines, last_task
3052
+ )
1818
3053
  complete_data = {
1819
3054
  "task_id": last_task,
1820
3055
  "success": result["success"],
1821
3056
  "score": result["score"],
1822
3057
  }
1823
- if not send_event("task_complete", complete_data):
3058
+ if not send_event(
3059
+ "task_complete", complete_data
3060
+ ):
1824
3061
  break
1825
3062
 
1826
3063
  last_task = current_task
@@ -1832,7 +3069,9 @@ def cmd_serve(args: argparse.Namespace) -> int:
1832
3069
  progress = json.loads(progress_file.read_text())
1833
3070
  if progress.get("status") == "running":
1834
3071
  progress_data = {
1835
- "tasks_completed": progress.get("tasks_complete", 0),
3072
+ "tasks_completed": progress.get(
3073
+ "tasks_complete", 0
3074
+ ),
1836
3075
  "total_tasks": progress.get("tasks_total", 0),
1837
3076
  "current_task": progress.get("provider", "unknown"),
1838
3077
  }
@@ -1858,7 +3097,244 @@ def cmd_serve(args: argparse.Namespace) -> int:
1858
3097
  # Cleanup - connection is ending
1859
3098
  client_connected = False
1860
3099
 
1861
- def _detect_running_benchmark_sync(self, vm_ip: str, container_name: str = "winarena") -> dict:
3100
+ def _stream_azure_ops_updates(self):
3101
+ """Stream Server-Sent Events for Azure operations status updates.
3102
+
3103
+ Monitors azure_ops_status.json for changes and streams updates.
3104
+ Uses file modification time to detect changes efficiently.
3105
+
3106
+ Streams events:
3107
+ - connected: Initial connection event
3108
+ - status: Azure ops status update when file changes
3109
+ - heartbeat: Keep-alive signal every 30 seconds
3110
+ - error: Error messages
3111
+ """
3112
+ import time
3113
+ import select
3114
+ from pathlib import Path
3115
+
3116
+ HEARTBEAT_INTERVAL = 30 # seconds
3117
+ CHECK_INTERVAL = 1 # Check file every second
3118
+
3119
+ # Set SSE headers
3120
+ self.send_response(200)
3121
+ self.send_header("Content-Type", "text/event-stream")
3122
+ self.send_header("Cache-Control", "no-cache")
3123
+ self.send_header("Access-Control-Allow-Origin", "*")
3124
+ self.send_header("Connection", "keep-alive")
3125
+ self.send_header("X-Accel-Buffering", "no") # Disable nginx buffering
3126
+ self.end_headers()
3127
+
3128
+ # Track connection state
3129
+ client_connected = True
3130
+ last_mtime = 0.0
3131
+ last_session_mtime = 0.0
3132
+ last_heartbeat = time.time()
3133
+
3134
+ def send_event(event_type: str, data: dict) -> bool:
3135
+ """Send an SSE event. Returns False if client disconnected."""
3136
+ nonlocal client_connected
3137
+ if not client_connected:
3138
+ return False
3139
+ try:
3140
+ event_str = f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
3141
+ self.wfile.write(event_str.encode("utf-8"))
3142
+ self.wfile.flush()
3143
+ return True
3144
+ except (BrokenPipeError, ConnectionResetError, ConnectionAbortedError):
3145
+ client_connected = False
3146
+ return False
3147
+ except Exception as e:
3148
+ print(f"Azure ops SSE send error: {e}")
3149
+ client_connected = False
3150
+ return False
3151
+
3152
+ def check_client_connected() -> bool:
3153
+ """Check if client is still connected using socket select."""
3154
+ nonlocal client_connected
3155
+ if not client_connected:
3156
+ return False
3157
+ try:
3158
+ rlist, _, xlist = select.select([self.rfile], [], [self.rfile], 0)
3159
+ if xlist:
3160
+ client_connected = False
3161
+ return False
3162
+ if rlist:
3163
+ data = self.rfile.read(1)
3164
+ if not data:
3165
+ client_connected = False
3166
+ return False
3167
+ return True
3168
+ except Exception:
3169
+ client_connected = False
3170
+ return False
3171
+
3172
+ # Status file path
3173
+ from openadapt_ml.benchmarks.azure_ops_tracker import (
3174
+ DEFAULT_OUTPUT_FILE,
3175
+ read_status,
3176
+ )
3177
+ from openadapt_ml.benchmarks.session_tracker import (
3178
+ get_session,
3179
+ update_session_vm_state,
3180
+ DEFAULT_SESSION_FILE,
3181
+ )
3182
+
3183
+ status_file = Path(DEFAULT_OUTPUT_FILE)
3184
+ session_file = Path(DEFAULT_SESSION_FILE)
3185
+
3186
+ def compute_server_side_values(status: dict) -> dict:
3187
+ """Get elapsed_seconds and cost_usd from session tracker for persistence."""
3188
+ # Get session data (persistent across refreshes)
3189
+ session = get_session()
3190
+
3191
+ # Update session based on VM state if we have VM info
3192
+ # IMPORTANT: Only pass vm_ip if it's truthy to avoid
3193
+ # overwriting session's stable vm_ip with None
3194
+ if status.get("vm_state") and status.get("vm_state") != "unknown":
3195
+ status_vm_ip = status.get("vm_ip")
3196
+ # Build update kwargs - only include vm_ip if present
3197
+ update_kwargs = {
3198
+ "vm_state": status["vm_state"],
3199
+ "vm_size": status.get("vm_size"),
3200
+ }
3201
+ if status_vm_ip: # Only include if truthy
3202
+ update_kwargs["vm_ip"] = status_vm_ip
3203
+ session = update_session_vm_state(**update_kwargs)
3204
+
3205
+ # Use session's vm_ip as authoritative source
3206
+ # This prevents IP flickering when status file has stale/None values
3207
+ if session.get("vm_ip"):
3208
+ status["vm_ip"] = session["vm_ip"]
3209
+
3210
+ # Use session's elapsed_seconds and cost_usd for persistence
3211
+ if (
3212
+ session.get("is_active")
3213
+ or session.get("accumulated_seconds", 0) > 0
3214
+ ):
3215
+ status["elapsed_seconds"] = session.get("elapsed_seconds", 0.0)
3216
+ status["cost_usd"] = session.get("cost_usd", 0.0)
3217
+ status["started_at"] = session.get("started_at")
3218
+ status["session_id"] = session.get("session_id")
3219
+ status["session_is_active"] = session.get("is_active", False)
3220
+ # Include accumulated time from previous sessions for hybrid display
3221
+ status["accumulated_seconds"] = session.get(
3222
+ "accumulated_seconds", 0.0
3223
+ )
3224
+ # Calculate current session time (total - accumulated)
3225
+ current_session_seconds = max(
3226
+ 0, status["elapsed_seconds"] - status["accumulated_seconds"]
3227
+ )
3228
+ status["current_session_seconds"] = current_session_seconds
3229
+ hourly_rate = session.get("hourly_rate_usd", 0.422)
3230
+ status["current_session_cost_usd"] = (
3231
+ current_session_seconds / 3600
3232
+ ) * hourly_rate
3233
+
3234
+ try:
3235
+ tunnel_mgr = get_tunnel_manager()
3236
+ tunnel_status = tunnel_mgr.get_tunnel_status()
3237
+ status["tunnels"] = {
3238
+ name: {
3239
+ "active": s.active,
3240
+ "local_port": s.local_port,
3241
+ "remote_endpoint": s.remote_endpoint,
3242
+ "pid": s.pid,
3243
+ "error": s.error,
3244
+ }
3245
+ for name, s in tunnel_status.items()
3246
+ }
3247
+ except Exception as e:
3248
+ status["tunnels"] = {"error": str(e)}
3249
+
3250
+ return status
3251
+
3252
+ # Send initial connected event
3253
+ if not send_event(
3254
+ "connected",
3255
+ {"timestamp": time.time(), "version": "1.0"},
3256
+ ):
3257
+ return
3258
+
3259
+ # Send initial status immediately
3260
+ try:
3261
+ status = compute_server_side_values(read_status())
3262
+ if not send_event("status", status):
3263
+ return
3264
+ if status_file.exists():
3265
+ last_mtime = status_file.stat().st_mtime
3266
+ except Exception as e:
3267
+ send_event("error", {"message": str(e)})
3268
+
3269
+ try:
3270
+ iteration_count = 0
3271
+ max_iterations = 3600 # Max 1 hour of streaming
3272
+ last_status_send = 0.0
3273
+ STATUS_SEND_INTERVAL = 2 # Send status every 2 seconds for live updates
3274
+
3275
+ while client_connected and iteration_count < max_iterations:
3276
+ iteration_count += 1
3277
+ current_time = time.time()
3278
+
3279
+ # Check client connection
3280
+ if not check_client_connected():
3281
+ break
3282
+
3283
+ # Send heartbeat every 30 seconds
3284
+ if current_time - last_heartbeat >= HEARTBEAT_INTERVAL:
3285
+ if not send_event("heartbeat", {"timestamp": current_time}):
3286
+ break
3287
+ last_heartbeat = current_time
3288
+
3289
+ # Check if status or session file changed OR if enough time passed
3290
+ try:
3291
+ status_changed = False
3292
+ session_changed = False
3293
+ time_to_send = (
3294
+ current_time - last_status_send >= STATUS_SEND_INTERVAL
3295
+ )
3296
+
3297
+ if status_file.exists():
3298
+ current_mtime = status_file.stat().st_mtime
3299
+ if current_mtime > last_mtime:
3300
+ status_changed = True
3301
+ last_mtime = current_mtime
3302
+
3303
+ if session_file.exists():
3304
+ current_session_mtime = session_file.stat().st_mtime
3305
+ if current_session_mtime > last_session_mtime:
3306
+ session_changed = True
3307
+ last_session_mtime = current_session_mtime
3308
+
3309
+ # Send status if file changed OR periodic timer expired
3310
+ # This ensures live elapsed time/cost updates even without file changes
3311
+ if status_changed or session_changed or time_to_send:
3312
+ # File changed or time to send - send update with session values
3313
+ status = compute_server_side_values(read_status())
3314
+ if not send_event("status", status):
3315
+ break
3316
+ last_status_send = current_time
3317
+ except Exception as e:
3318
+ # File access error - log but continue
3319
+ print(f"Azure ops SSE file check error: {e}")
3320
+
3321
+ # Sleep briefly before next check
3322
+ try:
3323
+ select.select([self.rfile], [], [], CHECK_INTERVAL)
3324
+ except Exception:
3325
+ break
3326
+
3327
+ except (BrokenPipeError, ConnectionResetError, ConnectionAbortedError):
3328
+ # Client disconnected - normal
3329
+ pass
3330
+ except Exception as e:
3331
+ send_event("error", {"message": str(e)})
3332
+ finally:
3333
+ client_connected = False
3334
+
3335
+ def _detect_running_benchmark_sync(
3336
+ self, vm_ip: str, container_name: str = "winarena"
3337
+ ) -> dict:
1862
3338
  """Synchronous version of _detect_running_benchmark.
1863
3339
 
1864
3340
  Avoids creating a new event loop on each call which causes issues
@@ -1881,11 +3357,20 @@ def cmd_serve(args: argparse.Namespace) -> int:
1881
3357
  try:
1882
3358
  # Check if benchmark process is running
1883
3359
  process_check = subprocess.run(
1884
- ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
1885
- "-i", str(Path.home() / ".ssh" / "id_rsa"),
1886
- f"azureuser@{vm_ip}",
1887
- f"docker exec {container_name} pgrep -f 'python.*run.py' 2>/dev/null || echo ''"],
1888
- capture_output=True, text=True, timeout=10
3360
+ [
3361
+ "ssh",
3362
+ "-o",
3363
+ "StrictHostKeyChecking=no",
3364
+ "-o",
3365
+ "ConnectTimeout=5",
3366
+ "-i",
3367
+ str(Path.home() / ".ssh" / "id_rsa"),
3368
+ f"azureuser@{vm_ip}",
3369
+ f"docker exec {container_name} pgrep -f 'python.*run.py' 2>/dev/null || echo ''",
3370
+ ],
3371
+ capture_output=True,
3372
+ text=True,
3373
+ timeout=10,
1889
3374
  )
1890
3375
 
1891
3376
  if process_check.returncode == 0 and process_check.stdout.strip():
@@ -1893,11 +3378,20 @@ def cmd_serve(args: argparse.Namespace) -> int:
1893
3378
 
1894
3379
  # Get benchmark log
1895
3380
  log_check = subprocess.run(
1896
- ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
1897
- "-i", str(Path.home() / ".ssh" / "id_rsa"),
1898
- f"azureuser@{vm_ip}",
1899
- "tail -100 /tmp/waa_benchmark.log 2>/dev/null || echo ''"],
1900
- capture_output=True, text=True, timeout=10
3381
+ [
3382
+ "ssh",
3383
+ "-o",
3384
+ "StrictHostKeyChecking=no",
3385
+ "-o",
3386
+ "ConnectTimeout=5",
3387
+ "-i",
3388
+ str(Path.home() / ".ssh" / "id_rsa"),
3389
+ f"azureuser@{vm_ip}",
3390
+ "tail -100 /tmp/waa_benchmark.log 2>/dev/null || echo ''",
3391
+ ],
3392
+ capture_output=True,
3393
+ text=True,
3394
+ timeout=10,
1901
3395
  )
1902
3396
 
1903
3397
  if log_check.returncode == 0 and log_check.stdout.strip():
@@ -1905,20 +3399,26 @@ def cmd_serve(args: argparse.Namespace) -> int:
1905
3399
  result["recent_logs"] = logs[-500:] # Last 500 chars
1906
3400
 
1907
3401
  # Parse progress from logs
1908
- task_match = re.search(r'Task\s+(\d+)/(\d+)', logs)
3402
+ task_match = re.search(r"Task\s+(\d+)/(\d+)", logs)
1909
3403
  if task_match:
1910
- result["progress"]["tasks_completed"] = int(task_match.group(1))
3404
+ result["progress"]["tasks_completed"] = int(
3405
+ task_match.group(1)
3406
+ )
1911
3407
  result["progress"]["total_tasks"] = int(task_match.group(2))
1912
3408
 
1913
3409
  # Extract current task ID
1914
- task_id_match = re.search(r'(?:Running|Processing) task:\s*(\S+)', logs)
3410
+ task_id_match = re.search(
3411
+ r"(?:Running|Processing) task:\s*(\S+)", logs
3412
+ )
1915
3413
  if task_id_match:
1916
3414
  result["current_task"] = task_id_match.group(1)
1917
3415
 
1918
3416
  # Extract step info
1919
- step_match = re.search(r'Step\s+(\d+)', logs)
3417
+ step_match = re.search(r"Step\s+(\d+)", logs)
1920
3418
  if step_match:
1921
- result["progress"]["current_step"] = int(step_match.group(1))
3419
+ result["progress"]["current_step"] = int(
3420
+ step_match.group(1)
3421
+ )
1922
3422
 
1923
3423
  except Exception:
1924
3424
  # SSH or parsing failed - leave defaults
@@ -1949,7 +3449,7 @@ def cmd_serve(args: argparse.Namespace) -> int:
1949
3449
  "vnc_port": vm_data.get("vnc_port", 8006),
1950
3450
  "waa_port": vm_data.get("waa_port", 5000),
1951
3451
  "docker_container": vm_data.get("docker_container", "win11-waa"),
1952
- "internal_ip": vm_data.get("internal_ip", "20.20.20.21")
3452
+ "internal_ip": vm_data.get("internal_ip", "20.20.20.21"),
1953
3453
  }
1954
3454
 
1955
3455
  vms.append(new_vm)
@@ -1957,7 +3457,7 @@ def cmd_serve(args: argparse.Namespace) -> int:
1957
3457
  # Save registry
1958
3458
  try:
1959
3459
  registry_file.parent.mkdir(parents=True, exist_ok=True)
1960
- with open(registry_file, 'w') as f:
3460
+ with open(registry_file, "w") as f:
1961
3461
  json.dump(vms, f, indent=2)
1962
3462
  return {"status": "success", "vm": new_vm}
1963
3463
  except Exception as e:
@@ -1990,12 +3490,20 @@ def cmd_serve(args: argparse.Namespace) -> int:
1990
3490
 
1991
3491
  # Build CLI command
1992
3492
  cmd = [
1993
- "uv", "run", "python", "-m", "openadapt_ml.benchmarks.cli",
1994
- "vm", "run-waa",
1995
- "--num-tasks", str(params.get("num_tasks", 5)),
1996
- "--model", params.get("model", "gpt-4o"),
1997
- "--agent", params.get("agent", "navi"),
1998
- "--no-open" # Don't open viewer (already open)
3493
+ "uv",
3494
+ "run",
3495
+ "python",
3496
+ "-m",
3497
+ "openadapt_ml.benchmarks.cli",
3498
+ "vm",
3499
+ "run-waa",
3500
+ "--num-tasks",
3501
+ str(params.get("num_tasks", 5)),
3502
+ "--model",
3503
+ params.get("model", "gpt-4o"),
3504
+ "--agent",
3505
+ params.get("agent", "navi"),
3506
+ "--no-open", # Don't open viewer (already open)
1999
3507
  ]
2000
3508
 
2001
3509
  # Add task selection args
@@ -2016,7 +3524,9 @@ def cmd_serve(args: argparse.Namespace) -> int:
2016
3524
  num_tasks = params.get("num_tasks", 5)
2017
3525
  agent = params.get("agent", "navi")
2018
3526
 
2019
- print(f"\n[Benchmark] Starting WAA benchmark: model={model}, tasks={num_tasks}, agent={agent}")
3527
+ print(
3528
+ f"\n[Benchmark] Starting WAA benchmark: model={model}, tasks={num_tasks}, agent={agent}"
3529
+ )
2020
3530
  print(f"[Benchmark] Task selection: {task_selection}")
2021
3531
  if task_selection == "domain":
2022
3532
  print(f"[Benchmark] Domain: {params.get('domain', 'general')}")
@@ -2024,15 +3534,19 @@ def cmd_serve(args: argparse.Namespace) -> int:
2024
3534
  print(f"[Benchmark] Task IDs: {params.get('task_ids', [])}")
2025
3535
  print(f"[Benchmark] Command: {' '.join(cmd)}")
2026
3536
 
2027
- progress_file.write_text(json.dumps({
2028
- "status": "running",
2029
- "model": model,
2030
- "num_tasks": num_tasks,
2031
- "agent": agent,
2032
- "task_selection": task_selection,
2033
- "tasks_complete": 0,
2034
- "message": f"Starting {model} benchmark with {num_tasks} tasks..."
2035
- }))
3537
+ progress_file.write_text(
3538
+ json.dumps(
3539
+ {
3540
+ "status": "running",
3541
+ "model": model,
3542
+ "num_tasks": num_tasks,
3543
+ "agent": agent,
3544
+ "task_selection": task_selection,
3545
+ "tasks_complete": 0,
3546
+ "message": f"Starting {model} benchmark with {num_tasks} tasks...",
3547
+ }
3548
+ )
3549
+ )
2036
3550
 
2037
3551
  # Copy environment with loaded vars
2038
3552
  env = os.environ.copy()
@@ -2040,11 +3554,7 @@ def cmd_serve(args: argparse.Namespace) -> int:
2040
3554
  # Run in background thread
2041
3555
  def run():
2042
3556
  result = subprocess.run(
2043
- cmd,
2044
- capture_output=True,
2045
- text=True,
2046
- cwd=str(project_root),
2047
- env=env
3557
+ cmd, capture_output=True, text=True, cwd=str(project_root), env=env
2048
3558
  )
2049
3559
 
2050
3560
  print(f"\n[Benchmark] Output:\n{result.stdout}")
@@ -2052,24 +3562,34 @@ def cmd_serve(args: argparse.Namespace) -> int:
2052
3562
  print(f"[Benchmark] Stderr: {result.stderr}")
2053
3563
 
2054
3564
  if result.returncode == 0:
2055
- print(f"[Benchmark] Complete. Regenerating viewer...")
2056
- progress_file.write_text(json.dumps({
2057
- "status": "complete",
2058
- "model": model,
2059
- "num_tasks": num_tasks,
2060
- "message": "Benchmark complete. Refresh to see results."
2061
- }))
3565
+ print("[Benchmark] Complete. Regenerating viewer...")
3566
+ progress_file.write_text(
3567
+ json.dumps(
3568
+ {
3569
+ "status": "complete",
3570
+ "model": model,
3571
+ "num_tasks": num_tasks,
3572
+ "message": "Benchmark complete. Refresh to see results.",
3573
+ }
3574
+ )
3575
+ )
2062
3576
  # Regenerate benchmark viewer
2063
3577
  _regenerate_benchmark_viewer_if_available(serve_dir)
2064
3578
  else:
2065
- error_msg = result.stderr[:200] if result.stderr else "Unknown error"
3579
+ error_msg = (
3580
+ result.stderr[:200] if result.stderr else "Unknown error"
3581
+ )
2066
3582
  print(f"[Benchmark] Failed: {error_msg}")
2067
- progress_file.write_text(json.dumps({
2068
- "status": "error",
2069
- "model": model,
2070
- "num_tasks": num_tasks,
2071
- "message": f"Benchmark failed: {error_msg}"
2072
- }))
3583
+ progress_file.write_text(
3584
+ json.dumps(
3585
+ {
3586
+ "status": "error",
3587
+ "model": model,
3588
+ "num_tasks": num_tasks,
3589
+ "message": f"Benchmark failed: {error_msg}",
3590
+ }
3591
+ )
3592
+ )
2073
3593
 
2074
3594
  threading.Thread(target=run, daemon=True).start()
2075
3595
 
@@ -2078,9 +3598,9 @@ def cmd_serve(args: argparse.Namespace) -> int:
2078
3598
  def do_OPTIONS(self):
2079
3599
  # Handle CORS preflight
2080
3600
  self.send_response(200)
2081
- self.send_header('Access-Control-Allow-Origin', '*')
2082
- self.send_header('Access-Control-Allow-Methods', 'POST, OPTIONS')
2083
- self.send_header('Access-Control-Allow-Headers', 'Content-Type')
3601
+ self.send_header("Access-Control-Allow-Origin", "*")
3602
+ self.send_header("Access-Control-Allow-Methods", "POST, OPTIONS")
3603
+ self.send_header("Access-Control-Allow-Headers", "Content-Type")
2084
3604
  self.end_headers()
2085
3605
 
2086
3606
  class ThreadedTCPServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
@@ -2134,7 +3654,9 @@ def cmd_viewer(args: argparse.Namespace) -> int:
2134
3654
  state.learning_rate = data.get("learning_rate", 0)
2135
3655
  state.losses = data.get("losses", [])
2136
3656
  state.status = data.get("status", "completed")
2137
- state.elapsed_time = data.get("elapsed_time", 0.0) # Load elapsed time for completed training
3657
+ state.elapsed_time = data.get(
3658
+ "elapsed_time", 0.0
3659
+ ) # Load elapsed time for completed training
2138
3660
  state.goal = data.get("goal", "")
2139
3661
  state.config_path = data.get("config_path", "")
2140
3662
  state.capture_path = data.get("capture_path", "")
@@ -2149,6 +3671,7 @@ def cmd_viewer(args: argparse.Namespace) -> int:
2149
3671
  if not state.model_name and state.config_path:
2150
3672
  try:
2151
3673
  import yaml
3674
+
2152
3675
  # Try relative to project root first, then as absolute path
2153
3676
  project_root = Path(__file__).parent.parent.parent
2154
3677
  config_file = project_root / state.config_path
@@ -2173,14 +3696,16 @@ def cmd_viewer(args: argparse.Namespace) -> int:
2173
3696
 
2174
3697
  dashboard_html = generate_training_dashboard(state, config)
2175
3698
  (current_dir / "dashboard.html").write_text(dashboard_html)
2176
- print(f" Regenerated: dashboard.html")
3699
+ print(" Regenerated: dashboard.html")
2177
3700
 
2178
3701
  # Generate unified viewer using consolidated function
2179
3702
  viewer_path = generate_unified_viewer_from_output_dir(current_dir)
2180
3703
  if viewer_path:
2181
3704
  print(f"\nGenerated: {viewer_path}")
2182
3705
  else:
2183
- print("\nNo comparison data found. Run comparison first or copy from capture directory.")
3706
+ print(
3707
+ "\nNo comparison data found. Run comparison first or copy from capture directory."
3708
+ )
2184
3709
 
2185
3710
  # Also regenerate benchmark viewer from latest benchmark results
2186
3711
  _regenerate_benchmark_viewer_if_available(current_dir)
@@ -2200,9 +3725,9 @@ def cmd_benchmark_viewer(args: argparse.Namespace) -> int:
2200
3725
  print(f"Error: Benchmark directory not found: {benchmark_dir}")
2201
3726
  return 1
2202
3727
 
2203
- print(f"\n{'='*50}")
3728
+ print(f"\n{'=' * 50}")
2204
3729
  print("GENERATING BENCHMARK VIEWER")
2205
- print(f"{'='*50}")
3730
+ print(f"{'=' * 50}")
2206
3731
  print(f"Benchmark dir: {benchmark_dir}")
2207
3732
  print()
2208
3733
 
@@ -2217,6 +3742,7 @@ def cmd_benchmark_viewer(args: argparse.Namespace) -> int:
2217
3742
  except Exception as e:
2218
3743
  print(f"Error generating benchmark viewer: {e}")
2219
3744
  import traceback
3745
+
2220
3746
  traceback.print_exc()
2221
3747
  return 1
2222
3748
 
@@ -2233,16 +3759,19 @@ def cmd_compare(args: argparse.Namespace) -> int:
2233
3759
  print(f"Error: Checkpoint not found: {checkpoint}")
2234
3760
  return 1
2235
3761
 
2236
- print(f"\n{'='*50}")
3762
+ print(f"\n{'=' * 50}")
2237
3763
  print("RUNNING COMPARISON")
2238
- print(f"{'='*50}")
3764
+ print(f"{'=' * 50}")
2239
3765
  print(f"Capture: {capture_path}")
2240
3766
  print(f"Checkpoint: {checkpoint or 'None (capture only)'}")
2241
3767
  print()
2242
3768
 
2243
3769
  cmd = [
2244
- sys.executable, "-m", "openadapt_ml.scripts.compare",
2245
- "--capture", str(capture_path),
3770
+ sys.executable,
3771
+ "-m",
3772
+ "openadapt_ml.scripts.compare",
3773
+ "--capture",
3774
+ str(capture_path),
2246
3775
  ]
2247
3776
 
2248
3777
  if checkpoint:
@@ -2281,7 +3810,7 @@ Examples:
2281
3810
 
2282
3811
  # Run comparison
2283
3812
  uv run python -m openadapt_ml.cloud.local compare --capture ~/captures/my-workflow --checkpoint checkpoints/model
2284
- """
3813
+ """,
2285
3814
  )
2286
3815
 
2287
3816
  subparsers = parser.add_subparsers(dest="command", help="Command")
@@ -2293,9 +3822,15 @@ Examples:
2293
3822
  # train
2294
3823
  p_train = subparsers.add_parser("train", help="Run training locally")
2295
3824
  p_train.add_argument("--capture", required=True, help="Path to capture directory")
2296
- p_train.add_argument("--goal", help="Task goal (default: derived from capture name)")
2297
- p_train.add_argument("--config", help="Config file (default: auto-select based on device)")
2298
- p_train.add_argument("--open", action="store_true", help="Open dashboard in browser")
3825
+ p_train.add_argument(
3826
+ "--goal", help="Task goal (default: derived from capture name)"
3827
+ )
3828
+ p_train.add_argument(
3829
+ "--config", help="Config file (default: auto-select based on device)"
3830
+ )
3831
+ p_train.add_argument(
3832
+ "--open", action="store_true", help="Open dashboard in browser"
3833
+ )
2299
3834
  p_train.set_defaults(func=cmd_train)
2300
3835
 
2301
3836
  # check
@@ -2306,11 +3841,21 @@ Examples:
2306
3841
  p_serve = subparsers.add_parser("serve", help="Start web server for dashboard")
2307
3842
  p_serve.add_argument("--port", type=int, default=8765, help="Port number")
2308
3843
  p_serve.add_argument("--open", action="store_true", help="Open in browser")
2309
- p_serve.add_argument("--quiet", "-q", action="store_true", help="Suppress request logging")
2310
- p_serve.add_argument("--no-regenerate", action="store_true",
2311
- help="Skip regenerating dashboard/viewer (serve existing files)")
2312
- p_serve.add_argument("--benchmark", help="Serve benchmark results directory instead of training output")
2313
- p_serve.add_argument("--start-page", help="Override default start page (e.g., benchmark.html)")
3844
+ p_serve.add_argument(
3845
+ "--quiet", "-q", action="store_true", help="Suppress request logging"
3846
+ )
3847
+ p_serve.add_argument(
3848
+ "--no-regenerate",
3849
+ action="store_true",
3850
+ help="Skip regenerating dashboard/viewer (serve existing files)",
3851
+ )
3852
+ p_serve.add_argument(
3853
+ "--benchmark",
3854
+ help="Serve benchmark results directory instead of training output",
3855
+ )
3856
+ p_serve.add_argument(
3857
+ "--start-page", help="Override default start page (e.g., benchmark.html)"
3858
+ )
2314
3859
  p_serve.set_defaults(func=cmd_serve)
2315
3860
 
2316
3861
  # viewer
@@ -2319,9 +3864,15 @@ Examples:
2319
3864
  p_viewer.set_defaults(func=cmd_viewer)
2320
3865
 
2321
3866
  # benchmark_viewer
2322
- p_benchmark = subparsers.add_parser("benchmark-viewer", help="Generate benchmark viewer")
2323
- p_benchmark.add_argument("benchmark_dir", help="Path to benchmark results directory")
2324
- p_benchmark.add_argument("--open", action="store_true", help="Open viewer in browser")
3867
+ p_benchmark = subparsers.add_parser(
3868
+ "benchmark-viewer", help="Generate benchmark viewer"
3869
+ )
3870
+ p_benchmark.add_argument(
3871
+ "benchmark_dir", help="Path to benchmark results directory"
3872
+ )
3873
+ p_benchmark.add_argument(
3874
+ "--open", action="store_true", help="Open viewer in browser"
3875
+ )
2325
3876
  p_benchmark.set_defaults(func=cmd_benchmark_viewer)
2326
3877
 
2327
3878
  # compare