nanochat 0.1.0.pre

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,167 @@
1
+ #!/bin/bash
2
+
3
+ # Train a tiny nanochat checkpoint using Python nanochat
4
+ #
5
+ # This script is adapted from python-nanochat/dev/runcpu.sh
6
+ # Original: https://github.com/karpathy/nanochat by Andrej Karpathy
7
+ #
8
+ # REQUIREMENTS:
9
+ # - Python nanochat cloned at: ../python-nanochat or ./python-nanochat
10
+ # - Python 3.10+
11
+ # - Rust (for building rustbpe tokenizer)
12
+ # - ~1GB disk space for data
13
+ # - ~30 minutes on CPU or ~5 minutes on GPU
14
+ #
15
+ # USAGE:
16
+ # bash bin/train-with-python-nanochat.sh
17
+ #
18
+ # OUTPUT:
19
+ # Trained checkpoint at: ~/.cache/nanochat/model.pt
20
+ # Tokenizer at: ~/.cache/nanochat/tokenizer/tokenizer.json
21
+
22
+ set -e # Exit on error
23
+
24
+ echo "🔥 Train Tiny Nanochat Model"
25
+ echo "======================================================================"
26
+ echo ""
27
+ echo "This will train a d4 model (4 layers, minimal for demos)"
28
+ echo "Output: ~/.cache/nanochat/"
29
+ echo ""
30
+ echo "⏱️ Estimated time: ~30 minutes on CPU"
31
+ echo ""
32
+ echo "📝 Attribution: Using training scripts from"
33
+ echo " https://github.com/karpathy/nanochat by Andrej Karpathy"
34
+ echo ""
35
+ echo "======================================================================"
36
+ echo ""
37
+
38
+ # Find python-nanochat directory
39
+ if [ -d "python-nanochat" ]; then
40
+ PYTHON_NANOCHAT_DIR="python-nanochat"
41
+ elif [ -d "../python-nanochat" ]; then
42
+ PYTHON_NANOCHAT_DIR="../python-nanochat"
43
+ else
44
+ echo "❌ Python nanochat not found"
45
+ echo ""
46
+ echo "Clone it first:"
47
+ echo " git clone https://github.com/karpathy/nanochat python-nanochat"
48
+ echo ""
49
+ exit 1
50
+ fi
51
+
52
+ echo "✅ Python nanochat found at: $PYTHON_NANOCHAT_DIR"
53
+ echo ""
54
+
55
+ # Change to python-nanochat directory
56
+ cd "$PYTHON_NANOCHAT_DIR"
57
+
58
+ # Setup environment
59
+ export OMP_NUM_THREADS=1
60
+ NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
61
+ mkdir -p "$NANOCHAT_BASE_DIR"
62
+
63
+ echo "🔧 Setting up Python environment..."
64
+ command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
65
+ [ -d ".venv" ] || uv venv
66
+ uv sync --extra cpu
67
+ source .venv/bin/activate
68
+ echo ""
69
+
70
+ echo "🦀 Building Rust tokenizer..."
71
+ curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
72
+ source "$HOME/.cargo/env"
73
+ uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
74
+ echo ""
75
+
76
+ echo "📦 Downloading evaluation bundle..."
77
+ EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip
78
+ if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then
79
+ curl -L -o eval_bundle.zip "$EVAL_BUNDLE_URL"
80
+ unzip -q eval_bundle.zip
81
+ rm eval_bundle.zip
82
+ mv eval_bundle "$NANOCHAT_BASE_DIR"
83
+ fi
84
+ echo ""
85
+
86
+ # Reset report
87
+ python -m nanochat.report reset
88
+
89
+ echo "📚 Downloading training data (~1GB)..."
90
+ python -m nanochat.dataset -n 4
91
+ echo ""
92
+
93
+ echo "🔤 Training tokenizer..."
94
+ python -m scripts.tok_train --max_chars=1000000000
95
+ python -m scripts.tok_eval
96
+ echo ""
97
+
98
+ echo "🚀 Training base model (50 iterations, ~30 mins)..."
99
+ python -m scripts.base_train \
100
+ --depth=4 \
101
+ --max_seq_len=1024 \
102
+ --device_batch_size=1 \
103
+ --total_batch_size=1024 \
104
+ --eval_every=50 \
105
+ --eval_tokens=4096 \
106
+ --core_metric_every=50 \
107
+ --core_metric_max_per_task=12 \
108
+ --sample_every=50 \
109
+ --num_iterations=50
110
+ echo ""
111
+
112
+ echo "📊 Evaluating base model..."
113
+ python -m scripts.base_loss --device_batch_size=1 --split_tokens=4096
114
+ python -m scripts.base_eval --max-per-task=16
115
+ echo ""
116
+
117
+ echo "🎯 Midtraining (100 iterations)..."
118
+ python -m scripts.mid_train \
119
+ --max_seq_len=1024 \
120
+ --device_batch_size=1 \
121
+ --eval_every=50 \
122
+ --eval_tokens=4096 \
123
+ --total_batch_size=1024 \
124
+ --num_iterations=100
125
+ echo ""
126
+
127
+ echo "💬 Supervised fine-tuning (100 iterations)..."
128
+ python -m scripts.chat_sft \
129
+ --device_batch_size=1 \
130
+ --target_examples_per_step=4 \
131
+ --num_iterations=100 \
132
+ --eval_steps=4 \
133
+ --eval_metrics_max_problems=16
134
+ echo ""
135
+
136
+ echo "📝 Generating training report..."
137
+ python -m nanochat.report generate
138
+ echo ""
139
+
140
+ echo "======================================================================"
141
+ echo "✅ Training complete!"
142
+ echo ""
143
+ echo "📦 Checkpoint location:"
144
+ echo " Model: $NANOCHAT_BASE_DIR/model.pt"
145
+ echo " Tokenizer: $NANOCHAT_BASE_DIR/tokenizer/tokenizer.json"
146
+ echo ""
147
+ echo "🎯 Next Steps - Use Your Model in Ruby"
148
+ echo "======================================================================"
149
+ echo ""
150
+ echo "# Interactive chat"
151
+ echo "ruby examples/chat_cli.rb"
152
+ echo ""
153
+ echo "# Web UI (visit http://localhost:8000)"
154
+ echo "ruby examples/chat_web.rb"
155
+ echo ""
156
+ echo "# Generate text"
157
+ echo "ruby examples/generate_text.rb 'Once upon a time'"
158
+ echo ""
159
+ echo "# Fine-tune on your data"
160
+ echo "ruby examples/finetune.rb --data my_data.txt --output custom.pt"
161
+ echo ""
162
+ echo "======================================================================"
163
+ echo ""
164
+ echo "📦 Optional: Package this checkpoint for distribution"
165
+ echo ""
166
+ echo "tar -czf nanochat-tiny-d4.tar.gz -C $(dirname $NANOCHAT_BASE_DIR) $(basename $NANOCHAT_BASE_DIR)"
167
+ echo ""
@@ -0,0 +1,40 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Nanochat
4
+ # Checkpoint loading and saving
5
+ module CheckpointManager
6
+ class << self
7
+ def load(path)
8
+ raise ArgumentError, "Checkpoint not found: #{path}" unless File.exist?(path)
9
+
10
+ Torch.load(path)
11
+ end
12
+
13
+ def save(path, model: nil, state_dict: nil, optimizer: nil, config: nil, **metadata)
14
+ raise ArgumentError, 'Must provide either model: or state_dict:' if model.nil? && state_dict.nil?
15
+
16
+ FileUtils.mkdir_p(File.dirname(path))
17
+
18
+ model_dict = model ? model.state_dict : state_dict
19
+ model_dict = convert_keys_to_strings(model_dict)
20
+
21
+ data = {
22
+ 'model' => model_dict,
23
+ 'config' => config&.to_h&.transform_keys(&:to_s),
24
+ **metadata.transform_keys(&:to_s)
25
+ }
26
+ data['optimizer'] = optimizer.state_dict if optimizer
27
+
28
+ Torch.save(data, path)
29
+ end
30
+
31
+ private
32
+
33
+ def convert_keys_to_strings(hash)
34
+ hash.transform_keys(&:to_s).transform_values do |value|
35
+ value.is_a?(Hash) ? convert_keys_to_strings(value) : value
36
+ end
37
+ end
38
+ end
39
+ end
40
+ end
@@ -0,0 +1,32 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'fileutils'
4
+
5
+ module Nanochat
6
+ # Common utilities
7
+ module Common
8
+ class << self
9
+ def device
10
+ @device ||= if Torch::CUDA.available?
11
+ Torch.device('cuda')
12
+ elsif defined?(Torch::Backends::MPS) && Torch::Backends::MPS.available?
13
+ Torch.device('mps')
14
+ else
15
+ Torch.device('cpu')
16
+ end
17
+ end
18
+
19
+ def seed(seed_value)
20
+ Torch.manual_seed(seed_value)
21
+ Torch::CUDA.manual_seed_all(seed_value) if Torch::CUDA.available?
22
+ end
23
+
24
+ def default_cache_dir = ENV.fetch('NANOCHAT_BASE_DIR') { File.expand_path('~/.cache/nanochat') }
25
+
26
+ def ensure_dir(path)
27
+ FileUtils.mkdir_p(path) unless File.directory?(path)
28
+ path
29
+ end
30
+ end
31
+ end
32
+ end
@@ -0,0 +1,49 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Nanochat
4
+ # GPT model configuration
5
+ Config = Data.define(
6
+ :vocab_size,
7
+ :block_size, # context length
8
+ :n_embd, # embedding dimension
9
+ :n_head, # query heads
10
+ :n_kv_head, # key/value heads (MQA)
11
+ :n_layer # transformer blocks
12
+ ) do
13
+ def self.default
14
+ new(
15
+ vocab_size: 50_304,
16
+ block_size: 1024,
17
+ n_embd: 768,
18
+ n_head: 6,
19
+ n_kv_head: 6,
20
+ n_layer: 12
21
+ )
22
+ end
23
+
24
+ def self.from_checkpoint(checkpoint)
25
+ config_dict = checkpoint['config'] || checkpoint[:config]
26
+ config_dict['block_size'] ||= config_dict['sequence_len'] if config_dict.is_a?(Hash)
27
+ new(**config_dict.transform_keys(&:to_sym))
28
+ end
29
+
30
+ def validate!
31
+ raise ArgumentError, "vocab_size (#{vocab_size}) must be positive" unless vocab_size.positive?
32
+ raise ArgumentError, "block_size (#{block_size}) must be positive" unless block_size.positive?
33
+ raise ArgumentError, "n_embd (#{n_embd}) must be positive" unless n_embd.positive?
34
+ raise ArgumentError, "n_head (#{n_head}) must be positive" unless n_head.positive?
35
+ raise ArgumentError, "n_kv_head (#{n_kv_head}) must be positive" unless n_kv_head.positive?
36
+ raise ArgumentError, "n_layer (#{n_layer}) must be positive" unless n_layer.positive?
37
+ raise ArgumentError, "n_embd (#{n_embd}) must be divisible by n_head (#{n_head})" unless (n_embd % n_head).zero?
38
+
39
+ unless n_kv_head <= n_head
40
+ raise ArgumentError,
41
+ "Invalid MQA: n_kv_head (#{n_kv_head}) must be <= n_head (#{n_head})"
42
+ end
43
+ return if (n_head % n_kv_head).zero?
44
+
45
+ raise ArgumentError,
46
+ "Invalid MQA: n_head (#{n_head}) must be divisible by n_kv_head (#{n_kv_head})"
47
+ end
48
+ end
49
+ end
@@ -0,0 +1,152 @@
1
+ # frozen_string_literal: true
2
+
3
+ # Nanochat Engine: Efficient inference with KV caching
4
+ # Ruby port of nanochat by Andrej Karpathy (https://github.com/karpathy/nanochat)
5
+
6
+ module Nanochat
7
+ # KV cache for efficient inference
8
+ class KVCache
9
+ attr_reader :pos
10
+
11
+ def initialize(batch_size, num_heads, seq_len, head_dim, num_layers)
12
+ @kv_shape = [num_layers, 2, batch_size, num_heads, seq_len, head_dim]
13
+ @kv_cache = nil
14
+ @pos = 0
15
+ end
16
+
17
+ def reset = @pos = 0
18
+
19
+ def insert_kv(layer_idx, key, value)
20
+ @kv_cache = Torch.empty(@kv_shape, dtype: key.dtype, device: key.device) if @kv_cache.nil?
21
+
22
+ _batch, _heads, t_add, _dim = key.size
23
+ t0 = @pos
24
+ t1 = @pos + t_add
25
+
26
+ if t1 > @kv_cache.size(4)
27
+ t_needed = t1 + 1024
28
+ t_needed = (t_needed + 1023) & ~1023
29
+ current_shape = @kv_shape.dup
30
+ current_shape[4] = t_needed
31
+ @kv_cache = @kv_cache.resize(current_shape)
32
+ @kv_shape = current_shape
33
+ end
34
+
35
+ @kv_cache[layer_idx, 0, 0..-1, 0..-1, t0...t1, 0..-1] = key
36
+ @kv_cache[layer_idx, 1, 0..-1, 0..-1, t0...t1, 0..-1] = value
37
+
38
+ key_view = @kv_cache[layer_idx, 0, 0..-1, 0..-1, 0...t1, 0..-1]
39
+ value_view = @kv_cache[layer_idx, 1, 0..-1, 0..-1, 0...t1, 0..-1]
40
+
41
+ @pos = t1 if layer_idx == @kv_cache.size(0) - 1
42
+
43
+ [key_view, value_view]
44
+ end
45
+ end
46
+
47
+ # Text generation engine
48
+ class Engine
49
+ def initialize(model:, tokenizer:, device: nil)
50
+ @model = model
51
+ @tokenizer = tokenizer
52
+ @device = device || Common.device
53
+ @model.to(@device)
54
+ @model.eval
55
+ end
56
+
57
+ def generate(prompt, max_tokens: 100, temperature: 1.0, top_k: nil, top_p: nil)
58
+ tokens = []
59
+ generate_stream(prompt, max_tokens:, temperature:, top_k:, top_p:) do |token_text, _token_id|
60
+ tokens << token_text
61
+ end
62
+ tokens.join
63
+ end
64
+
65
+ # Generate text with streaming. Yields token_text (String), token_id (Integer).
66
+ # Accepts string prompts or token arrays.
67
+ def generate_stream(prompt, max_tokens: 100, temperature: 1.0, top_k: nil, top_p: nil)
68
+ tokens = prompt.is_a?(Array) ? prompt : @tokenizer.encode(prompt)
69
+ return if tokens.empty?
70
+
71
+ config = @model.config
72
+ kv_cache = KVCache.new(
73
+ 1,
74
+ config.n_kv_head,
75
+ tokens.length + max_tokens,
76
+ config.n_embd / config.n_head,
77
+ config.n_layer
78
+ )
79
+
80
+ input_ids = Torch.tensor([tokens], dtype: :long).to(@device)
81
+ generated_tokens = []
82
+
83
+ Torch.no_grad do
84
+ max_tokens.times do
85
+ logits = @model.call(input_ids, kv_cache:)
86
+ next_token_logits = logits[0..-1, -1, 0..-1]
87
+
88
+ next_token = sample(next_token_logits, temperature, top_k, top_p)
89
+ token_id = next_token[0, 0].item
90
+
91
+ break if token_id == @tokenizer.eos_token_id
92
+
93
+ token_text = @tokenizer.decode([token_id])
94
+ yield(token_text, token_id) if block_given?
95
+
96
+ input_ids = next_token.view(1, 1)
97
+ generated_tokens << token_id
98
+ end
99
+ end
100
+
101
+ generated_tokens
102
+ end
103
+
104
+ private
105
+
106
+ def sample(logits, temperature, top_k, top_p)
107
+ return logits.argmax(-1, keepdim: true) if temperature.zero?
108
+
109
+ if top_k
110
+ k = [top_k, logits.size(-1)].min
111
+ vals, idx = Torch.topk(logits, k, dim: -1)
112
+ vals /= temperature
113
+ probs = Torch::NN::F.softmax(vals, dim: -1)
114
+ choice = Torch.multinomial(probs, num_samples: 1)
115
+ return idx.gather(1, choice)
116
+ end
117
+
118
+ # Top-p (nucleus) sampling
119
+ if top_p && top_p < 1.0
120
+ scaled_logits = logits / temperature
121
+ probs = Torch::NN::F.softmax(scaled_logits, dim: -1)
122
+
123
+ # Sort probabilities in descending order
124
+ sorted_probs, sorted_indices = probs.sort(dim: -1, descending: true)
125
+
126
+ # Compute cumulative probabilities
127
+ cumulative_probs = sorted_probs.cumsum(dim: -1)
128
+
129
+ # Remove tokens with cumulative probability above threshold
130
+ # Keep at least one token (the highest probability one)
131
+ sorted_indices_to_remove = Torch.gt(cumulative_probs, top_p)
132
+ sorted_indices_to_remove[0..-1, 0] = false
133
+
134
+ # Zero out probabilities for removed tokens
135
+ sorted_probs[sorted_indices_to_remove] = 0.0
136
+
137
+ # Renormalize probabilities
138
+ sorted_probs /= sorted_probs.sum(dim: -1, keepdim: true)
139
+
140
+ # Sample from filtered distribution
141
+ choice = Torch.multinomial(sorted_probs, num_samples: 1)
142
+
143
+ # Map back to original vocabulary indices
144
+ return sorted_indices.gather(1, choice)
145
+ end
146
+
147
+ scaled_logits = logits / temperature
148
+ probs = Torch::NN::F.softmax(scaled_logits, dim: -1)
149
+ Torch.multinomial(probs, num_samples: 1)
150
+ end
151
+ end
152
+ end