nanogpt 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,221 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "fileutils"
4
+
5
+ module NanoGPT
6
+ # Training loop for GPT models
7
+ class Trainer
8
+ attr_reader :model, :optimizer, :config, :iter_num, :best_val_loss
9
+
10
+ def initialize(model:, data_loader:, config: {})
11
+ @model = model
12
+ @data_loader = data_loader
13
+ @config = default_config.merge(config)
14
+
15
+ @iter_num = 0
16
+ @best_val_loss = Float::INFINITY
17
+
18
+ setup_optimizer
19
+ setup_lr_scheduler
20
+ end
21
+
22
+ def default_config
23
+ {
24
+ out_dir: "out",
25
+ eval_interval: 250,
26
+ log_interval: 10,
27
+ eval_iters: 200,
28
+ eval_only: false,
29
+ always_save_checkpoint: false,
30
+
31
+ # Optimizer
32
+ learning_rate: 1e-3,
33
+ weight_decay: 1e-1,
34
+ beta1: 0.9,
35
+ beta2: 0.99,
36
+ grad_clip: 1.0,
37
+
38
+ # LR scheduler
39
+ decay_lr: true,
40
+ warmup_iters: 100,
41
+ lr_decay_iters: 5000,
42
+ min_lr: 1e-4,
43
+
44
+ # Training
45
+ max_iters: 5000,
46
+ gradient_accumulation_steps: 1,
47
+
48
+ device: "cpu"
49
+ }
50
+ end
51
+
52
+ def train
53
+ puts "Starting training..."
54
+ puts "Tokens per iteration: #{tokens_per_iter}"
55
+
56
+ @model.train
57
+ x, y = @data_loader.get_batch(:train)
58
+ t0 = Time.now
59
+
60
+ while @iter_num <= @config[:max_iters]
61
+ # Set learning rate for this iteration
62
+ lr = @config[:decay_lr] ? @lr_scheduler.step(@optimizer, @iter_num) : @config[:learning_rate]
63
+
64
+ # Evaluate and checkpoint
65
+ if @iter_num % @config[:eval_interval] == 0
66
+ losses = estimate_loss
67
+ puts "step #{@iter_num}: train loss #{losses[:train].round(4)}, val loss #{losses[:val].round(4)}"
68
+
69
+ if losses[:val] < @best_val_loss || @config[:always_save_checkpoint]
70
+ @best_val_loss = [losses[:val], @best_val_loss].min
71
+ save_checkpoint if @iter_num > 0
72
+ end
73
+ end
74
+
75
+ break if @iter_num == 0 && @config[:eval_only]
76
+
77
+ # Forward/backward with gradient accumulation
78
+ @optimizer.zero_grad
79
+
80
+ accumulated_loss = 0.0
81
+ @config[:gradient_accumulation_steps].times do |micro_step|
82
+ logits, loss = @model.call(x, targets: y)
83
+ loss = loss / @config[:gradient_accumulation_steps]
84
+ accumulated_loss += loss.item
85
+ loss.backward
86
+
87
+ # Prefetch next batch
88
+ x, y = @data_loader.get_batch(:train)
89
+ end
90
+
91
+ # Gradient clipping (manual implementation since torch.rb lacks clip_grad_norm_)
92
+ if @config[:grad_clip] > 0.0
93
+ clip_grad_norm(@model.parameters, @config[:grad_clip])
94
+ end
95
+
96
+ # Optimizer step
97
+ @optimizer.step
98
+
99
+ # Logging
100
+ t1 = Time.now
101
+ dt = t1 - t0
102
+ t0 = t1
103
+
104
+ if @iter_num % @config[:log_interval] == 0
105
+ puts "iter #{@iter_num}: loss #{accumulated_loss.round(4)}, time #{(dt * 1000).round(2)}ms, lr #{lr.round(6)}"
106
+ end
107
+
108
+ @iter_num += 1
109
+ end
110
+
111
+ puts "Training complete!"
112
+ end
113
+
114
+ def estimate_loss
115
+ @model.eval
116
+ out = {}
117
+
118
+ [:train, :val].each do |split|
119
+ losses = []
120
+ @config[:eval_iters].times do
121
+ x, y = @data_loader.get_batch(split)
122
+ Torch.no_grad do
123
+ _logits, loss = @model.call(x, targets: y)
124
+ losses << loss.item
125
+ end
126
+ end
127
+ out[split] = losses.sum / losses.size
128
+ end
129
+
130
+ @model.train
131
+ out
132
+ end
133
+
134
+ def save_checkpoint
135
+ FileUtils.mkdir_p(@config[:out_dir])
136
+ path = File.join(@config[:out_dir], "ckpt.pt")
137
+
138
+ # Note: torch.rb doesn't support optimizer.state_dict yet
139
+ # We save model state and training metadata
140
+ # Convert symbol keys to strings for Torch.save compatibility
141
+ checkpoint = {
142
+ "model" => @model.state_dict,
143
+ "model_args" => stringify_keys(@model.config.to_h),
144
+ "iter_num" => @iter_num,
145
+ "best_val_loss" => @best_val_loss,
146
+ "config" => stringify_keys(@config)
147
+ }
148
+
149
+ Torch.save(checkpoint, path)
150
+ puts "Saved checkpoint to #{path}"
151
+ end
152
+
153
+ def load_checkpoint(path)
154
+ checkpoint = Torch.load(path)
155
+
156
+ @model.load_state_dict(checkpoint["model"])
157
+ @iter_num = checkpoint["iter_num"]
158
+ @best_val_loss = checkpoint["best_val_loss"]
159
+
160
+ # Reinitialize optimizer (since we can't restore optimizer state in torch.rb)
161
+ setup_optimizer
162
+
163
+ puts "Loaded checkpoint from #{path} (iter #{@iter_num})"
164
+ checkpoint
165
+ end
166
+
167
+ private
168
+
169
+ # Convert symbol keys to strings recursively (for Torch.save)
170
+ def stringify_keys(hash)
171
+ hash.transform_keys(&:to_s).transform_values do |v|
172
+ v.is_a?(Hash) ? stringify_keys(v) : v
173
+ end
174
+ end
175
+
176
+ # Manual gradient clipping (torch.rb doesn't have clip_grad_norm_)
177
+ def clip_grad_norm(parameters, max_norm)
178
+ total_norm = 0.0
179
+ parameters.each do |p|
180
+ next unless p.grad
181
+
182
+ param_norm = p.grad.data.norm(2).item
183
+ total_norm += param_norm ** 2
184
+ end
185
+ total_norm = Math.sqrt(total_norm)
186
+
187
+ clip_coef = max_norm / (total_norm + 1e-6)
188
+ if clip_coef < 1
189
+ parameters.each do |p|
190
+ next unless p.grad
191
+
192
+ p.grad.data.mul!(clip_coef)
193
+ end
194
+ end
195
+
196
+ total_norm
197
+ end
198
+
199
+ def setup_optimizer
200
+ @optimizer = @model.configure_optimizers(
201
+ weight_decay: @config[:weight_decay],
202
+ learning_rate: @config[:learning_rate],
203
+ betas: [@config[:beta1], @config[:beta2]],
204
+ device_type: @config[:device]
205
+ )
206
+ end
207
+
208
+ def setup_lr_scheduler
209
+ @lr_scheduler = LRScheduler.new(
210
+ learning_rate: @config[:learning_rate],
211
+ min_lr: @config[:min_lr],
212
+ warmup_iters: @config[:warmup_iters],
213
+ lr_decay_iters: @config[:lr_decay_iters]
214
+ )
215
+ end
216
+
217
+ def tokens_per_iter
218
+ @config[:gradient_accumulation_steps] * @data_loader.batch_size * @data_loader.block_size
219
+ end
220
+ end
221
+ end
@@ -0,0 +1,5 @@
1
+ # frozen_string_literal: true
2
+
3
+ module NanoGPT
4
+ VERSION = "0.1.0"
5
+ end
data/lib/nano_gpt.rb ADDED
@@ -0,0 +1,18 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "torch"
4
+ require "numo/narray"
5
+
6
+ require_relative "nano_gpt/version"
7
+ require_relative "nano_gpt/device"
8
+ require_relative "nano_gpt/config"
9
+ require_relative "nano_gpt/layers/layer_norm"
10
+ require_relative "nano_gpt/layers/mlp"
11
+ require_relative "nano_gpt/layers/causal_self_attention"
12
+ require_relative "nano_gpt/layers/block"
13
+ require_relative "nano_gpt/model"
14
+ require_relative "nano_gpt/tokenizer"
15
+ require_relative "nano_gpt/data_loader"
16
+ require_relative "nano_gpt/lr_scheduler"
17
+ require_relative "nano_gpt/trainer"
18
+ require_relative "nano_gpt/train_config"
data/nanogpt.gemspec ADDED
@@ -0,0 +1,37 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "lib/nano_gpt/version"
4
+
5
+ Gem::Specification.new do |spec|
6
+ spec.name = "nanogpt"
7
+ spec.version = NanoGPT::VERSION
8
+ spec.authors = ["Chris Hasiński"]
9
+ spec.email = ["krzysztof.hasinski@gmail.com"]
10
+
11
+ spec.summary = "A Ruby port of Karpathy's nanoGPT"
12
+ spec.description = "Train and run GPT models in Ruby using torch.rb. " \
13
+ "A minimal, educational implementation of GPT-2 style models."
14
+ spec.homepage = "https://github.com/khasinski/nanogpt-rb"
15
+ spec.license = "MIT"
16
+ spec.required_ruby_version = ">= 3.1.0"
17
+
18
+ spec.metadata["homepage_uri"] = spec.homepage
19
+ spec.metadata["source_code_uri"] = spec.homepage
20
+ spec.metadata["changelog_uri"] = "#{spec.homepage}/blob/main/CHANGELOG.md"
21
+
22
+ spec.files = Dir.chdir(__dir__) do
23
+ `git ls-files -z`.split("\x0").reject do |f|
24
+ (File.expand_path(f) == __FILE__) ||
25
+ f.start_with?("spec/", "test/", ".git", ".github", "vendor/", "python/")
26
+ end
27
+ end
28
+ spec.bindir = "exe"
29
+ spec.executables = ["nanogpt"]
30
+ spec.require_paths = ["lib"]
31
+
32
+ spec.add_dependency "torch-rb", "~> 0.14"
33
+ spec.add_dependency "numo-narray", "~> 0.9"
34
+ spec.add_dependency "tiktoken_ruby", "~> 0.0"
35
+
36
+ spec.add_development_dependency "rspec", "~> 3.12"
37
+ end
metadata ADDED
@@ -0,0 +1,133 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: nanogpt
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.1.0
5
+ platform: ruby
6
+ authors:
7
+ - Chris Hasiński
8
+ autorequire:
9
+ bindir: exe
10
+ cert_chain: []
11
+ date: 2025-12-02 00:00:00.000000000 Z
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ name: torch-rb
15
+ requirement: !ruby/object:Gem::Requirement
16
+ requirements:
17
+ - - "~>"
18
+ - !ruby/object:Gem::Version
19
+ version: '0.14'
20
+ type: :runtime
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - "~>"
25
+ - !ruby/object:Gem::Version
26
+ version: '0.14'
27
+ - !ruby/object:Gem::Dependency
28
+ name: numo-narray
29
+ requirement: !ruby/object:Gem::Requirement
30
+ requirements:
31
+ - - "~>"
32
+ - !ruby/object:Gem::Version
33
+ version: '0.9'
34
+ type: :runtime
35
+ prerelease: false
36
+ version_requirements: !ruby/object:Gem::Requirement
37
+ requirements:
38
+ - - "~>"
39
+ - !ruby/object:Gem::Version
40
+ version: '0.9'
41
+ - !ruby/object:Gem::Dependency
42
+ name: tiktoken_ruby
43
+ requirement: !ruby/object:Gem::Requirement
44
+ requirements:
45
+ - - "~>"
46
+ - !ruby/object:Gem::Version
47
+ version: '0.0'
48
+ type: :runtime
49
+ prerelease: false
50
+ version_requirements: !ruby/object:Gem::Requirement
51
+ requirements:
52
+ - - "~>"
53
+ - !ruby/object:Gem::Version
54
+ version: '0.0'
55
+ - !ruby/object:Gem::Dependency
56
+ name: rspec
57
+ requirement: !ruby/object:Gem::Requirement
58
+ requirements:
59
+ - - "~>"
60
+ - !ruby/object:Gem::Version
61
+ version: '3.12'
62
+ type: :development
63
+ prerelease: false
64
+ version_requirements: !ruby/object:Gem::Requirement
65
+ requirements:
66
+ - - "~>"
67
+ - !ruby/object:Gem::Version
68
+ version: '3.12'
69
+ description: Train and run GPT models in Ruby using torch.rb. A minimal, educational
70
+ implementation of GPT-2 style models.
71
+ email:
72
+ - krzysztof.hasinski@gmail.com
73
+ executables:
74
+ - nanogpt
75
+ extensions: []
76
+ extra_rdoc_files: []
77
+ files:
78
+ - Gemfile
79
+ - Gemfile.lock
80
+ - README.md
81
+ - bin/bench
82
+ - bin/sample
83
+ - bin/train
84
+ - config/train_gpt2.json
85
+ - config/train_shakespeare.json
86
+ - config/train_shakespeare_char.json
87
+ - data/openwebtext/prepare.rb
88
+ - data/shakespeare/prepare.rb
89
+ - data/shakespeare_char/input.txt
90
+ - data/shakespeare_char/prepare.rb
91
+ - exe/nanogpt
92
+ - lib/nano_gpt.rb
93
+ - lib/nano_gpt/config.rb
94
+ - lib/nano_gpt/data_loader.rb
95
+ - lib/nano_gpt/device.rb
96
+ - lib/nano_gpt/layers/block.rb
97
+ - lib/nano_gpt/layers/causal_self_attention.rb
98
+ - lib/nano_gpt/layers/layer_norm.rb
99
+ - lib/nano_gpt/layers/mlp.rb
100
+ - lib/nano_gpt/lr_scheduler.rb
101
+ - lib/nano_gpt/model.rb
102
+ - lib/nano_gpt/tokenizer.rb
103
+ - lib/nano_gpt/train_config.rb
104
+ - lib/nano_gpt/trainer.rb
105
+ - lib/nano_gpt/version.rb
106
+ - nanogpt.gemspec
107
+ homepage: https://github.com/khasinski/nanogpt-rb
108
+ licenses:
109
+ - MIT
110
+ metadata:
111
+ homepage_uri: https://github.com/khasinski/nanogpt-rb
112
+ source_code_uri: https://github.com/khasinski/nanogpt-rb
113
+ changelog_uri: https://github.com/khasinski/nanogpt-rb/blob/main/CHANGELOG.md
114
+ post_install_message:
115
+ rdoc_options: []
116
+ require_paths:
117
+ - lib
118
+ required_ruby_version: !ruby/object:Gem::Requirement
119
+ requirements:
120
+ - - ">="
121
+ - !ruby/object:Gem::Version
122
+ version: 3.1.0
123
+ required_rubygems_version: !ruby/object:Gem::Requirement
124
+ requirements:
125
+ - - ">="
126
+ - !ruby/object:Gem::Version
127
+ version: '0'
128
+ requirements: []
129
+ rubygems_version: 3.1.6
130
+ signing_key:
131
+ specification_version: 4
132
+ summary: A Ruby port of Karpathy's nanoGPT
133
+ test_files: []