openadapt-ml 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. openadapt_ml/__init__.py +0 -0
  2. openadapt_ml/benchmarks/__init__.py +125 -0
  3. openadapt_ml/benchmarks/agent.py +825 -0
  4. openadapt_ml/benchmarks/azure.py +761 -0
  5. openadapt_ml/benchmarks/base.py +366 -0
  6. openadapt_ml/benchmarks/cli.py +884 -0
  7. openadapt_ml/benchmarks/data_collection.py +432 -0
  8. openadapt_ml/benchmarks/runner.py +381 -0
  9. openadapt_ml/benchmarks/waa.py +704 -0
  10. openadapt_ml/cloud/__init__.py +5 -0
  11. openadapt_ml/cloud/azure_inference.py +441 -0
  12. openadapt_ml/cloud/lambda_labs.py +2445 -0
  13. openadapt_ml/cloud/local.py +790 -0
  14. openadapt_ml/config.py +56 -0
  15. openadapt_ml/datasets/__init__.py +0 -0
  16. openadapt_ml/datasets/next_action.py +507 -0
  17. openadapt_ml/evals/__init__.py +23 -0
  18. openadapt_ml/evals/grounding.py +241 -0
  19. openadapt_ml/evals/plot_eval_metrics.py +174 -0
  20. openadapt_ml/evals/trajectory_matching.py +486 -0
  21. openadapt_ml/grounding/__init__.py +45 -0
  22. openadapt_ml/grounding/base.py +236 -0
  23. openadapt_ml/grounding/detector.py +570 -0
  24. openadapt_ml/ingest/__init__.py +43 -0
  25. openadapt_ml/ingest/capture.py +312 -0
  26. openadapt_ml/ingest/loader.py +232 -0
  27. openadapt_ml/ingest/synthetic.py +1102 -0
  28. openadapt_ml/models/__init__.py +0 -0
  29. openadapt_ml/models/api_adapter.py +171 -0
  30. openadapt_ml/models/base_adapter.py +59 -0
  31. openadapt_ml/models/dummy_adapter.py +42 -0
  32. openadapt_ml/models/qwen_vl.py +426 -0
  33. openadapt_ml/runtime/__init__.py +0 -0
  34. openadapt_ml/runtime/policy.py +182 -0
  35. openadapt_ml/schemas/__init__.py +53 -0
  36. openadapt_ml/schemas/sessions.py +122 -0
  37. openadapt_ml/schemas/validation.py +252 -0
  38. openadapt_ml/scripts/__init__.py +0 -0
  39. openadapt_ml/scripts/compare.py +1490 -0
  40. openadapt_ml/scripts/demo_policy.py +62 -0
  41. openadapt_ml/scripts/eval_policy.py +287 -0
  42. openadapt_ml/scripts/make_gif.py +153 -0
  43. openadapt_ml/scripts/prepare_synthetic.py +43 -0
  44. openadapt_ml/scripts/run_qwen_login_benchmark.py +192 -0
  45. openadapt_ml/scripts/train.py +174 -0
  46. openadapt_ml/training/__init__.py +0 -0
  47. openadapt_ml/training/benchmark_viewer.py +1538 -0
  48. openadapt_ml/training/shared_ui.py +157 -0
  49. openadapt_ml/training/stub_provider.py +276 -0
  50. openadapt_ml/training/trainer.py +2446 -0
  51. openadapt_ml/training/viewer.py +2970 -0
  52. openadapt_ml-0.1.0.dist-info/METADATA +818 -0
  53. openadapt_ml-0.1.0.dist-info/RECORD +55 -0
  54. openadapt_ml-0.1.0.dist-info/WHEEL +4 -0
  55. openadapt_ml-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,2445 @@
1
+ """Lambda Labs cloud GPU integration.
2
+
3
+ Lambda Labs provides affordable GPU instances for training:
4
+ - A100 40GB: ~$1.10/hour
5
+ - H100: ~$2.00/hour
6
+ - A10: ~$0.60/hour
7
+
8
+ API docs: https://cloud.lambdalabs.com/api/v1/docs
9
+
10
+ Usage:
11
+ # Set API key
12
+ export LAMBDA_API_KEY=your_key_here
13
+
14
+ # List available instances
15
+ python -m openadapt_ml.cloud.lambda_labs list
16
+
17
+ # Launch instance for training
18
+ python -m openadapt_ml.cloud.lambda_labs launch --type gpu_1x_a100
19
+
20
+ # Check running instances
21
+ python -m openadapt_ml.cloud.lambda_labs status
22
+
23
+ # Terminate instance
24
+ python -m openadapt_ml.cloud.lambda_labs terminate <instance_id>
25
+ """
26
+
27
+ from __future__ import annotations
28
+
29
+ import json
30
+ import os
31
+ import subprocess
32
+ import sys
33
+ import time
34
+ from dataclasses import dataclass
35
+ from pathlib import Path
36
+ from typing import Any
37
+
38
+ import requests
39
+
40
+
41
+ API_BASE = "https://cloud.lambdalabs.com/api/v1"
42
+
43
+ # Default port for HTTP server
44
+ DEFAULT_SERVER_PORT = 8765
45
+
46
+
47
+ def start_dashboard_server(output_dir: Path, port: int = DEFAULT_SERVER_PORT) -> tuple[subprocess.Popen, str]:
48
+ """Start a background HTTP server for the dashboard.
49
+
50
+ Args:
51
+ output_dir: Directory containing dashboard files
52
+ port: Port to serve on
53
+
54
+ Returns:
55
+ (process, url): The server process and the dashboard URL
56
+ """
57
+ import webbrowser
58
+ import threading
59
+
60
+ # Start simple HTTP server in background thread
61
+ server_proc = subprocess.Popen(
62
+ [sys.executable, "-m", "http.server", str(port)],
63
+ cwd=str(output_dir),
64
+ stdout=subprocess.DEVNULL,
65
+ stderr=subprocess.DEVNULL,
66
+ )
67
+
68
+ url = f"http://localhost:{port}/dashboard.html"
69
+
70
+ # Give server time to start
71
+ time.sleep(0.5)
72
+
73
+ return server_proc, url
74
+
75
+
76
+ def open_dashboard_in_browser(output_dir: Path, port: int = DEFAULT_SERVER_PORT):
77
+ """Start HTTP server and open dashboard in browser.
78
+
79
+ Args:
80
+ output_dir: Directory containing dashboard files
81
+ port: Port to serve on
82
+
83
+ Returns:
84
+ Server process (caller should call .terminate() when done), or None if failed
85
+ """
86
+ import webbrowser
87
+
88
+ try:
89
+ server_proc, url = start_dashboard_server(output_dir, port)
90
+ webbrowser.open(url)
91
+ print(f"Dashboard: {url}")
92
+ print(" Stop Training button enabled in dashboard")
93
+ return server_proc
94
+ except Exception as e:
95
+ print(f"Warning: Could not start dashboard server: {e}")
96
+ return None
97
+
98
+
99
+ def setup_capture_screenshots_symlink(output_dir: Path, capture_path: str | Path) -> bool:
100
+ """Create symlink from output_dir/screenshots to capture's screenshots folder.
101
+
102
+ This allows the dashboard to serve screenshots via relative paths.
103
+
104
+ Args:
105
+ output_dir: Training output directory (e.g., training_output/job_id/)
106
+ capture_path: Path to capture directory (local)
107
+
108
+ Returns:
109
+ True if symlink created successfully
110
+ """
111
+ capture_path = Path(capture_path)
112
+ screenshots_src = capture_path / "screenshots"
113
+ screenshots_dst = output_dir / "screenshots"
114
+
115
+ if not screenshots_src.exists():
116
+ return False
117
+
118
+ # Remove existing symlink or directory
119
+ if screenshots_dst.is_symlink():
120
+ screenshots_dst.unlink()
121
+ elif screenshots_dst.exists():
122
+ return False # Don't overwrite real directory
123
+
124
+ try:
125
+ screenshots_dst.symlink_to(screenshots_src.resolve())
126
+ return True
127
+ except Exception:
128
+ return False
129
+
130
+
131
+ def rewrite_evaluation_paths(evaluations: list[dict], remote_prefix: str = "/home/ubuntu/capture/") -> list[dict]:
132
+ """Rewrite Lambda paths in evaluations to relative paths.
133
+
134
+ Converts: /home/ubuntu/capture/screenshots/foo.png -> screenshots/foo.png
135
+
136
+ Args:
137
+ evaluations: List of evaluation dicts with image_path
138
+ remote_prefix: The Lambda path prefix to replace
139
+
140
+ Returns:
141
+ Evaluations with rewritten paths
142
+ """
143
+ for ev in evaluations:
144
+ if "image_path" in ev and ev["image_path"].startswith(remote_prefix):
145
+ ev["image_path"] = ev["image_path"].replace(remote_prefix, "")
146
+ return evaluations
147
+
148
+
149
+ def download_checkpoints_from_instance(instance_ip: str, output_dir: Path, ssh_key: str | None = None) -> bool:
150
+ """Download checkpoints from Lambda instance.
151
+
152
+ Args:
153
+ instance_ip: IP address of Lambda instance
154
+ output_dir: Local directory to save checkpoints
155
+ ssh_key: Path to SSH key (uses default if not provided)
156
+
157
+ Returns:
158
+ True if download succeeded
159
+ """
160
+ checkpoints_dir = output_dir / "checkpoints"
161
+ checkpoints_dir.mkdir(parents=True, exist_ok=True)
162
+
163
+ ssh_key = ssh_key or str(Path.home() / ".ssh" / "lambda_id_ed25519")
164
+ ssh_opts = f"-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -i {ssh_key}"
165
+
166
+ # Download checkpoints from remote
167
+ remote_path = f"ubuntu@{instance_ip}:~/openadapt-ml/checkpoints/"
168
+ local_path = str(checkpoints_dir) + "/"
169
+
170
+ cmd = f"rsync -avz --progress -e 'ssh {ssh_opts}' {remote_path} {local_path}"
171
+ result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
172
+
173
+ if result.returncode == 0:
174
+ return True
175
+ return False
176
+
177
+
178
+ def check_stop_signal(output_dir: Path) -> bool:
179
+ """Check if stop signal file exists.
180
+
181
+ The dashboard can create this file to signal training should stop.
182
+ """
183
+ stop_file = output_dir / "STOP_TRAINING"
184
+ return stop_file.exists()
185
+
186
+
187
+ @dataclass
188
+ class InstanceType:
189
+ """Lambda Labs instance type."""
190
+ name: str
191
+ price_cents_per_hour: int
192
+ description: str
193
+ gpu_count: int
194
+ gpu_type: str
195
+ vcpus: int
196
+ memory_gb: int
197
+ storage_gb: int
198
+ available_regions: list[str]
199
+
200
+ @property
201
+ def price_per_hour(self) -> float:
202
+ return self.price_cents_per_hour / 100
203
+
204
+ def __str__(self) -> str:
205
+ regions = ", ".join(self.available_regions[:3])
206
+ if len(self.available_regions) > 3:
207
+ regions += f" (+{len(self.available_regions) - 3} more)"
208
+ return (
209
+ f"{self.name}: ${self.price_per_hour:.2f}/hr | "
210
+ f"{self.gpu_count}x {self.gpu_type} | {self.vcpus} vCPUs | "
211
+ f"{self.memory_gb}GB RAM | {self.storage_gb}GB SSD | "
212
+ f"Regions: {regions}"
213
+ )
214
+
215
+
216
+ @dataclass
217
+ class Instance:
218
+ """Running Lambda Labs instance."""
219
+ id: str
220
+ name: str
221
+ instance_type: str
222
+ status: str
223
+ ip: str | None
224
+ region: str
225
+ ssh_key_names: list[str]
226
+
227
+ def __str__(self) -> str:
228
+ ip_str = self.ip or "pending"
229
+ return f"{self.id[:8]}... | {self.instance_type} | {self.status} | IP: {ip_str} | {self.region}"
230
+
231
+
232
+ class LambdaLabsClient:
233
+ """Client for Lambda Labs API."""
234
+
235
+ def __init__(self, api_key: str | None = None):
236
+ # Try provided key, then settings, then env var
237
+ if not api_key:
238
+ from openadapt_ml.config import settings
239
+ api_key = settings.lambda_api_key or os.environ.get("LAMBDA_API_KEY")
240
+
241
+ self.api_key = api_key
242
+ if not self.api_key:
243
+ raise ValueError(
244
+ "Lambda Labs API key required. Set LAMBDA_API_KEY in .env file "
245
+ "or get one at https://cloud.lambdalabs.com/api-keys"
246
+ )
247
+ self.session = requests.Session()
248
+ self.session.headers["Authorization"] = f"Bearer {self.api_key}"
249
+
250
+ def _get(self, endpoint: str) -> dict[str, Any]:
251
+ """Make GET request to API."""
252
+ resp = self.session.get(f"{API_BASE}{endpoint}")
253
+ resp.raise_for_status()
254
+ return resp.json()
255
+
256
+ def _post(self, endpoint: str, data: dict[str, Any]) -> dict[str, Any]:
257
+ """Make POST request to API."""
258
+ resp = self.session.post(f"{API_BASE}{endpoint}", json=data)
259
+ if not resp.ok:
260
+ error = resp.json().get("error", {})
261
+ raise RuntimeError(f"API error: {error.get('message', resp.text)}")
262
+ return resp.json()
263
+
264
+ def list_instance_types(self) -> list[InstanceType]:
265
+ """List available GPU instance types."""
266
+ data = self._get("/instance-types")
267
+ types = []
268
+
269
+ for name, info in data.get("data", {}).items():
270
+ specs = info.get("instance_type", {}).get("specs", {})
271
+ regions = [r["name"] for r in info.get("regions_with_capacity_available", [])]
272
+
273
+ types.append(InstanceType(
274
+ name=name,
275
+ price_cents_per_hour=info.get("instance_type", {}).get("price_cents_per_hour", 0),
276
+ description=info.get("instance_type", {}).get("description", ""),
277
+ gpu_count=specs.get("gpus", 0),
278
+ gpu_type=info.get("instance_type", {}).get("gpu_description", ""),
279
+ vcpus=specs.get("vcpus", 0),
280
+ memory_gb=specs.get("memory_gib", 0),
281
+ storage_gb=specs.get("storage_gib", 0),
282
+ available_regions=regions,
283
+ ))
284
+
285
+ # Sort by price
286
+ types.sort(key=lambda t: t.price_cents_per_hour)
287
+ return types
288
+
289
+ def list_ssh_keys(self) -> list[dict[str, str]]:
290
+ """List registered SSH keys."""
291
+ data = self._get("/ssh-keys")
292
+ return data.get("data", [])
293
+
294
+ def add_ssh_key(self, name: str, public_key: str) -> dict[str, str]:
295
+ """Add an SSH key."""
296
+ data = self._post("/ssh-keys", {"name": name, "public_key": public_key})
297
+ return data.get("data", {})
298
+
299
+ def list_instances(self) -> list[Instance]:
300
+ """List running instances."""
301
+ data = self._get("/instances")
302
+ instances = []
303
+
304
+ for inst in data.get("data", []):
305
+ # ssh_key_names can be list of strings or list of dicts
306
+ ssh_keys = inst.get("ssh_key_names", [])
307
+ if ssh_keys and isinstance(ssh_keys[0], dict):
308
+ ssh_key_names = [k["name"] for k in ssh_keys]
309
+ else:
310
+ ssh_key_names = ssh_keys # Already list of strings
311
+
312
+ instances.append(Instance(
313
+ id=inst["id"],
314
+ name=inst.get("name", ""),
315
+ instance_type=inst.get("instance_type", {}).get("name", "unknown"),
316
+ status=inst.get("status", "unknown"),
317
+ ip=inst.get("ip"),
318
+ region=inst.get("region", {}).get("name", "unknown"),
319
+ ssh_key_names=ssh_key_names,
320
+ ))
321
+
322
+ return instances
323
+
324
+ def launch_instance(
325
+ self,
326
+ instance_type: str,
327
+ region: str | None = None,
328
+ ssh_key_names: list[str] | None = None,
329
+ name: str | None = None,
330
+ ) -> Instance:
331
+ """Launch a new GPU instance.
332
+
333
+ Args:
334
+ instance_type: Instance type name (e.g., 'gpu_1x_a100')
335
+ region: Region name (auto-selects if None)
336
+ ssh_key_names: SSH key names to use
337
+ name: Optional instance name
338
+
339
+ Returns:
340
+ Launched instance
341
+ """
342
+ # If no region specified, find one with capacity
343
+ if not region:
344
+ types = self.list_instance_types()
345
+ for t in types:
346
+ if t.name == instance_type and t.available_regions:
347
+ region = t.available_regions[0]
348
+ break
349
+ if not region:
350
+ raise RuntimeError(f"No regions available for {instance_type}")
351
+
352
+ # If no SSH key specified, use first available
353
+ if not ssh_key_names:
354
+ keys = self.list_ssh_keys()
355
+ if not keys:
356
+ raise RuntimeError(
357
+ "No SSH keys found. Add one at https://cloud.lambdalabs.com/ssh-keys"
358
+ )
359
+ ssh_key_names = [keys[0]["name"]]
360
+
361
+ payload = {
362
+ "region_name": region,
363
+ "instance_type_name": instance_type,
364
+ "ssh_key_names": ssh_key_names,
365
+ }
366
+ if name:
367
+ payload["name"] = name
368
+
369
+ data = self._post("/instance-operations/launch", payload)
370
+ instance_ids = data.get("data", {}).get("instance_ids", [])
371
+
372
+ if not instance_ids:
373
+ raise RuntimeError("Failed to launch instance")
374
+
375
+ # Wait for instance to be ready
376
+ print(f"Instance {instance_ids[0]} launched, waiting for IP...")
377
+ instance = None
378
+ for _ in range(60): # Wait up to 5 minutes for IP
379
+ instances = self.list_instances()
380
+ for inst in instances:
381
+ if inst.id == instance_ids[0] and inst.ip:
382
+ instance = inst
383
+ break
384
+ if instance:
385
+ break
386
+ time.sleep(5)
387
+
388
+ if not instance:
389
+ raise RuntimeError("Timed out waiting for instance IP")
390
+
391
+ # Wait for SSH to be ready - be patient, instances can take a while to boot
392
+ print(f"Instance IP: {instance.ip}, waiting for SSH...")
393
+ for attempt in range(60): # Wait up to 5 minutes for SSH
394
+ try:
395
+ result = subprocess.run(
396
+ ["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=10",
397
+ f"ubuntu@{instance.ip}", "echo ready"],
398
+ capture_output=True, text=True, timeout=20
399
+ )
400
+ if result.returncode == 0:
401
+ print("SSH ready!")
402
+ return instance
403
+ except subprocess.TimeoutExpired:
404
+ pass
405
+ if attempt % 6 == 5: # Log progress every 30 seconds
406
+ print(f" Still waiting for SSH ({(attempt+1)*5}s elapsed)...")
407
+ time.sleep(5)
408
+
409
+ print("Warning: SSH may not be ready yet, continuing anyway...")
410
+ return instance
411
+
412
+ def terminate_instance(self, instance_id: str) -> bool:
413
+ """Terminate an instance."""
414
+ data = self._post("/instance-operations/terminate", {"instance_ids": [instance_id]})
415
+ terminated = data.get("data", {}).get("terminated_instances", [])
416
+ return any(t.get("id") == instance_id for t in terminated)
417
+
418
+ def get_ssh_command(self, instance: Instance, user: str = "ubuntu") -> str:
419
+ """Get SSH command for an instance."""
420
+ if not instance.ip:
421
+ return "# Instance IP not yet available"
422
+ return f"ssh {user}@{instance.ip}"
423
+
424
+ def ssh_run(self, instance: Instance, command: str, timeout: int | None = None, retries: int = 3) -> subprocess.CompletedProcess:
425
+ """Run a command on an instance via SSH.
426
+
427
+ Args:
428
+ instance: Instance to run on
429
+ command: Shell command to run
430
+ timeout: Optional timeout in seconds
431
+ retries: Number of retries on connection failure
432
+
433
+ Returns:
434
+ CompletedProcess with stdout/stderr
435
+ """
436
+ if not instance.ip:
437
+ raise RuntimeError("Instance has no IP address")
438
+
439
+ ssh_cmd = [
440
+ "ssh", "-o", "StrictHostKeyChecking=no",
441
+ "-o", "ConnectTimeout=30", # Increased from 10
442
+ "-o", "ServerAliveInterval=60", # Keep connection alive
443
+ "-o", "ServerAliveCountMax=3",
444
+ f"ubuntu@{instance.ip}",
445
+ command
446
+ ]
447
+
448
+ last_error = None
449
+ for attempt in range(retries):
450
+ try:
451
+ return subprocess.run(
452
+ ssh_cmd,
453
+ capture_output=True,
454
+ text=True,
455
+ timeout=timeout,
456
+ )
457
+ except subprocess.TimeoutExpired as e:
458
+ last_error = e
459
+ if attempt < retries - 1:
460
+ print(f" SSH timeout, retrying ({attempt + 1}/{retries})...")
461
+ time.sleep(5)
462
+
463
+ raise last_error if last_error else RuntimeError("SSH failed")
464
+
465
+ def setup_instance(self, instance: Instance, repo_url: str = "https://github.com/OpenAdaptAI/openadapt-ml.git", clean_gpu: bool = True) -> bool:
466
+ """Set up training environment on instance.
467
+
468
+ Clones repo, installs uv, syncs dependencies.
469
+ Optionally clears GPU memory from previous runs.
470
+ Returns True if successful.
471
+ """
472
+ print(f"Setting up instance {instance.ip}...")
473
+
474
+ # Clean GPU memory if requested (don't fail if this doesn't work)
475
+ if clean_gpu:
476
+ print(" Clearing GPU memory...")
477
+ try:
478
+ self.ssh_run(instance, '''
479
+ python3 -c "
480
+ import torch
481
+ if torch.cuda.is_available():
482
+ torch.cuda.empty_cache()
483
+ torch.cuda.reset_peak_memory_stats()
484
+ print('GPU memory cleared')
485
+ " 2>/dev/null || true
486
+ # Kill any stale python processes using GPU
487
+ pkill -f "python.*train" 2>/dev/null || true
488
+ ''', timeout=60)
489
+ except Exception as e:
490
+ print(f" GPU cleanup skipped: {e}")
491
+
492
+ setup_script = f'''
493
+ set -e
494
+ cd ~
495
+
496
+ # Install uv via official installer (most robust)
497
+ if ! command -v uv &> /dev/null; then
498
+ curl -LsSf https://astral.sh/uv/install.sh | sh
499
+ fi
500
+ export PATH="$HOME/.local/bin:$HOME/.cargo/bin:$PATH"
501
+
502
+ # Clone or update repo
503
+ if [ ! -d "openadapt-ml" ]; then
504
+ git clone {repo_url}
505
+ else
506
+ cd openadapt-ml && git pull origin main && cd ~
507
+ fi
508
+
509
+ cd openadapt-ml
510
+ uv sync
511
+ echo "SETUP_COMPLETE"
512
+ '''
513
+
514
+ try:
515
+ result = self.ssh_run(instance, setup_script, timeout=900) # 15 min timeout for setup
516
+
517
+ if "SETUP_COMPLETE" in result.stdout:
518
+ print(" Environment ready")
519
+ return True
520
+ else:
521
+ stderr_preview = result.stderr[:500] if result.stderr else "(no stderr)"
522
+ print(f" Setup failed: {stderr_preview}")
523
+ return False
524
+ except subprocess.TimeoutExpired:
525
+ print(" Setup timed out after 15 minutes")
526
+ return False
527
+ except Exception as e:
528
+ print(f" Setup failed: {e}")
529
+ return False
530
+
531
+ def sync_local_code(self, instance: Instance, local_repo_path: str = ".", retries: int = 3) -> bool:
532
+ """Sync local code changes to remote instance.
533
+
534
+ Uses rsync to push local code, excluding .venv, .git, etc.
535
+ This ensures the remote has the same code as local.
536
+
537
+ Args:
538
+ instance: Instance to sync to
539
+ local_repo_path: Local repository path
540
+ retries: Number of retry attempts
541
+
542
+ Returns:
543
+ True if successful
544
+ """
545
+ if not instance.ip:
546
+ raise RuntimeError("Instance has no IP address")
547
+
548
+ print(f"Syncing local code to {instance.ip}...")
549
+
550
+ # SSH options for more robust connection
551
+ ssh_opts = "ssh -o StrictHostKeyChecking=no -o ConnectTimeout=30 -o ServerAliveInterval=60"
552
+
553
+ rsync_cmd = [
554
+ "rsync", "-avz", "--progress",
555
+ "--timeout=120", # 2 minute timeout per file
556
+ "--exclude", ".venv",
557
+ "--exclude", ".git",
558
+ "--exclude", "__pycache__",
559
+ "--exclude", "*.pyc",
560
+ "--exclude", ".env",
561
+ "--exclude", "training_output",
562
+ "--exclude", "checkpoints",
563
+ "--exclude", "synthetic*",
564
+ "-e", ssh_opts,
565
+ f"{local_repo_path}/",
566
+ f"ubuntu@{instance.ip}:~/openadapt-ml/"
567
+ ]
568
+
569
+ for attempt in range(retries):
570
+ result = subprocess.run(rsync_cmd)
571
+ if result.returncode == 0:
572
+ print(" Code synced")
573
+ return True
574
+ if attempt < retries - 1:
575
+ print(f" Sync failed, retrying ({attempt + 1}/{retries})...")
576
+ time.sleep(5)
577
+
578
+ return False
579
+
580
+ def upload_capture(self, instance: Instance, local_path: str, remote_path: str = "~/capture", retries: int = 3) -> bool:
581
+ """Upload a capture directory to instance via rsync.
582
+
583
+ Args:
584
+ instance: Instance to upload to
585
+ local_path: Local path to capture directory
586
+ remote_path: Remote path (default: ~/capture)
587
+ retries: Number of retry attempts
588
+
589
+ Returns:
590
+ True if successful
591
+ """
592
+ if not instance.ip:
593
+ raise RuntimeError("Instance has no IP address")
594
+
595
+ print(f"Uploading capture to {instance.ip}:{remote_path}...")
596
+
597
+ # SSH options for more robust connection
598
+ ssh_opts = "ssh -o StrictHostKeyChecking=no -o ConnectTimeout=30 -o ServerAliveInterval=60"
599
+
600
+ rsync_cmd = [
601
+ "rsync", "-avz", "--progress",
602
+ "--timeout=120", # 2 minute timeout per file
603
+ "-e", ssh_opts,
604
+ f"{local_path}/",
605
+ f"ubuntu@{instance.ip}:{remote_path}/"
606
+ ]
607
+
608
+ for attempt in range(retries):
609
+ result = subprocess.run(rsync_cmd)
610
+ if result.returncode == 0:
611
+ return True
612
+ if attempt < retries - 1:
613
+ print(f" Upload failed, retrying ({attempt + 1}/{retries})...")
614
+ time.sleep(5)
615
+
616
+ return False
617
+
618
+ def run_training(
619
+ self,
620
+ instance: Instance,
621
+ config: str = "configs/qwen3vl_capture.yaml",
622
+ capture: str | None = None,
623
+ goal: str | None = None,
624
+ background: bool = True,
625
+ ) -> subprocess.Popen | subprocess.CompletedProcess:
626
+ """Run training on instance.
627
+
628
+ Args:
629
+ instance: Instance to train on
630
+ config: Config file path (relative to repo)
631
+ capture: Remote capture path (if uploaded)
632
+ goal: Task goal description
633
+ background: Run in background (returns Popen) or foreground
634
+
635
+ Returns:
636
+ Popen if background=True, CompletedProcess if background=False
637
+ """
638
+ if not instance.ip:
639
+ raise RuntimeError("Instance has no IP address")
640
+
641
+ # Build training command
642
+ train_cmd = f"uv run python -m openadapt_ml.scripts.train --config {config}"
643
+ if capture:
644
+ train_cmd += f" --capture {capture}"
645
+ if goal:
646
+ train_cmd += f' --goal "{goal}"'
647
+
648
+ # Full script with environment setup
649
+ script = f'''
650
+ cd ~/openadapt-ml
651
+ export PATH="$HOME/.local/bin:$PATH"
652
+ {train_cmd}
653
+ '''
654
+
655
+ ssh_cmd = [
656
+ "ssh", "-o", "StrictHostKeyChecking=no",
657
+ f"ubuntu@{instance.ip}",
658
+ script
659
+ ]
660
+
661
+ print(f"Running training on {instance.ip}...")
662
+ print(f" Config: {config}")
663
+ if capture:
664
+ print(f" Capture: {capture}")
665
+
666
+ if background:
667
+ # Run in background, return Popen for monitoring
668
+ return subprocess.Popen(
669
+ ssh_cmd,
670
+ stdout=subprocess.PIPE,
671
+ stderr=subprocess.STDOUT,
672
+ text=True,
673
+ )
674
+ else:
675
+ # Run in foreground, stream output
676
+ return subprocess.run(ssh_cmd)
677
+
678
+ def download_results(
679
+ self,
680
+ instance: Instance,
681
+ remote_path: str = "~/openadapt-ml",
682
+ local_path: str = ".",
683
+ include_checkpoint: bool = True,
684
+ include_logs: bool = True,
685
+ ) -> bool:
686
+ """Download training results from instance.
687
+
688
+ Args:
689
+ instance: Instance to download from
690
+ remote_path: Remote openadapt-ml directory
691
+ local_path: Local directory to download to
692
+ include_checkpoint: Download checkpoint weights
693
+ include_logs: Download training logs and dashboard
694
+
695
+ Returns:
696
+ True if successful
697
+ """
698
+ if not instance.ip:
699
+ raise RuntimeError("Instance has no IP address")
700
+
701
+ print(f"Downloading results from {instance.ip}...")
702
+ success = True
703
+
704
+ # Download training output (logs, dashboard)
705
+ if include_logs:
706
+ print(" Downloading training logs...")
707
+ rsync_cmd = [
708
+ "rsync", "-avz",
709
+ "-e", "ssh -o StrictHostKeyChecking=no",
710
+ f"ubuntu@{instance.ip}:{remote_path}/training_output/",
711
+ f"{local_path}/training_output_lambda/"
712
+ ]
713
+ result = subprocess.run(rsync_cmd, capture_output=True)
714
+ if result.returncode == 0:
715
+ print(" Training logs downloaded to training_output_lambda/")
716
+ else:
717
+ print(f" Warning: Failed to download logs")
718
+ success = False
719
+
720
+ # Download checkpoint
721
+ if include_checkpoint:
722
+ print(" Downloading checkpoint...")
723
+ rsync_cmd = [
724
+ "rsync", "-avz",
725
+ "-e", "ssh -o StrictHostKeyChecking=no",
726
+ f"ubuntu@{instance.ip}:{remote_path}/checkpoints/",
727
+ f"{local_path}/checkpoints_lambda/"
728
+ ]
729
+ result = subprocess.run(rsync_cmd, capture_output=True)
730
+ if result.returncode == 0:
731
+ print(" Checkpoint downloaded to checkpoints_lambda/")
732
+ else:
733
+ print(f" Warning: Failed to download checkpoint (may not exist yet)")
734
+
735
+ # Regenerate all dashboards with static navigation and correct status
736
+ if include_logs:
737
+ try:
738
+ from openadapt_ml.training.trainer import regenerate_all_dashboards
739
+ output_dir = Path(local_path) / "training_output_lambda"
740
+ if output_dir.exists():
741
+ print(" Regenerating dashboards with static navigation...")
742
+ regenerate_all_dashboards(output_dir)
743
+ except Exception as e:
744
+ print(f" Warning: Failed to regenerate dashboards: {e}")
745
+
746
+ return success
747
+
748
+ def get_training_status(self, instance: Instance) -> dict:
749
+ """Check training status by reading training_log.json on instance."""
750
+ result = self.ssh_run(
751
+ instance,
752
+ "cat ~/openadapt-ml/training_output/training_log.json 2>/dev/null || echo '{}'",
753
+ timeout=10,
754
+ )
755
+ try:
756
+ import json
757
+ return json.loads(result.stdout.strip())
758
+ except:
759
+ return {}
760
+
761
+
762
+ def setup_lambda_ssh_key(client: LambdaLabsClient) -> str:
763
+ """Set up SSH key for Lambda Labs if not already done.
764
+
765
+ Returns the SSH key name that was added/found.
766
+ """
767
+ # Check if we already have keys
768
+ keys = client.list_ssh_keys()
769
+ if keys:
770
+ print(f"Using existing SSH key: {keys[0]['name']}")
771
+ return keys[0]["name"]
772
+
773
+ # Look for local SSH key
774
+ ssh_key_path = Path.home() / ".ssh" / "id_rsa.pub"
775
+ if not ssh_key_path.exists():
776
+ ssh_key_path = Path.home() / ".ssh" / "id_ed25519.pub"
777
+
778
+ if not ssh_key_path.exists():
779
+ raise RuntimeError(
780
+ "No SSH key found at ~/.ssh/id_rsa.pub or ~/.ssh/id_ed25519.pub\n"
781
+ "Generate one with: ssh-keygen -t ed25519"
782
+ )
783
+
784
+ public_key = ssh_key_path.read_text().strip()
785
+ key_name = f"openadapt-{os.environ.get('USER', 'user')}"
786
+
787
+ print(f"Adding SSH key '{key_name}' to Lambda Labs...")
788
+ client.add_ssh_key(key_name, public_key)
789
+ return key_name
790
+
791
+
792
+ def main():
793
+ """CLI for Lambda Labs."""
794
+ import argparse
795
+
796
+ parser = argparse.ArgumentParser(description="Lambda Labs GPU management")
797
+ subparsers = parser.add_subparsers(dest="command", help="Command")
798
+
799
+ # List instances command
800
+ list_parser = subparsers.add_parser("list", help="List available instance types")
801
+
802
+ # Status command
803
+ status_parser = subparsers.add_parser("status", help="Show running instances")
804
+
805
+ # Launch command
806
+ launch_parser = subparsers.add_parser("launch", help="Launch a GPU instance")
807
+ launch_parser.add_argument(
808
+ "--type", "-t",
809
+ default="gpu_1x_a100",
810
+ help="Instance type (default: gpu_1x_a100)",
811
+ )
812
+ launch_parser.add_argument("--region", "-r", help="Region (auto-selects if not specified)")
813
+ launch_parser.add_argument("--name", "-n", help="Instance name")
814
+
815
+ # Terminate command
816
+ term_parser = subparsers.add_parser("terminate", help="Terminate an instance")
817
+ term_parser.add_argument("instance_id", help="Instance ID to terminate")
818
+
819
+ # SSH command - run commands or get interactive shell
820
+ ssh_parser = subparsers.add_parser("ssh", help="SSH into Lambda instance or run command")
821
+ ssh_parser.add_argument("instance_id", nargs="?", help="Instance ID (uses first if not specified)")
822
+ ssh_parser.add_argument("--cmd", "-c", help="Command to run (opens shell if not specified)")
823
+ ssh_parser.add_argument("--timeout", "-t", type=int, default=60, help="Command timeout in seconds")
824
+
825
+ # Serve command - start dashboard server with stop button support
826
+ serve_parser = subparsers.add_parser("serve", help="Start dashboard server with stop button support")
827
+ serve_parser.add_argument("--output", "-o", default="training_output", help="Output directory (default: training_output)")
828
+ serve_parser.add_argument("--port", "-p", type=int, default=8765, help="Port (default: 8765)")
829
+ serve_parser.add_argument("--open", action="store_true", help="Open dashboard in browser")
830
+
831
+ # Rsync command - copy files to/from Lambda instance
832
+ rsync_parser = subparsers.add_parser("rsync", help="Rsync files to/from Lambda instance")
833
+ rsync_parser.add_argument("source", help="Source path (prefix with 'remote:' for remote paths)")
834
+ rsync_parser.add_argument("dest", help="Destination path (prefix with 'remote:' for remote paths)")
835
+ rsync_parser.add_argument("instance_id", nargs="?", help="Instance ID (uses first if not specified)")
836
+ rsync_parser.add_argument("--delete", action="store_true", help="Delete extraneous files from dest")
837
+
838
+ # Setup command
839
+ setup_parser = subparsers.add_parser("setup", help="Set up SSH key for Lambda Labs")
840
+
841
+ # Train command - full automated training pipeline
842
+ train_parser = subparsers.add_parser("train", help="Run training on Lambda GPU")
843
+ train_parser.add_argument("--capture", "-c", help="Local path to capture directory")
844
+ train_parser.add_argument("--goal", "-g", help="Task goal description")
845
+ train_parser.add_argument("--config", default="configs/qwen3vl_capture_4bit.yaml", help="Config file (default: 4bit for memory efficiency)")
846
+ train_parser.add_argument("--type", "-t", default="gpu_1x_a10", help="Instance type")
847
+ train_parser.add_argument("--instance", "-i", help="Use existing instance ID instead of launching new")
848
+ train_parser.add_argument("--no-terminate", action="store_true", help="Don't terminate instance after training")
849
+ train_parser.add_argument("--max-runtime", type=int, default=60, help="Max runtime in minutes before auto-terminate (default: 60)")
850
+ train_parser.add_argument("--open", action="store_true", help="Open dashboard in browser when training starts")
851
+
852
+ # Training status command
853
+ train_status_parser = subparsers.add_parser("train-status", help="Check training status on instance")
854
+ train_status_parser.add_argument("instance_id", nargs="?", help="Instance ID")
855
+
856
+ # Monitor command - live dashboard for Lambda training
857
+ monitor_parser = subparsers.add_parser("monitor", help="Monitor Lambda training with live dashboard")
858
+ monitor_parser.add_argument("instance_id", nargs="?", help="Instance ID")
859
+ monitor_parser.add_argument("--open", action="store_true", help="Open dashboard in browser")
860
+ monitor_parser.add_argument("--interval", type=int, default=5, help="Poll interval in seconds (default: 5)")
861
+ monitor_parser.add_argument("--capture", type=str, help="Local capture path for screenshot symlink")
862
+ monitor_parser.add_argument("--auto-stop-loss", type=float, default=0.5, help="Auto-terminate when loss drops below this (default: 0.5)")
863
+ monitor_parser.add_argument("--download-checkpoints", action="store_true", default=True, help="Auto-download checkpoints each epoch")
864
+ monitor_parser.add_argument("--no-download-checkpoints", action="store_false", dest="download_checkpoints", help="Disable checkpoint download")
865
+ monitor_parser.add_argument("--stub", action="store_true", help="Use stub training provider (no GPU, instant simulation)")
866
+
867
+ # Refresh command - one-shot dashboard update
868
+ refresh_parser = subparsers.add_parser("refresh", help="One-shot refresh of training dashboard")
869
+ refresh_parser.add_argument("instance_id", nargs="?", help="Instance ID")
870
+ refresh_parser.add_argument("--open", action="store_true", help="Open dashboard in browser")
871
+ refresh_parser.add_argument("--capture", type=str, help="Local capture path for screenshot preview")
872
+
873
+ # Checkpoints command - list remote checkpoints
874
+ checkpoints_parser = subparsers.add_parser("checkpoints", help="List checkpoints on remote instance")
875
+ checkpoints_parser.add_argument("instance_id", nargs="?", help="Instance ID")
876
+
877
+ # Download results command
878
+ download_parser = subparsers.add_parser("download", help="Download training results from instance")
879
+ download_parser.add_argument("instance_id", nargs="?", help="Instance ID")
880
+ download_parser.add_argument("--output", "-o", default=".", help="Local output directory")
881
+
882
+ # Check files on instance
883
+ files_parser = subparsers.add_parser("files", help="List training files on instance")
884
+ files_parser.add_argument("instance_id", nargs="?", help="Instance ID")
885
+ files_parser.add_argument("--path", "-p", default="~/openadapt-ml", help="Path to check")
886
+
887
+ # Kill command - terminate training processes
888
+ kill_parser = subparsers.add_parser("kill", help="Kill training/inference processes on instance")
889
+ kill_parser.add_argument("instance_id", nargs="?", help="Instance ID")
890
+ kill_parser.add_argument("--local", action="store_true", help="Also kill local Lambda-related processes")
891
+ kill_parser.add_argument("--all", action="store_true", help="Kill all Python processes on instance (careful!)")
892
+
893
+ # Check command - analyze training status and early stopping
894
+ check_parser = subparsers.add_parser("check", help="Check training health and early stopping status")
895
+ check_parser.add_argument("instance_id", nargs="?", help="Instance ID")
896
+ check_parser.add_argument("--threshold", "-t", type=float, default=0.01,
897
+ help="Early stopping threshold (loss improvement over last N steps)")
898
+ check_parser.add_argument("--window", "-w", type=int, default=10,
899
+ help="Number of recent steps to check for improvement")
900
+
901
+ # Compare command - run comparison on Lambda and sync back
902
+ compare_parser = subparsers.add_parser("compare", help="Run human vs AI comparison on Lambda")
903
+ compare_parser.add_argument("instance_id", nargs="?", help="Instance ID")
904
+ compare_parser.add_argument("--checkpoint", "-c", help="Checkpoint to use (default: latest)")
905
+ compare_parser.add_argument("--epoch", "-e", type=int, help="Use checkpoint from specific epoch")
906
+ compare_parser.add_argument("--open", action="store_true", help="Open viewer after generation")
907
+
908
+ # Results viewer command - downloads and generates comparison viewer
909
+ results_parser = subparsers.add_parser("results", help="Download results and generate comparison viewer")
910
+ results_parser.add_argument("--capture", "-c", required=True, help="Local capture directory (for comparison)")
911
+ results_parser.add_argument("--goal", "-g", help="Task goal description")
912
+ results_parser.add_argument("--open", action="store_true", help="Open viewer in browser")
913
+ results_parser.add_argument("instance_id", nargs="?", help="Instance ID")
914
+
915
+ # Sync command - sync training output and regenerate navigation for file:// protocol
916
+ sync_parser = subparsers.add_parser("sync", help="Sync training output from Lambda and regenerate navigation")
917
+ sync_parser.add_argument("instance_id", nargs="?", help="Instance ID")
918
+ sync_parser.add_argument("--output", "-o", default="training_output", help="Local output directory (default: training_output)")
919
+ sync_parser.add_argument("--open", action="store_true", help="Open dashboard in browser after sync")
920
+
921
+ # Viewer command - regenerate local viewer (no Lambda required)
922
+ viewer_parser = subparsers.add_parser("viewer", help="Regenerate local viewer (no Lambda required)")
923
+ viewer_parser.add_argument("--output", "-o", default="training_output", help="Training output directory (default: training_output)")
924
+ viewer_parser.add_argument("--dashboard", "-d", action="store_true", help="Regenerate dashboard instead of viewer")
925
+ viewer_parser.add_argument("--open", action="store_true", help="Open in browser (use 'serve' instead for better experience)")
926
+
927
+ args = parser.parse_args()
928
+
929
+ if not args.command:
930
+ parser.print_help()
931
+ return
932
+
933
+ try:
934
+ client = LambdaLabsClient()
935
+ except ValueError as e:
936
+ print(f"Error: {e}")
937
+ print("\nGet your API key at https://cloud.lambdalabs.com/api-keys")
938
+ print("Then set it: export LAMBDA_API_KEY=your_key_here")
939
+ return
940
+
941
+ if args.command == "list":
942
+ print("Available GPU instances:\n")
943
+ types = client.list_instance_types()
944
+ for t in types:
945
+ avail = "available" if t.available_regions else "no capacity"
946
+ print(f" {t}")
947
+ print(f"\nTotal: {len(types)} instance types")
948
+ print("\nLaunch with: python -m openadapt_ml.cloud.lambda_labs launch --type <name>")
949
+
950
+ elif args.command == "status":
951
+ instances = client.list_instances()
952
+ if not instances:
953
+ print("No running instances.")
954
+ else:
955
+ print("Running instances:\n")
956
+ for inst in instances:
957
+ print(f" {inst}")
958
+ print(f"\nTotal: {len(instances)} instances")
959
+
960
+ elif args.command == "launch":
961
+ # Ensure SSH key is set up
962
+ ssh_key = setup_lambda_ssh_key(client)
963
+
964
+ print(f"Launching {args.type}...")
965
+ instance = client.launch_instance(
966
+ instance_type=args.type,
967
+ region=args.region,
968
+ ssh_key_names=[ssh_key],
969
+ name=args.name,
970
+ )
971
+ print(f"\nInstance launched!")
972
+ print(f" ID: {instance.id}")
973
+ print(f" IP: {instance.ip}")
974
+ print(f" Type: {instance.instance_type}")
975
+ print(f" Region: {instance.region}")
976
+ print(f"\nConnect with: ssh ubuntu@{instance.ip}")
977
+ print(f"\nTerminate with: python -m openadapt_ml.cloud.lambda_labs terminate {instance.id}")
978
+
979
+ elif args.command == "terminate":
980
+ if client.terminate_instance(args.instance_id):
981
+ print(f"Instance {args.instance_id} terminated.")
982
+ else:
983
+ print(f"Failed to terminate instance {args.instance_id}")
984
+
985
+ elif args.command == "ssh":
986
+ instances = client.list_instances()
987
+ if not instances:
988
+ print("No running instances.")
989
+ return
990
+
991
+ if args.instance_id:
992
+ instance = next((i for i in instances if i.id.startswith(args.instance_id)), None)
993
+ if not instance:
994
+ print(f"Instance {args.instance_id} not found.")
995
+ return
996
+ else:
997
+ instance = instances[0]
998
+
999
+ if hasattr(args, 'cmd') and args.cmd:
1000
+ # Run single command
1001
+ print(f"Running on {instance.ip}: {args.cmd}")
1002
+ result = client.ssh_run(instance, args.cmd, timeout=args.timeout)
1003
+ if result.stdout:
1004
+ print(result.stdout)
1005
+ if result.stderr:
1006
+ print(f"[stderr] {result.stderr}", file=sys.stderr)
1007
+ if result.returncode != 0:
1008
+ sys.exit(result.returncode)
1009
+ else:
1010
+ # Print SSH command for interactive use
1011
+ print(client.get_ssh_command(instance))
1012
+
1013
+ elif args.command == "rsync":
1014
+ # Rsync files to/from Lambda instance
1015
+ instances = client.list_instances()
1016
+ if not instances:
1017
+ print("No running instances.")
1018
+ return
1019
+
1020
+ if args.instance_id:
1021
+ instance = next((i for i in instances if i.id.startswith(args.instance_id)), None)
1022
+ if not instance:
1023
+ print(f"Instance {args.instance_id} not found.")
1024
+ return
1025
+ else:
1026
+ instance = instances[0]
1027
+
1028
+ # Parse source and dest - 'remote:' prefix indicates remote path
1029
+ source = args.source
1030
+ dest = args.dest
1031
+
1032
+ if source.startswith("remote:"):
1033
+ source = f"ubuntu@{instance.ip}:{source[7:]}"
1034
+ if dest.startswith("remote:"):
1035
+ dest = f"ubuntu@{instance.ip}:{dest[7:]}"
1036
+
1037
+ rsync_cmd = [
1038
+ "rsync", "-avz", "--progress",
1039
+ "-e", "ssh -o StrictHostKeyChecking=no",
1040
+ ]
1041
+ if args.delete:
1042
+ rsync_cmd.append("--delete")
1043
+ rsync_cmd.extend([source, dest])
1044
+
1045
+ print(f"Running: {' '.join(rsync_cmd)}")
1046
+ result = subprocess.run(rsync_cmd)
1047
+ sys.exit(result.returncode)
1048
+
1049
+ elif args.command == "setup":
1050
+ ssh_key = setup_lambda_ssh_key(client)
1051
+ print(f"SSH key '{ssh_key}' is configured.")
1052
+
1053
+ elif args.command == "train":
1054
+ # Full automated training pipeline
1055
+ import time as time_module
1056
+
1057
+ instance = None
1058
+ start_time = time_module.time()
1059
+ launched_new = False
1060
+ training_completed = False # Track if training actually finished
1061
+
1062
+ # Instance pricing (approximate $/hr)
1063
+ INSTANCE_PRICES = {
1064
+ "gpu_1x_a10": 0.75,
1065
+ "gpu_1x_a100": 1.29,
1066
+ "gpu_1x_a100_sxm4": 1.29,
1067
+ "gpu_1x_h100_pcie": 2.49,
1068
+ "gpu_1x_h100_sxm5": 3.29,
1069
+ }
1070
+
1071
+ # Get or launch instance
1072
+ if args.instance:
1073
+ instances = client.list_instances()
1074
+ instance = next((i for i in instances if i.id.startswith(args.instance)), None)
1075
+ if not instance:
1076
+ print(f"Error: Instance {args.instance} not found")
1077
+ return
1078
+ else:
1079
+ # Check for existing instances
1080
+ instances = client.list_instances()
1081
+ if instances:
1082
+ print(f"Using existing instance: {instances[0].id[:8]}...")
1083
+ instance = instances[0]
1084
+ else:
1085
+ # Launch new instance
1086
+ ssh_key = setup_lambda_ssh_key(client)
1087
+ print(f"Launching {args.type}...")
1088
+ instance = client.launch_instance(
1089
+ instance_type=args.type,
1090
+ ssh_key_names=[ssh_key],
1091
+ name="openadapt-training",
1092
+ )
1093
+ print(f"Instance launched: {instance.id[:8]}... at {instance.ip}")
1094
+ launched_new = True
1095
+
1096
+ price_per_hour = INSTANCE_PRICES.get(instance.instance_type, 1.00)
1097
+ print(f" Instance type: {instance.instance_type} (~${price_per_hour:.2f}/hr)")
1098
+ print(f" Max runtime: {args.max_runtime} minutes")
1099
+
1100
+ # Generate initial dashboard with setup status
1101
+ from pathlib import Path
1102
+ from openadapt_ml.training.trainer import (
1103
+ TrainingState, TrainingConfig, generate_training_dashboard,
1104
+ setup_job_directory
1105
+ )
1106
+ import time as time_module
1107
+ job_id = time_module.strftime("%Y%m%d_%H%M%S")
1108
+ output_dir = setup_job_directory("training_output", job_id)
1109
+ dashboard_path = output_dir / "dashboard.html"
1110
+ log_path = output_dir / "training_log.json"
1111
+
1112
+ def update_dashboard(status: str, logs: list, step: int = 0, loss: float = 0.0, epoch: int = 0):
1113
+ """Update dashboard with current setup/training status."""
1114
+ state = TrainingState(job_id=job_id)
1115
+ state.cloud_provider = "lambda"
1116
+ state.cloud_dashboard_url = "https://cloud.lambda.ai/instances"
1117
+ state.cloud_instance_id = instance.id
1118
+ state.instance_ip = instance.ip or ""
1119
+ state.instance_type = instance.instance_type
1120
+ state.setup_status = status
1121
+ state.setup_logs = logs
1122
+ state.epoch = epoch
1123
+ state.step = step
1124
+ state.loss = loss
1125
+ state.start_time = start_time
1126
+ config = TrainingConfig(num_train_epochs=5, learning_rate=5e-5)
1127
+ dashboard_path.write_text(generate_training_dashboard(state, config))
1128
+ # Also write log for polling
1129
+ log_path.write_text(json.dumps(state.to_dict(), indent=2))
1130
+
1131
+ # Initial dashboard
1132
+ setup_logs = [
1133
+ f"Lambda Cloud instance: {instance.id[:8]}...",
1134
+ f"Instance type: {instance.instance_type} (~${price_per_hour:.2f}/hr)",
1135
+ f"IP address: {instance.ip or 'pending...'}",
1136
+ ]
1137
+ update_dashboard("booting", setup_logs)
1138
+
1139
+ # Open dashboard in browser via HTTP server
1140
+ server_proc = None
1141
+ if args.open:
1142
+ server_proc = open_dashboard_in_browser(output_dir)
1143
+
1144
+ try:
1145
+ # Set up environment with retries at the command level
1146
+ setup_logs.append("Connecting to instance...")
1147
+ update_dashboard("booting", setup_logs)
1148
+
1149
+ setup_success = False
1150
+ for setup_attempt in range(3):
1151
+ setup_logs.append(f"Setup attempt {setup_attempt + 1}/3...")
1152
+ update_dashboard("installing", setup_logs)
1153
+ if client.setup_instance(instance):
1154
+ setup_success = True
1155
+ setup_logs.append("Instance setup complete!")
1156
+ update_dashboard("installing", setup_logs)
1157
+ break
1158
+ if setup_attempt < 2:
1159
+ setup_logs.append(f"Setup attempt {setup_attempt + 1} failed, retrying in 30s...")
1160
+ update_dashboard("booting", setup_logs)
1161
+ print(f" Setup attempt {setup_attempt + 1} failed, retrying in 30s...")
1162
+ time_module.sleep(30)
1163
+
1164
+ if not setup_success:
1165
+ setup_logs.append("ERROR: Failed to set up instance after 3 attempts")
1166
+ update_dashboard("booting", setup_logs)
1167
+ print("\nError: Failed to set up instance after 3 attempts")
1168
+ print(f"Instance still running: {instance.ip}")
1169
+ print("Debug via: ssh ubuntu@" + instance.ip)
1170
+ print(f"Terminate with: python -m openadapt_ml.cloud.lambda_labs terminate {instance.id}")
1171
+ return # Don't terminate - let user debug
1172
+
1173
+ # Sync local code to ensure remote has latest changes
1174
+ setup_logs.append("Syncing local code to instance...")
1175
+ update_dashboard("installing", setup_logs)
1176
+ if not client.sync_local_code(instance):
1177
+ setup_logs.append("Warning: Failed to sync local code, using remote repo version")
1178
+ update_dashboard("installing", setup_logs)
1179
+ print("Warning: Failed to sync local code, using remote repo version")
1180
+ else:
1181
+ setup_logs.append("Code synced successfully")
1182
+ update_dashboard("installing", setup_logs)
1183
+
1184
+ # Upload capture if provided
1185
+ remote_capture = None
1186
+ if args.capture:
1187
+ setup_logs.append(f"Uploading capture data...")
1188
+ update_dashboard("installing", setup_logs)
1189
+ if client.upload_capture(instance, args.capture, "~/capture"):
1190
+ remote_capture = "~/capture"
1191
+ setup_logs.append(f"Capture uploaded to {instance.ip}:~/capture")
1192
+ update_dashboard("installing", setup_logs)
1193
+ print(f"Capture uploaded to {instance.ip}:~/capture")
1194
+ else:
1195
+ setup_logs.append("ERROR: Failed to upload capture after retries")
1196
+ update_dashboard("installing", setup_logs)
1197
+ print("\nError: Failed to upload capture after retries")
1198
+ print(f"Instance still running: {instance.ip}")
1199
+ print("Debug via: ssh ubuntu@" + instance.ip)
1200
+ print(f"Terminate with: python -m openadapt_ml.cloud.lambda_labs terminate {instance.id}")
1201
+ return # Don't terminate - let user debug
1202
+
1203
+ # Run training in background and poll for status
1204
+ setup_logs.append("Installing dependencies and starting training...")
1205
+ update_dashboard("training", setup_logs)
1206
+ print("\n" + "=" * 50)
1207
+ print("Starting training...")
1208
+ print("=" * 50 + "\n")
1209
+
1210
+ proc = client.run_training(
1211
+ instance,
1212
+ config=args.config,
1213
+ capture=remote_capture,
1214
+ goal=args.goal,
1215
+ background=True, # Run in background so we can poll
1216
+ )
1217
+
1218
+ # Poll for training status and update dashboard
1219
+ poll_interval = 10 # seconds
1220
+ last_step = 0
1221
+ last_epoch = 0
1222
+ print(f"Polling training status every {poll_interval}s (Ctrl+C to stop)...\n")
1223
+
1224
+ while True:
1225
+ try:
1226
+ status = client.get_training_status(instance)
1227
+
1228
+ if status and status.get("step", 0) > 0:
1229
+ step = status.get("step", 0)
1230
+ epoch = status.get("epoch", 0)
1231
+ loss = status.get("loss", 0)
1232
+ elapsed_training = status.get("elapsed_time", 0)
1233
+ total_epochs = status.get("total_epochs", 5)
1234
+
1235
+ # Print progress when step changes
1236
+ if step > last_step or epoch > last_epoch:
1237
+ print(f" Epoch {epoch+1}/{total_epochs} | Step {step} | Loss: {loss:.4f} | Elapsed: {elapsed_training:.0f}s")
1238
+ last_step = step
1239
+ last_epoch = epoch
1240
+
1241
+ # Update local training_log.json (dashboard polls this)
1242
+ status["total_epochs"] = total_epochs
1243
+ if not status.get("instance_ip"):
1244
+ status["instance_ip"] = instance.ip
1245
+ if not status.get("instance_type"):
1246
+ status["instance_type"] = instance.instance_type
1247
+ # Add cloud provider info
1248
+ status["cloud_provider"] = "lambda"
1249
+ status["cloud_dashboard_url"] = "https://cloud.lambda.ai/instances"
1250
+ status["cloud_instance_id"] = instance.id
1251
+ status["setup_status"] = "training"
1252
+ status["setup_logs"] = setup_logs
1253
+ log_path.write_text(json.dumps(status, indent=2))
1254
+
1255
+ # Regenerate dashboard with updated data
1256
+ state = TrainingState()
1257
+ state.job_id = status.get("job_id", "")
1258
+ state.hostname = status.get("hostname", "lambda")
1259
+ state.instance_ip = instance.ip or ""
1260
+ state.instance_type = instance.instance_type
1261
+ state.epoch = epoch
1262
+ state.step = step
1263
+ state.total_epochs = total_epochs
1264
+ state.loss = loss
1265
+ state.learning_rate = status.get("learning_rate", 5e-5)
1266
+ state.losses = status.get("losses", [])
1267
+ state.evaluations = status.get("evaluations", [])
1268
+ state.start_time = time_module.time() - elapsed_training
1269
+ state.cloud_provider = "lambda"
1270
+ state.cloud_dashboard_url = "https://cloud.lambda.ai/instances"
1271
+ state.cloud_instance_id = instance.id
1272
+ state.setup_status = "training"
1273
+ state.setup_logs = setup_logs
1274
+
1275
+ config = TrainingConfig(
1276
+ num_train_epochs=total_epochs,
1277
+ learning_rate=status.get("learning_rate", 5e-5)
1278
+ )
1279
+ dashboard_path.write_text(generate_training_dashboard(state, config))
1280
+
1281
+ # Check if training is complete (all epochs done)
1282
+ if epoch >= total_epochs - 1:
1283
+ # Check if step count stopped increasing
1284
+ time_module.sleep(poll_interval)
1285
+ new_status = client.get_training_status(instance)
1286
+ if new_status and new_status.get("step", 0) == step:
1287
+ print("\n" + "=" * 50)
1288
+ print("Training complete!")
1289
+ print("=" * 50)
1290
+ training_completed = True
1291
+ break
1292
+ else:
1293
+ # Training not started yet, show setup status
1294
+ print(" Waiting for training to start...")
1295
+
1296
+ except Exception as e:
1297
+ print(f" Poll error: {e}")
1298
+
1299
+ time_module.sleep(poll_interval)
1300
+
1301
+ except KeyboardInterrupt:
1302
+ print("\n\nTraining interrupted by user")
1303
+ finally:
1304
+ # Clean up HTTP server if running
1305
+ if server_proc:
1306
+ server_proc.terminate()
1307
+ print("Dashboard server stopped.")
1308
+
1309
+ # Only auto-terminate if training completed successfully or user requested it
1310
+ elapsed = time_module.time() - start_time
1311
+ cost = (elapsed / 3600) * price_per_hour
1312
+
1313
+ if training_completed and not args.no_terminate:
1314
+ # Run comparison on Lambda before downloading and terminating (if capture was provided)
1315
+ if args.capture:
1316
+ print("\n" + "=" * 50)
1317
+ print("Running comparison on Lambda instance...")
1318
+ print("=" * 50)
1319
+
1320
+ # Determine the final checkpoint path (main checkpoint after training)
1321
+ checkpoint_path = "/home/ubuntu/openadapt-ml/checkpoints/qwen3vl2b_capture_lora"
1322
+
1323
+ # Check if checkpoint exists
1324
+ result = client.ssh_run(
1325
+ instance,
1326
+ f"ls {checkpoint_path}/adapter_config.json 2>/dev/null && echo 'exists'",
1327
+ timeout=30
1328
+ )
1329
+
1330
+ if "exists" in result.stdout:
1331
+ # Run comparison on Lambda
1332
+ output_name = f"comparison_{time_module.strftime('%H%M%S')}.html"
1333
+ cmd = f"""cd ~/openadapt-ml && source .venv/bin/activate && \
1334
+ python -m openadapt_ml.scripts.compare \
1335
+ --capture ~/capture \
1336
+ --checkpoint {checkpoint_path} \
1337
+ --output training_output/{output_name} 2>&1"""
1338
+
1339
+ print(" Generating comparison viewer (this may take a few minutes)...")
1340
+ result = client.ssh_run(instance, cmd, timeout=600)
1341
+
1342
+ if result.returncode == 0:
1343
+ print(f" Comparison generated: {output_name}")
1344
+ else:
1345
+ print(f" Warning: Comparison generation failed")
1346
+ if result.stderr:
1347
+ print(f" Error: {result.stderr}")
1348
+ else:
1349
+ print(" Warning: Final checkpoint not found, skipping comparison")
1350
+
1351
+ # Download results (including comparison if generated)
1352
+ print("\n" + "=" * 50)
1353
+ print("Downloading results...")
1354
+ print("=" * 50)
1355
+ client.download_results(instance)
1356
+
1357
+ print(f"\nTerminating instance {instance.id[:8]}...")
1358
+ client.terminate_instance(instance.id)
1359
+ print("Instance terminated.")
1360
+ print(f"\nFinal cost: ~${cost:.2f} ({elapsed/60:.1f} minutes)")
1361
+ else:
1362
+ print(f"\nInstance still running: {instance.ip}")
1363
+ print(f" Current cost: ~${cost:.2f}")
1364
+ if not training_completed:
1365
+ print(f" (Not terminating - training did not complete successfully)")
1366
+ print(f"Terminate with: python -m openadapt_ml.cloud.lambda_labs terminate {instance.id}")
1367
+
1368
+ elif args.command == "train-status":
1369
+ instances = client.list_instances()
1370
+ if not instances:
1371
+ print("No running instances.")
1372
+ return
1373
+
1374
+ if args.instance_id:
1375
+ instance = next((i for i in instances if i.id.startswith(args.instance_id)), None)
1376
+ if not instance:
1377
+ print(f"Instance {args.instance_id} not found.")
1378
+ return
1379
+ else:
1380
+ instance = instances[0]
1381
+
1382
+ print(f"Checking training status on {instance.ip}...")
1383
+ status = client.get_training_status(instance)
1384
+
1385
+ if status:
1386
+ print(f" Epoch: {status.get('epoch', 'N/A')}")
1387
+ print(f" Step: {status.get('step', 'N/A')}")
1388
+ print(f" Loss: {status.get('loss', 'N/A')}")
1389
+ print(f" Elapsed: {status.get('elapsed_time', 0):.1f}s")
1390
+ else:
1391
+ print(" No training log found (training may not have started yet)")
1392
+
1393
+ elif args.command == "checkpoints":
1394
+ # List checkpoints on remote instance
1395
+ instances = client.list_instances()
1396
+ if not instances:
1397
+ print("No running instances.")
1398
+ return
1399
+
1400
+ if args.instance_id:
1401
+ instance = next((i for i in instances if i.id.startswith(args.instance_id)), None)
1402
+ if not instance:
1403
+ print(f"Instance {args.instance_id} not found.")
1404
+ return
1405
+ else:
1406
+ instance = instances[0]
1407
+
1408
+ print(f"Checking checkpoints on {instance.ip}...")
1409
+
1410
+ ssh_cmd = [
1411
+ "ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=10",
1412
+ f"ubuntu@{instance.ip}",
1413
+ "ls -la ~/openadapt-ml/checkpoints/ 2>/dev/null && "
1414
+ "du -sh ~/openadapt-ml/checkpoints/ 2>/dev/null || echo 'No checkpoints directory found'"
1415
+ ]
1416
+
1417
+ result = subprocess.run(ssh_cmd, capture_output=True, text=True)
1418
+ if result.returncode == 0:
1419
+ print(result.stdout)
1420
+ else:
1421
+ print("No checkpoints found yet")
1422
+ if result.stderr:
1423
+ print(f" Error: {result.stderr}")
1424
+
1425
+ elif args.command == "refresh":
1426
+ # One-shot dashboard refresh
1427
+ import time as time_module
1428
+ from pathlib import Path
1429
+ from openadapt_ml.training.trainer import TrainingState, TrainingConfig, generate_training_dashboard
1430
+
1431
+ instances = client.list_instances()
1432
+ if not instances:
1433
+ print("No running instances.")
1434
+ return
1435
+
1436
+ if args.instance_id:
1437
+ instance = next((i for i in instances if i.id.startswith(args.instance_id)), None)
1438
+ if not instance:
1439
+ print(f"Instance {args.instance_id} not found.")
1440
+ return
1441
+ else:
1442
+ instance = instances[0]
1443
+
1444
+ # Use current job directory via symlink
1445
+ from openadapt_ml.training.trainer import get_current_job_directory, setup_job_directory
1446
+ base_dir = Path("training_output")
1447
+ base_dir.mkdir(exist_ok=True)
1448
+
1449
+ status = client.get_training_status(instance)
1450
+
1451
+ if status and status.get("step", 0) > 0:
1452
+ # Get or create job directory based on remote job_id
1453
+ remote_job_id = status.get("job_id", "")
1454
+ if remote_job_id:
1455
+ output_dir = setup_job_directory(base_dir, remote_job_id)
1456
+ else:
1457
+ output_dir = get_current_job_directory(base_dir) or base_dir
1458
+ dashboard_path = output_dir / "dashboard.html"
1459
+ log_path = output_dir / "training_log.json"
1460
+
1461
+ # Setup screenshots symlink if local capture path provided
1462
+ local_capture = args.capture if hasattr(args, 'capture') and args.capture else None
1463
+ if local_capture:
1464
+ setup_capture_screenshots_symlink(output_dir, local_capture)
1465
+
1466
+ # Rewrite evaluation paths from Lambda to relative
1467
+ if "evaluations" in status:
1468
+ status["evaluations"] = rewrite_evaluation_paths(status["evaluations"])
1469
+
1470
+ # Ensure instance metadata is present
1471
+ status["instance_ip"] = instance.ip
1472
+ status["instance_type"] = instance.instance_type
1473
+ status["total_epochs"] = status.get("total_epochs", 5)
1474
+
1475
+ # Save log
1476
+ log_path.write_text(json.dumps(status, indent=2))
1477
+
1478
+ # Generate dashboard
1479
+ state = TrainingState(job_id=remote_job_id)
1480
+ state.job_id = remote_job_id
1481
+ state.hostname = status.get("hostname", "lambda")
1482
+ state.instance_ip = instance.ip or ""
1483
+ state.instance_type = instance.instance_type
1484
+ state.config_path = status.get("config_path", "")
1485
+ # Use local capture path for screenshots if provided, else remote path
1486
+ state.capture_path = args.capture if args.capture else status.get("capture_path", "")
1487
+ state.epoch = status.get("epoch", 0)
1488
+ state.step = status.get("step", 0)
1489
+ state.loss = status.get("loss", 0)
1490
+ state.learning_rate = status.get("learning_rate", 5e-5)
1491
+ state.losses = status.get("losses", [])
1492
+ state.evaluations = status.get("evaluations", [])
1493
+ state.total_epochs = status.get("total_epochs", 5)
1494
+ state.start_time = time_module.time() - status.get("elapsed_time", 0)
1495
+ # Cloud provider info
1496
+ state.cloud_provider = "lambda"
1497
+ state.cloud_dashboard_url = "https://cloud.lambda.ai/instances"
1498
+ state.cloud_instance_id = instance.id
1499
+ state.setup_status = status.get("setup_status", "training")
1500
+ state.setup_logs = status.get("setup_logs", [])
1501
+
1502
+ config = TrainingConfig(
1503
+ num_train_epochs=status.get("total_epochs", 5),
1504
+ learning_rate=status.get("learning_rate", 5e-5)
1505
+ )
1506
+
1507
+ dashboard_path.write_text(generate_training_dashboard(state, config))
1508
+
1509
+ # Regenerate navigation for file:// protocol
1510
+ try:
1511
+ from openadapt_ml.training.trainer import regenerate_all_dashboards
1512
+ regenerate_all_dashboards(output_dir)
1513
+ except Exception:
1514
+ pass # Silent fail for navigation
1515
+
1516
+ epoch = status.get("epoch", 0)
1517
+ step = status.get("step", 0)
1518
+ loss = status.get("loss", 0)
1519
+ elapsed = status.get("elapsed_time", 0)
1520
+ print(f"Epoch {epoch+1}/{state.total_epochs} | Step {step} | Loss: {loss:.4f} | Elapsed: {elapsed:.0f}s")
1521
+ print(f"Dashboard: {dashboard_path.absolute()}")
1522
+
1523
+ if args.open:
1524
+ import subprocess as sp
1525
+ sp.run(["open", str(dashboard_path)], capture_output=True)
1526
+ else:
1527
+ print("No training data yet")
1528
+
1529
+ elif args.command == "monitor":
1530
+ # Live dashboard monitoring for Lambda training
1531
+ # Updates training_output/training_log.json so the existing dashboard auto-refreshes
1532
+ import time as time_module
1533
+ from pathlib import Path
1534
+
1535
+ # Stub mode - simulate training without actual GPU
1536
+ if getattr(args, 'stub', False):
1537
+ from openadapt_ml.training.stub_provider import StubTrainingProvider
1538
+ from openadapt_ml.training.trainer import (
1539
+ TrainingState, TrainingConfig, generate_training_dashboard
1540
+ )
1541
+
1542
+ print("\n[Stub Mode] Simulating training without GPU...")
1543
+ output_dir = Path("training_output")
1544
+ output_dir.mkdir(exist_ok=True)
1545
+
1546
+ # Start dashboard server if requested
1547
+ server_proc = None
1548
+ if args.open:
1549
+ server_proc = open_dashboard_in_browser(output_dir)
1550
+
1551
+ # Run stub training
1552
+ stub = StubTrainingProvider(
1553
+ output_dir=output_dir,
1554
+ epochs=5,
1555
+ steps_per_epoch=10,
1556
+ step_delay=0.3, # Fast simulation
1557
+ )
1558
+
1559
+ def update_dashboard(status):
1560
+ """Regenerate dashboard after each step."""
1561
+ state = TrainingState()
1562
+ state.job_id = status.get("job_id", "")
1563
+ state.hostname = status.get("hostname", "stub")
1564
+ state.instance_ip = "127.0.0.1"
1565
+ state.instance_type = "stub"
1566
+ state.epoch = status.get("epoch", 0)
1567
+ state.step = status.get("step", 0)
1568
+ state.loss = status.get("loss", 0)
1569
+ state.learning_rate = status.get("learning_rate", 5e-5)
1570
+ state.losses = status.get("losses", [])
1571
+ state.evaluations = status.get("evaluations", [])
1572
+ state.cloud_provider = "stub"
1573
+ state.setup_status = "training"
1574
+
1575
+ config = TrainingConfig(
1576
+ num_train_epochs=status.get("total_epochs", 5),
1577
+ learning_rate=state.learning_rate
1578
+ )
1579
+
1580
+ dashboard_path = output_dir / "dashboard.html"
1581
+ dashboard_path.write_text(generate_training_dashboard(state, config))
1582
+
1583
+ try:
1584
+ stub.run(callback=update_dashboard)
1585
+ except KeyboardInterrupt:
1586
+ print("\n[Stub] Interrupted by user.")
1587
+ finally:
1588
+ if server_proc:
1589
+ server_proc.terminate()
1590
+ print("[Stub] Dashboard server stopped.")
1591
+
1592
+ print(f"\n[Stub] Results in: {output_dir}")
1593
+ return
1594
+
1595
+ instances = client.list_instances()
1596
+ if not instances:
1597
+ print("No running instances.")
1598
+ return
1599
+
1600
+ if args.instance_id:
1601
+ instance = next((i for i in instances if i.id.startswith(args.instance_id)), None)
1602
+ if not instance:
1603
+ print(f"Instance {args.instance_id} not found.")
1604
+ return
1605
+ else:
1606
+ instance = instances[0]
1607
+
1608
+ if instance.status == "booting" or not instance.ip:
1609
+ print(f"Instance {instance.id[:8]} is still booting, waiting for IP...")
1610
+ while True:
1611
+ time_module.sleep(5)
1612
+ instances = client.list_instances()
1613
+ instance = next((i for i in instances if i.id == instance.id), None)
1614
+ if not instance:
1615
+ print("Instance terminated or not found.")
1616
+ return
1617
+ if instance.ip and instance.status == "active":
1618
+ print(f"Instance ready at {instance.ip}")
1619
+ break
1620
+ print(f" Status: {instance.status}...")
1621
+
1622
+ # Use job-scoped directory structure
1623
+ from openadapt_ml.training.trainer import (
1624
+ TrainingState, TrainingConfig, generate_training_dashboard,
1625
+ setup_job_directory, get_current_job_directory
1626
+ )
1627
+ base_dir = Path("training_output")
1628
+ base_dir.mkdir(exist_ok=True)
1629
+
1630
+ # Get current job directory or wait for first status to determine job_id
1631
+ output_dir = get_current_job_directory(base_dir) or base_dir
1632
+ dashboard_path = output_dir / "dashboard.html"
1633
+ log_path = output_dir / "training_log.json"
1634
+
1635
+ # Check for existing log with job_id
1636
+ current_job_id = None
1637
+ if log_path.exists():
1638
+ try:
1639
+ existing_log = json.loads(log_path.read_text())
1640
+ current_job_id = existing_log.get("job_id")
1641
+ except (json.JSONDecodeError, IOError):
1642
+ pass
1643
+
1644
+ print(f"\nMonitoring Lambda training on {instance.ip}")
1645
+ print(f"Dashboard: {dashboard_path.absolute()}")
1646
+ print(f"Polling every {args.interval}s (Ctrl+C to stop)\n")
1647
+
1648
+ # Generate initial dashboard if it doesn't exist
1649
+ if not dashboard_path.exists():
1650
+ state = TrainingState(job_id=current_job_id or "")
1651
+ state.cloud_provider = "lambda"
1652
+ state.cloud_dashboard_url = "https://cloud.lambda.ai/instances"
1653
+ state.cloud_instance_id = instance.id
1654
+ state.instance_ip = instance.ip or ""
1655
+ state.instance_type = instance.instance_type
1656
+ state.setup_status = "booting"
1657
+ state.setup_logs = ["Starting Lambda Cloud instance...", f"Instance ID: {instance.id[:8]}...", f"Instance type: {instance.instance_type}"]
1658
+ config = TrainingConfig(num_train_epochs=5, learning_rate=5e-5)
1659
+ dashboard_path.write_text(generate_training_dashboard(state, config))
1660
+
1661
+ # Open dashboard if requested via HTTP server
1662
+ server_proc = None
1663
+ if args.open:
1664
+ server_proc = open_dashboard_in_browser(output_dir)
1665
+
1666
+ last_step = 0
1667
+ last_epoch = -1
1668
+ auto_stop_loss = getattr(args, 'auto_stop_loss', 0.5)
1669
+ download_checkpoints = getattr(args, 'download_checkpoints', True)
1670
+ step_stall_count = 0 # Track how many times step hasn't increased
1671
+
1672
+ print(f" Auto-stop loss threshold: {auto_stop_loss}")
1673
+ print(f" Checkpoint download: {'enabled' if download_checkpoints else 'disabled'}")
1674
+
1675
+ try:
1676
+ while True:
1677
+ # Check for stop signal from dashboard
1678
+ if check_stop_signal(output_dir):
1679
+ print("\n Stop signal received from dashboard!")
1680
+ print(" Downloading final checkpoints...")
1681
+ if download_checkpoints:
1682
+ download_checkpoints_from_instance(instance.ip, output_dir)
1683
+
1684
+ # Update status with termination info before terminating
1685
+ termination_status = {
1686
+ "termination_status": "user_stop",
1687
+ "termination_message": "Training stopped by user via dashboard"
1688
+ }
1689
+ current_log = log_path.read_text() if log_path.exists() else "{}"
1690
+ import json as json_module
1691
+ current_data = json_module.loads(current_log)
1692
+ current_data.update(termination_status)
1693
+ log_path.write_text(json_module.dumps(current_data, indent=2))
1694
+
1695
+ print(f" Terminating instance {instance.id}...")
1696
+ client.terminate_instance(instance.id)
1697
+ # Remove stop signal
1698
+ (output_dir / "STOP_TRAINING").unlink(missing_ok=True)
1699
+ print(" Training stopped by user.")
1700
+ break
1701
+
1702
+ try:
1703
+ # Fetch training log from remote
1704
+ status = client.get_training_status(instance)
1705
+
1706
+ if status and status.get("step", 0) > 0:
1707
+ step = status.get("step", 0)
1708
+ epoch = status.get("epoch", 0)
1709
+ loss = status.get("loss", 0)
1710
+ elapsed = status.get("elapsed_time", 0)
1711
+ remote_job_id = status.get("job_id")
1712
+
1713
+ # Detect job_id change - clear old data if new job started
1714
+ if remote_job_id and current_job_id and remote_job_id != current_job_id:
1715
+ print(f"\n New job detected: {remote_job_id} (was: {current_job_id})")
1716
+ print(" Clearing old job data...")
1717
+ last_step = 0 # Reset step tracking
1718
+ current_job_id = remote_job_id
1719
+
1720
+ # Update local training log (dashboard polls this file)
1721
+ # Add total_epochs to status for dashboard
1722
+ status["total_epochs"] = status.get("total_epochs", 5)
1723
+ # Ensure instance metadata is present
1724
+ if not status.get("instance_ip"):
1725
+ status["instance_ip"] = instance.ip
1726
+ if not status.get("instance_type"):
1727
+ status["instance_type"] = instance.instance_type
1728
+ # Add cloud provider info
1729
+ status["cloud_provider"] = "lambda"
1730
+ status["cloud_dashboard_url"] = "https://cloud.lambda.ai/instances"
1731
+ status["cloud_instance_id"] = instance.id
1732
+ status["setup_status"] = status.get("setup_status", "training")
1733
+
1734
+ # Setup screenshots symlink if local capture path provided
1735
+ local_capture = args.capture if hasattr(args, 'capture') and args.capture else None
1736
+ if local_capture:
1737
+ setup_capture_screenshots_symlink(output_dir, local_capture)
1738
+
1739
+ # Rewrite evaluation paths from Lambda to relative
1740
+ if "evaluations" in status:
1741
+ status["evaluations"] = rewrite_evaluation_paths(status["evaluations"])
1742
+
1743
+ log_path.write_text(json.dumps(status, indent=2))
1744
+
1745
+ if step > last_step:
1746
+ print(f" Epoch {epoch+1} | Step {step} | Loss: {loss:.4f} | Elapsed: {elapsed:.0f}s")
1747
+ last_step = step
1748
+ step_stall_count = 0 # Reset stall counter when step increases
1749
+ if not current_job_id:
1750
+ current_job_id = remote_job_id
1751
+
1752
+ # Regenerate dashboard with updated data
1753
+ state = TrainingState()
1754
+ state.job_id = status.get("job_id", "")
1755
+ state.hostname = status.get("hostname", "lambda")
1756
+ state.instance_ip = instance.ip or ""
1757
+ state.instance_type = instance.instance_type
1758
+ state.epoch = epoch
1759
+ state.step = step
1760
+ state.loss = loss
1761
+ state.learning_rate = status.get("learning_rate", 5e-5)
1762
+ state.losses = status.get("losses", [])
1763
+ state.evaluations = status.get("evaluations", [])
1764
+ state.start_time = time_module.time() - elapsed
1765
+ # Cloud provider info
1766
+ state.cloud_provider = "lambda"
1767
+ state.cloud_dashboard_url = "https://cloud.lambda.ai/instances"
1768
+ state.cloud_instance_id = instance.id
1769
+ state.setup_status = status.get("setup_status", "training")
1770
+ state.setup_logs = status.get("setup_logs", [])
1771
+ state.termination_status = status.get("termination_status", "")
1772
+ state.termination_message = status.get("termination_message", "")
1773
+
1774
+ config = TrainingConfig(
1775
+ num_train_epochs=status.get("total_epochs", 5),
1776
+ learning_rate=status.get("learning_rate", 5e-5)
1777
+ )
1778
+
1779
+ dashboard_path.write_text(generate_training_dashboard(state, config))
1780
+
1781
+ # Download checkpoints on epoch change
1782
+ if download_checkpoints and epoch > last_epoch:
1783
+ print(f" Epoch {epoch+1} completed - downloading checkpoints...")
1784
+ if download_checkpoints_from_instance(instance.ip, output_dir):
1785
+ print(f" Checkpoints saved to {output_dir}/checkpoints/")
1786
+ else:
1787
+ print(" Warning: checkpoint download failed")
1788
+ last_epoch = epoch
1789
+
1790
+ # Auto-terminate when loss is low enough
1791
+ if loss < auto_stop_loss and loss > 0:
1792
+ print(f"\n Loss {loss:.4f} < threshold {auto_stop_loss}")
1793
+ print(" Downloading final checkpoints...")
1794
+ if download_checkpoints:
1795
+ download_checkpoints_from_instance(instance.ip, output_dir)
1796
+
1797
+ # Update status with termination info
1798
+ status["termination_status"] = "auto_low_loss"
1799
+ status["termination_message"] = f"Training auto-stopped: loss {loss:.4f} < threshold {auto_stop_loss}"
1800
+ log_path.write_text(json.dumps(status, indent=2))
1801
+
1802
+ print(f" Auto-terminating instance {instance.id}...")
1803
+ client.terminate_instance(instance.id)
1804
+ print(" Training completed (auto-stopped)!")
1805
+ break
1806
+ else:
1807
+ # Step didn't increase - check if training is complete
1808
+ step_stall_count += 1
1809
+ total_epochs = status.get("total_epochs", 5)
1810
+
1811
+ # If on last epoch and step hasn't increased for 3 polls, training is complete
1812
+ if epoch >= total_epochs - 1 and step_stall_count >= 3:
1813
+ print(f"\n Training complete (epoch {epoch+1}/{total_epochs}, step stopped increasing)")
1814
+ print(" Downloading final checkpoints...")
1815
+ if download_checkpoints:
1816
+ download_checkpoints_from_instance(instance.ip, output_dir)
1817
+
1818
+ # Update status with termination info
1819
+ status["termination_status"] = "auto_complete"
1820
+ status["termination_message"] = f"Training completed successfully ({epoch+1}/{total_epochs} epochs)"
1821
+ log_path.write_text(json.dumps(status, indent=2))
1822
+
1823
+ print(f" Terminating instance {instance.id}...")
1824
+ client.terminate_instance(instance.id)
1825
+ print(" Instance terminated.")
1826
+ break
1827
+
1828
+ else:
1829
+ print(" Waiting for training to start...")
1830
+
1831
+ except Exception as e:
1832
+ print(f" Poll error: {e}")
1833
+
1834
+ time_module.sleep(args.interval)
1835
+
1836
+ except KeyboardInterrupt:
1837
+ print("\n\nMonitoring stopped.")
1838
+ print(f"Dashboard: {dashboard_path.absolute()}")
1839
+ finally:
1840
+ # Clean up HTTP server if running
1841
+ if server_proc:
1842
+ server_proc.terminate()
1843
+ print("Dashboard server stopped.")
1844
+
1845
+ elif args.command == "files":
1846
+ instances = client.list_instances()
1847
+ if not instances:
1848
+ print("No running instances.")
1849
+ return
1850
+
1851
+ if args.instance_id:
1852
+ instance = next((i for i in instances if i.id.startswith(args.instance_id)), None)
1853
+ if not instance:
1854
+ print(f"Instance {args.instance_id} not found.")
1855
+ return
1856
+ else:
1857
+ instance = instances[0]
1858
+
1859
+ print(f"Files on {instance.ip} at {args.path}:")
1860
+ result = client.ssh_run(instance, f"find {args.path} -type f -name '*.pt' -o -name '*.json' -o -name '*.bin' 2>/dev/null | head -20", timeout=30)
1861
+ if result.stdout:
1862
+ for line in result.stdout.strip().split('\n'):
1863
+ print(f" {line}")
1864
+ else:
1865
+ print(" (no checkpoint files found)")
1866
+
1867
+ elif args.command == "kill":
1868
+ # Kill training/inference processes
1869
+ instances = client.list_instances()
1870
+ if not instances:
1871
+ print("No running instances.")
1872
+ if args.local:
1873
+ print("\nKilling local Lambda-related processes...")
1874
+ subprocess.run(
1875
+ ["pkill", "-f", "ssh.*ubuntu@.*openadapt"],
1876
+ capture_output=True
1877
+ )
1878
+ subprocess.run(
1879
+ ["pkill", "-f", "lambda_labs"],
1880
+ capture_output=True
1881
+ )
1882
+ print("Done.")
1883
+ return
1884
+
1885
+ if args.instance_id:
1886
+ instance = next((i for i in instances if i.id.startswith(args.instance_id)), None)
1887
+ if not instance:
1888
+ print(f"Instance {args.instance_id} not found.")
1889
+ return
1890
+ else:
1891
+ instance = instances[0]
1892
+
1893
+ print(f"Checking processes on {instance.ip}...")
1894
+
1895
+ # List Python processes first
1896
+ result = client.ssh_run(
1897
+ instance,
1898
+ "ps aux | grep python | grep -v grep | grep -v jupyter",
1899
+ timeout=30
1900
+ )
1901
+ if result.stdout.strip():
1902
+ print("Found Python processes:")
1903
+ for line in result.stdout.strip().split('\n'):
1904
+ print(f" {line[:100]}...")
1905
+ else:
1906
+ print("No training/inference Python processes found.")
1907
+ return
1908
+
1909
+ if args.all:
1910
+ print("\nKilling ALL Python processes (except jupyter)...")
1911
+ cmd = "pkill -f 'python.*train\\|python.*compare\\|python.*openadapt' || true"
1912
+ else:
1913
+ print("\nKilling training and inference processes...")
1914
+ cmd = "pkill -f 'python.*train' ; pkill -f 'python.*compare' || true"
1915
+
1916
+ result = client.ssh_run(instance, cmd, timeout=30)
1917
+ print("Remote processes killed.")
1918
+
1919
+ if args.local:
1920
+ print("\nKilling local Lambda-related processes...")
1921
+ subprocess.run(
1922
+ ["pkill", "-f", "ssh.*ubuntu@.*openadapt"],
1923
+ capture_output=True
1924
+ )
1925
+ subprocess.run(
1926
+ ["pkill", "-f", "lambda_labs.*train"],
1927
+ capture_output=True
1928
+ )
1929
+ print("Local processes killed.")
1930
+
1931
+ print("\nDone. Current status:")
1932
+ result = client.ssh_run(
1933
+ instance,
1934
+ "ps aux | grep python | grep -v grep | grep -v jupyter | wc -l",
1935
+ timeout=30
1936
+ )
1937
+ count = result.stdout.strip()
1938
+ print(f" {count} Python processes remaining on instance")
1939
+
1940
+ elif args.command == "check":
1941
+ # Analyze training status and early stopping
1942
+ instances = client.list_instances()
1943
+ if not instances:
1944
+ print("No running instances.")
1945
+ return
1946
+
1947
+ if args.instance_id:
1948
+ instance = next((i for i in instances if i.id.startswith(args.instance_id)), None)
1949
+ if not instance:
1950
+ print(f"Instance {args.instance_id} not found.")
1951
+ return
1952
+ else:
1953
+ instance = instances[0]
1954
+
1955
+ print(f"Checking training on {instance.ip}...")
1956
+
1957
+ # Get training log
1958
+ result = client.ssh_run(
1959
+ instance,
1960
+ "cat ~/openadapt-ml/training_output/training_log.json 2>/dev/null",
1961
+ timeout=30
1962
+ )
1963
+
1964
+ if not result.stdout.strip():
1965
+ print("No training log found.")
1966
+ return
1967
+
1968
+ try:
1969
+ data = json.loads(result.stdout)
1970
+ losses = data.get("losses", [])
1971
+ except json.JSONDecodeError:
1972
+ print("Could not parse training log.")
1973
+ return
1974
+
1975
+ if not losses:
1976
+ print("No training data yet.")
1977
+ return
1978
+
1979
+ total_steps = len(losses)
1980
+ epochs = sorted(set(l["epoch"] for l in losses))
1981
+ total_epochs = data.get("total_epochs", 5)
1982
+ min_loss = min(l["loss"] for l in losses)
1983
+ current_loss = losses[-1]["loss"]
1984
+
1985
+ print(f"\n{'='*50}")
1986
+ print(f"TRAINING STATUS")
1987
+ print(f"{'='*50}")
1988
+ print(f"Steps: {total_steps}")
1989
+ print(f"Epochs: {max(epochs)+1}/{total_epochs}")
1990
+ print(f"Current loss: {current_loss:.4f}")
1991
+ print(f"Min loss: {min_loss:.4f}")
1992
+
1993
+ # Check if training is running
1994
+ proc_result = client.ssh_run(
1995
+ instance,
1996
+ "ps aux | grep 'python.*train' | grep -v grep | wc -l",
1997
+ timeout=30
1998
+ )
1999
+ is_running = int(proc_result.stdout.strip()) > 0
2000
+
2001
+ if is_running:
2002
+ print(f"Status: RUNNING")
2003
+ else:
2004
+ print(f"Status: STOPPED")
2005
+
2006
+ # Early stopping analysis
2007
+ window = min(args.window, len(losses))
2008
+ if window < 2:
2009
+ print("\nNot enough data for early stopping analysis.")
2010
+ else:
2011
+ recent_losses = [l["loss"] for l in losses[-window:]]
2012
+ older_losses = [l["loss"] for l in losses[-window*2:-window]] if len(losses) >= window*2 else [l["loss"] for l in losses[:window]]
2013
+
2014
+ recent_avg = sum(recent_losses) / len(recent_losses)
2015
+ older_avg = sum(older_losses) / len(older_losses) if older_losses else recent_avg
2016
+
2017
+ improvement = (older_avg - recent_avg) / older_avg if older_avg > 0 else 0
2018
+ loss_variance = max(recent_losses) - min(recent_losses)
2019
+
2020
+ print(f"\n{'='*50}")
2021
+ print(f"EARLY STOPPING ANALYSIS (window={window})")
2022
+ print(f"{'='*50}")
2023
+ print(f"Recent avg loss: {recent_avg:.4f}")
2024
+ print(f"Prior avg loss: {older_avg:.4f}")
2025
+ print(f"Improvement: {improvement*100:.2f}%")
2026
+ print(f"Loss variance: {loss_variance:.4f}")
2027
+
2028
+ should_stop = improvement < args.threshold and loss_variance < 0.1
2029
+ if should_stop:
2030
+ print(f"\n⚠️ EARLY STOPPING RECOMMENDED")
2031
+ print(f" Loss has plateaued (improvement < {args.threshold*100}%)")
2032
+ if not is_running:
2033
+ print(f" (Training already stopped)")
2034
+ else:
2035
+ print(f"\n To stop: uv run python -m openadapt_ml.cloud.lambda_labs kill")
2036
+ else:
2037
+ print(f"\n✓ Training still improving, continue.")
2038
+
2039
+ # Time estimate
2040
+ if is_running and len(losses) >= 2:
2041
+ avg_time_per_step = losses[-1].get("time", 0) / len(losses) if losses[-1].get("time") else 50
2042
+ steps_per_epoch = len(losses) / (max(epochs) + 1)
2043
+ remaining_epochs = total_epochs - max(epochs) - 1
2044
+ remaining_steps = remaining_epochs * steps_per_epoch
2045
+ eta_seconds = remaining_steps * avg_time_per_step
2046
+ eta_mins = eta_seconds / 60
2047
+
2048
+ print(f"\n{'='*50}")
2049
+ print(f"TIME ESTIMATE")
2050
+ print(f"{'='*50}")
2051
+ print(f"Remaining epochs: {remaining_epochs}")
2052
+ print(f"Est. remaining steps: {remaining_steps:.0f}")
2053
+ print(f"ETA: {eta_mins:.1f} minutes")
2054
+
2055
+ elif args.command == "compare":
2056
+ # Run comparison on Lambda and sync back
2057
+ instances = client.list_instances()
2058
+ if not instances:
2059
+ print("No running instances.")
2060
+ return
2061
+
2062
+ if args.instance_id:
2063
+ instance = next((i for i in instances if i.id.startswith(args.instance_id)), None)
2064
+ if not instance:
2065
+ print(f"Instance {args.instance_id} not found.")
2066
+ return
2067
+ else:
2068
+ instance = instances[0]
2069
+
2070
+ # Determine checkpoint to use
2071
+ if args.checkpoint:
2072
+ checkpoint_path = args.checkpoint
2073
+ elif args.epoch is not None:
2074
+ checkpoint_path = f"/home/ubuntu/openadapt-ml/checkpoints/epoch_{args.epoch}"
2075
+ else:
2076
+ # Use latest (main checkpoint)
2077
+ checkpoint_path = "/home/ubuntu/openadapt-ml/checkpoints/qwen3vl2b_capture_lora"
2078
+
2079
+ # Check if checkpoint exists
2080
+ result = client.ssh_run(
2081
+ instance,
2082
+ f"ls {checkpoint_path}/adapter_config.json 2>/dev/null && echo 'exists'",
2083
+ timeout=30
2084
+ )
2085
+ if "exists" not in result.stdout:
2086
+ print(f"Checkpoint not found at {checkpoint_path}")
2087
+ # List available checkpoints
2088
+ result = client.ssh_run(
2089
+ instance,
2090
+ "ls -la ~/openadapt-ml/checkpoints/",
2091
+ timeout=30
2092
+ )
2093
+ print(f"Available checkpoints:\n{result.stdout}")
2094
+ return
2095
+
2096
+ print(f"Running comparison on {instance.ip}...")
2097
+ print(f"Using checkpoint: {checkpoint_path}")
2098
+
2099
+ # Run comparison on Lambda
2100
+ output_name = f"comparison_{time.strftime('%H%M%S')}.html"
2101
+ cmd = f"""cd ~/openadapt-ml && source .venv/bin/activate && \
2102
+ python -m openadapt_ml.scripts.compare \
2103
+ --capture ~/capture \
2104
+ --checkpoint {checkpoint_path} \
2105
+ --output training_output/{output_name} 2>&1"""
2106
+
2107
+ print("Generating predictions (this may take a few minutes)...")
2108
+ result = client.ssh_run(instance, cmd, timeout=600)
2109
+
2110
+ if result.returncode != 0:
2111
+ print(f"Comparison failed:\n{result.stderr}")
2112
+ return
2113
+
2114
+ # Check if file was created
2115
+ result = client.ssh_run(
2116
+ instance,
2117
+ f"ls -la ~/openadapt-ml/training_output/{output_name}",
2118
+ timeout=30
2119
+ )
2120
+ if result.returncode != 0:
2121
+ print("Comparison file not created.")
2122
+ return
2123
+
2124
+ print(f"Comparison generated: {output_name}")
2125
+
2126
+ # Sync back to local
2127
+ local_output = Path("training_output") / output_name
2128
+ local_output.parent.mkdir(parents=True, exist_ok=True)
2129
+
2130
+ print(f"Syncing to {local_output}...")
2131
+ subprocess.run([
2132
+ "rsync", "-avz",
2133
+ f"ubuntu@{instance.ip}:~/openadapt-ml/training_output/{output_name}",
2134
+ str(local_output)
2135
+ ], capture_output=True)
2136
+
2137
+ print(f"Done! Comparison saved to: {local_output}")
2138
+
2139
+ if args.open:
2140
+ subprocess.run(["open", str(local_output)], capture_output=True)
2141
+ print("Opened in browser.")
2142
+
2143
+ elif args.command == "download":
2144
+ instances = client.list_instances()
2145
+ if not instances:
2146
+ print("No running instances.")
2147
+ return
2148
+
2149
+ if args.instance_id:
2150
+ instance = next((i for i in instances if i.id.startswith(args.instance_id)), None)
2151
+ if not instance:
2152
+ print(f"Instance {args.instance_id} not found.")
2153
+ return
2154
+ else:
2155
+ instance = instances[0]
2156
+
2157
+ client.download_results(instance, local_path=args.output)
2158
+
2159
+ elif args.command == "results":
2160
+ # Download results and generate comparison viewer
2161
+ instances = client.list_instances()
2162
+ if not instances:
2163
+ print("No running instances.")
2164
+ return
2165
+
2166
+ if args.instance_id:
2167
+ instance = next((i for i in instances if i.id.startswith(args.instance_id)), None)
2168
+ if not instance:
2169
+ print(f"Instance {args.instance_id} not found.")
2170
+ return
2171
+ else:
2172
+ instance = instances[0]
2173
+
2174
+ # Download results
2175
+ print("Step 1: Downloading training results...")
2176
+ client.download_results(instance)
2177
+
2178
+ # Generate comparison viewer
2179
+ print("\nStep 2: Generating comparison viewer...")
2180
+ checkpoint_path = "checkpoints_lambda/qwen3vl2b_capture_lora"
2181
+
2182
+ import subprocess as sp
2183
+ cmd = [
2184
+ "uv", "run", "python", "-m", "openadapt_ml.scripts.compare",
2185
+ "--capture", args.capture,
2186
+ "--checkpoint", checkpoint_path,
2187
+ ]
2188
+ if args.goal:
2189
+ cmd.extend(["--goal", args.goal])
2190
+ if args.open:
2191
+ cmd.append("--open")
2192
+
2193
+ result = sp.run(cmd)
2194
+ if result.returncode == 0:
2195
+ print("\nComparison viewer generated!")
2196
+ if not args.open:
2197
+ print(f"Open with: open {args.capture}/comparison.html")
2198
+ else:
2199
+ print("Warning: Failed to generate comparison viewer")
2200
+
2201
+ elif args.command == "serve":
2202
+ # Start web server for live dashboard with stop button support
2203
+ import http.server
2204
+ import socketserver
2205
+ import threading
2206
+ import time as time_module
2207
+ from pathlib import Path
2208
+
2209
+ output_dir = Path(args.output) if hasattr(args, 'output') else Path("training_output")
2210
+ port = args.port
2211
+
2212
+ if not output_dir.exists():
2213
+ print(f"No {output_dir} directory. Run 'refresh' first.")
2214
+ return
2215
+
2216
+ # Define handler with /api/stop support
2217
+ class Handler(http.server.SimpleHTTPRequestHandler):
2218
+ def __init__(self, *args, **kwargs):
2219
+ super().__init__(*args, directory=str(output_dir), **kwargs)
2220
+
2221
+ def do_POST(self):
2222
+ if self.path == '/api/stop':
2223
+ # Create stop signal file
2224
+ stop_file = output_dir / "STOP_TRAINING"
2225
+ stop_file.touch()
2226
+ self.send_response(200)
2227
+ self.send_header('Content-Type', 'application/json')
2228
+ self.send_header('Access-Control-Allow-Origin', '*')
2229
+ self.end_headers()
2230
+ self.wfile.write(b'{"status": "stop signal created"}')
2231
+ print(f" Stop signal created: {stop_file}")
2232
+ else:
2233
+ self.send_error(404)
2234
+
2235
+ def do_OPTIONS(self):
2236
+ # Handle CORS preflight
2237
+ self.send_response(200)
2238
+ self.send_header('Access-Control-Allow-Origin', '*')
2239
+ self.send_header('Access-Control-Allow-Methods', 'POST, OPTIONS')
2240
+ self.send_header('Access-Control-Allow-Headers', 'Content-Type')
2241
+ self.end_headers()
2242
+
2243
+ def log_message(self, format, *args):
2244
+ pass # Suppress log messages
2245
+
2246
+
2247
+ # Start web server
2248
+ with socketserver.TCPServer(("", port), Handler) as httpd:
2249
+ url = f"http://localhost:{port}/dashboard.html"
2250
+ print(f"\nDashboard server started at {url}")
2251
+ print("Press Ctrl+C to stop\n")
2252
+
2253
+ if args.open:
2254
+ subprocess.run(["open", url], capture_output=True)
2255
+
2256
+ try:
2257
+ httpd.serve_forever()
2258
+ except KeyboardInterrupt:
2259
+ print("\nServer stopped.")
2260
+
2261
+ elif args.command == "sync":
2262
+ # Sync training output from Lambda and regenerate navigation for file:// protocol
2263
+ from pathlib import Path
2264
+ from openadapt_ml.training.trainer import (
2265
+ TrainingState, TrainingConfig, generate_training_dashboard,
2266
+ regenerate_all_dashboards
2267
+ )
2268
+
2269
+ instances = client.list_instances()
2270
+ if not instances:
2271
+ print("No running instances.")
2272
+ return
2273
+
2274
+ if args.instance_id:
2275
+ instance = next((i for i in instances if i.id.startswith(args.instance_id)), None)
2276
+ if not instance:
2277
+ print(f"Instance {args.instance_id} not found.")
2278
+ return
2279
+ else:
2280
+ instance = instances[0]
2281
+
2282
+ output_dir = Path(args.output)
2283
+ output_dir.mkdir(exist_ok=True)
2284
+
2285
+ print(f"Syncing training output from {instance.ip}...")
2286
+
2287
+ # Sync all training output files
2288
+ rsync_cmd = [
2289
+ "rsync", "-avz", "--progress",
2290
+ "-e", "ssh -o StrictHostKeyChecking=no",
2291
+ f"ubuntu@{instance.ip}:~/openadapt-ml/training_output/",
2292
+ str(output_dir) + "/"
2293
+ ]
2294
+ result = subprocess.run(rsync_cmd, capture_output=False)
2295
+
2296
+ if result.returncode != 0:
2297
+ print("Warning: rsync may have had issues")
2298
+
2299
+ # Update dashboard with instance metadata
2300
+ log_path = output_dir / "training_log.json"
2301
+ dashboard_path = output_dir / "dashboard.html"
2302
+
2303
+ if log_path.exists():
2304
+ try:
2305
+ import time as time_module
2306
+ status = json.loads(log_path.read_text())
2307
+
2308
+ # Update with instance info
2309
+ status["instance_ip"] = instance.ip
2310
+ status["instance_type"] = instance.instance_type
2311
+ status["cloud_provider"] = "lambda"
2312
+ status["cloud_dashboard_url"] = "https://cloud.lambda.ai/instances"
2313
+ status["cloud_instance_id"] = instance.id
2314
+
2315
+ log_path.write_text(json.dumps(status, indent=2))
2316
+
2317
+ # Generate updated dashboard
2318
+ state = TrainingState()
2319
+ state.job_id = status.get("job_id", "")
2320
+ state.hostname = status.get("hostname", "lambda")
2321
+ state.instance_ip = instance.ip or ""
2322
+ state.instance_type = instance.instance_type
2323
+ state.config_path = status.get("config_path", "")
2324
+ state.capture_path = status.get("capture_path", "")
2325
+ state.epoch = status.get("epoch", 0)
2326
+ state.step = status.get("step", 0)
2327
+ state.loss = status.get("loss", 0)
2328
+ state.learning_rate = status.get("learning_rate", 5e-5)
2329
+ state.losses = status.get("losses", [])
2330
+ state.evaluations = status.get("evaluations", [])
2331
+ state.total_epochs = status.get("total_epochs", 5)
2332
+ state.start_time = time_module.time() - status.get("elapsed_time", 0)
2333
+ state.cloud_provider = "lambda"
2334
+ state.cloud_dashboard_url = "https://cloud.lambda.ai/instances"
2335
+ state.cloud_instance_id = instance.id
2336
+
2337
+ config = TrainingConfig(
2338
+ num_train_epochs=status.get("total_epochs", 5),
2339
+ learning_rate=status.get("learning_rate", 5e-5)
2340
+ )
2341
+
2342
+ dashboard_path.write_text(generate_training_dashboard(state, config))
2343
+ except Exception as e:
2344
+ print(f"Warning: Could not update dashboard: {e}")
2345
+
2346
+ # Regenerate ALL dashboards with static navigation (for file:// protocol)
2347
+ print("Regenerating navigation links...")
2348
+ try:
2349
+ regenerated = regenerate_all_dashboards(output_dir)
2350
+ print(f" Updated {len(regenerated)} files with static navigation")
2351
+ except Exception as e:
2352
+ print(f"Warning: Navigation regeneration failed: {e}")
2353
+
2354
+ # Summary
2355
+ files = list(output_dir.glob("*.html"))
2356
+ print(f"\nSynced {len(files)} HTML files to {output_dir}/")
2357
+ for f in sorted(files):
2358
+ print(f" - {f.name}")
2359
+
2360
+ print(f"\nDashboard: {dashboard_path.absolute()}")
2361
+
2362
+ if args.open:
2363
+ subprocess.run(["open", str(dashboard_path)], capture_output=True)
2364
+
2365
+ elif args.command == "viewer":
2366
+ # Regenerate and open local viewer (no Lambda required)
2367
+ from pathlib import Path
2368
+ from openadapt_ml.training.trainer import regenerate_all_dashboards
2369
+ import re
2370
+
2371
+ output_dir = Path(args.output)
2372
+
2373
+ if not output_dir.exists():
2374
+ print(f"Error: {output_dir} does not exist")
2375
+ print("Run training or sync first to populate the directory.")
2376
+ return
2377
+
2378
+ if not (output_dir / "training_log.json").exists():
2379
+ print(f"Error: No training_log.json found in {output_dir}")
2380
+ print("This directory doesn't contain training results.")
2381
+ return
2382
+
2383
+ # Auto-link local screenshots if available
2384
+ screenshots_link = output_dir / "screenshots"
2385
+ if not screenshots_link.exists():
2386
+ # Try to find capture ID from training log or predictions
2387
+ try:
2388
+ capture_id = None
2389
+
2390
+ # First try training log
2391
+ log_data = json.loads((output_dir / "training_log.json").read_text())
2392
+ capture_path = log_data.get("capture_path", "")
2393
+ capture_match = re.search(r'capture_(\d+)', capture_path)
2394
+ if capture_match:
2395
+ capture_id = capture_match.group(1)
2396
+
2397
+ # If not found, try predictions JSON files
2398
+ if not capture_id:
2399
+ for pred_file in output_dir.glob("predictions_*.json"):
2400
+ pred_data = json.loads(pred_file.read_text())
2401
+ base_data = pred_data.get("base_data", [])
2402
+ if base_data:
2403
+ image_path = base_data[0].get("image_path", "")
2404
+ capture_match = re.search(r'capture_(\d+)', image_path)
2405
+ if capture_match:
2406
+ capture_id = capture_match.group(1)
2407
+ break
2408
+
2409
+ if capture_id:
2410
+ # Search for local screenshots in openadapt-capture
2411
+ openadapt_capture_dir = Path.home() / "oa" / "src" / "openadapt-capture"
2412
+ if openadapt_capture_dir.exists():
2413
+ for capture_dir in openadapt_capture_dir.iterdir():
2414
+ if capture_dir.is_dir():
2415
+ screenshots_dir = capture_dir / "screenshots"
2416
+ if screenshots_dir.exists():
2417
+ # Check if this capture has our screenshots
2418
+ sample_file = list(screenshots_dir.glob(f"capture_{capture_id}_step_*.png"))
2419
+ if sample_file:
2420
+ print(f"Found local screenshots in {screenshots_dir}")
2421
+ screenshots_link.symlink_to(screenshots_dir)
2422
+ print(f" Linked: {screenshots_link} -> {screenshots_dir}")
2423
+ break
2424
+ except Exception as e:
2425
+ pass # Silently continue if auto-link fails
2426
+
2427
+ print(f"Regenerating viewer from {output_dir}...")
2428
+ regenerated = regenerate_all_dashboards(output_dir)
2429
+ print(f" Updated {len(regenerated)} files")
2430
+
2431
+ # Show path info
2432
+ if args.dashboard:
2433
+ target = output_dir / "dashboard.html"
2434
+ else:
2435
+ target = output_dir / "viewer.html"
2436
+
2437
+ print(f"\nGenerated: {target.absolute()}")
2438
+ print(f"View with: uv run python -m openadapt_ml.cloud.lambda_labs serve --open")
2439
+
2440
+ if args.open:
2441
+ subprocess.run(["open", str(target)], capture_output=True)
2442
+
2443
+
2444
+ if __name__ == "__main__":
2445
+ main()