nanogpt 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,23 @@
1
+ # frozen_string_literal: true
2
+
3
+ module NanoGPT
4
+ module Layers
5
+ # Feed-forward network with GELU activation
6
+ class MLP < Torch::NN::Module
7
+ def initialize(config)
8
+ super()
9
+ @c_fc = Torch::NN::Linear.new(config.n_embd, 4 * config.n_embd, bias: config.bias)
10
+ @gelu = Torch::NN::GELU.new
11
+ @c_proj = Torch::NN::Linear.new(4 * config.n_embd, config.n_embd, bias: config.bias)
12
+ @dropout = Torch::NN::Dropout.new(p: config.dropout)
13
+ end
14
+
15
+ def forward(x)
16
+ x = @c_fc.call(x)
17
+ x = @gelu.call(x)
18
+ x = @c_proj.call(x)
19
+ @dropout.call(x)
20
+ end
21
+ end
22
+ end
23
+ end
@@ -0,0 +1,42 @@
1
+ # frozen_string_literal: true
2
+
3
+ module NanoGPT
4
+ # Cosine learning rate scheduler with linear warmup
5
+ class LRScheduler
6
+ attr_reader :learning_rate, :min_lr, :warmup_iters, :lr_decay_iters
7
+
8
+ def initialize(learning_rate:, min_lr:, warmup_iters:, lr_decay_iters:)
9
+ @learning_rate = learning_rate
10
+ @min_lr = min_lr
11
+ @warmup_iters = warmup_iters
12
+ @lr_decay_iters = lr_decay_iters
13
+ end
14
+
15
+ # Get learning rate for given iteration
16
+ def get_lr(iter)
17
+ # 1) Linear warmup for warmup_iters steps
18
+ if iter < @warmup_iters
19
+ return @learning_rate * (iter + 1).to_f / (@warmup_iters + 1)
20
+ end
21
+
22
+ # 2) If iter > lr_decay_iters, return min learning rate
23
+ if iter > @lr_decay_iters
24
+ return @min_lr
25
+ end
26
+
27
+ # 3) In between, use cosine decay down to min learning rate
28
+ decay_ratio = (iter - @warmup_iters).to_f / (@lr_decay_iters - @warmup_iters)
29
+ coeff = 0.5 * (1.0 + Math.cos(Math::PI * decay_ratio))
30
+ @min_lr + coeff * (@learning_rate - @min_lr)
31
+ end
32
+
33
+ # Apply learning rate to optimizer
34
+ def step(optimizer, iter)
35
+ lr = get_lr(iter)
36
+ optimizer.param_groups.each do |group|
37
+ group[:lr] = lr
38
+ end
39
+ lr
40
+ end
41
+ end
42
+ end
@@ -0,0 +1,218 @@
1
+ # frozen_string_literal: true
2
+
3
+ module NanoGPT
4
+ # GPT Language Model
5
+ class GPT < Torch::NN::Module
6
+ attr_reader :config
7
+
8
+ def initialize(config)
9
+ super()
10
+ raise ArgumentError, "vocab_size must be set" unless config.vocab_size
11
+ raise ArgumentError, "block_size must be set" unless config.block_size
12
+
13
+ @config = config
14
+
15
+ # Token and position embeddings
16
+ @wte = Torch::NN::Embedding.new(config.vocab_size, config.n_embd)
17
+ @wpe = Torch::NN::Embedding.new(config.block_size, config.n_embd)
18
+ @drop = Torch::NN::Dropout.new(p: config.dropout)
19
+
20
+ # Transformer blocks
21
+ @h = Torch::NN::ModuleList.new(
22
+ config.n_layer.times.map { Layers::Block.new(config) }
23
+ )
24
+
25
+ # Final layer norm
26
+ @ln_f = Layers::LayerNorm.new(config.n_embd, bias: config.bias)
27
+
28
+ # Note: We use weight tying - lm_head shares weights with wte
29
+ # Instead of a separate Linear layer, we use wte.weight directly in forward
30
+
31
+ # Initialize weights
32
+ apply(method(:_init_weights))
33
+
34
+ # Special scaled init for residual projections (per GPT-2 paper)
35
+ named_parameters.each do |name, param|
36
+ if name.end_with?("c_proj.weight")
37
+ Torch::NN::Init.normal!(param, mean: 0.0, std: 0.02 / Math.sqrt(2 * config.n_layer))
38
+ end
39
+ end
40
+
41
+ puts format("number of parameters: %.2fM", num_params / 1e6)
42
+ end
43
+
44
+ def num_params(non_embedding: true)
45
+ n_params = parameters.sum(&:numel)
46
+ n_params -= @wpe.weight.numel if non_embedding
47
+ n_params
48
+ end
49
+
50
+ # Estimate model flops utilization (MFU)
51
+ # See PaLM paper Appendix B: https://arxiv.org/abs/2204.02311
52
+ def estimate_mfu(fwdbwd_per_iter, dt)
53
+ n = num_params
54
+ cfg = @config
55
+ l = cfg.n_layer
56
+ h = cfg.n_head
57
+ q = cfg.n_embd / cfg.n_head
58
+ t = cfg.block_size
59
+
60
+ # FLOPs per token and per forward-backward pass
61
+ flops_per_token = 6 * n + 12 * l * h * q * t
62
+ flops_per_fwdbwd = flops_per_token * t
63
+ flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
64
+
65
+ # Express throughput as ratio of A100 bfloat16 peak FLOPS (312 TFLOPS)
66
+ flops_achieved = flops_per_iter / dt
67
+ flops_promised = 312e12
68
+ flops_achieved / flops_promised
69
+ end
70
+
71
+ def forward(idx, targets: nil)
72
+ b, t = idx.shape
73
+ raise ArgumentError, "Sequence length #{t} exceeds block_size #{@config.block_size}" if t > @config.block_size
74
+
75
+ device = idx.device
76
+
77
+ # Position indices
78
+ pos = Torch.arange(0, t, dtype: :long, device: device)
79
+
80
+ # Embeddings
81
+ tok_emb = @wte.call(idx) # (B, T, n_embd)
82
+ pos_emb = @wpe.call(pos) # (T, n_embd)
83
+ x = @drop.call(tok_emb + pos_emb)
84
+
85
+ # Transformer blocks
86
+ @h.each { |block| x = block.call(x) }
87
+
88
+ # Final layer norm
89
+ x = @ln_f.call(x)
90
+
91
+ if targets
92
+ # Training: compute logits for all positions using tied weights
93
+ logits = Torch::NN::Functional.linear(x, @wte.weight, nil)
94
+ loss = Torch::NN::Functional.cross_entropy(
95
+ logits.view(-1, logits.size(-1)),
96
+ targets.view(-1),
97
+ ignore_index: -1
98
+ )
99
+ else
100
+ # Inference: only compute logits for last position (optimization)
101
+ # Use narrow to get last position: x[:, -1:, :]
102
+ x_last = x.narrow(1, x.size(1) - 1, 1)
103
+ logits = Torch::NN::Functional.linear(x_last, @wte.weight, nil)
104
+ loss = nil
105
+ end
106
+
107
+ [logits, loss]
108
+ end
109
+
110
+ def generate(idx, max_new_tokens, temperature: 1.0, top_k: nil)
111
+ Torch.no_grad do
112
+ max_new_tokens.times do
113
+ # Crop context if exceeds block_size
114
+ idx_cond = if idx.size(1) <= @config.block_size
115
+ idx
116
+ else
117
+ idx.narrow(1, idx.size(1) - @config.block_size, @config.block_size)
118
+ end
119
+
120
+ # Forward pass
121
+ logits, _loss = forward(idx_cond)
122
+
123
+ # Get logits for last position and scale by temperature
124
+ # logits shape is (B, 1, vocab_size), squeeze to (B, vocab_size)
125
+ logits = logits.squeeze(1) / temperature
126
+
127
+ # Optional top-k filtering
128
+ if top_k
129
+ k = [top_k, logits.size(-1)].min
130
+ v, _indices = logits.topk(k)
131
+ # Get the k-th largest value as threshold
132
+ threshold = v.narrow(1, k - 1, 1)
133
+ # Mask out values below threshold
134
+ logits = logits.masked_fill(logits.lt(threshold), -Float::INFINITY)
135
+ end
136
+
137
+ # Sample from probability distribution
138
+ probs = Torch::NN::Functional.softmax(logits, dim: -1)
139
+ idx_next = Torch.multinomial(probs, num_samples: 1)
140
+
141
+ # Append to sequence
142
+ idx = Torch.cat([idx, idx_next], dim: 1)
143
+ end
144
+ end
145
+
146
+ idx
147
+ end
148
+
149
+ def crop_block_size(block_size)
150
+ raise ArgumentError, "Cannot crop to larger block_size" if block_size > @config.block_size
151
+
152
+ @config.block_size = block_size
153
+
154
+ # Create new embedding with cropped weights
155
+ new_wpe = Torch::NN::Embedding.new(block_size, @config.n_embd)
156
+ Torch.no_grad do
157
+ new_wpe.weight.copy!(@wpe.weight[0...block_size])
158
+ end
159
+ @wpe = new_wpe
160
+
161
+ # Update attention masks in all blocks
162
+ @h.each do |block|
163
+ attn = block.instance_variable_get(:@attn)
164
+ next unless attn.instance_variable_defined?(:@mask)
165
+
166
+ mask = attn.instance_variable_get(:@mask)
167
+ attn.instance_variable_set(:@mask, mask[nil, nil, 0...block_size, 0...block_size])
168
+ end
169
+ end
170
+
171
+ def configure_optimizers(weight_decay:, learning_rate:, betas:, device_type:)
172
+ # Separate parameters into decay and no-decay groups
173
+ # All 2D+ params (weights) get weight decay, 1D params (biases, layernorm) don't
174
+ decay_params = []
175
+ nodecay_params = []
176
+
177
+ parameters.each do |param|
178
+ next unless param.requires_grad
179
+
180
+ if param.dim >= 2
181
+ decay_params << param
182
+ else
183
+ nodecay_params << param
184
+ end
185
+ end
186
+
187
+ num_decay = decay_params.sum(&:numel)
188
+ num_nodecay = nodecay_params.sum(&:numel)
189
+ puts "num decayed parameter tensors: #{decay_params.size}, with #{num_decay} parameters"
190
+ puts "num non-decayed parameter tensors: #{nodecay_params.size}, with #{num_nodecay} parameters"
191
+
192
+ # Create optimizer with parameter groups (using symbol keys for torch.rb)
193
+ Torch::Optim::AdamW.new(
194
+ [
195
+ { params: decay_params, weight_decay: weight_decay },
196
+ { params: nodecay_params, weight_decay: 0.0 }
197
+ ],
198
+ lr: learning_rate,
199
+ betas: betas
200
+ )
201
+ end
202
+
203
+ private
204
+
205
+ def _init_weights(mod)
206
+ case mod
207
+ when Torch::NN::Linear
208
+ Torch::NN::Init.normal!(mod.weight, mean: 0.0, std: 0.02)
209
+ # Check if bias exists (it won't when bias: false)
210
+ if mod.instance_variable_defined?(:@bias) && mod.instance_variable_get(:@bias)
211
+ Torch::NN::Init.zeros!(mod.instance_variable_get(:@bias))
212
+ end
213
+ when Torch::NN::Embedding
214
+ Torch::NN::Init.normal!(mod.weight, mean: 0.0, std: 0.02)
215
+ end
216
+ end
217
+ end
218
+ end
@@ -0,0 +1,106 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "json"
4
+
5
+ module NanoGPT
6
+ # Base tokenizer interface
7
+ class Tokenizer
8
+ attr_reader :vocab_size
9
+
10
+ def encode(text)
11
+ raise NotImplementedError
12
+ end
13
+
14
+ def decode(ids)
15
+ raise NotImplementedError
16
+ end
17
+
18
+ # Auto-detect and load the appropriate tokenizer
19
+ # If meta.json exists, use character-level; otherwise use GPT-2 BPE
20
+ def self.for_dataset(dataset_dir)
21
+ meta_path = File.join(dataset_dir, "meta.json")
22
+ if File.exist?(meta_path)
23
+ CharTokenizer.from_file(meta_path)
24
+ else
25
+ GPT2Tokenizer.new
26
+ end
27
+ end
28
+ end
29
+
30
+ # Character-level tokenizer
31
+ class CharTokenizer < Tokenizer
32
+ attr_reader :stoi, :itos
33
+
34
+ def initialize(stoi: nil, itos: nil)
35
+ super()
36
+ @stoi = stoi || {}
37
+ @itos = itos || {}
38
+ @vocab_size = @stoi.size
39
+ end
40
+
41
+ # Build vocabulary from text
42
+ def self.from_text(text)
43
+ chars = text.chars.uniq.sort
44
+ stoi = chars.each_with_index.to_h
45
+ itos = chars.each_with_index.map { |c, i| [i, c] }.to_h
46
+ new(stoi: stoi, itos: itos)
47
+ end
48
+
49
+ # Load from meta.json file
50
+ def self.from_file(path)
51
+ meta = JSON.parse(File.read(path))
52
+ # Convert string keys to integers for itos
53
+ itos = meta["itos"].transform_keys(&:to_i)
54
+ new(stoi: meta["stoi"], itos: itos)
55
+ end
56
+
57
+ # Encode string to list of integers
58
+ def encode(text)
59
+ text.chars.map { |c| @stoi[c] }
60
+ end
61
+
62
+ # Decode list of integers to string
63
+ def decode(ids)
64
+ ids.map { |i| @itos[i] }.join
65
+ end
66
+
67
+ # Save to meta.json file
68
+ def save(path)
69
+ meta = {
70
+ "vocab_size" => @vocab_size,
71
+ "stoi" => @stoi,
72
+ "itos" => @itos.transform_keys(&:to_s)
73
+ }
74
+ File.write(path, JSON.pretty_generate(meta))
75
+ end
76
+ end
77
+
78
+ # GPT-2 BPE tokenizer using tiktoken
79
+ class GPT2Tokenizer < Tokenizer
80
+ GPT2_VOCAB_SIZE = 50257
81
+ EOT_TOKEN = "<|endoftext|>"
82
+
83
+ def initialize
84
+ super()
85
+ require "tiktoken_ruby"
86
+ # GPT-2 uses the r50k_base encoding
87
+ @enc = Tiktoken.get_encoding(:r50k_base)
88
+ @vocab_size = GPT2_VOCAB_SIZE
89
+ end
90
+
91
+ # Encode string to list of integers
92
+ def encode(text)
93
+ @enc.encode(text)
94
+ end
95
+
96
+ # Decode list of integers to string
97
+ def decode(ids)
98
+ @enc.decode(ids)
99
+ end
100
+
101
+ # Get the end-of-text token ID
102
+ def eot_token
103
+ @enc.encode(EOT_TOKEN).first
104
+ end
105
+ end
106
+ end
@@ -0,0 +1,259 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "json"
4
+
5
+ module NanoGPT
6
+ # Configuration system for training and sampling
7
+ # Supports JSON config files with command-line overrides
8
+ #
9
+ # Priority (highest to lowest):
10
+ # 1. Command-line arguments (--key=value)
11
+ # 2. JSON config file (--config=path.json)
12
+ # 3. Default values
13
+ #
14
+ # Usage:
15
+ # config = TrainConfig.load(ARGV)
16
+ # config[:learning_rate] # => 0.001
17
+ #
18
+ class TrainConfig
19
+ # Defaults match bin/train exactly
20
+ DEFAULTS = {
21
+ # I/O
22
+ out_dir: "out-shakespeare-char",
23
+ eval_interval: 250,
24
+ log_interval: 10,
25
+ eval_iters: 200,
26
+ eval_only: false,
27
+ always_save_checkpoint: false,
28
+ init_from: "scratch", # 'scratch' or 'resume'
29
+
30
+ # Data
31
+ dataset: "shakespeare_char",
32
+ batch_size: 64,
33
+ block_size: 256,
34
+ gradient_accumulation_steps: 1,
35
+
36
+ # Model
37
+ n_layer: 6,
38
+ n_head: 6,
39
+ n_embd: 384,
40
+ dropout: 0.2,
41
+ bias: false,
42
+
43
+ # Optimizer
44
+ learning_rate: 1e-3,
45
+ weight_decay: 1e-1,
46
+ beta1: 0.9,
47
+ beta2: 0.99,
48
+ grad_clip: 1.0,
49
+
50
+ # LR scheduler
51
+ decay_lr: true,
52
+ warmup_iters: 100,
53
+ lr_decay_iters: 5000,
54
+ min_lr: 1e-4,
55
+
56
+ # Training
57
+ max_iters: 5000,
58
+
59
+ # System
60
+ device: "auto"
61
+ }.freeze
62
+
63
+ attr_reader :values
64
+
65
+ def initialize(values = {})
66
+ @values = DEFAULTS.merge(values)
67
+ end
68
+
69
+ def [](key)
70
+ @values[key.to_sym]
71
+ end
72
+
73
+ def []=(key, value)
74
+ @values[key.to_sym] = value
75
+ end
76
+
77
+ def to_h
78
+ @values.dup
79
+ end
80
+
81
+ # Load config from command-line args
82
+ # Supports:
83
+ # --config=path/to/config.json (load JSON file)
84
+ # --key=value (override specific values)
85
+ def self.load(args)
86
+ config = new
87
+
88
+ # First pass: find and load JSON config file
89
+ config_file = nil
90
+ args.each do |arg|
91
+ if arg.start_with?("--config=")
92
+ config_file = arg.split("=", 2).last
93
+ break
94
+ end
95
+ end
96
+
97
+ if config_file
98
+ config.load_json(config_file)
99
+ end
100
+
101
+ # Second pass: apply command-line overrides
102
+ args.each do |arg|
103
+ next unless arg.start_with?("--") && arg.include?("=")
104
+ next if arg.start_with?("--config=")
105
+
106
+ key, val = arg[2..].split("=", 2)
107
+ key = key.to_sym
108
+
109
+ unless config.values.key?(key)
110
+ puts "Warning: Unknown config key: #{key}"
111
+ next
112
+ end
113
+
114
+ config[key] = parse_value(val, config[key])
115
+ puts "Override: #{key} = #{config[key]}"
116
+ end
117
+
118
+ config
119
+ end
120
+
121
+ # Load values from JSON file
122
+ def load_json(path)
123
+ unless File.exist?(path)
124
+ raise "Config file not found: #{path}"
125
+ end
126
+
127
+ json = JSON.parse(File.read(path))
128
+ puts "Loaded config from #{path}"
129
+
130
+ json.each do |key, val|
131
+ key = key.to_sym
132
+ unless @values.key?(key)
133
+ puts "Warning: Unknown config key in JSON: #{key}"
134
+ next
135
+ end
136
+ @values[key] = val
137
+ end
138
+
139
+ self
140
+ end
141
+
142
+ # Save current config to JSON file
143
+ def save_json(path)
144
+ File.write(path, JSON.pretty_generate(@values))
145
+ puts "Saved config to #{path}"
146
+ end
147
+
148
+ private
149
+
150
+ def self.parse_value(val, existing)
151
+ case existing
152
+ when Integer then val.to_i
153
+ when Float then val.to_f
154
+ when TrueClass, FalseClass then val.downcase == "true"
155
+ else val
156
+ end
157
+ end
158
+ end
159
+
160
+ # Configuration for sampling/generation
161
+ class SampleConfig
162
+ # Defaults match bin/sample exactly
163
+ DEFAULTS = {
164
+ out_dir: "out-shakespeare-char",
165
+ dataset: "shakespeare_char",
166
+ start: "\n",
167
+ num_samples: 5,
168
+ max_new_tokens: 500,
169
+ temperature: 0.8,
170
+ top_k: 200,
171
+ seed: 1337,
172
+ device: "auto"
173
+ }.freeze
174
+
175
+ attr_reader :values
176
+
177
+ def initialize(values = {})
178
+ @values = DEFAULTS.merge(values)
179
+ end
180
+
181
+ def [](key)
182
+ @values[key.to_sym]
183
+ end
184
+
185
+ def []=(key, value)
186
+ @values[key.to_sym] = value
187
+ end
188
+
189
+ def to_h
190
+ @values.dup
191
+ end
192
+
193
+ def self.load(args)
194
+ config = new
195
+
196
+ # First pass: find and load JSON config file
197
+ config_file = nil
198
+ args.each do |arg|
199
+ if arg.start_with?("--config=")
200
+ config_file = arg.split("=", 2).last
201
+ break
202
+ end
203
+ end
204
+
205
+ if config_file
206
+ config.load_json(config_file)
207
+ end
208
+
209
+ # Second pass: apply command-line overrides
210
+ args.each do |arg|
211
+ next unless arg.start_with?("--") && arg.include?("=")
212
+ next if arg.start_with?("--config=")
213
+
214
+ key, val = arg[2..].split("=", 2)
215
+ key = key.to_sym
216
+
217
+ unless config.values.key?(key)
218
+ puts "Warning: Unknown config key: #{key}"
219
+ next
220
+ end
221
+
222
+ config[key] = parse_value(val, config[key])
223
+ end
224
+
225
+ config
226
+ end
227
+
228
+ def load_json(path)
229
+ unless File.exist?(path)
230
+ raise "Config file not found: #{path}"
231
+ end
232
+
233
+ json = JSON.parse(File.read(path))
234
+ puts "Loaded config from #{path}"
235
+
236
+ json.each do |key, val|
237
+ key = key.to_sym
238
+ unless @values.key?(key)
239
+ puts "Warning: Unknown config key in JSON: #{key}"
240
+ next
241
+ end
242
+ @values[key] = val
243
+ end
244
+
245
+ self
246
+ end
247
+
248
+ private
249
+
250
+ def self.parse_value(val, existing)
251
+ case existing
252
+ when Integer then val.to_i
253
+ when Float then val.to_f
254
+ when TrueClass, FalseClass then val.downcase == "true"
255
+ else val
256
+ end
257
+ end
258
+ end
259
+ end