trainers-rb 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +17 -0
- data/LICENSE.txt +21 -0
- data/README.md +293 -0
- data/lib/trainers/callbacks.rb +128 -0
- data/lib/trainers/data/data_collator.rb +131 -0
- data/lib/trainers/data/dataset.rb +26 -0
- data/lib/trainers/lora/lora_config.rb +34 -0
- data/lib/trainers/lora/lora_linear.rb +78 -0
- data/lib/trainers/lora/lora_model.rb +87 -0
- data/lib/trainers/lora/lora_utils.rb +73 -0
- data/lib/trainers/optimization/optimizer.rb +49 -0
- data/lib/trainers/optimization/scheduler.rb +67 -0
- data/lib/trainers/save_utils.rb +84 -0
- data/lib/trainers/trainer.rb +340 -0
- data/lib/trainers/trainer_utils.rb +30 -0
- data/lib/trainers/training_arguments.rb +64 -0
- data/lib/trainers/version.rb +5 -0
- data/lib/trainers-rb.rb +1 -0
- data/lib/trainers.rb +43 -0
- metadata +149 -0
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Trainers
|
|
4
|
+
# Wraps a frozen Torch::NN::Linear with low-rank A/B adapter matrices.
|
|
5
|
+
#
|
|
6
|
+
# Forward: y = Wx + b + (x @ A^T @ B^T) * scaling
|
|
7
|
+
#
|
|
8
|
+
# Only lora_A and lora_B are trainable. The original weight and bias
|
|
9
|
+
# are frozen, keeping ~99% of parameters fixed during fine-tuning.
|
|
10
|
+
class LoraLinear < Torch::NN::Module
|
|
11
|
+
attr_reader :in_features, :out_features, :r, :scaling
|
|
12
|
+
|
|
13
|
+
def initialize(original_linear, r:, lora_alpha:, lora_dropout: 0.0)
|
|
14
|
+
super()
|
|
15
|
+
|
|
16
|
+
@in_features = original_linear.instance_variable_get(:@in_features) ||
|
|
17
|
+
original_linear.weight.size(1)
|
|
18
|
+
@out_features = original_linear.instance_variable_get(:@out_features) ||
|
|
19
|
+
original_linear.weight.size(0)
|
|
20
|
+
@r = r
|
|
21
|
+
@scaling = lora_alpha.to_f / r
|
|
22
|
+
|
|
23
|
+
# Freeze original parameters
|
|
24
|
+
@weight = original_linear.weight
|
|
25
|
+
@weight.requires_grad = false
|
|
26
|
+
|
|
27
|
+
@bias = original_linear.bias
|
|
28
|
+
@bias.requires_grad = false if @bias
|
|
29
|
+
|
|
30
|
+
# LoRA low-rank matrices (these are the only trainable parameters)
|
|
31
|
+
# A: (r, in_features) — initialized with Kaiming uniform
|
|
32
|
+
# B: (out_features, r) — initialized to zero so LoRA starts as identity
|
|
33
|
+
@lora_A = Torch::NN::Parameter.new(Torch.empty(@r, @in_features))
|
|
34
|
+
Torch::NN::Init.kaiming_uniform!(@lora_A, a: Math.sqrt(5))
|
|
35
|
+
|
|
36
|
+
@lora_B = Torch::NN::Parameter.new(Torch.zeros(@out_features, @r))
|
|
37
|
+
|
|
38
|
+
@lora_dropout = lora_dropout > 0 ? Torch::NN::Dropout.new(p: lora_dropout) : nil
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
def forward(x)
|
|
42
|
+
# Original linear
|
|
43
|
+
base_output = Torch::NN::F.linear(x, @weight, @bias)
|
|
44
|
+
|
|
45
|
+
# LoRA path
|
|
46
|
+
lora_input = @lora_dropout ? @lora_dropout.call(x) : x
|
|
47
|
+
lora_output = lora_input.matmul(@lora_A.t).matmul(@lora_B.t) * @scaling
|
|
48
|
+
|
|
49
|
+
base_output + lora_output
|
|
50
|
+
end
|
|
51
|
+
|
|
52
|
+
# Merge LoRA weights into the base weight matrix (for inference)
|
|
53
|
+
def merge!
|
|
54
|
+
Torch.no_grad do
|
|
55
|
+
@weight.add!(@lora_B.matmul(@lora_A) * @scaling)
|
|
56
|
+
end
|
|
57
|
+
self
|
|
58
|
+
end
|
|
59
|
+
|
|
60
|
+
# Extract LoRA adapter state for saving
|
|
61
|
+
def lora_state_dict
|
|
62
|
+
{ "lora_A" => @lora_A.data, "lora_B" => @lora_B.data }
|
|
63
|
+
end
|
|
64
|
+
|
|
65
|
+
# Load LoRA weights from saved state
|
|
66
|
+
def load_lora_weights(lora_a_tensor, lora_b_tensor)
|
|
67
|
+
Torch.no_grad do
|
|
68
|
+
@lora_A.copy!(lora_a_tensor)
|
|
69
|
+
@lora_B.copy!(lora_b_tensor)
|
|
70
|
+
end
|
|
71
|
+
end
|
|
72
|
+
|
|
73
|
+
def extra_repr
|
|
74
|
+
"in_features=#{@in_features}, out_features=#{@out_features}, " \
|
|
75
|
+
"r=#{@r}, scaling=#{@scaling}"
|
|
76
|
+
end
|
|
77
|
+
end
|
|
78
|
+
end
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Trainers
|
|
4
|
+
class LoraModel
|
|
5
|
+
# Apply LoRA adapters to a model.
|
|
6
|
+
#
|
|
7
|
+
# 1. Finds all Linear layers matching config.target_modules
|
|
8
|
+
# 2. Replaces each with a LoraLinear that freezes the original weight
|
|
9
|
+
# 3. Freezes all base model parameters
|
|
10
|
+
# 4. Only LoRA A/B matrices remain trainable
|
|
11
|
+
#
|
|
12
|
+
# Returns the modified model (in-place).
|
|
13
|
+
def self.apply(model, config)
|
|
14
|
+
config = LoraConfig.new(**config) if config.is_a?(Hash)
|
|
15
|
+
|
|
16
|
+
# Freeze all base parameters first
|
|
17
|
+
model.parameters.each { |p| p.requires_grad = false }
|
|
18
|
+
|
|
19
|
+
# Find and replace target modules
|
|
20
|
+
targets = LoraUtils.find_target_modules(model, config.target_modules)
|
|
21
|
+
|
|
22
|
+
if targets.empty?
|
|
23
|
+
raise "No Linear modules found matching target_modules: #{config.target_modules.inspect}. " \
|
|
24
|
+
"Available modules: #{model.named_modules.map(&:first).join(', ')}"
|
|
25
|
+
end
|
|
26
|
+
|
|
27
|
+
targets.each do |name, linear|
|
|
28
|
+
lora_linear = LoraLinear.new(
|
|
29
|
+
linear,
|
|
30
|
+
r: config.r,
|
|
31
|
+
lora_alpha: config.lora_alpha,
|
|
32
|
+
lora_dropout: config.lora_dropout
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
LoraUtils.replace_module(model, name, lora_linear)
|
|
36
|
+
end
|
|
37
|
+
|
|
38
|
+
# Handle bias training based on config
|
|
39
|
+
case config.bias
|
|
40
|
+
when :all
|
|
41
|
+
model.named_parameters.each do |name, param|
|
|
42
|
+
param.requires_grad = true if name.include?("bias")
|
|
43
|
+
end
|
|
44
|
+
when :lora_only
|
|
45
|
+
model.named_modules.each do |_, mod|
|
|
46
|
+
if mod.is_a?(LoraLinear) && mod.instance_variable_get(:@bias)
|
|
47
|
+
mod.instance_variable_get(:@bias).requires_grad = true
|
|
48
|
+
end
|
|
49
|
+
end
|
|
50
|
+
end
|
|
51
|
+
# :none — biases stay frozen (default)
|
|
52
|
+
|
|
53
|
+
puts "LoRA applied to #{targets.size} modules: #{targets.keys.join(', ')}"
|
|
54
|
+
LoraUtils.print_trainable_parameters(model)
|
|
55
|
+
|
|
56
|
+
model
|
|
57
|
+
end
|
|
58
|
+
|
|
59
|
+
# Merge all LoRA weights back into base weights (for inference)
|
|
60
|
+
def self.merge(model)
|
|
61
|
+
count = 0
|
|
62
|
+
model.named_modules.each do |_, mod|
|
|
63
|
+
if mod.is_a?(LoraLinear)
|
|
64
|
+
mod.merge!
|
|
65
|
+
count += 1
|
|
66
|
+
end
|
|
67
|
+
end
|
|
68
|
+
puts "Merged #{count} LoRA adapters into base model"
|
|
69
|
+
model
|
|
70
|
+
end
|
|
71
|
+
|
|
72
|
+
# Save only the LoRA adapter weights
|
|
73
|
+
def self.save_adapters(model, output_dir, config: nil)
|
|
74
|
+
SaveUtils.save_lora_adapters(model, output_dir)
|
|
75
|
+
|
|
76
|
+
if config
|
|
77
|
+
config_path = File.join(output_dir, "lora_config.json")
|
|
78
|
+
File.write(config_path, JSON.pretty_generate(config.to_h))
|
|
79
|
+
end
|
|
80
|
+
end
|
|
81
|
+
|
|
82
|
+
# Load LoRA adapter weights into a model that already has LoRA applied
|
|
83
|
+
def self.load_adapters(model, input_dir)
|
|
84
|
+
SaveUtils.load_lora_adapters(model, input_dir)
|
|
85
|
+
end
|
|
86
|
+
end
|
|
87
|
+
end
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Trainers
|
|
4
|
+
module LoraUtils
|
|
5
|
+
# Find all Linear modules in a model matching the target pattern
|
|
6
|
+
def self.find_target_modules(model, target_modules)
|
|
7
|
+
targets = {}
|
|
8
|
+
|
|
9
|
+
model.named_modules.each do |name, mod|
|
|
10
|
+
next unless mod.is_a?(Torch::NN::Linear)
|
|
11
|
+
|
|
12
|
+
if target_modules == :all_linear
|
|
13
|
+
targets[name] = mod
|
|
14
|
+
elsif target_modules.is_a?(Array)
|
|
15
|
+
if target_modules.any? { |pattern| name.include?(pattern) }
|
|
16
|
+
targets[name] = mod
|
|
17
|
+
end
|
|
18
|
+
end
|
|
19
|
+
end
|
|
20
|
+
|
|
21
|
+
targets
|
|
22
|
+
end
|
|
23
|
+
|
|
24
|
+
# Replace a child module on a parent by setting the instance variable.
|
|
25
|
+
# torch-rb's named_children discovers modules via instance variable scan,
|
|
26
|
+
# so setting the ivar is the correct replacement mechanism.
|
|
27
|
+
def self.replace_module(model, target_name, new_module)
|
|
28
|
+
parts = target_name.split(".")
|
|
29
|
+
parent = model
|
|
30
|
+
|
|
31
|
+
# Navigate to the parent module
|
|
32
|
+
parts[0...-1].each do |part|
|
|
33
|
+
if part =~ /\A\d+\z/
|
|
34
|
+
# Numeric index — likely a ModuleList element
|
|
35
|
+
parent = parent.instance_variable_get(:@modules)[part.to_i]
|
|
36
|
+
else
|
|
37
|
+
parent = parent.instance_variable_get(:"@#{part}")
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
raise "Could not find module part '#{part}' in #{target_name}" unless parent
|
|
41
|
+
end
|
|
42
|
+
|
|
43
|
+
child_name = parts.last
|
|
44
|
+
if child_name =~ /\A\d+\z/
|
|
45
|
+
parent.instance_variable_get(:@modules)[child_name.to_i] = new_module
|
|
46
|
+
else
|
|
47
|
+
parent.instance_variable_set(:"@#{child_name}", new_module)
|
|
48
|
+
end
|
|
49
|
+
end
|
|
50
|
+
|
|
51
|
+
# Count total and trainable parameters
|
|
52
|
+
def self.print_trainable_parameters(model)
|
|
53
|
+
total = 0
|
|
54
|
+
trainable = 0
|
|
55
|
+
|
|
56
|
+
model.parameters.each do |p|
|
|
57
|
+
total += p.numel
|
|
58
|
+
trainable += p.numel if p.requires_grad
|
|
59
|
+
end
|
|
60
|
+
|
|
61
|
+
pct = total > 0 ? (trainable.to_f / total * 100) : 0.0
|
|
62
|
+
puts "trainable params: #{format_number(trainable)} || " \
|
|
63
|
+
"all params: #{format_number(total)} || " \
|
|
64
|
+
"trainable%: #{format('%.4f', pct)}%"
|
|
65
|
+
|
|
66
|
+
{ total: total, trainable: trainable, percentage: pct }
|
|
67
|
+
end
|
|
68
|
+
|
|
69
|
+
def self.format_number(n)
|
|
70
|
+
n.to_s.reverse.gsub(/(\d{3})(?=\d)/, '\\1,').reverse
|
|
71
|
+
end
|
|
72
|
+
end
|
|
73
|
+
end
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Trainers
|
|
4
|
+
module Optimization
|
|
5
|
+
# Builds AdamW with two param groups:
|
|
6
|
+
# - Parameters with weight decay (Linear weights, Embedding weights)
|
|
7
|
+
# - Parameters without weight decay (biases, LayerNorm/layer_norm params)
|
|
8
|
+
#
|
|
9
|
+
# This split is critical for transformer fine-tuning — regularizing biases
|
|
10
|
+
# and normalization weights hurts convergence.
|
|
11
|
+
def self.create_optimizer(model, args)
|
|
12
|
+
decay_params = []
|
|
13
|
+
no_decay_params = []
|
|
14
|
+
|
|
15
|
+
no_decay_patterns = ["bias", "LayerNorm", "layer_norm", "layernorm"]
|
|
16
|
+
|
|
17
|
+
model.named_parameters.each do |name, param|
|
|
18
|
+
next unless param.requires_grad
|
|
19
|
+
|
|
20
|
+
if no_decay_patterns.any? { |pattern| name.include?(pattern) }
|
|
21
|
+
no_decay_params << param
|
|
22
|
+
else
|
|
23
|
+
decay_params << param
|
|
24
|
+
end
|
|
25
|
+
end
|
|
26
|
+
|
|
27
|
+
param_groups = []
|
|
28
|
+
|
|
29
|
+
if decay_params.any?
|
|
30
|
+
param_groups << { params: decay_params, weight_decay: args.weight_decay }
|
|
31
|
+
end
|
|
32
|
+
|
|
33
|
+
if no_decay_params.any?
|
|
34
|
+
param_groups << { params: no_decay_params, weight_decay: 0.0 }
|
|
35
|
+
end
|
|
36
|
+
|
|
37
|
+
if param_groups.empty?
|
|
38
|
+
raise "No trainable parameters found. Did you forget to unfreeze the model or apply LoRA?"
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
Torch::Optim::AdamW.new(
|
|
42
|
+
param_groups,
|
|
43
|
+
lr: args.learning_rate,
|
|
44
|
+
betas: [args.adam_beta1, args.adam_beta2],
|
|
45
|
+
eps: args.adam_epsilon
|
|
46
|
+
)
|
|
47
|
+
end
|
|
48
|
+
end
|
|
49
|
+
end
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Trainers
|
|
4
|
+
module Optimization
|
|
5
|
+
# Linear warmup then linear decay to 0
|
|
6
|
+
def self.get_linear_schedule_with_warmup(optimizer, num_warmup_steps:, num_training_steps:)
|
|
7
|
+
lr_lambda = ->(current_step) {
|
|
8
|
+
if current_step < num_warmup_steps
|
|
9
|
+
current_step.to_f / [1, num_warmup_steps].max
|
|
10
|
+
else
|
|
11
|
+
remaining = num_training_steps - current_step
|
|
12
|
+
total_decay = num_training_steps - num_warmup_steps
|
|
13
|
+
[0.0, remaining.to_f / [1.0, total_decay].max].max
|
|
14
|
+
end
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
Torch::Optim::LRScheduler::LambdaLR.new(optimizer, lr_lambda)
|
|
18
|
+
end
|
|
19
|
+
|
|
20
|
+
# Linear warmup then cosine decay to 0
|
|
21
|
+
def self.get_cosine_schedule_with_warmup(optimizer, num_warmup_steps:, num_training_steps:, num_cycles: 0.5)
|
|
22
|
+
lr_lambda = ->(current_step) {
|
|
23
|
+
if current_step < num_warmup_steps
|
|
24
|
+
current_step.to_f / [1, num_warmup_steps].max
|
|
25
|
+
else
|
|
26
|
+
progress = (current_step - num_warmup_steps).to_f /
|
|
27
|
+
[1, num_training_steps - num_warmup_steps].max
|
|
28
|
+
[0.0, 0.5 * (1.0 + Math.cos(Math::PI * num_cycles * 2.0 * progress))].max
|
|
29
|
+
end
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
Torch::Optim::LRScheduler::LambdaLR.new(optimizer, lr_lambda)
|
|
33
|
+
end
|
|
34
|
+
|
|
35
|
+
# Linear warmup then constant LR
|
|
36
|
+
def self.get_constant_schedule_with_warmup(optimizer, num_warmup_steps:)
|
|
37
|
+
lr_lambda = ->(current_step) {
|
|
38
|
+
if current_step < num_warmup_steps
|
|
39
|
+
current_step.to_f / [1, num_warmup_steps].max
|
|
40
|
+
else
|
|
41
|
+
1.0
|
|
42
|
+
end
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
Torch::Optim::LRScheduler::LambdaLR.new(optimizer, lr_lambda)
|
|
46
|
+
end
|
|
47
|
+
|
|
48
|
+
# Dispatcher: pick scheduler by type symbol
|
|
49
|
+
def self.create_scheduler(type, optimizer, num_warmup_steps:, num_training_steps:)
|
|
50
|
+
case type
|
|
51
|
+
when :linear
|
|
52
|
+
get_linear_schedule_with_warmup(optimizer,
|
|
53
|
+
num_warmup_steps: num_warmup_steps,
|
|
54
|
+
num_training_steps: num_training_steps)
|
|
55
|
+
when :cosine
|
|
56
|
+
get_cosine_schedule_with_warmup(optimizer,
|
|
57
|
+
num_warmup_steps: num_warmup_steps,
|
|
58
|
+
num_training_steps: num_training_steps)
|
|
59
|
+
when :constant
|
|
60
|
+
get_constant_schedule_with_warmup(optimizer,
|
|
61
|
+
num_warmup_steps: num_warmup_steps)
|
|
62
|
+
else
|
|
63
|
+
raise ArgumentError, "Unknown scheduler type: #{type}. Use :linear, :cosine, or :constant"
|
|
64
|
+
end
|
|
65
|
+
end
|
|
66
|
+
end
|
|
67
|
+
end
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "json"
|
|
4
|
+
require "fileutils"
|
|
5
|
+
|
|
6
|
+
module Trainers
|
|
7
|
+
module SaveUtils
|
|
8
|
+
def self.save_pretrained(model, tokenizer, output_dir, training_args: nil)
|
|
9
|
+
FileUtils.mkdir_p(output_dir)
|
|
10
|
+
|
|
11
|
+
# Save model weights via safetensors
|
|
12
|
+
state_dict = model.state_dict
|
|
13
|
+
tensors = {}
|
|
14
|
+
state_dict.each do |name, param|
|
|
15
|
+
tensors[name] = param.is_a?(Torch::NN::Parameter) ? param.data : param
|
|
16
|
+
end
|
|
17
|
+
|
|
18
|
+
model_path = File.join(output_dir, "model.safetensors")
|
|
19
|
+
require "safetensors"
|
|
20
|
+
Safetensors::Torch.save_file(tensors, model_path)
|
|
21
|
+
puts "Model saved to #{model_path}"
|
|
22
|
+
|
|
23
|
+
# Save tokenizer
|
|
24
|
+
if tokenizer
|
|
25
|
+
if tokenizer.respond_to?(:save)
|
|
26
|
+
tokenizer.save(File.join(output_dir, "tokenizer.json"))
|
|
27
|
+
end
|
|
28
|
+
end
|
|
29
|
+
|
|
30
|
+
# Save training args
|
|
31
|
+
if training_args
|
|
32
|
+
args_path = File.join(output_dir, "training_args.json")
|
|
33
|
+
File.write(args_path, JSON.pretty_generate(training_args.to_h))
|
|
34
|
+
end
|
|
35
|
+
|
|
36
|
+
output_dir
|
|
37
|
+
end
|
|
38
|
+
|
|
39
|
+
def self.save_lora_adapters(model, output_dir)
|
|
40
|
+
FileUtils.mkdir_p(output_dir)
|
|
41
|
+
|
|
42
|
+
lora_state = {}
|
|
43
|
+
model.named_modules.each do |name, mod|
|
|
44
|
+
if mod.is_a?(Trainers::LoraLinear)
|
|
45
|
+
mod.lora_state_dict.each do |param_name, tensor|
|
|
46
|
+
lora_state["#{name}.#{param_name}"] = tensor
|
|
47
|
+
end
|
|
48
|
+
end
|
|
49
|
+
end
|
|
50
|
+
|
|
51
|
+
if lora_state.empty?
|
|
52
|
+
puts "No LoRA adapters found in model"
|
|
53
|
+
return
|
|
54
|
+
end
|
|
55
|
+
|
|
56
|
+
adapter_path = File.join(output_dir, "adapter_model.safetensors")
|
|
57
|
+
require "safetensors"
|
|
58
|
+
Safetensors::Torch.save_file(lora_state, adapter_path)
|
|
59
|
+
puts "LoRA adapters saved to #{adapter_path} (#{lora_state.size} tensors)"
|
|
60
|
+
|
|
61
|
+
output_dir
|
|
62
|
+
end
|
|
63
|
+
|
|
64
|
+
def self.load_lora_adapters(model, input_dir)
|
|
65
|
+
adapter_path = File.join(input_dir, "adapter_model.safetensors")
|
|
66
|
+
raise "Adapter file not found: #{adapter_path}" unless File.exist?(adapter_path)
|
|
67
|
+
|
|
68
|
+
require "safetensors"
|
|
69
|
+
lora_state = Safetensors::Torch.load_file(adapter_path)
|
|
70
|
+
|
|
71
|
+
model.named_modules.each do |name, mod|
|
|
72
|
+
if mod.is_a?(Trainers::LoraLinear)
|
|
73
|
+
a_key = "#{name}.lora_A"
|
|
74
|
+
b_key = "#{name}.lora_B"
|
|
75
|
+
if lora_state[a_key] && lora_state[b_key]
|
|
76
|
+
mod.load_lora_weights(lora_state[a_key], lora_state[b_key])
|
|
77
|
+
end
|
|
78
|
+
end
|
|
79
|
+
end
|
|
80
|
+
|
|
81
|
+
puts "LoRA adapters loaded from #{adapter_path}"
|
|
82
|
+
end
|
|
83
|
+
end
|
|
84
|
+
end
|