trainers-rb 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml ADDED
@@ -0,0 +1,7 @@
1
+ ---
2
+ SHA256:
3
+ metadata.gz: 3fd96717156bbd2f4d515f3b1a5aaea897f581b1888bb527124de305a6c2d536
4
+ data.tar.gz: 4a4a764def0528487ee20adb5e1edce52d7a3d267f1c99e399f5f48a7a5d9e2b
5
+ SHA512:
6
+ metadata.gz: 4a36d98cfc6a61d3fe7cad411a653e242b7f2a11f402e7add4ea2300bd9ee59601c8408b864a1094ba6486444b4b9197c524d89de9b50e8c74ee263b3d59ab0b
7
+ data.tar.gz: da2ec1f4fd3d69e76f3d6a812aec49454ff9d99b99006e1b53f0086e4a60d1958d69b6a1b7447e033fcbf927c412cc3fea50954d0c3cd72e5ce49adf4140d95f
data/CHANGELOG.md ADDED
@@ -0,0 +1,17 @@
1
+ # Changelog
2
+
3
+ ## 0.1.0 (2026-06-07)
4
+
5
+ - Initial release
6
+ - Trainer class with train/evaluate/predict
7
+ - TrainingArguments configuration
8
+ - AdamW optimizer with weight decay param groups
9
+ - Linear, cosine, and constant learning rate schedulers with warmup
10
+ - LoRA (Low-Rank Adaptation) for parameter-efficient fine-tuning
11
+ - LoRA adapter save/load via safetensors
12
+ - LoRA merge for inference deployment
13
+ - DataCollatorWithPadding for dynamic batch padding
14
+ - Callback system with PrinterCallback and EarlyStoppingCallback
15
+ - Model saving via safetensors
16
+ - CPU and MPS (Apple Silicon) device support
17
+ - Gradient accumulation and gradient clipping
data/LICENSE.txt ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Ruby ML Community
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
data/README.md ADDED
@@ -0,0 +1,293 @@
1
+ # trainers-rb
2
+
3
+ Fine-tune transformer models in Ruby.
4
+
5
+ trainers-rb provides a training loop, LoRA (Low-Rank Adaptation), learning rate scheduling, and model serialization for HuggingFace transformer models loaded via [transformers-rb](https://github.com/ankane/transformers-ruby). It builds on [torch-rb](https://github.com/ankane/torch.rb) for autograd, optimizers, and tensor operations.
6
+
7
+ All the heavy lifting happens in LibTorch C++ kernels. Ruby is the conductor.
8
+
9
+ ## Installation
10
+
11
+ Add to your Gemfile:
12
+
13
+ ```ruby
14
+ gem "trainers-rb"
15
+ ```
16
+
17
+ Or install directly:
18
+
19
+ ```bash
20
+ gem install trainers-rb
21
+ ```
22
+
23
+ ### Prerequisites
24
+
25
+ trainers-rb depends on torch-rb, which requires LibTorch:
26
+
27
+ ```bash
28
+ # macOS arm64
29
+ curl -L -o /tmp/libtorch.zip https://download.pytorch.org/libtorch/cpu/libtorch-macos-arm64-2.4.0.zip
30
+ unzip /tmp/libtorch.zip -d ~/libtorch
31
+ bundle config set build.torch-rb --with-torch-dir=$HOME/libtorch/libtorch
32
+ ```
33
+
34
+ ## Quick Start
35
+
36
+ ```ruby
37
+ require "trainers-rb"
38
+
39
+ # Load a pre-trained model and tokenizer
40
+ model, tokenizer = Trainers.from_pretrained(
41
+ "distilbert-base-uncased",
42
+ task: :sequence_classification,
43
+ num_labels: 2
44
+ )
45
+
46
+ # Prepare your dataset
47
+ train_data = texts.map.with_index do |text, i|
48
+ encoded = tokenizer.(text, truncation: true, max_length: 128)
49
+ {
50
+ input_ids: encoded["input_ids"],
51
+ attention_mask: encoded["attention_mask"],
52
+ labels: labels[i]
53
+ }
54
+ end
55
+ train_dataset = Trainers::Dataset.new(train_data)
56
+
57
+ # Configure and train
58
+ args = Trainers::TrainingArguments.new(
59
+ output_dir: "./output",
60
+ num_train_epochs: 3,
61
+ learning_rate: 2e-5,
62
+ eval_strategy: :epoch
63
+ )
64
+
65
+ trainer = Trainers::Trainer.new(
66
+ model: model,
67
+ args: args,
68
+ train_dataset: train_dataset,
69
+ eval_dataset: val_dataset,
70
+ tokenizer: tokenizer,
71
+ data_collator: Trainers::DataCollatorWithPadding.new(tokenizer: tokenizer),
72
+ compute_metrics: ->(eval_pred) {
73
+ preds = eval_pred.predictions.argmax(1)
74
+ correct = preds.eq(eval_pred.label_ids).sum.item
75
+ { accuracy: correct.to_f / eval_pred.label_ids.size(0) }
76
+ }
77
+ )
78
+
79
+ trainer.train
80
+ trainer.save_model("./my-model")
81
+ ```
82
+
83
+ ## LoRA (Parameter-Efficient Fine-Tuning)
84
+
85
+ Freeze 99% of parameters and train only small low-rank adapter matrices:
86
+
87
+ ```ruby
88
+ # Apply LoRA to specific layers
89
+ config = Trainers::LoraConfig.new(
90
+ r: 8, # rank
91
+ lora_alpha: 16, # scaling factor
92
+ lora_dropout: 0.1,
93
+ target_modules: ["query", "value"], # which Linear layers to adapt
94
+ bias: :none # :none, :all, or :lora_only
95
+ )
96
+
97
+ Trainers::LoraModel.apply(model, config)
98
+ # => LoRA applied to 12 modules: ...
99
+ # => trainable params: 294,912 || all params: 66,955,010 || trainable%: 0.4404%
100
+
101
+ # Train as usual
102
+ trainer.train
103
+
104
+ # Save just the adapters (tiny files)
105
+ Trainers::LoraModel.save_adapters(model, "./lora-adapters")
106
+
107
+ # Or merge into base model for inference
108
+ Trainers::LoraModel.merge(model)
109
+ trainer.save_model("./merged-model")
110
+ ```
111
+
112
+ ### Loading saved LoRA adapters
113
+
114
+ ```ruby
115
+ model, tokenizer = Trainers.from_pretrained("distilbert-base-uncased", num_labels: 2)
116
+ Trainers::LoraModel.apply(model, config)
117
+ Trainers::LoraModel.load_adapters(model, "./lora-adapters")
118
+ ```
119
+
120
+ ## Training Arguments
121
+
122
+ | Argument | Default | Description |
123
+ |----------|---------|-------------|
124
+ | `output_dir` | `"./output"` | Directory for checkpoints and saved models |
125
+ | `num_train_epochs` | `3` | Number of training epochs |
126
+ | `per_device_train_batch_size` | `8` | Training batch size |
127
+ | `per_device_eval_batch_size` | `8` | Evaluation batch size |
128
+ | `learning_rate` | `5e-5` | Peak learning rate for AdamW |
129
+ | `weight_decay` | `0.0` | Weight decay (applied to non-bias, non-norm params) |
130
+ | `max_grad_norm` | `1.0` | Max gradient norm for clipping |
131
+ | `gradient_accumulation_steps` | `1` | Accumulate gradients over N steps |
132
+ | `warmup_steps` | `0` | Linear warmup steps |
133
+ | `warmup_ratio` | `0.0` | Warmup as fraction of total steps (alternative to warmup_steps) |
134
+ | `lr_scheduler_type` | `:linear` | `:linear`, `:cosine`, or `:constant` |
135
+ | `eval_strategy` | `:no` | When to evaluate: `:no`, `:epoch`, or `:steps` |
136
+ | `eval_steps` | `nil` | Evaluate every N steps (when `eval_strategy: :steps`) |
137
+ | `save_strategy` | `:epoch` | When to save: `:no`, `:epoch`, or `:steps` |
138
+ | `save_total_limit` | `nil` | Keep only the last N checkpoints |
139
+ | `logging_steps` | `500` | Log every N steps |
140
+ | `seed` | `42` | Random seed |
141
+ | `no_mps` | `false` | Force CPU even if MPS is available |
142
+
143
+ ## Callbacks
144
+
145
+ Built-in callbacks:
146
+
147
+ ```ruby
148
+ # Early stopping
149
+ early_stop = Trainers::EarlyStoppingCallback.new(
150
+ patience: 3,
151
+ threshold: 0.01,
152
+ metric_name: "eval_loss"
153
+ )
154
+
155
+ trainer = Trainers::Trainer.new(
156
+ model: model,
157
+ args: args,
158
+ callbacks: [early_stop],
159
+ # ...
160
+ )
161
+ ```
162
+
163
+ Custom callbacks:
164
+
165
+ ```ruby
166
+ class WandbCallback < Trainers::TrainerCallback
167
+ def on_log(args, state, control, logs: nil, **)
168
+ # send logs to Weights & Biases, MLflow, etc.
169
+ end
170
+
171
+ def on_evaluate(args, state, control, metrics: nil, **)
172
+ # log evaluation metrics
173
+ end
174
+ end
175
+ ```
176
+
177
+ ### Callback hooks
178
+
179
+ | Hook | When it fires |
180
+ |------|---------------|
181
+ | `on_train_begin` | Before the first step |
182
+ | `on_train_end` | After the last step |
183
+ | `on_epoch_begin` | Start of each epoch |
184
+ | `on_epoch_end` | End of each epoch |
185
+ | `on_step_begin` | Before each training step |
186
+ | `on_step_end` | After each training step |
187
+ | `on_log` | When metrics are logged |
188
+ | `on_evaluate` | After evaluation |
189
+ | `on_save` | After saving a checkpoint |
190
+
191
+ ## Learning Rate Schedulers
192
+
193
+ Three schedules are available, all with optional linear warmup:
194
+
195
+ ```ruby
196
+ # Linear warmup then linear decay to 0 (default)
197
+ args = Trainers::TrainingArguments.new(lr_scheduler_type: :linear, warmup_steps: 100)
198
+
199
+ # Linear warmup then cosine decay to 0
200
+ args = Trainers::TrainingArguments.new(lr_scheduler_type: :cosine, warmup_steps: 100)
201
+
202
+ # Linear warmup then constant
203
+ args = Trainers::TrainingArguments.new(lr_scheduler_type: :constant, warmup_steps: 100)
204
+ ```
205
+
206
+ ## Data Utilities
207
+
208
+ ### Dataset
209
+
210
+ Wrap an array of hashes:
211
+
212
+ ```ruby
213
+ data = [
214
+ { input_ids: [101, 2023, 2003], attention_mask: [1, 1, 1], labels: 1 },
215
+ { input_ids: [101, 2919, 2143], attention_mask: [1, 1, 1], labels: 0 },
216
+ ]
217
+ dataset = Trainers::Dataset.new(data)
218
+ ```
219
+
220
+ ### Data Collators
221
+
222
+ Dynamic padding collator (pads each batch to the longest sequence in that batch):
223
+
224
+ ```ruby
225
+ collator = Trainers::DataCollatorWithPadding.new(tokenizer: tokenizer)
226
+ ```
227
+
228
+ Default collator (no padding, expects uniform-length inputs):
229
+
230
+ ```ruby
231
+ collator = Trainers::DefaultDataCollator.new
232
+ ```
233
+
234
+ ## Supported Tasks
235
+
236
+ trainers-rb works with any `Torch::NN::Module`. The `Trainers.from_pretrained` convenience method supports these transformers-rb model classes:
237
+
238
+ | Task | Model class |
239
+ |------|-------------|
240
+ | `:sequence_classification` | `AutoModelForSequenceClassification` |
241
+ | `:token_classification` | `AutoModelForTokenClassification` |
242
+ | `:question_answering` | `AutoModelForQuestionAnswering` |
243
+
244
+ You can also use any custom model:
245
+
246
+ ```ruby
247
+ trainer = Trainers::Trainer.new(model: my_custom_model, args: args, ...)
248
+ ```
249
+
250
+ ## Device Support
251
+
252
+ trainers-rb auto-detects the best available device:
253
+
254
+ - **CPU** — always available
255
+ - **MPS** — Apple Silicon GPU, used automatically when available
256
+
257
+ ```ruby
258
+ # Force CPU
259
+ args = Trainers::TrainingArguments.new(no_mps: true)
260
+
261
+ # Or set explicitly
262
+ args = Trainers::TrainingArguments.new(device: Torch.device("mps"))
263
+ ```
264
+
265
+ ## Architecture
266
+
267
+ ```
268
+ trainers-rb
269
+ -> transformers-rb (model loading, tokenizers, HF Hub)
270
+ -> torch-rb (autograd, nn modules, optimizers)
271
+ -> tokenizers (HuggingFace Rust tokenizers via FFI)
272
+ -> safetensors (weight file I/O)
273
+ ```
274
+
275
+ trainers-rb adds the training layer that transformers-rb intentionally omits. Both gems call into the same LibTorch C++ kernels for the actual computation.
276
+
277
+ ## Roadmap
278
+
279
+ - [ ] More model architectures in transformers-rb (GPT-2, Llama for text generation)
280
+ - [ ] Mixed precision training (fp16/bf16)
281
+ - [ ] Gradient checkpointing for memory efficiency
282
+ - [ ] Dataset streaming for large datasets
283
+ - [ ] Distributed training
284
+ - [ ] Integration with ONNX export for deployment
285
+ - [ ] QLoRA (quantized base model + LoRA)
286
+
287
+ ## Contributing
288
+
289
+ Bug reports and pull requests are welcome on GitHub.
290
+
291
+ ## License
292
+
293
+ MIT
@@ -0,0 +1,128 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Trainers
4
+ class TrainerState
5
+ attr_accessor :epoch, :global_step, :max_steps, :num_train_epochs,
6
+ :total_flos, :best_metric, :best_model_checkpoint,
7
+ :log_history
8
+
9
+ def initialize
10
+ @epoch = 0.0
11
+ @global_step = 0
12
+ @max_steps = 0
13
+ @num_train_epochs = 0
14
+ @total_flos = 0
15
+ @best_metric = nil
16
+ @best_model_checkpoint = nil
17
+ @log_history = []
18
+ end
19
+ end
20
+
21
+ class TrainerControl
22
+ attr_accessor :should_training_stop, :should_epoch_stop,
23
+ :should_save, :should_evaluate, :should_log
24
+
25
+ def initialize
26
+ @should_training_stop = false
27
+ @should_epoch_stop = false
28
+ @should_save = false
29
+ @should_evaluate = false
30
+ @should_log = false
31
+ end
32
+ end
33
+
34
+ # Base class for trainer callbacks. Override any hook you need.
35
+ class TrainerCallback
36
+ def on_train_begin(args, state, control, **kwargs); end
37
+ def on_train_end(args, state, control, **kwargs); end
38
+ def on_epoch_begin(args, state, control, **kwargs); end
39
+ def on_epoch_end(args, state, control, **kwargs); end
40
+ def on_step_begin(args, state, control, **kwargs); end
41
+ def on_step_end(args, state, control, **kwargs); end
42
+ def on_log(args, state, control, logs: nil, **kwargs); end
43
+ def on_evaluate(args, state, control, metrics: nil, **kwargs); end
44
+ def on_save(args, state, control, **kwargs); end
45
+ end
46
+
47
+ # Default callback: prints training progress to stdout
48
+ class PrinterCallback < TrainerCallback
49
+ def on_log(args, state, control, logs: nil, **kwargs)
50
+ return unless logs
51
+ output = logs.map { |k, v| "#{k}: #{format_value(v)}" }.join(" ")
52
+ puts "[step #{state.global_step}] #{output}"
53
+ end
54
+
55
+ def on_train_begin(args, state, control, **kwargs)
56
+ puts "Starting training: #{state.num_train_epochs} epochs, #{state.max_steps} total steps"
57
+ end
58
+
59
+ def on_train_end(args, state, control, **kwargs)
60
+ puts "Training complete. Total steps: #{state.global_step}"
61
+ end
62
+
63
+ def on_evaluate(args, state, control, metrics: nil, **kwargs)
64
+ return unless metrics
65
+ output = metrics.map { |k, v| "#{k}: #{format_value(v)}" }.join(" ")
66
+ puts "[eval step #{state.global_step}] #{output}"
67
+ end
68
+
69
+ private
70
+
71
+ def format_value(v)
72
+ v.is_a?(Float) ? format("%.4f", v) : v.to_s
73
+ end
74
+ end
75
+
76
+ # Early stopping callback
77
+ class EarlyStoppingCallback < TrainerCallback
78
+ def initialize(patience: 3, threshold: 0.0, metric_name: "eval_loss")
79
+ @patience = patience
80
+ @threshold = threshold
81
+ @metric_name = metric_name
82
+ @best_value = nil
83
+ @wait_count = 0
84
+ end
85
+
86
+ def on_evaluate(args, state, control, metrics: nil, **kwargs)
87
+ return unless metrics
88
+
89
+ current = metrics[@metric_name] || metrics[@metric_name.to_sym]
90
+ return unless current
91
+
92
+ if @best_value.nil? || improved?(current, @best_value)
93
+ @best_value = current
94
+ @wait_count = 0
95
+ else
96
+ @wait_count += 1
97
+ if @wait_count >= @patience
98
+ puts "Early stopping triggered after #{@wait_count} evaluations without improvement"
99
+ control.should_training_stop = true
100
+ end
101
+ end
102
+ end
103
+
104
+ private
105
+
106
+ def improved?(current, best)
107
+ if @metric_name.include?("loss")
108
+ current < best - @threshold
109
+ else
110
+ current > best + @threshold
111
+ end
112
+ end
113
+ end
114
+
115
+ # Dispatches callback events to all registered callbacks
116
+ class CallbackHandler
117
+ def initialize(callbacks)
118
+ @callbacks = callbacks
119
+ end
120
+
121
+ def fire(event, args, state, control, **kwargs)
122
+ @callbacks.each do |cb|
123
+ cb.send(event, args, state, control, **kwargs) if cb.respond_to?(event)
124
+ end
125
+ control
126
+ end
127
+ end
128
+ end
@@ -0,0 +1,131 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Trainers
4
+ class DataCollatorWithPadding
5
+ attr_reader :tokenizer, :padding, :max_length, :pad_to_multiple_of
6
+
7
+ def initialize(tokenizer:, padding: true, max_length: nil, pad_to_multiple_of: nil)
8
+ @tokenizer = tokenizer
9
+ @padding = padding
10
+ @max_length = max_length
11
+ @pad_to_multiple_of = pad_to_multiple_of
12
+ end
13
+
14
+ def call(features)
15
+ return {} if features.empty?
16
+
17
+ keys = features.first.keys
18
+ batch = {}
19
+
20
+ keys.each do |key|
21
+ values = features.map { |f| f[key] }
22
+
23
+ if values.first.is_a?(Array)
24
+ batch[key] = pad_and_stack(key, values)
25
+ elsif values.first.is_a?(Torch::Tensor)
26
+ if values.first.dim == 0
27
+ batch[key] = Torch.stack(values)
28
+ else
29
+ batch[key] = pad_and_stack_tensors(key, values)
30
+ end
31
+ elsif values.first.is_a?(Integer)
32
+ batch[key] = Torch.tensor(values, dtype: :int64)
33
+ elsif values.first.is_a?(Float)
34
+ batch[key] = Torch.tensor(values, dtype: :float32)
35
+ else
36
+ batch[key] = values
37
+ end
38
+ end
39
+
40
+ batch
41
+ end
42
+
43
+ private
44
+
45
+ def pad_and_stack(key, sequences)
46
+ max_len = compute_max_length(sequences.map(&:length))
47
+
48
+ pad_value = if key.to_s.include?("input_id")
49
+ pad_token_id
50
+ elsif key.to_s.include?("attention_mask")
51
+ 0
52
+ elsif key.to_s.include?("label")
53
+ -100
54
+ else
55
+ 0
56
+ end
57
+
58
+ padded = sequences.map do |seq|
59
+ padded_seq = seq + [pad_value] * (max_len - seq.length)
60
+ padded_seq
61
+ end
62
+
63
+ dtype = key.to_s.include?("attention_mask") ? :int64 : :int64
64
+ Torch.tensor(padded, dtype: dtype)
65
+ end
66
+
67
+ def pad_and_stack_tensors(key, tensors)
68
+ max_len = compute_max_length(tensors.map { |t| t.size(0) })
69
+
70
+ pad_value = if key.to_s.include?("input_id")
71
+ pad_token_id
72
+ elsif key.to_s.include?("attention_mask")
73
+ 0
74
+ else
75
+ 0
76
+ end
77
+
78
+ padded = tensors.map do |t|
79
+ if t.size(0) < max_len
80
+ padding = Torch.full([max_len - t.size(0)], pad_value, dtype: t.dtype)
81
+ Torch.cat([t, padding])
82
+ else
83
+ t[0...max_len]
84
+ end
85
+ end
86
+
87
+ Torch.stack(padded)
88
+ end
89
+
90
+ def compute_max_length(lengths)
91
+ max_len = @max_length || lengths.max
92
+ if @pad_to_multiple_of
93
+ max_len = ((max_len + @pad_to_multiple_of - 1) / @pad_to_multiple_of) * @pad_to_multiple_of
94
+ end
95
+ max_len
96
+ end
97
+
98
+ def pad_token_id
99
+ if @tokenizer.respond_to?(:pad_token_id)
100
+ @tokenizer.pad_token_id || 0
101
+ else
102
+ 0
103
+ end
104
+ end
105
+ end
106
+
107
+ class DefaultDataCollator
108
+ def call(features)
109
+ return {} if features.empty?
110
+
111
+ batch = {}
112
+ features.first.keys.each do |key|
113
+ values = features.map { |f| f[key] }
114
+
115
+ if values.first.is_a?(Torch::Tensor)
116
+ batch[key] = Torch.stack(values)
117
+ elsif values.first.is_a?(Integer)
118
+ batch[key] = Torch.tensor(values, dtype: :int64)
119
+ elsif values.first.is_a?(Float)
120
+ batch[key] = Torch.tensor(values, dtype: :float32)
121
+ elsif values.first.is_a?(Array)
122
+ batch[key] = Torch.tensor(values)
123
+ else
124
+ batch[key] = values
125
+ end
126
+ end
127
+
128
+ batch
129
+ end
130
+ end
131
+ end
@@ -0,0 +1,26 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Trainers
4
+ class Dataset
5
+ include Enumerable
6
+
7
+ attr_reader :data
8
+
9
+ def initialize(data)
10
+ @data = data
11
+ end
12
+
13
+ def [](index)
14
+ @data[index]
15
+ end
16
+
17
+ def size
18
+ @data.size
19
+ end
20
+ alias_method :length, :size
21
+
22
+ def each(&block)
23
+ @data.each(&block)
24
+ end
25
+ end
26
+ end
@@ -0,0 +1,34 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Trainers
4
+ class LoraConfig
5
+ attr_accessor :r, :lora_alpha, :lora_dropout, :target_modules,
6
+ :bias, :task_type
7
+
8
+ DEFAULTS = {
9
+ r: 8,
10
+ lora_alpha: 16,
11
+ lora_dropout: 0.0,
12
+ target_modules: ["query", "value"], # or :all_linear
13
+ bias: :none, # :none, :all, :lora_only
14
+ task_type: :sequence_classification
15
+ }.freeze
16
+
17
+ def initialize(**kwargs)
18
+ DEFAULTS.each do |key, default|
19
+ value = kwargs.fetch(key, default)
20
+ instance_variable_set(:"@#{key}", value)
21
+ end
22
+ end
23
+
24
+ def scaling
25
+ @lora_alpha.to_f / @r
26
+ end
27
+
28
+ def to_h
29
+ DEFAULTS.keys.each_with_object({}) do |key, hash|
30
+ hash[key] = send(key)
31
+ end
32
+ end
33
+ end
34
+ end