mlx 0.30.7 → 0.30.7.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,193 @@
1
+ # frozen_string_literal: true
2
+
3
+ module MLX
4
+ module DSL
5
+ class Generate
6
+ def initialize(
7
+ model:,
8
+ tokenizer: nil,
9
+ eos_id: nil,
10
+ sampler: nil,
11
+ mode: :decoder_only,
12
+ decoder_start_id: nil
13
+ )
14
+ @model = model
15
+ @tokenizer = tokenizer
16
+ @eos_id = eos_id
17
+ @sampler = { strategy: :argmax }.merge((sampler || {}).transform_keys(&:to_sym))
18
+ @mode = mode.to_sym
19
+ @decoder_start_id = decoder_start_id
20
+ end
21
+
22
+ def each_token(prompt: nil, input_ids: nil, max_tokens: 128, **kwargs)
23
+ return enum_for(__method__, prompt: prompt, input_ids: input_ids, max_tokens: max_tokens, **kwargs) unless block_given?
24
+
25
+ case @mode
26
+ when :decoder_only
27
+ __dsl_each_decoder_only(prompt: prompt, input_ids: input_ids, max_tokens: max_tokens, **kwargs) do |id, chunk|
28
+ yield id, chunk
29
+ end
30
+ when :encoder_decoder
31
+ __dsl_each_encoder_decoder(prompt: prompt, input_ids: input_ids, max_tokens: max_tokens, **kwargs) do |id, chunk|
32
+ yield id, chunk
33
+ end
34
+ else
35
+ raise ArgumentError, "unsupported generation mode: #{@mode.inspect}"
36
+ end
37
+
38
+ self
39
+ end
40
+
41
+ private
42
+
43
+ def __dsl_each_decoder_only(prompt:, input_ids:, max_tokens:, **kwargs)
44
+ tokens = input_ids.nil? ? __dsl_encode(prompt) : input_ids
45
+ model_input = __dsl_input_array(tokens)
46
+ logits, cache = __dsl_decode_step(model_input, cache: nil, **kwargs)
47
+
48
+ max_tokens.to_i.times do
49
+ token = __dsl_sample(__dsl_last_logits(logits))
50
+ token_id = __dsl_token_id(token)
51
+ chunk = __dsl_decode_token(token_id)
52
+ yield token_id, chunk
53
+ break if !@eos_id.nil? && token_id == @eos_id
54
+
55
+ next_input = MLX::Core.array([[token_id]], MLX::Core.int32)
56
+ logits, cache = __dsl_decode_step(next_input, cache: cache, **kwargs)
57
+ end
58
+ end
59
+
60
+ def __dsl_each_encoder_decoder(prompt:, input_ids:, max_tokens:, **kwargs)
61
+ tokens = input_ids.nil? ? __dsl_encode(prompt) : input_ids
62
+ source = __dsl_input_array(tokens)
63
+
64
+ if @model.respond_to?(:encode) && @model.respond_to?(:decode)
65
+ memory = @model.encode(source)
66
+ start_id = __dsl_decoder_start_id
67
+ decoder_input = MLX::Core.array([[start_id]], MLX::Core.int32)
68
+ cache = nil
69
+
70
+ max_tokens.to_i.times do
71
+ decoded = @model.decode(decoder_input, memory, cache: cache, **kwargs)
72
+ logits, cache = __dsl_split_logits_and_cache(decoded, cache)
73
+ token = __dsl_sample(__dsl_last_logits(logits))
74
+ token_id = __dsl_token_id(token)
75
+ chunk = __dsl_decode_token(token_id)
76
+ yield token_id, chunk
77
+ break if !@eos_id.nil? && token_id == @eos_id
78
+
79
+ decoder_input = MLX::Core.array([[token_id]], MLX::Core.int32)
80
+ end
81
+ return
82
+ end
83
+
84
+ # Fallback path for model.call-style APIs.
85
+ __dsl_each_decoder_only(prompt: prompt, input_ids: tokens, max_tokens: max_tokens, **kwargs) do |id, chunk|
86
+ yield id, chunk
87
+ end
88
+ end
89
+
90
+ def __dsl_decode_step(input_ids, cache:, **kwargs)
91
+ output = @model.call(input_ids, cache: cache, **kwargs)
92
+ __dsl_split_logits_and_cache(output, cache)
93
+ end
94
+
95
+ def __dsl_split_logits_and_cache(output, fallback_cache)
96
+ if output.is_a?(Array) && output.length == 2
97
+ [output[0], output[1]]
98
+ else
99
+ [output, fallback_cache]
100
+ end
101
+ end
102
+
103
+ def __dsl_last_logits(logits)
104
+ return logits if logits.ndim == 2
105
+ return logits if logits.ndim == 1
106
+
107
+ index = MLX::Core.array([logits.shape[1] - 1], MLX::Core.int32)
108
+ MLX::Core.squeeze(MLX::Core.take(logits, index, 1), 1)
109
+ end
110
+
111
+ def __dsl_sample(logits)
112
+ strategy = @sampler.fetch(:strategy, :argmax).to_sym
113
+ temperature = @sampler.fetch(:temperature, 1.0).to_f
114
+
115
+ return MLX::Core.argmax(logits, -1) if strategy == :argmax || temperature.zero?
116
+
117
+ case strategy
118
+ when :top_k
119
+ __dsl_top_k_sample(logits, k: Integer(@sampler.fetch(:k, 40)), temperature: temperature)
120
+ when :temperature, :categorical
121
+ __dsl_temperature_sample(logits, temperature: temperature)
122
+ else
123
+ raise ArgumentError, "unsupported sampler strategy: #{strategy.inspect}"
124
+ end
125
+ end
126
+
127
+ def __dsl_temperature_sample(logits, temperature:)
128
+ scaled = if temperature == 1.0
129
+ logits
130
+ else
131
+ MLX::Core.multiply(logits, 1.0 / temperature)
132
+ end
133
+ MLX::Core.categorical(scaled)
134
+ end
135
+
136
+ def __dsl_top_k_sample(logits, k:, temperature:)
137
+ rows = logits.ndim == 1 ? [logits.to_a] : logits.to_a
138
+ masked = rows.map do |row|
139
+ pairs = row.each_with_index.sort_by { |(value, _index)| -value }
140
+ keep = pairs.first([k, row.length].min).map(&:last)
141
+ filtered = Array.new(row.length, -Float::INFINITY)
142
+ keep.each { |idx| filtered[idx] = row[idx] }
143
+ filtered
144
+ end
145
+ masked_logits = MLX::Core.array(masked, logits.dtype)
146
+ __dsl_temperature_sample(masked_logits, temperature: temperature)
147
+ end
148
+
149
+ def __dsl_encode(prompt)
150
+ raise ArgumentError, "prompt/input_ids required when tokenizer is unavailable" if @tokenizer.nil?
151
+
152
+ @tokenizer.encode(prompt.to_s)
153
+ end
154
+
155
+ def __dsl_input_array(tokens)
156
+ if tokens.is_a?(MLX::Core::Array)
157
+ return tokens if tokens.ndim > 1
158
+
159
+ return MLX::Core.expand_dims(tokens.astype(MLX::Core.int32), 0)
160
+ end
161
+
162
+ arr = tokens.to_a
163
+ nested = arr.empty? ? [[]] : (arr.first.is_a?(Array) ? arr : [arr])
164
+ MLX::Core.array(nested, MLX::Core.int32)
165
+ end
166
+
167
+ def __dsl_token_id(token)
168
+ return token.item.to_i if token.respond_to?(:item)
169
+
170
+ value = token.to_a
171
+ if value.is_a?(Array)
172
+ first = value.first
173
+ return first.first.to_i if first.is_a?(Array)
174
+ return first.to_i
175
+ end
176
+ value.to_i
177
+ end
178
+
179
+ def __dsl_decode_token(token_id)
180
+ return nil if @tokenizer.nil? || !@tokenizer.respond_to?(:decode)
181
+
182
+ @tokenizer.decode([token_id])
183
+ end
184
+
185
+ def __dsl_decoder_start_id
186
+ return @decoder_start_id unless @decoder_start_id.nil?
187
+ return @tokenizer.decoder_start_id if !@tokenizer.nil? && @tokenizer.respond_to?(:decoder_start_id)
188
+
189
+ raise ArgumentError, "decoder_start_id is required for encoder-decoder mode"
190
+ end
191
+ end
192
+ end
193
+ end
@@ -0,0 +1,91 @@
1
+ # frozen_string_literal: true
2
+
3
+ module MLX
4
+ module DSL
5
+ class Callable < MLX::NN::Module
6
+ def initialize(callable = nil, &block)
7
+ super()
8
+ if !callable.nil? && block_given?
9
+ raise ArgumentError, "callable layer accepts either a callable argument or block, not both"
10
+ end
11
+
12
+ @callable = callable.nil? ? block : callable
13
+ unless @callable.respond_to?(:call)
14
+ raise ArgumentError, "callable layer requires a callable argument or block"
15
+ end
16
+ end
17
+
18
+ def call(*args, **kwargs)
19
+ return @callable.call(*args) if kwargs.empty?
20
+
21
+ @callable.call(*args, **kwargs)
22
+ end
23
+ end
24
+
25
+ class Residual < MLX::NN::Module
26
+ def initialize(module_obj)
27
+ super()
28
+ self.module_obj = module_obj
29
+ end
30
+
31
+ def call(*args, **kwargs)
32
+ raise ArgumentError, "residual module expects at least one positional input" if args.empty?
33
+
34
+ identity = args[0]
35
+ transformed = module_obj.call(*args, **kwargs)
36
+ MLX::Core.add(identity, transformed)
37
+ end
38
+ end
39
+
40
+ class Parallel < MLX::NN::Module
41
+ def initialize(*modules)
42
+ super()
43
+ self.layers = modules
44
+ end
45
+
46
+ def call(*args, **kwargs)
47
+ layers.map do |layer|
48
+ layer.call(*args, **kwargs)
49
+ end
50
+ end
51
+ end
52
+
53
+ class Concat < MLX::NN::Module
54
+ def initialize(*modules, axis: -1)
55
+ super()
56
+ self.layers = modules
57
+ @axis = axis
58
+ end
59
+
60
+ def call(*args, **kwargs)
61
+ outputs = layers.map do |layer|
62
+ layer.call(*args, **kwargs)
63
+ end
64
+ MLX::Core.concatenate(outputs, @axis)
65
+ end
66
+ end
67
+
68
+ class Reduce < MLX::NN::Module
69
+ def initialize(*modules, mode: :sum)
70
+ super()
71
+ self.layers = modules
72
+ @mode = mode.to_sym
73
+ end
74
+
75
+ def call(*args, **kwargs)
76
+ outputs = layers.map do |layer|
77
+ layer.call(*args, **kwargs)
78
+ end
79
+
80
+ case @mode
81
+ when :sum
82
+ outputs.reduce do |acc, item|
83
+ MLX::Core.add(acc, item)
84
+ end
85
+ else
86
+ raise ArgumentError, "unsupported reduce mode: #{@mode.inspect}"
87
+ end
88
+ end
89
+ end
90
+ end
91
+ end
@@ -0,0 +1,96 @@
1
+ # frozen_string_literal: true
2
+
3
+ module MLX
4
+ module DSL
5
+ class KVCache
6
+ attr_reader :num_layers
7
+
8
+ def initialize(num_layers:)
9
+ @num_layers = Integer(num_layers)
10
+ raise ArgumentError, "num_layers must be non-negative" if @num_layers.negative?
11
+
12
+ @layers = Array.new(@num_layers)
13
+ end
14
+
15
+ def layer(index)
16
+ @layers.fetch(__dsl_index(index))
17
+ end
18
+
19
+ def []=(index, value)
20
+ @layers[__dsl_index(index)] = value
21
+ end
22
+
23
+ def offset(layer:)
24
+ state = self.layer(layer)
25
+ return 0 if state.nil?
26
+
27
+ keys, = state
28
+ keys.shape[2]
29
+ end
30
+
31
+ def append(layer:, keys:, values:)
32
+ idx = __dsl_index(layer)
33
+ current = @layers[idx]
34
+ if current.nil?
35
+ @layers[idx] = [keys, values]
36
+ return @layers[idx]
37
+ end
38
+
39
+ key_cache, value_cache = current
40
+ next_keys = MLX::Core.concatenate([key_cache, keys], 2)
41
+ next_values = MLX::Core.concatenate([value_cache, values], 2)
42
+ @layers[idx] = [next_keys, next_values]
43
+ end
44
+
45
+ def truncate!(tokens:, layer: nil)
46
+ keep = Integer(tokens)
47
+ if layer.nil?
48
+ @layers.each_index { |idx| __dsl_truncate_layer!(idx, keep) }
49
+ else
50
+ __dsl_truncate_layer!(__dsl_index(layer), keep)
51
+ end
52
+ self
53
+ end
54
+
55
+ def reset!(layer: nil)
56
+ if layer.nil?
57
+ @layers.map! { nil }
58
+ else
59
+ @layers[__dsl_index(layer)] = nil
60
+ end
61
+ self
62
+ end
63
+
64
+ private
65
+
66
+ def __dsl_index(index)
67
+ idx = Integer(index)
68
+ if idx.negative? || idx >= @num_layers
69
+ raise IndexError, "layer index #{idx} out of range (0...#{@num_layers})"
70
+ end
71
+
72
+ idx
73
+ end
74
+
75
+ def __dsl_truncate_layer!(idx, keep)
76
+ state = @layers[idx]
77
+ return if state.nil?
78
+
79
+ if keep <= 0
80
+ @layers[idx] = nil
81
+ return
82
+ end
83
+
84
+ keys, values = state
85
+ total = keys.shape[2]
86
+ return if keep >= total
87
+
88
+ start = total - keep
89
+ indices = MLX::Core.arange(start, total, 1, MLX::Core.int32)
90
+ trimmed_keys = MLX::Core.take(keys, indices, 2)
91
+ trimmed_values = MLX::Core.take(values, indices, 2)
92
+ @layers[idx] = [trimmed_keys, trimmed_values]
93
+ end
94
+ end
95
+ end
96
+ end
@@ -0,0 +1,32 @@
1
+ # frozen_string_literal: true
2
+
3
+ module MLX
4
+ module DSL
5
+ module Masks
6
+ module_function
7
+
8
+ def causal(length:, offset: 0, dtype: MLX::Core.float32)
9
+ length = Integer(length)
10
+ offset = Integer(offset)
11
+ raise ArgumentError, "length must be non-negative" if length.negative?
12
+ raise ArgumentError, "offset must be non-negative" if offset.negative?
13
+
14
+ rinds = MLX::Core.arange(0, offset + length, 1)
15
+ linds = if offset.zero?
16
+ rinds
17
+ else
18
+ MLX::Core.arange(offset, offset + length, 1)
19
+ end
20
+ lhs = MLX::Core.expand_dims(linds, 1)
21
+ rhs = MLX::Core.expand_dims(rinds, 0)
22
+ mask = MLX::Core.less(lhs, rhs).astype(dtype)
23
+ min_value = if MLX::Core.respond_to?(:finfo)
24
+ MLX::Core.finfo(dtype).min
25
+ else
26
+ -1e9
27
+ end
28
+ MLX::Core.multiply(mask, min_value)
29
+ end
30
+ end
31
+ end
32
+ end
@@ -0,0 +1,9 @@
1
+ # frozen_string_literal: true
2
+
3
+ module MLX
4
+ module DSL
5
+ class Model < MLX::NN::Module
6
+ include ModelMixin
7
+ end
8
+ end
9
+ end