gptmed 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gptmed/__init__.py +37 -0
- gptmed/configs/__init__.py +1 -0
- gptmed/configs/train_config.py +154 -0
- gptmed/data/__init__.py +5 -0
- gptmed/data/parsers/__init__.py +10 -0
- gptmed/data/parsers/medquad_parser.py +257 -0
- gptmed/data/parsers/text_formatter.py +148 -0
- gptmed/inference/__init__.py +1 -0
- gptmed/inference/decoding_utils.py +190 -0
- gptmed/inference/generation_config.py +83 -0
- gptmed/inference/generator.py +253 -0
- gptmed/inference/sampling.py +261 -0
- gptmed/model/__init__.py +9 -0
- gptmed/model/architecture/__init__.py +35 -0
- gptmed/model/architecture/attention.py +188 -0
- gptmed/model/architecture/decoder_block.py +130 -0
- gptmed/model/architecture/embeddings.py +146 -0
- gptmed/model/architecture/feedforward.py +109 -0
- gptmed/model/architecture/transformer.py +204 -0
- gptmed/model/configs/__init__.py +17 -0
- gptmed/model/configs/model_config.py +155 -0
- gptmed/tokenizer/__init__.py +7 -0
- gptmed/tokenizer/tokenize_data.py +286 -0
- gptmed/tokenizer/train_tokenizer.py +218 -0
- gptmed/training/__init__.py +1 -0
- gptmed/training/dataset.py +183 -0
- gptmed/training/train.py +272 -0
- gptmed/training/trainer.py +331 -0
- gptmed/training/utils.py +212 -0
- gptmed/utils/__init__.py +1 -0
- gptmed/utils/checkpoints.py +224 -0
- gptmed/utils/logging.py +189 -0
- gptmed-0.0.1.dist-info/METADATA +325 -0
- gptmed-0.0.1.dist-info/RECORD +38 -0
- gptmed-0.0.1.dist-info/WHEEL +5 -0
- gptmed-0.0.1.dist-info/entry_points.txt +3 -0
- gptmed-0.0.1.dist-info/licenses/LICENSE +21 -0
- gptmed-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,261 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Sampling Strategies for Text Generation
|
|
3
|
+
|
|
4
|
+
PURPOSE:
|
|
5
|
+
Different methods to select the next token during generation.
|
|
6
|
+
Each strategy has trade-offs between quality, diversity, and speed.
|
|
7
|
+
|
|
8
|
+
WHAT THIS FILE DOES:
|
|
9
|
+
1. Greedy sampling: Always pick highest probability (deterministic)
|
|
10
|
+
2. Temperature sampling: Control randomness
|
|
11
|
+
3. Top-k sampling: Sample from top k tokens only
|
|
12
|
+
4. Top-p (nucleus) sampling: Sample from cumulative probability p
|
|
13
|
+
|
|
14
|
+
WHY DIFFERENT STRATEGIES:
|
|
15
|
+
- Greedy: Fast, deterministic, but boring and repetitive
|
|
16
|
+
- Temperature: Simple randomness control
|
|
17
|
+
- Top-k: Prevents sampling very unlikely tokens
|
|
18
|
+
- Top-p: More adaptive than top-k (adjusts to probability distribution)
|
|
19
|
+
|
|
20
|
+
PACKAGES USED:
|
|
21
|
+
- torch: PyTorch tensors and operations
|
|
22
|
+
|
|
23
|
+
FILES FROM THIS PROJECT:
|
|
24
|
+
- None (utility functions)
|
|
25
|
+
|
|
26
|
+
COMMON ISSUES:
|
|
27
|
+
- Temperature too low → boring, repetitive
|
|
28
|
+
- Temperature too high → incoherent
|
|
29
|
+
- Top-k too small → limited diversity
|
|
30
|
+
- Top-p too low → truncates good options
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
import torch
|
|
34
|
+
import torch.nn.functional as F
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def greedy_sample(logits: torch.Tensor) -> torch.Tensor:
|
|
38
|
+
"""
|
|
39
|
+
Greedy sampling: Always pick the highest probability token.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
logits: Logits from model [batch_size, vocab_size]
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
Next token IDs [batch_size]
|
|
46
|
+
|
|
47
|
+
Pros: Fast, deterministic, reproducible
|
|
48
|
+
Cons: Boring, repetitive, gets stuck in loops
|
|
49
|
+
|
|
50
|
+
Use when: You want deterministic outputs or testing
|
|
51
|
+
"""
|
|
52
|
+
return torch.argmax(logits, dim=-1)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def temperature_sample(logits: torch.Tensor, temperature: float) -> torch.Tensor:
|
|
56
|
+
"""
|
|
57
|
+
Temperature sampling: Scale logits before softmax.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
logits: Logits from model [batch_size, vocab_size]
|
|
61
|
+
temperature: Temperature parameter (>0)
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
Sampled token IDs [batch_size]
|
|
65
|
+
|
|
66
|
+
How it works:
|
|
67
|
+
- temperature = 1.0: No change (normal sampling)
|
|
68
|
+
- temperature < 1.0: More conservative (peaks sharper)
|
|
69
|
+
- temperature > 1.0: More random (distribution flatter)
|
|
70
|
+
|
|
71
|
+
Example:
|
|
72
|
+
- temp=0.1: Almost greedy
|
|
73
|
+
- temp=0.7: Balanced (recommended)
|
|
74
|
+
- temp=1.5: Very creative
|
|
75
|
+
|
|
76
|
+
Pros: Simple, interpretable
|
|
77
|
+
Cons: No control over tail probabilities
|
|
78
|
+
"""
|
|
79
|
+
if temperature == 0.0:
|
|
80
|
+
return greedy_sample(logits)
|
|
81
|
+
|
|
82
|
+
# Scale logits by temperature
|
|
83
|
+
scaled_logits = logits / temperature
|
|
84
|
+
|
|
85
|
+
# Convert to probabilities
|
|
86
|
+
probs = F.softmax(scaled_logits, dim=-1)
|
|
87
|
+
|
|
88
|
+
# Sample from distribution
|
|
89
|
+
next_token = torch.multinomial(probs, num_samples=1).squeeze(-1)
|
|
90
|
+
|
|
91
|
+
return next_token
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def top_k_sample(logits: torch.Tensor, k: int, temperature: float = 1.0) -> torch.Tensor:
|
|
95
|
+
"""
|
|
96
|
+
Top-k sampling: Only sample from top k tokens.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
logits: Logits from model [batch_size, vocab_size]
|
|
100
|
+
k: Number of top tokens to keep
|
|
101
|
+
temperature: Temperature scaling
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
Sampled token IDs [batch_size]
|
|
105
|
+
|
|
106
|
+
How it works:
|
|
107
|
+
- Keep only top k highest probability tokens
|
|
108
|
+
- Set all other tokens to -inf (zero probability)
|
|
109
|
+
- Sample from remaining tokens
|
|
110
|
+
|
|
111
|
+
Why it helps:
|
|
112
|
+
- Prevents sampling very unlikely tokens (noise)
|
|
113
|
+
- Reduces incoherent outputs
|
|
114
|
+
|
|
115
|
+
Typical values:
|
|
116
|
+
- k=1: Greedy
|
|
117
|
+
- k=10: Very conservative
|
|
118
|
+
- k=50: Balanced
|
|
119
|
+
- k=100: More diverse
|
|
120
|
+
|
|
121
|
+
Limitation:
|
|
122
|
+
- Fixed k doesn't adapt to probability distribution
|
|
123
|
+
- If top-1 has 99% probability, k=50 wastes options
|
|
124
|
+
"""
|
|
125
|
+
if k == 0 or k >= logits.size(-1):
|
|
126
|
+
# No filtering
|
|
127
|
+
return temperature_sample(logits, temperature)
|
|
128
|
+
|
|
129
|
+
# Get top k values and indices
|
|
130
|
+
top_k_logits, top_k_indices = torch.topk(logits, k, dim=-1)
|
|
131
|
+
|
|
132
|
+
# Create filtered logits (set non-top-k to -inf)
|
|
133
|
+
filtered_logits = torch.full_like(logits, float("-inf"))
|
|
134
|
+
filtered_logits.scatter_(-1, top_k_indices, top_k_logits)
|
|
135
|
+
|
|
136
|
+
# Sample with temperature
|
|
137
|
+
return temperature_sample(filtered_logits, temperature)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def top_p_sample(logits: torch.Tensor, p: float, temperature: float = 1.0) -> torch.Tensor:
|
|
141
|
+
"""
|
|
142
|
+
Top-p (nucleus) sampling: Sample from smallest set with cumulative prob >= p.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
logits: Logits from model [batch_size, vocab_size]
|
|
146
|
+
p: Cumulative probability threshold (0 < p <= 1)
|
|
147
|
+
temperature: Temperature scaling
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
Sampled token IDs [batch_size]
|
|
151
|
+
|
|
152
|
+
How it works:
|
|
153
|
+
1. Sort tokens by probability (descending)
|
|
154
|
+
2. Find smallest set where cumulative probability >= p
|
|
155
|
+
3. Sample only from this set
|
|
156
|
+
|
|
157
|
+
Why better than top-k:
|
|
158
|
+
- Adapts to probability distribution
|
|
159
|
+
- When model is confident (one token has 90%), nucleus is small
|
|
160
|
+
- When uncertain, nucleus is larger (more options)
|
|
161
|
+
|
|
162
|
+
Typical values:
|
|
163
|
+
- p=0.9: Conservative (90% probability mass)
|
|
164
|
+
- p=0.95: Balanced (recommended)
|
|
165
|
+
- p=0.99: More diverse
|
|
166
|
+
|
|
167
|
+
Used in: GPT-3, ChatGPT, most modern LLMs
|
|
168
|
+
"""
|
|
169
|
+
if p >= 1.0:
|
|
170
|
+
# No filtering
|
|
171
|
+
return temperature_sample(logits, temperature)
|
|
172
|
+
|
|
173
|
+
# Scale by temperature first
|
|
174
|
+
scaled_logits = logits / temperature
|
|
175
|
+
|
|
176
|
+
# Convert to probabilities
|
|
177
|
+
probs = F.softmax(scaled_logits, dim=-1)
|
|
178
|
+
|
|
179
|
+
# Sort probabilities descending
|
|
180
|
+
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
|
|
181
|
+
|
|
182
|
+
# Compute cumulative probabilities
|
|
183
|
+
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
|
184
|
+
|
|
185
|
+
# Find cutoff: first position where cumsum > p
|
|
186
|
+
# Shift right by 1 to keep at least one token
|
|
187
|
+
sorted_indices_to_remove = cumulative_probs > p
|
|
188
|
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
|
189
|
+
sorted_indices_to_remove[..., 0] = False
|
|
190
|
+
|
|
191
|
+
# Create mask in original order
|
|
192
|
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
|
193
|
+
-1, sorted_indices, sorted_indices_to_remove
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
# Set removed indices to -inf
|
|
197
|
+
filtered_logits = scaled_logits.clone()
|
|
198
|
+
filtered_logits[indices_to_remove] = float("-inf")
|
|
199
|
+
|
|
200
|
+
# Sample from filtered distribution
|
|
201
|
+
probs = F.softmax(filtered_logits, dim=-1)
|
|
202
|
+
next_token = torch.multinomial(probs, num_samples=1).squeeze(-1)
|
|
203
|
+
|
|
204
|
+
return next_token
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def sample_next_token(
|
|
208
|
+
logits: torch.Tensor, temperature: float = 1.0, top_k: int = 0, top_p: float = 1.0
|
|
209
|
+
) -> torch.Tensor:
|
|
210
|
+
"""
|
|
211
|
+
Unified sampling function combining temperature, top-k, and top-p.
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
logits: Logits from model [batch_size, vocab_size]
|
|
215
|
+
temperature: Temperature parameter
|
|
216
|
+
top_k: Top-k filtering (0 = disabled)
|
|
217
|
+
top_p: Top-p filtering (1.0 = disabled)
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
Sampled token IDs [batch_size]
|
|
221
|
+
|
|
222
|
+
Order of operations:
|
|
223
|
+
1. Temperature scaling
|
|
224
|
+
2. Top-k filtering (if enabled)
|
|
225
|
+
3. Top-p filtering (if enabled)
|
|
226
|
+
4. Sample from remaining distribution
|
|
227
|
+
"""
|
|
228
|
+
# Greedy if temperature is 0
|
|
229
|
+
if temperature == 0.0:
|
|
230
|
+
return greedy_sample(logits)
|
|
231
|
+
|
|
232
|
+
# Apply temperature
|
|
233
|
+
scaled_logits = logits / temperature
|
|
234
|
+
|
|
235
|
+
# Apply top-k if enabled
|
|
236
|
+
if top_k > 0 and top_k < logits.size(-1):
|
|
237
|
+
top_k_logits, top_k_indices = torch.topk(scaled_logits, top_k, dim=-1)
|
|
238
|
+
filtered_logits = torch.full_like(scaled_logits, float("-inf"))
|
|
239
|
+
filtered_logits.scatter_(-1, top_k_indices, top_k_logits)
|
|
240
|
+
scaled_logits = filtered_logits
|
|
241
|
+
|
|
242
|
+
# Apply top-p if enabled
|
|
243
|
+
if top_p < 1.0:
|
|
244
|
+
probs = F.softmax(scaled_logits, dim=-1)
|
|
245
|
+
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
|
|
246
|
+
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
|
247
|
+
|
|
248
|
+
sorted_indices_to_remove = cumulative_probs > top_p
|
|
249
|
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
|
250
|
+
sorted_indices_to_remove[..., 0] = False
|
|
251
|
+
|
|
252
|
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
|
253
|
+
-1, sorted_indices, sorted_indices_to_remove
|
|
254
|
+
)
|
|
255
|
+
scaled_logits[indices_to_remove] = float("-inf")
|
|
256
|
+
|
|
257
|
+
# Sample
|
|
258
|
+
probs = F.softmax(scaled_logits, dim=-1)
|
|
259
|
+
next_token = torch.multinomial(probs, num_samples=1).squeeze(-1)
|
|
260
|
+
|
|
261
|
+
return next_token
|
gptmed/model/__init__.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Model Architecture Components - Package Initializer
|
|
3
|
+
|
|
4
|
+
PURPOSE:
|
|
5
|
+
Makes it easy to import transformer components from a single location.
|
|
6
|
+
|
|
7
|
+
WHAT THIS FILE DOES:
|
|
8
|
+
Exposes main classes for import:
|
|
9
|
+
- from model.architecture import GPTTransformer
|
|
10
|
+
- from model.architecture import TransformerDecoderBlock
|
|
11
|
+
- from model.architecture import MultiHeadAttention
|
|
12
|
+
|
|
13
|
+
PACKAGES USED:
|
|
14
|
+
- None (just Python imports)
|
|
15
|
+
|
|
16
|
+
FILES FROM THIS PROJECT:
|
|
17
|
+
- All components in this architecture/ directory
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from .embeddings import TokenEmbedding, PositionalEmbedding, TokenPositionalEmbedding
|
|
21
|
+
from .attention import MultiHeadAttention, create_causal_mask
|
|
22
|
+
from .feedforward import FeedForward
|
|
23
|
+
from .decoder_block import TransformerDecoderBlock
|
|
24
|
+
from .transformer import GPTTransformer
|
|
25
|
+
|
|
26
|
+
__all__ = [
|
|
27
|
+
"TokenEmbedding",
|
|
28
|
+
"PositionalEmbedding",
|
|
29
|
+
"TokenPositionalEmbedding",
|
|
30
|
+
"MultiHeadAttention",
|
|
31
|
+
"create_causal_mask",
|
|
32
|
+
"FeedForward",
|
|
33
|
+
"TransformerDecoderBlock",
|
|
34
|
+
"GPTTransformer",
|
|
35
|
+
]
|
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Multi-Head Causal Self-Attention
|
|
3
|
+
|
|
4
|
+
PURPOSE:
|
|
5
|
+
This is the core mechanism that allows the model to attend to different parts
|
|
6
|
+
of the input sequence. "Causal" means the model can only look at previous tokens,
|
|
7
|
+
not future ones (essential for next-token prediction).
|
|
8
|
+
|
|
9
|
+
WHAT THIS STEP DOES:
|
|
10
|
+
1. Linear projections: Create Query, Key, Value matrices
|
|
11
|
+
- Input: [batch_size, seq_len, d_model]
|
|
12
|
+
- Q, K, V each: [batch_size, seq_len, d_model]
|
|
13
|
+
|
|
14
|
+
2. Split into multiple heads
|
|
15
|
+
- Reshape to: [batch_size, n_heads, seq_len, d_head]
|
|
16
|
+
- where d_head = d_model / n_heads
|
|
17
|
+
|
|
18
|
+
3. Scaled dot-product attention
|
|
19
|
+
- Compute attention scores: Q @ K^T / sqrt(d_head)
|
|
20
|
+
- Apply causal mask (CRITICAL: prevents looking at future)
|
|
21
|
+
- Softmax to get attention weights
|
|
22
|
+
- Apply to values: attention_weights @ V
|
|
23
|
+
|
|
24
|
+
4. Concatenate heads and project back
|
|
25
|
+
- Output: [batch_size, seq_len, d_model]
|
|
26
|
+
|
|
27
|
+
PACKAGES USED:
|
|
28
|
+
- torch: PyTorch tensors and operations
|
|
29
|
+
- torch.nn: Linear layers, Dropout
|
|
30
|
+
- torch.nn.functional: Softmax
|
|
31
|
+
- math: sqrt for scaling
|
|
32
|
+
|
|
33
|
+
FILES FROM THIS PROJECT:
|
|
34
|
+
- None (this is a base component)
|
|
35
|
+
|
|
36
|
+
TENSOR SHAPES EXPLAINED:
|
|
37
|
+
- n_heads: Number of attention heads (4-8)
|
|
38
|
+
- d_head: Dimension per head (d_model / n_heads)
|
|
39
|
+
- Causal mask: Lower triangular matrix [seq_len, seq_len]
|
|
40
|
+
|
|
41
|
+
COMMON FAILURE MODES TO AVOID:
|
|
42
|
+
- Missing causal mask → model cheats by seeing future tokens
|
|
43
|
+
- Wrong mask shape → silent failures or crashes
|
|
44
|
+
- Not scaling attention scores → vanishing/exploding gradients
|
|
45
|
+
- Forgetting dropout → overfitting
|
|
46
|
+
- Wrong tensor transpose/reshape → incorrect attention patterns
|
|
47
|
+
- Not masking padding tokens → attending to meaningless tokens
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
import torch
|
|
51
|
+
import torch.nn as nn
|
|
52
|
+
import torch.nn.functional as F
|
|
53
|
+
import math
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class MultiHeadAttention(nn.Module):
|
|
57
|
+
"""
|
|
58
|
+
Multi-Head Causal Self-Attention.
|
|
59
|
+
|
|
60
|
+
This is the CORE of the transformer. Understanding this is critical.
|
|
61
|
+
|
|
62
|
+
Key concept: Attention lets each token "look at" other tokens to gather context.
|
|
63
|
+
"Causal" means token at position i can ONLY look at positions <= i (not future).
|
|
64
|
+
|
|
65
|
+
Tensor shape flow:
|
|
66
|
+
Input: [batch_size, seq_len, d_model]
|
|
67
|
+
Q,K,V: [batch_size, seq_len, d_model]
|
|
68
|
+
Split: [batch_size, n_heads, seq_len, d_head]
|
|
69
|
+
Attention: [batch_size, n_heads, seq_len, seq_len]
|
|
70
|
+
Output: [batch_size, seq_len, d_model]
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
|
|
74
|
+
"""
|
|
75
|
+
Args:
|
|
76
|
+
d_model: Model dimension
|
|
77
|
+
n_heads: Number of attention heads
|
|
78
|
+
dropout: Dropout probability
|
|
79
|
+
"""
|
|
80
|
+
super().__init__()
|
|
81
|
+
|
|
82
|
+
assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
|
|
83
|
+
|
|
84
|
+
self.d_model = d_model
|
|
85
|
+
self.n_heads = n_heads
|
|
86
|
+
self.d_head = d_model // n_heads # Dimension per head
|
|
87
|
+
|
|
88
|
+
# Linear projections for Q, K, V
|
|
89
|
+
# We use separate projections for each, not combined
|
|
90
|
+
self.q_linear = nn.Linear(d_model, d_model)
|
|
91
|
+
self.k_linear = nn.Linear(d_model, d_model)
|
|
92
|
+
self.v_linear = nn.Linear(d_model, d_model)
|
|
93
|
+
|
|
94
|
+
# Output projection
|
|
95
|
+
self.out_linear = nn.Linear(d_model, d_model)
|
|
96
|
+
|
|
97
|
+
# Dropout
|
|
98
|
+
self.dropout = nn.Dropout(dropout)
|
|
99
|
+
|
|
100
|
+
# Scaling factor for attention scores
|
|
101
|
+
self.scale = math.sqrt(self.d_head)
|
|
102
|
+
|
|
103
|
+
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
|
|
104
|
+
"""
|
|
105
|
+
Args:
|
|
106
|
+
x: Input tensor [batch_size, seq_len, d_model]
|
|
107
|
+
mask: Causal mask [seq_len, seq_len] or [batch_size, seq_len, seq_len]
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
Output tensor [batch_size, seq_len, d_model]
|
|
111
|
+
"""
|
|
112
|
+
batch_size, seq_len, d_model = x.size()
|
|
113
|
+
|
|
114
|
+
# 1. Linear projections
|
|
115
|
+
# Each: [batch_size, seq_len, d_model]
|
|
116
|
+
Q = self.q_linear(x)
|
|
117
|
+
K = self.k_linear(x)
|
|
118
|
+
V = self.v_linear(x)
|
|
119
|
+
|
|
120
|
+
# 2. Split into multiple heads
|
|
121
|
+
# Reshape: [batch_size, seq_len, n_heads, d_head]
|
|
122
|
+
# Then transpose: [batch_size, n_heads, seq_len, d_head]
|
|
123
|
+
Q = Q.view(batch_size, seq_len, self.n_heads, self.d_head).transpose(1, 2)
|
|
124
|
+
K = K.view(batch_size, seq_len, self.n_heads, self.d_head).transpose(1, 2)
|
|
125
|
+
V = V.view(batch_size, seq_len, self.n_heads, self.d_head).transpose(1, 2)
|
|
126
|
+
|
|
127
|
+
# 3. Scaled dot-product attention
|
|
128
|
+
# Q @ K^T: [batch_size, n_heads, seq_len, seq_len]
|
|
129
|
+
attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
|
|
130
|
+
|
|
131
|
+
# 4. Apply causal mask (CRITICAL!)
|
|
132
|
+
# This prevents position i from attending to positions > i
|
|
133
|
+
if mask is not None:
|
|
134
|
+
# mask should be [seq_len, seq_len] or [batch_size, 1, seq_len, seq_len]
|
|
135
|
+
attention_scores = attention_scores.masked_fill(mask == 0, float("-inf"))
|
|
136
|
+
|
|
137
|
+
# 5. Softmax to get attention weights
|
|
138
|
+
# [batch_size, n_heads, seq_len, seq_len]
|
|
139
|
+
attention_weights = F.softmax(attention_scores, dim=-1)
|
|
140
|
+
|
|
141
|
+
# Apply dropout to attention weights (standard practice)
|
|
142
|
+
attention_weights = self.dropout(attention_weights)
|
|
143
|
+
|
|
144
|
+
# 6. Apply attention to values
|
|
145
|
+
# [batch_size, n_heads, seq_len, d_head]
|
|
146
|
+
attended_values = torch.matmul(attention_weights, V)
|
|
147
|
+
|
|
148
|
+
# 7. Concatenate heads
|
|
149
|
+
# Transpose back: [batch_size, seq_len, n_heads, d_head]
|
|
150
|
+
# Then reshape: [batch_size, seq_len, d_model]
|
|
151
|
+
attended_values = attended_values.transpose(1, 2).contiguous()
|
|
152
|
+
attended_values = attended_values.view(batch_size, seq_len, self.d_model)
|
|
153
|
+
|
|
154
|
+
# 8. Final linear projection
|
|
155
|
+
output = self.out_linear(attended_values)
|
|
156
|
+
|
|
157
|
+
return output
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def create_causal_mask(seq_len: int, device: torch.device = None) -> torch.Tensor:
|
|
161
|
+
"""
|
|
162
|
+
Create a causal mask for autoregressive generation.
|
|
163
|
+
|
|
164
|
+
The mask is a lower triangular matrix:
|
|
165
|
+
[[1, 0, 0, 0],
|
|
166
|
+
[1, 1, 0, 0],
|
|
167
|
+
[1, 1, 1, 0],
|
|
168
|
+
[1, 1, 1, 1]]
|
|
169
|
+
|
|
170
|
+
This ensures position i can only attend to positions <= i.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
seq_len: Sequence length
|
|
174
|
+
device: Device to create tensor on
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
Causal mask [seq_len, seq_len]
|
|
178
|
+
|
|
179
|
+
Why this works:
|
|
180
|
+
- Position 0 can see position 0 only
|
|
181
|
+
- Position 1 can see positions 0, 1
|
|
182
|
+
- Position 2 can see positions 0, 1, 2
|
|
183
|
+
- etc.
|
|
184
|
+
|
|
185
|
+
This is the ESSENCE of causal language modeling!
|
|
186
|
+
"""
|
|
187
|
+
mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
|
|
188
|
+
return mask # [seq_len, seq_len]
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Transformer Decoder Block
|
|
3
|
+
|
|
4
|
+
PURPOSE:
|
|
5
|
+
Combines attention, feed-forward network, residual connections, and layer
|
|
6
|
+
normalization into a single reusable block. Multiple blocks are stacked
|
|
7
|
+
to form the full transformer.
|
|
8
|
+
|
|
9
|
+
WHAT THIS STEP DOES:
|
|
10
|
+
1. Multi-head causal self-attention
|
|
11
|
+
- With residual connection: x + attention(x)
|
|
12
|
+
- With layer normalization
|
|
13
|
+
|
|
14
|
+
2. Feed-forward network
|
|
15
|
+
- With residual connection: x + ffn(x)
|
|
16
|
+
- With layer normalization
|
|
17
|
+
|
|
18
|
+
Architecture pattern (Pre-LN vs Post-LN):
|
|
19
|
+
We use Pre-LN (normalize before sublayer) because it's more stable:
|
|
20
|
+
x = x + attention(LayerNorm(x))
|
|
21
|
+
x = x + ffn(LayerNorm(x))
|
|
22
|
+
|
|
23
|
+
PACKAGES USED:
|
|
24
|
+
- torch: PyTorch tensors
|
|
25
|
+
- torch.nn: Module, LayerNorm, Dropout
|
|
26
|
+
|
|
27
|
+
FILES FROM THIS PROJECT:
|
|
28
|
+
- architecture/attention.py: Multi-head attention module
|
|
29
|
+
- architecture/feedforward.py: FFN module
|
|
30
|
+
|
|
31
|
+
TENSOR SHAPES:
|
|
32
|
+
- Input: [batch_size, seq_len, d_model]
|
|
33
|
+
- Output: [batch_size, seq_len, d_model] (unchanged)
|
|
34
|
+
|
|
35
|
+
COMMON FAILURE MODES TO AVOID:
|
|
36
|
+
- Post-LN instead of Pre-LN → training instability
|
|
37
|
+
- Forgetting residual connections → vanishing gradients
|
|
38
|
+
- Wrong LayerNorm dimension → incorrect normalization
|
|
39
|
+
- Dropout too high → underfitting
|
|
40
|
+
- Dropout too low → overfitting
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
import torch
|
|
44
|
+
import torch.nn as nn
|
|
45
|
+
|
|
46
|
+
from .attention import MultiHeadAttention
|
|
47
|
+
from .feedforward import FeedForward
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class TransformerDecoderBlock(nn.Module):
|
|
51
|
+
"""
|
|
52
|
+
Single Transformer Decoder Block.
|
|
53
|
+
|
|
54
|
+
This is one "layer" of the transformer. We stack multiple of these.
|
|
55
|
+
|
|
56
|
+
Architecture (Pre-LN):
|
|
57
|
+
1. LayerNorm → Multi-Head Attention → Residual
|
|
58
|
+
2. LayerNorm → Feed-Forward → Residual
|
|
59
|
+
|
|
60
|
+
Tensor shape flow:
|
|
61
|
+
Input: [batch_size, seq_len, d_model]
|
|
62
|
+
Output: [batch_size, seq_len, d_model] (same shape)
|
|
63
|
+
|
|
64
|
+
Why Pre-LN instead of Post-LN?
|
|
65
|
+
- Pre-LN: Normalize BEFORE sublayer → more stable gradients
|
|
66
|
+
- Post-LN: Normalize AFTER sublayer → can have training instability
|
|
67
|
+
- GPT-2 and modern transformers use Pre-LN
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
|
|
71
|
+
"""
|
|
72
|
+
Args:
|
|
73
|
+
d_model: Model dimension
|
|
74
|
+
n_heads: Number of attention heads
|
|
75
|
+
d_ff: Feed-forward hidden dimension
|
|
76
|
+
dropout: Dropout probability
|
|
77
|
+
"""
|
|
78
|
+
super().__init__()
|
|
79
|
+
|
|
80
|
+
# Multi-head self-attention
|
|
81
|
+
self.attention = MultiHeadAttention(d_model, n_heads, dropout)
|
|
82
|
+
|
|
83
|
+
# Feed-forward network
|
|
84
|
+
self.feed_forward = FeedForward(d_model, d_ff, dropout)
|
|
85
|
+
|
|
86
|
+
# Layer normalization (one for each sublayer)
|
|
87
|
+
self.norm1 = nn.LayerNorm(d_model)
|
|
88
|
+
self.norm2 = nn.LayerNorm(d_model)
|
|
89
|
+
|
|
90
|
+
# Dropout for residual connections
|
|
91
|
+
self.dropout = nn.Dropout(dropout)
|
|
92
|
+
|
|
93
|
+
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
|
|
94
|
+
"""
|
|
95
|
+
Args:
|
|
96
|
+
x: Input tensor [batch_size, seq_len, d_model]
|
|
97
|
+
mask: Causal mask [seq_len, seq_len]
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
Output tensor [batch_size, seq_len, d_model]
|
|
101
|
+
|
|
102
|
+
Step-by-step explanation:
|
|
103
|
+
1. Normalize input
|
|
104
|
+
2. Apply attention
|
|
105
|
+
3. Add residual (skip connection)
|
|
106
|
+
4. Normalize result
|
|
107
|
+
5. Apply feed-forward
|
|
108
|
+
6. Add residual (skip connection)
|
|
109
|
+
"""
|
|
110
|
+
# Sublayer 1: Self-Attention with residual
|
|
111
|
+
# Pre-LN: normalize first
|
|
112
|
+
normed = self.norm1(x)
|
|
113
|
+
|
|
114
|
+
# Apply attention
|
|
115
|
+
attention_output = self.attention(normed, mask)
|
|
116
|
+
|
|
117
|
+
# Residual connection: x + attention(norm(x))
|
|
118
|
+
x = x + self.dropout(attention_output)
|
|
119
|
+
|
|
120
|
+
# Sublayer 2: Feed-Forward with residual
|
|
121
|
+
# Pre-LN: normalize first
|
|
122
|
+
normed = self.norm2(x)
|
|
123
|
+
|
|
124
|
+
# Apply feed-forward
|
|
125
|
+
ff_output = self.feed_forward(normed)
|
|
126
|
+
|
|
127
|
+
# Residual connection: x + ffn(norm(x))
|
|
128
|
+
x = x + self.dropout(ff_output)
|
|
129
|
+
|
|
130
|
+
return x
|