openadapt-ml 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. openadapt_ml/baselines/__init__.py +121 -0
  2. openadapt_ml/baselines/adapter.py +185 -0
  3. openadapt_ml/baselines/cli.py +314 -0
  4. openadapt_ml/baselines/config.py +448 -0
  5. openadapt_ml/baselines/parser.py +922 -0
  6. openadapt_ml/baselines/prompts.py +787 -0
  7. openadapt_ml/benchmarks/__init__.py +13 -107
  8. openadapt_ml/benchmarks/agent.py +297 -374
  9. openadapt_ml/benchmarks/azure.py +62 -24
  10. openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
  11. openadapt_ml/benchmarks/cli.py +1874 -751
  12. openadapt_ml/benchmarks/trace_export.py +631 -0
  13. openadapt_ml/benchmarks/viewer.py +1236 -0
  14. openadapt_ml/benchmarks/vm_monitor.py +1111 -0
  15. openadapt_ml/benchmarks/waa_deploy/Dockerfile +216 -0
  16. openadapt_ml/benchmarks/waa_deploy/__init__.py +10 -0
  17. openadapt_ml/benchmarks/waa_deploy/api_agent.py +540 -0
  18. openadapt_ml/benchmarks/waa_deploy/start_waa_server.bat +53 -0
  19. openadapt_ml/cloud/azure_inference.py +3 -5
  20. openadapt_ml/cloud/lambda_labs.py +722 -307
  21. openadapt_ml/cloud/local.py +3194 -89
  22. openadapt_ml/cloud/ssh_tunnel.py +595 -0
  23. openadapt_ml/datasets/next_action.py +125 -96
  24. openadapt_ml/evals/grounding.py +32 -9
  25. openadapt_ml/evals/plot_eval_metrics.py +15 -13
  26. openadapt_ml/evals/trajectory_matching.py +120 -57
  27. openadapt_ml/experiments/demo_prompt/__init__.py +19 -0
  28. openadapt_ml/experiments/demo_prompt/format_demo.py +236 -0
  29. openadapt_ml/experiments/demo_prompt/results/experiment_20251231_002125.json +83 -0
  30. openadapt_ml/experiments/demo_prompt/results/experiment_n30_20251231_165958.json +1100 -0
  31. openadapt_ml/experiments/demo_prompt/results/multistep_20251231_025051.json +182 -0
  32. openadapt_ml/experiments/demo_prompt/run_experiment.py +541 -0
  33. openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
  34. openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
  35. openadapt_ml/experiments/representation_shootout/config.py +390 -0
  36. openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
  37. openadapt_ml/experiments/representation_shootout/runner.py +687 -0
  38. openadapt_ml/experiments/waa_demo/__init__.py +10 -0
  39. openadapt_ml/experiments/waa_demo/demos.py +357 -0
  40. openadapt_ml/experiments/waa_demo/runner.py +732 -0
  41. openadapt_ml/experiments/waa_demo/tasks.py +151 -0
  42. openadapt_ml/export/__init__.py +9 -0
  43. openadapt_ml/export/__main__.py +6 -0
  44. openadapt_ml/export/cli.py +89 -0
  45. openadapt_ml/export/parquet.py +277 -0
  46. openadapt_ml/grounding/detector.py +18 -14
  47. openadapt_ml/ingest/__init__.py +11 -10
  48. openadapt_ml/ingest/capture.py +97 -86
  49. openadapt_ml/ingest/loader.py +120 -69
  50. openadapt_ml/ingest/synthetic.py +344 -193
  51. openadapt_ml/models/api_adapter.py +14 -4
  52. openadapt_ml/models/base_adapter.py +10 -2
  53. openadapt_ml/models/providers/__init__.py +288 -0
  54. openadapt_ml/models/providers/anthropic.py +266 -0
  55. openadapt_ml/models/providers/base.py +299 -0
  56. openadapt_ml/models/providers/google.py +376 -0
  57. openadapt_ml/models/providers/openai.py +342 -0
  58. openadapt_ml/models/qwen_vl.py +46 -19
  59. openadapt_ml/perception/__init__.py +35 -0
  60. openadapt_ml/perception/integration.py +399 -0
  61. openadapt_ml/retrieval/README.md +226 -0
  62. openadapt_ml/retrieval/USAGE.md +391 -0
  63. openadapt_ml/retrieval/__init__.py +91 -0
  64. openadapt_ml/retrieval/demo_retriever.py +843 -0
  65. openadapt_ml/retrieval/embeddings.py +630 -0
  66. openadapt_ml/retrieval/index.py +194 -0
  67. openadapt_ml/retrieval/retriever.py +162 -0
  68. openadapt_ml/runtime/__init__.py +50 -0
  69. openadapt_ml/runtime/policy.py +27 -14
  70. openadapt_ml/runtime/safety_gate.py +471 -0
  71. openadapt_ml/schema/__init__.py +113 -0
  72. openadapt_ml/schema/converters.py +588 -0
  73. openadapt_ml/schema/episode.py +470 -0
  74. openadapt_ml/scripts/capture_screenshots.py +530 -0
  75. openadapt_ml/scripts/compare.py +102 -61
  76. openadapt_ml/scripts/demo_policy.py +4 -1
  77. openadapt_ml/scripts/eval_policy.py +19 -14
  78. openadapt_ml/scripts/make_gif.py +1 -1
  79. openadapt_ml/scripts/prepare_synthetic.py +16 -17
  80. openadapt_ml/scripts/train.py +98 -75
  81. openadapt_ml/segmentation/README.md +920 -0
  82. openadapt_ml/segmentation/__init__.py +97 -0
  83. openadapt_ml/segmentation/adapters/__init__.py +5 -0
  84. openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
  85. openadapt_ml/segmentation/annotator.py +610 -0
  86. openadapt_ml/segmentation/cache.py +290 -0
  87. openadapt_ml/segmentation/cli.py +674 -0
  88. openadapt_ml/segmentation/deduplicator.py +656 -0
  89. openadapt_ml/segmentation/frame_describer.py +788 -0
  90. openadapt_ml/segmentation/pipeline.py +340 -0
  91. openadapt_ml/segmentation/schemas.py +622 -0
  92. openadapt_ml/segmentation/segment_extractor.py +634 -0
  93. openadapt_ml/training/azure_ops_viewer.py +1097 -0
  94. openadapt_ml/training/benchmark_viewer.py +3255 -19
  95. openadapt_ml/training/shared_ui.py +7 -7
  96. openadapt_ml/training/stub_provider.py +57 -35
  97. openadapt_ml/training/trainer.py +255 -441
  98. openadapt_ml/training/trl_trainer.py +403 -0
  99. openadapt_ml/training/viewer.py +323 -108
  100. openadapt_ml/training/viewer_components.py +180 -0
  101. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +312 -69
  102. openadapt_ml-0.2.1.dist-info/RECORD +116 -0
  103. openadapt_ml/benchmarks/base.py +0 -366
  104. openadapt_ml/benchmarks/data_collection.py +0 -432
  105. openadapt_ml/benchmarks/runner.py +0 -381
  106. openadapt_ml/benchmarks/waa.py +0 -704
  107. openadapt_ml/schemas/__init__.py +0 -53
  108. openadapt_ml/schemas/sessions.py +0 -122
  109. openadapt_ml/schemas/validation.py +0 -252
  110. openadapt_ml-0.1.0.dist-info/RECORD +0 -55
  111. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
  112. {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -44,7 +44,9 @@ API_BASE = "https://cloud.lambdalabs.com/api/v1"
44
44
  DEFAULT_SERVER_PORT = 8765
45
45
 
46
46
 
47
- def start_dashboard_server(output_dir: Path, port: int = DEFAULT_SERVER_PORT) -> tuple[subprocess.Popen, str]:
47
+ def start_dashboard_server(
48
+ output_dir: Path, port: int = DEFAULT_SERVER_PORT
49
+ ) -> tuple[subprocess.Popen, str]:
48
50
  """Start a background HTTP server for the dashboard.
49
51
 
50
52
  Args:
@@ -54,8 +56,6 @@ def start_dashboard_server(output_dir: Path, port: int = DEFAULT_SERVER_PORT) ->
54
56
  Returns:
55
57
  (process, url): The server process and the dashboard URL
56
58
  """
57
- import webbrowser
58
- import threading
59
59
 
60
60
  # Start simple HTTP server in background thread
61
61
  server_proc = subprocess.Popen(
@@ -96,7 +96,9 @@ def open_dashboard_in_browser(output_dir: Path, port: int = DEFAULT_SERVER_PORT)
96
96
  return None
97
97
 
98
98
 
99
- def setup_capture_screenshots_symlink(output_dir: Path, capture_path: str | Path) -> bool:
99
+ def setup_capture_screenshots_symlink(
100
+ output_dir: Path, capture_path: str | Path
101
+ ) -> bool:
100
102
  """Create symlink from output_dir/screenshots to capture's screenshots folder.
101
103
 
102
104
  This allows the dashboard to serve screenshots via relative paths.
@@ -128,7 +130,9 @@ def setup_capture_screenshots_symlink(output_dir: Path, capture_path: str | Path
128
130
  return False
129
131
 
130
132
 
131
- def rewrite_evaluation_paths(evaluations: list[dict], remote_prefix: str = "/home/ubuntu/capture/") -> list[dict]:
133
+ def rewrite_evaluation_paths(
134
+ evaluations: list[dict], remote_prefix: str = "/home/ubuntu/capture/"
135
+ ) -> list[dict]:
132
136
  """Rewrite Lambda paths in evaluations to relative paths.
133
137
 
134
138
  Converts: /home/ubuntu/capture/screenshots/foo.png -> screenshots/foo.png
@@ -146,7 +150,9 @@ def rewrite_evaluation_paths(evaluations: list[dict], remote_prefix: str = "/hom
146
150
  return evaluations
147
151
 
148
152
 
149
- def download_checkpoints_from_instance(instance_ip: str, output_dir: Path, ssh_key: str | None = None) -> bool:
153
+ def download_checkpoints_from_instance(
154
+ instance_ip: str, output_dir: Path, ssh_key: str | None = None
155
+ ) -> bool:
150
156
  """Download checkpoints from Lambda instance.
151
157
 
152
158
  Args:
@@ -161,7 +167,9 @@ def download_checkpoints_from_instance(instance_ip: str, output_dir: Path, ssh_k
161
167
  checkpoints_dir.mkdir(parents=True, exist_ok=True)
162
168
 
163
169
  ssh_key = ssh_key or str(Path.home() / ".ssh" / "lambda_id_ed25519")
164
- ssh_opts = f"-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -i {ssh_key}"
170
+ ssh_opts = (
171
+ f"-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -i {ssh_key}"
172
+ )
165
173
 
166
174
  # Download checkpoints from remote
167
175
  remote_path = f"ubuntu@{instance_ip}:~/openadapt-ml/checkpoints/"
@@ -187,6 +195,7 @@ def check_stop_signal(output_dir: Path) -> bool:
187
195
  @dataclass
188
196
  class InstanceType:
189
197
  """Lambda Labs instance type."""
198
+
190
199
  name: str
191
200
  price_cents_per_hour: int
192
201
  description: str
@@ -216,6 +225,7 @@ class InstanceType:
216
225
  @dataclass
217
226
  class Instance:
218
227
  """Running Lambda Labs instance."""
228
+
219
229
  id: str
220
230
  name: str
221
231
  instance_type: str
@@ -236,6 +246,7 @@ class LambdaLabsClient:
236
246
  # Try provided key, then settings, then env var
237
247
  if not api_key:
238
248
  from openadapt_ml.config import settings
249
+
239
250
  api_key = settings.lambda_api_key or os.environ.get("LAMBDA_API_KEY")
240
251
 
241
252
  self.api_key = api_key
@@ -268,19 +279,25 @@ class LambdaLabsClient:
268
279
 
269
280
  for name, info in data.get("data", {}).items():
270
281
  specs = info.get("instance_type", {}).get("specs", {})
271
- regions = [r["name"] for r in info.get("regions_with_capacity_available", [])]
272
-
273
- types.append(InstanceType(
274
- name=name,
275
- price_cents_per_hour=info.get("instance_type", {}).get("price_cents_per_hour", 0),
276
- description=info.get("instance_type", {}).get("description", ""),
277
- gpu_count=specs.get("gpus", 0),
278
- gpu_type=info.get("instance_type", {}).get("gpu_description", ""),
279
- vcpus=specs.get("vcpus", 0),
280
- memory_gb=specs.get("memory_gib", 0),
281
- storage_gb=specs.get("storage_gib", 0),
282
- available_regions=regions,
283
- ))
282
+ regions = [
283
+ r["name"] for r in info.get("regions_with_capacity_available", [])
284
+ ]
285
+
286
+ types.append(
287
+ InstanceType(
288
+ name=name,
289
+ price_cents_per_hour=info.get("instance_type", {}).get(
290
+ "price_cents_per_hour", 0
291
+ ),
292
+ description=info.get("instance_type", {}).get("description", ""),
293
+ gpu_count=specs.get("gpus", 0),
294
+ gpu_type=info.get("instance_type", {}).get("gpu_description", ""),
295
+ vcpus=specs.get("vcpus", 0),
296
+ memory_gb=specs.get("memory_gib", 0),
297
+ storage_gb=specs.get("storage_gib", 0),
298
+ available_regions=regions,
299
+ )
300
+ )
284
301
 
285
302
  # Sort by price
286
303
  types.sort(key=lambda t: t.price_cents_per_hour)
@@ -309,15 +326,17 @@ class LambdaLabsClient:
309
326
  else:
310
327
  ssh_key_names = ssh_keys # Already list of strings
311
328
 
312
- instances.append(Instance(
313
- id=inst["id"],
314
- name=inst.get("name", ""),
315
- instance_type=inst.get("instance_type", {}).get("name", "unknown"),
316
- status=inst.get("status", "unknown"),
317
- ip=inst.get("ip"),
318
- region=inst.get("region", {}).get("name", "unknown"),
319
- ssh_key_names=ssh_key_names,
320
- ))
329
+ instances.append(
330
+ Instance(
331
+ id=inst["id"],
332
+ name=inst.get("name", ""),
333
+ instance_type=inst.get("instance_type", {}).get("name", "unknown"),
334
+ status=inst.get("status", "unknown"),
335
+ ip=inst.get("ip"),
336
+ region=inst.get("region", {}).get("name", "unknown"),
337
+ ssh_key_names=ssh_key_names,
338
+ )
339
+ )
321
340
 
322
341
  return instances
323
342
 
@@ -393,9 +412,18 @@ class LambdaLabsClient:
393
412
  for attempt in range(60): # Wait up to 5 minutes for SSH
394
413
  try:
395
414
  result = subprocess.run(
396
- ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=10",
397
- f"ubuntu@{instance.ip}", "echo ready"],
398
- capture_output=True, text=True, timeout=20
415
+ [
416
+ "ssh",
417
+ "-o",
418
+ "StrictHostKeyChecking=no",
419
+ "-o",
420
+ "ConnectTimeout=10",
421
+ f"ubuntu@{instance.ip}",
422
+ "echo ready",
423
+ ],
424
+ capture_output=True,
425
+ text=True,
426
+ timeout=20,
399
427
  )
400
428
  if result.returncode == 0:
401
429
  print("SSH ready!")
@@ -403,7 +431,7 @@ class LambdaLabsClient:
403
431
  except subprocess.TimeoutExpired:
404
432
  pass
405
433
  if attempt % 6 == 5: # Log progress every 30 seconds
406
- print(f" Still waiting for SSH ({(attempt+1)*5}s elapsed)...")
434
+ print(f" Still waiting for SSH ({(attempt + 1) * 5}s elapsed)...")
407
435
  time.sleep(5)
408
436
 
409
437
  print("Warning: SSH may not be ready yet, continuing anyway...")
@@ -411,7 +439,9 @@ class LambdaLabsClient:
411
439
 
412
440
  def terminate_instance(self, instance_id: str) -> bool:
413
441
  """Terminate an instance."""
414
- data = self._post("/instance-operations/terminate", {"instance_ids": [instance_id]})
442
+ data = self._post(
443
+ "/instance-operations/terminate", {"instance_ids": [instance_id]}
444
+ )
415
445
  terminated = data.get("data", {}).get("terminated_instances", [])
416
446
  return any(t.get("id") == instance_id for t in terminated)
417
447
 
@@ -421,7 +451,13 @@ class LambdaLabsClient:
421
451
  return "# Instance IP not yet available"
422
452
  return f"ssh {user}@{instance.ip}"
423
453
 
424
- def ssh_run(self, instance: Instance, command: str, timeout: int | None = None, retries: int = 3) -> subprocess.CompletedProcess:
454
+ def ssh_run(
455
+ self,
456
+ instance: Instance,
457
+ command: str,
458
+ timeout: int | None = None,
459
+ retries: int = 3,
460
+ ) -> subprocess.CompletedProcess:
425
461
  """Run a command on an instance via SSH.
426
462
 
427
463
  Args:
@@ -437,12 +473,17 @@ class LambdaLabsClient:
437
473
  raise RuntimeError("Instance has no IP address")
438
474
 
439
475
  ssh_cmd = [
440
- "ssh", "-o", "StrictHostKeyChecking=no",
441
- "-o", "ConnectTimeout=30", # Increased from 10
442
- "-o", "ServerAliveInterval=60", # Keep connection alive
443
- "-o", "ServerAliveCountMax=3",
476
+ "ssh",
477
+ "-o",
478
+ "StrictHostKeyChecking=no",
479
+ "-o",
480
+ "ConnectTimeout=30", # Increased from 10
481
+ "-o",
482
+ "ServerAliveInterval=60", # Keep connection alive
483
+ "-o",
484
+ "ServerAliveCountMax=3",
444
485
  f"ubuntu@{instance.ip}",
445
- command
486
+ command,
446
487
  ]
447
488
 
448
489
  last_error = None
@@ -462,7 +503,12 @@ class LambdaLabsClient:
462
503
 
463
504
  raise last_error if last_error else RuntimeError("SSH failed")
464
505
 
465
- def setup_instance(self, instance: Instance, repo_url: str = "https://github.com/OpenAdaptAI/openadapt-ml.git", clean_gpu: bool = True) -> bool:
506
+ def setup_instance(
507
+ self,
508
+ instance: Instance,
509
+ repo_url: str = "https://github.com/OpenAdaptAI/openadapt-ml.git",
510
+ clean_gpu: bool = True,
511
+ ) -> bool:
466
512
  """Set up training environment on instance.
467
513
 
468
514
  Clones repo, installs uv, syncs dependencies.
@@ -475,7 +521,9 @@ class LambdaLabsClient:
475
521
  if clean_gpu:
476
522
  print(" Clearing GPU memory...")
477
523
  try:
478
- self.ssh_run(instance, '''
524
+ self.ssh_run(
525
+ instance,
526
+ """
479
527
  python3 -c "
480
528
  import torch
481
529
  if torch.cuda.is_available():
@@ -485,11 +533,13 @@ if torch.cuda.is_available():
485
533
  " 2>/dev/null || true
486
534
  # Kill any stale python processes using GPU
487
535
  pkill -f "python.*train" 2>/dev/null || true
488
- ''', timeout=60)
536
+ """,
537
+ timeout=60,
538
+ )
489
539
  except Exception as e:
490
540
  print(f" GPU cleanup skipped: {e}")
491
541
 
492
- setup_script = f'''
542
+ setup_script = f"""
493
543
  set -e
494
544
  cd ~
495
545
 
@@ -509,10 +559,12 @@ fi
509
559
  cd openadapt-ml
510
560
  uv sync
511
561
  echo "SETUP_COMPLETE"
512
- '''
562
+ """
513
563
 
514
564
  try:
515
- result = self.ssh_run(instance, setup_script, timeout=900) # 15 min timeout for setup
565
+ result = self.ssh_run(
566
+ instance, setup_script, timeout=900
567
+ ) # 15 min timeout for setup
516
568
 
517
569
  if "SETUP_COMPLETE" in result.stdout:
518
570
  print(" Environment ready")
@@ -528,7 +580,9 @@ echo "SETUP_COMPLETE"
528
580
  print(f" Setup failed: {e}")
529
581
  return False
530
582
 
531
- def sync_local_code(self, instance: Instance, local_repo_path: str = ".", retries: int = 3) -> bool:
583
+ def sync_local_code(
584
+ self, instance: Instance, local_repo_path: str = ".", retries: int = 3
585
+ ) -> bool:
532
586
  """Sync local code changes to remote instance.
533
587
 
534
588
  Uses rsync to push local code, excluding .venv, .git, etc.
@@ -551,19 +605,30 @@ echo "SETUP_COMPLETE"
551
605
  ssh_opts = "ssh -o StrictHostKeyChecking=no -o ConnectTimeout=30 -o ServerAliveInterval=60"
552
606
 
553
607
  rsync_cmd = [
554
- "rsync", "-avz", "--progress",
608
+ "rsync",
609
+ "-avz",
610
+ "--progress",
555
611
  "--timeout=120", # 2 minute timeout per file
556
- "--exclude", ".venv",
557
- "--exclude", ".git",
558
- "--exclude", "__pycache__",
559
- "--exclude", "*.pyc",
560
- "--exclude", ".env",
561
- "--exclude", "training_output",
562
- "--exclude", "checkpoints",
563
- "--exclude", "synthetic*",
564
- "-e", ssh_opts,
612
+ "--exclude",
613
+ ".venv",
614
+ "--exclude",
615
+ ".git",
616
+ "--exclude",
617
+ "__pycache__",
618
+ "--exclude",
619
+ "*.pyc",
620
+ "--exclude",
621
+ ".env",
622
+ "--exclude",
623
+ "training_output",
624
+ "--exclude",
625
+ "checkpoints",
626
+ "--exclude",
627
+ "synthetic*",
628
+ "-e",
629
+ ssh_opts,
565
630
  f"{local_repo_path}/",
566
- f"ubuntu@{instance.ip}:~/openadapt-ml/"
631
+ f"ubuntu@{instance.ip}:~/openadapt-ml/",
567
632
  ]
568
633
 
569
634
  for attempt in range(retries):
@@ -577,7 +642,13 @@ echo "SETUP_COMPLETE"
577
642
 
578
643
  return False
579
644
 
580
- def upload_capture(self, instance: Instance, local_path: str, remote_path: str = "~/capture", retries: int = 3) -> bool:
645
+ def upload_capture(
646
+ self,
647
+ instance: Instance,
648
+ local_path: str,
649
+ remote_path: str = "~/capture",
650
+ retries: int = 3,
651
+ ) -> bool:
581
652
  """Upload a capture directory to instance via rsync.
582
653
 
583
654
  Args:
@@ -598,11 +669,14 @@ echo "SETUP_COMPLETE"
598
669
  ssh_opts = "ssh -o StrictHostKeyChecking=no -o ConnectTimeout=30 -o ServerAliveInterval=60"
599
670
 
600
671
  rsync_cmd = [
601
- "rsync", "-avz", "--progress",
672
+ "rsync",
673
+ "-avz",
674
+ "--progress",
602
675
  "--timeout=120", # 2 minute timeout per file
603
- "-e", ssh_opts,
676
+ "-e",
677
+ ssh_opts,
604
678
  f"{local_path}/",
605
- f"ubuntu@{instance.ip}:{remote_path}/"
679
+ f"ubuntu@{instance.ip}:{remote_path}/",
606
680
  ]
607
681
 
608
682
  for attempt in range(retries):
@@ -646,16 +720,18 @@ echo "SETUP_COMPLETE"
646
720
  train_cmd += f' --goal "{goal}"'
647
721
 
648
722
  # Full script with environment setup
649
- script = f'''
723
+ script = f"""
650
724
  cd ~/openadapt-ml
651
725
  export PATH="$HOME/.local/bin:$PATH"
652
726
  {train_cmd}
653
- '''
727
+ """
654
728
 
655
729
  ssh_cmd = [
656
- "ssh", "-o", "StrictHostKeyChecking=no",
730
+ "ssh",
731
+ "-o",
732
+ "StrictHostKeyChecking=no",
657
733
  f"ubuntu@{instance.ip}",
658
- script
734
+ script,
659
735
  ]
660
736
 
661
737
  print(f"Running training on {instance.ip}...")
@@ -705,37 +781,42 @@ export PATH="$HOME/.local/bin:$PATH"
705
781
  if include_logs:
706
782
  print(" Downloading training logs...")
707
783
  rsync_cmd = [
708
- "rsync", "-avz",
709
- "-e", "ssh -o StrictHostKeyChecking=no",
784
+ "rsync",
785
+ "-avz",
786
+ "-e",
787
+ "ssh -o StrictHostKeyChecking=no",
710
788
  f"ubuntu@{instance.ip}:{remote_path}/training_output/",
711
- f"{local_path}/training_output_lambda/"
789
+ f"{local_path}/training_output_lambda/",
712
790
  ]
713
791
  result = subprocess.run(rsync_cmd, capture_output=True)
714
792
  if result.returncode == 0:
715
793
  print(" Training logs downloaded to training_output_lambda/")
716
794
  else:
717
- print(f" Warning: Failed to download logs")
795
+ print(" Warning: Failed to download logs")
718
796
  success = False
719
797
 
720
798
  # Download checkpoint
721
799
  if include_checkpoint:
722
800
  print(" Downloading checkpoint...")
723
801
  rsync_cmd = [
724
- "rsync", "-avz",
725
- "-e", "ssh -o StrictHostKeyChecking=no",
802
+ "rsync",
803
+ "-avz",
804
+ "-e",
805
+ "ssh -o StrictHostKeyChecking=no",
726
806
  f"ubuntu@{instance.ip}:{remote_path}/checkpoints/",
727
- f"{local_path}/checkpoints_lambda/"
807
+ f"{local_path}/checkpoints_lambda/",
728
808
  ]
729
809
  result = subprocess.run(rsync_cmd, capture_output=True)
730
810
  if result.returncode == 0:
731
811
  print(" Checkpoint downloaded to checkpoints_lambda/")
732
812
  else:
733
- print(f" Warning: Failed to download checkpoint (may not exist yet)")
813
+ print(" Warning: Failed to download checkpoint (may not exist yet)")
734
814
 
735
815
  # Regenerate all dashboards with static navigation and correct status
736
816
  if include_logs:
737
817
  try:
738
818
  from openadapt_ml.training.trainer import regenerate_all_dashboards
819
+
739
820
  output_dir = Path(local_path) / "training_output_lambda"
740
821
  if output_dir.exists():
741
822
  print(" Regenerating dashboards with static navigation...")
@@ -754,8 +835,9 @@ export PATH="$HOME/.local/bin:$PATH"
754
835
  )
755
836
  try:
756
837
  import json
838
+
757
839
  return json.loads(result.stdout.strip())
758
- except:
840
+ except Exception:
759
841
  return {}
760
842
 
761
843
 
@@ -797,19 +879,22 @@ def main():
797
879
  subparsers = parser.add_subparsers(dest="command", help="Command")
798
880
 
799
881
  # List instances command
800
- list_parser = subparsers.add_parser("list", help="List available instance types")
882
+ subparsers.add_parser("list", help="List available instance types")
801
883
 
802
884
  # Status command
803
- status_parser = subparsers.add_parser("status", help="Show running instances")
885
+ subparsers.add_parser("status", help="Show running instances")
804
886
 
805
887
  # Launch command
806
888
  launch_parser = subparsers.add_parser("launch", help="Launch a GPU instance")
807
889
  launch_parser.add_argument(
808
- "--type", "-t",
890
+ "--type",
891
+ "-t",
809
892
  default="gpu_1x_a100",
810
893
  help="Instance type (default: gpu_1x_a100)",
811
894
  )
812
- launch_parser.add_argument("--region", "-r", help="Region (auto-selects if not specified)")
895
+ launch_parser.add_argument(
896
+ "--region", "-r", help="Region (auto-selects if not specified)"
897
+ )
813
898
  launch_parser.add_argument("--name", "-n", help="Instance name")
814
899
 
815
900
  # Terminate command
@@ -817,112 +902,269 @@ def main():
817
902
  term_parser.add_argument("instance_id", help="Instance ID to terminate")
818
903
 
819
904
  # SSH command - run commands or get interactive shell
820
- ssh_parser = subparsers.add_parser("ssh", help="SSH into Lambda instance or run command")
821
- ssh_parser.add_argument("instance_id", nargs="?", help="Instance ID (uses first if not specified)")
822
- ssh_parser.add_argument("--cmd", "-c", help="Command to run (opens shell if not specified)")
823
- ssh_parser.add_argument("--timeout", "-t", type=int, default=60, help="Command timeout in seconds")
905
+ ssh_parser = subparsers.add_parser(
906
+ "ssh", help="SSH into Lambda instance or run command"
907
+ )
908
+ ssh_parser.add_argument(
909
+ "instance_id", nargs="?", help="Instance ID (uses first if not specified)"
910
+ )
911
+ ssh_parser.add_argument(
912
+ "--cmd", "-c", help="Command to run (opens shell if not specified)"
913
+ )
914
+ ssh_parser.add_argument(
915
+ "--timeout", "-t", type=int, default=60, help="Command timeout in seconds"
916
+ )
824
917
 
825
918
  # Serve command - start dashboard server with stop button support
826
- serve_parser = subparsers.add_parser("serve", help="Start dashboard server with stop button support")
827
- serve_parser.add_argument("--output", "-o", default="training_output", help="Output directory (default: training_output)")
828
- serve_parser.add_argument("--port", "-p", type=int, default=8765, help="Port (default: 8765)")
829
- serve_parser.add_argument("--open", action="store_true", help="Open dashboard in browser")
919
+ serve_parser = subparsers.add_parser(
920
+ "serve", help="Start dashboard server with stop button support"
921
+ )
922
+ serve_parser.add_argument(
923
+ "--output",
924
+ "-o",
925
+ default="training_output",
926
+ help="Output directory (default: training_output)",
927
+ )
928
+ serve_parser.add_argument(
929
+ "--port", "-p", type=int, default=8765, help="Port (default: 8765)"
930
+ )
931
+ serve_parser.add_argument(
932
+ "--open", action="store_true", help="Open dashboard in browser"
933
+ )
830
934
 
831
935
  # Rsync command - copy files to/from Lambda instance
832
- rsync_parser = subparsers.add_parser("rsync", help="Rsync files to/from Lambda instance")
833
- rsync_parser.add_argument("source", help="Source path (prefix with 'remote:' for remote paths)")
834
- rsync_parser.add_argument("dest", help="Destination path (prefix with 'remote:' for remote paths)")
835
- rsync_parser.add_argument("instance_id", nargs="?", help="Instance ID (uses first if not specified)")
836
- rsync_parser.add_argument("--delete", action="store_true", help="Delete extraneous files from dest")
936
+ rsync_parser = subparsers.add_parser(
937
+ "rsync", help="Rsync files to/from Lambda instance"
938
+ )
939
+ rsync_parser.add_argument(
940
+ "source", help="Source path (prefix with 'remote:' for remote paths)"
941
+ )
942
+ rsync_parser.add_argument(
943
+ "dest", help="Destination path (prefix with 'remote:' for remote paths)"
944
+ )
945
+ rsync_parser.add_argument(
946
+ "instance_id", nargs="?", help="Instance ID (uses first if not specified)"
947
+ )
948
+ rsync_parser.add_argument(
949
+ "--delete", action="store_true", help="Delete extraneous files from dest"
950
+ )
837
951
 
838
952
  # Setup command
839
- setup_parser = subparsers.add_parser("setup", help="Set up SSH key for Lambda Labs")
953
+ subparsers.add_parser("setup", help="Set up SSH key for Lambda Labs")
840
954
 
841
955
  # Train command - full automated training pipeline
842
956
  train_parser = subparsers.add_parser("train", help="Run training on Lambda GPU")
843
957
  train_parser.add_argument("--capture", "-c", help="Local path to capture directory")
844
958
  train_parser.add_argument("--goal", "-g", help="Task goal description")
845
- train_parser.add_argument("--config", default="configs/qwen3vl_capture_4bit.yaml", help="Config file (default: 4bit for memory efficiency)")
846
- train_parser.add_argument("--type", "-t", default="gpu_1x_a10", help="Instance type")
847
- train_parser.add_argument("--instance", "-i", help="Use existing instance ID instead of launching new")
848
- train_parser.add_argument("--no-terminate", action="store_true", help="Don't terminate instance after training")
849
- train_parser.add_argument("--max-runtime", type=int, default=60, help="Max runtime in minutes before auto-terminate (default: 60)")
850
- train_parser.add_argument("--open", action="store_true", help="Open dashboard in browser when training starts")
959
+ train_parser.add_argument(
960
+ "--config",
961
+ default="configs/qwen3vl_capture_4bit.yaml",
962
+ help="Config file (default: 4bit for memory efficiency)",
963
+ )
964
+ train_parser.add_argument(
965
+ "--type", "-t", default="gpu_1x_a10", help="Instance type"
966
+ )
967
+ train_parser.add_argument(
968
+ "--instance", "-i", help="Use existing instance ID instead of launching new"
969
+ )
970
+ train_parser.add_argument(
971
+ "--no-terminate",
972
+ action="store_true",
973
+ help="Don't terminate instance after training",
974
+ )
975
+ train_parser.add_argument(
976
+ "--max-runtime",
977
+ type=int,
978
+ default=60,
979
+ help="Max runtime in minutes before auto-terminate (default: 60)",
980
+ )
981
+ train_parser.add_argument(
982
+ "--open",
983
+ action="store_true",
984
+ help="Open dashboard in browser when training starts",
985
+ )
851
986
 
852
987
  # Training status command
853
- train_status_parser = subparsers.add_parser("train-status", help="Check training status on instance")
988
+ train_status_parser = subparsers.add_parser(
989
+ "train-status", help="Check training status on instance"
990
+ )
854
991
  train_status_parser.add_argument("instance_id", nargs="?", help="Instance ID")
855
992
 
856
993
  # Monitor command - live dashboard for Lambda training
857
- monitor_parser = subparsers.add_parser("monitor", help="Monitor Lambda training with live dashboard")
994
+ monitor_parser = subparsers.add_parser(
995
+ "monitor", help="Monitor Lambda training with live dashboard"
996
+ )
858
997
  monitor_parser.add_argument("instance_id", nargs="?", help="Instance ID")
859
- monitor_parser.add_argument("--open", action="store_true", help="Open dashboard in browser")
860
- monitor_parser.add_argument("--interval", type=int, default=5, help="Poll interval in seconds (default: 5)")
861
- monitor_parser.add_argument("--capture", type=str, help="Local capture path for screenshot symlink")
862
- monitor_parser.add_argument("--auto-stop-loss", type=float, default=0.5, help="Auto-terminate when loss drops below this (default: 0.5)")
863
- monitor_parser.add_argument("--download-checkpoints", action="store_true", default=True, help="Auto-download checkpoints each epoch")
864
- monitor_parser.add_argument("--no-download-checkpoints", action="store_false", dest="download_checkpoints", help="Disable checkpoint download")
865
- monitor_parser.add_argument("--stub", action="store_true", help="Use stub training provider (no GPU, instant simulation)")
998
+ monitor_parser.add_argument(
999
+ "--open", action="store_true", help="Open dashboard in browser"
1000
+ )
1001
+ monitor_parser.add_argument(
1002
+ "--interval", type=int, default=5, help="Poll interval in seconds (default: 5)"
1003
+ )
1004
+ monitor_parser.add_argument(
1005
+ "--capture", type=str, help="Local capture path for screenshot symlink"
1006
+ )
1007
+ monitor_parser.add_argument(
1008
+ "--auto-stop-loss",
1009
+ type=float,
1010
+ default=0.5,
1011
+ help="Auto-terminate when loss drops below this (default: 0.5)",
1012
+ )
1013
+ monitor_parser.add_argument(
1014
+ "--download-checkpoints",
1015
+ action="store_true",
1016
+ default=True,
1017
+ help="Auto-download checkpoints each epoch",
1018
+ )
1019
+ monitor_parser.add_argument(
1020
+ "--no-download-checkpoints",
1021
+ action="store_false",
1022
+ dest="download_checkpoints",
1023
+ help="Disable checkpoint download",
1024
+ )
1025
+ monitor_parser.add_argument(
1026
+ "--stub",
1027
+ action="store_true",
1028
+ help="Use stub training provider (no GPU, instant simulation)",
1029
+ )
866
1030
 
867
1031
  # Refresh command - one-shot dashboard update
868
- refresh_parser = subparsers.add_parser("refresh", help="One-shot refresh of training dashboard")
1032
+ refresh_parser = subparsers.add_parser(
1033
+ "refresh", help="One-shot refresh of training dashboard"
1034
+ )
869
1035
  refresh_parser.add_argument("instance_id", nargs="?", help="Instance ID")
870
- refresh_parser.add_argument("--open", action="store_true", help="Open dashboard in browser")
871
- refresh_parser.add_argument("--capture", type=str, help="Local capture path for screenshot preview")
1036
+ refresh_parser.add_argument(
1037
+ "--open", action="store_true", help="Open dashboard in browser"
1038
+ )
1039
+ refresh_parser.add_argument(
1040
+ "--capture", type=str, help="Local capture path for screenshot preview"
1041
+ )
872
1042
 
873
1043
  # Checkpoints command - list remote checkpoints
874
- checkpoints_parser = subparsers.add_parser("checkpoints", help="List checkpoints on remote instance")
1044
+ checkpoints_parser = subparsers.add_parser(
1045
+ "checkpoints", help="List checkpoints on remote instance"
1046
+ )
875
1047
  checkpoints_parser.add_argument("instance_id", nargs="?", help="Instance ID")
876
1048
 
877
1049
  # Download results command
878
- download_parser = subparsers.add_parser("download", help="Download training results from instance")
1050
+ download_parser = subparsers.add_parser(
1051
+ "download", help="Download training results from instance"
1052
+ )
879
1053
  download_parser.add_argument("instance_id", nargs="?", help="Instance ID")
880
- download_parser.add_argument("--output", "-o", default=".", help="Local output directory")
1054
+ download_parser.add_argument(
1055
+ "--output", "-o", default=".", help="Local output directory"
1056
+ )
881
1057
 
882
1058
  # Check files on instance
883
- files_parser = subparsers.add_parser("files", help="List training files on instance")
1059
+ files_parser = subparsers.add_parser(
1060
+ "files", help="List training files on instance"
1061
+ )
884
1062
  files_parser.add_argument("instance_id", nargs="?", help="Instance ID")
885
- files_parser.add_argument("--path", "-p", default="~/openadapt-ml", help="Path to check")
1063
+ files_parser.add_argument(
1064
+ "--path", "-p", default="~/openadapt-ml", help="Path to check"
1065
+ )
886
1066
 
887
1067
  # Kill command - terminate training processes
888
- kill_parser = subparsers.add_parser("kill", help="Kill training/inference processes on instance")
1068
+ kill_parser = subparsers.add_parser(
1069
+ "kill", help="Kill training/inference processes on instance"
1070
+ )
889
1071
  kill_parser.add_argument("instance_id", nargs="?", help="Instance ID")
890
- kill_parser.add_argument("--local", action="store_true", help="Also kill local Lambda-related processes")
891
- kill_parser.add_argument("--all", action="store_true", help="Kill all Python processes on instance (careful!)")
1072
+ kill_parser.add_argument(
1073
+ "--local", action="store_true", help="Also kill local Lambda-related processes"
1074
+ )
1075
+ kill_parser.add_argument(
1076
+ "--all",
1077
+ action="store_true",
1078
+ help="Kill all Python processes on instance (careful!)",
1079
+ )
892
1080
 
893
1081
  # Check command - analyze training status and early stopping
894
- check_parser = subparsers.add_parser("check", help="Check training health and early stopping status")
1082
+ check_parser = subparsers.add_parser(
1083
+ "check", help="Check training health and early stopping status"
1084
+ )
895
1085
  check_parser.add_argument("instance_id", nargs="?", help="Instance ID")
896
- check_parser.add_argument("--threshold", "-t", type=float, default=0.01,
897
- help="Early stopping threshold (loss improvement over last N steps)")
898
- check_parser.add_argument("--window", "-w", type=int, default=10,
899
- help="Number of recent steps to check for improvement")
1086
+ check_parser.add_argument(
1087
+ "--threshold",
1088
+ "-t",
1089
+ type=float,
1090
+ default=0.01,
1091
+ help="Early stopping threshold (loss improvement over last N steps)",
1092
+ )
1093
+ check_parser.add_argument(
1094
+ "--window",
1095
+ "-w",
1096
+ type=int,
1097
+ default=10,
1098
+ help="Number of recent steps to check for improvement",
1099
+ )
900
1100
 
901
1101
  # Compare command - run comparison on Lambda and sync back
902
- compare_parser = subparsers.add_parser("compare", help="Run human vs AI comparison on Lambda")
1102
+ compare_parser = subparsers.add_parser(
1103
+ "compare", help="Run human vs AI comparison on Lambda"
1104
+ )
903
1105
  compare_parser.add_argument("instance_id", nargs="?", help="Instance ID")
904
- compare_parser.add_argument("--checkpoint", "-c", help="Checkpoint to use (default: latest)")
905
- compare_parser.add_argument("--epoch", "-e", type=int, help="Use checkpoint from specific epoch")
906
- compare_parser.add_argument("--open", action="store_true", help="Open viewer after generation")
1106
+ compare_parser.add_argument(
1107
+ "--checkpoint", "-c", help="Checkpoint to use (default: latest)"
1108
+ )
1109
+ compare_parser.add_argument(
1110
+ "--epoch", "-e", type=int, help="Use checkpoint from specific epoch"
1111
+ )
1112
+ compare_parser.add_argument(
1113
+ "--open", action="store_true", help="Open viewer after generation"
1114
+ )
907
1115
 
908
1116
  # Results viewer command - downloads and generates comparison viewer
909
- results_parser = subparsers.add_parser("results", help="Download results and generate comparison viewer")
910
- results_parser.add_argument("--capture", "-c", required=True, help="Local capture directory (for comparison)")
1117
+ results_parser = subparsers.add_parser(
1118
+ "results", help="Download results and generate comparison viewer"
1119
+ )
1120
+ results_parser.add_argument(
1121
+ "--capture",
1122
+ "-c",
1123
+ required=True,
1124
+ help="Local capture directory (for comparison)",
1125
+ )
911
1126
  results_parser.add_argument("--goal", "-g", help="Task goal description")
912
- results_parser.add_argument("--open", action="store_true", help="Open viewer in browser")
1127
+ results_parser.add_argument(
1128
+ "--open", action="store_true", help="Open viewer in browser"
1129
+ )
913
1130
  results_parser.add_argument("instance_id", nargs="?", help="Instance ID")
914
1131
 
915
1132
  # Sync command - sync training output and regenerate navigation for file:// protocol
916
- sync_parser = subparsers.add_parser("sync", help="Sync training output from Lambda and regenerate navigation")
1133
+ sync_parser = subparsers.add_parser(
1134
+ "sync", help="Sync training output from Lambda and regenerate navigation"
1135
+ )
917
1136
  sync_parser.add_argument("instance_id", nargs="?", help="Instance ID")
918
- sync_parser.add_argument("--output", "-o", default="training_output", help="Local output directory (default: training_output)")
919
- sync_parser.add_argument("--open", action="store_true", help="Open dashboard in browser after sync")
1137
+ sync_parser.add_argument(
1138
+ "--output",
1139
+ "-o",
1140
+ default="training_output",
1141
+ help="Local output directory (default: training_output)",
1142
+ )
1143
+ sync_parser.add_argument(
1144
+ "--open", action="store_true", help="Open dashboard in browser after sync"
1145
+ )
920
1146
 
921
1147
  # Viewer command - regenerate local viewer (no Lambda required)
922
- viewer_parser = subparsers.add_parser("viewer", help="Regenerate local viewer (no Lambda required)")
923
- viewer_parser.add_argument("--output", "-o", default="training_output", help="Training output directory (default: training_output)")
924
- viewer_parser.add_argument("--dashboard", "-d", action="store_true", help="Regenerate dashboard instead of viewer")
925
- viewer_parser.add_argument("--open", action="store_true", help="Open in browser (use 'serve' instead for better experience)")
1148
+ viewer_parser = subparsers.add_parser(
1149
+ "viewer", help="Regenerate local viewer (no Lambda required)"
1150
+ )
1151
+ viewer_parser.add_argument(
1152
+ "--output",
1153
+ "-o",
1154
+ default="training_output",
1155
+ help="Training output directory (default: training_output)",
1156
+ )
1157
+ viewer_parser.add_argument(
1158
+ "--dashboard",
1159
+ "-d",
1160
+ action="store_true",
1161
+ help="Regenerate dashboard instead of viewer",
1162
+ )
1163
+ viewer_parser.add_argument(
1164
+ "--open",
1165
+ action="store_true",
1166
+ help="Open in browser (use 'serve' instead for better experience)",
1167
+ )
926
1168
 
927
1169
  args = parser.parse_args()
928
1170
 
@@ -942,10 +1184,11 @@ def main():
942
1184
  print("Available GPU instances:\n")
943
1185
  types = client.list_instance_types()
944
1186
  for t in types:
945
- avail = "available" if t.available_regions else "no capacity"
946
1187
  print(f" {t}")
947
1188
  print(f"\nTotal: {len(types)} instance types")
948
- print("\nLaunch with: python -m openadapt_ml.cloud.lambda_labs launch --type <name>")
1189
+ print(
1190
+ "\nLaunch with: python -m openadapt_ml.cloud.lambda_labs launch --type <name>"
1191
+ )
949
1192
 
950
1193
  elif args.command == "status":
951
1194
  instances = client.list_instances()
@@ -968,13 +1211,15 @@ def main():
968
1211
  ssh_key_names=[ssh_key],
969
1212
  name=args.name,
970
1213
  )
971
- print(f"\nInstance launched!")
1214
+ print("\nInstance launched!")
972
1215
  print(f" ID: {instance.id}")
973
1216
  print(f" IP: {instance.ip}")
974
1217
  print(f" Type: {instance.instance_type}")
975
1218
  print(f" Region: {instance.region}")
976
1219
  print(f"\nConnect with: ssh ubuntu@{instance.ip}")
977
- print(f"\nTerminate with: python -m openadapt_ml.cloud.lambda_labs terminate {instance.id}")
1220
+ print(
1221
+ f"\nTerminate with: python -m openadapt_ml.cloud.lambda_labs terminate {instance.id}"
1222
+ )
978
1223
 
979
1224
  elif args.command == "terminate":
980
1225
  if client.terminate_instance(args.instance_id):
@@ -989,14 +1234,16 @@ def main():
989
1234
  return
990
1235
 
991
1236
  if args.instance_id:
992
- instance = next((i for i in instances if i.id.startswith(args.instance_id)), None)
1237
+ instance = next(
1238
+ (i for i in instances if i.id.startswith(args.instance_id)), None
1239
+ )
993
1240
  if not instance:
994
1241
  print(f"Instance {args.instance_id} not found.")
995
1242
  return
996
1243
  else:
997
1244
  instance = instances[0]
998
1245
 
999
- if hasattr(args, 'cmd') and args.cmd:
1246
+ if hasattr(args, "cmd") and args.cmd:
1000
1247
  # Run single command
1001
1248
  print(f"Running on {instance.ip}: {args.cmd}")
1002
1249
  result = client.ssh_run(instance, args.cmd, timeout=args.timeout)
@@ -1018,7 +1265,9 @@ def main():
1018
1265
  return
1019
1266
 
1020
1267
  if args.instance_id:
1021
- instance = next((i for i in instances if i.id.startswith(args.instance_id)), None)
1268
+ instance = next(
1269
+ (i for i in instances if i.id.startswith(args.instance_id)), None
1270
+ )
1022
1271
  if not instance:
1023
1272
  print(f"Instance {args.instance_id} not found.")
1024
1273
  return
@@ -1035,8 +1284,11 @@ def main():
1035
1284
  dest = f"ubuntu@{instance.ip}:{dest[7:]}"
1036
1285
 
1037
1286
  rsync_cmd = [
1038
- "rsync", "-avz", "--progress",
1039
- "-e", "ssh -o StrictHostKeyChecking=no",
1287
+ "rsync",
1288
+ "-avz",
1289
+ "--progress",
1290
+ "-e",
1291
+ "ssh -o StrictHostKeyChecking=no",
1040
1292
  ]
1041
1293
  if args.delete:
1042
1294
  rsync_cmd.append("--delete")
@@ -1056,7 +1308,6 @@ def main():
1056
1308
 
1057
1309
  instance = None
1058
1310
  start_time = time_module.time()
1059
- launched_new = False
1060
1311
  training_completed = False # Track if training actually finished
1061
1312
 
1062
1313
  # Instance pricing (approximate $/hr)
@@ -1071,7 +1322,9 @@ def main():
1071
1322
  # Get or launch instance
1072
1323
  if args.instance:
1073
1324
  instances = client.list_instances()
1074
- instance = next((i for i in instances if i.id.startswith(args.instance)), None)
1325
+ instance = next(
1326
+ (i for i in instances if i.id.startswith(args.instance)), None
1327
+ )
1075
1328
  if not instance:
1076
1329
  print(f"Error: Instance {args.instance} not found")
1077
1330
  return
@@ -1091,7 +1344,6 @@ def main():
1091
1344
  name="openadapt-training",
1092
1345
  )
1093
1346
  print(f"Instance launched: {instance.id[:8]}... at {instance.ip}")
1094
- launched_new = True
1095
1347
 
1096
1348
  price_per_hour = INSTANCE_PRICES.get(instance.instance_type, 1.00)
1097
1349
  print(f" Instance type: {instance.instance_type} (~${price_per_hour:.2f}/hr)")
@@ -1100,16 +1352,21 @@ def main():
1100
1352
  # Generate initial dashboard with setup status
1101
1353
  from pathlib import Path
1102
1354
  from openadapt_ml.training.trainer import (
1103
- TrainingState, TrainingConfig, generate_training_dashboard,
1104
- setup_job_directory
1355
+ TrainingState,
1356
+ TrainingConfig,
1357
+ generate_training_dashboard,
1358
+ setup_job_directory,
1105
1359
  )
1106
1360
  import time as time_module
1361
+
1107
1362
  job_id = time_module.strftime("%Y%m%d_%H%M%S")
1108
1363
  output_dir = setup_job_directory("training_output", job_id)
1109
1364
  dashboard_path = output_dir / "dashboard.html"
1110
1365
  log_path = output_dir / "training_log.json"
1111
1366
 
1112
- def update_dashboard(status: str, logs: list, step: int = 0, loss: float = 0.0, epoch: int = 0):
1367
+ def update_dashboard(
1368
+ status: str, logs: list, step: int = 0, loss: float = 0.0, epoch: int = 0
1369
+ ):
1113
1370
  """Update dashboard with current setup/training status."""
1114
1371
  state = TrainingState(job_id=job_id)
1115
1372
  state.cloud_provider = "lambda"
@@ -1156,9 +1413,13 @@ def main():
1156
1413
  update_dashboard("installing", setup_logs)
1157
1414
  break
1158
1415
  if setup_attempt < 2:
1159
- setup_logs.append(f"Setup attempt {setup_attempt + 1} failed, retrying in 30s...")
1416
+ setup_logs.append(
1417
+ f"Setup attempt {setup_attempt + 1} failed, retrying in 30s..."
1418
+ )
1160
1419
  update_dashboard("booting", setup_logs)
1161
- print(f" Setup attempt {setup_attempt + 1} failed, retrying in 30s...")
1420
+ print(
1421
+ f" Setup attempt {setup_attempt + 1} failed, retrying in 30s..."
1422
+ )
1162
1423
  time_module.sleep(30)
1163
1424
 
1164
1425
  if not setup_success:
@@ -1167,14 +1428,18 @@ def main():
1167
1428
  print("\nError: Failed to set up instance after 3 attempts")
1168
1429
  print(f"Instance still running: {instance.ip}")
1169
1430
  print("Debug via: ssh ubuntu@" + instance.ip)
1170
- print(f"Terminate with: python -m openadapt_ml.cloud.lambda_labs terminate {instance.id}")
1431
+ print(
1432
+ f"Terminate with: python -m openadapt_ml.cloud.lambda_labs terminate {instance.id}"
1433
+ )
1171
1434
  return # Don't terminate - let user debug
1172
1435
 
1173
1436
  # Sync local code to ensure remote has latest changes
1174
1437
  setup_logs.append("Syncing local code to instance...")
1175
1438
  update_dashboard("installing", setup_logs)
1176
1439
  if not client.sync_local_code(instance):
1177
- setup_logs.append("Warning: Failed to sync local code, using remote repo version")
1440
+ setup_logs.append(
1441
+ "Warning: Failed to sync local code, using remote repo version"
1442
+ )
1178
1443
  update_dashboard("installing", setup_logs)
1179
1444
  print("Warning: Failed to sync local code, using remote repo version")
1180
1445
  else:
@@ -1184,7 +1449,7 @@ def main():
1184
1449
  # Upload capture if provided
1185
1450
  remote_capture = None
1186
1451
  if args.capture:
1187
- setup_logs.append(f"Uploading capture data...")
1452
+ setup_logs.append("Uploading capture data...")
1188
1453
  update_dashboard("installing", setup_logs)
1189
1454
  if client.upload_capture(instance, args.capture, "~/capture"):
1190
1455
  remote_capture = "~/capture"
@@ -1197,7 +1462,9 @@ def main():
1197
1462
  print("\nError: Failed to upload capture after retries")
1198
1463
  print(f"Instance still running: {instance.ip}")
1199
1464
  print("Debug via: ssh ubuntu@" + instance.ip)
1200
- print(f"Terminate with: python -m openadapt_ml.cloud.lambda_labs terminate {instance.id}")
1465
+ print(
1466
+ f"Terminate with: python -m openadapt_ml.cloud.lambda_labs terminate {instance.id}"
1467
+ )
1201
1468
  return # Don't terminate - let user debug
1202
1469
 
1203
1470
  # Run training in background and poll for status
@@ -1207,7 +1474,7 @@ def main():
1207
1474
  print("Starting training...")
1208
1475
  print("=" * 50 + "\n")
1209
1476
 
1210
- proc = client.run_training(
1477
+ client.run_training(
1211
1478
  instance,
1212
1479
  config=args.config,
1213
1480
  capture=remote_capture,
@@ -1219,7 +1486,9 @@ def main():
1219
1486
  poll_interval = 10 # seconds
1220
1487
  last_step = 0
1221
1488
  last_epoch = 0
1222
- print(f"Polling training status every {poll_interval}s (Ctrl+C to stop)...\n")
1489
+ print(
1490
+ f"Polling training status every {poll_interval}s (Ctrl+C to stop)...\n"
1491
+ )
1223
1492
 
1224
1493
  while True:
1225
1494
  try:
@@ -1234,7 +1503,9 @@ def main():
1234
1503
 
1235
1504
  # Print progress when step changes
1236
1505
  if step > last_step or epoch > last_epoch:
1237
- print(f" Epoch {epoch+1}/{total_epochs} | Step {step} | Loss: {loss:.4f} | Elapsed: {elapsed_training:.0f}s")
1506
+ print(
1507
+ f" Epoch {epoch + 1}/{total_epochs} | Step {step} | Loss: {loss:.4f} | Elapsed: {elapsed_training:.0f}s"
1508
+ )
1238
1509
  last_step = step
1239
1510
  last_epoch = epoch
1240
1511
 
@@ -1246,7 +1517,9 @@ def main():
1246
1517
  status["instance_type"] = instance.instance_type
1247
1518
  # Add cloud provider info
1248
1519
  status["cloud_provider"] = "lambda"
1249
- status["cloud_dashboard_url"] = "https://cloud.lambda.ai/instances"
1520
+ status["cloud_dashboard_url"] = (
1521
+ "https://cloud.lambda.ai/instances"
1522
+ )
1250
1523
  status["cloud_instance_id"] = instance.id
1251
1524
  status["setup_status"] = "training"
1252
1525
  status["setup_logs"] = setup_logs
@@ -1274,9 +1547,11 @@ def main():
1274
1547
 
1275
1548
  config = TrainingConfig(
1276
1549
  num_train_epochs=total_epochs,
1277
- learning_rate=status.get("learning_rate", 5e-5)
1550
+ learning_rate=status.get("learning_rate", 5e-5),
1551
+ )
1552
+ dashboard_path.write_text(
1553
+ generate_training_dashboard(state, config)
1278
1554
  )
1279
- dashboard_path.write_text(generate_training_dashboard(state, config))
1280
1555
 
1281
1556
  # Check if training is complete (all epochs done)
1282
1557
  if epoch >= total_epochs - 1:
@@ -1318,13 +1593,15 @@ def main():
1318
1593
  print("=" * 50)
1319
1594
 
1320
1595
  # Determine the final checkpoint path (main checkpoint after training)
1321
- checkpoint_path = "/home/ubuntu/openadapt-ml/checkpoints/qwen3vl2b_capture_lora"
1596
+ checkpoint_path = (
1597
+ "/home/ubuntu/openadapt-ml/checkpoints/qwen3vl2b_capture_lora"
1598
+ )
1322
1599
 
1323
1600
  # Check if checkpoint exists
1324
1601
  result = client.ssh_run(
1325
1602
  instance,
1326
1603
  f"ls {checkpoint_path}/adapter_config.json 2>/dev/null && echo 'exists'",
1327
- timeout=30
1604
+ timeout=30,
1328
1605
  )
1329
1606
 
1330
1607
  if "exists" in result.stdout:
@@ -1336,13 +1613,15 @@ def main():
1336
1613
  --checkpoint {checkpoint_path} \
1337
1614
  --output training_output/{output_name} 2>&1"""
1338
1615
 
1339
- print(" Generating comparison viewer (this may take a few minutes)...")
1616
+ print(
1617
+ " Generating comparison viewer (this may take a few minutes)..."
1618
+ )
1340
1619
  result = client.ssh_run(instance, cmd, timeout=600)
1341
1620
 
1342
1621
  if result.returncode == 0:
1343
1622
  print(f" Comparison generated: {output_name}")
1344
1623
  else:
1345
- print(f" Warning: Comparison generation failed")
1624
+ print(" Warning: Comparison generation failed")
1346
1625
  if result.stderr:
1347
1626
  print(f" Error: {result.stderr}")
1348
1627
  else:
@@ -1357,13 +1636,15 @@ def main():
1357
1636
  print(f"\nTerminating instance {instance.id[:8]}...")
1358
1637
  client.terminate_instance(instance.id)
1359
1638
  print("Instance terminated.")
1360
- print(f"\nFinal cost: ~${cost:.2f} ({elapsed/60:.1f} minutes)")
1639
+ print(f"\nFinal cost: ~${cost:.2f} ({elapsed / 60:.1f} minutes)")
1361
1640
  else:
1362
1641
  print(f"\nInstance still running: {instance.ip}")
1363
1642
  print(f" Current cost: ~${cost:.2f}")
1364
1643
  if not training_completed:
1365
- print(f" (Not terminating - training did not complete successfully)")
1366
- print(f"Terminate with: python -m openadapt_ml.cloud.lambda_labs terminate {instance.id}")
1644
+ print(" (Not terminating - training did not complete successfully)")
1645
+ print(
1646
+ f"Terminate with: python -m openadapt_ml.cloud.lambda_labs terminate {instance.id}"
1647
+ )
1367
1648
 
1368
1649
  elif args.command == "train-status":
1369
1650
  instances = client.list_instances()
@@ -1372,7 +1653,9 @@ def main():
1372
1653
  return
1373
1654
 
1374
1655
  if args.instance_id:
1375
- instance = next((i for i in instances if i.id.startswith(args.instance_id)), None)
1656
+ instance = next(
1657
+ (i for i in instances if i.id.startswith(args.instance_id)), None
1658
+ )
1376
1659
  if not instance:
1377
1660
  print(f"Instance {args.instance_id} not found.")
1378
1661
  return
@@ -1398,7 +1681,9 @@ def main():
1398
1681
  return
1399
1682
 
1400
1683
  if args.instance_id:
1401
- instance = next((i for i in instances if i.id.startswith(args.instance_id)), None)
1684
+ instance = next(
1685
+ (i for i in instances if i.id.startswith(args.instance_id)), None
1686
+ )
1402
1687
  if not instance:
1403
1688
  print(f"Instance {args.instance_id} not found.")
1404
1689
  return
@@ -1408,10 +1693,14 @@ def main():
1408
1693
  print(f"Checking checkpoints on {instance.ip}...")
1409
1694
 
1410
1695
  ssh_cmd = [
1411
- "ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=10",
1696
+ "ssh",
1697
+ "-o",
1698
+ "StrictHostKeyChecking=no",
1699
+ "-o",
1700
+ "ConnectTimeout=10",
1412
1701
  f"ubuntu@{instance.ip}",
1413
1702
  "ls -la ~/openadapt-ml/checkpoints/ 2>/dev/null && "
1414
- "du -sh ~/openadapt-ml/checkpoints/ 2>/dev/null || echo 'No checkpoints directory found'"
1703
+ "du -sh ~/openadapt-ml/checkpoints/ 2>/dev/null || echo 'No checkpoints directory found'",
1415
1704
  ]
1416
1705
 
1417
1706
  result = subprocess.run(ssh_cmd, capture_output=True, text=True)
@@ -1426,7 +1715,11 @@ def main():
1426
1715
  # One-shot dashboard refresh
1427
1716
  import time as time_module
1428
1717
  from pathlib import Path
1429
- from openadapt_ml.training.trainer import TrainingState, TrainingConfig, generate_training_dashboard
1718
+ from openadapt_ml.training.trainer import (
1719
+ TrainingState,
1720
+ TrainingConfig,
1721
+ generate_training_dashboard,
1722
+ )
1430
1723
 
1431
1724
  instances = client.list_instances()
1432
1725
  if not instances:
@@ -1434,7 +1727,9 @@ def main():
1434
1727
  return
1435
1728
 
1436
1729
  if args.instance_id:
1437
- instance = next((i for i in instances if i.id.startswith(args.instance_id)), None)
1730
+ instance = next(
1731
+ (i for i in instances if i.id.startswith(args.instance_id)), None
1732
+ )
1438
1733
  if not instance:
1439
1734
  print(f"Instance {args.instance_id} not found.")
1440
1735
  return
@@ -1442,7 +1737,11 @@ def main():
1442
1737
  instance = instances[0]
1443
1738
 
1444
1739
  # Use current job directory via symlink
1445
- from openadapt_ml.training.trainer import get_current_job_directory, setup_job_directory
1740
+ from openadapt_ml.training.trainer import (
1741
+ get_current_job_directory,
1742
+ setup_job_directory,
1743
+ )
1744
+
1446
1745
  base_dir = Path("training_output")
1447
1746
  base_dir.mkdir(exist_ok=True)
1448
1747
 
@@ -1459,7 +1758,9 @@ def main():
1459
1758
  log_path = output_dir / "training_log.json"
1460
1759
 
1461
1760
  # Setup screenshots symlink if local capture path provided
1462
- local_capture = args.capture if hasattr(args, 'capture') and args.capture else None
1761
+ local_capture = (
1762
+ args.capture if hasattr(args, "capture") and args.capture else None
1763
+ )
1463
1764
  if local_capture:
1464
1765
  setup_capture_screenshots_symlink(output_dir, local_capture)
1465
1766
 
@@ -1483,7 +1784,9 @@ def main():
1483
1784
  state.instance_type = instance.instance_type
1484
1785
  state.config_path = status.get("config_path", "")
1485
1786
  # Use local capture path for screenshots if provided, else remote path
1486
- state.capture_path = args.capture if args.capture else status.get("capture_path", "")
1787
+ state.capture_path = (
1788
+ args.capture if args.capture else status.get("capture_path", "")
1789
+ )
1487
1790
  state.epoch = status.get("epoch", 0)
1488
1791
  state.step = status.get("step", 0)
1489
1792
  state.loss = status.get("loss", 0)
@@ -1501,7 +1804,7 @@ def main():
1501
1804
 
1502
1805
  config = TrainingConfig(
1503
1806
  num_train_epochs=status.get("total_epochs", 5),
1504
- learning_rate=status.get("learning_rate", 5e-5)
1807
+ learning_rate=status.get("learning_rate", 5e-5),
1505
1808
  )
1506
1809
 
1507
1810
  dashboard_path.write_text(generate_training_dashboard(state, config))
@@ -1509,6 +1812,7 @@ def main():
1509
1812
  # Regenerate navigation for file:// protocol
1510
1813
  try:
1511
1814
  from openadapt_ml.training.trainer import regenerate_all_dashboards
1815
+
1512
1816
  regenerate_all_dashboards(output_dir)
1513
1817
  except Exception:
1514
1818
  pass # Silent fail for navigation
@@ -1517,11 +1821,14 @@ def main():
1517
1821
  step = status.get("step", 0)
1518
1822
  loss = status.get("loss", 0)
1519
1823
  elapsed = status.get("elapsed_time", 0)
1520
- print(f"Epoch {epoch+1}/{state.total_epochs} | Step {step} | Loss: {loss:.4f} | Elapsed: {elapsed:.0f}s")
1824
+ print(
1825
+ f"Epoch {epoch + 1}/{state.total_epochs} | Step {step} | Loss: {loss:.4f} | Elapsed: {elapsed:.0f}s"
1826
+ )
1521
1827
  print(f"Dashboard: {dashboard_path.absolute()}")
1522
1828
 
1523
1829
  if args.open:
1524
1830
  import subprocess as sp
1831
+
1525
1832
  sp.run(["open", str(dashboard_path)], capture_output=True)
1526
1833
  else:
1527
1834
  print("No training data yet")
@@ -1533,10 +1840,12 @@ def main():
1533
1840
  from pathlib import Path
1534
1841
 
1535
1842
  # Stub mode - simulate training without actual GPU
1536
- if getattr(args, 'stub', False):
1843
+ if getattr(args, "stub", False):
1537
1844
  from openadapt_ml.training.stub_provider import StubTrainingProvider
1538
1845
  from openadapt_ml.training.trainer import (
1539
- TrainingState, TrainingConfig, generate_training_dashboard
1846
+ TrainingState,
1847
+ TrainingConfig,
1848
+ generate_training_dashboard,
1540
1849
  )
1541
1850
 
1542
1851
  print("\n[Stub Mode] Simulating training without GPU...")
@@ -1574,7 +1883,7 @@ def main():
1574
1883
 
1575
1884
  config = TrainingConfig(
1576
1885
  num_train_epochs=status.get("total_epochs", 5),
1577
- learning_rate=state.learning_rate
1886
+ learning_rate=state.learning_rate,
1578
1887
  )
1579
1888
 
1580
1889
  dashboard_path = output_dir / "dashboard.html"
@@ -1598,7 +1907,9 @@ def main():
1598
1907
  return
1599
1908
 
1600
1909
  if args.instance_id:
1601
- instance = next((i for i in instances if i.id.startswith(args.instance_id)), None)
1910
+ instance = next(
1911
+ (i for i in instances if i.id.startswith(args.instance_id)), None
1912
+ )
1602
1913
  if not instance:
1603
1914
  print(f"Instance {args.instance_id} not found.")
1604
1915
  return
@@ -1621,9 +1932,13 @@ def main():
1621
1932
 
1622
1933
  # Use job-scoped directory structure
1623
1934
  from openadapt_ml.training.trainer import (
1624
- TrainingState, TrainingConfig, generate_training_dashboard,
1625
- setup_job_directory, get_current_job_directory
1935
+ TrainingState,
1936
+ TrainingConfig,
1937
+ generate_training_dashboard,
1938
+ setup_job_directory,
1939
+ get_current_job_directory,
1626
1940
  )
1941
+
1627
1942
  base_dir = Path("training_output")
1628
1943
  base_dir.mkdir(exist_ok=True)
1629
1944
 
@@ -1654,7 +1969,11 @@ def main():
1654
1969
  state.instance_ip = instance.ip or ""
1655
1970
  state.instance_type = instance.instance_type
1656
1971
  state.setup_status = "booting"
1657
- state.setup_logs = ["Starting Lambda Cloud instance...", f"Instance ID: {instance.id[:8]}...", f"Instance type: {instance.instance_type}"]
1972
+ state.setup_logs = [
1973
+ "Starting Lambda Cloud instance...",
1974
+ f"Instance ID: {instance.id[:8]}...",
1975
+ f"Instance type: {instance.instance_type}",
1976
+ ]
1658
1977
  config = TrainingConfig(num_train_epochs=5, learning_rate=5e-5)
1659
1978
  dashboard_path.write_text(generate_training_dashboard(state, config))
1660
1979
 
@@ -1665,12 +1984,14 @@ def main():
1665
1984
 
1666
1985
  last_step = 0
1667
1986
  last_epoch = -1
1668
- auto_stop_loss = getattr(args, 'auto_stop_loss', 0.5)
1669
- download_checkpoints = getattr(args, 'download_checkpoints', True)
1987
+ auto_stop_loss = getattr(args, "auto_stop_loss", 0.5)
1988
+ download_checkpoints = getattr(args, "download_checkpoints", True)
1670
1989
  step_stall_count = 0 # Track how many times step hasn't increased
1671
1990
 
1672
1991
  print(f" Auto-stop loss threshold: {auto_stop_loss}")
1673
- print(f" Checkpoint download: {'enabled' if download_checkpoints else 'disabled'}")
1992
+ print(
1993
+ f" Checkpoint download: {'enabled' if download_checkpoints else 'disabled'}"
1994
+ )
1674
1995
 
1675
1996
  try:
1676
1997
  while True:
@@ -1684,10 +2005,11 @@ def main():
1684
2005
  # Update status with termination info before terminating
1685
2006
  termination_status = {
1686
2007
  "termination_status": "user_stop",
1687
- "termination_message": "Training stopped by user via dashboard"
2008
+ "termination_message": "Training stopped by user via dashboard",
1688
2009
  }
1689
2010
  current_log = log_path.read_text() if log_path.exists() else "{}"
1690
2011
  import json as json_module
2012
+
1691
2013
  current_data = json_module.loads(current_log)
1692
2014
  current_data.update(termination_status)
1693
2015
  log_path.write_text(json_module.dumps(current_data, indent=2))
@@ -1711,8 +2033,14 @@ def main():
1711
2033
  remote_job_id = status.get("job_id")
1712
2034
 
1713
2035
  # Detect job_id change - clear old data if new job started
1714
- if remote_job_id and current_job_id and remote_job_id != current_job_id:
1715
- print(f"\n New job detected: {remote_job_id} (was: {current_job_id})")
2036
+ if (
2037
+ remote_job_id
2038
+ and current_job_id
2039
+ and remote_job_id != current_job_id
2040
+ ):
2041
+ print(
2042
+ f"\n New job detected: {remote_job_id} (was: {current_job_id})"
2043
+ )
1716
2044
  print(" Clearing old job data...")
1717
2045
  last_step = 0 # Reset step tracking
1718
2046
  current_job_id = remote_job_id
@@ -1727,25 +2055,37 @@ def main():
1727
2055
  status["instance_type"] = instance.instance_type
1728
2056
  # Add cloud provider info
1729
2057
  status["cloud_provider"] = "lambda"
1730
- status["cloud_dashboard_url"] = "https://cloud.lambda.ai/instances"
2058
+ status["cloud_dashboard_url"] = (
2059
+ "https://cloud.lambda.ai/instances"
2060
+ )
1731
2061
  status["cloud_instance_id"] = instance.id
1732
2062
  status["setup_status"] = status.get("setup_status", "training")
1733
2063
 
1734
2064
  # Setup screenshots symlink if local capture path provided
1735
- local_capture = args.capture if hasattr(args, 'capture') and args.capture else None
2065
+ local_capture = (
2066
+ args.capture
2067
+ if hasattr(args, "capture") and args.capture
2068
+ else None
2069
+ )
1736
2070
  if local_capture:
1737
2071
  setup_capture_screenshots_symlink(output_dir, local_capture)
1738
2072
 
1739
2073
  # Rewrite evaluation paths from Lambda to relative
1740
2074
  if "evaluations" in status:
1741
- status["evaluations"] = rewrite_evaluation_paths(status["evaluations"])
2075
+ status["evaluations"] = rewrite_evaluation_paths(
2076
+ status["evaluations"]
2077
+ )
1742
2078
 
1743
2079
  log_path.write_text(json.dumps(status, indent=2))
1744
2080
 
1745
2081
  if step > last_step:
1746
- print(f" Epoch {epoch+1} | Step {step} | Loss: {loss:.4f} | Elapsed: {elapsed:.0f}s")
2082
+ print(
2083
+ f" Epoch {epoch + 1} | Step {step} | Loss: {loss:.4f} | Elapsed: {elapsed:.0f}s"
2084
+ )
1747
2085
  last_step = step
1748
- step_stall_count = 0 # Reset stall counter when step increases
2086
+ step_stall_count = (
2087
+ 0 # Reset stall counter when step increases
2088
+ )
1749
2089
  if not current_job_id:
1750
2090
  current_job_id = remote_job_id
1751
2091
 
@@ -1764,39 +2104,59 @@ def main():
1764
2104
  state.start_time = time_module.time() - elapsed
1765
2105
  # Cloud provider info
1766
2106
  state.cloud_provider = "lambda"
1767
- state.cloud_dashboard_url = "https://cloud.lambda.ai/instances"
2107
+ state.cloud_dashboard_url = (
2108
+ "https://cloud.lambda.ai/instances"
2109
+ )
1768
2110
  state.cloud_instance_id = instance.id
1769
2111
  state.setup_status = status.get("setup_status", "training")
1770
2112
  state.setup_logs = status.get("setup_logs", [])
1771
- state.termination_status = status.get("termination_status", "")
1772
- state.termination_message = status.get("termination_message", "")
2113
+ state.termination_status = status.get(
2114
+ "termination_status", ""
2115
+ )
2116
+ state.termination_message = status.get(
2117
+ "termination_message", ""
2118
+ )
1773
2119
 
1774
2120
  config = TrainingConfig(
1775
2121
  num_train_epochs=status.get("total_epochs", 5),
1776
- learning_rate=status.get("learning_rate", 5e-5)
2122
+ learning_rate=status.get("learning_rate", 5e-5),
1777
2123
  )
1778
2124
 
1779
- dashboard_path.write_text(generate_training_dashboard(state, config))
2125
+ dashboard_path.write_text(
2126
+ generate_training_dashboard(state, config)
2127
+ )
1780
2128
 
1781
2129
  # Download checkpoints on epoch change
1782
2130
  if download_checkpoints and epoch > last_epoch:
1783
- print(f" Epoch {epoch+1} completed - downloading checkpoints...")
1784
- if download_checkpoints_from_instance(instance.ip, output_dir):
1785
- print(f" Checkpoints saved to {output_dir}/checkpoints/")
2131
+ print(
2132
+ f" Epoch {epoch + 1} completed - downloading checkpoints..."
2133
+ )
2134
+ if download_checkpoints_from_instance(
2135
+ instance.ip, output_dir
2136
+ ):
2137
+ print(
2138
+ f" Checkpoints saved to {output_dir}/checkpoints/"
2139
+ )
1786
2140
  else:
1787
2141
  print(" Warning: checkpoint download failed")
1788
2142
  last_epoch = epoch
1789
2143
 
1790
2144
  # Auto-terminate when loss is low enough
1791
2145
  if loss < auto_stop_loss and loss > 0:
1792
- print(f"\n Loss {loss:.4f} < threshold {auto_stop_loss}")
2146
+ print(
2147
+ f"\n Loss {loss:.4f} < threshold {auto_stop_loss}"
2148
+ )
1793
2149
  print(" Downloading final checkpoints...")
1794
2150
  if download_checkpoints:
1795
- download_checkpoints_from_instance(instance.ip, output_dir)
2151
+ download_checkpoints_from_instance(
2152
+ instance.ip, output_dir
2153
+ )
1796
2154
 
1797
2155
  # Update status with termination info
1798
2156
  status["termination_status"] = "auto_low_loss"
1799
- status["termination_message"] = f"Training auto-stopped: loss {loss:.4f} < threshold {auto_stop_loss}"
2157
+ status["termination_message"] = (
2158
+ f"Training auto-stopped: loss {loss:.4f} < threshold {auto_stop_loss}"
2159
+ )
1800
2160
  log_path.write_text(json.dumps(status, indent=2))
1801
2161
 
1802
2162
  print(f" Auto-terminating instance {instance.id}...")
@@ -1810,14 +2170,20 @@ def main():
1810
2170
 
1811
2171
  # If on last epoch and step hasn't increased for 3 polls, training is complete
1812
2172
  if epoch >= total_epochs - 1 and step_stall_count >= 3:
1813
- print(f"\n Training complete (epoch {epoch+1}/{total_epochs}, step stopped increasing)")
2173
+ print(
2174
+ f"\n Training complete (epoch {epoch + 1}/{total_epochs}, step stopped increasing)"
2175
+ )
1814
2176
  print(" Downloading final checkpoints...")
1815
2177
  if download_checkpoints:
1816
- download_checkpoints_from_instance(instance.ip, output_dir)
2178
+ download_checkpoints_from_instance(
2179
+ instance.ip, output_dir
2180
+ )
1817
2181
 
1818
2182
  # Update status with termination info
1819
2183
  status["termination_status"] = "auto_complete"
1820
- status["termination_message"] = f"Training completed successfully ({epoch+1}/{total_epochs} epochs)"
2184
+ status["termination_message"] = (
2185
+ f"Training completed successfully ({epoch + 1}/{total_epochs} epochs)"
2186
+ )
1821
2187
  log_path.write_text(json.dumps(status, indent=2))
1822
2188
 
1823
2189
  print(f" Terminating instance {instance.id}...")
@@ -1849,7 +2215,9 @@ def main():
1849
2215
  return
1850
2216
 
1851
2217
  if args.instance_id:
1852
- instance = next((i for i in instances if i.id.startswith(args.instance_id)), None)
2218
+ instance = next(
2219
+ (i for i in instances if i.id.startswith(args.instance_id)), None
2220
+ )
1853
2221
  if not instance:
1854
2222
  print(f"Instance {args.instance_id} not found.")
1855
2223
  return
@@ -1857,9 +2225,13 @@ def main():
1857
2225
  instance = instances[0]
1858
2226
 
1859
2227
  print(f"Files on {instance.ip} at {args.path}:")
1860
- result = client.ssh_run(instance, f"find {args.path} -type f -name '*.pt' -o -name '*.json' -o -name '*.bin' 2>/dev/null | head -20", timeout=30)
2228
+ result = client.ssh_run(
2229
+ instance,
2230
+ f"find {args.path} -type f -name '*.pt' -o -name '*.json' -o -name '*.bin' 2>/dev/null | head -20",
2231
+ timeout=30,
2232
+ )
1861
2233
  if result.stdout:
1862
- for line in result.stdout.strip().split('\n'):
2234
+ for line in result.stdout.strip().split("\n"):
1863
2235
  print(f" {line}")
1864
2236
  else:
1865
2237
  print(" (no checkpoint files found)")
@@ -1872,18 +2244,16 @@ def main():
1872
2244
  if args.local:
1873
2245
  print("\nKilling local Lambda-related processes...")
1874
2246
  subprocess.run(
1875
- ["pkill", "-f", "ssh.*ubuntu@.*openadapt"],
1876
- capture_output=True
1877
- )
1878
- subprocess.run(
1879
- ["pkill", "-f", "lambda_labs"],
1880
- capture_output=True
2247
+ ["pkill", "-f", "ssh.*ubuntu@.*openadapt"], capture_output=True
1881
2248
  )
2249
+ subprocess.run(["pkill", "-f", "lambda_labs"], capture_output=True)
1882
2250
  print("Done.")
1883
2251
  return
1884
2252
 
1885
2253
  if args.instance_id:
1886
- instance = next((i for i in instances if i.id.startswith(args.instance_id)), None)
2254
+ instance = next(
2255
+ (i for i in instances if i.id.startswith(args.instance_id)), None
2256
+ )
1887
2257
  if not instance:
1888
2258
  print(f"Instance {args.instance_id} not found.")
1889
2259
  return
@@ -1896,11 +2266,11 @@ def main():
1896
2266
  result = client.ssh_run(
1897
2267
  instance,
1898
2268
  "ps aux | grep python | grep -v grep | grep -v jupyter",
1899
- timeout=30
2269
+ timeout=30,
1900
2270
  )
1901
2271
  if result.stdout.strip():
1902
2272
  print("Found Python processes:")
1903
- for line in result.stdout.strip().split('\n'):
2273
+ for line in result.stdout.strip().split("\n"):
1904
2274
  print(f" {line[:100]}...")
1905
2275
  else:
1906
2276
  print("No training/inference Python processes found.")
@@ -1908,7 +2278,9 @@ def main():
1908
2278
 
1909
2279
  if args.all:
1910
2280
  print("\nKilling ALL Python processes (except jupyter)...")
1911
- cmd = "pkill -f 'python.*train\\|python.*compare\\|python.*openadapt' || true"
2281
+ cmd = (
2282
+ "pkill -f 'python.*train\\|python.*compare\\|python.*openadapt' || true"
2283
+ )
1912
2284
  else:
1913
2285
  print("\nKilling training and inference processes...")
1914
2286
  cmd = "pkill -f 'python.*train' ; pkill -f 'python.*compare' || true"
@@ -1919,20 +2291,16 @@ def main():
1919
2291
  if args.local:
1920
2292
  print("\nKilling local Lambda-related processes...")
1921
2293
  subprocess.run(
1922
- ["pkill", "-f", "ssh.*ubuntu@.*openadapt"],
1923
- capture_output=True
1924
- )
1925
- subprocess.run(
1926
- ["pkill", "-f", "lambda_labs.*train"],
1927
- capture_output=True
2294
+ ["pkill", "-f", "ssh.*ubuntu@.*openadapt"], capture_output=True
1928
2295
  )
2296
+ subprocess.run(["pkill", "-f", "lambda_labs.*train"], capture_output=True)
1929
2297
  print("Local processes killed.")
1930
2298
 
1931
2299
  print("\nDone. Current status:")
1932
2300
  result = client.ssh_run(
1933
2301
  instance,
1934
2302
  "ps aux | grep python | grep -v grep | grep -v jupyter | wc -l",
1935
- timeout=30
2303
+ timeout=30,
1936
2304
  )
1937
2305
  count = result.stdout.strip()
1938
2306
  print(f" {count} Python processes remaining on instance")
@@ -1945,7 +2313,9 @@ def main():
1945
2313
  return
1946
2314
 
1947
2315
  if args.instance_id:
1948
- instance = next((i for i in instances if i.id.startswith(args.instance_id)), None)
2316
+ instance = next(
2317
+ (i for i in instances if i.id.startswith(args.instance_id)), None
2318
+ )
1949
2319
  if not instance:
1950
2320
  print(f"Instance {args.instance_id} not found.")
1951
2321
  return
@@ -1958,7 +2328,7 @@ def main():
1958
2328
  result = client.ssh_run(
1959
2329
  instance,
1960
2330
  "cat ~/openadapt-ml/training_output/training_log.json 2>/dev/null",
1961
- timeout=30
2331
+ timeout=30,
1962
2332
  )
1963
2333
 
1964
2334
  if not result.stdout.strip():
@@ -1977,77 +2347,87 @@ def main():
1977
2347
  return
1978
2348
 
1979
2349
  total_steps = len(losses)
1980
- epochs = sorted(set(l["epoch"] for l in losses))
2350
+ epochs = sorted(set(loss["epoch"] for loss in losses))
1981
2351
  total_epochs = data.get("total_epochs", 5)
1982
- min_loss = min(l["loss"] for l in losses)
2352
+ min_loss = min(loss["loss"] for loss in losses)
1983
2353
  current_loss = losses[-1]["loss"]
1984
2354
 
1985
- print(f"\n{'='*50}")
1986
- print(f"TRAINING STATUS")
1987
- print(f"{'='*50}")
2355
+ print(f"\n{'=' * 50}")
2356
+ print("TRAINING STATUS")
2357
+ print(f"{'=' * 50}")
1988
2358
  print(f"Steps: {total_steps}")
1989
- print(f"Epochs: {max(epochs)+1}/{total_epochs}")
2359
+ print(f"Epochs: {max(epochs) + 1}/{total_epochs}")
1990
2360
  print(f"Current loss: {current_loss:.4f}")
1991
2361
  print(f"Min loss: {min_loss:.4f}")
1992
2362
 
1993
2363
  # Check if training is running
1994
2364
  proc_result = client.ssh_run(
1995
- instance,
1996
- "ps aux | grep 'python.*train' | grep -v grep | wc -l",
1997
- timeout=30
2365
+ instance, "ps aux | grep 'python.*train' | grep -v grep | wc -l", timeout=30
1998
2366
  )
1999
2367
  is_running = int(proc_result.stdout.strip()) > 0
2000
2368
 
2001
2369
  if is_running:
2002
- print(f"Status: RUNNING")
2370
+ print("Status: RUNNING")
2003
2371
  else:
2004
- print(f"Status: STOPPED")
2372
+ print("Status: STOPPED")
2005
2373
 
2006
2374
  # Early stopping analysis
2007
2375
  window = min(args.window, len(losses))
2008
2376
  if window < 2:
2009
2377
  print("\nNot enough data for early stopping analysis.")
2010
2378
  else:
2011
- recent_losses = [l["loss"] for l in losses[-window:]]
2012
- older_losses = [l["loss"] for l in losses[-window*2:-window]] if len(losses) >= window*2 else [l["loss"] for l in losses[:window]]
2379
+ recent_losses = [loss["loss"] for loss in losses[-window:]]
2380
+ older_losses = (
2381
+ [loss["loss"] for loss in losses[-window * 2 : -window]]
2382
+ if len(losses) >= window * 2
2383
+ else [loss["loss"] for loss in losses[:window]]
2384
+ )
2013
2385
 
2014
2386
  recent_avg = sum(recent_losses) / len(recent_losses)
2015
- older_avg = sum(older_losses) / len(older_losses) if older_losses else recent_avg
2387
+ older_avg = (
2388
+ sum(older_losses) / len(older_losses) if older_losses else recent_avg
2389
+ )
2016
2390
 
2017
2391
  improvement = (older_avg - recent_avg) / older_avg if older_avg > 0 else 0
2018
2392
  loss_variance = max(recent_losses) - min(recent_losses)
2019
2393
 
2020
- print(f"\n{'='*50}")
2394
+ print(f"\n{'=' * 50}")
2021
2395
  print(f"EARLY STOPPING ANALYSIS (window={window})")
2022
- print(f"{'='*50}")
2396
+ print(f"{'=' * 50}")
2023
2397
  print(f"Recent avg loss: {recent_avg:.4f}")
2024
2398
  print(f"Prior avg loss: {older_avg:.4f}")
2025
- print(f"Improvement: {improvement*100:.2f}%")
2399
+ print(f"Improvement: {improvement * 100:.2f}%")
2026
2400
  print(f"Loss variance: {loss_variance:.4f}")
2027
2401
 
2028
2402
  should_stop = improvement < args.threshold and loss_variance < 0.1
2029
2403
  if should_stop:
2030
- print(f"\n⚠️ EARLY STOPPING RECOMMENDED")
2031
- print(f" Loss has plateaued (improvement < {args.threshold*100}%)")
2404
+ print("\n⚠️ EARLY STOPPING RECOMMENDED")
2405
+ print(f" Loss has plateaued (improvement < {args.threshold * 100}%)")
2032
2406
  if not is_running:
2033
- print(f" (Training already stopped)")
2407
+ print(" (Training already stopped)")
2034
2408
  else:
2035
- print(f"\n To stop: uv run python -m openadapt_ml.cloud.lambda_labs kill")
2409
+ print(
2410
+ "\n To stop: uv run python -m openadapt_ml.cloud.lambda_labs kill"
2411
+ )
2036
2412
  else:
2037
- print(f"\n✓ Training still improving, continue.")
2413
+ print("\n✓ Training still improving, continue.")
2038
2414
 
2039
2415
  # Time estimate
2040
2416
  if is_running and len(losses) >= 2:
2041
- avg_time_per_step = losses[-1].get("time", 0) / len(losses) if losses[-1].get("time") else 50
2417
+ avg_time_per_step = (
2418
+ losses[-1].get("time", 0) / len(losses)
2419
+ if losses[-1].get("time")
2420
+ else 50
2421
+ )
2042
2422
  steps_per_epoch = len(losses) / (max(epochs) + 1)
2043
2423
  remaining_epochs = total_epochs - max(epochs) - 1
2044
2424
  remaining_steps = remaining_epochs * steps_per_epoch
2045
2425
  eta_seconds = remaining_steps * avg_time_per_step
2046
2426
  eta_mins = eta_seconds / 60
2047
2427
 
2048
- print(f"\n{'='*50}")
2049
- print(f"TIME ESTIMATE")
2050
- print(f"{'='*50}")
2428
+ print(f"\n{'=' * 50}")
2429
+ print("TIME ESTIMATE")
2430
+ print(f"{'=' * 50}")
2051
2431
  print(f"Remaining epochs: {remaining_epochs}")
2052
2432
  print(f"Est. remaining steps: {remaining_steps:.0f}")
2053
2433
  print(f"ETA: {eta_mins:.1f} minutes")
@@ -2060,7 +2440,9 @@ def main():
2060
2440
  return
2061
2441
 
2062
2442
  if args.instance_id:
2063
- instance = next((i for i in instances if i.id.startswith(args.instance_id)), None)
2443
+ instance = next(
2444
+ (i for i in instances if i.id.startswith(args.instance_id)), None
2445
+ )
2064
2446
  if not instance:
2065
2447
  print(f"Instance {args.instance_id} not found.")
2066
2448
  return
@@ -2071,24 +2453,26 @@ def main():
2071
2453
  if args.checkpoint:
2072
2454
  checkpoint_path = args.checkpoint
2073
2455
  elif args.epoch is not None:
2074
- checkpoint_path = f"/home/ubuntu/openadapt-ml/checkpoints/epoch_{args.epoch}"
2456
+ checkpoint_path = (
2457
+ f"/home/ubuntu/openadapt-ml/checkpoints/epoch_{args.epoch}"
2458
+ )
2075
2459
  else:
2076
2460
  # Use latest (main checkpoint)
2077
- checkpoint_path = "/home/ubuntu/openadapt-ml/checkpoints/qwen3vl2b_capture_lora"
2461
+ checkpoint_path = (
2462
+ "/home/ubuntu/openadapt-ml/checkpoints/qwen3vl2b_capture_lora"
2463
+ )
2078
2464
 
2079
2465
  # Check if checkpoint exists
2080
2466
  result = client.ssh_run(
2081
2467
  instance,
2082
2468
  f"ls {checkpoint_path}/adapter_config.json 2>/dev/null && echo 'exists'",
2083
- timeout=30
2469
+ timeout=30,
2084
2470
  )
2085
2471
  if "exists" not in result.stdout:
2086
2472
  print(f"Checkpoint not found at {checkpoint_path}")
2087
2473
  # List available checkpoints
2088
2474
  result = client.ssh_run(
2089
- instance,
2090
- "ls -la ~/openadapt-ml/checkpoints/",
2091
- timeout=30
2475
+ instance, "ls -la ~/openadapt-ml/checkpoints/", timeout=30
2092
2476
  )
2093
2477
  print(f"Available checkpoints:\n{result.stdout}")
2094
2478
  return
@@ -2113,9 +2497,7 @@ def main():
2113
2497
 
2114
2498
  # Check if file was created
2115
2499
  result = client.ssh_run(
2116
- instance,
2117
- f"ls -la ~/openadapt-ml/training_output/{output_name}",
2118
- timeout=30
2500
+ instance, f"ls -la ~/openadapt-ml/training_output/{output_name}", timeout=30
2119
2501
  )
2120
2502
  if result.returncode != 0:
2121
2503
  print("Comparison file not created.")
@@ -2128,11 +2510,15 @@ def main():
2128
2510
  local_output.parent.mkdir(parents=True, exist_ok=True)
2129
2511
 
2130
2512
  print(f"Syncing to {local_output}...")
2131
- subprocess.run([
2132
- "rsync", "-avz",
2133
- f"ubuntu@{instance.ip}:~/openadapt-ml/training_output/{output_name}",
2134
- str(local_output)
2135
- ], capture_output=True)
2513
+ subprocess.run(
2514
+ [
2515
+ "rsync",
2516
+ "-avz",
2517
+ f"ubuntu@{instance.ip}:~/openadapt-ml/training_output/{output_name}",
2518
+ str(local_output),
2519
+ ],
2520
+ capture_output=True,
2521
+ )
2136
2522
 
2137
2523
  print(f"Done! Comparison saved to: {local_output}")
2138
2524
 
@@ -2147,7 +2533,9 @@ def main():
2147
2533
  return
2148
2534
 
2149
2535
  if args.instance_id:
2150
- instance = next((i for i in instances if i.id.startswith(args.instance_id)), None)
2536
+ instance = next(
2537
+ (i for i in instances if i.id.startswith(args.instance_id)), None
2538
+ )
2151
2539
  if not instance:
2152
2540
  print(f"Instance {args.instance_id} not found.")
2153
2541
  return
@@ -2164,7 +2552,9 @@ def main():
2164
2552
  return
2165
2553
 
2166
2554
  if args.instance_id:
2167
- instance = next((i for i in instances if i.id.startswith(args.instance_id)), None)
2555
+ instance = next(
2556
+ (i for i in instances if i.id.startswith(args.instance_id)), None
2557
+ )
2168
2558
  if not instance:
2169
2559
  print(f"Instance {args.instance_id} not found.")
2170
2560
  return
@@ -2180,10 +2570,17 @@ def main():
2180
2570
  checkpoint_path = "checkpoints_lambda/qwen3vl2b_capture_lora"
2181
2571
 
2182
2572
  import subprocess as sp
2573
+
2183
2574
  cmd = [
2184
- "uv", "run", "python", "-m", "openadapt_ml.scripts.compare",
2185
- "--capture", args.capture,
2186
- "--checkpoint", checkpoint_path,
2575
+ "uv",
2576
+ "run",
2577
+ "python",
2578
+ "-m",
2579
+ "openadapt_ml.scripts.compare",
2580
+ "--capture",
2581
+ args.capture,
2582
+ "--checkpoint",
2583
+ checkpoint_path,
2187
2584
  ]
2188
2585
  if args.goal:
2189
2586
  cmd.extend(["--goal", args.goal])
@@ -2202,11 +2599,12 @@ def main():
2202
2599
  # Start web server for live dashboard with stop button support
2203
2600
  import http.server
2204
2601
  import socketserver
2205
- import threading
2206
2602
  import time as time_module
2207
2603
  from pathlib import Path
2208
2604
 
2209
- output_dir = Path(args.output) if hasattr(args, 'output') else Path("training_output")
2605
+ output_dir = (
2606
+ Path(args.output) if hasattr(args, "output") else Path("training_output")
2607
+ )
2210
2608
  port = args.port
2211
2609
 
2212
2610
  if not output_dir.exists():
@@ -2219,13 +2617,13 @@ def main():
2219
2617
  super().__init__(*args, directory=str(output_dir), **kwargs)
2220
2618
 
2221
2619
  def do_POST(self):
2222
- if self.path == '/api/stop':
2620
+ if self.path == "/api/stop":
2223
2621
  # Create stop signal file
2224
2622
  stop_file = output_dir / "STOP_TRAINING"
2225
2623
  stop_file.touch()
2226
2624
  self.send_response(200)
2227
- self.send_header('Content-Type', 'application/json')
2228
- self.send_header('Access-Control-Allow-Origin', '*')
2625
+ self.send_header("Content-Type", "application/json")
2626
+ self.send_header("Access-Control-Allow-Origin", "*")
2229
2627
  self.end_headers()
2230
2628
  self.wfile.write(b'{"status": "stop signal created"}')
2231
2629
  print(f" Stop signal created: {stop_file}")
@@ -2235,15 +2633,14 @@ def main():
2235
2633
  def do_OPTIONS(self):
2236
2634
  # Handle CORS preflight
2237
2635
  self.send_response(200)
2238
- self.send_header('Access-Control-Allow-Origin', '*')
2239
- self.send_header('Access-Control-Allow-Methods', 'POST, OPTIONS')
2240
- self.send_header('Access-Control-Allow-Headers', 'Content-Type')
2636
+ self.send_header("Access-Control-Allow-Origin", "*")
2637
+ self.send_header("Access-Control-Allow-Methods", "POST, OPTIONS")
2638
+ self.send_header("Access-Control-Allow-Headers", "Content-Type")
2241
2639
  self.end_headers()
2242
2640
 
2243
2641
  def log_message(self, format, *args):
2244
2642
  pass # Suppress log messages
2245
2643
 
2246
-
2247
2644
  # Start web server
2248
2645
  with socketserver.TCPServer(("", port), Handler) as httpd:
2249
2646
  url = f"http://localhost:{port}/dashboard.html"
@@ -2262,8 +2659,10 @@ def main():
2262
2659
  # Sync training output from Lambda and regenerate navigation for file:// protocol
2263
2660
  from pathlib import Path
2264
2661
  from openadapt_ml.training.trainer import (
2265
- TrainingState, TrainingConfig, generate_training_dashboard,
2266
- regenerate_all_dashboards
2662
+ TrainingState,
2663
+ TrainingConfig,
2664
+ generate_training_dashboard,
2665
+ regenerate_all_dashboards,
2267
2666
  )
2268
2667
 
2269
2668
  instances = client.list_instances()
@@ -2272,7 +2671,9 @@ def main():
2272
2671
  return
2273
2672
 
2274
2673
  if args.instance_id:
2275
- instance = next((i for i in instances if i.id.startswith(args.instance_id)), None)
2674
+ instance = next(
2675
+ (i for i in instances if i.id.startswith(args.instance_id)), None
2676
+ )
2276
2677
  if not instance:
2277
2678
  print(f"Instance {args.instance_id} not found.")
2278
2679
  return
@@ -2286,10 +2687,13 @@ def main():
2286
2687
 
2287
2688
  # Sync all training output files
2288
2689
  rsync_cmd = [
2289
- "rsync", "-avz", "--progress",
2290
- "-e", "ssh -o StrictHostKeyChecking=no",
2690
+ "rsync",
2691
+ "-avz",
2692
+ "--progress",
2693
+ "-e",
2694
+ "ssh -o StrictHostKeyChecking=no",
2291
2695
  f"ubuntu@{instance.ip}:~/openadapt-ml/training_output/",
2292
- str(output_dir) + "/"
2696
+ str(output_dir) + "/",
2293
2697
  ]
2294
2698
  result = subprocess.run(rsync_cmd, capture_output=False)
2295
2699
 
@@ -2303,6 +2707,7 @@ def main():
2303
2707
  if log_path.exists():
2304
2708
  try:
2305
2709
  import time as time_module
2710
+
2306
2711
  status = json.loads(log_path.read_text())
2307
2712
 
2308
2713
  # Update with instance info
@@ -2336,7 +2741,7 @@ def main():
2336
2741
 
2337
2742
  config = TrainingConfig(
2338
2743
  num_train_epochs=status.get("total_epochs", 5),
2339
- learning_rate=status.get("learning_rate", 5e-5)
2744
+ learning_rate=status.get("learning_rate", 5e-5),
2340
2745
  )
2341
2746
 
2342
2747
  dashboard_path.write_text(generate_training_dashboard(state, config))
@@ -2390,7 +2795,7 @@ def main():
2390
2795
  # First try training log
2391
2796
  log_data = json.loads((output_dir / "training_log.json").read_text())
2392
2797
  capture_path = log_data.get("capture_path", "")
2393
- capture_match = re.search(r'capture_(\d+)', capture_path)
2798
+ capture_match = re.search(r"capture_(\d+)", capture_path)
2394
2799
  if capture_match:
2395
2800
  capture_id = capture_match.group(1)
2396
2801
 
@@ -2401,27 +2806,37 @@ def main():
2401
2806
  base_data = pred_data.get("base_data", [])
2402
2807
  if base_data:
2403
2808
  image_path = base_data[0].get("image_path", "")
2404
- capture_match = re.search(r'capture_(\d+)', image_path)
2809
+ capture_match = re.search(r"capture_(\d+)", image_path)
2405
2810
  if capture_match:
2406
2811
  capture_id = capture_match.group(1)
2407
2812
  break
2408
2813
 
2409
2814
  if capture_id:
2410
2815
  # Search for local screenshots in openadapt-capture
2411
- openadapt_capture_dir = Path.home() / "oa" / "src" / "openadapt-capture"
2816
+ openadapt_capture_dir = (
2817
+ Path.home() / "oa" / "src" / "openadapt-capture"
2818
+ )
2412
2819
  if openadapt_capture_dir.exists():
2413
2820
  for capture_dir in openadapt_capture_dir.iterdir():
2414
2821
  if capture_dir.is_dir():
2415
2822
  screenshots_dir = capture_dir / "screenshots"
2416
2823
  if screenshots_dir.exists():
2417
2824
  # Check if this capture has our screenshots
2418
- sample_file = list(screenshots_dir.glob(f"capture_{capture_id}_step_*.png"))
2825
+ sample_file = list(
2826
+ screenshots_dir.glob(
2827
+ f"capture_{capture_id}_step_*.png"
2828
+ )
2829
+ )
2419
2830
  if sample_file:
2420
- print(f"Found local screenshots in {screenshots_dir}")
2831
+ print(
2832
+ f"Found local screenshots in {screenshots_dir}"
2833
+ )
2421
2834
  screenshots_link.symlink_to(screenshots_dir)
2422
- print(f" Linked: {screenshots_link} -> {screenshots_dir}")
2835
+ print(
2836
+ f" Linked: {screenshots_link} -> {screenshots_dir}"
2837
+ )
2423
2838
  break
2424
- except Exception as e:
2839
+ except Exception:
2425
2840
  pass # Silently continue if auto-link fails
2426
2841
 
2427
2842
  print(f"Regenerating viewer from {output_dir}...")
@@ -2435,7 +2850,7 @@ def main():
2435
2850
  target = output_dir / "viewer.html"
2436
2851
 
2437
2852
  print(f"\nGenerated: {target.absolute()}")
2438
- print(f"View with: uv run python -m openadapt_ml.cloud.lambda_labs serve --open")
2853
+ print("View with: uv run python -m openadapt_ml.cloud.lambda_labs serve --open")
2439
2854
 
2440
2855
  if args.open:
2441
2856
  subprocess.run(["open", str(target)], capture_output=True)