gptmed 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. gptmed/__init__.py +37 -0
  2. gptmed/configs/__init__.py +1 -0
  3. gptmed/configs/train_config.py +154 -0
  4. gptmed/data/__init__.py +5 -0
  5. gptmed/data/parsers/__init__.py +10 -0
  6. gptmed/data/parsers/medquad_parser.py +257 -0
  7. gptmed/data/parsers/text_formatter.py +148 -0
  8. gptmed/inference/__init__.py +1 -0
  9. gptmed/inference/decoding_utils.py +190 -0
  10. gptmed/inference/generation_config.py +83 -0
  11. gptmed/inference/generator.py +253 -0
  12. gptmed/inference/sampling.py +261 -0
  13. gptmed/model/__init__.py +9 -0
  14. gptmed/model/architecture/__init__.py +35 -0
  15. gptmed/model/architecture/attention.py +188 -0
  16. gptmed/model/architecture/decoder_block.py +130 -0
  17. gptmed/model/architecture/embeddings.py +146 -0
  18. gptmed/model/architecture/feedforward.py +109 -0
  19. gptmed/model/architecture/transformer.py +204 -0
  20. gptmed/model/configs/__init__.py +17 -0
  21. gptmed/model/configs/model_config.py +155 -0
  22. gptmed/tokenizer/__init__.py +7 -0
  23. gptmed/tokenizer/tokenize_data.py +286 -0
  24. gptmed/tokenizer/train_tokenizer.py +218 -0
  25. gptmed/training/__init__.py +1 -0
  26. gptmed/training/dataset.py +183 -0
  27. gptmed/training/train.py +272 -0
  28. gptmed/training/trainer.py +331 -0
  29. gptmed/training/utils.py +212 -0
  30. gptmed/utils/__init__.py +1 -0
  31. gptmed/utils/checkpoints.py +224 -0
  32. gptmed/utils/logging.py +189 -0
  33. gptmed-0.0.1.dist-info/METADATA +325 -0
  34. gptmed-0.0.1.dist-info/RECORD +38 -0
  35. gptmed-0.0.1.dist-info/WHEEL +5 -0
  36. gptmed-0.0.1.dist-info/entry_points.txt +3 -0
  37. gptmed-0.0.1.dist-info/licenses/LICENSE +21 -0
  38. gptmed-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,261 @@
1
+ """
2
+ Sampling Strategies for Text Generation
3
+
4
+ PURPOSE:
5
+ Different methods to select the next token during generation.
6
+ Each strategy has trade-offs between quality, diversity, and speed.
7
+
8
+ WHAT THIS FILE DOES:
9
+ 1. Greedy sampling: Always pick highest probability (deterministic)
10
+ 2. Temperature sampling: Control randomness
11
+ 3. Top-k sampling: Sample from top k tokens only
12
+ 4. Top-p (nucleus) sampling: Sample from cumulative probability p
13
+
14
+ WHY DIFFERENT STRATEGIES:
15
+ - Greedy: Fast, deterministic, but boring and repetitive
16
+ - Temperature: Simple randomness control
17
+ - Top-k: Prevents sampling very unlikely tokens
18
+ - Top-p: More adaptive than top-k (adjusts to probability distribution)
19
+
20
+ PACKAGES USED:
21
+ - torch: PyTorch tensors and operations
22
+
23
+ FILES FROM THIS PROJECT:
24
+ - None (utility functions)
25
+
26
+ COMMON ISSUES:
27
+ - Temperature too low → boring, repetitive
28
+ - Temperature too high → incoherent
29
+ - Top-k too small → limited diversity
30
+ - Top-p too low → truncates good options
31
+ """
32
+
33
+ import torch
34
+ import torch.nn.functional as F
35
+
36
+
37
+ def greedy_sample(logits: torch.Tensor) -> torch.Tensor:
38
+ """
39
+ Greedy sampling: Always pick the highest probability token.
40
+
41
+ Args:
42
+ logits: Logits from model [batch_size, vocab_size]
43
+
44
+ Returns:
45
+ Next token IDs [batch_size]
46
+
47
+ Pros: Fast, deterministic, reproducible
48
+ Cons: Boring, repetitive, gets stuck in loops
49
+
50
+ Use when: You want deterministic outputs or testing
51
+ """
52
+ return torch.argmax(logits, dim=-1)
53
+
54
+
55
+ def temperature_sample(logits: torch.Tensor, temperature: float) -> torch.Tensor:
56
+ """
57
+ Temperature sampling: Scale logits before softmax.
58
+
59
+ Args:
60
+ logits: Logits from model [batch_size, vocab_size]
61
+ temperature: Temperature parameter (>0)
62
+
63
+ Returns:
64
+ Sampled token IDs [batch_size]
65
+
66
+ How it works:
67
+ - temperature = 1.0: No change (normal sampling)
68
+ - temperature < 1.0: More conservative (peaks sharper)
69
+ - temperature > 1.0: More random (distribution flatter)
70
+
71
+ Example:
72
+ - temp=0.1: Almost greedy
73
+ - temp=0.7: Balanced (recommended)
74
+ - temp=1.5: Very creative
75
+
76
+ Pros: Simple, interpretable
77
+ Cons: No control over tail probabilities
78
+ """
79
+ if temperature == 0.0:
80
+ return greedy_sample(logits)
81
+
82
+ # Scale logits by temperature
83
+ scaled_logits = logits / temperature
84
+
85
+ # Convert to probabilities
86
+ probs = F.softmax(scaled_logits, dim=-1)
87
+
88
+ # Sample from distribution
89
+ next_token = torch.multinomial(probs, num_samples=1).squeeze(-1)
90
+
91
+ return next_token
92
+
93
+
94
+ def top_k_sample(logits: torch.Tensor, k: int, temperature: float = 1.0) -> torch.Tensor:
95
+ """
96
+ Top-k sampling: Only sample from top k tokens.
97
+
98
+ Args:
99
+ logits: Logits from model [batch_size, vocab_size]
100
+ k: Number of top tokens to keep
101
+ temperature: Temperature scaling
102
+
103
+ Returns:
104
+ Sampled token IDs [batch_size]
105
+
106
+ How it works:
107
+ - Keep only top k highest probability tokens
108
+ - Set all other tokens to -inf (zero probability)
109
+ - Sample from remaining tokens
110
+
111
+ Why it helps:
112
+ - Prevents sampling very unlikely tokens (noise)
113
+ - Reduces incoherent outputs
114
+
115
+ Typical values:
116
+ - k=1: Greedy
117
+ - k=10: Very conservative
118
+ - k=50: Balanced
119
+ - k=100: More diverse
120
+
121
+ Limitation:
122
+ - Fixed k doesn't adapt to probability distribution
123
+ - If top-1 has 99% probability, k=50 wastes options
124
+ """
125
+ if k == 0 or k >= logits.size(-1):
126
+ # No filtering
127
+ return temperature_sample(logits, temperature)
128
+
129
+ # Get top k values and indices
130
+ top_k_logits, top_k_indices = torch.topk(logits, k, dim=-1)
131
+
132
+ # Create filtered logits (set non-top-k to -inf)
133
+ filtered_logits = torch.full_like(logits, float("-inf"))
134
+ filtered_logits.scatter_(-1, top_k_indices, top_k_logits)
135
+
136
+ # Sample with temperature
137
+ return temperature_sample(filtered_logits, temperature)
138
+
139
+
140
+ def top_p_sample(logits: torch.Tensor, p: float, temperature: float = 1.0) -> torch.Tensor:
141
+ """
142
+ Top-p (nucleus) sampling: Sample from smallest set with cumulative prob >= p.
143
+
144
+ Args:
145
+ logits: Logits from model [batch_size, vocab_size]
146
+ p: Cumulative probability threshold (0 < p <= 1)
147
+ temperature: Temperature scaling
148
+
149
+ Returns:
150
+ Sampled token IDs [batch_size]
151
+
152
+ How it works:
153
+ 1. Sort tokens by probability (descending)
154
+ 2. Find smallest set where cumulative probability >= p
155
+ 3. Sample only from this set
156
+
157
+ Why better than top-k:
158
+ - Adapts to probability distribution
159
+ - When model is confident (one token has 90%), nucleus is small
160
+ - When uncertain, nucleus is larger (more options)
161
+
162
+ Typical values:
163
+ - p=0.9: Conservative (90% probability mass)
164
+ - p=0.95: Balanced (recommended)
165
+ - p=0.99: More diverse
166
+
167
+ Used in: GPT-3, ChatGPT, most modern LLMs
168
+ """
169
+ if p >= 1.0:
170
+ # No filtering
171
+ return temperature_sample(logits, temperature)
172
+
173
+ # Scale by temperature first
174
+ scaled_logits = logits / temperature
175
+
176
+ # Convert to probabilities
177
+ probs = F.softmax(scaled_logits, dim=-1)
178
+
179
+ # Sort probabilities descending
180
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
181
+
182
+ # Compute cumulative probabilities
183
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
184
+
185
+ # Find cutoff: first position where cumsum > p
186
+ # Shift right by 1 to keep at least one token
187
+ sorted_indices_to_remove = cumulative_probs > p
188
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
189
+ sorted_indices_to_remove[..., 0] = False
190
+
191
+ # Create mask in original order
192
+ indices_to_remove = sorted_indices_to_remove.scatter(
193
+ -1, sorted_indices, sorted_indices_to_remove
194
+ )
195
+
196
+ # Set removed indices to -inf
197
+ filtered_logits = scaled_logits.clone()
198
+ filtered_logits[indices_to_remove] = float("-inf")
199
+
200
+ # Sample from filtered distribution
201
+ probs = F.softmax(filtered_logits, dim=-1)
202
+ next_token = torch.multinomial(probs, num_samples=1).squeeze(-1)
203
+
204
+ return next_token
205
+
206
+
207
+ def sample_next_token(
208
+ logits: torch.Tensor, temperature: float = 1.0, top_k: int = 0, top_p: float = 1.0
209
+ ) -> torch.Tensor:
210
+ """
211
+ Unified sampling function combining temperature, top-k, and top-p.
212
+
213
+ Args:
214
+ logits: Logits from model [batch_size, vocab_size]
215
+ temperature: Temperature parameter
216
+ top_k: Top-k filtering (0 = disabled)
217
+ top_p: Top-p filtering (1.0 = disabled)
218
+
219
+ Returns:
220
+ Sampled token IDs [batch_size]
221
+
222
+ Order of operations:
223
+ 1. Temperature scaling
224
+ 2. Top-k filtering (if enabled)
225
+ 3. Top-p filtering (if enabled)
226
+ 4. Sample from remaining distribution
227
+ """
228
+ # Greedy if temperature is 0
229
+ if temperature == 0.0:
230
+ return greedy_sample(logits)
231
+
232
+ # Apply temperature
233
+ scaled_logits = logits / temperature
234
+
235
+ # Apply top-k if enabled
236
+ if top_k > 0 and top_k < logits.size(-1):
237
+ top_k_logits, top_k_indices = torch.topk(scaled_logits, top_k, dim=-1)
238
+ filtered_logits = torch.full_like(scaled_logits, float("-inf"))
239
+ filtered_logits.scatter_(-1, top_k_indices, top_k_logits)
240
+ scaled_logits = filtered_logits
241
+
242
+ # Apply top-p if enabled
243
+ if top_p < 1.0:
244
+ probs = F.softmax(scaled_logits, dim=-1)
245
+ sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
246
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
247
+
248
+ sorted_indices_to_remove = cumulative_probs > top_p
249
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
250
+ sorted_indices_to_remove[..., 0] = False
251
+
252
+ indices_to_remove = sorted_indices_to_remove.scatter(
253
+ -1, sorted_indices, sorted_indices_to_remove
254
+ )
255
+ scaled_logits[indices_to_remove] = float("-inf")
256
+
257
+ # Sample
258
+ probs = F.softmax(scaled_logits, dim=-1)
259
+ next_token = torch.multinomial(probs, num_samples=1).squeeze(-1)
260
+
261
+ return next_token
@@ -0,0 +1,9 @@
1
+ """
2
+ MedLLM Model Package
3
+
4
+ This package contains the GPT-based transformer architecture for medical QA.
5
+ """
6
+
7
+ from llm_med.model.architecture import GPTTransformer
8
+
9
+ __all__ = ["GPTTransformer"]
@@ -0,0 +1,35 @@
1
+ """
2
+ Model Architecture Components - Package Initializer
3
+
4
+ PURPOSE:
5
+ Makes it easy to import transformer components from a single location.
6
+
7
+ WHAT THIS FILE DOES:
8
+ Exposes main classes for import:
9
+ - from model.architecture import GPTTransformer
10
+ - from model.architecture import TransformerDecoderBlock
11
+ - from model.architecture import MultiHeadAttention
12
+
13
+ PACKAGES USED:
14
+ - None (just Python imports)
15
+
16
+ FILES FROM THIS PROJECT:
17
+ - All components in this architecture/ directory
18
+ """
19
+
20
+ from .embeddings import TokenEmbedding, PositionalEmbedding, TokenPositionalEmbedding
21
+ from .attention import MultiHeadAttention, create_causal_mask
22
+ from .feedforward import FeedForward
23
+ from .decoder_block import TransformerDecoderBlock
24
+ from .transformer import GPTTransformer
25
+
26
+ __all__ = [
27
+ "TokenEmbedding",
28
+ "PositionalEmbedding",
29
+ "TokenPositionalEmbedding",
30
+ "MultiHeadAttention",
31
+ "create_causal_mask",
32
+ "FeedForward",
33
+ "TransformerDecoderBlock",
34
+ "GPTTransformer",
35
+ ]
@@ -0,0 +1,188 @@
1
+ """
2
+ Multi-Head Causal Self-Attention
3
+
4
+ PURPOSE:
5
+ This is the core mechanism that allows the model to attend to different parts
6
+ of the input sequence. "Causal" means the model can only look at previous tokens,
7
+ not future ones (essential for next-token prediction).
8
+
9
+ WHAT THIS STEP DOES:
10
+ 1. Linear projections: Create Query, Key, Value matrices
11
+ - Input: [batch_size, seq_len, d_model]
12
+ - Q, K, V each: [batch_size, seq_len, d_model]
13
+
14
+ 2. Split into multiple heads
15
+ - Reshape to: [batch_size, n_heads, seq_len, d_head]
16
+ - where d_head = d_model / n_heads
17
+
18
+ 3. Scaled dot-product attention
19
+ - Compute attention scores: Q @ K^T / sqrt(d_head)
20
+ - Apply causal mask (CRITICAL: prevents looking at future)
21
+ - Softmax to get attention weights
22
+ - Apply to values: attention_weights @ V
23
+
24
+ 4. Concatenate heads and project back
25
+ - Output: [batch_size, seq_len, d_model]
26
+
27
+ PACKAGES USED:
28
+ - torch: PyTorch tensors and operations
29
+ - torch.nn: Linear layers, Dropout
30
+ - torch.nn.functional: Softmax
31
+ - math: sqrt for scaling
32
+
33
+ FILES FROM THIS PROJECT:
34
+ - None (this is a base component)
35
+
36
+ TENSOR SHAPES EXPLAINED:
37
+ - n_heads: Number of attention heads (4-8)
38
+ - d_head: Dimension per head (d_model / n_heads)
39
+ - Causal mask: Lower triangular matrix [seq_len, seq_len]
40
+
41
+ COMMON FAILURE MODES TO AVOID:
42
+ - Missing causal mask → model cheats by seeing future tokens
43
+ - Wrong mask shape → silent failures or crashes
44
+ - Not scaling attention scores → vanishing/exploding gradients
45
+ - Forgetting dropout → overfitting
46
+ - Wrong tensor transpose/reshape → incorrect attention patterns
47
+ - Not masking padding tokens → attending to meaningless tokens
48
+ """
49
+
50
+ import torch
51
+ import torch.nn as nn
52
+ import torch.nn.functional as F
53
+ import math
54
+
55
+
56
+ class MultiHeadAttention(nn.Module):
57
+ """
58
+ Multi-Head Causal Self-Attention.
59
+
60
+ This is the CORE of the transformer. Understanding this is critical.
61
+
62
+ Key concept: Attention lets each token "look at" other tokens to gather context.
63
+ "Causal" means token at position i can ONLY look at positions <= i (not future).
64
+
65
+ Tensor shape flow:
66
+ Input: [batch_size, seq_len, d_model]
67
+ Q,K,V: [batch_size, seq_len, d_model]
68
+ Split: [batch_size, n_heads, seq_len, d_head]
69
+ Attention: [batch_size, n_heads, seq_len, seq_len]
70
+ Output: [batch_size, seq_len, d_model]
71
+ """
72
+
73
+ def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
74
+ """
75
+ Args:
76
+ d_model: Model dimension
77
+ n_heads: Number of attention heads
78
+ dropout: Dropout probability
79
+ """
80
+ super().__init__()
81
+
82
+ assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
83
+
84
+ self.d_model = d_model
85
+ self.n_heads = n_heads
86
+ self.d_head = d_model // n_heads # Dimension per head
87
+
88
+ # Linear projections for Q, K, V
89
+ # We use separate projections for each, not combined
90
+ self.q_linear = nn.Linear(d_model, d_model)
91
+ self.k_linear = nn.Linear(d_model, d_model)
92
+ self.v_linear = nn.Linear(d_model, d_model)
93
+
94
+ # Output projection
95
+ self.out_linear = nn.Linear(d_model, d_model)
96
+
97
+ # Dropout
98
+ self.dropout = nn.Dropout(dropout)
99
+
100
+ # Scaling factor for attention scores
101
+ self.scale = math.sqrt(self.d_head)
102
+
103
+ def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
104
+ """
105
+ Args:
106
+ x: Input tensor [batch_size, seq_len, d_model]
107
+ mask: Causal mask [seq_len, seq_len] or [batch_size, seq_len, seq_len]
108
+
109
+ Returns:
110
+ Output tensor [batch_size, seq_len, d_model]
111
+ """
112
+ batch_size, seq_len, d_model = x.size()
113
+
114
+ # 1. Linear projections
115
+ # Each: [batch_size, seq_len, d_model]
116
+ Q = self.q_linear(x)
117
+ K = self.k_linear(x)
118
+ V = self.v_linear(x)
119
+
120
+ # 2. Split into multiple heads
121
+ # Reshape: [batch_size, seq_len, n_heads, d_head]
122
+ # Then transpose: [batch_size, n_heads, seq_len, d_head]
123
+ Q = Q.view(batch_size, seq_len, self.n_heads, self.d_head).transpose(1, 2)
124
+ K = K.view(batch_size, seq_len, self.n_heads, self.d_head).transpose(1, 2)
125
+ V = V.view(batch_size, seq_len, self.n_heads, self.d_head).transpose(1, 2)
126
+
127
+ # 3. Scaled dot-product attention
128
+ # Q @ K^T: [batch_size, n_heads, seq_len, seq_len]
129
+ attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
130
+
131
+ # 4. Apply causal mask (CRITICAL!)
132
+ # This prevents position i from attending to positions > i
133
+ if mask is not None:
134
+ # mask should be [seq_len, seq_len] or [batch_size, 1, seq_len, seq_len]
135
+ attention_scores = attention_scores.masked_fill(mask == 0, float("-inf"))
136
+
137
+ # 5. Softmax to get attention weights
138
+ # [batch_size, n_heads, seq_len, seq_len]
139
+ attention_weights = F.softmax(attention_scores, dim=-1)
140
+
141
+ # Apply dropout to attention weights (standard practice)
142
+ attention_weights = self.dropout(attention_weights)
143
+
144
+ # 6. Apply attention to values
145
+ # [batch_size, n_heads, seq_len, d_head]
146
+ attended_values = torch.matmul(attention_weights, V)
147
+
148
+ # 7. Concatenate heads
149
+ # Transpose back: [batch_size, seq_len, n_heads, d_head]
150
+ # Then reshape: [batch_size, seq_len, d_model]
151
+ attended_values = attended_values.transpose(1, 2).contiguous()
152
+ attended_values = attended_values.view(batch_size, seq_len, self.d_model)
153
+
154
+ # 8. Final linear projection
155
+ output = self.out_linear(attended_values)
156
+
157
+ return output
158
+
159
+
160
+ def create_causal_mask(seq_len: int, device: torch.device = None) -> torch.Tensor:
161
+ """
162
+ Create a causal mask for autoregressive generation.
163
+
164
+ The mask is a lower triangular matrix:
165
+ [[1, 0, 0, 0],
166
+ [1, 1, 0, 0],
167
+ [1, 1, 1, 0],
168
+ [1, 1, 1, 1]]
169
+
170
+ This ensures position i can only attend to positions <= i.
171
+
172
+ Args:
173
+ seq_len: Sequence length
174
+ device: Device to create tensor on
175
+
176
+ Returns:
177
+ Causal mask [seq_len, seq_len]
178
+
179
+ Why this works:
180
+ - Position 0 can see position 0 only
181
+ - Position 1 can see positions 0, 1
182
+ - Position 2 can see positions 0, 1, 2
183
+ - etc.
184
+
185
+ This is the ESSENCE of causal language modeling!
186
+ """
187
+ mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
188
+ return mask # [seq_len, seq_len]
@@ -0,0 +1,130 @@
1
+ """
2
+ Transformer Decoder Block
3
+
4
+ PURPOSE:
5
+ Combines attention, feed-forward network, residual connections, and layer
6
+ normalization into a single reusable block. Multiple blocks are stacked
7
+ to form the full transformer.
8
+
9
+ WHAT THIS STEP DOES:
10
+ 1. Multi-head causal self-attention
11
+ - With residual connection: x + attention(x)
12
+ - With layer normalization
13
+
14
+ 2. Feed-forward network
15
+ - With residual connection: x + ffn(x)
16
+ - With layer normalization
17
+
18
+ Architecture pattern (Pre-LN vs Post-LN):
19
+ We use Pre-LN (normalize before sublayer) because it's more stable:
20
+ x = x + attention(LayerNorm(x))
21
+ x = x + ffn(LayerNorm(x))
22
+
23
+ PACKAGES USED:
24
+ - torch: PyTorch tensors
25
+ - torch.nn: Module, LayerNorm, Dropout
26
+
27
+ FILES FROM THIS PROJECT:
28
+ - architecture/attention.py: Multi-head attention module
29
+ - architecture/feedforward.py: FFN module
30
+
31
+ TENSOR SHAPES:
32
+ - Input: [batch_size, seq_len, d_model]
33
+ - Output: [batch_size, seq_len, d_model] (unchanged)
34
+
35
+ COMMON FAILURE MODES TO AVOID:
36
+ - Post-LN instead of Pre-LN → training instability
37
+ - Forgetting residual connections → vanishing gradients
38
+ - Wrong LayerNorm dimension → incorrect normalization
39
+ - Dropout too high → underfitting
40
+ - Dropout too low → overfitting
41
+ """
42
+
43
+ import torch
44
+ import torch.nn as nn
45
+
46
+ from .attention import MultiHeadAttention
47
+ from .feedforward import FeedForward
48
+
49
+
50
+ class TransformerDecoderBlock(nn.Module):
51
+ """
52
+ Single Transformer Decoder Block.
53
+
54
+ This is one "layer" of the transformer. We stack multiple of these.
55
+
56
+ Architecture (Pre-LN):
57
+ 1. LayerNorm → Multi-Head Attention → Residual
58
+ 2. LayerNorm → Feed-Forward → Residual
59
+
60
+ Tensor shape flow:
61
+ Input: [batch_size, seq_len, d_model]
62
+ Output: [batch_size, seq_len, d_model] (same shape)
63
+
64
+ Why Pre-LN instead of Post-LN?
65
+ - Pre-LN: Normalize BEFORE sublayer → more stable gradients
66
+ - Post-LN: Normalize AFTER sublayer → can have training instability
67
+ - GPT-2 and modern transformers use Pre-LN
68
+ """
69
+
70
+ def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
71
+ """
72
+ Args:
73
+ d_model: Model dimension
74
+ n_heads: Number of attention heads
75
+ d_ff: Feed-forward hidden dimension
76
+ dropout: Dropout probability
77
+ """
78
+ super().__init__()
79
+
80
+ # Multi-head self-attention
81
+ self.attention = MultiHeadAttention(d_model, n_heads, dropout)
82
+
83
+ # Feed-forward network
84
+ self.feed_forward = FeedForward(d_model, d_ff, dropout)
85
+
86
+ # Layer normalization (one for each sublayer)
87
+ self.norm1 = nn.LayerNorm(d_model)
88
+ self.norm2 = nn.LayerNorm(d_model)
89
+
90
+ # Dropout for residual connections
91
+ self.dropout = nn.Dropout(dropout)
92
+
93
+ def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
94
+ """
95
+ Args:
96
+ x: Input tensor [batch_size, seq_len, d_model]
97
+ mask: Causal mask [seq_len, seq_len]
98
+
99
+ Returns:
100
+ Output tensor [batch_size, seq_len, d_model]
101
+
102
+ Step-by-step explanation:
103
+ 1. Normalize input
104
+ 2. Apply attention
105
+ 3. Add residual (skip connection)
106
+ 4. Normalize result
107
+ 5. Apply feed-forward
108
+ 6. Add residual (skip connection)
109
+ """
110
+ # Sublayer 1: Self-Attention with residual
111
+ # Pre-LN: normalize first
112
+ normed = self.norm1(x)
113
+
114
+ # Apply attention
115
+ attention_output = self.attention(normed, mask)
116
+
117
+ # Residual connection: x + attention(norm(x))
118
+ x = x + self.dropout(attention_output)
119
+
120
+ # Sublayer 2: Feed-Forward with residual
121
+ # Pre-LN: normalize first
122
+ normed = self.norm2(x)
123
+
124
+ # Apply feed-forward
125
+ ff_output = self.feed_forward(normed)
126
+
127
+ # Residual connection: x + ffn(norm(x))
128
+ x = x + self.dropout(ff_output)
129
+
130
+ return x