mlx 0.30.7 → 0.30.7.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,284 @@
1
+ # frozen_string_literal: true
2
+
3
+ module MLX
4
+ module DSL
5
+ module Data
6
+ def self.from(source = nil, &block)
7
+ if !source.nil? && block_given?
8
+ raise ArgumentError, "data pipeline source accepts either a source or block, not both"
9
+ end
10
+
11
+ producer = block_given? ? block : source
12
+ if producer.nil?
13
+ raise ArgumentError, "data pipeline requires a source enumerable or source block"
14
+ end
15
+
16
+ Pipeline.new(__dsl_factory_for(producer))
17
+ end
18
+
19
+ def self.pipeline(source = nil, &block)
20
+ from(source, &block)
21
+ end
22
+
23
+ def self.__dsl_factory_for(producer)
24
+ if producer.respond_to?(:call)
25
+ lambda do
26
+ __dsl_to_enumerator(producer.call)
27
+ end
28
+ else
29
+ lambda do
30
+ if producer.respond_to?(:rewind)
31
+ begin
32
+ producer.rewind
33
+ rescue StandardError
34
+ # Keep default Enumerable semantics when rewind is unavailable at runtime.
35
+ end
36
+ end
37
+ __dsl_to_enumerator(producer)
38
+ end
39
+ end
40
+ end
41
+ private_class_method :__dsl_factory_for
42
+
43
+ def self.__dsl_to_enumerator(value)
44
+ unless value.respond_to?(:each)
45
+ raise ArgumentError, "data pipeline source must respond to #each"
46
+ end
47
+
48
+ value.to_enum
49
+ end
50
+ private_class_method :__dsl_to_enumerator
51
+
52
+ class Pipeline
53
+ include Enumerable
54
+
55
+ def initialize(factory)
56
+ @factory = factory
57
+ end
58
+
59
+ def each
60
+ enum = @factory.call
61
+ return enum unless block_given?
62
+
63
+ enum.each { |item| yield item }
64
+ end
65
+
66
+ def map(&block)
67
+ raise ArgumentError, "pipeline map requires a block" unless block_given?
68
+
69
+ self.class.new(lambda {
70
+ upstream = @factory.call
71
+ Enumerator.new do |y|
72
+ index = 0
73
+ upstream.each do |item|
74
+ y << __dsl_call_with_context(block, item, index, "pipeline map")
75
+ index += 1
76
+ end
77
+ end
78
+ })
79
+ end
80
+
81
+ def filter(&block)
82
+ raise ArgumentError, "pipeline filter requires a block" unless block_given?
83
+
84
+ self.class.new(lambda {
85
+ upstream = @factory.call
86
+ Enumerator.new do |y|
87
+ index = 0
88
+ upstream.each do |item|
89
+ y << item if __dsl_call_with_context(block, item, index, "pipeline filter")
90
+ index += 1
91
+ end
92
+ end
93
+ })
94
+ end
95
+
96
+ def batch(size, drop_last: false)
97
+ batch_size = size.to_i
98
+ raise ArgumentError, "pipeline batch size must be positive" if batch_size <= 0
99
+
100
+ self.class.new(lambda {
101
+ upstream = @factory.call
102
+ Enumerator.new do |y|
103
+ chunk = []
104
+ upstream.each do |item|
105
+ chunk << item
106
+ if chunk.length == batch_size
107
+ y << chunk
108
+ chunk = []
109
+ end
110
+ end
111
+ y << chunk unless drop_last || chunk.empty?
112
+ end
113
+ })
114
+ end
115
+
116
+ def take(count)
117
+ limit = count.to_i
118
+ raise ArgumentError, "pipeline take count must be non-negative" if limit.negative?
119
+
120
+ self.class.new(lambda {
121
+ upstream = @factory.call
122
+ Enumerator.new do |y|
123
+ seen = 0
124
+ while seen < limit
125
+ begin
126
+ y << upstream.next
127
+ seen += 1
128
+ rescue StopIteration
129
+ break
130
+ end
131
+ end
132
+ end
133
+ })
134
+ end
135
+
136
+ def repeat(times = nil)
137
+ if times.nil?
138
+ self.class.new(lambda {
139
+ Enumerator.new do |y|
140
+ loop do
141
+ upstream = @factory.call
142
+ produced = false
143
+ upstream.each do |item|
144
+ produced = true
145
+ y << item
146
+ end
147
+ break unless produced
148
+ end
149
+ end
150
+ })
151
+ else
152
+ cycles = times.to_i
153
+ raise ArgumentError, "pipeline repeat count must be non-negative" if cycles.negative?
154
+
155
+ self.class.new(lambda {
156
+ Enumerator.new do |y|
157
+ cycles.times do
158
+ @factory.call.each do |item|
159
+ y << item
160
+ end
161
+ end
162
+ end
163
+ })
164
+ end
165
+ end
166
+
167
+ def shuffle(seed: nil, random: nil)
168
+ if !seed.nil? && !random.nil?
169
+ raise ArgumentError, "pipeline shuffle accepts either seed: or random:, not both"
170
+ end
171
+
172
+ self.class.new(lambda {
173
+ items = @factory.call.to_a
174
+ rng = if !random.nil?
175
+ random
176
+ elsif !seed.nil?
177
+ Random.new(seed.to_i)
178
+ else
179
+ Random.new
180
+ end
181
+ items.shuffle(random: rng).to_enum
182
+ })
183
+ end
184
+
185
+ def prefetch(size = 1)
186
+ prefetch_size = size.to_i
187
+ raise ArgumentError, "pipeline prefetch size must be positive" if prefetch_size <= 0
188
+
189
+ self.class.new(lambda {
190
+ upstream = @factory.call
191
+ Enumerator.new do |y|
192
+ buffer = []
193
+
194
+ prefetch_size.times do
195
+ begin
196
+ buffer << upstream.next
197
+ rescue StopIteration
198
+ break
199
+ end
200
+ end
201
+
202
+ until buffer.empty?
203
+ y << buffer.shift
204
+ begin
205
+ buffer << upstream.next
206
+ rescue StopIteration
207
+ # Exhausted upstream; continue draining buffer.
208
+ end
209
+ end
210
+ end
211
+ })
212
+ end
213
+
214
+ private
215
+
216
+ def __dsl_call_with_context(callable, item, index, label)
217
+ values = {
218
+ item: item,
219
+ index: index,
220
+ pipeline: self
221
+ }
222
+ return callable.call(item, index) unless callable.respond_to?(:parameters)
223
+
224
+ params = callable.parameters
225
+ return callable.call(item, index) if params.empty?
226
+
227
+ args = __dsl_build_positional_args(
228
+ params,
229
+ values,
230
+ [[:item, item], [:index, index], [:pipeline, self]],
231
+ label
232
+ )
233
+ kwargs = __dsl_build_keyword_args(params, values, label)
234
+ return callable.call(*args) if kwargs.empty?
235
+
236
+ callable.call(*args, **kwargs)
237
+ end
238
+
239
+ def __dsl_build_positional_args(params, values, fallback_pairs, label)
240
+ queue = fallback_pairs.dup
241
+ args = []
242
+ params.each do |type, name|
243
+ next unless type == :req || type == :opt
244
+
245
+ if !name.nil? && values.key?(name)
246
+ args << values.fetch(name)
247
+ queue.reject! { |key, _value| key == name }
248
+ next
249
+ end
250
+
251
+ if queue.empty?
252
+ raise ArgumentError, "#{label} has unsupported required positional argument: #{name.inspect}" if type == :req
253
+ break
254
+ end
255
+
256
+ _key, value = queue.shift
257
+ args << value
258
+ end
259
+ args
260
+ end
261
+
262
+ def __dsl_build_keyword_args(params, values, label)
263
+ return values.dup if params.any? { |type, _name| type == :keyrest }
264
+
265
+ required_keys = params.each_with_object([]) do |(type, name), out|
266
+ out << name if type == :keyreq
267
+ end
268
+ missing = required_keys.reject { |name| values.key?(name) }
269
+ unless missing.empty?
270
+ raise ArgumentError, "#{label} requires unsupported keyword argument(s): #{missing.map(&:inspect).join(", ")}"
271
+ end
272
+
273
+ accepted_keys = params.each_with_object([]) do |(type, name), out|
274
+ out << name if type == :key || type == :keyreq
275
+ end
276
+
277
+ values.each_with_object({}) do |(name, value), out|
278
+ out[name] = value if accepted_keys.include?(name)
279
+ end
280
+ end
281
+ end
282
+ end
283
+ end
284
+ end
@@ -0,0 +1,154 @@
1
+ # frozen_string_literal: true
2
+
3
+ module MLX
4
+ module DSL
5
+ def self.experiment(name = nil, &block)
6
+ instance = Experiment.new(name: name)
7
+ instance.instance_eval(&block) if block_given?
8
+ instance
9
+ end
10
+
11
+ class Experiment
12
+ attr_reader :name
13
+
14
+ def initialize(name: nil)
15
+ @name = name
16
+ @model_source = nil
17
+ @optimizer_source = nil
18
+ @trainer_source = nil
19
+ @trainer_kwargs = {}
20
+ @loss_block = nil
21
+ @data_config = { train: nil, validation: nil, fit: {} }
22
+ @artifact_config = {}
23
+ @last_trainer = nil
24
+ @last_report = nil
25
+ end
26
+
27
+ def model(value = nil, &block)
28
+ if !value.nil? && block_given?
29
+ raise ArgumentError, "model accepts either a value argument or block, not both"
30
+ end
31
+
32
+ @model_source = block_given? ? block : value
33
+ self
34
+ end
35
+
36
+ def optimizer(value = nil, &block)
37
+ if !value.nil? && block_given?
38
+ raise ArgumentError, "optimizer accepts either a value argument or block, not both"
39
+ end
40
+
41
+ @optimizer_source = block_given? ? block : value
42
+ self
43
+ end
44
+
45
+ def trainer(value = nil, **kwargs, &block)
46
+ if value.is_a?(MLX::DSL::Trainer)
47
+ if !kwargs.empty? || block_given?
48
+ raise ArgumentError, "trainer instance injection cannot be combined with trainer kwargs or loss block"
49
+ end
50
+ @trainer_source = value
51
+ return self
52
+ end
53
+
54
+ unless value.nil?
55
+ raise ArgumentError, "trainer positional argument must be an MLX::DSL::Trainer instance"
56
+ end
57
+
58
+ @trainer_source = nil
59
+ @trainer_kwargs = kwargs.dup
60
+ @loss_block = block if block_given?
61
+ self
62
+ end
63
+
64
+ def data(train: nil, validation: :__dsl_unset__, **fit_kwargs)
65
+ @data_config[:train] = train unless train.nil?
66
+ @data_config[:validation] = validation unless validation == :__dsl_unset__
67
+ @data_config[:fit].merge!(fit_kwargs)
68
+ self
69
+ end
70
+
71
+ def artifacts(**kwargs)
72
+ @artifact_config.merge!(kwargs)
73
+ self
74
+ end
75
+
76
+ def run(report: false, **overrides)
77
+ dataset, fit_kwargs = __dsl_resolve_fit_call(overrides)
78
+ active_trainer = __dsl_resolve_trainer
79
+ result = if report
80
+ active_trainer.fit_report(dataset, **fit_kwargs)
81
+ else
82
+ active_trainer.fit(dataset, **fit_kwargs)
83
+ end
84
+ @last_report = result if report
85
+ result
86
+ end
87
+
88
+ def report(**overrides)
89
+ run(report: true, **overrides)
90
+ end
91
+
92
+ def save_run_bundle(path, report: nil, config: {}, **overrides)
93
+ active_report = report
94
+ if active_report.nil?
95
+ active_report = if !@last_report.nil?
96
+ @last_report
97
+ else
98
+ self.report(**overrides)
99
+ end
100
+ end
101
+
102
+ __dsl_resolve_trainer.save_run_bundle(path, report: active_report, config: config)
103
+ end
104
+
105
+ private
106
+
107
+ def __dsl_resolve_fit_call(overrides)
108
+ fit_kwargs = @data_config.fetch(:fit).dup
109
+ fit_kwargs.merge!(@artifact_config)
110
+ fit_kwargs[:validation_data] = @data_config[:validation] if !@data_config[:validation].nil? && !fit_kwargs.key?(:validation_data)
111
+
112
+ incoming = overrides.dup
113
+ dataset = if incoming.key?(:dataset)
114
+ incoming.delete(:dataset)
115
+ elsif incoming.key?(:train)
116
+ incoming.delete(:train)
117
+ else
118
+ @data_config[:train]
119
+ end
120
+ if dataset.nil?
121
+ raise ArgumentError, "experiment run requires a train dataset via data(train:) or run(dataset:)"
122
+ end
123
+
124
+ [dataset, fit_kwargs.merge(incoming)]
125
+ end
126
+
127
+ def __dsl_resolve_trainer
128
+ return @trainer_source if @trainer_source.is_a?(MLX::DSL::Trainer)
129
+ return @last_trainer unless @last_trainer.nil?
130
+
131
+ model = __dsl_resolve_source(@model_source, "model")
132
+ optimizer = __dsl_resolve_source(@optimizer_source, "optimizer")
133
+ unless model.respond_to?(:trainer)
134
+ raise ArgumentError, "experiment model must respond to #trainer when trainer instance is not injected"
135
+ end
136
+ unless @loss_block.respond_to?(:call)
137
+ raise ArgumentError, "experiment trainer requires a loss block when trainer instance is not injected"
138
+ end
139
+
140
+ @last_trainer = model.trainer(optimizer: optimizer, **@trainer_kwargs, &@loss_block)
141
+ end
142
+
143
+ def __dsl_resolve_source(source, label)
144
+ value = source
145
+ value = value.call if value.respond_to?(:call)
146
+ if value.nil?
147
+ raise ArgumentError, "experiment #{label} section is required when trainer instance is not injected"
148
+ end
149
+
150
+ value
151
+ end
152
+ end
153
+ end
154
+ end
@@ -0,0 +1,91 @@
1
+ # frozen_string_literal: true
2
+
3
+ module MLX
4
+ module DSL
5
+ class Callable < MLX::NN::Module
6
+ def initialize(callable = nil, &block)
7
+ super()
8
+ if !callable.nil? && block_given?
9
+ raise ArgumentError, "callable layer accepts either a callable argument or block, not both"
10
+ end
11
+
12
+ @callable = callable.nil? ? block : callable
13
+ unless @callable.respond_to?(:call)
14
+ raise ArgumentError, "callable layer requires a callable argument or block"
15
+ end
16
+ end
17
+
18
+ def call(*args, **kwargs)
19
+ return @callable.call(*args) if kwargs.empty?
20
+
21
+ @callable.call(*args, **kwargs)
22
+ end
23
+ end
24
+
25
+ class Residual < MLX::NN::Module
26
+ def initialize(module_obj)
27
+ super()
28
+ self.module_obj = module_obj
29
+ end
30
+
31
+ def call(*args, **kwargs)
32
+ raise ArgumentError, "residual module expects at least one positional input" if args.empty?
33
+
34
+ identity = args[0]
35
+ transformed = module_obj.call(*args, **kwargs)
36
+ MLX::Core.add(identity, transformed)
37
+ end
38
+ end
39
+
40
+ class Parallel < MLX::NN::Module
41
+ def initialize(*modules)
42
+ super()
43
+ self.layers = modules
44
+ end
45
+
46
+ def call(*args, **kwargs)
47
+ layers.map do |layer|
48
+ layer.call(*args, **kwargs)
49
+ end
50
+ end
51
+ end
52
+
53
+ class Concat < MLX::NN::Module
54
+ def initialize(*modules, axis: -1)
55
+ super()
56
+ self.layers = modules
57
+ @axis = axis
58
+ end
59
+
60
+ def call(*args, **kwargs)
61
+ outputs = layers.map do |layer|
62
+ layer.call(*args, **kwargs)
63
+ end
64
+ MLX::Core.concatenate(outputs, @axis)
65
+ end
66
+ end
67
+
68
+ class Reduce < MLX::NN::Module
69
+ def initialize(*modules, mode: :sum)
70
+ super()
71
+ self.layers = modules
72
+ @mode = mode.to_sym
73
+ end
74
+
75
+ def call(*args, **kwargs)
76
+ outputs = layers.map do |layer|
77
+ layer.call(*args, **kwargs)
78
+ end
79
+
80
+ case @mode
81
+ when :sum
82
+ outputs.reduce do |acc, item|
83
+ MLX::Core.add(acc, item)
84
+ end
85
+ else
86
+ raise ArgumentError, "unsupported reduce mode: #{@mode.inspect}"
87
+ end
88
+ end
89
+ end
90
+ end
91
+ end
@@ -0,0 +1,9 @@
1
+ # frozen_string_literal: true
2
+
3
+ module MLX
4
+ module DSL
5
+ class Model < MLX::NN::Module
6
+ include ModelMixin
7
+ end
8
+ end
9
+ end