mlx 0.30.7 → 0.30.7.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/ext/mlx/native.cpp +0 -4
- data/lib/mlx/core.rb +8 -1
- data/lib/mlx/distributed_utils/launch.rb +9 -3
- data/lib/mlx/dsl/builder.rb +377 -0
- data/lib/mlx/dsl/data_pipeline.rb +284 -0
- data/lib/mlx/dsl/experiment.rb +154 -0
- data/lib/mlx/dsl/graph_modules.rb +91 -0
- data/lib/mlx/dsl/model.rb +9 -0
- data/lib/mlx/dsl/model_mixin.rb +706 -0
- data/lib/mlx/dsl/split_plan.rb +85 -0
- data/lib/mlx/dsl/train_step.rb +197 -0
- data/lib/mlx/dsl/trainer.rb +2110 -0
- data/lib/mlx/dsl.rb +16 -0
- data/lib/mlx/nn/layers/containers.rb +21 -4
- data/lib/mlx/version.rb +1 -1
- data/lib/mlx.rb +1 -0
- data/mlx/CMakeLists.txt +4 -16
- metadata +12 -2
|
@@ -0,0 +1,284 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module MLX
|
|
4
|
+
module DSL
|
|
5
|
+
module Data
|
|
6
|
+
def self.from(source = nil, &block)
|
|
7
|
+
if !source.nil? && block_given?
|
|
8
|
+
raise ArgumentError, "data pipeline source accepts either a source or block, not both"
|
|
9
|
+
end
|
|
10
|
+
|
|
11
|
+
producer = block_given? ? block : source
|
|
12
|
+
if producer.nil?
|
|
13
|
+
raise ArgumentError, "data pipeline requires a source enumerable or source block"
|
|
14
|
+
end
|
|
15
|
+
|
|
16
|
+
Pipeline.new(__dsl_factory_for(producer))
|
|
17
|
+
end
|
|
18
|
+
|
|
19
|
+
def self.pipeline(source = nil, &block)
|
|
20
|
+
from(source, &block)
|
|
21
|
+
end
|
|
22
|
+
|
|
23
|
+
def self.__dsl_factory_for(producer)
|
|
24
|
+
if producer.respond_to?(:call)
|
|
25
|
+
lambda do
|
|
26
|
+
__dsl_to_enumerator(producer.call)
|
|
27
|
+
end
|
|
28
|
+
else
|
|
29
|
+
lambda do
|
|
30
|
+
if producer.respond_to?(:rewind)
|
|
31
|
+
begin
|
|
32
|
+
producer.rewind
|
|
33
|
+
rescue StandardError
|
|
34
|
+
# Keep default Enumerable semantics when rewind is unavailable at runtime.
|
|
35
|
+
end
|
|
36
|
+
end
|
|
37
|
+
__dsl_to_enumerator(producer)
|
|
38
|
+
end
|
|
39
|
+
end
|
|
40
|
+
end
|
|
41
|
+
private_class_method :__dsl_factory_for
|
|
42
|
+
|
|
43
|
+
def self.__dsl_to_enumerator(value)
|
|
44
|
+
unless value.respond_to?(:each)
|
|
45
|
+
raise ArgumentError, "data pipeline source must respond to #each"
|
|
46
|
+
end
|
|
47
|
+
|
|
48
|
+
value.to_enum
|
|
49
|
+
end
|
|
50
|
+
private_class_method :__dsl_to_enumerator
|
|
51
|
+
|
|
52
|
+
class Pipeline
|
|
53
|
+
include Enumerable
|
|
54
|
+
|
|
55
|
+
def initialize(factory)
|
|
56
|
+
@factory = factory
|
|
57
|
+
end
|
|
58
|
+
|
|
59
|
+
def each
|
|
60
|
+
enum = @factory.call
|
|
61
|
+
return enum unless block_given?
|
|
62
|
+
|
|
63
|
+
enum.each { |item| yield item }
|
|
64
|
+
end
|
|
65
|
+
|
|
66
|
+
def map(&block)
|
|
67
|
+
raise ArgumentError, "pipeline map requires a block" unless block_given?
|
|
68
|
+
|
|
69
|
+
self.class.new(lambda {
|
|
70
|
+
upstream = @factory.call
|
|
71
|
+
Enumerator.new do |y|
|
|
72
|
+
index = 0
|
|
73
|
+
upstream.each do |item|
|
|
74
|
+
y << __dsl_call_with_context(block, item, index, "pipeline map")
|
|
75
|
+
index += 1
|
|
76
|
+
end
|
|
77
|
+
end
|
|
78
|
+
})
|
|
79
|
+
end
|
|
80
|
+
|
|
81
|
+
def filter(&block)
|
|
82
|
+
raise ArgumentError, "pipeline filter requires a block" unless block_given?
|
|
83
|
+
|
|
84
|
+
self.class.new(lambda {
|
|
85
|
+
upstream = @factory.call
|
|
86
|
+
Enumerator.new do |y|
|
|
87
|
+
index = 0
|
|
88
|
+
upstream.each do |item|
|
|
89
|
+
y << item if __dsl_call_with_context(block, item, index, "pipeline filter")
|
|
90
|
+
index += 1
|
|
91
|
+
end
|
|
92
|
+
end
|
|
93
|
+
})
|
|
94
|
+
end
|
|
95
|
+
|
|
96
|
+
def batch(size, drop_last: false)
|
|
97
|
+
batch_size = size.to_i
|
|
98
|
+
raise ArgumentError, "pipeline batch size must be positive" if batch_size <= 0
|
|
99
|
+
|
|
100
|
+
self.class.new(lambda {
|
|
101
|
+
upstream = @factory.call
|
|
102
|
+
Enumerator.new do |y|
|
|
103
|
+
chunk = []
|
|
104
|
+
upstream.each do |item|
|
|
105
|
+
chunk << item
|
|
106
|
+
if chunk.length == batch_size
|
|
107
|
+
y << chunk
|
|
108
|
+
chunk = []
|
|
109
|
+
end
|
|
110
|
+
end
|
|
111
|
+
y << chunk unless drop_last || chunk.empty?
|
|
112
|
+
end
|
|
113
|
+
})
|
|
114
|
+
end
|
|
115
|
+
|
|
116
|
+
def take(count)
|
|
117
|
+
limit = count.to_i
|
|
118
|
+
raise ArgumentError, "pipeline take count must be non-negative" if limit.negative?
|
|
119
|
+
|
|
120
|
+
self.class.new(lambda {
|
|
121
|
+
upstream = @factory.call
|
|
122
|
+
Enumerator.new do |y|
|
|
123
|
+
seen = 0
|
|
124
|
+
while seen < limit
|
|
125
|
+
begin
|
|
126
|
+
y << upstream.next
|
|
127
|
+
seen += 1
|
|
128
|
+
rescue StopIteration
|
|
129
|
+
break
|
|
130
|
+
end
|
|
131
|
+
end
|
|
132
|
+
end
|
|
133
|
+
})
|
|
134
|
+
end
|
|
135
|
+
|
|
136
|
+
def repeat(times = nil)
|
|
137
|
+
if times.nil?
|
|
138
|
+
self.class.new(lambda {
|
|
139
|
+
Enumerator.new do |y|
|
|
140
|
+
loop do
|
|
141
|
+
upstream = @factory.call
|
|
142
|
+
produced = false
|
|
143
|
+
upstream.each do |item|
|
|
144
|
+
produced = true
|
|
145
|
+
y << item
|
|
146
|
+
end
|
|
147
|
+
break unless produced
|
|
148
|
+
end
|
|
149
|
+
end
|
|
150
|
+
})
|
|
151
|
+
else
|
|
152
|
+
cycles = times.to_i
|
|
153
|
+
raise ArgumentError, "pipeline repeat count must be non-negative" if cycles.negative?
|
|
154
|
+
|
|
155
|
+
self.class.new(lambda {
|
|
156
|
+
Enumerator.new do |y|
|
|
157
|
+
cycles.times do
|
|
158
|
+
@factory.call.each do |item|
|
|
159
|
+
y << item
|
|
160
|
+
end
|
|
161
|
+
end
|
|
162
|
+
end
|
|
163
|
+
})
|
|
164
|
+
end
|
|
165
|
+
end
|
|
166
|
+
|
|
167
|
+
def shuffle(seed: nil, random: nil)
|
|
168
|
+
if !seed.nil? && !random.nil?
|
|
169
|
+
raise ArgumentError, "pipeline shuffle accepts either seed: or random:, not both"
|
|
170
|
+
end
|
|
171
|
+
|
|
172
|
+
self.class.new(lambda {
|
|
173
|
+
items = @factory.call.to_a
|
|
174
|
+
rng = if !random.nil?
|
|
175
|
+
random
|
|
176
|
+
elsif !seed.nil?
|
|
177
|
+
Random.new(seed.to_i)
|
|
178
|
+
else
|
|
179
|
+
Random.new
|
|
180
|
+
end
|
|
181
|
+
items.shuffle(random: rng).to_enum
|
|
182
|
+
})
|
|
183
|
+
end
|
|
184
|
+
|
|
185
|
+
def prefetch(size = 1)
|
|
186
|
+
prefetch_size = size.to_i
|
|
187
|
+
raise ArgumentError, "pipeline prefetch size must be positive" if prefetch_size <= 0
|
|
188
|
+
|
|
189
|
+
self.class.new(lambda {
|
|
190
|
+
upstream = @factory.call
|
|
191
|
+
Enumerator.new do |y|
|
|
192
|
+
buffer = []
|
|
193
|
+
|
|
194
|
+
prefetch_size.times do
|
|
195
|
+
begin
|
|
196
|
+
buffer << upstream.next
|
|
197
|
+
rescue StopIteration
|
|
198
|
+
break
|
|
199
|
+
end
|
|
200
|
+
end
|
|
201
|
+
|
|
202
|
+
until buffer.empty?
|
|
203
|
+
y << buffer.shift
|
|
204
|
+
begin
|
|
205
|
+
buffer << upstream.next
|
|
206
|
+
rescue StopIteration
|
|
207
|
+
# Exhausted upstream; continue draining buffer.
|
|
208
|
+
end
|
|
209
|
+
end
|
|
210
|
+
end
|
|
211
|
+
})
|
|
212
|
+
end
|
|
213
|
+
|
|
214
|
+
private
|
|
215
|
+
|
|
216
|
+
def __dsl_call_with_context(callable, item, index, label)
|
|
217
|
+
values = {
|
|
218
|
+
item: item,
|
|
219
|
+
index: index,
|
|
220
|
+
pipeline: self
|
|
221
|
+
}
|
|
222
|
+
return callable.call(item, index) unless callable.respond_to?(:parameters)
|
|
223
|
+
|
|
224
|
+
params = callable.parameters
|
|
225
|
+
return callable.call(item, index) if params.empty?
|
|
226
|
+
|
|
227
|
+
args = __dsl_build_positional_args(
|
|
228
|
+
params,
|
|
229
|
+
values,
|
|
230
|
+
[[:item, item], [:index, index], [:pipeline, self]],
|
|
231
|
+
label
|
|
232
|
+
)
|
|
233
|
+
kwargs = __dsl_build_keyword_args(params, values, label)
|
|
234
|
+
return callable.call(*args) if kwargs.empty?
|
|
235
|
+
|
|
236
|
+
callable.call(*args, **kwargs)
|
|
237
|
+
end
|
|
238
|
+
|
|
239
|
+
def __dsl_build_positional_args(params, values, fallback_pairs, label)
|
|
240
|
+
queue = fallback_pairs.dup
|
|
241
|
+
args = []
|
|
242
|
+
params.each do |type, name|
|
|
243
|
+
next unless type == :req || type == :opt
|
|
244
|
+
|
|
245
|
+
if !name.nil? && values.key?(name)
|
|
246
|
+
args << values.fetch(name)
|
|
247
|
+
queue.reject! { |key, _value| key == name }
|
|
248
|
+
next
|
|
249
|
+
end
|
|
250
|
+
|
|
251
|
+
if queue.empty?
|
|
252
|
+
raise ArgumentError, "#{label} has unsupported required positional argument: #{name.inspect}" if type == :req
|
|
253
|
+
break
|
|
254
|
+
end
|
|
255
|
+
|
|
256
|
+
_key, value = queue.shift
|
|
257
|
+
args << value
|
|
258
|
+
end
|
|
259
|
+
args
|
|
260
|
+
end
|
|
261
|
+
|
|
262
|
+
def __dsl_build_keyword_args(params, values, label)
|
|
263
|
+
return values.dup if params.any? { |type, _name| type == :keyrest }
|
|
264
|
+
|
|
265
|
+
required_keys = params.each_with_object([]) do |(type, name), out|
|
|
266
|
+
out << name if type == :keyreq
|
|
267
|
+
end
|
|
268
|
+
missing = required_keys.reject { |name| values.key?(name) }
|
|
269
|
+
unless missing.empty?
|
|
270
|
+
raise ArgumentError, "#{label} requires unsupported keyword argument(s): #{missing.map(&:inspect).join(", ")}"
|
|
271
|
+
end
|
|
272
|
+
|
|
273
|
+
accepted_keys = params.each_with_object([]) do |(type, name), out|
|
|
274
|
+
out << name if type == :key || type == :keyreq
|
|
275
|
+
end
|
|
276
|
+
|
|
277
|
+
values.each_with_object({}) do |(name, value), out|
|
|
278
|
+
out[name] = value if accepted_keys.include?(name)
|
|
279
|
+
end
|
|
280
|
+
end
|
|
281
|
+
end
|
|
282
|
+
end
|
|
283
|
+
end
|
|
284
|
+
end
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module MLX
|
|
4
|
+
module DSL
|
|
5
|
+
def self.experiment(name = nil, &block)
|
|
6
|
+
instance = Experiment.new(name: name)
|
|
7
|
+
instance.instance_eval(&block) if block_given?
|
|
8
|
+
instance
|
|
9
|
+
end
|
|
10
|
+
|
|
11
|
+
class Experiment
|
|
12
|
+
attr_reader :name
|
|
13
|
+
|
|
14
|
+
def initialize(name: nil)
|
|
15
|
+
@name = name
|
|
16
|
+
@model_source = nil
|
|
17
|
+
@optimizer_source = nil
|
|
18
|
+
@trainer_source = nil
|
|
19
|
+
@trainer_kwargs = {}
|
|
20
|
+
@loss_block = nil
|
|
21
|
+
@data_config = { train: nil, validation: nil, fit: {} }
|
|
22
|
+
@artifact_config = {}
|
|
23
|
+
@last_trainer = nil
|
|
24
|
+
@last_report = nil
|
|
25
|
+
end
|
|
26
|
+
|
|
27
|
+
def model(value = nil, &block)
|
|
28
|
+
if !value.nil? && block_given?
|
|
29
|
+
raise ArgumentError, "model accepts either a value argument or block, not both"
|
|
30
|
+
end
|
|
31
|
+
|
|
32
|
+
@model_source = block_given? ? block : value
|
|
33
|
+
self
|
|
34
|
+
end
|
|
35
|
+
|
|
36
|
+
def optimizer(value = nil, &block)
|
|
37
|
+
if !value.nil? && block_given?
|
|
38
|
+
raise ArgumentError, "optimizer accepts either a value argument or block, not both"
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
@optimizer_source = block_given? ? block : value
|
|
42
|
+
self
|
|
43
|
+
end
|
|
44
|
+
|
|
45
|
+
def trainer(value = nil, **kwargs, &block)
|
|
46
|
+
if value.is_a?(MLX::DSL::Trainer)
|
|
47
|
+
if !kwargs.empty? || block_given?
|
|
48
|
+
raise ArgumentError, "trainer instance injection cannot be combined with trainer kwargs or loss block"
|
|
49
|
+
end
|
|
50
|
+
@trainer_source = value
|
|
51
|
+
return self
|
|
52
|
+
end
|
|
53
|
+
|
|
54
|
+
unless value.nil?
|
|
55
|
+
raise ArgumentError, "trainer positional argument must be an MLX::DSL::Trainer instance"
|
|
56
|
+
end
|
|
57
|
+
|
|
58
|
+
@trainer_source = nil
|
|
59
|
+
@trainer_kwargs = kwargs.dup
|
|
60
|
+
@loss_block = block if block_given?
|
|
61
|
+
self
|
|
62
|
+
end
|
|
63
|
+
|
|
64
|
+
def data(train: nil, validation: :__dsl_unset__, **fit_kwargs)
|
|
65
|
+
@data_config[:train] = train unless train.nil?
|
|
66
|
+
@data_config[:validation] = validation unless validation == :__dsl_unset__
|
|
67
|
+
@data_config[:fit].merge!(fit_kwargs)
|
|
68
|
+
self
|
|
69
|
+
end
|
|
70
|
+
|
|
71
|
+
def artifacts(**kwargs)
|
|
72
|
+
@artifact_config.merge!(kwargs)
|
|
73
|
+
self
|
|
74
|
+
end
|
|
75
|
+
|
|
76
|
+
def run(report: false, **overrides)
|
|
77
|
+
dataset, fit_kwargs = __dsl_resolve_fit_call(overrides)
|
|
78
|
+
active_trainer = __dsl_resolve_trainer
|
|
79
|
+
result = if report
|
|
80
|
+
active_trainer.fit_report(dataset, **fit_kwargs)
|
|
81
|
+
else
|
|
82
|
+
active_trainer.fit(dataset, **fit_kwargs)
|
|
83
|
+
end
|
|
84
|
+
@last_report = result if report
|
|
85
|
+
result
|
|
86
|
+
end
|
|
87
|
+
|
|
88
|
+
def report(**overrides)
|
|
89
|
+
run(report: true, **overrides)
|
|
90
|
+
end
|
|
91
|
+
|
|
92
|
+
def save_run_bundle(path, report: nil, config: {}, **overrides)
|
|
93
|
+
active_report = report
|
|
94
|
+
if active_report.nil?
|
|
95
|
+
active_report = if !@last_report.nil?
|
|
96
|
+
@last_report
|
|
97
|
+
else
|
|
98
|
+
self.report(**overrides)
|
|
99
|
+
end
|
|
100
|
+
end
|
|
101
|
+
|
|
102
|
+
__dsl_resolve_trainer.save_run_bundle(path, report: active_report, config: config)
|
|
103
|
+
end
|
|
104
|
+
|
|
105
|
+
private
|
|
106
|
+
|
|
107
|
+
def __dsl_resolve_fit_call(overrides)
|
|
108
|
+
fit_kwargs = @data_config.fetch(:fit).dup
|
|
109
|
+
fit_kwargs.merge!(@artifact_config)
|
|
110
|
+
fit_kwargs[:validation_data] = @data_config[:validation] if !@data_config[:validation].nil? && !fit_kwargs.key?(:validation_data)
|
|
111
|
+
|
|
112
|
+
incoming = overrides.dup
|
|
113
|
+
dataset = if incoming.key?(:dataset)
|
|
114
|
+
incoming.delete(:dataset)
|
|
115
|
+
elsif incoming.key?(:train)
|
|
116
|
+
incoming.delete(:train)
|
|
117
|
+
else
|
|
118
|
+
@data_config[:train]
|
|
119
|
+
end
|
|
120
|
+
if dataset.nil?
|
|
121
|
+
raise ArgumentError, "experiment run requires a train dataset via data(train:) or run(dataset:)"
|
|
122
|
+
end
|
|
123
|
+
|
|
124
|
+
[dataset, fit_kwargs.merge(incoming)]
|
|
125
|
+
end
|
|
126
|
+
|
|
127
|
+
def __dsl_resolve_trainer
|
|
128
|
+
return @trainer_source if @trainer_source.is_a?(MLX::DSL::Trainer)
|
|
129
|
+
return @last_trainer unless @last_trainer.nil?
|
|
130
|
+
|
|
131
|
+
model = __dsl_resolve_source(@model_source, "model")
|
|
132
|
+
optimizer = __dsl_resolve_source(@optimizer_source, "optimizer")
|
|
133
|
+
unless model.respond_to?(:trainer)
|
|
134
|
+
raise ArgumentError, "experiment model must respond to #trainer when trainer instance is not injected"
|
|
135
|
+
end
|
|
136
|
+
unless @loss_block.respond_to?(:call)
|
|
137
|
+
raise ArgumentError, "experiment trainer requires a loss block when trainer instance is not injected"
|
|
138
|
+
end
|
|
139
|
+
|
|
140
|
+
@last_trainer = model.trainer(optimizer: optimizer, **@trainer_kwargs, &@loss_block)
|
|
141
|
+
end
|
|
142
|
+
|
|
143
|
+
def __dsl_resolve_source(source, label)
|
|
144
|
+
value = source
|
|
145
|
+
value = value.call if value.respond_to?(:call)
|
|
146
|
+
if value.nil?
|
|
147
|
+
raise ArgumentError, "experiment #{label} section is required when trainer instance is not injected"
|
|
148
|
+
end
|
|
149
|
+
|
|
150
|
+
value
|
|
151
|
+
end
|
|
152
|
+
end
|
|
153
|
+
end
|
|
154
|
+
end
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module MLX
|
|
4
|
+
module DSL
|
|
5
|
+
class Callable < MLX::NN::Module
|
|
6
|
+
def initialize(callable = nil, &block)
|
|
7
|
+
super()
|
|
8
|
+
if !callable.nil? && block_given?
|
|
9
|
+
raise ArgumentError, "callable layer accepts either a callable argument or block, not both"
|
|
10
|
+
end
|
|
11
|
+
|
|
12
|
+
@callable = callable.nil? ? block : callable
|
|
13
|
+
unless @callable.respond_to?(:call)
|
|
14
|
+
raise ArgumentError, "callable layer requires a callable argument or block"
|
|
15
|
+
end
|
|
16
|
+
end
|
|
17
|
+
|
|
18
|
+
def call(*args, **kwargs)
|
|
19
|
+
return @callable.call(*args) if kwargs.empty?
|
|
20
|
+
|
|
21
|
+
@callable.call(*args, **kwargs)
|
|
22
|
+
end
|
|
23
|
+
end
|
|
24
|
+
|
|
25
|
+
class Residual < MLX::NN::Module
|
|
26
|
+
def initialize(module_obj)
|
|
27
|
+
super()
|
|
28
|
+
self.module_obj = module_obj
|
|
29
|
+
end
|
|
30
|
+
|
|
31
|
+
def call(*args, **kwargs)
|
|
32
|
+
raise ArgumentError, "residual module expects at least one positional input" if args.empty?
|
|
33
|
+
|
|
34
|
+
identity = args[0]
|
|
35
|
+
transformed = module_obj.call(*args, **kwargs)
|
|
36
|
+
MLX::Core.add(identity, transformed)
|
|
37
|
+
end
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
class Parallel < MLX::NN::Module
|
|
41
|
+
def initialize(*modules)
|
|
42
|
+
super()
|
|
43
|
+
self.layers = modules
|
|
44
|
+
end
|
|
45
|
+
|
|
46
|
+
def call(*args, **kwargs)
|
|
47
|
+
layers.map do |layer|
|
|
48
|
+
layer.call(*args, **kwargs)
|
|
49
|
+
end
|
|
50
|
+
end
|
|
51
|
+
end
|
|
52
|
+
|
|
53
|
+
class Concat < MLX::NN::Module
|
|
54
|
+
def initialize(*modules, axis: -1)
|
|
55
|
+
super()
|
|
56
|
+
self.layers = modules
|
|
57
|
+
@axis = axis
|
|
58
|
+
end
|
|
59
|
+
|
|
60
|
+
def call(*args, **kwargs)
|
|
61
|
+
outputs = layers.map do |layer|
|
|
62
|
+
layer.call(*args, **kwargs)
|
|
63
|
+
end
|
|
64
|
+
MLX::Core.concatenate(outputs, @axis)
|
|
65
|
+
end
|
|
66
|
+
end
|
|
67
|
+
|
|
68
|
+
class Reduce < MLX::NN::Module
|
|
69
|
+
def initialize(*modules, mode: :sum)
|
|
70
|
+
super()
|
|
71
|
+
self.layers = modules
|
|
72
|
+
@mode = mode.to_sym
|
|
73
|
+
end
|
|
74
|
+
|
|
75
|
+
def call(*args, **kwargs)
|
|
76
|
+
outputs = layers.map do |layer|
|
|
77
|
+
layer.call(*args, **kwargs)
|
|
78
|
+
end
|
|
79
|
+
|
|
80
|
+
case @mode
|
|
81
|
+
when :sum
|
|
82
|
+
outputs.reduce do |acc, item|
|
|
83
|
+
MLX::Core.add(acc, item)
|
|
84
|
+
end
|
|
85
|
+
else
|
|
86
|
+
raise ArgumentError, "unsupported reduce mode: #{@mode.inspect}"
|
|
87
|
+
end
|
|
88
|
+
end
|
|
89
|
+
end
|
|
90
|
+
end
|
|
91
|
+
end
|