trainers-rb 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +17 -0
- data/LICENSE.txt +21 -0
- data/README.md +293 -0
- data/lib/trainers/callbacks.rb +128 -0
- data/lib/trainers/data/data_collator.rb +131 -0
- data/lib/trainers/data/dataset.rb +26 -0
- data/lib/trainers/lora/lora_config.rb +34 -0
- data/lib/trainers/lora/lora_linear.rb +78 -0
- data/lib/trainers/lora/lora_model.rb +87 -0
- data/lib/trainers/lora/lora_utils.rb +73 -0
- data/lib/trainers/optimization/optimizer.rb +49 -0
- data/lib/trainers/optimization/scheduler.rb +67 -0
- data/lib/trainers/save_utils.rb +84 -0
- data/lib/trainers/trainer.rb +340 -0
- data/lib/trainers/trainer_utils.rb +30 -0
- data/lib/trainers/training_arguments.rb +64 -0
- data/lib/trainers/version.rb +5 -0
- data/lib/trainers-rb.rb +1 -0
- data/lib/trainers.rb +43 -0
- metadata +149 -0
checksums.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
---
|
|
2
|
+
SHA256:
|
|
3
|
+
metadata.gz: 3fd96717156bbd2f4d515f3b1a5aaea897f581b1888bb527124de305a6c2d536
|
|
4
|
+
data.tar.gz: 4a4a764def0528487ee20adb5e1edce52d7a3d267f1c99e399f5f48a7a5d9e2b
|
|
5
|
+
SHA512:
|
|
6
|
+
metadata.gz: 4a36d98cfc6a61d3fe7cad411a653e242b7f2a11f402e7add4ea2300bd9ee59601c8408b864a1094ba6486444b4b9197c524d89de9b50e8c74ee263b3d59ab0b
|
|
7
|
+
data.tar.gz: da2ec1f4fd3d69e76f3d6a812aec49454ff9d99b99006e1b53f0086e4a60d1958d69b6a1b7447e033fcbf927c412cc3fea50954d0c3cd72e5ce49adf4140d95f
|
data/CHANGELOG.md
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
# Changelog
|
|
2
|
+
|
|
3
|
+
## 0.1.0 (2026-06-07)
|
|
4
|
+
|
|
5
|
+
- Initial release
|
|
6
|
+
- Trainer class with train/evaluate/predict
|
|
7
|
+
- TrainingArguments configuration
|
|
8
|
+
- AdamW optimizer with weight decay param groups
|
|
9
|
+
- Linear, cosine, and constant learning rate schedulers with warmup
|
|
10
|
+
- LoRA (Low-Rank Adaptation) for parameter-efficient fine-tuning
|
|
11
|
+
- LoRA adapter save/load via safetensors
|
|
12
|
+
- LoRA merge for inference deployment
|
|
13
|
+
- DataCollatorWithPadding for dynamic batch padding
|
|
14
|
+
- Callback system with PrinterCallback and EarlyStoppingCallback
|
|
15
|
+
- Model saving via safetensors
|
|
16
|
+
- CPU and MPS (Apple Silicon) device support
|
|
17
|
+
- Gradient accumulation and gradient clipping
|
data/LICENSE.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Ruby ML Community
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
data/README.md
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
1
|
+
# trainers-rb
|
|
2
|
+
|
|
3
|
+
Fine-tune transformer models in Ruby.
|
|
4
|
+
|
|
5
|
+
trainers-rb provides a training loop, LoRA (Low-Rank Adaptation), learning rate scheduling, and model serialization for HuggingFace transformer models loaded via [transformers-rb](https://github.com/ankane/transformers-ruby). It builds on [torch-rb](https://github.com/ankane/torch.rb) for autograd, optimizers, and tensor operations.
|
|
6
|
+
|
|
7
|
+
All the heavy lifting happens in LibTorch C++ kernels. Ruby is the conductor.
|
|
8
|
+
|
|
9
|
+
## Installation
|
|
10
|
+
|
|
11
|
+
Add to your Gemfile:
|
|
12
|
+
|
|
13
|
+
```ruby
|
|
14
|
+
gem "trainers-rb"
|
|
15
|
+
```
|
|
16
|
+
|
|
17
|
+
Or install directly:
|
|
18
|
+
|
|
19
|
+
```bash
|
|
20
|
+
gem install trainers-rb
|
|
21
|
+
```
|
|
22
|
+
|
|
23
|
+
### Prerequisites
|
|
24
|
+
|
|
25
|
+
trainers-rb depends on torch-rb, which requires LibTorch:
|
|
26
|
+
|
|
27
|
+
```bash
|
|
28
|
+
# macOS arm64
|
|
29
|
+
curl -L -o /tmp/libtorch.zip https://download.pytorch.org/libtorch/cpu/libtorch-macos-arm64-2.4.0.zip
|
|
30
|
+
unzip /tmp/libtorch.zip -d ~/libtorch
|
|
31
|
+
bundle config set build.torch-rb --with-torch-dir=$HOME/libtorch/libtorch
|
|
32
|
+
```
|
|
33
|
+
|
|
34
|
+
## Quick Start
|
|
35
|
+
|
|
36
|
+
```ruby
|
|
37
|
+
require "trainers-rb"
|
|
38
|
+
|
|
39
|
+
# Load a pre-trained model and tokenizer
|
|
40
|
+
model, tokenizer = Trainers.from_pretrained(
|
|
41
|
+
"distilbert-base-uncased",
|
|
42
|
+
task: :sequence_classification,
|
|
43
|
+
num_labels: 2
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
# Prepare your dataset
|
|
47
|
+
train_data = texts.map.with_index do |text, i|
|
|
48
|
+
encoded = tokenizer.(text, truncation: true, max_length: 128)
|
|
49
|
+
{
|
|
50
|
+
input_ids: encoded["input_ids"],
|
|
51
|
+
attention_mask: encoded["attention_mask"],
|
|
52
|
+
labels: labels[i]
|
|
53
|
+
}
|
|
54
|
+
end
|
|
55
|
+
train_dataset = Trainers::Dataset.new(train_data)
|
|
56
|
+
|
|
57
|
+
# Configure and train
|
|
58
|
+
args = Trainers::TrainingArguments.new(
|
|
59
|
+
output_dir: "./output",
|
|
60
|
+
num_train_epochs: 3,
|
|
61
|
+
learning_rate: 2e-5,
|
|
62
|
+
eval_strategy: :epoch
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
trainer = Trainers::Trainer.new(
|
|
66
|
+
model: model,
|
|
67
|
+
args: args,
|
|
68
|
+
train_dataset: train_dataset,
|
|
69
|
+
eval_dataset: val_dataset,
|
|
70
|
+
tokenizer: tokenizer,
|
|
71
|
+
data_collator: Trainers::DataCollatorWithPadding.new(tokenizer: tokenizer),
|
|
72
|
+
compute_metrics: ->(eval_pred) {
|
|
73
|
+
preds = eval_pred.predictions.argmax(1)
|
|
74
|
+
correct = preds.eq(eval_pred.label_ids).sum.item
|
|
75
|
+
{ accuracy: correct.to_f / eval_pred.label_ids.size(0) }
|
|
76
|
+
}
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
trainer.train
|
|
80
|
+
trainer.save_model("./my-model")
|
|
81
|
+
```
|
|
82
|
+
|
|
83
|
+
## LoRA (Parameter-Efficient Fine-Tuning)
|
|
84
|
+
|
|
85
|
+
Freeze 99% of parameters and train only small low-rank adapter matrices:
|
|
86
|
+
|
|
87
|
+
```ruby
|
|
88
|
+
# Apply LoRA to specific layers
|
|
89
|
+
config = Trainers::LoraConfig.new(
|
|
90
|
+
r: 8, # rank
|
|
91
|
+
lora_alpha: 16, # scaling factor
|
|
92
|
+
lora_dropout: 0.1,
|
|
93
|
+
target_modules: ["query", "value"], # which Linear layers to adapt
|
|
94
|
+
bias: :none # :none, :all, or :lora_only
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
Trainers::LoraModel.apply(model, config)
|
|
98
|
+
# => LoRA applied to 12 modules: ...
|
|
99
|
+
# => trainable params: 294,912 || all params: 66,955,010 || trainable%: 0.4404%
|
|
100
|
+
|
|
101
|
+
# Train as usual
|
|
102
|
+
trainer.train
|
|
103
|
+
|
|
104
|
+
# Save just the adapters (tiny files)
|
|
105
|
+
Trainers::LoraModel.save_adapters(model, "./lora-adapters")
|
|
106
|
+
|
|
107
|
+
# Or merge into base model for inference
|
|
108
|
+
Trainers::LoraModel.merge(model)
|
|
109
|
+
trainer.save_model("./merged-model")
|
|
110
|
+
```
|
|
111
|
+
|
|
112
|
+
### Loading saved LoRA adapters
|
|
113
|
+
|
|
114
|
+
```ruby
|
|
115
|
+
model, tokenizer = Trainers.from_pretrained("distilbert-base-uncased", num_labels: 2)
|
|
116
|
+
Trainers::LoraModel.apply(model, config)
|
|
117
|
+
Trainers::LoraModel.load_adapters(model, "./lora-adapters")
|
|
118
|
+
```
|
|
119
|
+
|
|
120
|
+
## Training Arguments
|
|
121
|
+
|
|
122
|
+
| Argument | Default | Description |
|
|
123
|
+
|----------|---------|-------------|
|
|
124
|
+
| `output_dir` | `"./output"` | Directory for checkpoints and saved models |
|
|
125
|
+
| `num_train_epochs` | `3` | Number of training epochs |
|
|
126
|
+
| `per_device_train_batch_size` | `8` | Training batch size |
|
|
127
|
+
| `per_device_eval_batch_size` | `8` | Evaluation batch size |
|
|
128
|
+
| `learning_rate` | `5e-5` | Peak learning rate for AdamW |
|
|
129
|
+
| `weight_decay` | `0.0` | Weight decay (applied to non-bias, non-norm params) |
|
|
130
|
+
| `max_grad_norm` | `1.0` | Max gradient norm for clipping |
|
|
131
|
+
| `gradient_accumulation_steps` | `1` | Accumulate gradients over N steps |
|
|
132
|
+
| `warmup_steps` | `0` | Linear warmup steps |
|
|
133
|
+
| `warmup_ratio` | `0.0` | Warmup as fraction of total steps (alternative to warmup_steps) |
|
|
134
|
+
| `lr_scheduler_type` | `:linear` | `:linear`, `:cosine`, or `:constant` |
|
|
135
|
+
| `eval_strategy` | `:no` | When to evaluate: `:no`, `:epoch`, or `:steps` |
|
|
136
|
+
| `eval_steps` | `nil` | Evaluate every N steps (when `eval_strategy: :steps`) |
|
|
137
|
+
| `save_strategy` | `:epoch` | When to save: `:no`, `:epoch`, or `:steps` |
|
|
138
|
+
| `save_total_limit` | `nil` | Keep only the last N checkpoints |
|
|
139
|
+
| `logging_steps` | `500` | Log every N steps |
|
|
140
|
+
| `seed` | `42` | Random seed |
|
|
141
|
+
| `no_mps` | `false` | Force CPU even if MPS is available |
|
|
142
|
+
|
|
143
|
+
## Callbacks
|
|
144
|
+
|
|
145
|
+
Built-in callbacks:
|
|
146
|
+
|
|
147
|
+
```ruby
|
|
148
|
+
# Early stopping
|
|
149
|
+
early_stop = Trainers::EarlyStoppingCallback.new(
|
|
150
|
+
patience: 3,
|
|
151
|
+
threshold: 0.01,
|
|
152
|
+
metric_name: "eval_loss"
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
trainer = Trainers::Trainer.new(
|
|
156
|
+
model: model,
|
|
157
|
+
args: args,
|
|
158
|
+
callbacks: [early_stop],
|
|
159
|
+
# ...
|
|
160
|
+
)
|
|
161
|
+
```
|
|
162
|
+
|
|
163
|
+
Custom callbacks:
|
|
164
|
+
|
|
165
|
+
```ruby
|
|
166
|
+
class WandbCallback < Trainers::TrainerCallback
|
|
167
|
+
def on_log(args, state, control, logs: nil, **)
|
|
168
|
+
# send logs to Weights & Biases, MLflow, etc.
|
|
169
|
+
end
|
|
170
|
+
|
|
171
|
+
def on_evaluate(args, state, control, metrics: nil, **)
|
|
172
|
+
# log evaluation metrics
|
|
173
|
+
end
|
|
174
|
+
end
|
|
175
|
+
```
|
|
176
|
+
|
|
177
|
+
### Callback hooks
|
|
178
|
+
|
|
179
|
+
| Hook | When it fires |
|
|
180
|
+
|------|---------------|
|
|
181
|
+
| `on_train_begin` | Before the first step |
|
|
182
|
+
| `on_train_end` | After the last step |
|
|
183
|
+
| `on_epoch_begin` | Start of each epoch |
|
|
184
|
+
| `on_epoch_end` | End of each epoch |
|
|
185
|
+
| `on_step_begin` | Before each training step |
|
|
186
|
+
| `on_step_end` | After each training step |
|
|
187
|
+
| `on_log` | When metrics are logged |
|
|
188
|
+
| `on_evaluate` | After evaluation |
|
|
189
|
+
| `on_save` | After saving a checkpoint |
|
|
190
|
+
|
|
191
|
+
## Learning Rate Schedulers
|
|
192
|
+
|
|
193
|
+
Three schedules are available, all with optional linear warmup:
|
|
194
|
+
|
|
195
|
+
```ruby
|
|
196
|
+
# Linear warmup then linear decay to 0 (default)
|
|
197
|
+
args = Trainers::TrainingArguments.new(lr_scheduler_type: :linear, warmup_steps: 100)
|
|
198
|
+
|
|
199
|
+
# Linear warmup then cosine decay to 0
|
|
200
|
+
args = Trainers::TrainingArguments.new(lr_scheduler_type: :cosine, warmup_steps: 100)
|
|
201
|
+
|
|
202
|
+
# Linear warmup then constant
|
|
203
|
+
args = Trainers::TrainingArguments.new(lr_scheduler_type: :constant, warmup_steps: 100)
|
|
204
|
+
```
|
|
205
|
+
|
|
206
|
+
## Data Utilities
|
|
207
|
+
|
|
208
|
+
### Dataset
|
|
209
|
+
|
|
210
|
+
Wrap an array of hashes:
|
|
211
|
+
|
|
212
|
+
```ruby
|
|
213
|
+
data = [
|
|
214
|
+
{ input_ids: [101, 2023, 2003], attention_mask: [1, 1, 1], labels: 1 },
|
|
215
|
+
{ input_ids: [101, 2919, 2143], attention_mask: [1, 1, 1], labels: 0 },
|
|
216
|
+
]
|
|
217
|
+
dataset = Trainers::Dataset.new(data)
|
|
218
|
+
```
|
|
219
|
+
|
|
220
|
+
### Data Collators
|
|
221
|
+
|
|
222
|
+
Dynamic padding collator (pads each batch to the longest sequence in that batch):
|
|
223
|
+
|
|
224
|
+
```ruby
|
|
225
|
+
collator = Trainers::DataCollatorWithPadding.new(tokenizer: tokenizer)
|
|
226
|
+
```
|
|
227
|
+
|
|
228
|
+
Default collator (no padding, expects uniform-length inputs):
|
|
229
|
+
|
|
230
|
+
```ruby
|
|
231
|
+
collator = Trainers::DefaultDataCollator.new
|
|
232
|
+
```
|
|
233
|
+
|
|
234
|
+
## Supported Tasks
|
|
235
|
+
|
|
236
|
+
trainers-rb works with any `Torch::NN::Module`. The `Trainers.from_pretrained` convenience method supports these transformers-rb model classes:
|
|
237
|
+
|
|
238
|
+
| Task | Model class |
|
|
239
|
+
|------|-------------|
|
|
240
|
+
| `:sequence_classification` | `AutoModelForSequenceClassification` |
|
|
241
|
+
| `:token_classification` | `AutoModelForTokenClassification` |
|
|
242
|
+
| `:question_answering` | `AutoModelForQuestionAnswering` |
|
|
243
|
+
|
|
244
|
+
You can also use any custom model:
|
|
245
|
+
|
|
246
|
+
```ruby
|
|
247
|
+
trainer = Trainers::Trainer.new(model: my_custom_model, args: args, ...)
|
|
248
|
+
```
|
|
249
|
+
|
|
250
|
+
## Device Support
|
|
251
|
+
|
|
252
|
+
trainers-rb auto-detects the best available device:
|
|
253
|
+
|
|
254
|
+
- **CPU** — always available
|
|
255
|
+
- **MPS** — Apple Silicon GPU, used automatically when available
|
|
256
|
+
|
|
257
|
+
```ruby
|
|
258
|
+
# Force CPU
|
|
259
|
+
args = Trainers::TrainingArguments.new(no_mps: true)
|
|
260
|
+
|
|
261
|
+
# Or set explicitly
|
|
262
|
+
args = Trainers::TrainingArguments.new(device: Torch.device("mps"))
|
|
263
|
+
```
|
|
264
|
+
|
|
265
|
+
## Architecture
|
|
266
|
+
|
|
267
|
+
```
|
|
268
|
+
trainers-rb
|
|
269
|
+
-> transformers-rb (model loading, tokenizers, HF Hub)
|
|
270
|
+
-> torch-rb (autograd, nn modules, optimizers)
|
|
271
|
+
-> tokenizers (HuggingFace Rust tokenizers via FFI)
|
|
272
|
+
-> safetensors (weight file I/O)
|
|
273
|
+
```
|
|
274
|
+
|
|
275
|
+
trainers-rb adds the training layer that transformers-rb intentionally omits. Both gems call into the same LibTorch C++ kernels for the actual computation.
|
|
276
|
+
|
|
277
|
+
## Roadmap
|
|
278
|
+
|
|
279
|
+
- [ ] More model architectures in transformers-rb (GPT-2, Llama for text generation)
|
|
280
|
+
- [ ] Mixed precision training (fp16/bf16)
|
|
281
|
+
- [ ] Gradient checkpointing for memory efficiency
|
|
282
|
+
- [ ] Dataset streaming for large datasets
|
|
283
|
+
- [ ] Distributed training
|
|
284
|
+
- [ ] Integration with ONNX export for deployment
|
|
285
|
+
- [ ] QLoRA (quantized base model + LoRA)
|
|
286
|
+
|
|
287
|
+
## Contributing
|
|
288
|
+
|
|
289
|
+
Bug reports and pull requests are welcome on GitHub.
|
|
290
|
+
|
|
291
|
+
## License
|
|
292
|
+
|
|
293
|
+
MIT
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Trainers
|
|
4
|
+
class TrainerState
|
|
5
|
+
attr_accessor :epoch, :global_step, :max_steps, :num_train_epochs,
|
|
6
|
+
:total_flos, :best_metric, :best_model_checkpoint,
|
|
7
|
+
:log_history
|
|
8
|
+
|
|
9
|
+
def initialize
|
|
10
|
+
@epoch = 0.0
|
|
11
|
+
@global_step = 0
|
|
12
|
+
@max_steps = 0
|
|
13
|
+
@num_train_epochs = 0
|
|
14
|
+
@total_flos = 0
|
|
15
|
+
@best_metric = nil
|
|
16
|
+
@best_model_checkpoint = nil
|
|
17
|
+
@log_history = []
|
|
18
|
+
end
|
|
19
|
+
end
|
|
20
|
+
|
|
21
|
+
class TrainerControl
|
|
22
|
+
attr_accessor :should_training_stop, :should_epoch_stop,
|
|
23
|
+
:should_save, :should_evaluate, :should_log
|
|
24
|
+
|
|
25
|
+
def initialize
|
|
26
|
+
@should_training_stop = false
|
|
27
|
+
@should_epoch_stop = false
|
|
28
|
+
@should_save = false
|
|
29
|
+
@should_evaluate = false
|
|
30
|
+
@should_log = false
|
|
31
|
+
end
|
|
32
|
+
end
|
|
33
|
+
|
|
34
|
+
# Base class for trainer callbacks. Override any hook you need.
|
|
35
|
+
class TrainerCallback
|
|
36
|
+
def on_train_begin(args, state, control, **kwargs); end
|
|
37
|
+
def on_train_end(args, state, control, **kwargs); end
|
|
38
|
+
def on_epoch_begin(args, state, control, **kwargs); end
|
|
39
|
+
def on_epoch_end(args, state, control, **kwargs); end
|
|
40
|
+
def on_step_begin(args, state, control, **kwargs); end
|
|
41
|
+
def on_step_end(args, state, control, **kwargs); end
|
|
42
|
+
def on_log(args, state, control, logs: nil, **kwargs); end
|
|
43
|
+
def on_evaluate(args, state, control, metrics: nil, **kwargs); end
|
|
44
|
+
def on_save(args, state, control, **kwargs); end
|
|
45
|
+
end
|
|
46
|
+
|
|
47
|
+
# Default callback: prints training progress to stdout
|
|
48
|
+
class PrinterCallback < TrainerCallback
|
|
49
|
+
def on_log(args, state, control, logs: nil, **kwargs)
|
|
50
|
+
return unless logs
|
|
51
|
+
output = logs.map { |k, v| "#{k}: #{format_value(v)}" }.join(" ")
|
|
52
|
+
puts "[step #{state.global_step}] #{output}"
|
|
53
|
+
end
|
|
54
|
+
|
|
55
|
+
def on_train_begin(args, state, control, **kwargs)
|
|
56
|
+
puts "Starting training: #{state.num_train_epochs} epochs, #{state.max_steps} total steps"
|
|
57
|
+
end
|
|
58
|
+
|
|
59
|
+
def on_train_end(args, state, control, **kwargs)
|
|
60
|
+
puts "Training complete. Total steps: #{state.global_step}"
|
|
61
|
+
end
|
|
62
|
+
|
|
63
|
+
def on_evaluate(args, state, control, metrics: nil, **kwargs)
|
|
64
|
+
return unless metrics
|
|
65
|
+
output = metrics.map { |k, v| "#{k}: #{format_value(v)}" }.join(" ")
|
|
66
|
+
puts "[eval step #{state.global_step}] #{output}"
|
|
67
|
+
end
|
|
68
|
+
|
|
69
|
+
private
|
|
70
|
+
|
|
71
|
+
def format_value(v)
|
|
72
|
+
v.is_a?(Float) ? format("%.4f", v) : v.to_s
|
|
73
|
+
end
|
|
74
|
+
end
|
|
75
|
+
|
|
76
|
+
# Early stopping callback
|
|
77
|
+
class EarlyStoppingCallback < TrainerCallback
|
|
78
|
+
def initialize(patience: 3, threshold: 0.0, metric_name: "eval_loss")
|
|
79
|
+
@patience = patience
|
|
80
|
+
@threshold = threshold
|
|
81
|
+
@metric_name = metric_name
|
|
82
|
+
@best_value = nil
|
|
83
|
+
@wait_count = 0
|
|
84
|
+
end
|
|
85
|
+
|
|
86
|
+
def on_evaluate(args, state, control, metrics: nil, **kwargs)
|
|
87
|
+
return unless metrics
|
|
88
|
+
|
|
89
|
+
current = metrics[@metric_name] || metrics[@metric_name.to_sym]
|
|
90
|
+
return unless current
|
|
91
|
+
|
|
92
|
+
if @best_value.nil? || improved?(current, @best_value)
|
|
93
|
+
@best_value = current
|
|
94
|
+
@wait_count = 0
|
|
95
|
+
else
|
|
96
|
+
@wait_count += 1
|
|
97
|
+
if @wait_count >= @patience
|
|
98
|
+
puts "Early stopping triggered after #{@wait_count} evaluations without improvement"
|
|
99
|
+
control.should_training_stop = true
|
|
100
|
+
end
|
|
101
|
+
end
|
|
102
|
+
end
|
|
103
|
+
|
|
104
|
+
private
|
|
105
|
+
|
|
106
|
+
def improved?(current, best)
|
|
107
|
+
if @metric_name.include?("loss")
|
|
108
|
+
current < best - @threshold
|
|
109
|
+
else
|
|
110
|
+
current > best + @threshold
|
|
111
|
+
end
|
|
112
|
+
end
|
|
113
|
+
end
|
|
114
|
+
|
|
115
|
+
# Dispatches callback events to all registered callbacks
|
|
116
|
+
class CallbackHandler
|
|
117
|
+
def initialize(callbacks)
|
|
118
|
+
@callbacks = callbacks
|
|
119
|
+
end
|
|
120
|
+
|
|
121
|
+
def fire(event, args, state, control, **kwargs)
|
|
122
|
+
@callbacks.each do |cb|
|
|
123
|
+
cb.send(event, args, state, control, **kwargs) if cb.respond_to?(event)
|
|
124
|
+
end
|
|
125
|
+
control
|
|
126
|
+
end
|
|
127
|
+
end
|
|
128
|
+
end
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Trainers
|
|
4
|
+
class DataCollatorWithPadding
|
|
5
|
+
attr_reader :tokenizer, :padding, :max_length, :pad_to_multiple_of
|
|
6
|
+
|
|
7
|
+
def initialize(tokenizer:, padding: true, max_length: nil, pad_to_multiple_of: nil)
|
|
8
|
+
@tokenizer = tokenizer
|
|
9
|
+
@padding = padding
|
|
10
|
+
@max_length = max_length
|
|
11
|
+
@pad_to_multiple_of = pad_to_multiple_of
|
|
12
|
+
end
|
|
13
|
+
|
|
14
|
+
def call(features)
|
|
15
|
+
return {} if features.empty?
|
|
16
|
+
|
|
17
|
+
keys = features.first.keys
|
|
18
|
+
batch = {}
|
|
19
|
+
|
|
20
|
+
keys.each do |key|
|
|
21
|
+
values = features.map { |f| f[key] }
|
|
22
|
+
|
|
23
|
+
if values.first.is_a?(Array)
|
|
24
|
+
batch[key] = pad_and_stack(key, values)
|
|
25
|
+
elsif values.first.is_a?(Torch::Tensor)
|
|
26
|
+
if values.first.dim == 0
|
|
27
|
+
batch[key] = Torch.stack(values)
|
|
28
|
+
else
|
|
29
|
+
batch[key] = pad_and_stack_tensors(key, values)
|
|
30
|
+
end
|
|
31
|
+
elsif values.first.is_a?(Integer)
|
|
32
|
+
batch[key] = Torch.tensor(values, dtype: :int64)
|
|
33
|
+
elsif values.first.is_a?(Float)
|
|
34
|
+
batch[key] = Torch.tensor(values, dtype: :float32)
|
|
35
|
+
else
|
|
36
|
+
batch[key] = values
|
|
37
|
+
end
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
batch
|
|
41
|
+
end
|
|
42
|
+
|
|
43
|
+
private
|
|
44
|
+
|
|
45
|
+
def pad_and_stack(key, sequences)
|
|
46
|
+
max_len = compute_max_length(sequences.map(&:length))
|
|
47
|
+
|
|
48
|
+
pad_value = if key.to_s.include?("input_id")
|
|
49
|
+
pad_token_id
|
|
50
|
+
elsif key.to_s.include?("attention_mask")
|
|
51
|
+
0
|
|
52
|
+
elsif key.to_s.include?("label")
|
|
53
|
+
-100
|
|
54
|
+
else
|
|
55
|
+
0
|
|
56
|
+
end
|
|
57
|
+
|
|
58
|
+
padded = sequences.map do |seq|
|
|
59
|
+
padded_seq = seq + [pad_value] * (max_len - seq.length)
|
|
60
|
+
padded_seq
|
|
61
|
+
end
|
|
62
|
+
|
|
63
|
+
dtype = key.to_s.include?("attention_mask") ? :int64 : :int64
|
|
64
|
+
Torch.tensor(padded, dtype: dtype)
|
|
65
|
+
end
|
|
66
|
+
|
|
67
|
+
def pad_and_stack_tensors(key, tensors)
|
|
68
|
+
max_len = compute_max_length(tensors.map { |t| t.size(0) })
|
|
69
|
+
|
|
70
|
+
pad_value = if key.to_s.include?("input_id")
|
|
71
|
+
pad_token_id
|
|
72
|
+
elsif key.to_s.include?("attention_mask")
|
|
73
|
+
0
|
|
74
|
+
else
|
|
75
|
+
0
|
|
76
|
+
end
|
|
77
|
+
|
|
78
|
+
padded = tensors.map do |t|
|
|
79
|
+
if t.size(0) < max_len
|
|
80
|
+
padding = Torch.full([max_len - t.size(0)], pad_value, dtype: t.dtype)
|
|
81
|
+
Torch.cat([t, padding])
|
|
82
|
+
else
|
|
83
|
+
t[0...max_len]
|
|
84
|
+
end
|
|
85
|
+
end
|
|
86
|
+
|
|
87
|
+
Torch.stack(padded)
|
|
88
|
+
end
|
|
89
|
+
|
|
90
|
+
def compute_max_length(lengths)
|
|
91
|
+
max_len = @max_length || lengths.max
|
|
92
|
+
if @pad_to_multiple_of
|
|
93
|
+
max_len = ((max_len + @pad_to_multiple_of - 1) / @pad_to_multiple_of) * @pad_to_multiple_of
|
|
94
|
+
end
|
|
95
|
+
max_len
|
|
96
|
+
end
|
|
97
|
+
|
|
98
|
+
def pad_token_id
|
|
99
|
+
if @tokenizer.respond_to?(:pad_token_id)
|
|
100
|
+
@tokenizer.pad_token_id || 0
|
|
101
|
+
else
|
|
102
|
+
0
|
|
103
|
+
end
|
|
104
|
+
end
|
|
105
|
+
end
|
|
106
|
+
|
|
107
|
+
class DefaultDataCollator
|
|
108
|
+
def call(features)
|
|
109
|
+
return {} if features.empty?
|
|
110
|
+
|
|
111
|
+
batch = {}
|
|
112
|
+
features.first.keys.each do |key|
|
|
113
|
+
values = features.map { |f| f[key] }
|
|
114
|
+
|
|
115
|
+
if values.first.is_a?(Torch::Tensor)
|
|
116
|
+
batch[key] = Torch.stack(values)
|
|
117
|
+
elsif values.first.is_a?(Integer)
|
|
118
|
+
batch[key] = Torch.tensor(values, dtype: :int64)
|
|
119
|
+
elsif values.first.is_a?(Float)
|
|
120
|
+
batch[key] = Torch.tensor(values, dtype: :float32)
|
|
121
|
+
elsif values.first.is_a?(Array)
|
|
122
|
+
batch[key] = Torch.tensor(values)
|
|
123
|
+
else
|
|
124
|
+
batch[key] = values
|
|
125
|
+
end
|
|
126
|
+
end
|
|
127
|
+
|
|
128
|
+
batch
|
|
129
|
+
end
|
|
130
|
+
end
|
|
131
|
+
end
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Trainers
|
|
4
|
+
class Dataset
|
|
5
|
+
include Enumerable
|
|
6
|
+
|
|
7
|
+
attr_reader :data
|
|
8
|
+
|
|
9
|
+
def initialize(data)
|
|
10
|
+
@data = data
|
|
11
|
+
end
|
|
12
|
+
|
|
13
|
+
def [](index)
|
|
14
|
+
@data[index]
|
|
15
|
+
end
|
|
16
|
+
|
|
17
|
+
def size
|
|
18
|
+
@data.size
|
|
19
|
+
end
|
|
20
|
+
alias_method :length, :size
|
|
21
|
+
|
|
22
|
+
def each(&block)
|
|
23
|
+
@data.each(&block)
|
|
24
|
+
end
|
|
25
|
+
end
|
|
26
|
+
end
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module Trainers
|
|
4
|
+
class LoraConfig
|
|
5
|
+
attr_accessor :r, :lora_alpha, :lora_dropout, :target_modules,
|
|
6
|
+
:bias, :task_type
|
|
7
|
+
|
|
8
|
+
DEFAULTS = {
|
|
9
|
+
r: 8,
|
|
10
|
+
lora_alpha: 16,
|
|
11
|
+
lora_dropout: 0.0,
|
|
12
|
+
target_modules: ["query", "value"], # or :all_linear
|
|
13
|
+
bias: :none, # :none, :all, :lora_only
|
|
14
|
+
task_type: :sequence_classification
|
|
15
|
+
}.freeze
|
|
16
|
+
|
|
17
|
+
def initialize(**kwargs)
|
|
18
|
+
DEFAULTS.each do |key, default|
|
|
19
|
+
value = kwargs.fetch(key, default)
|
|
20
|
+
instance_variable_set(:"@#{key}", value)
|
|
21
|
+
end
|
|
22
|
+
end
|
|
23
|
+
|
|
24
|
+
def scaling
|
|
25
|
+
@lora_alpha.to_f / @r
|
|
26
|
+
end
|
|
27
|
+
|
|
28
|
+
def to_h
|
|
29
|
+
DEFAULTS.keys.each_with_object({}) do |key, hash|
|
|
30
|
+
hash[key] = send(key)
|
|
31
|
+
end
|
|
32
|
+
end
|
|
33
|
+
end
|
|
34
|
+
end
|