langtune 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
langtune/trainer.py ADDED
@@ -0,0 +1,889 @@
1
+ """
2
+ trainer.py: Training utilities for Langtune
3
+ """
4
+
5
+ import os
6
+ import time
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.optim import AdamW
10
+ from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, OneCycleLR
11
+ from torch.utils.data import DataLoader
12
+ from typing import Dict, Any, Optional, Callable, List
13
+ import logging
14
+ from pathlib import Path
15
+ import json
16
+ import numpy as np
17
+ from tqdm import tqdm
18
+ import wandb
19
+ from contextlib import contextmanager
20
+
21
+ from .models import LoRALanguageModel
22
+ from .config import Config
23
+ from .data import DataCollator
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+ class EarlyStopping:
28
+ """Early stopping utility."""
29
+
30
+ def __init__(self, patience: int = 5, threshold: float = 0.001, mode: str = "min"):
31
+ self.patience = patience
32
+ self.threshold = threshold
33
+ self.mode = mode
34
+ self.best_score = None
35
+ self.counter = 0
36
+ self.early_stop = False
37
+
38
+ def __call__(self, score: float) -> bool:
39
+ """Check if training should stop early."""
40
+ if self.best_score is None:
41
+ self.best_score = score
42
+ elif self.mode == "min":
43
+ if score < self.best_score - self.threshold:
44
+ self.best_score = score
45
+ self.counter = 0
46
+ else:
47
+ self.counter += 1
48
+ else: # mode == "max"
49
+ if score > self.best_score + self.threshold:
50
+ self.best_score = score
51
+ self.counter = 0
52
+ else:
53
+ self.counter += 1
54
+
55
+ if self.counter >= self.patience:
56
+ self.early_stop = True
57
+
58
+ return self.early_stop
59
+
60
+ class MetricsTracker:
61
+ """Track training and validation metrics."""
62
+
63
+ def __init__(self):
64
+ self.metrics = {}
65
+ self.history = []
66
+
67
+ def update(self, metrics: Dict[str, float]):
68
+ """Update metrics."""
69
+ for key, value in metrics.items():
70
+ if key not in self.metrics:
71
+ self.metrics[key] = []
72
+ self.metrics[key].append(value)
73
+
74
+ def get_average(self, key: str, window: int = None) -> float:
75
+ """Get average of a metric."""
76
+ if key not in self.metrics:
77
+ return 0.0
78
+
79
+ values = self.metrics[key]
80
+ if window is None:
81
+ return np.mean(values)
82
+ else:
83
+ return np.mean(values[-window:])
84
+
85
+ def get_latest(self, key: str) -> float:
86
+ """Get latest value of a metric."""
87
+ if key not in self.metrics or not self.metrics[key]:
88
+ return 0.0
89
+ return self.metrics[key][-1]
90
+
91
+ def log_epoch(self):
92
+ """Log epoch metrics."""
93
+ epoch_metrics = {}
94
+ for key, values in self.metrics.items():
95
+ epoch_metrics[f"epoch_{key}"] = np.mean(values)
96
+
97
+ self.history.append(epoch_metrics)
98
+ self.metrics = {} # Reset for next epoch
99
+
100
+ return epoch_metrics
101
+
102
+ class ModelCheckpoint:
103
+ """Model checkpointing utility."""
104
+
105
+ def __init__(
106
+ self,
107
+ save_dir: str,
108
+ save_best_only: bool = True,
109
+ save_total_limit: int = 3,
110
+ monitor: str = "val_loss",
111
+ mode: str = "min"
112
+ ):
113
+ self.save_dir = Path(save_dir)
114
+ self.save_dir.mkdir(parents=True, exist_ok=True)
115
+ self.save_best_only = save_best_only
116
+ self.save_total_limit = save_total_limit
117
+ self.monitor = monitor
118
+ self.mode = mode
119
+ self.best_score = None
120
+ self.checkpoints = []
121
+
122
+ def save(self, model: nn.Module, optimizer, scheduler, epoch: int, metrics: Dict[str, float]):
123
+ """Save model checkpoint."""
124
+ checkpoint = {
125
+ "epoch": epoch,
126
+ "model_state_dict": model.state_dict(),
127
+ "optimizer_state_dict": optimizer.state_dict(),
128
+ "scheduler_state_dict": scheduler.state_dict() if scheduler else None,
129
+ "metrics": metrics
130
+ }
131
+
132
+ # Determine if this is the best checkpoint
133
+ current_score = metrics.get(self.monitor, float('inf') if self.mode == "min" else float('-inf'))
134
+ is_best = False
135
+
136
+ if self.best_score is None:
137
+ self.best_score = current_score
138
+ is_best = True
139
+ elif self.mode == "min" and current_score < self.best_score:
140
+ self.best_score = current_score
141
+ is_best = True
142
+ elif self.mode == "max" and current_score > self.best_score:
143
+ self.best_score = current_score
144
+ is_best = True
145
+
146
+ # Save checkpoint
147
+ if not self.save_best_only or is_best:
148
+ checkpoint_path = self.save_dir / f"checkpoint_epoch_{epoch}.pt"
149
+ torch.save(checkpoint, checkpoint_path)
150
+ self.checkpoints.append(checkpoint_path)
151
+
152
+ if is_best:
153
+ best_path = self.save_dir / "best_model.pt"
154
+ torch.save(checkpoint, best_path)
155
+ logger.info(f"New best model saved with {self.monitor}={current_score:.4f}")
156
+
157
+ # Clean up old checkpoints
158
+ if len(self.checkpoints) > self.save_total_limit:
159
+ old_checkpoint = self.checkpoints.pop(0)
160
+ if old_checkpoint.exists():
161
+ old_checkpoint.unlink()
162
+
163
+ def load(self, model: nn.Module, optimizer, scheduler, checkpoint_path: str):
164
+ """Load model checkpoint."""
165
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
166
+
167
+ model.load_state_dict(checkpoint["model_state_dict"])
168
+ optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
169
+
170
+ if scheduler and checkpoint["scheduler_state_dict"]:
171
+ scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
172
+
173
+ return checkpoint["epoch"], checkpoint["metrics"]
174
+
175
+ class Trainer:
176
+ """
177
+ Main trainer class for fine-tuning language models.
178
+ """
179
+
180
+ def __init__(
181
+ self,
182
+ model: LoRALanguageModel,
183
+ config: Config,
184
+ train_dataloader: DataLoader,
185
+ val_dataloader: Optional[DataLoader] = None,
186
+ test_dataloader: Optional[DataLoader] = None
187
+ ):
188
+ self.model = model
189
+ self.config = config
190
+ self.train_dataloader = train_dataloader
191
+ self.val_dataloader = val_dataloader
192
+ self.test_dataloader = test_dataloader
193
+
194
+ # Setup device
195
+ self.device = self._setup_device()
196
+ self.model.to(self.device)
197
+
198
+ # Setup optimizer and scheduler
199
+ self.optimizer = self._setup_optimizer()
200
+ self.scheduler = self._setup_scheduler()
201
+
202
+ # Setup utilities
203
+ self.metrics_tracker = MetricsTracker()
204
+ self.early_stopping = EarlyStopping(
205
+ patience=config.training.early_stopping_patience,
206
+ threshold=config.training.early_stopping_threshold
207
+ )
208
+ self.checkpointer = ModelCheckpoint(
209
+ save_dir=config.output_dir,
210
+ save_total_limit=config.training.save_total_limit,
211
+ monitor="val_loss"
212
+ )
213
+
214
+ # Setup mixed precision
215
+ self.scaler = torch.cuda.amp.GradScaler() if config.training.mixed_precision else None
216
+
217
+ # Setup logging
218
+ self._setup_logging()
219
+
220
+ logger.info(f"Trainer initialized on device: {self.device}")
221
+ logger.info(f"Model parameters: {self.model.count_parameters():,}")
222
+ logger.info(f"LoRA parameters: {self.model.count_lora_parameters():,}")
223
+
224
+ def _setup_device(self) -> torch.device:
225
+ """Setup training device."""
226
+ if self.config.device == "auto":
227
+ if torch.cuda.is_available():
228
+ device = torch.device("cuda")
229
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
230
+ device = torch.device("mps")
231
+ else:
232
+ device = torch.device("cpu")
233
+ else:
234
+ device = torch.device(self.config.device)
235
+
236
+ return device
237
+
238
+ def _setup_optimizer(self):
239
+ """Setup optimizer."""
240
+ # Only optimize LoRA parameters if using LoRA
241
+ if hasattr(self.model, 'count_lora_parameters') and self.model.count_lora_parameters() > 0:
242
+ lora_params = []
243
+ for name, param in self.model.named_parameters():
244
+ if 'lora' in name.lower():
245
+ lora_params.append(param)
246
+
247
+ logger.info(f"Optimizing {len(lora_params)} LoRA parameter groups")
248
+ return AdamW(lora_params, lr=self.config.training.learning_rate, weight_decay=self.config.training.weight_decay)
249
+ else:
250
+ return AdamW(self.model.parameters(), lr=self.config.training.learning_rate, weight_decay=self.config.training.weight_decay)
251
+
252
+ def _setup_scheduler(self):
253
+ """Setup learning rate scheduler."""
254
+ total_steps = len(self.train_dataloader) * self.config.training.num_epochs
255
+
256
+ if self.config.training.warmup_steps > 0:
257
+ return OneCycleLR(
258
+ self.optimizer,
259
+ max_lr=self.config.training.learning_rate,
260
+ total_steps=total_steps,
261
+ pct_start=self.config.training.warmup_steps / total_steps
262
+ )
263
+ else:
264
+ return CosineAnnealingLR(self.optimizer, T_max=total_steps)
265
+
266
+ def _setup_logging(self):
267
+ """Setup logging and experiment tracking."""
268
+ # Setup Weights & Biases if available
269
+ try:
270
+ wandb.init(
271
+ project="langtune",
272
+ config=self.config.__dict__ if hasattr(self.config, '__dict__') else {},
273
+ name=f"run_{int(time.time())}"
274
+ )
275
+ self.use_wandb = True
276
+ except:
277
+ self.use_wandb = False
278
+ logger.warning("Weights & Biases not available, using local logging only")
279
+
280
+ def train_epoch(self, epoch: int) -> Dict[str, float]:
281
+ """Train for one epoch."""
282
+ self.model.train()
283
+ total_loss = 0.0
284
+ num_batches = 0
285
+
286
+ progress_bar = tqdm(
287
+ self.train_dataloader,
288
+ desc=f"Epoch {epoch+1}/{self.config.training.num_epochs}",
289
+ leave=False
290
+ )
291
+
292
+ for batch_idx, batch in enumerate(progress_bar):
293
+ # Move batch to device
294
+ batch = {k: v.to(self.device) for k, v in batch.items()}
295
+
296
+ # Forward pass with mixed precision
297
+ if self.scaler:
298
+ with torch.cuda.amp.autocast():
299
+ outputs = self.model(**batch)
300
+ loss = outputs["loss"]
301
+
302
+ # Backward pass
303
+ self.scaler.scale(loss).backward()
304
+
305
+ # Gradient clipping
306
+ if self.config.training.max_grad_norm > 0:
307
+ self.scaler.unscale_(self.optimizer)
308
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.training.max_grad_norm)
309
+
310
+ # Optimizer step
311
+ self.scaler.step(self.optimizer)
312
+ self.scaler.update()
313
+ else:
314
+ outputs = self.model(**batch)
315
+ loss = outputs["loss"]
316
+
317
+ # Backward pass
318
+ loss.backward()
319
+
320
+ # Gradient clipping
321
+ if self.config.training.max_grad_norm > 0:
322
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.training.max_grad_norm)
323
+
324
+ # Optimizer step
325
+ self.optimizer.step()
326
+
327
+ # Scheduler step
328
+ if self.scheduler:
329
+ self.scheduler.step()
330
+
331
+ # Clear gradients
332
+ self.optimizer.zero_grad()
333
+
334
+ # Update metrics
335
+ total_loss += loss.item()
336
+ num_batches += 1
337
+
338
+ # Update progress bar
339
+ progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
340
+
341
+ # Logging
342
+ if batch_idx % self.config.training.logging_steps == 0:
343
+ current_lr = self.optimizer.param_groups[0]['lr']
344
+ metrics = {
345
+ "train_loss": loss.item(),
346
+ "learning_rate": current_lr,
347
+ "epoch": epoch
348
+ }
349
+
350
+ self.metrics_tracker.update(metrics)
351
+
352
+ if self.use_wandb:
353
+ wandb.log(metrics)
354
+
355
+ logger.info(f"Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss.item():.4f}, LR: {current_lr:.2e}")
356
+
357
+ avg_loss = total_loss / num_batches
358
+ return {"train_loss": avg_loss}
359
+
360
+ def validate(self, epoch: int) -> Dict[str, float]:
361
+ """Validate the model."""
362
+ if self.val_dataloader is None:
363
+ return {}
364
+
365
+ self.model.eval()
366
+ total_loss = 0.0
367
+ num_batches = 0
368
+
369
+ with torch.no_grad():
370
+ for batch in tqdm(self.val_dataloader, desc="Validation", leave=False):
371
+ batch = {k: v.to(self.device) for k, v in batch.items()}
372
+
373
+ if self.scaler:
374
+ with torch.cuda.amp.autocast():
375
+ outputs = self.model(**batch)
376
+ loss = outputs["loss"]
377
+ else:
378
+ outputs = self.model(**batch)
379
+ loss = outputs["loss"]
380
+
381
+ total_loss += loss.item()
382
+ num_batches += 1
383
+
384
+ avg_loss = total_loss / num_batches
385
+ return {"val_loss": avg_loss}
386
+
387
+ def train(self, resume_from_checkpoint: Optional[str] = None):
388
+ """Main training loop."""
389
+ start_epoch = 0
390
+
391
+ # Resume from checkpoint if provided
392
+ if resume_from_checkpoint:
393
+ start_epoch, _ = self.checkpointer.load(
394
+ self.model, self.optimizer, self.scheduler, resume_from_checkpoint
395
+ )
396
+ logger.info(f"Resumed training from epoch {start_epoch}")
397
+
398
+ logger.info("Starting training...")
399
+
400
+ for epoch in range(start_epoch, self.config.training.num_epochs):
401
+ # Train epoch
402
+ train_metrics = self.train_epoch(epoch)
403
+
404
+ # Validate
405
+ val_metrics = self.validate(epoch)
406
+
407
+ # Combine metrics
408
+ all_metrics = {**train_metrics, **val_metrics}
409
+
410
+ # Update metrics tracker
411
+ self.metrics_tracker.update(all_metrics)
412
+
413
+ # Log epoch metrics
414
+ epoch_metrics = self.metrics_tracker.log_epoch()
415
+
416
+ # Log to wandb
417
+ if self.use_wandb:
418
+ wandb.log(epoch_metrics)
419
+
420
+ # Save checkpoint
421
+ self.checkpointer.save(self.model, self.optimizer, self.scheduler, epoch, all_metrics)
422
+
423
+ # Early stopping
424
+ if val_metrics and "val_loss" in val_metrics:
425
+ if self.early_stopping(val_metrics["val_loss"]):
426
+ logger.info(f"Early stopping triggered at epoch {epoch+1}")
427
+ break
428
+
429
+ # Log epoch summary
430
+ logger.info(f"Epoch {epoch+1} completed - Train Loss: {train_metrics['train_loss']:.4f}, Val Loss: {val_metrics.get('val_loss', 'N/A')}")
431
+
432
+ logger.info("Training completed!")
433
+
434
+ # Final evaluation on test set
435
+ if self.test_dataloader:
436
+ test_metrics = self.evaluate()
437
+ logger.info(f"Final test metrics: {test_metrics}")
438
+
439
+ def evaluate(self) -> Dict[str, float]:
440
+ """Evaluate the model on test set."""
441
+ if self.test_dataloader is None:
442
+ logger.warning("No test dataloader provided")
443
+ return {}
444
+
445
+ self.model.eval()
446
+ total_loss = 0.0
447
+ num_batches = 0
448
+
449
+ with torch.no_grad():
450
+ for batch in tqdm(self.test_dataloader, desc="Testing"):
451
+ batch = {k: v.to(self.device) for k, v in batch.items()}
452
+
453
+ if self.scaler:
454
+ with torch.cuda.amp.autocast():
455
+ outputs = self.model(**batch)
456
+ loss = outputs["loss"]
457
+ else:
458
+ outputs = self.model(**batch)
459
+ loss = outputs["loss"]
460
+
461
+ total_loss += loss.item()
462
+ num_batches += 1
463
+
464
+ avg_loss = total_loss / num_batches
465
+ return {"test_loss": avg_loss}
466
+
467
+ def generate_sample(self, prompt: str, max_length: int = 100) -> str:
468
+ """Generate a sample from the model."""
469
+ self.model.eval()
470
+
471
+ # Simple tokenization (in practice, you'd use a proper tokenizer)
472
+ input_ids = torch.tensor([ord(c) for c in prompt[:50]], dtype=torch.long).unsqueeze(0).to(self.device)
473
+
474
+ with torch.no_grad():
475
+ generated = self.model.generate(
476
+ input_ids,
477
+ max_length=max_length,
478
+ temperature=0.8,
479
+ top_k=50,
480
+ top_p=0.9
481
+ )
482
+
483
+ # Simple decoding
484
+ generated_text = "".join([chr(i) for i in generated[0].cpu().tolist()])
485
+ return generated_text
486
+
487
+
488
+ class FastTrainer:
489
+ """
490
+ Optimized trainer with:
491
+ - Gradient accumulation for effective larger batches
492
+ - Enhanced mixed precision training
493
+ - Memory monitoring and optimization
494
+ - Support for FastLoRALanguageModel
495
+ """
496
+
497
+ def __init__(
498
+ self,
499
+ model: nn.Module,
500
+ config: Config,
501
+ train_dataloader: DataLoader,
502
+ val_dataloader: Optional[DataLoader] = None,
503
+ test_dataloader: Optional[DataLoader] = None,
504
+ gradient_accumulation_steps: int = 4,
505
+ mixed_precision: str = "fp16" # fp16, bf16, or fp32
506
+ ):
507
+ self.model = model
508
+ self.config = config
509
+ self.train_dataloader = train_dataloader
510
+ self.val_dataloader = val_dataloader
511
+ self.test_dataloader = test_dataloader
512
+ self.gradient_accumulation_steps = gradient_accumulation_steps
513
+
514
+ # Setup device
515
+ self.device = self._setup_device()
516
+ self.model.to(self.device)
517
+
518
+ # Freeze base model if using FastLoRALanguageModel
519
+ if hasattr(self.model, 'freeze_base_model'):
520
+ self.model.freeze_base_model()
521
+
522
+ # Setup mixed precision
523
+ self.mixed_precision = mixed_precision
524
+ self._setup_amp()
525
+
526
+ # Setup optimizer (only trainable params)
527
+ self.optimizer = self._setup_optimizer()
528
+ self.scheduler = self._setup_scheduler()
529
+
530
+ # Utilities
531
+ self.metrics_tracker = MetricsTracker()
532
+ self.early_stopping = EarlyStopping(
533
+ patience=config.training.early_stopping_patience,
534
+ threshold=config.training.early_stopping_threshold
535
+ )
536
+ self.checkpointer = ModelCheckpoint(
537
+ save_dir=config.output_dir,
538
+ save_total_limit=config.training.save_total_limit,
539
+ monitor="val_loss"
540
+ )
541
+
542
+ # Setup logging
543
+ self._setup_logging()
544
+
545
+ # Log configuration
546
+ self._log_training_info()
547
+
548
+ def _setup_device(self) -> torch.device:
549
+ if self.config.device == "auto":
550
+ if torch.cuda.is_available():
551
+ return torch.device("cuda")
552
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
553
+ return torch.device("mps")
554
+ else:
555
+ return torch.device("cpu")
556
+ return torch.device(self.config.device)
557
+
558
+ def _setup_amp(self):
559
+ """Setup automatic mixed precision."""
560
+ try:
561
+ from .optimizations import MixedPrecisionTrainer
562
+
563
+ if self.mixed_precision == "bf16":
564
+ dtype = torch.bfloat16
565
+ elif self.mixed_precision == "fp16":
566
+ dtype = torch.float16
567
+ else:
568
+ dtype = torch.float32
569
+
570
+ self.amp_trainer = MixedPrecisionTrainer(
571
+ enabled=(self.mixed_precision != "fp32" and self.device.type == "cuda"),
572
+ dtype=dtype
573
+ )
574
+ except ImportError:
575
+ # Fallback to standard scaler
576
+ self.amp_trainer = None
577
+ self.scaler = torch.cuda.amp.GradScaler() if self.device.type == "cuda" else None
578
+
579
+ def _setup_optimizer(self):
580
+ """Setup optimizer for trainable parameters only."""
581
+ trainable_params = [p for p in self.model.parameters() if p.requires_grad]
582
+
583
+ if len(trainable_params) == 0:
584
+ logger.warning("No trainable parameters found! Check model configuration.")
585
+ trainable_params = list(self.model.parameters())
586
+
587
+ logger.info(f"Optimizing {len(trainable_params)} parameter groups")
588
+
589
+ return AdamW(
590
+ trainable_params,
591
+ lr=self.config.training.learning_rate,
592
+ weight_decay=self.config.training.weight_decay
593
+ )
594
+
595
+ def _setup_scheduler(self):
596
+ steps_per_epoch = len(self.train_dataloader) // self.gradient_accumulation_steps
597
+ total_steps = steps_per_epoch * self.config.training.num_epochs
598
+
599
+ if self.config.training.warmup_steps > 0:
600
+ return OneCycleLR(
601
+ self.optimizer,
602
+ max_lr=self.config.training.learning_rate,
603
+ total_steps=total_steps,
604
+ pct_start=self.config.training.warmup_steps / max(total_steps, 1)
605
+ )
606
+ return CosineAnnealingLR(self.optimizer, T_max=total_steps)
607
+
608
+ def _setup_logging(self):
609
+ try:
610
+ wandb.init(
611
+ project="langtune-fast",
612
+ config={
613
+ "gradient_accumulation": self.gradient_accumulation_steps,
614
+ "mixed_precision": self.mixed_precision,
615
+ **({k: v for k, v in self.config.__dict__.items() if not k.startswith('_')})
616
+ },
617
+ name=f"fast_run_{int(time.time())}"
618
+ )
619
+ self.use_wandb = True
620
+ except:
621
+ self.use_wandb = False
622
+
623
+ def _log_training_info(self):
624
+ total_params = sum(p.numel() for p in self.model.parameters())
625
+ trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
626
+
627
+ logger.info(f"FastTrainer initialized on {self.device}")
628
+ logger.info(f"Total parameters: {total_params:,}")
629
+ logger.info(f"Trainable parameters: {trainable_params:,} ({100*trainable_params/total_params:.2f}%)")
630
+ logger.info(f"Gradient accumulation: {self.gradient_accumulation_steps}")
631
+ logger.info(f"Mixed precision: {self.mixed_precision}")
632
+ logger.info(f"Effective batch size: {self.config.training.batch_size * self.gradient_accumulation_steps}")
633
+
634
+ def _log_memory(self, prefix: str = ""):
635
+ if self.device.type == "cuda":
636
+ try:
637
+ from .optimizations import log_memory_usage
638
+ log_memory_usage(prefix)
639
+ except ImportError:
640
+ allocated = torch.cuda.memory_allocated() / 1e9
641
+ logger.info(f"{prefix}GPU Memory: {allocated:.2f} GB")
642
+
643
+ def train_epoch(self, epoch: int) -> Dict[str, float]:
644
+ self.model.train()
645
+ total_loss = 0.0
646
+ num_steps = 0
647
+ accumulated_loss = 0.0
648
+
649
+ progress_bar = tqdm(
650
+ self.train_dataloader,
651
+ desc=f"Epoch {epoch+1}/{self.config.training.num_epochs}",
652
+ leave=False
653
+ )
654
+
655
+ self.optimizer.zero_grad()
656
+
657
+ for batch_idx, batch in enumerate(progress_bar):
658
+ batch = {k: v.to(self.device) for k, v in batch.items()}
659
+
660
+ # Forward with AMP
661
+ if self.amp_trainer:
662
+ with self.amp_trainer.autocast_context:
663
+ outputs = self.model(**batch)
664
+ loss = outputs["loss"] / self.gradient_accumulation_steps
665
+
666
+ # Scale and backward
667
+ scaled_loss = self.amp_trainer.scale_loss(loss)
668
+ scaled_loss.backward()
669
+ else:
670
+ outputs = self.model(**batch)
671
+ loss = outputs["loss"] / self.gradient_accumulation_steps
672
+ loss.backward()
673
+
674
+ accumulated_loss += loss.item()
675
+
676
+ # Optimizer step after accumulation
677
+ if (batch_idx + 1) % self.gradient_accumulation_steps == 0:
678
+ # Gradient clipping
679
+ if self.config.training.max_grad_norm > 0:
680
+ if self.amp_trainer:
681
+ self.amp_trainer.unscale_gradients(self.optimizer)
682
+ torch.nn.utils.clip_grad_norm_(
683
+ self.model.parameters(),
684
+ self.config.training.max_grad_norm
685
+ )
686
+
687
+ # Step
688
+ if self.amp_trainer:
689
+ self.amp_trainer.step(self.optimizer)
690
+ else:
691
+ self.optimizer.step()
692
+
693
+ if self.scheduler:
694
+ self.scheduler.step()
695
+
696
+ self.optimizer.zero_grad()
697
+
698
+ # Track
699
+ step_loss = accumulated_loss * self.gradient_accumulation_steps
700
+ total_loss += step_loss
701
+ num_steps += 1
702
+ accumulated_loss = 0.0
703
+
704
+ progress_bar.set_postfix({"loss": f"{step_loss:.4f}"})
705
+
706
+ # Log periodically
707
+ if num_steps % (self.config.training.logging_steps // self.gradient_accumulation_steps + 1) == 0:
708
+ lr = self.optimizer.param_groups[0]['lr']
709
+ metrics = {"train_loss": step_loss, "learning_rate": lr, "epoch": epoch}
710
+ self.metrics_tracker.update(metrics)
711
+ if self.use_wandb:
712
+ wandb.log(metrics)
713
+
714
+ # Handle remaining batches
715
+ if accumulated_loss > 0:
716
+ if self.amp_trainer:
717
+ self.amp_trainer.step(self.optimizer)
718
+ else:
719
+ self.optimizer.step()
720
+ self.optimizer.zero_grad()
721
+
722
+ self._log_memory(f"Epoch {epoch+1} end - ")
723
+
724
+ return {"train_loss": total_loss / max(num_steps, 1)}
725
+
726
+ def validate(self, epoch: int) -> Dict[str, float]:
727
+ if self.val_dataloader is None:
728
+ return {}
729
+
730
+ self.model.eval()
731
+ total_loss = 0.0
732
+ num_batches = 0
733
+
734
+ with torch.no_grad():
735
+ for batch in tqdm(self.val_dataloader, desc="Validation", leave=False):
736
+ batch = {k: v.to(self.device) for k, v in batch.items()}
737
+
738
+ if self.amp_trainer:
739
+ with self.amp_trainer.autocast_context:
740
+ outputs = self.model(**batch)
741
+ else:
742
+ outputs = self.model(**batch)
743
+
744
+ total_loss += outputs["loss"].item()
745
+ num_batches += 1
746
+
747
+ return {"val_loss": total_loss / max(num_batches, 1)}
748
+
749
+ def train(self, resume_from_checkpoint: Optional[str] = None):
750
+ start_epoch = 0
751
+
752
+ if resume_from_checkpoint:
753
+ start_epoch, _ = self.checkpointer.load(
754
+ self.model, self.optimizer, self.scheduler, resume_from_checkpoint
755
+ )
756
+ logger.info(f"Resumed from epoch {start_epoch}")
757
+
758
+ logger.info("Starting optimized training...")
759
+ self._log_memory("Training start - ")
760
+
761
+ for epoch in range(start_epoch, self.config.training.num_epochs):
762
+ train_metrics = self.train_epoch(epoch)
763
+ val_metrics = self.validate(epoch)
764
+
765
+ all_metrics = {**train_metrics, **val_metrics}
766
+ self.metrics_tracker.update(all_metrics)
767
+ epoch_metrics = self.metrics_tracker.log_epoch()
768
+
769
+ if self.use_wandb:
770
+ wandb.log(epoch_metrics)
771
+
772
+ self.checkpointer.save(self.model, self.optimizer, self.scheduler, epoch, all_metrics)
773
+
774
+ if val_metrics and "val_loss" in val_metrics:
775
+ if self.early_stopping(val_metrics["val_loss"]):
776
+ logger.info(f"Early stopping at epoch {epoch+1}")
777
+ break
778
+
779
+ logger.info(
780
+ f"Epoch {epoch+1} - Train: {train_metrics['train_loss']:.4f}, "
781
+ f"Val: {val_metrics.get('val_loss', 'N/A')}"
782
+ )
783
+
784
+ logger.info("Training completed!")
785
+ self._log_memory("Training end - ")
786
+
787
+ # Cleanup
788
+ try:
789
+ from .optimizations import cleanup_memory
790
+ cleanup_memory()
791
+ except ImportError:
792
+ if self.device.type == "cuda":
793
+ torch.cuda.empty_cache()
794
+
795
+
796
+ def create_trainer(
797
+ config: Config,
798
+ train_dataloader: DataLoader,
799
+ val_dataloader: Optional[DataLoader] = None,
800
+ test_dataloader: Optional[DataLoader] = None
801
+ ) -> Trainer:
802
+ """
803
+ Create a trainer instance.
804
+
805
+ Args:
806
+ config: Training configuration
807
+ train_dataloader: Training data loader
808
+ val_dataloader: Validation data loader (optional)
809
+ test_dataloader: Test data loader (optional)
810
+
811
+ Returns:
812
+ Trainer instance
813
+ """
814
+ # Create model
815
+ model = LoRALanguageModel(
816
+ vocab_size=config.model.vocab_size,
817
+ embed_dim=config.model.embed_dim,
818
+ num_layers=config.model.num_layers,
819
+ num_heads=config.model.num_heads,
820
+ max_seq_len=config.model.max_seq_len,
821
+ mlp_ratio=config.model.mlp_ratio,
822
+ dropout=config.model.dropout,
823
+ lora_config=config.model.lora.__dict__ if config.model.lora else None
824
+ )
825
+
826
+ # Create trainer
827
+ trainer = Trainer(
828
+ model=model,
829
+ config=config,
830
+ train_dataloader=train_dataloader,
831
+ val_dataloader=val_dataloader,
832
+ test_dataloader=test_dataloader
833
+ )
834
+
835
+ return trainer
836
+
837
+
838
+ def create_fast_trainer(
839
+ config: Config,
840
+ train_dataloader: DataLoader,
841
+ val_dataloader: Optional[DataLoader] = None,
842
+ test_dataloader: Optional[DataLoader] = None,
843
+ gradient_accumulation_steps: int = 4,
844
+ mixed_precision: str = "fp16"
845
+ ) -> FastTrainer:
846
+ """
847
+ Create an optimized FastTrainer instance with FastLoRALanguageModel.
848
+
849
+ Args:
850
+ config: Training configuration
851
+ train_dataloader: Training data loader
852
+ val_dataloader: Validation data loader (optional)
853
+ test_dataloader: Test data loader (optional)
854
+ gradient_accumulation_steps: Steps to accumulate gradients
855
+ mixed_precision: "fp16", "bf16", or "fp32"
856
+
857
+ Returns:
858
+ FastTrainer instance
859
+ """
860
+ from .models import FastLoRALanguageModel
861
+
862
+ # Create optimized model
863
+ model = FastLoRALanguageModel(
864
+ vocab_size=config.model.vocab_size,
865
+ embed_dim=config.model.embed_dim,
866
+ num_layers=config.model.num_layers,
867
+ num_heads=config.model.num_heads,
868
+ max_seq_len=config.model.max_seq_len,
869
+ mlp_ratio=config.model.mlp_ratio,
870
+ dropout=config.model.dropout,
871
+ lora_config=config.model.lora.__dict__ if config.model.lora else None,
872
+ use_rope=True,
873
+ use_flash_attention=True,
874
+ use_gradient_checkpointing=True
875
+ )
876
+
877
+ # Create fast trainer
878
+ trainer = FastTrainer(
879
+ model=model,
880
+ config=config,
881
+ train_dataloader=train_dataloader,
882
+ val_dataloader=val_dataloader,
883
+ test_dataloader=test_dataloader,
884
+ gradient_accumulation_steps=gradient_accumulation_steps,
885
+ mixed_precision=mixed_precision
886
+ )
887
+
888
+ return trainer
889
+