fine 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. checksums.yaml +7 -0
  2. data/.rspec +3 -0
  3. data/CHANGELOG.md +38 -0
  4. data/Gemfile +6 -0
  5. data/Gemfile.lock +167 -0
  6. data/LICENSE +21 -0
  7. data/README.md +212 -0
  8. data/Rakefile +6 -0
  9. data/docs/installation.md +151 -0
  10. data/docs/tutorials/llm-fine-tuning.md +246 -0
  11. data/docs/tutorials/model-export.md +200 -0
  12. data/docs/tutorials/siglip2-image-classification.md +130 -0
  13. data/docs/tutorials/siglip2-object-recognition.md +203 -0
  14. data/docs/tutorials/siglip2-similarity-search.md +152 -0
  15. data/docs/tutorials/text-classification.md +233 -0
  16. data/docs/tutorials/text-embeddings.md +211 -0
  17. data/examples/basic_classification.rb +70 -0
  18. data/examples/data/tool_calls.jsonl +30 -0
  19. data/examples/demo_training.rb +78 -0
  20. data/examples/finetune_gemma3_tools.rb +135 -0
  21. data/examples/real_llm_test.rb +128 -0
  22. data/examples/real_text_classification_test.rb +90 -0
  23. data/examples/real_text_embedder_test.rb +110 -0
  24. data/examples/real_training_test.rb +88 -0
  25. data/examples/test_export.rb +28 -0
  26. data/examples/test_image_classifier.rb +79 -0
  27. data/examples/test_llm.rb +100 -0
  28. data/examples/test_text_classifier.rb +59 -0
  29. data/lib/fine/callbacks/base.rb +140 -0
  30. data/lib/fine/callbacks/progress_bar.rb +66 -0
  31. data/lib/fine/configuration.rb +106 -0
  32. data/lib/fine/datasets/data_loader.rb +63 -0
  33. data/lib/fine/datasets/image_dataset.rb +203 -0
  34. data/lib/fine/datasets/instruction_dataset.rb +226 -0
  35. data/lib/fine/datasets/text_data_loader.rb +88 -0
  36. data/lib/fine/datasets/text_dataset.rb +266 -0
  37. data/lib/fine/error.rb +49 -0
  38. data/lib/fine/export/gguf_exporter.rb +424 -0
  39. data/lib/fine/export/onnx_exporter.rb +249 -0
  40. data/lib/fine/export.rb +53 -0
  41. data/lib/fine/hub/config_loader.rb +145 -0
  42. data/lib/fine/hub/model_downloader.rb +136 -0
  43. data/lib/fine/hub/safetensors_loader.rb +108 -0
  44. data/lib/fine/image_classifier.rb +256 -0
  45. data/lib/fine/llm.rb +336 -0
  46. data/lib/fine/models/base.rb +48 -0
  47. data/lib/fine/models/bert_encoder.rb +202 -0
  48. data/lib/fine/models/bert_for_sequence_classification.rb +226 -0
  49. data/lib/fine/models/causal_lm.rb +279 -0
  50. data/lib/fine/models/classification_head.rb +24 -0
  51. data/lib/fine/models/gemma3_decoder.rb +244 -0
  52. data/lib/fine/models/llama_decoder.rb +297 -0
  53. data/lib/fine/models/sentence_transformer.rb +202 -0
  54. data/lib/fine/models/siglip2_for_image_classification.rb +155 -0
  55. data/lib/fine/models/siglip2_vision_encoder.rb +190 -0
  56. data/lib/fine/text_classifier.rb +250 -0
  57. data/lib/fine/text_embedder.rb +221 -0
  58. data/lib/fine/tokenizers/auto_tokenizer.rb +208 -0
  59. data/lib/fine/training/llm_trainer.rb +212 -0
  60. data/lib/fine/training/text_trainer.rb +275 -0
  61. data/lib/fine/training/trainer.rb +194 -0
  62. data/lib/fine/transforms/compose.rb +28 -0
  63. data/lib/fine/transforms/normalize.rb +33 -0
  64. data/lib/fine/transforms/resize.rb +35 -0
  65. data/lib/fine/transforms/to_tensor.rb +53 -0
  66. data/lib/fine/version.rb +3 -0
  67. data/lib/fine.rb +112 -0
  68. data/mise.toml +2 -0
  69. metadata +240 -0
@@ -0,0 +1,250 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Fine
4
+ # High-level API for text classification
5
+ #
6
+ # @example Basic usage
7
+ # classifier = Fine::TextClassifier.new("distilbert-base-uncased")
8
+ # classifier.fit(train_file: "reviews.jsonl", epochs: 3)
9
+ # classifier.predict("This product is amazing!")
10
+ #
11
+ # @example With configuration
12
+ # classifier = Fine::TextClassifier.new("microsoft/deberta-v3-small") do |config|
13
+ # config.epochs = 5
14
+ # config.batch_size = 16
15
+ # config.learning_rate = 2e-5
16
+ # end
17
+ #
18
+ class TextClassifier
19
+ attr_reader :model, :config, :tokenizer, :label_map, :model_id
20
+
21
+ # Create a new TextClassifier
22
+ #
23
+ # @param model_id [String] HuggingFace model ID
24
+ # @yield [config] Optional configuration block
25
+ def initialize(model_id, &block)
26
+ @model_id = model_id
27
+ @config = TextConfiguration.new
28
+ @model = nil
29
+ @tokenizer = nil
30
+ @label_map = nil
31
+ @trained = false
32
+
33
+ block&.call(@config)
34
+
35
+ if @config.callbacks.empty? && Fine.configuration&.progress_bar != false
36
+ @config.callbacks << Callbacks::ProgressBar.new
37
+ end
38
+ end
39
+
40
+ # Load a fine-tuned classifier from disk
41
+ #
42
+ # @param path [String] Path to saved model directory
43
+ # @return [TextClassifier]
44
+ def self.load(path)
45
+ config_path = File.join(path, "config.json")
46
+ raise ModelNotFoundError.new(path) unless File.exist?(config_path)
47
+
48
+ config_data = JSON.parse(File.read(config_path))
49
+
50
+ classifier = allocate
51
+ classifier.instance_variable_set(:@model_id, config_data["_model_id"] || "custom")
52
+ classifier.instance_variable_set(:@config, TextConfiguration.new)
53
+ classifier.instance_variable_set(:@trained, true)
54
+
55
+ # Load label map
56
+ if config_data["label2id"]
57
+ classifier.instance_variable_set(:@label_map, config_data["label2id"])
58
+ elsif config_data["id2label"]
59
+ label_map = config_data["id2label"].transform_keys(&:to_i).invert
60
+ classifier.instance_variable_set(:@label_map, label_map)
61
+ end
62
+
63
+ # Load tokenizer
64
+ tokenizer_path = File.join(path, "tokenizer.json")
65
+ tokenizer = if File.exist?(tokenizer_path)
66
+ Tokenizers::AutoTokenizer.new(path, max_length: config_data["max_length"] || 512)
67
+ else
68
+ Tokenizers::AutoTokenizer.from_pretrained(
69
+ config_data["_model_id"] || "distilbert-base-uncased",
70
+ max_length: config_data["max_length"] || 512
71
+ )
72
+ end
73
+ classifier.instance_variable_set(:@tokenizer, tokenizer)
74
+
75
+ # Load model
76
+ classifier.instance_variable_set(
77
+ :@model,
78
+ Models::BertForSequenceClassification.load(path)
79
+ )
80
+
81
+ classifier
82
+ end
83
+
84
+ # Fine-tune on a dataset
85
+ #
86
+ # @param train_file [String] Path to training data (JSONL or CSV)
87
+ # @param val_file [String, nil] Path to validation data
88
+ # @param epochs [Integer, nil] Override config epochs
89
+ # @return [Array<Hash>] Training history
90
+ def fit(train_file:, val_file: nil, epochs: nil)
91
+ @config.epochs = epochs if epochs
92
+
93
+ # Load tokenizer
94
+ @tokenizer = Tokenizers::AutoTokenizer.from_pretrained(
95
+ @model_id,
96
+ max_length: @config.max_length
97
+ )
98
+
99
+ # Load datasets
100
+ train_dataset = Datasets::TextDataset.from_file(train_file, tokenizer: @tokenizer)
101
+
102
+ val_dataset = if val_file
103
+ Datasets::TextDataset.from_file(val_file, tokenizer: @tokenizer, label_map: train_dataset.label_map)
104
+ end
105
+
106
+ @label_map = train_dataset.label_map
107
+ num_classes = train_dataset.num_classes
108
+
109
+ # Load model
110
+ @model = Models::BertForSequenceClassification.from_pretrained(
111
+ @model_id,
112
+ num_labels: num_classes,
113
+ dropout: @config.dropout
114
+ )
115
+
116
+ # Train
117
+ trainer = Training::TextTrainer.new(
118
+ @model,
119
+ @config,
120
+ train_dataset: train_dataset,
121
+ val_dataset: val_dataset
122
+ )
123
+
124
+ history = trainer.fit
125
+ @trained = true
126
+
127
+ history
128
+ end
129
+
130
+ # Fine-tune with automatic train/val split
131
+ #
132
+ # @param data_file [String] Path to data file
133
+ # @param val_split [Float] Fraction for validation
134
+ # @return [Array<Hash>] Training history
135
+ def fit_with_split(data_file:, val_split: 0.2, epochs: nil)
136
+ @config.epochs = epochs if epochs
137
+
138
+ @tokenizer = Tokenizers::AutoTokenizer.from_pretrained(
139
+ @model_id,
140
+ max_length: @config.max_length
141
+ )
142
+
143
+ full_dataset = Datasets::TextDataset.from_file(data_file, tokenizer: @tokenizer)
144
+ train_dataset, val_dataset = full_dataset.split(test_size: val_split)
145
+
146
+ @label_map = train_dataset.label_map
147
+
148
+ @model = Models::BertForSequenceClassification.from_pretrained(
149
+ @model_id,
150
+ num_labels: train_dataset.num_classes,
151
+ dropout: @config.dropout
152
+ )
153
+
154
+ trainer = Training::TextTrainer.new(
155
+ @model,
156
+ @config,
157
+ train_dataset: train_dataset,
158
+ val_dataset: val_dataset
159
+ )
160
+
161
+ history = trainer.fit
162
+ @trained = true
163
+
164
+ history
165
+ end
166
+
167
+ # Make predictions
168
+ #
169
+ # @param texts [String, Array<String>] Text(s) to classify
170
+ # @param top_k [Integer] Number of top predictions to return
171
+ # @return [Array<Array<Hash>>] Predictions with :label and :score
172
+ def predict(texts, top_k: 5)
173
+ raise TrainingError, "Model not trained or loaded" unless @trained && @model
174
+
175
+ texts = [texts] if texts.is_a?(String)
176
+
177
+ # Tokenize
178
+ encoding = @tokenizer.encode(texts)
179
+
180
+ # Get predictions
181
+ @model.eval
182
+ probs = @model.predict_proba(
183
+ encoding[:input_ids],
184
+ attention_mask: encoding[:attention_mask],
185
+ token_type_ids: encoding[:token_type_ids]
186
+ )
187
+
188
+ # Convert to result format
189
+ inverse_label_map = @label_map.invert
190
+
191
+ probs.to_a.map do |sample_probs|
192
+ sorted = sample_probs.each_with_index.sort_by { |prob, _| -prob }
193
+ top = sorted.first([top_k, @label_map.size].min)
194
+
195
+ top.map do |prob, idx|
196
+ {
197
+ label: inverse_label_map[idx] || idx.to_s,
198
+ score: prob.round(4)
199
+ }
200
+ end
201
+ end
202
+ end
203
+
204
+ # Save the model
205
+ #
206
+ # @param path [String] Directory to save to
207
+ def save(path)
208
+ raise TrainingError, "Model not trained" unless @trained && @model
209
+
210
+ @model.save(path, label_map: @label_map)
211
+ @tokenizer.save(path)
212
+
213
+ # Update config with model ID and max_length
214
+ config_path = File.join(path, "config.json")
215
+ config = JSON.parse(File.read(config_path))
216
+ config["_model_id"] = @model_id
217
+ config["max_length"] = @config.max_length
218
+ File.write(config_path, JSON.pretty_generate(config))
219
+ end
220
+
221
+ # Get class names
222
+ def class_names
223
+ return [] unless @label_map
224
+
225
+ @label_map.sort_by { |_, v| v }.map(&:first)
226
+ end
227
+
228
+ # Export to ONNX format
229
+ #
230
+ # @param path [String] Output path for ONNX file
231
+ # @param options [Hash] Export options
232
+ # @return [String] The output path
233
+ def export_onnx(path, **options)
234
+ Export.to_onnx(self, path, **options)
235
+ end
236
+ end
237
+
238
+ # Configuration for text models
239
+ class TextConfiguration < Configuration
240
+ attr_accessor :max_length, :warmup_ratio
241
+
242
+ def initialize
243
+ super
244
+ @max_length = 256
245
+ @warmup_ratio = 0.1
246
+ @learning_rate = 2e-5 # Lower default for text models
247
+ @batch_size = 16 # Smaller default for text
248
+ end
249
+ end
250
+ end
@@ -0,0 +1,221 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Fine
4
+ # High-level API for text embeddings
5
+ #
6
+ # @example Basic usage
7
+ # embedder = Fine::TextEmbedder.new("sentence-transformers/all-MiniLM-L6-v2")
8
+ # embedder.fit(train_file: "pairs.jsonl", epochs: 3)
9
+ # embedding = embedder.encode("Hello world")
10
+ #
11
+ # @example Without fine-tuning (use pretrained directly)
12
+ # embedder = Fine::TextEmbedder.new("sentence-transformers/all-MiniLM-L6-v2")
13
+ # embedding = embedder.encode("Hello world")
14
+ #
15
+ class TextEmbedder
16
+ attr_reader :model, :config, :tokenizer, :model_id
17
+
18
+ # Create a new TextEmbedder
19
+ #
20
+ # @param model_id [String] HuggingFace model ID
21
+ # @yield [config] Optional configuration block
22
+ def initialize(model_id, &block)
23
+ @model_id = model_id
24
+ @config = EmbeddingConfiguration.new
25
+ @model = nil
26
+ @tokenizer = nil
27
+ @trained = false
28
+
29
+ block&.call(@config)
30
+
31
+ # Load tokenizer immediately for encoding
32
+ @tokenizer = Tokenizers::AutoTokenizer.from_pretrained(
33
+ model_id,
34
+ max_length: @config.max_length
35
+ )
36
+
37
+ # Load pretrained model for immediate use
38
+ @model = Models::SentenceTransformer.from_pretrained(
39
+ model_id,
40
+ pooling_mode: @config.pooling_mode
41
+ )
42
+ @trained = true # Pretrained is ready to use
43
+ end
44
+
45
+ # Load a fine-tuned embedder from disk
46
+ #
47
+ # @param path [String] Path to saved model
48
+ # @return [TextEmbedder]
49
+ def self.load(path)
50
+ config_path = File.join(path, "config.json")
51
+ raise ModelNotFoundError.new(path) unless File.exist?(config_path)
52
+
53
+ config_data = JSON.parse(File.read(config_path))
54
+
55
+ embedder = allocate
56
+ embedder.instance_variable_set(:@model_id, config_data["_model_id"] || "custom")
57
+ embedder.instance_variable_set(:@config, EmbeddingConfiguration.new)
58
+ embedder.instance_variable_set(:@trained, true)
59
+
60
+ # Load tokenizer
61
+ tokenizer_path = File.join(path, "tokenizer.json")
62
+ tokenizer = if File.exist?(tokenizer_path)
63
+ Tokenizers::AutoTokenizer.new(path, max_length: config_data["max_length"] || 512)
64
+ else
65
+ Tokenizers::AutoTokenizer.from_pretrained(
66
+ config_data["_model_id"],
67
+ max_length: config_data["max_length"] || 512
68
+ )
69
+ end
70
+ embedder.instance_variable_set(:@tokenizer, tokenizer)
71
+
72
+ # Load model
73
+ embedder.instance_variable_set(
74
+ :@model,
75
+ Models::SentenceTransformer.load(path)
76
+ )
77
+
78
+ embedder
79
+ end
80
+
81
+ # Fine-tune on pairs/triplets data
82
+ #
83
+ # @param train_file [String] Path to training data (JSONL)
84
+ # @param epochs [Integer, nil] Override config epochs
85
+ # @return [Array<Hash>] Training history
86
+ def fit(train_file:, epochs: nil)
87
+ @config.epochs = epochs if epochs
88
+
89
+ # Load dataset
90
+ train_dataset = Datasets::TextPairDataset.from_jsonl(
91
+ train_file,
92
+ tokenizer: @tokenizer,
93
+ text_a_column: "query",
94
+ text_b_column: "positive"
95
+ )
96
+
97
+ # Add progress bar callback
98
+ if @config.callbacks.empty? && Fine.configuration&.progress_bar != false
99
+ @config.callbacks << Callbacks::ProgressBar.new
100
+ end
101
+
102
+ # Train
103
+ trainer = Training::EmbeddingTrainer.new(
104
+ @model,
105
+ @config,
106
+ train_dataset: train_dataset
107
+ )
108
+
109
+ trainer.fit
110
+ end
111
+
112
+ # Encode text(s) to embeddings
113
+ #
114
+ # @param texts [String, Array<String>] Text(s) to encode
115
+ # @return [Array<Float>, Array<Array<Float>>] Embedding(s)
116
+ def encode(texts)
117
+ raise TrainingError, "Model not loaded" unless @model
118
+
119
+ single_input = texts.is_a?(String)
120
+ texts = [texts] if single_input
121
+
122
+ # Tokenize
123
+ encoding = @tokenizer.encode(texts)
124
+
125
+ # Get embeddings
126
+ embeddings = @model.encode(
127
+ encoding[:input_ids],
128
+ attention_mask: encoding[:attention_mask]
129
+ )
130
+
131
+ # Convert to Ruby arrays
132
+ result = embeddings.to_a
133
+
134
+ single_input ? result.first : result
135
+ end
136
+
137
+ # Compute similarity between two texts
138
+ #
139
+ # @param text_a [String] First text
140
+ # @param text_b [String] Second text
141
+ # @return [Float] Cosine similarity score
142
+ def similarity(text_a, text_b)
143
+ emb_a = encode(text_a)
144
+ emb_b = encode(text_b)
145
+
146
+ cosine_similarity(emb_a, emb_b)
147
+ end
148
+
149
+ # Find most similar texts from a corpus
150
+ #
151
+ # @param query [String] Query text
152
+ # @param corpus [Array<String>] Corpus to search
153
+ # @param top_k [Integer] Number of results
154
+ # @return [Array<Hash>] Results with :text, :score, :index
155
+ def search(query, corpus, top_k: 5)
156
+ query_emb = encode(query)
157
+ corpus_embs = encode(corpus)
158
+
159
+ scores = corpus_embs.map.with_index do |emb, idx|
160
+ { text: corpus[idx], score: cosine_similarity(query_emb, emb), index: idx }
161
+ end
162
+
163
+ scores.sort_by { |s| -s[:score] }.first(top_k)
164
+ end
165
+
166
+ # Save the model
167
+ #
168
+ # @param path [String] Directory to save to
169
+ def save(path)
170
+ raise TrainingError, "Model not loaded" unless @model
171
+
172
+ @model.save(path)
173
+ @tokenizer.save(path)
174
+
175
+ # Update config
176
+ config_path = File.join(path, "config.json")
177
+ config = JSON.parse(File.read(config_path))
178
+ config["_model_id"] = @model_id
179
+ config["max_length"] = @config.max_length
180
+ File.write(config_path, JSON.pretty_generate(config))
181
+ end
182
+
183
+ # Get embedding dimension
184
+ def embedding_dim
185
+ @model.config.hidden_size
186
+ end
187
+
188
+ # Export to ONNX format
189
+ #
190
+ # @param path [String] Output path for ONNX file
191
+ # @param options [Hash] Export options
192
+ # @return [String] The output path
193
+ def export_onnx(path, **options)
194
+ Export.to_onnx(self, path, **options)
195
+ end
196
+
197
+ private
198
+
199
+ def cosine_similarity(a, b)
200
+ dot = a.zip(b).sum { |x, y| x * y }
201
+ norm_a = Math.sqrt(a.sum { |x| x * x })
202
+ norm_b = Math.sqrt(b.sum { |x| x * x })
203
+ dot / (norm_a * norm_b)
204
+ end
205
+ end
206
+
207
+ # Configuration for embedding models
208
+ class EmbeddingConfiguration < Configuration
209
+ attr_accessor :max_length, :pooling_mode, :loss
210
+
211
+ def initialize
212
+ super
213
+ @max_length = 256
214
+ @pooling_mode = :mean
215
+ @loss = :multiple_negatives_ranking
216
+ @learning_rate = 2e-5
217
+ @batch_size = 32
218
+ @epochs = 1
219
+ end
220
+ end
221
+ end
@@ -0,0 +1,208 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "tokenizers"
4
+
5
+ module Fine
6
+ module Tokenizers
7
+ # Wrapper around HuggingFace tokenizers
8
+ class AutoTokenizer
9
+ attr_reader :tokenizer, :model_id, :max_length
10
+
11
+ # Load tokenizer from a pretrained model
12
+ #
13
+ # @param model_id [String] HuggingFace model ID
14
+ # @param max_length [Integer] Maximum sequence length
15
+ # @return [AutoTokenizer]
16
+ def self.from_pretrained(model_id, max_length: 512)
17
+ new(model_id, max_length: max_length)
18
+ end
19
+
20
+ def initialize(model_id, max_length: 512)
21
+ @model_id = model_id
22
+ @max_length = max_length
23
+ @tokenizer = load_tokenizer(model_id)
24
+
25
+ configure_tokenizer
26
+ end
27
+
28
+ # Tokenize a single text or batch of texts
29
+ #
30
+ # @param texts [String, Array<String>] Text(s) to tokenize
31
+ # @param padding [Boolean] Whether to pad sequences
32
+ # @param truncation [Boolean] Whether to truncate sequences
33
+ # @param return_tensors [Boolean] Whether to return Torch tensors
34
+ # @return [Hash] Hash with :input_ids, :attention_mask, and optionally :token_type_ids
35
+ def encode(texts, padding: true, truncation: true, return_tensors: true)
36
+ texts = [texts] if texts.is_a?(String)
37
+ single_input = texts.size == 1
38
+
39
+ # Encode all texts
40
+ encodings = texts.map do |text|
41
+ @tokenizer.encode(text)
42
+ end
43
+
44
+ # Get max length in batch for padding
45
+ max_len = if padding
46
+ [encodings.map { |e| e.ids.length }.max, @max_length].min
47
+ else
48
+ @max_length
49
+ end
50
+
51
+ # Build output arrays
52
+ input_ids = []
53
+ attention_mask = []
54
+ token_type_ids = []
55
+
56
+ encodings.each do |encoding|
57
+ ids = encoding.ids
58
+ mask = encoding.attention_mask
59
+ type_ids = encoding.type_ids rescue Array.new(ids.length, 0)
60
+
61
+ # Truncate if needed
62
+ if truncation && ids.length > max_len
63
+ ids = ids[0...max_len]
64
+ mask = mask[0...max_len]
65
+ type_ids = type_ids[0...max_len]
66
+ end
67
+
68
+ # Pad if needed
69
+ if padding && ids.length < max_len
70
+ pad_length = max_len - ids.length
71
+ ids = ids + Array.new(pad_length, pad_token_id)
72
+ mask = mask + Array.new(pad_length, 0)
73
+ type_ids = type_ids + Array.new(pad_length, 0)
74
+ end
75
+
76
+ input_ids << ids
77
+ attention_mask << mask
78
+ token_type_ids << type_ids
79
+ end
80
+
81
+ result = {
82
+ input_ids: input_ids,
83
+ attention_mask: attention_mask
84
+ }
85
+
86
+ # Only include token_type_ids for BERT-style models
87
+ result[:token_type_ids] = token_type_ids if has_token_type_ids?
88
+
89
+ if return_tensors
90
+ result.transform_values! { |v| Torch.tensor(v, dtype: :long) }
91
+ end
92
+
93
+ result
94
+ end
95
+
96
+ # Encode a pair of texts (for sentence pair tasks)
97
+ #
98
+ # @param text_a [String] First text
99
+ # @param text_b [String] Second text
100
+ # @return [Hash] Tokenized output
101
+ def encode_pair(text_a, text_b, **kwargs)
102
+ encoding = @tokenizer.encode(text_a, text_b)
103
+
104
+ ids = encoding.ids
105
+ mask = encoding.attention_mask
106
+ type_ids = encoding.type_ids rescue Array.new(ids.length, 0)
107
+
108
+ # Truncate if needed
109
+ if kwargs.fetch(:truncation, true) && ids.length > @max_length
110
+ ids = ids[0...@max_length]
111
+ mask = mask[0...@max_length]
112
+ type_ids = type_ids[0...@max_length]
113
+ end
114
+
115
+ result = {
116
+ input_ids: [ids],
117
+ attention_mask: [mask]
118
+ }
119
+ result[:token_type_ids] = [type_ids] if has_token_type_ids?
120
+
121
+ if kwargs.fetch(:return_tensors, true)
122
+ result.transform_values! { |v| Torch.tensor(v, dtype: :long) }
123
+ end
124
+
125
+ result
126
+ end
127
+
128
+ # Decode token IDs back to text
129
+ #
130
+ # @param token_ids [Array<Integer>] Token IDs
131
+ # @param skip_special_tokens [Boolean] Whether to skip special tokens
132
+ # @return [String] Decoded text
133
+ def decode(token_ids, skip_special_tokens: true)
134
+ token_ids = token_ids.to_a if token_ids.respond_to?(:to_a)
135
+ @tokenizer.decode(token_ids, skip_special_tokens: skip_special_tokens)
136
+ end
137
+
138
+ # Get vocabulary size
139
+ def vocab_size
140
+ @tokenizer.vocab_size
141
+ end
142
+
143
+ # Get pad token ID
144
+ def pad_token_id
145
+ @tokenizer.token_to_id(@tokenizer.padding&.dig("pad_token") || "[PAD]") || 0
146
+ end
147
+
148
+ # Get CLS token ID
149
+ def cls_token_id
150
+ @tokenizer.token_to_id("[CLS]") || @tokenizer.token_to_id("<s>") || 0
151
+ end
152
+
153
+ # Get SEP token ID
154
+ def sep_token_id
155
+ @tokenizer.token_to_id("[SEP]") || @tokenizer.token_to_id("</s>") || 0
156
+ end
157
+
158
+ # Get EOS token ID
159
+ def eos_token_id
160
+ @tokenizer.token_to_id("</s>") || @tokenizer.token_to_id("[SEP]") || @tokenizer.token_to_id("<|endoftext|>") || 0
161
+ end
162
+
163
+ # Get BOS token ID
164
+ def bos_token_id
165
+ @tokenizer.token_to_id("<s>") || @tokenizer.token_to_id("[CLS]") || @tokenizer.token_to_id("<|startoftext|>") || 0
166
+ end
167
+
168
+ # Save tokenizer to directory
169
+ def save(path)
170
+ FileUtils.mkdir_p(path)
171
+ @tokenizer.save(File.join(path, "tokenizer.json"))
172
+ end
173
+
174
+ private
175
+
176
+ def load_tokenizer(model_id)
177
+ # Check if it's a local path with tokenizer.json
178
+ local_tokenizer_path = File.join(model_id, "tokenizer.json")
179
+ if File.exist?(local_tokenizer_path)
180
+ return ::Tokenizers::Tokenizer.from_file(local_tokenizer_path)
181
+ end
182
+
183
+ # Check if model_id itself is a tokenizer.json path
184
+ if File.exist?(model_id) && model_id.end_with?("tokenizer.json")
185
+ return ::Tokenizers::Tokenizer.from_file(model_id)
186
+ end
187
+
188
+ # Try to load from HuggingFace Hub
189
+ ::Tokenizers::Tokenizer.from_pretrained(model_id)
190
+ rescue StandardError => e
191
+ raise ConfigurationError, "Failed to load tokenizer for #{model_id}: #{e.message}"
192
+ end
193
+
194
+ def configure_tokenizer
195
+ # Enable truncation
196
+ @tokenizer.enable_truncation(@max_length)
197
+
198
+ # Enable padding
199
+ @tokenizer.enable_padding(length: @max_length)
200
+ end
201
+
202
+ def has_token_type_ids?
203
+ # BERT-style models use token_type_ids, RoBERTa/DistilBERT don't always need them
204
+ @model_id.include?("bert") && !@model_id.include?("roberta")
205
+ end
206
+ end
207
+ end
208
+ end