nanochat 0.1.0.pre
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/LICENSE +25 -0
- data/README.md +129 -0
- data/bin/nanochat-setup +186 -0
- data/bin/package-checkpoint +122 -0
- data/bin/speedrun.sh +32 -0
- data/bin/train-tiny-model +190 -0
- data/bin/train-with-python-nanochat.sh +167 -0
- data/lib/nanochat/checkpoint_manager.rb +40 -0
- data/lib/nanochat/common.rb +32 -0
- data/lib/nanochat/config.rb +49 -0
- data/lib/nanochat/engine.rb +152 -0
- data/lib/nanochat/gpt.rb +285 -0
- data/lib/nanochat/tokenizer.rb +119 -0
- data/lib/nanochat/version.rb +5 -0
- data/lib/nanochat.rb +27 -0
- metadata +91 -0
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
#!/bin/bash
|
|
2
|
+
|
|
3
|
+
# Train a tiny nanochat checkpoint using Python nanochat
|
|
4
|
+
#
|
|
5
|
+
# This script is adapted from python-nanochat/dev/runcpu.sh
|
|
6
|
+
# Original: https://github.com/karpathy/nanochat by Andrej Karpathy
|
|
7
|
+
#
|
|
8
|
+
# REQUIREMENTS:
|
|
9
|
+
# - Python nanochat cloned at: ../python-nanochat or ./python-nanochat
|
|
10
|
+
# - Python 3.10+
|
|
11
|
+
# - Rust (for building rustbpe tokenizer)
|
|
12
|
+
# - ~1GB disk space for data
|
|
13
|
+
# - ~30 minutes on CPU or ~5 minutes on GPU
|
|
14
|
+
#
|
|
15
|
+
# USAGE:
|
|
16
|
+
# bash bin/train-with-python-nanochat.sh
|
|
17
|
+
#
|
|
18
|
+
# OUTPUT:
|
|
19
|
+
# Trained checkpoint at: ~/.cache/nanochat/model.pt
|
|
20
|
+
# Tokenizer at: ~/.cache/nanochat/tokenizer/tokenizer.json
|
|
21
|
+
|
|
22
|
+
set -e # Exit on error
|
|
23
|
+
|
|
24
|
+
echo "🔥 Train Tiny Nanochat Model"
|
|
25
|
+
echo "======================================================================"
|
|
26
|
+
echo ""
|
|
27
|
+
echo "This will train a d4 model (4 layers, minimal for demos)"
|
|
28
|
+
echo "Output: ~/.cache/nanochat/"
|
|
29
|
+
echo ""
|
|
30
|
+
echo "⏱️ Estimated time: ~30 minutes on CPU"
|
|
31
|
+
echo ""
|
|
32
|
+
echo "📝 Attribution: Using training scripts from"
|
|
33
|
+
echo " https://github.com/karpathy/nanochat by Andrej Karpathy"
|
|
34
|
+
echo ""
|
|
35
|
+
echo "======================================================================"
|
|
36
|
+
echo ""
|
|
37
|
+
|
|
38
|
+
# Find python-nanochat directory
|
|
39
|
+
if [ -d "python-nanochat" ]; then
|
|
40
|
+
PYTHON_NANOCHAT_DIR="python-nanochat"
|
|
41
|
+
elif [ -d "../python-nanochat" ]; then
|
|
42
|
+
PYTHON_NANOCHAT_DIR="../python-nanochat"
|
|
43
|
+
else
|
|
44
|
+
echo "❌ Python nanochat not found"
|
|
45
|
+
echo ""
|
|
46
|
+
echo "Clone it first:"
|
|
47
|
+
echo " git clone https://github.com/karpathy/nanochat python-nanochat"
|
|
48
|
+
echo ""
|
|
49
|
+
exit 1
|
|
50
|
+
fi
|
|
51
|
+
|
|
52
|
+
echo "✅ Python nanochat found at: $PYTHON_NANOCHAT_DIR"
|
|
53
|
+
echo ""
|
|
54
|
+
|
|
55
|
+
# Change to python-nanochat directory
|
|
56
|
+
cd "$PYTHON_NANOCHAT_DIR"
|
|
57
|
+
|
|
58
|
+
# Setup environment
|
|
59
|
+
export OMP_NUM_THREADS=1
|
|
60
|
+
NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
|
|
61
|
+
mkdir -p "$NANOCHAT_BASE_DIR"
|
|
62
|
+
|
|
63
|
+
echo "🔧 Setting up Python environment..."
|
|
64
|
+
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
|
|
65
|
+
[ -d ".venv" ] || uv venv
|
|
66
|
+
uv sync --extra cpu
|
|
67
|
+
source .venv/bin/activate
|
|
68
|
+
echo ""
|
|
69
|
+
|
|
70
|
+
echo "🦀 Building Rust tokenizer..."
|
|
71
|
+
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
|
72
|
+
source "$HOME/.cargo/env"
|
|
73
|
+
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
|
|
74
|
+
echo ""
|
|
75
|
+
|
|
76
|
+
echo "📦 Downloading evaluation bundle..."
|
|
77
|
+
EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip
|
|
78
|
+
if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then
|
|
79
|
+
curl -L -o eval_bundle.zip "$EVAL_BUNDLE_URL"
|
|
80
|
+
unzip -q eval_bundle.zip
|
|
81
|
+
rm eval_bundle.zip
|
|
82
|
+
mv eval_bundle "$NANOCHAT_BASE_DIR"
|
|
83
|
+
fi
|
|
84
|
+
echo ""
|
|
85
|
+
|
|
86
|
+
# Reset report
|
|
87
|
+
python -m nanochat.report reset
|
|
88
|
+
|
|
89
|
+
echo "📚 Downloading training data (~1GB)..."
|
|
90
|
+
python -m nanochat.dataset -n 4
|
|
91
|
+
echo ""
|
|
92
|
+
|
|
93
|
+
echo "🔤 Training tokenizer..."
|
|
94
|
+
python -m scripts.tok_train --max_chars=1000000000
|
|
95
|
+
python -m scripts.tok_eval
|
|
96
|
+
echo ""
|
|
97
|
+
|
|
98
|
+
echo "🚀 Training base model (50 iterations, ~30 mins)..."
|
|
99
|
+
python -m scripts.base_train \
|
|
100
|
+
--depth=4 \
|
|
101
|
+
--max_seq_len=1024 \
|
|
102
|
+
--device_batch_size=1 \
|
|
103
|
+
--total_batch_size=1024 \
|
|
104
|
+
--eval_every=50 \
|
|
105
|
+
--eval_tokens=4096 \
|
|
106
|
+
--core_metric_every=50 \
|
|
107
|
+
--core_metric_max_per_task=12 \
|
|
108
|
+
--sample_every=50 \
|
|
109
|
+
--num_iterations=50
|
|
110
|
+
echo ""
|
|
111
|
+
|
|
112
|
+
echo "📊 Evaluating base model..."
|
|
113
|
+
python -m scripts.base_loss --device_batch_size=1 --split_tokens=4096
|
|
114
|
+
python -m scripts.base_eval --max-per-task=16
|
|
115
|
+
echo ""
|
|
116
|
+
|
|
117
|
+
echo "🎯 Midtraining (100 iterations)..."
|
|
118
|
+
python -m scripts.mid_train \
|
|
119
|
+
--max_seq_len=1024 \
|
|
120
|
+
--device_batch_size=1 \
|
|
121
|
+
--eval_every=50 \
|
|
122
|
+
--eval_tokens=4096 \
|
|
123
|
+
--total_batch_size=1024 \
|
|
124
|
+
--num_iterations=100
|
|
125
|
+
echo ""
|
|
126
|
+
|
|
127
|
+
echo "💬 Supervised fine-tuning (100 iterations)..."
|
|
128
|
+
python -m scripts.chat_sft \
|
|
129
|
+
--device_batch_size=1 \
|
|
130
|
+
--target_examples_per_step=4 \
|
|
131
|
+
--num_iterations=100 \
|
|
132
|
+
--eval_steps=4 \
|
|
133
|
+
--eval_metrics_max_problems=16
|
|
134
|
+
echo ""
|
|
135
|
+
|
|
136
|
+
echo "📝 Generating training report..."
|
|
137
|
+
python -m nanochat.report generate
|
|
138
|
+
echo ""
|
|
139
|
+
|
|
140
|
+
echo "======================================================================"
|
|
141
|
+
echo "✅ Training complete!"
|
|
142
|
+
echo ""
|
|
143
|
+
echo "📦 Checkpoint location:"
|
|
144
|
+
echo " Model: $NANOCHAT_BASE_DIR/model.pt"
|
|
145
|
+
echo " Tokenizer: $NANOCHAT_BASE_DIR/tokenizer/tokenizer.json"
|
|
146
|
+
echo ""
|
|
147
|
+
echo "🎯 Next Steps - Use Your Model in Ruby"
|
|
148
|
+
echo "======================================================================"
|
|
149
|
+
echo ""
|
|
150
|
+
echo "# Interactive chat"
|
|
151
|
+
echo "ruby examples/chat_cli.rb"
|
|
152
|
+
echo ""
|
|
153
|
+
echo "# Web UI (visit http://localhost:8000)"
|
|
154
|
+
echo "ruby examples/chat_web.rb"
|
|
155
|
+
echo ""
|
|
156
|
+
echo "# Generate text"
|
|
157
|
+
echo "ruby examples/generate_text.rb 'Once upon a time'"
|
|
158
|
+
echo ""
|
|
159
|
+
echo "# Fine-tune on your data"
|
|
160
|
+
echo "ruby examples/finetune.rb --data my_data.txt --output custom.pt"
|
|
161
|
+
echo ""
|
|
162
|
+
echo "======================================================================"
|
|
163
|
+
echo ""
|
|
164
|
+
echo "📦 Optional: Package this checkpoint for distribution"
|
|
165
|
+
echo ""
|
|
166
|
+
echo "tar -czf nanochat-tiny-d4.tar.gz -C $(dirname $NANOCHAT_BASE_DIR) $(basename $NANOCHAT_BASE_DIR)"
|
|
167
|
+
echo ""
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Nanochat
|
|
4
|
+
# Checkpoint loading and saving
|
|
5
|
+
module CheckpointManager
|
|
6
|
+
class << self
|
|
7
|
+
def load(path)
|
|
8
|
+
raise ArgumentError, "Checkpoint not found: #{path}" unless File.exist?(path)
|
|
9
|
+
|
|
10
|
+
Torch.load(path)
|
|
11
|
+
end
|
|
12
|
+
|
|
13
|
+
def save(path, model: nil, state_dict: nil, optimizer: nil, config: nil, **metadata)
|
|
14
|
+
raise ArgumentError, 'Must provide either model: or state_dict:' if model.nil? && state_dict.nil?
|
|
15
|
+
|
|
16
|
+
FileUtils.mkdir_p(File.dirname(path))
|
|
17
|
+
|
|
18
|
+
model_dict = model ? model.state_dict : state_dict
|
|
19
|
+
model_dict = convert_keys_to_strings(model_dict)
|
|
20
|
+
|
|
21
|
+
data = {
|
|
22
|
+
'model' => model_dict,
|
|
23
|
+
'config' => config&.to_h&.transform_keys(&:to_s),
|
|
24
|
+
**metadata.transform_keys(&:to_s)
|
|
25
|
+
}
|
|
26
|
+
data['optimizer'] = optimizer.state_dict if optimizer
|
|
27
|
+
|
|
28
|
+
Torch.save(data, path)
|
|
29
|
+
end
|
|
30
|
+
|
|
31
|
+
private
|
|
32
|
+
|
|
33
|
+
def convert_keys_to_strings(hash)
|
|
34
|
+
hash.transform_keys(&:to_s).transform_values do |value|
|
|
35
|
+
value.is_a?(Hash) ? convert_keys_to_strings(value) : value
|
|
36
|
+
end
|
|
37
|
+
end
|
|
38
|
+
end
|
|
39
|
+
end
|
|
40
|
+
end
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require 'fileutils'
|
|
4
|
+
|
|
5
|
+
module Nanochat
|
|
6
|
+
# Common utilities
|
|
7
|
+
module Common
|
|
8
|
+
class << self
|
|
9
|
+
def device
|
|
10
|
+
@device ||= if Torch::CUDA.available?
|
|
11
|
+
Torch.device('cuda')
|
|
12
|
+
elsif defined?(Torch::Backends::MPS) && Torch::Backends::MPS.available?
|
|
13
|
+
Torch.device('mps')
|
|
14
|
+
else
|
|
15
|
+
Torch.device('cpu')
|
|
16
|
+
end
|
|
17
|
+
end
|
|
18
|
+
|
|
19
|
+
def seed(seed_value)
|
|
20
|
+
Torch.manual_seed(seed_value)
|
|
21
|
+
Torch::CUDA.manual_seed_all(seed_value) if Torch::CUDA.available?
|
|
22
|
+
end
|
|
23
|
+
|
|
24
|
+
def default_cache_dir = ENV.fetch('NANOCHAT_BASE_DIR') { File.expand_path('~/.cache/nanochat') }
|
|
25
|
+
|
|
26
|
+
def ensure_dir(path)
|
|
27
|
+
FileUtils.mkdir_p(path) unless File.directory?(path)
|
|
28
|
+
path
|
|
29
|
+
end
|
|
30
|
+
end
|
|
31
|
+
end
|
|
32
|
+
end
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Nanochat
|
|
4
|
+
# GPT model configuration
|
|
5
|
+
Config = Data.define(
|
|
6
|
+
:vocab_size,
|
|
7
|
+
:block_size, # context length
|
|
8
|
+
:n_embd, # embedding dimension
|
|
9
|
+
:n_head, # query heads
|
|
10
|
+
:n_kv_head, # key/value heads (MQA)
|
|
11
|
+
:n_layer # transformer blocks
|
|
12
|
+
) do
|
|
13
|
+
def self.default
|
|
14
|
+
new(
|
|
15
|
+
vocab_size: 50_304,
|
|
16
|
+
block_size: 1024,
|
|
17
|
+
n_embd: 768,
|
|
18
|
+
n_head: 6,
|
|
19
|
+
n_kv_head: 6,
|
|
20
|
+
n_layer: 12
|
|
21
|
+
)
|
|
22
|
+
end
|
|
23
|
+
|
|
24
|
+
def self.from_checkpoint(checkpoint)
|
|
25
|
+
config_dict = checkpoint['config'] || checkpoint[:config]
|
|
26
|
+
config_dict['block_size'] ||= config_dict['sequence_len'] if config_dict.is_a?(Hash)
|
|
27
|
+
new(**config_dict.transform_keys(&:to_sym))
|
|
28
|
+
end
|
|
29
|
+
|
|
30
|
+
def validate!
|
|
31
|
+
raise ArgumentError, "vocab_size (#{vocab_size}) must be positive" unless vocab_size.positive?
|
|
32
|
+
raise ArgumentError, "block_size (#{block_size}) must be positive" unless block_size.positive?
|
|
33
|
+
raise ArgumentError, "n_embd (#{n_embd}) must be positive" unless n_embd.positive?
|
|
34
|
+
raise ArgumentError, "n_head (#{n_head}) must be positive" unless n_head.positive?
|
|
35
|
+
raise ArgumentError, "n_kv_head (#{n_kv_head}) must be positive" unless n_kv_head.positive?
|
|
36
|
+
raise ArgumentError, "n_layer (#{n_layer}) must be positive" unless n_layer.positive?
|
|
37
|
+
raise ArgumentError, "n_embd (#{n_embd}) must be divisible by n_head (#{n_head})" unless (n_embd % n_head).zero?
|
|
38
|
+
|
|
39
|
+
unless n_kv_head <= n_head
|
|
40
|
+
raise ArgumentError,
|
|
41
|
+
"Invalid MQA: n_kv_head (#{n_kv_head}) must be <= n_head (#{n_head})"
|
|
42
|
+
end
|
|
43
|
+
return if (n_head % n_kv_head).zero?
|
|
44
|
+
|
|
45
|
+
raise ArgumentError,
|
|
46
|
+
"Invalid MQA: n_head (#{n_head}) must be divisible by n_kv_head (#{n_kv_head})"
|
|
47
|
+
end
|
|
48
|
+
end
|
|
49
|
+
end
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
# Nanochat Engine: Efficient inference with KV caching
|
|
4
|
+
# Ruby port of nanochat by Andrej Karpathy (https://github.com/karpathy/nanochat)
|
|
5
|
+
|
|
6
|
+
module Nanochat
|
|
7
|
+
# KV cache for efficient inference
|
|
8
|
+
class KVCache
|
|
9
|
+
attr_reader :pos
|
|
10
|
+
|
|
11
|
+
def initialize(batch_size, num_heads, seq_len, head_dim, num_layers)
|
|
12
|
+
@kv_shape = [num_layers, 2, batch_size, num_heads, seq_len, head_dim]
|
|
13
|
+
@kv_cache = nil
|
|
14
|
+
@pos = 0
|
|
15
|
+
end
|
|
16
|
+
|
|
17
|
+
def reset = @pos = 0
|
|
18
|
+
|
|
19
|
+
def insert_kv(layer_idx, key, value)
|
|
20
|
+
@kv_cache = Torch.empty(@kv_shape, dtype: key.dtype, device: key.device) if @kv_cache.nil?
|
|
21
|
+
|
|
22
|
+
_batch, _heads, t_add, _dim = key.size
|
|
23
|
+
t0 = @pos
|
|
24
|
+
t1 = @pos + t_add
|
|
25
|
+
|
|
26
|
+
if t1 > @kv_cache.size(4)
|
|
27
|
+
t_needed = t1 + 1024
|
|
28
|
+
t_needed = (t_needed + 1023) & ~1023
|
|
29
|
+
current_shape = @kv_shape.dup
|
|
30
|
+
current_shape[4] = t_needed
|
|
31
|
+
@kv_cache = @kv_cache.resize(current_shape)
|
|
32
|
+
@kv_shape = current_shape
|
|
33
|
+
end
|
|
34
|
+
|
|
35
|
+
@kv_cache[layer_idx, 0, 0..-1, 0..-1, t0...t1, 0..-1] = key
|
|
36
|
+
@kv_cache[layer_idx, 1, 0..-1, 0..-1, t0...t1, 0..-1] = value
|
|
37
|
+
|
|
38
|
+
key_view = @kv_cache[layer_idx, 0, 0..-1, 0..-1, 0...t1, 0..-1]
|
|
39
|
+
value_view = @kv_cache[layer_idx, 1, 0..-1, 0..-1, 0...t1, 0..-1]
|
|
40
|
+
|
|
41
|
+
@pos = t1 if layer_idx == @kv_cache.size(0) - 1
|
|
42
|
+
|
|
43
|
+
[key_view, value_view]
|
|
44
|
+
end
|
|
45
|
+
end
|
|
46
|
+
|
|
47
|
+
# Text generation engine
|
|
48
|
+
class Engine
|
|
49
|
+
def initialize(model:, tokenizer:, device: nil)
|
|
50
|
+
@model = model
|
|
51
|
+
@tokenizer = tokenizer
|
|
52
|
+
@device = device || Common.device
|
|
53
|
+
@model.to(@device)
|
|
54
|
+
@model.eval
|
|
55
|
+
end
|
|
56
|
+
|
|
57
|
+
def generate(prompt, max_tokens: 100, temperature: 1.0, top_k: nil, top_p: nil)
|
|
58
|
+
tokens = []
|
|
59
|
+
generate_stream(prompt, max_tokens:, temperature:, top_k:, top_p:) do |token_text, _token_id|
|
|
60
|
+
tokens << token_text
|
|
61
|
+
end
|
|
62
|
+
tokens.join
|
|
63
|
+
end
|
|
64
|
+
|
|
65
|
+
# Generate text with streaming. Yields token_text (String), token_id (Integer).
|
|
66
|
+
# Accepts string prompts or token arrays.
|
|
67
|
+
def generate_stream(prompt, max_tokens: 100, temperature: 1.0, top_k: nil, top_p: nil)
|
|
68
|
+
tokens = prompt.is_a?(Array) ? prompt : @tokenizer.encode(prompt)
|
|
69
|
+
return if tokens.empty?
|
|
70
|
+
|
|
71
|
+
config = @model.config
|
|
72
|
+
kv_cache = KVCache.new(
|
|
73
|
+
1,
|
|
74
|
+
config.n_kv_head,
|
|
75
|
+
tokens.length + max_tokens,
|
|
76
|
+
config.n_embd / config.n_head,
|
|
77
|
+
config.n_layer
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
input_ids = Torch.tensor([tokens], dtype: :long).to(@device)
|
|
81
|
+
generated_tokens = []
|
|
82
|
+
|
|
83
|
+
Torch.no_grad do
|
|
84
|
+
max_tokens.times do
|
|
85
|
+
logits = @model.call(input_ids, kv_cache:)
|
|
86
|
+
next_token_logits = logits[0..-1, -1, 0..-1]
|
|
87
|
+
|
|
88
|
+
next_token = sample(next_token_logits, temperature, top_k, top_p)
|
|
89
|
+
token_id = next_token[0, 0].item
|
|
90
|
+
|
|
91
|
+
break if token_id == @tokenizer.eos_token_id
|
|
92
|
+
|
|
93
|
+
token_text = @tokenizer.decode([token_id])
|
|
94
|
+
yield(token_text, token_id) if block_given?
|
|
95
|
+
|
|
96
|
+
input_ids = next_token.view(1, 1)
|
|
97
|
+
generated_tokens << token_id
|
|
98
|
+
end
|
|
99
|
+
end
|
|
100
|
+
|
|
101
|
+
generated_tokens
|
|
102
|
+
end
|
|
103
|
+
|
|
104
|
+
private
|
|
105
|
+
|
|
106
|
+
def sample(logits, temperature, top_k, top_p)
|
|
107
|
+
return logits.argmax(-1, keepdim: true) if temperature.zero?
|
|
108
|
+
|
|
109
|
+
if top_k
|
|
110
|
+
k = [top_k, logits.size(-1)].min
|
|
111
|
+
vals, idx = Torch.topk(logits, k, dim: -1)
|
|
112
|
+
vals /= temperature
|
|
113
|
+
probs = Torch::NN::F.softmax(vals, dim: -1)
|
|
114
|
+
choice = Torch.multinomial(probs, num_samples: 1)
|
|
115
|
+
return idx.gather(1, choice)
|
|
116
|
+
end
|
|
117
|
+
|
|
118
|
+
# Top-p (nucleus) sampling
|
|
119
|
+
if top_p && top_p < 1.0
|
|
120
|
+
scaled_logits = logits / temperature
|
|
121
|
+
probs = Torch::NN::F.softmax(scaled_logits, dim: -1)
|
|
122
|
+
|
|
123
|
+
# Sort probabilities in descending order
|
|
124
|
+
sorted_probs, sorted_indices = probs.sort(dim: -1, descending: true)
|
|
125
|
+
|
|
126
|
+
# Compute cumulative probabilities
|
|
127
|
+
cumulative_probs = sorted_probs.cumsum(dim: -1)
|
|
128
|
+
|
|
129
|
+
# Remove tokens with cumulative probability above threshold
|
|
130
|
+
# Keep at least one token (the highest probability one)
|
|
131
|
+
sorted_indices_to_remove = Torch.gt(cumulative_probs, top_p)
|
|
132
|
+
sorted_indices_to_remove[0..-1, 0] = false
|
|
133
|
+
|
|
134
|
+
# Zero out probabilities for removed tokens
|
|
135
|
+
sorted_probs[sorted_indices_to_remove] = 0.0
|
|
136
|
+
|
|
137
|
+
# Renormalize probabilities
|
|
138
|
+
sorted_probs /= sorted_probs.sum(dim: -1, keepdim: true)
|
|
139
|
+
|
|
140
|
+
# Sample from filtered distribution
|
|
141
|
+
choice = Torch.multinomial(sorted_probs, num_samples: 1)
|
|
142
|
+
|
|
143
|
+
# Map back to original vocabulary indices
|
|
144
|
+
return sorted_indices.gather(1, choice)
|
|
145
|
+
end
|
|
146
|
+
|
|
147
|
+
scaled_logits = logits / temperature
|
|
148
|
+
probs = Torch::NN::F.softmax(scaled_logits, dim: -1)
|
|
149
|
+
Torch.multinomial(probs, num_samples: 1)
|
|
150
|
+
end
|
|
151
|
+
end
|
|
152
|
+
end
|