fine 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.rspec +3 -0
- data/CHANGELOG.md +38 -0
- data/Gemfile +6 -0
- data/Gemfile.lock +167 -0
- data/LICENSE +21 -0
- data/README.md +212 -0
- data/Rakefile +6 -0
- data/docs/installation.md +151 -0
- data/docs/tutorials/llm-fine-tuning.md +246 -0
- data/docs/tutorials/model-export.md +200 -0
- data/docs/tutorials/siglip2-image-classification.md +130 -0
- data/docs/tutorials/siglip2-object-recognition.md +203 -0
- data/docs/tutorials/siglip2-similarity-search.md +152 -0
- data/docs/tutorials/text-classification.md +233 -0
- data/docs/tutorials/text-embeddings.md +211 -0
- data/examples/basic_classification.rb +70 -0
- data/examples/data/tool_calls.jsonl +30 -0
- data/examples/demo_training.rb +78 -0
- data/examples/finetune_gemma3_tools.rb +135 -0
- data/examples/real_llm_test.rb +128 -0
- data/examples/real_text_classification_test.rb +90 -0
- data/examples/real_text_embedder_test.rb +110 -0
- data/examples/real_training_test.rb +88 -0
- data/examples/test_export.rb +28 -0
- data/examples/test_image_classifier.rb +79 -0
- data/examples/test_llm.rb +100 -0
- data/examples/test_text_classifier.rb +59 -0
- data/lib/fine/callbacks/base.rb +140 -0
- data/lib/fine/callbacks/progress_bar.rb +66 -0
- data/lib/fine/configuration.rb +106 -0
- data/lib/fine/datasets/data_loader.rb +63 -0
- data/lib/fine/datasets/image_dataset.rb +203 -0
- data/lib/fine/datasets/instruction_dataset.rb +226 -0
- data/lib/fine/datasets/text_data_loader.rb +88 -0
- data/lib/fine/datasets/text_dataset.rb +266 -0
- data/lib/fine/error.rb +49 -0
- data/lib/fine/export/gguf_exporter.rb +424 -0
- data/lib/fine/export/onnx_exporter.rb +249 -0
- data/lib/fine/export.rb +53 -0
- data/lib/fine/hub/config_loader.rb +145 -0
- data/lib/fine/hub/model_downloader.rb +136 -0
- data/lib/fine/hub/safetensors_loader.rb +108 -0
- data/lib/fine/image_classifier.rb +256 -0
- data/lib/fine/llm.rb +336 -0
- data/lib/fine/models/base.rb +48 -0
- data/lib/fine/models/bert_encoder.rb +202 -0
- data/lib/fine/models/bert_for_sequence_classification.rb +226 -0
- data/lib/fine/models/causal_lm.rb +279 -0
- data/lib/fine/models/classification_head.rb +24 -0
- data/lib/fine/models/gemma3_decoder.rb +244 -0
- data/lib/fine/models/llama_decoder.rb +297 -0
- data/lib/fine/models/sentence_transformer.rb +202 -0
- data/lib/fine/models/siglip2_for_image_classification.rb +155 -0
- data/lib/fine/models/siglip2_vision_encoder.rb +190 -0
- data/lib/fine/text_classifier.rb +250 -0
- data/lib/fine/text_embedder.rb +221 -0
- data/lib/fine/tokenizers/auto_tokenizer.rb +208 -0
- data/lib/fine/training/llm_trainer.rb +212 -0
- data/lib/fine/training/text_trainer.rb +275 -0
- data/lib/fine/training/trainer.rb +194 -0
- data/lib/fine/transforms/compose.rb +28 -0
- data/lib/fine/transforms/normalize.rb +33 -0
- data/lib/fine/transforms/resize.rb +35 -0
- data/lib/fine/transforms/to_tensor.rb +53 -0
- data/lib/fine/version.rb +3 -0
- data/lib/fine.rb +112 -0
- data/mise.toml +2 -0
- metadata +240 -0
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fine
|
|
4
|
+
module Models
|
|
5
|
+
# Sentence Transformer for generating text embeddings
|
|
6
|
+
#
|
|
7
|
+
# Produces dense vector representations of sentences for semantic similarity,
|
|
8
|
+
# clustering, and retrieval tasks.
|
|
9
|
+
class SentenceTransformer < Base
|
|
10
|
+
attr_reader :encoder, :pooling_mode
|
|
11
|
+
|
|
12
|
+
POOLING_MODES = %i[cls mean max].freeze
|
|
13
|
+
|
|
14
|
+
# Load a pretrained sentence transformer
|
|
15
|
+
#
|
|
16
|
+
# @param model_id [String] HuggingFace model ID
|
|
17
|
+
# @param pooling_mode [Symbol] Pooling strategy (:cls, :mean, :max)
|
|
18
|
+
# @return [SentenceTransformer]
|
|
19
|
+
def self.from_pretrained(model_id, pooling_mode: :mean)
|
|
20
|
+
downloader = Hub::ModelDownloader.new(model_id)
|
|
21
|
+
model_path = downloader.download
|
|
22
|
+
|
|
23
|
+
config = Hub::ConfigLoader.from_pretrained(model_path)
|
|
24
|
+
|
|
25
|
+
model = new(config, pooling_mode: pooling_mode)
|
|
26
|
+
|
|
27
|
+
# Load weights
|
|
28
|
+
weights_path = downloader.file_path("model.safetensors")
|
|
29
|
+
if File.exist?(weights_path)
|
|
30
|
+
load_pretrained_weights(model, weights_path)
|
|
31
|
+
end
|
|
32
|
+
|
|
33
|
+
model
|
|
34
|
+
end
|
|
35
|
+
|
|
36
|
+
# Load from saved model
|
|
37
|
+
def self.load(path)
|
|
38
|
+
raise ModelNotFoundError.new(path) unless File.directory?(path)
|
|
39
|
+
|
|
40
|
+
config = Hub::ConfigLoader.from_pretrained(path)
|
|
41
|
+
pooling_mode = (config.config["pooling_mode"] || "mean").to_sym
|
|
42
|
+
|
|
43
|
+
model = new(config, pooling_mode: pooling_mode)
|
|
44
|
+
|
|
45
|
+
weights_path = File.join(path, "model.safetensors")
|
|
46
|
+
Hub::SafetensorsLoader.load_into_model(model, weights_path, strict: false)
|
|
47
|
+
|
|
48
|
+
model
|
|
49
|
+
end
|
|
50
|
+
|
|
51
|
+
def initialize(config, pooling_mode: :mean)
|
|
52
|
+
super(config)
|
|
53
|
+
|
|
54
|
+
raise ArgumentError, "Invalid pooling mode: #{pooling_mode}" unless POOLING_MODES.include?(pooling_mode)
|
|
55
|
+
|
|
56
|
+
@pooling_mode = pooling_mode
|
|
57
|
+
@encoder = BertEncoder.new(config)
|
|
58
|
+
|
|
59
|
+
# Optional: normalize embeddings
|
|
60
|
+
@normalize = true
|
|
61
|
+
end
|
|
62
|
+
|
|
63
|
+
def forward(input_ids, attention_mask: nil, token_type_ids: nil)
|
|
64
|
+
encoder_output = @encoder.call(
|
|
65
|
+
input_ids,
|
|
66
|
+
attention_mask: attention_mask,
|
|
67
|
+
token_type_ids: token_type_ids
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# Pool based on strategy
|
|
71
|
+
embeddings = case @pooling_mode
|
|
72
|
+
when :cls
|
|
73
|
+
encoder_output[:pooler_output]
|
|
74
|
+
when :mean
|
|
75
|
+
mean_pooling(encoder_output[:last_hidden_state], attention_mask)
|
|
76
|
+
when :max
|
|
77
|
+
max_pooling(encoder_output[:last_hidden_state], attention_mask)
|
|
78
|
+
end
|
|
79
|
+
|
|
80
|
+
# L2 normalize
|
|
81
|
+
embeddings = Torch::NN::Functional.normalize(embeddings, p: 2, dim: 1) if @normalize
|
|
82
|
+
|
|
83
|
+
{ embeddings: embeddings }
|
|
84
|
+
end
|
|
85
|
+
|
|
86
|
+
# Encode texts to embeddings
|
|
87
|
+
#
|
|
88
|
+
# @param input_ids [Torch::Tensor] Token IDs
|
|
89
|
+
# @param attention_mask [Torch::Tensor] Attention mask
|
|
90
|
+
# @return [Torch::Tensor] Embeddings
|
|
91
|
+
def encode(input_ids, attention_mask: nil, token_type_ids: nil)
|
|
92
|
+
eval
|
|
93
|
+
Torch.no_grad do
|
|
94
|
+
output = forward(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids)
|
|
95
|
+
output[:embeddings]
|
|
96
|
+
end
|
|
97
|
+
end
|
|
98
|
+
|
|
99
|
+
# Compute similarity between two sets of embeddings
|
|
100
|
+
#
|
|
101
|
+
# @param embeddings_a [Torch::Tensor] First embeddings
|
|
102
|
+
# @param embeddings_b [Torch::Tensor] Second embeddings
|
|
103
|
+
# @return [Torch::Tensor] Similarity scores
|
|
104
|
+
def similarity(embeddings_a, embeddings_b)
|
|
105
|
+
# Cosine similarity (embeddings should be normalized)
|
|
106
|
+
Torch.matmul(embeddings_a, embeddings_b.transpose(0, 1))
|
|
107
|
+
end
|
|
108
|
+
|
|
109
|
+
def save(path)
|
|
110
|
+
FileUtils.mkdir_p(path)
|
|
111
|
+
|
|
112
|
+
weights_path = File.join(path, "model.safetensors")
|
|
113
|
+
Safetensors::Torch.save_file(state_dict, weights_path)
|
|
114
|
+
|
|
115
|
+
save_config = @config.to_h.merge(
|
|
116
|
+
"model_type" => "sentence_transformer",
|
|
117
|
+
"pooling_mode" => @pooling_mode.to_s
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
config_path = File.join(path, "config.json")
|
|
121
|
+
File.write(config_path, JSON.pretty_generate(save_config))
|
|
122
|
+
end
|
|
123
|
+
|
|
124
|
+
private
|
|
125
|
+
|
|
126
|
+
def mean_pooling(hidden_states, attention_mask)
|
|
127
|
+
if attention_mask
|
|
128
|
+
mask = attention_mask.unsqueeze(-1).expand_as(hidden_states).float
|
|
129
|
+
sum_embeddings = (hidden_states * mask).sum(dim: 1)
|
|
130
|
+
sum_mask = mask.sum(dim: 1).clamp(min: 1e-9)
|
|
131
|
+
sum_embeddings / sum_mask
|
|
132
|
+
else
|
|
133
|
+
hidden_states.mean(dim: 1)
|
|
134
|
+
end
|
|
135
|
+
end
|
|
136
|
+
|
|
137
|
+
def max_pooling(hidden_states, attention_mask)
|
|
138
|
+
if attention_mask
|
|
139
|
+
mask = attention_mask.unsqueeze(-1).expand_as(hidden_states)
|
|
140
|
+
hidden_states = hidden_states.masked_fill(mask == 0, -1e9)
|
|
141
|
+
end
|
|
142
|
+
hidden_states.max(dim: 1).values
|
|
143
|
+
end
|
|
144
|
+
|
|
145
|
+
def self.load_pretrained_weights(model, weights_path)
|
|
146
|
+
tensors = Safetensors::Torch.load_file(weights_path)
|
|
147
|
+
|
|
148
|
+
mapped = {}
|
|
149
|
+
model_keys = model.state_dict.keys
|
|
150
|
+
|
|
151
|
+
tensors.each do |name, tensor|
|
|
152
|
+
# Try direct mapping first
|
|
153
|
+
mapped_name = map_sentence_transformer_weight_name(name)
|
|
154
|
+
|
|
155
|
+
if model_keys.include?(mapped_name)
|
|
156
|
+
mapped[mapped_name] = tensor
|
|
157
|
+
end
|
|
158
|
+
end
|
|
159
|
+
|
|
160
|
+
# Use no_grad and copy! since torch.rb doesn't support strict: false
|
|
161
|
+
Torch.no_grad do
|
|
162
|
+
state_dict = model.state_dict
|
|
163
|
+
mapped.each do |name, tensor|
|
|
164
|
+
state_dict[name].copy!(tensor) if state_dict.key?(name)
|
|
165
|
+
end
|
|
166
|
+
end
|
|
167
|
+
end
|
|
168
|
+
|
|
169
|
+
def self.map_sentence_transformer_weight_name(hf_name)
|
|
170
|
+
name = hf_name.dup
|
|
171
|
+
|
|
172
|
+
# sentence-transformers uses "0." prefix for the encoder module
|
|
173
|
+
name = name.sub(/^0\./, "encoder.")
|
|
174
|
+
|
|
175
|
+
# Then apply BERT mappings
|
|
176
|
+
name = name.sub("auto_model.", "encoder.")
|
|
177
|
+
name = name.sub("bert.embeddings.word_embeddings", "encoder.word_embeddings")
|
|
178
|
+
name = name.sub("bert.embeddings.position_embeddings", "encoder.position_embeddings")
|
|
179
|
+
name = name.sub("bert.embeddings.token_type_embeddings", "encoder.token_type_embeddings")
|
|
180
|
+
name = name.sub("bert.embeddings.LayerNorm", "encoder.embeddings_layer_norm")
|
|
181
|
+
name = name.sub("embeddings.word_embeddings", "encoder.word_embeddings")
|
|
182
|
+
name = name.sub("embeddings.position_embeddings", "encoder.position_embeddings")
|
|
183
|
+
name = name.sub("embeddings.token_type_embeddings", "encoder.token_type_embeddings")
|
|
184
|
+
name = name.sub("embeddings.LayerNorm", "encoder.embeddings_layer_norm")
|
|
185
|
+
name = name.gsub("bert.encoder.layer", "encoder.layers")
|
|
186
|
+
name = name.gsub("encoder.layer", "encoder.layers")
|
|
187
|
+
name = name.gsub(".attention.self.query", ".attention.query")
|
|
188
|
+
name = name.gsub(".attention.self.key", ".attention.key")
|
|
189
|
+
name = name.gsub(".attention.self.value", ".attention.value")
|
|
190
|
+
name = name.gsub(".attention.output.dense", ".attention.out")
|
|
191
|
+
name = name.gsub(".attention.output.LayerNorm", ".attention_layer_norm")
|
|
192
|
+
name = name.gsub(".intermediate.dense", ".intermediate")
|
|
193
|
+
name = name.gsub(".output.dense", ".output")
|
|
194
|
+
name = name.gsub(".output.LayerNorm", ".output_layer_norm")
|
|
195
|
+
name = name.sub("bert.pooler.dense", "encoder.pooler_dense")
|
|
196
|
+
name = name.sub("pooler.dense", "encoder.pooler_dense")
|
|
197
|
+
|
|
198
|
+
name
|
|
199
|
+
end
|
|
200
|
+
end
|
|
201
|
+
end
|
|
202
|
+
end
|
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fine
|
|
4
|
+
module Models
|
|
5
|
+
# SigLIP2 model with classification head for image classification
|
|
6
|
+
class SigLIP2ForImageClassification < Base
|
|
7
|
+
attr_reader :encoder, :classifier, :num_labels
|
|
8
|
+
|
|
9
|
+
# Load a pretrained model from Hugging Face Hub
|
|
10
|
+
#
|
|
11
|
+
# @param model_id [String] Hugging Face model ID (e.g., "google/siglip2-base-patch16-224")
|
|
12
|
+
# @param num_labels [Integer] Number of classification labels
|
|
13
|
+
# @param freeze_encoder [Boolean] Whether to freeze the vision encoder
|
|
14
|
+
# @param dropout [Float] Dropout rate for classifier
|
|
15
|
+
# @return [SigLIP2ForImageClassification]
|
|
16
|
+
def self.from_pretrained(model_id, num_labels:, freeze_encoder: false, dropout: 0.0)
|
|
17
|
+
# Download model files
|
|
18
|
+
downloader = Hub::ModelDownloader.new(model_id)
|
|
19
|
+
model_path = downloader.download
|
|
20
|
+
|
|
21
|
+
# Load config
|
|
22
|
+
config = Hub::ConfigLoader.from_pretrained(model_path)
|
|
23
|
+
|
|
24
|
+
# Create model
|
|
25
|
+
model = new(config, num_labels: num_labels, dropout: dropout)
|
|
26
|
+
|
|
27
|
+
# Load pretrained weights into encoder
|
|
28
|
+
weights_path = downloader.file_path("model.safetensors")
|
|
29
|
+
load_result = Hub::SafetensorsLoader.load_into_model(
|
|
30
|
+
model.encoder,
|
|
31
|
+
weights_path,
|
|
32
|
+
strict: false,
|
|
33
|
+
prefix: "vision_model"
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
# Log any issues
|
|
37
|
+
if load_result[:missing_keys].any?
|
|
38
|
+
warn "Missing keys in encoder: #{load_result[:missing_keys].first(5).join(', ')}..."
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
# Freeze encoder if requested
|
|
42
|
+
model.encoder.freeze! if freeze_encoder
|
|
43
|
+
|
|
44
|
+
model
|
|
45
|
+
end
|
|
46
|
+
|
|
47
|
+
# Load a fine-tuned model from disk
|
|
48
|
+
#
|
|
49
|
+
# @param path [String] Path to saved model directory
|
|
50
|
+
# @return [SigLIP2ForImageClassification]
|
|
51
|
+
def self.load(path)
|
|
52
|
+
raise ModelNotFoundError.new(path, "Model directory not found") unless File.directory?(path)
|
|
53
|
+
|
|
54
|
+
# Load config
|
|
55
|
+
config = Hub::ConfigLoader.from_pretrained(path)
|
|
56
|
+
num_labels = config.config["num_labels"] || config.config["id2label"]&.size
|
|
57
|
+
|
|
58
|
+
raise ConfigurationError, "Cannot determine num_labels from saved model" unless num_labels
|
|
59
|
+
|
|
60
|
+
# Create model
|
|
61
|
+
model = new(config, num_labels: num_labels)
|
|
62
|
+
|
|
63
|
+
# Load weights
|
|
64
|
+
weights_path = File.join(path, "model.safetensors")
|
|
65
|
+
Hub::SafetensorsLoader.load_into_model(model, weights_path, strict: false)
|
|
66
|
+
|
|
67
|
+
model
|
|
68
|
+
end
|
|
69
|
+
|
|
70
|
+
def initialize(config, num_labels:, dropout: 0.0)
|
|
71
|
+
super(config)
|
|
72
|
+
|
|
73
|
+
@num_labels = num_labels
|
|
74
|
+
|
|
75
|
+
# Vision encoder
|
|
76
|
+
@encoder = SigLIP2VisionEncoder.new(config)
|
|
77
|
+
|
|
78
|
+
# Classification head
|
|
79
|
+
@classifier = ClassificationHead.new(
|
|
80
|
+
config.hidden_size,
|
|
81
|
+
num_labels,
|
|
82
|
+
dropout: dropout
|
|
83
|
+
)
|
|
84
|
+
end
|
|
85
|
+
|
|
86
|
+
def forward(pixel_values, labels: nil)
|
|
87
|
+
# Get image features from encoder
|
|
88
|
+
features = @encoder.call(pixel_values)
|
|
89
|
+
|
|
90
|
+
# Classification logits
|
|
91
|
+
logits = @classifier.call(features)
|
|
92
|
+
|
|
93
|
+
# Compute loss if labels provided
|
|
94
|
+
if labels
|
|
95
|
+
loss = Torch::NN::Functional.cross_entropy(logits, labels)
|
|
96
|
+
{ loss: loss, logits: logits }
|
|
97
|
+
else
|
|
98
|
+
{ logits: logits }
|
|
99
|
+
end
|
|
100
|
+
end
|
|
101
|
+
|
|
102
|
+
# Predict class for an image tensor
|
|
103
|
+
#
|
|
104
|
+
# @param pixel_values [Torch::Tensor] Image tensor (batch, C, H, W)
|
|
105
|
+
# @return [Torch::Tensor] Predicted class indices
|
|
106
|
+
def predict(pixel_values)
|
|
107
|
+
eval
|
|
108
|
+
Torch.no_grad do
|
|
109
|
+
output = forward(pixel_values)
|
|
110
|
+
output[:logits].argmax(dim: 1)
|
|
111
|
+
end
|
|
112
|
+
end
|
|
113
|
+
|
|
114
|
+
# Get probabilities for each class
|
|
115
|
+
#
|
|
116
|
+
# @param pixel_values [Torch::Tensor] Image tensor (batch, C, H, W)
|
|
117
|
+
# @return [Torch::Tensor] Class probabilities
|
|
118
|
+
def predict_proba(pixel_values)
|
|
119
|
+
eval
|
|
120
|
+
Torch.no_grad do
|
|
121
|
+
output = forward(pixel_values)
|
|
122
|
+
Torch::NN::Functional.softmax(output[:logits], dim: 1)
|
|
123
|
+
end
|
|
124
|
+
end
|
|
125
|
+
|
|
126
|
+
# Save the model to disk
|
|
127
|
+
#
|
|
128
|
+
# @param path [String] Directory path to save to
|
|
129
|
+
# @param label_map [Hash, nil] Optional mapping of label names to IDs
|
|
130
|
+
def save(path, label_map: nil)
|
|
131
|
+
FileUtils.mkdir_p(path)
|
|
132
|
+
|
|
133
|
+
# Save weights
|
|
134
|
+
weights_path = File.join(path, "model.safetensors")
|
|
135
|
+
Safetensors::Torch.save_file(state_dict, weights_path)
|
|
136
|
+
|
|
137
|
+
# Build config with num_labels
|
|
138
|
+
save_config = @config.to_h.merge(
|
|
139
|
+
"num_labels" => @num_labels,
|
|
140
|
+
"model_type" => "siglip2_image_classification"
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# Add label mapping if provided
|
|
144
|
+
if label_map
|
|
145
|
+
save_config["id2label"] = label_map.invert.sort.to_h
|
|
146
|
+
save_config["label2id"] = label_map
|
|
147
|
+
end
|
|
148
|
+
|
|
149
|
+
# Save config
|
|
150
|
+
config_path = File.join(path, "config.json")
|
|
151
|
+
File.write(config_path, JSON.pretty_generate(save_config))
|
|
152
|
+
end
|
|
153
|
+
end
|
|
154
|
+
end
|
|
155
|
+
end
|
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fine
|
|
4
|
+
module Models
|
|
5
|
+
# SigLIP2 Vision Transformer Encoder
|
|
6
|
+
#
|
|
7
|
+
# Implements the vision encoder portion of SigLIP2 for image feature extraction.
|
|
8
|
+
# Architecture follows the standard ViT with patch embedding, transformer blocks,
|
|
9
|
+
# and pooling.
|
|
10
|
+
class SigLIP2VisionEncoder < Base
|
|
11
|
+
def initialize(config)
|
|
12
|
+
super(config)
|
|
13
|
+
|
|
14
|
+
@hidden_size = config.hidden_size
|
|
15
|
+
@num_layers = config.num_hidden_layers
|
|
16
|
+
@num_heads = config.num_attention_heads
|
|
17
|
+
@intermediate_size = config.intermediate_size
|
|
18
|
+
@image_size = config.image_size
|
|
19
|
+
@patch_size = config.patch_size
|
|
20
|
+
@num_channels = config.num_channels
|
|
21
|
+
@layer_norm_eps = config.layer_norm_eps
|
|
22
|
+
|
|
23
|
+
# Patch embedding
|
|
24
|
+
@patch_embed = PatchEmbedding.new(
|
|
25
|
+
image_size: @image_size,
|
|
26
|
+
patch_size: @patch_size,
|
|
27
|
+
in_channels: @num_channels,
|
|
28
|
+
embed_dim: @hidden_size
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
# Position embedding (learnable)
|
|
32
|
+
num_patches = (@image_size / @patch_size) ** 2
|
|
33
|
+
@pos_embed = Torch::NN::Parameter.new(
|
|
34
|
+
Torch.zeros(1, num_patches, @hidden_size)
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
# Transformer blocks
|
|
38
|
+
@blocks = Torch::NN::ModuleList.new(
|
|
39
|
+
@num_layers.times.map do
|
|
40
|
+
TransformerBlock.new(
|
|
41
|
+
hidden_size: @hidden_size,
|
|
42
|
+
num_heads: @num_heads,
|
|
43
|
+
intermediate_size: @intermediate_size,
|
|
44
|
+
layer_norm_eps: @layer_norm_eps
|
|
45
|
+
)
|
|
46
|
+
end
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
# Final layer norm
|
|
50
|
+
@norm = Torch::NN::LayerNorm.new(@hidden_size, eps: @layer_norm_eps)
|
|
51
|
+
|
|
52
|
+
# Initialize position embedding
|
|
53
|
+
init_pos_embed
|
|
54
|
+
end
|
|
55
|
+
|
|
56
|
+
def forward(pixel_values)
|
|
57
|
+
# pixel_values: (batch, channels, height, width)
|
|
58
|
+
|
|
59
|
+
# Patch embedding: (batch, num_patches, hidden_size)
|
|
60
|
+
x = @patch_embed.call(pixel_values)
|
|
61
|
+
|
|
62
|
+
# Add position embedding
|
|
63
|
+
x = x + @pos_embed
|
|
64
|
+
|
|
65
|
+
# Transformer blocks
|
|
66
|
+
@blocks.each do |block|
|
|
67
|
+
x = block.call(x)
|
|
68
|
+
end
|
|
69
|
+
|
|
70
|
+
# Final layer norm
|
|
71
|
+
x = @norm.call(x)
|
|
72
|
+
|
|
73
|
+
# Pool: take mean of all patch embeddings
|
|
74
|
+
x.mean(dim: 1)
|
|
75
|
+
end
|
|
76
|
+
|
|
77
|
+
private
|
|
78
|
+
|
|
79
|
+
def init_pos_embed
|
|
80
|
+
# Initialize with normal distribution (truncated normal not available in torch.rb)
|
|
81
|
+
# The values will be overwritten when loading pretrained weights
|
|
82
|
+
Torch::NN::Init.normal!(@pos_embed, mean: 0.0, std: 0.02)
|
|
83
|
+
end
|
|
84
|
+
end
|
|
85
|
+
|
|
86
|
+
# Patch embedding layer
|
|
87
|
+
class PatchEmbedding < Torch::NN::Module
|
|
88
|
+
def initialize(image_size:, patch_size:, in_channels:, embed_dim:)
|
|
89
|
+
super()
|
|
90
|
+
|
|
91
|
+
@image_size = image_size
|
|
92
|
+
@patch_size = patch_size
|
|
93
|
+
@num_patches = (image_size / patch_size) ** 2
|
|
94
|
+
|
|
95
|
+
# Use conv2d for efficient patch extraction
|
|
96
|
+
@proj = Torch::NN::Conv2d.new(
|
|
97
|
+
in_channels, embed_dim, patch_size,
|
|
98
|
+
stride: patch_size
|
|
99
|
+
)
|
|
100
|
+
end
|
|
101
|
+
|
|
102
|
+
def forward(x)
|
|
103
|
+
# x: (batch, channels, height, width)
|
|
104
|
+
# output: (batch, num_patches, embed_dim)
|
|
105
|
+
|
|
106
|
+
x = @proj.call(x) # (batch, embed_dim, h/patch, w/patch)
|
|
107
|
+
x = x.flatten(2) # (batch, embed_dim, num_patches)
|
|
108
|
+
x.transpose(1, 2) # (batch, num_patches, embed_dim)
|
|
109
|
+
end
|
|
110
|
+
end
|
|
111
|
+
|
|
112
|
+
# Transformer block with self-attention and MLP
|
|
113
|
+
class TransformerBlock < Torch::NN::Module
|
|
114
|
+
def initialize(hidden_size:, num_heads:, intermediate_size:, layer_norm_eps:)
|
|
115
|
+
super()
|
|
116
|
+
|
|
117
|
+
@norm1 = Torch::NN::LayerNorm.new(hidden_size, eps: layer_norm_eps)
|
|
118
|
+
@attn = Attention.new(hidden_size: hidden_size, num_heads: num_heads)
|
|
119
|
+
@norm2 = Torch::NN::LayerNorm.new(hidden_size, eps: layer_norm_eps)
|
|
120
|
+
@mlp = MLP.new(hidden_size: hidden_size, intermediate_size: intermediate_size)
|
|
121
|
+
end
|
|
122
|
+
|
|
123
|
+
def forward(x)
|
|
124
|
+
# Pre-norm architecture
|
|
125
|
+
x = x + @attn.call(@norm1.call(x))
|
|
126
|
+
x = x + @mlp.call(@norm2.call(x))
|
|
127
|
+
x
|
|
128
|
+
end
|
|
129
|
+
end
|
|
130
|
+
|
|
131
|
+
# Multi-head self-attention
|
|
132
|
+
class Attention < Torch::NN::Module
|
|
133
|
+
def initialize(hidden_size:, num_heads:)
|
|
134
|
+
super()
|
|
135
|
+
|
|
136
|
+
@num_heads = num_heads
|
|
137
|
+
@head_dim = hidden_size / num_heads
|
|
138
|
+
@scale = @head_dim ** -0.5
|
|
139
|
+
@hidden_size = hidden_size
|
|
140
|
+
|
|
141
|
+
# Separate Q, K, V projections (matches HuggingFace)
|
|
142
|
+
@q_proj = Torch::NN::Linear.new(hidden_size, hidden_size)
|
|
143
|
+
@k_proj = Torch::NN::Linear.new(hidden_size, hidden_size)
|
|
144
|
+
@v_proj = Torch::NN::Linear.new(hidden_size, hidden_size)
|
|
145
|
+
@out_proj = Torch::NN::Linear.new(hidden_size, hidden_size)
|
|
146
|
+
end
|
|
147
|
+
|
|
148
|
+
def forward(x)
|
|
149
|
+
batch_size, seq_len, _ = x.shape
|
|
150
|
+
|
|
151
|
+
# Compute Q, K, V separately
|
|
152
|
+
q = @q_proj.call(x)
|
|
153
|
+
k = @k_proj.call(x)
|
|
154
|
+
v = @v_proj.call(x)
|
|
155
|
+
|
|
156
|
+
# Reshape for multi-head attention
|
|
157
|
+
q = q.reshape(batch_size, seq_len, @num_heads, @head_dim).transpose(1, 2)
|
|
158
|
+
k = k.reshape(batch_size, seq_len, @num_heads, @head_dim).transpose(1, 2)
|
|
159
|
+
v = v.reshape(batch_size, seq_len, @num_heads, @head_dim).transpose(1, 2)
|
|
160
|
+
|
|
161
|
+
# Scaled dot-product attention
|
|
162
|
+
attn = Torch.matmul(q, k.transpose(-2, -1)) * @scale
|
|
163
|
+
attn = attn.softmax(dim: -1)
|
|
164
|
+
|
|
165
|
+
# Apply attention to values
|
|
166
|
+
out = Torch.matmul(attn, v) # (batch, heads, seq, head_dim)
|
|
167
|
+
out = out.transpose(1, 2).reshape(batch_size, seq_len, @hidden_size)
|
|
168
|
+
|
|
169
|
+
@out_proj.call(out)
|
|
170
|
+
end
|
|
171
|
+
end
|
|
172
|
+
|
|
173
|
+
# MLP (feed-forward network)
|
|
174
|
+
class MLP < Torch::NN::Module
|
|
175
|
+
def initialize(hidden_size:, intermediate_size:)
|
|
176
|
+
super()
|
|
177
|
+
|
|
178
|
+
@fc1 = Torch::NN::Linear.new(hidden_size, intermediate_size)
|
|
179
|
+
@act = Torch::NN::GELU.new
|
|
180
|
+
@fc2 = Torch::NN::Linear.new(intermediate_size, hidden_size)
|
|
181
|
+
end
|
|
182
|
+
|
|
183
|
+
def forward(x)
|
|
184
|
+
x = @fc1.call(x)
|
|
185
|
+
x = @act.call(x)
|
|
186
|
+
@fc2.call(x)
|
|
187
|
+
end
|
|
188
|
+
end
|
|
189
|
+
end
|
|
190
|
+
end
|