fine 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.rspec +3 -0
- data/CHANGELOG.md +38 -0
- data/Gemfile +6 -0
- data/Gemfile.lock +167 -0
- data/LICENSE +21 -0
- data/README.md +212 -0
- data/Rakefile +6 -0
- data/docs/installation.md +151 -0
- data/docs/tutorials/llm-fine-tuning.md +246 -0
- data/docs/tutorials/model-export.md +200 -0
- data/docs/tutorials/siglip2-image-classification.md +130 -0
- data/docs/tutorials/siglip2-object-recognition.md +203 -0
- data/docs/tutorials/siglip2-similarity-search.md +152 -0
- data/docs/tutorials/text-classification.md +233 -0
- data/docs/tutorials/text-embeddings.md +211 -0
- data/examples/basic_classification.rb +70 -0
- data/examples/data/tool_calls.jsonl +30 -0
- data/examples/demo_training.rb +78 -0
- data/examples/finetune_gemma3_tools.rb +135 -0
- data/examples/real_llm_test.rb +128 -0
- data/examples/real_text_classification_test.rb +90 -0
- data/examples/real_text_embedder_test.rb +110 -0
- data/examples/real_training_test.rb +88 -0
- data/examples/test_export.rb +28 -0
- data/examples/test_image_classifier.rb +79 -0
- data/examples/test_llm.rb +100 -0
- data/examples/test_text_classifier.rb +59 -0
- data/lib/fine/callbacks/base.rb +140 -0
- data/lib/fine/callbacks/progress_bar.rb +66 -0
- data/lib/fine/configuration.rb +106 -0
- data/lib/fine/datasets/data_loader.rb +63 -0
- data/lib/fine/datasets/image_dataset.rb +203 -0
- data/lib/fine/datasets/instruction_dataset.rb +226 -0
- data/lib/fine/datasets/text_data_loader.rb +88 -0
- data/lib/fine/datasets/text_dataset.rb +266 -0
- data/lib/fine/error.rb +49 -0
- data/lib/fine/export/gguf_exporter.rb +424 -0
- data/lib/fine/export/onnx_exporter.rb +249 -0
- data/lib/fine/export.rb +53 -0
- data/lib/fine/hub/config_loader.rb +145 -0
- data/lib/fine/hub/model_downloader.rb +136 -0
- data/lib/fine/hub/safetensors_loader.rb +108 -0
- data/lib/fine/image_classifier.rb +256 -0
- data/lib/fine/llm.rb +336 -0
- data/lib/fine/models/base.rb +48 -0
- data/lib/fine/models/bert_encoder.rb +202 -0
- data/lib/fine/models/bert_for_sequence_classification.rb +226 -0
- data/lib/fine/models/causal_lm.rb +279 -0
- data/lib/fine/models/classification_head.rb +24 -0
- data/lib/fine/models/gemma3_decoder.rb +244 -0
- data/lib/fine/models/llama_decoder.rb +297 -0
- data/lib/fine/models/sentence_transformer.rb +202 -0
- data/lib/fine/models/siglip2_for_image_classification.rb +155 -0
- data/lib/fine/models/siglip2_vision_encoder.rb +190 -0
- data/lib/fine/text_classifier.rb +250 -0
- data/lib/fine/text_embedder.rb +221 -0
- data/lib/fine/tokenizers/auto_tokenizer.rb +208 -0
- data/lib/fine/training/llm_trainer.rb +212 -0
- data/lib/fine/training/text_trainer.rb +275 -0
- data/lib/fine/training/trainer.rb +194 -0
- data/lib/fine/transforms/compose.rb +28 -0
- data/lib/fine/transforms/normalize.rb +33 -0
- data/lib/fine/transforms/resize.rb +35 -0
- data/lib/fine/transforms/to_tensor.rb +53 -0
- data/lib/fine/version.rb +3 -0
- data/lib/fine.rb +112 -0
- data/mise.toml +2 -0
- metadata +240 -0
|
@@ -0,0 +1,226 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fine
|
|
4
|
+
module Models
|
|
5
|
+
# BERT model with classification head for sequence classification
|
|
6
|
+
class BertForSequenceClassification < Base
|
|
7
|
+
attr_reader :encoder, :classifier, :num_labels
|
|
8
|
+
|
|
9
|
+
# Load a pretrained model from Hugging Face Hub
|
|
10
|
+
#
|
|
11
|
+
# @param model_id [String] HuggingFace model ID
|
|
12
|
+
# @param num_labels [Integer] Number of classification labels
|
|
13
|
+
# @param dropout [Float] Dropout rate for classifier
|
|
14
|
+
# @return [BertForSequenceClassification]
|
|
15
|
+
def self.from_pretrained(model_id, num_labels:, dropout: 0.1)
|
|
16
|
+
# Download model files
|
|
17
|
+
downloader = Hub::ModelDownloader.new(model_id)
|
|
18
|
+
model_path = downloader.download
|
|
19
|
+
|
|
20
|
+
# Load config
|
|
21
|
+
config = Hub::ConfigLoader.from_pretrained(model_path)
|
|
22
|
+
|
|
23
|
+
# Create model
|
|
24
|
+
model = new(config, num_labels: num_labels, dropout: dropout)
|
|
25
|
+
|
|
26
|
+
# Load pretrained weights into encoder
|
|
27
|
+
weights_path = downloader.file_path("model.safetensors")
|
|
28
|
+
|
|
29
|
+
if File.exist?(weights_path)
|
|
30
|
+
load_result = load_pretrained_weights(model, weights_path)
|
|
31
|
+
|
|
32
|
+
if load_result[:missing_keys].any?
|
|
33
|
+
# Only warn about unexpected missing keys (classifier is expected to be missing)
|
|
34
|
+
encoder_missing = load_result[:missing_keys].reject { |k| k.include?("classifier") }
|
|
35
|
+
if encoder_missing.any?
|
|
36
|
+
warn "Missing encoder keys: #{encoder_missing.first(5).join(', ')}..."
|
|
37
|
+
end
|
|
38
|
+
end
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
model
|
|
42
|
+
end
|
|
43
|
+
|
|
44
|
+
# Load from saved fine-tuned model
|
|
45
|
+
#
|
|
46
|
+
# @param path [String] Path to saved model directory
|
|
47
|
+
# @return [BertForSequenceClassification]
|
|
48
|
+
def self.load(path)
|
|
49
|
+
raise ModelNotFoundError.new(path) unless File.directory?(path)
|
|
50
|
+
|
|
51
|
+
config = Hub::ConfigLoader.from_pretrained(path)
|
|
52
|
+
num_labels = config.config["num_labels"] || config.config["id2label"]&.size
|
|
53
|
+
|
|
54
|
+
raise ConfigurationError, "Cannot determine num_labels from saved model" unless num_labels
|
|
55
|
+
|
|
56
|
+
model = new(config, num_labels: num_labels)
|
|
57
|
+
|
|
58
|
+
weights_path = File.join(path, "model.safetensors")
|
|
59
|
+
Hub::SafetensorsLoader.load_into_model(model, weights_path, strict: false)
|
|
60
|
+
|
|
61
|
+
model
|
|
62
|
+
end
|
|
63
|
+
|
|
64
|
+
def initialize(config, num_labels:, dropout: 0.1)
|
|
65
|
+
super(config)
|
|
66
|
+
|
|
67
|
+
@num_labels = num_labels
|
|
68
|
+
|
|
69
|
+
# Encoder
|
|
70
|
+
@encoder = BertEncoder.new(config)
|
|
71
|
+
|
|
72
|
+
# Classification head
|
|
73
|
+
@dropout = Torch::NN::Dropout.new(p: dropout)
|
|
74
|
+
@classifier = Torch::NN::Linear.new(config.hidden_size, num_labels)
|
|
75
|
+
end
|
|
76
|
+
|
|
77
|
+
def forward(input_ids, attention_mask: nil, token_type_ids: nil, labels: nil)
|
|
78
|
+
# Get encoder outputs
|
|
79
|
+
encoder_output = @encoder.call(
|
|
80
|
+
input_ids,
|
|
81
|
+
attention_mask: attention_mask,
|
|
82
|
+
token_type_ids: token_type_ids
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# Use pooled output for classification
|
|
86
|
+
pooled_output = encoder_output[:pooler_output]
|
|
87
|
+
pooled_output = @dropout.call(pooled_output)
|
|
88
|
+
|
|
89
|
+
# Classification logits
|
|
90
|
+
logits = @classifier.call(pooled_output)
|
|
91
|
+
|
|
92
|
+
# Compute loss if labels provided
|
|
93
|
+
if labels
|
|
94
|
+
loss = Torch::NN::Functional.cross_entropy(logits, labels)
|
|
95
|
+
{ loss: loss, logits: logits }
|
|
96
|
+
else
|
|
97
|
+
{ logits: logits }
|
|
98
|
+
end
|
|
99
|
+
end
|
|
100
|
+
|
|
101
|
+
# Predict class for input
|
|
102
|
+
#
|
|
103
|
+
# @param input_ids [Torch::Tensor] Input token IDs
|
|
104
|
+
# @param attention_mask [Torch::Tensor] Attention mask
|
|
105
|
+
# @return [Torch::Tensor] Predicted class indices
|
|
106
|
+
def predict(input_ids, attention_mask: nil, token_type_ids: nil)
|
|
107
|
+
eval
|
|
108
|
+
Torch.no_grad do
|
|
109
|
+
output = forward(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids)
|
|
110
|
+
output[:logits].argmax(dim: 1)
|
|
111
|
+
end
|
|
112
|
+
end
|
|
113
|
+
|
|
114
|
+
# Get class probabilities
|
|
115
|
+
#
|
|
116
|
+
# @return [Torch::Tensor] Class probabilities
|
|
117
|
+
def predict_proba(input_ids, attention_mask: nil, token_type_ids: nil)
|
|
118
|
+
eval
|
|
119
|
+
Torch.no_grad do
|
|
120
|
+
output = forward(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids)
|
|
121
|
+
Torch::NN::Functional.softmax(output[:logits], dim: 1)
|
|
122
|
+
end
|
|
123
|
+
end
|
|
124
|
+
|
|
125
|
+
# Save the model
|
|
126
|
+
#
|
|
127
|
+
# @param path [String] Directory path
|
|
128
|
+
# @param label_map [Hash, nil] Label mapping
|
|
129
|
+
def save(path, label_map: nil)
|
|
130
|
+
FileUtils.mkdir_p(path)
|
|
131
|
+
|
|
132
|
+
# Save weights
|
|
133
|
+
weights_path = File.join(path, "model.safetensors")
|
|
134
|
+
Safetensors::Torch.save_file(state_dict, weights_path)
|
|
135
|
+
|
|
136
|
+
# Build config
|
|
137
|
+
save_config = @config.to_h.merge(
|
|
138
|
+
"num_labels" => @num_labels,
|
|
139
|
+
"model_type" => "bert_classification"
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
if label_map
|
|
143
|
+
save_config["id2label"] = label_map.invert.sort.to_h
|
|
144
|
+
save_config["label2id"] = label_map
|
|
145
|
+
end
|
|
146
|
+
|
|
147
|
+
config_path = File.join(path, "config.json")
|
|
148
|
+
File.write(config_path, JSON.pretty_generate(save_config))
|
|
149
|
+
end
|
|
150
|
+
|
|
151
|
+
private
|
|
152
|
+
|
|
153
|
+
def self.load_pretrained_weights(model, weights_path)
|
|
154
|
+
tensors = Safetensors::Torch.load_file(weights_path)
|
|
155
|
+
|
|
156
|
+
# Map HuggingFace weight names to our model structure
|
|
157
|
+
mapped = {}
|
|
158
|
+
missing_keys = []
|
|
159
|
+
unexpected_keys = []
|
|
160
|
+
|
|
161
|
+
model_keys = model.state_dict.keys
|
|
162
|
+
|
|
163
|
+
tensors.each do |name, tensor|
|
|
164
|
+
mapped_name = map_bert_weight_name(name)
|
|
165
|
+
|
|
166
|
+
if model_keys.include?(mapped_name)
|
|
167
|
+
mapped[mapped_name] = tensor
|
|
168
|
+
else
|
|
169
|
+
unexpected_keys << name
|
|
170
|
+
end
|
|
171
|
+
end
|
|
172
|
+
|
|
173
|
+
missing_keys = model_keys - mapped.keys
|
|
174
|
+
|
|
175
|
+
# Use no_grad and copy! since torch.rb doesn't support strict: false
|
|
176
|
+
Torch.no_grad do
|
|
177
|
+
state_dict = model.state_dict
|
|
178
|
+
mapped.each do |name, tensor|
|
|
179
|
+
state_dict[name].copy!(tensor) if state_dict.key?(name)
|
|
180
|
+
end
|
|
181
|
+
end
|
|
182
|
+
|
|
183
|
+
{ missing_keys: missing_keys, unexpected_keys: unexpected_keys }
|
|
184
|
+
end
|
|
185
|
+
|
|
186
|
+
def self.map_bert_weight_name(hf_name)
|
|
187
|
+
name = hf_name.dup
|
|
188
|
+
|
|
189
|
+
# Embeddings
|
|
190
|
+
name = name.sub("bert.embeddings.word_embeddings", "encoder.word_embeddings")
|
|
191
|
+
name = name.sub("bert.embeddings.position_embeddings", "encoder.position_embeddings")
|
|
192
|
+
name = name.sub("bert.embeddings.token_type_embeddings", "encoder.token_type_embeddings")
|
|
193
|
+
name = name.sub("bert.embeddings.LayerNorm", "encoder.embeddings_layer_norm")
|
|
194
|
+
name = name.sub("embeddings.word_embeddings", "encoder.word_embeddings")
|
|
195
|
+
name = name.sub("embeddings.position_embeddings", "encoder.position_embeddings")
|
|
196
|
+
name = name.sub("embeddings.token_type_embeddings", "encoder.token_type_embeddings")
|
|
197
|
+
name = name.sub("embeddings.LayerNorm", "encoder.embeddings_layer_norm")
|
|
198
|
+
|
|
199
|
+
# Encoder layers
|
|
200
|
+
name = name.gsub("bert.encoder.layer", "encoder.layers")
|
|
201
|
+
name = name.gsub("encoder.layer", "encoder.layers")
|
|
202
|
+
|
|
203
|
+
# Attention
|
|
204
|
+
name = name.gsub(".attention.self.query", ".attention.query")
|
|
205
|
+
name = name.gsub(".attention.self.key", ".attention.key")
|
|
206
|
+
name = name.gsub(".attention.self.value", ".attention.value")
|
|
207
|
+
name = name.gsub(".attention.output.dense", ".attention.out")
|
|
208
|
+
name = name.gsub(".attention.output.LayerNorm", ".attention_layer_norm")
|
|
209
|
+
|
|
210
|
+
# FFN
|
|
211
|
+
name = name.gsub(".intermediate.dense", ".intermediate")
|
|
212
|
+
name = name.gsub(".output.dense", ".output")
|
|
213
|
+
name = name.gsub(".output.LayerNorm", ".output_layer_norm")
|
|
214
|
+
|
|
215
|
+
# Pooler
|
|
216
|
+
name = name.sub("bert.pooler.dense", "encoder.pooler_dense")
|
|
217
|
+
name = name.sub("pooler.dense", "encoder.pooler_dense")
|
|
218
|
+
|
|
219
|
+
# Classifier
|
|
220
|
+
name = name.sub("classifier", "classifier")
|
|
221
|
+
|
|
222
|
+
name
|
|
223
|
+
end
|
|
224
|
+
end
|
|
225
|
+
end
|
|
226
|
+
end
|
|
@@ -0,0 +1,279 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fine
|
|
4
|
+
module Models
|
|
5
|
+
# Causal Language Model for text generation
|
|
6
|
+
#
|
|
7
|
+
# Wraps a decoder model with a language modeling head for next-token prediction.
|
|
8
|
+
class CausalLM < Base
|
|
9
|
+
attr_reader :decoder, :lm_head
|
|
10
|
+
|
|
11
|
+
# Load from HuggingFace Hub
|
|
12
|
+
#
|
|
13
|
+
# @param model_id [String] Model ID (e.g., "google/gemma-2b", "meta-llama/Llama-3.2-1B")
|
|
14
|
+
# @param dtype [Symbol] Data type (:float32, :float16, :bfloat16, or :auto)
|
|
15
|
+
# @return [CausalLM]
|
|
16
|
+
def self.from_pretrained(model_id, dtype: :auto)
|
|
17
|
+
downloader = Hub::ModelDownloader.new(model_id)
|
|
18
|
+
model_path = downloader.download
|
|
19
|
+
|
|
20
|
+
config = Hub::ConfigLoader.from_pretrained(model_path)
|
|
21
|
+
|
|
22
|
+
# Determine dtype from config if auto
|
|
23
|
+
if dtype == :auto
|
|
24
|
+
config_dtype = config.config["torch_dtype"]
|
|
25
|
+
dtype = case config_dtype
|
|
26
|
+
when "bfloat16" then :bfloat16
|
|
27
|
+
when "float16" then :float16
|
|
28
|
+
else :float32
|
|
29
|
+
end
|
|
30
|
+
end
|
|
31
|
+
|
|
32
|
+
model = new(config)
|
|
33
|
+
|
|
34
|
+
# Convert model to target dtype before loading weights (saves memory)
|
|
35
|
+
model.to(dtype) if dtype != :float32
|
|
36
|
+
|
|
37
|
+
# Load weights
|
|
38
|
+
weights_path = downloader.file_path("model.safetensors")
|
|
39
|
+
if File.exist?(weights_path)
|
|
40
|
+
load_pretrained_weights(model, weights_path)
|
|
41
|
+
else
|
|
42
|
+
# Try sharded weights
|
|
43
|
+
load_sharded_weights(model, model_path)
|
|
44
|
+
end
|
|
45
|
+
|
|
46
|
+
model
|
|
47
|
+
end
|
|
48
|
+
|
|
49
|
+
# Load from saved model
|
|
50
|
+
def self.load(path)
|
|
51
|
+
raise ModelNotFoundError.new(path) unless File.directory?(path)
|
|
52
|
+
|
|
53
|
+
config = Hub::ConfigLoader.from_pretrained(path)
|
|
54
|
+
model = new(config)
|
|
55
|
+
|
|
56
|
+
weights_path = File.join(path, "model.safetensors")
|
|
57
|
+
Hub::SafetensorsLoader.load_into_model(model, weights_path, strict: false)
|
|
58
|
+
|
|
59
|
+
model
|
|
60
|
+
end
|
|
61
|
+
|
|
62
|
+
def initialize(config)
|
|
63
|
+
super(config)
|
|
64
|
+
|
|
65
|
+
# Use appropriate decoder based on model type
|
|
66
|
+
@decoder = if config.model_type&.include?("gemma3")
|
|
67
|
+
Gemma3Decoder.new(config)
|
|
68
|
+
else
|
|
69
|
+
LlamaDecoder.new(config)
|
|
70
|
+
end
|
|
71
|
+
|
|
72
|
+
# LM head (often tied to embeddings)
|
|
73
|
+
@lm_head = Torch::NN::Linear.new(config.hidden_size, config.vocab_size, bias: false)
|
|
74
|
+
|
|
75
|
+
# Optionally tie weights with embeddings
|
|
76
|
+
@tie_word_embeddings = config.config["tie_word_embeddings"] != false
|
|
77
|
+
end
|
|
78
|
+
|
|
79
|
+
def forward(input_ids, attention_mask: nil, labels: nil, return_logits: nil)
|
|
80
|
+
# Default: return logits only if no labels (inference mode)
|
|
81
|
+
return_logits = labels.nil? if return_logits.nil?
|
|
82
|
+
|
|
83
|
+
# Get decoder outputs
|
|
84
|
+
outputs = @decoder.call(input_ids, attention_mask: attention_mask)
|
|
85
|
+
hidden_states = outputs[:last_hidden_state]
|
|
86
|
+
|
|
87
|
+
# LM head
|
|
88
|
+
logits = @lm_head.call(hidden_states)
|
|
89
|
+
|
|
90
|
+
# Compute loss if labels provided
|
|
91
|
+
if labels
|
|
92
|
+
# Shift for next-token prediction
|
|
93
|
+
shift_logits = logits[0.., 0...-1, 0..].contiguous
|
|
94
|
+
shift_labels = labels[0.., 1..].contiguous
|
|
95
|
+
|
|
96
|
+
# Compute cross entropy loss
|
|
97
|
+
vocab_size = logits.size(-1)
|
|
98
|
+
loss = Torch::NN::Functional.cross_entropy(
|
|
99
|
+
shift_logits.view(-1, vocab_size),
|
|
100
|
+
shift_labels.view(-1),
|
|
101
|
+
ignore_index: -100
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# Don't return logits during training to save memory
|
|
105
|
+
return_logits ? { loss: loss, logits: logits } : { loss: loss }
|
|
106
|
+
else
|
|
107
|
+
{ logits: logits }
|
|
108
|
+
end
|
|
109
|
+
end
|
|
110
|
+
|
|
111
|
+
# Generate text autoregressively
|
|
112
|
+
#
|
|
113
|
+
# @param input_ids [Torch::Tensor] Input token IDs
|
|
114
|
+
# @param max_new_tokens [Integer] Maximum tokens to generate
|
|
115
|
+
# @param temperature [Float] Sampling temperature
|
|
116
|
+
# @param top_p [Float] Nucleus sampling threshold
|
|
117
|
+
# @param top_k [Integer] Top-k sampling
|
|
118
|
+
# @param do_sample [Boolean] Whether to sample or use greedy decoding
|
|
119
|
+
# @return [Torch::Tensor] Generated token IDs
|
|
120
|
+
def generate(input_ids, max_new_tokens: 100, temperature: 1.0, top_p: 0.9,
|
|
121
|
+
top_k: 50, do_sample: true, eos_token_id: nil, pad_token_id: nil)
|
|
122
|
+
eval
|
|
123
|
+
generated = input_ids.clone
|
|
124
|
+
|
|
125
|
+
Torch.no_grad do
|
|
126
|
+
max_new_tokens.times do
|
|
127
|
+
# Forward pass
|
|
128
|
+
outputs = forward(generated)
|
|
129
|
+
next_token_logits = outputs[:logits][0.., -1, 0..]
|
|
130
|
+
|
|
131
|
+
# Apply temperature
|
|
132
|
+
next_token_logits = next_token_logits / temperature if temperature != 1.0
|
|
133
|
+
|
|
134
|
+
if do_sample
|
|
135
|
+
# Top-k filtering
|
|
136
|
+
if top_k > 0
|
|
137
|
+
indices_to_remove = next_token_logits < Torch.topk(next_token_logits, top_k).values[0.., -1, nil]
|
|
138
|
+
next_token_logits = next_token_logits.masked_fill(indices_to_remove, -Float::INFINITY)
|
|
139
|
+
end
|
|
140
|
+
|
|
141
|
+
# Top-p (nucleus) filtering
|
|
142
|
+
if top_p < 1.0
|
|
143
|
+
sorted_logits, sorted_indices = Torch.sort(next_token_logits, descending: true)
|
|
144
|
+
cumulative_probs = Torch.cumsum(Torch::NN::Functional.softmax(sorted_logits, dim: -1), dim: -1)
|
|
145
|
+
|
|
146
|
+
# Remove tokens with cumulative probability above threshold
|
|
147
|
+
sorted_indices_to_remove = cumulative_probs > top_p
|
|
148
|
+
sorted_indices_to_remove[0.., 1..] = sorted_indices_to_remove[0.., 0...-1].clone
|
|
149
|
+
sorted_indices_to_remove[0.., 0] = false
|
|
150
|
+
|
|
151
|
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
|
152
|
+
next_token_logits = next_token_logits.masked_fill(indices_to_remove, -Float::INFINITY)
|
|
153
|
+
end
|
|
154
|
+
|
|
155
|
+
# Sample
|
|
156
|
+
probs = Torch::NN::Functional.softmax(next_token_logits, dim: -1)
|
|
157
|
+
next_token = Torch.multinomial(probs, num_samples: 1)
|
|
158
|
+
else
|
|
159
|
+
# Greedy
|
|
160
|
+
next_token = next_token_logits.argmax(dim: -1, keepdim: true)
|
|
161
|
+
end
|
|
162
|
+
|
|
163
|
+
# Append to generated
|
|
164
|
+
generated = Torch.cat([generated, next_token], dim: 1)
|
|
165
|
+
|
|
166
|
+
# Check for EOS
|
|
167
|
+
if eos_token_id
|
|
168
|
+
# Handle both single and array EOS token IDs
|
|
169
|
+
eos_ids = eos_token_id.is_a?(Array) ? eos_token_id : [eos_token_id]
|
|
170
|
+
next_token_val = next_token[0, 0].item rescue next_token.to(:int64)[0, 0].item
|
|
171
|
+
break if eos_ids.include?(next_token_val)
|
|
172
|
+
end
|
|
173
|
+
end
|
|
174
|
+
end
|
|
175
|
+
|
|
176
|
+
generated
|
|
177
|
+
end
|
|
178
|
+
|
|
179
|
+
def save(path)
|
|
180
|
+
FileUtils.mkdir_p(path)
|
|
181
|
+
|
|
182
|
+
weights_path = File.join(path, "model.safetensors")
|
|
183
|
+
Safetensors::Torch.save_file(state_dict, weights_path)
|
|
184
|
+
|
|
185
|
+
save_config = @config.to_h.merge("model_type" => "causal_lm")
|
|
186
|
+
|
|
187
|
+
config_path = File.join(path, "config.json")
|
|
188
|
+
File.write(config_path, JSON.pretty_generate(save_config))
|
|
189
|
+
end
|
|
190
|
+
|
|
191
|
+
private
|
|
192
|
+
|
|
193
|
+
def self.load_pretrained_weights(model, weights_path)
|
|
194
|
+
# Load and copy weights one at a time to minimize memory usage
|
|
195
|
+
model_state = model.state_dict
|
|
196
|
+
model_keys = model_state.keys
|
|
197
|
+
|
|
198
|
+
Torch.no_grad do
|
|
199
|
+
Safetensors::Torch.load_file(weights_path).each do |name, tensor|
|
|
200
|
+
mapped_name = map_llama_weight_name(name)
|
|
201
|
+
if model_keys.include?(mapped_name)
|
|
202
|
+
target = model_state[mapped_name]
|
|
203
|
+
# Convert dtype if needed
|
|
204
|
+
tensor = tensor.to(target.dtype) if tensor.dtype != target.dtype
|
|
205
|
+
target.copy!(tensor)
|
|
206
|
+
end
|
|
207
|
+
end
|
|
208
|
+
end
|
|
209
|
+
|
|
210
|
+
# Force garbage collection to free loaded tensors
|
|
211
|
+
GC.start
|
|
212
|
+
end
|
|
213
|
+
|
|
214
|
+
def self.load_sharded_weights(model, model_path)
|
|
215
|
+
# Find all safetensors shards
|
|
216
|
+
shards = Dir.glob(File.join(model_path, "model-*.safetensors")).sort
|
|
217
|
+
|
|
218
|
+
return if shards.empty?
|
|
219
|
+
|
|
220
|
+
model_state = model.state_dict
|
|
221
|
+
model_keys = model_state.keys
|
|
222
|
+
|
|
223
|
+
# Load each shard and copy weights immediately to minimize memory
|
|
224
|
+
Torch.no_grad do
|
|
225
|
+
shards.each do |shard_path|
|
|
226
|
+
Safetensors::Torch.load_file(shard_path).each do |name, tensor|
|
|
227
|
+
mapped_name = map_llama_weight_name(name)
|
|
228
|
+
if model_keys.include?(mapped_name)
|
|
229
|
+
target = model_state[mapped_name]
|
|
230
|
+
tensor = tensor.to(target.dtype) if tensor.dtype != target.dtype
|
|
231
|
+
target.copy!(tensor)
|
|
232
|
+
end
|
|
233
|
+
end
|
|
234
|
+
# GC after each shard to free memory
|
|
235
|
+
GC.start
|
|
236
|
+
end
|
|
237
|
+
end
|
|
238
|
+
end
|
|
239
|
+
|
|
240
|
+
def self.map_llama_weight_name(hf_name)
|
|
241
|
+
name = hf_name.dup
|
|
242
|
+
|
|
243
|
+
# Embeddings
|
|
244
|
+
name = name.sub("model.embed_tokens", "decoder.embed_tokens")
|
|
245
|
+
name = name.sub("lm_head", "lm_head")
|
|
246
|
+
|
|
247
|
+
# Layers
|
|
248
|
+
name = name.gsub("model.layers", "decoder.layers")
|
|
249
|
+
|
|
250
|
+
# Attention (works for both Llama and Gemma)
|
|
251
|
+
name = name.gsub(".self_attn.q_proj", ".self_attn.q_proj")
|
|
252
|
+
name = name.gsub(".self_attn.k_proj", ".self_attn.k_proj")
|
|
253
|
+
name = name.gsub(".self_attn.v_proj", ".self_attn.v_proj")
|
|
254
|
+
name = name.gsub(".self_attn.o_proj", ".self_attn.o_proj")
|
|
255
|
+
|
|
256
|
+
# Gemma 3 QK normalization
|
|
257
|
+
name = name.gsub(".self_attn.q_norm", ".self_attn.q_norm")
|
|
258
|
+
name = name.gsub(".self_attn.k_norm", ".self_attn.k_norm")
|
|
259
|
+
|
|
260
|
+
# MLP
|
|
261
|
+
name = name.gsub(".mlp.gate_proj", ".mlp.gate_proj")
|
|
262
|
+
name = name.gsub(".mlp.up_proj", ".mlp.up_proj")
|
|
263
|
+
name = name.gsub(".mlp.down_proj", ".mlp.down_proj")
|
|
264
|
+
|
|
265
|
+
# Norms (standard)
|
|
266
|
+
name = name.gsub(".input_layernorm", ".input_layernorm")
|
|
267
|
+
name = name.gsub(".post_attention_layernorm", ".post_attention_layernorm")
|
|
268
|
+
|
|
269
|
+
# Gemma 3 additional norms
|
|
270
|
+
name = name.gsub(".pre_feedforward_layernorm", ".pre_feedforward_layernorm")
|
|
271
|
+
name = name.gsub(".post_feedforward_layernorm", ".post_feedforward_layernorm")
|
|
272
|
+
|
|
273
|
+
name = name.sub("model.norm", "decoder.norm")
|
|
274
|
+
|
|
275
|
+
name
|
|
276
|
+
end
|
|
277
|
+
end
|
|
278
|
+
end
|
|
279
|
+
end
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fine
|
|
4
|
+
module Models
|
|
5
|
+
# Simple classification head (linear layer)
|
|
6
|
+
class ClassificationHead < Torch::NN::Module
|
|
7
|
+
attr_reader :in_features, :num_classes
|
|
8
|
+
|
|
9
|
+
def initialize(in_features, num_classes, dropout: 0.0)
|
|
10
|
+
super()
|
|
11
|
+
@in_features = in_features
|
|
12
|
+
@num_classes = num_classes
|
|
13
|
+
|
|
14
|
+
@dropout = Torch::NN::Dropout.new(p: dropout) if dropout.positive?
|
|
15
|
+
@classifier = Torch::NN::Linear.new(in_features, num_classes)
|
|
16
|
+
end
|
|
17
|
+
|
|
18
|
+
def forward(x)
|
|
19
|
+
x = @dropout.call(x) if @dropout
|
|
20
|
+
@classifier.call(x)
|
|
21
|
+
end
|
|
22
|
+
end
|
|
23
|
+
end
|
|
24
|
+
end
|