trainers-rb 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,340 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Trainers
4
+ class Trainer
5
+ attr_reader :model, :args, :train_dataset, :eval_dataset, :tokenizer,
6
+ :data_collator, :optimizer, :lr_scheduler, :state, :control
7
+
8
+ def initialize(
9
+ model:,
10
+ args: nil,
11
+ train_dataset: nil,
12
+ eval_dataset: nil,
13
+ tokenizer: nil,
14
+ data_collator: nil,
15
+ compute_metrics: nil,
16
+ callbacks: []
17
+ )
18
+ @model = model
19
+ @args = args || TrainingArguments.new
20
+ @train_dataset = train_dataset
21
+ @eval_dataset = eval_dataset
22
+ @tokenizer = tokenizer
23
+ @data_collator = data_collator || DefaultDataCollator.new
24
+ @compute_metrics = compute_metrics
25
+ @state = TrainerState.new
26
+ @control = TrainerControl.new
27
+
28
+ all_callbacks = [PrinterCallback.new] + callbacks
29
+ @callback_handler = CallbackHandler.new(all_callbacks)
30
+ end
31
+
32
+ def train
33
+ device = @args.resolved_device
34
+ @model.to(device)
35
+ @model.train
36
+
37
+ num_examples = @train_dataset.size
38
+ batch_size = @args.per_device_train_batch_size
39
+ steps_per_epoch = (num_examples.to_f / batch_size).ceil
40
+ total_steps = steps_per_epoch * @args.num_train_epochs
41
+
42
+ @state.max_steps = total_steps
43
+ @state.num_train_epochs = @args.num_train_epochs
44
+
45
+ @optimizer = create_optimizer
46
+ @lr_scheduler = create_scheduler(total_steps)
47
+
48
+ @callback_handler.fire(:on_train_begin, @args, @state, @control)
49
+
50
+ @args.num_train_epochs.times do |epoch|
51
+ @state.epoch = epoch + 1
52
+ @callback_handler.fire(:on_epoch_begin, @args, @state, @control)
53
+ @model.train
54
+
55
+ epoch_loss = 0.0
56
+ epoch_steps = 0
57
+
58
+ each_batch(@train_dataset, batch_size, shuffle: true) do |batch|
59
+ @callback_handler.fire(:on_step_begin, @args, @state, @control)
60
+
61
+ batch = move_to_device(batch, device)
62
+ loss = compute_loss(batch)
63
+
64
+ scaled_loss = if @args.gradient_accumulation_steps > 1
65
+ loss / @args.gradient_accumulation_steps
66
+ else
67
+ loss
68
+ end
69
+
70
+ scaled_loss.backward
71
+
72
+ epoch_loss += loss.item
73
+ epoch_steps += 1
74
+ @state.global_step += 1
75
+
76
+ if @state.global_step % @args.gradient_accumulation_steps == 0
77
+ clip_grad_norm!(@model.parameters, @args.max_grad_norm)
78
+ @optimizer.step
79
+ @lr_scheduler.step
80
+ @optimizer.zero_grad
81
+ end
82
+
83
+ # Logging
84
+ if should_log?
85
+ logs = {
86
+ loss: epoch_loss / epoch_steps,
87
+ learning_rate: current_lr,
88
+ epoch: @state.epoch
89
+ }
90
+ @state.log_history << logs.merge(step: @state.global_step)
91
+ @callback_handler.fire(:on_log, @args, @state, @control, logs: logs)
92
+ end
93
+
94
+ # Step-based evaluation
95
+ if @args.eval_strategy == :steps && @args.eval_steps &&
96
+ @state.global_step % @args.eval_steps == 0
97
+ metrics = evaluate
98
+ @callback_handler.fire(:on_evaluate, @args, @state, @control, metrics: metrics)
99
+ end
100
+
101
+ # Step-based saving
102
+ if @args.save_strategy == :steps && @args.save_steps &&
103
+ @state.global_step % @args.save_steps == 0
104
+ save_checkpoint
105
+ @callback_handler.fire(:on_save, @args, @state, @control)
106
+ end
107
+
108
+ @callback_handler.fire(:on_step_end, @args, @state, @control)
109
+ break if @control.should_training_stop || @control.should_epoch_stop
110
+ end
111
+
112
+ # Epoch-level logging
113
+ epoch_avg_loss = epoch_steps > 0 ? epoch_loss / epoch_steps : 0.0
114
+ logs = { loss: epoch_avg_loss, learning_rate: current_lr, epoch: @state.epoch }
115
+ @state.log_history << logs.merge(step: @state.global_step)
116
+ @callback_handler.fire(:on_log, @args, @state, @control, logs: logs)
117
+
118
+ # Epoch-based evaluation
119
+ if @args.eval_strategy == :epoch && @eval_dataset
120
+ metrics = evaluate
121
+ @callback_handler.fire(:on_evaluate, @args, @state, @control, metrics: metrics)
122
+ end
123
+
124
+ # Epoch-based saving
125
+ if @args.save_strategy == :epoch
126
+ save_checkpoint
127
+ @callback_handler.fire(:on_save, @args, @state, @control)
128
+ end
129
+
130
+ @callback_handler.fire(:on_epoch_end, @args, @state, @control)
131
+ @control.should_epoch_stop = false
132
+ break if @control.should_training_stop
133
+ end
134
+
135
+ @callback_handler.fire(:on_train_end, @args, @state, @control)
136
+ @state
137
+ end
138
+
139
+ def evaluate(eval_dataset: nil)
140
+ dataset = eval_dataset || @eval_dataset
141
+ raise ArgumentError, "No eval_dataset provided" unless dataset
142
+
143
+ device = @args.resolved_device
144
+ @model.eval
145
+
146
+ all_preds = []
147
+ all_labels = []
148
+ total_loss = 0.0
149
+ total_steps = 0
150
+
151
+ Torch.no_grad do
152
+ each_batch(dataset, @args.per_device_eval_batch_size) do |batch|
153
+ batch = move_to_device(batch, device)
154
+ labels = batch.delete(:labels) || batch.delete("labels")
155
+
156
+ output = forward(batch)
157
+
158
+ if labels
159
+ logits = output.respond_to?(:logits) ? output.logits : output
160
+ loss = Torch::NN::F.cross_entropy(logits, labels)
161
+ total_loss += loss.item
162
+ all_labels << labels.detach.cpu
163
+ end
164
+ total_steps += 1
165
+
166
+ logits = output.respond_to?(:logits) ? output.logits : output
167
+ all_preds << logits.detach.cpu
168
+ end
169
+ end
170
+
171
+ @model.train
172
+
173
+ metrics = {}
174
+ metrics[:eval_loss] = total_loss / total_steps if total_steps > 0
175
+
176
+ if @compute_metrics && all_preds.any? && all_labels.any?
177
+ preds = Torch.cat(all_preds)
178
+ labels = Torch.cat(all_labels)
179
+ eval_pred = EvalPrediction.new(predictions: preds, label_ids: labels)
180
+ custom_metrics = @compute_metrics.call(eval_pred)
181
+ metrics.merge!(custom_metrics)
182
+ end
183
+
184
+ metrics
185
+ end
186
+
187
+ def predict(test_dataset)
188
+ device = @args.resolved_device
189
+ @model.eval
190
+
191
+ all_preds = []
192
+ Torch.no_grad do
193
+ each_batch(test_dataset, @args.per_device_eval_batch_size) do |batch|
194
+ batch = move_to_device(batch, device)
195
+ output = forward(batch)
196
+ logits = output.respond_to?(:logits) ? output.logits : output
197
+ all_preds << logits.detach.cpu
198
+ end
199
+ end
200
+
201
+ Torch.cat(all_preds)
202
+ end
203
+
204
+ def save_model(output_dir = nil)
205
+ output_dir ||= @args.output_dir
206
+ SaveUtils.save_pretrained(@model, @tokenizer, output_dir, training_args: @args)
207
+ end
208
+
209
+ private
210
+
211
+ def compute_loss(batch)
212
+ labels = batch.delete(:labels) || batch.delete("labels")
213
+
214
+ # Try passing labels to the model (some models compute loss internally)
215
+ output = begin
216
+ forward(labels ? batch.merge(labels: labels) : batch)
217
+ rescue => e
218
+ # If the model doesn't support labels kwarg (e.g. transformers-rb Todo),
219
+ # fall back to forward without labels + external loss
220
+ if e.message.include?("Todo") || e.message.include?("not implemented")
221
+ forward(batch)
222
+ else
223
+ raise
224
+ end
225
+ end
226
+
227
+ # Restore labels to batch for downstream use
228
+ batch[:labels] = labels if labels
229
+
230
+ if output.respond_to?(:loss) && output.loss
231
+ output.loss
232
+ elsif labels
233
+ logits = output.respond_to?(:logits) ? output.logits : output
234
+ Torch::NN::F.cross_entropy(logits, labels)
235
+ else
236
+ raise "Model did not return a loss and no labels found in batch. " \
237
+ "Either pass labels in your dataset or use a model that computes loss."
238
+ end
239
+ end
240
+
241
+ def forward(batch)
242
+ if batch.is_a?(Hash)
243
+ # Filter to only keys the model accepts, using symbol keys
244
+ @model.call(**batch)
245
+ else
246
+ @model.call(batch)
247
+ end
248
+ end
249
+
250
+ def create_optimizer
251
+ Optimization.create_optimizer(@model, @args)
252
+ end
253
+
254
+ def create_scheduler(total_steps)
255
+ warmup_steps = if @args.warmup_steps > 0
256
+ @args.warmup_steps
257
+ elsif @args.warmup_ratio > 0
258
+ (total_steps * @args.warmup_ratio).to_i
259
+ else
260
+ 0
261
+ end
262
+
263
+ Optimization.create_scheduler(
264
+ @args.lr_scheduler_type,
265
+ @optimizer,
266
+ num_warmup_steps: warmup_steps,
267
+ num_training_steps: total_steps
268
+ )
269
+ end
270
+
271
+ def each_batch(dataset, batch_size, shuffle: false)
272
+ indices = (0...dataset.size).to_a
273
+ indices.shuffle! if shuffle
274
+
275
+ (0...dataset.size).step(batch_size) do |start|
276
+ batch_indices = indices[start, batch_size]
277
+ next if batch_indices.nil? || batch_indices.empty?
278
+
279
+ features = batch_indices.map { |i| dataset[i] }
280
+ batch = @data_collator.call(features)
281
+ yield batch
282
+ end
283
+ end
284
+
285
+ def move_to_device(batch, device)
286
+ batch.each_with_object({}) do |(key, value), result|
287
+ result[key] = if value.is_a?(Torch::Tensor)
288
+ value.to(device)
289
+ else
290
+ value
291
+ end
292
+ end
293
+ end
294
+
295
+ def clip_grad_norm!(parameters, max_norm)
296
+ params = parameters.select { |p| p.grad }
297
+ return 0.0 if params.empty?
298
+
299
+ total_norm_sq = 0.0
300
+ params.each do |p|
301
+ total_norm_sq += p.grad.data.norm(2).item ** 2
302
+ end
303
+ total_norm = Math.sqrt(total_norm_sq)
304
+
305
+ clip_coef = max_norm / (total_norm + 1e-6)
306
+ if clip_coef < 1.0
307
+ params.each { |p| p.grad.data.mul!(clip_coef) }
308
+ end
309
+
310
+ total_norm
311
+ end
312
+
313
+ def current_lr
314
+ @optimizer.param_groups.first[:lr]
315
+ end
316
+
317
+ def should_log?
318
+ return true if @args.logging_first_step && @state.global_step == 1
319
+ @state.global_step % @args.logging_steps == 0
320
+ end
321
+
322
+ def save_checkpoint
323
+ dir = File.join(@args.output_dir, "checkpoint-#{@state.global_step}")
324
+ save_model(dir)
325
+ cleanup_checkpoints if @args.save_total_limit
326
+ end
327
+
328
+ def cleanup_checkpoints
329
+ return unless @args.save_total_limit
330
+
331
+ checkpoints = Dir.glob(File.join(@args.output_dir, "checkpoint-*"))
332
+ .sort_by { |d| d[/checkpoint-(\d+)/, 1].to_i }
333
+
334
+ while checkpoints.length > @args.save_total_limit
335
+ old = checkpoints.shift
336
+ FileUtils.rm_rf(old)
337
+ end
338
+ end
339
+ end
340
+ end
@@ -0,0 +1,30 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Trainers
4
+ module EvalStrategy
5
+ NO = :no
6
+ EPOCH = :epoch
7
+ STEPS = :steps
8
+ end
9
+
10
+ module SaveStrategy
11
+ NO = :no
12
+ EPOCH = :epoch
13
+ STEPS = :steps
14
+ end
15
+
16
+ module SchedulerType
17
+ LINEAR = :linear
18
+ COSINE = :cosine
19
+ CONSTANT = :constant
20
+ end
21
+
22
+ class EvalPrediction
23
+ attr_reader :predictions, :label_ids
24
+
25
+ def initialize(predictions:, label_ids:)
26
+ @predictions = predictions
27
+ @label_ids = label_ids
28
+ end
29
+ end
30
+ end
@@ -0,0 +1,64 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Trainers
4
+ class TrainingArguments
5
+ DEFAULTS = {
6
+ output_dir: "./output",
7
+ num_train_epochs: 3,
8
+ per_device_train_batch_size: 8,
9
+ per_device_eval_batch_size: 8,
10
+ learning_rate: 5e-5,
11
+ weight_decay: 0.0,
12
+ adam_beta1: 0.9,
13
+ adam_beta2: 0.999,
14
+ adam_epsilon: 1e-8,
15
+ max_grad_norm: 1.0,
16
+ gradient_accumulation_steps: 1,
17
+ warmup_steps: 0,
18
+ warmup_ratio: 0.0,
19
+ lr_scheduler_type: :linear,
20
+ eval_strategy: :no,
21
+ eval_steps: nil,
22
+ save_strategy: :epoch,
23
+ save_steps: 500,
24
+ save_total_limit: nil,
25
+ logging_steps: 500,
26
+ logging_first_step: false,
27
+ seed: 42,
28
+ device: nil,
29
+ no_mps: false,
30
+ dataloader_drop_last: false,
31
+ label_names: ["labels"]
32
+ }.freeze
33
+
34
+ DEFAULTS.each_key do |key|
35
+ attr_accessor key
36
+ end
37
+
38
+ def initialize(**kwargs)
39
+ DEFAULTS.each do |key, default|
40
+ value = kwargs.fetch(key, default)
41
+ instance_variable_set(:"@#{key}", value)
42
+ end
43
+
44
+ unknown = kwargs.keys - DEFAULTS.keys
45
+ raise ArgumentError, "Unknown arguments: #{unknown.join(', ')}" unless unknown.empty?
46
+ end
47
+
48
+ def resolved_device
49
+ return @device if @device
50
+
51
+ if !@no_mps && defined?(Torch::Backends::MPS) && Torch::Backends::MPS.available?
52
+ Torch.device("mps")
53
+ else
54
+ Torch.device("cpu")
55
+ end
56
+ end
57
+
58
+ def to_h
59
+ DEFAULTS.keys.each_with_object({}) do |key, hash|
60
+ hash[key] = send(key)
61
+ end
62
+ end
63
+ end
64
+ end
@@ -0,0 +1,5 @@
1
+ # frozen_string_literal: true
2
+
3
+ module Trainers
4
+ VERSION = "0.1.0"
5
+ end
@@ -0,0 +1 @@
1
+ require_relative "trainers"
data/lib/trainers.rb ADDED
@@ -0,0 +1,43 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "torch"
4
+ require "json"
5
+ require "fileutils"
6
+
7
+ require_relative "trainers/version"
8
+ require_relative "trainers/trainer_utils"
9
+ require_relative "trainers/training_arguments"
10
+ require_relative "trainers/data/dataset"
11
+ require_relative "trainers/data/data_collator"
12
+ require_relative "trainers/optimization/optimizer"
13
+ require_relative "trainers/optimization/scheduler"
14
+ require_relative "trainers/callbacks"
15
+ require_relative "trainers/save_utils"
16
+ require_relative "trainers/lora/lora_config"
17
+ require_relative "trainers/lora/lora_linear"
18
+ require_relative "trainers/lora/lora_utils"
19
+ require_relative "trainers/lora/lora_model"
20
+ require_relative "trainers/trainer"
21
+
22
+ module Trainers
23
+ # Convenience method: load model + tokenizer and prepare for training
24
+ def self.from_pretrained(model_name, task: :sequence_classification, num_labels: 2)
25
+ require "transformers-rb"
26
+
27
+ model_class = case task
28
+ when :sequence_classification
29
+ Transformers::AutoModelForSequenceClassification
30
+ when :token_classification
31
+ Transformers::AutoModelForTokenClassification
32
+ when :question_answering
33
+ Transformers::AutoModelForQuestionAnswering
34
+ else
35
+ Transformers::AutoModel
36
+ end
37
+
38
+ model = model_class.from_pretrained(model_name, num_labels: num_labels)
39
+ tokenizer = Transformers::AutoTokenizer.from_pretrained(model_name)
40
+
41
+ [model, tokenizer]
42
+ end
43
+ end
metadata ADDED
@@ -0,0 +1,149 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: trainers-rb
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.1.0
5
+ platform: ruby
6
+ authors:
7
+ - Vishwajeetsingh Desurkar
8
+ bindir: bin
9
+ cert_chain: []
10
+ date: 1980-01-02 00:00:00.000000000 Z
11
+ dependencies:
12
+ - !ruby/object:Gem::Dependency
13
+ name: torch-rb
14
+ requirement: !ruby/object:Gem::Requirement
15
+ requirements:
16
+ - - ">="
17
+ - !ruby/object:Gem::Version
18
+ version: 0.17.1
19
+ type: :runtime
20
+ prerelease: false
21
+ version_requirements: !ruby/object:Gem::Requirement
22
+ requirements:
23
+ - - ">="
24
+ - !ruby/object:Gem::Version
25
+ version: 0.17.1
26
+ - !ruby/object:Gem::Dependency
27
+ name: transformers-rb
28
+ requirement: !ruby/object:Gem::Requirement
29
+ requirements:
30
+ - - ">="
31
+ - !ruby/object:Gem::Version
32
+ version: 0.2.0
33
+ type: :runtime
34
+ prerelease: false
35
+ version_requirements: !ruby/object:Gem::Requirement
36
+ requirements:
37
+ - - ">="
38
+ - !ruby/object:Gem::Version
39
+ version: 0.2.0
40
+ - !ruby/object:Gem::Dependency
41
+ name: safetensors
42
+ requirement: !ruby/object:Gem::Requirement
43
+ requirements:
44
+ - - ">="
45
+ - !ruby/object:Gem::Version
46
+ version: 0.1.1
47
+ type: :runtime
48
+ prerelease: false
49
+ version_requirements: !ruby/object:Gem::Requirement
50
+ requirements:
51
+ - - ">="
52
+ - !ruby/object:Gem::Version
53
+ version: 0.1.1
54
+ - !ruby/object:Gem::Dependency
55
+ name: tokenizers
56
+ requirement: !ruby/object:Gem::Requirement
57
+ requirements:
58
+ - - ">="
59
+ - !ruby/object:Gem::Version
60
+ version: 0.5.3
61
+ type: :runtime
62
+ prerelease: false
63
+ version_requirements: !ruby/object:Gem::Requirement
64
+ requirements:
65
+ - - ">="
66
+ - !ruby/object:Gem::Version
67
+ version: 0.5.3
68
+ - !ruby/object:Gem::Dependency
69
+ name: rake
70
+ requirement: !ruby/object:Gem::Requirement
71
+ requirements:
72
+ - - "~>"
73
+ - !ruby/object:Gem::Version
74
+ version: '13.0'
75
+ type: :development
76
+ prerelease: false
77
+ version_requirements: !ruby/object:Gem::Requirement
78
+ requirements:
79
+ - - "~>"
80
+ - !ruby/object:Gem::Version
81
+ version: '13.0'
82
+ - !ruby/object:Gem::Dependency
83
+ name: minitest
84
+ requirement: !ruby/object:Gem::Requirement
85
+ requirements:
86
+ - - "~>"
87
+ - !ruby/object:Gem::Version
88
+ version: '5.0'
89
+ type: :development
90
+ prerelease: false
91
+ version_requirements: !ruby/object:Gem::Requirement
92
+ requirements:
93
+ - - "~>"
94
+ - !ruby/object:Gem::Version
95
+ version: '5.0'
96
+ description: Training loop, LoRA, and optimization utilities for fine-tuning HuggingFace
97
+ transformer models using torch-rb and transformers-rb. Supports full fine-tuning,
98
+ LoRA adapters, learning rate scheduling, callbacks, and model serialization via
99
+ safetensors.
100
+ email:
101
+ - selectus2@users.noreply.rubygems.org
102
+ executables: []
103
+ extensions: []
104
+ extra_rdoc_files: []
105
+ files:
106
+ - CHANGELOG.md
107
+ - LICENSE.txt
108
+ - README.md
109
+ - lib/trainers-rb.rb
110
+ - lib/trainers.rb
111
+ - lib/trainers/callbacks.rb
112
+ - lib/trainers/data/data_collator.rb
113
+ - lib/trainers/data/dataset.rb
114
+ - lib/trainers/lora/lora_config.rb
115
+ - lib/trainers/lora/lora_linear.rb
116
+ - lib/trainers/lora/lora_model.rb
117
+ - lib/trainers/lora/lora_utils.rb
118
+ - lib/trainers/optimization/optimizer.rb
119
+ - lib/trainers/optimization/scheduler.rb
120
+ - lib/trainers/save_utils.rb
121
+ - lib/trainers/trainer.rb
122
+ - lib/trainers/trainer_utils.rb
123
+ - lib/trainers/training_arguments.rb
124
+ - lib/trainers/version.rb
125
+ homepage: https://github.com/trainers-rb/trainers-rb
126
+ licenses:
127
+ - MIT
128
+ metadata:
129
+ homepage_uri: https://github.com/trainers-rb/trainers-rb
130
+ source_code_uri: https://github.com/trainers-rb/trainers-rb
131
+ changelog_uri: https://github.com/trainers-rb/trainers-rb/blob/main/CHANGELOG.md
132
+ rdoc_options: []
133
+ require_paths:
134
+ - lib
135
+ required_ruby_version: !ruby/object:Gem::Requirement
136
+ requirements:
137
+ - - ">="
138
+ - !ruby/object:Gem::Version
139
+ version: 3.1.0
140
+ required_rubygems_version: !ruby/object:Gem::Requirement
141
+ requirements:
142
+ - - ">="
143
+ - !ruby/object:Gem::Version
144
+ version: '0'
145
+ requirements: []
146
+ rubygems_version: 4.0.5
147
+ specification_version: 4
148
+ summary: Fine-tune transformer models in Ruby
149
+ test_files: []