fine 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. checksums.yaml +7 -0
  2. data/.rspec +3 -0
  3. data/CHANGELOG.md +38 -0
  4. data/Gemfile +6 -0
  5. data/Gemfile.lock +167 -0
  6. data/LICENSE +21 -0
  7. data/README.md +212 -0
  8. data/Rakefile +6 -0
  9. data/docs/installation.md +151 -0
  10. data/docs/tutorials/llm-fine-tuning.md +246 -0
  11. data/docs/tutorials/model-export.md +200 -0
  12. data/docs/tutorials/siglip2-image-classification.md +130 -0
  13. data/docs/tutorials/siglip2-object-recognition.md +203 -0
  14. data/docs/tutorials/siglip2-similarity-search.md +152 -0
  15. data/docs/tutorials/text-classification.md +233 -0
  16. data/docs/tutorials/text-embeddings.md +211 -0
  17. data/examples/basic_classification.rb +70 -0
  18. data/examples/data/tool_calls.jsonl +30 -0
  19. data/examples/demo_training.rb +78 -0
  20. data/examples/finetune_gemma3_tools.rb +135 -0
  21. data/examples/real_llm_test.rb +128 -0
  22. data/examples/real_text_classification_test.rb +90 -0
  23. data/examples/real_text_embedder_test.rb +110 -0
  24. data/examples/real_training_test.rb +88 -0
  25. data/examples/test_export.rb +28 -0
  26. data/examples/test_image_classifier.rb +79 -0
  27. data/examples/test_llm.rb +100 -0
  28. data/examples/test_text_classifier.rb +59 -0
  29. data/lib/fine/callbacks/base.rb +140 -0
  30. data/lib/fine/callbacks/progress_bar.rb +66 -0
  31. data/lib/fine/configuration.rb +106 -0
  32. data/lib/fine/datasets/data_loader.rb +63 -0
  33. data/lib/fine/datasets/image_dataset.rb +203 -0
  34. data/lib/fine/datasets/instruction_dataset.rb +226 -0
  35. data/lib/fine/datasets/text_data_loader.rb +88 -0
  36. data/lib/fine/datasets/text_dataset.rb +266 -0
  37. data/lib/fine/error.rb +49 -0
  38. data/lib/fine/export/gguf_exporter.rb +424 -0
  39. data/lib/fine/export/onnx_exporter.rb +249 -0
  40. data/lib/fine/export.rb +53 -0
  41. data/lib/fine/hub/config_loader.rb +145 -0
  42. data/lib/fine/hub/model_downloader.rb +136 -0
  43. data/lib/fine/hub/safetensors_loader.rb +108 -0
  44. data/lib/fine/image_classifier.rb +256 -0
  45. data/lib/fine/llm.rb +336 -0
  46. data/lib/fine/models/base.rb +48 -0
  47. data/lib/fine/models/bert_encoder.rb +202 -0
  48. data/lib/fine/models/bert_for_sequence_classification.rb +226 -0
  49. data/lib/fine/models/causal_lm.rb +279 -0
  50. data/lib/fine/models/classification_head.rb +24 -0
  51. data/lib/fine/models/gemma3_decoder.rb +244 -0
  52. data/lib/fine/models/llama_decoder.rb +297 -0
  53. data/lib/fine/models/sentence_transformer.rb +202 -0
  54. data/lib/fine/models/siglip2_for_image_classification.rb +155 -0
  55. data/lib/fine/models/siglip2_vision_encoder.rb +190 -0
  56. data/lib/fine/text_classifier.rb +250 -0
  57. data/lib/fine/text_embedder.rb +221 -0
  58. data/lib/fine/tokenizers/auto_tokenizer.rb +208 -0
  59. data/lib/fine/training/llm_trainer.rb +212 -0
  60. data/lib/fine/training/text_trainer.rb +275 -0
  61. data/lib/fine/training/trainer.rb +194 -0
  62. data/lib/fine/transforms/compose.rb +28 -0
  63. data/lib/fine/transforms/normalize.rb +33 -0
  64. data/lib/fine/transforms/resize.rb +35 -0
  65. data/lib/fine/transforms/to_tensor.rb +53 -0
  66. data/lib/fine/version.rb +3 -0
  67. data/lib/fine.rb +112 -0
  68. data/mise.toml +2 -0
  69. metadata +240 -0
@@ -0,0 +1,212 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Fine
4
+ module Training
5
+ # Trainer for causal language model fine-tuning
6
+ class LLMTrainer
7
+ attr_reader :model, :config, :train_dataset, :val_dataset
8
+
9
+ def initialize(model, config, train_dataset:, val_dataset: nil)
10
+ @model = model
11
+ @config = config
12
+ @train_dataset = train_dataset
13
+ @val_dataset = val_dataset
14
+ @device = Fine.device
15
+ @history = []
16
+ end
17
+
18
+ def fit
19
+ @model.to(@device)
20
+ @model.train
21
+
22
+ optimizer = create_optimizer
23
+ scheduler = create_scheduler(optimizer)
24
+
25
+ @config.callbacks.each { |cb| cb.on_train_begin(model: @model, config: @config) }
26
+
27
+ @config.epochs.times do |epoch|
28
+ epoch_loss = train_epoch(optimizer, scheduler, epoch)
29
+
30
+ metrics = { loss: epoch_loss }
31
+
32
+ # Validation
33
+ if @val_dataset
34
+ val_loss, val_perplexity = evaluate
35
+ metrics[:val_loss] = val_loss
36
+ metrics[:val_perplexity] = val_perplexity
37
+ end
38
+
39
+ @history << { epoch: epoch + 1, **metrics }
40
+
41
+ @config.callbacks.each { |cb| cb.on_epoch_end(self, epoch + 1, metrics) }
42
+ end
43
+
44
+ @config.callbacks.each { |cb| cb.on_train_end(history: @history) }
45
+
46
+ @history
47
+ end
48
+
49
+ def evaluate
50
+ @model.eval
51
+ total_loss = 0.0
52
+ num_batches = 0
53
+
54
+ data_loader = create_data_loader(@val_dataset, shuffle: false)
55
+
56
+ Torch.no_grad do
57
+ data_loader.each do |batch|
58
+ batch = move_to_device(batch)
59
+
60
+ outputs = @model.forward(
61
+ batch[:input_ids],
62
+ attention_mask: batch[:attention_mask],
63
+ labels: batch[:labels]
64
+ )
65
+
66
+ total_loss += outputs[:loss].to(:float32).item
67
+ num_batches += 1
68
+ end
69
+ end
70
+
71
+ @model.train
72
+
73
+ avg_loss = total_loss / num_batches
74
+ perplexity = Math.exp(avg_loss)
75
+
76
+ [avg_loss, perplexity]
77
+ end
78
+
79
+ private
80
+
81
+ def train_epoch(optimizer, scheduler, epoch)
82
+ total_loss = 0.0
83
+ num_batches = 0
84
+
85
+ data_loader = create_data_loader(@train_dataset, shuffle: true)
86
+ total_steps = data_loader.size
87
+
88
+ @config.callbacks.each do |cb|
89
+ cb.on_epoch_begin(epoch + 1, total_steps: total_steps) if cb.respond_to?(:on_epoch_begin)
90
+ end
91
+
92
+ data_loader.each_with_index do |batch, step|
93
+ # Move batch to device
94
+ input_ids = batch[:input_ids].to(@device)
95
+ attention_mask = batch[:attention_mask].to(@device)
96
+ labels = batch[:labels].to(@device)
97
+
98
+ # Forward pass
99
+ outputs = @model.forward(input_ids, attention_mask: attention_mask, labels: labels)
100
+
101
+ # Get loss value before backward
102
+ loss_value = outputs[:loss].detach.to(:float32).item
103
+
104
+ # Backward pass - scale loss for gradient accumulation
105
+ scaled_loss = outputs[:loss] / @config.gradient_accumulation_steps
106
+ scaled_loss.backward
107
+
108
+ # CRITICAL: Clear ALL references to free computation graph
109
+ scaled_loss = nil
110
+ outputs = nil
111
+ input_ids = nil
112
+ attention_mask = nil
113
+ labels = nil
114
+ batch = nil
115
+
116
+ if (step + 1) % @config.gradient_accumulation_steps == 0
117
+ # Gradient clipping
118
+ if @config.max_grad_norm
119
+ clip_grad_norm(@model.parameters, @config.max_grad_norm)
120
+ end
121
+
122
+ optimizer.step
123
+ scheduler&.step
124
+ optimizer.zero_grad
125
+
126
+ # Force GC after each optimizer step to free computation graphs
127
+ GC.start
128
+ end
129
+
130
+ total_loss += loss_value
131
+ num_batches += 1
132
+
133
+ @config.callbacks.each do |cb|
134
+ cb.on_batch_end(self, step + 1, loss_value) if cb.respond_to?(:on_batch_end)
135
+ end
136
+ end
137
+
138
+ total_loss / num_batches
139
+ end
140
+
141
+ def create_data_loader(dataset, shuffle:)
142
+ Datasets::InstructionDataLoader.new(
143
+ dataset,
144
+ batch_size: @config.batch_size,
145
+ shuffle: shuffle,
146
+ pad_token_id: @config.pad_token_id || 0
147
+ )
148
+ end
149
+
150
+ def create_optimizer
151
+ # Separate weight decay for different parameter groups
152
+ decay_params = []
153
+ no_decay_params = []
154
+
155
+ @model.named_parameters.each do |name, param|
156
+ next unless param.requires_grad
157
+
158
+ if name.include?("bias") || name.include?("layernorm") || name.include?("norm")
159
+ no_decay_params << param
160
+ else
161
+ decay_params << param
162
+ end
163
+ end
164
+
165
+ param_groups = [
166
+ { params: decay_params, weight_decay: @config.weight_decay },
167
+ { params: no_decay_params, weight_decay: 0.0 }
168
+ ]
169
+
170
+ Torch::Optim::AdamW.new(param_groups, lr: @config.learning_rate)
171
+ end
172
+
173
+ def create_scheduler(optimizer)
174
+ return nil unless @config.warmup_steps && @config.warmup_steps > 0
175
+
176
+ # Linear warmup then constant
177
+ # Note: torch.rb scheduler support may be limited
178
+ nil
179
+ end
180
+
181
+ def move_to_device(batch)
182
+ batch.transform_values { |v| v.to(@device) }
183
+ end
184
+
185
+ # Manual gradient clipping implementation
186
+ def clip_grad_norm(parameters, max_norm)
187
+ total_norm = 0.0
188
+
189
+ parameters.each do |param|
190
+ next unless param.grad
191
+
192
+ # Convert to float32 for .item (bfloat16 not supported)
193
+ param_norm = param.grad.data.norm(2).to(:float32).item
194
+ total_norm += param_norm ** 2
195
+ end
196
+
197
+ total_norm = Math.sqrt(total_norm)
198
+ clip_coef = max_norm / (total_norm + 1e-6)
199
+
200
+ if clip_coef < 1.0
201
+ parameters.each do |param|
202
+ next unless param.grad
203
+
204
+ param.grad.data.mul!(clip_coef)
205
+ end
206
+ end
207
+
208
+ total_norm
209
+ end
210
+ end
211
+ end
212
+ end
@@ -0,0 +1,275 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Fine
4
+ module Training
5
+ # Trainer for text classification models
6
+ class TextTrainer
7
+ attr_reader :model, :config, :train_loader, :val_loader, :label_map
8
+ attr_accessor :stop_training
9
+
10
+ def initialize(model, config, train_dataset:, val_dataset: nil)
11
+ @model = model
12
+ @config = config
13
+ @stop_training = false
14
+ @label_map = train_dataset.label_map
15
+
16
+ @train_loader = Datasets::TextDataLoader.new(
17
+ train_dataset,
18
+ batch_size: config.batch_size,
19
+ shuffle: true
20
+ )
21
+
22
+ @val_loader = if val_dataset
23
+ Datasets::TextDataLoader.new(
24
+ val_dataset,
25
+ batch_size: config.batch_size,
26
+ shuffle: false
27
+ )
28
+ end
29
+
30
+ @history = []
31
+ end
32
+
33
+ # Train the model
34
+ #
35
+ # @return [Array<Hash>] Training history
36
+ def fit
37
+ @model.train
38
+
39
+ optimizer = build_optimizer
40
+ scheduler = build_scheduler(optimizer)
41
+
42
+ run_callbacks(:on_train_begin, self)
43
+
44
+ @config.epochs.times do |epoch|
45
+ break if @stop_training
46
+
47
+ run_callbacks(:on_epoch_begin, self, epoch)
48
+
49
+ train_metrics = train_epoch(optimizer, epoch)
50
+
51
+ val_metrics = @val_loader ? evaluate : {}
52
+
53
+ scheduler&.step
54
+
55
+ metrics = train_metrics.merge(
56
+ val_metrics.transform_keys { |k| :"val_#{k}" }
57
+ )
58
+ @history << metrics
59
+
60
+ run_callbacks(:on_epoch_end, self, epoch, metrics)
61
+ end
62
+
63
+ run_callbacks(:on_train_end, self)
64
+
65
+ @history
66
+ end
67
+
68
+ # Evaluate on validation set
69
+ def evaluate
70
+ @model.eval
71
+
72
+ total_loss = 0.0
73
+ correct = 0
74
+ total = 0
75
+
76
+ Torch.no_grad do
77
+ @val_loader.each_batch do |batch|
78
+ output = @model.call(
79
+ batch[:input_ids],
80
+ attention_mask: batch[:attention_mask],
81
+ token_type_ids: batch[:token_type_ids],
82
+ labels: batch[:labels]
83
+ )
84
+
85
+ total_loss += output[:loss].item * batch[:labels].size(0)
86
+ predictions = output[:logits].argmax(dim: 1)
87
+ correct += predictions.eq(batch[:labels]).sum.item
88
+ total += batch[:labels].size(0)
89
+ end
90
+ end
91
+
92
+ @model.train
93
+
94
+ {
95
+ loss: total_loss / total,
96
+ accuracy: correct.to_f / total
97
+ }
98
+ end
99
+
100
+ private
101
+
102
+ def train_epoch(optimizer, _epoch)
103
+ total_loss = 0.0
104
+ correct = 0
105
+ total = 0
106
+
107
+ @train_loader.each_batch.with_index do |batch, batch_idx|
108
+ run_callbacks(:on_batch_begin, self, batch_idx)
109
+
110
+ optimizer.zero_grad
111
+
112
+ output = @model.call(
113
+ batch[:input_ids],
114
+ attention_mask: batch[:attention_mask],
115
+ token_type_ids: batch[:token_type_ids],
116
+ labels: batch[:labels]
117
+ )
118
+
119
+ loss = output[:loss]
120
+ loss.backward
121
+ optimizer.step
122
+
123
+ batch_loss = loss.item
124
+ total_loss += batch_loss * batch[:labels].size(0)
125
+
126
+ predictions = output[:logits].argmax(dim: 1)
127
+ correct += predictions.eq(batch[:labels]).sum.item
128
+ total += batch[:labels].size(0)
129
+
130
+ run_callbacks(:on_batch_end, self, batch_idx, batch_loss)
131
+ end
132
+
133
+ {
134
+ loss: total_loss / total,
135
+ accuracy: correct.to_f / total
136
+ }
137
+ end
138
+
139
+ def build_optimizer
140
+ params = @model.parameters.select(&:requires_grad)
141
+
142
+ case @config.optimizer
143
+ when :adam
144
+ Torch::Optim::Adam.new(params, lr: @config.learning_rate)
145
+ when :adamw
146
+ Torch::Optim::AdamW.new(
147
+ params,
148
+ lr: @config.learning_rate,
149
+ weight_decay: @config.weight_decay
150
+ )
151
+ when :sgd
152
+ Torch::Optim::SGD.new(params, lr: @config.learning_rate, momentum: 0.9)
153
+ else
154
+ raise ConfigurationError, "Unknown optimizer: #{@config.optimizer}"
155
+ end
156
+ end
157
+
158
+ def build_scheduler(optimizer)
159
+ return nil unless @config.scheduler
160
+
161
+ case @config.scheduler
162
+ when :cosine
163
+ Torch::Optim::LRScheduler::CosineAnnealingLR.new(optimizer, @config.epochs)
164
+ when :linear
165
+ # Linear decay
166
+ Torch::Optim::LRScheduler::StepLR.new(optimizer, step_size: 1, gamma: 0.9)
167
+ else
168
+ nil
169
+ end
170
+ end
171
+
172
+ def run_callbacks(method, *args)
173
+ @config.callbacks.each do |callback|
174
+ callback.send(method, *args)
175
+ end
176
+ end
177
+ end
178
+
179
+ # Trainer for text embedding models (contrastive learning)
180
+ class EmbeddingTrainer
181
+ attr_reader :model, :config, :train_loader
182
+ attr_accessor :stop_training
183
+
184
+ def initialize(model, config, train_dataset:)
185
+ @model = model
186
+ @config = config
187
+ @stop_training = false
188
+
189
+ @train_loader = Datasets::TextDataLoader.new(
190
+ train_dataset,
191
+ batch_size: config.batch_size,
192
+ shuffle: true
193
+ )
194
+
195
+ @history = []
196
+ end
197
+
198
+ def fit
199
+ @model.train
200
+
201
+ optimizer = build_optimizer
202
+
203
+ run_callbacks(:on_train_begin, self)
204
+
205
+ @config.epochs.times do |epoch|
206
+ break if @stop_training
207
+
208
+ run_callbacks(:on_epoch_begin, self, epoch)
209
+
210
+ metrics = train_epoch(optimizer)
211
+ @history << metrics
212
+
213
+ run_callbacks(:on_epoch_end, self, epoch, metrics)
214
+ end
215
+
216
+ run_callbacks(:on_train_end, self)
217
+
218
+ @history
219
+ end
220
+
221
+ private
222
+
223
+ def train_epoch(optimizer)
224
+ total_loss = 0.0
225
+ num_batches = 0
226
+
227
+ @train_loader.each_batch do |batch|
228
+ optimizer.zero_grad
229
+
230
+ # For pair datasets, we get anchor and positive texts
231
+ embeddings = @model.encode(
232
+ batch[:input_ids],
233
+ attention_mask: batch[:attention_mask]
234
+ )
235
+
236
+ # Multiple Negatives Ranking Loss
237
+ # Treat other samples in batch as negatives
238
+ loss = multiple_negatives_ranking_loss(embeddings)
239
+
240
+ loss.backward
241
+ optimizer.step
242
+
243
+ total_loss += loss.item
244
+ num_batches += 1
245
+ end
246
+
247
+ { loss: total_loss / num_batches }
248
+ end
249
+
250
+ def multiple_negatives_ranking_loss(embeddings, scale: 20.0)
251
+ # Split embeddings into anchors and positives (assuming paired data)
252
+ batch_size = embeddings.size(0) / 2
253
+ anchors = embeddings[0...batch_size]
254
+ positives = embeddings[batch_size..]
255
+
256
+ # Compute similarity matrix
257
+ scores = Torch.matmul(anchors, positives.transpose(0, 1)) * scale
258
+
259
+ # Labels: diagonal is positive (index i matches index i)
260
+ labels = Torch.arange(batch_size, device: embeddings.device)
261
+
262
+ Torch::NN::Functional.cross_entropy(scores, labels)
263
+ end
264
+
265
+ def build_optimizer
266
+ params = @model.parameters.select(&:requires_grad)
267
+ Torch::Optim::AdamW.new(params, lr: @config.learning_rate, weight_decay: @config.weight_decay)
268
+ end
269
+
270
+ def run_callbacks(method, *args)
271
+ @config.callbacks.each { |cb| cb.send(method, *args) }
272
+ end
273
+ end
274
+ end
275
+ end
@@ -0,0 +1,194 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Fine
4
+ module Training
5
+ # Main training orchestrator
6
+ class Trainer
7
+ attr_reader :model, :config, :train_loader, :val_loader, :label_map
8
+ attr_accessor :stop_training
9
+
10
+ def initialize(model, config, train_dataset:, val_dataset: nil)
11
+ @model = model
12
+ @config = config
13
+ @stop_training = false
14
+ @label_map = train_dataset.label_map
15
+
16
+ # Create data loaders
17
+ @train_loader = Datasets::DataLoader.new(
18
+ train_dataset,
19
+ batch_size: config.batch_size,
20
+ shuffle: true
21
+ )
22
+
23
+ @val_loader = if val_dataset
24
+ Datasets::DataLoader.new(
25
+ val_dataset,
26
+ batch_size: config.batch_size,
27
+ shuffle: false
28
+ )
29
+ end
30
+
31
+ # History tracking
32
+ @history = []
33
+ end
34
+
35
+ # Train the model
36
+ #
37
+ # @return [Array<Hash>] Training history (metrics per epoch)
38
+ def fit
39
+ @model.train
40
+
41
+ # Build optimizer
42
+ optimizer = build_optimizer
43
+ scheduler = build_scheduler(optimizer)
44
+
45
+ run_callbacks(:on_train_begin, self)
46
+
47
+ @config.epochs.times do |epoch|
48
+ break if @stop_training
49
+
50
+ run_callbacks(:on_epoch_begin, self, epoch)
51
+
52
+ # Training epoch
53
+ train_metrics = train_epoch(optimizer, epoch)
54
+
55
+ # Validation
56
+ val_metrics = @val_loader ? evaluate : {}
57
+
58
+ # Step scheduler
59
+ scheduler&.step
60
+
61
+ # Combine metrics
62
+ metrics = train_metrics.merge(
63
+ val_metrics.transform_keys { |k| :"val_#{k}" }
64
+ )
65
+ @history << metrics
66
+
67
+ run_callbacks(:on_epoch_end, self, epoch, metrics)
68
+ end
69
+
70
+ run_callbacks(:on_train_end, self)
71
+
72
+ @history
73
+ end
74
+
75
+ # Evaluate on validation set
76
+ #
77
+ # @return [Hash] Evaluation metrics
78
+ def evaluate
79
+ @model.eval
80
+
81
+ total_loss = 0.0
82
+ correct = 0
83
+ total = 0
84
+
85
+ Torch.no_grad do
86
+ @val_loader.each_batch do |batch|
87
+ output = @model.call(batch[:pixel_values], labels: batch[:labels])
88
+
89
+ total_loss += output[:loss].item * batch[:labels].size(0)
90
+ predictions = output[:logits].argmax(dim: 1)
91
+ correct += (predictions == batch[:labels]).sum.item
92
+ total += batch[:labels].size(0)
93
+ end
94
+ end
95
+
96
+ @model.train
97
+
98
+ {
99
+ loss: total_loss / total,
100
+ accuracy: correct.to_f / total
101
+ }
102
+ end
103
+
104
+ private
105
+
106
+ def train_epoch(optimizer, epoch)
107
+ total_loss = 0.0
108
+ correct = 0
109
+ total = 0
110
+
111
+ @train_loader.each_batch.with_index do |batch, batch_idx|
112
+ run_callbacks(:on_batch_begin, self, batch_idx)
113
+
114
+ # Zero gradients
115
+ optimizer.zero_grad
116
+
117
+ # Forward pass
118
+ output = @model.call(batch[:pixel_values], labels: batch[:labels])
119
+ loss = output[:loss]
120
+
121
+ # Backward pass
122
+ loss.backward
123
+
124
+ # Update weights
125
+ optimizer.step
126
+
127
+ # Track metrics
128
+ batch_loss = loss.item
129
+ total_loss += batch_loss * batch[:labels].size(0)
130
+
131
+ predictions = output[:logits].argmax(dim: 1)
132
+ correct += predictions.eq(batch[:labels]).sum.item
133
+ total += batch[:labels].size(0)
134
+
135
+ run_callbacks(:on_batch_end, self, batch_idx, batch_loss)
136
+ end
137
+
138
+ {
139
+ loss: total_loss / total,
140
+ accuracy: correct.to_f / total
141
+ }
142
+ end
143
+
144
+ def build_optimizer
145
+ params = @model.parameters.select(&:requires_grad)
146
+
147
+ case @config.optimizer
148
+ when :adam
149
+ Torch::Optim::Adam.new(params, lr: @config.learning_rate)
150
+ when :adamw
151
+ Torch::Optim::AdamW.new(
152
+ params,
153
+ lr: @config.learning_rate,
154
+ weight_decay: @config.weight_decay
155
+ )
156
+ when :sgd
157
+ Torch::Optim::SGD.new(
158
+ params,
159
+ lr: @config.learning_rate,
160
+ momentum: 0.9
161
+ )
162
+ else
163
+ raise ConfigurationError, "Unknown optimizer: #{@config.optimizer}"
164
+ end
165
+ end
166
+
167
+ def build_scheduler(optimizer)
168
+ return nil unless @config.scheduler
169
+
170
+ case @config.scheduler
171
+ when :cosine
172
+ Torch::Optim::LRScheduler::CosineAnnealingLR.new(
173
+ optimizer,
174
+ @config.epochs
175
+ )
176
+ when :step
177
+ Torch::Optim::LRScheduler::StepLR.new(
178
+ optimizer,
179
+ step_size: @config.epochs / 3,
180
+ gamma: 0.1
181
+ )
182
+ else
183
+ nil
184
+ end
185
+ end
186
+
187
+ def run_callbacks(method, *args)
188
+ @config.callbacks.each do |callback|
189
+ callback.send(method, *args)
190
+ end
191
+ end
192
+ end
193
+ end
194
+ end
@@ -0,0 +1,28 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Fine
4
+ module Transforms
5
+ # Composes multiple transforms into a single callable
6
+ class Compose
7
+ attr_reader :transforms
8
+
9
+ def initialize(transforms)
10
+ @transforms = transforms
11
+ end
12
+
13
+ def call(image)
14
+ @transforms.reduce(image) { |img, transform| transform.call(img) }
15
+ end
16
+
17
+ def <<(transform)
18
+ @transforms << transform
19
+ self
20
+ end
21
+
22
+ def prepend(transform)
23
+ @transforms.unshift(transform)
24
+ self
25
+ end
26
+ end
27
+ end
28
+ end