openadapt-ml 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. openadapt_ml/benchmarks/__init__.py +8 -0
  2. openadapt_ml/benchmarks/agent.py +90 -11
  3. openadapt_ml/benchmarks/azure.py +35 -6
  4. openadapt_ml/benchmarks/cli.py +4449 -201
  5. openadapt_ml/benchmarks/live_tracker.py +180 -0
  6. openadapt_ml/benchmarks/runner.py +41 -4
  7. openadapt_ml/benchmarks/viewer.py +1219 -0
  8. openadapt_ml/benchmarks/vm_monitor.py +610 -0
  9. openadapt_ml/benchmarks/waa.py +61 -4
  10. openadapt_ml/benchmarks/waa_deploy/Dockerfile +222 -0
  11. openadapt_ml/benchmarks/waa_deploy/__init__.py +10 -0
  12. openadapt_ml/benchmarks/waa_deploy/api_agent.py +539 -0
  13. openadapt_ml/benchmarks/waa_deploy/start_waa_server.bat +53 -0
  14. openadapt_ml/benchmarks/waa_live.py +619 -0
  15. openadapt_ml/cloud/local.py +1555 -1
  16. openadapt_ml/cloud/ssh_tunnel.py +553 -0
  17. openadapt_ml/datasets/next_action.py +87 -68
  18. openadapt_ml/evals/grounding.py +26 -8
  19. openadapt_ml/evals/trajectory_matching.py +84 -36
  20. openadapt_ml/experiments/demo_prompt/__init__.py +19 -0
  21. openadapt_ml/experiments/demo_prompt/format_demo.py +226 -0
  22. openadapt_ml/experiments/demo_prompt/results/experiment_20251231_002125.json +83 -0
  23. openadapt_ml/experiments/demo_prompt/results/experiment_n30_20251231_165958.json +1100 -0
  24. openadapt_ml/experiments/demo_prompt/results/multistep_20251231_025051.json +182 -0
  25. openadapt_ml/experiments/demo_prompt/run_experiment.py +531 -0
  26. openadapt_ml/experiments/waa_demo/__init__.py +10 -0
  27. openadapt_ml/experiments/waa_demo/demos.py +357 -0
  28. openadapt_ml/experiments/waa_demo/runner.py +717 -0
  29. openadapt_ml/experiments/waa_demo/tasks.py +151 -0
  30. openadapt_ml/export/__init__.py +9 -0
  31. openadapt_ml/export/__main__.py +6 -0
  32. openadapt_ml/export/cli.py +89 -0
  33. openadapt_ml/export/parquet.py +265 -0
  34. openadapt_ml/ingest/__init__.py +3 -4
  35. openadapt_ml/ingest/capture.py +89 -81
  36. openadapt_ml/ingest/loader.py +116 -68
  37. openadapt_ml/ingest/synthetic.py +221 -159
  38. openadapt_ml/retrieval/README.md +226 -0
  39. openadapt_ml/retrieval/USAGE.md +391 -0
  40. openadapt_ml/retrieval/__init__.py +91 -0
  41. openadapt_ml/retrieval/demo_retriever.py +817 -0
  42. openadapt_ml/retrieval/embeddings.py +629 -0
  43. openadapt_ml/retrieval/index.py +194 -0
  44. openadapt_ml/retrieval/retriever.py +160 -0
  45. openadapt_ml/runtime/policy.py +10 -10
  46. openadapt_ml/schema/__init__.py +104 -0
  47. openadapt_ml/schema/converters.py +541 -0
  48. openadapt_ml/schema/episode.py +457 -0
  49. openadapt_ml/scripts/compare.py +26 -16
  50. openadapt_ml/scripts/eval_policy.py +4 -5
  51. openadapt_ml/scripts/prepare_synthetic.py +14 -17
  52. openadapt_ml/scripts/train.py +81 -70
  53. openadapt_ml/training/benchmark_viewer.py +3225 -0
  54. openadapt_ml/training/trainer.py +120 -363
  55. openadapt_ml/training/trl_trainer.py +354 -0
  56. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/METADATA +102 -60
  57. openadapt_ml-0.2.0.dist-info/RECORD +86 -0
  58. openadapt_ml/schemas/__init__.py +0 -53
  59. openadapt_ml/schemas/sessions.py +0 -122
  60. openadapt_ml/schemas/validation.py +0 -252
  61. openadapt_ml-0.1.0.dist-info/RECORD +0 -55
  62. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/WHEEL +0 -0
  63. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/licenses/LICENSE +0 -0
@@ -36,6 +36,8 @@ import webbrowser
36
36
  from pathlib import Path
37
37
  from typing import Any
38
38
 
39
+ from openadapt_ml.cloud.ssh_tunnel import get_tunnel_manager
40
+
39
41
  # Training output directory
40
42
  TRAINING_OUTPUT = Path("training_output")
41
43
 
@@ -143,6 +145,16 @@ def _regenerate_benchmark_viewer_if_available(output_dir: Path) -> bool:
143
145
  # No real benchmark data - generate empty state viewer
144
146
  try:
145
147
  generate_empty_benchmark_viewer(benchmark_html_path)
148
+
149
+ # Still create symlink for azure_jobs.json access (even without real benchmarks)
150
+ if benchmark_results_dir.exists():
151
+ benchmark_results_link = output_dir / "benchmark_results"
152
+ if benchmark_results_link.is_symlink():
153
+ benchmark_results_link.unlink()
154
+ elif benchmark_results_link.exists():
155
+ shutil.rmtree(benchmark_results_link)
156
+ benchmark_results_link.symlink_to(benchmark_results_dir.absolute())
157
+
146
158
  print(" Generated benchmark viewer: No real evaluation data yet")
147
159
  return True
148
160
  except Exception as e:
@@ -168,6 +180,14 @@ def _regenerate_benchmark_viewer_if_available(output_dir: Path) -> bool:
168
180
  tasks_dst = benchmark_tasks_dir / benchmark_dir.name
169
181
  shutil.copytree(tasks_src, tasks_dst)
170
182
 
183
+ # Create symlink for benchmark_results directory (for azure_jobs.json access)
184
+ benchmark_results_link = output_dir / "benchmark_results"
185
+ if benchmark_results_link.is_symlink():
186
+ benchmark_results_link.unlink()
187
+ elif benchmark_results_link.exists():
188
+ shutil.rmtree(benchmark_results_link)
189
+ benchmark_results_link.symlink_to(benchmark_results_dir.absolute())
190
+
171
191
  print(f" Regenerated benchmark viewer with {len(real_benchmarks)} run(s)")
172
192
  return True
173
193
  except Exception as e:
@@ -438,6 +458,10 @@ def cmd_serve(args: argparse.Namespace) -> int:
438
458
 
439
459
  start_page = "dashboard.html"
440
460
 
461
+ # Override start page if specified
462
+ if hasattr(args, 'start_page') and args.start_page:
463
+ start_page = args.start_page
464
+
441
465
  # Serve from the specified directory
442
466
  os.chdir(serve_dir)
443
467
 
@@ -535,6 +559,42 @@ def cmd_serve(args: argparse.Namespace) -> int:
535
559
  }))
536
560
 
537
561
  threading.Thread(target=run_benchmark, daemon=True).start()
562
+ elif self.path == '/api/vms/register':
563
+ # Register a new VM
564
+ content_length = int(self.headers.get('Content-Length', 0))
565
+ body = self.rfile.read(content_length).decode('utf-8') if content_length else '{}'
566
+ try:
567
+ vm_data = json.loads(body)
568
+ result = self._register_vm(vm_data)
569
+ self.send_response(200)
570
+ self.send_header('Content-Type', 'application/json')
571
+ self.send_header('Access-Control-Allow-Origin', '*')
572
+ self.end_headers()
573
+ self.wfile.write(json.dumps(result).encode())
574
+ except Exception as e:
575
+ self.send_response(500)
576
+ self.send_header('Content-Type', 'application/json')
577
+ self.send_header('Access-Control-Allow-Origin', '*')
578
+ self.end_headers()
579
+ self.wfile.write(json.dumps({"error": str(e)}).encode())
580
+ elif self.path == '/api/benchmark/start':
581
+ # Start a benchmark run with configurable parameters
582
+ content_length = int(self.headers.get('Content-Length', 0))
583
+ body = self.rfile.read(content_length).decode('utf-8') if content_length else '{}'
584
+ try:
585
+ params = json.loads(body)
586
+ result = self._start_benchmark_run(params)
587
+ self.send_response(200)
588
+ self.send_header('Content-Type', 'application/json')
589
+ self.send_header('Access-Control-Allow-Origin', '*')
590
+ self.end_headers()
591
+ self.wfile.write(json.dumps(result).encode())
592
+ except Exception as e:
593
+ self.send_response(500)
594
+ self.send_header('Content-Type', 'application/json')
595
+ self.send_header('Access-Control-Allow-Origin', '*')
596
+ self.end_headers()
597
+ self.wfile.write(json.dumps({"error": str(e)}).encode())
538
598
  else:
539
599
  self.send_error(404, "Not found")
540
600
 
@@ -552,10 +612,1469 @@ def cmd_serve(args: argparse.Namespace) -> int:
552
612
  self.send_header('Access-Control-Allow-Origin', '*')
553
613
  self.end_headers()
554
614
  self.wfile.write(progress.encode())
615
+ elif self.path.startswith('/api/benchmark-live'):
616
+ # Return live evaluation state
617
+ live_file = Path("benchmark_live.json") # Relative to serve_dir (cwd)
618
+ if live_file.exists():
619
+ live_state = live_file.read_text()
620
+ else:
621
+ live_state = json.dumps({"status": "idle"})
622
+
623
+ self.send_response(200)
624
+ self.send_header('Content-Type', 'application/json')
625
+ self.send_header('Access-Control-Allow-Origin', '*')
626
+ self.end_headers()
627
+ self.wfile.write(live_state.encode())
628
+ elif self.path.startswith('/api/tasks'):
629
+ # Return background task status (VM, Docker, benchmarks)
630
+ try:
631
+ tasks = self._fetch_background_tasks()
632
+ self.send_response(200)
633
+ self.send_header('Content-Type', 'application/json')
634
+ self.send_header('Access-Control-Allow-Origin', '*')
635
+ self.end_headers()
636
+ self.wfile.write(json.dumps(tasks).encode())
637
+ except Exception as e:
638
+ self.send_response(500)
639
+ self.send_header('Content-Type', 'application/json')
640
+ self.send_header('Access-Control-Allow-Origin', '*')
641
+ self.end_headers()
642
+ self.wfile.write(json.dumps({"error": str(e)}).encode())
643
+ elif self.path.startswith('/api/azure-jobs'):
644
+ # Return LIVE Azure job status from Azure ML
645
+ # Supports ?force=true parameter for manual refresh (always fetches live)
646
+ try:
647
+ from urllib.parse import urlparse, parse_qs
648
+ query = parse_qs(urlparse(self.path).query)
649
+ force_refresh = query.get('force', ['false'])[0].lower() == 'true'
650
+
651
+ # Always fetch live data (force just indicates manual refresh for logging)
652
+ if force_refresh:
653
+ print("Azure Jobs: Manual refresh requested")
654
+
655
+ jobs = self._fetch_live_azure_jobs()
656
+ self.send_response(200)
657
+ self.send_header('Content-Type', 'application/json')
658
+ self.send_header('Access-Control-Allow-Origin', '*')
659
+ self.end_headers()
660
+ self.wfile.write(json.dumps(jobs).encode())
661
+ except Exception as e:
662
+ self.send_response(500)
663
+ self.send_header('Content-Type', 'application/json')
664
+ self.send_header('Access-Control-Allow-Origin', '*')
665
+ self.end_headers()
666
+ self.wfile.write(json.dumps({"error": str(e)}).encode())
667
+ elif self.path.startswith('/api/benchmark-sse'):
668
+ # Server-Sent Events endpoint for real-time benchmark updates
669
+ try:
670
+ from urllib.parse import urlparse, parse_qs
671
+ query = parse_qs(urlparse(self.path).query)
672
+ interval = int(query.get('interval', [5])[0])
673
+
674
+ # Validate interval (min 1s, max 60s)
675
+ interval = max(1, min(60, interval))
676
+
677
+ self._stream_benchmark_updates(interval)
678
+ except Exception as e:
679
+ self.send_error(500, f"SSE error: {e}")
680
+ elif self.path.startswith('/api/vms'):
681
+ # Return VM registry with live status
682
+ try:
683
+ vms = self._fetch_vm_registry()
684
+ self.send_response(200)
685
+ self.send_header('Content-Type', 'application/json')
686
+ self.send_header('Access-Control-Allow-Origin', '*')
687
+ self.end_headers()
688
+ self.wfile.write(json.dumps(vms).encode())
689
+ except Exception as e:
690
+ self.send_response(500)
691
+ self.send_header('Content-Type', 'application/json')
692
+ self.send_header('Access-Control-Allow-Origin', '*')
693
+ self.end_headers()
694
+ self.wfile.write(json.dumps({"error": str(e)}).encode())
695
+ elif self.path.startswith('/api/azure-job-logs'):
696
+ # Return live logs for running Azure job
697
+ try:
698
+ # Parse job_id from query string
699
+ from urllib.parse import urlparse, parse_qs
700
+ query = parse_qs(urlparse(self.path).query)
701
+ job_id = query.get('job_id', [None])[0]
702
+
703
+ logs = self._fetch_azure_job_logs(job_id)
704
+ self.send_response(200)
705
+ self.send_header('Content-Type', 'application/json')
706
+ self.send_header('Access-Control-Allow-Origin', '*')
707
+ self.end_headers()
708
+ self.wfile.write(json.dumps(logs).encode())
709
+ except Exception as e:
710
+ self.send_response(500)
711
+ self.send_header('Content-Type', 'application/json')
712
+ self.send_header('Access-Control-Allow-Origin', '*')
713
+ self.end_headers()
714
+ self.wfile.write(json.dumps({"error": str(e)}).encode())
715
+ elif self.path.startswith('/api/probe-vm'):
716
+ # Probe the VM to check if WAA server is responding
717
+ try:
718
+ result = self._probe_vm()
719
+ self.send_response(200)
720
+ self.send_header('Content-Type', 'application/json')
721
+ self.send_header('Access-Control-Allow-Origin', '*')
722
+ self.end_headers()
723
+ self.wfile.write(json.dumps(result).encode())
724
+ except Exception as e:
725
+ self.send_response(500)
726
+ self.send_header('Content-Type', 'application/json')
727
+ self.send_header('Access-Control-Allow-Origin', '*')
728
+ self.end_headers()
729
+ self.wfile.write(json.dumps({"error": str(e), "responding": False}).encode())
730
+ elif self.path.startswith('/api/tunnels'):
731
+ # Return SSH tunnel status
732
+ try:
733
+ tunnel_mgr = get_tunnel_manager()
734
+ status = tunnel_mgr.get_tunnel_status()
735
+ result = {
736
+ name: {
737
+ "active": s.active,
738
+ "local_port": s.local_port,
739
+ "remote_endpoint": s.remote_endpoint,
740
+ "pid": s.pid,
741
+ "error": s.error,
742
+ }
743
+ for name, s in status.items()
744
+ }
745
+ self.send_response(200)
746
+ self.send_header('Content-Type', 'application/json')
747
+ self.send_header('Access-Control-Allow-Origin', '*')
748
+ self.end_headers()
749
+ self.wfile.write(json.dumps(result).encode())
750
+ except Exception as e:
751
+ self.send_response(500)
752
+ self.send_header('Content-Type', 'application/json')
753
+ self.send_header('Access-Control-Allow-Origin', '*')
754
+ self.end_headers()
755
+ self.wfile.write(json.dumps({"error": str(e)}).encode())
756
+ elif self.path.startswith('/api/current-run'):
757
+ # Return currently running benchmark info
758
+ try:
759
+ result = self._get_current_run()
760
+ self.send_response(200)
761
+ self.send_header('Content-Type', 'application/json')
762
+ self.send_header('Access-Control-Allow-Origin', '*')
763
+ self.end_headers()
764
+ self.wfile.write(json.dumps(result).encode())
765
+ except Exception as e:
766
+ self.send_response(500)
767
+ self.send_header('Content-Type', 'application/json')
768
+ self.send_header('Access-Control-Allow-Origin', '*')
769
+ self.end_headers()
770
+ self.wfile.write(json.dumps({"error": str(e), "running": False}).encode())
771
+ elif self.path.startswith('/api/background-tasks'):
772
+ # Alias for /api/tasks - background task status
773
+ try:
774
+ tasks = self._fetch_background_tasks()
775
+ self.send_response(200)
776
+ self.send_header('Content-Type', 'application/json')
777
+ self.send_header('Access-Control-Allow-Origin', '*')
778
+ self.end_headers()
779
+ self.wfile.write(json.dumps(tasks).encode())
780
+ except Exception as e:
781
+ self.send_response(500)
782
+ self.send_header('Content-Type', 'application/json')
783
+ self.send_header('Access-Control-Allow-Origin', '*')
784
+ self.end_headers()
785
+ self.wfile.write(json.dumps({"error": str(e)}).encode())
555
786
  else:
556
787
  # Default file serving
557
788
  super().do_GET()
558
789
 
790
+ def _fetch_live_azure_jobs(self):
791
+ """Fetch live job status from Azure ML."""
792
+ import subprocess
793
+ result = subprocess.run(
794
+ ["az", "ml", "job", "list",
795
+ "--resource-group", "openadapt-agents",
796
+ "--workspace-name", "openadapt-ml",
797
+ "--query", "[].{name:name,display_name:display_name,status:status,creation_context:creation_context.created_at}",
798
+ "-o", "json"],
799
+ capture_output=True, text=True, timeout=30
800
+ )
801
+ if result.returncode != 0:
802
+ raise Exception(f"Azure CLI error: {result.stderr}")
803
+
804
+ jobs = json.loads(result.stdout)
805
+ # Format for frontend
806
+ experiment_id = "ad29082c-0607-4fda-8cc7-38944eb5a518"
807
+ wsid = "/subscriptions/78add6c6-c92a-4a53-b751-eb644ac77e59/resourceGroups/openadapt-agents/providers/Microsoft.MachineLearningServices/workspaces/openadapt-ml"
808
+
809
+ formatted = []
810
+ for job in jobs[:10]: # Limit to 10 most recent
811
+ formatted.append({
812
+ "job_id": job.get("name", "unknown"),
813
+ "display_name": job.get("display_name", ""),
814
+ "status": job.get("status", "unknown").lower(),
815
+ "started_at": job.get("creation_context", ""),
816
+ "azure_dashboard_url": f"https://ml.azure.com/experiments/id/{experiment_id}/runs/{job.get('name', '')}?wsid={wsid}",
817
+ "is_live": True # Flag to indicate this is live data
818
+ })
819
+ return formatted
820
+
821
+ def _fetch_azure_job_logs(self, job_id: str | None):
822
+ """Fetch logs for an Azure ML job (streaming for running jobs)."""
823
+ import subprocess
824
+
825
+ if not job_id:
826
+ # Get the most recent running job
827
+ jobs = self._fetch_live_azure_jobs()
828
+ running = [j for j in jobs if j['status'] == 'running']
829
+ if running:
830
+ job_id = running[0]['job_id']
831
+ else:
832
+ return {"logs": "No running jobs found", "job_id": None, "status": "idle"}
833
+
834
+ # Try to stream logs for running job using az ml job stream
835
+ try:
836
+ result = subprocess.run(
837
+ ["az", "ml", "job", "stream",
838
+ "--name", job_id,
839
+ "--resource-group", "openadapt-agents",
840
+ "--workspace-name", "openadapt-ml"],
841
+ capture_output=True, text=True, timeout=3 # Short timeout
842
+ )
843
+ if result.returncode == 0 and result.stdout.strip():
844
+ return {"logs": result.stdout[-5000:], "job_id": job_id, "status": "streaming"}
845
+ except subprocess.TimeoutExpired:
846
+ pass # Fall through to job show
847
+
848
+ # Get job details instead
849
+ result = subprocess.run(
850
+ ["az", "ml", "job", "show",
851
+ "--name", job_id,
852
+ "--resource-group", "openadapt-agents",
853
+ "--workspace-name", "openadapt-ml",
854
+ "-o", "json"],
855
+ capture_output=True, text=True, timeout=10
856
+ )
857
+
858
+ if result.returncode == 0:
859
+ job_info = json.loads(result.stdout)
860
+ return {
861
+ "logs": f"Job {job_id} is {job_info.get('status', 'unknown')}\\n\\nCommand: {job_info.get('command', 'N/A')}",
862
+ "job_id": job_id,
863
+ "status": job_info.get('status', 'unknown').lower(),
864
+ "command": job_info.get('command', '')
865
+ }
866
+
867
+ return {"logs": f"Could not fetch logs: {result.stderr}", "job_id": job_id, "status": "error"}
868
+
869
+ def _get_vm_detailed_metadata(self, vm_ip: str, container_name: str, logs: str, phase: str) -> dict:
870
+ """Get detailed VM metadata for the VM Details panel.
871
+
872
+ Returns:
873
+ dict with disk_usage_gb, memory_usage_mb, setup_script_phase, probe_response, qmp_connected, dependencies
874
+ """
875
+ import subprocess
876
+ import re
877
+
878
+ metadata = {
879
+ "disk_usage_gb": None,
880
+ "memory_usage_mb": None,
881
+ "setup_script_phase": None,
882
+ "probe_response": None,
883
+ "qmp_connected": False,
884
+ "dependencies": []
885
+ }
886
+
887
+ # 1. Get disk usage from docker stats
888
+ try:
889
+ disk_result = subprocess.run(
890
+ ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
891
+ "-i", str(Path.home() / ".ssh" / "id_rsa"),
892
+ f"azureuser@{vm_ip}",
893
+ f"docker exec {container_name} df -h /storage 2>/dev/null | tail -1"],
894
+ capture_output=True, text=True, timeout=10
895
+ )
896
+ if disk_result.returncode == 0 and disk_result.stdout.strip():
897
+ # Parse: "Filesystem Size Used Avail Use% Mounted on"
898
+ # Example: "/dev/sda1 30G 9.2G 20G 31% /storage"
899
+ parts = disk_result.stdout.split()
900
+ if len(parts) >= 3:
901
+ used_str = parts[2] # e.g., "9.2G"
902
+ total_str = parts[1] # e.g., "30G"
903
+ # Convert to GB (handle M/G suffixes)
904
+ def to_gb(s):
905
+ if s.endswith('G'):
906
+ return float(s[:-1])
907
+ elif s.endswith('M'):
908
+ return float(s[:-1]) / 1024
909
+ elif s.endswith('K'):
910
+ return float(s[:-1]) / (1024 * 1024)
911
+ return 0
912
+ metadata["disk_usage_gb"] = f"{to_gb(used_str):.1f} GB / {to_gb(total_str):.0f} GB used"
913
+ except Exception:
914
+ pass
915
+
916
+ # 2. Get memory usage from docker stats
917
+ try:
918
+ mem_result = subprocess.run(
919
+ ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
920
+ "-i", str(Path.home() / ".ssh" / "id_rsa"),
921
+ f"azureuser@{vm_ip}",
922
+ f"docker stats {container_name} --no-stream --format '{{{{.MemUsage}}}}'"],
923
+ capture_output=True, text=True, timeout=10
924
+ )
925
+ if mem_result.returncode == 0 and mem_result.stdout.strip():
926
+ # Example: "1.5GiB / 4GiB"
927
+ metadata["memory_usage_mb"] = mem_result.stdout.strip()
928
+ except Exception:
929
+ pass
930
+
931
+ # 3. Parse setup script phase from logs
932
+ metadata["setup_script_phase"] = self._parse_setup_phase_from_logs(logs, phase)
933
+
934
+ # 4. Check /probe endpoint
935
+ try:
936
+ probe_result = subprocess.run(
937
+ ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
938
+ "-i", str(Path.home() / ".ssh" / "id_rsa"),
939
+ f"azureuser@{vm_ip}",
940
+ "curl -s --connect-timeout 2 http://20.20.20.21:5000/probe 2>/dev/null"],
941
+ capture_output=True, text=True, timeout=10
942
+ )
943
+ if probe_result.returncode == 0 and probe_result.stdout.strip():
944
+ metadata["probe_response"] = probe_result.stdout.strip()
945
+ else:
946
+ metadata["probe_response"] = "Not responding"
947
+ except Exception:
948
+ metadata["probe_response"] = "Connection failed"
949
+
950
+ # 5. Check QMP connection (port 7200)
951
+ try:
952
+ qmp_result = subprocess.run(
953
+ ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
954
+ "-i", str(Path.home() / ".ssh" / "id_rsa"),
955
+ f"azureuser@{vm_ip}",
956
+ "nc -z -w2 localhost 7200 2>&1"],
957
+ capture_output=True, text=True, timeout=10
958
+ )
959
+ metadata["qmp_connected"] = qmp_result.returncode == 0
960
+ except Exception:
961
+ pass
962
+
963
+ # 6. Parse dependencies from logs
964
+ metadata["dependencies"] = self._parse_dependencies_from_logs(logs, phase)
965
+
966
+ return metadata
967
+
968
+ def _parse_setup_phase_from_logs(self, logs: str, current_phase: str) -> str:
969
+ """Parse the current setup script phase from logs.
970
+
971
+ Looks for patterns indicating which script is running:
972
+ - install.bat
973
+ - setup.ps1
974
+ - on-logon.ps1
975
+ """
976
+ if current_phase == "ready":
977
+ return "Setup complete"
978
+ elif current_phase == "oobe":
979
+ # Check for specific script patterns
980
+ if "on-logon.ps1" in logs.lower():
981
+ return "Running on-logon.ps1"
982
+ elif "setup.ps1" in logs.lower():
983
+ return "Running setup.ps1"
984
+ elif "install.bat" in logs.lower():
985
+ return "Running install.bat"
986
+ else:
987
+ return "Windows installation in progress"
988
+ elif current_phase == "booting":
989
+ return "Booting Windows"
990
+ elif current_phase in ["downloading", "extracting", "configuring", "building"]:
991
+ return "Preparing Windows VM"
992
+ else:
993
+ return "Initializing..."
994
+
995
+ def _parse_dependencies_from_logs(self, logs: str, phase: str) -> list[dict]:
996
+ """Parse dependency installation status from logs.
997
+
998
+ Returns list of dependencies with their installation status:
999
+ - Python
1000
+ - Chrome
1001
+ - LibreOffice
1002
+ - VSCode
1003
+ - etc.
1004
+ """
1005
+ dependencies = [
1006
+ {"name": "Python", "icon": "🐍", "status": "pending"},
1007
+ {"name": "Chrome", "icon": "🌐", "status": "pending"},
1008
+ {"name": "LibreOffice", "icon": "📝", "status": "pending"},
1009
+ {"name": "VSCode", "icon": "💻", "status": "pending"},
1010
+ {"name": "WAA Server", "icon": "🔧", "status": "pending"},
1011
+ ]
1012
+
1013
+ if phase not in ["oobe", "ready"]:
1014
+ # Not yet at Windows setup phase
1015
+ return dependencies
1016
+
1017
+ logs_lower = logs.lower()
1018
+
1019
+ # Check for installation patterns
1020
+ if "python" in logs_lower and ("installing python" in logs_lower or "python.exe" in logs_lower):
1021
+ dependencies[0]["status"] = "installing"
1022
+ elif "python" in logs_lower and "installed" in logs_lower:
1023
+ dependencies[0]["status"] = "complete"
1024
+
1025
+ if "chrome" in logs_lower and ("downloading" in logs_lower or "installing" in logs_lower):
1026
+ dependencies[1]["status"] = "installing"
1027
+ elif "chrome" in logs_lower and "installed" in logs_lower:
1028
+ dependencies[1]["status"] = "complete"
1029
+
1030
+ if "libreoffice" in logs_lower and ("downloading" in logs_lower or "installing" in logs_lower):
1031
+ dependencies[2]["status"] = "installing"
1032
+ elif "libreoffice" in logs_lower and "installed" in logs_lower:
1033
+ dependencies[2]["status"] = "complete"
1034
+
1035
+ if "vscode" in logs_lower or "visual studio code" in logs_lower:
1036
+ if "installing" in logs_lower:
1037
+ dependencies[3]["status"] = "installing"
1038
+ elif "installed" in logs_lower:
1039
+ dependencies[3]["status"] = "complete"
1040
+
1041
+ if "waa" in logs_lower or "flask" in logs_lower:
1042
+ if "starting" in logs_lower or "running" in logs_lower:
1043
+ dependencies[4]["status"] = "installing"
1044
+ elif phase == "ready":
1045
+ dependencies[4]["status"] = "complete"
1046
+
1047
+ return dependencies
1048
+
1049
+ def _fetch_background_tasks(self):
1050
+ """Fetch status of all background tasks: Azure VM, Docker containers, benchmarks."""
1051
+ import subprocess
1052
+ from datetime import datetime
1053
+ import time
1054
+
1055
+ tasks = []
1056
+
1057
+ # Check for VM IP from environment (set by CLI when auto-launching viewer)
1058
+ env_vm_ip = os.environ.get("WAA_VM_IP")
1059
+ env_internal_ip = os.environ.get("WAA_INTERNAL_IP", "172.30.0.2")
1060
+
1061
+ # 1. Check Azure WAA VM status
1062
+ vm_ip = None
1063
+ if env_vm_ip:
1064
+ # Use environment variable - VM IP was provided directly
1065
+ vm_ip = env_vm_ip
1066
+ tasks.append({
1067
+ "task_id": "azure-vm-waa",
1068
+ "task_type": "vm_provision",
1069
+ "status": "completed",
1070
+ "phase": "ready", # Match status to prevent "Starting" + "completed" conflict
1071
+ "title": "Azure VM Host",
1072
+ "description": f"Linux host running at {vm_ip}",
1073
+ "progress_percent": 100.0,
1074
+ "elapsed_seconds": 0,
1075
+ "metadata": {
1076
+ "vm_name": "waa-eval-vm",
1077
+ "ip_address": vm_ip,
1078
+ "internal_ip": env_internal_ip
1079
+ }
1080
+ })
1081
+ else:
1082
+ # Query Azure CLI for VM status
1083
+ try:
1084
+ result = subprocess.run(
1085
+ ["az", "vm", "get-instance-view",
1086
+ "--name", "waa-eval-vm",
1087
+ "--resource-group", "openadapt-agents",
1088
+ "--query", "instanceView.statuses",
1089
+ "-o", "json"],
1090
+ capture_output=True, text=True, timeout=10
1091
+ )
1092
+ if result.returncode == 0:
1093
+ statuses = json.loads(result.stdout)
1094
+ power_state = "unknown"
1095
+ for s in statuses:
1096
+ if s.get("code", "").startswith("PowerState/"):
1097
+ power_state = s["code"].replace("PowerState/", "")
1098
+
1099
+ # Get VM IP
1100
+ ip_result = subprocess.run(
1101
+ ["az", "vm", "list-ip-addresses",
1102
+ "--name", "waa-eval-vm",
1103
+ "--resource-group", "openadapt-agents",
1104
+ "--query", "[0].virtualMachine.network.publicIpAddresses[0].ipAddress",
1105
+ "-o", "tsv"],
1106
+ capture_output=True, text=True, timeout=10
1107
+ )
1108
+ vm_ip = ip_result.stdout.strip() if ip_result.returncode == 0 else None
1109
+
1110
+ if power_state == "running":
1111
+ tasks.append({
1112
+ "task_id": "azure-vm-waa",
1113
+ "task_type": "vm_provision",
1114
+ "status": "completed",
1115
+ "phase": "ready", # Match status to prevent "Starting" + "completed" conflict
1116
+ "title": "Azure VM Host",
1117
+ "description": f"Linux host running at {vm_ip}" if vm_ip else "Linux host running",
1118
+ "progress_percent": 100.0,
1119
+ "elapsed_seconds": 0,
1120
+ "metadata": {
1121
+ "vm_name": "waa-eval-vm",
1122
+ "ip_address": vm_ip
1123
+ # No VNC link - that's for the Windows container
1124
+ }
1125
+ })
1126
+ except subprocess.TimeoutExpired:
1127
+ pass
1128
+ except Exception:
1129
+ pass
1130
+
1131
+ # 2. Check Docker container status on VM (if we have an IP)
1132
+ if vm_ip:
1133
+ try:
1134
+ docker_result = subprocess.run(
1135
+ ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
1136
+ "-i", str(Path.home() / ".ssh" / "id_rsa"),
1137
+ f"azureuser@{vm_ip}",
1138
+ "docker ps --format '{{.Names}}|{{.Status}}|{{.Image}}'"],
1139
+ capture_output=True, text=True, timeout=15
1140
+ )
1141
+ if docker_result.returncode == 0 and docker_result.stdout.strip():
1142
+ for line in docker_result.stdout.strip().split('\n'):
1143
+ parts = line.split('|')
1144
+ if len(parts) >= 3:
1145
+ container_name, status, image = parts[0], parts[1], parts[2]
1146
+ # Parse "Up X minutes" to determine if healthy
1147
+ is_healthy = "Up" in status
1148
+
1149
+ # Check for Windows VM specifically
1150
+ if "windows" in image.lower() or container_name == "winarena":
1151
+ # Get detailed progress from docker logs
1152
+ log_check = subprocess.run(
1153
+ ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
1154
+ "-i", str(Path.home() / ".ssh" / "id_rsa"),
1155
+ f"azureuser@{vm_ip}",
1156
+ f"docker logs {container_name} 2>&1 | tail -30"],
1157
+ capture_output=True, text=True, timeout=10
1158
+ )
1159
+ logs = log_check.stdout if log_check.returncode == 0 else ""
1160
+
1161
+ # Parse progress from logs
1162
+ phase = "unknown"
1163
+ progress = 0.0
1164
+ description = "Starting..."
1165
+
1166
+ if "Windows started successfully" in logs:
1167
+ # Check if WAA server is ready via Docker port forwarding
1168
+ # See docs/waa_network_architecture.md - always use localhost
1169
+ server_check = subprocess.run(
1170
+ ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
1171
+ "-i", str(Path.home() / ".ssh" / "id_rsa"),
1172
+ f"azureuser@{vm_ip}",
1173
+ "curl -s --connect-timeout 2 http://localhost:5000/probe 2>/dev/null"],
1174
+ capture_output=True, text=True, timeout=10
1175
+ )
1176
+ waa_ready = server_check.returncode == 0 and "Service is operational" in server_check.stdout
1177
+ if waa_ready:
1178
+ phase = "ready"
1179
+ progress = 100.0
1180
+ description = "WAA Server ready - benchmarks can run"
1181
+ else:
1182
+ phase = "oobe"
1183
+ progress = 80.0 # Phase 5/6 - VM install in progress
1184
+ description = "Phase 5/6: Windows installing (check VNC for %). OEM scripts will run after."
1185
+ elif "Booting Windows" in logs:
1186
+ phase = "booting"
1187
+ progress = 70.0 # Phase 4/6
1188
+ description = "Phase 4/6: Booting Windows from installer..."
1189
+ elif "Building Windows" in logs or "Creating a" in logs:
1190
+ phase = "building"
1191
+ progress = 60.0 # Phase 3/6
1192
+ description = "Phase 3/6: Building Windows VM disk..."
1193
+ elif "Adding" in logs and "image" in logs:
1194
+ phase = "configuring"
1195
+ progress = 50.0 # Phase 2/6
1196
+ description = "Phase 2/6: Configuring Windows image with WAA scripts..."
1197
+ elif "Extracting" in logs:
1198
+ phase = "extracting"
1199
+ progress = 35.0 # Phase 1/6 (after download)
1200
+ description = "Phase 1/6: Extracting Windows ISO..."
1201
+ else:
1202
+ # Check for download progress (e.g., "1234K ........ 45% 80M 30s")
1203
+ import re
1204
+ download_match = re.search(r'(\d+)%\s+[\d.]+[KMG]\s+(\d+)s', logs)
1205
+ if download_match:
1206
+ phase = "downloading"
1207
+ dl_pct = float(download_match.group(1))
1208
+ progress = dl_pct * 0.30 # 0-30% for download phase
1209
+ eta = download_match.group(2)
1210
+ description = f"Phase 0/6: Downloading Windows 11... {download_match.group(1)}% ({eta}s left)"
1211
+
1212
+ # Improve phase detection - if Windows is booted but WAA not ready,
1213
+ # it might be at login screen waiting for OEM scripts or running install.bat
1214
+ if phase == "oobe" and "Boot0004" in logs:
1215
+ # Windows finished installing, at login/desktop
1216
+ # install.bat should auto-run from FirstLogonCommands (see Dockerfile)
1217
+ description = "Phase 5/6: Windows at desktop, OEM scripts running... (WAA server starting)"
1218
+ progress = 90.0
1219
+
1220
+ # Get detailed metadata for VM Details panel
1221
+ vm_metadata = self._get_vm_detailed_metadata(vm_ip, container_name, logs, phase)
1222
+
1223
+ tasks.append({
1224
+ "task_id": f"docker-{container_name}",
1225
+ "task_type": "docker_container",
1226
+ "status": "completed" if phase == "ready" else "running",
1227
+ "title": "Windows 11 + WAA Server",
1228
+ "description": description,
1229
+ "progress_percent": progress,
1230
+ "elapsed_seconds": 0,
1231
+ "phase": phase,
1232
+ "metadata": {
1233
+ "container": container_name,
1234
+ "image": image,
1235
+ "status": status,
1236
+ "phase": phase,
1237
+ "windows_ready": phase in ["oobe", "ready"],
1238
+ "waa_server_ready": phase == "ready",
1239
+ # Use localhost - SSH tunnel handles routing to VM
1240
+ # See docs/waa_network_architecture.md
1241
+ "vnc_url": "http://localhost:8006",
1242
+ "windows_username": "Docker",
1243
+ "windows_password": "admin",
1244
+ "recent_logs": logs[-500:] if logs else "",
1245
+ # Enhanced VM details
1246
+ "disk_usage_gb": vm_metadata["disk_usage_gb"],
1247
+ "memory_usage_mb": vm_metadata["memory_usage_mb"],
1248
+ "setup_script_phase": vm_metadata["setup_script_phase"],
1249
+ "probe_response": vm_metadata["probe_response"],
1250
+ "qmp_connected": vm_metadata["qmp_connected"],
1251
+ "dependencies": vm_metadata["dependencies"],
1252
+ }
1253
+ })
1254
+ except Exception as e:
1255
+ # SSH failed, VM might still be starting
1256
+ pass
1257
+
1258
+ # 3. Check local benchmark progress
1259
+ progress_file = Path("benchmark_progress.json")
1260
+ if progress_file.exists():
1261
+ try:
1262
+ progress = json.loads(progress_file.read_text())
1263
+ if progress.get("status") == "running":
1264
+ tasks.append({
1265
+ "task_id": "benchmark-local",
1266
+ "task_type": "benchmark_run",
1267
+ "status": "running",
1268
+ "title": f"{progress.get('provider', 'API').upper()} Benchmark",
1269
+ "description": progress.get("message", "Running benchmark..."),
1270
+ "progress_percent": (progress.get("tasks_complete", 0) / max(progress.get("tasks_total", 1), 1)) * 100,
1271
+ "elapsed_seconds": 0,
1272
+ "metadata": progress
1273
+ })
1274
+ except Exception:
1275
+ pass
1276
+
1277
+ return tasks
1278
+
1279
+ def _fetch_vm_registry(self):
1280
+ """Fetch VM registry with live status checks."""
1281
+ import subprocess
1282
+ from datetime import datetime
1283
+
1284
+ # Path to VM registry file (relative to project root)
1285
+ project_root = Path(__file__).parent.parent.parent
1286
+ registry_file = project_root / "benchmark_results" / "vm_registry.json"
1287
+
1288
+ if not registry_file.exists():
1289
+ return []
1290
+
1291
+ try:
1292
+ with open(registry_file) as f:
1293
+ vms = json.load(f)
1294
+ except Exception as e:
1295
+ return {"error": f"Failed to read VM registry: {e}"}
1296
+
1297
+ # Check status for each VM
1298
+ for vm in vms:
1299
+ vm["status"] = "unknown"
1300
+ vm["last_checked"] = datetime.now().isoformat()
1301
+ vm["vnc_reachable"] = False
1302
+ vm["waa_probe_status"] = "unknown"
1303
+
1304
+ # Check VNC (HTTP HEAD request)
1305
+ try:
1306
+ vnc_url = f"http://{vm['ssh_host']}:{vm['vnc_port']}"
1307
+ result = subprocess.run(
1308
+ ["curl", "-I", "-s", "--connect-timeout", "3", vnc_url],
1309
+ capture_output=True, text=True, timeout=5
1310
+ )
1311
+ if result.returncode == 0 and "200" in result.stdout:
1312
+ vm["vnc_reachable"] = True
1313
+ except Exception:
1314
+ pass
1315
+
1316
+ # Check WAA probe via SSH
1317
+ # Probe WAA via localhost (Docker port forwarding handles routing)
1318
+ # See docs/waa_network_architecture.md for architecture details
1319
+ try:
1320
+ waa_port = vm.get("waa_port", 5000)
1321
+ ssh_cmd = f"curl -s --connect-timeout 2 http://localhost:{waa_port}/probe 2>/dev/null"
1322
+ result = subprocess.run(
1323
+ ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=3",
1324
+ "-i", str(Path.home() / ".ssh" / "id_rsa"),
1325
+ f"{vm['ssh_user']}@{vm['ssh_host']}",
1326
+ ssh_cmd],
1327
+ capture_output=True, text=True, timeout=5
1328
+ )
1329
+ probe_success = result.returncode == 0 and "Service is operational" in result.stdout
1330
+ if probe_success:
1331
+ vm["waa_probe_status"] = "ready"
1332
+ vm["status"] = "online"
1333
+ # Auto-start SSH tunnels for VNC and WAA
1334
+ try:
1335
+ tunnel_mgr = get_tunnel_manager()
1336
+ tunnel_status = tunnel_mgr.ensure_tunnels_for_vm(
1337
+ vm_ip=vm["ssh_host"],
1338
+ ssh_user=vm.get("ssh_user", "azureuser"),
1339
+ )
1340
+ vm["tunnels"] = {
1341
+ name: {"active": s.active, "local_port": s.local_port, "error": s.error}
1342
+ for name, s in tunnel_status.items()
1343
+ }
1344
+ except Exception as e:
1345
+ vm["tunnels"] = {"error": str(e)}
1346
+ else:
1347
+ vm["waa_probe_status"] = "not responding"
1348
+ vm["status"] = "offline"
1349
+ # Stop tunnels when VM goes offline
1350
+ try:
1351
+ tunnel_mgr = get_tunnel_manager()
1352
+ tunnel_mgr.stop_all_tunnels()
1353
+ vm["tunnels"] = {}
1354
+ except Exception:
1355
+ pass
1356
+ except Exception:
1357
+ vm["waa_probe_status"] = "ssh failed"
1358
+ vm["status"] = "offline"
1359
+
1360
+ return vms
1361
+
1362
+ def _probe_vm(self) -> dict:
1363
+ """Probe the Azure VM to check if WAA server is responding.
1364
+
1365
+ Returns:
1366
+ dict with:
1367
+ - responding: bool - whether the WAA server is responding
1368
+ - vm_ip: str - the VM's IP address
1369
+ - container: str - the container name
1370
+ - probe_result: str - the raw probe response or error message
1371
+ - last_checked: str - ISO timestamp
1372
+ """
1373
+ import subprocess
1374
+ from datetime import datetime
1375
+
1376
+ result = {
1377
+ "responding": False,
1378
+ "vm_ip": None,
1379
+ "container": None,
1380
+ "probe_result": None,
1381
+ "last_checked": datetime.now().isoformat(),
1382
+ }
1383
+
1384
+ # First get VM IP
1385
+ try:
1386
+ ip_result = subprocess.run(
1387
+ ["az", "vm", "list-ip-addresses",
1388
+ "--name", "waa-eval-vm",
1389
+ "--resource-group", "openadapt-agents",
1390
+ "--query", "[0].virtualMachine.network.publicIpAddresses[0].ipAddress",
1391
+ "-o", "tsv"],
1392
+ capture_output=True, text=True, timeout=10
1393
+ )
1394
+ if ip_result.returncode == 0 and ip_result.stdout.strip():
1395
+ vm_ip = ip_result.stdout.strip()
1396
+ result["vm_ip"] = vm_ip
1397
+
1398
+ # Try to probe WAA server via SSH
1399
+ # Use the correct internal IP for the Windows VM inside Docker
1400
+ probe_result = subprocess.run(
1401
+ ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
1402
+ "-i", str(Path.home() / ".ssh" / "id_rsa"),
1403
+ f"azureuser@{vm_ip}",
1404
+ "docker exec waa-container curl -s --connect-timeout 3 http://172.30.0.2:5000/probe 2>/dev/null || echo 'probe_failed'"],
1405
+ capture_output=True, text=True, timeout=15
1406
+ )
1407
+
1408
+ result["container"] = "waa-container"
1409
+
1410
+ if probe_result.returncode == 0:
1411
+ probe_output = probe_result.stdout.strip()
1412
+ if probe_output and "probe_failed" not in probe_output:
1413
+ result["responding"] = True
1414
+ result["probe_result"] = probe_output
1415
+ else:
1416
+ result["probe_result"] = "WAA server not responding"
1417
+ else:
1418
+ result["probe_result"] = f"SSH/Docker error: {probe_result.stderr[:200]}"
1419
+ else:
1420
+ result["probe_result"] = "Could not get VM IP"
1421
+ except subprocess.TimeoutExpired:
1422
+ result["probe_result"] = "Connection timeout"
1423
+ except Exception as e:
1424
+ result["probe_result"] = f"Error: {str(e)}"
1425
+
1426
+ return result
1427
+
1428
+ def _get_current_run(self) -> dict:
1429
+ """Get info about any currently running benchmark.
1430
+
1431
+ Checks:
1432
+ 1. Local benchmark_progress.json for API benchmarks
1433
+ 2. Azure VM for WAA benchmarks running via SSH
1434
+
1435
+ Returns:
1436
+ dict with:
1437
+ - running: bool - whether a benchmark is running
1438
+ - type: str - 'local' or 'azure_vm'
1439
+ - model: str - model being evaluated
1440
+ - progress: dict with tasks_completed, total_tasks
1441
+ - current_task: str - current task ID
1442
+ - started_at: str - ISO timestamp
1443
+ - elapsed_minutes: int
1444
+ """
1445
+ import subprocess
1446
+ from datetime import datetime
1447
+ import re
1448
+
1449
+ result = {
1450
+ "running": False,
1451
+ "type": None,
1452
+ "model": None,
1453
+ "progress": {"tasks_completed": 0, "total_tasks": 0},
1454
+ "current_task": None,
1455
+ "started_at": None,
1456
+ "elapsed_minutes": 0,
1457
+ }
1458
+
1459
+ # Check local benchmark progress first
1460
+ progress_file = Path("benchmark_progress.json")
1461
+ if progress_file.exists():
1462
+ try:
1463
+ progress = json.loads(progress_file.read_text())
1464
+ if progress.get("status") == "running":
1465
+ result["running"] = True
1466
+ result["type"] = "local"
1467
+ result["model"] = progress.get("provider", "unknown")
1468
+ result["progress"]["tasks_completed"] = progress.get("tasks_complete", 0)
1469
+ result["progress"]["total_tasks"] = progress.get("tasks_total", 0)
1470
+ return result
1471
+ except Exception:
1472
+ pass
1473
+
1474
+ # Check Azure VM for running benchmark
1475
+ try:
1476
+ # Get VM IP
1477
+ ip_result = subprocess.run(
1478
+ ["az", "vm", "list-ip-addresses",
1479
+ "--name", "waa-eval-vm",
1480
+ "--resource-group", "openadapt-agents",
1481
+ "--query", "[0].virtualMachine.network.publicIpAddresses[0].ipAddress",
1482
+ "-o", "tsv"],
1483
+ capture_output=True, text=True, timeout=10
1484
+ )
1485
+
1486
+ if ip_result.returncode == 0 and ip_result.stdout.strip():
1487
+ vm_ip = ip_result.stdout.strip()
1488
+
1489
+ # Check if benchmark process is running
1490
+ process_check = subprocess.run(
1491
+ ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
1492
+ "-i", str(Path.home() / ".ssh" / "id_rsa"),
1493
+ f"azureuser@{vm_ip}",
1494
+ "docker exec waa-container pgrep -f 'python.*run.py' 2>/dev/null && echo 'RUNNING' || echo 'NOT_RUNNING'"],
1495
+ capture_output=True, text=True, timeout=10
1496
+ )
1497
+
1498
+ if process_check.returncode == 0 and "RUNNING" in process_check.stdout:
1499
+ result["running"] = True
1500
+ result["type"] = "azure_vm"
1501
+
1502
+ # Get log file for more details
1503
+ log_check = subprocess.run(
1504
+ ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
1505
+ "-i", str(Path.home() / ".ssh" / "id_rsa"),
1506
+ f"azureuser@{vm_ip}",
1507
+ "tail -100 /tmp/waa_benchmark.log 2>/dev/null || echo ''"],
1508
+ capture_output=True, text=True, timeout=10
1509
+ )
1510
+
1511
+ if log_check.returncode == 0 and log_check.stdout.strip():
1512
+ logs = log_check.stdout
1513
+
1514
+ # Parse model from logs
1515
+ model_match = re.search(r'model[=:\s]+([^\s,]+)', logs, re.IGNORECASE)
1516
+ if model_match:
1517
+ result["model"] = model_match.group(1)
1518
+
1519
+ # Parse progress
1520
+ task_match = re.search(r'Task\s+(\d+)/(\d+)', logs)
1521
+ if task_match:
1522
+ result["progress"]["tasks_completed"] = int(task_match.group(1))
1523
+ result["progress"]["total_tasks"] = int(task_match.group(2))
1524
+
1525
+ # Parse current task
1526
+ task_id_match = re.search(r'(?:Running|Processing|task)[:\s]+([a-f0-9-]+)', logs, re.IGNORECASE)
1527
+ if task_id_match:
1528
+ result["current_task"] = task_id_match.group(1)
1529
+
1530
+ except Exception:
1531
+ pass
1532
+
1533
+ return result
1534
+
1535
+ async def _detect_running_benchmark(self, vm_ip: str, container_name: str = "winarena") -> dict:
1536
+ """Detect if a benchmark is running on the VM and extract progress.
1537
+
1538
+ SSH into VM and check:
1539
+ 1. Process running: docker exec {container} pgrep -f 'python.*run.py'
1540
+ 2. Log progress: tail /tmp/waa_benchmark.log
1541
+
1542
+ Returns:
1543
+ dict with:
1544
+ - running: bool
1545
+ - current_task: str (task ID or description)
1546
+ - progress: dict with tasks_completed, total_tasks, current_step
1547
+ - recent_logs: str (last few log lines)
1548
+ """
1549
+ import subprocess
1550
+ import re
1551
+
1552
+ result = {
1553
+ "running": False,
1554
+ "current_task": None,
1555
+ "progress": {
1556
+ "tasks_completed": 0,
1557
+ "total_tasks": 0,
1558
+ "current_step": 0,
1559
+ },
1560
+ "recent_logs": "",
1561
+ }
1562
+
1563
+ try:
1564
+ # Check if benchmark process is running
1565
+ process_check = subprocess.run(
1566
+ ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
1567
+ "-i", str(Path.home() / ".ssh" / "id_rsa"),
1568
+ f"azureuser@{vm_ip}",
1569
+ f"docker exec {container_name} pgrep -f 'python.*run.py' 2>/dev/null || echo ''"],
1570
+ capture_output=True, text=True, timeout=10
1571
+ )
1572
+
1573
+ if process_check.returncode == 0 and process_check.stdout.strip():
1574
+ result["running"] = True
1575
+
1576
+ # Get benchmark log
1577
+ log_check = subprocess.run(
1578
+ ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
1579
+ "-i", str(Path.home() / ".ssh" / "id_rsa"),
1580
+ f"azureuser@{vm_ip}",
1581
+ "tail -100 /tmp/waa_benchmark.log 2>/dev/null || echo ''"],
1582
+ capture_output=True, text=True, timeout=10
1583
+ )
1584
+
1585
+ if log_check.returncode == 0 and log_check.stdout.strip():
1586
+ logs = log_check.stdout
1587
+ result["recent_logs"] = logs[-500:] # Last 500 chars
1588
+
1589
+ # Parse progress from logs
1590
+ # Look for patterns like "Task 5/30" or "Completed: 5, Remaining: 25"
1591
+ task_match = re.search(r'Task\s+(\d+)/(\d+)', logs)
1592
+ if task_match:
1593
+ result["progress"]["tasks_completed"] = int(task_match.group(1))
1594
+ result["progress"]["total_tasks"] = int(task_match.group(2))
1595
+
1596
+ # Extract current task ID
1597
+ task_id_match = re.search(r'(?:Running|Processing) task:\s*(\S+)', logs)
1598
+ if task_id_match:
1599
+ result["current_task"] = task_id_match.group(1)
1600
+
1601
+ # Extract step info
1602
+ step_match = re.search(r'Step\s+(\d+)', logs)
1603
+ if step_match:
1604
+ result["progress"]["current_step"] = int(step_match.group(1))
1605
+
1606
+ except Exception as e:
1607
+ # SSH or parsing failed - leave defaults
1608
+ pass
1609
+
1610
+ return result
1611
+
1612
+ def _parse_task_result(self, log_lines: list[str], task_id: str) -> dict:
1613
+ """Parse task success/failure from log output.
1614
+
1615
+ WAA log patterns:
1616
+ - Success: "Task task_001 completed successfully"
1617
+ - Success: "Result: PASS"
1618
+ - Failure: "Task task_001 failed"
1619
+ - Failure: "Result: FAIL"
1620
+ - Score: "Score: 0.85"
1621
+ """
1622
+ import re
1623
+
1624
+ success = None
1625
+ score = None
1626
+
1627
+ # Search backwards from most recent
1628
+ for line in reversed(log_lines):
1629
+ # Check for explicit result
1630
+ if 'Result: PASS' in line or 'completed successfully' in line:
1631
+ success = True
1632
+ elif 'Result: FAIL' in line or 'failed' in line.lower():
1633
+ success = False
1634
+
1635
+ # Check for score
1636
+ score_match = re.search(r'Score:\s*([\d.]+)', line)
1637
+ if score_match:
1638
+ try:
1639
+ score = float(score_match.group(1))
1640
+ except ValueError:
1641
+ pass
1642
+
1643
+ # Check for task-specific completion
1644
+ if task_id in line:
1645
+ if 'success' in line.lower() or 'pass' in line.lower():
1646
+ success = True
1647
+ elif 'fail' in line.lower() or 'error' in line.lower():
1648
+ success = False
1649
+
1650
+ # Default to True if no explicit failure found (backwards compatible)
1651
+ if success is None:
1652
+ success = True
1653
+
1654
+ return {"success": success, "score": score}
1655
+
1656
+ def _stream_benchmark_updates(self, interval: int):
1657
+ """Stream Server-Sent Events for benchmark status updates.
1658
+
1659
+ Streams events:
1660
+ - connected: Initial connection event
1661
+ - status: VM status and probe results
1662
+ - progress: Benchmark progress (tasks completed, current task)
1663
+ - task_complete: When a task finishes
1664
+ - heartbeat: Keep-alive signal every 30 seconds
1665
+ - error: Error messages
1666
+
1667
+ Uses a generator-based approach to avoid blocking the main thread
1668
+ and properly handles client disconnection.
1669
+ """
1670
+ import time
1671
+ import select
1672
+
1673
+ HEARTBEAT_INTERVAL = 30 # seconds
1674
+
1675
+ # Set SSE headers
1676
+ self.send_response(200)
1677
+ self.send_header('Content-Type', 'text/event-stream')
1678
+ self.send_header('Cache-Control', 'no-cache')
1679
+ self.send_header('Access-Control-Allow-Origin', '*')
1680
+ self.send_header('Connection', 'keep-alive')
1681
+ self.send_header('X-Accel-Buffering', 'no') # Disable nginx buffering
1682
+ self.end_headers()
1683
+
1684
+ # Track connection state
1685
+ client_connected = True
1686
+
1687
+ def send_event(event_type: str, data: dict) -> bool:
1688
+ """Send an SSE event. Returns False if client disconnected."""
1689
+ nonlocal client_connected
1690
+ if not client_connected:
1691
+ return False
1692
+ try:
1693
+ event_str = f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
1694
+ self.wfile.write(event_str.encode('utf-8'))
1695
+ self.wfile.flush()
1696
+ return True
1697
+ except (BrokenPipeError, ConnectionResetError, ConnectionAbortedError):
1698
+ # Client disconnected
1699
+ client_connected = False
1700
+ return False
1701
+ except Exception as e:
1702
+ # Other error - log and assume disconnected
1703
+ print(f"SSE send error: {e}")
1704
+ client_connected = False
1705
+ return False
1706
+
1707
+ def check_client_connected() -> bool:
1708
+ """Check if client is still connected using socket select."""
1709
+ nonlocal client_connected
1710
+ if not client_connected:
1711
+ return False
1712
+ try:
1713
+ # Check if socket has data (would indicate client sent something or closed)
1714
+ # Use non-blocking check with 0 timeout
1715
+ rlist, _, xlist = select.select([self.rfile], [], [self.rfile], 0)
1716
+ if xlist:
1717
+ # Error condition on socket
1718
+ client_connected = False
1719
+ return False
1720
+ if rlist:
1721
+ # Client sent data - for SSE this usually means disconnect
1722
+ # (SSE is server-push only, client doesn't send data)
1723
+ data = self.rfile.read(1)
1724
+ if not data:
1725
+ client_connected = False
1726
+ return False
1727
+ return True
1728
+ except Exception:
1729
+ client_connected = False
1730
+ return False
1731
+
1732
+ last_task = None
1733
+ last_heartbeat = time.time()
1734
+ recent_log_lines = []
1735
+
1736
+ # Send initial connected event
1737
+ if not send_event("connected", {
1738
+ "timestamp": time.time(),
1739
+ "interval": interval,
1740
+ "version": "1.0"
1741
+ }):
1742
+ return
1743
+
1744
+ try:
1745
+ iteration_count = 0
1746
+ max_iterations = 3600 // interval # Max 1 hour of streaming
1747
+
1748
+ while client_connected and iteration_count < max_iterations:
1749
+ iteration_count += 1
1750
+ current_time = time.time()
1751
+
1752
+ # Check client connection before doing work
1753
+ if not check_client_connected():
1754
+ break
1755
+
1756
+ # Send heartbeat every 30 seconds to prevent proxy/LB timeouts
1757
+ if current_time - last_heartbeat >= HEARTBEAT_INTERVAL:
1758
+ if not send_event("heartbeat", {"timestamp": current_time}):
1759
+ break
1760
+ last_heartbeat = current_time
1761
+
1762
+ # Fetch background tasks (includes VM status)
1763
+ tasks = self._fetch_background_tasks()
1764
+
1765
+ # Send VM status event
1766
+ vm_task = next((t for t in tasks if t.get("task_type") == "docker_container"), None)
1767
+ if vm_task:
1768
+ vm_data = {
1769
+ "type": "vm_status",
1770
+ "connected": vm_task.get("status") in ["running", "completed"],
1771
+ "phase": vm_task.get("phase", "unknown"),
1772
+ "waa_ready": vm_task.get("metadata", {}).get("waa_server_ready", False),
1773
+ "probe": {
1774
+ "status": vm_task.get("metadata", {}).get("probe_response", "unknown"),
1775
+ "vnc_url": vm_task.get("metadata", {}).get("vnc_url"),
1776
+ }
1777
+ }
1778
+
1779
+ if not send_event("status", vm_data):
1780
+ break
1781
+
1782
+ # If VM is ready, check for running benchmark
1783
+ if vm_data["waa_ready"]:
1784
+ # Get VM IP from tasks
1785
+ vm_ip = None
1786
+ azure_vm = next((t for t in tasks if t.get("task_type") == "vm_provision"), None)
1787
+ if azure_vm:
1788
+ vm_ip = azure_vm.get("metadata", {}).get("ip_address")
1789
+
1790
+ if vm_ip:
1791
+ # Detect running benchmark using sync version
1792
+ benchmark_status = self._detect_running_benchmark_sync(
1793
+ vm_ip, vm_task.get("metadata", {}).get("container", "winarena")
1794
+ )
1795
+
1796
+ if benchmark_status["running"]:
1797
+ # Store log lines for result parsing
1798
+ if benchmark_status.get("recent_logs"):
1799
+ recent_log_lines = benchmark_status["recent_logs"].split('\n')
1800
+
1801
+ # Send progress event
1802
+ progress_data = {
1803
+ "tasks_completed": benchmark_status["progress"]["tasks_completed"],
1804
+ "total_tasks": benchmark_status["progress"]["total_tasks"],
1805
+ "current_task": benchmark_status["current_task"],
1806
+ "current_step": benchmark_status["progress"]["current_step"],
1807
+ }
1808
+
1809
+ if not send_event("progress", progress_data):
1810
+ break
1811
+
1812
+ # Check if task completed
1813
+ current_task = benchmark_status["current_task"]
1814
+ if current_task and current_task != last_task:
1815
+ if last_task is not None:
1816
+ # Previous task completed - parse result from logs
1817
+ result = self._parse_task_result(recent_log_lines, last_task)
1818
+ complete_data = {
1819
+ "task_id": last_task,
1820
+ "success": result["success"],
1821
+ "score": result["score"],
1822
+ }
1823
+ if not send_event("task_complete", complete_data):
1824
+ break
1825
+
1826
+ last_task = current_task
1827
+
1828
+ # Check local benchmark progress file
1829
+ progress_file = Path("benchmark_progress.json")
1830
+ if progress_file.exists():
1831
+ try:
1832
+ progress = json.loads(progress_file.read_text())
1833
+ if progress.get("status") == "running":
1834
+ progress_data = {
1835
+ "tasks_completed": progress.get("tasks_complete", 0),
1836
+ "total_tasks": progress.get("tasks_total", 0),
1837
+ "current_task": progress.get("provider", "unknown"),
1838
+ }
1839
+ if not send_event("progress", progress_data):
1840
+ break
1841
+ except Exception:
1842
+ pass
1843
+
1844
+ # Non-blocking sleep using select with timeout
1845
+ # This allows checking for client disconnect during sleep
1846
+ try:
1847
+ select.select([self.rfile], [], [], interval)
1848
+ except Exception:
1849
+ break
1850
+
1851
+ except (BrokenPipeError, ConnectionResetError, ConnectionAbortedError):
1852
+ # Client disconnected - this is normal, don't log as error
1853
+ pass
1854
+ except Exception as e:
1855
+ # Send error event if still connected
1856
+ send_event("error", {"message": str(e)})
1857
+ finally:
1858
+ # Cleanup - connection is ending
1859
+ client_connected = False
1860
+
1861
+ def _detect_running_benchmark_sync(self, vm_ip: str, container_name: str = "winarena") -> dict:
1862
+ """Synchronous version of _detect_running_benchmark.
1863
+
1864
+ Avoids creating a new event loop on each call which causes issues
1865
+ when called from a synchronous context.
1866
+ """
1867
+ import subprocess
1868
+ import re
1869
+
1870
+ result = {
1871
+ "running": False,
1872
+ "current_task": None,
1873
+ "progress": {
1874
+ "tasks_completed": 0,
1875
+ "total_tasks": 0,
1876
+ "current_step": 0,
1877
+ },
1878
+ "recent_logs": "",
1879
+ }
1880
+
1881
+ try:
1882
+ # Check if benchmark process is running
1883
+ process_check = subprocess.run(
1884
+ ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
1885
+ "-i", str(Path.home() / ".ssh" / "id_rsa"),
1886
+ f"azureuser@{vm_ip}",
1887
+ f"docker exec {container_name} pgrep -f 'python.*run.py' 2>/dev/null || echo ''"],
1888
+ capture_output=True, text=True, timeout=10
1889
+ )
1890
+
1891
+ if process_check.returncode == 0 and process_check.stdout.strip():
1892
+ result["running"] = True
1893
+
1894
+ # Get benchmark log
1895
+ log_check = subprocess.run(
1896
+ ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
1897
+ "-i", str(Path.home() / ".ssh" / "id_rsa"),
1898
+ f"azureuser@{vm_ip}",
1899
+ "tail -100 /tmp/waa_benchmark.log 2>/dev/null || echo ''"],
1900
+ capture_output=True, text=True, timeout=10
1901
+ )
1902
+
1903
+ if log_check.returncode == 0 and log_check.stdout.strip():
1904
+ logs = log_check.stdout
1905
+ result["recent_logs"] = logs[-500:] # Last 500 chars
1906
+
1907
+ # Parse progress from logs
1908
+ task_match = re.search(r'Task\s+(\d+)/(\d+)', logs)
1909
+ if task_match:
1910
+ result["progress"]["tasks_completed"] = int(task_match.group(1))
1911
+ result["progress"]["total_tasks"] = int(task_match.group(2))
1912
+
1913
+ # Extract current task ID
1914
+ task_id_match = re.search(r'(?:Running|Processing) task:\s*(\S+)', logs)
1915
+ if task_id_match:
1916
+ result["current_task"] = task_id_match.group(1)
1917
+
1918
+ # Extract step info
1919
+ step_match = re.search(r'Step\s+(\d+)', logs)
1920
+ if step_match:
1921
+ result["progress"]["current_step"] = int(step_match.group(1))
1922
+
1923
+ except Exception:
1924
+ # SSH or parsing failed - leave defaults
1925
+ pass
1926
+
1927
+ return result
1928
+
1929
+ def _register_vm(self, vm_data):
1930
+ """Register a new VM in the registry."""
1931
+ # Path to VM registry file (relative to project root)
1932
+ project_root = Path(__file__).parent.parent.parent
1933
+ registry_file = project_root / "benchmark_results" / "vm_registry.json"
1934
+
1935
+ # Load existing registry
1936
+ vms = []
1937
+ if registry_file.exists():
1938
+ try:
1939
+ with open(registry_file) as f:
1940
+ vms = json.load(f)
1941
+ except Exception:
1942
+ pass
1943
+
1944
+ # Add new VM
1945
+ new_vm = {
1946
+ "name": vm_data.get("name", "unnamed-vm"),
1947
+ "ssh_host": vm_data.get("ssh_host", ""),
1948
+ "ssh_user": vm_data.get("ssh_user", "azureuser"),
1949
+ "vnc_port": vm_data.get("vnc_port", 8006),
1950
+ "waa_port": vm_data.get("waa_port", 5000),
1951
+ "docker_container": vm_data.get("docker_container", "win11-waa"),
1952
+ "internal_ip": vm_data.get("internal_ip", "20.20.20.21")
1953
+ }
1954
+
1955
+ vms.append(new_vm)
1956
+
1957
+ # Save registry
1958
+ try:
1959
+ registry_file.parent.mkdir(parents=True, exist_ok=True)
1960
+ with open(registry_file, 'w') as f:
1961
+ json.dump(vms, f, indent=2)
1962
+ return {"status": "success", "vm": new_vm}
1963
+ except Exception as e:
1964
+ return {"status": "error", "message": str(e)}
1965
+
1966
+ def _start_benchmark_run(self, params: dict) -> dict:
1967
+ """Start a benchmark run with the given parameters.
1968
+
1969
+ Runs the benchmark in a background thread and returns immediately.
1970
+ Progress is tracked via benchmark_progress.json.
1971
+
1972
+ Expected params:
1973
+ {
1974
+ "model": "gpt-4o",
1975
+ "num_tasks": 5,
1976
+ "agent": "navi",
1977
+ "task_selection": "all" | "domain" | "task_ids",
1978
+ "domain": "general", // if task_selection == "domain"
1979
+ "task_ids": ["task_001", "task_015"] // if task_selection == "task_ids"
1980
+ }
1981
+
1982
+ Returns:
1983
+ dict with status and params
1984
+ """
1985
+ from dotenv import load_dotenv
1986
+
1987
+ # Load .env file for API keys
1988
+ project_root = Path(__file__).parent.parent.parent
1989
+ load_dotenv(project_root / ".env")
1990
+
1991
+ # Build CLI command
1992
+ cmd = [
1993
+ "uv", "run", "python", "-m", "openadapt_ml.benchmarks.cli",
1994
+ "vm", "run-waa",
1995
+ "--num-tasks", str(params.get("num_tasks", 5)),
1996
+ "--model", params.get("model", "gpt-4o"),
1997
+ "--agent", params.get("agent", "navi"),
1998
+ "--no-open" # Don't open viewer (already open)
1999
+ ]
2000
+
2001
+ # Add task selection args
2002
+ task_selection = params.get("task_selection", "all")
2003
+ if task_selection == "domain":
2004
+ domain = params.get("domain", "general")
2005
+ cmd.extend(["--domain", domain])
2006
+ elif task_selection == "task_ids":
2007
+ task_ids = params.get("task_ids", [])
2008
+ if task_ids:
2009
+ cmd.extend(["--task-ids", ",".join(task_ids)])
2010
+
2011
+ # Create progress log file (in cwd which is serve_dir)
2012
+ progress_file = Path("benchmark_progress.json")
2013
+
2014
+ # Write initial progress
2015
+ model = params.get("model", "gpt-4o")
2016
+ num_tasks = params.get("num_tasks", 5)
2017
+ agent = params.get("agent", "navi")
2018
+
2019
+ print(f"\n[Benchmark] Starting WAA benchmark: model={model}, tasks={num_tasks}, agent={agent}")
2020
+ print(f"[Benchmark] Task selection: {task_selection}")
2021
+ if task_selection == "domain":
2022
+ print(f"[Benchmark] Domain: {params.get('domain', 'general')}")
2023
+ elif task_selection == "task_ids":
2024
+ print(f"[Benchmark] Task IDs: {params.get('task_ids', [])}")
2025
+ print(f"[Benchmark] Command: {' '.join(cmd)}")
2026
+
2027
+ progress_file.write_text(json.dumps({
2028
+ "status": "running",
2029
+ "model": model,
2030
+ "num_tasks": num_tasks,
2031
+ "agent": agent,
2032
+ "task_selection": task_selection,
2033
+ "tasks_complete": 0,
2034
+ "message": f"Starting {model} benchmark with {num_tasks} tasks..."
2035
+ }))
2036
+
2037
+ # Copy environment with loaded vars
2038
+ env = os.environ.copy()
2039
+
2040
+ # Run in background thread
2041
+ def run():
2042
+ result = subprocess.run(
2043
+ cmd,
2044
+ capture_output=True,
2045
+ text=True,
2046
+ cwd=str(project_root),
2047
+ env=env
2048
+ )
2049
+
2050
+ print(f"\n[Benchmark] Output:\n{result.stdout}")
2051
+ if result.stderr:
2052
+ print(f"[Benchmark] Stderr: {result.stderr}")
2053
+
2054
+ if result.returncode == 0:
2055
+ print(f"[Benchmark] Complete. Regenerating viewer...")
2056
+ progress_file.write_text(json.dumps({
2057
+ "status": "complete",
2058
+ "model": model,
2059
+ "num_tasks": num_tasks,
2060
+ "message": "Benchmark complete. Refresh to see results."
2061
+ }))
2062
+ # Regenerate benchmark viewer
2063
+ _regenerate_benchmark_viewer_if_available(serve_dir)
2064
+ else:
2065
+ error_msg = result.stderr[:200] if result.stderr else "Unknown error"
2066
+ print(f"[Benchmark] Failed: {error_msg}")
2067
+ progress_file.write_text(json.dumps({
2068
+ "status": "error",
2069
+ "model": model,
2070
+ "num_tasks": num_tasks,
2071
+ "message": f"Benchmark failed: {error_msg}"
2072
+ }))
2073
+
2074
+ threading.Thread(target=run, daemon=True).start()
2075
+
2076
+ return {"status": "started", "params": params}
2077
+
559
2078
  def do_OPTIONS(self):
560
2079
  # Handle CORS preflight
561
2080
  self.send_response(200)
@@ -564,7 +2083,11 @@ def cmd_serve(args: argparse.Namespace) -> int:
564
2083
  self.send_header('Access-Control-Allow-Headers', 'Content-Type')
565
2084
  self.end_headers()
566
2085
 
567
- with socketserver.TCPServer(("", port), StopHandler) as httpd:
2086
+ class ThreadedTCPServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
2087
+ allow_reuse_address = True
2088
+ daemon_threads = True # Don't block shutdown
2089
+
2090
+ with ThreadedTCPServer(("", port), StopHandler) as httpd:
568
2091
  url = f"http://localhost:{port}/{start_page}"
569
2092
  print(f"\nServing at: {url}")
570
2093
  print(f"Directory: {serve_dir}")
@@ -612,6 +2135,36 @@ def cmd_viewer(args: argparse.Namespace) -> int:
612
2135
  state.losses = data.get("losses", [])
613
2136
  state.status = data.get("status", "completed")
614
2137
  state.elapsed_time = data.get("elapsed_time", 0.0) # Load elapsed time for completed training
2138
+ state.goal = data.get("goal", "")
2139
+ state.config_path = data.get("config_path", "")
2140
+ state.capture_path = data.get("capture_path", "")
2141
+
2142
+ # Load model config from training_log.json or fall back to reading config file
2143
+ state.model_name = data.get("model_name", "")
2144
+ state.lora_r = data.get("lora_r", 0)
2145
+ state.lora_alpha = data.get("lora_alpha", 0)
2146
+ state.load_in_4bit = data.get("load_in_4bit", False)
2147
+
2148
+ # If model config not in JSON, try to read from config file
2149
+ if not state.model_name and state.config_path:
2150
+ try:
2151
+ import yaml
2152
+ # Try relative to project root first, then as absolute path
2153
+ project_root = Path(__file__).parent.parent.parent
2154
+ config_file = project_root / state.config_path
2155
+ if not config_file.exists():
2156
+ config_file = Path(state.config_path)
2157
+ if config_file.exists():
2158
+ with open(config_file) as cf:
2159
+ cfg = yaml.safe_load(cf)
2160
+ if cfg and "model" in cfg:
2161
+ state.model_name = cfg["model"].get("name", "")
2162
+ state.load_in_4bit = cfg["model"].get("load_in_4bit", False)
2163
+ if cfg and "lora" in cfg:
2164
+ state.lora_r = cfg["lora"].get("r", 0)
2165
+ state.lora_alpha = cfg["lora"].get("lora_alpha", 0)
2166
+ except Exception as e:
2167
+ print(f" Warning: Could not read config file: {e}")
615
2168
 
616
2169
  config = TrainingConfig(
617
2170
  num_train_epochs=data.get("total_epochs", 5),
@@ -757,6 +2310,7 @@ Examples:
757
2310
  p_serve.add_argument("--no-regenerate", action="store_true",
758
2311
  help="Skip regenerating dashboard/viewer (serve existing files)")
759
2312
  p_serve.add_argument("--benchmark", help="Serve benchmark results directory instead of training output")
2313
+ p_serve.add_argument("--start-page", help="Override default start page (e.g., benchmark.html)")
760
2314
  p_serve.set_defaults(func=cmd_serve)
761
2315
 
762
2316
  # viewer