nanochat 0.1.0.pre

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,285 @@
1
+ # frozen_string_literal: true
2
+
3
+ # Nanochat: Minimal ChatGPT implementation in Ruby
4
+ # Ruby port of nanochat by Andrej Karpathy (https://github.com/karpathy/nanochat)
5
+ module Nanochat
6
+ # RMSNorm (no learnable parameters)
7
+ def self.norm(input)
8
+ variance = input.pow(2).mean(-1, keepdim: true)
9
+ input * Torch.rsqrt(variance + 1e-5)
10
+ end
11
+
12
+ # Apply RoPE (Rotary Position Embeddings)
13
+ def self.apply_rotary_emb(input, cos, sin)
14
+ raise ArgumentError, 'Expected 4D tensor for multihead attention' unless input.ndim == 4
15
+
16
+ dim = input.shape[3] / 2
17
+ # Split last dimension in half
18
+ x1, x2 = input.split(dim, dim: -1)
19
+ y1 = (x1 * cos) + (x2 * sin)
20
+ y2 = (x1 * (-sin)) + (x2 * cos)
21
+ out = Torch.cat([y1, y2], -1)
22
+ out.to(dtype: input.dtype)
23
+ end
24
+
25
+ # Feed-forward network (MLP)
26
+ class MLP < Torch::NN::Module
27
+ def initialize(config)
28
+ super()
29
+ @c_fc = Torch::NN::Linear.new(config.n_embd, 4 * config.n_embd, bias: false)
30
+ @c_proj = Torch::NN::Linear.new(4 * config.n_embd, config.n_embd, bias: false)
31
+ end
32
+
33
+ def forward(input)
34
+ x = @c_fc.call(input)
35
+ x = Torch::NN::F.relu(x).square # ReLU^2 activation
36
+ @c_proj.call(x)
37
+ end
38
+ end
39
+
40
+ # Causal self-attention with multi-query support
41
+ class CausalSelfAttention < Torch::NN::Module
42
+ attr_reader :layer_idx
43
+
44
+ def initialize(config, layer_idx)
45
+ super()
46
+ @layer_idx = layer_idx
47
+ @n_head = config.n_head
48
+ @n_kv_head = config.n_kv_head
49
+ @n_embd = config.n_embd
50
+ @head_dim = @n_embd / @n_head
51
+
52
+ raise ArgumentError, 'n_embd must be divisible by n_head' unless (@n_embd % @n_head).zero?
53
+ raise ArgumentError, 'Invalid MQA configuration' unless @n_kv_head <= @n_head && (@n_head % @n_kv_head).zero?
54
+
55
+ @c_q = Torch::NN::Linear.new(@n_embd, @n_head * @head_dim, bias: false)
56
+ @c_k = Torch::NN::Linear.new(@n_embd, @n_kv_head * @head_dim, bias: false)
57
+ @c_v = Torch::NN::Linear.new(@n_embd, @n_kv_head * @head_dim, bias: false)
58
+ @c_proj = Torch::NN::Linear.new(@n_embd, @n_embd, bias: false)
59
+ end
60
+
61
+ def forward(input, cos_sin, kv_cache)
62
+ batch_size, seq_len, = input.size
63
+
64
+ # Project to Q, K, V
65
+ query = @c_q.call(input).view(batch_size, seq_len, @n_head, @head_dim)
66
+ key = @c_k.call(input).view(batch_size, seq_len, @n_kv_head, @head_dim)
67
+ value = @c_v.call(input).view(batch_size, seq_len, @n_kv_head, @head_dim)
68
+
69
+ # Apply Rotary Embeddings and QK norm
70
+ cos, sin = cos_sin
71
+ query = Nanochat.apply_rotary_emb(query, cos, sin)
72
+ key = Nanochat.apply_rotary_emb(key, cos, sin)
73
+ query = Nanochat.norm(query)
74
+ key = Nanochat.norm(key)
75
+
76
+ # Transpose to make head the batch dimension: (B, T, H, D) -> (B, H, T, D)
77
+ query = query.transpose(1, 2)
78
+ key = key.transpose(1, 2)
79
+ value = value.transpose(1, 2)
80
+
81
+ # Apply KV cache if provided
82
+ key, value = kv_cache.insert_kv(@layer_idx, key, value) if kv_cache
83
+
84
+ tq = query.size(2) # number of queries
85
+ tk = key.size(2) # number of keys (total in cache + current)
86
+
87
+ # Expand key/value for Grouped Query Attention if needed
88
+ if @n_head != @n_kv_head
89
+ num_groups = @n_head / @n_kv_head
90
+ key = key.repeat_interleave(num_groups, dim: 1)
91
+ value = value.repeat_interleave(num_groups, dim: 1)
92
+ end
93
+
94
+ # Scaled dot-product attention (manual implementation for torch-rb)
95
+ scale = Math.sqrt(@head_dim)
96
+ scores = Torch.matmul(query, key.transpose(-2, -1)) / scale
97
+
98
+ # Apply causal mask
99
+ if kv_cache.nil? || tq == tk
100
+ # Training or full sequence: causal mask
101
+ mask = Torch.ones([tq, tk], device: query.device).tril
102
+ mask = Torch.zeros([tq, tk], device: query.device).masked_fill(mask.logical_not, -Float::INFINITY)
103
+ scores += mask
104
+ elsif tq == 1
105
+ # Inference with single query: attend to all keys (no mask needed)
106
+ else
107
+ # Inference with multiple queries: prefix + causal
108
+ mask = Torch.zeros([tq, tk], device: query.device)
109
+ prefix_len = tk - tq
110
+ if prefix_len.positive?
111
+ # Full attention to prefix
112
+ mask[0..-1, prefix_len..-1] = Torch.ones([tq, tq], device: query.device).tril.logical_not * -Float::INFINITY
113
+ else
114
+ mask = Torch.ones([tq, tk], device: query.device).tril.logical_not * -Float::INFINITY
115
+ end
116
+ scores += mask
117
+ end
118
+
119
+ # Softmax and apply to values
120
+ attn_weights = Torch::NN::F.softmax(scores, dim: -1)
121
+ y = Torch.matmul(attn_weights, value) # (B, H, Tq, D)
122
+
123
+ # Re-assemble heads and project back
124
+ y = y.transpose(1, 2).contiguous.view(batch_size, seq_len, -1)
125
+ @c_proj.call(y)
126
+ end
127
+ end
128
+
129
+ # Transformer Block
130
+ class Block < Torch::NN::Module
131
+ def initialize(config, layer_idx)
132
+ super()
133
+ @attn = CausalSelfAttention.new(config, layer_idx)
134
+ @mlp = MLP.new(config)
135
+ end
136
+
137
+ def forward(input, cos_sin, kv_cache)
138
+ # Attention with pre-norm and residual
139
+ x = input + @attn.call(Nanochat.norm(input), cos_sin, kv_cache)
140
+ # MLP with pre-norm and residual
141
+ x + @mlp.call(Nanochat.norm(x))
142
+ end
143
+ end
144
+
145
+ # GPT transformer model
146
+ class GPT < Torch::NN::Module
147
+ SOFTCAP = 15
148
+ ROTARY_CACHE_MULTIPLIER = 10
149
+
150
+ attr_reader :config
151
+
152
+ # Load a GPT model from a checkpoint file
153
+ def self.from_checkpoint(path, config)
154
+ checkpoint = CheckpointManager.load(path)
155
+
156
+ model_state = checkpoint['model'] || checkpoint[:model]
157
+ raise ArgumentError, 'Checkpoint missing model state' unless model_state
158
+
159
+ model = new(config)
160
+ model.load_state_dict(model_state)
161
+
162
+ model
163
+ end
164
+
165
+ def initialize(config)
166
+ super()
167
+ @config = config
168
+ config.validate!
169
+
170
+ # Transformer components
171
+ @wte = Torch::NN::Embedding.new(config.vocab_size, config.n_embd)
172
+ blocks = (0...config.n_layer).map { |idx| Block.new(config, idx) }
173
+ @blocks = Torch::NN::ModuleList.new(blocks)
174
+ @lm_head = Torch::NN::Linear.new(config.n_embd, config.vocab_size, bias: false)
175
+
176
+ # Precompute rotary embeddings
177
+ @rotary_seq_len = config.block_size * ROTARY_CACHE_MULTIPLIER
178
+ head_dim = config.n_embd / config.n_head
179
+ cos, sin = precompute_rotary_embeddings(@rotary_seq_len, head_dim)
180
+
181
+ # Register as buffers (not saved to checkpoint)
182
+ register_buffer('cos', cos, persistent: false)
183
+ register_buffer('sin', sin, persistent: false)
184
+ end
185
+
186
+ def forward(idx, targets: nil, kv_cache: nil)
187
+ _, seq_len = idx.size
188
+
189
+ # Get rotary embeddings for current sequence
190
+ raise ArgumentError, "Sequence too long: #{seq_len} > #{@cos.size(1)}" if seq_len > @cos.size(1)
191
+ raise ArgumentError, 'Device mismatch' unless idx.device == @cos.device
192
+
193
+ # Offset rotary embeddings for KV cache
194
+ t0 = kv_cache.nil? ? 0 : kv_cache.pos
195
+ cos = @cos.narrow(1, t0, seq_len)
196
+ sin = @sin.narrow(1, t0, seq_len)
197
+ cos_sin = [cos, sin]
198
+
199
+ # Forward through transformer
200
+ x = @wte.call(idx)
201
+ x = Nanochat.norm(x)
202
+
203
+ @blocks.each do |block|
204
+ x = block.call(x, cos_sin, kv_cache)
205
+ end
206
+
207
+ x = Nanochat.norm(x)
208
+
209
+ # Compute logits with softcap
210
+ logits = @lm_head.call(x)
211
+ logits = SOFTCAP * Torch.tanh(logits / SOFTCAP)
212
+
213
+ if targets
214
+ logits = logits.float
215
+ Torch::NN::F.cross_entropy(
216
+ logits.view(-1, logits.size(-1)),
217
+ targets.view(-1),
218
+ ignore_index: -1
219
+ )
220
+
221
+ else
222
+ logits
223
+ end
224
+ end
225
+
226
+ # Generate tokens autoregressively
227
+ def generate(tokens, max_tokens:, temperature: 1.0, top_k: nil, seed: 42)
228
+ raise ArgumentError, 'tokens must be an Array' unless tokens.is_a?(Array)
229
+
230
+ device = @wte.weight.device
231
+ rng = temperature.positive? ? Torch::Generator.new(device).manual_seed(seed) : nil
232
+
233
+ ids = Torch.tensor([tokens], dtype: :long, device: device)
234
+
235
+ Torch.inference_mode do
236
+ max_tokens.times do
237
+ logits = forward(ids) # (B, T, vocab_size)
238
+ logits = logits[0..-1, -1, 0..-1] # (B, vocab_size) - last position, preserve batch dim
239
+
240
+ # Apply top-k filtering
241
+ if top_k
242
+ v, = Torch.topk(logits, [top_k, logits.size(-1)].min)
243
+ logits[logits < v[0..-1, -1..-1]] = -Float::INFINITY
244
+ end
245
+
246
+ # Sample next token
247
+ next_id = if temperature.positive?
248
+ probs = Torch::NN::F.softmax(logits / temperature, dim: -1)
249
+ Torch.multinomial(probs, num_samples: 1, generator: rng)
250
+ else
251
+ Torch.argmax(logits, dim: -1, keepdim: true)
252
+ end
253
+
254
+ ids = Torch.cat([ids, next_id], dim: 1)
255
+ yield next_id[0].item
256
+ end
257
+ end
258
+ end
259
+
260
+ private
261
+
262
+ # Precompute rotary embeddings
263
+ def precompute_rotary_embeddings(seq_len, head_dim, base: 10_000)
264
+ device = @wte.weight.device
265
+
266
+ # Frequency computation
267
+ channel_range = Torch.arange(0, head_dim, 2, dtype: :float32, device: device)
268
+ inv_freq = 1.0 / (base**(channel_range / head_dim))
269
+
270
+ # Time steps
271
+ time = Torch.arange(seq_len, dtype: :float32, device: device)
272
+
273
+ # Calculate rotation frequencies
274
+ freqs = Torch.outer(time, inv_freq)
275
+ cos = freqs.cos.bfloat16
276
+ sin = freqs.sin.bfloat16
277
+
278
+ # Add batch and head dimensions for broadcasting
279
+ cos = cos.unsqueeze(0).unsqueeze(2)
280
+ sin = sin.unsqueeze(0).unsqueeze(2)
281
+
282
+ [cos, sin]
283
+ end
284
+ end
285
+ end
@@ -0,0 +1,119 @@
1
+ # frozen_string_literal: true
2
+
3
+ # Nanochat Tokenizer: BPE tokenization using HuggingFace tokenizers
4
+ # Ruby port of nanochat by Andrej Karpathy (https://github.com/karpathy/nanochat)
5
+
6
+ require 'tokenizers'
7
+
8
+ module Nanochat
9
+ # BPE tokenizer wrapping HuggingFace's tokenizers gem.
10
+ # Compatible with Python nanochat's tokenizer.json format.
11
+ class Tokenizer
12
+ SPECIAL_TOKENS = [
13
+ '<|bos|>',
14
+ '<|user_start|>', '<|user_end|>',
15
+ '<|assistant_start|>', '<|assistant_end|>',
16
+ '<|python_start|>', '<|python_end|>',
17
+ '<|output_start|>', '<|output_end|>'
18
+ ].freeze
19
+
20
+ attr_reader :tokenizer, :bos_token_id
21
+
22
+ def initialize(tokenizer)
23
+ @tokenizer = tokenizer
24
+ @bos_token_id = token_to_id('<|bos|>') || 0
25
+ end
26
+
27
+ def self.from_pretrained(identifier)
28
+ tokenizer = Tokenizers.from_pretrained(identifier)
29
+ new(tokenizer)
30
+ end
31
+
32
+ def self.from_file(path)
33
+ tokenizer = Tokenizers.from_file(path)
34
+ new(tokenizer)
35
+ end
36
+
37
+ def self.from_directory(directory)
38
+ tokenizer_path = File.join(directory, 'tokenizer.json')
39
+ raise ArgumentError, "Tokenizer file not found: #{tokenizer_path}" unless File.exist?(tokenizer_path)
40
+
41
+ from_file(tokenizer_path)
42
+ end
43
+
44
+ def vocab_size = @tokenizer.vocab_size
45
+
46
+ # EOS is the same as BOS for compatibility
47
+ def eos_token_id = @bos_token_id
48
+
49
+ # Encode text to token IDs. Accepts strings or arrays of strings.
50
+ # Pass prepend/append as token strings or IDs to add special tokens.
51
+ def encode(text, prepend: nil, append: nil)
52
+ prepend_id = prepend.is_a?(String) ? token_to_id(prepend) : prepend
53
+ append_id = append.is_a?(String) ? token_to_id(append) : append
54
+
55
+ case text
56
+ when String
57
+ encode_single(text, prepend_id, append_id)
58
+ when Array
59
+ text.map { |str| encode_single(str, prepend_id, append_id) }
60
+ else
61
+ raise ArgumentError, "Invalid input type: #{text.class}"
62
+ end
63
+ end
64
+
65
+ def decode(ids)
66
+ return '' if ids.empty?
67
+
68
+ @tokenizer.decode(ids)
69
+ end
70
+
71
+ def token_to_id(token) = @tokenizer.token_to_id(token)
72
+
73
+ def id_to_token(token_id) = @tokenizer.id_to_token(token_id)
74
+
75
+ def save(directory)
76
+ require 'fileutils'
77
+ FileUtils.mkdir_p(directory)
78
+ tokenizer_path = File.join(directory, 'tokenizer.json')
79
+ @tokenizer.save(tokenizer_path)
80
+ puts "Saved tokenizer to #{tokenizer_path}"
81
+ end
82
+
83
+ # Train a new BPE tokenizer from text files
84
+ def self.train_from_files(files, vocab_size:, special_tokens: SPECIAL_TOKENS)
85
+ tokenizer = Tokenizers::Tokenizer.new(Tokenizers::Models::BPE.new(
86
+ unk_token: special_tokens.first
87
+ ))
88
+
89
+ tokenizer.pre_tokenizer = Tokenizers::PreTokenizers::ByteLevel.new(
90
+ add_prefix_space: false,
91
+ use_regex: true
92
+ )
93
+
94
+ tokenizer.decoder = Tokenizers::Decoders::ByteLevel.new
95
+
96
+ trainer = Tokenizers::Trainers::BpeTrainer.new(
97
+ vocab_size:,
98
+ special_tokens:,
99
+ show_progress: true,
100
+ min_frequency: 0
101
+ )
102
+
103
+ tokenizer.train(files, trainer)
104
+ new(tokenizer)
105
+ end
106
+
107
+ private
108
+
109
+ def encode_single(text, prepend_id, append_id)
110
+ encoding = @tokenizer.encode(text)
111
+ ids = encoding.ids
112
+
113
+ ids.unshift(prepend_id) if prepend_id
114
+ ids << append_id if append_id
115
+
116
+ ids
117
+ end
118
+ end
119
+ end
@@ -0,0 +1,5 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Nanochat
4
+ VERSION = '0.1.0.pre'
5
+ end
data/lib/nanochat.rb ADDED
@@ -0,0 +1,27 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'torch'
4
+ require_relative 'nanochat/version'
5
+ require_relative 'nanochat/common'
6
+ require_relative 'nanochat/config'
7
+ require_relative 'nanochat/tokenizer'
8
+ require_relative 'nanochat/gpt'
9
+ require_relative 'nanochat/checkpoint_manager'
10
+ require_relative 'nanochat/engine'
11
+
12
+ # Nanochat: Minimal ChatGPT implementation in Ruby
13
+ # Port of https://github.com/karpathy/nanochat
14
+ module Nanochat
15
+ class Error < StandardError; end
16
+
17
+ def self.load_model(checkpoint_path, tokenizer: nil)
18
+ checkpoint = CheckpointManager.load(checkpoint_path)
19
+ config = Config.from_checkpoint(checkpoint)
20
+ model = GPT.new(config)
21
+ model.load_state_dict(checkpoint[:model] || checkpoint['model'])
22
+
23
+ tokenizer ||= Tokenizer.new(vocab_size: config.vocab_size)
24
+
25
+ Engine.new(model: model, tokenizer: tokenizer)
26
+ end
27
+ end
metadata ADDED
@@ -0,0 +1,91 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: nanochat
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.1.0.pre
5
+ platform: ruby
6
+ authors:
7
+ - Shannon Skipper
8
+ bindir: bin
9
+ cert_chain: []
10
+ date: 1980-01-02 00:00:00.000000000 Z
11
+ dependencies:
12
+ - !ruby/object:Gem::Dependency
13
+ name: tokenizers
14
+ requirement: !ruby/object:Gem::Requirement
15
+ requirements:
16
+ - - "~>"
17
+ - !ruby/object:Gem::Version
18
+ version: 0.6.1
19
+ type: :runtime
20
+ prerelease: false
21
+ version_requirements: !ruby/object:Gem::Requirement
22
+ requirements:
23
+ - - "~>"
24
+ - !ruby/object:Gem::Version
25
+ version: 0.6.1
26
+ - !ruby/object:Gem::Dependency
27
+ name: torch-rb
28
+ requirement: !ruby/object:Gem::Requirement
29
+ requirements:
30
+ - - "~>"
31
+ - !ruby/object:Gem::Version
32
+ version: 0.22.1
33
+ type: :runtime
34
+ prerelease: false
35
+ version_requirements: !ruby/object:Gem::Requirement
36
+ requirements:
37
+ - - "~>"
38
+ - !ruby/object:Gem::Version
39
+ version: 0.22.1
40
+ description: A minimal, hackable LLM implementation with inference capabilities. Port
41
+ of Karpathy's nanochat.
42
+ email:
43
+ - shannonskipper@gmail.com
44
+ executables:
45
+ - nanochat-setup
46
+ extensions: []
47
+ extra_rdoc_files: []
48
+ files:
49
+ - LICENSE
50
+ - README.md
51
+ - bin/nanochat-setup
52
+ - bin/package-checkpoint
53
+ - bin/speedrun.sh
54
+ - bin/train-tiny-model
55
+ - bin/train-with-python-nanochat.sh
56
+ - lib/nanochat.rb
57
+ - lib/nanochat/checkpoint_manager.rb
58
+ - lib/nanochat/common.rb
59
+ - lib/nanochat/config.rb
60
+ - lib/nanochat/engine.rb
61
+ - lib/nanochat/gpt.rb
62
+ - lib/nanochat/tokenizer.rb
63
+ - lib/nanochat/version.rb
64
+ homepage: https://github.com/havenwood/nanochat
65
+ licenses:
66
+ - MIT
67
+ metadata:
68
+ homepage_uri: https://github.com/havenwood/nanochat
69
+ source_code_uri: https://github.com/havenwood/nanochat/tree/main
70
+ changelog_uri: https://github.com/havenwood/nanochat/releases
71
+ bug_tracker_uri: https://github.com/havenwood/nanochat/issues
72
+ documentation_uri: https://rubydoc.info/gems/nanochat/0.1.0.pre
73
+ rubygems_mfa_required: 'true'
74
+ rdoc_options: []
75
+ require_paths:
76
+ - lib
77
+ required_ruby_version: !ruby/object:Gem::Requirement
78
+ requirements:
79
+ - - ">="
80
+ - !ruby/object:Gem::Version
81
+ version: 3.4.0
82
+ required_rubygems_version: !ruby/object:Gem::Requirement
83
+ requirements:
84
+ - - ">="
85
+ - !ruby/object:Gem::Version
86
+ version: '0'
87
+ requirements: []
88
+ rubygems_version: 3.6.9
89
+ specification_version: 4
90
+ summary: Ruby port of nanochat, a minimal LLM
91
+ test_files: []