fine 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. checksums.yaml +7 -0
  2. data/.rspec +3 -0
  3. data/CHANGELOG.md +38 -0
  4. data/Gemfile +6 -0
  5. data/Gemfile.lock +167 -0
  6. data/LICENSE +21 -0
  7. data/README.md +212 -0
  8. data/Rakefile +6 -0
  9. data/docs/installation.md +151 -0
  10. data/docs/tutorials/llm-fine-tuning.md +246 -0
  11. data/docs/tutorials/model-export.md +200 -0
  12. data/docs/tutorials/siglip2-image-classification.md +130 -0
  13. data/docs/tutorials/siglip2-object-recognition.md +203 -0
  14. data/docs/tutorials/siglip2-similarity-search.md +152 -0
  15. data/docs/tutorials/text-classification.md +233 -0
  16. data/docs/tutorials/text-embeddings.md +211 -0
  17. data/examples/basic_classification.rb +70 -0
  18. data/examples/data/tool_calls.jsonl +30 -0
  19. data/examples/demo_training.rb +78 -0
  20. data/examples/finetune_gemma3_tools.rb +135 -0
  21. data/examples/real_llm_test.rb +128 -0
  22. data/examples/real_text_classification_test.rb +90 -0
  23. data/examples/real_text_embedder_test.rb +110 -0
  24. data/examples/real_training_test.rb +88 -0
  25. data/examples/test_export.rb +28 -0
  26. data/examples/test_image_classifier.rb +79 -0
  27. data/examples/test_llm.rb +100 -0
  28. data/examples/test_text_classifier.rb +59 -0
  29. data/lib/fine/callbacks/base.rb +140 -0
  30. data/lib/fine/callbacks/progress_bar.rb +66 -0
  31. data/lib/fine/configuration.rb +106 -0
  32. data/lib/fine/datasets/data_loader.rb +63 -0
  33. data/lib/fine/datasets/image_dataset.rb +203 -0
  34. data/lib/fine/datasets/instruction_dataset.rb +226 -0
  35. data/lib/fine/datasets/text_data_loader.rb +88 -0
  36. data/lib/fine/datasets/text_dataset.rb +266 -0
  37. data/lib/fine/error.rb +49 -0
  38. data/lib/fine/export/gguf_exporter.rb +424 -0
  39. data/lib/fine/export/onnx_exporter.rb +249 -0
  40. data/lib/fine/export.rb +53 -0
  41. data/lib/fine/hub/config_loader.rb +145 -0
  42. data/lib/fine/hub/model_downloader.rb +136 -0
  43. data/lib/fine/hub/safetensors_loader.rb +108 -0
  44. data/lib/fine/image_classifier.rb +256 -0
  45. data/lib/fine/llm.rb +336 -0
  46. data/lib/fine/models/base.rb +48 -0
  47. data/lib/fine/models/bert_encoder.rb +202 -0
  48. data/lib/fine/models/bert_for_sequence_classification.rb +226 -0
  49. data/lib/fine/models/causal_lm.rb +279 -0
  50. data/lib/fine/models/classification_head.rb +24 -0
  51. data/lib/fine/models/gemma3_decoder.rb +244 -0
  52. data/lib/fine/models/llama_decoder.rb +297 -0
  53. data/lib/fine/models/sentence_transformer.rb +202 -0
  54. data/lib/fine/models/siglip2_for_image_classification.rb +155 -0
  55. data/lib/fine/models/siglip2_vision_encoder.rb +190 -0
  56. data/lib/fine/text_classifier.rb +250 -0
  57. data/lib/fine/text_embedder.rb +221 -0
  58. data/lib/fine/tokenizers/auto_tokenizer.rb +208 -0
  59. data/lib/fine/training/llm_trainer.rb +212 -0
  60. data/lib/fine/training/text_trainer.rb +275 -0
  61. data/lib/fine/training/trainer.rb +194 -0
  62. data/lib/fine/transforms/compose.rb +28 -0
  63. data/lib/fine/transforms/normalize.rb +33 -0
  64. data/lib/fine/transforms/resize.rb +35 -0
  65. data/lib/fine/transforms/to_tensor.rb +53 -0
  66. data/lib/fine/version.rb +3 -0
  67. data/lib/fine.rb +112 -0
  68. data/mise.toml +2 -0
  69. metadata +240 -0
data/lib/fine/llm.rb ADDED
@@ -0,0 +1,336 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Fine
4
+ # High-level API for LLM fine-tuning
5
+ #
6
+ # @example Basic fine-tuning
7
+ # llm = Fine::LLM.new("meta-llama/Llama-3.2-1B")
8
+ # llm.fit(train_file: "instructions.jsonl", epochs: 3)
9
+ # llm.save("my_llama")
10
+ #
11
+ # @example Generation
12
+ # llm = Fine::LLM.load("my_llama")
13
+ # response = llm.generate("What is Ruby?", max_new_tokens: 100)
14
+ #
15
+ # @example With configuration
16
+ # llm = Fine::LLM.new("google/gemma-2b") do |config|
17
+ # config.epochs = 3
18
+ # config.batch_size = 4
19
+ # config.learning_rate = 1e-5
20
+ # config.max_length = 1024
21
+ # end
22
+ #
23
+ class LLM
24
+ attr_reader :model, :config, :tokenizer, :model_id
25
+
26
+ # Create a new LLM for fine-tuning
27
+ #
28
+ # @param model_id [String] HuggingFace model ID
29
+ # @yield [config] Optional configuration block
30
+ def initialize(model_id, &block)
31
+ @model_id = model_id
32
+ @config = LLMConfiguration.new
33
+ @model = nil
34
+ @tokenizer = nil
35
+ @trained = false
36
+
37
+ block&.call(@config)
38
+
39
+ if @config.callbacks.empty? && Fine.configuration&.progress_bar != false
40
+ @config.callbacks << Callbacks::ProgressBar.new
41
+ end
42
+ end
43
+
44
+ # Load a fine-tuned LLM from disk
45
+ #
46
+ # @param path [String] Path to saved model
47
+ # @return [LLM]
48
+ def self.load(path)
49
+ config_path = File.join(path, "config.json")
50
+ raise ModelNotFoundError.new(path) unless File.exist?(config_path)
51
+
52
+ config_data = JSON.parse(File.read(config_path))
53
+
54
+ llm = allocate
55
+ llm.instance_variable_set(:@model_id, config_data["_model_id"] || "custom")
56
+ llm.instance_variable_set(:@config, LLMConfiguration.new)
57
+ llm.instance_variable_set(:@trained, true)
58
+
59
+ # Load tokenizer
60
+ tokenizer_path = File.join(path, "tokenizer.json")
61
+ tokenizer = if File.exist?(tokenizer_path)
62
+ Tokenizers::AutoTokenizer.new(path, max_length: config_data["max_length"] || 2048)
63
+ else
64
+ Tokenizers::AutoTokenizer.from_pretrained(
65
+ config_data["_model_id"],
66
+ max_length: config_data["max_length"] || 2048
67
+ )
68
+ end
69
+ llm.instance_variable_set(:@tokenizer, tokenizer)
70
+
71
+ # Load model
72
+ llm.instance_variable_set(
73
+ :@model,
74
+ Models::CausalLM.load(path)
75
+ )
76
+
77
+ llm
78
+ end
79
+
80
+ # Fine-tune on instruction data
81
+ #
82
+ # @param train_file [String] Path to training data (JSONL)
83
+ # @param val_file [String, nil] Path to validation data
84
+ # @param format [Symbol] Data format (:alpaca, :sharegpt, :simple, :auto)
85
+ # @param epochs [Integer, nil] Override config epochs
86
+ # @return [Array<Hash>] Training history
87
+ def fit(train_file:, val_file: nil, format: :auto, epochs: nil)
88
+ @config.epochs = epochs if epochs
89
+
90
+ # Download model files first (including tokenizer)
91
+ downloader = Hub::ModelDownloader.new(@model_id)
92
+ model_path = downloader.download
93
+
94
+ # Load tokenizer from cache or HuggingFace
95
+ tokenizer_path = File.join(model_path, "tokenizer.json")
96
+ @tokenizer = if File.exist?(tokenizer_path)
97
+ Tokenizers::AutoTokenizer.new(model_path, max_length: @config.max_length)
98
+ else
99
+ Tokenizers::AutoTokenizer.from_pretrained(@model_id, max_length: @config.max_length)
100
+ end
101
+
102
+ # Set pad token if not set
103
+ @config.pad_token_id ||= @tokenizer.pad_token_id || @tokenizer.eos_token_id || 0
104
+
105
+ # Load datasets
106
+ train_dataset = Datasets::InstructionDataset.from_jsonl(
107
+ train_file,
108
+ tokenizer: @tokenizer,
109
+ format: format,
110
+ max_length: @config.max_length
111
+ )
112
+
113
+ val_dataset = if val_file
114
+ Datasets::InstructionDataset.from_jsonl(
115
+ val_file,
116
+ tokenizer: @tokenizer,
117
+ format: format,
118
+ max_length: @config.max_length
119
+ )
120
+ end
121
+
122
+ # Load model
123
+ @model = Models::CausalLM.from_pretrained(@model_id)
124
+
125
+ # Freeze layers if configured
126
+ if @config.freeze_layers && @config.freeze_layers > 0
127
+ freeze_bottom_layers(@config.freeze_layers)
128
+ end
129
+
130
+ # Train
131
+ trainer = Training::LLMTrainer.new(
132
+ @model,
133
+ @config,
134
+ train_dataset: train_dataset,
135
+ val_dataset: val_dataset
136
+ )
137
+
138
+ history = trainer.fit
139
+ @trained = true
140
+
141
+ history
142
+ end
143
+
144
+ # Fine-tune with automatic train/val split
145
+ #
146
+ # @param data_file [String] Path to data file
147
+ # @param val_split [Float] Fraction for validation
148
+ # @param format [Symbol] Data format
149
+ # @return [Array<Hash>] Training history
150
+ def fit_with_split(data_file:, val_split: 0.1, format: :auto, epochs: nil)
151
+ @config.epochs = epochs if epochs
152
+
153
+ # Download model files first (including tokenizer)
154
+ downloader = Hub::ModelDownloader.new(@model_id)
155
+ model_path = downloader.download
156
+
157
+ # Load tokenizer from cache or HuggingFace
158
+ tokenizer_path = File.join(model_path, "tokenizer.json")
159
+ @tokenizer = if File.exist?(tokenizer_path)
160
+ Tokenizers::AutoTokenizer.new(model_path, max_length: @config.max_length)
161
+ else
162
+ Tokenizers::AutoTokenizer.from_pretrained(@model_id, max_length: @config.max_length)
163
+ end
164
+
165
+ @config.pad_token_id ||= @tokenizer.pad_token_id || @tokenizer.eos_token_id || 0
166
+
167
+ full_dataset = Datasets::InstructionDataset.from_jsonl(
168
+ data_file,
169
+ tokenizer: @tokenizer,
170
+ format: format,
171
+ max_length: @config.max_length
172
+ )
173
+
174
+ train_dataset, val_dataset = full_dataset.split(test_size: val_split)
175
+
176
+ @model = Models::CausalLM.from_pretrained(@model_id)
177
+
178
+ if @config.freeze_layers && @config.freeze_layers > 0
179
+ freeze_bottom_layers(@config.freeze_layers)
180
+ end
181
+
182
+ trainer = Training::LLMTrainer.new(
183
+ @model,
184
+ @config,
185
+ train_dataset: train_dataset,
186
+ val_dataset: val_dataset
187
+ )
188
+
189
+ history = trainer.fit
190
+ @trained = true
191
+
192
+ history
193
+ end
194
+
195
+ # Generate text
196
+ #
197
+ # @param prompt [String] Input prompt
198
+ # @param max_new_tokens [Integer] Maximum tokens to generate
199
+ # @param temperature [Float] Sampling temperature (higher = more random)
200
+ # @param top_p [Float] Nucleus sampling threshold
201
+ # @param top_k [Integer] Top-k sampling
202
+ # @param do_sample [Boolean] Whether to sample (false = greedy)
203
+ # @return [String] Generated text
204
+ def generate(prompt, max_new_tokens: 100, temperature: 0.7, top_p: 0.9, top_k: 50, do_sample: true)
205
+ raise TrainingError, "Model not loaded" unless @model && @tokenizer
206
+
207
+ # Tokenize prompt (without tensors for easier manipulation)
208
+ encoding = @tokenizer.encode(prompt, return_tensors: false)
209
+ input_ids = Torch.tensor([encoding[:input_ids].first])
210
+
211
+ # Move to device
212
+ input_ids = input_ids.to(Fine.device)
213
+ @model.to(Fine.device)
214
+
215
+ # Generate
216
+ output_ids = @model.generate(
217
+ input_ids,
218
+ max_new_tokens: max_new_tokens,
219
+ temperature: temperature,
220
+ top_p: top_p,
221
+ top_k: top_k,
222
+ do_sample: do_sample,
223
+ eos_token_id: @tokenizer.eos_token_id,
224
+ pad_token_id: @config.pad_token_id
225
+ )
226
+
227
+ # Decode
228
+ @tokenizer.decode(output_ids[0].to_a)
229
+ end
230
+
231
+ # Chat-style generation
232
+ #
233
+ # @param messages [Array<Hash>] Messages with :role and :content
234
+ # @param kwargs [Hash] Generation parameters
235
+ # @return [String] Assistant response
236
+ def chat(messages, **kwargs)
237
+ prompt = format_chat_prompt(messages)
238
+ full_response = generate(prompt, **kwargs)
239
+
240
+ # Extract just the assistant response
241
+ if full_response.include?("### Response:")
242
+ full_response.split("### Response:").last.strip
243
+ else
244
+ # Remove the prompt from the response
245
+ full_response[prompt.length..].strip
246
+ end
247
+ end
248
+
249
+ # Save the model
250
+ #
251
+ # @param path [String] Directory to save to
252
+ def save(path)
253
+ raise TrainingError, "Model not trained or loaded" unless @model
254
+
255
+ @model.save(path)
256
+ @tokenizer.save(path)
257
+
258
+ # Update config with model ID
259
+ config_path = File.join(path, "config.json")
260
+ config = JSON.parse(File.read(config_path))
261
+ config["_model_id"] = @model_id
262
+ config["max_length"] = @config.max_length
263
+ File.write(config_path, JSON.pretty_generate(config))
264
+ end
265
+
266
+ # Export to GGUF format for llama.cpp, ollama, etc.
267
+ #
268
+ # @param path [String] Output path for GGUF file
269
+ # @param quantization [Symbol] Quantization type (:f16, :q4_0, :q8_0, etc.)
270
+ # @param metadata [Hash] Additional metadata
271
+ # @return [String] The output path
272
+ def export_gguf(path, quantization: :f16, **options)
273
+ Export.to_gguf(self, path, quantization: quantization, **options)
274
+ end
275
+
276
+ # Export to ONNX format
277
+ #
278
+ # @param path [String] Output path for ONNX file
279
+ # @param options [Hash] Export options
280
+ # @return [String] The output path
281
+ def export_onnx(path, **options)
282
+ Export.to_onnx(self, path, **options)
283
+ end
284
+
285
+ private
286
+
287
+ def freeze_bottom_layers(num_layers)
288
+ # Freeze embedding
289
+ @model.decoder.embed_tokens.parameters.each { |p| p.requires_grad = false }
290
+
291
+ # Freeze bottom N layers
292
+ @model.decoder.layers[0...num_layers].each do |layer|
293
+ layer.parameters.each { |p| p.requires_grad = false }
294
+ end
295
+ end
296
+
297
+ def format_chat_prompt(messages)
298
+ prompt = ""
299
+
300
+ messages.each do |msg|
301
+ case msg[:role]
302
+ when "system"
303
+ prompt += "### System:\n#{msg[:content]}\n\n"
304
+ when "user"
305
+ prompt += "### Instruction:\n#{msg[:content]}\n\n"
306
+ when "assistant"
307
+ prompt += "### Response:\n#{msg[:content]}\n\n"
308
+ end
309
+ end
310
+
311
+ # Add response prefix for the model to continue
312
+ prompt += "### Response:\n" unless prompt.end_with?("### Response:\n")
313
+
314
+ prompt
315
+ end
316
+ end
317
+
318
+ # Configuration for LLM fine-tuning
319
+ class LLMConfiguration < Configuration
320
+ attr_accessor :max_length, :warmup_steps, :gradient_accumulation_steps,
321
+ :max_grad_norm, :freeze_layers, :pad_token_id
322
+
323
+ def initialize
324
+ super
325
+ @max_length = 2048
326
+ @learning_rate = 2e-5
327
+ @batch_size = 4
328
+ @epochs = 1
329
+ @warmup_steps = 100
330
+ @gradient_accumulation_steps = 4
331
+ @max_grad_norm = 1.0
332
+ @freeze_layers = 0
333
+ @pad_token_id = nil
334
+ end
335
+ end
336
+ end
@@ -0,0 +1,48 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Fine
4
+ module Models
5
+ # Base class for all Fine models
6
+ class Base < Torch::NN::Module
7
+ attr_reader :config
8
+
9
+ def initialize(config)
10
+ super()
11
+ @config = config
12
+ end
13
+
14
+ # Save model weights to a file
15
+ #
16
+ # @param path [String] Directory path to save to
17
+ def save_pretrained(path)
18
+ FileUtils.mkdir_p(path)
19
+
20
+ # Save weights as safetensors
21
+ weights_path = File.join(path, "model.safetensors")
22
+ Safetensors::Torch.save_file(state_dict, weights_path)
23
+
24
+ # Save config
25
+ config_path = File.join(path, "config.json")
26
+ File.write(config_path, JSON.pretty_generate(@config.to_h))
27
+ end
28
+
29
+ # Freeze all parameters (for feature extraction)
30
+ def freeze!
31
+ parameters.each { |p| p.requires_grad = false }
32
+ self
33
+ end
34
+
35
+ # Unfreeze all parameters
36
+ def unfreeze!
37
+ parameters.each { |p| p.requires_grad = true }
38
+ self
39
+ end
40
+
41
+ # Get number of trainable parameters
42
+ def num_parameters(trainable_only: false)
43
+ params = trainable_only ? parameters.select(&:requires_grad) : parameters
44
+ params.sum { |p| p.numel }
45
+ end
46
+ end
47
+ end
48
+ end
@@ -0,0 +1,202 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Fine
4
+ module Models
5
+ # BERT/DistilBERT/DeBERTa Encoder
6
+ #
7
+ # Implements transformer encoder for text understanding tasks.
8
+ # Supports loading pretrained weights from HuggingFace Hub.
9
+ class BertEncoder < Base
10
+ attr_reader :embeddings, :encoder, :pooler
11
+
12
+ def initialize(config)
13
+ super(config)
14
+
15
+ @hidden_size = config.hidden_size
16
+ @num_layers = config.num_hidden_layers
17
+ @num_heads = config.num_attention_heads
18
+ @intermediate_size = config.intermediate_size
19
+ @vocab_size = config.vocab_size
20
+ @max_position_embeddings = config.max_position_embeddings
21
+ @type_vocab_size = config.type_vocab_size || 2
22
+ @layer_norm_eps = config.layer_norm_eps
23
+ @hidden_dropout_prob = config.hidden_dropout_prob || 0.1
24
+
25
+ # Embeddings
26
+ @word_embeddings = Torch::NN::Embedding.new(@vocab_size, @hidden_size)
27
+ @position_embeddings = Torch::NN::Embedding.new(@max_position_embeddings, @hidden_size)
28
+ @token_type_embeddings = Torch::NN::Embedding.new(@type_vocab_size, @hidden_size)
29
+ @embeddings_layer_norm = Torch::NN::LayerNorm.new(@hidden_size, eps: @layer_norm_eps)
30
+ @embeddings_dropout = Torch::NN::Dropout.new(p: @hidden_dropout_prob)
31
+
32
+ # Transformer layers
33
+ @layers = Torch::NN::ModuleList.new(
34
+ @num_layers.times.map do
35
+ BertLayer.new(
36
+ hidden_size: @hidden_size,
37
+ num_heads: @num_heads,
38
+ intermediate_size: @intermediate_size,
39
+ layer_norm_eps: @layer_norm_eps,
40
+ hidden_dropout_prob: @hidden_dropout_prob
41
+ )
42
+ end
43
+ )
44
+
45
+ # Pooler (for [CLS] token representation)
46
+ @pooler_dense = Torch::NN::Linear.new(@hidden_size, @hidden_size)
47
+ @pooler_activation = Torch::NN::Tanh.new
48
+ end
49
+
50
+ def forward(input_ids, attention_mask: nil, token_type_ids: nil)
51
+ batch_size, seq_length = input_ids.shape
52
+
53
+ # Create position IDs
54
+ position_ids = Torch.arange(seq_length, device: input_ids.device)
55
+ position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
56
+
57
+ # Default token type IDs to zeros
58
+ token_type_ids ||= Torch.zeros_like(input_ids)
59
+
60
+ # Embeddings
61
+ word_embeds = @word_embeddings.call(input_ids)
62
+ position_embeds = @position_embeddings.call(position_ids)
63
+ token_type_embeds = @token_type_embeddings.call(token_type_ids)
64
+
65
+ embeddings = word_embeds + position_embeds + token_type_embeds
66
+ embeddings = @embeddings_layer_norm.call(embeddings)
67
+ embeddings = @embeddings_dropout.call(embeddings)
68
+
69
+ # Create attention mask for transformer
70
+ # Convert from (batch, seq) to (batch, 1, 1, seq) for broadcasting
71
+ if attention_mask
72
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
73
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
74
+ else
75
+ extended_attention_mask = nil
76
+ end
77
+
78
+ # Transformer layers
79
+ hidden_states = embeddings
80
+ @layers.each do |layer|
81
+ hidden_states = layer.call(hidden_states, attention_mask: extended_attention_mask)
82
+ end
83
+
84
+ # Pool the [CLS] token (first token)
85
+ cls_output = hidden_states[0.., 0, 0..]
86
+ pooled_output = @pooler_dense.call(cls_output)
87
+ pooled_output = @pooler_activation.call(pooled_output)
88
+
89
+ {
90
+ last_hidden_state: hidden_states,
91
+ pooler_output: pooled_output
92
+ }
93
+ end
94
+
95
+ # Get the [CLS] token embedding (useful for classification)
96
+ def get_pooled_output(input_ids, attention_mask: nil, token_type_ids: nil)
97
+ output = forward(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids)
98
+ output[:pooler_output]
99
+ end
100
+
101
+ # Get mean of all token embeddings (useful for sentence embeddings)
102
+ def get_mean_output(input_ids, attention_mask: nil, token_type_ids: nil)
103
+ output = forward(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids)
104
+ hidden_states = output[:last_hidden_state]
105
+
106
+ if attention_mask
107
+ # Mask padding tokens before taking mean
108
+ mask = attention_mask.unsqueeze(-1).expand_as(hidden_states).float
109
+ sum_embeddings = (hidden_states * mask).sum(dim: 1)
110
+ sum_mask = mask.sum(dim: 1).clamp(min: 1e-9)
111
+ sum_embeddings / sum_mask
112
+ else
113
+ hidden_states.mean(dim: 1)
114
+ end
115
+ end
116
+ end
117
+
118
+ # Single BERT transformer layer
119
+ class BertLayer < Torch::NN::Module
120
+ def initialize(hidden_size:, num_heads:, intermediate_size:, layer_norm_eps:, hidden_dropout_prob:)
121
+ super()
122
+
123
+ @attention = BertAttention.new(
124
+ hidden_size: hidden_size,
125
+ num_heads: num_heads,
126
+ dropout: hidden_dropout_prob
127
+ )
128
+ @attention_layer_norm = Torch::NN::LayerNorm.new(hidden_size, eps: layer_norm_eps)
129
+ @attention_dropout = Torch::NN::Dropout.new(p: hidden_dropout_prob)
130
+
131
+ @intermediate = Torch::NN::Linear.new(hidden_size, intermediate_size)
132
+ @intermediate_act = Torch::NN::GELU.new
133
+ @output = Torch::NN::Linear.new(intermediate_size, hidden_size)
134
+ @output_layer_norm = Torch::NN::LayerNorm.new(hidden_size, eps: layer_norm_eps)
135
+ @output_dropout = Torch::NN::Dropout.new(p: hidden_dropout_prob)
136
+ end
137
+
138
+ def forward(hidden_states, attention_mask: nil)
139
+ # Self-attention with residual
140
+ attention_output = @attention.call(hidden_states, attention_mask: attention_mask)
141
+ attention_output = @attention_dropout.call(attention_output)
142
+ hidden_states = @attention_layer_norm.call(hidden_states + attention_output)
143
+
144
+ # FFN with residual
145
+ intermediate_output = @intermediate.call(hidden_states)
146
+ intermediate_output = @intermediate_act.call(intermediate_output)
147
+ layer_output = @output.call(intermediate_output)
148
+ layer_output = @output_dropout.call(layer_output)
149
+ @output_layer_norm.call(hidden_states + layer_output)
150
+ end
151
+ end
152
+
153
+ # BERT multi-head self-attention
154
+ class BertAttention < Torch::NN::Module
155
+ def initialize(hidden_size:, num_heads:, dropout:)
156
+ super()
157
+
158
+ @num_heads = num_heads
159
+ @head_dim = hidden_size / num_heads
160
+ @scale = @head_dim ** -0.5
161
+
162
+ @query = Torch::NN::Linear.new(hidden_size, hidden_size)
163
+ @key = Torch::NN::Linear.new(hidden_size, hidden_size)
164
+ @value = Torch::NN::Linear.new(hidden_size, hidden_size)
165
+ @out = Torch::NN::Linear.new(hidden_size, hidden_size)
166
+ @dropout = Torch::NN::Dropout.new(p: dropout)
167
+ end
168
+
169
+ def forward(hidden_states, attention_mask: nil)
170
+ batch_size, seq_length, _ = hidden_states.shape
171
+
172
+ # Project to Q, K, V
173
+ q = @query.call(hidden_states)
174
+ k = @key.call(hidden_states)
175
+ v = @value.call(hidden_states)
176
+
177
+ # Reshape for multi-head attention
178
+ q = q.view(batch_size, seq_length, @num_heads, @head_dim).transpose(1, 2)
179
+ k = k.view(batch_size, seq_length, @num_heads, @head_dim).transpose(1, 2)
180
+ v = v.view(batch_size, seq_length, @num_heads, @head_dim).transpose(1, 2)
181
+
182
+ # Attention scores
183
+ scores = Torch.matmul(q, k.transpose(-2, -1)) * @scale
184
+
185
+ # Apply attention mask
186
+ scores = scores + attention_mask if attention_mask
187
+
188
+ # Softmax and dropout
189
+ attn_probs = Torch::NN::Functional.softmax(scores, dim: -1)
190
+ attn_probs = @dropout.call(attn_probs)
191
+
192
+ # Apply attention to values
193
+ context = Torch.matmul(attn_probs, v)
194
+
195
+ # Reshape back
196
+ context = context.transpose(1, 2).contiguous.view(batch_size, seq_length, -1)
197
+
198
+ @out.call(context)
199
+ end
200
+ end
201
+ end
202
+ end