fine 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.rspec +3 -0
- data/CHANGELOG.md +38 -0
- data/Gemfile +6 -0
- data/Gemfile.lock +167 -0
- data/LICENSE +21 -0
- data/README.md +212 -0
- data/Rakefile +6 -0
- data/docs/installation.md +151 -0
- data/docs/tutorials/llm-fine-tuning.md +246 -0
- data/docs/tutorials/model-export.md +200 -0
- data/docs/tutorials/siglip2-image-classification.md +130 -0
- data/docs/tutorials/siglip2-object-recognition.md +203 -0
- data/docs/tutorials/siglip2-similarity-search.md +152 -0
- data/docs/tutorials/text-classification.md +233 -0
- data/docs/tutorials/text-embeddings.md +211 -0
- data/examples/basic_classification.rb +70 -0
- data/examples/data/tool_calls.jsonl +30 -0
- data/examples/demo_training.rb +78 -0
- data/examples/finetune_gemma3_tools.rb +135 -0
- data/examples/real_llm_test.rb +128 -0
- data/examples/real_text_classification_test.rb +90 -0
- data/examples/real_text_embedder_test.rb +110 -0
- data/examples/real_training_test.rb +88 -0
- data/examples/test_export.rb +28 -0
- data/examples/test_image_classifier.rb +79 -0
- data/examples/test_llm.rb +100 -0
- data/examples/test_text_classifier.rb +59 -0
- data/lib/fine/callbacks/base.rb +140 -0
- data/lib/fine/callbacks/progress_bar.rb +66 -0
- data/lib/fine/configuration.rb +106 -0
- data/lib/fine/datasets/data_loader.rb +63 -0
- data/lib/fine/datasets/image_dataset.rb +203 -0
- data/lib/fine/datasets/instruction_dataset.rb +226 -0
- data/lib/fine/datasets/text_data_loader.rb +88 -0
- data/lib/fine/datasets/text_dataset.rb +266 -0
- data/lib/fine/error.rb +49 -0
- data/lib/fine/export/gguf_exporter.rb +424 -0
- data/lib/fine/export/onnx_exporter.rb +249 -0
- data/lib/fine/export.rb +53 -0
- data/lib/fine/hub/config_loader.rb +145 -0
- data/lib/fine/hub/model_downloader.rb +136 -0
- data/lib/fine/hub/safetensors_loader.rb +108 -0
- data/lib/fine/image_classifier.rb +256 -0
- data/lib/fine/llm.rb +336 -0
- data/lib/fine/models/base.rb +48 -0
- data/lib/fine/models/bert_encoder.rb +202 -0
- data/lib/fine/models/bert_for_sequence_classification.rb +226 -0
- data/lib/fine/models/causal_lm.rb +279 -0
- data/lib/fine/models/classification_head.rb +24 -0
- data/lib/fine/models/gemma3_decoder.rb +244 -0
- data/lib/fine/models/llama_decoder.rb +297 -0
- data/lib/fine/models/sentence_transformer.rb +202 -0
- data/lib/fine/models/siglip2_for_image_classification.rb +155 -0
- data/lib/fine/models/siglip2_vision_encoder.rb +190 -0
- data/lib/fine/text_classifier.rb +250 -0
- data/lib/fine/text_embedder.rb +221 -0
- data/lib/fine/tokenizers/auto_tokenizer.rb +208 -0
- data/lib/fine/training/llm_trainer.rb +212 -0
- data/lib/fine/training/text_trainer.rb +275 -0
- data/lib/fine/training/trainer.rb +194 -0
- data/lib/fine/transforms/compose.rb +28 -0
- data/lib/fine/transforms/normalize.rb +33 -0
- data/lib/fine/transforms/resize.rb +35 -0
- data/lib/fine/transforms/to_tensor.rb +53 -0
- data/lib/fine/version.rb +3 -0
- data/lib/fine.rb +112 -0
- data/mise.toml +2 -0
- metadata +240 -0
data/lib/fine/export.rb
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require_relative "export/onnx_exporter"
|
|
4
|
+
require_relative "export/gguf_exporter"
|
|
5
|
+
|
|
6
|
+
module Fine
|
|
7
|
+
# Export models to various deployment formats
|
|
8
|
+
#
|
|
9
|
+
# @example Export to ONNX
|
|
10
|
+
# Fine::Export.to_onnx(classifier, "model.onnx")
|
|
11
|
+
#
|
|
12
|
+
# @example Export LLM to GGUF
|
|
13
|
+
# Fine::Export.to_gguf(llm, "model.gguf", quantization: :q4_0)
|
|
14
|
+
module Export
|
|
15
|
+
class << self
|
|
16
|
+
# Export any Fine model to ONNX format
|
|
17
|
+
#
|
|
18
|
+
# @param model [TextClassifier, TextEmbedder, ImageClassifier, LLM] The model
|
|
19
|
+
# @param path [String] Output path
|
|
20
|
+
# @param options [Hash] Export options
|
|
21
|
+
# @return [String] The output path
|
|
22
|
+
def to_onnx(model, path, **options)
|
|
23
|
+
ONNXExporter.export(model, path, **options)
|
|
24
|
+
end
|
|
25
|
+
|
|
26
|
+
# Export LLM to GGUF format
|
|
27
|
+
#
|
|
28
|
+
# @param model [LLM] The LLM model
|
|
29
|
+
# @param path [String] Output path
|
|
30
|
+
# @param quantization [Symbol] Quantization type (:f16, :q4_0, :q8_0, etc.)
|
|
31
|
+
# @param metadata [Hash] Additional metadata
|
|
32
|
+
# @return [String] The output path
|
|
33
|
+
def to_gguf(model, path, quantization: :f16, **options)
|
|
34
|
+
GGUFExporter.export(model, path, quantization: quantization, **options)
|
|
35
|
+
end
|
|
36
|
+
|
|
37
|
+
# List available quantization options for GGUF
|
|
38
|
+
#
|
|
39
|
+
# @return [Hash] Quantization types with descriptions
|
|
40
|
+
def gguf_quantization_options
|
|
41
|
+
{
|
|
42
|
+
f32: "32-bit float (largest, no quality loss)",
|
|
43
|
+
f16: "16-bit float (good balance)",
|
|
44
|
+
q8_0: "8-bit quantization (smaller, minimal quality loss)",
|
|
45
|
+
q4_0: "4-bit quantization (smallest, some quality loss)",
|
|
46
|
+
q4_k: "4-bit K-quant (better quality than q4_0)",
|
|
47
|
+
q5_k: "5-bit K-quant (good quality/size balance)",
|
|
48
|
+
q6_k: "6-bit K-quant (high quality)"
|
|
49
|
+
}
|
|
50
|
+
end
|
|
51
|
+
end
|
|
52
|
+
end
|
|
53
|
+
end
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fine
|
|
4
|
+
module Hub
|
|
5
|
+
# Loads and parses model configuration from Hugging Face format
|
|
6
|
+
class ConfigLoader
|
|
7
|
+
attr_reader :config
|
|
8
|
+
|
|
9
|
+
def initialize(config_path)
|
|
10
|
+
@config_path = config_path
|
|
11
|
+
@config = load_config
|
|
12
|
+
end
|
|
13
|
+
|
|
14
|
+
def self.from_pretrained(model_path)
|
|
15
|
+
config_path = File.join(model_path, "config.json")
|
|
16
|
+
raise ConfigurationError, "Config not found: #{config_path}" unless File.exist?(config_path)
|
|
17
|
+
|
|
18
|
+
new(config_path)
|
|
19
|
+
end
|
|
20
|
+
|
|
21
|
+
# Vision encoder configuration
|
|
22
|
+
def hidden_size
|
|
23
|
+
vision_config["hidden_size"] || config["hidden_size"] || 768
|
|
24
|
+
end
|
|
25
|
+
|
|
26
|
+
def num_hidden_layers
|
|
27
|
+
vision_config["num_hidden_layers"] || config["num_hidden_layers"] || 12
|
|
28
|
+
end
|
|
29
|
+
|
|
30
|
+
def num_attention_heads
|
|
31
|
+
vision_config["num_attention_heads"] || config["num_attention_heads"] || 12
|
|
32
|
+
end
|
|
33
|
+
|
|
34
|
+
def intermediate_size
|
|
35
|
+
vision_config["intermediate_size"] || config["intermediate_size"] || 3072
|
|
36
|
+
end
|
|
37
|
+
|
|
38
|
+
def image_size
|
|
39
|
+
vision_config["image_size"] || config["image_size"] || 224
|
|
40
|
+
end
|
|
41
|
+
|
|
42
|
+
def patch_size
|
|
43
|
+
vision_config["patch_size"] || config["patch_size"] || 16
|
|
44
|
+
end
|
|
45
|
+
|
|
46
|
+
def num_channels
|
|
47
|
+
vision_config["num_channels"] || config["num_channels"] || 3
|
|
48
|
+
end
|
|
49
|
+
|
|
50
|
+
def layer_norm_eps
|
|
51
|
+
vision_config["layer_norm_eps"] || config["layer_norm_eps"] || 1e-6
|
|
52
|
+
end
|
|
53
|
+
|
|
54
|
+
def hidden_act
|
|
55
|
+
vision_config["hidden_act"] || config["hidden_act"] || "gelu"
|
|
56
|
+
end
|
|
57
|
+
|
|
58
|
+
def attention_dropout
|
|
59
|
+
vision_config["attention_dropout"] || config["attention_dropout"] || 0.0
|
|
60
|
+
end
|
|
61
|
+
|
|
62
|
+
# Text model configuration (BERT, DistilBERT, DeBERTa, etc.)
|
|
63
|
+
def vocab_size
|
|
64
|
+
config["vocab_size"] || 30522
|
|
65
|
+
end
|
|
66
|
+
|
|
67
|
+
def max_position_embeddings
|
|
68
|
+
config["max_position_embeddings"] || 512
|
|
69
|
+
end
|
|
70
|
+
|
|
71
|
+
def type_vocab_size
|
|
72
|
+
config["type_vocab_size"] || 2
|
|
73
|
+
end
|
|
74
|
+
|
|
75
|
+
def hidden_dropout_prob
|
|
76
|
+
config["hidden_dropout_prob"] || config["dropout"] || 0.1
|
|
77
|
+
end
|
|
78
|
+
|
|
79
|
+
# LLM configuration (Llama, Gemma, etc.)
|
|
80
|
+
def rope_theta
|
|
81
|
+
config["rope_theta"] || 10000.0
|
|
82
|
+
end
|
|
83
|
+
|
|
84
|
+
def rms_norm_eps
|
|
85
|
+
config["rms_norm_eps"] || 1e-6
|
|
86
|
+
end
|
|
87
|
+
|
|
88
|
+
def num_key_value_heads
|
|
89
|
+
config["num_key_value_heads"] || num_attention_heads
|
|
90
|
+
end
|
|
91
|
+
|
|
92
|
+
def use_bias
|
|
93
|
+
config["use_bias"] != false
|
|
94
|
+
end
|
|
95
|
+
|
|
96
|
+
def head_dim
|
|
97
|
+
config["head_dim"] || (hidden_size / num_attention_heads)
|
|
98
|
+
end
|
|
99
|
+
|
|
100
|
+
def attention_bias
|
|
101
|
+
config["attention_bias"] || false
|
|
102
|
+
end
|
|
103
|
+
|
|
104
|
+
def use_qk_norm
|
|
105
|
+
# Gemma 3 uses QK normalization
|
|
106
|
+
config.key?("query_pre_attn_scalar") || model_type&.include?("gemma3")
|
|
107
|
+
end
|
|
108
|
+
|
|
109
|
+
def use_pre_feedforward_layernorm
|
|
110
|
+
# Gemma 3 has additional layer norms
|
|
111
|
+
model_type&.include?("gemma3")
|
|
112
|
+
end
|
|
113
|
+
|
|
114
|
+
# Computed properties
|
|
115
|
+
def num_patches
|
|
116
|
+
(image_size / patch_size) ** 2
|
|
117
|
+
end
|
|
118
|
+
|
|
119
|
+
def model_type
|
|
120
|
+
config["model_type"]
|
|
121
|
+
end
|
|
122
|
+
|
|
123
|
+
# Raw access to config sections
|
|
124
|
+
def vision_config
|
|
125
|
+
config["vision_config"] || {}
|
|
126
|
+
end
|
|
127
|
+
|
|
128
|
+
def text_config
|
|
129
|
+
config["text_config"] || {}
|
|
130
|
+
end
|
|
131
|
+
|
|
132
|
+
def to_h
|
|
133
|
+
@config
|
|
134
|
+
end
|
|
135
|
+
|
|
136
|
+
private
|
|
137
|
+
|
|
138
|
+
def load_config
|
|
139
|
+
JSON.parse(File.read(@config_path))
|
|
140
|
+
rescue JSON::ParserError => e
|
|
141
|
+
raise ConfigurationError, "Invalid JSON in config: #{e.message}"
|
|
142
|
+
end
|
|
143
|
+
end
|
|
144
|
+
end
|
|
145
|
+
end
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fine
|
|
4
|
+
module Hub
|
|
5
|
+
# Downloads models from Hugging Face Hub
|
|
6
|
+
class ModelDownloader
|
|
7
|
+
HF_HUB_URL = "https://huggingface.co"
|
|
8
|
+
REQUIRED_FILES = %w[config.json].freeze
|
|
9
|
+
MODEL_FILES = %w[model.safetensors].freeze
|
|
10
|
+
OPTIONAL_FILES = %w[preprocessor_config.json tokenizer_config.json tokenizer.json].freeze
|
|
11
|
+
|
|
12
|
+
attr_reader :model_id, :cache_path
|
|
13
|
+
|
|
14
|
+
def initialize(model_id, cache_dir: nil)
|
|
15
|
+
@model_id = model_id
|
|
16
|
+
@cache_dir = cache_dir || Fine.cache_dir
|
|
17
|
+
@cache_path = File.join(@cache_dir, "models", model_id.tr("/", "--"))
|
|
18
|
+
end
|
|
19
|
+
|
|
20
|
+
# Download model files and return the cache path
|
|
21
|
+
def download(force: false)
|
|
22
|
+
return @cache_path if cached? && !force
|
|
23
|
+
|
|
24
|
+
FileUtils.mkdir_p(@cache_path)
|
|
25
|
+
|
|
26
|
+
download_required_files
|
|
27
|
+
download_optional_files
|
|
28
|
+
|
|
29
|
+
@cache_path
|
|
30
|
+
end
|
|
31
|
+
|
|
32
|
+
# Check if model is already cached
|
|
33
|
+
def cached?
|
|
34
|
+
return false unless File.directory?(@cache_path)
|
|
35
|
+
|
|
36
|
+
# Check required files
|
|
37
|
+
return false unless REQUIRED_FILES.all? { |f| File.exist?(File.join(@cache_path, f)) }
|
|
38
|
+
|
|
39
|
+
# Check for model weights (single file or sharded)
|
|
40
|
+
has_single_weights = File.exist?(File.join(@cache_path, "model.safetensors"))
|
|
41
|
+
has_sharded_weights = Dir.glob(File.join(@cache_path, "model-*.safetensors")).any?
|
|
42
|
+
|
|
43
|
+
has_single_weights || has_sharded_weights
|
|
44
|
+
end
|
|
45
|
+
|
|
46
|
+
# Get path to a specific file
|
|
47
|
+
def file_path(filename)
|
|
48
|
+
File.join(@cache_path, filename)
|
|
49
|
+
end
|
|
50
|
+
|
|
51
|
+
private
|
|
52
|
+
|
|
53
|
+
def download_required_files
|
|
54
|
+
REQUIRED_FILES.each do |filename|
|
|
55
|
+
download_file(filename)
|
|
56
|
+
end
|
|
57
|
+
|
|
58
|
+
# Download model weights (try single file first, then sharded)
|
|
59
|
+
download_model_weights
|
|
60
|
+
end
|
|
61
|
+
|
|
62
|
+
def download_model_weights
|
|
63
|
+
# Try single model file first
|
|
64
|
+
begin
|
|
65
|
+
download_file("model.safetensors")
|
|
66
|
+
return
|
|
67
|
+
rescue ModelNotFoundError
|
|
68
|
+
# Single file not found, try sharded
|
|
69
|
+
end
|
|
70
|
+
|
|
71
|
+
# Download index file to find shards
|
|
72
|
+
download_file("model.safetensors.index.json", required: false)
|
|
73
|
+
index_path = File.join(@cache_path, "model.safetensors.index.json")
|
|
74
|
+
|
|
75
|
+
if File.exist?(index_path)
|
|
76
|
+
index = JSON.parse(File.read(index_path))
|
|
77
|
+
weight_files = index["weight_map"].values.uniq
|
|
78
|
+
|
|
79
|
+
weight_files.each do |filename|
|
|
80
|
+
download_file(filename)
|
|
81
|
+
end
|
|
82
|
+
else
|
|
83
|
+
raise ModelNotFoundError.new(@model_id, "No model weights found")
|
|
84
|
+
end
|
|
85
|
+
end
|
|
86
|
+
|
|
87
|
+
def download_optional_files
|
|
88
|
+
OPTIONAL_FILES.each do |filename|
|
|
89
|
+
download_file(filename, required: false)
|
|
90
|
+
end
|
|
91
|
+
end
|
|
92
|
+
|
|
93
|
+
def download_file(filename, required: true)
|
|
94
|
+
local_path = File.join(@cache_path, filename)
|
|
95
|
+
return if File.exist?(local_path)
|
|
96
|
+
|
|
97
|
+
url = file_url(filename)
|
|
98
|
+
|
|
99
|
+
begin
|
|
100
|
+
puts "Downloading #{filename}..." if Fine.configuration&.progress_bar != false
|
|
101
|
+
|
|
102
|
+
headers = { "User-Agent" => "fine-ruby/#{Fine::VERSION}" }
|
|
103
|
+
|
|
104
|
+
# Add HuggingFace token if available
|
|
105
|
+
if (token = hf_token)
|
|
106
|
+
headers["Authorization"] = "Bearer #{token}"
|
|
107
|
+
end
|
|
108
|
+
|
|
109
|
+
tempfile = Down.download(url, headers: headers)
|
|
110
|
+
|
|
111
|
+
FileUtils.mv(tempfile.path, local_path)
|
|
112
|
+
rescue Down::NotFound
|
|
113
|
+
raise ModelNotFoundError.new(@model_id, "File not found: #{filename}") if required
|
|
114
|
+
rescue Down::Error => e
|
|
115
|
+
raise ModelNotFoundError.new(@model_id, "Download failed: #{e.message}") if required
|
|
116
|
+
end
|
|
117
|
+
end
|
|
118
|
+
|
|
119
|
+
def hf_token
|
|
120
|
+
# Check environment variable first
|
|
121
|
+
return ENV["HF_TOKEN"] if ENV["HF_TOKEN"]
|
|
122
|
+
return ENV["HUGGING_FACE_HUB_TOKEN"] if ENV["HUGGING_FACE_HUB_TOKEN"]
|
|
123
|
+
|
|
124
|
+
# Check standard HuggingFace cache location
|
|
125
|
+
token_path = File.expand_path("~/.cache/huggingface/token")
|
|
126
|
+
return File.read(token_path).strip if File.exist?(token_path)
|
|
127
|
+
|
|
128
|
+
nil
|
|
129
|
+
end
|
|
130
|
+
|
|
131
|
+
def file_url(filename)
|
|
132
|
+
"#{HF_HUB_URL}/#{@model_id}/resolve/main/#{filename}"
|
|
133
|
+
end
|
|
134
|
+
end
|
|
135
|
+
end
|
|
136
|
+
end
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fine
|
|
4
|
+
module Hub
|
|
5
|
+
# Loads SafeTensors weights into torch.rb models
|
|
6
|
+
class SafetensorsLoader
|
|
7
|
+
# Load weights from safetensors file into a model
|
|
8
|
+
#
|
|
9
|
+
# @param model [Torch::NN::Module] The model to load weights into
|
|
10
|
+
# @param path [String] Path to the safetensors file
|
|
11
|
+
# @param strict [Boolean] If true, raise error on missing/unexpected keys
|
|
12
|
+
# @param prefix [String] Prefix to add/remove from weight names
|
|
13
|
+
# @return [Hash] Hash with :missing_keys and :unexpected_keys arrays
|
|
14
|
+
def self.load_into_model(model, path, strict: false, prefix: nil)
|
|
15
|
+
tensors = Safetensors::Torch.load_file(path)
|
|
16
|
+
|
|
17
|
+
# Get model's state dict keys
|
|
18
|
+
model_keys = model.state_dict.keys
|
|
19
|
+
|
|
20
|
+
# Map and filter tensors
|
|
21
|
+
mapped_tensors = {}
|
|
22
|
+
unexpected_keys = []
|
|
23
|
+
|
|
24
|
+
tensors.each do |name, tensor|
|
|
25
|
+
mapped_name = map_weight_name(name, prefix: prefix)
|
|
26
|
+
|
|
27
|
+
if model_keys.include?(mapped_name)
|
|
28
|
+
mapped_tensors[mapped_name] = tensor
|
|
29
|
+
else
|
|
30
|
+
unexpected_keys << name
|
|
31
|
+
end
|
|
32
|
+
end
|
|
33
|
+
|
|
34
|
+
# Find missing keys
|
|
35
|
+
missing_keys = model_keys - mapped_tensors.keys
|
|
36
|
+
|
|
37
|
+
# Raise error if strict mode and there are issues
|
|
38
|
+
if strict && (missing_keys.any? || unexpected_keys.any?)
|
|
39
|
+
raise WeightLoadingError.new(
|
|
40
|
+
"Weight loading failed",
|
|
41
|
+
missing_keys: missing_keys,
|
|
42
|
+
unexpected_keys: unexpected_keys
|
|
43
|
+
)
|
|
44
|
+
end
|
|
45
|
+
|
|
46
|
+
# Load the mapped tensors by manually copying data
|
|
47
|
+
# (torch.rb doesn't support strict: false yet)
|
|
48
|
+
# Use no_grad to avoid in-place operation errors on leaf tensors
|
|
49
|
+
Torch.no_grad do
|
|
50
|
+
state_dict = model.state_dict
|
|
51
|
+
mapped_tensors.each do |name, tensor|
|
|
52
|
+
if state_dict.key?(name)
|
|
53
|
+
state_dict[name].copy!(tensor)
|
|
54
|
+
end
|
|
55
|
+
end
|
|
56
|
+
end
|
|
57
|
+
|
|
58
|
+
{ missing_keys: missing_keys, unexpected_keys: unexpected_keys }
|
|
59
|
+
end
|
|
60
|
+
|
|
61
|
+
# Map HuggingFace weight names to torch.rb model structure
|
|
62
|
+
#
|
|
63
|
+
# @param hf_name [String] Original HuggingFace weight name
|
|
64
|
+
# @param prefix [String] Optional prefix to strip
|
|
65
|
+
# @return [String] Mapped weight name for torch.rb
|
|
66
|
+
def self.map_weight_name(hf_name, prefix: nil)
|
|
67
|
+
name = hf_name.dup
|
|
68
|
+
|
|
69
|
+
# Strip prefix if provided
|
|
70
|
+
name = name.sub(/^#{Regexp.escape(prefix)}\.?/, "") if prefix
|
|
71
|
+
|
|
72
|
+
# SigLIP2 specific mappings
|
|
73
|
+
# Note: prefix (e.g., "vision_model") is already stripped above
|
|
74
|
+
name = name.sub("embeddings.patch_embedding", "patch_embed.proj")
|
|
75
|
+
name = name.sub("embeddings.position_embedding.weight", "pos_embed")
|
|
76
|
+
name = name.sub("encoder.layers", "blocks")
|
|
77
|
+
name = name.sub("post_layernorm", "norm")
|
|
78
|
+
name = name.sub("head", "head")
|
|
79
|
+
|
|
80
|
+
# Transformer block mappings
|
|
81
|
+
name = name.gsub(".self_attn.", ".attn.")
|
|
82
|
+
name = name.gsub(".layer_norm1.", ".norm1.")
|
|
83
|
+
name = name.gsub(".layer_norm2.", ".norm2.")
|
|
84
|
+
# mlp.fc1, mlp.fc2, q_proj, k_proj, v_proj, out_proj names match our model
|
|
85
|
+
|
|
86
|
+
name
|
|
87
|
+
end
|
|
88
|
+
|
|
89
|
+
# Load raw tensors from safetensors file
|
|
90
|
+
#
|
|
91
|
+
# @param path [String] Path to safetensors file
|
|
92
|
+
# @return [Hash<String, Torch::Tensor>] Hash of tensor name to tensor
|
|
93
|
+
def self.load_file(path)
|
|
94
|
+
Safetensors::Torch.load_file(path)
|
|
95
|
+
end
|
|
96
|
+
|
|
97
|
+
# List tensor names in a safetensors file
|
|
98
|
+
#
|
|
99
|
+
# @param path [String] Path to safetensors file
|
|
100
|
+
# @return [Array<String>] List of tensor names
|
|
101
|
+
def self.tensor_names(path)
|
|
102
|
+
Safetensors.safe_open(path, framework: "torch") do |f|
|
|
103
|
+
f.keys
|
|
104
|
+
end
|
|
105
|
+
end
|
|
106
|
+
end
|
|
107
|
+
end
|
|
108
|
+
end
|
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Fine
|
|
4
|
+
# High-level API for image classification fine-tuning
|
|
5
|
+
#
|
|
6
|
+
# @example Simple usage
|
|
7
|
+
# classifier = Fine::ImageClassifier.new("google/siglip2-base-patch16-224")
|
|
8
|
+
# classifier.fit(train_dir: "data/train", epochs: 3)
|
|
9
|
+
# classifier.save("my_model")
|
|
10
|
+
#
|
|
11
|
+
# @example With configuration
|
|
12
|
+
# classifier = Fine::ImageClassifier.new("google/siglip2-base-patch16-224") do |config|
|
|
13
|
+
# config.learning_rate = 2e-4
|
|
14
|
+
# config.batch_size = 32
|
|
15
|
+
# config.on_epoch_end { |epoch, metrics| puts "Epoch #{epoch}: #{metrics}" }
|
|
16
|
+
# end
|
|
17
|
+
#
|
|
18
|
+
class ImageClassifier
|
|
19
|
+
attr_reader :model, :config, :label_map, :model_id
|
|
20
|
+
|
|
21
|
+
# Create a new ImageClassifier
|
|
22
|
+
#
|
|
23
|
+
# @param model_id [String] Hugging Face model ID (e.g., "google/siglip2-base-patch16-224")
|
|
24
|
+
# @yield [config] Optional configuration block
|
|
25
|
+
# @yieldparam config [Configuration] Configuration object
|
|
26
|
+
def initialize(model_id, &block)
|
|
27
|
+
@model_id = model_id
|
|
28
|
+
@config = Configuration.new
|
|
29
|
+
@model = nil
|
|
30
|
+
@label_map = nil
|
|
31
|
+
@trained = false
|
|
32
|
+
|
|
33
|
+
# Apply configuration block
|
|
34
|
+
block&.call(@config)
|
|
35
|
+
|
|
36
|
+
# Add default progress bar if none configured
|
|
37
|
+
if @config.callbacks.empty? && Fine.configuration&.progress_bar != false
|
|
38
|
+
@config.callbacks << Callbacks::ProgressBar.new
|
|
39
|
+
end
|
|
40
|
+
end
|
|
41
|
+
|
|
42
|
+
# Load a fine-tuned classifier from disk
|
|
43
|
+
#
|
|
44
|
+
# @param path [String] Path to saved model directory
|
|
45
|
+
# @return [ImageClassifier]
|
|
46
|
+
def self.load(path)
|
|
47
|
+
config_path = File.join(path, "config.json")
|
|
48
|
+
raise ModelNotFoundError.new(path, "Model not found") unless File.exist?(config_path)
|
|
49
|
+
|
|
50
|
+
config_data = JSON.parse(File.read(config_path))
|
|
51
|
+
|
|
52
|
+
classifier = allocate
|
|
53
|
+
classifier.instance_variable_set(:@model_id, config_data["_model_id"] || "custom")
|
|
54
|
+
classifier.instance_variable_set(:@config, Configuration.new)
|
|
55
|
+
classifier.instance_variable_set(:@trained, true)
|
|
56
|
+
|
|
57
|
+
# Load label map
|
|
58
|
+
if config_data["label2id"]
|
|
59
|
+
classifier.instance_variable_set(:@label_map, config_data["label2id"])
|
|
60
|
+
elsif config_data["id2label"]
|
|
61
|
+
classifier.instance_variable_set(
|
|
62
|
+
:@label_map,
|
|
63
|
+
config_data["id2label"].transform_keys(&:to_i).invert
|
|
64
|
+
)
|
|
65
|
+
end
|
|
66
|
+
|
|
67
|
+
# Load model
|
|
68
|
+
classifier.instance_variable_set(
|
|
69
|
+
:@model,
|
|
70
|
+
Models::SigLIP2ForImageClassification.load(path)
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
classifier
|
|
74
|
+
end
|
|
75
|
+
|
|
76
|
+
# Fine-tune the model on a dataset
|
|
77
|
+
#
|
|
78
|
+
# @param train_dir [String] Path to training data directory
|
|
79
|
+
# @param val_dir [String, nil] Path to validation data directory
|
|
80
|
+
# @param epochs [Integer, nil] Number of epochs (overrides config)
|
|
81
|
+
# @return [Array<Hash>] Training history
|
|
82
|
+
def fit(train_dir:, val_dir: nil, epochs: nil)
|
|
83
|
+
# Override epochs if provided
|
|
84
|
+
@config.epochs = epochs if epochs
|
|
85
|
+
|
|
86
|
+
# Load datasets
|
|
87
|
+
transforms = build_transforms
|
|
88
|
+
train_dataset = Datasets::ImageDataset.from_directory(train_dir, transforms: transforms)
|
|
89
|
+
|
|
90
|
+
val_dataset = if val_dir
|
|
91
|
+
Datasets::ImageDataset.from_directory(val_dir, transforms: transforms)
|
|
92
|
+
end
|
|
93
|
+
|
|
94
|
+
# Store label map
|
|
95
|
+
@label_map = train_dataset.label_map
|
|
96
|
+
num_classes = train_dataset.num_classes
|
|
97
|
+
|
|
98
|
+
# Load or create model
|
|
99
|
+
@model = Models::SigLIP2ForImageClassification.from_pretrained(
|
|
100
|
+
@model_id,
|
|
101
|
+
num_labels: num_classes,
|
|
102
|
+
freeze_encoder: @config.freeze_encoder,
|
|
103
|
+
dropout: @config.dropout
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# Create trainer and train
|
|
107
|
+
trainer = Training::Trainer.new(
|
|
108
|
+
@model,
|
|
109
|
+
@config,
|
|
110
|
+
train_dataset: train_dataset,
|
|
111
|
+
val_dataset: val_dataset
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
history = trainer.fit
|
|
115
|
+
@trained = true
|
|
116
|
+
|
|
117
|
+
history
|
|
118
|
+
end
|
|
119
|
+
|
|
120
|
+
# Fine-tune with explicit train/val split
|
|
121
|
+
#
|
|
122
|
+
# @param data_dir [String] Path to data directory
|
|
123
|
+
# @param val_split [Float] Fraction of data to use for validation
|
|
124
|
+
# @param epochs [Integer, nil] Number of epochs
|
|
125
|
+
# @return [Array<Hash>] Training history
|
|
126
|
+
def fit_with_split(data_dir:, val_split: 0.2, epochs: nil)
|
|
127
|
+
@config.epochs = epochs if epochs
|
|
128
|
+
|
|
129
|
+
transforms = build_transforms
|
|
130
|
+
full_dataset = Datasets::ImageDataset.from_directory(data_dir, transforms: transforms)
|
|
131
|
+
train_dataset, val_dataset = full_dataset.split(test_size: val_split)
|
|
132
|
+
|
|
133
|
+
@label_map = train_dataset.label_map
|
|
134
|
+
num_classes = train_dataset.num_classes
|
|
135
|
+
|
|
136
|
+
@model = Models::SigLIP2ForImageClassification.from_pretrained(
|
|
137
|
+
@model_id,
|
|
138
|
+
num_labels: num_classes,
|
|
139
|
+
freeze_encoder: @config.freeze_encoder,
|
|
140
|
+
dropout: @config.dropout
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
trainer = Training::Trainer.new(
|
|
144
|
+
@model,
|
|
145
|
+
@config,
|
|
146
|
+
train_dataset: train_dataset,
|
|
147
|
+
val_dataset: val_dataset
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
history = trainer.fit
|
|
151
|
+
@trained = true
|
|
152
|
+
|
|
153
|
+
history
|
|
154
|
+
end
|
|
155
|
+
|
|
156
|
+
# Make predictions on images
|
|
157
|
+
#
|
|
158
|
+
# @param images [String, Array<String>] Path(s) to image file(s)
|
|
159
|
+
# @param top_k [Integer] Number of top predictions to return
|
|
160
|
+
# @return [Array<Hash>] Predictions with :label and :score
|
|
161
|
+
def predict(images, top_k: 5)
|
|
162
|
+
raise TrainingError, "Model not trained or loaded" unless @trained && @model
|
|
163
|
+
|
|
164
|
+
images = [images] if images.is_a?(String)
|
|
165
|
+
transforms = build_inference_transforms
|
|
166
|
+
|
|
167
|
+
# Load and transform images
|
|
168
|
+
tensors = images.map do |path|
|
|
169
|
+
image = Vips::Image.new_from_file(path, access: :sequential)
|
|
170
|
+
transforms.call(image)
|
|
171
|
+
end
|
|
172
|
+
|
|
173
|
+
# Stack into batch
|
|
174
|
+
pixel_values = Torch.stack(tensors)
|
|
175
|
+
|
|
176
|
+
# Get predictions
|
|
177
|
+
@model.eval
|
|
178
|
+
probs = @model.predict_proba(pixel_values)
|
|
179
|
+
|
|
180
|
+
# Convert to result format
|
|
181
|
+
inverse_label_map = @label_map.invert
|
|
182
|
+
|
|
183
|
+
probs.to_a.map do |sample_probs|
|
|
184
|
+
# Get top-k predictions
|
|
185
|
+
sorted = sample_probs.each_with_index.sort_by { |prob, _| -prob }
|
|
186
|
+
top = sorted.first(top_k)
|
|
187
|
+
|
|
188
|
+
top.map do |prob, idx|
|
|
189
|
+
{
|
|
190
|
+
label: inverse_label_map[idx] || idx.to_s,
|
|
191
|
+
score: prob.round(4)
|
|
192
|
+
}
|
|
193
|
+
end
|
|
194
|
+
end
|
|
195
|
+
end
|
|
196
|
+
|
|
197
|
+
# Save the fine-tuned model
|
|
198
|
+
#
|
|
199
|
+
# @param path [String] Directory path to save to
|
|
200
|
+
def save(path)
|
|
201
|
+
raise TrainingError, "Model not trained" unless @trained && @model
|
|
202
|
+
|
|
203
|
+
@model.save(path, label_map: @label_map)
|
|
204
|
+
|
|
205
|
+
# Also save the original model ID for reference
|
|
206
|
+
config_path = File.join(path, "config.json")
|
|
207
|
+
config = JSON.parse(File.read(config_path))
|
|
208
|
+
config["_model_id"] = @model_id
|
|
209
|
+
File.write(config_path, JSON.pretty_generate(config))
|
|
210
|
+
end
|
|
211
|
+
|
|
212
|
+
# Get class names in order of their IDs
|
|
213
|
+
#
|
|
214
|
+
# @return [Array<String>] Class names
|
|
215
|
+
def class_names
|
|
216
|
+
return [] unless @label_map
|
|
217
|
+
|
|
218
|
+
@label_map.sort_by { |_, v| v }.map(&:first)
|
|
219
|
+
end
|
|
220
|
+
|
|
221
|
+
# Export to ONNX format
|
|
222
|
+
#
|
|
223
|
+
# @param path [String] Output path for ONNX file
|
|
224
|
+
# @param options [Hash] Export options
|
|
225
|
+
# @return [String] The output path
|
|
226
|
+
def export_onnx(path, **options)
|
|
227
|
+
Export.to_onnx(self, path, **options)
|
|
228
|
+
end
|
|
229
|
+
|
|
230
|
+
private
|
|
231
|
+
|
|
232
|
+
def build_transforms
|
|
233
|
+
transforms = []
|
|
234
|
+
|
|
235
|
+
# Add augmentation transforms if configured
|
|
236
|
+
if @config.augmentation_config.enabled?
|
|
237
|
+
transforms.concat(@config.augmentation_config.to_transforms)
|
|
238
|
+
end
|
|
239
|
+
|
|
240
|
+
# Core transforms
|
|
241
|
+
transforms << Transforms::Resize.new(@config.image_size)
|
|
242
|
+
transforms << Transforms::ToTensor.new
|
|
243
|
+
transforms << Transforms::Normalize.new
|
|
244
|
+
|
|
245
|
+
Transforms::Compose.new(transforms)
|
|
246
|
+
end
|
|
247
|
+
|
|
248
|
+
def build_inference_transforms
|
|
249
|
+
Transforms::Compose.new([
|
|
250
|
+
Transforms::Resize.new(@config.image_size),
|
|
251
|
+
Transforms::ToTensor.new,
|
|
252
|
+
Transforms::Normalize.new
|
|
253
|
+
])
|
|
254
|
+
end
|
|
255
|
+
end
|
|
256
|
+
end
|