nanogpt 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/Gemfile +7 -0
- data/Gemfile.lock +42 -0
- data/README.md +102 -0
- data/bin/bench +210 -0
- data/bin/sample +76 -0
- data/bin/train +82 -0
- data/config/train_gpt2.json +19 -0
- data/config/train_shakespeare.json +14 -0
- data/config/train_shakespeare_char.json +14 -0
- data/data/openwebtext/prepare.rb +287 -0
- data/data/shakespeare/prepare.rb +61 -0
- data/data/shakespeare_char/input.txt +40000 -0
- data/data/shakespeare_char/prepare.rb +73 -0
- data/exe/nanogpt +338 -0
- data/lib/nano_gpt/config.rb +42 -0
- data/lib/nano_gpt/data_loader.rb +74 -0
- data/lib/nano_gpt/device.rb +56 -0
- data/lib/nano_gpt/layers/block.rb +25 -0
- data/lib/nano_gpt/layers/causal_self_attention.rb +73 -0
- data/lib/nano_gpt/layers/layer_norm.rb +21 -0
- data/lib/nano_gpt/layers/mlp.rb +23 -0
- data/lib/nano_gpt/lr_scheduler.rb +42 -0
- data/lib/nano_gpt/model.rb +218 -0
- data/lib/nano_gpt/tokenizer.rb +106 -0
- data/lib/nano_gpt/train_config.rb +259 -0
- data/lib/nano_gpt/trainer.rb +221 -0
- data/lib/nano_gpt/version.rb +5 -0
- data/lib/nano_gpt.rb +18 -0
- data/nanogpt.gemspec +37 -0
- metadata +133 -0
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module NanoGPT
|
|
4
|
+
module Layers
|
|
5
|
+
# Feed-forward network with GELU activation
|
|
6
|
+
class MLP < Torch::NN::Module
|
|
7
|
+
def initialize(config)
|
|
8
|
+
super()
|
|
9
|
+
@c_fc = Torch::NN::Linear.new(config.n_embd, 4 * config.n_embd, bias: config.bias)
|
|
10
|
+
@gelu = Torch::NN::GELU.new
|
|
11
|
+
@c_proj = Torch::NN::Linear.new(4 * config.n_embd, config.n_embd, bias: config.bias)
|
|
12
|
+
@dropout = Torch::NN::Dropout.new(p: config.dropout)
|
|
13
|
+
end
|
|
14
|
+
|
|
15
|
+
def forward(x)
|
|
16
|
+
x = @c_fc.call(x)
|
|
17
|
+
x = @gelu.call(x)
|
|
18
|
+
x = @c_proj.call(x)
|
|
19
|
+
@dropout.call(x)
|
|
20
|
+
end
|
|
21
|
+
end
|
|
22
|
+
end
|
|
23
|
+
end
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module NanoGPT
|
|
4
|
+
# Cosine learning rate scheduler with linear warmup
|
|
5
|
+
class LRScheduler
|
|
6
|
+
attr_reader :learning_rate, :min_lr, :warmup_iters, :lr_decay_iters
|
|
7
|
+
|
|
8
|
+
def initialize(learning_rate:, min_lr:, warmup_iters:, lr_decay_iters:)
|
|
9
|
+
@learning_rate = learning_rate
|
|
10
|
+
@min_lr = min_lr
|
|
11
|
+
@warmup_iters = warmup_iters
|
|
12
|
+
@lr_decay_iters = lr_decay_iters
|
|
13
|
+
end
|
|
14
|
+
|
|
15
|
+
# Get learning rate for given iteration
|
|
16
|
+
def get_lr(iter)
|
|
17
|
+
# 1) Linear warmup for warmup_iters steps
|
|
18
|
+
if iter < @warmup_iters
|
|
19
|
+
return @learning_rate * (iter + 1).to_f / (@warmup_iters + 1)
|
|
20
|
+
end
|
|
21
|
+
|
|
22
|
+
# 2) If iter > lr_decay_iters, return min learning rate
|
|
23
|
+
if iter > @lr_decay_iters
|
|
24
|
+
return @min_lr
|
|
25
|
+
end
|
|
26
|
+
|
|
27
|
+
# 3) In between, use cosine decay down to min learning rate
|
|
28
|
+
decay_ratio = (iter - @warmup_iters).to_f / (@lr_decay_iters - @warmup_iters)
|
|
29
|
+
coeff = 0.5 * (1.0 + Math.cos(Math::PI * decay_ratio))
|
|
30
|
+
@min_lr + coeff * (@learning_rate - @min_lr)
|
|
31
|
+
end
|
|
32
|
+
|
|
33
|
+
# Apply learning rate to optimizer
|
|
34
|
+
def step(optimizer, iter)
|
|
35
|
+
lr = get_lr(iter)
|
|
36
|
+
optimizer.param_groups.each do |group|
|
|
37
|
+
group[:lr] = lr
|
|
38
|
+
end
|
|
39
|
+
lr
|
|
40
|
+
end
|
|
41
|
+
end
|
|
42
|
+
end
|
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module NanoGPT
|
|
4
|
+
# GPT Language Model
|
|
5
|
+
class GPT < Torch::NN::Module
|
|
6
|
+
attr_reader :config
|
|
7
|
+
|
|
8
|
+
def initialize(config)
|
|
9
|
+
super()
|
|
10
|
+
raise ArgumentError, "vocab_size must be set" unless config.vocab_size
|
|
11
|
+
raise ArgumentError, "block_size must be set" unless config.block_size
|
|
12
|
+
|
|
13
|
+
@config = config
|
|
14
|
+
|
|
15
|
+
# Token and position embeddings
|
|
16
|
+
@wte = Torch::NN::Embedding.new(config.vocab_size, config.n_embd)
|
|
17
|
+
@wpe = Torch::NN::Embedding.new(config.block_size, config.n_embd)
|
|
18
|
+
@drop = Torch::NN::Dropout.new(p: config.dropout)
|
|
19
|
+
|
|
20
|
+
# Transformer blocks
|
|
21
|
+
@h = Torch::NN::ModuleList.new(
|
|
22
|
+
config.n_layer.times.map { Layers::Block.new(config) }
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
# Final layer norm
|
|
26
|
+
@ln_f = Layers::LayerNorm.new(config.n_embd, bias: config.bias)
|
|
27
|
+
|
|
28
|
+
# Note: We use weight tying - lm_head shares weights with wte
|
|
29
|
+
# Instead of a separate Linear layer, we use wte.weight directly in forward
|
|
30
|
+
|
|
31
|
+
# Initialize weights
|
|
32
|
+
apply(method(:_init_weights))
|
|
33
|
+
|
|
34
|
+
# Special scaled init for residual projections (per GPT-2 paper)
|
|
35
|
+
named_parameters.each do |name, param|
|
|
36
|
+
if name.end_with?("c_proj.weight")
|
|
37
|
+
Torch::NN::Init.normal!(param, mean: 0.0, std: 0.02 / Math.sqrt(2 * config.n_layer))
|
|
38
|
+
end
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
puts format("number of parameters: %.2fM", num_params / 1e6)
|
|
42
|
+
end
|
|
43
|
+
|
|
44
|
+
def num_params(non_embedding: true)
|
|
45
|
+
n_params = parameters.sum(&:numel)
|
|
46
|
+
n_params -= @wpe.weight.numel if non_embedding
|
|
47
|
+
n_params
|
|
48
|
+
end
|
|
49
|
+
|
|
50
|
+
# Estimate model flops utilization (MFU)
|
|
51
|
+
# See PaLM paper Appendix B: https://arxiv.org/abs/2204.02311
|
|
52
|
+
def estimate_mfu(fwdbwd_per_iter, dt)
|
|
53
|
+
n = num_params
|
|
54
|
+
cfg = @config
|
|
55
|
+
l = cfg.n_layer
|
|
56
|
+
h = cfg.n_head
|
|
57
|
+
q = cfg.n_embd / cfg.n_head
|
|
58
|
+
t = cfg.block_size
|
|
59
|
+
|
|
60
|
+
# FLOPs per token and per forward-backward pass
|
|
61
|
+
flops_per_token = 6 * n + 12 * l * h * q * t
|
|
62
|
+
flops_per_fwdbwd = flops_per_token * t
|
|
63
|
+
flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
|
|
64
|
+
|
|
65
|
+
# Express throughput as ratio of A100 bfloat16 peak FLOPS (312 TFLOPS)
|
|
66
|
+
flops_achieved = flops_per_iter / dt
|
|
67
|
+
flops_promised = 312e12
|
|
68
|
+
flops_achieved / flops_promised
|
|
69
|
+
end
|
|
70
|
+
|
|
71
|
+
def forward(idx, targets: nil)
|
|
72
|
+
b, t = idx.shape
|
|
73
|
+
raise ArgumentError, "Sequence length #{t} exceeds block_size #{@config.block_size}" if t > @config.block_size
|
|
74
|
+
|
|
75
|
+
device = idx.device
|
|
76
|
+
|
|
77
|
+
# Position indices
|
|
78
|
+
pos = Torch.arange(0, t, dtype: :long, device: device)
|
|
79
|
+
|
|
80
|
+
# Embeddings
|
|
81
|
+
tok_emb = @wte.call(idx) # (B, T, n_embd)
|
|
82
|
+
pos_emb = @wpe.call(pos) # (T, n_embd)
|
|
83
|
+
x = @drop.call(tok_emb + pos_emb)
|
|
84
|
+
|
|
85
|
+
# Transformer blocks
|
|
86
|
+
@h.each { |block| x = block.call(x) }
|
|
87
|
+
|
|
88
|
+
# Final layer norm
|
|
89
|
+
x = @ln_f.call(x)
|
|
90
|
+
|
|
91
|
+
if targets
|
|
92
|
+
# Training: compute logits for all positions using tied weights
|
|
93
|
+
logits = Torch::NN::Functional.linear(x, @wte.weight, nil)
|
|
94
|
+
loss = Torch::NN::Functional.cross_entropy(
|
|
95
|
+
logits.view(-1, logits.size(-1)),
|
|
96
|
+
targets.view(-1),
|
|
97
|
+
ignore_index: -1
|
|
98
|
+
)
|
|
99
|
+
else
|
|
100
|
+
# Inference: only compute logits for last position (optimization)
|
|
101
|
+
# Use narrow to get last position: x[:, -1:, :]
|
|
102
|
+
x_last = x.narrow(1, x.size(1) - 1, 1)
|
|
103
|
+
logits = Torch::NN::Functional.linear(x_last, @wte.weight, nil)
|
|
104
|
+
loss = nil
|
|
105
|
+
end
|
|
106
|
+
|
|
107
|
+
[logits, loss]
|
|
108
|
+
end
|
|
109
|
+
|
|
110
|
+
def generate(idx, max_new_tokens, temperature: 1.0, top_k: nil)
|
|
111
|
+
Torch.no_grad do
|
|
112
|
+
max_new_tokens.times do
|
|
113
|
+
# Crop context if exceeds block_size
|
|
114
|
+
idx_cond = if idx.size(1) <= @config.block_size
|
|
115
|
+
idx
|
|
116
|
+
else
|
|
117
|
+
idx.narrow(1, idx.size(1) - @config.block_size, @config.block_size)
|
|
118
|
+
end
|
|
119
|
+
|
|
120
|
+
# Forward pass
|
|
121
|
+
logits, _loss = forward(idx_cond)
|
|
122
|
+
|
|
123
|
+
# Get logits for last position and scale by temperature
|
|
124
|
+
# logits shape is (B, 1, vocab_size), squeeze to (B, vocab_size)
|
|
125
|
+
logits = logits.squeeze(1) / temperature
|
|
126
|
+
|
|
127
|
+
# Optional top-k filtering
|
|
128
|
+
if top_k
|
|
129
|
+
k = [top_k, logits.size(-1)].min
|
|
130
|
+
v, _indices = logits.topk(k)
|
|
131
|
+
# Get the k-th largest value as threshold
|
|
132
|
+
threshold = v.narrow(1, k - 1, 1)
|
|
133
|
+
# Mask out values below threshold
|
|
134
|
+
logits = logits.masked_fill(logits.lt(threshold), -Float::INFINITY)
|
|
135
|
+
end
|
|
136
|
+
|
|
137
|
+
# Sample from probability distribution
|
|
138
|
+
probs = Torch::NN::Functional.softmax(logits, dim: -1)
|
|
139
|
+
idx_next = Torch.multinomial(probs, num_samples: 1)
|
|
140
|
+
|
|
141
|
+
# Append to sequence
|
|
142
|
+
idx = Torch.cat([idx, idx_next], dim: 1)
|
|
143
|
+
end
|
|
144
|
+
end
|
|
145
|
+
|
|
146
|
+
idx
|
|
147
|
+
end
|
|
148
|
+
|
|
149
|
+
def crop_block_size(block_size)
|
|
150
|
+
raise ArgumentError, "Cannot crop to larger block_size" if block_size > @config.block_size
|
|
151
|
+
|
|
152
|
+
@config.block_size = block_size
|
|
153
|
+
|
|
154
|
+
# Create new embedding with cropped weights
|
|
155
|
+
new_wpe = Torch::NN::Embedding.new(block_size, @config.n_embd)
|
|
156
|
+
Torch.no_grad do
|
|
157
|
+
new_wpe.weight.copy!(@wpe.weight[0...block_size])
|
|
158
|
+
end
|
|
159
|
+
@wpe = new_wpe
|
|
160
|
+
|
|
161
|
+
# Update attention masks in all blocks
|
|
162
|
+
@h.each do |block|
|
|
163
|
+
attn = block.instance_variable_get(:@attn)
|
|
164
|
+
next unless attn.instance_variable_defined?(:@mask)
|
|
165
|
+
|
|
166
|
+
mask = attn.instance_variable_get(:@mask)
|
|
167
|
+
attn.instance_variable_set(:@mask, mask[nil, nil, 0...block_size, 0...block_size])
|
|
168
|
+
end
|
|
169
|
+
end
|
|
170
|
+
|
|
171
|
+
def configure_optimizers(weight_decay:, learning_rate:, betas:, device_type:)
|
|
172
|
+
# Separate parameters into decay and no-decay groups
|
|
173
|
+
# All 2D+ params (weights) get weight decay, 1D params (biases, layernorm) don't
|
|
174
|
+
decay_params = []
|
|
175
|
+
nodecay_params = []
|
|
176
|
+
|
|
177
|
+
parameters.each do |param|
|
|
178
|
+
next unless param.requires_grad
|
|
179
|
+
|
|
180
|
+
if param.dim >= 2
|
|
181
|
+
decay_params << param
|
|
182
|
+
else
|
|
183
|
+
nodecay_params << param
|
|
184
|
+
end
|
|
185
|
+
end
|
|
186
|
+
|
|
187
|
+
num_decay = decay_params.sum(&:numel)
|
|
188
|
+
num_nodecay = nodecay_params.sum(&:numel)
|
|
189
|
+
puts "num decayed parameter tensors: #{decay_params.size}, with #{num_decay} parameters"
|
|
190
|
+
puts "num non-decayed parameter tensors: #{nodecay_params.size}, with #{num_nodecay} parameters"
|
|
191
|
+
|
|
192
|
+
# Create optimizer with parameter groups (using symbol keys for torch.rb)
|
|
193
|
+
Torch::Optim::AdamW.new(
|
|
194
|
+
[
|
|
195
|
+
{ params: decay_params, weight_decay: weight_decay },
|
|
196
|
+
{ params: nodecay_params, weight_decay: 0.0 }
|
|
197
|
+
],
|
|
198
|
+
lr: learning_rate,
|
|
199
|
+
betas: betas
|
|
200
|
+
)
|
|
201
|
+
end
|
|
202
|
+
|
|
203
|
+
private
|
|
204
|
+
|
|
205
|
+
def _init_weights(mod)
|
|
206
|
+
case mod
|
|
207
|
+
when Torch::NN::Linear
|
|
208
|
+
Torch::NN::Init.normal!(mod.weight, mean: 0.0, std: 0.02)
|
|
209
|
+
# Check if bias exists (it won't when bias: false)
|
|
210
|
+
if mod.instance_variable_defined?(:@bias) && mod.instance_variable_get(:@bias)
|
|
211
|
+
Torch::NN::Init.zeros!(mod.instance_variable_get(:@bias))
|
|
212
|
+
end
|
|
213
|
+
when Torch::NN::Embedding
|
|
214
|
+
Torch::NN::Init.normal!(mod.weight, mean: 0.0, std: 0.02)
|
|
215
|
+
end
|
|
216
|
+
end
|
|
217
|
+
end
|
|
218
|
+
end
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "json"
|
|
4
|
+
|
|
5
|
+
module NanoGPT
|
|
6
|
+
# Base tokenizer interface
|
|
7
|
+
class Tokenizer
|
|
8
|
+
attr_reader :vocab_size
|
|
9
|
+
|
|
10
|
+
def encode(text)
|
|
11
|
+
raise NotImplementedError
|
|
12
|
+
end
|
|
13
|
+
|
|
14
|
+
def decode(ids)
|
|
15
|
+
raise NotImplementedError
|
|
16
|
+
end
|
|
17
|
+
|
|
18
|
+
# Auto-detect and load the appropriate tokenizer
|
|
19
|
+
# If meta.json exists, use character-level; otherwise use GPT-2 BPE
|
|
20
|
+
def self.for_dataset(dataset_dir)
|
|
21
|
+
meta_path = File.join(dataset_dir, "meta.json")
|
|
22
|
+
if File.exist?(meta_path)
|
|
23
|
+
CharTokenizer.from_file(meta_path)
|
|
24
|
+
else
|
|
25
|
+
GPT2Tokenizer.new
|
|
26
|
+
end
|
|
27
|
+
end
|
|
28
|
+
end
|
|
29
|
+
|
|
30
|
+
# Character-level tokenizer
|
|
31
|
+
class CharTokenizer < Tokenizer
|
|
32
|
+
attr_reader :stoi, :itos
|
|
33
|
+
|
|
34
|
+
def initialize(stoi: nil, itos: nil)
|
|
35
|
+
super()
|
|
36
|
+
@stoi = stoi || {}
|
|
37
|
+
@itos = itos || {}
|
|
38
|
+
@vocab_size = @stoi.size
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
# Build vocabulary from text
|
|
42
|
+
def self.from_text(text)
|
|
43
|
+
chars = text.chars.uniq.sort
|
|
44
|
+
stoi = chars.each_with_index.to_h
|
|
45
|
+
itos = chars.each_with_index.map { |c, i| [i, c] }.to_h
|
|
46
|
+
new(stoi: stoi, itos: itos)
|
|
47
|
+
end
|
|
48
|
+
|
|
49
|
+
# Load from meta.json file
|
|
50
|
+
def self.from_file(path)
|
|
51
|
+
meta = JSON.parse(File.read(path))
|
|
52
|
+
# Convert string keys to integers for itos
|
|
53
|
+
itos = meta["itos"].transform_keys(&:to_i)
|
|
54
|
+
new(stoi: meta["stoi"], itos: itos)
|
|
55
|
+
end
|
|
56
|
+
|
|
57
|
+
# Encode string to list of integers
|
|
58
|
+
def encode(text)
|
|
59
|
+
text.chars.map { |c| @stoi[c] }
|
|
60
|
+
end
|
|
61
|
+
|
|
62
|
+
# Decode list of integers to string
|
|
63
|
+
def decode(ids)
|
|
64
|
+
ids.map { |i| @itos[i] }.join
|
|
65
|
+
end
|
|
66
|
+
|
|
67
|
+
# Save to meta.json file
|
|
68
|
+
def save(path)
|
|
69
|
+
meta = {
|
|
70
|
+
"vocab_size" => @vocab_size,
|
|
71
|
+
"stoi" => @stoi,
|
|
72
|
+
"itos" => @itos.transform_keys(&:to_s)
|
|
73
|
+
}
|
|
74
|
+
File.write(path, JSON.pretty_generate(meta))
|
|
75
|
+
end
|
|
76
|
+
end
|
|
77
|
+
|
|
78
|
+
# GPT-2 BPE tokenizer using tiktoken
|
|
79
|
+
class GPT2Tokenizer < Tokenizer
|
|
80
|
+
GPT2_VOCAB_SIZE = 50257
|
|
81
|
+
EOT_TOKEN = "<|endoftext|>"
|
|
82
|
+
|
|
83
|
+
def initialize
|
|
84
|
+
super()
|
|
85
|
+
require "tiktoken_ruby"
|
|
86
|
+
# GPT-2 uses the r50k_base encoding
|
|
87
|
+
@enc = Tiktoken.get_encoding(:r50k_base)
|
|
88
|
+
@vocab_size = GPT2_VOCAB_SIZE
|
|
89
|
+
end
|
|
90
|
+
|
|
91
|
+
# Encode string to list of integers
|
|
92
|
+
def encode(text)
|
|
93
|
+
@enc.encode(text)
|
|
94
|
+
end
|
|
95
|
+
|
|
96
|
+
# Decode list of integers to string
|
|
97
|
+
def decode(ids)
|
|
98
|
+
@enc.decode(ids)
|
|
99
|
+
end
|
|
100
|
+
|
|
101
|
+
# Get the end-of-text token ID
|
|
102
|
+
def eot_token
|
|
103
|
+
@enc.encode(EOT_TOKEN).first
|
|
104
|
+
end
|
|
105
|
+
end
|
|
106
|
+
end
|
|
@@ -0,0 +1,259 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "json"
|
|
4
|
+
|
|
5
|
+
module NanoGPT
|
|
6
|
+
# Configuration system for training and sampling
|
|
7
|
+
# Supports JSON config files with command-line overrides
|
|
8
|
+
#
|
|
9
|
+
# Priority (highest to lowest):
|
|
10
|
+
# 1. Command-line arguments (--key=value)
|
|
11
|
+
# 2. JSON config file (--config=path.json)
|
|
12
|
+
# 3. Default values
|
|
13
|
+
#
|
|
14
|
+
# Usage:
|
|
15
|
+
# config = TrainConfig.load(ARGV)
|
|
16
|
+
# config[:learning_rate] # => 0.001
|
|
17
|
+
#
|
|
18
|
+
class TrainConfig
|
|
19
|
+
# Defaults match bin/train exactly
|
|
20
|
+
DEFAULTS = {
|
|
21
|
+
# I/O
|
|
22
|
+
out_dir: "out-shakespeare-char",
|
|
23
|
+
eval_interval: 250,
|
|
24
|
+
log_interval: 10,
|
|
25
|
+
eval_iters: 200,
|
|
26
|
+
eval_only: false,
|
|
27
|
+
always_save_checkpoint: false,
|
|
28
|
+
init_from: "scratch", # 'scratch' or 'resume'
|
|
29
|
+
|
|
30
|
+
# Data
|
|
31
|
+
dataset: "shakespeare_char",
|
|
32
|
+
batch_size: 64,
|
|
33
|
+
block_size: 256,
|
|
34
|
+
gradient_accumulation_steps: 1,
|
|
35
|
+
|
|
36
|
+
# Model
|
|
37
|
+
n_layer: 6,
|
|
38
|
+
n_head: 6,
|
|
39
|
+
n_embd: 384,
|
|
40
|
+
dropout: 0.2,
|
|
41
|
+
bias: false,
|
|
42
|
+
|
|
43
|
+
# Optimizer
|
|
44
|
+
learning_rate: 1e-3,
|
|
45
|
+
weight_decay: 1e-1,
|
|
46
|
+
beta1: 0.9,
|
|
47
|
+
beta2: 0.99,
|
|
48
|
+
grad_clip: 1.0,
|
|
49
|
+
|
|
50
|
+
# LR scheduler
|
|
51
|
+
decay_lr: true,
|
|
52
|
+
warmup_iters: 100,
|
|
53
|
+
lr_decay_iters: 5000,
|
|
54
|
+
min_lr: 1e-4,
|
|
55
|
+
|
|
56
|
+
# Training
|
|
57
|
+
max_iters: 5000,
|
|
58
|
+
|
|
59
|
+
# System
|
|
60
|
+
device: "auto"
|
|
61
|
+
}.freeze
|
|
62
|
+
|
|
63
|
+
attr_reader :values
|
|
64
|
+
|
|
65
|
+
def initialize(values = {})
|
|
66
|
+
@values = DEFAULTS.merge(values)
|
|
67
|
+
end
|
|
68
|
+
|
|
69
|
+
def [](key)
|
|
70
|
+
@values[key.to_sym]
|
|
71
|
+
end
|
|
72
|
+
|
|
73
|
+
def []=(key, value)
|
|
74
|
+
@values[key.to_sym] = value
|
|
75
|
+
end
|
|
76
|
+
|
|
77
|
+
def to_h
|
|
78
|
+
@values.dup
|
|
79
|
+
end
|
|
80
|
+
|
|
81
|
+
# Load config from command-line args
|
|
82
|
+
# Supports:
|
|
83
|
+
# --config=path/to/config.json (load JSON file)
|
|
84
|
+
# --key=value (override specific values)
|
|
85
|
+
def self.load(args)
|
|
86
|
+
config = new
|
|
87
|
+
|
|
88
|
+
# First pass: find and load JSON config file
|
|
89
|
+
config_file = nil
|
|
90
|
+
args.each do |arg|
|
|
91
|
+
if arg.start_with?("--config=")
|
|
92
|
+
config_file = arg.split("=", 2).last
|
|
93
|
+
break
|
|
94
|
+
end
|
|
95
|
+
end
|
|
96
|
+
|
|
97
|
+
if config_file
|
|
98
|
+
config.load_json(config_file)
|
|
99
|
+
end
|
|
100
|
+
|
|
101
|
+
# Second pass: apply command-line overrides
|
|
102
|
+
args.each do |arg|
|
|
103
|
+
next unless arg.start_with?("--") && arg.include?("=")
|
|
104
|
+
next if arg.start_with?("--config=")
|
|
105
|
+
|
|
106
|
+
key, val = arg[2..].split("=", 2)
|
|
107
|
+
key = key.to_sym
|
|
108
|
+
|
|
109
|
+
unless config.values.key?(key)
|
|
110
|
+
puts "Warning: Unknown config key: #{key}"
|
|
111
|
+
next
|
|
112
|
+
end
|
|
113
|
+
|
|
114
|
+
config[key] = parse_value(val, config[key])
|
|
115
|
+
puts "Override: #{key} = #{config[key]}"
|
|
116
|
+
end
|
|
117
|
+
|
|
118
|
+
config
|
|
119
|
+
end
|
|
120
|
+
|
|
121
|
+
# Load values from JSON file
|
|
122
|
+
def load_json(path)
|
|
123
|
+
unless File.exist?(path)
|
|
124
|
+
raise "Config file not found: #{path}"
|
|
125
|
+
end
|
|
126
|
+
|
|
127
|
+
json = JSON.parse(File.read(path))
|
|
128
|
+
puts "Loaded config from #{path}"
|
|
129
|
+
|
|
130
|
+
json.each do |key, val|
|
|
131
|
+
key = key.to_sym
|
|
132
|
+
unless @values.key?(key)
|
|
133
|
+
puts "Warning: Unknown config key in JSON: #{key}"
|
|
134
|
+
next
|
|
135
|
+
end
|
|
136
|
+
@values[key] = val
|
|
137
|
+
end
|
|
138
|
+
|
|
139
|
+
self
|
|
140
|
+
end
|
|
141
|
+
|
|
142
|
+
# Save current config to JSON file
|
|
143
|
+
def save_json(path)
|
|
144
|
+
File.write(path, JSON.pretty_generate(@values))
|
|
145
|
+
puts "Saved config to #{path}"
|
|
146
|
+
end
|
|
147
|
+
|
|
148
|
+
private
|
|
149
|
+
|
|
150
|
+
def self.parse_value(val, existing)
|
|
151
|
+
case existing
|
|
152
|
+
when Integer then val.to_i
|
|
153
|
+
when Float then val.to_f
|
|
154
|
+
when TrueClass, FalseClass then val.downcase == "true"
|
|
155
|
+
else val
|
|
156
|
+
end
|
|
157
|
+
end
|
|
158
|
+
end
|
|
159
|
+
|
|
160
|
+
# Configuration for sampling/generation
|
|
161
|
+
class SampleConfig
|
|
162
|
+
# Defaults match bin/sample exactly
|
|
163
|
+
DEFAULTS = {
|
|
164
|
+
out_dir: "out-shakespeare-char",
|
|
165
|
+
dataset: "shakespeare_char",
|
|
166
|
+
start: "\n",
|
|
167
|
+
num_samples: 5,
|
|
168
|
+
max_new_tokens: 500,
|
|
169
|
+
temperature: 0.8,
|
|
170
|
+
top_k: 200,
|
|
171
|
+
seed: 1337,
|
|
172
|
+
device: "auto"
|
|
173
|
+
}.freeze
|
|
174
|
+
|
|
175
|
+
attr_reader :values
|
|
176
|
+
|
|
177
|
+
def initialize(values = {})
|
|
178
|
+
@values = DEFAULTS.merge(values)
|
|
179
|
+
end
|
|
180
|
+
|
|
181
|
+
def [](key)
|
|
182
|
+
@values[key.to_sym]
|
|
183
|
+
end
|
|
184
|
+
|
|
185
|
+
def []=(key, value)
|
|
186
|
+
@values[key.to_sym] = value
|
|
187
|
+
end
|
|
188
|
+
|
|
189
|
+
def to_h
|
|
190
|
+
@values.dup
|
|
191
|
+
end
|
|
192
|
+
|
|
193
|
+
def self.load(args)
|
|
194
|
+
config = new
|
|
195
|
+
|
|
196
|
+
# First pass: find and load JSON config file
|
|
197
|
+
config_file = nil
|
|
198
|
+
args.each do |arg|
|
|
199
|
+
if arg.start_with?("--config=")
|
|
200
|
+
config_file = arg.split("=", 2).last
|
|
201
|
+
break
|
|
202
|
+
end
|
|
203
|
+
end
|
|
204
|
+
|
|
205
|
+
if config_file
|
|
206
|
+
config.load_json(config_file)
|
|
207
|
+
end
|
|
208
|
+
|
|
209
|
+
# Second pass: apply command-line overrides
|
|
210
|
+
args.each do |arg|
|
|
211
|
+
next unless arg.start_with?("--") && arg.include?("=")
|
|
212
|
+
next if arg.start_with?("--config=")
|
|
213
|
+
|
|
214
|
+
key, val = arg[2..].split("=", 2)
|
|
215
|
+
key = key.to_sym
|
|
216
|
+
|
|
217
|
+
unless config.values.key?(key)
|
|
218
|
+
puts "Warning: Unknown config key: #{key}"
|
|
219
|
+
next
|
|
220
|
+
end
|
|
221
|
+
|
|
222
|
+
config[key] = parse_value(val, config[key])
|
|
223
|
+
end
|
|
224
|
+
|
|
225
|
+
config
|
|
226
|
+
end
|
|
227
|
+
|
|
228
|
+
def load_json(path)
|
|
229
|
+
unless File.exist?(path)
|
|
230
|
+
raise "Config file not found: #{path}"
|
|
231
|
+
end
|
|
232
|
+
|
|
233
|
+
json = JSON.parse(File.read(path))
|
|
234
|
+
puts "Loaded config from #{path}"
|
|
235
|
+
|
|
236
|
+
json.each do |key, val|
|
|
237
|
+
key = key.to_sym
|
|
238
|
+
unless @values.key?(key)
|
|
239
|
+
puts "Warning: Unknown config key in JSON: #{key}"
|
|
240
|
+
next
|
|
241
|
+
end
|
|
242
|
+
@values[key] = val
|
|
243
|
+
end
|
|
244
|
+
|
|
245
|
+
self
|
|
246
|
+
end
|
|
247
|
+
|
|
248
|
+
private
|
|
249
|
+
|
|
250
|
+
def self.parse_value(val, existing)
|
|
251
|
+
case existing
|
|
252
|
+
when Integer then val.to_i
|
|
253
|
+
when Float then val.to_f
|
|
254
|
+
when TrueClass, FalseClass then val.downcase == "true"
|
|
255
|
+
else val
|
|
256
|
+
end
|
|
257
|
+
end
|
|
258
|
+
end
|
|
259
|
+
end
|