mlx 0.30.7 → 0.30.7.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,2110 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "json"
4
+ require "fileutils"
5
+ require "time"
6
+
7
+ module MLX
8
+ module DSL
9
+ class Trainer
10
+ HOOK_EVENTS = %i[
11
+ before_fit
12
+ before_epoch
13
+ after_batch
14
+ before_validation
15
+ after_validation_batch
16
+ after_validation
17
+ after_epoch
18
+ checkpoint
19
+ after_fit
20
+ ].freeze
21
+ UNSET = Object.new.freeze
22
+ FIT_OPTION_DEFAULTS = {
23
+ epochs: 1,
24
+ limit: nil,
25
+ collate: nil,
26
+ bind: nil,
27
+ report: false,
28
+ reduce: :mean,
29
+ monitor: :epoch_loss,
30
+ metric: nil,
31
+ validation_data: nil,
32
+ validation_limit: nil,
33
+ validation_reduce: nil,
34
+ validation_collate: nil,
35
+ validation_bind: nil,
36
+ train_transform: nil,
37
+ validation_transform: nil,
38
+ checkpoint_path: nil,
39
+ save_best: false,
40
+ monitor_mode: :min,
41
+ patience: nil,
42
+ min_delta: 0.0,
43
+ keep_losses: true,
44
+ strict_data_reuse: false,
45
+ resume_from: nil,
46
+ metadata: {}
47
+ }.freeze
48
+
49
+ def initialize(model:, optimizer:, clip_grad_norm: nil, compile: false, sync: :none, &loss_block)
50
+ raise ArgumentError, "trainer requires a loss block" unless block_given?
51
+
52
+ @__dsl_init_options = {
53
+ model: model,
54
+ optimizer: optimizer,
55
+ clip_grad_norm: clip_grad_norm,
56
+ compile: compile,
57
+ sync: sync
58
+ }
59
+ @model = model
60
+ @loss_block = loss_block
61
+ @sync_mode = __dsl_normalize_sync(sync)
62
+ @step = __dsl_build_train_step(
63
+ model,
64
+ optimizer: optimizer,
65
+ clip_grad_norm: clip_grad_norm,
66
+ compile: compile,
67
+ &loss_block
68
+ )
69
+ @optimizer = optimizer
70
+ @hooks = Hash.new { |h, k| h[k] = [] }
71
+ @hook_order = 0
72
+ @collate_registry = {}
73
+ @fit_defaults = {}
74
+ @fit_presets = {}
75
+ @batch_schemas = { train: nil, validation: nil }
76
+ @dataflow_profiles = {}
77
+ @hook_packs = {}
78
+ @metric_registry = {}
79
+ @task_presets = __dsl_builtin_task_presets
80
+ @artifact_policy_config = {
81
+ checkpoint: {},
82
+ retention: {},
83
+ resume: nil,
84
+ run_bundle: {}
85
+ }
86
+ @checkpoint_history = []
87
+ @last_checkpoint_snapshot = nil
88
+ end
89
+
90
+ def on(event, priority: 0, every: nil, once: false, **kwargs, &block)
91
+ raise ArgumentError, "hook registration requires a block" unless block_given?
92
+ condition = kwargs.delete(:if)
93
+ condition = kwargs.delete(:condition) if condition.nil? && kwargs.key?(:condition)
94
+ unless kwargs.empty?
95
+ raise ArgumentError, "unsupported hook option(s): #{kwargs.keys.map(&:inspect).join(', ')}"
96
+ end
97
+ every_value = nil
98
+ unless every.nil?
99
+ every_value = every.to_i
100
+ raise ArgumentError, "hook :every must be a positive integer" if every_value <= 0
101
+ end
102
+
103
+ @hooks[event.to_sym] << {
104
+ hook: block,
105
+ priority: priority.to_i,
106
+ every: every_value,
107
+ once: !!once,
108
+ if: condition,
109
+ fired: false,
110
+ invocations: 0,
111
+ order: @hook_order
112
+ }
113
+ @hook_order += 1
114
+ self
115
+ end
116
+
117
+ HOOK_EVENTS.each do |event|
118
+ define_method(event) do |**options, &block|
119
+ on(event, **options, &block)
120
+ end
121
+ end
122
+
123
+ def fit(
124
+ dataset,
125
+ epochs: UNSET,
126
+ limit: UNSET,
127
+ collate: UNSET,
128
+ bind: UNSET,
129
+ report: UNSET,
130
+ reduce: UNSET,
131
+ monitor: UNSET,
132
+ metric: UNSET,
133
+ validation_data: UNSET,
134
+ validation_limit: UNSET,
135
+ validation_reduce: UNSET,
136
+ validation_collate: UNSET,
137
+ validation_bind: UNSET,
138
+ train_transform: UNSET,
139
+ validation_transform: UNSET,
140
+ checkpoint_path: UNSET,
141
+ save_best: UNSET,
142
+ monitor_mode: UNSET,
143
+ patience: UNSET,
144
+ min_delta: UNSET,
145
+ keep_losses: UNSET,
146
+ strict_data_reuse: UNSET,
147
+ resume_from: UNSET,
148
+ metadata: UNSET
149
+ )
150
+ raw_options = {
151
+ epochs: epochs,
152
+ limit: limit,
153
+ collate: collate,
154
+ bind: bind,
155
+ report: report,
156
+ reduce: reduce,
157
+ monitor: monitor,
158
+ metric: metric,
159
+ validation_data: validation_data,
160
+ validation_limit: validation_limit,
161
+ validation_reduce: validation_reduce,
162
+ validation_collate: validation_collate,
163
+ validation_bind: validation_bind,
164
+ train_transform: train_transform,
165
+ validation_transform: validation_transform,
166
+ checkpoint_path: checkpoint_path,
167
+ save_best: save_best,
168
+ monitor_mode: monitor_mode,
169
+ patience: patience,
170
+ min_delta: min_delta,
171
+ keep_losses: keep_losses,
172
+ strict_data_reuse: strict_data_reuse,
173
+ resume_from: resume_from,
174
+ metadata: metadata
175
+ }
176
+ dataset, raw_options = __dsl_expand_split_plan(dataset, raw_options)
177
+ options = __dsl_resolve_fit_options(raw_options)
178
+ epochs = options.fetch(:epochs)
179
+ limit = options.fetch(:limit)
180
+ collate = options.fetch(:collate)
181
+ bind = options.fetch(:bind)
182
+ report = options.fetch(:report)
183
+ reduce = options.fetch(:reduce)
184
+ monitor = options.fetch(:monitor)
185
+ metric = options.fetch(:metric)
186
+ validation_data = options.fetch(:validation_data)
187
+ validation_limit = options.fetch(:validation_limit)
188
+ validation_reduce = options.fetch(:validation_reduce)
189
+ validation_collate = options.fetch(:validation_collate)
190
+ validation_bind = options.fetch(:validation_bind)
191
+ train_transform = options.fetch(:train_transform)
192
+ validation_transform = options.fetch(:validation_transform)
193
+ checkpoint_path = options.fetch(:checkpoint_path)
194
+ save_best = options.fetch(:save_best)
195
+ monitor_mode = options.fetch(:monitor_mode)
196
+ patience = options.fetch(:patience)
197
+ min_delta = options.fetch(:min_delta)
198
+ keep_losses = options.fetch(:keep_losses)
199
+ strict_data_reuse = options.fetch(:strict_data_reuse)
200
+ resume_from = options.fetch(:resume_from)
201
+ metadata = options.fetch(:metadata)
202
+ policy = __dsl_resolve_artifact_policy(
203
+ checkpoint_path: checkpoint_path,
204
+ save_best: save_best,
205
+ resume_from: resume_from,
206
+ monitor_mode: monitor_mode
207
+ )
208
+ checkpoint_path = policy.fetch(:checkpoint_path)
209
+ save_best = policy.fetch(:save_best)
210
+ resume_from = policy.fetch(:resume_from)
211
+ checkpoint_every = policy.fetch(:checkpoint_every)
212
+ retention_keep_last_n = policy.fetch(:keep_last_n)
213
+ run_bundle_policy = policy.fetch(:run_bundle)
214
+ policy_payload = policy.fetch(:payload)
215
+
216
+ keep_losses = !!keep_losses
217
+ strict_data_reuse = !!strict_data_reuse
218
+ @last_checkpoint_snapshot = nil
219
+ losses = []
220
+ epoch_rows = []
221
+ best_metric = nil
222
+ stale_epochs = 0
223
+ stopped_early = false
224
+ previous_train_batches = nil
225
+ previous_validation_batches = nil
226
+ monitor_name = monitor.to_s
227
+ resume_state = __dsl_resume_state(resume_from, monitor_name)
228
+ start_epoch = resume_state.fetch(:start_epoch)
229
+ best_metric = resume_state.fetch(:best_metric)
230
+ stale_epochs = resume_state.fetch(:stale_epochs)
231
+ monitor_name = resume_state.fetch(:monitor_name)
232
+ total_epochs = epochs.to_i
233
+ patience_value = __dsl_normalize_patience(patience)
234
+ min_delta_value = __dsl_normalize_min_delta(min_delta)
235
+ train_dataset_size = __dsl_dataset_size(dataset)
236
+ validation_dataset_size = __dsl_dataset_size(validation_data)
237
+ validation_reducer = validation_reduce.nil? ? reduce : validation_reduce
238
+
239
+ emit(
240
+ :before_fit,
241
+ {
242
+ model: @model,
243
+ optimizer: @optimizer,
244
+ epochs: total_epochs,
245
+ start_epoch: start_epoch,
246
+ resumed_from_epoch: resume_state.fetch(:checkpoint_epoch),
247
+ resume_from: resume_state.fetch(:path),
248
+ best_metric: best_metric,
249
+ stale_epochs: stale_epochs,
250
+ dataset_size: train_dataset_size,
251
+ validation_size: validation_dataset_size
252
+ }
253
+ )
254
+
255
+ (start_epoch...total_epochs).each do |epoch|
256
+ emit(:before_epoch, { epoch: epoch, model: @model })
257
+ index = 0
258
+ epoch_losses = []
259
+ epoch_last_loss = nil
260
+ train_limit = __dsl_resolve_loop_limit(limit, epoch: epoch, kind: :train)
261
+ __dsl_dataset_for_epoch(dataset, epoch: epoch, kind: :train).each do |batch|
262
+ break if !train_limit.nil? && index >= train_limit
263
+
264
+ batch = __dsl_apply_batch_transform(
265
+ train_transform,
266
+ __dsl_apply_collate(
267
+ __dsl_effective_collate(
268
+ collate,
269
+ bind,
270
+ batch,
271
+ kind: :train
272
+ ),
273
+ batch,
274
+ kind: :train,
275
+ epoch: epoch,
276
+ batch_index: index
277
+ ),
278
+ epoch: epoch,
279
+ batch_index: index,
280
+ kind: :train
281
+ )
282
+ loss = __dsl_run_batch(
283
+ batch,
284
+ epoch: epoch,
285
+ batch_index: index,
286
+ kind: :train
287
+ )
288
+ epoch_last_loss = loss
289
+ losses << loss if keep_losses
290
+ scalar = __dsl_loss_scalar(loss)
291
+ epoch_losses << scalar unless scalar.nil?
292
+ emit(
293
+ :after_batch,
294
+ {
295
+ epoch: epoch,
296
+ batch_index: index,
297
+ loss: loss,
298
+ loss_value: scalar,
299
+ model: @model
300
+ }
301
+ )
302
+ index += 1
303
+ end
304
+ __dsl_validate_data_reuse!(
305
+ strict: strict_data_reuse,
306
+ dataset: dataset,
307
+ kind: :train,
308
+ epoch: epoch,
309
+ previous_batches: previous_train_batches,
310
+ current_batches: index
311
+ )
312
+ previous_train_batches = index
313
+
314
+ epoch_metric = __dsl_reduce_values(epoch_losses, reduce)
315
+ validation_losses = []
316
+ validation_batch_count = 0
317
+ val_metric = nil
318
+ unless validation_data.nil?
319
+ validation_epoch_limit = __dsl_resolve_loop_limit(
320
+ validation_limit,
321
+ epoch: epoch,
322
+ kind: :validation
323
+ )
324
+ emit(
325
+ :before_validation,
326
+ {
327
+ epoch: epoch,
328
+ model: @model,
329
+ monitor_name: monitor_name
330
+ }
331
+ )
332
+ __dsl_with_eval_mode do
333
+ __dsl_dataset_for_epoch(validation_data, epoch: epoch, kind: :validation).each do |batch|
334
+ break if !validation_epoch_limit.nil? && validation_batch_count >= validation_epoch_limit
335
+
336
+ batch = __dsl_apply_batch_transform(
337
+ validation_transform,
338
+ __dsl_apply_collate(
339
+ __dsl_effective_collate(
340
+ validation_collate,
341
+ validation_bind,
342
+ batch,
343
+ kind: :validation
344
+ ),
345
+ batch,
346
+ kind: :validation,
347
+ epoch: epoch,
348
+ batch_index: validation_batch_count
349
+ ),
350
+ epoch: epoch,
351
+ batch_index: validation_batch_count,
352
+ kind: :validation
353
+ )
354
+ loss = __dsl_run_validation_batch(
355
+ batch,
356
+ epoch: epoch,
357
+ batch_index: validation_batch_count,
358
+ kind: :validation
359
+ )
360
+ scalar = __dsl_loss_scalar(loss)
361
+ validation_losses << scalar unless scalar.nil?
362
+ emit(
363
+ :after_validation_batch,
364
+ {
365
+ epoch: epoch,
366
+ batch_index: validation_batch_count,
367
+ loss: loss,
368
+ loss_value: scalar,
369
+ model: @model
370
+ }
371
+ )
372
+ validation_batch_count += 1
373
+ end
374
+ end
375
+ __dsl_validate_data_reuse!(
376
+ strict: strict_data_reuse,
377
+ dataset: validation_data,
378
+ kind: :validation,
379
+ epoch: epoch,
380
+ previous_batches: previous_validation_batches,
381
+ current_batches: validation_batch_count
382
+ )
383
+ previous_validation_batches = validation_batch_count
384
+ val_metric = __dsl_reduce_values(validation_losses, validation_reducer)
385
+ emit(
386
+ :after_validation,
387
+ {
388
+ epoch: epoch,
389
+ model: @model,
390
+ val_loss: val_metric,
391
+ validation_batches: validation_batch_count
392
+ }
393
+ )
394
+ end
395
+
396
+ monitor_value = __dsl_monitor_value(
397
+ metric,
398
+ {
399
+ epoch: epoch,
400
+ epoch_losses: epoch_losses,
401
+ epoch_loss: epoch_metric,
402
+ val_loss: val_metric,
403
+ validation_losses: validation_losses,
404
+ losses: losses,
405
+ batches: index,
406
+ validation_batches: validation_batch_count,
407
+ model: @model,
408
+ optimizer: @optimizer
409
+ },
410
+ fallback: __dsl_default_monitor_value(monitor_name, epoch_metric, val_metric)
411
+ )
412
+ improved = __dsl_improved?(
413
+ monitor_value,
414
+ best_metric,
415
+ monitor_mode,
416
+ min_delta: min_delta_value
417
+ )
418
+ if improved
419
+ best_metric = monitor_value
420
+ stale_epochs = 0
421
+ elsif !best_metric.nil?
422
+ stale_epochs += 1
423
+ end
424
+
425
+ row = {
426
+ "epoch" => epoch,
427
+ "batches" => index,
428
+ "epoch_loss" => epoch_metric,
429
+ "val_loss" => val_metric,
430
+ "validation_batches" => validation_batch_count,
431
+ "monitor_value" => monitor_value,
432
+ "stale_epochs" => stale_epochs,
433
+ "improved" => improved
434
+ }
435
+ epoch_rows << row
436
+
437
+ checkpoint_saved = __dsl_maybe_checkpoint(
438
+ checkpoint_path,
439
+ save_best: save_best,
440
+ improved: improved,
441
+ epoch: epoch,
442
+ monitor_name: monitor_name,
443
+ monitor_value: monitor_value,
444
+ epoch_metric: epoch_metric,
445
+ stale_epochs: stale_epochs,
446
+ best_metric: best_metric,
447
+ metadata: metadata,
448
+ checkpoint_every: checkpoint_every,
449
+ keep_last_n: retention_keep_last_n
450
+ )
451
+
452
+ emit(
453
+ :after_epoch,
454
+ {
455
+ epoch: epoch,
456
+ model: @model,
457
+ epoch_loss: epoch_metric,
458
+ val_loss: val_metric,
459
+ monitor_name: monitor_name,
460
+ monitor_value: monitor_value,
461
+ validation_batches: validation_batch_count,
462
+ stale_epochs: stale_epochs,
463
+ improved: improved,
464
+ best_metric: best_metric,
465
+ checkpoint_saved: checkpoint_saved
466
+ }
467
+ )
468
+
469
+ __dsl_sync_epoch(epoch_last_loss) if @sync_mode == :epoch
470
+
471
+ if !patience_value.nil? && stale_epochs > patience_value
472
+ stopped_early = true
473
+ break
474
+ end
475
+ end
476
+
477
+ payload = {
478
+ "losses" => losses,
479
+ "losses_kept" => keep_losses,
480
+ "epochs" => epoch_rows,
481
+ "monitor_name" => monitor_name,
482
+ "epochs_target" => total_epochs,
483
+ "epochs_completed" => [start_epoch + epoch_rows.length, total_epochs].min,
484
+ "epochs_ran" => epoch_rows.length,
485
+ "stopped_early" => stopped_early,
486
+ "best_metric" => best_metric,
487
+ "resume_from" => resume_state.fetch(:path),
488
+ "resumed_from_epoch" => resume_state.fetch(:checkpoint_epoch),
489
+ "start_epoch" => start_epoch,
490
+ "artifact_policy" => policy_payload
491
+ }
492
+ auto_bundle_path = __dsl_auto_save_run_bundle(
493
+ run_bundle_policy,
494
+ payload
495
+ )
496
+ payload["run_bundle_path"] = auto_bundle_path unless auto_bundle_path.nil?
497
+ emit(
498
+ :after_fit,
499
+ {
500
+ model: @model,
501
+ optimizer: @optimizer,
502
+ epochs: epoch_rows.length,
503
+ best_metric: best_metric,
504
+ stopped_early: stopped_early,
505
+ resume_from: resume_state.fetch(:path),
506
+ resumed_from_epoch: resume_state.fetch(:checkpoint_epoch),
507
+ report: payload
508
+ }
509
+ )
510
+
511
+ return payload if report
512
+
513
+ losses
514
+ end
515
+
516
+ def fit_report(dataset, **kwargs)
517
+ fit(dataset, **kwargs, report: true)
518
+ end
519
+
520
+ def with_fit_defaults(**defaults)
521
+ configured = __dsl_clone_trainer
522
+ configured.instance_variable_set(
523
+ :@fit_defaults,
524
+ configured.instance_variable_get(:@fit_defaults).merge(__dsl_normalize_fit_option_keys(defaults))
525
+ )
526
+ configured
527
+ end
528
+
529
+ def register_fit_preset(name, **defaults)
530
+ @fit_presets[name.to_sym] = __dsl_normalize_fit_option_keys(defaults)
531
+ self
532
+ end
533
+
534
+ def fit_with(name, dataset, **overrides)
535
+ fit(dataset, **__dsl_merge_fit_preset(name, overrides))
536
+ end
537
+
538
+ def fit_report_with(name, dataset, **overrides)
539
+ fit_report(dataset, **__dsl_merge_fit_preset(name, overrides))
540
+ end
541
+
542
+ def register_dataflow(name, train: {}, validation: {}, extends: nil)
543
+ profile = __dsl_normalize_dataflow_profile(train: train, validation: validation)
544
+ if !extends.nil?
545
+ base_profile = { train: {}, validation: {} }
546
+ base_names = extends.is_a?(Array) ? extends : [extends]
547
+ base_names.each do |base_name|
548
+ base_key = base_name.to_sym
549
+ unless @dataflow_profiles.key?(base_key)
550
+ raise ArgumentError, "unknown dataflow profile: #{base_name.inspect}"
551
+ end
552
+
553
+ base_profile = __dsl_compose_dataflow_profile(base_profile, @dataflow_profiles.fetch(base_key))
554
+ end
555
+ profile = __dsl_compose_dataflow_profile(base_profile, profile)
556
+ end
557
+
558
+ @dataflow_profiles[name.to_sym] = profile
559
+ self
560
+ end
561
+
562
+ def use_dataflow(name, **overrides)
563
+ profile = __dsl_resolve_dataflow_profile(name)
564
+ profile_overrides, direct_overrides = __dsl_normalize_dataflow_overrides(overrides)
565
+ resolved = __dsl_compose_dataflow_profile(profile, profile_overrides)
566
+ __dsl_fit_kwargs_from_dataflow(resolved).merge(__dsl_normalize_fit_option_keys(direct_overrides))
567
+ end
568
+
569
+ def register_hook_pack(name, callable = nil, &block)
570
+ if !callable.nil? && block_given?
571
+ raise ArgumentError, "register_hook_pack accepts either a callable argument or block, not both"
572
+ end
573
+
574
+ pack = callable.nil? ? block : callable
575
+ unless pack.respond_to?(:call)
576
+ raise ArgumentError, "register_hook_pack requires a callable argument or block"
577
+ end
578
+
579
+ @hook_packs[name.to_sym] = pack
580
+ self
581
+ end
582
+
583
+ def use_hook_pack(name, **options)
584
+ key = name.to_sym
585
+ unless @hook_packs.key?(key)
586
+ raise ArgumentError, "unknown hook pack: #{name.inspect}"
587
+ end
588
+
589
+ __dsl_call_hook_pack(@hook_packs.fetch(key), options)
590
+ self
591
+ end
592
+
593
+ def register_metric(name, callable = nil, &block)
594
+ if !callable.nil? && block_given?
595
+ raise ArgumentError, "register_metric accepts either a callable argument or block, not both"
596
+ end
597
+
598
+ metric_callable = callable.nil? ? block : callable
599
+ unless metric_callable.respond_to?(:call)
600
+ raise ArgumentError, "register_metric requires a callable argument or block"
601
+ end
602
+
603
+ @metric_registry[name.to_sym] = metric_callable
604
+ self
605
+ end
606
+
607
+ def register_task(name, **defaults)
608
+ @task_presets[name.to_sym] = __dsl_normalize_fit_option_keys(defaults)
609
+ self
610
+ end
611
+
612
+ def fit_task(task, dataset, **overrides)
613
+ fit(dataset, **__dsl_task_fit_options(task, overrides))
614
+ end
615
+
616
+ def fit_task_report(task, dataset, **overrides)
617
+ fit_report(dataset, **__dsl_task_fit_options(task, overrides))
618
+ end
619
+
620
+ def artifact_policy(checkpoint: nil, retention: nil, resume: UNSET, run_bundle: nil)
621
+ if checkpoint.nil? && retention.nil? && resume.equal?(UNSET) && run_bundle.nil?
622
+ return __dsl_clone_config_value(@artifact_policy_config)
623
+ end
624
+
625
+ unless checkpoint.nil?
626
+ @artifact_policy_config[:checkpoint] = __dsl_normalize_artifact_checkpoint_policy(checkpoint)
627
+ end
628
+ unless retention.nil?
629
+ @artifact_policy_config[:retention] = __dsl_normalize_artifact_retention_policy(retention)
630
+ end
631
+ @artifact_policy_config[:resume] = resume unless resume.equal?(UNSET)
632
+ unless run_bundle.nil?
633
+ @artifact_policy_config[:run_bundle] = __dsl_normalize_artifact_run_bundle_policy(run_bundle)
634
+ end
635
+ self
636
+ end
637
+
638
+ def checkpoint_history
639
+ __dsl_clone_config_value(@checkpoint_history)
640
+ end
641
+
642
+ def batch_schema(spec = UNSET, train: UNSET, validation: UNSET, **schema_kwargs)
643
+ unless schema_kwargs.empty?
644
+ if spec.equal?(UNSET) && train.equal?(UNSET) && validation.equal?(UNSET)
645
+ spec = schema_kwargs
646
+ else
647
+ raise ArgumentError, "batch_schema keyword mappings cannot be combined with positional or split-specific forms"
648
+ end
649
+ end
650
+
651
+ if !spec.equal?(UNSET) && (!train.equal?(UNSET) || !validation.equal?(UNSET))
652
+ raise ArgumentError, "batch_schema accepts either positional spec or split-specific train:/validation: overrides"
653
+ end
654
+ if spec.equal?(UNSET) && train.equal?(UNSET) && validation.equal?(UNSET)
655
+ return __dsl_clone_config_value(@batch_schemas)
656
+ end
657
+
658
+ if !spec.equal?(UNSET)
659
+ normalized = __dsl_normalize_batch_schema_spec(spec)
660
+ @batch_schemas[:train] = normalized
661
+ @batch_schemas[:validation] = normalized
662
+ return self
663
+ end
664
+
665
+ unless train.equal?(UNSET)
666
+ @batch_schemas[:train] = __dsl_normalize_batch_schema_spec(train)
667
+ end
668
+ unless validation.equal?(UNSET)
669
+ @batch_schemas[:validation] = __dsl_normalize_batch_schema_spec(validation)
670
+ end
671
+ self
672
+ end
673
+
674
+ def register_collate(name, spec = nil, extends: nil, &block)
675
+ if !spec.nil? && block_given?
676
+ raise ArgumentError, "register_collate accepts either a spec argument or block, not both"
677
+ end
678
+
679
+ key = name.to_sym
680
+ resolved_spec = block_given? ? block : spec
681
+ if resolved_spec.nil?
682
+ raise ArgumentError, "register_collate requires a collate spec or block"
683
+ end
684
+
685
+ if !extends.nil?
686
+ base_keys = extends.is_a?(Array) ? extends : [extends]
687
+ base_spec = nil
688
+ base_keys.each do |base_name|
689
+ base_key = base_name.to_sym
690
+ unless @collate_registry.key?(base_key)
691
+ raise ArgumentError, "unknown base collate schema: #{base_name.inspect}"
692
+ end
693
+
694
+ current_base = @collate_registry.fetch(base_key)
695
+ base_spec = base_spec.nil? ? current_base : __dsl_compose_collate(base_spec, current_base)
696
+ end
697
+ resolved_spec = __dsl_compose_collate(base_spec, resolved_spec)
698
+ end
699
+
700
+ @collate_registry[key] = resolved_spec
701
+ self
702
+ end
703
+
704
+ def run_bundle(report:, config: {}, schema_version: "mlx_dsl_run_bundle_v1")
705
+ unless report.is_a?(Hash)
706
+ raise ArgumentError, "run_bundle requires report to be a Hash from fit_report"
707
+ end
708
+
709
+ {
710
+ "format" => schema_version.to_s,
711
+ "generated_at" => Time.now.utc.iso8601,
712
+ "trainer" => {
713
+ "monitor_name" => report["monitor_name"],
714
+ "epochs_target" => report["epochs_target"],
715
+ "epochs_completed" => report["epochs_completed"],
716
+ "resume_from" => report["resume_from"],
717
+ "resumed_from_epoch" => report["resumed_from_epoch"]
718
+ },
719
+ "config" => config || {},
720
+ "report" => report,
721
+ "checkpoint" => __dsl_deep_copy(@last_checkpoint_snapshot)
722
+ }
723
+ end
724
+
725
+ def save_run_bundle(path, report:, config: {}, schema_version: "mlx_dsl_run_bundle_v1")
726
+ bundle = run_bundle(
727
+ report: report,
728
+ config: config,
729
+ schema_version: schema_version
730
+ )
731
+ dir = File.dirname(path.to_s)
732
+ FileUtils.mkdir_p(dir) unless dir.nil? || dir.empty? || dir == "."
733
+ File.binwrite(path, JSON.pretty_generate(bundle))
734
+ path
735
+ end
736
+
737
+ def resume_payload_from_bundle(bundle_or_path)
738
+ bundle = __dsl_resolve_run_bundle(bundle_or_path)
739
+ checkpoint = bundle["checkpoint"]
740
+ unless checkpoint.is_a?(Hash)
741
+ raise ArgumentError, "run bundle does not include checkpoint snapshot"
742
+ end
743
+
744
+ metadata = checkpoint["metadata"]
745
+ unless metadata.is_a?(Hash)
746
+ raise ArgumentError, "run bundle checkpoint metadata is missing or invalid"
747
+ end
748
+
749
+ { "metadata" => __dsl_deep_copy(metadata) }
750
+ end
751
+
752
+ private
753
+
754
+ def __dsl_clone_trainer
755
+ cloned = self.class.new(**@__dsl_init_options, &@loss_block)
756
+ hooks = Hash.new { |h, k| h[k] = [] }
757
+ @hooks.each do |event, entries|
758
+ hooks[event] = entries.map(&:dup)
759
+ end
760
+ cloned.instance_variable_set(:@hooks, hooks)
761
+ cloned.instance_variable_set(:@hook_order, @hook_order)
762
+ cloned.instance_variable_set(:@collate_registry, __dsl_clone_config_value(@collate_registry))
763
+ cloned.instance_variable_set(:@fit_defaults, __dsl_clone_config_value(@fit_defaults))
764
+ cloned.instance_variable_set(:@fit_presets, __dsl_clone_config_value(@fit_presets))
765
+ cloned.instance_variable_set(:@batch_schemas, __dsl_clone_config_value(@batch_schemas))
766
+ cloned.instance_variable_set(:@dataflow_profiles, __dsl_clone_config_value(@dataflow_profiles))
767
+ cloned.instance_variable_set(:@hook_packs, __dsl_clone_config_value(@hook_packs))
768
+ cloned.instance_variable_set(:@metric_registry, __dsl_clone_config_value(@metric_registry))
769
+ cloned.instance_variable_set(:@task_presets, __dsl_clone_config_value(@task_presets))
770
+ cloned.instance_variable_set(:@artifact_policy_config, __dsl_clone_config_value(@artifact_policy_config))
771
+ cloned.instance_variable_set(:@checkpoint_history, __dsl_clone_config_value(@checkpoint_history))
772
+ cloned.instance_variable_set(:@last_checkpoint_snapshot, __dsl_clone_config_value(@last_checkpoint_snapshot))
773
+ cloned
774
+ end
775
+
776
+ def __dsl_clone_config_value(value)
777
+ case value
778
+ when Hash
779
+ value.each_with_object({}) do |(key, item), out|
780
+ out[key] = __dsl_clone_config_value(item)
781
+ end
782
+ when Array
783
+ value.map { |item| __dsl_clone_config_value(item) }
784
+ else
785
+ value
786
+ end
787
+ end
788
+
789
+ def __dsl_expand_split_plan(dataset, raw_options)
790
+ return [dataset, raw_options] unless defined?(MLX::DSL::SplitPlan) && dataset.is_a?(MLX::DSL::SplitPlan)
791
+
792
+ train_dataset, plan_options = dataset.to_fit_inputs
793
+ merged = raw_options.dup
794
+ plan_options.each do |key, value|
795
+ next unless merged.key?(key)
796
+ next unless merged.fetch(key).equal?(UNSET)
797
+
798
+ merged[key] = value
799
+ end
800
+ [train_dataset, merged]
801
+ end
802
+
803
+ def __dsl_normalize_artifact_checkpoint_policy(checkpoint)
804
+ unless checkpoint.is_a?(Hash)
805
+ raise ArgumentError, "artifact checkpoint policy must be a Hash"
806
+ end
807
+
808
+ normalized = checkpoint.each_with_object({}) do |(key, value), out|
809
+ out[key.to_sym] = value
810
+ end
811
+ unknown = normalized.keys - %i[path strategy every]
812
+ unless unknown.empty?
813
+ raise ArgumentError, "artifact checkpoint policy has unsupported key(s): #{unknown.map(&:inspect).join(', ')}"
814
+ end
815
+ if normalized.key?(:strategy)
816
+ strategy = normalized.fetch(:strategy).to_sym
817
+ unless %i[latest best every].include?(strategy)
818
+ raise ArgumentError, "artifact checkpoint strategy must be :latest, :best, or :every"
819
+ end
820
+
821
+ normalized[:strategy] = strategy
822
+ end
823
+ if normalized.key?(:every)
824
+ every = normalized.fetch(:every).to_i
825
+ raise ArgumentError, "artifact checkpoint every must be positive" if every <= 0
826
+
827
+ normalized[:every] = every
828
+ end
829
+ normalized
830
+ end
831
+
832
+ def __dsl_normalize_artifact_retention_policy(retention)
833
+ unless retention.is_a?(Hash)
834
+ raise ArgumentError, "artifact retention policy must be a Hash"
835
+ end
836
+
837
+ normalized = retention.each_with_object({}) do |(key, value), out|
838
+ out[key.to_sym] = value
839
+ end
840
+ unknown = normalized.keys - %i[keep_last_n]
841
+ unless unknown.empty?
842
+ raise ArgumentError, "artifact retention policy has unsupported key(s): #{unknown.map(&:inspect).join(', ')}"
843
+ end
844
+ if normalized.key?(:keep_last_n)
845
+ keep = normalized.fetch(:keep_last_n).to_i
846
+ raise ArgumentError, "artifact retention keep_last_n must be non-negative" if keep.negative?
847
+
848
+ normalized[:keep_last_n] = keep
849
+ end
850
+ normalized
851
+ end
852
+
853
+ def __dsl_normalize_artifact_run_bundle_policy(run_bundle)
854
+ unless run_bundle.is_a?(Hash)
855
+ raise ArgumentError, "artifact run_bundle policy must be a Hash"
856
+ end
857
+
858
+ normalized = run_bundle.each_with_object({}) do |(key, value), out|
859
+ out[key.to_sym] = value
860
+ end
861
+ unknown = normalized.keys - %i[enabled path config]
862
+ unless unknown.empty?
863
+ raise ArgumentError, "artifact run_bundle policy has unsupported key(s): #{unknown.map(&:inspect).join(', ')}"
864
+ end
865
+ normalized[:enabled] = !!normalized.fetch(:enabled, false)
866
+ normalized[:config] = {} if normalized[:config].nil?
867
+ normalized
868
+ end
869
+
870
+ def __dsl_resolve_artifact_policy(checkpoint_path:, save_best:, resume_from:, monitor_mode:)
871
+ checkpoint_policy = @artifact_policy_config.fetch(:checkpoint, {})
872
+ retention_policy = @artifact_policy_config.fetch(:retention, {})
873
+ run_bundle_policy = @artifact_policy_config.fetch(:run_bundle, {})
874
+ resume_policy = @artifact_policy_config.fetch(:resume, nil)
875
+
876
+ resolved_checkpoint_path = checkpoint_path
877
+ if (resolved_checkpoint_path.nil? || resolved_checkpoint_path.to_s.empty?) && checkpoint_policy.key?(:path)
878
+ resolved_checkpoint_path = checkpoint_policy.fetch(:path)
879
+ end
880
+
881
+ resolved_save_best = save_best
882
+ strategy = checkpoint_policy[:strategy]
883
+ case strategy
884
+ when :best
885
+ resolved_save_best = true
886
+ when :latest, :every
887
+ resolved_save_best = false
888
+ end
889
+
890
+ resolved_resume = resume_from
891
+ if (resolved_resume.nil? || resolved_resume.to_s.empty?) && !resume_policy.nil?
892
+ resolved_resume = __dsl_policy_resume_source(resume_policy, monitor_mode: monitor_mode)
893
+ end
894
+
895
+ {
896
+ checkpoint_path: resolved_checkpoint_path,
897
+ save_best: resolved_save_best,
898
+ checkpoint_every: checkpoint_policy[:every],
899
+ keep_last_n: retention_policy[:keep_last_n],
900
+ resume_from: resolved_resume,
901
+ run_bundle: __dsl_clone_config_value(run_bundle_policy),
902
+ payload: __dsl_stringify_keys(__dsl_clone_config_value(@artifact_policy_config))
903
+ }
904
+ end
905
+
906
+ def __dsl_policy_resume_source(policy_resume, monitor_mode:)
907
+ case policy_resume
908
+ when :latest
909
+ entry = @checkpoint_history.last
910
+ return nil if entry.nil?
911
+
912
+ { "metadata" => __dsl_deep_copy(entry.fetch("metadata")) }
913
+ when :best
914
+ entry = __dsl_best_checkpoint_from_history(monitor_mode)
915
+ return nil if entry.nil?
916
+
917
+ { "metadata" => __dsl_deep_copy(entry.fetch("metadata")) }
918
+ else
919
+ policy_resume
920
+ end
921
+ end
922
+
923
+ def __dsl_best_checkpoint_from_history(monitor_mode)
924
+ rows = @checkpoint_history.select do |row|
925
+ row.is_a?(Hash) && row.key?("metadata") && row.fetch("metadata").is_a?(Hash)
926
+ end
927
+ return nil if rows.empty?
928
+
929
+ comparator = monitor_mode.to_sym
930
+ case comparator
931
+ when :max
932
+ rows.max_by { |row| row.fetch("monitor_value").to_f }
933
+ else
934
+ rows.min_by { |row| row.fetch("monitor_value").to_f }
935
+ end
936
+ end
937
+
938
+ def __dsl_auto_save_run_bundle(run_bundle_policy, report_payload)
939
+ return nil unless run_bundle_policy.is_a?(Hash)
940
+ return nil unless run_bundle_policy.fetch(:enabled, false)
941
+
942
+ path = run_bundle_policy[:path]
943
+ return nil if path.nil? || path.to_s.empty?
944
+
945
+ save_run_bundle(
946
+ path,
947
+ report: report_payload,
948
+ config: run_bundle_policy.fetch(:config, {})
949
+ )
950
+ end
951
+
952
+ def __dsl_stringify_keys(value)
953
+ case value
954
+ when Hash
955
+ value.each_with_object({}) do |(key, item), out|
956
+ out[key.to_s] = __dsl_stringify_keys(item)
957
+ end
958
+ when Array
959
+ value.map { |item| __dsl_stringify_keys(item) }
960
+ else
961
+ value
962
+ end
963
+ end
964
+
965
+ def __dsl_normalize_fit_option_keys(options)
966
+ normalized = (options || {}).each_with_object({}) do |(key, value), out|
967
+ out[key.to_sym] = value
968
+ end
969
+ unknown = normalized.keys - FIT_OPTION_DEFAULTS.keys
970
+ unless unknown.empty?
971
+ raise ArgumentError, "unsupported fit option(s): #{unknown.map(&:inspect).join(', ')}"
972
+ end
973
+
974
+ normalized
975
+ end
976
+
977
+ def __dsl_resolve_fit_options(raw_options)
978
+ normalized_raw = __dsl_normalize_fit_option_keys(raw_options)
979
+ normalized_defaults = __dsl_normalize_fit_option_keys(@fit_defaults)
980
+ FIT_OPTION_DEFAULTS.each_with_object({}) do |(key, fallback), out|
981
+ if normalized_raw.fetch(key).equal?(UNSET)
982
+ out[key] = if normalized_defaults.key?(key)
983
+ normalized_defaults.fetch(key)
984
+ else
985
+ __dsl_clone_config_value(fallback)
986
+ end
987
+ else
988
+ out[key] = normalized_raw.fetch(key)
989
+ end
990
+ end
991
+ end
992
+
993
+ def __dsl_merge_fit_preset(name, overrides)
994
+ key = name.to_sym
995
+ unless @fit_presets.key?(key)
996
+ raise ArgumentError, "unknown fit preset: #{name.inspect}"
997
+ end
998
+
999
+ __dsl_normalize_fit_option_keys(@fit_defaults)
1000
+ .merge(__dsl_normalize_fit_option_keys(@fit_presets.fetch(key)))
1001
+ .merge(__dsl_normalize_fit_option_keys(overrides))
1002
+ end
1003
+
1004
+ def __dsl_builtin_task_presets
1005
+ {
1006
+ classification: {
1007
+ collate: :xy,
1008
+ monitor: :epoch_loss,
1009
+ monitor_mode: :min
1010
+ },
1011
+ regression: {
1012
+ collate: :xy,
1013
+ monitor: :epoch_loss,
1014
+ monitor_mode: :min
1015
+ },
1016
+ language_modeling: {
1017
+ collate: :xy,
1018
+ monitor: :perplexity,
1019
+ monitor_mode: :min,
1020
+ metric: ->(context) { Math.exp(context.fetch(:epoch_loss).to_f) }
1021
+ }
1022
+ }
1023
+ end
1024
+
1025
+ def __dsl_task_fit_options(task, overrides)
1026
+ key = task.to_sym
1027
+ unless @task_presets.key?(key)
1028
+ raise ArgumentError, "unknown task preset: #{task.inspect}"
1029
+ end
1030
+
1031
+ __dsl_normalize_fit_option_keys(@task_presets.fetch(key))
1032
+ .merge(__dsl_normalize_fit_option_keys(overrides))
1033
+ end
1034
+
1035
+ def __dsl_resolve_dataflow_profile(name)
1036
+ key = name.to_sym
1037
+ unless @dataflow_profiles.key?(key)
1038
+ raise ArgumentError, "unknown dataflow profile: #{name.inspect}"
1039
+ end
1040
+
1041
+ __dsl_clone_config_value(@dataflow_profiles.fetch(key))
1042
+ end
1043
+
1044
+ def __dsl_normalize_dataflow_profile(train:, validation:)
1045
+ {
1046
+ train: __dsl_normalize_dataflow_split(train, split: :train),
1047
+ validation: __dsl_normalize_dataflow_split(validation, split: :validation)
1048
+ }
1049
+ end
1050
+
1051
+ def __dsl_normalize_dataflow_split(spec, split:)
1052
+ return {} if spec.nil?
1053
+ unless spec.is_a?(Hash)
1054
+ raise ArgumentError, "#{split} dataflow spec must be a Hash or nil"
1055
+ end
1056
+
1057
+ normalized = spec.each_with_object({}) do |(key, value), out|
1058
+ out[key.to_sym] = value
1059
+ end
1060
+ allowed = %i[collate transform limit reduce]
1061
+ unknown = normalized.keys - allowed
1062
+ unless unknown.empty?
1063
+ raise ArgumentError, "#{split} dataflow spec has unsupported key(s): #{unknown.map(&:inspect).join(', ')}"
1064
+ end
1065
+
1066
+ normalized
1067
+ end
1068
+
1069
+ def __dsl_compose_dataflow_profile(base, overlay)
1070
+ {
1071
+ train: base.fetch(:train, {}).merge(overlay.fetch(:train, {})),
1072
+ validation: base.fetch(:validation, {}).merge(overlay.fetch(:validation, {}))
1073
+ }
1074
+ end
1075
+
1076
+ def __dsl_normalize_dataflow_overrides(overrides)
1077
+ profile_overrides = { train: {}, validation: {} }
1078
+ direct_overrides = {}
1079
+ (overrides || {}).each do |key, value|
1080
+ key = key.to_sym
1081
+ case key
1082
+ when :train, :validation
1083
+ profile_overrides[key] = __dsl_normalize_dataflow_split(value, split: key)
1084
+ else
1085
+ direct_overrides[key] = value
1086
+ end
1087
+ end
1088
+ [profile_overrides, direct_overrides]
1089
+ end
1090
+
1091
+ def __dsl_fit_kwargs_from_dataflow(profile)
1092
+ {
1093
+ collate: profile.fetch(:train).fetch(:collate, UNSET),
1094
+ train_transform: profile.fetch(:train).fetch(:transform, UNSET),
1095
+ limit: profile.fetch(:train).fetch(:limit, UNSET),
1096
+ reduce: profile.fetch(:train).fetch(:reduce, UNSET),
1097
+ validation_collate: profile.fetch(:validation).fetch(:collate, UNSET),
1098
+ validation_transform: profile.fetch(:validation).fetch(:transform, UNSET),
1099
+ validation_limit: profile.fetch(:validation).fetch(:limit, UNSET),
1100
+ validation_reduce: profile.fetch(:validation).fetch(:reduce, UNSET)
1101
+ }.each_with_object({}) do |(key, value), out|
1102
+ out[key] = value unless value.equal?(UNSET)
1103
+ end
1104
+ end
1105
+
1106
+ def __dsl_call_hook_pack(pack, options)
1107
+ values = { trainer: self, options: options || {} }
1108
+ (options || {}).each do |key, value|
1109
+ values[key.to_sym] = value
1110
+ end
1111
+ if !pack.respond_to?(:parameters) || pack.parameters.empty?
1112
+ return instance_exec(&pack) if pack.is_a?(Proc)
1113
+
1114
+ return pack.call
1115
+ end
1116
+
1117
+ params = pack.parameters
1118
+ args = __dsl_build_positional_args(
1119
+ params,
1120
+ values,
1121
+ [[:trainer, self], [:options, options || {}]],
1122
+ "hook pack"
1123
+ )
1124
+ kwargs = __dsl_build_keyword_args(params, values, "hook pack")
1125
+ return pack.call(*args) if kwargs.empty?
1126
+
1127
+ pack.call(*args, **kwargs)
1128
+ end
1129
+
1130
+ def __dsl_normalize_sync(sync)
1131
+ mode = sync.nil? ? :none : sync.to_sym
1132
+ return mode if %i[none step epoch].include?(mode)
1133
+
1134
+ raise ArgumentError, "trainer sync must be one of :none, :step, or :epoch"
1135
+ end
1136
+
1137
+ def __dsl_build_train_step(model, optimizer:, clip_grad_norm:, compile:, &loss_block)
1138
+ params = model.method(:train_step).parameters
1139
+ accepts_keyrest = params.any? { |type, _name| type == :keyrest }
1140
+ accepts_keyword = lambda do |key|
1141
+ accepts_keyrest || params.any? do |type, name|
1142
+ (type == :key || type == :keyreq) && name == key
1143
+ end
1144
+ end
1145
+
1146
+ kwargs = {
1147
+ optimizer: optimizer,
1148
+ clip_grad_norm: clip_grad_norm
1149
+ }
1150
+ kwargs[:compile] = compile if accepts_keyword.call(:compile)
1151
+ kwargs[:sync] = (@sync_mode == :step ? :step : :none) if accepts_keyword.call(:sync)
1152
+
1153
+ model.train_step(**kwargs, &loss_block)
1154
+ end
1155
+
1156
+ def __dsl_resume_state(resume_from, monitor_name)
1157
+ return __dsl_empty_resume_state(monitor_name) if resume_from.nil? || resume_from.to_s.empty?
1158
+
1159
+ source = resume_from
1160
+ source = __dsl_call_resume_loader(source, monitor_name) if source.respond_to?(:call)
1161
+ return __dsl_empty_resume_state(monitor_name) if source.nil?
1162
+
1163
+ if source.is_a?(Hash)
1164
+ payload = if __dsl_run_bundle_hash?(source)
1165
+ resume_payload_from_bundle(source)
1166
+ else
1167
+ source
1168
+ end
1169
+ resume_path = nil
1170
+ else
1171
+ source_path = source.to_s
1172
+ bundle_payload = __dsl_resume_payload_from_run_bundle_path(source_path)
1173
+ if bundle_payload.nil?
1174
+ unless @model.respond_to?(:load_checkpoint)
1175
+ raise ArgumentError, "resume_from requires model to implement #load_checkpoint"
1176
+ end
1177
+
1178
+ payload = __dsl_load_checkpoint_for_resume(source)
1179
+ resume_path = source_path
1180
+ else
1181
+ payload = bundle_payload
1182
+ resume_path = source_path
1183
+ end
1184
+ end
1185
+ unless payload.is_a?(Hash)
1186
+ raise ArgumentError, "resume checkpoint payload must be a Hash"
1187
+ end
1188
+
1189
+ metadata = payload["metadata"]
1190
+ metadata = {} unless metadata.is_a?(Hash)
1191
+
1192
+ checkpoint_epoch_value = __dsl_resume_state_value(payload, metadata, "epoch")
1193
+ checkpoint_epoch = checkpoint_epoch_value.nil? ? nil : checkpoint_epoch_value.to_i
1194
+ start_epoch = checkpoint_epoch.nil? ? 0 : checkpoint_epoch + 1
1195
+ best_metric = __dsl_resume_state_value(payload, metadata, "best_metric")
1196
+ stale_epochs_value = __dsl_resume_state_value(payload, metadata, "stale_epochs")
1197
+ stale_epochs = stale_epochs_value.nil? ? 0 : stale_epochs_value.to_i
1198
+ if stale_epochs.negative?
1199
+ raise ArgumentError, "resume checkpoint stale_epochs must be non-negative"
1200
+ end
1201
+
1202
+ resume_monitor_name = __dsl_resume_state_value(payload, metadata, "monitor_name")
1203
+ if !resume_monitor_name.nil? && resume_monitor_name.to_s != monitor_name.to_s
1204
+ raise ArgumentError,
1205
+ "resume checkpoint monitor_name #{resume_monitor_name.inspect} does not match requested monitor #{monitor_name.inspect}"
1206
+ end
1207
+
1208
+ {
1209
+ path: resume_path,
1210
+ checkpoint_epoch: checkpoint_epoch,
1211
+ start_epoch: start_epoch,
1212
+ best_metric: best_metric,
1213
+ stale_epochs: stale_epochs,
1214
+ monitor_name: monitor_name.to_s
1215
+ }
1216
+ end
1217
+
1218
+ def __dsl_run_bundle_hash?(value)
1219
+ value.is_a?(Hash) && value.key?("checkpoint") && value.key?("report")
1220
+ end
1221
+
1222
+ def __dsl_resume_payload_from_run_bundle_path(path)
1223
+ return nil if path.nil? || path.empty?
1224
+ return nil unless File.file?(path)
1225
+
1226
+ bundle = JSON.parse(File.binread(path))
1227
+ return nil unless __dsl_run_bundle_hash?(bundle)
1228
+
1229
+ resume_payload_from_bundle(bundle)
1230
+ rescue JSON::ParserError
1231
+ nil
1232
+ end
1233
+
1234
+ def __dsl_empty_resume_state(monitor_name)
1235
+ {
1236
+ path: nil,
1237
+ checkpoint_epoch: nil,
1238
+ start_epoch: 0,
1239
+ best_metric: nil,
1240
+ stale_epochs: 0,
1241
+ monitor_name: monitor_name.to_s
1242
+ }
1243
+ end
1244
+
1245
+ def __dsl_call_resume_loader(loader, monitor_name)
1246
+ values = {
1247
+ trainer: self,
1248
+ model: @model,
1249
+ optimizer: @optimizer,
1250
+ monitor_name: monitor_name.to_s
1251
+ }
1252
+ return loader.call unless loader.respond_to?(:parameters)
1253
+
1254
+ params = loader.parameters
1255
+ return loader.call if params.empty?
1256
+
1257
+ args = __dsl_build_positional_args(
1258
+ params,
1259
+ values,
1260
+ [[:trainer, self], [:model, @model], [:optimizer, @optimizer], [:monitor_name, monitor_name.to_s]],
1261
+ "resume loader"
1262
+ )
1263
+ kwargs = __dsl_build_keyword_args(params, values, "resume loader")
1264
+ return loader.call(*args) if kwargs.empty?
1265
+
1266
+ loader.call(*args, **kwargs)
1267
+ end
1268
+
1269
+ def __dsl_load_checkpoint_for_resume(path)
1270
+ params = @model.method(:load_checkpoint).parameters
1271
+ accepts_keyrest = params.any? { |type, _name| type == :keyrest }
1272
+ accepts_keyword = lambda do |key|
1273
+ accepts_keyrest || params.any? do |type, name|
1274
+ (type == :key || type == :keyreq) && name == key
1275
+ end
1276
+ end
1277
+
1278
+ kwargs = {}
1279
+ kwargs[:optimizer] = @optimizer if accepts_keyword.call(:optimizer)
1280
+ kwargs[:strict] = true if accepts_keyword.call(:strict)
1281
+ kwargs[:format] = nil if accepts_keyword.call(:format)
1282
+ return @model.load_checkpoint(path) if kwargs.empty?
1283
+
1284
+ @model.load_checkpoint(path, **kwargs)
1285
+ end
1286
+
1287
+ def __dsl_resume_state_value(payload, metadata, key)
1288
+ return metadata[key] if metadata.key?(key)
1289
+
1290
+ payload[key]
1291
+ end
1292
+
1293
+ def emit(event, context)
1294
+ @hooks[event.to_sym].sort_by { |entry| [entry.fetch(:priority), entry.fetch(:order)] }.each do |entry|
1295
+ entry[:invocations] += 1
1296
+ if !entry[:every].nil? && ((entry[:invocations] - 1) % entry[:every]).nonzero?
1297
+ next
1298
+ end
1299
+ if entry[:once] && entry[:fired]
1300
+ next
1301
+ end
1302
+ unless __dsl_hook_condition_met?(entry[:if], context)
1303
+ next
1304
+ end
1305
+
1306
+ entry.fetch(:hook).call(context)
1307
+ entry[:fired] = true if entry[:once]
1308
+ end
1309
+ end
1310
+
1311
+ def __dsl_hook_condition_met?(condition, context)
1312
+ return true if condition.nil?
1313
+ return !!condition unless condition.respond_to?(:call)
1314
+ return !!condition.call unless condition.respond_to?(:parameters)
1315
+
1316
+ params = condition.parameters
1317
+ return !!condition.call if params.empty?
1318
+
1319
+ if params.any? { |type, _name| type == :keyrest || type == :key || type == :keyreq }
1320
+ return !!condition.call(context: context)
1321
+ end
1322
+
1323
+ !!condition.call(context)
1324
+ end
1325
+
1326
+ def __dsl_sync_epoch(loss)
1327
+ return unless defined?(MLX::Core) && MLX::Core.respond_to?(:eval)
1328
+
1329
+ targets = []
1330
+ targets << loss unless loss.nil?
1331
+ targets << @model.parameters if @model.respond_to?(:parameters)
1332
+ targets << @optimizer.state if @optimizer.respond_to?(:state)
1333
+ return if targets.empty?
1334
+
1335
+ MLX::Core.eval(*targets)
1336
+ end
1337
+
1338
+ def __dsl_apply_collate(collate, batch, kind:, epoch:, batch_index:)
1339
+ collate = __dsl_resolve_registered_collate(collate, kind: kind)
1340
+ collate = __dsl_auto_collate_spec(kind, batch) if collate.to_s == "auto"
1341
+ return batch if collate.nil?
1342
+ if collate.respond_to?(:call)
1343
+ return __dsl_call_collate_callable(
1344
+ collate,
1345
+ batch,
1346
+ kind: kind,
1347
+ epoch: epoch,
1348
+ batch_index: batch_index
1349
+ )
1350
+ end
1351
+
1352
+ case collate
1353
+ when String, Symbol
1354
+ __dsl_apply_named_collate(collate.to_sym, batch, kind: kind)
1355
+ when Hash
1356
+ __dsl_apply_mapping_collate(
1357
+ collate,
1358
+ batch,
1359
+ kind: kind,
1360
+ epoch: epoch,
1361
+ batch_index: batch_index
1362
+ )
1363
+ else
1364
+ raise ArgumentError, "#{kind} collate must be a Proc, Symbol/String, Hash, or nil"
1365
+ end
1366
+ end
1367
+
1368
+ def __dsl_effective_collate(collate, bind, batch, kind:)
1369
+ return collate unless collate.nil?
1370
+ return nil if bind.nil?
1371
+
1372
+ __dsl_bind_to_collate(bind, batch, kind: kind)
1373
+ end
1374
+
1375
+ def __dsl_bind_to_collate(bind, batch, kind:)
1376
+ case bind
1377
+ when true
1378
+ __dsl_infer_bind_mapping(kind, batch)
1379
+ when String, Symbol
1380
+ return __dsl_infer_bind_mapping(kind, batch) if bind.to_s == "auto"
1381
+
1382
+ bind
1383
+ when Hash
1384
+ bind
1385
+ when Proc
1386
+ __dsl_call_collate_callable(
1387
+ bind,
1388
+ batch,
1389
+ kind: kind,
1390
+ epoch: 0,
1391
+ batch_index: 0,
1392
+ label: "#{kind} bind"
1393
+ )
1394
+ else
1395
+ raise ArgumentError, "#{kind} bind must be :auto/true, collate spec, or hash mapping"
1396
+ end
1397
+ end
1398
+
1399
+ def __dsl_infer_bind_mapping(kind, batch)
1400
+ target_params = __dsl_keyword_names_for_callable(kind == :train ? @step.method(:call) : @loss_block)
1401
+ return :xy if target_params.empty? && batch.is_a?(Array) && batch.length >= 2
1402
+ return :x if target_params.empty? && !batch.is_a?(Hash)
1403
+
1404
+ target_params.each_with_index.each_with_object({}) do |(name, index), out|
1405
+ out[name] = __dsl_infer_bind_selector(batch, name, index)
1406
+ end
1407
+ end
1408
+
1409
+ def __dsl_keyword_names_for_callable(callable)
1410
+ return [] unless callable.respond_to?(:parameters)
1411
+
1412
+ callable.parameters.each_with_object([]) do |(type, name), out|
1413
+ out << name if (type == :key || type == :keyreq) && !name.nil?
1414
+ end
1415
+ end
1416
+
1417
+ def __dsl_infer_bind_selector(batch, name, index)
1418
+ if batch.is_a?(Hash)
1419
+ candidates = __dsl_bind_candidates_for(name)
1420
+ found = candidates.find do |candidate|
1421
+ batch.key?(candidate) || batch.key?(candidate.to_s) || batch.key?(candidate.to_sym)
1422
+ end
1423
+ return found unless found.nil?
1424
+
1425
+ return name
1426
+ end
1427
+
1428
+ return index if batch.respond_to?(:[]) && !batch.is_a?(Hash)
1429
+
1430
+ name
1431
+ end
1432
+
1433
+ def __dsl_bind_candidates_for(name)
1434
+ base = [name, name.to_s, name.to_sym]
1435
+ aliases = case name.to_sym
1436
+ when :x
1437
+ %i[input inputs feature features token tokens]
1438
+ when :y
1439
+ %i[target targets label labels output outputs]
1440
+ else
1441
+ []
1442
+ end
1443
+ base + aliases + aliases.map(&:to_s)
1444
+ end
1445
+
1446
+ def __dsl_auto_collate_spec(kind, batch)
1447
+ schema = @batch_schemas.fetch(kind, nil)
1448
+ return schema unless schema.nil?
1449
+
1450
+ if batch.is_a?(Hash)
1451
+ has_x = batch.key?(:x) || batch.key?("x")
1452
+ has_y = batch.key?(:y) || batch.key?("y")
1453
+ return :xy if has_x && has_y
1454
+ return { x: :x } if has_x
1455
+
1456
+ return nil
1457
+ end
1458
+
1459
+ return :xy if batch.is_a?(Array) && batch.length >= 2
1460
+
1461
+ :x
1462
+ end
1463
+
1464
+ def __dsl_normalize_batch_schema_spec(spec)
1465
+ return nil if spec.nil?
1466
+ return spec if spec.respond_to?(:call)
1467
+
1468
+ case spec
1469
+ when String, Symbol, Hash
1470
+ spec
1471
+ else
1472
+ raise ArgumentError, "batch schema must be a collate spec (Symbol/String, Hash, Proc, or nil)"
1473
+ end
1474
+ end
1475
+
1476
+ def __dsl_resolve_registered_collate(collate, kind:)
1477
+ current = collate
1478
+ seen = []
1479
+ while current.is_a?(String) || current.is_a?(Symbol)
1480
+ key = current.to_sym
1481
+ break unless @collate_registry.key?(key)
1482
+ if seen.include?(key)
1483
+ cycle = (seen + [key]).map(&:inspect).join(" -> ")
1484
+ raise ArgumentError, "cyclic #{kind} collate registry reference: #{cycle}"
1485
+ end
1486
+
1487
+ seen << key
1488
+ current = @collate_registry.fetch(key)
1489
+ end
1490
+ current
1491
+ end
1492
+
1493
+ def __dsl_compose_collate(base_spec, overlay_spec)
1494
+ if base_spec.is_a?(Hash) && overlay_spec.is_a?(Hash)
1495
+ return base_spec.merge(overlay_spec)
1496
+ end
1497
+
1498
+ raise ArgumentError, "collate extends composition supports Hash schemas only"
1499
+ end
1500
+
1501
+ def __dsl_apply_named_collate(name, batch, kind:)
1502
+ case name
1503
+ when :xy
1504
+ return batch if batch.is_a?(Hash) && (batch.key?(:x) || batch.key?("x")) && (batch.key?(:y) || batch.key?("y"))
1505
+ unless batch.is_a?(Array) && batch.length >= 2
1506
+ raise ArgumentError, "#{kind} collate :xy expects a 2-item array batch"
1507
+ end
1508
+
1509
+ { x: batch[0], y: batch[1] }
1510
+ when :x
1511
+ { x: batch }
1512
+ else
1513
+ raise ArgumentError, "unknown #{kind} collate schema: #{name.inspect}"
1514
+ end
1515
+ end
1516
+
1517
+ def __dsl_apply_mapping_collate(mapping, batch, kind:, epoch:, batch_index:)
1518
+ mapping.each_with_object({}) do |(out_key, selector), out|
1519
+ out[out_key] = __dsl_collate_select(
1520
+ batch,
1521
+ selector,
1522
+ kind: kind,
1523
+ epoch: epoch,
1524
+ batch_index: batch_index
1525
+ )
1526
+ end
1527
+ end
1528
+
1529
+ def __dsl_collate_select(batch, selector, kind:, epoch:, batch_index:)
1530
+ case selector
1531
+ when Integer
1532
+ unless batch.respond_to?(:[]) && !batch.is_a?(Hash)
1533
+ raise ArgumentError, "#{kind} collate integer selector requires indexable non-hash batch"
1534
+ end
1535
+ batch[selector]
1536
+ when String, Symbol
1537
+ __dsl_collate_fetch_key(batch, selector, kind: kind)
1538
+ when Proc
1539
+ __dsl_call_collate_callable(
1540
+ selector,
1541
+ batch,
1542
+ kind: kind,
1543
+ epoch: epoch,
1544
+ batch_index: batch_index,
1545
+ label: "#{kind} collate selector"
1546
+ )
1547
+ when Array
1548
+ __dsl_collate_select_path(batch, selector, kind: kind)
1549
+ else
1550
+ raise ArgumentError, "#{kind} collate selector must be Integer, String/Symbol, Proc, or Array path"
1551
+ end
1552
+ end
1553
+
1554
+ def __dsl_call_collate_callable(callable, batch, kind:, epoch:, batch_index:, label: nil)
1555
+ values = {
1556
+ batch: batch,
1557
+ epoch: epoch,
1558
+ batch_index: batch_index,
1559
+ kind: kind,
1560
+ trainer: self
1561
+ }
1562
+ return callable.call(batch) unless callable.respond_to?(:parameters)
1563
+
1564
+ params = callable.parameters
1565
+ return callable.call(batch) if params.empty?
1566
+
1567
+ callable_label = label || "#{kind} collate"
1568
+ args = __dsl_build_positional_args(
1569
+ params,
1570
+ values,
1571
+ [[:batch, batch], [:epoch, epoch], [:batch_index, batch_index], [:kind, kind], [:trainer, self]],
1572
+ callable_label
1573
+ )
1574
+ kwargs = __dsl_build_keyword_args(params, values, callable_label)
1575
+ return callable.call(*args) if kwargs.empty?
1576
+
1577
+ callable.call(*args, **kwargs)
1578
+ end
1579
+
1580
+ def __dsl_collate_select_path(batch, selector_path, kind:)
1581
+ if selector_path.empty?
1582
+ raise ArgumentError, "#{kind} collate selector path must not be empty"
1583
+ end
1584
+
1585
+ selector_path.each_with_index.reduce(batch) do |current, (selector, depth)|
1586
+ __dsl_collate_select_path_segment(current, selector, selector_path, depth, kind: kind)
1587
+ end
1588
+ end
1589
+
1590
+ def __dsl_collate_select_path_segment(current, selector, selector_path, depth, kind:)
1591
+ case selector
1592
+ when Integer
1593
+ unless current.respond_to?(:[]) && !current.is_a?(Hash)
1594
+ raise ArgumentError, "#{kind} collate path #{selector_path.inspect} expected indexable non-hash at depth #{depth}"
1595
+ end
1596
+ current[selector]
1597
+ when String, Symbol
1598
+ __dsl_collate_fetch_key(
1599
+ current,
1600
+ selector,
1601
+ kind: kind,
1602
+ context: "in path #{selector_path.inspect} at depth #{depth}"
1603
+ )
1604
+ else
1605
+ raise ArgumentError, "#{kind} collate path #{selector_path.inspect} contains unsupported selector #{selector.inspect}"
1606
+ end
1607
+ end
1608
+
1609
+ def __dsl_collate_fetch_key(batch, selector, kind:, context: nil)
1610
+ unless batch.is_a?(Hash)
1611
+ extra = context.nil? ? "" : " #{context}"
1612
+ raise ArgumentError, "#{kind} collate key selector #{selector.inspect} requires hash batch#{extra}"
1613
+ end
1614
+
1615
+ return batch.fetch(selector) if batch.key?(selector)
1616
+
1617
+ str_key = selector.to_s
1618
+ return batch.fetch(str_key) if batch.key?(str_key)
1619
+
1620
+ sym_key = str_key.to_sym
1621
+ return batch.fetch(sym_key) if batch.key?(sym_key)
1622
+
1623
+ extra = context.nil? ? "" : " #{context}"
1624
+ raise ArgumentError, "#{kind} collate key selector #{selector.inspect} was not found in batch#{extra}"
1625
+ end
1626
+
1627
+ def __dsl_run_batch(batch, epoch:, batch_index:, kind:)
1628
+ if batch.is_a?(Hash)
1629
+ @step.call(**__dsl_normalize_batch_kwargs(batch, label: "#{kind} batch"))
1630
+ elsif batch.is_a?(Array)
1631
+ @step.call(*batch)
1632
+ else
1633
+ @step.call(batch)
1634
+ end
1635
+ rescue StandardError => e
1636
+ __dsl_raise_batch_error!(e, kind: kind, epoch: epoch, batch_index: batch_index)
1637
+ end
1638
+
1639
+ def __dsl_loss_scalar(loss)
1640
+ return nil if loss.nil?
1641
+ return loss.to_f if loss.is_a?(Numeric)
1642
+
1643
+ return loss.item.to_f if loss.respond_to?(:item)
1644
+ return loss.to_f if loss.respond_to?(:to_f)
1645
+
1646
+ nil
1647
+ end
1648
+
1649
+ def __dsl_default_monitor_value(monitor_name, epoch_metric, val_metric)
1650
+ case monitor_name.to_s
1651
+ when "val_loss", "validation_loss"
1652
+ val_metric.nil? ? epoch_metric : val_metric
1653
+ else
1654
+ epoch_metric
1655
+ end
1656
+ end
1657
+
1658
+ def __dsl_monitor_value(metric, context, fallback:)
1659
+ metric_callable = __dsl_resolve_metric_callable(metric)
1660
+ return fallback if metric_callable.nil?
1661
+ return metric_callable.call(context) unless metric_callable.respond_to?(:parameters)
1662
+
1663
+ params = metric_callable.parameters
1664
+ return metric_callable.call(context) if params.empty?
1665
+
1666
+ values = {
1667
+ context: context,
1668
+ trainer: self
1669
+ }
1670
+ args = __dsl_build_positional_args(
1671
+ params,
1672
+ values,
1673
+ [[:context, context], [:trainer, self]],
1674
+ "metric callable"
1675
+ )
1676
+ kwargs = __dsl_build_keyword_args(params, values, "metric callable")
1677
+ return metric_callable.call(*args) if kwargs.empty?
1678
+
1679
+ metric_callable.call(*args, **kwargs)
1680
+ end
1681
+
1682
+ def __dsl_resolve_metric_callable(metric)
1683
+ return nil if metric.nil?
1684
+ return metric if metric.respond_to?(:call)
1685
+ return nil unless metric.is_a?(String) || metric.is_a?(Symbol)
1686
+
1687
+ @metric_registry[metric.to_sym]
1688
+ end
1689
+
1690
+ def __dsl_run_validation_batch(batch, epoch:, batch_index:, kind:)
1691
+ if batch.is_a?(Hash)
1692
+ @loss_block.call(**__dsl_normalize_batch_kwargs(batch, label: "#{kind} batch"))
1693
+ elsif batch.is_a?(Array)
1694
+ @loss_block.call(*batch)
1695
+ else
1696
+ @loss_block.call(batch)
1697
+ end
1698
+ rescue StandardError => e
1699
+ __dsl_raise_batch_error!(e, kind: kind, epoch: epoch, batch_index: batch_index)
1700
+ end
1701
+
1702
+ def __dsl_raise_batch_error!(error, kind:, epoch:, batch_index:)
1703
+ prefix = "#{kind} batch failed at epoch #{epoch}, batch #{batch_index}"
1704
+ raise error.class, "#{prefix}: #{error.message}"
1705
+ end
1706
+
1707
+ def __dsl_normalize_batch_kwargs(batch, label:)
1708
+ batch.each_with_object({}) do |(key, value), out|
1709
+ kw_key = __dsl_normalize_keyword_key(key, label: label)
1710
+ if out.key?(kw_key)
1711
+ raise ArgumentError, "#{label} contains duplicate keyword after normalization: #{kw_key.inspect}"
1712
+ end
1713
+
1714
+ out[kw_key] = value
1715
+ end
1716
+ end
1717
+
1718
+ def __dsl_normalize_keyword_key(key, label:)
1719
+ return key if key.is_a?(Symbol)
1720
+ return key.to_sym if key.respond_to?(:to_sym)
1721
+
1722
+ raise ArgumentError, "#{label} key #{key.inspect} cannot be converted to keyword symbol"
1723
+ end
1724
+
1725
+ def __dsl_reduce_values(values, reducer)
1726
+ return nil if values.empty?
1727
+
1728
+ if reducer.respond_to?(:call)
1729
+ return reducer.call(values)
1730
+ end
1731
+
1732
+ case reducer.to_sym
1733
+ when :mean
1734
+ values.sum / values.length.to_f
1735
+ when :sum
1736
+ values.sum
1737
+ when :last
1738
+ values[-1]
1739
+ else
1740
+ raise ArgumentError, "unsupported reducer: #{reducer.inspect}"
1741
+ end
1742
+ end
1743
+
1744
+ def __dsl_improved?(metric, best_metric, monitor_mode, min_delta: 0.0)
1745
+ return false if metric.nil?
1746
+ return true if best_metric.nil?
1747
+
1748
+ case monitor_mode.to_sym
1749
+ when :min
1750
+ metric < (best_metric - min_delta)
1751
+ when :max
1752
+ metric > (best_metric + min_delta)
1753
+ else
1754
+ raise ArgumentError, "unsupported monitor_mode: #{monitor_mode.inspect}"
1755
+ end
1756
+ end
1757
+
1758
+ def __dsl_normalize_patience(patience)
1759
+ return nil if patience.nil?
1760
+
1761
+ value = patience.to_i
1762
+ raise ArgumentError, "patience must be non-negative" if value.negative?
1763
+
1764
+ value
1765
+ end
1766
+
1767
+ def __dsl_apply_batch_transform(transform, batch, epoch:, batch_index:, kind:)
1768
+ return batch if transform.nil?
1769
+ unless transform.respond_to?(:call)
1770
+ raise ArgumentError, "#{kind} transform must respond to #call"
1771
+ end
1772
+
1773
+ values = {
1774
+ batch: batch,
1775
+ epoch: epoch,
1776
+ batch_index: batch_index,
1777
+ kind: kind,
1778
+ trainer: self
1779
+ }
1780
+ return transform.call(batch) unless transform.respond_to?(:parameters)
1781
+
1782
+ params = transform.parameters
1783
+ return transform.call(batch) if params.empty?
1784
+
1785
+ args = __dsl_build_positional_args(
1786
+ params,
1787
+ values,
1788
+ [[:batch, batch], [:epoch, epoch], [:batch_index, batch_index], [:kind, kind], [:trainer, self]],
1789
+ "batch transform"
1790
+ )
1791
+ kwargs = __dsl_build_keyword_args(params, values, "batch transform")
1792
+ return transform.call(*args) if kwargs.empty?
1793
+
1794
+ transform.call(*args, **kwargs)
1795
+ end
1796
+
1797
+ def __dsl_normalize_min_delta(min_delta)
1798
+ value = min_delta.to_f
1799
+ raise ArgumentError, "min_delta must be non-negative" if value.negative?
1800
+
1801
+ value
1802
+ end
1803
+
1804
+ def __dsl_maybe_checkpoint(
1805
+ path,
1806
+ save_best:,
1807
+ improved:,
1808
+ epoch:,
1809
+ monitor_name:,
1810
+ monitor_value:,
1811
+ epoch_metric:,
1812
+ stale_epochs:,
1813
+ best_metric:,
1814
+ metadata:,
1815
+ checkpoint_every: nil,
1816
+ keep_last_n: nil
1817
+ )
1818
+ return false if path.nil? || path.to_s.empty?
1819
+ unless checkpoint_every.nil?
1820
+ every = checkpoint_every.to_i
1821
+ raise ArgumentError, "checkpoint every must be positive" if every <= 0
1822
+ return false if ((epoch + 1) % every).nonzero?
1823
+ end
1824
+ return false if save_best && !improved
1825
+ return false unless @model.respond_to?(:save_checkpoint)
1826
+
1827
+ resolved_path = __dsl_checkpoint_path(
1828
+ path,
1829
+ epoch: epoch,
1830
+ monitor_name: monitor_name,
1831
+ monitor_value: monitor_value,
1832
+ epoch_metric: epoch_metric,
1833
+ improved: improved
1834
+ )
1835
+
1836
+ merged_metadata = (metadata || {}).dup
1837
+ merged_metadata["epoch"] = epoch
1838
+ merged_metadata["epoch_loss"] = epoch_metric
1839
+ merged_metadata["monitor_name"] = monitor_name
1840
+ merged_metadata["monitor_value"] = monitor_value
1841
+ merged_metadata["stale_epochs"] = stale_epochs
1842
+ merged_metadata["best_metric"] = best_metric
1843
+ merged_metadata["next_epoch"] = epoch + 1
1844
+
1845
+ @model.save_checkpoint(resolved_path, optimizer: @optimizer, metadata: merged_metadata)
1846
+ @last_checkpoint_snapshot = {
1847
+ "path" => resolved_path,
1848
+ "epoch" => epoch,
1849
+ "monitor_name" => monitor_name,
1850
+ "monitor_value" => monitor_value,
1851
+ "epoch_loss" => epoch_metric,
1852
+ "metadata" => __dsl_deep_copy(merged_metadata)
1853
+ }
1854
+ @checkpoint_history << __dsl_deep_copy(@last_checkpoint_snapshot)
1855
+ unless keep_last_n.nil?
1856
+ keep = keep_last_n.to_i
1857
+ raise ArgumentError, "artifact retention keep_last_n must be non-negative" if keep.negative?
1858
+ @checkpoint_history.shift while @checkpoint_history.length > keep
1859
+ end
1860
+ emit(
1861
+ :checkpoint,
1862
+ {
1863
+ model: @model,
1864
+ optimizer: @optimizer,
1865
+ path: resolved_path,
1866
+ epoch: epoch,
1867
+ monitor_name: monitor_name,
1868
+ monitor_value: monitor_value,
1869
+ epoch_loss: epoch_metric,
1870
+ improved: improved
1871
+ }
1872
+ )
1873
+ true
1874
+ end
1875
+
1876
+ def __dsl_checkpoint_path(path, epoch:, monitor_name:, monitor_value:, epoch_metric:, improved:)
1877
+ if path.respond_to?(:call)
1878
+ values = {
1879
+ epoch: epoch,
1880
+ next_epoch: epoch + 1,
1881
+ monitor: monitor_value,
1882
+ monitor_name: monitor_name,
1883
+ epoch_loss: epoch_metric,
1884
+ improved: improved,
1885
+ trainer: self,
1886
+ model: @model,
1887
+ optimizer: @optimizer
1888
+ }
1889
+ path = if !path.respond_to?(:parameters) || path.parameters.empty?
1890
+ path.call
1891
+ else
1892
+ args = __dsl_build_positional_args(
1893
+ path.parameters,
1894
+ values,
1895
+ [
1896
+ [:epoch, epoch],
1897
+ [:next_epoch, epoch + 1],
1898
+ [:monitor, monitor_value],
1899
+ [:monitor_name, monitor_name],
1900
+ [:epoch_loss, epoch_metric],
1901
+ [:improved, improved],
1902
+ [:trainer, self],
1903
+ [:model, @model],
1904
+ [:optimizer, @optimizer]
1905
+ ],
1906
+ "checkpoint path"
1907
+ )
1908
+ kwargs = __dsl_build_keyword_args(path.parameters, values, "checkpoint path")
1909
+ kwargs.empty? ? path.call(*args) : path.call(*args, **kwargs)
1910
+ end
1911
+ unless path.respond_to?(:to_str)
1912
+ raise ArgumentError, "checkpoint path callable must return a String-compatible path"
1913
+ end
1914
+
1915
+ path = path.to_str
1916
+ end
1917
+
1918
+ template = path.to_s
1919
+ return template unless template.include?("%{")
1920
+
1921
+ template % {
1922
+ epoch: epoch,
1923
+ next_epoch: epoch + 1,
1924
+ monitor: monitor_value,
1925
+ monitor_name: monitor_name,
1926
+ epoch_loss: epoch_metric,
1927
+ improved: improved
1928
+ }
1929
+ rescue KeyError => e
1930
+ raise ArgumentError, "unsupported checkpoint template key: #{e.message}"
1931
+ end
1932
+
1933
+ def __dsl_resolve_run_bundle(bundle_or_path)
1934
+ if bundle_or_path.is_a?(Hash)
1935
+ return bundle_or_path
1936
+ end
1937
+
1938
+ path = bundle_or_path.to_s
1939
+ if path.empty?
1940
+ raise ArgumentError, "run bundle source must be a bundle hash or path"
1941
+ end
1942
+
1943
+ JSON.parse(File.binread(path))
1944
+ end
1945
+
1946
+ def __dsl_deep_copy(value)
1947
+ return nil if value.nil?
1948
+
1949
+ Marshal.load(Marshal.dump(value))
1950
+ end
1951
+
1952
+ def __dsl_dataset_size(dataset)
1953
+ return nil if dataset.nil? || dataset.respond_to?(:call)
1954
+ return dataset.size if dataset.respond_to?(:size)
1955
+ return dataset.length if dataset.respond_to?(:length)
1956
+
1957
+ nil
1958
+ end
1959
+
1960
+ def __dsl_dataset_for_epoch(dataset, epoch:, kind:)
1961
+ source = if dataset.respond_to?(:call)
1962
+ __dsl_call_dataset_factory(dataset, epoch: epoch, kind: kind)
1963
+ else
1964
+ dataset
1965
+ end
1966
+ unless source.respond_to?(:each)
1967
+ raise ArgumentError, "#{kind} dataset must respond to #each"
1968
+ end
1969
+
1970
+ if epoch.positive? && !dataset.respond_to?(:call) && source.respond_to?(:rewind)
1971
+ begin
1972
+ source.rewind
1973
+ rescue StandardError => e
1974
+ raise ArgumentError, "#{kind} dataset could not rewind for epoch #{epoch}: #{e.message}"
1975
+ end
1976
+ end
1977
+
1978
+ source
1979
+ end
1980
+
1981
+ def __dsl_call_dataset_factory(factory, epoch:, kind:)
1982
+ return factory.call unless factory.respond_to?(:parameters)
1983
+
1984
+ params = factory.parameters
1985
+ values = {
1986
+ epoch: epoch,
1987
+ trainer: self,
1988
+ kind: kind
1989
+ }
1990
+ return factory.call if params.empty?
1991
+
1992
+ args = __dsl_build_positional_args(
1993
+ params,
1994
+ values,
1995
+ [[:epoch, epoch], [:kind, kind], [:trainer, self]],
1996
+ "dataset factory"
1997
+ )
1998
+ kwargs = __dsl_build_keyword_args(params, values, "dataset factory")
1999
+ return factory.call(*args) if kwargs.empty?
2000
+
2001
+ factory.call(*args, **kwargs)
2002
+ end
2003
+
2004
+ def __dsl_resolve_loop_limit(limit, epoch:, kind:)
2005
+ return nil if limit.nil?
2006
+
2007
+ raw = if limit.respond_to?(:call)
2008
+ values = {
2009
+ epoch: epoch,
2010
+ kind: kind,
2011
+ trainer: self
2012
+ }
2013
+ return limit.call if !limit.respond_to?(:parameters) || limit.parameters.empty?
2014
+
2015
+ args = __dsl_build_positional_args(
2016
+ limit.parameters,
2017
+ values,
2018
+ [[:epoch, epoch], [:kind, kind], [:trainer, self]],
2019
+ "#{kind} limit"
2020
+ )
2021
+ kwargs = __dsl_build_keyword_args(limit.parameters, values, "#{kind} limit")
2022
+ kwargs.empty? ? limit.call(*args) : limit.call(*args, **kwargs)
2023
+ else
2024
+ limit
2025
+ end
2026
+ return nil if raw.nil?
2027
+ unless raw.respond_to?(:to_int)
2028
+ raise ArgumentError, "#{kind} limit must be an Integer, nil, or callable returning one"
2029
+ end
2030
+
2031
+ value = raw.to_int
2032
+ raise ArgumentError, "#{kind} limit must be non-negative" if value.negative?
2033
+
2034
+ value
2035
+ end
2036
+
2037
+ def __dsl_build_positional_args(params, values, fallback_pairs, label)
2038
+ queue = fallback_pairs.dup
2039
+ args = []
2040
+ params.each do |type, name|
2041
+ next unless type == :req || type == :opt
2042
+
2043
+ if !name.nil? && values.key?(name)
2044
+ args << values.fetch(name)
2045
+ queue.reject! { |key, _value| key == name }
2046
+ next
2047
+ end
2048
+
2049
+ if queue.empty?
2050
+ raise ArgumentError, "#{label} has unsupported required positional argument: #{name.inspect}" if type == :req
2051
+ break
2052
+ end
2053
+
2054
+ _key, value = queue.shift
2055
+ args << value
2056
+ end
2057
+ args
2058
+ end
2059
+
2060
+ def __dsl_build_keyword_args(params, values, label)
2061
+ return values.dup if params.any? { |type, _name| type == :keyrest }
2062
+
2063
+ required_keys = params.each_with_object([]) do |(type, name), out|
2064
+ out << name if type == :keyreq
2065
+ end
2066
+ missing = required_keys.reject { |name| values.key?(name) }
2067
+ unless missing.empty?
2068
+ raise ArgumentError, "#{label} requires unsupported keyword argument(s): #{missing.map(&:inspect).join(", ")}"
2069
+ end
2070
+
2071
+ accepted_keys = params.each_with_object([]) do |(type, name), out|
2072
+ out << name if type == :key || type == :keyreq
2073
+ end
2074
+
2075
+ values.each_with_object({}) do |(name, value), out|
2076
+ out[name] = value if accepted_keys.include?(name)
2077
+ end
2078
+ end
2079
+
2080
+ def __dsl_validate_data_reuse!(strict:, dataset:, kind:, epoch:, previous_batches:, current_batches:)
2081
+ return unless strict
2082
+ return if dataset.nil? || dataset.respond_to?(:call)
2083
+ return unless epoch.positive?
2084
+ return unless previous_batches.to_i.positive? && current_batches.to_i.zero?
2085
+ return if dataset.respond_to?(:rewind)
2086
+
2087
+ raise ArgumentError,
2088
+ "#{kind} dataset appears exhausted across epochs; pass a factory like ->(epoch:) { ... }"
2089
+ end
2090
+
2091
+ def __dsl_with_eval_mode
2092
+ if @model.respond_to?(:eval_mode)
2093
+ @model.eval_mode { yield }
2094
+ return
2095
+ end
2096
+
2097
+ unless @model.respond_to?(:eval) && @model.respond_to?(:train) && @model.respond_to?(:training)
2098
+ yield
2099
+ return
2100
+ end
2101
+
2102
+ previous = @model.training
2103
+ @model.eval
2104
+ yield
2105
+ ensure
2106
+ @model.train(previous) unless previous.nil?
2107
+ end
2108
+ end
2109
+ end
2110
+ end