fine 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. checksums.yaml +7 -0
  2. data/.rspec +3 -0
  3. data/CHANGELOG.md +38 -0
  4. data/Gemfile +6 -0
  5. data/Gemfile.lock +167 -0
  6. data/LICENSE +21 -0
  7. data/README.md +212 -0
  8. data/Rakefile +6 -0
  9. data/docs/installation.md +151 -0
  10. data/docs/tutorials/llm-fine-tuning.md +246 -0
  11. data/docs/tutorials/model-export.md +200 -0
  12. data/docs/tutorials/siglip2-image-classification.md +130 -0
  13. data/docs/tutorials/siglip2-object-recognition.md +203 -0
  14. data/docs/tutorials/siglip2-similarity-search.md +152 -0
  15. data/docs/tutorials/text-classification.md +233 -0
  16. data/docs/tutorials/text-embeddings.md +211 -0
  17. data/examples/basic_classification.rb +70 -0
  18. data/examples/data/tool_calls.jsonl +30 -0
  19. data/examples/demo_training.rb +78 -0
  20. data/examples/finetune_gemma3_tools.rb +135 -0
  21. data/examples/real_llm_test.rb +128 -0
  22. data/examples/real_text_classification_test.rb +90 -0
  23. data/examples/real_text_embedder_test.rb +110 -0
  24. data/examples/real_training_test.rb +88 -0
  25. data/examples/test_export.rb +28 -0
  26. data/examples/test_image_classifier.rb +79 -0
  27. data/examples/test_llm.rb +100 -0
  28. data/examples/test_text_classifier.rb +59 -0
  29. data/lib/fine/callbacks/base.rb +140 -0
  30. data/lib/fine/callbacks/progress_bar.rb +66 -0
  31. data/lib/fine/configuration.rb +106 -0
  32. data/lib/fine/datasets/data_loader.rb +63 -0
  33. data/lib/fine/datasets/image_dataset.rb +203 -0
  34. data/lib/fine/datasets/instruction_dataset.rb +226 -0
  35. data/lib/fine/datasets/text_data_loader.rb +88 -0
  36. data/lib/fine/datasets/text_dataset.rb +266 -0
  37. data/lib/fine/error.rb +49 -0
  38. data/lib/fine/export/gguf_exporter.rb +424 -0
  39. data/lib/fine/export/onnx_exporter.rb +249 -0
  40. data/lib/fine/export.rb +53 -0
  41. data/lib/fine/hub/config_loader.rb +145 -0
  42. data/lib/fine/hub/model_downloader.rb +136 -0
  43. data/lib/fine/hub/safetensors_loader.rb +108 -0
  44. data/lib/fine/image_classifier.rb +256 -0
  45. data/lib/fine/llm.rb +336 -0
  46. data/lib/fine/models/base.rb +48 -0
  47. data/lib/fine/models/bert_encoder.rb +202 -0
  48. data/lib/fine/models/bert_for_sequence_classification.rb +226 -0
  49. data/lib/fine/models/causal_lm.rb +279 -0
  50. data/lib/fine/models/classification_head.rb +24 -0
  51. data/lib/fine/models/gemma3_decoder.rb +244 -0
  52. data/lib/fine/models/llama_decoder.rb +297 -0
  53. data/lib/fine/models/sentence_transformer.rb +202 -0
  54. data/lib/fine/models/siglip2_for_image_classification.rb +155 -0
  55. data/lib/fine/models/siglip2_vision_encoder.rb +190 -0
  56. data/lib/fine/text_classifier.rb +250 -0
  57. data/lib/fine/text_embedder.rb +221 -0
  58. data/lib/fine/tokenizers/auto_tokenizer.rb +208 -0
  59. data/lib/fine/training/llm_trainer.rb +212 -0
  60. data/lib/fine/training/text_trainer.rb +275 -0
  61. data/lib/fine/training/trainer.rb +194 -0
  62. data/lib/fine/transforms/compose.rb +28 -0
  63. data/lib/fine/transforms/normalize.rb +33 -0
  64. data/lib/fine/transforms/resize.rb +35 -0
  65. data/lib/fine/transforms/to_tensor.rb +53 -0
  66. data/lib/fine/version.rb +3 -0
  67. data/lib/fine.rb +112 -0
  68. data/mise.toml +2 -0
  69. metadata +240 -0
@@ -0,0 +1,202 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Fine
4
+ module Models
5
+ # Sentence Transformer for generating text embeddings
6
+ #
7
+ # Produces dense vector representations of sentences for semantic similarity,
8
+ # clustering, and retrieval tasks.
9
+ class SentenceTransformer < Base
10
+ attr_reader :encoder, :pooling_mode
11
+
12
+ POOLING_MODES = %i[cls mean max].freeze
13
+
14
+ # Load a pretrained sentence transformer
15
+ #
16
+ # @param model_id [String] HuggingFace model ID
17
+ # @param pooling_mode [Symbol] Pooling strategy (:cls, :mean, :max)
18
+ # @return [SentenceTransformer]
19
+ def self.from_pretrained(model_id, pooling_mode: :mean)
20
+ downloader = Hub::ModelDownloader.new(model_id)
21
+ model_path = downloader.download
22
+
23
+ config = Hub::ConfigLoader.from_pretrained(model_path)
24
+
25
+ model = new(config, pooling_mode: pooling_mode)
26
+
27
+ # Load weights
28
+ weights_path = downloader.file_path("model.safetensors")
29
+ if File.exist?(weights_path)
30
+ load_pretrained_weights(model, weights_path)
31
+ end
32
+
33
+ model
34
+ end
35
+
36
+ # Load from saved model
37
+ def self.load(path)
38
+ raise ModelNotFoundError.new(path) unless File.directory?(path)
39
+
40
+ config = Hub::ConfigLoader.from_pretrained(path)
41
+ pooling_mode = (config.config["pooling_mode"] || "mean").to_sym
42
+
43
+ model = new(config, pooling_mode: pooling_mode)
44
+
45
+ weights_path = File.join(path, "model.safetensors")
46
+ Hub::SafetensorsLoader.load_into_model(model, weights_path, strict: false)
47
+
48
+ model
49
+ end
50
+
51
+ def initialize(config, pooling_mode: :mean)
52
+ super(config)
53
+
54
+ raise ArgumentError, "Invalid pooling mode: #{pooling_mode}" unless POOLING_MODES.include?(pooling_mode)
55
+
56
+ @pooling_mode = pooling_mode
57
+ @encoder = BertEncoder.new(config)
58
+
59
+ # Optional: normalize embeddings
60
+ @normalize = true
61
+ end
62
+
63
+ def forward(input_ids, attention_mask: nil, token_type_ids: nil)
64
+ encoder_output = @encoder.call(
65
+ input_ids,
66
+ attention_mask: attention_mask,
67
+ token_type_ids: token_type_ids
68
+ )
69
+
70
+ # Pool based on strategy
71
+ embeddings = case @pooling_mode
72
+ when :cls
73
+ encoder_output[:pooler_output]
74
+ when :mean
75
+ mean_pooling(encoder_output[:last_hidden_state], attention_mask)
76
+ when :max
77
+ max_pooling(encoder_output[:last_hidden_state], attention_mask)
78
+ end
79
+
80
+ # L2 normalize
81
+ embeddings = Torch::NN::Functional.normalize(embeddings, p: 2, dim: 1) if @normalize
82
+
83
+ { embeddings: embeddings }
84
+ end
85
+
86
+ # Encode texts to embeddings
87
+ #
88
+ # @param input_ids [Torch::Tensor] Token IDs
89
+ # @param attention_mask [Torch::Tensor] Attention mask
90
+ # @return [Torch::Tensor] Embeddings
91
+ def encode(input_ids, attention_mask: nil, token_type_ids: nil)
92
+ eval
93
+ Torch.no_grad do
94
+ output = forward(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids)
95
+ output[:embeddings]
96
+ end
97
+ end
98
+
99
+ # Compute similarity between two sets of embeddings
100
+ #
101
+ # @param embeddings_a [Torch::Tensor] First embeddings
102
+ # @param embeddings_b [Torch::Tensor] Second embeddings
103
+ # @return [Torch::Tensor] Similarity scores
104
+ def similarity(embeddings_a, embeddings_b)
105
+ # Cosine similarity (embeddings should be normalized)
106
+ Torch.matmul(embeddings_a, embeddings_b.transpose(0, 1))
107
+ end
108
+
109
+ def save(path)
110
+ FileUtils.mkdir_p(path)
111
+
112
+ weights_path = File.join(path, "model.safetensors")
113
+ Safetensors::Torch.save_file(state_dict, weights_path)
114
+
115
+ save_config = @config.to_h.merge(
116
+ "model_type" => "sentence_transformer",
117
+ "pooling_mode" => @pooling_mode.to_s
118
+ )
119
+
120
+ config_path = File.join(path, "config.json")
121
+ File.write(config_path, JSON.pretty_generate(save_config))
122
+ end
123
+
124
+ private
125
+
126
+ def mean_pooling(hidden_states, attention_mask)
127
+ if attention_mask
128
+ mask = attention_mask.unsqueeze(-1).expand_as(hidden_states).float
129
+ sum_embeddings = (hidden_states * mask).sum(dim: 1)
130
+ sum_mask = mask.sum(dim: 1).clamp(min: 1e-9)
131
+ sum_embeddings / sum_mask
132
+ else
133
+ hidden_states.mean(dim: 1)
134
+ end
135
+ end
136
+
137
+ def max_pooling(hidden_states, attention_mask)
138
+ if attention_mask
139
+ mask = attention_mask.unsqueeze(-1).expand_as(hidden_states)
140
+ hidden_states = hidden_states.masked_fill(mask == 0, -1e9)
141
+ end
142
+ hidden_states.max(dim: 1).values
143
+ end
144
+
145
+ def self.load_pretrained_weights(model, weights_path)
146
+ tensors = Safetensors::Torch.load_file(weights_path)
147
+
148
+ mapped = {}
149
+ model_keys = model.state_dict.keys
150
+
151
+ tensors.each do |name, tensor|
152
+ # Try direct mapping first
153
+ mapped_name = map_sentence_transformer_weight_name(name)
154
+
155
+ if model_keys.include?(mapped_name)
156
+ mapped[mapped_name] = tensor
157
+ end
158
+ end
159
+
160
+ # Use no_grad and copy! since torch.rb doesn't support strict: false
161
+ Torch.no_grad do
162
+ state_dict = model.state_dict
163
+ mapped.each do |name, tensor|
164
+ state_dict[name].copy!(tensor) if state_dict.key?(name)
165
+ end
166
+ end
167
+ end
168
+
169
+ def self.map_sentence_transformer_weight_name(hf_name)
170
+ name = hf_name.dup
171
+
172
+ # sentence-transformers uses "0." prefix for the encoder module
173
+ name = name.sub(/^0\./, "encoder.")
174
+
175
+ # Then apply BERT mappings
176
+ name = name.sub("auto_model.", "encoder.")
177
+ name = name.sub("bert.embeddings.word_embeddings", "encoder.word_embeddings")
178
+ name = name.sub("bert.embeddings.position_embeddings", "encoder.position_embeddings")
179
+ name = name.sub("bert.embeddings.token_type_embeddings", "encoder.token_type_embeddings")
180
+ name = name.sub("bert.embeddings.LayerNorm", "encoder.embeddings_layer_norm")
181
+ name = name.sub("embeddings.word_embeddings", "encoder.word_embeddings")
182
+ name = name.sub("embeddings.position_embeddings", "encoder.position_embeddings")
183
+ name = name.sub("embeddings.token_type_embeddings", "encoder.token_type_embeddings")
184
+ name = name.sub("embeddings.LayerNorm", "encoder.embeddings_layer_norm")
185
+ name = name.gsub("bert.encoder.layer", "encoder.layers")
186
+ name = name.gsub("encoder.layer", "encoder.layers")
187
+ name = name.gsub(".attention.self.query", ".attention.query")
188
+ name = name.gsub(".attention.self.key", ".attention.key")
189
+ name = name.gsub(".attention.self.value", ".attention.value")
190
+ name = name.gsub(".attention.output.dense", ".attention.out")
191
+ name = name.gsub(".attention.output.LayerNorm", ".attention_layer_norm")
192
+ name = name.gsub(".intermediate.dense", ".intermediate")
193
+ name = name.gsub(".output.dense", ".output")
194
+ name = name.gsub(".output.LayerNorm", ".output_layer_norm")
195
+ name = name.sub("bert.pooler.dense", "encoder.pooler_dense")
196
+ name = name.sub("pooler.dense", "encoder.pooler_dense")
197
+
198
+ name
199
+ end
200
+ end
201
+ end
202
+ end
@@ -0,0 +1,155 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Fine
4
+ module Models
5
+ # SigLIP2 model with classification head for image classification
6
+ class SigLIP2ForImageClassification < Base
7
+ attr_reader :encoder, :classifier, :num_labels
8
+
9
+ # Load a pretrained model from Hugging Face Hub
10
+ #
11
+ # @param model_id [String] Hugging Face model ID (e.g., "google/siglip2-base-patch16-224")
12
+ # @param num_labels [Integer] Number of classification labels
13
+ # @param freeze_encoder [Boolean] Whether to freeze the vision encoder
14
+ # @param dropout [Float] Dropout rate for classifier
15
+ # @return [SigLIP2ForImageClassification]
16
+ def self.from_pretrained(model_id, num_labels:, freeze_encoder: false, dropout: 0.0)
17
+ # Download model files
18
+ downloader = Hub::ModelDownloader.new(model_id)
19
+ model_path = downloader.download
20
+
21
+ # Load config
22
+ config = Hub::ConfigLoader.from_pretrained(model_path)
23
+
24
+ # Create model
25
+ model = new(config, num_labels: num_labels, dropout: dropout)
26
+
27
+ # Load pretrained weights into encoder
28
+ weights_path = downloader.file_path("model.safetensors")
29
+ load_result = Hub::SafetensorsLoader.load_into_model(
30
+ model.encoder,
31
+ weights_path,
32
+ strict: false,
33
+ prefix: "vision_model"
34
+ )
35
+
36
+ # Log any issues
37
+ if load_result[:missing_keys].any?
38
+ warn "Missing keys in encoder: #{load_result[:missing_keys].first(5).join(', ')}..."
39
+ end
40
+
41
+ # Freeze encoder if requested
42
+ model.encoder.freeze! if freeze_encoder
43
+
44
+ model
45
+ end
46
+
47
+ # Load a fine-tuned model from disk
48
+ #
49
+ # @param path [String] Path to saved model directory
50
+ # @return [SigLIP2ForImageClassification]
51
+ def self.load(path)
52
+ raise ModelNotFoundError.new(path, "Model directory not found") unless File.directory?(path)
53
+
54
+ # Load config
55
+ config = Hub::ConfigLoader.from_pretrained(path)
56
+ num_labels = config.config["num_labels"] || config.config["id2label"]&.size
57
+
58
+ raise ConfigurationError, "Cannot determine num_labels from saved model" unless num_labels
59
+
60
+ # Create model
61
+ model = new(config, num_labels: num_labels)
62
+
63
+ # Load weights
64
+ weights_path = File.join(path, "model.safetensors")
65
+ Hub::SafetensorsLoader.load_into_model(model, weights_path, strict: false)
66
+
67
+ model
68
+ end
69
+
70
+ def initialize(config, num_labels:, dropout: 0.0)
71
+ super(config)
72
+
73
+ @num_labels = num_labels
74
+
75
+ # Vision encoder
76
+ @encoder = SigLIP2VisionEncoder.new(config)
77
+
78
+ # Classification head
79
+ @classifier = ClassificationHead.new(
80
+ config.hidden_size,
81
+ num_labels,
82
+ dropout: dropout
83
+ )
84
+ end
85
+
86
+ def forward(pixel_values, labels: nil)
87
+ # Get image features from encoder
88
+ features = @encoder.call(pixel_values)
89
+
90
+ # Classification logits
91
+ logits = @classifier.call(features)
92
+
93
+ # Compute loss if labels provided
94
+ if labels
95
+ loss = Torch::NN::Functional.cross_entropy(logits, labels)
96
+ { loss: loss, logits: logits }
97
+ else
98
+ { logits: logits }
99
+ end
100
+ end
101
+
102
+ # Predict class for an image tensor
103
+ #
104
+ # @param pixel_values [Torch::Tensor] Image tensor (batch, C, H, W)
105
+ # @return [Torch::Tensor] Predicted class indices
106
+ def predict(pixel_values)
107
+ eval
108
+ Torch.no_grad do
109
+ output = forward(pixel_values)
110
+ output[:logits].argmax(dim: 1)
111
+ end
112
+ end
113
+
114
+ # Get probabilities for each class
115
+ #
116
+ # @param pixel_values [Torch::Tensor] Image tensor (batch, C, H, W)
117
+ # @return [Torch::Tensor] Class probabilities
118
+ def predict_proba(pixel_values)
119
+ eval
120
+ Torch.no_grad do
121
+ output = forward(pixel_values)
122
+ Torch::NN::Functional.softmax(output[:logits], dim: 1)
123
+ end
124
+ end
125
+
126
+ # Save the model to disk
127
+ #
128
+ # @param path [String] Directory path to save to
129
+ # @param label_map [Hash, nil] Optional mapping of label names to IDs
130
+ def save(path, label_map: nil)
131
+ FileUtils.mkdir_p(path)
132
+
133
+ # Save weights
134
+ weights_path = File.join(path, "model.safetensors")
135
+ Safetensors::Torch.save_file(state_dict, weights_path)
136
+
137
+ # Build config with num_labels
138
+ save_config = @config.to_h.merge(
139
+ "num_labels" => @num_labels,
140
+ "model_type" => "siglip2_image_classification"
141
+ )
142
+
143
+ # Add label mapping if provided
144
+ if label_map
145
+ save_config["id2label"] = label_map.invert.sort.to_h
146
+ save_config["label2id"] = label_map
147
+ end
148
+
149
+ # Save config
150
+ config_path = File.join(path, "config.json")
151
+ File.write(config_path, JSON.pretty_generate(save_config))
152
+ end
153
+ end
154
+ end
155
+ end
@@ -0,0 +1,190 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Fine
4
+ module Models
5
+ # SigLIP2 Vision Transformer Encoder
6
+ #
7
+ # Implements the vision encoder portion of SigLIP2 for image feature extraction.
8
+ # Architecture follows the standard ViT with patch embedding, transformer blocks,
9
+ # and pooling.
10
+ class SigLIP2VisionEncoder < Base
11
+ def initialize(config)
12
+ super(config)
13
+
14
+ @hidden_size = config.hidden_size
15
+ @num_layers = config.num_hidden_layers
16
+ @num_heads = config.num_attention_heads
17
+ @intermediate_size = config.intermediate_size
18
+ @image_size = config.image_size
19
+ @patch_size = config.patch_size
20
+ @num_channels = config.num_channels
21
+ @layer_norm_eps = config.layer_norm_eps
22
+
23
+ # Patch embedding
24
+ @patch_embed = PatchEmbedding.new(
25
+ image_size: @image_size,
26
+ patch_size: @patch_size,
27
+ in_channels: @num_channels,
28
+ embed_dim: @hidden_size
29
+ )
30
+
31
+ # Position embedding (learnable)
32
+ num_patches = (@image_size / @patch_size) ** 2
33
+ @pos_embed = Torch::NN::Parameter.new(
34
+ Torch.zeros(1, num_patches, @hidden_size)
35
+ )
36
+
37
+ # Transformer blocks
38
+ @blocks = Torch::NN::ModuleList.new(
39
+ @num_layers.times.map do
40
+ TransformerBlock.new(
41
+ hidden_size: @hidden_size,
42
+ num_heads: @num_heads,
43
+ intermediate_size: @intermediate_size,
44
+ layer_norm_eps: @layer_norm_eps
45
+ )
46
+ end
47
+ )
48
+
49
+ # Final layer norm
50
+ @norm = Torch::NN::LayerNorm.new(@hidden_size, eps: @layer_norm_eps)
51
+
52
+ # Initialize position embedding
53
+ init_pos_embed
54
+ end
55
+
56
+ def forward(pixel_values)
57
+ # pixel_values: (batch, channels, height, width)
58
+
59
+ # Patch embedding: (batch, num_patches, hidden_size)
60
+ x = @patch_embed.call(pixel_values)
61
+
62
+ # Add position embedding
63
+ x = x + @pos_embed
64
+
65
+ # Transformer blocks
66
+ @blocks.each do |block|
67
+ x = block.call(x)
68
+ end
69
+
70
+ # Final layer norm
71
+ x = @norm.call(x)
72
+
73
+ # Pool: take mean of all patch embeddings
74
+ x.mean(dim: 1)
75
+ end
76
+
77
+ private
78
+
79
+ def init_pos_embed
80
+ # Initialize with normal distribution (truncated normal not available in torch.rb)
81
+ # The values will be overwritten when loading pretrained weights
82
+ Torch::NN::Init.normal!(@pos_embed, mean: 0.0, std: 0.02)
83
+ end
84
+ end
85
+
86
+ # Patch embedding layer
87
+ class PatchEmbedding < Torch::NN::Module
88
+ def initialize(image_size:, patch_size:, in_channels:, embed_dim:)
89
+ super()
90
+
91
+ @image_size = image_size
92
+ @patch_size = patch_size
93
+ @num_patches = (image_size / patch_size) ** 2
94
+
95
+ # Use conv2d for efficient patch extraction
96
+ @proj = Torch::NN::Conv2d.new(
97
+ in_channels, embed_dim, patch_size,
98
+ stride: patch_size
99
+ )
100
+ end
101
+
102
+ def forward(x)
103
+ # x: (batch, channels, height, width)
104
+ # output: (batch, num_patches, embed_dim)
105
+
106
+ x = @proj.call(x) # (batch, embed_dim, h/patch, w/patch)
107
+ x = x.flatten(2) # (batch, embed_dim, num_patches)
108
+ x.transpose(1, 2) # (batch, num_patches, embed_dim)
109
+ end
110
+ end
111
+
112
+ # Transformer block with self-attention and MLP
113
+ class TransformerBlock < Torch::NN::Module
114
+ def initialize(hidden_size:, num_heads:, intermediate_size:, layer_norm_eps:)
115
+ super()
116
+
117
+ @norm1 = Torch::NN::LayerNorm.new(hidden_size, eps: layer_norm_eps)
118
+ @attn = Attention.new(hidden_size: hidden_size, num_heads: num_heads)
119
+ @norm2 = Torch::NN::LayerNorm.new(hidden_size, eps: layer_norm_eps)
120
+ @mlp = MLP.new(hidden_size: hidden_size, intermediate_size: intermediate_size)
121
+ end
122
+
123
+ def forward(x)
124
+ # Pre-norm architecture
125
+ x = x + @attn.call(@norm1.call(x))
126
+ x = x + @mlp.call(@norm2.call(x))
127
+ x
128
+ end
129
+ end
130
+
131
+ # Multi-head self-attention
132
+ class Attention < Torch::NN::Module
133
+ def initialize(hidden_size:, num_heads:)
134
+ super()
135
+
136
+ @num_heads = num_heads
137
+ @head_dim = hidden_size / num_heads
138
+ @scale = @head_dim ** -0.5
139
+ @hidden_size = hidden_size
140
+
141
+ # Separate Q, K, V projections (matches HuggingFace)
142
+ @q_proj = Torch::NN::Linear.new(hidden_size, hidden_size)
143
+ @k_proj = Torch::NN::Linear.new(hidden_size, hidden_size)
144
+ @v_proj = Torch::NN::Linear.new(hidden_size, hidden_size)
145
+ @out_proj = Torch::NN::Linear.new(hidden_size, hidden_size)
146
+ end
147
+
148
+ def forward(x)
149
+ batch_size, seq_len, _ = x.shape
150
+
151
+ # Compute Q, K, V separately
152
+ q = @q_proj.call(x)
153
+ k = @k_proj.call(x)
154
+ v = @v_proj.call(x)
155
+
156
+ # Reshape for multi-head attention
157
+ q = q.reshape(batch_size, seq_len, @num_heads, @head_dim).transpose(1, 2)
158
+ k = k.reshape(batch_size, seq_len, @num_heads, @head_dim).transpose(1, 2)
159
+ v = v.reshape(batch_size, seq_len, @num_heads, @head_dim).transpose(1, 2)
160
+
161
+ # Scaled dot-product attention
162
+ attn = Torch.matmul(q, k.transpose(-2, -1)) * @scale
163
+ attn = attn.softmax(dim: -1)
164
+
165
+ # Apply attention to values
166
+ out = Torch.matmul(attn, v) # (batch, heads, seq, head_dim)
167
+ out = out.transpose(1, 2).reshape(batch_size, seq_len, @hidden_size)
168
+
169
+ @out_proj.call(out)
170
+ end
171
+ end
172
+
173
+ # MLP (feed-forward network)
174
+ class MLP < Torch::NN::Module
175
+ def initialize(hidden_size:, intermediate_size:)
176
+ super()
177
+
178
+ @fc1 = Torch::NN::Linear.new(hidden_size, intermediate_size)
179
+ @act = Torch::NN::GELU.new
180
+ @fc2 = Torch::NN::Linear.new(intermediate_size, hidden_size)
181
+ end
182
+
183
+ def forward(x)
184
+ x = @fc1.call(x)
185
+ x = @act.call(x)
186
+ @fc2.call(x)
187
+ end
188
+ end
189
+ end
190
+ end