mlx 0.30.7.2 → 0.30.7.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 215a912d2353fd5edaa60e320a5b857aa13009e55f3a190b21b2ffe5735f37af
4
- data.tar.gz: 1c9b4279f8077e3cd067354ea692e1b24248afe7262d88d4569c566c52f5a158
3
+ metadata.gz: 25d582e4816d69b27713a4027534b75cd00ca72557e69681daf07146d3e79ef2
4
+ data.tar.gz: c010252aa355370a531fa4f3b9bf8cc729876d2f7fb9ae8b8e0d6a1eb6cb57c4
5
5
  SHA512:
6
- metadata.gz: 66abcbd58ccfc04186df11b0d2b6445c7d1e0ab4a36451742755f6fcf41022363403536c3d27640935364c58231c4c4a03e39ceba97617a59f2ad69acf23dc16
7
- data.tar.gz: ba7ad07ccd31e94bdf3fdee73117f069c6ee22c11cdfc3f2470eedee9ee0bc9e976970bab21f156c003b658723a0e3fb2f63f62d0e11d33de3a61fc1d9121711
6
+ metadata.gz: 53e629e845342f173c04c7c6d9d976a29dd5492ae945239897d3168a586288ec958ba58753345317f301220ac5f4b91a22f97731ab799fcea5d59f3d19e48214
7
+ data.tar.gz: 5b04f2e63e3dcdb6a0282184a310600f4fb72e606b45e8e7a27a7b9461abef3a598afe2eb5525e66457da45b8a84a6c2f07c871c955e9cddd4308971496c7fd1
data/ext/mlx/native.cpp CHANGED
@@ -6625,7 +6625,8 @@ static VALUE core_clear_cache(VALUE) {
6625
6625
 
6626
6626
  static VALUE core_metal_is_available(VALUE) {
6627
6627
  try {
6628
- return mxmetal::is_available() ? Qtrue : Qfalse;
6628
+ const mx::Device gpu_device(mx::Device::gpu, 0);
6629
+ return mx::is_available(gpu_device) ? Qtrue : Qfalse;
6629
6630
  } catch (const std::exception& error) {
6630
6631
  raise_std_exception(error);
6631
6632
  return Qnil;
@@ -6654,7 +6655,12 @@ static VALUE core_metal_stop_capture(VALUE) {
6654
6655
 
6655
6656
  static VALUE core_metal_device_info(VALUE) {
6656
6657
  try {
6657
- const auto& info = mxmetal::device_info();
6658
+ const mx::Device gpu_device(mx::Device::gpu, 0);
6659
+ if (!mx::is_available(gpu_device)) {
6660
+ rb_raise(rb_eRuntimeError, "[metal_device_info] Metal GPU device is not available");
6661
+ }
6662
+
6663
+ const auto& info = mx::device_info(gpu_device);
6658
6664
  VALUE hash = rb_hash_new();
6659
6665
  for (const auto& [key, value] : info) {
6660
6666
  VALUE ruby_key = rb_utf8_str_new(key.c_str(), static_cast<long>(key.size()));
@@ -5,7 +5,7 @@ require "json"
5
5
 
6
6
  module MLX
7
7
  module DistributedUtils
8
- Host = Struct.new(:rank, :ssh_hostname, :ips, :rdma, keyword_init: true)
8
+ Host = Data.define(:rank, :ssh_hostname, :ips, :rdma)
9
9
 
10
10
  class Hostfile
11
11
  attr_accessor :hosts, :backend, :envs
@@ -8,13 +8,14 @@ require "shellwords"
8
8
 
9
9
  module MLX
10
10
  module DistributedUtils
11
- SSHInfo = Struct.new(:can_ssh, :has_sudo, keyword_init: true) do
11
+ SSHInfo = Data.define(:can_ssh, :has_sudo) do
12
12
  def to_bool
13
13
  can_ssh
14
14
  end
15
15
  end
16
- ThunderboltPort = Struct.new(:iface, :uuid, :connected_to, keyword_init: true)
17
- ThunderboltHost = Struct.new(:name, :ports, keyword_init: true)
16
+ ThunderboltPort = Data.define(:iface, :uuid, :connected_to)
17
+ ThunderboltHost = Data.define(:name, :ports)
18
+ CommandResult = Data.define(:stdout, :stderr, :status)
18
19
 
19
20
  class IPConfigurator
20
21
  attr_reader :ips, :hosts, :tb_hosts
@@ -509,6 +510,8 @@ module MLX
509
510
  end
510
511
 
511
512
  def config_main(argv = ARGV, runner: nil)
513
+ Process.warmup if Process.respond_to?(:warmup)
514
+
512
515
  opts = {
513
516
  verbose: false,
514
517
  hosts: "127.0.0.1",
@@ -577,7 +580,7 @@ module MLX
577
580
  return runner.call(cmd) unless runner.nil?
578
581
 
579
582
  stdout, stderr, status = Open3.capture3(*cmd)
580
- Struct.new(:stdout, :stderr, :status, keyword_init: true).new(stdout: stdout, stderr: stderr, status: status)
583
+ CommandResult.new(stdout: stdout, stderr: stderr, status: status)
581
584
  end
582
585
 
583
586
  def stdout_for(result)
@@ -314,6 +314,8 @@ module MLX
314
314
  end
315
315
 
316
316
  def main(argv = ARGV)
317
+ Process.warmup if Process.respond_to?(:warmup)
318
+
317
319
  opts = {
318
320
  print_python: false,
319
321
  verbose: false,
@@ -0,0 +1,132 @@
1
+ # frozen_string_literal: true
2
+
3
+ module MLX
4
+ module DSL
5
+ class Attention < MLX::NN::Module
6
+ def initialize(
7
+ dims:,
8
+ num_heads:,
9
+ kv_heads: nil,
10
+ qkv_bias: false,
11
+ backend: :sdpa,
12
+ rope: nil,
13
+ cache: false
14
+ )
15
+ super()
16
+
17
+ @dims = Integer(dims)
18
+ @num_heads = Integer(num_heads)
19
+ @kv_heads = kv_heads.nil? ? @num_heads : Integer(kv_heads)
20
+ if (@dims % @num_heads) != 0
21
+ raise ArgumentError, "dims must be divisible by num_heads"
22
+ end
23
+ if (@num_heads % @kv_heads) != 0
24
+ raise ArgumentError, "num_heads must be divisible by kv_heads"
25
+ end
26
+
27
+ @head_dim = @dims / @num_heads
28
+ @kv_repeats = @num_heads / @kv_heads
29
+ @backend = backend.to_sym
30
+ @cache_enabled = !!cache
31
+ @scale = Math.sqrt(1.0 / @head_dim)
32
+
33
+ self.query_proj = MLX::NN::Linear.new(@dims, @num_heads * @head_dim, bias: qkv_bias)
34
+ self.key_proj = MLX::NN::Linear.new(@dims, @kv_heads * @head_dim, bias: qkv_bias)
35
+ self.value_proj = MLX::NN::Linear.new(@dims, @kv_heads * @head_dim, bias: qkv_bias)
36
+ self.out_proj = MLX::NN::Linear.new(@num_heads * @head_dim, @dims, bias: qkv_bias)
37
+ self.rope = __dsl_build_rope(rope)
38
+ end
39
+
40
+ def call(queries, keys = nil, values = nil, mask: nil, cache: nil)
41
+ keys ||= queries
42
+ values ||= keys
43
+ q_was_2d = queries.ndim == 2
44
+
45
+ queries = MLX::Core.expand_dims(queries, 0) if q_was_2d
46
+ keys = MLX::Core.expand_dims(keys, 0) if keys.ndim == 2
47
+ values = MLX::Core.expand_dims(values, 0) if values.ndim == 2
48
+
49
+ batch_size, q_len, = queries.shape
50
+
51
+ q = __dsl_pack_heads(query_proj.call(queries), @num_heads)
52
+ k = __dsl_pack_heads(key_proj.call(keys), @kv_heads)
53
+ v = __dsl_pack_heads(value_proj.call(values), @kv_heads)
54
+
55
+ offset = cache.nil? ? 0 : cache[0].shape[2]
56
+ if !rope.nil?
57
+ if offset.zero?
58
+ q = rope.call(q)
59
+ k = rope.call(k)
60
+ else
61
+ q = rope.call(q, offset: offset)
62
+ k = rope.call(k, offset: offset)
63
+ end
64
+ end
65
+
66
+ unless cache.nil?
67
+ key_cache, value_cache = cache
68
+ k = MLX::Core.concatenate([key_cache, k], 2)
69
+ v = MLX::Core.concatenate([value_cache, v], 2)
70
+ end
71
+ next_cache = [k, v]
72
+
73
+ k_for_attn = __dsl_repeat_kv(k)
74
+ v_for_attn = __dsl_repeat_kv(v)
75
+ out = __dsl_attention(q, k_for_attn, v_for_attn, mask)
76
+ out = MLX::Core.transpose(out, [0, 2, 1, 3])
77
+ out = MLX::Core.reshape(out, [batch_size, q_len, @num_heads * @head_dim])
78
+ out = out_proj.call(out)
79
+ out = MLX::Core.squeeze(out, 0) if q_was_2d
80
+
81
+ if @cache_enabled || !cache.nil?
82
+ [out, next_cache]
83
+ else
84
+ out
85
+ end
86
+ end
87
+
88
+ private
89
+
90
+ def __dsl_build_rope(config)
91
+ return nil if config.nil?
92
+
93
+ opts = config.transform_keys(&:to_sym)
94
+ rope_kwargs = {
95
+ traditional: opts.fetch(:traditional, false),
96
+ base: opts.fetch(:base, 10_000.0)
97
+ }
98
+ rope_kwargs[:scale] = opts[:scale] if opts.key?(:scale)
99
+ MLX::NN::RoPE.new(@head_dim, **rope_kwargs)
100
+ end
101
+
102
+ def __dsl_pack_heads(x, heads)
103
+ batch, length, = x.shape
104
+ x = MLX::Core.reshape(x, [batch, length, heads, @head_dim])
105
+ MLX::Core.transpose(x, [0, 2, 1, 3])
106
+ end
107
+
108
+ def __dsl_repeat_kv(x)
109
+ return x if @kv_repeats == 1
110
+
111
+ batch, _heads, length, dim = x.shape
112
+ expanded = MLX::Core.expand_dims(x, 2)
113
+ repeated = MLX::Core.concatenate(Array.new(@kv_repeats, expanded), 2)
114
+ MLX::Core.reshape(repeated, [batch, @num_heads, length, dim])
115
+ end
116
+
117
+ def __dsl_attention(q, k, v, mask)
118
+ if @backend == :sdpa && MLX::Core.respond_to?(:scaled_dot_product_attention)
119
+ return MLX::Core.scaled_dot_product_attention(q, k, v, @scale, mask)
120
+ end
121
+
122
+ scores = MLX::Core.matmul(
123
+ MLX::Core.multiply(q, @scale),
124
+ MLX::Core.transpose(k, [0, 1, 3, 2])
125
+ )
126
+ scores = MLX::Core.add(scores, mask.astype(scores.dtype)) unless mask.nil?
127
+ probs = MLX::Core.softmax(scores.astype(MLX::Core.float32), -1).astype(scores.dtype)
128
+ MLX::Core.matmul(probs, v)
129
+ end
130
+ end
131
+ end
132
+ end
@@ -266,6 +266,14 @@ module MLX
266
266
  push(MLX::NN::Transformer.new(*args, **kwargs))
267
267
  end
268
268
 
269
+ def attention(*args, **kwargs)
270
+ push(MLX::DSL::Attention.new(*args, **kwargs))
271
+ end
272
+
273
+ def transformer_block(*args, **kwargs)
274
+ push(MLX::DSL::TransformerBlock.new(*args, **kwargs))
275
+ end
276
+
269
277
  def rope(*args, **kwargs)
270
278
  push(MLX::NN::RoPE.new(*args, **kwargs))
271
279
  end
@@ -0,0 +1,133 @@
1
+ # frozen_string_literal: true
2
+
3
+ module MLX
4
+ module DSL
5
+ module ConfigSchema
6
+ UNSET = Object.new.freeze
7
+ class DefaultContext
8
+ def initialize(values)
9
+ @values = values
10
+ end
11
+
12
+ def method_missing(name, *args, &block)
13
+ if args.empty? && block.nil? && @values.key?(name.to_sym)
14
+ return @values[name.to_sym]
15
+ end
16
+
17
+ super
18
+ end
19
+
20
+ def respond_to_missing?(name, include_private = false)
21
+ @values.key?(name.to_sym) || super
22
+ end
23
+ end
24
+
25
+ def self.included(base)
26
+ base.extend(ClassMethods)
27
+ end
28
+
29
+ module ClassMethods
30
+ def field(name, type = nil, required: false, default: UNSET, &validator)
31
+ key = name.to_sym
32
+ config_schema_fields[key] = {
33
+ type: type,
34
+ required: !!required,
35
+ default: default,
36
+ validator: validator
37
+ }
38
+ attr_accessor key unless method_defined?(key) && method_defined?(:"#{key}=")
39
+ end
40
+
41
+ def config_schema_fields
42
+ @config_schema_fields ||= {}
43
+ end
44
+
45
+ def inherited(subclass)
46
+ super
47
+ copied = config_schema_fields.each_with_object({}) do |(key, value), out|
48
+ out[key] = value.dup
49
+ end
50
+ subclass.instance_variable_set(:@config_schema_fields, copied)
51
+ end
52
+
53
+ def from_hash(raw)
54
+ source = (raw || {}).each_with_object({}) do |(key, value), out|
55
+ out[key.to_sym] = value
56
+ end
57
+ new(**source)
58
+ end
59
+
60
+ private
61
+
62
+ def __dsl_call_default(default, resolved)
63
+ context = DefaultContext.new(resolved)
64
+ return default unless default.respond_to?(:call)
65
+ return default.call(context) if default.is_a?(Proc) && default.arity == 1
66
+ return default.call if !default.is_a?(Proc)
67
+ return default.call if default.arity.zero?
68
+
69
+ default.call(context)
70
+ end
71
+
72
+ def __dsl_validate_field(name, value, spec)
73
+ type = spec.fetch(:type)
74
+ if !type.nil? && !value.nil?
75
+ allowed_types = type.is_a?(Array) ? type : [type]
76
+ unless allowed_types.any? { |klass| value.is_a?(klass) }
77
+ raise TypeError,
78
+ "config field #{name} must be #{allowed_types.map(&:to_s).join(' or ')}, got #{value.class}"
79
+ end
80
+ end
81
+
82
+ validator = spec.fetch(:validator)
83
+ unless validator.nil?
84
+ if validator.arity == 2
85
+ validator.call(value, name)
86
+ else
87
+ validator.call(value)
88
+ end
89
+ end
90
+
91
+ value
92
+ end
93
+ end
94
+
95
+ def initialize(**kwargs)
96
+ source = kwargs.each_with_object({}) do |(key, value), out|
97
+ out[key.to_sym] = value
98
+ end
99
+ resolved = {}
100
+ unknown = source.keys - self.class.config_schema_fields.keys
101
+ unless unknown.empty?
102
+ names = unknown.map(&:to_s).sort.join(", ")
103
+ raise ArgumentError, "unknown config field(s): #{names}"
104
+ end
105
+
106
+ self.class.config_schema_fields.each do |name, spec|
107
+ if source.key?(name)
108
+ value = source.fetch(name)
109
+ else
110
+ default = spec.fetch(:default)
111
+ if default.equal?(UNSET)
112
+ if spec.fetch(:required)
113
+ raise ArgumentError, "missing required config field: #{name}"
114
+ end
115
+ next
116
+ end
117
+ value = self.class.send(:__dsl_call_default, default, resolved)
118
+ end
119
+
120
+ value = self.class.send(:__dsl_validate_field, name, value, spec)
121
+ resolved[name] = value
122
+ public_send(:"#{name}=", value)
123
+ end
124
+ end
125
+
126
+ def to_h
127
+ self.class.config_schema_fields.keys.each_with_object({}) do |name, out|
128
+ out[name.to_s] = public_send(name)
129
+ end
130
+ end
131
+ end
132
+ end
133
+ end
@@ -0,0 +1,193 @@
1
+ # frozen_string_literal: true
2
+
3
+ module MLX
4
+ module DSL
5
+ class Generate
6
+ def initialize(
7
+ model:,
8
+ tokenizer: nil,
9
+ eos_id: nil,
10
+ sampler: nil,
11
+ mode: :decoder_only,
12
+ decoder_start_id: nil
13
+ )
14
+ @model = model
15
+ @tokenizer = tokenizer
16
+ @eos_id = eos_id
17
+ @sampler = { strategy: :argmax }.merge((sampler || {}).transform_keys(&:to_sym))
18
+ @mode = mode.to_sym
19
+ @decoder_start_id = decoder_start_id
20
+ end
21
+
22
+ def each_token(prompt: nil, input_ids: nil, max_tokens: 128, **kwargs)
23
+ return enum_for(__method__, prompt: prompt, input_ids: input_ids, max_tokens: max_tokens, **kwargs) unless block_given?
24
+
25
+ case @mode
26
+ when :decoder_only
27
+ __dsl_each_decoder_only(prompt: prompt, input_ids: input_ids, max_tokens: max_tokens, **kwargs) do |id, chunk|
28
+ yield id, chunk
29
+ end
30
+ when :encoder_decoder
31
+ __dsl_each_encoder_decoder(prompt: prompt, input_ids: input_ids, max_tokens: max_tokens, **kwargs) do |id, chunk|
32
+ yield id, chunk
33
+ end
34
+ else
35
+ raise ArgumentError, "unsupported generation mode: #{@mode.inspect}"
36
+ end
37
+
38
+ self
39
+ end
40
+
41
+ private
42
+
43
+ def __dsl_each_decoder_only(prompt:, input_ids:, max_tokens:, **kwargs)
44
+ tokens = input_ids.nil? ? __dsl_encode(prompt) : input_ids
45
+ model_input = __dsl_input_array(tokens)
46
+ logits, cache = __dsl_decode_step(model_input, cache: nil, **kwargs)
47
+
48
+ max_tokens.to_i.times do
49
+ token = __dsl_sample(__dsl_last_logits(logits))
50
+ token_id = __dsl_token_id(token)
51
+ chunk = __dsl_decode_token(token_id)
52
+ yield token_id, chunk
53
+ break if !@eos_id.nil? && token_id == @eos_id
54
+
55
+ next_input = MLX::Core.array([[token_id]], MLX::Core.int32)
56
+ logits, cache = __dsl_decode_step(next_input, cache: cache, **kwargs)
57
+ end
58
+ end
59
+
60
+ def __dsl_each_encoder_decoder(prompt:, input_ids:, max_tokens:, **kwargs)
61
+ tokens = input_ids.nil? ? __dsl_encode(prompt) : input_ids
62
+ source = __dsl_input_array(tokens)
63
+
64
+ if @model.respond_to?(:encode) && @model.respond_to?(:decode)
65
+ memory = @model.encode(source)
66
+ start_id = __dsl_decoder_start_id
67
+ decoder_input = MLX::Core.array([[start_id]], MLX::Core.int32)
68
+ cache = nil
69
+
70
+ max_tokens.to_i.times do
71
+ decoded = @model.decode(decoder_input, memory, cache: cache, **kwargs)
72
+ logits, cache = __dsl_split_logits_and_cache(decoded, cache)
73
+ token = __dsl_sample(__dsl_last_logits(logits))
74
+ token_id = __dsl_token_id(token)
75
+ chunk = __dsl_decode_token(token_id)
76
+ yield token_id, chunk
77
+ break if !@eos_id.nil? && token_id == @eos_id
78
+
79
+ decoder_input = MLX::Core.array([[token_id]], MLX::Core.int32)
80
+ end
81
+ return
82
+ end
83
+
84
+ # Fallback path for model.call-style APIs.
85
+ __dsl_each_decoder_only(prompt: prompt, input_ids: tokens, max_tokens: max_tokens, **kwargs) do |id, chunk|
86
+ yield id, chunk
87
+ end
88
+ end
89
+
90
+ def __dsl_decode_step(input_ids, cache:, **kwargs)
91
+ output = @model.call(input_ids, cache: cache, **kwargs)
92
+ __dsl_split_logits_and_cache(output, cache)
93
+ end
94
+
95
+ def __dsl_split_logits_and_cache(output, fallback_cache)
96
+ if output.is_a?(Array) && output.length == 2
97
+ [output[0], output[1]]
98
+ else
99
+ [output, fallback_cache]
100
+ end
101
+ end
102
+
103
+ def __dsl_last_logits(logits)
104
+ return logits if logits.ndim == 2
105
+ return logits if logits.ndim == 1
106
+
107
+ index = MLX::Core.array([logits.shape[1] - 1], MLX::Core.int32)
108
+ MLX::Core.squeeze(MLX::Core.take(logits, index, 1), 1)
109
+ end
110
+
111
+ def __dsl_sample(logits)
112
+ strategy = @sampler.fetch(:strategy, :argmax).to_sym
113
+ temperature = @sampler.fetch(:temperature, 1.0).to_f
114
+
115
+ return MLX::Core.argmax(logits, -1) if strategy == :argmax || temperature.zero?
116
+
117
+ case strategy
118
+ when :top_k
119
+ __dsl_top_k_sample(logits, k: Integer(@sampler.fetch(:k, 40)), temperature: temperature)
120
+ when :temperature, :categorical
121
+ __dsl_temperature_sample(logits, temperature: temperature)
122
+ else
123
+ raise ArgumentError, "unsupported sampler strategy: #{strategy.inspect}"
124
+ end
125
+ end
126
+
127
+ def __dsl_temperature_sample(logits, temperature:)
128
+ scaled = if temperature == 1.0
129
+ logits
130
+ else
131
+ MLX::Core.multiply(logits, 1.0 / temperature)
132
+ end
133
+ MLX::Core.categorical(scaled)
134
+ end
135
+
136
+ def __dsl_top_k_sample(logits, k:, temperature:)
137
+ rows = logits.ndim == 1 ? [logits.to_a] : logits.to_a
138
+ masked = rows.map do |row|
139
+ pairs = row.each_with_index.sort_by { |(value, _index)| -value }
140
+ keep = pairs.first([k, row.length].min).map(&:last)
141
+ filtered = Array.new(row.length, -Float::INFINITY)
142
+ keep.each { |idx| filtered[idx] = row[idx] }
143
+ filtered
144
+ end
145
+ masked_logits = MLX::Core.array(masked, logits.dtype)
146
+ __dsl_temperature_sample(masked_logits, temperature: temperature)
147
+ end
148
+
149
+ def __dsl_encode(prompt)
150
+ raise ArgumentError, "prompt/input_ids required when tokenizer is unavailable" if @tokenizer.nil?
151
+
152
+ @tokenizer.encode(prompt.to_s)
153
+ end
154
+
155
+ def __dsl_input_array(tokens)
156
+ if tokens.is_a?(MLX::Core::Array)
157
+ return tokens if tokens.ndim > 1
158
+
159
+ return MLX::Core.expand_dims(tokens.astype(MLX::Core.int32), 0)
160
+ end
161
+
162
+ arr = tokens.to_a
163
+ nested = arr.empty? ? [[]] : (arr.first.is_a?(Array) ? arr : [arr])
164
+ MLX::Core.array(nested, MLX::Core.int32)
165
+ end
166
+
167
+ def __dsl_token_id(token)
168
+ return token.item.to_i if token.respond_to?(:item)
169
+
170
+ value = token.to_a
171
+ if value.is_a?(Array)
172
+ first = value.first
173
+ return first.first.to_i if first.is_a?(Array)
174
+ return first.to_i
175
+ end
176
+ value.to_i
177
+ end
178
+
179
+ def __dsl_decode_token(token_id)
180
+ return nil if @tokenizer.nil? || !@tokenizer.respond_to?(:decode)
181
+
182
+ @tokenizer.decode([token_id])
183
+ end
184
+
185
+ def __dsl_decoder_start_id
186
+ return @decoder_start_id unless @decoder_start_id.nil?
187
+ return @tokenizer.decoder_start_id if !@tokenizer.nil? && @tokenizer.respond_to?(:decoder_start_id)
188
+
189
+ raise ArgumentError, "decoder_start_id is required for encoder-decoder mode"
190
+ end
191
+ end
192
+ end
193
+ end
@@ -0,0 +1,96 @@
1
+ # frozen_string_literal: true
2
+
3
+ module MLX
4
+ module DSL
5
+ class KVCache
6
+ attr_reader :num_layers
7
+
8
+ def initialize(num_layers:)
9
+ @num_layers = Integer(num_layers)
10
+ raise ArgumentError, "num_layers must be non-negative" if @num_layers.negative?
11
+
12
+ @layers = Array.new(@num_layers)
13
+ end
14
+
15
+ def layer(index)
16
+ @layers.fetch(__dsl_index(index))
17
+ end
18
+
19
+ def []=(index, value)
20
+ @layers[__dsl_index(index)] = value
21
+ end
22
+
23
+ def offset(layer:)
24
+ state = self.layer(layer)
25
+ return 0 if state.nil?
26
+
27
+ keys, = state
28
+ keys.shape[2]
29
+ end
30
+
31
+ def append(layer:, keys:, values:)
32
+ idx = __dsl_index(layer)
33
+ current = @layers[idx]
34
+ if current.nil?
35
+ @layers[idx] = [keys, values]
36
+ return @layers[idx]
37
+ end
38
+
39
+ key_cache, value_cache = current
40
+ next_keys = MLX::Core.concatenate([key_cache, keys], 2)
41
+ next_values = MLX::Core.concatenate([value_cache, values], 2)
42
+ @layers[idx] = [next_keys, next_values]
43
+ end
44
+
45
+ def truncate!(tokens:, layer: nil)
46
+ keep = Integer(tokens)
47
+ if layer.nil?
48
+ @layers.each_index { |idx| __dsl_truncate_layer!(idx, keep) }
49
+ else
50
+ __dsl_truncate_layer!(__dsl_index(layer), keep)
51
+ end
52
+ self
53
+ end
54
+
55
+ def reset!(layer: nil)
56
+ if layer.nil?
57
+ @layers.map! { nil }
58
+ else
59
+ @layers[__dsl_index(layer)] = nil
60
+ end
61
+ self
62
+ end
63
+
64
+ private
65
+
66
+ def __dsl_index(index)
67
+ idx = Integer(index)
68
+ if idx.negative? || idx >= @num_layers
69
+ raise IndexError, "layer index #{idx} out of range (0...#{@num_layers})"
70
+ end
71
+
72
+ idx
73
+ end
74
+
75
+ def __dsl_truncate_layer!(idx, keep)
76
+ state = @layers[idx]
77
+ return if state.nil?
78
+
79
+ if keep <= 0
80
+ @layers[idx] = nil
81
+ return
82
+ end
83
+
84
+ keys, values = state
85
+ total = keys.shape[2]
86
+ return if keep >= total
87
+
88
+ start = total - keep
89
+ indices = MLX::Core.arange(start, total, 1, MLX::Core.int32)
90
+ trimmed_keys = MLX::Core.take(keys, indices, 2)
91
+ trimmed_values = MLX::Core.take(values, indices, 2)
92
+ @layers[idx] = [trimmed_keys, trimmed_values]
93
+ end
94
+ end
95
+ end
96
+ end
@@ -0,0 +1,32 @@
1
+ # frozen_string_literal: true
2
+
3
+ module MLX
4
+ module DSL
5
+ module Masks
6
+ module_function
7
+
8
+ def causal(length:, offset: 0, dtype: MLX::Core.float32)
9
+ length = Integer(length)
10
+ offset = Integer(offset)
11
+ raise ArgumentError, "length must be non-negative" if length.negative?
12
+ raise ArgumentError, "offset must be non-negative" if offset.negative?
13
+
14
+ rinds = MLX::Core.arange(0, offset + length, 1)
15
+ linds = if offset.zero?
16
+ rinds
17
+ else
18
+ MLX::Core.arange(offset, offset + length, 1)
19
+ end
20
+ lhs = MLX::Core.expand_dims(linds, 1)
21
+ rhs = MLX::Core.expand_dims(rinds, 0)
22
+ mask = MLX::Core.less(lhs, rhs).astype(dtype)
23
+ min_value = if MLX::Core.respond_to?(:finfo)
24
+ MLX::Core.finfo(dtype).min
25
+ else
26
+ -1e9
27
+ end
28
+ MLX::Core.multiply(mask, min_value)
29
+ end
30
+ end
31
+ end
32
+ end
@@ -0,0 +1,35 @@
1
+ # frozen_string_literal: true
2
+
3
+ module MLX
4
+ module DSL
5
+ module Positions
6
+ module_function
7
+
8
+ def ids_like(input_ids, offset: 0, dtype: nil)
9
+ shape = input_ids.shape
10
+ seq_len = shape[-1]
11
+ dtype ||= input_ids.respond_to?(:dtype) ? input_ids.dtype : MLX::Core.int32
12
+
13
+ base = MLX::Core.arange(offset.to_i, offset.to_i + seq_len, 1, dtype)
14
+ return base if shape.length == 1
15
+
16
+ reshape_dims = Array.new(shape.length, 1)
17
+ reshape_dims[-1] = seq_len
18
+ expanded = MLX::Core.reshape(base, reshape_dims)
19
+ MLX::Core.broadcast_to(expanded, shape)
20
+ end
21
+
22
+ def offset_from_cache(cache, layer: 0)
23
+ return 0 if cache.nil?
24
+ return cache.offset(layer: layer) if cache.respond_to?(:offset)
25
+
26
+ if cache.respond_to?(:[]) && !cache[layer].nil?
27
+ keys, = cache[layer]
28
+ return keys.shape[2]
29
+ end
30
+
31
+ 0
32
+ end
33
+ end
34
+ end
35
+ end
@@ -0,0 +1,68 @@
1
+ # frozen_string_literal: true
2
+
3
+ module MLX
4
+ module DSL
5
+ def self.run_stack(layers, input, cache: nil, **kwargs)
6
+ modules = layers.to_a
7
+ if cache.is_a?(MLX::DSL::KVCache)
8
+ hidden = input
9
+ modules.each_with_index do |layer, index|
10
+ hidden, next_cache = __dsl_run_stack_layer(
11
+ layer,
12
+ hidden,
13
+ kwargs,
14
+ cache: cache.layer(index),
15
+ use_cache: true
16
+ )
17
+ cache[index] = next_cache
18
+ end
19
+ return [hidden, cache]
20
+ end
21
+
22
+ use_cache = !cache.nil?
23
+ cache_state = if use_cache
24
+ entries = cache.to_a
25
+ entries.length < modules.length ? entries + Array.new(modules.length - entries.length) : entries.dup
26
+ else
27
+ nil
28
+ end
29
+
30
+ hidden = input
31
+ modules.each_with_index do |layer, index|
32
+ layer_cache = use_cache ? cache_state[index] : nil
33
+ hidden, next_cache = __dsl_run_stack_layer(
34
+ layer,
35
+ hidden,
36
+ kwargs,
37
+ cache: layer_cache,
38
+ use_cache: use_cache
39
+ )
40
+ cache_state[index] = next_cache if use_cache
41
+ end
42
+
43
+ use_cache ? [hidden, cache_state] : hidden
44
+ end
45
+
46
+ def self.__dsl_run_stack_layer(layer, hidden, kwargs, cache:, use_cache:)
47
+ call_kwargs = kwargs.dup
48
+ call_kwargs[:cache] = cache if use_cache
49
+ result = layer.call(hidden, **call_kwargs)
50
+
51
+ if use_cache && result.is_a?(Array) && result.length == 2
52
+ [result[0], result[1]]
53
+ else
54
+ [result, cache]
55
+ end
56
+ rescue ArgumentError => e
57
+ if use_cache && e.message.include?("unknown keyword: :cache")
58
+ result = layer.call(hidden, **kwargs)
59
+ if result.is_a?(Array) && result.length == 2
60
+ return [result[0], result[1]]
61
+ end
62
+ return [result, cache]
63
+ end
64
+ raise
65
+ end
66
+ private_class_method :__dsl_run_stack_layer
67
+ end
68
+ end
@@ -0,0 +1,126 @@
1
+ # frozen_string_literal: true
2
+
3
+ module MLX
4
+ module DSL
5
+ module Tensor
6
+ module_function
7
+
8
+ def scatter_rows(base:, row_indices:, values:, axis: nil)
9
+ axis = __dsl_default_scatter_axis(base, axis)
10
+ row_indices = __dsl_to_index_array(row_indices)
11
+ values = __dsl_to_array(values, dtype: base.dtype)
12
+
13
+ if base.ndim != values.ndim
14
+ raise ArgumentError, "base and values must have the same rank"
15
+ end
16
+
17
+ unless base.shape.each_with_index.all? { |dim, idx| idx == axis || dim == values.shape[idx] }
18
+ raise ArgumentError, "values shape must match base shape except along axis #{axis}"
19
+ end
20
+
21
+ indices = __dsl_expand_indices_for_axis(
22
+ row_indices: row_indices,
23
+ values_shape: values.shape,
24
+ axis: axis
25
+ )
26
+ MLX::Core.put_along_axis(base, indices, values, axis)
27
+ end
28
+
29
+ def where_labels(base:, labels:, mapping:, mode: :add_or_replace)
30
+ mode = mode.to_sym
31
+ unless [:add_or_replace, :replace].include?(mode)
32
+ raise ArgumentError, "mode must be :add_or_replace or :replace"
33
+ end
34
+
35
+ out = base
36
+ mapping.each do |label_value, mapped_value|
37
+ mask = MLX::Core.equal(labels, label_value)
38
+ mask = __dsl_expand_trailing_dims(mask, out.ndim)
39
+ mapped = __dsl_broadcast_mapping(mapped_value, base: out, labels_ndim: labels.ndim)
40
+ replacement = if mode == :replace
41
+ mapped
42
+ else
43
+ MLX::Core.add(out, mapped)
44
+ end
45
+ out = MLX::Core.where(mask, replacement, out)
46
+ end
47
+ out
48
+ end
49
+
50
+ def __dsl_default_scatter_axis(base, axis)
51
+ return axis.to_i unless axis.nil?
52
+ return 1 if base.ndim == 3
53
+
54
+ 0
55
+ end
56
+ private_class_method :__dsl_default_scatter_axis
57
+
58
+ def __dsl_to_index_array(indices)
59
+ if indices.is_a?(MLX::Core::Array)
60
+ return indices.astype(MLX::Core.int32)
61
+ end
62
+
63
+ MLX::Core.array(indices, MLX::Core.int32)
64
+ end
65
+ private_class_method :__dsl_to_index_array
66
+
67
+ def __dsl_to_array(value, dtype:)
68
+ if value.is_a?(MLX::Core::Array)
69
+ return value.astype(dtype)
70
+ end
71
+
72
+ MLX::Core.array(value, dtype)
73
+ end
74
+ private_class_method :__dsl_to_array
75
+
76
+ def __dsl_expand_indices_for_axis(row_indices:, values_shape:, axis:)
77
+ if row_indices.ndim == 1
78
+ expected = values_shape[axis]
79
+ if row_indices.shape[0] != expected
80
+ raise ArgumentError, "row_indices length must match values shape at axis #{axis}"
81
+ end
82
+
83
+ reshape = Array.new(values_shape.length, 1)
84
+ reshape[axis] = expected
85
+ base = MLX::Core.reshape(row_indices, reshape)
86
+ return MLX::Core.broadcast_to(base, values_shape)
87
+ end
88
+
89
+ if row_indices.shape == values_shape
90
+ return row_indices
91
+ end
92
+
93
+ raise ArgumentError, "row_indices must be rank-1 or match values shape"
94
+ end
95
+ private_class_method :__dsl_expand_indices_for_axis
96
+
97
+ def __dsl_expand_trailing_dims(array, target_ndim)
98
+ out = array
99
+ while out.ndim < target_ndim
100
+ out = MLX::Core.expand_dims(out, out.ndim)
101
+ end
102
+ out
103
+ end
104
+ private_class_method :__dsl_expand_trailing_dims
105
+
106
+ def __dsl_broadcast_mapping(value, base:, labels_ndim:)
107
+ mapped = __dsl_to_array(value, dtype: base.dtype)
108
+ return mapped if mapped.shape == base.shape
109
+
110
+ if mapped.ndim == 1 && mapped.shape[0] == base.shape[-1]
111
+ reshape = Array.new(labels_ndim, 1) + [mapped.shape[0]]
112
+ mapped = MLX::Core.reshape(mapped, reshape)
113
+ return MLX::Core.broadcast_to(mapped, base.shape)
114
+ end
115
+
116
+ if mapped.ndim == labels_ndim && mapped.shape == base.shape[0...labels_ndim]
117
+ mapped = __dsl_expand_trailing_dims(mapped, base.ndim)
118
+ return MLX::Core.broadcast_to(mapped, base.shape)
119
+ end
120
+
121
+ MLX::Core.broadcast_to(mapped, base.shape)
122
+ end
123
+ private_class_method :__dsl_broadcast_mapping
124
+ end
125
+ end
126
+ end
@@ -0,0 +1,113 @@
1
+ # frozen_string_literal: true
2
+
3
+ module MLX
4
+ module DSL
5
+ class FeedForward < MLX::NN::Module
6
+ def initialize(dims:, hidden_dims:, kind: :gelu, bias: false)
7
+ super()
8
+ @kind = kind.to_sym
9
+ @dims = Integer(dims)
10
+ @hidden_dims = Integer(hidden_dims)
11
+
12
+ case @kind
13
+ when :swiglu
14
+ self.gate_proj = MLX::NN::Linear.new(@dims, @hidden_dims, bias: bias)
15
+ self.up_proj = MLX::NN::Linear.new(@dims, @hidden_dims, bias: bias)
16
+ self.down_proj = MLX::NN::Linear.new(@hidden_dims, @dims, bias: bias)
17
+ else
18
+ self.in_proj = MLX::NN::Linear.new(@dims, @hidden_dims, bias: bias)
19
+ self.out_proj = MLX::NN::Linear.new(@hidden_dims, @dims, bias: bias)
20
+ end
21
+ end
22
+
23
+ def call(x)
24
+ case @kind
25
+ when :swiglu
26
+ gated = MLX::NN.silu(gate_proj.call(x))
27
+ down_proj.call(MLX::Core.multiply(gated, up_proj.call(x)))
28
+ when :relu
29
+ out_proj.call(MLX::NN.relu(in_proj.call(x)))
30
+ else
31
+ out_proj.call(MLX::NN.gelu(in_proj.call(x)))
32
+ end
33
+ end
34
+ end
35
+
36
+ class TransformerBlock < MLX::NN::Module
37
+ def initialize(
38
+ dims:,
39
+ num_heads:,
40
+ kv_heads: nil,
41
+ ffn_dims: nil,
42
+ norm: :rms,
43
+ norm_eps: 1e-5,
44
+ ffn: nil,
45
+ rope: nil,
46
+ qkv_bias: false,
47
+ backend: :sdpa,
48
+ cache: false
49
+ )
50
+ super()
51
+
52
+ ffn_config = (ffn || {}).transform_keys(&:to_sym)
53
+ ffn_kind = ffn_config.fetch(:kind, :gelu)
54
+ hidden_dims = ffn_dims || ffn_config.fetch(:hidden_dims, Integer(dims) * 4)
55
+ ffn_bias = ffn_config.fetch(:bias, false)
56
+
57
+ self.attention_norm = __dsl_build_norm(norm, Integer(dims), norm_eps)
58
+ self.attention = MLX::DSL::Attention.new(
59
+ dims: dims,
60
+ num_heads: num_heads,
61
+ kv_heads: kv_heads,
62
+ qkv_bias: qkv_bias,
63
+ backend: backend,
64
+ rope: rope,
65
+ cache: cache
66
+ )
67
+ self.ffn_norm = __dsl_build_norm(norm, Integer(dims), norm_eps)
68
+ self.feed_forward = MLX::DSL::FeedForward.new(
69
+ dims: dims,
70
+ hidden_dims: hidden_dims,
71
+ kind: ffn_kind,
72
+ bias: ffn_bias
73
+ )
74
+
75
+ @cache_enabled = !!cache
76
+ end
77
+
78
+ def call(x, mask: nil, cache: nil, **_kwargs)
79
+ attn_input = attention_norm.call(x)
80
+ attn_result = attention.call(attn_input, attn_input, attn_input, mask: mask, cache: cache)
81
+ if attn_result.is_a?(Array) && attn_result.length == 2
82
+ attn_out, next_cache = attn_result
83
+ else
84
+ attn_out = attn_result
85
+ next_cache = cache
86
+ end
87
+
88
+ hidden = MLX::Core.add(x, attn_out)
89
+ ffn_out = feed_forward.call(ffn_norm.call(hidden))
90
+ output = MLX::Core.add(hidden, ffn_out)
91
+
92
+ if @cache_enabled || !cache.nil?
93
+ [output, next_cache]
94
+ else
95
+ output
96
+ end
97
+ end
98
+
99
+ private
100
+
101
+ def __dsl_build_norm(kind, dims, eps)
102
+ case kind.to_sym
103
+ when :layer, :layer_norm
104
+ MLX::NN::LayerNorm.new(dims, eps: eps)
105
+ when :rms, :rms_norm
106
+ MLX::NN::RMSNorm.new(dims, eps: eps)
107
+ else
108
+ raise ArgumentError, "unsupported norm kind: #{kind.inspect}"
109
+ end
110
+ end
111
+ end
112
+ end
113
+ end
@@ -0,0 +1,140 @@
1
+ # frozen_string_literal: true
2
+
3
+ module MLX
4
+ module DSL
5
+ def self.weight_map(&block)
6
+ mapper = WeightMap.new
7
+ mapper.instance_eval(&block) if block_given?
8
+ mapper
9
+ end
10
+
11
+ class WeightMap
12
+ def initialize
13
+ @rules = []
14
+ end
15
+
16
+ def strip_prefix(prefix)
17
+ @rules << [:strip_prefix, prefix.to_s]
18
+ self
19
+ end
20
+
21
+ def rename(from = nil, to = nil)
22
+ if from.is_a?(Hash)
23
+ from.each do |src, dst|
24
+ @rules << [:rename, src.to_s, dst.to_s]
25
+ end
26
+ return self
27
+ end
28
+
29
+ if from.nil? || to.nil?
30
+ raise ArgumentError, "rename requires either a mapping hash or from/to arguments"
31
+ end
32
+
33
+ @rules << [:rename, from.to_s, to.to_s]
34
+ self
35
+ end
36
+
37
+ def regex(pattern, replacement = nil, &block)
38
+ if pattern.nil?
39
+ raise ArgumentError, "regex requires a Regexp pattern"
40
+ end
41
+ if replacement.nil? && !block_given?
42
+ raise ArgumentError, "regex requires a replacement argument or block"
43
+ end
44
+
45
+ @rules << [:regex, pattern, replacement, block]
46
+ self
47
+ end
48
+
49
+ def split_qkv(source, into:, axis: 0)
50
+ names = into.to_a.map(&:to_s)
51
+ if names.empty? || names.length < 2
52
+ raise ArgumentError, "split_qkv :into must include at least two output names"
53
+ end
54
+
55
+ @rules << [:split, source.to_s, names, axis.to_i]
56
+ self
57
+ end
58
+
59
+ def transpose_if(rank:, order:)
60
+ @rules << [:transpose_if, rank.to_i, order.to_a]
61
+ self
62
+ end
63
+
64
+ def apply(weights)
65
+ entries = if weights.is_a?(Hash)
66
+ weights.to_a
67
+ elsif weights.respond_to?(:to_a)
68
+ weights.to_a
69
+ else
70
+ raise ArgumentError, "weights must be a Hash or array-like key/value collection"
71
+ end
72
+
73
+ out = {}
74
+ entries.each do |entry|
75
+ key, value = entry
76
+ __dsl_apply_rules_for_entry(key.to_s, value).each do |mapped_key, mapped_value|
77
+ out[mapped_key] = mapped_value
78
+ end
79
+ end
80
+ out
81
+ end
82
+
83
+ private
84
+
85
+ def __dsl_apply_rules_for_entry(key, value)
86
+ current_entries = [[key, value]]
87
+ @rules.each do |rule|
88
+ current_entries = current_entries.flat_map do |curr_key, curr_value|
89
+ __dsl_apply_rule(rule, curr_key, curr_value)
90
+ end
91
+ end
92
+ current_entries
93
+ end
94
+
95
+ def __dsl_apply_rule(rule, key, value)
96
+ kind = rule[0]
97
+ case kind
98
+ when :strip_prefix
99
+ prefix = rule[1]
100
+ if key.start_with?(prefix)
101
+ [[key[prefix.length..], value]]
102
+ else
103
+ [[key, value]]
104
+ end
105
+ when :rename
106
+ from = rule[1]
107
+ to = rule[2]
108
+ [[key.gsub(from, to), value]]
109
+ when :regex
110
+ pattern = rule[1]
111
+ replacement = rule[2]
112
+ block = rule[3]
113
+ if block.nil?
114
+ [[key.gsub(pattern, replacement.to_s), value]]
115
+ else
116
+ [[key.gsub(pattern, &block), value]]
117
+ end
118
+ when :split
119
+ source = rule[1]
120
+ targets = rule[2]
121
+ axis = rule[3]
122
+ return [[key, value]] unless key == source
123
+
124
+ parts = MLX::Core.split(value, targets.length, axis)
125
+ targets.each_with_index.map { |target, index| [target, parts[index]] }
126
+ when :transpose_if
127
+ target_rank = rule[1]
128
+ order = rule[2]
129
+ if value.respond_to?(:shape) && value.shape.length == target_rank
130
+ [[key, MLX::Core.transpose(value, order)]]
131
+ else
132
+ [[key, value]]
133
+ end
134
+ else
135
+ [[key, value]]
136
+ end
137
+ end
138
+ end
139
+ end
140
+ end
data/lib/mlx/dsl.rb CHANGED
@@ -10,6 +10,16 @@ require_relative "dsl/data_pipeline"
10
10
  require_relative "dsl/experiment"
11
11
  require_relative "dsl/split_plan"
12
12
  require_relative "dsl/builder"
13
+ require_relative "dsl/config_schema"
14
+ require_relative "dsl/weight_map"
15
+ require_relative "dsl/kv_cache"
16
+ require_relative "dsl/masks"
17
+ require_relative "dsl/positions"
18
+ require_relative "dsl/tensor"
19
+ require_relative "dsl/run_stack"
20
+ require_relative "dsl/attention"
21
+ require_relative "dsl/transformer_block"
22
+ require_relative "dsl/generate"
13
23
  require_relative "dsl/train_step"
14
24
  require_relative "dsl/model_mixin"
15
25
  require_relative "dsl/model"
data/lib/mlx/version.rb CHANGED
@@ -1,5 +1,5 @@
1
1
  # frozen_string_literal: true
2
2
 
3
3
  module MLX
4
- VERSION = "0.30.7.2"
4
+ VERSION = "0.30.7.3"
5
5
  end
metadata CHANGED
@@ -1,15 +1,57 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: mlx
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.30.7.2
4
+ version: 0.30.7.3
5
5
  platform: ruby
6
6
  authors:
7
7
  - MLX Contributors
8
8
  - Aleksey Skryl
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2026-02-14 00:00:00.000000000 Z
12
- dependencies: []
11
+ date: 1980-01-02 00:00:00.000000000 Z
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ name: rake
15
+ requirement: !ruby/object:Gem::Requirement
16
+ requirements:
17
+ - - ">="
18
+ - !ruby/object:Gem::Version
19
+ version: '0'
20
+ type: :development
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - ">="
25
+ - !ruby/object:Gem::Version
26
+ version: '0'
27
+ - !ruby/object:Gem::Dependency
28
+ name: minitest
29
+ requirement: !ruby/object:Gem::Requirement
30
+ requirements:
31
+ - - ">="
32
+ - !ruby/object:Gem::Version
33
+ version: '0'
34
+ type: :development
35
+ prerelease: false
36
+ version_requirements: !ruby/object:Gem::Requirement
37
+ requirements:
38
+ - - ">="
39
+ - !ruby/object:Gem::Version
40
+ version: '0'
41
+ - !ruby/object:Gem::Dependency
42
+ name: benchmark
43
+ requirement: !ruby/object:Gem::Requirement
44
+ requirements:
45
+ - - ">="
46
+ - !ruby/object:Gem::Version
47
+ version: '0'
48
+ type: :development
49
+ prerelease: false
50
+ version_requirements: !ruby/object:Gem::Requirement
51
+ requirements:
52
+ - - ">="
53
+ - !ruby/object:Gem::Version
54
+ version: '0'
13
55
  description: A Ruby wrapper for the native MLX machine learning runtime.
14
56
  email:
15
57
  - mlx@group.apple.com
@@ -27,15 +69,25 @@ files:
27
69
  - lib/mlx/distributed_utils/config.rb
28
70
  - lib/mlx/distributed_utils/launch.rb
29
71
  - lib/mlx/dsl.rb
72
+ - lib/mlx/dsl/attention.rb
30
73
  - lib/mlx/dsl/builder.rb
74
+ - lib/mlx/dsl/config_schema.rb
31
75
  - lib/mlx/dsl/data_pipeline.rb
32
76
  - lib/mlx/dsl/experiment.rb
77
+ - lib/mlx/dsl/generate.rb
33
78
  - lib/mlx/dsl/graph_modules.rb
79
+ - lib/mlx/dsl/kv_cache.rb
80
+ - lib/mlx/dsl/masks.rb
34
81
  - lib/mlx/dsl/model.rb
35
82
  - lib/mlx/dsl/model_mixin.rb
83
+ - lib/mlx/dsl/positions.rb
84
+ - lib/mlx/dsl/run_stack.rb
36
85
  - lib/mlx/dsl/split_plan.rb
86
+ - lib/mlx/dsl/tensor.rb
37
87
  - lib/mlx/dsl/train_step.rb
38
88
  - lib/mlx/dsl/trainer.rb
89
+ - lib/mlx/dsl/transformer_block.rb
90
+ - lib/mlx/dsl/weight_map.rb
39
91
  - lib/mlx/extension.rb
40
92
  - lib/mlx/nn.rb
41
93
  - lib/mlx/nn/base.rb
@@ -640,14 +692,14 @@ required_ruby_version: !ruby/object:Gem::Requirement
640
692
  requirements:
641
693
  - - ">="
642
694
  - !ruby/object:Gem::Version
643
- version: '3.1'
695
+ version: '3.3'
644
696
  required_rubygems_version: !ruby/object:Gem::Requirement
645
697
  requirements:
646
698
  - - ">="
647
699
  - !ruby/object:Gem::Version
648
700
  version: '0'
649
701
  requirements: []
650
- rubygems_version: 3.6.2
702
+ rubygems_version: 4.0.3
651
703
  specification_version: 4
652
704
  summary: Ruby bindings for the native MLX library
653
705
  test_files: []