gptmed 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gptmed/__init__.py +37 -0
- gptmed/configs/__init__.py +1 -0
- gptmed/configs/train_config.py +154 -0
- gptmed/data/__init__.py +5 -0
- gptmed/data/parsers/__init__.py +10 -0
- gptmed/data/parsers/medquad_parser.py +257 -0
- gptmed/data/parsers/text_formatter.py +148 -0
- gptmed/inference/__init__.py +1 -0
- gptmed/inference/decoding_utils.py +190 -0
- gptmed/inference/generation_config.py +83 -0
- gptmed/inference/generator.py +253 -0
- gptmed/inference/sampling.py +261 -0
- gptmed/model/__init__.py +9 -0
- gptmed/model/architecture/__init__.py +35 -0
- gptmed/model/architecture/attention.py +188 -0
- gptmed/model/architecture/decoder_block.py +130 -0
- gptmed/model/architecture/embeddings.py +146 -0
- gptmed/model/architecture/feedforward.py +109 -0
- gptmed/model/architecture/transformer.py +204 -0
- gptmed/model/configs/__init__.py +17 -0
- gptmed/model/configs/model_config.py +155 -0
- gptmed/tokenizer/__init__.py +7 -0
- gptmed/tokenizer/tokenize_data.py +286 -0
- gptmed/tokenizer/train_tokenizer.py +218 -0
- gptmed/training/__init__.py +1 -0
- gptmed/training/dataset.py +183 -0
- gptmed/training/train.py +272 -0
- gptmed/training/trainer.py +331 -0
- gptmed/training/utils.py +212 -0
- gptmed/utils/__init__.py +1 -0
- gptmed/utils/checkpoints.py +224 -0
- gptmed/utils/logging.py +189 -0
- gptmed-0.0.1.dist-info/METADATA +325 -0
- gptmed-0.0.1.dist-info/RECORD +38 -0
- gptmed-0.0.1.dist-info/WHEEL +5 -0
- gptmed-0.0.1.dist-info/entry_points.txt +3 -0
- gptmed-0.0.1.dist-info/licenses/LICENSE +21 -0
- gptmed-0.0.1.dist-info/top_level.txt +1 -0
gptmed/training/train.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Main Training Script
|
|
3
|
+
|
|
4
|
+
PURPOSE:
|
|
5
|
+
Entry point for training the GPT model.
|
|
6
|
+
Ties everything together and starts training.
|
|
7
|
+
|
|
8
|
+
WHAT THIS FILE DOES:
|
|
9
|
+
1. Load configuration (model + training)
|
|
10
|
+
2. Create model
|
|
11
|
+
3. Load tokenized data
|
|
12
|
+
4. Create optimizer
|
|
13
|
+
5. Start training
|
|
14
|
+
6. Handle command-line arguments
|
|
15
|
+
|
|
16
|
+
USAGE:
|
|
17
|
+
python training/train.py # Use default config
|
|
18
|
+
python training/train.py --batch-size 32 # Override batch size
|
|
19
|
+
python training/train.py --resume # Resume from checkpoint
|
|
20
|
+
|
|
21
|
+
PACKAGES USED:
|
|
22
|
+
- torch: PyTorch
|
|
23
|
+
- argparse: Command-line arguments
|
|
24
|
+
- pathlib: Path handling
|
|
25
|
+
|
|
26
|
+
FILES FROM THIS PROJECT:
|
|
27
|
+
- All model, training, and utility modules
|
|
28
|
+
|
|
29
|
+
EXECUTION ORDER:
|
|
30
|
+
1. Parse arguments
|
|
31
|
+
2. Set random seeds (reproducibility)
|
|
32
|
+
3. Create model
|
|
33
|
+
4. Load data
|
|
34
|
+
5. Create optimizer
|
|
35
|
+
6. Initialize trainer
|
|
36
|
+
7. Start training
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
import torch
|
|
40
|
+
import argparse
|
|
41
|
+
from pathlib import Path
|
|
42
|
+
import random
|
|
43
|
+
import numpy as np
|
|
44
|
+
import sys
|
|
45
|
+
|
|
46
|
+
# Add parent directory to path for imports
|
|
47
|
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
48
|
+
|
|
49
|
+
from llm_med.model.architecture import GPTTransformer
|
|
50
|
+
from llm_med.model.configs.model_config import get_small_config, get_tiny_config
|
|
51
|
+
from llm_med.configs.train_config import get_default_config, get_quick_test_config
|
|
52
|
+
from llm_med.training.dataset import create_dataloaders
|
|
53
|
+
from llm_med.training.trainer import Trainer
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def set_seed(seed: int):
|
|
57
|
+
"""
|
|
58
|
+
Set random seeds for reproducibility.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
seed: Random seed
|
|
62
|
+
|
|
63
|
+
Why this matters:
|
|
64
|
+
- Makes training reproducible
|
|
65
|
+
- Critical for debugging (can recreate issues)
|
|
66
|
+
- Scientific experiments need reproducibility
|
|
67
|
+
"""
|
|
68
|
+
random.seed(seed)
|
|
69
|
+
np.random.seed(seed)
|
|
70
|
+
torch.manual_seed(seed)
|
|
71
|
+
if torch.cuda.is_available():
|
|
72
|
+
torch.cuda.manual_seed(seed)
|
|
73
|
+
torch.cuda.manual_seed_all(seed)
|
|
74
|
+
|
|
75
|
+
# Make cudnn deterministic (slower but reproducible)
|
|
76
|
+
torch.backends.cudnn.deterministic = True
|
|
77
|
+
torch.backends.cudnn.benchmark = False
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def count_parameters(model):
|
|
81
|
+
"""Count trainable parameters."""
|
|
82
|
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def main():
|
|
86
|
+
parser = argparse.ArgumentParser(description="Train GPT model on MedQuAD")
|
|
87
|
+
|
|
88
|
+
# Model config
|
|
89
|
+
parser.add_argument(
|
|
90
|
+
"--model-size",
|
|
91
|
+
type=str,
|
|
92
|
+
default="small",
|
|
93
|
+
choices=["tiny", "small", "medium"],
|
|
94
|
+
help="Model size (tiny/small/medium)",
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# Training config
|
|
98
|
+
parser.add_argument(
|
|
99
|
+
"--batch-size", type=int, default=None, help="Batch size (overrides config)"
|
|
100
|
+
)
|
|
101
|
+
parser.add_argument(
|
|
102
|
+
"--learning-rate", type=float, default=None, help="Learning rate (overrides config)"
|
|
103
|
+
)
|
|
104
|
+
parser.add_argument(
|
|
105
|
+
"--num-epochs", type=int, default=None, help="Number of epochs (overrides config)"
|
|
106
|
+
)
|
|
107
|
+
parser.add_argument(
|
|
108
|
+
"--quick-test", action="store_true", help="Use quick test config (small batches, few steps)"
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
# Paths
|
|
112
|
+
parser.add_argument(
|
|
113
|
+
"--train-data", type=str, default="./data/tokenized/train.npy", help="Path to training data"
|
|
114
|
+
)
|
|
115
|
+
parser.add_argument(
|
|
116
|
+
"--val-data", type=str, default="./data/tokenized/val.npy", help="Path to validation data"
|
|
117
|
+
)
|
|
118
|
+
parser.add_argument(
|
|
119
|
+
"--checkpoint-dir", type=str, default="./model/checkpoints", help="Checkpoint directory"
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
# Resume training
|
|
123
|
+
parser.add_argument("--resume", action="store_true", help="Resume from latest checkpoint")
|
|
124
|
+
parser.add_argument(
|
|
125
|
+
"--resume-from", type=str, default=None, help="Resume from specific checkpoint"
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# Device
|
|
129
|
+
parser.add_argument(
|
|
130
|
+
"--device", type=str, default="cuda", choices=["cuda", "cpu"], help="Device to train on"
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
# Misc
|
|
134
|
+
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
|
135
|
+
|
|
136
|
+
args = parser.parse_args()
|
|
137
|
+
|
|
138
|
+
print("=" * 60)
|
|
139
|
+
print("GPT Training - MedQuAD")
|
|
140
|
+
print("=" * 60)
|
|
141
|
+
|
|
142
|
+
# Check CUDA availability
|
|
143
|
+
if args.device == "cuda" and not torch.cuda.is_available():
|
|
144
|
+
print("WARNING: CUDA not available, using CPU")
|
|
145
|
+
args.device = "cpu"
|
|
146
|
+
|
|
147
|
+
# Set random seed
|
|
148
|
+
print(f"\nSetting random seed: {args.seed}")
|
|
149
|
+
set_seed(args.seed)
|
|
150
|
+
|
|
151
|
+
# Load configurations
|
|
152
|
+
print(f"\nLoading configurations...")
|
|
153
|
+
|
|
154
|
+
# Model config
|
|
155
|
+
if args.model_size == "tiny":
|
|
156
|
+
model_config = get_tiny_config()
|
|
157
|
+
elif args.model_size == "small":
|
|
158
|
+
model_config = get_small_config()
|
|
159
|
+
else:
|
|
160
|
+
raise ValueError(f"Unknown model size: {args.model_size}")
|
|
161
|
+
|
|
162
|
+
print(f"Model config: {args.model_size}")
|
|
163
|
+
print(f" d_model: {model_config.d_model}")
|
|
164
|
+
print(f" n_layers: {model_config.n_layers}")
|
|
165
|
+
print(f" n_heads: {model_config.n_heads}")
|
|
166
|
+
|
|
167
|
+
# Training config
|
|
168
|
+
if args.quick_test:
|
|
169
|
+
train_config = get_quick_test_config()
|
|
170
|
+
print("Using quick test config (fast debugging)")
|
|
171
|
+
else:
|
|
172
|
+
train_config = get_default_config()
|
|
173
|
+
|
|
174
|
+
# Override with command-line args
|
|
175
|
+
if args.batch_size is not None:
|
|
176
|
+
train_config.batch_size = args.batch_size
|
|
177
|
+
if args.learning_rate is not None:
|
|
178
|
+
train_config.learning_rate = args.learning_rate
|
|
179
|
+
if args.num_epochs is not None:
|
|
180
|
+
train_config.num_epochs = args.num_epochs
|
|
181
|
+
if args.train_data:
|
|
182
|
+
train_config.train_data_path = args.train_data
|
|
183
|
+
if args.val_data:
|
|
184
|
+
train_config.val_data_path = args.val_data
|
|
185
|
+
if args.checkpoint_dir:
|
|
186
|
+
train_config.checkpoint_dir = args.checkpoint_dir
|
|
187
|
+
|
|
188
|
+
train_config.device = args.device
|
|
189
|
+
train_config.seed = args.seed
|
|
190
|
+
|
|
191
|
+
print(f"\nTraining config:")
|
|
192
|
+
print(f" Batch size: {train_config.batch_size}")
|
|
193
|
+
print(f" Learning rate: {train_config.learning_rate}")
|
|
194
|
+
print(f" Num epochs: {train_config.num_epochs}")
|
|
195
|
+
print(f" Device: {train_config.device}")
|
|
196
|
+
|
|
197
|
+
# Create model
|
|
198
|
+
print(f"\nCreating model...")
|
|
199
|
+
model = GPTTransformer(model_config)
|
|
200
|
+
total_params = count_parameters(model)
|
|
201
|
+
print(f"Model created with {total_params:,} parameters")
|
|
202
|
+
print(f"Model size: ~{total_params * 4 / 1024 / 1024:.2f} MB")
|
|
203
|
+
|
|
204
|
+
# Load data
|
|
205
|
+
print(f"\nLoading data...")
|
|
206
|
+
train_loader, val_loader = create_dataloaders(
|
|
207
|
+
train_path=Path(train_config.train_data_path),
|
|
208
|
+
val_path=Path(train_config.val_data_path),
|
|
209
|
+
batch_size=train_config.batch_size,
|
|
210
|
+
num_workers=0,
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
# Create optimizer
|
|
214
|
+
print(f"\nCreating optimizer...")
|
|
215
|
+
optimizer = torch.optim.AdamW(
|
|
216
|
+
model.parameters(),
|
|
217
|
+
lr=train_config.learning_rate,
|
|
218
|
+
betas=train_config.betas,
|
|
219
|
+
eps=train_config.eps,
|
|
220
|
+
weight_decay=train_config.weight_decay,
|
|
221
|
+
)
|
|
222
|
+
print(f"Optimizer: AdamW")
|
|
223
|
+
print(f" LR: {train_config.learning_rate}")
|
|
224
|
+
print(f" Weight decay: {train_config.weight_decay}")
|
|
225
|
+
|
|
226
|
+
# Create trainer
|
|
227
|
+
print(f"\nInitializing trainer...")
|
|
228
|
+
trainer = Trainer(
|
|
229
|
+
model=model,
|
|
230
|
+
train_loader=train_loader,
|
|
231
|
+
val_loader=val_loader,
|
|
232
|
+
optimizer=optimizer,
|
|
233
|
+
config=train_config,
|
|
234
|
+
device=args.device,
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
# Resume if requested
|
|
238
|
+
if args.resume or args.resume_from:
|
|
239
|
+
print(f"\nResuming training...")
|
|
240
|
+
checkpoint_path = Path(args.resume_from) if args.resume_from else None
|
|
241
|
+
trainer.resume_from_checkpoint(checkpoint_path)
|
|
242
|
+
|
|
243
|
+
# Start training
|
|
244
|
+
print(f"\n{'='*60}")
|
|
245
|
+
print("Ready to train!")
|
|
246
|
+
print(f"{'='*60}\n")
|
|
247
|
+
|
|
248
|
+
try:
|
|
249
|
+
trainer.train()
|
|
250
|
+
except KeyboardInterrupt:
|
|
251
|
+
print("\n\nTraining interrupted by user")
|
|
252
|
+
print("Saving checkpoint...")
|
|
253
|
+
trainer.checkpoint_manager.save_checkpoint(
|
|
254
|
+
model=model,
|
|
255
|
+
optimizer=optimizer,
|
|
256
|
+
step=trainer.global_step,
|
|
257
|
+
epoch=trainer.current_epoch,
|
|
258
|
+
val_loss=trainer.best_val_loss,
|
|
259
|
+
model_config=model_config.to_dict(),
|
|
260
|
+
train_config=train_config.to_dict(),
|
|
261
|
+
)
|
|
262
|
+
print("Checkpoint saved. You can resume with --resume")
|
|
263
|
+
|
|
264
|
+
print("\n" + "=" * 60)
|
|
265
|
+
print("Training finished!")
|
|
266
|
+
print("=" * 60)
|
|
267
|
+
print(f"\nBest model saved in: {train_config.checkpoint_dir}/best_model.pt")
|
|
268
|
+
print(f"Logs saved in: {train_config.log_dir}")
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
if __name__ == "__main__":
|
|
272
|
+
main()
|
|
@@ -0,0 +1,331 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Trainer Class
|
|
3
|
+
|
|
4
|
+
PURPOSE:
|
|
5
|
+
Core training loop logic. This is the heart of Phase 3.
|
|
6
|
+
Handles forward pass, backward pass, optimization, and evaluation.
|
|
7
|
+
|
|
8
|
+
WHAT THIS FILE DOES:
|
|
9
|
+
1. Training loop: iterate over batches, compute loss, backprop
|
|
10
|
+
2. Evaluation: compute validation loss
|
|
11
|
+
3. Checkpointing: save model periodically
|
|
12
|
+
4. Logging: track metrics
|
|
13
|
+
|
|
14
|
+
TRAINING ALGORITHM:
|
|
15
|
+
For each batch:
|
|
16
|
+
1. Forward pass: model(input) → logits
|
|
17
|
+
2. Compute loss: CrossEntropyLoss(logits, targets)
|
|
18
|
+
3. Backward pass: loss.backward()
|
|
19
|
+
4. Clip gradients: prevent explosion
|
|
20
|
+
5. Optimizer step: update weights
|
|
21
|
+
6. Update learning rate: warmup + decay
|
|
22
|
+
|
|
23
|
+
PACKAGES USED:
|
|
24
|
+
- torch: PyTorch training
|
|
25
|
+
- time: Track speed
|
|
26
|
+
|
|
27
|
+
FILES FROM THIS PROJECT:
|
|
28
|
+
- model/architecture/transformer.py: Model
|
|
29
|
+
- training/dataset.py: DataLoader
|
|
30
|
+
- training/utils.py: Helper functions
|
|
31
|
+
- utils/logging.py: Metrics logging
|
|
32
|
+
- utils/checkpoints.py: Save/load
|
|
33
|
+
|
|
34
|
+
TENSOR SHAPES:
|
|
35
|
+
- Input: [batch_size, seq_len]
|
|
36
|
+
- Logits: [batch_size, seq_len, vocab_size]
|
|
37
|
+
- Targets: [batch_size, seq_len]
|
|
38
|
+
- Loss: scalar
|
|
39
|
+
|
|
40
|
+
COMMON TRAINING ISSUES:
|
|
41
|
+
- Loss = NaN → gradient explosion (reduce LR, check grad clipping)
|
|
42
|
+
- Loss stuck → LR too low, bad initialization
|
|
43
|
+
- Slow convergence → LR too low, increase it
|
|
44
|
+
- Overfitting → add dropout, weight decay
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
import torch
|
|
48
|
+
import torch.nn as nn
|
|
49
|
+
from torch.utils.data import DataLoader
|
|
50
|
+
import time
|
|
51
|
+
from pathlib import Path
|
|
52
|
+
from typing import Optional
|
|
53
|
+
|
|
54
|
+
from llm_med.model.architecture import GPTTransformer
|
|
55
|
+
from llm_med.training.utils import (
|
|
56
|
+
clip_grad_norm,
|
|
57
|
+
get_lr_with_warmup,
|
|
58
|
+
set_learning_rate,
|
|
59
|
+
estimate_loss_dataloader,
|
|
60
|
+
compute_perplexity,
|
|
61
|
+
)
|
|
62
|
+
from llm_med.utils.logging import MetricsLogger, log_training_step, log_validation
|
|
63
|
+
from llm_med.utils.checkpoints import CheckpointManager
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class Trainer:
|
|
67
|
+
"""
|
|
68
|
+
Training orchestrator for GPT model.
|
|
69
|
+
|
|
70
|
+
Handles the full training loop including:
|
|
71
|
+
- Forward/backward passes
|
|
72
|
+
- Optimization
|
|
73
|
+
- Evaluation
|
|
74
|
+
- Checkpointing
|
|
75
|
+
- Logging
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
def __init__(
|
|
79
|
+
self,
|
|
80
|
+
model: GPTTransformer,
|
|
81
|
+
train_loader: DataLoader,
|
|
82
|
+
val_loader: DataLoader,
|
|
83
|
+
optimizer: torch.optim.Optimizer,
|
|
84
|
+
config, # TrainingConfig
|
|
85
|
+
device: str = "cuda",
|
|
86
|
+
):
|
|
87
|
+
"""
|
|
88
|
+
Args:
|
|
89
|
+
model: GPT model to train
|
|
90
|
+
train_loader: Training data loader
|
|
91
|
+
val_loader: Validation data loader
|
|
92
|
+
optimizer: Optimizer (e.g., AdamW)
|
|
93
|
+
config: TrainingConfig object
|
|
94
|
+
device: Device to train on
|
|
95
|
+
"""
|
|
96
|
+
self.model = model.to(device)
|
|
97
|
+
self.train_loader = train_loader
|
|
98
|
+
self.val_loader = val_loader
|
|
99
|
+
self.optimizer = optimizer
|
|
100
|
+
self.config = config
|
|
101
|
+
self.device = device
|
|
102
|
+
|
|
103
|
+
# Initialize utilities
|
|
104
|
+
self.logger = MetricsLogger(log_dir=config.log_dir, experiment_name="gpt_training")
|
|
105
|
+
|
|
106
|
+
self.checkpoint_manager = CheckpointManager(
|
|
107
|
+
checkpoint_dir=config.checkpoint_dir, keep_last_n=config.keep_last_n
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
# Training state
|
|
111
|
+
self.global_step = 0
|
|
112
|
+
self.current_epoch = 0
|
|
113
|
+
self.best_val_loss = float("inf")
|
|
114
|
+
|
|
115
|
+
# Calculate total steps
|
|
116
|
+
steps_per_epoch = len(train_loader)
|
|
117
|
+
if config.max_steps > 0:
|
|
118
|
+
self.total_steps = config.max_steps
|
|
119
|
+
else:
|
|
120
|
+
self.total_steps = steps_per_epoch * config.num_epochs
|
|
121
|
+
|
|
122
|
+
print(f"\nTrainer initialized:")
|
|
123
|
+
print(f" Device: {device}")
|
|
124
|
+
print(f" Total steps: {self.total_steps}")
|
|
125
|
+
print(f" Steps per epoch: {steps_per_epoch}")
|
|
126
|
+
print(f" Num epochs: {config.num_epochs}")
|
|
127
|
+
|
|
128
|
+
def train_step(self, batch: tuple) -> dict:
|
|
129
|
+
"""
|
|
130
|
+
Single training step.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
batch: (input_ids, target_ids) tuple
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
Dictionary with step metrics
|
|
137
|
+
"""
|
|
138
|
+
# Move batch to device
|
|
139
|
+
input_ids, target_ids = batch
|
|
140
|
+
input_ids = input_ids.to(self.device)
|
|
141
|
+
target_ids = target_ids.to(self.device)
|
|
142
|
+
|
|
143
|
+
# Forward pass
|
|
144
|
+
logits = self.model(input_ids)
|
|
145
|
+
|
|
146
|
+
# Compute loss
|
|
147
|
+
# CrossEntropyLoss expects:
|
|
148
|
+
# - Input: [N, C] where N = batch_size * seq_len, C = vocab_size
|
|
149
|
+
# - Target: [N] with class indices
|
|
150
|
+
batch_size, seq_len, vocab_size = logits.shape
|
|
151
|
+
logits_flat = logits.view(batch_size * seq_len, vocab_size)
|
|
152
|
+
targets_flat = target_ids.view(batch_size * seq_len)
|
|
153
|
+
|
|
154
|
+
loss = nn.functional.cross_entropy(logits_flat, targets_flat)
|
|
155
|
+
|
|
156
|
+
# Backward pass
|
|
157
|
+
self.optimizer.zero_grad()
|
|
158
|
+
loss.backward()
|
|
159
|
+
|
|
160
|
+
# Clip gradients (CRITICAL for stability)
|
|
161
|
+
grad_norm = clip_grad_norm(self.model, self.config.grad_clip)
|
|
162
|
+
|
|
163
|
+
# Optimizer step
|
|
164
|
+
self.optimizer.step()
|
|
165
|
+
|
|
166
|
+
# Return metrics
|
|
167
|
+
return {
|
|
168
|
+
"loss": loss.item(),
|
|
169
|
+
"grad_norm": grad_norm,
|
|
170
|
+
"batch_size": batch_size,
|
|
171
|
+
"seq_len": seq_len,
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
def evaluate(self) -> dict:
|
|
175
|
+
"""
|
|
176
|
+
Evaluate on validation set.
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
Dictionary with validation metrics
|
|
180
|
+
"""
|
|
181
|
+
print("\nRunning validation...")
|
|
182
|
+
|
|
183
|
+
val_loss = estimate_loss_dataloader(
|
|
184
|
+
self.model, self.val_loader, self.device, max_batches=self.config.eval_iters
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
val_perplexity = compute_perplexity(val_loss)
|
|
188
|
+
|
|
189
|
+
log_validation(self.global_step, val_loss, val_perplexity)
|
|
190
|
+
|
|
191
|
+
return {"val_loss": val_loss, "val_perplexity": val_perplexity}
|
|
192
|
+
|
|
193
|
+
def train(self):
|
|
194
|
+
"""
|
|
195
|
+
Main training loop.
|
|
196
|
+
|
|
197
|
+
This is where everything comes together.
|
|
198
|
+
"""
|
|
199
|
+
print("\n" + "=" * 60)
|
|
200
|
+
print("Starting Training")
|
|
201
|
+
print("=" * 60)
|
|
202
|
+
|
|
203
|
+
self.model.train()
|
|
204
|
+
|
|
205
|
+
# Training loop
|
|
206
|
+
for epoch in range(self.config.num_epochs):
|
|
207
|
+
self.current_epoch = epoch
|
|
208
|
+
|
|
209
|
+
print(f"\n{'='*60}")
|
|
210
|
+
print(f"Epoch {epoch + 1}/{self.config.num_epochs}")
|
|
211
|
+
print(f"{'='*60}")
|
|
212
|
+
|
|
213
|
+
epoch_start_time = time.time()
|
|
214
|
+
|
|
215
|
+
for batch_idx, batch in enumerate(self.train_loader):
|
|
216
|
+
step_start_time = time.time()
|
|
217
|
+
|
|
218
|
+
# Update learning rate (warmup + decay)
|
|
219
|
+
lr = get_lr_with_warmup(
|
|
220
|
+
step=self.global_step,
|
|
221
|
+
warmup_steps=self.config.warmup_steps,
|
|
222
|
+
max_lr=self.config.learning_rate,
|
|
223
|
+
min_lr=self.config.min_lr,
|
|
224
|
+
max_steps=self.total_steps,
|
|
225
|
+
decay_type=self.config.lr_decay,
|
|
226
|
+
)
|
|
227
|
+
set_learning_rate(self.optimizer, lr)
|
|
228
|
+
|
|
229
|
+
# Training step
|
|
230
|
+
metrics = self.train_step(batch)
|
|
231
|
+
|
|
232
|
+
# Calculate tokens per second
|
|
233
|
+
step_time = time.time() - step_start_time
|
|
234
|
+
tokens_per_sec = (metrics["batch_size"] * metrics["seq_len"]) / step_time
|
|
235
|
+
|
|
236
|
+
# Log to console
|
|
237
|
+
if self.global_step % self.config.log_interval == 0:
|
|
238
|
+
log_training_step(
|
|
239
|
+
step=self.global_step,
|
|
240
|
+
loss=metrics["loss"],
|
|
241
|
+
lr=lr,
|
|
242
|
+
grad_norm=metrics["grad_norm"],
|
|
243
|
+
tokens_per_sec=tokens_per_sec,
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
# Log metrics
|
|
247
|
+
self.logger.log(
|
|
248
|
+
self.global_step,
|
|
249
|
+
{
|
|
250
|
+
"train_loss": metrics["loss"],
|
|
251
|
+
"learning_rate": lr,
|
|
252
|
+
"grad_norm": metrics["grad_norm"],
|
|
253
|
+
"tokens_per_sec": tokens_per_sec,
|
|
254
|
+
},
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
# Evaluate
|
|
258
|
+
if self.global_step % self.config.eval_interval == 0 and self.global_step > 0:
|
|
259
|
+
val_metrics = self.evaluate()
|
|
260
|
+
|
|
261
|
+
# Log validation metrics
|
|
262
|
+
self.logger.log(self.global_step, val_metrics)
|
|
263
|
+
|
|
264
|
+
# Check if best model
|
|
265
|
+
if val_metrics["val_loss"] < self.best_val_loss:
|
|
266
|
+
self.best_val_loss = val_metrics["val_loss"]
|
|
267
|
+
is_best = True
|
|
268
|
+
else:
|
|
269
|
+
is_best = False
|
|
270
|
+
|
|
271
|
+
# Save checkpoint
|
|
272
|
+
self.checkpoint_manager.save_checkpoint(
|
|
273
|
+
model=self.model,
|
|
274
|
+
optimizer=self.optimizer,
|
|
275
|
+
step=self.global_step,
|
|
276
|
+
epoch=epoch,
|
|
277
|
+
val_loss=val_metrics["val_loss"],
|
|
278
|
+
model_config=self.model.config.to_dict(),
|
|
279
|
+
train_config=self.config.to_dict(),
|
|
280
|
+
is_best=is_best,
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
self.model.train() # Back to training mode
|
|
284
|
+
|
|
285
|
+
# Save checkpoint periodically
|
|
286
|
+
if self.global_step % self.config.save_interval == 0 and self.global_step > 0:
|
|
287
|
+
self.checkpoint_manager.save_checkpoint(
|
|
288
|
+
model=self.model,
|
|
289
|
+
optimizer=self.optimizer,
|
|
290
|
+
step=self.global_step,
|
|
291
|
+
epoch=epoch,
|
|
292
|
+
val_loss=self.best_val_loss,
|
|
293
|
+
model_config=self.model.config.to_dict(),
|
|
294
|
+
train_config=self.config.to_dict(),
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
self.global_step += 1
|
|
298
|
+
|
|
299
|
+
# Check if reached max steps
|
|
300
|
+
if self.config.max_steps > 0 and self.global_step >= self.config.max_steps:
|
|
301
|
+
print(f"\nReached max_steps ({self.config.max_steps}). Stopping training.")
|
|
302
|
+
return
|
|
303
|
+
|
|
304
|
+
# End of epoch
|
|
305
|
+
epoch_time = time.time() - epoch_start_time
|
|
306
|
+
print(f"\nEpoch {epoch + 1} completed in {epoch_time:.2f}s")
|
|
307
|
+
|
|
308
|
+
print("\n" + "=" * 60)
|
|
309
|
+
print("Training Complete!")
|
|
310
|
+
print("=" * 60)
|
|
311
|
+
print(f"Best validation loss: {self.best_val_loss:.4f}")
|
|
312
|
+
|
|
313
|
+
def resume_from_checkpoint(self, checkpoint_path: Optional[Path] = None):
|
|
314
|
+
"""
|
|
315
|
+
Resume training from a checkpoint.
|
|
316
|
+
|
|
317
|
+
Args:
|
|
318
|
+
checkpoint_path: Path to checkpoint (or None for latest)
|
|
319
|
+
"""
|
|
320
|
+
checkpoint = self.checkpoint_manager.load_checkpoint(
|
|
321
|
+
model=self.model,
|
|
322
|
+
optimizer=self.optimizer,
|
|
323
|
+
checkpoint_path=checkpoint_path,
|
|
324
|
+
device=self.device,
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
self.global_step = checkpoint["step"]
|
|
328
|
+
self.current_epoch = checkpoint["epoch"]
|
|
329
|
+
self.best_val_loss = checkpoint.get("best_val_loss", float("inf"))
|
|
330
|
+
|
|
331
|
+
print(f"Resumed from step {self.global_step}, epoch {self.current_epoch}")
|