trainers-rb 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,78 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Trainers
4
+ # Wraps a frozen Torch::NN::Linear with low-rank A/B adapter matrices.
5
+ #
6
+ # Forward: y = Wx + b + (x @ A^T @ B^T) * scaling
7
+ #
8
+ # Only lora_A and lora_B are trainable. The original weight and bias
9
+ # are frozen, keeping ~99% of parameters fixed during fine-tuning.
10
+ class LoraLinear < Torch::NN::Module
11
+ attr_reader :in_features, :out_features, :r, :scaling
12
+
13
+ def initialize(original_linear, r:, lora_alpha:, lora_dropout: 0.0)
14
+ super()
15
+
16
+ @in_features = original_linear.instance_variable_get(:@in_features) ||
17
+ original_linear.weight.size(1)
18
+ @out_features = original_linear.instance_variable_get(:@out_features) ||
19
+ original_linear.weight.size(0)
20
+ @r = r
21
+ @scaling = lora_alpha.to_f / r
22
+
23
+ # Freeze original parameters
24
+ @weight = original_linear.weight
25
+ @weight.requires_grad = false
26
+
27
+ @bias = original_linear.bias
28
+ @bias.requires_grad = false if @bias
29
+
30
+ # LoRA low-rank matrices (these are the only trainable parameters)
31
+ # A: (r, in_features) — initialized with Kaiming uniform
32
+ # B: (out_features, r) — initialized to zero so LoRA starts as identity
33
+ @lora_A = Torch::NN::Parameter.new(Torch.empty(@r, @in_features))
34
+ Torch::NN::Init.kaiming_uniform!(@lora_A, a: Math.sqrt(5))
35
+
36
+ @lora_B = Torch::NN::Parameter.new(Torch.zeros(@out_features, @r))
37
+
38
+ @lora_dropout = lora_dropout > 0 ? Torch::NN::Dropout.new(p: lora_dropout) : nil
39
+ end
40
+
41
+ def forward(x)
42
+ # Original linear
43
+ base_output = Torch::NN::F.linear(x, @weight, @bias)
44
+
45
+ # LoRA path
46
+ lora_input = @lora_dropout ? @lora_dropout.call(x) : x
47
+ lora_output = lora_input.matmul(@lora_A.t).matmul(@lora_B.t) * @scaling
48
+
49
+ base_output + lora_output
50
+ end
51
+
52
+ # Merge LoRA weights into the base weight matrix (for inference)
53
+ def merge!
54
+ Torch.no_grad do
55
+ @weight.add!(@lora_B.matmul(@lora_A) * @scaling)
56
+ end
57
+ self
58
+ end
59
+
60
+ # Extract LoRA adapter state for saving
61
+ def lora_state_dict
62
+ { "lora_A" => @lora_A.data, "lora_B" => @lora_B.data }
63
+ end
64
+
65
+ # Load LoRA weights from saved state
66
+ def load_lora_weights(lora_a_tensor, lora_b_tensor)
67
+ Torch.no_grad do
68
+ @lora_A.copy!(lora_a_tensor)
69
+ @lora_B.copy!(lora_b_tensor)
70
+ end
71
+ end
72
+
73
+ def extra_repr
74
+ "in_features=#{@in_features}, out_features=#{@out_features}, " \
75
+ "r=#{@r}, scaling=#{@scaling}"
76
+ end
77
+ end
78
+ end
@@ -0,0 +1,87 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Trainers
4
+ class LoraModel
5
+ # Apply LoRA adapters to a model.
6
+ #
7
+ # 1. Finds all Linear layers matching config.target_modules
8
+ # 2. Replaces each with a LoraLinear that freezes the original weight
9
+ # 3. Freezes all base model parameters
10
+ # 4. Only LoRA A/B matrices remain trainable
11
+ #
12
+ # Returns the modified model (in-place).
13
+ def self.apply(model, config)
14
+ config = LoraConfig.new(**config) if config.is_a?(Hash)
15
+
16
+ # Freeze all base parameters first
17
+ model.parameters.each { |p| p.requires_grad = false }
18
+
19
+ # Find and replace target modules
20
+ targets = LoraUtils.find_target_modules(model, config.target_modules)
21
+
22
+ if targets.empty?
23
+ raise "No Linear modules found matching target_modules: #{config.target_modules.inspect}. " \
24
+ "Available modules: #{model.named_modules.map(&:first).join(', ')}"
25
+ end
26
+
27
+ targets.each do |name, linear|
28
+ lora_linear = LoraLinear.new(
29
+ linear,
30
+ r: config.r,
31
+ lora_alpha: config.lora_alpha,
32
+ lora_dropout: config.lora_dropout
33
+ )
34
+
35
+ LoraUtils.replace_module(model, name, lora_linear)
36
+ end
37
+
38
+ # Handle bias training based on config
39
+ case config.bias
40
+ when :all
41
+ model.named_parameters.each do |name, param|
42
+ param.requires_grad = true if name.include?("bias")
43
+ end
44
+ when :lora_only
45
+ model.named_modules.each do |_, mod|
46
+ if mod.is_a?(LoraLinear) && mod.instance_variable_get(:@bias)
47
+ mod.instance_variable_get(:@bias).requires_grad = true
48
+ end
49
+ end
50
+ end
51
+ # :none — biases stay frozen (default)
52
+
53
+ puts "LoRA applied to #{targets.size} modules: #{targets.keys.join(', ')}"
54
+ LoraUtils.print_trainable_parameters(model)
55
+
56
+ model
57
+ end
58
+
59
+ # Merge all LoRA weights back into base weights (for inference)
60
+ def self.merge(model)
61
+ count = 0
62
+ model.named_modules.each do |_, mod|
63
+ if mod.is_a?(LoraLinear)
64
+ mod.merge!
65
+ count += 1
66
+ end
67
+ end
68
+ puts "Merged #{count} LoRA adapters into base model"
69
+ model
70
+ end
71
+
72
+ # Save only the LoRA adapter weights
73
+ def self.save_adapters(model, output_dir, config: nil)
74
+ SaveUtils.save_lora_adapters(model, output_dir)
75
+
76
+ if config
77
+ config_path = File.join(output_dir, "lora_config.json")
78
+ File.write(config_path, JSON.pretty_generate(config.to_h))
79
+ end
80
+ end
81
+
82
+ # Load LoRA adapter weights into a model that already has LoRA applied
83
+ def self.load_adapters(model, input_dir)
84
+ SaveUtils.load_lora_adapters(model, input_dir)
85
+ end
86
+ end
87
+ end
@@ -0,0 +1,73 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Trainers
4
+ module LoraUtils
5
+ # Find all Linear modules in a model matching the target pattern
6
+ def self.find_target_modules(model, target_modules)
7
+ targets = {}
8
+
9
+ model.named_modules.each do |name, mod|
10
+ next unless mod.is_a?(Torch::NN::Linear)
11
+
12
+ if target_modules == :all_linear
13
+ targets[name] = mod
14
+ elsif target_modules.is_a?(Array)
15
+ if target_modules.any? { |pattern| name.include?(pattern) }
16
+ targets[name] = mod
17
+ end
18
+ end
19
+ end
20
+
21
+ targets
22
+ end
23
+
24
+ # Replace a child module on a parent by setting the instance variable.
25
+ # torch-rb's named_children discovers modules via instance variable scan,
26
+ # so setting the ivar is the correct replacement mechanism.
27
+ def self.replace_module(model, target_name, new_module)
28
+ parts = target_name.split(".")
29
+ parent = model
30
+
31
+ # Navigate to the parent module
32
+ parts[0...-1].each do |part|
33
+ if part =~ /\A\d+\z/
34
+ # Numeric index — likely a ModuleList element
35
+ parent = parent.instance_variable_get(:@modules)[part.to_i]
36
+ else
37
+ parent = parent.instance_variable_get(:"@#{part}")
38
+ end
39
+
40
+ raise "Could not find module part '#{part}' in #{target_name}" unless parent
41
+ end
42
+
43
+ child_name = parts.last
44
+ if child_name =~ /\A\d+\z/
45
+ parent.instance_variable_get(:@modules)[child_name.to_i] = new_module
46
+ else
47
+ parent.instance_variable_set(:"@#{child_name}", new_module)
48
+ end
49
+ end
50
+
51
+ # Count total and trainable parameters
52
+ def self.print_trainable_parameters(model)
53
+ total = 0
54
+ trainable = 0
55
+
56
+ model.parameters.each do |p|
57
+ total += p.numel
58
+ trainable += p.numel if p.requires_grad
59
+ end
60
+
61
+ pct = total > 0 ? (trainable.to_f / total * 100) : 0.0
62
+ puts "trainable params: #{format_number(trainable)} || " \
63
+ "all params: #{format_number(total)} || " \
64
+ "trainable%: #{format('%.4f', pct)}%"
65
+
66
+ { total: total, trainable: trainable, percentage: pct }
67
+ end
68
+
69
+ def self.format_number(n)
70
+ n.to_s.reverse.gsub(/(\d{3})(?=\d)/, '\\1,').reverse
71
+ end
72
+ end
73
+ end
@@ -0,0 +1,49 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Trainers
4
+ module Optimization
5
+ # Builds AdamW with two param groups:
6
+ # - Parameters with weight decay (Linear weights, Embedding weights)
7
+ # - Parameters without weight decay (biases, LayerNorm/layer_norm params)
8
+ #
9
+ # This split is critical for transformer fine-tuning — regularizing biases
10
+ # and normalization weights hurts convergence.
11
+ def self.create_optimizer(model, args)
12
+ decay_params = []
13
+ no_decay_params = []
14
+
15
+ no_decay_patterns = ["bias", "LayerNorm", "layer_norm", "layernorm"]
16
+
17
+ model.named_parameters.each do |name, param|
18
+ next unless param.requires_grad
19
+
20
+ if no_decay_patterns.any? { |pattern| name.include?(pattern) }
21
+ no_decay_params << param
22
+ else
23
+ decay_params << param
24
+ end
25
+ end
26
+
27
+ param_groups = []
28
+
29
+ if decay_params.any?
30
+ param_groups << { params: decay_params, weight_decay: args.weight_decay }
31
+ end
32
+
33
+ if no_decay_params.any?
34
+ param_groups << { params: no_decay_params, weight_decay: 0.0 }
35
+ end
36
+
37
+ if param_groups.empty?
38
+ raise "No trainable parameters found. Did you forget to unfreeze the model or apply LoRA?"
39
+ end
40
+
41
+ Torch::Optim::AdamW.new(
42
+ param_groups,
43
+ lr: args.learning_rate,
44
+ betas: [args.adam_beta1, args.adam_beta2],
45
+ eps: args.adam_epsilon
46
+ )
47
+ end
48
+ end
49
+ end
@@ -0,0 +1,67 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Trainers
4
+ module Optimization
5
+ # Linear warmup then linear decay to 0
6
+ def self.get_linear_schedule_with_warmup(optimizer, num_warmup_steps:, num_training_steps:)
7
+ lr_lambda = ->(current_step) {
8
+ if current_step < num_warmup_steps
9
+ current_step.to_f / [1, num_warmup_steps].max
10
+ else
11
+ remaining = num_training_steps - current_step
12
+ total_decay = num_training_steps - num_warmup_steps
13
+ [0.0, remaining.to_f / [1.0, total_decay].max].max
14
+ end
15
+ }
16
+
17
+ Torch::Optim::LRScheduler::LambdaLR.new(optimizer, lr_lambda)
18
+ end
19
+
20
+ # Linear warmup then cosine decay to 0
21
+ def self.get_cosine_schedule_with_warmup(optimizer, num_warmup_steps:, num_training_steps:, num_cycles: 0.5)
22
+ lr_lambda = ->(current_step) {
23
+ if current_step < num_warmup_steps
24
+ current_step.to_f / [1, num_warmup_steps].max
25
+ else
26
+ progress = (current_step - num_warmup_steps).to_f /
27
+ [1, num_training_steps - num_warmup_steps].max
28
+ [0.0, 0.5 * (1.0 + Math.cos(Math::PI * num_cycles * 2.0 * progress))].max
29
+ end
30
+ }
31
+
32
+ Torch::Optim::LRScheduler::LambdaLR.new(optimizer, lr_lambda)
33
+ end
34
+
35
+ # Linear warmup then constant LR
36
+ def self.get_constant_schedule_with_warmup(optimizer, num_warmup_steps:)
37
+ lr_lambda = ->(current_step) {
38
+ if current_step < num_warmup_steps
39
+ current_step.to_f / [1, num_warmup_steps].max
40
+ else
41
+ 1.0
42
+ end
43
+ }
44
+
45
+ Torch::Optim::LRScheduler::LambdaLR.new(optimizer, lr_lambda)
46
+ end
47
+
48
+ # Dispatcher: pick scheduler by type symbol
49
+ def self.create_scheduler(type, optimizer, num_warmup_steps:, num_training_steps:)
50
+ case type
51
+ when :linear
52
+ get_linear_schedule_with_warmup(optimizer,
53
+ num_warmup_steps: num_warmup_steps,
54
+ num_training_steps: num_training_steps)
55
+ when :cosine
56
+ get_cosine_schedule_with_warmup(optimizer,
57
+ num_warmup_steps: num_warmup_steps,
58
+ num_training_steps: num_training_steps)
59
+ when :constant
60
+ get_constant_schedule_with_warmup(optimizer,
61
+ num_warmup_steps: num_warmup_steps)
62
+ else
63
+ raise ArgumentError, "Unknown scheduler type: #{type}. Use :linear, :cosine, or :constant"
64
+ end
65
+ end
66
+ end
67
+ end
@@ -0,0 +1,84 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "json"
4
+ require "fileutils"
5
+
6
+ module Trainers
7
+ module SaveUtils
8
+ def self.save_pretrained(model, tokenizer, output_dir, training_args: nil)
9
+ FileUtils.mkdir_p(output_dir)
10
+
11
+ # Save model weights via safetensors
12
+ state_dict = model.state_dict
13
+ tensors = {}
14
+ state_dict.each do |name, param|
15
+ tensors[name] = param.is_a?(Torch::NN::Parameter) ? param.data : param
16
+ end
17
+
18
+ model_path = File.join(output_dir, "model.safetensors")
19
+ require "safetensors"
20
+ Safetensors::Torch.save_file(tensors, model_path)
21
+ puts "Model saved to #{model_path}"
22
+
23
+ # Save tokenizer
24
+ if tokenizer
25
+ if tokenizer.respond_to?(:save)
26
+ tokenizer.save(File.join(output_dir, "tokenizer.json"))
27
+ end
28
+ end
29
+
30
+ # Save training args
31
+ if training_args
32
+ args_path = File.join(output_dir, "training_args.json")
33
+ File.write(args_path, JSON.pretty_generate(training_args.to_h))
34
+ end
35
+
36
+ output_dir
37
+ end
38
+
39
+ def self.save_lora_adapters(model, output_dir)
40
+ FileUtils.mkdir_p(output_dir)
41
+
42
+ lora_state = {}
43
+ model.named_modules.each do |name, mod|
44
+ if mod.is_a?(Trainers::LoraLinear)
45
+ mod.lora_state_dict.each do |param_name, tensor|
46
+ lora_state["#{name}.#{param_name}"] = tensor
47
+ end
48
+ end
49
+ end
50
+
51
+ if lora_state.empty?
52
+ puts "No LoRA adapters found in model"
53
+ return
54
+ end
55
+
56
+ adapter_path = File.join(output_dir, "adapter_model.safetensors")
57
+ require "safetensors"
58
+ Safetensors::Torch.save_file(lora_state, adapter_path)
59
+ puts "LoRA adapters saved to #{adapter_path} (#{lora_state.size} tensors)"
60
+
61
+ output_dir
62
+ end
63
+
64
+ def self.load_lora_adapters(model, input_dir)
65
+ adapter_path = File.join(input_dir, "adapter_model.safetensors")
66
+ raise "Adapter file not found: #{adapter_path}" unless File.exist?(adapter_path)
67
+
68
+ require "safetensors"
69
+ lora_state = Safetensors::Torch.load_file(adapter_path)
70
+
71
+ model.named_modules.each do |name, mod|
72
+ if mod.is_a?(Trainers::LoraLinear)
73
+ a_key = "#{name}.lora_A"
74
+ b_key = "#{name}.lora_B"
75
+ if lora_state[a_key] && lora_state[b_key]
76
+ mod.load_lora_weights(lora_state[a_key], lora_state[b_key])
77
+ end
78
+ end
79
+ end
80
+
81
+ puts "LoRA adapters loaded from #{adapter_path}"
82
+ end
83
+ end
84
+ end