gptmed 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. gptmed/__init__.py +37 -0
  2. gptmed/configs/__init__.py +1 -0
  3. gptmed/configs/train_config.py +154 -0
  4. gptmed/data/__init__.py +5 -0
  5. gptmed/data/parsers/__init__.py +10 -0
  6. gptmed/data/parsers/medquad_parser.py +257 -0
  7. gptmed/data/parsers/text_formatter.py +148 -0
  8. gptmed/inference/__init__.py +1 -0
  9. gptmed/inference/decoding_utils.py +190 -0
  10. gptmed/inference/generation_config.py +83 -0
  11. gptmed/inference/generator.py +253 -0
  12. gptmed/inference/sampling.py +261 -0
  13. gptmed/model/__init__.py +9 -0
  14. gptmed/model/architecture/__init__.py +35 -0
  15. gptmed/model/architecture/attention.py +188 -0
  16. gptmed/model/architecture/decoder_block.py +130 -0
  17. gptmed/model/architecture/embeddings.py +146 -0
  18. gptmed/model/architecture/feedforward.py +109 -0
  19. gptmed/model/architecture/transformer.py +204 -0
  20. gptmed/model/configs/__init__.py +17 -0
  21. gptmed/model/configs/model_config.py +155 -0
  22. gptmed/tokenizer/__init__.py +7 -0
  23. gptmed/tokenizer/tokenize_data.py +286 -0
  24. gptmed/tokenizer/train_tokenizer.py +218 -0
  25. gptmed/training/__init__.py +1 -0
  26. gptmed/training/dataset.py +183 -0
  27. gptmed/training/train.py +272 -0
  28. gptmed/training/trainer.py +331 -0
  29. gptmed/training/utils.py +212 -0
  30. gptmed/utils/__init__.py +1 -0
  31. gptmed/utils/checkpoints.py +224 -0
  32. gptmed/utils/logging.py +189 -0
  33. gptmed-0.0.1.dist-info/METADATA +325 -0
  34. gptmed-0.0.1.dist-info/RECORD +38 -0
  35. gptmed-0.0.1.dist-info/WHEEL +5 -0
  36. gptmed-0.0.1.dist-info/entry_points.txt +3 -0
  37. gptmed-0.0.1.dist-info/licenses/LICENSE +21 -0
  38. gptmed-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,272 @@
1
+ """
2
+ Main Training Script
3
+
4
+ PURPOSE:
5
+ Entry point for training the GPT model.
6
+ Ties everything together and starts training.
7
+
8
+ WHAT THIS FILE DOES:
9
+ 1. Load configuration (model + training)
10
+ 2. Create model
11
+ 3. Load tokenized data
12
+ 4. Create optimizer
13
+ 5. Start training
14
+ 6. Handle command-line arguments
15
+
16
+ USAGE:
17
+ python training/train.py # Use default config
18
+ python training/train.py --batch-size 32 # Override batch size
19
+ python training/train.py --resume # Resume from checkpoint
20
+
21
+ PACKAGES USED:
22
+ - torch: PyTorch
23
+ - argparse: Command-line arguments
24
+ - pathlib: Path handling
25
+
26
+ FILES FROM THIS PROJECT:
27
+ - All model, training, and utility modules
28
+
29
+ EXECUTION ORDER:
30
+ 1. Parse arguments
31
+ 2. Set random seeds (reproducibility)
32
+ 3. Create model
33
+ 4. Load data
34
+ 5. Create optimizer
35
+ 6. Initialize trainer
36
+ 7. Start training
37
+ """
38
+
39
+ import torch
40
+ import argparse
41
+ from pathlib import Path
42
+ import random
43
+ import numpy as np
44
+ import sys
45
+
46
+ # Add parent directory to path for imports
47
+ sys.path.insert(0, str(Path(__file__).parent.parent))
48
+
49
+ from llm_med.model.architecture import GPTTransformer
50
+ from llm_med.model.configs.model_config import get_small_config, get_tiny_config
51
+ from llm_med.configs.train_config import get_default_config, get_quick_test_config
52
+ from llm_med.training.dataset import create_dataloaders
53
+ from llm_med.training.trainer import Trainer
54
+
55
+
56
+ def set_seed(seed: int):
57
+ """
58
+ Set random seeds for reproducibility.
59
+
60
+ Args:
61
+ seed: Random seed
62
+
63
+ Why this matters:
64
+ - Makes training reproducible
65
+ - Critical for debugging (can recreate issues)
66
+ - Scientific experiments need reproducibility
67
+ """
68
+ random.seed(seed)
69
+ np.random.seed(seed)
70
+ torch.manual_seed(seed)
71
+ if torch.cuda.is_available():
72
+ torch.cuda.manual_seed(seed)
73
+ torch.cuda.manual_seed_all(seed)
74
+
75
+ # Make cudnn deterministic (slower but reproducible)
76
+ torch.backends.cudnn.deterministic = True
77
+ torch.backends.cudnn.benchmark = False
78
+
79
+
80
+ def count_parameters(model):
81
+ """Count trainable parameters."""
82
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
83
+
84
+
85
+ def main():
86
+ parser = argparse.ArgumentParser(description="Train GPT model on MedQuAD")
87
+
88
+ # Model config
89
+ parser.add_argument(
90
+ "--model-size",
91
+ type=str,
92
+ default="small",
93
+ choices=["tiny", "small", "medium"],
94
+ help="Model size (tiny/small/medium)",
95
+ )
96
+
97
+ # Training config
98
+ parser.add_argument(
99
+ "--batch-size", type=int, default=None, help="Batch size (overrides config)"
100
+ )
101
+ parser.add_argument(
102
+ "--learning-rate", type=float, default=None, help="Learning rate (overrides config)"
103
+ )
104
+ parser.add_argument(
105
+ "--num-epochs", type=int, default=None, help="Number of epochs (overrides config)"
106
+ )
107
+ parser.add_argument(
108
+ "--quick-test", action="store_true", help="Use quick test config (small batches, few steps)"
109
+ )
110
+
111
+ # Paths
112
+ parser.add_argument(
113
+ "--train-data", type=str, default="./data/tokenized/train.npy", help="Path to training data"
114
+ )
115
+ parser.add_argument(
116
+ "--val-data", type=str, default="./data/tokenized/val.npy", help="Path to validation data"
117
+ )
118
+ parser.add_argument(
119
+ "--checkpoint-dir", type=str, default="./model/checkpoints", help="Checkpoint directory"
120
+ )
121
+
122
+ # Resume training
123
+ parser.add_argument("--resume", action="store_true", help="Resume from latest checkpoint")
124
+ parser.add_argument(
125
+ "--resume-from", type=str, default=None, help="Resume from specific checkpoint"
126
+ )
127
+
128
+ # Device
129
+ parser.add_argument(
130
+ "--device", type=str, default="cuda", choices=["cuda", "cpu"], help="Device to train on"
131
+ )
132
+
133
+ # Misc
134
+ parser.add_argument("--seed", type=int, default=42, help="Random seed")
135
+
136
+ args = parser.parse_args()
137
+
138
+ print("=" * 60)
139
+ print("GPT Training - MedQuAD")
140
+ print("=" * 60)
141
+
142
+ # Check CUDA availability
143
+ if args.device == "cuda" and not torch.cuda.is_available():
144
+ print("WARNING: CUDA not available, using CPU")
145
+ args.device = "cpu"
146
+
147
+ # Set random seed
148
+ print(f"\nSetting random seed: {args.seed}")
149
+ set_seed(args.seed)
150
+
151
+ # Load configurations
152
+ print(f"\nLoading configurations...")
153
+
154
+ # Model config
155
+ if args.model_size == "tiny":
156
+ model_config = get_tiny_config()
157
+ elif args.model_size == "small":
158
+ model_config = get_small_config()
159
+ else:
160
+ raise ValueError(f"Unknown model size: {args.model_size}")
161
+
162
+ print(f"Model config: {args.model_size}")
163
+ print(f" d_model: {model_config.d_model}")
164
+ print(f" n_layers: {model_config.n_layers}")
165
+ print(f" n_heads: {model_config.n_heads}")
166
+
167
+ # Training config
168
+ if args.quick_test:
169
+ train_config = get_quick_test_config()
170
+ print("Using quick test config (fast debugging)")
171
+ else:
172
+ train_config = get_default_config()
173
+
174
+ # Override with command-line args
175
+ if args.batch_size is not None:
176
+ train_config.batch_size = args.batch_size
177
+ if args.learning_rate is not None:
178
+ train_config.learning_rate = args.learning_rate
179
+ if args.num_epochs is not None:
180
+ train_config.num_epochs = args.num_epochs
181
+ if args.train_data:
182
+ train_config.train_data_path = args.train_data
183
+ if args.val_data:
184
+ train_config.val_data_path = args.val_data
185
+ if args.checkpoint_dir:
186
+ train_config.checkpoint_dir = args.checkpoint_dir
187
+
188
+ train_config.device = args.device
189
+ train_config.seed = args.seed
190
+
191
+ print(f"\nTraining config:")
192
+ print(f" Batch size: {train_config.batch_size}")
193
+ print(f" Learning rate: {train_config.learning_rate}")
194
+ print(f" Num epochs: {train_config.num_epochs}")
195
+ print(f" Device: {train_config.device}")
196
+
197
+ # Create model
198
+ print(f"\nCreating model...")
199
+ model = GPTTransformer(model_config)
200
+ total_params = count_parameters(model)
201
+ print(f"Model created with {total_params:,} parameters")
202
+ print(f"Model size: ~{total_params * 4 / 1024 / 1024:.2f} MB")
203
+
204
+ # Load data
205
+ print(f"\nLoading data...")
206
+ train_loader, val_loader = create_dataloaders(
207
+ train_path=Path(train_config.train_data_path),
208
+ val_path=Path(train_config.val_data_path),
209
+ batch_size=train_config.batch_size,
210
+ num_workers=0,
211
+ )
212
+
213
+ # Create optimizer
214
+ print(f"\nCreating optimizer...")
215
+ optimizer = torch.optim.AdamW(
216
+ model.parameters(),
217
+ lr=train_config.learning_rate,
218
+ betas=train_config.betas,
219
+ eps=train_config.eps,
220
+ weight_decay=train_config.weight_decay,
221
+ )
222
+ print(f"Optimizer: AdamW")
223
+ print(f" LR: {train_config.learning_rate}")
224
+ print(f" Weight decay: {train_config.weight_decay}")
225
+
226
+ # Create trainer
227
+ print(f"\nInitializing trainer...")
228
+ trainer = Trainer(
229
+ model=model,
230
+ train_loader=train_loader,
231
+ val_loader=val_loader,
232
+ optimizer=optimizer,
233
+ config=train_config,
234
+ device=args.device,
235
+ )
236
+
237
+ # Resume if requested
238
+ if args.resume or args.resume_from:
239
+ print(f"\nResuming training...")
240
+ checkpoint_path = Path(args.resume_from) if args.resume_from else None
241
+ trainer.resume_from_checkpoint(checkpoint_path)
242
+
243
+ # Start training
244
+ print(f"\n{'='*60}")
245
+ print("Ready to train!")
246
+ print(f"{'='*60}\n")
247
+
248
+ try:
249
+ trainer.train()
250
+ except KeyboardInterrupt:
251
+ print("\n\nTraining interrupted by user")
252
+ print("Saving checkpoint...")
253
+ trainer.checkpoint_manager.save_checkpoint(
254
+ model=model,
255
+ optimizer=optimizer,
256
+ step=trainer.global_step,
257
+ epoch=trainer.current_epoch,
258
+ val_loss=trainer.best_val_loss,
259
+ model_config=model_config.to_dict(),
260
+ train_config=train_config.to_dict(),
261
+ )
262
+ print("Checkpoint saved. You can resume with --resume")
263
+
264
+ print("\n" + "=" * 60)
265
+ print("Training finished!")
266
+ print("=" * 60)
267
+ print(f"\nBest model saved in: {train_config.checkpoint_dir}/best_model.pt")
268
+ print(f"Logs saved in: {train_config.log_dir}")
269
+
270
+
271
+ if __name__ == "__main__":
272
+ main()
@@ -0,0 +1,331 @@
1
+ """
2
+ Trainer Class
3
+
4
+ PURPOSE:
5
+ Core training loop logic. This is the heart of Phase 3.
6
+ Handles forward pass, backward pass, optimization, and evaluation.
7
+
8
+ WHAT THIS FILE DOES:
9
+ 1. Training loop: iterate over batches, compute loss, backprop
10
+ 2. Evaluation: compute validation loss
11
+ 3. Checkpointing: save model periodically
12
+ 4. Logging: track metrics
13
+
14
+ TRAINING ALGORITHM:
15
+ For each batch:
16
+ 1. Forward pass: model(input) → logits
17
+ 2. Compute loss: CrossEntropyLoss(logits, targets)
18
+ 3. Backward pass: loss.backward()
19
+ 4. Clip gradients: prevent explosion
20
+ 5. Optimizer step: update weights
21
+ 6. Update learning rate: warmup + decay
22
+
23
+ PACKAGES USED:
24
+ - torch: PyTorch training
25
+ - time: Track speed
26
+
27
+ FILES FROM THIS PROJECT:
28
+ - model/architecture/transformer.py: Model
29
+ - training/dataset.py: DataLoader
30
+ - training/utils.py: Helper functions
31
+ - utils/logging.py: Metrics logging
32
+ - utils/checkpoints.py: Save/load
33
+
34
+ TENSOR SHAPES:
35
+ - Input: [batch_size, seq_len]
36
+ - Logits: [batch_size, seq_len, vocab_size]
37
+ - Targets: [batch_size, seq_len]
38
+ - Loss: scalar
39
+
40
+ COMMON TRAINING ISSUES:
41
+ - Loss = NaN → gradient explosion (reduce LR, check grad clipping)
42
+ - Loss stuck → LR too low, bad initialization
43
+ - Slow convergence → LR too low, increase it
44
+ - Overfitting → add dropout, weight decay
45
+ """
46
+
47
+ import torch
48
+ import torch.nn as nn
49
+ from torch.utils.data import DataLoader
50
+ import time
51
+ from pathlib import Path
52
+ from typing import Optional
53
+
54
+ from llm_med.model.architecture import GPTTransformer
55
+ from llm_med.training.utils import (
56
+ clip_grad_norm,
57
+ get_lr_with_warmup,
58
+ set_learning_rate,
59
+ estimate_loss_dataloader,
60
+ compute_perplexity,
61
+ )
62
+ from llm_med.utils.logging import MetricsLogger, log_training_step, log_validation
63
+ from llm_med.utils.checkpoints import CheckpointManager
64
+
65
+
66
+ class Trainer:
67
+ """
68
+ Training orchestrator for GPT model.
69
+
70
+ Handles the full training loop including:
71
+ - Forward/backward passes
72
+ - Optimization
73
+ - Evaluation
74
+ - Checkpointing
75
+ - Logging
76
+ """
77
+
78
+ def __init__(
79
+ self,
80
+ model: GPTTransformer,
81
+ train_loader: DataLoader,
82
+ val_loader: DataLoader,
83
+ optimizer: torch.optim.Optimizer,
84
+ config, # TrainingConfig
85
+ device: str = "cuda",
86
+ ):
87
+ """
88
+ Args:
89
+ model: GPT model to train
90
+ train_loader: Training data loader
91
+ val_loader: Validation data loader
92
+ optimizer: Optimizer (e.g., AdamW)
93
+ config: TrainingConfig object
94
+ device: Device to train on
95
+ """
96
+ self.model = model.to(device)
97
+ self.train_loader = train_loader
98
+ self.val_loader = val_loader
99
+ self.optimizer = optimizer
100
+ self.config = config
101
+ self.device = device
102
+
103
+ # Initialize utilities
104
+ self.logger = MetricsLogger(log_dir=config.log_dir, experiment_name="gpt_training")
105
+
106
+ self.checkpoint_manager = CheckpointManager(
107
+ checkpoint_dir=config.checkpoint_dir, keep_last_n=config.keep_last_n
108
+ )
109
+
110
+ # Training state
111
+ self.global_step = 0
112
+ self.current_epoch = 0
113
+ self.best_val_loss = float("inf")
114
+
115
+ # Calculate total steps
116
+ steps_per_epoch = len(train_loader)
117
+ if config.max_steps > 0:
118
+ self.total_steps = config.max_steps
119
+ else:
120
+ self.total_steps = steps_per_epoch * config.num_epochs
121
+
122
+ print(f"\nTrainer initialized:")
123
+ print(f" Device: {device}")
124
+ print(f" Total steps: {self.total_steps}")
125
+ print(f" Steps per epoch: {steps_per_epoch}")
126
+ print(f" Num epochs: {config.num_epochs}")
127
+
128
+ def train_step(self, batch: tuple) -> dict:
129
+ """
130
+ Single training step.
131
+
132
+ Args:
133
+ batch: (input_ids, target_ids) tuple
134
+
135
+ Returns:
136
+ Dictionary with step metrics
137
+ """
138
+ # Move batch to device
139
+ input_ids, target_ids = batch
140
+ input_ids = input_ids.to(self.device)
141
+ target_ids = target_ids.to(self.device)
142
+
143
+ # Forward pass
144
+ logits = self.model(input_ids)
145
+
146
+ # Compute loss
147
+ # CrossEntropyLoss expects:
148
+ # - Input: [N, C] where N = batch_size * seq_len, C = vocab_size
149
+ # - Target: [N] with class indices
150
+ batch_size, seq_len, vocab_size = logits.shape
151
+ logits_flat = logits.view(batch_size * seq_len, vocab_size)
152
+ targets_flat = target_ids.view(batch_size * seq_len)
153
+
154
+ loss = nn.functional.cross_entropy(logits_flat, targets_flat)
155
+
156
+ # Backward pass
157
+ self.optimizer.zero_grad()
158
+ loss.backward()
159
+
160
+ # Clip gradients (CRITICAL for stability)
161
+ grad_norm = clip_grad_norm(self.model, self.config.grad_clip)
162
+
163
+ # Optimizer step
164
+ self.optimizer.step()
165
+
166
+ # Return metrics
167
+ return {
168
+ "loss": loss.item(),
169
+ "grad_norm": grad_norm,
170
+ "batch_size": batch_size,
171
+ "seq_len": seq_len,
172
+ }
173
+
174
+ def evaluate(self) -> dict:
175
+ """
176
+ Evaluate on validation set.
177
+
178
+ Returns:
179
+ Dictionary with validation metrics
180
+ """
181
+ print("\nRunning validation...")
182
+
183
+ val_loss = estimate_loss_dataloader(
184
+ self.model, self.val_loader, self.device, max_batches=self.config.eval_iters
185
+ )
186
+
187
+ val_perplexity = compute_perplexity(val_loss)
188
+
189
+ log_validation(self.global_step, val_loss, val_perplexity)
190
+
191
+ return {"val_loss": val_loss, "val_perplexity": val_perplexity}
192
+
193
+ def train(self):
194
+ """
195
+ Main training loop.
196
+
197
+ This is where everything comes together.
198
+ """
199
+ print("\n" + "=" * 60)
200
+ print("Starting Training")
201
+ print("=" * 60)
202
+
203
+ self.model.train()
204
+
205
+ # Training loop
206
+ for epoch in range(self.config.num_epochs):
207
+ self.current_epoch = epoch
208
+
209
+ print(f"\n{'='*60}")
210
+ print(f"Epoch {epoch + 1}/{self.config.num_epochs}")
211
+ print(f"{'='*60}")
212
+
213
+ epoch_start_time = time.time()
214
+
215
+ for batch_idx, batch in enumerate(self.train_loader):
216
+ step_start_time = time.time()
217
+
218
+ # Update learning rate (warmup + decay)
219
+ lr = get_lr_with_warmup(
220
+ step=self.global_step,
221
+ warmup_steps=self.config.warmup_steps,
222
+ max_lr=self.config.learning_rate,
223
+ min_lr=self.config.min_lr,
224
+ max_steps=self.total_steps,
225
+ decay_type=self.config.lr_decay,
226
+ )
227
+ set_learning_rate(self.optimizer, lr)
228
+
229
+ # Training step
230
+ metrics = self.train_step(batch)
231
+
232
+ # Calculate tokens per second
233
+ step_time = time.time() - step_start_time
234
+ tokens_per_sec = (metrics["batch_size"] * metrics["seq_len"]) / step_time
235
+
236
+ # Log to console
237
+ if self.global_step % self.config.log_interval == 0:
238
+ log_training_step(
239
+ step=self.global_step,
240
+ loss=metrics["loss"],
241
+ lr=lr,
242
+ grad_norm=metrics["grad_norm"],
243
+ tokens_per_sec=tokens_per_sec,
244
+ )
245
+
246
+ # Log metrics
247
+ self.logger.log(
248
+ self.global_step,
249
+ {
250
+ "train_loss": metrics["loss"],
251
+ "learning_rate": lr,
252
+ "grad_norm": metrics["grad_norm"],
253
+ "tokens_per_sec": tokens_per_sec,
254
+ },
255
+ )
256
+
257
+ # Evaluate
258
+ if self.global_step % self.config.eval_interval == 0 and self.global_step > 0:
259
+ val_metrics = self.evaluate()
260
+
261
+ # Log validation metrics
262
+ self.logger.log(self.global_step, val_metrics)
263
+
264
+ # Check if best model
265
+ if val_metrics["val_loss"] < self.best_val_loss:
266
+ self.best_val_loss = val_metrics["val_loss"]
267
+ is_best = True
268
+ else:
269
+ is_best = False
270
+
271
+ # Save checkpoint
272
+ self.checkpoint_manager.save_checkpoint(
273
+ model=self.model,
274
+ optimizer=self.optimizer,
275
+ step=self.global_step,
276
+ epoch=epoch,
277
+ val_loss=val_metrics["val_loss"],
278
+ model_config=self.model.config.to_dict(),
279
+ train_config=self.config.to_dict(),
280
+ is_best=is_best,
281
+ )
282
+
283
+ self.model.train() # Back to training mode
284
+
285
+ # Save checkpoint periodically
286
+ if self.global_step % self.config.save_interval == 0 and self.global_step > 0:
287
+ self.checkpoint_manager.save_checkpoint(
288
+ model=self.model,
289
+ optimizer=self.optimizer,
290
+ step=self.global_step,
291
+ epoch=epoch,
292
+ val_loss=self.best_val_loss,
293
+ model_config=self.model.config.to_dict(),
294
+ train_config=self.config.to_dict(),
295
+ )
296
+
297
+ self.global_step += 1
298
+
299
+ # Check if reached max steps
300
+ if self.config.max_steps > 0 and self.global_step >= self.config.max_steps:
301
+ print(f"\nReached max_steps ({self.config.max_steps}). Stopping training.")
302
+ return
303
+
304
+ # End of epoch
305
+ epoch_time = time.time() - epoch_start_time
306
+ print(f"\nEpoch {epoch + 1} completed in {epoch_time:.2f}s")
307
+
308
+ print("\n" + "=" * 60)
309
+ print("Training Complete!")
310
+ print("=" * 60)
311
+ print(f"Best validation loss: {self.best_val_loss:.4f}")
312
+
313
+ def resume_from_checkpoint(self, checkpoint_path: Optional[Path] = None):
314
+ """
315
+ Resume training from a checkpoint.
316
+
317
+ Args:
318
+ checkpoint_path: Path to checkpoint (or None for latest)
319
+ """
320
+ checkpoint = self.checkpoint_manager.load_checkpoint(
321
+ model=self.model,
322
+ optimizer=self.optimizer,
323
+ checkpoint_path=checkpoint_path,
324
+ device=self.device,
325
+ )
326
+
327
+ self.global_step = checkpoint["step"]
328
+ self.current_epoch = checkpoint["epoch"]
329
+ self.best_val_loss = checkpoint.get("best_val_loss", float("inf"))
330
+
331
+ print(f"Resumed from step {self.global_step}, epoch {self.current_epoch}")