fine 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.rspec +3 -0
- data/CHANGELOG.md +38 -0
- data/Gemfile +6 -0
- data/Gemfile.lock +167 -0
- data/LICENSE +21 -0
- data/README.md +212 -0
- data/Rakefile +6 -0
- data/docs/installation.md +151 -0
- data/docs/tutorials/llm-fine-tuning.md +246 -0
- data/docs/tutorials/model-export.md +200 -0
- data/docs/tutorials/siglip2-image-classification.md +130 -0
- data/docs/tutorials/siglip2-object-recognition.md +203 -0
- data/docs/tutorials/siglip2-similarity-search.md +152 -0
- data/docs/tutorials/text-classification.md +233 -0
- data/docs/tutorials/text-embeddings.md +211 -0
- data/examples/basic_classification.rb +70 -0
- data/examples/data/tool_calls.jsonl +30 -0
- data/examples/demo_training.rb +78 -0
- data/examples/finetune_gemma3_tools.rb +135 -0
- data/examples/real_llm_test.rb +128 -0
- data/examples/real_text_classification_test.rb +90 -0
- data/examples/real_text_embedder_test.rb +110 -0
- data/examples/real_training_test.rb +88 -0
- data/examples/test_export.rb +28 -0
- data/examples/test_image_classifier.rb +79 -0
- data/examples/test_llm.rb +100 -0
- data/examples/test_text_classifier.rb +59 -0
- data/lib/fine/callbacks/base.rb +140 -0
- data/lib/fine/callbacks/progress_bar.rb +66 -0
- data/lib/fine/configuration.rb +106 -0
- data/lib/fine/datasets/data_loader.rb +63 -0
- data/lib/fine/datasets/image_dataset.rb +203 -0
- data/lib/fine/datasets/instruction_dataset.rb +226 -0
- data/lib/fine/datasets/text_data_loader.rb +88 -0
- data/lib/fine/datasets/text_dataset.rb +266 -0
- data/lib/fine/error.rb +49 -0
- data/lib/fine/export/gguf_exporter.rb +424 -0
- data/lib/fine/export/onnx_exporter.rb +249 -0
- data/lib/fine/export.rb +53 -0
- data/lib/fine/hub/config_loader.rb +145 -0
- data/lib/fine/hub/model_downloader.rb +136 -0
- data/lib/fine/hub/safetensors_loader.rb +108 -0
- data/lib/fine/image_classifier.rb +256 -0
- data/lib/fine/llm.rb +336 -0
- data/lib/fine/models/base.rb +48 -0
- data/lib/fine/models/bert_encoder.rb +202 -0
- data/lib/fine/models/bert_for_sequence_classification.rb +226 -0
- data/lib/fine/models/causal_lm.rb +279 -0
- data/lib/fine/models/classification_head.rb +24 -0
- data/lib/fine/models/gemma3_decoder.rb +244 -0
- data/lib/fine/models/llama_decoder.rb +297 -0
- data/lib/fine/models/sentence_transformer.rb +202 -0
- data/lib/fine/models/siglip2_for_image_classification.rb +155 -0
- data/lib/fine/models/siglip2_vision_encoder.rb +190 -0
- data/lib/fine/text_classifier.rb +250 -0
- data/lib/fine/text_embedder.rb +221 -0
- data/lib/fine/tokenizers/auto_tokenizer.rb +208 -0
- data/lib/fine/training/llm_trainer.rb +212 -0
- data/lib/fine/training/text_trainer.rb +275 -0
- data/lib/fine/training/trainer.rb +194 -0
- data/lib/fine/transforms/compose.rb +28 -0
- data/lib/fine/transforms/normalize.rb +33 -0
- data/lib/fine/transforms/resize.rb +35 -0
- data/lib/fine/transforms/to_tensor.rb +53 -0
- data/lib/fine/version.rb +3 -0
- data/lib/fine.rb +112 -0
- data/mise.toml +2 -0
- metadata +240 -0
|
@@ -0,0 +1,226 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fine
|
|
4
|
+
module Datasets
|
|
5
|
+
# Dataset for instruction/chat fine-tuning
|
|
6
|
+
#
|
|
7
|
+
# Supports common formats:
|
|
8
|
+
# - Alpaca: {"instruction": "...", "input": "...", "output": "..."}
|
|
9
|
+
# - ShareGPT: {"conversations": [{"from": "human", "value": "..."}, {"from": "gpt", "value": "..."}]}
|
|
10
|
+
# - Simple: {"prompt": "...", "completion": "..."}
|
|
11
|
+
class InstructionDataset
|
|
12
|
+
attr_reader :examples, :tokenizer, :max_length
|
|
13
|
+
|
|
14
|
+
# Load from JSONL file
|
|
15
|
+
#
|
|
16
|
+
# @param path [String] Path to JSONL file
|
|
17
|
+
# @param tokenizer [Tokenizers::AutoTokenizer] Tokenizer
|
|
18
|
+
# @param format [Symbol] Data format (:alpaca, :sharegpt, :simple, :auto)
|
|
19
|
+
# @param max_length [Integer] Maximum sequence length
|
|
20
|
+
# @return [InstructionDataset]
|
|
21
|
+
def self.from_jsonl(path, tokenizer:, format: :auto, max_length: 2048)
|
|
22
|
+
examples = File.readlines(path).map { |line| JSON.parse(line, symbolize_names: true) }
|
|
23
|
+
new(examples, tokenizer: tokenizer, format: format, max_length: max_length)
|
|
24
|
+
end
|
|
25
|
+
|
|
26
|
+
def initialize(examples, tokenizer:, format: :auto, max_length: 2048)
|
|
27
|
+
@tokenizer = tokenizer
|
|
28
|
+
@max_length = max_length
|
|
29
|
+
@format = format == :auto ? detect_format(examples.first) : format
|
|
30
|
+
|
|
31
|
+
@examples = examples.map { |ex| normalize_example(ex) }
|
|
32
|
+
end
|
|
33
|
+
|
|
34
|
+
def size
|
|
35
|
+
@examples.size
|
|
36
|
+
end
|
|
37
|
+
|
|
38
|
+
def [](idx)
|
|
39
|
+
example = @examples[idx]
|
|
40
|
+
|
|
41
|
+
# Format as prompt + completion
|
|
42
|
+
text = format_example(example)
|
|
43
|
+
|
|
44
|
+
# Tokenize (without tensors for easier manipulation)
|
|
45
|
+
encoding = @tokenizer.encode(text, return_tensors: false)
|
|
46
|
+
input_ids = encoding[:input_ids].first
|
|
47
|
+
|
|
48
|
+
# Get prompt length before truncation
|
|
49
|
+
prompt_text = format_prompt_only(example)
|
|
50
|
+
prompt_encoding = @tokenizer.encode(prompt_text, return_tensors: false)
|
|
51
|
+
prompt_length = prompt_encoding[:input_ids].first.size
|
|
52
|
+
|
|
53
|
+
# Truncate if needed, but ensure at least some completion tokens remain
|
|
54
|
+
if input_ids.size > @max_length
|
|
55
|
+
# If prompt alone is too long, truncate from the left (keep completion)
|
|
56
|
+
if prompt_length >= @max_length - 10
|
|
57
|
+
# Keep last max_length tokens (includes completion)
|
|
58
|
+
input_ids = input_ids.last(@max_length)
|
|
59
|
+
# No masking since we dropped the prompt prefix
|
|
60
|
+
prompt_length = 0
|
|
61
|
+
else
|
|
62
|
+
input_ids = input_ids.first(@max_length)
|
|
63
|
+
end
|
|
64
|
+
end
|
|
65
|
+
|
|
66
|
+
# Labels are same as input_ids for causal LM (predict next token)
|
|
67
|
+
labels = input_ids.dup
|
|
68
|
+
|
|
69
|
+
# Mask prompt tokens with -100 (ignored in loss), but only if there's room for completion
|
|
70
|
+
if prompt_length > 0 && prompt_length < input_ids.size
|
|
71
|
+
labels[0...prompt_length] = [-100] * prompt_length
|
|
72
|
+
end
|
|
73
|
+
# If prompt_length >= input_ids.size, don't mask (train on full sequence)
|
|
74
|
+
|
|
75
|
+
{
|
|
76
|
+
input_ids: Torch.tensor([input_ids]),
|
|
77
|
+
labels: Torch.tensor([labels]),
|
|
78
|
+
attention_mask: Torch.ones(1, input_ids.size)
|
|
79
|
+
}
|
|
80
|
+
end
|
|
81
|
+
|
|
82
|
+
# Split dataset
|
|
83
|
+
def split(test_size: 0.1, seed: 42)
|
|
84
|
+
rng = Random.new(seed)
|
|
85
|
+
indices = (0...size).to_a.shuffle(random: rng)
|
|
86
|
+
|
|
87
|
+
split_idx = (size * (1 - test_size)).to_i
|
|
88
|
+
train_indices = indices[0...split_idx]
|
|
89
|
+
test_indices = indices[split_idx..]
|
|
90
|
+
|
|
91
|
+
train_examples = train_indices.map { |i| @examples[i] }
|
|
92
|
+
test_examples = test_indices.map { |i| @examples[i] }
|
|
93
|
+
|
|
94
|
+
[
|
|
95
|
+
self.class.new(train_examples, tokenizer: @tokenizer, format: @format, max_length: @max_length),
|
|
96
|
+
self.class.new(test_examples, tokenizer: @tokenizer, format: @format, max_length: @max_length)
|
|
97
|
+
]
|
|
98
|
+
end
|
|
99
|
+
|
|
100
|
+
private
|
|
101
|
+
|
|
102
|
+
def detect_format(example)
|
|
103
|
+
if example.key?(:instruction)
|
|
104
|
+
:alpaca
|
|
105
|
+
elsif example.key?(:conversations)
|
|
106
|
+
:sharegpt
|
|
107
|
+
elsif example.key?(:prompt) || example.key?(:text)
|
|
108
|
+
:simple
|
|
109
|
+
else
|
|
110
|
+
raise DatasetError, "Cannot detect format. Keys: #{example.keys}"
|
|
111
|
+
end
|
|
112
|
+
end
|
|
113
|
+
|
|
114
|
+
def normalize_example(example)
|
|
115
|
+
case @format
|
|
116
|
+
when :alpaca
|
|
117
|
+
{
|
|
118
|
+
prompt: build_alpaca_prompt(example[:instruction], example[:input]),
|
|
119
|
+
completion: example[:output] || example[:response]
|
|
120
|
+
}
|
|
121
|
+
when :sharegpt
|
|
122
|
+
conversations = example[:conversations]
|
|
123
|
+
# Take first human/assistant pair
|
|
124
|
+
human = conversations.find { |c| c[:from] == "human" }
|
|
125
|
+
assistant = conversations.find { |c| c[:from] == "gpt" || c[:from] == "assistant" }
|
|
126
|
+
{
|
|
127
|
+
prompt: human[:value],
|
|
128
|
+
completion: assistant[:value]
|
|
129
|
+
}
|
|
130
|
+
when :simple
|
|
131
|
+
{
|
|
132
|
+
prompt: example[:prompt] || example[:text],
|
|
133
|
+
completion: example[:completion] || example[:response] || ""
|
|
134
|
+
}
|
|
135
|
+
else
|
|
136
|
+
raise DatasetError, "Unknown format: #{@format}"
|
|
137
|
+
end
|
|
138
|
+
end
|
|
139
|
+
|
|
140
|
+
def build_alpaca_prompt(instruction, input = nil)
|
|
141
|
+
if input && !input.empty?
|
|
142
|
+
"### Instruction:\n#{instruction}\n\n### Input:\n#{input}\n\n### Response:\n"
|
|
143
|
+
else
|
|
144
|
+
"### Instruction:\n#{instruction}\n\n### Response:\n"
|
|
145
|
+
end
|
|
146
|
+
end
|
|
147
|
+
|
|
148
|
+
def format_example(example)
|
|
149
|
+
"#{example[:prompt]}#{example[:completion]}"
|
|
150
|
+
end
|
|
151
|
+
|
|
152
|
+
def format_prompt_only(example)
|
|
153
|
+
example[:prompt]
|
|
154
|
+
end
|
|
155
|
+
end
|
|
156
|
+
|
|
157
|
+
# Data loader for instruction dataset with dynamic padding
|
|
158
|
+
class InstructionDataLoader
|
|
159
|
+
include Enumerable
|
|
160
|
+
|
|
161
|
+
def initialize(dataset, batch_size:, shuffle: true, pad_token_id: 0)
|
|
162
|
+
@dataset = dataset
|
|
163
|
+
@batch_size = batch_size
|
|
164
|
+
@shuffle = shuffle
|
|
165
|
+
@pad_token_id = pad_token_id
|
|
166
|
+
end
|
|
167
|
+
|
|
168
|
+
def each
|
|
169
|
+
indices = (0...@dataset.size).to_a
|
|
170
|
+
indices.shuffle! if @shuffle
|
|
171
|
+
|
|
172
|
+
indices.each_slice(@batch_size) do |batch_indices|
|
|
173
|
+
batch = batch_indices.map { |i| @dataset[i] }
|
|
174
|
+
yield collate_batch(batch)
|
|
175
|
+
end
|
|
176
|
+
end
|
|
177
|
+
|
|
178
|
+
def size
|
|
179
|
+
(@dataset.size.to_f / @batch_size).ceil
|
|
180
|
+
end
|
|
181
|
+
|
|
182
|
+
private
|
|
183
|
+
|
|
184
|
+
def collate_batch(batch)
|
|
185
|
+
max_len = batch.map { |b| b[:input_ids].size(-1) }.max
|
|
186
|
+
|
|
187
|
+
input_ids = []
|
|
188
|
+
labels = []
|
|
189
|
+
attention_masks = []
|
|
190
|
+
|
|
191
|
+
batch.each do |item|
|
|
192
|
+
seq_len = item[:input_ids].size(-1)
|
|
193
|
+
pad_len = max_len - seq_len
|
|
194
|
+
|
|
195
|
+
if pad_len > 0
|
|
196
|
+
# Pad on the right
|
|
197
|
+
input_ids << Torch.cat([
|
|
198
|
+
item[:input_ids],
|
|
199
|
+
Torch.full([1, pad_len], @pad_token_id)
|
|
200
|
+
], dim: 1)
|
|
201
|
+
|
|
202
|
+
labels << Torch.cat([
|
|
203
|
+
item[:labels],
|
|
204
|
+
Torch.full([1, pad_len], -100) # Ignore padding in loss
|
|
205
|
+
], dim: 1)
|
|
206
|
+
|
|
207
|
+
attention_masks << Torch.cat([
|
|
208
|
+
item[:attention_mask],
|
|
209
|
+
Torch.zeros(1, pad_len)
|
|
210
|
+
], dim: 1)
|
|
211
|
+
else
|
|
212
|
+
input_ids << item[:input_ids]
|
|
213
|
+
labels << item[:labels]
|
|
214
|
+
attention_masks << item[:attention_mask]
|
|
215
|
+
end
|
|
216
|
+
end
|
|
217
|
+
|
|
218
|
+
{
|
|
219
|
+
input_ids: Torch.cat(input_ids, dim: 0),
|
|
220
|
+
labels: Torch.cat(labels, dim: 0),
|
|
221
|
+
attention_mask: Torch.cat(attention_masks, dim: 0)
|
|
222
|
+
}
|
|
223
|
+
end
|
|
224
|
+
end
|
|
225
|
+
end
|
|
226
|
+
end
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fine
|
|
4
|
+
module Datasets
|
|
5
|
+
# DataLoader for text datasets with dynamic padding
|
|
6
|
+
class TextDataLoader
|
|
7
|
+
include Enumerable
|
|
8
|
+
|
|
9
|
+
attr_reader :dataset, :batch_size, :shuffle, :drop_last
|
|
10
|
+
|
|
11
|
+
def initialize(dataset, batch_size:, shuffle: false, drop_last: false)
|
|
12
|
+
@dataset = dataset
|
|
13
|
+
@batch_size = batch_size
|
|
14
|
+
@shuffle = shuffle
|
|
15
|
+
@drop_last = drop_last
|
|
16
|
+
end
|
|
17
|
+
|
|
18
|
+
# Iterate over batches
|
|
19
|
+
def each_batch
|
|
20
|
+
return enum_for(:each_batch) unless block_given?
|
|
21
|
+
|
|
22
|
+
indices = (0...@dataset.size).to_a
|
|
23
|
+
indices.shuffle! if @shuffle
|
|
24
|
+
|
|
25
|
+
indices.each_slice(@batch_size) do |batch_indices|
|
|
26
|
+
next if @drop_last && batch_indices.size < @batch_size
|
|
27
|
+
|
|
28
|
+
yield collate(batch_indices)
|
|
29
|
+
end
|
|
30
|
+
end
|
|
31
|
+
|
|
32
|
+
alias each each_batch
|
|
33
|
+
|
|
34
|
+
# Number of batches
|
|
35
|
+
def size
|
|
36
|
+
n = @dataset.size / @batch_size
|
|
37
|
+
n += 1 unless @drop_last || (@dataset.size % @batch_size).zero?
|
|
38
|
+
n
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
alias num_batches size
|
|
42
|
+
|
|
43
|
+
private
|
|
44
|
+
|
|
45
|
+
def collate(indices)
|
|
46
|
+
samples = indices.map { |i| @dataset[i] }
|
|
47
|
+
|
|
48
|
+
# Find max length in this batch for dynamic padding
|
|
49
|
+
max_len = samples.map { |s| s[:input_ids].length }.max
|
|
50
|
+
|
|
51
|
+
# Pad and stack
|
|
52
|
+
input_ids = []
|
|
53
|
+
attention_mask = []
|
|
54
|
+
token_type_ids = []
|
|
55
|
+
labels = []
|
|
56
|
+
|
|
57
|
+
samples.each do |sample|
|
|
58
|
+
ids = sample[:input_ids]
|
|
59
|
+
mask = sample[:attention_mask]
|
|
60
|
+
type_ids = sample[:token_type_ids]
|
|
61
|
+
|
|
62
|
+
# Pad to max_len
|
|
63
|
+
pad_len = max_len - ids.length
|
|
64
|
+
if pad_len > 0
|
|
65
|
+
ids = ids + Array.new(pad_len, 0)
|
|
66
|
+
mask = mask + Array.new(pad_len, 0)
|
|
67
|
+
type_ids = type_ids + Array.new(pad_len, 0) if type_ids
|
|
68
|
+
end
|
|
69
|
+
|
|
70
|
+
input_ids << ids
|
|
71
|
+
attention_mask << mask
|
|
72
|
+
token_type_ids << type_ids if type_ids
|
|
73
|
+
labels << sample[:label]
|
|
74
|
+
end
|
|
75
|
+
|
|
76
|
+
result = {
|
|
77
|
+
input_ids: Torch.tensor(input_ids, dtype: :long),
|
|
78
|
+
attention_mask: Torch.tensor(attention_mask, dtype: :long),
|
|
79
|
+
labels: Torch.tensor(labels, dtype: :long)
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
result[:token_type_ids] = Torch.tensor(token_type_ids, dtype: :long) if token_type_ids.first
|
|
83
|
+
|
|
84
|
+
result
|
|
85
|
+
end
|
|
86
|
+
end
|
|
87
|
+
end
|
|
88
|
+
end
|
|
@@ -0,0 +1,266 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fine
|
|
4
|
+
module Datasets
|
|
5
|
+
# Dataset for text classification tasks
|
|
6
|
+
#
|
|
7
|
+
# Supports JSONL and CSV formats:
|
|
8
|
+
# JSONL: {"text": "...", "label": "positive"}
|
|
9
|
+
# CSV: text,label (with header)
|
|
10
|
+
#
|
|
11
|
+
class TextDataset
|
|
12
|
+
include Enumerable
|
|
13
|
+
|
|
14
|
+
attr_reader :texts, :labels, :label_map, :inverse_label_map, :tokenizer
|
|
15
|
+
|
|
16
|
+
# Load dataset from a JSONL file
|
|
17
|
+
#
|
|
18
|
+
# @param path [String] Path to JSONL file
|
|
19
|
+
# @param tokenizer [AutoTokenizer] Tokenizer to use
|
|
20
|
+
# @param text_column [String] Name of text field
|
|
21
|
+
# @param label_column [String] Name of label field
|
|
22
|
+
# @return [TextDataset]
|
|
23
|
+
def self.from_jsonl(path, tokenizer:, text_column: "text", label_column: "label")
|
|
24
|
+
raise DatasetError, "File not found: #{path}" unless File.exist?(path)
|
|
25
|
+
|
|
26
|
+
texts = []
|
|
27
|
+
labels = []
|
|
28
|
+
|
|
29
|
+
File.foreach(path) do |line|
|
|
30
|
+
next if line.strip.empty?
|
|
31
|
+
|
|
32
|
+
data = JSON.parse(line)
|
|
33
|
+
texts << data[text_column]
|
|
34
|
+
labels << data[label_column]
|
|
35
|
+
end
|
|
36
|
+
|
|
37
|
+
raise DatasetError, "No data found in #{path}" if texts.empty?
|
|
38
|
+
|
|
39
|
+
new(texts: texts, labels: labels, tokenizer: tokenizer)
|
|
40
|
+
end
|
|
41
|
+
|
|
42
|
+
# Load dataset from a CSV file
|
|
43
|
+
#
|
|
44
|
+
# @param path [String] Path to CSV file
|
|
45
|
+
# @param tokenizer [AutoTokenizer] Tokenizer to use
|
|
46
|
+
# @param text_column [String] Name of text column
|
|
47
|
+
# @param label_column [String] Name of label column
|
|
48
|
+
# @return [TextDataset]
|
|
49
|
+
def self.from_csv(path, tokenizer:, text_column: "text", label_column: "label")
|
|
50
|
+
require "csv"
|
|
51
|
+
raise DatasetError, "File not found: #{path}" unless File.exist?(path)
|
|
52
|
+
|
|
53
|
+
texts = []
|
|
54
|
+
labels = []
|
|
55
|
+
|
|
56
|
+
CSV.foreach(path, headers: true) do |row|
|
|
57
|
+
texts << row[text_column]
|
|
58
|
+
labels << row[label_column]
|
|
59
|
+
end
|
|
60
|
+
|
|
61
|
+
raise DatasetError, "No data found in #{path}" if texts.empty?
|
|
62
|
+
|
|
63
|
+
new(texts: texts, labels: labels, tokenizer: tokenizer)
|
|
64
|
+
end
|
|
65
|
+
|
|
66
|
+
# Load from file (auto-detect format)
|
|
67
|
+
#
|
|
68
|
+
# @param path [String] Path to data file
|
|
69
|
+
# @param tokenizer [AutoTokenizer] Tokenizer to use
|
|
70
|
+
# @return [TextDataset]
|
|
71
|
+
def self.from_file(path, tokenizer:, **kwargs)
|
|
72
|
+
case File.extname(path).downcase
|
|
73
|
+
when ".jsonl", ".json"
|
|
74
|
+
from_jsonl(path, tokenizer: tokenizer, **kwargs)
|
|
75
|
+
when ".csv"
|
|
76
|
+
from_csv(path, tokenizer: tokenizer, **kwargs)
|
|
77
|
+
else
|
|
78
|
+
# Try JSONL first, then CSV
|
|
79
|
+
begin
|
|
80
|
+
from_jsonl(path, tokenizer: tokenizer, **kwargs)
|
|
81
|
+
rescue JSON::ParserError
|
|
82
|
+
from_csv(path, tokenizer: tokenizer, **kwargs)
|
|
83
|
+
end
|
|
84
|
+
end
|
|
85
|
+
end
|
|
86
|
+
|
|
87
|
+
def initialize(texts:, labels:, tokenizer:, label_map: nil)
|
|
88
|
+
raise ArgumentError, "texts and labels must have same length" if texts.size != labels.size
|
|
89
|
+
|
|
90
|
+
@texts = texts
|
|
91
|
+
@tokenizer = tokenizer
|
|
92
|
+
|
|
93
|
+
# Build label map if not provided
|
|
94
|
+
if label_map
|
|
95
|
+
@label_map = label_map
|
|
96
|
+
else
|
|
97
|
+
unique_labels = labels.uniq.sort
|
|
98
|
+
@label_map = unique_labels.each_with_index.to_h
|
|
99
|
+
end
|
|
100
|
+
|
|
101
|
+
# Convert string labels to integers
|
|
102
|
+
@labels = labels.map do |label|
|
|
103
|
+
label.is_a?(Integer) ? label : @label_map[label]
|
|
104
|
+
end
|
|
105
|
+
|
|
106
|
+
# Build inverse mapping
|
|
107
|
+
@inverse_label_map = @label_map.invert
|
|
108
|
+
end
|
|
109
|
+
|
|
110
|
+
# Get a single item from the dataset
|
|
111
|
+
#
|
|
112
|
+
# @param index [Integer] Index of the item
|
|
113
|
+
# @return [Hash] Hash with tokenized inputs and label
|
|
114
|
+
def [](index)
|
|
115
|
+
text = @texts[index]
|
|
116
|
+
encoding = @tokenizer.encode(text, return_tensors: false)
|
|
117
|
+
|
|
118
|
+
{
|
|
119
|
+
input_ids: encoding[:input_ids].first,
|
|
120
|
+
attention_mask: encoding[:attention_mask].first,
|
|
121
|
+
token_type_ids: encoding[:token_type_ids]&.first,
|
|
122
|
+
label: @labels[index]
|
|
123
|
+
}.compact
|
|
124
|
+
end
|
|
125
|
+
|
|
126
|
+
# Number of items in the dataset
|
|
127
|
+
def size
|
|
128
|
+
@texts.size
|
|
129
|
+
end
|
|
130
|
+
alias length size
|
|
131
|
+
|
|
132
|
+
# Iterate over all items
|
|
133
|
+
def each
|
|
134
|
+
return enum_for(:each) unless block_given?
|
|
135
|
+
|
|
136
|
+
size.times { |i| yield self[i] }
|
|
137
|
+
end
|
|
138
|
+
|
|
139
|
+
# Number of classes
|
|
140
|
+
def num_classes
|
|
141
|
+
@label_map.size
|
|
142
|
+
end
|
|
143
|
+
|
|
144
|
+
# Get class names in order
|
|
145
|
+
def class_names
|
|
146
|
+
@inverse_label_map.sort.map(&:last)
|
|
147
|
+
end
|
|
148
|
+
|
|
149
|
+
# Split dataset into train and validation sets
|
|
150
|
+
#
|
|
151
|
+
# @param test_size [Float] Fraction of data for validation (0.0-1.0)
|
|
152
|
+
# @param shuffle [Boolean] Whether to shuffle before splitting
|
|
153
|
+
# @param stratify [Boolean] Whether to maintain class distribution
|
|
154
|
+
# @param seed [Integer, nil] Random seed
|
|
155
|
+
# @return [Array<TextDataset, TextDataset>] Train and validation datasets
|
|
156
|
+
def split(test_size: 0.2, shuffle: true, stratify: true, seed: nil)
|
|
157
|
+
rng = seed ? Random.new(seed) : Random.new
|
|
158
|
+
|
|
159
|
+
indices = (0...size).to_a
|
|
160
|
+
indices = indices.shuffle(random: rng) if shuffle && !stratify
|
|
161
|
+
|
|
162
|
+
if stratify
|
|
163
|
+
train_indices, val_indices = stratified_split(indices, test_size, rng)
|
|
164
|
+
else
|
|
165
|
+
split_idx = (size * (1 - test_size)).round
|
|
166
|
+
train_indices = indices[0...split_idx]
|
|
167
|
+
val_indices = indices[split_idx..]
|
|
168
|
+
end
|
|
169
|
+
|
|
170
|
+
train_set = subset(train_indices)
|
|
171
|
+
val_set = subset(val_indices)
|
|
172
|
+
|
|
173
|
+
[train_set, val_set]
|
|
174
|
+
end
|
|
175
|
+
|
|
176
|
+
private
|
|
177
|
+
|
|
178
|
+
def subset(indices)
|
|
179
|
+
TextDataset.new(
|
|
180
|
+
texts: indices.map { |i| @texts[i] },
|
|
181
|
+
labels: indices.map { |i| @labels[i] },
|
|
182
|
+
tokenizer: @tokenizer,
|
|
183
|
+
label_map: @label_map
|
|
184
|
+
)
|
|
185
|
+
end
|
|
186
|
+
|
|
187
|
+
def stratified_split(indices, test_size, rng)
|
|
188
|
+
train_indices = []
|
|
189
|
+
val_indices = []
|
|
190
|
+
|
|
191
|
+
# Group indices by label
|
|
192
|
+
by_label = indices.group_by { |i| @labels[i] }
|
|
193
|
+
|
|
194
|
+
by_label.each_value do |label_indices|
|
|
195
|
+
shuffled = label_indices.shuffle(random: rng)
|
|
196
|
+
split_idx = (shuffled.size * (1 - test_size)).round
|
|
197
|
+
|
|
198
|
+
train_indices.concat(shuffled[0...split_idx])
|
|
199
|
+
val_indices.concat(shuffled[split_idx..])
|
|
200
|
+
end
|
|
201
|
+
|
|
202
|
+
[train_indices.shuffle(random: rng), val_indices.shuffle(random: rng)]
|
|
203
|
+
end
|
|
204
|
+
end
|
|
205
|
+
|
|
206
|
+
# Dataset for text pair tasks (similarity, NLI, etc.)
|
|
207
|
+
class TextPairDataset
|
|
208
|
+
include Enumerable
|
|
209
|
+
|
|
210
|
+
attr_reader :texts_a, :texts_b, :labels, :tokenizer
|
|
211
|
+
|
|
212
|
+
# Load from JSONL with query/positive pairs
|
|
213
|
+
#
|
|
214
|
+
# @param path [String] Path to JSONL file
|
|
215
|
+
# @param tokenizer [AutoTokenizer] Tokenizer to use
|
|
216
|
+
# @return [TextPairDataset]
|
|
217
|
+
def self.from_jsonl(path, tokenizer:, text_a_column: "query", text_b_column: "positive")
|
|
218
|
+
raise DatasetError, "File not found: #{path}" unless File.exist?(path)
|
|
219
|
+
|
|
220
|
+
texts_a = []
|
|
221
|
+
texts_b = []
|
|
222
|
+
labels = []
|
|
223
|
+
|
|
224
|
+
File.foreach(path) do |line|
|
|
225
|
+
next if line.strip.empty?
|
|
226
|
+
|
|
227
|
+
data = JSON.parse(line)
|
|
228
|
+
texts_a << data[text_a_column]
|
|
229
|
+
texts_b << data[text_b_column]
|
|
230
|
+
labels << (data["label"] || 1.0) # Default to positive pair
|
|
231
|
+
end
|
|
232
|
+
|
|
233
|
+
new(texts_a: texts_a, texts_b: texts_b, labels: labels, tokenizer: tokenizer)
|
|
234
|
+
end
|
|
235
|
+
|
|
236
|
+
def initialize(texts_a:, texts_b:, labels:, tokenizer:)
|
|
237
|
+
@texts_a = texts_a
|
|
238
|
+
@texts_b = texts_b
|
|
239
|
+
@labels = labels
|
|
240
|
+
@tokenizer = tokenizer
|
|
241
|
+
end
|
|
242
|
+
|
|
243
|
+
def [](index)
|
|
244
|
+
encoding = @tokenizer.encode_pair(@texts_a[index], @texts_b[index], return_tensors: false)
|
|
245
|
+
|
|
246
|
+
{
|
|
247
|
+
input_ids: encoding[:input_ids].first,
|
|
248
|
+
attention_mask: encoding[:attention_mask].first,
|
|
249
|
+
token_type_ids: encoding[:token_type_ids]&.first,
|
|
250
|
+
label: @labels[index]
|
|
251
|
+
}.compact
|
|
252
|
+
end
|
|
253
|
+
|
|
254
|
+
def size
|
|
255
|
+
@texts_a.size
|
|
256
|
+
end
|
|
257
|
+
alias length size
|
|
258
|
+
|
|
259
|
+
def each
|
|
260
|
+
return enum_for(:each) unless block_given?
|
|
261
|
+
|
|
262
|
+
size.times { |i| yield self[i] }
|
|
263
|
+
end
|
|
264
|
+
end
|
|
265
|
+
end
|
|
266
|
+
end
|
data/lib/fine/error.rb
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fine
|
|
4
|
+
# Base error class for all Fine errors
|
|
5
|
+
class Error < StandardError; end
|
|
6
|
+
|
|
7
|
+
# Raised when a model cannot be found on Hugging Face Hub
|
|
8
|
+
class ModelNotFoundError < Error
|
|
9
|
+
attr_reader :model_id
|
|
10
|
+
|
|
11
|
+
def initialize(model_id, message = nil)
|
|
12
|
+
@model_id = model_id
|
|
13
|
+
super(message || "Model not found: #{model_id}")
|
|
14
|
+
end
|
|
15
|
+
end
|
|
16
|
+
|
|
17
|
+
# Raised when configuration is invalid
|
|
18
|
+
class ConfigurationError < Error; end
|
|
19
|
+
|
|
20
|
+
# Raised when there's an issue with dataset loading or processing
|
|
21
|
+
class DatasetError < Error; end
|
|
22
|
+
|
|
23
|
+
# Raised when training fails
|
|
24
|
+
class TrainingError < Error; end
|
|
25
|
+
|
|
26
|
+
# Raised when model weights cannot be loaded
|
|
27
|
+
class WeightLoadingError < Error
|
|
28
|
+
attr_reader :missing_keys, :unexpected_keys
|
|
29
|
+
|
|
30
|
+
def initialize(message, missing_keys: [], unexpected_keys: [])
|
|
31
|
+
@missing_keys = missing_keys
|
|
32
|
+
@unexpected_keys = unexpected_keys
|
|
33
|
+
super(message)
|
|
34
|
+
end
|
|
35
|
+
end
|
|
36
|
+
|
|
37
|
+
# Raised when image processing fails
|
|
38
|
+
class ImageProcessingError < Error
|
|
39
|
+
attr_reader :path
|
|
40
|
+
|
|
41
|
+
def initialize(path, message = nil)
|
|
42
|
+
@path = path
|
|
43
|
+
super(message || "Failed to process image: #{path}")
|
|
44
|
+
end
|
|
45
|
+
end
|
|
46
|
+
|
|
47
|
+
# Raised when model export fails
|
|
48
|
+
class ExportError < Error; end
|
|
49
|
+
end
|