langvision 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of langvision might be problematic. Click here for more details.
- langvision/__init__.py +77 -2
- langvision/callbacks/base.py +166 -7
- langvision/cli/__init__.py +85 -0
- langvision/cli/complete_cli.py +319 -0
- langvision/cli/config.py +344 -0
- langvision/cli/evaluate.py +201 -0
- langvision/cli/export.py +177 -0
- langvision/cli/finetune.py +165 -48
- langvision/cli/model_zoo.py +162 -0
- langvision/cli/train.py +27 -13
- langvision/cli/utils.py +258 -0
- langvision/components/attention.py +4 -1
- langvision/concepts/__init__.py +9 -0
- langvision/concepts/ccot.py +30 -0
- langvision/concepts/cot.py +29 -0
- langvision/concepts/dpo.py +37 -0
- langvision/concepts/grpo.py +25 -0
- langvision/concepts/lime.py +37 -0
- langvision/concepts/ppo.py +47 -0
- langvision/concepts/rlhf.py +40 -0
- langvision/concepts/rlvr.py +25 -0
- langvision/concepts/shap.py +37 -0
- langvision/data/enhanced_datasets.py +582 -0
- langvision/model_zoo.py +169 -2
- langvision/models/lora.py +189 -17
- langvision/models/multimodal.py +297 -0
- langvision/models/resnet.py +303 -0
- langvision/training/advanced_trainer.py +478 -0
- langvision/training/trainer.py +30 -2
- langvision/utils/config.py +180 -9
- langvision/utils/metrics.py +448 -0
- langvision/utils/setup.py +266 -0
- langvision-0.1.0.dist-info/METADATA +50 -0
- langvision-0.1.0.dist-info/RECORD +61 -0
- {langvision-0.0.1.dist-info → langvision-0.1.0.dist-info}/WHEEL +1 -1
- langvision-0.1.0.dist-info/entry_points.txt +2 -0
- langvision-0.0.1.dist-info/METADATA +0 -463
- langvision-0.0.1.dist-info/RECORD +0 -40
- langvision-0.0.1.dist-info/entry_points.txt +0 -2
- langvision-0.0.1.dist-info/licenses/LICENSE +0 -21
- {langvision-0.0.1.dist-info → langvision-0.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,478 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Advanced training utilities with LoRA fine-tuning, mixed precision, and comprehensive logging.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
import torch.optim as optim
|
|
8
|
+
from torch.utils.data import DataLoader
|
|
9
|
+
from torch.cuda.amp import GradScaler, autocast
|
|
10
|
+
from typing import Dict, Any, Optional, List, Callable, Union
|
|
11
|
+
import logging
|
|
12
|
+
import time
|
|
13
|
+
import os
|
|
14
|
+
import json
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
import numpy as np
|
|
17
|
+
from tqdm import tqdm
|
|
18
|
+
import wandb
|
|
19
|
+
from dataclasses import dataclass, asdict
|
|
20
|
+
|
|
21
|
+
from ..models.lora import LoRAConfig
|
|
22
|
+
from ..utils.metrics import MetricsTracker
|
|
23
|
+
from ..callbacks.base import Callback
|
|
24
|
+
from ..utils.device import get_device, to_device
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class TrainingConfig:
|
|
29
|
+
"""Configuration for training parameters."""
|
|
30
|
+
# Basic training settings
|
|
31
|
+
epochs: int = 100
|
|
32
|
+
batch_size: int = 32
|
|
33
|
+
learning_rate: float = 1e-4
|
|
34
|
+
weight_decay: float = 1e-4
|
|
35
|
+
warmup_epochs: int = 5
|
|
36
|
+
|
|
37
|
+
# LoRA settings
|
|
38
|
+
lora_config: Optional[LoRAConfig] = None
|
|
39
|
+
freeze_backbone: bool = True
|
|
40
|
+
|
|
41
|
+
# Optimization settings
|
|
42
|
+
optimizer: str = "adamw" # "adam", "adamw", "sgd"
|
|
43
|
+
scheduler: str = "cosine" # "cosine", "step", "plateau"
|
|
44
|
+
gradient_clip_norm: float = 1.0
|
|
45
|
+
|
|
46
|
+
# Mixed precision and performance
|
|
47
|
+
use_amp: bool = True
|
|
48
|
+
compile_model: bool = False
|
|
49
|
+
|
|
50
|
+
# Regularization
|
|
51
|
+
dropout: float = 0.1
|
|
52
|
+
label_smoothing: float = 0.0
|
|
53
|
+
|
|
54
|
+
# Logging and checkpointing
|
|
55
|
+
log_interval: int = 10
|
|
56
|
+
eval_interval: int = 1
|
|
57
|
+
save_interval: int = 5
|
|
58
|
+
save_top_k: int = 3
|
|
59
|
+
|
|
60
|
+
# Paths
|
|
61
|
+
output_dir: str = "./outputs"
|
|
62
|
+
experiment_name: str = "langvision_experiment"
|
|
63
|
+
|
|
64
|
+
# Distributed training
|
|
65
|
+
distributed: bool = False
|
|
66
|
+
local_rank: int = 0
|
|
67
|
+
|
|
68
|
+
# Early stopping
|
|
69
|
+
early_stopping_patience: int = 10
|
|
70
|
+
early_stopping_metric: str = "val_loss"
|
|
71
|
+
early_stopping_mode: str = "min" # "min" or "max"
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class AdvancedTrainer:
|
|
75
|
+
"""Advanced trainer with comprehensive features for vision model fine-tuning."""
|
|
76
|
+
|
|
77
|
+
def __init__(self,
|
|
78
|
+
model: nn.Module,
|
|
79
|
+
train_loader: DataLoader,
|
|
80
|
+
val_loader: Optional[DataLoader] = None,
|
|
81
|
+
config: Optional[TrainingConfig] = None,
|
|
82
|
+
callbacks: Optional[List[Callback]] = None):
|
|
83
|
+
|
|
84
|
+
self.model = model
|
|
85
|
+
self.train_loader = train_loader
|
|
86
|
+
self.val_loader = val_loader
|
|
87
|
+
self.config = config or TrainingConfig()
|
|
88
|
+
self.callbacks = callbacks or []
|
|
89
|
+
|
|
90
|
+
# Setup device and distributed training
|
|
91
|
+
self.device = get_device()
|
|
92
|
+
self.model = to_device(self.model, self.device)
|
|
93
|
+
|
|
94
|
+
# Setup LoRA if configured
|
|
95
|
+
if self.config.lora_config:
|
|
96
|
+
self._setup_lora()
|
|
97
|
+
|
|
98
|
+
# Setup optimization
|
|
99
|
+
self.optimizer = self._create_optimizer()
|
|
100
|
+
self.scheduler = self._create_scheduler()
|
|
101
|
+
self.scaler = GradScaler() if self.config.use_amp else None
|
|
102
|
+
|
|
103
|
+
# Setup loss function
|
|
104
|
+
self.criterion = self._create_criterion()
|
|
105
|
+
|
|
106
|
+
# Setup logging
|
|
107
|
+
self.logger = self._setup_logging()
|
|
108
|
+
self.metrics_tracker = MetricsTracker()
|
|
109
|
+
|
|
110
|
+
# Training state
|
|
111
|
+
self.current_epoch = 0
|
|
112
|
+
self.global_step = 0
|
|
113
|
+
self.best_metric = float('inf') if self.config.early_stopping_mode == 'min' else float('-inf')
|
|
114
|
+
self.patience_counter = 0
|
|
115
|
+
|
|
116
|
+
# Create output directory
|
|
117
|
+
self.output_dir = Path(self.config.output_dir) / self.config.experiment_name
|
|
118
|
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
119
|
+
|
|
120
|
+
# Save config
|
|
121
|
+
with open(self.output_dir / "config.json", "w") as f:
|
|
122
|
+
json.dump(asdict(self.config), f, indent=2)
|
|
123
|
+
|
|
124
|
+
# Model compilation for PyTorch 2.0+
|
|
125
|
+
if self.config.compile_model and hasattr(torch, 'compile'):
|
|
126
|
+
self.model = torch.compile(self.model)
|
|
127
|
+
|
|
128
|
+
def _setup_lora(self):
|
|
129
|
+
"""Setup LoRA fine-tuning by freezing backbone parameters."""
|
|
130
|
+
if self.config.freeze_backbone:
|
|
131
|
+
# Freeze all parameters first
|
|
132
|
+
for param in self.model.parameters():
|
|
133
|
+
param.requires_grad = False
|
|
134
|
+
|
|
135
|
+
# Unfreeze LoRA parameters
|
|
136
|
+
for name, module in self.model.named_modules():
|
|
137
|
+
if hasattr(module, 'lora_A') or hasattr(module, 'lora_B'):
|
|
138
|
+
for param in module.parameters():
|
|
139
|
+
if 'lora' in param.name if hasattr(param, 'name') else True:
|
|
140
|
+
param.requires_grad = True
|
|
141
|
+
|
|
142
|
+
# Log trainable parameters
|
|
143
|
+
total_params = sum(p.numel() for p in self.model.parameters())
|
|
144
|
+
trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
|
145
|
+
|
|
146
|
+
self.logger.info(f"Total parameters: {total_params:,}")
|
|
147
|
+
self.logger.info(f"Trainable parameters: {trainable_params:,}")
|
|
148
|
+
self.logger.info(f"Trainable ratio: {trainable_params/total_params:.2%}")
|
|
149
|
+
|
|
150
|
+
def _create_optimizer(self) -> optim.Optimizer:
|
|
151
|
+
"""Create optimizer based on configuration."""
|
|
152
|
+
trainable_params = [p for p in self.model.parameters() if p.requires_grad]
|
|
153
|
+
|
|
154
|
+
if self.config.optimizer.lower() == "adam":
|
|
155
|
+
return optim.Adam(trainable_params,
|
|
156
|
+
lr=self.config.learning_rate,
|
|
157
|
+
weight_decay=self.config.weight_decay)
|
|
158
|
+
elif self.config.optimizer.lower() == "adamw":
|
|
159
|
+
return optim.AdamW(trainable_params,
|
|
160
|
+
lr=self.config.learning_rate,
|
|
161
|
+
weight_decay=self.config.weight_decay)
|
|
162
|
+
elif self.config.optimizer.lower() == "sgd":
|
|
163
|
+
return optim.SGD(trainable_params,
|
|
164
|
+
lr=self.config.learning_rate,
|
|
165
|
+
momentum=0.9,
|
|
166
|
+
weight_decay=self.config.weight_decay)
|
|
167
|
+
else:
|
|
168
|
+
raise ValueError(f"Unsupported optimizer: {self.config.optimizer}")
|
|
169
|
+
|
|
170
|
+
def _create_scheduler(self) -> Optional[optim.lr_scheduler._LRScheduler]:
|
|
171
|
+
"""Create learning rate scheduler."""
|
|
172
|
+
if self.config.scheduler.lower() == "cosine":
|
|
173
|
+
return optim.lr_scheduler.CosineAnnealingLR(
|
|
174
|
+
self.optimizer,
|
|
175
|
+
T_max=self.config.epochs,
|
|
176
|
+
eta_min=self.config.learning_rate * 0.01
|
|
177
|
+
)
|
|
178
|
+
elif self.config.scheduler.lower() == "step":
|
|
179
|
+
return optim.lr_scheduler.StepLR(
|
|
180
|
+
self.optimizer,
|
|
181
|
+
step_size=self.config.epochs // 3,
|
|
182
|
+
gamma=0.1
|
|
183
|
+
)
|
|
184
|
+
elif self.config.scheduler.lower() == "plateau":
|
|
185
|
+
return optim.lr_scheduler.ReduceLROnPlateau(
|
|
186
|
+
self.optimizer,
|
|
187
|
+
mode=self.config.early_stopping_mode,
|
|
188
|
+
patience=self.config.early_stopping_patience // 2,
|
|
189
|
+
factor=0.5
|
|
190
|
+
)
|
|
191
|
+
else:
|
|
192
|
+
return None
|
|
193
|
+
|
|
194
|
+
def _create_criterion(self) -> nn.Module:
|
|
195
|
+
"""Create loss function with label smoothing support."""
|
|
196
|
+
if self.config.label_smoothing > 0:
|
|
197
|
+
return nn.CrossEntropyLoss(label_smoothing=self.config.label_smoothing)
|
|
198
|
+
else:
|
|
199
|
+
return nn.CrossEntropyLoss()
|
|
200
|
+
|
|
201
|
+
def _setup_logging(self) -> logging.Logger:
|
|
202
|
+
"""Setup logging configuration."""
|
|
203
|
+
logger = logging.getLogger(f"langvision.{self.config.experiment_name}")
|
|
204
|
+
logger.setLevel(logging.INFO)
|
|
205
|
+
|
|
206
|
+
# File handler
|
|
207
|
+
log_file = self.output_dir / "training.log"
|
|
208
|
+
file_handler = logging.FileHandler(log_file)
|
|
209
|
+
file_handler.setLevel(logging.INFO)
|
|
210
|
+
|
|
211
|
+
# Console handler
|
|
212
|
+
console_handler = logging.StreamHandler()
|
|
213
|
+
console_handler.setLevel(logging.INFO)
|
|
214
|
+
|
|
215
|
+
# Formatter
|
|
216
|
+
formatter = logging.Formatter(
|
|
217
|
+
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
218
|
+
)
|
|
219
|
+
file_handler.setFormatter(formatter)
|
|
220
|
+
console_handler.setFormatter(formatter)
|
|
221
|
+
|
|
222
|
+
logger.addHandler(file_handler)
|
|
223
|
+
logger.addHandler(console_handler)
|
|
224
|
+
|
|
225
|
+
return logger
|
|
226
|
+
|
|
227
|
+
def train_epoch(self) -> Dict[str, float]:
|
|
228
|
+
"""Train for one epoch."""
|
|
229
|
+
self.model.train()
|
|
230
|
+
epoch_metrics = {}
|
|
231
|
+
|
|
232
|
+
# Progress bar
|
|
233
|
+
pbar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch}")
|
|
234
|
+
|
|
235
|
+
for batch_idx, batch in enumerate(pbar):
|
|
236
|
+
# Move batch to device
|
|
237
|
+
batch = to_device(batch, self.device)
|
|
238
|
+
|
|
239
|
+
# Forward pass with mixed precision
|
|
240
|
+
with autocast(enabled=self.config.use_amp):
|
|
241
|
+
outputs = self.model(batch['images'])
|
|
242
|
+
loss = self.criterion(outputs, batch['labels'])
|
|
243
|
+
|
|
244
|
+
# Backward pass
|
|
245
|
+
self.optimizer.zero_grad()
|
|
246
|
+
|
|
247
|
+
if self.config.use_amp:
|
|
248
|
+
self.scaler.scale(loss).backward()
|
|
249
|
+
|
|
250
|
+
# Gradient clipping
|
|
251
|
+
if self.config.gradient_clip_norm > 0:
|
|
252
|
+
self.scaler.unscale_(self.optimizer)
|
|
253
|
+
torch.nn.utils.clip_grad_norm_(
|
|
254
|
+
self.model.parameters(),
|
|
255
|
+
self.config.gradient_clip_norm
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
self.scaler.step(self.optimizer)
|
|
259
|
+
self.scaler.update()
|
|
260
|
+
else:
|
|
261
|
+
loss.backward()
|
|
262
|
+
|
|
263
|
+
# Gradient clipping
|
|
264
|
+
if self.config.gradient_clip_norm > 0:
|
|
265
|
+
torch.nn.utils.clip_grad_norm_(
|
|
266
|
+
self.model.parameters(),
|
|
267
|
+
self.config.gradient_clip_norm
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
self.optimizer.step()
|
|
271
|
+
|
|
272
|
+
# Update metrics
|
|
273
|
+
batch_size = batch['images'].size(0)
|
|
274
|
+
self.metrics_tracker.update('train_loss', loss.item(), batch_size)
|
|
275
|
+
|
|
276
|
+
# Calculate accuracy
|
|
277
|
+
with torch.no_grad():
|
|
278
|
+
pred = outputs.argmax(dim=1)
|
|
279
|
+
acc = (pred == batch['labels']).float().mean()
|
|
280
|
+
self.metrics_tracker.update('train_acc', acc.item(), batch_size)
|
|
281
|
+
|
|
282
|
+
# Update progress bar
|
|
283
|
+
pbar.set_postfix({
|
|
284
|
+
'loss': f"{loss.item():.4f}",
|
|
285
|
+
'acc': f"{acc.item():.4f}",
|
|
286
|
+
'lr': f"{self.optimizer.param_groups[0]['lr']:.6f}"
|
|
287
|
+
})
|
|
288
|
+
|
|
289
|
+
# Logging
|
|
290
|
+
if batch_idx % self.config.log_interval == 0:
|
|
291
|
+
self.logger.info(
|
|
292
|
+
f"Epoch {self.current_epoch}, Batch {batch_idx}: "
|
|
293
|
+
f"Loss={loss.item():.4f}, Acc={acc.item():.4f}"
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
self.global_step += 1
|
|
297
|
+
|
|
298
|
+
# Get epoch metrics
|
|
299
|
+
epoch_metrics = self.metrics_tracker.get_averages(['train_loss', 'train_acc'])
|
|
300
|
+
self.metrics_tracker.reset()
|
|
301
|
+
|
|
302
|
+
return epoch_metrics
|
|
303
|
+
|
|
304
|
+
def validate(self) -> Dict[str, float]:
|
|
305
|
+
"""Validate the model."""
|
|
306
|
+
if self.val_loader is None:
|
|
307
|
+
return {}
|
|
308
|
+
|
|
309
|
+
self.model.eval()
|
|
310
|
+
|
|
311
|
+
with torch.no_grad():
|
|
312
|
+
pbar = tqdm(self.val_loader, desc="Validation")
|
|
313
|
+
|
|
314
|
+
for batch in pbar:
|
|
315
|
+
batch = to_device(batch, self.device)
|
|
316
|
+
|
|
317
|
+
with autocast(enabled=self.config.use_amp):
|
|
318
|
+
outputs = self.model(batch['images'])
|
|
319
|
+
loss = self.criterion(outputs, batch['labels'])
|
|
320
|
+
|
|
321
|
+
# Update metrics
|
|
322
|
+
batch_size = batch['images'].size(0)
|
|
323
|
+
self.metrics_tracker.update('val_loss', loss.item(), batch_size)
|
|
324
|
+
|
|
325
|
+
# Calculate accuracy
|
|
326
|
+
pred = outputs.argmax(dim=1)
|
|
327
|
+
acc = (pred == batch['labels']).float().mean()
|
|
328
|
+
self.metrics_tracker.update('val_acc', acc.item(), batch_size)
|
|
329
|
+
|
|
330
|
+
pbar.set_postfix({
|
|
331
|
+
'val_loss': f"{loss.item():.4f}",
|
|
332
|
+
'val_acc': f"{acc.item():.4f}"
|
|
333
|
+
})
|
|
334
|
+
|
|
335
|
+
# Get validation metrics
|
|
336
|
+
val_metrics = self.metrics_tracker.get_averages(['val_loss', 'val_acc'])
|
|
337
|
+
self.metrics_tracker.reset()
|
|
338
|
+
|
|
339
|
+
return val_metrics
|
|
340
|
+
|
|
341
|
+
def save_checkpoint(self, metrics: Dict[str, float], is_best: bool = False):
|
|
342
|
+
"""Save model checkpoint."""
|
|
343
|
+
checkpoint = {
|
|
344
|
+
'epoch': self.current_epoch,
|
|
345
|
+
'global_step': self.global_step,
|
|
346
|
+
'model_state_dict': self.model.state_dict(),
|
|
347
|
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
348
|
+
'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
|
|
349
|
+
'scaler_state_dict': self.scaler.state_dict() if self.scaler else None,
|
|
350
|
+
'metrics': metrics,
|
|
351
|
+
'config': asdict(self.config)
|
|
352
|
+
}
|
|
353
|
+
|
|
354
|
+
# Save regular checkpoint
|
|
355
|
+
checkpoint_path = self.output_dir / f"checkpoint_epoch_{self.current_epoch}.pt"
|
|
356
|
+
torch.save(checkpoint, checkpoint_path)
|
|
357
|
+
|
|
358
|
+
# Save best checkpoint
|
|
359
|
+
if is_best:
|
|
360
|
+
best_path = self.output_dir / "best_model.pt"
|
|
361
|
+
torch.save(checkpoint, best_path)
|
|
362
|
+
self.logger.info(f"New best model saved with {self.config.early_stopping_metric}: {metrics.get(self.config.early_stopping_metric, 'N/A')}")
|
|
363
|
+
|
|
364
|
+
# Clean up old checkpoints (keep only top-k)
|
|
365
|
+
self._cleanup_checkpoints()
|
|
366
|
+
|
|
367
|
+
def _cleanup_checkpoints(self):
|
|
368
|
+
"""Remove old checkpoints, keeping only the most recent ones."""
|
|
369
|
+
checkpoint_files = list(self.output_dir.glob("checkpoint_epoch_*.pt"))
|
|
370
|
+
if len(checkpoint_files) > self.config.save_top_k:
|
|
371
|
+
# Sort by epoch number
|
|
372
|
+
checkpoint_files.sort(key=lambda x: int(x.stem.split('_')[-1]))
|
|
373
|
+
# Remove oldest checkpoints
|
|
374
|
+
for old_checkpoint in checkpoint_files[:-self.config.save_top_k]:
|
|
375
|
+
old_checkpoint.unlink()
|
|
376
|
+
|
|
377
|
+
def train(self):
|
|
378
|
+
"""Main training loop."""
|
|
379
|
+
self.logger.info("Starting training...")
|
|
380
|
+
self.logger.info(f"Training for {self.config.epochs} epochs")
|
|
381
|
+
|
|
382
|
+
# Call training start callbacks
|
|
383
|
+
for callback in self.callbacks:
|
|
384
|
+
callback.on_train_start(self)
|
|
385
|
+
|
|
386
|
+
try:
|
|
387
|
+
for epoch in range(self.config.epochs):
|
|
388
|
+
self.current_epoch = epoch
|
|
389
|
+
|
|
390
|
+
# Call epoch start callbacks
|
|
391
|
+
for callback in self.callbacks:
|
|
392
|
+
callback.on_epoch_start(self, epoch)
|
|
393
|
+
|
|
394
|
+
# Training
|
|
395
|
+
train_metrics = self.train_epoch()
|
|
396
|
+
|
|
397
|
+
# Validation
|
|
398
|
+
val_metrics = {}
|
|
399
|
+
if epoch % self.config.eval_interval == 0:
|
|
400
|
+
val_metrics = self.validate()
|
|
401
|
+
|
|
402
|
+
# Combine metrics
|
|
403
|
+
all_metrics = {**train_metrics, **val_metrics}
|
|
404
|
+
|
|
405
|
+
# Learning rate scheduling
|
|
406
|
+
if self.scheduler:
|
|
407
|
+
if isinstance(self.scheduler, optim.lr_scheduler.ReduceLROnPlateau):
|
|
408
|
+
metric_value = all_metrics.get(self.config.early_stopping_metric)
|
|
409
|
+
if metric_value is not None:
|
|
410
|
+
self.scheduler.step(metric_value)
|
|
411
|
+
else:
|
|
412
|
+
self.scheduler.step()
|
|
413
|
+
|
|
414
|
+
# Logging
|
|
415
|
+
self.logger.info(f"Epoch {epoch} completed:")
|
|
416
|
+
for metric_name, metric_value in all_metrics.items():
|
|
417
|
+
self.logger.info(f" {metric_name}: {metric_value:.4f}")
|
|
418
|
+
|
|
419
|
+
# Check for best model
|
|
420
|
+
current_metric = all_metrics.get(self.config.early_stopping_metric)
|
|
421
|
+
is_best = False
|
|
422
|
+
|
|
423
|
+
if current_metric is not None:
|
|
424
|
+
if self.config.early_stopping_mode == 'min':
|
|
425
|
+
is_best = current_metric < self.best_metric
|
|
426
|
+
else:
|
|
427
|
+
is_best = current_metric > self.best_metric
|
|
428
|
+
|
|
429
|
+
if is_best:
|
|
430
|
+
self.best_metric = current_metric
|
|
431
|
+
self.patience_counter = 0
|
|
432
|
+
else:
|
|
433
|
+
self.patience_counter += 1
|
|
434
|
+
|
|
435
|
+
# Save checkpoint
|
|
436
|
+
if epoch % self.config.save_interval == 0 or is_best:
|
|
437
|
+
self.save_checkpoint(all_metrics, is_best)
|
|
438
|
+
|
|
439
|
+
# Call epoch end callbacks
|
|
440
|
+
for callback in self.callbacks:
|
|
441
|
+
callback.on_epoch_end(self, epoch, all_metrics)
|
|
442
|
+
|
|
443
|
+
# Early stopping
|
|
444
|
+
if self.patience_counter >= self.config.early_stopping_patience:
|
|
445
|
+
self.logger.info(f"Early stopping triggered after {self.patience_counter} epochs without improvement")
|
|
446
|
+
break
|
|
447
|
+
|
|
448
|
+
except KeyboardInterrupt:
|
|
449
|
+
self.logger.info("Training interrupted by user")
|
|
450
|
+
|
|
451
|
+
except Exception as e:
|
|
452
|
+
self.logger.error(f"Training failed with error: {str(e)}")
|
|
453
|
+
raise
|
|
454
|
+
|
|
455
|
+
finally:
|
|
456
|
+
# Call training end callbacks
|
|
457
|
+
for callback in self.callbacks:
|
|
458
|
+
callback.on_train_end(self)
|
|
459
|
+
|
|
460
|
+
self.logger.info("Training completed!")
|
|
461
|
+
|
|
462
|
+
def load_checkpoint(self, checkpoint_path: str):
|
|
463
|
+
"""Load model from checkpoint."""
|
|
464
|
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
|
465
|
+
|
|
466
|
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
467
|
+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
468
|
+
|
|
469
|
+
if self.scheduler and checkpoint.get('scheduler_state_dict'):
|
|
470
|
+
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
|
471
|
+
|
|
472
|
+
if self.scaler and checkpoint.get('scaler_state_dict'):
|
|
473
|
+
self.scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
|
474
|
+
|
|
475
|
+
self.current_epoch = checkpoint['epoch']
|
|
476
|
+
self.global_step = checkpoint['global_step']
|
|
477
|
+
|
|
478
|
+
self.logger.info(f"Loaded checkpoint from epoch {self.current_epoch}")
|
langvision/training/trainer.py
CHANGED
|
@@ -1,7 +1,15 @@
|
|
|
1
1
|
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.optim as optim
|
|
4
|
+
from torch.utils.data import DataLoader
|
|
2
5
|
import os
|
|
3
6
|
import logging
|
|
4
|
-
from
|
|
7
|
+
from typing import Optional, List, Dict, Any
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
import time
|
|
10
|
+
|
|
11
|
+
from ..callbacks.base import Callback
|
|
12
|
+
from ..utils.device import get_device, to_device
|
|
5
13
|
|
|
6
14
|
logger = logging.getLogger("langvision.trainer")
|
|
7
15
|
|
|
@@ -13,8 +21,13 @@ class Trainer:
|
|
|
13
21
|
- Callbacks
|
|
14
22
|
- Distributed/multi-GPU (use torch.nn.parallel.DistributedDataParallel)
|
|
15
23
|
- TPU (use torch_xla)
|
|
24
|
+
|
|
25
|
+
Integration points for advanced LLM concepts:
|
|
26
|
+
- RLHF: Use RLHF-based feedback in the training loop
|
|
27
|
+
- PPO/DPO/GRPO/RLVR: Use RL-based optimization for policy/model updates
|
|
28
|
+
- LIME/SHAP: Use for model interpretability during/after training
|
|
16
29
|
"""
|
|
17
|
-
def __init__(self, model, optimizer, criterion, scheduler=None, scaler=None, callbacks=None, device='cpu'):
|
|
30
|
+
def __init__(self, model, optimizer, criterion, scheduler=None, scaler=None, callbacks=None, device='cpu', rlhf=None, ppo=None, dpo=None):
|
|
18
31
|
self.model = model
|
|
19
32
|
self.optimizer = optimizer
|
|
20
33
|
self.criterion = criterion
|
|
@@ -22,6 +35,9 @@ class Trainer:
|
|
|
22
35
|
self.scaler = scaler
|
|
23
36
|
self.callbacks = callbacks or []
|
|
24
37
|
self.device = device
|
|
38
|
+
self.rlhf = rlhf # RLHF integration (optional)
|
|
39
|
+
self.ppo = ppo # PPO integration (optional)
|
|
40
|
+
self.dpo = dpo # DPO integration (optional)
|
|
25
41
|
# GPU optimization
|
|
26
42
|
if device.type == 'cuda':
|
|
27
43
|
torch.backends.cudnn.benchmark = True
|
|
@@ -40,6 +56,18 @@ class Trainer:
|
|
|
40
56
|
imgs = imgs.to(self.device, non_blocking=True)
|
|
41
57
|
labels = labels.to(self.device, non_blocking=True)
|
|
42
58
|
self.optimizer.zero_grad()
|
|
59
|
+
# RLHF integration example (stub)
|
|
60
|
+
if self.rlhf is not None:
|
|
61
|
+
# TODO: Use self.rlhf.train(data, feedback) for RLHF-based updates
|
|
62
|
+
pass
|
|
63
|
+
# PPO integration example (stub)
|
|
64
|
+
if self.ppo is not None:
|
|
65
|
+
# TODO: Use self.ppo.step(state, action, reward) for PPO-based updates
|
|
66
|
+
pass
|
|
67
|
+
# DPO integration example (stub)
|
|
68
|
+
if self.dpo is not None:
|
|
69
|
+
# TODO: Use self.dpo.optimize_with_preferences(preferences) for DPO-based updates
|
|
70
|
+
pass
|
|
43
71
|
if self.scaler:
|
|
44
72
|
with torch.cuda.amp.autocast():
|
|
45
73
|
outputs = self.model(imgs)
|