fine 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. checksums.yaml +7 -0
  2. data/.rspec +3 -0
  3. data/CHANGELOG.md +38 -0
  4. data/Gemfile +6 -0
  5. data/Gemfile.lock +167 -0
  6. data/LICENSE +21 -0
  7. data/README.md +212 -0
  8. data/Rakefile +6 -0
  9. data/docs/installation.md +151 -0
  10. data/docs/tutorials/llm-fine-tuning.md +246 -0
  11. data/docs/tutorials/model-export.md +200 -0
  12. data/docs/tutorials/siglip2-image-classification.md +130 -0
  13. data/docs/tutorials/siglip2-object-recognition.md +203 -0
  14. data/docs/tutorials/siglip2-similarity-search.md +152 -0
  15. data/docs/tutorials/text-classification.md +233 -0
  16. data/docs/tutorials/text-embeddings.md +211 -0
  17. data/examples/basic_classification.rb +70 -0
  18. data/examples/data/tool_calls.jsonl +30 -0
  19. data/examples/demo_training.rb +78 -0
  20. data/examples/finetune_gemma3_tools.rb +135 -0
  21. data/examples/real_llm_test.rb +128 -0
  22. data/examples/real_text_classification_test.rb +90 -0
  23. data/examples/real_text_embedder_test.rb +110 -0
  24. data/examples/real_training_test.rb +88 -0
  25. data/examples/test_export.rb +28 -0
  26. data/examples/test_image_classifier.rb +79 -0
  27. data/examples/test_llm.rb +100 -0
  28. data/examples/test_text_classifier.rb +59 -0
  29. data/lib/fine/callbacks/base.rb +140 -0
  30. data/lib/fine/callbacks/progress_bar.rb +66 -0
  31. data/lib/fine/configuration.rb +106 -0
  32. data/lib/fine/datasets/data_loader.rb +63 -0
  33. data/lib/fine/datasets/image_dataset.rb +203 -0
  34. data/lib/fine/datasets/instruction_dataset.rb +226 -0
  35. data/lib/fine/datasets/text_data_loader.rb +88 -0
  36. data/lib/fine/datasets/text_dataset.rb +266 -0
  37. data/lib/fine/error.rb +49 -0
  38. data/lib/fine/export/gguf_exporter.rb +424 -0
  39. data/lib/fine/export/onnx_exporter.rb +249 -0
  40. data/lib/fine/export.rb +53 -0
  41. data/lib/fine/hub/config_loader.rb +145 -0
  42. data/lib/fine/hub/model_downloader.rb +136 -0
  43. data/lib/fine/hub/safetensors_loader.rb +108 -0
  44. data/lib/fine/image_classifier.rb +256 -0
  45. data/lib/fine/llm.rb +336 -0
  46. data/lib/fine/models/base.rb +48 -0
  47. data/lib/fine/models/bert_encoder.rb +202 -0
  48. data/lib/fine/models/bert_for_sequence_classification.rb +226 -0
  49. data/lib/fine/models/causal_lm.rb +279 -0
  50. data/lib/fine/models/classification_head.rb +24 -0
  51. data/lib/fine/models/gemma3_decoder.rb +244 -0
  52. data/lib/fine/models/llama_decoder.rb +297 -0
  53. data/lib/fine/models/sentence_transformer.rb +202 -0
  54. data/lib/fine/models/siglip2_for_image_classification.rb +155 -0
  55. data/lib/fine/models/siglip2_vision_encoder.rb +190 -0
  56. data/lib/fine/text_classifier.rb +250 -0
  57. data/lib/fine/text_embedder.rb +221 -0
  58. data/lib/fine/tokenizers/auto_tokenizer.rb +208 -0
  59. data/lib/fine/training/llm_trainer.rb +212 -0
  60. data/lib/fine/training/text_trainer.rb +275 -0
  61. data/lib/fine/training/trainer.rb +194 -0
  62. data/lib/fine/transforms/compose.rb +28 -0
  63. data/lib/fine/transforms/normalize.rb +33 -0
  64. data/lib/fine/transforms/resize.rb +35 -0
  65. data/lib/fine/transforms/to_tensor.rb +53 -0
  66. data/lib/fine/version.rb +3 -0
  67. data/lib/fine.rb +112 -0
  68. data/mise.toml +2 -0
  69. metadata +240 -0
@@ -0,0 +1,244 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Fine
4
+ module Models
5
+ # Gemma 3 decoder-only transformer
6
+ #
7
+ # Differences from Llama:
8
+ # - Explicit head_dim config
9
+ # - QK normalization (q_norm, k_norm)
10
+ # - Pre/post feedforward layer norms
11
+ # - Different GQA ratios
12
+ class Gemma3Decoder < Base
13
+ def initialize(config)
14
+ super(config)
15
+
16
+ @vocab_size = config.vocab_size
17
+ @hidden_size = config.hidden_size
18
+ @num_layers = config.num_hidden_layers
19
+ @num_heads = config.num_attention_heads
20
+ @num_kv_heads = config.num_key_value_heads || @num_heads
21
+ @head_dim = config.head_dim
22
+ @intermediate_size = config.intermediate_size
23
+ @max_position_embeddings = config.max_position_embeddings || 32768
24
+ @rms_norm_eps = config.rms_norm_eps || 1e-6
25
+ @rope_theta = config.rope_theta || 1_000_000.0
26
+
27
+ # Token embeddings
28
+ @embed_tokens = Torch::NN::Embedding.new(@vocab_size, @hidden_size)
29
+
30
+ # Transformer layers
31
+ @layers = Torch::NN::ModuleList.new(
32
+ @num_layers.times.map do
33
+ Gemma3DecoderLayer.new(
34
+ hidden_size: @hidden_size,
35
+ num_heads: @num_heads,
36
+ num_kv_heads: @num_kv_heads,
37
+ head_dim: @head_dim,
38
+ intermediate_size: @intermediate_size,
39
+ rms_norm_eps: @rms_norm_eps,
40
+ rope_theta: @rope_theta,
41
+ max_position_embeddings: @max_position_embeddings
42
+ )
43
+ end
44
+ )
45
+
46
+ # Final layer norm
47
+ @norm = RMSNorm.new(@hidden_size, eps: @rms_norm_eps)
48
+ end
49
+
50
+ def forward(input_ids, attention_mask: nil, position_ids: nil)
51
+ batch_size, seq_length = input_ids.shape
52
+
53
+ # Get token embeddings
54
+ hidden_states = @embed_tokens.call(input_ids)
55
+
56
+ # Normalize embeddings (Gemma specific)
57
+ hidden_states = hidden_states * Math.sqrt(@hidden_size)
58
+
59
+ # Create position IDs if not provided
60
+ position_ids ||= Torch.arange(seq_length, device: input_ids.device)
61
+ position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
62
+
63
+ # Create causal mask
64
+ causal_mask = create_causal_mask(seq_length, hidden_states.device)
65
+
66
+ # Combine with attention mask if provided
67
+ if attention_mask
68
+ expanded_mask = attention_mask.unsqueeze(1).unsqueeze(2)
69
+ expanded_mask = expanded_mask.expand(-1, -1, seq_length, -1)
70
+ causal_mask = causal_mask + (1.0 - expanded_mask) * -1e9
71
+ end
72
+
73
+ # Forward through layers
74
+ @layers.each do |layer|
75
+ hidden_states = layer.call(
76
+ hidden_states,
77
+ attention_mask: causal_mask,
78
+ position_ids: position_ids
79
+ )
80
+ end
81
+
82
+ # Final norm
83
+ hidden_states = @norm.call(hidden_states)
84
+
85
+ { last_hidden_state: hidden_states }
86
+ end
87
+
88
+ private
89
+
90
+ def create_causal_mask(seq_length, device)
91
+ mask = Torch.triu(
92
+ Torch.ones(seq_length, seq_length, device: device) * -1e9,
93
+ diagonal: 1
94
+ )
95
+ mask.unsqueeze(0).unsqueeze(0)
96
+ end
97
+ end
98
+
99
+ # Single Gemma 3 decoder layer
100
+ class Gemma3DecoderLayer < Torch::NN::Module
101
+ def initialize(hidden_size:, num_heads:, num_kv_heads:, head_dim:,
102
+ intermediate_size:, rms_norm_eps:, rope_theta:, max_position_embeddings:)
103
+ super()
104
+
105
+ @self_attn = Gemma3Attention.new(
106
+ hidden_size: hidden_size,
107
+ num_heads: num_heads,
108
+ num_kv_heads: num_kv_heads,
109
+ head_dim: head_dim,
110
+ rope_theta: rope_theta,
111
+ max_position_embeddings: max_position_embeddings,
112
+ rms_norm_eps: rms_norm_eps
113
+ )
114
+
115
+ @mlp = LlamaMLP.new(
116
+ hidden_size: hidden_size,
117
+ intermediate_size: intermediate_size
118
+ )
119
+
120
+ @input_layernorm = RMSNorm.new(hidden_size, eps: rms_norm_eps)
121
+ @post_attention_layernorm = RMSNorm.new(hidden_size, eps: rms_norm_eps)
122
+ @pre_feedforward_layernorm = RMSNorm.new(hidden_size, eps: rms_norm_eps)
123
+ @post_feedforward_layernorm = RMSNorm.new(hidden_size, eps: rms_norm_eps)
124
+ end
125
+
126
+ def forward(hidden_states, attention_mask: nil, position_ids: nil)
127
+ # Self attention with residual
128
+ residual = hidden_states
129
+ hidden_states = @input_layernorm.call(hidden_states)
130
+ hidden_states = @self_attn.call(
131
+ hidden_states,
132
+ attention_mask: attention_mask,
133
+ position_ids: position_ids
134
+ )
135
+ hidden_states = @post_attention_layernorm.call(hidden_states)
136
+ hidden_states = residual + hidden_states
137
+
138
+ # MLP with residual
139
+ residual = hidden_states
140
+ hidden_states = @pre_feedforward_layernorm.call(hidden_states)
141
+ hidden_states = @mlp.call(hidden_states)
142
+ hidden_states = @post_feedforward_layernorm.call(hidden_states)
143
+ hidden_states = residual + hidden_states
144
+
145
+ hidden_states
146
+ end
147
+ end
148
+
149
+ # Gemma 3 attention with QK normalization
150
+ class Gemma3Attention < Torch::NN::Module
151
+ def initialize(hidden_size:, num_heads:, num_kv_heads:, head_dim:,
152
+ rope_theta:, max_position_embeddings:, rms_norm_eps:)
153
+ super()
154
+
155
+ @num_heads = num_heads
156
+ @num_kv_heads = num_kv_heads
157
+ @head_dim = head_dim
158
+ @num_key_value_groups = num_heads / num_kv_heads
159
+
160
+ @q_proj = Torch::NN::Linear.new(hidden_size, num_heads * head_dim, bias: false)
161
+ @k_proj = Torch::NN::Linear.new(hidden_size, num_kv_heads * head_dim, bias: false)
162
+ @v_proj = Torch::NN::Linear.new(hidden_size, num_kv_heads * head_dim, bias: false)
163
+ @o_proj = Torch::NN::Linear.new(num_heads * head_dim, hidden_size, bias: false)
164
+
165
+ # QK normalization (Gemma 3 specific)
166
+ @q_norm = RMSNorm.new(head_dim, eps: rms_norm_eps)
167
+ @k_norm = RMSNorm.new(head_dim, eps: rms_norm_eps)
168
+
169
+ @rotary_emb = RotaryEmbedding.new(head_dim, max_position_embeddings, rope_theta)
170
+ end
171
+
172
+ def forward(hidden_states, attention_mask: nil, position_ids: nil)
173
+ batch_size, seq_length, _ = hidden_states.shape
174
+
175
+ # Project to Q, K, V
176
+ query_states = @q_proj.call(hidden_states)
177
+ key_states = @k_proj.call(hidden_states)
178
+ value_states = @v_proj.call(hidden_states)
179
+
180
+ # Reshape for multi-head attention
181
+ query_states = query_states.view(batch_size, seq_length, @num_heads, @head_dim).transpose(1, 2)
182
+ key_states = key_states.view(batch_size, seq_length, @num_kv_heads, @head_dim).transpose(1, 2)
183
+ value_states = value_states.view(batch_size, seq_length, @num_kv_heads, @head_dim).transpose(1, 2)
184
+
185
+ # Apply QK normalization
186
+ query_states = apply_qk_norm(query_states, @q_norm)
187
+ key_states = apply_qk_norm(key_states, @k_norm)
188
+
189
+ # Apply rotary embeddings
190
+ cos, sin = @rotary_emb.call(value_states, position_ids)
191
+ query_states = apply_rotary_pos_emb(query_states, cos, sin)
192
+ key_states = apply_rotary_pos_emb(key_states, cos, sin)
193
+
194
+ # Repeat KV heads for grouped-query attention
195
+ if @num_key_value_groups > 1
196
+ key_states = repeat_kv(key_states, @num_key_value_groups)
197
+ value_states = repeat_kv(value_states, @num_key_value_groups)
198
+ end
199
+
200
+ # Attention with softcapping (Gemma uses sqrt(head_dim) scaling)
201
+ scale = @head_dim ** -0.5
202
+ attn_weights = Torch.matmul(query_states, key_states.transpose(-2, -1)) * scale
203
+
204
+ # Apply causal mask
205
+ attn_weights = attn_weights + attention_mask if attention_mask
206
+
207
+ attn_weights = Torch::NN::Functional.softmax(attn_weights, dim: -1)
208
+ attn_output = Torch.matmul(attn_weights, value_states)
209
+
210
+ # Reshape back
211
+ attn_output = attn_output.transpose(1, 2).contiguous
212
+ attn_output = attn_output.reshape(batch_size, seq_length, -1)
213
+
214
+ @o_proj.call(attn_output)
215
+ end
216
+
217
+ private
218
+
219
+ def apply_qk_norm(states, norm)
220
+ # states: (batch, heads, seq, head_dim)
221
+ # Need to apply norm per head
222
+ batch, heads, seq, head_dim = states.shape
223
+ states = states.transpose(1, 2).reshape(batch * seq, heads, head_dim)
224
+ states = norm.call(states)
225
+ states.reshape(batch, seq, heads, head_dim).transpose(1, 2)
226
+ end
227
+
228
+ def apply_rotary_pos_emb(x, cos, sin)
229
+ x1 = x[0.., 0.., 0.., 0...(@head_dim / 2)]
230
+ x2 = x[0.., 0.., 0.., (@head_dim / 2)..]
231
+ rotated = Torch.cat([-x2, x1], dim: -1)
232
+ (x * cos) + (rotated * sin)
233
+ end
234
+
235
+ def repeat_kv(x, n_rep)
236
+ batch, num_kv_heads, seq_len, head_dim = x.shape
237
+ return x if n_rep == 1
238
+
239
+ x = x.unsqueeze(2).expand(batch, num_kv_heads, n_rep, seq_len, head_dim)
240
+ x.reshape(batch, num_kv_heads * n_rep, seq_len, head_dim)
241
+ end
242
+ end
243
+ end
244
+ end
@@ -0,0 +1,297 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Fine
4
+ module Models
5
+ # Llama-style decoder-only transformer
6
+ #
7
+ # Compatible with Llama, Gemma, Mistral, Qwen architectures.
8
+ # Uses RoPE positional embeddings, RMSNorm, and SwiGLU activation.
9
+ class LlamaDecoder < Base
10
+ def initialize(config)
11
+ super(config)
12
+
13
+ @vocab_size = config.vocab_size
14
+ @hidden_size = config.hidden_size
15
+ @num_layers = config.num_hidden_layers
16
+ @num_heads = config.num_attention_heads
17
+ @num_kv_heads = config.num_key_value_heads || @num_heads
18
+ @intermediate_size = config.intermediate_size
19
+ @max_position_embeddings = config.max_position_embeddings || 2048
20
+ @rms_norm_eps = config.rms_norm_eps || 1e-6
21
+ @rope_theta = config.rope_theta || 10000.0
22
+
23
+ # Token embeddings
24
+ @embed_tokens = Torch::NN::Embedding.new(@vocab_size, @hidden_size)
25
+
26
+ # Transformer layers
27
+ @layers = Torch::NN::ModuleList.new(
28
+ @num_layers.times.map do
29
+ LlamaDecoderLayer.new(
30
+ hidden_size: @hidden_size,
31
+ num_heads: @num_heads,
32
+ num_kv_heads: @num_kv_heads,
33
+ intermediate_size: @intermediate_size,
34
+ rms_norm_eps: @rms_norm_eps,
35
+ rope_theta: @rope_theta,
36
+ max_position_embeddings: @max_position_embeddings
37
+ )
38
+ end
39
+ )
40
+
41
+ # Final layer norm
42
+ @norm = RMSNorm.new(@hidden_size, eps: @rms_norm_eps)
43
+ end
44
+
45
+ def forward(input_ids, attention_mask: nil, position_ids: nil)
46
+ batch_size, seq_length = input_ids.shape
47
+
48
+ # Get token embeddings
49
+ hidden_states = @embed_tokens.call(input_ids)
50
+
51
+ # Create position IDs if not provided
52
+ position_ids ||= Torch.arange(seq_length, device: input_ids.device)
53
+ position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
54
+
55
+ # Create causal mask
56
+ causal_mask = create_causal_mask(seq_length, hidden_states.device)
57
+
58
+ # Combine with attention mask if provided
59
+ if attention_mask
60
+ # Expand attention mask: (batch, seq) -> (batch, 1, seq, seq)
61
+ expanded_mask = attention_mask.unsqueeze(1).unsqueeze(2)
62
+ expanded_mask = expanded_mask.expand(-1, -1, seq_length, -1)
63
+ causal_mask = causal_mask + (1.0 - expanded_mask) * -1e9
64
+ end
65
+
66
+ # Forward through layers
67
+ @layers.each do |layer|
68
+ hidden_states = layer.call(
69
+ hidden_states,
70
+ attention_mask: causal_mask,
71
+ position_ids: position_ids
72
+ )
73
+ end
74
+
75
+ # Final norm
76
+ hidden_states = @norm.call(hidden_states)
77
+
78
+ { last_hidden_state: hidden_states }
79
+ end
80
+
81
+ private
82
+
83
+ def create_causal_mask(seq_length, device)
84
+ # Lower triangular mask for causal attention
85
+ mask = Torch.triu(
86
+ Torch.ones(seq_length, seq_length, device: device) * -1e9,
87
+ diagonal: 1
88
+ )
89
+ mask.unsqueeze(0).unsqueeze(0)
90
+ end
91
+ end
92
+
93
+ # Single Llama decoder layer
94
+ class LlamaDecoderLayer < Torch::NN::Module
95
+ def initialize(hidden_size:, num_heads:, num_kv_heads:, intermediate_size:,
96
+ rms_norm_eps:, rope_theta:, max_position_embeddings:)
97
+ super()
98
+
99
+ @self_attn = LlamaAttention.new(
100
+ hidden_size: hidden_size,
101
+ num_heads: num_heads,
102
+ num_kv_heads: num_kv_heads,
103
+ rope_theta: rope_theta,
104
+ max_position_embeddings: max_position_embeddings
105
+ )
106
+
107
+ @mlp = LlamaMLP.new(
108
+ hidden_size: hidden_size,
109
+ intermediate_size: intermediate_size
110
+ )
111
+
112
+ @input_layernorm = RMSNorm.new(hidden_size, eps: rms_norm_eps)
113
+ @post_attention_layernorm = RMSNorm.new(hidden_size, eps: rms_norm_eps)
114
+ end
115
+
116
+ def forward(hidden_states, attention_mask: nil, position_ids: nil)
117
+ # Self attention with residual
118
+ residual = hidden_states
119
+ hidden_states = @input_layernorm.call(hidden_states)
120
+ hidden_states = @self_attn.call(
121
+ hidden_states,
122
+ attention_mask: attention_mask,
123
+ position_ids: position_ids
124
+ )
125
+ hidden_states = residual + hidden_states
126
+
127
+ # MLP with residual
128
+ residual = hidden_states
129
+ hidden_states = @post_attention_layernorm.call(hidden_states)
130
+ hidden_states = @mlp.call(hidden_states)
131
+ hidden_states = residual + hidden_states
132
+
133
+ hidden_states
134
+ end
135
+ end
136
+
137
+ # Llama attention with RoPE and grouped-query attention
138
+ class LlamaAttention < Torch::NN::Module
139
+ def initialize(hidden_size:, num_heads:, num_kv_heads:, rope_theta:, max_position_embeddings:)
140
+ super()
141
+
142
+ @num_heads = num_heads
143
+ @num_kv_heads = num_kv_heads
144
+ @head_dim = hidden_size / num_heads
145
+ @num_key_value_groups = num_heads / num_kv_heads
146
+
147
+ @q_proj = Torch::NN::Linear.new(hidden_size, num_heads * @head_dim, bias: false)
148
+ @k_proj = Torch::NN::Linear.new(hidden_size, num_kv_heads * @head_dim, bias: false)
149
+ @v_proj = Torch::NN::Linear.new(hidden_size, num_kv_heads * @head_dim, bias: false)
150
+ @o_proj = Torch::NN::Linear.new(num_heads * @head_dim, hidden_size, bias: false)
151
+
152
+ @rotary_emb = RotaryEmbedding.new(@head_dim, max_position_embeddings, rope_theta)
153
+ end
154
+
155
+ def forward(hidden_states, attention_mask: nil, position_ids: nil)
156
+ batch_size, seq_length, _ = hidden_states.shape
157
+
158
+ # Project to Q, K, V
159
+ query_states = @q_proj.call(hidden_states)
160
+ key_states = @k_proj.call(hidden_states)
161
+ value_states = @v_proj.call(hidden_states)
162
+
163
+ # Reshape for multi-head attention
164
+ query_states = query_states.view(batch_size, seq_length, @num_heads, @head_dim).transpose(1, 2)
165
+ key_states = key_states.view(batch_size, seq_length, @num_kv_heads, @head_dim).transpose(1, 2)
166
+ value_states = value_states.view(batch_size, seq_length, @num_kv_heads, @head_dim).transpose(1, 2)
167
+
168
+ # Apply rotary embeddings
169
+ cos, sin = @rotary_emb.call(value_states, position_ids)
170
+ query_states = apply_rotary_pos_emb(query_states, cos, sin)
171
+ key_states = apply_rotary_pos_emb(key_states, cos, sin)
172
+
173
+ # Repeat KV heads for grouped-query attention
174
+ if @num_key_value_groups > 1
175
+ key_states = repeat_kv(key_states, @num_key_value_groups)
176
+ value_states = repeat_kv(value_states, @num_key_value_groups)
177
+ end
178
+
179
+ # Attention
180
+ scale = @head_dim ** -0.5
181
+ attn_weights = Torch.matmul(query_states, key_states.transpose(-2, -1)) * scale
182
+
183
+ # Apply causal mask
184
+ attn_weights = attn_weights + attention_mask if attention_mask
185
+
186
+ attn_weights = Torch::NN::Functional.softmax(attn_weights, dim: -1)
187
+ attn_output = Torch.matmul(attn_weights, value_states)
188
+
189
+ # Reshape back
190
+ attn_output = attn_output.transpose(1, 2).contiguous
191
+ attn_output = attn_output.reshape(batch_size, seq_length, -1)
192
+
193
+ @o_proj.call(attn_output)
194
+ end
195
+
196
+ private
197
+
198
+ def apply_rotary_pos_emb(x, cos, sin)
199
+ # x: (batch, heads, seq, head_dim)
200
+ x1 = x[0.., 0.., 0.., 0...(@head_dim / 2)]
201
+ x2 = x[0.., 0.., 0.., (@head_dim / 2)..]
202
+
203
+ # Rotate
204
+ rotated = Torch.cat([-x2, x1], dim: -1)
205
+ (x * cos) + (rotated * sin)
206
+ end
207
+
208
+ def repeat_kv(x, n_rep)
209
+ batch, num_kv_heads, seq_len, head_dim = x.shape
210
+ return x if n_rep == 1
211
+
212
+ x = x.unsqueeze(2).expand(batch, num_kv_heads, n_rep, seq_len, head_dim)
213
+ x.reshape(batch, num_kv_heads * n_rep, seq_len, head_dim)
214
+ end
215
+ end
216
+
217
+ # Rotary Position Embedding
218
+ class RotaryEmbedding < Torch::NN::Module
219
+ def initialize(dim, max_position_embeddings, base)
220
+ super()
221
+
222
+ @dim = dim
223
+ @max_position_embeddings = max_position_embeddings
224
+ @base = base
225
+
226
+ # Precompute inverse frequencies
227
+ inv_freq = 1.0 / (base ** (Torch.arange(0, dim, 2).float / dim))
228
+ register_buffer("inv_freq", inv_freq)
229
+
230
+ # Build cos/sin cache
231
+ build_cache(max_position_embeddings)
232
+ end
233
+
234
+ def call(x, position_ids)
235
+ seq_len = position_ids.max.item + 1
236
+ build_cache(seq_len) if seq_len > @cos_cached.size(0)
237
+
238
+ # Move cached tensors to position_ids device if needed
239
+ device = position_ids.device
240
+ cos_cached = @cos_cached.to(device)
241
+ sin_cached = @sin_cached.to(device)
242
+
243
+ cos = cos_cached[position_ids].unsqueeze(1)
244
+ sin = sin_cached[position_ids].unsqueeze(1)
245
+
246
+ [cos, sin]
247
+ end
248
+
249
+ private
250
+
251
+ def build_cache(seq_len)
252
+ t = Torch.arange(seq_len, device: @inv_freq.device)
253
+ freqs = Torch.outer(t, @inv_freq)
254
+ emb = Torch.cat([freqs, freqs], dim: -1)
255
+
256
+ @cos_cached = emb.cos
257
+ @sin_cached = emb.sin
258
+ end
259
+ end
260
+
261
+ # Llama MLP with SwiGLU activation
262
+ class LlamaMLP < Torch::NN::Module
263
+ def initialize(hidden_size:, intermediate_size:)
264
+ super()
265
+
266
+ @gate_proj = Torch::NN::Linear.new(hidden_size, intermediate_size, bias: false)
267
+ @up_proj = Torch::NN::Linear.new(hidden_size, intermediate_size, bias: false)
268
+ @down_proj = Torch::NN::Linear.new(intermediate_size, hidden_size, bias: false)
269
+ end
270
+
271
+ def forward(x)
272
+ # SwiGLU: silu(gate) * up
273
+ # SiLU = x * sigmoid(x)
274
+ gate_out = @gate_proj.call(x)
275
+ gate = gate_out * Torch.sigmoid(gate_out)
276
+ up = @up_proj.call(x)
277
+ @down_proj.call(gate * up)
278
+ end
279
+ end
280
+
281
+ # RMS Normalization
282
+ class RMSNorm < Torch::NN::Module
283
+ def initialize(hidden_size, eps: 1e-6)
284
+ super()
285
+
286
+ @weight = Torch::NN::Parameter.new(Torch.ones(hidden_size))
287
+ @eps = eps
288
+ end
289
+
290
+ def forward(hidden_states)
291
+ variance = hidden_states.pow(2).mean(-1, keepdim: true)
292
+ hidden_states = hidden_states * Torch.rsqrt(variance + @eps)
293
+ @weight * hidden_states
294
+ end
295
+ end
296
+ end
297
+ end