fine 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.rspec +3 -0
- data/CHANGELOG.md +38 -0
- data/Gemfile +6 -0
- data/Gemfile.lock +167 -0
- data/LICENSE +21 -0
- data/README.md +212 -0
- data/Rakefile +6 -0
- data/docs/installation.md +151 -0
- data/docs/tutorials/llm-fine-tuning.md +246 -0
- data/docs/tutorials/model-export.md +200 -0
- data/docs/tutorials/siglip2-image-classification.md +130 -0
- data/docs/tutorials/siglip2-object-recognition.md +203 -0
- data/docs/tutorials/siglip2-similarity-search.md +152 -0
- data/docs/tutorials/text-classification.md +233 -0
- data/docs/tutorials/text-embeddings.md +211 -0
- data/examples/basic_classification.rb +70 -0
- data/examples/data/tool_calls.jsonl +30 -0
- data/examples/demo_training.rb +78 -0
- data/examples/finetune_gemma3_tools.rb +135 -0
- data/examples/real_llm_test.rb +128 -0
- data/examples/real_text_classification_test.rb +90 -0
- data/examples/real_text_embedder_test.rb +110 -0
- data/examples/real_training_test.rb +88 -0
- data/examples/test_export.rb +28 -0
- data/examples/test_image_classifier.rb +79 -0
- data/examples/test_llm.rb +100 -0
- data/examples/test_text_classifier.rb +59 -0
- data/lib/fine/callbacks/base.rb +140 -0
- data/lib/fine/callbacks/progress_bar.rb +66 -0
- data/lib/fine/configuration.rb +106 -0
- data/lib/fine/datasets/data_loader.rb +63 -0
- data/lib/fine/datasets/image_dataset.rb +203 -0
- data/lib/fine/datasets/instruction_dataset.rb +226 -0
- data/lib/fine/datasets/text_data_loader.rb +88 -0
- data/lib/fine/datasets/text_dataset.rb +266 -0
- data/lib/fine/error.rb +49 -0
- data/lib/fine/export/gguf_exporter.rb +424 -0
- data/lib/fine/export/onnx_exporter.rb +249 -0
- data/lib/fine/export.rb +53 -0
- data/lib/fine/hub/config_loader.rb +145 -0
- data/lib/fine/hub/model_downloader.rb +136 -0
- data/lib/fine/hub/safetensors_loader.rb +108 -0
- data/lib/fine/image_classifier.rb +256 -0
- data/lib/fine/llm.rb +336 -0
- data/lib/fine/models/base.rb +48 -0
- data/lib/fine/models/bert_encoder.rb +202 -0
- data/lib/fine/models/bert_for_sequence_classification.rb +226 -0
- data/lib/fine/models/causal_lm.rb +279 -0
- data/lib/fine/models/classification_head.rb +24 -0
- data/lib/fine/models/gemma3_decoder.rb +244 -0
- data/lib/fine/models/llama_decoder.rb +297 -0
- data/lib/fine/models/sentence_transformer.rb +202 -0
- data/lib/fine/models/siglip2_for_image_classification.rb +155 -0
- data/lib/fine/models/siglip2_vision_encoder.rb +190 -0
- data/lib/fine/text_classifier.rb +250 -0
- data/lib/fine/text_embedder.rb +221 -0
- data/lib/fine/tokenizers/auto_tokenizer.rb +208 -0
- data/lib/fine/training/llm_trainer.rb +212 -0
- data/lib/fine/training/text_trainer.rb +275 -0
- data/lib/fine/training/trainer.rb +194 -0
- data/lib/fine/transforms/compose.rb +28 -0
- data/lib/fine/transforms/normalize.rb +33 -0
- data/lib/fine/transforms/resize.rb +35 -0
- data/lib/fine/transforms/to_tensor.rb +53 -0
- data/lib/fine/version.rb +3 -0
- data/lib/fine.rb +112 -0
- data/mise.toml +2 -0
- metadata +240 -0
|
@@ -0,0 +1,250 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fine
|
|
4
|
+
# High-level API for text classification
|
|
5
|
+
#
|
|
6
|
+
# @example Basic usage
|
|
7
|
+
# classifier = Fine::TextClassifier.new("distilbert-base-uncased")
|
|
8
|
+
# classifier.fit(train_file: "reviews.jsonl", epochs: 3)
|
|
9
|
+
# classifier.predict("This product is amazing!")
|
|
10
|
+
#
|
|
11
|
+
# @example With configuration
|
|
12
|
+
# classifier = Fine::TextClassifier.new("microsoft/deberta-v3-small") do |config|
|
|
13
|
+
# config.epochs = 5
|
|
14
|
+
# config.batch_size = 16
|
|
15
|
+
# config.learning_rate = 2e-5
|
|
16
|
+
# end
|
|
17
|
+
#
|
|
18
|
+
class TextClassifier
|
|
19
|
+
attr_reader :model, :config, :tokenizer, :label_map, :model_id
|
|
20
|
+
|
|
21
|
+
# Create a new TextClassifier
|
|
22
|
+
#
|
|
23
|
+
# @param model_id [String] HuggingFace model ID
|
|
24
|
+
# @yield [config] Optional configuration block
|
|
25
|
+
def initialize(model_id, &block)
|
|
26
|
+
@model_id = model_id
|
|
27
|
+
@config = TextConfiguration.new
|
|
28
|
+
@model = nil
|
|
29
|
+
@tokenizer = nil
|
|
30
|
+
@label_map = nil
|
|
31
|
+
@trained = false
|
|
32
|
+
|
|
33
|
+
block&.call(@config)
|
|
34
|
+
|
|
35
|
+
if @config.callbacks.empty? && Fine.configuration&.progress_bar != false
|
|
36
|
+
@config.callbacks << Callbacks::ProgressBar.new
|
|
37
|
+
end
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
# Load a fine-tuned classifier from disk
|
|
41
|
+
#
|
|
42
|
+
# @param path [String] Path to saved model directory
|
|
43
|
+
# @return [TextClassifier]
|
|
44
|
+
def self.load(path)
|
|
45
|
+
config_path = File.join(path, "config.json")
|
|
46
|
+
raise ModelNotFoundError.new(path) unless File.exist?(config_path)
|
|
47
|
+
|
|
48
|
+
config_data = JSON.parse(File.read(config_path))
|
|
49
|
+
|
|
50
|
+
classifier = allocate
|
|
51
|
+
classifier.instance_variable_set(:@model_id, config_data["_model_id"] || "custom")
|
|
52
|
+
classifier.instance_variable_set(:@config, TextConfiguration.new)
|
|
53
|
+
classifier.instance_variable_set(:@trained, true)
|
|
54
|
+
|
|
55
|
+
# Load label map
|
|
56
|
+
if config_data["label2id"]
|
|
57
|
+
classifier.instance_variable_set(:@label_map, config_data["label2id"])
|
|
58
|
+
elsif config_data["id2label"]
|
|
59
|
+
label_map = config_data["id2label"].transform_keys(&:to_i).invert
|
|
60
|
+
classifier.instance_variable_set(:@label_map, label_map)
|
|
61
|
+
end
|
|
62
|
+
|
|
63
|
+
# Load tokenizer
|
|
64
|
+
tokenizer_path = File.join(path, "tokenizer.json")
|
|
65
|
+
tokenizer = if File.exist?(tokenizer_path)
|
|
66
|
+
Tokenizers::AutoTokenizer.new(path, max_length: config_data["max_length"] || 512)
|
|
67
|
+
else
|
|
68
|
+
Tokenizers::AutoTokenizer.from_pretrained(
|
|
69
|
+
config_data["_model_id"] || "distilbert-base-uncased",
|
|
70
|
+
max_length: config_data["max_length"] || 512
|
|
71
|
+
)
|
|
72
|
+
end
|
|
73
|
+
classifier.instance_variable_set(:@tokenizer, tokenizer)
|
|
74
|
+
|
|
75
|
+
# Load model
|
|
76
|
+
classifier.instance_variable_set(
|
|
77
|
+
:@model,
|
|
78
|
+
Models::BertForSequenceClassification.load(path)
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
classifier
|
|
82
|
+
end
|
|
83
|
+
|
|
84
|
+
# Fine-tune on a dataset
|
|
85
|
+
#
|
|
86
|
+
# @param train_file [String] Path to training data (JSONL or CSV)
|
|
87
|
+
# @param val_file [String, nil] Path to validation data
|
|
88
|
+
# @param epochs [Integer, nil] Override config epochs
|
|
89
|
+
# @return [Array<Hash>] Training history
|
|
90
|
+
def fit(train_file:, val_file: nil, epochs: nil)
|
|
91
|
+
@config.epochs = epochs if epochs
|
|
92
|
+
|
|
93
|
+
# Load tokenizer
|
|
94
|
+
@tokenizer = Tokenizers::AutoTokenizer.from_pretrained(
|
|
95
|
+
@model_id,
|
|
96
|
+
max_length: @config.max_length
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# Load datasets
|
|
100
|
+
train_dataset = Datasets::TextDataset.from_file(train_file, tokenizer: @tokenizer)
|
|
101
|
+
|
|
102
|
+
val_dataset = if val_file
|
|
103
|
+
Datasets::TextDataset.from_file(val_file, tokenizer: @tokenizer, label_map: train_dataset.label_map)
|
|
104
|
+
end
|
|
105
|
+
|
|
106
|
+
@label_map = train_dataset.label_map
|
|
107
|
+
num_classes = train_dataset.num_classes
|
|
108
|
+
|
|
109
|
+
# Load model
|
|
110
|
+
@model = Models::BertForSequenceClassification.from_pretrained(
|
|
111
|
+
@model_id,
|
|
112
|
+
num_labels: num_classes,
|
|
113
|
+
dropout: @config.dropout
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# Train
|
|
117
|
+
trainer = Training::TextTrainer.new(
|
|
118
|
+
@model,
|
|
119
|
+
@config,
|
|
120
|
+
train_dataset: train_dataset,
|
|
121
|
+
val_dataset: val_dataset
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
history = trainer.fit
|
|
125
|
+
@trained = true
|
|
126
|
+
|
|
127
|
+
history
|
|
128
|
+
end
|
|
129
|
+
|
|
130
|
+
# Fine-tune with automatic train/val split
|
|
131
|
+
#
|
|
132
|
+
# @param data_file [String] Path to data file
|
|
133
|
+
# @param val_split [Float] Fraction for validation
|
|
134
|
+
# @return [Array<Hash>] Training history
|
|
135
|
+
def fit_with_split(data_file:, val_split: 0.2, epochs: nil)
|
|
136
|
+
@config.epochs = epochs if epochs
|
|
137
|
+
|
|
138
|
+
@tokenizer = Tokenizers::AutoTokenizer.from_pretrained(
|
|
139
|
+
@model_id,
|
|
140
|
+
max_length: @config.max_length
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
full_dataset = Datasets::TextDataset.from_file(data_file, tokenizer: @tokenizer)
|
|
144
|
+
train_dataset, val_dataset = full_dataset.split(test_size: val_split)
|
|
145
|
+
|
|
146
|
+
@label_map = train_dataset.label_map
|
|
147
|
+
|
|
148
|
+
@model = Models::BertForSequenceClassification.from_pretrained(
|
|
149
|
+
@model_id,
|
|
150
|
+
num_labels: train_dataset.num_classes,
|
|
151
|
+
dropout: @config.dropout
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
trainer = Training::TextTrainer.new(
|
|
155
|
+
@model,
|
|
156
|
+
@config,
|
|
157
|
+
train_dataset: train_dataset,
|
|
158
|
+
val_dataset: val_dataset
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
history = trainer.fit
|
|
162
|
+
@trained = true
|
|
163
|
+
|
|
164
|
+
history
|
|
165
|
+
end
|
|
166
|
+
|
|
167
|
+
# Make predictions
|
|
168
|
+
#
|
|
169
|
+
# @param texts [String, Array<String>] Text(s) to classify
|
|
170
|
+
# @param top_k [Integer] Number of top predictions to return
|
|
171
|
+
# @return [Array<Array<Hash>>] Predictions with :label and :score
|
|
172
|
+
def predict(texts, top_k: 5)
|
|
173
|
+
raise TrainingError, "Model not trained or loaded" unless @trained && @model
|
|
174
|
+
|
|
175
|
+
texts = [texts] if texts.is_a?(String)
|
|
176
|
+
|
|
177
|
+
# Tokenize
|
|
178
|
+
encoding = @tokenizer.encode(texts)
|
|
179
|
+
|
|
180
|
+
# Get predictions
|
|
181
|
+
@model.eval
|
|
182
|
+
probs = @model.predict_proba(
|
|
183
|
+
encoding[:input_ids],
|
|
184
|
+
attention_mask: encoding[:attention_mask],
|
|
185
|
+
token_type_ids: encoding[:token_type_ids]
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
# Convert to result format
|
|
189
|
+
inverse_label_map = @label_map.invert
|
|
190
|
+
|
|
191
|
+
probs.to_a.map do |sample_probs|
|
|
192
|
+
sorted = sample_probs.each_with_index.sort_by { |prob, _| -prob }
|
|
193
|
+
top = sorted.first([top_k, @label_map.size].min)
|
|
194
|
+
|
|
195
|
+
top.map do |prob, idx|
|
|
196
|
+
{
|
|
197
|
+
label: inverse_label_map[idx] || idx.to_s,
|
|
198
|
+
score: prob.round(4)
|
|
199
|
+
}
|
|
200
|
+
end
|
|
201
|
+
end
|
|
202
|
+
end
|
|
203
|
+
|
|
204
|
+
# Save the model
|
|
205
|
+
#
|
|
206
|
+
# @param path [String] Directory to save to
|
|
207
|
+
def save(path)
|
|
208
|
+
raise TrainingError, "Model not trained" unless @trained && @model
|
|
209
|
+
|
|
210
|
+
@model.save(path, label_map: @label_map)
|
|
211
|
+
@tokenizer.save(path)
|
|
212
|
+
|
|
213
|
+
# Update config with model ID and max_length
|
|
214
|
+
config_path = File.join(path, "config.json")
|
|
215
|
+
config = JSON.parse(File.read(config_path))
|
|
216
|
+
config["_model_id"] = @model_id
|
|
217
|
+
config["max_length"] = @config.max_length
|
|
218
|
+
File.write(config_path, JSON.pretty_generate(config))
|
|
219
|
+
end
|
|
220
|
+
|
|
221
|
+
# Get class names
|
|
222
|
+
def class_names
|
|
223
|
+
return [] unless @label_map
|
|
224
|
+
|
|
225
|
+
@label_map.sort_by { |_, v| v }.map(&:first)
|
|
226
|
+
end
|
|
227
|
+
|
|
228
|
+
# Export to ONNX format
|
|
229
|
+
#
|
|
230
|
+
# @param path [String] Output path for ONNX file
|
|
231
|
+
# @param options [Hash] Export options
|
|
232
|
+
# @return [String] The output path
|
|
233
|
+
def export_onnx(path, **options)
|
|
234
|
+
Export.to_onnx(self, path, **options)
|
|
235
|
+
end
|
|
236
|
+
end
|
|
237
|
+
|
|
238
|
+
# Configuration for text models
|
|
239
|
+
class TextConfiguration < Configuration
|
|
240
|
+
attr_accessor :max_length, :warmup_ratio
|
|
241
|
+
|
|
242
|
+
def initialize
|
|
243
|
+
super
|
|
244
|
+
@max_length = 256
|
|
245
|
+
@warmup_ratio = 0.1
|
|
246
|
+
@learning_rate = 2e-5 # Lower default for text models
|
|
247
|
+
@batch_size = 16 # Smaller default for text
|
|
248
|
+
end
|
|
249
|
+
end
|
|
250
|
+
end
|
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fine
|
|
4
|
+
# High-level API for text embeddings
|
|
5
|
+
#
|
|
6
|
+
# @example Basic usage
|
|
7
|
+
# embedder = Fine::TextEmbedder.new("sentence-transformers/all-MiniLM-L6-v2")
|
|
8
|
+
# embedder.fit(train_file: "pairs.jsonl", epochs: 3)
|
|
9
|
+
# embedding = embedder.encode("Hello world")
|
|
10
|
+
#
|
|
11
|
+
# @example Without fine-tuning (use pretrained directly)
|
|
12
|
+
# embedder = Fine::TextEmbedder.new("sentence-transformers/all-MiniLM-L6-v2")
|
|
13
|
+
# embedding = embedder.encode("Hello world")
|
|
14
|
+
#
|
|
15
|
+
class TextEmbedder
|
|
16
|
+
attr_reader :model, :config, :tokenizer, :model_id
|
|
17
|
+
|
|
18
|
+
# Create a new TextEmbedder
|
|
19
|
+
#
|
|
20
|
+
# @param model_id [String] HuggingFace model ID
|
|
21
|
+
# @yield [config] Optional configuration block
|
|
22
|
+
def initialize(model_id, &block)
|
|
23
|
+
@model_id = model_id
|
|
24
|
+
@config = EmbeddingConfiguration.new
|
|
25
|
+
@model = nil
|
|
26
|
+
@tokenizer = nil
|
|
27
|
+
@trained = false
|
|
28
|
+
|
|
29
|
+
block&.call(@config)
|
|
30
|
+
|
|
31
|
+
# Load tokenizer immediately for encoding
|
|
32
|
+
@tokenizer = Tokenizers::AutoTokenizer.from_pretrained(
|
|
33
|
+
model_id,
|
|
34
|
+
max_length: @config.max_length
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
# Load pretrained model for immediate use
|
|
38
|
+
@model = Models::SentenceTransformer.from_pretrained(
|
|
39
|
+
model_id,
|
|
40
|
+
pooling_mode: @config.pooling_mode
|
|
41
|
+
)
|
|
42
|
+
@trained = true # Pretrained is ready to use
|
|
43
|
+
end
|
|
44
|
+
|
|
45
|
+
# Load a fine-tuned embedder from disk
|
|
46
|
+
#
|
|
47
|
+
# @param path [String] Path to saved model
|
|
48
|
+
# @return [TextEmbedder]
|
|
49
|
+
def self.load(path)
|
|
50
|
+
config_path = File.join(path, "config.json")
|
|
51
|
+
raise ModelNotFoundError.new(path) unless File.exist?(config_path)
|
|
52
|
+
|
|
53
|
+
config_data = JSON.parse(File.read(config_path))
|
|
54
|
+
|
|
55
|
+
embedder = allocate
|
|
56
|
+
embedder.instance_variable_set(:@model_id, config_data["_model_id"] || "custom")
|
|
57
|
+
embedder.instance_variable_set(:@config, EmbeddingConfiguration.new)
|
|
58
|
+
embedder.instance_variable_set(:@trained, true)
|
|
59
|
+
|
|
60
|
+
# Load tokenizer
|
|
61
|
+
tokenizer_path = File.join(path, "tokenizer.json")
|
|
62
|
+
tokenizer = if File.exist?(tokenizer_path)
|
|
63
|
+
Tokenizers::AutoTokenizer.new(path, max_length: config_data["max_length"] || 512)
|
|
64
|
+
else
|
|
65
|
+
Tokenizers::AutoTokenizer.from_pretrained(
|
|
66
|
+
config_data["_model_id"],
|
|
67
|
+
max_length: config_data["max_length"] || 512
|
|
68
|
+
)
|
|
69
|
+
end
|
|
70
|
+
embedder.instance_variable_set(:@tokenizer, tokenizer)
|
|
71
|
+
|
|
72
|
+
# Load model
|
|
73
|
+
embedder.instance_variable_set(
|
|
74
|
+
:@model,
|
|
75
|
+
Models::SentenceTransformer.load(path)
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
embedder
|
|
79
|
+
end
|
|
80
|
+
|
|
81
|
+
# Fine-tune on pairs/triplets data
|
|
82
|
+
#
|
|
83
|
+
# @param train_file [String] Path to training data (JSONL)
|
|
84
|
+
# @param epochs [Integer, nil] Override config epochs
|
|
85
|
+
# @return [Array<Hash>] Training history
|
|
86
|
+
def fit(train_file:, epochs: nil)
|
|
87
|
+
@config.epochs = epochs if epochs
|
|
88
|
+
|
|
89
|
+
# Load dataset
|
|
90
|
+
train_dataset = Datasets::TextPairDataset.from_jsonl(
|
|
91
|
+
train_file,
|
|
92
|
+
tokenizer: @tokenizer,
|
|
93
|
+
text_a_column: "query",
|
|
94
|
+
text_b_column: "positive"
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
# Add progress bar callback
|
|
98
|
+
if @config.callbacks.empty? && Fine.configuration&.progress_bar != false
|
|
99
|
+
@config.callbacks << Callbacks::ProgressBar.new
|
|
100
|
+
end
|
|
101
|
+
|
|
102
|
+
# Train
|
|
103
|
+
trainer = Training::EmbeddingTrainer.new(
|
|
104
|
+
@model,
|
|
105
|
+
@config,
|
|
106
|
+
train_dataset: train_dataset
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
trainer.fit
|
|
110
|
+
end
|
|
111
|
+
|
|
112
|
+
# Encode text(s) to embeddings
|
|
113
|
+
#
|
|
114
|
+
# @param texts [String, Array<String>] Text(s) to encode
|
|
115
|
+
# @return [Array<Float>, Array<Array<Float>>] Embedding(s)
|
|
116
|
+
def encode(texts)
|
|
117
|
+
raise TrainingError, "Model not loaded" unless @model
|
|
118
|
+
|
|
119
|
+
single_input = texts.is_a?(String)
|
|
120
|
+
texts = [texts] if single_input
|
|
121
|
+
|
|
122
|
+
# Tokenize
|
|
123
|
+
encoding = @tokenizer.encode(texts)
|
|
124
|
+
|
|
125
|
+
# Get embeddings
|
|
126
|
+
embeddings = @model.encode(
|
|
127
|
+
encoding[:input_ids],
|
|
128
|
+
attention_mask: encoding[:attention_mask]
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
# Convert to Ruby arrays
|
|
132
|
+
result = embeddings.to_a
|
|
133
|
+
|
|
134
|
+
single_input ? result.first : result
|
|
135
|
+
end
|
|
136
|
+
|
|
137
|
+
# Compute similarity between two texts
|
|
138
|
+
#
|
|
139
|
+
# @param text_a [String] First text
|
|
140
|
+
# @param text_b [String] Second text
|
|
141
|
+
# @return [Float] Cosine similarity score
|
|
142
|
+
def similarity(text_a, text_b)
|
|
143
|
+
emb_a = encode(text_a)
|
|
144
|
+
emb_b = encode(text_b)
|
|
145
|
+
|
|
146
|
+
cosine_similarity(emb_a, emb_b)
|
|
147
|
+
end
|
|
148
|
+
|
|
149
|
+
# Find most similar texts from a corpus
|
|
150
|
+
#
|
|
151
|
+
# @param query [String] Query text
|
|
152
|
+
# @param corpus [Array<String>] Corpus to search
|
|
153
|
+
# @param top_k [Integer] Number of results
|
|
154
|
+
# @return [Array<Hash>] Results with :text, :score, :index
|
|
155
|
+
def search(query, corpus, top_k: 5)
|
|
156
|
+
query_emb = encode(query)
|
|
157
|
+
corpus_embs = encode(corpus)
|
|
158
|
+
|
|
159
|
+
scores = corpus_embs.map.with_index do |emb, idx|
|
|
160
|
+
{ text: corpus[idx], score: cosine_similarity(query_emb, emb), index: idx }
|
|
161
|
+
end
|
|
162
|
+
|
|
163
|
+
scores.sort_by { |s| -s[:score] }.first(top_k)
|
|
164
|
+
end
|
|
165
|
+
|
|
166
|
+
# Save the model
|
|
167
|
+
#
|
|
168
|
+
# @param path [String] Directory to save to
|
|
169
|
+
def save(path)
|
|
170
|
+
raise TrainingError, "Model not loaded" unless @model
|
|
171
|
+
|
|
172
|
+
@model.save(path)
|
|
173
|
+
@tokenizer.save(path)
|
|
174
|
+
|
|
175
|
+
# Update config
|
|
176
|
+
config_path = File.join(path, "config.json")
|
|
177
|
+
config = JSON.parse(File.read(config_path))
|
|
178
|
+
config["_model_id"] = @model_id
|
|
179
|
+
config["max_length"] = @config.max_length
|
|
180
|
+
File.write(config_path, JSON.pretty_generate(config))
|
|
181
|
+
end
|
|
182
|
+
|
|
183
|
+
# Get embedding dimension
|
|
184
|
+
def embedding_dim
|
|
185
|
+
@model.config.hidden_size
|
|
186
|
+
end
|
|
187
|
+
|
|
188
|
+
# Export to ONNX format
|
|
189
|
+
#
|
|
190
|
+
# @param path [String] Output path for ONNX file
|
|
191
|
+
# @param options [Hash] Export options
|
|
192
|
+
# @return [String] The output path
|
|
193
|
+
def export_onnx(path, **options)
|
|
194
|
+
Export.to_onnx(self, path, **options)
|
|
195
|
+
end
|
|
196
|
+
|
|
197
|
+
private
|
|
198
|
+
|
|
199
|
+
def cosine_similarity(a, b)
|
|
200
|
+
dot = a.zip(b).sum { |x, y| x * y }
|
|
201
|
+
norm_a = Math.sqrt(a.sum { |x| x * x })
|
|
202
|
+
norm_b = Math.sqrt(b.sum { |x| x * x })
|
|
203
|
+
dot / (norm_a * norm_b)
|
|
204
|
+
end
|
|
205
|
+
end
|
|
206
|
+
|
|
207
|
+
# Configuration for embedding models
|
|
208
|
+
class EmbeddingConfiguration < Configuration
|
|
209
|
+
attr_accessor :max_length, :pooling_mode, :loss
|
|
210
|
+
|
|
211
|
+
def initialize
|
|
212
|
+
super
|
|
213
|
+
@max_length = 256
|
|
214
|
+
@pooling_mode = :mean
|
|
215
|
+
@loss = :multiple_negatives_ranking
|
|
216
|
+
@learning_rate = 2e-5
|
|
217
|
+
@batch_size = 32
|
|
218
|
+
@epochs = 1
|
|
219
|
+
end
|
|
220
|
+
end
|
|
221
|
+
end
|
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "tokenizers"
|
|
4
|
+
|
|
5
|
+
module Fine
|
|
6
|
+
module Tokenizers
|
|
7
|
+
# Wrapper around HuggingFace tokenizers
|
|
8
|
+
class AutoTokenizer
|
|
9
|
+
attr_reader :tokenizer, :model_id, :max_length
|
|
10
|
+
|
|
11
|
+
# Load tokenizer from a pretrained model
|
|
12
|
+
#
|
|
13
|
+
# @param model_id [String] HuggingFace model ID
|
|
14
|
+
# @param max_length [Integer] Maximum sequence length
|
|
15
|
+
# @return [AutoTokenizer]
|
|
16
|
+
def self.from_pretrained(model_id, max_length: 512)
|
|
17
|
+
new(model_id, max_length: max_length)
|
|
18
|
+
end
|
|
19
|
+
|
|
20
|
+
def initialize(model_id, max_length: 512)
|
|
21
|
+
@model_id = model_id
|
|
22
|
+
@max_length = max_length
|
|
23
|
+
@tokenizer = load_tokenizer(model_id)
|
|
24
|
+
|
|
25
|
+
configure_tokenizer
|
|
26
|
+
end
|
|
27
|
+
|
|
28
|
+
# Tokenize a single text or batch of texts
|
|
29
|
+
#
|
|
30
|
+
# @param texts [String, Array<String>] Text(s) to tokenize
|
|
31
|
+
# @param padding [Boolean] Whether to pad sequences
|
|
32
|
+
# @param truncation [Boolean] Whether to truncate sequences
|
|
33
|
+
# @param return_tensors [Boolean] Whether to return Torch tensors
|
|
34
|
+
# @return [Hash] Hash with :input_ids, :attention_mask, and optionally :token_type_ids
|
|
35
|
+
def encode(texts, padding: true, truncation: true, return_tensors: true)
|
|
36
|
+
texts = [texts] if texts.is_a?(String)
|
|
37
|
+
single_input = texts.size == 1
|
|
38
|
+
|
|
39
|
+
# Encode all texts
|
|
40
|
+
encodings = texts.map do |text|
|
|
41
|
+
@tokenizer.encode(text)
|
|
42
|
+
end
|
|
43
|
+
|
|
44
|
+
# Get max length in batch for padding
|
|
45
|
+
max_len = if padding
|
|
46
|
+
[encodings.map { |e| e.ids.length }.max, @max_length].min
|
|
47
|
+
else
|
|
48
|
+
@max_length
|
|
49
|
+
end
|
|
50
|
+
|
|
51
|
+
# Build output arrays
|
|
52
|
+
input_ids = []
|
|
53
|
+
attention_mask = []
|
|
54
|
+
token_type_ids = []
|
|
55
|
+
|
|
56
|
+
encodings.each do |encoding|
|
|
57
|
+
ids = encoding.ids
|
|
58
|
+
mask = encoding.attention_mask
|
|
59
|
+
type_ids = encoding.type_ids rescue Array.new(ids.length, 0)
|
|
60
|
+
|
|
61
|
+
# Truncate if needed
|
|
62
|
+
if truncation && ids.length > max_len
|
|
63
|
+
ids = ids[0...max_len]
|
|
64
|
+
mask = mask[0...max_len]
|
|
65
|
+
type_ids = type_ids[0...max_len]
|
|
66
|
+
end
|
|
67
|
+
|
|
68
|
+
# Pad if needed
|
|
69
|
+
if padding && ids.length < max_len
|
|
70
|
+
pad_length = max_len - ids.length
|
|
71
|
+
ids = ids + Array.new(pad_length, pad_token_id)
|
|
72
|
+
mask = mask + Array.new(pad_length, 0)
|
|
73
|
+
type_ids = type_ids + Array.new(pad_length, 0)
|
|
74
|
+
end
|
|
75
|
+
|
|
76
|
+
input_ids << ids
|
|
77
|
+
attention_mask << mask
|
|
78
|
+
token_type_ids << type_ids
|
|
79
|
+
end
|
|
80
|
+
|
|
81
|
+
result = {
|
|
82
|
+
input_ids: input_ids,
|
|
83
|
+
attention_mask: attention_mask
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
# Only include token_type_ids for BERT-style models
|
|
87
|
+
result[:token_type_ids] = token_type_ids if has_token_type_ids?
|
|
88
|
+
|
|
89
|
+
if return_tensors
|
|
90
|
+
result.transform_values! { |v| Torch.tensor(v, dtype: :long) }
|
|
91
|
+
end
|
|
92
|
+
|
|
93
|
+
result
|
|
94
|
+
end
|
|
95
|
+
|
|
96
|
+
# Encode a pair of texts (for sentence pair tasks)
|
|
97
|
+
#
|
|
98
|
+
# @param text_a [String] First text
|
|
99
|
+
# @param text_b [String] Second text
|
|
100
|
+
# @return [Hash] Tokenized output
|
|
101
|
+
def encode_pair(text_a, text_b, **kwargs)
|
|
102
|
+
encoding = @tokenizer.encode(text_a, text_b)
|
|
103
|
+
|
|
104
|
+
ids = encoding.ids
|
|
105
|
+
mask = encoding.attention_mask
|
|
106
|
+
type_ids = encoding.type_ids rescue Array.new(ids.length, 0)
|
|
107
|
+
|
|
108
|
+
# Truncate if needed
|
|
109
|
+
if kwargs.fetch(:truncation, true) && ids.length > @max_length
|
|
110
|
+
ids = ids[0...@max_length]
|
|
111
|
+
mask = mask[0...@max_length]
|
|
112
|
+
type_ids = type_ids[0...@max_length]
|
|
113
|
+
end
|
|
114
|
+
|
|
115
|
+
result = {
|
|
116
|
+
input_ids: [ids],
|
|
117
|
+
attention_mask: [mask]
|
|
118
|
+
}
|
|
119
|
+
result[:token_type_ids] = [type_ids] if has_token_type_ids?
|
|
120
|
+
|
|
121
|
+
if kwargs.fetch(:return_tensors, true)
|
|
122
|
+
result.transform_values! { |v| Torch.tensor(v, dtype: :long) }
|
|
123
|
+
end
|
|
124
|
+
|
|
125
|
+
result
|
|
126
|
+
end
|
|
127
|
+
|
|
128
|
+
# Decode token IDs back to text
|
|
129
|
+
#
|
|
130
|
+
# @param token_ids [Array<Integer>] Token IDs
|
|
131
|
+
# @param skip_special_tokens [Boolean] Whether to skip special tokens
|
|
132
|
+
# @return [String] Decoded text
|
|
133
|
+
def decode(token_ids, skip_special_tokens: true)
|
|
134
|
+
token_ids = token_ids.to_a if token_ids.respond_to?(:to_a)
|
|
135
|
+
@tokenizer.decode(token_ids, skip_special_tokens: skip_special_tokens)
|
|
136
|
+
end
|
|
137
|
+
|
|
138
|
+
# Get vocabulary size
|
|
139
|
+
def vocab_size
|
|
140
|
+
@tokenizer.vocab_size
|
|
141
|
+
end
|
|
142
|
+
|
|
143
|
+
# Get pad token ID
|
|
144
|
+
def pad_token_id
|
|
145
|
+
@tokenizer.token_to_id(@tokenizer.padding&.dig("pad_token") || "[PAD]") || 0
|
|
146
|
+
end
|
|
147
|
+
|
|
148
|
+
# Get CLS token ID
|
|
149
|
+
def cls_token_id
|
|
150
|
+
@tokenizer.token_to_id("[CLS]") || @tokenizer.token_to_id("<s>") || 0
|
|
151
|
+
end
|
|
152
|
+
|
|
153
|
+
# Get SEP token ID
|
|
154
|
+
def sep_token_id
|
|
155
|
+
@tokenizer.token_to_id("[SEP]") || @tokenizer.token_to_id("</s>") || 0
|
|
156
|
+
end
|
|
157
|
+
|
|
158
|
+
# Get EOS token ID
|
|
159
|
+
def eos_token_id
|
|
160
|
+
@tokenizer.token_to_id("</s>") || @tokenizer.token_to_id("[SEP]") || @tokenizer.token_to_id("<|endoftext|>") || 0
|
|
161
|
+
end
|
|
162
|
+
|
|
163
|
+
# Get BOS token ID
|
|
164
|
+
def bos_token_id
|
|
165
|
+
@tokenizer.token_to_id("<s>") || @tokenizer.token_to_id("[CLS]") || @tokenizer.token_to_id("<|startoftext|>") || 0
|
|
166
|
+
end
|
|
167
|
+
|
|
168
|
+
# Save tokenizer to directory
|
|
169
|
+
def save(path)
|
|
170
|
+
FileUtils.mkdir_p(path)
|
|
171
|
+
@tokenizer.save(File.join(path, "tokenizer.json"))
|
|
172
|
+
end
|
|
173
|
+
|
|
174
|
+
private
|
|
175
|
+
|
|
176
|
+
def load_tokenizer(model_id)
|
|
177
|
+
# Check if it's a local path with tokenizer.json
|
|
178
|
+
local_tokenizer_path = File.join(model_id, "tokenizer.json")
|
|
179
|
+
if File.exist?(local_tokenizer_path)
|
|
180
|
+
return ::Tokenizers::Tokenizer.from_file(local_tokenizer_path)
|
|
181
|
+
end
|
|
182
|
+
|
|
183
|
+
# Check if model_id itself is a tokenizer.json path
|
|
184
|
+
if File.exist?(model_id) && model_id.end_with?("tokenizer.json")
|
|
185
|
+
return ::Tokenizers::Tokenizer.from_file(model_id)
|
|
186
|
+
end
|
|
187
|
+
|
|
188
|
+
# Try to load from HuggingFace Hub
|
|
189
|
+
::Tokenizers::Tokenizer.from_pretrained(model_id)
|
|
190
|
+
rescue StandardError => e
|
|
191
|
+
raise ConfigurationError, "Failed to load tokenizer for #{model_id}: #{e.message}"
|
|
192
|
+
end
|
|
193
|
+
|
|
194
|
+
def configure_tokenizer
|
|
195
|
+
# Enable truncation
|
|
196
|
+
@tokenizer.enable_truncation(@max_length)
|
|
197
|
+
|
|
198
|
+
# Enable padding
|
|
199
|
+
@tokenizer.enable_padding(length: @max_length)
|
|
200
|
+
end
|
|
201
|
+
|
|
202
|
+
def has_token_type_ids?
|
|
203
|
+
# BERT-style models use token_type_ids, RoBERTa/DistilBERT don't always need them
|
|
204
|
+
@model_id.include?("bert") && !@model_id.include?("roberta")
|
|
205
|
+
end
|
|
206
|
+
end
|
|
207
|
+
end
|
|
208
|
+
end
|