langtune 0.1.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langtune/__init__.py +315 -0
- langtune/acceleration.py +132 -0
- langtune/api.py +320 -0
- langtune/auth.py +434 -0
- langtune/callbacks.py +268 -0
- langtune/cli.py +687 -0
- langtune/client.py +721 -0
- langtune/config.py +356 -0
- langtune/data.py +526 -0
- langtune/distributed.py +154 -0
- langtune/facade.py +174 -0
- langtune/finetune.py +491 -0
- langtune/generation.py +95 -0
- langtune/logging_utils.py +182 -0
- langtune/metrics.py +345 -0
- langtune/model/__init__.py +20 -0
- langtune/model/hub.py +109 -0
- langtune/model/loader.py +84 -0
- langtune/model/safetensors.py +104 -0
- langtune/model/weights.py +100 -0
- langtune/models.py +19 -0
- langtune/nn/fast_transformer.py +399 -0
- langtune/nn/layers.py +178 -0
- langtune/nn/transformer.py +254 -0
- langtune/optimizations.py +870 -0
- langtune/py.typed +2 -0
- langtune/schedulers.py +234 -0
- langtune/tokenizers.py +275 -0
- langtune/trainer.py +889 -0
- langtune/training/neftune.py +80 -0
- langtune/utils.py +337 -0
- langtune-0.1.19.dist-info/METADATA +257 -0
- langtune-0.1.19.dist-info/RECORD +37 -0
- langtune-0.1.19.dist-info/WHEEL +5 -0
- langtune-0.1.19.dist-info/entry_points.txt +2 -0
- langtune-0.1.19.dist-info/licenses/LICENSE +21 -0
- langtune-0.1.19.dist-info/top_level.txt +1 -0
langtune/trainer.py
ADDED
|
@@ -0,0 +1,889 @@
|
|
|
1
|
+
"""
|
|
2
|
+
trainer.py: Training utilities for Langtune
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
import time
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
from torch.optim import AdamW
|
|
10
|
+
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, OneCycleLR
|
|
11
|
+
from torch.utils.data import DataLoader
|
|
12
|
+
from typing import Dict, Any, Optional, Callable, List
|
|
13
|
+
import logging
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
import json
|
|
16
|
+
import numpy as np
|
|
17
|
+
from tqdm import tqdm
|
|
18
|
+
import wandb
|
|
19
|
+
from contextlib import contextmanager
|
|
20
|
+
|
|
21
|
+
from .models import LoRALanguageModel
|
|
22
|
+
from .config import Config
|
|
23
|
+
from .data import DataCollator
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
class EarlyStopping:
|
|
28
|
+
"""Early stopping utility."""
|
|
29
|
+
|
|
30
|
+
def __init__(self, patience: int = 5, threshold: float = 0.001, mode: str = "min"):
|
|
31
|
+
self.patience = patience
|
|
32
|
+
self.threshold = threshold
|
|
33
|
+
self.mode = mode
|
|
34
|
+
self.best_score = None
|
|
35
|
+
self.counter = 0
|
|
36
|
+
self.early_stop = False
|
|
37
|
+
|
|
38
|
+
def __call__(self, score: float) -> bool:
|
|
39
|
+
"""Check if training should stop early."""
|
|
40
|
+
if self.best_score is None:
|
|
41
|
+
self.best_score = score
|
|
42
|
+
elif self.mode == "min":
|
|
43
|
+
if score < self.best_score - self.threshold:
|
|
44
|
+
self.best_score = score
|
|
45
|
+
self.counter = 0
|
|
46
|
+
else:
|
|
47
|
+
self.counter += 1
|
|
48
|
+
else: # mode == "max"
|
|
49
|
+
if score > self.best_score + self.threshold:
|
|
50
|
+
self.best_score = score
|
|
51
|
+
self.counter = 0
|
|
52
|
+
else:
|
|
53
|
+
self.counter += 1
|
|
54
|
+
|
|
55
|
+
if self.counter >= self.patience:
|
|
56
|
+
self.early_stop = True
|
|
57
|
+
|
|
58
|
+
return self.early_stop
|
|
59
|
+
|
|
60
|
+
class MetricsTracker:
|
|
61
|
+
"""Track training and validation metrics."""
|
|
62
|
+
|
|
63
|
+
def __init__(self):
|
|
64
|
+
self.metrics = {}
|
|
65
|
+
self.history = []
|
|
66
|
+
|
|
67
|
+
def update(self, metrics: Dict[str, float]):
|
|
68
|
+
"""Update metrics."""
|
|
69
|
+
for key, value in metrics.items():
|
|
70
|
+
if key not in self.metrics:
|
|
71
|
+
self.metrics[key] = []
|
|
72
|
+
self.metrics[key].append(value)
|
|
73
|
+
|
|
74
|
+
def get_average(self, key: str, window: int = None) -> float:
|
|
75
|
+
"""Get average of a metric."""
|
|
76
|
+
if key not in self.metrics:
|
|
77
|
+
return 0.0
|
|
78
|
+
|
|
79
|
+
values = self.metrics[key]
|
|
80
|
+
if window is None:
|
|
81
|
+
return np.mean(values)
|
|
82
|
+
else:
|
|
83
|
+
return np.mean(values[-window:])
|
|
84
|
+
|
|
85
|
+
def get_latest(self, key: str) -> float:
|
|
86
|
+
"""Get latest value of a metric."""
|
|
87
|
+
if key not in self.metrics or not self.metrics[key]:
|
|
88
|
+
return 0.0
|
|
89
|
+
return self.metrics[key][-1]
|
|
90
|
+
|
|
91
|
+
def log_epoch(self):
|
|
92
|
+
"""Log epoch metrics."""
|
|
93
|
+
epoch_metrics = {}
|
|
94
|
+
for key, values in self.metrics.items():
|
|
95
|
+
epoch_metrics[f"epoch_{key}"] = np.mean(values)
|
|
96
|
+
|
|
97
|
+
self.history.append(epoch_metrics)
|
|
98
|
+
self.metrics = {} # Reset for next epoch
|
|
99
|
+
|
|
100
|
+
return epoch_metrics
|
|
101
|
+
|
|
102
|
+
class ModelCheckpoint:
|
|
103
|
+
"""Model checkpointing utility."""
|
|
104
|
+
|
|
105
|
+
def __init__(
|
|
106
|
+
self,
|
|
107
|
+
save_dir: str,
|
|
108
|
+
save_best_only: bool = True,
|
|
109
|
+
save_total_limit: int = 3,
|
|
110
|
+
monitor: str = "val_loss",
|
|
111
|
+
mode: str = "min"
|
|
112
|
+
):
|
|
113
|
+
self.save_dir = Path(save_dir)
|
|
114
|
+
self.save_dir.mkdir(parents=True, exist_ok=True)
|
|
115
|
+
self.save_best_only = save_best_only
|
|
116
|
+
self.save_total_limit = save_total_limit
|
|
117
|
+
self.monitor = monitor
|
|
118
|
+
self.mode = mode
|
|
119
|
+
self.best_score = None
|
|
120
|
+
self.checkpoints = []
|
|
121
|
+
|
|
122
|
+
def save(self, model: nn.Module, optimizer, scheduler, epoch: int, metrics: Dict[str, float]):
|
|
123
|
+
"""Save model checkpoint."""
|
|
124
|
+
checkpoint = {
|
|
125
|
+
"epoch": epoch,
|
|
126
|
+
"model_state_dict": model.state_dict(),
|
|
127
|
+
"optimizer_state_dict": optimizer.state_dict(),
|
|
128
|
+
"scheduler_state_dict": scheduler.state_dict() if scheduler else None,
|
|
129
|
+
"metrics": metrics
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
# Determine if this is the best checkpoint
|
|
133
|
+
current_score = metrics.get(self.monitor, float('inf') if self.mode == "min" else float('-inf'))
|
|
134
|
+
is_best = False
|
|
135
|
+
|
|
136
|
+
if self.best_score is None:
|
|
137
|
+
self.best_score = current_score
|
|
138
|
+
is_best = True
|
|
139
|
+
elif self.mode == "min" and current_score < self.best_score:
|
|
140
|
+
self.best_score = current_score
|
|
141
|
+
is_best = True
|
|
142
|
+
elif self.mode == "max" and current_score > self.best_score:
|
|
143
|
+
self.best_score = current_score
|
|
144
|
+
is_best = True
|
|
145
|
+
|
|
146
|
+
# Save checkpoint
|
|
147
|
+
if not self.save_best_only or is_best:
|
|
148
|
+
checkpoint_path = self.save_dir / f"checkpoint_epoch_{epoch}.pt"
|
|
149
|
+
torch.save(checkpoint, checkpoint_path)
|
|
150
|
+
self.checkpoints.append(checkpoint_path)
|
|
151
|
+
|
|
152
|
+
if is_best:
|
|
153
|
+
best_path = self.save_dir / "best_model.pt"
|
|
154
|
+
torch.save(checkpoint, best_path)
|
|
155
|
+
logger.info(f"New best model saved with {self.monitor}={current_score:.4f}")
|
|
156
|
+
|
|
157
|
+
# Clean up old checkpoints
|
|
158
|
+
if len(self.checkpoints) > self.save_total_limit:
|
|
159
|
+
old_checkpoint = self.checkpoints.pop(0)
|
|
160
|
+
if old_checkpoint.exists():
|
|
161
|
+
old_checkpoint.unlink()
|
|
162
|
+
|
|
163
|
+
def load(self, model: nn.Module, optimizer, scheduler, checkpoint_path: str):
|
|
164
|
+
"""Load model checkpoint."""
|
|
165
|
+
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
|
166
|
+
|
|
167
|
+
model.load_state_dict(checkpoint["model_state_dict"])
|
|
168
|
+
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
|
169
|
+
|
|
170
|
+
if scheduler and checkpoint["scheduler_state_dict"]:
|
|
171
|
+
scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
|
172
|
+
|
|
173
|
+
return checkpoint["epoch"], checkpoint["metrics"]
|
|
174
|
+
|
|
175
|
+
class Trainer:
|
|
176
|
+
"""
|
|
177
|
+
Main trainer class for fine-tuning language models.
|
|
178
|
+
"""
|
|
179
|
+
|
|
180
|
+
def __init__(
|
|
181
|
+
self,
|
|
182
|
+
model: LoRALanguageModel,
|
|
183
|
+
config: Config,
|
|
184
|
+
train_dataloader: DataLoader,
|
|
185
|
+
val_dataloader: Optional[DataLoader] = None,
|
|
186
|
+
test_dataloader: Optional[DataLoader] = None
|
|
187
|
+
):
|
|
188
|
+
self.model = model
|
|
189
|
+
self.config = config
|
|
190
|
+
self.train_dataloader = train_dataloader
|
|
191
|
+
self.val_dataloader = val_dataloader
|
|
192
|
+
self.test_dataloader = test_dataloader
|
|
193
|
+
|
|
194
|
+
# Setup device
|
|
195
|
+
self.device = self._setup_device()
|
|
196
|
+
self.model.to(self.device)
|
|
197
|
+
|
|
198
|
+
# Setup optimizer and scheduler
|
|
199
|
+
self.optimizer = self._setup_optimizer()
|
|
200
|
+
self.scheduler = self._setup_scheduler()
|
|
201
|
+
|
|
202
|
+
# Setup utilities
|
|
203
|
+
self.metrics_tracker = MetricsTracker()
|
|
204
|
+
self.early_stopping = EarlyStopping(
|
|
205
|
+
patience=config.training.early_stopping_patience,
|
|
206
|
+
threshold=config.training.early_stopping_threshold
|
|
207
|
+
)
|
|
208
|
+
self.checkpointer = ModelCheckpoint(
|
|
209
|
+
save_dir=config.output_dir,
|
|
210
|
+
save_total_limit=config.training.save_total_limit,
|
|
211
|
+
monitor="val_loss"
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
# Setup mixed precision
|
|
215
|
+
self.scaler = torch.cuda.amp.GradScaler() if config.training.mixed_precision else None
|
|
216
|
+
|
|
217
|
+
# Setup logging
|
|
218
|
+
self._setup_logging()
|
|
219
|
+
|
|
220
|
+
logger.info(f"Trainer initialized on device: {self.device}")
|
|
221
|
+
logger.info(f"Model parameters: {self.model.count_parameters():,}")
|
|
222
|
+
logger.info(f"LoRA parameters: {self.model.count_lora_parameters():,}")
|
|
223
|
+
|
|
224
|
+
def _setup_device(self) -> torch.device:
|
|
225
|
+
"""Setup training device."""
|
|
226
|
+
if self.config.device == "auto":
|
|
227
|
+
if torch.cuda.is_available():
|
|
228
|
+
device = torch.device("cuda")
|
|
229
|
+
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
|
230
|
+
device = torch.device("mps")
|
|
231
|
+
else:
|
|
232
|
+
device = torch.device("cpu")
|
|
233
|
+
else:
|
|
234
|
+
device = torch.device(self.config.device)
|
|
235
|
+
|
|
236
|
+
return device
|
|
237
|
+
|
|
238
|
+
def _setup_optimizer(self):
|
|
239
|
+
"""Setup optimizer."""
|
|
240
|
+
# Only optimize LoRA parameters if using LoRA
|
|
241
|
+
if hasattr(self.model, 'count_lora_parameters') and self.model.count_lora_parameters() > 0:
|
|
242
|
+
lora_params = []
|
|
243
|
+
for name, param in self.model.named_parameters():
|
|
244
|
+
if 'lora' in name.lower():
|
|
245
|
+
lora_params.append(param)
|
|
246
|
+
|
|
247
|
+
logger.info(f"Optimizing {len(lora_params)} LoRA parameter groups")
|
|
248
|
+
return AdamW(lora_params, lr=self.config.training.learning_rate, weight_decay=self.config.training.weight_decay)
|
|
249
|
+
else:
|
|
250
|
+
return AdamW(self.model.parameters(), lr=self.config.training.learning_rate, weight_decay=self.config.training.weight_decay)
|
|
251
|
+
|
|
252
|
+
def _setup_scheduler(self):
|
|
253
|
+
"""Setup learning rate scheduler."""
|
|
254
|
+
total_steps = len(self.train_dataloader) * self.config.training.num_epochs
|
|
255
|
+
|
|
256
|
+
if self.config.training.warmup_steps > 0:
|
|
257
|
+
return OneCycleLR(
|
|
258
|
+
self.optimizer,
|
|
259
|
+
max_lr=self.config.training.learning_rate,
|
|
260
|
+
total_steps=total_steps,
|
|
261
|
+
pct_start=self.config.training.warmup_steps / total_steps
|
|
262
|
+
)
|
|
263
|
+
else:
|
|
264
|
+
return CosineAnnealingLR(self.optimizer, T_max=total_steps)
|
|
265
|
+
|
|
266
|
+
def _setup_logging(self):
|
|
267
|
+
"""Setup logging and experiment tracking."""
|
|
268
|
+
# Setup Weights & Biases if available
|
|
269
|
+
try:
|
|
270
|
+
wandb.init(
|
|
271
|
+
project="langtune",
|
|
272
|
+
config=self.config.__dict__ if hasattr(self.config, '__dict__') else {},
|
|
273
|
+
name=f"run_{int(time.time())}"
|
|
274
|
+
)
|
|
275
|
+
self.use_wandb = True
|
|
276
|
+
except:
|
|
277
|
+
self.use_wandb = False
|
|
278
|
+
logger.warning("Weights & Biases not available, using local logging only")
|
|
279
|
+
|
|
280
|
+
def train_epoch(self, epoch: int) -> Dict[str, float]:
|
|
281
|
+
"""Train for one epoch."""
|
|
282
|
+
self.model.train()
|
|
283
|
+
total_loss = 0.0
|
|
284
|
+
num_batches = 0
|
|
285
|
+
|
|
286
|
+
progress_bar = tqdm(
|
|
287
|
+
self.train_dataloader,
|
|
288
|
+
desc=f"Epoch {epoch+1}/{self.config.training.num_epochs}",
|
|
289
|
+
leave=False
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
for batch_idx, batch in enumerate(progress_bar):
|
|
293
|
+
# Move batch to device
|
|
294
|
+
batch = {k: v.to(self.device) for k, v in batch.items()}
|
|
295
|
+
|
|
296
|
+
# Forward pass with mixed precision
|
|
297
|
+
if self.scaler:
|
|
298
|
+
with torch.cuda.amp.autocast():
|
|
299
|
+
outputs = self.model(**batch)
|
|
300
|
+
loss = outputs["loss"]
|
|
301
|
+
|
|
302
|
+
# Backward pass
|
|
303
|
+
self.scaler.scale(loss).backward()
|
|
304
|
+
|
|
305
|
+
# Gradient clipping
|
|
306
|
+
if self.config.training.max_grad_norm > 0:
|
|
307
|
+
self.scaler.unscale_(self.optimizer)
|
|
308
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.training.max_grad_norm)
|
|
309
|
+
|
|
310
|
+
# Optimizer step
|
|
311
|
+
self.scaler.step(self.optimizer)
|
|
312
|
+
self.scaler.update()
|
|
313
|
+
else:
|
|
314
|
+
outputs = self.model(**batch)
|
|
315
|
+
loss = outputs["loss"]
|
|
316
|
+
|
|
317
|
+
# Backward pass
|
|
318
|
+
loss.backward()
|
|
319
|
+
|
|
320
|
+
# Gradient clipping
|
|
321
|
+
if self.config.training.max_grad_norm > 0:
|
|
322
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.training.max_grad_norm)
|
|
323
|
+
|
|
324
|
+
# Optimizer step
|
|
325
|
+
self.optimizer.step()
|
|
326
|
+
|
|
327
|
+
# Scheduler step
|
|
328
|
+
if self.scheduler:
|
|
329
|
+
self.scheduler.step()
|
|
330
|
+
|
|
331
|
+
# Clear gradients
|
|
332
|
+
self.optimizer.zero_grad()
|
|
333
|
+
|
|
334
|
+
# Update metrics
|
|
335
|
+
total_loss += loss.item()
|
|
336
|
+
num_batches += 1
|
|
337
|
+
|
|
338
|
+
# Update progress bar
|
|
339
|
+
progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
|
|
340
|
+
|
|
341
|
+
# Logging
|
|
342
|
+
if batch_idx % self.config.training.logging_steps == 0:
|
|
343
|
+
current_lr = self.optimizer.param_groups[0]['lr']
|
|
344
|
+
metrics = {
|
|
345
|
+
"train_loss": loss.item(),
|
|
346
|
+
"learning_rate": current_lr,
|
|
347
|
+
"epoch": epoch
|
|
348
|
+
}
|
|
349
|
+
|
|
350
|
+
self.metrics_tracker.update(metrics)
|
|
351
|
+
|
|
352
|
+
if self.use_wandb:
|
|
353
|
+
wandb.log(metrics)
|
|
354
|
+
|
|
355
|
+
logger.info(f"Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}, LR: {current_lr:.2e}")
|
|
356
|
+
|
|
357
|
+
avg_loss = total_loss / num_batches
|
|
358
|
+
return {"train_loss": avg_loss}
|
|
359
|
+
|
|
360
|
+
def validate(self, epoch: int) -> Dict[str, float]:
|
|
361
|
+
"""Validate the model."""
|
|
362
|
+
if self.val_dataloader is None:
|
|
363
|
+
return {}
|
|
364
|
+
|
|
365
|
+
self.model.eval()
|
|
366
|
+
total_loss = 0.0
|
|
367
|
+
num_batches = 0
|
|
368
|
+
|
|
369
|
+
with torch.no_grad():
|
|
370
|
+
for batch in tqdm(self.val_dataloader, desc="Validation", leave=False):
|
|
371
|
+
batch = {k: v.to(self.device) for k, v in batch.items()}
|
|
372
|
+
|
|
373
|
+
if self.scaler:
|
|
374
|
+
with torch.cuda.amp.autocast():
|
|
375
|
+
outputs = self.model(**batch)
|
|
376
|
+
loss = outputs["loss"]
|
|
377
|
+
else:
|
|
378
|
+
outputs = self.model(**batch)
|
|
379
|
+
loss = outputs["loss"]
|
|
380
|
+
|
|
381
|
+
total_loss += loss.item()
|
|
382
|
+
num_batches += 1
|
|
383
|
+
|
|
384
|
+
avg_loss = total_loss / num_batches
|
|
385
|
+
return {"val_loss": avg_loss}
|
|
386
|
+
|
|
387
|
+
def train(self, resume_from_checkpoint: Optional[str] = None):
|
|
388
|
+
"""Main training loop."""
|
|
389
|
+
start_epoch = 0
|
|
390
|
+
|
|
391
|
+
# Resume from checkpoint if provided
|
|
392
|
+
if resume_from_checkpoint:
|
|
393
|
+
start_epoch, _ = self.checkpointer.load(
|
|
394
|
+
self.model, self.optimizer, self.scheduler, resume_from_checkpoint
|
|
395
|
+
)
|
|
396
|
+
logger.info(f"Resumed training from epoch {start_epoch}")
|
|
397
|
+
|
|
398
|
+
logger.info("Starting training...")
|
|
399
|
+
|
|
400
|
+
for epoch in range(start_epoch, self.config.training.num_epochs):
|
|
401
|
+
# Train epoch
|
|
402
|
+
train_metrics = self.train_epoch(epoch)
|
|
403
|
+
|
|
404
|
+
# Validate
|
|
405
|
+
val_metrics = self.validate(epoch)
|
|
406
|
+
|
|
407
|
+
# Combine metrics
|
|
408
|
+
all_metrics = {**train_metrics, **val_metrics}
|
|
409
|
+
|
|
410
|
+
# Update metrics tracker
|
|
411
|
+
self.metrics_tracker.update(all_metrics)
|
|
412
|
+
|
|
413
|
+
# Log epoch metrics
|
|
414
|
+
epoch_metrics = self.metrics_tracker.log_epoch()
|
|
415
|
+
|
|
416
|
+
# Log to wandb
|
|
417
|
+
if self.use_wandb:
|
|
418
|
+
wandb.log(epoch_metrics)
|
|
419
|
+
|
|
420
|
+
# Save checkpoint
|
|
421
|
+
self.checkpointer.save(self.model, self.optimizer, self.scheduler, epoch, all_metrics)
|
|
422
|
+
|
|
423
|
+
# Early stopping
|
|
424
|
+
if val_metrics and "val_loss" in val_metrics:
|
|
425
|
+
if self.early_stopping(val_metrics["val_loss"]):
|
|
426
|
+
logger.info(f"Early stopping triggered at epoch {epoch+1}")
|
|
427
|
+
break
|
|
428
|
+
|
|
429
|
+
# Log epoch summary
|
|
430
|
+
logger.info(f"Epoch {epoch+1} completed - Train Loss: {train_metrics['train_loss']:.4f}, Val Loss: {val_metrics.get('val_loss', 'N/A')}")
|
|
431
|
+
|
|
432
|
+
logger.info("Training completed!")
|
|
433
|
+
|
|
434
|
+
# Final evaluation on test set
|
|
435
|
+
if self.test_dataloader:
|
|
436
|
+
test_metrics = self.evaluate()
|
|
437
|
+
logger.info(f"Final test metrics: {test_metrics}")
|
|
438
|
+
|
|
439
|
+
def evaluate(self) -> Dict[str, float]:
|
|
440
|
+
"""Evaluate the model on test set."""
|
|
441
|
+
if self.test_dataloader is None:
|
|
442
|
+
logger.warning("No test dataloader provided")
|
|
443
|
+
return {}
|
|
444
|
+
|
|
445
|
+
self.model.eval()
|
|
446
|
+
total_loss = 0.0
|
|
447
|
+
num_batches = 0
|
|
448
|
+
|
|
449
|
+
with torch.no_grad():
|
|
450
|
+
for batch in tqdm(self.test_dataloader, desc="Testing"):
|
|
451
|
+
batch = {k: v.to(self.device) for k, v in batch.items()}
|
|
452
|
+
|
|
453
|
+
if self.scaler:
|
|
454
|
+
with torch.cuda.amp.autocast():
|
|
455
|
+
outputs = self.model(**batch)
|
|
456
|
+
loss = outputs["loss"]
|
|
457
|
+
else:
|
|
458
|
+
outputs = self.model(**batch)
|
|
459
|
+
loss = outputs["loss"]
|
|
460
|
+
|
|
461
|
+
total_loss += loss.item()
|
|
462
|
+
num_batches += 1
|
|
463
|
+
|
|
464
|
+
avg_loss = total_loss / num_batches
|
|
465
|
+
return {"test_loss": avg_loss}
|
|
466
|
+
|
|
467
|
+
def generate_sample(self, prompt: str, max_length: int = 100) -> str:
|
|
468
|
+
"""Generate a sample from the model."""
|
|
469
|
+
self.model.eval()
|
|
470
|
+
|
|
471
|
+
# Simple tokenization (in practice, you'd use a proper tokenizer)
|
|
472
|
+
input_ids = torch.tensor([ord(c) for c in prompt[:50]], dtype=torch.long).unsqueeze(0).to(self.device)
|
|
473
|
+
|
|
474
|
+
with torch.no_grad():
|
|
475
|
+
generated = self.model.generate(
|
|
476
|
+
input_ids,
|
|
477
|
+
max_length=max_length,
|
|
478
|
+
temperature=0.8,
|
|
479
|
+
top_k=50,
|
|
480
|
+
top_p=0.9
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
# Simple decoding
|
|
484
|
+
generated_text = "".join([chr(i) for i in generated[0].cpu().tolist()])
|
|
485
|
+
return generated_text
|
|
486
|
+
|
|
487
|
+
|
|
488
|
+
class FastTrainer:
|
|
489
|
+
"""
|
|
490
|
+
Optimized trainer with:
|
|
491
|
+
- Gradient accumulation for effective larger batches
|
|
492
|
+
- Enhanced mixed precision training
|
|
493
|
+
- Memory monitoring and optimization
|
|
494
|
+
- Support for FastLoRALanguageModel
|
|
495
|
+
"""
|
|
496
|
+
|
|
497
|
+
def __init__(
|
|
498
|
+
self,
|
|
499
|
+
model: nn.Module,
|
|
500
|
+
config: Config,
|
|
501
|
+
train_dataloader: DataLoader,
|
|
502
|
+
val_dataloader: Optional[DataLoader] = None,
|
|
503
|
+
test_dataloader: Optional[DataLoader] = None,
|
|
504
|
+
gradient_accumulation_steps: int = 4,
|
|
505
|
+
mixed_precision: str = "fp16" # fp16, bf16, or fp32
|
|
506
|
+
):
|
|
507
|
+
self.model = model
|
|
508
|
+
self.config = config
|
|
509
|
+
self.train_dataloader = train_dataloader
|
|
510
|
+
self.val_dataloader = val_dataloader
|
|
511
|
+
self.test_dataloader = test_dataloader
|
|
512
|
+
self.gradient_accumulation_steps = gradient_accumulation_steps
|
|
513
|
+
|
|
514
|
+
# Setup device
|
|
515
|
+
self.device = self._setup_device()
|
|
516
|
+
self.model.to(self.device)
|
|
517
|
+
|
|
518
|
+
# Freeze base model if using FastLoRALanguageModel
|
|
519
|
+
if hasattr(self.model, 'freeze_base_model'):
|
|
520
|
+
self.model.freeze_base_model()
|
|
521
|
+
|
|
522
|
+
# Setup mixed precision
|
|
523
|
+
self.mixed_precision = mixed_precision
|
|
524
|
+
self._setup_amp()
|
|
525
|
+
|
|
526
|
+
# Setup optimizer (only trainable params)
|
|
527
|
+
self.optimizer = self._setup_optimizer()
|
|
528
|
+
self.scheduler = self._setup_scheduler()
|
|
529
|
+
|
|
530
|
+
# Utilities
|
|
531
|
+
self.metrics_tracker = MetricsTracker()
|
|
532
|
+
self.early_stopping = EarlyStopping(
|
|
533
|
+
patience=config.training.early_stopping_patience,
|
|
534
|
+
threshold=config.training.early_stopping_threshold
|
|
535
|
+
)
|
|
536
|
+
self.checkpointer = ModelCheckpoint(
|
|
537
|
+
save_dir=config.output_dir,
|
|
538
|
+
save_total_limit=config.training.save_total_limit,
|
|
539
|
+
monitor="val_loss"
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
# Setup logging
|
|
543
|
+
self._setup_logging()
|
|
544
|
+
|
|
545
|
+
# Log configuration
|
|
546
|
+
self._log_training_info()
|
|
547
|
+
|
|
548
|
+
def _setup_device(self) -> torch.device:
|
|
549
|
+
if self.config.device == "auto":
|
|
550
|
+
if torch.cuda.is_available():
|
|
551
|
+
return torch.device("cuda")
|
|
552
|
+
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
|
553
|
+
return torch.device("mps")
|
|
554
|
+
else:
|
|
555
|
+
return torch.device("cpu")
|
|
556
|
+
return torch.device(self.config.device)
|
|
557
|
+
|
|
558
|
+
def _setup_amp(self):
|
|
559
|
+
"""Setup automatic mixed precision."""
|
|
560
|
+
try:
|
|
561
|
+
from .optimizations import MixedPrecisionTrainer
|
|
562
|
+
|
|
563
|
+
if self.mixed_precision == "bf16":
|
|
564
|
+
dtype = torch.bfloat16
|
|
565
|
+
elif self.mixed_precision == "fp16":
|
|
566
|
+
dtype = torch.float16
|
|
567
|
+
else:
|
|
568
|
+
dtype = torch.float32
|
|
569
|
+
|
|
570
|
+
self.amp_trainer = MixedPrecisionTrainer(
|
|
571
|
+
enabled=(self.mixed_precision != "fp32" and self.device.type == "cuda"),
|
|
572
|
+
dtype=dtype
|
|
573
|
+
)
|
|
574
|
+
except ImportError:
|
|
575
|
+
# Fallback to standard scaler
|
|
576
|
+
self.amp_trainer = None
|
|
577
|
+
self.scaler = torch.cuda.amp.GradScaler() if self.device.type == "cuda" else None
|
|
578
|
+
|
|
579
|
+
def _setup_optimizer(self):
|
|
580
|
+
"""Setup optimizer for trainable parameters only."""
|
|
581
|
+
trainable_params = [p for p in self.model.parameters() if p.requires_grad]
|
|
582
|
+
|
|
583
|
+
if len(trainable_params) == 0:
|
|
584
|
+
logger.warning("No trainable parameters found! Check model configuration.")
|
|
585
|
+
trainable_params = list(self.model.parameters())
|
|
586
|
+
|
|
587
|
+
logger.info(f"Optimizing {len(trainable_params)} parameter groups")
|
|
588
|
+
|
|
589
|
+
return AdamW(
|
|
590
|
+
trainable_params,
|
|
591
|
+
lr=self.config.training.learning_rate,
|
|
592
|
+
weight_decay=self.config.training.weight_decay
|
|
593
|
+
)
|
|
594
|
+
|
|
595
|
+
def _setup_scheduler(self):
|
|
596
|
+
steps_per_epoch = len(self.train_dataloader) // self.gradient_accumulation_steps
|
|
597
|
+
total_steps = steps_per_epoch * self.config.training.num_epochs
|
|
598
|
+
|
|
599
|
+
if self.config.training.warmup_steps > 0:
|
|
600
|
+
return OneCycleLR(
|
|
601
|
+
self.optimizer,
|
|
602
|
+
max_lr=self.config.training.learning_rate,
|
|
603
|
+
total_steps=total_steps,
|
|
604
|
+
pct_start=self.config.training.warmup_steps / max(total_steps, 1)
|
|
605
|
+
)
|
|
606
|
+
return CosineAnnealingLR(self.optimizer, T_max=total_steps)
|
|
607
|
+
|
|
608
|
+
def _setup_logging(self):
|
|
609
|
+
try:
|
|
610
|
+
wandb.init(
|
|
611
|
+
project="langtune-fast",
|
|
612
|
+
config={
|
|
613
|
+
"gradient_accumulation": self.gradient_accumulation_steps,
|
|
614
|
+
"mixed_precision": self.mixed_precision,
|
|
615
|
+
**({k: v for k, v in self.config.__dict__.items() if not k.startswith('_')})
|
|
616
|
+
},
|
|
617
|
+
name=f"fast_run_{int(time.time())}"
|
|
618
|
+
)
|
|
619
|
+
self.use_wandb = True
|
|
620
|
+
except:
|
|
621
|
+
self.use_wandb = False
|
|
622
|
+
|
|
623
|
+
def _log_training_info(self):
|
|
624
|
+
total_params = sum(p.numel() for p in self.model.parameters())
|
|
625
|
+
trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
|
626
|
+
|
|
627
|
+
logger.info(f"FastTrainer initialized on {self.device}")
|
|
628
|
+
logger.info(f"Total parameters: {total_params:,}")
|
|
629
|
+
logger.info(f"Trainable parameters: {trainable_params:,} ({100*trainable_params/total_params:.2f}%)")
|
|
630
|
+
logger.info(f"Gradient accumulation: {self.gradient_accumulation_steps}")
|
|
631
|
+
logger.info(f"Mixed precision: {self.mixed_precision}")
|
|
632
|
+
logger.info(f"Effective batch size: {self.config.training.batch_size * self.gradient_accumulation_steps}")
|
|
633
|
+
|
|
634
|
+
def _log_memory(self, prefix: str = ""):
|
|
635
|
+
if self.device.type == "cuda":
|
|
636
|
+
try:
|
|
637
|
+
from .optimizations import log_memory_usage
|
|
638
|
+
log_memory_usage(prefix)
|
|
639
|
+
except ImportError:
|
|
640
|
+
allocated = torch.cuda.memory_allocated() / 1e9
|
|
641
|
+
logger.info(f"{prefix}GPU Memory: {allocated:.2f} GB")
|
|
642
|
+
|
|
643
|
+
def train_epoch(self, epoch: int) -> Dict[str, float]:
|
|
644
|
+
self.model.train()
|
|
645
|
+
total_loss = 0.0
|
|
646
|
+
num_steps = 0
|
|
647
|
+
accumulated_loss = 0.0
|
|
648
|
+
|
|
649
|
+
progress_bar = tqdm(
|
|
650
|
+
self.train_dataloader,
|
|
651
|
+
desc=f"Epoch {epoch+1}/{self.config.training.num_epochs}",
|
|
652
|
+
leave=False
|
|
653
|
+
)
|
|
654
|
+
|
|
655
|
+
self.optimizer.zero_grad()
|
|
656
|
+
|
|
657
|
+
for batch_idx, batch in enumerate(progress_bar):
|
|
658
|
+
batch = {k: v.to(self.device) for k, v in batch.items()}
|
|
659
|
+
|
|
660
|
+
# Forward with AMP
|
|
661
|
+
if self.amp_trainer:
|
|
662
|
+
with self.amp_trainer.autocast_context:
|
|
663
|
+
outputs = self.model(**batch)
|
|
664
|
+
loss = outputs["loss"] / self.gradient_accumulation_steps
|
|
665
|
+
|
|
666
|
+
# Scale and backward
|
|
667
|
+
scaled_loss = self.amp_trainer.scale_loss(loss)
|
|
668
|
+
scaled_loss.backward()
|
|
669
|
+
else:
|
|
670
|
+
outputs = self.model(**batch)
|
|
671
|
+
loss = outputs["loss"] / self.gradient_accumulation_steps
|
|
672
|
+
loss.backward()
|
|
673
|
+
|
|
674
|
+
accumulated_loss += loss.item()
|
|
675
|
+
|
|
676
|
+
# Optimizer step after accumulation
|
|
677
|
+
if (batch_idx + 1) % self.gradient_accumulation_steps == 0:
|
|
678
|
+
# Gradient clipping
|
|
679
|
+
if self.config.training.max_grad_norm > 0:
|
|
680
|
+
if self.amp_trainer:
|
|
681
|
+
self.amp_trainer.unscale_gradients(self.optimizer)
|
|
682
|
+
torch.nn.utils.clip_grad_norm_(
|
|
683
|
+
self.model.parameters(),
|
|
684
|
+
self.config.training.max_grad_norm
|
|
685
|
+
)
|
|
686
|
+
|
|
687
|
+
# Step
|
|
688
|
+
if self.amp_trainer:
|
|
689
|
+
self.amp_trainer.step(self.optimizer)
|
|
690
|
+
else:
|
|
691
|
+
self.optimizer.step()
|
|
692
|
+
|
|
693
|
+
if self.scheduler:
|
|
694
|
+
self.scheduler.step()
|
|
695
|
+
|
|
696
|
+
self.optimizer.zero_grad()
|
|
697
|
+
|
|
698
|
+
# Track
|
|
699
|
+
step_loss = accumulated_loss * self.gradient_accumulation_steps
|
|
700
|
+
total_loss += step_loss
|
|
701
|
+
num_steps += 1
|
|
702
|
+
accumulated_loss = 0.0
|
|
703
|
+
|
|
704
|
+
progress_bar.set_postfix({"loss": f"{step_loss:.4f}"})
|
|
705
|
+
|
|
706
|
+
# Log periodically
|
|
707
|
+
if num_steps % (self.config.training.logging_steps // self.gradient_accumulation_steps + 1) == 0:
|
|
708
|
+
lr = self.optimizer.param_groups[0]['lr']
|
|
709
|
+
metrics = {"train_loss": step_loss, "learning_rate": lr, "epoch": epoch}
|
|
710
|
+
self.metrics_tracker.update(metrics)
|
|
711
|
+
if self.use_wandb:
|
|
712
|
+
wandb.log(metrics)
|
|
713
|
+
|
|
714
|
+
# Handle remaining batches
|
|
715
|
+
if accumulated_loss > 0:
|
|
716
|
+
if self.amp_trainer:
|
|
717
|
+
self.amp_trainer.step(self.optimizer)
|
|
718
|
+
else:
|
|
719
|
+
self.optimizer.step()
|
|
720
|
+
self.optimizer.zero_grad()
|
|
721
|
+
|
|
722
|
+
self._log_memory(f"Epoch {epoch+1} end - ")
|
|
723
|
+
|
|
724
|
+
return {"train_loss": total_loss / max(num_steps, 1)}
|
|
725
|
+
|
|
726
|
+
def validate(self, epoch: int) -> Dict[str, float]:
|
|
727
|
+
if self.val_dataloader is None:
|
|
728
|
+
return {}
|
|
729
|
+
|
|
730
|
+
self.model.eval()
|
|
731
|
+
total_loss = 0.0
|
|
732
|
+
num_batches = 0
|
|
733
|
+
|
|
734
|
+
with torch.no_grad():
|
|
735
|
+
for batch in tqdm(self.val_dataloader, desc="Validation", leave=False):
|
|
736
|
+
batch = {k: v.to(self.device) for k, v in batch.items()}
|
|
737
|
+
|
|
738
|
+
if self.amp_trainer:
|
|
739
|
+
with self.amp_trainer.autocast_context:
|
|
740
|
+
outputs = self.model(**batch)
|
|
741
|
+
else:
|
|
742
|
+
outputs = self.model(**batch)
|
|
743
|
+
|
|
744
|
+
total_loss += outputs["loss"].item()
|
|
745
|
+
num_batches += 1
|
|
746
|
+
|
|
747
|
+
return {"val_loss": total_loss / max(num_batches, 1)}
|
|
748
|
+
|
|
749
|
+
def train(self, resume_from_checkpoint: Optional[str] = None):
|
|
750
|
+
start_epoch = 0
|
|
751
|
+
|
|
752
|
+
if resume_from_checkpoint:
|
|
753
|
+
start_epoch, _ = self.checkpointer.load(
|
|
754
|
+
self.model, self.optimizer, self.scheduler, resume_from_checkpoint
|
|
755
|
+
)
|
|
756
|
+
logger.info(f"Resumed from epoch {start_epoch}")
|
|
757
|
+
|
|
758
|
+
logger.info("Starting optimized training...")
|
|
759
|
+
self._log_memory("Training start - ")
|
|
760
|
+
|
|
761
|
+
for epoch in range(start_epoch, self.config.training.num_epochs):
|
|
762
|
+
train_metrics = self.train_epoch(epoch)
|
|
763
|
+
val_metrics = self.validate(epoch)
|
|
764
|
+
|
|
765
|
+
all_metrics = {**train_metrics, **val_metrics}
|
|
766
|
+
self.metrics_tracker.update(all_metrics)
|
|
767
|
+
epoch_metrics = self.metrics_tracker.log_epoch()
|
|
768
|
+
|
|
769
|
+
if self.use_wandb:
|
|
770
|
+
wandb.log(epoch_metrics)
|
|
771
|
+
|
|
772
|
+
self.checkpointer.save(self.model, self.optimizer, self.scheduler, epoch, all_metrics)
|
|
773
|
+
|
|
774
|
+
if val_metrics and "val_loss" in val_metrics:
|
|
775
|
+
if self.early_stopping(val_metrics["val_loss"]):
|
|
776
|
+
logger.info(f"Early stopping at epoch {epoch+1}")
|
|
777
|
+
break
|
|
778
|
+
|
|
779
|
+
logger.info(
|
|
780
|
+
f"Epoch {epoch+1} - Train: {train_metrics['train_loss']:.4f}, "
|
|
781
|
+
f"Val: {val_metrics.get('val_loss', 'N/A')}"
|
|
782
|
+
)
|
|
783
|
+
|
|
784
|
+
logger.info("Training completed!")
|
|
785
|
+
self._log_memory("Training end - ")
|
|
786
|
+
|
|
787
|
+
# Cleanup
|
|
788
|
+
try:
|
|
789
|
+
from .optimizations import cleanup_memory
|
|
790
|
+
cleanup_memory()
|
|
791
|
+
except ImportError:
|
|
792
|
+
if self.device.type == "cuda":
|
|
793
|
+
torch.cuda.empty_cache()
|
|
794
|
+
|
|
795
|
+
|
|
796
|
+
def create_trainer(
|
|
797
|
+
config: Config,
|
|
798
|
+
train_dataloader: DataLoader,
|
|
799
|
+
val_dataloader: Optional[DataLoader] = None,
|
|
800
|
+
test_dataloader: Optional[DataLoader] = None
|
|
801
|
+
) -> Trainer:
|
|
802
|
+
"""
|
|
803
|
+
Create a trainer instance.
|
|
804
|
+
|
|
805
|
+
Args:
|
|
806
|
+
config: Training configuration
|
|
807
|
+
train_dataloader: Training data loader
|
|
808
|
+
val_dataloader: Validation data loader (optional)
|
|
809
|
+
test_dataloader: Test data loader (optional)
|
|
810
|
+
|
|
811
|
+
Returns:
|
|
812
|
+
Trainer instance
|
|
813
|
+
"""
|
|
814
|
+
# Create model
|
|
815
|
+
model = LoRALanguageModel(
|
|
816
|
+
vocab_size=config.model.vocab_size,
|
|
817
|
+
embed_dim=config.model.embed_dim,
|
|
818
|
+
num_layers=config.model.num_layers,
|
|
819
|
+
num_heads=config.model.num_heads,
|
|
820
|
+
max_seq_len=config.model.max_seq_len,
|
|
821
|
+
mlp_ratio=config.model.mlp_ratio,
|
|
822
|
+
dropout=config.model.dropout,
|
|
823
|
+
lora_config=config.model.lora.__dict__ if config.model.lora else None
|
|
824
|
+
)
|
|
825
|
+
|
|
826
|
+
# Create trainer
|
|
827
|
+
trainer = Trainer(
|
|
828
|
+
model=model,
|
|
829
|
+
config=config,
|
|
830
|
+
train_dataloader=train_dataloader,
|
|
831
|
+
val_dataloader=val_dataloader,
|
|
832
|
+
test_dataloader=test_dataloader
|
|
833
|
+
)
|
|
834
|
+
|
|
835
|
+
return trainer
|
|
836
|
+
|
|
837
|
+
|
|
838
|
+
def create_fast_trainer(
|
|
839
|
+
config: Config,
|
|
840
|
+
train_dataloader: DataLoader,
|
|
841
|
+
val_dataloader: Optional[DataLoader] = None,
|
|
842
|
+
test_dataloader: Optional[DataLoader] = None,
|
|
843
|
+
gradient_accumulation_steps: int = 4,
|
|
844
|
+
mixed_precision: str = "fp16"
|
|
845
|
+
) -> FastTrainer:
|
|
846
|
+
"""
|
|
847
|
+
Create an optimized FastTrainer instance with FastLoRALanguageModel.
|
|
848
|
+
|
|
849
|
+
Args:
|
|
850
|
+
config: Training configuration
|
|
851
|
+
train_dataloader: Training data loader
|
|
852
|
+
val_dataloader: Validation data loader (optional)
|
|
853
|
+
test_dataloader: Test data loader (optional)
|
|
854
|
+
gradient_accumulation_steps: Steps to accumulate gradients
|
|
855
|
+
mixed_precision: "fp16", "bf16", or "fp32"
|
|
856
|
+
|
|
857
|
+
Returns:
|
|
858
|
+
FastTrainer instance
|
|
859
|
+
"""
|
|
860
|
+
from .models import FastLoRALanguageModel
|
|
861
|
+
|
|
862
|
+
# Create optimized model
|
|
863
|
+
model = FastLoRALanguageModel(
|
|
864
|
+
vocab_size=config.model.vocab_size,
|
|
865
|
+
embed_dim=config.model.embed_dim,
|
|
866
|
+
num_layers=config.model.num_layers,
|
|
867
|
+
num_heads=config.model.num_heads,
|
|
868
|
+
max_seq_len=config.model.max_seq_len,
|
|
869
|
+
mlp_ratio=config.model.mlp_ratio,
|
|
870
|
+
dropout=config.model.dropout,
|
|
871
|
+
lora_config=config.model.lora.__dict__ if config.model.lora else None,
|
|
872
|
+
use_rope=True,
|
|
873
|
+
use_flash_attention=True,
|
|
874
|
+
use_gradient_checkpointing=True
|
|
875
|
+
)
|
|
876
|
+
|
|
877
|
+
# Create fast trainer
|
|
878
|
+
trainer = FastTrainer(
|
|
879
|
+
model=model,
|
|
880
|
+
config=config,
|
|
881
|
+
train_dataloader=train_dataloader,
|
|
882
|
+
val_dataloader=val_dataloader,
|
|
883
|
+
test_dataloader=test_dataloader,
|
|
884
|
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
|
885
|
+
mixed_precision=mixed_precision
|
|
886
|
+
)
|
|
887
|
+
|
|
888
|
+
return trainer
|
|
889
|
+
|