openadapt-ml 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/__init__.py +0 -0
- openadapt_ml/benchmarks/__init__.py +125 -0
- openadapt_ml/benchmarks/agent.py +825 -0
- openadapt_ml/benchmarks/azure.py +761 -0
- openadapt_ml/benchmarks/base.py +366 -0
- openadapt_ml/benchmarks/cli.py +884 -0
- openadapt_ml/benchmarks/data_collection.py +432 -0
- openadapt_ml/benchmarks/runner.py +381 -0
- openadapt_ml/benchmarks/waa.py +704 -0
- openadapt_ml/cloud/__init__.py +5 -0
- openadapt_ml/cloud/azure_inference.py +441 -0
- openadapt_ml/cloud/lambda_labs.py +2445 -0
- openadapt_ml/cloud/local.py +790 -0
- openadapt_ml/config.py +56 -0
- openadapt_ml/datasets/__init__.py +0 -0
- openadapt_ml/datasets/next_action.py +507 -0
- openadapt_ml/evals/__init__.py +23 -0
- openadapt_ml/evals/grounding.py +241 -0
- openadapt_ml/evals/plot_eval_metrics.py +174 -0
- openadapt_ml/evals/trajectory_matching.py +486 -0
- openadapt_ml/grounding/__init__.py +45 -0
- openadapt_ml/grounding/base.py +236 -0
- openadapt_ml/grounding/detector.py +570 -0
- openadapt_ml/ingest/__init__.py +43 -0
- openadapt_ml/ingest/capture.py +312 -0
- openadapt_ml/ingest/loader.py +232 -0
- openadapt_ml/ingest/synthetic.py +1102 -0
- openadapt_ml/models/__init__.py +0 -0
- openadapt_ml/models/api_adapter.py +171 -0
- openadapt_ml/models/base_adapter.py +59 -0
- openadapt_ml/models/dummy_adapter.py +42 -0
- openadapt_ml/models/qwen_vl.py +426 -0
- openadapt_ml/runtime/__init__.py +0 -0
- openadapt_ml/runtime/policy.py +182 -0
- openadapt_ml/schemas/__init__.py +53 -0
- openadapt_ml/schemas/sessions.py +122 -0
- openadapt_ml/schemas/validation.py +252 -0
- openadapt_ml/scripts/__init__.py +0 -0
- openadapt_ml/scripts/compare.py +1490 -0
- openadapt_ml/scripts/demo_policy.py +62 -0
- openadapt_ml/scripts/eval_policy.py +287 -0
- openadapt_ml/scripts/make_gif.py +153 -0
- openadapt_ml/scripts/prepare_synthetic.py +43 -0
- openadapt_ml/scripts/run_qwen_login_benchmark.py +192 -0
- openadapt_ml/scripts/train.py +174 -0
- openadapt_ml/training/__init__.py +0 -0
- openadapt_ml/training/benchmark_viewer.py +1538 -0
- openadapt_ml/training/shared_ui.py +157 -0
- openadapt_ml/training/stub_provider.py +276 -0
- openadapt_ml/training/trainer.py +2446 -0
- openadapt_ml/training/viewer.py +2970 -0
- openadapt_ml-0.1.0.dist-info/METADATA +818 -0
- openadapt_ml-0.1.0.dist-info/RECORD +55 -0
- openadapt_ml-0.1.0.dist-info/WHEEL +4 -0
- openadapt_ml-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,2445 @@
|
|
|
1
|
+
"""Lambda Labs cloud GPU integration.
|
|
2
|
+
|
|
3
|
+
Lambda Labs provides affordable GPU instances for training:
|
|
4
|
+
- A100 40GB: ~$1.10/hour
|
|
5
|
+
- H100: ~$2.00/hour
|
|
6
|
+
- A10: ~$0.60/hour
|
|
7
|
+
|
|
8
|
+
API docs: https://cloud.lambdalabs.com/api/v1/docs
|
|
9
|
+
|
|
10
|
+
Usage:
|
|
11
|
+
# Set API key
|
|
12
|
+
export LAMBDA_API_KEY=your_key_here
|
|
13
|
+
|
|
14
|
+
# List available instances
|
|
15
|
+
python -m openadapt_ml.cloud.lambda_labs list
|
|
16
|
+
|
|
17
|
+
# Launch instance for training
|
|
18
|
+
python -m openadapt_ml.cloud.lambda_labs launch --type gpu_1x_a100
|
|
19
|
+
|
|
20
|
+
# Check running instances
|
|
21
|
+
python -m openadapt_ml.cloud.lambda_labs status
|
|
22
|
+
|
|
23
|
+
# Terminate instance
|
|
24
|
+
python -m openadapt_ml.cloud.lambda_labs terminate <instance_id>
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
from __future__ import annotations
|
|
28
|
+
|
|
29
|
+
import json
|
|
30
|
+
import os
|
|
31
|
+
import subprocess
|
|
32
|
+
import sys
|
|
33
|
+
import time
|
|
34
|
+
from dataclasses import dataclass
|
|
35
|
+
from pathlib import Path
|
|
36
|
+
from typing import Any
|
|
37
|
+
|
|
38
|
+
import requests
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
API_BASE = "https://cloud.lambdalabs.com/api/v1"
|
|
42
|
+
|
|
43
|
+
# Default port for HTTP server
|
|
44
|
+
DEFAULT_SERVER_PORT = 8765
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def start_dashboard_server(output_dir: Path, port: int = DEFAULT_SERVER_PORT) -> tuple[subprocess.Popen, str]:
|
|
48
|
+
"""Start a background HTTP server for the dashboard.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
output_dir: Directory containing dashboard files
|
|
52
|
+
port: Port to serve on
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
(process, url): The server process and the dashboard URL
|
|
56
|
+
"""
|
|
57
|
+
import webbrowser
|
|
58
|
+
import threading
|
|
59
|
+
|
|
60
|
+
# Start simple HTTP server in background thread
|
|
61
|
+
server_proc = subprocess.Popen(
|
|
62
|
+
[sys.executable, "-m", "http.server", str(port)],
|
|
63
|
+
cwd=str(output_dir),
|
|
64
|
+
stdout=subprocess.DEVNULL,
|
|
65
|
+
stderr=subprocess.DEVNULL,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
url = f"http://localhost:{port}/dashboard.html"
|
|
69
|
+
|
|
70
|
+
# Give server time to start
|
|
71
|
+
time.sleep(0.5)
|
|
72
|
+
|
|
73
|
+
return server_proc, url
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def open_dashboard_in_browser(output_dir: Path, port: int = DEFAULT_SERVER_PORT):
|
|
77
|
+
"""Start HTTP server and open dashboard in browser.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
output_dir: Directory containing dashboard files
|
|
81
|
+
port: Port to serve on
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
Server process (caller should call .terminate() when done), or None if failed
|
|
85
|
+
"""
|
|
86
|
+
import webbrowser
|
|
87
|
+
|
|
88
|
+
try:
|
|
89
|
+
server_proc, url = start_dashboard_server(output_dir, port)
|
|
90
|
+
webbrowser.open(url)
|
|
91
|
+
print(f"Dashboard: {url}")
|
|
92
|
+
print(" Stop Training button enabled in dashboard")
|
|
93
|
+
return server_proc
|
|
94
|
+
except Exception as e:
|
|
95
|
+
print(f"Warning: Could not start dashboard server: {e}")
|
|
96
|
+
return None
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def setup_capture_screenshots_symlink(output_dir: Path, capture_path: str | Path) -> bool:
|
|
100
|
+
"""Create symlink from output_dir/screenshots to capture's screenshots folder.
|
|
101
|
+
|
|
102
|
+
This allows the dashboard to serve screenshots via relative paths.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
output_dir: Training output directory (e.g., training_output/job_id/)
|
|
106
|
+
capture_path: Path to capture directory (local)
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
True if symlink created successfully
|
|
110
|
+
"""
|
|
111
|
+
capture_path = Path(capture_path)
|
|
112
|
+
screenshots_src = capture_path / "screenshots"
|
|
113
|
+
screenshots_dst = output_dir / "screenshots"
|
|
114
|
+
|
|
115
|
+
if not screenshots_src.exists():
|
|
116
|
+
return False
|
|
117
|
+
|
|
118
|
+
# Remove existing symlink or directory
|
|
119
|
+
if screenshots_dst.is_symlink():
|
|
120
|
+
screenshots_dst.unlink()
|
|
121
|
+
elif screenshots_dst.exists():
|
|
122
|
+
return False # Don't overwrite real directory
|
|
123
|
+
|
|
124
|
+
try:
|
|
125
|
+
screenshots_dst.symlink_to(screenshots_src.resolve())
|
|
126
|
+
return True
|
|
127
|
+
except Exception:
|
|
128
|
+
return False
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def rewrite_evaluation_paths(evaluations: list[dict], remote_prefix: str = "/home/ubuntu/capture/") -> list[dict]:
|
|
132
|
+
"""Rewrite Lambda paths in evaluations to relative paths.
|
|
133
|
+
|
|
134
|
+
Converts: /home/ubuntu/capture/screenshots/foo.png -> screenshots/foo.png
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
evaluations: List of evaluation dicts with image_path
|
|
138
|
+
remote_prefix: The Lambda path prefix to replace
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
Evaluations with rewritten paths
|
|
142
|
+
"""
|
|
143
|
+
for ev in evaluations:
|
|
144
|
+
if "image_path" in ev and ev["image_path"].startswith(remote_prefix):
|
|
145
|
+
ev["image_path"] = ev["image_path"].replace(remote_prefix, "")
|
|
146
|
+
return evaluations
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def download_checkpoints_from_instance(instance_ip: str, output_dir: Path, ssh_key: str | None = None) -> bool:
|
|
150
|
+
"""Download checkpoints from Lambda instance.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
instance_ip: IP address of Lambda instance
|
|
154
|
+
output_dir: Local directory to save checkpoints
|
|
155
|
+
ssh_key: Path to SSH key (uses default if not provided)
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
True if download succeeded
|
|
159
|
+
"""
|
|
160
|
+
checkpoints_dir = output_dir / "checkpoints"
|
|
161
|
+
checkpoints_dir.mkdir(parents=True, exist_ok=True)
|
|
162
|
+
|
|
163
|
+
ssh_key = ssh_key or str(Path.home() / ".ssh" / "lambda_id_ed25519")
|
|
164
|
+
ssh_opts = f"-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -i {ssh_key}"
|
|
165
|
+
|
|
166
|
+
# Download checkpoints from remote
|
|
167
|
+
remote_path = f"ubuntu@{instance_ip}:~/openadapt-ml/checkpoints/"
|
|
168
|
+
local_path = str(checkpoints_dir) + "/"
|
|
169
|
+
|
|
170
|
+
cmd = f"rsync -avz --progress -e 'ssh {ssh_opts}' {remote_path} {local_path}"
|
|
171
|
+
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
|
|
172
|
+
|
|
173
|
+
if result.returncode == 0:
|
|
174
|
+
return True
|
|
175
|
+
return False
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def check_stop_signal(output_dir: Path) -> bool:
|
|
179
|
+
"""Check if stop signal file exists.
|
|
180
|
+
|
|
181
|
+
The dashboard can create this file to signal training should stop.
|
|
182
|
+
"""
|
|
183
|
+
stop_file = output_dir / "STOP_TRAINING"
|
|
184
|
+
return stop_file.exists()
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
@dataclass
|
|
188
|
+
class InstanceType:
|
|
189
|
+
"""Lambda Labs instance type."""
|
|
190
|
+
name: str
|
|
191
|
+
price_cents_per_hour: int
|
|
192
|
+
description: str
|
|
193
|
+
gpu_count: int
|
|
194
|
+
gpu_type: str
|
|
195
|
+
vcpus: int
|
|
196
|
+
memory_gb: int
|
|
197
|
+
storage_gb: int
|
|
198
|
+
available_regions: list[str]
|
|
199
|
+
|
|
200
|
+
@property
|
|
201
|
+
def price_per_hour(self) -> float:
|
|
202
|
+
return self.price_cents_per_hour / 100
|
|
203
|
+
|
|
204
|
+
def __str__(self) -> str:
|
|
205
|
+
regions = ", ".join(self.available_regions[:3])
|
|
206
|
+
if len(self.available_regions) > 3:
|
|
207
|
+
regions += f" (+{len(self.available_regions) - 3} more)"
|
|
208
|
+
return (
|
|
209
|
+
f"{self.name}: ${self.price_per_hour:.2f}/hr | "
|
|
210
|
+
f"{self.gpu_count}x {self.gpu_type} | {self.vcpus} vCPUs | "
|
|
211
|
+
f"{self.memory_gb}GB RAM | {self.storage_gb}GB SSD | "
|
|
212
|
+
f"Regions: {regions}"
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
@dataclass
|
|
217
|
+
class Instance:
|
|
218
|
+
"""Running Lambda Labs instance."""
|
|
219
|
+
id: str
|
|
220
|
+
name: str
|
|
221
|
+
instance_type: str
|
|
222
|
+
status: str
|
|
223
|
+
ip: str | None
|
|
224
|
+
region: str
|
|
225
|
+
ssh_key_names: list[str]
|
|
226
|
+
|
|
227
|
+
def __str__(self) -> str:
|
|
228
|
+
ip_str = self.ip or "pending"
|
|
229
|
+
return f"{self.id[:8]}... | {self.instance_type} | {self.status} | IP: {ip_str} | {self.region}"
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
class LambdaLabsClient:
|
|
233
|
+
"""Client for Lambda Labs API."""
|
|
234
|
+
|
|
235
|
+
def __init__(self, api_key: str | None = None):
|
|
236
|
+
# Try provided key, then settings, then env var
|
|
237
|
+
if not api_key:
|
|
238
|
+
from openadapt_ml.config import settings
|
|
239
|
+
api_key = settings.lambda_api_key or os.environ.get("LAMBDA_API_KEY")
|
|
240
|
+
|
|
241
|
+
self.api_key = api_key
|
|
242
|
+
if not self.api_key:
|
|
243
|
+
raise ValueError(
|
|
244
|
+
"Lambda Labs API key required. Set LAMBDA_API_KEY in .env file "
|
|
245
|
+
"or get one at https://cloud.lambdalabs.com/api-keys"
|
|
246
|
+
)
|
|
247
|
+
self.session = requests.Session()
|
|
248
|
+
self.session.headers["Authorization"] = f"Bearer {self.api_key}"
|
|
249
|
+
|
|
250
|
+
def _get(self, endpoint: str) -> dict[str, Any]:
|
|
251
|
+
"""Make GET request to API."""
|
|
252
|
+
resp = self.session.get(f"{API_BASE}{endpoint}")
|
|
253
|
+
resp.raise_for_status()
|
|
254
|
+
return resp.json()
|
|
255
|
+
|
|
256
|
+
def _post(self, endpoint: str, data: dict[str, Any]) -> dict[str, Any]:
|
|
257
|
+
"""Make POST request to API."""
|
|
258
|
+
resp = self.session.post(f"{API_BASE}{endpoint}", json=data)
|
|
259
|
+
if not resp.ok:
|
|
260
|
+
error = resp.json().get("error", {})
|
|
261
|
+
raise RuntimeError(f"API error: {error.get('message', resp.text)}")
|
|
262
|
+
return resp.json()
|
|
263
|
+
|
|
264
|
+
def list_instance_types(self) -> list[InstanceType]:
|
|
265
|
+
"""List available GPU instance types."""
|
|
266
|
+
data = self._get("/instance-types")
|
|
267
|
+
types = []
|
|
268
|
+
|
|
269
|
+
for name, info in data.get("data", {}).items():
|
|
270
|
+
specs = info.get("instance_type", {}).get("specs", {})
|
|
271
|
+
regions = [r["name"] for r in info.get("regions_with_capacity_available", [])]
|
|
272
|
+
|
|
273
|
+
types.append(InstanceType(
|
|
274
|
+
name=name,
|
|
275
|
+
price_cents_per_hour=info.get("instance_type", {}).get("price_cents_per_hour", 0),
|
|
276
|
+
description=info.get("instance_type", {}).get("description", ""),
|
|
277
|
+
gpu_count=specs.get("gpus", 0),
|
|
278
|
+
gpu_type=info.get("instance_type", {}).get("gpu_description", ""),
|
|
279
|
+
vcpus=specs.get("vcpus", 0),
|
|
280
|
+
memory_gb=specs.get("memory_gib", 0),
|
|
281
|
+
storage_gb=specs.get("storage_gib", 0),
|
|
282
|
+
available_regions=regions,
|
|
283
|
+
))
|
|
284
|
+
|
|
285
|
+
# Sort by price
|
|
286
|
+
types.sort(key=lambda t: t.price_cents_per_hour)
|
|
287
|
+
return types
|
|
288
|
+
|
|
289
|
+
def list_ssh_keys(self) -> list[dict[str, str]]:
|
|
290
|
+
"""List registered SSH keys."""
|
|
291
|
+
data = self._get("/ssh-keys")
|
|
292
|
+
return data.get("data", [])
|
|
293
|
+
|
|
294
|
+
def add_ssh_key(self, name: str, public_key: str) -> dict[str, str]:
|
|
295
|
+
"""Add an SSH key."""
|
|
296
|
+
data = self._post("/ssh-keys", {"name": name, "public_key": public_key})
|
|
297
|
+
return data.get("data", {})
|
|
298
|
+
|
|
299
|
+
def list_instances(self) -> list[Instance]:
|
|
300
|
+
"""List running instances."""
|
|
301
|
+
data = self._get("/instances")
|
|
302
|
+
instances = []
|
|
303
|
+
|
|
304
|
+
for inst in data.get("data", []):
|
|
305
|
+
# ssh_key_names can be list of strings or list of dicts
|
|
306
|
+
ssh_keys = inst.get("ssh_key_names", [])
|
|
307
|
+
if ssh_keys and isinstance(ssh_keys[0], dict):
|
|
308
|
+
ssh_key_names = [k["name"] for k in ssh_keys]
|
|
309
|
+
else:
|
|
310
|
+
ssh_key_names = ssh_keys # Already list of strings
|
|
311
|
+
|
|
312
|
+
instances.append(Instance(
|
|
313
|
+
id=inst["id"],
|
|
314
|
+
name=inst.get("name", ""),
|
|
315
|
+
instance_type=inst.get("instance_type", {}).get("name", "unknown"),
|
|
316
|
+
status=inst.get("status", "unknown"),
|
|
317
|
+
ip=inst.get("ip"),
|
|
318
|
+
region=inst.get("region", {}).get("name", "unknown"),
|
|
319
|
+
ssh_key_names=ssh_key_names,
|
|
320
|
+
))
|
|
321
|
+
|
|
322
|
+
return instances
|
|
323
|
+
|
|
324
|
+
def launch_instance(
|
|
325
|
+
self,
|
|
326
|
+
instance_type: str,
|
|
327
|
+
region: str | None = None,
|
|
328
|
+
ssh_key_names: list[str] | None = None,
|
|
329
|
+
name: str | None = None,
|
|
330
|
+
) -> Instance:
|
|
331
|
+
"""Launch a new GPU instance.
|
|
332
|
+
|
|
333
|
+
Args:
|
|
334
|
+
instance_type: Instance type name (e.g., 'gpu_1x_a100')
|
|
335
|
+
region: Region name (auto-selects if None)
|
|
336
|
+
ssh_key_names: SSH key names to use
|
|
337
|
+
name: Optional instance name
|
|
338
|
+
|
|
339
|
+
Returns:
|
|
340
|
+
Launched instance
|
|
341
|
+
"""
|
|
342
|
+
# If no region specified, find one with capacity
|
|
343
|
+
if not region:
|
|
344
|
+
types = self.list_instance_types()
|
|
345
|
+
for t in types:
|
|
346
|
+
if t.name == instance_type and t.available_regions:
|
|
347
|
+
region = t.available_regions[0]
|
|
348
|
+
break
|
|
349
|
+
if not region:
|
|
350
|
+
raise RuntimeError(f"No regions available for {instance_type}")
|
|
351
|
+
|
|
352
|
+
# If no SSH key specified, use first available
|
|
353
|
+
if not ssh_key_names:
|
|
354
|
+
keys = self.list_ssh_keys()
|
|
355
|
+
if not keys:
|
|
356
|
+
raise RuntimeError(
|
|
357
|
+
"No SSH keys found. Add one at https://cloud.lambdalabs.com/ssh-keys"
|
|
358
|
+
)
|
|
359
|
+
ssh_key_names = [keys[0]["name"]]
|
|
360
|
+
|
|
361
|
+
payload = {
|
|
362
|
+
"region_name": region,
|
|
363
|
+
"instance_type_name": instance_type,
|
|
364
|
+
"ssh_key_names": ssh_key_names,
|
|
365
|
+
}
|
|
366
|
+
if name:
|
|
367
|
+
payload["name"] = name
|
|
368
|
+
|
|
369
|
+
data = self._post("/instance-operations/launch", payload)
|
|
370
|
+
instance_ids = data.get("data", {}).get("instance_ids", [])
|
|
371
|
+
|
|
372
|
+
if not instance_ids:
|
|
373
|
+
raise RuntimeError("Failed to launch instance")
|
|
374
|
+
|
|
375
|
+
# Wait for instance to be ready
|
|
376
|
+
print(f"Instance {instance_ids[0]} launched, waiting for IP...")
|
|
377
|
+
instance = None
|
|
378
|
+
for _ in range(60): # Wait up to 5 minutes for IP
|
|
379
|
+
instances = self.list_instances()
|
|
380
|
+
for inst in instances:
|
|
381
|
+
if inst.id == instance_ids[0] and inst.ip:
|
|
382
|
+
instance = inst
|
|
383
|
+
break
|
|
384
|
+
if instance:
|
|
385
|
+
break
|
|
386
|
+
time.sleep(5)
|
|
387
|
+
|
|
388
|
+
if not instance:
|
|
389
|
+
raise RuntimeError("Timed out waiting for instance IP")
|
|
390
|
+
|
|
391
|
+
# Wait for SSH to be ready - be patient, instances can take a while to boot
|
|
392
|
+
print(f"Instance IP: {instance.ip}, waiting for SSH...")
|
|
393
|
+
for attempt in range(60): # Wait up to 5 minutes for SSH
|
|
394
|
+
try:
|
|
395
|
+
result = subprocess.run(
|
|
396
|
+
["ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=10",
|
|
397
|
+
f"ubuntu@{instance.ip}", "echo ready"],
|
|
398
|
+
capture_output=True, text=True, timeout=20
|
|
399
|
+
)
|
|
400
|
+
if result.returncode == 0:
|
|
401
|
+
print("SSH ready!")
|
|
402
|
+
return instance
|
|
403
|
+
except subprocess.TimeoutExpired:
|
|
404
|
+
pass
|
|
405
|
+
if attempt % 6 == 5: # Log progress every 30 seconds
|
|
406
|
+
print(f" Still waiting for SSH ({(attempt+1)*5}s elapsed)...")
|
|
407
|
+
time.sleep(5)
|
|
408
|
+
|
|
409
|
+
print("Warning: SSH may not be ready yet, continuing anyway...")
|
|
410
|
+
return instance
|
|
411
|
+
|
|
412
|
+
def terminate_instance(self, instance_id: str) -> bool:
|
|
413
|
+
"""Terminate an instance."""
|
|
414
|
+
data = self._post("/instance-operations/terminate", {"instance_ids": [instance_id]})
|
|
415
|
+
terminated = data.get("data", {}).get("terminated_instances", [])
|
|
416
|
+
return any(t.get("id") == instance_id for t in terminated)
|
|
417
|
+
|
|
418
|
+
def get_ssh_command(self, instance: Instance, user: str = "ubuntu") -> str:
|
|
419
|
+
"""Get SSH command for an instance."""
|
|
420
|
+
if not instance.ip:
|
|
421
|
+
return "# Instance IP not yet available"
|
|
422
|
+
return f"ssh {user}@{instance.ip}"
|
|
423
|
+
|
|
424
|
+
def ssh_run(self, instance: Instance, command: str, timeout: int | None = None, retries: int = 3) -> subprocess.CompletedProcess:
|
|
425
|
+
"""Run a command on an instance via SSH.
|
|
426
|
+
|
|
427
|
+
Args:
|
|
428
|
+
instance: Instance to run on
|
|
429
|
+
command: Shell command to run
|
|
430
|
+
timeout: Optional timeout in seconds
|
|
431
|
+
retries: Number of retries on connection failure
|
|
432
|
+
|
|
433
|
+
Returns:
|
|
434
|
+
CompletedProcess with stdout/stderr
|
|
435
|
+
"""
|
|
436
|
+
if not instance.ip:
|
|
437
|
+
raise RuntimeError("Instance has no IP address")
|
|
438
|
+
|
|
439
|
+
ssh_cmd = [
|
|
440
|
+
"ssh", "-o", "StrictHostKeyChecking=no",
|
|
441
|
+
"-o", "ConnectTimeout=30", # Increased from 10
|
|
442
|
+
"-o", "ServerAliveInterval=60", # Keep connection alive
|
|
443
|
+
"-o", "ServerAliveCountMax=3",
|
|
444
|
+
f"ubuntu@{instance.ip}",
|
|
445
|
+
command
|
|
446
|
+
]
|
|
447
|
+
|
|
448
|
+
last_error = None
|
|
449
|
+
for attempt in range(retries):
|
|
450
|
+
try:
|
|
451
|
+
return subprocess.run(
|
|
452
|
+
ssh_cmd,
|
|
453
|
+
capture_output=True,
|
|
454
|
+
text=True,
|
|
455
|
+
timeout=timeout,
|
|
456
|
+
)
|
|
457
|
+
except subprocess.TimeoutExpired as e:
|
|
458
|
+
last_error = e
|
|
459
|
+
if attempt < retries - 1:
|
|
460
|
+
print(f" SSH timeout, retrying ({attempt + 1}/{retries})...")
|
|
461
|
+
time.sleep(5)
|
|
462
|
+
|
|
463
|
+
raise last_error if last_error else RuntimeError("SSH failed")
|
|
464
|
+
|
|
465
|
+
def setup_instance(self, instance: Instance, repo_url: str = "https://github.com/OpenAdaptAI/openadapt-ml.git", clean_gpu: bool = True) -> bool:
|
|
466
|
+
"""Set up training environment on instance.
|
|
467
|
+
|
|
468
|
+
Clones repo, installs uv, syncs dependencies.
|
|
469
|
+
Optionally clears GPU memory from previous runs.
|
|
470
|
+
Returns True if successful.
|
|
471
|
+
"""
|
|
472
|
+
print(f"Setting up instance {instance.ip}...")
|
|
473
|
+
|
|
474
|
+
# Clean GPU memory if requested (don't fail if this doesn't work)
|
|
475
|
+
if clean_gpu:
|
|
476
|
+
print(" Clearing GPU memory...")
|
|
477
|
+
try:
|
|
478
|
+
self.ssh_run(instance, '''
|
|
479
|
+
python3 -c "
|
|
480
|
+
import torch
|
|
481
|
+
if torch.cuda.is_available():
|
|
482
|
+
torch.cuda.empty_cache()
|
|
483
|
+
torch.cuda.reset_peak_memory_stats()
|
|
484
|
+
print('GPU memory cleared')
|
|
485
|
+
" 2>/dev/null || true
|
|
486
|
+
# Kill any stale python processes using GPU
|
|
487
|
+
pkill -f "python.*train" 2>/dev/null || true
|
|
488
|
+
''', timeout=60)
|
|
489
|
+
except Exception as e:
|
|
490
|
+
print(f" GPU cleanup skipped: {e}")
|
|
491
|
+
|
|
492
|
+
setup_script = f'''
|
|
493
|
+
set -e
|
|
494
|
+
cd ~
|
|
495
|
+
|
|
496
|
+
# Install uv via official installer (most robust)
|
|
497
|
+
if ! command -v uv &> /dev/null; then
|
|
498
|
+
curl -LsSf https://astral.sh/uv/install.sh | sh
|
|
499
|
+
fi
|
|
500
|
+
export PATH="$HOME/.local/bin:$HOME/.cargo/bin:$PATH"
|
|
501
|
+
|
|
502
|
+
# Clone or update repo
|
|
503
|
+
if [ ! -d "openadapt-ml" ]; then
|
|
504
|
+
git clone {repo_url}
|
|
505
|
+
else
|
|
506
|
+
cd openadapt-ml && git pull origin main && cd ~
|
|
507
|
+
fi
|
|
508
|
+
|
|
509
|
+
cd openadapt-ml
|
|
510
|
+
uv sync
|
|
511
|
+
echo "SETUP_COMPLETE"
|
|
512
|
+
'''
|
|
513
|
+
|
|
514
|
+
try:
|
|
515
|
+
result = self.ssh_run(instance, setup_script, timeout=900) # 15 min timeout for setup
|
|
516
|
+
|
|
517
|
+
if "SETUP_COMPLETE" in result.stdout:
|
|
518
|
+
print(" Environment ready")
|
|
519
|
+
return True
|
|
520
|
+
else:
|
|
521
|
+
stderr_preview = result.stderr[:500] if result.stderr else "(no stderr)"
|
|
522
|
+
print(f" Setup failed: {stderr_preview}")
|
|
523
|
+
return False
|
|
524
|
+
except subprocess.TimeoutExpired:
|
|
525
|
+
print(" Setup timed out after 15 minutes")
|
|
526
|
+
return False
|
|
527
|
+
except Exception as e:
|
|
528
|
+
print(f" Setup failed: {e}")
|
|
529
|
+
return False
|
|
530
|
+
|
|
531
|
+
def sync_local_code(self, instance: Instance, local_repo_path: str = ".", retries: int = 3) -> bool:
|
|
532
|
+
"""Sync local code changes to remote instance.
|
|
533
|
+
|
|
534
|
+
Uses rsync to push local code, excluding .venv, .git, etc.
|
|
535
|
+
This ensures the remote has the same code as local.
|
|
536
|
+
|
|
537
|
+
Args:
|
|
538
|
+
instance: Instance to sync to
|
|
539
|
+
local_repo_path: Local repository path
|
|
540
|
+
retries: Number of retry attempts
|
|
541
|
+
|
|
542
|
+
Returns:
|
|
543
|
+
True if successful
|
|
544
|
+
"""
|
|
545
|
+
if not instance.ip:
|
|
546
|
+
raise RuntimeError("Instance has no IP address")
|
|
547
|
+
|
|
548
|
+
print(f"Syncing local code to {instance.ip}...")
|
|
549
|
+
|
|
550
|
+
# SSH options for more robust connection
|
|
551
|
+
ssh_opts = "ssh -o StrictHostKeyChecking=no -o ConnectTimeout=30 -o ServerAliveInterval=60"
|
|
552
|
+
|
|
553
|
+
rsync_cmd = [
|
|
554
|
+
"rsync", "-avz", "--progress",
|
|
555
|
+
"--timeout=120", # 2 minute timeout per file
|
|
556
|
+
"--exclude", ".venv",
|
|
557
|
+
"--exclude", ".git",
|
|
558
|
+
"--exclude", "__pycache__",
|
|
559
|
+
"--exclude", "*.pyc",
|
|
560
|
+
"--exclude", ".env",
|
|
561
|
+
"--exclude", "training_output",
|
|
562
|
+
"--exclude", "checkpoints",
|
|
563
|
+
"--exclude", "synthetic*",
|
|
564
|
+
"-e", ssh_opts,
|
|
565
|
+
f"{local_repo_path}/",
|
|
566
|
+
f"ubuntu@{instance.ip}:~/openadapt-ml/"
|
|
567
|
+
]
|
|
568
|
+
|
|
569
|
+
for attempt in range(retries):
|
|
570
|
+
result = subprocess.run(rsync_cmd)
|
|
571
|
+
if result.returncode == 0:
|
|
572
|
+
print(" Code synced")
|
|
573
|
+
return True
|
|
574
|
+
if attempt < retries - 1:
|
|
575
|
+
print(f" Sync failed, retrying ({attempt + 1}/{retries})...")
|
|
576
|
+
time.sleep(5)
|
|
577
|
+
|
|
578
|
+
return False
|
|
579
|
+
|
|
580
|
+
def upload_capture(self, instance: Instance, local_path: str, remote_path: str = "~/capture", retries: int = 3) -> bool:
|
|
581
|
+
"""Upload a capture directory to instance via rsync.
|
|
582
|
+
|
|
583
|
+
Args:
|
|
584
|
+
instance: Instance to upload to
|
|
585
|
+
local_path: Local path to capture directory
|
|
586
|
+
remote_path: Remote path (default: ~/capture)
|
|
587
|
+
retries: Number of retry attempts
|
|
588
|
+
|
|
589
|
+
Returns:
|
|
590
|
+
True if successful
|
|
591
|
+
"""
|
|
592
|
+
if not instance.ip:
|
|
593
|
+
raise RuntimeError("Instance has no IP address")
|
|
594
|
+
|
|
595
|
+
print(f"Uploading capture to {instance.ip}:{remote_path}...")
|
|
596
|
+
|
|
597
|
+
# SSH options for more robust connection
|
|
598
|
+
ssh_opts = "ssh -o StrictHostKeyChecking=no -o ConnectTimeout=30 -o ServerAliveInterval=60"
|
|
599
|
+
|
|
600
|
+
rsync_cmd = [
|
|
601
|
+
"rsync", "-avz", "--progress",
|
|
602
|
+
"--timeout=120", # 2 minute timeout per file
|
|
603
|
+
"-e", ssh_opts,
|
|
604
|
+
f"{local_path}/",
|
|
605
|
+
f"ubuntu@{instance.ip}:{remote_path}/"
|
|
606
|
+
]
|
|
607
|
+
|
|
608
|
+
for attempt in range(retries):
|
|
609
|
+
result = subprocess.run(rsync_cmd)
|
|
610
|
+
if result.returncode == 0:
|
|
611
|
+
return True
|
|
612
|
+
if attempt < retries - 1:
|
|
613
|
+
print(f" Upload failed, retrying ({attempt + 1}/{retries})...")
|
|
614
|
+
time.sleep(5)
|
|
615
|
+
|
|
616
|
+
return False
|
|
617
|
+
|
|
618
|
+
def run_training(
|
|
619
|
+
self,
|
|
620
|
+
instance: Instance,
|
|
621
|
+
config: str = "configs/qwen3vl_capture.yaml",
|
|
622
|
+
capture: str | None = None,
|
|
623
|
+
goal: str | None = None,
|
|
624
|
+
background: bool = True,
|
|
625
|
+
) -> subprocess.Popen | subprocess.CompletedProcess:
|
|
626
|
+
"""Run training on instance.
|
|
627
|
+
|
|
628
|
+
Args:
|
|
629
|
+
instance: Instance to train on
|
|
630
|
+
config: Config file path (relative to repo)
|
|
631
|
+
capture: Remote capture path (if uploaded)
|
|
632
|
+
goal: Task goal description
|
|
633
|
+
background: Run in background (returns Popen) or foreground
|
|
634
|
+
|
|
635
|
+
Returns:
|
|
636
|
+
Popen if background=True, CompletedProcess if background=False
|
|
637
|
+
"""
|
|
638
|
+
if not instance.ip:
|
|
639
|
+
raise RuntimeError("Instance has no IP address")
|
|
640
|
+
|
|
641
|
+
# Build training command
|
|
642
|
+
train_cmd = f"uv run python -m openadapt_ml.scripts.train --config {config}"
|
|
643
|
+
if capture:
|
|
644
|
+
train_cmd += f" --capture {capture}"
|
|
645
|
+
if goal:
|
|
646
|
+
train_cmd += f' --goal "{goal}"'
|
|
647
|
+
|
|
648
|
+
# Full script with environment setup
|
|
649
|
+
script = f'''
|
|
650
|
+
cd ~/openadapt-ml
|
|
651
|
+
export PATH="$HOME/.local/bin:$PATH"
|
|
652
|
+
{train_cmd}
|
|
653
|
+
'''
|
|
654
|
+
|
|
655
|
+
ssh_cmd = [
|
|
656
|
+
"ssh", "-o", "StrictHostKeyChecking=no",
|
|
657
|
+
f"ubuntu@{instance.ip}",
|
|
658
|
+
script
|
|
659
|
+
]
|
|
660
|
+
|
|
661
|
+
print(f"Running training on {instance.ip}...")
|
|
662
|
+
print(f" Config: {config}")
|
|
663
|
+
if capture:
|
|
664
|
+
print(f" Capture: {capture}")
|
|
665
|
+
|
|
666
|
+
if background:
|
|
667
|
+
# Run in background, return Popen for monitoring
|
|
668
|
+
return subprocess.Popen(
|
|
669
|
+
ssh_cmd,
|
|
670
|
+
stdout=subprocess.PIPE,
|
|
671
|
+
stderr=subprocess.STDOUT,
|
|
672
|
+
text=True,
|
|
673
|
+
)
|
|
674
|
+
else:
|
|
675
|
+
# Run in foreground, stream output
|
|
676
|
+
return subprocess.run(ssh_cmd)
|
|
677
|
+
|
|
678
|
+
def download_results(
|
|
679
|
+
self,
|
|
680
|
+
instance: Instance,
|
|
681
|
+
remote_path: str = "~/openadapt-ml",
|
|
682
|
+
local_path: str = ".",
|
|
683
|
+
include_checkpoint: bool = True,
|
|
684
|
+
include_logs: bool = True,
|
|
685
|
+
) -> bool:
|
|
686
|
+
"""Download training results from instance.
|
|
687
|
+
|
|
688
|
+
Args:
|
|
689
|
+
instance: Instance to download from
|
|
690
|
+
remote_path: Remote openadapt-ml directory
|
|
691
|
+
local_path: Local directory to download to
|
|
692
|
+
include_checkpoint: Download checkpoint weights
|
|
693
|
+
include_logs: Download training logs and dashboard
|
|
694
|
+
|
|
695
|
+
Returns:
|
|
696
|
+
True if successful
|
|
697
|
+
"""
|
|
698
|
+
if not instance.ip:
|
|
699
|
+
raise RuntimeError("Instance has no IP address")
|
|
700
|
+
|
|
701
|
+
print(f"Downloading results from {instance.ip}...")
|
|
702
|
+
success = True
|
|
703
|
+
|
|
704
|
+
# Download training output (logs, dashboard)
|
|
705
|
+
if include_logs:
|
|
706
|
+
print(" Downloading training logs...")
|
|
707
|
+
rsync_cmd = [
|
|
708
|
+
"rsync", "-avz",
|
|
709
|
+
"-e", "ssh -o StrictHostKeyChecking=no",
|
|
710
|
+
f"ubuntu@{instance.ip}:{remote_path}/training_output/",
|
|
711
|
+
f"{local_path}/training_output_lambda/"
|
|
712
|
+
]
|
|
713
|
+
result = subprocess.run(rsync_cmd, capture_output=True)
|
|
714
|
+
if result.returncode == 0:
|
|
715
|
+
print(" Training logs downloaded to training_output_lambda/")
|
|
716
|
+
else:
|
|
717
|
+
print(f" Warning: Failed to download logs")
|
|
718
|
+
success = False
|
|
719
|
+
|
|
720
|
+
# Download checkpoint
|
|
721
|
+
if include_checkpoint:
|
|
722
|
+
print(" Downloading checkpoint...")
|
|
723
|
+
rsync_cmd = [
|
|
724
|
+
"rsync", "-avz",
|
|
725
|
+
"-e", "ssh -o StrictHostKeyChecking=no",
|
|
726
|
+
f"ubuntu@{instance.ip}:{remote_path}/checkpoints/",
|
|
727
|
+
f"{local_path}/checkpoints_lambda/"
|
|
728
|
+
]
|
|
729
|
+
result = subprocess.run(rsync_cmd, capture_output=True)
|
|
730
|
+
if result.returncode == 0:
|
|
731
|
+
print(" Checkpoint downloaded to checkpoints_lambda/")
|
|
732
|
+
else:
|
|
733
|
+
print(f" Warning: Failed to download checkpoint (may not exist yet)")
|
|
734
|
+
|
|
735
|
+
# Regenerate all dashboards with static navigation and correct status
|
|
736
|
+
if include_logs:
|
|
737
|
+
try:
|
|
738
|
+
from openadapt_ml.training.trainer import regenerate_all_dashboards
|
|
739
|
+
output_dir = Path(local_path) / "training_output_lambda"
|
|
740
|
+
if output_dir.exists():
|
|
741
|
+
print(" Regenerating dashboards with static navigation...")
|
|
742
|
+
regenerate_all_dashboards(output_dir)
|
|
743
|
+
except Exception as e:
|
|
744
|
+
print(f" Warning: Failed to regenerate dashboards: {e}")
|
|
745
|
+
|
|
746
|
+
return success
|
|
747
|
+
|
|
748
|
+
def get_training_status(self, instance: Instance) -> dict:
|
|
749
|
+
"""Check training status by reading training_log.json on instance."""
|
|
750
|
+
result = self.ssh_run(
|
|
751
|
+
instance,
|
|
752
|
+
"cat ~/openadapt-ml/training_output/training_log.json 2>/dev/null || echo '{}'",
|
|
753
|
+
timeout=10,
|
|
754
|
+
)
|
|
755
|
+
try:
|
|
756
|
+
import json
|
|
757
|
+
return json.loads(result.stdout.strip())
|
|
758
|
+
except:
|
|
759
|
+
return {}
|
|
760
|
+
|
|
761
|
+
|
|
762
|
+
def setup_lambda_ssh_key(client: LambdaLabsClient) -> str:
|
|
763
|
+
"""Set up SSH key for Lambda Labs if not already done.
|
|
764
|
+
|
|
765
|
+
Returns the SSH key name that was added/found.
|
|
766
|
+
"""
|
|
767
|
+
# Check if we already have keys
|
|
768
|
+
keys = client.list_ssh_keys()
|
|
769
|
+
if keys:
|
|
770
|
+
print(f"Using existing SSH key: {keys[0]['name']}")
|
|
771
|
+
return keys[0]["name"]
|
|
772
|
+
|
|
773
|
+
# Look for local SSH key
|
|
774
|
+
ssh_key_path = Path.home() / ".ssh" / "id_rsa.pub"
|
|
775
|
+
if not ssh_key_path.exists():
|
|
776
|
+
ssh_key_path = Path.home() / ".ssh" / "id_ed25519.pub"
|
|
777
|
+
|
|
778
|
+
if not ssh_key_path.exists():
|
|
779
|
+
raise RuntimeError(
|
|
780
|
+
"No SSH key found at ~/.ssh/id_rsa.pub or ~/.ssh/id_ed25519.pub\n"
|
|
781
|
+
"Generate one with: ssh-keygen -t ed25519"
|
|
782
|
+
)
|
|
783
|
+
|
|
784
|
+
public_key = ssh_key_path.read_text().strip()
|
|
785
|
+
key_name = f"openadapt-{os.environ.get('USER', 'user')}"
|
|
786
|
+
|
|
787
|
+
print(f"Adding SSH key '{key_name}' to Lambda Labs...")
|
|
788
|
+
client.add_ssh_key(key_name, public_key)
|
|
789
|
+
return key_name
|
|
790
|
+
|
|
791
|
+
|
|
792
|
+
def main():
|
|
793
|
+
"""CLI for Lambda Labs."""
|
|
794
|
+
import argparse
|
|
795
|
+
|
|
796
|
+
parser = argparse.ArgumentParser(description="Lambda Labs GPU management")
|
|
797
|
+
subparsers = parser.add_subparsers(dest="command", help="Command")
|
|
798
|
+
|
|
799
|
+
# List instances command
|
|
800
|
+
list_parser = subparsers.add_parser("list", help="List available instance types")
|
|
801
|
+
|
|
802
|
+
# Status command
|
|
803
|
+
status_parser = subparsers.add_parser("status", help="Show running instances")
|
|
804
|
+
|
|
805
|
+
# Launch command
|
|
806
|
+
launch_parser = subparsers.add_parser("launch", help="Launch a GPU instance")
|
|
807
|
+
launch_parser.add_argument(
|
|
808
|
+
"--type", "-t",
|
|
809
|
+
default="gpu_1x_a100",
|
|
810
|
+
help="Instance type (default: gpu_1x_a100)",
|
|
811
|
+
)
|
|
812
|
+
launch_parser.add_argument("--region", "-r", help="Region (auto-selects if not specified)")
|
|
813
|
+
launch_parser.add_argument("--name", "-n", help="Instance name")
|
|
814
|
+
|
|
815
|
+
# Terminate command
|
|
816
|
+
term_parser = subparsers.add_parser("terminate", help="Terminate an instance")
|
|
817
|
+
term_parser.add_argument("instance_id", help="Instance ID to terminate")
|
|
818
|
+
|
|
819
|
+
# SSH command - run commands or get interactive shell
|
|
820
|
+
ssh_parser = subparsers.add_parser("ssh", help="SSH into Lambda instance or run command")
|
|
821
|
+
ssh_parser.add_argument("instance_id", nargs="?", help="Instance ID (uses first if not specified)")
|
|
822
|
+
ssh_parser.add_argument("--cmd", "-c", help="Command to run (opens shell if not specified)")
|
|
823
|
+
ssh_parser.add_argument("--timeout", "-t", type=int, default=60, help="Command timeout in seconds")
|
|
824
|
+
|
|
825
|
+
# Serve command - start dashboard server with stop button support
|
|
826
|
+
serve_parser = subparsers.add_parser("serve", help="Start dashboard server with stop button support")
|
|
827
|
+
serve_parser.add_argument("--output", "-o", default="training_output", help="Output directory (default: training_output)")
|
|
828
|
+
serve_parser.add_argument("--port", "-p", type=int, default=8765, help="Port (default: 8765)")
|
|
829
|
+
serve_parser.add_argument("--open", action="store_true", help="Open dashboard in browser")
|
|
830
|
+
|
|
831
|
+
# Rsync command - copy files to/from Lambda instance
|
|
832
|
+
rsync_parser = subparsers.add_parser("rsync", help="Rsync files to/from Lambda instance")
|
|
833
|
+
rsync_parser.add_argument("source", help="Source path (prefix with 'remote:' for remote paths)")
|
|
834
|
+
rsync_parser.add_argument("dest", help="Destination path (prefix with 'remote:' for remote paths)")
|
|
835
|
+
rsync_parser.add_argument("instance_id", nargs="?", help="Instance ID (uses first if not specified)")
|
|
836
|
+
rsync_parser.add_argument("--delete", action="store_true", help="Delete extraneous files from dest")
|
|
837
|
+
|
|
838
|
+
# Setup command
|
|
839
|
+
setup_parser = subparsers.add_parser("setup", help="Set up SSH key for Lambda Labs")
|
|
840
|
+
|
|
841
|
+
# Train command - full automated training pipeline
|
|
842
|
+
train_parser = subparsers.add_parser("train", help="Run training on Lambda GPU")
|
|
843
|
+
train_parser.add_argument("--capture", "-c", help="Local path to capture directory")
|
|
844
|
+
train_parser.add_argument("--goal", "-g", help="Task goal description")
|
|
845
|
+
train_parser.add_argument("--config", default="configs/qwen3vl_capture_4bit.yaml", help="Config file (default: 4bit for memory efficiency)")
|
|
846
|
+
train_parser.add_argument("--type", "-t", default="gpu_1x_a10", help="Instance type")
|
|
847
|
+
train_parser.add_argument("--instance", "-i", help="Use existing instance ID instead of launching new")
|
|
848
|
+
train_parser.add_argument("--no-terminate", action="store_true", help="Don't terminate instance after training")
|
|
849
|
+
train_parser.add_argument("--max-runtime", type=int, default=60, help="Max runtime in minutes before auto-terminate (default: 60)")
|
|
850
|
+
train_parser.add_argument("--open", action="store_true", help="Open dashboard in browser when training starts")
|
|
851
|
+
|
|
852
|
+
# Training status command
|
|
853
|
+
train_status_parser = subparsers.add_parser("train-status", help="Check training status on instance")
|
|
854
|
+
train_status_parser.add_argument("instance_id", nargs="?", help="Instance ID")
|
|
855
|
+
|
|
856
|
+
# Monitor command - live dashboard for Lambda training
|
|
857
|
+
monitor_parser = subparsers.add_parser("monitor", help="Monitor Lambda training with live dashboard")
|
|
858
|
+
monitor_parser.add_argument("instance_id", nargs="?", help="Instance ID")
|
|
859
|
+
monitor_parser.add_argument("--open", action="store_true", help="Open dashboard in browser")
|
|
860
|
+
monitor_parser.add_argument("--interval", type=int, default=5, help="Poll interval in seconds (default: 5)")
|
|
861
|
+
monitor_parser.add_argument("--capture", type=str, help="Local capture path for screenshot symlink")
|
|
862
|
+
monitor_parser.add_argument("--auto-stop-loss", type=float, default=0.5, help="Auto-terminate when loss drops below this (default: 0.5)")
|
|
863
|
+
monitor_parser.add_argument("--download-checkpoints", action="store_true", default=True, help="Auto-download checkpoints each epoch")
|
|
864
|
+
monitor_parser.add_argument("--no-download-checkpoints", action="store_false", dest="download_checkpoints", help="Disable checkpoint download")
|
|
865
|
+
monitor_parser.add_argument("--stub", action="store_true", help="Use stub training provider (no GPU, instant simulation)")
|
|
866
|
+
|
|
867
|
+
# Refresh command - one-shot dashboard update
|
|
868
|
+
refresh_parser = subparsers.add_parser("refresh", help="One-shot refresh of training dashboard")
|
|
869
|
+
refresh_parser.add_argument("instance_id", nargs="?", help="Instance ID")
|
|
870
|
+
refresh_parser.add_argument("--open", action="store_true", help="Open dashboard in browser")
|
|
871
|
+
refresh_parser.add_argument("--capture", type=str, help="Local capture path for screenshot preview")
|
|
872
|
+
|
|
873
|
+
# Checkpoints command - list remote checkpoints
|
|
874
|
+
checkpoints_parser = subparsers.add_parser("checkpoints", help="List checkpoints on remote instance")
|
|
875
|
+
checkpoints_parser.add_argument("instance_id", nargs="?", help="Instance ID")
|
|
876
|
+
|
|
877
|
+
# Download results command
|
|
878
|
+
download_parser = subparsers.add_parser("download", help="Download training results from instance")
|
|
879
|
+
download_parser.add_argument("instance_id", nargs="?", help="Instance ID")
|
|
880
|
+
download_parser.add_argument("--output", "-o", default=".", help="Local output directory")
|
|
881
|
+
|
|
882
|
+
# Check files on instance
|
|
883
|
+
files_parser = subparsers.add_parser("files", help="List training files on instance")
|
|
884
|
+
files_parser.add_argument("instance_id", nargs="?", help="Instance ID")
|
|
885
|
+
files_parser.add_argument("--path", "-p", default="~/openadapt-ml", help="Path to check")
|
|
886
|
+
|
|
887
|
+
# Kill command - terminate training processes
|
|
888
|
+
kill_parser = subparsers.add_parser("kill", help="Kill training/inference processes on instance")
|
|
889
|
+
kill_parser.add_argument("instance_id", nargs="?", help="Instance ID")
|
|
890
|
+
kill_parser.add_argument("--local", action="store_true", help="Also kill local Lambda-related processes")
|
|
891
|
+
kill_parser.add_argument("--all", action="store_true", help="Kill all Python processes on instance (careful!)")
|
|
892
|
+
|
|
893
|
+
# Check command - analyze training status and early stopping
|
|
894
|
+
check_parser = subparsers.add_parser("check", help="Check training health and early stopping status")
|
|
895
|
+
check_parser.add_argument("instance_id", nargs="?", help="Instance ID")
|
|
896
|
+
check_parser.add_argument("--threshold", "-t", type=float, default=0.01,
|
|
897
|
+
help="Early stopping threshold (loss improvement over last N steps)")
|
|
898
|
+
check_parser.add_argument("--window", "-w", type=int, default=10,
|
|
899
|
+
help="Number of recent steps to check for improvement")
|
|
900
|
+
|
|
901
|
+
# Compare command - run comparison on Lambda and sync back
|
|
902
|
+
compare_parser = subparsers.add_parser("compare", help="Run human vs AI comparison on Lambda")
|
|
903
|
+
compare_parser.add_argument("instance_id", nargs="?", help="Instance ID")
|
|
904
|
+
compare_parser.add_argument("--checkpoint", "-c", help="Checkpoint to use (default: latest)")
|
|
905
|
+
compare_parser.add_argument("--epoch", "-e", type=int, help="Use checkpoint from specific epoch")
|
|
906
|
+
compare_parser.add_argument("--open", action="store_true", help="Open viewer after generation")
|
|
907
|
+
|
|
908
|
+
# Results viewer command - downloads and generates comparison viewer
|
|
909
|
+
results_parser = subparsers.add_parser("results", help="Download results and generate comparison viewer")
|
|
910
|
+
results_parser.add_argument("--capture", "-c", required=True, help="Local capture directory (for comparison)")
|
|
911
|
+
results_parser.add_argument("--goal", "-g", help="Task goal description")
|
|
912
|
+
results_parser.add_argument("--open", action="store_true", help="Open viewer in browser")
|
|
913
|
+
results_parser.add_argument("instance_id", nargs="?", help="Instance ID")
|
|
914
|
+
|
|
915
|
+
# Sync command - sync training output and regenerate navigation for file:// protocol
|
|
916
|
+
sync_parser = subparsers.add_parser("sync", help="Sync training output from Lambda and regenerate navigation")
|
|
917
|
+
sync_parser.add_argument("instance_id", nargs="?", help="Instance ID")
|
|
918
|
+
sync_parser.add_argument("--output", "-o", default="training_output", help="Local output directory (default: training_output)")
|
|
919
|
+
sync_parser.add_argument("--open", action="store_true", help="Open dashboard in browser after sync")
|
|
920
|
+
|
|
921
|
+
# Viewer command - regenerate local viewer (no Lambda required)
|
|
922
|
+
viewer_parser = subparsers.add_parser("viewer", help="Regenerate local viewer (no Lambda required)")
|
|
923
|
+
viewer_parser.add_argument("--output", "-o", default="training_output", help="Training output directory (default: training_output)")
|
|
924
|
+
viewer_parser.add_argument("--dashboard", "-d", action="store_true", help="Regenerate dashboard instead of viewer")
|
|
925
|
+
viewer_parser.add_argument("--open", action="store_true", help="Open in browser (use 'serve' instead for better experience)")
|
|
926
|
+
|
|
927
|
+
args = parser.parse_args()
|
|
928
|
+
|
|
929
|
+
if not args.command:
|
|
930
|
+
parser.print_help()
|
|
931
|
+
return
|
|
932
|
+
|
|
933
|
+
try:
|
|
934
|
+
client = LambdaLabsClient()
|
|
935
|
+
except ValueError as e:
|
|
936
|
+
print(f"Error: {e}")
|
|
937
|
+
print("\nGet your API key at https://cloud.lambdalabs.com/api-keys")
|
|
938
|
+
print("Then set it: export LAMBDA_API_KEY=your_key_here")
|
|
939
|
+
return
|
|
940
|
+
|
|
941
|
+
if args.command == "list":
|
|
942
|
+
print("Available GPU instances:\n")
|
|
943
|
+
types = client.list_instance_types()
|
|
944
|
+
for t in types:
|
|
945
|
+
avail = "available" if t.available_regions else "no capacity"
|
|
946
|
+
print(f" {t}")
|
|
947
|
+
print(f"\nTotal: {len(types)} instance types")
|
|
948
|
+
print("\nLaunch with: python -m openadapt_ml.cloud.lambda_labs launch --type <name>")
|
|
949
|
+
|
|
950
|
+
elif args.command == "status":
|
|
951
|
+
instances = client.list_instances()
|
|
952
|
+
if not instances:
|
|
953
|
+
print("No running instances.")
|
|
954
|
+
else:
|
|
955
|
+
print("Running instances:\n")
|
|
956
|
+
for inst in instances:
|
|
957
|
+
print(f" {inst}")
|
|
958
|
+
print(f"\nTotal: {len(instances)} instances")
|
|
959
|
+
|
|
960
|
+
elif args.command == "launch":
|
|
961
|
+
# Ensure SSH key is set up
|
|
962
|
+
ssh_key = setup_lambda_ssh_key(client)
|
|
963
|
+
|
|
964
|
+
print(f"Launching {args.type}...")
|
|
965
|
+
instance = client.launch_instance(
|
|
966
|
+
instance_type=args.type,
|
|
967
|
+
region=args.region,
|
|
968
|
+
ssh_key_names=[ssh_key],
|
|
969
|
+
name=args.name,
|
|
970
|
+
)
|
|
971
|
+
print(f"\nInstance launched!")
|
|
972
|
+
print(f" ID: {instance.id}")
|
|
973
|
+
print(f" IP: {instance.ip}")
|
|
974
|
+
print(f" Type: {instance.instance_type}")
|
|
975
|
+
print(f" Region: {instance.region}")
|
|
976
|
+
print(f"\nConnect with: ssh ubuntu@{instance.ip}")
|
|
977
|
+
print(f"\nTerminate with: python -m openadapt_ml.cloud.lambda_labs terminate {instance.id}")
|
|
978
|
+
|
|
979
|
+
elif args.command == "terminate":
|
|
980
|
+
if client.terminate_instance(args.instance_id):
|
|
981
|
+
print(f"Instance {args.instance_id} terminated.")
|
|
982
|
+
else:
|
|
983
|
+
print(f"Failed to terminate instance {args.instance_id}")
|
|
984
|
+
|
|
985
|
+
elif args.command == "ssh":
|
|
986
|
+
instances = client.list_instances()
|
|
987
|
+
if not instances:
|
|
988
|
+
print("No running instances.")
|
|
989
|
+
return
|
|
990
|
+
|
|
991
|
+
if args.instance_id:
|
|
992
|
+
instance = next((i for i in instances if i.id.startswith(args.instance_id)), None)
|
|
993
|
+
if not instance:
|
|
994
|
+
print(f"Instance {args.instance_id} not found.")
|
|
995
|
+
return
|
|
996
|
+
else:
|
|
997
|
+
instance = instances[0]
|
|
998
|
+
|
|
999
|
+
if hasattr(args, 'cmd') and args.cmd:
|
|
1000
|
+
# Run single command
|
|
1001
|
+
print(f"Running on {instance.ip}: {args.cmd}")
|
|
1002
|
+
result = client.ssh_run(instance, args.cmd, timeout=args.timeout)
|
|
1003
|
+
if result.stdout:
|
|
1004
|
+
print(result.stdout)
|
|
1005
|
+
if result.stderr:
|
|
1006
|
+
print(f"[stderr] {result.stderr}", file=sys.stderr)
|
|
1007
|
+
if result.returncode != 0:
|
|
1008
|
+
sys.exit(result.returncode)
|
|
1009
|
+
else:
|
|
1010
|
+
# Print SSH command for interactive use
|
|
1011
|
+
print(client.get_ssh_command(instance))
|
|
1012
|
+
|
|
1013
|
+
elif args.command == "rsync":
|
|
1014
|
+
# Rsync files to/from Lambda instance
|
|
1015
|
+
instances = client.list_instances()
|
|
1016
|
+
if not instances:
|
|
1017
|
+
print("No running instances.")
|
|
1018
|
+
return
|
|
1019
|
+
|
|
1020
|
+
if args.instance_id:
|
|
1021
|
+
instance = next((i for i in instances if i.id.startswith(args.instance_id)), None)
|
|
1022
|
+
if not instance:
|
|
1023
|
+
print(f"Instance {args.instance_id} not found.")
|
|
1024
|
+
return
|
|
1025
|
+
else:
|
|
1026
|
+
instance = instances[0]
|
|
1027
|
+
|
|
1028
|
+
# Parse source and dest - 'remote:' prefix indicates remote path
|
|
1029
|
+
source = args.source
|
|
1030
|
+
dest = args.dest
|
|
1031
|
+
|
|
1032
|
+
if source.startswith("remote:"):
|
|
1033
|
+
source = f"ubuntu@{instance.ip}:{source[7:]}"
|
|
1034
|
+
if dest.startswith("remote:"):
|
|
1035
|
+
dest = f"ubuntu@{instance.ip}:{dest[7:]}"
|
|
1036
|
+
|
|
1037
|
+
rsync_cmd = [
|
|
1038
|
+
"rsync", "-avz", "--progress",
|
|
1039
|
+
"-e", "ssh -o StrictHostKeyChecking=no",
|
|
1040
|
+
]
|
|
1041
|
+
if args.delete:
|
|
1042
|
+
rsync_cmd.append("--delete")
|
|
1043
|
+
rsync_cmd.extend([source, dest])
|
|
1044
|
+
|
|
1045
|
+
print(f"Running: {' '.join(rsync_cmd)}")
|
|
1046
|
+
result = subprocess.run(rsync_cmd)
|
|
1047
|
+
sys.exit(result.returncode)
|
|
1048
|
+
|
|
1049
|
+
elif args.command == "setup":
|
|
1050
|
+
ssh_key = setup_lambda_ssh_key(client)
|
|
1051
|
+
print(f"SSH key '{ssh_key}' is configured.")
|
|
1052
|
+
|
|
1053
|
+
elif args.command == "train":
|
|
1054
|
+
# Full automated training pipeline
|
|
1055
|
+
import time as time_module
|
|
1056
|
+
|
|
1057
|
+
instance = None
|
|
1058
|
+
start_time = time_module.time()
|
|
1059
|
+
launched_new = False
|
|
1060
|
+
training_completed = False # Track if training actually finished
|
|
1061
|
+
|
|
1062
|
+
# Instance pricing (approximate $/hr)
|
|
1063
|
+
INSTANCE_PRICES = {
|
|
1064
|
+
"gpu_1x_a10": 0.75,
|
|
1065
|
+
"gpu_1x_a100": 1.29,
|
|
1066
|
+
"gpu_1x_a100_sxm4": 1.29,
|
|
1067
|
+
"gpu_1x_h100_pcie": 2.49,
|
|
1068
|
+
"gpu_1x_h100_sxm5": 3.29,
|
|
1069
|
+
}
|
|
1070
|
+
|
|
1071
|
+
# Get or launch instance
|
|
1072
|
+
if args.instance:
|
|
1073
|
+
instances = client.list_instances()
|
|
1074
|
+
instance = next((i for i in instances if i.id.startswith(args.instance)), None)
|
|
1075
|
+
if not instance:
|
|
1076
|
+
print(f"Error: Instance {args.instance} not found")
|
|
1077
|
+
return
|
|
1078
|
+
else:
|
|
1079
|
+
# Check for existing instances
|
|
1080
|
+
instances = client.list_instances()
|
|
1081
|
+
if instances:
|
|
1082
|
+
print(f"Using existing instance: {instances[0].id[:8]}...")
|
|
1083
|
+
instance = instances[0]
|
|
1084
|
+
else:
|
|
1085
|
+
# Launch new instance
|
|
1086
|
+
ssh_key = setup_lambda_ssh_key(client)
|
|
1087
|
+
print(f"Launching {args.type}...")
|
|
1088
|
+
instance = client.launch_instance(
|
|
1089
|
+
instance_type=args.type,
|
|
1090
|
+
ssh_key_names=[ssh_key],
|
|
1091
|
+
name="openadapt-training",
|
|
1092
|
+
)
|
|
1093
|
+
print(f"Instance launched: {instance.id[:8]}... at {instance.ip}")
|
|
1094
|
+
launched_new = True
|
|
1095
|
+
|
|
1096
|
+
price_per_hour = INSTANCE_PRICES.get(instance.instance_type, 1.00)
|
|
1097
|
+
print(f" Instance type: {instance.instance_type} (~${price_per_hour:.2f}/hr)")
|
|
1098
|
+
print(f" Max runtime: {args.max_runtime} minutes")
|
|
1099
|
+
|
|
1100
|
+
# Generate initial dashboard with setup status
|
|
1101
|
+
from pathlib import Path
|
|
1102
|
+
from openadapt_ml.training.trainer import (
|
|
1103
|
+
TrainingState, TrainingConfig, generate_training_dashboard,
|
|
1104
|
+
setup_job_directory
|
|
1105
|
+
)
|
|
1106
|
+
import time as time_module
|
|
1107
|
+
job_id = time_module.strftime("%Y%m%d_%H%M%S")
|
|
1108
|
+
output_dir = setup_job_directory("training_output", job_id)
|
|
1109
|
+
dashboard_path = output_dir / "dashboard.html"
|
|
1110
|
+
log_path = output_dir / "training_log.json"
|
|
1111
|
+
|
|
1112
|
+
def update_dashboard(status: str, logs: list, step: int = 0, loss: float = 0.0, epoch: int = 0):
|
|
1113
|
+
"""Update dashboard with current setup/training status."""
|
|
1114
|
+
state = TrainingState(job_id=job_id)
|
|
1115
|
+
state.cloud_provider = "lambda"
|
|
1116
|
+
state.cloud_dashboard_url = "https://cloud.lambda.ai/instances"
|
|
1117
|
+
state.cloud_instance_id = instance.id
|
|
1118
|
+
state.instance_ip = instance.ip or ""
|
|
1119
|
+
state.instance_type = instance.instance_type
|
|
1120
|
+
state.setup_status = status
|
|
1121
|
+
state.setup_logs = logs
|
|
1122
|
+
state.epoch = epoch
|
|
1123
|
+
state.step = step
|
|
1124
|
+
state.loss = loss
|
|
1125
|
+
state.start_time = start_time
|
|
1126
|
+
config = TrainingConfig(num_train_epochs=5, learning_rate=5e-5)
|
|
1127
|
+
dashboard_path.write_text(generate_training_dashboard(state, config))
|
|
1128
|
+
# Also write log for polling
|
|
1129
|
+
log_path.write_text(json.dumps(state.to_dict(), indent=2))
|
|
1130
|
+
|
|
1131
|
+
# Initial dashboard
|
|
1132
|
+
setup_logs = [
|
|
1133
|
+
f"Lambda Cloud instance: {instance.id[:8]}...",
|
|
1134
|
+
f"Instance type: {instance.instance_type} (~${price_per_hour:.2f}/hr)",
|
|
1135
|
+
f"IP address: {instance.ip or 'pending...'}",
|
|
1136
|
+
]
|
|
1137
|
+
update_dashboard("booting", setup_logs)
|
|
1138
|
+
|
|
1139
|
+
# Open dashboard in browser via HTTP server
|
|
1140
|
+
server_proc = None
|
|
1141
|
+
if args.open:
|
|
1142
|
+
server_proc = open_dashboard_in_browser(output_dir)
|
|
1143
|
+
|
|
1144
|
+
try:
|
|
1145
|
+
# Set up environment with retries at the command level
|
|
1146
|
+
setup_logs.append("Connecting to instance...")
|
|
1147
|
+
update_dashboard("booting", setup_logs)
|
|
1148
|
+
|
|
1149
|
+
setup_success = False
|
|
1150
|
+
for setup_attempt in range(3):
|
|
1151
|
+
setup_logs.append(f"Setup attempt {setup_attempt + 1}/3...")
|
|
1152
|
+
update_dashboard("installing", setup_logs)
|
|
1153
|
+
if client.setup_instance(instance):
|
|
1154
|
+
setup_success = True
|
|
1155
|
+
setup_logs.append("Instance setup complete!")
|
|
1156
|
+
update_dashboard("installing", setup_logs)
|
|
1157
|
+
break
|
|
1158
|
+
if setup_attempt < 2:
|
|
1159
|
+
setup_logs.append(f"Setup attempt {setup_attempt + 1} failed, retrying in 30s...")
|
|
1160
|
+
update_dashboard("booting", setup_logs)
|
|
1161
|
+
print(f" Setup attempt {setup_attempt + 1} failed, retrying in 30s...")
|
|
1162
|
+
time_module.sleep(30)
|
|
1163
|
+
|
|
1164
|
+
if not setup_success:
|
|
1165
|
+
setup_logs.append("ERROR: Failed to set up instance after 3 attempts")
|
|
1166
|
+
update_dashboard("booting", setup_logs)
|
|
1167
|
+
print("\nError: Failed to set up instance after 3 attempts")
|
|
1168
|
+
print(f"Instance still running: {instance.ip}")
|
|
1169
|
+
print("Debug via: ssh ubuntu@" + instance.ip)
|
|
1170
|
+
print(f"Terminate with: python -m openadapt_ml.cloud.lambda_labs terminate {instance.id}")
|
|
1171
|
+
return # Don't terminate - let user debug
|
|
1172
|
+
|
|
1173
|
+
# Sync local code to ensure remote has latest changes
|
|
1174
|
+
setup_logs.append("Syncing local code to instance...")
|
|
1175
|
+
update_dashboard("installing", setup_logs)
|
|
1176
|
+
if not client.sync_local_code(instance):
|
|
1177
|
+
setup_logs.append("Warning: Failed to sync local code, using remote repo version")
|
|
1178
|
+
update_dashboard("installing", setup_logs)
|
|
1179
|
+
print("Warning: Failed to sync local code, using remote repo version")
|
|
1180
|
+
else:
|
|
1181
|
+
setup_logs.append("Code synced successfully")
|
|
1182
|
+
update_dashboard("installing", setup_logs)
|
|
1183
|
+
|
|
1184
|
+
# Upload capture if provided
|
|
1185
|
+
remote_capture = None
|
|
1186
|
+
if args.capture:
|
|
1187
|
+
setup_logs.append(f"Uploading capture data...")
|
|
1188
|
+
update_dashboard("installing", setup_logs)
|
|
1189
|
+
if client.upload_capture(instance, args.capture, "~/capture"):
|
|
1190
|
+
remote_capture = "~/capture"
|
|
1191
|
+
setup_logs.append(f"Capture uploaded to {instance.ip}:~/capture")
|
|
1192
|
+
update_dashboard("installing", setup_logs)
|
|
1193
|
+
print(f"Capture uploaded to {instance.ip}:~/capture")
|
|
1194
|
+
else:
|
|
1195
|
+
setup_logs.append("ERROR: Failed to upload capture after retries")
|
|
1196
|
+
update_dashboard("installing", setup_logs)
|
|
1197
|
+
print("\nError: Failed to upload capture after retries")
|
|
1198
|
+
print(f"Instance still running: {instance.ip}")
|
|
1199
|
+
print("Debug via: ssh ubuntu@" + instance.ip)
|
|
1200
|
+
print(f"Terminate with: python -m openadapt_ml.cloud.lambda_labs terminate {instance.id}")
|
|
1201
|
+
return # Don't terminate - let user debug
|
|
1202
|
+
|
|
1203
|
+
# Run training in background and poll for status
|
|
1204
|
+
setup_logs.append("Installing dependencies and starting training...")
|
|
1205
|
+
update_dashboard("training", setup_logs)
|
|
1206
|
+
print("\n" + "=" * 50)
|
|
1207
|
+
print("Starting training...")
|
|
1208
|
+
print("=" * 50 + "\n")
|
|
1209
|
+
|
|
1210
|
+
proc = client.run_training(
|
|
1211
|
+
instance,
|
|
1212
|
+
config=args.config,
|
|
1213
|
+
capture=remote_capture,
|
|
1214
|
+
goal=args.goal,
|
|
1215
|
+
background=True, # Run in background so we can poll
|
|
1216
|
+
)
|
|
1217
|
+
|
|
1218
|
+
# Poll for training status and update dashboard
|
|
1219
|
+
poll_interval = 10 # seconds
|
|
1220
|
+
last_step = 0
|
|
1221
|
+
last_epoch = 0
|
|
1222
|
+
print(f"Polling training status every {poll_interval}s (Ctrl+C to stop)...\n")
|
|
1223
|
+
|
|
1224
|
+
while True:
|
|
1225
|
+
try:
|
|
1226
|
+
status = client.get_training_status(instance)
|
|
1227
|
+
|
|
1228
|
+
if status and status.get("step", 0) > 0:
|
|
1229
|
+
step = status.get("step", 0)
|
|
1230
|
+
epoch = status.get("epoch", 0)
|
|
1231
|
+
loss = status.get("loss", 0)
|
|
1232
|
+
elapsed_training = status.get("elapsed_time", 0)
|
|
1233
|
+
total_epochs = status.get("total_epochs", 5)
|
|
1234
|
+
|
|
1235
|
+
# Print progress when step changes
|
|
1236
|
+
if step > last_step or epoch > last_epoch:
|
|
1237
|
+
print(f" Epoch {epoch+1}/{total_epochs} | Step {step} | Loss: {loss:.4f} | Elapsed: {elapsed_training:.0f}s")
|
|
1238
|
+
last_step = step
|
|
1239
|
+
last_epoch = epoch
|
|
1240
|
+
|
|
1241
|
+
# Update local training_log.json (dashboard polls this)
|
|
1242
|
+
status["total_epochs"] = total_epochs
|
|
1243
|
+
if not status.get("instance_ip"):
|
|
1244
|
+
status["instance_ip"] = instance.ip
|
|
1245
|
+
if not status.get("instance_type"):
|
|
1246
|
+
status["instance_type"] = instance.instance_type
|
|
1247
|
+
# Add cloud provider info
|
|
1248
|
+
status["cloud_provider"] = "lambda"
|
|
1249
|
+
status["cloud_dashboard_url"] = "https://cloud.lambda.ai/instances"
|
|
1250
|
+
status["cloud_instance_id"] = instance.id
|
|
1251
|
+
status["setup_status"] = "training"
|
|
1252
|
+
status["setup_logs"] = setup_logs
|
|
1253
|
+
log_path.write_text(json.dumps(status, indent=2))
|
|
1254
|
+
|
|
1255
|
+
# Regenerate dashboard with updated data
|
|
1256
|
+
state = TrainingState()
|
|
1257
|
+
state.job_id = status.get("job_id", "")
|
|
1258
|
+
state.hostname = status.get("hostname", "lambda")
|
|
1259
|
+
state.instance_ip = instance.ip or ""
|
|
1260
|
+
state.instance_type = instance.instance_type
|
|
1261
|
+
state.epoch = epoch
|
|
1262
|
+
state.step = step
|
|
1263
|
+
state.total_epochs = total_epochs
|
|
1264
|
+
state.loss = loss
|
|
1265
|
+
state.learning_rate = status.get("learning_rate", 5e-5)
|
|
1266
|
+
state.losses = status.get("losses", [])
|
|
1267
|
+
state.evaluations = status.get("evaluations", [])
|
|
1268
|
+
state.start_time = time_module.time() - elapsed_training
|
|
1269
|
+
state.cloud_provider = "lambda"
|
|
1270
|
+
state.cloud_dashboard_url = "https://cloud.lambda.ai/instances"
|
|
1271
|
+
state.cloud_instance_id = instance.id
|
|
1272
|
+
state.setup_status = "training"
|
|
1273
|
+
state.setup_logs = setup_logs
|
|
1274
|
+
|
|
1275
|
+
config = TrainingConfig(
|
|
1276
|
+
num_train_epochs=total_epochs,
|
|
1277
|
+
learning_rate=status.get("learning_rate", 5e-5)
|
|
1278
|
+
)
|
|
1279
|
+
dashboard_path.write_text(generate_training_dashboard(state, config))
|
|
1280
|
+
|
|
1281
|
+
# Check if training is complete (all epochs done)
|
|
1282
|
+
if epoch >= total_epochs - 1:
|
|
1283
|
+
# Check if step count stopped increasing
|
|
1284
|
+
time_module.sleep(poll_interval)
|
|
1285
|
+
new_status = client.get_training_status(instance)
|
|
1286
|
+
if new_status and new_status.get("step", 0) == step:
|
|
1287
|
+
print("\n" + "=" * 50)
|
|
1288
|
+
print("Training complete!")
|
|
1289
|
+
print("=" * 50)
|
|
1290
|
+
training_completed = True
|
|
1291
|
+
break
|
|
1292
|
+
else:
|
|
1293
|
+
# Training not started yet, show setup status
|
|
1294
|
+
print(" Waiting for training to start...")
|
|
1295
|
+
|
|
1296
|
+
except Exception as e:
|
|
1297
|
+
print(f" Poll error: {e}")
|
|
1298
|
+
|
|
1299
|
+
time_module.sleep(poll_interval)
|
|
1300
|
+
|
|
1301
|
+
except KeyboardInterrupt:
|
|
1302
|
+
print("\n\nTraining interrupted by user")
|
|
1303
|
+
finally:
|
|
1304
|
+
# Clean up HTTP server if running
|
|
1305
|
+
if server_proc:
|
|
1306
|
+
server_proc.terminate()
|
|
1307
|
+
print("Dashboard server stopped.")
|
|
1308
|
+
|
|
1309
|
+
# Only auto-terminate if training completed successfully or user requested it
|
|
1310
|
+
elapsed = time_module.time() - start_time
|
|
1311
|
+
cost = (elapsed / 3600) * price_per_hour
|
|
1312
|
+
|
|
1313
|
+
if training_completed and not args.no_terminate:
|
|
1314
|
+
# Run comparison on Lambda before downloading and terminating (if capture was provided)
|
|
1315
|
+
if args.capture:
|
|
1316
|
+
print("\n" + "=" * 50)
|
|
1317
|
+
print("Running comparison on Lambda instance...")
|
|
1318
|
+
print("=" * 50)
|
|
1319
|
+
|
|
1320
|
+
# Determine the final checkpoint path (main checkpoint after training)
|
|
1321
|
+
checkpoint_path = "/home/ubuntu/openadapt-ml/checkpoints/qwen3vl2b_capture_lora"
|
|
1322
|
+
|
|
1323
|
+
# Check if checkpoint exists
|
|
1324
|
+
result = client.ssh_run(
|
|
1325
|
+
instance,
|
|
1326
|
+
f"ls {checkpoint_path}/adapter_config.json 2>/dev/null && echo 'exists'",
|
|
1327
|
+
timeout=30
|
|
1328
|
+
)
|
|
1329
|
+
|
|
1330
|
+
if "exists" in result.stdout:
|
|
1331
|
+
# Run comparison on Lambda
|
|
1332
|
+
output_name = f"comparison_{time_module.strftime('%H%M%S')}.html"
|
|
1333
|
+
cmd = f"""cd ~/openadapt-ml && source .venv/bin/activate && \
|
|
1334
|
+
python -m openadapt_ml.scripts.compare \
|
|
1335
|
+
--capture ~/capture \
|
|
1336
|
+
--checkpoint {checkpoint_path} \
|
|
1337
|
+
--output training_output/{output_name} 2>&1"""
|
|
1338
|
+
|
|
1339
|
+
print(" Generating comparison viewer (this may take a few minutes)...")
|
|
1340
|
+
result = client.ssh_run(instance, cmd, timeout=600)
|
|
1341
|
+
|
|
1342
|
+
if result.returncode == 0:
|
|
1343
|
+
print(f" Comparison generated: {output_name}")
|
|
1344
|
+
else:
|
|
1345
|
+
print(f" Warning: Comparison generation failed")
|
|
1346
|
+
if result.stderr:
|
|
1347
|
+
print(f" Error: {result.stderr}")
|
|
1348
|
+
else:
|
|
1349
|
+
print(" Warning: Final checkpoint not found, skipping comparison")
|
|
1350
|
+
|
|
1351
|
+
# Download results (including comparison if generated)
|
|
1352
|
+
print("\n" + "=" * 50)
|
|
1353
|
+
print("Downloading results...")
|
|
1354
|
+
print("=" * 50)
|
|
1355
|
+
client.download_results(instance)
|
|
1356
|
+
|
|
1357
|
+
print(f"\nTerminating instance {instance.id[:8]}...")
|
|
1358
|
+
client.terminate_instance(instance.id)
|
|
1359
|
+
print("Instance terminated.")
|
|
1360
|
+
print(f"\nFinal cost: ~${cost:.2f} ({elapsed/60:.1f} minutes)")
|
|
1361
|
+
else:
|
|
1362
|
+
print(f"\nInstance still running: {instance.ip}")
|
|
1363
|
+
print(f" Current cost: ~${cost:.2f}")
|
|
1364
|
+
if not training_completed:
|
|
1365
|
+
print(f" (Not terminating - training did not complete successfully)")
|
|
1366
|
+
print(f"Terminate with: python -m openadapt_ml.cloud.lambda_labs terminate {instance.id}")
|
|
1367
|
+
|
|
1368
|
+
elif args.command == "train-status":
|
|
1369
|
+
instances = client.list_instances()
|
|
1370
|
+
if not instances:
|
|
1371
|
+
print("No running instances.")
|
|
1372
|
+
return
|
|
1373
|
+
|
|
1374
|
+
if args.instance_id:
|
|
1375
|
+
instance = next((i for i in instances if i.id.startswith(args.instance_id)), None)
|
|
1376
|
+
if not instance:
|
|
1377
|
+
print(f"Instance {args.instance_id} not found.")
|
|
1378
|
+
return
|
|
1379
|
+
else:
|
|
1380
|
+
instance = instances[0]
|
|
1381
|
+
|
|
1382
|
+
print(f"Checking training status on {instance.ip}...")
|
|
1383
|
+
status = client.get_training_status(instance)
|
|
1384
|
+
|
|
1385
|
+
if status:
|
|
1386
|
+
print(f" Epoch: {status.get('epoch', 'N/A')}")
|
|
1387
|
+
print(f" Step: {status.get('step', 'N/A')}")
|
|
1388
|
+
print(f" Loss: {status.get('loss', 'N/A')}")
|
|
1389
|
+
print(f" Elapsed: {status.get('elapsed_time', 0):.1f}s")
|
|
1390
|
+
else:
|
|
1391
|
+
print(" No training log found (training may not have started yet)")
|
|
1392
|
+
|
|
1393
|
+
elif args.command == "checkpoints":
|
|
1394
|
+
# List checkpoints on remote instance
|
|
1395
|
+
instances = client.list_instances()
|
|
1396
|
+
if not instances:
|
|
1397
|
+
print("No running instances.")
|
|
1398
|
+
return
|
|
1399
|
+
|
|
1400
|
+
if args.instance_id:
|
|
1401
|
+
instance = next((i for i in instances if i.id.startswith(args.instance_id)), None)
|
|
1402
|
+
if not instance:
|
|
1403
|
+
print(f"Instance {args.instance_id} not found.")
|
|
1404
|
+
return
|
|
1405
|
+
else:
|
|
1406
|
+
instance = instances[0]
|
|
1407
|
+
|
|
1408
|
+
print(f"Checking checkpoints on {instance.ip}...")
|
|
1409
|
+
|
|
1410
|
+
ssh_cmd = [
|
|
1411
|
+
"ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=10",
|
|
1412
|
+
f"ubuntu@{instance.ip}",
|
|
1413
|
+
"ls -la ~/openadapt-ml/checkpoints/ 2>/dev/null && "
|
|
1414
|
+
"du -sh ~/openadapt-ml/checkpoints/ 2>/dev/null || echo 'No checkpoints directory found'"
|
|
1415
|
+
]
|
|
1416
|
+
|
|
1417
|
+
result = subprocess.run(ssh_cmd, capture_output=True, text=True)
|
|
1418
|
+
if result.returncode == 0:
|
|
1419
|
+
print(result.stdout)
|
|
1420
|
+
else:
|
|
1421
|
+
print("No checkpoints found yet")
|
|
1422
|
+
if result.stderr:
|
|
1423
|
+
print(f" Error: {result.stderr}")
|
|
1424
|
+
|
|
1425
|
+
elif args.command == "refresh":
|
|
1426
|
+
# One-shot dashboard refresh
|
|
1427
|
+
import time as time_module
|
|
1428
|
+
from pathlib import Path
|
|
1429
|
+
from openadapt_ml.training.trainer import TrainingState, TrainingConfig, generate_training_dashboard
|
|
1430
|
+
|
|
1431
|
+
instances = client.list_instances()
|
|
1432
|
+
if not instances:
|
|
1433
|
+
print("No running instances.")
|
|
1434
|
+
return
|
|
1435
|
+
|
|
1436
|
+
if args.instance_id:
|
|
1437
|
+
instance = next((i for i in instances if i.id.startswith(args.instance_id)), None)
|
|
1438
|
+
if not instance:
|
|
1439
|
+
print(f"Instance {args.instance_id} not found.")
|
|
1440
|
+
return
|
|
1441
|
+
else:
|
|
1442
|
+
instance = instances[0]
|
|
1443
|
+
|
|
1444
|
+
# Use current job directory via symlink
|
|
1445
|
+
from openadapt_ml.training.trainer import get_current_job_directory, setup_job_directory
|
|
1446
|
+
base_dir = Path("training_output")
|
|
1447
|
+
base_dir.mkdir(exist_ok=True)
|
|
1448
|
+
|
|
1449
|
+
status = client.get_training_status(instance)
|
|
1450
|
+
|
|
1451
|
+
if status and status.get("step", 0) > 0:
|
|
1452
|
+
# Get or create job directory based on remote job_id
|
|
1453
|
+
remote_job_id = status.get("job_id", "")
|
|
1454
|
+
if remote_job_id:
|
|
1455
|
+
output_dir = setup_job_directory(base_dir, remote_job_id)
|
|
1456
|
+
else:
|
|
1457
|
+
output_dir = get_current_job_directory(base_dir) or base_dir
|
|
1458
|
+
dashboard_path = output_dir / "dashboard.html"
|
|
1459
|
+
log_path = output_dir / "training_log.json"
|
|
1460
|
+
|
|
1461
|
+
# Setup screenshots symlink if local capture path provided
|
|
1462
|
+
local_capture = args.capture if hasattr(args, 'capture') and args.capture else None
|
|
1463
|
+
if local_capture:
|
|
1464
|
+
setup_capture_screenshots_symlink(output_dir, local_capture)
|
|
1465
|
+
|
|
1466
|
+
# Rewrite evaluation paths from Lambda to relative
|
|
1467
|
+
if "evaluations" in status:
|
|
1468
|
+
status["evaluations"] = rewrite_evaluation_paths(status["evaluations"])
|
|
1469
|
+
|
|
1470
|
+
# Ensure instance metadata is present
|
|
1471
|
+
status["instance_ip"] = instance.ip
|
|
1472
|
+
status["instance_type"] = instance.instance_type
|
|
1473
|
+
status["total_epochs"] = status.get("total_epochs", 5)
|
|
1474
|
+
|
|
1475
|
+
# Save log
|
|
1476
|
+
log_path.write_text(json.dumps(status, indent=2))
|
|
1477
|
+
|
|
1478
|
+
# Generate dashboard
|
|
1479
|
+
state = TrainingState(job_id=remote_job_id)
|
|
1480
|
+
state.job_id = remote_job_id
|
|
1481
|
+
state.hostname = status.get("hostname", "lambda")
|
|
1482
|
+
state.instance_ip = instance.ip or ""
|
|
1483
|
+
state.instance_type = instance.instance_type
|
|
1484
|
+
state.config_path = status.get("config_path", "")
|
|
1485
|
+
# Use local capture path for screenshots if provided, else remote path
|
|
1486
|
+
state.capture_path = args.capture if args.capture else status.get("capture_path", "")
|
|
1487
|
+
state.epoch = status.get("epoch", 0)
|
|
1488
|
+
state.step = status.get("step", 0)
|
|
1489
|
+
state.loss = status.get("loss", 0)
|
|
1490
|
+
state.learning_rate = status.get("learning_rate", 5e-5)
|
|
1491
|
+
state.losses = status.get("losses", [])
|
|
1492
|
+
state.evaluations = status.get("evaluations", [])
|
|
1493
|
+
state.total_epochs = status.get("total_epochs", 5)
|
|
1494
|
+
state.start_time = time_module.time() - status.get("elapsed_time", 0)
|
|
1495
|
+
# Cloud provider info
|
|
1496
|
+
state.cloud_provider = "lambda"
|
|
1497
|
+
state.cloud_dashboard_url = "https://cloud.lambda.ai/instances"
|
|
1498
|
+
state.cloud_instance_id = instance.id
|
|
1499
|
+
state.setup_status = status.get("setup_status", "training")
|
|
1500
|
+
state.setup_logs = status.get("setup_logs", [])
|
|
1501
|
+
|
|
1502
|
+
config = TrainingConfig(
|
|
1503
|
+
num_train_epochs=status.get("total_epochs", 5),
|
|
1504
|
+
learning_rate=status.get("learning_rate", 5e-5)
|
|
1505
|
+
)
|
|
1506
|
+
|
|
1507
|
+
dashboard_path.write_text(generate_training_dashboard(state, config))
|
|
1508
|
+
|
|
1509
|
+
# Regenerate navigation for file:// protocol
|
|
1510
|
+
try:
|
|
1511
|
+
from openadapt_ml.training.trainer import regenerate_all_dashboards
|
|
1512
|
+
regenerate_all_dashboards(output_dir)
|
|
1513
|
+
except Exception:
|
|
1514
|
+
pass # Silent fail for navigation
|
|
1515
|
+
|
|
1516
|
+
epoch = status.get("epoch", 0)
|
|
1517
|
+
step = status.get("step", 0)
|
|
1518
|
+
loss = status.get("loss", 0)
|
|
1519
|
+
elapsed = status.get("elapsed_time", 0)
|
|
1520
|
+
print(f"Epoch {epoch+1}/{state.total_epochs} | Step {step} | Loss: {loss:.4f} | Elapsed: {elapsed:.0f}s")
|
|
1521
|
+
print(f"Dashboard: {dashboard_path.absolute()}")
|
|
1522
|
+
|
|
1523
|
+
if args.open:
|
|
1524
|
+
import subprocess as sp
|
|
1525
|
+
sp.run(["open", str(dashboard_path)], capture_output=True)
|
|
1526
|
+
else:
|
|
1527
|
+
print("No training data yet")
|
|
1528
|
+
|
|
1529
|
+
elif args.command == "monitor":
|
|
1530
|
+
# Live dashboard monitoring for Lambda training
|
|
1531
|
+
# Updates training_output/training_log.json so the existing dashboard auto-refreshes
|
|
1532
|
+
import time as time_module
|
|
1533
|
+
from pathlib import Path
|
|
1534
|
+
|
|
1535
|
+
# Stub mode - simulate training without actual GPU
|
|
1536
|
+
if getattr(args, 'stub', False):
|
|
1537
|
+
from openadapt_ml.training.stub_provider import StubTrainingProvider
|
|
1538
|
+
from openadapt_ml.training.trainer import (
|
|
1539
|
+
TrainingState, TrainingConfig, generate_training_dashboard
|
|
1540
|
+
)
|
|
1541
|
+
|
|
1542
|
+
print("\n[Stub Mode] Simulating training without GPU...")
|
|
1543
|
+
output_dir = Path("training_output")
|
|
1544
|
+
output_dir.mkdir(exist_ok=True)
|
|
1545
|
+
|
|
1546
|
+
# Start dashboard server if requested
|
|
1547
|
+
server_proc = None
|
|
1548
|
+
if args.open:
|
|
1549
|
+
server_proc = open_dashboard_in_browser(output_dir)
|
|
1550
|
+
|
|
1551
|
+
# Run stub training
|
|
1552
|
+
stub = StubTrainingProvider(
|
|
1553
|
+
output_dir=output_dir,
|
|
1554
|
+
epochs=5,
|
|
1555
|
+
steps_per_epoch=10,
|
|
1556
|
+
step_delay=0.3, # Fast simulation
|
|
1557
|
+
)
|
|
1558
|
+
|
|
1559
|
+
def update_dashboard(status):
|
|
1560
|
+
"""Regenerate dashboard after each step."""
|
|
1561
|
+
state = TrainingState()
|
|
1562
|
+
state.job_id = status.get("job_id", "")
|
|
1563
|
+
state.hostname = status.get("hostname", "stub")
|
|
1564
|
+
state.instance_ip = "127.0.0.1"
|
|
1565
|
+
state.instance_type = "stub"
|
|
1566
|
+
state.epoch = status.get("epoch", 0)
|
|
1567
|
+
state.step = status.get("step", 0)
|
|
1568
|
+
state.loss = status.get("loss", 0)
|
|
1569
|
+
state.learning_rate = status.get("learning_rate", 5e-5)
|
|
1570
|
+
state.losses = status.get("losses", [])
|
|
1571
|
+
state.evaluations = status.get("evaluations", [])
|
|
1572
|
+
state.cloud_provider = "stub"
|
|
1573
|
+
state.setup_status = "training"
|
|
1574
|
+
|
|
1575
|
+
config = TrainingConfig(
|
|
1576
|
+
num_train_epochs=status.get("total_epochs", 5),
|
|
1577
|
+
learning_rate=state.learning_rate
|
|
1578
|
+
)
|
|
1579
|
+
|
|
1580
|
+
dashboard_path = output_dir / "dashboard.html"
|
|
1581
|
+
dashboard_path.write_text(generate_training_dashboard(state, config))
|
|
1582
|
+
|
|
1583
|
+
try:
|
|
1584
|
+
stub.run(callback=update_dashboard)
|
|
1585
|
+
except KeyboardInterrupt:
|
|
1586
|
+
print("\n[Stub] Interrupted by user.")
|
|
1587
|
+
finally:
|
|
1588
|
+
if server_proc:
|
|
1589
|
+
server_proc.terminate()
|
|
1590
|
+
print("[Stub] Dashboard server stopped.")
|
|
1591
|
+
|
|
1592
|
+
print(f"\n[Stub] Results in: {output_dir}")
|
|
1593
|
+
return
|
|
1594
|
+
|
|
1595
|
+
instances = client.list_instances()
|
|
1596
|
+
if not instances:
|
|
1597
|
+
print("No running instances.")
|
|
1598
|
+
return
|
|
1599
|
+
|
|
1600
|
+
if args.instance_id:
|
|
1601
|
+
instance = next((i for i in instances if i.id.startswith(args.instance_id)), None)
|
|
1602
|
+
if not instance:
|
|
1603
|
+
print(f"Instance {args.instance_id} not found.")
|
|
1604
|
+
return
|
|
1605
|
+
else:
|
|
1606
|
+
instance = instances[0]
|
|
1607
|
+
|
|
1608
|
+
if instance.status == "booting" or not instance.ip:
|
|
1609
|
+
print(f"Instance {instance.id[:8]} is still booting, waiting for IP...")
|
|
1610
|
+
while True:
|
|
1611
|
+
time_module.sleep(5)
|
|
1612
|
+
instances = client.list_instances()
|
|
1613
|
+
instance = next((i for i in instances if i.id == instance.id), None)
|
|
1614
|
+
if not instance:
|
|
1615
|
+
print("Instance terminated or not found.")
|
|
1616
|
+
return
|
|
1617
|
+
if instance.ip and instance.status == "active":
|
|
1618
|
+
print(f"Instance ready at {instance.ip}")
|
|
1619
|
+
break
|
|
1620
|
+
print(f" Status: {instance.status}...")
|
|
1621
|
+
|
|
1622
|
+
# Use job-scoped directory structure
|
|
1623
|
+
from openadapt_ml.training.trainer import (
|
|
1624
|
+
TrainingState, TrainingConfig, generate_training_dashboard,
|
|
1625
|
+
setup_job_directory, get_current_job_directory
|
|
1626
|
+
)
|
|
1627
|
+
base_dir = Path("training_output")
|
|
1628
|
+
base_dir.mkdir(exist_ok=True)
|
|
1629
|
+
|
|
1630
|
+
# Get current job directory or wait for first status to determine job_id
|
|
1631
|
+
output_dir = get_current_job_directory(base_dir) or base_dir
|
|
1632
|
+
dashboard_path = output_dir / "dashboard.html"
|
|
1633
|
+
log_path = output_dir / "training_log.json"
|
|
1634
|
+
|
|
1635
|
+
# Check for existing log with job_id
|
|
1636
|
+
current_job_id = None
|
|
1637
|
+
if log_path.exists():
|
|
1638
|
+
try:
|
|
1639
|
+
existing_log = json.loads(log_path.read_text())
|
|
1640
|
+
current_job_id = existing_log.get("job_id")
|
|
1641
|
+
except (json.JSONDecodeError, IOError):
|
|
1642
|
+
pass
|
|
1643
|
+
|
|
1644
|
+
print(f"\nMonitoring Lambda training on {instance.ip}")
|
|
1645
|
+
print(f"Dashboard: {dashboard_path.absolute()}")
|
|
1646
|
+
print(f"Polling every {args.interval}s (Ctrl+C to stop)\n")
|
|
1647
|
+
|
|
1648
|
+
# Generate initial dashboard if it doesn't exist
|
|
1649
|
+
if not dashboard_path.exists():
|
|
1650
|
+
state = TrainingState(job_id=current_job_id or "")
|
|
1651
|
+
state.cloud_provider = "lambda"
|
|
1652
|
+
state.cloud_dashboard_url = "https://cloud.lambda.ai/instances"
|
|
1653
|
+
state.cloud_instance_id = instance.id
|
|
1654
|
+
state.instance_ip = instance.ip or ""
|
|
1655
|
+
state.instance_type = instance.instance_type
|
|
1656
|
+
state.setup_status = "booting"
|
|
1657
|
+
state.setup_logs = ["Starting Lambda Cloud instance...", f"Instance ID: {instance.id[:8]}...", f"Instance type: {instance.instance_type}"]
|
|
1658
|
+
config = TrainingConfig(num_train_epochs=5, learning_rate=5e-5)
|
|
1659
|
+
dashboard_path.write_text(generate_training_dashboard(state, config))
|
|
1660
|
+
|
|
1661
|
+
# Open dashboard if requested via HTTP server
|
|
1662
|
+
server_proc = None
|
|
1663
|
+
if args.open:
|
|
1664
|
+
server_proc = open_dashboard_in_browser(output_dir)
|
|
1665
|
+
|
|
1666
|
+
last_step = 0
|
|
1667
|
+
last_epoch = -1
|
|
1668
|
+
auto_stop_loss = getattr(args, 'auto_stop_loss', 0.5)
|
|
1669
|
+
download_checkpoints = getattr(args, 'download_checkpoints', True)
|
|
1670
|
+
step_stall_count = 0 # Track how many times step hasn't increased
|
|
1671
|
+
|
|
1672
|
+
print(f" Auto-stop loss threshold: {auto_stop_loss}")
|
|
1673
|
+
print(f" Checkpoint download: {'enabled' if download_checkpoints else 'disabled'}")
|
|
1674
|
+
|
|
1675
|
+
try:
|
|
1676
|
+
while True:
|
|
1677
|
+
# Check for stop signal from dashboard
|
|
1678
|
+
if check_stop_signal(output_dir):
|
|
1679
|
+
print("\n Stop signal received from dashboard!")
|
|
1680
|
+
print(" Downloading final checkpoints...")
|
|
1681
|
+
if download_checkpoints:
|
|
1682
|
+
download_checkpoints_from_instance(instance.ip, output_dir)
|
|
1683
|
+
|
|
1684
|
+
# Update status with termination info before terminating
|
|
1685
|
+
termination_status = {
|
|
1686
|
+
"termination_status": "user_stop",
|
|
1687
|
+
"termination_message": "Training stopped by user via dashboard"
|
|
1688
|
+
}
|
|
1689
|
+
current_log = log_path.read_text() if log_path.exists() else "{}"
|
|
1690
|
+
import json as json_module
|
|
1691
|
+
current_data = json_module.loads(current_log)
|
|
1692
|
+
current_data.update(termination_status)
|
|
1693
|
+
log_path.write_text(json_module.dumps(current_data, indent=2))
|
|
1694
|
+
|
|
1695
|
+
print(f" Terminating instance {instance.id}...")
|
|
1696
|
+
client.terminate_instance(instance.id)
|
|
1697
|
+
# Remove stop signal
|
|
1698
|
+
(output_dir / "STOP_TRAINING").unlink(missing_ok=True)
|
|
1699
|
+
print(" Training stopped by user.")
|
|
1700
|
+
break
|
|
1701
|
+
|
|
1702
|
+
try:
|
|
1703
|
+
# Fetch training log from remote
|
|
1704
|
+
status = client.get_training_status(instance)
|
|
1705
|
+
|
|
1706
|
+
if status and status.get("step", 0) > 0:
|
|
1707
|
+
step = status.get("step", 0)
|
|
1708
|
+
epoch = status.get("epoch", 0)
|
|
1709
|
+
loss = status.get("loss", 0)
|
|
1710
|
+
elapsed = status.get("elapsed_time", 0)
|
|
1711
|
+
remote_job_id = status.get("job_id")
|
|
1712
|
+
|
|
1713
|
+
# Detect job_id change - clear old data if new job started
|
|
1714
|
+
if remote_job_id and current_job_id and remote_job_id != current_job_id:
|
|
1715
|
+
print(f"\n New job detected: {remote_job_id} (was: {current_job_id})")
|
|
1716
|
+
print(" Clearing old job data...")
|
|
1717
|
+
last_step = 0 # Reset step tracking
|
|
1718
|
+
current_job_id = remote_job_id
|
|
1719
|
+
|
|
1720
|
+
# Update local training log (dashboard polls this file)
|
|
1721
|
+
# Add total_epochs to status for dashboard
|
|
1722
|
+
status["total_epochs"] = status.get("total_epochs", 5)
|
|
1723
|
+
# Ensure instance metadata is present
|
|
1724
|
+
if not status.get("instance_ip"):
|
|
1725
|
+
status["instance_ip"] = instance.ip
|
|
1726
|
+
if not status.get("instance_type"):
|
|
1727
|
+
status["instance_type"] = instance.instance_type
|
|
1728
|
+
# Add cloud provider info
|
|
1729
|
+
status["cloud_provider"] = "lambda"
|
|
1730
|
+
status["cloud_dashboard_url"] = "https://cloud.lambda.ai/instances"
|
|
1731
|
+
status["cloud_instance_id"] = instance.id
|
|
1732
|
+
status["setup_status"] = status.get("setup_status", "training")
|
|
1733
|
+
|
|
1734
|
+
# Setup screenshots symlink if local capture path provided
|
|
1735
|
+
local_capture = args.capture if hasattr(args, 'capture') and args.capture else None
|
|
1736
|
+
if local_capture:
|
|
1737
|
+
setup_capture_screenshots_symlink(output_dir, local_capture)
|
|
1738
|
+
|
|
1739
|
+
# Rewrite evaluation paths from Lambda to relative
|
|
1740
|
+
if "evaluations" in status:
|
|
1741
|
+
status["evaluations"] = rewrite_evaluation_paths(status["evaluations"])
|
|
1742
|
+
|
|
1743
|
+
log_path.write_text(json.dumps(status, indent=2))
|
|
1744
|
+
|
|
1745
|
+
if step > last_step:
|
|
1746
|
+
print(f" Epoch {epoch+1} | Step {step} | Loss: {loss:.4f} | Elapsed: {elapsed:.0f}s")
|
|
1747
|
+
last_step = step
|
|
1748
|
+
step_stall_count = 0 # Reset stall counter when step increases
|
|
1749
|
+
if not current_job_id:
|
|
1750
|
+
current_job_id = remote_job_id
|
|
1751
|
+
|
|
1752
|
+
# Regenerate dashboard with updated data
|
|
1753
|
+
state = TrainingState()
|
|
1754
|
+
state.job_id = status.get("job_id", "")
|
|
1755
|
+
state.hostname = status.get("hostname", "lambda")
|
|
1756
|
+
state.instance_ip = instance.ip or ""
|
|
1757
|
+
state.instance_type = instance.instance_type
|
|
1758
|
+
state.epoch = epoch
|
|
1759
|
+
state.step = step
|
|
1760
|
+
state.loss = loss
|
|
1761
|
+
state.learning_rate = status.get("learning_rate", 5e-5)
|
|
1762
|
+
state.losses = status.get("losses", [])
|
|
1763
|
+
state.evaluations = status.get("evaluations", [])
|
|
1764
|
+
state.start_time = time_module.time() - elapsed
|
|
1765
|
+
# Cloud provider info
|
|
1766
|
+
state.cloud_provider = "lambda"
|
|
1767
|
+
state.cloud_dashboard_url = "https://cloud.lambda.ai/instances"
|
|
1768
|
+
state.cloud_instance_id = instance.id
|
|
1769
|
+
state.setup_status = status.get("setup_status", "training")
|
|
1770
|
+
state.setup_logs = status.get("setup_logs", [])
|
|
1771
|
+
state.termination_status = status.get("termination_status", "")
|
|
1772
|
+
state.termination_message = status.get("termination_message", "")
|
|
1773
|
+
|
|
1774
|
+
config = TrainingConfig(
|
|
1775
|
+
num_train_epochs=status.get("total_epochs", 5),
|
|
1776
|
+
learning_rate=status.get("learning_rate", 5e-5)
|
|
1777
|
+
)
|
|
1778
|
+
|
|
1779
|
+
dashboard_path.write_text(generate_training_dashboard(state, config))
|
|
1780
|
+
|
|
1781
|
+
# Download checkpoints on epoch change
|
|
1782
|
+
if download_checkpoints and epoch > last_epoch:
|
|
1783
|
+
print(f" Epoch {epoch+1} completed - downloading checkpoints...")
|
|
1784
|
+
if download_checkpoints_from_instance(instance.ip, output_dir):
|
|
1785
|
+
print(f" Checkpoints saved to {output_dir}/checkpoints/")
|
|
1786
|
+
else:
|
|
1787
|
+
print(" Warning: checkpoint download failed")
|
|
1788
|
+
last_epoch = epoch
|
|
1789
|
+
|
|
1790
|
+
# Auto-terminate when loss is low enough
|
|
1791
|
+
if loss < auto_stop_loss and loss > 0:
|
|
1792
|
+
print(f"\n Loss {loss:.4f} < threshold {auto_stop_loss}")
|
|
1793
|
+
print(" Downloading final checkpoints...")
|
|
1794
|
+
if download_checkpoints:
|
|
1795
|
+
download_checkpoints_from_instance(instance.ip, output_dir)
|
|
1796
|
+
|
|
1797
|
+
# Update status with termination info
|
|
1798
|
+
status["termination_status"] = "auto_low_loss"
|
|
1799
|
+
status["termination_message"] = f"Training auto-stopped: loss {loss:.4f} < threshold {auto_stop_loss}"
|
|
1800
|
+
log_path.write_text(json.dumps(status, indent=2))
|
|
1801
|
+
|
|
1802
|
+
print(f" Auto-terminating instance {instance.id}...")
|
|
1803
|
+
client.terminate_instance(instance.id)
|
|
1804
|
+
print(" Training completed (auto-stopped)!")
|
|
1805
|
+
break
|
|
1806
|
+
else:
|
|
1807
|
+
# Step didn't increase - check if training is complete
|
|
1808
|
+
step_stall_count += 1
|
|
1809
|
+
total_epochs = status.get("total_epochs", 5)
|
|
1810
|
+
|
|
1811
|
+
# If on last epoch and step hasn't increased for 3 polls, training is complete
|
|
1812
|
+
if epoch >= total_epochs - 1 and step_stall_count >= 3:
|
|
1813
|
+
print(f"\n Training complete (epoch {epoch+1}/{total_epochs}, step stopped increasing)")
|
|
1814
|
+
print(" Downloading final checkpoints...")
|
|
1815
|
+
if download_checkpoints:
|
|
1816
|
+
download_checkpoints_from_instance(instance.ip, output_dir)
|
|
1817
|
+
|
|
1818
|
+
# Update status with termination info
|
|
1819
|
+
status["termination_status"] = "auto_complete"
|
|
1820
|
+
status["termination_message"] = f"Training completed successfully ({epoch+1}/{total_epochs} epochs)"
|
|
1821
|
+
log_path.write_text(json.dumps(status, indent=2))
|
|
1822
|
+
|
|
1823
|
+
print(f" Terminating instance {instance.id}...")
|
|
1824
|
+
client.terminate_instance(instance.id)
|
|
1825
|
+
print(" Instance terminated.")
|
|
1826
|
+
break
|
|
1827
|
+
|
|
1828
|
+
else:
|
|
1829
|
+
print(" Waiting for training to start...")
|
|
1830
|
+
|
|
1831
|
+
except Exception as e:
|
|
1832
|
+
print(f" Poll error: {e}")
|
|
1833
|
+
|
|
1834
|
+
time_module.sleep(args.interval)
|
|
1835
|
+
|
|
1836
|
+
except KeyboardInterrupt:
|
|
1837
|
+
print("\n\nMonitoring stopped.")
|
|
1838
|
+
print(f"Dashboard: {dashboard_path.absolute()}")
|
|
1839
|
+
finally:
|
|
1840
|
+
# Clean up HTTP server if running
|
|
1841
|
+
if server_proc:
|
|
1842
|
+
server_proc.terminate()
|
|
1843
|
+
print("Dashboard server stopped.")
|
|
1844
|
+
|
|
1845
|
+
elif args.command == "files":
|
|
1846
|
+
instances = client.list_instances()
|
|
1847
|
+
if not instances:
|
|
1848
|
+
print("No running instances.")
|
|
1849
|
+
return
|
|
1850
|
+
|
|
1851
|
+
if args.instance_id:
|
|
1852
|
+
instance = next((i for i in instances if i.id.startswith(args.instance_id)), None)
|
|
1853
|
+
if not instance:
|
|
1854
|
+
print(f"Instance {args.instance_id} not found.")
|
|
1855
|
+
return
|
|
1856
|
+
else:
|
|
1857
|
+
instance = instances[0]
|
|
1858
|
+
|
|
1859
|
+
print(f"Files on {instance.ip} at {args.path}:")
|
|
1860
|
+
result = client.ssh_run(instance, f"find {args.path} -type f -name '*.pt' -o -name '*.json' -o -name '*.bin' 2>/dev/null | head -20", timeout=30)
|
|
1861
|
+
if result.stdout:
|
|
1862
|
+
for line in result.stdout.strip().split('\n'):
|
|
1863
|
+
print(f" {line}")
|
|
1864
|
+
else:
|
|
1865
|
+
print(" (no checkpoint files found)")
|
|
1866
|
+
|
|
1867
|
+
elif args.command == "kill":
|
|
1868
|
+
# Kill training/inference processes
|
|
1869
|
+
instances = client.list_instances()
|
|
1870
|
+
if not instances:
|
|
1871
|
+
print("No running instances.")
|
|
1872
|
+
if args.local:
|
|
1873
|
+
print("\nKilling local Lambda-related processes...")
|
|
1874
|
+
subprocess.run(
|
|
1875
|
+
["pkill", "-f", "ssh.*ubuntu@.*openadapt"],
|
|
1876
|
+
capture_output=True
|
|
1877
|
+
)
|
|
1878
|
+
subprocess.run(
|
|
1879
|
+
["pkill", "-f", "lambda_labs"],
|
|
1880
|
+
capture_output=True
|
|
1881
|
+
)
|
|
1882
|
+
print("Done.")
|
|
1883
|
+
return
|
|
1884
|
+
|
|
1885
|
+
if args.instance_id:
|
|
1886
|
+
instance = next((i for i in instances if i.id.startswith(args.instance_id)), None)
|
|
1887
|
+
if not instance:
|
|
1888
|
+
print(f"Instance {args.instance_id} not found.")
|
|
1889
|
+
return
|
|
1890
|
+
else:
|
|
1891
|
+
instance = instances[0]
|
|
1892
|
+
|
|
1893
|
+
print(f"Checking processes on {instance.ip}...")
|
|
1894
|
+
|
|
1895
|
+
# List Python processes first
|
|
1896
|
+
result = client.ssh_run(
|
|
1897
|
+
instance,
|
|
1898
|
+
"ps aux | grep python | grep -v grep | grep -v jupyter",
|
|
1899
|
+
timeout=30
|
|
1900
|
+
)
|
|
1901
|
+
if result.stdout.strip():
|
|
1902
|
+
print("Found Python processes:")
|
|
1903
|
+
for line in result.stdout.strip().split('\n'):
|
|
1904
|
+
print(f" {line[:100]}...")
|
|
1905
|
+
else:
|
|
1906
|
+
print("No training/inference Python processes found.")
|
|
1907
|
+
return
|
|
1908
|
+
|
|
1909
|
+
if args.all:
|
|
1910
|
+
print("\nKilling ALL Python processes (except jupyter)...")
|
|
1911
|
+
cmd = "pkill -f 'python.*train\\|python.*compare\\|python.*openadapt' || true"
|
|
1912
|
+
else:
|
|
1913
|
+
print("\nKilling training and inference processes...")
|
|
1914
|
+
cmd = "pkill -f 'python.*train' ; pkill -f 'python.*compare' || true"
|
|
1915
|
+
|
|
1916
|
+
result = client.ssh_run(instance, cmd, timeout=30)
|
|
1917
|
+
print("Remote processes killed.")
|
|
1918
|
+
|
|
1919
|
+
if args.local:
|
|
1920
|
+
print("\nKilling local Lambda-related processes...")
|
|
1921
|
+
subprocess.run(
|
|
1922
|
+
["pkill", "-f", "ssh.*ubuntu@.*openadapt"],
|
|
1923
|
+
capture_output=True
|
|
1924
|
+
)
|
|
1925
|
+
subprocess.run(
|
|
1926
|
+
["pkill", "-f", "lambda_labs.*train"],
|
|
1927
|
+
capture_output=True
|
|
1928
|
+
)
|
|
1929
|
+
print("Local processes killed.")
|
|
1930
|
+
|
|
1931
|
+
print("\nDone. Current status:")
|
|
1932
|
+
result = client.ssh_run(
|
|
1933
|
+
instance,
|
|
1934
|
+
"ps aux | grep python | grep -v grep | grep -v jupyter | wc -l",
|
|
1935
|
+
timeout=30
|
|
1936
|
+
)
|
|
1937
|
+
count = result.stdout.strip()
|
|
1938
|
+
print(f" {count} Python processes remaining on instance")
|
|
1939
|
+
|
|
1940
|
+
elif args.command == "check":
|
|
1941
|
+
# Analyze training status and early stopping
|
|
1942
|
+
instances = client.list_instances()
|
|
1943
|
+
if not instances:
|
|
1944
|
+
print("No running instances.")
|
|
1945
|
+
return
|
|
1946
|
+
|
|
1947
|
+
if args.instance_id:
|
|
1948
|
+
instance = next((i for i in instances if i.id.startswith(args.instance_id)), None)
|
|
1949
|
+
if not instance:
|
|
1950
|
+
print(f"Instance {args.instance_id} not found.")
|
|
1951
|
+
return
|
|
1952
|
+
else:
|
|
1953
|
+
instance = instances[0]
|
|
1954
|
+
|
|
1955
|
+
print(f"Checking training on {instance.ip}...")
|
|
1956
|
+
|
|
1957
|
+
# Get training log
|
|
1958
|
+
result = client.ssh_run(
|
|
1959
|
+
instance,
|
|
1960
|
+
"cat ~/openadapt-ml/training_output/training_log.json 2>/dev/null",
|
|
1961
|
+
timeout=30
|
|
1962
|
+
)
|
|
1963
|
+
|
|
1964
|
+
if not result.stdout.strip():
|
|
1965
|
+
print("No training log found.")
|
|
1966
|
+
return
|
|
1967
|
+
|
|
1968
|
+
try:
|
|
1969
|
+
data = json.loads(result.stdout)
|
|
1970
|
+
losses = data.get("losses", [])
|
|
1971
|
+
except json.JSONDecodeError:
|
|
1972
|
+
print("Could not parse training log.")
|
|
1973
|
+
return
|
|
1974
|
+
|
|
1975
|
+
if not losses:
|
|
1976
|
+
print("No training data yet.")
|
|
1977
|
+
return
|
|
1978
|
+
|
|
1979
|
+
total_steps = len(losses)
|
|
1980
|
+
epochs = sorted(set(l["epoch"] for l in losses))
|
|
1981
|
+
total_epochs = data.get("total_epochs", 5)
|
|
1982
|
+
min_loss = min(l["loss"] for l in losses)
|
|
1983
|
+
current_loss = losses[-1]["loss"]
|
|
1984
|
+
|
|
1985
|
+
print(f"\n{'='*50}")
|
|
1986
|
+
print(f"TRAINING STATUS")
|
|
1987
|
+
print(f"{'='*50}")
|
|
1988
|
+
print(f"Steps: {total_steps}")
|
|
1989
|
+
print(f"Epochs: {max(epochs)+1}/{total_epochs}")
|
|
1990
|
+
print(f"Current loss: {current_loss:.4f}")
|
|
1991
|
+
print(f"Min loss: {min_loss:.4f}")
|
|
1992
|
+
|
|
1993
|
+
# Check if training is running
|
|
1994
|
+
proc_result = client.ssh_run(
|
|
1995
|
+
instance,
|
|
1996
|
+
"ps aux | grep 'python.*train' | grep -v grep | wc -l",
|
|
1997
|
+
timeout=30
|
|
1998
|
+
)
|
|
1999
|
+
is_running = int(proc_result.stdout.strip()) > 0
|
|
2000
|
+
|
|
2001
|
+
if is_running:
|
|
2002
|
+
print(f"Status: RUNNING")
|
|
2003
|
+
else:
|
|
2004
|
+
print(f"Status: STOPPED")
|
|
2005
|
+
|
|
2006
|
+
# Early stopping analysis
|
|
2007
|
+
window = min(args.window, len(losses))
|
|
2008
|
+
if window < 2:
|
|
2009
|
+
print("\nNot enough data for early stopping analysis.")
|
|
2010
|
+
else:
|
|
2011
|
+
recent_losses = [l["loss"] for l in losses[-window:]]
|
|
2012
|
+
older_losses = [l["loss"] for l in losses[-window*2:-window]] if len(losses) >= window*2 else [l["loss"] for l in losses[:window]]
|
|
2013
|
+
|
|
2014
|
+
recent_avg = sum(recent_losses) / len(recent_losses)
|
|
2015
|
+
older_avg = sum(older_losses) / len(older_losses) if older_losses else recent_avg
|
|
2016
|
+
|
|
2017
|
+
improvement = (older_avg - recent_avg) / older_avg if older_avg > 0 else 0
|
|
2018
|
+
loss_variance = max(recent_losses) - min(recent_losses)
|
|
2019
|
+
|
|
2020
|
+
print(f"\n{'='*50}")
|
|
2021
|
+
print(f"EARLY STOPPING ANALYSIS (window={window})")
|
|
2022
|
+
print(f"{'='*50}")
|
|
2023
|
+
print(f"Recent avg loss: {recent_avg:.4f}")
|
|
2024
|
+
print(f"Prior avg loss: {older_avg:.4f}")
|
|
2025
|
+
print(f"Improvement: {improvement*100:.2f}%")
|
|
2026
|
+
print(f"Loss variance: {loss_variance:.4f}")
|
|
2027
|
+
|
|
2028
|
+
should_stop = improvement < args.threshold and loss_variance < 0.1
|
|
2029
|
+
if should_stop:
|
|
2030
|
+
print(f"\n⚠️ EARLY STOPPING RECOMMENDED")
|
|
2031
|
+
print(f" Loss has plateaued (improvement < {args.threshold*100}%)")
|
|
2032
|
+
if not is_running:
|
|
2033
|
+
print(f" (Training already stopped)")
|
|
2034
|
+
else:
|
|
2035
|
+
print(f"\n To stop: uv run python -m openadapt_ml.cloud.lambda_labs kill")
|
|
2036
|
+
else:
|
|
2037
|
+
print(f"\n✓ Training still improving, continue.")
|
|
2038
|
+
|
|
2039
|
+
# Time estimate
|
|
2040
|
+
if is_running and len(losses) >= 2:
|
|
2041
|
+
avg_time_per_step = losses[-1].get("time", 0) / len(losses) if losses[-1].get("time") else 50
|
|
2042
|
+
steps_per_epoch = len(losses) / (max(epochs) + 1)
|
|
2043
|
+
remaining_epochs = total_epochs - max(epochs) - 1
|
|
2044
|
+
remaining_steps = remaining_epochs * steps_per_epoch
|
|
2045
|
+
eta_seconds = remaining_steps * avg_time_per_step
|
|
2046
|
+
eta_mins = eta_seconds / 60
|
|
2047
|
+
|
|
2048
|
+
print(f"\n{'='*50}")
|
|
2049
|
+
print(f"TIME ESTIMATE")
|
|
2050
|
+
print(f"{'='*50}")
|
|
2051
|
+
print(f"Remaining epochs: {remaining_epochs}")
|
|
2052
|
+
print(f"Est. remaining steps: {remaining_steps:.0f}")
|
|
2053
|
+
print(f"ETA: {eta_mins:.1f} minutes")
|
|
2054
|
+
|
|
2055
|
+
elif args.command == "compare":
|
|
2056
|
+
# Run comparison on Lambda and sync back
|
|
2057
|
+
instances = client.list_instances()
|
|
2058
|
+
if not instances:
|
|
2059
|
+
print("No running instances.")
|
|
2060
|
+
return
|
|
2061
|
+
|
|
2062
|
+
if args.instance_id:
|
|
2063
|
+
instance = next((i for i in instances if i.id.startswith(args.instance_id)), None)
|
|
2064
|
+
if not instance:
|
|
2065
|
+
print(f"Instance {args.instance_id} not found.")
|
|
2066
|
+
return
|
|
2067
|
+
else:
|
|
2068
|
+
instance = instances[0]
|
|
2069
|
+
|
|
2070
|
+
# Determine checkpoint to use
|
|
2071
|
+
if args.checkpoint:
|
|
2072
|
+
checkpoint_path = args.checkpoint
|
|
2073
|
+
elif args.epoch is not None:
|
|
2074
|
+
checkpoint_path = f"/home/ubuntu/openadapt-ml/checkpoints/epoch_{args.epoch}"
|
|
2075
|
+
else:
|
|
2076
|
+
# Use latest (main checkpoint)
|
|
2077
|
+
checkpoint_path = "/home/ubuntu/openadapt-ml/checkpoints/qwen3vl2b_capture_lora"
|
|
2078
|
+
|
|
2079
|
+
# Check if checkpoint exists
|
|
2080
|
+
result = client.ssh_run(
|
|
2081
|
+
instance,
|
|
2082
|
+
f"ls {checkpoint_path}/adapter_config.json 2>/dev/null && echo 'exists'",
|
|
2083
|
+
timeout=30
|
|
2084
|
+
)
|
|
2085
|
+
if "exists" not in result.stdout:
|
|
2086
|
+
print(f"Checkpoint not found at {checkpoint_path}")
|
|
2087
|
+
# List available checkpoints
|
|
2088
|
+
result = client.ssh_run(
|
|
2089
|
+
instance,
|
|
2090
|
+
"ls -la ~/openadapt-ml/checkpoints/",
|
|
2091
|
+
timeout=30
|
|
2092
|
+
)
|
|
2093
|
+
print(f"Available checkpoints:\n{result.stdout}")
|
|
2094
|
+
return
|
|
2095
|
+
|
|
2096
|
+
print(f"Running comparison on {instance.ip}...")
|
|
2097
|
+
print(f"Using checkpoint: {checkpoint_path}")
|
|
2098
|
+
|
|
2099
|
+
# Run comparison on Lambda
|
|
2100
|
+
output_name = f"comparison_{time.strftime('%H%M%S')}.html"
|
|
2101
|
+
cmd = f"""cd ~/openadapt-ml && source .venv/bin/activate && \
|
|
2102
|
+
python -m openadapt_ml.scripts.compare \
|
|
2103
|
+
--capture ~/capture \
|
|
2104
|
+
--checkpoint {checkpoint_path} \
|
|
2105
|
+
--output training_output/{output_name} 2>&1"""
|
|
2106
|
+
|
|
2107
|
+
print("Generating predictions (this may take a few minutes)...")
|
|
2108
|
+
result = client.ssh_run(instance, cmd, timeout=600)
|
|
2109
|
+
|
|
2110
|
+
if result.returncode != 0:
|
|
2111
|
+
print(f"Comparison failed:\n{result.stderr}")
|
|
2112
|
+
return
|
|
2113
|
+
|
|
2114
|
+
# Check if file was created
|
|
2115
|
+
result = client.ssh_run(
|
|
2116
|
+
instance,
|
|
2117
|
+
f"ls -la ~/openadapt-ml/training_output/{output_name}",
|
|
2118
|
+
timeout=30
|
|
2119
|
+
)
|
|
2120
|
+
if result.returncode != 0:
|
|
2121
|
+
print("Comparison file not created.")
|
|
2122
|
+
return
|
|
2123
|
+
|
|
2124
|
+
print(f"Comparison generated: {output_name}")
|
|
2125
|
+
|
|
2126
|
+
# Sync back to local
|
|
2127
|
+
local_output = Path("training_output") / output_name
|
|
2128
|
+
local_output.parent.mkdir(parents=True, exist_ok=True)
|
|
2129
|
+
|
|
2130
|
+
print(f"Syncing to {local_output}...")
|
|
2131
|
+
subprocess.run([
|
|
2132
|
+
"rsync", "-avz",
|
|
2133
|
+
f"ubuntu@{instance.ip}:~/openadapt-ml/training_output/{output_name}",
|
|
2134
|
+
str(local_output)
|
|
2135
|
+
], capture_output=True)
|
|
2136
|
+
|
|
2137
|
+
print(f"Done! Comparison saved to: {local_output}")
|
|
2138
|
+
|
|
2139
|
+
if args.open:
|
|
2140
|
+
subprocess.run(["open", str(local_output)], capture_output=True)
|
|
2141
|
+
print("Opened in browser.")
|
|
2142
|
+
|
|
2143
|
+
elif args.command == "download":
|
|
2144
|
+
instances = client.list_instances()
|
|
2145
|
+
if not instances:
|
|
2146
|
+
print("No running instances.")
|
|
2147
|
+
return
|
|
2148
|
+
|
|
2149
|
+
if args.instance_id:
|
|
2150
|
+
instance = next((i for i in instances if i.id.startswith(args.instance_id)), None)
|
|
2151
|
+
if not instance:
|
|
2152
|
+
print(f"Instance {args.instance_id} not found.")
|
|
2153
|
+
return
|
|
2154
|
+
else:
|
|
2155
|
+
instance = instances[0]
|
|
2156
|
+
|
|
2157
|
+
client.download_results(instance, local_path=args.output)
|
|
2158
|
+
|
|
2159
|
+
elif args.command == "results":
|
|
2160
|
+
# Download results and generate comparison viewer
|
|
2161
|
+
instances = client.list_instances()
|
|
2162
|
+
if not instances:
|
|
2163
|
+
print("No running instances.")
|
|
2164
|
+
return
|
|
2165
|
+
|
|
2166
|
+
if args.instance_id:
|
|
2167
|
+
instance = next((i for i in instances if i.id.startswith(args.instance_id)), None)
|
|
2168
|
+
if not instance:
|
|
2169
|
+
print(f"Instance {args.instance_id} not found.")
|
|
2170
|
+
return
|
|
2171
|
+
else:
|
|
2172
|
+
instance = instances[0]
|
|
2173
|
+
|
|
2174
|
+
# Download results
|
|
2175
|
+
print("Step 1: Downloading training results...")
|
|
2176
|
+
client.download_results(instance)
|
|
2177
|
+
|
|
2178
|
+
# Generate comparison viewer
|
|
2179
|
+
print("\nStep 2: Generating comparison viewer...")
|
|
2180
|
+
checkpoint_path = "checkpoints_lambda/qwen3vl2b_capture_lora"
|
|
2181
|
+
|
|
2182
|
+
import subprocess as sp
|
|
2183
|
+
cmd = [
|
|
2184
|
+
"uv", "run", "python", "-m", "openadapt_ml.scripts.compare",
|
|
2185
|
+
"--capture", args.capture,
|
|
2186
|
+
"--checkpoint", checkpoint_path,
|
|
2187
|
+
]
|
|
2188
|
+
if args.goal:
|
|
2189
|
+
cmd.extend(["--goal", args.goal])
|
|
2190
|
+
if args.open:
|
|
2191
|
+
cmd.append("--open")
|
|
2192
|
+
|
|
2193
|
+
result = sp.run(cmd)
|
|
2194
|
+
if result.returncode == 0:
|
|
2195
|
+
print("\nComparison viewer generated!")
|
|
2196
|
+
if not args.open:
|
|
2197
|
+
print(f"Open with: open {args.capture}/comparison.html")
|
|
2198
|
+
else:
|
|
2199
|
+
print("Warning: Failed to generate comparison viewer")
|
|
2200
|
+
|
|
2201
|
+
elif args.command == "serve":
|
|
2202
|
+
# Start web server for live dashboard with stop button support
|
|
2203
|
+
import http.server
|
|
2204
|
+
import socketserver
|
|
2205
|
+
import threading
|
|
2206
|
+
import time as time_module
|
|
2207
|
+
from pathlib import Path
|
|
2208
|
+
|
|
2209
|
+
output_dir = Path(args.output) if hasattr(args, 'output') else Path("training_output")
|
|
2210
|
+
port = args.port
|
|
2211
|
+
|
|
2212
|
+
if not output_dir.exists():
|
|
2213
|
+
print(f"No {output_dir} directory. Run 'refresh' first.")
|
|
2214
|
+
return
|
|
2215
|
+
|
|
2216
|
+
# Define handler with /api/stop support
|
|
2217
|
+
class Handler(http.server.SimpleHTTPRequestHandler):
|
|
2218
|
+
def __init__(self, *args, **kwargs):
|
|
2219
|
+
super().__init__(*args, directory=str(output_dir), **kwargs)
|
|
2220
|
+
|
|
2221
|
+
def do_POST(self):
|
|
2222
|
+
if self.path == '/api/stop':
|
|
2223
|
+
# Create stop signal file
|
|
2224
|
+
stop_file = output_dir / "STOP_TRAINING"
|
|
2225
|
+
stop_file.touch()
|
|
2226
|
+
self.send_response(200)
|
|
2227
|
+
self.send_header('Content-Type', 'application/json')
|
|
2228
|
+
self.send_header('Access-Control-Allow-Origin', '*')
|
|
2229
|
+
self.end_headers()
|
|
2230
|
+
self.wfile.write(b'{"status": "stop signal created"}')
|
|
2231
|
+
print(f" Stop signal created: {stop_file}")
|
|
2232
|
+
else:
|
|
2233
|
+
self.send_error(404)
|
|
2234
|
+
|
|
2235
|
+
def do_OPTIONS(self):
|
|
2236
|
+
# Handle CORS preflight
|
|
2237
|
+
self.send_response(200)
|
|
2238
|
+
self.send_header('Access-Control-Allow-Origin', '*')
|
|
2239
|
+
self.send_header('Access-Control-Allow-Methods', 'POST, OPTIONS')
|
|
2240
|
+
self.send_header('Access-Control-Allow-Headers', 'Content-Type')
|
|
2241
|
+
self.end_headers()
|
|
2242
|
+
|
|
2243
|
+
def log_message(self, format, *args):
|
|
2244
|
+
pass # Suppress log messages
|
|
2245
|
+
|
|
2246
|
+
|
|
2247
|
+
# Start web server
|
|
2248
|
+
with socketserver.TCPServer(("", port), Handler) as httpd:
|
|
2249
|
+
url = f"http://localhost:{port}/dashboard.html"
|
|
2250
|
+
print(f"\nDashboard server started at {url}")
|
|
2251
|
+
print("Press Ctrl+C to stop\n")
|
|
2252
|
+
|
|
2253
|
+
if args.open:
|
|
2254
|
+
subprocess.run(["open", url], capture_output=True)
|
|
2255
|
+
|
|
2256
|
+
try:
|
|
2257
|
+
httpd.serve_forever()
|
|
2258
|
+
except KeyboardInterrupt:
|
|
2259
|
+
print("\nServer stopped.")
|
|
2260
|
+
|
|
2261
|
+
elif args.command == "sync":
|
|
2262
|
+
# Sync training output from Lambda and regenerate navigation for file:// protocol
|
|
2263
|
+
from pathlib import Path
|
|
2264
|
+
from openadapt_ml.training.trainer import (
|
|
2265
|
+
TrainingState, TrainingConfig, generate_training_dashboard,
|
|
2266
|
+
regenerate_all_dashboards
|
|
2267
|
+
)
|
|
2268
|
+
|
|
2269
|
+
instances = client.list_instances()
|
|
2270
|
+
if not instances:
|
|
2271
|
+
print("No running instances.")
|
|
2272
|
+
return
|
|
2273
|
+
|
|
2274
|
+
if args.instance_id:
|
|
2275
|
+
instance = next((i for i in instances if i.id.startswith(args.instance_id)), None)
|
|
2276
|
+
if not instance:
|
|
2277
|
+
print(f"Instance {args.instance_id} not found.")
|
|
2278
|
+
return
|
|
2279
|
+
else:
|
|
2280
|
+
instance = instances[0]
|
|
2281
|
+
|
|
2282
|
+
output_dir = Path(args.output)
|
|
2283
|
+
output_dir.mkdir(exist_ok=True)
|
|
2284
|
+
|
|
2285
|
+
print(f"Syncing training output from {instance.ip}...")
|
|
2286
|
+
|
|
2287
|
+
# Sync all training output files
|
|
2288
|
+
rsync_cmd = [
|
|
2289
|
+
"rsync", "-avz", "--progress",
|
|
2290
|
+
"-e", "ssh -o StrictHostKeyChecking=no",
|
|
2291
|
+
f"ubuntu@{instance.ip}:~/openadapt-ml/training_output/",
|
|
2292
|
+
str(output_dir) + "/"
|
|
2293
|
+
]
|
|
2294
|
+
result = subprocess.run(rsync_cmd, capture_output=False)
|
|
2295
|
+
|
|
2296
|
+
if result.returncode != 0:
|
|
2297
|
+
print("Warning: rsync may have had issues")
|
|
2298
|
+
|
|
2299
|
+
# Update dashboard with instance metadata
|
|
2300
|
+
log_path = output_dir / "training_log.json"
|
|
2301
|
+
dashboard_path = output_dir / "dashboard.html"
|
|
2302
|
+
|
|
2303
|
+
if log_path.exists():
|
|
2304
|
+
try:
|
|
2305
|
+
import time as time_module
|
|
2306
|
+
status = json.loads(log_path.read_text())
|
|
2307
|
+
|
|
2308
|
+
# Update with instance info
|
|
2309
|
+
status["instance_ip"] = instance.ip
|
|
2310
|
+
status["instance_type"] = instance.instance_type
|
|
2311
|
+
status["cloud_provider"] = "lambda"
|
|
2312
|
+
status["cloud_dashboard_url"] = "https://cloud.lambda.ai/instances"
|
|
2313
|
+
status["cloud_instance_id"] = instance.id
|
|
2314
|
+
|
|
2315
|
+
log_path.write_text(json.dumps(status, indent=2))
|
|
2316
|
+
|
|
2317
|
+
# Generate updated dashboard
|
|
2318
|
+
state = TrainingState()
|
|
2319
|
+
state.job_id = status.get("job_id", "")
|
|
2320
|
+
state.hostname = status.get("hostname", "lambda")
|
|
2321
|
+
state.instance_ip = instance.ip or ""
|
|
2322
|
+
state.instance_type = instance.instance_type
|
|
2323
|
+
state.config_path = status.get("config_path", "")
|
|
2324
|
+
state.capture_path = status.get("capture_path", "")
|
|
2325
|
+
state.epoch = status.get("epoch", 0)
|
|
2326
|
+
state.step = status.get("step", 0)
|
|
2327
|
+
state.loss = status.get("loss", 0)
|
|
2328
|
+
state.learning_rate = status.get("learning_rate", 5e-5)
|
|
2329
|
+
state.losses = status.get("losses", [])
|
|
2330
|
+
state.evaluations = status.get("evaluations", [])
|
|
2331
|
+
state.total_epochs = status.get("total_epochs", 5)
|
|
2332
|
+
state.start_time = time_module.time() - status.get("elapsed_time", 0)
|
|
2333
|
+
state.cloud_provider = "lambda"
|
|
2334
|
+
state.cloud_dashboard_url = "https://cloud.lambda.ai/instances"
|
|
2335
|
+
state.cloud_instance_id = instance.id
|
|
2336
|
+
|
|
2337
|
+
config = TrainingConfig(
|
|
2338
|
+
num_train_epochs=status.get("total_epochs", 5),
|
|
2339
|
+
learning_rate=status.get("learning_rate", 5e-5)
|
|
2340
|
+
)
|
|
2341
|
+
|
|
2342
|
+
dashboard_path.write_text(generate_training_dashboard(state, config))
|
|
2343
|
+
except Exception as e:
|
|
2344
|
+
print(f"Warning: Could not update dashboard: {e}")
|
|
2345
|
+
|
|
2346
|
+
# Regenerate ALL dashboards with static navigation (for file:// protocol)
|
|
2347
|
+
print("Regenerating navigation links...")
|
|
2348
|
+
try:
|
|
2349
|
+
regenerated = regenerate_all_dashboards(output_dir)
|
|
2350
|
+
print(f" Updated {len(regenerated)} files with static navigation")
|
|
2351
|
+
except Exception as e:
|
|
2352
|
+
print(f"Warning: Navigation regeneration failed: {e}")
|
|
2353
|
+
|
|
2354
|
+
# Summary
|
|
2355
|
+
files = list(output_dir.glob("*.html"))
|
|
2356
|
+
print(f"\nSynced {len(files)} HTML files to {output_dir}/")
|
|
2357
|
+
for f in sorted(files):
|
|
2358
|
+
print(f" - {f.name}")
|
|
2359
|
+
|
|
2360
|
+
print(f"\nDashboard: {dashboard_path.absolute()}")
|
|
2361
|
+
|
|
2362
|
+
if args.open:
|
|
2363
|
+
subprocess.run(["open", str(dashboard_path)], capture_output=True)
|
|
2364
|
+
|
|
2365
|
+
elif args.command == "viewer":
|
|
2366
|
+
# Regenerate and open local viewer (no Lambda required)
|
|
2367
|
+
from pathlib import Path
|
|
2368
|
+
from openadapt_ml.training.trainer import regenerate_all_dashboards
|
|
2369
|
+
import re
|
|
2370
|
+
|
|
2371
|
+
output_dir = Path(args.output)
|
|
2372
|
+
|
|
2373
|
+
if not output_dir.exists():
|
|
2374
|
+
print(f"Error: {output_dir} does not exist")
|
|
2375
|
+
print("Run training or sync first to populate the directory.")
|
|
2376
|
+
return
|
|
2377
|
+
|
|
2378
|
+
if not (output_dir / "training_log.json").exists():
|
|
2379
|
+
print(f"Error: No training_log.json found in {output_dir}")
|
|
2380
|
+
print("This directory doesn't contain training results.")
|
|
2381
|
+
return
|
|
2382
|
+
|
|
2383
|
+
# Auto-link local screenshots if available
|
|
2384
|
+
screenshots_link = output_dir / "screenshots"
|
|
2385
|
+
if not screenshots_link.exists():
|
|
2386
|
+
# Try to find capture ID from training log or predictions
|
|
2387
|
+
try:
|
|
2388
|
+
capture_id = None
|
|
2389
|
+
|
|
2390
|
+
# First try training log
|
|
2391
|
+
log_data = json.loads((output_dir / "training_log.json").read_text())
|
|
2392
|
+
capture_path = log_data.get("capture_path", "")
|
|
2393
|
+
capture_match = re.search(r'capture_(\d+)', capture_path)
|
|
2394
|
+
if capture_match:
|
|
2395
|
+
capture_id = capture_match.group(1)
|
|
2396
|
+
|
|
2397
|
+
# If not found, try predictions JSON files
|
|
2398
|
+
if not capture_id:
|
|
2399
|
+
for pred_file in output_dir.glob("predictions_*.json"):
|
|
2400
|
+
pred_data = json.loads(pred_file.read_text())
|
|
2401
|
+
base_data = pred_data.get("base_data", [])
|
|
2402
|
+
if base_data:
|
|
2403
|
+
image_path = base_data[0].get("image_path", "")
|
|
2404
|
+
capture_match = re.search(r'capture_(\d+)', image_path)
|
|
2405
|
+
if capture_match:
|
|
2406
|
+
capture_id = capture_match.group(1)
|
|
2407
|
+
break
|
|
2408
|
+
|
|
2409
|
+
if capture_id:
|
|
2410
|
+
# Search for local screenshots in openadapt-capture
|
|
2411
|
+
openadapt_capture_dir = Path.home() / "oa" / "src" / "openadapt-capture"
|
|
2412
|
+
if openadapt_capture_dir.exists():
|
|
2413
|
+
for capture_dir in openadapt_capture_dir.iterdir():
|
|
2414
|
+
if capture_dir.is_dir():
|
|
2415
|
+
screenshots_dir = capture_dir / "screenshots"
|
|
2416
|
+
if screenshots_dir.exists():
|
|
2417
|
+
# Check if this capture has our screenshots
|
|
2418
|
+
sample_file = list(screenshots_dir.glob(f"capture_{capture_id}_step_*.png"))
|
|
2419
|
+
if sample_file:
|
|
2420
|
+
print(f"Found local screenshots in {screenshots_dir}")
|
|
2421
|
+
screenshots_link.symlink_to(screenshots_dir)
|
|
2422
|
+
print(f" Linked: {screenshots_link} -> {screenshots_dir}")
|
|
2423
|
+
break
|
|
2424
|
+
except Exception as e:
|
|
2425
|
+
pass # Silently continue if auto-link fails
|
|
2426
|
+
|
|
2427
|
+
print(f"Regenerating viewer from {output_dir}...")
|
|
2428
|
+
regenerated = regenerate_all_dashboards(output_dir)
|
|
2429
|
+
print(f" Updated {len(regenerated)} files")
|
|
2430
|
+
|
|
2431
|
+
# Show path info
|
|
2432
|
+
if args.dashboard:
|
|
2433
|
+
target = output_dir / "dashboard.html"
|
|
2434
|
+
else:
|
|
2435
|
+
target = output_dir / "viewer.html"
|
|
2436
|
+
|
|
2437
|
+
print(f"\nGenerated: {target.absolute()}")
|
|
2438
|
+
print(f"View with: uv run python -m openadapt_ml.cloud.lambda_labs serve --open")
|
|
2439
|
+
|
|
2440
|
+
if args.open:
|
|
2441
|
+
subprocess.run(["open", str(target)], capture_output=True)
|
|
2442
|
+
|
|
2443
|
+
|
|
2444
|
+
if __name__ == "__main__":
|
|
2445
|
+
main()
|