ignis-dl 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml ADDED
@@ -0,0 +1,7 @@
1
+ ---
2
+ SHA256:
3
+ metadata.gz: 2e2b7a2b0882360bf9c07ca60b72aa71c3a43dbb9739aa9dde31182e8e9068ca
4
+ data.tar.gz: a34343e8f5e2334a79b8a5b6dc935891d8776ffe5b64f47832dccb2f87ce9fcc
5
+ SHA512:
6
+ metadata.gz: 525582d82f78c61c0fedf0a31273f6010453c5f3e1753ff2908b6ba402f03de909aa6a7eb1dc12c0feb773e7a65c519f6be11a4dc44db6307f5aed4fe1baa65d
7
+ data.tar.gz: f7c0f62d4cdcb6ddcf8daf3af78a1d8580663f91dab1ef6ac8219c9b2ce445e6858368b9782ce779b0710ca403579c58505ff2079297599e6d454b3cad5c2d0e
data/README.md ADDED
@@ -0,0 +1,15 @@
1
+ # ignis-dl
2
+
3
+ Neural-net layers + transformers for Ruby, on the [Ignis](https://github.com/tigel-agm/Ignis) GPU/autograd stack. Installing this pulls the whole chain (`ignis` + `ignis-autograd`).
4
+
5
+ - **NN:** Linear, Embedding, LayerNorm, RMSNorm, Dropout; SGD/Adam/AdamW; losses.
6
+ - **Transformers:** multi-head + grouped-query attention, RoPE, SwiGLU, KV cache.
7
+ - **Loaders:** real GPT-2 and Llama (HuggingFace `safetensors`) — matches HF logits.
8
+
9
+ ```ruby
10
+ require "ignis-dl"
11
+ model = Ignis::AI::LlamaLoader.from_pretrained("path/to/Llama-3.2-1B")
12
+ # ... matches HuggingFace; "The capital of France is" => " Paris"
13
+ ```
14
+
15
+ Loads real GPT-2 and Llama-3.2 (cosine 1.0 vs HuggingFace) and trains transformers from scratch. MIT.
data/lib/ignis-dl.rb ADDED
@@ -0,0 +1,48 @@
1
+ # frozen_string_literal: true
2
+
3
+ # ignis-dl — neural-net layers + transformers for Ruby, on the Ignis GPU/autograd
4
+ # stack. Loads real GPT-2 and Llama checkpoints (matching HuggingFace) and trains
5
+ # transformers from scratch. Requiring this pulls the whole chain (ignis +
6
+ # ignis-autograd), so it doubles as the meta-gem.
7
+
8
+ require "ignis-autograd"
9
+
10
+ # --- NN modules ---
11
+ require_relative "nnw/ai/nn/module"
12
+ require_relative "nnw/ai/nn/linear"
13
+ require_relative "nnw/ai/nn/embedding"
14
+ require_relative "nnw/ai/nn/layer_norm"
15
+ require_relative "nnw/ai/nn/rms_norm"
16
+ require_relative "nnw/ai/nn/dropout"
17
+ require_relative "nnw/ai/nn/sequential"
18
+
19
+ # --- Transformer ---
20
+ require_relative "nnw/ai/transformer/attention"
21
+ require_relative "nnw/ai/transformer/feed_forward"
22
+ require_relative "nnw/ai/transformer/swiglu"
23
+ require_relative "nnw/ai/transformer/block"
24
+ require_relative "nnw/ai/transformer/model"
25
+ require_relative "nnw/ai/transformer/modern"
26
+ require_relative "nnw/ai/kv_cache"
27
+
28
+ # --- Loss ---
29
+ require_relative "nnw/ai/loss"
30
+
31
+ # --- Optimizers ---
32
+ require_relative "nnw/ai/optim/base"
33
+ require_relative "nnw/ai/optim/sgd"
34
+ require_relative "nnw/ai/optim/adam"
35
+ require_relative "nnw/ai/optim/adamw"
36
+ require_relative "nnw/ai/optim/lr_scheduler"
37
+
38
+ # --- Weight loading + tokenizer ---
39
+ require_relative "nnw/ai/safetensors"
40
+ require_relative "nnw/ai/weight_map"
41
+ require_relative "nnw/ai/gpt2_loader"
42
+ require_relative "nnw/ai/llama_loader"
43
+ require_relative "nnw/ai/tokenizer"
44
+
45
+ # --- Inference + training ---
46
+ require_relative "nnw/ai/inference"
47
+ require_relative "nnw/ai/trainer"
48
+ require_relative "nnw/ai/server"
@@ -0,0 +1,144 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "json"
4
+
5
+ module Ignis
6
+ module AI
7
+ # Loads HuggingFace GPT-2 safetensors checkpoints into an Ignis Transformer::Model.
8
+ #
9
+ # GPT-2 is NOT a drop-in for nn.Linear: its attention/MLP projections are
10
+ # `Conv1D` layers whose weights are stored as [in, out] (transposed vs.
11
+ # nn.Linear's [out, in]), the QKV projection is a single fused [embed, 3*embed]
12
+ # tensor, and the LM head is tied to the token embedding. This loader applies
13
+ # those transforms while copying weights in.
14
+ #
15
+ # @example
16
+ # model = Ignis::AI::Transformer::Model.gpt2_small
17
+ # Ignis::AI::GPT2Loader.load(model, "gpt2")
18
+ # model.eval!
19
+ module GPT2Loader
20
+ module_function
21
+
22
+ # Load weights from <dir>/model.safetensors into +model+ (in place).
23
+ # @param model [Ignis::AI::Transformer::Model]
24
+ # @param dir [String] directory holding model.safetensors
25
+ # @return [model]
26
+ def load(model, dir)
27
+ path = File.join(dir, "model.safetensors")
28
+ raise ArgumentError, "not found: #{path}" unless File.exist?(path)
29
+
30
+ header, data_offset = read_header(path)
31
+ params = model.named_parameters
32
+ embed = model.embed_dim
33
+ n_layer = model.num_layers
34
+
35
+ File.open(path, "rb") do |f|
36
+ get = lambda do |name|
37
+ meta = header[name] or raise "missing safetensors tensor: #{name}"
38
+ b, e = meta["data_offsets"]
39
+ f.seek(data_offset + b)
40
+ raw = f.read(e - b)
41
+ { values: raw.unpack("e*"), shape: meta["shape"] } # F32 little-endian
42
+ end
43
+
44
+ set = lambda do |pname, values|
45
+ p = params[pname] or raise "missing model param: #{pname}"
46
+ unless p.numel == values.length
47
+ raise "size mismatch for #{pname}: param has #{p.numel}, data has #{values.length}"
48
+ end
49
+ p.data.from_host(values)
50
+ end
51
+
52
+ # Embeddings (direct)
53
+ set.call("token_embedding.weight", get.call("wte.weight")[:values])
54
+ set.call("position_embedding.weight", get.call("wpe.weight")[:values])
55
+
56
+ n_layer.times do |i|
57
+ set.call("blocks.#{i}.norm1.weight", get.call("h.#{i}.ln_1.weight")[:values])
58
+ set.call("blocks.#{i}.norm1.bias", get.call("h.#{i}.ln_1.bias")[:values])
59
+ set.call("blocks.#{i}.norm2.weight", get.call("h.#{i}.ln_2.weight")[:values])
60
+ set.call("blocks.#{i}.norm2.bias", get.call("h.#{i}.ln_2.bias")[:values])
61
+
62
+ # Fused QKV: c_attn.weight is Conv1D [embed, 3*embed]. Split columns into
63
+ # q|k|v and transpose each into nn.Linear layout [out, in] = [embed, embed].
64
+ cattn_w = get.call("h.#{i}.attn.c_attn.weight")[:values]
65
+ cattn_b = get.call("h.#{i}.attn.c_attn.bias")[:values]
66
+ set.call("blocks.#{i}.attention.q_proj.weight", transpose_slice_cols(cattn_w, embed, 3 * embed, 0 * embed, embed))
67
+ set.call("blocks.#{i}.attention.k_proj.weight", transpose_slice_cols(cattn_w, embed, 3 * embed, 1 * embed, embed))
68
+ set.call("blocks.#{i}.attention.v_proj.weight", transpose_slice_cols(cattn_w, embed, 3 * embed, 2 * embed, embed))
69
+ set.call("blocks.#{i}.attention.q_proj.bias", cattn_b[(0 * embed)...(1 * embed)])
70
+ set.call("blocks.#{i}.attention.k_proj.bias", cattn_b[(1 * embed)...(2 * embed)])
71
+ set.call("blocks.#{i}.attention.v_proj.bias", cattn_b[(2 * embed)...(3 * embed)])
72
+
73
+ # Attention output projection: Conv1D [embed, embed] -> transpose.
74
+ aproj = get.call("h.#{i}.attn.c_proj.weight")[:values]
75
+ set.call("blocks.#{i}.attention.out_proj.weight", transpose(aproj, embed, embed))
76
+ set.call("blocks.#{i}.attention.out_proj.bias", get.call("h.#{i}.attn.c_proj.bias")[:values])
77
+
78
+ # MLP fc1: c_fc Conv1D [embed, ff] -> transpose -> [ff, embed].
79
+ fc = get.call("h.#{i}.mlp.c_fc.weight")
80
+ ff = fc[:shape][1]
81
+ set.call("blocks.#{i}.feed_forward.fc1.weight", transpose(fc[:values], embed, ff))
82
+ set.call("blocks.#{i}.feed_forward.fc1.bias", get.call("h.#{i}.mlp.c_fc.bias")[:values])
83
+
84
+ # MLP fc2: c_proj Conv1D [ff, embed] -> transpose -> [embed, ff].
85
+ fp = get.call("h.#{i}.mlp.c_proj.weight")[:values]
86
+ set.call("blocks.#{i}.feed_forward.fc2.weight", transpose(fp, ff, embed))
87
+ set.call("blocks.#{i}.feed_forward.fc2.bias", get.call("h.#{i}.mlp.c_proj.bias")[:values])
88
+ end
89
+
90
+ set.call("norm.weight", get.call("ln_f.weight")[:values])
91
+ set.call("norm.bias", get.call("ln_f.bias")[:values])
92
+
93
+ # LM head is tied to the token embedding (logits = x @ wte^T).
94
+ set.call("head.weight", get.call("wte.weight")[:values])
95
+ end
96
+
97
+ model
98
+ end
99
+
100
+ # Read the safetensors JSON header. @return [[Hash, Integer]] header, data_offset
101
+ def read_header(path)
102
+ File.open(path, "rb") do |f|
103
+ n = f.read(8).unpack1("Q<")
104
+ h = JSON.parse(f.read(n))
105
+ h.delete("__metadata__")
106
+ [h, 8 + n]
107
+ end
108
+ end
109
+
110
+ # Transpose a row-major [rows, cols] flat array → [cols, rows].
111
+ def transpose(arr, rows, cols)
112
+ out = Array.new(rows * cols)
113
+ r = 0
114
+ while r < rows
115
+ base = r * cols
116
+ c = 0
117
+ while c < cols
118
+ out[c * rows + r] = arr[base + c]
119
+ c += 1
120
+ end
121
+ r += 1
122
+ end
123
+ out
124
+ end
125
+
126
+ # Slice columns [col_off, col_off+len) from a [rows, total_cols] array, then
127
+ # transpose → [len, rows]. (Used to split the fused QKV Conv1D weight.)
128
+ def transpose_slice_cols(arr, rows, total_cols, col_off, len)
129
+ out = Array.new(rows * len)
130
+ r = 0
131
+ while r < rows
132
+ base = r * total_cols + col_off
133
+ c = 0
134
+ while c < len
135
+ out[c * rows + r] = arr[base + c]
136
+ c += 1
137
+ end
138
+ r += 1
139
+ end
140
+ out
141
+ end
142
+ end
143
+ end
144
+ end
@@ -0,0 +1,224 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Ignis
4
+ module AI
5
+ # TextGenerator — autoregressive inference with KV cache.
6
+ #
7
+ # Loads a Transformer model + tokenizer, generates text token-by-token
8
+ # with top-k/top-p sampling and streaming support via EventBus.
9
+ class TextGenerator
10
+ # @return [Transformer::Model]
11
+ attr_reader :model
12
+
13
+ # @return [Tokenizer]
14
+ attr_reader :tokenizer
15
+
16
+ # @param model [Transformer::Model]
17
+ # @param tokenizer [Tokenizer]
18
+ def initialize(model, tokenizer)
19
+ @model = model
20
+ @tokenizer = tokenizer
21
+ @model.eval!
22
+ end
23
+
24
+ # Generate text from a prompt.
25
+ # @param prompt [String] input text
26
+ # @param max_tokens [Integer] maximum new tokens to generate
27
+ # @param temperature [Float] sampling temperature (0.0 = greedy)
28
+ # @param top_k [Integer] top-k sampling (0 = disabled)
29
+ # @param top_p [Float] nucleus sampling threshold (1.0 = disabled)
30
+ # @param stop_tokens [Array<String>] strings that trigger early stop
31
+ # @param stream [Boolean] yield tokens as they are generated
32
+ # @yield [String] each generated token (when stream: true)
33
+ # @return [String] full generated text
34
+ def generate(prompt, max_tokens: 128, temperature: 0.7,
35
+ top_k: 50, top_p: 0.9, stop_tokens: [], stream: false, &block)
36
+ input_ids = @tokenizer.encode(prompt)
37
+ generated_ids = []
38
+ cache = @model.make_kv_cache
39
+
40
+ Tape.no_grad do
41
+ # Prefill: stream the prompt through the cache one token at a time. Each
42
+ # decode_step appends that token's K/V; the last token's logits predict
43
+ # the first generated token. O(prefix) total, like one full forward.
44
+ last_logits = nil
45
+ input_ids.each do |tid|
46
+ break if cache.full?
47
+ last_logits = @model.decode_step(tid, cache).to_host
48
+ end
49
+
50
+ # Decode: sample from the running logits, feed the sampled token back to
51
+ # extend the cache by one (O(prefix) per step, not O(prefix²) re-forward).
52
+ max_tokens.times do |step|
53
+ break if last_logits.nil? || cache.full?
54
+
55
+ next_id = sample_token(last_logits, temperature: temperature,
56
+ top_k: top_k, top_p: top_p)
57
+ generated_ids << next_id
58
+
59
+ token_text = @tokenizer.decode([next_id], skip_special_tokens: true)
60
+ block.call(token_text) if stream && block
61
+
62
+ if defined?(Ignis::Shared::EventBus)
63
+ Ignis::Shared::EventBus.publish(:token_generated, {
64
+ text: token_text, token_id: next_id, step: step
65
+ })
66
+ end
67
+
68
+ full_text = @tokenizer.decode(generated_ids, skip_special_tokens: true)
69
+ break if stop_tokens.any? { |s| full_text.include?(s) }
70
+ break if @tokenizer.special_token_ids.include?(next_id) # EOS
71
+
72
+ # Advance: logits conditioned on prompt + everything generated so far.
73
+ last_logits = @model.decode_step(next_id, cache).to_host
74
+ end
75
+ end
76
+
77
+ @tokenizer.decode(input_ids + generated_ids, skip_special_tokens: true)
78
+ end
79
+
80
+ # Compute embeddings for text (useful for similarity/search).
81
+ # @param text [String]
82
+ # @return [Array<Float>] embedding vector
83
+ def embed(text)
84
+ Tape.no_grad do
85
+ input_ids = @tokenizer.encode(text)
86
+ input_tensor = Tensor.from_host(input_ids, shape: [1, input_ids.length],
87
+ dtype: :int32, device_id: 0)
88
+ # Get token embeddings, mean pool
89
+ tok_emb = @model.instance_variable_get(:@token_embedding).call(input_tensor)
90
+ tok_emb.mean.to_host
91
+ end
92
+ end
93
+
94
+ private
95
+
96
+ # Sample a token from logits.
97
+ # @param logits [Array<Float>]
98
+ # @param temperature [Float]
99
+ # @param top_k [Integer]
100
+ # @param top_p [Float]
101
+ # @return [Integer] sampled token id
102
+ def sample_token(logits, temperature:, top_k:, top_p:)
103
+ return logits.each_with_index.max_by { |v, _| v }[1] if temperature <= 0.0
104
+
105
+ # Temperature scaling
106
+ scaled = logits.map { |l| l / temperature }
107
+
108
+ # Top-k filtering
109
+ if top_k > 0 && top_k < scaled.length
110
+ threshold = scaled.sort.reverse[top_k - 1]
111
+ scaled = scaled.map { |l| l >= threshold ? l : -Float::INFINITY }
112
+ end
113
+
114
+ # Softmax
115
+ max_val = scaled.max
116
+ exps = scaled.map { |l| Math.exp(l - max_val) }
117
+ sum = exps.sum
118
+ probs = exps.map { |e| e / sum }
119
+
120
+ # Top-p filtering
121
+ if top_p < 1.0
122
+ sorted_indices = probs.each_with_index.sort_by { |p, _| -p }.map { |_, i| i }
123
+ cumsum = 0.0
124
+ cutoff_idx = sorted_indices.length
125
+ sorted_indices.each_with_index do |idx, rank|
126
+ cumsum += probs[idx]
127
+ if cumsum > top_p
128
+ cutoff_idx = rank + 1
129
+ break
130
+ end
131
+ end
132
+ allowed = Set.new(sorted_indices[0...cutoff_idx])
133
+ probs = probs.each_with_index.map { |p, i| allowed.include?(i) ? p : 0.0 }
134
+ sum = probs.sum
135
+ probs = probs.map { |p| p / sum } if sum > 0
136
+ end
137
+
138
+ # Multinomial sampling
139
+ r = Kernel.rand
140
+ cumulative = 0.0
141
+ probs.each_with_index do |p, i|
142
+ cumulative += p
143
+ return i if r <= cumulative
144
+ end
145
+
146
+ probs.length - 1 # Fallback
147
+ end
148
+ end
149
+
150
+ # BatchProcessor — concurrent inference with dynamic batching.
151
+ class BatchProcessor
152
+ # @param model [Transformer::Model]
153
+ # @param tokenizer [Tokenizer]
154
+ # @param max_batch_size [Integer]
155
+ # @param max_wait_ms [Integer] max milliseconds to wait for batch fill
156
+ def initialize(model, tokenizer, max_batch_size: 8, max_wait_ms: 50)
157
+ @generator = TextGenerator.new(model, tokenizer)
158
+ @max_batch_size = max_batch_size
159
+ @max_wait_ms = max_wait_ms
160
+ @queue = Queue.new
161
+ @running = false
162
+ end
163
+
164
+ # Submit a request for processing.
165
+ # @param prompt [String]
166
+ # @param params [Hash] generation parameters
167
+ # @return [String] generated text (blocks until complete)
168
+ def submit(prompt, **params)
169
+ result_queue = Queue.new
170
+ @queue << { prompt: prompt, params: params, result: result_queue }
171
+ result_queue.pop
172
+ end
173
+
174
+ # Start the batch processing loop in a background thread.
175
+ # @return [Thread]
176
+ def start!
177
+ @running = true
178
+ @thread = Thread.new { batch_loop }
179
+ @thread
180
+ end
181
+
182
+ # Stop the processor.
183
+ # @return [void]
184
+ def stop!
185
+ @running = false
186
+ @thread&.join(5)
187
+ end
188
+
189
+ private
190
+
191
+ def batch_loop
192
+ while @running
193
+ batch = []
194
+ deadline = Process.clock_gettime(Process::CLOCK_MONOTONIC) + @max_wait_ms / 1000.0
195
+
196
+ # Collect up to max_batch_size requests
197
+ while batch.size < @max_batch_size
198
+ remaining = deadline - Process.clock_gettime(Process::CLOCK_MONOTONIC)
199
+ break if remaining <= 0
200
+
201
+ begin
202
+ request = @queue.pop(true) rescue nil
203
+ batch << request if request
204
+ rescue ThreadError
205
+ sleep(0.001)
206
+ end
207
+ end
208
+
209
+ next if batch.empty?
210
+
211
+ # Process each request (serial for now — batched matmul is future optimization)
212
+ batch.each do |req|
213
+ begin
214
+ result = @generator.generate(req[:prompt], **req[:params])
215
+ req[:result] << result
216
+ rescue => e
217
+ req[:result] << "Error: #{e.message}"
218
+ end
219
+ end
220
+ end
221
+ end
222
+ end
223
+ end
224
+ end
@@ -0,0 +1,79 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Ignis
4
+ module AI
5
+ # KVCache — per-layer key/value cache for O(1)-prefix autoregressive decoding.
6
+ #
7
+ # Without a cache, generating token t re-runs attention over the entire prefix
8
+ # (O(t) work that grows every step → O(n²) generation). The cache stores each
9
+ # layer's projected K and V once; each new token only projects ITS key/value,
10
+ # appends them, and attends its single query over all cached keys.
11
+ #
12
+ # Buffers are preallocated to [max_seq_len, embed_dim] per layer so appending a
13
+ # row is an O(row) device→device copy (NvArray#write_rows!) rather than a
14
+ # realloc + full recopy. A contiguous [length+1, embed] view (slice along dim 0)
15
+ # exposes the live region — including the row just appended this step.
16
+ class KVCache
17
+ # @return [Integer] number of tokens currently cached (positions 0..length-1)
18
+ attr_reader :length
19
+ # @return [Integer]
20
+ attr_reader :max_seq_len, :num_layers, :embed_dim
21
+
22
+ # @param num_layers [Integer]
23
+ # @param max_seq_len [Integer] capacity (generation cannot exceed this)
24
+ # @param embed_dim [Integer]
25
+ # @param device_id [Integer]
26
+ def initialize(num_layers:, max_seq_len:, embed_dim:, device_id: 0)
27
+ @num_layers = num_layers
28
+ @max_seq_len = max_seq_len
29
+ @embed_dim = embed_dim
30
+ @device_id = device_id
31
+ @length = 0
32
+ @k = Array.new(num_layers) { make_buffer }
33
+ @v = Array.new(num_layers) { make_buffer }
34
+ end
35
+
36
+ # Append a layer's K/V for the current token at row +length+.
37
+ # @param layer [Integer] layer index
38
+ # @param k_row [Ignis::Shared::NvArray] [1, embed] key projection of the new token
39
+ # @param v_row [Ignis::Shared::NvArray] [1, embed] value projection of the new token
40
+ # @return [void]
41
+ def append(layer, k_row, v_row)
42
+ @k[layer].write_rows!(k_row, @length)
43
+ @v[layer].write_rows!(v_row, @length)
44
+ end
45
+
46
+ # Contiguous [length+1, embed] view of cached keys for +layer+, INCLUDING the
47
+ # row just appended this step. (Call after #append, before #advance!.)
48
+ # @param layer [Integer]
49
+ # @return [Ignis::Shared::NvArray] non-owning view
50
+ def k_view(layer)
51
+ @k[layer].slice(0, 0, @length + 1)
52
+ end
53
+
54
+ # @param layer [Integer]
55
+ # @return [Ignis::Shared::NvArray] non-owning view of cached values
56
+ def v_view(layer)
57
+ @v[layer].slice(0, 0, @length + 1)
58
+ end
59
+
60
+ # Advance to the next position after every layer has appended this token.
61
+ # @return [void]
62
+ def advance!
63
+ @length += 1
64
+ end
65
+
66
+ # @return [Boolean] true if the cache is at capacity (no more tokens fit)
67
+ def full?
68
+ @length >= @max_seq_len
69
+ end
70
+
71
+ private
72
+
73
+ def make_buffer
74
+ Ignis::Shared::NvArray.new(shape: [@max_seq_len, @embed_dim],
75
+ dtype: :float32, device_id: @device_id).to_device
76
+ end
77
+ end
78
+ end
79
+ end
@@ -0,0 +1,100 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "json"
4
+
5
+ module Ignis
6
+ module AI
7
+ # LlamaLoader — load HuggingFace Llama-family checkpoints (Llama-3.x, and other
8
+ # LlamaForCausalLM models: SmolLM, TinyLlama, …) into a Transformer::ModernModel.
9
+ #
10
+ # Reads config.json to size the model (RoPE base/scaling, GQA head counts, RMSNorm
11
+ # eps, tied embeddings), then loads model.safetensors, dequantizing bf16 weights to
12
+ # fp32 on-device. HF's nn.Linear weights are [out, in] and applied as x·Wᵀ — the
13
+ # SAME convention as Ignis::AI::NN::Linear — so weights map across with no transpose
14
+ # (unlike GPT-2's Conv1D).
15
+ module LlamaLoader
16
+ class << self
17
+ # Build a ModernModel from config.json and load its weights.
18
+ # @param dir [String] directory containing config.json + model.safetensors
19
+ # @param device_id [Integer]
20
+ # @return [Transformer::ModernModel]
21
+ def from_pretrained(dir, device_id: 0)
22
+ cfg = JSON.parse(File.read(File.join(dir, "config.json")))
23
+ model = Transformer::ModernModel.new(
24
+ vocab_size: cfg["vocab_size"],
25
+ embed_dim: cfg["hidden_size"],
26
+ num_heads: cfg["num_attention_heads"],
27
+ num_kv_heads: cfg["num_key_value_heads"] || cfg["num_attention_heads"],
28
+ num_layers: cfg["num_hidden_layers"],
29
+ ff_dim: cfg["intermediate_size"],
30
+ max_seq_len: cfg["max_position_embeddings"],
31
+ rope_base: (cfg["rope_theta"] || 10000.0).to_f,
32
+ rope_scaling: cfg["rope_scaling"],
33
+ head_dim: cfg["head_dim"],
34
+ eps: (cfg["rms_norm_eps"] || 1e-5).to_f,
35
+ device_id: device_id
36
+ )
37
+ load(model, dir, device_id: device_id)
38
+ model
39
+ end
40
+
41
+ # Load weights from dir/model.safetensors into an existing ModernModel.
42
+ # @return [Integer] number of parameters loaded
43
+ def load(model, dir, device_id: 0)
44
+ tensors = Safetensors.load(File.join(dir, "model.safetensors"), device_id: device_id)
45
+ embed_src = tensors["model.embed_tokens.weight"]
46
+ raise "LlamaLoader: missing model.embed_tokens.weight" unless embed_src
47
+ lm_head_src = tensors["lm_head.weight"] # nil when embeddings are tied
48
+
49
+ count = 0
50
+ model.named_parameters.each do |name, param|
51
+ src = if name == "head.weight"
52
+ lm_head_src || embed_src # untied lm_head, else tied to embeddings
53
+ else
54
+ tensors[hf_name(name)]
55
+ end
56
+ raise "LlamaLoader: no source weight for param #{name} (#{hf_name(name)})" unless src
57
+ unless param.shape == src.shape
58
+ raise "LlamaLoader: shape mismatch for #{name}: model #{param.shape} vs file #{src.shape}"
59
+ end
60
+ dequant_into!(src.data, param.data)
61
+ count += 1
62
+ end
63
+ Ignis.synchronize
64
+ count
65
+ end
66
+
67
+ private
68
+
69
+ # Translate an Ignis ModernModel parameter name → HuggingFace Llama weight name.
70
+ # ("head.weight" is resolved in #load, since it may be tied or an explicit lm_head.)
71
+ def hf_name(name)
72
+ case name
73
+ when "token_embedding.weight" then "model.embed_tokens.weight"
74
+ when "norm.weight" then "model.norm.weight"
75
+ when /\Ablocks\.(\d+)\.norm1\.weight\z/ then "model.layers.#{$1}.input_layernorm.weight"
76
+ when /\Ablocks\.(\d+)\.norm2\.weight\z/ then "model.layers.#{$1}.post_attention_layernorm.weight"
77
+ when /\Ablocks\.(\d+)\.attn\.([qkvo])_proj\.weight\z/ then "model.layers.#{$1}.self_attn.#{$2}_proj.weight"
78
+ when /\Ablocks\.(\d+)\.mlp\.(gate|up|down)\.weight\z/ then "model.layers.#{$1}.mlp.#{$2}_proj.weight"
79
+ else raise "LlamaLoader: unmapped parameter #{name}"
80
+ end
81
+ end
82
+
83
+ # Dequantize a bf16 source NvArray into an fp32 destination NvArray (on-device).
84
+ def dequant_into!(src_nv, dst_nv)
85
+ unless src_nv.numel == dst_nv.numel
86
+ raise "LlamaLoader: numel mismatch #{src_nv.numel} vs #{dst_nv.numel}"
87
+ end
88
+ if src_nv.dtype == :float32
89
+ # already fp32 (some checkpoints) — straight device→device copy
90
+ dst_nv.write_rows!(src_nv, 0)
91
+ return
92
+ end
93
+ n = dst_nv.numel
94
+ k = Ignis::JIT::Kernels::Elementwise.bf16_to_f32
95
+ k.launch(grid: [(n + 255) / 256], block: [256], args: [src_nv, dst_nv, n])
96
+ end
97
+ end
98
+ end
99
+ end
100
+ end