fine 0.1.0 → 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/README.md +20 -10
- data/docs/examples/image-classification-shapes.md +83 -0
- data/docs/examples/text-embeddings-faq.md +98 -0
- data/docs/quickstart.md +209 -0
- data/docs/tutorials/lora-tool-calling.md +306 -0
- data/examples/data/generate_tool_data.rb +261 -0
- data/examples/data/ollama_tool_calls.jsonl +40 -0
- data/examples/data/sentiment_reviews.jsonl +30 -0
- data/examples/data/shapes/circle/circle_1.jpg +0 -0
- data/examples/data/shapes/circle/circle_10.jpg +0 -0
- data/examples/data/shapes/circle/circle_2.jpg +0 -0
- data/examples/data/shapes/circle/circle_3.jpg +0 -0
- data/examples/data/shapes/circle/circle_4.jpg +0 -0
- data/examples/data/shapes/circle/circle_5.jpg +0 -0
- data/examples/data/shapes/circle/circle_6.jpg +0 -0
- data/examples/data/shapes/circle/circle_7.jpg +0 -0
- data/examples/data/shapes/circle/circle_8.jpg +0 -0
- data/examples/data/shapes/circle/circle_9.jpg +0 -0
- data/examples/data/shapes/square/square_1.jpg +0 -0
- data/examples/data/shapes/square/square_10.jpg +0 -0
- data/examples/data/shapes/square/square_2.jpg +0 -0
- data/examples/data/shapes/square/square_3.jpg +0 -0
- data/examples/data/shapes/square/square_4.jpg +0 -0
- data/examples/data/shapes/square/square_5.jpg +0 -0
- data/examples/data/shapes/square/square_6.jpg +0 -0
- data/examples/data/shapes/square/square_7.jpg +0 -0
- data/examples/data/shapes/square/square_8.jpg +0 -0
- data/examples/data/shapes/square/square_9.jpg +0 -0
- data/examples/data/shapes/triangle/triangle_1.jpg +0 -0
- data/examples/data/shapes/triangle/triangle_10.jpg +0 -0
- data/examples/data/shapes/triangle/triangle_2.jpg +0 -0
- data/examples/data/shapes/triangle/triangle_3.jpg +0 -0
- data/examples/data/shapes/triangle/triangle_4.jpg +0 -0
- data/examples/data/shapes/triangle/triangle_5.jpg +0 -0
- data/examples/data/shapes/triangle/triangle_6.jpg +0 -0
- data/examples/data/shapes/triangle/triangle_7.jpg +0 -0
- data/examples/data/shapes/triangle/triangle_8.jpg +0 -0
- data/examples/data/shapes/triangle/triangle_9.jpg +0 -0
- data/examples/data/support_faq_pairs.jsonl +30 -0
- data/examples/generate_shape_images.rb +94 -0
- data/examples/sentiment_classification.rb +87 -0
- data/examples/shape_classification.rb +87 -0
- data/examples/support_faq_embeddings.rb +105 -0
- data/examples/train_lora_tools.rb +218 -0
- data/lib/fine/configuration.rb +173 -15
- data/lib/fine/datasets/image_dataset.rb +14 -2
- data/lib/fine/datasets/instruction_dataset.rb +17 -2
- data/lib/fine/datasets/text_dataset.rb +15 -5
- data/lib/fine/hub/config_loader.rb +4 -4
- data/lib/fine/hub/safetensors_loader.rb +3 -2
- data/lib/fine/llm.rb +39 -10
- data/lib/fine/lora.rb +214 -0
- data/lib/fine/models/bert_encoder.rb +15 -6
- data/lib/fine/models/bert_for_sequence_classification.rb +35 -4
- data/lib/fine/models/causal_lm.rb +46 -5
- data/lib/fine/models/gemma3_decoder.rb +25 -6
- data/lib/fine/models/llama_decoder.rb +9 -8
- data/lib/fine/models/sentence_transformer.rb +1 -1
- data/lib/fine/tokenizers/auto_tokenizer.rb +15 -0
- data/lib/fine/training/text_trainer.rb +3 -1
- data/lib/fine/validators.rb +304 -0
- data/lib/fine/version.rb +1 -1
- data/lib/fine.rb +4 -0
- metadata +47 -2
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
#!/usr/bin/env ruby
|
|
2
|
+
# frozen_string_literal: true
|
|
3
|
+
|
|
4
|
+
# Train larger model with LoRA for efficient fine-tuning
|
|
5
|
+
# LoRA dramatically reduces memory by only training ~1% of parameters
|
|
6
|
+
|
|
7
|
+
require "bundler/setup"
|
|
8
|
+
require "fine"
|
|
9
|
+
|
|
10
|
+
MAX_MEMORY_GB = 40
|
|
11
|
+
MONITOR_INTERVAL = 2
|
|
12
|
+
|
|
13
|
+
def get_memory_usage_gb
|
|
14
|
+
`ps -o rss= -p #{Process.pid}`.strip.to_i / 1024.0 / 1024.0
|
|
15
|
+
end
|
|
16
|
+
|
|
17
|
+
puts "=" * 70
|
|
18
|
+
puts "LORA TOOL CALLING TRAINING"
|
|
19
|
+
puts "=" * 70
|
|
20
|
+
puts "Max memory limit: #{MAX_MEMORY_GB} GB"
|
|
21
|
+
|
|
22
|
+
max_memory_seen = 0.0
|
|
23
|
+
memory_exceeded = false
|
|
24
|
+
|
|
25
|
+
monitor_thread = Thread.new do
|
|
26
|
+
loop do
|
|
27
|
+
mem = get_memory_usage_gb
|
|
28
|
+
max_memory_seen = mem if mem > max_memory_seen
|
|
29
|
+
if mem > MAX_MEMORY_GB
|
|
30
|
+
memory_exceeded = true
|
|
31
|
+
Thread.main.raise(Interrupt, "Memory limit exceeded: #{mem.round(2)} GB")
|
|
32
|
+
end
|
|
33
|
+
sleep MONITOR_INTERVAL
|
|
34
|
+
rescue => e
|
|
35
|
+
break if e.is_a?(Interrupt)
|
|
36
|
+
end
|
|
37
|
+
end
|
|
38
|
+
|
|
39
|
+
begin
|
|
40
|
+
Fine.configure { |c| c.progress_bar = false }
|
|
41
|
+
|
|
42
|
+
# Use larger dataset
|
|
43
|
+
data_path = File.expand_path("data/ollama_tool_calls_large.jsonl", __dir__)
|
|
44
|
+
|
|
45
|
+
# Try 4B model first, fall back to 1B if memory issues
|
|
46
|
+
model_id = ARGV[0] || "google/gemma-3-1b-it"
|
|
47
|
+
|
|
48
|
+
puts "\n[1/6] Loading model: #{model_id}..."
|
|
49
|
+
model = Fine::Models::CausalLM.from_pretrained(model_id)
|
|
50
|
+
puts " Model loaded: #{get_memory_usage_gb.round(2)} GB"
|
|
51
|
+
|
|
52
|
+
puts "\n[2/6] Applying LoRA..."
|
|
53
|
+
# Apply LoRA to attention projections
|
|
54
|
+
Fine::LoRA.apply(
|
|
55
|
+
model,
|
|
56
|
+
rank: 32, # Higher rank = more capacity for structured output
|
|
57
|
+
alpha: 64, # Scaling factor
|
|
58
|
+
dropout: 0.05, # Light dropout for regularization
|
|
59
|
+
target_modules: %w[q_proj k_proj v_proj o_proj] # All attention projections
|
|
60
|
+
)
|
|
61
|
+
puts " LoRA applied: #{get_memory_usage_gb.round(2)} GB"
|
|
62
|
+
|
|
63
|
+
# Move to device
|
|
64
|
+
model.to(Fine.device)
|
|
65
|
+
model.train
|
|
66
|
+
puts " On #{Fine.device}: #{get_memory_usage_gb.round(2)} GB"
|
|
67
|
+
|
|
68
|
+
puts "\n[3/6] Loading tokenizer..."
|
|
69
|
+
downloader = Fine::Hub::ModelDownloader.new(model_id)
|
|
70
|
+
model_path = downloader.download
|
|
71
|
+
tokenizer = Fine::Tokenizers::AutoTokenizer.new(model_path, max_length: 384)
|
|
72
|
+
puts " Found tokenizer"
|
|
73
|
+
|
|
74
|
+
puts "\n[4/6] Loading training data..."
|
|
75
|
+
dataset = Fine::Datasets::InstructionDataset.from_jsonl(
|
|
76
|
+
data_path,
|
|
77
|
+
tokenizer: tokenizer,
|
|
78
|
+
format: :alpaca,
|
|
79
|
+
max_length: 384
|
|
80
|
+
)
|
|
81
|
+
puts " #{dataset.size} examples loaded"
|
|
82
|
+
|
|
83
|
+
data_loader = Fine::Datasets::InstructionDataLoader.new(
|
|
84
|
+
dataset,
|
|
85
|
+
batch_size: 1,
|
|
86
|
+
shuffle: true,
|
|
87
|
+
pad_token_id: tokenizer.pad_token_id
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
puts "\n[5/6] Training with LoRA..."
|
|
91
|
+
|
|
92
|
+
# Only get LoRA parameters for optimizer
|
|
93
|
+
lora_params = Fine::LoRA.trainable_parameters(model)
|
|
94
|
+
optimizer = Torch::Optim::AdamW.new(lora_params, lr: 1e-4) # Higher LR for LoRA
|
|
95
|
+
|
|
96
|
+
epochs = 15 # More epochs for structured output learning
|
|
97
|
+
total_loss = 0.0
|
|
98
|
+
step = 0
|
|
99
|
+
|
|
100
|
+
epochs.times do |epoch|
|
|
101
|
+
epoch_loss = 0.0
|
|
102
|
+
batch_count = 0
|
|
103
|
+
|
|
104
|
+
data_loader.each do |batch|
|
|
105
|
+
input_ids = batch[:input_ids].to(Fine.device)
|
|
106
|
+
labels = batch[:labels].to(Fine.device)
|
|
107
|
+
attention_mask = batch[:attention_mask].to(Fine.device)
|
|
108
|
+
|
|
109
|
+
# Forward
|
|
110
|
+
outputs = model.forward(input_ids, attention_mask: attention_mask, labels: labels)
|
|
111
|
+
loss = outputs[:loss]
|
|
112
|
+
|
|
113
|
+
# Backward
|
|
114
|
+
loss.backward
|
|
115
|
+
|
|
116
|
+
# Optimizer step
|
|
117
|
+
optimizer.step
|
|
118
|
+
optimizer.zero_grad
|
|
119
|
+
|
|
120
|
+
epoch_loss += loss.to(:float32).item
|
|
121
|
+
batch_count += 1
|
|
122
|
+
step += 1
|
|
123
|
+
end
|
|
124
|
+
|
|
125
|
+
avg_loss = epoch_loss / batch_count
|
|
126
|
+
mem = get_memory_usage_gb
|
|
127
|
+
puts " Epoch #{epoch + 1}: loss=#{avg_loss.round(4)} | Memory: #{mem.round(2)} GB"
|
|
128
|
+
end
|
|
129
|
+
|
|
130
|
+
puts "\n[6/6] Testing generation..."
|
|
131
|
+
|
|
132
|
+
model.eval
|
|
133
|
+
test_cases = [
|
|
134
|
+
{
|
|
135
|
+
prompt: "What's the weather in Tokyo?",
|
|
136
|
+
tools: "get_weather: Get current weather\n Parameters: location (string, required)"
|
|
137
|
+
},
|
|
138
|
+
{
|
|
139
|
+
prompt: "Calculate 50 + 25 * 2",
|
|
140
|
+
tools: "calculate: Math calculator\n Parameters: expression (string, required)"
|
|
141
|
+
},
|
|
142
|
+
{
|
|
143
|
+
prompt: "Search for Ruby tutorials",
|
|
144
|
+
tools: "search_web: Web search\n Parameters: query (string, required)"
|
|
145
|
+
}
|
|
146
|
+
]
|
|
147
|
+
|
|
148
|
+
test_cases.each do |tc|
|
|
149
|
+
full_prompt = <<~PROMPT
|
|
150
|
+
### Instruction:
|
|
151
|
+
#{tc[:prompt]}
|
|
152
|
+
|
|
153
|
+
### Input:
|
|
154
|
+
You have access to the following tools:
|
|
155
|
+
|
|
156
|
+
#{tc[:tools]}
|
|
157
|
+
|
|
158
|
+
Respond with a JSON tool call if a tool is needed.
|
|
159
|
+
|
|
160
|
+
### Response:
|
|
161
|
+
PROMPT
|
|
162
|
+
|
|
163
|
+
ids = tokenizer.encode_for_generation(full_prompt)
|
|
164
|
+
input_ids = Torch.tensor([ids]).to(Fine.device)
|
|
165
|
+
|
|
166
|
+
Torch.no_grad do
|
|
167
|
+
output_ids = model.generate(
|
|
168
|
+
input_ids,
|
|
169
|
+
max_new_tokens: 150,
|
|
170
|
+
temperature: 0.1,
|
|
171
|
+
do_sample: false,
|
|
172
|
+
eos_token_id: tokenizer.eos_token_id
|
|
173
|
+
)
|
|
174
|
+
response = tokenizer.decode(output_ids[0].to_a)
|
|
175
|
+
generated = response.split("### Response:").last.to_s.strip
|
|
176
|
+
|
|
177
|
+
puts "\n Q: #{tc[:prompt]}"
|
|
178
|
+
puts " A: #{generated[0..200]}"
|
|
179
|
+
|
|
180
|
+
begin
|
|
181
|
+
json = JSON.parse(generated)
|
|
182
|
+
if json["tool_calls"]
|
|
183
|
+
puts " [Valid Ollama format]"
|
|
184
|
+
end
|
|
185
|
+
rescue JSON::ParserError
|
|
186
|
+
puts " [Not valid JSON]"
|
|
187
|
+
end
|
|
188
|
+
end
|
|
189
|
+
end
|
|
190
|
+
|
|
191
|
+
puts "\n" + "=" * 70
|
|
192
|
+
save_path = "/tmp/gemma3-lora-tools"
|
|
193
|
+
|
|
194
|
+
# Merge LoRA weights for inference
|
|
195
|
+
puts "Merging LoRA weights..."
|
|
196
|
+
Fine::LoRA.merge!(model)
|
|
197
|
+
|
|
198
|
+
model.save(save_path)
|
|
199
|
+
tokenizer.save(save_path)
|
|
200
|
+
puts "Model saved to: #{save_path}"
|
|
201
|
+
puts "Max memory used: #{max_memory_seen.round(2)} GB"
|
|
202
|
+
puts "=" * 70
|
|
203
|
+
|
|
204
|
+
rescue Interrupt => e
|
|
205
|
+
if memory_exceeded
|
|
206
|
+
puts "\n\nTERMINATED: Memory limit exceeded!"
|
|
207
|
+
exit 1
|
|
208
|
+
else
|
|
209
|
+
puts "\n\nInterrupted"
|
|
210
|
+
exit 130
|
|
211
|
+
end
|
|
212
|
+
rescue => e
|
|
213
|
+
puts "\nFailed: #{e.class}: #{e.message}"
|
|
214
|
+
puts e.backtrace.first(10).join("\n")
|
|
215
|
+
exit 1
|
|
216
|
+
ensure
|
|
217
|
+
monitor_thread&.kill
|
|
218
|
+
end
|
data/lib/fine/configuration.rb
CHANGED
|
@@ -2,42 +2,95 @@
|
|
|
2
2
|
|
|
3
3
|
module Fine
|
|
4
4
|
# Configuration for training runs
|
|
5
|
+
#
|
|
6
|
+
# @example Basic usage
|
|
7
|
+
# Fine::TextClassifier.new("distilbert-base-uncased") do |config|
|
|
8
|
+
# config.epochs = 5
|
|
9
|
+
# config.batch_size = 16
|
|
10
|
+
# end
|
|
11
|
+
#
|
|
12
|
+
# @example With callbacks
|
|
13
|
+
# config.on_epoch_end do |epoch, metrics|
|
|
14
|
+
# puts "Epoch #{epoch}: loss=#{metrics[:loss]}"
|
|
15
|
+
# end
|
|
16
|
+
#
|
|
5
17
|
class Configuration
|
|
18
|
+
# Default values for all configurations
|
|
19
|
+
DEFAULTS = {
|
|
20
|
+
epochs: 3,
|
|
21
|
+
batch_size: 16,
|
|
22
|
+
learning_rate: 2e-5,
|
|
23
|
+
weight_decay: 0.01,
|
|
24
|
+
warmup_ratio: 0.1,
|
|
25
|
+
optimizer: :adamw,
|
|
26
|
+
scheduler: :linear,
|
|
27
|
+
dropout: 0.1,
|
|
28
|
+
image_size: 224
|
|
29
|
+
}.freeze
|
|
30
|
+
|
|
6
31
|
# Training hyperparameters
|
|
32
|
+
# @!attribute epochs
|
|
33
|
+
# @return [Integer] Number of training epochs (default: 3)
|
|
34
|
+
# @!attribute batch_size
|
|
35
|
+
# @return [Integer] Samples per batch (default: 16)
|
|
36
|
+
# @!attribute learning_rate
|
|
37
|
+
# @return [Float] Learning rate (default: 2e-5)
|
|
38
|
+
# @!attribute weight_decay
|
|
39
|
+
# @return [Float] L2 regularization (default: 0.01)
|
|
7
40
|
attr_accessor :epochs, :batch_size, :learning_rate, :weight_decay
|
|
41
|
+
|
|
42
|
+
# @!attribute warmup_steps
|
|
43
|
+
# @return [Integer] Number of warmup steps (default: 0, use warmup_ratio instead)
|
|
44
|
+
# @!attribute warmup_ratio
|
|
45
|
+
# @return [Float] Fraction of training for warmup (default: 0.1)
|
|
8
46
|
attr_accessor :warmup_steps, :warmup_ratio
|
|
47
|
+
|
|
48
|
+
# @!attribute optimizer
|
|
49
|
+
# @return [Symbol] Optimizer type (:adamw, :adam, :sgd) (default: :adamw)
|
|
50
|
+
# @!attribute scheduler
|
|
51
|
+
# @return [Symbol] LR scheduler (:linear, :cosine, :constant) (default: :linear)
|
|
9
52
|
attr_accessor :optimizer, :scheduler
|
|
10
53
|
|
|
11
54
|
# Model configuration
|
|
55
|
+
# @!attribute freeze_encoder
|
|
56
|
+
# @return [Boolean] Freeze encoder weights, only train head (default: false)
|
|
57
|
+
# @!attribute dropout
|
|
58
|
+
# @return [Float] Dropout probability (default: 0.1)
|
|
59
|
+
# @!attribute num_labels
|
|
60
|
+
# @return [Integer, nil] Number of output classes (auto-detected if nil)
|
|
12
61
|
attr_accessor :freeze_encoder, :dropout, :num_labels
|
|
13
62
|
|
|
14
63
|
# Data configuration
|
|
64
|
+
# @!attribute image_size
|
|
65
|
+
# @return [Integer] Target image size for resizing (default: 224)
|
|
15
66
|
attr_accessor :image_size
|
|
16
67
|
|
|
17
|
-
#
|
|
68
|
+
# @!attribute callbacks
|
|
69
|
+
# @return [Array<Callbacks::Base>] Training callbacks
|
|
18
70
|
attr_accessor :callbacks
|
|
19
71
|
|
|
20
|
-
#
|
|
72
|
+
# @!attribute augmentation_config
|
|
73
|
+
# @return [AugmentationConfig] Data augmentation settings
|
|
21
74
|
attr_reader :augmentation_config
|
|
22
75
|
|
|
23
76
|
def initialize
|
|
24
|
-
# Training defaults
|
|
25
|
-
@epochs =
|
|
26
|
-
@batch_size =
|
|
27
|
-
@learning_rate =
|
|
28
|
-
@weight_decay =
|
|
77
|
+
# Training defaults - optimized for most tasks
|
|
78
|
+
@epochs = DEFAULTS[:epochs]
|
|
79
|
+
@batch_size = DEFAULTS[:batch_size]
|
|
80
|
+
@learning_rate = DEFAULTS[:learning_rate]
|
|
81
|
+
@weight_decay = DEFAULTS[:weight_decay]
|
|
29
82
|
@warmup_steps = 0
|
|
30
|
-
@warmup_ratio =
|
|
31
|
-
@optimizer = :
|
|
32
|
-
@scheduler = :
|
|
83
|
+
@warmup_ratio = DEFAULTS[:warmup_ratio]
|
|
84
|
+
@optimizer = DEFAULTS[:optimizer]
|
|
85
|
+
@scheduler = DEFAULTS[:scheduler]
|
|
33
86
|
|
|
34
87
|
# Model defaults
|
|
35
88
|
@freeze_encoder = false
|
|
36
|
-
@dropout =
|
|
89
|
+
@dropout = DEFAULTS[:dropout]
|
|
37
90
|
@num_labels = nil # auto-detect from dataset
|
|
38
91
|
|
|
39
92
|
# Data defaults
|
|
40
|
-
@image_size =
|
|
93
|
+
@image_size = DEFAULTS[:image_size]
|
|
41
94
|
|
|
42
95
|
# Callbacks
|
|
43
96
|
@callbacks = []
|
|
@@ -46,34 +99,138 @@ module Fine
|
|
|
46
99
|
@augmentation_config = AugmentationConfig.new
|
|
47
100
|
end
|
|
48
101
|
|
|
102
|
+
# Configure data augmentation
|
|
103
|
+
#
|
|
104
|
+
# @yield [AugmentationConfig] The augmentation configuration
|
|
105
|
+
# @return [AugmentationConfig]
|
|
106
|
+
#
|
|
107
|
+
# @example
|
|
108
|
+
# config.augmentation do |aug|
|
|
109
|
+
# aug.random_horizontal_flip = true
|
|
110
|
+
# aug.random_rotation = 15
|
|
111
|
+
# end
|
|
49
112
|
def augmentation
|
|
50
113
|
yield @augmentation_config if block_given?
|
|
51
114
|
@augmentation_config
|
|
52
115
|
end
|
|
53
116
|
|
|
54
117
|
# Register a callback for epoch end
|
|
118
|
+
#
|
|
119
|
+
# @yield [Integer, Hash] Epoch number and metrics hash
|
|
120
|
+
#
|
|
121
|
+
# @example
|
|
122
|
+
# config.on_epoch_end do |epoch, metrics|
|
|
123
|
+
# puts "Epoch #{epoch}: loss=#{metrics[:loss]}"
|
|
124
|
+
# end
|
|
55
125
|
def on_epoch_end(&block)
|
|
56
126
|
@callbacks << Callbacks::LambdaCallback.new(on_epoch_end: block)
|
|
57
127
|
end
|
|
58
128
|
|
|
59
129
|
# Register a callback for batch end
|
|
130
|
+
#
|
|
131
|
+
# @yield [Integer, Float] Batch index and loss value
|
|
60
132
|
def on_batch_end(&block)
|
|
61
133
|
@callbacks << Callbacks::LambdaCallback.new(on_batch_end: block)
|
|
62
134
|
end
|
|
63
135
|
|
|
64
|
-
# Register a callback for
|
|
136
|
+
# Register a callback for training start
|
|
137
|
+
#
|
|
138
|
+
# @yield [Hash] Training info (model, config)
|
|
65
139
|
def on_train_begin(&block)
|
|
66
140
|
@callbacks << Callbacks::LambdaCallback.new(on_train_begin: block)
|
|
67
141
|
end
|
|
68
142
|
|
|
69
|
-
# Register a callback for
|
|
143
|
+
# Register a callback for training end
|
|
144
|
+
#
|
|
145
|
+
# @yield [Array<Hash>] Training history
|
|
70
146
|
def on_train_end(&block)
|
|
71
147
|
@callbacks << Callbacks::LambdaCallback.new(on_train_end: block)
|
|
72
148
|
end
|
|
149
|
+
|
|
150
|
+
# Return configuration as a hash
|
|
151
|
+
def to_h
|
|
152
|
+
{
|
|
153
|
+
epochs: @epochs,
|
|
154
|
+
batch_size: @batch_size,
|
|
155
|
+
learning_rate: @learning_rate,
|
|
156
|
+
weight_decay: @weight_decay,
|
|
157
|
+
warmup_steps: @warmup_steps,
|
|
158
|
+
warmup_ratio: @warmup_ratio,
|
|
159
|
+
optimizer: @optimizer,
|
|
160
|
+
scheduler: @scheduler,
|
|
161
|
+
freeze_encoder: @freeze_encoder,
|
|
162
|
+
dropout: @dropout,
|
|
163
|
+
num_labels: @num_labels,
|
|
164
|
+
image_size: @image_size
|
|
165
|
+
}
|
|
166
|
+
end
|
|
167
|
+
end
|
|
168
|
+
|
|
169
|
+
# Configuration for text models (BERT, DistilBERT, DeBERTa)
|
|
170
|
+
class TextConfiguration < Configuration
|
|
171
|
+
# @!attribute max_length
|
|
172
|
+
# @return [Integer] Maximum sequence length (default: 128)
|
|
173
|
+
attr_accessor :max_length
|
|
174
|
+
|
|
175
|
+
# Text model defaults
|
|
176
|
+
DEFAULTS = Configuration::DEFAULTS.merge(
|
|
177
|
+
max_length: 128,
|
|
178
|
+
batch_size: 16
|
|
179
|
+
).freeze
|
|
180
|
+
|
|
181
|
+
def initialize
|
|
182
|
+
super
|
|
183
|
+
@max_length = DEFAULTS[:max_length]
|
|
184
|
+
@batch_size = DEFAULTS[:batch_size]
|
|
185
|
+
end
|
|
186
|
+
end
|
|
187
|
+
|
|
188
|
+
# Configuration for embedding models (Sentence Transformers)
|
|
189
|
+
class EmbeddingConfiguration < Configuration
|
|
190
|
+
# @!attribute max_length
|
|
191
|
+
# @return [Integer] Maximum sequence length (default: 128)
|
|
192
|
+
# @!attribute pooling_mode
|
|
193
|
+
# @return [Symbol] Pooling strategy (:mean, :cls, :max) (default: :mean)
|
|
194
|
+
# @!attribute loss
|
|
195
|
+
# @return [Symbol] Loss function (:cosine, :contrastive, :triplet) (default: :cosine)
|
|
196
|
+
attr_accessor :max_length, :pooling_mode, :loss
|
|
197
|
+
|
|
198
|
+
# Embedding model defaults
|
|
199
|
+
DEFAULTS = Configuration::DEFAULTS.merge(
|
|
200
|
+
max_length: 128,
|
|
201
|
+
pooling_mode: :mean,
|
|
202
|
+
loss: :cosine,
|
|
203
|
+
batch_size: 32
|
|
204
|
+
).freeze
|
|
205
|
+
|
|
206
|
+
def initialize
|
|
207
|
+
super
|
|
208
|
+
@max_length = DEFAULTS[:max_length]
|
|
209
|
+
@pooling_mode = DEFAULTS[:pooling_mode]
|
|
210
|
+
@loss = DEFAULTS[:loss]
|
|
211
|
+
@batch_size = DEFAULTS[:batch_size]
|
|
212
|
+
end
|
|
73
213
|
end
|
|
74
214
|
|
|
75
215
|
# Configuration for data augmentation
|
|
216
|
+
#
|
|
217
|
+
# @example
|
|
218
|
+
# config.augmentation do |aug|
|
|
219
|
+
# aug.random_horizontal_flip = true
|
|
220
|
+
# aug.random_rotation = 15
|
|
221
|
+
# aug.color_jitter = { brightness: 0.2, contrast: 0.2 }
|
|
222
|
+
# end
|
|
76
223
|
class AugmentationConfig
|
|
224
|
+
# @!attribute random_horizontal_flip
|
|
225
|
+
# @return [Boolean] Randomly flip images horizontally (default: false)
|
|
226
|
+
# @!attribute random_vertical_flip
|
|
227
|
+
# @return [Boolean] Randomly flip images vertically (default: false)
|
|
228
|
+
# @!attribute random_rotation
|
|
229
|
+
# @return [Integer] Max rotation degrees (0 = disabled) (default: 0)
|
|
230
|
+
# @!attribute color_jitter
|
|
231
|
+
# @return [Hash, nil] Color jitter settings { brightness:, contrast:, saturation:, hue: }
|
|
232
|
+
# @!attribute random_resized_crop
|
|
233
|
+
# @return [Hash, nil] Random crop settings { scale:, ratio: }
|
|
77
234
|
attr_accessor :random_horizontal_flip, :random_vertical_flip
|
|
78
235
|
attr_accessor :random_rotation, :color_jitter
|
|
79
236
|
attr_accessor :random_resized_crop
|
|
@@ -86,6 +243,7 @@ module Fine
|
|
|
86
243
|
@random_resized_crop = nil
|
|
87
244
|
end
|
|
88
245
|
|
|
246
|
+
# Check if any augmentation is enabled
|
|
89
247
|
def enabled?
|
|
90
248
|
@random_horizontal_flip ||
|
|
91
249
|
@random_vertical_flip ||
|
|
@@ -94,12 +252,12 @@ module Fine
|
|
|
94
252
|
@random_resized_crop
|
|
95
253
|
end
|
|
96
254
|
|
|
255
|
+
# Convert to transform objects
|
|
97
256
|
def to_transforms
|
|
98
257
|
transforms = []
|
|
99
258
|
transforms << Transforms::RandomHorizontalFlip.new if @random_horizontal_flip
|
|
100
259
|
transforms << Transforms::RandomVerticalFlip.new if @random_vertical_flip
|
|
101
260
|
transforms << Transforms::RandomRotation.new(@random_rotation) if @random_rotation.positive?
|
|
102
|
-
# Add more transforms as implemented
|
|
103
261
|
transforms
|
|
104
262
|
end
|
|
105
263
|
end
|
|
@@ -23,9 +23,21 @@ module Fine
|
|
|
23
23
|
#
|
|
24
24
|
# @param path [String] Path to the root directory
|
|
25
25
|
# @param transforms [Transforms::Compose, nil] Optional transforms to apply
|
|
26
|
+
# @param validate [Boolean] Whether to validate directory structure
|
|
26
27
|
# @return [ImageDataset]
|
|
27
|
-
|
|
28
|
-
|
|
28
|
+
#
|
|
29
|
+
# @example Expected directory structure
|
|
30
|
+
# # data/
|
|
31
|
+
# # cats/
|
|
32
|
+
# # cat1.jpg
|
|
33
|
+
# # cat2.jpg
|
|
34
|
+
# # dogs/
|
|
35
|
+
# # dog1.jpg
|
|
36
|
+
# # dog2.jpg
|
|
37
|
+
# dataset = ImageDataset.from_directory("data/", transforms: transforms)
|
|
38
|
+
#
|
|
39
|
+
def self.from_directory(path, transforms: nil, validate: true)
|
|
40
|
+
Validators.validate_image_directory!(path) if validate
|
|
29
41
|
|
|
30
42
|
images = []
|
|
31
43
|
labels = []
|
|
@@ -17,9 +17,24 @@ module Fine
|
|
|
17
17
|
# @param tokenizer [Tokenizers::AutoTokenizer] Tokenizer
|
|
18
18
|
# @param format [Symbol] Data format (:alpaca, :sharegpt, :simple, :auto)
|
|
19
19
|
# @param max_length [Integer] Maximum sequence length
|
|
20
|
+
# @param validate [Boolean] Whether to validate the file first
|
|
20
21
|
# @return [InstructionDataset]
|
|
21
|
-
|
|
22
|
-
|
|
22
|
+
#
|
|
23
|
+
# @example Alpaca format
|
|
24
|
+
# # {"instruction": "Summarize this", "input": "Long text...", "output": "Summary"}
|
|
25
|
+
# dataset = InstructionDataset.from_jsonl("data.jsonl", tokenizer: tok)
|
|
26
|
+
#
|
|
27
|
+
# @example ShareGPT format
|
|
28
|
+
# # {"conversations": [{"from": "human", "value": "Hi"}, {"from": "assistant", "value": "Hello!"}]}
|
|
29
|
+
# dataset = InstructionDataset.from_jsonl("chat.jsonl", tokenizer: tok, format: :sharegpt)
|
|
30
|
+
#
|
|
31
|
+
def self.from_jsonl(path, tokenizer:, format: :auto, max_length: 2048, validate: true)
|
|
32
|
+
detected_format = Validators.validate_instructions!(path, format: format) if validate
|
|
33
|
+
format = detected_format if validate && format == :auto
|
|
34
|
+
|
|
35
|
+
examples = File.readlines(path).reject { |l| l.strip.empty? }.map do |line|
|
|
36
|
+
JSON.parse(line, symbolize_names: true)
|
|
37
|
+
end
|
|
23
38
|
new(examples, tokenizer: tokenizer, format: format, max_length: max_length)
|
|
24
39
|
end
|
|
25
40
|
|
|
@@ -19,9 +19,17 @@ module Fine
|
|
|
19
19
|
# @param tokenizer [AutoTokenizer] Tokenizer to use
|
|
20
20
|
# @param text_column [String] Name of text field
|
|
21
21
|
# @param label_column [String] Name of label field
|
|
22
|
+
# @param validate [Boolean] Whether to validate the file first
|
|
22
23
|
# @return [TextDataset]
|
|
23
|
-
|
|
24
|
-
|
|
24
|
+
#
|
|
25
|
+
# @example
|
|
26
|
+
# # Expected JSONL format:
|
|
27
|
+
# # {"text": "Great product!", "label": "positive"}
|
|
28
|
+
# # {"text": "Terrible service", "label": "negative"}
|
|
29
|
+
# dataset = TextDataset.from_jsonl("reviews.jsonl", tokenizer: tokenizer)
|
|
30
|
+
#
|
|
31
|
+
def self.from_jsonl(path, tokenizer:, text_column: "text", label_column: "label", validate: true)
|
|
32
|
+
Validators.validate_text_classification!(path) if validate
|
|
25
33
|
|
|
26
34
|
texts = []
|
|
27
35
|
labels = []
|
|
@@ -29,9 +37,11 @@ module Fine
|
|
|
29
37
|
File.foreach(path) do |line|
|
|
30
38
|
next if line.strip.empty?
|
|
31
39
|
|
|
32
|
-
data = JSON.parse(line)
|
|
33
|
-
|
|
34
|
-
|
|
40
|
+
data = JSON.parse(line, symbolize_names: true)
|
|
41
|
+
text_key = data.key?(text_column.to_sym) ? text_column.to_sym : text_column
|
|
42
|
+
label_key = data.key?(label_column.to_sym) ? label_column.to_sym : label_column
|
|
43
|
+
texts << data[text_key]
|
|
44
|
+
labels << data[label_key]
|
|
35
45
|
end
|
|
36
46
|
|
|
37
47
|
raise DatasetError, "No data found in #{path}" if texts.empty?
|
|
@@ -20,19 +20,19 @@ module Fine
|
|
|
20
20
|
|
|
21
21
|
# Vision encoder configuration
|
|
22
22
|
def hidden_size
|
|
23
|
-
vision_config["hidden_size"] || config["hidden_size"] || 768
|
|
23
|
+
vision_config["hidden_size"] || config["hidden_size"] || config["dim"] || 768
|
|
24
24
|
end
|
|
25
25
|
|
|
26
26
|
def num_hidden_layers
|
|
27
|
-
vision_config["num_hidden_layers"] || config["num_hidden_layers"] || 12
|
|
27
|
+
vision_config["num_hidden_layers"] || config["num_hidden_layers"] || config["n_layers"] || 12
|
|
28
28
|
end
|
|
29
29
|
|
|
30
30
|
def num_attention_heads
|
|
31
|
-
vision_config["num_attention_heads"] || config["num_attention_heads"] || 12
|
|
31
|
+
vision_config["num_attention_heads"] || config["num_attention_heads"] || config["n_heads"] || 12
|
|
32
32
|
end
|
|
33
33
|
|
|
34
34
|
def intermediate_size
|
|
35
|
-
vision_config["intermediate_size"] || config["intermediate_size"] || 3072
|
|
35
|
+
vision_config["intermediate_size"] || config["intermediate_size"] || config["hidden_dim"] || 3072
|
|
36
36
|
end
|
|
37
37
|
|
|
38
38
|
def image_size
|
|
@@ -10,8 +10,9 @@ module Fine
|
|
|
10
10
|
# @param path [String] Path to the safetensors file
|
|
11
11
|
# @param strict [Boolean] If true, raise error on missing/unexpected keys
|
|
12
12
|
# @param prefix [String] Prefix to add/remove from weight names
|
|
13
|
+
# @param skip_mapping [Boolean] If true, skip weight name mapping (for loading saved Fine models)
|
|
13
14
|
# @return [Hash] Hash with :missing_keys and :unexpected_keys arrays
|
|
14
|
-
def self.load_into_model(model, path, strict: false, prefix: nil)
|
|
15
|
+
def self.load_into_model(model, path, strict: false, prefix: nil, skip_mapping: false)
|
|
15
16
|
tensors = Safetensors::Torch.load_file(path)
|
|
16
17
|
|
|
17
18
|
# Get model's state dict keys
|
|
@@ -22,7 +23,7 @@ module Fine
|
|
|
22
23
|
unexpected_keys = []
|
|
23
24
|
|
|
24
25
|
tensors.each do |name, tensor|
|
|
25
|
-
mapped_name = map_weight_name(name, prefix: prefix)
|
|
26
|
+
mapped_name = skip_mapping ? name : map_weight_name(name, prefix: prefix)
|
|
26
27
|
|
|
27
28
|
if model_keys.include?(mapped_name)
|
|
28
29
|
mapped_tensors[mapped_name] = tensor
|
data/lib/fine/llm.rb
CHANGED
|
@@ -204,9 +204,9 @@ module Fine
|
|
|
204
204
|
def generate(prompt, max_new_tokens: 100, temperature: 0.7, top_p: 0.9, top_k: 50, do_sample: true)
|
|
205
205
|
raise TrainingError, "Model not loaded" unless @model && @tokenizer
|
|
206
206
|
|
|
207
|
-
# Tokenize prompt
|
|
208
|
-
|
|
209
|
-
input_ids = Torch.tensor([
|
|
207
|
+
# Tokenize prompt without padding for autoregressive generation
|
|
208
|
+
ids = @tokenizer.encode_for_generation(prompt)
|
|
209
|
+
input_ids = Torch.tensor([ids])
|
|
210
210
|
|
|
211
211
|
# Move to device
|
|
212
212
|
input_ids = input_ids.to(Fine.device)
|
|
@@ -316,19 +316,48 @@ module Fine
|
|
|
316
316
|
end
|
|
317
317
|
|
|
318
318
|
# Configuration for LLM fine-tuning
|
|
319
|
+
#
|
|
320
|
+
# @example
|
|
321
|
+
# llm = Fine::LLM.new("google/gemma-3-1b-it") do |config|
|
|
322
|
+
# config.epochs = 3
|
|
323
|
+
# config.max_length = 512
|
|
324
|
+
# config.learning_rate = 1e-5
|
|
325
|
+
# end
|
|
326
|
+
#
|
|
319
327
|
class LLMConfiguration < Configuration
|
|
328
|
+
# LLM-specific defaults
|
|
329
|
+
DEFAULTS = Configuration::DEFAULTS.merge(
|
|
330
|
+
max_length: 2048,
|
|
331
|
+
learning_rate: 2e-5,
|
|
332
|
+
batch_size: 4,
|
|
333
|
+
epochs: 3,
|
|
334
|
+
warmup_steps: 100,
|
|
335
|
+
gradient_accumulation_steps: 4,
|
|
336
|
+
max_grad_norm: 1.0
|
|
337
|
+
).freeze
|
|
338
|
+
|
|
339
|
+
# @!attribute max_length
|
|
340
|
+
# @return [Integer] Maximum sequence length (default: 2048)
|
|
341
|
+
# @!attribute gradient_accumulation_steps
|
|
342
|
+
# @return [Integer] Accumulate gradients over N steps (default: 4)
|
|
343
|
+
# @!attribute max_grad_norm
|
|
344
|
+
# @return [Float] Gradient clipping norm (default: 1.0)
|
|
345
|
+
# @!attribute freeze_layers
|
|
346
|
+
# @return [Integer] Number of bottom layers to freeze (default: 0)
|
|
347
|
+
# @!attribute pad_token_id
|
|
348
|
+
# @return [Integer, nil] Padding token ID (auto-detected if nil)
|
|
320
349
|
attr_accessor :max_length, :warmup_steps, :gradient_accumulation_steps,
|
|
321
350
|
:max_grad_norm, :freeze_layers, :pad_token_id
|
|
322
351
|
|
|
323
352
|
def initialize
|
|
324
353
|
super
|
|
325
|
-
@max_length =
|
|
326
|
-
@learning_rate =
|
|
327
|
-
@batch_size =
|
|
328
|
-
@epochs =
|
|
329
|
-
@warmup_steps =
|
|
330
|
-
@gradient_accumulation_steps =
|
|
331
|
-
@max_grad_norm =
|
|
354
|
+
@max_length = DEFAULTS[:max_length]
|
|
355
|
+
@learning_rate = DEFAULTS[:learning_rate]
|
|
356
|
+
@batch_size = DEFAULTS[:batch_size]
|
|
357
|
+
@epochs = DEFAULTS[:epochs]
|
|
358
|
+
@warmup_steps = DEFAULTS[:warmup_steps]
|
|
359
|
+
@gradient_accumulation_steps = DEFAULTS[:gradient_accumulation_steps]
|
|
360
|
+
@max_grad_norm = DEFAULTS[:max_grad_norm]
|
|
332
361
|
@freeze_layers = 0
|
|
333
362
|
@pad_token_id = nil
|
|
334
363
|
end
|