nanogpt 0.2.0 → 0.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Gemfile.lock +30 -1
- data/docs/ARCHITECTURE.md +429 -0
- data/exe/nanogpt +210 -233
- data/lib/nano_gpt/bpe_textfile_preparer.rb +105 -0
- data/lib/nano_gpt/data_loader.rb +5 -20
- data/lib/nano_gpt/layers/block.rb +6 -1
- data/lib/nano_gpt/layers/causal_self_attention.rb +11 -1
- data/lib/nano_gpt/model.rb +1 -7
- data/lib/nano_gpt/textfile_preparer.rb +189 -0
- data/lib/nano_gpt/train_config.rb +80 -146
- data/lib/nano_gpt/trainer.rb +21 -48
- data/lib/nano_gpt/version.rb +1 -1
- data/lib/nano_gpt/web/metrics_store.rb +136 -0
- data/lib/nano_gpt/web/server.rb +294 -0
- data/lib/nano_gpt/web/sse_notifier.rb +37 -0
- data/lib/nano_gpt/web/training_state.rb +56 -0
- data/lib/nano_gpt/web/training_worker.rb +153 -0
- data/lib/nano_gpt/web/views/layout.erb +78 -0
- data/lib/nano_gpt/web/views/run_detail.erb +432 -0
- data/lib/nano_gpt/web/views/runs.erb +434 -0
- data/lib/nano_gpt/web/web_trainer.rb +210 -0
- data/lib/nano_gpt/web.rb +9 -0
- data/lib/nano_gpt.rb +1 -0
- data/nanogpt.gemspec +4 -0
- metadata +71 -2
data/lib/nano_gpt/data_loader.rb
CHANGED
|
@@ -4,9 +4,10 @@ module NanoGPT
|
|
|
4
4
|
# Loads batches from binary token files
|
|
5
5
|
# Memory-efficient: reads from file each batch (like Python's memmap recreation)
|
|
6
6
|
class DataLoader
|
|
7
|
-
|
|
7
|
+
# Each token is stored as uint16 (2 bytes) in little-endian format
|
|
8
|
+
BYTES_PER_TOKEN = 2
|
|
8
9
|
|
|
9
|
-
|
|
10
|
+
attr_reader :block_size, :batch_size, :train_size, :val_size
|
|
10
11
|
|
|
11
12
|
def initialize(data_dir:, block_size:, batch_size:, device: "cpu")
|
|
12
13
|
@data_dir = data_dir
|
|
@@ -14,7 +15,6 @@ module NanoGPT
|
|
|
14
15
|
@batch_size = batch_size
|
|
15
16
|
@device = device
|
|
16
17
|
|
|
17
|
-
# Store file paths and sizes (NOT the data itself)
|
|
18
18
|
@train_path = File.join(data_dir, "train.bin")
|
|
19
19
|
@val_path = File.join(data_dir, "val.bin")
|
|
20
20
|
|
|
@@ -22,47 +22,32 @@ module NanoGPT
|
|
|
22
22
|
@val_size = File.size(@val_path) / BYTES_PER_TOKEN
|
|
23
23
|
end
|
|
24
24
|
|
|
25
|
-
def train_size
|
|
26
|
-
@train_size
|
|
27
|
-
end
|
|
28
|
-
|
|
29
|
-
def val_size
|
|
30
|
-
@val_size
|
|
31
|
-
end
|
|
32
|
-
|
|
33
25
|
# Get a batch of data
|
|
34
|
-
# Memory-efficient:
|
|
35
|
-
# (matches Python's memmap recreation pattern)
|
|
26
|
+
# Memory-efficient: reads only needed bytes per batch to avoid memory leak
|
|
36
27
|
def get_batch(split)
|
|
37
28
|
path = split == :train ? @train_path : @val_path
|
|
38
29
|
data_size = split == :train ? @train_size : @val_size
|
|
39
30
|
|
|
40
|
-
# Random starting indices
|
|
41
31
|
max_start = data_size - @block_size - 1
|
|
42
32
|
indices = Array.new(@batch_size) { rand(0..max_start) }
|
|
43
33
|
|
|
44
|
-
# Read only the bytes we need from file (memory-efficient)
|
|
45
|
-
# This mimics Python's memmap recreation per batch
|
|
46
34
|
x_arrays = []
|
|
47
35
|
y_arrays = []
|
|
48
36
|
|
|
49
37
|
File.open(path, "rb") do |f|
|
|
50
38
|
indices.each do |i|
|
|
51
|
-
# Read x: tokens[i:i+block_size]
|
|
52
39
|
f.seek(i * BYTES_PER_TOKEN)
|
|
53
40
|
x_bytes = f.read((@block_size + 1) * BYTES_PER_TOKEN)
|
|
54
|
-
tokens = x_bytes.unpack("S<*")
|
|
41
|
+
tokens = x_bytes.unpack("S<*") # uint16 little-endian
|
|
55
42
|
|
|
56
43
|
x_arrays << tokens[0...@block_size]
|
|
57
44
|
y_arrays << tokens[1..@block_size]
|
|
58
45
|
end
|
|
59
46
|
end
|
|
60
47
|
|
|
61
|
-
# Create tensors directly from arrays (avoiding Numo intermediate)
|
|
62
48
|
x = Torch.tensor(x_arrays, dtype: :long)
|
|
63
49
|
y = Torch.tensor(y_arrays, dtype: :long)
|
|
64
50
|
|
|
65
|
-
# Move to device (CPU, CUDA, or MPS)
|
|
66
51
|
if @device != "cpu"
|
|
67
52
|
x = x.to(@device)
|
|
68
53
|
y = y.to(@device)
|
|
@@ -4,6 +4,8 @@ module NanoGPT
|
|
|
4
4
|
module Layers
|
|
5
5
|
# Transformer block: LayerNorm -> Attention -> LayerNorm -> MLP
|
|
6
6
|
class Block < Torch::NN::Module
|
|
7
|
+
attr_reader :attn
|
|
8
|
+
|
|
7
9
|
def initialize(config)
|
|
8
10
|
super()
|
|
9
11
|
@ln_1 = LayerNorm.new(config.n_embd, bias: config.bias)
|
|
@@ -16,10 +18,13 @@ module NanoGPT
|
|
|
16
18
|
x = x + @attn.call(@ln_1.call(x))
|
|
17
19
|
x = x + @mlp.call(@ln_2.call(x))
|
|
18
20
|
# Trigger GC to free intermediate tensors (critical for torch.rb memory management)
|
|
19
|
-
# Ruby's GC doesn't run frequently enough during forward pass, causing memory accumulation
|
|
20
21
|
GC.start(full_mark: false, immediate_sweep: true)
|
|
21
22
|
x
|
|
22
23
|
end
|
|
24
|
+
|
|
25
|
+
def crop_mask(block_size)
|
|
26
|
+
@attn.crop_mask(block_size)
|
|
27
|
+
end
|
|
23
28
|
end
|
|
24
29
|
end
|
|
25
30
|
end
|
|
@@ -27,7 +27,10 @@ module NanoGPT
|
|
|
27
27
|
|
|
28
28
|
# Causal mask for manual attention (only used when @flash is false)
|
|
29
29
|
unless @flash
|
|
30
|
-
|
|
30
|
+
# Torch.tril segfaults with torch-rb 0.22.2 + PyTorch 2.10, use equivalent
|
|
31
|
+
rows = Torch.arange(config.block_size).unsqueeze(1)
|
|
32
|
+
cols = Torch.arange(config.block_size).unsqueeze(0)
|
|
33
|
+
mask = rows.ge(cols).to(dtype: :float32)
|
|
31
34
|
register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size))
|
|
32
35
|
end
|
|
33
36
|
end
|
|
@@ -68,6 +71,13 @@ module NanoGPT
|
|
|
68
71
|
# Output projection
|
|
69
72
|
@resid_dropout.call(@c_proj.call(y))
|
|
70
73
|
end
|
|
74
|
+
|
|
75
|
+
# Crop the causal mask to a smaller block size (for inference optimization)
|
|
76
|
+
def crop_mask(block_size)
|
|
77
|
+
return unless defined?(@mask) && @mask
|
|
78
|
+
|
|
79
|
+
@mask = @mask[nil, nil, 0...block_size, 0...block_size]
|
|
80
|
+
end
|
|
71
81
|
end
|
|
72
82
|
end
|
|
73
83
|
end
|
data/lib/nano_gpt/model.rb
CHANGED
|
@@ -159,13 +159,7 @@ module NanoGPT
|
|
|
159
159
|
@wpe = new_wpe
|
|
160
160
|
|
|
161
161
|
# Update attention masks in all blocks
|
|
162
|
-
@h.each
|
|
163
|
-
attn = block.instance_variable_get(:@attn)
|
|
164
|
-
next unless attn.instance_variable_defined?(:@mask)
|
|
165
|
-
|
|
166
|
-
mask = attn.instance_variable_get(:@mask)
|
|
167
|
-
attn.instance_variable_set(:@mask, mask[nil, nil, 0...block_size, 0...block_size])
|
|
168
|
-
end
|
|
162
|
+
@h.each { |block| block.crop_mask(block_size) }
|
|
169
163
|
end
|
|
170
164
|
|
|
171
165
|
def configure_optimizers(weight_decay:, learning_rate:, betas:, device_type:)
|
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "numo/narray"
|
|
4
|
+
require "json"
|
|
5
|
+
require "fileutils"
|
|
6
|
+
require "set"
|
|
7
|
+
|
|
8
|
+
module NanoGPT
|
|
9
|
+
# Prepares custom text files for training with character-level tokenization
|
|
10
|
+
# Handles large files efficiently through streaming
|
|
11
|
+
class TextfilePreparer
|
|
12
|
+
BUFFER_SIZE = 100_000
|
|
13
|
+
|
|
14
|
+
attr_reader :input_path, :output_dir, :val_ratio
|
|
15
|
+
|
|
16
|
+
def initialize(input_path:, output_name: nil, val_ratio: 0.1)
|
|
17
|
+
@input_path = input_path
|
|
18
|
+
@val_ratio = val_ratio
|
|
19
|
+
@output_name = output_name || derive_output_name(input_path)
|
|
20
|
+
@output_dir = File.join(Dir.pwd, "data", @output_name)
|
|
21
|
+
end
|
|
22
|
+
|
|
23
|
+
def prepare
|
|
24
|
+
validate_input!
|
|
25
|
+
FileUtils.mkdir_p(@output_dir)
|
|
26
|
+
|
|
27
|
+
print_header
|
|
28
|
+
encoding = detect_encoding
|
|
29
|
+
vocab, char_count = build_vocabulary(encoding)
|
|
30
|
+
stoi, itos = build_mappings(vocab)
|
|
31
|
+
train_chars, val_chars = calculate_split(char_count)
|
|
32
|
+
|
|
33
|
+
write_train_bin(encoding, stoi, train_chars)
|
|
34
|
+
write_val_bin(encoding, stoi, train_chars, val_chars)
|
|
35
|
+
write_meta_json(vocab.size, stoi, itos)
|
|
36
|
+
|
|
37
|
+
print_summary(train_chars, val_chars, vocab.size)
|
|
38
|
+
@output_name
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
private
|
|
42
|
+
|
|
43
|
+
def derive_output_name(path)
|
|
44
|
+
File.basename(path, ".*").gsub(/[^a-zA-Z0-9_-]/, "_")
|
|
45
|
+
end
|
|
46
|
+
|
|
47
|
+
def validate_input!
|
|
48
|
+
raise "File not found: #{@input_path}" unless File.exist?(@input_path)
|
|
49
|
+
end
|
|
50
|
+
|
|
51
|
+
def print_header
|
|
52
|
+
file_size = File.size(@input_path)
|
|
53
|
+
puts "Preparing text file: #{@input_path}"
|
|
54
|
+
puts "File size: #{(file_size / 1_000_000.0).round(2)} MB"
|
|
55
|
+
puts "Output directory: #{@output_dir}"
|
|
56
|
+
puts "Validation ratio: #{@val_ratio}"
|
|
57
|
+
puts ""
|
|
58
|
+
end
|
|
59
|
+
|
|
60
|
+
def detect_encoding
|
|
61
|
+
sample = File.binread(@input_path, 100_000)
|
|
62
|
+
encoding = sample.force_encoding("UTF-8").valid_encoding? ? "UTF-8" : "Windows-1252:UTF-8"
|
|
63
|
+
puts " Detected encoding: #{encoding.split(':').first}"
|
|
64
|
+
encoding
|
|
65
|
+
end
|
|
66
|
+
|
|
67
|
+
def build_vocabulary(encoding)
|
|
68
|
+
puts "Phase 1: Building vocabulary..."
|
|
69
|
+
char_set = Set.new
|
|
70
|
+
char_count = 0
|
|
71
|
+
|
|
72
|
+
File.foreach(@input_path, encoding: encoding) do |line|
|
|
73
|
+
line.each_char { |c| char_set.add(c) }
|
|
74
|
+
char_count += line.length
|
|
75
|
+
print "\r Scanned #{char_count} characters, #{char_set.size} unique..." if (char_count % 100_000) < 1000
|
|
76
|
+
end
|
|
77
|
+
puts "\r Scanned #{char_count} characters, #{char_set.size} unique..."
|
|
78
|
+
puts "Vocabulary size: #{char_set.size}"
|
|
79
|
+
|
|
80
|
+
[char_set.to_a.sort, char_count]
|
|
81
|
+
end
|
|
82
|
+
|
|
83
|
+
def build_mappings(vocab)
|
|
84
|
+
stoi = vocab.each_with_index.to_h
|
|
85
|
+
itos = vocab.each_with_index.map { |c, i| [i, c] }.to_h
|
|
86
|
+
[stoi, itos]
|
|
87
|
+
end
|
|
88
|
+
|
|
89
|
+
def calculate_split(total_chars)
|
|
90
|
+
val_chars = (total_chars * @val_ratio).to_i
|
|
91
|
+
train_chars = total_chars - val_chars
|
|
92
|
+
puts ""
|
|
93
|
+
puts "Train: #{train_chars} characters"
|
|
94
|
+
puts "Val: #{val_chars} characters"
|
|
95
|
+
[train_chars, val_chars]
|
|
96
|
+
end
|
|
97
|
+
|
|
98
|
+
def write_train_bin(encoding, stoi, train_chars)
|
|
99
|
+
puts ""
|
|
100
|
+
puts "Phase 2: Encoding and writing train.bin..."
|
|
101
|
+
train_path = File.join(@output_dir, "train.bin")
|
|
102
|
+
|
|
103
|
+
write_tokens_to_bin(train_path, train_chars) do |on_char|
|
|
104
|
+
chars_written = 0
|
|
105
|
+
File.foreach(@input_path, encoding: encoding) do |line|
|
|
106
|
+
line.each_char do |c|
|
|
107
|
+
break if chars_written >= train_chars
|
|
108
|
+
on_char.call(stoi[c])
|
|
109
|
+
chars_written += 1
|
|
110
|
+
end
|
|
111
|
+
break if chars_written >= train_chars
|
|
112
|
+
end
|
|
113
|
+
end
|
|
114
|
+
end
|
|
115
|
+
|
|
116
|
+
def write_val_bin(encoding, stoi, train_chars, val_chars)
|
|
117
|
+
puts "Phase 3: Encoding and writing val.bin..."
|
|
118
|
+
val_path = File.join(@output_dir, "val.bin")
|
|
119
|
+
|
|
120
|
+
write_tokens_to_bin(val_path, val_chars) do |on_char|
|
|
121
|
+
skipped = 0
|
|
122
|
+
chars_written = 0
|
|
123
|
+
File.foreach(@input_path, encoding: encoding) do |line|
|
|
124
|
+
line.each_char do |c|
|
|
125
|
+
if skipped < train_chars
|
|
126
|
+
skipped += 1
|
|
127
|
+
next
|
|
128
|
+
end
|
|
129
|
+
on_char.call(stoi[c])
|
|
130
|
+
chars_written += 1
|
|
131
|
+
end
|
|
132
|
+
end
|
|
133
|
+
end
|
|
134
|
+
end
|
|
135
|
+
|
|
136
|
+
# Shared helper for writing tokens to binary files with buffering
|
|
137
|
+
def write_tokens_to_bin(path, total_chars)
|
|
138
|
+
buffer = []
|
|
139
|
+
chars_written = 0
|
|
140
|
+
|
|
141
|
+
File.open(path, "wb") do |output|
|
|
142
|
+
on_char = lambda do |token|
|
|
143
|
+
buffer << token
|
|
144
|
+
chars_written += 1
|
|
145
|
+
|
|
146
|
+
if buffer.size >= BUFFER_SIZE
|
|
147
|
+
flush_buffer(output, buffer)
|
|
148
|
+
print "\r Written #{chars_written}/#{total_chars} characters..."
|
|
149
|
+
end
|
|
150
|
+
end
|
|
151
|
+
|
|
152
|
+
yield on_char
|
|
153
|
+
|
|
154
|
+
flush_buffer(output, buffer) unless buffer.empty?
|
|
155
|
+
end
|
|
156
|
+
puts ""
|
|
157
|
+
end
|
|
158
|
+
|
|
159
|
+
def flush_buffer(output, buffer)
|
|
160
|
+
arr = Numo::UInt16.cast(buffer)
|
|
161
|
+
output.write(arr.to_binary)
|
|
162
|
+
buffer.clear
|
|
163
|
+
end
|
|
164
|
+
|
|
165
|
+
def write_meta_json(vocab_size, stoi, itos)
|
|
166
|
+
puts "Phase 4: Saving meta.json..."
|
|
167
|
+
meta = {
|
|
168
|
+
"vocab_size" => vocab_size,
|
|
169
|
+
"itos" => itos.transform_keys(&:to_s),
|
|
170
|
+
"stoi" => stoi
|
|
171
|
+
}
|
|
172
|
+
File.write(File.join(@output_dir, "meta.json"), JSON.pretty_generate(meta))
|
|
173
|
+
end
|
|
174
|
+
|
|
175
|
+
def print_summary(train_chars, val_chars, vocab_size)
|
|
176
|
+
train_size_mb = File.size(File.join(@output_dir, "train.bin")) / 1_000_000.0
|
|
177
|
+
val_size_mb = File.size(File.join(@output_dir, "val.bin")) / 1_000_000.0
|
|
178
|
+
|
|
179
|
+
puts ""
|
|
180
|
+
puts "Done!"
|
|
181
|
+
puts " train.bin: #{train_chars} tokens (#{train_size_mb.round(2)} MB)"
|
|
182
|
+
puts " val.bin: #{val_chars} tokens (#{val_size_mb.round(2)} MB)"
|
|
183
|
+
puts " meta.json: vocab_size=#{vocab_size}"
|
|
184
|
+
puts ""
|
|
185
|
+
puts "To train:"
|
|
186
|
+
puts " nanogpt train --dataset=#{@output_name}"
|
|
187
|
+
end
|
|
188
|
+
end
|
|
189
|
+
end
|
|
@@ -3,67 +3,22 @@
|
|
|
3
3
|
require "json"
|
|
4
4
|
|
|
5
5
|
module NanoGPT
|
|
6
|
-
#
|
|
6
|
+
# Base configuration class with shared functionality
|
|
7
7
|
# Supports JSON config files with command-line overrides
|
|
8
8
|
#
|
|
9
9
|
# Priority (highest to lowest):
|
|
10
10
|
# 1. Command-line arguments (--key=value)
|
|
11
11
|
# 2. JSON config file (--config=path.json)
|
|
12
12
|
# 3. Default values
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
#
|
|
18
|
-
class TrainConfig
|
|
19
|
-
# Defaults match bin/train exactly
|
|
20
|
-
DEFAULTS = {
|
|
21
|
-
# I/O
|
|
22
|
-
out_dir: "out-shakespeare-char",
|
|
23
|
-
eval_interval: 250,
|
|
24
|
-
log_interval: 10,
|
|
25
|
-
eval_iters: 200,
|
|
26
|
-
eval_only: false,
|
|
27
|
-
always_save_checkpoint: false,
|
|
28
|
-
init_from: "scratch", # 'scratch' or 'resume'
|
|
29
|
-
|
|
30
|
-
# Data
|
|
31
|
-
dataset: "shakespeare_char",
|
|
32
|
-
batch_size: 64,
|
|
33
|
-
block_size: 256,
|
|
34
|
-
gradient_accumulation_steps: 1,
|
|
35
|
-
|
|
36
|
-
# Model
|
|
37
|
-
n_layer: 6,
|
|
38
|
-
n_head: 6,
|
|
39
|
-
n_embd: 384,
|
|
40
|
-
dropout: 0.2,
|
|
41
|
-
bias: false,
|
|
42
|
-
|
|
43
|
-
# Optimizer
|
|
44
|
-
learning_rate: 1e-3,
|
|
45
|
-
weight_decay: 1e-1,
|
|
46
|
-
beta1: 0.9,
|
|
47
|
-
beta2: 0.99,
|
|
48
|
-
grad_clip: 1.0,
|
|
49
|
-
|
|
50
|
-
# LR scheduler
|
|
51
|
-
decay_lr: true,
|
|
52
|
-
warmup_iters: 100,
|
|
53
|
-
lr_decay_iters: 5000,
|
|
54
|
-
min_lr: 1e-4,
|
|
55
|
-
|
|
56
|
-
# Training
|
|
57
|
-
max_iters: 5000,
|
|
58
|
-
|
|
59
|
-
# System
|
|
60
|
-
device: "auto"
|
|
61
|
-
}.freeze
|
|
13
|
+
class BaseConfig
|
|
14
|
+
def self.defaults
|
|
15
|
+
raise NotImplementedError, "Subclasses must define DEFAULTS"
|
|
16
|
+
end
|
|
62
17
|
|
|
63
18
|
attr_reader :values
|
|
64
19
|
|
|
65
20
|
def initialize(values = {})
|
|
66
|
-
@values =
|
|
21
|
+
@values = self.class.defaults.merge(values)
|
|
67
22
|
end
|
|
68
23
|
|
|
69
24
|
def [](key)
|
|
@@ -79,25 +34,17 @@ module NanoGPT
|
|
|
79
34
|
end
|
|
80
35
|
|
|
81
36
|
# Load config from command-line args
|
|
82
|
-
# Supports:
|
|
83
|
-
# --config=path/to/config.json (load JSON file)
|
|
84
|
-
# --key=value (override specific values)
|
|
85
37
|
def self.load(args)
|
|
86
38
|
config = new
|
|
87
39
|
|
|
88
40
|
# First pass: find and load JSON config file
|
|
89
|
-
config_file = nil
|
|
90
41
|
args.each do |arg|
|
|
91
42
|
if arg.start_with?("--config=")
|
|
92
|
-
|
|
43
|
+
config.load_json(arg.split("=", 2).last)
|
|
93
44
|
break
|
|
94
45
|
end
|
|
95
46
|
end
|
|
96
47
|
|
|
97
|
-
if config_file
|
|
98
|
-
config.load_json(config_file)
|
|
99
|
-
end
|
|
100
|
-
|
|
101
48
|
# Second pass: apply command-line overrides
|
|
102
49
|
args.each do |arg|
|
|
103
50
|
next unless arg.start_with?("--") && arg.include?("=")
|
|
@@ -120,9 +67,7 @@ module NanoGPT
|
|
|
120
67
|
|
|
121
68
|
# Load values from JSON file
|
|
122
69
|
def load_json(path)
|
|
123
|
-
unless File.exist?(path)
|
|
124
|
-
raise "Config file not found: #{path}"
|
|
125
|
-
end
|
|
70
|
+
raise "Config file not found: #{path}" unless File.exist?(path)
|
|
126
71
|
|
|
127
72
|
json = JSON.parse(File.read(path))
|
|
128
73
|
puts "Loaded config from #{path}"
|
|
@@ -145,8 +90,6 @@ module NanoGPT
|
|
|
145
90
|
puts "Saved config to #{path}"
|
|
146
91
|
end
|
|
147
92
|
|
|
148
|
-
private
|
|
149
|
-
|
|
150
93
|
def self.parse_value(val, existing)
|
|
151
94
|
case existing
|
|
152
95
|
when Integer then val.to_i
|
|
@@ -155,11 +98,61 @@ module NanoGPT
|
|
|
155
98
|
else val
|
|
156
99
|
end
|
|
157
100
|
end
|
|
101
|
+
private_class_method :parse_value
|
|
102
|
+
end
|
|
103
|
+
|
|
104
|
+
# Configuration for training
|
|
105
|
+
class TrainConfig < BaseConfig
|
|
106
|
+
DEFAULTS = {
|
|
107
|
+
# I/O
|
|
108
|
+
out_dir: "out-shakespeare-char",
|
|
109
|
+
eval_interval: 250,
|
|
110
|
+
log_interval: 10,
|
|
111
|
+
eval_iters: 200,
|
|
112
|
+
eval_only: false,
|
|
113
|
+
always_save_checkpoint: false,
|
|
114
|
+
init_from: "scratch",
|
|
115
|
+
|
|
116
|
+
# Data
|
|
117
|
+
dataset: "shakespeare_char",
|
|
118
|
+
batch_size: 64,
|
|
119
|
+
block_size: 256,
|
|
120
|
+
gradient_accumulation_steps: 1,
|
|
121
|
+
|
|
122
|
+
# Model
|
|
123
|
+
n_layer: 6,
|
|
124
|
+
n_head: 6,
|
|
125
|
+
n_embd: 384,
|
|
126
|
+
dropout: 0.2,
|
|
127
|
+
bias: false,
|
|
128
|
+
|
|
129
|
+
# Optimizer
|
|
130
|
+
learning_rate: 1e-3,
|
|
131
|
+
weight_decay: 1e-1,
|
|
132
|
+
beta1: 0.9,
|
|
133
|
+
beta2: 0.99,
|
|
134
|
+
grad_clip: 1.0,
|
|
135
|
+
|
|
136
|
+
# LR scheduler
|
|
137
|
+
decay_lr: true,
|
|
138
|
+
warmup_iters: 100,
|
|
139
|
+
lr_decay_iters: 5000,
|
|
140
|
+
min_lr: 1e-4,
|
|
141
|
+
|
|
142
|
+
# Training
|
|
143
|
+
max_iters: 5000,
|
|
144
|
+
|
|
145
|
+
# System
|
|
146
|
+
device: "auto"
|
|
147
|
+
}.freeze
|
|
148
|
+
|
|
149
|
+
def self.defaults
|
|
150
|
+
DEFAULTS
|
|
151
|
+
end
|
|
158
152
|
end
|
|
159
153
|
|
|
160
154
|
# Configuration for sampling/generation
|
|
161
|
-
class SampleConfig
|
|
162
|
-
# Defaults match bin/sample exactly
|
|
155
|
+
class SampleConfig < BaseConfig
|
|
163
156
|
DEFAULTS = {
|
|
164
157
|
out_dir: "out-shakespeare-char",
|
|
165
158
|
dataset: "shakespeare_char",
|
|
@@ -172,88 +165,29 @@ module NanoGPT
|
|
|
172
165
|
device: "auto"
|
|
173
166
|
}.freeze
|
|
174
167
|
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
def initialize(values = {})
|
|
178
|
-
@values = DEFAULTS.merge(values)
|
|
179
|
-
end
|
|
180
|
-
|
|
181
|
-
def [](key)
|
|
182
|
-
@values[key.to_sym]
|
|
183
|
-
end
|
|
184
|
-
|
|
185
|
-
def []=(key, value)
|
|
186
|
-
@values[key.to_sym] = value
|
|
187
|
-
end
|
|
188
|
-
|
|
189
|
-
def to_h
|
|
190
|
-
@values.dup
|
|
191
|
-
end
|
|
192
|
-
|
|
193
|
-
def self.load(args)
|
|
194
|
-
config = new
|
|
195
|
-
|
|
196
|
-
# First pass: find and load JSON config file
|
|
197
|
-
config_file = nil
|
|
198
|
-
args.each do |arg|
|
|
199
|
-
if arg.start_with?("--config=")
|
|
200
|
-
config_file = arg.split("=", 2).last
|
|
201
|
-
break
|
|
202
|
-
end
|
|
203
|
-
end
|
|
204
|
-
|
|
205
|
-
if config_file
|
|
206
|
-
config.load_json(config_file)
|
|
207
|
-
end
|
|
208
|
-
|
|
209
|
-
# Second pass: apply command-line overrides
|
|
210
|
-
args.each do |arg|
|
|
211
|
-
next unless arg.start_with?("--") && arg.include?("=")
|
|
212
|
-
next if arg.start_with?("--config=")
|
|
213
|
-
|
|
214
|
-
key, val = arg[2..].split("=", 2)
|
|
215
|
-
key = key.to_sym
|
|
216
|
-
|
|
217
|
-
unless config.values.key?(key)
|
|
218
|
-
puts "Warning: Unknown config key: #{key}"
|
|
219
|
-
next
|
|
220
|
-
end
|
|
221
|
-
|
|
222
|
-
config[key] = parse_value(val, config[key])
|
|
223
|
-
end
|
|
224
|
-
|
|
225
|
-
config
|
|
226
|
-
end
|
|
227
|
-
|
|
228
|
-
def load_json(path)
|
|
229
|
-
unless File.exist?(path)
|
|
230
|
-
raise "Config file not found: #{path}"
|
|
231
|
-
end
|
|
232
|
-
|
|
233
|
-
json = JSON.parse(File.read(path))
|
|
234
|
-
puts "Loaded config from #{path}"
|
|
235
|
-
|
|
236
|
-
json.each do |key, val|
|
|
237
|
-
key = key.to_sym
|
|
238
|
-
unless @values.key?(key)
|
|
239
|
-
puts "Warning: Unknown config key in JSON: #{key}"
|
|
240
|
-
next
|
|
241
|
-
end
|
|
242
|
-
@values[key] = val
|
|
243
|
-
end
|
|
244
|
-
|
|
245
|
-
self
|
|
168
|
+
def self.defaults
|
|
169
|
+
DEFAULTS
|
|
246
170
|
end
|
|
171
|
+
end
|
|
247
172
|
|
|
248
|
-
|
|
173
|
+
# Configuration for benchmarking
|
|
174
|
+
class BenchConfig < BaseConfig
|
|
175
|
+
DEFAULTS = {
|
|
176
|
+
batch_size: 12,
|
|
177
|
+
block_size: 1024,
|
|
178
|
+
n_layer: 12,
|
|
179
|
+
n_head: 12,
|
|
180
|
+
n_embd: 768,
|
|
181
|
+
dropout: 0.0,
|
|
182
|
+
bias: false,
|
|
183
|
+
real_data: true,
|
|
184
|
+
dataset: "openwebtext",
|
|
185
|
+
seed: 1337,
|
|
186
|
+
device: "auto"
|
|
187
|
+
}.freeze
|
|
249
188
|
|
|
250
|
-
def self.
|
|
251
|
-
|
|
252
|
-
when Integer then val.to_i
|
|
253
|
-
when Float then val.to_f
|
|
254
|
-
when TrueClass, FalseClass then val.downcase == "true"
|
|
255
|
-
else val
|
|
256
|
-
end
|
|
189
|
+
def self.defaults
|
|
190
|
+
DEFAULTS
|
|
257
191
|
end
|
|
258
192
|
end
|
|
259
193
|
end
|