fine 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.rspec +3 -0
- data/CHANGELOG.md +38 -0
- data/Gemfile +6 -0
- data/Gemfile.lock +167 -0
- data/LICENSE +21 -0
- data/README.md +212 -0
- data/Rakefile +6 -0
- data/docs/installation.md +151 -0
- data/docs/tutorials/llm-fine-tuning.md +246 -0
- data/docs/tutorials/model-export.md +200 -0
- data/docs/tutorials/siglip2-image-classification.md +130 -0
- data/docs/tutorials/siglip2-object-recognition.md +203 -0
- data/docs/tutorials/siglip2-similarity-search.md +152 -0
- data/docs/tutorials/text-classification.md +233 -0
- data/docs/tutorials/text-embeddings.md +211 -0
- data/examples/basic_classification.rb +70 -0
- data/examples/data/tool_calls.jsonl +30 -0
- data/examples/demo_training.rb +78 -0
- data/examples/finetune_gemma3_tools.rb +135 -0
- data/examples/real_llm_test.rb +128 -0
- data/examples/real_text_classification_test.rb +90 -0
- data/examples/real_text_embedder_test.rb +110 -0
- data/examples/real_training_test.rb +88 -0
- data/examples/test_export.rb +28 -0
- data/examples/test_image_classifier.rb +79 -0
- data/examples/test_llm.rb +100 -0
- data/examples/test_text_classifier.rb +59 -0
- data/lib/fine/callbacks/base.rb +140 -0
- data/lib/fine/callbacks/progress_bar.rb +66 -0
- data/lib/fine/configuration.rb +106 -0
- data/lib/fine/datasets/data_loader.rb +63 -0
- data/lib/fine/datasets/image_dataset.rb +203 -0
- data/lib/fine/datasets/instruction_dataset.rb +226 -0
- data/lib/fine/datasets/text_data_loader.rb +88 -0
- data/lib/fine/datasets/text_dataset.rb +266 -0
- data/lib/fine/error.rb +49 -0
- data/lib/fine/export/gguf_exporter.rb +424 -0
- data/lib/fine/export/onnx_exporter.rb +249 -0
- data/lib/fine/export.rb +53 -0
- data/lib/fine/hub/config_loader.rb +145 -0
- data/lib/fine/hub/model_downloader.rb +136 -0
- data/lib/fine/hub/safetensors_loader.rb +108 -0
- data/lib/fine/image_classifier.rb +256 -0
- data/lib/fine/llm.rb +336 -0
- data/lib/fine/models/base.rb +48 -0
- data/lib/fine/models/bert_encoder.rb +202 -0
- data/lib/fine/models/bert_for_sequence_classification.rb +226 -0
- data/lib/fine/models/causal_lm.rb +279 -0
- data/lib/fine/models/classification_head.rb +24 -0
- data/lib/fine/models/gemma3_decoder.rb +244 -0
- data/lib/fine/models/llama_decoder.rb +297 -0
- data/lib/fine/models/sentence_transformer.rb +202 -0
- data/lib/fine/models/siglip2_for_image_classification.rb +155 -0
- data/lib/fine/models/siglip2_vision_encoder.rb +190 -0
- data/lib/fine/text_classifier.rb +250 -0
- data/lib/fine/text_embedder.rb +221 -0
- data/lib/fine/tokenizers/auto_tokenizer.rb +208 -0
- data/lib/fine/training/llm_trainer.rb +212 -0
- data/lib/fine/training/text_trainer.rb +275 -0
- data/lib/fine/training/trainer.rb +194 -0
- data/lib/fine/transforms/compose.rb +28 -0
- data/lib/fine/transforms/normalize.rb +33 -0
- data/lib/fine/transforms/resize.rb +35 -0
- data/lib/fine/transforms/to_tensor.rb +53 -0
- data/lib/fine/version.rb +3 -0
- data/lib/fine.rb +112 -0
- data/mise.toml +2 -0
- metadata +240 -0
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fine
|
|
4
|
+
module Training
|
|
5
|
+
# Trainer for causal language model fine-tuning
|
|
6
|
+
class LLMTrainer
|
|
7
|
+
attr_reader :model, :config, :train_dataset, :val_dataset
|
|
8
|
+
|
|
9
|
+
def initialize(model, config, train_dataset:, val_dataset: nil)
|
|
10
|
+
@model = model
|
|
11
|
+
@config = config
|
|
12
|
+
@train_dataset = train_dataset
|
|
13
|
+
@val_dataset = val_dataset
|
|
14
|
+
@device = Fine.device
|
|
15
|
+
@history = []
|
|
16
|
+
end
|
|
17
|
+
|
|
18
|
+
def fit
|
|
19
|
+
@model.to(@device)
|
|
20
|
+
@model.train
|
|
21
|
+
|
|
22
|
+
optimizer = create_optimizer
|
|
23
|
+
scheduler = create_scheduler(optimizer)
|
|
24
|
+
|
|
25
|
+
@config.callbacks.each { |cb| cb.on_train_begin(model: @model, config: @config) }
|
|
26
|
+
|
|
27
|
+
@config.epochs.times do |epoch|
|
|
28
|
+
epoch_loss = train_epoch(optimizer, scheduler, epoch)
|
|
29
|
+
|
|
30
|
+
metrics = { loss: epoch_loss }
|
|
31
|
+
|
|
32
|
+
# Validation
|
|
33
|
+
if @val_dataset
|
|
34
|
+
val_loss, val_perplexity = evaluate
|
|
35
|
+
metrics[:val_loss] = val_loss
|
|
36
|
+
metrics[:val_perplexity] = val_perplexity
|
|
37
|
+
end
|
|
38
|
+
|
|
39
|
+
@history << { epoch: epoch + 1, **metrics }
|
|
40
|
+
|
|
41
|
+
@config.callbacks.each { |cb| cb.on_epoch_end(self, epoch + 1, metrics) }
|
|
42
|
+
end
|
|
43
|
+
|
|
44
|
+
@config.callbacks.each { |cb| cb.on_train_end(history: @history) }
|
|
45
|
+
|
|
46
|
+
@history
|
|
47
|
+
end
|
|
48
|
+
|
|
49
|
+
def evaluate
|
|
50
|
+
@model.eval
|
|
51
|
+
total_loss = 0.0
|
|
52
|
+
num_batches = 0
|
|
53
|
+
|
|
54
|
+
data_loader = create_data_loader(@val_dataset, shuffle: false)
|
|
55
|
+
|
|
56
|
+
Torch.no_grad do
|
|
57
|
+
data_loader.each do |batch|
|
|
58
|
+
batch = move_to_device(batch)
|
|
59
|
+
|
|
60
|
+
outputs = @model.forward(
|
|
61
|
+
batch[:input_ids],
|
|
62
|
+
attention_mask: batch[:attention_mask],
|
|
63
|
+
labels: batch[:labels]
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
total_loss += outputs[:loss].to(:float32).item
|
|
67
|
+
num_batches += 1
|
|
68
|
+
end
|
|
69
|
+
end
|
|
70
|
+
|
|
71
|
+
@model.train
|
|
72
|
+
|
|
73
|
+
avg_loss = total_loss / num_batches
|
|
74
|
+
perplexity = Math.exp(avg_loss)
|
|
75
|
+
|
|
76
|
+
[avg_loss, perplexity]
|
|
77
|
+
end
|
|
78
|
+
|
|
79
|
+
private
|
|
80
|
+
|
|
81
|
+
def train_epoch(optimizer, scheduler, epoch)
|
|
82
|
+
total_loss = 0.0
|
|
83
|
+
num_batches = 0
|
|
84
|
+
|
|
85
|
+
data_loader = create_data_loader(@train_dataset, shuffle: true)
|
|
86
|
+
total_steps = data_loader.size
|
|
87
|
+
|
|
88
|
+
@config.callbacks.each do |cb|
|
|
89
|
+
cb.on_epoch_begin(epoch + 1, total_steps: total_steps) if cb.respond_to?(:on_epoch_begin)
|
|
90
|
+
end
|
|
91
|
+
|
|
92
|
+
data_loader.each_with_index do |batch, step|
|
|
93
|
+
# Move batch to device
|
|
94
|
+
input_ids = batch[:input_ids].to(@device)
|
|
95
|
+
attention_mask = batch[:attention_mask].to(@device)
|
|
96
|
+
labels = batch[:labels].to(@device)
|
|
97
|
+
|
|
98
|
+
# Forward pass
|
|
99
|
+
outputs = @model.forward(input_ids, attention_mask: attention_mask, labels: labels)
|
|
100
|
+
|
|
101
|
+
# Get loss value before backward
|
|
102
|
+
loss_value = outputs[:loss].detach.to(:float32).item
|
|
103
|
+
|
|
104
|
+
# Backward pass - scale loss for gradient accumulation
|
|
105
|
+
scaled_loss = outputs[:loss] / @config.gradient_accumulation_steps
|
|
106
|
+
scaled_loss.backward
|
|
107
|
+
|
|
108
|
+
# CRITICAL: Clear ALL references to free computation graph
|
|
109
|
+
scaled_loss = nil
|
|
110
|
+
outputs = nil
|
|
111
|
+
input_ids = nil
|
|
112
|
+
attention_mask = nil
|
|
113
|
+
labels = nil
|
|
114
|
+
batch = nil
|
|
115
|
+
|
|
116
|
+
if (step + 1) % @config.gradient_accumulation_steps == 0
|
|
117
|
+
# Gradient clipping
|
|
118
|
+
if @config.max_grad_norm
|
|
119
|
+
clip_grad_norm(@model.parameters, @config.max_grad_norm)
|
|
120
|
+
end
|
|
121
|
+
|
|
122
|
+
optimizer.step
|
|
123
|
+
scheduler&.step
|
|
124
|
+
optimizer.zero_grad
|
|
125
|
+
|
|
126
|
+
# Force GC after each optimizer step to free computation graphs
|
|
127
|
+
GC.start
|
|
128
|
+
end
|
|
129
|
+
|
|
130
|
+
total_loss += loss_value
|
|
131
|
+
num_batches += 1
|
|
132
|
+
|
|
133
|
+
@config.callbacks.each do |cb|
|
|
134
|
+
cb.on_batch_end(self, step + 1, loss_value) if cb.respond_to?(:on_batch_end)
|
|
135
|
+
end
|
|
136
|
+
end
|
|
137
|
+
|
|
138
|
+
total_loss / num_batches
|
|
139
|
+
end
|
|
140
|
+
|
|
141
|
+
def create_data_loader(dataset, shuffle:)
|
|
142
|
+
Datasets::InstructionDataLoader.new(
|
|
143
|
+
dataset,
|
|
144
|
+
batch_size: @config.batch_size,
|
|
145
|
+
shuffle: shuffle,
|
|
146
|
+
pad_token_id: @config.pad_token_id || 0
|
|
147
|
+
)
|
|
148
|
+
end
|
|
149
|
+
|
|
150
|
+
def create_optimizer
|
|
151
|
+
# Separate weight decay for different parameter groups
|
|
152
|
+
decay_params = []
|
|
153
|
+
no_decay_params = []
|
|
154
|
+
|
|
155
|
+
@model.named_parameters.each do |name, param|
|
|
156
|
+
next unless param.requires_grad
|
|
157
|
+
|
|
158
|
+
if name.include?("bias") || name.include?("layernorm") || name.include?("norm")
|
|
159
|
+
no_decay_params << param
|
|
160
|
+
else
|
|
161
|
+
decay_params << param
|
|
162
|
+
end
|
|
163
|
+
end
|
|
164
|
+
|
|
165
|
+
param_groups = [
|
|
166
|
+
{ params: decay_params, weight_decay: @config.weight_decay },
|
|
167
|
+
{ params: no_decay_params, weight_decay: 0.0 }
|
|
168
|
+
]
|
|
169
|
+
|
|
170
|
+
Torch::Optim::AdamW.new(param_groups, lr: @config.learning_rate)
|
|
171
|
+
end
|
|
172
|
+
|
|
173
|
+
def create_scheduler(optimizer)
|
|
174
|
+
return nil unless @config.warmup_steps && @config.warmup_steps > 0
|
|
175
|
+
|
|
176
|
+
# Linear warmup then constant
|
|
177
|
+
# Note: torch.rb scheduler support may be limited
|
|
178
|
+
nil
|
|
179
|
+
end
|
|
180
|
+
|
|
181
|
+
def move_to_device(batch)
|
|
182
|
+
batch.transform_values { |v| v.to(@device) }
|
|
183
|
+
end
|
|
184
|
+
|
|
185
|
+
# Manual gradient clipping implementation
|
|
186
|
+
def clip_grad_norm(parameters, max_norm)
|
|
187
|
+
total_norm = 0.0
|
|
188
|
+
|
|
189
|
+
parameters.each do |param|
|
|
190
|
+
next unless param.grad
|
|
191
|
+
|
|
192
|
+
# Convert to float32 for .item (bfloat16 not supported)
|
|
193
|
+
param_norm = param.grad.data.norm(2).to(:float32).item
|
|
194
|
+
total_norm += param_norm ** 2
|
|
195
|
+
end
|
|
196
|
+
|
|
197
|
+
total_norm = Math.sqrt(total_norm)
|
|
198
|
+
clip_coef = max_norm / (total_norm + 1e-6)
|
|
199
|
+
|
|
200
|
+
if clip_coef < 1.0
|
|
201
|
+
parameters.each do |param|
|
|
202
|
+
next unless param.grad
|
|
203
|
+
|
|
204
|
+
param.grad.data.mul!(clip_coef)
|
|
205
|
+
end
|
|
206
|
+
end
|
|
207
|
+
|
|
208
|
+
total_norm
|
|
209
|
+
end
|
|
210
|
+
end
|
|
211
|
+
end
|
|
212
|
+
end
|
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fine
|
|
4
|
+
module Training
|
|
5
|
+
# Trainer for text classification models
|
|
6
|
+
class TextTrainer
|
|
7
|
+
attr_reader :model, :config, :train_loader, :val_loader, :label_map
|
|
8
|
+
attr_accessor :stop_training
|
|
9
|
+
|
|
10
|
+
def initialize(model, config, train_dataset:, val_dataset: nil)
|
|
11
|
+
@model = model
|
|
12
|
+
@config = config
|
|
13
|
+
@stop_training = false
|
|
14
|
+
@label_map = train_dataset.label_map
|
|
15
|
+
|
|
16
|
+
@train_loader = Datasets::TextDataLoader.new(
|
|
17
|
+
train_dataset,
|
|
18
|
+
batch_size: config.batch_size,
|
|
19
|
+
shuffle: true
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
@val_loader = if val_dataset
|
|
23
|
+
Datasets::TextDataLoader.new(
|
|
24
|
+
val_dataset,
|
|
25
|
+
batch_size: config.batch_size,
|
|
26
|
+
shuffle: false
|
|
27
|
+
)
|
|
28
|
+
end
|
|
29
|
+
|
|
30
|
+
@history = []
|
|
31
|
+
end
|
|
32
|
+
|
|
33
|
+
# Train the model
|
|
34
|
+
#
|
|
35
|
+
# @return [Array<Hash>] Training history
|
|
36
|
+
def fit
|
|
37
|
+
@model.train
|
|
38
|
+
|
|
39
|
+
optimizer = build_optimizer
|
|
40
|
+
scheduler = build_scheduler(optimizer)
|
|
41
|
+
|
|
42
|
+
run_callbacks(:on_train_begin, self)
|
|
43
|
+
|
|
44
|
+
@config.epochs.times do |epoch|
|
|
45
|
+
break if @stop_training
|
|
46
|
+
|
|
47
|
+
run_callbacks(:on_epoch_begin, self, epoch)
|
|
48
|
+
|
|
49
|
+
train_metrics = train_epoch(optimizer, epoch)
|
|
50
|
+
|
|
51
|
+
val_metrics = @val_loader ? evaluate : {}
|
|
52
|
+
|
|
53
|
+
scheduler&.step
|
|
54
|
+
|
|
55
|
+
metrics = train_metrics.merge(
|
|
56
|
+
val_metrics.transform_keys { |k| :"val_#{k}" }
|
|
57
|
+
)
|
|
58
|
+
@history << metrics
|
|
59
|
+
|
|
60
|
+
run_callbacks(:on_epoch_end, self, epoch, metrics)
|
|
61
|
+
end
|
|
62
|
+
|
|
63
|
+
run_callbacks(:on_train_end, self)
|
|
64
|
+
|
|
65
|
+
@history
|
|
66
|
+
end
|
|
67
|
+
|
|
68
|
+
# Evaluate on validation set
|
|
69
|
+
def evaluate
|
|
70
|
+
@model.eval
|
|
71
|
+
|
|
72
|
+
total_loss = 0.0
|
|
73
|
+
correct = 0
|
|
74
|
+
total = 0
|
|
75
|
+
|
|
76
|
+
Torch.no_grad do
|
|
77
|
+
@val_loader.each_batch do |batch|
|
|
78
|
+
output = @model.call(
|
|
79
|
+
batch[:input_ids],
|
|
80
|
+
attention_mask: batch[:attention_mask],
|
|
81
|
+
token_type_ids: batch[:token_type_ids],
|
|
82
|
+
labels: batch[:labels]
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
total_loss += output[:loss].item * batch[:labels].size(0)
|
|
86
|
+
predictions = output[:logits].argmax(dim: 1)
|
|
87
|
+
correct += predictions.eq(batch[:labels]).sum.item
|
|
88
|
+
total += batch[:labels].size(0)
|
|
89
|
+
end
|
|
90
|
+
end
|
|
91
|
+
|
|
92
|
+
@model.train
|
|
93
|
+
|
|
94
|
+
{
|
|
95
|
+
loss: total_loss / total,
|
|
96
|
+
accuracy: correct.to_f / total
|
|
97
|
+
}
|
|
98
|
+
end
|
|
99
|
+
|
|
100
|
+
private
|
|
101
|
+
|
|
102
|
+
def train_epoch(optimizer, _epoch)
|
|
103
|
+
total_loss = 0.0
|
|
104
|
+
correct = 0
|
|
105
|
+
total = 0
|
|
106
|
+
|
|
107
|
+
@train_loader.each_batch.with_index do |batch, batch_idx|
|
|
108
|
+
run_callbacks(:on_batch_begin, self, batch_idx)
|
|
109
|
+
|
|
110
|
+
optimizer.zero_grad
|
|
111
|
+
|
|
112
|
+
output = @model.call(
|
|
113
|
+
batch[:input_ids],
|
|
114
|
+
attention_mask: batch[:attention_mask],
|
|
115
|
+
token_type_ids: batch[:token_type_ids],
|
|
116
|
+
labels: batch[:labels]
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
loss = output[:loss]
|
|
120
|
+
loss.backward
|
|
121
|
+
optimizer.step
|
|
122
|
+
|
|
123
|
+
batch_loss = loss.item
|
|
124
|
+
total_loss += batch_loss * batch[:labels].size(0)
|
|
125
|
+
|
|
126
|
+
predictions = output[:logits].argmax(dim: 1)
|
|
127
|
+
correct += predictions.eq(batch[:labels]).sum.item
|
|
128
|
+
total += batch[:labels].size(0)
|
|
129
|
+
|
|
130
|
+
run_callbacks(:on_batch_end, self, batch_idx, batch_loss)
|
|
131
|
+
end
|
|
132
|
+
|
|
133
|
+
{
|
|
134
|
+
loss: total_loss / total,
|
|
135
|
+
accuracy: correct.to_f / total
|
|
136
|
+
}
|
|
137
|
+
end
|
|
138
|
+
|
|
139
|
+
def build_optimizer
|
|
140
|
+
params = @model.parameters.select(&:requires_grad)
|
|
141
|
+
|
|
142
|
+
case @config.optimizer
|
|
143
|
+
when :adam
|
|
144
|
+
Torch::Optim::Adam.new(params, lr: @config.learning_rate)
|
|
145
|
+
when :adamw
|
|
146
|
+
Torch::Optim::AdamW.new(
|
|
147
|
+
params,
|
|
148
|
+
lr: @config.learning_rate,
|
|
149
|
+
weight_decay: @config.weight_decay
|
|
150
|
+
)
|
|
151
|
+
when :sgd
|
|
152
|
+
Torch::Optim::SGD.new(params, lr: @config.learning_rate, momentum: 0.9)
|
|
153
|
+
else
|
|
154
|
+
raise ConfigurationError, "Unknown optimizer: #{@config.optimizer}"
|
|
155
|
+
end
|
|
156
|
+
end
|
|
157
|
+
|
|
158
|
+
def build_scheduler(optimizer)
|
|
159
|
+
return nil unless @config.scheduler
|
|
160
|
+
|
|
161
|
+
case @config.scheduler
|
|
162
|
+
when :cosine
|
|
163
|
+
Torch::Optim::LRScheduler::CosineAnnealingLR.new(optimizer, @config.epochs)
|
|
164
|
+
when :linear
|
|
165
|
+
# Linear decay
|
|
166
|
+
Torch::Optim::LRScheduler::StepLR.new(optimizer, step_size: 1, gamma: 0.9)
|
|
167
|
+
else
|
|
168
|
+
nil
|
|
169
|
+
end
|
|
170
|
+
end
|
|
171
|
+
|
|
172
|
+
def run_callbacks(method, *args)
|
|
173
|
+
@config.callbacks.each do |callback|
|
|
174
|
+
callback.send(method, *args)
|
|
175
|
+
end
|
|
176
|
+
end
|
|
177
|
+
end
|
|
178
|
+
|
|
179
|
+
# Trainer for text embedding models (contrastive learning)
|
|
180
|
+
class EmbeddingTrainer
|
|
181
|
+
attr_reader :model, :config, :train_loader
|
|
182
|
+
attr_accessor :stop_training
|
|
183
|
+
|
|
184
|
+
def initialize(model, config, train_dataset:)
|
|
185
|
+
@model = model
|
|
186
|
+
@config = config
|
|
187
|
+
@stop_training = false
|
|
188
|
+
|
|
189
|
+
@train_loader = Datasets::TextDataLoader.new(
|
|
190
|
+
train_dataset,
|
|
191
|
+
batch_size: config.batch_size,
|
|
192
|
+
shuffle: true
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
@history = []
|
|
196
|
+
end
|
|
197
|
+
|
|
198
|
+
def fit
|
|
199
|
+
@model.train
|
|
200
|
+
|
|
201
|
+
optimizer = build_optimizer
|
|
202
|
+
|
|
203
|
+
run_callbacks(:on_train_begin, self)
|
|
204
|
+
|
|
205
|
+
@config.epochs.times do |epoch|
|
|
206
|
+
break if @stop_training
|
|
207
|
+
|
|
208
|
+
run_callbacks(:on_epoch_begin, self, epoch)
|
|
209
|
+
|
|
210
|
+
metrics = train_epoch(optimizer)
|
|
211
|
+
@history << metrics
|
|
212
|
+
|
|
213
|
+
run_callbacks(:on_epoch_end, self, epoch, metrics)
|
|
214
|
+
end
|
|
215
|
+
|
|
216
|
+
run_callbacks(:on_train_end, self)
|
|
217
|
+
|
|
218
|
+
@history
|
|
219
|
+
end
|
|
220
|
+
|
|
221
|
+
private
|
|
222
|
+
|
|
223
|
+
def train_epoch(optimizer)
|
|
224
|
+
total_loss = 0.0
|
|
225
|
+
num_batches = 0
|
|
226
|
+
|
|
227
|
+
@train_loader.each_batch do |batch|
|
|
228
|
+
optimizer.zero_grad
|
|
229
|
+
|
|
230
|
+
# For pair datasets, we get anchor and positive texts
|
|
231
|
+
embeddings = @model.encode(
|
|
232
|
+
batch[:input_ids],
|
|
233
|
+
attention_mask: batch[:attention_mask]
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
# Multiple Negatives Ranking Loss
|
|
237
|
+
# Treat other samples in batch as negatives
|
|
238
|
+
loss = multiple_negatives_ranking_loss(embeddings)
|
|
239
|
+
|
|
240
|
+
loss.backward
|
|
241
|
+
optimizer.step
|
|
242
|
+
|
|
243
|
+
total_loss += loss.item
|
|
244
|
+
num_batches += 1
|
|
245
|
+
end
|
|
246
|
+
|
|
247
|
+
{ loss: total_loss / num_batches }
|
|
248
|
+
end
|
|
249
|
+
|
|
250
|
+
def multiple_negatives_ranking_loss(embeddings, scale: 20.0)
|
|
251
|
+
# Split embeddings into anchors and positives (assuming paired data)
|
|
252
|
+
batch_size = embeddings.size(0) / 2
|
|
253
|
+
anchors = embeddings[0...batch_size]
|
|
254
|
+
positives = embeddings[batch_size..]
|
|
255
|
+
|
|
256
|
+
# Compute similarity matrix
|
|
257
|
+
scores = Torch.matmul(anchors, positives.transpose(0, 1)) * scale
|
|
258
|
+
|
|
259
|
+
# Labels: diagonal is positive (index i matches index i)
|
|
260
|
+
labels = Torch.arange(batch_size, device: embeddings.device)
|
|
261
|
+
|
|
262
|
+
Torch::NN::Functional.cross_entropy(scores, labels)
|
|
263
|
+
end
|
|
264
|
+
|
|
265
|
+
def build_optimizer
|
|
266
|
+
params = @model.parameters.select(&:requires_grad)
|
|
267
|
+
Torch::Optim::AdamW.new(params, lr: @config.learning_rate, weight_decay: @config.weight_decay)
|
|
268
|
+
end
|
|
269
|
+
|
|
270
|
+
def run_callbacks(method, *args)
|
|
271
|
+
@config.callbacks.each { |cb| cb.send(method, *args) }
|
|
272
|
+
end
|
|
273
|
+
end
|
|
274
|
+
end
|
|
275
|
+
end
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fine
|
|
4
|
+
module Training
|
|
5
|
+
# Main training orchestrator
|
|
6
|
+
class Trainer
|
|
7
|
+
attr_reader :model, :config, :train_loader, :val_loader, :label_map
|
|
8
|
+
attr_accessor :stop_training
|
|
9
|
+
|
|
10
|
+
def initialize(model, config, train_dataset:, val_dataset: nil)
|
|
11
|
+
@model = model
|
|
12
|
+
@config = config
|
|
13
|
+
@stop_training = false
|
|
14
|
+
@label_map = train_dataset.label_map
|
|
15
|
+
|
|
16
|
+
# Create data loaders
|
|
17
|
+
@train_loader = Datasets::DataLoader.new(
|
|
18
|
+
train_dataset,
|
|
19
|
+
batch_size: config.batch_size,
|
|
20
|
+
shuffle: true
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
@val_loader = if val_dataset
|
|
24
|
+
Datasets::DataLoader.new(
|
|
25
|
+
val_dataset,
|
|
26
|
+
batch_size: config.batch_size,
|
|
27
|
+
shuffle: false
|
|
28
|
+
)
|
|
29
|
+
end
|
|
30
|
+
|
|
31
|
+
# History tracking
|
|
32
|
+
@history = []
|
|
33
|
+
end
|
|
34
|
+
|
|
35
|
+
# Train the model
|
|
36
|
+
#
|
|
37
|
+
# @return [Array<Hash>] Training history (metrics per epoch)
|
|
38
|
+
def fit
|
|
39
|
+
@model.train
|
|
40
|
+
|
|
41
|
+
# Build optimizer
|
|
42
|
+
optimizer = build_optimizer
|
|
43
|
+
scheduler = build_scheduler(optimizer)
|
|
44
|
+
|
|
45
|
+
run_callbacks(:on_train_begin, self)
|
|
46
|
+
|
|
47
|
+
@config.epochs.times do |epoch|
|
|
48
|
+
break if @stop_training
|
|
49
|
+
|
|
50
|
+
run_callbacks(:on_epoch_begin, self, epoch)
|
|
51
|
+
|
|
52
|
+
# Training epoch
|
|
53
|
+
train_metrics = train_epoch(optimizer, epoch)
|
|
54
|
+
|
|
55
|
+
# Validation
|
|
56
|
+
val_metrics = @val_loader ? evaluate : {}
|
|
57
|
+
|
|
58
|
+
# Step scheduler
|
|
59
|
+
scheduler&.step
|
|
60
|
+
|
|
61
|
+
# Combine metrics
|
|
62
|
+
metrics = train_metrics.merge(
|
|
63
|
+
val_metrics.transform_keys { |k| :"val_#{k}" }
|
|
64
|
+
)
|
|
65
|
+
@history << metrics
|
|
66
|
+
|
|
67
|
+
run_callbacks(:on_epoch_end, self, epoch, metrics)
|
|
68
|
+
end
|
|
69
|
+
|
|
70
|
+
run_callbacks(:on_train_end, self)
|
|
71
|
+
|
|
72
|
+
@history
|
|
73
|
+
end
|
|
74
|
+
|
|
75
|
+
# Evaluate on validation set
|
|
76
|
+
#
|
|
77
|
+
# @return [Hash] Evaluation metrics
|
|
78
|
+
def evaluate
|
|
79
|
+
@model.eval
|
|
80
|
+
|
|
81
|
+
total_loss = 0.0
|
|
82
|
+
correct = 0
|
|
83
|
+
total = 0
|
|
84
|
+
|
|
85
|
+
Torch.no_grad do
|
|
86
|
+
@val_loader.each_batch do |batch|
|
|
87
|
+
output = @model.call(batch[:pixel_values], labels: batch[:labels])
|
|
88
|
+
|
|
89
|
+
total_loss += output[:loss].item * batch[:labels].size(0)
|
|
90
|
+
predictions = output[:logits].argmax(dim: 1)
|
|
91
|
+
correct += (predictions == batch[:labels]).sum.item
|
|
92
|
+
total += batch[:labels].size(0)
|
|
93
|
+
end
|
|
94
|
+
end
|
|
95
|
+
|
|
96
|
+
@model.train
|
|
97
|
+
|
|
98
|
+
{
|
|
99
|
+
loss: total_loss / total,
|
|
100
|
+
accuracy: correct.to_f / total
|
|
101
|
+
}
|
|
102
|
+
end
|
|
103
|
+
|
|
104
|
+
private
|
|
105
|
+
|
|
106
|
+
def train_epoch(optimizer, epoch)
|
|
107
|
+
total_loss = 0.0
|
|
108
|
+
correct = 0
|
|
109
|
+
total = 0
|
|
110
|
+
|
|
111
|
+
@train_loader.each_batch.with_index do |batch, batch_idx|
|
|
112
|
+
run_callbacks(:on_batch_begin, self, batch_idx)
|
|
113
|
+
|
|
114
|
+
# Zero gradients
|
|
115
|
+
optimizer.zero_grad
|
|
116
|
+
|
|
117
|
+
# Forward pass
|
|
118
|
+
output = @model.call(batch[:pixel_values], labels: batch[:labels])
|
|
119
|
+
loss = output[:loss]
|
|
120
|
+
|
|
121
|
+
# Backward pass
|
|
122
|
+
loss.backward
|
|
123
|
+
|
|
124
|
+
# Update weights
|
|
125
|
+
optimizer.step
|
|
126
|
+
|
|
127
|
+
# Track metrics
|
|
128
|
+
batch_loss = loss.item
|
|
129
|
+
total_loss += batch_loss * batch[:labels].size(0)
|
|
130
|
+
|
|
131
|
+
predictions = output[:logits].argmax(dim: 1)
|
|
132
|
+
correct += predictions.eq(batch[:labels]).sum.item
|
|
133
|
+
total += batch[:labels].size(0)
|
|
134
|
+
|
|
135
|
+
run_callbacks(:on_batch_end, self, batch_idx, batch_loss)
|
|
136
|
+
end
|
|
137
|
+
|
|
138
|
+
{
|
|
139
|
+
loss: total_loss / total,
|
|
140
|
+
accuracy: correct.to_f / total
|
|
141
|
+
}
|
|
142
|
+
end
|
|
143
|
+
|
|
144
|
+
def build_optimizer
|
|
145
|
+
params = @model.parameters.select(&:requires_grad)
|
|
146
|
+
|
|
147
|
+
case @config.optimizer
|
|
148
|
+
when :adam
|
|
149
|
+
Torch::Optim::Adam.new(params, lr: @config.learning_rate)
|
|
150
|
+
when :adamw
|
|
151
|
+
Torch::Optim::AdamW.new(
|
|
152
|
+
params,
|
|
153
|
+
lr: @config.learning_rate,
|
|
154
|
+
weight_decay: @config.weight_decay
|
|
155
|
+
)
|
|
156
|
+
when :sgd
|
|
157
|
+
Torch::Optim::SGD.new(
|
|
158
|
+
params,
|
|
159
|
+
lr: @config.learning_rate,
|
|
160
|
+
momentum: 0.9
|
|
161
|
+
)
|
|
162
|
+
else
|
|
163
|
+
raise ConfigurationError, "Unknown optimizer: #{@config.optimizer}"
|
|
164
|
+
end
|
|
165
|
+
end
|
|
166
|
+
|
|
167
|
+
def build_scheduler(optimizer)
|
|
168
|
+
return nil unless @config.scheduler
|
|
169
|
+
|
|
170
|
+
case @config.scheduler
|
|
171
|
+
when :cosine
|
|
172
|
+
Torch::Optim::LRScheduler::CosineAnnealingLR.new(
|
|
173
|
+
optimizer,
|
|
174
|
+
@config.epochs
|
|
175
|
+
)
|
|
176
|
+
when :step
|
|
177
|
+
Torch::Optim::LRScheduler::StepLR.new(
|
|
178
|
+
optimizer,
|
|
179
|
+
step_size: @config.epochs / 3,
|
|
180
|
+
gamma: 0.1
|
|
181
|
+
)
|
|
182
|
+
else
|
|
183
|
+
nil
|
|
184
|
+
end
|
|
185
|
+
end
|
|
186
|
+
|
|
187
|
+
def run_callbacks(method, *args)
|
|
188
|
+
@config.callbacks.each do |callback|
|
|
189
|
+
callback.send(method, *args)
|
|
190
|
+
end
|
|
191
|
+
end
|
|
192
|
+
end
|
|
193
|
+
end
|
|
194
|
+
end
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fine
|
|
4
|
+
module Transforms
|
|
5
|
+
# Composes multiple transforms into a single callable
|
|
6
|
+
class Compose
|
|
7
|
+
attr_reader :transforms
|
|
8
|
+
|
|
9
|
+
def initialize(transforms)
|
|
10
|
+
@transforms = transforms
|
|
11
|
+
end
|
|
12
|
+
|
|
13
|
+
def call(image)
|
|
14
|
+
@transforms.reduce(image) { |img, transform| transform.call(img) }
|
|
15
|
+
end
|
|
16
|
+
|
|
17
|
+
def <<(transform)
|
|
18
|
+
@transforms << transform
|
|
19
|
+
self
|
|
20
|
+
end
|
|
21
|
+
|
|
22
|
+
def prepend(transform)
|
|
23
|
+
@transforms.unshift(transform)
|
|
24
|
+
self
|
|
25
|
+
end
|
|
26
|
+
end
|
|
27
|
+
end
|
|
28
|
+
end
|