fine 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.rspec +3 -0
- data/CHANGELOG.md +38 -0
- data/Gemfile +6 -0
- data/Gemfile.lock +167 -0
- data/LICENSE +21 -0
- data/README.md +212 -0
- data/Rakefile +6 -0
- data/docs/installation.md +151 -0
- data/docs/tutorials/llm-fine-tuning.md +246 -0
- data/docs/tutorials/model-export.md +200 -0
- data/docs/tutorials/siglip2-image-classification.md +130 -0
- data/docs/tutorials/siglip2-object-recognition.md +203 -0
- data/docs/tutorials/siglip2-similarity-search.md +152 -0
- data/docs/tutorials/text-classification.md +233 -0
- data/docs/tutorials/text-embeddings.md +211 -0
- data/examples/basic_classification.rb +70 -0
- data/examples/data/tool_calls.jsonl +30 -0
- data/examples/demo_training.rb +78 -0
- data/examples/finetune_gemma3_tools.rb +135 -0
- data/examples/real_llm_test.rb +128 -0
- data/examples/real_text_classification_test.rb +90 -0
- data/examples/real_text_embedder_test.rb +110 -0
- data/examples/real_training_test.rb +88 -0
- data/examples/test_export.rb +28 -0
- data/examples/test_image_classifier.rb +79 -0
- data/examples/test_llm.rb +100 -0
- data/examples/test_text_classifier.rb +59 -0
- data/lib/fine/callbacks/base.rb +140 -0
- data/lib/fine/callbacks/progress_bar.rb +66 -0
- data/lib/fine/configuration.rb +106 -0
- data/lib/fine/datasets/data_loader.rb +63 -0
- data/lib/fine/datasets/image_dataset.rb +203 -0
- data/lib/fine/datasets/instruction_dataset.rb +226 -0
- data/lib/fine/datasets/text_data_loader.rb +88 -0
- data/lib/fine/datasets/text_dataset.rb +266 -0
- data/lib/fine/error.rb +49 -0
- data/lib/fine/export/gguf_exporter.rb +424 -0
- data/lib/fine/export/onnx_exporter.rb +249 -0
- data/lib/fine/export.rb +53 -0
- data/lib/fine/hub/config_loader.rb +145 -0
- data/lib/fine/hub/model_downloader.rb +136 -0
- data/lib/fine/hub/safetensors_loader.rb +108 -0
- data/lib/fine/image_classifier.rb +256 -0
- data/lib/fine/llm.rb +336 -0
- data/lib/fine/models/base.rb +48 -0
- data/lib/fine/models/bert_encoder.rb +202 -0
- data/lib/fine/models/bert_for_sequence_classification.rb +226 -0
- data/lib/fine/models/causal_lm.rb +279 -0
- data/lib/fine/models/classification_head.rb +24 -0
- data/lib/fine/models/gemma3_decoder.rb +244 -0
- data/lib/fine/models/llama_decoder.rb +297 -0
- data/lib/fine/models/sentence_transformer.rb +202 -0
- data/lib/fine/models/siglip2_for_image_classification.rb +155 -0
- data/lib/fine/models/siglip2_vision_encoder.rb +190 -0
- data/lib/fine/text_classifier.rb +250 -0
- data/lib/fine/text_embedder.rb +221 -0
- data/lib/fine/tokenizers/auto_tokenizer.rb +208 -0
- data/lib/fine/training/llm_trainer.rb +212 -0
- data/lib/fine/training/text_trainer.rb +275 -0
- data/lib/fine/training/trainer.rb +194 -0
- data/lib/fine/transforms/compose.rb +28 -0
- data/lib/fine/transforms/normalize.rb +33 -0
- data/lib/fine/transforms/resize.rb +35 -0
- data/lib/fine/transforms/to_tensor.rb +53 -0
- data/lib/fine/version.rb +3 -0
- data/lib/fine.rb +112 -0
- data/mise.toml +2 -0
- metadata +240 -0
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fine
|
|
4
|
+
module Callbacks
|
|
5
|
+
# Base class for training callbacks
|
|
6
|
+
class Base
|
|
7
|
+
def on_train_begin(trainer); end
|
|
8
|
+
def on_train_end(trainer); end
|
|
9
|
+
def on_epoch_begin(trainer, epoch); end
|
|
10
|
+
def on_epoch_end(trainer, epoch, metrics); end
|
|
11
|
+
def on_batch_begin(trainer, batch_idx); end
|
|
12
|
+
def on_batch_end(trainer, batch_idx, loss); end
|
|
13
|
+
end
|
|
14
|
+
|
|
15
|
+
# Callback that wraps lambda functions
|
|
16
|
+
class LambdaCallback < Base
|
|
17
|
+
def initialize(on_train_begin: nil, on_train_end: nil,
|
|
18
|
+
on_epoch_begin: nil, on_epoch_end: nil,
|
|
19
|
+
on_batch_begin: nil, on_batch_end: nil)
|
|
20
|
+
super()
|
|
21
|
+
@on_train_begin_fn = on_train_begin
|
|
22
|
+
@on_train_end_fn = on_train_end
|
|
23
|
+
@on_epoch_begin_fn = on_epoch_begin
|
|
24
|
+
@on_epoch_end_fn = on_epoch_end
|
|
25
|
+
@on_batch_begin_fn = on_batch_begin
|
|
26
|
+
@on_batch_end_fn = on_batch_end
|
|
27
|
+
end
|
|
28
|
+
|
|
29
|
+
def on_train_begin(trainer)
|
|
30
|
+
@on_train_begin_fn&.call(trainer)
|
|
31
|
+
end
|
|
32
|
+
|
|
33
|
+
def on_train_end(trainer)
|
|
34
|
+
@on_train_end_fn&.call(trainer)
|
|
35
|
+
end
|
|
36
|
+
|
|
37
|
+
def on_epoch_begin(trainer, epoch)
|
|
38
|
+
@on_epoch_begin_fn&.call(epoch)
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
def on_epoch_end(trainer, epoch, metrics)
|
|
42
|
+
@on_epoch_end_fn&.call(epoch, metrics)
|
|
43
|
+
end
|
|
44
|
+
|
|
45
|
+
def on_batch_begin(trainer, batch_idx)
|
|
46
|
+
@on_batch_begin_fn&.call(batch_idx)
|
|
47
|
+
end
|
|
48
|
+
|
|
49
|
+
def on_batch_end(trainer, batch_idx, loss)
|
|
50
|
+
@on_batch_end_fn&.call(batch_idx, loss)
|
|
51
|
+
end
|
|
52
|
+
end
|
|
53
|
+
|
|
54
|
+
# Early stopping callback
|
|
55
|
+
class EarlyStopping < Base
|
|
56
|
+
attr_reader :patience, :monitor, :best_value, :wait
|
|
57
|
+
|
|
58
|
+
def initialize(patience: 3, monitor: :val_loss, mode: :min)
|
|
59
|
+
super()
|
|
60
|
+
@patience = patience
|
|
61
|
+
@monitor = monitor
|
|
62
|
+
@mode = mode
|
|
63
|
+
@best_value = nil
|
|
64
|
+
@wait = 0
|
|
65
|
+
end
|
|
66
|
+
|
|
67
|
+
def on_epoch_end(trainer, _epoch, metrics)
|
|
68
|
+
current = metrics[@monitor]
|
|
69
|
+
return unless current
|
|
70
|
+
|
|
71
|
+
if @best_value.nil? || improved?(current)
|
|
72
|
+
@best_value = current
|
|
73
|
+
@wait = 0
|
|
74
|
+
else
|
|
75
|
+
@wait += 1
|
|
76
|
+
if @wait >= @patience
|
|
77
|
+
puts "Early stopping triggered after #{@patience} epochs without improvement"
|
|
78
|
+
trainer.stop_training = true
|
|
79
|
+
end
|
|
80
|
+
end
|
|
81
|
+
end
|
|
82
|
+
|
|
83
|
+
private
|
|
84
|
+
|
|
85
|
+
def improved?(current)
|
|
86
|
+
if @mode == :min
|
|
87
|
+
current < @best_value
|
|
88
|
+
else
|
|
89
|
+
current > @best_value
|
|
90
|
+
end
|
|
91
|
+
end
|
|
92
|
+
end
|
|
93
|
+
|
|
94
|
+
# Model checkpoint callback
|
|
95
|
+
class ModelCheckpoint < Base
|
|
96
|
+
def initialize(path:, save_best_only: true, monitor: :val_loss, mode: :min)
|
|
97
|
+
super()
|
|
98
|
+
@path = path
|
|
99
|
+
@save_best_only = save_best_only
|
|
100
|
+
@monitor = monitor
|
|
101
|
+
@mode = mode
|
|
102
|
+
@best_value = nil
|
|
103
|
+
end
|
|
104
|
+
|
|
105
|
+
def on_epoch_end(trainer, epoch, metrics)
|
|
106
|
+
current = metrics[@monitor]
|
|
107
|
+
|
|
108
|
+
if @save_best_only
|
|
109
|
+
return unless current
|
|
110
|
+
|
|
111
|
+
if @best_value.nil? || improved?(current)
|
|
112
|
+
@best_value = current
|
|
113
|
+
save_checkpoint(trainer, epoch, metrics)
|
|
114
|
+
end
|
|
115
|
+
else
|
|
116
|
+
save_checkpoint(trainer, epoch, metrics)
|
|
117
|
+
end
|
|
118
|
+
end
|
|
119
|
+
|
|
120
|
+
private
|
|
121
|
+
|
|
122
|
+
def improved?(current)
|
|
123
|
+
if @mode == :min
|
|
124
|
+
current < @best_value
|
|
125
|
+
else
|
|
126
|
+
current > @best_value
|
|
127
|
+
end
|
|
128
|
+
end
|
|
129
|
+
|
|
130
|
+
def save_checkpoint(trainer, epoch, metrics)
|
|
131
|
+
checkpoint_path = @path.include?("{epoch}") ?
|
|
132
|
+
@path.gsub("{epoch}", epoch.to_s) :
|
|
133
|
+
@path
|
|
134
|
+
|
|
135
|
+
trainer.model.save(checkpoint_path, label_map: trainer.label_map)
|
|
136
|
+
puts "Saved checkpoint to #{checkpoint_path}"
|
|
137
|
+
end
|
|
138
|
+
end
|
|
139
|
+
end
|
|
140
|
+
end
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fine
|
|
4
|
+
module Callbacks
|
|
5
|
+
# Progress bar callback using TTY::ProgressBar
|
|
6
|
+
class ProgressBar < Base
|
|
7
|
+
def initialize(show_epoch: true, show_batch: true)
|
|
8
|
+
super()
|
|
9
|
+
@show_epoch = show_epoch
|
|
10
|
+
@show_batch = show_batch
|
|
11
|
+
@epoch_bar = nil
|
|
12
|
+
@batch_bar = nil
|
|
13
|
+
end
|
|
14
|
+
|
|
15
|
+
def on_train_begin(trainer)
|
|
16
|
+
return unless @show_epoch
|
|
17
|
+
|
|
18
|
+
@epoch_bar = TTY::ProgressBar.new(
|
|
19
|
+
"Training [:bar] :current/:total epochs",
|
|
20
|
+
total: trainer.config.epochs,
|
|
21
|
+
width: 30
|
|
22
|
+
)
|
|
23
|
+
end
|
|
24
|
+
|
|
25
|
+
def on_epoch_begin(trainer, epoch)
|
|
26
|
+
return unless @show_batch
|
|
27
|
+
|
|
28
|
+
@batch_bar = TTY::ProgressBar.new(
|
|
29
|
+
" Epoch #{epoch + 1} [:bar] :current/:total batches :rate/s",
|
|
30
|
+
total: trainer.train_loader.size,
|
|
31
|
+
width: 25,
|
|
32
|
+
hide_cursor: true
|
|
33
|
+
)
|
|
34
|
+
end
|
|
35
|
+
|
|
36
|
+
def on_batch_end(_trainer, _batch_idx, _loss)
|
|
37
|
+
@batch_bar&.advance
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
def on_epoch_end(_trainer, epoch, metrics)
|
|
41
|
+
@batch_bar&.finish
|
|
42
|
+
|
|
43
|
+
# Format metrics for display
|
|
44
|
+
metrics_str = metrics.map { |k, v| "#{k}: #{format_value(v)}" }.join(", ")
|
|
45
|
+
puts " #{metrics_str}"
|
|
46
|
+
|
|
47
|
+
@epoch_bar&.advance
|
|
48
|
+
end
|
|
49
|
+
|
|
50
|
+
def on_train_end(_trainer)
|
|
51
|
+
@epoch_bar&.finish
|
|
52
|
+
puts "Training complete!"
|
|
53
|
+
end
|
|
54
|
+
|
|
55
|
+
private
|
|
56
|
+
|
|
57
|
+
def format_value(v)
|
|
58
|
+
case v
|
|
59
|
+
when Float then format("%.4f", v)
|
|
60
|
+
when Torch::Tensor then format("%.4f", v.item)
|
|
61
|
+
else v.to_s
|
|
62
|
+
end
|
|
63
|
+
end
|
|
64
|
+
end
|
|
65
|
+
end
|
|
66
|
+
end
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fine
|
|
4
|
+
# Configuration for training runs
|
|
5
|
+
class Configuration
|
|
6
|
+
# Training hyperparameters
|
|
7
|
+
attr_accessor :epochs, :batch_size, :learning_rate, :weight_decay
|
|
8
|
+
attr_accessor :warmup_steps, :warmup_ratio
|
|
9
|
+
attr_accessor :optimizer, :scheduler
|
|
10
|
+
|
|
11
|
+
# Model configuration
|
|
12
|
+
attr_accessor :freeze_encoder, :dropout, :num_labels
|
|
13
|
+
|
|
14
|
+
# Data configuration
|
|
15
|
+
attr_accessor :image_size
|
|
16
|
+
|
|
17
|
+
# Callbacks
|
|
18
|
+
attr_accessor :callbacks
|
|
19
|
+
|
|
20
|
+
# Augmentation
|
|
21
|
+
attr_reader :augmentation_config
|
|
22
|
+
|
|
23
|
+
def initialize
|
|
24
|
+
# Training defaults
|
|
25
|
+
@epochs = 3
|
|
26
|
+
@batch_size = 32
|
|
27
|
+
@learning_rate = 2e-4
|
|
28
|
+
@weight_decay = 0.02
|
|
29
|
+
@warmup_steps = 0
|
|
30
|
+
@warmup_ratio = 0.0
|
|
31
|
+
@optimizer = :adamw
|
|
32
|
+
@scheduler = :cosine
|
|
33
|
+
|
|
34
|
+
# Model defaults
|
|
35
|
+
@freeze_encoder = false
|
|
36
|
+
@dropout = 0.1
|
|
37
|
+
@num_labels = nil # auto-detect from dataset
|
|
38
|
+
|
|
39
|
+
# Data defaults
|
|
40
|
+
@image_size = 224
|
|
41
|
+
|
|
42
|
+
# Callbacks
|
|
43
|
+
@callbacks = []
|
|
44
|
+
|
|
45
|
+
# Augmentation
|
|
46
|
+
@augmentation_config = AugmentationConfig.new
|
|
47
|
+
end
|
|
48
|
+
|
|
49
|
+
def augmentation
|
|
50
|
+
yield @augmentation_config if block_given?
|
|
51
|
+
@augmentation_config
|
|
52
|
+
end
|
|
53
|
+
|
|
54
|
+
# Register a callback for epoch end
|
|
55
|
+
def on_epoch_end(&block)
|
|
56
|
+
@callbacks << Callbacks::LambdaCallback.new(on_epoch_end: block)
|
|
57
|
+
end
|
|
58
|
+
|
|
59
|
+
# Register a callback for batch end
|
|
60
|
+
def on_batch_end(&block)
|
|
61
|
+
@callbacks << Callbacks::LambdaCallback.new(on_batch_end: block)
|
|
62
|
+
end
|
|
63
|
+
|
|
64
|
+
# Register a callback for train begin
|
|
65
|
+
def on_train_begin(&block)
|
|
66
|
+
@callbacks << Callbacks::LambdaCallback.new(on_train_begin: block)
|
|
67
|
+
end
|
|
68
|
+
|
|
69
|
+
# Register a callback for train end
|
|
70
|
+
def on_train_end(&block)
|
|
71
|
+
@callbacks << Callbacks::LambdaCallback.new(on_train_end: block)
|
|
72
|
+
end
|
|
73
|
+
end
|
|
74
|
+
|
|
75
|
+
# Configuration for data augmentation
|
|
76
|
+
class AugmentationConfig
|
|
77
|
+
attr_accessor :random_horizontal_flip, :random_vertical_flip
|
|
78
|
+
attr_accessor :random_rotation, :color_jitter
|
|
79
|
+
attr_accessor :random_resized_crop
|
|
80
|
+
|
|
81
|
+
def initialize
|
|
82
|
+
@random_horizontal_flip = false
|
|
83
|
+
@random_vertical_flip = false
|
|
84
|
+
@random_rotation = 0
|
|
85
|
+
@color_jitter = nil
|
|
86
|
+
@random_resized_crop = nil
|
|
87
|
+
end
|
|
88
|
+
|
|
89
|
+
def enabled?
|
|
90
|
+
@random_horizontal_flip ||
|
|
91
|
+
@random_vertical_flip ||
|
|
92
|
+
@random_rotation.positive? ||
|
|
93
|
+
@color_jitter ||
|
|
94
|
+
@random_resized_crop
|
|
95
|
+
end
|
|
96
|
+
|
|
97
|
+
def to_transforms
|
|
98
|
+
transforms = []
|
|
99
|
+
transforms << Transforms::RandomHorizontalFlip.new if @random_horizontal_flip
|
|
100
|
+
transforms << Transforms::RandomVerticalFlip.new if @random_vertical_flip
|
|
101
|
+
transforms << Transforms::RandomRotation.new(@random_rotation) if @random_rotation.positive?
|
|
102
|
+
# Add more transforms as implemented
|
|
103
|
+
transforms
|
|
104
|
+
end
|
|
105
|
+
end
|
|
106
|
+
end
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fine
|
|
4
|
+
module Datasets
|
|
5
|
+
# DataLoader for batching dataset samples
|
|
6
|
+
class DataLoader
|
|
7
|
+
include Enumerable
|
|
8
|
+
|
|
9
|
+
attr_reader :dataset, :batch_size, :shuffle, :drop_last
|
|
10
|
+
|
|
11
|
+
# @param dataset [ImageDataset] The dataset to load from
|
|
12
|
+
# @param batch_size [Integer] Number of samples per batch
|
|
13
|
+
# @param shuffle [Boolean] Whether to shuffle indices each epoch
|
|
14
|
+
# @param drop_last [Boolean] Whether to drop the last incomplete batch
|
|
15
|
+
def initialize(dataset, batch_size:, shuffle: false, drop_last: false)
|
|
16
|
+
@dataset = dataset
|
|
17
|
+
@batch_size = batch_size
|
|
18
|
+
@shuffle = shuffle
|
|
19
|
+
@drop_last = drop_last
|
|
20
|
+
@indices = nil
|
|
21
|
+
end
|
|
22
|
+
|
|
23
|
+
# Iterate over batches
|
|
24
|
+
def each_batch
|
|
25
|
+
return enum_for(:each_batch) unless block_given?
|
|
26
|
+
|
|
27
|
+
indices = (0...@dataset.size).to_a
|
|
28
|
+
indices.shuffle! if @shuffle
|
|
29
|
+
|
|
30
|
+
indices.each_slice(@batch_size) do |batch_indices|
|
|
31
|
+
next if @drop_last && batch_indices.size < @batch_size
|
|
32
|
+
|
|
33
|
+
yield collate(batch_indices)
|
|
34
|
+
end
|
|
35
|
+
end
|
|
36
|
+
|
|
37
|
+
alias each each_batch
|
|
38
|
+
|
|
39
|
+
# Number of batches
|
|
40
|
+
def size
|
|
41
|
+
n = @dataset.size / @batch_size
|
|
42
|
+
n += 1 unless @drop_last || (@dataset.size % @batch_size).zero?
|
|
43
|
+
n
|
|
44
|
+
end
|
|
45
|
+
|
|
46
|
+
alias num_batches size
|
|
47
|
+
|
|
48
|
+
private
|
|
49
|
+
|
|
50
|
+
def collate(indices)
|
|
51
|
+
samples = indices.map { |i| @dataset[i] }
|
|
52
|
+
|
|
53
|
+
# Stack pixel_values into a single tensor
|
|
54
|
+
pixel_values = Torch.stack(samples.map { |s| s[:pixel_values] })
|
|
55
|
+
|
|
56
|
+
# Stack labels into a single tensor
|
|
57
|
+
labels = Torch.tensor(samples.map { |s| s[:label] }, dtype: :long)
|
|
58
|
+
|
|
59
|
+
{ pixel_values: pixel_values, labels: labels }
|
|
60
|
+
end
|
|
61
|
+
end
|
|
62
|
+
end
|
|
63
|
+
end
|
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fine
|
|
4
|
+
module Datasets
|
|
5
|
+
# Dataset for loading images from a directory structure
|
|
6
|
+
#
|
|
7
|
+
# Expected structure:
|
|
8
|
+
# data/
|
|
9
|
+
# class1/
|
|
10
|
+
# image1.jpg
|
|
11
|
+
# image2.jpg
|
|
12
|
+
# class2/
|
|
13
|
+
# image3.jpg
|
|
14
|
+
#
|
|
15
|
+
class ImageDataset
|
|
16
|
+
include Enumerable
|
|
17
|
+
|
|
18
|
+
attr_reader :images, :labels, :label_map, :inverse_label_map, :transforms
|
|
19
|
+
|
|
20
|
+
IMAGE_EXTENSIONS = %w[.jpg .jpeg .png .webp .bmp .gif].freeze
|
|
21
|
+
|
|
22
|
+
# Create dataset from a directory with class subdirectories
|
|
23
|
+
#
|
|
24
|
+
# @param path [String] Path to the root directory
|
|
25
|
+
# @param transforms [Transforms::Compose, nil] Optional transforms to apply
|
|
26
|
+
# @return [ImageDataset]
|
|
27
|
+
def self.from_directory(path, transforms: nil)
|
|
28
|
+
raise DatasetError, "Directory not found: #{path}" unless File.directory?(path)
|
|
29
|
+
|
|
30
|
+
images = []
|
|
31
|
+
labels = []
|
|
32
|
+
|
|
33
|
+
# Get sorted list of class directories
|
|
34
|
+
label_names = Dir.children(path)
|
|
35
|
+
.select { |f| File.directory?(File.join(path, f)) }
|
|
36
|
+
.reject { |f| f.start_with?(".") }
|
|
37
|
+
.sort
|
|
38
|
+
|
|
39
|
+
raise DatasetError, "No class directories found in #{path}" if label_names.empty?
|
|
40
|
+
|
|
41
|
+
# Build label map
|
|
42
|
+
label_map = label_names.each_with_index.to_h
|
|
43
|
+
|
|
44
|
+
# Collect images from each class directory
|
|
45
|
+
label_names.each do |label_name|
|
|
46
|
+
class_dir = File.join(path, label_name)
|
|
47
|
+
label_id = label_map[label_name]
|
|
48
|
+
|
|
49
|
+
Dir.glob(File.join(class_dir, "*")).each do |image_path|
|
|
50
|
+
next unless image_file?(image_path)
|
|
51
|
+
|
|
52
|
+
images << image_path
|
|
53
|
+
labels << label_id
|
|
54
|
+
end
|
|
55
|
+
end
|
|
56
|
+
|
|
57
|
+
raise DatasetError, "No images found in #{path}" if images.empty?
|
|
58
|
+
|
|
59
|
+
new(images: images, labels: labels, label_map: label_map, transforms: transforms)
|
|
60
|
+
end
|
|
61
|
+
|
|
62
|
+
# Create dataset from explicit arrays
|
|
63
|
+
#
|
|
64
|
+
# @param images [Array<String>] Array of image paths
|
|
65
|
+
# @param labels [Array<Integer, String>] Array of labels
|
|
66
|
+
# @param label_map [Hash, nil] Optional mapping of label names to IDs
|
|
67
|
+
# @param transforms [Transforms::Compose, nil] Optional transforms
|
|
68
|
+
def initialize(images:, labels:, label_map: nil, transforms: nil)
|
|
69
|
+
raise ArgumentError, "images and labels must have same length" if images.size != labels.size
|
|
70
|
+
|
|
71
|
+
@images = images
|
|
72
|
+
@transforms = transforms || default_transforms
|
|
73
|
+
|
|
74
|
+
# Build label map if not provided
|
|
75
|
+
if label_map
|
|
76
|
+
@label_map = label_map
|
|
77
|
+
else
|
|
78
|
+
unique_labels = labels.uniq.sort
|
|
79
|
+
@label_map = unique_labels.each_with_index.to_h
|
|
80
|
+
end
|
|
81
|
+
|
|
82
|
+
# Convert string labels to integers if needed
|
|
83
|
+
@labels = labels.map do |label|
|
|
84
|
+
label.is_a?(Integer) ? label : @label_map[label]
|
|
85
|
+
end
|
|
86
|
+
|
|
87
|
+
# Build inverse mapping
|
|
88
|
+
@inverse_label_map = @label_map.invert
|
|
89
|
+
end
|
|
90
|
+
|
|
91
|
+
# Get a single item from the dataset
|
|
92
|
+
#
|
|
93
|
+
# @param index [Integer] Index of the item
|
|
94
|
+
# @return [Hash] Hash with :pixel_values and :label keys
|
|
95
|
+
def [](index)
|
|
96
|
+
image = load_image(@images[index])
|
|
97
|
+
image = @transforms.call(image)
|
|
98
|
+
|
|
99
|
+
{ pixel_values: image, label: @labels[index] }
|
|
100
|
+
end
|
|
101
|
+
|
|
102
|
+
# Number of items in the dataset
|
|
103
|
+
def size
|
|
104
|
+
@images.size
|
|
105
|
+
end
|
|
106
|
+
alias length size
|
|
107
|
+
|
|
108
|
+
# Iterate over all items
|
|
109
|
+
def each
|
|
110
|
+
return enum_for(:each) unless block_given?
|
|
111
|
+
|
|
112
|
+
size.times { |i| yield self[i] }
|
|
113
|
+
end
|
|
114
|
+
|
|
115
|
+
# Number of classes
|
|
116
|
+
def num_classes
|
|
117
|
+
@label_map.size
|
|
118
|
+
end
|
|
119
|
+
|
|
120
|
+
# Get class names in order
|
|
121
|
+
def class_names
|
|
122
|
+
@inverse_label_map.sort.map(&:last)
|
|
123
|
+
end
|
|
124
|
+
|
|
125
|
+
# Split dataset into train and validation sets
|
|
126
|
+
#
|
|
127
|
+
# @param test_size [Float] Fraction of data to use for validation (0.0-1.0)
|
|
128
|
+
# @param shuffle [Boolean] Whether to shuffle before splitting
|
|
129
|
+
# @param stratify [Boolean] Whether to maintain class distribution
|
|
130
|
+
# @param seed [Integer, nil] Random seed for reproducibility
|
|
131
|
+
# @return [Array<ImageDataset, ImageDataset>] Train and validation datasets
|
|
132
|
+
def split(test_size: 0.2, shuffle: true, stratify: true, seed: nil)
|
|
133
|
+
rng = seed ? Random.new(seed) : Random.new
|
|
134
|
+
|
|
135
|
+
indices = (0...size).to_a
|
|
136
|
+
indices = indices.shuffle(random: rng) if shuffle && !stratify
|
|
137
|
+
|
|
138
|
+
if stratify
|
|
139
|
+
train_indices, val_indices = stratified_split(indices, test_size, rng)
|
|
140
|
+
else
|
|
141
|
+
split_idx = (size * (1 - test_size)).round
|
|
142
|
+
train_indices = indices[0...split_idx]
|
|
143
|
+
val_indices = indices[split_idx..]
|
|
144
|
+
end
|
|
145
|
+
|
|
146
|
+
train_set = subset(train_indices)
|
|
147
|
+
val_set = subset(val_indices)
|
|
148
|
+
|
|
149
|
+
[train_set, val_set]
|
|
150
|
+
end
|
|
151
|
+
|
|
152
|
+
private
|
|
153
|
+
|
|
154
|
+
def self.image_file?(path)
|
|
155
|
+
return false unless File.file?(path)
|
|
156
|
+
|
|
157
|
+
ext = File.extname(path).downcase
|
|
158
|
+
IMAGE_EXTENSIONS.include?(ext)
|
|
159
|
+
end
|
|
160
|
+
|
|
161
|
+
def load_image(path)
|
|
162
|
+
Vips::Image.new_from_file(path, access: :sequential)
|
|
163
|
+
rescue Vips::Error => e
|
|
164
|
+
raise ImageProcessingError.new(path, "Failed to load image: #{e.message}")
|
|
165
|
+
end
|
|
166
|
+
|
|
167
|
+
def default_transforms
|
|
168
|
+
Transforms::Compose.new([
|
|
169
|
+
Transforms::Resize.new(224),
|
|
170
|
+
Transforms::ToTensor.new,
|
|
171
|
+
Transforms::Normalize.new
|
|
172
|
+
])
|
|
173
|
+
end
|
|
174
|
+
|
|
175
|
+
def subset(indices)
|
|
176
|
+
ImageDataset.new(
|
|
177
|
+
images: indices.map { |i| @images[i] },
|
|
178
|
+
labels: indices.map { |i| @labels[i] },
|
|
179
|
+
label_map: @label_map,
|
|
180
|
+
transforms: @transforms
|
|
181
|
+
)
|
|
182
|
+
end
|
|
183
|
+
|
|
184
|
+
def stratified_split(indices, test_size, rng)
|
|
185
|
+
train_indices = []
|
|
186
|
+
val_indices = []
|
|
187
|
+
|
|
188
|
+
# Group indices by label
|
|
189
|
+
by_label = indices.group_by { |i| @labels[i] }
|
|
190
|
+
|
|
191
|
+
by_label.each_value do |label_indices|
|
|
192
|
+
shuffled = label_indices.shuffle(random: rng)
|
|
193
|
+
split_idx = (shuffled.size * (1 - test_size)).round
|
|
194
|
+
|
|
195
|
+
train_indices.concat(shuffled[0...split_idx])
|
|
196
|
+
val_indices.concat(shuffled[split_idx..])
|
|
197
|
+
end
|
|
198
|
+
|
|
199
|
+
[train_indices.shuffle(random: rng), val_indices.shuffle(random: rng)]
|
|
200
|
+
end
|
|
201
|
+
end
|
|
202
|
+
end
|
|
203
|
+
end
|