openadapt-ml 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. openadapt_ml/baselines/__init__.py +121 -0
  2. openadapt_ml/baselines/adapter.py +185 -0
  3. openadapt_ml/baselines/cli.py +314 -0
  4. openadapt_ml/baselines/config.py +448 -0
  5. openadapt_ml/baselines/parser.py +922 -0
  6. openadapt_ml/baselines/prompts.py +787 -0
  7. openadapt_ml/benchmarks/__init__.py +13 -107
  8. openadapt_ml/benchmarks/agent.py +297 -374
  9. openadapt_ml/benchmarks/azure.py +62 -24
  10. openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
  11. openadapt_ml/benchmarks/cli.py +1874 -751
  12. openadapt_ml/benchmarks/trace_export.py +631 -0
  13. openadapt_ml/benchmarks/viewer.py +1236 -0
  14. openadapt_ml/benchmarks/vm_monitor.py +1111 -0
  15. openadapt_ml/benchmarks/waa_deploy/Dockerfile +216 -0
  16. openadapt_ml/benchmarks/waa_deploy/__init__.py +10 -0
  17. openadapt_ml/benchmarks/waa_deploy/api_agent.py +540 -0
  18. openadapt_ml/benchmarks/waa_deploy/start_waa_server.bat +53 -0
  19. openadapt_ml/cloud/azure_inference.py +3 -5
  20. openadapt_ml/cloud/lambda_labs.py +722 -307
  21. openadapt_ml/cloud/local.py +3194 -89
  22. openadapt_ml/cloud/ssh_tunnel.py +595 -0
  23. openadapt_ml/datasets/next_action.py +125 -96
  24. openadapt_ml/evals/grounding.py +32 -9
  25. openadapt_ml/evals/plot_eval_metrics.py +15 -13
  26. openadapt_ml/evals/trajectory_matching.py +120 -57
  27. openadapt_ml/experiments/demo_prompt/__init__.py +19 -0
  28. openadapt_ml/experiments/demo_prompt/format_demo.py +236 -0
  29. openadapt_ml/experiments/demo_prompt/results/experiment_20251231_002125.json +83 -0
  30. openadapt_ml/experiments/demo_prompt/results/experiment_n30_20251231_165958.json +1100 -0
  31. openadapt_ml/experiments/demo_prompt/results/multistep_20251231_025051.json +182 -0
  32. openadapt_ml/experiments/demo_prompt/run_experiment.py +541 -0
  33. openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
  34. openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
  35. openadapt_ml/experiments/representation_shootout/config.py +390 -0
  36. openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
  37. openadapt_ml/experiments/representation_shootout/runner.py +687 -0
  38. openadapt_ml/experiments/waa_demo/__init__.py +10 -0
  39. openadapt_ml/experiments/waa_demo/demos.py +357 -0
  40. openadapt_ml/experiments/waa_demo/runner.py +732 -0
  41. openadapt_ml/experiments/waa_demo/tasks.py +151 -0
  42. openadapt_ml/export/__init__.py +9 -0
  43. openadapt_ml/export/__main__.py +6 -0
  44. openadapt_ml/export/cli.py +89 -0
  45. openadapt_ml/export/parquet.py +277 -0
  46. openadapt_ml/grounding/detector.py +18 -14
  47. openadapt_ml/ingest/__init__.py +11 -10
  48. openadapt_ml/ingest/capture.py +97 -86
  49. openadapt_ml/ingest/loader.py +120 -69
  50. openadapt_ml/ingest/synthetic.py +344 -193
  51. openadapt_ml/models/api_adapter.py +14 -4
  52. openadapt_ml/models/base_adapter.py +10 -2
  53. openadapt_ml/models/providers/__init__.py +288 -0
  54. openadapt_ml/models/providers/anthropic.py +266 -0
  55. openadapt_ml/models/providers/base.py +299 -0
  56. openadapt_ml/models/providers/google.py +376 -0
  57. openadapt_ml/models/providers/openai.py +342 -0
  58. openadapt_ml/models/qwen_vl.py +46 -19
  59. openadapt_ml/perception/__init__.py +35 -0
  60. openadapt_ml/perception/integration.py +399 -0
  61. openadapt_ml/retrieval/README.md +226 -0
  62. openadapt_ml/retrieval/USAGE.md +391 -0
  63. openadapt_ml/retrieval/__init__.py +91 -0
  64. openadapt_ml/retrieval/demo_retriever.py +843 -0
  65. openadapt_ml/retrieval/embeddings.py +630 -0
  66. openadapt_ml/retrieval/index.py +194 -0
  67. openadapt_ml/retrieval/retriever.py +162 -0
  68. openadapt_ml/runtime/__init__.py +50 -0
  69. openadapt_ml/runtime/policy.py +27 -14
  70. openadapt_ml/runtime/safety_gate.py +471 -0
  71. openadapt_ml/schema/__init__.py +113 -0
  72. openadapt_ml/schema/converters.py +588 -0
  73. openadapt_ml/schema/episode.py +470 -0
  74. openadapt_ml/scripts/capture_screenshots.py +530 -0
  75. openadapt_ml/scripts/compare.py +102 -61
  76. openadapt_ml/scripts/demo_policy.py +4 -1
  77. openadapt_ml/scripts/eval_policy.py +19 -14
  78. openadapt_ml/scripts/make_gif.py +1 -1
  79. openadapt_ml/scripts/prepare_synthetic.py +16 -17
  80. openadapt_ml/scripts/train.py +98 -75
  81. openadapt_ml/segmentation/README.md +920 -0
  82. openadapt_ml/segmentation/__init__.py +97 -0
  83. openadapt_ml/segmentation/adapters/__init__.py +5 -0
  84. openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
  85. openadapt_ml/segmentation/annotator.py +610 -0
  86. openadapt_ml/segmentation/cache.py +290 -0
  87. openadapt_ml/segmentation/cli.py +674 -0
  88. openadapt_ml/segmentation/deduplicator.py +656 -0
  89. openadapt_ml/segmentation/frame_describer.py +788 -0
  90. openadapt_ml/segmentation/pipeline.py +340 -0
  91. openadapt_ml/segmentation/schemas.py +622 -0
  92. openadapt_ml/segmentation/segment_extractor.py +634 -0
  93. openadapt_ml/training/azure_ops_viewer.py +1097 -0
  94. openadapt_ml/training/benchmark_viewer.py +3255 -19
  95. openadapt_ml/training/shared_ui.py +7 -7
  96. openadapt_ml/training/stub_provider.py +57 -35
  97. openadapt_ml/training/trainer.py +255 -441
  98. openadapt_ml/training/trl_trainer.py +403 -0
  99. openadapt_ml/training/viewer.py +323 -108
  100. openadapt_ml/training/viewer_components.py +180 -0
  101. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +312 -69
  102. openadapt_ml-0.2.1.dist-info/RECORD +116 -0
  103. openadapt_ml/benchmarks/base.py +0 -366
  104. openadapt_ml/benchmarks/data_collection.py +0 -432
  105. openadapt_ml/benchmarks/runner.py +0 -381
  106. openadapt_ml/benchmarks/waa.py +0 -704
  107. openadapt_ml/schemas/__init__.py +0 -53
  108. openadapt_ml/schemas/sessions.py +0 -122
  109. openadapt_ml/schemas/validation.py +0 -252
  110. openadapt_ml-0.1.0.dist-info/RECORD +0 -55
  111. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
  112. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -27,7 +27,6 @@ import http.server
27
27
  import json
28
28
  import os
29
29
  import shutil
30
- import signal
31
30
  import socketserver
32
31
  import subprocess
33
32
  import sys
@@ -36,6 +35,8 @@ import webbrowser
36
35
  from pathlib import Path
37
36
  from typing import Any
38
37
 
38
+ from openadapt_ml.cloud.ssh_tunnel import get_tunnel_manager
39
+
39
40
  # Training output directory
40
41
  TRAINING_OUTPUT = Path("training_output")
41
42
 
@@ -107,7 +108,10 @@ def _is_mock_benchmark(benchmark_dir: Path) -> bool:
107
108
 
108
109
  # Check for test runs (but allow waa-mock evaluations with real API models)
109
110
  # Only filter out purely synthetic test data directories
110
- if any(term in benchmark_dir.name.lower() for term in ["test_run", "test_cli", "quick_demo"]):
111
+ if any(
112
+ term in benchmark_dir.name.lower()
113
+ for term in ["test_run", "test_cli", "quick_demo"]
114
+ ):
111
115
  return True
112
116
 
113
117
  return False
@@ -143,6 +147,16 @@ def _regenerate_benchmark_viewer_if_available(output_dir: Path) -> bool:
143
147
  # No real benchmark data - generate empty state viewer
144
148
  try:
145
149
  generate_empty_benchmark_viewer(benchmark_html_path)
150
+
151
+ # Still create symlink for azure_jobs.json access (even without real benchmarks)
152
+ if benchmark_results_dir.exists():
153
+ benchmark_results_link = output_dir / "benchmark_results"
154
+ if benchmark_results_link.is_symlink():
155
+ benchmark_results_link.unlink()
156
+ elif benchmark_results_link.exists():
157
+ shutil.rmtree(benchmark_results_link)
158
+ benchmark_results_link.symlink_to(benchmark_results_dir.absolute())
159
+
146
160
  print(" Generated benchmark viewer: No real evaluation data yet")
147
161
  return True
148
162
  except Exception as e:
@@ -168,11 +182,20 @@ def _regenerate_benchmark_viewer_if_available(output_dir: Path) -> bool:
168
182
  tasks_dst = benchmark_tasks_dir / benchmark_dir.name
169
183
  shutil.copytree(tasks_src, tasks_dst)
170
184
 
185
+ # Create symlink for benchmark_results directory (for azure_jobs.json access)
186
+ benchmark_results_link = output_dir / "benchmark_results"
187
+ if benchmark_results_link.is_symlink():
188
+ benchmark_results_link.unlink()
189
+ elif benchmark_results_link.exists():
190
+ shutil.rmtree(benchmark_results_link)
191
+ benchmark_results_link.symlink_to(benchmark_results_dir.absolute())
192
+
171
193
  print(f" Regenerated benchmark viewer with {len(real_benchmarks)} run(s)")
172
194
  return True
173
195
  except Exception as e:
174
196
  print(f" Could not regenerate benchmark viewer: {e}")
175
197
  import traceback
198
+
176
199
  traceback.print_exc()
177
200
  return False
178
201
 
@@ -181,6 +204,7 @@ def detect_device() -> str:
181
204
  """Detect available compute device."""
182
205
  try:
183
206
  import torch
207
+
184
208
  if torch.cuda.is_available():
185
209
  device_name = torch.cuda.get_device_name(0)
186
210
  return f"cuda ({device_name})"
@@ -231,10 +255,13 @@ def get_training_status() -> dict[str, Any]:
231
255
  # Find checkpoints
232
256
  checkpoints_dir = Path("checkpoints")
233
257
  if checkpoints_dir.exists():
234
- status["checkpoints"] = sorted([
235
- d.name for d in checkpoints_dir.iterdir()
236
- if d.is_dir() and (d / "adapter_config.json").exists()
237
- ])
258
+ status["checkpoints"] = sorted(
259
+ [
260
+ d.name
261
+ for d in checkpoints_dir.iterdir()
262
+ if d.is_dir() and (d / "adapter_config.json").exists()
263
+ ]
264
+ )
238
265
 
239
266
  return status
240
267
 
@@ -244,9 +271,9 @@ def cmd_status(args: argparse.Namespace) -> int:
244
271
  status = get_training_status()
245
272
  current_dir = get_current_output_dir()
246
273
 
247
- print(f"\n{'='*50}")
274
+ print(f"\n{'=' * 50}")
248
275
  print("LOCAL TRAINING STATUS")
249
- print(f"{'='*50}")
276
+ print(f"{'=' * 50}")
250
277
  print(f"Device: {status['device']}")
251
278
  print(f"Status: {'RUNNING' if status['running'] else 'IDLE'}")
252
279
  if status.get("job_id"):
@@ -254,7 +281,7 @@ def cmd_status(args: argparse.Namespace) -> int:
254
281
  print(f"Output: {current_dir}")
255
282
 
256
283
  if status.get("epoch"):
257
- print(f"\nProgress:")
284
+ print("\nProgress:")
258
285
  print(f" Epoch: {status['epoch']}")
259
286
  print(f" Step: {status['step']}")
260
287
  if status.get("loss"):
@@ -267,7 +294,9 @@ def cmd_status(args: argparse.Namespace) -> int:
267
294
  for cp in status["checkpoints"][-5:]: # Show last 5
268
295
  print(f" - {cp}")
269
296
 
270
- print(f"\nDashboard: {'✓' if status['has_dashboard'] else '✗'} {current_dir}/dashboard.html")
297
+ print(
298
+ f"\nDashboard: {'✓' if status['has_dashboard'] else '✗'} {current_dir}/dashboard.html"
299
+ )
271
300
  print(f"Viewer: {'✓' if status['has_viewer'] else '✗'} {current_dir}/viewer.html")
272
301
  print()
273
302
 
@@ -300,9 +329,9 @@ def cmd_train(args: argparse.Namespace) -> int:
300
329
  print(f"Error: Config not found: {config_path}")
301
330
  return 1
302
331
 
303
- print(f"\n{'='*50}")
332
+ print(f"\n{'=' * 50}")
304
333
  print("STARTING LOCAL TRAINING")
305
- print(f"{'='*50}")
334
+ print(f"{'=' * 50}")
306
335
  print(f"Capture: {capture_path}")
307
336
  print(f"Goal: {goal}")
308
337
  print(f"Config: {config}")
@@ -311,10 +340,15 @@ def cmd_train(args: argparse.Namespace) -> int:
311
340
 
312
341
  # Build command
313
342
  cmd = [
314
- sys.executable, "-m", "openadapt_ml.scripts.train",
315
- "--config", str(config_path),
316
- "--capture", str(capture_path),
317
- "--goal", goal,
343
+ sys.executable,
344
+ "-m",
345
+ "openadapt_ml.scripts.train",
346
+ "--config",
347
+ str(config_path),
348
+ "--capture",
349
+ str(capture_path),
350
+ "--goal",
351
+ goal,
318
352
  ]
319
353
 
320
354
  if args.open:
@@ -333,14 +367,16 @@ def cmd_check(args: argparse.Namespace) -> int:
333
367
  """Check training health and early stopping analysis."""
334
368
  status = get_training_status()
335
369
 
336
- print(f"\n{'='*50}")
370
+ print(f"\n{'=' * 50}")
337
371
  print("TRAINING HEALTH CHECK")
338
- print(f"{'='*50}")
372
+ print(f"{'=' * 50}")
339
373
 
340
374
  raw_losses = status.get("losses", [])
341
375
  if not raw_losses:
342
376
  print("No training data found.")
343
- print("Run training first with: uv run python -m openadapt_ml.cloud.local train --capture <path>")
377
+ print(
378
+ "Run training first with: uv run python -m openadapt_ml.cloud.local train --capture <path>"
379
+ )
344
380
  return 1
345
381
 
346
382
  # Extract loss values (handle both dict and float formats)
@@ -361,7 +397,7 @@ def cmd_check(args: argparse.Namespace) -> int:
361
397
  min_loss = min(losses)
362
398
  max_loss = max(losses)
363
399
 
364
- print(f"\nLoss progression:")
400
+ print("\nLoss progression:")
365
401
  print(f" First: {first_loss:.4f}")
366
402
  print(f" Last: {last_loss:.4f}")
367
403
  print(f" Min: {min_loss:.4f}")
@@ -372,9 +408,11 @@ def cmd_check(args: argparse.Namespace) -> int:
372
408
  if len(losses) >= 10:
373
409
  recent = losses[-10:]
374
410
  recent_avg = sum(recent) / len(recent)
375
- recent_std = (sum((x - recent_avg) ** 2 for x in recent) / len(recent)) ** 0.5
411
+ recent_std = (
412
+ sum((x - recent_avg) ** 2 for x in recent) / len(recent)
413
+ ) ** 0.5
376
414
 
377
- print(f"\nRecent stability (last 10 steps):")
415
+ print("\nRecent stability (last 10 steps):")
378
416
  print(f" Avg loss: {recent_avg:.4f}")
379
417
  print(f" Std dev: {recent_std:.4f}")
380
418
 
@@ -393,14 +431,18 @@ def cmd_serve(args: argparse.Namespace) -> int:
393
431
  """Start local web server for dashboard.
394
432
 
395
433
  Automatically regenerates dashboard and viewer before serving to ensure
396
- the latest code and data are reflected.
434
+ the latest code and data are reflected. Also ensures the 'current' symlink
435
+ points to the most recent training run.
397
436
  """
398
- from openadapt_ml.training.trainer import regenerate_local_dashboard
437
+ from openadapt_ml.training.trainer import (
438
+ regenerate_local_dashboard,
439
+ update_current_symlink_to_latest,
440
+ )
399
441
 
400
442
  port = args.port
401
443
 
402
444
  # Determine what to serve: benchmark directory or training output
403
- if hasattr(args, 'benchmark') and args.benchmark:
445
+ if hasattr(args, "benchmark") and args.benchmark:
404
446
  serve_dir = Path(args.benchmark).expanduser().resolve()
405
447
  if not serve_dir.exists():
406
448
  print(f"Error: Benchmark directory not found: {serve_dir}")
@@ -410,7 +452,10 @@ def cmd_serve(args: argparse.Namespace) -> int:
410
452
  if not args.no_regenerate:
411
453
  print("Regenerating benchmark viewer...")
412
454
  try:
413
- from openadapt_ml.training.benchmark_viewer import generate_benchmark_viewer
455
+ from openadapt_ml.training.benchmark_viewer import (
456
+ generate_benchmark_viewer,
457
+ )
458
+
414
459
  generate_benchmark_viewer(serve_dir)
415
460
  except Exception as e:
416
461
  print(f"Warning: Could not regenerate benchmark viewer: {e}")
@@ -419,6 +464,17 @@ def cmd_serve(args: argparse.Namespace) -> int:
419
464
  else:
420
465
  serve_dir = get_current_output_dir()
421
466
 
467
+ # If current symlink doesn't exist or is broken, update to latest run
468
+ if not serve_dir.exists() or not serve_dir.is_dir():
469
+ print("Updating 'current' symlink to latest training run...")
470
+ latest = update_current_symlink_to_latest()
471
+ if latest:
472
+ serve_dir = get_current_output_dir()
473
+ print(f" Updated to: {latest.name}")
474
+ else:
475
+ print(f"Error: {serve_dir} not found. Run training first.")
476
+ return 1
477
+
422
478
  if not serve_dir.exists():
423
479
  print(f"Error: {serve_dir} not found. Run training first.")
424
480
  return 1
@@ -427,7 +483,9 @@ def cmd_serve(args: argparse.Namespace) -> int:
427
483
  if not args.no_regenerate:
428
484
  print("Regenerating dashboard and viewer...")
429
485
  try:
430
- regenerate_local_dashboard(str(serve_dir))
486
+ # Use keep_polling=True so JavaScript fetches live data from training_log.json
487
+ # This ensures the dashboard shows current data instead of stale embedded data
488
+ regenerate_local_dashboard(str(serve_dir), keep_polling=True)
431
489
  # Also regenerate viewer if comparison data exists
432
490
  _regenerate_viewer_if_possible(serve_dir)
433
491
  except Exception as e:
@@ -436,8 +494,23 @@ def cmd_serve(args: argparse.Namespace) -> int:
436
494
  # Also regenerate benchmark viewer from latest benchmark results
437
495
  _regenerate_benchmark_viewer_if_available(serve_dir)
438
496
 
497
+ # Generate Azure ops dashboard
498
+ try:
499
+ from openadapt_ml.training.azure_ops_viewer import (
500
+ generate_azure_ops_dashboard,
501
+ )
502
+
503
+ generate_azure_ops_dashboard(serve_dir / "azure_ops.html")
504
+ print(" Generated Azure ops dashboard")
505
+ except Exception as e:
506
+ print(f" Warning: Could not generate Azure ops dashboard: {e}")
507
+
439
508
  start_page = "dashboard.html"
440
509
 
510
+ # Override start page if specified
511
+ if hasattr(args, "start_page") and args.start_page:
512
+ start_page = args.start_page
513
+
441
514
  # Serve from the specified directory
442
515
  os.chdir(serve_dir)
443
516
 
@@ -452,33 +525,41 @@ def cmd_serve(args: argparse.Namespace) -> int:
452
525
  super().log_message(format, *log_args)
453
526
 
454
527
  def do_POST(self):
455
- if self.path == '/api/stop':
528
+ if self.path == "/api/stop":
456
529
  # Create stop signal file
457
530
  stop_file = serve_dir / "STOP_TRAINING"
458
531
  stop_file.touch()
459
532
  self.send_response(200)
460
- self.send_header('Content-Type', 'application/json')
461
- self.send_header('Access-Control-Allow-Origin', '*')
533
+ self.send_header("Content-Type", "application/json")
534
+ self.send_header("Access-Control-Allow-Origin", "*")
462
535
  self.end_headers()
463
536
  self.wfile.write(b'{"status": "stop_signal_created"}')
464
537
  print(f"\n⏹ Stop signal created: {stop_file}")
465
- elif self.path == '/api/run-benchmark':
538
+ elif self.path == "/api/run-benchmark":
466
539
  # Parse request body for provider
467
- content_length = int(self.headers.get('Content-Length', 0))
468
- body = self.rfile.read(content_length).decode('utf-8') if content_length else '{}'
540
+ content_length = int(self.headers.get("Content-Length", 0))
541
+ body = (
542
+ self.rfile.read(content_length).decode("utf-8")
543
+ if content_length
544
+ else "{}"
545
+ )
469
546
  try:
470
547
  params = json.loads(body)
471
548
  except json.JSONDecodeError:
472
549
  params = {}
473
550
 
474
- provider = params.get('provider', 'anthropic')
475
- tasks = params.get('tasks', 5)
551
+ provider = params.get("provider", "anthropic")
552
+ tasks = params.get("tasks", 5)
476
553
 
477
554
  self.send_response(200)
478
- self.send_header('Content-Type', 'application/json')
479
- self.send_header('Access-Control-Allow-Origin', '*')
555
+ self.send_header("Content-Type", "application/json")
556
+ self.send_header("Access-Control-Allow-Origin", "*")
480
557
  self.end_headers()
481
- self.wfile.write(json.dumps({"status": "started", "provider": provider, "tasks": tasks}).encode())
558
+ self.wfile.write(
559
+ json.dumps(
560
+ {"status": "started", "provider": provider, "tasks": tasks}
561
+ ).encode()
562
+ )
482
563
 
483
564
  # Run benchmark in background thread with progress logging
484
565
  def run_benchmark():
@@ -492,25 +573,45 @@ def cmd_serve(args: argparse.Namespace) -> int:
492
573
  # Create progress log file (in cwd which is serve_dir)
493
574
  progress_file = Path("benchmark_progress.json")
494
575
 
495
- print(f"\n🚀 Starting {provider} benchmark evaluation ({tasks} tasks)...")
576
+ print(
577
+ f"\n🚀 Starting {provider} benchmark evaluation ({tasks} tasks)..."
578
+ )
496
579
 
497
580
  # Write initial progress
498
- progress_file.write_text(json.dumps({
499
- "status": "running",
500
- "provider": provider,
501
- "tasks_total": tasks,
502
- "tasks_complete": 0,
503
- "message": f"Starting {provider} evaluation..."
504
- }))
581
+ progress_file.write_text(
582
+ json.dumps(
583
+ {
584
+ "status": "running",
585
+ "provider": provider,
586
+ "tasks_total": tasks,
587
+ "tasks_complete": 0,
588
+ "message": f"Starting {provider} evaluation...",
589
+ }
590
+ )
591
+ )
505
592
 
506
593
  # Copy environment with loaded vars
507
594
  env = os.environ.copy()
508
595
 
509
596
  result = subprocess.run(
510
- ["uv", "run", "python", "-m", "openadapt_ml.benchmarks.cli", "run-api",
511
- "--provider", provider, "--tasks", str(tasks),
512
- "--model-id", f"{provider}-api"],
513
- capture_output=True, text=True, cwd=str(project_root), env=env
597
+ [
598
+ "uv",
599
+ "run",
600
+ "python",
601
+ "-m",
602
+ "openadapt_ml.benchmarks.cli",
603
+ "run-api",
604
+ "--provider",
605
+ provider,
606
+ "--tasks",
607
+ str(tasks),
608
+ "--model-id",
609
+ f"{provider}-api",
610
+ ],
611
+ capture_output=True,
612
+ text=True,
613
+ cwd=str(project_root),
614
+ env=env,
514
615
  )
515
616
 
516
617
  print(f"\n📋 Benchmark output:\n{result.stdout}")
@@ -518,53 +619,2995 @@ def cmd_serve(args: argparse.Namespace) -> int:
518
619
  print(f"Stderr: {result.stderr}")
519
620
 
520
621
  if result.returncode == 0:
521
- print(f"✅ Benchmark complete. Regenerating viewer...")
522
- progress_file.write_text(json.dumps({
523
- "status": "complete",
524
- "provider": provider,
525
- "message": "Evaluation complete! Refreshing results..."
526
- }))
622
+ print("✅ Benchmark complete. Regenerating viewer...")
623
+ progress_file.write_text(
624
+ json.dumps(
625
+ {
626
+ "status": "complete",
627
+ "provider": provider,
628
+ "message": "Evaluation complete! Refreshing results...",
629
+ }
630
+ )
631
+ )
527
632
  # Regenerate benchmark viewer
528
633
  _regenerate_benchmark_viewer_if_available(serve_dir)
529
634
  else:
530
635
  print(f"❌ Benchmark failed: {result.stderr}")
531
- progress_file.write_text(json.dumps({
532
- "status": "error",
533
- "provider": provider,
534
- "message": f"Evaluation failed: {result.stderr[:200]}"
535
- }))
636
+ progress_file.write_text(
637
+ json.dumps(
638
+ {
639
+ "status": "error",
640
+ "provider": provider,
641
+ "message": f"Evaluation failed: {result.stderr[:200]}",
642
+ }
643
+ )
644
+ )
536
645
 
537
646
  threading.Thread(target=run_benchmark, daemon=True).start()
647
+ elif self.path == "/api/vms/register":
648
+ # Register a new VM
649
+ content_length = int(self.headers.get("Content-Length", 0))
650
+ body = (
651
+ self.rfile.read(content_length).decode("utf-8")
652
+ if content_length
653
+ else "{}"
654
+ )
655
+ try:
656
+ vm_data = json.loads(body)
657
+ result = self._register_vm(vm_data)
658
+ self.send_response(200)
659
+ self.send_header("Content-Type", "application/json")
660
+ self.send_header("Access-Control-Allow-Origin", "*")
661
+ self.end_headers()
662
+ self.wfile.write(json.dumps(result).encode())
663
+ except Exception as e:
664
+ self.send_response(500)
665
+ self.send_header("Content-Type", "application/json")
666
+ self.send_header("Access-Control-Allow-Origin", "*")
667
+ self.end_headers()
668
+ self.wfile.write(json.dumps({"error": str(e)}).encode())
669
+ elif self.path == "/api/benchmark/start":
670
+ # Start a benchmark run with configurable parameters
671
+ content_length = int(self.headers.get("Content-Length", 0))
672
+ body = (
673
+ self.rfile.read(content_length).decode("utf-8")
674
+ if content_length
675
+ else "{}"
676
+ )
677
+ try:
678
+ params = json.loads(body)
679
+ result = self._start_benchmark_run(params)
680
+ self.send_response(200)
681
+ self.send_header("Content-Type", "application/json")
682
+ self.send_header("Access-Control-Allow-Origin", "*")
683
+ self.end_headers()
684
+ self.wfile.write(json.dumps(result).encode())
685
+ except Exception as e:
686
+ self.send_response(500)
687
+ self.send_header("Content-Type", "application/json")
688
+ self.send_header("Access-Control-Allow-Origin", "*")
689
+ self.end_headers()
690
+ self.wfile.write(json.dumps({"error": str(e)}).encode())
538
691
  else:
539
692
  self.send_error(404, "Not found")
540
693
 
541
694
  def do_GET(self):
542
- if self.path.startswith('/api/benchmark-progress'):
695
+ if self.path.startswith("/api/benchmark-progress"):
543
696
  # Return benchmark progress
544
- progress_file = Path("benchmark_progress.json") # Relative to serve_dir (cwd)
697
+ progress_file = Path(
698
+ "benchmark_progress.json"
699
+ ) # Relative to serve_dir (cwd)
545
700
  if progress_file.exists():
546
701
  progress = progress_file.read_text()
547
702
  else:
548
703
  progress = json.dumps({"status": "idle"})
549
704
 
550
705
  self.send_response(200)
551
- self.send_header('Content-Type', 'application/json')
552
- self.send_header('Access-Control-Allow-Origin', '*')
706
+ self.send_header("Content-Type", "application/json")
707
+ self.send_header("Access-Control-Allow-Origin", "*")
553
708
  self.end_headers()
554
709
  self.wfile.write(progress.encode())
710
+ elif self.path.startswith("/api/benchmark-live"):
711
+ # Return live evaluation state
712
+ live_file = Path("benchmark_live.json") # Relative to serve_dir (cwd)
713
+ if live_file.exists():
714
+ live_state = live_file.read_text()
715
+ else:
716
+ live_state = json.dumps({"status": "idle"})
717
+
718
+ self.send_response(200)
719
+ self.send_header("Content-Type", "application/json")
720
+ self.send_header("Access-Control-Allow-Origin", "*")
721
+ self.end_headers()
722
+ self.wfile.write(live_state.encode())
723
+ elif self.path.startswith("/api/tasks"):
724
+ # Return background task status (VM, Docker, benchmarks)
725
+ try:
726
+ tasks = self._fetch_background_tasks()
727
+ self.send_response(200)
728
+ self.send_header("Content-Type", "application/json")
729
+ self.send_header("Access-Control-Allow-Origin", "*")
730
+ self.end_headers()
731
+ self.wfile.write(json.dumps(tasks).encode())
732
+ except Exception as e:
733
+ self.send_response(500)
734
+ self.send_header("Content-Type", "application/json")
735
+ self.send_header("Access-Control-Allow-Origin", "*")
736
+ self.end_headers()
737
+ self.wfile.write(json.dumps({"error": str(e)}).encode())
738
+ elif self.path.startswith("/api/azure-jobs"):
739
+ # Return LIVE Azure job status from Azure ML
740
+ # Supports ?force=true parameter for manual refresh (always fetches live)
741
+ try:
742
+ from urllib.parse import urlparse, parse_qs
743
+
744
+ query = parse_qs(urlparse(self.path).query)
745
+ force_refresh = query.get("force", ["false"])[0].lower() == "true"
746
+
747
+ # Always fetch live data (force just indicates manual refresh for logging)
748
+ if force_refresh:
749
+ print("Azure Jobs: Manual refresh requested")
750
+
751
+ jobs = self._fetch_live_azure_jobs()
752
+ self.send_response(200)
753
+ self.send_header("Content-Type", "application/json")
754
+ self.send_header("Access-Control-Allow-Origin", "*")
755
+ self.end_headers()
756
+ self.wfile.write(json.dumps(jobs).encode())
757
+ except Exception as e:
758
+ self.send_response(500)
759
+ self.send_header("Content-Type", "application/json")
760
+ self.send_header("Access-Control-Allow-Origin", "*")
761
+ self.end_headers()
762
+ self.wfile.write(json.dumps({"error": str(e)}).encode())
763
+ elif self.path.startswith("/api/benchmark-sse"):
764
+ # Server-Sent Events endpoint for real-time benchmark updates
765
+ try:
766
+ from urllib.parse import urlparse, parse_qs
767
+
768
+ query = parse_qs(urlparse(self.path).query)
769
+ interval = int(query.get("interval", [5])[0])
770
+
771
+ # Validate interval (min 1s, max 60s)
772
+ interval = max(1, min(60, interval))
773
+
774
+ self._stream_benchmark_updates(interval)
775
+ except Exception as e:
776
+ self.send_error(500, f"SSE error: {e}")
777
+ elif self.path.startswith("/api/vms"):
778
+ # Return VM registry with live status
779
+ try:
780
+ vms = self._fetch_vm_registry()
781
+ self.send_response(200)
782
+ self.send_header("Content-Type", "application/json")
783
+ self.send_header("Access-Control-Allow-Origin", "*")
784
+ self.end_headers()
785
+ self.wfile.write(json.dumps(vms).encode())
786
+ except Exception as e:
787
+ self.send_response(500)
788
+ self.send_header("Content-Type", "application/json")
789
+ self.send_header("Access-Control-Allow-Origin", "*")
790
+ self.end_headers()
791
+ self.wfile.write(json.dumps({"error": str(e)}).encode())
792
+ elif self.path.startswith("/api/azure-job-logs"):
793
+ # Return live logs for running Azure job
794
+ try:
795
+ # Parse job_id from query string
796
+ from urllib.parse import urlparse, parse_qs
797
+
798
+ query = parse_qs(urlparse(self.path).query)
799
+ job_id = query.get("job_id", [None])[0]
800
+
801
+ logs = self._fetch_azure_job_logs(job_id)
802
+ self.send_response(200)
803
+ self.send_header("Content-Type", "application/json")
804
+ self.send_header("Access-Control-Allow-Origin", "*")
805
+ self.end_headers()
806
+ self.wfile.write(json.dumps(logs).encode())
807
+ except Exception as e:
808
+ self.send_response(500)
809
+ self.send_header("Content-Type", "application/json")
810
+ self.send_header("Access-Control-Allow-Origin", "*")
811
+ self.end_headers()
812
+ self.wfile.write(json.dumps({"error": str(e)}).encode())
813
+ elif self.path.startswith("/api/probe-vm"):
814
+ # Probe the VM to check if WAA server is responding
815
+ try:
816
+ result = self._probe_vm()
817
+ self.send_response(200)
818
+ self.send_header("Content-Type", "application/json")
819
+ self.send_header("Access-Control-Allow-Origin", "*")
820
+ self.end_headers()
821
+ self.wfile.write(json.dumps(result).encode())
822
+ except Exception as e:
823
+ self.send_response(500)
824
+ self.send_header("Content-Type", "application/json")
825
+ self.send_header("Access-Control-Allow-Origin", "*")
826
+ self.end_headers()
827
+ self.wfile.write(
828
+ json.dumps({"error": str(e), "responding": False}).encode()
829
+ )
830
+ elif self.path.startswith("/api/tunnels"):
831
+ # Return SSH tunnel status
832
+ try:
833
+ tunnel_mgr = get_tunnel_manager()
834
+ status = tunnel_mgr.get_tunnel_status()
835
+ result = {
836
+ name: {
837
+ "active": s.active,
838
+ "local_port": s.local_port,
839
+ "remote_endpoint": s.remote_endpoint,
840
+ "pid": s.pid,
841
+ "error": s.error,
842
+ }
843
+ for name, s in status.items()
844
+ }
845
+ self.send_response(200)
846
+ self.send_header("Content-Type", "application/json")
847
+ self.send_header("Access-Control-Allow-Origin", "*")
848
+ self.end_headers()
849
+ self.wfile.write(json.dumps(result).encode())
850
+ except Exception as e:
851
+ self.send_response(500)
852
+ self.send_header("Content-Type", "application/json")
853
+ self.send_header("Access-Control-Allow-Origin", "*")
854
+ self.end_headers()
855
+ self.wfile.write(json.dumps({"error": str(e)}).encode())
856
+ elif self.path.startswith("/api/current-run"):
857
+ # Return currently running benchmark info
858
+ try:
859
+ result = self._get_current_run()
860
+ self.send_response(200)
861
+ self.send_header("Content-Type", "application/json")
862
+ self.send_header("Access-Control-Allow-Origin", "*")
863
+ self.end_headers()
864
+ self.wfile.write(json.dumps(result).encode())
865
+ except Exception as e:
866
+ self.send_response(500)
867
+ self.send_header("Content-Type", "application/json")
868
+ self.send_header("Access-Control-Allow-Origin", "*")
869
+ self.end_headers()
870
+ self.wfile.write(
871
+ json.dumps({"error": str(e), "running": False}).encode()
872
+ )
873
+ elif self.path.startswith("/api/background-tasks"):
874
+ # Alias for /api/tasks - background task status
875
+ try:
876
+ tasks = self._fetch_background_tasks()
877
+ self.send_response(200)
878
+ self.send_header("Content-Type", "application/json")
879
+ self.send_header("Access-Control-Allow-Origin", "*")
880
+ self.end_headers()
881
+ self.wfile.write(json.dumps(tasks).encode())
882
+ except Exception as e:
883
+ self.send_response(500)
884
+ self.send_header("Content-Type", "application/json")
885
+ self.send_header("Access-Control-Allow-Origin", "*")
886
+ self.end_headers()
887
+ self.wfile.write(json.dumps({"error": str(e)}).encode())
888
+ elif self.path.startswith("/api/benchmark/status"):
889
+ # Return current benchmark job status with ETA
890
+ try:
891
+ status = self._get_benchmark_status()
892
+ self.send_response(200)
893
+ self.send_header("Content-Type", "application/json")
894
+ self.send_header("Access-Control-Allow-Origin", "*")
895
+ self.end_headers()
896
+ self.wfile.write(json.dumps(status).encode())
897
+ except Exception as e:
898
+ self.send_response(500)
899
+ self.send_header("Content-Type", "application/json")
900
+ self.send_header("Access-Control-Allow-Origin", "*")
901
+ self.end_headers()
902
+ self.wfile.write(
903
+ json.dumps({"error": str(e), "status": "error"}).encode()
904
+ )
905
+ elif self.path.startswith("/api/benchmark/costs"):
906
+ # Return cost breakdown (Azure VM, API calls, GPU)
907
+ try:
908
+ costs = self._get_benchmark_costs()
909
+ self.send_response(200)
910
+ self.send_header("Content-Type", "application/json")
911
+ self.send_header("Access-Control-Allow-Origin", "*")
912
+ self.end_headers()
913
+ self.wfile.write(json.dumps(costs).encode())
914
+ except Exception as e:
915
+ self.send_response(500)
916
+ self.send_header("Content-Type", "application/json")
917
+ self.send_header("Access-Control-Allow-Origin", "*")
918
+ self.end_headers()
919
+ self.wfile.write(json.dumps({"error": str(e)}).encode())
920
+ elif self.path.startswith("/api/benchmark/metrics"):
921
+ # Return performance metrics (success rate, domain breakdown)
922
+ try:
923
+ metrics = self._get_benchmark_metrics()
924
+ self.send_response(200)
925
+ self.send_header("Content-Type", "application/json")
926
+ self.send_header("Access-Control-Allow-Origin", "*")
927
+ self.end_headers()
928
+ self.wfile.write(json.dumps(metrics).encode())
929
+ except Exception as e:
930
+ self.send_response(500)
931
+ self.send_header("Content-Type", "application/json")
932
+ self.send_header("Access-Control-Allow-Origin", "*")
933
+ self.end_headers()
934
+ self.wfile.write(json.dumps({"error": str(e)}).encode())
935
+ elif self.path.startswith("/api/benchmark/workers"):
936
+ # Return worker status and utilization
937
+ try:
938
+ workers = self._get_benchmark_workers()
939
+ self.send_response(200)
940
+ self.send_header("Content-Type", "application/json")
941
+ self.send_header("Access-Control-Allow-Origin", "*")
942
+ self.end_headers()
943
+ self.wfile.write(json.dumps(workers).encode())
944
+ except Exception as e:
945
+ self.send_response(500)
946
+ self.send_header("Content-Type", "application/json")
947
+ self.send_header("Access-Control-Allow-Origin", "*")
948
+ self.end_headers()
949
+ self.wfile.write(json.dumps({"error": str(e)}).encode())
950
+ elif self.path.startswith("/api/benchmark/runs"):
951
+ # Return list of all benchmark runs
952
+ try:
953
+ runs = self._get_benchmark_runs()
954
+ self.send_response(200)
955
+ self.send_header("Content-Type", "application/json")
956
+ self.send_header("Access-Control-Allow-Origin", "*")
957
+ self.end_headers()
958
+ self.wfile.write(json.dumps(runs).encode())
959
+ except Exception as e:
960
+ self.send_response(500)
961
+ self.send_header("Content-Type", "application/json")
962
+ self.send_header("Access-Control-Allow-Origin", "*")
963
+ self.end_headers()
964
+ self.wfile.write(json.dumps({"error": str(e)}).encode())
965
+ elif self.path.startswith("/api/benchmark/tasks/"):
966
+ # Return task execution details
967
+ # URL format: /api/benchmark/tasks/{run_name}/{task_id}
968
+ try:
969
+ parts = self.path.split("/")
970
+ if len(parts) >= 6:
971
+ run_name = parts[4]
972
+ task_id = parts[5]
973
+ execution = self._get_task_execution(run_name, task_id)
974
+ self.send_response(200)
975
+ self.send_header("Content-Type", "application/json")
976
+ self.send_header("Access-Control-Allow-Origin", "*")
977
+ self.end_headers()
978
+ self.wfile.write(json.dumps(execution).encode())
979
+ else:
980
+ self.send_error(400, "Invalid path format")
981
+ except Exception as e:
982
+ self.send_response(500)
983
+ self.send_header("Content-Type", "application/json")
984
+ self.send_header("Access-Control-Allow-Origin", "*")
985
+ self.end_headers()
986
+ self.wfile.write(json.dumps({"error": str(e)}).encode())
987
+ elif self.path.startswith("/api/benchmark/screenshots/"):
988
+ # Serve screenshot files
989
+ # URL format: /api/benchmark/screenshots/{run_name}/{task_id}/screenshots/{filename}
990
+ try:
991
+ # Remove /api/benchmark/screenshots/ prefix
992
+ path_parts = self.path.replace(
993
+ "/api/benchmark/screenshots/", ""
994
+ ).split("/")
995
+ if len(path_parts) >= 4:
996
+ run_name = path_parts[0]
997
+ task_id = path_parts[1]
998
+ # path_parts[2] should be 'screenshots'
999
+ filename = path_parts[3]
1000
+
1001
+ results_dir = Path("benchmark_results")
1002
+ screenshot_path = (
1003
+ results_dir
1004
+ / run_name
1005
+ / "tasks"
1006
+ / task_id
1007
+ / "screenshots"
1008
+ / filename
1009
+ )
1010
+
1011
+ if screenshot_path.exists():
1012
+ self.send_response(200)
1013
+ self.send_header("Content-Type", "image/png")
1014
+ self.send_header("Access-Control-Allow-Origin", "*")
1015
+ self.end_headers()
1016
+ with open(screenshot_path, "rb") as f:
1017
+ self.wfile.write(f.read())
1018
+ else:
1019
+ self.send_error(
1020
+ 404, f"Screenshot not found: {screenshot_path}"
1021
+ )
1022
+ else:
1023
+ self.send_error(400, "Invalid screenshot path format")
1024
+ except Exception as e:
1025
+ self.send_error(500, f"Error serving screenshot: {e}")
1026
+ elif self.path.startswith("/api/azure-ops-sse"):
1027
+ # Server-Sent Events endpoint for Azure operations status
1028
+ try:
1029
+ self._stream_azure_ops_updates()
1030
+ except Exception as e:
1031
+ self.send_error(500, f"SSE error: {e}")
1032
+ elif self.path.startswith("/api/azure-ops-status"):
1033
+ # Return Azure operations status from JSON file
1034
+ # Session tracker provides elapsed_seconds and cost_usd for
1035
+ # persistence across page refreshes
1036
+ try:
1037
+ from openadapt_ml.benchmarks.azure_ops_tracker import read_status
1038
+ from openadapt_ml.benchmarks.session_tracker import (
1039
+ get_session,
1040
+ update_session_vm_state,
1041
+ )
1042
+
1043
+ # Get operation status (current task)
1044
+ status = read_status()
1045
+
1046
+ # Get session data (persistent across refreshes)
1047
+ session = get_session()
1048
+
1049
+ # Update session based on VM state if we have VM info
1050
+ # IMPORTANT: Only pass vm_ip if it's truthy to avoid
1051
+ # overwriting session's stable vm_ip with None
1052
+ if status.get("vm_state") and status.get("vm_state") != "unknown":
1053
+ status_vm_ip = status.get("vm_ip")
1054
+ # Build update kwargs - only include vm_ip if present
1055
+ update_kwargs = {
1056
+ "vm_state": status["vm_state"],
1057
+ "vm_size": status.get("vm_size"),
1058
+ }
1059
+ if status_vm_ip: # Only include if truthy
1060
+ update_kwargs["vm_ip"] = status_vm_ip
1061
+ session = update_session_vm_state(**update_kwargs)
1062
+
1063
+ # Use session's vm_ip as authoritative source
1064
+ # This prevents IP flickering when status file has stale/None values
1065
+ if session.get("vm_ip"):
1066
+ status["vm_ip"] = session["vm_ip"]
1067
+
1068
+ # Use session's elapsed_seconds and cost_usd for persistence
1069
+ # These survive page refreshes and track total VM runtime
1070
+ if (
1071
+ session.get("is_active")
1072
+ or session.get("accumulated_seconds", 0) > 0
1073
+ ):
1074
+ status["elapsed_seconds"] = session.get("elapsed_seconds", 0.0)
1075
+ status["cost_usd"] = session.get("cost_usd", 0.0)
1076
+ status["started_at"] = session.get("started_at")
1077
+ # Include session metadata for debugging
1078
+ status["session_id"] = session.get("session_id")
1079
+ status["session_is_active"] = session.get("is_active", False)
1080
+ # Include accumulated time from previous sessions for hybrid display
1081
+ status["accumulated_seconds"] = session.get(
1082
+ "accumulated_seconds", 0.0
1083
+ )
1084
+ # Calculate current session time (total - accumulated)
1085
+ current_session_seconds = max(
1086
+ 0, status["elapsed_seconds"] - status["accumulated_seconds"]
1087
+ )
1088
+ status["current_session_seconds"] = current_session_seconds
1089
+ status["current_session_cost_usd"] = (
1090
+ current_session_seconds / 3600
1091
+ ) * session.get("hourly_rate_usd", 0.422)
1092
+
1093
+ try:
1094
+ tunnel_mgr = get_tunnel_manager()
1095
+ tunnel_status = tunnel_mgr.get_tunnel_status()
1096
+ status["tunnels"] = {
1097
+ name: {
1098
+ "active": s.active,
1099
+ "local_port": s.local_port,
1100
+ "remote_endpoint": s.remote_endpoint,
1101
+ "pid": s.pid,
1102
+ "error": s.error,
1103
+ }
1104
+ for name, s in tunnel_status.items()
1105
+ }
1106
+ except Exception as e:
1107
+ status["tunnels"] = {"error": str(e)}
1108
+
1109
+ self.send_response(200)
1110
+ self.send_header("Content-Type", "application/json")
1111
+ self.send_header("Access-Control-Allow-Origin", "*")
1112
+ self.end_headers()
1113
+ self.wfile.write(json.dumps(status).encode())
1114
+ except Exception as e:
1115
+ self.send_response(500)
1116
+ self.send_header("Content-Type", "application/json")
1117
+ self.send_header("Access-Control-Allow-Origin", "*")
1118
+ self.end_headers()
1119
+ self.wfile.write(json.dumps({"error": str(e)}).encode())
1120
+ elif self.path.startswith("/api/vm-diagnostics"):
1121
+ # Return VM diagnostics: disk usage, Docker stats, memory usage
1122
+ try:
1123
+ diagnostics = self._get_vm_diagnostics()
1124
+ self.send_response(200)
1125
+ self.send_header("Content-Type", "application/json")
1126
+ self.send_header("Access-Control-Allow-Origin", "*")
1127
+ self.end_headers()
1128
+ self.wfile.write(json.dumps(diagnostics).encode())
1129
+ except Exception as e:
1130
+ self.send_response(500)
1131
+ self.send_header("Content-Type", "application/json")
1132
+ self.send_header("Access-Control-Allow-Origin", "*")
1133
+ self.end_headers()
1134
+ self.wfile.write(json.dumps({"error": str(e)}).encode())
555
1135
  else:
556
1136
  # Default file serving
557
1137
  super().do_GET()
558
1138
 
1139
+ def _fetch_live_azure_jobs(self):
1140
+ """Fetch live job status from Azure ML."""
1141
+ import subprocess
1142
+
1143
+ result = subprocess.run(
1144
+ [
1145
+ "az",
1146
+ "ml",
1147
+ "job",
1148
+ "list",
1149
+ "--resource-group",
1150
+ "openadapt-agents",
1151
+ "--workspace-name",
1152
+ "openadapt-ml",
1153
+ "--query",
1154
+ "[].{name:name,display_name:display_name,status:status,creation_context:creation_context.created_at}",
1155
+ "-o",
1156
+ "json",
1157
+ ],
1158
+ capture_output=True,
1159
+ text=True,
1160
+ timeout=30,
1161
+ )
1162
+ if result.returncode != 0:
1163
+ raise Exception(f"Azure CLI error: {result.stderr}")
1164
+
1165
+ jobs = json.loads(result.stdout)
1166
+ # Format for frontend
1167
+ experiment_id = "ad29082c-0607-4fda-8cc7-38944eb5a518"
1168
+ wsid = "/subscriptions/78add6c6-c92a-4a53-b751-eb644ac77e59/resourceGroups/openadapt-agents/providers/Microsoft.MachineLearningServices/workspaces/openadapt-ml"
1169
+
1170
+ formatted = []
1171
+ for job in jobs[:10]: # Limit to 10 most recent
1172
+ formatted.append(
1173
+ {
1174
+ "job_id": job.get("name", "unknown"),
1175
+ "display_name": job.get("display_name", ""),
1176
+ "status": job.get("status", "unknown").lower(),
1177
+ "started_at": job.get("creation_context", ""),
1178
+ "azure_dashboard_url": f"https://ml.azure.com/experiments/id/{experiment_id}/runs/{job.get('name', '')}?wsid={wsid}",
1179
+ "is_live": True, # Flag to indicate this is live data
1180
+ }
1181
+ )
1182
+ return formatted
1183
+
1184
+ def _fetch_azure_job_logs(self, job_id: str | None):
1185
+ """Fetch logs for an Azure ML job (streaming for running jobs)."""
1186
+ import subprocess
1187
+
1188
+ if not job_id:
1189
+ # Get the most recent running job
1190
+ jobs = self._fetch_live_azure_jobs()
1191
+ running = [j for j in jobs if j["status"] == "running"]
1192
+ if running:
1193
+ job_id = running[0]["job_id"]
1194
+ else:
1195
+ return {
1196
+ "logs": "No running jobs found",
1197
+ "job_id": None,
1198
+ "status": "idle",
1199
+ }
1200
+
1201
+ # Try to stream logs for running job using az ml job stream
1202
+ try:
1203
+ result = subprocess.run(
1204
+ [
1205
+ "az",
1206
+ "ml",
1207
+ "job",
1208
+ "stream",
1209
+ "--name",
1210
+ job_id,
1211
+ "--resource-group",
1212
+ "openadapt-agents",
1213
+ "--workspace-name",
1214
+ "openadapt-ml",
1215
+ ],
1216
+ capture_output=True,
1217
+ text=True,
1218
+ timeout=3, # Short timeout
1219
+ )
1220
+ if result.returncode == 0 and result.stdout.strip():
1221
+ return {
1222
+ "logs": result.stdout[-5000:],
1223
+ "job_id": job_id,
1224
+ "status": "streaming",
1225
+ }
1226
+ except subprocess.TimeoutExpired:
1227
+ pass # Fall through to job show
1228
+
1229
+ # Get job details instead
1230
+ result = subprocess.run(
1231
+ [
1232
+ "az",
1233
+ "ml",
1234
+ "job",
1235
+ "show",
1236
+ "--name",
1237
+ job_id,
1238
+ "--resource-group",
1239
+ "openadapt-agents",
1240
+ "--workspace-name",
1241
+ "openadapt-ml",
1242
+ "-o",
1243
+ "json",
1244
+ ],
1245
+ capture_output=True,
1246
+ text=True,
1247
+ timeout=10,
1248
+ )
1249
+
1250
+ if result.returncode == 0:
1251
+ job_info = json.loads(result.stdout)
1252
+ return {
1253
+ "logs": f"Job {job_id} is {job_info.get('status', 'unknown')}\\n\\nCommand: {job_info.get('command', 'N/A')}",
1254
+ "job_id": job_id,
1255
+ "status": job_info.get("status", "unknown").lower(),
1256
+ "command": job_info.get("command", ""),
1257
+ }
1258
+
1259
+ return {
1260
+ "logs": f"Could not fetch logs: {result.stderr}",
1261
+ "job_id": job_id,
1262
+ "status": "error",
1263
+ }
1264
+
1265
+ def _get_vm_detailed_metadata(
1266
+ self, vm_ip: str, container_name: str, logs: str, phase: str
1267
+ ) -> dict:
1268
+ """Get detailed VM metadata for the VM Details panel.
1269
+
1270
+ Returns:
1271
+ dict with disk_usage_gb, memory_usage_mb, setup_script_phase, probe_response, qmp_connected, dependencies
1272
+ """
1273
+ import subprocess
1274
+
1275
+ metadata = {
1276
+ "disk_usage_gb": None,
1277
+ "memory_usage_mb": None,
1278
+ "setup_script_phase": None,
1279
+ "probe_response": None,
1280
+ "qmp_connected": False,
1281
+ "dependencies": [],
1282
+ }
1283
+
1284
+ # 1. Get disk usage from docker stats
1285
+ try:
1286
+ disk_result = subprocess.run(
1287
+ [
1288
+ "ssh",
1289
+ "-o",
1290
+ "StrictHostKeyChecking=no",
1291
+ "-o",
1292
+ "ConnectTimeout=5",
1293
+ "-i",
1294
+ str(Path.home() / ".ssh" / "id_rsa"),
1295
+ f"azureuser@{vm_ip}",
1296
+ f"docker exec {container_name} df -h /storage 2>/dev/null | tail -1",
1297
+ ],
1298
+ capture_output=True,
1299
+ text=True,
1300
+ timeout=10,
1301
+ )
1302
+ if disk_result.returncode == 0 and disk_result.stdout.strip():
1303
+ # Parse: "Filesystem Size Used Avail Use% Mounted on"
1304
+ # Example: "/dev/sda1 30G 9.2G 20G 31% /storage"
1305
+ parts = disk_result.stdout.split()
1306
+ if len(parts) >= 3:
1307
+ used_str = parts[2] # e.g., "9.2G"
1308
+ total_str = parts[1] # e.g., "30G"
1309
+
1310
+ # Convert to GB (handle M/G suffixes)
1311
+ def to_gb(s):
1312
+ if s.endswith("G"):
1313
+ return float(s[:-1])
1314
+ elif s.endswith("M"):
1315
+ return float(s[:-1]) / 1024
1316
+ elif s.endswith("K"):
1317
+ return float(s[:-1]) / (1024 * 1024)
1318
+ return 0
1319
+
1320
+ metadata["disk_usage_gb"] = (
1321
+ f"{to_gb(used_str):.1f} GB / {to_gb(total_str):.0f} GB used"
1322
+ )
1323
+ except Exception:
1324
+ pass
1325
+
1326
+ # 2. Get memory usage from docker stats
1327
+ try:
1328
+ mem_result = subprocess.run(
1329
+ [
1330
+ "ssh",
1331
+ "-o",
1332
+ "StrictHostKeyChecking=no",
1333
+ "-o",
1334
+ "ConnectTimeout=5",
1335
+ "-i",
1336
+ str(Path.home() / ".ssh" / "id_rsa"),
1337
+ f"azureuser@{vm_ip}",
1338
+ f"docker stats {container_name} --no-stream --format '{{{{.MemUsage}}}}'",
1339
+ ],
1340
+ capture_output=True,
1341
+ text=True,
1342
+ timeout=10,
1343
+ )
1344
+ if mem_result.returncode == 0 and mem_result.stdout.strip():
1345
+ # Example: "1.5GiB / 4GiB"
1346
+ metadata["memory_usage_mb"] = mem_result.stdout.strip()
1347
+ except Exception:
1348
+ pass
1349
+
1350
+ # 3. Parse setup script phase from logs
1351
+ metadata["setup_script_phase"] = self._parse_setup_phase_from_logs(
1352
+ logs, phase
1353
+ )
1354
+
1355
+ # 4. Check /probe endpoint
1356
+ try:
1357
+ probe_result = subprocess.run(
1358
+ [
1359
+ "ssh",
1360
+ "-o",
1361
+ "StrictHostKeyChecking=no",
1362
+ "-o",
1363
+ "ConnectTimeout=5",
1364
+ "-i",
1365
+ str(Path.home() / ".ssh" / "id_rsa"),
1366
+ f"azureuser@{vm_ip}",
1367
+ "curl -s --connect-timeout 2 http://20.20.20.21:5000/probe 2>/dev/null",
1368
+ ],
1369
+ capture_output=True,
1370
+ text=True,
1371
+ timeout=10,
1372
+ )
1373
+ if probe_result.returncode == 0 and probe_result.stdout.strip():
1374
+ metadata["probe_response"] = probe_result.stdout.strip()
1375
+ else:
1376
+ metadata["probe_response"] = "Not responding"
1377
+ except Exception:
1378
+ metadata["probe_response"] = "Connection failed"
1379
+
1380
+ # 5. Check QMP connection (port 7200)
1381
+ try:
1382
+ qmp_result = subprocess.run(
1383
+ [
1384
+ "ssh",
1385
+ "-o",
1386
+ "StrictHostKeyChecking=no",
1387
+ "-o",
1388
+ "ConnectTimeout=5",
1389
+ "-i",
1390
+ str(Path.home() / ".ssh" / "id_rsa"),
1391
+ f"azureuser@{vm_ip}",
1392
+ "nc -z -w2 localhost 7200 2>&1",
1393
+ ],
1394
+ capture_output=True,
1395
+ text=True,
1396
+ timeout=10,
1397
+ )
1398
+ metadata["qmp_connected"] = qmp_result.returncode == 0
1399
+ except Exception:
1400
+ pass
1401
+
1402
+ # 6. Parse dependencies from logs
1403
+ metadata["dependencies"] = self._parse_dependencies_from_logs(logs, phase)
1404
+
1405
+ return metadata
1406
+
1407
+ def _parse_setup_phase_from_logs(self, logs: str, current_phase: str) -> str:
1408
+ """Parse the current setup script phase from logs.
1409
+
1410
+ Looks for patterns indicating which script is running:
1411
+ - install.bat
1412
+ - setup.ps1
1413
+ - on-logon.ps1
1414
+ """
1415
+ if current_phase == "ready":
1416
+ return "Setup complete"
1417
+ elif current_phase == "oobe":
1418
+ # Check for specific script patterns
1419
+ if "on-logon.ps1" in logs.lower():
1420
+ return "Running on-logon.ps1"
1421
+ elif "setup.ps1" in logs.lower():
1422
+ return "Running setup.ps1"
1423
+ elif "install.bat" in logs.lower():
1424
+ return "Running install.bat"
1425
+ else:
1426
+ return "Windows installation in progress"
1427
+ elif current_phase == "booting":
1428
+ return "Booting Windows"
1429
+ elif current_phase in [
1430
+ "downloading",
1431
+ "extracting",
1432
+ "configuring",
1433
+ "building",
1434
+ ]:
1435
+ return "Preparing Windows VM"
1436
+ else:
1437
+ return "Initializing..."
1438
+
1439
+ def _parse_dependencies_from_logs(self, logs: str, phase: str) -> list[dict]:
1440
+ """Parse dependency installation status from logs.
1441
+
1442
+ Returns list of dependencies with their installation status:
1443
+ - Python
1444
+ - Chrome
1445
+ - LibreOffice
1446
+ - VSCode
1447
+ - etc.
1448
+ """
1449
+ dependencies = [
1450
+ {"name": "Python", "icon": "🐍", "status": "pending"},
1451
+ {"name": "Chrome", "icon": "🌐", "status": "pending"},
1452
+ {"name": "LibreOffice", "icon": "📝", "status": "pending"},
1453
+ {"name": "VSCode", "icon": "💻", "status": "pending"},
1454
+ {"name": "WAA Server", "icon": "🔧", "status": "pending"},
1455
+ ]
1456
+
1457
+ if phase not in ["oobe", "ready"]:
1458
+ # Not yet at Windows setup phase
1459
+ return dependencies
1460
+
1461
+ logs_lower = logs.lower()
1462
+
1463
+ # Check for installation patterns
1464
+ if "python" in logs_lower and (
1465
+ "installing python" in logs_lower or "python.exe" in logs_lower
1466
+ ):
1467
+ dependencies[0]["status"] = "installing"
1468
+ elif "python" in logs_lower and "installed" in logs_lower:
1469
+ dependencies[0]["status"] = "complete"
1470
+
1471
+ if "chrome" in logs_lower and (
1472
+ "downloading" in logs_lower or "installing" in logs_lower
1473
+ ):
1474
+ dependencies[1]["status"] = "installing"
1475
+ elif "chrome" in logs_lower and "installed" in logs_lower:
1476
+ dependencies[1]["status"] = "complete"
1477
+
1478
+ if "libreoffice" in logs_lower and (
1479
+ "downloading" in logs_lower or "installing" in logs_lower
1480
+ ):
1481
+ dependencies[2]["status"] = "installing"
1482
+ elif "libreoffice" in logs_lower and "installed" in logs_lower:
1483
+ dependencies[2]["status"] = "complete"
1484
+
1485
+ if "vscode" in logs_lower or "visual studio code" in logs_lower:
1486
+ if "installing" in logs_lower:
1487
+ dependencies[3]["status"] = "installing"
1488
+ elif "installed" in logs_lower:
1489
+ dependencies[3]["status"] = "complete"
1490
+
1491
+ if "waa" in logs_lower or "flask" in logs_lower:
1492
+ if "starting" in logs_lower or "running" in logs_lower:
1493
+ dependencies[4]["status"] = "installing"
1494
+ elif phase == "ready":
1495
+ dependencies[4]["status"] = "complete"
1496
+
1497
+ return dependencies
1498
+
1499
+ def _get_vm_diagnostics(self) -> dict:
1500
+ """Get VM diagnostics: disk usage, Docker stats, memory usage.
1501
+
1502
+ Returns a dictionary with:
1503
+ - vm_online: bool - whether VM is reachable
1504
+ - disk_usage: list of disk partitions with usage stats
1505
+ - docker_stats: list of container stats (CPU, memory)
1506
+ - memory_usage: VM host memory stats
1507
+ - docker_system: Docker system disk usage
1508
+ - error: str if any error occurred
1509
+ """
1510
+ import subprocess
1511
+
1512
+ from openadapt_ml.benchmarks.session_tracker import get_session
1513
+
1514
+ diagnostics = {
1515
+ "vm_online": False,
1516
+ "disk_usage": [],
1517
+ "docker_stats": [],
1518
+ "memory_usage": {},
1519
+ "docker_system": {},
1520
+ "docker_images": [],
1521
+ "error": None,
1522
+ }
1523
+
1524
+ # Get VM IP from session
1525
+ session = get_session()
1526
+ vm_ip = session.get("vm_ip")
1527
+
1528
+ if not vm_ip:
1529
+ diagnostics["error"] = (
1530
+ "VM IP not found in session. VM may not be running."
1531
+ )
1532
+ return diagnostics
1533
+
1534
+ # SSH options for Azure VM
1535
+ ssh_opts = [
1536
+ "-o",
1537
+ "StrictHostKeyChecking=no",
1538
+ "-o",
1539
+ "UserKnownHostsFile=/dev/null",
1540
+ "-o",
1541
+ "ConnectTimeout=10",
1542
+ "-o",
1543
+ "ServerAliveInterval=30",
1544
+ ]
1545
+
1546
+ # Test VM connectivity
1547
+ try:
1548
+ test_result = subprocess.run(
1549
+ ["ssh", *ssh_opts, f"azureuser@{vm_ip}", "echo 'online'"],
1550
+ capture_output=True,
1551
+ text=True,
1552
+ timeout=15,
1553
+ )
1554
+ if test_result.returncode != 0:
1555
+ diagnostics["error"] = f"Cannot connect to VM at {vm_ip}"
1556
+ return diagnostics
1557
+ diagnostics["vm_online"] = True
1558
+ except subprocess.TimeoutExpired:
1559
+ diagnostics["error"] = f"Connection to VM at {vm_ip} timed out"
1560
+ return diagnostics
1561
+ except Exception as e:
1562
+ diagnostics["error"] = f"SSH error: {str(e)}"
1563
+ return diagnostics
1564
+
1565
+ # 1. Disk usage (df -h)
1566
+ try:
1567
+ df_result = subprocess.run(
1568
+ [
1569
+ "ssh",
1570
+ *ssh_opts,
1571
+ f"azureuser@{vm_ip}",
1572
+ "df -h / /mnt 2>/dev/null | tail -n +2",
1573
+ ],
1574
+ capture_output=True,
1575
+ text=True,
1576
+ timeout=15,
1577
+ )
1578
+ if df_result.returncode == 0 and df_result.stdout.strip():
1579
+ for line in df_result.stdout.strip().split("\n"):
1580
+ parts = line.split()
1581
+ if len(parts) >= 6:
1582
+ diagnostics["disk_usage"].append(
1583
+ {
1584
+ "filesystem": parts[0],
1585
+ "size": parts[1],
1586
+ "used": parts[2],
1587
+ "available": parts[3],
1588
+ "use_percent": parts[4],
1589
+ "mount_point": parts[5],
1590
+ }
1591
+ )
1592
+ except Exception as e:
1593
+ diagnostics["disk_usage"] = [{"error": str(e)}]
1594
+
1595
+ # 2. Docker container stats
1596
+ try:
1597
+ stats_result = subprocess.run(
1598
+ [
1599
+ "ssh",
1600
+ *ssh_opts,
1601
+ f"azureuser@{vm_ip}",
1602
+ "docker stats --no-stream --format '{{.Name}}|{{.CPUPerc}}|{{.MemUsage}}|{{.MemPerc}}|{{.NetIO}}|{{.BlockIO}}' 2>/dev/null || echo ''",
1603
+ ],
1604
+ capture_output=True,
1605
+ text=True,
1606
+ timeout=30,
1607
+ )
1608
+ if stats_result.returncode == 0 and stats_result.stdout.strip():
1609
+ for line in stats_result.stdout.strip().split("\n"):
1610
+ if "|" in line:
1611
+ parts = line.split("|")
1612
+ if len(parts) >= 6:
1613
+ diagnostics["docker_stats"].append(
1614
+ {
1615
+ "container": parts[0],
1616
+ "cpu_percent": parts[1],
1617
+ "memory_usage": parts[2],
1618
+ "memory_percent": parts[3],
1619
+ "net_io": parts[4],
1620
+ "block_io": parts[5],
1621
+ }
1622
+ )
1623
+ except Exception as e:
1624
+ diagnostics["docker_stats"] = [{"error": str(e)}]
1625
+
1626
+ # 3. VM host memory usage (free -h)
1627
+ try:
1628
+ mem_result = subprocess.run(
1629
+ [
1630
+ "ssh",
1631
+ *ssh_opts,
1632
+ f"azureuser@{vm_ip}",
1633
+ "free -h | head -2 | tail -1",
1634
+ ],
1635
+ capture_output=True,
1636
+ text=True,
1637
+ timeout=15,
1638
+ )
1639
+ if mem_result.returncode == 0 and mem_result.stdout.strip():
1640
+ parts = mem_result.stdout.strip().split()
1641
+ if len(parts) >= 7:
1642
+ diagnostics["memory_usage"] = {
1643
+ "total": parts[1],
1644
+ "used": parts[2],
1645
+ "free": parts[3],
1646
+ "shared": parts[4],
1647
+ "buff_cache": parts[5],
1648
+ "available": parts[6],
1649
+ }
1650
+ except Exception as e:
1651
+ diagnostics["memory_usage"] = {"error": str(e)}
1652
+
1653
+ # 4. Docker system disk usage
1654
+ try:
1655
+ docker_df_result = subprocess.run(
1656
+ [
1657
+ "ssh",
1658
+ *ssh_opts,
1659
+ f"azureuser@{vm_ip}",
1660
+ "docker system df 2>/dev/null || echo ''",
1661
+ ],
1662
+ capture_output=True,
1663
+ text=True,
1664
+ timeout=15,
1665
+ )
1666
+ if docker_df_result.returncode == 0 and docker_df_result.stdout.strip():
1667
+ lines = docker_df_result.stdout.strip().split("\n")
1668
+ # Parse the table: TYPE, TOTAL, ACTIVE, SIZE, RECLAIMABLE
1669
+ for line in lines[1:]: # Skip header
1670
+ parts = line.split()
1671
+ if len(parts) >= 5:
1672
+ dtype = parts[0]
1673
+ diagnostics["docker_system"][dtype.lower()] = {
1674
+ "total": parts[1],
1675
+ "active": parts[2],
1676
+ "size": parts[3],
1677
+ "reclaimable": " ".join(parts[4:]),
1678
+ }
1679
+ except Exception as e:
1680
+ diagnostics["docker_system"] = {"error": str(e)}
1681
+
1682
+ # 5. Docker images
1683
+ try:
1684
+ images_result = subprocess.run(
1685
+ [
1686
+ "ssh",
1687
+ *ssh_opts,
1688
+ f"azureuser@{vm_ip}",
1689
+ "docker images --format '{{.Repository}}:{{.Tag}}|{{.Size}}|{{.CreatedSince}}' 2>/dev/null || echo ''",
1690
+ ],
1691
+ capture_output=True,
1692
+ text=True,
1693
+ timeout=15,
1694
+ )
1695
+ if images_result.returncode == 0 and images_result.stdout.strip():
1696
+ for line in images_result.stdout.strip().split("\n"):
1697
+ if "|" in line:
1698
+ parts = line.split("|")
1699
+ if len(parts) >= 3:
1700
+ diagnostics["docker_images"].append(
1701
+ {
1702
+ "image": parts[0],
1703
+ "size": parts[1],
1704
+ "created": parts[2],
1705
+ }
1706
+ )
1707
+ except Exception as e:
1708
+ diagnostics["docker_images"] = [{"error": str(e)}]
1709
+
1710
+ return diagnostics
1711
+
1712
+ def _fetch_background_tasks(self):
1713
+ """Fetch status of all background tasks: Azure VM, Docker containers, benchmarks."""
1714
+ import subprocess
1715
+
1716
+ tasks = []
1717
+
1718
+ # Check for VM IP from environment (set by CLI when auto-launching viewer)
1719
+ env_vm_ip = os.environ.get("WAA_VM_IP")
1720
+ env_internal_ip = os.environ.get("WAA_INTERNAL_IP", "172.30.0.2")
1721
+
1722
+ # 1. Check Azure WAA VM status
1723
+ vm_ip = None
1724
+ if env_vm_ip:
1725
+ # Use environment variable - VM IP was provided directly
1726
+ vm_ip = env_vm_ip
1727
+ tasks.append(
1728
+ {
1729
+ "task_id": "azure-vm-waa",
1730
+ "task_type": "vm_provision",
1731
+ "status": "completed",
1732
+ "phase": "ready", # Match status to prevent "Starting" + "completed" conflict
1733
+ "title": "Azure VM Host",
1734
+ "description": f"Linux host running at {vm_ip}",
1735
+ "progress_percent": 100.0,
1736
+ "elapsed_seconds": 0,
1737
+ "metadata": {
1738
+ "vm_name": "waa-eval-vm",
1739
+ "ip_address": vm_ip,
1740
+ "internal_ip": env_internal_ip,
1741
+ },
1742
+ }
1743
+ )
1744
+ else:
1745
+ # Query Azure CLI for VM status
1746
+ try:
1747
+ result = subprocess.run(
1748
+ [
1749
+ "az",
1750
+ "vm",
1751
+ "get-instance-view",
1752
+ "--name",
1753
+ "waa-eval-vm",
1754
+ "--resource-group",
1755
+ "openadapt-agents",
1756
+ "--query",
1757
+ "instanceView.statuses",
1758
+ "-o",
1759
+ "json",
1760
+ ],
1761
+ capture_output=True,
1762
+ text=True,
1763
+ timeout=10,
1764
+ )
1765
+ if result.returncode == 0:
1766
+ statuses = json.loads(result.stdout)
1767
+ power_state = "unknown"
1768
+ for s in statuses:
1769
+ if s.get("code", "").startswith("PowerState/"):
1770
+ power_state = s["code"].replace("PowerState/", "")
1771
+
1772
+ # Get VM IP
1773
+ ip_result = subprocess.run(
1774
+ [
1775
+ "az",
1776
+ "vm",
1777
+ "list-ip-addresses",
1778
+ "--name",
1779
+ "waa-eval-vm",
1780
+ "--resource-group",
1781
+ "openadapt-agents",
1782
+ "--query",
1783
+ "[0].virtualMachine.network.publicIpAddresses[0].ipAddress",
1784
+ "-o",
1785
+ "tsv",
1786
+ ],
1787
+ capture_output=True,
1788
+ text=True,
1789
+ timeout=10,
1790
+ )
1791
+ vm_ip = (
1792
+ ip_result.stdout.strip()
1793
+ if ip_result.returncode == 0
1794
+ else None
1795
+ )
1796
+
1797
+ if power_state == "running":
1798
+ tasks.append(
1799
+ {
1800
+ "task_id": "azure-vm-waa",
1801
+ "task_type": "vm_provision",
1802
+ "status": "completed",
1803
+ "phase": "ready", # Match status to prevent "Starting" + "completed" conflict
1804
+ "title": "Azure VM Host",
1805
+ "description": f"Linux host running at {vm_ip}"
1806
+ if vm_ip
1807
+ else "Linux host running",
1808
+ "progress_percent": 100.0,
1809
+ "elapsed_seconds": 0,
1810
+ "metadata": {
1811
+ "vm_name": "waa-eval-vm",
1812
+ "ip_address": vm_ip,
1813
+ # No VNC link - that's for the Windows container
1814
+ },
1815
+ }
1816
+ )
1817
+ except subprocess.TimeoutExpired:
1818
+ pass
1819
+ except Exception:
1820
+ pass
1821
+
1822
+ # 2. Check Docker container status on VM (if we have an IP)
1823
+ if vm_ip:
1824
+ try:
1825
+ docker_result = subprocess.run(
1826
+ [
1827
+ "ssh",
1828
+ "-o",
1829
+ "StrictHostKeyChecking=no",
1830
+ "-o",
1831
+ "ConnectTimeout=5",
1832
+ "-i",
1833
+ str(Path.home() / ".ssh" / "id_rsa"),
1834
+ f"azureuser@{vm_ip}",
1835
+ "docker ps --format '{{.Names}}|{{.Status}}|{{.Image}}'",
1836
+ ],
1837
+ capture_output=True,
1838
+ text=True,
1839
+ timeout=15,
1840
+ )
1841
+ if docker_result.returncode == 0 and docker_result.stdout.strip():
1842
+ for line in docker_result.stdout.strip().split("\n"):
1843
+ parts = line.split("|")
1844
+ if len(parts) >= 3:
1845
+ container_name, status, image = (
1846
+ parts[0],
1847
+ parts[1],
1848
+ parts[2],
1849
+ )
1850
+ # Parse "Up X minutes" to determine if healthy
1851
+
1852
+ # Check for Windows VM specifically
1853
+ if (
1854
+ "windows" in image.lower()
1855
+ or container_name == "winarena"
1856
+ ):
1857
+ # Get detailed progress from docker logs
1858
+ log_check = subprocess.run(
1859
+ [
1860
+ "ssh",
1861
+ "-o",
1862
+ "StrictHostKeyChecking=no",
1863
+ "-o",
1864
+ "ConnectTimeout=5",
1865
+ "-i",
1866
+ str(Path.home() / ".ssh" / "id_rsa"),
1867
+ f"azureuser@{vm_ip}",
1868
+ f"docker logs {container_name} 2>&1 | tail -30",
1869
+ ],
1870
+ capture_output=True,
1871
+ text=True,
1872
+ timeout=10,
1873
+ )
1874
+ logs = (
1875
+ log_check.stdout
1876
+ if log_check.returncode == 0
1877
+ else ""
1878
+ )
1879
+
1880
+ # Parse progress from logs
1881
+ phase = "unknown"
1882
+ progress = 0.0
1883
+ description = "Starting..."
1884
+
1885
+ if "Windows started successfully" in logs:
1886
+ # Check if WAA server is ready via Docker port forwarding
1887
+ # See docs/waa_network_architecture.md - always use localhost
1888
+ server_check = subprocess.run(
1889
+ [
1890
+ "ssh",
1891
+ "-o",
1892
+ "StrictHostKeyChecking=no",
1893
+ "-o",
1894
+ "ConnectTimeout=5",
1895
+ "-i",
1896
+ str(Path.home() / ".ssh" / "id_rsa"),
1897
+ f"azureuser@{vm_ip}",
1898
+ "curl -s --connect-timeout 2 http://localhost:5000/probe 2>/dev/null",
1899
+ ],
1900
+ capture_output=True,
1901
+ text=True,
1902
+ timeout=10,
1903
+ )
1904
+ waa_ready = (
1905
+ server_check.returncode == 0
1906
+ and "Service is operational"
1907
+ in server_check.stdout
1908
+ )
1909
+ if waa_ready:
1910
+ phase = "ready"
1911
+ progress = 100.0
1912
+ description = (
1913
+ "WAA Server ready - benchmarks can run"
1914
+ )
1915
+ else:
1916
+ phase = "oobe"
1917
+ progress = 80.0 # Phase 5/6 - VM install in progress
1918
+ description = "Phase 5/6: Windows installing (check VNC for %). OEM scripts will run after."
1919
+ elif "Booting Windows" in logs:
1920
+ phase = "booting"
1921
+ progress = 70.0 # Phase 4/6
1922
+ description = "Phase 4/6: Booting Windows from installer..."
1923
+ elif (
1924
+ "Building Windows" in logs
1925
+ or "Creating a" in logs
1926
+ ):
1927
+ phase = "building"
1928
+ progress = 60.0 # Phase 3/6
1929
+ description = (
1930
+ "Phase 3/6: Building Windows VM disk..."
1931
+ )
1932
+ elif "Adding" in logs and "image" in logs:
1933
+ phase = "configuring"
1934
+ progress = 50.0 # Phase 2/6
1935
+ description = "Phase 2/6: Configuring Windows image with WAA scripts..."
1936
+ elif "Extracting" in logs:
1937
+ phase = "extracting"
1938
+ progress = 35.0 # Phase 1/6 (after download)
1939
+ description = (
1940
+ "Phase 1/6: Extracting Windows ISO..."
1941
+ )
1942
+ else:
1943
+ # Check for download progress (e.g., "1234K ........ 45% 80M 30s")
1944
+ import re
1945
+
1946
+ download_match = re.search(
1947
+ r"(\d+)%\s+[\d.]+[KMG]\s+(\d+)s", logs
1948
+ )
1949
+ if download_match:
1950
+ phase = "downloading"
1951
+ dl_pct = float(download_match.group(1))
1952
+ progress = (
1953
+ dl_pct * 0.30
1954
+ ) # 0-30% for download phase
1955
+ eta = download_match.group(2)
1956
+ description = f"Phase 0/6: Downloading Windows 11... {download_match.group(1)}% ({eta}s left)"
1957
+
1958
+ # Improve phase detection - if Windows is booted but WAA not ready,
1959
+ # it might be at login screen waiting for OEM scripts or running install.bat
1960
+ if phase == "oobe" and "Boot0004" in logs:
1961
+ # Windows finished installing, at login/desktop
1962
+ # install.bat should auto-run from FirstLogonCommands (see Dockerfile)
1963
+ description = "Phase 5/6: Windows at desktop, OEM scripts running... (WAA server starting)"
1964
+ progress = 90.0
1965
+
1966
+ # Get detailed metadata for VM Details panel
1967
+ vm_metadata = self._get_vm_detailed_metadata(
1968
+ vm_ip, container_name, logs, phase
1969
+ )
1970
+
1971
+ tasks.append(
1972
+ {
1973
+ "task_id": f"docker-{container_name}",
1974
+ "task_type": "docker_container",
1975
+ "status": "completed"
1976
+ if phase == "ready"
1977
+ else "running",
1978
+ "title": "Windows 11 + WAA Server",
1979
+ "description": description,
1980
+ "progress_percent": progress,
1981
+ "elapsed_seconds": 0,
1982
+ "phase": phase,
1983
+ "metadata": {
1984
+ "container": container_name,
1985
+ "image": image,
1986
+ "status": status,
1987
+ "phase": phase,
1988
+ "windows_ready": phase
1989
+ in ["oobe", "ready"],
1990
+ "waa_server_ready": phase == "ready",
1991
+ # Use localhost - SSH tunnel handles routing to VM
1992
+ # See docs/waa_network_architecture.md
1993
+ "vnc_url": "http://localhost:8006",
1994
+ "windows_username": "Docker",
1995
+ "windows_password": "admin",
1996
+ "recent_logs": logs[-500:]
1997
+ if logs
1998
+ else "",
1999
+ # Enhanced VM details
2000
+ "disk_usage_gb": vm_metadata[
2001
+ "disk_usage_gb"
2002
+ ],
2003
+ "memory_usage_mb": vm_metadata[
2004
+ "memory_usage_mb"
2005
+ ],
2006
+ "setup_script_phase": vm_metadata[
2007
+ "setup_script_phase"
2008
+ ],
2009
+ "probe_response": vm_metadata[
2010
+ "probe_response"
2011
+ ],
2012
+ "qmp_connected": vm_metadata[
2013
+ "qmp_connected"
2014
+ ],
2015
+ "dependencies": vm_metadata[
2016
+ "dependencies"
2017
+ ],
2018
+ },
2019
+ }
2020
+ )
2021
+ except Exception:
2022
+ # SSH failed, VM might still be starting
2023
+ pass
2024
+
2025
+ # 3. Check local benchmark progress
2026
+ progress_file = Path("benchmark_progress.json")
2027
+ if progress_file.exists():
2028
+ try:
2029
+ progress = json.loads(progress_file.read_text())
2030
+ if progress.get("status") == "running":
2031
+ tasks.append(
2032
+ {
2033
+ "task_id": "benchmark-local",
2034
+ "task_type": "benchmark_run",
2035
+ "status": "running",
2036
+ "title": f"{progress.get('provider', 'API').upper()} Benchmark",
2037
+ "description": progress.get(
2038
+ "message", "Running benchmark..."
2039
+ ),
2040
+ "progress_percent": (
2041
+ progress.get("tasks_complete", 0)
2042
+ / max(progress.get("tasks_total", 1), 1)
2043
+ )
2044
+ * 100,
2045
+ "elapsed_seconds": 0,
2046
+ "metadata": progress,
2047
+ }
2048
+ )
2049
+ except Exception:
2050
+ pass
2051
+
2052
+ return tasks
2053
+
2054
+ def _fetch_vm_registry(self):
2055
+ """Fetch VM registry with live status checks.
2056
+
2057
+ NOTE: We now fetch the VM IP from Azure CLI at runtime to avoid
2058
+ stale IP issues. The registry file is only used as a fallback.
2059
+ """
2060
+ import subprocess
2061
+ from datetime import datetime
2062
+
2063
+ # Try to get VM IP from Azure CLI (always fresh)
2064
+ vm_ip = None
2065
+ resource_group = "openadapt-agents"
2066
+ vm_name = "azure-waa-vm"
2067
+ try:
2068
+ result = subprocess.run(
2069
+ [
2070
+ "az",
2071
+ "vm",
2072
+ "show",
2073
+ "-d",
2074
+ "-g",
2075
+ resource_group,
2076
+ "-n",
2077
+ vm_name,
2078
+ "--query",
2079
+ "publicIps",
2080
+ "-o",
2081
+ "tsv",
2082
+ ],
2083
+ capture_output=True,
2084
+ text=True,
2085
+ timeout=10,
2086
+ )
2087
+ if result.returncode == 0 and result.stdout.strip():
2088
+ vm_ip = result.stdout.strip()
2089
+ except Exception:
2090
+ pass
2091
+
2092
+ # If we have a fresh IP from Azure, use it
2093
+ if vm_ip:
2094
+ vms = [
2095
+ {
2096
+ "name": vm_name,
2097
+ "ssh_host": vm_ip,
2098
+ "ssh_user": "azureuser",
2099
+ "vnc_port": 8006,
2100
+ "waa_port": 5000,
2101
+ "docker_container": "winarena",
2102
+ "internal_ip": "localhost",
2103
+ }
2104
+ ]
2105
+ else:
2106
+ # Fallback to registry file
2107
+ project_root = Path(__file__).parent.parent.parent
2108
+ registry_file = project_root / "benchmark_results" / "vm_registry.json"
2109
+
2110
+ if not registry_file.exists():
2111
+ return []
2112
+
2113
+ try:
2114
+ with open(registry_file) as f:
2115
+ vms = json.load(f)
2116
+ except Exception as e:
2117
+ return {"error": f"Failed to read VM registry: {e}"}
2118
+
2119
+ # Check status for each VM
2120
+ for vm in vms:
2121
+ vm["status"] = "unknown"
2122
+ vm["last_checked"] = datetime.now().isoformat()
2123
+ vm["vnc_reachable"] = False
2124
+ vm["waa_probe_status"] = "unknown"
2125
+
2126
+ # Check VNC (HTTP HEAD request)
2127
+ try:
2128
+ vnc_url = f"http://{vm['ssh_host']}:{vm['vnc_port']}"
2129
+ result = subprocess.run(
2130
+ ["curl", "-I", "-s", "--connect-timeout", "3", vnc_url],
2131
+ capture_output=True,
2132
+ text=True,
2133
+ timeout=5,
2134
+ )
2135
+ if result.returncode == 0 and "200" in result.stdout:
2136
+ vm["vnc_reachable"] = True
2137
+ except Exception:
2138
+ pass
2139
+
2140
+ # Check WAA probe via SSH
2141
+ # Probe WAA via localhost (Docker port forwarding handles routing)
2142
+ # See docs/waa_network_architecture.md for architecture details
2143
+ try:
2144
+ waa_port = vm.get("waa_port", 5000)
2145
+ ssh_cmd = f"curl -s --connect-timeout 2 http://localhost:{waa_port}/probe 2>/dev/null"
2146
+ result = subprocess.run(
2147
+ [
2148
+ "ssh",
2149
+ "-o",
2150
+ "StrictHostKeyChecking=no",
2151
+ "-o",
2152
+ "ConnectTimeout=3",
2153
+ "-i",
2154
+ str(Path.home() / ".ssh" / "id_rsa"),
2155
+ f"{vm['ssh_user']}@{vm['ssh_host']}",
2156
+ ssh_cmd,
2157
+ ],
2158
+ capture_output=True,
2159
+ text=True,
2160
+ timeout=5,
2161
+ )
2162
+ probe_success = (
2163
+ result.returncode == 0
2164
+ and "Service is operational" in result.stdout
2165
+ )
2166
+ if probe_success:
2167
+ vm["waa_probe_status"] = "ready"
2168
+ vm["status"] = "online"
2169
+ # Auto-start SSH tunnels for VNC and WAA
2170
+ try:
2171
+ tunnel_mgr = get_tunnel_manager()
2172
+ tunnel_status = tunnel_mgr.ensure_tunnels_for_vm(
2173
+ vm_ip=vm["ssh_host"],
2174
+ ssh_user=vm.get("ssh_user", "azureuser"),
2175
+ )
2176
+ vm["tunnels"] = {
2177
+ name: {
2178
+ "active": s.active,
2179
+ "local_port": s.local_port,
2180
+ "error": s.error,
2181
+ }
2182
+ for name, s in tunnel_status.items()
2183
+ }
2184
+ except Exception as e:
2185
+ vm["tunnels"] = {"error": str(e)}
2186
+ else:
2187
+ vm["waa_probe_status"] = "not responding"
2188
+ vm["status"] = "offline"
2189
+ # Stop tunnels when VM goes offline
2190
+ try:
2191
+ tunnel_mgr = get_tunnel_manager()
2192
+ tunnel_mgr.stop_all_tunnels()
2193
+ vm["tunnels"] = {}
2194
+ except Exception:
2195
+ pass
2196
+ except Exception:
2197
+ vm["waa_probe_status"] = "ssh failed"
2198
+ vm["status"] = "offline"
2199
+
2200
+ return vms
2201
+
2202
+ def _probe_vm(self) -> dict:
2203
+ """Probe the Azure VM to check if WAA server is responding.
2204
+
2205
+ Returns:
2206
+ dict with:
2207
+ - responding: bool - whether the WAA server is responding
2208
+ - vm_ip: str - the VM's IP address
2209
+ - container: str - the container name
2210
+ - probe_result: str - the raw probe response or error message
2211
+ - last_checked: str - ISO timestamp
2212
+ """
2213
+ import subprocess
2214
+ from datetime import datetime
2215
+
2216
+ result = {
2217
+ "responding": False,
2218
+ "vm_ip": None,
2219
+ "container": None,
2220
+ "probe_result": None,
2221
+ "last_checked": datetime.now().isoformat(),
2222
+ }
2223
+
2224
+ # First get VM IP
2225
+ try:
2226
+ ip_result = subprocess.run(
2227
+ [
2228
+ "az",
2229
+ "vm",
2230
+ "list-ip-addresses",
2231
+ "--name",
2232
+ "waa-eval-vm",
2233
+ "--resource-group",
2234
+ "openadapt-agents",
2235
+ "--query",
2236
+ "[0].virtualMachine.network.publicIpAddresses[0].ipAddress",
2237
+ "-o",
2238
+ "tsv",
2239
+ ],
2240
+ capture_output=True,
2241
+ text=True,
2242
+ timeout=10,
2243
+ )
2244
+ if ip_result.returncode == 0 and ip_result.stdout.strip():
2245
+ vm_ip = ip_result.stdout.strip()
2246
+ result["vm_ip"] = vm_ip
2247
+
2248
+ # Try to probe WAA server via SSH
2249
+ # Use the correct internal IP for the Windows VM inside Docker
2250
+ probe_result = subprocess.run(
2251
+ [
2252
+ "ssh",
2253
+ "-o",
2254
+ "StrictHostKeyChecking=no",
2255
+ "-o",
2256
+ "ConnectTimeout=5",
2257
+ "-i",
2258
+ str(Path.home() / ".ssh" / "id_rsa"),
2259
+ f"azureuser@{vm_ip}",
2260
+ "docker exec waa-container curl -s --connect-timeout 3 http://172.30.0.2:5000/probe 2>/dev/null || echo 'probe_failed'",
2261
+ ],
2262
+ capture_output=True,
2263
+ text=True,
2264
+ timeout=15,
2265
+ )
2266
+
2267
+ result["container"] = "waa-container"
2268
+
2269
+ if probe_result.returncode == 0:
2270
+ probe_output = probe_result.stdout.strip()
2271
+ if probe_output and "probe_failed" not in probe_output:
2272
+ result["responding"] = True
2273
+ result["probe_result"] = probe_output
2274
+ else:
2275
+ result["probe_result"] = "WAA server not responding"
2276
+ else:
2277
+ result["probe_result"] = (
2278
+ f"SSH/Docker error: {probe_result.stderr[:200]}"
2279
+ )
2280
+ else:
2281
+ result["probe_result"] = "Could not get VM IP"
2282
+ except subprocess.TimeoutExpired:
2283
+ result["probe_result"] = "Connection timeout"
2284
+ except Exception as e:
2285
+ result["probe_result"] = f"Error: {str(e)}"
2286
+
2287
+ return result
2288
+
2289
+ def _get_current_run(self) -> dict:
2290
+ """Get info about any currently running benchmark.
2291
+
2292
+ Checks:
2293
+ 1. Local benchmark_progress.json for API benchmarks
2294
+ 2. Azure VM for WAA benchmarks running via SSH
2295
+
2296
+ Returns:
2297
+ dict with:
2298
+ - running: bool - whether a benchmark is running
2299
+ - type: str - 'local' or 'azure_vm'
2300
+ - model: str - model being evaluated
2301
+ - progress: dict with tasks_completed, total_tasks
2302
+ - current_task: str - current task ID
2303
+ - started_at: str - ISO timestamp
2304
+ - elapsed_minutes: int
2305
+ """
2306
+ import subprocess
2307
+ import re
2308
+
2309
+ result = {
2310
+ "running": False,
2311
+ "type": None,
2312
+ "model": None,
2313
+ "progress": {"tasks_completed": 0, "total_tasks": 0},
2314
+ "current_task": None,
2315
+ "started_at": None,
2316
+ "elapsed_minutes": 0,
2317
+ }
2318
+
2319
+ # Check local benchmark progress first
2320
+ progress_file = Path("benchmark_progress.json")
2321
+ if progress_file.exists():
2322
+ try:
2323
+ progress = json.loads(progress_file.read_text())
2324
+ if progress.get("status") == "running":
2325
+ result["running"] = True
2326
+ result["type"] = "local"
2327
+ result["model"] = progress.get("provider", "unknown")
2328
+ result["progress"]["tasks_completed"] = progress.get(
2329
+ "tasks_complete", 0
2330
+ )
2331
+ result["progress"]["total_tasks"] = progress.get(
2332
+ "tasks_total", 0
2333
+ )
2334
+ return result
2335
+ except Exception:
2336
+ pass
2337
+
2338
+ # Check Azure VM for running benchmark
2339
+ try:
2340
+ # Get VM IP
2341
+ ip_result = subprocess.run(
2342
+ [
2343
+ "az",
2344
+ "vm",
2345
+ "list-ip-addresses",
2346
+ "--name",
2347
+ "waa-eval-vm",
2348
+ "--resource-group",
2349
+ "openadapt-agents",
2350
+ "--query",
2351
+ "[0].virtualMachine.network.publicIpAddresses[0].ipAddress",
2352
+ "-o",
2353
+ "tsv",
2354
+ ],
2355
+ capture_output=True,
2356
+ text=True,
2357
+ timeout=10,
2358
+ )
2359
+
2360
+ if ip_result.returncode == 0 and ip_result.stdout.strip():
2361
+ vm_ip = ip_result.stdout.strip()
2362
+
2363
+ # Check if benchmark process is running
2364
+ process_check = subprocess.run(
2365
+ [
2366
+ "ssh",
2367
+ "-o",
2368
+ "StrictHostKeyChecking=no",
2369
+ "-o",
2370
+ "ConnectTimeout=5",
2371
+ "-i",
2372
+ str(Path.home() / ".ssh" / "id_rsa"),
2373
+ f"azureuser@{vm_ip}",
2374
+ "docker exec waa-container pgrep -f 'python.*run.py' 2>/dev/null && echo 'RUNNING' || echo 'NOT_RUNNING'",
2375
+ ],
2376
+ capture_output=True,
2377
+ text=True,
2378
+ timeout=10,
2379
+ )
2380
+
2381
+ if (
2382
+ process_check.returncode == 0
2383
+ and "RUNNING" in process_check.stdout
2384
+ ):
2385
+ result["running"] = True
2386
+ result["type"] = "azure_vm"
2387
+
2388
+ # Get log file for more details
2389
+ log_check = subprocess.run(
2390
+ [
2391
+ "ssh",
2392
+ "-o",
2393
+ "StrictHostKeyChecking=no",
2394
+ "-o",
2395
+ "ConnectTimeout=5",
2396
+ "-i",
2397
+ str(Path.home() / ".ssh" / "id_rsa"),
2398
+ f"azureuser@{vm_ip}",
2399
+ "tail -100 /tmp/waa_benchmark.log 2>/dev/null || echo ''",
2400
+ ],
2401
+ capture_output=True,
2402
+ text=True,
2403
+ timeout=10,
2404
+ )
2405
+
2406
+ if log_check.returncode == 0 and log_check.stdout.strip():
2407
+ logs = log_check.stdout
2408
+
2409
+ # Parse model from logs
2410
+ model_match = re.search(
2411
+ r"model[=:\s]+([^\s,]+)", logs, re.IGNORECASE
2412
+ )
2413
+ if model_match:
2414
+ result["model"] = model_match.group(1)
2415
+
2416
+ # Parse progress
2417
+ task_match = re.search(r"Task\s+(\d+)/(\d+)", logs)
2418
+ if task_match:
2419
+ result["progress"]["tasks_completed"] = int(
2420
+ task_match.group(1)
2421
+ )
2422
+ result["progress"]["total_tasks"] = int(
2423
+ task_match.group(2)
2424
+ )
2425
+
2426
+ # Parse current task
2427
+ task_id_match = re.search(
2428
+ r"(?:Running|Processing|task)[:\s]+([a-f0-9-]+)",
2429
+ logs,
2430
+ re.IGNORECASE,
2431
+ )
2432
+ if task_id_match:
2433
+ result["current_task"] = task_id_match.group(1)
2434
+
2435
+ except Exception:
2436
+ pass
2437
+
2438
+ return result
2439
+
2440
+ def _get_benchmark_status(self) -> dict:
2441
+ """Get current benchmark job status with ETA calculation.
2442
+
2443
+ Returns:
2444
+ dict with job status, progress, ETA, and current task info
2445
+ """
2446
+ import time
2447
+
2448
+ # Check for live evaluation state
2449
+ live_file = Path("benchmark_live.json")
2450
+ if live_file.exists():
2451
+ try:
2452
+ live_state = json.loads(live_file.read_text())
2453
+ if live_state.get("status") == "running":
2454
+ total_tasks = live_state.get("total_tasks", 0)
2455
+ completed_tasks = live_state.get("tasks_completed", 0)
2456
+ current_task = live_state.get("current_task", {})
2457
+
2458
+ # Calculate ETA based on completed tasks
2459
+ eta_seconds = None
2460
+ avg_task_seconds = None
2461
+ if completed_tasks > 0 and total_tasks > 0:
2462
+ # Estimate from live state timestamp or use fallback
2463
+ elapsed = time.time() - live_state.get(
2464
+ "start_time", time.time()
2465
+ )
2466
+ avg_task_seconds = (
2467
+ elapsed / completed_tasks
2468
+ if completed_tasks > 0
2469
+ else 30.0
2470
+ )
2471
+ remaining_tasks = total_tasks - completed_tasks
2472
+ eta_seconds = remaining_tasks * avg_task_seconds
2473
+
2474
+ return {
2475
+ "status": "running",
2476
+ "current_job": {
2477
+ "run_id": live_state.get("run_id", "unknown"),
2478
+ "model_id": live_state.get("model_id", "unknown"),
2479
+ "total_tasks": total_tasks,
2480
+ "completed_tasks": completed_tasks,
2481
+ "current_task": current_task,
2482
+ "eta_seconds": eta_seconds,
2483
+ "avg_task_seconds": avg_task_seconds,
2484
+ },
2485
+ "queue": [], # TODO: implement queue tracking
2486
+ }
2487
+ except Exception as e:
2488
+ return {"status": "error", "error": str(e)}
2489
+
2490
+ # Fallback to current_run check
2491
+ current_run = self._get_current_run()
2492
+ if current_run.get("running"):
2493
+ return {
2494
+ "status": "running",
2495
+ "current_job": {
2496
+ "run_id": "unknown",
2497
+ "model_id": current_run.get("model", "unknown"),
2498
+ "total_tasks": current_run["progress"]["total_tasks"],
2499
+ "completed_tasks": current_run["progress"]["tasks_completed"],
2500
+ "current_task": {"task_id": current_run.get("current_task")},
2501
+ },
2502
+ "queue": [],
2503
+ }
2504
+
2505
+ return {"status": "idle"}
2506
+
2507
+ def _get_benchmark_costs(self) -> dict:
2508
+ """Get cost breakdown for current benchmark run.
2509
+
2510
+ Returns:
2511
+ dict with Azure VM, API calls, and GPU costs
2512
+ """
2513
+
2514
+ # Check for cost tracking file
2515
+ cost_file = Path("benchmark_costs.json")
2516
+ if cost_file.exists():
2517
+ try:
2518
+ return json.loads(cost_file.read_text())
2519
+ except Exception:
2520
+ pass
2521
+
2522
+ # Return placeholder structure
2523
+ return {
2524
+ "azure_vm": {
2525
+ "instance_type": "Standard_D4ds_v5",
2526
+ "hourly_rate_usd": 0.192,
2527
+ "hours_elapsed": 0.0,
2528
+ "cost_usd": 0.0,
2529
+ },
2530
+ "api_calls": {
2531
+ "anthropic": {"cost_usd": 0.0},
2532
+ "openai": {"cost_usd": 0.0},
2533
+ },
2534
+ "gpu_time": {
2535
+ "lambda_labs": {"cost_usd": 0.0},
2536
+ },
2537
+ "total_cost_usd": 0.0,
2538
+ }
2539
+
2540
+ def _get_benchmark_metrics(self) -> dict:
2541
+ """Get performance metrics for current/completed benchmarks.
2542
+
2543
+ Returns:
2544
+ dict with success rate trends, domain breakdown, episode metrics
2545
+ """
2546
+ # Check for metrics file
2547
+ metrics_file = Path("benchmark_metrics.json")
2548
+ if metrics_file.exists():
2549
+ try:
2550
+ return json.loads(metrics_file.read_text())
2551
+ except Exception:
2552
+ pass
2553
+
2554
+ # Load completed runs from benchmark_results/
2555
+ benchmark_results_dir = Path("benchmark_results")
2556
+ if not benchmark_results_dir.exists():
2557
+ return {"error": "No benchmark results found"}
2558
+
2559
+ # Find most recent run
2560
+ runs = sorted(
2561
+ benchmark_results_dir.iterdir(),
2562
+ key=lambda p: p.stat().st_mtime,
2563
+ reverse=True,
2564
+ )
2565
+ if not runs:
2566
+ return {"error": "No benchmark runs found"}
2567
+
2568
+ recent_run = runs[0]
2569
+ summary_path = recent_run / "summary.json"
2570
+ if not summary_path.exists():
2571
+ return {"error": f"No summary.json in {recent_run.name}"}
2572
+
2573
+ try:
2574
+ summary = json.loads(summary_path.read_text())
2575
+
2576
+ # Build domain breakdown from tasks
2577
+ domain_breakdown = {}
2578
+ tasks_dir = recent_run / "tasks"
2579
+ if tasks_dir.exists():
2580
+ for task_dir in tasks_dir.iterdir():
2581
+ if not task_dir.is_dir():
2582
+ continue
2583
+
2584
+ task_json = task_dir / "task.json"
2585
+ execution_json = task_dir / "execution.json"
2586
+ if not (task_json.exists() and execution_json.exists()):
2587
+ continue
2588
+
2589
+ try:
2590
+ task_def = json.loads(task_json.read_text())
2591
+ execution = json.loads(execution_json.read_text())
2592
+
2593
+ domain = task_def.get("domain", "unknown")
2594
+ if domain not in domain_breakdown:
2595
+ domain_breakdown[domain] = {
2596
+ "total": 0,
2597
+ "success": 0,
2598
+ "rate": 0.0,
2599
+ "avg_steps": 0.0,
2600
+ "total_steps": 0,
2601
+ }
2602
+
2603
+ domain_breakdown[domain]["total"] += 1
2604
+ if execution.get("success"):
2605
+ domain_breakdown[domain]["success"] += 1
2606
+ domain_breakdown[domain]["total_steps"] += execution.get(
2607
+ "num_steps", 0
2608
+ )
2609
+
2610
+ except Exception:
2611
+ continue
2612
+
2613
+ # Calculate averages
2614
+ for domain, stats in domain_breakdown.items():
2615
+ if stats["total"] > 0:
2616
+ stats["rate"] = stats["success"] / stats["total"]
2617
+ stats["avg_steps"] = stats["total_steps"] / stats["total"]
2618
+
2619
+ return {
2620
+ "success_rate_over_time": [], # TODO: implement trend tracking
2621
+ "avg_steps_per_task": [], # TODO: implement trend tracking
2622
+ "domain_breakdown": domain_breakdown,
2623
+ "episode_success_metrics": {
2624
+ "first_action_accuracy": summary.get(
2625
+ "first_action_accuracy", 0.0
2626
+ ),
2627
+ "episode_success_rate": summary.get("success_rate", 0.0),
2628
+ "avg_steps_to_success": summary.get("avg_steps", 0.0),
2629
+ "avg_steps_to_failure": 0.0, # TODO: calculate from failed tasks
2630
+ },
2631
+ }
2632
+ except Exception as e:
2633
+ return {"error": f"Failed to load metrics: {str(e)}"}
2634
+
2635
+ def _get_benchmark_workers(self) -> dict:
2636
+ """Get worker status and utilization.
2637
+
2638
+ Returns:
2639
+ dict with total/active/idle workers and per-worker stats
2640
+ """
2641
+ # Get VM registry
2642
+ vms = self._fetch_vm_registry()
2643
+
2644
+ active_workers = [v for v in vms if v.get("status") == "online"]
2645
+ idle_workers = [v for v in vms if v.get("status") != "online"]
2646
+
2647
+ workers = []
2648
+ for vm in vms:
2649
+ workers.append(
2650
+ {
2651
+ "worker_id": vm.get("name", "unknown"),
2652
+ "status": "running" if vm.get("status") == "online" else "idle",
2653
+ "current_task": vm.get("current_task"),
2654
+ "tasks_completed": vm.get("tasks_completed", 0),
2655
+ "uptime_seconds": vm.get("uptime_seconds", 0),
2656
+ "idle_time_seconds": vm.get("idle_time_seconds", 0),
2657
+ }
2658
+ )
2659
+
2660
+ return {
2661
+ "total_workers": len(vms),
2662
+ "active_workers": len(active_workers),
2663
+ "idle_workers": len(idle_workers),
2664
+ "workers": workers,
2665
+ }
2666
+
2667
+ def _get_benchmark_runs(self) -> list[dict]:
2668
+ """Load all benchmark runs from benchmark_results directory.
2669
+
2670
+ Returns:
2671
+ List of benchmark run summaries sorted by timestamp (newest first)
2672
+ """
2673
+ results_dir = Path("benchmark_results")
2674
+ if not results_dir.exists():
2675
+ return []
2676
+
2677
+ runs = []
2678
+ for run_dir in results_dir.iterdir():
2679
+ if run_dir.is_dir():
2680
+ summary_file = run_dir / "summary.json"
2681
+ if summary_file.exists():
2682
+ try:
2683
+ summary = json.loads(summary_file.read_text())
2684
+ runs.append(summary)
2685
+ except (json.JSONDecodeError, IOError) as e:
2686
+ print(f"Warning: Failed to load {summary_file}: {e}")
2687
+
2688
+ # Sort by run_name descending (newest first)
2689
+ runs.sort(key=lambda r: r.get("run_name", ""), reverse=True)
2690
+ return runs
2691
+
2692
+ def _get_task_execution(self, run_name: str, task_id: str) -> dict:
2693
+ """Load task execution details from execution.json.
2694
+
2695
+ Args:
2696
+ run_name: Name of the benchmark run
2697
+ task_id: Task identifier
2698
+
2699
+ Returns:
2700
+ Task execution data with steps and screenshots
2701
+ """
2702
+ results_dir = Path("benchmark_results")
2703
+ execution_file = (
2704
+ results_dir / run_name / "tasks" / task_id / "execution.json"
2705
+ )
2706
+
2707
+ if not execution_file.exists():
2708
+ raise FileNotFoundError(f"Execution file not found: {execution_file}")
2709
+
2710
+ try:
2711
+ return json.loads(execution_file.read_text())
2712
+ except (json.JSONDecodeError, IOError) as e:
2713
+ raise Exception(f"Failed to load execution data: {e}")
2714
+
2715
+ async def _detect_running_benchmark(
2716
+ self, vm_ip: str, container_name: str = "winarena"
2717
+ ) -> dict:
2718
+ """Detect if a benchmark is running on the VM and extract progress.
2719
+
2720
+ SSH into VM and check:
2721
+ 1. Process running: docker exec {container} pgrep -f 'python.*run.py'
2722
+ 2. Log progress: tail /tmp/waa_benchmark.log
2723
+
2724
+ Returns:
2725
+ dict with:
2726
+ - running: bool
2727
+ - current_task: str (task ID or description)
2728
+ - progress: dict with tasks_completed, total_tasks, current_step
2729
+ - recent_logs: str (last few log lines)
2730
+ """
2731
+ import subprocess
2732
+ import re
2733
+
2734
+ result = {
2735
+ "running": False,
2736
+ "current_task": None,
2737
+ "progress": {
2738
+ "tasks_completed": 0,
2739
+ "total_tasks": 0,
2740
+ "current_step": 0,
2741
+ },
2742
+ "recent_logs": "",
2743
+ }
2744
+
2745
+ try:
2746
+ # Check if benchmark process is running
2747
+ process_check = subprocess.run(
2748
+ [
2749
+ "ssh",
2750
+ "-o",
2751
+ "StrictHostKeyChecking=no",
2752
+ "-o",
2753
+ "ConnectTimeout=5",
2754
+ "-i",
2755
+ str(Path.home() / ".ssh" / "id_rsa"),
2756
+ f"azureuser@{vm_ip}",
2757
+ f"docker exec {container_name} pgrep -f 'python.*run.py' 2>/dev/null || echo ''",
2758
+ ],
2759
+ capture_output=True,
2760
+ text=True,
2761
+ timeout=10,
2762
+ )
2763
+
2764
+ if process_check.returncode == 0 and process_check.stdout.strip():
2765
+ result["running"] = True
2766
+
2767
+ # Get benchmark log
2768
+ log_check = subprocess.run(
2769
+ [
2770
+ "ssh",
2771
+ "-o",
2772
+ "StrictHostKeyChecking=no",
2773
+ "-o",
2774
+ "ConnectTimeout=5",
2775
+ "-i",
2776
+ str(Path.home() / ".ssh" / "id_rsa"),
2777
+ f"azureuser@{vm_ip}",
2778
+ "tail -100 /tmp/waa_benchmark.log 2>/dev/null || echo ''",
2779
+ ],
2780
+ capture_output=True,
2781
+ text=True,
2782
+ timeout=10,
2783
+ )
2784
+
2785
+ if log_check.returncode == 0 and log_check.stdout.strip():
2786
+ logs = log_check.stdout
2787
+ result["recent_logs"] = logs[-500:] # Last 500 chars
2788
+
2789
+ # Parse progress from logs
2790
+ # Look for patterns like "Task 5/30" or "Completed: 5, Remaining: 25"
2791
+ task_match = re.search(r"Task\s+(\d+)/(\d+)", logs)
2792
+ if task_match:
2793
+ result["progress"]["tasks_completed"] = int(
2794
+ task_match.group(1)
2795
+ )
2796
+ result["progress"]["total_tasks"] = int(task_match.group(2))
2797
+
2798
+ # Extract current task ID
2799
+ task_id_match = re.search(
2800
+ r"(?:Running|Processing) task:\s*(\S+)", logs
2801
+ )
2802
+ if task_id_match:
2803
+ result["current_task"] = task_id_match.group(1)
2804
+
2805
+ # Extract step info
2806
+ step_match = re.search(r"Step\s+(\d+)", logs)
2807
+ if step_match:
2808
+ result["progress"]["current_step"] = int(
2809
+ step_match.group(1)
2810
+ )
2811
+
2812
+ except Exception:
2813
+ # SSH or parsing failed - leave defaults
2814
+ pass
2815
+
2816
+ return result
2817
+
2818
+ def _parse_task_result(self, log_lines: list[str], task_id: str) -> dict:
2819
+ """Parse task success/failure from log output.
2820
+
2821
+ WAA log patterns:
2822
+ - Success: "Task task_001 completed successfully"
2823
+ - Success: "Result: PASS"
2824
+ - Failure: "Task task_001 failed"
2825
+ - Failure: "Result: FAIL"
2826
+ - Score: "Score: 0.85"
2827
+ """
2828
+ import re
2829
+
2830
+ success = None
2831
+ score = None
2832
+
2833
+ # Search backwards from most recent
2834
+ for line in reversed(log_lines):
2835
+ # Check for explicit result
2836
+ if "Result: PASS" in line or "completed successfully" in line:
2837
+ success = True
2838
+ elif "Result: FAIL" in line or "failed" in line.lower():
2839
+ success = False
2840
+
2841
+ # Check for score
2842
+ score_match = re.search(r"Score:\s*([\d.]+)", line)
2843
+ if score_match:
2844
+ try:
2845
+ score = float(score_match.group(1))
2846
+ except ValueError:
2847
+ pass
2848
+
2849
+ # Check for task-specific completion
2850
+ if task_id in line:
2851
+ if "success" in line.lower() or "pass" in line.lower():
2852
+ success = True
2853
+ elif "fail" in line.lower() or "error" in line.lower():
2854
+ success = False
2855
+
2856
+ # Default to True if no explicit failure found (backwards compatible)
2857
+ if success is None:
2858
+ success = True
2859
+
2860
+ return {"success": success, "score": score}
2861
+
2862
+ def _stream_benchmark_updates(self, interval: int):
2863
+ """Stream Server-Sent Events for benchmark status updates.
2864
+
2865
+ Streams events:
2866
+ - connected: Initial connection event
2867
+ - status: VM status and probe results
2868
+ - progress: Benchmark progress (tasks completed, current task)
2869
+ - task_complete: When a task finishes
2870
+ - heartbeat: Keep-alive signal every 30 seconds
2871
+ - error: Error messages
2872
+
2873
+ Uses a generator-based approach to avoid blocking the main thread
2874
+ and properly handles client disconnection.
2875
+ """
2876
+ import time
2877
+ import select
2878
+
2879
+ HEARTBEAT_INTERVAL = 30 # seconds
2880
+
2881
+ # Set SSE headers
2882
+ self.send_response(200)
2883
+ self.send_header("Content-Type", "text/event-stream")
2884
+ self.send_header("Cache-Control", "no-cache")
2885
+ self.send_header("Access-Control-Allow-Origin", "*")
2886
+ self.send_header("Connection", "keep-alive")
2887
+ self.send_header("X-Accel-Buffering", "no") # Disable nginx buffering
2888
+ self.end_headers()
2889
+
2890
+ # Track connection state
2891
+ client_connected = True
2892
+
2893
+ def send_event(event_type: str, data: dict) -> bool:
2894
+ """Send an SSE event. Returns False if client disconnected."""
2895
+ nonlocal client_connected
2896
+ if not client_connected:
2897
+ return False
2898
+ try:
2899
+ event_str = f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
2900
+ self.wfile.write(event_str.encode("utf-8"))
2901
+ self.wfile.flush()
2902
+ return True
2903
+ except (BrokenPipeError, ConnectionResetError, ConnectionAbortedError):
2904
+ # Client disconnected
2905
+ client_connected = False
2906
+ return False
2907
+ except Exception as e:
2908
+ # Other error - log and assume disconnected
2909
+ print(f"SSE send error: {e}")
2910
+ client_connected = False
2911
+ return False
2912
+
2913
+ def check_client_connected() -> bool:
2914
+ """Check if client is still connected using socket select."""
2915
+ nonlocal client_connected
2916
+ if not client_connected:
2917
+ return False
2918
+ try:
2919
+ # Check if socket has data (would indicate client sent something or closed)
2920
+ # Use non-blocking check with 0 timeout
2921
+ rlist, _, xlist = select.select([self.rfile], [], [self.rfile], 0)
2922
+ if xlist:
2923
+ # Error condition on socket
2924
+ client_connected = False
2925
+ return False
2926
+ if rlist:
2927
+ # Client sent data - for SSE this usually means disconnect
2928
+ # (SSE is server-push only, client doesn't send data)
2929
+ data = self.rfile.read(1)
2930
+ if not data:
2931
+ client_connected = False
2932
+ return False
2933
+ return True
2934
+ except Exception:
2935
+ client_connected = False
2936
+ return False
2937
+
2938
+ last_task = None
2939
+ last_heartbeat = time.time()
2940
+ recent_log_lines = []
2941
+
2942
+ # Send initial connected event
2943
+ if not send_event(
2944
+ "connected",
2945
+ {"timestamp": time.time(), "interval": interval, "version": "1.0"},
2946
+ ):
2947
+ return
2948
+
2949
+ try:
2950
+ iteration_count = 0
2951
+ max_iterations = 3600 // interval # Max 1 hour of streaming
2952
+
2953
+ while client_connected and iteration_count < max_iterations:
2954
+ iteration_count += 1
2955
+ current_time = time.time()
2956
+
2957
+ # Check client connection before doing work
2958
+ if not check_client_connected():
2959
+ break
2960
+
2961
+ # Send heartbeat every 30 seconds to prevent proxy/LB timeouts
2962
+ if current_time - last_heartbeat >= HEARTBEAT_INTERVAL:
2963
+ if not send_event("heartbeat", {"timestamp": current_time}):
2964
+ break
2965
+ last_heartbeat = current_time
2966
+
2967
+ # Fetch background tasks (includes VM status)
2968
+ tasks = self._fetch_background_tasks()
2969
+
2970
+ # Send VM status event
2971
+ vm_task = next(
2972
+ (t for t in tasks if t.get("task_type") == "docker_container"),
2973
+ None,
2974
+ )
2975
+ if vm_task:
2976
+ vm_data = {
2977
+ "type": "vm_status",
2978
+ "connected": vm_task.get("status")
2979
+ in ["running", "completed"],
2980
+ "phase": vm_task.get("phase", "unknown"),
2981
+ "waa_ready": vm_task.get("metadata", {}).get(
2982
+ "waa_server_ready", False
2983
+ ),
2984
+ "probe": {
2985
+ "status": vm_task.get("metadata", {}).get(
2986
+ "probe_response", "unknown"
2987
+ ),
2988
+ "vnc_url": vm_task.get("metadata", {}).get("vnc_url"),
2989
+ },
2990
+ }
2991
+
2992
+ if not send_event("status", vm_data):
2993
+ break
2994
+
2995
+ # If VM is ready, check for running benchmark
2996
+ if vm_data["waa_ready"]:
2997
+ # Get VM IP from tasks
2998
+ vm_ip = None
2999
+ azure_vm = next(
3000
+ (
3001
+ t
3002
+ for t in tasks
3003
+ if t.get("task_type") == "vm_provision"
3004
+ ),
3005
+ None,
3006
+ )
3007
+ if azure_vm:
3008
+ vm_ip = azure_vm.get("metadata", {}).get("ip_address")
3009
+
3010
+ if vm_ip:
3011
+ # Detect running benchmark using sync version
3012
+ benchmark_status = self._detect_running_benchmark_sync(
3013
+ vm_ip,
3014
+ vm_task.get("metadata", {}).get(
3015
+ "container", "winarena"
3016
+ ),
3017
+ )
3018
+
3019
+ if benchmark_status["running"]:
3020
+ # Store log lines for result parsing
3021
+ if benchmark_status.get("recent_logs"):
3022
+ recent_log_lines = benchmark_status[
3023
+ "recent_logs"
3024
+ ].split("\n")
3025
+
3026
+ # Send progress event
3027
+ progress_data = {
3028
+ "tasks_completed": benchmark_status["progress"][
3029
+ "tasks_completed"
3030
+ ],
3031
+ "total_tasks": benchmark_status["progress"][
3032
+ "total_tasks"
3033
+ ],
3034
+ "current_task": benchmark_status[
3035
+ "current_task"
3036
+ ],
3037
+ "current_step": benchmark_status["progress"][
3038
+ "current_step"
3039
+ ],
3040
+ }
3041
+
3042
+ if not send_event("progress", progress_data):
3043
+ break
3044
+
3045
+ # Check if task completed
3046
+ current_task = benchmark_status["current_task"]
3047
+ if current_task and current_task != last_task:
3048
+ if last_task is not None:
3049
+ # Previous task completed - parse result from logs
3050
+ result = self._parse_task_result(
3051
+ recent_log_lines, last_task
3052
+ )
3053
+ complete_data = {
3054
+ "task_id": last_task,
3055
+ "success": result["success"],
3056
+ "score": result["score"],
3057
+ }
3058
+ if not send_event(
3059
+ "task_complete", complete_data
3060
+ ):
3061
+ break
3062
+
3063
+ last_task = current_task
3064
+
3065
+ # Check local benchmark progress file
3066
+ progress_file = Path("benchmark_progress.json")
3067
+ if progress_file.exists():
3068
+ try:
3069
+ progress = json.loads(progress_file.read_text())
3070
+ if progress.get("status") == "running":
3071
+ progress_data = {
3072
+ "tasks_completed": progress.get(
3073
+ "tasks_complete", 0
3074
+ ),
3075
+ "total_tasks": progress.get("tasks_total", 0),
3076
+ "current_task": progress.get("provider", "unknown"),
3077
+ }
3078
+ if not send_event("progress", progress_data):
3079
+ break
3080
+ except Exception:
3081
+ pass
3082
+
3083
+ # Non-blocking sleep using select with timeout
3084
+ # This allows checking for client disconnect during sleep
3085
+ try:
3086
+ select.select([self.rfile], [], [], interval)
3087
+ except Exception:
3088
+ break
3089
+
3090
+ except (BrokenPipeError, ConnectionResetError, ConnectionAbortedError):
3091
+ # Client disconnected - this is normal, don't log as error
3092
+ pass
3093
+ except Exception as e:
3094
+ # Send error event if still connected
3095
+ send_event("error", {"message": str(e)})
3096
+ finally:
3097
+ # Cleanup - connection is ending
3098
+ client_connected = False
3099
+
3100
+ def _stream_azure_ops_updates(self):
3101
+ """Stream Server-Sent Events for Azure operations status updates.
3102
+
3103
+ Monitors azure_ops_status.json for changes and streams updates.
3104
+ Uses file modification time to detect changes efficiently.
3105
+
3106
+ Streams events:
3107
+ - connected: Initial connection event
3108
+ - status: Azure ops status update when file changes
3109
+ - heartbeat: Keep-alive signal every 30 seconds
3110
+ - error: Error messages
3111
+ """
3112
+ import time
3113
+ import select
3114
+ from pathlib import Path
3115
+
3116
+ HEARTBEAT_INTERVAL = 30 # seconds
3117
+ CHECK_INTERVAL = 1 # Check file every second
3118
+
3119
+ # Set SSE headers
3120
+ self.send_response(200)
3121
+ self.send_header("Content-Type", "text/event-stream")
3122
+ self.send_header("Cache-Control", "no-cache")
3123
+ self.send_header("Access-Control-Allow-Origin", "*")
3124
+ self.send_header("Connection", "keep-alive")
3125
+ self.send_header("X-Accel-Buffering", "no") # Disable nginx buffering
3126
+ self.end_headers()
3127
+
3128
+ # Track connection state
3129
+ client_connected = True
3130
+ last_mtime = 0.0
3131
+ last_session_mtime = 0.0
3132
+ last_heartbeat = time.time()
3133
+
3134
+ def send_event(event_type: str, data: dict) -> bool:
3135
+ """Send an SSE event. Returns False if client disconnected."""
3136
+ nonlocal client_connected
3137
+ if not client_connected:
3138
+ return False
3139
+ try:
3140
+ event_str = f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
3141
+ self.wfile.write(event_str.encode("utf-8"))
3142
+ self.wfile.flush()
3143
+ return True
3144
+ except (BrokenPipeError, ConnectionResetError, ConnectionAbortedError):
3145
+ client_connected = False
3146
+ return False
3147
+ except Exception as e:
3148
+ print(f"Azure ops SSE send error: {e}")
3149
+ client_connected = False
3150
+ return False
3151
+
3152
+ def check_client_connected() -> bool:
3153
+ """Check if client is still connected using socket select."""
3154
+ nonlocal client_connected
3155
+ if not client_connected:
3156
+ return False
3157
+ try:
3158
+ rlist, _, xlist = select.select([self.rfile], [], [self.rfile], 0)
3159
+ if xlist:
3160
+ client_connected = False
3161
+ return False
3162
+ if rlist:
3163
+ data = self.rfile.read(1)
3164
+ if not data:
3165
+ client_connected = False
3166
+ return False
3167
+ return True
3168
+ except Exception:
3169
+ client_connected = False
3170
+ return False
3171
+
3172
+ # Status file path
3173
+ from openadapt_ml.benchmarks.azure_ops_tracker import (
3174
+ DEFAULT_OUTPUT_FILE,
3175
+ read_status,
3176
+ )
3177
+ from openadapt_ml.benchmarks.session_tracker import (
3178
+ get_session,
3179
+ update_session_vm_state,
3180
+ DEFAULT_SESSION_FILE,
3181
+ )
3182
+
3183
+ status_file = Path(DEFAULT_OUTPUT_FILE)
3184
+ session_file = Path(DEFAULT_SESSION_FILE)
3185
+
3186
+ def compute_server_side_values(status: dict) -> dict:
3187
+ """Get elapsed_seconds and cost_usd from session tracker for persistence."""
3188
+ # Get session data (persistent across refreshes)
3189
+ session = get_session()
3190
+
3191
+ # Update session based on VM state if we have VM info
3192
+ # IMPORTANT: Only pass vm_ip if it's truthy to avoid
3193
+ # overwriting session's stable vm_ip with None
3194
+ if status.get("vm_state") and status.get("vm_state") != "unknown":
3195
+ status_vm_ip = status.get("vm_ip")
3196
+ # Build update kwargs - only include vm_ip if present
3197
+ update_kwargs = {
3198
+ "vm_state": status["vm_state"],
3199
+ "vm_size": status.get("vm_size"),
3200
+ }
3201
+ if status_vm_ip: # Only include if truthy
3202
+ update_kwargs["vm_ip"] = status_vm_ip
3203
+ session = update_session_vm_state(**update_kwargs)
3204
+
3205
+ # Use session's vm_ip as authoritative source
3206
+ # This prevents IP flickering when status file has stale/None values
3207
+ if session.get("vm_ip"):
3208
+ status["vm_ip"] = session["vm_ip"]
3209
+
3210
+ # Use session's elapsed_seconds and cost_usd for persistence
3211
+ if (
3212
+ session.get("is_active")
3213
+ or session.get("accumulated_seconds", 0) > 0
3214
+ ):
3215
+ status["elapsed_seconds"] = session.get("elapsed_seconds", 0.0)
3216
+ status["cost_usd"] = session.get("cost_usd", 0.0)
3217
+ status["started_at"] = session.get("started_at")
3218
+ status["session_id"] = session.get("session_id")
3219
+ status["session_is_active"] = session.get("is_active", False)
3220
+ # Include accumulated time from previous sessions for hybrid display
3221
+ status["accumulated_seconds"] = session.get(
3222
+ "accumulated_seconds", 0.0
3223
+ )
3224
+ # Calculate current session time (total - accumulated)
3225
+ current_session_seconds = max(
3226
+ 0, status["elapsed_seconds"] - status["accumulated_seconds"]
3227
+ )
3228
+ status["current_session_seconds"] = current_session_seconds
3229
+ hourly_rate = session.get("hourly_rate_usd", 0.422)
3230
+ status["current_session_cost_usd"] = (
3231
+ current_session_seconds / 3600
3232
+ ) * hourly_rate
3233
+
3234
+ try:
3235
+ tunnel_mgr = get_tunnel_manager()
3236
+ tunnel_status = tunnel_mgr.get_tunnel_status()
3237
+ status["tunnels"] = {
3238
+ name: {
3239
+ "active": s.active,
3240
+ "local_port": s.local_port,
3241
+ "remote_endpoint": s.remote_endpoint,
3242
+ "pid": s.pid,
3243
+ "error": s.error,
3244
+ }
3245
+ for name, s in tunnel_status.items()
3246
+ }
3247
+ except Exception as e:
3248
+ status["tunnels"] = {"error": str(e)}
3249
+
3250
+ return status
3251
+
3252
+ # Send initial connected event
3253
+ if not send_event(
3254
+ "connected",
3255
+ {"timestamp": time.time(), "version": "1.0"},
3256
+ ):
3257
+ return
3258
+
3259
+ # Send initial status immediately
3260
+ try:
3261
+ status = compute_server_side_values(read_status())
3262
+ if not send_event("status", status):
3263
+ return
3264
+ if status_file.exists():
3265
+ last_mtime = status_file.stat().st_mtime
3266
+ except Exception as e:
3267
+ send_event("error", {"message": str(e)})
3268
+
3269
+ try:
3270
+ iteration_count = 0
3271
+ max_iterations = 3600 # Max 1 hour of streaming
3272
+ last_status_send = 0.0
3273
+ STATUS_SEND_INTERVAL = 2 # Send status every 2 seconds for live updates
3274
+
3275
+ while client_connected and iteration_count < max_iterations:
3276
+ iteration_count += 1
3277
+ current_time = time.time()
3278
+
3279
+ # Check client connection
3280
+ if not check_client_connected():
3281
+ break
3282
+
3283
+ # Send heartbeat every 30 seconds
3284
+ if current_time - last_heartbeat >= HEARTBEAT_INTERVAL:
3285
+ if not send_event("heartbeat", {"timestamp": current_time}):
3286
+ break
3287
+ last_heartbeat = current_time
3288
+
3289
+ # Check if status or session file changed OR if enough time passed
3290
+ try:
3291
+ status_changed = False
3292
+ session_changed = False
3293
+ time_to_send = (
3294
+ current_time - last_status_send >= STATUS_SEND_INTERVAL
3295
+ )
3296
+
3297
+ if status_file.exists():
3298
+ current_mtime = status_file.stat().st_mtime
3299
+ if current_mtime > last_mtime:
3300
+ status_changed = True
3301
+ last_mtime = current_mtime
3302
+
3303
+ if session_file.exists():
3304
+ current_session_mtime = session_file.stat().st_mtime
3305
+ if current_session_mtime > last_session_mtime:
3306
+ session_changed = True
3307
+ last_session_mtime = current_session_mtime
3308
+
3309
+ # Send status if file changed OR periodic timer expired
3310
+ # This ensures live elapsed time/cost updates even without file changes
3311
+ if status_changed or session_changed or time_to_send:
3312
+ # File changed or time to send - send update with session values
3313
+ status = compute_server_side_values(read_status())
3314
+ if not send_event("status", status):
3315
+ break
3316
+ last_status_send = current_time
3317
+ except Exception as e:
3318
+ # File access error - log but continue
3319
+ print(f"Azure ops SSE file check error: {e}")
3320
+
3321
+ # Sleep briefly before next check
3322
+ try:
3323
+ select.select([self.rfile], [], [], CHECK_INTERVAL)
3324
+ except Exception:
3325
+ break
3326
+
3327
+ except (BrokenPipeError, ConnectionResetError, ConnectionAbortedError):
3328
+ # Client disconnected - normal
3329
+ pass
3330
+ except Exception as e:
3331
+ send_event("error", {"message": str(e)})
3332
+ finally:
3333
+ client_connected = False
3334
+
3335
+ def _detect_running_benchmark_sync(
3336
+ self, vm_ip: str, container_name: str = "winarena"
3337
+ ) -> dict:
3338
+ """Synchronous version of _detect_running_benchmark.
3339
+
3340
+ Avoids creating a new event loop on each call which causes issues
3341
+ when called from a synchronous context.
3342
+ """
3343
+ import subprocess
3344
+ import re
3345
+
3346
+ result = {
3347
+ "running": False,
3348
+ "current_task": None,
3349
+ "progress": {
3350
+ "tasks_completed": 0,
3351
+ "total_tasks": 0,
3352
+ "current_step": 0,
3353
+ },
3354
+ "recent_logs": "",
3355
+ }
3356
+
3357
+ try:
3358
+ # Check if benchmark process is running
3359
+ process_check = subprocess.run(
3360
+ [
3361
+ "ssh",
3362
+ "-o",
3363
+ "StrictHostKeyChecking=no",
3364
+ "-o",
3365
+ "ConnectTimeout=5",
3366
+ "-i",
3367
+ str(Path.home() / ".ssh" / "id_rsa"),
3368
+ f"azureuser@{vm_ip}",
3369
+ f"docker exec {container_name} pgrep -f 'python.*run.py' 2>/dev/null || echo ''",
3370
+ ],
3371
+ capture_output=True,
3372
+ text=True,
3373
+ timeout=10,
3374
+ )
3375
+
3376
+ if process_check.returncode == 0 and process_check.stdout.strip():
3377
+ result["running"] = True
3378
+
3379
+ # Get benchmark log
3380
+ log_check = subprocess.run(
3381
+ [
3382
+ "ssh",
3383
+ "-o",
3384
+ "StrictHostKeyChecking=no",
3385
+ "-o",
3386
+ "ConnectTimeout=5",
3387
+ "-i",
3388
+ str(Path.home() / ".ssh" / "id_rsa"),
3389
+ f"azureuser@{vm_ip}",
3390
+ "tail -100 /tmp/waa_benchmark.log 2>/dev/null || echo ''",
3391
+ ],
3392
+ capture_output=True,
3393
+ text=True,
3394
+ timeout=10,
3395
+ )
3396
+
3397
+ if log_check.returncode == 0 and log_check.stdout.strip():
3398
+ logs = log_check.stdout
3399
+ result["recent_logs"] = logs[-500:] # Last 500 chars
3400
+
3401
+ # Parse progress from logs
3402
+ task_match = re.search(r"Task\s+(\d+)/(\d+)", logs)
3403
+ if task_match:
3404
+ result["progress"]["tasks_completed"] = int(
3405
+ task_match.group(1)
3406
+ )
3407
+ result["progress"]["total_tasks"] = int(task_match.group(2))
3408
+
3409
+ # Extract current task ID
3410
+ task_id_match = re.search(
3411
+ r"(?:Running|Processing) task:\s*(\S+)", logs
3412
+ )
3413
+ if task_id_match:
3414
+ result["current_task"] = task_id_match.group(1)
3415
+
3416
+ # Extract step info
3417
+ step_match = re.search(r"Step\s+(\d+)", logs)
3418
+ if step_match:
3419
+ result["progress"]["current_step"] = int(
3420
+ step_match.group(1)
3421
+ )
3422
+
3423
+ except Exception:
3424
+ # SSH or parsing failed - leave defaults
3425
+ pass
3426
+
3427
+ return result
3428
+
3429
+ def _register_vm(self, vm_data):
3430
+ """Register a new VM in the registry."""
3431
+ # Path to VM registry file (relative to project root)
3432
+ project_root = Path(__file__).parent.parent.parent
3433
+ registry_file = project_root / "benchmark_results" / "vm_registry.json"
3434
+
3435
+ # Load existing registry
3436
+ vms = []
3437
+ if registry_file.exists():
3438
+ try:
3439
+ with open(registry_file) as f:
3440
+ vms = json.load(f)
3441
+ except Exception:
3442
+ pass
3443
+
3444
+ # Add new VM
3445
+ new_vm = {
3446
+ "name": vm_data.get("name", "unnamed-vm"),
3447
+ "ssh_host": vm_data.get("ssh_host", ""),
3448
+ "ssh_user": vm_data.get("ssh_user", "azureuser"),
3449
+ "vnc_port": vm_data.get("vnc_port", 8006),
3450
+ "waa_port": vm_data.get("waa_port", 5000),
3451
+ "docker_container": vm_data.get("docker_container", "win11-waa"),
3452
+ "internal_ip": vm_data.get("internal_ip", "20.20.20.21"),
3453
+ }
3454
+
3455
+ vms.append(new_vm)
3456
+
3457
+ # Save registry
3458
+ try:
3459
+ registry_file.parent.mkdir(parents=True, exist_ok=True)
3460
+ with open(registry_file, "w") as f:
3461
+ json.dump(vms, f, indent=2)
3462
+ return {"status": "success", "vm": new_vm}
3463
+ except Exception as e:
3464
+ return {"status": "error", "message": str(e)}
3465
+
3466
+ def _start_benchmark_run(self, params: dict) -> dict:
3467
+ """Start a benchmark run with the given parameters.
3468
+
3469
+ Runs the benchmark in a background thread and returns immediately.
3470
+ Progress is tracked via benchmark_progress.json.
3471
+
3472
+ Expected params:
3473
+ {
3474
+ "model": "gpt-4o",
3475
+ "num_tasks": 5,
3476
+ "agent": "navi",
3477
+ "task_selection": "all" | "domain" | "task_ids",
3478
+ "domain": "general", // if task_selection == "domain"
3479
+ "task_ids": ["task_001", "task_015"] // if task_selection == "task_ids"
3480
+ }
3481
+
3482
+ Returns:
3483
+ dict with status and params
3484
+ """
3485
+ from dotenv import load_dotenv
3486
+
3487
+ # Load .env file for API keys
3488
+ project_root = Path(__file__).parent.parent.parent
3489
+ load_dotenv(project_root / ".env")
3490
+
3491
+ # Build CLI command
3492
+ cmd = [
3493
+ "uv",
3494
+ "run",
3495
+ "python",
3496
+ "-m",
3497
+ "openadapt_ml.benchmarks.cli",
3498
+ "vm",
3499
+ "run-waa",
3500
+ "--num-tasks",
3501
+ str(params.get("num_tasks", 5)),
3502
+ "--model",
3503
+ params.get("model", "gpt-4o"),
3504
+ "--agent",
3505
+ params.get("agent", "navi"),
3506
+ "--no-open", # Don't open viewer (already open)
3507
+ ]
3508
+
3509
+ # Add task selection args
3510
+ task_selection = params.get("task_selection", "all")
3511
+ if task_selection == "domain":
3512
+ domain = params.get("domain", "general")
3513
+ cmd.extend(["--domain", domain])
3514
+ elif task_selection == "task_ids":
3515
+ task_ids = params.get("task_ids", [])
3516
+ if task_ids:
3517
+ cmd.extend(["--task-ids", ",".join(task_ids)])
3518
+
3519
+ # Create progress log file (in cwd which is serve_dir)
3520
+ progress_file = Path("benchmark_progress.json")
3521
+
3522
+ # Write initial progress
3523
+ model = params.get("model", "gpt-4o")
3524
+ num_tasks = params.get("num_tasks", 5)
3525
+ agent = params.get("agent", "navi")
3526
+
3527
+ print(
3528
+ f"\n[Benchmark] Starting WAA benchmark: model={model}, tasks={num_tasks}, agent={agent}"
3529
+ )
3530
+ print(f"[Benchmark] Task selection: {task_selection}")
3531
+ if task_selection == "domain":
3532
+ print(f"[Benchmark] Domain: {params.get('domain', 'general')}")
3533
+ elif task_selection == "task_ids":
3534
+ print(f"[Benchmark] Task IDs: {params.get('task_ids', [])}")
3535
+ print(f"[Benchmark] Command: {' '.join(cmd)}")
3536
+
3537
+ progress_file.write_text(
3538
+ json.dumps(
3539
+ {
3540
+ "status": "running",
3541
+ "model": model,
3542
+ "num_tasks": num_tasks,
3543
+ "agent": agent,
3544
+ "task_selection": task_selection,
3545
+ "tasks_complete": 0,
3546
+ "message": f"Starting {model} benchmark with {num_tasks} tasks...",
3547
+ }
3548
+ )
3549
+ )
3550
+
3551
+ # Copy environment with loaded vars
3552
+ env = os.environ.copy()
3553
+
3554
+ # Run in background thread
3555
+ def run():
3556
+ result = subprocess.run(
3557
+ cmd, capture_output=True, text=True, cwd=str(project_root), env=env
3558
+ )
3559
+
3560
+ print(f"\n[Benchmark] Output:\n{result.stdout}")
3561
+ if result.stderr:
3562
+ print(f"[Benchmark] Stderr: {result.stderr}")
3563
+
3564
+ if result.returncode == 0:
3565
+ print("[Benchmark] Complete. Regenerating viewer...")
3566
+ progress_file.write_text(
3567
+ json.dumps(
3568
+ {
3569
+ "status": "complete",
3570
+ "model": model,
3571
+ "num_tasks": num_tasks,
3572
+ "message": "Benchmark complete. Refresh to see results.",
3573
+ }
3574
+ )
3575
+ )
3576
+ # Regenerate benchmark viewer
3577
+ _regenerate_benchmark_viewer_if_available(serve_dir)
3578
+ else:
3579
+ error_msg = (
3580
+ result.stderr[:200] if result.stderr else "Unknown error"
3581
+ )
3582
+ print(f"[Benchmark] Failed: {error_msg}")
3583
+ progress_file.write_text(
3584
+ json.dumps(
3585
+ {
3586
+ "status": "error",
3587
+ "model": model,
3588
+ "num_tasks": num_tasks,
3589
+ "message": f"Benchmark failed: {error_msg}",
3590
+ }
3591
+ )
3592
+ )
3593
+
3594
+ threading.Thread(target=run, daemon=True).start()
3595
+
3596
+ return {"status": "started", "params": params}
3597
+
559
3598
  def do_OPTIONS(self):
560
3599
  # Handle CORS preflight
561
3600
  self.send_response(200)
562
- self.send_header('Access-Control-Allow-Origin', '*')
563
- self.send_header('Access-Control-Allow-Methods', 'POST, OPTIONS')
564
- self.send_header('Access-Control-Allow-Headers', 'Content-Type')
3601
+ self.send_header("Access-Control-Allow-Origin", "*")
3602
+ self.send_header("Access-Control-Allow-Methods", "POST, OPTIONS")
3603
+ self.send_header("Access-Control-Allow-Headers", "Content-Type")
565
3604
  self.end_headers()
566
3605
 
567
- with socketserver.TCPServer(("", port), StopHandler) as httpd:
3606
+ class ThreadedTCPServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
3607
+ allow_reuse_address = True
3608
+ daemon_threads = True # Don't block shutdown
3609
+
3610
+ with ThreadedTCPServer(("", port), StopHandler) as httpd:
568
3611
  url = f"http://localhost:{port}/{start_page}"
569
3612
  print(f"\nServing at: {url}")
570
3613
  print(f"Directory: {serve_dir}")
@@ -611,7 +3654,40 @@ def cmd_viewer(args: argparse.Namespace) -> int:
611
3654
  state.learning_rate = data.get("learning_rate", 0)
612
3655
  state.losses = data.get("losses", [])
613
3656
  state.status = data.get("status", "completed")
614
- state.elapsed_time = data.get("elapsed_time", 0.0) # Load elapsed time for completed training
3657
+ state.elapsed_time = data.get(
3658
+ "elapsed_time", 0.0
3659
+ ) # Load elapsed time for completed training
3660
+ state.goal = data.get("goal", "")
3661
+ state.config_path = data.get("config_path", "")
3662
+ state.capture_path = data.get("capture_path", "")
3663
+
3664
+ # Load model config from training_log.json or fall back to reading config file
3665
+ state.model_name = data.get("model_name", "")
3666
+ state.lora_r = data.get("lora_r", 0)
3667
+ state.lora_alpha = data.get("lora_alpha", 0)
3668
+ state.load_in_4bit = data.get("load_in_4bit", False)
3669
+
3670
+ # If model config not in JSON, try to read from config file
3671
+ if not state.model_name and state.config_path:
3672
+ try:
3673
+ import yaml
3674
+
3675
+ # Try relative to project root first, then as absolute path
3676
+ project_root = Path(__file__).parent.parent.parent
3677
+ config_file = project_root / state.config_path
3678
+ if not config_file.exists():
3679
+ config_file = Path(state.config_path)
3680
+ if config_file.exists():
3681
+ with open(config_file) as cf:
3682
+ cfg = yaml.safe_load(cf)
3683
+ if cfg and "model" in cfg:
3684
+ state.model_name = cfg["model"].get("name", "")
3685
+ state.load_in_4bit = cfg["model"].get("load_in_4bit", False)
3686
+ if cfg and "lora" in cfg:
3687
+ state.lora_r = cfg["lora"].get("r", 0)
3688
+ state.lora_alpha = cfg["lora"].get("lora_alpha", 0)
3689
+ except Exception as e:
3690
+ print(f" Warning: Could not read config file: {e}")
615
3691
 
616
3692
  config = TrainingConfig(
617
3693
  num_train_epochs=data.get("total_epochs", 5),
@@ -620,14 +3696,16 @@ def cmd_viewer(args: argparse.Namespace) -> int:
620
3696
 
621
3697
  dashboard_html = generate_training_dashboard(state, config)
622
3698
  (current_dir / "dashboard.html").write_text(dashboard_html)
623
- print(f" Regenerated: dashboard.html")
3699
+ print(" Regenerated: dashboard.html")
624
3700
 
625
3701
  # Generate unified viewer using consolidated function
626
3702
  viewer_path = generate_unified_viewer_from_output_dir(current_dir)
627
3703
  if viewer_path:
628
3704
  print(f"\nGenerated: {viewer_path}")
629
3705
  else:
630
- print("\nNo comparison data found. Run comparison first or copy from capture directory.")
3706
+ print(
3707
+ "\nNo comparison data found. Run comparison first or copy from capture directory."
3708
+ )
631
3709
 
632
3710
  # Also regenerate benchmark viewer from latest benchmark results
633
3711
  _regenerate_benchmark_viewer_if_available(current_dir)
@@ -647,9 +3725,9 @@ def cmd_benchmark_viewer(args: argparse.Namespace) -> int:
647
3725
  print(f"Error: Benchmark directory not found: {benchmark_dir}")
648
3726
  return 1
649
3727
 
650
- print(f"\n{'='*50}")
3728
+ print(f"\n{'=' * 50}")
651
3729
  print("GENERATING BENCHMARK VIEWER")
652
- print(f"{'='*50}")
3730
+ print(f"{'=' * 50}")
653
3731
  print(f"Benchmark dir: {benchmark_dir}")
654
3732
  print()
655
3733
 
@@ -664,6 +3742,7 @@ def cmd_benchmark_viewer(args: argparse.Namespace) -> int:
664
3742
  except Exception as e:
665
3743
  print(f"Error generating benchmark viewer: {e}")
666
3744
  import traceback
3745
+
667
3746
  traceback.print_exc()
668
3747
  return 1
669
3748
 
@@ -680,16 +3759,19 @@ def cmd_compare(args: argparse.Namespace) -> int:
680
3759
  print(f"Error: Checkpoint not found: {checkpoint}")
681
3760
  return 1
682
3761
 
683
- print(f"\n{'='*50}")
3762
+ print(f"\n{'=' * 50}")
684
3763
  print("RUNNING COMPARISON")
685
- print(f"{'='*50}")
3764
+ print(f"{'=' * 50}")
686
3765
  print(f"Capture: {capture_path}")
687
3766
  print(f"Checkpoint: {checkpoint or 'None (capture only)'}")
688
3767
  print()
689
3768
 
690
3769
  cmd = [
691
- sys.executable, "-m", "openadapt_ml.scripts.compare",
692
- "--capture", str(capture_path),
3770
+ sys.executable,
3771
+ "-m",
3772
+ "openadapt_ml.scripts.compare",
3773
+ "--capture",
3774
+ str(capture_path),
693
3775
  ]
694
3776
 
695
3777
  if checkpoint:
@@ -728,7 +3810,7 @@ Examples:
728
3810
 
729
3811
  # Run comparison
730
3812
  uv run python -m openadapt_ml.cloud.local compare --capture ~/captures/my-workflow --checkpoint checkpoints/model
731
- """
3813
+ """,
732
3814
  )
733
3815
 
734
3816
  subparsers = parser.add_subparsers(dest="command", help="Command")
@@ -740,9 +3822,15 @@ Examples:
740
3822
  # train
741
3823
  p_train = subparsers.add_parser("train", help="Run training locally")
742
3824
  p_train.add_argument("--capture", required=True, help="Path to capture directory")
743
- p_train.add_argument("--goal", help="Task goal (default: derived from capture name)")
744
- p_train.add_argument("--config", help="Config file (default: auto-select based on device)")
745
- p_train.add_argument("--open", action="store_true", help="Open dashboard in browser")
3825
+ p_train.add_argument(
3826
+ "--goal", help="Task goal (default: derived from capture name)"
3827
+ )
3828
+ p_train.add_argument(
3829
+ "--config", help="Config file (default: auto-select based on device)"
3830
+ )
3831
+ p_train.add_argument(
3832
+ "--open", action="store_true", help="Open dashboard in browser"
3833
+ )
746
3834
  p_train.set_defaults(func=cmd_train)
747
3835
 
748
3836
  # check
@@ -753,10 +3841,21 @@ Examples:
753
3841
  p_serve = subparsers.add_parser("serve", help="Start web server for dashboard")
754
3842
  p_serve.add_argument("--port", type=int, default=8765, help="Port number")
755
3843
  p_serve.add_argument("--open", action="store_true", help="Open in browser")
756
- p_serve.add_argument("--quiet", "-q", action="store_true", help="Suppress request logging")
757
- p_serve.add_argument("--no-regenerate", action="store_true",
758
- help="Skip regenerating dashboard/viewer (serve existing files)")
759
- p_serve.add_argument("--benchmark", help="Serve benchmark results directory instead of training output")
3844
+ p_serve.add_argument(
3845
+ "--quiet", "-q", action="store_true", help="Suppress request logging"
3846
+ )
3847
+ p_serve.add_argument(
3848
+ "--no-regenerate",
3849
+ action="store_true",
3850
+ help="Skip regenerating dashboard/viewer (serve existing files)",
3851
+ )
3852
+ p_serve.add_argument(
3853
+ "--benchmark",
3854
+ help="Serve benchmark results directory instead of training output",
3855
+ )
3856
+ p_serve.add_argument(
3857
+ "--start-page", help="Override default start page (e.g., benchmark.html)"
3858
+ )
760
3859
  p_serve.set_defaults(func=cmd_serve)
761
3860
 
762
3861
  # viewer
@@ -765,9 +3864,15 @@ Examples:
765
3864
  p_viewer.set_defaults(func=cmd_viewer)
766
3865
 
767
3866
  # benchmark_viewer
768
- p_benchmark = subparsers.add_parser("benchmark-viewer", help="Generate benchmark viewer")
769
- p_benchmark.add_argument("benchmark_dir", help="Path to benchmark results directory")
770
- p_benchmark.add_argument("--open", action="store_true", help="Open viewer in browser")
3867
+ p_benchmark = subparsers.add_parser(
3868
+ "benchmark-viewer", help="Generate benchmark viewer"
3869
+ )
3870
+ p_benchmark.add_argument(
3871
+ "benchmark_dir", help="Path to benchmark results directory"
3872
+ )
3873
+ p_benchmark.add_argument(
3874
+ "--open", action="store_true", help="Open viewer in browser"
3875
+ )
771
3876
  p_benchmark.set_defaults(func=cmd_benchmark_viewer)
772
3877
 
773
3878
  # compare