ignis-dl 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,245 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Ignis
4
+ module AI
5
+ # Trainer — complete training loop with gradient accumulation,
6
+ # checkpointing, and multi-GPU support via NvCCL.
7
+ class Trainer
8
+ # @return [Transformer::Model]
9
+ attr_reader :model
10
+
11
+ # @return [Optim::Base]
12
+ attr_reader :optimizer
13
+
14
+ # @return [Hash] training metrics
15
+ attr_reader :metrics
16
+
17
+ # @param model [Transformer::Model]
18
+ # @param optimizer [Optim::Base]
19
+ # @param scheduler [Optim::LRScheduler::*, nil]
20
+ # @param grad_accumulation_steps [Integer] accumulate gradients over N steps
21
+ # @param max_grad_norm [Float] gradient clipping norm
22
+ # @param use_nvccl [Boolean] enable multi-GPU gradient sync
23
+ # @param checkpoint_dir [String, nil] directory for saving checkpoints
24
+ def initialize(model:, optimizer:, scheduler: nil,
25
+ grad_accumulation_steps: 1, max_grad_norm: 1.0,
26
+ use_nvccl: false, checkpoint_dir: nil)
27
+ @model = model
28
+ @optimizer = optimizer
29
+ @scheduler = scheduler
30
+ @grad_accumulation_steps = grad_accumulation_steps
31
+ @max_grad_norm = max_grad_norm
32
+ @use_nvccl = use_nvccl
33
+ @checkpoint_dir = checkpoint_dir
34
+ @metrics = { steps: 0, total_loss: 0.0, best_loss: Float::INFINITY }
35
+ @model.train!
36
+ end
37
+
38
+ # Train for a specified number of steps.
39
+ # @param data_loader [DataLoader] provides batches
40
+ # @param steps [Integer] total training steps
41
+ # @param log_interval [Integer] log every N steps
42
+ # @param checkpoint_interval [Integer] save every N steps
43
+ # @param eval_fn [Proc, nil] evaluation function called at log intervals
44
+ # @return [Hash] final metrics
45
+ def train(data_loader, steps:, log_interval: 100,
46
+ checkpoint_interval: 1000, eval_fn: nil)
47
+ @model.train!
48
+ accumulated_loss = 0.0
49
+
50
+ steps.times do |step|
51
+ # Get batch
52
+ batch = data_loader.next_batch
53
+ input_ids = batch[:input_ids]
54
+ targets = batch[:targets]
55
+
56
+ # Forward pass
57
+ logits = @model.call(input_ids)
58
+ loss = Loss.cross_entropy(logits, targets)
59
+
60
+ # Scale loss for gradient accumulation
61
+ scaled_loss = loss * (1.0 / @grad_accumulation_steps)
62
+
63
+ # Backward pass
64
+ scaled_loss.backward!
65
+
66
+ accumulated_loss += loss.item
67
+
68
+ # Optimizer step (every grad_accumulation_steps)
69
+ if (step + 1) % @grad_accumulation_steps == 0
70
+ # Gradient clipping
71
+ grad_norm = @optimizer.clip_grad_norm!(@max_grad_norm)
72
+
73
+ # Multi-GPU gradient sync
74
+ if @use_nvccl
75
+ sync_gradients_nvccl!
76
+ end
77
+
78
+ # Optimizer step
79
+ @optimizer.step
80
+ @optimizer.zero_grad!
81
+ @scheduler&.step
82
+
83
+ @metrics[:steps] += 1
84
+ @metrics[:total_loss] += accumulated_loss / @grad_accumulation_steps
85
+
86
+ # Logging
87
+ if @metrics[:steps] % log_interval == 0
88
+ avg_loss = @metrics[:total_loss] / @metrics[:steps]
89
+ lr = @optimizer.lr
90
+ Ignis.logger.info(
91
+ "Step #{@metrics[:steps]} | Loss: #{'%.4f' % (accumulated_loss / @grad_accumulation_steps)} | " \
92
+ "Avg Loss: #{'%.4f' % avg_loss} | LR: #{'%.2e' % lr} | Grad Norm: #{'%.2f' % grad_norm}"
93
+ )
94
+
95
+ # EventBus publish
96
+ if defined?(Ignis::Shared::EventBus)
97
+ Ignis::Shared::EventBus.publish(:training_step, {
98
+ step: @metrics[:steps],
99
+ loss: accumulated_loss / @grad_accumulation_steps,
100
+ avg_loss: avg_loss,
101
+ lr: lr,
102
+ grad_norm: grad_norm
103
+ })
104
+ end
105
+
106
+ # Eval
107
+ if eval_fn
108
+ @model.eval!
109
+ eval_fn.call(@model, @metrics[:steps])
110
+ @model.train!
111
+ end
112
+ end
113
+
114
+ # Checkpointing
115
+ if @checkpoint_dir && @metrics[:steps] % checkpoint_interval == 0
116
+ save_checkpoint!
117
+ end
118
+
119
+ accumulated_loss = 0.0
120
+ end
121
+
122
+ # Clear tape each iteration
123
+ Tape.clear!
124
+ end
125
+
126
+ @metrics
127
+ end
128
+
129
+ # Save model checkpoint.
130
+ # @return [String] checkpoint path
131
+ def save_checkpoint!
132
+ return unless @checkpoint_dir
133
+
134
+ Dir.mkdir(@checkpoint_dir) unless Dir.exist?(@checkpoint_dir)
135
+ path = File.join(@checkpoint_dir, "checkpoint_step_#{@metrics[:steps]}.safetensors")
136
+
137
+ tensors = {}
138
+ @model.named_parameters.each do |name, param|
139
+ tensors[name] = param
140
+ end
141
+
142
+ Safetensors.save(tensors, path, metadata: {
143
+ "step" => @metrics[:steps].to_s,
144
+ "loss" => (@metrics[:total_loss] / [@metrics[:steps], 1].max).to_s,
145
+ "framework" => "nnw"
146
+ })
147
+
148
+ Ignis.logger.info("Checkpoint saved: #{path}")
149
+ path
150
+ end
151
+
152
+ # Load from checkpoint.
153
+ # @param path [String]
154
+ # @return [void]
155
+ def load_checkpoint!(path)
156
+ Safetensors.load_model(@model, path, strict: false)
157
+ Ignis.logger.info("Checkpoint loaded: #{path}")
158
+ end
159
+
160
+ private
161
+
162
+ # Sync gradients across GPUs via NvCCL AllReduce.
163
+ # @return [void]
164
+ def sync_gradients_nvccl!
165
+ return unless defined?(Ignis::Collective)
166
+
167
+ @model.parameters.each do |p|
168
+ next unless p.grad
169
+ Ignis::Collective.all_reduce(p.grad, op: :sum)
170
+ end
171
+ end
172
+ end
173
+
174
+ # DataLoader — batching, shuffling, and GPU prefetch for training data.
175
+ class DataLoader
176
+ # @param data [Array<Array<Integer>>] tokenized sequences
177
+ # @param batch_size [Integer]
178
+ # @param seq_len [Integer] sequence length per sample
179
+ # @param device_id [Integer]
180
+ # @param shuffle [Boolean]
181
+ def initialize(data, batch_size:, seq_len:, device_id: 0, shuffle: true)
182
+ @data = data.flatten
183
+ @batch_size = batch_size
184
+ @seq_len = seq_len
185
+ @device_id = device_id
186
+ @shuffle = shuffle
187
+ @position = 0
188
+
189
+ # Shuffle on init
190
+ reshuffle! if @shuffle
191
+ end
192
+
193
+ # Get next training batch.
194
+ # @return [Hash{Symbol => Tensor}] :input_ids, :targets
195
+ def next_batch
196
+ total_tokens = @batch_size * @seq_len
197
+
198
+ if @position + total_tokens + 1 > @data.length
199
+ @position = 0
200
+ reshuffle! if @shuffle
201
+ end
202
+
203
+ input_ids = []
204
+ targets = []
205
+
206
+ @batch_size.times do |b|
207
+ start = @position + b * @seq_len
208
+ input_ids.concat(@data[start, @seq_len])
209
+ targets.concat(@data[start + 1, @seq_len])
210
+ end
211
+
212
+ @position += total_tokens
213
+
214
+ input_nv = Ignis::Shared::NvArray.new(shape: [@batch_size, @seq_len], dtype: :int32,
215
+ device_id: @device_id)
216
+ input_nv.from_host(input_ids)
217
+
218
+ target_nv = Ignis::Shared::NvArray.new(shape: [@batch_size * @seq_len], dtype: :int32,
219
+ device_id: @device_id)
220
+ target_nv.from_host(targets)
221
+
222
+ {
223
+ input_ids: Tensor.new(data: input_nv, requires_grad: false),
224
+ targets: Tensor.new(data: target_nv, requires_grad: false)
225
+ }
226
+ end
227
+
228
+ # Number of batches per epoch.
229
+ # @return [Integer]
230
+ def num_batches
231
+ (@data.length - 1) / (@batch_size * @seq_len)
232
+ end
233
+
234
+ private
235
+
236
+ def reshuffle!
237
+ # Shuffle at chunk boundaries (not individual tokens)
238
+ chunk_size = @seq_len + 1
239
+ chunks = @data.each_slice(chunk_size).to_a
240
+ chunks.shuffle!
241
+ @data = chunks.flatten
242
+ end
243
+ end
244
+ end
245
+ end
@@ -0,0 +1,89 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Ignis
4
+ module AI
5
+ module Transformer
6
+ # Multi-Head Attention with optional causal masking.
7
+ # Supports standard scaled dot-product attention.
8
+ #
9
+ # @example
10
+ # attn = MultiHeadAttention.new(768, 12)
11
+ # out = attn.call(x, x, x, mask: causal_mask)
12
+ class MultiHeadAttention < NN::Module
13
+ # @param embed_dim [Integer] model dimension
14
+ # @param num_heads [Integer] number of attention heads
15
+ # @param dropout [Float] attention dropout rate
16
+ # @param bias [Boolean] whether projections have bias
17
+ # @param device_id [Integer]
18
+ def initialize(embed_dim, num_heads, dropout: 0.0, bias: true, causal: true, device_id: 0)
19
+ super()
20
+ raise ArgumentError, "embed_dim must be divisible by num_heads" unless (embed_dim % num_heads).zero?
21
+
22
+ @embed_dim = embed_dim
23
+ @num_heads = num_heads
24
+ @head_dim = embed_dim / num_heads
25
+ @scale = 1.0 / Math.sqrt(@head_dim)
26
+ @causal = causal
27
+
28
+ @q_proj = register_module("q_proj", NN::Linear.new(embed_dim, embed_dim, bias: bias, device_id: device_id))
29
+ @k_proj = register_module("k_proj", NN::Linear.new(embed_dim, embed_dim, bias: bias, device_id: device_id))
30
+ @v_proj = register_module("v_proj", NN::Linear.new(embed_dim, embed_dim, bias: bias, device_id: device_id))
31
+ @out_proj = register_module("out_proj", NN::Linear.new(embed_dim, embed_dim, bias: bias, device_id: device_id))
32
+ @attn_dropout = register_module("attn_dropout", NN::Dropout.new(p: dropout))
33
+ end
34
+
35
+ # Forward pass: scaled dot-product attention.
36
+ # @param query [Tensor] [batch, seq_q, embed_dim]
37
+ # @param key [Tensor] [batch, seq_k, embed_dim]
38
+ # @param value [Tensor] [batch, seq_k, embed_dim]
39
+ # @param mask [Tensor, nil] attention mask
40
+ # @return [Tensor] [batch, seq_q, embed_dim]
41
+ def forward(query, key, value, mask: nil)
42
+ # Project Q, K, V → [seq, embed_dim] each (batch = 1)
43
+ q = @q_proj.call(query)
44
+ k = @k_proj.call(key)
45
+ v = @v_proj.call(value)
46
+
47
+ # Real multi-head scaled dot-product attention with causal masking.
48
+ # Each head attends over its own [seq, head_dim] slice; the per-head
49
+ # Flash-Attention-2 kernel applies 1/sqrt(head_dim) scaling, the causal
50
+ # mask, and a numerically-stable online softmax. (A nil mask still means
51
+ # causal for a GPT-2-style decoder; pass causal: false at construction
52
+ # for bidirectional attention.)
53
+ context = q.sdpa(k, v, num_heads: @num_heads, causal: @causal)
54
+
55
+ # Output projection
56
+ @out_proj.call(context)
57
+ end
58
+
59
+ # Incremental attention for one new token via a KV cache (decode path).
60
+ # Projects only this token's q/k/v, appends k/v to the cache, then attends
61
+ # the single query over all cached keys/values (the new token is the last
62
+ # position, so it sees everything — no causal mask). Must run under no_grad.
63
+ # @param x [Tensor] [1, embed] hidden state of the new token (post-norm)
64
+ # @param cache [KVCache]
65
+ # @param layer [Integer] this block's layer index
66
+ # @return [Tensor] [1, embed]
67
+ def decode_step(x, cache, layer)
68
+ raise "decode_step requires a causal attention module" unless @causal
69
+
70
+ q = @q_proj.call(x)
71
+ k = @k_proj.call(x)
72
+ v = @v_proj.call(x)
73
+
74
+ cache.append(layer, k.data, v.data)
75
+ kview = Tensor.new(data: cache.k_view(layer), requires_grad: false)
76
+ vview = Tensor.new(data: cache.v_view(layer), requires_grad: false)
77
+
78
+ ctx = q.decode_sdpa(kview, vview, num_heads: @num_heads)
79
+ @out_proj.call(ctx)
80
+ end
81
+
82
+ # @return [String]
83
+ def to_s
84
+ "MultiHeadAttention(embed=#{@embed_dim}, heads=#{@num_heads}, head_dim=#{@head_dim})"
85
+ end
86
+ end
87
+ end
88
+ end
89
+ end
@@ -0,0 +1,90 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Ignis
4
+ module AI
5
+ module Transformer
6
+ # Single Transformer block: Attention + FF with residual connections.
7
+ # Supports both pre-norm (GPT-2, LLaMA) and post-norm (original) variants.
8
+ class Block < NN::Module
9
+ # @param embed_dim [Integer]
10
+ # @param num_heads [Integer]
11
+ # @param ff_dim [Integer]
12
+ # @param dropout [Float]
13
+ # @param pre_norm [Boolean] pre-norm (true) vs post-norm
14
+ # @param activation [Symbol]
15
+ # @param device_id [Integer]
16
+ def initialize(embed_dim, num_heads, ff_dim, dropout: 0.0,
17
+ pre_norm: true, activation: :gelu, device_id: 0)
18
+ super()
19
+ @pre_norm = pre_norm
20
+
21
+ @attention = register_module("attention",
22
+ MultiHeadAttention.new(embed_dim, num_heads, dropout: dropout, device_id: device_id))
23
+ @feed_forward = register_module("feed_forward",
24
+ FeedForward.new(embed_dim, ff_dim, activation: activation,
25
+ dropout: dropout, device_id: device_id))
26
+ @norm1 = register_module("norm1", NN::LayerNorm.new(embed_dim, device_id: device_id))
27
+ @norm2 = register_module("norm2", NN::LayerNorm.new(embed_dim, device_id: device_id))
28
+ @dropout = register_module("dropout", NN::Dropout.new(p: dropout))
29
+ end
30
+
31
+ # Forward pass.
32
+ # @param x [Tensor] input [batch*seq, embed_dim]
33
+ # @param mask [Tensor, nil] attention mask
34
+ # @return [Tensor]
35
+ def forward(x, mask: nil)
36
+ if @pre_norm
37
+ # Pre-norm (GPT-2 style): Norm → Attn → Residual, Norm → FF → Residual
38
+ residual = x
39
+ h = @norm1.call(x)
40
+ h = @attention.call(h, h, h, mask: mask)
41
+ h = @dropout.call(h)
42
+ x = residual + h
43
+
44
+ residual = x
45
+ h = @norm2.call(x)
46
+ h = @feed_forward.call(h)
47
+ h = @dropout.call(h)
48
+ residual + h
49
+ else
50
+ # Post-norm (original Transformer): Attn → Residual → Norm
51
+ residual = x
52
+ h = @attention.call(x, x, x, mask: mask)
53
+ h = @dropout.call(h)
54
+ x = @norm1.call(residual + h)
55
+
56
+ residual = x
57
+ h = @feed_forward.call(x)
58
+ h = @dropout.call(h)
59
+ @norm2.call(residual + h)
60
+ end
61
+ end
62
+
63
+ # Incremental forward for one new token (decode path, pre-norm only).
64
+ # Mirrors #forward but routes attention through the KV cache and operates
65
+ # on a single [1, embed] row. Dropout is identity in eval mode, so it is
66
+ # omitted. @param x [Tensor] [1, embed]; @param cache [KVCache];
67
+ # @param layer [Integer] this block's index. @return [Tensor] [1, embed]
68
+ def decode_step(x, cache, layer)
69
+ raise "Block#decode_step is only implemented for pre-norm blocks" unless @pre_norm
70
+
71
+ residual = x
72
+ h = @norm1.call(x)
73
+ h = @attention.decode_step(h, cache, layer)
74
+ x = residual + h
75
+
76
+ residual = x
77
+ h = @norm2.call(x)
78
+ h = @feed_forward.call(h)
79
+ residual + h
80
+ end
81
+
82
+ # @return [String]
83
+ def to_s
84
+ style = @pre_norm ? "pre-norm" : "post-norm"
85
+ "Block(#{style}, attn=#{@attention}, ff=#{@feed_forward})"
86
+ end
87
+ end
88
+ end
89
+ end
90
+ end
@@ -0,0 +1,53 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Ignis
4
+ module AI
5
+ module Transformer
6
+ # Feed-forward network: Linear → Activation → Dropout → Linear
7
+ # Used in each Transformer block after attention.
8
+ class FeedForward < NN::Module
9
+ # @param embed_dim [Integer] model dimension
10
+ # @param ff_dim [Integer] feed-forward hidden dimension (typically 4x embed_dim)
11
+ # @param activation [Symbol] :gelu, :relu, or :silu
12
+ # @param dropout [Float] dropout rate
13
+ # @param device_id [Integer]
14
+ def initialize(embed_dim, ff_dim, activation: :gelu, dropout: 0.0, device_id: 0)
15
+ super()
16
+ @activation = activation
17
+
18
+ @fc1 = register_module("fc1", NN::Linear.new(embed_dim, ff_dim, device_id: device_id))
19
+ @fc2 = register_module("fc2", NN::Linear.new(ff_dim, embed_dim, device_id: device_id))
20
+ @dropout = register_module("dropout", NN::Dropout.new(p: dropout))
21
+ end
22
+
23
+ # Forward pass.
24
+ # @param x [Tensor]
25
+ # @return [Tensor]
26
+ def forward(x)
27
+ h = @fc1.call(x)
28
+ h = apply_activation(h)
29
+ h = @dropout.call(h)
30
+ @fc2.call(h)
31
+ end
32
+
33
+ # @return [String]
34
+ def to_s
35
+ "FeedForward(fc1=#{@fc1}, activation=#{@activation}, fc2=#{@fc2})"
36
+ end
37
+
38
+ private
39
+
40
+ # @param x [Tensor]
41
+ # @return [Tensor]
42
+ def apply_activation(x)
43
+ case @activation
44
+ when :gelu then x.gelu
45
+ when :relu then x.relu
46
+ when :silu then x.silu
47
+ else raise ArgumentError, "Unknown activation: #{@activation}"
48
+ end
49
+ end
50
+ end
51
+ end
52
+ end
53
+ end
@@ -0,0 +1,189 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Ignis
4
+ module AI
5
+ module Transformer
6
+ # Full Transformer language model.
7
+ #
8
+ # token_embedding → position_embedding → N × Block → LayerNorm → LM head
9
+ #
10
+ # Factory methods provide standard model configurations:
11
+ # .gpt2_small → 124M params
12
+ # .gpt2_medium → 345M params
13
+ # .gpt2_large → 774M params
14
+ class Model < NN::Module
15
+ # @return [Integer]
16
+ attr_reader :vocab_size, :embed_dim, :num_heads, :num_layers, :max_seq_len
17
+
18
+ # @param vocab_size [Integer] vocabulary size
19
+ # @param embed_dim [Integer] model dimension
20
+ # @param num_heads [Integer] attention heads per block
21
+ # @param num_layers [Integer] number of Transformer blocks
22
+ # @param ff_dim [Integer] feed-forward hidden dimension
23
+ # @param max_seq_len [Integer] maximum sequence length
24
+ # @param dropout [Float]
25
+ # @param activation [Symbol]
26
+ # @param pre_norm [Boolean]
27
+ # @param device_id [Integer]
28
+ def initialize(vocab_size:, embed_dim:, num_heads:, num_layers:,
29
+ ff_dim:, max_seq_len:, dropout: 0.0,
30
+ activation: :gelu, pre_norm: true, device_id: 0)
31
+ super()
32
+ @vocab_size = vocab_size
33
+ @embed_dim = embed_dim
34
+ @num_heads = num_heads
35
+ @num_layers = num_layers
36
+ @max_seq_len = max_seq_len
37
+ @device_id = device_id
38
+
39
+ @token_embedding = register_module("token_embedding",
40
+ NN::Embedding.new(vocab_size, embed_dim, device_id: device_id))
41
+ @position_embedding = register_module("position_embedding",
42
+ NN::Embedding.new(max_seq_len, embed_dim, device_id: device_id))
43
+
44
+ @blocks = []
45
+ num_layers.times do |i|
46
+ block = Block.new(embed_dim, num_heads, ff_dim,
47
+ dropout: dropout, pre_norm: pre_norm,
48
+ activation: activation, device_id: device_id)
49
+ @blocks << register_module("blocks.#{i}", block)
50
+ end
51
+
52
+ @norm = register_module("norm", NN::LayerNorm.new(embed_dim, device_id: device_id))
53
+ @head = register_module("head",
54
+ NN::Linear.new(embed_dim, vocab_size, bias: false, device_id: device_id))
55
+ @dropout = register_module("dropout", NN::Dropout.new(p: dropout))
56
+ end
57
+
58
+ # Forward pass: returns logits.
59
+ # @param input_ids [Tensor] token indices [batch_size, seq_len] (int32)
60
+ # @param mask [Tensor, nil] attention mask
61
+ # @return [Tensor] logits [batch_size * seq_len, vocab_size]
62
+ def forward(input_ids, mask: nil)
63
+ seq_len = input_ids.shape[-1]
64
+
65
+ # Create position indices
66
+ positions_data = (0...seq_len).to_a
67
+ pos_nv = Ignis::Shared::NvArray.new(shape: [seq_len], dtype: :int32,
68
+ device_id: input_ids.device_id)
69
+ pos_nv.from_host(positions_data)
70
+ positions = Tensor.new(data: pos_nv, requires_grad: false)
71
+
72
+ # Embeddings
73
+ tok_emb = @token_embedding.call(input_ids) # [batch, seq, embed]
74
+ pos_emb = @position_embedding.call(positions) # [seq, embed]
75
+
76
+ # Combine and dropout
77
+ x = tok_emb + pos_emb
78
+ x = @dropout.call(x)
79
+
80
+ # Transformer blocks
81
+ @blocks.each do |block|
82
+ x = block.call(x, mask: mask)
83
+ end
84
+
85
+ # Final norm and LM head
86
+ x = @norm.call(x)
87
+ @head.call(x) # → logits [batch*seq, vocab]
88
+ end
89
+
90
+ # Allocate a fresh KV cache sized for this model.
91
+ # @param device_id [Integer]
92
+ # @return [KVCache]
93
+ def make_kv_cache(device_id: @device_id)
94
+ KVCache.new(num_layers: @num_layers, max_seq_len: @max_seq_len,
95
+ embed_dim: @embed_dim, device_id: device_id)
96
+ end
97
+
98
+ # Incremental forward for ONE new token using a KV cache (decode path).
99
+ # Equivalent to the last-position logits of a full forward over the whole
100
+ # prefix, but O(prefix) instead of O(prefix²): only this token is projected
101
+ # and embedded; its query attends over cached K/V. Must run under
102
+ # Tape.no_grad (no autograd). Append order matches the prefix order, so
103
+ # callers feed the prompt token-by-token before sampling.
104
+ # @param token_id [Integer] the new token's id
105
+ # @param cache [KVCache]
106
+ # @return [Tensor] logits [1, vocab]
107
+ def decode_step(token_id, cache)
108
+ pos = cache.length
109
+ raise "KVCache full: position #{pos} exceeds max_seq_len #{@max_seq_len}" if pos >= @max_seq_len
110
+
111
+ tok = Tensor.from_host([token_id], shape: [1], dtype: :int32, device_id: @device_id)
112
+ pos_t = Tensor.from_host([pos], shape: [1], dtype: :int32, device_id: @device_id)
113
+
114
+ x = @token_embedding.call(tok) + @position_embedding.call(pos_t) # [1, embed]
115
+ @blocks.each_with_index { |block, i| x = block.decode_step(x, cache, i) }
116
+ cache.advance!
117
+
118
+ x = @norm.call(x)
119
+ @head.call(x) # [1, vocab]
120
+ end
121
+
122
+ # -------------------------------------------------------------------
123
+ # Factory methods for standard configurations
124
+ # -------------------------------------------------------------------
125
+
126
+ # GPT-2 Small: 124M parameters
127
+ # @param device_id [Integer]
128
+ # @return [Model]
129
+ def self.gpt2_small(device_id: 0)
130
+ new(
131
+ vocab_size: 50257,
132
+ embed_dim: 768,
133
+ num_heads: 12,
134
+ num_layers: 12,
135
+ ff_dim: 3072,
136
+ max_seq_len: 1024,
137
+ dropout: 0.1,
138
+ activation: :gelu,
139
+ pre_norm: true,
140
+ device_id: device_id
141
+ )
142
+ end
143
+
144
+ # GPT-2 Medium: 345M parameters
145
+ # @param device_id [Integer]
146
+ # @return [Model]
147
+ def self.gpt2_medium(device_id: 0)
148
+ new(
149
+ vocab_size: 50257,
150
+ embed_dim: 1024,
151
+ num_heads: 16,
152
+ num_layers: 24,
153
+ ff_dim: 4096,
154
+ max_seq_len: 1024,
155
+ dropout: 0.1,
156
+ activation: :gelu,
157
+ pre_norm: true,
158
+ device_id: device_id
159
+ )
160
+ end
161
+
162
+ # GPT-2 Large: 774M parameters
163
+ # @param device_id [Integer]
164
+ # @return [Model]
165
+ def self.gpt2_large(device_id: 0)
166
+ new(
167
+ vocab_size: 50257,
168
+ embed_dim: 1280,
169
+ num_heads: 20,
170
+ num_layers: 36,
171
+ ff_dim: 5120,
172
+ max_seq_len: 1024,
173
+ dropout: 0.1,
174
+ activation: :gelu,
175
+ pre_norm: true,
176
+ device_id: device_id
177
+ )
178
+ end
179
+
180
+ # @return [String]
181
+ def to_s
182
+ "TransformerModel(vocab=#{@vocab_size}, embed=#{@embed_dim}, " \
183
+ "heads=#{@num_heads}, layers=#{@num_layers}, " \
184
+ "params=#{num_parameters})"
185
+ end
186
+ end
187
+ end
188
+ end
189
+ end