openadapt-ml 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/baselines/__init__.py +121 -0
- openadapt_ml/baselines/adapter.py +185 -0
- openadapt_ml/baselines/cli.py +314 -0
- openadapt_ml/baselines/config.py +448 -0
- openadapt_ml/baselines/parser.py +922 -0
- openadapt_ml/baselines/prompts.py +787 -0
- openadapt_ml/benchmarks/__init__.py +13 -115
- openadapt_ml/benchmarks/agent.py +265 -421
- openadapt_ml/benchmarks/azure.py +28 -19
- openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
- openadapt_ml/benchmarks/cli.py +1722 -4847
- openadapt_ml/benchmarks/trace_export.py +631 -0
- openadapt_ml/benchmarks/viewer.py +22 -5
- openadapt_ml/benchmarks/vm_monitor.py +530 -29
- openadapt_ml/benchmarks/waa_deploy/Dockerfile +47 -53
- openadapt_ml/benchmarks/waa_deploy/api_agent.py +21 -20
- openadapt_ml/cloud/azure_inference.py +3 -5
- openadapt_ml/cloud/lambda_labs.py +722 -307
- openadapt_ml/cloud/local.py +2038 -487
- openadapt_ml/cloud/ssh_tunnel.py +68 -26
- openadapt_ml/datasets/next_action.py +40 -30
- openadapt_ml/evals/grounding.py +8 -3
- openadapt_ml/evals/plot_eval_metrics.py +15 -13
- openadapt_ml/evals/trajectory_matching.py +41 -26
- openadapt_ml/experiments/demo_prompt/format_demo.py +16 -6
- openadapt_ml/experiments/demo_prompt/run_experiment.py +26 -16
- openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
- openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
- openadapt_ml/experiments/representation_shootout/config.py +390 -0
- openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
- openadapt_ml/experiments/representation_shootout/runner.py +687 -0
- openadapt_ml/experiments/waa_demo/runner.py +29 -14
- openadapt_ml/export/parquet.py +36 -24
- openadapt_ml/grounding/detector.py +18 -14
- openadapt_ml/ingest/__init__.py +8 -6
- openadapt_ml/ingest/capture.py +25 -22
- openadapt_ml/ingest/loader.py +7 -4
- openadapt_ml/ingest/synthetic.py +189 -100
- openadapt_ml/models/api_adapter.py +14 -4
- openadapt_ml/models/base_adapter.py +10 -2
- openadapt_ml/models/providers/__init__.py +288 -0
- openadapt_ml/models/providers/anthropic.py +266 -0
- openadapt_ml/models/providers/base.py +299 -0
- openadapt_ml/models/providers/google.py +376 -0
- openadapt_ml/models/providers/openai.py +342 -0
- openadapt_ml/models/qwen_vl.py +46 -19
- openadapt_ml/perception/__init__.py +35 -0
- openadapt_ml/perception/integration.py +399 -0
- openadapt_ml/retrieval/demo_retriever.py +50 -24
- openadapt_ml/retrieval/embeddings.py +9 -8
- openadapt_ml/retrieval/retriever.py +3 -1
- openadapt_ml/runtime/__init__.py +50 -0
- openadapt_ml/runtime/policy.py +18 -5
- openadapt_ml/runtime/safety_gate.py +471 -0
- openadapt_ml/schema/__init__.py +9 -0
- openadapt_ml/schema/converters.py +74 -27
- openadapt_ml/schema/episode.py +31 -18
- openadapt_ml/scripts/capture_screenshots.py +530 -0
- openadapt_ml/scripts/compare.py +85 -54
- openadapt_ml/scripts/demo_policy.py +4 -1
- openadapt_ml/scripts/eval_policy.py +15 -9
- openadapt_ml/scripts/make_gif.py +1 -1
- openadapt_ml/scripts/prepare_synthetic.py +3 -1
- openadapt_ml/scripts/train.py +21 -9
- openadapt_ml/segmentation/README.md +920 -0
- openadapt_ml/segmentation/__init__.py +97 -0
- openadapt_ml/segmentation/adapters/__init__.py +5 -0
- openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
- openadapt_ml/segmentation/annotator.py +610 -0
- openadapt_ml/segmentation/cache.py +290 -0
- openadapt_ml/segmentation/cli.py +674 -0
- openadapt_ml/segmentation/deduplicator.py +656 -0
- openadapt_ml/segmentation/frame_describer.py +788 -0
- openadapt_ml/segmentation/pipeline.py +340 -0
- openadapt_ml/segmentation/schemas.py +622 -0
- openadapt_ml/segmentation/segment_extractor.py +634 -0
- openadapt_ml/training/azure_ops_viewer.py +1097 -0
- openadapt_ml/training/benchmark_viewer.py +52 -41
- openadapt_ml/training/shared_ui.py +7 -7
- openadapt_ml/training/stub_provider.py +57 -35
- openadapt_ml/training/trainer.py +143 -86
- openadapt_ml/training/trl_trainer.py +70 -21
- openadapt_ml/training/viewer.py +323 -108
- openadapt_ml/training/viewer_components.py +180 -0
- {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +215 -14
- openadapt_ml-0.2.1.dist-info/RECORD +116 -0
- openadapt_ml/benchmarks/base.py +0 -366
- openadapt_ml/benchmarks/data_collection.py +0 -432
- openadapt_ml/benchmarks/live_tracker.py +0 -180
- openadapt_ml/benchmarks/runner.py +0 -418
- openadapt_ml/benchmarks/waa.py +0 -761
- openadapt_ml/benchmarks/waa_live.py +0 -619
- openadapt_ml-0.2.0.dist-info/RECORD +0 -86
- {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
- {openadapt_ml-0.2.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
openadapt_ml/training/trainer.py
CHANGED
|
@@ -4,9 +4,9 @@ import json
|
|
|
4
4
|
import time
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from typing import Any,
|
|
7
|
+
from typing import Any, Dict, List
|
|
8
8
|
|
|
9
|
-
from openadapt_ml.schema import
|
|
9
|
+
from openadapt_ml.schema import ActionType
|
|
10
10
|
from openadapt_ml.training.shared_ui import (
|
|
11
11
|
get_shared_header_css as _get_shared_header_css,
|
|
12
12
|
generate_shared_header_html as _generate_shared_header_html,
|
|
@@ -108,9 +108,10 @@ class TrainingConfig:
|
|
|
108
108
|
@dataclass
|
|
109
109
|
class TrainingState:
|
|
110
110
|
"""Tracks training progress for visualization."""
|
|
111
|
+
|
|
111
112
|
# Job identification
|
|
112
113
|
job_id: str = field(default_factory=lambda: time.strftime("%Y%m%d_%H%M%S"))
|
|
113
|
-
hostname: str = field(default_factory=lambda: __import__(
|
|
114
|
+
hostname: str = field(default_factory=lambda: __import__("socket").gethostname())
|
|
114
115
|
capture_path: str = ""
|
|
115
116
|
config_path: str = ""
|
|
116
117
|
goal: str = "" # Task goal/description for the training run
|
|
@@ -142,7 +143,9 @@ class TrainingState:
|
|
|
142
143
|
setup_status: str = "" # e.g. "booting", "installing", "training", "complete"
|
|
143
144
|
setup_logs: List[str] = field(default_factory=list) # Setup progress messages
|
|
144
145
|
# Termination tracking
|
|
145
|
-
termination_status: str =
|
|
146
|
+
termination_status: str = (
|
|
147
|
+
"" # e.g. "auto_low_loss", "auto_complete", "user_stop", "running"
|
|
148
|
+
)
|
|
146
149
|
termination_message: str = "" # Human-readable termination reason
|
|
147
150
|
|
|
148
151
|
def log_step(self, epoch: int, step: int, loss: float, lr: float = 0.0) -> None:
|
|
@@ -151,33 +154,46 @@ class TrainingState:
|
|
|
151
154
|
self.step = step
|
|
152
155
|
self.loss = loss
|
|
153
156
|
self.learning_rate = lr
|
|
154
|
-
self.losses.append(
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
157
|
+
self.losses.append(
|
|
158
|
+
{
|
|
159
|
+
"epoch": epoch,
|
|
160
|
+
"step": step,
|
|
161
|
+
"loss": loss,
|
|
162
|
+
"lr": lr,
|
|
163
|
+
"time": time.time() - self.start_time,
|
|
164
|
+
}
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
def log_evaluation(
|
|
168
|
+
self,
|
|
169
|
+
epoch: int,
|
|
170
|
+
sample_idx: int,
|
|
171
|
+
image_path: str,
|
|
172
|
+
human_action: Dict,
|
|
173
|
+
predicted_action: Dict,
|
|
174
|
+
) -> None:
|
|
164
175
|
"""Log an evaluation sample."""
|
|
165
176
|
# Calculate distance for click actions
|
|
166
177
|
distance = 0.0
|
|
167
|
-
if
|
|
178
|
+
if (
|
|
179
|
+
human_action.get("type") == "click"
|
|
180
|
+
and predicted_action.get("type") == "click"
|
|
181
|
+
):
|
|
168
182
|
hx, hy = human_action.get("x", 0), human_action.get("y", 0)
|
|
169
183
|
px, py = predicted_action.get("x", 0), predicted_action.get("y", 0)
|
|
170
184
|
distance = ((hx - px) ** 2 + (hy - py) ** 2) ** 0.5
|
|
171
185
|
|
|
172
|
-
self.evaluations.append(
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
186
|
+
self.evaluations.append(
|
|
187
|
+
{
|
|
188
|
+
"epoch": epoch,
|
|
189
|
+
"sample_idx": sample_idx,
|
|
190
|
+
"image_path": image_path,
|
|
191
|
+
"human_action": human_action,
|
|
192
|
+
"predicted_action": predicted_action,
|
|
193
|
+
"distance": distance,
|
|
194
|
+
"correct": distance < 50, # Within 50 pixels is "correct"
|
|
195
|
+
}
|
|
196
|
+
)
|
|
181
197
|
|
|
182
198
|
def to_dict(self) -> Dict[str, Any]:
|
|
183
199
|
"""Convert state to serializable dict."""
|
|
@@ -195,7 +211,9 @@ class TrainingState:
|
|
|
195
211
|
"load_in_4bit": self.load_in_4bit,
|
|
196
212
|
"instance_type": self.instance_type,
|
|
197
213
|
"instance_ip": self.instance_ip,
|
|
198
|
-
"started_at": time.strftime(
|
|
214
|
+
"started_at": time.strftime(
|
|
215
|
+
"%Y-%m-%dT%H:%M:%SZ", time.gmtime(self.start_time)
|
|
216
|
+
),
|
|
199
217
|
# Cloud provider info
|
|
200
218
|
"cloud_provider": self.cloud_provider,
|
|
201
219
|
"cloud_dashboard_url": self.cloud_dashboard_url,
|
|
@@ -316,6 +334,7 @@ class TrainingLogger:
|
|
|
316
334
|
def _save_config_snapshot(self) -> None:
|
|
317
335
|
"""Save training config snapshot to JSON."""
|
|
318
336
|
from dataclasses import asdict
|
|
337
|
+
|
|
319
338
|
config_file = self.output_dir / "config.json"
|
|
320
339
|
config_dict = asdict(self.config)
|
|
321
340
|
with open(config_file, "w") as f:
|
|
@@ -333,32 +352,45 @@ class TrainingLogger:
|
|
|
333
352
|
dashboard_path.write_text(html)
|
|
334
353
|
|
|
335
354
|
|
|
336
|
-
def _generate_termination_status_html(
|
|
355
|
+
def _generate_termination_status_html(
|
|
356
|
+
state: TrainingState, is_training_complete: bool
|
|
357
|
+
) -> str:
|
|
337
358
|
"""Generate HTML for termination status section."""
|
|
338
359
|
# Check if we have termination info
|
|
339
360
|
if state.termination_status:
|
|
340
361
|
# Map termination status to colors and icons
|
|
341
362
|
status_styles = {
|
|
342
|
-
"auto_complete": {
|
|
343
|
-
|
|
363
|
+
"auto_complete": {
|
|
364
|
+
"color": "#22c55e",
|
|
365
|
+
"icon": "✓",
|
|
366
|
+
"label": "Training Complete",
|
|
367
|
+
},
|
|
368
|
+
"auto_low_loss": {
|
|
369
|
+
"color": "#22c55e",
|
|
370
|
+
"icon": "✓",
|
|
371
|
+
"label": "Auto-Stopped (Low Loss)",
|
|
372
|
+
},
|
|
344
373
|
"user_stop": {"color": "#f59e0b", "icon": "■", "label": "Stopped by User"},
|
|
345
374
|
}
|
|
346
|
-
style = status_styles.get(
|
|
375
|
+
style = status_styles.get(
|
|
376
|
+
state.termination_status,
|
|
377
|
+
{"color": "#22c55e", "icon": "✓", "label": "Complete"},
|
|
378
|
+
)
|
|
347
379
|
|
|
348
|
-
return f
|
|
349
|
-
<div style="display: flex; align-items: center; gap: 8px; color: {style[
|
|
350
|
-
<span style="font-size: 1.2rem;">{style[
|
|
351
|
-
<span style="font-weight: 600;">{style[
|
|
380
|
+
return f"""<div style="display: flex; flex-direction: column; gap: 8px;">
|
|
381
|
+
<div style="display: flex; align-items: center; gap: 8px; color: {style["color"]};">
|
|
382
|
+
<span style="font-size: 1.2rem;">{style["icon"]}</span>
|
|
383
|
+
<span style="font-weight: 600;">{style["label"]}</span>
|
|
352
384
|
</div>
|
|
353
|
-
{f'<div style="font-size: 0.85rem; color: var(--text-muted); margin-left: 28px;">{state.termination_message}</div>' if state.termination_message else
|
|
354
|
-
</div>
|
|
385
|
+
{f'<div style="font-size: 0.85rem; color: var(--text-muted); margin-left: 28px;">{state.termination_message}</div>' if state.termination_message else ""}
|
|
386
|
+
</div>"""
|
|
355
387
|
elif is_training_complete:
|
|
356
|
-
return
|
|
388
|
+
return """<div style="display: flex; align-items: center; gap: 8px; color: #22c55e;">
|
|
357
389
|
<span style="font-size: 1.2rem;">✓</span>
|
|
358
390
|
<span style="font-weight: 600;">Training Complete</span>
|
|
359
|
-
</div>
|
|
391
|
+
</div>"""
|
|
360
392
|
else:
|
|
361
|
-
return
|
|
393
|
+
return """<button id="stop-training-btn" onclick="stopTraining()" style="
|
|
362
394
|
background: linear-gradient(135deg, #ef4444 0%, #dc2626 100%);
|
|
363
395
|
color: white;
|
|
364
396
|
border: none;
|
|
@@ -374,53 +406,65 @@ def _generate_termination_status_html(state: TrainingState, is_training_complete
|
|
|
374
406
|
">
|
|
375
407
|
<span style="font-size: 1.1rem;">■</span> Stop Training
|
|
376
408
|
</button>
|
|
377
|
-
<p id="stop-status" style="margin-top: 8px; font-size: 0.75rem; color: var(--text-muted);"></p>
|
|
409
|
+
<p id="stop-status" style="margin-top: 8px; font-size: 0.75rem; color: var(--text-muted);"></p>"""
|
|
378
410
|
|
|
379
411
|
|
|
380
412
|
def generate_training_dashboard(state: TrainingState, config: TrainingConfig) -> str:
|
|
381
413
|
"""Generate an HTML dashboard for training visualization."""
|
|
382
414
|
losses_json = json.dumps(state.losses)
|
|
383
415
|
# Use stored elapsed_time if available (historical data), otherwise calculate
|
|
384
|
-
elapsed =
|
|
416
|
+
elapsed = (
|
|
417
|
+
state.elapsed_time if state.elapsed_time > 0 else time.time() - state.start_time
|
|
418
|
+
)
|
|
385
419
|
elapsed_str = f"{int(elapsed // 60)}m {int(elapsed % 60)}s"
|
|
386
420
|
|
|
387
421
|
# Calculate stats
|
|
388
422
|
if state.losses:
|
|
389
|
-
min_loss = min(
|
|
390
|
-
|
|
423
|
+
min_loss = min(loss["loss"] for loss in state.losses)
|
|
424
|
+
sum(loss["loss"] for loss in state.losses) / len(state.losses)
|
|
391
425
|
recent_losses = state.losses[-10:] if len(state.losses) >= 10 else state.losses
|
|
392
|
-
recent_avg = sum(
|
|
426
|
+
recent_avg = sum(loss["loss"] for loss in recent_losses) / len(recent_losses)
|
|
393
427
|
# Calculate step times
|
|
394
428
|
step_times = []
|
|
395
429
|
for i in range(1, len(state.losses)):
|
|
396
|
-
step_times.append(state.losses[i]["time"] - state.losses[i-1]["time"])
|
|
430
|
+
step_times.append(state.losses[i]["time"] - state.losses[i - 1]["time"])
|
|
397
431
|
avg_step_time = sum(step_times) / len(step_times) if step_times else 0
|
|
398
432
|
# Loss by epoch
|
|
399
433
|
epoch_losses: dict = {}
|
|
400
|
-
for
|
|
401
|
-
ep =
|
|
434
|
+
for loss in state.losses:
|
|
435
|
+
ep = loss["epoch"]
|
|
402
436
|
if ep not in epoch_losses:
|
|
403
437
|
epoch_losses[ep] = []
|
|
404
|
-
epoch_losses[ep].append(
|
|
405
|
-
epoch_avg = {
|
|
438
|
+
epoch_losses[ep].append(loss["loss"])
|
|
439
|
+
epoch_avg = {
|
|
440
|
+
ep: sum(losses) / len(losses) for ep, losses in epoch_losses.items()
|
|
441
|
+
}
|
|
406
442
|
# Estimate ETA
|
|
407
443
|
# Steps per epoch = steps in completed epochs / completed epochs
|
|
408
444
|
completed_epochs = state.epoch
|
|
409
|
-
steps_in_completed = sum(
|
|
445
|
+
steps_in_completed = sum(
|
|
446
|
+
1 for loss in state.losses if loss["epoch"] < completed_epochs
|
|
447
|
+
)
|
|
410
448
|
if completed_epochs > 0 and steps_in_completed > 0:
|
|
411
449
|
steps_per_epoch = steps_in_completed / completed_epochs
|
|
412
450
|
else:
|
|
413
451
|
# Estimate from current epoch progress
|
|
414
|
-
steps_per_epoch =
|
|
415
|
-
|
|
416
|
-
|
|
452
|
+
steps_per_epoch = (
|
|
453
|
+
len(state.losses) / (state.epoch + 1)
|
|
454
|
+
if state.epoch >= 0
|
|
455
|
+
else len(state.losses)
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
total_epochs = (
|
|
459
|
+
state.total_epochs if state.total_epochs > 0 else config.num_train_epochs
|
|
460
|
+
)
|
|
417
461
|
total_steps_estimate = steps_per_epoch * total_epochs
|
|
418
462
|
remaining_steps = max(0, total_steps_estimate - len(state.losses))
|
|
419
463
|
eta_seconds = remaining_steps * avg_step_time if avg_step_time > 0 else 0
|
|
420
464
|
# Check if training is complete (all steps done)
|
|
421
465
|
is_training_complete = remaining_steps == 0 and len(state.losses) > 0
|
|
422
466
|
else:
|
|
423
|
-
min_loss =
|
|
467
|
+
min_loss = recent_avg = avg_step_time = 0.0
|
|
424
468
|
epoch_avg = {}
|
|
425
469
|
eta_seconds = 0
|
|
426
470
|
steps_per_epoch = 0
|
|
@@ -431,10 +475,9 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
|
|
|
431
475
|
epoch_avg_json = json.dumps(list(epoch_avg.items()))
|
|
432
476
|
|
|
433
477
|
# Generate comparison viewer preview if capture path available
|
|
434
|
-
comparison_viewer_path = ""
|
|
435
478
|
if state.capture_path:
|
|
436
479
|
try:
|
|
437
|
-
from openadapt_ml.scripts.compare import generate_comparison_html
|
|
480
|
+
from openadapt_ml.scripts.compare import generate_comparison_html
|
|
438
481
|
from openadapt_ml.ingest.capture import capture_to_episode
|
|
439
482
|
|
|
440
483
|
capture_path = Path(state.capture_path)
|
|
@@ -454,7 +497,9 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
|
|
|
454
497
|
"time": step.step_index,
|
|
455
498
|
"image_path": step.observation.screenshot_path,
|
|
456
499
|
"human_action": {
|
|
457
|
-
"type": step.action.type.value
|
|
500
|
+
"type": step.action.type.value
|
|
501
|
+
if isinstance(step.action.type, ActionType)
|
|
502
|
+
else step.action.type,
|
|
458
503
|
"x": action_x,
|
|
459
504
|
"y": action_y,
|
|
460
505
|
"text": step.action.text,
|
|
@@ -465,15 +510,21 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
|
|
|
465
510
|
comparison_data.append(step_data)
|
|
466
511
|
|
|
467
512
|
# Generate comparison HTML
|
|
468
|
-
output_dir =
|
|
513
|
+
output_dir = (
|
|
514
|
+
Path(config.output_dir)
|
|
515
|
+
if hasattr(config, "output_dir")
|
|
516
|
+
else Path("training_output")
|
|
517
|
+
)
|
|
469
518
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
470
519
|
comparison_output = output_dir / "comparison_preview.html"
|
|
471
|
-
generate_comparison_html(
|
|
472
|
-
|
|
473
|
-
|
|
520
|
+
generate_comparison_html(
|
|
521
|
+
capture_path, episode, comparison_data, comparison_output
|
|
522
|
+
)
|
|
523
|
+
str(comparison_output.name) # Relative path
|
|
524
|
+
except Exception:
|
|
474
525
|
pass # Fail silently if comparison viewer can't be generated
|
|
475
526
|
|
|
476
|
-
html = f
|
|
527
|
+
html = f"""<!DOCTYPE html>
|
|
477
528
|
<html lang="en">
|
|
478
529
|
<head>
|
|
479
530
|
<meta charset="UTF-8">
|
|
@@ -1132,10 +1183,10 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
|
|
|
1132
1183
|
<div class="container">
|
|
1133
1184
|
<header>
|
|
1134
1185
|
<div>
|
|
1135
|
-
<h1>Training Dashboard{f' <a href="{state.cloud_dashboard_url}" target="_blank" class="cloud-link cloud-badge"><svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M18 10h-1.26A8 8 0 1 0 9 20h9a5 5 0 0 0 0-10z"/></svg>{state.cloud_provider.title()} Cloud</a>' if state.cloud_dashboard_url else
|
|
1186
|
+
<h1>Training Dashboard{f' <a href="{state.cloud_dashboard_url}" target="_blank" class="cloud-link cloud-badge"><svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M18 10h-1.26A8 8 0 1 0 9 20h9a5 5 0 0 0 0-10z"/></svg>{state.cloud_provider.title()} Cloud</a>' if state.cloud_dashboard_url else ""}</h1>
|
|
1136
1187
|
<div class="job-info" id="job-info">
|
|
1137
|
-
<span class="job-host">{state.hostname or
|
|
1138
|
-
{f'<span class="job-config">{state.instance_type}</span>' if state.instance_type else
|
|
1188
|
+
<span class="job-host">{state.hostname or "stub-local"} @ {state.instance_ip or "127.0.0.1"}</span>
|
|
1189
|
+
{f'<span class="job-config">{state.instance_type}</span>' if state.instance_type else ""}
|
|
1139
1190
|
</div>
|
|
1140
1191
|
</div>
|
|
1141
1192
|
<div class="status" id="status">
|
|
@@ -1144,13 +1195,13 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
|
|
|
1144
1195
|
</div>
|
|
1145
1196
|
</header>
|
|
1146
1197
|
|
|
1147
|
-
<div class="setup-panel{
|
|
1198
|
+
<div class="setup-panel{" hidden" if not state.setup_logs else ""}" id="setup-panel">
|
|
1148
1199
|
<div class="setup-header">
|
|
1149
1200
|
<h2>Setup Progress</h2>
|
|
1150
|
-
<span class="setup-status-badge {state.setup_status}" id="setup-status-badge">{state.setup_status or
|
|
1201
|
+
<span class="setup-status-badge {state.setup_status}" id="setup-status-badge">{state.setup_status or "initializing"}</span>
|
|
1151
1202
|
</div>
|
|
1152
1203
|
<div class="setup-logs" id="setup-logs">
|
|
1153
|
-
{
|
|
1204
|
+
{"".join(f'<div class="setup-log-line{" current" if i == len(state.setup_logs) - 1 else ""}">{log}</div>' for i, log in enumerate(state.setup_logs)) if state.setup_logs else '<div class="setup-log-line">Waiting for setup logs...</div>'}
|
|
1154
1205
|
</div>
|
|
1155
1206
|
</div>
|
|
1156
1207
|
|
|
@@ -1160,23 +1211,23 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
|
|
|
1160
1211
|
<div class="config-grid">
|
|
1161
1212
|
<div class="config-item">
|
|
1162
1213
|
<span class="config-label">Model</span>
|
|
1163
|
-
<span class="config-value model" id="config-model">{state.model_name or
|
|
1214
|
+
<span class="config-value model" id="config-model">{state.model_name or "Not specified"}</span>
|
|
1164
1215
|
</div>
|
|
1165
1216
|
<div class="config-item">
|
|
1166
1217
|
<span class="config-label">Goal</span>
|
|
1167
|
-
<span class="config-value goal" id="config-goal">{state.goal or
|
|
1218
|
+
<span class="config-value goal" id="config-goal">{state.goal or "Not specified"}</span>
|
|
1168
1219
|
</div>
|
|
1169
1220
|
<div class="config-item">
|
|
1170
1221
|
<span class="config-label">LoRA</span>
|
|
1171
|
-
<span class="config-value" id="config-lora">{f
|
|
1222
|
+
<span class="config-value" id="config-lora">{f"r={state.lora_r}, α={state.lora_alpha}" if state.lora_r else "Not specified"}</span>
|
|
1172
1223
|
</div>
|
|
1173
1224
|
<div class="config-item">
|
|
1174
1225
|
<span class="config-label">Quantization</span>
|
|
1175
|
-
<span class="config-value" id="config-quant">{
|
|
1226
|
+
<span class="config-value" id="config-quant">{"4-bit" if state.load_in_4bit else "None"}</span>
|
|
1176
1227
|
</div>
|
|
1177
1228
|
<div class="config-item">
|
|
1178
1229
|
<span class="config-label">Config</span>
|
|
1179
|
-
<span class="config-value" id="config-path">{state.config_path or
|
|
1230
|
+
<span class="config-value" id="config-path">{state.config_path or "Not specified"}</span>
|
|
1180
1231
|
</div>
|
|
1181
1232
|
</div>
|
|
1182
1233
|
</div>
|
|
@@ -1523,7 +1574,7 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
|
|
|
1523
1574
|
let etaSeconds = {eta_seconds};
|
|
1524
1575
|
let avgStepTime = {avg_step_time};
|
|
1525
1576
|
let remainingSteps = {remaining_steps};
|
|
1526
|
-
let isTrainingComplete = {
|
|
1577
|
+
let isTrainingComplete = {"true" if is_training_complete else "false"};
|
|
1527
1578
|
|
|
1528
1579
|
// Auto-stop when loss <= threshold (INVARIANT: training should stop when loss <= 1.0)
|
|
1529
1580
|
const AUTO_STOP_LOSS_THRESHOLD = 1.0;
|
|
@@ -2022,7 +2073,7 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
|
|
|
2022
2073
|
setInterval(updateStatusIndicator, 1000); // Update LIVE/STALE indicator every second
|
|
2023
2074
|
</script>
|
|
2024
2075
|
</body>
|
|
2025
|
-
</html>
|
|
2076
|
+
</html>"""
|
|
2026
2077
|
return html
|
|
2027
2078
|
|
|
2028
2079
|
|
|
@@ -2060,6 +2111,7 @@ def regenerate_all_dashboards(output_dir: str | Path) -> list[Path]:
|
|
|
2060
2111
|
except Exception as e:
|
|
2061
2112
|
print(f"Warning: Failed to generate unified viewer: {e}")
|
|
2062
2113
|
import traceback
|
|
2114
|
+
|
|
2063
2115
|
traceback.print_exc()
|
|
2064
2116
|
|
|
2065
2117
|
return regenerated
|
|
@@ -2102,7 +2154,9 @@ def regenerate_local_dashboard(
|
|
|
2102
2154
|
state = TrainingState(
|
|
2103
2155
|
job_id=data.get("job_id", "unknown"),
|
|
2104
2156
|
hostname=data.get("hostname", ""),
|
|
2105
|
-
capture_path=str(capture_path)
|
|
2157
|
+
capture_path=str(capture_path)
|
|
2158
|
+
if capture_path
|
|
2159
|
+
else data.get("capture_path", ""),
|
|
2106
2160
|
config_path=data.get("config_path", ""),
|
|
2107
2161
|
epoch=data.get("epoch", 0),
|
|
2108
2162
|
step=data.get("step", 0),
|
|
@@ -2145,30 +2199,33 @@ def regenerate_local_dashboard(
|
|
|
2145
2199
|
if training_status == "COMPLETED":
|
|
2146
2200
|
html = html.replace(
|
|
2147
2201
|
'<div class="status" id="status">',
|
|
2148
|
-
'<div class="status complete" id="status">'
|
|
2202
|
+
'<div class="status complete" id="status">',
|
|
2149
2203
|
)
|
|
2150
2204
|
html = html.replace(
|
|
2151
2205
|
'<span id="status-text">Training in progress</span>',
|
|
2152
|
-
'<span id="status-text">COMPLETED</span>'
|
|
2206
|
+
'<span id="status-text">COMPLETED</span>',
|
|
2153
2207
|
)
|
|
2154
2208
|
elif training_status == "STOPPED":
|
|
2155
2209
|
html = html.replace(
|
|
2156
|
-
'<div class="status" id="status">',
|
|
2157
|
-
'<div class="status stale" id="status">'
|
|
2210
|
+
'<div class="status" id="status">', '<div class="status stale" id="status">'
|
|
2158
2211
|
)
|
|
2159
2212
|
html = html.replace(
|
|
2160
2213
|
'<span id="status-text">Training in progress</span>',
|
|
2161
|
-
'<span id="status-text">STOPPED (Epoch {}/{})'.format(
|
|
2214
|
+
'<span id="status-text">STOPPED (Epoch {}/{})'.format(
|
|
2215
|
+
current_epoch + 1, total_epochs
|
|
2216
|
+
)
|
|
2217
|
+
+ "</span>",
|
|
2162
2218
|
)
|
|
2163
2219
|
|
|
2164
2220
|
# Fix ETA display for completed/stopped training
|
|
2165
2221
|
import re
|
|
2222
|
+
|
|
2166
2223
|
if training_status in ("COMPLETED", "STOPPED"):
|
|
2167
2224
|
# Replace "calculating..." with appropriate status
|
|
2168
2225
|
html = re.sub(
|
|
2169
2226
|
r'(<div class="stat-value" id="stat-eta">)[^<]*(</div>)',
|
|
2170
|
-
r
|
|
2171
|
-
html
|
|
2227
|
+
r"\1—\2" if training_status == "STOPPED" else r"\1complete\2",
|
|
2228
|
+
html,
|
|
2172
2229
|
)
|
|
2173
2230
|
|
|
2174
2231
|
# Replace dynamic nav with static unified header
|
|
@@ -2179,20 +2236,20 @@ def regenerate_local_dashboard(
|
|
|
2179
2236
|
# This is critical for file:// protocol where fetch() doesn't work
|
|
2180
2237
|
html = html.replace(
|
|
2181
2238
|
"setInterval(fetchAndUpdate, 3000);",
|
|
2182
|
-
"// fetchAndUpdate disabled for static dashboard"
|
|
2239
|
+
"// fetchAndUpdate disabled for static dashboard",
|
|
2183
2240
|
)
|
|
2184
2241
|
html = html.replace(
|
|
2185
2242
|
"setInterval(updateElapsedDisplay, 1000);",
|
|
2186
|
-
"// updateElapsedDisplay disabled for static dashboard"
|
|
2243
|
+
"// updateElapsedDisplay disabled for static dashboard",
|
|
2187
2244
|
)
|
|
2188
2245
|
html = html.replace(
|
|
2189
2246
|
"setInterval(updateStatusIndicator, 1000);",
|
|
2190
|
-
"// updateStatusIndicator disabled for static dashboard"
|
|
2247
|
+
"// updateStatusIndicator disabled for static dashboard",
|
|
2191
2248
|
)
|
|
2192
2249
|
# CRITICAL: Disable discoverDashboards() - it overwrites static nav on file:// protocol
|
|
2193
2250
|
html = html.replace(
|
|
2194
2251
|
"discoverDashboards();",
|
|
2195
|
-
"// discoverDashboards disabled - using static nav for file:// protocol"
|
|
2252
|
+
"// discoverDashboards disabled - using static nav for file:// protocol",
|
|
2196
2253
|
)
|
|
2197
2254
|
|
|
2198
2255
|
# Write output
|
|
@@ -91,7 +91,9 @@ def _load_unsloth_model(config: TRLTrainingConfig):
|
|
|
91
91
|
# Enable training mode
|
|
92
92
|
FastVisionModel.for_training(model)
|
|
93
93
|
|
|
94
|
-
print(
|
|
94
|
+
print(
|
|
95
|
+
f"✓ Loaded {config.model_name} with Unsloth (4-bit: {config.load_in_4bit})"
|
|
96
|
+
)
|
|
95
97
|
return model, tokenizer, True
|
|
96
98
|
|
|
97
99
|
except ImportError:
|
|
@@ -100,26 +102,70 @@ def _load_unsloth_model(config: TRLTrainingConfig):
|
|
|
100
102
|
|
|
101
103
|
|
|
102
104
|
def _load_standard_model(config: TRLTrainingConfig):
|
|
103
|
-
"""Fallback: Load model with standard transformers + peft.
|
|
104
|
-
|
|
105
|
+
"""Fallback: Load model with standard transformers + peft.
|
|
106
|
+
|
|
107
|
+
Automatically detects vision-language models and uses the appropriate
|
|
108
|
+
model class (Qwen2VLForConditionalGeneration for VL models,
|
|
109
|
+
AutoModelForCausalLM for text-only models).
|
|
110
|
+
"""
|
|
111
|
+
from transformers import AutoConfig, AutoProcessor
|
|
105
112
|
from peft import LoraConfig, get_peft_model
|
|
106
113
|
import torch
|
|
107
114
|
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
115
|
+
# Check if this is a vision-language model
|
|
116
|
+
model_config = AutoConfig.from_pretrained(
|
|
117
|
+
config.model_name, trust_remote_code=True
|
|
118
|
+
)
|
|
119
|
+
is_vl_model = (
|
|
120
|
+
"VL" in config.model_name.upper()
|
|
121
|
+
or "vision" in config.model_name.lower()
|
|
122
|
+
or hasattr(model_config, "vision_config")
|
|
113
123
|
)
|
|
124
|
+
|
|
125
|
+
if is_vl_model:
|
|
126
|
+
# Vision-language model - use Qwen2VLForConditionalGeneration or AutoModelForVision2Seq
|
|
127
|
+
try:
|
|
128
|
+
from transformers import Qwen2VLForConditionalGeneration
|
|
129
|
+
|
|
130
|
+
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
131
|
+
config.model_name,
|
|
132
|
+
torch_dtype=torch.bfloat16,
|
|
133
|
+
device_map="auto",
|
|
134
|
+
trust_remote_code=True,
|
|
135
|
+
)
|
|
136
|
+
print(" Using Qwen2VLForConditionalGeneration for VL model")
|
|
137
|
+
except (ImportError, ValueError, RuntimeError, TypeError):
|
|
138
|
+
# Fallback to AutoModelForVision2Seq for other VL models
|
|
139
|
+
from transformers import AutoModelForVision2Seq
|
|
140
|
+
|
|
141
|
+
model = AutoModelForVision2Seq.from_pretrained(
|
|
142
|
+
config.model_name,
|
|
143
|
+
torch_dtype=torch.bfloat16,
|
|
144
|
+
device_map="auto",
|
|
145
|
+
trust_remote_code=True,
|
|
146
|
+
)
|
|
147
|
+
print(" Using AutoModelForVision2Seq for VL model")
|
|
148
|
+
else:
|
|
149
|
+
# Text-only model
|
|
150
|
+
from transformers import AutoModelForCausalLM
|
|
151
|
+
|
|
152
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
153
|
+
config.model_name,
|
|
154
|
+
torch_dtype=torch.bfloat16,
|
|
155
|
+
device_map="auto",
|
|
156
|
+
trust_remote_code=True,
|
|
157
|
+
)
|
|
158
|
+
print(" Using AutoModelForCausalLM for text-only model")
|
|
159
|
+
|
|
114
160
|
processor = AutoProcessor.from_pretrained(config.model_name, trust_remote_code=True)
|
|
115
161
|
|
|
116
|
-
# Apply LoRA
|
|
162
|
+
# Apply LoRA - use SEQ_2_SEQ_LM for VL models, CAUSAL_LM for text-only
|
|
117
163
|
peft_config = LoraConfig(
|
|
118
164
|
r=config.lora_r,
|
|
119
165
|
lora_alpha=config.lora_alpha,
|
|
120
166
|
lora_dropout=config.lora_dropout,
|
|
121
167
|
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
|
|
122
|
-
task_type="CAUSAL_LM",
|
|
168
|
+
task_type="SEQ_2_SEQ_LM" if is_vl_model else "CAUSAL_LM",
|
|
123
169
|
)
|
|
124
170
|
model = get_peft_model(model, peft_config)
|
|
125
171
|
|
|
@@ -161,10 +207,12 @@ def _convert_samples_to_trl_format(
|
|
|
161
207
|
if not pil_images:
|
|
162
208
|
continue # Skip samples with missing images
|
|
163
209
|
|
|
164
|
-
trl_samples.append(
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
210
|
+
trl_samples.append(
|
|
211
|
+
{
|
|
212
|
+
"images": pil_images,
|
|
213
|
+
"messages": sample["messages"],
|
|
214
|
+
}
|
|
215
|
+
)
|
|
168
216
|
|
|
169
217
|
return trl_samples
|
|
170
218
|
|
|
@@ -261,7 +309,7 @@ def train_with_trl(
|
|
|
261
309
|
logging_steps=config.logging_steps,
|
|
262
310
|
save_strategy=config.save_strategy,
|
|
263
311
|
max_length=None, # Critical for VLMs
|
|
264
|
-
assistant_only_loss=
|
|
312
|
+
assistant_only_loss=False, # Not supported for VL models yet
|
|
265
313
|
)
|
|
266
314
|
|
|
267
315
|
trainer = SFTTrainer(
|
|
@@ -270,15 +318,15 @@ def train_with_trl(
|
|
|
270
318
|
args=training_args,
|
|
271
319
|
)
|
|
272
320
|
|
|
273
|
-
print(f"\n{'='*50}")
|
|
274
|
-
print(
|
|
321
|
+
print(f"\n{'=' * 50}")
|
|
322
|
+
print("Starting training:")
|
|
275
323
|
print(f" Model: {config.model_name}")
|
|
276
324
|
print(f" Samples: {len(trl_samples)}")
|
|
277
325
|
print(f" Epochs: {config.num_epochs}")
|
|
278
326
|
print(f" Batch size: {config.batch_size}")
|
|
279
327
|
print(f" Unsloth: {is_unsloth}")
|
|
280
328
|
print(f" Output: {config.output_dir}")
|
|
281
|
-
print(f"{'='*50}\n")
|
|
329
|
+
print(f"{'=' * 50}\n")
|
|
282
330
|
|
|
283
331
|
trainer.train()
|
|
284
332
|
|
|
@@ -291,8 +339,7 @@ def train_with_trl(
|
|
|
291
339
|
|
|
292
340
|
except ImportError as e:
|
|
293
341
|
raise ImportError(
|
|
294
|
-
f"TRL not installed. Install with: pip install trl\
|
|
295
|
-
f"Original error: {e}"
|
|
342
|
+
f"TRL not installed. Install with: pip install trl\nOriginal error: {e}"
|
|
296
343
|
)
|
|
297
344
|
|
|
298
345
|
|
|
@@ -333,7 +380,9 @@ if __name__ == "__main__":
|
|
|
333
380
|
parser = argparse.ArgumentParser(description="Train VLM with TRL + Unsloth")
|
|
334
381
|
parser.add_argument("--parquet", required=True, help="Path to parquet file")
|
|
335
382
|
parser.add_argument("--output", default="checkpoints", help="Output directory")
|
|
336
|
-
parser.add_argument(
|
|
383
|
+
parser.add_argument(
|
|
384
|
+
"--model", default="unsloth/Qwen2.5-VL-7B-Instruct", help="Model name"
|
|
385
|
+
)
|
|
337
386
|
parser.add_argument("--epochs", type=int, default=3, help="Number of epochs")
|
|
338
387
|
parser.add_argument("--use-som", action="store_true", help="Use Set-of-Marks DSL")
|
|
339
388
|
|