mlx 0.30.7 → 0.30.7.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/ext/mlx/native.cpp +8 -6
- data/lib/mlx/core.rb +8 -1
- data/lib/mlx/distributed_utils/common.rb +1 -1
- data/lib/mlx/distributed_utils/config.rb +7 -4
- data/lib/mlx/distributed_utils/launch.rb +11 -3
- data/lib/mlx/dsl/attention.rb +132 -0
- data/lib/mlx/dsl/builder.rb +385 -0
- data/lib/mlx/dsl/config_schema.rb +133 -0
- data/lib/mlx/dsl/data_pipeline.rb +284 -0
- data/lib/mlx/dsl/experiment.rb +154 -0
- data/lib/mlx/dsl/generate.rb +193 -0
- data/lib/mlx/dsl/graph_modules.rb +91 -0
- data/lib/mlx/dsl/kv_cache.rb +96 -0
- data/lib/mlx/dsl/masks.rb +32 -0
- data/lib/mlx/dsl/model.rb +9 -0
- data/lib/mlx/dsl/model_mixin.rb +706 -0
- data/lib/mlx/dsl/positions.rb +35 -0
- data/lib/mlx/dsl/run_stack.rb +68 -0
- data/lib/mlx/dsl/split_plan.rb +85 -0
- data/lib/mlx/dsl/tensor.rb +126 -0
- data/lib/mlx/dsl/train_step.rb +197 -0
- data/lib/mlx/dsl/trainer.rb +2110 -0
- data/lib/mlx/dsl/transformer_block.rb +113 -0
- data/lib/mlx/dsl/weight_map.rb +140 -0
- data/lib/mlx/dsl.rb +26 -0
- data/lib/mlx/nn/layers/containers.rb +21 -4
- data/lib/mlx/version.rb +1 -1
- data/lib/mlx.rb +1 -0
- data/mlx/CMakeLists.txt +4 -16
- metadata +67 -5
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module MLX
|
|
4
|
+
module DSL
|
|
5
|
+
class Generate
|
|
6
|
+
def initialize(
|
|
7
|
+
model:,
|
|
8
|
+
tokenizer: nil,
|
|
9
|
+
eos_id: nil,
|
|
10
|
+
sampler: nil,
|
|
11
|
+
mode: :decoder_only,
|
|
12
|
+
decoder_start_id: nil
|
|
13
|
+
)
|
|
14
|
+
@model = model
|
|
15
|
+
@tokenizer = tokenizer
|
|
16
|
+
@eos_id = eos_id
|
|
17
|
+
@sampler = { strategy: :argmax }.merge((sampler || {}).transform_keys(&:to_sym))
|
|
18
|
+
@mode = mode.to_sym
|
|
19
|
+
@decoder_start_id = decoder_start_id
|
|
20
|
+
end
|
|
21
|
+
|
|
22
|
+
def each_token(prompt: nil, input_ids: nil, max_tokens: 128, **kwargs)
|
|
23
|
+
return enum_for(__method__, prompt: prompt, input_ids: input_ids, max_tokens: max_tokens, **kwargs) unless block_given?
|
|
24
|
+
|
|
25
|
+
case @mode
|
|
26
|
+
when :decoder_only
|
|
27
|
+
__dsl_each_decoder_only(prompt: prompt, input_ids: input_ids, max_tokens: max_tokens, **kwargs) do |id, chunk|
|
|
28
|
+
yield id, chunk
|
|
29
|
+
end
|
|
30
|
+
when :encoder_decoder
|
|
31
|
+
__dsl_each_encoder_decoder(prompt: prompt, input_ids: input_ids, max_tokens: max_tokens, **kwargs) do |id, chunk|
|
|
32
|
+
yield id, chunk
|
|
33
|
+
end
|
|
34
|
+
else
|
|
35
|
+
raise ArgumentError, "unsupported generation mode: #{@mode.inspect}"
|
|
36
|
+
end
|
|
37
|
+
|
|
38
|
+
self
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
private
|
|
42
|
+
|
|
43
|
+
def __dsl_each_decoder_only(prompt:, input_ids:, max_tokens:, **kwargs)
|
|
44
|
+
tokens = input_ids.nil? ? __dsl_encode(prompt) : input_ids
|
|
45
|
+
model_input = __dsl_input_array(tokens)
|
|
46
|
+
logits, cache = __dsl_decode_step(model_input, cache: nil, **kwargs)
|
|
47
|
+
|
|
48
|
+
max_tokens.to_i.times do
|
|
49
|
+
token = __dsl_sample(__dsl_last_logits(logits))
|
|
50
|
+
token_id = __dsl_token_id(token)
|
|
51
|
+
chunk = __dsl_decode_token(token_id)
|
|
52
|
+
yield token_id, chunk
|
|
53
|
+
break if !@eos_id.nil? && token_id == @eos_id
|
|
54
|
+
|
|
55
|
+
next_input = MLX::Core.array([[token_id]], MLX::Core.int32)
|
|
56
|
+
logits, cache = __dsl_decode_step(next_input, cache: cache, **kwargs)
|
|
57
|
+
end
|
|
58
|
+
end
|
|
59
|
+
|
|
60
|
+
def __dsl_each_encoder_decoder(prompt:, input_ids:, max_tokens:, **kwargs)
|
|
61
|
+
tokens = input_ids.nil? ? __dsl_encode(prompt) : input_ids
|
|
62
|
+
source = __dsl_input_array(tokens)
|
|
63
|
+
|
|
64
|
+
if @model.respond_to?(:encode) && @model.respond_to?(:decode)
|
|
65
|
+
memory = @model.encode(source)
|
|
66
|
+
start_id = __dsl_decoder_start_id
|
|
67
|
+
decoder_input = MLX::Core.array([[start_id]], MLX::Core.int32)
|
|
68
|
+
cache = nil
|
|
69
|
+
|
|
70
|
+
max_tokens.to_i.times do
|
|
71
|
+
decoded = @model.decode(decoder_input, memory, cache: cache, **kwargs)
|
|
72
|
+
logits, cache = __dsl_split_logits_and_cache(decoded, cache)
|
|
73
|
+
token = __dsl_sample(__dsl_last_logits(logits))
|
|
74
|
+
token_id = __dsl_token_id(token)
|
|
75
|
+
chunk = __dsl_decode_token(token_id)
|
|
76
|
+
yield token_id, chunk
|
|
77
|
+
break if !@eos_id.nil? && token_id == @eos_id
|
|
78
|
+
|
|
79
|
+
decoder_input = MLX::Core.array([[token_id]], MLX::Core.int32)
|
|
80
|
+
end
|
|
81
|
+
return
|
|
82
|
+
end
|
|
83
|
+
|
|
84
|
+
# Fallback path for model.call-style APIs.
|
|
85
|
+
__dsl_each_decoder_only(prompt: prompt, input_ids: tokens, max_tokens: max_tokens, **kwargs) do |id, chunk|
|
|
86
|
+
yield id, chunk
|
|
87
|
+
end
|
|
88
|
+
end
|
|
89
|
+
|
|
90
|
+
def __dsl_decode_step(input_ids, cache:, **kwargs)
|
|
91
|
+
output = @model.call(input_ids, cache: cache, **kwargs)
|
|
92
|
+
__dsl_split_logits_and_cache(output, cache)
|
|
93
|
+
end
|
|
94
|
+
|
|
95
|
+
def __dsl_split_logits_and_cache(output, fallback_cache)
|
|
96
|
+
if output.is_a?(Array) && output.length == 2
|
|
97
|
+
[output[0], output[1]]
|
|
98
|
+
else
|
|
99
|
+
[output, fallback_cache]
|
|
100
|
+
end
|
|
101
|
+
end
|
|
102
|
+
|
|
103
|
+
def __dsl_last_logits(logits)
|
|
104
|
+
return logits if logits.ndim == 2
|
|
105
|
+
return logits if logits.ndim == 1
|
|
106
|
+
|
|
107
|
+
index = MLX::Core.array([logits.shape[1] - 1], MLX::Core.int32)
|
|
108
|
+
MLX::Core.squeeze(MLX::Core.take(logits, index, 1), 1)
|
|
109
|
+
end
|
|
110
|
+
|
|
111
|
+
def __dsl_sample(logits)
|
|
112
|
+
strategy = @sampler.fetch(:strategy, :argmax).to_sym
|
|
113
|
+
temperature = @sampler.fetch(:temperature, 1.0).to_f
|
|
114
|
+
|
|
115
|
+
return MLX::Core.argmax(logits, -1) if strategy == :argmax || temperature.zero?
|
|
116
|
+
|
|
117
|
+
case strategy
|
|
118
|
+
when :top_k
|
|
119
|
+
__dsl_top_k_sample(logits, k: Integer(@sampler.fetch(:k, 40)), temperature: temperature)
|
|
120
|
+
when :temperature, :categorical
|
|
121
|
+
__dsl_temperature_sample(logits, temperature: temperature)
|
|
122
|
+
else
|
|
123
|
+
raise ArgumentError, "unsupported sampler strategy: #{strategy.inspect}"
|
|
124
|
+
end
|
|
125
|
+
end
|
|
126
|
+
|
|
127
|
+
def __dsl_temperature_sample(logits, temperature:)
|
|
128
|
+
scaled = if temperature == 1.0
|
|
129
|
+
logits
|
|
130
|
+
else
|
|
131
|
+
MLX::Core.multiply(logits, 1.0 / temperature)
|
|
132
|
+
end
|
|
133
|
+
MLX::Core.categorical(scaled)
|
|
134
|
+
end
|
|
135
|
+
|
|
136
|
+
def __dsl_top_k_sample(logits, k:, temperature:)
|
|
137
|
+
rows = logits.ndim == 1 ? [logits.to_a] : logits.to_a
|
|
138
|
+
masked = rows.map do |row|
|
|
139
|
+
pairs = row.each_with_index.sort_by { |(value, _index)| -value }
|
|
140
|
+
keep = pairs.first([k, row.length].min).map(&:last)
|
|
141
|
+
filtered = Array.new(row.length, -Float::INFINITY)
|
|
142
|
+
keep.each { |idx| filtered[idx] = row[idx] }
|
|
143
|
+
filtered
|
|
144
|
+
end
|
|
145
|
+
masked_logits = MLX::Core.array(masked, logits.dtype)
|
|
146
|
+
__dsl_temperature_sample(masked_logits, temperature: temperature)
|
|
147
|
+
end
|
|
148
|
+
|
|
149
|
+
def __dsl_encode(prompt)
|
|
150
|
+
raise ArgumentError, "prompt/input_ids required when tokenizer is unavailable" if @tokenizer.nil?
|
|
151
|
+
|
|
152
|
+
@tokenizer.encode(prompt.to_s)
|
|
153
|
+
end
|
|
154
|
+
|
|
155
|
+
def __dsl_input_array(tokens)
|
|
156
|
+
if tokens.is_a?(MLX::Core::Array)
|
|
157
|
+
return tokens if tokens.ndim > 1
|
|
158
|
+
|
|
159
|
+
return MLX::Core.expand_dims(tokens.astype(MLX::Core.int32), 0)
|
|
160
|
+
end
|
|
161
|
+
|
|
162
|
+
arr = tokens.to_a
|
|
163
|
+
nested = arr.empty? ? [[]] : (arr.first.is_a?(Array) ? arr : [arr])
|
|
164
|
+
MLX::Core.array(nested, MLX::Core.int32)
|
|
165
|
+
end
|
|
166
|
+
|
|
167
|
+
def __dsl_token_id(token)
|
|
168
|
+
return token.item.to_i if token.respond_to?(:item)
|
|
169
|
+
|
|
170
|
+
value = token.to_a
|
|
171
|
+
if value.is_a?(Array)
|
|
172
|
+
first = value.first
|
|
173
|
+
return first.first.to_i if first.is_a?(Array)
|
|
174
|
+
return first.to_i
|
|
175
|
+
end
|
|
176
|
+
value.to_i
|
|
177
|
+
end
|
|
178
|
+
|
|
179
|
+
def __dsl_decode_token(token_id)
|
|
180
|
+
return nil if @tokenizer.nil? || !@tokenizer.respond_to?(:decode)
|
|
181
|
+
|
|
182
|
+
@tokenizer.decode([token_id])
|
|
183
|
+
end
|
|
184
|
+
|
|
185
|
+
def __dsl_decoder_start_id
|
|
186
|
+
return @decoder_start_id unless @decoder_start_id.nil?
|
|
187
|
+
return @tokenizer.decoder_start_id if !@tokenizer.nil? && @tokenizer.respond_to?(:decoder_start_id)
|
|
188
|
+
|
|
189
|
+
raise ArgumentError, "decoder_start_id is required for encoder-decoder mode"
|
|
190
|
+
end
|
|
191
|
+
end
|
|
192
|
+
end
|
|
193
|
+
end
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module MLX
|
|
4
|
+
module DSL
|
|
5
|
+
class Callable < MLX::NN::Module
|
|
6
|
+
def initialize(callable = nil, &block)
|
|
7
|
+
super()
|
|
8
|
+
if !callable.nil? && block_given?
|
|
9
|
+
raise ArgumentError, "callable layer accepts either a callable argument or block, not both"
|
|
10
|
+
end
|
|
11
|
+
|
|
12
|
+
@callable = callable.nil? ? block : callable
|
|
13
|
+
unless @callable.respond_to?(:call)
|
|
14
|
+
raise ArgumentError, "callable layer requires a callable argument or block"
|
|
15
|
+
end
|
|
16
|
+
end
|
|
17
|
+
|
|
18
|
+
def call(*args, **kwargs)
|
|
19
|
+
return @callable.call(*args) if kwargs.empty?
|
|
20
|
+
|
|
21
|
+
@callable.call(*args, **kwargs)
|
|
22
|
+
end
|
|
23
|
+
end
|
|
24
|
+
|
|
25
|
+
class Residual < MLX::NN::Module
|
|
26
|
+
def initialize(module_obj)
|
|
27
|
+
super()
|
|
28
|
+
self.module_obj = module_obj
|
|
29
|
+
end
|
|
30
|
+
|
|
31
|
+
def call(*args, **kwargs)
|
|
32
|
+
raise ArgumentError, "residual module expects at least one positional input" if args.empty?
|
|
33
|
+
|
|
34
|
+
identity = args[0]
|
|
35
|
+
transformed = module_obj.call(*args, **kwargs)
|
|
36
|
+
MLX::Core.add(identity, transformed)
|
|
37
|
+
end
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
class Parallel < MLX::NN::Module
|
|
41
|
+
def initialize(*modules)
|
|
42
|
+
super()
|
|
43
|
+
self.layers = modules
|
|
44
|
+
end
|
|
45
|
+
|
|
46
|
+
def call(*args, **kwargs)
|
|
47
|
+
layers.map do |layer|
|
|
48
|
+
layer.call(*args, **kwargs)
|
|
49
|
+
end
|
|
50
|
+
end
|
|
51
|
+
end
|
|
52
|
+
|
|
53
|
+
class Concat < MLX::NN::Module
|
|
54
|
+
def initialize(*modules, axis: -1)
|
|
55
|
+
super()
|
|
56
|
+
self.layers = modules
|
|
57
|
+
@axis = axis
|
|
58
|
+
end
|
|
59
|
+
|
|
60
|
+
def call(*args, **kwargs)
|
|
61
|
+
outputs = layers.map do |layer|
|
|
62
|
+
layer.call(*args, **kwargs)
|
|
63
|
+
end
|
|
64
|
+
MLX::Core.concatenate(outputs, @axis)
|
|
65
|
+
end
|
|
66
|
+
end
|
|
67
|
+
|
|
68
|
+
class Reduce < MLX::NN::Module
|
|
69
|
+
def initialize(*modules, mode: :sum)
|
|
70
|
+
super()
|
|
71
|
+
self.layers = modules
|
|
72
|
+
@mode = mode.to_sym
|
|
73
|
+
end
|
|
74
|
+
|
|
75
|
+
def call(*args, **kwargs)
|
|
76
|
+
outputs = layers.map do |layer|
|
|
77
|
+
layer.call(*args, **kwargs)
|
|
78
|
+
end
|
|
79
|
+
|
|
80
|
+
case @mode
|
|
81
|
+
when :sum
|
|
82
|
+
outputs.reduce do |acc, item|
|
|
83
|
+
MLX::Core.add(acc, item)
|
|
84
|
+
end
|
|
85
|
+
else
|
|
86
|
+
raise ArgumentError, "unsupported reduce mode: #{@mode.inspect}"
|
|
87
|
+
end
|
|
88
|
+
end
|
|
89
|
+
end
|
|
90
|
+
end
|
|
91
|
+
end
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module MLX
|
|
4
|
+
module DSL
|
|
5
|
+
class KVCache
|
|
6
|
+
attr_reader :num_layers
|
|
7
|
+
|
|
8
|
+
def initialize(num_layers:)
|
|
9
|
+
@num_layers = Integer(num_layers)
|
|
10
|
+
raise ArgumentError, "num_layers must be non-negative" if @num_layers.negative?
|
|
11
|
+
|
|
12
|
+
@layers = Array.new(@num_layers)
|
|
13
|
+
end
|
|
14
|
+
|
|
15
|
+
def layer(index)
|
|
16
|
+
@layers.fetch(__dsl_index(index))
|
|
17
|
+
end
|
|
18
|
+
|
|
19
|
+
def []=(index, value)
|
|
20
|
+
@layers[__dsl_index(index)] = value
|
|
21
|
+
end
|
|
22
|
+
|
|
23
|
+
def offset(layer:)
|
|
24
|
+
state = self.layer(layer)
|
|
25
|
+
return 0 if state.nil?
|
|
26
|
+
|
|
27
|
+
keys, = state
|
|
28
|
+
keys.shape[2]
|
|
29
|
+
end
|
|
30
|
+
|
|
31
|
+
def append(layer:, keys:, values:)
|
|
32
|
+
idx = __dsl_index(layer)
|
|
33
|
+
current = @layers[idx]
|
|
34
|
+
if current.nil?
|
|
35
|
+
@layers[idx] = [keys, values]
|
|
36
|
+
return @layers[idx]
|
|
37
|
+
end
|
|
38
|
+
|
|
39
|
+
key_cache, value_cache = current
|
|
40
|
+
next_keys = MLX::Core.concatenate([key_cache, keys], 2)
|
|
41
|
+
next_values = MLX::Core.concatenate([value_cache, values], 2)
|
|
42
|
+
@layers[idx] = [next_keys, next_values]
|
|
43
|
+
end
|
|
44
|
+
|
|
45
|
+
def truncate!(tokens:, layer: nil)
|
|
46
|
+
keep = Integer(tokens)
|
|
47
|
+
if layer.nil?
|
|
48
|
+
@layers.each_index { |idx| __dsl_truncate_layer!(idx, keep) }
|
|
49
|
+
else
|
|
50
|
+
__dsl_truncate_layer!(__dsl_index(layer), keep)
|
|
51
|
+
end
|
|
52
|
+
self
|
|
53
|
+
end
|
|
54
|
+
|
|
55
|
+
def reset!(layer: nil)
|
|
56
|
+
if layer.nil?
|
|
57
|
+
@layers.map! { nil }
|
|
58
|
+
else
|
|
59
|
+
@layers[__dsl_index(layer)] = nil
|
|
60
|
+
end
|
|
61
|
+
self
|
|
62
|
+
end
|
|
63
|
+
|
|
64
|
+
private
|
|
65
|
+
|
|
66
|
+
def __dsl_index(index)
|
|
67
|
+
idx = Integer(index)
|
|
68
|
+
if idx.negative? || idx >= @num_layers
|
|
69
|
+
raise IndexError, "layer index #{idx} out of range (0...#{@num_layers})"
|
|
70
|
+
end
|
|
71
|
+
|
|
72
|
+
idx
|
|
73
|
+
end
|
|
74
|
+
|
|
75
|
+
def __dsl_truncate_layer!(idx, keep)
|
|
76
|
+
state = @layers[idx]
|
|
77
|
+
return if state.nil?
|
|
78
|
+
|
|
79
|
+
if keep <= 0
|
|
80
|
+
@layers[idx] = nil
|
|
81
|
+
return
|
|
82
|
+
end
|
|
83
|
+
|
|
84
|
+
keys, values = state
|
|
85
|
+
total = keys.shape[2]
|
|
86
|
+
return if keep >= total
|
|
87
|
+
|
|
88
|
+
start = total - keep
|
|
89
|
+
indices = MLX::Core.arange(start, total, 1, MLX::Core.int32)
|
|
90
|
+
trimmed_keys = MLX::Core.take(keys, indices, 2)
|
|
91
|
+
trimmed_values = MLX::Core.take(values, indices, 2)
|
|
92
|
+
@layers[idx] = [trimmed_keys, trimmed_values]
|
|
93
|
+
end
|
|
94
|
+
end
|
|
95
|
+
end
|
|
96
|
+
end
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module MLX
|
|
4
|
+
module DSL
|
|
5
|
+
module Masks
|
|
6
|
+
module_function
|
|
7
|
+
|
|
8
|
+
def causal(length:, offset: 0, dtype: MLX::Core.float32)
|
|
9
|
+
length = Integer(length)
|
|
10
|
+
offset = Integer(offset)
|
|
11
|
+
raise ArgumentError, "length must be non-negative" if length.negative?
|
|
12
|
+
raise ArgumentError, "offset must be non-negative" if offset.negative?
|
|
13
|
+
|
|
14
|
+
rinds = MLX::Core.arange(0, offset + length, 1)
|
|
15
|
+
linds = if offset.zero?
|
|
16
|
+
rinds
|
|
17
|
+
else
|
|
18
|
+
MLX::Core.arange(offset, offset + length, 1)
|
|
19
|
+
end
|
|
20
|
+
lhs = MLX::Core.expand_dims(linds, 1)
|
|
21
|
+
rhs = MLX::Core.expand_dims(rinds, 0)
|
|
22
|
+
mask = MLX::Core.less(lhs, rhs).astype(dtype)
|
|
23
|
+
min_value = if MLX::Core.respond_to?(:finfo)
|
|
24
|
+
MLX::Core.finfo(dtype).min
|
|
25
|
+
else
|
|
26
|
+
-1e9
|
|
27
|
+
end
|
|
28
|
+
MLX::Core.multiply(mask, min_value)
|
|
29
|
+
end
|
|
30
|
+
end
|
|
31
|
+
end
|
|
32
|
+
end
|