fine 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.rspec +3 -0
- data/CHANGELOG.md +38 -0
- data/Gemfile +6 -0
- data/Gemfile.lock +167 -0
- data/LICENSE +21 -0
- data/README.md +212 -0
- data/Rakefile +6 -0
- data/docs/installation.md +151 -0
- data/docs/tutorials/llm-fine-tuning.md +246 -0
- data/docs/tutorials/model-export.md +200 -0
- data/docs/tutorials/siglip2-image-classification.md +130 -0
- data/docs/tutorials/siglip2-object-recognition.md +203 -0
- data/docs/tutorials/siglip2-similarity-search.md +152 -0
- data/docs/tutorials/text-classification.md +233 -0
- data/docs/tutorials/text-embeddings.md +211 -0
- data/examples/basic_classification.rb +70 -0
- data/examples/data/tool_calls.jsonl +30 -0
- data/examples/demo_training.rb +78 -0
- data/examples/finetune_gemma3_tools.rb +135 -0
- data/examples/real_llm_test.rb +128 -0
- data/examples/real_text_classification_test.rb +90 -0
- data/examples/real_text_embedder_test.rb +110 -0
- data/examples/real_training_test.rb +88 -0
- data/examples/test_export.rb +28 -0
- data/examples/test_image_classifier.rb +79 -0
- data/examples/test_llm.rb +100 -0
- data/examples/test_text_classifier.rb +59 -0
- data/lib/fine/callbacks/base.rb +140 -0
- data/lib/fine/callbacks/progress_bar.rb +66 -0
- data/lib/fine/configuration.rb +106 -0
- data/lib/fine/datasets/data_loader.rb +63 -0
- data/lib/fine/datasets/image_dataset.rb +203 -0
- data/lib/fine/datasets/instruction_dataset.rb +226 -0
- data/lib/fine/datasets/text_data_loader.rb +88 -0
- data/lib/fine/datasets/text_dataset.rb +266 -0
- data/lib/fine/error.rb +49 -0
- data/lib/fine/export/gguf_exporter.rb +424 -0
- data/lib/fine/export/onnx_exporter.rb +249 -0
- data/lib/fine/export.rb +53 -0
- data/lib/fine/hub/config_loader.rb +145 -0
- data/lib/fine/hub/model_downloader.rb +136 -0
- data/lib/fine/hub/safetensors_loader.rb +108 -0
- data/lib/fine/image_classifier.rb +256 -0
- data/lib/fine/llm.rb +336 -0
- data/lib/fine/models/base.rb +48 -0
- data/lib/fine/models/bert_encoder.rb +202 -0
- data/lib/fine/models/bert_for_sequence_classification.rb +226 -0
- data/lib/fine/models/causal_lm.rb +279 -0
- data/lib/fine/models/classification_head.rb +24 -0
- data/lib/fine/models/gemma3_decoder.rb +244 -0
- data/lib/fine/models/llama_decoder.rb +297 -0
- data/lib/fine/models/sentence_transformer.rb +202 -0
- data/lib/fine/models/siglip2_for_image_classification.rb +155 -0
- data/lib/fine/models/siglip2_vision_encoder.rb +190 -0
- data/lib/fine/text_classifier.rb +250 -0
- data/lib/fine/text_embedder.rb +221 -0
- data/lib/fine/tokenizers/auto_tokenizer.rb +208 -0
- data/lib/fine/training/llm_trainer.rb +212 -0
- data/lib/fine/training/text_trainer.rb +275 -0
- data/lib/fine/training/trainer.rb +194 -0
- data/lib/fine/transforms/compose.rb +28 -0
- data/lib/fine/transforms/normalize.rb +33 -0
- data/lib/fine/transforms/resize.rb +35 -0
- data/lib/fine/transforms/to_tensor.rb +53 -0
- data/lib/fine/version.rb +3 -0
- data/lib/fine.rb +112 -0
- data/mise.toml +2 -0
- metadata +240 -0
data/lib/fine/llm.rb
ADDED
|
@@ -0,0 +1,336 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fine
|
|
4
|
+
# High-level API for LLM fine-tuning
|
|
5
|
+
#
|
|
6
|
+
# @example Basic fine-tuning
|
|
7
|
+
# llm = Fine::LLM.new("meta-llama/Llama-3.2-1B")
|
|
8
|
+
# llm.fit(train_file: "instructions.jsonl", epochs: 3)
|
|
9
|
+
# llm.save("my_llama")
|
|
10
|
+
#
|
|
11
|
+
# @example Generation
|
|
12
|
+
# llm = Fine::LLM.load("my_llama")
|
|
13
|
+
# response = llm.generate("What is Ruby?", max_new_tokens: 100)
|
|
14
|
+
#
|
|
15
|
+
# @example With configuration
|
|
16
|
+
# llm = Fine::LLM.new("google/gemma-2b") do |config|
|
|
17
|
+
# config.epochs = 3
|
|
18
|
+
# config.batch_size = 4
|
|
19
|
+
# config.learning_rate = 1e-5
|
|
20
|
+
# config.max_length = 1024
|
|
21
|
+
# end
|
|
22
|
+
#
|
|
23
|
+
class LLM
|
|
24
|
+
attr_reader :model, :config, :tokenizer, :model_id
|
|
25
|
+
|
|
26
|
+
# Create a new LLM for fine-tuning
|
|
27
|
+
#
|
|
28
|
+
# @param model_id [String] HuggingFace model ID
|
|
29
|
+
# @yield [config] Optional configuration block
|
|
30
|
+
def initialize(model_id, &block)
|
|
31
|
+
@model_id = model_id
|
|
32
|
+
@config = LLMConfiguration.new
|
|
33
|
+
@model = nil
|
|
34
|
+
@tokenizer = nil
|
|
35
|
+
@trained = false
|
|
36
|
+
|
|
37
|
+
block&.call(@config)
|
|
38
|
+
|
|
39
|
+
if @config.callbacks.empty? && Fine.configuration&.progress_bar != false
|
|
40
|
+
@config.callbacks << Callbacks::ProgressBar.new
|
|
41
|
+
end
|
|
42
|
+
end
|
|
43
|
+
|
|
44
|
+
# Load a fine-tuned LLM from disk
|
|
45
|
+
#
|
|
46
|
+
# @param path [String] Path to saved model
|
|
47
|
+
# @return [LLM]
|
|
48
|
+
def self.load(path)
|
|
49
|
+
config_path = File.join(path, "config.json")
|
|
50
|
+
raise ModelNotFoundError.new(path) unless File.exist?(config_path)
|
|
51
|
+
|
|
52
|
+
config_data = JSON.parse(File.read(config_path))
|
|
53
|
+
|
|
54
|
+
llm = allocate
|
|
55
|
+
llm.instance_variable_set(:@model_id, config_data["_model_id"] || "custom")
|
|
56
|
+
llm.instance_variable_set(:@config, LLMConfiguration.new)
|
|
57
|
+
llm.instance_variable_set(:@trained, true)
|
|
58
|
+
|
|
59
|
+
# Load tokenizer
|
|
60
|
+
tokenizer_path = File.join(path, "tokenizer.json")
|
|
61
|
+
tokenizer = if File.exist?(tokenizer_path)
|
|
62
|
+
Tokenizers::AutoTokenizer.new(path, max_length: config_data["max_length"] || 2048)
|
|
63
|
+
else
|
|
64
|
+
Tokenizers::AutoTokenizer.from_pretrained(
|
|
65
|
+
config_data["_model_id"],
|
|
66
|
+
max_length: config_data["max_length"] || 2048
|
|
67
|
+
)
|
|
68
|
+
end
|
|
69
|
+
llm.instance_variable_set(:@tokenizer, tokenizer)
|
|
70
|
+
|
|
71
|
+
# Load model
|
|
72
|
+
llm.instance_variable_set(
|
|
73
|
+
:@model,
|
|
74
|
+
Models::CausalLM.load(path)
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
llm
|
|
78
|
+
end
|
|
79
|
+
|
|
80
|
+
# Fine-tune on instruction data
|
|
81
|
+
#
|
|
82
|
+
# @param train_file [String] Path to training data (JSONL)
|
|
83
|
+
# @param val_file [String, nil] Path to validation data
|
|
84
|
+
# @param format [Symbol] Data format (:alpaca, :sharegpt, :simple, :auto)
|
|
85
|
+
# @param epochs [Integer, nil] Override config epochs
|
|
86
|
+
# @return [Array<Hash>] Training history
|
|
87
|
+
def fit(train_file:, val_file: nil, format: :auto, epochs: nil)
|
|
88
|
+
@config.epochs = epochs if epochs
|
|
89
|
+
|
|
90
|
+
# Download model files first (including tokenizer)
|
|
91
|
+
downloader = Hub::ModelDownloader.new(@model_id)
|
|
92
|
+
model_path = downloader.download
|
|
93
|
+
|
|
94
|
+
# Load tokenizer from cache or HuggingFace
|
|
95
|
+
tokenizer_path = File.join(model_path, "tokenizer.json")
|
|
96
|
+
@tokenizer = if File.exist?(tokenizer_path)
|
|
97
|
+
Tokenizers::AutoTokenizer.new(model_path, max_length: @config.max_length)
|
|
98
|
+
else
|
|
99
|
+
Tokenizers::AutoTokenizer.from_pretrained(@model_id, max_length: @config.max_length)
|
|
100
|
+
end
|
|
101
|
+
|
|
102
|
+
# Set pad token if not set
|
|
103
|
+
@config.pad_token_id ||= @tokenizer.pad_token_id || @tokenizer.eos_token_id || 0
|
|
104
|
+
|
|
105
|
+
# Load datasets
|
|
106
|
+
train_dataset = Datasets::InstructionDataset.from_jsonl(
|
|
107
|
+
train_file,
|
|
108
|
+
tokenizer: @tokenizer,
|
|
109
|
+
format: format,
|
|
110
|
+
max_length: @config.max_length
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
val_dataset = if val_file
|
|
114
|
+
Datasets::InstructionDataset.from_jsonl(
|
|
115
|
+
val_file,
|
|
116
|
+
tokenizer: @tokenizer,
|
|
117
|
+
format: format,
|
|
118
|
+
max_length: @config.max_length
|
|
119
|
+
)
|
|
120
|
+
end
|
|
121
|
+
|
|
122
|
+
# Load model
|
|
123
|
+
@model = Models::CausalLM.from_pretrained(@model_id)
|
|
124
|
+
|
|
125
|
+
# Freeze layers if configured
|
|
126
|
+
if @config.freeze_layers && @config.freeze_layers > 0
|
|
127
|
+
freeze_bottom_layers(@config.freeze_layers)
|
|
128
|
+
end
|
|
129
|
+
|
|
130
|
+
# Train
|
|
131
|
+
trainer = Training::LLMTrainer.new(
|
|
132
|
+
@model,
|
|
133
|
+
@config,
|
|
134
|
+
train_dataset: train_dataset,
|
|
135
|
+
val_dataset: val_dataset
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
history = trainer.fit
|
|
139
|
+
@trained = true
|
|
140
|
+
|
|
141
|
+
history
|
|
142
|
+
end
|
|
143
|
+
|
|
144
|
+
# Fine-tune with automatic train/val split
|
|
145
|
+
#
|
|
146
|
+
# @param data_file [String] Path to data file
|
|
147
|
+
# @param val_split [Float] Fraction for validation
|
|
148
|
+
# @param format [Symbol] Data format
|
|
149
|
+
# @return [Array<Hash>] Training history
|
|
150
|
+
def fit_with_split(data_file:, val_split: 0.1, format: :auto, epochs: nil)
|
|
151
|
+
@config.epochs = epochs if epochs
|
|
152
|
+
|
|
153
|
+
# Download model files first (including tokenizer)
|
|
154
|
+
downloader = Hub::ModelDownloader.new(@model_id)
|
|
155
|
+
model_path = downloader.download
|
|
156
|
+
|
|
157
|
+
# Load tokenizer from cache or HuggingFace
|
|
158
|
+
tokenizer_path = File.join(model_path, "tokenizer.json")
|
|
159
|
+
@tokenizer = if File.exist?(tokenizer_path)
|
|
160
|
+
Tokenizers::AutoTokenizer.new(model_path, max_length: @config.max_length)
|
|
161
|
+
else
|
|
162
|
+
Tokenizers::AutoTokenizer.from_pretrained(@model_id, max_length: @config.max_length)
|
|
163
|
+
end
|
|
164
|
+
|
|
165
|
+
@config.pad_token_id ||= @tokenizer.pad_token_id || @tokenizer.eos_token_id || 0
|
|
166
|
+
|
|
167
|
+
full_dataset = Datasets::InstructionDataset.from_jsonl(
|
|
168
|
+
data_file,
|
|
169
|
+
tokenizer: @tokenizer,
|
|
170
|
+
format: format,
|
|
171
|
+
max_length: @config.max_length
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
train_dataset, val_dataset = full_dataset.split(test_size: val_split)
|
|
175
|
+
|
|
176
|
+
@model = Models::CausalLM.from_pretrained(@model_id)
|
|
177
|
+
|
|
178
|
+
if @config.freeze_layers && @config.freeze_layers > 0
|
|
179
|
+
freeze_bottom_layers(@config.freeze_layers)
|
|
180
|
+
end
|
|
181
|
+
|
|
182
|
+
trainer = Training::LLMTrainer.new(
|
|
183
|
+
@model,
|
|
184
|
+
@config,
|
|
185
|
+
train_dataset: train_dataset,
|
|
186
|
+
val_dataset: val_dataset
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
history = trainer.fit
|
|
190
|
+
@trained = true
|
|
191
|
+
|
|
192
|
+
history
|
|
193
|
+
end
|
|
194
|
+
|
|
195
|
+
# Generate text
|
|
196
|
+
#
|
|
197
|
+
# @param prompt [String] Input prompt
|
|
198
|
+
# @param max_new_tokens [Integer] Maximum tokens to generate
|
|
199
|
+
# @param temperature [Float] Sampling temperature (higher = more random)
|
|
200
|
+
# @param top_p [Float] Nucleus sampling threshold
|
|
201
|
+
# @param top_k [Integer] Top-k sampling
|
|
202
|
+
# @param do_sample [Boolean] Whether to sample (false = greedy)
|
|
203
|
+
# @return [String] Generated text
|
|
204
|
+
def generate(prompt, max_new_tokens: 100, temperature: 0.7, top_p: 0.9, top_k: 50, do_sample: true)
|
|
205
|
+
raise TrainingError, "Model not loaded" unless @model && @tokenizer
|
|
206
|
+
|
|
207
|
+
# Tokenize prompt (without tensors for easier manipulation)
|
|
208
|
+
encoding = @tokenizer.encode(prompt, return_tensors: false)
|
|
209
|
+
input_ids = Torch.tensor([encoding[:input_ids].first])
|
|
210
|
+
|
|
211
|
+
# Move to device
|
|
212
|
+
input_ids = input_ids.to(Fine.device)
|
|
213
|
+
@model.to(Fine.device)
|
|
214
|
+
|
|
215
|
+
# Generate
|
|
216
|
+
output_ids = @model.generate(
|
|
217
|
+
input_ids,
|
|
218
|
+
max_new_tokens: max_new_tokens,
|
|
219
|
+
temperature: temperature,
|
|
220
|
+
top_p: top_p,
|
|
221
|
+
top_k: top_k,
|
|
222
|
+
do_sample: do_sample,
|
|
223
|
+
eos_token_id: @tokenizer.eos_token_id,
|
|
224
|
+
pad_token_id: @config.pad_token_id
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
# Decode
|
|
228
|
+
@tokenizer.decode(output_ids[0].to_a)
|
|
229
|
+
end
|
|
230
|
+
|
|
231
|
+
# Chat-style generation
|
|
232
|
+
#
|
|
233
|
+
# @param messages [Array<Hash>] Messages with :role and :content
|
|
234
|
+
# @param kwargs [Hash] Generation parameters
|
|
235
|
+
# @return [String] Assistant response
|
|
236
|
+
def chat(messages, **kwargs)
|
|
237
|
+
prompt = format_chat_prompt(messages)
|
|
238
|
+
full_response = generate(prompt, **kwargs)
|
|
239
|
+
|
|
240
|
+
# Extract just the assistant response
|
|
241
|
+
if full_response.include?("### Response:")
|
|
242
|
+
full_response.split("### Response:").last.strip
|
|
243
|
+
else
|
|
244
|
+
# Remove the prompt from the response
|
|
245
|
+
full_response[prompt.length..].strip
|
|
246
|
+
end
|
|
247
|
+
end
|
|
248
|
+
|
|
249
|
+
# Save the model
|
|
250
|
+
#
|
|
251
|
+
# @param path [String] Directory to save to
|
|
252
|
+
def save(path)
|
|
253
|
+
raise TrainingError, "Model not trained or loaded" unless @model
|
|
254
|
+
|
|
255
|
+
@model.save(path)
|
|
256
|
+
@tokenizer.save(path)
|
|
257
|
+
|
|
258
|
+
# Update config with model ID
|
|
259
|
+
config_path = File.join(path, "config.json")
|
|
260
|
+
config = JSON.parse(File.read(config_path))
|
|
261
|
+
config["_model_id"] = @model_id
|
|
262
|
+
config["max_length"] = @config.max_length
|
|
263
|
+
File.write(config_path, JSON.pretty_generate(config))
|
|
264
|
+
end
|
|
265
|
+
|
|
266
|
+
# Export to GGUF format for llama.cpp, ollama, etc.
|
|
267
|
+
#
|
|
268
|
+
# @param path [String] Output path for GGUF file
|
|
269
|
+
# @param quantization [Symbol] Quantization type (:f16, :q4_0, :q8_0, etc.)
|
|
270
|
+
# @param metadata [Hash] Additional metadata
|
|
271
|
+
# @return [String] The output path
|
|
272
|
+
def export_gguf(path, quantization: :f16, **options)
|
|
273
|
+
Export.to_gguf(self, path, quantization: quantization, **options)
|
|
274
|
+
end
|
|
275
|
+
|
|
276
|
+
# Export to ONNX format
|
|
277
|
+
#
|
|
278
|
+
# @param path [String] Output path for ONNX file
|
|
279
|
+
# @param options [Hash] Export options
|
|
280
|
+
# @return [String] The output path
|
|
281
|
+
def export_onnx(path, **options)
|
|
282
|
+
Export.to_onnx(self, path, **options)
|
|
283
|
+
end
|
|
284
|
+
|
|
285
|
+
private
|
|
286
|
+
|
|
287
|
+
def freeze_bottom_layers(num_layers)
|
|
288
|
+
# Freeze embedding
|
|
289
|
+
@model.decoder.embed_tokens.parameters.each { |p| p.requires_grad = false }
|
|
290
|
+
|
|
291
|
+
# Freeze bottom N layers
|
|
292
|
+
@model.decoder.layers[0...num_layers].each do |layer|
|
|
293
|
+
layer.parameters.each { |p| p.requires_grad = false }
|
|
294
|
+
end
|
|
295
|
+
end
|
|
296
|
+
|
|
297
|
+
def format_chat_prompt(messages)
|
|
298
|
+
prompt = ""
|
|
299
|
+
|
|
300
|
+
messages.each do |msg|
|
|
301
|
+
case msg[:role]
|
|
302
|
+
when "system"
|
|
303
|
+
prompt += "### System:\n#{msg[:content]}\n\n"
|
|
304
|
+
when "user"
|
|
305
|
+
prompt += "### Instruction:\n#{msg[:content]}\n\n"
|
|
306
|
+
when "assistant"
|
|
307
|
+
prompt += "### Response:\n#{msg[:content]}\n\n"
|
|
308
|
+
end
|
|
309
|
+
end
|
|
310
|
+
|
|
311
|
+
# Add response prefix for the model to continue
|
|
312
|
+
prompt += "### Response:\n" unless prompt.end_with?("### Response:\n")
|
|
313
|
+
|
|
314
|
+
prompt
|
|
315
|
+
end
|
|
316
|
+
end
|
|
317
|
+
|
|
318
|
+
# Configuration for LLM fine-tuning
|
|
319
|
+
class LLMConfiguration < Configuration
|
|
320
|
+
attr_accessor :max_length, :warmup_steps, :gradient_accumulation_steps,
|
|
321
|
+
:max_grad_norm, :freeze_layers, :pad_token_id
|
|
322
|
+
|
|
323
|
+
def initialize
|
|
324
|
+
super
|
|
325
|
+
@max_length = 2048
|
|
326
|
+
@learning_rate = 2e-5
|
|
327
|
+
@batch_size = 4
|
|
328
|
+
@epochs = 1
|
|
329
|
+
@warmup_steps = 100
|
|
330
|
+
@gradient_accumulation_steps = 4
|
|
331
|
+
@max_grad_norm = 1.0
|
|
332
|
+
@freeze_layers = 0
|
|
333
|
+
@pad_token_id = nil
|
|
334
|
+
end
|
|
335
|
+
end
|
|
336
|
+
end
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fine
|
|
4
|
+
module Models
|
|
5
|
+
# Base class for all Fine models
|
|
6
|
+
class Base < Torch::NN::Module
|
|
7
|
+
attr_reader :config
|
|
8
|
+
|
|
9
|
+
def initialize(config)
|
|
10
|
+
super()
|
|
11
|
+
@config = config
|
|
12
|
+
end
|
|
13
|
+
|
|
14
|
+
# Save model weights to a file
|
|
15
|
+
#
|
|
16
|
+
# @param path [String] Directory path to save to
|
|
17
|
+
def save_pretrained(path)
|
|
18
|
+
FileUtils.mkdir_p(path)
|
|
19
|
+
|
|
20
|
+
# Save weights as safetensors
|
|
21
|
+
weights_path = File.join(path, "model.safetensors")
|
|
22
|
+
Safetensors::Torch.save_file(state_dict, weights_path)
|
|
23
|
+
|
|
24
|
+
# Save config
|
|
25
|
+
config_path = File.join(path, "config.json")
|
|
26
|
+
File.write(config_path, JSON.pretty_generate(@config.to_h))
|
|
27
|
+
end
|
|
28
|
+
|
|
29
|
+
# Freeze all parameters (for feature extraction)
|
|
30
|
+
def freeze!
|
|
31
|
+
parameters.each { |p| p.requires_grad = false }
|
|
32
|
+
self
|
|
33
|
+
end
|
|
34
|
+
|
|
35
|
+
# Unfreeze all parameters
|
|
36
|
+
def unfreeze!
|
|
37
|
+
parameters.each { |p| p.requires_grad = true }
|
|
38
|
+
self
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
# Get number of trainable parameters
|
|
42
|
+
def num_parameters(trainable_only: false)
|
|
43
|
+
params = trainable_only ? parameters.select(&:requires_grad) : parameters
|
|
44
|
+
params.sum { |p| p.numel }
|
|
45
|
+
end
|
|
46
|
+
end
|
|
47
|
+
end
|
|
48
|
+
end
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fine
|
|
4
|
+
module Models
|
|
5
|
+
# BERT/DistilBERT/DeBERTa Encoder
|
|
6
|
+
#
|
|
7
|
+
# Implements transformer encoder for text understanding tasks.
|
|
8
|
+
# Supports loading pretrained weights from HuggingFace Hub.
|
|
9
|
+
class BertEncoder < Base
|
|
10
|
+
attr_reader :embeddings, :encoder, :pooler
|
|
11
|
+
|
|
12
|
+
def initialize(config)
|
|
13
|
+
super(config)
|
|
14
|
+
|
|
15
|
+
@hidden_size = config.hidden_size
|
|
16
|
+
@num_layers = config.num_hidden_layers
|
|
17
|
+
@num_heads = config.num_attention_heads
|
|
18
|
+
@intermediate_size = config.intermediate_size
|
|
19
|
+
@vocab_size = config.vocab_size
|
|
20
|
+
@max_position_embeddings = config.max_position_embeddings
|
|
21
|
+
@type_vocab_size = config.type_vocab_size || 2
|
|
22
|
+
@layer_norm_eps = config.layer_norm_eps
|
|
23
|
+
@hidden_dropout_prob = config.hidden_dropout_prob || 0.1
|
|
24
|
+
|
|
25
|
+
# Embeddings
|
|
26
|
+
@word_embeddings = Torch::NN::Embedding.new(@vocab_size, @hidden_size)
|
|
27
|
+
@position_embeddings = Torch::NN::Embedding.new(@max_position_embeddings, @hidden_size)
|
|
28
|
+
@token_type_embeddings = Torch::NN::Embedding.new(@type_vocab_size, @hidden_size)
|
|
29
|
+
@embeddings_layer_norm = Torch::NN::LayerNorm.new(@hidden_size, eps: @layer_norm_eps)
|
|
30
|
+
@embeddings_dropout = Torch::NN::Dropout.new(p: @hidden_dropout_prob)
|
|
31
|
+
|
|
32
|
+
# Transformer layers
|
|
33
|
+
@layers = Torch::NN::ModuleList.new(
|
|
34
|
+
@num_layers.times.map do
|
|
35
|
+
BertLayer.new(
|
|
36
|
+
hidden_size: @hidden_size,
|
|
37
|
+
num_heads: @num_heads,
|
|
38
|
+
intermediate_size: @intermediate_size,
|
|
39
|
+
layer_norm_eps: @layer_norm_eps,
|
|
40
|
+
hidden_dropout_prob: @hidden_dropout_prob
|
|
41
|
+
)
|
|
42
|
+
end
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
# Pooler (for [CLS] token representation)
|
|
46
|
+
@pooler_dense = Torch::NN::Linear.new(@hidden_size, @hidden_size)
|
|
47
|
+
@pooler_activation = Torch::NN::Tanh.new
|
|
48
|
+
end
|
|
49
|
+
|
|
50
|
+
def forward(input_ids, attention_mask: nil, token_type_ids: nil)
|
|
51
|
+
batch_size, seq_length = input_ids.shape
|
|
52
|
+
|
|
53
|
+
# Create position IDs
|
|
54
|
+
position_ids = Torch.arange(seq_length, device: input_ids.device)
|
|
55
|
+
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
|
56
|
+
|
|
57
|
+
# Default token type IDs to zeros
|
|
58
|
+
token_type_ids ||= Torch.zeros_like(input_ids)
|
|
59
|
+
|
|
60
|
+
# Embeddings
|
|
61
|
+
word_embeds = @word_embeddings.call(input_ids)
|
|
62
|
+
position_embeds = @position_embeddings.call(position_ids)
|
|
63
|
+
token_type_embeds = @token_type_embeddings.call(token_type_ids)
|
|
64
|
+
|
|
65
|
+
embeddings = word_embeds + position_embeds + token_type_embeds
|
|
66
|
+
embeddings = @embeddings_layer_norm.call(embeddings)
|
|
67
|
+
embeddings = @embeddings_dropout.call(embeddings)
|
|
68
|
+
|
|
69
|
+
# Create attention mask for transformer
|
|
70
|
+
# Convert from (batch, seq) to (batch, 1, 1, seq) for broadcasting
|
|
71
|
+
if attention_mask
|
|
72
|
+
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
|
73
|
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
|
74
|
+
else
|
|
75
|
+
extended_attention_mask = nil
|
|
76
|
+
end
|
|
77
|
+
|
|
78
|
+
# Transformer layers
|
|
79
|
+
hidden_states = embeddings
|
|
80
|
+
@layers.each do |layer|
|
|
81
|
+
hidden_states = layer.call(hidden_states, attention_mask: extended_attention_mask)
|
|
82
|
+
end
|
|
83
|
+
|
|
84
|
+
# Pool the [CLS] token (first token)
|
|
85
|
+
cls_output = hidden_states[0.., 0, 0..]
|
|
86
|
+
pooled_output = @pooler_dense.call(cls_output)
|
|
87
|
+
pooled_output = @pooler_activation.call(pooled_output)
|
|
88
|
+
|
|
89
|
+
{
|
|
90
|
+
last_hidden_state: hidden_states,
|
|
91
|
+
pooler_output: pooled_output
|
|
92
|
+
}
|
|
93
|
+
end
|
|
94
|
+
|
|
95
|
+
# Get the [CLS] token embedding (useful for classification)
|
|
96
|
+
def get_pooled_output(input_ids, attention_mask: nil, token_type_ids: nil)
|
|
97
|
+
output = forward(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids)
|
|
98
|
+
output[:pooler_output]
|
|
99
|
+
end
|
|
100
|
+
|
|
101
|
+
# Get mean of all token embeddings (useful for sentence embeddings)
|
|
102
|
+
def get_mean_output(input_ids, attention_mask: nil, token_type_ids: nil)
|
|
103
|
+
output = forward(input_ids, attention_mask: attention_mask, token_type_ids: token_type_ids)
|
|
104
|
+
hidden_states = output[:last_hidden_state]
|
|
105
|
+
|
|
106
|
+
if attention_mask
|
|
107
|
+
# Mask padding tokens before taking mean
|
|
108
|
+
mask = attention_mask.unsqueeze(-1).expand_as(hidden_states).float
|
|
109
|
+
sum_embeddings = (hidden_states * mask).sum(dim: 1)
|
|
110
|
+
sum_mask = mask.sum(dim: 1).clamp(min: 1e-9)
|
|
111
|
+
sum_embeddings / sum_mask
|
|
112
|
+
else
|
|
113
|
+
hidden_states.mean(dim: 1)
|
|
114
|
+
end
|
|
115
|
+
end
|
|
116
|
+
end
|
|
117
|
+
|
|
118
|
+
# Single BERT transformer layer
|
|
119
|
+
class BertLayer < Torch::NN::Module
|
|
120
|
+
def initialize(hidden_size:, num_heads:, intermediate_size:, layer_norm_eps:, hidden_dropout_prob:)
|
|
121
|
+
super()
|
|
122
|
+
|
|
123
|
+
@attention = BertAttention.new(
|
|
124
|
+
hidden_size: hidden_size,
|
|
125
|
+
num_heads: num_heads,
|
|
126
|
+
dropout: hidden_dropout_prob
|
|
127
|
+
)
|
|
128
|
+
@attention_layer_norm = Torch::NN::LayerNorm.new(hidden_size, eps: layer_norm_eps)
|
|
129
|
+
@attention_dropout = Torch::NN::Dropout.new(p: hidden_dropout_prob)
|
|
130
|
+
|
|
131
|
+
@intermediate = Torch::NN::Linear.new(hidden_size, intermediate_size)
|
|
132
|
+
@intermediate_act = Torch::NN::GELU.new
|
|
133
|
+
@output = Torch::NN::Linear.new(intermediate_size, hidden_size)
|
|
134
|
+
@output_layer_norm = Torch::NN::LayerNorm.new(hidden_size, eps: layer_norm_eps)
|
|
135
|
+
@output_dropout = Torch::NN::Dropout.new(p: hidden_dropout_prob)
|
|
136
|
+
end
|
|
137
|
+
|
|
138
|
+
def forward(hidden_states, attention_mask: nil)
|
|
139
|
+
# Self-attention with residual
|
|
140
|
+
attention_output = @attention.call(hidden_states, attention_mask: attention_mask)
|
|
141
|
+
attention_output = @attention_dropout.call(attention_output)
|
|
142
|
+
hidden_states = @attention_layer_norm.call(hidden_states + attention_output)
|
|
143
|
+
|
|
144
|
+
# FFN with residual
|
|
145
|
+
intermediate_output = @intermediate.call(hidden_states)
|
|
146
|
+
intermediate_output = @intermediate_act.call(intermediate_output)
|
|
147
|
+
layer_output = @output.call(intermediate_output)
|
|
148
|
+
layer_output = @output_dropout.call(layer_output)
|
|
149
|
+
@output_layer_norm.call(hidden_states + layer_output)
|
|
150
|
+
end
|
|
151
|
+
end
|
|
152
|
+
|
|
153
|
+
# BERT multi-head self-attention
|
|
154
|
+
class BertAttention < Torch::NN::Module
|
|
155
|
+
def initialize(hidden_size:, num_heads:, dropout:)
|
|
156
|
+
super()
|
|
157
|
+
|
|
158
|
+
@num_heads = num_heads
|
|
159
|
+
@head_dim = hidden_size / num_heads
|
|
160
|
+
@scale = @head_dim ** -0.5
|
|
161
|
+
|
|
162
|
+
@query = Torch::NN::Linear.new(hidden_size, hidden_size)
|
|
163
|
+
@key = Torch::NN::Linear.new(hidden_size, hidden_size)
|
|
164
|
+
@value = Torch::NN::Linear.new(hidden_size, hidden_size)
|
|
165
|
+
@out = Torch::NN::Linear.new(hidden_size, hidden_size)
|
|
166
|
+
@dropout = Torch::NN::Dropout.new(p: dropout)
|
|
167
|
+
end
|
|
168
|
+
|
|
169
|
+
def forward(hidden_states, attention_mask: nil)
|
|
170
|
+
batch_size, seq_length, _ = hidden_states.shape
|
|
171
|
+
|
|
172
|
+
# Project to Q, K, V
|
|
173
|
+
q = @query.call(hidden_states)
|
|
174
|
+
k = @key.call(hidden_states)
|
|
175
|
+
v = @value.call(hidden_states)
|
|
176
|
+
|
|
177
|
+
# Reshape for multi-head attention
|
|
178
|
+
q = q.view(batch_size, seq_length, @num_heads, @head_dim).transpose(1, 2)
|
|
179
|
+
k = k.view(batch_size, seq_length, @num_heads, @head_dim).transpose(1, 2)
|
|
180
|
+
v = v.view(batch_size, seq_length, @num_heads, @head_dim).transpose(1, 2)
|
|
181
|
+
|
|
182
|
+
# Attention scores
|
|
183
|
+
scores = Torch.matmul(q, k.transpose(-2, -1)) * @scale
|
|
184
|
+
|
|
185
|
+
# Apply attention mask
|
|
186
|
+
scores = scores + attention_mask if attention_mask
|
|
187
|
+
|
|
188
|
+
# Softmax and dropout
|
|
189
|
+
attn_probs = Torch::NN::Functional.softmax(scores, dim: -1)
|
|
190
|
+
attn_probs = @dropout.call(attn_probs)
|
|
191
|
+
|
|
192
|
+
# Apply attention to values
|
|
193
|
+
context = Torch.matmul(attn_probs, v)
|
|
194
|
+
|
|
195
|
+
# Reshape back
|
|
196
|
+
context = context.transpose(1, 2).contiguous.view(batch_size, seq_length, -1)
|
|
197
|
+
|
|
198
|
+
@out.call(context)
|
|
199
|
+
end
|
|
200
|
+
end
|
|
201
|
+
end
|
|
202
|
+
end
|