nanochat 0.1.0.pre
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/LICENSE +25 -0
- data/README.md +129 -0
- data/bin/nanochat-setup +186 -0
- data/bin/package-checkpoint +122 -0
- data/bin/speedrun.sh +32 -0
- data/bin/train-tiny-model +190 -0
- data/bin/train-with-python-nanochat.sh +167 -0
- data/lib/nanochat/checkpoint_manager.rb +40 -0
- data/lib/nanochat/common.rb +32 -0
- data/lib/nanochat/config.rb +49 -0
- data/lib/nanochat/engine.rb +152 -0
- data/lib/nanochat/gpt.rb +285 -0
- data/lib/nanochat/tokenizer.rb +119 -0
- data/lib/nanochat/version.rb +5 -0
- data/lib/nanochat.rb +27 -0
- metadata +91 -0
data/lib/nanochat/gpt.rb
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
# Nanochat: Minimal ChatGPT implementation in Ruby
|
|
4
|
+
# Ruby port of nanochat by Andrej Karpathy (https://github.com/karpathy/nanochat)
|
|
5
|
+
module Nanochat
|
|
6
|
+
# RMSNorm (no learnable parameters)
|
|
7
|
+
def self.norm(input)
|
|
8
|
+
variance = input.pow(2).mean(-1, keepdim: true)
|
|
9
|
+
input * Torch.rsqrt(variance + 1e-5)
|
|
10
|
+
end
|
|
11
|
+
|
|
12
|
+
# Apply RoPE (Rotary Position Embeddings)
|
|
13
|
+
def self.apply_rotary_emb(input, cos, sin)
|
|
14
|
+
raise ArgumentError, 'Expected 4D tensor for multihead attention' unless input.ndim == 4
|
|
15
|
+
|
|
16
|
+
dim = input.shape[3] / 2
|
|
17
|
+
# Split last dimension in half
|
|
18
|
+
x1, x2 = input.split(dim, dim: -1)
|
|
19
|
+
y1 = (x1 * cos) + (x2 * sin)
|
|
20
|
+
y2 = (x1 * (-sin)) + (x2 * cos)
|
|
21
|
+
out = Torch.cat([y1, y2], -1)
|
|
22
|
+
out.to(dtype: input.dtype)
|
|
23
|
+
end
|
|
24
|
+
|
|
25
|
+
# Feed-forward network (MLP)
|
|
26
|
+
class MLP < Torch::NN::Module
|
|
27
|
+
def initialize(config)
|
|
28
|
+
super()
|
|
29
|
+
@c_fc = Torch::NN::Linear.new(config.n_embd, 4 * config.n_embd, bias: false)
|
|
30
|
+
@c_proj = Torch::NN::Linear.new(4 * config.n_embd, config.n_embd, bias: false)
|
|
31
|
+
end
|
|
32
|
+
|
|
33
|
+
def forward(input)
|
|
34
|
+
x = @c_fc.call(input)
|
|
35
|
+
x = Torch::NN::F.relu(x).square # ReLU^2 activation
|
|
36
|
+
@c_proj.call(x)
|
|
37
|
+
end
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
# Causal self-attention with multi-query support
|
|
41
|
+
class CausalSelfAttention < Torch::NN::Module
|
|
42
|
+
attr_reader :layer_idx
|
|
43
|
+
|
|
44
|
+
def initialize(config, layer_idx)
|
|
45
|
+
super()
|
|
46
|
+
@layer_idx = layer_idx
|
|
47
|
+
@n_head = config.n_head
|
|
48
|
+
@n_kv_head = config.n_kv_head
|
|
49
|
+
@n_embd = config.n_embd
|
|
50
|
+
@head_dim = @n_embd / @n_head
|
|
51
|
+
|
|
52
|
+
raise ArgumentError, 'n_embd must be divisible by n_head' unless (@n_embd % @n_head).zero?
|
|
53
|
+
raise ArgumentError, 'Invalid MQA configuration' unless @n_kv_head <= @n_head && (@n_head % @n_kv_head).zero?
|
|
54
|
+
|
|
55
|
+
@c_q = Torch::NN::Linear.new(@n_embd, @n_head * @head_dim, bias: false)
|
|
56
|
+
@c_k = Torch::NN::Linear.new(@n_embd, @n_kv_head * @head_dim, bias: false)
|
|
57
|
+
@c_v = Torch::NN::Linear.new(@n_embd, @n_kv_head * @head_dim, bias: false)
|
|
58
|
+
@c_proj = Torch::NN::Linear.new(@n_embd, @n_embd, bias: false)
|
|
59
|
+
end
|
|
60
|
+
|
|
61
|
+
def forward(input, cos_sin, kv_cache)
|
|
62
|
+
batch_size, seq_len, = input.size
|
|
63
|
+
|
|
64
|
+
# Project to Q, K, V
|
|
65
|
+
query = @c_q.call(input).view(batch_size, seq_len, @n_head, @head_dim)
|
|
66
|
+
key = @c_k.call(input).view(batch_size, seq_len, @n_kv_head, @head_dim)
|
|
67
|
+
value = @c_v.call(input).view(batch_size, seq_len, @n_kv_head, @head_dim)
|
|
68
|
+
|
|
69
|
+
# Apply Rotary Embeddings and QK norm
|
|
70
|
+
cos, sin = cos_sin
|
|
71
|
+
query = Nanochat.apply_rotary_emb(query, cos, sin)
|
|
72
|
+
key = Nanochat.apply_rotary_emb(key, cos, sin)
|
|
73
|
+
query = Nanochat.norm(query)
|
|
74
|
+
key = Nanochat.norm(key)
|
|
75
|
+
|
|
76
|
+
# Transpose to make head the batch dimension: (B, T, H, D) -> (B, H, T, D)
|
|
77
|
+
query = query.transpose(1, 2)
|
|
78
|
+
key = key.transpose(1, 2)
|
|
79
|
+
value = value.transpose(1, 2)
|
|
80
|
+
|
|
81
|
+
# Apply KV cache if provided
|
|
82
|
+
key, value = kv_cache.insert_kv(@layer_idx, key, value) if kv_cache
|
|
83
|
+
|
|
84
|
+
tq = query.size(2) # number of queries
|
|
85
|
+
tk = key.size(2) # number of keys (total in cache + current)
|
|
86
|
+
|
|
87
|
+
# Expand key/value for Grouped Query Attention if needed
|
|
88
|
+
if @n_head != @n_kv_head
|
|
89
|
+
num_groups = @n_head / @n_kv_head
|
|
90
|
+
key = key.repeat_interleave(num_groups, dim: 1)
|
|
91
|
+
value = value.repeat_interleave(num_groups, dim: 1)
|
|
92
|
+
end
|
|
93
|
+
|
|
94
|
+
# Scaled dot-product attention (manual implementation for torch-rb)
|
|
95
|
+
scale = Math.sqrt(@head_dim)
|
|
96
|
+
scores = Torch.matmul(query, key.transpose(-2, -1)) / scale
|
|
97
|
+
|
|
98
|
+
# Apply causal mask
|
|
99
|
+
if kv_cache.nil? || tq == tk
|
|
100
|
+
# Training or full sequence: causal mask
|
|
101
|
+
mask = Torch.ones([tq, tk], device: query.device).tril
|
|
102
|
+
mask = Torch.zeros([tq, tk], device: query.device).masked_fill(mask.logical_not, -Float::INFINITY)
|
|
103
|
+
scores += mask
|
|
104
|
+
elsif tq == 1
|
|
105
|
+
# Inference with single query: attend to all keys (no mask needed)
|
|
106
|
+
else
|
|
107
|
+
# Inference with multiple queries: prefix + causal
|
|
108
|
+
mask = Torch.zeros([tq, tk], device: query.device)
|
|
109
|
+
prefix_len = tk - tq
|
|
110
|
+
if prefix_len.positive?
|
|
111
|
+
# Full attention to prefix
|
|
112
|
+
mask[0..-1, prefix_len..-1] = Torch.ones([tq, tq], device: query.device).tril.logical_not * -Float::INFINITY
|
|
113
|
+
else
|
|
114
|
+
mask = Torch.ones([tq, tk], device: query.device).tril.logical_not * -Float::INFINITY
|
|
115
|
+
end
|
|
116
|
+
scores += mask
|
|
117
|
+
end
|
|
118
|
+
|
|
119
|
+
# Softmax and apply to values
|
|
120
|
+
attn_weights = Torch::NN::F.softmax(scores, dim: -1)
|
|
121
|
+
y = Torch.matmul(attn_weights, value) # (B, H, Tq, D)
|
|
122
|
+
|
|
123
|
+
# Re-assemble heads and project back
|
|
124
|
+
y = y.transpose(1, 2).contiguous.view(batch_size, seq_len, -1)
|
|
125
|
+
@c_proj.call(y)
|
|
126
|
+
end
|
|
127
|
+
end
|
|
128
|
+
|
|
129
|
+
# Transformer Block
|
|
130
|
+
class Block < Torch::NN::Module
|
|
131
|
+
def initialize(config, layer_idx)
|
|
132
|
+
super()
|
|
133
|
+
@attn = CausalSelfAttention.new(config, layer_idx)
|
|
134
|
+
@mlp = MLP.new(config)
|
|
135
|
+
end
|
|
136
|
+
|
|
137
|
+
def forward(input, cos_sin, kv_cache)
|
|
138
|
+
# Attention with pre-norm and residual
|
|
139
|
+
x = input + @attn.call(Nanochat.norm(input), cos_sin, kv_cache)
|
|
140
|
+
# MLP with pre-norm and residual
|
|
141
|
+
x + @mlp.call(Nanochat.norm(x))
|
|
142
|
+
end
|
|
143
|
+
end
|
|
144
|
+
|
|
145
|
+
# GPT transformer model
|
|
146
|
+
class GPT < Torch::NN::Module
|
|
147
|
+
SOFTCAP = 15
|
|
148
|
+
ROTARY_CACHE_MULTIPLIER = 10
|
|
149
|
+
|
|
150
|
+
attr_reader :config
|
|
151
|
+
|
|
152
|
+
# Load a GPT model from a checkpoint file
|
|
153
|
+
def self.from_checkpoint(path, config)
|
|
154
|
+
checkpoint = CheckpointManager.load(path)
|
|
155
|
+
|
|
156
|
+
model_state = checkpoint['model'] || checkpoint[:model]
|
|
157
|
+
raise ArgumentError, 'Checkpoint missing model state' unless model_state
|
|
158
|
+
|
|
159
|
+
model = new(config)
|
|
160
|
+
model.load_state_dict(model_state)
|
|
161
|
+
|
|
162
|
+
model
|
|
163
|
+
end
|
|
164
|
+
|
|
165
|
+
def initialize(config)
|
|
166
|
+
super()
|
|
167
|
+
@config = config
|
|
168
|
+
config.validate!
|
|
169
|
+
|
|
170
|
+
# Transformer components
|
|
171
|
+
@wte = Torch::NN::Embedding.new(config.vocab_size, config.n_embd)
|
|
172
|
+
blocks = (0...config.n_layer).map { |idx| Block.new(config, idx) }
|
|
173
|
+
@blocks = Torch::NN::ModuleList.new(blocks)
|
|
174
|
+
@lm_head = Torch::NN::Linear.new(config.n_embd, config.vocab_size, bias: false)
|
|
175
|
+
|
|
176
|
+
# Precompute rotary embeddings
|
|
177
|
+
@rotary_seq_len = config.block_size * ROTARY_CACHE_MULTIPLIER
|
|
178
|
+
head_dim = config.n_embd / config.n_head
|
|
179
|
+
cos, sin = precompute_rotary_embeddings(@rotary_seq_len, head_dim)
|
|
180
|
+
|
|
181
|
+
# Register as buffers (not saved to checkpoint)
|
|
182
|
+
register_buffer('cos', cos, persistent: false)
|
|
183
|
+
register_buffer('sin', sin, persistent: false)
|
|
184
|
+
end
|
|
185
|
+
|
|
186
|
+
def forward(idx, targets: nil, kv_cache: nil)
|
|
187
|
+
_, seq_len = idx.size
|
|
188
|
+
|
|
189
|
+
# Get rotary embeddings for current sequence
|
|
190
|
+
raise ArgumentError, "Sequence too long: #{seq_len} > #{@cos.size(1)}" if seq_len > @cos.size(1)
|
|
191
|
+
raise ArgumentError, 'Device mismatch' unless idx.device == @cos.device
|
|
192
|
+
|
|
193
|
+
# Offset rotary embeddings for KV cache
|
|
194
|
+
t0 = kv_cache.nil? ? 0 : kv_cache.pos
|
|
195
|
+
cos = @cos.narrow(1, t0, seq_len)
|
|
196
|
+
sin = @sin.narrow(1, t0, seq_len)
|
|
197
|
+
cos_sin = [cos, sin]
|
|
198
|
+
|
|
199
|
+
# Forward through transformer
|
|
200
|
+
x = @wte.call(idx)
|
|
201
|
+
x = Nanochat.norm(x)
|
|
202
|
+
|
|
203
|
+
@blocks.each do |block|
|
|
204
|
+
x = block.call(x, cos_sin, kv_cache)
|
|
205
|
+
end
|
|
206
|
+
|
|
207
|
+
x = Nanochat.norm(x)
|
|
208
|
+
|
|
209
|
+
# Compute logits with softcap
|
|
210
|
+
logits = @lm_head.call(x)
|
|
211
|
+
logits = SOFTCAP * Torch.tanh(logits / SOFTCAP)
|
|
212
|
+
|
|
213
|
+
if targets
|
|
214
|
+
logits = logits.float
|
|
215
|
+
Torch::NN::F.cross_entropy(
|
|
216
|
+
logits.view(-1, logits.size(-1)),
|
|
217
|
+
targets.view(-1),
|
|
218
|
+
ignore_index: -1
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
else
|
|
222
|
+
logits
|
|
223
|
+
end
|
|
224
|
+
end
|
|
225
|
+
|
|
226
|
+
# Generate tokens autoregressively
|
|
227
|
+
def generate(tokens, max_tokens:, temperature: 1.0, top_k: nil, seed: 42)
|
|
228
|
+
raise ArgumentError, 'tokens must be an Array' unless tokens.is_a?(Array)
|
|
229
|
+
|
|
230
|
+
device = @wte.weight.device
|
|
231
|
+
rng = temperature.positive? ? Torch::Generator.new(device).manual_seed(seed) : nil
|
|
232
|
+
|
|
233
|
+
ids = Torch.tensor([tokens], dtype: :long, device: device)
|
|
234
|
+
|
|
235
|
+
Torch.inference_mode do
|
|
236
|
+
max_tokens.times do
|
|
237
|
+
logits = forward(ids) # (B, T, vocab_size)
|
|
238
|
+
logits = logits[0..-1, -1, 0..-1] # (B, vocab_size) - last position, preserve batch dim
|
|
239
|
+
|
|
240
|
+
# Apply top-k filtering
|
|
241
|
+
if top_k
|
|
242
|
+
v, = Torch.topk(logits, [top_k, logits.size(-1)].min)
|
|
243
|
+
logits[logits < v[0..-1, -1..-1]] = -Float::INFINITY
|
|
244
|
+
end
|
|
245
|
+
|
|
246
|
+
# Sample next token
|
|
247
|
+
next_id = if temperature.positive?
|
|
248
|
+
probs = Torch::NN::F.softmax(logits / temperature, dim: -1)
|
|
249
|
+
Torch.multinomial(probs, num_samples: 1, generator: rng)
|
|
250
|
+
else
|
|
251
|
+
Torch.argmax(logits, dim: -1, keepdim: true)
|
|
252
|
+
end
|
|
253
|
+
|
|
254
|
+
ids = Torch.cat([ids, next_id], dim: 1)
|
|
255
|
+
yield next_id[0].item
|
|
256
|
+
end
|
|
257
|
+
end
|
|
258
|
+
end
|
|
259
|
+
|
|
260
|
+
private
|
|
261
|
+
|
|
262
|
+
# Precompute rotary embeddings
|
|
263
|
+
def precompute_rotary_embeddings(seq_len, head_dim, base: 10_000)
|
|
264
|
+
device = @wte.weight.device
|
|
265
|
+
|
|
266
|
+
# Frequency computation
|
|
267
|
+
channel_range = Torch.arange(0, head_dim, 2, dtype: :float32, device: device)
|
|
268
|
+
inv_freq = 1.0 / (base**(channel_range / head_dim))
|
|
269
|
+
|
|
270
|
+
# Time steps
|
|
271
|
+
time = Torch.arange(seq_len, dtype: :float32, device: device)
|
|
272
|
+
|
|
273
|
+
# Calculate rotation frequencies
|
|
274
|
+
freqs = Torch.outer(time, inv_freq)
|
|
275
|
+
cos = freqs.cos.bfloat16
|
|
276
|
+
sin = freqs.sin.bfloat16
|
|
277
|
+
|
|
278
|
+
# Add batch and head dimensions for broadcasting
|
|
279
|
+
cos = cos.unsqueeze(0).unsqueeze(2)
|
|
280
|
+
sin = sin.unsqueeze(0).unsqueeze(2)
|
|
281
|
+
|
|
282
|
+
[cos, sin]
|
|
283
|
+
end
|
|
284
|
+
end
|
|
285
|
+
end
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
# Nanochat Tokenizer: BPE tokenization using HuggingFace tokenizers
|
|
4
|
+
# Ruby port of nanochat by Andrej Karpathy (https://github.com/karpathy/nanochat)
|
|
5
|
+
|
|
6
|
+
require 'tokenizers'
|
|
7
|
+
|
|
8
|
+
module Nanochat
|
|
9
|
+
# BPE tokenizer wrapping HuggingFace's tokenizers gem.
|
|
10
|
+
# Compatible with Python nanochat's tokenizer.json format.
|
|
11
|
+
class Tokenizer
|
|
12
|
+
SPECIAL_TOKENS = [
|
|
13
|
+
'<|bos|>',
|
|
14
|
+
'<|user_start|>', '<|user_end|>',
|
|
15
|
+
'<|assistant_start|>', '<|assistant_end|>',
|
|
16
|
+
'<|python_start|>', '<|python_end|>',
|
|
17
|
+
'<|output_start|>', '<|output_end|>'
|
|
18
|
+
].freeze
|
|
19
|
+
|
|
20
|
+
attr_reader :tokenizer, :bos_token_id
|
|
21
|
+
|
|
22
|
+
def initialize(tokenizer)
|
|
23
|
+
@tokenizer = tokenizer
|
|
24
|
+
@bos_token_id = token_to_id('<|bos|>') || 0
|
|
25
|
+
end
|
|
26
|
+
|
|
27
|
+
def self.from_pretrained(identifier)
|
|
28
|
+
tokenizer = Tokenizers.from_pretrained(identifier)
|
|
29
|
+
new(tokenizer)
|
|
30
|
+
end
|
|
31
|
+
|
|
32
|
+
def self.from_file(path)
|
|
33
|
+
tokenizer = Tokenizers.from_file(path)
|
|
34
|
+
new(tokenizer)
|
|
35
|
+
end
|
|
36
|
+
|
|
37
|
+
def self.from_directory(directory)
|
|
38
|
+
tokenizer_path = File.join(directory, 'tokenizer.json')
|
|
39
|
+
raise ArgumentError, "Tokenizer file not found: #{tokenizer_path}" unless File.exist?(tokenizer_path)
|
|
40
|
+
|
|
41
|
+
from_file(tokenizer_path)
|
|
42
|
+
end
|
|
43
|
+
|
|
44
|
+
def vocab_size = @tokenizer.vocab_size
|
|
45
|
+
|
|
46
|
+
# EOS is the same as BOS for compatibility
|
|
47
|
+
def eos_token_id = @bos_token_id
|
|
48
|
+
|
|
49
|
+
# Encode text to token IDs. Accepts strings or arrays of strings.
|
|
50
|
+
# Pass prepend/append as token strings or IDs to add special tokens.
|
|
51
|
+
def encode(text, prepend: nil, append: nil)
|
|
52
|
+
prepend_id = prepend.is_a?(String) ? token_to_id(prepend) : prepend
|
|
53
|
+
append_id = append.is_a?(String) ? token_to_id(append) : append
|
|
54
|
+
|
|
55
|
+
case text
|
|
56
|
+
when String
|
|
57
|
+
encode_single(text, prepend_id, append_id)
|
|
58
|
+
when Array
|
|
59
|
+
text.map { |str| encode_single(str, prepend_id, append_id) }
|
|
60
|
+
else
|
|
61
|
+
raise ArgumentError, "Invalid input type: #{text.class}"
|
|
62
|
+
end
|
|
63
|
+
end
|
|
64
|
+
|
|
65
|
+
def decode(ids)
|
|
66
|
+
return '' if ids.empty?
|
|
67
|
+
|
|
68
|
+
@tokenizer.decode(ids)
|
|
69
|
+
end
|
|
70
|
+
|
|
71
|
+
def token_to_id(token) = @tokenizer.token_to_id(token)
|
|
72
|
+
|
|
73
|
+
def id_to_token(token_id) = @tokenizer.id_to_token(token_id)
|
|
74
|
+
|
|
75
|
+
def save(directory)
|
|
76
|
+
require 'fileutils'
|
|
77
|
+
FileUtils.mkdir_p(directory)
|
|
78
|
+
tokenizer_path = File.join(directory, 'tokenizer.json')
|
|
79
|
+
@tokenizer.save(tokenizer_path)
|
|
80
|
+
puts "Saved tokenizer to #{tokenizer_path}"
|
|
81
|
+
end
|
|
82
|
+
|
|
83
|
+
# Train a new BPE tokenizer from text files
|
|
84
|
+
def self.train_from_files(files, vocab_size:, special_tokens: SPECIAL_TOKENS)
|
|
85
|
+
tokenizer = Tokenizers::Tokenizer.new(Tokenizers::Models::BPE.new(
|
|
86
|
+
unk_token: special_tokens.first
|
|
87
|
+
))
|
|
88
|
+
|
|
89
|
+
tokenizer.pre_tokenizer = Tokenizers::PreTokenizers::ByteLevel.new(
|
|
90
|
+
add_prefix_space: false,
|
|
91
|
+
use_regex: true
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
tokenizer.decoder = Tokenizers::Decoders::ByteLevel.new
|
|
95
|
+
|
|
96
|
+
trainer = Tokenizers::Trainers::BpeTrainer.new(
|
|
97
|
+
vocab_size:,
|
|
98
|
+
special_tokens:,
|
|
99
|
+
show_progress: true,
|
|
100
|
+
min_frequency: 0
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
tokenizer.train(files, trainer)
|
|
104
|
+
new(tokenizer)
|
|
105
|
+
end
|
|
106
|
+
|
|
107
|
+
private
|
|
108
|
+
|
|
109
|
+
def encode_single(text, prepend_id, append_id)
|
|
110
|
+
encoding = @tokenizer.encode(text)
|
|
111
|
+
ids = encoding.ids
|
|
112
|
+
|
|
113
|
+
ids.unshift(prepend_id) if prepend_id
|
|
114
|
+
ids << append_id if append_id
|
|
115
|
+
|
|
116
|
+
ids
|
|
117
|
+
end
|
|
118
|
+
end
|
|
119
|
+
end
|
data/lib/nanochat.rb
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require 'torch'
|
|
4
|
+
require_relative 'nanochat/version'
|
|
5
|
+
require_relative 'nanochat/common'
|
|
6
|
+
require_relative 'nanochat/config'
|
|
7
|
+
require_relative 'nanochat/tokenizer'
|
|
8
|
+
require_relative 'nanochat/gpt'
|
|
9
|
+
require_relative 'nanochat/checkpoint_manager'
|
|
10
|
+
require_relative 'nanochat/engine'
|
|
11
|
+
|
|
12
|
+
# Nanochat: Minimal ChatGPT implementation in Ruby
|
|
13
|
+
# Port of https://github.com/karpathy/nanochat
|
|
14
|
+
module Nanochat
|
|
15
|
+
class Error < StandardError; end
|
|
16
|
+
|
|
17
|
+
def self.load_model(checkpoint_path, tokenizer: nil)
|
|
18
|
+
checkpoint = CheckpointManager.load(checkpoint_path)
|
|
19
|
+
config = Config.from_checkpoint(checkpoint)
|
|
20
|
+
model = GPT.new(config)
|
|
21
|
+
model.load_state_dict(checkpoint[:model] || checkpoint['model'])
|
|
22
|
+
|
|
23
|
+
tokenizer ||= Tokenizer.new(vocab_size: config.vocab_size)
|
|
24
|
+
|
|
25
|
+
Engine.new(model: model, tokenizer: tokenizer)
|
|
26
|
+
end
|
|
27
|
+
end
|
metadata
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
--- !ruby/object:Gem::Specification
|
|
2
|
+
name: nanochat
|
|
3
|
+
version: !ruby/object:Gem::Version
|
|
4
|
+
version: 0.1.0.pre
|
|
5
|
+
platform: ruby
|
|
6
|
+
authors:
|
|
7
|
+
- Shannon Skipper
|
|
8
|
+
bindir: bin
|
|
9
|
+
cert_chain: []
|
|
10
|
+
date: 1980-01-02 00:00:00.000000000 Z
|
|
11
|
+
dependencies:
|
|
12
|
+
- !ruby/object:Gem::Dependency
|
|
13
|
+
name: tokenizers
|
|
14
|
+
requirement: !ruby/object:Gem::Requirement
|
|
15
|
+
requirements:
|
|
16
|
+
- - "~>"
|
|
17
|
+
- !ruby/object:Gem::Version
|
|
18
|
+
version: 0.6.1
|
|
19
|
+
type: :runtime
|
|
20
|
+
prerelease: false
|
|
21
|
+
version_requirements: !ruby/object:Gem::Requirement
|
|
22
|
+
requirements:
|
|
23
|
+
- - "~>"
|
|
24
|
+
- !ruby/object:Gem::Version
|
|
25
|
+
version: 0.6.1
|
|
26
|
+
- !ruby/object:Gem::Dependency
|
|
27
|
+
name: torch-rb
|
|
28
|
+
requirement: !ruby/object:Gem::Requirement
|
|
29
|
+
requirements:
|
|
30
|
+
- - "~>"
|
|
31
|
+
- !ruby/object:Gem::Version
|
|
32
|
+
version: 0.22.1
|
|
33
|
+
type: :runtime
|
|
34
|
+
prerelease: false
|
|
35
|
+
version_requirements: !ruby/object:Gem::Requirement
|
|
36
|
+
requirements:
|
|
37
|
+
- - "~>"
|
|
38
|
+
- !ruby/object:Gem::Version
|
|
39
|
+
version: 0.22.1
|
|
40
|
+
description: A minimal, hackable LLM implementation with inference capabilities. Port
|
|
41
|
+
of Karpathy's nanochat.
|
|
42
|
+
email:
|
|
43
|
+
- shannonskipper@gmail.com
|
|
44
|
+
executables:
|
|
45
|
+
- nanochat-setup
|
|
46
|
+
extensions: []
|
|
47
|
+
extra_rdoc_files: []
|
|
48
|
+
files:
|
|
49
|
+
- LICENSE
|
|
50
|
+
- README.md
|
|
51
|
+
- bin/nanochat-setup
|
|
52
|
+
- bin/package-checkpoint
|
|
53
|
+
- bin/speedrun.sh
|
|
54
|
+
- bin/train-tiny-model
|
|
55
|
+
- bin/train-with-python-nanochat.sh
|
|
56
|
+
- lib/nanochat.rb
|
|
57
|
+
- lib/nanochat/checkpoint_manager.rb
|
|
58
|
+
- lib/nanochat/common.rb
|
|
59
|
+
- lib/nanochat/config.rb
|
|
60
|
+
- lib/nanochat/engine.rb
|
|
61
|
+
- lib/nanochat/gpt.rb
|
|
62
|
+
- lib/nanochat/tokenizer.rb
|
|
63
|
+
- lib/nanochat/version.rb
|
|
64
|
+
homepage: https://github.com/havenwood/nanochat
|
|
65
|
+
licenses:
|
|
66
|
+
- MIT
|
|
67
|
+
metadata:
|
|
68
|
+
homepage_uri: https://github.com/havenwood/nanochat
|
|
69
|
+
source_code_uri: https://github.com/havenwood/nanochat/tree/main
|
|
70
|
+
changelog_uri: https://github.com/havenwood/nanochat/releases
|
|
71
|
+
bug_tracker_uri: https://github.com/havenwood/nanochat/issues
|
|
72
|
+
documentation_uri: https://rubydoc.info/gems/nanochat/0.1.0.pre
|
|
73
|
+
rubygems_mfa_required: 'true'
|
|
74
|
+
rdoc_options: []
|
|
75
|
+
require_paths:
|
|
76
|
+
- lib
|
|
77
|
+
required_ruby_version: !ruby/object:Gem::Requirement
|
|
78
|
+
requirements:
|
|
79
|
+
- - ">="
|
|
80
|
+
- !ruby/object:Gem::Version
|
|
81
|
+
version: 3.4.0
|
|
82
|
+
required_rubygems_version: !ruby/object:Gem::Requirement
|
|
83
|
+
requirements:
|
|
84
|
+
- - ">="
|
|
85
|
+
- !ruby/object:Gem::Version
|
|
86
|
+
version: '0'
|
|
87
|
+
requirements: []
|
|
88
|
+
rubygems_version: 3.6.9
|
|
89
|
+
specification_version: 4
|
|
90
|
+
summary: Ruby port of nanochat, a minimal LLM
|
|
91
|
+
test_files: []
|