mlx 0.30.7 → 0.30.7.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/ext/mlx/native.cpp +0 -4
- data/lib/mlx/core.rb +8 -1
- data/lib/mlx/distributed_utils/launch.rb +9 -3
- data/lib/mlx/dsl/builder.rb +377 -0
- data/lib/mlx/dsl/data_pipeline.rb +284 -0
- data/lib/mlx/dsl/experiment.rb +154 -0
- data/lib/mlx/dsl/graph_modules.rb +91 -0
- data/lib/mlx/dsl/model.rb +9 -0
- data/lib/mlx/dsl/model_mixin.rb +706 -0
- data/lib/mlx/dsl/split_plan.rb +85 -0
- data/lib/mlx/dsl/train_step.rb +197 -0
- data/lib/mlx/dsl/trainer.rb +2110 -0
- data/lib/mlx/dsl.rb +16 -0
- data/lib/mlx/nn/layers/containers.rb +21 -4
- data/lib/mlx/version.rb +1 -1
- data/lib/mlx.rb +1 -0
- data/mlx/CMakeLists.txt +4 -16
- metadata +12 -2
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: 215a912d2353fd5edaa60e320a5b857aa13009e55f3a190b21b2ffe5735f37af
|
|
4
|
+
data.tar.gz: 1c9b4279f8077e3cd067354ea692e1b24248afe7262d88d4569c566c52f5a158
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: 66abcbd58ccfc04186df11b0d2b6445c7d1e0ab4a36451742755f6fcf41022363403536c3d27640935364c58231c4c4a03e39ceba97617a59f2ad69acf23dc16
|
|
7
|
+
data.tar.gz: ba7ad07ccd31e94bdf3fdee73117f069c6ee22c11cdfc3f2470eedee9ee0bc9e976970bab21f156c003b658723a0e3fb2f63f62d0e11d33de3a61fc1d9121711
|
data/ext/mlx/native.cpp
CHANGED
|
@@ -7884,9 +7884,6 @@ extern "C" void Init_native(void) {
|
|
|
7884
7884
|
"scaled_dot_product_attention",
|
|
7885
7885
|
RUBY_METHOD_FUNC(core_scaled_dot_product_attention),
|
|
7886
7886
|
-1);
|
|
7887
|
-
rb_define_singleton_method(
|
|
7888
|
-
mCore, "scaled_dot_product_attention", RUBY_METHOD_FUNC(core_scaled_dot_product_attention), -1);
|
|
7889
|
-
rb_define_singleton_method(mCore, "scaled_dot_product_attention", RUBY_METHOD_FUNC(core_scaled_dot_product_attention), -1);
|
|
7890
7887
|
rb_define_singleton_method(mCore, "arange", RUBY_METHOD_FUNC(core_arange), -1);
|
|
7891
7888
|
rb_define_singleton_method(mCore, "linspace", RUBY_METHOD_FUNC(core_linspace), -1);
|
|
7892
7889
|
rb_define_singleton_method(mCore, "zeros", RUBY_METHOD_FUNC(core_zeros), -1);
|
|
@@ -8023,5 +8020,4 @@ extern "C" void Init_native(void) {
|
|
|
8023
8020
|
"precompiled_cuda_kernel",
|
|
8024
8021
|
RUBY_METHOD_FUNC(core_precompiled_cuda_kernel),
|
|
8025
8022
|
-1);
|
|
8026
|
-
rb_define_singleton_method(mCore, "precompiled_cuda_kernel", RUBY_METHOD_FUNC(core_precompiled_cuda_kernel), -1);
|
|
8027
8023
|
}
|
data/lib/mlx/core.rb
CHANGED
|
@@ -335,6 +335,12 @@ module MLX
|
|
|
335
335
|
alias_method :native_export_to_dot,
|
|
336
336
|
:export_to_dot if method_defined?(:export_to_dot) && !method_defined?(:native_export_to_dot)
|
|
337
337
|
|
|
338
|
+
%i[savez savez_compressed].each do |method_name|
|
|
339
|
+
if method_defined?(method_name) && instance_method(method_name).owner == self
|
|
340
|
+
remove_method(method_name)
|
|
341
|
+
end
|
|
342
|
+
end
|
|
343
|
+
|
|
338
344
|
ARRAY_LEAF = :__mlx_array_leaf__
|
|
339
345
|
|
|
340
346
|
def load(file, format = nil, return_metadata = false)
|
|
@@ -963,7 +969,8 @@ module MLX
|
|
|
963
969
|
end
|
|
964
970
|
end
|
|
965
971
|
|
|
966
|
-
|
|
972
|
+
remove_method(:eql?) if method_defined?(:eql?) && instance_method(:eql?).owner == self
|
|
973
|
+
alias_method :eql?, :==
|
|
967
974
|
end
|
|
968
975
|
|
|
969
976
|
class Array
|
|
@@ -373,12 +373,18 @@ module MLX
|
|
|
373
373
|
opts[:env] = hostfile.envs + opts[:env]
|
|
374
374
|
|
|
375
375
|
command = rest.dup
|
|
376
|
-
|
|
377
|
-
|
|
376
|
+
command_name = command.first.to_s
|
|
377
|
+
script = Pathname.new(command_name)
|
|
378
|
+
explicit_path = command_name.include?(File::SEPARATOR) || command_name.start_with?(".", "~")
|
|
379
|
+
|
|
380
|
+
if explicit_path && script.file?
|
|
378
381
|
command[0] = opts[:python]
|
|
379
382
|
command.insert(1, script.realpath.to_s)
|
|
380
|
-
elsif (resolved = find_executable(
|
|
383
|
+
elsif (resolved = find_executable(command_name))
|
|
381
384
|
command[0] = resolved
|
|
385
|
+
elsif script.file?
|
|
386
|
+
command[0] = opts[:python]
|
|
387
|
+
command.insert(1, script.realpath.to_s)
|
|
382
388
|
elsif opts[:verify_script]
|
|
383
389
|
raise ArgumentError, "Invalid script or command #{command.first}"
|
|
384
390
|
end
|
|
@@ -0,0 +1,377 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module MLX
|
|
4
|
+
module DSL
|
|
5
|
+
class Builder
|
|
6
|
+
def initialize(owner = nil)
|
|
7
|
+
@owner = owner
|
|
8
|
+
@collector = nil
|
|
9
|
+
end
|
|
10
|
+
|
|
11
|
+
def build(&block)
|
|
12
|
+
raise ArgumentError, "builder requires a block" unless block_given?
|
|
13
|
+
|
|
14
|
+
instance_eval(&block)
|
|
15
|
+
end
|
|
16
|
+
|
|
17
|
+
def sequential(*modules, &block)
|
|
18
|
+
collected = __dsl_modules_from(modules, &block)
|
|
19
|
+
push(MLX::NN::Sequential.new(*collected))
|
|
20
|
+
end
|
|
21
|
+
|
|
22
|
+
def layer(entry = nil, *args, **kwargs, &block)
|
|
23
|
+
if !entry.nil? && block_given?
|
|
24
|
+
raise ArgumentError, "layer accepts either a module entry or block, not both"
|
|
25
|
+
end
|
|
26
|
+
|
|
27
|
+
if block_given?
|
|
28
|
+
return push(MLX::DSL::Callable.new(&block))
|
|
29
|
+
end
|
|
30
|
+
|
|
31
|
+
if entry.nil?
|
|
32
|
+
raise ArgumentError, "layer requires a module entry or block"
|
|
33
|
+
end
|
|
34
|
+
|
|
35
|
+
if entry.is_a?(MLX::NN::Module)
|
|
36
|
+
__dsl_reject_layer_constructor_args!(args, kwargs, entry.class)
|
|
37
|
+
return push(entry)
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
if entry.is_a?(Class)
|
|
41
|
+
unless entry <= MLX::NN::Module
|
|
42
|
+
raise TypeError, "layer class must inherit from MLX::NN::Module"
|
|
43
|
+
end
|
|
44
|
+
return push(entry.new(*args, **kwargs))
|
|
45
|
+
end
|
|
46
|
+
|
|
47
|
+
if entry.respond_to?(:call)
|
|
48
|
+
__dsl_reject_layer_constructor_args!(args, kwargs, entry.class)
|
|
49
|
+
return push(MLX::DSL::Callable.new(entry))
|
|
50
|
+
end
|
|
51
|
+
|
|
52
|
+
raise TypeError, "layer requires an MLX::NN::Module instance, MLX::NN::Module class, callable, or block"
|
|
53
|
+
end
|
|
54
|
+
|
|
55
|
+
def residual(module_obj = nil, &block)
|
|
56
|
+
modules = __dsl_modules_from(module_obj.nil? ? [] : [module_obj], &block)
|
|
57
|
+
raise ArgumentError, "residual requires at least one module" if modules.empty?
|
|
58
|
+
|
|
59
|
+
target = if modules.length == 1
|
|
60
|
+
modules[0]
|
|
61
|
+
else
|
|
62
|
+
MLX::NN::Sequential.new(*modules)
|
|
63
|
+
end
|
|
64
|
+
push(MLX::DSL::Residual.new(target))
|
|
65
|
+
end
|
|
66
|
+
|
|
67
|
+
def branch(*modules, &block)
|
|
68
|
+
collected = __dsl_modules_from(modules, &block)
|
|
69
|
+
raise ArgumentError, "branch requires at least one module" if collected.empty?
|
|
70
|
+
|
|
71
|
+
push(MLX::DSL::Parallel.new(*collected))
|
|
72
|
+
end
|
|
73
|
+
|
|
74
|
+
def concat(*modules, axis: -1, &block)
|
|
75
|
+
collected = __dsl_modules_from(modules, &block)
|
|
76
|
+
raise ArgumentError, "concat requires at least one module" if collected.empty?
|
|
77
|
+
|
|
78
|
+
push(MLX::DSL::Concat.new(*collected, axis: axis))
|
|
79
|
+
end
|
|
80
|
+
|
|
81
|
+
def sum(*modules, &block)
|
|
82
|
+
collected = __dsl_modules_from(modules, &block)
|
|
83
|
+
raise ArgumentError, "sum requires at least one module" if collected.empty?
|
|
84
|
+
|
|
85
|
+
push(MLX::DSL::Reduce.new(*collected, mode: :sum))
|
|
86
|
+
end
|
|
87
|
+
|
|
88
|
+
def fn(callable = nil, &block)
|
|
89
|
+
push(MLX::DSL::Callable.new(callable, &block))
|
|
90
|
+
end
|
|
91
|
+
alias_method :lambda_layer, :fn
|
|
92
|
+
|
|
93
|
+
def repeat_layers(count, &block)
|
|
94
|
+
entries = __dsl_collect_repeated_entries(count, &block)
|
|
95
|
+
layers = entries.map { |entry| __dsl_normalize_module_entry(entry) }
|
|
96
|
+
layers.each { |layer| push(layer) }
|
|
97
|
+
layers
|
|
98
|
+
end
|
|
99
|
+
|
|
100
|
+
def stack(count, layer_class = nil, *args, **kwargs, &block)
|
|
101
|
+
if !layer_class.nil? && block_given?
|
|
102
|
+
raise ArgumentError, "stack accepts either a layer class or block, not both"
|
|
103
|
+
end
|
|
104
|
+
|
|
105
|
+
layers = if layer_class.nil?
|
|
106
|
+
__dsl_collect_repeated_entries(count, &block).map { |entry| __dsl_normalize_module_entry(entry) }
|
|
107
|
+
else
|
|
108
|
+
__dsl_build_class_stack_layers(count, layer_class, args, kwargs)
|
|
109
|
+
end
|
|
110
|
+
push(MLX::NN::Sequential.new(*layers))
|
|
111
|
+
end
|
|
112
|
+
|
|
113
|
+
def identity(*args, **kwargs)
|
|
114
|
+
push(MLX::NN::Identity.new(*args, **kwargs))
|
|
115
|
+
end
|
|
116
|
+
|
|
117
|
+
def embedding(*args, **kwargs)
|
|
118
|
+
push(MLX::NN::Embedding.new(*args, **kwargs))
|
|
119
|
+
end
|
|
120
|
+
|
|
121
|
+
def linear(*args, **kwargs)
|
|
122
|
+
push(MLX::NN::Linear.new(*args, **kwargs))
|
|
123
|
+
end
|
|
124
|
+
|
|
125
|
+
def bilinear(*args, **kwargs)
|
|
126
|
+
push(MLX::NN::Bilinear.new(*args, **kwargs))
|
|
127
|
+
end
|
|
128
|
+
|
|
129
|
+
def relu
|
|
130
|
+
push(MLX::NN::ReLU.new)
|
|
131
|
+
end
|
|
132
|
+
|
|
133
|
+
def relu6
|
|
134
|
+
push(MLX::NN::ReLU6.new)
|
|
135
|
+
end
|
|
136
|
+
|
|
137
|
+
def leaky_relu(*args)
|
|
138
|
+
push(MLX::NN::LeakyReLU.new(*args))
|
|
139
|
+
end
|
|
140
|
+
|
|
141
|
+
def gelu(*args, **kwargs)
|
|
142
|
+
push(MLX::NN::GELU.new(*args, **kwargs))
|
|
143
|
+
end
|
|
144
|
+
|
|
145
|
+
def tanh
|
|
146
|
+
push(MLX::NN::Tanh.new)
|
|
147
|
+
end
|
|
148
|
+
|
|
149
|
+
def sigmoid
|
|
150
|
+
push(MLX::NN::Sigmoid.new)
|
|
151
|
+
end
|
|
152
|
+
|
|
153
|
+
def dropout(*args)
|
|
154
|
+
push(MLX::NN::Dropout.new(*args))
|
|
155
|
+
end
|
|
156
|
+
|
|
157
|
+
def dropout2d(*args)
|
|
158
|
+
push(MLX::NN::Dropout2d.new(*args))
|
|
159
|
+
end
|
|
160
|
+
|
|
161
|
+
def dropout3d(*args)
|
|
162
|
+
push(MLX::NN::Dropout3d.new(*args))
|
|
163
|
+
end
|
|
164
|
+
|
|
165
|
+
def conv1d(*args, **kwargs)
|
|
166
|
+
push(MLX::NN::Conv1d.new(*args, **kwargs))
|
|
167
|
+
end
|
|
168
|
+
|
|
169
|
+
def conv2d(*args, **kwargs)
|
|
170
|
+
push(MLX::NN::Conv2d.new(*args, **kwargs))
|
|
171
|
+
end
|
|
172
|
+
|
|
173
|
+
def conv3d(*args, **kwargs)
|
|
174
|
+
push(MLX::NN::Conv3d.new(*args, **kwargs))
|
|
175
|
+
end
|
|
176
|
+
|
|
177
|
+
def conv_transpose1d(*args, **kwargs)
|
|
178
|
+
push(MLX::NN::ConvTranspose1d.new(*args, **kwargs))
|
|
179
|
+
end
|
|
180
|
+
|
|
181
|
+
def conv_transpose2d(*args, **kwargs)
|
|
182
|
+
push(MLX::NN::ConvTranspose2d.new(*args, **kwargs))
|
|
183
|
+
end
|
|
184
|
+
|
|
185
|
+
def conv_transpose3d(*args, **kwargs)
|
|
186
|
+
push(MLX::NN::ConvTranspose3d.new(*args, **kwargs))
|
|
187
|
+
end
|
|
188
|
+
|
|
189
|
+
def layer_norm(*args, **kwargs)
|
|
190
|
+
push(MLX::NN::LayerNorm.new(*args, **kwargs))
|
|
191
|
+
end
|
|
192
|
+
|
|
193
|
+
def rms_norm(*args, **kwargs)
|
|
194
|
+
push(MLX::NN::RMSNorm.new(*args, **kwargs))
|
|
195
|
+
end
|
|
196
|
+
|
|
197
|
+
def batch_norm(*args, **kwargs)
|
|
198
|
+
push(MLX::NN::BatchNorm.new(*args, **kwargs))
|
|
199
|
+
end
|
|
200
|
+
|
|
201
|
+
def instance_norm(*args, **kwargs)
|
|
202
|
+
push(MLX::NN::InstanceNorm.new(*args, **kwargs))
|
|
203
|
+
end
|
|
204
|
+
|
|
205
|
+
def group_norm(*args, **kwargs)
|
|
206
|
+
push(MLX::NN::GroupNorm.new(*args, **kwargs))
|
|
207
|
+
end
|
|
208
|
+
|
|
209
|
+
def max_pool2d(*args, **kwargs)
|
|
210
|
+
push(MLX::NN::MaxPool2d.new(*args, **kwargs))
|
|
211
|
+
end
|
|
212
|
+
|
|
213
|
+
def avg_pool2d(*args, **kwargs)
|
|
214
|
+
push(MLX::NN::AvgPool2d.new(*args, **kwargs))
|
|
215
|
+
end
|
|
216
|
+
|
|
217
|
+
def max_pool1d(*args, **kwargs)
|
|
218
|
+
push(MLX::NN::MaxPool1d.new(*args, **kwargs))
|
|
219
|
+
end
|
|
220
|
+
|
|
221
|
+
def avg_pool1d(*args, **kwargs)
|
|
222
|
+
push(MLX::NN::AvgPool1d.new(*args, **kwargs))
|
|
223
|
+
end
|
|
224
|
+
|
|
225
|
+
def max_pool3d(*args, **kwargs)
|
|
226
|
+
push(MLX::NN::MaxPool3d.new(*args, **kwargs))
|
|
227
|
+
end
|
|
228
|
+
|
|
229
|
+
def avg_pool3d(*args, **kwargs)
|
|
230
|
+
push(MLX::NN::AvgPool3d.new(*args, **kwargs))
|
|
231
|
+
end
|
|
232
|
+
|
|
233
|
+
def rnn(*args, **kwargs)
|
|
234
|
+
push(MLX::NN::RNN.new(*args, **kwargs))
|
|
235
|
+
end
|
|
236
|
+
|
|
237
|
+
def gru(*args, **kwargs)
|
|
238
|
+
push(MLX::NN::GRU.new(*args, **kwargs))
|
|
239
|
+
end
|
|
240
|
+
|
|
241
|
+
def lstm(*args, **kwargs)
|
|
242
|
+
push(MLX::NN::LSTM.new(*args, **kwargs))
|
|
243
|
+
end
|
|
244
|
+
|
|
245
|
+
def multi_head_attention(*args, **kwargs)
|
|
246
|
+
push(MLX::NN::MultiHeadAttention.new(*args, **kwargs))
|
|
247
|
+
end
|
|
248
|
+
|
|
249
|
+
def transformer_encoder_layer(*args, **kwargs)
|
|
250
|
+
push(MLX::NN::TransformerEncoderLayer.new(*args, **kwargs))
|
|
251
|
+
end
|
|
252
|
+
|
|
253
|
+
def transformer_encoder(*args, **kwargs)
|
|
254
|
+
push(MLX::NN::TransformerEncoder.new(*args, **kwargs))
|
|
255
|
+
end
|
|
256
|
+
|
|
257
|
+
def transformer_decoder_layer(*args, **kwargs)
|
|
258
|
+
push(MLX::NN::TransformerDecoderLayer.new(*args, **kwargs))
|
|
259
|
+
end
|
|
260
|
+
|
|
261
|
+
def transformer_decoder(*args, **kwargs)
|
|
262
|
+
push(MLX::NN::TransformerDecoder.new(*args, **kwargs))
|
|
263
|
+
end
|
|
264
|
+
|
|
265
|
+
def transformer(*args, **kwargs)
|
|
266
|
+
push(MLX::NN::Transformer.new(*args, **kwargs))
|
|
267
|
+
end
|
|
268
|
+
|
|
269
|
+
def rope(*args, **kwargs)
|
|
270
|
+
push(MLX::NN::RoPE.new(*args, **kwargs))
|
|
271
|
+
end
|
|
272
|
+
|
|
273
|
+
def sinusoidal_positional_encoding(*args, **kwargs)
|
|
274
|
+
push(MLX::NN::SinusoidalPositionalEncoding.new(*args, **kwargs))
|
|
275
|
+
end
|
|
276
|
+
|
|
277
|
+
def alibi(*args, **kwargs)
|
|
278
|
+
push(MLX::NN::ALiBi.new(*args, **kwargs))
|
|
279
|
+
end
|
|
280
|
+
|
|
281
|
+
def upsample(*args, **kwargs)
|
|
282
|
+
push(MLX::NN::Upsample.new(*args, **kwargs))
|
|
283
|
+
end
|
|
284
|
+
|
|
285
|
+
def method_missing(name, *args, **kwargs, &block)
|
|
286
|
+
if !@owner.nil? && @owner.respond_to?(name)
|
|
287
|
+
@owner.public_send(name, *args, **kwargs, &block)
|
|
288
|
+
else
|
|
289
|
+
super
|
|
290
|
+
end
|
|
291
|
+
end
|
|
292
|
+
|
|
293
|
+
def respond_to_missing?(name, include_private = false)
|
|
294
|
+
(!@owner.nil? && @owner.respond_to?(name, include_private)) || super
|
|
295
|
+
end
|
|
296
|
+
|
|
297
|
+
private
|
|
298
|
+
|
|
299
|
+
def collect_modules(&block)
|
|
300
|
+
previous = @collector
|
|
301
|
+
@collector = []
|
|
302
|
+
returned = instance_eval(&block)
|
|
303
|
+
collected = @collector.dup
|
|
304
|
+
if collected.empty? && !returned.nil?
|
|
305
|
+
collected << returned
|
|
306
|
+
end
|
|
307
|
+
collected
|
|
308
|
+
ensure
|
|
309
|
+
@collector = previous
|
|
310
|
+
end
|
|
311
|
+
|
|
312
|
+
def push(module_obj)
|
|
313
|
+
@collector << module_obj unless @collector.nil?
|
|
314
|
+
module_obj
|
|
315
|
+
end
|
|
316
|
+
|
|
317
|
+
def __dsl_modules_from(existing, &block)
|
|
318
|
+
out = existing.dup
|
|
319
|
+
out.concat(collect_modules(&block)) if block_given?
|
|
320
|
+
out.map { |entry| __dsl_normalize_module_entry(entry) }
|
|
321
|
+
end
|
|
322
|
+
|
|
323
|
+
def __dsl_normalize_module_entry(entry)
|
|
324
|
+
return entry if entry.is_a?(MLX::NN::Module)
|
|
325
|
+
|
|
326
|
+
if entry.is_a?(Class)
|
|
327
|
+
return entry.new if entry <= MLX::NN::Module
|
|
328
|
+
|
|
329
|
+
raise TypeError, "builder entries must be MLX::NN::Module instances, MLX::NN::Module classes, or callables"
|
|
330
|
+
end
|
|
331
|
+
|
|
332
|
+
return MLX::DSL::Callable.new(entry) if entry.respond_to?(:call)
|
|
333
|
+
|
|
334
|
+
raise TypeError, "builder entries must be MLX::NN::Module instances, MLX::NN::Module classes, or callables"
|
|
335
|
+
end
|
|
336
|
+
|
|
337
|
+
def __dsl_reject_layer_constructor_args!(args, kwargs, entry_type)
|
|
338
|
+
return if args.empty? && kwargs.empty?
|
|
339
|
+
|
|
340
|
+
raise ArgumentError, "layer entry #{entry_type} does not accept constructor arguments"
|
|
341
|
+
end
|
|
342
|
+
|
|
343
|
+
def __dsl_collect_repeated_entries(count, &block)
|
|
344
|
+
raise ArgumentError, "repeat requires a block" unless block_given?
|
|
345
|
+
|
|
346
|
+
repeats = count.to_i
|
|
347
|
+
raise ArgumentError, "repeat count must be non-negative" if repeats.negative?
|
|
348
|
+
|
|
349
|
+
out = []
|
|
350
|
+
repeats.times do |index|
|
|
351
|
+
out.concat(
|
|
352
|
+
collect_modules do
|
|
353
|
+
__dsl_call_repeat_block(block, index)
|
|
354
|
+
end
|
|
355
|
+
)
|
|
356
|
+
end
|
|
357
|
+
out
|
|
358
|
+
end
|
|
359
|
+
|
|
360
|
+
def __dsl_call_repeat_block(block, index)
|
|
361
|
+
return instance_eval(&block) if block.arity.zero?
|
|
362
|
+
|
|
363
|
+
block.call(index)
|
|
364
|
+
end
|
|
365
|
+
|
|
366
|
+
def __dsl_build_class_stack_layers(count, layer_class, args, kwargs)
|
|
367
|
+
repeats = count.to_i
|
|
368
|
+
raise ArgumentError, "stack count must be non-negative" if repeats.negative?
|
|
369
|
+
unless layer_class.is_a?(Class) && layer_class <= MLX::NN::Module
|
|
370
|
+
raise TypeError, "stack layer class must inherit from MLX::NN::Module"
|
|
371
|
+
end
|
|
372
|
+
|
|
373
|
+
Array.new(repeats) { layer_class.new(*args, **kwargs) }
|
|
374
|
+
end
|
|
375
|
+
end
|
|
376
|
+
end
|
|
377
|
+
end
|