fine 0.1.0 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +20 -10
  3. data/docs/examples/image-classification-shapes.md +83 -0
  4. data/docs/examples/text-embeddings-faq.md +98 -0
  5. data/docs/quickstart.md +209 -0
  6. data/docs/tutorials/lora-tool-calling.md +306 -0
  7. data/examples/data/generate_tool_data.rb +261 -0
  8. data/examples/data/ollama_tool_calls.jsonl +40 -0
  9. data/examples/data/sentiment_reviews.jsonl +30 -0
  10. data/examples/data/shapes/circle/circle_1.jpg +0 -0
  11. data/examples/data/shapes/circle/circle_10.jpg +0 -0
  12. data/examples/data/shapes/circle/circle_2.jpg +0 -0
  13. data/examples/data/shapes/circle/circle_3.jpg +0 -0
  14. data/examples/data/shapes/circle/circle_4.jpg +0 -0
  15. data/examples/data/shapes/circle/circle_5.jpg +0 -0
  16. data/examples/data/shapes/circle/circle_6.jpg +0 -0
  17. data/examples/data/shapes/circle/circle_7.jpg +0 -0
  18. data/examples/data/shapes/circle/circle_8.jpg +0 -0
  19. data/examples/data/shapes/circle/circle_9.jpg +0 -0
  20. data/examples/data/shapes/square/square_1.jpg +0 -0
  21. data/examples/data/shapes/square/square_10.jpg +0 -0
  22. data/examples/data/shapes/square/square_2.jpg +0 -0
  23. data/examples/data/shapes/square/square_3.jpg +0 -0
  24. data/examples/data/shapes/square/square_4.jpg +0 -0
  25. data/examples/data/shapes/square/square_5.jpg +0 -0
  26. data/examples/data/shapes/square/square_6.jpg +0 -0
  27. data/examples/data/shapes/square/square_7.jpg +0 -0
  28. data/examples/data/shapes/square/square_8.jpg +0 -0
  29. data/examples/data/shapes/square/square_9.jpg +0 -0
  30. data/examples/data/shapes/triangle/triangle_1.jpg +0 -0
  31. data/examples/data/shapes/triangle/triangle_10.jpg +0 -0
  32. data/examples/data/shapes/triangle/triangle_2.jpg +0 -0
  33. data/examples/data/shapes/triangle/triangle_3.jpg +0 -0
  34. data/examples/data/shapes/triangle/triangle_4.jpg +0 -0
  35. data/examples/data/shapes/triangle/triangle_5.jpg +0 -0
  36. data/examples/data/shapes/triangle/triangle_6.jpg +0 -0
  37. data/examples/data/shapes/triangle/triangle_7.jpg +0 -0
  38. data/examples/data/shapes/triangle/triangle_8.jpg +0 -0
  39. data/examples/data/shapes/triangle/triangle_9.jpg +0 -0
  40. data/examples/data/support_faq_pairs.jsonl +30 -0
  41. data/examples/generate_shape_images.rb +94 -0
  42. data/examples/sentiment_classification.rb +87 -0
  43. data/examples/shape_classification.rb +87 -0
  44. data/examples/support_faq_embeddings.rb +105 -0
  45. data/examples/train_lora_tools.rb +218 -0
  46. data/lib/fine/configuration.rb +173 -15
  47. data/lib/fine/datasets/image_dataset.rb +14 -2
  48. data/lib/fine/datasets/instruction_dataset.rb +17 -2
  49. data/lib/fine/datasets/text_dataset.rb +15 -5
  50. data/lib/fine/hub/config_loader.rb +4 -4
  51. data/lib/fine/hub/safetensors_loader.rb +3 -2
  52. data/lib/fine/llm.rb +39 -10
  53. data/lib/fine/lora.rb +214 -0
  54. data/lib/fine/models/bert_encoder.rb +15 -6
  55. data/lib/fine/models/bert_for_sequence_classification.rb +35 -4
  56. data/lib/fine/models/causal_lm.rb +46 -5
  57. data/lib/fine/models/gemma3_decoder.rb +25 -6
  58. data/lib/fine/models/llama_decoder.rb +9 -8
  59. data/lib/fine/models/sentence_transformer.rb +1 -1
  60. data/lib/fine/tokenizers/auto_tokenizer.rb +15 -0
  61. data/lib/fine/training/text_trainer.rb +3 -1
  62. data/lib/fine/validators.rb +304 -0
  63. data/lib/fine/version.rb +1 -1
  64. data/lib/fine.rb +4 -0
  65. metadata +47 -2
data/lib/fine/lora.rb ADDED
@@ -0,0 +1,214 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Fine
4
+ # Low-Rank Adaptation (LoRA) for parameter-efficient fine-tuning
5
+ #
6
+ # LoRA freezes the pretrained model weights and injects trainable
7
+ # rank decomposition matrices into each layer, dramatically reducing
8
+ # the number of trainable parameters.
9
+ #
10
+ # @example
11
+ # model = Fine::Models::CausalLM.from_pretrained("google/gemma-3-4b-it")
12
+ # lora_model = Fine::LoRA.apply(model, rank: 8, alpha: 16, target_modules: ["q_proj", "v_proj"])
13
+ # # Only LoRA parameters are trainable now
14
+ #
15
+ module LoRA
16
+ # LoRA Linear layer that wraps an existing Linear layer
17
+ class LoRALinear < Torch::NN::Module
18
+ attr_reader :in_features, :out_features, :rank, :alpha, :scaling
19
+
20
+ def initialize(original_layer, rank: 8, alpha: 16, dropout: 0.0)
21
+ super()
22
+
23
+ @in_features = original_layer.weight.shape[1]
24
+ @out_features = original_layer.weight.shape[0]
25
+ @rank = rank
26
+ @alpha = alpha
27
+ @scaling = alpha.to_f / rank
28
+
29
+ # Match dtype of original layer
30
+ @dtype = original_layer.weight.dtype
31
+ @device = original_layer.weight.device
32
+
33
+ # Store original layer (frozen)
34
+ @original = original_layer
35
+ @original.weight.requires_grad = false
36
+ @original.bias&.requires_grad = false if @original.respond_to?(:bias) && @original.bias
37
+
38
+ # LoRA matrices A and B - match dtype of original layer
39
+ # W' = W + (B @ A) * scaling
40
+ # A: (rank, in_features) - initialized with Kaiming uniform
41
+ # B: (out_features, rank) - initialized with zeros
42
+ @lora_a = Torch::NN::Parameter.new(
43
+ Torch.empty(@rank, @in_features, dtype: @dtype, device: @device)
44
+ )
45
+ @lora_b = Torch::NN::Parameter.new(
46
+ Torch.zeros(@out_features, @rank, dtype: @dtype, device: @device)
47
+ )
48
+
49
+ # Initialize A with Kaiming uniform (in float32, then convert)
50
+ temp_a = Torch.empty(@rank, @in_features)
51
+ Torch::NN::Init.kaiming_uniform!(temp_a, a: Math.sqrt(5))
52
+ @lora_a.data.copy!(temp_a.to(@dtype))
53
+
54
+ # Optional dropout
55
+ @dropout = dropout > 0 ? Torch::NN::Dropout.new(p: dropout) : nil
56
+ end
57
+
58
+ def forward(x)
59
+ # Original forward pass (frozen)
60
+ original_out = @original.call(x)
61
+
62
+ # LoRA forward: x @ A.T @ B.T * scaling
63
+ lora_out = x
64
+ lora_out = @dropout.call(lora_out) if @dropout
65
+ lora_out = lora_out.matmul(@lora_a.t)
66
+ lora_out = lora_out.matmul(@lora_b.t)
67
+ lora_out = lora_out * @scaling
68
+
69
+ original_out + lora_out
70
+ end
71
+
72
+ # Number of trainable parameters
73
+ def trainable_params
74
+ @rank * @in_features + @out_features * @rank
75
+ end
76
+
77
+ # Merge LoRA weights into original layer (for inference)
78
+ def merge!
79
+ Torch.no_grad do
80
+ delta_w = @lora_b.matmul(@lora_a) * @scaling
81
+ @original.weight.add!(delta_w)
82
+ end
83
+ end
84
+ end
85
+
86
+ class << self
87
+ # Apply LoRA to a model
88
+ #
89
+ # @param model [Torch::NN::Module] Model to apply LoRA to
90
+ # @param rank [Integer] LoRA rank (lower = fewer params, higher = more capacity)
91
+ # @param alpha [Integer] LoRA alpha (scaling factor)
92
+ # @param dropout [Float] Dropout probability for LoRA layers
93
+ # @param target_modules [Array<String>] Module names to apply LoRA to
94
+ # @return [Torch::NN::Module] Model with LoRA applied
95
+ def apply(model, rank: 8, alpha: 16, dropout: 0.0, target_modules: nil)
96
+ target_modules ||= default_target_modules
97
+
98
+ # First freeze all parameters
99
+ model.parameters.each { |p| p.requires_grad = false }
100
+
101
+ # Track replacements
102
+ replacements = []
103
+ total_lora_params = 0
104
+
105
+ # Find and replace target modules
106
+ find_modules(model, target_modules) do |parent, name, layer|
107
+ next unless layer.is_a?(Torch::NN::Linear)
108
+
109
+ lora_layer = LoRALinear.new(layer, rank: rank, alpha: alpha, dropout: dropout)
110
+ replacements << [parent, name, lora_layer]
111
+ total_lora_params += lora_layer.trainable_params
112
+ end
113
+
114
+ # Apply replacements
115
+ replacements.each do |parent, name, lora_layer|
116
+ parent.instance_variable_set("@#{name}", lora_layer)
117
+ end
118
+
119
+ # Calculate stats
120
+ total_params = count_params(model)
121
+ trainable = count_trainable_params(model)
122
+
123
+ puts " LoRA applied to #{replacements.size} layers"
124
+ puts " Total params: #{format_params(total_params)}"
125
+ puts " Trainable params: #{format_params(trainable)} (#{(trainable.to_f / total_params * 100).round(2)}%)"
126
+
127
+ model
128
+ end
129
+
130
+ # Merge LoRA weights into base model (for efficient inference)
131
+ def merge!(model)
132
+ find_lora_layers(model) do |lora_layer|
133
+ lora_layer.merge!
134
+ end
135
+ model
136
+ end
137
+
138
+ # Get only trainable (LoRA) parameters
139
+ def trainable_parameters(model)
140
+ params = []
141
+ find_lora_layers(model) do |lora_layer|
142
+ params << lora_layer.lora_a
143
+ params << lora_layer.lora_b
144
+ end
145
+ params
146
+ end
147
+
148
+ # Default modules to apply LoRA to (attention projections)
149
+ def default_target_modules
150
+ %w[q_proj k_proj v_proj o_proj]
151
+ end
152
+
153
+ private
154
+
155
+ def find_modules(model, target_names, parent = nil, prefix = "", &block)
156
+ model.instance_variables.each do |ivar|
157
+ name = ivar.to_s.delete_prefix("@")
158
+ child = model.instance_variable_get(ivar)
159
+
160
+ if child.is_a?(Torch::NN::Module)
161
+ full_name = prefix.empty? ? name : "#{prefix}.#{name}"
162
+
163
+ if target_names.any? { |t| name == t || name.end_with?(t) }
164
+ yield(model, name, child)
165
+ end
166
+
167
+ # Recurse into ModuleList
168
+ if child.is_a?(Torch::NN::ModuleList)
169
+ child.each_with_index do |layer, idx|
170
+ find_modules(layer, target_names, child, "#{full_name}[#{idx}]", &block)
171
+ end
172
+ else
173
+ find_modules(child, target_names, model, full_name, &block)
174
+ end
175
+ end
176
+ end
177
+ end
178
+
179
+ def find_lora_layers(model, &block)
180
+ model.instance_variables.each do |ivar|
181
+ child = model.instance_variable_get(ivar)
182
+
183
+ if child.is_a?(LoRALinear)
184
+ yield(child)
185
+ elsif child.is_a?(Torch::NN::ModuleList)
186
+ child.each { |layer| find_lora_layers(layer, &block) }
187
+ elsif child.is_a?(Torch::NN::Module)
188
+ find_lora_layers(child, &block)
189
+ end
190
+ end
191
+ end
192
+
193
+ def count_params(model)
194
+ model.parameters.sum { |p| p.numel }
195
+ end
196
+
197
+ def count_trainable_params(model)
198
+ model.parameters.select { |p| p.requires_grad }.sum { |p| p.numel }
199
+ end
200
+
201
+ def format_params(n)
202
+ if n >= 1_000_000_000
203
+ "#{(n / 1_000_000_000.0).round(2)}B"
204
+ elsif n >= 1_000_000
205
+ "#{(n / 1_000_000.0).round(2)}M"
206
+ elsif n >= 1_000
207
+ "#{(n / 1_000.0).round(2)}K"
208
+ else
209
+ n.to_s
210
+ end
211
+ end
212
+ end
213
+ end
214
+ end
@@ -9,7 +9,7 @@ module Fine
9
9
  class BertEncoder < Base
10
10
  attr_reader :embeddings, :encoder, :pooler
11
11
 
12
- def initialize(config)
12
+ def initialize(config, use_pooler: true)
13
13
  super(config)
14
14
 
15
15
  @hidden_size = config.hidden_size
@@ -21,6 +21,7 @@ module Fine
21
21
  @type_vocab_size = config.type_vocab_size || 2
22
22
  @layer_norm_eps = config.layer_norm_eps
23
23
  @hidden_dropout_prob = config.hidden_dropout_prob || 0.1
24
+ @use_pooler = use_pooler
24
25
 
25
26
  # Embeddings
26
27
  @word_embeddings = Torch::NN::Embedding.new(@vocab_size, @hidden_size)
@@ -42,9 +43,11 @@ module Fine
42
43
  end
43
44
  )
44
45
 
45
- # Pooler (for [CLS] token representation)
46
- @pooler_dense = Torch::NN::Linear.new(@hidden_size, @hidden_size)
47
- @pooler_activation = Torch::NN::Tanh.new
46
+ # Pooler (for [CLS] token representation) - optional for models like DistilBERT
47
+ if @use_pooler
48
+ @pooler_dense = Torch::NN::Linear.new(@hidden_size, @hidden_size)
49
+ @pooler_activation = Torch::NN::Tanh.new
50
+ end
48
51
  end
49
52
 
50
53
  def forward(input_ids, attention_mask: nil, token_type_ids: nil)
@@ -83,8 +86,14 @@ module Fine
83
86
 
84
87
  # Pool the [CLS] token (first token)
85
88
  cls_output = hidden_states[0.., 0, 0..]
86
- pooled_output = @pooler_dense.call(cls_output)
87
- pooled_output = @pooler_activation.call(pooled_output)
89
+
90
+ # Apply pooler if available, otherwise use CLS directly
91
+ pooled_output = if @use_pooler && @pooler_dense
92
+ temp = @pooler_dense.call(cls_output)
93
+ @pooler_activation.call(temp)
94
+ else
95
+ cls_output
96
+ end
88
97
 
89
98
  {
90
99
  last_hidden_state: hidden_states,
@@ -30,8 +30,12 @@ module Fine
30
30
  load_result = load_pretrained_weights(model, weights_path)
31
31
 
32
32
  if load_result[:missing_keys].any?
33
- # Only warn about unexpected missing keys (classifier is expected to be missing)
34
- encoder_missing = load_result[:missing_keys].reject { |k| k.include?("classifier") }
33
+ # Only warn about unexpected missing keys
34
+ # Expected missing: classifier (new), token_type_embeddings (DistilBERT), pooler (DistilBERT)
35
+ expected_missing = %w[classifier token_type_embeddings pooler_dense]
36
+ encoder_missing = load_result[:missing_keys].reject do |k|
37
+ expected_missing.any? { |exp| k.include?(exp) }
38
+ end
35
39
  if encoder_missing.any?
36
40
  warn "Missing encoder keys: #{encoder_missing.first(5).join(', ')}..."
37
41
  end
@@ -56,7 +60,7 @@ module Fine
56
60
  model = new(config, num_labels: num_labels)
57
61
 
58
62
  weights_path = File.join(path, "model.safetensors")
59
- Hub::SafetensorsLoader.load_into_model(model, weights_path, strict: false)
63
+ Hub::SafetensorsLoader.load_into_model(model, weights_path, strict: false, skip_mapping: true)
60
64
 
61
65
  model
62
66
  end
@@ -66,8 +70,11 @@ module Fine
66
70
 
67
71
  @num_labels = num_labels
68
72
 
73
+ # Detect if this is DistilBERT (no pooler layer in pretrained weights)
74
+ use_pooler = config.model_type != "distilbert"
75
+
69
76
  # Encoder
70
- @encoder = BertEncoder.new(config)
77
+ @encoder = BertEncoder.new(config, use_pooler: use_pooler)
71
78
 
72
79
  # Classification head
73
80
  @dropout = Torch::NN::Dropout.new(p: dropout)
@@ -83,6 +90,8 @@ module Fine
83
90
  )
84
91
 
85
92
  # Use pooled output for classification
93
+ # For DistilBERT (no pooler), this is the raw CLS token
94
+ # which works better than mean pooling for classification
86
95
  pooled_output = encoder_output[:pooler_output]
87
96
  pooled_output = @dropout.call(pooled_output)
88
97
 
@@ -186,6 +195,28 @@ module Fine
186
195
  def self.map_bert_weight_name(hf_name)
187
196
  name = hf_name.dup
188
197
 
198
+ # DistilBERT mappings (must come first as they're more specific)
199
+ if name.start_with?("distilbert.")
200
+ name = name.sub("distilbert.embeddings.word_embeddings", "encoder.word_embeddings")
201
+ name = name.sub("distilbert.embeddings.position_embeddings", "encoder.position_embeddings")
202
+ name = name.sub("distilbert.embeddings.LayerNorm", "encoder.embeddings_layer_norm")
203
+ name = name.gsub("distilbert.transformer.layer", "encoder.layers")
204
+
205
+ # DistilBERT attention naming
206
+ name = name.gsub(".attention.q_lin", ".attention.query")
207
+ name = name.gsub(".attention.k_lin", ".attention.key")
208
+ name = name.gsub(".attention.v_lin", ".attention.value")
209
+ name = name.gsub(".attention.out_lin", ".attention.out")
210
+ name = name.gsub(".sa_layer_norm", ".attention_layer_norm")
211
+
212
+ # DistilBERT FFN naming
213
+ name = name.gsub(".ffn.lin1", ".intermediate")
214
+ name = name.gsub(".ffn.lin2", ".output")
215
+
216
+ return name
217
+ end
218
+
219
+ # Standard BERT mappings
189
220
  # Embeddings
190
221
  name = name.sub("bert.embeddings.word_embeddings", "encoder.word_embeddings")
191
222
  name = name.sub("bert.embeddings.position_embeddings", "encoder.position_embeddings")
@@ -62,8 +62,8 @@ module Fine
62
62
  def initialize(config)
63
63
  super(config)
64
64
 
65
- # Use appropriate decoder based on model type
66
- @decoder = if config.model_type&.include?("gemma3")
65
+ # Use appropriate decoder based on model type or architectures
66
+ @decoder = if gemma3_architecture?(config)
67
67
  Gemma3Decoder.new(config)
68
68
  else
69
69
  LlamaDecoder.new(config)
@@ -134,7 +134,10 @@ module Fine
134
134
  if do_sample
135
135
  # Top-k filtering
136
136
  if top_k > 0
137
- indices_to_remove = next_token_logits < Torch.topk(next_token_logits, top_k).values[0.., -1, nil]
137
+ # Torch.topk returns [values, indices] array in torch-rb
138
+ topk_values, _topk_indices = Torch.topk(next_token_logits, top_k)
139
+ threshold = topk_values[0.., -1, nil]
140
+ indices_to_remove = Torch.lt(next_token_logits, threshold)
138
141
  next_token_logits = next_token_logits.masked_fill(indices_to_remove, -Float::INFINITY)
139
142
  end
140
143
 
@@ -144,7 +147,7 @@ module Fine
144
147
  cumulative_probs = Torch.cumsum(Torch::NN::Functional.softmax(sorted_logits, dim: -1), dim: -1)
145
148
 
146
149
  # Remove tokens with cumulative probability above threshold
147
- sorted_indices_to_remove = cumulative_probs > top_p
150
+ sorted_indices_to_remove = Torch.gt(cumulative_probs, top_p)
148
151
  sorted_indices_to_remove[0.., 1..] = sorted_indices_to_remove[0.., 0...-1].clone
149
152
  sorted_indices_to_remove[0.., 0] = false
150
153
 
@@ -182,7 +185,12 @@ module Fine
182
185
  weights_path = File.join(path, "model.safetensors")
183
186
  Safetensors::Torch.save_file(state_dict, weights_path)
184
187
 
185
- save_config = @config.to_h.merge("model_type" => "causal_lm")
188
+ # Preserve architecture info for proper decoder selection on load
189
+ save_config = @config.to_h.dup
190
+ # Don't overwrite model_type if it contains architecture info
191
+ save_config["model_type"] ||= "causal_lm"
192
+ # Mark which decoder type this model uses
193
+ save_config["_decoder_type"] = @decoder.class.name.split("::").last
186
194
 
187
195
  config_path = File.join(path, "config.json")
188
196
  File.write(config_path, JSON.pretty_generate(save_config))
@@ -190,10 +198,26 @@ module Fine
190
198
 
191
199
  private
192
200
 
201
+ # Detect if this is a Gemma 3 model based on config
202
+ def gemma3_architecture?(config)
203
+ # Check model_type first
204
+ return true if config.model_type&.include?("gemma3")
205
+
206
+ # Check architectures array (HuggingFace format)
207
+ architectures = config.config["architectures"] || []
208
+ return true if architectures.any? { |a| a.downcase.include?("gemma3") }
209
+
210
+ # Check saved decoder type
211
+ return true if config.config["_decoder_type"] == "Gemma3Decoder"
212
+
213
+ false
214
+ end
215
+
193
216
  def self.load_pretrained_weights(model, weights_path)
194
217
  # Load and copy weights one at a time to minimize memory usage
195
218
  model_state = model.state_dict
196
219
  model_keys = model_state.keys
220
+ loaded_lm_head = false
197
221
 
198
222
  Torch.no_grad do
199
223
  Safetensors::Torch.load_file(weights_path).each do |name, tensor|
@@ -203,8 +227,16 @@ module Fine
203
227
  # Convert dtype if needed
204
228
  tensor = tensor.to(target.dtype) if tensor.dtype != target.dtype
205
229
  target.copy!(tensor)
230
+ loaded_lm_head = true if mapped_name == "lm_head.weight"
206
231
  end
207
232
  end
233
+
234
+ # If lm_head wasn't in the weights file, tie it to embeddings
235
+ unless loaded_lm_head
236
+ embed_weight = model_state["decoder.embed_tokens.weight"]
237
+ lm_head_weight = model_state["lm_head.weight"]
238
+ lm_head_weight.copy!(embed_weight)
239
+ end
208
240
  end
209
241
 
210
242
  # Force garbage collection to free loaded tensors
@@ -219,6 +251,7 @@ module Fine
219
251
 
220
252
  model_state = model.state_dict
221
253
  model_keys = model_state.keys
254
+ loaded_lm_head = false
222
255
 
223
256
  # Load each shard and copy weights immediately to minimize memory
224
257
  Torch.no_grad do
@@ -229,11 +262,19 @@ module Fine
229
262
  target = model_state[mapped_name]
230
263
  tensor = tensor.to(target.dtype) if tensor.dtype != target.dtype
231
264
  target.copy!(tensor)
265
+ loaded_lm_head = true if mapped_name == "lm_head.weight"
232
266
  end
233
267
  end
234
268
  # GC after each shard to free memory
235
269
  GC.start
236
270
  end
271
+
272
+ # If lm_head wasn't in the weights file, tie it to embeddings
273
+ unless loaded_lm_head
274
+ embed_weight = model_state["decoder.embed_tokens.weight"]
275
+ lm_head_weight = model_state["lm_head.weight"]
276
+ lm_head_weight.copy!(embed_weight)
277
+ end
237
278
  end
238
279
  end
239
280
 
@@ -60,12 +60,12 @@ module Fine
60
60
  position_ids ||= Torch.arange(seq_length, device: input_ids.device)
61
61
  position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
62
62
 
63
- # Create causal mask
64
- causal_mask = create_causal_mask(seq_length, hidden_states.device)
63
+ # Create causal mask (must match dtype of hidden_states)
64
+ causal_mask = create_causal_mask(seq_length, hidden_states.device, hidden_states.dtype)
65
65
 
66
66
  # Combine with attention mask if provided
67
67
  if attention_mask
68
- expanded_mask = attention_mask.unsqueeze(1).unsqueeze(2)
68
+ expanded_mask = attention_mask.unsqueeze(1).unsqueeze(2).to(hidden_states.dtype)
69
69
  expanded_mask = expanded_mask.expand(-1, -1, seq_length, -1)
70
70
  causal_mask = causal_mask + (1.0 - expanded_mask) * -1e9
71
71
  end
@@ -87,9 +87,9 @@ module Fine
87
87
 
88
88
  private
89
89
 
90
- def create_causal_mask(seq_length, device)
90
+ def create_causal_mask(seq_length, device, dtype)
91
91
  mask = Torch.triu(
92
- Torch.ones(seq_length, seq_length, device: device) * -1e9,
92
+ Torch.ones(seq_length, seq_length, device: device, dtype: dtype) * -1e9,
93
93
  diagonal: 1
94
94
  )
95
95
  mask.unsqueeze(0).unsqueeze(0)
@@ -112,7 +112,7 @@ module Fine
112
112
  rms_norm_eps: rms_norm_eps
113
113
  )
114
114
 
115
- @mlp = LlamaMLP.new(
115
+ @mlp = Gemma3MLP.new(
116
116
  hidden_size: hidden_size,
117
117
  intermediate_size: intermediate_size
118
118
  )
@@ -240,5 +240,24 @@ module Fine
240
240
  x.reshape(batch, num_kv_heads * n_rep, seq_len, head_dim)
241
241
  end
242
242
  end
243
+
244
+ # Gemma 3 MLP with GELU activation (not SiLU like Llama)
245
+ class Gemma3MLP < Torch::NN::Module
246
+ def initialize(hidden_size:, intermediate_size:)
247
+ super()
248
+
249
+ @gate_proj = Torch::NN::Linear.new(hidden_size, intermediate_size, bias: false)
250
+ @up_proj = Torch::NN::Linear.new(hidden_size, intermediate_size, bias: false)
251
+ @down_proj = Torch::NN::Linear.new(intermediate_size, hidden_size, bias: false)
252
+ end
253
+
254
+ def forward(x)
255
+ # GeGLU: gelu(gate) * up
256
+ # Using GELU with tanh approximation as per Gemma config
257
+ gate = Torch::NN::Functional.gelu(@gate_proj.call(x), approximate: "tanh")
258
+ up = @up_proj.call(x)
259
+ @down_proj.call(gate * up)
260
+ end
261
+ end
243
262
  end
244
263
  end
@@ -52,13 +52,13 @@ module Fine
52
52
  position_ids ||= Torch.arange(seq_length, device: input_ids.device)
53
53
  position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
54
54
 
55
- # Create causal mask
56
- causal_mask = create_causal_mask(seq_length, hidden_states.device)
55
+ # Create causal mask (must match dtype of hidden_states)
56
+ causal_mask = create_causal_mask(seq_length, hidden_states.device, hidden_states.dtype)
57
57
 
58
58
  # Combine with attention mask if provided
59
59
  if attention_mask
60
60
  # Expand attention mask: (batch, seq) -> (batch, 1, seq, seq)
61
- expanded_mask = attention_mask.unsqueeze(1).unsqueeze(2)
61
+ expanded_mask = attention_mask.unsqueeze(1).unsqueeze(2).to(hidden_states.dtype)
62
62
  expanded_mask = expanded_mask.expand(-1, -1, seq_length, -1)
63
63
  causal_mask = causal_mask + (1.0 - expanded_mask) * -1e9
64
64
  end
@@ -80,10 +80,10 @@ module Fine
80
80
 
81
81
  private
82
82
 
83
- def create_causal_mask(seq_length, device)
83
+ def create_causal_mask(seq_length, device, dtype)
84
84
  # Lower triangular mask for causal attention
85
85
  mask = Torch.triu(
86
- Torch.ones(seq_length, seq_length, device: device) * -1e9,
86
+ Torch.ones(seq_length, seq_length, device: device, dtype: dtype) * -1e9,
87
87
  diagonal: 1
88
88
  )
89
89
  mask.unsqueeze(0).unsqueeze(0)
@@ -235,10 +235,11 @@ module Fine
235
235
  seq_len = position_ids.max.item + 1
236
236
  build_cache(seq_len) if seq_len > @cos_cached.size(0)
237
237
 
238
- # Move cached tensors to position_ids device if needed
238
+ # Move cached tensors to position_ids device and match dtype of input
239
239
  device = position_ids.device
240
- cos_cached = @cos_cached.to(device)
241
- sin_cached = @sin_cached.to(device)
240
+ dtype = x.dtype
241
+ cos_cached = @cos_cached.to(device).to(dtype)
242
+ sin_cached = @sin_cached.to(device).to(dtype)
242
243
 
243
244
  cos = cos_cached[position_ids].unsqueeze(1)
244
245
  sin = sin_cached[position_ids].unsqueeze(1)
@@ -43,7 +43,7 @@ module Fine
43
43
  model = new(config, pooling_mode: pooling_mode)
44
44
 
45
45
  weights_path = File.join(path, "model.safetensors")
46
- Hub::SafetensorsLoader.load_into_model(model, weights_path, strict: false)
46
+ Hub::SafetensorsLoader.load_into_model(model, weights_path, strict: false, skip_mapping: true)
47
47
 
48
48
  model
49
49
  end
@@ -135,6 +135,21 @@ module Fine
135
135
  @tokenizer.decode(token_ids, skip_special_tokens: skip_special_tokens)
136
136
  end
137
137
 
138
+ # Encode without padding (for generation)
139
+ # Returns only the actual tokens, no padding
140
+ #
141
+ # @param text [String] Text to tokenize
142
+ # @return [Array<Integer>] Token IDs
143
+ def encode_for_generation(text)
144
+ # Temporarily disable padding
145
+ @tokenizer.no_padding
146
+ encoding = @tokenizer.encode(text)
147
+ ids = encoding.ids
148
+ # Re-enable padding
149
+ @tokenizer.enable_padding(length: @max_length)
150
+ ids
151
+ end
152
+
138
153
  # Get vocabulary size
139
154
  def vocab_size
140
155
  @tokenizer.vocab_size
@@ -228,10 +228,12 @@ module Fine
228
228
  optimizer.zero_grad
229
229
 
230
230
  # For pair datasets, we get anchor and positive texts
231
- embeddings = @model.encode(
231
+ # Use forward() directly during training (not encode() which uses no_grad)
232
+ output = @model.forward(
232
233
  batch[:input_ids],
233
234
  attention_mask: batch[:attention_mask]
234
235
  )
236
+ embeddings = output[:embeddings]
235
237
 
236
238
  # Multiple Negatives Ranking Loss
237
239
  # Treat other samples in batch as negatives