openadapt-ml 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/benchmarks/__init__.py +8 -0
- openadapt_ml/benchmarks/agent.py +90 -11
- openadapt_ml/benchmarks/azure.py +35 -6
- openadapt_ml/benchmarks/cli.py +4449 -201
- openadapt_ml/benchmarks/live_tracker.py +180 -0
- openadapt_ml/benchmarks/runner.py +41 -4
- openadapt_ml/benchmarks/viewer.py +1219 -0
- openadapt_ml/benchmarks/vm_monitor.py +610 -0
- openadapt_ml/benchmarks/waa.py +61 -4
- openadapt_ml/benchmarks/waa_deploy/Dockerfile +222 -0
- openadapt_ml/benchmarks/waa_deploy/__init__.py +10 -0
- openadapt_ml/benchmarks/waa_deploy/api_agent.py +539 -0
- openadapt_ml/benchmarks/waa_deploy/start_waa_server.bat +53 -0
- openadapt_ml/benchmarks/waa_live.py +619 -0
- openadapt_ml/cloud/local.py +1555 -1
- openadapt_ml/cloud/ssh_tunnel.py +553 -0
- openadapt_ml/datasets/next_action.py +87 -68
- openadapt_ml/evals/grounding.py +26 -8
- openadapt_ml/evals/trajectory_matching.py +84 -36
- openadapt_ml/experiments/demo_prompt/__init__.py +19 -0
- openadapt_ml/experiments/demo_prompt/format_demo.py +226 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_20251231_002125.json +83 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_n30_20251231_165958.json +1100 -0
- openadapt_ml/experiments/demo_prompt/results/multistep_20251231_025051.json +182 -0
- openadapt_ml/experiments/demo_prompt/run_experiment.py +531 -0
- openadapt_ml/experiments/waa_demo/__init__.py +10 -0
- openadapt_ml/experiments/waa_demo/demos.py +357 -0
- openadapt_ml/experiments/waa_demo/runner.py +717 -0
- openadapt_ml/experiments/waa_demo/tasks.py +151 -0
- openadapt_ml/export/__init__.py +9 -0
- openadapt_ml/export/__main__.py +6 -0
- openadapt_ml/export/cli.py +89 -0
- openadapt_ml/export/parquet.py +265 -0
- openadapt_ml/ingest/__init__.py +3 -4
- openadapt_ml/ingest/capture.py +89 -81
- openadapt_ml/ingest/loader.py +116 -68
- openadapt_ml/ingest/synthetic.py +221 -159
- openadapt_ml/retrieval/README.md +226 -0
- openadapt_ml/retrieval/USAGE.md +391 -0
- openadapt_ml/retrieval/__init__.py +91 -0
- openadapt_ml/retrieval/demo_retriever.py +817 -0
- openadapt_ml/retrieval/embeddings.py +629 -0
- openadapt_ml/retrieval/index.py +194 -0
- openadapt_ml/retrieval/retriever.py +160 -0
- openadapt_ml/runtime/policy.py +10 -10
- openadapt_ml/schema/__init__.py +104 -0
- openadapt_ml/schema/converters.py +541 -0
- openadapt_ml/schema/episode.py +457 -0
- openadapt_ml/scripts/compare.py +26 -16
- openadapt_ml/scripts/eval_policy.py +4 -5
- openadapt_ml/scripts/prepare_synthetic.py +14 -17
- openadapt_ml/scripts/train.py +81 -70
- openadapt_ml/training/benchmark_viewer.py +3225 -0
- openadapt_ml/training/trainer.py +120 -363
- openadapt_ml/training/trl_trainer.py +354 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/METADATA +102 -60
- openadapt_ml-0.2.0.dist-info/RECORD +86 -0
- openadapt_ml/schemas/__init__.py +0 -53
- openadapt_ml/schemas/sessions.py +0 -122
- openadapt_ml/schemas/validation.py +0 -252
- openadapt_ml-0.1.0.dist-info/RECORD +0 -55
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/WHEEL +0 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/licenses/LICENSE +0 -0
openadapt_ml/cloud/local.py
CHANGED
|
@@ -36,6 +36,8 @@ import webbrowser
|
|
|
36
36
|
from pathlib import Path
|
|
37
37
|
from typing import Any
|
|
38
38
|
|
|
39
|
+
from openadapt_ml.cloud.ssh_tunnel import get_tunnel_manager
|
|
40
|
+
|
|
39
41
|
# Training output directory
|
|
40
42
|
TRAINING_OUTPUT = Path("training_output")
|
|
41
43
|
|
|
@@ -143,6 +145,16 @@ def _regenerate_benchmark_viewer_if_available(output_dir: Path) -> bool:
|
|
|
143
145
|
# No real benchmark data - generate empty state viewer
|
|
144
146
|
try:
|
|
145
147
|
generate_empty_benchmark_viewer(benchmark_html_path)
|
|
148
|
+
|
|
149
|
+
# Still create symlink for azure_jobs.json access (even without real benchmarks)
|
|
150
|
+
if benchmark_results_dir.exists():
|
|
151
|
+
benchmark_results_link = output_dir / "benchmark_results"
|
|
152
|
+
if benchmark_results_link.is_symlink():
|
|
153
|
+
benchmark_results_link.unlink()
|
|
154
|
+
elif benchmark_results_link.exists():
|
|
155
|
+
shutil.rmtree(benchmark_results_link)
|
|
156
|
+
benchmark_results_link.symlink_to(benchmark_results_dir.absolute())
|
|
157
|
+
|
|
146
158
|
print(" Generated benchmark viewer: No real evaluation data yet")
|
|
147
159
|
return True
|
|
148
160
|
except Exception as e:
|
|
@@ -168,6 +180,14 @@ def _regenerate_benchmark_viewer_if_available(output_dir: Path) -> bool:
|
|
|
168
180
|
tasks_dst = benchmark_tasks_dir / benchmark_dir.name
|
|
169
181
|
shutil.copytree(tasks_src, tasks_dst)
|
|
170
182
|
|
|
183
|
+
# Create symlink for benchmark_results directory (for azure_jobs.json access)
|
|
184
|
+
benchmark_results_link = output_dir / "benchmark_results"
|
|
185
|
+
if benchmark_results_link.is_symlink():
|
|
186
|
+
benchmark_results_link.unlink()
|
|
187
|
+
elif benchmark_results_link.exists():
|
|
188
|
+
shutil.rmtree(benchmark_results_link)
|
|
189
|
+
benchmark_results_link.symlink_to(benchmark_results_dir.absolute())
|
|
190
|
+
|
|
171
191
|
print(f" Regenerated benchmark viewer with {len(real_benchmarks)} run(s)")
|
|
172
192
|
return True
|
|
173
193
|
except Exception as e:
|
|
@@ -438,6 +458,10 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
438
458
|
|
|
439
459
|
start_page = "dashboard.html"
|
|
440
460
|
|
|
461
|
+
# Override start page if specified
|
|
462
|
+
if hasattr(args, 'start_page') and args.start_page:
|
|
463
|
+
start_page = args.start_page
|
|
464
|
+
|
|
441
465
|
# Serve from the specified directory
|
|
442
466
|
os.chdir(serve_dir)
|
|
443
467
|
|
|
@@ -535,6 +559,42 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
535
559
|
}))
|
|
536
560
|
|
|
537
561
|
threading.Thread(target=run_benchmark, daemon=True).start()
|
|
562
|
+
elif self.path == '/api/vms/register':
|
|
563
|
+
# Register a new VM
|
|
564
|
+
content_length = int(self.headers.get('Content-Length', 0))
|
|
565
|
+
body = self.rfile.read(content_length).decode('utf-8') if content_length else '{}'
|
|
566
|
+
try:
|
|
567
|
+
vm_data = json.loads(body)
|
|
568
|
+
result = self._register_vm(vm_data)
|
|
569
|
+
self.send_response(200)
|
|
570
|
+
self.send_header('Content-Type', 'application/json')
|
|
571
|
+
self.send_header('Access-Control-Allow-Origin', '*')
|
|
572
|
+
self.end_headers()
|
|
573
|
+
self.wfile.write(json.dumps(result).encode())
|
|
574
|
+
except Exception as e:
|
|
575
|
+
self.send_response(500)
|
|
576
|
+
self.send_header('Content-Type', 'application/json')
|
|
577
|
+
self.send_header('Access-Control-Allow-Origin', '*')
|
|
578
|
+
self.end_headers()
|
|
579
|
+
self.wfile.write(json.dumps({"error": str(e)}).encode())
|
|
580
|
+
elif self.path == '/api/benchmark/start':
|
|
581
|
+
# Start a benchmark run with configurable parameters
|
|
582
|
+
content_length = int(self.headers.get('Content-Length', 0))
|
|
583
|
+
body = self.rfile.read(content_length).decode('utf-8') if content_length else '{}'
|
|
584
|
+
try:
|
|
585
|
+
params = json.loads(body)
|
|
586
|
+
result = self._start_benchmark_run(params)
|
|
587
|
+
self.send_response(200)
|
|
588
|
+
self.send_header('Content-Type', 'application/json')
|
|
589
|
+
self.send_header('Access-Control-Allow-Origin', '*')
|
|
590
|
+
self.end_headers()
|
|
591
|
+
self.wfile.write(json.dumps(result).encode())
|
|
592
|
+
except Exception as e:
|
|
593
|
+
self.send_response(500)
|
|
594
|
+
self.send_header('Content-Type', 'application/json')
|
|
595
|
+
self.send_header('Access-Control-Allow-Origin', '*')
|
|
596
|
+
self.end_headers()
|
|
597
|
+
self.wfile.write(json.dumps({"error": str(e)}).encode())
|
|
538
598
|
else:
|
|
539
599
|
self.send_error(404, "Not found")
|
|
540
600
|
|
|
@@ -552,10 +612,1469 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
552
612
|
self.send_header('Access-Control-Allow-Origin', '*')
|
|
553
613
|
self.end_headers()
|
|
554
614
|
self.wfile.write(progress.encode())
|
|
615
|
+
elif self.path.startswith('/api/benchmark-live'):
|
|
616
|
+
# Return live evaluation state
|
|
617
|
+
live_file = Path("benchmark_live.json") # Relative to serve_dir (cwd)
|
|
618
|
+
if live_file.exists():
|
|
619
|
+
live_state = live_file.read_text()
|
|
620
|
+
else:
|
|
621
|
+
live_state = json.dumps({"status": "idle"})
|
|
622
|
+
|
|
623
|
+
self.send_response(200)
|
|
624
|
+
self.send_header('Content-Type', 'application/json')
|
|
625
|
+
self.send_header('Access-Control-Allow-Origin', '*')
|
|
626
|
+
self.end_headers()
|
|
627
|
+
self.wfile.write(live_state.encode())
|
|
628
|
+
elif self.path.startswith('/api/tasks'):
|
|
629
|
+
# Return background task status (VM, Docker, benchmarks)
|
|
630
|
+
try:
|
|
631
|
+
tasks = self._fetch_background_tasks()
|
|
632
|
+
self.send_response(200)
|
|
633
|
+
self.send_header('Content-Type', 'application/json')
|
|
634
|
+
self.send_header('Access-Control-Allow-Origin', '*')
|
|
635
|
+
self.end_headers()
|
|
636
|
+
self.wfile.write(json.dumps(tasks).encode())
|
|
637
|
+
except Exception as e:
|
|
638
|
+
self.send_response(500)
|
|
639
|
+
self.send_header('Content-Type', 'application/json')
|
|
640
|
+
self.send_header('Access-Control-Allow-Origin', '*')
|
|
641
|
+
self.end_headers()
|
|
642
|
+
self.wfile.write(json.dumps({"error": str(e)}).encode())
|
|
643
|
+
elif self.path.startswith('/api/azure-jobs'):
|
|
644
|
+
# Return LIVE Azure job status from Azure ML
|
|
645
|
+
# Supports ?force=true parameter for manual refresh (always fetches live)
|
|
646
|
+
try:
|
|
647
|
+
from urllib.parse import urlparse, parse_qs
|
|
648
|
+
query = parse_qs(urlparse(self.path).query)
|
|
649
|
+
force_refresh = query.get('force', ['false'])[0].lower() == 'true'
|
|
650
|
+
|
|
651
|
+
# Always fetch live data (force just indicates manual refresh for logging)
|
|
652
|
+
if force_refresh:
|
|
653
|
+
print("Azure Jobs: Manual refresh requested")
|
|
654
|
+
|
|
655
|
+
jobs = self._fetch_live_azure_jobs()
|
|
656
|
+
self.send_response(200)
|
|
657
|
+
self.send_header('Content-Type', 'application/json')
|
|
658
|
+
self.send_header('Access-Control-Allow-Origin', '*')
|
|
659
|
+
self.end_headers()
|
|
660
|
+
self.wfile.write(json.dumps(jobs).encode())
|
|
661
|
+
except Exception as e:
|
|
662
|
+
self.send_response(500)
|
|
663
|
+
self.send_header('Content-Type', 'application/json')
|
|
664
|
+
self.send_header('Access-Control-Allow-Origin', '*')
|
|
665
|
+
self.end_headers()
|
|
666
|
+
self.wfile.write(json.dumps({"error": str(e)}).encode())
|
|
667
|
+
elif self.path.startswith('/api/benchmark-sse'):
|
|
668
|
+
# Server-Sent Events endpoint for real-time benchmark updates
|
|
669
|
+
try:
|
|
670
|
+
from urllib.parse import urlparse, parse_qs
|
|
671
|
+
query = parse_qs(urlparse(self.path).query)
|
|
672
|
+
interval = int(query.get('interval', [5])[0])
|
|
673
|
+
|
|
674
|
+
# Validate interval (min 1s, max 60s)
|
|
675
|
+
interval = max(1, min(60, interval))
|
|
676
|
+
|
|
677
|
+
self._stream_benchmark_updates(interval)
|
|
678
|
+
except Exception as e:
|
|
679
|
+
self.send_error(500, f"SSE error: {e}")
|
|
680
|
+
elif self.path.startswith('/api/vms'):
|
|
681
|
+
# Return VM registry with live status
|
|
682
|
+
try:
|
|
683
|
+
vms = self._fetch_vm_registry()
|
|
684
|
+
self.send_response(200)
|
|
685
|
+
self.send_header('Content-Type', 'application/json')
|
|
686
|
+
self.send_header('Access-Control-Allow-Origin', '*')
|
|
687
|
+
self.end_headers()
|
|
688
|
+
self.wfile.write(json.dumps(vms).encode())
|
|
689
|
+
except Exception as e:
|
|
690
|
+
self.send_response(500)
|
|
691
|
+
self.send_header('Content-Type', 'application/json')
|
|
692
|
+
self.send_header('Access-Control-Allow-Origin', '*')
|
|
693
|
+
self.end_headers()
|
|
694
|
+
self.wfile.write(json.dumps({"error": str(e)}).encode())
|
|
695
|
+
elif self.path.startswith('/api/azure-job-logs'):
|
|
696
|
+
# Return live logs for running Azure job
|
|
697
|
+
try:
|
|
698
|
+
# Parse job_id from query string
|
|
699
|
+
from urllib.parse import urlparse, parse_qs
|
|
700
|
+
query = parse_qs(urlparse(self.path).query)
|
|
701
|
+
job_id = query.get('job_id', [None])[0]
|
|
702
|
+
|
|
703
|
+
logs = self._fetch_azure_job_logs(job_id)
|
|
704
|
+
self.send_response(200)
|
|
705
|
+
self.send_header('Content-Type', 'application/json')
|
|
706
|
+
self.send_header('Access-Control-Allow-Origin', '*')
|
|
707
|
+
self.end_headers()
|
|
708
|
+
self.wfile.write(json.dumps(logs).encode())
|
|
709
|
+
except Exception as e:
|
|
710
|
+
self.send_response(500)
|
|
711
|
+
self.send_header('Content-Type', 'application/json')
|
|
712
|
+
self.send_header('Access-Control-Allow-Origin', '*')
|
|
713
|
+
self.end_headers()
|
|
714
|
+
self.wfile.write(json.dumps({"error": str(e)}).encode())
|
|
715
|
+
elif self.path.startswith('/api/probe-vm'):
|
|
716
|
+
# Probe the VM to check if WAA server is responding
|
|
717
|
+
try:
|
|
718
|
+
result = self._probe_vm()
|
|
719
|
+
self.send_response(200)
|
|
720
|
+
self.send_header('Content-Type', 'application/json')
|
|
721
|
+
self.send_header('Access-Control-Allow-Origin', '*')
|
|
722
|
+
self.end_headers()
|
|
723
|
+
self.wfile.write(json.dumps(result).encode())
|
|
724
|
+
except Exception as e:
|
|
725
|
+
self.send_response(500)
|
|
726
|
+
self.send_header('Content-Type', 'application/json')
|
|
727
|
+
self.send_header('Access-Control-Allow-Origin', '*')
|
|
728
|
+
self.end_headers()
|
|
729
|
+
self.wfile.write(json.dumps({"error": str(e), "responding": False}).encode())
|
|
730
|
+
elif self.path.startswith('/api/tunnels'):
|
|
731
|
+
# Return SSH tunnel status
|
|
732
|
+
try:
|
|
733
|
+
tunnel_mgr = get_tunnel_manager()
|
|
734
|
+
status = tunnel_mgr.get_tunnel_status()
|
|
735
|
+
result = {
|
|
736
|
+
name: {
|
|
737
|
+
"active": s.active,
|
|
738
|
+
"local_port": s.local_port,
|
|
739
|
+
"remote_endpoint": s.remote_endpoint,
|
|
740
|
+
"pid": s.pid,
|
|
741
|
+
"error": s.error,
|
|
742
|
+
}
|
|
743
|
+
for name, s in status.items()
|
|
744
|
+
}
|
|
745
|
+
self.send_response(200)
|
|
746
|
+
self.send_header('Content-Type', 'application/json')
|
|
747
|
+
self.send_header('Access-Control-Allow-Origin', '*')
|
|
748
|
+
self.end_headers()
|
|
749
|
+
self.wfile.write(json.dumps(result).encode())
|
|
750
|
+
except Exception as e:
|
|
751
|
+
self.send_response(500)
|
|
752
|
+
self.send_header('Content-Type', 'application/json')
|
|
753
|
+
self.send_header('Access-Control-Allow-Origin', '*')
|
|
754
|
+
self.end_headers()
|
|
755
|
+
self.wfile.write(json.dumps({"error": str(e)}).encode())
|
|
756
|
+
elif self.path.startswith('/api/current-run'):
|
|
757
|
+
# Return currently running benchmark info
|
|
758
|
+
try:
|
|
759
|
+
result = self._get_current_run()
|
|
760
|
+
self.send_response(200)
|
|
761
|
+
self.send_header('Content-Type', 'application/json')
|
|
762
|
+
self.send_header('Access-Control-Allow-Origin', '*')
|
|
763
|
+
self.end_headers()
|
|
764
|
+
self.wfile.write(json.dumps(result).encode())
|
|
765
|
+
except Exception as e:
|
|
766
|
+
self.send_response(500)
|
|
767
|
+
self.send_header('Content-Type', 'application/json')
|
|
768
|
+
self.send_header('Access-Control-Allow-Origin', '*')
|
|
769
|
+
self.end_headers()
|
|
770
|
+
self.wfile.write(json.dumps({"error": str(e), "running": False}).encode())
|
|
771
|
+
elif self.path.startswith('/api/background-tasks'):
|
|
772
|
+
# Alias for /api/tasks - background task status
|
|
773
|
+
try:
|
|
774
|
+
tasks = self._fetch_background_tasks()
|
|
775
|
+
self.send_response(200)
|
|
776
|
+
self.send_header('Content-Type', 'application/json')
|
|
777
|
+
self.send_header('Access-Control-Allow-Origin', '*')
|
|
778
|
+
self.end_headers()
|
|
779
|
+
self.wfile.write(json.dumps(tasks).encode())
|
|
780
|
+
except Exception as e:
|
|
781
|
+
self.send_response(500)
|
|
782
|
+
self.send_header('Content-Type', 'application/json')
|
|
783
|
+
self.send_header('Access-Control-Allow-Origin', '*')
|
|
784
|
+
self.end_headers()
|
|
785
|
+
self.wfile.write(json.dumps({"error": str(e)}).encode())
|
|
555
786
|
else:
|
|
556
787
|
# Default file serving
|
|
557
788
|
super().do_GET()
|
|
558
789
|
|
|
790
|
+
def _fetch_live_azure_jobs(self):
|
|
791
|
+
"""Fetch live job status from Azure ML."""
|
|
792
|
+
import subprocess
|
|
793
|
+
result = subprocess.run(
|
|
794
|
+
["az", "ml", "job", "list",
|
|
795
|
+
"--resource-group", "openadapt-agents",
|
|
796
|
+
"--workspace-name", "openadapt-ml",
|
|
797
|
+
"--query", "[].{name:name,display_name:display_name,status:status,creation_context:creation_context.created_at}",
|
|
798
|
+
"-o", "json"],
|
|
799
|
+
capture_output=True, text=True, timeout=30
|
|
800
|
+
)
|
|
801
|
+
if result.returncode != 0:
|
|
802
|
+
raise Exception(f"Azure CLI error: {result.stderr}")
|
|
803
|
+
|
|
804
|
+
jobs = json.loads(result.stdout)
|
|
805
|
+
# Format for frontend
|
|
806
|
+
experiment_id = "ad29082c-0607-4fda-8cc7-38944eb5a518"
|
|
807
|
+
wsid = "/subscriptions/78add6c6-c92a-4a53-b751-eb644ac77e59/resourceGroups/openadapt-agents/providers/Microsoft.MachineLearningServices/workspaces/openadapt-ml"
|
|
808
|
+
|
|
809
|
+
formatted = []
|
|
810
|
+
for job in jobs[:10]: # Limit to 10 most recent
|
|
811
|
+
formatted.append({
|
|
812
|
+
"job_id": job.get("name", "unknown"),
|
|
813
|
+
"display_name": job.get("display_name", ""),
|
|
814
|
+
"status": job.get("status", "unknown").lower(),
|
|
815
|
+
"started_at": job.get("creation_context", ""),
|
|
816
|
+
"azure_dashboard_url": f"https://ml.azure.com/experiments/id/{experiment_id}/runs/{job.get('name', '')}?wsid={wsid}",
|
|
817
|
+
"is_live": True # Flag to indicate this is live data
|
|
818
|
+
})
|
|
819
|
+
return formatted
|
|
820
|
+
|
|
821
|
+
def _fetch_azure_job_logs(self, job_id: str | None):
|
|
822
|
+
"""Fetch logs for an Azure ML job (streaming for running jobs)."""
|
|
823
|
+
import subprocess
|
|
824
|
+
|
|
825
|
+
if not job_id:
|
|
826
|
+
# Get the most recent running job
|
|
827
|
+
jobs = self._fetch_live_azure_jobs()
|
|
828
|
+
running = [j for j in jobs if j['status'] == 'running']
|
|
829
|
+
if running:
|
|
830
|
+
job_id = running[0]['job_id']
|
|
831
|
+
else:
|
|
832
|
+
return {"logs": "No running jobs found", "job_id": None, "status": "idle"}
|
|
833
|
+
|
|
834
|
+
# Try to stream logs for running job using az ml job stream
|
|
835
|
+
try:
|
|
836
|
+
result = subprocess.run(
|
|
837
|
+
["az", "ml", "job", "stream",
|
|
838
|
+
"--name", job_id,
|
|
839
|
+
"--resource-group", "openadapt-agents",
|
|
840
|
+
"--workspace-name", "openadapt-ml"],
|
|
841
|
+
capture_output=True, text=True, timeout=3 # Short timeout
|
|
842
|
+
)
|
|
843
|
+
if result.returncode == 0 and result.stdout.strip():
|
|
844
|
+
return {"logs": result.stdout[-5000:], "job_id": job_id, "status": "streaming"}
|
|
845
|
+
except subprocess.TimeoutExpired:
|
|
846
|
+
pass # Fall through to job show
|
|
847
|
+
|
|
848
|
+
# Get job details instead
|
|
849
|
+
result = subprocess.run(
|
|
850
|
+
["az", "ml", "job", "show",
|
|
851
|
+
"--name", job_id,
|
|
852
|
+
"--resource-group", "openadapt-agents",
|
|
853
|
+
"--workspace-name", "openadapt-ml",
|
|
854
|
+
"-o", "json"],
|
|
855
|
+
capture_output=True, text=True, timeout=10
|
|
856
|
+
)
|
|
857
|
+
|
|
858
|
+
if result.returncode == 0:
|
|
859
|
+
job_info = json.loads(result.stdout)
|
|
860
|
+
return {
|
|
861
|
+
"logs": f"Job {job_id} is {job_info.get('status', 'unknown')}\\n\\nCommand: {job_info.get('command', 'N/A')}",
|
|
862
|
+
"job_id": job_id,
|
|
863
|
+
"status": job_info.get('status', 'unknown').lower(),
|
|
864
|
+
"command": job_info.get('command', '')
|
|
865
|
+
}
|
|
866
|
+
|
|
867
|
+
return {"logs": f"Could not fetch logs: {result.stderr}", "job_id": job_id, "status": "error"}
|
|
868
|
+
|
|
869
|
+
def _get_vm_detailed_metadata(self, vm_ip: str, container_name: str, logs: str, phase: str) -> dict:
|
|
870
|
+
"""Get detailed VM metadata for the VM Details panel.
|
|
871
|
+
|
|
872
|
+
Returns:
|
|
873
|
+
dict with disk_usage_gb, memory_usage_mb, setup_script_phase, probe_response, qmp_connected, dependencies
|
|
874
|
+
"""
|
|
875
|
+
import subprocess
|
|
876
|
+
import re
|
|
877
|
+
|
|
878
|
+
metadata = {
|
|
879
|
+
"disk_usage_gb": None,
|
|
880
|
+
"memory_usage_mb": None,
|
|
881
|
+
"setup_script_phase": None,
|
|
882
|
+
"probe_response": None,
|
|
883
|
+
"qmp_connected": False,
|
|
884
|
+
"dependencies": []
|
|
885
|
+
}
|
|
886
|
+
|
|
887
|
+
# 1. Get disk usage from docker stats
|
|
888
|
+
try:
|
|
889
|
+
disk_result = subprocess.run(
|
|
890
|
+
["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
|
|
891
|
+
"-i", str(Path.home() / ".ssh" / "id_rsa"),
|
|
892
|
+
f"azureuser@{vm_ip}",
|
|
893
|
+
f"docker exec {container_name} df -h /storage 2>/dev/null | tail -1"],
|
|
894
|
+
capture_output=True, text=True, timeout=10
|
|
895
|
+
)
|
|
896
|
+
if disk_result.returncode == 0 and disk_result.stdout.strip():
|
|
897
|
+
# Parse: "Filesystem Size Used Avail Use% Mounted on"
|
|
898
|
+
# Example: "/dev/sda1 30G 9.2G 20G 31% /storage"
|
|
899
|
+
parts = disk_result.stdout.split()
|
|
900
|
+
if len(parts) >= 3:
|
|
901
|
+
used_str = parts[2] # e.g., "9.2G"
|
|
902
|
+
total_str = parts[1] # e.g., "30G"
|
|
903
|
+
# Convert to GB (handle M/G suffixes)
|
|
904
|
+
def to_gb(s):
|
|
905
|
+
if s.endswith('G'):
|
|
906
|
+
return float(s[:-1])
|
|
907
|
+
elif s.endswith('M'):
|
|
908
|
+
return float(s[:-1]) / 1024
|
|
909
|
+
elif s.endswith('K'):
|
|
910
|
+
return float(s[:-1]) / (1024 * 1024)
|
|
911
|
+
return 0
|
|
912
|
+
metadata["disk_usage_gb"] = f"{to_gb(used_str):.1f} GB / {to_gb(total_str):.0f} GB used"
|
|
913
|
+
except Exception:
|
|
914
|
+
pass
|
|
915
|
+
|
|
916
|
+
# 2. Get memory usage from docker stats
|
|
917
|
+
try:
|
|
918
|
+
mem_result = subprocess.run(
|
|
919
|
+
["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
|
|
920
|
+
"-i", str(Path.home() / ".ssh" / "id_rsa"),
|
|
921
|
+
f"azureuser@{vm_ip}",
|
|
922
|
+
f"docker stats {container_name} --no-stream --format '{{{{.MemUsage}}}}'"],
|
|
923
|
+
capture_output=True, text=True, timeout=10
|
|
924
|
+
)
|
|
925
|
+
if mem_result.returncode == 0 and mem_result.stdout.strip():
|
|
926
|
+
# Example: "1.5GiB / 4GiB"
|
|
927
|
+
metadata["memory_usage_mb"] = mem_result.stdout.strip()
|
|
928
|
+
except Exception:
|
|
929
|
+
pass
|
|
930
|
+
|
|
931
|
+
# 3. Parse setup script phase from logs
|
|
932
|
+
metadata["setup_script_phase"] = self._parse_setup_phase_from_logs(logs, phase)
|
|
933
|
+
|
|
934
|
+
# 4. Check /probe endpoint
|
|
935
|
+
try:
|
|
936
|
+
probe_result = subprocess.run(
|
|
937
|
+
["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
|
|
938
|
+
"-i", str(Path.home() / ".ssh" / "id_rsa"),
|
|
939
|
+
f"azureuser@{vm_ip}",
|
|
940
|
+
"curl -s --connect-timeout 2 http://20.20.20.21:5000/probe 2>/dev/null"],
|
|
941
|
+
capture_output=True, text=True, timeout=10
|
|
942
|
+
)
|
|
943
|
+
if probe_result.returncode == 0 and probe_result.stdout.strip():
|
|
944
|
+
metadata["probe_response"] = probe_result.stdout.strip()
|
|
945
|
+
else:
|
|
946
|
+
metadata["probe_response"] = "Not responding"
|
|
947
|
+
except Exception:
|
|
948
|
+
metadata["probe_response"] = "Connection failed"
|
|
949
|
+
|
|
950
|
+
# 5. Check QMP connection (port 7200)
|
|
951
|
+
try:
|
|
952
|
+
qmp_result = subprocess.run(
|
|
953
|
+
["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
|
|
954
|
+
"-i", str(Path.home() / ".ssh" / "id_rsa"),
|
|
955
|
+
f"azureuser@{vm_ip}",
|
|
956
|
+
"nc -z -w2 localhost 7200 2>&1"],
|
|
957
|
+
capture_output=True, text=True, timeout=10
|
|
958
|
+
)
|
|
959
|
+
metadata["qmp_connected"] = qmp_result.returncode == 0
|
|
960
|
+
except Exception:
|
|
961
|
+
pass
|
|
962
|
+
|
|
963
|
+
# 6. Parse dependencies from logs
|
|
964
|
+
metadata["dependencies"] = self._parse_dependencies_from_logs(logs, phase)
|
|
965
|
+
|
|
966
|
+
return metadata
|
|
967
|
+
|
|
968
|
+
def _parse_setup_phase_from_logs(self, logs: str, current_phase: str) -> str:
|
|
969
|
+
"""Parse the current setup script phase from logs.
|
|
970
|
+
|
|
971
|
+
Looks for patterns indicating which script is running:
|
|
972
|
+
- install.bat
|
|
973
|
+
- setup.ps1
|
|
974
|
+
- on-logon.ps1
|
|
975
|
+
"""
|
|
976
|
+
if current_phase == "ready":
|
|
977
|
+
return "Setup complete"
|
|
978
|
+
elif current_phase == "oobe":
|
|
979
|
+
# Check for specific script patterns
|
|
980
|
+
if "on-logon.ps1" in logs.lower():
|
|
981
|
+
return "Running on-logon.ps1"
|
|
982
|
+
elif "setup.ps1" in logs.lower():
|
|
983
|
+
return "Running setup.ps1"
|
|
984
|
+
elif "install.bat" in logs.lower():
|
|
985
|
+
return "Running install.bat"
|
|
986
|
+
else:
|
|
987
|
+
return "Windows installation in progress"
|
|
988
|
+
elif current_phase == "booting":
|
|
989
|
+
return "Booting Windows"
|
|
990
|
+
elif current_phase in ["downloading", "extracting", "configuring", "building"]:
|
|
991
|
+
return "Preparing Windows VM"
|
|
992
|
+
else:
|
|
993
|
+
return "Initializing..."
|
|
994
|
+
|
|
995
|
+
def _parse_dependencies_from_logs(self, logs: str, phase: str) -> list[dict]:
|
|
996
|
+
"""Parse dependency installation status from logs.
|
|
997
|
+
|
|
998
|
+
Returns list of dependencies with their installation status:
|
|
999
|
+
- Python
|
|
1000
|
+
- Chrome
|
|
1001
|
+
- LibreOffice
|
|
1002
|
+
- VSCode
|
|
1003
|
+
- etc.
|
|
1004
|
+
"""
|
|
1005
|
+
dependencies = [
|
|
1006
|
+
{"name": "Python", "icon": "🐍", "status": "pending"},
|
|
1007
|
+
{"name": "Chrome", "icon": "🌐", "status": "pending"},
|
|
1008
|
+
{"name": "LibreOffice", "icon": "📝", "status": "pending"},
|
|
1009
|
+
{"name": "VSCode", "icon": "💻", "status": "pending"},
|
|
1010
|
+
{"name": "WAA Server", "icon": "🔧", "status": "pending"},
|
|
1011
|
+
]
|
|
1012
|
+
|
|
1013
|
+
if phase not in ["oobe", "ready"]:
|
|
1014
|
+
# Not yet at Windows setup phase
|
|
1015
|
+
return dependencies
|
|
1016
|
+
|
|
1017
|
+
logs_lower = logs.lower()
|
|
1018
|
+
|
|
1019
|
+
# Check for installation patterns
|
|
1020
|
+
if "python" in logs_lower and ("installing python" in logs_lower or "python.exe" in logs_lower):
|
|
1021
|
+
dependencies[0]["status"] = "installing"
|
|
1022
|
+
elif "python" in logs_lower and "installed" in logs_lower:
|
|
1023
|
+
dependencies[0]["status"] = "complete"
|
|
1024
|
+
|
|
1025
|
+
if "chrome" in logs_lower and ("downloading" in logs_lower or "installing" in logs_lower):
|
|
1026
|
+
dependencies[1]["status"] = "installing"
|
|
1027
|
+
elif "chrome" in logs_lower and "installed" in logs_lower:
|
|
1028
|
+
dependencies[1]["status"] = "complete"
|
|
1029
|
+
|
|
1030
|
+
if "libreoffice" in logs_lower and ("downloading" in logs_lower or "installing" in logs_lower):
|
|
1031
|
+
dependencies[2]["status"] = "installing"
|
|
1032
|
+
elif "libreoffice" in logs_lower and "installed" in logs_lower:
|
|
1033
|
+
dependencies[2]["status"] = "complete"
|
|
1034
|
+
|
|
1035
|
+
if "vscode" in logs_lower or "visual studio code" in logs_lower:
|
|
1036
|
+
if "installing" in logs_lower:
|
|
1037
|
+
dependencies[3]["status"] = "installing"
|
|
1038
|
+
elif "installed" in logs_lower:
|
|
1039
|
+
dependencies[3]["status"] = "complete"
|
|
1040
|
+
|
|
1041
|
+
if "waa" in logs_lower or "flask" in logs_lower:
|
|
1042
|
+
if "starting" in logs_lower or "running" in logs_lower:
|
|
1043
|
+
dependencies[4]["status"] = "installing"
|
|
1044
|
+
elif phase == "ready":
|
|
1045
|
+
dependencies[4]["status"] = "complete"
|
|
1046
|
+
|
|
1047
|
+
return dependencies
|
|
1048
|
+
|
|
1049
|
+
def _fetch_background_tasks(self):
|
|
1050
|
+
"""Fetch status of all background tasks: Azure VM, Docker containers, benchmarks."""
|
|
1051
|
+
import subprocess
|
|
1052
|
+
from datetime import datetime
|
|
1053
|
+
import time
|
|
1054
|
+
|
|
1055
|
+
tasks = []
|
|
1056
|
+
|
|
1057
|
+
# Check for VM IP from environment (set by CLI when auto-launching viewer)
|
|
1058
|
+
env_vm_ip = os.environ.get("WAA_VM_IP")
|
|
1059
|
+
env_internal_ip = os.environ.get("WAA_INTERNAL_IP", "172.30.0.2")
|
|
1060
|
+
|
|
1061
|
+
# 1. Check Azure WAA VM status
|
|
1062
|
+
vm_ip = None
|
|
1063
|
+
if env_vm_ip:
|
|
1064
|
+
# Use environment variable - VM IP was provided directly
|
|
1065
|
+
vm_ip = env_vm_ip
|
|
1066
|
+
tasks.append({
|
|
1067
|
+
"task_id": "azure-vm-waa",
|
|
1068
|
+
"task_type": "vm_provision",
|
|
1069
|
+
"status": "completed",
|
|
1070
|
+
"phase": "ready", # Match status to prevent "Starting" + "completed" conflict
|
|
1071
|
+
"title": "Azure VM Host",
|
|
1072
|
+
"description": f"Linux host running at {vm_ip}",
|
|
1073
|
+
"progress_percent": 100.0,
|
|
1074
|
+
"elapsed_seconds": 0,
|
|
1075
|
+
"metadata": {
|
|
1076
|
+
"vm_name": "waa-eval-vm",
|
|
1077
|
+
"ip_address": vm_ip,
|
|
1078
|
+
"internal_ip": env_internal_ip
|
|
1079
|
+
}
|
|
1080
|
+
})
|
|
1081
|
+
else:
|
|
1082
|
+
# Query Azure CLI for VM status
|
|
1083
|
+
try:
|
|
1084
|
+
result = subprocess.run(
|
|
1085
|
+
["az", "vm", "get-instance-view",
|
|
1086
|
+
"--name", "waa-eval-vm",
|
|
1087
|
+
"--resource-group", "openadapt-agents",
|
|
1088
|
+
"--query", "instanceView.statuses",
|
|
1089
|
+
"-o", "json"],
|
|
1090
|
+
capture_output=True, text=True, timeout=10
|
|
1091
|
+
)
|
|
1092
|
+
if result.returncode == 0:
|
|
1093
|
+
statuses = json.loads(result.stdout)
|
|
1094
|
+
power_state = "unknown"
|
|
1095
|
+
for s in statuses:
|
|
1096
|
+
if s.get("code", "").startswith("PowerState/"):
|
|
1097
|
+
power_state = s["code"].replace("PowerState/", "")
|
|
1098
|
+
|
|
1099
|
+
# Get VM IP
|
|
1100
|
+
ip_result = subprocess.run(
|
|
1101
|
+
["az", "vm", "list-ip-addresses",
|
|
1102
|
+
"--name", "waa-eval-vm",
|
|
1103
|
+
"--resource-group", "openadapt-agents",
|
|
1104
|
+
"--query", "[0].virtualMachine.network.publicIpAddresses[0].ipAddress",
|
|
1105
|
+
"-o", "tsv"],
|
|
1106
|
+
capture_output=True, text=True, timeout=10
|
|
1107
|
+
)
|
|
1108
|
+
vm_ip = ip_result.stdout.strip() if ip_result.returncode == 0 else None
|
|
1109
|
+
|
|
1110
|
+
if power_state == "running":
|
|
1111
|
+
tasks.append({
|
|
1112
|
+
"task_id": "azure-vm-waa",
|
|
1113
|
+
"task_type": "vm_provision",
|
|
1114
|
+
"status": "completed",
|
|
1115
|
+
"phase": "ready", # Match status to prevent "Starting" + "completed" conflict
|
|
1116
|
+
"title": "Azure VM Host",
|
|
1117
|
+
"description": f"Linux host running at {vm_ip}" if vm_ip else "Linux host running",
|
|
1118
|
+
"progress_percent": 100.0,
|
|
1119
|
+
"elapsed_seconds": 0,
|
|
1120
|
+
"metadata": {
|
|
1121
|
+
"vm_name": "waa-eval-vm",
|
|
1122
|
+
"ip_address": vm_ip
|
|
1123
|
+
# No VNC link - that's for the Windows container
|
|
1124
|
+
}
|
|
1125
|
+
})
|
|
1126
|
+
except subprocess.TimeoutExpired:
|
|
1127
|
+
pass
|
|
1128
|
+
except Exception:
|
|
1129
|
+
pass
|
|
1130
|
+
|
|
1131
|
+
# 2. Check Docker container status on VM (if we have an IP)
|
|
1132
|
+
if vm_ip:
|
|
1133
|
+
try:
|
|
1134
|
+
docker_result = subprocess.run(
|
|
1135
|
+
["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
|
|
1136
|
+
"-i", str(Path.home() / ".ssh" / "id_rsa"),
|
|
1137
|
+
f"azureuser@{vm_ip}",
|
|
1138
|
+
"docker ps --format '{{.Names}}|{{.Status}}|{{.Image}}'"],
|
|
1139
|
+
capture_output=True, text=True, timeout=15
|
|
1140
|
+
)
|
|
1141
|
+
if docker_result.returncode == 0 and docker_result.stdout.strip():
|
|
1142
|
+
for line in docker_result.stdout.strip().split('\n'):
|
|
1143
|
+
parts = line.split('|')
|
|
1144
|
+
if len(parts) >= 3:
|
|
1145
|
+
container_name, status, image = parts[0], parts[1], parts[2]
|
|
1146
|
+
# Parse "Up X minutes" to determine if healthy
|
|
1147
|
+
is_healthy = "Up" in status
|
|
1148
|
+
|
|
1149
|
+
# Check for Windows VM specifically
|
|
1150
|
+
if "windows" in image.lower() or container_name == "winarena":
|
|
1151
|
+
# Get detailed progress from docker logs
|
|
1152
|
+
log_check = subprocess.run(
|
|
1153
|
+
["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
|
|
1154
|
+
"-i", str(Path.home() / ".ssh" / "id_rsa"),
|
|
1155
|
+
f"azureuser@{vm_ip}",
|
|
1156
|
+
f"docker logs {container_name} 2>&1 | tail -30"],
|
|
1157
|
+
capture_output=True, text=True, timeout=10
|
|
1158
|
+
)
|
|
1159
|
+
logs = log_check.stdout if log_check.returncode == 0 else ""
|
|
1160
|
+
|
|
1161
|
+
# Parse progress from logs
|
|
1162
|
+
phase = "unknown"
|
|
1163
|
+
progress = 0.0
|
|
1164
|
+
description = "Starting..."
|
|
1165
|
+
|
|
1166
|
+
if "Windows started successfully" in logs:
|
|
1167
|
+
# Check if WAA server is ready via Docker port forwarding
|
|
1168
|
+
# See docs/waa_network_architecture.md - always use localhost
|
|
1169
|
+
server_check = subprocess.run(
|
|
1170
|
+
["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
|
|
1171
|
+
"-i", str(Path.home() / ".ssh" / "id_rsa"),
|
|
1172
|
+
f"azureuser@{vm_ip}",
|
|
1173
|
+
"curl -s --connect-timeout 2 http://localhost:5000/probe 2>/dev/null"],
|
|
1174
|
+
capture_output=True, text=True, timeout=10
|
|
1175
|
+
)
|
|
1176
|
+
waa_ready = server_check.returncode == 0 and "Service is operational" in server_check.stdout
|
|
1177
|
+
if waa_ready:
|
|
1178
|
+
phase = "ready"
|
|
1179
|
+
progress = 100.0
|
|
1180
|
+
description = "WAA Server ready - benchmarks can run"
|
|
1181
|
+
else:
|
|
1182
|
+
phase = "oobe"
|
|
1183
|
+
progress = 80.0 # Phase 5/6 - VM install in progress
|
|
1184
|
+
description = "Phase 5/6: Windows installing (check VNC for %). OEM scripts will run after."
|
|
1185
|
+
elif "Booting Windows" in logs:
|
|
1186
|
+
phase = "booting"
|
|
1187
|
+
progress = 70.0 # Phase 4/6
|
|
1188
|
+
description = "Phase 4/6: Booting Windows from installer..."
|
|
1189
|
+
elif "Building Windows" in logs or "Creating a" in logs:
|
|
1190
|
+
phase = "building"
|
|
1191
|
+
progress = 60.0 # Phase 3/6
|
|
1192
|
+
description = "Phase 3/6: Building Windows VM disk..."
|
|
1193
|
+
elif "Adding" in logs and "image" in logs:
|
|
1194
|
+
phase = "configuring"
|
|
1195
|
+
progress = 50.0 # Phase 2/6
|
|
1196
|
+
description = "Phase 2/6: Configuring Windows image with WAA scripts..."
|
|
1197
|
+
elif "Extracting" in logs:
|
|
1198
|
+
phase = "extracting"
|
|
1199
|
+
progress = 35.0 # Phase 1/6 (after download)
|
|
1200
|
+
description = "Phase 1/6: Extracting Windows ISO..."
|
|
1201
|
+
else:
|
|
1202
|
+
# Check for download progress (e.g., "1234K ........ 45% 80M 30s")
|
|
1203
|
+
import re
|
|
1204
|
+
download_match = re.search(r'(\d+)%\s+[\d.]+[KMG]\s+(\d+)s', logs)
|
|
1205
|
+
if download_match:
|
|
1206
|
+
phase = "downloading"
|
|
1207
|
+
dl_pct = float(download_match.group(1))
|
|
1208
|
+
progress = dl_pct * 0.30 # 0-30% for download phase
|
|
1209
|
+
eta = download_match.group(2)
|
|
1210
|
+
description = f"Phase 0/6: Downloading Windows 11... {download_match.group(1)}% ({eta}s left)"
|
|
1211
|
+
|
|
1212
|
+
# Improve phase detection - if Windows is booted but WAA not ready,
|
|
1213
|
+
# it might be at login screen waiting for OEM scripts or running install.bat
|
|
1214
|
+
if phase == "oobe" and "Boot0004" in logs:
|
|
1215
|
+
# Windows finished installing, at login/desktop
|
|
1216
|
+
# install.bat should auto-run from FirstLogonCommands (see Dockerfile)
|
|
1217
|
+
description = "Phase 5/6: Windows at desktop, OEM scripts running... (WAA server starting)"
|
|
1218
|
+
progress = 90.0
|
|
1219
|
+
|
|
1220
|
+
# Get detailed metadata for VM Details panel
|
|
1221
|
+
vm_metadata = self._get_vm_detailed_metadata(vm_ip, container_name, logs, phase)
|
|
1222
|
+
|
|
1223
|
+
tasks.append({
|
|
1224
|
+
"task_id": f"docker-{container_name}",
|
|
1225
|
+
"task_type": "docker_container",
|
|
1226
|
+
"status": "completed" if phase == "ready" else "running",
|
|
1227
|
+
"title": "Windows 11 + WAA Server",
|
|
1228
|
+
"description": description,
|
|
1229
|
+
"progress_percent": progress,
|
|
1230
|
+
"elapsed_seconds": 0,
|
|
1231
|
+
"phase": phase,
|
|
1232
|
+
"metadata": {
|
|
1233
|
+
"container": container_name,
|
|
1234
|
+
"image": image,
|
|
1235
|
+
"status": status,
|
|
1236
|
+
"phase": phase,
|
|
1237
|
+
"windows_ready": phase in ["oobe", "ready"],
|
|
1238
|
+
"waa_server_ready": phase == "ready",
|
|
1239
|
+
# Use localhost - SSH tunnel handles routing to VM
|
|
1240
|
+
# See docs/waa_network_architecture.md
|
|
1241
|
+
"vnc_url": "http://localhost:8006",
|
|
1242
|
+
"windows_username": "Docker",
|
|
1243
|
+
"windows_password": "admin",
|
|
1244
|
+
"recent_logs": logs[-500:] if logs else "",
|
|
1245
|
+
# Enhanced VM details
|
|
1246
|
+
"disk_usage_gb": vm_metadata["disk_usage_gb"],
|
|
1247
|
+
"memory_usage_mb": vm_metadata["memory_usage_mb"],
|
|
1248
|
+
"setup_script_phase": vm_metadata["setup_script_phase"],
|
|
1249
|
+
"probe_response": vm_metadata["probe_response"],
|
|
1250
|
+
"qmp_connected": vm_metadata["qmp_connected"],
|
|
1251
|
+
"dependencies": vm_metadata["dependencies"],
|
|
1252
|
+
}
|
|
1253
|
+
})
|
|
1254
|
+
except Exception as e:
|
|
1255
|
+
# SSH failed, VM might still be starting
|
|
1256
|
+
pass
|
|
1257
|
+
|
|
1258
|
+
# 3. Check local benchmark progress
|
|
1259
|
+
progress_file = Path("benchmark_progress.json")
|
|
1260
|
+
if progress_file.exists():
|
|
1261
|
+
try:
|
|
1262
|
+
progress = json.loads(progress_file.read_text())
|
|
1263
|
+
if progress.get("status") == "running":
|
|
1264
|
+
tasks.append({
|
|
1265
|
+
"task_id": "benchmark-local",
|
|
1266
|
+
"task_type": "benchmark_run",
|
|
1267
|
+
"status": "running",
|
|
1268
|
+
"title": f"{progress.get('provider', 'API').upper()} Benchmark",
|
|
1269
|
+
"description": progress.get("message", "Running benchmark..."),
|
|
1270
|
+
"progress_percent": (progress.get("tasks_complete", 0) / max(progress.get("tasks_total", 1), 1)) * 100,
|
|
1271
|
+
"elapsed_seconds": 0,
|
|
1272
|
+
"metadata": progress
|
|
1273
|
+
})
|
|
1274
|
+
except Exception:
|
|
1275
|
+
pass
|
|
1276
|
+
|
|
1277
|
+
return tasks
|
|
1278
|
+
|
|
1279
|
+
def _fetch_vm_registry(self):
|
|
1280
|
+
"""Fetch VM registry with live status checks."""
|
|
1281
|
+
import subprocess
|
|
1282
|
+
from datetime import datetime
|
|
1283
|
+
|
|
1284
|
+
# Path to VM registry file (relative to project root)
|
|
1285
|
+
project_root = Path(__file__).parent.parent.parent
|
|
1286
|
+
registry_file = project_root / "benchmark_results" / "vm_registry.json"
|
|
1287
|
+
|
|
1288
|
+
if not registry_file.exists():
|
|
1289
|
+
return []
|
|
1290
|
+
|
|
1291
|
+
try:
|
|
1292
|
+
with open(registry_file) as f:
|
|
1293
|
+
vms = json.load(f)
|
|
1294
|
+
except Exception as e:
|
|
1295
|
+
return {"error": f"Failed to read VM registry: {e}"}
|
|
1296
|
+
|
|
1297
|
+
# Check status for each VM
|
|
1298
|
+
for vm in vms:
|
|
1299
|
+
vm["status"] = "unknown"
|
|
1300
|
+
vm["last_checked"] = datetime.now().isoformat()
|
|
1301
|
+
vm["vnc_reachable"] = False
|
|
1302
|
+
vm["waa_probe_status"] = "unknown"
|
|
1303
|
+
|
|
1304
|
+
# Check VNC (HTTP HEAD request)
|
|
1305
|
+
try:
|
|
1306
|
+
vnc_url = f"http://{vm['ssh_host']}:{vm['vnc_port']}"
|
|
1307
|
+
result = subprocess.run(
|
|
1308
|
+
["curl", "-I", "-s", "--connect-timeout", "3", vnc_url],
|
|
1309
|
+
capture_output=True, text=True, timeout=5
|
|
1310
|
+
)
|
|
1311
|
+
if result.returncode == 0 and "200" in result.stdout:
|
|
1312
|
+
vm["vnc_reachable"] = True
|
|
1313
|
+
except Exception:
|
|
1314
|
+
pass
|
|
1315
|
+
|
|
1316
|
+
# Check WAA probe via SSH
|
|
1317
|
+
# Probe WAA via localhost (Docker port forwarding handles routing)
|
|
1318
|
+
# See docs/waa_network_architecture.md for architecture details
|
|
1319
|
+
try:
|
|
1320
|
+
waa_port = vm.get("waa_port", 5000)
|
|
1321
|
+
ssh_cmd = f"curl -s --connect-timeout 2 http://localhost:{waa_port}/probe 2>/dev/null"
|
|
1322
|
+
result = subprocess.run(
|
|
1323
|
+
["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=3",
|
|
1324
|
+
"-i", str(Path.home() / ".ssh" / "id_rsa"),
|
|
1325
|
+
f"{vm['ssh_user']}@{vm['ssh_host']}",
|
|
1326
|
+
ssh_cmd],
|
|
1327
|
+
capture_output=True, text=True, timeout=5
|
|
1328
|
+
)
|
|
1329
|
+
probe_success = result.returncode == 0 and "Service is operational" in result.stdout
|
|
1330
|
+
if probe_success:
|
|
1331
|
+
vm["waa_probe_status"] = "ready"
|
|
1332
|
+
vm["status"] = "online"
|
|
1333
|
+
# Auto-start SSH tunnels for VNC and WAA
|
|
1334
|
+
try:
|
|
1335
|
+
tunnel_mgr = get_tunnel_manager()
|
|
1336
|
+
tunnel_status = tunnel_mgr.ensure_tunnels_for_vm(
|
|
1337
|
+
vm_ip=vm["ssh_host"],
|
|
1338
|
+
ssh_user=vm.get("ssh_user", "azureuser"),
|
|
1339
|
+
)
|
|
1340
|
+
vm["tunnels"] = {
|
|
1341
|
+
name: {"active": s.active, "local_port": s.local_port, "error": s.error}
|
|
1342
|
+
for name, s in tunnel_status.items()
|
|
1343
|
+
}
|
|
1344
|
+
except Exception as e:
|
|
1345
|
+
vm["tunnels"] = {"error": str(e)}
|
|
1346
|
+
else:
|
|
1347
|
+
vm["waa_probe_status"] = "not responding"
|
|
1348
|
+
vm["status"] = "offline"
|
|
1349
|
+
# Stop tunnels when VM goes offline
|
|
1350
|
+
try:
|
|
1351
|
+
tunnel_mgr = get_tunnel_manager()
|
|
1352
|
+
tunnel_mgr.stop_all_tunnels()
|
|
1353
|
+
vm["tunnels"] = {}
|
|
1354
|
+
except Exception:
|
|
1355
|
+
pass
|
|
1356
|
+
except Exception:
|
|
1357
|
+
vm["waa_probe_status"] = "ssh failed"
|
|
1358
|
+
vm["status"] = "offline"
|
|
1359
|
+
|
|
1360
|
+
return vms
|
|
1361
|
+
|
|
1362
|
+
def _probe_vm(self) -> dict:
|
|
1363
|
+
"""Probe the Azure VM to check if WAA server is responding.
|
|
1364
|
+
|
|
1365
|
+
Returns:
|
|
1366
|
+
dict with:
|
|
1367
|
+
- responding: bool - whether the WAA server is responding
|
|
1368
|
+
- vm_ip: str - the VM's IP address
|
|
1369
|
+
- container: str - the container name
|
|
1370
|
+
- probe_result: str - the raw probe response or error message
|
|
1371
|
+
- last_checked: str - ISO timestamp
|
|
1372
|
+
"""
|
|
1373
|
+
import subprocess
|
|
1374
|
+
from datetime import datetime
|
|
1375
|
+
|
|
1376
|
+
result = {
|
|
1377
|
+
"responding": False,
|
|
1378
|
+
"vm_ip": None,
|
|
1379
|
+
"container": None,
|
|
1380
|
+
"probe_result": None,
|
|
1381
|
+
"last_checked": datetime.now().isoformat(),
|
|
1382
|
+
}
|
|
1383
|
+
|
|
1384
|
+
# First get VM IP
|
|
1385
|
+
try:
|
|
1386
|
+
ip_result = subprocess.run(
|
|
1387
|
+
["az", "vm", "list-ip-addresses",
|
|
1388
|
+
"--name", "waa-eval-vm",
|
|
1389
|
+
"--resource-group", "openadapt-agents",
|
|
1390
|
+
"--query", "[0].virtualMachine.network.publicIpAddresses[0].ipAddress",
|
|
1391
|
+
"-o", "tsv"],
|
|
1392
|
+
capture_output=True, text=True, timeout=10
|
|
1393
|
+
)
|
|
1394
|
+
if ip_result.returncode == 0 and ip_result.stdout.strip():
|
|
1395
|
+
vm_ip = ip_result.stdout.strip()
|
|
1396
|
+
result["vm_ip"] = vm_ip
|
|
1397
|
+
|
|
1398
|
+
# Try to probe WAA server via SSH
|
|
1399
|
+
# Use the correct internal IP for the Windows VM inside Docker
|
|
1400
|
+
probe_result = subprocess.run(
|
|
1401
|
+
["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
|
|
1402
|
+
"-i", str(Path.home() / ".ssh" / "id_rsa"),
|
|
1403
|
+
f"azureuser@{vm_ip}",
|
|
1404
|
+
"docker exec waa-container curl -s --connect-timeout 3 http://172.30.0.2:5000/probe 2>/dev/null || echo 'probe_failed'"],
|
|
1405
|
+
capture_output=True, text=True, timeout=15
|
|
1406
|
+
)
|
|
1407
|
+
|
|
1408
|
+
result["container"] = "waa-container"
|
|
1409
|
+
|
|
1410
|
+
if probe_result.returncode == 0:
|
|
1411
|
+
probe_output = probe_result.stdout.strip()
|
|
1412
|
+
if probe_output and "probe_failed" not in probe_output:
|
|
1413
|
+
result["responding"] = True
|
|
1414
|
+
result["probe_result"] = probe_output
|
|
1415
|
+
else:
|
|
1416
|
+
result["probe_result"] = "WAA server not responding"
|
|
1417
|
+
else:
|
|
1418
|
+
result["probe_result"] = f"SSH/Docker error: {probe_result.stderr[:200]}"
|
|
1419
|
+
else:
|
|
1420
|
+
result["probe_result"] = "Could not get VM IP"
|
|
1421
|
+
except subprocess.TimeoutExpired:
|
|
1422
|
+
result["probe_result"] = "Connection timeout"
|
|
1423
|
+
except Exception as e:
|
|
1424
|
+
result["probe_result"] = f"Error: {str(e)}"
|
|
1425
|
+
|
|
1426
|
+
return result
|
|
1427
|
+
|
|
1428
|
+
def _get_current_run(self) -> dict:
|
|
1429
|
+
"""Get info about any currently running benchmark.
|
|
1430
|
+
|
|
1431
|
+
Checks:
|
|
1432
|
+
1. Local benchmark_progress.json for API benchmarks
|
|
1433
|
+
2. Azure VM for WAA benchmarks running via SSH
|
|
1434
|
+
|
|
1435
|
+
Returns:
|
|
1436
|
+
dict with:
|
|
1437
|
+
- running: bool - whether a benchmark is running
|
|
1438
|
+
- type: str - 'local' or 'azure_vm'
|
|
1439
|
+
- model: str - model being evaluated
|
|
1440
|
+
- progress: dict with tasks_completed, total_tasks
|
|
1441
|
+
- current_task: str - current task ID
|
|
1442
|
+
- started_at: str - ISO timestamp
|
|
1443
|
+
- elapsed_minutes: int
|
|
1444
|
+
"""
|
|
1445
|
+
import subprocess
|
|
1446
|
+
from datetime import datetime
|
|
1447
|
+
import re
|
|
1448
|
+
|
|
1449
|
+
result = {
|
|
1450
|
+
"running": False,
|
|
1451
|
+
"type": None,
|
|
1452
|
+
"model": None,
|
|
1453
|
+
"progress": {"tasks_completed": 0, "total_tasks": 0},
|
|
1454
|
+
"current_task": None,
|
|
1455
|
+
"started_at": None,
|
|
1456
|
+
"elapsed_minutes": 0,
|
|
1457
|
+
}
|
|
1458
|
+
|
|
1459
|
+
# Check local benchmark progress first
|
|
1460
|
+
progress_file = Path("benchmark_progress.json")
|
|
1461
|
+
if progress_file.exists():
|
|
1462
|
+
try:
|
|
1463
|
+
progress = json.loads(progress_file.read_text())
|
|
1464
|
+
if progress.get("status") == "running":
|
|
1465
|
+
result["running"] = True
|
|
1466
|
+
result["type"] = "local"
|
|
1467
|
+
result["model"] = progress.get("provider", "unknown")
|
|
1468
|
+
result["progress"]["tasks_completed"] = progress.get("tasks_complete", 0)
|
|
1469
|
+
result["progress"]["total_tasks"] = progress.get("tasks_total", 0)
|
|
1470
|
+
return result
|
|
1471
|
+
except Exception:
|
|
1472
|
+
pass
|
|
1473
|
+
|
|
1474
|
+
# Check Azure VM for running benchmark
|
|
1475
|
+
try:
|
|
1476
|
+
# Get VM IP
|
|
1477
|
+
ip_result = subprocess.run(
|
|
1478
|
+
["az", "vm", "list-ip-addresses",
|
|
1479
|
+
"--name", "waa-eval-vm",
|
|
1480
|
+
"--resource-group", "openadapt-agents",
|
|
1481
|
+
"--query", "[0].virtualMachine.network.publicIpAddresses[0].ipAddress",
|
|
1482
|
+
"-o", "tsv"],
|
|
1483
|
+
capture_output=True, text=True, timeout=10
|
|
1484
|
+
)
|
|
1485
|
+
|
|
1486
|
+
if ip_result.returncode == 0 and ip_result.stdout.strip():
|
|
1487
|
+
vm_ip = ip_result.stdout.strip()
|
|
1488
|
+
|
|
1489
|
+
# Check if benchmark process is running
|
|
1490
|
+
process_check = subprocess.run(
|
|
1491
|
+
["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
|
|
1492
|
+
"-i", str(Path.home() / ".ssh" / "id_rsa"),
|
|
1493
|
+
f"azureuser@{vm_ip}",
|
|
1494
|
+
"docker exec waa-container pgrep -f 'python.*run.py' 2>/dev/null && echo 'RUNNING' || echo 'NOT_RUNNING'"],
|
|
1495
|
+
capture_output=True, text=True, timeout=10
|
|
1496
|
+
)
|
|
1497
|
+
|
|
1498
|
+
if process_check.returncode == 0 and "RUNNING" in process_check.stdout:
|
|
1499
|
+
result["running"] = True
|
|
1500
|
+
result["type"] = "azure_vm"
|
|
1501
|
+
|
|
1502
|
+
# Get log file for more details
|
|
1503
|
+
log_check = subprocess.run(
|
|
1504
|
+
["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
|
|
1505
|
+
"-i", str(Path.home() / ".ssh" / "id_rsa"),
|
|
1506
|
+
f"azureuser@{vm_ip}",
|
|
1507
|
+
"tail -100 /tmp/waa_benchmark.log 2>/dev/null || echo ''"],
|
|
1508
|
+
capture_output=True, text=True, timeout=10
|
|
1509
|
+
)
|
|
1510
|
+
|
|
1511
|
+
if log_check.returncode == 0 and log_check.stdout.strip():
|
|
1512
|
+
logs = log_check.stdout
|
|
1513
|
+
|
|
1514
|
+
# Parse model from logs
|
|
1515
|
+
model_match = re.search(r'model[=:\s]+([^\s,]+)', logs, re.IGNORECASE)
|
|
1516
|
+
if model_match:
|
|
1517
|
+
result["model"] = model_match.group(1)
|
|
1518
|
+
|
|
1519
|
+
# Parse progress
|
|
1520
|
+
task_match = re.search(r'Task\s+(\d+)/(\d+)', logs)
|
|
1521
|
+
if task_match:
|
|
1522
|
+
result["progress"]["tasks_completed"] = int(task_match.group(1))
|
|
1523
|
+
result["progress"]["total_tasks"] = int(task_match.group(2))
|
|
1524
|
+
|
|
1525
|
+
# Parse current task
|
|
1526
|
+
task_id_match = re.search(r'(?:Running|Processing|task)[:\s]+([a-f0-9-]+)', logs, re.IGNORECASE)
|
|
1527
|
+
if task_id_match:
|
|
1528
|
+
result["current_task"] = task_id_match.group(1)
|
|
1529
|
+
|
|
1530
|
+
except Exception:
|
|
1531
|
+
pass
|
|
1532
|
+
|
|
1533
|
+
return result
|
|
1534
|
+
|
|
1535
|
+
async def _detect_running_benchmark(self, vm_ip: str, container_name: str = "winarena") -> dict:
|
|
1536
|
+
"""Detect if a benchmark is running on the VM and extract progress.
|
|
1537
|
+
|
|
1538
|
+
SSH into VM and check:
|
|
1539
|
+
1. Process running: docker exec {container} pgrep -f 'python.*run.py'
|
|
1540
|
+
2. Log progress: tail /tmp/waa_benchmark.log
|
|
1541
|
+
|
|
1542
|
+
Returns:
|
|
1543
|
+
dict with:
|
|
1544
|
+
- running: bool
|
|
1545
|
+
- current_task: str (task ID or description)
|
|
1546
|
+
- progress: dict with tasks_completed, total_tasks, current_step
|
|
1547
|
+
- recent_logs: str (last few log lines)
|
|
1548
|
+
"""
|
|
1549
|
+
import subprocess
|
|
1550
|
+
import re
|
|
1551
|
+
|
|
1552
|
+
result = {
|
|
1553
|
+
"running": False,
|
|
1554
|
+
"current_task": None,
|
|
1555
|
+
"progress": {
|
|
1556
|
+
"tasks_completed": 0,
|
|
1557
|
+
"total_tasks": 0,
|
|
1558
|
+
"current_step": 0,
|
|
1559
|
+
},
|
|
1560
|
+
"recent_logs": "",
|
|
1561
|
+
}
|
|
1562
|
+
|
|
1563
|
+
try:
|
|
1564
|
+
# Check if benchmark process is running
|
|
1565
|
+
process_check = subprocess.run(
|
|
1566
|
+
["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
|
|
1567
|
+
"-i", str(Path.home() / ".ssh" / "id_rsa"),
|
|
1568
|
+
f"azureuser@{vm_ip}",
|
|
1569
|
+
f"docker exec {container_name} pgrep -f 'python.*run.py' 2>/dev/null || echo ''"],
|
|
1570
|
+
capture_output=True, text=True, timeout=10
|
|
1571
|
+
)
|
|
1572
|
+
|
|
1573
|
+
if process_check.returncode == 0 and process_check.stdout.strip():
|
|
1574
|
+
result["running"] = True
|
|
1575
|
+
|
|
1576
|
+
# Get benchmark log
|
|
1577
|
+
log_check = subprocess.run(
|
|
1578
|
+
["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
|
|
1579
|
+
"-i", str(Path.home() / ".ssh" / "id_rsa"),
|
|
1580
|
+
f"azureuser@{vm_ip}",
|
|
1581
|
+
"tail -100 /tmp/waa_benchmark.log 2>/dev/null || echo ''"],
|
|
1582
|
+
capture_output=True, text=True, timeout=10
|
|
1583
|
+
)
|
|
1584
|
+
|
|
1585
|
+
if log_check.returncode == 0 and log_check.stdout.strip():
|
|
1586
|
+
logs = log_check.stdout
|
|
1587
|
+
result["recent_logs"] = logs[-500:] # Last 500 chars
|
|
1588
|
+
|
|
1589
|
+
# Parse progress from logs
|
|
1590
|
+
# Look for patterns like "Task 5/30" or "Completed: 5, Remaining: 25"
|
|
1591
|
+
task_match = re.search(r'Task\s+(\d+)/(\d+)', logs)
|
|
1592
|
+
if task_match:
|
|
1593
|
+
result["progress"]["tasks_completed"] = int(task_match.group(1))
|
|
1594
|
+
result["progress"]["total_tasks"] = int(task_match.group(2))
|
|
1595
|
+
|
|
1596
|
+
# Extract current task ID
|
|
1597
|
+
task_id_match = re.search(r'(?:Running|Processing) task:\s*(\S+)', logs)
|
|
1598
|
+
if task_id_match:
|
|
1599
|
+
result["current_task"] = task_id_match.group(1)
|
|
1600
|
+
|
|
1601
|
+
# Extract step info
|
|
1602
|
+
step_match = re.search(r'Step\s+(\d+)', logs)
|
|
1603
|
+
if step_match:
|
|
1604
|
+
result["progress"]["current_step"] = int(step_match.group(1))
|
|
1605
|
+
|
|
1606
|
+
except Exception as e:
|
|
1607
|
+
# SSH or parsing failed - leave defaults
|
|
1608
|
+
pass
|
|
1609
|
+
|
|
1610
|
+
return result
|
|
1611
|
+
|
|
1612
|
+
def _parse_task_result(self, log_lines: list[str], task_id: str) -> dict:
|
|
1613
|
+
"""Parse task success/failure from log output.
|
|
1614
|
+
|
|
1615
|
+
WAA log patterns:
|
|
1616
|
+
- Success: "Task task_001 completed successfully"
|
|
1617
|
+
- Success: "Result: PASS"
|
|
1618
|
+
- Failure: "Task task_001 failed"
|
|
1619
|
+
- Failure: "Result: FAIL"
|
|
1620
|
+
- Score: "Score: 0.85"
|
|
1621
|
+
"""
|
|
1622
|
+
import re
|
|
1623
|
+
|
|
1624
|
+
success = None
|
|
1625
|
+
score = None
|
|
1626
|
+
|
|
1627
|
+
# Search backwards from most recent
|
|
1628
|
+
for line in reversed(log_lines):
|
|
1629
|
+
# Check for explicit result
|
|
1630
|
+
if 'Result: PASS' in line or 'completed successfully' in line:
|
|
1631
|
+
success = True
|
|
1632
|
+
elif 'Result: FAIL' in line or 'failed' in line.lower():
|
|
1633
|
+
success = False
|
|
1634
|
+
|
|
1635
|
+
# Check for score
|
|
1636
|
+
score_match = re.search(r'Score:\s*([\d.]+)', line)
|
|
1637
|
+
if score_match:
|
|
1638
|
+
try:
|
|
1639
|
+
score = float(score_match.group(1))
|
|
1640
|
+
except ValueError:
|
|
1641
|
+
pass
|
|
1642
|
+
|
|
1643
|
+
# Check for task-specific completion
|
|
1644
|
+
if task_id in line:
|
|
1645
|
+
if 'success' in line.lower() or 'pass' in line.lower():
|
|
1646
|
+
success = True
|
|
1647
|
+
elif 'fail' in line.lower() or 'error' in line.lower():
|
|
1648
|
+
success = False
|
|
1649
|
+
|
|
1650
|
+
# Default to True if no explicit failure found (backwards compatible)
|
|
1651
|
+
if success is None:
|
|
1652
|
+
success = True
|
|
1653
|
+
|
|
1654
|
+
return {"success": success, "score": score}
|
|
1655
|
+
|
|
1656
|
+
def _stream_benchmark_updates(self, interval: int):
|
|
1657
|
+
"""Stream Server-Sent Events for benchmark status updates.
|
|
1658
|
+
|
|
1659
|
+
Streams events:
|
|
1660
|
+
- connected: Initial connection event
|
|
1661
|
+
- status: VM status and probe results
|
|
1662
|
+
- progress: Benchmark progress (tasks completed, current task)
|
|
1663
|
+
- task_complete: When a task finishes
|
|
1664
|
+
- heartbeat: Keep-alive signal every 30 seconds
|
|
1665
|
+
- error: Error messages
|
|
1666
|
+
|
|
1667
|
+
Uses a generator-based approach to avoid blocking the main thread
|
|
1668
|
+
and properly handles client disconnection.
|
|
1669
|
+
"""
|
|
1670
|
+
import time
|
|
1671
|
+
import select
|
|
1672
|
+
|
|
1673
|
+
HEARTBEAT_INTERVAL = 30 # seconds
|
|
1674
|
+
|
|
1675
|
+
# Set SSE headers
|
|
1676
|
+
self.send_response(200)
|
|
1677
|
+
self.send_header('Content-Type', 'text/event-stream')
|
|
1678
|
+
self.send_header('Cache-Control', 'no-cache')
|
|
1679
|
+
self.send_header('Access-Control-Allow-Origin', '*')
|
|
1680
|
+
self.send_header('Connection', 'keep-alive')
|
|
1681
|
+
self.send_header('X-Accel-Buffering', 'no') # Disable nginx buffering
|
|
1682
|
+
self.end_headers()
|
|
1683
|
+
|
|
1684
|
+
# Track connection state
|
|
1685
|
+
client_connected = True
|
|
1686
|
+
|
|
1687
|
+
def send_event(event_type: str, data: dict) -> bool:
|
|
1688
|
+
"""Send an SSE event. Returns False if client disconnected."""
|
|
1689
|
+
nonlocal client_connected
|
|
1690
|
+
if not client_connected:
|
|
1691
|
+
return False
|
|
1692
|
+
try:
|
|
1693
|
+
event_str = f"event: {event_type}\ndata: {json.dumps(data)}\n\n"
|
|
1694
|
+
self.wfile.write(event_str.encode('utf-8'))
|
|
1695
|
+
self.wfile.flush()
|
|
1696
|
+
return True
|
|
1697
|
+
except (BrokenPipeError, ConnectionResetError, ConnectionAbortedError):
|
|
1698
|
+
# Client disconnected
|
|
1699
|
+
client_connected = False
|
|
1700
|
+
return False
|
|
1701
|
+
except Exception as e:
|
|
1702
|
+
# Other error - log and assume disconnected
|
|
1703
|
+
print(f"SSE send error: {e}")
|
|
1704
|
+
client_connected = False
|
|
1705
|
+
return False
|
|
1706
|
+
|
|
1707
|
+
def check_client_connected() -> bool:
|
|
1708
|
+
"""Check if client is still connected using socket select."""
|
|
1709
|
+
nonlocal client_connected
|
|
1710
|
+
if not client_connected:
|
|
1711
|
+
return False
|
|
1712
|
+
try:
|
|
1713
|
+
# Check if socket has data (would indicate client sent something or closed)
|
|
1714
|
+
# Use non-blocking check with 0 timeout
|
|
1715
|
+
rlist, _, xlist = select.select([self.rfile], [], [self.rfile], 0)
|
|
1716
|
+
if xlist:
|
|
1717
|
+
# Error condition on socket
|
|
1718
|
+
client_connected = False
|
|
1719
|
+
return False
|
|
1720
|
+
if rlist:
|
|
1721
|
+
# Client sent data - for SSE this usually means disconnect
|
|
1722
|
+
# (SSE is server-push only, client doesn't send data)
|
|
1723
|
+
data = self.rfile.read(1)
|
|
1724
|
+
if not data:
|
|
1725
|
+
client_connected = False
|
|
1726
|
+
return False
|
|
1727
|
+
return True
|
|
1728
|
+
except Exception:
|
|
1729
|
+
client_connected = False
|
|
1730
|
+
return False
|
|
1731
|
+
|
|
1732
|
+
last_task = None
|
|
1733
|
+
last_heartbeat = time.time()
|
|
1734
|
+
recent_log_lines = []
|
|
1735
|
+
|
|
1736
|
+
# Send initial connected event
|
|
1737
|
+
if not send_event("connected", {
|
|
1738
|
+
"timestamp": time.time(),
|
|
1739
|
+
"interval": interval,
|
|
1740
|
+
"version": "1.0"
|
|
1741
|
+
}):
|
|
1742
|
+
return
|
|
1743
|
+
|
|
1744
|
+
try:
|
|
1745
|
+
iteration_count = 0
|
|
1746
|
+
max_iterations = 3600 // interval # Max 1 hour of streaming
|
|
1747
|
+
|
|
1748
|
+
while client_connected and iteration_count < max_iterations:
|
|
1749
|
+
iteration_count += 1
|
|
1750
|
+
current_time = time.time()
|
|
1751
|
+
|
|
1752
|
+
# Check client connection before doing work
|
|
1753
|
+
if not check_client_connected():
|
|
1754
|
+
break
|
|
1755
|
+
|
|
1756
|
+
# Send heartbeat every 30 seconds to prevent proxy/LB timeouts
|
|
1757
|
+
if current_time - last_heartbeat >= HEARTBEAT_INTERVAL:
|
|
1758
|
+
if not send_event("heartbeat", {"timestamp": current_time}):
|
|
1759
|
+
break
|
|
1760
|
+
last_heartbeat = current_time
|
|
1761
|
+
|
|
1762
|
+
# Fetch background tasks (includes VM status)
|
|
1763
|
+
tasks = self._fetch_background_tasks()
|
|
1764
|
+
|
|
1765
|
+
# Send VM status event
|
|
1766
|
+
vm_task = next((t for t in tasks if t.get("task_type") == "docker_container"), None)
|
|
1767
|
+
if vm_task:
|
|
1768
|
+
vm_data = {
|
|
1769
|
+
"type": "vm_status",
|
|
1770
|
+
"connected": vm_task.get("status") in ["running", "completed"],
|
|
1771
|
+
"phase": vm_task.get("phase", "unknown"),
|
|
1772
|
+
"waa_ready": vm_task.get("metadata", {}).get("waa_server_ready", False),
|
|
1773
|
+
"probe": {
|
|
1774
|
+
"status": vm_task.get("metadata", {}).get("probe_response", "unknown"),
|
|
1775
|
+
"vnc_url": vm_task.get("metadata", {}).get("vnc_url"),
|
|
1776
|
+
}
|
|
1777
|
+
}
|
|
1778
|
+
|
|
1779
|
+
if not send_event("status", vm_data):
|
|
1780
|
+
break
|
|
1781
|
+
|
|
1782
|
+
# If VM is ready, check for running benchmark
|
|
1783
|
+
if vm_data["waa_ready"]:
|
|
1784
|
+
# Get VM IP from tasks
|
|
1785
|
+
vm_ip = None
|
|
1786
|
+
azure_vm = next((t for t in tasks if t.get("task_type") == "vm_provision"), None)
|
|
1787
|
+
if azure_vm:
|
|
1788
|
+
vm_ip = azure_vm.get("metadata", {}).get("ip_address")
|
|
1789
|
+
|
|
1790
|
+
if vm_ip:
|
|
1791
|
+
# Detect running benchmark using sync version
|
|
1792
|
+
benchmark_status = self._detect_running_benchmark_sync(
|
|
1793
|
+
vm_ip, vm_task.get("metadata", {}).get("container", "winarena")
|
|
1794
|
+
)
|
|
1795
|
+
|
|
1796
|
+
if benchmark_status["running"]:
|
|
1797
|
+
# Store log lines for result parsing
|
|
1798
|
+
if benchmark_status.get("recent_logs"):
|
|
1799
|
+
recent_log_lines = benchmark_status["recent_logs"].split('\n')
|
|
1800
|
+
|
|
1801
|
+
# Send progress event
|
|
1802
|
+
progress_data = {
|
|
1803
|
+
"tasks_completed": benchmark_status["progress"]["tasks_completed"],
|
|
1804
|
+
"total_tasks": benchmark_status["progress"]["total_tasks"],
|
|
1805
|
+
"current_task": benchmark_status["current_task"],
|
|
1806
|
+
"current_step": benchmark_status["progress"]["current_step"],
|
|
1807
|
+
}
|
|
1808
|
+
|
|
1809
|
+
if not send_event("progress", progress_data):
|
|
1810
|
+
break
|
|
1811
|
+
|
|
1812
|
+
# Check if task completed
|
|
1813
|
+
current_task = benchmark_status["current_task"]
|
|
1814
|
+
if current_task and current_task != last_task:
|
|
1815
|
+
if last_task is not None:
|
|
1816
|
+
# Previous task completed - parse result from logs
|
|
1817
|
+
result = self._parse_task_result(recent_log_lines, last_task)
|
|
1818
|
+
complete_data = {
|
|
1819
|
+
"task_id": last_task,
|
|
1820
|
+
"success": result["success"],
|
|
1821
|
+
"score": result["score"],
|
|
1822
|
+
}
|
|
1823
|
+
if not send_event("task_complete", complete_data):
|
|
1824
|
+
break
|
|
1825
|
+
|
|
1826
|
+
last_task = current_task
|
|
1827
|
+
|
|
1828
|
+
# Check local benchmark progress file
|
|
1829
|
+
progress_file = Path("benchmark_progress.json")
|
|
1830
|
+
if progress_file.exists():
|
|
1831
|
+
try:
|
|
1832
|
+
progress = json.loads(progress_file.read_text())
|
|
1833
|
+
if progress.get("status") == "running":
|
|
1834
|
+
progress_data = {
|
|
1835
|
+
"tasks_completed": progress.get("tasks_complete", 0),
|
|
1836
|
+
"total_tasks": progress.get("tasks_total", 0),
|
|
1837
|
+
"current_task": progress.get("provider", "unknown"),
|
|
1838
|
+
}
|
|
1839
|
+
if not send_event("progress", progress_data):
|
|
1840
|
+
break
|
|
1841
|
+
except Exception:
|
|
1842
|
+
pass
|
|
1843
|
+
|
|
1844
|
+
# Non-blocking sleep using select with timeout
|
|
1845
|
+
# This allows checking for client disconnect during sleep
|
|
1846
|
+
try:
|
|
1847
|
+
select.select([self.rfile], [], [], interval)
|
|
1848
|
+
except Exception:
|
|
1849
|
+
break
|
|
1850
|
+
|
|
1851
|
+
except (BrokenPipeError, ConnectionResetError, ConnectionAbortedError):
|
|
1852
|
+
# Client disconnected - this is normal, don't log as error
|
|
1853
|
+
pass
|
|
1854
|
+
except Exception as e:
|
|
1855
|
+
# Send error event if still connected
|
|
1856
|
+
send_event("error", {"message": str(e)})
|
|
1857
|
+
finally:
|
|
1858
|
+
# Cleanup - connection is ending
|
|
1859
|
+
client_connected = False
|
|
1860
|
+
|
|
1861
|
+
def _detect_running_benchmark_sync(self, vm_ip: str, container_name: str = "winarena") -> dict:
|
|
1862
|
+
"""Synchronous version of _detect_running_benchmark.
|
|
1863
|
+
|
|
1864
|
+
Avoids creating a new event loop on each call which causes issues
|
|
1865
|
+
when called from a synchronous context.
|
|
1866
|
+
"""
|
|
1867
|
+
import subprocess
|
|
1868
|
+
import re
|
|
1869
|
+
|
|
1870
|
+
result = {
|
|
1871
|
+
"running": False,
|
|
1872
|
+
"current_task": None,
|
|
1873
|
+
"progress": {
|
|
1874
|
+
"tasks_completed": 0,
|
|
1875
|
+
"total_tasks": 0,
|
|
1876
|
+
"current_step": 0,
|
|
1877
|
+
},
|
|
1878
|
+
"recent_logs": "",
|
|
1879
|
+
}
|
|
1880
|
+
|
|
1881
|
+
try:
|
|
1882
|
+
# Check if benchmark process is running
|
|
1883
|
+
process_check = subprocess.run(
|
|
1884
|
+
["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
|
|
1885
|
+
"-i", str(Path.home() / ".ssh" / "id_rsa"),
|
|
1886
|
+
f"azureuser@{vm_ip}",
|
|
1887
|
+
f"docker exec {container_name} pgrep -f 'python.*run.py' 2>/dev/null || echo ''"],
|
|
1888
|
+
capture_output=True, text=True, timeout=10
|
|
1889
|
+
)
|
|
1890
|
+
|
|
1891
|
+
if process_check.returncode == 0 and process_check.stdout.strip():
|
|
1892
|
+
result["running"] = True
|
|
1893
|
+
|
|
1894
|
+
# Get benchmark log
|
|
1895
|
+
log_check = subprocess.run(
|
|
1896
|
+
["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=5",
|
|
1897
|
+
"-i", str(Path.home() / ".ssh" / "id_rsa"),
|
|
1898
|
+
f"azureuser@{vm_ip}",
|
|
1899
|
+
"tail -100 /tmp/waa_benchmark.log 2>/dev/null || echo ''"],
|
|
1900
|
+
capture_output=True, text=True, timeout=10
|
|
1901
|
+
)
|
|
1902
|
+
|
|
1903
|
+
if log_check.returncode == 0 and log_check.stdout.strip():
|
|
1904
|
+
logs = log_check.stdout
|
|
1905
|
+
result["recent_logs"] = logs[-500:] # Last 500 chars
|
|
1906
|
+
|
|
1907
|
+
# Parse progress from logs
|
|
1908
|
+
task_match = re.search(r'Task\s+(\d+)/(\d+)', logs)
|
|
1909
|
+
if task_match:
|
|
1910
|
+
result["progress"]["tasks_completed"] = int(task_match.group(1))
|
|
1911
|
+
result["progress"]["total_tasks"] = int(task_match.group(2))
|
|
1912
|
+
|
|
1913
|
+
# Extract current task ID
|
|
1914
|
+
task_id_match = re.search(r'(?:Running|Processing) task:\s*(\S+)', logs)
|
|
1915
|
+
if task_id_match:
|
|
1916
|
+
result["current_task"] = task_id_match.group(1)
|
|
1917
|
+
|
|
1918
|
+
# Extract step info
|
|
1919
|
+
step_match = re.search(r'Step\s+(\d+)', logs)
|
|
1920
|
+
if step_match:
|
|
1921
|
+
result["progress"]["current_step"] = int(step_match.group(1))
|
|
1922
|
+
|
|
1923
|
+
except Exception:
|
|
1924
|
+
# SSH or parsing failed - leave defaults
|
|
1925
|
+
pass
|
|
1926
|
+
|
|
1927
|
+
return result
|
|
1928
|
+
|
|
1929
|
+
def _register_vm(self, vm_data):
|
|
1930
|
+
"""Register a new VM in the registry."""
|
|
1931
|
+
# Path to VM registry file (relative to project root)
|
|
1932
|
+
project_root = Path(__file__).parent.parent.parent
|
|
1933
|
+
registry_file = project_root / "benchmark_results" / "vm_registry.json"
|
|
1934
|
+
|
|
1935
|
+
# Load existing registry
|
|
1936
|
+
vms = []
|
|
1937
|
+
if registry_file.exists():
|
|
1938
|
+
try:
|
|
1939
|
+
with open(registry_file) as f:
|
|
1940
|
+
vms = json.load(f)
|
|
1941
|
+
except Exception:
|
|
1942
|
+
pass
|
|
1943
|
+
|
|
1944
|
+
# Add new VM
|
|
1945
|
+
new_vm = {
|
|
1946
|
+
"name": vm_data.get("name", "unnamed-vm"),
|
|
1947
|
+
"ssh_host": vm_data.get("ssh_host", ""),
|
|
1948
|
+
"ssh_user": vm_data.get("ssh_user", "azureuser"),
|
|
1949
|
+
"vnc_port": vm_data.get("vnc_port", 8006),
|
|
1950
|
+
"waa_port": vm_data.get("waa_port", 5000),
|
|
1951
|
+
"docker_container": vm_data.get("docker_container", "win11-waa"),
|
|
1952
|
+
"internal_ip": vm_data.get("internal_ip", "20.20.20.21")
|
|
1953
|
+
}
|
|
1954
|
+
|
|
1955
|
+
vms.append(new_vm)
|
|
1956
|
+
|
|
1957
|
+
# Save registry
|
|
1958
|
+
try:
|
|
1959
|
+
registry_file.parent.mkdir(parents=True, exist_ok=True)
|
|
1960
|
+
with open(registry_file, 'w') as f:
|
|
1961
|
+
json.dump(vms, f, indent=2)
|
|
1962
|
+
return {"status": "success", "vm": new_vm}
|
|
1963
|
+
except Exception as e:
|
|
1964
|
+
return {"status": "error", "message": str(e)}
|
|
1965
|
+
|
|
1966
|
+
def _start_benchmark_run(self, params: dict) -> dict:
|
|
1967
|
+
"""Start a benchmark run with the given parameters.
|
|
1968
|
+
|
|
1969
|
+
Runs the benchmark in a background thread and returns immediately.
|
|
1970
|
+
Progress is tracked via benchmark_progress.json.
|
|
1971
|
+
|
|
1972
|
+
Expected params:
|
|
1973
|
+
{
|
|
1974
|
+
"model": "gpt-4o",
|
|
1975
|
+
"num_tasks": 5,
|
|
1976
|
+
"agent": "navi",
|
|
1977
|
+
"task_selection": "all" | "domain" | "task_ids",
|
|
1978
|
+
"domain": "general", // if task_selection == "domain"
|
|
1979
|
+
"task_ids": ["task_001", "task_015"] // if task_selection == "task_ids"
|
|
1980
|
+
}
|
|
1981
|
+
|
|
1982
|
+
Returns:
|
|
1983
|
+
dict with status and params
|
|
1984
|
+
"""
|
|
1985
|
+
from dotenv import load_dotenv
|
|
1986
|
+
|
|
1987
|
+
# Load .env file for API keys
|
|
1988
|
+
project_root = Path(__file__).parent.parent.parent
|
|
1989
|
+
load_dotenv(project_root / ".env")
|
|
1990
|
+
|
|
1991
|
+
# Build CLI command
|
|
1992
|
+
cmd = [
|
|
1993
|
+
"uv", "run", "python", "-m", "openadapt_ml.benchmarks.cli",
|
|
1994
|
+
"vm", "run-waa",
|
|
1995
|
+
"--num-tasks", str(params.get("num_tasks", 5)),
|
|
1996
|
+
"--model", params.get("model", "gpt-4o"),
|
|
1997
|
+
"--agent", params.get("agent", "navi"),
|
|
1998
|
+
"--no-open" # Don't open viewer (already open)
|
|
1999
|
+
]
|
|
2000
|
+
|
|
2001
|
+
# Add task selection args
|
|
2002
|
+
task_selection = params.get("task_selection", "all")
|
|
2003
|
+
if task_selection == "domain":
|
|
2004
|
+
domain = params.get("domain", "general")
|
|
2005
|
+
cmd.extend(["--domain", domain])
|
|
2006
|
+
elif task_selection == "task_ids":
|
|
2007
|
+
task_ids = params.get("task_ids", [])
|
|
2008
|
+
if task_ids:
|
|
2009
|
+
cmd.extend(["--task-ids", ",".join(task_ids)])
|
|
2010
|
+
|
|
2011
|
+
# Create progress log file (in cwd which is serve_dir)
|
|
2012
|
+
progress_file = Path("benchmark_progress.json")
|
|
2013
|
+
|
|
2014
|
+
# Write initial progress
|
|
2015
|
+
model = params.get("model", "gpt-4o")
|
|
2016
|
+
num_tasks = params.get("num_tasks", 5)
|
|
2017
|
+
agent = params.get("agent", "navi")
|
|
2018
|
+
|
|
2019
|
+
print(f"\n[Benchmark] Starting WAA benchmark: model={model}, tasks={num_tasks}, agent={agent}")
|
|
2020
|
+
print(f"[Benchmark] Task selection: {task_selection}")
|
|
2021
|
+
if task_selection == "domain":
|
|
2022
|
+
print(f"[Benchmark] Domain: {params.get('domain', 'general')}")
|
|
2023
|
+
elif task_selection == "task_ids":
|
|
2024
|
+
print(f"[Benchmark] Task IDs: {params.get('task_ids', [])}")
|
|
2025
|
+
print(f"[Benchmark] Command: {' '.join(cmd)}")
|
|
2026
|
+
|
|
2027
|
+
progress_file.write_text(json.dumps({
|
|
2028
|
+
"status": "running",
|
|
2029
|
+
"model": model,
|
|
2030
|
+
"num_tasks": num_tasks,
|
|
2031
|
+
"agent": agent,
|
|
2032
|
+
"task_selection": task_selection,
|
|
2033
|
+
"tasks_complete": 0,
|
|
2034
|
+
"message": f"Starting {model} benchmark with {num_tasks} tasks..."
|
|
2035
|
+
}))
|
|
2036
|
+
|
|
2037
|
+
# Copy environment with loaded vars
|
|
2038
|
+
env = os.environ.copy()
|
|
2039
|
+
|
|
2040
|
+
# Run in background thread
|
|
2041
|
+
def run():
|
|
2042
|
+
result = subprocess.run(
|
|
2043
|
+
cmd,
|
|
2044
|
+
capture_output=True,
|
|
2045
|
+
text=True,
|
|
2046
|
+
cwd=str(project_root),
|
|
2047
|
+
env=env
|
|
2048
|
+
)
|
|
2049
|
+
|
|
2050
|
+
print(f"\n[Benchmark] Output:\n{result.stdout}")
|
|
2051
|
+
if result.stderr:
|
|
2052
|
+
print(f"[Benchmark] Stderr: {result.stderr}")
|
|
2053
|
+
|
|
2054
|
+
if result.returncode == 0:
|
|
2055
|
+
print(f"[Benchmark] Complete. Regenerating viewer...")
|
|
2056
|
+
progress_file.write_text(json.dumps({
|
|
2057
|
+
"status": "complete",
|
|
2058
|
+
"model": model,
|
|
2059
|
+
"num_tasks": num_tasks,
|
|
2060
|
+
"message": "Benchmark complete. Refresh to see results."
|
|
2061
|
+
}))
|
|
2062
|
+
# Regenerate benchmark viewer
|
|
2063
|
+
_regenerate_benchmark_viewer_if_available(serve_dir)
|
|
2064
|
+
else:
|
|
2065
|
+
error_msg = result.stderr[:200] if result.stderr else "Unknown error"
|
|
2066
|
+
print(f"[Benchmark] Failed: {error_msg}")
|
|
2067
|
+
progress_file.write_text(json.dumps({
|
|
2068
|
+
"status": "error",
|
|
2069
|
+
"model": model,
|
|
2070
|
+
"num_tasks": num_tasks,
|
|
2071
|
+
"message": f"Benchmark failed: {error_msg}"
|
|
2072
|
+
}))
|
|
2073
|
+
|
|
2074
|
+
threading.Thread(target=run, daemon=True).start()
|
|
2075
|
+
|
|
2076
|
+
return {"status": "started", "params": params}
|
|
2077
|
+
|
|
559
2078
|
def do_OPTIONS(self):
|
|
560
2079
|
# Handle CORS preflight
|
|
561
2080
|
self.send_response(200)
|
|
@@ -564,7 +2083,11 @@ def cmd_serve(args: argparse.Namespace) -> int:
|
|
|
564
2083
|
self.send_header('Access-Control-Allow-Headers', 'Content-Type')
|
|
565
2084
|
self.end_headers()
|
|
566
2085
|
|
|
567
|
-
|
|
2086
|
+
class ThreadedTCPServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
|
|
2087
|
+
allow_reuse_address = True
|
|
2088
|
+
daemon_threads = True # Don't block shutdown
|
|
2089
|
+
|
|
2090
|
+
with ThreadedTCPServer(("", port), StopHandler) as httpd:
|
|
568
2091
|
url = f"http://localhost:{port}/{start_page}"
|
|
569
2092
|
print(f"\nServing at: {url}")
|
|
570
2093
|
print(f"Directory: {serve_dir}")
|
|
@@ -612,6 +2135,36 @@ def cmd_viewer(args: argparse.Namespace) -> int:
|
|
|
612
2135
|
state.losses = data.get("losses", [])
|
|
613
2136
|
state.status = data.get("status", "completed")
|
|
614
2137
|
state.elapsed_time = data.get("elapsed_time", 0.0) # Load elapsed time for completed training
|
|
2138
|
+
state.goal = data.get("goal", "")
|
|
2139
|
+
state.config_path = data.get("config_path", "")
|
|
2140
|
+
state.capture_path = data.get("capture_path", "")
|
|
2141
|
+
|
|
2142
|
+
# Load model config from training_log.json or fall back to reading config file
|
|
2143
|
+
state.model_name = data.get("model_name", "")
|
|
2144
|
+
state.lora_r = data.get("lora_r", 0)
|
|
2145
|
+
state.lora_alpha = data.get("lora_alpha", 0)
|
|
2146
|
+
state.load_in_4bit = data.get("load_in_4bit", False)
|
|
2147
|
+
|
|
2148
|
+
# If model config not in JSON, try to read from config file
|
|
2149
|
+
if not state.model_name and state.config_path:
|
|
2150
|
+
try:
|
|
2151
|
+
import yaml
|
|
2152
|
+
# Try relative to project root first, then as absolute path
|
|
2153
|
+
project_root = Path(__file__).parent.parent.parent
|
|
2154
|
+
config_file = project_root / state.config_path
|
|
2155
|
+
if not config_file.exists():
|
|
2156
|
+
config_file = Path(state.config_path)
|
|
2157
|
+
if config_file.exists():
|
|
2158
|
+
with open(config_file) as cf:
|
|
2159
|
+
cfg = yaml.safe_load(cf)
|
|
2160
|
+
if cfg and "model" in cfg:
|
|
2161
|
+
state.model_name = cfg["model"].get("name", "")
|
|
2162
|
+
state.load_in_4bit = cfg["model"].get("load_in_4bit", False)
|
|
2163
|
+
if cfg and "lora" in cfg:
|
|
2164
|
+
state.lora_r = cfg["lora"].get("r", 0)
|
|
2165
|
+
state.lora_alpha = cfg["lora"].get("lora_alpha", 0)
|
|
2166
|
+
except Exception as e:
|
|
2167
|
+
print(f" Warning: Could not read config file: {e}")
|
|
615
2168
|
|
|
616
2169
|
config = TrainingConfig(
|
|
617
2170
|
num_train_epochs=data.get("total_epochs", 5),
|
|
@@ -757,6 +2310,7 @@ Examples:
|
|
|
757
2310
|
p_serve.add_argument("--no-regenerate", action="store_true",
|
|
758
2311
|
help="Skip regenerating dashboard/viewer (serve existing files)")
|
|
759
2312
|
p_serve.add_argument("--benchmark", help="Serve benchmark results directory instead of training output")
|
|
2313
|
+
p_serve.add_argument("--start-page", help="Override default start page (e.g., benchmark.html)")
|
|
760
2314
|
p_serve.set_defaults(func=cmd_serve)
|
|
761
2315
|
|
|
762
2316
|
# viewer
|