openadapt-ml 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openadapt_ml/baselines/__init__.py +121 -0
- openadapt_ml/baselines/adapter.py +185 -0
- openadapt_ml/baselines/cli.py +314 -0
- openadapt_ml/baselines/config.py +448 -0
- openadapt_ml/baselines/parser.py +922 -0
- openadapt_ml/baselines/prompts.py +787 -0
- openadapt_ml/benchmarks/__init__.py +13 -107
- openadapt_ml/benchmarks/agent.py +297 -374
- openadapt_ml/benchmarks/azure.py +62 -24
- openadapt_ml/benchmarks/azure_ops_tracker.py +521 -0
- openadapt_ml/benchmarks/cli.py +1874 -751
- openadapt_ml/benchmarks/trace_export.py +631 -0
- openadapt_ml/benchmarks/viewer.py +1236 -0
- openadapt_ml/benchmarks/vm_monitor.py +1111 -0
- openadapt_ml/benchmarks/waa_deploy/Dockerfile +216 -0
- openadapt_ml/benchmarks/waa_deploy/__init__.py +10 -0
- openadapt_ml/benchmarks/waa_deploy/api_agent.py +540 -0
- openadapt_ml/benchmarks/waa_deploy/start_waa_server.bat +53 -0
- openadapt_ml/cloud/azure_inference.py +3 -5
- openadapt_ml/cloud/lambda_labs.py +722 -307
- openadapt_ml/cloud/local.py +3194 -89
- openadapt_ml/cloud/ssh_tunnel.py +595 -0
- openadapt_ml/datasets/next_action.py +125 -96
- openadapt_ml/evals/grounding.py +32 -9
- openadapt_ml/evals/plot_eval_metrics.py +15 -13
- openadapt_ml/evals/trajectory_matching.py +120 -57
- openadapt_ml/experiments/demo_prompt/__init__.py +19 -0
- openadapt_ml/experiments/demo_prompt/format_demo.py +236 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_20251231_002125.json +83 -0
- openadapt_ml/experiments/demo_prompt/results/experiment_n30_20251231_165958.json +1100 -0
- openadapt_ml/experiments/demo_prompt/results/multistep_20251231_025051.json +182 -0
- openadapt_ml/experiments/demo_prompt/run_experiment.py +541 -0
- openadapt_ml/experiments/representation_shootout/__init__.py +70 -0
- openadapt_ml/experiments/representation_shootout/conditions.py +708 -0
- openadapt_ml/experiments/representation_shootout/config.py +390 -0
- openadapt_ml/experiments/representation_shootout/evaluator.py +659 -0
- openadapt_ml/experiments/representation_shootout/runner.py +687 -0
- openadapt_ml/experiments/waa_demo/__init__.py +10 -0
- openadapt_ml/experiments/waa_demo/demos.py +357 -0
- openadapt_ml/experiments/waa_demo/runner.py +732 -0
- openadapt_ml/experiments/waa_demo/tasks.py +151 -0
- openadapt_ml/export/__init__.py +9 -0
- openadapt_ml/export/__main__.py +6 -0
- openadapt_ml/export/cli.py +89 -0
- openadapt_ml/export/parquet.py +277 -0
- openadapt_ml/grounding/detector.py +18 -14
- openadapt_ml/ingest/__init__.py +11 -10
- openadapt_ml/ingest/capture.py +97 -86
- openadapt_ml/ingest/loader.py +120 -69
- openadapt_ml/ingest/synthetic.py +344 -193
- openadapt_ml/models/api_adapter.py +14 -4
- openadapt_ml/models/base_adapter.py +10 -2
- openadapt_ml/models/providers/__init__.py +288 -0
- openadapt_ml/models/providers/anthropic.py +266 -0
- openadapt_ml/models/providers/base.py +299 -0
- openadapt_ml/models/providers/google.py +376 -0
- openadapt_ml/models/providers/openai.py +342 -0
- openadapt_ml/models/qwen_vl.py +46 -19
- openadapt_ml/perception/__init__.py +35 -0
- openadapt_ml/perception/integration.py +399 -0
- openadapt_ml/retrieval/README.md +226 -0
- openadapt_ml/retrieval/USAGE.md +391 -0
- openadapt_ml/retrieval/__init__.py +91 -0
- openadapt_ml/retrieval/demo_retriever.py +843 -0
- openadapt_ml/retrieval/embeddings.py +630 -0
- openadapt_ml/retrieval/index.py +194 -0
- openadapt_ml/retrieval/retriever.py +162 -0
- openadapt_ml/runtime/__init__.py +50 -0
- openadapt_ml/runtime/policy.py +27 -14
- openadapt_ml/runtime/safety_gate.py +471 -0
- openadapt_ml/schema/__init__.py +113 -0
- openadapt_ml/schema/converters.py +588 -0
- openadapt_ml/schema/episode.py +470 -0
- openadapt_ml/scripts/capture_screenshots.py +530 -0
- openadapt_ml/scripts/compare.py +102 -61
- openadapt_ml/scripts/demo_policy.py +4 -1
- openadapt_ml/scripts/eval_policy.py +19 -14
- openadapt_ml/scripts/make_gif.py +1 -1
- openadapt_ml/scripts/prepare_synthetic.py +16 -17
- openadapt_ml/scripts/train.py +98 -75
- openadapt_ml/segmentation/README.md +920 -0
- openadapt_ml/segmentation/__init__.py +97 -0
- openadapt_ml/segmentation/adapters/__init__.py +5 -0
- openadapt_ml/segmentation/adapters/capture_adapter.py +420 -0
- openadapt_ml/segmentation/annotator.py +610 -0
- openadapt_ml/segmentation/cache.py +290 -0
- openadapt_ml/segmentation/cli.py +674 -0
- openadapt_ml/segmentation/deduplicator.py +656 -0
- openadapt_ml/segmentation/frame_describer.py +788 -0
- openadapt_ml/segmentation/pipeline.py +340 -0
- openadapt_ml/segmentation/schemas.py +622 -0
- openadapt_ml/segmentation/segment_extractor.py +634 -0
- openadapt_ml/training/azure_ops_viewer.py +1097 -0
- openadapt_ml/training/benchmark_viewer.py +3255 -19
- openadapt_ml/training/shared_ui.py +7 -7
- openadapt_ml/training/stub_provider.py +57 -35
- openadapt_ml/training/trainer.py +255 -441
- openadapt_ml/training/trl_trainer.py +403 -0
- openadapt_ml/training/viewer.py +323 -108
- openadapt_ml/training/viewer_components.py +180 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/METADATA +312 -69
- openadapt_ml-0.2.1.dist-info/RECORD +116 -0
- openadapt_ml/benchmarks/base.py +0 -366
- openadapt_ml/benchmarks/data_collection.py +0 -432
- openadapt_ml/benchmarks/runner.py +0 -381
- openadapt_ml/benchmarks/waa.py +0 -704
- openadapt_ml/schemas/__init__.py +0 -53
- openadapt_ml/schemas/sessions.py +0 -122
- openadapt_ml/schemas/validation.py +0 -252
- openadapt_ml-0.1.0.dist-info/RECORD +0 -55
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/WHEEL +0 -0
- {openadapt_ml-0.1.0.dist-info → openadapt_ml-0.2.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,403 @@
|
|
|
1
|
+
"""Simplified training using TRL SFTTrainer + Unsloth.
|
|
2
|
+
|
|
3
|
+
This module provides a minimal, efficient training path for VLMs:
|
|
4
|
+
- Unsloth for 2x speed, 50% less VRAM
|
|
5
|
+
- TRL SFTTrainer for production-grade training
|
|
6
|
+
- Direct integration with openadapt-ml data format
|
|
7
|
+
|
|
8
|
+
Usage:
|
|
9
|
+
from openadapt_ml.training.trl_trainer import train_with_trl
|
|
10
|
+
|
|
11
|
+
# Train on episodes
|
|
12
|
+
train_with_trl(
|
|
13
|
+
episodes=episodes,
|
|
14
|
+
model_name="unsloth/Qwen2.5-VL-7B-Instruct",
|
|
15
|
+
output_dir="checkpoints/my_model",
|
|
16
|
+
)
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
from __future__ import annotations
|
|
20
|
+
|
|
21
|
+
import os
|
|
22
|
+
from dataclasses import dataclass
|
|
23
|
+
from pathlib import Path
|
|
24
|
+
from typing import Any, Dict, List, Optional
|
|
25
|
+
|
|
26
|
+
from PIL import Image
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class TRLTrainingConfig:
|
|
31
|
+
"""Configuration for TRL-based training."""
|
|
32
|
+
|
|
33
|
+
# Model
|
|
34
|
+
model_name: str = "unsloth/Qwen2.5-VL-7B-Instruct"
|
|
35
|
+
load_in_4bit: bool = True
|
|
36
|
+
max_seq_length: int = 4096
|
|
37
|
+
|
|
38
|
+
# LoRA
|
|
39
|
+
lora_r: int = 16
|
|
40
|
+
lora_alpha: int = 32
|
|
41
|
+
lora_dropout: float = 0.0
|
|
42
|
+
finetune_vision_layers: bool = False # Set True if grounding needs improvement
|
|
43
|
+
|
|
44
|
+
# Training
|
|
45
|
+
num_epochs: int = 3
|
|
46
|
+
batch_size: int = 1
|
|
47
|
+
gradient_accumulation_steps: int = 4
|
|
48
|
+
learning_rate: float = 2e-4
|
|
49
|
+
warmup_ratio: float = 0.03
|
|
50
|
+
|
|
51
|
+
# Output
|
|
52
|
+
output_dir: str = "checkpoints"
|
|
53
|
+
logging_steps: int = 10
|
|
54
|
+
save_strategy: str = "epoch"
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _load_unsloth_model(config: TRLTrainingConfig):
|
|
58
|
+
"""Load model with Unsloth optimizations.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
tuple: (model, tokenizer, is_unsloth) - is_unsloth indicates if Unsloth was used
|
|
62
|
+
"""
|
|
63
|
+
# Check if Unsloth is explicitly disabled via environment variable
|
|
64
|
+
if os.environ.get("OPENADAPT_DISABLE_UNSLOTH", "").lower() in ("1", "true", "yes"):
|
|
65
|
+
print("Unsloth disabled via OPENADAPT_DISABLE_UNSLOTH environment variable")
|
|
66
|
+
return _load_standard_model(config)
|
|
67
|
+
|
|
68
|
+
try:
|
|
69
|
+
from unsloth import FastVisionModel
|
|
70
|
+
|
|
71
|
+
model, tokenizer = FastVisionModel.from_pretrained(
|
|
72
|
+
config.model_name,
|
|
73
|
+
load_in_4bit=config.load_in_4bit,
|
|
74
|
+
use_gradient_checkpointing="unsloth",
|
|
75
|
+
max_seq_length=config.max_seq_length,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
# Apply LoRA
|
|
79
|
+
model = FastVisionModel.get_peft_model(
|
|
80
|
+
model,
|
|
81
|
+
finetune_vision_layers=config.finetune_vision_layers,
|
|
82
|
+
finetune_language_layers=True,
|
|
83
|
+
finetune_attention_modules=True,
|
|
84
|
+
finetune_mlp_modules=True,
|
|
85
|
+
r=config.lora_r,
|
|
86
|
+
lora_alpha=config.lora_alpha,
|
|
87
|
+
lora_dropout=config.lora_dropout,
|
|
88
|
+
random_state=42,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# Enable training mode
|
|
92
|
+
FastVisionModel.for_training(model)
|
|
93
|
+
|
|
94
|
+
print(
|
|
95
|
+
f"✓ Loaded {config.model_name} with Unsloth (4-bit: {config.load_in_4bit})"
|
|
96
|
+
)
|
|
97
|
+
return model, tokenizer, True
|
|
98
|
+
|
|
99
|
+
except ImportError:
|
|
100
|
+
print("⚠ Unsloth not installed, falling back to standard transformers")
|
|
101
|
+
return _load_standard_model(config)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _load_standard_model(config: TRLTrainingConfig):
|
|
105
|
+
"""Fallback: Load model with standard transformers + peft.
|
|
106
|
+
|
|
107
|
+
Automatically detects vision-language models and uses the appropriate
|
|
108
|
+
model class (Qwen2VLForConditionalGeneration for VL models,
|
|
109
|
+
AutoModelForCausalLM for text-only models).
|
|
110
|
+
"""
|
|
111
|
+
from transformers import AutoConfig, AutoProcessor
|
|
112
|
+
from peft import LoraConfig, get_peft_model
|
|
113
|
+
import torch
|
|
114
|
+
|
|
115
|
+
# Check if this is a vision-language model
|
|
116
|
+
model_config = AutoConfig.from_pretrained(
|
|
117
|
+
config.model_name, trust_remote_code=True
|
|
118
|
+
)
|
|
119
|
+
is_vl_model = (
|
|
120
|
+
"VL" in config.model_name.upper()
|
|
121
|
+
or "vision" in config.model_name.lower()
|
|
122
|
+
or hasattr(model_config, "vision_config")
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
if is_vl_model:
|
|
126
|
+
# Vision-language model - use Qwen2VLForConditionalGeneration or AutoModelForVision2Seq
|
|
127
|
+
try:
|
|
128
|
+
from transformers import Qwen2VLForConditionalGeneration
|
|
129
|
+
|
|
130
|
+
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
131
|
+
config.model_name,
|
|
132
|
+
torch_dtype=torch.bfloat16,
|
|
133
|
+
device_map="auto",
|
|
134
|
+
trust_remote_code=True,
|
|
135
|
+
)
|
|
136
|
+
print(" Using Qwen2VLForConditionalGeneration for VL model")
|
|
137
|
+
except (ImportError, ValueError, RuntimeError, TypeError):
|
|
138
|
+
# Fallback to AutoModelForVision2Seq for other VL models
|
|
139
|
+
from transformers import AutoModelForVision2Seq
|
|
140
|
+
|
|
141
|
+
model = AutoModelForVision2Seq.from_pretrained(
|
|
142
|
+
config.model_name,
|
|
143
|
+
torch_dtype=torch.bfloat16,
|
|
144
|
+
device_map="auto",
|
|
145
|
+
trust_remote_code=True,
|
|
146
|
+
)
|
|
147
|
+
print(" Using AutoModelForVision2Seq for VL model")
|
|
148
|
+
else:
|
|
149
|
+
# Text-only model
|
|
150
|
+
from transformers import AutoModelForCausalLM
|
|
151
|
+
|
|
152
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
153
|
+
config.model_name,
|
|
154
|
+
torch_dtype=torch.bfloat16,
|
|
155
|
+
device_map="auto",
|
|
156
|
+
trust_remote_code=True,
|
|
157
|
+
)
|
|
158
|
+
print(" Using AutoModelForCausalLM for text-only model")
|
|
159
|
+
|
|
160
|
+
processor = AutoProcessor.from_pretrained(config.model_name, trust_remote_code=True)
|
|
161
|
+
|
|
162
|
+
# Apply LoRA - use SEQ_2_SEQ_LM for VL models, CAUSAL_LM for text-only
|
|
163
|
+
peft_config = LoraConfig(
|
|
164
|
+
r=config.lora_r,
|
|
165
|
+
lora_alpha=config.lora_alpha,
|
|
166
|
+
lora_dropout=config.lora_dropout,
|
|
167
|
+
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
|
|
168
|
+
task_type="SEQ_2_SEQ_LM" if is_vl_model else "CAUSAL_LM",
|
|
169
|
+
)
|
|
170
|
+
model = get_peft_model(model, peft_config)
|
|
171
|
+
|
|
172
|
+
print(f"✓ Loaded {config.model_name} with standard transformers")
|
|
173
|
+
return model, processor, False
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def _convert_samples_to_trl_format(
|
|
177
|
+
samples: List[Dict[str, Any]],
|
|
178
|
+
base_path: Optional[Path] = None,
|
|
179
|
+
) -> List[Dict[str, Any]]:
|
|
180
|
+
"""Convert openadapt-ml samples to TRL format.
|
|
181
|
+
|
|
182
|
+
The only change is loading image paths as PIL Images.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
samples: List of samples from build_next_action_sft_samples()
|
|
186
|
+
base_path: Optional base path to resolve relative image paths
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
List of samples with PIL Images instead of paths
|
|
190
|
+
"""
|
|
191
|
+
trl_samples = []
|
|
192
|
+
|
|
193
|
+
for sample in samples:
|
|
194
|
+
# Load images as PIL
|
|
195
|
+
pil_images = []
|
|
196
|
+
for img_path in sample["images"]:
|
|
197
|
+
path = Path(img_path)
|
|
198
|
+
if base_path and not path.is_absolute():
|
|
199
|
+
path = base_path / path
|
|
200
|
+
|
|
201
|
+
if path.exists():
|
|
202
|
+
pil_images.append(Image.open(path).convert("RGB"))
|
|
203
|
+
else:
|
|
204
|
+
print(f"⚠ Image not found: {path}")
|
|
205
|
+
continue
|
|
206
|
+
|
|
207
|
+
if not pil_images:
|
|
208
|
+
continue # Skip samples with missing images
|
|
209
|
+
|
|
210
|
+
trl_samples.append(
|
|
211
|
+
{
|
|
212
|
+
"images": pil_images,
|
|
213
|
+
"messages": sample["messages"],
|
|
214
|
+
}
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
return trl_samples
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def train_with_trl(
|
|
221
|
+
episodes: List,
|
|
222
|
+
config: Optional[TRLTrainingConfig] = None,
|
|
223
|
+
use_som: bool = False,
|
|
224
|
+
base_path: Optional[Path] = None,
|
|
225
|
+
) -> str:
|
|
226
|
+
"""Train a VLM using TRL SFTTrainer + Unsloth.
|
|
227
|
+
|
|
228
|
+
This is the simplified training entry point that replaces the legacy
|
|
229
|
+
custom training loop. It:
|
|
230
|
+
1. Converts episodes to TRL format
|
|
231
|
+
2. Loads model with Unsloth (or fallback)
|
|
232
|
+
3. Trains with TRL's SFTTrainer
|
|
233
|
+
4. Saves LoRA adapter
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
episodes: List of Episode objects from openadapt-ml schema
|
|
237
|
+
config: Training configuration (uses defaults if None)
|
|
238
|
+
use_som: If True, use Set-of-Marks DSL instead of coordinates
|
|
239
|
+
base_path: Base path for resolving relative image paths
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
Path to saved checkpoint
|
|
243
|
+
"""
|
|
244
|
+
from datasets import Dataset
|
|
245
|
+
from openadapt_ml.datasets.next_action import build_next_action_sft_samples
|
|
246
|
+
|
|
247
|
+
config = config or TRLTrainingConfig()
|
|
248
|
+
|
|
249
|
+
# Step 1: Convert episodes to SFT samples
|
|
250
|
+
print(f"Converting {len(episodes)} episodes to training samples...")
|
|
251
|
+
raw_samples = build_next_action_sft_samples(episodes, use_som=use_som)
|
|
252
|
+
print(f" Generated {len(raw_samples)} training samples")
|
|
253
|
+
|
|
254
|
+
# Step 2: Convert to TRL format (load images as PIL)
|
|
255
|
+
print("Loading images...")
|
|
256
|
+
trl_samples = _convert_samples_to_trl_format(raw_samples, base_path)
|
|
257
|
+
print(f" Loaded {len(trl_samples)} samples with images")
|
|
258
|
+
|
|
259
|
+
if not trl_samples:
|
|
260
|
+
raise ValueError("No valid training samples after loading images")
|
|
261
|
+
|
|
262
|
+
# Step 3: Create HuggingFace Dataset
|
|
263
|
+
dataset = Dataset.from_list(trl_samples)
|
|
264
|
+
|
|
265
|
+
# Step 4: Load model with Unsloth (or fallback)
|
|
266
|
+
model, tokenizer, is_unsloth = _load_unsloth_model(config)
|
|
267
|
+
|
|
268
|
+
# Step 5: Configure and run training
|
|
269
|
+
try:
|
|
270
|
+
from trl import SFTTrainer, SFTConfig
|
|
271
|
+
|
|
272
|
+
if is_unsloth:
|
|
273
|
+
# Unsloth-specific configuration
|
|
274
|
+
from unsloth.trainer import UnslothVisionDataCollator
|
|
275
|
+
|
|
276
|
+
training_args = SFTConfig(
|
|
277
|
+
output_dir=config.output_dir,
|
|
278
|
+
per_device_train_batch_size=config.batch_size,
|
|
279
|
+
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
|
280
|
+
learning_rate=config.learning_rate,
|
|
281
|
+
num_train_epochs=config.num_epochs,
|
|
282
|
+
warmup_ratio=config.warmup_ratio,
|
|
283
|
+
lr_scheduler_type="cosine",
|
|
284
|
+
logging_steps=config.logging_steps,
|
|
285
|
+
save_strategy=config.save_strategy,
|
|
286
|
+
# Unsloth-specific settings
|
|
287
|
+
remove_unused_columns=False,
|
|
288
|
+
dataset_text_field="",
|
|
289
|
+
dataset_kwargs={"skip_prepare_dataset": True},
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
trainer = SFTTrainer(
|
|
293
|
+
model=model,
|
|
294
|
+
tokenizer=tokenizer,
|
|
295
|
+
data_collator=UnslothVisionDataCollator(model, tokenizer),
|
|
296
|
+
train_dataset=dataset,
|
|
297
|
+
args=training_args,
|
|
298
|
+
)
|
|
299
|
+
else:
|
|
300
|
+
# Standard TRL configuration
|
|
301
|
+
training_args = SFTConfig(
|
|
302
|
+
output_dir=config.output_dir,
|
|
303
|
+
per_device_train_batch_size=config.batch_size,
|
|
304
|
+
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
|
305
|
+
learning_rate=config.learning_rate,
|
|
306
|
+
num_train_epochs=config.num_epochs,
|
|
307
|
+
warmup_ratio=config.warmup_ratio,
|
|
308
|
+
lr_scheduler_type="cosine",
|
|
309
|
+
logging_steps=config.logging_steps,
|
|
310
|
+
save_strategy=config.save_strategy,
|
|
311
|
+
max_length=None, # Critical for VLMs
|
|
312
|
+
assistant_only_loss=False, # Not supported for VL models yet
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
trainer = SFTTrainer(
|
|
316
|
+
model=model,
|
|
317
|
+
train_dataset=dataset,
|
|
318
|
+
args=training_args,
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
print(f"\n{'=' * 50}")
|
|
322
|
+
print("Starting training:")
|
|
323
|
+
print(f" Model: {config.model_name}")
|
|
324
|
+
print(f" Samples: {len(trl_samples)}")
|
|
325
|
+
print(f" Epochs: {config.num_epochs}")
|
|
326
|
+
print(f" Batch size: {config.batch_size}")
|
|
327
|
+
print(f" Unsloth: {is_unsloth}")
|
|
328
|
+
print(f" Output: {config.output_dir}")
|
|
329
|
+
print(f"{'=' * 50}\n")
|
|
330
|
+
|
|
331
|
+
trainer.train()
|
|
332
|
+
|
|
333
|
+
# Save the LoRA adapter
|
|
334
|
+
checkpoint_path = Path(config.output_dir) / "final"
|
|
335
|
+
trainer.save_model(str(checkpoint_path))
|
|
336
|
+
print(f"\n✓ Saved checkpoint to {checkpoint_path}")
|
|
337
|
+
|
|
338
|
+
return str(checkpoint_path)
|
|
339
|
+
|
|
340
|
+
except ImportError as e:
|
|
341
|
+
raise ImportError(
|
|
342
|
+
f"TRL not installed. Install with: pip install trl\nOriginal error: {e}"
|
|
343
|
+
)
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
def train_from_parquet(
|
|
347
|
+
parquet_path: str,
|
|
348
|
+
config: Optional[TRLTrainingConfig] = None,
|
|
349
|
+
use_som: bool = False,
|
|
350
|
+
) -> str:
|
|
351
|
+
"""Train from a parquet file exported by openadapt-ml.
|
|
352
|
+
|
|
353
|
+
Args:
|
|
354
|
+
parquet_path: Path to parquet file with episode data
|
|
355
|
+
config: Training configuration
|
|
356
|
+
use_som: Use Set-of-Marks DSL
|
|
357
|
+
|
|
358
|
+
Returns:
|
|
359
|
+
Path to saved checkpoint
|
|
360
|
+
"""
|
|
361
|
+
from openadapt_ml.export import from_parquet
|
|
362
|
+
|
|
363
|
+
print(f"Loading episodes from {parquet_path}...")
|
|
364
|
+
episodes = from_parquet(parquet_path)
|
|
365
|
+
|
|
366
|
+
base_path = Path(parquet_path).parent
|
|
367
|
+
|
|
368
|
+
return train_with_trl(
|
|
369
|
+
episodes=episodes,
|
|
370
|
+
config=config,
|
|
371
|
+
use_som=use_som,
|
|
372
|
+
base_path=base_path,
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
if __name__ == "__main__":
|
|
377
|
+
# Simple CLI for testing
|
|
378
|
+
import argparse
|
|
379
|
+
|
|
380
|
+
parser = argparse.ArgumentParser(description="Train VLM with TRL + Unsloth")
|
|
381
|
+
parser.add_argument("--parquet", required=True, help="Path to parquet file")
|
|
382
|
+
parser.add_argument("--output", default="checkpoints", help="Output directory")
|
|
383
|
+
parser.add_argument(
|
|
384
|
+
"--model", default="unsloth/Qwen2.5-VL-7B-Instruct", help="Model name"
|
|
385
|
+
)
|
|
386
|
+
parser.add_argument("--epochs", type=int, default=3, help="Number of epochs")
|
|
387
|
+
parser.add_argument("--use-som", action="store_true", help="Use Set-of-Marks DSL")
|
|
388
|
+
|
|
389
|
+
args = parser.parse_args()
|
|
390
|
+
|
|
391
|
+
config = TRLTrainingConfig(
|
|
392
|
+
model_name=args.model,
|
|
393
|
+
output_dir=args.output,
|
|
394
|
+
num_epochs=args.epochs,
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
checkpoint = train_from_parquet(
|
|
398
|
+
parquet_path=args.parquet,
|
|
399
|
+
config=config,
|
|
400
|
+
use_som=args.use_som,
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
print(f"\nTraining complete! Checkpoint: {checkpoint}")
|