openadapt-ml 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/baselines/__init__.py +121 -0
- openadapt_ml/baselines/adapter.py +185 -0
- openadapt_ml/baselines/cli.py +314 -0
- openadapt_ml/baselines/config.py +448 -0
- openadapt_ml/baselines/parser.py +922 -0
- openadapt_ml/baselines/prompts.py +787 -0
- openadapt_ml/benchmarks/__init__.py +13 -107
- openadapt_ml/benchmarks/agent.py +297 -374
- openadapt_ml/benchmarks/azure.py +62 -24
- openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
- openadapt_ml/benchmarks/cli.py +1874 -751
- openadapt_ml/benchmarks/trace_export.py +631 -0
- openadapt_ml/benchmarks/viewer.py +1236 -0
- openadapt_ml/benchmarks/vm_monitor.py +1111 -0
- openadapt_ml/benchmarks/waa_deploy/Dockerfile +216 -0
- openadapt_ml/benchmarks/waa_deploy/__init__.py +10 -0
- openadapt_ml/benchmarks/waa_deploy/api_agent.py +540 -0
- openadapt_ml/benchmarks/waa_deploy/start_waa_server.bat +53 -0
- openadapt_ml/cloud/azure_inference.py +3 -5
- openadapt_ml/cloud/lambda_labs.py +722 -307
- openadapt_ml/cloud/local.py +3194 -89
- openadapt_ml/cloud/ssh_tunnel.py +595 -0
- openadapt_ml/datasets/next_action.py +125 -96
- openadapt_ml/evals/grounding.py +32 -9
- openadapt_ml/evals/plot_eval_metrics.py +15 -13
- openadapt_ml/evals/trajectory_matching.py +120 -57
- openadapt_ml/experiments/demo_prompt/__init__.py +19 -0
- openadapt_ml/experiments/demo_prompt/format_demo.py +236 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_20251231_002125.json +83 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_n30_20251231_165958.json +1100 -0
- openadapt_ml/experiments/demo_prompt/results/multistep_20251231_025051.json +182 -0
- openadapt_ml/experiments/demo_prompt/run_experiment.py +541 -0
- openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
- openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
- openadapt_ml/experiments/representation_shootout/config.py +390 -0
- openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
- openadapt_ml/experiments/representation_shootout/runner.py +687 -0
- openadapt_ml/experiments/waa_demo/__init__.py +10 -0
- openadapt_ml/experiments/waa_demo/demos.py +357 -0
- openadapt_ml/experiments/waa_demo/runner.py +732 -0
- openadapt_ml/experiments/waa_demo/tasks.py +151 -0
- openadapt_ml/export/__init__.py +9 -0
- openadapt_ml/export/__main__.py +6 -0
- openadapt_ml/export/cli.py +89 -0
- openadapt_ml/export/parquet.py +277 -0
- openadapt_ml/grounding/detector.py +18 -14
- openadapt_ml/ingest/__init__.py +11 -10
- openadapt_ml/ingest/capture.py +97 -86
- openadapt_ml/ingest/loader.py +120 -69
- openadapt_ml/ingest/synthetic.py +344 -193
- openadapt_ml/models/api_adapter.py +14 -4
- openadapt_ml/models/base_adapter.py +10 -2
- openadapt_ml/models/providers/__init__.py +288 -0
- openadapt_ml/models/providers/anthropic.py +266 -0
- openadapt_ml/models/providers/base.py +299 -0
- openadapt_ml/models/providers/google.py +376 -0
- openadapt_ml/models/providers/openai.py +342 -0
- openadapt_ml/models/qwen_vl.py +46 -19
- openadapt_ml/perception/__init__.py +35 -0
- openadapt_ml/perception/integration.py +399 -0
- openadapt_ml/retrieval/README.md +226 -0
- openadapt_ml/retrieval/USAGE.md +391 -0
- openadapt_ml/retrieval/__init__.py +91 -0
- openadapt_ml/retrieval/demo_retriever.py +843 -0
- openadapt_ml/retrieval/embeddings.py +630 -0
- openadapt_ml/retrieval/index.py +194 -0
- openadapt_ml/retrieval/retriever.py +162 -0
- openadapt_ml/runtime/__init__.py +50 -0
- openadapt_ml/runtime/policy.py +27 -14
- openadapt_ml/runtime/safety_gate.py +471 -0
- openadapt_ml/schema/__init__.py +113 -0
- openadapt_ml/schema/converters.py +588 -0
- openadapt_ml/schema/episode.py +470 -0
- openadapt_ml/scripts/capture_screenshots.py +530 -0
- openadapt_ml/scripts/compare.py +102 -61
- openadapt_ml/scripts/demo_policy.py +4 -1
- openadapt_ml/scripts/eval_policy.py +19 -14
- openadapt_ml/scripts/make_gif.py +1 -1
- openadapt_ml/scripts/prepare_synthetic.py +16 -17
- openadapt_ml/scripts/train.py +98 -75
- openadapt_ml/segmentation/README.md +920 -0
- openadapt_ml/segmentation/__init__.py +97 -0
- openadapt_ml/segmentation/adapters/__init__.py +5 -0
- openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
- openadapt_ml/segmentation/annotator.py +610 -0
- openadapt_ml/segmentation/cache.py +290 -0
- openadapt_ml/segmentation/cli.py +674 -0
- openadapt_ml/segmentation/deduplicator.py +656 -0
- openadapt_ml/segmentation/frame_describer.py +788 -0
- openadapt_ml/segmentation/pipeline.py +340 -0
- openadapt_ml/segmentation/schemas.py +622 -0
- openadapt_ml/segmentation/segment_extractor.py +634 -0
- openadapt_ml/training/azure_ops_viewer.py +1097 -0
- openadapt_ml/training/benchmark_viewer.py +3255 -19
- openadapt_ml/training/shared_ui.py +7 -7
- openadapt_ml/training/stub_provider.py +57 -35
- openadapt_ml/training/trainer.py +255 -441
- openadapt_ml/training/trl_trainer.py +403 -0
- openadapt_ml/training/viewer.py +323 -108
- openadapt_ml/training/viewer_components.py +180 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +312 -69
- openadapt_ml-0.2.1.dist-info/RECORD +116 -0
- openadapt_ml/benchmarks/base.py +0 -366
- openadapt_ml/benchmarks/data_collection.py +0 -432
- openadapt_ml/benchmarks/runner.py +0 -381
- openadapt_ml/benchmarks/waa.py +0 -704
- openadapt_ml/schemas/__init__.py +0 -53
- openadapt_ml/schemas/sessions.py +0 -122
- openadapt_ml/schemas/validation.py +0 -252
- openadapt_ml-0.1.0.dist-info/RECORD +0 -55
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -44,7 +44,9 @@ API_BASE = "https://cloud.lambdalabs.com/api/v1"
|
|
|
44
44
|
DEFAULT_SERVER_PORT = 8765
|
|
45
45
|
|
|
46
46
|
|
|
47
|
-
def start_dashboard_server(
|
|
47
|
+
def start_dashboard_server(
|
|
48
|
+
output_dir: Path, port: int = DEFAULT_SERVER_PORT
|
|
49
|
+
) -> tuple[subprocess.Popen, str]:
|
|
48
50
|
"""Start a background HTTP server for the dashboard.
|
|
49
51
|
|
|
50
52
|
Args:
|
|
@@ -54,8 +56,6 @@ def start_dashboard_server(output_dir: Path, port: int = DEFAULT_SERVER_PORT) ->
|
|
|
54
56
|
Returns:
|
|
55
57
|
(process, url): The server process and the dashboard URL
|
|
56
58
|
"""
|
|
57
|
-
import webbrowser
|
|
58
|
-
import threading
|
|
59
59
|
|
|
60
60
|
# Start simple HTTP server in background thread
|
|
61
61
|
server_proc = subprocess.Popen(
|
|
@@ -96,7 +96,9 @@ def open_dashboard_in_browser(output_dir: Path, port: int = DEFAULT_SERVER_PORT)
|
|
|
96
96
|
return None
|
|
97
97
|
|
|
98
98
|
|
|
99
|
-
def setup_capture_screenshots_symlink(
|
|
99
|
+
def setup_capture_screenshots_symlink(
|
|
100
|
+
output_dir: Path, capture_path: str | Path
|
|
101
|
+
) -> bool:
|
|
100
102
|
"""Create symlink from output_dir/screenshots to capture's screenshots folder.
|
|
101
103
|
|
|
102
104
|
This allows the dashboard to serve screenshots via relative paths.
|
|
@@ -128,7 +130,9 @@ def setup_capture_screenshots_symlink(output_dir: Path, capture_path: str | Path
|
|
|
128
130
|
return False
|
|
129
131
|
|
|
130
132
|
|
|
131
|
-
def rewrite_evaluation_paths(
|
|
133
|
+
def rewrite_evaluation_paths(
|
|
134
|
+
evaluations: list[dict], remote_prefix: str = "/home/ubuntu/capture/"
|
|
135
|
+
) -> list[dict]:
|
|
132
136
|
"""Rewrite Lambda paths in evaluations to relative paths.
|
|
133
137
|
|
|
134
138
|
Converts: /home/ubuntu/capture/screenshots/foo.png -> screenshots/foo.png
|
|
@@ -146,7 +150,9 @@ def rewrite_evaluation_paths(evaluations: list[dict], remote_prefix: str = "/hom
|
|
|
146
150
|
return evaluations
|
|
147
151
|
|
|
148
152
|
|
|
149
|
-
def download_checkpoints_from_instance(
|
|
153
|
+
def download_checkpoints_from_instance(
|
|
154
|
+
instance_ip: str, output_dir: Path, ssh_key: str | None = None
|
|
155
|
+
) -> bool:
|
|
150
156
|
"""Download checkpoints from Lambda instance.
|
|
151
157
|
|
|
152
158
|
Args:
|
|
@@ -161,7 +167,9 @@ def download_checkpoints_from_instance(instance_ip: str, output_dir: Path, ssh_k
|
|
|
161
167
|
checkpoints_dir.mkdir(parents=True, exist_ok=True)
|
|
162
168
|
|
|
163
169
|
ssh_key = ssh_key or str(Path.home() / ".ssh" / "lambda_id_ed25519")
|
|
164
|
-
ssh_opts =
|
|
170
|
+
ssh_opts = (
|
|
171
|
+
f"-o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -i {ssh_key}"
|
|
172
|
+
)
|
|
165
173
|
|
|
166
174
|
# Download checkpoints from remote
|
|
167
175
|
remote_path = f"ubuntu@{instance_ip}:~/openadapt-ml/checkpoints/"
|
|
@@ -187,6 +195,7 @@ def check_stop_signal(output_dir: Path) -> bool:
|
|
|
187
195
|
@dataclass
|
|
188
196
|
class InstanceType:
|
|
189
197
|
"""Lambda Labs instance type."""
|
|
198
|
+
|
|
190
199
|
name: str
|
|
191
200
|
price_cents_per_hour: int
|
|
192
201
|
description: str
|
|
@@ -216,6 +225,7 @@ class InstanceType:
|
|
|
216
225
|
@dataclass
|
|
217
226
|
class Instance:
|
|
218
227
|
"""Running Lambda Labs instance."""
|
|
228
|
+
|
|
219
229
|
id: str
|
|
220
230
|
name: str
|
|
221
231
|
instance_type: str
|
|
@@ -236,6 +246,7 @@ class LambdaLabsClient:
|
|
|
236
246
|
# Try provided key, then settings, then env var
|
|
237
247
|
if not api_key:
|
|
238
248
|
from openadapt_ml.config import settings
|
|
249
|
+
|
|
239
250
|
api_key = settings.lambda_api_key or os.environ.get("LAMBDA_API_KEY")
|
|
240
251
|
|
|
241
252
|
self.api_key = api_key
|
|
@@ -268,19 +279,25 @@ class LambdaLabsClient:
|
|
|
268
279
|
|
|
269
280
|
for name, info in data.get("data", {}).items():
|
|
270
281
|
specs = info.get("instance_type", {}).get("specs", {})
|
|
271
|
-
regions = [
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
282
|
+
regions = [
|
|
283
|
+
r["name"] for r in info.get("regions_with_capacity_available", [])
|
|
284
|
+
]
|
|
285
|
+
|
|
286
|
+
types.append(
|
|
287
|
+
InstanceType(
|
|
288
|
+
name=name,
|
|
289
|
+
price_cents_per_hour=info.get("instance_type", {}).get(
|
|
290
|
+
"price_cents_per_hour", 0
|
|
291
|
+
),
|
|
292
|
+
description=info.get("instance_type", {}).get("description", ""),
|
|
293
|
+
gpu_count=specs.get("gpus", 0),
|
|
294
|
+
gpu_type=info.get("instance_type", {}).get("gpu_description", ""),
|
|
295
|
+
vcpus=specs.get("vcpus", 0),
|
|
296
|
+
memory_gb=specs.get("memory_gib", 0),
|
|
297
|
+
storage_gb=specs.get("storage_gib", 0),
|
|
298
|
+
available_regions=regions,
|
|
299
|
+
)
|
|
300
|
+
)
|
|
284
301
|
|
|
285
302
|
# Sort by price
|
|
286
303
|
types.sort(key=lambda t: t.price_cents_per_hour)
|
|
@@ -309,15 +326,17 @@ class LambdaLabsClient:
|
|
|
309
326
|
else:
|
|
310
327
|
ssh_key_names = ssh_keys # Already list of strings
|
|
311
328
|
|
|
312
|
-
instances.append(
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
329
|
+
instances.append(
|
|
330
|
+
Instance(
|
|
331
|
+
id=inst["id"],
|
|
332
|
+
name=inst.get("name", ""),
|
|
333
|
+
instance_type=inst.get("instance_type", {}).get("name", "unknown"),
|
|
334
|
+
status=inst.get("status", "unknown"),
|
|
335
|
+
ip=inst.get("ip"),
|
|
336
|
+
region=inst.get("region", {}).get("name", "unknown"),
|
|
337
|
+
ssh_key_names=ssh_key_names,
|
|
338
|
+
)
|
|
339
|
+
)
|
|
321
340
|
|
|
322
341
|
return instances
|
|
323
342
|
|
|
@@ -393,9 +412,18 @@ class LambdaLabsClient:
|
|
|
393
412
|
for attempt in range(60): # Wait up to 5 minutes for SSH
|
|
394
413
|
try:
|
|
395
414
|
result = subprocess.run(
|
|
396
|
-
[
|
|
397
|
-
|
|
398
|
-
|
|
415
|
+
[
|
|
416
|
+
"ssh",
|
|
417
|
+
"-o",
|
|
418
|
+
"StrictHostKeyChecking=no",
|
|
419
|
+
"-o",
|
|
420
|
+
"ConnectTimeout=10",
|
|
421
|
+
f"ubuntu@{instance.ip}",
|
|
422
|
+
"echo ready",
|
|
423
|
+
],
|
|
424
|
+
capture_output=True,
|
|
425
|
+
text=True,
|
|
426
|
+
timeout=20,
|
|
399
427
|
)
|
|
400
428
|
if result.returncode == 0:
|
|
401
429
|
print("SSH ready!")
|
|
@@ -403,7 +431,7 @@ class LambdaLabsClient:
|
|
|
403
431
|
except subprocess.TimeoutExpired:
|
|
404
432
|
pass
|
|
405
433
|
if attempt % 6 == 5: # Log progress every 30 seconds
|
|
406
|
-
print(f" Still waiting for SSH ({(attempt+1)*5}s elapsed)...")
|
|
434
|
+
print(f" Still waiting for SSH ({(attempt + 1) * 5}s elapsed)...")
|
|
407
435
|
time.sleep(5)
|
|
408
436
|
|
|
409
437
|
print("Warning: SSH may not be ready yet, continuing anyway...")
|
|
@@ -411,7 +439,9 @@ class LambdaLabsClient:
|
|
|
411
439
|
|
|
412
440
|
def terminate_instance(self, instance_id: str) -> bool:
|
|
413
441
|
"""Terminate an instance."""
|
|
414
|
-
data = self._post(
|
|
442
|
+
data = self._post(
|
|
443
|
+
"/instance-operations/terminate", {"instance_ids": [instance_id]}
|
|
444
|
+
)
|
|
415
445
|
terminated = data.get("data", {}).get("terminated_instances", [])
|
|
416
446
|
return any(t.get("id") == instance_id for t in terminated)
|
|
417
447
|
|
|
@@ -421,7 +451,13 @@ class LambdaLabsClient:
|
|
|
421
451
|
return "# Instance IP not yet available"
|
|
422
452
|
return f"ssh {user}@{instance.ip}"
|
|
423
453
|
|
|
424
|
-
def ssh_run(
|
|
454
|
+
def ssh_run(
|
|
455
|
+
self,
|
|
456
|
+
instance: Instance,
|
|
457
|
+
command: str,
|
|
458
|
+
timeout: int | None = None,
|
|
459
|
+
retries: int = 3,
|
|
460
|
+
) -> subprocess.CompletedProcess:
|
|
425
461
|
"""Run a command on an instance via SSH.
|
|
426
462
|
|
|
427
463
|
Args:
|
|
@@ -437,12 +473,17 @@ class LambdaLabsClient:
|
|
|
437
473
|
raise RuntimeError("Instance has no IP address")
|
|
438
474
|
|
|
439
475
|
ssh_cmd = [
|
|
440
|
-
"ssh",
|
|
441
|
-
"-o",
|
|
442
|
-
"
|
|
443
|
-
"-o",
|
|
476
|
+
"ssh",
|
|
477
|
+
"-o",
|
|
478
|
+
"StrictHostKeyChecking=no",
|
|
479
|
+
"-o",
|
|
480
|
+
"ConnectTimeout=30", # Increased from 10
|
|
481
|
+
"-o",
|
|
482
|
+
"ServerAliveInterval=60", # Keep connection alive
|
|
483
|
+
"-o",
|
|
484
|
+
"ServerAliveCountMax=3",
|
|
444
485
|
f"ubuntu@{instance.ip}",
|
|
445
|
-
command
|
|
486
|
+
command,
|
|
446
487
|
]
|
|
447
488
|
|
|
448
489
|
last_error = None
|
|
@@ -462,7 +503,12 @@ class LambdaLabsClient:
|
|
|
462
503
|
|
|
463
504
|
raise last_error if last_error else RuntimeError("SSH failed")
|
|
464
505
|
|
|
465
|
-
def setup_instance(
|
|
506
|
+
def setup_instance(
|
|
507
|
+
self,
|
|
508
|
+
instance: Instance,
|
|
509
|
+
repo_url: str = "https://github.com/OpenAdaptAI/openadapt-ml.git",
|
|
510
|
+
clean_gpu: bool = True,
|
|
511
|
+
) -> bool:
|
|
466
512
|
"""Set up training environment on instance.
|
|
467
513
|
|
|
468
514
|
Clones repo, installs uv, syncs dependencies.
|
|
@@ -475,7 +521,9 @@ class LambdaLabsClient:
|
|
|
475
521
|
if clean_gpu:
|
|
476
522
|
print(" Clearing GPU memory...")
|
|
477
523
|
try:
|
|
478
|
-
self.ssh_run(
|
|
524
|
+
self.ssh_run(
|
|
525
|
+
instance,
|
|
526
|
+
"""
|
|
479
527
|
python3 -c "
|
|
480
528
|
import torch
|
|
481
529
|
if torch.cuda.is_available():
|
|
@@ -485,11 +533,13 @@ if torch.cuda.is_available():
|
|
|
485
533
|
" 2>/dev/null || true
|
|
486
534
|
# Kill any stale python processes using GPU
|
|
487
535
|
pkill -f "python.*train" 2>/dev/null || true
|
|
488
|
-
|
|
536
|
+
""",
|
|
537
|
+
timeout=60,
|
|
538
|
+
)
|
|
489
539
|
except Exception as e:
|
|
490
540
|
print(f" GPU cleanup skipped: {e}")
|
|
491
541
|
|
|
492
|
-
setup_script = f
|
|
542
|
+
setup_script = f"""
|
|
493
543
|
set -e
|
|
494
544
|
cd ~
|
|
495
545
|
|
|
@@ -509,10 +559,12 @@ fi
|
|
|
509
559
|
cd openadapt-ml
|
|
510
560
|
uv sync
|
|
511
561
|
echo "SETUP_COMPLETE"
|
|
512
|
-
|
|
562
|
+
"""
|
|
513
563
|
|
|
514
564
|
try:
|
|
515
|
-
result = self.ssh_run(
|
|
565
|
+
result = self.ssh_run(
|
|
566
|
+
instance, setup_script, timeout=900
|
|
567
|
+
) # 15 min timeout for setup
|
|
516
568
|
|
|
517
569
|
if "SETUP_COMPLETE" in result.stdout:
|
|
518
570
|
print(" Environment ready")
|
|
@@ -528,7 +580,9 @@ echo "SETUP_COMPLETE"
|
|
|
528
580
|
print(f" Setup failed: {e}")
|
|
529
581
|
return False
|
|
530
582
|
|
|
531
|
-
def sync_local_code(
|
|
583
|
+
def sync_local_code(
|
|
584
|
+
self, instance: Instance, local_repo_path: str = ".", retries: int = 3
|
|
585
|
+
) -> bool:
|
|
532
586
|
"""Sync local code changes to remote instance.
|
|
533
587
|
|
|
534
588
|
Uses rsync to push local code, excluding .venv, .git, etc.
|
|
@@ -551,19 +605,30 @@ echo "SETUP_COMPLETE"
|
|
|
551
605
|
ssh_opts = "ssh -o StrictHostKeyChecking=no -o ConnectTimeout=30 -o ServerAliveInterval=60"
|
|
552
606
|
|
|
553
607
|
rsync_cmd = [
|
|
554
|
-
"rsync",
|
|
608
|
+
"rsync",
|
|
609
|
+
"-avz",
|
|
610
|
+
"--progress",
|
|
555
611
|
"--timeout=120", # 2 minute timeout per file
|
|
556
|
-
"--exclude",
|
|
557
|
-
"
|
|
558
|
-
"--exclude",
|
|
559
|
-
"
|
|
560
|
-
"--exclude",
|
|
561
|
-
"
|
|
562
|
-
"--exclude",
|
|
563
|
-
"
|
|
564
|
-
"
|
|
612
|
+
"--exclude",
|
|
613
|
+
".venv",
|
|
614
|
+
"--exclude",
|
|
615
|
+
".git",
|
|
616
|
+
"--exclude",
|
|
617
|
+
"__pycache__",
|
|
618
|
+
"--exclude",
|
|
619
|
+
"*.pyc",
|
|
620
|
+
"--exclude",
|
|
621
|
+
".env",
|
|
622
|
+
"--exclude",
|
|
623
|
+
"training_output",
|
|
624
|
+
"--exclude",
|
|
625
|
+
"checkpoints",
|
|
626
|
+
"--exclude",
|
|
627
|
+
"synthetic*",
|
|
628
|
+
"-e",
|
|
629
|
+
ssh_opts,
|
|
565
630
|
f"{local_repo_path}/",
|
|
566
|
-
f"ubuntu@{instance.ip}:~/openadapt-ml/"
|
|
631
|
+
f"ubuntu@{instance.ip}:~/openadapt-ml/",
|
|
567
632
|
]
|
|
568
633
|
|
|
569
634
|
for attempt in range(retries):
|
|
@@ -577,7 +642,13 @@ echo "SETUP_COMPLETE"
|
|
|
577
642
|
|
|
578
643
|
return False
|
|
579
644
|
|
|
580
|
-
def upload_capture(
|
|
645
|
+
def upload_capture(
|
|
646
|
+
self,
|
|
647
|
+
instance: Instance,
|
|
648
|
+
local_path: str,
|
|
649
|
+
remote_path: str = "~/capture",
|
|
650
|
+
retries: int = 3,
|
|
651
|
+
) -> bool:
|
|
581
652
|
"""Upload a capture directory to instance via rsync.
|
|
582
653
|
|
|
583
654
|
Args:
|
|
@@ -598,11 +669,14 @@ echo "SETUP_COMPLETE"
|
|
|
598
669
|
ssh_opts = "ssh -o StrictHostKeyChecking=no -o ConnectTimeout=30 -o ServerAliveInterval=60"
|
|
599
670
|
|
|
600
671
|
rsync_cmd = [
|
|
601
|
-
"rsync",
|
|
672
|
+
"rsync",
|
|
673
|
+
"-avz",
|
|
674
|
+
"--progress",
|
|
602
675
|
"--timeout=120", # 2 minute timeout per file
|
|
603
|
-
"-e",
|
|
676
|
+
"-e",
|
|
677
|
+
ssh_opts,
|
|
604
678
|
f"{local_path}/",
|
|
605
|
-
f"ubuntu@{instance.ip}:{remote_path}/"
|
|
679
|
+
f"ubuntu@{instance.ip}:{remote_path}/",
|
|
606
680
|
]
|
|
607
681
|
|
|
608
682
|
for attempt in range(retries):
|
|
@@ -646,16 +720,18 @@ echo "SETUP_COMPLETE"
|
|
|
646
720
|
train_cmd += f' --goal "{goal}"'
|
|
647
721
|
|
|
648
722
|
# Full script with environment setup
|
|
649
|
-
script = f
|
|
723
|
+
script = f"""
|
|
650
724
|
cd ~/openadapt-ml
|
|
651
725
|
export PATH="$HOME/.local/bin:$PATH"
|
|
652
726
|
{train_cmd}
|
|
653
|
-
|
|
727
|
+
"""
|
|
654
728
|
|
|
655
729
|
ssh_cmd = [
|
|
656
|
-
"ssh",
|
|
730
|
+
"ssh",
|
|
731
|
+
"-o",
|
|
732
|
+
"StrictHostKeyChecking=no",
|
|
657
733
|
f"ubuntu@{instance.ip}",
|
|
658
|
-
script
|
|
734
|
+
script,
|
|
659
735
|
]
|
|
660
736
|
|
|
661
737
|
print(f"Running training on {instance.ip}...")
|
|
@@ -705,37 +781,42 @@ export PATH="$HOME/.local/bin:$PATH"
|
|
|
705
781
|
if include_logs:
|
|
706
782
|
print(" Downloading training logs...")
|
|
707
783
|
rsync_cmd = [
|
|
708
|
-
"rsync",
|
|
709
|
-
"-
|
|
784
|
+
"rsync",
|
|
785
|
+
"-avz",
|
|
786
|
+
"-e",
|
|
787
|
+
"ssh -o StrictHostKeyChecking=no",
|
|
710
788
|
f"ubuntu@{instance.ip}:{remote_path}/training_output/",
|
|
711
|
-
f"{local_path}/training_output_lambda/"
|
|
789
|
+
f"{local_path}/training_output_lambda/",
|
|
712
790
|
]
|
|
713
791
|
result = subprocess.run(rsync_cmd, capture_output=True)
|
|
714
792
|
if result.returncode == 0:
|
|
715
793
|
print(" Training logs downloaded to training_output_lambda/")
|
|
716
794
|
else:
|
|
717
|
-
print(
|
|
795
|
+
print(" Warning: Failed to download logs")
|
|
718
796
|
success = False
|
|
719
797
|
|
|
720
798
|
# Download checkpoint
|
|
721
799
|
if include_checkpoint:
|
|
722
800
|
print(" Downloading checkpoint...")
|
|
723
801
|
rsync_cmd = [
|
|
724
|
-
"rsync",
|
|
725
|
-
"-
|
|
802
|
+
"rsync",
|
|
803
|
+
"-avz",
|
|
804
|
+
"-e",
|
|
805
|
+
"ssh -o StrictHostKeyChecking=no",
|
|
726
806
|
f"ubuntu@{instance.ip}:{remote_path}/checkpoints/",
|
|
727
|
-
f"{local_path}/checkpoints_lambda/"
|
|
807
|
+
f"{local_path}/checkpoints_lambda/",
|
|
728
808
|
]
|
|
729
809
|
result = subprocess.run(rsync_cmd, capture_output=True)
|
|
730
810
|
if result.returncode == 0:
|
|
731
811
|
print(" Checkpoint downloaded to checkpoints_lambda/")
|
|
732
812
|
else:
|
|
733
|
-
print(
|
|
813
|
+
print(" Warning: Failed to download checkpoint (may not exist yet)")
|
|
734
814
|
|
|
735
815
|
# Regenerate all dashboards with static navigation and correct status
|
|
736
816
|
if include_logs:
|
|
737
817
|
try:
|
|
738
818
|
from openadapt_ml.training.trainer import regenerate_all_dashboards
|
|
819
|
+
|
|
739
820
|
output_dir = Path(local_path) / "training_output_lambda"
|
|
740
821
|
if output_dir.exists():
|
|
741
822
|
print(" Regenerating dashboards with static navigation...")
|
|
@@ -754,8 +835,9 @@ export PATH="$HOME/.local/bin:$PATH"
|
|
|
754
835
|
)
|
|
755
836
|
try:
|
|
756
837
|
import json
|
|
838
|
+
|
|
757
839
|
return json.loads(result.stdout.strip())
|
|
758
|
-
except:
|
|
840
|
+
except Exception:
|
|
759
841
|
return {}
|
|
760
842
|
|
|
761
843
|
|
|
@@ -797,19 +879,22 @@ def main():
|
|
|
797
879
|
subparsers = parser.add_subparsers(dest="command", help="Command")
|
|
798
880
|
|
|
799
881
|
# List instances command
|
|
800
|
-
|
|
882
|
+
subparsers.add_parser("list", help="List available instance types")
|
|
801
883
|
|
|
802
884
|
# Status command
|
|
803
|
-
|
|
885
|
+
subparsers.add_parser("status", help="Show running instances")
|
|
804
886
|
|
|
805
887
|
# Launch command
|
|
806
888
|
launch_parser = subparsers.add_parser("launch", help="Launch a GPU instance")
|
|
807
889
|
launch_parser.add_argument(
|
|
808
|
-
"--type",
|
|
890
|
+
"--type",
|
|
891
|
+
"-t",
|
|
809
892
|
default="gpu_1x_a100",
|
|
810
893
|
help="Instance type (default: gpu_1x_a100)",
|
|
811
894
|
)
|
|
812
|
-
launch_parser.add_argument(
|
|
895
|
+
launch_parser.add_argument(
|
|
896
|
+
"--region", "-r", help="Region (auto-selects if not specified)"
|
|
897
|
+
)
|
|
813
898
|
launch_parser.add_argument("--name", "-n", help="Instance name")
|
|
814
899
|
|
|
815
900
|
# Terminate command
|
|
@@ -817,112 +902,269 @@ def main():
|
|
|
817
902
|
term_parser.add_argument("instance_id", help="Instance ID to terminate")
|
|
818
903
|
|
|
819
904
|
# SSH command - run commands or get interactive shell
|
|
820
|
-
ssh_parser = subparsers.add_parser(
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
ssh_parser.add_argument(
|
|
905
|
+
ssh_parser = subparsers.add_parser(
|
|
906
|
+
"ssh", help="SSH into Lambda instance or run command"
|
|
907
|
+
)
|
|
908
|
+
ssh_parser.add_argument(
|
|
909
|
+
"instance_id", nargs="?", help="Instance ID (uses first if not specified)"
|
|
910
|
+
)
|
|
911
|
+
ssh_parser.add_argument(
|
|
912
|
+
"--cmd", "-c", help="Command to run (opens shell if not specified)"
|
|
913
|
+
)
|
|
914
|
+
ssh_parser.add_argument(
|
|
915
|
+
"--timeout", "-t", type=int, default=60, help="Command timeout in seconds"
|
|
916
|
+
)
|
|
824
917
|
|
|
825
918
|
# Serve command - start dashboard server with stop button support
|
|
826
|
-
serve_parser = subparsers.add_parser(
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
serve_parser.add_argument(
|
|
919
|
+
serve_parser = subparsers.add_parser(
|
|
920
|
+
"serve", help="Start dashboard server with stop button support"
|
|
921
|
+
)
|
|
922
|
+
serve_parser.add_argument(
|
|
923
|
+
"--output",
|
|
924
|
+
"-o",
|
|
925
|
+
default="training_output",
|
|
926
|
+
help="Output directory (default: training_output)",
|
|
927
|
+
)
|
|
928
|
+
serve_parser.add_argument(
|
|
929
|
+
"--port", "-p", type=int, default=8765, help="Port (default: 8765)"
|
|
930
|
+
)
|
|
931
|
+
serve_parser.add_argument(
|
|
932
|
+
"--open", action="store_true", help="Open dashboard in browser"
|
|
933
|
+
)
|
|
830
934
|
|
|
831
935
|
# Rsync command - copy files to/from Lambda instance
|
|
832
|
-
rsync_parser = subparsers.add_parser(
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
rsync_parser.add_argument(
|
|
836
|
-
|
|
936
|
+
rsync_parser = subparsers.add_parser(
|
|
937
|
+
"rsync", help="Rsync files to/from Lambda instance"
|
|
938
|
+
)
|
|
939
|
+
rsync_parser.add_argument(
|
|
940
|
+
"source", help="Source path (prefix with 'remote:' for remote paths)"
|
|
941
|
+
)
|
|
942
|
+
rsync_parser.add_argument(
|
|
943
|
+
"dest", help="Destination path (prefix with 'remote:' for remote paths)"
|
|
944
|
+
)
|
|
945
|
+
rsync_parser.add_argument(
|
|
946
|
+
"instance_id", nargs="?", help="Instance ID (uses first if not specified)"
|
|
947
|
+
)
|
|
948
|
+
rsync_parser.add_argument(
|
|
949
|
+
"--delete", action="store_true", help="Delete extraneous files from dest"
|
|
950
|
+
)
|
|
837
951
|
|
|
838
952
|
# Setup command
|
|
839
|
-
|
|
953
|
+
subparsers.add_parser("setup", help="Set up SSH key for Lambda Labs")
|
|
840
954
|
|
|
841
955
|
# Train command - full automated training pipeline
|
|
842
956
|
train_parser = subparsers.add_parser("train", help="Run training on Lambda GPU")
|
|
843
957
|
train_parser.add_argument("--capture", "-c", help="Local path to capture directory")
|
|
844
958
|
train_parser.add_argument("--goal", "-g", help="Task goal description")
|
|
845
|
-
train_parser.add_argument(
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
train_parser.add_argument(
|
|
959
|
+
train_parser.add_argument(
|
|
960
|
+
"--config",
|
|
961
|
+
default="configs/qwen3vl_capture_4bit.yaml",
|
|
962
|
+
help="Config file (default: 4bit for memory efficiency)",
|
|
963
|
+
)
|
|
964
|
+
train_parser.add_argument(
|
|
965
|
+
"--type", "-t", default="gpu_1x_a10", help="Instance type"
|
|
966
|
+
)
|
|
967
|
+
train_parser.add_argument(
|
|
968
|
+
"--instance", "-i", help="Use existing instance ID instead of launching new"
|
|
969
|
+
)
|
|
970
|
+
train_parser.add_argument(
|
|
971
|
+
"--no-terminate",
|
|
972
|
+
action="store_true",
|
|
973
|
+
help="Don't terminate instance after training",
|
|
974
|
+
)
|
|
975
|
+
train_parser.add_argument(
|
|
976
|
+
"--max-runtime",
|
|
977
|
+
type=int,
|
|
978
|
+
default=60,
|
|
979
|
+
help="Max runtime in minutes before auto-terminate (default: 60)",
|
|
980
|
+
)
|
|
981
|
+
train_parser.add_argument(
|
|
982
|
+
"--open",
|
|
983
|
+
action="store_true",
|
|
984
|
+
help="Open dashboard in browser when training starts",
|
|
985
|
+
)
|
|
851
986
|
|
|
852
987
|
# Training status command
|
|
853
|
-
train_status_parser = subparsers.add_parser(
|
|
988
|
+
train_status_parser = subparsers.add_parser(
|
|
989
|
+
"train-status", help="Check training status on instance"
|
|
990
|
+
)
|
|
854
991
|
train_status_parser.add_argument("instance_id", nargs="?", help="Instance ID")
|
|
855
992
|
|
|
856
993
|
# Monitor command - live dashboard for Lambda training
|
|
857
|
-
monitor_parser = subparsers.add_parser(
|
|
994
|
+
monitor_parser = subparsers.add_parser(
|
|
995
|
+
"monitor", help="Monitor Lambda training with live dashboard"
|
|
996
|
+
)
|
|
858
997
|
monitor_parser.add_argument("instance_id", nargs="?", help="Instance ID")
|
|
859
|
-
monitor_parser.add_argument(
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
monitor_parser.add_argument(
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
monitor_parser.add_argument(
|
|
998
|
+
monitor_parser.add_argument(
|
|
999
|
+
"--open", action="store_true", help="Open dashboard in browser"
|
|
1000
|
+
)
|
|
1001
|
+
monitor_parser.add_argument(
|
|
1002
|
+
"--interval", type=int, default=5, help="Poll interval in seconds (default: 5)"
|
|
1003
|
+
)
|
|
1004
|
+
monitor_parser.add_argument(
|
|
1005
|
+
"--capture", type=str, help="Local capture path for screenshot symlink"
|
|
1006
|
+
)
|
|
1007
|
+
monitor_parser.add_argument(
|
|
1008
|
+
"--auto-stop-loss",
|
|
1009
|
+
type=float,
|
|
1010
|
+
default=0.5,
|
|
1011
|
+
help="Auto-terminate when loss drops below this (default: 0.5)",
|
|
1012
|
+
)
|
|
1013
|
+
monitor_parser.add_argument(
|
|
1014
|
+
"--download-checkpoints",
|
|
1015
|
+
action="store_true",
|
|
1016
|
+
default=True,
|
|
1017
|
+
help="Auto-download checkpoints each epoch",
|
|
1018
|
+
)
|
|
1019
|
+
monitor_parser.add_argument(
|
|
1020
|
+
"--no-download-checkpoints",
|
|
1021
|
+
action="store_false",
|
|
1022
|
+
dest="download_checkpoints",
|
|
1023
|
+
help="Disable checkpoint download",
|
|
1024
|
+
)
|
|
1025
|
+
monitor_parser.add_argument(
|
|
1026
|
+
"--stub",
|
|
1027
|
+
action="store_true",
|
|
1028
|
+
help="Use stub training provider (no GPU, instant simulation)",
|
|
1029
|
+
)
|
|
866
1030
|
|
|
867
1031
|
# Refresh command - one-shot dashboard update
|
|
868
|
-
refresh_parser = subparsers.add_parser(
|
|
1032
|
+
refresh_parser = subparsers.add_parser(
|
|
1033
|
+
"refresh", help="One-shot refresh of training dashboard"
|
|
1034
|
+
)
|
|
869
1035
|
refresh_parser.add_argument("instance_id", nargs="?", help="Instance ID")
|
|
870
|
-
refresh_parser.add_argument(
|
|
871
|
-
|
|
1036
|
+
refresh_parser.add_argument(
|
|
1037
|
+
"--open", action="store_true", help="Open dashboard in browser"
|
|
1038
|
+
)
|
|
1039
|
+
refresh_parser.add_argument(
|
|
1040
|
+
"--capture", type=str, help="Local capture path for screenshot preview"
|
|
1041
|
+
)
|
|
872
1042
|
|
|
873
1043
|
# Checkpoints command - list remote checkpoints
|
|
874
|
-
checkpoints_parser = subparsers.add_parser(
|
|
1044
|
+
checkpoints_parser = subparsers.add_parser(
|
|
1045
|
+
"checkpoints", help="List checkpoints on remote instance"
|
|
1046
|
+
)
|
|
875
1047
|
checkpoints_parser.add_argument("instance_id", nargs="?", help="Instance ID")
|
|
876
1048
|
|
|
877
1049
|
# Download results command
|
|
878
|
-
download_parser = subparsers.add_parser(
|
|
1050
|
+
download_parser = subparsers.add_parser(
|
|
1051
|
+
"download", help="Download training results from instance"
|
|
1052
|
+
)
|
|
879
1053
|
download_parser.add_argument("instance_id", nargs="?", help="Instance ID")
|
|
880
|
-
download_parser.add_argument(
|
|
1054
|
+
download_parser.add_argument(
|
|
1055
|
+
"--output", "-o", default=".", help="Local output directory"
|
|
1056
|
+
)
|
|
881
1057
|
|
|
882
1058
|
# Check files on instance
|
|
883
|
-
files_parser = subparsers.add_parser(
|
|
1059
|
+
files_parser = subparsers.add_parser(
|
|
1060
|
+
"files", help="List training files on instance"
|
|
1061
|
+
)
|
|
884
1062
|
files_parser.add_argument("instance_id", nargs="?", help="Instance ID")
|
|
885
|
-
files_parser.add_argument(
|
|
1063
|
+
files_parser.add_argument(
|
|
1064
|
+
"--path", "-p", default="~/openadapt-ml", help="Path to check"
|
|
1065
|
+
)
|
|
886
1066
|
|
|
887
1067
|
# Kill command - terminate training processes
|
|
888
|
-
kill_parser = subparsers.add_parser(
|
|
1068
|
+
kill_parser = subparsers.add_parser(
|
|
1069
|
+
"kill", help="Kill training/inference processes on instance"
|
|
1070
|
+
)
|
|
889
1071
|
kill_parser.add_argument("instance_id", nargs="?", help="Instance ID")
|
|
890
|
-
kill_parser.add_argument(
|
|
891
|
-
|
|
1072
|
+
kill_parser.add_argument(
|
|
1073
|
+
"--local", action="store_true", help="Also kill local Lambda-related processes"
|
|
1074
|
+
)
|
|
1075
|
+
kill_parser.add_argument(
|
|
1076
|
+
"--all",
|
|
1077
|
+
action="store_true",
|
|
1078
|
+
help="Kill all Python processes on instance (careful!)",
|
|
1079
|
+
)
|
|
892
1080
|
|
|
893
1081
|
# Check command - analyze training status and early stopping
|
|
894
|
-
check_parser = subparsers.add_parser(
|
|
1082
|
+
check_parser = subparsers.add_parser(
|
|
1083
|
+
"check", help="Check training health and early stopping status"
|
|
1084
|
+
)
|
|
895
1085
|
check_parser.add_argument("instance_id", nargs="?", help="Instance ID")
|
|
896
|
-
check_parser.add_argument(
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
1086
|
+
check_parser.add_argument(
|
|
1087
|
+
"--threshold",
|
|
1088
|
+
"-t",
|
|
1089
|
+
type=float,
|
|
1090
|
+
default=0.01,
|
|
1091
|
+
help="Early stopping threshold (loss improvement over last N steps)",
|
|
1092
|
+
)
|
|
1093
|
+
check_parser.add_argument(
|
|
1094
|
+
"--window",
|
|
1095
|
+
"-w",
|
|
1096
|
+
type=int,
|
|
1097
|
+
default=10,
|
|
1098
|
+
help="Number of recent steps to check for improvement",
|
|
1099
|
+
)
|
|
900
1100
|
|
|
901
1101
|
# Compare command - run comparison on Lambda and sync back
|
|
902
|
-
compare_parser = subparsers.add_parser(
|
|
1102
|
+
compare_parser = subparsers.add_parser(
|
|
1103
|
+
"compare", help="Run human vs AI comparison on Lambda"
|
|
1104
|
+
)
|
|
903
1105
|
compare_parser.add_argument("instance_id", nargs="?", help="Instance ID")
|
|
904
|
-
compare_parser.add_argument(
|
|
905
|
-
|
|
906
|
-
|
|
1106
|
+
compare_parser.add_argument(
|
|
1107
|
+
"--checkpoint", "-c", help="Checkpoint to use (default: latest)"
|
|
1108
|
+
)
|
|
1109
|
+
compare_parser.add_argument(
|
|
1110
|
+
"--epoch", "-e", type=int, help="Use checkpoint from specific epoch"
|
|
1111
|
+
)
|
|
1112
|
+
compare_parser.add_argument(
|
|
1113
|
+
"--open", action="store_true", help="Open viewer after generation"
|
|
1114
|
+
)
|
|
907
1115
|
|
|
908
1116
|
# Results viewer command - downloads and generates comparison viewer
|
|
909
|
-
results_parser = subparsers.add_parser(
|
|
910
|
-
|
|
1117
|
+
results_parser = subparsers.add_parser(
|
|
1118
|
+
"results", help="Download results and generate comparison viewer"
|
|
1119
|
+
)
|
|
1120
|
+
results_parser.add_argument(
|
|
1121
|
+
"--capture",
|
|
1122
|
+
"-c",
|
|
1123
|
+
required=True,
|
|
1124
|
+
help="Local capture directory (for comparison)",
|
|
1125
|
+
)
|
|
911
1126
|
results_parser.add_argument("--goal", "-g", help="Task goal description")
|
|
912
|
-
results_parser.add_argument(
|
|
1127
|
+
results_parser.add_argument(
|
|
1128
|
+
"--open", action="store_true", help="Open viewer in browser"
|
|
1129
|
+
)
|
|
913
1130
|
results_parser.add_argument("instance_id", nargs="?", help="Instance ID")
|
|
914
1131
|
|
|
915
1132
|
# Sync command - sync training output and regenerate navigation for file:// protocol
|
|
916
|
-
sync_parser = subparsers.add_parser(
|
|
1133
|
+
sync_parser = subparsers.add_parser(
|
|
1134
|
+
"sync", help="Sync training output from Lambda and regenerate navigation"
|
|
1135
|
+
)
|
|
917
1136
|
sync_parser.add_argument("instance_id", nargs="?", help="Instance ID")
|
|
918
|
-
sync_parser.add_argument(
|
|
919
|
-
|
|
1137
|
+
sync_parser.add_argument(
|
|
1138
|
+
"--output",
|
|
1139
|
+
"-o",
|
|
1140
|
+
default="training_output",
|
|
1141
|
+
help="Local output directory (default: training_output)",
|
|
1142
|
+
)
|
|
1143
|
+
sync_parser.add_argument(
|
|
1144
|
+
"--open", action="store_true", help="Open dashboard in browser after sync"
|
|
1145
|
+
)
|
|
920
1146
|
|
|
921
1147
|
# Viewer command - regenerate local viewer (no Lambda required)
|
|
922
|
-
viewer_parser = subparsers.add_parser(
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
viewer_parser.add_argument(
|
|
1148
|
+
viewer_parser = subparsers.add_parser(
|
|
1149
|
+
"viewer", help="Regenerate local viewer (no Lambda required)"
|
|
1150
|
+
)
|
|
1151
|
+
viewer_parser.add_argument(
|
|
1152
|
+
"--output",
|
|
1153
|
+
"-o",
|
|
1154
|
+
default="training_output",
|
|
1155
|
+
help="Training output directory (default: training_output)",
|
|
1156
|
+
)
|
|
1157
|
+
viewer_parser.add_argument(
|
|
1158
|
+
"--dashboard",
|
|
1159
|
+
"-d",
|
|
1160
|
+
action="store_true",
|
|
1161
|
+
help="Regenerate dashboard instead of viewer",
|
|
1162
|
+
)
|
|
1163
|
+
viewer_parser.add_argument(
|
|
1164
|
+
"--open",
|
|
1165
|
+
action="store_true",
|
|
1166
|
+
help="Open in browser (use 'serve' instead for better experience)",
|
|
1167
|
+
)
|
|
926
1168
|
|
|
927
1169
|
args = parser.parse_args()
|
|
928
1170
|
|
|
@@ -942,10 +1184,11 @@ def main():
|
|
|
942
1184
|
print("Available GPU instances:\n")
|
|
943
1185
|
types = client.list_instance_types()
|
|
944
1186
|
for t in types:
|
|
945
|
-
avail = "available" if t.available_regions else "no capacity"
|
|
946
1187
|
print(f" {t}")
|
|
947
1188
|
print(f"\nTotal: {len(types)} instance types")
|
|
948
|
-
print(
|
|
1189
|
+
print(
|
|
1190
|
+
"\nLaunch with: python -m openadapt_ml.cloud.lambda_labs launch --type <name>"
|
|
1191
|
+
)
|
|
949
1192
|
|
|
950
1193
|
elif args.command == "status":
|
|
951
1194
|
instances = client.list_instances()
|
|
@@ -968,13 +1211,15 @@ def main():
|
|
|
968
1211
|
ssh_key_names=[ssh_key],
|
|
969
1212
|
name=args.name,
|
|
970
1213
|
)
|
|
971
|
-
print(
|
|
1214
|
+
print("\nInstance launched!")
|
|
972
1215
|
print(f" ID: {instance.id}")
|
|
973
1216
|
print(f" IP: {instance.ip}")
|
|
974
1217
|
print(f" Type: {instance.instance_type}")
|
|
975
1218
|
print(f" Region: {instance.region}")
|
|
976
1219
|
print(f"\nConnect with: ssh ubuntu@{instance.ip}")
|
|
977
|
-
print(
|
|
1220
|
+
print(
|
|
1221
|
+
f"\nTerminate with: python -m openadapt_ml.cloud.lambda_labs terminate {instance.id}"
|
|
1222
|
+
)
|
|
978
1223
|
|
|
979
1224
|
elif args.command == "terminate":
|
|
980
1225
|
if client.terminate_instance(args.instance_id):
|
|
@@ -989,14 +1234,16 @@ def main():
|
|
|
989
1234
|
return
|
|
990
1235
|
|
|
991
1236
|
if args.instance_id:
|
|
992
|
-
instance = next(
|
|
1237
|
+
instance = next(
|
|
1238
|
+
(i for i in instances if i.id.startswith(args.instance_id)), None
|
|
1239
|
+
)
|
|
993
1240
|
if not instance:
|
|
994
1241
|
print(f"Instance {args.instance_id} not found.")
|
|
995
1242
|
return
|
|
996
1243
|
else:
|
|
997
1244
|
instance = instances[0]
|
|
998
1245
|
|
|
999
|
-
if hasattr(args,
|
|
1246
|
+
if hasattr(args, "cmd") and args.cmd:
|
|
1000
1247
|
# Run single command
|
|
1001
1248
|
print(f"Running on {instance.ip}: {args.cmd}")
|
|
1002
1249
|
result = client.ssh_run(instance, args.cmd, timeout=args.timeout)
|
|
@@ -1018,7 +1265,9 @@ def main():
|
|
|
1018
1265
|
return
|
|
1019
1266
|
|
|
1020
1267
|
if args.instance_id:
|
|
1021
|
-
instance = next(
|
|
1268
|
+
instance = next(
|
|
1269
|
+
(i for i in instances if i.id.startswith(args.instance_id)), None
|
|
1270
|
+
)
|
|
1022
1271
|
if not instance:
|
|
1023
1272
|
print(f"Instance {args.instance_id} not found.")
|
|
1024
1273
|
return
|
|
@@ -1035,8 +1284,11 @@ def main():
|
|
|
1035
1284
|
dest = f"ubuntu@{instance.ip}:{dest[7:]}"
|
|
1036
1285
|
|
|
1037
1286
|
rsync_cmd = [
|
|
1038
|
-
"rsync",
|
|
1039
|
-
"-
|
|
1287
|
+
"rsync",
|
|
1288
|
+
"-avz",
|
|
1289
|
+
"--progress",
|
|
1290
|
+
"-e",
|
|
1291
|
+
"ssh -o StrictHostKeyChecking=no",
|
|
1040
1292
|
]
|
|
1041
1293
|
if args.delete:
|
|
1042
1294
|
rsync_cmd.append("--delete")
|
|
@@ -1056,7 +1308,6 @@ def main():
|
|
|
1056
1308
|
|
|
1057
1309
|
instance = None
|
|
1058
1310
|
start_time = time_module.time()
|
|
1059
|
-
launched_new = False
|
|
1060
1311
|
training_completed = False # Track if training actually finished
|
|
1061
1312
|
|
|
1062
1313
|
# Instance pricing (approximate $/hr)
|
|
@@ -1071,7 +1322,9 @@ def main():
|
|
|
1071
1322
|
# Get or launch instance
|
|
1072
1323
|
if args.instance:
|
|
1073
1324
|
instances = client.list_instances()
|
|
1074
|
-
instance = next(
|
|
1325
|
+
instance = next(
|
|
1326
|
+
(i for i in instances if i.id.startswith(args.instance)), None
|
|
1327
|
+
)
|
|
1075
1328
|
if not instance:
|
|
1076
1329
|
print(f"Error: Instance {args.instance} not found")
|
|
1077
1330
|
return
|
|
@@ -1091,7 +1344,6 @@ def main():
|
|
|
1091
1344
|
name="openadapt-training",
|
|
1092
1345
|
)
|
|
1093
1346
|
print(f"Instance launched: {instance.id[:8]}... at {instance.ip}")
|
|
1094
|
-
launched_new = True
|
|
1095
1347
|
|
|
1096
1348
|
price_per_hour = INSTANCE_PRICES.get(instance.instance_type, 1.00)
|
|
1097
1349
|
print(f" Instance type: {instance.instance_type} (~${price_per_hour:.2f}/hr)")
|
|
@@ -1100,16 +1352,21 @@ def main():
|
|
|
1100
1352
|
# Generate initial dashboard with setup status
|
|
1101
1353
|
from pathlib import Path
|
|
1102
1354
|
from openadapt_ml.training.trainer import (
|
|
1103
|
-
TrainingState,
|
|
1104
|
-
|
|
1355
|
+
TrainingState,
|
|
1356
|
+
TrainingConfig,
|
|
1357
|
+
generate_training_dashboard,
|
|
1358
|
+
setup_job_directory,
|
|
1105
1359
|
)
|
|
1106
1360
|
import time as time_module
|
|
1361
|
+
|
|
1107
1362
|
job_id = time_module.strftime("%Y%m%d_%H%M%S")
|
|
1108
1363
|
output_dir = setup_job_directory("training_output", job_id)
|
|
1109
1364
|
dashboard_path = output_dir / "dashboard.html"
|
|
1110
1365
|
log_path = output_dir / "training_log.json"
|
|
1111
1366
|
|
|
1112
|
-
def update_dashboard(
|
|
1367
|
+
def update_dashboard(
|
|
1368
|
+
status: str, logs: list, step: int = 0, loss: float = 0.0, epoch: int = 0
|
|
1369
|
+
):
|
|
1113
1370
|
"""Update dashboard with current setup/training status."""
|
|
1114
1371
|
state = TrainingState(job_id=job_id)
|
|
1115
1372
|
state.cloud_provider = "lambda"
|
|
@@ -1156,9 +1413,13 @@ def main():
|
|
|
1156
1413
|
update_dashboard("installing", setup_logs)
|
|
1157
1414
|
break
|
|
1158
1415
|
if setup_attempt < 2:
|
|
1159
|
-
setup_logs.append(
|
|
1416
|
+
setup_logs.append(
|
|
1417
|
+
f"Setup attempt {setup_attempt + 1} failed, retrying in 30s..."
|
|
1418
|
+
)
|
|
1160
1419
|
update_dashboard("booting", setup_logs)
|
|
1161
|
-
print(
|
|
1420
|
+
print(
|
|
1421
|
+
f" Setup attempt {setup_attempt + 1} failed, retrying in 30s..."
|
|
1422
|
+
)
|
|
1162
1423
|
time_module.sleep(30)
|
|
1163
1424
|
|
|
1164
1425
|
if not setup_success:
|
|
@@ -1167,14 +1428,18 @@ def main():
|
|
|
1167
1428
|
print("\nError: Failed to set up instance after 3 attempts")
|
|
1168
1429
|
print(f"Instance still running: {instance.ip}")
|
|
1169
1430
|
print("Debug via: ssh ubuntu@" + instance.ip)
|
|
1170
|
-
print(
|
|
1431
|
+
print(
|
|
1432
|
+
f"Terminate with: python -m openadapt_ml.cloud.lambda_labs terminate {instance.id}"
|
|
1433
|
+
)
|
|
1171
1434
|
return # Don't terminate - let user debug
|
|
1172
1435
|
|
|
1173
1436
|
# Sync local code to ensure remote has latest changes
|
|
1174
1437
|
setup_logs.append("Syncing local code to instance...")
|
|
1175
1438
|
update_dashboard("installing", setup_logs)
|
|
1176
1439
|
if not client.sync_local_code(instance):
|
|
1177
|
-
setup_logs.append(
|
|
1440
|
+
setup_logs.append(
|
|
1441
|
+
"Warning: Failed to sync local code, using remote repo version"
|
|
1442
|
+
)
|
|
1178
1443
|
update_dashboard("installing", setup_logs)
|
|
1179
1444
|
print("Warning: Failed to sync local code, using remote repo version")
|
|
1180
1445
|
else:
|
|
@@ -1184,7 +1449,7 @@ def main():
|
|
|
1184
1449
|
# Upload capture if provided
|
|
1185
1450
|
remote_capture = None
|
|
1186
1451
|
if args.capture:
|
|
1187
|
-
setup_logs.append(
|
|
1452
|
+
setup_logs.append("Uploading capture data...")
|
|
1188
1453
|
update_dashboard("installing", setup_logs)
|
|
1189
1454
|
if client.upload_capture(instance, args.capture, "~/capture"):
|
|
1190
1455
|
remote_capture = "~/capture"
|
|
@@ -1197,7 +1462,9 @@ def main():
|
|
|
1197
1462
|
print("\nError: Failed to upload capture after retries")
|
|
1198
1463
|
print(f"Instance still running: {instance.ip}")
|
|
1199
1464
|
print("Debug via: ssh ubuntu@" + instance.ip)
|
|
1200
|
-
print(
|
|
1465
|
+
print(
|
|
1466
|
+
f"Terminate with: python -m openadapt_ml.cloud.lambda_labs terminate {instance.id}"
|
|
1467
|
+
)
|
|
1201
1468
|
return # Don't terminate - let user debug
|
|
1202
1469
|
|
|
1203
1470
|
# Run training in background and poll for status
|
|
@@ -1207,7 +1474,7 @@ def main():
|
|
|
1207
1474
|
print("Starting training...")
|
|
1208
1475
|
print("=" * 50 + "\n")
|
|
1209
1476
|
|
|
1210
|
-
|
|
1477
|
+
client.run_training(
|
|
1211
1478
|
instance,
|
|
1212
1479
|
config=args.config,
|
|
1213
1480
|
capture=remote_capture,
|
|
@@ -1219,7 +1486,9 @@ def main():
|
|
|
1219
1486
|
poll_interval = 10 # seconds
|
|
1220
1487
|
last_step = 0
|
|
1221
1488
|
last_epoch = 0
|
|
1222
|
-
print(
|
|
1489
|
+
print(
|
|
1490
|
+
f"Polling training status every {poll_interval}s (Ctrl+C to stop)...\n"
|
|
1491
|
+
)
|
|
1223
1492
|
|
|
1224
1493
|
while True:
|
|
1225
1494
|
try:
|
|
@@ -1234,7 +1503,9 @@ def main():
|
|
|
1234
1503
|
|
|
1235
1504
|
# Print progress when step changes
|
|
1236
1505
|
if step > last_step or epoch > last_epoch:
|
|
1237
|
-
print(
|
|
1506
|
+
print(
|
|
1507
|
+
f" Epoch {epoch + 1}/{total_epochs} | Step {step} | Loss: {loss:.4f} | Elapsed: {elapsed_training:.0f}s"
|
|
1508
|
+
)
|
|
1238
1509
|
last_step = step
|
|
1239
1510
|
last_epoch = epoch
|
|
1240
1511
|
|
|
@@ -1246,7 +1517,9 @@ def main():
|
|
|
1246
1517
|
status["instance_type"] = instance.instance_type
|
|
1247
1518
|
# Add cloud provider info
|
|
1248
1519
|
status["cloud_provider"] = "lambda"
|
|
1249
|
-
status["cloud_dashboard_url"] =
|
|
1520
|
+
status["cloud_dashboard_url"] = (
|
|
1521
|
+
"https://cloud.lambda.ai/instances"
|
|
1522
|
+
)
|
|
1250
1523
|
status["cloud_instance_id"] = instance.id
|
|
1251
1524
|
status["setup_status"] = "training"
|
|
1252
1525
|
status["setup_logs"] = setup_logs
|
|
@@ -1274,9 +1547,11 @@ def main():
|
|
|
1274
1547
|
|
|
1275
1548
|
config = TrainingConfig(
|
|
1276
1549
|
num_train_epochs=total_epochs,
|
|
1277
|
-
learning_rate=status.get("learning_rate", 5e-5)
|
|
1550
|
+
learning_rate=status.get("learning_rate", 5e-5),
|
|
1551
|
+
)
|
|
1552
|
+
dashboard_path.write_text(
|
|
1553
|
+
generate_training_dashboard(state, config)
|
|
1278
1554
|
)
|
|
1279
|
-
dashboard_path.write_text(generate_training_dashboard(state, config))
|
|
1280
1555
|
|
|
1281
1556
|
# Check if training is complete (all epochs done)
|
|
1282
1557
|
if epoch >= total_epochs - 1:
|
|
@@ -1318,13 +1593,15 @@ def main():
|
|
|
1318
1593
|
print("=" * 50)
|
|
1319
1594
|
|
|
1320
1595
|
# Determine the final checkpoint path (main checkpoint after training)
|
|
1321
|
-
checkpoint_path =
|
|
1596
|
+
checkpoint_path = (
|
|
1597
|
+
"/home/ubuntu/openadapt-ml/checkpoints/qwen3vl2b_capture_lora"
|
|
1598
|
+
)
|
|
1322
1599
|
|
|
1323
1600
|
# Check if checkpoint exists
|
|
1324
1601
|
result = client.ssh_run(
|
|
1325
1602
|
instance,
|
|
1326
1603
|
f"ls {checkpoint_path}/adapter_config.json 2>/dev/null && echo 'exists'",
|
|
1327
|
-
timeout=30
|
|
1604
|
+
timeout=30,
|
|
1328
1605
|
)
|
|
1329
1606
|
|
|
1330
1607
|
if "exists" in result.stdout:
|
|
@@ -1336,13 +1613,15 @@ def main():
|
|
|
1336
1613
|
--checkpoint {checkpoint_path} \
|
|
1337
1614
|
--output training_output/{output_name} 2>&1"""
|
|
1338
1615
|
|
|
1339
|
-
print(
|
|
1616
|
+
print(
|
|
1617
|
+
" Generating comparison viewer (this may take a few minutes)..."
|
|
1618
|
+
)
|
|
1340
1619
|
result = client.ssh_run(instance, cmd, timeout=600)
|
|
1341
1620
|
|
|
1342
1621
|
if result.returncode == 0:
|
|
1343
1622
|
print(f" Comparison generated: {output_name}")
|
|
1344
1623
|
else:
|
|
1345
|
-
print(
|
|
1624
|
+
print(" Warning: Comparison generation failed")
|
|
1346
1625
|
if result.stderr:
|
|
1347
1626
|
print(f" Error: {result.stderr}")
|
|
1348
1627
|
else:
|
|
@@ -1357,13 +1636,15 @@ def main():
|
|
|
1357
1636
|
print(f"\nTerminating instance {instance.id[:8]}...")
|
|
1358
1637
|
client.terminate_instance(instance.id)
|
|
1359
1638
|
print("Instance terminated.")
|
|
1360
|
-
print(f"\nFinal cost: ~${cost:.2f} ({elapsed/60:.1f} minutes)")
|
|
1639
|
+
print(f"\nFinal cost: ~${cost:.2f} ({elapsed / 60:.1f} minutes)")
|
|
1361
1640
|
else:
|
|
1362
1641
|
print(f"\nInstance still running: {instance.ip}")
|
|
1363
1642
|
print(f" Current cost: ~${cost:.2f}")
|
|
1364
1643
|
if not training_completed:
|
|
1365
|
-
print(
|
|
1366
|
-
print(
|
|
1644
|
+
print(" (Not terminating - training did not complete successfully)")
|
|
1645
|
+
print(
|
|
1646
|
+
f"Terminate with: python -m openadapt_ml.cloud.lambda_labs terminate {instance.id}"
|
|
1647
|
+
)
|
|
1367
1648
|
|
|
1368
1649
|
elif args.command == "train-status":
|
|
1369
1650
|
instances = client.list_instances()
|
|
@@ -1372,7 +1653,9 @@ def main():
|
|
|
1372
1653
|
return
|
|
1373
1654
|
|
|
1374
1655
|
if args.instance_id:
|
|
1375
|
-
instance = next(
|
|
1656
|
+
instance = next(
|
|
1657
|
+
(i for i in instances if i.id.startswith(args.instance_id)), None
|
|
1658
|
+
)
|
|
1376
1659
|
if not instance:
|
|
1377
1660
|
print(f"Instance {args.instance_id} not found.")
|
|
1378
1661
|
return
|
|
@@ -1398,7 +1681,9 @@ def main():
|
|
|
1398
1681
|
return
|
|
1399
1682
|
|
|
1400
1683
|
if args.instance_id:
|
|
1401
|
-
instance = next(
|
|
1684
|
+
instance = next(
|
|
1685
|
+
(i for i in instances if i.id.startswith(args.instance_id)), None
|
|
1686
|
+
)
|
|
1402
1687
|
if not instance:
|
|
1403
1688
|
print(f"Instance {args.instance_id} not found.")
|
|
1404
1689
|
return
|
|
@@ -1408,10 +1693,14 @@ def main():
|
|
|
1408
1693
|
print(f"Checking checkpoints on {instance.ip}...")
|
|
1409
1694
|
|
|
1410
1695
|
ssh_cmd = [
|
|
1411
|
-
"ssh",
|
|
1696
|
+
"ssh",
|
|
1697
|
+
"-o",
|
|
1698
|
+
"StrictHostKeyChecking=no",
|
|
1699
|
+
"-o",
|
|
1700
|
+
"ConnectTimeout=10",
|
|
1412
1701
|
f"ubuntu@{instance.ip}",
|
|
1413
1702
|
"ls -la ~/openadapt-ml/checkpoints/ 2>/dev/null && "
|
|
1414
|
-
"du -sh ~/openadapt-ml/checkpoints/ 2>/dev/null || echo 'No checkpoints directory found'"
|
|
1703
|
+
"du -sh ~/openadapt-ml/checkpoints/ 2>/dev/null || echo 'No checkpoints directory found'",
|
|
1415
1704
|
]
|
|
1416
1705
|
|
|
1417
1706
|
result = subprocess.run(ssh_cmd, capture_output=True, text=True)
|
|
@@ -1426,7 +1715,11 @@ def main():
|
|
|
1426
1715
|
# One-shot dashboard refresh
|
|
1427
1716
|
import time as time_module
|
|
1428
1717
|
from pathlib import Path
|
|
1429
|
-
from openadapt_ml.training.trainer import
|
|
1718
|
+
from openadapt_ml.training.trainer import (
|
|
1719
|
+
TrainingState,
|
|
1720
|
+
TrainingConfig,
|
|
1721
|
+
generate_training_dashboard,
|
|
1722
|
+
)
|
|
1430
1723
|
|
|
1431
1724
|
instances = client.list_instances()
|
|
1432
1725
|
if not instances:
|
|
@@ -1434,7 +1727,9 @@ def main():
|
|
|
1434
1727
|
return
|
|
1435
1728
|
|
|
1436
1729
|
if args.instance_id:
|
|
1437
|
-
instance = next(
|
|
1730
|
+
instance = next(
|
|
1731
|
+
(i for i in instances if i.id.startswith(args.instance_id)), None
|
|
1732
|
+
)
|
|
1438
1733
|
if not instance:
|
|
1439
1734
|
print(f"Instance {args.instance_id} not found.")
|
|
1440
1735
|
return
|
|
@@ -1442,7 +1737,11 @@ def main():
|
|
|
1442
1737
|
instance = instances[0]
|
|
1443
1738
|
|
|
1444
1739
|
# Use current job directory via symlink
|
|
1445
|
-
from openadapt_ml.training.trainer import
|
|
1740
|
+
from openadapt_ml.training.trainer import (
|
|
1741
|
+
get_current_job_directory,
|
|
1742
|
+
setup_job_directory,
|
|
1743
|
+
)
|
|
1744
|
+
|
|
1446
1745
|
base_dir = Path("training_output")
|
|
1447
1746
|
base_dir.mkdir(exist_ok=True)
|
|
1448
1747
|
|
|
@@ -1459,7 +1758,9 @@ def main():
|
|
|
1459
1758
|
log_path = output_dir / "training_log.json"
|
|
1460
1759
|
|
|
1461
1760
|
# Setup screenshots symlink if local capture path provided
|
|
1462
|
-
local_capture =
|
|
1761
|
+
local_capture = (
|
|
1762
|
+
args.capture if hasattr(args, "capture") and args.capture else None
|
|
1763
|
+
)
|
|
1463
1764
|
if local_capture:
|
|
1464
1765
|
setup_capture_screenshots_symlink(output_dir, local_capture)
|
|
1465
1766
|
|
|
@@ -1483,7 +1784,9 @@ def main():
|
|
|
1483
1784
|
state.instance_type = instance.instance_type
|
|
1484
1785
|
state.config_path = status.get("config_path", "")
|
|
1485
1786
|
# Use local capture path for screenshots if provided, else remote path
|
|
1486
|
-
state.capture_path =
|
|
1787
|
+
state.capture_path = (
|
|
1788
|
+
args.capture if args.capture else status.get("capture_path", "")
|
|
1789
|
+
)
|
|
1487
1790
|
state.epoch = status.get("epoch", 0)
|
|
1488
1791
|
state.step = status.get("step", 0)
|
|
1489
1792
|
state.loss = status.get("loss", 0)
|
|
@@ -1501,7 +1804,7 @@ def main():
|
|
|
1501
1804
|
|
|
1502
1805
|
config = TrainingConfig(
|
|
1503
1806
|
num_train_epochs=status.get("total_epochs", 5),
|
|
1504
|
-
learning_rate=status.get("learning_rate", 5e-5)
|
|
1807
|
+
learning_rate=status.get("learning_rate", 5e-5),
|
|
1505
1808
|
)
|
|
1506
1809
|
|
|
1507
1810
|
dashboard_path.write_text(generate_training_dashboard(state, config))
|
|
@@ -1509,6 +1812,7 @@ def main():
|
|
|
1509
1812
|
# Regenerate navigation for file:// protocol
|
|
1510
1813
|
try:
|
|
1511
1814
|
from openadapt_ml.training.trainer import regenerate_all_dashboards
|
|
1815
|
+
|
|
1512
1816
|
regenerate_all_dashboards(output_dir)
|
|
1513
1817
|
except Exception:
|
|
1514
1818
|
pass # Silent fail for navigation
|
|
@@ -1517,11 +1821,14 @@ def main():
|
|
|
1517
1821
|
step = status.get("step", 0)
|
|
1518
1822
|
loss = status.get("loss", 0)
|
|
1519
1823
|
elapsed = status.get("elapsed_time", 0)
|
|
1520
|
-
print(
|
|
1824
|
+
print(
|
|
1825
|
+
f"Epoch {epoch + 1}/{state.total_epochs} | Step {step} | Loss: {loss:.4f} | Elapsed: {elapsed:.0f}s"
|
|
1826
|
+
)
|
|
1521
1827
|
print(f"Dashboard: {dashboard_path.absolute()}")
|
|
1522
1828
|
|
|
1523
1829
|
if args.open:
|
|
1524
1830
|
import subprocess as sp
|
|
1831
|
+
|
|
1525
1832
|
sp.run(["open", str(dashboard_path)], capture_output=True)
|
|
1526
1833
|
else:
|
|
1527
1834
|
print("No training data yet")
|
|
@@ -1533,10 +1840,12 @@ def main():
|
|
|
1533
1840
|
from pathlib import Path
|
|
1534
1841
|
|
|
1535
1842
|
# Stub mode - simulate training without actual GPU
|
|
1536
|
-
if getattr(args,
|
|
1843
|
+
if getattr(args, "stub", False):
|
|
1537
1844
|
from openadapt_ml.training.stub_provider import StubTrainingProvider
|
|
1538
1845
|
from openadapt_ml.training.trainer import (
|
|
1539
|
-
TrainingState,
|
|
1846
|
+
TrainingState,
|
|
1847
|
+
TrainingConfig,
|
|
1848
|
+
generate_training_dashboard,
|
|
1540
1849
|
)
|
|
1541
1850
|
|
|
1542
1851
|
print("\n[Stub Mode] Simulating training without GPU...")
|
|
@@ -1574,7 +1883,7 @@ def main():
|
|
|
1574
1883
|
|
|
1575
1884
|
config = TrainingConfig(
|
|
1576
1885
|
num_train_epochs=status.get("total_epochs", 5),
|
|
1577
|
-
learning_rate=state.learning_rate
|
|
1886
|
+
learning_rate=state.learning_rate,
|
|
1578
1887
|
)
|
|
1579
1888
|
|
|
1580
1889
|
dashboard_path = output_dir / "dashboard.html"
|
|
@@ -1598,7 +1907,9 @@ def main():
|
|
|
1598
1907
|
return
|
|
1599
1908
|
|
|
1600
1909
|
if args.instance_id:
|
|
1601
|
-
instance = next(
|
|
1910
|
+
instance = next(
|
|
1911
|
+
(i for i in instances if i.id.startswith(args.instance_id)), None
|
|
1912
|
+
)
|
|
1602
1913
|
if not instance:
|
|
1603
1914
|
print(f"Instance {args.instance_id} not found.")
|
|
1604
1915
|
return
|
|
@@ -1621,9 +1932,13 @@ def main():
|
|
|
1621
1932
|
|
|
1622
1933
|
# Use job-scoped directory structure
|
|
1623
1934
|
from openadapt_ml.training.trainer import (
|
|
1624
|
-
TrainingState,
|
|
1625
|
-
|
|
1935
|
+
TrainingState,
|
|
1936
|
+
TrainingConfig,
|
|
1937
|
+
generate_training_dashboard,
|
|
1938
|
+
setup_job_directory,
|
|
1939
|
+
get_current_job_directory,
|
|
1626
1940
|
)
|
|
1941
|
+
|
|
1627
1942
|
base_dir = Path("training_output")
|
|
1628
1943
|
base_dir.mkdir(exist_ok=True)
|
|
1629
1944
|
|
|
@@ -1654,7 +1969,11 @@ def main():
|
|
|
1654
1969
|
state.instance_ip = instance.ip or ""
|
|
1655
1970
|
state.instance_type = instance.instance_type
|
|
1656
1971
|
state.setup_status = "booting"
|
|
1657
|
-
state.setup_logs = [
|
|
1972
|
+
state.setup_logs = [
|
|
1973
|
+
"Starting Lambda Cloud instance...",
|
|
1974
|
+
f"Instance ID: {instance.id[:8]}...",
|
|
1975
|
+
f"Instance type: {instance.instance_type}",
|
|
1976
|
+
]
|
|
1658
1977
|
config = TrainingConfig(num_train_epochs=5, learning_rate=5e-5)
|
|
1659
1978
|
dashboard_path.write_text(generate_training_dashboard(state, config))
|
|
1660
1979
|
|
|
@@ -1665,12 +1984,14 @@ def main():
|
|
|
1665
1984
|
|
|
1666
1985
|
last_step = 0
|
|
1667
1986
|
last_epoch = -1
|
|
1668
|
-
auto_stop_loss = getattr(args,
|
|
1669
|
-
download_checkpoints = getattr(args,
|
|
1987
|
+
auto_stop_loss = getattr(args, "auto_stop_loss", 0.5)
|
|
1988
|
+
download_checkpoints = getattr(args, "download_checkpoints", True)
|
|
1670
1989
|
step_stall_count = 0 # Track how many times step hasn't increased
|
|
1671
1990
|
|
|
1672
1991
|
print(f" Auto-stop loss threshold: {auto_stop_loss}")
|
|
1673
|
-
print(
|
|
1992
|
+
print(
|
|
1993
|
+
f" Checkpoint download: {'enabled' if download_checkpoints else 'disabled'}"
|
|
1994
|
+
)
|
|
1674
1995
|
|
|
1675
1996
|
try:
|
|
1676
1997
|
while True:
|
|
@@ -1684,10 +2005,11 @@ def main():
|
|
|
1684
2005
|
# Update status with termination info before terminating
|
|
1685
2006
|
termination_status = {
|
|
1686
2007
|
"termination_status": "user_stop",
|
|
1687
|
-
"termination_message": "Training stopped by user via dashboard"
|
|
2008
|
+
"termination_message": "Training stopped by user via dashboard",
|
|
1688
2009
|
}
|
|
1689
2010
|
current_log = log_path.read_text() if log_path.exists() else "{}"
|
|
1690
2011
|
import json as json_module
|
|
2012
|
+
|
|
1691
2013
|
current_data = json_module.loads(current_log)
|
|
1692
2014
|
current_data.update(termination_status)
|
|
1693
2015
|
log_path.write_text(json_module.dumps(current_data, indent=2))
|
|
@@ -1711,8 +2033,14 @@ def main():
|
|
|
1711
2033
|
remote_job_id = status.get("job_id")
|
|
1712
2034
|
|
|
1713
2035
|
# Detect job_id change - clear old data if new job started
|
|
1714
|
-
if
|
|
1715
|
-
|
|
2036
|
+
if (
|
|
2037
|
+
remote_job_id
|
|
2038
|
+
and current_job_id
|
|
2039
|
+
and remote_job_id != current_job_id
|
|
2040
|
+
):
|
|
2041
|
+
print(
|
|
2042
|
+
f"\n New job detected: {remote_job_id} (was: {current_job_id})"
|
|
2043
|
+
)
|
|
1716
2044
|
print(" Clearing old job data...")
|
|
1717
2045
|
last_step = 0 # Reset step tracking
|
|
1718
2046
|
current_job_id = remote_job_id
|
|
@@ -1727,25 +2055,37 @@ def main():
|
|
|
1727
2055
|
status["instance_type"] = instance.instance_type
|
|
1728
2056
|
# Add cloud provider info
|
|
1729
2057
|
status["cloud_provider"] = "lambda"
|
|
1730
|
-
status["cloud_dashboard_url"] =
|
|
2058
|
+
status["cloud_dashboard_url"] = (
|
|
2059
|
+
"https://cloud.lambda.ai/instances"
|
|
2060
|
+
)
|
|
1731
2061
|
status["cloud_instance_id"] = instance.id
|
|
1732
2062
|
status["setup_status"] = status.get("setup_status", "training")
|
|
1733
2063
|
|
|
1734
2064
|
# Setup screenshots symlink if local capture path provided
|
|
1735
|
-
local_capture =
|
|
2065
|
+
local_capture = (
|
|
2066
|
+
args.capture
|
|
2067
|
+
if hasattr(args, "capture") and args.capture
|
|
2068
|
+
else None
|
|
2069
|
+
)
|
|
1736
2070
|
if local_capture:
|
|
1737
2071
|
setup_capture_screenshots_symlink(output_dir, local_capture)
|
|
1738
2072
|
|
|
1739
2073
|
# Rewrite evaluation paths from Lambda to relative
|
|
1740
2074
|
if "evaluations" in status:
|
|
1741
|
-
status["evaluations"] = rewrite_evaluation_paths(
|
|
2075
|
+
status["evaluations"] = rewrite_evaluation_paths(
|
|
2076
|
+
status["evaluations"]
|
|
2077
|
+
)
|
|
1742
2078
|
|
|
1743
2079
|
log_path.write_text(json.dumps(status, indent=2))
|
|
1744
2080
|
|
|
1745
2081
|
if step > last_step:
|
|
1746
|
-
print(
|
|
2082
|
+
print(
|
|
2083
|
+
f" Epoch {epoch + 1} | Step {step} | Loss: {loss:.4f} | Elapsed: {elapsed:.0f}s"
|
|
2084
|
+
)
|
|
1747
2085
|
last_step = step
|
|
1748
|
-
step_stall_count =
|
|
2086
|
+
step_stall_count = (
|
|
2087
|
+
0 # Reset stall counter when step increases
|
|
2088
|
+
)
|
|
1749
2089
|
if not current_job_id:
|
|
1750
2090
|
current_job_id = remote_job_id
|
|
1751
2091
|
|
|
@@ -1764,39 +2104,59 @@ def main():
|
|
|
1764
2104
|
state.start_time = time_module.time() - elapsed
|
|
1765
2105
|
# Cloud provider info
|
|
1766
2106
|
state.cloud_provider = "lambda"
|
|
1767
|
-
state.cloud_dashboard_url =
|
|
2107
|
+
state.cloud_dashboard_url = (
|
|
2108
|
+
"https://cloud.lambda.ai/instances"
|
|
2109
|
+
)
|
|
1768
2110
|
state.cloud_instance_id = instance.id
|
|
1769
2111
|
state.setup_status = status.get("setup_status", "training")
|
|
1770
2112
|
state.setup_logs = status.get("setup_logs", [])
|
|
1771
|
-
state.termination_status = status.get(
|
|
1772
|
-
|
|
2113
|
+
state.termination_status = status.get(
|
|
2114
|
+
"termination_status", ""
|
|
2115
|
+
)
|
|
2116
|
+
state.termination_message = status.get(
|
|
2117
|
+
"termination_message", ""
|
|
2118
|
+
)
|
|
1773
2119
|
|
|
1774
2120
|
config = TrainingConfig(
|
|
1775
2121
|
num_train_epochs=status.get("total_epochs", 5),
|
|
1776
|
-
learning_rate=status.get("learning_rate", 5e-5)
|
|
2122
|
+
learning_rate=status.get("learning_rate", 5e-5),
|
|
1777
2123
|
)
|
|
1778
2124
|
|
|
1779
|
-
dashboard_path.write_text(
|
|
2125
|
+
dashboard_path.write_text(
|
|
2126
|
+
generate_training_dashboard(state, config)
|
|
2127
|
+
)
|
|
1780
2128
|
|
|
1781
2129
|
# Download checkpoints on epoch change
|
|
1782
2130
|
if download_checkpoints and epoch > last_epoch:
|
|
1783
|
-
print(
|
|
1784
|
-
|
|
1785
|
-
|
|
2131
|
+
print(
|
|
2132
|
+
f" Epoch {epoch + 1} completed - downloading checkpoints..."
|
|
2133
|
+
)
|
|
2134
|
+
if download_checkpoints_from_instance(
|
|
2135
|
+
instance.ip, output_dir
|
|
2136
|
+
):
|
|
2137
|
+
print(
|
|
2138
|
+
f" Checkpoints saved to {output_dir}/checkpoints/"
|
|
2139
|
+
)
|
|
1786
2140
|
else:
|
|
1787
2141
|
print(" Warning: checkpoint download failed")
|
|
1788
2142
|
last_epoch = epoch
|
|
1789
2143
|
|
|
1790
2144
|
# Auto-terminate when loss is low enough
|
|
1791
2145
|
if loss < auto_stop_loss and loss > 0:
|
|
1792
|
-
print(
|
|
2146
|
+
print(
|
|
2147
|
+
f"\n Loss {loss:.4f} < threshold {auto_stop_loss}"
|
|
2148
|
+
)
|
|
1793
2149
|
print(" Downloading final checkpoints...")
|
|
1794
2150
|
if download_checkpoints:
|
|
1795
|
-
download_checkpoints_from_instance(
|
|
2151
|
+
download_checkpoints_from_instance(
|
|
2152
|
+
instance.ip, output_dir
|
|
2153
|
+
)
|
|
1796
2154
|
|
|
1797
2155
|
# Update status with termination info
|
|
1798
2156
|
status["termination_status"] = "auto_low_loss"
|
|
1799
|
-
status["termination_message"] =
|
|
2157
|
+
status["termination_message"] = (
|
|
2158
|
+
f"Training auto-stopped: loss {loss:.4f} < threshold {auto_stop_loss}"
|
|
2159
|
+
)
|
|
1800
2160
|
log_path.write_text(json.dumps(status, indent=2))
|
|
1801
2161
|
|
|
1802
2162
|
print(f" Auto-terminating instance {instance.id}...")
|
|
@@ -1810,14 +2170,20 @@ def main():
|
|
|
1810
2170
|
|
|
1811
2171
|
# If on last epoch and step hasn't increased for 3 polls, training is complete
|
|
1812
2172
|
if epoch >= total_epochs - 1 and step_stall_count >= 3:
|
|
1813
|
-
print(
|
|
2173
|
+
print(
|
|
2174
|
+
f"\n Training complete (epoch {epoch + 1}/{total_epochs}, step stopped increasing)"
|
|
2175
|
+
)
|
|
1814
2176
|
print(" Downloading final checkpoints...")
|
|
1815
2177
|
if download_checkpoints:
|
|
1816
|
-
download_checkpoints_from_instance(
|
|
2178
|
+
download_checkpoints_from_instance(
|
|
2179
|
+
instance.ip, output_dir
|
|
2180
|
+
)
|
|
1817
2181
|
|
|
1818
2182
|
# Update status with termination info
|
|
1819
2183
|
status["termination_status"] = "auto_complete"
|
|
1820
|
-
status["termination_message"] =
|
|
2184
|
+
status["termination_message"] = (
|
|
2185
|
+
f"Training completed successfully ({epoch + 1}/{total_epochs} epochs)"
|
|
2186
|
+
)
|
|
1821
2187
|
log_path.write_text(json.dumps(status, indent=2))
|
|
1822
2188
|
|
|
1823
2189
|
print(f" Terminating instance {instance.id}...")
|
|
@@ -1849,7 +2215,9 @@ def main():
|
|
|
1849
2215
|
return
|
|
1850
2216
|
|
|
1851
2217
|
if args.instance_id:
|
|
1852
|
-
instance = next(
|
|
2218
|
+
instance = next(
|
|
2219
|
+
(i for i in instances if i.id.startswith(args.instance_id)), None
|
|
2220
|
+
)
|
|
1853
2221
|
if not instance:
|
|
1854
2222
|
print(f"Instance {args.instance_id} not found.")
|
|
1855
2223
|
return
|
|
@@ -1857,9 +2225,13 @@ def main():
|
|
|
1857
2225
|
instance = instances[0]
|
|
1858
2226
|
|
|
1859
2227
|
print(f"Files on {instance.ip} at {args.path}:")
|
|
1860
|
-
result = client.ssh_run(
|
|
2228
|
+
result = client.ssh_run(
|
|
2229
|
+
instance,
|
|
2230
|
+
f"find {args.path} -type f -name '*.pt' -o -name '*.json' -o -name '*.bin' 2>/dev/null | head -20",
|
|
2231
|
+
timeout=30,
|
|
2232
|
+
)
|
|
1861
2233
|
if result.stdout:
|
|
1862
|
-
for line in result.stdout.strip().split(
|
|
2234
|
+
for line in result.stdout.strip().split("\n"):
|
|
1863
2235
|
print(f" {line}")
|
|
1864
2236
|
else:
|
|
1865
2237
|
print(" (no checkpoint files found)")
|
|
@@ -1872,18 +2244,16 @@ def main():
|
|
|
1872
2244
|
if args.local:
|
|
1873
2245
|
print("\nKilling local Lambda-related processes...")
|
|
1874
2246
|
subprocess.run(
|
|
1875
|
-
["pkill", "-f", "ssh.*ubuntu@.*openadapt"],
|
|
1876
|
-
capture_output=True
|
|
1877
|
-
)
|
|
1878
|
-
subprocess.run(
|
|
1879
|
-
["pkill", "-f", "lambda_labs"],
|
|
1880
|
-
capture_output=True
|
|
2247
|
+
["pkill", "-f", "ssh.*ubuntu@.*openadapt"], capture_output=True
|
|
1881
2248
|
)
|
|
2249
|
+
subprocess.run(["pkill", "-f", "lambda_labs"], capture_output=True)
|
|
1882
2250
|
print("Done.")
|
|
1883
2251
|
return
|
|
1884
2252
|
|
|
1885
2253
|
if args.instance_id:
|
|
1886
|
-
instance = next(
|
|
2254
|
+
instance = next(
|
|
2255
|
+
(i for i in instances if i.id.startswith(args.instance_id)), None
|
|
2256
|
+
)
|
|
1887
2257
|
if not instance:
|
|
1888
2258
|
print(f"Instance {args.instance_id} not found.")
|
|
1889
2259
|
return
|
|
@@ -1896,11 +2266,11 @@ def main():
|
|
|
1896
2266
|
result = client.ssh_run(
|
|
1897
2267
|
instance,
|
|
1898
2268
|
"ps aux | grep python | grep -v grep | grep -v jupyter",
|
|
1899
|
-
timeout=30
|
|
2269
|
+
timeout=30,
|
|
1900
2270
|
)
|
|
1901
2271
|
if result.stdout.strip():
|
|
1902
2272
|
print("Found Python processes:")
|
|
1903
|
-
for line in result.stdout.strip().split(
|
|
2273
|
+
for line in result.stdout.strip().split("\n"):
|
|
1904
2274
|
print(f" {line[:100]}...")
|
|
1905
2275
|
else:
|
|
1906
2276
|
print("No training/inference Python processes found.")
|
|
@@ -1908,7 +2278,9 @@ def main():
|
|
|
1908
2278
|
|
|
1909
2279
|
if args.all:
|
|
1910
2280
|
print("\nKilling ALL Python processes (except jupyter)...")
|
|
1911
|
-
cmd =
|
|
2281
|
+
cmd = (
|
|
2282
|
+
"pkill -f 'python.*train\\|python.*compare\\|python.*openadapt' || true"
|
|
2283
|
+
)
|
|
1912
2284
|
else:
|
|
1913
2285
|
print("\nKilling training and inference processes...")
|
|
1914
2286
|
cmd = "pkill -f 'python.*train' ; pkill -f 'python.*compare' || true"
|
|
@@ -1919,20 +2291,16 @@ def main():
|
|
|
1919
2291
|
if args.local:
|
|
1920
2292
|
print("\nKilling local Lambda-related processes...")
|
|
1921
2293
|
subprocess.run(
|
|
1922
|
-
["pkill", "-f", "ssh.*ubuntu@.*openadapt"],
|
|
1923
|
-
capture_output=True
|
|
1924
|
-
)
|
|
1925
|
-
subprocess.run(
|
|
1926
|
-
["pkill", "-f", "lambda_labs.*train"],
|
|
1927
|
-
capture_output=True
|
|
2294
|
+
["pkill", "-f", "ssh.*ubuntu@.*openadapt"], capture_output=True
|
|
1928
2295
|
)
|
|
2296
|
+
subprocess.run(["pkill", "-f", "lambda_labs.*train"], capture_output=True)
|
|
1929
2297
|
print("Local processes killed.")
|
|
1930
2298
|
|
|
1931
2299
|
print("\nDone. Current status:")
|
|
1932
2300
|
result = client.ssh_run(
|
|
1933
2301
|
instance,
|
|
1934
2302
|
"ps aux | grep python | grep -v grep | grep -v jupyter | wc -l",
|
|
1935
|
-
timeout=30
|
|
2303
|
+
timeout=30,
|
|
1936
2304
|
)
|
|
1937
2305
|
count = result.stdout.strip()
|
|
1938
2306
|
print(f" {count} Python processes remaining on instance")
|
|
@@ -1945,7 +2313,9 @@ def main():
|
|
|
1945
2313
|
return
|
|
1946
2314
|
|
|
1947
2315
|
if args.instance_id:
|
|
1948
|
-
instance = next(
|
|
2316
|
+
instance = next(
|
|
2317
|
+
(i for i in instances if i.id.startswith(args.instance_id)), None
|
|
2318
|
+
)
|
|
1949
2319
|
if not instance:
|
|
1950
2320
|
print(f"Instance {args.instance_id} not found.")
|
|
1951
2321
|
return
|
|
@@ -1958,7 +2328,7 @@ def main():
|
|
|
1958
2328
|
result = client.ssh_run(
|
|
1959
2329
|
instance,
|
|
1960
2330
|
"cat ~/openadapt-ml/training_output/training_log.json 2>/dev/null",
|
|
1961
|
-
timeout=30
|
|
2331
|
+
timeout=30,
|
|
1962
2332
|
)
|
|
1963
2333
|
|
|
1964
2334
|
if not result.stdout.strip():
|
|
@@ -1977,77 +2347,87 @@ def main():
|
|
|
1977
2347
|
return
|
|
1978
2348
|
|
|
1979
2349
|
total_steps = len(losses)
|
|
1980
|
-
epochs = sorted(set(
|
|
2350
|
+
epochs = sorted(set(loss["epoch"] for loss in losses))
|
|
1981
2351
|
total_epochs = data.get("total_epochs", 5)
|
|
1982
|
-
min_loss = min(
|
|
2352
|
+
min_loss = min(loss["loss"] for loss in losses)
|
|
1983
2353
|
current_loss = losses[-1]["loss"]
|
|
1984
2354
|
|
|
1985
|
-
print(f"\n{'='*50}")
|
|
1986
|
-
print(
|
|
1987
|
-
print(f"{'='*50}")
|
|
2355
|
+
print(f"\n{'=' * 50}")
|
|
2356
|
+
print("TRAINING STATUS")
|
|
2357
|
+
print(f"{'=' * 50}")
|
|
1988
2358
|
print(f"Steps: {total_steps}")
|
|
1989
|
-
print(f"Epochs: {max(epochs)+1}/{total_epochs}")
|
|
2359
|
+
print(f"Epochs: {max(epochs) + 1}/{total_epochs}")
|
|
1990
2360
|
print(f"Current loss: {current_loss:.4f}")
|
|
1991
2361
|
print(f"Min loss: {min_loss:.4f}")
|
|
1992
2362
|
|
|
1993
2363
|
# Check if training is running
|
|
1994
2364
|
proc_result = client.ssh_run(
|
|
1995
|
-
instance,
|
|
1996
|
-
"ps aux | grep 'python.*train' | grep -v grep | wc -l",
|
|
1997
|
-
timeout=30
|
|
2365
|
+
instance, "ps aux | grep 'python.*train' | grep -v grep | wc -l", timeout=30
|
|
1998
2366
|
)
|
|
1999
2367
|
is_running = int(proc_result.stdout.strip()) > 0
|
|
2000
2368
|
|
|
2001
2369
|
if is_running:
|
|
2002
|
-
print(
|
|
2370
|
+
print("Status: RUNNING")
|
|
2003
2371
|
else:
|
|
2004
|
-
print(
|
|
2372
|
+
print("Status: STOPPED")
|
|
2005
2373
|
|
|
2006
2374
|
# Early stopping analysis
|
|
2007
2375
|
window = min(args.window, len(losses))
|
|
2008
2376
|
if window < 2:
|
|
2009
2377
|
print("\nNot enough data for early stopping analysis.")
|
|
2010
2378
|
else:
|
|
2011
|
-
recent_losses = [
|
|
2012
|
-
older_losses =
|
|
2379
|
+
recent_losses = [loss["loss"] for loss in losses[-window:]]
|
|
2380
|
+
older_losses = (
|
|
2381
|
+
[loss["loss"] for loss in losses[-window * 2 : -window]]
|
|
2382
|
+
if len(losses) >= window * 2
|
|
2383
|
+
else [loss["loss"] for loss in losses[:window]]
|
|
2384
|
+
)
|
|
2013
2385
|
|
|
2014
2386
|
recent_avg = sum(recent_losses) / len(recent_losses)
|
|
2015
|
-
older_avg =
|
|
2387
|
+
older_avg = (
|
|
2388
|
+
sum(older_losses) / len(older_losses) if older_losses else recent_avg
|
|
2389
|
+
)
|
|
2016
2390
|
|
|
2017
2391
|
improvement = (older_avg - recent_avg) / older_avg if older_avg > 0 else 0
|
|
2018
2392
|
loss_variance = max(recent_losses) - min(recent_losses)
|
|
2019
2393
|
|
|
2020
|
-
print(f"\n{'='*50}")
|
|
2394
|
+
print(f"\n{'=' * 50}")
|
|
2021
2395
|
print(f"EARLY STOPPING ANALYSIS (window={window})")
|
|
2022
|
-
print(f"{'='*50}")
|
|
2396
|
+
print(f"{'=' * 50}")
|
|
2023
2397
|
print(f"Recent avg loss: {recent_avg:.4f}")
|
|
2024
2398
|
print(f"Prior avg loss: {older_avg:.4f}")
|
|
2025
|
-
print(f"Improvement: {improvement*100:.2f}%")
|
|
2399
|
+
print(f"Improvement: {improvement * 100:.2f}%")
|
|
2026
2400
|
print(f"Loss variance: {loss_variance:.4f}")
|
|
2027
2401
|
|
|
2028
2402
|
should_stop = improvement < args.threshold and loss_variance < 0.1
|
|
2029
2403
|
if should_stop:
|
|
2030
|
-
print(
|
|
2031
|
-
print(f" Loss has plateaued (improvement < {args.threshold*100}%)")
|
|
2404
|
+
print("\n⚠️ EARLY STOPPING RECOMMENDED")
|
|
2405
|
+
print(f" Loss has plateaued (improvement < {args.threshold * 100}%)")
|
|
2032
2406
|
if not is_running:
|
|
2033
|
-
print(
|
|
2407
|
+
print(" (Training already stopped)")
|
|
2034
2408
|
else:
|
|
2035
|
-
print(
|
|
2409
|
+
print(
|
|
2410
|
+
"\n To stop: uv run python -m openadapt_ml.cloud.lambda_labs kill"
|
|
2411
|
+
)
|
|
2036
2412
|
else:
|
|
2037
|
-
print(
|
|
2413
|
+
print("\n✓ Training still improving, continue.")
|
|
2038
2414
|
|
|
2039
2415
|
# Time estimate
|
|
2040
2416
|
if is_running and len(losses) >= 2:
|
|
2041
|
-
avg_time_per_step =
|
|
2417
|
+
avg_time_per_step = (
|
|
2418
|
+
losses[-1].get("time", 0) / len(losses)
|
|
2419
|
+
if losses[-1].get("time")
|
|
2420
|
+
else 50
|
|
2421
|
+
)
|
|
2042
2422
|
steps_per_epoch = len(losses) / (max(epochs) + 1)
|
|
2043
2423
|
remaining_epochs = total_epochs - max(epochs) - 1
|
|
2044
2424
|
remaining_steps = remaining_epochs * steps_per_epoch
|
|
2045
2425
|
eta_seconds = remaining_steps * avg_time_per_step
|
|
2046
2426
|
eta_mins = eta_seconds / 60
|
|
2047
2427
|
|
|
2048
|
-
print(f"\n{'='*50}")
|
|
2049
|
-
print(
|
|
2050
|
-
print(f"{'='*50}")
|
|
2428
|
+
print(f"\n{'=' * 50}")
|
|
2429
|
+
print("TIME ESTIMATE")
|
|
2430
|
+
print(f"{'=' * 50}")
|
|
2051
2431
|
print(f"Remaining epochs: {remaining_epochs}")
|
|
2052
2432
|
print(f"Est. remaining steps: {remaining_steps:.0f}")
|
|
2053
2433
|
print(f"ETA: {eta_mins:.1f} minutes")
|
|
@@ -2060,7 +2440,9 @@ def main():
|
|
|
2060
2440
|
return
|
|
2061
2441
|
|
|
2062
2442
|
if args.instance_id:
|
|
2063
|
-
instance = next(
|
|
2443
|
+
instance = next(
|
|
2444
|
+
(i for i in instances if i.id.startswith(args.instance_id)), None
|
|
2445
|
+
)
|
|
2064
2446
|
if not instance:
|
|
2065
2447
|
print(f"Instance {args.instance_id} not found.")
|
|
2066
2448
|
return
|
|
@@ -2071,24 +2453,26 @@ def main():
|
|
|
2071
2453
|
if args.checkpoint:
|
|
2072
2454
|
checkpoint_path = args.checkpoint
|
|
2073
2455
|
elif args.epoch is not None:
|
|
2074
|
-
checkpoint_path =
|
|
2456
|
+
checkpoint_path = (
|
|
2457
|
+
f"/home/ubuntu/openadapt-ml/checkpoints/epoch_{args.epoch}"
|
|
2458
|
+
)
|
|
2075
2459
|
else:
|
|
2076
2460
|
# Use latest (main checkpoint)
|
|
2077
|
-
checkpoint_path =
|
|
2461
|
+
checkpoint_path = (
|
|
2462
|
+
"/home/ubuntu/openadapt-ml/checkpoints/qwen3vl2b_capture_lora"
|
|
2463
|
+
)
|
|
2078
2464
|
|
|
2079
2465
|
# Check if checkpoint exists
|
|
2080
2466
|
result = client.ssh_run(
|
|
2081
2467
|
instance,
|
|
2082
2468
|
f"ls {checkpoint_path}/adapter_config.json 2>/dev/null && echo 'exists'",
|
|
2083
|
-
timeout=30
|
|
2469
|
+
timeout=30,
|
|
2084
2470
|
)
|
|
2085
2471
|
if "exists" not in result.stdout:
|
|
2086
2472
|
print(f"Checkpoint not found at {checkpoint_path}")
|
|
2087
2473
|
# List available checkpoints
|
|
2088
2474
|
result = client.ssh_run(
|
|
2089
|
-
instance,
|
|
2090
|
-
"ls -la ~/openadapt-ml/checkpoints/",
|
|
2091
|
-
timeout=30
|
|
2475
|
+
instance, "ls -la ~/openadapt-ml/checkpoints/", timeout=30
|
|
2092
2476
|
)
|
|
2093
2477
|
print(f"Available checkpoints:\n{result.stdout}")
|
|
2094
2478
|
return
|
|
@@ -2113,9 +2497,7 @@ def main():
|
|
|
2113
2497
|
|
|
2114
2498
|
# Check if file was created
|
|
2115
2499
|
result = client.ssh_run(
|
|
2116
|
-
instance,
|
|
2117
|
-
f"ls -la ~/openadapt-ml/training_output/{output_name}",
|
|
2118
|
-
timeout=30
|
|
2500
|
+
instance, f"ls -la ~/openadapt-ml/training_output/{output_name}", timeout=30
|
|
2119
2501
|
)
|
|
2120
2502
|
if result.returncode != 0:
|
|
2121
2503
|
print("Comparison file not created.")
|
|
@@ -2128,11 +2510,15 @@ def main():
|
|
|
2128
2510
|
local_output.parent.mkdir(parents=True, exist_ok=True)
|
|
2129
2511
|
|
|
2130
2512
|
print(f"Syncing to {local_output}...")
|
|
2131
|
-
subprocess.run(
|
|
2132
|
-
|
|
2133
|
-
|
|
2134
|
-
|
|
2135
|
-
|
|
2513
|
+
subprocess.run(
|
|
2514
|
+
[
|
|
2515
|
+
"rsync",
|
|
2516
|
+
"-avz",
|
|
2517
|
+
f"ubuntu@{instance.ip}:~/openadapt-ml/training_output/{output_name}",
|
|
2518
|
+
str(local_output),
|
|
2519
|
+
],
|
|
2520
|
+
capture_output=True,
|
|
2521
|
+
)
|
|
2136
2522
|
|
|
2137
2523
|
print(f"Done! Comparison saved to: {local_output}")
|
|
2138
2524
|
|
|
@@ -2147,7 +2533,9 @@ def main():
|
|
|
2147
2533
|
return
|
|
2148
2534
|
|
|
2149
2535
|
if args.instance_id:
|
|
2150
|
-
instance = next(
|
|
2536
|
+
instance = next(
|
|
2537
|
+
(i for i in instances if i.id.startswith(args.instance_id)), None
|
|
2538
|
+
)
|
|
2151
2539
|
if not instance:
|
|
2152
2540
|
print(f"Instance {args.instance_id} not found.")
|
|
2153
2541
|
return
|
|
@@ -2164,7 +2552,9 @@ def main():
|
|
|
2164
2552
|
return
|
|
2165
2553
|
|
|
2166
2554
|
if args.instance_id:
|
|
2167
|
-
instance = next(
|
|
2555
|
+
instance = next(
|
|
2556
|
+
(i for i in instances if i.id.startswith(args.instance_id)), None
|
|
2557
|
+
)
|
|
2168
2558
|
if not instance:
|
|
2169
2559
|
print(f"Instance {args.instance_id} not found.")
|
|
2170
2560
|
return
|
|
@@ -2180,10 +2570,17 @@ def main():
|
|
|
2180
2570
|
checkpoint_path = "checkpoints_lambda/qwen3vl2b_capture_lora"
|
|
2181
2571
|
|
|
2182
2572
|
import subprocess as sp
|
|
2573
|
+
|
|
2183
2574
|
cmd = [
|
|
2184
|
-
"uv",
|
|
2185
|
-
"
|
|
2186
|
-
"
|
|
2575
|
+
"uv",
|
|
2576
|
+
"run",
|
|
2577
|
+
"python",
|
|
2578
|
+
"-m",
|
|
2579
|
+
"openadapt_ml.scripts.compare",
|
|
2580
|
+
"--capture",
|
|
2581
|
+
args.capture,
|
|
2582
|
+
"--checkpoint",
|
|
2583
|
+
checkpoint_path,
|
|
2187
2584
|
]
|
|
2188
2585
|
if args.goal:
|
|
2189
2586
|
cmd.extend(["--goal", args.goal])
|
|
@@ -2202,11 +2599,12 @@ def main():
|
|
|
2202
2599
|
# Start web server for live dashboard with stop button support
|
|
2203
2600
|
import http.server
|
|
2204
2601
|
import socketserver
|
|
2205
|
-
import threading
|
|
2206
2602
|
import time as time_module
|
|
2207
2603
|
from pathlib import Path
|
|
2208
2604
|
|
|
2209
|
-
output_dir =
|
|
2605
|
+
output_dir = (
|
|
2606
|
+
Path(args.output) if hasattr(args, "output") else Path("training_output")
|
|
2607
|
+
)
|
|
2210
2608
|
port = args.port
|
|
2211
2609
|
|
|
2212
2610
|
if not output_dir.exists():
|
|
@@ -2219,13 +2617,13 @@ def main():
|
|
|
2219
2617
|
super().__init__(*args, directory=str(output_dir), **kwargs)
|
|
2220
2618
|
|
|
2221
2619
|
def do_POST(self):
|
|
2222
|
-
if self.path ==
|
|
2620
|
+
if self.path == "/api/stop":
|
|
2223
2621
|
# Create stop signal file
|
|
2224
2622
|
stop_file = output_dir / "STOP_TRAINING"
|
|
2225
2623
|
stop_file.touch()
|
|
2226
2624
|
self.send_response(200)
|
|
2227
|
-
self.send_header(
|
|
2228
|
-
self.send_header(
|
|
2625
|
+
self.send_header("Content-Type", "application/json")
|
|
2626
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
2229
2627
|
self.end_headers()
|
|
2230
2628
|
self.wfile.write(b'{"status": "stop signal created"}')
|
|
2231
2629
|
print(f" Stop signal created: {stop_file}")
|
|
@@ -2235,15 +2633,14 @@ def main():
|
|
|
2235
2633
|
def do_OPTIONS(self):
|
|
2236
2634
|
# Handle CORS preflight
|
|
2237
2635
|
self.send_response(200)
|
|
2238
|
-
self.send_header(
|
|
2239
|
-
self.send_header(
|
|
2240
|
-
self.send_header(
|
|
2636
|
+
self.send_header("Access-Control-Allow-Origin", "*")
|
|
2637
|
+
self.send_header("Access-Control-Allow-Methods", "POST, OPTIONS")
|
|
2638
|
+
self.send_header("Access-Control-Allow-Headers", "Content-Type")
|
|
2241
2639
|
self.end_headers()
|
|
2242
2640
|
|
|
2243
2641
|
def log_message(self, format, *args):
|
|
2244
2642
|
pass # Suppress log messages
|
|
2245
2643
|
|
|
2246
|
-
|
|
2247
2644
|
# Start web server
|
|
2248
2645
|
with socketserver.TCPServer(("", port), Handler) as httpd:
|
|
2249
2646
|
url = f"http://localhost:{port}/dashboard.html"
|
|
@@ -2262,8 +2659,10 @@ def main():
|
|
|
2262
2659
|
# Sync training output from Lambda and regenerate navigation for file:// protocol
|
|
2263
2660
|
from pathlib import Path
|
|
2264
2661
|
from openadapt_ml.training.trainer import (
|
|
2265
|
-
TrainingState,
|
|
2266
|
-
|
|
2662
|
+
TrainingState,
|
|
2663
|
+
TrainingConfig,
|
|
2664
|
+
generate_training_dashboard,
|
|
2665
|
+
regenerate_all_dashboards,
|
|
2267
2666
|
)
|
|
2268
2667
|
|
|
2269
2668
|
instances = client.list_instances()
|
|
@@ -2272,7 +2671,9 @@ def main():
|
|
|
2272
2671
|
return
|
|
2273
2672
|
|
|
2274
2673
|
if args.instance_id:
|
|
2275
|
-
instance = next(
|
|
2674
|
+
instance = next(
|
|
2675
|
+
(i for i in instances if i.id.startswith(args.instance_id)), None
|
|
2676
|
+
)
|
|
2276
2677
|
if not instance:
|
|
2277
2678
|
print(f"Instance {args.instance_id} not found.")
|
|
2278
2679
|
return
|
|
@@ -2286,10 +2687,13 @@ def main():
|
|
|
2286
2687
|
|
|
2287
2688
|
# Sync all training output files
|
|
2288
2689
|
rsync_cmd = [
|
|
2289
|
-
"rsync",
|
|
2290
|
-
"-
|
|
2690
|
+
"rsync",
|
|
2691
|
+
"-avz",
|
|
2692
|
+
"--progress",
|
|
2693
|
+
"-e",
|
|
2694
|
+
"ssh -o StrictHostKeyChecking=no",
|
|
2291
2695
|
f"ubuntu@{instance.ip}:~/openadapt-ml/training_output/",
|
|
2292
|
-
str(output_dir) + "/"
|
|
2696
|
+
str(output_dir) + "/",
|
|
2293
2697
|
]
|
|
2294
2698
|
result = subprocess.run(rsync_cmd, capture_output=False)
|
|
2295
2699
|
|
|
@@ -2303,6 +2707,7 @@ def main():
|
|
|
2303
2707
|
if log_path.exists():
|
|
2304
2708
|
try:
|
|
2305
2709
|
import time as time_module
|
|
2710
|
+
|
|
2306
2711
|
status = json.loads(log_path.read_text())
|
|
2307
2712
|
|
|
2308
2713
|
# Update with instance info
|
|
@@ -2336,7 +2741,7 @@ def main():
|
|
|
2336
2741
|
|
|
2337
2742
|
config = TrainingConfig(
|
|
2338
2743
|
num_train_epochs=status.get("total_epochs", 5),
|
|
2339
|
-
learning_rate=status.get("learning_rate", 5e-5)
|
|
2744
|
+
learning_rate=status.get("learning_rate", 5e-5),
|
|
2340
2745
|
)
|
|
2341
2746
|
|
|
2342
2747
|
dashboard_path.write_text(generate_training_dashboard(state, config))
|
|
@@ -2390,7 +2795,7 @@ def main():
|
|
|
2390
2795
|
# First try training log
|
|
2391
2796
|
log_data = json.loads((output_dir / "training_log.json").read_text())
|
|
2392
2797
|
capture_path = log_data.get("capture_path", "")
|
|
2393
|
-
capture_match = re.search(r
|
|
2798
|
+
capture_match = re.search(r"capture_(\d+)", capture_path)
|
|
2394
2799
|
if capture_match:
|
|
2395
2800
|
capture_id = capture_match.group(1)
|
|
2396
2801
|
|
|
@@ -2401,27 +2806,37 @@ def main():
|
|
|
2401
2806
|
base_data = pred_data.get("base_data", [])
|
|
2402
2807
|
if base_data:
|
|
2403
2808
|
image_path = base_data[0].get("image_path", "")
|
|
2404
|
-
capture_match = re.search(r
|
|
2809
|
+
capture_match = re.search(r"capture_(\d+)", image_path)
|
|
2405
2810
|
if capture_match:
|
|
2406
2811
|
capture_id = capture_match.group(1)
|
|
2407
2812
|
break
|
|
2408
2813
|
|
|
2409
2814
|
if capture_id:
|
|
2410
2815
|
# Search for local screenshots in openadapt-capture
|
|
2411
|
-
openadapt_capture_dir =
|
|
2816
|
+
openadapt_capture_dir = (
|
|
2817
|
+
Path.home() / "oa" / "src" / "openadapt-capture"
|
|
2818
|
+
)
|
|
2412
2819
|
if openadapt_capture_dir.exists():
|
|
2413
2820
|
for capture_dir in openadapt_capture_dir.iterdir():
|
|
2414
2821
|
if capture_dir.is_dir():
|
|
2415
2822
|
screenshots_dir = capture_dir / "screenshots"
|
|
2416
2823
|
if screenshots_dir.exists():
|
|
2417
2824
|
# Check if this capture has our screenshots
|
|
2418
|
-
sample_file = list(
|
|
2825
|
+
sample_file = list(
|
|
2826
|
+
screenshots_dir.glob(
|
|
2827
|
+
f"capture_{capture_id}_step_*.png"
|
|
2828
|
+
)
|
|
2829
|
+
)
|
|
2419
2830
|
if sample_file:
|
|
2420
|
-
print(
|
|
2831
|
+
print(
|
|
2832
|
+
f"Found local screenshots in {screenshots_dir}"
|
|
2833
|
+
)
|
|
2421
2834
|
screenshots_link.symlink_to(screenshots_dir)
|
|
2422
|
-
print(
|
|
2835
|
+
print(
|
|
2836
|
+
f" Linked: {screenshots_link} -> {screenshots_dir}"
|
|
2837
|
+
)
|
|
2423
2838
|
break
|
|
2424
|
-
except Exception
|
|
2839
|
+
except Exception:
|
|
2425
2840
|
pass # Silently continue if auto-link fails
|
|
2426
2841
|
|
|
2427
2842
|
print(f"Regenerating viewer from {output_dir}...")
|
|
@@ -2435,7 +2850,7 @@ def main():
|
|
|
2435
2850
|
target = output_dir / "viewer.html"
|
|
2436
2851
|
|
|
2437
2852
|
print(f"\nGenerated: {target.absolute()}")
|
|
2438
|
-
print(
|
|
2853
|
+
print("View with: uv run python -m openadapt_ml.cloud.lambda_labs serve --open")
|
|
2439
2854
|
|
|
2440
2855
|
if args.open:
|
|
2441
2856
|
subprocess.run(["open", str(target)], capture_output=True)
|