mlx 0.30.7 → 0.30.7.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,706 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "set"
4
+ require "json"
5
+ require "fileutils"
6
+
7
+ module MLX
8
+ module DSL
9
+ module ModelMixin
10
+ UNSET = Object.new.freeze
11
+
12
+ class OptimizerGroupsBuilder
13
+ def initialize
14
+ @groups = []
15
+ end
16
+
17
+ def group(matcher = nil, &factory)
18
+ raise ArgumentError, "group requires an optimizer block" unless block_given?
19
+
20
+ optimizer = factory.call
21
+ unless optimizer.is_a?(MLX::Optimizers::Optimizer)
22
+ raise TypeError, "group block must return an MLX::Optimizers::Optimizer"
23
+ end
24
+
25
+ @groups << {
26
+ optimizer: optimizer,
27
+ filter: matcher_lambda(matcher)
28
+ }
29
+ optimizer
30
+ end
31
+
32
+ def build
33
+ if @groups.empty?
34
+ raise ArgumentError, "optimizer_groups requires at least one group"
35
+ end
36
+
37
+ return @groups.first[:optimizer] if @groups.length == 1
38
+
39
+ MLX::Optimizers::MultiOptimizer.new(
40
+ @groups.map { |entry| entry[:optimizer] },
41
+ filters: @groups[0...-1].map { |entry| entry[:filter] }
42
+ )
43
+ end
44
+
45
+ private
46
+
47
+ def matcher_lambda(matcher)
48
+ case matcher
49
+ when nil
50
+ lambda { |_path, _grad| true }
51
+ when Regexp
52
+ lambda { |path, _grad| matcher.match?(path.to_s) }
53
+ when String, Symbol
54
+ target = matcher.to_s
55
+ lambda { |path, _grad| path.to_s == target }
56
+ when Proc
57
+ lambda do |path, grad|
58
+ matcher.call(path.to_s, grad)
59
+ end
60
+ else
61
+ raise ArgumentError, "unsupported group matcher: #{matcher.class}"
62
+ end
63
+ end
64
+ end
65
+
66
+ def self.included(base)
67
+ base.extend(ClassMethods)
68
+ base.prepend(Initializer)
69
+ end
70
+
71
+ module Initializer
72
+ def initialize(*args, **kwargs, &block)
73
+ dsl_options = __dsl_extract_declared_options(kwargs)
74
+ unknown_kwargs = __dsl_unknown_kwargs_for_super(kwargs, method(__method__).super_method)
75
+ unless unknown_kwargs.empty?
76
+ names = unknown_kwargs.map(&:to_s).sort.join(", ")
77
+ raise ArgumentError, "unknown option(s): #{names}"
78
+ end
79
+
80
+ super(*args, **kwargs, &block)
81
+ __dsl_apply_option_values(dsl_options)
82
+ __dsl_materialize_declarations!
83
+ end
84
+ end
85
+
86
+ module ClassMethods
87
+ def option(name, default: UNSET, required: UNSET)
88
+ key = name.to_s
89
+ required = default.equal?(UNSET) if required.equal?(UNSET)
90
+ dsl_option_definitions[key] = { default: default, required: !!required }
91
+ end
92
+
93
+ def layer(name, factory = nil, *factory_args, **factory_kwargs, &block)
94
+ if factory.nil? && !block_given?
95
+ raise ArgumentError, "layer requires either a factory or block"
96
+ end
97
+ if !factory.nil? && block_given?
98
+ raise ArgumentError, "layer cannot accept both a factory and block"
99
+ end
100
+
101
+ dsl_declarations << {
102
+ kind: :layer,
103
+ name: name.to_s,
104
+ factory: factory,
105
+ factory_args: factory_args,
106
+ factory_kwargs: factory_kwargs,
107
+ block: block
108
+ }
109
+ end
110
+
111
+ def network(name, factory = nil, *factory_args, **factory_kwargs, &block)
112
+ layer(name, factory, *factory_args, **factory_kwargs, &block)
113
+ end
114
+
115
+ def param(name, shape:, init: nil, dtype: UNSET)
116
+ dsl_declarations << {
117
+ kind: :param,
118
+ name: name.to_s,
119
+ shape: shape,
120
+ init: init,
121
+ dtype: dtype
122
+ }
123
+ end
124
+
125
+ def buffer(name, shape:, init: nil, dtype: UNSET)
126
+ dsl_declarations << {
127
+ kind: :buffer,
128
+ name: name.to_s,
129
+ shape: shape,
130
+ init: init,
131
+ dtype: dtype
132
+ }
133
+ end
134
+
135
+ def dsl_option_definitions
136
+ @dsl_option_definitions ||= {}
137
+ end
138
+
139
+ def dsl_declarations
140
+ @dsl_declarations ||= []
141
+ end
142
+
143
+ def inherited(subclass)
144
+ super
145
+ copied_options = dsl_option_definitions.each_with_object({}) do |(key, value), out|
146
+ out[key] = value.dup
147
+ end
148
+ copied_declarations = dsl_declarations.map(&:dup)
149
+ subclass.instance_variable_set(:@dsl_option_definitions, copied_options)
150
+ subclass.instance_variable_set(:@dsl_declarations, copied_declarations)
151
+ end
152
+ end
153
+
154
+ include TrainStepMethods
155
+
156
+ def optimizer_groups(&block)
157
+ raise ArgumentError, "optimizer_groups requires a block" unless block_given?
158
+
159
+ builder = OptimizerGroupsBuilder.new
160
+ builder.instance_eval(&block)
161
+ builder.build
162
+ end
163
+
164
+ def trainer(optimizer:, clip_grad_norm: nil, compile: false, sync: :none, &loss_block)
165
+ MLX::DSL::Trainer.new(
166
+ model: self,
167
+ optimizer: optimizer,
168
+ clip_grad_norm: clip_grad_norm,
169
+ compile: compile,
170
+ sync: sync,
171
+ &loss_block
172
+ )
173
+ end
174
+
175
+ def save_checkpoint(path, optimizer: nil, metadata: {}, format: nil)
176
+ checkpoint_format = __dsl_checkpoint_format(path, format)
177
+ if checkpoint_format == :marshal
178
+ __dsl_ensure_parent_dir!(path)
179
+ payload = {
180
+ "format" => "mlx_dsl_checkpoint_v1",
181
+ "model" => __dsl_serialize_tree(parameters),
182
+ "metadata" => metadata || {}
183
+ }
184
+ payload["optimizer"] = __dsl_serialize_tree(optimizer.state) unless optimizer.nil?
185
+
186
+ File.binwrite(path, Marshal.dump(payload))
187
+ return path
188
+ end
189
+
190
+ __dsl_save_native_checkpoint(path, checkpoint_format, optimizer: optimizer, metadata: metadata)
191
+ end
192
+
193
+ def load_checkpoint(path, optimizer: nil, strict: true, format: nil)
194
+ resolved_path = __dsl_resolve_load_checkpoint_path(path, format)
195
+ checkpoint_format = __dsl_checkpoint_format(resolved_path, format)
196
+ if checkpoint_format == :marshal
197
+ payload = Marshal.load(File.binread(resolved_path))
198
+ unless payload.is_a?(Hash) && payload["format"] == "mlx_dsl_checkpoint_v1"
199
+ raise ArgumentError, "unsupported checkpoint format"
200
+ end
201
+
202
+ model_state = __dsl_deserialize_tree(payload.fetch("model"))
203
+ update(model_state, strict: strict)
204
+
205
+ if !optimizer.nil? && payload.key?("optimizer")
206
+ optimizer.state = __dsl_deserialize_tree(payload["optimizer"])
207
+ end
208
+
209
+ return payload
210
+ end
211
+
212
+ weights_path = __dsl_checkpoint_weights_path(resolved_path, checkpoint_format)
213
+ load_weights(weights_path, strict: strict)
214
+ payload = __dsl_load_native_checkpoint_payload(weights_path)
215
+
216
+ if !optimizer.nil? && payload.key?("optimizer")
217
+ optimizer.state = __dsl_deserialize_tree(payload["optimizer"])
218
+ end
219
+
220
+ payload
221
+ end
222
+
223
+ def train_mode
224
+ previous = training
225
+ train(true)
226
+ return self unless block_given?
227
+
228
+ yield(self)
229
+ ensure
230
+ train(previous) unless previous.nil?
231
+ end
232
+
233
+ def eval_mode
234
+ previous = training
235
+ eval
236
+ return self unless block_given?
237
+
238
+ yield(self)
239
+ ensure
240
+ train(previous) unless previous.nil?
241
+ end
242
+
243
+ def freeze_paths!(matcher, strict: false)
244
+ selected = __dsl_select_paths(matcher)
245
+ if strict && selected.empty?
246
+ raise KeyError, "no parameter paths matched #{matcher.inspect}"
247
+ end
248
+
249
+ __dsl_toggle_paths(selected, freeze: true)
250
+ end
251
+
252
+ def unfreeze_paths!(matcher, strict: false)
253
+ selected = __dsl_select_paths(matcher)
254
+ if strict && selected.empty?
255
+ raise KeyError, "no parameter paths matched #{matcher.inspect}"
256
+ end
257
+
258
+ __dsl_toggle_paths(selected, freeze: false)
259
+ end
260
+
261
+ def parameter_paths(matcher: nil)
262
+ all_paths = MLX::Utils.tree_flatten(parameters, destination: {}).keys.sort
263
+ return all_paths if matcher.nil?
264
+
265
+ path_matcher = __dsl_path_matcher(matcher)
266
+ all_paths.select { |path| path_matcher.call(path) }
267
+ end
268
+
269
+ def parameter_count
270
+ __dsl_count_parameters(parameters)
271
+ end
272
+
273
+ def trainable_parameter_count
274
+ __dsl_count_parameters(trainable_parameters)
275
+ end
276
+
277
+ def summary(as: :hash)
278
+ total = parameter_count
279
+ trainable = trainable_parameter_count
280
+ payload = {
281
+ "model_class" => self.class.name.to_s,
282
+ "total_parameters" => total,
283
+ "trainable_parameters" => trainable,
284
+ "frozen_parameters" => total - trainable,
285
+ "parameter_paths" => parameter_paths
286
+ }
287
+ case as.to_sym
288
+ when :hash
289
+ payload
290
+ when :text
291
+ __dsl_summary_text(payload)
292
+ else
293
+ raise ArgumentError, "summary :as must be :hash or :text"
294
+ end
295
+ end
296
+
297
+ private
298
+
299
+ def __dsl_summary_text(payload)
300
+ [
301
+ "model=#{payload.fetch('model_class')}",
302
+ "total_parameters=#{payload.fetch('total_parameters')}",
303
+ "trainable_parameters=#{payload.fetch('trainable_parameters')}",
304
+ "frozen_parameters=#{payload.fetch('frozen_parameters')}",
305
+ "parameter_paths=#{payload.fetch('parameter_paths').join(',')}"
306
+ ].join("\n")
307
+ end
308
+
309
+ def __dsl_count_parameters(tree)
310
+ flat = MLX::Utils.tree_flatten(tree, destination: {})
311
+ flat.values.sum { |value| __dsl_leaf_numel(value) }
312
+ end
313
+
314
+ def __dsl_leaf_numel(value)
315
+ if value.respond_to?(:shape)
316
+ shape = value.shape
317
+ if shape.is_a?(Array)
318
+ return 1 if shape.empty?
319
+
320
+ return shape.reduce(1) { |acc, dim| acc * dim.to_i }
321
+ end
322
+ end
323
+ return value.size.to_i if value.respond_to?(:size)
324
+
325
+ 1
326
+ end
327
+
328
+ def __dsl_checkpoint_format(path, format)
329
+ raw = if format.nil?
330
+ ext = File.extname(path.to_s).delete_prefix(".").downcase
331
+ ext.empty? ? "marshal" : ext
332
+ else
333
+ format.to_s.downcase
334
+ end
335
+
336
+ case raw
337
+ when "marshal", "bin", "legacy", "dsl_v1"
338
+ :marshal
339
+ when "npz"
340
+ :npz
341
+ when "safetensors"
342
+ :safetensors
343
+ else
344
+ raise ArgumentError, "unsupported checkpoint format: #{raw.inspect}"
345
+ end
346
+ end
347
+
348
+ def __dsl_checkpoint_weights_path(path, checkpoint_format)
349
+ target = path.to_s
350
+ case checkpoint_format
351
+ when :npz
352
+ target.end_with?(".npz") ? target : "#{target}.npz"
353
+ when :safetensors
354
+ target.end_with?(".safetensors") ? target : "#{target}.safetensors"
355
+ else
356
+ target
357
+ end
358
+ end
359
+
360
+ def __dsl_checkpoint_sidecar_path(weights_path)
361
+ "#{weights_path}.mlxmeta.json"
362
+ end
363
+
364
+ def __dsl_resolve_load_checkpoint_path(path, format)
365
+ target = path.to_s
366
+ return target unless format.nil?
367
+ return target unless File.extname(target).empty?
368
+ return target if File.exist?(target)
369
+
370
+ candidates = ["#{target}.npz", "#{target}.safetensors"]
371
+ found = candidates.find { |candidate| File.exist?(candidate) }
372
+ found.nil? ? target : found
373
+ end
374
+
375
+ def __dsl_save_native_checkpoint(path, checkpoint_format, optimizer:, metadata:)
376
+ weights_path = __dsl_checkpoint_weights_path(path, checkpoint_format)
377
+ __dsl_ensure_parent_dir!(weights_path)
378
+ save_weights(weights_path)
379
+
380
+ payload = {
381
+ "format" => "mlx_dsl_checkpoint_v2_native",
382
+ "weights_format" => checkpoint_format.to_s,
383
+ "metadata" => metadata || {}
384
+ }
385
+ payload["optimizer"] = __dsl_serialize_tree(optimizer.state) unless optimizer.nil?
386
+
387
+ File.binwrite(__dsl_checkpoint_sidecar_path(weights_path), JSON.generate(payload))
388
+ weights_path
389
+ end
390
+
391
+ def __dsl_ensure_parent_dir!(path)
392
+ dir = File.dirname(path.to_s)
393
+ return if dir.nil? || dir.empty? || dir == "."
394
+
395
+ FileUtils.mkdir_p(dir)
396
+ end
397
+
398
+ def __dsl_load_native_checkpoint_payload(weights_path)
399
+ sidecar_path = __dsl_checkpoint_sidecar_path(weights_path)
400
+ payload = if File.exist?(sidecar_path)
401
+ JSON.parse(File.binread(sidecar_path))
402
+ else
403
+ {}
404
+ end
405
+ unless payload.is_a?(Hash)
406
+ raise ArgumentError, "invalid native checkpoint sidecar payload"
407
+ end
408
+
409
+ payload["format"] ||= "mlx_dsl_checkpoint_v2_native"
410
+ payload["weights_format"] ||= File.extname(weights_path).delete_prefix(".")
411
+ payload["metadata"] ||= {}
412
+ payload
413
+ end
414
+
415
+ def __dsl_unknown_kwargs_for_super(kwargs, super_method)
416
+ return [] if kwargs.empty?
417
+ return kwargs.keys unless super_method
418
+
419
+ params = super_method.parameters
420
+ return [] if params.any? { |type, _name| type == :keyrest }
421
+
422
+ accepted = params.each_with_object(Set.new) do |(type, name), out|
423
+ out << name.to_s if (type == :key || type == :keyreq) && !name.nil?
424
+ end
425
+
426
+ kwargs.keys.select { |key| !accepted.include?(key.to_s) }
427
+ end
428
+
429
+ def __dsl_extract_declared_options(kwargs)
430
+ out = {}
431
+ option_defs = self.class.dsl_option_definitions
432
+ return out if option_defs.empty? || kwargs.empty?
433
+
434
+ option_defs.each_key do |name|
435
+ symbol_key = name.to_sym
436
+ if kwargs.key?(symbol_key)
437
+ out[name] = kwargs.delete(symbol_key)
438
+ elsif kwargs.key?(name)
439
+ out[name] = kwargs.delete(name)
440
+ end
441
+ end
442
+ out
443
+ end
444
+
445
+ def __dsl_apply_option_values(provided)
446
+ self.class.dsl_option_definitions.each do |name, spec|
447
+ value = if provided.key?(name)
448
+ provided[name]
449
+ elsif spec[:default].equal?(UNSET)
450
+ if spec[:required]
451
+ raise ArgumentError, "missing required option: #{name}"
452
+ end
453
+ else
454
+ __dsl_resolve_callable(spec[:default])
455
+ end
456
+ next if value.nil? && spec[:default].equal?(UNSET) && !spec[:required]
457
+
458
+ public_send("#{name}=", value)
459
+ end
460
+ end
461
+
462
+ def __dsl_materialize_declarations!
463
+ return if defined?(@__dsl_materialized) && @__dsl_materialized
464
+
465
+ self.class.dsl_declarations.each do |decl|
466
+ case decl[:kind]
467
+ when :layer
468
+ public_send("#{decl[:name]}=", __dsl_build_layer(decl))
469
+ when :param
470
+ public_send("#{decl[:name]}=", __dsl_build_array(decl, default_fill: :uniform))
471
+ when :buffer
472
+ public_send("#{decl[:name]}=", __dsl_build_array(decl, default_fill: :zeros))
473
+ __send__(:no_grad).add(decl[:name])
474
+ else
475
+ raise ArgumentError, "unknown declaration kind: #{decl[:kind]}"
476
+ end
477
+ end
478
+
479
+ @__dsl_materialized = true
480
+ end
481
+
482
+ def __dsl_build_layer(decl)
483
+ if !decl[:factory].nil?
484
+ value = decl[:factory]
485
+ args = __dsl_resolve_factory_args(decl.fetch(:factory_args, []))
486
+ kwargs = __dsl_resolve_factory_kwargs(decl.fetch(:factory_kwargs, {}))
487
+ if value.is_a?(Class)
488
+ value = value.new(*args, **kwargs)
489
+ elsif value.respond_to?(:call)
490
+ if args.empty? && kwargs.empty?
491
+ value = __dsl_resolve_callable(value)
492
+ elsif kwargs.empty?
493
+ value = value.call(*args)
494
+ else
495
+ value = value.call(*args, **kwargs)
496
+ end
497
+ end
498
+ return __dsl_validate_layer_value(value, decl[:name])
499
+ end
500
+
501
+ builder = MLX::DSL::Builder.new(self)
502
+ value = builder.build(&decl[:block])
503
+ if value.nil?
504
+ raise ArgumentError, "layer #{decl[:name]} block returned nil"
505
+ end
506
+ __dsl_validate_layer_value(value, decl[:name])
507
+ end
508
+
509
+ def __dsl_validate_layer_value(value, name)
510
+ unless value.is_a?(MLX::NN::Module)
511
+ raise TypeError, "layer #{name} must build an MLX::NN::Module, got #{value.class}"
512
+ end
513
+
514
+ value
515
+ end
516
+
517
+ def __dsl_build_array(decl, default_fill:)
518
+ shape = __dsl_resolve_shape(decl[:shape])
519
+ dtype = __dsl_resolve_dtype(decl[:dtype])
520
+ init = decl[:init]
521
+
522
+ value = if init.nil?
523
+ if default_fill == :uniform
524
+ MLX::Core.random_uniform(shape, -0.05, 0.05, dtype)
525
+ else
526
+ MLX::Core.zeros(shape, dtype)
527
+ end
528
+ else
529
+ __dsl_call_initializer(init, shape, dtype)
530
+ end
531
+
532
+ value.is_a?(MLX::Core::Array) ? value : MLX::Core.array(value, dtype)
533
+ end
534
+
535
+ def __dsl_resolve_dtype(dtype)
536
+ return __dsl_default_dtype if dtype.equal?(UNSET)
537
+
538
+ __dsl_resolve_callable(dtype)
539
+ end
540
+
541
+ def __dsl_default_dtype
542
+ if defined?(MLX::Core) && MLX::Core.respond_to?(:float32)
543
+ return MLX::Core.float32
544
+ end
545
+
546
+ error_class = if defined?(MLX::Core) && defined?(MLX::Core::NativeUnavailableError)
547
+ MLX::Core::NativeUnavailableError
548
+ else
549
+ RuntimeError
550
+ end
551
+ raise error_class, "MLX native extension is required to initialize DSL params/buffers"
552
+ end
553
+
554
+ def __dsl_resolve_shape(shape)
555
+ resolved = __dsl_resolve_callable(shape)
556
+ resolved = [resolved] if resolved.is_a?(Integer)
557
+ unless resolved.is_a?(Array) && resolved.all? { |dim| dim.is_a?(Integer) && dim >= 0 }
558
+ raise ArgumentError, "shape must resolve to an array of non-negative integers"
559
+ end
560
+
561
+ resolved
562
+ end
563
+
564
+ def __dsl_call_initializer(init, shape, dtype)
565
+ unless init.respond_to?(:call)
566
+ return init
567
+ end
568
+
569
+ if init.is_a?(Proc)
570
+ case init.arity
571
+ when 0
572
+ instance_exec(&init)
573
+ when 1
574
+ init.call(shape)
575
+ else
576
+ init.call(shape, dtype)
577
+ end
578
+ else
579
+ init.call(shape, dtype)
580
+ end
581
+ end
582
+
583
+ def __dsl_resolve_factory_args(values)
584
+ values.map { |value| __dsl_resolve_callable(value) }
585
+ end
586
+
587
+ def __dsl_resolve_factory_kwargs(values)
588
+ values.each_with_object({}) do |(key, value), out|
589
+ out[key] = __dsl_resolve_callable(value)
590
+ end
591
+ end
592
+
593
+ def __dsl_resolve_callable(value)
594
+ return value unless value.respond_to?(:call)
595
+
596
+ if value.is_a?(Proc)
597
+ if value.arity == 1
598
+ value.call(self)
599
+ else
600
+ instance_exec(&value)
601
+ end
602
+ else
603
+ value.call
604
+ end
605
+ end
606
+
607
+ def __dsl_select_paths(matcher)
608
+ path_matcher = __dsl_path_matcher(matcher)
609
+ all_paths = MLX::Utils.tree_flatten(parameters, destination: {}).keys
610
+ all_paths.select { |path| path_matcher.call(path) }
611
+ end
612
+
613
+ def __dsl_path_matcher(matcher)
614
+ case matcher
615
+ when Regexp
616
+ lambda { |path| matcher.match?(path) }
617
+ when String, Symbol
618
+ target = matcher.to_s
619
+ lambda { |path| path == target }
620
+ when Array
621
+ targets = matcher.map(&:to_s)
622
+ lambda { |path| targets.include?(path) }
623
+ when Proc
624
+ lambda { |path| matcher.call(path) }
625
+ else
626
+ raise ArgumentError, "unsupported matcher: #{matcher.class}"
627
+ end
628
+ end
629
+
630
+ def __dsl_toggle_paths(paths, freeze:)
631
+ module_map = named_modules.to_h
632
+ paths.each do |path|
633
+ module_obj, local_key = __dsl_find_module_for_path(path, module_map)
634
+ next if local_key.nil? || local_key.empty?
635
+
636
+ if freeze
637
+ module_obj.__send__(:no_grad).add(local_key)
638
+ else
639
+ module_obj.__send__(:no_grad).delete(local_key)
640
+ end
641
+ end
642
+ self
643
+ end
644
+
645
+ def __dsl_serialize_tree(value)
646
+ if value.is_a?(MLX::Core::Array)
647
+ return { "__mlx_array__" => value.__getstate__ }
648
+ end
649
+ if value.is_a?(Array)
650
+ return value.map { |entry| __dsl_serialize_tree(entry) }
651
+ end
652
+ if value.is_a?(Hash)
653
+ return value.each_with_object({}) do |(key, entry), out|
654
+ out[key.to_s] = __dsl_serialize_tree(entry)
655
+ end
656
+ end
657
+
658
+ value
659
+ end
660
+
661
+ def __dsl_deserialize_tree(value)
662
+ if value.is_a?(Hash) && value.key?("__mlx_array__")
663
+ return __dsl_array_from_state(value.fetch("__mlx_array__"))
664
+ end
665
+ if value.is_a?(Array)
666
+ return value.map { |entry| __dsl_deserialize_tree(entry) }
667
+ end
668
+ if value.is_a?(Hash)
669
+ return value.each_with_object({}) do |(key, entry), out|
670
+ out[key] = __dsl_deserialize_tree(entry)
671
+ end
672
+ end
673
+
674
+ value
675
+ end
676
+
677
+ def __dsl_array_from_state(state)
678
+ values = state["values"] || state[:values]
679
+ dtype_name = state["dtype"] || state[:dtype]
680
+ if !dtype_name.nil? && MLX::Core.respond_to?(dtype_name.to_sym)
681
+ MLX::Core.array(values, MLX::Core.public_send(dtype_name.to_sym))
682
+ else
683
+ MLX::Core.array(values)
684
+ end
685
+ end
686
+
687
+ def __dsl_find_module_for_path(path, module_map)
688
+ best_prefix = ""
689
+ module_map.each_key do |prefix|
690
+ next if prefix.nil? || prefix.empty?
691
+ next if prefix == path
692
+ next unless path.start_with?(prefix + ".")
693
+ next unless prefix.length > best_prefix.length
694
+
695
+ best_prefix = prefix
696
+ end
697
+
698
+ if best_prefix.empty?
699
+ [self, path]
700
+ else
701
+ [module_map.fetch(best_prefix), path[(best_prefix.length + 1)..]]
702
+ end
703
+ end
704
+ end
705
+ end
706
+ end