mlx 0.30.7 → 0.30.7.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/ext/mlx/native.cpp +0 -4
- data/lib/mlx/core.rb +8 -1
- data/lib/mlx/distributed_utils/launch.rb +9 -3
- data/lib/mlx/dsl/builder.rb +377 -0
- data/lib/mlx/dsl/data_pipeline.rb +284 -0
- data/lib/mlx/dsl/experiment.rb +154 -0
- data/lib/mlx/dsl/graph_modules.rb +91 -0
- data/lib/mlx/dsl/model.rb +9 -0
- data/lib/mlx/dsl/model_mixin.rb +706 -0
- data/lib/mlx/dsl/split_plan.rb +85 -0
- data/lib/mlx/dsl/train_step.rb +197 -0
- data/lib/mlx/dsl/trainer.rb +2110 -0
- data/lib/mlx/dsl.rb +16 -0
- data/lib/mlx/nn/layers/containers.rb +21 -4
- data/lib/mlx/version.rb +1 -1
- data/lib/mlx.rb +1 -0
- data/mlx/CMakeLists.txt +4 -16
- metadata +12 -2
|
@@ -0,0 +1,706 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
require "set"
|
|
4
|
+
require "json"
|
|
5
|
+
require "fileutils"
|
|
6
|
+
|
|
7
|
+
module MLX
|
|
8
|
+
module DSL
|
|
9
|
+
module ModelMixin
|
|
10
|
+
UNSET = Object.new.freeze
|
|
11
|
+
|
|
12
|
+
class OptimizerGroupsBuilder
|
|
13
|
+
def initialize
|
|
14
|
+
@groups = []
|
|
15
|
+
end
|
|
16
|
+
|
|
17
|
+
def group(matcher = nil, &factory)
|
|
18
|
+
raise ArgumentError, "group requires an optimizer block" unless block_given?
|
|
19
|
+
|
|
20
|
+
optimizer = factory.call
|
|
21
|
+
unless optimizer.is_a?(MLX::Optimizers::Optimizer)
|
|
22
|
+
raise TypeError, "group block must return an MLX::Optimizers::Optimizer"
|
|
23
|
+
end
|
|
24
|
+
|
|
25
|
+
@groups << {
|
|
26
|
+
optimizer: optimizer,
|
|
27
|
+
filter: matcher_lambda(matcher)
|
|
28
|
+
}
|
|
29
|
+
optimizer
|
|
30
|
+
end
|
|
31
|
+
|
|
32
|
+
def build
|
|
33
|
+
if @groups.empty?
|
|
34
|
+
raise ArgumentError, "optimizer_groups requires at least one group"
|
|
35
|
+
end
|
|
36
|
+
|
|
37
|
+
return @groups.first[:optimizer] if @groups.length == 1
|
|
38
|
+
|
|
39
|
+
MLX::Optimizers::MultiOptimizer.new(
|
|
40
|
+
@groups.map { |entry| entry[:optimizer] },
|
|
41
|
+
filters: @groups[0...-1].map { |entry| entry[:filter] }
|
|
42
|
+
)
|
|
43
|
+
end
|
|
44
|
+
|
|
45
|
+
private
|
|
46
|
+
|
|
47
|
+
def matcher_lambda(matcher)
|
|
48
|
+
case matcher
|
|
49
|
+
when nil
|
|
50
|
+
lambda { |_path, _grad| true }
|
|
51
|
+
when Regexp
|
|
52
|
+
lambda { |path, _grad| matcher.match?(path.to_s) }
|
|
53
|
+
when String, Symbol
|
|
54
|
+
target = matcher.to_s
|
|
55
|
+
lambda { |path, _grad| path.to_s == target }
|
|
56
|
+
when Proc
|
|
57
|
+
lambda do |path, grad|
|
|
58
|
+
matcher.call(path.to_s, grad)
|
|
59
|
+
end
|
|
60
|
+
else
|
|
61
|
+
raise ArgumentError, "unsupported group matcher: #{matcher.class}"
|
|
62
|
+
end
|
|
63
|
+
end
|
|
64
|
+
end
|
|
65
|
+
|
|
66
|
+
def self.included(base)
|
|
67
|
+
base.extend(ClassMethods)
|
|
68
|
+
base.prepend(Initializer)
|
|
69
|
+
end
|
|
70
|
+
|
|
71
|
+
module Initializer
|
|
72
|
+
def initialize(*args, **kwargs, &block)
|
|
73
|
+
dsl_options = __dsl_extract_declared_options(kwargs)
|
|
74
|
+
unknown_kwargs = __dsl_unknown_kwargs_for_super(kwargs, method(__method__).super_method)
|
|
75
|
+
unless unknown_kwargs.empty?
|
|
76
|
+
names = unknown_kwargs.map(&:to_s).sort.join(", ")
|
|
77
|
+
raise ArgumentError, "unknown option(s): #{names}"
|
|
78
|
+
end
|
|
79
|
+
|
|
80
|
+
super(*args, **kwargs, &block)
|
|
81
|
+
__dsl_apply_option_values(dsl_options)
|
|
82
|
+
__dsl_materialize_declarations!
|
|
83
|
+
end
|
|
84
|
+
end
|
|
85
|
+
|
|
86
|
+
module ClassMethods
|
|
87
|
+
def option(name, default: UNSET, required: UNSET)
|
|
88
|
+
key = name.to_s
|
|
89
|
+
required = default.equal?(UNSET) if required.equal?(UNSET)
|
|
90
|
+
dsl_option_definitions[key] = { default: default, required: !!required }
|
|
91
|
+
end
|
|
92
|
+
|
|
93
|
+
def layer(name, factory = nil, *factory_args, **factory_kwargs, &block)
|
|
94
|
+
if factory.nil? && !block_given?
|
|
95
|
+
raise ArgumentError, "layer requires either a factory or block"
|
|
96
|
+
end
|
|
97
|
+
if !factory.nil? && block_given?
|
|
98
|
+
raise ArgumentError, "layer cannot accept both a factory and block"
|
|
99
|
+
end
|
|
100
|
+
|
|
101
|
+
dsl_declarations << {
|
|
102
|
+
kind: :layer,
|
|
103
|
+
name: name.to_s,
|
|
104
|
+
factory: factory,
|
|
105
|
+
factory_args: factory_args,
|
|
106
|
+
factory_kwargs: factory_kwargs,
|
|
107
|
+
block: block
|
|
108
|
+
}
|
|
109
|
+
end
|
|
110
|
+
|
|
111
|
+
def network(name, factory = nil, *factory_args, **factory_kwargs, &block)
|
|
112
|
+
layer(name, factory, *factory_args, **factory_kwargs, &block)
|
|
113
|
+
end
|
|
114
|
+
|
|
115
|
+
def param(name, shape:, init: nil, dtype: UNSET)
|
|
116
|
+
dsl_declarations << {
|
|
117
|
+
kind: :param,
|
|
118
|
+
name: name.to_s,
|
|
119
|
+
shape: shape,
|
|
120
|
+
init: init,
|
|
121
|
+
dtype: dtype
|
|
122
|
+
}
|
|
123
|
+
end
|
|
124
|
+
|
|
125
|
+
def buffer(name, shape:, init: nil, dtype: UNSET)
|
|
126
|
+
dsl_declarations << {
|
|
127
|
+
kind: :buffer,
|
|
128
|
+
name: name.to_s,
|
|
129
|
+
shape: shape,
|
|
130
|
+
init: init,
|
|
131
|
+
dtype: dtype
|
|
132
|
+
}
|
|
133
|
+
end
|
|
134
|
+
|
|
135
|
+
def dsl_option_definitions
|
|
136
|
+
@dsl_option_definitions ||= {}
|
|
137
|
+
end
|
|
138
|
+
|
|
139
|
+
def dsl_declarations
|
|
140
|
+
@dsl_declarations ||= []
|
|
141
|
+
end
|
|
142
|
+
|
|
143
|
+
def inherited(subclass)
|
|
144
|
+
super
|
|
145
|
+
copied_options = dsl_option_definitions.each_with_object({}) do |(key, value), out|
|
|
146
|
+
out[key] = value.dup
|
|
147
|
+
end
|
|
148
|
+
copied_declarations = dsl_declarations.map(&:dup)
|
|
149
|
+
subclass.instance_variable_set(:@dsl_option_definitions, copied_options)
|
|
150
|
+
subclass.instance_variable_set(:@dsl_declarations, copied_declarations)
|
|
151
|
+
end
|
|
152
|
+
end
|
|
153
|
+
|
|
154
|
+
include TrainStepMethods
|
|
155
|
+
|
|
156
|
+
def optimizer_groups(&block)
|
|
157
|
+
raise ArgumentError, "optimizer_groups requires a block" unless block_given?
|
|
158
|
+
|
|
159
|
+
builder = OptimizerGroupsBuilder.new
|
|
160
|
+
builder.instance_eval(&block)
|
|
161
|
+
builder.build
|
|
162
|
+
end
|
|
163
|
+
|
|
164
|
+
def trainer(optimizer:, clip_grad_norm: nil, compile: false, sync: :none, &loss_block)
|
|
165
|
+
MLX::DSL::Trainer.new(
|
|
166
|
+
model: self,
|
|
167
|
+
optimizer: optimizer,
|
|
168
|
+
clip_grad_norm: clip_grad_norm,
|
|
169
|
+
compile: compile,
|
|
170
|
+
sync: sync,
|
|
171
|
+
&loss_block
|
|
172
|
+
)
|
|
173
|
+
end
|
|
174
|
+
|
|
175
|
+
def save_checkpoint(path, optimizer: nil, metadata: {}, format: nil)
|
|
176
|
+
checkpoint_format = __dsl_checkpoint_format(path, format)
|
|
177
|
+
if checkpoint_format == :marshal
|
|
178
|
+
__dsl_ensure_parent_dir!(path)
|
|
179
|
+
payload = {
|
|
180
|
+
"format" => "mlx_dsl_checkpoint_v1",
|
|
181
|
+
"model" => __dsl_serialize_tree(parameters),
|
|
182
|
+
"metadata" => metadata || {}
|
|
183
|
+
}
|
|
184
|
+
payload["optimizer"] = __dsl_serialize_tree(optimizer.state) unless optimizer.nil?
|
|
185
|
+
|
|
186
|
+
File.binwrite(path, Marshal.dump(payload))
|
|
187
|
+
return path
|
|
188
|
+
end
|
|
189
|
+
|
|
190
|
+
__dsl_save_native_checkpoint(path, checkpoint_format, optimizer: optimizer, metadata: metadata)
|
|
191
|
+
end
|
|
192
|
+
|
|
193
|
+
def load_checkpoint(path, optimizer: nil, strict: true, format: nil)
|
|
194
|
+
resolved_path = __dsl_resolve_load_checkpoint_path(path, format)
|
|
195
|
+
checkpoint_format = __dsl_checkpoint_format(resolved_path, format)
|
|
196
|
+
if checkpoint_format == :marshal
|
|
197
|
+
payload = Marshal.load(File.binread(resolved_path))
|
|
198
|
+
unless payload.is_a?(Hash) && payload["format"] == "mlx_dsl_checkpoint_v1"
|
|
199
|
+
raise ArgumentError, "unsupported checkpoint format"
|
|
200
|
+
end
|
|
201
|
+
|
|
202
|
+
model_state = __dsl_deserialize_tree(payload.fetch("model"))
|
|
203
|
+
update(model_state, strict: strict)
|
|
204
|
+
|
|
205
|
+
if !optimizer.nil? && payload.key?("optimizer")
|
|
206
|
+
optimizer.state = __dsl_deserialize_tree(payload["optimizer"])
|
|
207
|
+
end
|
|
208
|
+
|
|
209
|
+
return payload
|
|
210
|
+
end
|
|
211
|
+
|
|
212
|
+
weights_path = __dsl_checkpoint_weights_path(resolved_path, checkpoint_format)
|
|
213
|
+
load_weights(weights_path, strict: strict)
|
|
214
|
+
payload = __dsl_load_native_checkpoint_payload(weights_path)
|
|
215
|
+
|
|
216
|
+
if !optimizer.nil? && payload.key?("optimizer")
|
|
217
|
+
optimizer.state = __dsl_deserialize_tree(payload["optimizer"])
|
|
218
|
+
end
|
|
219
|
+
|
|
220
|
+
payload
|
|
221
|
+
end
|
|
222
|
+
|
|
223
|
+
def train_mode
|
|
224
|
+
previous = training
|
|
225
|
+
train(true)
|
|
226
|
+
return self unless block_given?
|
|
227
|
+
|
|
228
|
+
yield(self)
|
|
229
|
+
ensure
|
|
230
|
+
train(previous) unless previous.nil?
|
|
231
|
+
end
|
|
232
|
+
|
|
233
|
+
def eval_mode
|
|
234
|
+
previous = training
|
|
235
|
+
eval
|
|
236
|
+
return self unless block_given?
|
|
237
|
+
|
|
238
|
+
yield(self)
|
|
239
|
+
ensure
|
|
240
|
+
train(previous) unless previous.nil?
|
|
241
|
+
end
|
|
242
|
+
|
|
243
|
+
def freeze_paths!(matcher, strict: false)
|
|
244
|
+
selected = __dsl_select_paths(matcher)
|
|
245
|
+
if strict && selected.empty?
|
|
246
|
+
raise KeyError, "no parameter paths matched #{matcher.inspect}"
|
|
247
|
+
end
|
|
248
|
+
|
|
249
|
+
__dsl_toggle_paths(selected, freeze: true)
|
|
250
|
+
end
|
|
251
|
+
|
|
252
|
+
def unfreeze_paths!(matcher, strict: false)
|
|
253
|
+
selected = __dsl_select_paths(matcher)
|
|
254
|
+
if strict && selected.empty?
|
|
255
|
+
raise KeyError, "no parameter paths matched #{matcher.inspect}"
|
|
256
|
+
end
|
|
257
|
+
|
|
258
|
+
__dsl_toggle_paths(selected, freeze: false)
|
|
259
|
+
end
|
|
260
|
+
|
|
261
|
+
def parameter_paths(matcher: nil)
|
|
262
|
+
all_paths = MLX::Utils.tree_flatten(parameters, destination: {}).keys.sort
|
|
263
|
+
return all_paths if matcher.nil?
|
|
264
|
+
|
|
265
|
+
path_matcher = __dsl_path_matcher(matcher)
|
|
266
|
+
all_paths.select { |path| path_matcher.call(path) }
|
|
267
|
+
end
|
|
268
|
+
|
|
269
|
+
def parameter_count
|
|
270
|
+
__dsl_count_parameters(parameters)
|
|
271
|
+
end
|
|
272
|
+
|
|
273
|
+
def trainable_parameter_count
|
|
274
|
+
__dsl_count_parameters(trainable_parameters)
|
|
275
|
+
end
|
|
276
|
+
|
|
277
|
+
def summary(as: :hash)
|
|
278
|
+
total = parameter_count
|
|
279
|
+
trainable = trainable_parameter_count
|
|
280
|
+
payload = {
|
|
281
|
+
"model_class" => self.class.name.to_s,
|
|
282
|
+
"total_parameters" => total,
|
|
283
|
+
"trainable_parameters" => trainable,
|
|
284
|
+
"frozen_parameters" => total - trainable,
|
|
285
|
+
"parameter_paths" => parameter_paths
|
|
286
|
+
}
|
|
287
|
+
case as.to_sym
|
|
288
|
+
when :hash
|
|
289
|
+
payload
|
|
290
|
+
when :text
|
|
291
|
+
__dsl_summary_text(payload)
|
|
292
|
+
else
|
|
293
|
+
raise ArgumentError, "summary :as must be :hash or :text"
|
|
294
|
+
end
|
|
295
|
+
end
|
|
296
|
+
|
|
297
|
+
private
|
|
298
|
+
|
|
299
|
+
def __dsl_summary_text(payload)
|
|
300
|
+
[
|
|
301
|
+
"model=#{payload.fetch('model_class')}",
|
|
302
|
+
"total_parameters=#{payload.fetch('total_parameters')}",
|
|
303
|
+
"trainable_parameters=#{payload.fetch('trainable_parameters')}",
|
|
304
|
+
"frozen_parameters=#{payload.fetch('frozen_parameters')}",
|
|
305
|
+
"parameter_paths=#{payload.fetch('parameter_paths').join(',')}"
|
|
306
|
+
].join("\n")
|
|
307
|
+
end
|
|
308
|
+
|
|
309
|
+
def __dsl_count_parameters(tree)
|
|
310
|
+
flat = MLX::Utils.tree_flatten(tree, destination: {})
|
|
311
|
+
flat.values.sum { |value| __dsl_leaf_numel(value) }
|
|
312
|
+
end
|
|
313
|
+
|
|
314
|
+
def __dsl_leaf_numel(value)
|
|
315
|
+
if value.respond_to?(:shape)
|
|
316
|
+
shape = value.shape
|
|
317
|
+
if shape.is_a?(Array)
|
|
318
|
+
return 1 if shape.empty?
|
|
319
|
+
|
|
320
|
+
return shape.reduce(1) { |acc, dim| acc * dim.to_i }
|
|
321
|
+
end
|
|
322
|
+
end
|
|
323
|
+
return value.size.to_i if value.respond_to?(:size)
|
|
324
|
+
|
|
325
|
+
1
|
|
326
|
+
end
|
|
327
|
+
|
|
328
|
+
def __dsl_checkpoint_format(path, format)
|
|
329
|
+
raw = if format.nil?
|
|
330
|
+
ext = File.extname(path.to_s).delete_prefix(".").downcase
|
|
331
|
+
ext.empty? ? "marshal" : ext
|
|
332
|
+
else
|
|
333
|
+
format.to_s.downcase
|
|
334
|
+
end
|
|
335
|
+
|
|
336
|
+
case raw
|
|
337
|
+
when "marshal", "bin", "legacy", "dsl_v1"
|
|
338
|
+
:marshal
|
|
339
|
+
when "npz"
|
|
340
|
+
:npz
|
|
341
|
+
when "safetensors"
|
|
342
|
+
:safetensors
|
|
343
|
+
else
|
|
344
|
+
raise ArgumentError, "unsupported checkpoint format: #{raw.inspect}"
|
|
345
|
+
end
|
|
346
|
+
end
|
|
347
|
+
|
|
348
|
+
def __dsl_checkpoint_weights_path(path, checkpoint_format)
|
|
349
|
+
target = path.to_s
|
|
350
|
+
case checkpoint_format
|
|
351
|
+
when :npz
|
|
352
|
+
target.end_with?(".npz") ? target : "#{target}.npz"
|
|
353
|
+
when :safetensors
|
|
354
|
+
target.end_with?(".safetensors") ? target : "#{target}.safetensors"
|
|
355
|
+
else
|
|
356
|
+
target
|
|
357
|
+
end
|
|
358
|
+
end
|
|
359
|
+
|
|
360
|
+
def __dsl_checkpoint_sidecar_path(weights_path)
|
|
361
|
+
"#{weights_path}.mlxmeta.json"
|
|
362
|
+
end
|
|
363
|
+
|
|
364
|
+
def __dsl_resolve_load_checkpoint_path(path, format)
|
|
365
|
+
target = path.to_s
|
|
366
|
+
return target unless format.nil?
|
|
367
|
+
return target unless File.extname(target).empty?
|
|
368
|
+
return target if File.exist?(target)
|
|
369
|
+
|
|
370
|
+
candidates = ["#{target}.npz", "#{target}.safetensors"]
|
|
371
|
+
found = candidates.find { |candidate| File.exist?(candidate) }
|
|
372
|
+
found.nil? ? target : found
|
|
373
|
+
end
|
|
374
|
+
|
|
375
|
+
def __dsl_save_native_checkpoint(path, checkpoint_format, optimizer:, metadata:)
|
|
376
|
+
weights_path = __dsl_checkpoint_weights_path(path, checkpoint_format)
|
|
377
|
+
__dsl_ensure_parent_dir!(weights_path)
|
|
378
|
+
save_weights(weights_path)
|
|
379
|
+
|
|
380
|
+
payload = {
|
|
381
|
+
"format" => "mlx_dsl_checkpoint_v2_native",
|
|
382
|
+
"weights_format" => checkpoint_format.to_s,
|
|
383
|
+
"metadata" => metadata || {}
|
|
384
|
+
}
|
|
385
|
+
payload["optimizer"] = __dsl_serialize_tree(optimizer.state) unless optimizer.nil?
|
|
386
|
+
|
|
387
|
+
File.binwrite(__dsl_checkpoint_sidecar_path(weights_path), JSON.generate(payload))
|
|
388
|
+
weights_path
|
|
389
|
+
end
|
|
390
|
+
|
|
391
|
+
def __dsl_ensure_parent_dir!(path)
|
|
392
|
+
dir = File.dirname(path.to_s)
|
|
393
|
+
return if dir.nil? || dir.empty? || dir == "."
|
|
394
|
+
|
|
395
|
+
FileUtils.mkdir_p(dir)
|
|
396
|
+
end
|
|
397
|
+
|
|
398
|
+
def __dsl_load_native_checkpoint_payload(weights_path)
|
|
399
|
+
sidecar_path = __dsl_checkpoint_sidecar_path(weights_path)
|
|
400
|
+
payload = if File.exist?(sidecar_path)
|
|
401
|
+
JSON.parse(File.binread(sidecar_path))
|
|
402
|
+
else
|
|
403
|
+
{}
|
|
404
|
+
end
|
|
405
|
+
unless payload.is_a?(Hash)
|
|
406
|
+
raise ArgumentError, "invalid native checkpoint sidecar payload"
|
|
407
|
+
end
|
|
408
|
+
|
|
409
|
+
payload["format"] ||= "mlx_dsl_checkpoint_v2_native"
|
|
410
|
+
payload["weights_format"] ||= File.extname(weights_path).delete_prefix(".")
|
|
411
|
+
payload["metadata"] ||= {}
|
|
412
|
+
payload
|
|
413
|
+
end
|
|
414
|
+
|
|
415
|
+
def __dsl_unknown_kwargs_for_super(kwargs, super_method)
|
|
416
|
+
return [] if kwargs.empty?
|
|
417
|
+
return kwargs.keys unless super_method
|
|
418
|
+
|
|
419
|
+
params = super_method.parameters
|
|
420
|
+
return [] if params.any? { |type, _name| type == :keyrest }
|
|
421
|
+
|
|
422
|
+
accepted = params.each_with_object(Set.new) do |(type, name), out|
|
|
423
|
+
out << name.to_s if (type == :key || type == :keyreq) && !name.nil?
|
|
424
|
+
end
|
|
425
|
+
|
|
426
|
+
kwargs.keys.select { |key| !accepted.include?(key.to_s) }
|
|
427
|
+
end
|
|
428
|
+
|
|
429
|
+
def __dsl_extract_declared_options(kwargs)
|
|
430
|
+
out = {}
|
|
431
|
+
option_defs = self.class.dsl_option_definitions
|
|
432
|
+
return out if option_defs.empty? || kwargs.empty?
|
|
433
|
+
|
|
434
|
+
option_defs.each_key do |name|
|
|
435
|
+
symbol_key = name.to_sym
|
|
436
|
+
if kwargs.key?(symbol_key)
|
|
437
|
+
out[name] = kwargs.delete(symbol_key)
|
|
438
|
+
elsif kwargs.key?(name)
|
|
439
|
+
out[name] = kwargs.delete(name)
|
|
440
|
+
end
|
|
441
|
+
end
|
|
442
|
+
out
|
|
443
|
+
end
|
|
444
|
+
|
|
445
|
+
def __dsl_apply_option_values(provided)
|
|
446
|
+
self.class.dsl_option_definitions.each do |name, spec|
|
|
447
|
+
value = if provided.key?(name)
|
|
448
|
+
provided[name]
|
|
449
|
+
elsif spec[:default].equal?(UNSET)
|
|
450
|
+
if spec[:required]
|
|
451
|
+
raise ArgumentError, "missing required option: #{name}"
|
|
452
|
+
end
|
|
453
|
+
else
|
|
454
|
+
__dsl_resolve_callable(spec[:default])
|
|
455
|
+
end
|
|
456
|
+
next if value.nil? && spec[:default].equal?(UNSET) && !spec[:required]
|
|
457
|
+
|
|
458
|
+
public_send("#{name}=", value)
|
|
459
|
+
end
|
|
460
|
+
end
|
|
461
|
+
|
|
462
|
+
def __dsl_materialize_declarations!
|
|
463
|
+
return if defined?(@__dsl_materialized) && @__dsl_materialized
|
|
464
|
+
|
|
465
|
+
self.class.dsl_declarations.each do |decl|
|
|
466
|
+
case decl[:kind]
|
|
467
|
+
when :layer
|
|
468
|
+
public_send("#{decl[:name]}=", __dsl_build_layer(decl))
|
|
469
|
+
when :param
|
|
470
|
+
public_send("#{decl[:name]}=", __dsl_build_array(decl, default_fill: :uniform))
|
|
471
|
+
when :buffer
|
|
472
|
+
public_send("#{decl[:name]}=", __dsl_build_array(decl, default_fill: :zeros))
|
|
473
|
+
__send__(:no_grad).add(decl[:name])
|
|
474
|
+
else
|
|
475
|
+
raise ArgumentError, "unknown declaration kind: #{decl[:kind]}"
|
|
476
|
+
end
|
|
477
|
+
end
|
|
478
|
+
|
|
479
|
+
@__dsl_materialized = true
|
|
480
|
+
end
|
|
481
|
+
|
|
482
|
+
def __dsl_build_layer(decl)
|
|
483
|
+
if !decl[:factory].nil?
|
|
484
|
+
value = decl[:factory]
|
|
485
|
+
args = __dsl_resolve_factory_args(decl.fetch(:factory_args, []))
|
|
486
|
+
kwargs = __dsl_resolve_factory_kwargs(decl.fetch(:factory_kwargs, {}))
|
|
487
|
+
if value.is_a?(Class)
|
|
488
|
+
value = value.new(*args, **kwargs)
|
|
489
|
+
elsif value.respond_to?(:call)
|
|
490
|
+
if args.empty? && kwargs.empty?
|
|
491
|
+
value = __dsl_resolve_callable(value)
|
|
492
|
+
elsif kwargs.empty?
|
|
493
|
+
value = value.call(*args)
|
|
494
|
+
else
|
|
495
|
+
value = value.call(*args, **kwargs)
|
|
496
|
+
end
|
|
497
|
+
end
|
|
498
|
+
return __dsl_validate_layer_value(value, decl[:name])
|
|
499
|
+
end
|
|
500
|
+
|
|
501
|
+
builder = MLX::DSL::Builder.new(self)
|
|
502
|
+
value = builder.build(&decl[:block])
|
|
503
|
+
if value.nil?
|
|
504
|
+
raise ArgumentError, "layer #{decl[:name]} block returned nil"
|
|
505
|
+
end
|
|
506
|
+
__dsl_validate_layer_value(value, decl[:name])
|
|
507
|
+
end
|
|
508
|
+
|
|
509
|
+
def __dsl_validate_layer_value(value, name)
|
|
510
|
+
unless value.is_a?(MLX::NN::Module)
|
|
511
|
+
raise TypeError, "layer #{name} must build an MLX::NN::Module, got #{value.class}"
|
|
512
|
+
end
|
|
513
|
+
|
|
514
|
+
value
|
|
515
|
+
end
|
|
516
|
+
|
|
517
|
+
def __dsl_build_array(decl, default_fill:)
|
|
518
|
+
shape = __dsl_resolve_shape(decl[:shape])
|
|
519
|
+
dtype = __dsl_resolve_dtype(decl[:dtype])
|
|
520
|
+
init = decl[:init]
|
|
521
|
+
|
|
522
|
+
value = if init.nil?
|
|
523
|
+
if default_fill == :uniform
|
|
524
|
+
MLX::Core.random_uniform(shape, -0.05, 0.05, dtype)
|
|
525
|
+
else
|
|
526
|
+
MLX::Core.zeros(shape, dtype)
|
|
527
|
+
end
|
|
528
|
+
else
|
|
529
|
+
__dsl_call_initializer(init, shape, dtype)
|
|
530
|
+
end
|
|
531
|
+
|
|
532
|
+
value.is_a?(MLX::Core::Array) ? value : MLX::Core.array(value, dtype)
|
|
533
|
+
end
|
|
534
|
+
|
|
535
|
+
def __dsl_resolve_dtype(dtype)
|
|
536
|
+
return __dsl_default_dtype if dtype.equal?(UNSET)
|
|
537
|
+
|
|
538
|
+
__dsl_resolve_callable(dtype)
|
|
539
|
+
end
|
|
540
|
+
|
|
541
|
+
def __dsl_default_dtype
|
|
542
|
+
if defined?(MLX::Core) && MLX::Core.respond_to?(:float32)
|
|
543
|
+
return MLX::Core.float32
|
|
544
|
+
end
|
|
545
|
+
|
|
546
|
+
error_class = if defined?(MLX::Core) && defined?(MLX::Core::NativeUnavailableError)
|
|
547
|
+
MLX::Core::NativeUnavailableError
|
|
548
|
+
else
|
|
549
|
+
RuntimeError
|
|
550
|
+
end
|
|
551
|
+
raise error_class, "MLX native extension is required to initialize DSL params/buffers"
|
|
552
|
+
end
|
|
553
|
+
|
|
554
|
+
def __dsl_resolve_shape(shape)
|
|
555
|
+
resolved = __dsl_resolve_callable(shape)
|
|
556
|
+
resolved = [resolved] if resolved.is_a?(Integer)
|
|
557
|
+
unless resolved.is_a?(Array) && resolved.all? { |dim| dim.is_a?(Integer) && dim >= 0 }
|
|
558
|
+
raise ArgumentError, "shape must resolve to an array of non-negative integers"
|
|
559
|
+
end
|
|
560
|
+
|
|
561
|
+
resolved
|
|
562
|
+
end
|
|
563
|
+
|
|
564
|
+
def __dsl_call_initializer(init, shape, dtype)
|
|
565
|
+
unless init.respond_to?(:call)
|
|
566
|
+
return init
|
|
567
|
+
end
|
|
568
|
+
|
|
569
|
+
if init.is_a?(Proc)
|
|
570
|
+
case init.arity
|
|
571
|
+
when 0
|
|
572
|
+
instance_exec(&init)
|
|
573
|
+
when 1
|
|
574
|
+
init.call(shape)
|
|
575
|
+
else
|
|
576
|
+
init.call(shape, dtype)
|
|
577
|
+
end
|
|
578
|
+
else
|
|
579
|
+
init.call(shape, dtype)
|
|
580
|
+
end
|
|
581
|
+
end
|
|
582
|
+
|
|
583
|
+
def __dsl_resolve_factory_args(values)
|
|
584
|
+
values.map { |value| __dsl_resolve_callable(value) }
|
|
585
|
+
end
|
|
586
|
+
|
|
587
|
+
def __dsl_resolve_factory_kwargs(values)
|
|
588
|
+
values.each_with_object({}) do |(key, value), out|
|
|
589
|
+
out[key] = __dsl_resolve_callable(value)
|
|
590
|
+
end
|
|
591
|
+
end
|
|
592
|
+
|
|
593
|
+
def __dsl_resolve_callable(value)
|
|
594
|
+
return value unless value.respond_to?(:call)
|
|
595
|
+
|
|
596
|
+
if value.is_a?(Proc)
|
|
597
|
+
if value.arity == 1
|
|
598
|
+
value.call(self)
|
|
599
|
+
else
|
|
600
|
+
instance_exec(&value)
|
|
601
|
+
end
|
|
602
|
+
else
|
|
603
|
+
value.call
|
|
604
|
+
end
|
|
605
|
+
end
|
|
606
|
+
|
|
607
|
+
def __dsl_select_paths(matcher)
|
|
608
|
+
path_matcher = __dsl_path_matcher(matcher)
|
|
609
|
+
all_paths = MLX::Utils.tree_flatten(parameters, destination: {}).keys
|
|
610
|
+
all_paths.select { |path| path_matcher.call(path) }
|
|
611
|
+
end
|
|
612
|
+
|
|
613
|
+
def __dsl_path_matcher(matcher)
|
|
614
|
+
case matcher
|
|
615
|
+
when Regexp
|
|
616
|
+
lambda { |path| matcher.match?(path) }
|
|
617
|
+
when String, Symbol
|
|
618
|
+
target = matcher.to_s
|
|
619
|
+
lambda { |path| path == target }
|
|
620
|
+
when Array
|
|
621
|
+
targets = matcher.map(&:to_s)
|
|
622
|
+
lambda { |path| targets.include?(path) }
|
|
623
|
+
when Proc
|
|
624
|
+
lambda { |path| matcher.call(path) }
|
|
625
|
+
else
|
|
626
|
+
raise ArgumentError, "unsupported matcher: #{matcher.class}"
|
|
627
|
+
end
|
|
628
|
+
end
|
|
629
|
+
|
|
630
|
+
def __dsl_toggle_paths(paths, freeze:)
|
|
631
|
+
module_map = named_modules.to_h
|
|
632
|
+
paths.each do |path|
|
|
633
|
+
module_obj, local_key = __dsl_find_module_for_path(path, module_map)
|
|
634
|
+
next if local_key.nil? || local_key.empty?
|
|
635
|
+
|
|
636
|
+
if freeze
|
|
637
|
+
module_obj.__send__(:no_grad).add(local_key)
|
|
638
|
+
else
|
|
639
|
+
module_obj.__send__(:no_grad).delete(local_key)
|
|
640
|
+
end
|
|
641
|
+
end
|
|
642
|
+
self
|
|
643
|
+
end
|
|
644
|
+
|
|
645
|
+
def __dsl_serialize_tree(value)
|
|
646
|
+
if value.is_a?(MLX::Core::Array)
|
|
647
|
+
return { "__mlx_array__" => value.__getstate__ }
|
|
648
|
+
end
|
|
649
|
+
if value.is_a?(Array)
|
|
650
|
+
return value.map { |entry| __dsl_serialize_tree(entry) }
|
|
651
|
+
end
|
|
652
|
+
if value.is_a?(Hash)
|
|
653
|
+
return value.each_with_object({}) do |(key, entry), out|
|
|
654
|
+
out[key.to_s] = __dsl_serialize_tree(entry)
|
|
655
|
+
end
|
|
656
|
+
end
|
|
657
|
+
|
|
658
|
+
value
|
|
659
|
+
end
|
|
660
|
+
|
|
661
|
+
def __dsl_deserialize_tree(value)
|
|
662
|
+
if value.is_a?(Hash) && value.key?("__mlx_array__")
|
|
663
|
+
return __dsl_array_from_state(value.fetch("__mlx_array__"))
|
|
664
|
+
end
|
|
665
|
+
if value.is_a?(Array)
|
|
666
|
+
return value.map { |entry| __dsl_deserialize_tree(entry) }
|
|
667
|
+
end
|
|
668
|
+
if value.is_a?(Hash)
|
|
669
|
+
return value.each_with_object({}) do |(key, entry), out|
|
|
670
|
+
out[key] = __dsl_deserialize_tree(entry)
|
|
671
|
+
end
|
|
672
|
+
end
|
|
673
|
+
|
|
674
|
+
value
|
|
675
|
+
end
|
|
676
|
+
|
|
677
|
+
def __dsl_array_from_state(state)
|
|
678
|
+
values = state["values"] || state[:values]
|
|
679
|
+
dtype_name = state["dtype"] || state[:dtype]
|
|
680
|
+
if !dtype_name.nil? && MLX::Core.respond_to?(dtype_name.to_sym)
|
|
681
|
+
MLX::Core.array(values, MLX::Core.public_send(dtype_name.to_sym))
|
|
682
|
+
else
|
|
683
|
+
MLX::Core.array(values)
|
|
684
|
+
end
|
|
685
|
+
end
|
|
686
|
+
|
|
687
|
+
def __dsl_find_module_for_path(path, module_map)
|
|
688
|
+
best_prefix = ""
|
|
689
|
+
module_map.each_key do |prefix|
|
|
690
|
+
next if prefix.nil? || prefix.empty?
|
|
691
|
+
next if prefix == path
|
|
692
|
+
next unless path.start_with?(prefix + ".")
|
|
693
|
+
next unless prefix.length > best_prefix.length
|
|
694
|
+
|
|
695
|
+
best_prefix = prefix
|
|
696
|
+
end
|
|
697
|
+
|
|
698
|
+
if best_prefix.empty?
|
|
699
|
+
[self, path]
|
|
700
|
+
else
|
|
701
|
+
[module_map.fetch(best_prefix), path[(best_prefix.length + 1)..]]
|
|
702
|
+
end
|
|
703
|
+
end
|
|
704
|
+
end
|
|
705
|
+
end
|
|
706
|
+
end
|