mlx 0.30.7 → 0.30.7.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: ce770b59d71e1d5fbf8697bfebde05006fa52a770f8526575f94944cf161d05b
4
- data.tar.gz: 814aa4e6c063b3b36c3b1d3c4c26d8f5fa3952a66947d294a8edb3864b07ece2
3
+ metadata.gz: 25d582e4816d69b27713a4027534b75cd00ca72557e69681daf07146d3e79ef2
4
+ data.tar.gz: c010252aa355370a531fa4f3b9bf8cc729876d2f7fb9ae8b8e0d6a1eb6cb57c4
5
5
  SHA512:
6
- metadata.gz: 3702f6d4445ea4af978ffcebb71dcef91f8e28d4f4d79a1d16a49b66848023a84b7b8ff2a614fa7b2c236d324dfb92b8ccb01e70ff59c53234ac28ee6fc09b39
7
- data.tar.gz: 52bf528fee068f422dac611fdca22fae8089a7ea9a3a4d0cd42a21aecc53b5219b0dcf83d79f5fa907a1c99f97acb831d952474963071b566b3483e91c4a4a72
6
+ metadata.gz: 53e629e845342f173c04c7c6d9d976a29dd5492ae945239897d3168a586288ec958ba58753345317f301220ac5f4b91a22f97731ab799fcea5d59f3d19e48214
7
+ data.tar.gz: 5b04f2e63e3dcdb6a0282184a310600f4fb72e606b45e8e7a27a7b9461abef3a598afe2eb5525e66457da45b8a84a6c2f07c871c955e9cddd4308971496c7fd1
data/ext/mlx/native.cpp CHANGED
@@ -6625,7 +6625,8 @@ static VALUE core_clear_cache(VALUE) {
6625
6625
 
6626
6626
  static VALUE core_metal_is_available(VALUE) {
6627
6627
  try {
6628
- return mxmetal::is_available() ? Qtrue : Qfalse;
6628
+ const mx::Device gpu_device(mx::Device::gpu, 0);
6629
+ return mx::is_available(gpu_device) ? Qtrue : Qfalse;
6629
6630
  } catch (const std::exception& error) {
6630
6631
  raise_std_exception(error);
6631
6632
  return Qnil;
@@ -6654,7 +6655,12 @@ static VALUE core_metal_stop_capture(VALUE) {
6654
6655
 
6655
6656
  static VALUE core_metal_device_info(VALUE) {
6656
6657
  try {
6657
- const auto& info = mxmetal::device_info();
6658
+ const mx::Device gpu_device(mx::Device::gpu, 0);
6659
+ if (!mx::is_available(gpu_device)) {
6660
+ rb_raise(rb_eRuntimeError, "[metal_device_info] Metal GPU device is not available");
6661
+ }
6662
+
6663
+ const auto& info = mx::device_info(gpu_device);
6658
6664
  VALUE hash = rb_hash_new();
6659
6665
  for (const auto& [key, value] : info) {
6660
6666
  VALUE ruby_key = rb_utf8_str_new(key.c_str(), static_cast<long>(key.size()));
@@ -7884,9 +7890,6 @@ extern "C" void Init_native(void) {
7884
7890
  "scaled_dot_product_attention",
7885
7891
  RUBY_METHOD_FUNC(core_scaled_dot_product_attention),
7886
7892
  -1);
7887
- rb_define_singleton_method(
7888
- mCore, "scaled_dot_product_attention", RUBY_METHOD_FUNC(core_scaled_dot_product_attention), -1);
7889
- rb_define_singleton_method(mCore, "scaled_dot_product_attention", RUBY_METHOD_FUNC(core_scaled_dot_product_attention), -1);
7890
7893
  rb_define_singleton_method(mCore, "arange", RUBY_METHOD_FUNC(core_arange), -1);
7891
7894
  rb_define_singleton_method(mCore, "linspace", RUBY_METHOD_FUNC(core_linspace), -1);
7892
7895
  rb_define_singleton_method(mCore, "zeros", RUBY_METHOD_FUNC(core_zeros), -1);
@@ -8023,5 +8026,4 @@ extern "C" void Init_native(void) {
8023
8026
  "precompiled_cuda_kernel",
8024
8027
  RUBY_METHOD_FUNC(core_precompiled_cuda_kernel),
8025
8028
  -1);
8026
- rb_define_singleton_method(mCore, "precompiled_cuda_kernel", RUBY_METHOD_FUNC(core_precompiled_cuda_kernel), -1);
8027
8029
  }
data/lib/mlx/core.rb CHANGED
@@ -335,6 +335,12 @@ module MLX
335
335
  alias_method :native_export_to_dot,
336
336
  :export_to_dot if method_defined?(:export_to_dot) && !method_defined?(:native_export_to_dot)
337
337
 
338
+ %i[savez savez_compressed].each do |method_name|
339
+ if method_defined?(method_name) && instance_method(method_name).owner == self
340
+ remove_method(method_name)
341
+ end
342
+ end
343
+
338
344
  ARRAY_LEAF = :__mlx_array_leaf__
339
345
 
340
346
  def load(file, format = nil, return_metadata = false)
@@ -963,7 +969,8 @@ module MLX
963
969
  end
964
970
  end
965
971
 
966
- alias eql? ==
972
+ remove_method(:eql?) if method_defined?(:eql?) && instance_method(:eql?).owner == self
973
+ alias_method :eql?, :==
967
974
  end
968
975
 
969
976
  class Array
@@ -5,7 +5,7 @@ require "json"
5
5
 
6
6
  module MLX
7
7
  module DistributedUtils
8
- Host = Struct.new(:rank, :ssh_hostname, :ips, :rdma, keyword_init: true)
8
+ Host = Data.define(:rank, :ssh_hostname, :ips, :rdma)
9
9
 
10
10
  class Hostfile
11
11
  attr_accessor :hosts, :backend, :envs
@@ -8,13 +8,14 @@ require "shellwords"
8
8
 
9
9
  module MLX
10
10
  module DistributedUtils
11
- SSHInfo = Struct.new(:can_ssh, :has_sudo, keyword_init: true) do
11
+ SSHInfo = Data.define(:can_ssh, :has_sudo) do
12
12
  def to_bool
13
13
  can_ssh
14
14
  end
15
15
  end
16
- ThunderboltPort = Struct.new(:iface, :uuid, :connected_to, keyword_init: true)
17
- ThunderboltHost = Struct.new(:name, :ports, keyword_init: true)
16
+ ThunderboltPort = Data.define(:iface, :uuid, :connected_to)
17
+ ThunderboltHost = Data.define(:name, :ports)
18
+ CommandResult = Data.define(:stdout, :stderr, :status)
18
19
 
19
20
  class IPConfigurator
20
21
  attr_reader :ips, :hosts, :tb_hosts
@@ -509,6 +510,8 @@ module MLX
509
510
  end
510
511
 
511
512
  def config_main(argv = ARGV, runner: nil)
513
+ Process.warmup if Process.respond_to?(:warmup)
514
+
512
515
  opts = {
513
516
  verbose: false,
514
517
  hosts: "127.0.0.1",
@@ -577,7 +580,7 @@ module MLX
577
580
  return runner.call(cmd) unless runner.nil?
578
581
 
579
582
  stdout, stderr, status = Open3.capture3(*cmd)
580
- Struct.new(:stdout, :stderr, :status, keyword_init: true).new(stdout: stdout, stderr: stderr, status: status)
583
+ CommandResult.new(stdout: stdout, stderr: stderr, status: status)
581
584
  end
582
585
 
583
586
  def stdout_for(result)
@@ -314,6 +314,8 @@ module MLX
314
314
  end
315
315
 
316
316
  def main(argv = ARGV)
317
+ Process.warmup if Process.respond_to?(:warmup)
318
+
317
319
  opts = {
318
320
  print_python: false,
319
321
  verbose: false,
@@ -373,12 +375,18 @@ module MLX
373
375
  opts[:env] = hostfile.envs + opts[:env]
374
376
 
375
377
  command = rest.dup
376
- script = Pathname.new(command.first)
377
- if script.file?
378
+ command_name = command.first.to_s
379
+ script = Pathname.new(command_name)
380
+ explicit_path = command_name.include?(File::SEPARATOR) || command_name.start_with?(".", "~")
381
+
382
+ if explicit_path && script.file?
378
383
  command[0] = opts[:python]
379
384
  command.insert(1, script.realpath.to_s)
380
- elsif (resolved = find_executable(command.first))
385
+ elsif (resolved = find_executable(command_name))
381
386
  command[0] = resolved
387
+ elsif script.file?
388
+ command[0] = opts[:python]
389
+ command.insert(1, script.realpath.to_s)
382
390
  elsif opts[:verify_script]
383
391
  raise ArgumentError, "Invalid script or command #{command.first}"
384
392
  end
@@ -0,0 +1,132 @@
1
+ # frozen_string_literal: true
2
+
3
+ module MLX
4
+ module DSL
5
+ class Attention < MLX::NN::Module
6
+ def initialize(
7
+ dims:,
8
+ num_heads:,
9
+ kv_heads: nil,
10
+ qkv_bias: false,
11
+ backend: :sdpa,
12
+ rope: nil,
13
+ cache: false
14
+ )
15
+ super()
16
+
17
+ @dims = Integer(dims)
18
+ @num_heads = Integer(num_heads)
19
+ @kv_heads = kv_heads.nil? ? @num_heads : Integer(kv_heads)
20
+ if (@dims % @num_heads) != 0
21
+ raise ArgumentError, "dims must be divisible by num_heads"
22
+ end
23
+ if (@num_heads % @kv_heads) != 0
24
+ raise ArgumentError, "num_heads must be divisible by kv_heads"
25
+ end
26
+
27
+ @head_dim = @dims / @num_heads
28
+ @kv_repeats = @num_heads / @kv_heads
29
+ @backend = backend.to_sym
30
+ @cache_enabled = !!cache
31
+ @scale = Math.sqrt(1.0 / @head_dim)
32
+
33
+ self.query_proj = MLX::NN::Linear.new(@dims, @num_heads * @head_dim, bias: qkv_bias)
34
+ self.key_proj = MLX::NN::Linear.new(@dims, @kv_heads * @head_dim, bias: qkv_bias)
35
+ self.value_proj = MLX::NN::Linear.new(@dims, @kv_heads * @head_dim, bias: qkv_bias)
36
+ self.out_proj = MLX::NN::Linear.new(@num_heads * @head_dim, @dims, bias: qkv_bias)
37
+ self.rope = __dsl_build_rope(rope)
38
+ end
39
+
40
+ def call(queries, keys = nil, values = nil, mask: nil, cache: nil)
41
+ keys ||= queries
42
+ values ||= keys
43
+ q_was_2d = queries.ndim == 2
44
+
45
+ queries = MLX::Core.expand_dims(queries, 0) if q_was_2d
46
+ keys = MLX::Core.expand_dims(keys, 0) if keys.ndim == 2
47
+ values = MLX::Core.expand_dims(values, 0) if values.ndim == 2
48
+
49
+ batch_size, q_len, = queries.shape
50
+
51
+ q = __dsl_pack_heads(query_proj.call(queries), @num_heads)
52
+ k = __dsl_pack_heads(key_proj.call(keys), @kv_heads)
53
+ v = __dsl_pack_heads(value_proj.call(values), @kv_heads)
54
+
55
+ offset = cache.nil? ? 0 : cache[0].shape[2]
56
+ if !rope.nil?
57
+ if offset.zero?
58
+ q = rope.call(q)
59
+ k = rope.call(k)
60
+ else
61
+ q = rope.call(q, offset: offset)
62
+ k = rope.call(k, offset: offset)
63
+ end
64
+ end
65
+
66
+ unless cache.nil?
67
+ key_cache, value_cache = cache
68
+ k = MLX::Core.concatenate([key_cache, k], 2)
69
+ v = MLX::Core.concatenate([value_cache, v], 2)
70
+ end
71
+ next_cache = [k, v]
72
+
73
+ k_for_attn = __dsl_repeat_kv(k)
74
+ v_for_attn = __dsl_repeat_kv(v)
75
+ out = __dsl_attention(q, k_for_attn, v_for_attn, mask)
76
+ out = MLX::Core.transpose(out, [0, 2, 1, 3])
77
+ out = MLX::Core.reshape(out, [batch_size, q_len, @num_heads * @head_dim])
78
+ out = out_proj.call(out)
79
+ out = MLX::Core.squeeze(out, 0) if q_was_2d
80
+
81
+ if @cache_enabled || !cache.nil?
82
+ [out, next_cache]
83
+ else
84
+ out
85
+ end
86
+ end
87
+
88
+ private
89
+
90
+ def __dsl_build_rope(config)
91
+ return nil if config.nil?
92
+
93
+ opts = config.transform_keys(&:to_sym)
94
+ rope_kwargs = {
95
+ traditional: opts.fetch(:traditional, false),
96
+ base: opts.fetch(:base, 10_000.0)
97
+ }
98
+ rope_kwargs[:scale] = opts[:scale] if opts.key?(:scale)
99
+ MLX::NN::RoPE.new(@head_dim, **rope_kwargs)
100
+ end
101
+
102
+ def __dsl_pack_heads(x, heads)
103
+ batch, length, = x.shape
104
+ x = MLX::Core.reshape(x, [batch, length, heads, @head_dim])
105
+ MLX::Core.transpose(x, [0, 2, 1, 3])
106
+ end
107
+
108
+ def __dsl_repeat_kv(x)
109
+ return x if @kv_repeats == 1
110
+
111
+ batch, _heads, length, dim = x.shape
112
+ expanded = MLX::Core.expand_dims(x, 2)
113
+ repeated = MLX::Core.concatenate(Array.new(@kv_repeats, expanded), 2)
114
+ MLX::Core.reshape(repeated, [batch, @num_heads, length, dim])
115
+ end
116
+
117
+ def __dsl_attention(q, k, v, mask)
118
+ if @backend == :sdpa && MLX::Core.respond_to?(:scaled_dot_product_attention)
119
+ return MLX::Core.scaled_dot_product_attention(q, k, v, @scale, mask)
120
+ end
121
+
122
+ scores = MLX::Core.matmul(
123
+ MLX::Core.multiply(q, @scale),
124
+ MLX::Core.transpose(k, [0, 1, 3, 2])
125
+ )
126
+ scores = MLX::Core.add(scores, mask.astype(scores.dtype)) unless mask.nil?
127
+ probs = MLX::Core.softmax(scores.astype(MLX::Core.float32), -1).astype(scores.dtype)
128
+ MLX::Core.matmul(probs, v)
129
+ end
130
+ end
131
+ end
132
+ end
@@ -0,0 +1,385 @@
1
+ # frozen_string_literal: true
2
+
3
+ module MLX
4
+ module DSL
5
+ class Builder
6
+ def initialize(owner = nil)
7
+ @owner = owner
8
+ @collector = nil
9
+ end
10
+
11
+ def build(&block)
12
+ raise ArgumentError, "builder requires a block" unless block_given?
13
+
14
+ instance_eval(&block)
15
+ end
16
+
17
+ def sequential(*modules, &block)
18
+ collected = __dsl_modules_from(modules, &block)
19
+ push(MLX::NN::Sequential.new(*collected))
20
+ end
21
+
22
+ def layer(entry = nil, *args, **kwargs, &block)
23
+ if !entry.nil? && block_given?
24
+ raise ArgumentError, "layer accepts either a module entry or block, not both"
25
+ end
26
+
27
+ if block_given?
28
+ return push(MLX::DSL::Callable.new(&block))
29
+ end
30
+
31
+ if entry.nil?
32
+ raise ArgumentError, "layer requires a module entry or block"
33
+ end
34
+
35
+ if entry.is_a?(MLX::NN::Module)
36
+ __dsl_reject_layer_constructor_args!(args, kwargs, entry.class)
37
+ return push(entry)
38
+ end
39
+
40
+ if entry.is_a?(Class)
41
+ unless entry <= MLX::NN::Module
42
+ raise TypeError, "layer class must inherit from MLX::NN::Module"
43
+ end
44
+ return push(entry.new(*args, **kwargs))
45
+ end
46
+
47
+ if entry.respond_to?(:call)
48
+ __dsl_reject_layer_constructor_args!(args, kwargs, entry.class)
49
+ return push(MLX::DSL::Callable.new(entry))
50
+ end
51
+
52
+ raise TypeError, "layer requires an MLX::NN::Module instance, MLX::NN::Module class, callable, or block"
53
+ end
54
+
55
+ def residual(module_obj = nil, &block)
56
+ modules = __dsl_modules_from(module_obj.nil? ? [] : [module_obj], &block)
57
+ raise ArgumentError, "residual requires at least one module" if modules.empty?
58
+
59
+ target = if modules.length == 1
60
+ modules[0]
61
+ else
62
+ MLX::NN::Sequential.new(*modules)
63
+ end
64
+ push(MLX::DSL::Residual.new(target))
65
+ end
66
+
67
+ def branch(*modules, &block)
68
+ collected = __dsl_modules_from(modules, &block)
69
+ raise ArgumentError, "branch requires at least one module" if collected.empty?
70
+
71
+ push(MLX::DSL::Parallel.new(*collected))
72
+ end
73
+
74
+ def concat(*modules, axis: -1, &block)
75
+ collected = __dsl_modules_from(modules, &block)
76
+ raise ArgumentError, "concat requires at least one module" if collected.empty?
77
+
78
+ push(MLX::DSL::Concat.new(*collected, axis: axis))
79
+ end
80
+
81
+ def sum(*modules, &block)
82
+ collected = __dsl_modules_from(modules, &block)
83
+ raise ArgumentError, "sum requires at least one module" if collected.empty?
84
+
85
+ push(MLX::DSL::Reduce.new(*collected, mode: :sum))
86
+ end
87
+
88
+ def fn(callable = nil, &block)
89
+ push(MLX::DSL::Callable.new(callable, &block))
90
+ end
91
+ alias_method :lambda_layer, :fn
92
+
93
+ def repeat_layers(count, &block)
94
+ entries = __dsl_collect_repeated_entries(count, &block)
95
+ layers = entries.map { |entry| __dsl_normalize_module_entry(entry) }
96
+ layers.each { |layer| push(layer) }
97
+ layers
98
+ end
99
+
100
+ def stack(count, layer_class = nil, *args, **kwargs, &block)
101
+ if !layer_class.nil? && block_given?
102
+ raise ArgumentError, "stack accepts either a layer class or block, not both"
103
+ end
104
+
105
+ layers = if layer_class.nil?
106
+ __dsl_collect_repeated_entries(count, &block).map { |entry| __dsl_normalize_module_entry(entry) }
107
+ else
108
+ __dsl_build_class_stack_layers(count, layer_class, args, kwargs)
109
+ end
110
+ push(MLX::NN::Sequential.new(*layers))
111
+ end
112
+
113
+ def identity(*args, **kwargs)
114
+ push(MLX::NN::Identity.new(*args, **kwargs))
115
+ end
116
+
117
+ def embedding(*args, **kwargs)
118
+ push(MLX::NN::Embedding.new(*args, **kwargs))
119
+ end
120
+
121
+ def linear(*args, **kwargs)
122
+ push(MLX::NN::Linear.new(*args, **kwargs))
123
+ end
124
+
125
+ def bilinear(*args, **kwargs)
126
+ push(MLX::NN::Bilinear.new(*args, **kwargs))
127
+ end
128
+
129
+ def relu
130
+ push(MLX::NN::ReLU.new)
131
+ end
132
+
133
+ def relu6
134
+ push(MLX::NN::ReLU6.new)
135
+ end
136
+
137
+ def leaky_relu(*args)
138
+ push(MLX::NN::LeakyReLU.new(*args))
139
+ end
140
+
141
+ def gelu(*args, **kwargs)
142
+ push(MLX::NN::GELU.new(*args, **kwargs))
143
+ end
144
+
145
+ def tanh
146
+ push(MLX::NN::Tanh.new)
147
+ end
148
+
149
+ def sigmoid
150
+ push(MLX::NN::Sigmoid.new)
151
+ end
152
+
153
+ def dropout(*args)
154
+ push(MLX::NN::Dropout.new(*args))
155
+ end
156
+
157
+ def dropout2d(*args)
158
+ push(MLX::NN::Dropout2d.new(*args))
159
+ end
160
+
161
+ def dropout3d(*args)
162
+ push(MLX::NN::Dropout3d.new(*args))
163
+ end
164
+
165
+ def conv1d(*args, **kwargs)
166
+ push(MLX::NN::Conv1d.new(*args, **kwargs))
167
+ end
168
+
169
+ def conv2d(*args, **kwargs)
170
+ push(MLX::NN::Conv2d.new(*args, **kwargs))
171
+ end
172
+
173
+ def conv3d(*args, **kwargs)
174
+ push(MLX::NN::Conv3d.new(*args, **kwargs))
175
+ end
176
+
177
+ def conv_transpose1d(*args, **kwargs)
178
+ push(MLX::NN::ConvTranspose1d.new(*args, **kwargs))
179
+ end
180
+
181
+ def conv_transpose2d(*args, **kwargs)
182
+ push(MLX::NN::ConvTranspose2d.new(*args, **kwargs))
183
+ end
184
+
185
+ def conv_transpose3d(*args, **kwargs)
186
+ push(MLX::NN::ConvTranspose3d.new(*args, **kwargs))
187
+ end
188
+
189
+ def layer_norm(*args, **kwargs)
190
+ push(MLX::NN::LayerNorm.new(*args, **kwargs))
191
+ end
192
+
193
+ def rms_norm(*args, **kwargs)
194
+ push(MLX::NN::RMSNorm.new(*args, **kwargs))
195
+ end
196
+
197
+ def batch_norm(*args, **kwargs)
198
+ push(MLX::NN::BatchNorm.new(*args, **kwargs))
199
+ end
200
+
201
+ def instance_norm(*args, **kwargs)
202
+ push(MLX::NN::InstanceNorm.new(*args, **kwargs))
203
+ end
204
+
205
+ def group_norm(*args, **kwargs)
206
+ push(MLX::NN::GroupNorm.new(*args, **kwargs))
207
+ end
208
+
209
+ def max_pool2d(*args, **kwargs)
210
+ push(MLX::NN::MaxPool2d.new(*args, **kwargs))
211
+ end
212
+
213
+ def avg_pool2d(*args, **kwargs)
214
+ push(MLX::NN::AvgPool2d.new(*args, **kwargs))
215
+ end
216
+
217
+ def max_pool1d(*args, **kwargs)
218
+ push(MLX::NN::MaxPool1d.new(*args, **kwargs))
219
+ end
220
+
221
+ def avg_pool1d(*args, **kwargs)
222
+ push(MLX::NN::AvgPool1d.new(*args, **kwargs))
223
+ end
224
+
225
+ def max_pool3d(*args, **kwargs)
226
+ push(MLX::NN::MaxPool3d.new(*args, **kwargs))
227
+ end
228
+
229
+ def avg_pool3d(*args, **kwargs)
230
+ push(MLX::NN::AvgPool3d.new(*args, **kwargs))
231
+ end
232
+
233
+ def rnn(*args, **kwargs)
234
+ push(MLX::NN::RNN.new(*args, **kwargs))
235
+ end
236
+
237
+ def gru(*args, **kwargs)
238
+ push(MLX::NN::GRU.new(*args, **kwargs))
239
+ end
240
+
241
+ def lstm(*args, **kwargs)
242
+ push(MLX::NN::LSTM.new(*args, **kwargs))
243
+ end
244
+
245
+ def multi_head_attention(*args, **kwargs)
246
+ push(MLX::NN::MultiHeadAttention.new(*args, **kwargs))
247
+ end
248
+
249
+ def transformer_encoder_layer(*args, **kwargs)
250
+ push(MLX::NN::TransformerEncoderLayer.new(*args, **kwargs))
251
+ end
252
+
253
+ def transformer_encoder(*args, **kwargs)
254
+ push(MLX::NN::TransformerEncoder.new(*args, **kwargs))
255
+ end
256
+
257
+ def transformer_decoder_layer(*args, **kwargs)
258
+ push(MLX::NN::TransformerDecoderLayer.new(*args, **kwargs))
259
+ end
260
+
261
+ def transformer_decoder(*args, **kwargs)
262
+ push(MLX::NN::TransformerDecoder.new(*args, **kwargs))
263
+ end
264
+
265
+ def transformer(*args, **kwargs)
266
+ push(MLX::NN::Transformer.new(*args, **kwargs))
267
+ end
268
+
269
+ def attention(*args, **kwargs)
270
+ push(MLX::DSL::Attention.new(*args, **kwargs))
271
+ end
272
+
273
+ def transformer_block(*args, **kwargs)
274
+ push(MLX::DSL::TransformerBlock.new(*args, **kwargs))
275
+ end
276
+
277
+ def rope(*args, **kwargs)
278
+ push(MLX::NN::RoPE.new(*args, **kwargs))
279
+ end
280
+
281
+ def sinusoidal_positional_encoding(*args, **kwargs)
282
+ push(MLX::NN::SinusoidalPositionalEncoding.new(*args, **kwargs))
283
+ end
284
+
285
+ def alibi(*args, **kwargs)
286
+ push(MLX::NN::ALiBi.new(*args, **kwargs))
287
+ end
288
+
289
+ def upsample(*args, **kwargs)
290
+ push(MLX::NN::Upsample.new(*args, **kwargs))
291
+ end
292
+
293
+ def method_missing(name, *args, **kwargs, &block)
294
+ if !@owner.nil? && @owner.respond_to?(name)
295
+ @owner.public_send(name, *args, **kwargs, &block)
296
+ else
297
+ super
298
+ end
299
+ end
300
+
301
+ def respond_to_missing?(name, include_private = false)
302
+ (!@owner.nil? && @owner.respond_to?(name, include_private)) || super
303
+ end
304
+
305
+ private
306
+
307
+ def collect_modules(&block)
308
+ previous = @collector
309
+ @collector = []
310
+ returned = instance_eval(&block)
311
+ collected = @collector.dup
312
+ if collected.empty? && !returned.nil?
313
+ collected << returned
314
+ end
315
+ collected
316
+ ensure
317
+ @collector = previous
318
+ end
319
+
320
+ def push(module_obj)
321
+ @collector << module_obj unless @collector.nil?
322
+ module_obj
323
+ end
324
+
325
+ def __dsl_modules_from(existing, &block)
326
+ out = existing.dup
327
+ out.concat(collect_modules(&block)) if block_given?
328
+ out.map { |entry| __dsl_normalize_module_entry(entry) }
329
+ end
330
+
331
+ def __dsl_normalize_module_entry(entry)
332
+ return entry if entry.is_a?(MLX::NN::Module)
333
+
334
+ if entry.is_a?(Class)
335
+ return entry.new if entry <= MLX::NN::Module
336
+
337
+ raise TypeError, "builder entries must be MLX::NN::Module instances, MLX::NN::Module classes, or callables"
338
+ end
339
+
340
+ return MLX::DSL::Callable.new(entry) if entry.respond_to?(:call)
341
+
342
+ raise TypeError, "builder entries must be MLX::NN::Module instances, MLX::NN::Module classes, or callables"
343
+ end
344
+
345
+ def __dsl_reject_layer_constructor_args!(args, kwargs, entry_type)
346
+ return if args.empty? && kwargs.empty?
347
+
348
+ raise ArgumentError, "layer entry #{entry_type} does not accept constructor arguments"
349
+ end
350
+
351
+ def __dsl_collect_repeated_entries(count, &block)
352
+ raise ArgumentError, "repeat requires a block" unless block_given?
353
+
354
+ repeats = count.to_i
355
+ raise ArgumentError, "repeat count must be non-negative" if repeats.negative?
356
+
357
+ out = []
358
+ repeats.times do |index|
359
+ out.concat(
360
+ collect_modules do
361
+ __dsl_call_repeat_block(block, index)
362
+ end
363
+ )
364
+ end
365
+ out
366
+ end
367
+
368
+ def __dsl_call_repeat_block(block, index)
369
+ return instance_eval(&block) if block.arity.zero?
370
+
371
+ block.call(index)
372
+ end
373
+
374
+ def __dsl_build_class_stack_layers(count, layer_class, args, kwargs)
375
+ repeats = count.to_i
376
+ raise ArgumentError, "stack count must be non-negative" if repeats.negative?
377
+ unless layer_class.is_a?(Class) && layer_class <= MLX::NN::Module
378
+ raise TypeError, "stack layer class must inherit from MLX::NN::Module"
379
+ end
380
+
381
+ Array.new(repeats) { layer_class.new(*args, **kwargs) }
382
+ end
383
+ end
384
+ end
385
+ end