mlx 0.30.7 → 0.30.7.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: ce770b59d71e1d5fbf8697bfebde05006fa52a770f8526575f94944cf161d05b
4
- data.tar.gz: 814aa4e6c063b3b36c3b1d3c4c26d8f5fa3952a66947d294a8edb3864b07ece2
3
+ metadata.gz: 215a912d2353fd5edaa60e320a5b857aa13009e55f3a190b21b2ffe5735f37af
4
+ data.tar.gz: 1c9b4279f8077e3cd067354ea692e1b24248afe7262d88d4569c566c52f5a158
5
5
  SHA512:
6
- metadata.gz: 3702f6d4445ea4af978ffcebb71dcef91f8e28d4f4d79a1d16a49b66848023a84b7b8ff2a614fa7b2c236d324dfb92b8ccb01e70ff59c53234ac28ee6fc09b39
7
- data.tar.gz: 52bf528fee068f422dac611fdca22fae8089a7ea9a3a4d0cd42a21aecc53b5219b0dcf83d79f5fa907a1c99f97acb831d952474963071b566b3483e91c4a4a72
6
+ metadata.gz: 66abcbd58ccfc04186df11b0d2b6445c7d1e0ab4a36451742755f6fcf41022363403536c3d27640935364c58231c4c4a03e39ceba97617a59f2ad69acf23dc16
7
+ data.tar.gz: ba7ad07ccd31e94bdf3fdee73117f069c6ee22c11cdfc3f2470eedee9ee0bc9e976970bab21f156c003b658723a0e3fb2f63f62d0e11d33de3a61fc1d9121711
data/ext/mlx/native.cpp CHANGED
@@ -7884,9 +7884,6 @@ extern "C" void Init_native(void) {
7884
7884
  "scaled_dot_product_attention",
7885
7885
  RUBY_METHOD_FUNC(core_scaled_dot_product_attention),
7886
7886
  -1);
7887
- rb_define_singleton_method(
7888
- mCore, "scaled_dot_product_attention", RUBY_METHOD_FUNC(core_scaled_dot_product_attention), -1);
7889
- rb_define_singleton_method(mCore, "scaled_dot_product_attention", RUBY_METHOD_FUNC(core_scaled_dot_product_attention), -1);
7890
7887
  rb_define_singleton_method(mCore, "arange", RUBY_METHOD_FUNC(core_arange), -1);
7891
7888
  rb_define_singleton_method(mCore, "linspace", RUBY_METHOD_FUNC(core_linspace), -1);
7892
7889
  rb_define_singleton_method(mCore, "zeros", RUBY_METHOD_FUNC(core_zeros), -1);
@@ -8023,5 +8020,4 @@ extern "C" void Init_native(void) {
8023
8020
  "precompiled_cuda_kernel",
8024
8021
  RUBY_METHOD_FUNC(core_precompiled_cuda_kernel),
8025
8022
  -1);
8026
- rb_define_singleton_method(mCore, "precompiled_cuda_kernel", RUBY_METHOD_FUNC(core_precompiled_cuda_kernel), -1);
8027
8023
  }
data/lib/mlx/core.rb CHANGED
@@ -335,6 +335,12 @@ module MLX
335
335
  alias_method :native_export_to_dot,
336
336
  :export_to_dot if method_defined?(:export_to_dot) && !method_defined?(:native_export_to_dot)
337
337
 
338
+ %i[savez savez_compressed].each do |method_name|
339
+ if method_defined?(method_name) && instance_method(method_name).owner == self
340
+ remove_method(method_name)
341
+ end
342
+ end
343
+
338
344
  ARRAY_LEAF = :__mlx_array_leaf__
339
345
 
340
346
  def load(file, format = nil, return_metadata = false)
@@ -963,7 +969,8 @@ module MLX
963
969
  end
964
970
  end
965
971
 
966
- alias eql? ==
972
+ remove_method(:eql?) if method_defined?(:eql?) && instance_method(:eql?).owner == self
973
+ alias_method :eql?, :==
967
974
  end
968
975
 
969
976
  class Array
@@ -373,12 +373,18 @@ module MLX
373
373
  opts[:env] = hostfile.envs + opts[:env]
374
374
 
375
375
  command = rest.dup
376
- script = Pathname.new(command.first)
377
- if script.file?
376
+ command_name = command.first.to_s
377
+ script = Pathname.new(command_name)
378
+ explicit_path = command_name.include?(File::SEPARATOR) || command_name.start_with?(".", "~")
379
+
380
+ if explicit_path && script.file?
378
381
  command[0] = opts[:python]
379
382
  command.insert(1, script.realpath.to_s)
380
- elsif (resolved = find_executable(command.first))
383
+ elsif (resolved = find_executable(command_name))
381
384
  command[0] = resolved
385
+ elsif script.file?
386
+ command[0] = opts[:python]
387
+ command.insert(1, script.realpath.to_s)
382
388
  elsif opts[:verify_script]
383
389
  raise ArgumentError, "Invalid script or command #{command.first}"
384
390
  end
@@ -0,0 +1,377 @@
1
+ # frozen_string_literal: true
2
+
3
+ module MLX
4
+ module DSL
5
+ class Builder
6
+ def initialize(owner = nil)
7
+ @owner = owner
8
+ @collector = nil
9
+ end
10
+
11
+ def build(&block)
12
+ raise ArgumentError, "builder requires a block" unless block_given?
13
+
14
+ instance_eval(&block)
15
+ end
16
+
17
+ def sequential(*modules, &block)
18
+ collected = __dsl_modules_from(modules, &block)
19
+ push(MLX::NN::Sequential.new(*collected))
20
+ end
21
+
22
+ def layer(entry = nil, *args, **kwargs, &block)
23
+ if !entry.nil? && block_given?
24
+ raise ArgumentError, "layer accepts either a module entry or block, not both"
25
+ end
26
+
27
+ if block_given?
28
+ return push(MLX::DSL::Callable.new(&block))
29
+ end
30
+
31
+ if entry.nil?
32
+ raise ArgumentError, "layer requires a module entry or block"
33
+ end
34
+
35
+ if entry.is_a?(MLX::NN::Module)
36
+ __dsl_reject_layer_constructor_args!(args, kwargs, entry.class)
37
+ return push(entry)
38
+ end
39
+
40
+ if entry.is_a?(Class)
41
+ unless entry <= MLX::NN::Module
42
+ raise TypeError, "layer class must inherit from MLX::NN::Module"
43
+ end
44
+ return push(entry.new(*args, **kwargs))
45
+ end
46
+
47
+ if entry.respond_to?(:call)
48
+ __dsl_reject_layer_constructor_args!(args, kwargs, entry.class)
49
+ return push(MLX::DSL::Callable.new(entry))
50
+ end
51
+
52
+ raise TypeError, "layer requires an MLX::NN::Module instance, MLX::NN::Module class, callable, or block"
53
+ end
54
+
55
+ def residual(module_obj = nil, &block)
56
+ modules = __dsl_modules_from(module_obj.nil? ? [] : [module_obj], &block)
57
+ raise ArgumentError, "residual requires at least one module" if modules.empty?
58
+
59
+ target = if modules.length == 1
60
+ modules[0]
61
+ else
62
+ MLX::NN::Sequential.new(*modules)
63
+ end
64
+ push(MLX::DSL::Residual.new(target))
65
+ end
66
+
67
+ def branch(*modules, &block)
68
+ collected = __dsl_modules_from(modules, &block)
69
+ raise ArgumentError, "branch requires at least one module" if collected.empty?
70
+
71
+ push(MLX::DSL::Parallel.new(*collected))
72
+ end
73
+
74
+ def concat(*modules, axis: -1, &block)
75
+ collected = __dsl_modules_from(modules, &block)
76
+ raise ArgumentError, "concat requires at least one module" if collected.empty?
77
+
78
+ push(MLX::DSL::Concat.new(*collected, axis: axis))
79
+ end
80
+
81
+ def sum(*modules, &block)
82
+ collected = __dsl_modules_from(modules, &block)
83
+ raise ArgumentError, "sum requires at least one module" if collected.empty?
84
+
85
+ push(MLX::DSL::Reduce.new(*collected, mode: :sum))
86
+ end
87
+
88
+ def fn(callable = nil, &block)
89
+ push(MLX::DSL::Callable.new(callable, &block))
90
+ end
91
+ alias_method :lambda_layer, :fn
92
+
93
+ def repeat_layers(count, &block)
94
+ entries = __dsl_collect_repeated_entries(count, &block)
95
+ layers = entries.map { |entry| __dsl_normalize_module_entry(entry) }
96
+ layers.each { |layer| push(layer) }
97
+ layers
98
+ end
99
+
100
+ def stack(count, layer_class = nil, *args, **kwargs, &block)
101
+ if !layer_class.nil? && block_given?
102
+ raise ArgumentError, "stack accepts either a layer class or block, not both"
103
+ end
104
+
105
+ layers = if layer_class.nil?
106
+ __dsl_collect_repeated_entries(count, &block).map { |entry| __dsl_normalize_module_entry(entry) }
107
+ else
108
+ __dsl_build_class_stack_layers(count, layer_class, args, kwargs)
109
+ end
110
+ push(MLX::NN::Sequential.new(*layers))
111
+ end
112
+
113
+ def identity(*args, **kwargs)
114
+ push(MLX::NN::Identity.new(*args, **kwargs))
115
+ end
116
+
117
+ def embedding(*args, **kwargs)
118
+ push(MLX::NN::Embedding.new(*args, **kwargs))
119
+ end
120
+
121
+ def linear(*args, **kwargs)
122
+ push(MLX::NN::Linear.new(*args, **kwargs))
123
+ end
124
+
125
+ def bilinear(*args, **kwargs)
126
+ push(MLX::NN::Bilinear.new(*args, **kwargs))
127
+ end
128
+
129
+ def relu
130
+ push(MLX::NN::ReLU.new)
131
+ end
132
+
133
+ def relu6
134
+ push(MLX::NN::ReLU6.new)
135
+ end
136
+
137
+ def leaky_relu(*args)
138
+ push(MLX::NN::LeakyReLU.new(*args))
139
+ end
140
+
141
+ def gelu(*args, **kwargs)
142
+ push(MLX::NN::GELU.new(*args, **kwargs))
143
+ end
144
+
145
+ def tanh
146
+ push(MLX::NN::Tanh.new)
147
+ end
148
+
149
+ def sigmoid
150
+ push(MLX::NN::Sigmoid.new)
151
+ end
152
+
153
+ def dropout(*args)
154
+ push(MLX::NN::Dropout.new(*args))
155
+ end
156
+
157
+ def dropout2d(*args)
158
+ push(MLX::NN::Dropout2d.new(*args))
159
+ end
160
+
161
+ def dropout3d(*args)
162
+ push(MLX::NN::Dropout3d.new(*args))
163
+ end
164
+
165
+ def conv1d(*args, **kwargs)
166
+ push(MLX::NN::Conv1d.new(*args, **kwargs))
167
+ end
168
+
169
+ def conv2d(*args, **kwargs)
170
+ push(MLX::NN::Conv2d.new(*args, **kwargs))
171
+ end
172
+
173
+ def conv3d(*args, **kwargs)
174
+ push(MLX::NN::Conv3d.new(*args, **kwargs))
175
+ end
176
+
177
+ def conv_transpose1d(*args, **kwargs)
178
+ push(MLX::NN::ConvTranspose1d.new(*args, **kwargs))
179
+ end
180
+
181
+ def conv_transpose2d(*args, **kwargs)
182
+ push(MLX::NN::ConvTranspose2d.new(*args, **kwargs))
183
+ end
184
+
185
+ def conv_transpose3d(*args, **kwargs)
186
+ push(MLX::NN::ConvTranspose3d.new(*args, **kwargs))
187
+ end
188
+
189
+ def layer_norm(*args, **kwargs)
190
+ push(MLX::NN::LayerNorm.new(*args, **kwargs))
191
+ end
192
+
193
+ def rms_norm(*args, **kwargs)
194
+ push(MLX::NN::RMSNorm.new(*args, **kwargs))
195
+ end
196
+
197
+ def batch_norm(*args, **kwargs)
198
+ push(MLX::NN::BatchNorm.new(*args, **kwargs))
199
+ end
200
+
201
+ def instance_norm(*args, **kwargs)
202
+ push(MLX::NN::InstanceNorm.new(*args, **kwargs))
203
+ end
204
+
205
+ def group_norm(*args, **kwargs)
206
+ push(MLX::NN::GroupNorm.new(*args, **kwargs))
207
+ end
208
+
209
+ def max_pool2d(*args, **kwargs)
210
+ push(MLX::NN::MaxPool2d.new(*args, **kwargs))
211
+ end
212
+
213
+ def avg_pool2d(*args, **kwargs)
214
+ push(MLX::NN::AvgPool2d.new(*args, **kwargs))
215
+ end
216
+
217
+ def max_pool1d(*args, **kwargs)
218
+ push(MLX::NN::MaxPool1d.new(*args, **kwargs))
219
+ end
220
+
221
+ def avg_pool1d(*args, **kwargs)
222
+ push(MLX::NN::AvgPool1d.new(*args, **kwargs))
223
+ end
224
+
225
+ def max_pool3d(*args, **kwargs)
226
+ push(MLX::NN::MaxPool3d.new(*args, **kwargs))
227
+ end
228
+
229
+ def avg_pool3d(*args, **kwargs)
230
+ push(MLX::NN::AvgPool3d.new(*args, **kwargs))
231
+ end
232
+
233
+ def rnn(*args, **kwargs)
234
+ push(MLX::NN::RNN.new(*args, **kwargs))
235
+ end
236
+
237
+ def gru(*args, **kwargs)
238
+ push(MLX::NN::GRU.new(*args, **kwargs))
239
+ end
240
+
241
+ def lstm(*args, **kwargs)
242
+ push(MLX::NN::LSTM.new(*args, **kwargs))
243
+ end
244
+
245
+ def multi_head_attention(*args, **kwargs)
246
+ push(MLX::NN::MultiHeadAttention.new(*args, **kwargs))
247
+ end
248
+
249
+ def transformer_encoder_layer(*args, **kwargs)
250
+ push(MLX::NN::TransformerEncoderLayer.new(*args, **kwargs))
251
+ end
252
+
253
+ def transformer_encoder(*args, **kwargs)
254
+ push(MLX::NN::TransformerEncoder.new(*args, **kwargs))
255
+ end
256
+
257
+ def transformer_decoder_layer(*args, **kwargs)
258
+ push(MLX::NN::TransformerDecoderLayer.new(*args, **kwargs))
259
+ end
260
+
261
+ def transformer_decoder(*args, **kwargs)
262
+ push(MLX::NN::TransformerDecoder.new(*args, **kwargs))
263
+ end
264
+
265
+ def transformer(*args, **kwargs)
266
+ push(MLX::NN::Transformer.new(*args, **kwargs))
267
+ end
268
+
269
+ def rope(*args, **kwargs)
270
+ push(MLX::NN::RoPE.new(*args, **kwargs))
271
+ end
272
+
273
+ def sinusoidal_positional_encoding(*args, **kwargs)
274
+ push(MLX::NN::SinusoidalPositionalEncoding.new(*args, **kwargs))
275
+ end
276
+
277
+ def alibi(*args, **kwargs)
278
+ push(MLX::NN::ALiBi.new(*args, **kwargs))
279
+ end
280
+
281
+ def upsample(*args, **kwargs)
282
+ push(MLX::NN::Upsample.new(*args, **kwargs))
283
+ end
284
+
285
+ def method_missing(name, *args, **kwargs, &block)
286
+ if !@owner.nil? && @owner.respond_to?(name)
287
+ @owner.public_send(name, *args, **kwargs, &block)
288
+ else
289
+ super
290
+ end
291
+ end
292
+
293
+ def respond_to_missing?(name, include_private = false)
294
+ (!@owner.nil? && @owner.respond_to?(name, include_private)) || super
295
+ end
296
+
297
+ private
298
+
299
+ def collect_modules(&block)
300
+ previous = @collector
301
+ @collector = []
302
+ returned = instance_eval(&block)
303
+ collected = @collector.dup
304
+ if collected.empty? && !returned.nil?
305
+ collected << returned
306
+ end
307
+ collected
308
+ ensure
309
+ @collector = previous
310
+ end
311
+
312
+ def push(module_obj)
313
+ @collector << module_obj unless @collector.nil?
314
+ module_obj
315
+ end
316
+
317
+ def __dsl_modules_from(existing, &block)
318
+ out = existing.dup
319
+ out.concat(collect_modules(&block)) if block_given?
320
+ out.map { |entry| __dsl_normalize_module_entry(entry) }
321
+ end
322
+
323
+ def __dsl_normalize_module_entry(entry)
324
+ return entry if entry.is_a?(MLX::NN::Module)
325
+
326
+ if entry.is_a?(Class)
327
+ return entry.new if entry <= MLX::NN::Module
328
+
329
+ raise TypeError, "builder entries must be MLX::NN::Module instances, MLX::NN::Module classes, or callables"
330
+ end
331
+
332
+ return MLX::DSL::Callable.new(entry) if entry.respond_to?(:call)
333
+
334
+ raise TypeError, "builder entries must be MLX::NN::Module instances, MLX::NN::Module classes, or callables"
335
+ end
336
+
337
+ def __dsl_reject_layer_constructor_args!(args, kwargs, entry_type)
338
+ return if args.empty? && kwargs.empty?
339
+
340
+ raise ArgumentError, "layer entry #{entry_type} does not accept constructor arguments"
341
+ end
342
+
343
+ def __dsl_collect_repeated_entries(count, &block)
344
+ raise ArgumentError, "repeat requires a block" unless block_given?
345
+
346
+ repeats = count.to_i
347
+ raise ArgumentError, "repeat count must be non-negative" if repeats.negative?
348
+
349
+ out = []
350
+ repeats.times do |index|
351
+ out.concat(
352
+ collect_modules do
353
+ __dsl_call_repeat_block(block, index)
354
+ end
355
+ )
356
+ end
357
+ out
358
+ end
359
+
360
+ def __dsl_call_repeat_block(block, index)
361
+ return instance_eval(&block) if block.arity.zero?
362
+
363
+ block.call(index)
364
+ end
365
+
366
+ def __dsl_build_class_stack_layers(count, layer_class, args, kwargs)
367
+ repeats = count.to_i
368
+ raise ArgumentError, "stack count must be non-negative" if repeats.negative?
369
+ unless layer_class.is_a?(Class) && layer_class <= MLX::NN::Module
370
+ raise TypeError, "stack layer class must inherit from MLX::NN::Module"
371
+ end
372
+
373
+ Array.new(repeats) { layer_class.new(*args, **kwargs) }
374
+ end
375
+ end
376
+ end
377
+ end