gptmed 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gptmed/__init__.py +37 -0
- gptmed/configs/__init__.py +1 -0
- gptmed/configs/train_config.py +154 -0
- gptmed/data/__init__.py +5 -0
- gptmed/data/parsers/__init__.py +10 -0
- gptmed/data/parsers/medquad_parser.py +257 -0
- gptmed/data/parsers/text_formatter.py +148 -0
- gptmed/inference/__init__.py +1 -0
- gptmed/inference/decoding_utils.py +190 -0
- gptmed/inference/generation_config.py +83 -0
- gptmed/inference/generator.py +253 -0
- gptmed/inference/sampling.py +261 -0
- gptmed/model/__init__.py +9 -0
- gptmed/model/architecture/__init__.py +35 -0
- gptmed/model/architecture/attention.py +188 -0
- gptmed/model/architecture/decoder_block.py +130 -0
- gptmed/model/architecture/embeddings.py +146 -0
- gptmed/model/architecture/feedforward.py +109 -0
- gptmed/model/architecture/transformer.py +204 -0
- gptmed/model/configs/__init__.py +17 -0
- gptmed/model/configs/model_config.py +155 -0
- gptmed/tokenizer/__init__.py +7 -0
- gptmed/tokenizer/tokenize_data.py +286 -0
- gptmed/tokenizer/train_tokenizer.py +218 -0
- gptmed/training/__init__.py +1 -0
- gptmed/training/dataset.py +183 -0
- gptmed/training/train.py +272 -0
- gptmed/training/trainer.py +331 -0
- gptmed/training/utils.py +212 -0
- gptmed/utils/__init__.py +1 -0
- gptmed/utils/checkpoints.py +224 -0
- gptmed/utils/logging.py +189 -0
- gptmed-0.0.1.dist-info/METADATA +325 -0
- gptmed-0.0.1.dist-info/RECORD +38 -0
- gptmed-0.0.1.dist-info/WHEEL +5 -0
- gptmed-0.0.1.dist-info/entry_points.txt +3 -0
- gptmed-0.0.1.dist-info/licenses/LICENSE +21 -0
- gptmed-0.0.1.dist-info/top_level.txt +1 -0
gptmed/training/utils.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Training Utilities
|
|
3
|
+
|
|
4
|
+
PURPOSE:
|
|
5
|
+
Helper functions for training loop:
|
|
6
|
+
- Gradient clipping
|
|
7
|
+
- Learning rate scheduling
|
|
8
|
+
- Training state management
|
|
9
|
+
|
|
10
|
+
WHAT THIS FILE DOES:
|
|
11
|
+
1. Clip gradients to prevent explosion
|
|
12
|
+
2. Calculate learning rate with warmup + decay
|
|
13
|
+
3. Compute gradient norms for monitoring
|
|
14
|
+
|
|
15
|
+
WHY THESE ARE CRITICAL:
|
|
16
|
+
- Gradient clipping: Prevents training collapse from exploding gradients
|
|
17
|
+
- LR warmup: Stabilizes early training (large steps can destabilize)
|
|
18
|
+
- LR decay: Helps model converge to better minima
|
|
19
|
+
|
|
20
|
+
PACKAGES USED:
|
|
21
|
+
- torch: Gradient operations
|
|
22
|
+
- math: Cosine calculations
|
|
23
|
+
|
|
24
|
+
FILES FROM THIS PROJECT:
|
|
25
|
+
- None (utility functions)
|
|
26
|
+
|
|
27
|
+
COMMON FAILURE MODES:
|
|
28
|
+
- No gradient clipping → NaN loss
|
|
29
|
+
- No warmup → unstable first epochs
|
|
30
|
+
- Constant LR → suboptimal convergence
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
import torch
|
|
34
|
+
import torch.nn as nn
|
|
35
|
+
import math
|
|
36
|
+
from typing import Optional
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def clip_grad_norm(model: nn.Module, max_norm: float) -> float:
|
|
40
|
+
"""
|
|
41
|
+
Clip gradient norms to prevent explosion.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
model: Model with gradients
|
|
45
|
+
max_norm: Maximum gradient norm
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
Total gradient norm (before clipping)
|
|
49
|
+
|
|
50
|
+
How it works:
|
|
51
|
+
- Compute total gradient norm across all parameters
|
|
52
|
+
- If norm > max_norm, scale all gradients down
|
|
53
|
+
- This prevents single large gradients from destroying training
|
|
54
|
+
|
|
55
|
+
Typical values:
|
|
56
|
+
- max_norm=1.0: Standard for transformers
|
|
57
|
+
- max_norm=5.0: More lenient
|
|
58
|
+
- max_norm=0.5: Very conservative
|
|
59
|
+
"""
|
|
60
|
+
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
|
|
61
|
+
return total_norm.item()
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def get_lr_with_warmup(
|
|
65
|
+
step: int,
|
|
66
|
+
warmup_steps: int,
|
|
67
|
+
max_lr: float,
|
|
68
|
+
min_lr: float,
|
|
69
|
+
max_steps: int,
|
|
70
|
+
decay_type: str = "cosine",
|
|
71
|
+
) -> float:
|
|
72
|
+
"""
|
|
73
|
+
Calculate learning rate with warmup and decay.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
step: Current training step
|
|
77
|
+
warmup_steps: Number of warmup steps
|
|
78
|
+
max_lr: Peak learning rate
|
|
79
|
+
min_lr: Minimum learning rate
|
|
80
|
+
max_steps: Total training steps
|
|
81
|
+
decay_type: 'cosine', 'linear', or 'constant'
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
Learning rate for this step
|
|
85
|
+
|
|
86
|
+
Schedule:
|
|
87
|
+
1. Warmup (0 to warmup_steps): Linear increase from 0 to max_lr
|
|
88
|
+
2. Decay (warmup_steps to max_steps): Cosine/linear decay to min_lr
|
|
89
|
+
|
|
90
|
+
Why warmup?
|
|
91
|
+
- Random initialization → large gradients early on
|
|
92
|
+
- Large LR + large gradients = explosion
|
|
93
|
+
- Warmup gives model time to stabilize
|
|
94
|
+
"""
|
|
95
|
+
# Warmup phase
|
|
96
|
+
if step < warmup_steps:
|
|
97
|
+
# Linear warmup from 0 to max_lr
|
|
98
|
+
return max_lr * (step / warmup_steps)
|
|
99
|
+
|
|
100
|
+
# After warmup
|
|
101
|
+
if decay_type == "constant":
|
|
102
|
+
return max_lr
|
|
103
|
+
|
|
104
|
+
# Progress through decay phase
|
|
105
|
+
progress = (step - warmup_steps) / (max_steps - warmup_steps)
|
|
106
|
+
progress = min(progress, 1.0) # Cap at 1.0
|
|
107
|
+
|
|
108
|
+
if decay_type == "cosine":
|
|
109
|
+
# Cosine decay: smooth curve to min_lr
|
|
110
|
+
lr = min_lr + (max_lr - min_lr) * 0.5 * (1 + math.cos(math.pi * progress))
|
|
111
|
+
elif decay_type == "linear":
|
|
112
|
+
# Linear decay: straight line to min_lr
|
|
113
|
+
lr = max_lr - (max_lr - min_lr) * progress
|
|
114
|
+
else:
|
|
115
|
+
raise ValueError(f"Unknown decay_type: {decay_type}")
|
|
116
|
+
|
|
117
|
+
return lr
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def set_learning_rate(optimizer: torch.optim.Optimizer, lr: float):
|
|
121
|
+
"""
|
|
122
|
+
Set learning rate for optimizer.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
optimizer: PyTorch optimizer
|
|
126
|
+
lr: New learning rate
|
|
127
|
+
"""
|
|
128
|
+
for param_group in optimizer.param_groups:
|
|
129
|
+
param_group["lr"] = lr
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def count_parameters(model: nn.Module) -> int:
|
|
133
|
+
"""Count trainable parameters."""
|
|
134
|
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def estimate_loss_single_batch(model: nn.Module, batch: tuple, device: str) -> float:
|
|
138
|
+
"""
|
|
139
|
+
Compute loss on a single batch (for evaluation).
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
model: Model to evaluate
|
|
143
|
+
batch: (input_ids, target_ids) batch
|
|
144
|
+
device: Device to run on
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
Loss value
|
|
148
|
+
"""
|
|
149
|
+
input_ids, target_ids = batch
|
|
150
|
+
input_ids = input_ids.to(device)
|
|
151
|
+
target_ids = target_ids.to(device)
|
|
152
|
+
|
|
153
|
+
# Forward pass
|
|
154
|
+
logits = model(input_ids)
|
|
155
|
+
|
|
156
|
+
# Compute loss
|
|
157
|
+
# Reshape for CrossEntropyLoss:
|
|
158
|
+
# logits: [batch_size * seq_len, vocab_size]
|
|
159
|
+
# targets: [batch_size * seq_len]
|
|
160
|
+
batch_size, seq_len, vocab_size = logits.shape
|
|
161
|
+
logits = logits.view(batch_size * seq_len, vocab_size)
|
|
162
|
+
targets = target_ids.view(batch_size * seq_len)
|
|
163
|
+
|
|
164
|
+
loss = nn.functional.cross_entropy(logits, targets)
|
|
165
|
+
|
|
166
|
+
return loss.item()
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def estimate_loss_dataloader(
|
|
170
|
+
model: nn.Module, dataloader, device: str, max_batches: Optional[int] = None
|
|
171
|
+
) -> float:
|
|
172
|
+
"""
|
|
173
|
+
Estimate average loss over a dataloader.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
model: Model to evaluate
|
|
177
|
+
dataloader: DataLoader to evaluate on
|
|
178
|
+
device: Device to run on
|
|
179
|
+
max_batches: Maximum number of batches to evaluate (None = all)
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
Average loss
|
|
183
|
+
"""
|
|
184
|
+
model.eval()
|
|
185
|
+
total_loss = 0.0
|
|
186
|
+
num_batches = 0
|
|
187
|
+
|
|
188
|
+
with torch.no_grad():
|
|
189
|
+
for i, batch in enumerate(dataloader):
|
|
190
|
+
if max_batches is not None and i >= max_batches:
|
|
191
|
+
break
|
|
192
|
+
|
|
193
|
+
loss = estimate_loss_single_batch(model, batch, device)
|
|
194
|
+
total_loss += loss
|
|
195
|
+
num_batches += 1
|
|
196
|
+
|
|
197
|
+
model.train()
|
|
198
|
+
|
|
199
|
+
return total_loss / num_batches if num_batches > 0 else 0.0
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def compute_perplexity(loss: float) -> float:
|
|
203
|
+
"""
|
|
204
|
+
Compute perplexity from loss.
|
|
205
|
+
|
|
206
|
+
Perplexity = exp(loss)
|
|
207
|
+
|
|
208
|
+
Lower is better. Perplexity measures how "surprised" the model is.
|
|
209
|
+
- Perplexity of 10 = model is choosing between ~10 likely tokens
|
|
210
|
+
- Perplexity of 100 = model is very uncertain
|
|
211
|
+
"""
|
|
212
|
+
return math.exp(loss)
|
gptmed/utils/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Utils package for training utilities."""
|
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Model Checkpointing Utilities
|
|
3
|
+
|
|
4
|
+
PURPOSE:
|
|
5
|
+
Save and load model checkpoints during training. This is essential for:
|
|
6
|
+
- Resuming interrupted training
|
|
7
|
+
- Saving best model based on validation loss
|
|
8
|
+
- Preventing loss of work from crashes
|
|
9
|
+
|
|
10
|
+
WHAT THIS FILE DOES:
|
|
11
|
+
1. Save model state_dict + optimizer + training state
|
|
12
|
+
2. Load checkpoints to resume training
|
|
13
|
+
3. Manage checkpoint files (keep only best/recent)
|
|
14
|
+
4. Save configuration alongside weights
|
|
15
|
+
|
|
16
|
+
PACKAGES USED:
|
|
17
|
+
- torch: Save/load model state
|
|
18
|
+
- pathlib: File management
|
|
19
|
+
- json: Save metadata
|
|
20
|
+
|
|
21
|
+
FILES FROM THIS PROJECT:
|
|
22
|
+
- model/checkpoints/ (checkpoint directory)
|
|
23
|
+
|
|
24
|
+
CHECKPOINT CONTENTS:
|
|
25
|
+
- model_state_dict: Model weights
|
|
26
|
+
- optimizer_state_dict: Optimizer state (for resuming)
|
|
27
|
+
- step: Current training step
|
|
28
|
+
- epoch: Current epoch
|
|
29
|
+
- best_val_loss: Best validation loss so far
|
|
30
|
+
- config: Model and training config
|
|
31
|
+
|
|
32
|
+
COMMON ISSUES:
|
|
33
|
+
- Not saving optimizer → can't resume training properly
|
|
34
|
+
- Not saving RNG state → non-reproducible results
|
|
35
|
+
- Checkpoints too large → disk space issues
|
|
36
|
+
- Not testing loading → corrupt checkpoints discovered too late
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
import torch
|
|
40
|
+
from pathlib import Path
|
|
41
|
+
import json
|
|
42
|
+
from typing import Dict, Optional
|
|
43
|
+
import shutil
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class CheckpointManager:
|
|
47
|
+
"""
|
|
48
|
+
Manages model checkpoints during training.
|
|
49
|
+
|
|
50
|
+
Handles saving, loading, and cleanup of checkpoint files.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def __init__(self, checkpoint_dir: Path, keep_last_n: int = 3):
|
|
54
|
+
"""
|
|
55
|
+
Args:
|
|
56
|
+
checkpoint_dir: Directory to save checkpoints
|
|
57
|
+
keep_last_n: Number of recent checkpoints to keep
|
|
58
|
+
"""
|
|
59
|
+
self.checkpoint_dir = Path(checkpoint_dir)
|
|
60
|
+
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
61
|
+
self.keep_last_n = keep_last_n
|
|
62
|
+
|
|
63
|
+
# Track best validation loss
|
|
64
|
+
self.best_val_loss = float("inf")
|
|
65
|
+
|
|
66
|
+
print(f"Checkpoint directory: {self.checkpoint_dir}")
|
|
67
|
+
|
|
68
|
+
def save_checkpoint(
|
|
69
|
+
self,
|
|
70
|
+
model: torch.nn.Module,
|
|
71
|
+
optimizer: torch.optim.Optimizer,
|
|
72
|
+
step: int,
|
|
73
|
+
epoch: int,
|
|
74
|
+
val_loss: float,
|
|
75
|
+
model_config: dict,
|
|
76
|
+
train_config: dict,
|
|
77
|
+
is_best: bool = False,
|
|
78
|
+
):
|
|
79
|
+
"""
|
|
80
|
+
Save a checkpoint.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
model: Model to save
|
|
84
|
+
optimizer: Optimizer to save
|
|
85
|
+
step: Current training step
|
|
86
|
+
epoch: Current epoch
|
|
87
|
+
val_loss: Validation loss
|
|
88
|
+
model_config: Model configuration
|
|
89
|
+
train_config: Training configuration
|
|
90
|
+
is_best: Whether this is the best model so far
|
|
91
|
+
"""
|
|
92
|
+
checkpoint = {
|
|
93
|
+
"model_state_dict": model.state_dict(),
|
|
94
|
+
"optimizer_state_dict": optimizer.state_dict(),
|
|
95
|
+
"step": step,
|
|
96
|
+
"epoch": epoch,
|
|
97
|
+
"val_loss": val_loss,
|
|
98
|
+
"best_val_loss": self.best_val_loss,
|
|
99
|
+
"model_config": model_config,
|
|
100
|
+
"train_config": train_config,
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
# Save regular checkpoint
|
|
104
|
+
checkpoint_path = self.checkpoint_dir / f"checkpoint_step_{step}.pt"
|
|
105
|
+
torch.save(checkpoint, checkpoint_path)
|
|
106
|
+
print(f"Checkpoint saved: {checkpoint_path}")
|
|
107
|
+
|
|
108
|
+
# Save as best if applicable
|
|
109
|
+
if is_best or val_loss < self.best_val_loss:
|
|
110
|
+
self.best_val_loss = val_loss
|
|
111
|
+
best_path = self.checkpoint_dir / "best_model.pt"
|
|
112
|
+
torch.save(checkpoint, best_path)
|
|
113
|
+
print(f"Best model saved: {best_path} (val_loss: {val_loss:.4f})")
|
|
114
|
+
|
|
115
|
+
# Save as latest (for easy resuming)
|
|
116
|
+
latest_path = self.checkpoint_dir / "latest_checkpoint.pt"
|
|
117
|
+
torch.save(checkpoint, latest_path)
|
|
118
|
+
|
|
119
|
+
# Cleanup old checkpoints
|
|
120
|
+
self._cleanup_old_checkpoints()
|
|
121
|
+
|
|
122
|
+
def _cleanup_old_checkpoints(self):
|
|
123
|
+
"""Remove old checkpoints, keeping only last N."""
|
|
124
|
+
# Get all checkpoint files (except best and latest)
|
|
125
|
+
checkpoints = sorted(
|
|
126
|
+
self.checkpoint_dir.glob("checkpoint_step_*.pt"),
|
|
127
|
+
key=lambda p: int(p.stem.split("_")[-1]),
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# Remove old ones
|
|
131
|
+
if len(checkpoints) > self.keep_last_n:
|
|
132
|
+
for ckpt in checkpoints[: -self.keep_last_n]:
|
|
133
|
+
ckpt.unlink()
|
|
134
|
+
print(f"Removed old checkpoint: {ckpt.name}")
|
|
135
|
+
|
|
136
|
+
def load_checkpoint(
|
|
137
|
+
self,
|
|
138
|
+
model: torch.nn.Module,
|
|
139
|
+
optimizer: Optional[torch.optim.Optimizer] = None,
|
|
140
|
+
checkpoint_path: Optional[Path] = None,
|
|
141
|
+
device: str = "cuda",
|
|
142
|
+
) -> Dict:
|
|
143
|
+
"""
|
|
144
|
+
Load a checkpoint.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
model: Model to load weights into
|
|
148
|
+
optimizer: Optimizer to load state into (optional)
|
|
149
|
+
checkpoint_path: Specific checkpoint to load (or None for latest)
|
|
150
|
+
device: Device to load to
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
Checkpoint dictionary with metadata
|
|
154
|
+
"""
|
|
155
|
+
# Use latest if no path specified
|
|
156
|
+
if checkpoint_path is None:
|
|
157
|
+
checkpoint_path = self.checkpoint_dir / "latest_checkpoint.pt"
|
|
158
|
+
else:
|
|
159
|
+
checkpoint_path = Path(checkpoint_path)
|
|
160
|
+
|
|
161
|
+
if not checkpoint_path.exists():
|
|
162
|
+
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
|
|
163
|
+
|
|
164
|
+
print(f"Loading checkpoint: {checkpoint_path}")
|
|
165
|
+
|
|
166
|
+
# Load checkpoint
|
|
167
|
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
|
168
|
+
|
|
169
|
+
# Load model weights
|
|
170
|
+
model.load_state_dict(checkpoint["model_state_dict"])
|
|
171
|
+
|
|
172
|
+
# Load optimizer state if provided
|
|
173
|
+
if optimizer is not None and "optimizer_state_dict" in checkpoint:
|
|
174
|
+
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
|
175
|
+
|
|
176
|
+
print(f"Loaded checkpoint from step {checkpoint['step']}, epoch {checkpoint['epoch']}")
|
|
177
|
+
print(f"Validation loss: {checkpoint['val_loss']:.4f}")
|
|
178
|
+
|
|
179
|
+
return checkpoint
|
|
180
|
+
|
|
181
|
+
def has_checkpoint(self) -> bool:
|
|
182
|
+
"""Check if a checkpoint exists."""
|
|
183
|
+
latest_path = self.checkpoint_dir / "latest_checkpoint.pt"
|
|
184
|
+
return latest_path.exists()
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def save_model_for_inference(
|
|
188
|
+
model: torch.nn.Module, tokenizer_path: Path, save_path: Path, model_config: dict
|
|
189
|
+
):
|
|
190
|
+
"""
|
|
191
|
+
Save model for inference (weights only, no optimizer).
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
model: Trained model
|
|
195
|
+
tokenizer_path: Path to tokenizer model
|
|
196
|
+
save_path: Where to save
|
|
197
|
+
model_config: Model configuration
|
|
198
|
+
"""
|
|
199
|
+
save_dict = {
|
|
200
|
+
"model_state_dict": model.state_dict(),
|
|
201
|
+
"model_config": model_config,
|
|
202
|
+
"tokenizer_path": str(tokenizer_path),
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
torch.save(save_dict, save_path)
|
|
206
|
+
print(f"Model saved for inference: {save_path}")
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def load_model_for_inference(model: torch.nn.Module, checkpoint_path: Path, device: str = "cuda"):
|
|
210
|
+
"""
|
|
211
|
+
Load model for inference.
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
model: Model architecture (will be populated with weights)
|
|
215
|
+
checkpoint_path: Path to saved model
|
|
216
|
+
device: Device to load to
|
|
217
|
+
"""
|
|
218
|
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
|
219
|
+
model.load_state_dict(checkpoint["model_state_dict"])
|
|
220
|
+
model.eval()
|
|
221
|
+
|
|
222
|
+
print(f"Model loaded for inference from: {checkpoint_path}")
|
|
223
|
+
|
|
224
|
+
return checkpoint.get("model_config")
|
gptmed/utils/logging.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Training Logging Utilities
|
|
3
|
+
|
|
4
|
+
PURPOSE:
|
|
5
|
+
Track and log training metrics (loss, learning rate, gradient norms).
|
|
6
|
+
This is CRITICAL for debugging training issues.
|
|
7
|
+
|
|
8
|
+
WHAT THIS FILE DOES:
|
|
9
|
+
1. Log training metrics to console
|
|
10
|
+
2. Track loss history
|
|
11
|
+
3. Compute moving averages
|
|
12
|
+
4. Save metrics to file
|
|
13
|
+
|
|
14
|
+
WHY LOGGING IS CRITICAL:
|
|
15
|
+
- Detect gradient explosions (loss → NaN)
|
|
16
|
+
- Spot overfitting (train loss ↓, val loss ↑)
|
|
17
|
+
- Monitor learning rate schedule
|
|
18
|
+
- Debug slow convergence
|
|
19
|
+
|
|
20
|
+
PACKAGES USED:
|
|
21
|
+
- json: Save metrics
|
|
22
|
+
- time: Track training speed
|
|
23
|
+
- statistics: Compute averages
|
|
24
|
+
|
|
25
|
+
FILES FROM THIS PROJECT:
|
|
26
|
+
- None (utility module)
|
|
27
|
+
|
|
28
|
+
COMMON ISSUES TO DETECT:
|
|
29
|
+
- Loss = NaN → exploding gradients, LR too high
|
|
30
|
+
- Loss stuck → LR too low, bad initialization
|
|
31
|
+
- Loss oscillating → LR too high, reduce it
|
|
32
|
+
- Val loss increasing → overfitting, add regularization
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
import json
|
|
36
|
+
import time
|
|
37
|
+
from pathlib import Path
|
|
38
|
+
from typing import Dict, List, Optional
|
|
39
|
+
from collections import defaultdict
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class MetricsLogger:
|
|
43
|
+
"""
|
|
44
|
+
Simple metrics logger for training.
|
|
45
|
+
|
|
46
|
+
Tracks loss, learning rate, and other metrics over time.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(self, log_dir: Path, experiment_name: str = "training"):
|
|
50
|
+
"""
|
|
51
|
+
Args:
|
|
52
|
+
log_dir: Directory to save logs
|
|
53
|
+
experiment_name: Name for this training run
|
|
54
|
+
"""
|
|
55
|
+
self.log_dir = Path(log_dir)
|
|
56
|
+
self.log_dir.mkdir(parents=True, exist_ok=True)
|
|
57
|
+
|
|
58
|
+
self.experiment_name = experiment_name
|
|
59
|
+
self.log_file = self.log_dir / f"{experiment_name}_metrics.jsonl"
|
|
60
|
+
|
|
61
|
+
# Metrics storage
|
|
62
|
+
self.metrics = defaultdict(list)
|
|
63
|
+
|
|
64
|
+
# Timing
|
|
65
|
+
self.start_time = time.time()
|
|
66
|
+
self.step_times = []
|
|
67
|
+
|
|
68
|
+
print(f"Logging to: {self.log_file}")
|
|
69
|
+
|
|
70
|
+
def log(self, step: int, metrics: Dict[str, float]):
|
|
71
|
+
"""
|
|
72
|
+
Log metrics for a training step.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
step: Training step number
|
|
76
|
+
metrics: Dictionary of metric_name -> value
|
|
77
|
+
"""
|
|
78
|
+
# Add timestamp and step
|
|
79
|
+
log_entry = {"step": step, "timestamp": time.time() - self.start_time, **metrics}
|
|
80
|
+
|
|
81
|
+
# Store in memory
|
|
82
|
+
for key, value in metrics.items():
|
|
83
|
+
self.metrics[key].append(value)
|
|
84
|
+
|
|
85
|
+
# Append to file (JSONL format)
|
|
86
|
+
with open(self.log_file, "a") as f:
|
|
87
|
+
f.write(json.dumps(log_entry) + "\n")
|
|
88
|
+
|
|
89
|
+
def log_epoch(self, epoch: int, train_loss: float, val_loss: float):
|
|
90
|
+
"""Log epoch-level metrics."""
|
|
91
|
+
print(f"\nEpoch {epoch}:")
|
|
92
|
+
print(f" Train loss: {train_loss:.4f}")
|
|
93
|
+
print(f" Val loss: {val_loss:.4f}")
|
|
94
|
+
|
|
95
|
+
self.log(epoch, {"epoch": epoch, "train_loss": train_loss, "val_loss": val_loss})
|
|
96
|
+
|
|
97
|
+
def get_last(self, metric_name: str, n: int = 1) -> Optional[float]:
|
|
98
|
+
"""Get last n values of a metric."""
|
|
99
|
+
if metric_name not in self.metrics:
|
|
100
|
+
return None
|
|
101
|
+
values = self.metrics[metric_name]
|
|
102
|
+
if len(values) < n:
|
|
103
|
+
return None
|
|
104
|
+
return sum(values[-n:]) / n
|
|
105
|
+
|
|
106
|
+
def get_average(self, metric_name: str, window: int = 100) -> Optional[float]:
|
|
107
|
+
"""Get moving average of a metric."""
|
|
108
|
+
if metric_name not in self.metrics:
|
|
109
|
+
return None
|
|
110
|
+
values = self.metrics[metric_name]
|
|
111
|
+
if len(values) == 0:
|
|
112
|
+
return None
|
|
113
|
+
window = min(window, len(values))
|
|
114
|
+
return sum(values[-window:]) / window
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def log_training_step(
|
|
118
|
+
step: int,
|
|
119
|
+
loss: float,
|
|
120
|
+
lr: float,
|
|
121
|
+
grad_norm: float,
|
|
122
|
+
tokens_per_sec: float,
|
|
123
|
+
print_output: bool = True,
|
|
124
|
+
):
|
|
125
|
+
"""
|
|
126
|
+
Log a single training step to console.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
step: Training step
|
|
130
|
+
loss: Current loss
|
|
131
|
+
lr: Current learning rate
|
|
132
|
+
grad_norm: Gradient norm (before clipping)
|
|
133
|
+
tokens_per_sec: Processing speed
|
|
134
|
+
print_output: Whether to print to console
|
|
135
|
+
"""
|
|
136
|
+
if not print_output:
|
|
137
|
+
return
|
|
138
|
+
|
|
139
|
+
# Format output
|
|
140
|
+
msg = f"Step {step:5d} | "
|
|
141
|
+
msg += f"Loss: {loss:.4f} | "
|
|
142
|
+
msg += f"LR: {lr:.2e} | "
|
|
143
|
+
msg += f"Grad: {grad_norm:.3f} | "
|
|
144
|
+
msg += f"Tok/s: {tokens_per_sec:.0f}"
|
|
145
|
+
|
|
146
|
+
# Check for issues
|
|
147
|
+
if loss != loss: # NaN check
|
|
148
|
+
msg += " [WARNING: NaN loss!]"
|
|
149
|
+
elif loss > 100:
|
|
150
|
+
msg += " [WARNING: High loss!]"
|
|
151
|
+
|
|
152
|
+
if grad_norm > 10:
|
|
153
|
+
msg += " [WARNING: Large gradients!]"
|
|
154
|
+
|
|
155
|
+
print(msg)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def log_validation(step: int, val_loss: float, val_perplexity: float):
|
|
159
|
+
"""
|
|
160
|
+
Log validation metrics.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
step: Training step
|
|
164
|
+
val_loss: Validation loss
|
|
165
|
+
val_perplexity: Validation perplexity (exp(loss))
|
|
166
|
+
"""
|
|
167
|
+
print(f"\n{'='*60}")
|
|
168
|
+
print(f"Validation at step {step}")
|
|
169
|
+
print(f" Loss: {val_loss:.4f}")
|
|
170
|
+
print(f" Perplexity: {val_perplexity:.2f}")
|
|
171
|
+
print(f"{'='*60}\n")
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def save_training_summary(log_dir: Path, config: dict, final_metrics: dict):
|
|
175
|
+
"""
|
|
176
|
+
Save final training summary.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
log_dir: Directory to save summary
|
|
180
|
+
config: Training configuration
|
|
181
|
+
final_metrics: Final metrics (best val loss, etc.)
|
|
182
|
+
"""
|
|
183
|
+
summary = {"config": config, "final_metrics": final_metrics, "timestamp": time.time()}
|
|
184
|
+
|
|
185
|
+
summary_file = log_dir / "training_summary.json"
|
|
186
|
+
with open(summary_file, "w") as f:
|
|
187
|
+
json.dump(summary, f, indent=2)
|
|
188
|
+
|
|
189
|
+
print(f"\nTraining summary saved to: {summary_file}")
|