ignis-dl 0.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/README.md +15 -0
- data/lib/ignis-dl.rb +48 -0
- data/lib/nnw/ai/gpt2_loader.rb +144 -0
- data/lib/nnw/ai/inference.rb +224 -0
- data/lib/nnw/ai/kv_cache.rb +79 -0
- data/lib/nnw/ai/llama_loader.rb +100 -0
- data/lib/nnw/ai/loss.rb +170 -0
- data/lib/nnw/ai/nn/dropout.rb +68 -0
- data/lib/nnw/ai/nn/embedding.rb +86 -0
- data/lib/nnw/ai/nn/layer_norm.rb +54 -0
- data/lib/nnw/ai/nn/linear.rb +80 -0
- data/lib/nnw/ai/nn/module.rb +178 -0
- data/lib/nnw/ai/nn/rms_norm.rb +43 -0
- data/lib/nnw/ai/nn/sequential.rb +52 -0
- data/lib/nnw/ai/optim/adam.rb +63 -0
- data/lib/nnw/ai/optim/adamw.rb +63 -0
- data/lib/nnw/ai/optim/base.rb +90 -0
- data/lib/nnw/ai/optim/lr_scheduler.rb +118 -0
- data/lib/nnw/ai/optim/sgd.rb +49 -0
- data/lib/nnw/ai/safetensors.rb +220 -0
- data/lib/nnw/ai/server.rb +268 -0
- data/lib/nnw/ai/tokenizer.rb +413 -0
- data/lib/nnw/ai/trainer.rb +245 -0
- data/lib/nnw/ai/transformer/attention.rb +89 -0
- data/lib/nnw/ai/transformer/block.rb +90 -0
- data/lib/nnw/ai/transformer/feed_forward.rb +53 -0
- data/lib/nnw/ai/transformer/model.rb +189 -0
- data/lib/nnw/ai/transformer/modern.rb +191 -0
- data/lib/nnw/ai/transformer/swiglu.rb +39 -0
- data/lib/nnw/ai/weight_map.rb +139 -0
- metadata +91 -0
checksums.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
---
|
|
2
|
+
SHA256:
|
|
3
|
+
metadata.gz: 2e2b7a2b0882360bf9c07ca60b72aa71c3a43dbb9739aa9dde31182e8e9068ca
|
|
4
|
+
data.tar.gz: a34343e8f5e2334a79b8a5b6dc935891d8776ffe5b64f47832dccb2f87ce9fcc
|
|
5
|
+
SHA512:
|
|
6
|
+
metadata.gz: 525582d82f78c61c0fedf0a31273f6010453c5f3e1753ff2908b6ba402f03de909aa6a7eb1dc12c0feb773e7a65c519f6be11a4dc44db6307f5aed4fe1baa65d
|
|
7
|
+
data.tar.gz: f7c0f62d4cdcb6ddcf8daf3af78a1d8580663f91dab1ef6ac8219c9b2ce445e6858368b9782ce779b0710ca403579c58505ff2079297599e6d454b3cad5c2d0e
|
data/README.md
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# ignis-dl
|
|
2
|
+
|
|
3
|
+
Neural-net layers + transformers for Ruby, on the [Ignis](https://github.com/tigel-agm/Ignis) GPU/autograd stack. Installing this pulls the whole chain (`ignis` + `ignis-autograd`).
|
|
4
|
+
|
|
5
|
+
- **NN:** Linear, Embedding, LayerNorm, RMSNorm, Dropout; SGD/Adam/AdamW; losses.
|
|
6
|
+
- **Transformers:** multi-head + grouped-query attention, RoPE, SwiGLU, KV cache.
|
|
7
|
+
- **Loaders:** real GPT-2 and Llama (HuggingFace `safetensors`) — matches HF logits.
|
|
8
|
+
|
|
9
|
+
```ruby
|
|
10
|
+
require "ignis-dl"
|
|
11
|
+
model = Ignis::AI::LlamaLoader.from_pretrained("path/to/Llama-3.2-1B")
|
|
12
|
+
# ... matches HuggingFace; "The capital of France is" => " Paris"
|
|
13
|
+
```
|
|
14
|
+
|
|
15
|
+
Loads real GPT-2 and Llama-3.2 (cosine 1.0 vs HuggingFace) and trains transformers from scratch. MIT.
|
data/lib/ignis-dl.rb
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
# ignis-dl — neural-net layers + transformers for Ruby, on the Ignis GPU/autograd
|
|
4
|
+
# stack. Loads real GPT-2 and Llama checkpoints (matching HuggingFace) and trains
|
|
5
|
+
# transformers from scratch. Requiring this pulls the whole chain (ignis +
|
|
6
|
+
# ignis-autograd), so it doubles as the meta-gem.
|
|
7
|
+
|
|
8
|
+
require "ignis-autograd"
|
|
9
|
+
|
|
10
|
+
# --- NN modules ---
|
|
11
|
+
require_relative "nnw/ai/nn/module"
|
|
12
|
+
require_relative "nnw/ai/nn/linear"
|
|
13
|
+
require_relative "nnw/ai/nn/embedding"
|
|
14
|
+
require_relative "nnw/ai/nn/layer_norm"
|
|
15
|
+
require_relative "nnw/ai/nn/rms_norm"
|
|
16
|
+
require_relative "nnw/ai/nn/dropout"
|
|
17
|
+
require_relative "nnw/ai/nn/sequential"
|
|
18
|
+
|
|
19
|
+
# --- Transformer ---
|
|
20
|
+
require_relative "nnw/ai/transformer/attention"
|
|
21
|
+
require_relative "nnw/ai/transformer/feed_forward"
|
|
22
|
+
require_relative "nnw/ai/transformer/swiglu"
|
|
23
|
+
require_relative "nnw/ai/transformer/block"
|
|
24
|
+
require_relative "nnw/ai/transformer/model"
|
|
25
|
+
require_relative "nnw/ai/transformer/modern"
|
|
26
|
+
require_relative "nnw/ai/kv_cache"
|
|
27
|
+
|
|
28
|
+
# --- Loss ---
|
|
29
|
+
require_relative "nnw/ai/loss"
|
|
30
|
+
|
|
31
|
+
# --- Optimizers ---
|
|
32
|
+
require_relative "nnw/ai/optim/base"
|
|
33
|
+
require_relative "nnw/ai/optim/sgd"
|
|
34
|
+
require_relative "nnw/ai/optim/adam"
|
|
35
|
+
require_relative "nnw/ai/optim/adamw"
|
|
36
|
+
require_relative "nnw/ai/optim/lr_scheduler"
|
|
37
|
+
|
|
38
|
+
# --- Weight loading + tokenizer ---
|
|
39
|
+
require_relative "nnw/ai/safetensors"
|
|
40
|
+
require_relative "nnw/ai/weight_map"
|
|
41
|
+
require_relative "nnw/ai/gpt2_loader"
|
|
42
|
+
require_relative "nnw/ai/llama_loader"
|
|
43
|
+
require_relative "nnw/ai/tokenizer"
|
|
44
|
+
|
|
45
|
+
# --- Inference + training ---
|
|
46
|
+
require_relative "nnw/ai/inference"
|
|
47
|
+
require_relative "nnw/ai/trainer"
|
|
48
|
+
require_relative "nnw/ai/server"
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "json"
|
|
4
|
+
|
|
5
|
+
module Ignis
|
|
6
|
+
module AI
|
|
7
|
+
# Loads HuggingFace GPT-2 safetensors checkpoints into an Ignis Transformer::Model.
|
|
8
|
+
#
|
|
9
|
+
# GPT-2 is NOT a drop-in for nn.Linear: its attention/MLP projections are
|
|
10
|
+
# `Conv1D` layers whose weights are stored as [in, out] (transposed vs.
|
|
11
|
+
# nn.Linear's [out, in]), the QKV projection is a single fused [embed, 3*embed]
|
|
12
|
+
# tensor, and the LM head is tied to the token embedding. This loader applies
|
|
13
|
+
# those transforms while copying weights in.
|
|
14
|
+
#
|
|
15
|
+
# @example
|
|
16
|
+
# model = Ignis::AI::Transformer::Model.gpt2_small
|
|
17
|
+
# Ignis::AI::GPT2Loader.load(model, "gpt2")
|
|
18
|
+
# model.eval!
|
|
19
|
+
module GPT2Loader
|
|
20
|
+
module_function
|
|
21
|
+
|
|
22
|
+
# Load weights from <dir>/model.safetensors into +model+ (in place).
|
|
23
|
+
# @param model [Ignis::AI::Transformer::Model]
|
|
24
|
+
# @param dir [String] directory holding model.safetensors
|
|
25
|
+
# @return [model]
|
|
26
|
+
def load(model, dir)
|
|
27
|
+
path = File.join(dir, "model.safetensors")
|
|
28
|
+
raise ArgumentError, "not found: #{path}" unless File.exist?(path)
|
|
29
|
+
|
|
30
|
+
header, data_offset = read_header(path)
|
|
31
|
+
params = model.named_parameters
|
|
32
|
+
embed = model.embed_dim
|
|
33
|
+
n_layer = model.num_layers
|
|
34
|
+
|
|
35
|
+
File.open(path, "rb") do |f|
|
|
36
|
+
get = lambda do |name|
|
|
37
|
+
meta = header[name] or raise "missing safetensors tensor: #{name}"
|
|
38
|
+
b, e = meta["data_offsets"]
|
|
39
|
+
f.seek(data_offset + b)
|
|
40
|
+
raw = f.read(e - b)
|
|
41
|
+
{ values: raw.unpack("e*"), shape: meta["shape"] } # F32 little-endian
|
|
42
|
+
end
|
|
43
|
+
|
|
44
|
+
set = lambda do |pname, values|
|
|
45
|
+
p = params[pname] or raise "missing model param: #{pname}"
|
|
46
|
+
unless p.numel == values.length
|
|
47
|
+
raise "size mismatch for #{pname}: param has #{p.numel}, data has #{values.length}"
|
|
48
|
+
end
|
|
49
|
+
p.data.from_host(values)
|
|
50
|
+
end
|
|
51
|
+
|
|
52
|
+
# Embeddings (direct)
|
|
53
|
+
set.call("token_embedding.weight", get.call("wte.weight")[:values])
|
|
54
|
+
set.call("position_embedding.weight", get.call("wpe.weight")[:values])
|
|
55
|
+
|
|
56
|
+
n_layer.times do |i|
|
|
57
|
+
set.call("blocks.#{i}.norm1.weight", get.call("h.#{i}.ln_1.weight")[:values])
|
|
58
|
+
set.call("blocks.#{i}.norm1.bias", get.call("h.#{i}.ln_1.bias")[:values])
|
|
59
|
+
set.call("blocks.#{i}.norm2.weight", get.call("h.#{i}.ln_2.weight")[:values])
|
|
60
|
+
set.call("blocks.#{i}.norm2.bias", get.call("h.#{i}.ln_2.bias")[:values])
|
|
61
|
+
|
|
62
|
+
# Fused QKV: c_attn.weight is Conv1D [embed, 3*embed]. Split columns into
|
|
63
|
+
# q|k|v and transpose each into nn.Linear layout [out, in] = [embed, embed].
|
|
64
|
+
cattn_w = get.call("h.#{i}.attn.c_attn.weight")[:values]
|
|
65
|
+
cattn_b = get.call("h.#{i}.attn.c_attn.bias")[:values]
|
|
66
|
+
set.call("blocks.#{i}.attention.q_proj.weight", transpose_slice_cols(cattn_w, embed, 3 * embed, 0 * embed, embed))
|
|
67
|
+
set.call("blocks.#{i}.attention.k_proj.weight", transpose_slice_cols(cattn_w, embed, 3 * embed, 1 * embed, embed))
|
|
68
|
+
set.call("blocks.#{i}.attention.v_proj.weight", transpose_slice_cols(cattn_w, embed, 3 * embed, 2 * embed, embed))
|
|
69
|
+
set.call("blocks.#{i}.attention.q_proj.bias", cattn_b[(0 * embed)...(1 * embed)])
|
|
70
|
+
set.call("blocks.#{i}.attention.k_proj.bias", cattn_b[(1 * embed)...(2 * embed)])
|
|
71
|
+
set.call("blocks.#{i}.attention.v_proj.bias", cattn_b[(2 * embed)...(3 * embed)])
|
|
72
|
+
|
|
73
|
+
# Attention output projection: Conv1D [embed, embed] -> transpose.
|
|
74
|
+
aproj = get.call("h.#{i}.attn.c_proj.weight")[:values]
|
|
75
|
+
set.call("blocks.#{i}.attention.out_proj.weight", transpose(aproj, embed, embed))
|
|
76
|
+
set.call("blocks.#{i}.attention.out_proj.bias", get.call("h.#{i}.attn.c_proj.bias")[:values])
|
|
77
|
+
|
|
78
|
+
# MLP fc1: c_fc Conv1D [embed, ff] -> transpose -> [ff, embed].
|
|
79
|
+
fc = get.call("h.#{i}.mlp.c_fc.weight")
|
|
80
|
+
ff = fc[:shape][1]
|
|
81
|
+
set.call("blocks.#{i}.feed_forward.fc1.weight", transpose(fc[:values], embed, ff))
|
|
82
|
+
set.call("blocks.#{i}.feed_forward.fc1.bias", get.call("h.#{i}.mlp.c_fc.bias")[:values])
|
|
83
|
+
|
|
84
|
+
# MLP fc2: c_proj Conv1D [ff, embed] -> transpose -> [embed, ff].
|
|
85
|
+
fp = get.call("h.#{i}.mlp.c_proj.weight")[:values]
|
|
86
|
+
set.call("blocks.#{i}.feed_forward.fc2.weight", transpose(fp, ff, embed))
|
|
87
|
+
set.call("blocks.#{i}.feed_forward.fc2.bias", get.call("h.#{i}.mlp.c_proj.bias")[:values])
|
|
88
|
+
end
|
|
89
|
+
|
|
90
|
+
set.call("norm.weight", get.call("ln_f.weight")[:values])
|
|
91
|
+
set.call("norm.bias", get.call("ln_f.bias")[:values])
|
|
92
|
+
|
|
93
|
+
# LM head is tied to the token embedding (logits = x @ wte^T).
|
|
94
|
+
set.call("head.weight", get.call("wte.weight")[:values])
|
|
95
|
+
end
|
|
96
|
+
|
|
97
|
+
model
|
|
98
|
+
end
|
|
99
|
+
|
|
100
|
+
# Read the safetensors JSON header. @return [[Hash, Integer]] header, data_offset
|
|
101
|
+
def read_header(path)
|
|
102
|
+
File.open(path, "rb") do |f|
|
|
103
|
+
n = f.read(8).unpack1("Q<")
|
|
104
|
+
h = JSON.parse(f.read(n))
|
|
105
|
+
h.delete("__metadata__")
|
|
106
|
+
[h, 8 + n]
|
|
107
|
+
end
|
|
108
|
+
end
|
|
109
|
+
|
|
110
|
+
# Transpose a row-major [rows, cols] flat array → [cols, rows].
|
|
111
|
+
def transpose(arr, rows, cols)
|
|
112
|
+
out = Array.new(rows * cols)
|
|
113
|
+
r = 0
|
|
114
|
+
while r < rows
|
|
115
|
+
base = r * cols
|
|
116
|
+
c = 0
|
|
117
|
+
while c < cols
|
|
118
|
+
out[c * rows + r] = arr[base + c]
|
|
119
|
+
c += 1
|
|
120
|
+
end
|
|
121
|
+
r += 1
|
|
122
|
+
end
|
|
123
|
+
out
|
|
124
|
+
end
|
|
125
|
+
|
|
126
|
+
# Slice columns [col_off, col_off+len) from a [rows, total_cols] array, then
|
|
127
|
+
# transpose → [len, rows]. (Used to split the fused QKV Conv1D weight.)
|
|
128
|
+
def transpose_slice_cols(arr, rows, total_cols, col_off, len)
|
|
129
|
+
out = Array.new(rows * len)
|
|
130
|
+
r = 0
|
|
131
|
+
while r < rows
|
|
132
|
+
base = r * total_cols + col_off
|
|
133
|
+
c = 0
|
|
134
|
+
while c < len
|
|
135
|
+
out[c * rows + r] = arr[base + c]
|
|
136
|
+
c += 1
|
|
137
|
+
end
|
|
138
|
+
r += 1
|
|
139
|
+
end
|
|
140
|
+
out
|
|
141
|
+
end
|
|
142
|
+
end
|
|
143
|
+
end
|
|
144
|
+
end
|
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Ignis
|
|
4
|
+
module AI
|
|
5
|
+
# TextGenerator — autoregressive inference with KV cache.
|
|
6
|
+
#
|
|
7
|
+
# Loads a Transformer model + tokenizer, generates text token-by-token
|
|
8
|
+
# with top-k/top-p sampling and streaming support via EventBus.
|
|
9
|
+
class TextGenerator
|
|
10
|
+
# @return [Transformer::Model]
|
|
11
|
+
attr_reader :model
|
|
12
|
+
|
|
13
|
+
# @return [Tokenizer]
|
|
14
|
+
attr_reader :tokenizer
|
|
15
|
+
|
|
16
|
+
# @param model [Transformer::Model]
|
|
17
|
+
# @param tokenizer [Tokenizer]
|
|
18
|
+
def initialize(model, tokenizer)
|
|
19
|
+
@model = model
|
|
20
|
+
@tokenizer = tokenizer
|
|
21
|
+
@model.eval!
|
|
22
|
+
end
|
|
23
|
+
|
|
24
|
+
# Generate text from a prompt.
|
|
25
|
+
# @param prompt [String] input text
|
|
26
|
+
# @param max_tokens [Integer] maximum new tokens to generate
|
|
27
|
+
# @param temperature [Float] sampling temperature (0.0 = greedy)
|
|
28
|
+
# @param top_k [Integer] top-k sampling (0 = disabled)
|
|
29
|
+
# @param top_p [Float] nucleus sampling threshold (1.0 = disabled)
|
|
30
|
+
# @param stop_tokens [Array<String>] strings that trigger early stop
|
|
31
|
+
# @param stream [Boolean] yield tokens as they are generated
|
|
32
|
+
# @yield [String] each generated token (when stream: true)
|
|
33
|
+
# @return [String] full generated text
|
|
34
|
+
def generate(prompt, max_tokens: 128, temperature: 0.7,
|
|
35
|
+
top_k: 50, top_p: 0.9, stop_tokens: [], stream: false, &block)
|
|
36
|
+
input_ids = @tokenizer.encode(prompt)
|
|
37
|
+
generated_ids = []
|
|
38
|
+
cache = @model.make_kv_cache
|
|
39
|
+
|
|
40
|
+
Tape.no_grad do
|
|
41
|
+
# Prefill: stream the prompt through the cache one token at a time. Each
|
|
42
|
+
# decode_step appends that token's K/V; the last token's logits predict
|
|
43
|
+
# the first generated token. O(prefix) total, like one full forward.
|
|
44
|
+
last_logits = nil
|
|
45
|
+
input_ids.each do |tid|
|
|
46
|
+
break if cache.full?
|
|
47
|
+
last_logits = @model.decode_step(tid, cache).to_host
|
|
48
|
+
end
|
|
49
|
+
|
|
50
|
+
# Decode: sample from the running logits, feed the sampled token back to
|
|
51
|
+
# extend the cache by one (O(prefix) per step, not O(prefix²) re-forward).
|
|
52
|
+
max_tokens.times do |step|
|
|
53
|
+
break if last_logits.nil? || cache.full?
|
|
54
|
+
|
|
55
|
+
next_id = sample_token(last_logits, temperature: temperature,
|
|
56
|
+
top_k: top_k, top_p: top_p)
|
|
57
|
+
generated_ids << next_id
|
|
58
|
+
|
|
59
|
+
token_text = @tokenizer.decode([next_id], skip_special_tokens: true)
|
|
60
|
+
block.call(token_text) if stream && block
|
|
61
|
+
|
|
62
|
+
if defined?(Ignis::Shared::EventBus)
|
|
63
|
+
Ignis::Shared::EventBus.publish(:token_generated, {
|
|
64
|
+
text: token_text, token_id: next_id, step: step
|
|
65
|
+
})
|
|
66
|
+
end
|
|
67
|
+
|
|
68
|
+
full_text = @tokenizer.decode(generated_ids, skip_special_tokens: true)
|
|
69
|
+
break if stop_tokens.any? { |s| full_text.include?(s) }
|
|
70
|
+
break if @tokenizer.special_token_ids.include?(next_id) # EOS
|
|
71
|
+
|
|
72
|
+
# Advance: logits conditioned on prompt + everything generated so far.
|
|
73
|
+
last_logits = @model.decode_step(next_id, cache).to_host
|
|
74
|
+
end
|
|
75
|
+
end
|
|
76
|
+
|
|
77
|
+
@tokenizer.decode(input_ids + generated_ids, skip_special_tokens: true)
|
|
78
|
+
end
|
|
79
|
+
|
|
80
|
+
# Compute embeddings for text (useful for similarity/search).
|
|
81
|
+
# @param text [String]
|
|
82
|
+
# @return [Array<Float>] embedding vector
|
|
83
|
+
def embed(text)
|
|
84
|
+
Tape.no_grad do
|
|
85
|
+
input_ids = @tokenizer.encode(text)
|
|
86
|
+
input_tensor = Tensor.from_host(input_ids, shape: [1, input_ids.length],
|
|
87
|
+
dtype: :int32, device_id: 0)
|
|
88
|
+
# Get token embeddings, mean pool
|
|
89
|
+
tok_emb = @model.instance_variable_get(:@token_embedding).call(input_tensor)
|
|
90
|
+
tok_emb.mean.to_host
|
|
91
|
+
end
|
|
92
|
+
end
|
|
93
|
+
|
|
94
|
+
private
|
|
95
|
+
|
|
96
|
+
# Sample a token from logits.
|
|
97
|
+
# @param logits [Array<Float>]
|
|
98
|
+
# @param temperature [Float]
|
|
99
|
+
# @param top_k [Integer]
|
|
100
|
+
# @param top_p [Float]
|
|
101
|
+
# @return [Integer] sampled token id
|
|
102
|
+
def sample_token(logits, temperature:, top_k:, top_p:)
|
|
103
|
+
return logits.each_with_index.max_by { |v, _| v }[1] if temperature <= 0.0
|
|
104
|
+
|
|
105
|
+
# Temperature scaling
|
|
106
|
+
scaled = logits.map { |l| l / temperature }
|
|
107
|
+
|
|
108
|
+
# Top-k filtering
|
|
109
|
+
if top_k > 0 && top_k < scaled.length
|
|
110
|
+
threshold = scaled.sort.reverse[top_k - 1]
|
|
111
|
+
scaled = scaled.map { |l| l >= threshold ? l : -Float::INFINITY }
|
|
112
|
+
end
|
|
113
|
+
|
|
114
|
+
# Softmax
|
|
115
|
+
max_val = scaled.max
|
|
116
|
+
exps = scaled.map { |l| Math.exp(l - max_val) }
|
|
117
|
+
sum = exps.sum
|
|
118
|
+
probs = exps.map { |e| e / sum }
|
|
119
|
+
|
|
120
|
+
# Top-p filtering
|
|
121
|
+
if top_p < 1.0
|
|
122
|
+
sorted_indices = probs.each_with_index.sort_by { |p, _| -p }.map { |_, i| i }
|
|
123
|
+
cumsum = 0.0
|
|
124
|
+
cutoff_idx = sorted_indices.length
|
|
125
|
+
sorted_indices.each_with_index do |idx, rank|
|
|
126
|
+
cumsum += probs[idx]
|
|
127
|
+
if cumsum > top_p
|
|
128
|
+
cutoff_idx = rank + 1
|
|
129
|
+
break
|
|
130
|
+
end
|
|
131
|
+
end
|
|
132
|
+
allowed = Set.new(sorted_indices[0...cutoff_idx])
|
|
133
|
+
probs = probs.each_with_index.map { |p, i| allowed.include?(i) ? p : 0.0 }
|
|
134
|
+
sum = probs.sum
|
|
135
|
+
probs = probs.map { |p| p / sum } if sum > 0
|
|
136
|
+
end
|
|
137
|
+
|
|
138
|
+
# Multinomial sampling
|
|
139
|
+
r = Kernel.rand
|
|
140
|
+
cumulative = 0.0
|
|
141
|
+
probs.each_with_index do |p, i|
|
|
142
|
+
cumulative += p
|
|
143
|
+
return i if r <= cumulative
|
|
144
|
+
end
|
|
145
|
+
|
|
146
|
+
probs.length - 1 # Fallback
|
|
147
|
+
end
|
|
148
|
+
end
|
|
149
|
+
|
|
150
|
+
# BatchProcessor — concurrent inference with dynamic batching.
|
|
151
|
+
class BatchProcessor
|
|
152
|
+
# @param model [Transformer::Model]
|
|
153
|
+
# @param tokenizer [Tokenizer]
|
|
154
|
+
# @param max_batch_size [Integer]
|
|
155
|
+
# @param max_wait_ms [Integer] max milliseconds to wait for batch fill
|
|
156
|
+
def initialize(model, tokenizer, max_batch_size: 8, max_wait_ms: 50)
|
|
157
|
+
@generator = TextGenerator.new(model, tokenizer)
|
|
158
|
+
@max_batch_size = max_batch_size
|
|
159
|
+
@max_wait_ms = max_wait_ms
|
|
160
|
+
@queue = Queue.new
|
|
161
|
+
@running = false
|
|
162
|
+
end
|
|
163
|
+
|
|
164
|
+
# Submit a request for processing.
|
|
165
|
+
# @param prompt [String]
|
|
166
|
+
# @param params [Hash] generation parameters
|
|
167
|
+
# @return [String] generated text (blocks until complete)
|
|
168
|
+
def submit(prompt, **params)
|
|
169
|
+
result_queue = Queue.new
|
|
170
|
+
@queue << { prompt: prompt, params: params, result: result_queue }
|
|
171
|
+
result_queue.pop
|
|
172
|
+
end
|
|
173
|
+
|
|
174
|
+
# Start the batch processing loop in a background thread.
|
|
175
|
+
# @return [Thread]
|
|
176
|
+
def start!
|
|
177
|
+
@running = true
|
|
178
|
+
@thread = Thread.new { batch_loop }
|
|
179
|
+
@thread
|
|
180
|
+
end
|
|
181
|
+
|
|
182
|
+
# Stop the processor.
|
|
183
|
+
# @return [void]
|
|
184
|
+
def stop!
|
|
185
|
+
@running = false
|
|
186
|
+
@thread&.join(5)
|
|
187
|
+
end
|
|
188
|
+
|
|
189
|
+
private
|
|
190
|
+
|
|
191
|
+
def batch_loop
|
|
192
|
+
while @running
|
|
193
|
+
batch = []
|
|
194
|
+
deadline = Process.clock_gettime(Process::CLOCK_MONOTONIC) + @max_wait_ms / 1000.0
|
|
195
|
+
|
|
196
|
+
# Collect up to max_batch_size requests
|
|
197
|
+
while batch.size < @max_batch_size
|
|
198
|
+
remaining = deadline - Process.clock_gettime(Process::CLOCK_MONOTONIC)
|
|
199
|
+
break if remaining <= 0
|
|
200
|
+
|
|
201
|
+
begin
|
|
202
|
+
request = @queue.pop(true) rescue nil
|
|
203
|
+
batch << request if request
|
|
204
|
+
rescue ThreadError
|
|
205
|
+
sleep(0.001)
|
|
206
|
+
end
|
|
207
|
+
end
|
|
208
|
+
|
|
209
|
+
next if batch.empty?
|
|
210
|
+
|
|
211
|
+
# Process each request (serial for now — batched matmul is future optimization)
|
|
212
|
+
batch.each do |req|
|
|
213
|
+
begin
|
|
214
|
+
result = @generator.generate(req[:prompt], **req[:params])
|
|
215
|
+
req[:result] << result
|
|
216
|
+
rescue => e
|
|
217
|
+
req[:result] << "Error: #{e.message}"
|
|
218
|
+
end
|
|
219
|
+
end
|
|
220
|
+
end
|
|
221
|
+
end
|
|
222
|
+
end
|
|
223
|
+
end
|
|
224
|
+
end
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Ignis
|
|
4
|
+
module AI
|
|
5
|
+
# KVCache — per-layer key/value cache for O(1)-prefix autoregressive decoding.
|
|
6
|
+
#
|
|
7
|
+
# Without a cache, generating token t re-runs attention over the entire prefix
|
|
8
|
+
# (O(t) work that grows every step → O(n²) generation). The cache stores each
|
|
9
|
+
# layer's projected K and V once; each new token only projects ITS key/value,
|
|
10
|
+
# appends them, and attends its single query over all cached keys.
|
|
11
|
+
#
|
|
12
|
+
# Buffers are preallocated to [max_seq_len, embed_dim] per layer so appending a
|
|
13
|
+
# row is an O(row) device→device copy (NvArray#write_rows!) rather than a
|
|
14
|
+
# realloc + full recopy. A contiguous [length+1, embed] view (slice along dim 0)
|
|
15
|
+
# exposes the live region — including the row just appended this step.
|
|
16
|
+
class KVCache
|
|
17
|
+
# @return [Integer] number of tokens currently cached (positions 0..length-1)
|
|
18
|
+
attr_reader :length
|
|
19
|
+
# @return [Integer]
|
|
20
|
+
attr_reader :max_seq_len, :num_layers, :embed_dim
|
|
21
|
+
|
|
22
|
+
# @param num_layers [Integer]
|
|
23
|
+
# @param max_seq_len [Integer] capacity (generation cannot exceed this)
|
|
24
|
+
# @param embed_dim [Integer]
|
|
25
|
+
# @param device_id [Integer]
|
|
26
|
+
def initialize(num_layers:, max_seq_len:, embed_dim:, device_id: 0)
|
|
27
|
+
@num_layers = num_layers
|
|
28
|
+
@max_seq_len = max_seq_len
|
|
29
|
+
@embed_dim = embed_dim
|
|
30
|
+
@device_id = device_id
|
|
31
|
+
@length = 0
|
|
32
|
+
@k = Array.new(num_layers) { make_buffer }
|
|
33
|
+
@v = Array.new(num_layers) { make_buffer }
|
|
34
|
+
end
|
|
35
|
+
|
|
36
|
+
# Append a layer's K/V for the current token at row +length+.
|
|
37
|
+
# @param layer [Integer] layer index
|
|
38
|
+
# @param k_row [Ignis::Shared::NvArray] [1, embed] key projection of the new token
|
|
39
|
+
# @param v_row [Ignis::Shared::NvArray] [1, embed] value projection of the new token
|
|
40
|
+
# @return [void]
|
|
41
|
+
def append(layer, k_row, v_row)
|
|
42
|
+
@k[layer].write_rows!(k_row, @length)
|
|
43
|
+
@v[layer].write_rows!(v_row, @length)
|
|
44
|
+
end
|
|
45
|
+
|
|
46
|
+
# Contiguous [length+1, embed] view of cached keys for +layer+, INCLUDING the
|
|
47
|
+
# row just appended this step. (Call after #append, before #advance!.)
|
|
48
|
+
# @param layer [Integer]
|
|
49
|
+
# @return [Ignis::Shared::NvArray] non-owning view
|
|
50
|
+
def k_view(layer)
|
|
51
|
+
@k[layer].slice(0, 0, @length + 1)
|
|
52
|
+
end
|
|
53
|
+
|
|
54
|
+
# @param layer [Integer]
|
|
55
|
+
# @return [Ignis::Shared::NvArray] non-owning view of cached values
|
|
56
|
+
def v_view(layer)
|
|
57
|
+
@v[layer].slice(0, 0, @length + 1)
|
|
58
|
+
end
|
|
59
|
+
|
|
60
|
+
# Advance to the next position after every layer has appended this token.
|
|
61
|
+
# @return [void]
|
|
62
|
+
def advance!
|
|
63
|
+
@length += 1
|
|
64
|
+
end
|
|
65
|
+
|
|
66
|
+
# @return [Boolean] true if the cache is at capacity (no more tokens fit)
|
|
67
|
+
def full?
|
|
68
|
+
@length >= @max_seq_len
|
|
69
|
+
end
|
|
70
|
+
|
|
71
|
+
private
|
|
72
|
+
|
|
73
|
+
def make_buffer
|
|
74
|
+
Ignis::Shared::NvArray.new(shape: [@max_seq_len, @embed_dim],
|
|
75
|
+
dtype: :float32, device_id: @device_id).to_device
|
|
76
|
+
end
|
|
77
|
+
end
|
|
78
|
+
end
|
|
79
|
+
end
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "json"
|
|
4
|
+
|
|
5
|
+
module Ignis
|
|
6
|
+
module AI
|
|
7
|
+
# LlamaLoader — load HuggingFace Llama-family checkpoints (Llama-3.x, and other
|
|
8
|
+
# LlamaForCausalLM models: SmolLM, TinyLlama, …) into a Transformer::ModernModel.
|
|
9
|
+
#
|
|
10
|
+
# Reads config.json to size the model (RoPE base/scaling, GQA head counts, RMSNorm
|
|
11
|
+
# eps, tied embeddings), then loads model.safetensors, dequantizing bf16 weights to
|
|
12
|
+
# fp32 on-device. HF's nn.Linear weights are [out, in] and applied as x·Wᵀ — the
|
|
13
|
+
# SAME convention as Ignis::AI::NN::Linear — so weights map across with no transpose
|
|
14
|
+
# (unlike GPT-2's Conv1D).
|
|
15
|
+
module LlamaLoader
|
|
16
|
+
class << self
|
|
17
|
+
# Build a ModernModel from config.json and load its weights.
|
|
18
|
+
# @param dir [String] directory containing config.json + model.safetensors
|
|
19
|
+
# @param device_id [Integer]
|
|
20
|
+
# @return [Transformer::ModernModel]
|
|
21
|
+
def from_pretrained(dir, device_id: 0)
|
|
22
|
+
cfg = JSON.parse(File.read(File.join(dir, "config.json")))
|
|
23
|
+
model = Transformer::ModernModel.new(
|
|
24
|
+
vocab_size: cfg["vocab_size"],
|
|
25
|
+
embed_dim: cfg["hidden_size"],
|
|
26
|
+
num_heads: cfg["num_attention_heads"],
|
|
27
|
+
num_kv_heads: cfg["num_key_value_heads"] || cfg["num_attention_heads"],
|
|
28
|
+
num_layers: cfg["num_hidden_layers"],
|
|
29
|
+
ff_dim: cfg["intermediate_size"],
|
|
30
|
+
max_seq_len: cfg["max_position_embeddings"],
|
|
31
|
+
rope_base: (cfg["rope_theta"] || 10000.0).to_f,
|
|
32
|
+
rope_scaling: cfg["rope_scaling"],
|
|
33
|
+
head_dim: cfg["head_dim"],
|
|
34
|
+
eps: (cfg["rms_norm_eps"] || 1e-5).to_f,
|
|
35
|
+
device_id: device_id
|
|
36
|
+
)
|
|
37
|
+
load(model, dir, device_id: device_id)
|
|
38
|
+
model
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
# Load weights from dir/model.safetensors into an existing ModernModel.
|
|
42
|
+
# @return [Integer] number of parameters loaded
|
|
43
|
+
def load(model, dir, device_id: 0)
|
|
44
|
+
tensors = Safetensors.load(File.join(dir, "model.safetensors"), device_id: device_id)
|
|
45
|
+
embed_src = tensors["model.embed_tokens.weight"]
|
|
46
|
+
raise "LlamaLoader: missing model.embed_tokens.weight" unless embed_src
|
|
47
|
+
lm_head_src = tensors["lm_head.weight"] # nil when embeddings are tied
|
|
48
|
+
|
|
49
|
+
count = 0
|
|
50
|
+
model.named_parameters.each do |name, param|
|
|
51
|
+
src = if name == "head.weight"
|
|
52
|
+
lm_head_src || embed_src # untied lm_head, else tied to embeddings
|
|
53
|
+
else
|
|
54
|
+
tensors[hf_name(name)]
|
|
55
|
+
end
|
|
56
|
+
raise "LlamaLoader: no source weight for param #{name} (#{hf_name(name)})" unless src
|
|
57
|
+
unless param.shape == src.shape
|
|
58
|
+
raise "LlamaLoader: shape mismatch for #{name}: model #{param.shape} vs file #{src.shape}"
|
|
59
|
+
end
|
|
60
|
+
dequant_into!(src.data, param.data)
|
|
61
|
+
count += 1
|
|
62
|
+
end
|
|
63
|
+
Ignis.synchronize
|
|
64
|
+
count
|
|
65
|
+
end
|
|
66
|
+
|
|
67
|
+
private
|
|
68
|
+
|
|
69
|
+
# Translate an Ignis ModernModel parameter name → HuggingFace Llama weight name.
|
|
70
|
+
# ("head.weight" is resolved in #load, since it may be tied or an explicit lm_head.)
|
|
71
|
+
def hf_name(name)
|
|
72
|
+
case name
|
|
73
|
+
when "token_embedding.weight" then "model.embed_tokens.weight"
|
|
74
|
+
when "norm.weight" then "model.norm.weight"
|
|
75
|
+
when /\Ablocks\.(\d+)\.norm1\.weight\z/ then "model.layers.#{$1}.input_layernorm.weight"
|
|
76
|
+
when /\Ablocks\.(\d+)\.norm2\.weight\z/ then "model.layers.#{$1}.post_attention_layernorm.weight"
|
|
77
|
+
when /\Ablocks\.(\d+)\.attn\.([qkvo])_proj\.weight\z/ then "model.layers.#{$1}.self_attn.#{$2}_proj.weight"
|
|
78
|
+
when /\Ablocks\.(\d+)\.mlp\.(gate|up|down)\.weight\z/ then "model.layers.#{$1}.mlp.#{$2}_proj.weight"
|
|
79
|
+
else raise "LlamaLoader: unmapped parameter #{name}"
|
|
80
|
+
end
|
|
81
|
+
end
|
|
82
|
+
|
|
83
|
+
# Dequantize a bf16 source NvArray into an fp32 destination NvArray (on-device).
|
|
84
|
+
def dequant_into!(src_nv, dst_nv)
|
|
85
|
+
unless src_nv.numel == dst_nv.numel
|
|
86
|
+
raise "LlamaLoader: numel mismatch #{src_nv.numel} vs #{dst_nv.numel}"
|
|
87
|
+
end
|
|
88
|
+
if src_nv.dtype == :float32
|
|
89
|
+
# already fp32 (some checkpoints) — straight device→device copy
|
|
90
|
+
dst_nv.write_rows!(src_nv, 0)
|
|
91
|
+
return
|
|
92
|
+
end
|
|
93
|
+
n = dst_nv.numel
|
|
94
|
+
k = Ignis::JIT::Kernels::Elementwise.bf16_to_f32
|
|
95
|
+
k.launch(grid: [(n + 255) / 256], block: [256], args: [src_nv, dst_nv, n])
|
|
96
|
+
end
|
|
97
|
+
end
|
|
98
|
+
end
|
|
99
|
+
end
|
|
100
|
+
end
|