fine 0.1.0 → 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +20 -10
- data/docs/examples/image-classification-shapes.md +83 -0
- data/docs/examples/text-embeddings-faq.md +98 -0
- data/docs/quickstart.md +209 -0
- data/docs/tutorials/lora-tool-calling.md +306 -0
- data/examples/data/generate_tool_data.rb +261 -0
- data/examples/data/ollama_tool_calls.jsonl +40 -0
- data/examples/data/sentiment_reviews.jsonl +30 -0
- data/examples/data/shapes/circle/circle_1.jpg +0 -0
- data/examples/data/shapes/circle/circle_10.jpg +0 -0
- data/examples/data/shapes/circle/circle_2.jpg +0 -0
- data/examples/data/shapes/circle/circle_3.jpg +0 -0
- data/examples/data/shapes/circle/circle_4.jpg +0 -0
- data/examples/data/shapes/circle/circle_5.jpg +0 -0
- data/examples/data/shapes/circle/circle_6.jpg +0 -0
- data/examples/data/shapes/circle/circle_7.jpg +0 -0
- data/examples/data/shapes/circle/circle_8.jpg +0 -0
- data/examples/data/shapes/circle/circle_9.jpg +0 -0
- data/examples/data/shapes/square/square_1.jpg +0 -0
- data/examples/data/shapes/square/square_10.jpg +0 -0
- data/examples/data/shapes/square/square_2.jpg +0 -0
- data/examples/data/shapes/square/square_3.jpg +0 -0
- data/examples/data/shapes/square/square_4.jpg +0 -0
- data/examples/data/shapes/square/square_5.jpg +0 -0
- data/examples/data/shapes/square/square_6.jpg +0 -0
- data/examples/data/shapes/square/square_7.jpg +0 -0
- data/examples/data/shapes/square/square_8.jpg +0 -0
- data/examples/data/shapes/square/square_9.jpg +0 -0
- data/examples/data/shapes/triangle/triangle_1.jpg +0 -0
- data/examples/data/shapes/triangle/triangle_10.jpg +0 -0
- data/examples/data/shapes/triangle/triangle_2.jpg +0 -0
- data/examples/data/shapes/triangle/triangle_3.jpg +0 -0
- data/examples/data/shapes/triangle/triangle_4.jpg +0 -0
- data/examples/data/shapes/triangle/triangle_5.jpg +0 -0
- data/examples/data/shapes/triangle/triangle_6.jpg +0 -0
- data/examples/data/shapes/triangle/triangle_7.jpg +0 -0
- data/examples/data/shapes/triangle/triangle_8.jpg +0 -0
- data/examples/data/shapes/triangle/triangle_9.jpg +0 -0
- data/examples/data/support_faq_pairs.jsonl +30 -0
- data/examples/generate_shape_images.rb +94 -0
- data/examples/sentiment_classification.rb +87 -0
- data/examples/shape_classification.rb +87 -0
- data/examples/support_faq_embeddings.rb +105 -0
- data/examples/train_lora_tools.rb +218 -0
- data/lib/fine/configuration.rb +173 -15
- data/lib/fine/datasets/image_dataset.rb +14 -2
- data/lib/fine/datasets/instruction_dataset.rb +17 -2
- data/lib/fine/datasets/text_dataset.rb +15 -5
- data/lib/fine/hub/config_loader.rb +4 -4
- data/lib/fine/hub/safetensors_loader.rb +3 -2
- data/lib/fine/llm.rb +39 -10
- data/lib/fine/lora.rb +214 -0
- data/lib/fine/models/bert_encoder.rb +15 -6
- data/lib/fine/models/bert_for_sequence_classification.rb +35 -4
- data/lib/fine/models/causal_lm.rb +46 -5
- data/lib/fine/models/gemma3_decoder.rb +25 -6
- data/lib/fine/models/llama_decoder.rb +9 -8
- data/lib/fine/models/sentence_transformer.rb +1 -1
- data/lib/fine/tokenizers/auto_tokenizer.rb +15 -0
- data/lib/fine/training/text_trainer.rb +3 -1
- data/lib/fine/validators.rb +304 -0
- data/lib/fine/version.rb +1 -1
- data/lib/fine.rb +4 -0
- metadata +47 -2
data/lib/fine/lora.rb
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fine
|
|
4
|
+
# Low-Rank Adaptation (LoRA) for parameter-efficient fine-tuning
|
|
5
|
+
#
|
|
6
|
+
# LoRA freezes the pretrained model weights and injects trainable
|
|
7
|
+
# rank decomposition matrices into each layer, dramatically reducing
|
|
8
|
+
# the number of trainable parameters.
|
|
9
|
+
#
|
|
10
|
+
# @example
|
|
11
|
+
# model = Fine::Models::CausalLM.from_pretrained("google/gemma-3-4b-it")
|
|
12
|
+
# lora_model = Fine::LoRA.apply(model, rank: 8, alpha: 16, target_modules: ["q_proj", "v_proj"])
|
|
13
|
+
# # Only LoRA parameters are trainable now
|
|
14
|
+
#
|
|
15
|
+
module LoRA
|
|
16
|
+
# LoRA Linear layer that wraps an existing Linear layer
|
|
17
|
+
class LoRALinear < Torch::NN::Module
|
|
18
|
+
attr_reader :in_features, :out_features, :rank, :alpha, :scaling
|
|
19
|
+
|
|
20
|
+
def initialize(original_layer, rank: 8, alpha: 16, dropout: 0.0)
|
|
21
|
+
super()
|
|
22
|
+
|
|
23
|
+
@in_features = original_layer.weight.shape[1]
|
|
24
|
+
@out_features = original_layer.weight.shape[0]
|
|
25
|
+
@rank = rank
|
|
26
|
+
@alpha = alpha
|
|
27
|
+
@scaling = alpha.to_f / rank
|
|
28
|
+
|
|
29
|
+
# Match dtype of original layer
|
|
30
|
+
@dtype = original_layer.weight.dtype
|
|
31
|
+
@device = original_layer.weight.device
|
|
32
|
+
|
|
33
|
+
# Store original layer (frozen)
|
|
34
|
+
@original = original_layer
|
|
35
|
+
@original.weight.requires_grad = false
|
|
36
|
+
@original.bias&.requires_grad = false if @original.respond_to?(:bias) && @original.bias
|
|
37
|
+
|
|
38
|
+
# LoRA matrices A and B - match dtype of original layer
|
|
39
|
+
# W' = W + (B @ A) * scaling
|
|
40
|
+
# A: (rank, in_features) - initialized with Kaiming uniform
|
|
41
|
+
# B: (out_features, rank) - initialized with zeros
|
|
42
|
+
@lora_a = Torch::NN::Parameter.new(
|
|
43
|
+
Torch.empty(@rank, @in_features, dtype: @dtype, device: @device)
|
|
44
|
+
)
|
|
45
|
+
@lora_b = Torch::NN::Parameter.new(
|
|
46
|
+
Torch.zeros(@out_features, @rank, dtype: @dtype, device: @device)
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
# Initialize A with Kaiming uniform (in float32, then convert)
|
|
50
|
+
temp_a = Torch.empty(@rank, @in_features)
|
|
51
|
+
Torch::NN::Init.kaiming_uniform!(temp_a, a: Math.sqrt(5))
|
|
52
|
+
@lora_a.data.copy!(temp_a.to(@dtype))
|
|
53
|
+
|
|
54
|
+
# Optional dropout
|
|
55
|
+
@dropout = dropout > 0 ? Torch::NN::Dropout.new(p: dropout) : nil
|
|
56
|
+
end
|
|
57
|
+
|
|
58
|
+
def forward(x)
|
|
59
|
+
# Original forward pass (frozen)
|
|
60
|
+
original_out = @original.call(x)
|
|
61
|
+
|
|
62
|
+
# LoRA forward: x @ A.T @ B.T * scaling
|
|
63
|
+
lora_out = x
|
|
64
|
+
lora_out = @dropout.call(lora_out) if @dropout
|
|
65
|
+
lora_out = lora_out.matmul(@lora_a.t)
|
|
66
|
+
lora_out = lora_out.matmul(@lora_b.t)
|
|
67
|
+
lora_out = lora_out * @scaling
|
|
68
|
+
|
|
69
|
+
original_out + lora_out
|
|
70
|
+
end
|
|
71
|
+
|
|
72
|
+
# Number of trainable parameters
|
|
73
|
+
def trainable_params
|
|
74
|
+
@rank * @in_features + @out_features * @rank
|
|
75
|
+
end
|
|
76
|
+
|
|
77
|
+
# Merge LoRA weights into original layer (for inference)
|
|
78
|
+
def merge!
|
|
79
|
+
Torch.no_grad do
|
|
80
|
+
delta_w = @lora_b.matmul(@lora_a) * @scaling
|
|
81
|
+
@original.weight.add!(delta_w)
|
|
82
|
+
end
|
|
83
|
+
end
|
|
84
|
+
end
|
|
85
|
+
|
|
86
|
+
class << self
|
|
87
|
+
# Apply LoRA to a model
|
|
88
|
+
#
|
|
89
|
+
# @param model [Torch::NN::Module] Model to apply LoRA to
|
|
90
|
+
# @param rank [Integer] LoRA rank (lower = fewer params, higher = more capacity)
|
|
91
|
+
# @param alpha [Integer] LoRA alpha (scaling factor)
|
|
92
|
+
# @param dropout [Float] Dropout probability for LoRA layers
|
|
93
|
+
# @param target_modules [Array<String>] Module names to apply LoRA to
|
|
94
|
+
# @return [Torch::NN::Module] Model with LoRA applied
|
|
95
|
+
def apply(model, rank: 8, alpha: 16, dropout: 0.0, target_modules: nil)
|
|
96
|
+
target_modules ||= default_target_modules
|
|
97
|
+
|
|
98
|
+
# First freeze all parameters
|
|
99
|
+
model.parameters.each { |p| p.requires_grad = false }
|
|
100
|
+
|
|
101
|
+
# Track replacements
|
|
102
|
+
replacements = []
|
|
103
|
+
total_lora_params = 0
|
|
104
|
+
|
|
105
|
+
# Find and replace target modules
|
|
106
|
+
find_modules(model, target_modules) do |parent, name, layer|
|
|
107
|
+
next unless layer.is_a?(Torch::NN::Linear)
|
|
108
|
+
|
|
109
|
+
lora_layer = LoRALinear.new(layer, rank: rank, alpha: alpha, dropout: dropout)
|
|
110
|
+
replacements << [parent, name, lora_layer]
|
|
111
|
+
total_lora_params += lora_layer.trainable_params
|
|
112
|
+
end
|
|
113
|
+
|
|
114
|
+
# Apply replacements
|
|
115
|
+
replacements.each do |parent, name, lora_layer|
|
|
116
|
+
parent.instance_variable_set("@#{name}", lora_layer)
|
|
117
|
+
end
|
|
118
|
+
|
|
119
|
+
# Calculate stats
|
|
120
|
+
total_params = count_params(model)
|
|
121
|
+
trainable = count_trainable_params(model)
|
|
122
|
+
|
|
123
|
+
puts " LoRA applied to #{replacements.size} layers"
|
|
124
|
+
puts " Total params: #{format_params(total_params)}"
|
|
125
|
+
puts " Trainable params: #{format_params(trainable)} (#{(trainable.to_f / total_params * 100).round(2)}%)"
|
|
126
|
+
|
|
127
|
+
model
|
|
128
|
+
end
|
|
129
|
+
|
|
130
|
+
# Merge LoRA weights into base model (for efficient inference)
|
|
131
|
+
def merge!(model)
|
|
132
|
+
find_lora_layers(model) do |lora_layer|
|
|
133
|
+
lora_layer.merge!
|
|
134
|
+
end
|
|
135
|
+
model
|
|
136
|
+
end
|
|
137
|
+
|
|
138
|
+
# Get only trainable (LoRA) parameters
|
|
139
|
+
def trainable_parameters(model)
|
|
140
|
+
params = []
|
|
141
|
+
find_lora_layers(model) do |lora_layer|
|
|
142
|
+
params << lora_layer.lora_a
|
|
143
|
+
params << lora_layer.lora_b
|
|
144
|
+
end
|
|
145
|
+
params
|
|
146
|
+
end
|
|
147
|
+
|
|
148
|
+
# Default modules to apply LoRA to (attention projections)
|
|
149
|
+
def default_target_modules
|
|
150
|
+
%w[q_proj k_proj v_proj o_proj]
|
|
151
|
+
end
|
|
152
|
+
|
|
153
|
+
private
|
|
154
|
+
|
|
155
|
+
def find_modules(model, target_names, parent = nil, prefix = "", &block)
|
|
156
|
+
model.instance_variables.each do |ivar|
|
|
157
|
+
name = ivar.to_s.delete_prefix("@")
|
|
158
|
+
child = model.instance_variable_get(ivar)
|
|
159
|
+
|
|
160
|
+
if child.is_a?(Torch::NN::Module)
|
|
161
|
+
full_name = prefix.empty? ? name : "#{prefix}.#{name}"
|
|
162
|
+
|
|
163
|
+
if target_names.any? { |t| name == t || name.end_with?(t) }
|
|
164
|
+
yield(model, name, child)
|
|
165
|
+
end
|
|
166
|
+
|
|
167
|
+
# Recurse into ModuleList
|
|
168
|
+
if child.is_a?(Torch::NN::ModuleList)
|
|
169
|
+
child.each_with_index do |layer, idx|
|
|
170
|
+
find_modules(layer, target_names, child, "#{full_name}[#{idx}]", &block)
|
|
171
|
+
end
|
|
172
|
+
else
|
|
173
|
+
find_modules(child, target_names, model, full_name, &block)
|
|
174
|
+
end
|
|
175
|
+
end
|
|
176
|
+
end
|
|
177
|
+
end
|
|
178
|
+
|
|
179
|
+
def find_lora_layers(model, &block)
|
|
180
|
+
model.instance_variables.each do |ivar|
|
|
181
|
+
child = model.instance_variable_get(ivar)
|
|
182
|
+
|
|
183
|
+
if child.is_a?(LoRALinear)
|
|
184
|
+
yield(child)
|
|
185
|
+
elsif child.is_a?(Torch::NN::ModuleList)
|
|
186
|
+
child.each { |layer| find_lora_layers(layer, &block) }
|
|
187
|
+
elsif child.is_a?(Torch::NN::Module)
|
|
188
|
+
find_lora_layers(child, &block)
|
|
189
|
+
end
|
|
190
|
+
end
|
|
191
|
+
end
|
|
192
|
+
|
|
193
|
+
def count_params(model)
|
|
194
|
+
model.parameters.sum { |p| p.numel }
|
|
195
|
+
end
|
|
196
|
+
|
|
197
|
+
def count_trainable_params(model)
|
|
198
|
+
model.parameters.select { |p| p.requires_grad }.sum { |p| p.numel }
|
|
199
|
+
end
|
|
200
|
+
|
|
201
|
+
def format_params(n)
|
|
202
|
+
if n >= 1_000_000_000
|
|
203
|
+
"#{(n / 1_000_000_000.0).round(2)}B"
|
|
204
|
+
elsif n >= 1_000_000
|
|
205
|
+
"#{(n / 1_000_000.0).round(2)}M"
|
|
206
|
+
elsif n >= 1_000
|
|
207
|
+
"#{(n / 1_000.0).round(2)}K"
|
|
208
|
+
else
|
|
209
|
+
n.to_s
|
|
210
|
+
end
|
|
211
|
+
end
|
|
212
|
+
end
|
|
213
|
+
end
|
|
214
|
+
end
|
|
@@ -9,7 +9,7 @@ module Fine
|
|
|
9
9
|
class BertEncoder < Base
|
|
10
10
|
attr_reader :embeddings, :encoder, :pooler
|
|
11
11
|
|
|
12
|
-
def initialize(config)
|
|
12
|
+
def initialize(config, use_pooler: true)
|
|
13
13
|
super(config)
|
|
14
14
|
|
|
15
15
|
@hidden_size = config.hidden_size
|
|
@@ -21,6 +21,7 @@ module Fine
|
|
|
21
21
|
@type_vocab_size = config.type_vocab_size || 2
|
|
22
22
|
@layer_norm_eps = config.layer_norm_eps
|
|
23
23
|
@hidden_dropout_prob = config.hidden_dropout_prob || 0.1
|
|
24
|
+
@use_pooler = use_pooler
|
|
24
25
|
|
|
25
26
|
# Embeddings
|
|
26
27
|
@word_embeddings = Torch::NN::Embedding.new(@vocab_size, @hidden_size)
|
|
@@ -42,9 +43,11 @@ module Fine
|
|
|
42
43
|
end
|
|
43
44
|
)
|
|
44
45
|
|
|
45
|
-
# Pooler (for [CLS] token representation)
|
|
46
|
-
|
|
47
|
-
|
|
46
|
+
# Pooler (for [CLS] token representation) - optional for models like DistilBERT
|
|
47
|
+
if @use_pooler
|
|
48
|
+
@pooler_dense = Torch::NN::Linear.new(@hidden_size, @hidden_size)
|
|
49
|
+
@pooler_activation = Torch::NN::Tanh.new
|
|
50
|
+
end
|
|
48
51
|
end
|
|
49
52
|
|
|
50
53
|
def forward(input_ids, attention_mask: nil, token_type_ids: nil)
|
|
@@ -83,8 +86,14 @@ module Fine
|
|
|
83
86
|
|
|
84
87
|
# Pool the [CLS] token (first token)
|
|
85
88
|
cls_output = hidden_states[0.., 0, 0..]
|
|
86
|
-
|
|
87
|
-
|
|
89
|
+
|
|
90
|
+
# Apply pooler if available, otherwise use CLS directly
|
|
91
|
+
pooled_output = if @use_pooler && @pooler_dense
|
|
92
|
+
temp = @pooler_dense.call(cls_output)
|
|
93
|
+
@pooler_activation.call(temp)
|
|
94
|
+
else
|
|
95
|
+
cls_output
|
|
96
|
+
end
|
|
88
97
|
|
|
89
98
|
{
|
|
90
99
|
last_hidden_state: hidden_states,
|
|
@@ -30,8 +30,12 @@ module Fine
|
|
|
30
30
|
load_result = load_pretrained_weights(model, weights_path)
|
|
31
31
|
|
|
32
32
|
if load_result[:missing_keys].any?
|
|
33
|
-
# Only warn about unexpected missing keys
|
|
34
|
-
|
|
33
|
+
# Only warn about unexpected missing keys
|
|
34
|
+
# Expected missing: classifier (new), token_type_embeddings (DistilBERT), pooler (DistilBERT)
|
|
35
|
+
expected_missing = %w[classifier token_type_embeddings pooler_dense]
|
|
36
|
+
encoder_missing = load_result[:missing_keys].reject do |k|
|
|
37
|
+
expected_missing.any? { |exp| k.include?(exp) }
|
|
38
|
+
end
|
|
35
39
|
if encoder_missing.any?
|
|
36
40
|
warn "Missing encoder keys: #{encoder_missing.first(5).join(', ')}..."
|
|
37
41
|
end
|
|
@@ -56,7 +60,7 @@ module Fine
|
|
|
56
60
|
model = new(config, num_labels: num_labels)
|
|
57
61
|
|
|
58
62
|
weights_path = File.join(path, "model.safetensors")
|
|
59
|
-
Hub::SafetensorsLoader.load_into_model(model, weights_path, strict: false)
|
|
63
|
+
Hub::SafetensorsLoader.load_into_model(model, weights_path, strict: false, skip_mapping: true)
|
|
60
64
|
|
|
61
65
|
model
|
|
62
66
|
end
|
|
@@ -66,8 +70,11 @@ module Fine
|
|
|
66
70
|
|
|
67
71
|
@num_labels = num_labels
|
|
68
72
|
|
|
73
|
+
# Detect if this is DistilBERT (no pooler layer in pretrained weights)
|
|
74
|
+
use_pooler = config.model_type != "distilbert"
|
|
75
|
+
|
|
69
76
|
# Encoder
|
|
70
|
-
@encoder = BertEncoder.new(config)
|
|
77
|
+
@encoder = BertEncoder.new(config, use_pooler: use_pooler)
|
|
71
78
|
|
|
72
79
|
# Classification head
|
|
73
80
|
@dropout = Torch::NN::Dropout.new(p: dropout)
|
|
@@ -83,6 +90,8 @@ module Fine
|
|
|
83
90
|
)
|
|
84
91
|
|
|
85
92
|
# Use pooled output for classification
|
|
93
|
+
# For DistilBERT (no pooler), this is the raw CLS token
|
|
94
|
+
# which works better than mean pooling for classification
|
|
86
95
|
pooled_output = encoder_output[:pooler_output]
|
|
87
96
|
pooled_output = @dropout.call(pooled_output)
|
|
88
97
|
|
|
@@ -186,6 +195,28 @@ module Fine
|
|
|
186
195
|
def self.map_bert_weight_name(hf_name)
|
|
187
196
|
name = hf_name.dup
|
|
188
197
|
|
|
198
|
+
# DistilBERT mappings (must come first as they're more specific)
|
|
199
|
+
if name.start_with?("distilbert.")
|
|
200
|
+
name = name.sub("distilbert.embeddings.word_embeddings", "encoder.word_embeddings")
|
|
201
|
+
name = name.sub("distilbert.embeddings.position_embeddings", "encoder.position_embeddings")
|
|
202
|
+
name = name.sub("distilbert.embeddings.LayerNorm", "encoder.embeddings_layer_norm")
|
|
203
|
+
name = name.gsub("distilbert.transformer.layer", "encoder.layers")
|
|
204
|
+
|
|
205
|
+
# DistilBERT attention naming
|
|
206
|
+
name = name.gsub(".attention.q_lin", ".attention.query")
|
|
207
|
+
name = name.gsub(".attention.k_lin", ".attention.key")
|
|
208
|
+
name = name.gsub(".attention.v_lin", ".attention.value")
|
|
209
|
+
name = name.gsub(".attention.out_lin", ".attention.out")
|
|
210
|
+
name = name.gsub(".sa_layer_norm", ".attention_layer_norm")
|
|
211
|
+
|
|
212
|
+
# DistilBERT FFN naming
|
|
213
|
+
name = name.gsub(".ffn.lin1", ".intermediate")
|
|
214
|
+
name = name.gsub(".ffn.lin2", ".output")
|
|
215
|
+
|
|
216
|
+
return name
|
|
217
|
+
end
|
|
218
|
+
|
|
219
|
+
# Standard BERT mappings
|
|
189
220
|
# Embeddings
|
|
190
221
|
name = name.sub("bert.embeddings.word_embeddings", "encoder.word_embeddings")
|
|
191
222
|
name = name.sub("bert.embeddings.position_embeddings", "encoder.position_embeddings")
|
|
@@ -62,8 +62,8 @@ module Fine
|
|
|
62
62
|
def initialize(config)
|
|
63
63
|
super(config)
|
|
64
64
|
|
|
65
|
-
# Use appropriate decoder based on model type
|
|
66
|
-
@decoder = if
|
|
65
|
+
# Use appropriate decoder based on model type or architectures
|
|
66
|
+
@decoder = if gemma3_architecture?(config)
|
|
67
67
|
Gemma3Decoder.new(config)
|
|
68
68
|
else
|
|
69
69
|
LlamaDecoder.new(config)
|
|
@@ -134,7 +134,10 @@ module Fine
|
|
|
134
134
|
if do_sample
|
|
135
135
|
# Top-k filtering
|
|
136
136
|
if top_k > 0
|
|
137
|
-
|
|
137
|
+
# Torch.topk returns [values, indices] array in torch-rb
|
|
138
|
+
topk_values, _topk_indices = Torch.topk(next_token_logits, top_k)
|
|
139
|
+
threshold = topk_values[0.., -1, nil]
|
|
140
|
+
indices_to_remove = Torch.lt(next_token_logits, threshold)
|
|
138
141
|
next_token_logits = next_token_logits.masked_fill(indices_to_remove, -Float::INFINITY)
|
|
139
142
|
end
|
|
140
143
|
|
|
@@ -144,7 +147,7 @@ module Fine
|
|
|
144
147
|
cumulative_probs = Torch.cumsum(Torch::NN::Functional.softmax(sorted_logits, dim: -1), dim: -1)
|
|
145
148
|
|
|
146
149
|
# Remove tokens with cumulative probability above threshold
|
|
147
|
-
sorted_indices_to_remove = cumulative_probs
|
|
150
|
+
sorted_indices_to_remove = Torch.gt(cumulative_probs, top_p)
|
|
148
151
|
sorted_indices_to_remove[0.., 1..] = sorted_indices_to_remove[0.., 0...-1].clone
|
|
149
152
|
sorted_indices_to_remove[0.., 0] = false
|
|
150
153
|
|
|
@@ -182,7 +185,12 @@ module Fine
|
|
|
182
185
|
weights_path = File.join(path, "model.safetensors")
|
|
183
186
|
Safetensors::Torch.save_file(state_dict, weights_path)
|
|
184
187
|
|
|
185
|
-
|
|
188
|
+
# Preserve architecture info for proper decoder selection on load
|
|
189
|
+
save_config = @config.to_h.dup
|
|
190
|
+
# Don't overwrite model_type if it contains architecture info
|
|
191
|
+
save_config["model_type"] ||= "causal_lm"
|
|
192
|
+
# Mark which decoder type this model uses
|
|
193
|
+
save_config["_decoder_type"] = @decoder.class.name.split("::").last
|
|
186
194
|
|
|
187
195
|
config_path = File.join(path, "config.json")
|
|
188
196
|
File.write(config_path, JSON.pretty_generate(save_config))
|
|
@@ -190,10 +198,26 @@ module Fine
|
|
|
190
198
|
|
|
191
199
|
private
|
|
192
200
|
|
|
201
|
+
# Detect if this is a Gemma 3 model based on config
|
|
202
|
+
def gemma3_architecture?(config)
|
|
203
|
+
# Check model_type first
|
|
204
|
+
return true if config.model_type&.include?("gemma3")
|
|
205
|
+
|
|
206
|
+
# Check architectures array (HuggingFace format)
|
|
207
|
+
architectures = config.config["architectures"] || []
|
|
208
|
+
return true if architectures.any? { |a| a.downcase.include?("gemma3") }
|
|
209
|
+
|
|
210
|
+
# Check saved decoder type
|
|
211
|
+
return true if config.config["_decoder_type"] == "Gemma3Decoder"
|
|
212
|
+
|
|
213
|
+
false
|
|
214
|
+
end
|
|
215
|
+
|
|
193
216
|
def self.load_pretrained_weights(model, weights_path)
|
|
194
217
|
# Load and copy weights one at a time to minimize memory usage
|
|
195
218
|
model_state = model.state_dict
|
|
196
219
|
model_keys = model_state.keys
|
|
220
|
+
loaded_lm_head = false
|
|
197
221
|
|
|
198
222
|
Torch.no_grad do
|
|
199
223
|
Safetensors::Torch.load_file(weights_path).each do |name, tensor|
|
|
@@ -203,8 +227,16 @@ module Fine
|
|
|
203
227
|
# Convert dtype if needed
|
|
204
228
|
tensor = tensor.to(target.dtype) if tensor.dtype != target.dtype
|
|
205
229
|
target.copy!(tensor)
|
|
230
|
+
loaded_lm_head = true if mapped_name == "lm_head.weight"
|
|
206
231
|
end
|
|
207
232
|
end
|
|
233
|
+
|
|
234
|
+
# If lm_head wasn't in the weights file, tie it to embeddings
|
|
235
|
+
unless loaded_lm_head
|
|
236
|
+
embed_weight = model_state["decoder.embed_tokens.weight"]
|
|
237
|
+
lm_head_weight = model_state["lm_head.weight"]
|
|
238
|
+
lm_head_weight.copy!(embed_weight)
|
|
239
|
+
end
|
|
208
240
|
end
|
|
209
241
|
|
|
210
242
|
# Force garbage collection to free loaded tensors
|
|
@@ -219,6 +251,7 @@ module Fine
|
|
|
219
251
|
|
|
220
252
|
model_state = model.state_dict
|
|
221
253
|
model_keys = model_state.keys
|
|
254
|
+
loaded_lm_head = false
|
|
222
255
|
|
|
223
256
|
# Load each shard and copy weights immediately to minimize memory
|
|
224
257
|
Torch.no_grad do
|
|
@@ -229,11 +262,19 @@ module Fine
|
|
|
229
262
|
target = model_state[mapped_name]
|
|
230
263
|
tensor = tensor.to(target.dtype) if tensor.dtype != target.dtype
|
|
231
264
|
target.copy!(tensor)
|
|
265
|
+
loaded_lm_head = true if mapped_name == "lm_head.weight"
|
|
232
266
|
end
|
|
233
267
|
end
|
|
234
268
|
# GC after each shard to free memory
|
|
235
269
|
GC.start
|
|
236
270
|
end
|
|
271
|
+
|
|
272
|
+
# If lm_head wasn't in the weights file, tie it to embeddings
|
|
273
|
+
unless loaded_lm_head
|
|
274
|
+
embed_weight = model_state["decoder.embed_tokens.weight"]
|
|
275
|
+
lm_head_weight = model_state["lm_head.weight"]
|
|
276
|
+
lm_head_weight.copy!(embed_weight)
|
|
277
|
+
end
|
|
237
278
|
end
|
|
238
279
|
end
|
|
239
280
|
|
|
@@ -60,12 +60,12 @@ module Fine
|
|
|
60
60
|
position_ids ||= Torch.arange(seq_length, device: input_ids.device)
|
|
61
61
|
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
|
62
62
|
|
|
63
|
-
# Create causal mask
|
|
64
|
-
causal_mask = create_causal_mask(seq_length, hidden_states.device)
|
|
63
|
+
# Create causal mask (must match dtype of hidden_states)
|
|
64
|
+
causal_mask = create_causal_mask(seq_length, hidden_states.device, hidden_states.dtype)
|
|
65
65
|
|
|
66
66
|
# Combine with attention mask if provided
|
|
67
67
|
if attention_mask
|
|
68
|
-
expanded_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
|
68
|
+
expanded_mask = attention_mask.unsqueeze(1).unsqueeze(2).to(hidden_states.dtype)
|
|
69
69
|
expanded_mask = expanded_mask.expand(-1, -1, seq_length, -1)
|
|
70
70
|
causal_mask = causal_mask + (1.0 - expanded_mask) * -1e9
|
|
71
71
|
end
|
|
@@ -87,9 +87,9 @@ module Fine
|
|
|
87
87
|
|
|
88
88
|
private
|
|
89
89
|
|
|
90
|
-
def create_causal_mask(seq_length, device)
|
|
90
|
+
def create_causal_mask(seq_length, device, dtype)
|
|
91
91
|
mask = Torch.triu(
|
|
92
|
-
Torch.ones(seq_length, seq_length, device: device) * -1e9,
|
|
92
|
+
Torch.ones(seq_length, seq_length, device: device, dtype: dtype) * -1e9,
|
|
93
93
|
diagonal: 1
|
|
94
94
|
)
|
|
95
95
|
mask.unsqueeze(0).unsqueeze(0)
|
|
@@ -112,7 +112,7 @@ module Fine
|
|
|
112
112
|
rms_norm_eps: rms_norm_eps
|
|
113
113
|
)
|
|
114
114
|
|
|
115
|
-
@mlp =
|
|
115
|
+
@mlp = Gemma3MLP.new(
|
|
116
116
|
hidden_size: hidden_size,
|
|
117
117
|
intermediate_size: intermediate_size
|
|
118
118
|
)
|
|
@@ -240,5 +240,24 @@ module Fine
|
|
|
240
240
|
x.reshape(batch, num_kv_heads * n_rep, seq_len, head_dim)
|
|
241
241
|
end
|
|
242
242
|
end
|
|
243
|
+
|
|
244
|
+
# Gemma 3 MLP with GELU activation (not SiLU like Llama)
|
|
245
|
+
class Gemma3MLP < Torch::NN::Module
|
|
246
|
+
def initialize(hidden_size:, intermediate_size:)
|
|
247
|
+
super()
|
|
248
|
+
|
|
249
|
+
@gate_proj = Torch::NN::Linear.new(hidden_size, intermediate_size, bias: false)
|
|
250
|
+
@up_proj = Torch::NN::Linear.new(hidden_size, intermediate_size, bias: false)
|
|
251
|
+
@down_proj = Torch::NN::Linear.new(intermediate_size, hidden_size, bias: false)
|
|
252
|
+
end
|
|
253
|
+
|
|
254
|
+
def forward(x)
|
|
255
|
+
# GeGLU: gelu(gate) * up
|
|
256
|
+
# Using GELU with tanh approximation as per Gemma config
|
|
257
|
+
gate = Torch::NN::Functional.gelu(@gate_proj.call(x), approximate: "tanh")
|
|
258
|
+
up = @up_proj.call(x)
|
|
259
|
+
@down_proj.call(gate * up)
|
|
260
|
+
end
|
|
261
|
+
end
|
|
243
262
|
end
|
|
244
263
|
end
|
|
@@ -52,13 +52,13 @@ module Fine
|
|
|
52
52
|
position_ids ||= Torch.arange(seq_length, device: input_ids.device)
|
|
53
53
|
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
|
54
54
|
|
|
55
|
-
# Create causal mask
|
|
56
|
-
causal_mask = create_causal_mask(seq_length, hidden_states.device)
|
|
55
|
+
# Create causal mask (must match dtype of hidden_states)
|
|
56
|
+
causal_mask = create_causal_mask(seq_length, hidden_states.device, hidden_states.dtype)
|
|
57
57
|
|
|
58
58
|
# Combine with attention mask if provided
|
|
59
59
|
if attention_mask
|
|
60
60
|
# Expand attention mask: (batch, seq) -> (batch, 1, seq, seq)
|
|
61
|
-
expanded_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
|
61
|
+
expanded_mask = attention_mask.unsqueeze(1).unsqueeze(2).to(hidden_states.dtype)
|
|
62
62
|
expanded_mask = expanded_mask.expand(-1, -1, seq_length, -1)
|
|
63
63
|
causal_mask = causal_mask + (1.0 - expanded_mask) * -1e9
|
|
64
64
|
end
|
|
@@ -80,10 +80,10 @@ module Fine
|
|
|
80
80
|
|
|
81
81
|
private
|
|
82
82
|
|
|
83
|
-
def create_causal_mask(seq_length, device)
|
|
83
|
+
def create_causal_mask(seq_length, device, dtype)
|
|
84
84
|
# Lower triangular mask for causal attention
|
|
85
85
|
mask = Torch.triu(
|
|
86
|
-
Torch.ones(seq_length, seq_length, device: device) * -1e9,
|
|
86
|
+
Torch.ones(seq_length, seq_length, device: device, dtype: dtype) * -1e9,
|
|
87
87
|
diagonal: 1
|
|
88
88
|
)
|
|
89
89
|
mask.unsqueeze(0).unsqueeze(0)
|
|
@@ -235,10 +235,11 @@ module Fine
|
|
|
235
235
|
seq_len = position_ids.max.item + 1
|
|
236
236
|
build_cache(seq_len) if seq_len > @cos_cached.size(0)
|
|
237
237
|
|
|
238
|
-
# Move cached tensors to position_ids device
|
|
238
|
+
# Move cached tensors to position_ids device and match dtype of input
|
|
239
239
|
device = position_ids.device
|
|
240
|
-
|
|
241
|
-
|
|
240
|
+
dtype = x.dtype
|
|
241
|
+
cos_cached = @cos_cached.to(device).to(dtype)
|
|
242
|
+
sin_cached = @sin_cached.to(device).to(dtype)
|
|
242
243
|
|
|
243
244
|
cos = cos_cached[position_ids].unsqueeze(1)
|
|
244
245
|
sin = sin_cached[position_ids].unsqueeze(1)
|
|
@@ -43,7 +43,7 @@ module Fine
|
|
|
43
43
|
model = new(config, pooling_mode: pooling_mode)
|
|
44
44
|
|
|
45
45
|
weights_path = File.join(path, "model.safetensors")
|
|
46
|
-
Hub::SafetensorsLoader.load_into_model(model, weights_path, strict: false)
|
|
46
|
+
Hub::SafetensorsLoader.load_into_model(model, weights_path, strict: false, skip_mapping: true)
|
|
47
47
|
|
|
48
48
|
model
|
|
49
49
|
end
|
|
@@ -135,6 +135,21 @@ module Fine
|
|
|
135
135
|
@tokenizer.decode(token_ids, skip_special_tokens: skip_special_tokens)
|
|
136
136
|
end
|
|
137
137
|
|
|
138
|
+
# Encode without padding (for generation)
|
|
139
|
+
# Returns only the actual tokens, no padding
|
|
140
|
+
#
|
|
141
|
+
# @param text [String] Text to tokenize
|
|
142
|
+
# @return [Array<Integer>] Token IDs
|
|
143
|
+
def encode_for_generation(text)
|
|
144
|
+
# Temporarily disable padding
|
|
145
|
+
@tokenizer.no_padding
|
|
146
|
+
encoding = @tokenizer.encode(text)
|
|
147
|
+
ids = encoding.ids
|
|
148
|
+
# Re-enable padding
|
|
149
|
+
@tokenizer.enable_padding(length: @max_length)
|
|
150
|
+
ids
|
|
151
|
+
end
|
|
152
|
+
|
|
138
153
|
# Get vocabulary size
|
|
139
154
|
def vocab_size
|
|
140
155
|
@tokenizer.vocab_size
|
|
@@ -228,10 +228,12 @@ module Fine
|
|
|
228
228
|
optimizer.zero_grad
|
|
229
229
|
|
|
230
230
|
# For pair datasets, we get anchor and positive texts
|
|
231
|
-
|
|
231
|
+
# Use forward() directly during training (not encode() which uses no_grad)
|
|
232
|
+
output = @model.forward(
|
|
232
233
|
batch[:input_ids],
|
|
233
234
|
attention_mask: batch[:attention_mask]
|
|
234
235
|
)
|
|
236
|
+
embeddings = output[:embeddings]
|
|
235
237
|
|
|
236
238
|
# Multiple Negatives Ranking Loss
|
|
237
239
|
# Treat other samples in batch as negatives
|