fine 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.rspec +3 -0
- data/CHANGELOG.md +38 -0
- data/Gemfile +6 -0
- data/Gemfile.lock +167 -0
- data/LICENSE +21 -0
- data/README.md +212 -0
- data/Rakefile +6 -0
- data/docs/installation.md +151 -0
- data/docs/tutorials/llm-fine-tuning.md +246 -0
- data/docs/tutorials/model-export.md +200 -0
- data/docs/tutorials/siglip2-image-classification.md +130 -0
- data/docs/tutorials/siglip2-object-recognition.md +203 -0
- data/docs/tutorials/siglip2-similarity-search.md +152 -0
- data/docs/tutorials/text-classification.md +233 -0
- data/docs/tutorials/text-embeddings.md +211 -0
- data/examples/basic_classification.rb +70 -0
- data/examples/data/tool_calls.jsonl +30 -0
- data/examples/demo_training.rb +78 -0
- data/examples/finetune_gemma3_tools.rb +135 -0
- data/examples/real_llm_test.rb +128 -0
- data/examples/real_text_classification_test.rb +90 -0
- data/examples/real_text_embedder_test.rb +110 -0
- data/examples/real_training_test.rb +88 -0
- data/examples/test_export.rb +28 -0
- data/examples/test_image_classifier.rb +79 -0
- data/examples/test_llm.rb +100 -0
- data/examples/test_text_classifier.rb +59 -0
- data/lib/fine/callbacks/base.rb +140 -0
- data/lib/fine/callbacks/progress_bar.rb +66 -0
- data/lib/fine/configuration.rb +106 -0
- data/lib/fine/datasets/data_loader.rb +63 -0
- data/lib/fine/datasets/image_dataset.rb +203 -0
- data/lib/fine/datasets/instruction_dataset.rb +226 -0
- data/lib/fine/datasets/text_data_loader.rb +88 -0
- data/lib/fine/datasets/text_dataset.rb +266 -0
- data/lib/fine/error.rb +49 -0
- data/lib/fine/export/gguf_exporter.rb +424 -0
- data/lib/fine/export/onnx_exporter.rb +249 -0
- data/lib/fine/export.rb +53 -0
- data/lib/fine/hub/config_loader.rb +145 -0
- data/lib/fine/hub/model_downloader.rb +136 -0
- data/lib/fine/hub/safetensors_loader.rb +108 -0
- data/lib/fine/image_classifier.rb +256 -0
- data/lib/fine/llm.rb +336 -0
- data/lib/fine/models/base.rb +48 -0
- data/lib/fine/models/bert_encoder.rb +202 -0
- data/lib/fine/models/bert_for_sequence_classification.rb +226 -0
- data/lib/fine/models/causal_lm.rb +279 -0
- data/lib/fine/models/classification_head.rb +24 -0
- data/lib/fine/models/gemma3_decoder.rb +244 -0
- data/lib/fine/models/llama_decoder.rb +297 -0
- data/lib/fine/models/sentence_transformer.rb +202 -0
- data/lib/fine/models/siglip2_for_image_classification.rb +155 -0
- data/lib/fine/models/siglip2_vision_encoder.rb +190 -0
- data/lib/fine/text_classifier.rb +250 -0
- data/lib/fine/text_embedder.rb +221 -0
- data/lib/fine/tokenizers/auto_tokenizer.rb +208 -0
- data/lib/fine/training/llm_trainer.rb +212 -0
- data/lib/fine/training/text_trainer.rb +275 -0
- data/lib/fine/training/trainer.rb +194 -0
- data/lib/fine/transforms/compose.rb +28 -0
- data/lib/fine/transforms/normalize.rb +33 -0
- data/lib/fine/transforms/resize.rb +35 -0
- data/lib/fine/transforms/to_tensor.rb +53 -0
- data/lib/fine/version.rb +3 -0
- data/lib/fine.rb +112 -0
- data/mise.toml +2 -0
- metadata +240 -0
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fine
|
|
4
|
+
module Models
|
|
5
|
+
# Gemma 3 decoder-only transformer
|
|
6
|
+
#
|
|
7
|
+
# Differences from Llama:
|
|
8
|
+
# - Explicit head_dim config
|
|
9
|
+
# - QK normalization (q_norm, k_norm)
|
|
10
|
+
# - Pre/post feedforward layer norms
|
|
11
|
+
# - Different GQA ratios
|
|
12
|
+
class Gemma3Decoder < Base
|
|
13
|
+
def initialize(config)
|
|
14
|
+
super(config)
|
|
15
|
+
|
|
16
|
+
@vocab_size = config.vocab_size
|
|
17
|
+
@hidden_size = config.hidden_size
|
|
18
|
+
@num_layers = config.num_hidden_layers
|
|
19
|
+
@num_heads = config.num_attention_heads
|
|
20
|
+
@num_kv_heads = config.num_key_value_heads || @num_heads
|
|
21
|
+
@head_dim = config.head_dim
|
|
22
|
+
@intermediate_size = config.intermediate_size
|
|
23
|
+
@max_position_embeddings = config.max_position_embeddings || 32768
|
|
24
|
+
@rms_norm_eps = config.rms_norm_eps || 1e-6
|
|
25
|
+
@rope_theta = config.rope_theta || 1_000_000.0
|
|
26
|
+
|
|
27
|
+
# Token embeddings
|
|
28
|
+
@embed_tokens = Torch::NN::Embedding.new(@vocab_size, @hidden_size)
|
|
29
|
+
|
|
30
|
+
# Transformer layers
|
|
31
|
+
@layers = Torch::NN::ModuleList.new(
|
|
32
|
+
@num_layers.times.map do
|
|
33
|
+
Gemma3DecoderLayer.new(
|
|
34
|
+
hidden_size: @hidden_size,
|
|
35
|
+
num_heads: @num_heads,
|
|
36
|
+
num_kv_heads: @num_kv_heads,
|
|
37
|
+
head_dim: @head_dim,
|
|
38
|
+
intermediate_size: @intermediate_size,
|
|
39
|
+
rms_norm_eps: @rms_norm_eps,
|
|
40
|
+
rope_theta: @rope_theta,
|
|
41
|
+
max_position_embeddings: @max_position_embeddings
|
|
42
|
+
)
|
|
43
|
+
end
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
# Final layer norm
|
|
47
|
+
@norm = RMSNorm.new(@hidden_size, eps: @rms_norm_eps)
|
|
48
|
+
end
|
|
49
|
+
|
|
50
|
+
def forward(input_ids, attention_mask: nil, position_ids: nil)
|
|
51
|
+
batch_size, seq_length = input_ids.shape
|
|
52
|
+
|
|
53
|
+
# Get token embeddings
|
|
54
|
+
hidden_states = @embed_tokens.call(input_ids)
|
|
55
|
+
|
|
56
|
+
# Normalize embeddings (Gemma specific)
|
|
57
|
+
hidden_states = hidden_states * Math.sqrt(@hidden_size)
|
|
58
|
+
|
|
59
|
+
# Create position IDs if not provided
|
|
60
|
+
position_ids ||= Torch.arange(seq_length, device: input_ids.device)
|
|
61
|
+
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
|
62
|
+
|
|
63
|
+
# Create causal mask
|
|
64
|
+
causal_mask = create_causal_mask(seq_length, hidden_states.device)
|
|
65
|
+
|
|
66
|
+
# Combine with attention mask if provided
|
|
67
|
+
if attention_mask
|
|
68
|
+
expanded_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
|
69
|
+
expanded_mask = expanded_mask.expand(-1, -1, seq_length, -1)
|
|
70
|
+
causal_mask = causal_mask + (1.0 - expanded_mask) * -1e9
|
|
71
|
+
end
|
|
72
|
+
|
|
73
|
+
# Forward through layers
|
|
74
|
+
@layers.each do |layer|
|
|
75
|
+
hidden_states = layer.call(
|
|
76
|
+
hidden_states,
|
|
77
|
+
attention_mask: causal_mask,
|
|
78
|
+
position_ids: position_ids
|
|
79
|
+
)
|
|
80
|
+
end
|
|
81
|
+
|
|
82
|
+
# Final norm
|
|
83
|
+
hidden_states = @norm.call(hidden_states)
|
|
84
|
+
|
|
85
|
+
{ last_hidden_state: hidden_states }
|
|
86
|
+
end
|
|
87
|
+
|
|
88
|
+
private
|
|
89
|
+
|
|
90
|
+
def create_causal_mask(seq_length, device)
|
|
91
|
+
mask = Torch.triu(
|
|
92
|
+
Torch.ones(seq_length, seq_length, device: device) * -1e9,
|
|
93
|
+
diagonal: 1
|
|
94
|
+
)
|
|
95
|
+
mask.unsqueeze(0).unsqueeze(0)
|
|
96
|
+
end
|
|
97
|
+
end
|
|
98
|
+
|
|
99
|
+
# Single Gemma 3 decoder layer
|
|
100
|
+
class Gemma3DecoderLayer < Torch::NN::Module
|
|
101
|
+
def initialize(hidden_size:, num_heads:, num_kv_heads:, head_dim:,
|
|
102
|
+
intermediate_size:, rms_norm_eps:, rope_theta:, max_position_embeddings:)
|
|
103
|
+
super()
|
|
104
|
+
|
|
105
|
+
@self_attn = Gemma3Attention.new(
|
|
106
|
+
hidden_size: hidden_size,
|
|
107
|
+
num_heads: num_heads,
|
|
108
|
+
num_kv_heads: num_kv_heads,
|
|
109
|
+
head_dim: head_dim,
|
|
110
|
+
rope_theta: rope_theta,
|
|
111
|
+
max_position_embeddings: max_position_embeddings,
|
|
112
|
+
rms_norm_eps: rms_norm_eps
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
@mlp = LlamaMLP.new(
|
|
116
|
+
hidden_size: hidden_size,
|
|
117
|
+
intermediate_size: intermediate_size
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
@input_layernorm = RMSNorm.new(hidden_size, eps: rms_norm_eps)
|
|
121
|
+
@post_attention_layernorm = RMSNorm.new(hidden_size, eps: rms_norm_eps)
|
|
122
|
+
@pre_feedforward_layernorm = RMSNorm.new(hidden_size, eps: rms_norm_eps)
|
|
123
|
+
@post_feedforward_layernorm = RMSNorm.new(hidden_size, eps: rms_norm_eps)
|
|
124
|
+
end
|
|
125
|
+
|
|
126
|
+
def forward(hidden_states, attention_mask: nil, position_ids: nil)
|
|
127
|
+
# Self attention with residual
|
|
128
|
+
residual = hidden_states
|
|
129
|
+
hidden_states = @input_layernorm.call(hidden_states)
|
|
130
|
+
hidden_states = @self_attn.call(
|
|
131
|
+
hidden_states,
|
|
132
|
+
attention_mask: attention_mask,
|
|
133
|
+
position_ids: position_ids
|
|
134
|
+
)
|
|
135
|
+
hidden_states = @post_attention_layernorm.call(hidden_states)
|
|
136
|
+
hidden_states = residual + hidden_states
|
|
137
|
+
|
|
138
|
+
# MLP with residual
|
|
139
|
+
residual = hidden_states
|
|
140
|
+
hidden_states = @pre_feedforward_layernorm.call(hidden_states)
|
|
141
|
+
hidden_states = @mlp.call(hidden_states)
|
|
142
|
+
hidden_states = @post_feedforward_layernorm.call(hidden_states)
|
|
143
|
+
hidden_states = residual + hidden_states
|
|
144
|
+
|
|
145
|
+
hidden_states
|
|
146
|
+
end
|
|
147
|
+
end
|
|
148
|
+
|
|
149
|
+
# Gemma 3 attention with QK normalization
|
|
150
|
+
class Gemma3Attention < Torch::NN::Module
|
|
151
|
+
def initialize(hidden_size:, num_heads:, num_kv_heads:, head_dim:,
|
|
152
|
+
rope_theta:, max_position_embeddings:, rms_norm_eps:)
|
|
153
|
+
super()
|
|
154
|
+
|
|
155
|
+
@num_heads = num_heads
|
|
156
|
+
@num_kv_heads = num_kv_heads
|
|
157
|
+
@head_dim = head_dim
|
|
158
|
+
@num_key_value_groups = num_heads / num_kv_heads
|
|
159
|
+
|
|
160
|
+
@q_proj = Torch::NN::Linear.new(hidden_size, num_heads * head_dim, bias: false)
|
|
161
|
+
@k_proj = Torch::NN::Linear.new(hidden_size, num_kv_heads * head_dim, bias: false)
|
|
162
|
+
@v_proj = Torch::NN::Linear.new(hidden_size, num_kv_heads * head_dim, bias: false)
|
|
163
|
+
@o_proj = Torch::NN::Linear.new(num_heads * head_dim, hidden_size, bias: false)
|
|
164
|
+
|
|
165
|
+
# QK normalization (Gemma 3 specific)
|
|
166
|
+
@q_norm = RMSNorm.new(head_dim, eps: rms_norm_eps)
|
|
167
|
+
@k_norm = RMSNorm.new(head_dim, eps: rms_norm_eps)
|
|
168
|
+
|
|
169
|
+
@rotary_emb = RotaryEmbedding.new(head_dim, max_position_embeddings, rope_theta)
|
|
170
|
+
end
|
|
171
|
+
|
|
172
|
+
def forward(hidden_states, attention_mask: nil, position_ids: nil)
|
|
173
|
+
batch_size, seq_length, _ = hidden_states.shape
|
|
174
|
+
|
|
175
|
+
# Project to Q, K, V
|
|
176
|
+
query_states = @q_proj.call(hidden_states)
|
|
177
|
+
key_states = @k_proj.call(hidden_states)
|
|
178
|
+
value_states = @v_proj.call(hidden_states)
|
|
179
|
+
|
|
180
|
+
# Reshape for multi-head attention
|
|
181
|
+
query_states = query_states.view(batch_size, seq_length, @num_heads, @head_dim).transpose(1, 2)
|
|
182
|
+
key_states = key_states.view(batch_size, seq_length, @num_kv_heads, @head_dim).transpose(1, 2)
|
|
183
|
+
value_states = value_states.view(batch_size, seq_length, @num_kv_heads, @head_dim).transpose(1, 2)
|
|
184
|
+
|
|
185
|
+
# Apply QK normalization
|
|
186
|
+
query_states = apply_qk_norm(query_states, @q_norm)
|
|
187
|
+
key_states = apply_qk_norm(key_states, @k_norm)
|
|
188
|
+
|
|
189
|
+
# Apply rotary embeddings
|
|
190
|
+
cos, sin = @rotary_emb.call(value_states, position_ids)
|
|
191
|
+
query_states = apply_rotary_pos_emb(query_states, cos, sin)
|
|
192
|
+
key_states = apply_rotary_pos_emb(key_states, cos, sin)
|
|
193
|
+
|
|
194
|
+
# Repeat KV heads for grouped-query attention
|
|
195
|
+
if @num_key_value_groups > 1
|
|
196
|
+
key_states = repeat_kv(key_states, @num_key_value_groups)
|
|
197
|
+
value_states = repeat_kv(value_states, @num_key_value_groups)
|
|
198
|
+
end
|
|
199
|
+
|
|
200
|
+
# Attention with softcapping (Gemma uses sqrt(head_dim) scaling)
|
|
201
|
+
scale = @head_dim ** -0.5
|
|
202
|
+
attn_weights = Torch.matmul(query_states, key_states.transpose(-2, -1)) * scale
|
|
203
|
+
|
|
204
|
+
# Apply causal mask
|
|
205
|
+
attn_weights = attn_weights + attention_mask if attention_mask
|
|
206
|
+
|
|
207
|
+
attn_weights = Torch::NN::Functional.softmax(attn_weights, dim: -1)
|
|
208
|
+
attn_output = Torch.matmul(attn_weights, value_states)
|
|
209
|
+
|
|
210
|
+
# Reshape back
|
|
211
|
+
attn_output = attn_output.transpose(1, 2).contiguous
|
|
212
|
+
attn_output = attn_output.reshape(batch_size, seq_length, -1)
|
|
213
|
+
|
|
214
|
+
@o_proj.call(attn_output)
|
|
215
|
+
end
|
|
216
|
+
|
|
217
|
+
private
|
|
218
|
+
|
|
219
|
+
def apply_qk_norm(states, norm)
|
|
220
|
+
# states: (batch, heads, seq, head_dim)
|
|
221
|
+
# Need to apply norm per head
|
|
222
|
+
batch, heads, seq, head_dim = states.shape
|
|
223
|
+
states = states.transpose(1, 2).reshape(batch * seq, heads, head_dim)
|
|
224
|
+
states = norm.call(states)
|
|
225
|
+
states.reshape(batch, seq, heads, head_dim).transpose(1, 2)
|
|
226
|
+
end
|
|
227
|
+
|
|
228
|
+
def apply_rotary_pos_emb(x, cos, sin)
|
|
229
|
+
x1 = x[0.., 0.., 0.., 0...(@head_dim / 2)]
|
|
230
|
+
x2 = x[0.., 0.., 0.., (@head_dim / 2)..]
|
|
231
|
+
rotated = Torch.cat([-x2, x1], dim: -1)
|
|
232
|
+
(x * cos) + (rotated * sin)
|
|
233
|
+
end
|
|
234
|
+
|
|
235
|
+
def repeat_kv(x, n_rep)
|
|
236
|
+
batch, num_kv_heads, seq_len, head_dim = x.shape
|
|
237
|
+
return x if n_rep == 1
|
|
238
|
+
|
|
239
|
+
x = x.unsqueeze(2).expand(batch, num_kv_heads, n_rep, seq_len, head_dim)
|
|
240
|
+
x.reshape(batch, num_kv_heads * n_rep, seq_len, head_dim)
|
|
241
|
+
end
|
|
242
|
+
end
|
|
243
|
+
end
|
|
244
|
+
end
|
|
@@ -0,0 +1,297 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fine
|
|
4
|
+
module Models
|
|
5
|
+
# Llama-style decoder-only transformer
|
|
6
|
+
#
|
|
7
|
+
# Compatible with Llama, Gemma, Mistral, Qwen architectures.
|
|
8
|
+
# Uses RoPE positional embeddings, RMSNorm, and SwiGLU activation.
|
|
9
|
+
class LlamaDecoder < Base
|
|
10
|
+
def initialize(config)
|
|
11
|
+
super(config)
|
|
12
|
+
|
|
13
|
+
@vocab_size = config.vocab_size
|
|
14
|
+
@hidden_size = config.hidden_size
|
|
15
|
+
@num_layers = config.num_hidden_layers
|
|
16
|
+
@num_heads = config.num_attention_heads
|
|
17
|
+
@num_kv_heads = config.num_key_value_heads || @num_heads
|
|
18
|
+
@intermediate_size = config.intermediate_size
|
|
19
|
+
@max_position_embeddings = config.max_position_embeddings || 2048
|
|
20
|
+
@rms_norm_eps = config.rms_norm_eps || 1e-6
|
|
21
|
+
@rope_theta = config.rope_theta || 10000.0
|
|
22
|
+
|
|
23
|
+
# Token embeddings
|
|
24
|
+
@embed_tokens = Torch::NN::Embedding.new(@vocab_size, @hidden_size)
|
|
25
|
+
|
|
26
|
+
# Transformer layers
|
|
27
|
+
@layers = Torch::NN::ModuleList.new(
|
|
28
|
+
@num_layers.times.map do
|
|
29
|
+
LlamaDecoderLayer.new(
|
|
30
|
+
hidden_size: @hidden_size,
|
|
31
|
+
num_heads: @num_heads,
|
|
32
|
+
num_kv_heads: @num_kv_heads,
|
|
33
|
+
intermediate_size: @intermediate_size,
|
|
34
|
+
rms_norm_eps: @rms_norm_eps,
|
|
35
|
+
rope_theta: @rope_theta,
|
|
36
|
+
max_position_embeddings: @max_position_embeddings
|
|
37
|
+
)
|
|
38
|
+
end
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
# Final layer norm
|
|
42
|
+
@norm = RMSNorm.new(@hidden_size, eps: @rms_norm_eps)
|
|
43
|
+
end
|
|
44
|
+
|
|
45
|
+
def forward(input_ids, attention_mask: nil, position_ids: nil)
|
|
46
|
+
batch_size, seq_length = input_ids.shape
|
|
47
|
+
|
|
48
|
+
# Get token embeddings
|
|
49
|
+
hidden_states = @embed_tokens.call(input_ids)
|
|
50
|
+
|
|
51
|
+
# Create position IDs if not provided
|
|
52
|
+
position_ids ||= Torch.arange(seq_length, device: input_ids.device)
|
|
53
|
+
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
|
54
|
+
|
|
55
|
+
# Create causal mask
|
|
56
|
+
causal_mask = create_causal_mask(seq_length, hidden_states.device)
|
|
57
|
+
|
|
58
|
+
# Combine with attention mask if provided
|
|
59
|
+
if attention_mask
|
|
60
|
+
# Expand attention mask: (batch, seq) -> (batch, 1, seq, seq)
|
|
61
|
+
expanded_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
|
62
|
+
expanded_mask = expanded_mask.expand(-1, -1, seq_length, -1)
|
|
63
|
+
causal_mask = causal_mask + (1.0 - expanded_mask) * -1e9
|
|
64
|
+
end
|
|
65
|
+
|
|
66
|
+
# Forward through layers
|
|
67
|
+
@layers.each do |layer|
|
|
68
|
+
hidden_states = layer.call(
|
|
69
|
+
hidden_states,
|
|
70
|
+
attention_mask: causal_mask,
|
|
71
|
+
position_ids: position_ids
|
|
72
|
+
)
|
|
73
|
+
end
|
|
74
|
+
|
|
75
|
+
# Final norm
|
|
76
|
+
hidden_states = @norm.call(hidden_states)
|
|
77
|
+
|
|
78
|
+
{ last_hidden_state: hidden_states }
|
|
79
|
+
end
|
|
80
|
+
|
|
81
|
+
private
|
|
82
|
+
|
|
83
|
+
def create_causal_mask(seq_length, device)
|
|
84
|
+
# Lower triangular mask for causal attention
|
|
85
|
+
mask = Torch.triu(
|
|
86
|
+
Torch.ones(seq_length, seq_length, device: device) * -1e9,
|
|
87
|
+
diagonal: 1
|
|
88
|
+
)
|
|
89
|
+
mask.unsqueeze(0).unsqueeze(0)
|
|
90
|
+
end
|
|
91
|
+
end
|
|
92
|
+
|
|
93
|
+
# Single Llama decoder layer
|
|
94
|
+
class LlamaDecoderLayer < Torch::NN::Module
|
|
95
|
+
def initialize(hidden_size:, num_heads:, num_kv_heads:, intermediate_size:,
|
|
96
|
+
rms_norm_eps:, rope_theta:, max_position_embeddings:)
|
|
97
|
+
super()
|
|
98
|
+
|
|
99
|
+
@self_attn = LlamaAttention.new(
|
|
100
|
+
hidden_size: hidden_size,
|
|
101
|
+
num_heads: num_heads,
|
|
102
|
+
num_kv_heads: num_kv_heads,
|
|
103
|
+
rope_theta: rope_theta,
|
|
104
|
+
max_position_embeddings: max_position_embeddings
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
@mlp = LlamaMLP.new(
|
|
108
|
+
hidden_size: hidden_size,
|
|
109
|
+
intermediate_size: intermediate_size
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
@input_layernorm = RMSNorm.new(hidden_size, eps: rms_norm_eps)
|
|
113
|
+
@post_attention_layernorm = RMSNorm.new(hidden_size, eps: rms_norm_eps)
|
|
114
|
+
end
|
|
115
|
+
|
|
116
|
+
def forward(hidden_states, attention_mask: nil, position_ids: nil)
|
|
117
|
+
# Self attention with residual
|
|
118
|
+
residual = hidden_states
|
|
119
|
+
hidden_states = @input_layernorm.call(hidden_states)
|
|
120
|
+
hidden_states = @self_attn.call(
|
|
121
|
+
hidden_states,
|
|
122
|
+
attention_mask: attention_mask,
|
|
123
|
+
position_ids: position_ids
|
|
124
|
+
)
|
|
125
|
+
hidden_states = residual + hidden_states
|
|
126
|
+
|
|
127
|
+
# MLP with residual
|
|
128
|
+
residual = hidden_states
|
|
129
|
+
hidden_states = @post_attention_layernorm.call(hidden_states)
|
|
130
|
+
hidden_states = @mlp.call(hidden_states)
|
|
131
|
+
hidden_states = residual + hidden_states
|
|
132
|
+
|
|
133
|
+
hidden_states
|
|
134
|
+
end
|
|
135
|
+
end
|
|
136
|
+
|
|
137
|
+
# Llama attention with RoPE and grouped-query attention
|
|
138
|
+
class LlamaAttention < Torch::NN::Module
|
|
139
|
+
def initialize(hidden_size:, num_heads:, num_kv_heads:, rope_theta:, max_position_embeddings:)
|
|
140
|
+
super()
|
|
141
|
+
|
|
142
|
+
@num_heads = num_heads
|
|
143
|
+
@num_kv_heads = num_kv_heads
|
|
144
|
+
@head_dim = hidden_size / num_heads
|
|
145
|
+
@num_key_value_groups = num_heads / num_kv_heads
|
|
146
|
+
|
|
147
|
+
@q_proj = Torch::NN::Linear.new(hidden_size, num_heads * @head_dim, bias: false)
|
|
148
|
+
@k_proj = Torch::NN::Linear.new(hidden_size, num_kv_heads * @head_dim, bias: false)
|
|
149
|
+
@v_proj = Torch::NN::Linear.new(hidden_size, num_kv_heads * @head_dim, bias: false)
|
|
150
|
+
@o_proj = Torch::NN::Linear.new(num_heads * @head_dim, hidden_size, bias: false)
|
|
151
|
+
|
|
152
|
+
@rotary_emb = RotaryEmbedding.new(@head_dim, max_position_embeddings, rope_theta)
|
|
153
|
+
end
|
|
154
|
+
|
|
155
|
+
def forward(hidden_states, attention_mask: nil, position_ids: nil)
|
|
156
|
+
batch_size, seq_length, _ = hidden_states.shape
|
|
157
|
+
|
|
158
|
+
# Project to Q, K, V
|
|
159
|
+
query_states = @q_proj.call(hidden_states)
|
|
160
|
+
key_states = @k_proj.call(hidden_states)
|
|
161
|
+
value_states = @v_proj.call(hidden_states)
|
|
162
|
+
|
|
163
|
+
# Reshape for multi-head attention
|
|
164
|
+
query_states = query_states.view(batch_size, seq_length, @num_heads, @head_dim).transpose(1, 2)
|
|
165
|
+
key_states = key_states.view(batch_size, seq_length, @num_kv_heads, @head_dim).transpose(1, 2)
|
|
166
|
+
value_states = value_states.view(batch_size, seq_length, @num_kv_heads, @head_dim).transpose(1, 2)
|
|
167
|
+
|
|
168
|
+
# Apply rotary embeddings
|
|
169
|
+
cos, sin = @rotary_emb.call(value_states, position_ids)
|
|
170
|
+
query_states = apply_rotary_pos_emb(query_states, cos, sin)
|
|
171
|
+
key_states = apply_rotary_pos_emb(key_states, cos, sin)
|
|
172
|
+
|
|
173
|
+
# Repeat KV heads for grouped-query attention
|
|
174
|
+
if @num_key_value_groups > 1
|
|
175
|
+
key_states = repeat_kv(key_states, @num_key_value_groups)
|
|
176
|
+
value_states = repeat_kv(value_states, @num_key_value_groups)
|
|
177
|
+
end
|
|
178
|
+
|
|
179
|
+
# Attention
|
|
180
|
+
scale = @head_dim ** -0.5
|
|
181
|
+
attn_weights = Torch.matmul(query_states, key_states.transpose(-2, -1)) * scale
|
|
182
|
+
|
|
183
|
+
# Apply causal mask
|
|
184
|
+
attn_weights = attn_weights + attention_mask if attention_mask
|
|
185
|
+
|
|
186
|
+
attn_weights = Torch::NN::Functional.softmax(attn_weights, dim: -1)
|
|
187
|
+
attn_output = Torch.matmul(attn_weights, value_states)
|
|
188
|
+
|
|
189
|
+
# Reshape back
|
|
190
|
+
attn_output = attn_output.transpose(1, 2).contiguous
|
|
191
|
+
attn_output = attn_output.reshape(batch_size, seq_length, -1)
|
|
192
|
+
|
|
193
|
+
@o_proj.call(attn_output)
|
|
194
|
+
end
|
|
195
|
+
|
|
196
|
+
private
|
|
197
|
+
|
|
198
|
+
def apply_rotary_pos_emb(x, cos, sin)
|
|
199
|
+
# x: (batch, heads, seq, head_dim)
|
|
200
|
+
x1 = x[0.., 0.., 0.., 0...(@head_dim / 2)]
|
|
201
|
+
x2 = x[0.., 0.., 0.., (@head_dim / 2)..]
|
|
202
|
+
|
|
203
|
+
# Rotate
|
|
204
|
+
rotated = Torch.cat([-x2, x1], dim: -1)
|
|
205
|
+
(x * cos) + (rotated * sin)
|
|
206
|
+
end
|
|
207
|
+
|
|
208
|
+
def repeat_kv(x, n_rep)
|
|
209
|
+
batch, num_kv_heads, seq_len, head_dim = x.shape
|
|
210
|
+
return x if n_rep == 1
|
|
211
|
+
|
|
212
|
+
x = x.unsqueeze(2).expand(batch, num_kv_heads, n_rep, seq_len, head_dim)
|
|
213
|
+
x.reshape(batch, num_kv_heads * n_rep, seq_len, head_dim)
|
|
214
|
+
end
|
|
215
|
+
end
|
|
216
|
+
|
|
217
|
+
# Rotary Position Embedding
|
|
218
|
+
class RotaryEmbedding < Torch::NN::Module
|
|
219
|
+
def initialize(dim, max_position_embeddings, base)
|
|
220
|
+
super()
|
|
221
|
+
|
|
222
|
+
@dim = dim
|
|
223
|
+
@max_position_embeddings = max_position_embeddings
|
|
224
|
+
@base = base
|
|
225
|
+
|
|
226
|
+
# Precompute inverse frequencies
|
|
227
|
+
inv_freq = 1.0 / (base ** (Torch.arange(0, dim, 2).float / dim))
|
|
228
|
+
register_buffer("inv_freq", inv_freq)
|
|
229
|
+
|
|
230
|
+
# Build cos/sin cache
|
|
231
|
+
build_cache(max_position_embeddings)
|
|
232
|
+
end
|
|
233
|
+
|
|
234
|
+
def call(x, position_ids)
|
|
235
|
+
seq_len = position_ids.max.item + 1
|
|
236
|
+
build_cache(seq_len) if seq_len > @cos_cached.size(0)
|
|
237
|
+
|
|
238
|
+
# Move cached tensors to position_ids device if needed
|
|
239
|
+
device = position_ids.device
|
|
240
|
+
cos_cached = @cos_cached.to(device)
|
|
241
|
+
sin_cached = @sin_cached.to(device)
|
|
242
|
+
|
|
243
|
+
cos = cos_cached[position_ids].unsqueeze(1)
|
|
244
|
+
sin = sin_cached[position_ids].unsqueeze(1)
|
|
245
|
+
|
|
246
|
+
[cos, sin]
|
|
247
|
+
end
|
|
248
|
+
|
|
249
|
+
private
|
|
250
|
+
|
|
251
|
+
def build_cache(seq_len)
|
|
252
|
+
t = Torch.arange(seq_len, device: @inv_freq.device)
|
|
253
|
+
freqs = Torch.outer(t, @inv_freq)
|
|
254
|
+
emb = Torch.cat([freqs, freqs], dim: -1)
|
|
255
|
+
|
|
256
|
+
@cos_cached = emb.cos
|
|
257
|
+
@sin_cached = emb.sin
|
|
258
|
+
end
|
|
259
|
+
end
|
|
260
|
+
|
|
261
|
+
# Llama MLP with SwiGLU activation
|
|
262
|
+
class LlamaMLP < Torch::NN::Module
|
|
263
|
+
def initialize(hidden_size:, intermediate_size:)
|
|
264
|
+
super()
|
|
265
|
+
|
|
266
|
+
@gate_proj = Torch::NN::Linear.new(hidden_size, intermediate_size, bias: false)
|
|
267
|
+
@up_proj = Torch::NN::Linear.new(hidden_size, intermediate_size, bias: false)
|
|
268
|
+
@down_proj = Torch::NN::Linear.new(intermediate_size, hidden_size, bias: false)
|
|
269
|
+
end
|
|
270
|
+
|
|
271
|
+
def forward(x)
|
|
272
|
+
# SwiGLU: silu(gate) * up
|
|
273
|
+
# SiLU = x * sigmoid(x)
|
|
274
|
+
gate_out = @gate_proj.call(x)
|
|
275
|
+
gate = gate_out * Torch.sigmoid(gate_out)
|
|
276
|
+
up = @up_proj.call(x)
|
|
277
|
+
@down_proj.call(gate * up)
|
|
278
|
+
end
|
|
279
|
+
end
|
|
280
|
+
|
|
281
|
+
# RMS Normalization
|
|
282
|
+
class RMSNorm < Torch::NN::Module
|
|
283
|
+
def initialize(hidden_size, eps: 1e-6)
|
|
284
|
+
super()
|
|
285
|
+
|
|
286
|
+
@weight = Torch::NN::Parameter.new(Torch.ones(hidden_size))
|
|
287
|
+
@eps = eps
|
|
288
|
+
end
|
|
289
|
+
|
|
290
|
+
def forward(hidden_states)
|
|
291
|
+
variance = hidden_states.pow(2).mean(-1, keepdim: true)
|
|
292
|
+
hidden_states = hidden_states * Torch.rsqrt(variance + @eps)
|
|
293
|
+
@weight * hidden_states
|
|
294
|
+
end
|
|
295
|
+
end
|
|
296
|
+
end
|
|
297
|
+
end
|