ignis-dl 0.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/README.md +15 -0
- data/lib/ignis-dl.rb +48 -0
- data/lib/nnw/ai/gpt2_loader.rb +144 -0
- data/lib/nnw/ai/inference.rb +224 -0
- data/lib/nnw/ai/kv_cache.rb +79 -0
- data/lib/nnw/ai/llama_loader.rb +100 -0
- data/lib/nnw/ai/loss.rb +170 -0
- data/lib/nnw/ai/nn/dropout.rb +68 -0
- data/lib/nnw/ai/nn/embedding.rb +86 -0
- data/lib/nnw/ai/nn/layer_norm.rb +54 -0
- data/lib/nnw/ai/nn/linear.rb +80 -0
- data/lib/nnw/ai/nn/module.rb +178 -0
- data/lib/nnw/ai/nn/rms_norm.rb +43 -0
- data/lib/nnw/ai/nn/sequential.rb +52 -0
- data/lib/nnw/ai/optim/adam.rb +63 -0
- data/lib/nnw/ai/optim/adamw.rb +63 -0
- data/lib/nnw/ai/optim/base.rb +90 -0
- data/lib/nnw/ai/optim/lr_scheduler.rb +118 -0
- data/lib/nnw/ai/optim/sgd.rb +49 -0
- data/lib/nnw/ai/safetensors.rb +220 -0
- data/lib/nnw/ai/server.rb +268 -0
- data/lib/nnw/ai/tokenizer.rb +413 -0
- data/lib/nnw/ai/trainer.rb +245 -0
- data/lib/nnw/ai/transformer/attention.rb +89 -0
- data/lib/nnw/ai/transformer/block.rb +90 -0
- data/lib/nnw/ai/transformer/feed_forward.rb +53 -0
- data/lib/nnw/ai/transformer/model.rb +189 -0
- data/lib/nnw/ai/transformer/modern.rb +191 -0
- data/lib/nnw/ai/transformer/swiglu.rb +39 -0
- data/lib/nnw/ai/weight_map.rb +139 -0
- metadata +91 -0
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Ignis
|
|
4
|
+
module AI
|
|
5
|
+
# Trainer — complete training loop with gradient accumulation,
|
|
6
|
+
# checkpointing, and multi-GPU support via NvCCL.
|
|
7
|
+
class Trainer
|
|
8
|
+
# @return [Transformer::Model]
|
|
9
|
+
attr_reader :model
|
|
10
|
+
|
|
11
|
+
# @return [Optim::Base]
|
|
12
|
+
attr_reader :optimizer
|
|
13
|
+
|
|
14
|
+
# @return [Hash] training metrics
|
|
15
|
+
attr_reader :metrics
|
|
16
|
+
|
|
17
|
+
# @param model [Transformer::Model]
|
|
18
|
+
# @param optimizer [Optim::Base]
|
|
19
|
+
# @param scheduler [Optim::LRScheduler::*, nil]
|
|
20
|
+
# @param grad_accumulation_steps [Integer] accumulate gradients over N steps
|
|
21
|
+
# @param max_grad_norm [Float] gradient clipping norm
|
|
22
|
+
# @param use_nvccl [Boolean] enable multi-GPU gradient sync
|
|
23
|
+
# @param checkpoint_dir [String, nil] directory for saving checkpoints
|
|
24
|
+
def initialize(model:, optimizer:, scheduler: nil,
|
|
25
|
+
grad_accumulation_steps: 1, max_grad_norm: 1.0,
|
|
26
|
+
use_nvccl: false, checkpoint_dir: nil)
|
|
27
|
+
@model = model
|
|
28
|
+
@optimizer = optimizer
|
|
29
|
+
@scheduler = scheduler
|
|
30
|
+
@grad_accumulation_steps = grad_accumulation_steps
|
|
31
|
+
@max_grad_norm = max_grad_norm
|
|
32
|
+
@use_nvccl = use_nvccl
|
|
33
|
+
@checkpoint_dir = checkpoint_dir
|
|
34
|
+
@metrics = { steps: 0, total_loss: 0.0, best_loss: Float::INFINITY }
|
|
35
|
+
@model.train!
|
|
36
|
+
end
|
|
37
|
+
|
|
38
|
+
# Train for a specified number of steps.
|
|
39
|
+
# @param data_loader [DataLoader] provides batches
|
|
40
|
+
# @param steps [Integer] total training steps
|
|
41
|
+
# @param log_interval [Integer] log every N steps
|
|
42
|
+
# @param checkpoint_interval [Integer] save every N steps
|
|
43
|
+
# @param eval_fn [Proc, nil] evaluation function called at log intervals
|
|
44
|
+
# @return [Hash] final metrics
|
|
45
|
+
def train(data_loader, steps:, log_interval: 100,
|
|
46
|
+
checkpoint_interval: 1000, eval_fn: nil)
|
|
47
|
+
@model.train!
|
|
48
|
+
accumulated_loss = 0.0
|
|
49
|
+
|
|
50
|
+
steps.times do |step|
|
|
51
|
+
# Get batch
|
|
52
|
+
batch = data_loader.next_batch
|
|
53
|
+
input_ids = batch[:input_ids]
|
|
54
|
+
targets = batch[:targets]
|
|
55
|
+
|
|
56
|
+
# Forward pass
|
|
57
|
+
logits = @model.call(input_ids)
|
|
58
|
+
loss = Loss.cross_entropy(logits, targets)
|
|
59
|
+
|
|
60
|
+
# Scale loss for gradient accumulation
|
|
61
|
+
scaled_loss = loss * (1.0 / @grad_accumulation_steps)
|
|
62
|
+
|
|
63
|
+
# Backward pass
|
|
64
|
+
scaled_loss.backward!
|
|
65
|
+
|
|
66
|
+
accumulated_loss += loss.item
|
|
67
|
+
|
|
68
|
+
# Optimizer step (every grad_accumulation_steps)
|
|
69
|
+
if (step + 1) % @grad_accumulation_steps == 0
|
|
70
|
+
# Gradient clipping
|
|
71
|
+
grad_norm = @optimizer.clip_grad_norm!(@max_grad_norm)
|
|
72
|
+
|
|
73
|
+
# Multi-GPU gradient sync
|
|
74
|
+
if @use_nvccl
|
|
75
|
+
sync_gradients_nvccl!
|
|
76
|
+
end
|
|
77
|
+
|
|
78
|
+
# Optimizer step
|
|
79
|
+
@optimizer.step
|
|
80
|
+
@optimizer.zero_grad!
|
|
81
|
+
@scheduler&.step
|
|
82
|
+
|
|
83
|
+
@metrics[:steps] += 1
|
|
84
|
+
@metrics[:total_loss] += accumulated_loss / @grad_accumulation_steps
|
|
85
|
+
|
|
86
|
+
# Logging
|
|
87
|
+
if @metrics[:steps] % log_interval == 0
|
|
88
|
+
avg_loss = @metrics[:total_loss] / @metrics[:steps]
|
|
89
|
+
lr = @optimizer.lr
|
|
90
|
+
Ignis.logger.info(
|
|
91
|
+
"Step #{@metrics[:steps]} | Loss: #{'%.4f' % (accumulated_loss / @grad_accumulation_steps)} | " \
|
|
92
|
+
"Avg Loss: #{'%.4f' % avg_loss} | LR: #{'%.2e' % lr} | Grad Norm: #{'%.2f' % grad_norm}"
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
# EventBus publish
|
|
96
|
+
if defined?(Ignis::Shared::EventBus)
|
|
97
|
+
Ignis::Shared::EventBus.publish(:training_step, {
|
|
98
|
+
step: @metrics[:steps],
|
|
99
|
+
loss: accumulated_loss / @grad_accumulation_steps,
|
|
100
|
+
avg_loss: avg_loss,
|
|
101
|
+
lr: lr,
|
|
102
|
+
grad_norm: grad_norm
|
|
103
|
+
})
|
|
104
|
+
end
|
|
105
|
+
|
|
106
|
+
# Eval
|
|
107
|
+
if eval_fn
|
|
108
|
+
@model.eval!
|
|
109
|
+
eval_fn.call(@model, @metrics[:steps])
|
|
110
|
+
@model.train!
|
|
111
|
+
end
|
|
112
|
+
end
|
|
113
|
+
|
|
114
|
+
# Checkpointing
|
|
115
|
+
if @checkpoint_dir && @metrics[:steps] % checkpoint_interval == 0
|
|
116
|
+
save_checkpoint!
|
|
117
|
+
end
|
|
118
|
+
|
|
119
|
+
accumulated_loss = 0.0
|
|
120
|
+
end
|
|
121
|
+
|
|
122
|
+
# Clear tape each iteration
|
|
123
|
+
Tape.clear!
|
|
124
|
+
end
|
|
125
|
+
|
|
126
|
+
@metrics
|
|
127
|
+
end
|
|
128
|
+
|
|
129
|
+
# Save model checkpoint.
|
|
130
|
+
# @return [String] checkpoint path
|
|
131
|
+
def save_checkpoint!
|
|
132
|
+
return unless @checkpoint_dir
|
|
133
|
+
|
|
134
|
+
Dir.mkdir(@checkpoint_dir) unless Dir.exist?(@checkpoint_dir)
|
|
135
|
+
path = File.join(@checkpoint_dir, "checkpoint_step_#{@metrics[:steps]}.safetensors")
|
|
136
|
+
|
|
137
|
+
tensors = {}
|
|
138
|
+
@model.named_parameters.each do |name, param|
|
|
139
|
+
tensors[name] = param
|
|
140
|
+
end
|
|
141
|
+
|
|
142
|
+
Safetensors.save(tensors, path, metadata: {
|
|
143
|
+
"step" => @metrics[:steps].to_s,
|
|
144
|
+
"loss" => (@metrics[:total_loss] / [@metrics[:steps], 1].max).to_s,
|
|
145
|
+
"framework" => "nnw"
|
|
146
|
+
})
|
|
147
|
+
|
|
148
|
+
Ignis.logger.info("Checkpoint saved: #{path}")
|
|
149
|
+
path
|
|
150
|
+
end
|
|
151
|
+
|
|
152
|
+
# Load from checkpoint.
|
|
153
|
+
# @param path [String]
|
|
154
|
+
# @return [void]
|
|
155
|
+
def load_checkpoint!(path)
|
|
156
|
+
Safetensors.load_model(@model, path, strict: false)
|
|
157
|
+
Ignis.logger.info("Checkpoint loaded: #{path}")
|
|
158
|
+
end
|
|
159
|
+
|
|
160
|
+
private
|
|
161
|
+
|
|
162
|
+
# Sync gradients across GPUs via NvCCL AllReduce.
|
|
163
|
+
# @return [void]
|
|
164
|
+
def sync_gradients_nvccl!
|
|
165
|
+
return unless defined?(Ignis::Collective)
|
|
166
|
+
|
|
167
|
+
@model.parameters.each do |p|
|
|
168
|
+
next unless p.grad
|
|
169
|
+
Ignis::Collective.all_reduce(p.grad, op: :sum)
|
|
170
|
+
end
|
|
171
|
+
end
|
|
172
|
+
end
|
|
173
|
+
|
|
174
|
+
# DataLoader — batching, shuffling, and GPU prefetch for training data.
|
|
175
|
+
class DataLoader
|
|
176
|
+
# @param data [Array<Array<Integer>>] tokenized sequences
|
|
177
|
+
# @param batch_size [Integer]
|
|
178
|
+
# @param seq_len [Integer] sequence length per sample
|
|
179
|
+
# @param device_id [Integer]
|
|
180
|
+
# @param shuffle [Boolean]
|
|
181
|
+
def initialize(data, batch_size:, seq_len:, device_id: 0, shuffle: true)
|
|
182
|
+
@data = data.flatten
|
|
183
|
+
@batch_size = batch_size
|
|
184
|
+
@seq_len = seq_len
|
|
185
|
+
@device_id = device_id
|
|
186
|
+
@shuffle = shuffle
|
|
187
|
+
@position = 0
|
|
188
|
+
|
|
189
|
+
# Shuffle on init
|
|
190
|
+
reshuffle! if @shuffle
|
|
191
|
+
end
|
|
192
|
+
|
|
193
|
+
# Get next training batch.
|
|
194
|
+
# @return [Hash{Symbol => Tensor}] :input_ids, :targets
|
|
195
|
+
def next_batch
|
|
196
|
+
total_tokens = @batch_size * @seq_len
|
|
197
|
+
|
|
198
|
+
if @position + total_tokens + 1 > @data.length
|
|
199
|
+
@position = 0
|
|
200
|
+
reshuffle! if @shuffle
|
|
201
|
+
end
|
|
202
|
+
|
|
203
|
+
input_ids = []
|
|
204
|
+
targets = []
|
|
205
|
+
|
|
206
|
+
@batch_size.times do |b|
|
|
207
|
+
start = @position + b * @seq_len
|
|
208
|
+
input_ids.concat(@data[start, @seq_len])
|
|
209
|
+
targets.concat(@data[start + 1, @seq_len])
|
|
210
|
+
end
|
|
211
|
+
|
|
212
|
+
@position += total_tokens
|
|
213
|
+
|
|
214
|
+
input_nv = Ignis::Shared::NvArray.new(shape: [@batch_size, @seq_len], dtype: :int32,
|
|
215
|
+
device_id: @device_id)
|
|
216
|
+
input_nv.from_host(input_ids)
|
|
217
|
+
|
|
218
|
+
target_nv = Ignis::Shared::NvArray.new(shape: [@batch_size * @seq_len], dtype: :int32,
|
|
219
|
+
device_id: @device_id)
|
|
220
|
+
target_nv.from_host(targets)
|
|
221
|
+
|
|
222
|
+
{
|
|
223
|
+
input_ids: Tensor.new(data: input_nv, requires_grad: false),
|
|
224
|
+
targets: Tensor.new(data: target_nv, requires_grad: false)
|
|
225
|
+
}
|
|
226
|
+
end
|
|
227
|
+
|
|
228
|
+
# Number of batches per epoch.
|
|
229
|
+
# @return [Integer]
|
|
230
|
+
def num_batches
|
|
231
|
+
(@data.length - 1) / (@batch_size * @seq_len)
|
|
232
|
+
end
|
|
233
|
+
|
|
234
|
+
private
|
|
235
|
+
|
|
236
|
+
def reshuffle!
|
|
237
|
+
# Shuffle at chunk boundaries (not individual tokens)
|
|
238
|
+
chunk_size = @seq_len + 1
|
|
239
|
+
chunks = @data.each_slice(chunk_size).to_a
|
|
240
|
+
chunks.shuffle!
|
|
241
|
+
@data = chunks.flatten
|
|
242
|
+
end
|
|
243
|
+
end
|
|
244
|
+
end
|
|
245
|
+
end
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Ignis
|
|
4
|
+
module AI
|
|
5
|
+
module Transformer
|
|
6
|
+
# Multi-Head Attention with optional causal masking.
|
|
7
|
+
# Supports standard scaled dot-product attention.
|
|
8
|
+
#
|
|
9
|
+
# @example
|
|
10
|
+
# attn = MultiHeadAttention.new(768, 12)
|
|
11
|
+
# out = attn.call(x, x, x, mask: causal_mask)
|
|
12
|
+
class MultiHeadAttention < NN::Module
|
|
13
|
+
# @param embed_dim [Integer] model dimension
|
|
14
|
+
# @param num_heads [Integer] number of attention heads
|
|
15
|
+
# @param dropout [Float] attention dropout rate
|
|
16
|
+
# @param bias [Boolean] whether projections have bias
|
|
17
|
+
# @param device_id [Integer]
|
|
18
|
+
def initialize(embed_dim, num_heads, dropout: 0.0, bias: true, causal: true, device_id: 0)
|
|
19
|
+
super()
|
|
20
|
+
raise ArgumentError, "embed_dim must be divisible by num_heads" unless (embed_dim % num_heads).zero?
|
|
21
|
+
|
|
22
|
+
@embed_dim = embed_dim
|
|
23
|
+
@num_heads = num_heads
|
|
24
|
+
@head_dim = embed_dim / num_heads
|
|
25
|
+
@scale = 1.0 / Math.sqrt(@head_dim)
|
|
26
|
+
@causal = causal
|
|
27
|
+
|
|
28
|
+
@q_proj = register_module("q_proj", NN::Linear.new(embed_dim, embed_dim, bias: bias, device_id: device_id))
|
|
29
|
+
@k_proj = register_module("k_proj", NN::Linear.new(embed_dim, embed_dim, bias: bias, device_id: device_id))
|
|
30
|
+
@v_proj = register_module("v_proj", NN::Linear.new(embed_dim, embed_dim, bias: bias, device_id: device_id))
|
|
31
|
+
@out_proj = register_module("out_proj", NN::Linear.new(embed_dim, embed_dim, bias: bias, device_id: device_id))
|
|
32
|
+
@attn_dropout = register_module("attn_dropout", NN::Dropout.new(p: dropout))
|
|
33
|
+
end
|
|
34
|
+
|
|
35
|
+
# Forward pass: scaled dot-product attention.
|
|
36
|
+
# @param query [Tensor] [batch, seq_q, embed_dim]
|
|
37
|
+
# @param key [Tensor] [batch, seq_k, embed_dim]
|
|
38
|
+
# @param value [Tensor] [batch, seq_k, embed_dim]
|
|
39
|
+
# @param mask [Tensor, nil] attention mask
|
|
40
|
+
# @return [Tensor] [batch, seq_q, embed_dim]
|
|
41
|
+
def forward(query, key, value, mask: nil)
|
|
42
|
+
# Project Q, K, V → [seq, embed_dim] each (batch = 1)
|
|
43
|
+
q = @q_proj.call(query)
|
|
44
|
+
k = @k_proj.call(key)
|
|
45
|
+
v = @v_proj.call(value)
|
|
46
|
+
|
|
47
|
+
# Real multi-head scaled dot-product attention with causal masking.
|
|
48
|
+
# Each head attends over its own [seq, head_dim] slice; the per-head
|
|
49
|
+
# Flash-Attention-2 kernel applies 1/sqrt(head_dim) scaling, the causal
|
|
50
|
+
# mask, and a numerically-stable online softmax. (A nil mask still means
|
|
51
|
+
# causal for a GPT-2-style decoder; pass causal: false at construction
|
|
52
|
+
# for bidirectional attention.)
|
|
53
|
+
context = q.sdpa(k, v, num_heads: @num_heads, causal: @causal)
|
|
54
|
+
|
|
55
|
+
# Output projection
|
|
56
|
+
@out_proj.call(context)
|
|
57
|
+
end
|
|
58
|
+
|
|
59
|
+
# Incremental attention for one new token via a KV cache (decode path).
|
|
60
|
+
# Projects only this token's q/k/v, appends k/v to the cache, then attends
|
|
61
|
+
# the single query over all cached keys/values (the new token is the last
|
|
62
|
+
# position, so it sees everything — no causal mask). Must run under no_grad.
|
|
63
|
+
# @param x [Tensor] [1, embed] hidden state of the new token (post-norm)
|
|
64
|
+
# @param cache [KVCache]
|
|
65
|
+
# @param layer [Integer] this block's layer index
|
|
66
|
+
# @return [Tensor] [1, embed]
|
|
67
|
+
def decode_step(x, cache, layer)
|
|
68
|
+
raise "decode_step requires a causal attention module" unless @causal
|
|
69
|
+
|
|
70
|
+
q = @q_proj.call(x)
|
|
71
|
+
k = @k_proj.call(x)
|
|
72
|
+
v = @v_proj.call(x)
|
|
73
|
+
|
|
74
|
+
cache.append(layer, k.data, v.data)
|
|
75
|
+
kview = Tensor.new(data: cache.k_view(layer), requires_grad: false)
|
|
76
|
+
vview = Tensor.new(data: cache.v_view(layer), requires_grad: false)
|
|
77
|
+
|
|
78
|
+
ctx = q.decode_sdpa(kview, vview, num_heads: @num_heads)
|
|
79
|
+
@out_proj.call(ctx)
|
|
80
|
+
end
|
|
81
|
+
|
|
82
|
+
# @return [String]
|
|
83
|
+
def to_s
|
|
84
|
+
"MultiHeadAttention(embed=#{@embed_dim}, heads=#{@num_heads}, head_dim=#{@head_dim})"
|
|
85
|
+
end
|
|
86
|
+
end
|
|
87
|
+
end
|
|
88
|
+
end
|
|
89
|
+
end
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Ignis
|
|
4
|
+
module AI
|
|
5
|
+
module Transformer
|
|
6
|
+
# Single Transformer block: Attention + FF with residual connections.
|
|
7
|
+
# Supports both pre-norm (GPT-2, LLaMA) and post-norm (original) variants.
|
|
8
|
+
class Block < NN::Module
|
|
9
|
+
# @param embed_dim [Integer]
|
|
10
|
+
# @param num_heads [Integer]
|
|
11
|
+
# @param ff_dim [Integer]
|
|
12
|
+
# @param dropout [Float]
|
|
13
|
+
# @param pre_norm [Boolean] pre-norm (true) vs post-norm
|
|
14
|
+
# @param activation [Symbol]
|
|
15
|
+
# @param device_id [Integer]
|
|
16
|
+
def initialize(embed_dim, num_heads, ff_dim, dropout: 0.0,
|
|
17
|
+
pre_norm: true, activation: :gelu, device_id: 0)
|
|
18
|
+
super()
|
|
19
|
+
@pre_norm = pre_norm
|
|
20
|
+
|
|
21
|
+
@attention = register_module("attention",
|
|
22
|
+
MultiHeadAttention.new(embed_dim, num_heads, dropout: dropout, device_id: device_id))
|
|
23
|
+
@feed_forward = register_module("feed_forward",
|
|
24
|
+
FeedForward.new(embed_dim, ff_dim, activation: activation,
|
|
25
|
+
dropout: dropout, device_id: device_id))
|
|
26
|
+
@norm1 = register_module("norm1", NN::LayerNorm.new(embed_dim, device_id: device_id))
|
|
27
|
+
@norm2 = register_module("norm2", NN::LayerNorm.new(embed_dim, device_id: device_id))
|
|
28
|
+
@dropout = register_module("dropout", NN::Dropout.new(p: dropout))
|
|
29
|
+
end
|
|
30
|
+
|
|
31
|
+
# Forward pass.
|
|
32
|
+
# @param x [Tensor] input [batch*seq, embed_dim]
|
|
33
|
+
# @param mask [Tensor, nil] attention mask
|
|
34
|
+
# @return [Tensor]
|
|
35
|
+
def forward(x, mask: nil)
|
|
36
|
+
if @pre_norm
|
|
37
|
+
# Pre-norm (GPT-2 style): Norm → Attn → Residual, Norm → FF → Residual
|
|
38
|
+
residual = x
|
|
39
|
+
h = @norm1.call(x)
|
|
40
|
+
h = @attention.call(h, h, h, mask: mask)
|
|
41
|
+
h = @dropout.call(h)
|
|
42
|
+
x = residual + h
|
|
43
|
+
|
|
44
|
+
residual = x
|
|
45
|
+
h = @norm2.call(x)
|
|
46
|
+
h = @feed_forward.call(h)
|
|
47
|
+
h = @dropout.call(h)
|
|
48
|
+
residual + h
|
|
49
|
+
else
|
|
50
|
+
# Post-norm (original Transformer): Attn → Residual → Norm
|
|
51
|
+
residual = x
|
|
52
|
+
h = @attention.call(x, x, x, mask: mask)
|
|
53
|
+
h = @dropout.call(h)
|
|
54
|
+
x = @norm1.call(residual + h)
|
|
55
|
+
|
|
56
|
+
residual = x
|
|
57
|
+
h = @feed_forward.call(x)
|
|
58
|
+
h = @dropout.call(h)
|
|
59
|
+
@norm2.call(residual + h)
|
|
60
|
+
end
|
|
61
|
+
end
|
|
62
|
+
|
|
63
|
+
# Incremental forward for one new token (decode path, pre-norm only).
|
|
64
|
+
# Mirrors #forward but routes attention through the KV cache and operates
|
|
65
|
+
# on a single [1, embed] row. Dropout is identity in eval mode, so it is
|
|
66
|
+
# omitted. @param x [Tensor] [1, embed]; @param cache [KVCache];
|
|
67
|
+
# @param layer [Integer] this block's index. @return [Tensor] [1, embed]
|
|
68
|
+
def decode_step(x, cache, layer)
|
|
69
|
+
raise "Block#decode_step is only implemented for pre-norm blocks" unless @pre_norm
|
|
70
|
+
|
|
71
|
+
residual = x
|
|
72
|
+
h = @norm1.call(x)
|
|
73
|
+
h = @attention.decode_step(h, cache, layer)
|
|
74
|
+
x = residual + h
|
|
75
|
+
|
|
76
|
+
residual = x
|
|
77
|
+
h = @norm2.call(x)
|
|
78
|
+
h = @feed_forward.call(h)
|
|
79
|
+
residual + h
|
|
80
|
+
end
|
|
81
|
+
|
|
82
|
+
# @return [String]
|
|
83
|
+
def to_s
|
|
84
|
+
style = @pre_norm ? "pre-norm" : "post-norm"
|
|
85
|
+
"Block(#{style}, attn=#{@attention}, ff=#{@feed_forward})"
|
|
86
|
+
end
|
|
87
|
+
end
|
|
88
|
+
end
|
|
89
|
+
end
|
|
90
|
+
end
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Ignis
|
|
4
|
+
module AI
|
|
5
|
+
module Transformer
|
|
6
|
+
# Feed-forward network: Linear → Activation → Dropout → Linear
|
|
7
|
+
# Used in each Transformer block after attention.
|
|
8
|
+
class FeedForward < NN::Module
|
|
9
|
+
# @param embed_dim [Integer] model dimension
|
|
10
|
+
# @param ff_dim [Integer] feed-forward hidden dimension (typically 4x embed_dim)
|
|
11
|
+
# @param activation [Symbol] :gelu, :relu, or :silu
|
|
12
|
+
# @param dropout [Float] dropout rate
|
|
13
|
+
# @param device_id [Integer]
|
|
14
|
+
def initialize(embed_dim, ff_dim, activation: :gelu, dropout: 0.0, device_id: 0)
|
|
15
|
+
super()
|
|
16
|
+
@activation = activation
|
|
17
|
+
|
|
18
|
+
@fc1 = register_module("fc1", NN::Linear.new(embed_dim, ff_dim, device_id: device_id))
|
|
19
|
+
@fc2 = register_module("fc2", NN::Linear.new(ff_dim, embed_dim, device_id: device_id))
|
|
20
|
+
@dropout = register_module("dropout", NN::Dropout.new(p: dropout))
|
|
21
|
+
end
|
|
22
|
+
|
|
23
|
+
# Forward pass.
|
|
24
|
+
# @param x [Tensor]
|
|
25
|
+
# @return [Tensor]
|
|
26
|
+
def forward(x)
|
|
27
|
+
h = @fc1.call(x)
|
|
28
|
+
h = apply_activation(h)
|
|
29
|
+
h = @dropout.call(h)
|
|
30
|
+
@fc2.call(h)
|
|
31
|
+
end
|
|
32
|
+
|
|
33
|
+
# @return [String]
|
|
34
|
+
def to_s
|
|
35
|
+
"FeedForward(fc1=#{@fc1}, activation=#{@activation}, fc2=#{@fc2})"
|
|
36
|
+
end
|
|
37
|
+
|
|
38
|
+
private
|
|
39
|
+
|
|
40
|
+
# @param x [Tensor]
|
|
41
|
+
# @return [Tensor]
|
|
42
|
+
def apply_activation(x)
|
|
43
|
+
case @activation
|
|
44
|
+
when :gelu then x.gelu
|
|
45
|
+
when :relu then x.relu
|
|
46
|
+
when :silu then x.silu
|
|
47
|
+
else raise ArgumentError, "Unknown activation: #{@activation}"
|
|
48
|
+
end
|
|
49
|
+
end
|
|
50
|
+
end
|
|
51
|
+
end
|
|
52
|
+
end
|
|
53
|
+
end
|
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Ignis
|
|
4
|
+
module AI
|
|
5
|
+
module Transformer
|
|
6
|
+
# Full Transformer language model.
|
|
7
|
+
#
|
|
8
|
+
# token_embedding → position_embedding → N × Block → LayerNorm → LM head
|
|
9
|
+
#
|
|
10
|
+
# Factory methods provide standard model configurations:
|
|
11
|
+
# .gpt2_small → 124M params
|
|
12
|
+
# .gpt2_medium → 345M params
|
|
13
|
+
# .gpt2_large → 774M params
|
|
14
|
+
class Model < NN::Module
|
|
15
|
+
# @return [Integer]
|
|
16
|
+
attr_reader :vocab_size, :embed_dim, :num_heads, :num_layers, :max_seq_len
|
|
17
|
+
|
|
18
|
+
# @param vocab_size [Integer] vocabulary size
|
|
19
|
+
# @param embed_dim [Integer] model dimension
|
|
20
|
+
# @param num_heads [Integer] attention heads per block
|
|
21
|
+
# @param num_layers [Integer] number of Transformer blocks
|
|
22
|
+
# @param ff_dim [Integer] feed-forward hidden dimension
|
|
23
|
+
# @param max_seq_len [Integer] maximum sequence length
|
|
24
|
+
# @param dropout [Float]
|
|
25
|
+
# @param activation [Symbol]
|
|
26
|
+
# @param pre_norm [Boolean]
|
|
27
|
+
# @param device_id [Integer]
|
|
28
|
+
def initialize(vocab_size:, embed_dim:, num_heads:, num_layers:,
|
|
29
|
+
ff_dim:, max_seq_len:, dropout: 0.0,
|
|
30
|
+
activation: :gelu, pre_norm: true, device_id: 0)
|
|
31
|
+
super()
|
|
32
|
+
@vocab_size = vocab_size
|
|
33
|
+
@embed_dim = embed_dim
|
|
34
|
+
@num_heads = num_heads
|
|
35
|
+
@num_layers = num_layers
|
|
36
|
+
@max_seq_len = max_seq_len
|
|
37
|
+
@device_id = device_id
|
|
38
|
+
|
|
39
|
+
@token_embedding = register_module("token_embedding",
|
|
40
|
+
NN::Embedding.new(vocab_size, embed_dim, device_id: device_id))
|
|
41
|
+
@position_embedding = register_module("position_embedding",
|
|
42
|
+
NN::Embedding.new(max_seq_len, embed_dim, device_id: device_id))
|
|
43
|
+
|
|
44
|
+
@blocks = []
|
|
45
|
+
num_layers.times do |i|
|
|
46
|
+
block = Block.new(embed_dim, num_heads, ff_dim,
|
|
47
|
+
dropout: dropout, pre_norm: pre_norm,
|
|
48
|
+
activation: activation, device_id: device_id)
|
|
49
|
+
@blocks << register_module("blocks.#{i}", block)
|
|
50
|
+
end
|
|
51
|
+
|
|
52
|
+
@norm = register_module("norm", NN::LayerNorm.new(embed_dim, device_id: device_id))
|
|
53
|
+
@head = register_module("head",
|
|
54
|
+
NN::Linear.new(embed_dim, vocab_size, bias: false, device_id: device_id))
|
|
55
|
+
@dropout = register_module("dropout", NN::Dropout.new(p: dropout))
|
|
56
|
+
end
|
|
57
|
+
|
|
58
|
+
# Forward pass: returns logits.
|
|
59
|
+
# @param input_ids [Tensor] token indices [batch_size, seq_len] (int32)
|
|
60
|
+
# @param mask [Tensor, nil] attention mask
|
|
61
|
+
# @return [Tensor] logits [batch_size * seq_len, vocab_size]
|
|
62
|
+
def forward(input_ids, mask: nil)
|
|
63
|
+
seq_len = input_ids.shape[-1]
|
|
64
|
+
|
|
65
|
+
# Create position indices
|
|
66
|
+
positions_data = (0...seq_len).to_a
|
|
67
|
+
pos_nv = Ignis::Shared::NvArray.new(shape: [seq_len], dtype: :int32,
|
|
68
|
+
device_id: input_ids.device_id)
|
|
69
|
+
pos_nv.from_host(positions_data)
|
|
70
|
+
positions = Tensor.new(data: pos_nv, requires_grad: false)
|
|
71
|
+
|
|
72
|
+
# Embeddings
|
|
73
|
+
tok_emb = @token_embedding.call(input_ids) # [batch, seq, embed]
|
|
74
|
+
pos_emb = @position_embedding.call(positions) # [seq, embed]
|
|
75
|
+
|
|
76
|
+
# Combine and dropout
|
|
77
|
+
x = tok_emb + pos_emb
|
|
78
|
+
x = @dropout.call(x)
|
|
79
|
+
|
|
80
|
+
# Transformer blocks
|
|
81
|
+
@blocks.each do |block|
|
|
82
|
+
x = block.call(x, mask: mask)
|
|
83
|
+
end
|
|
84
|
+
|
|
85
|
+
# Final norm and LM head
|
|
86
|
+
x = @norm.call(x)
|
|
87
|
+
@head.call(x) # → logits [batch*seq, vocab]
|
|
88
|
+
end
|
|
89
|
+
|
|
90
|
+
# Allocate a fresh KV cache sized for this model.
|
|
91
|
+
# @param device_id [Integer]
|
|
92
|
+
# @return [KVCache]
|
|
93
|
+
def make_kv_cache(device_id: @device_id)
|
|
94
|
+
KVCache.new(num_layers: @num_layers, max_seq_len: @max_seq_len,
|
|
95
|
+
embed_dim: @embed_dim, device_id: device_id)
|
|
96
|
+
end
|
|
97
|
+
|
|
98
|
+
# Incremental forward for ONE new token using a KV cache (decode path).
|
|
99
|
+
# Equivalent to the last-position logits of a full forward over the whole
|
|
100
|
+
# prefix, but O(prefix) instead of O(prefix²): only this token is projected
|
|
101
|
+
# and embedded; its query attends over cached K/V. Must run under
|
|
102
|
+
# Tape.no_grad (no autograd). Append order matches the prefix order, so
|
|
103
|
+
# callers feed the prompt token-by-token before sampling.
|
|
104
|
+
# @param token_id [Integer] the new token's id
|
|
105
|
+
# @param cache [KVCache]
|
|
106
|
+
# @return [Tensor] logits [1, vocab]
|
|
107
|
+
def decode_step(token_id, cache)
|
|
108
|
+
pos = cache.length
|
|
109
|
+
raise "KVCache full: position #{pos} exceeds max_seq_len #{@max_seq_len}" if pos >= @max_seq_len
|
|
110
|
+
|
|
111
|
+
tok = Tensor.from_host([token_id], shape: [1], dtype: :int32, device_id: @device_id)
|
|
112
|
+
pos_t = Tensor.from_host([pos], shape: [1], dtype: :int32, device_id: @device_id)
|
|
113
|
+
|
|
114
|
+
x = @token_embedding.call(tok) + @position_embedding.call(pos_t) # [1, embed]
|
|
115
|
+
@blocks.each_with_index { |block, i| x = block.decode_step(x, cache, i) }
|
|
116
|
+
cache.advance!
|
|
117
|
+
|
|
118
|
+
x = @norm.call(x)
|
|
119
|
+
@head.call(x) # [1, vocab]
|
|
120
|
+
end
|
|
121
|
+
|
|
122
|
+
# -------------------------------------------------------------------
|
|
123
|
+
# Factory methods for standard configurations
|
|
124
|
+
# -------------------------------------------------------------------
|
|
125
|
+
|
|
126
|
+
# GPT-2 Small: 124M parameters
|
|
127
|
+
# @param device_id [Integer]
|
|
128
|
+
# @return [Model]
|
|
129
|
+
def self.gpt2_small(device_id: 0)
|
|
130
|
+
new(
|
|
131
|
+
vocab_size: 50257,
|
|
132
|
+
embed_dim: 768,
|
|
133
|
+
num_heads: 12,
|
|
134
|
+
num_layers: 12,
|
|
135
|
+
ff_dim: 3072,
|
|
136
|
+
max_seq_len: 1024,
|
|
137
|
+
dropout: 0.1,
|
|
138
|
+
activation: :gelu,
|
|
139
|
+
pre_norm: true,
|
|
140
|
+
device_id: device_id
|
|
141
|
+
)
|
|
142
|
+
end
|
|
143
|
+
|
|
144
|
+
# GPT-2 Medium: 345M parameters
|
|
145
|
+
# @param device_id [Integer]
|
|
146
|
+
# @return [Model]
|
|
147
|
+
def self.gpt2_medium(device_id: 0)
|
|
148
|
+
new(
|
|
149
|
+
vocab_size: 50257,
|
|
150
|
+
embed_dim: 1024,
|
|
151
|
+
num_heads: 16,
|
|
152
|
+
num_layers: 24,
|
|
153
|
+
ff_dim: 4096,
|
|
154
|
+
max_seq_len: 1024,
|
|
155
|
+
dropout: 0.1,
|
|
156
|
+
activation: :gelu,
|
|
157
|
+
pre_norm: true,
|
|
158
|
+
device_id: device_id
|
|
159
|
+
)
|
|
160
|
+
end
|
|
161
|
+
|
|
162
|
+
# GPT-2 Large: 774M parameters
|
|
163
|
+
# @param device_id [Integer]
|
|
164
|
+
# @return [Model]
|
|
165
|
+
def self.gpt2_large(device_id: 0)
|
|
166
|
+
new(
|
|
167
|
+
vocab_size: 50257,
|
|
168
|
+
embed_dim: 1280,
|
|
169
|
+
num_heads: 20,
|
|
170
|
+
num_layers: 36,
|
|
171
|
+
ff_dim: 5120,
|
|
172
|
+
max_seq_len: 1024,
|
|
173
|
+
dropout: 0.1,
|
|
174
|
+
activation: :gelu,
|
|
175
|
+
pre_norm: true,
|
|
176
|
+
device_id: device_id
|
|
177
|
+
)
|
|
178
|
+
end
|
|
179
|
+
|
|
180
|
+
# @return [String]
|
|
181
|
+
def to_s
|
|
182
|
+
"TransformerModel(vocab=#{@vocab_size}, embed=#{@embed_dim}, " \
|
|
183
|
+
"heads=#{@num_heads}, layers=#{@num_layers}, " \
|
|
184
|
+
"params=#{num_parameters})"
|
|
185
|
+
end
|
|
186
|
+
end
|
|
187
|
+
end
|
|
188
|
+
end
|
|
189
|
+
end
|