mlx 0.30.7 → 0.30.7.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/ext/mlx/native.cpp +0 -4
- data/lib/mlx/core.rb +8 -1
- data/lib/mlx/distributed_utils/launch.rb +9 -3
- data/lib/mlx/dsl/builder.rb +377 -0
- data/lib/mlx/dsl/data_pipeline.rb +284 -0
- data/lib/mlx/dsl/experiment.rb +154 -0
- data/lib/mlx/dsl/graph_modules.rb +91 -0
- data/lib/mlx/dsl/model.rb +9 -0
- data/lib/mlx/dsl/model_mixin.rb +706 -0
- data/lib/mlx/dsl/split_plan.rb +85 -0
- data/lib/mlx/dsl/train_step.rb +197 -0
- data/lib/mlx/dsl/trainer.rb +2110 -0
- data/lib/mlx/dsl.rb +16 -0
- data/lib/mlx/nn/layers/containers.rb +21 -4
- data/lib/mlx/version.rb +1 -1
- data/lib/mlx.rb +1 -0
- data/mlx/CMakeLists.txt +4 -16
- metadata +12 -2
|
@@ -0,0 +1,2110 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "json"
|
|
4
|
+
require "fileutils"
|
|
5
|
+
require "time"
|
|
6
|
+
|
|
7
|
+
module MLX
|
|
8
|
+
module DSL
|
|
9
|
+
class Trainer
|
|
10
|
+
HOOK_EVENTS = %i[
|
|
11
|
+
before_fit
|
|
12
|
+
before_epoch
|
|
13
|
+
after_batch
|
|
14
|
+
before_validation
|
|
15
|
+
after_validation_batch
|
|
16
|
+
after_validation
|
|
17
|
+
after_epoch
|
|
18
|
+
checkpoint
|
|
19
|
+
after_fit
|
|
20
|
+
].freeze
|
|
21
|
+
UNSET = Object.new.freeze
|
|
22
|
+
FIT_OPTION_DEFAULTS = {
|
|
23
|
+
epochs: 1,
|
|
24
|
+
limit: nil,
|
|
25
|
+
collate: nil,
|
|
26
|
+
bind: nil,
|
|
27
|
+
report: false,
|
|
28
|
+
reduce: :mean,
|
|
29
|
+
monitor: :epoch_loss,
|
|
30
|
+
metric: nil,
|
|
31
|
+
validation_data: nil,
|
|
32
|
+
validation_limit: nil,
|
|
33
|
+
validation_reduce: nil,
|
|
34
|
+
validation_collate: nil,
|
|
35
|
+
validation_bind: nil,
|
|
36
|
+
train_transform: nil,
|
|
37
|
+
validation_transform: nil,
|
|
38
|
+
checkpoint_path: nil,
|
|
39
|
+
save_best: false,
|
|
40
|
+
monitor_mode: :min,
|
|
41
|
+
patience: nil,
|
|
42
|
+
min_delta: 0.0,
|
|
43
|
+
keep_losses: true,
|
|
44
|
+
strict_data_reuse: false,
|
|
45
|
+
resume_from: nil,
|
|
46
|
+
metadata: {}
|
|
47
|
+
}.freeze
|
|
48
|
+
|
|
49
|
+
def initialize(model:, optimizer:, clip_grad_norm: nil, compile: false, sync: :none, &loss_block)
|
|
50
|
+
raise ArgumentError, "trainer requires a loss block" unless block_given?
|
|
51
|
+
|
|
52
|
+
@__dsl_init_options = {
|
|
53
|
+
model: model,
|
|
54
|
+
optimizer: optimizer,
|
|
55
|
+
clip_grad_norm: clip_grad_norm,
|
|
56
|
+
compile: compile,
|
|
57
|
+
sync: sync
|
|
58
|
+
}
|
|
59
|
+
@model = model
|
|
60
|
+
@loss_block = loss_block
|
|
61
|
+
@sync_mode = __dsl_normalize_sync(sync)
|
|
62
|
+
@step = __dsl_build_train_step(
|
|
63
|
+
model,
|
|
64
|
+
optimizer: optimizer,
|
|
65
|
+
clip_grad_norm: clip_grad_norm,
|
|
66
|
+
compile: compile,
|
|
67
|
+
&loss_block
|
|
68
|
+
)
|
|
69
|
+
@optimizer = optimizer
|
|
70
|
+
@hooks = Hash.new { |h, k| h[k] = [] }
|
|
71
|
+
@hook_order = 0
|
|
72
|
+
@collate_registry = {}
|
|
73
|
+
@fit_defaults = {}
|
|
74
|
+
@fit_presets = {}
|
|
75
|
+
@batch_schemas = { train: nil, validation: nil }
|
|
76
|
+
@dataflow_profiles = {}
|
|
77
|
+
@hook_packs = {}
|
|
78
|
+
@metric_registry = {}
|
|
79
|
+
@task_presets = __dsl_builtin_task_presets
|
|
80
|
+
@artifact_policy_config = {
|
|
81
|
+
checkpoint: {},
|
|
82
|
+
retention: {},
|
|
83
|
+
resume: nil,
|
|
84
|
+
run_bundle: {}
|
|
85
|
+
}
|
|
86
|
+
@checkpoint_history = []
|
|
87
|
+
@last_checkpoint_snapshot = nil
|
|
88
|
+
end
|
|
89
|
+
|
|
90
|
+
def on(event, priority: 0, every: nil, once: false, **kwargs, &block)
|
|
91
|
+
raise ArgumentError, "hook registration requires a block" unless block_given?
|
|
92
|
+
condition = kwargs.delete(:if)
|
|
93
|
+
condition = kwargs.delete(:condition) if condition.nil? && kwargs.key?(:condition)
|
|
94
|
+
unless kwargs.empty?
|
|
95
|
+
raise ArgumentError, "unsupported hook option(s): #{kwargs.keys.map(&:inspect).join(', ')}"
|
|
96
|
+
end
|
|
97
|
+
every_value = nil
|
|
98
|
+
unless every.nil?
|
|
99
|
+
every_value = every.to_i
|
|
100
|
+
raise ArgumentError, "hook :every must be a positive integer" if every_value <= 0
|
|
101
|
+
end
|
|
102
|
+
|
|
103
|
+
@hooks[event.to_sym] << {
|
|
104
|
+
hook: block,
|
|
105
|
+
priority: priority.to_i,
|
|
106
|
+
every: every_value,
|
|
107
|
+
once: !!once,
|
|
108
|
+
if: condition,
|
|
109
|
+
fired: false,
|
|
110
|
+
invocations: 0,
|
|
111
|
+
order: @hook_order
|
|
112
|
+
}
|
|
113
|
+
@hook_order += 1
|
|
114
|
+
self
|
|
115
|
+
end
|
|
116
|
+
|
|
117
|
+
HOOK_EVENTS.each do |event|
|
|
118
|
+
define_method(event) do |**options, &block|
|
|
119
|
+
on(event, **options, &block)
|
|
120
|
+
end
|
|
121
|
+
end
|
|
122
|
+
|
|
123
|
+
def fit(
|
|
124
|
+
dataset,
|
|
125
|
+
epochs: UNSET,
|
|
126
|
+
limit: UNSET,
|
|
127
|
+
collate: UNSET,
|
|
128
|
+
bind: UNSET,
|
|
129
|
+
report: UNSET,
|
|
130
|
+
reduce: UNSET,
|
|
131
|
+
monitor: UNSET,
|
|
132
|
+
metric: UNSET,
|
|
133
|
+
validation_data: UNSET,
|
|
134
|
+
validation_limit: UNSET,
|
|
135
|
+
validation_reduce: UNSET,
|
|
136
|
+
validation_collate: UNSET,
|
|
137
|
+
validation_bind: UNSET,
|
|
138
|
+
train_transform: UNSET,
|
|
139
|
+
validation_transform: UNSET,
|
|
140
|
+
checkpoint_path: UNSET,
|
|
141
|
+
save_best: UNSET,
|
|
142
|
+
monitor_mode: UNSET,
|
|
143
|
+
patience: UNSET,
|
|
144
|
+
min_delta: UNSET,
|
|
145
|
+
keep_losses: UNSET,
|
|
146
|
+
strict_data_reuse: UNSET,
|
|
147
|
+
resume_from: UNSET,
|
|
148
|
+
metadata: UNSET
|
|
149
|
+
)
|
|
150
|
+
raw_options = {
|
|
151
|
+
epochs: epochs,
|
|
152
|
+
limit: limit,
|
|
153
|
+
collate: collate,
|
|
154
|
+
bind: bind,
|
|
155
|
+
report: report,
|
|
156
|
+
reduce: reduce,
|
|
157
|
+
monitor: monitor,
|
|
158
|
+
metric: metric,
|
|
159
|
+
validation_data: validation_data,
|
|
160
|
+
validation_limit: validation_limit,
|
|
161
|
+
validation_reduce: validation_reduce,
|
|
162
|
+
validation_collate: validation_collate,
|
|
163
|
+
validation_bind: validation_bind,
|
|
164
|
+
train_transform: train_transform,
|
|
165
|
+
validation_transform: validation_transform,
|
|
166
|
+
checkpoint_path: checkpoint_path,
|
|
167
|
+
save_best: save_best,
|
|
168
|
+
monitor_mode: monitor_mode,
|
|
169
|
+
patience: patience,
|
|
170
|
+
min_delta: min_delta,
|
|
171
|
+
keep_losses: keep_losses,
|
|
172
|
+
strict_data_reuse: strict_data_reuse,
|
|
173
|
+
resume_from: resume_from,
|
|
174
|
+
metadata: metadata
|
|
175
|
+
}
|
|
176
|
+
dataset, raw_options = __dsl_expand_split_plan(dataset, raw_options)
|
|
177
|
+
options = __dsl_resolve_fit_options(raw_options)
|
|
178
|
+
epochs = options.fetch(:epochs)
|
|
179
|
+
limit = options.fetch(:limit)
|
|
180
|
+
collate = options.fetch(:collate)
|
|
181
|
+
bind = options.fetch(:bind)
|
|
182
|
+
report = options.fetch(:report)
|
|
183
|
+
reduce = options.fetch(:reduce)
|
|
184
|
+
monitor = options.fetch(:monitor)
|
|
185
|
+
metric = options.fetch(:metric)
|
|
186
|
+
validation_data = options.fetch(:validation_data)
|
|
187
|
+
validation_limit = options.fetch(:validation_limit)
|
|
188
|
+
validation_reduce = options.fetch(:validation_reduce)
|
|
189
|
+
validation_collate = options.fetch(:validation_collate)
|
|
190
|
+
validation_bind = options.fetch(:validation_bind)
|
|
191
|
+
train_transform = options.fetch(:train_transform)
|
|
192
|
+
validation_transform = options.fetch(:validation_transform)
|
|
193
|
+
checkpoint_path = options.fetch(:checkpoint_path)
|
|
194
|
+
save_best = options.fetch(:save_best)
|
|
195
|
+
monitor_mode = options.fetch(:monitor_mode)
|
|
196
|
+
patience = options.fetch(:patience)
|
|
197
|
+
min_delta = options.fetch(:min_delta)
|
|
198
|
+
keep_losses = options.fetch(:keep_losses)
|
|
199
|
+
strict_data_reuse = options.fetch(:strict_data_reuse)
|
|
200
|
+
resume_from = options.fetch(:resume_from)
|
|
201
|
+
metadata = options.fetch(:metadata)
|
|
202
|
+
policy = __dsl_resolve_artifact_policy(
|
|
203
|
+
checkpoint_path: checkpoint_path,
|
|
204
|
+
save_best: save_best,
|
|
205
|
+
resume_from: resume_from,
|
|
206
|
+
monitor_mode: monitor_mode
|
|
207
|
+
)
|
|
208
|
+
checkpoint_path = policy.fetch(:checkpoint_path)
|
|
209
|
+
save_best = policy.fetch(:save_best)
|
|
210
|
+
resume_from = policy.fetch(:resume_from)
|
|
211
|
+
checkpoint_every = policy.fetch(:checkpoint_every)
|
|
212
|
+
retention_keep_last_n = policy.fetch(:keep_last_n)
|
|
213
|
+
run_bundle_policy = policy.fetch(:run_bundle)
|
|
214
|
+
policy_payload = policy.fetch(:payload)
|
|
215
|
+
|
|
216
|
+
keep_losses = !!keep_losses
|
|
217
|
+
strict_data_reuse = !!strict_data_reuse
|
|
218
|
+
@last_checkpoint_snapshot = nil
|
|
219
|
+
losses = []
|
|
220
|
+
epoch_rows = []
|
|
221
|
+
best_metric = nil
|
|
222
|
+
stale_epochs = 0
|
|
223
|
+
stopped_early = false
|
|
224
|
+
previous_train_batches = nil
|
|
225
|
+
previous_validation_batches = nil
|
|
226
|
+
monitor_name = monitor.to_s
|
|
227
|
+
resume_state = __dsl_resume_state(resume_from, monitor_name)
|
|
228
|
+
start_epoch = resume_state.fetch(:start_epoch)
|
|
229
|
+
best_metric = resume_state.fetch(:best_metric)
|
|
230
|
+
stale_epochs = resume_state.fetch(:stale_epochs)
|
|
231
|
+
monitor_name = resume_state.fetch(:monitor_name)
|
|
232
|
+
total_epochs = epochs.to_i
|
|
233
|
+
patience_value = __dsl_normalize_patience(patience)
|
|
234
|
+
min_delta_value = __dsl_normalize_min_delta(min_delta)
|
|
235
|
+
train_dataset_size = __dsl_dataset_size(dataset)
|
|
236
|
+
validation_dataset_size = __dsl_dataset_size(validation_data)
|
|
237
|
+
validation_reducer = validation_reduce.nil? ? reduce : validation_reduce
|
|
238
|
+
|
|
239
|
+
emit(
|
|
240
|
+
:before_fit,
|
|
241
|
+
{
|
|
242
|
+
model: @model,
|
|
243
|
+
optimizer: @optimizer,
|
|
244
|
+
epochs: total_epochs,
|
|
245
|
+
start_epoch: start_epoch,
|
|
246
|
+
resumed_from_epoch: resume_state.fetch(:checkpoint_epoch),
|
|
247
|
+
resume_from: resume_state.fetch(:path),
|
|
248
|
+
best_metric: best_metric,
|
|
249
|
+
stale_epochs: stale_epochs,
|
|
250
|
+
dataset_size: train_dataset_size,
|
|
251
|
+
validation_size: validation_dataset_size
|
|
252
|
+
}
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
(start_epoch...total_epochs).each do |epoch|
|
|
256
|
+
emit(:before_epoch, { epoch: epoch, model: @model })
|
|
257
|
+
index = 0
|
|
258
|
+
epoch_losses = []
|
|
259
|
+
epoch_last_loss = nil
|
|
260
|
+
train_limit = __dsl_resolve_loop_limit(limit, epoch: epoch, kind: :train)
|
|
261
|
+
__dsl_dataset_for_epoch(dataset, epoch: epoch, kind: :train).each do |batch|
|
|
262
|
+
break if !train_limit.nil? && index >= train_limit
|
|
263
|
+
|
|
264
|
+
batch = __dsl_apply_batch_transform(
|
|
265
|
+
train_transform,
|
|
266
|
+
__dsl_apply_collate(
|
|
267
|
+
__dsl_effective_collate(
|
|
268
|
+
collate,
|
|
269
|
+
bind,
|
|
270
|
+
batch,
|
|
271
|
+
kind: :train
|
|
272
|
+
),
|
|
273
|
+
batch,
|
|
274
|
+
kind: :train,
|
|
275
|
+
epoch: epoch,
|
|
276
|
+
batch_index: index
|
|
277
|
+
),
|
|
278
|
+
epoch: epoch,
|
|
279
|
+
batch_index: index,
|
|
280
|
+
kind: :train
|
|
281
|
+
)
|
|
282
|
+
loss = __dsl_run_batch(
|
|
283
|
+
batch,
|
|
284
|
+
epoch: epoch,
|
|
285
|
+
batch_index: index,
|
|
286
|
+
kind: :train
|
|
287
|
+
)
|
|
288
|
+
epoch_last_loss = loss
|
|
289
|
+
losses << loss if keep_losses
|
|
290
|
+
scalar = __dsl_loss_scalar(loss)
|
|
291
|
+
epoch_losses << scalar unless scalar.nil?
|
|
292
|
+
emit(
|
|
293
|
+
:after_batch,
|
|
294
|
+
{
|
|
295
|
+
epoch: epoch,
|
|
296
|
+
batch_index: index,
|
|
297
|
+
loss: loss,
|
|
298
|
+
loss_value: scalar,
|
|
299
|
+
model: @model
|
|
300
|
+
}
|
|
301
|
+
)
|
|
302
|
+
index += 1
|
|
303
|
+
end
|
|
304
|
+
__dsl_validate_data_reuse!(
|
|
305
|
+
strict: strict_data_reuse,
|
|
306
|
+
dataset: dataset,
|
|
307
|
+
kind: :train,
|
|
308
|
+
epoch: epoch,
|
|
309
|
+
previous_batches: previous_train_batches,
|
|
310
|
+
current_batches: index
|
|
311
|
+
)
|
|
312
|
+
previous_train_batches = index
|
|
313
|
+
|
|
314
|
+
epoch_metric = __dsl_reduce_values(epoch_losses, reduce)
|
|
315
|
+
validation_losses = []
|
|
316
|
+
validation_batch_count = 0
|
|
317
|
+
val_metric = nil
|
|
318
|
+
unless validation_data.nil?
|
|
319
|
+
validation_epoch_limit = __dsl_resolve_loop_limit(
|
|
320
|
+
validation_limit,
|
|
321
|
+
epoch: epoch,
|
|
322
|
+
kind: :validation
|
|
323
|
+
)
|
|
324
|
+
emit(
|
|
325
|
+
:before_validation,
|
|
326
|
+
{
|
|
327
|
+
epoch: epoch,
|
|
328
|
+
model: @model,
|
|
329
|
+
monitor_name: monitor_name
|
|
330
|
+
}
|
|
331
|
+
)
|
|
332
|
+
__dsl_with_eval_mode do
|
|
333
|
+
__dsl_dataset_for_epoch(validation_data, epoch: epoch, kind: :validation).each do |batch|
|
|
334
|
+
break if !validation_epoch_limit.nil? && validation_batch_count >= validation_epoch_limit
|
|
335
|
+
|
|
336
|
+
batch = __dsl_apply_batch_transform(
|
|
337
|
+
validation_transform,
|
|
338
|
+
__dsl_apply_collate(
|
|
339
|
+
__dsl_effective_collate(
|
|
340
|
+
validation_collate,
|
|
341
|
+
validation_bind,
|
|
342
|
+
batch,
|
|
343
|
+
kind: :validation
|
|
344
|
+
),
|
|
345
|
+
batch,
|
|
346
|
+
kind: :validation,
|
|
347
|
+
epoch: epoch,
|
|
348
|
+
batch_index: validation_batch_count
|
|
349
|
+
),
|
|
350
|
+
epoch: epoch,
|
|
351
|
+
batch_index: validation_batch_count,
|
|
352
|
+
kind: :validation
|
|
353
|
+
)
|
|
354
|
+
loss = __dsl_run_validation_batch(
|
|
355
|
+
batch,
|
|
356
|
+
epoch: epoch,
|
|
357
|
+
batch_index: validation_batch_count,
|
|
358
|
+
kind: :validation
|
|
359
|
+
)
|
|
360
|
+
scalar = __dsl_loss_scalar(loss)
|
|
361
|
+
validation_losses << scalar unless scalar.nil?
|
|
362
|
+
emit(
|
|
363
|
+
:after_validation_batch,
|
|
364
|
+
{
|
|
365
|
+
epoch: epoch,
|
|
366
|
+
batch_index: validation_batch_count,
|
|
367
|
+
loss: loss,
|
|
368
|
+
loss_value: scalar,
|
|
369
|
+
model: @model
|
|
370
|
+
}
|
|
371
|
+
)
|
|
372
|
+
validation_batch_count += 1
|
|
373
|
+
end
|
|
374
|
+
end
|
|
375
|
+
__dsl_validate_data_reuse!(
|
|
376
|
+
strict: strict_data_reuse,
|
|
377
|
+
dataset: validation_data,
|
|
378
|
+
kind: :validation,
|
|
379
|
+
epoch: epoch,
|
|
380
|
+
previous_batches: previous_validation_batches,
|
|
381
|
+
current_batches: validation_batch_count
|
|
382
|
+
)
|
|
383
|
+
previous_validation_batches = validation_batch_count
|
|
384
|
+
val_metric = __dsl_reduce_values(validation_losses, validation_reducer)
|
|
385
|
+
emit(
|
|
386
|
+
:after_validation,
|
|
387
|
+
{
|
|
388
|
+
epoch: epoch,
|
|
389
|
+
model: @model,
|
|
390
|
+
val_loss: val_metric,
|
|
391
|
+
validation_batches: validation_batch_count
|
|
392
|
+
}
|
|
393
|
+
)
|
|
394
|
+
end
|
|
395
|
+
|
|
396
|
+
monitor_value = __dsl_monitor_value(
|
|
397
|
+
metric,
|
|
398
|
+
{
|
|
399
|
+
epoch: epoch,
|
|
400
|
+
epoch_losses: epoch_losses,
|
|
401
|
+
epoch_loss: epoch_metric,
|
|
402
|
+
val_loss: val_metric,
|
|
403
|
+
validation_losses: validation_losses,
|
|
404
|
+
losses: losses,
|
|
405
|
+
batches: index,
|
|
406
|
+
validation_batches: validation_batch_count,
|
|
407
|
+
model: @model,
|
|
408
|
+
optimizer: @optimizer
|
|
409
|
+
},
|
|
410
|
+
fallback: __dsl_default_monitor_value(monitor_name, epoch_metric, val_metric)
|
|
411
|
+
)
|
|
412
|
+
improved = __dsl_improved?(
|
|
413
|
+
monitor_value,
|
|
414
|
+
best_metric,
|
|
415
|
+
monitor_mode,
|
|
416
|
+
min_delta: min_delta_value
|
|
417
|
+
)
|
|
418
|
+
if improved
|
|
419
|
+
best_metric = monitor_value
|
|
420
|
+
stale_epochs = 0
|
|
421
|
+
elsif !best_metric.nil?
|
|
422
|
+
stale_epochs += 1
|
|
423
|
+
end
|
|
424
|
+
|
|
425
|
+
row = {
|
|
426
|
+
"epoch" => epoch,
|
|
427
|
+
"batches" => index,
|
|
428
|
+
"epoch_loss" => epoch_metric,
|
|
429
|
+
"val_loss" => val_metric,
|
|
430
|
+
"validation_batches" => validation_batch_count,
|
|
431
|
+
"monitor_value" => monitor_value,
|
|
432
|
+
"stale_epochs" => stale_epochs,
|
|
433
|
+
"improved" => improved
|
|
434
|
+
}
|
|
435
|
+
epoch_rows << row
|
|
436
|
+
|
|
437
|
+
checkpoint_saved = __dsl_maybe_checkpoint(
|
|
438
|
+
checkpoint_path,
|
|
439
|
+
save_best: save_best,
|
|
440
|
+
improved: improved,
|
|
441
|
+
epoch: epoch,
|
|
442
|
+
monitor_name: monitor_name,
|
|
443
|
+
monitor_value: monitor_value,
|
|
444
|
+
epoch_metric: epoch_metric,
|
|
445
|
+
stale_epochs: stale_epochs,
|
|
446
|
+
best_metric: best_metric,
|
|
447
|
+
metadata: metadata,
|
|
448
|
+
checkpoint_every: checkpoint_every,
|
|
449
|
+
keep_last_n: retention_keep_last_n
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
emit(
|
|
453
|
+
:after_epoch,
|
|
454
|
+
{
|
|
455
|
+
epoch: epoch,
|
|
456
|
+
model: @model,
|
|
457
|
+
epoch_loss: epoch_metric,
|
|
458
|
+
val_loss: val_metric,
|
|
459
|
+
monitor_name: monitor_name,
|
|
460
|
+
monitor_value: monitor_value,
|
|
461
|
+
validation_batches: validation_batch_count,
|
|
462
|
+
stale_epochs: stale_epochs,
|
|
463
|
+
improved: improved,
|
|
464
|
+
best_metric: best_metric,
|
|
465
|
+
checkpoint_saved: checkpoint_saved
|
|
466
|
+
}
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
__dsl_sync_epoch(epoch_last_loss) if @sync_mode == :epoch
|
|
470
|
+
|
|
471
|
+
if !patience_value.nil? && stale_epochs > patience_value
|
|
472
|
+
stopped_early = true
|
|
473
|
+
break
|
|
474
|
+
end
|
|
475
|
+
end
|
|
476
|
+
|
|
477
|
+
payload = {
|
|
478
|
+
"losses" => losses,
|
|
479
|
+
"losses_kept" => keep_losses,
|
|
480
|
+
"epochs" => epoch_rows,
|
|
481
|
+
"monitor_name" => monitor_name,
|
|
482
|
+
"epochs_target" => total_epochs,
|
|
483
|
+
"epochs_completed" => [start_epoch + epoch_rows.length, total_epochs].min,
|
|
484
|
+
"epochs_ran" => epoch_rows.length,
|
|
485
|
+
"stopped_early" => stopped_early,
|
|
486
|
+
"best_metric" => best_metric,
|
|
487
|
+
"resume_from" => resume_state.fetch(:path),
|
|
488
|
+
"resumed_from_epoch" => resume_state.fetch(:checkpoint_epoch),
|
|
489
|
+
"start_epoch" => start_epoch,
|
|
490
|
+
"artifact_policy" => policy_payload
|
|
491
|
+
}
|
|
492
|
+
auto_bundle_path = __dsl_auto_save_run_bundle(
|
|
493
|
+
run_bundle_policy,
|
|
494
|
+
payload
|
|
495
|
+
)
|
|
496
|
+
payload["run_bundle_path"] = auto_bundle_path unless auto_bundle_path.nil?
|
|
497
|
+
emit(
|
|
498
|
+
:after_fit,
|
|
499
|
+
{
|
|
500
|
+
model: @model,
|
|
501
|
+
optimizer: @optimizer,
|
|
502
|
+
epochs: epoch_rows.length,
|
|
503
|
+
best_metric: best_metric,
|
|
504
|
+
stopped_early: stopped_early,
|
|
505
|
+
resume_from: resume_state.fetch(:path),
|
|
506
|
+
resumed_from_epoch: resume_state.fetch(:checkpoint_epoch),
|
|
507
|
+
report: payload
|
|
508
|
+
}
|
|
509
|
+
)
|
|
510
|
+
|
|
511
|
+
return payload if report
|
|
512
|
+
|
|
513
|
+
losses
|
|
514
|
+
end
|
|
515
|
+
|
|
516
|
+
def fit_report(dataset, **kwargs)
|
|
517
|
+
fit(dataset, **kwargs, report: true)
|
|
518
|
+
end
|
|
519
|
+
|
|
520
|
+
def with_fit_defaults(**defaults)
|
|
521
|
+
configured = __dsl_clone_trainer
|
|
522
|
+
configured.instance_variable_set(
|
|
523
|
+
:@fit_defaults,
|
|
524
|
+
configured.instance_variable_get(:@fit_defaults).merge(__dsl_normalize_fit_option_keys(defaults))
|
|
525
|
+
)
|
|
526
|
+
configured
|
|
527
|
+
end
|
|
528
|
+
|
|
529
|
+
def register_fit_preset(name, **defaults)
|
|
530
|
+
@fit_presets[name.to_sym] = __dsl_normalize_fit_option_keys(defaults)
|
|
531
|
+
self
|
|
532
|
+
end
|
|
533
|
+
|
|
534
|
+
def fit_with(name, dataset, **overrides)
|
|
535
|
+
fit(dataset, **__dsl_merge_fit_preset(name, overrides))
|
|
536
|
+
end
|
|
537
|
+
|
|
538
|
+
def fit_report_with(name, dataset, **overrides)
|
|
539
|
+
fit_report(dataset, **__dsl_merge_fit_preset(name, overrides))
|
|
540
|
+
end
|
|
541
|
+
|
|
542
|
+
def register_dataflow(name, train: {}, validation: {}, extends: nil)
|
|
543
|
+
profile = __dsl_normalize_dataflow_profile(train: train, validation: validation)
|
|
544
|
+
if !extends.nil?
|
|
545
|
+
base_profile = { train: {}, validation: {} }
|
|
546
|
+
base_names = extends.is_a?(Array) ? extends : [extends]
|
|
547
|
+
base_names.each do |base_name|
|
|
548
|
+
base_key = base_name.to_sym
|
|
549
|
+
unless @dataflow_profiles.key?(base_key)
|
|
550
|
+
raise ArgumentError, "unknown dataflow profile: #{base_name.inspect}"
|
|
551
|
+
end
|
|
552
|
+
|
|
553
|
+
base_profile = __dsl_compose_dataflow_profile(base_profile, @dataflow_profiles.fetch(base_key))
|
|
554
|
+
end
|
|
555
|
+
profile = __dsl_compose_dataflow_profile(base_profile, profile)
|
|
556
|
+
end
|
|
557
|
+
|
|
558
|
+
@dataflow_profiles[name.to_sym] = profile
|
|
559
|
+
self
|
|
560
|
+
end
|
|
561
|
+
|
|
562
|
+
def use_dataflow(name, **overrides)
|
|
563
|
+
profile = __dsl_resolve_dataflow_profile(name)
|
|
564
|
+
profile_overrides, direct_overrides = __dsl_normalize_dataflow_overrides(overrides)
|
|
565
|
+
resolved = __dsl_compose_dataflow_profile(profile, profile_overrides)
|
|
566
|
+
__dsl_fit_kwargs_from_dataflow(resolved).merge(__dsl_normalize_fit_option_keys(direct_overrides))
|
|
567
|
+
end
|
|
568
|
+
|
|
569
|
+
def register_hook_pack(name, callable = nil, &block)
|
|
570
|
+
if !callable.nil? && block_given?
|
|
571
|
+
raise ArgumentError, "register_hook_pack accepts either a callable argument or block, not both"
|
|
572
|
+
end
|
|
573
|
+
|
|
574
|
+
pack = callable.nil? ? block : callable
|
|
575
|
+
unless pack.respond_to?(:call)
|
|
576
|
+
raise ArgumentError, "register_hook_pack requires a callable argument or block"
|
|
577
|
+
end
|
|
578
|
+
|
|
579
|
+
@hook_packs[name.to_sym] = pack
|
|
580
|
+
self
|
|
581
|
+
end
|
|
582
|
+
|
|
583
|
+
def use_hook_pack(name, **options)
|
|
584
|
+
key = name.to_sym
|
|
585
|
+
unless @hook_packs.key?(key)
|
|
586
|
+
raise ArgumentError, "unknown hook pack: #{name.inspect}"
|
|
587
|
+
end
|
|
588
|
+
|
|
589
|
+
__dsl_call_hook_pack(@hook_packs.fetch(key), options)
|
|
590
|
+
self
|
|
591
|
+
end
|
|
592
|
+
|
|
593
|
+
def register_metric(name, callable = nil, &block)
|
|
594
|
+
if !callable.nil? && block_given?
|
|
595
|
+
raise ArgumentError, "register_metric accepts either a callable argument or block, not both"
|
|
596
|
+
end
|
|
597
|
+
|
|
598
|
+
metric_callable = callable.nil? ? block : callable
|
|
599
|
+
unless metric_callable.respond_to?(:call)
|
|
600
|
+
raise ArgumentError, "register_metric requires a callable argument or block"
|
|
601
|
+
end
|
|
602
|
+
|
|
603
|
+
@metric_registry[name.to_sym] = metric_callable
|
|
604
|
+
self
|
|
605
|
+
end
|
|
606
|
+
|
|
607
|
+
def register_task(name, **defaults)
|
|
608
|
+
@task_presets[name.to_sym] = __dsl_normalize_fit_option_keys(defaults)
|
|
609
|
+
self
|
|
610
|
+
end
|
|
611
|
+
|
|
612
|
+
def fit_task(task, dataset, **overrides)
|
|
613
|
+
fit(dataset, **__dsl_task_fit_options(task, overrides))
|
|
614
|
+
end
|
|
615
|
+
|
|
616
|
+
def fit_task_report(task, dataset, **overrides)
|
|
617
|
+
fit_report(dataset, **__dsl_task_fit_options(task, overrides))
|
|
618
|
+
end
|
|
619
|
+
|
|
620
|
+
def artifact_policy(checkpoint: nil, retention: nil, resume: UNSET, run_bundle: nil)
|
|
621
|
+
if checkpoint.nil? && retention.nil? && resume.equal?(UNSET) && run_bundle.nil?
|
|
622
|
+
return __dsl_clone_config_value(@artifact_policy_config)
|
|
623
|
+
end
|
|
624
|
+
|
|
625
|
+
unless checkpoint.nil?
|
|
626
|
+
@artifact_policy_config[:checkpoint] = __dsl_normalize_artifact_checkpoint_policy(checkpoint)
|
|
627
|
+
end
|
|
628
|
+
unless retention.nil?
|
|
629
|
+
@artifact_policy_config[:retention] = __dsl_normalize_artifact_retention_policy(retention)
|
|
630
|
+
end
|
|
631
|
+
@artifact_policy_config[:resume] = resume unless resume.equal?(UNSET)
|
|
632
|
+
unless run_bundle.nil?
|
|
633
|
+
@artifact_policy_config[:run_bundle] = __dsl_normalize_artifact_run_bundle_policy(run_bundle)
|
|
634
|
+
end
|
|
635
|
+
self
|
|
636
|
+
end
|
|
637
|
+
|
|
638
|
+
def checkpoint_history
|
|
639
|
+
__dsl_clone_config_value(@checkpoint_history)
|
|
640
|
+
end
|
|
641
|
+
|
|
642
|
+
def batch_schema(spec = UNSET, train: UNSET, validation: UNSET, **schema_kwargs)
|
|
643
|
+
unless schema_kwargs.empty?
|
|
644
|
+
if spec.equal?(UNSET) && train.equal?(UNSET) && validation.equal?(UNSET)
|
|
645
|
+
spec = schema_kwargs
|
|
646
|
+
else
|
|
647
|
+
raise ArgumentError, "batch_schema keyword mappings cannot be combined with positional or split-specific forms"
|
|
648
|
+
end
|
|
649
|
+
end
|
|
650
|
+
|
|
651
|
+
if !spec.equal?(UNSET) && (!train.equal?(UNSET) || !validation.equal?(UNSET))
|
|
652
|
+
raise ArgumentError, "batch_schema accepts either positional spec or split-specific train:/validation: overrides"
|
|
653
|
+
end
|
|
654
|
+
if spec.equal?(UNSET) && train.equal?(UNSET) && validation.equal?(UNSET)
|
|
655
|
+
return __dsl_clone_config_value(@batch_schemas)
|
|
656
|
+
end
|
|
657
|
+
|
|
658
|
+
if !spec.equal?(UNSET)
|
|
659
|
+
normalized = __dsl_normalize_batch_schema_spec(spec)
|
|
660
|
+
@batch_schemas[:train] = normalized
|
|
661
|
+
@batch_schemas[:validation] = normalized
|
|
662
|
+
return self
|
|
663
|
+
end
|
|
664
|
+
|
|
665
|
+
unless train.equal?(UNSET)
|
|
666
|
+
@batch_schemas[:train] = __dsl_normalize_batch_schema_spec(train)
|
|
667
|
+
end
|
|
668
|
+
unless validation.equal?(UNSET)
|
|
669
|
+
@batch_schemas[:validation] = __dsl_normalize_batch_schema_spec(validation)
|
|
670
|
+
end
|
|
671
|
+
self
|
|
672
|
+
end
|
|
673
|
+
|
|
674
|
+
def register_collate(name, spec = nil, extends: nil, &block)
|
|
675
|
+
if !spec.nil? && block_given?
|
|
676
|
+
raise ArgumentError, "register_collate accepts either a spec argument or block, not both"
|
|
677
|
+
end
|
|
678
|
+
|
|
679
|
+
key = name.to_sym
|
|
680
|
+
resolved_spec = block_given? ? block : spec
|
|
681
|
+
if resolved_spec.nil?
|
|
682
|
+
raise ArgumentError, "register_collate requires a collate spec or block"
|
|
683
|
+
end
|
|
684
|
+
|
|
685
|
+
if !extends.nil?
|
|
686
|
+
base_keys = extends.is_a?(Array) ? extends : [extends]
|
|
687
|
+
base_spec = nil
|
|
688
|
+
base_keys.each do |base_name|
|
|
689
|
+
base_key = base_name.to_sym
|
|
690
|
+
unless @collate_registry.key?(base_key)
|
|
691
|
+
raise ArgumentError, "unknown base collate schema: #{base_name.inspect}"
|
|
692
|
+
end
|
|
693
|
+
|
|
694
|
+
current_base = @collate_registry.fetch(base_key)
|
|
695
|
+
base_spec = base_spec.nil? ? current_base : __dsl_compose_collate(base_spec, current_base)
|
|
696
|
+
end
|
|
697
|
+
resolved_spec = __dsl_compose_collate(base_spec, resolved_spec)
|
|
698
|
+
end
|
|
699
|
+
|
|
700
|
+
@collate_registry[key] = resolved_spec
|
|
701
|
+
self
|
|
702
|
+
end
|
|
703
|
+
|
|
704
|
+
def run_bundle(report:, config: {}, schema_version: "mlx_dsl_run_bundle_v1")
|
|
705
|
+
unless report.is_a?(Hash)
|
|
706
|
+
raise ArgumentError, "run_bundle requires report to be a Hash from fit_report"
|
|
707
|
+
end
|
|
708
|
+
|
|
709
|
+
{
|
|
710
|
+
"format" => schema_version.to_s,
|
|
711
|
+
"generated_at" => Time.now.utc.iso8601,
|
|
712
|
+
"trainer" => {
|
|
713
|
+
"monitor_name" => report["monitor_name"],
|
|
714
|
+
"epochs_target" => report["epochs_target"],
|
|
715
|
+
"epochs_completed" => report["epochs_completed"],
|
|
716
|
+
"resume_from" => report["resume_from"],
|
|
717
|
+
"resumed_from_epoch" => report["resumed_from_epoch"]
|
|
718
|
+
},
|
|
719
|
+
"config" => config || {},
|
|
720
|
+
"report" => report,
|
|
721
|
+
"checkpoint" => __dsl_deep_copy(@last_checkpoint_snapshot)
|
|
722
|
+
}
|
|
723
|
+
end
|
|
724
|
+
|
|
725
|
+
def save_run_bundle(path, report:, config: {}, schema_version: "mlx_dsl_run_bundle_v1")
|
|
726
|
+
bundle = run_bundle(
|
|
727
|
+
report: report,
|
|
728
|
+
config: config,
|
|
729
|
+
schema_version: schema_version
|
|
730
|
+
)
|
|
731
|
+
dir = File.dirname(path.to_s)
|
|
732
|
+
FileUtils.mkdir_p(dir) unless dir.nil? || dir.empty? || dir == "."
|
|
733
|
+
File.binwrite(path, JSON.pretty_generate(bundle))
|
|
734
|
+
path
|
|
735
|
+
end
|
|
736
|
+
|
|
737
|
+
def resume_payload_from_bundle(bundle_or_path)
|
|
738
|
+
bundle = __dsl_resolve_run_bundle(bundle_or_path)
|
|
739
|
+
checkpoint = bundle["checkpoint"]
|
|
740
|
+
unless checkpoint.is_a?(Hash)
|
|
741
|
+
raise ArgumentError, "run bundle does not include checkpoint snapshot"
|
|
742
|
+
end
|
|
743
|
+
|
|
744
|
+
metadata = checkpoint["metadata"]
|
|
745
|
+
unless metadata.is_a?(Hash)
|
|
746
|
+
raise ArgumentError, "run bundle checkpoint metadata is missing or invalid"
|
|
747
|
+
end
|
|
748
|
+
|
|
749
|
+
{ "metadata" => __dsl_deep_copy(metadata) }
|
|
750
|
+
end
|
|
751
|
+
|
|
752
|
+
private
|
|
753
|
+
|
|
754
|
+
def __dsl_clone_trainer
|
|
755
|
+
cloned = self.class.new(**@__dsl_init_options, &@loss_block)
|
|
756
|
+
hooks = Hash.new { |h, k| h[k] = [] }
|
|
757
|
+
@hooks.each do |event, entries|
|
|
758
|
+
hooks[event] = entries.map(&:dup)
|
|
759
|
+
end
|
|
760
|
+
cloned.instance_variable_set(:@hooks, hooks)
|
|
761
|
+
cloned.instance_variable_set(:@hook_order, @hook_order)
|
|
762
|
+
cloned.instance_variable_set(:@collate_registry, __dsl_clone_config_value(@collate_registry))
|
|
763
|
+
cloned.instance_variable_set(:@fit_defaults, __dsl_clone_config_value(@fit_defaults))
|
|
764
|
+
cloned.instance_variable_set(:@fit_presets, __dsl_clone_config_value(@fit_presets))
|
|
765
|
+
cloned.instance_variable_set(:@batch_schemas, __dsl_clone_config_value(@batch_schemas))
|
|
766
|
+
cloned.instance_variable_set(:@dataflow_profiles, __dsl_clone_config_value(@dataflow_profiles))
|
|
767
|
+
cloned.instance_variable_set(:@hook_packs, __dsl_clone_config_value(@hook_packs))
|
|
768
|
+
cloned.instance_variable_set(:@metric_registry, __dsl_clone_config_value(@metric_registry))
|
|
769
|
+
cloned.instance_variable_set(:@task_presets, __dsl_clone_config_value(@task_presets))
|
|
770
|
+
cloned.instance_variable_set(:@artifact_policy_config, __dsl_clone_config_value(@artifact_policy_config))
|
|
771
|
+
cloned.instance_variable_set(:@checkpoint_history, __dsl_clone_config_value(@checkpoint_history))
|
|
772
|
+
cloned.instance_variable_set(:@last_checkpoint_snapshot, __dsl_clone_config_value(@last_checkpoint_snapshot))
|
|
773
|
+
cloned
|
|
774
|
+
end
|
|
775
|
+
|
|
776
|
+
def __dsl_clone_config_value(value)
|
|
777
|
+
case value
|
|
778
|
+
when Hash
|
|
779
|
+
value.each_with_object({}) do |(key, item), out|
|
|
780
|
+
out[key] = __dsl_clone_config_value(item)
|
|
781
|
+
end
|
|
782
|
+
when Array
|
|
783
|
+
value.map { |item| __dsl_clone_config_value(item) }
|
|
784
|
+
else
|
|
785
|
+
value
|
|
786
|
+
end
|
|
787
|
+
end
|
|
788
|
+
|
|
789
|
+
def __dsl_expand_split_plan(dataset, raw_options)
|
|
790
|
+
return [dataset, raw_options] unless defined?(MLX::DSL::SplitPlan) && dataset.is_a?(MLX::DSL::SplitPlan)
|
|
791
|
+
|
|
792
|
+
train_dataset, plan_options = dataset.to_fit_inputs
|
|
793
|
+
merged = raw_options.dup
|
|
794
|
+
plan_options.each do |key, value|
|
|
795
|
+
next unless merged.key?(key)
|
|
796
|
+
next unless merged.fetch(key).equal?(UNSET)
|
|
797
|
+
|
|
798
|
+
merged[key] = value
|
|
799
|
+
end
|
|
800
|
+
[train_dataset, merged]
|
|
801
|
+
end
|
|
802
|
+
|
|
803
|
+
def __dsl_normalize_artifact_checkpoint_policy(checkpoint)
|
|
804
|
+
unless checkpoint.is_a?(Hash)
|
|
805
|
+
raise ArgumentError, "artifact checkpoint policy must be a Hash"
|
|
806
|
+
end
|
|
807
|
+
|
|
808
|
+
normalized = checkpoint.each_with_object({}) do |(key, value), out|
|
|
809
|
+
out[key.to_sym] = value
|
|
810
|
+
end
|
|
811
|
+
unknown = normalized.keys - %i[path strategy every]
|
|
812
|
+
unless unknown.empty?
|
|
813
|
+
raise ArgumentError, "artifact checkpoint policy has unsupported key(s): #{unknown.map(&:inspect).join(', ')}"
|
|
814
|
+
end
|
|
815
|
+
if normalized.key?(:strategy)
|
|
816
|
+
strategy = normalized.fetch(:strategy).to_sym
|
|
817
|
+
unless %i[latest best every].include?(strategy)
|
|
818
|
+
raise ArgumentError, "artifact checkpoint strategy must be :latest, :best, or :every"
|
|
819
|
+
end
|
|
820
|
+
|
|
821
|
+
normalized[:strategy] = strategy
|
|
822
|
+
end
|
|
823
|
+
if normalized.key?(:every)
|
|
824
|
+
every = normalized.fetch(:every).to_i
|
|
825
|
+
raise ArgumentError, "artifact checkpoint every must be positive" if every <= 0
|
|
826
|
+
|
|
827
|
+
normalized[:every] = every
|
|
828
|
+
end
|
|
829
|
+
normalized
|
|
830
|
+
end
|
|
831
|
+
|
|
832
|
+
def __dsl_normalize_artifact_retention_policy(retention)
|
|
833
|
+
unless retention.is_a?(Hash)
|
|
834
|
+
raise ArgumentError, "artifact retention policy must be a Hash"
|
|
835
|
+
end
|
|
836
|
+
|
|
837
|
+
normalized = retention.each_with_object({}) do |(key, value), out|
|
|
838
|
+
out[key.to_sym] = value
|
|
839
|
+
end
|
|
840
|
+
unknown = normalized.keys - %i[keep_last_n]
|
|
841
|
+
unless unknown.empty?
|
|
842
|
+
raise ArgumentError, "artifact retention policy has unsupported key(s): #{unknown.map(&:inspect).join(', ')}"
|
|
843
|
+
end
|
|
844
|
+
if normalized.key?(:keep_last_n)
|
|
845
|
+
keep = normalized.fetch(:keep_last_n).to_i
|
|
846
|
+
raise ArgumentError, "artifact retention keep_last_n must be non-negative" if keep.negative?
|
|
847
|
+
|
|
848
|
+
normalized[:keep_last_n] = keep
|
|
849
|
+
end
|
|
850
|
+
normalized
|
|
851
|
+
end
|
|
852
|
+
|
|
853
|
+
def __dsl_normalize_artifact_run_bundle_policy(run_bundle)
|
|
854
|
+
unless run_bundle.is_a?(Hash)
|
|
855
|
+
raise ArgumentError, "artifact run_bundle policy must be a Hash"
|
|
856
|
+
end
|
|
857
|
+
|
|
858
|
+
normalized = run_bundle.each_with_object({}) do |(key, value), out|
|
|
859
|
+
out[key.to_sym] = value
|
|
860
|
+
end
|
|
861
|
+
unknown = normalized.keys - %i[enabled path config]
|
|
862
|
+
unless unknown.empty?
|
|
863
|
+
raise ArgumentError, "artifact run_bundle policy has unsupported key(s): #{unknown.map(&:inspect).join(', ')}"
|
|
864
|
+
end
|
|
865
|
+
normalized[:enabled] = !!normalized.fetch(:enabled, false)
|
|
866
|
+
normalized[:config] = {} if normalized[:config].nil?
|
|
867
|
+
normalized
|
|
868
|
+
end
|
|
869
|
+
|
|
870
|
+
def __dsl_resolve_artifact_policy(checkpoint_path:, save_best:, resume_from:, monitor_mode:)
|
|
871
|
+
checkpoint_policy = @artifact_policy_config.fetch(:checkpoint, {})
|
|
872
|
+
retention_policy = @artifact_policy_config.fetch(:retention, {})
|
|
873
|
+
run_bundle_policy = @artifact_policy_config.fetch(:run_bundle, {})
|
|
874
|
+
resume_policy = @artifact_policy_config.fetch(:resume, nil)
|
|
875
|
+
|
|
876
|
+
resolved_checkpoint_path = checkpoint_path
|
|
877
|
+
if (resolved_checkpoint_path.nil? || resolved_checkpoint_path.to_s.empty?) && checkpoint_policy.key?(:path)
|
|
878
|
+
resolved_checkpoint_path = checkpoint_policy.fetch(:path)
|
|
879
|
+
end
|
|
880
|
+
|
|
881
|
+
resolved_save_best = save_best
|
|
882
|
+
strategy = checkpoint_policy[:strategy]
|
|
883
|
+
case strategy
|
|
884
|
+
when :best
|
|
885
|
+
resolved_save_best = true
|
|
886
|
+
when :latest, :every
|
|
887
|
+
resolved_save_best = false
|
|
888
|
+
end
|
|
889
|
+
|
|
890
|
+
resolved_resume = resume_from
|
|
891
|
+
if (resolved_resume.nil? || resolved_resume.to_s.empty?) && !resume_policy.nil?
|
|
892
|
+
resolved_resume = __dsl_policy_resume_source(resume_policy, monitor_mode: monitor_mode)
|
|
893
|
+
end
|
|
894
|
+
|
|
895
|
+
{
|
|
896
|
+
checkpoint_path: resolved_checkpoint_path,
|
|
897
|
+
save_best: resolved_save_best,
|
|
898
|
+
checkpoint_every: checkpoint_policy[:every],
|
|
899
|
+
keep_last_n: retention_policy[:keep_last_n],
|
|
900
|
+
resume_from: resolved_resume,
|
|
901
|
+
run_bundle: __dsl_clone_config_value(run_bundle_policy),
|
|
902
|
+
payload: __dsl_stringify_keys(__dsl_clone_config_value(@artifact_policy_config))
|
|
903
|
+
}
|
|
904
|
+
end
|
|
905
|
+
|
|
906
|
+
def __dsl_policy_resume_source(policy_resume, monitor_mode:)
|
|
907
|
+
case policy_resume
|
|
908
|
+
when :latest
|
|
909
|
+
entry = @checkpoint_history.last
|
|
910
|
+
return nil if entry.nil?
|
|
911
|
+
|
|
912
|
+
{ "metadata" => __dsl_deep_copy(entry.fetch("metadata")) }
|
|
913
|
+
when :best
|
|
914
|
+
entry = __dsl_best_checkpoint_from_history(monitor_mode)
|
|
915
|
+
return nil if entry.nil?
|
|
916
|
+
|
|
917
|
+
{ "metadata" => __dsl_deep_copy(entry.fetch("metadata")) }
|
|
918
|
+
else
|
|
919
|
+
policy_resume
|
|
920
|
+
end
|
|
921
|
+
end
|
|
922
|
+
|
|
923
|
+
def __dsl_best_checkpoint_from_history(monitor_mode)
|
|
924
|
+
rows = @checkpoint_history.select do |row|
|
|
925
|
+
row.is_a?(Hash) && row.key?("metadata") && row.fetch("metadata").is_a?(Hash)
|
|
926
|
+
end
|
|
927
|
+
return nil if rows.empty?
|
|
928
|
+
|
|
929
|
+
comparator = monitor_mode.to_sym
|
|
930
|
+
case comparator
|
|
931
|
+
when :max
|
|
932
|
+
rows.max_by { |row| row.fetch("monitor_value").to_f }
|
|
933
|
+
else
|
|
934
|
+
rows.min_by { |row| row.fetch("monitor_value").to_f }
|
|
935
|
+
end
|
|
936
|
+
end
|
|
937
|
+
|
|
938
|
+
def __dsl_auto_save_run_bundle(run_bundle_policy, report_payload)
|
|
939
|
+
return nil unless run_bundle_policy.is_a?(Hash)
|
|
940
|
+
return nil unless run_bundle_policy.fetch(:enabled, false)
|
|
941
|
+
|
|
942
|
+
path = run_bundle_policy[:path]
|
|
943
|
+
return nil if path.nil? || path.to_s.empty?
|
|
944
|
+
|
|
945
|
+
save_run_bundle(
|
|
946
|
+
path,
|
|
947
|
+
report: report_payload,
|
|
948
|
+
config: run_bundle_policy.fetch(:config, {})
|
|
949
|
+
)
|
|
950
|
+
end
|
|
951
|
+
|
|
952
|
+
def __dsl_stringify_keys(value)
|
|
953
|
+
case value
|
|
954
|
+
when Hash
|
|
955
|
+
value.each_with_object({}) do |(key, item), out|
|
|
956
|
+
out[key.to_s] = __dsl_stringify_keys(item)
|
|
957
|
+
end
|
|
958
|
+
when Array
|
|
959
|
+
value.map { |item| __dsl_stringify_keys(item) }
|
|
960
|
+
else
|
|
961
|
+
value
|
|
962
|
+
end
|
|
963
|
+
end
|
|
964
|
+
|
|
965
|
+
def __dsl_normalize_fit_option_keys(options)
|
|
966
|
+
normalized = (options || {}).each_with_object({}) do |(key, value), out|
|
|
967
|
+
out[key.to_sym] = value
|
|
968
|
+
end
|
|
969
|
+
unknown = normalized.keys - FIT_OPTION_DEFAULTS.keys
|
|
970
|
+
unless unknown.empty?
|
|
971
|
+
raise ArgumentError, "unsupported fit option(s): #{unknown.map(&:inspect).join(', ')}"
|
|
972
|
+
end
|
|
973
|
+
|
|
974
|
+
normalized
|
|
975
|
+
end
|
|
976
|
+
|
|
977
|
+
def __dsl_resolve_fit_options(raw_options)
|
|
978
|
+
normalized_raw = __dsl_normalize_fit_option_keys(raw_options)
|
|
979
|
+
normalized_defaults = __dsl_normalize_fit_option_keys(@fit_defaults)
|
|
980
|
+
FIT_OPTION_DEFAULTS.each_with_object({}) do |(key, fallback), out|
|
|
981
|
+
if normalized_raw.fetch(key).equal?(UNSET)
|
|
982
|
+
out[key] = if normalized_defaults.key?(key)
|
|
983
|
+
normalized_defaults.fetch(key)
|
|
984
|
+
else
|
|
985
|
+
__dsl_clone_config_value(fallback)
|
|
986
|
+
end
|
|
987
|
+
else
|
|
988
|
+
out[key] = normalized_raw.fetch(key)
|
|
989
|
+
end
|
|
990
|
+
end
|
|
991
|
+
end
|
|
992
|
+
|
|
993
|
+
def __dsl_merge_fit_preset(name, overrides)
|
|
994
|
+
key = name.to_sym
|
|
995
|
+
unless @fit_presets.key?(key)
|
|
996
|
+
raise ArgumentError, "unknown fit preset: #{name.inspect}"
|
|
997
|
+
end
|
|
998
|
+
|
|
999
|
+
__dsl_normalize_fit_option_keys(@fit_defaults)
|
|
1000
|
+
.merge(__dsl_normalize_fit_option_keys(@fit_presets.fetch(key)))
|
|
1001
|
+
.merge(__dsl_normalize_fit_option_keys(overrides))
|
|
1002
|
+
end
|
|
1003
|
+
|
|
1004
|
+
def __dsl_builtin_task_presets
|
|
1005
|
+
{
|
|
1006
|
+
classification: {
|
|
1007
|
+
collate: :xy,
|
|
1008
|
+
monitor: :epoch_loss,
|
|
1009
|
+
monitor_mode: :min
|
|
1010
|
+
},
|
|
1011
|
+
regression: {
|
|
1012
|
+
collate: :xy,
|
|
1013
|
+
monitor: :epoch_loss,
|
|
1014
|
+
monitor_mode: :min
|
|
1015
|
+
},
|
|
1016
|
+
language_modeling: {
|
|
1017
|
+
collate: :xy,
|
|
1018
|
+
monitor: :perplexity,
|
|
1019
|
+
monitor_mode: :min,
|
|
1020
|
+
metric: ->(context) { Math.exp(context.fetch(:epoch_loss).to_f) }
|
|
1021
|
+
}
|
|
1022
|
+
}
|
|
1023
|
+
end
|
|
1024
|
+
|
|
1025
|
+
def __dsl_task_fit_options(task, overrides)
|
|
1026
|
+
key = task.to_sym
|
|
1027
|
+
unless @task_presets.key?(key)
|
|
1028
|
+
raise ArgumentError, "unknown task preset: #{task.inspect}"
|
|
1029
|
+
end
|
|
1030
|
+
|
|
1031
|
+
__dsl_normalize_fit_option_keys(@task_presets.fetch(key))
|
|
1032
|
+
.merge(__dsl_normalize_fit_option_keys(overrides))
|
|
1033
|
+
end
|
|
1034
|
+
|
|
1035
|
+
def __dsl_resolve_dataflow_profile(name)
|
|
1036
|
+
key = name.to_sym
|
|
1037
|
+
unless @dataflow_profiles.key?(key)
|
|
1038
|
+
raise ArgumentError, "unknown dataflow profile: #{name.inspect}"
|
|
1039
|
+
end
|
|
1040
|
+
|
|
1041
|
+
__dsl_clone_config_value(@dataflow_profiles.fetch(key))
|
|
1042
|
+
end
|
|
1043
|
+
|
|
1044
|
+
def __dsl_normalize_dataflow_profile(train:, validation:)
|
|
1045
|
+
{
|
|
1046
|
+
train: __dsl_normalize_dataflow_split(train, split: :train),
|
|
1047
|
+
validation: __dsl_normalize_dataflow_split(validation, split: :validation)
|
|
1048
|
+
}
|
|
1049
|
+
end
|
|
1050
|
+
|
|
1051
|
+
def __dsl_normalize_dataflow_split(spec, split:)
|
|
1052
|
+
return {} if spec.nil?
|
|
1053
|
+
unless spec.is_a?(Hash)
|
|
1054
|
+
raise ArgumentError, "#{split} dataflow spec must be a Hash or nil"
|
|
1055
|
+
end
|
|
1056
|
+
|
|
1057
|
+
normalized = spec.each_with_object({}) do |(key, value), out|
|
|
1058
|
+
out[key.to_sym] = value
|
|
1059
|
+
end
|
|
1060
|
+
allowed = %i[collate transform limit reduce]
|
|
1061
|
+
unknown = normalized.keys - allowed
|
|
1062
|
+
unless unknown.empty?
|
|
1063
|
+
raise ArgumentError, "#{split} dataflow spec has unsupported key(s): #{unknown.map(&:inspect).join(', ')}"
|
|
1064
|
+
end
|
|
1065
|
+
|
|
1066
|
+
normalized
|
|
1067
|
+
end
|
|
1068
|
+
|
|
1069
|
+
def __dsl_compose_dataflow_profile(base, overlay)
|
|
1070
|
+
{
|
|
1071
|
+
train: base.fetch(:train, {}).merge(overlay.fetch(:train, {})),
|
|
1072
|
+
validation: base.fetch(:validation, {}).merge(overlay.fetch(:validation, {}))
|
|
1073
|
+
}
|
|
1074
|
+
end
|
|
1075
|
+
|
|
1076
|
+
def __dsl_normalize_dataflow_overrides(overrides)
|
|
1077
|
+
profile_overrides = { train: {}, validation: {} }
|
|
1078
|
+
direct_overrides = {}
|
|
1079
|
+
(overrides || {}).each do |key, value|
|
|
1080
|
+
key = key.to_sym
|
|
1081
|
+
case key
|
|
1082
|
+
when :train, :validation
|
|
1083
|
+
profile_overrides[key] = __dsl_normalize_dataflow_split(value, split: key)
|
|
1084
|
+
else
|
|
1085
|
+
direct_overrides[key] = value
|
|
1086
|
+
end
|
|
1087
|
+
end
|
|
1088
|
+
[profile_overrides, direct_overrides]
|
|
1089
|
+
end
|
|
1090
|
+
|
|
1091
|
+
def __dsl_fit_kwargs_from_dataflow(profile)
|
|
1092
|
+
{
|
|
1093
|
+
collate: profile.fetch(:train).fetch(:collate, UNSET),
|
|
1094
|
+
train_transform: profile.fetch(:train).fetch(:transform, UNSET),
|
|
1095
|
+
limit: profile.fetch(:train).fetch(:limit, UNSET),
|
|
1096
|
+
reduce: profile.fetch(:train).fetch(:reduce, UNSET),
|
|
1097
|
+
validation_collate: profile.fetch(:validation).fetch(:collate, UNSET),
|
|
1098
|
+
validation_transform: profile.fetch(:validation).fetch(:transform, UNSET),
|
|
1099
|
+
validation_limit: profile.fetch(:validation).fetch(:limit, UNSET),
|
|
1100
|
+
validation_reduce: profile.fetch(:validation).fetch(:reduce, UNSET)
|
|
1101
|
+
}.each_with_object({}) do |(key, value), out|
|
|
1102
|
+
out[key] = value unless value.equal?(UNSET)
|
|
1103
|
+
end
|
|
1104
|
+
end
|
|
1105
|
+
|
|
1106
|
+
def __dsl_call_hook_pack(pack, options)
|
|
1107
|
+
values = { trainer: self, options: options || {} }
|
|
1108
|
+
(options || {}).each do |key, value|
|
|
1109
|
+
values[key.to_sym] = value
|
|
1110
|
+
end
|
|
1111
|
+
if !pack.respond_to?(:parameters) || pack.parameters.empty?
|
|
1112
|
+
return instance_exec(&pack) if pack.is_a?(Proc)
|
|
1113
|
+
|
|
1114
|
+
return pack.call
|
|
1115
|
+
end
|
|
1116
|
+
|
|
1117
|
+
params = pack.parameters
|
|
1118
|
+
args = __dsl_build_positional_args(
|
|
1119
|
+
params,
|
|
1120
|
+
values,
|
|
1121
|
+
[[:trainer, self], [:options, options || {}]],
|
|
1122
|
+
"hook pack"
|
|
1123
|
+
)
|
|
1124
|
+
kwargs = __dsl_build_keyword_args(params, values, "hook pack")
|
|
1125
|
+
return pack.call(*args) if kwargs.empty?
|
|
1126
|
+
|
|
1127
|
+
pack.call(*args, **kwargs)
|
|
1128
|
+
end
|
|
1129
|
+
|
|
1130
|
+
def __dsl_normalize_sync(sync)
|
|
1131
|
+
mode = sync.nil? ? :none : sync.to_sym
|
|
1132
|
+
return mode if %i[none step epoch].include?(mode)
|
|
1133
|
+
|
|
1134
|
+
raise ArgumentError, "trainer sync must be one of :none, :step, or :epoch"
|
|
1135
|
+
end
|
|
1136
|
+
|
|
1137
|
+
def __dsl_build_train_step(model, optimizer:, clip_grad_norm:, compile:, &loss_block)
|
|
1138
|
+
params = model.method(:train_step).parameters
|
|
1139
|
+
accepts_keyrest = params.any? { |type, _name| type == :keyrest }
|
|
1140
|
+
accepts_keyword = lambda do |key|
|
|
1141
|
+
accepts_keyrest || params.any? do |type, name|
|
|
1142
|
+
(type == :key || type == :keyreq) && name == key
|
|
1143
|
+
end
|
|
1144
|
+
end
|
|
1145
|
+
|
|
1146
|
+
kwargs = {
|
|
1147
|
+
optimizer: optimizer,
|
|
1148
|
+
clip_grad_norm: clip_grad_norm
|
|
1149
|
+
}
|
|
1150
|
+
kwargs[:compile] = compile if accepts_keyword.call(:compile)
|
|
1151
|
+
kwargs[:sync] = (@sync_mode == :step ? :step : :none) if accepts_keyword.call(:sync)
|
|
1152
|
+
|
|
1153
|
+
model.train_step(**kwargs, &loss_block)
|
|
1154
|
+
end
|
|
1155
|
+
|
|
1156
|
+
def __dsl_resume_state(resume_from, monitor_name)
|
|
1157
|
+
return __dsl_empty_resume_state(monitor_name) if resume_from.nil? || resume_from.to_s.empty?
|
|
1158
|
+
|
|
1159
|
+
source = resume_from
|
|
1160
|
+
source = __dsl_call_resume_loader(source, monitor_name) if source.respond_to?(:call)
|
|
1161
|
+
return __dsl_empty_resume_state(monitor_name) if source.nil?
|
|
1162
|
+
|
|
1163
|
+
if source.is_a?(Hash)
|
|
1164
|
+
payload = if __dsl_run_bundle_hash?(source)
|
|
1165
|
+
resume_payload_from_bundle(source)
|
|
1166
|
+
else
|
|
1167
|
+
source
|
|
1168
|
+
end
|
|
1169
|
+
resume_path = nil
|
|
1170
|
+
else
|
|
1171
|
+
source_path = source.to_s
|
|
1172
|
+
bundle_payload = __dsl_resume_payload_from_run_bundle_path(source_path)
|
|
1173
|
+
if bundle_payload.nil?
|
|
1174
|
+
unless @model.respond_to?(:load_checkpoint)
|
|
1175
|
+
raise ArgumentError, "resume_from requires model to implement #load_checkpoint"
|
|
1176
|
+
end
|
|
1177
|
+
|
|
1178
|
+
payload = __dsl_load_checkpoint_for_resume(source)
|
|
1179
|
+
resume_path = source_path
|
|
1180
|
+
else
|
|
1181
|
+
payload = bundle_payload
|
|
1182
|
+
resume_path = source_path
|
|
1183
|
+
end
|
|
1184
|
+
end
|
|
1185
|
+
unless payload.is_a?(Hash)
|
|
1186
|
+
raise ArgumentError, "resume checkpoint payload must be a Hash"
|
|
1187
|
+
end
|
|
1188
|
+
|
|
1189
|
+
metadata = payload["metadata"]
|
|
1190
|
+
metadata = {} unless metadata.is_a?(Hash)
|
|
1191
|
+
|
|
1192
|
+
checkpoint_epoch_value = __dsl_resume_state_value(payload, metadata, "epoch")
|
|
1193
|
+
checkpoint_epoch = checkpoint_epoch_value.nil? ? nil : checkpoint_epoch_value.to_i
|
|
1194
|
+
start_epoch = checkpoint_epoch.nil? ? 0 : checkpoint_epoch + 1
|
|
1195
|
+
best_metric = __dsl_resume_state_value(payload, metadata, "best_metric")
|
|
1196
|
+
stale_epochs_value = __dsl_resume_state_value(payload, metadata, "stale_epochs")
|
|
1197
|
+
stale_epochs = stale_epochs_value.nil? ? 0 : stale_epochs_value.to_i
|
|
1198
|
+
if stale_epochs.negative?
|
|
1199
|
+
raise ArgumentError, "resume checkpoint stale_epochs must be non-negative"
|
|
1200
|
+
end
|
|
1201
|
+
|
|
1202
|
+
resume_monitor_name = __dsl_resume_state_value(payload, metadata, "monitor_name")
|
|
1203
|
+
if !resume_monitor_name.nil? && resume_monitor_name.to_s != monitor_name.to_s
|
|
1204
|
+
raise ArgumentError,
|
|
1205
|
+
"resume checkpoint monitor_name #{resume_monitor_name.inspect} does not match requested monitor #{monitor_name.inspect}"
|
|
1206
|
+
end
|
|
1207
|
+
|
|
1208
|
+
{
|
|
1209
|
+
path: resume_path,
|
|
1210
|
+
checkpoint_epoch: checkpoint_epoch,
|
|
1211
|
+
start_epoch: start_epoch,
|
|
1212
|
+
best_metric: best_metric,
|
|
1213
|
+
stale_epochs: stale_epochs,
|
|
1214
|
+
monitor_name: monitor_name.to_s
|
|
1215
|
+
}
|
|
1216
|
+
end
|
|
1217
|
+
|
|
1218
|
+
def __dsl_run_bundle_hash?(value)
|
|
1219
|
+
value.is_a?(Hash) && value.key?("checkpoint") && value.key?("report")
|
|
1220
|
+
end
|
|
1221
|
+
|
|
1222
|
+
def __dsl_resume_payload_from_run_bundle_path(path)
|
|
1223
|
+
return nil if path.nil? || path.empty?
|
|
1224
|
+
return nil unless File.file?(path)
|
|
1225
|
+
|
|
1226
|
+
bundle = JSON.parse(File.binread(path))
|
|
1227
|
+
return nil unless __dsl_run_bundle_hash?(bundle)
|
|
1228
|
+
|
|
1229
|
+
resume_payload_from_bundle(bundle)
|
|
1230
|
+
rescue JSON::ParserError
|
|
1231
|
+
nil
|
|
1232
|
+
end
|
|
1233
|
+
|
|
1234
|
+
def __dsl_empty_resume_state(monitor_name)
|
|
1235
|
+
{
|
|
1236
|
+
path: nil,
|
|
1237
|
+
checkpoint_epoch: nil,
|
|
1238
|
+
start_epoch: 0,
|
|
1239
|
+
best_metric: nil,
|
|
1240
|
+
stale_epochs: 0,
|
|
1241
|
+
monitor_name: monitor_name.to_s
|
|
1242
|
+
}
|
|
1243
|
+
end
|
|
1244
|
+
|
|
1245
|
+
def __dsl_call_resume_loader(loader, monitor_name)
|
|
1246
|
+
values = {
|
|
1247
|
+
trainer: self,
|
|
1248
|
+
model: @model,
|
|
1249
|
+
optimizer: @optimizer,
|
|
1250
|
+
monitor_name: monitor_name.to_s
|
|
1251
|
+
}
|
|
1252
|
+
return loader.call unless loader.respond_to?(:parameters)
|
|
1253
|
+
|
|
1254
|
+
params = loader.parameters
|
|
1255
|
+
return loader.call if params.empty?
|
|
1256
|
+
|
|
1257
|
+
args = __dsl_build_positional_args(
|
|
1258
|
+
params,
|
|
1259
|
+
values,
|
|
1260
|
+
[[:trainer, self], [:model, @model], [:optimizer, @optimizer], [:monitor_name, monitor_name.to_s]],
|
|
1261
|
+
"resume loader"
|
|
1262
|
+
)
|
|
1263
|
+
kwargs = __dsl_build_keyword_args(params, values, "resume loader")
|
|
1264
|
+
return loader.call(*args) if kwargs.empty?
|
|
1265
|
+
|
|
1266
|
+
loader.call(*args, **kwargs)
|
|
1267
|
+
end
|
|
1268
|
+
|
|
1269
|
+
def __dsl_load_checkpoint_for_resume(path)
|
|
1270
|
+
params = @model.method(:load_checkpoint).parameters
|
|
1271
|
+
accepts_keyrest = params.any? { |type, _name| type == :keyrest }
|
|
1272
|
+
accepts_keyword = lambda do |key|
|
|
1273
|
+
accepts_keyrest || params.any? do |type, name|
|
|
1274
|
+
(type == :key || type == :keyreq) && name == key
|
|
1275
|
+
end
|
|
1276
|
+
end
|
|
1277
|
+
|
|
1278
|
+
kwargs = {}
|
|
1279
|
+
kwargs[:optimizer] = @optimizer if accepts_keyword.call(:optimizer)
|
|
1280
|
+
kwargs[:strict] = true if accepts_keyword.call(:strict)
|
|
1281
|
+
kwargs[:format] = nil if accepts_keyword.call(:format)
|
|
1282
|
+
return @model.load_checkpoint(path) if kwargs.empty?
|
|
1283
|
+
|
|
1284
|
+
@model.load_checkpoint(path, **kwargs)
|
|
1285
|
+
end
|
|
1286
|
+
|
|
1287
|
+
def __dsl_resume_state_value(payload, metadata, key)
|
|
1288
|
+
return metadata[key] if metadata.key?(key)
|
|
1289
|
+
|
|
1290
|
+
payload[key]
|
|
1291
|
+
end
|
|
1292
|
+
|
|
1293
|
+
def emit(event, context)
|
|
1294
|
+
@hooks[event.to_sym].sort_by { |entry| [entry.fetch(:priority), entry.fetch(:order)] }.each do |entry|
|
|
1295
|
+
entry[:invocations] += 1
|
|
1296
|
+
if !entry[:every].nil? && ((entry[:invocations] - 1) % entry[:every]).nonzero?
|
|
1297
|
+
next
|
|
1298
|
+
end
|
|
1299
|
+
if entry[:once] && entry[:fired]
|
|
1300
|
+
next
|
|
1301
|
+
end
|
|
1302
|
+
unless __dsl_hook_condition_met?(entry[:if], context)
|
|
1303
|
+
next
|
|
1304
|
+
end
|
|
1305
|
+
|
|
1306
|
+
entry.fetch(:hook).call(context)
|
|
1307
|
+
entry[:fired] = true if entry[:once]
|
|
1308
|
+
end
|
|
1309
|
+
end
|
|
1310
|
+
|
|
1311
|
+
def __dsl_hook_condition_met?(condition, context)
|
|
1312
|
+
return true if condition.nil?
|
|
1313
|
+
return !!condition unless condition.respond_to?(:call)
|
|
1314
|
+
return !!condition.call unless condition.respond_to?(:parameters)
|
|
1315
|
+
|
|
1316
|
+
params = condition.parameters
|
|
1317
|
+
return !!condition.call if params.empty?
|
|
1318
|
+
|
|
1319
|
+
if params.any? { |type, _name| type == :keyrest || type == :key || type == :keyreq }
|
|
1320
|
+
return !!condition.call(context: context)
|
|
1321
|
+
end
|
|
1322
|
+
|
|
1323
|
+
!!condition.call(context)
|
|
1324
|
+
end
|
|
1325
|
+
|
|
1326
|
+
def __dsl_sync_epoch(loss)
|
|
1327
|
+
return unless defined?(MLX::Core) && MLX::Core.respond_to?(:eval)
|
|
1328
|
+
|
|
1329
|
+
targets = []
|
|
1330
|
+
targets << loss unless loss.nil?
|
|
1331
|
+
targets << @model.parameters if @model.respond_to?(:parameters)
|
|
1332
|
+
targets << @optimizer.state if @optimizer.respond_to?(:state)
|
|
1333
|
+
return if targets.empty?
|
|
1334
|
+
|
|
1335
|
+
MLX::Core.eval(*targets)
|
|
1336
|
+
end
|
|
1337
|
+
|
|
1338
|
+
def __dsl_apply_collate(collate, batch, kind:, epoch:, batch_index:)
|
|
1339
|
+
collate = __dsl_resolve_registered_collate(collate, kind: kind)
|
|
1340
|
+
collate = __dsl_auto_collate_spec(kind, batch) if collate.to_s == "auto"
|
|
1341
|
+
return batch if collate.nil?
|
|
1342
|
+
if collate.respond_to?(:call)
|
|
1343
|
+
return __dsl_call_collate_callable(
|
|
1344
|
+
collate,
|
|
1345
|
+
batch,
|
|
1346
|
+
kind: kind,
|
|
1347
|
+
epoch: epoch,
|
|
1348
|
+
batch_index: batch_index
|
|
1349
|
+
)
|
|
1350
|
+
end
|
|
1351
|
+
|
|
1352
|
+
case collate
|
|
1353
|
+
when String, Symbol
|
|
1354
|
+
__dsl_apply_named_collate(collate.to_sym, batch, kind: kind)
|
|
1355
|
+
when Hash
|
|
1356
|
+
__dsl_apply_mapping_collate(
|
|
1357
|
+
collate,
|
|
1358
|
+
batch,
|
|
1359
|
+
kind: kind,
|
|
1360
|
+
epoch: epoch,
|
|
1361
|
+
batch_index: batch_index
|
|
1362
|
+
)
|
|
1363
|
+
else
|
|
1364
|
+
raise ArgumentError, "#{kind} collate must be a Proc, Symbol/String, Hash, or nil"
|
|
1365
|
+
end
|
|
1366
|
+
end
|
|
1367
|
+
|
|
1368
|
+
def __dsl_effective_collate(collate, bind, batch, kind:)
|
|
1369
|
+
return collate unless collate.nil?
|
|
1370
|
+
return nil if bind.nil?
|
|
1371
|
+
|
|
1372
|
+
__dsl_bind_to_collate(bind, batch, kind: kind)
|
|
1373
|
+
end
|
|
1374
|
+
|
|
1375
|
+
def __dsl_bind_to_collate(bind, batch, kind:)
|
|
1376
|
+
case bind
|
|
1377
|
+
when true
|
|
1378
|
+
__dsl_infer_bind_mapping(kind, batch)
|
|
1379
|
+
when String, Symbol
|
|
1380
|
+
return __dsl_infer_bind_mapping(kind, batch) if bind.to_s == "auto"
|
|
1381
|
+
|
|
1382
|
+
bind
|
|
1383
|
+
when Hash
|
|
1384
|
+
bind
|
|
1385
|
+
when Proc
|
|
1386
|
+
__dsl_call_collate_callable(
|
|
1387
|
+
bind,
|
|
1388
|
+
batch,
|
|
1389
|
+
kind: kind,
|
|
1390
|
+
epoch: 0,
|
|
1391
|
+
batch_index: 0,
|
|
1392
|
+
label: "#{kind} bind"
|
|
1393
|
+
)
|
|
1394
|
+
else
|
|
1395
|
+
raise ArgumentError, "#{kind} bind must be :auto/true, collate spec, or hash mapping"
|
|
1396
|
+
end
|
|
1397
|
+
end
|
|
1398
|
+
|
|
1399
|
+
def __dsl_infer_bind_mapping(kind, batch)
|
|
1400
|
+
target_params = __dsl_keyword_names_for_callable(kind == :train ? @step.method(:call) : @loss_block)
|
|
1401
|
+
return :xy if target_params.empty? && batch.is_a?(Array) && batch.length >= 2
|
|
1402
|
+
return :x if target_params.empty? && !batch.is_a?(Hash)
|
|
1403
|
+
|
|
1404
|
+
target_params.each_with_index.each_with_object({}) do |(name, index), out|
|
|
1405
|
+
out[name] = __dsl_infer_bind_selector(batch, name, index)
|
|
1406
|
+
end
|
|
1407
|
+
end
|
|
1408
|
+
|
|
1409
|
+
def __dsl_keyword_names_for_callable(callable)
|
|
1410
|
+
return [] unless callable.respond_to?(:parameters)
|
|
1411
|
+
|
|
1412
|
+
callable.parameters.each_with_object([]) do |(type, name), out|
|
|
1413
|
+
out << name if (type == :key || type == :keyreq) && !name.nil?
|
|
1414
|
+
end
|
|
1415
|
+
end
|
|
1416
|
+
|
|
1417
|
+
def __dsl_infer_bind_selector(batch, name, index)
|
|
1418
|
+
if batch.is_a?(Hash)
|
|
1419
|
+
candidates = __dsl_bind_candidates_for(name)
|
|
1420
|
+
found = candidates.find do |candidate|
|
|
1421
|
+
batch.key?(candidate) || batch.key?(candidate.to_s) || batch.key?(candidate.to_sym)
|
|
1422
|
+
end
|
|
1423
|
+
return found unless found.nil?
|
|
1424
|
+
|
|
1425
|
+
return name
|
|
1426
|
+
end
|
|
1427
|
+
|
|
1428
|
+
return index if batch.respond_to?(:[]) && !batch.is_a?(Hash)
|
|
1429
|
+
|
|
1430
|
+
name
|
|
1431
|
+
end
|
|
1432
|
+
|
|
1433
|
+
def __dsl_bind_candidates_for(name)
|
|
1434
|
+
base = [name, name.to_s, name.to_sym]
|
|
1435
|
+
aliases = case name.to_sym
|
|
1436
|
+
when :x
|
|
1437
|
+
%i[input inputs feature features token tokens]
|
|
1438
|
+
when :y
|
|
1439
|
+
%i[target targets label labels output outputs]
|
|
1440
|
+
else
|
|
1441
|
+
[]
|
|
1442
|
+
end
|
|
1443
|
+
base + aliases + aliases.map(&:to_s)
|
|
1444
|
+
end
|
|
1445
|
+
|
|
1446
|
+
def __dsl_auto_collate_spec(kind, batch)
|
|
1447
|
+
schema = @batch_schemas.fetch(kind, nil)
|
|
1448
|
+
return schema unless schema.nil?
|
|
1449
|
+
|
|
1450
|
+
if batch.is_a?(Hash)
|
|
1451
|
+
has_x = batch.key?(:x) || batch.key?("x")
|
|
1452
|
+
has_y = batch.key?(:y) || batch.key?("y")
|
|
1453
|
+
return :xy if has_x && has_y
|
|
1454
|
+
return { x: :x } if has_x
|
|
1455
|
+
|
|
1456
|
+
return nil
|
|
1457
|
+
end
|
|
1458
|
+
|
|
1459
|
+
return :xy if batch.is_a?(Array) && batch.length >= 2
|
|
1460
|
+
|
|
1461
|
+
:x
|
|
1462
|
+
end
|
|
1463
|
+
|
|
1464
|
+
def __dsl_normalize_batch_schema_spec(spec)
|
|
1465
|
+
return nil if spec.nil?
|
|
1466
|
+
return spec if spec.respond_to?(:call)
|
|
1467
|
+
|
|
1468
|
+
case spec
|
|
1469
|
+
when String, Symbol, Hash
|
|
1470
|
+
spec
|
|
1471
|
+
else
|
|
1472
|
+
raise ArgumentError, "batch schema must be a collate spec (Symbol/String, Hash, Proc, or nil)"
|
|
1473
|
+
end
|
|
1474
|
+
end
|
|
1475
|
+
|
|
1476
|
+
def __dsl_resolve_registered_collate(collate, kind:)
|
|
1477
|
+
current = collate
|
|
1478
|
+
seen = []
|
|
1479
|
+
while current.is_a?(String) || current.is_a?(Symbol)
|
|
1480
|
+
key = current.to_sym
|
|
1481
|
+
break unless @collate_registry.key?(key)
|
|
1482
|
+
if seen.include?(key)
|
|
1483
|
+
cycle = (seen + [key]).map(&:inspect).join(" -> ")
|
|
1484
|
+
raise ArgumentError, "cyclic #{kind} collate registry reference: #{cycle}"
|
|
1485
|
+
end
|
|
1486
|
+
|
|
1487
|
+
seen << key
|
|
1488
|
+
current = @collate_registry.fetch(key)
|
|
1489
|
+
end
|
|
1490
|
+
current
|
|
1491
|
+
end
|
|
1492
|
+
|
|
1493
|
+
def __dsl_compose_collate(base_spec, overlay_spec)
|
|
1494
|
+
if base_spec.is_a?(Hash) && overlay_spec.is_a?(Hash)
|
|
1495
|
+
return base_spec.merge(overlay_spec)
|
|
1496
|
+
end
|
|
1497
|
+
|
|
1498
|
+
raise ArgumentError, "collate extends composition supports Hash schemas only"
|
|
1499
|
+
end
|
|
1500
|
+
|
|
1501
|
+
def __dsl_apply_named_collate(name, batch, kind:)
|
|
1502
|
+
case name
|
|
1503
|
+
when :xy
|
|
1504
|
+
return batch if batch.is_a?(Hash) && (batch.key?(:x) || batch.key?("x")) && (batch.key?(:y) || batch.key?("y"))
|
|
1505
|
+
unless batch.is_a?(Array) && batch.length >= 2
|
|
1506
|
+
raise ArgumentError, "#{kind} collate :xy expects a 2-item array batch"
|
|
1507
|
+
end
|
|
1508
|
+
|
|
1509
|
+
{ x: batch[0], y: batch[1] }
|
|
1510
|
+
when :x
|
|
1511
|
+
{ x: batch }
|
|
1512
|
+
else
|
|
1513
|
+
raise ArgumentError, "unknown #{kind} collate schema: #{name.inspect}"
|
|
1514
|
+
end
|
|
1515
|
+
end
|
|
1516
|
+
|
|
1517
|
+
def __dsl_apply_mapping_collate(mapping, batch, kind:, epoch:, batch_index:)
|
|
1518
|
+
mapping.each_with_object({}) do |(out_key, selector), out|
|
|
1519
|
+
out[out_key] = __dsl_collate_select(
|
|
1520
|
+
batch,
|
|
1521
|
+
selector,
|
|
1522
|
+
kind: kind,
|
|
1523
|
+
epoch: epoch,
|
|
1524
|
+
batch_index: batch_index
|
|
1525
|
+
)
|
|
1526
|
+
end
|
|
1527
|
+
end
|
|
1528
|
+
|
|
1529
|
+
def __dsl_collate_select(batch, selector, kind:, epoch:, batch_index:)
|
|
1530
|
+
case selector
|
|
1531
|
+
when Integer
|
|
1532
|
+
unless batch.respond_to?(:[]) && !batch.is_a?(Hash)
|
|
1533
|
+
raise ArgumentError, "#{kind} collate integer selector requires indexable non-hash batch"
|
|
1534
|
+
end
|
|
1535
|
+
batch[selector]
|
|
1536
|
+
when String, Symbol
|
|
1537
|
+
__dsl_collate_fetch_key(batch, selector, kind: kind)
|
|
1538
|
+
when Proc
|
|
1539
|
+
__dsl_call_collate_callable(
|
|
1540
|
+
selector,
|
|
1541
|
+
batch,
|
|
1542
|
+
kind: kind,
|
|
1543
|
+
epoch: epoch,
|
|
1544
|
+
batch_index: batch_index,
|
|
1545
|
+
label: "#{kind} collate selector"
|
|
1546
|
+
)
|
|
1547
|
+
when Array
|
|
1548
|
+
__dsl_collate_select_path(batch, selector, kind: kind)
|
|
1549
|
+
else
|
|
1550
|
+
raise ArgumentError, "#{kind} collate selector must be Integer, String/Symbol, Proc, or Array path"
|
|
1551
|
+
end
|
|
1552
|
+
end
|
|
1553
|
+
|
|
1554
|
+
def __dsl_call_collate_callable(callable, batch, kind:, epoch:, batch_index:, label: nil)
|
|
1555
|
+
values = {
|
|
1556
|
+
batch: batch,
|
|
1557
|
+
epoch: epoch,
|
|
1558
|
+
batch_index: batch_index,
|
|
1559
|
+
kind: kind,
|
|
1560
|
+
trainer: self
|
|
1561
|
+
}
|
|
1562
|
+
return callable.call(batch) unless callable.respond_to?(:parameters)
|
|
1563
|
+
|
|
1564
|
+
params = callable.parameters
|
|
1565
|
+
return callable.call(batch) if params.empty?
|
|
1566
|
+
|
|
1567
|
+
callable_label = label || "#{kind} collate"
|
|
1568
|
+
args = __dsl_build_positional_args(
|
|
1569
|
+
params,
|
|
1570
|
+
values,
|
|
1571
|
+
[[:batch, batch], [:epoch, epoch], [:batch_index, batch_index], [:kind, kind], [:trainer, self]],
|
|
1572
|
+
callable_label
|
|
1573
|
+
)
|
|
1574
|
+
kwargs = __dsl_build_keyword_args(params, values, callable_label)
|
|
1575
|
+
return callable.call(*args) if kwargs.empty?
|
|
1576
|
+
|
|
1577
|
+
callable.call(*args, **kwargs)
|
|
1578
|
+
end
|
|
1579
|
+
|
|
1580
|
+
def __dsl_collate_select_path(batch, selector_path, kind:)
|
|
1581
|
+
if selector_path.empty?
|
|
1582
|
+
raise ArgumentError, "#{kind} collate selector path must not be empty"
|
|
1583
|
+
end
|
|
1584
|
+
|
|
1585
|
+
selector_path.each_with_index.reduce(batch) do |current, (selector, depth)|
|
|
1586
|
+
__dsl_collate_select_path_segment(current, selector, selector_path, depth, kind: kind)
|
|
1587
|
+
end
|
|
1588
|
+
end
|
|
1589
|
+
|
|
1590
|
+
def __dsl_collate_select_path_segment(current, selector, selector_path, depth, kind:)
|
|
1591
|
+
case selector
|
|
1592
|
+
when Integer
|
|
1593
|
+
unless current.respond_to?(:[]) && !current.is_a?(Hash)
|
|
1594
|
+
raise ArgumentError, "#{kind} collate path #{selector_path.inspect} expected indexable non-hash at depth #{depth}"
|
|
1595
|
+
end
|
|
1596
|
+
current[selector]
|
|
1597
|
+
when String, Symbol
|
|
1598
|
+
__dsl_collate_fetch_key(
|
|
1599
|
+
current,
|
|
1600
|
+
selector,
|
|
1601
|
+
kind: kind,
|
|
1602
|
+
context: "in path #{selector_path.inspect} at depth #{depth}"
|
|
1603
|
+
)
|
|
1604
|
+
else
|
|
1605
|
+
raise ArgumentError, "#{kind} collate path #{selector_path.inspect} contains unsupported selector #{selector.inspect}"
|
|
1606
|
+
end
|
|
1607
|
+
end
|
|
1608
|
+
|
|
1609
|
+
def __dsl_collate_fetch_key(batch, selector, kind:, context: nil)
|
|
1610
|
+
unless batch.is_a?(Hash)
|
|
1611
|
+
extra = context.nil? ? "" : " #{context}"
|
|
1612
|
+
raise ArgumentError, "#{kind} collate key selector #{selector.inspect} requires hash batch#{extra}"
|
|
1613
|
+
end
|
|
1614
|
+
|
|
1615
|
+
return batch.fetch(selector) if batch.key?(selector)
|
|
1616
|
+
|
|
1617
|
+
str_key = selector.to_s
|
|
1618
|
+
return batch.fetch(str_key) if batch.key?(str_key)
|
|
1619
|
+
|
|
1620
|
+
sym_key = str_key.to_sym
|
|
1621
|
+
return batch.fetch(sym_key) if batch.key?(sym_key)
|
|
1622
|
+
|
|
1623
|
+
extra = context.nil? ? "" : " #{context}"
|
|
1624
|
+
raise ArgumentError, "#{kind} collate key selector #{selector.inspect} was not found in batch#{extra}"
|
|
1625
|
+
end
|
|
1626
|
+
|
|
1627
|
+
def __dsl_run_batch(batch, epoch:, batch_index:, kind:)
|
|
1628
|
+
if batch.is_a?(Hash)
|
|
1629
|
+
@step.call(**__dsl_normalize_batch_kwargs(batch, label: "#{kind} batch"))
|
|
1630
|
+
elsif batch.is_a?(Array)
|
|
1631
|
+
@step.call(*batch)
|
|
1632
|
+
else
|
|
1633
|
+
@step.call(batch)
|
|
1634
|
+
end
|
|
1635
|
+
rescue StandardError => e
|
|
1636
|
+
__dsl_raise_batch_error!(e, kind: kind, epoch: epoch, batch_index: batch_index)
|
|
1637
|
+
end
|
|
1638
|
+
|
|
1639
|
+
def __dsl_loss_scalar(loss)
|
|
1640
|
+
return nil if loss.nil?
|
|
1641
|
+
return loss.to_f if loss.is_a?(Numeric)
|
|
1642
|
+
|
|
1643
|
+
return loss.item.to_f if loss.respond_to?(:item)
|
|
1644
|
+
return loss.to_f if loss.respond_to?(:to_f)
|
|
1645
|
+
|
|
1646
|
+
nil
|
|
1647
|
+
end
|
|
1648
|
+
|
|
1649
|
+
def __dsl_default_monitor_value(monitor_name, epoch_metric, val_metric)
|
|
1650
|
+
case monitor_name.to_s
|
|
1651
|
+
when "val_loss", "validation_loss"
|
|
1652
|
+
val_metric.nil? ? epoch_metric : val_metric
|
|
1653
|
+
else
|
|
1654
|
+
epoch_metric
|
|
1655
|
+
end
|
|
1656
|
+
end
|
|
1657
|
+
|
|
1658
|
+
def __dsl_monitor_value(metric, context, fallback:)
|
|
1659
|
+
metric_callable = __dsl_resolve_metric_callable(metric)
|
|
1660
|
+
return fallback if metric_callable.nil?
|
|
1661
|
+
return metric_callable.call(context) unless metric_callable.respond_to?(:parameters)
|
|
1662
|
+
|
|
1663
|
+
params = metric_callable.parameters
|
|
1664
|
+
return metric_callable.call(context) if params.empty?
|
|
1665
|
+
|
|
1666
|
+
values = {
|
|
1667
|
+
context: context,
|
|
1668
|
+
trainer: self
|
|
1669
|
+
}
|
|
1670
|
+
args = __dsl_build_positional_args(
|
|
1671
|
+
params,
|
|
1672
|
+
values,
|
|
1673
|
+
[[:context, context], [:trainer, self]],
|
|
1674
|
+
"metric callable"
|
|
1675
|
+
)
|
|
1676
|
+
kwargs = __dsl_build_keyword_args(params, values, "metric callable")
|
|
1677
|
+
return metric_callable.call(*args) if kwargs.empty?
|
|
1678
|
+
|
|
1679
|
+
metric_callable.call(*args, **kwargs)
|
|
1680
|
+
end
|
|
1681
|
+
|
|
1682
|
+
def __dsl_resolve_metric_callable(metric)
|
|
1683
|
+
return nil if metric.nil?
|
|
1684
|
+
return metric if metric.respond_to?(:call)
|
|
1685
|
+
return nil unless metric.is_a?(String) || metric.is_a?(Symbol)
|
|
1686
|
+
|
|
1687
|
+
@metric_registry[metric.to_sym]
|
|
1688
|
+
end
|
|
1689
|
+
|
|
1690
|
+
def __dsl_run_validation_batch(batch, epoch:, batch_index:, kind:)
|
|
1691
|
+
if batch.is_a?(Hash)
|
|
1692
|
+
@loss_block.call(**__dsl_normalize_batch_kwargs(batch, label: "#{kind} batch"))
|
|
1693
|
+
elsif batch.is_a?(Array)
|
|
1694
|
+
@loss_block.call(*batch)
|
|
1695
|
+
else
|
|
1696
|
+
@loss_block.call(batch)
|
|
1697
|
+
end
|
|
1698
|
+
rescue StandardError => e
|
|
1699
|
+
__dsl_raise_batch_error!(e, kind: kind, epoch: epoch, batch_index: batch_index)
|
|
1700
|
+
end
|
|
1701
|
+
|
|
1702
|
+
def __dsl_raise_batch_error!(error, kind:, epoch:, batch_index:)
|
|
1703
|
+
prefix = "#{kind} batch failed at epoch #{epoch}, batch #{batch_index}"
|
|
1704
|
+
raise error.class, "#{prefix}: #{error.message}"
|
|
1705
|
+
end
|
|
1706
|
+
|
|
1707
|
+
def __dsl_normalize_batch_kwargs(batch, label:)
|
|
1708
|
+
batch.each_with_object({}) do |(key, value), out|
|
|
1709
|
+
kw_key = __dsl_normalize_keyword_key(key, label: label)
|
|
1710
|
+
if out.key?(kw_key)
|
|
1711
|
+
raise ArgumentError, "#{label} contains duplicate keyword after normalization: #{kw_key.inspect}"
|
|
1712
|
+
end
|
|
1713
|
+
|
|
1714
|
+
out[kw_key] = value
|
|
1715
|
+
end
|
|
1716
|
+
end
|
|
1717
|
+
|
|
1718
|
+
def __dsl_normalize_keyword_key(key, label:)
|
|
1719
|
+
return key if key.is_a?(Symbol)
|
|
1720
|
+
return key.to_sym if key.respond_to?(:to_sym)
|
|
1721
|
+
|
|
1722
|
+
raise ArgumentError, "#{label} key #{key.inspect} cannot be converted to keyword symbol"
|
|
1723
|
+
end
|
|
1724
|
+
|
|
1725
|
+
def __dsl_reduce_values(values, reducer)
|
|
1726
|
+
return nil if values.empty?
|
|
1727
|
+
|
|
1728
|
+
if reducer.respond_to?(:call)
|
|
1729
|
+
return reducer.call(values)
|
|
1730
|
+
end
|
|
1731
|
+
|
|
1732
|
+
case reducer.to_sym
|
|
1733
|
+
when :mean
|
|
1734
|
+
values.sum / values.length.to_f
|
|
1735
|
+
when :sum
|
|
1736
|
+
values.sum
|
|
1737
|
+
when :last
|
|
1738
|
+
values[-1]
|
|
1739
|
+
else
|
|
1740
|
+
raise ArgumentError, "unsupported reducer: #{reducer.inspect}"
|
|
1741
|
+
end
|
|
1742
|
+
end
|
|
1743
|
+
|
|
1744
|
+
def __dsl_improved?(metric, best_metric, monitor_mode, min_delta: 0.0)
|
|
1745
|
+
return false if metric.nil?
|
|
1746
|
+
return true if best_metric.nil?
|
|
1747
|
+
|
|
1748
|
+
case monitor_mode.to_sym
|
|
1749
|
+
when :min
|
|
1750
|
+
metric < (best_metric - min_delta)
|
|
1751
|
+
when :max
|
|
1752
|
+
metric > (best_metric + min_delta)
|
|
1753
|
+
else
|
|
1754
|
+
raise ArgumentError, "unsupported monitor_mode: #{monitor_mode.inspect}"
|
|
1755
|
+
end
|
|
1756
|
+
end
|
|
1757
|
+
|
|
1758
|
+
def __dsl_normalize_patience(patience)
|
|
1759
|
+
return nil if patience.nil?
|
|
1760
|
+
|
|
1761
|
+
value = patience.to_i
|
|
1762
|
+
raise ArgumentError, "patience must be non-negative" if value.negative?
|
|
1763
|
+
|
|
1764
|
+
value
|
|
1765
|
+
end
|
|
1766
|
+
|
|
1767
|
+
def __dsl_apply_batch_transform(transform, batch, epoch:, batch_index:, kind:)
|
|
1768
|
+
return batch if transform.nil?
|
|
1769
|
+
unless transform.respond_to?(:call)
|
|
1770
|
+
raise ArgumentError, "#{kind} transform must respond to #call"
|
|
1771
|
+
end
|
|
1772
|
+
|
|
1773
|
+
values = {
|
|
1774
|
+
batch: batch,
|
|
1775
|
+
epoch: epoch,
|
|
1776
|
+
batch_index: batch_index,
|
|
1777
|
+
kind: kind,
|
|
1778
|
+
trainer: self
|
|
1779
|
+
}
|
|
1780
|
+
return transform.call(batch) unless transform.respond_to?(:parameters)
|
|
1781
|
+
|
|
1782
|
+
params = transform.parameters
|
|
1783
|
+
return transform.call(batch) if params.empty?
|
|
1784
|
+
|
|
1785
|
+
args = __dsl_build_positional_args(
|
|
1786
|
+
params,
|
|
1787
|
+
values,
|
|
1788
|
+
[[:batch, batch], [:epoch, epoch], [:batch_index, batch_index], [:kind, kind], [:trainer, self]],
|
|
1789
|
+
"batch transform"
|
|
1790
|
+
)
|
|
1791
|
+
kwargs = __dsl_build_keyword_args(params, values, "batch transform")
|
|
1792
|
+
return transform.call(*args) if kwargs.empty?
|
|
1793
|
+
|
|
1794
|
+
transform.call(*args, **kwargs)
|
|
1795
|
+
end
|
|
1796
|
+
|
|
1797
|
+
def __dsl_normalize_min_delta(min_delta)
|
|
1798
|
+
value = min_delta.to_f
|
|
1799
|
+
raise ArgumentError, "min_delta must be non-negative" if value.negative?
|
|
1800
|
+
|
|
1801
|
+
value
|
|
1802
|
+
end
|
|
1803
|
+
|
|
1804
|
+
def __dsl_maybe_checkpoint(
|
|
1805
|
+
path,
|
|
1806
|
+
save_best:,
|
|
1807
|
+
improved:,
|
|
1808
|
+
epoch:,
|
|
1809
|
+
monitor_name:,
|
|
1810
|
+
monitor_value:,
|
|
1811
|
+
epoch_metric:,
|
|
1812
|
+
stale_epochs:,
|
|
1813
|
+
best_metric:,
|
|
1814
|
+
metadata:,
|
|
1815
|
+
checkpoint_every: nil,
|
|
1816
|
+
keep_last_n: nil
|
|
1817
|
+
)
|
|
1818
|
+
return false if path.nil? || path.to_s.empty?
|
|
1819
|
+
unless checkpoint_every.nil?
|
|
1820
|
+
every = checkpoint_every.to_i
|
|
1821
|
+
raise ArgumentError, "checkpoint every must be positive" if every <= 0
|
|
1822
|
+
return false if ((epoch + 1) % every).nonzero?
|
|
1823
|
+
end
|
|
1824
|
+
return false if save_best && !improved
|
|
1825
|
+
return false unless @model.respond_to?(:save_checkpoint)
|
|
1826
|
+
|
|
1827
|
+
resolved_path = __dsl_checkpoint_path(
|
|
1828
|
+
path,
|
|
1829
|
+
epoch: epoch,
|
|
1830
|
+
monitor_name: monitor_name,
|
|
1831
|
+
monitor_value: monitor_value,
|
|
1832
|
+
epoch_metric: epoch_metric,
|
|
1833
|
+
improved: improved
|
|
1834
|
+
)
|
|
1835
|
+
|
|
1836
|
+
merged_metadata = (metadata || {}).dup
|
|
1837
|
+
merged_metadata["epoch"] = epoch
|
|
1838
|
+
merged_metadata["epoch_loss"] = epoch_metric
|
|
1839
|
+
merged_metadata["monitor_name"] = monitor_name
|
|
1840
|
+
merged_metadata["monitor_value"] = monitor_value
|
|
1841
|
+
merged_metadata["stale_epochs"] = stale_epochs
|
|
1842
|
+
merged_metadata["best_metric"] = best_metric
|
|
1843
|
+
merged_metadata["next_epoch"] = epoch + 1
|
|
1844
|
+
|
|
1845
|
+
@model.save_checkpoint(resolved_path, optimizer: @optimizer, metadata: merged_metadata)
|
|
1846
|
+
@last_checkpoint_snapshot = {
|
|
1847
|
+
"path" => resolved_path,
|
|
1848
|
+
"epoch" => epoch,
|
|
1849
|
+
"monitor_name" => monitor_name,
|
|
1850
|
+
"monitor_value" => monitor_value,
|
|
1851
|
+
"epoch_loss" => epoch_metric,
|
|
1852
|
+
"metadata" => __dsl_deep_copy(merged_metadata)
|
|
1853
|
+
}
|
|
1854
|
+
@checkpoint_history << __dsl_deep_copy(@last_checkpoint_snapshot)
|
|
1855
|
+
unless keep_last_n.nil?
|
|
1856
|
+
keep = keep_last_n.to_i
|
|
1857
|
+
raise ArgumentError, "artifact retention keep_last_n must be non-negative" if keep.negative?
|
|
1858
|
+
@checkpoint_history.shift while @checkpoint_history.length > keep
|
|
1859
|
+
end
|
|
1860
|
+
emit(
|
|
1861
|
+
:checkpoint,
|
|
1862
|
+
{
|
|
1863
|
+
model: @model,
|
|
1864
|
+
optimizer: @optimizer,
|
|
1865
|
+
path: resolved_path,
|
|
1866
|
+
epoch: epoch,
|
|
1867
|
+
monitor_name: monitor_name,
|
|
1868
|
+
monitor_value: monitor_value,
|
|
1869
|
+
epoch_loss: epoch_metric,
|
|
1870
|
+
improved: improved
|
|
1871
|
+
}
|
|
1872
|
+
)
|
|
1873
|
+
true
|
|
1874
|
+
end
|
|
1875
|
+
|
|
1876
|
+
def __dsl_checkpoint_path(path, epoch:, monitor_name:, monitor_value:, epoch_metric:, improved:)
|
|
1877
|
+
if path.respond_to?(:call)
|
|
1878
|
+
values = {
|
|
1879
|
+
epoch: epoch,
|
|
1880
|
+
next_epoch: epoch + 1,
|
|
1881
|
+
monitor: monitor_value,
|
|
1882
|
+
monitor_name: monitor_name,
|
|
1883
|
+
epoch_loss: epoch_metric,
|
|
1884
|
+
improved: improved,
|
|
1885
|
+
trainer: self,
|
|
1886
|
+
model: @model,
|
|
1887
|
+
optimizer: @optimizer
|
|
1888
|
+
}
|
|
1889
|
+
path = if !path.respond_to?(:parameters) || path.parameters.empty?
|
|
1890
|
+
path.call
|
|
1891
|
+
else
|
|
1892
|
+
args = __dsl_build_positional_args(
|
|
1893
|
+
path.parameters,
|
|
1894
|
+
values,
|
|
1895
|
+
[
|
|
1896
|
+
[:epoch, epoch],
|
|
1897
|
+
[:next_epoch, epoch + 1],
|
|
1898
|
+
[:monitor, monitor_value],
|
|
1899
|
+
[:monitor_name, monitor_name],
|
|
1900
|
+
[:epoch_loss, epoch_metric],
|
|
1901
|
+
[:improved, improved],
|
|
1902
|
+
[:trainer, self],
|
|
1903
|
+
[:model, @model],
|
|
1904
|
+
[:optimizer, @optimizer]
|
|
1905
|
+
],
|
|
1906
|
+
"checkpoint path"
|
|
1907
|
+
)
|
|
1908
|
+
kwargs = __dsl_build_keyword_args(path.parameters, values, "checkpoint path")
|
|
1909
|
+
kwargs.empty? ? path.call(*args) : path.call(*args, **kwargs)
|
|
1910
|
+
end
|
|
1911
|
+
unless path.respond_to?(:to_str)
|
|
1912
|
+
raise ArgumentError, "checkpoint path callable must return a String-compatible path"
|
|
1913
|
+
end
|
|
1914
|
+
|
|
1915
|
+
path = path.to_str
|
|
1916
|
+
end
|
|
1917
|
+
|
|
1918
|
+
template = path.to_s
|
|
1919
|
+
return template unless template.include?("%{")
|
|
1920
|
+
|
|
1921
|
+
template % {
|
|
1922
|
+
epoch: epoch,
|
|
1923
|
+
next_epoch: epoch + 1,
|
|
1924
|
+
monitor: monitor_value,
|
|
1925
|
+
monitor_name: monitor_name,
|
|
1926
|
+
epoch_loss: epoch_metric,
|
|
1927
|
+
improved: improved
|
|
1928
|
+
}
|
|
1929
|
+
rescue KeyError => e
|
|
1930
|
+
raise ArgumentError, "unsupported checkpoint template key: #{e.message}"
|
|
1931
|
+
end
|
|
1932
|
+
|
|
1933
|
+
def __dsl_resolve_run_bundle(bundle_or_path)
|
|
1934
|
+
if bundle_or_path.is_a?(Hash)
|
|
1935
|
+
return bundle_or_path
|
|
1936
|
+
end
|
|
1937
|
+
|
|
1938
|
+
path = bundle_or_path.to_s
|
|
1939
|
+
if path.empty?
|
|
1940
|
+
raise ArgumentError, "run bundle source must be a bundle hash or path"
|
|
1941
|
+
end
|
|
1942
|
+
|
|
1943
|
+
JSON.parse(File.binread(path))
|
|
1944
|
+
end
|
|
1945
|
+
|
|
1946
|
+
def __dsl_deep_copy(value)
|
|
1947
|
+
return nil if value.nil?
|
|
1948
|
+
|
|
1949
|
+
Marshal.load(Marshal.dump(value))
|
|
1950
|
+
end
|
|
1951
|
+
|
|
1952
|
+
def __dsl_dataset_size(dataset)
|
|
1953
|
+
return nil if dataset.nil? || dataset.respond_to?(:call)
|
|
1954
|
+
return dataset.size if dataset.respond_to?(:size)
|
|
1955
|
+
return dataset.length if dataset.respond_to?(:length)
|
|
1956
|
+
|
|
1957
|
+
nil
|
|
1958
|
+
end
|
|
1959
|
+
|
|
1960
|
+
def __dsl_dataset_for_epoch(dataset, epoch:, kind:)
|
|
1961
|
+
source = if dataset.respond_to?(:call)
|
|
1962
|
+
__dsl_call_dataset_factory(dataset, epoch: epoch, kind: kind)
|
|
1963
|
+
else
|
|
1964
|
+
dataset
|
|
1965
|
+
end
|
|
1966
|
+
unless source.respond_to?(:each)
|
|
1967
|
+
raise ArgumentError, "#{kind} dataset must respond to #each"
|
|
1968
|
+
end
|
|
1969
|
+
|
|
1970
|
+
if epoch.positive? && !dataset.respond_to?(:call) && source.respond_to?(:rewind)
|
|
1971
|
+
begin
|
|
1972
|
+
source.rewind
|
|
1973
|
+
rescue StandardError => e
|
|
1974
|
+
raise ArgumentError, "#{kind} dataset could not rewind for epoch #{epoch}: #{e.message}"
|
|
1975
|
+
end
|
|
1976
|
+
end
|
|
1977
|
+
|
|
1978
|
+
source
|
|
1979
|
+
end
|
|
1980
|
+
|
|
1981
|
+
def __dsl_call_dataset_factory(factory, epoch:, kind:)
|
|
1982
|
+
return factory.call unless factory.respond_to?(:parameters)
|
|
1983
|
+
|
|
1984
|
+
params = factory.parameters
|
|
1985
|
+
values = {
|
|
1986
|
+
epoch: epoch,
|
|
1987
|
+
trainer: self,
|
|
1988
|
+
kind: kind
|
|
1989
|
+
}
|
|
1990
|
+
return factory.call if params.empty?
|
|
1991
|
+
|
|
1992
|
+
args = __dsl_build_positional_args(
|
|
1993
|
+
params,
|
|
1994
|
+
values,
|
|
1995
|
+
[[:epoch, epoch], [:kind, kind], [:trainer, self]],
|
|
1996
|
+
"dataset factory"
|
|
1997
|
+
)
|
|
1998
|
+
kwargs = __dsl_build_keyword_args(params, values, "dataset factory")
|
|
1999
|
+
return factory.call(*args) if kwargs.empty?
|
|
2000
|
+
|
|
2001
|
+
factory.call(*args, **kwargs)
|
|
2002
|
+
end
|
|
2003
|
+
|
|
2004
|
+
def __dsl_resolve_loop_limit(limit, epoch:, kind:)
|
|
2005
|
+
return nil if limit.nil?
|
|
2006
|
+
|
|
2007
|
+
raw = if limit.respond_to?(:call)
|
|
2008
|
+
values = {
|
|
2009
|
+
epoch: epoch,
|
|
2010
|
+
kind: kind,
|
|
2011
|
+
trainer: self
|
|
2012
|
+
}
|
|
2013
|
+
return limit.call if !limit.respond_to?(:parameters) || limit.parameters.empty?
|
|
2014
|
+
|
|
2015
|
+
args = __dsl_build_positional_args(
|
|
2016
|
+
limit.parameters,
|
|
2017
|
+
values,
|
|
2018
|
+
[[:epoch, epoch], [:kind, kind], [:trainer, self]],
|
|
2019
|
+
"#{kind} limit"
|
|
2020
|
+
)
|
|
2021
|
+
kwargs = __dsl_build_keyword_args(limit.parameters, values, "#{kind} limit")
|
|
2022
|
+
kwargs.empty? ? limit.call(*args) : limit.call(*args, **kwargs)
|
|
2023
|
+
else
|
|
2024
|
+
limit
|
|
2025
|
+
end
|
|
2026
|
+
return nil if raw.nil?
|
|
2027
|
+
unless raw.respond_to?(:to_int)
|
|
2028
|
+
raise ArgumentError, "#{kind} limit must be an Integer, nil, or callable returning one"
|
|
2029
|
+
end
|
|
2030
|
+
|
|
2031
|
+
value = raw.to_int
|
|
2032
|
+
raise ArgumentError, "#{kind} limit must be non-negative" if value.negative?
|
|
2033
|
+
|
|
2034
|
+
value
|
|
2035
|
+
end
|
|
2036
|
+
|
|
2037
|
+
def __dsl_build_positional_args(params, values, fallback_pairs, label)
|
|
2038
|
+
queue = fallback_pairs.dup
|
|
2039
|
+
args = []
|
|
2040
|
+
params.each do |type, name|
|
|
2041
|
+
next unless type == :req || type == :opt
|
|
2042
|
+
|
|
2043
|
+
if !name.nil? && values.key?(name)
|
|
2044
|
+
args << values.fetch(name)
|
|
2045
|
+
queue.reject! { |key, _value| key == name }
|
|
2046
|
+
next
|
|
2047
|
+
end
|
|
2048
|
+
|
|
2049
|
+
if queue.empty?
|
|
2050
|
+
raise ArgumentError, "#{label} has unsupported required positional argument: #{name.inspect}" if type == :req
|
|
2051
|
+
break
|
|
2052
|
+
end
|
|
2053
|
+
|
|
2054
|
+
_key, value = queue.shift
|
|
2055
|
+
args << value
|
|
2056
|
+
end
|
|
2057
|
+
args
|
|
2058
|
+
end
|
|
2059
|
+
|
|
2060
|
+
def __dsl_build_keyword_args(params, values, label)
|
|
2061
|
+
return values.dup if params.any? { |type, _name| type == :keyrest }
|
|
2062
|
+
|
|
2063
|
+
required_keys = params.each_with_object([]) do |(type, name), out|
|
|
2064
|
+
out << name if type == :keyreq
|
|
2065
|
+
end
|
|
2066
|
+
missing = required_keys.reject { |name| values.key?(name) }
|
|
2067
|
+
unless missing.empty?
|
|
2068
|
+
raise ArgumentError, "#{label} requires unsupported keyword argument(s): #{missing.map(&:inspect).join(", ")}"
|
|
2069
|
+
end
|
|
2070
|
+
|
|
2071
|
+
accepted_keys = params.each_with_object([]) do |(type, name), out|
|
|
2072
|
+
out << name if type == :key || type == :keyreq
|
|
2073
|
+
end
|
|
2074
|
+
|
|
2075
|
+
values.each_with_object({}) do |(name, value), out|
|
|
2076
|
+
out[name] = value if accepted_keys.include?(name)
|
|
2077
|
+
end
|
|
2078
|
+
end
|
|
2079
|
+
|
|
2080
|
+
def __dsl_validate_data_reuse!(strict:, dataset:, kind:, epoch:, previous_batches:, current_batches:)
|
|
2081
|
+
return unless strict
|
|
2082
|
+
return if dataset.nil? || dataset.respond_to?(:call)
|
|
2083
|
+
return unless epoch.positive?
|
|
2084
|
+
return unless previous_batches.to_i.positive? && current_batches.to_i.zero?
|
|
2085
|
+
return if dataset.respond_to?(:rewind)
|
|
2086
|
+
|
|
2087
|
+
raise ArgumentError,
|
|
2088
|
+
"#{kind} dataset appears exhausted across epochs; pass a factory like ->(epoch:) { ... }"
|
|
2089
|
+
end
|
|
2090
|
+
|
|
2091
|
+
def __dsl_with_eval_mode
|
|
2092
|
+
if @model.respond_to?(:eval_mode)
|
|
2093
|
+
@model.eval_mode { yield }
|
|
2094
|
+
return
|
|
2095
|
+
end
|
|
2096
|
+
|
|
2097
|
+
unless @model.respond_to?(:eval) && @model.respond_to?(:train) && @model.respond_to?(:training)
|
|
2098
|
+
yield
|
|
2099
|
+
return
|
|
2100
|
+
end
|
|
2101
|
+
|
|
2102
|
+
previous = @model.training
|
|
2103
|
+
@model.eval
|
|
2104
|
+
yield
|
|
2105
|
+
ensure
|
|
2106
|
+
@model.train(previous) unless previous.nil?
|
|
2107
|
+
end
|
|
2108
|
+
end
|
|
2109
|
+
end
|
|
2110
|
+
end
|