openadapt-ml 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/benchmarks/__init__.py +8 -0
- openadapt_ml/benchmarks/agent.py +90 -11
- openadapt_ml/benchmarks/azure.py +35 -6
- openadapt_ml/benchmarks/cli.py +4449 -201
- openadapt_ml/benchmarks/live_tracker.py +180 -0
- openadapt_ml/benchmarks/runner.py +41 -4
- openadapt_ml/benchmarks/viewer.py +1219 -0
- openadapt_ml/benchmarks/vm_monitor.py +610 -0
- openadapt_ml/benchmarks/waa.py +61 -4
- openadapt_ml/benchmarks/waa_deploy/Dockerfile +222 -0
- openadapt_ml/benchmarks/waa_deploy/__init__.py +10 -0
- openadapt_ml/benchmarks/waa_deploy/api_agent.py +539 -0
- openadapt_ml/benchmarks/waa_deploy/start_waa_server.bat +53 -0
- openadapt_ml/benchmarks/waa_live.py +619 -0
- openadapt_ml/cloud/local.py +1555 -1
- openadapt_ml/cloud/ssh_tunnel.py +553 -0
- openadapt_ml/datasets/next_action.py +87 -68
- openadapt_ml/evals/grounding.py +26 -8
- openadapt_ml/evals/trajectory_matching.py +84 -36
- openadapt_ml/experiments/demo_prompt/__init__.py +19 -0
- openadapt_ml/experiments/demo_prompt/format_demo.py +226 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_20251231_002125.json +83 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_n30_20251231_165958.json +1100 -0
- openadapt_ml/experiments/demo_prompt/results/multistep_20251231_025051.json +182 -0
- openadapt_ml/experiments/demo_prompt/run_experiment.py +531 -0
- openadapt_ml/experiments/waa_demo/__init__.py +10 -0
- openadapt_ml/experiments/waa_demo/demos.py +357 -0
- openadapt_ml/experiments/waa_demo/runner.py +717 -0
- openadapt_ml/experiments/waa_demo/tasks.py +151 -0
- openadapt_ml/export/__init__.py +9 -0
- openadapt_ml/export/__main__.py +6 -0
- openadapt_ml/export/cli.py +89 -0
- openadapt_ml/export/parquet.py +265 -0
- openadapt_ml/ingest/__init__.py +3 -4
- openadapt_ml/ingest/capture.py +89 -81
- openadapt_ml/ingest/loader.py +116 -68
- openadapt_ml/ingest/synthetic.py +221 -159
- openadapt_ml/retrieval/README.md +226 -0
- openadapt_ml/retrieval/USAGE.md +391 -0
- openadapt_ml/retrieval/__init__.py +91 -0
- openadapt_ml/retrieval/demo_retriever.py +817 -0
- openadapt_ml/retrieval/embeddings.py +629 -0
- openadapt_ml/retrieval/index.py +194 -0
- openadapt_ml/retrieval/retriever.py +160 -0
- openadapt_ml/runtime/policy.py +10 -10
- openadapt_ml/schema/__init__.py +104 -0
- openadapt_ml/schema/converters.py +541 -0
- openadapt_ml/schema/episode.py +457 -0
- openadapt_ml/scripts/compare.py +26 -16
- openadapt_ml/scripts/eval_policy.py +4 -5
- openadapt_ml/scripts/prepare_synthetic.py +14 -17
- openadapt_ml/scripts/train.py +81 -70
- openadapt_ml/training/benchmark_viewer.py +3225 -0
- openadapt_ml/training/trainer.py +120 -363
- openadapt_ml/training/trl_trainer.py +354 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/METADATA +102 -60
- openadapt_ml-0.2.0.dist-info/RECORD +86 -0
- openadapt_ml/schemas/__init__.py +0 -53
- openadapt_ml/schemas/sessions.py +0 -122
- openadapt_ml/schemas/validation.py +0 -252
- openadapt_ml-0.1.0.dist-info/RECORD +0 -55
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/WHEEL +0 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.0.dist-info}/licenses/LICENSE +0 -0
openadapt_ml/training/trainer.py
CHANGED
|
@@ -6,13 +6,7 @@ from dataclasses import dataclass, field
|
|
|
6
6
|
from pathlib import Path
|
|
7
7
|
from typing import Any, Callable, Dict, List, Optional
|
|
8
8
|
|
|
9
|
-
import
|
|
10
|
-
from torch.optim import Optimizer
|
|
11
|
-
from torch.optim.lr_scheduler import LambdaLR
|
|
12
|
-
from torch.utils.data import DataLoader, Dataset
|
|
13
|
-
|
|
14
|
-
from openadapt_ml.models.base_adapter import BaseVLMAdapter
|
|
15
|
-
from openadapt_ml.schemas.sessions import Episode
|
|
9
|
+
from openadapt_ml.schema import Episode, Step, Action, ActionType
|
|
16
10
|
from openadapt_ml.training.shared_ui import (
|
|
17
11
|
get_shared_header_css as _get_shared_header_css,
|
|
18
12
|
generate_shared_header_html as _generate_shared_header_html,
|
|
@@ -21,6 +15,10 @@ from openadapt_ml.training.shared_ui import (
|
|
|
21
15
|
from openadapt_ml.training.viewer import (
|
|
22
16
|
generate_unified_viewer_from_output_dir,
|
|
23
17
|
)
|
|
18
|
+
from openadapt_ml.training.benchmark_viewer import (
|
|
19
|
+
_get_azure_jobs_panel_css,
|
|
20
|
+
_get_azure_jobs_panel_html,
|
|
21
|
+
)
|
|
24
22
|
|
|
25
23
|
|
|
26
24
|
def setup_job_directory(base_dir: str | Path, job_id: str) -> Path:
|
|
@@ -116,6 +114,11 @@ class TrainingState:
|
|
|
116
114
|
capture_path: str = ""
|
|
117
115
|
config_path: str = ""
|
|
118
116
|
goal: str = "" # Task goal/description for the training run
|
|
117
|
+
# Model configuration
|
|
118
|
+
model_name: str = "" # e.g. "Qwen/Qwen3-VL-2B-Instruct"
|
|
119
|
+
lora_r: int = 0 # LoRA rank
|
|
120
|
+
lora_alpha: int = 0 # LoRA alpha
|
|
121
|
+
load_in_4bit: bool = False # Quantization
|
|
119
122
|
# Training progress
|
|
120
123
|
epoch: int = 0
|
|
121
124
|
step: int = 0
|
|
@@ -185,6 +188,11 @@ class TrainingState:
|
|
|
185
188
|
"capture_path": self.capture_path,
|
|
186
189
|
"config_path": self.config_path,
|
|
187
190
|
"goal": self.goal,
|
|
191
|
+
# Model configuration
|
|
192
|
+
"model_name": self.model_name,
|
|
193
|
+
"lora_r": self.lora_r,
|
|
194
|
+
"lora_alpha": self.lora_alpha,
|
|
195
|
+
"load_in_4bit": self.load_in_4bit,
|
|
188
196
|
"instance_type": self.instance_type,
|
|
189
197
|
"instance_ip": self.instance_ip,
|
|
190
198
|
"started_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(self.start_time)),
|
|
@@ -227,6 +235,11 @@ class TrainingLogger:
|
|
|
227
235
|
cloud_dashboard_url: str = "",
|
|
228
236
|
cloud_instance_id: str = "",
|
|
229
237
|
job_id: str = "",
|
|
238
|
+
# Model configuration
|
|
239
|
+
model_name: str = "",
|
|
240
|
+
lora_r: int = 0,
|
|
241
|
+
lora_alpha: int = 0,
|
|
242
|
+
load_in_4bit: bool = False,
|
|
230
243
|
):
|
|
231
244
|
# Generate job_id if not provided
|
|
232
245
|
if not job_id:
|
|
@@ -242,6 +255,10 @@ class TrainingLogger:
|
|
|
242
255
|
capture_path=capture_path,
|
|
243
256
|
config_path=config_path,
|
|
244
257
|
goal=goal,
|
|
258
|
+
model_name=model_name,
|
|
259
|
+
lora_r=lora_r,
|
|
260
|
+
lora_alpha=lora_alpha,
|
|
261
|
+
load_in_4bit=load_in_4bit,
|
|
245
262
|
instance_ip=instance_ip,
|
|
246
263
|
instance_type=instance_type,
|
|
247
264
|
total_epochs=config.num_train_epochs,
|
|
@@ -428,14 +445,18 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
|
|
|
428
445
|
# Generate comparison data with null predictions (shows "— No prediction")
|
|
429
446
|
comparison_data = []
|
|
430
447
|
for i, step in enumerate(episode.steps):
|
|
448
|
+
# Extract normalized coordinates if available
|
|
449
|
+
action_x, action_y = None, None
|
|
450
|
+
if step.action.normalized_coordinates:
|
|
451
|
+
action_x, action_y = step.action.normalized_coordinates
|
|
431
452
|
step_data = {
|
|
432
453
|
"index": i,
|
|
433
|
-
"time": step.
|
|
434
|
-
"image_path": step.observation.
|
|
454
|
+
"time": step.step_index,
|
|
455
|
+
"image_path": step.observation.screenshot_path,
|
|
435
456
|
"human_action": {
|
|
436
|
-
"type": step.action.type,
|
|
437
|
-
"x":
|
|
438
|
-
"y":
|
|
457
|
+
"type": step.action.type.value if isinstance(step.action.type, ActionType) else step.action.type,
|
|
458
|
+
"x": action_x,
|
|
459
|
+
"y": action_y,
|
|
439
460
|
"text": step.action.text,
|
|
440
461
|
},
|
|
441
462
|
"predicted_action": None, # Shows "— No prediction" in viewer
|
|
@@ -596,6 +617,42 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
|
|
|
596
617
|
.setup-log-line.current {{
|
|
597
618
|
color: var(--accent);
|
|
598
619
|
}}
|
|
620
|
+
.config-panel {{
|
|
621
|
+
background: var(--bg-secondary);
|
|
622
|
+
border: 1px solid var(--border-color);
|
|
623
|
+
border-radius: 12px;
|
|
624
|
+
padding: 16px 20px;
|
|
625
|
+
margin-bottom: 24px;
|
|
626
|
+
}}
|
|
627
|
+
.config-grid {{
|
|
628
|
+
display: grid;
|
|
629
|
+
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
|
|
630
|
+
gap: 16px;
|
|
631
|
+
}}
|
|
632
|
+
.config-item {{
|
|
633
|
+
display: flex;
|
|
634
|
+
flex-direction: column;
|
|
635
|
+
gap: 4px;
|
|
636
|
+
}}
|
|
637
|
+
.config-label {{
|
|
638
|
+
font-size: 0.7rem;
|
|
639
|
+
color: var(--text-secondary);
|
|
640
|
+
text-transform: uppercase;
|
|
641
|
+
letter-spacing: 0.5px;
|
|
642
|
+
}}
|
|
643
|
+
.config-value {{
|
|
644
|
+
font-family: "SF Mono", Monaco, monospace;
|
|
645
|
+
font-size: 0.85rem;
|
|
646
|
+
color: var(--text-primary);
|
|
647
|
+
}}
|
|
648
|
+
.config-value.model {{
|
|
649
|
+
color: var(--accent);
|
|
650
|
+
}}
|
|
651
|
+
.config-value.goal {{
|
|
652
|
+
font-family: -apple-system, BlinkMacSystemFont, "Inter", sans-serif;
|
|
653
|
+
font-size: 0.8rem;
|
|
654
|
+
opacity: 0.9;
|
|
655
|
+
}}
|
|
599
656
|
.status {{
|
|
600
657
|
display: flex;
|
|
601
658
|
align-items: center;
|
|
@@ -754,6 +811,8 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
|
|
|
754
811
|
}}
|
|
755
812
|
/* Shared header styles (injected from _get_shared_header_css) */
|
|
756
813
|
{_get_shared_header_css()}
|
|
814
|
+
/* Azure ML Jobs panel styles */
|
|
815
|
+
{_get_azure_jobs_panel_css()}
|
|
757
816
|
.eval-panel {{
|
|
758
817
|
background: var(--bg-secondary);
|
|
759
818
|
border: 1px solid var(--border-color);
|
|
@@ -1095,6 +1154,33 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
|
|
|
1095
1154
|
</div>
|
|
1096
1155
|
</div>
|
|
1097
1156
|
|
|
1157
|
+
{_get_azure_jobs_panel_html()}
|
|
1158
|
+
|
|
1159
|
+
<div class="config-panel" id="config-panel">
|
|
1160
|
+
<div class="config-grid">
|
|
1161
|
+
<div class="config-item">
|
|
1162
|
+
<span class="config-label">Model</span>
|
|
1163
|
+
<span class="config-value model" id="config-model">{state.model_name or 'Not specified'}</span>
|
|
1164
|
+
</div>
|
|
1165
|
+
<div class="config-item">
|
|
1166
|
+
<span class="config-label">Goal</span>
|
|
1167
|
+
<span class="config-value goal" id="config-goal">{state.goal or 'Not specified'}</span>
|
|
1168
|
+
</div>
|
|
1169
|
+
<div class="config-item">
|
|
1170
|
+
<span class="config-label">LoRA</span>
|
|
1171
|
+
<span class="config-value" id="config-lora">{f'r={state.lora_r}, α={state.lora_alpha}' if state.lora_r else 'Not specified'}</span>
|
|
1172
|
+
</div>
|
|
1173
|
+
<div class="config-item">
|
|
1174
|
+
<span class="config-label">Quantization</span>
|
|
1175
|
+
<span class="config-value" id="config-quant">{'4-bit' if state.load_in_4bit else 'None'}</span>
|
|
1176
|
+
</div>
|
|
1177
|
+
<div class="config-item">
|
|
1178
|
+
<span class="config-label">Config</span>
|
|
1179
|
+
<span class="config-value" id="config-path">{state.config_path or 'Not specified'}</span>
|
|
1180
|
+
</div>
|
|
1181
|
+
</div>
|
|
1182
|
+
</div>
|
|
1183
|
+
|
|
1098
1184
|
<div class="stats-grid">
|
|
1099
1185
|
<div class="stat-card" id="card-epoch">
|
|
1100
1186
|
<div class="stat-label">Epoch Progress</div>
|
|
@@ -1519,6 +1605,28 @@ def generate_training_dashboard(state: TrainingState, config: TrainingConfig) ->
|
|
|
1519
1605
|
}}
|
|
1520
1606
|
}}
|
|
1521
1607
|
|
|
1608
|
+
// Update config panel
|
|
1609
|
+
const configModel = document.getElementById('config-model');
|
|
1610
|
+
const configGoal = document.getElementById('config-goal');
|
|
1611
|
+
const configLora = document.getElementById('config-lora');
|
|
1612
|
+
const configQuant = document.getElementById('config-quant');
|
|
1613
|
+
const configPath = document.getElementById('config-path');
|
|
1614
|
+
if (configModel && data.model_name) {{
|
|
1615
|
+
configModel.textContent = data.model_name;
|
|
1616
|
+
}}
|
|
1617
|
+
if (configGoal && data.goal) {{
|
|
1618
|
+
configGoal.textContent = data.goal;
|
|
1619
|
+
}}
|
|
1620
|
+
if (configLora && (data.lora_r || data.lora_alpha)) {{
|
|
1621
|
+
configLora.textContent = `r=${{data.lora_r || 0}}, α=${{data.lora_alpha || 0}}`;
|
|
1622
|
+
}}
|
|
1623
|
+
if (configQuant) {{
|
|
1624
|
+
configQuant.textContent = data.load_in_4bit ? '4-bit' : 'None';
|
|
1625
|
+
}}
|
|
1626
|
+
if (configPath && data.config_path) {{
|
|
1627
|
+
configPath.textContent = data.config_path;
|
|
1628
|
+
}}
|
|
1629
|
+
|
|
1522
1630
|
// Update setup panel if setup logs present
|
|
1523
1631
|
if (data.setup_logs && data.setup_logs.length > 0) {{
|
|
1524
1632
|
const setupPanel = document.getElementById('setup-panel');
|
|
@@ -2093,354 +2201,3 @@ def regenerate_local_dashboard(
|
|
|
2093
2201
|
print(f"Regenerated dashboard: {dashboard_path}")
|
|
2094
2202
|
|
|
2095
2203
|
return dashboard_path
|
|
2096
|
-
|
|
2097
|
-
|
|
2098
|
-
def run_epoch_evaluation(
|
|
2099
|
-
adapter: BaseVLMAdapter,
|
|
2100
|
-
episode: Episode,
|
|
2101
|
-
epoch: int,
|
|
2102
|
-
config: TrainingConfig,
|
|
2103
|
-
logger: "TrainingLogger",
|
|
2104
|
-
sample_indices: Optional[List[int]] = None,
|
|
2105
|
-
) -> Path:
|
|
2106
|
-
"""Run inference evaluation on sample steps after an epoch.
|
|
2107
|
-
|
|
2108
|
-
This generates a comparison_epoch{N}.html file showing human vs predicted actions.
|
|
2109
|
-
|
|
2110
|
-
Args:
|
|
2111
|
-
adapter: Trained adapter to use for inference
|
|
2112
|
-
episode: Episode with steps to evaluate
|
|
2113
|
-
epoch: Current epoch number
|
|
2114
|
-
config: Training configuration
|
|
2115
|
-
logger: Training logger for state tracking
|
|
2116
|
-
sample_indices: Specific step indices to evaluate (default: evenly spaced)
|
|
2117
|
-
|
|
2118
|
-
Returns:
|
|
2119
|
-
Path to generated comparison HTML file
|
|
2120
|
-
"""
|
|
2121
|
-
from openadapt_ml.scripts.compare import generate_comparison_html, predict_action, format_action
|
|
2122
|
-
|
|
2123
|
-
output_dir = Path(config.output_dir)
|
|
2124
|
-
output_dir.mkdir(parents=True, exist_ok=True)
|
|
2125
|
-
|
|
2126
|
-
# Select sample indices if not provided
|
|
2127
|
-
num_samples = min(config.eval_samples, len(episode.steps))
|
|
2128
|
-
if sample_indices is None:
|
|
2129
|
-
if num_samples >= len(episode.steps):
|
|
2130
|
-
sample_indices = list(range(len(episode.steps)))
|
|
2131
|
-
else:
|
|
2132
|
-
# Evenly space samples across the episode
|
|
2133
|
-
step_size = len(episode.steps) // num_samples
|
|
2134
|
-
sample_indices = [i * step_size for i in range(num_samples)]
|
|
2135
|
-
|
|
2136
|
-
print(f" Running inference on {len(sample_indices)} sample steps...")
|
|
2137
|
-
|
|
2138
|
-
# Switch adapter to eval mode
|
|
2139
|
-
adapter.eval()
|
|
2140
|
-
|
|
2141
|
-
comparison_data = []
|
|
2142
|
-
action_history: List[str] = []
|
|
2143
|
-
total_steps = len(episode.steps)
|
|
2144
|
-
|
|
2145
|
-
for i, step in enumerate(episode.steps):
|
|
2146
|
-
step_data = {
|
|
2147
|
-
"index": i,
|
|
2148
|
-
"time": step.t,
|
|
2149
|
-
"image_path": step.observation.image_path,
|
|
2150
|
-
"human_action": {
|
|
2151
|
-
"type": step.action.type,
|
|
2152
|
-
"x": step.action.x,
|
|
2153
|
-
"y": step.action.y,
|
|
2154
|
-
"text": step.action.text,
|
|
2155
|
-
},
|
|
2156
|
-
"predicted_action": None,
|
|
2157
|
-
"match": None,
|
|
2158
|
-
}
|
|
2159
|
-
|
|
2160
|
-
# Only run inference on selected samples (for speed)
|
|
2161
|
-
if i in sample_indices and step.observation.image_path:
|
|
2162
|
-
try:
|
|
2163
|
-
predicted = predict_action(
|
|
2164
|
-
adapter,
|
|
2165
|
-
step.observation.image_path,
|
|
2166
|
-
episode.goal,
|
|
2167
|
-
step_index=i,
|
|
2168
|
-
total_steps=total_steps,
|
|
2169
|
-
action_history=action_history.copy(),
|
|
2170
|
-
)
|
|
2171
|
-
step_data["predicted_action"] = predicted
|
|
2172
|
-
|
|
2173
|
-
# Check match and calculate distance
|
|
2174
|
-
if predicted and predicted.get("type") == step.action.type:
|
|
2175
|
-
step_data["match"] = True
|
|
2176
|
-
|
|
2177
|
-
# Calculate distance for click actions
|
|
2178
|
-
if step.action.type == "click":
|
|
2179
|
-
hx, hy = step.action.x or 0, step.action.y or 0
|
|
2180
|
-
px, py = predicted.get("x", 0), predicted.get("y", 0)
|
|
2181
|
-
distance = ((hx - px) ** 2 + (hy - py) ** 2) ** 0.5
|
|
2182
|
-
|
|
2183
|
-
# Log evaluation to training state
|
|
2184
|
-
logger.state.log_evaluation(
|
|
2185
|
-
epoch=epoch,
|
|
2186
|
-
sample_idx=i,
|
|
2187
|
-
image_path=step.observation.image_path,
|
|
2188
|
-
human_action=step_data["human_action"],
|
|
2189
|
-
predicted_action=predicted,
|
|
2190
|
-
)
|
|
2191
|
-
else:
|
|
2192
|
-
step_data["match"] = False
|
|
2193
|
-
|
|
2194
|
-
print(f" Step {i}: {step.action.type} -> {predicted.get('type') if predicted else 'none'}")
|
|
2195
|
-
|
|
2196
|
-
except Exception as e:
|
|
2197
|
-
print(f" Step {i}: inference failed - {e}")
|
|
2198
|
-
|
|
2199
|
-
# Build action history for context
|
|
2200
|
-
action_history.append(format_action(step.action, use_som=False))
|
|
2201
|
-
comparison_data.append(step_data)
|
|
2202
|
-
|
|
2203
|
-
# Switch back to train mode
|
|
2204
|
-
adapter.train()
|
|
2205
|
-
|
|
2206
|
-
# Generate comparison HTML
|
|
2207
|
-
output_path = output_dir / f"comparison_epoch{epoch}.html"
|
|
2208
|
-
capture_path = Path(logger.state.capture_path) if logger.state.capture_path else Path(".")
|
|
2209
|
-
|
|
2210
|
-
generate_comparison_html(capture_path, episode, comparison_data, output_path)
|
|
2211
|
-
print(f" Comparison saved: {output_path}")
|
|
2212
|
-
|
|
2213
|
-
# Also regenerate all dashboards to update navigation
|
|
2214
|
-
regenerate_all_dashboards(output_dir)
|
|
2215
|
-
|
|
2216
|
-
return output_path
|
|
2217
|
-
|
|
2218
|
-
|
|
2219
|
-
def _create_dataloader(dataset: Dataset, batch_size: int) -> DataLoader:
|
|
2220
|
-
# Use an identity collate_fn so that each batch is a List[Dict], matching
|
|
2221
|
-
# the expectations of adapters that operate on SFT-style samples.
|
|
2222
|
-
return DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=lambda x: x)
|
|
2223
|
-
|
|
2224
|
-
|
|
2225
|
-
def _create_lr_scheduler(
|
|
2226
|
-
optimizer: Optimizer,
|
|
2227
|
-
config: TrainingConfig,
|
|
2228
|
-
num_training_steps: int,
|
|
2229
|
-
) -> Optional[LambdaLR]:
|
|
2230
|
-
"""Create learning rate scheduler based on config.
|
|
2231
|
-
|
|
2232
|
-
Args:
|
|
2233
|
-
optimizer: The optimizer to schedule.
|
|
2234
|
-
config: Training configuration with lr_scheduler_type and warmup_ratio.
|
|
2235
|
-
num_training_steps: Total number of training steps.
|
|
2236
|
-
|
|
2237
|
-
Returns:
|
|
2238
|
-
LambdaLR scheduler or None if scheduler_type is "none" or "constant".
|
|
2239
|
-
"""
|
|
2240
|
-
scheduler_type = config.lr_scheduler_type.lower()
|
|
2241
|
-
|
|
2242
|
-
if scheduler_type in ("none", "constant"):
|
|
2243
|
-
return None
|
|
2244
|
-
|
|
2245
|
-
num_warmup_steps = int(num_training_steps * config.warmup_ratio)
|
|
2246
|
-
|
|
2247
|
-
if scheduler_type == "linear":
|
|
2248
|
-
def lr_lambda(current_step: int) -> float:
|
|
2249
|
-
if current_step < num_warmup_steps:
|
|
2250
|
-
# Linear warmup
|
|
2251
|
-
return float(current_step) / float(max(1, num_warmup_steps))
|
|
2252
|
-
# Linear decay
|
|
2253
|
-
return max(
|
|
2254
|
-
0.0,
|
|
2255
|
-
float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
|
|
2256
|
-
)
|
|
2257
|
-
elif scheduler_type == "cosine":
|
|
2258
|
-
import math
|
|
2259
|
-
def lr_lambda(current_step: int) -> float:
|
|
2260
|
-
if current_step < num_warmup_steps:
|
|
2261
|
-
# Linear warmup
|
|
2262
|
-
return float(current_step) / float(max(1, num_warmup_steps))
|
|
2263
|
-
# Cosine decay
|
|
2264
|
-
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
|
2265
|
-
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
|
|
2266
|
-
else:
|
|
2267
|
-
raise ValueError(f"Unknown lr_scheduler_type: {scheduler_type}. Use 'linear', 'cosine', 'constant', or 'none'.")
|
|
2268
|
-
|
|
2269
|
-
return LambdaLR(optimizer, lr_lambda)
|
|
2270
|
-
|
|
2271
|
-
|
|
2272
|
-
def train_supervised(
|
|
2273
|
-
adapter: BaseVLMAdapter,
|
|
2274
|
-
dataset: Dataset,
|
|
2275
|
-
config: TrainingConfig,
|
|
2276
|
-
optimizer: Optional[Optimizer] = None,
|
|
2277
|
-
logger: Optional[TrainingLogger] = None,
|
|
2278
|
-
episode: Optional[Episode] = None,
|
|
2279
|
-
) -> bool:
|
|
2280
|
-
"""Minimal supervised training loop skeleton.
|
|
2281
|
-
|
|
2282
|
-
This assumes that `adapter.prepare_inputs` and `adapter.compute_loss` are
|
|
2283
|
-
implemented. It will raise if those methods are not implemented.
|
|
2284
|
-
|
|
2285
|
-
Args:
|
|
2286
|
-
adapter: VLM adapter to train.
|
|
2287
|
-
dataset: Training dataset.
|
|
2288
|
-
config: Training configuration.
|
|
2289
|
-
optimizer: Optional optimizer (default: AdamW).
|
|
2290
|
-
logger: Optional training logger for visualization.
|
|
2291
|
-
episode: Optional episode for periodic evaluation (generates comparison_epoch{N}.html).
|
|
2292
|
-
|
|
2293
|
-
Returns:
|
|
2294
|
-
True if training completed successfully, False if aborted due to NaN/Inf loss.
|
|
2295
|
-
"""
|
|
2296
|
-
|
|
2297
|
-
device = adapter.device # type: ignore[attr-defined]
|
|
2298
|
-
dataloader = _create_dataloader(dataset, batch_size=config.per_device_train_batch_size)
|
|
2299
|
-
|
|
2300
|
-
if optimizer is None:
|
|
2301
|
-
optimizer = torch.optim.AdamW(
|
|
2302
|
-
adapter.model.parameters(), # type: ignore[arg-type]
|
|
2303
|
-
lr=config.learning_rate,
|
|
2304
|
-
weight_decay=config.weight_decay,
|
|
2305
|
-
)
|
|
2306
|
-
|
|
2307
|
-
# Create logger if not provided
|
|
2308
|
-
if logger is None:
|
|
2309
|
-
logger = TrainingLogger(config.output_dir, config)
|
|
2310
|
-
|
|
2311
|
-
# Calculate total training steps for scheduler
|
|
2312
|
-
num_training_steps = len(dataloader) * config.num_train_epochs // config.gradient_accumulation_steps
|
|
2313
|
-
|
|
2314
|
-
# Create learning rate scheduler
|
|
2315
|
-
lr_scheduler = _create_lr_scheduler(optimizer, config, num_training_steps)
|
|
2316
|
-
|
|
2317
|
-
total_steps = 0
|
|
2318
|
-
adapter.train()
|
|
2319
|
-
|
|
2320
|
-
# Early stopping tracking
|
|
2321
|
-
consecutive_low_loss = 0
|
|
2322
|
-
early_stopped = False
|
|
2323
|
-
user_stopped = False
|
|
2324
|
-
|
|
2325
|
-
for epoch in range(config.num_train_epochs):
|
|
2326
|
-
if early_stopped or user_stopped:
|
|
2327
|
-
break
|
|
2328
|
-
|
|
2329
|
-
for _, batch in enumerate(dataloader):
|
|
2330
|
-
# Check for stop signal from dashboard
|
|
2331
|
-
stop_file = Path(config.output_dir) / "STOP_TRAINING"
|
|
2332
|
-
if stop_file.exists():
|
|
2333
|
-
msg = "Stop signal received from dashboard. Stopping training..."
|
|
2334
|
-
print(msg)
|
|
2335
|
-
logger._log_to_terminal(msg)
|
|
2336
|
-
# Set termination status for dashboard
|
|
2337
|
-
logger.state.termination_status = "user_stop"
|
|
2338
|
-
logger.state.termination_message = "Training stopped by user via dashboard"
|
|
2339
|
-
logger.save()
|
|
2340
|
-
user_stopped = True
|
|
2341
|
-
stop_file.unlink() # Remove signal file
|
|
2342
|
-
break
|
|
2343
|
-
|
|
2344
|
-
# Batch is a List[Dict[str, Any]] of SFT-style samples; adapter is
|
|
2345
|
-
# responsible for converting it into model inputs.
|
|
2346
|
-
samples: List[Dict[str, Any]] = batch
|
|
2347
|
-
|
|
2348
|
-
inputs = adapter.prepare_inputs(samples)
|
|
2349
|
-
inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
|
|
2350
|
-
|
|
2351
|
-
loss = adapter.compute_loss(inputs)
|
|
2352
|
-
|
|
2353
|
-
# Guard against invalid losses to avoid propagating NaNs/Infs
|
|
2354
|
-
if torch.isnan(loss) or torch.isinf(loss):
|
|
2355
|
-
msg = f"Encountered invalid loss at epoch={epoch} step={total_steps + 1}: {loss.item()}"
|
|
2356
|
-
print(msg)
|
|
2357
|
-
logger._log_to_terminal(msg)
|
|
2358
|
-
logger.on_train_end()
|
|
2359
|
-
return False
|
|
2360
|
-
|
|
2361
|
-
loss.backward()
|
|
2362
|
-
|
|
2363
|
-
if (total_steps + 1) % config.gradient_accumulation_steps == 0:
|
|
2364
|
-
torch.nn.utils.clip_grad_norm_(adapter.model.parameters(), config.max_grad_norm) # type: ignore[arg-type]
|
|
2365
|
-
optimizer.step()
|
|
2366
|
-
if lr_scheduler is not None:
|
|
2367
|
-
lr_scheduler.step()
|
|
2368
|
-
optimizer.zero_grad()
|
|
2369
|
-
|
|
2370
|
-
total_steps += 1
|
|
2371
|
-
loss_val = loss.item()
|
|
2372
|
-
|
|
2373
|
-
# Get current learning rate from optimizer
|
|
2374
|
-
current_lr = optimizer.param_groups[0]['lr']
|
|
2375
|
-
|
|
2376
|
-
# Log step
|
|
2377
|
-
logger.on_step(epoch, total_steps, loss_val, current_lr)
|
|
2378
|
-
|
|
2379
|
-
if config.logging_steps and total_steps % config.logging_steps == 0:
|
|
2380
|
-
msg = f"epoch={epoch} step={total_steps} loss={loss_val:.4f} lr={current_lr:.6f}"
|
|
2381
|
-
print(msg)
|
|
2382
|
-
logger._log_to_terminal(msg)
|
|
2383
|
-
|
|
2384
|
-
# Early stopping check
|
|
2385
|
-
if loss_val < config.early_stop_loss:
|
|
2386
|
-
consecutive_low_loss += 1
|
|
2387
|
-
if consecutive_low_loss >= config.early_stop_patience:
|
|
2388
|
-
msg = (
|
|
2389
|
-
f"Early stopping: loss ({loss_val:.6f}) below threshold "
|
|
2390
|
-
f"({config.early_stop_loss}) for {config.early_stop_patience} consecutive steps"
|
|
2391
|
-
)
|
|
2392
|
-
print(msg)
|
|
2393
|
-
logger._log_to_terminal(msg)
|
|
2394
|
-
# Set termination status for dashboard
|
|
2395
|
-
logger.state.termination_status = "auto_low_loss"
|
|
2396
|
-
logger.state.termination_message = (
|
|
2397
|
-
f"Loss reached {loss_val:.6f} (< {config.early_stop_loss}) "
|
|
2398
|
-
f"for {config.early_stop_patience} consecutive steps"
|
|
2399
|
-
)
|
|
2400
|
-
logger.save()
|
|
2401
|
-
early_stopped = True
|
|
2402
|
-
break
|
|
2403
|
-
else:
|
|
2404
|
-
consecutive_low_loss = 0
|
|
2405
|
-
|
|
2406
|
-
# End of epoch
|
|
2407
|
-
logger.on_epoch_end(epoch)
|
|
2408
|
-
|
|
2409
|
-
# Save checkpoint at end of each epoch
|
|
2410
|
-
if config.save_checkpoint_every_epoch:
|
|
2411
|
-
checkpoint_path = Path(config.checkpoint_dir) / f"epoch_{epoch}"
|
|
2412
|
-
checkpoint_path.mkdir(parents=True, exist_ok=True)
|
|
2413
|
-
try:
|
|
2414
|
-
adapter.save_checkpoint(str(checkpoint_path))
|
|
2415
|
-
msg = f"Checkpoint saved to {checkpoint_path}"
|
|
2416
|
-
print(msg)
|
|
2417
|
-
logger._log_to_terminal(msg)
|
|
2418
|
-
except Exception as e:
|
|
2419
|
-
msg = f"Warning: Failed to save checkpoint: {e}"
|
|
2420
|
-
print(msg)
|
|
2421
|
-
logger._log_to_terminal(msg)
|
|
2422
|
-
|
|
2423
|
-
# Run evaluation after each epoch (generates comparison_epoch{N}.html)
|
|
2424
|
-
if config.eval_every_epoch and episode is not None:
|
|
2425
|
-
try:
|
|
2426
|
-
print(f"Running epoch {epoch} evaluation...")
|
|
2427
|
-
run_epoch_evaluation(
|
|
2428
|
-
adapter=adapter,
|
|
2429
|
-
episode=episode,
|
|
2430
|
-
epoch=epoch,
|
|
2431
|
-
config=config,
|
|
2432
|
-
logger=logger,
|
|
2433
|
-
)
|
|
2434
|
-
except Exception as e:
|
|
2435
|
-
print(f"Warning: Epoch evaluation failed: {e}")
|
|
2436
|
-
import traceback
|
|
2437
|
-
traceback.print_exc()
|
|
2438
|
-
|
|
2439
|
-
# Set termination status if not already set (normal completion)
|
|
2440
|
-
if not logger.state.termination_status:
|
|
2441
|
-
logger.state.termination_status = "auto_complete"
|
|
2442
|
-
logger.state.termination_message = f"Training completed all {config.num_train_epochs} epochs"
|
|
2443
|
-
logger.save()
|
|
2444
|
-
|
|
2445
|
-
logger.on_train_end()
|
|
2446
|
-
return True
|