fine 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. checksums.yaml +7 -0
  2. data/.rspec +3 -0
  3. data/CHANGELOG.md +38 -0
  4. data/Gemfile +6 -0
  5. data/Gemfile.lock +167 -0
  6. data/LICENSE +21 -0
  7. data/README.md +212 -0
  8. data/Rakefile +6 -0
  9. data/docs/installation.md +151 -0
  10. data/docs/tutorials/llm-fine-tuning.md +246 -0
  11. data/docs/tutorials/model-export.md +200 -0
  12. data/docs/tutorials/siglip2-image-classification.md +130 -0
  13. data/docs/tutorials/siglip2-object-recognition.md +203 -0
  14. data/docs/tutorials/siglip2-similarity-search.md +152 -0
  15. data/docs/tutorials/text-classification.md +233 -0
  16. data/docs/tutorials/text-embeddings.md +211 -0
  17. data/examples/basic_classification.rb +70 -0
  18. data/examples/data/tool_calls.jsonl +30 -0
  19. data/examples/demo_training.rb +78 -0
  20. data/examples/finetune_gemma3_tools.rb +135 -0
  21. data/examples/real_llm_test.rb +128 -0
  22. data/examples/real_text_classification_test.rb +90 -0
  23. data/examples/real_text_embedder_test.rb +110 -0
  24. data/examples/real_training_test.rb +88 -0
  25. data/examples/test_export.rb +28 -0
  26. data/examples/test_image_classifier.rb +79 -0
  27. data/examples/test_llm.rb +100 -0
  28. data/examples/test_text_classifier.rb +59 -0
  29. data/lib/fine/callbacks/base.rb +140 -0
  30. data/lib/fine/callbacks/progress_bar.rb +66 -0
  31. data/lib/fine/configuration.rb +106 -0
  32. data/lib/fine/datasets/data_loader.rb +63 -0
  33. data/lib/fine/datasets/image_dataset.rb +203 -0
  34. data/lib/fine/datasets/instruction_dataset.rb +226 -0
  35. data/lib/fine/datasets/text_data_loader.rb +88 -0
  36. data/lib/fine/datasets/text_dataset.rb +266 -0
  37. data/lib/fine/error.rb +49 -0
  38. data/lib/fine/export/gguf_exporter.rb +424 -0
  39. data/lib/fine/export/onnx_exporter.rb +249 -0
  40. data/lib/fine/export.rb +53 -0
  41. data/lib/fine/hub/config_loader.rb +145 -0
  42. data/lib/fine/hub/model_downloader.rb +136 -0
  43. data/lib/fine/hub/safetensors_loader.rb +108 -0
  44. data/lib/fine/image_classifier.rb +256 -0
  45. data/lib/fine/llm.rb +336 -0
  46. data/lib/fine/models/base.rb +48 -0
  47. data/lib/fine/models/bert_encoder.rb +202 -0
  48. data/lib/fine/models/bert_for_sequence_classification.rb +226 -0
  49. data/lib/fine/models/causal_lm.rb +279 -0
  50. data/lib/fine/models/classification_head.rb +24 -0
  51. data/lib/fine/models/gemma3_decoder.rb +244 -0
  52. data/lib/fine/models/llama_decoder.rb +297 -0
  53. data/lib/fine/models/sentence_transformer.rb +202 -0
  54. data/lib/fine/models/siglip2_for_image_classification.rb +155 -0
  55. data/lib/fine/models/siglip2_vision_encoder.rb +190 -0
  56. data/lib/fine/text_classifier.rb +250 -0
  57. data/lib/fine/text_embedder.rb +221 -0
  58. data/lib/fine/tokenizers/auto_tokenizer.rb +208 -0
  59. data/lib/fine/training/llm_trainer.rb +212 -0
  60. data/lib/fine/training/text_trainer.rb +275 -0
  61. data/lib/fine/training/trainer.rb +194 -0
  62. data/lib/fine/transforms/compose.rb +28 -0
  63. data/lib/fine/transforms/normalize.rb +33 -0
  64. data/lib/fine/transforms/resize.rb +35 -0
  65. data/lib/fine/transforms/to_tensor.rb +53 -0
  66. data/lib/fine/version.rb +3 -0
  67. data/lib/fine.rb +112 -0
  68. data/mise.toml +2 -0
  69. metadata +240 -0
@@ -0,0 +1,53 @@
1
+ # frozen_string_literal: true
2
+
3
+ require_relative "export/onnx_exporter"
4
+ require_relative "export/gguf_exporter"
5
+
6
+ module Fine
7
+ # Export models to various deployment formats
8
+ #
9
+ # @example Export to ONNX
10
+ # Fine::Export.to_onnx(classifier, "model.onnx")
11
+ #
12
+ # @example Export LLM to GGUF
13
+ # Fine::Export.to_gguf(llm, "model.gguf", quantization: :q4_0)
14
+ module Export
15
+ class << self
16
+ # Export any Fine model to ONNX format
17
+ #
18
+ # @param model [TextClassifier, TextEmbedder, ImageClassifier, LLM] The model
19
+ # @param path [String] Output path
20
+ # @param options [Hash] Export options
21
+ # @return [String] The output path
22
+ def to_onnx(model, path, **options)
23
+ ONNXExporter.export(model, path, **options)
24
+ end
25
+
26
+ # Export LLM to GGUF format
27
+ #
28
+ # @param model [LLM] The LLM model
29
+ # @param path [String] Output path
30
+ # @param quantization [Symbol] Quantization type (:f16, :q4_0, :q8_0, etc.)
31
+ # @param metadata [Hash] Additional metadata
32
+ # @return [String] The output path
33
+ def to_gguf(model, path, quantization: :f16, **options)
34
+ GGUFExporter.export(model, path, quantization: quantization, **options)
35
+ end
36
+
37
+ # List available quantization options for GGUF
38
+ #
39
+ # @return [Hash] Quantization types with descriptions
40
+ def gguf_quantization_options
41
+ {
42
+ f32: "32-bit float (largest, no quality loss)",
43
+ f16: "16-bit float (good balance)",
44
+ q8_0: "8-bit quantization (smaller, minimal quality loss)",
45
+ q4_0: "4-bit quantization (smallest, some quality loss)",
46
+ q4_k: "4-bit K-quant (better quality than q4_0)",
47
+ q5_k: "5-bit K-quant (good quality/size balance)",
48
+ q6_k: "6-bit K-quant (high quality)"
49
+ }
50
+ end
51
+ end
52
+ end
53
+ end
@@ -0,0 +1,145 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Fine
4
+ module Hub
5
+ # Loads and parses model configuration from Hugging Face format
6
+ class ConfigLoader
7
+ attr_reader :config
8
+
9
+ def initialize(config_path)
10
+ @config_path = config_path
11
+ @config = load_config
12
+ end
13
+
14
+ def self.from_pretrained(model_path)
15
+ config_path = File.join(model_path, "config.json")
16
+ raise ConfigurationError, "Config not found: #{config_path}" unless File.exist?(config_path)
17
+
18
+ new(config_path)
19
+ end
20
+
21
+ # Vision encoder configuration
22
+ def hidden_size
23
+ vision_config["hidden_size"] || config["hidden_size"] || 768
24
+ end
25
+
26
+ def num_hidden_layers
27
+ vision_config["num_hidden_layers"] || config["num_hidden_layers"] || 12
28
+ end
29
+
30
+ def num_attention_heads
31
+ vision_config["num_attention_heads"] || config["num_attention_heads"] || 12
32
+ end
33
+
34
+ def intermediate_size
35
+ vision_config["intermediate_size"] || config["intermediate_size"] || 3072
36
+ end
37
+
38
+ def image_size
39
+ vision_config["image_size"] || config["image_size"] || 224
40
+ end
41
+
42
+ def patch_size
43
+ vision_config["patch_size"] || config["patch_size"] || 16
44
+ end
45
+
46
+ def num_channels
47
+ vision_config["num_channels"] || config["num_channels"] || 3
48
+ end
49
+
50
+ def layer_norm_eps
51
+ vision_config["layer_norm_eps"] || config["layer_norm_eps"] || 1e-6
52
+ end
53
+
54
+ def hidden_act
55
+ vision_config["hidden_act"] || config["hidden_act"] || "gelu"
56
+ end
57
+
58
+ def attention_dropout
59
+ vision_config["attention_dropout"] || config["attention_dropout"] || 0.0
60
+ end
61
+
62
+ # Text model configuration (BERT, DistilBERT, DeBERTa, etc.)
63
+ def vocab_size
64
+ config["vocab_size"] || 30522
65
+ end
66
+
67
+ def max_position_embeddings
68
+ config["max_position_embeddings"] || 512
69
+ end
70
+
71
+ def type_vocab_size
72
+ config["type_vocab_size"] || 2
73
+ end
74
+
75
+ def hidden_dropout_prob
76
+ config["hidden_dropout_prob"] || config["dropout"] || 0.1
77
+ end
78
+
79
+ # LLM configuration (Llama, Gemma, etc.)
80
+ def rope_theta
81
+ config["rope_theta"] || 10000.0
82
+ end
83
+
84
+ def rms_norm_eps
85
+ config["rms_norm_eps"] || 1e-6
86
+ end
87
+
88
+ def num_key_value_heads
89
+ config["num_key_value_heads"] || num_attention_heads
90
+ end
91
+
92
+ def use_bias
93
+ config["use_bias"] != false
94
+ end
95
+
96
+ def head_dim
97
+ config["head_dim"] || (hidden_size / num_attention_heads)
98
+ end
99
+
100
+ def attention_bias
101
+ config["attention_bias"] || false
102
+ end
103
+
104
+ def use_qk_norm
105
+ # Gemma 3 uses QK normalization
106
+ config.key?("query_pre_attn_scalar") || model_type&.include?("gemma3")
107
+ end
108
+
109
+ def use_pre_feedforward_layernorm
110
+ # Gemma 3 has additional layer norms
111
+ model_type&.include?("gemma3")
112
+ end
113
+
114
+ # Computed properties
115
+ def num_patches
116
+ (image_size / patch_size) ** 2
117
+ end
118
+
119
+ def model_type
120
+ config["model_type"]
121
+ end
122
+
123
+ # Raw access to config sections
124
+ def vision_config
125
+ config["vision_config"] || {}
126
+ end
127
+
128
+ def text_config
129
+ config["text_config"] || {}
130
+ end
131
+
132
+ def to_h
133
+ @config
134
+ end
135
+
136
+ private
137
+
138
+ def load_config
139
+ JSON.parse(File.read(@config_path))
140
+ rescue JSON::ParserError => e
141
+ raise ConfigurationError, "Invalid JSON in config: #{e.message}"
142
+ end
143
+ end
144
+ end
145
+ end
@@ -0,0 +1,136 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Fine
4
+ module Hub
5
+ # Downloads models from Hugging Face Hub
6
+ class ModelDownloader
7
+ HF_HUB_URL = "https://huggingface.co"
8
+ REQUIRED_FILES = %w[config.json].freeze
9
+ MODEL_FILES = %w[model.safetensors].freeze
10
+ OPTIONAL_FILES = %w[preprocessor_config.json tokenizer_config.json tokenizer.json].freeze
11
+
12
+ attr_reader :model_id, :cache_path
13
+
14
+ def initialize(model_id, cache_dir: nil)
15
+ @model_id = model_id
16
+ @cache_dir = cache_dir || Fine.cache_dir
17
+ @cache_path = File.join(@cache_dir, "models", model_id.tr("/", "--"))
18
+ end
19
+
20
+ # Download model files and return the cache path
21
+ def download(force: false)
22
+ return @cache_path if cached? && !force
23
+
24
+ FileUtils.mkdir_p(@cache_path)
25
+
26
+ download_required_files
27
+ download_optional_files
28
+
29
+ @cache_path
30
+ end
31
+
32
+ # Check if model is already cached
33
+ def cached?
34
+ return false unless File.directory?(@cache_path)
35
+
36
+ # Check required files
37
+ return false unless REQUIRED_FILES.all? { |f| File.exist?(File.join(@cache_path, f)) }
38
+
39
+ # Check for model weights (single file or sharded)
40
+ has_single_weights = File.exist?(File.join(@cache_path, "model.safetensors"))
41
+ has_sharded_weights = Dir.glob(File.join(@cache_path, "model-*.safetensors")).any?
42
+
43
+ has_single_weights || has_sharded_weights
44
+ end
45
+
46
+ # Get path to a specific file
47
+ def file_path(filename)
48
+ File.join(@cache_path, filename)
49
+ end
50
+
51
+ private
52
+
53
+ def download_required_files
54
+ REQUIRED_FILES.each do |filename|
55
+ download_file(filename)
56
+ end
57
+
58
+ # Download model weights (try single file first, then sharded)
59
+ download_model_weights
60
+ end
61
+
62
+ def download_model_weights
63
+ # Try single model file first
64
+ begin
65
+ download_file("model.safetensors")
66
+ return
67
+ rescue ModelNotFoundError
68
+ # Single file not found, try sharded
69
+ end
70
+
71
+ # Download index file to find shards
72
+ download_file("model.safetensors.index.json", required: false)
73
+ index_path = File.join(@cache_path, "model.safetensors.index.json")
74
+
75
+ if File.exist?(index_path)
76
+ index = JSON.parse(File.read(index_path))
77
+ weight_files = index["weight_map"].values.uniq
78
+
79
+ weight_files.each do |filename|
80
+ download_file(filename)
81
+ end
82
+ else
83
+ raise ModelNotFoundError.new(@model_id, "No model weights found")
84
+ end
85
+ end
86
+
87
+ def download_optional_files
88
+ OPTIONAL_FILES.each do |filename|
89
+ download_file(filename, required: false)
90
+ end
91
+ end
92
+
93
+ def download_file(filename, required: true)
94
+ local_path = File.join(@cache_path, filename)
95
+ return if File.exist?(local_path)
96
+
97
+ url = file_url(filename)
98
+
99
+ begin
100
+ puts "Downloading #{filename}..." if Fine.configuration&.progress_bar != false
101
+
102
+ headers = { "User-Agent" => "fine-ruby/#{Fine::VERSION}" }
103
+
104
+ # Add HuggingFace token if available
105
+ if (token = hf_token)
106
+ headers["Authorization"] = "Bearer #{token}"
107
+ end
108
+
109
+ tempfile = Down.download(url, headers: headers)
110
+
111
+ FileUtils.mv(tempfile.path, local_path)
112
+ rescue Down::NotFound
113
+ raise ModelNotFoundError.new(@model_id, "File not found: #{filename}") if required
114
+ rescue Down::Error => e
115
+ raise ModelNotFoundError.new(@model_id, "Download failed: #{e.message}") if required
116
+ end
117
+ end
118
+
119
+ def hf_token
120
+ # Check environment variable first
121
+ return ENV["HF_TOKEN"] if ENV["HF_TOKEN"]
122
+ return ENV["HUGGING_FACE_HUB_TOKEN"] if ENV["HUGGING_FACE_HUB_TOKEN"]
123
+
124
+ # Check standard HuggingFace cache location
125
+ token_path = File.expand_path("~/.cache/huggingface/token")
126
+ return File.read(token_path).strip if File.exist?(token_path)
127
+
128
+ nil
129
+ end
130
+
131
+ def file_url(filename)
132
+ "#{HF_HUB_URL}/#{@model_id}/resolve/main/#{filename}"
133
+ end
134
+ end
135
+ end
136
+ end
@@ -0,0 +1,108 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Fine
4
+ module Hub
5
+ # Loads SafeTensors weights into torch.rb models
6
+ class SafetensorsLoader
7
+ # Load weights from safetensors file into a model
8
+ #
9
+ # @param model [Torch::NN::Module] The model to load weights into
10
+ # @param path [String] Path to the safetensors file
11
+ # @param strict [Boolean] If true, raise error on missing/unexpected keys
12
+ # @param prefix [String] Prefix to add/remove from weight names
13
+ # @return [Hash] Hash with :missing_keys and :unexpected_keys arrays
14
+ def self.load_into_model(model, path, strict: false, prefix: nil)
15
+ tensors = Safetensors::Torch.load_file(path)
16
+
17
+ # Get model's state dict keys
18
+ model_keys = model.state_dict.keys
19
+
20
+ # Map and filter tensors
21
+ mapped_tensors = {}
22
+ unexpected_keys = []
23
+
24
+ tensors.each do |name, tensor|
25
+ mapped_name = map_weight_name(name, prefix: prefix)
26
+
27
+ if model_keys.include?(mapped_name)
28
+ mapped_tensors[mapped_name] = tensor
29
+ else
30
+ unexpected_keys << name
31
+ end
32
+ end
33
+
34
+ # Find missing keys
35
+ missing_keys = model_keys - mapped_tensors.keys
36
+
37
+ # Raise error if strict mode and there are issues
38
+ if strict && (missing_keys.any? || unexpected_keys.any?)
39
+ raise WeightLoadingError.new(
40
+ "Weight loading failed",
41
+ missing_keys: missing_keys,
42
+ unexpected_keys: unexpected_keys
43
+ )
44
+ end
45
+
46
+ # Load the mapped tensors by manually copying data
47
+ # (torch.rb doesn't support strict: false yet)
48
+ # Use no_grad to avoid in-place operation errors on leaf tensors
49
+ Torch.no_grad do
50
+ state_dict = model.state_dict
51
+ mapped_tensors.each do |name, tensor|
52
+ if state_dict.key?(name)
53
+ state_dict[name].copy!(tensor)
54
+ end
55
+ end
56
+ end
57
+
58
+ { missing_keys: missing_keys, unexpected_keys: unexpected_keys }
59
+ end
60
+
61
+ # Map HuggingFace weight names to torch.rb model structure
62
+ #
63
+ # @param hf_name [String] Original HuggingFace weight name
64
+ # @param prefix [String] Optional prefix to strip
65
+ # @return [String] Mapped weight name for torch.rb
66
+ def self.map_weight_name(hf_name, prefix: nil)
67
+ name = hf_name.dup
68
+
69
+ # Strip prefix if provided
70
+ name = name.sub(/^#{Regexp.escape(prefix)}\.?/, "") if prefix
71
+
72
+ # SigLIP2 specific mappings
73
+ # Note: prefix (e.g., "vision_model") is already stripped above
74
+ name = name.sub("embeddings.patch_embedding", "patch_embed.proj")
75
+ name = name.sub("embeddings.position_embedding.weight", "pos_embed")
76
+ name = name.sub("encoder.layers", "blocks")
77
+ name = name.sub("post_layernorm", "norm")
78
+ name = name.sub("head", "head")
79
+
80
+ # Transformer block mappings
81
+ name = name.gsub(".self_attn.", ".attn.")
82
+ name = name.gsub(".layer_norm1.", ".norm1.")
83
+ name = name.gsub(".layer_norm2.", ".norm2.")
84
+ # mlp.fc1, mlp.fc2, q_proj, k_proj, v_proj, out_proj names match our model
85
+
86
+ name
87
+ end
88
+
89
+ # Load raw tensors from safetensors file
90
+ #
91
+ # @param path [String] Path to safetensors file
92
+ # @return [Hash<String, Torch::Tensor>] Hash of tensor name to tensor
93
+ def self.load_file(path)
94
+ Safetensors::Torch.load_file(path)
95
+ end
96
+
97
+ # List tensor names in a safetensors file
98
+ #
99
+ # @param path [String] Path to safetensors file
100
+ # @return [Array<String>] List of tensor names
101
+ def self.tensor_names(path)
102
+ Safetensors.safe_open(path, framework: "torch") do |f|
103
+ f.keys
104
+ end
105
+ end
106
+ end
107
+ end
108
+ end
@@ -0,0 +1,256 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Fine
4
+ # High-level API for image classification fine-tuning
5
+ #
6
+ # @example Simple usage
7
+ # classifier = Fine::ImageClassifier.new("google/siglip2-base-patch16-224")
8
+ # classifier.fit(train_dir: "data/train", epochs: 3)
9
+ # classifier.save("my_model")
10
+ #
11
+ # @example With configuration
12
+ # classifier = Fine::ImageClassifier.new("google/siglip2-base-patch16-224") do |config|
13
+ # config.learning_rate = 2e-4
14
+ # config.batch_size = 32
15
+ # config.on_epoch_end { |epoch, metrics| puts "Epoch #{epoch}: #{metrics}" }
16
+ # end
17
+ #
18
+ class ImageClassifier
19
+ attr_reader :model, :config, :label_map, :model_id
20
+
21
+ # Create a new ImageClassifier
22
+ #
23
+ # @param model_id [String] Hugging Face model ID (e.g., "google/siglip2-base-patch16-224")
24
+ # @yield [config] Optional configuration block
25
+ # @yieldparam config [Configuration] Configuration object
26
+ def initialize(model_id, &block)
27
+ @model_id = model_id
28
+ @config = Configuration.new
29
+ @model = nil
30
+ @label_map = nil
31
+ @trained = false
32
+
33
+ # Apply configuration block
34
+ block&.call(@config)
35
+
36
+ # Add default progress bar if none configured
37
+ if @config.callbacks.empty? && Fine.configuration&.progress_bar != false
38
+ @config.callbacks << Callbacks::ProgressBar.new
39
+ end
40
+ end
41
+
42
+ # Load a fine-tuned classifier from disk
43
+ #
44
+ # @param path [String] Path to saved model directory
45
+ # @return [ImageClassifier]
46
+ def self.load(path)
47
+ config_path = File.join(path, "config.json")
48
+ raise ModelNotFoundError.new(path, "Model not found") unless File.exist?(config_path)
49
+
50
+ config_data = JSON.parse(File.read(config_path))
51
+
52
+ classifier = allocate
53
+ classifier.instance_variable_set(:@model_id, config_data["_model_id"] || "custom")
54
+ classifier.instance_variable_set(:@config, Configuration.new)
55
+ classifier.instance_variable_set(:@trained, true)
56
+
57
+ # Load label map
58
+ if config_data["label2id"]
59
+ classifier.instance_variable_set(:@label_map, config_data["label2id"])
60
+ elsif config_data["id2label"]
61
+ classifier.instance_variable_set(
62
+ :@label_map,
63
+ config_data["id2label"].transform_keys(&:to_i).invert
64
+ )
65
+ end
66
+
67
+ # Load model
68
+ classifier.instance_variable_set(
69
+ :@model,
70
+ Models::SigLIP2ForImageClassification.load(path)
71
+ )
72
+
73
+ classifier
74
+ end
75
+
76
+ # Fine-tune the model on a dataset
77
+ #
78
+ # @param train_dir [String] Path to training data directory
79
+ # @param val_dir [String, nil] Path to validation data directory
80
+ # @param epochs [Integer, nil] Number of epochs (overrides config)
81
+ # @return [Array<Hash>] Training history
82
+ def fit(train_dir:, val_dir: nil, epochs: nil)
83
+ # Override epochs if provided
84
+ @config.epochs = epochs if epochs
85
+
86
+ # Load datasets
87
+ transforms = build_transforms
88
+ train_dataset = Datasets::ImageDataset.from_directory(train_dir, transforms: transforms)
89
+
90
+ val_dataset = if val_dir
91
+ Datasets::ImageDataset.from_directory(val_dir, transforms: transforms)
92
+ end
93
+
94
+ # Store label map
95
+ @label_map = train_dataset.label_map
96
+ num_classes = train_dataset.num_classes
97
+
98
+ # Load or create model
99
+ @model = Models::SigLIP2ForImageClassification.from_pretrained(
100
+ @model_id,
101
+ num_labels: num_classes,
102
+ freeze_encoder: @config.freeze_encoder,
103
+ dropout: @config.dropout
104
+ )
105
+
106
+ # Create trainer and train
107
+ trainer = Training::Trainer.new(
108
+ @model,
109
+ @config,
110
+ train_dataset: train_dataset,
111
+ val_dataset: val_dataset
112
+ )
113
+
114
+ history = trainer.fit
115
+ @trained = true
116
+
117
+ history
118
+ end
119
+
120
+ # Fine-tune with explicit train/val split
121
+ #
122
+ # @param data_dir [String] Path to data directory
123
+ # @param val_split [Float] Fraction of data to use for validation
124
+ # @param epochs [Integer, nil] Number of epochs
125
+ # @return [Array<Hash>] Training history
126
+ def fit_with_split(data_dir:, val_split: 0.2, epochs: nil)
127
+ @config.epochs = epochs if epochs
128
+
129
+ transforms = build_transforms
130
+ full_dataset = Datasets::ImageDataset.from_directory(data_dir, transforms: transforms)
131
+ train_dataset, val_dataset = full_dataset.split(test_size: val_split)
132
+
133
+ @label_map = train_dataset.label_map
134
+ num_classes = train_dataset.num_classes
135
+
136
+ @model = Models::SigLIP2ForImageClassification.from_pretrained(
137
+ @model_id,
138
+ num_labels: num_classes,
139
+ freeze_encoder: @config.freeze_encoder,
140
+ dropout: @config.dropout
141
+ )
142
+
143
+ trainer = Training::Trainer.new(
144
+ @model,
145
+ @config,
146
+ train_dataset: train_dataset,
147
+ val_dataset: val_dataset
148
+ )
149
+
150
+ history = trainer.fit
151
+ @trained = true
152
+
153
+ history
154
+ end
155
+
156
+ # Make predictions on images
157
+ #
158
+ # @param images [String, Array<String>] Path(s) to image file(s)
159
+ # @param top_k [Integer] Number of top predictions to return
160
+ # @return [Array<Hash>] Predictions with :label and :score
161
+ def predict(images, top_k: 5)
162
+ raise TrainingError, "Model not trained or loaded" unless @trained && @model
163
+
164
+ images = [images] if images.is_a?(String)
165
+ transforms = build_inference_transforms
166
+
167
+ # Load and transform images
168
+ tensors = images.map do |path|
169
+ image = Vips::Image.new_from_file(path, access: :sequential)
170
+ transforms.call(image)
171
+ end
172
+
173
+ # Stack into batch
174
+ pixel_values = Torch.stack(tensors)
175
+
176
+ # Get predictions
177
+ @model.eval
178
+ probs = @model.predict_proba(pixel_values)
179
+
180
+ # Convert to result format
181
+ inverse_label_map = @label_map.invert
182
+
183
+ probs.to_a.map do |sample_probs|
184
+ # Get top-k predictions
185
+ sorted = sample_probs.each_with_index.sort_by { |prob, _| -prob }
186
+ top = sorted.first(top_k)
187
+
188
+ top.map do |prob, idx|
189
+ {
190
+ label: inverse_label_map[idx] || idx.to_s,
191
+ score: prob.round(4)
192
+ }
193
+ end
194
+ end
195
+ end
196
+
197
+ # Save the fine-tuned model
198
+ #
199
+ # @param path [String] Directory path to save to
200
+ def save(path)
201
+ raise TrainingError, "Model not trained" unless @trained && @model
202
+
203
+ @model.save(path, label_map: @label_map)
204
+
205
+ # Also save the original model ID for reference
206
+ config_path = File.join(path, "config.json")
207
+ config = JSON.parse(File.read(config_path))
208
+ config["_model_id"] = @model_id
209
+ File.write(config_path, JSON.pretty_generate(config))
210
+ end
211
+
212
+ # Get class names in order of their IDs
213
+ #
214
+ # @return [Array<String>] Class names
215
+ def class_names
216
+ return [] unless @label_map
217
+
218
+ @label_map.sort_by { |_, v| v }.map(&:first)
219
+ end
220
+
221
+ # Export to ONNX format
222
+ #
223
+ # @param path [String] Output path for ONNX file
224
+ # @param options [Hash] Export options
225
+ # @return [String] The output path
226
+ def export_onnx(path, **options)
227
+ Export.to_onnx(self, path, **options)
228
+ end
229
+
230
+ private
231
+
232
+ def build_transforms
233
+ transforms = []
234
+
235
+ # Add augmentation transforms if configured
236
+ if @config.augmentation_config.enabled?
237
+ transforms.concat(@config.augmentation_config.to_transforms)
238
+ end
239
+
240
+ # Core transforms
241
+ transforms << Transforms::Resize.new(@config.image_size)
242
+ transforms << Transforms::ToTensor.new
243
+ transforms << Transforms::Normalize.new
244
+
245
+ Transforms::Compose.new(transforms)
246
+ end
247
+
248
+ def build_inference_transforms
249
+ Transforms::Compose.new([
250
+ Transforms::Resize.new(@config.image_size),
251
+ Transforms::ToTensor.new,
252
+ Transforms::Normalize.new
253
+ ])
254
+ end
255
+ end
256
+ end