mlx 0.30.7 → 0.30.7.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/ext/mlx/native.cpp +8 -6
- data/lib/mlx/core.rb +8 -1
- data/lib/mlx/distributed_utils/common.rb +1 -1
- data/lib/mlx/distributed_utils/config.rb +7 -4
- data/lib/mlx/distributed_utils/launch.rb +11 -3
- data/lib/mlx/dsl/attention.rb +132 -0
- data/lib/mlx/dsl/builder.rb +385 -0
- data/lib/mlx/dsl/config_schema.rb +133 -0
- data/lib/mlx/dsl/data_pipeline.rb +284 -0
- data/lib/mlx/dsl/experiment.rb +154 -0
- data/lib/mlx/dsl/generate.rb +193 -0
- data/lib/mlx/dsl/graph_modules.rb +91 -0
- data/lib/mlx/dsl/kv_cache.rb +96 -0
- data/lib/mlx/dsl/masks.rb +32 -0
- data/lib/mlx/dsl/model.rb +9 -0
- data/lib/mlx/dsl/model_mixin.rb +706 -0
- data/lib/mlx/dsl/positions.rb +35 -0
- data/lib/mlx/dsl/run_stack.rb +68 -0
- data/lib/mlx/dsl/split_plan.rb +85 -0
- data/lib/mlx/dsl/tensor.rb +126 -0
- data/lib/mlx/dsl/train_step.rb +197 -0
- data/lib/mlx/dsl/trainer.rb +2110 -0
- data/lib/mlx/dsl/transformer_block.rb +113 -0
- data/lib/mlx/dsl/weight_map.rb +140 -0
- data/lib/mlx/dsl.rb +26 -0
- data/lib/mlx/nn/layers/containers.rb +21 -4
- data/lib/mlx/version.rb +1 -1
- data/lib/mlx.rb +1 -0
- data/mlx/CMakeLists.txt +4 -16
- metadata +67 -5
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: 25d582e4816d69b27713a4027534b75cd00ca72557e69681daf07146d3e79ef2
|
|
4
|
+
data.tar.gz: c010252aa355370a531fa4f3b9bf8cc729876d2f7fb9ae8b8e0d6a1eb6cb57c4
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: 53e629e845342f173c04c7c6d9d976a29dd5492ae945239897d3168a586288ec958ba58753345317f301220ac5f4b91a22f97731ab799fcea5d59f3d19e48214
|
|
7
|
+
data.tar.gz: 5b04f2e63e3dcdb6a0282184a310600f4fb72e606b45e8e7a27a7b9461abef3a598afe2eb5525e66457da45b8a84a6c2f07c871c955e9cddd4308971496c7fd1
|
data/ext/mlx/native.cpp
CHANGED
|
@@ -6625,7 +6625,8 @@ static VALUE core_clear_cache(VALUE) {
|
|
|
6625
6625
|
|
|
6626
6626
|
static VALUE core_metal_is_available(VALUE) {
|
|
6627
6627
|
try {
|
|
6628
|
-
|
|
6628
|
+
const mx::Device gpu_device(mx::Device::gpu, 0);
|
|
6629
|
+
return mx::is_available(gpu_device) ? Qtrue : Qfalse;
|
|
6629
6630
|
} catch (const std::exception& error) {
|
|
6630
6631
|
raise_std_exception(error);
|
|
6631
6632
|
return Qnil;
|
|
@@ -6654,7 +6655,12 @@ static VALUE core_metal_stop_capture(VALUE) {
|
|
|
6654
6655
|
|
|
6655
6656
|
static VALUE core_metal_device_info(VALUE) {
|
|
6656
6657
|
try {
|
|
6657
|
-
const
|
|
6658
|
+
const mx::Device gpu_device(mx::Device::gpu, 0);
|
|
6659
|
+
if (!mx::is_available(gpu_device)) {
|
|
6660
|
+
rb_raise(rb_eRuntimeError, "[metal_device_info] Metal GPU device is not available");
|
|
6661
|
+
}
|
|
6662
|
+
|
|
6663
|
+
const auto& info = mx::device_info(gpu_device);
|
|
6658
6664
|
VALUE hash = rb_hash_new();
|
|
6659
6665
|
for (const auto& [key, value] : info) {
|
|
6660
6666
|
VALUE ruby_key = rb_utf8_str_new(key.c_str(), static_cast<long>(key.size()));
|
|
@@ -7884,9 +7890,6 @@ extern "C" void Init_native(void) {
|
|
|
7884
7890
|
"scaled_dot_product_attention",
|
|
7885
7891
|
RUBY_METHOD_FUNC(core_scaled_dot_product_attention),
|
|
7886
7892
|
-1);
|
|
7887
|
-
rb_define_singleton_method(
|
|
7888
|
-
mCore, "scaled_dot_product_attention", RUBY_METHOD_FUNC(core_scaled_dot_product_attention), -1);
|
|
7889
|
-
rb_define_singleton_method(mCore, "scaled_dot_product_attention", RUBY_METHOD_FUNC(core_scaled_dot_product_attention), -1);
|
|
7890
7893
|
rb_define_singleton_method(mCore, "arange", RUBY_METHOD_FUNC(core_arange), -1);
|
|
7891
7894
|
rb_define_singleton_method(mCore, "linspace", RUBY_METHOD_FUNC(core_linspace), -1);
|
|
7892
7895
|
rb_define_singleton_method(mCore, "zeros", RUBY_METHOD_FUNC(core_zeros), -1);
|
|
@@ -8023,5 +8026,4 @@ extern "C" void Init_native(void) {
|
|
|
8023
8026
|
"precompiled_cuda_kernel",
|
|
8024
8027
|
RUBY_METHOD_FUNC(core_precompiled_cuda_kernel),
|
|
8025
8028
|
-1);
|
|
8026
|
-
rb_define_singleton_method(mCore, "precompiled_cuda_kernel", RUBY_METHOD_FUNC(core_precompiled_cuda_kernel), -1);
|
|
8027
8029
|
}
|
data/lib/mlx/core.rb
CHANGED
|
@@ -335,6 +335,12 @@ module MLX
|
|
|
335
335
|
alias_method :native_export_to_dot,
|
|
336
336
|
:export_to_dot if method_defined?(:export_to_dot) && !method_defined?(:native_export_to_dot)
|
|
337
337
|
|
|
338
|
+
%i[savez savez_compressed].each do |method_name|
|
|
339
|
+
if method_defined?(method_name) && instance_method(method_name).owner == self
|
|
340
|
+
remove_method(method_name)
|
|
341
|
+
end
|
|
342
|
+
end
|
|
343
|
+
|
|
338
344
|
ARRAY_LEAF = :__mlx_array_leaf__
|
|
339
345
|
|
|
340
346
|
def load(file, format = nil, return_metadata = false)
|
|
@@ -963,7 +969,8 @@ module MLX
|
|
|
963
969
|
end
|
|
964
970
|
end
|
|
965
971
|
|
|
966
|
-
|
|
972
|
+
remove_method(:eql?) if method_defined?(:eql?) && instance_method(:eql?).owner == self
|
|
973
|
+
alias_method :eql?, :==
|
|
967
974
|
end
|
|
968
975
|
|
|
969
976
|
class Array
|
|
@@ -8,13 +8,14 @@ require "shellwords"
|
|
|
8
8
|
|
|
9
9
|
module MLX
|
|
10
10
|
module DistributedUtils
|
|
11
|
-
SSHInfo =
|
|
11
|
+
SSHInfo = Data.define(:can_ssh, :has_sudo) do
|
|
12
12
|
def to_bool
|
|
13
13
|
can_ssh
|
|
14
14
|
end
|
|
15
15
|
end
|
|
16
|
-
ThunderboltPort =
|
|
17
|
-
ThunderboltHost =
|
|
16
|
+
ThunderboltPort = Data.define(:iface, :uuid, :connected_to)
|
|
17
|
+
ThunderboltHost = Data.define(:name, :ports)
|
|
18
|
+
CommandResult = Data.define(:stdout, :stderr, :status)
|
|
18
19
|
|
|
19
20
|
class IPConfigurator
|
|
20
21
|
attr_reader :ips, :hosts, :tb_hosts
|
|
@@ -509,6 +510,8 @@ module MLX
|
|
|
509
510
|
end
|
|
510
511
|
|
|
511
512
|
def config_main(argv = ARGV, runner: nil)
|
|
513
|
+
Process.warmup if Process.respond_to?(:warmup)
|
|
514
|
+
|
|
512
515
|
opts = {
|
|
513
516
|
verbose: false,
|
|
514
517
|
hosts: "127.0.0.1",
|
|
@@ -577,7 +580,7 @@ module MLX
|
|
|
577
580
|
return runner.call(cmd) unless runner.nil?
|
|
578
581
|
|
|
579
582
|
stdout, stderr, status = Open3.capture3(*cmd)
|
|
580
|
-
|
|
583
|
+
CommandResult.new(stdout: stdout, stderr: stderr, status: status)
|
|
581
584
|
end
|
|
582
585
|
|
|
583
586
|
def stdout_for(result)
|
|
@@ -314,6 +314,8 @@ module MLX
|
|
|
314
314
|
end
|
|
315
315
|
|
|
316
316
|
def main(argv = ARGV)
|
|
317
|
+
Process.warmup if Process.respond_to?(:warmup)
|
|
318
|
+
|
|
317
319
|
opts = {
|
|
318
320
|
print_python: false,
|
|
319
321
|
verbose: false,
|
|
@@ -373,12 +375,18 @@ module MLX
|
|
|
373
375
|
opts[:env] = hostfile.envs + opts[:env]
|
|
374
376
|
|
|
375
377
|
command = rest.dup
|
|
376
|
-
|
|
377
|
-
|
|
378
|
+
command_name = command.first.to_s
|
|
379
|
+
script = Pathname.new(command_name)
|
|
380
|
+
explicit_path = command_name.include?(File::SEPARATOR) || command_name.start_with?(".", "~")
|
|
381
|
+
|
|
382
|
+
if explicit_path && script.file?
|
|
378
383
|
command[0] = opts[:python]
|
|
379
384
|
command.insert(1, script.realpath.to_s)
|
|
380
|
-
elsif (resolved = find_executable(
|
|
385
|
+
elsif (resolved = find_executable(command_name))
|
|
381
386
|
command[0] = resolved
|
|
387
|
+
elsif script.file?
|
|
388
|
+
command[0] = opts[:python]
|
|
389
|
+
command.insert(1, script.realpath.to_s)
|
|
382
390
|
elsif opts[:verify_script]
|
|
383
391
|
raise ArgumentError, "Invalid script or command #{command.first}"
|
|
384
392
|
end
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module MLX
|
|
4
|
+
module DSL
|
|
5
|
+
class Attention < MLX::NN::Module
|
|
6
|
+
def initialize(
|
|
7
|
+
dims:,
|
|
8
|
+
num_heads:,
|
|
9
|
+
kv_heads: nil,
|
|
10
|
+
qkv_bias: false,
|
|
11
|
+
backend: :sdpa,
|
|
12
|
+
rope: nil,
|
|
13
|
+
cache: false
|
|
14
|
+
)
|
|
15
|
+
super()
|
|
16
|
+
|
|
17
|
+
@dims = Integer(dims)
|
|
18
|
+
@num_heads = Integer(num_heads)
|
|
19
|
+
@kv_heads = kv_heads.nil? ? @num_heads : Integer(kv_heads)
|
|
20
|
+
if (@dims % @num_heads) != 0
|
|
21
|
+
raise ArgumentError, "dims must be divisible by num_heads"
|
|
22
|
+
end
|
|
23
|
+
if (@num_heads % @kv_heads) != 0
|
|
24
|
+
raise ArgumentError, "num_heads must be divisible by kv_heads"
|
|
25
|
+
end
|
|
26
|
+
|
|
27
|
+
@head_dim = @dims / @num_heads
|
|
28
|
+
@kv_repeats = @num_heads / @kv_heads
|
|
29
|
+
@backend = backend.to_sym
|
|
30
|
+
@cache_enabled = !!cache
|
|
31
|
+
@scale = Math.sqrt(1.0 / @head_dim)
|
|
32
|
+
|
|
33
|
+
self.query_proj = MLX::NN::Linear.new(@dims, @num_heads * @head_dim, bias: qkv_bias)
|
|
34
|
+
self.key_proj = MLX::NN::Linear.new(@dims, @kv_heads * @head_dim, bias: qkv_bias)
|
|
35
|
+
self.value_proj = MLX::NN::Linear.new(@dims, @kv_heads * @head_dim, bias: qkv_bias)
|
|
36
|
+
self.out_proj = MLX::NN::Linear.new(@num_heads * @head_dim, @dims, bias: qkv_bias)
|
|
37
|
+
self.rope = __dsl_build_rope(rope)
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
def call(queries, keys = nil, values = nil, mask: nil, cache: nil)
|
|
41
|
+
keys ||= queries
|
|
42
|
+
values ||= keys
|
|
43
|
+
q_was_2d = queries.ndim == 2
|
|
44
|
+
|
|
45
|
+
queries = MLX::Core.expand_dims(queries, 0) if q_was_2d
|
|
46
|
+
keys = MLX::Core.expand_dims(keys, 0) if keys.ndim == 2
|
|
47
|
+
values = MLX::Core.expand_dims(values, 0) if values.ndim == 2
|
|
48
|
+
|
|
49
|
+
batch_size, q_len, = queries.shape
|
|
50
|
+
|
|
51
|
+
q = __dsl_pack_heads(query_proj.call(queries), @num_heads)
|
|
52
|
+
k = __dsl_pack_heads(key_proj.call(keys), @kv_heads)
|
|
53
|
+
v = __dsl_pack_heads(value_proj.call(values), @kv_heads)
|
|
54
|
+
|
|
55
|
+
offset = cache.nil? ? 0 : cache[0].shape[2]
|
|
56
|
+
if !rope.nil?
|
|
57
|
+
if offset.zero?
|
|
58
|
+
q = rope.call(q)
|
|
59
|
+
k = rope.call(k)
|
|
60
|
+
else
|
|
61
|
+
q = rope.call(q, offset: offset)
|
|
62
|
+
k = rope.call(k, offset: offset)
|
|
63
|
+
end
|
|
64
|
+
end
|
|
65
|
+
|
|
66
|
+
unless cache.nil?
|
|
67
|
+
key_cache, value_cache = cache
|
|
68
|
+
k = MLX::Core.concatenate([key_cache, k], 2)
|
|
69
|
+
v = MLX::Core.concatenate([value_cache, v], 2)
|
|
70
|
+
end
|
|
71
|
+
next_cache = [k, v]
|
|
72
|
+
|
|
73
|
+
k_for_attn = __dsl_repeat_kv(k)
|
|
74
|
+
v_for_attn = __dsl_repeat_kv(v)
|
|
75
|
+
out = __dsl_attention(q, k_for_attn, v_for_attn, mask)
|
|
76
|
+
out = MLX::Core.transpose(out, [0, 2, 1, 3])
|
|
77
|
+
out = MLX::Core.reshape(out, [batch_size, q_len, @num_heads * @head_dim])
|
|
78
|
+
out = out_proj.call(out)
|
|
79
|
+
out = MLX::Core.squeeze(out, 0) if q_was_2d
|
|
80
|
+
|
|
81
|
+
if @cache_enabled || !cache.nil?
|
|
82
|
+
[out, next_cache]
|
|
83
|
+
else
|
|
84
|
+
out
|
|
85
|
+
end
|
|
86
|
+
end
|
|
87
|
+
|
|
88
|
+
private
|
|
89
|
+
|
|
90
|
+
def __dsl_build_rope(config)
|
|
91
|
+
return nil if config.nil?
|
|
92
|
+
|
|
93
|
+
opts = config.transform_keys(&:to_sym)
|
|
94
|
+
rope_kwargs = {
|
|
95
|
+
traditional: opts.fetch(:traditional, false),
|
|
96
|
+
base: opts.fetch(:base, 10_000.0)
|
|
97
|
+
}
|
|
98
|
+
rope_kwargs[:scale] = opts[:scale] if opts.key?(:scale)
|
|
99
|
+
MLX::NN::RoPE.new(@head_dim, **rope_kwargs)
|
|
100
|
+
end
|
|
101
|
+
|
|
102
|
+
def __dsl_pack_heads(x, heads)
|
|
103
|
+
batch, length, = x.shape
|
|
104
|
+
x = MLX::Core.reshape(x, [batch, length, heads, @head_dim])
|
|
105
|
+
MLX::Core.transpose(x, [0, 2, 1, 3])
|
|
106
|
+
end
|
|
107
|
+
|
|
108
|
+
def __dsl_repeat_kv(x)
|
|
109
|
+
return x if @kv_repeats == 1
|
|
110
|
+
|
|
111
|
+
batch, _heads, length, dim = x.shape
|
|
112
|
+
expanded = MLX::Core.expand_dims(x, 2)
|
|
113
|
+
repeated = MLX::Core.concatenate(Array.new(@kv_repeats, expanded), 2)
|
|
114
|
+
MLX::Core.reshape(repeated, [batch, @num_heads, length, dim])
|
|
115
|
+
end
|
|
116
|
+
|
|
117
|
+
def __dsl_attention(q, k, v, mask)
|
|
118
|
+
if @backend == :sdpa && MLX::Core.respond_to?(:scaled_dot_product_attention)
|
|
119
|
+
return MLX::Core.scaled_dot_product_attention(q, k, v, @scale, mask)
|
|
120
|
+
end
|
|
121
|
+
|
|
122
|
+
scores = MLX::Core.matmul(
|
|
123
|
+
MLX::Core.multiply(q, @scale),
|
|
124
|
+
MLX::Core.transpose(k, [0, 1, 3, 2])
|
|
125
|
+
)
|
|
126
|
+
scores = MLX::Core.add(scores, mask.astype(scores.dtype)) unless mask.nil?
|
|
127
|
+
probs = MLX::Core.softmax(scores.astype(MLX::Core.float32), -1).astype(scores.dtype)
|
|
128
|
+
MLX::Core.matmul(probs, v)
|
|
129
|
+
end
|
|
130
|
+
end
|
|
131
|
+
end
|
|
132
|
+
end
|
|
@@ -0,0 +1,385 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module MLX
|
|
4
|
+
module DSL
|
|
5
|
+
class Builder
|
|
6
|
+
def initialize(owner = nil)
|
|
7
|
+
@owner = owner
|
|
8
|
+
@collector = nil
|
|
9
|
+
end
|
|
10
|
+
|
|
11
|
+
def build(&block)
|
|
12
|
+
raise ArgumentError, "builder requires a block" unless block_given?
|
|
13
|
+
|
|
14
|
+
instance_eval(&block)
|
|
15
|
+
end
|
|
16
|
+
|
|
17
|
+
def sequential(*modules, &block)
|
|
18
|
+
collected = __dsl_modules_from(modules, &block)
|
|
19
|
+
push(MLX::NN::Sequential.new(*collected))
|
|
20
|
+
end
|
|
21
|
+
|
|
22
|
+
def layer(entry = nil, *args, **kwargs, &block)
|
|
23
|
+
if !entry.nil? && block_given?
|
|
24
|
+
raise ArgumentError, "layer accepts either a module entry or block, not both"
|
|
25
|
+
end
|
|
26
|
+
|
|
27
|
+
if block_given?
|
|
28
|
+
return push(MLX::DSL::Callable.new(&block))
|
|
29
|
+
end
|
|
30
|
+
|
|
31
|
+
if entry.nil?
|
|
32
|
+
raise ArgumentError, "layer requires a module entry or block"
|
|
33
|
+
end
|
|
34
|
+
|
|
35
|
+
if entry.is_a?(MLX::NN::Module)
|
|
36
|
+
__dsl_reject_layer_constructor_args!(args, kwargs, entry.class)
|
|
37
|
+
return push(entry)
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
if entry.is_a?(Class)
|
|
41
|
+
unless entry <= MLX::NN::Module
|
|
42
|
+
raise TypeError, "layer class must inherit from MLX::NN::Module"
|
|
43
|
+
end
|
|
44
|
+
return push(entry.new(*args, **kwargs))
|
|
45
|
+
end
|
|
46
|
+
|
|
47
|
+
if entry.respond_to?(:call)
|
|
48
|
+
__dsl_reject_layer_constructor_args!(args, kwargs, entry.class)
|
|
49
|
+
return push(MLX::DSL::Callable.new(entry))
|
|
50
|
+
end
|
|
51
|
+
|
|
52
|
+
raise TypeError, "layer requires an MLX::NN::Module instance, MLX::NN::Module class, callable, or block"
|
|
53
|
+
end
|
|
54
|
+
|
|
55
|
+
def residual(module_obj = nil, &block)
|
|
56
|
+
modules = __dsl_modules_from(module_obj.nil? ? [] : [module_obj], &block)
|
|
57
|
+
raise ArgumentError, "residual requires at least one module" if modules.empty?
|
|
58
|
+
|
|
59
|
+
target = if modules.length == 1
|
|
60
|
+
modules[0]
|
|
61
|
+
else
|
|
62
|
+
MLX::NN::Sequential.new(*modules)
|
|
63
|
+
end
|
|
64
|
+
push(MLX::DSL::Residual.new(target))
|
|
65
|
+
end
|
|
66
|
+
|
|
67
|
+
def branch(*modules, &block)
|
|
68
|
+
collected = __dsl_modules_from(modules, &block)
|
|
69
|
+
raise ArgumentError, "branch requires at least one module" if collected.empty?
|
|
70
|
+
|
|
71
|
+
push(MLX::DSL::Parallel.new(*collected))
|
|
72
|
+
end
|
|
73
|
+
|
|
74
|
+
def concat(*modules, axis: -1, &block)
|
|
75
|
+
collected = __dsl_modules_from(modules, &block)
|
|
76
|
+
raise ArgumentError, "concat requires at least one module" if collected.empty?
|
|
77
|
+
|
|
78
|
+
push(MLX::DSL::Concat.new(*collected, axis: axis))
|
|
79
|
+
end
|
|
80
|
+
|
|
81
|
+
def sum(*modules, &block)
|
|
82
|
+
collected = __dsl_modules_from(modules, &block)
|
|
83
|
+
raise ArgumentError, "sum requires at least one module" if collected.empty?
|
|
84
|
+
|
|
85
|
+
push(MLX::DSL::Reduce.new(*collected, mode: :sum))
|
|
86
|
+
end
|
|
87
|
+
|
|
88
|
+
def fn(callable = nil, &block)
|
|
89
|
+
push(MLX::DSL::Callable.new(callable, &block))
|
|
90
|
+
end
|
|
91
|
+
alias_method :lambda_layer, :fn
|
|
92
|
+
|
|
93
|
+
def repeat_layers(count, &block)
|
|
94
|
+
entries = __dsl_collect_repeated_entries(count, &block)
|
|
95
|
+
layers = entries.map { |entry| __dsl_normalize_module_entry(entry) }
|
|
96
|
+
layers.each { |layer| push(layer) }
|
|
97
|
+
layers
|
|
98
|
+
end
|
|
99
|
+
|
|
100
|
+
def stack(count, layer_class = nil, *args, **kwargs, &block)
|
|
101
|
+
if !layer_class.nil? && block_given?
|
|
102
|
+
raise ArgumentError, "stack accepts either a layer class or block, not both"
|
|
103
|
+
end
|
|
104
|
+
|
|
105
|
+
layers = if layer_class.nil?
|
|
106
|
+
__dsl_collect_repeated_entries(count, &block).map { |entry| __dsl_normalize_module_entry(entry) }
|
|
107
|
+
else
|
|
108
|
+
__dsl_build_class_stack_layers(count, layer_class, args, kwargs)
|
|
109
|
+
end
|
|
110
|
+
push(MLX::NN::Sequential.new(*layers))
|
|
111
|
+
end
|
|
112
|
+
|
|
113
|
+
def identity(*args, **kwargs)
|
|
114
|
+
push(MLX::NN::Identity.new(*args, **kwargs))
|
|
115
|
+
end
|
|
116
|
+
|
|
117
|
+
def embedding(*args, **kwargs)
|
|
118
|
+
push(MLX::NN::Embedding.new(*args, **kwargs))
|
|
119
|
+
end
|
|
120
|
+
|
|
121
|
+
def linear(*args, **kwargs)
|
|
122
|
+
push(MLX::NN::Linear.new(*args, **kwargs))
|
|
123
|
+
end
|
|
124
|
+
|
|
125
|
+
def bilinear(*args, **kwargs)
|
|
126
|
+
push(MLX::NN::Bilinear.new(*args, **kwargs))
|
|
127
|
+
end
|
|
128
|
+
|
|
129
|
+
def relu
|
|
130
|
+
push(MLX::NN::ReLU.new)
|
|
131
|
+
end
|
|
132
|
+
|
|
133
|
+
def relu6
|
|
134
|
+
push(MLX::NN::ReLU6.new)
|
|
135
|
+
end
|
|
136
|
+
|
|
137
|
+
def leaky_relu(*args)
|
|
138
|
+
push(MLX::NN::LeakyReLU.new(*args))
|
|
139
|
+
end
|
|
140
|
+
|
|
141
|
+
def gelu(*args, **kwargs)
|
|
142
|
+
push(MLX::NN::GELU.new(*args, **kwargs))
|
|
143
|
+
end
|
|
144
|
+
|
|
145
|
+
def tanh
|
|
146
|
+
push(MLX::NN::Tanh.new)
|
|
147
|
+
end
|
|
148
|
+
|
|
149
|
+
def sigmoid
|
|
150
|
+
push(MLX::NN::Sigmoid.new)
|
|
151
|
+
end
|
|
152
|
+
|
|
153
|
+
def dropout(*args)
|
|
154
|
+
push(MLX::NN::Dropout.new(*args))
|
|
155
|
+
end
|
|
156
|
+
|
|
157
|
+
def dropout2d(*args)
|
|
158
|
+
push(MLX::NN::Dropout2d.new(*args))
|
|
159
|
+
end
|
|
160
|
+
|
|
161
|
+
def dropout3d(*args)
|
|
162
|
+
push(MLX::NN::Dropout3d.new(*args))
|
|
163
|
+
end
|
|
164
|
+
|
|
165
|
+
def conv1d(*args, **kwargs)
|
|
166
|
+
push(MLX::NN::Conv1d.new(*args, **kwargs))
|
|
167
|
+
end
|
|
168
|
+
|
|
169
|
+
def conv2d(*args, **kwargs)
|
|
170
|
+
push(MLX::NN::Conv2d.new(*args, **kwargs))
|
|
171
|
+
end
|
|
172
|
+
|
|
173
|
+
def conv3d(*args, **kwargs)
|
|
174
|
+
push(MLX::NN::Conv3d.new(*args, **kwargs))
|
|
175
|
+
end
|
|
176
|
+
|
|
177
|
+
def conv_transpose1d(*args, **kwargs)
|
|
178
|
+
push(MLX::NN::ConvTranspose1d.new(*args, **kwargs))
|
|
179
|
+
end
|
|
180
|
+
|
|
181
|
+
def conv_transpose2d(*args, **kwargs)
|
|
182
|
+
push(MLX::NN::ConvTranspose2d.new(*args, **kwargs))
|
|
183
|
+
end
|
|
184
|
+
|
|
185
|
+
def conv_transpose3d(*args, **kwargs)
|
|
186
|
+
push(MLX::NN::ConvTranspose3d.new(*args, **kwargs))
|
|
187
|
+
end
|
|
188
|
+
|
|
189
|
+
def layer_norm(*args, **kwargs)
|
|
190
|
+
push(MLX::NN::LayerNorm.new(*args, **kwargs))
|
|
191
|
+
end
|
|
192
|
+
|
|
193
|
+
def rms_norm(*args, **kwargs)
|
|
194
|
+
push(MLX::NN::RMSNorm.new(*args, **kwargs))
|
|
195
|
+
end
|
|
196
|
+
|
|
197
|
+
def batch_norm(*args, **kwargs)
|
|
198
|
+
push(MLX::NN::BatchNorm.new(*args, **kwargs))
|
|
199
|
+
end
|
|
200
|
+
|
|
201
|
+
def instance_norm(*args, **kwargs)
|
|
202
|
+
push(MLX::NN::InstanceNorm.new(*args, **kwargs))
|
|
203
|
+
end
|
|
204
|
+
|
|
205
|
+
def group_norm(*args, **kwargs)
|
|
206
|
+
push(MLX::NN::GroupNorm.new(*args, **kwargs))
|
|
207
|
+
end
|
|
208
|
+
|
|
209
|
+
def max_pool2d(*args, **kwargs)
|
|
210
|
+
push(MLX::NN::MaxPool2d.new(*args, **kwargs))
|
|
211
|
+
end
|
|
212
|
+
|
|
213
|
+
def avg_pool2d(*args, **kwargs)
|
|
214
|
+
push(MLX::NN::AvgPool2d.new(*args, **kwargs))
|
|
215
|
+
end
|
|
216
|
+
|
|
217
|
+
def max_pool1d(*args, **kwargs)
|
|
218
|
+
push(MLX::NN::MaxPool1d.new(*args, **kwargs))
|
|
219
|
+
end
|
|
220
|
+
|
|
221
|
+
def avg_pool1d(*args, **kwargs)
|
|
222
|
+
push(MLX::NN::AvgPool1d.new(*args, **kwargs))
|
|
223
|
+
end
|
|
224
|
+
|
|
225
|
+
def max_pool3d(*args, **kwargs)
|
|
226
|
+
push(MLX::NN::MaxPool3d.new(*args, **kwargs))
|
|
227
|
+
end
|
|
228
|
+
|
|
229
|
+
def avg_pool3d(*args, **kwargs)
|
|
230
|
+
push(MLX::NN::AvgPool3d.new(*args, **kwargs))
|
|
231
|
+
end
|
|
232
|
+
|
|
233
|
+
def rnn(*args, **kwargs)
|
|
234
|
+
push(MLX::NN::RNN.new(*args, **kwargs))
|
|
235
|
+
end
|
|
236
|
+
|
|
237
|
+
def gru(*args, **kwargs)
|
|
238
|
+
push(MLX::NN::GRU.new(*args, **kwargs))
|
|
239
|
+
end
|
|
240
|
+
|
|
241
|
+
def lstm(*args, **kwargs)
|
|
242
|
+
push(MLX::NN::LSTM.new(*args, **kwargs))
|
|
243
|
+
end
|
|
244
|
+
|
|
245
|
+
def multi_head_attention(*args, **kwargs)
|
|
246
|
+
push(MLX::NN::MultiHeadAttention.new(*args, **kwargs))
|
|
247
|
+
end
|
|
248
|
+
|
|
249
|
+
def transformer_encoder_layer(*args, **kwargs)
|
|
250
|
+
push(MLX::NN::TransformerEncoderLayer.new(*args, **kwargs))
|
|
251
|
+
end
|
|
252
|
+
|
|
253
|
+
def transformer_encoder(*args, **kwargs)
|
|
254
|
+
push(MLX::NN::TransformerEncoder.new(*args, **kwargs))
|
|
255
|
+
end
|
|
256
|
+
|
|
257
|
+
def transformer_decoder_layer(*args, **kwargs)
|
|
258
|
+
push(MLX::NN::TransformerDecoderLayer.new(*args, **kwargs))
|
|
259
|
+
end
|
|
260
|
+
|
|
261
|
+
def transformer_decoder(*args, **kwargs)
|
|
262
|
+
push(MLX::NN::TransformerDecoder.new(*args, **kwargs))
|
|
263
|
+
end
|
|
264
|
+
|
|
265
|
+
def transformer(*args, **kwargs)
|
|
266
|
+
push(MLX::NN::Transformer.new(*args, **kwargs))
|
|
267
|
+
end
|
|
268
|
+
|
|
269
|
+
def attention(*args, **kwargs)
|
|
270
|
+
push(MLX::DSL::Attention.new(*args, **kwargs))
|
|
271
|
+
end
|
|
272
|
+
|
|
273
|
+
def transformer_block(*args, **kwargs)
|
|
274
|
+
push(MLX::DSL::TransformerBlock.new(*args, **kwargs))
|
|
275
|
+
end
|
|
276
|
+
|
|
277
|
+
def rope(*args, **kwargs)
|
|
278
|
+
push(MLX::NN::RoPE.new(*args, **kwargs))
|
|
279
|
+
end
|
|
280
|
+
|
|
281
|
+
def sinusoidal_positional_encoding(*args, **kwargs)
|
|
282
|
+
push(MLX::NN::SinusoidalPositionalEncoding.new(*args, **kwargs))
|
|
283
|
+
end
|
|
284
|
+
|
|
285
|
+
def alibi(*args, **kwargs)
|
|
286
|
+
push(MLX::NN::ALiBi.new(*args, **kwargs))
|
|
287
|
+
end
|
|
288
|
+
|
|
289
|
+
def upsample(*args, **kwargs)
|
|
290
|
+
push(MLX::NN::Upsample.new(*args, **kwargs))
|
|
291
|
+
end
|
|
292
|
+
|
|
293
|
+
def method_missing(name, *args, **kwargs, &block)
|
|
294
|
+
if !@owner.nil? && @owner.respond_to?(name)
|
|
295
|
+
@owner.public_send(name, *args, **kwargs, &block)
|
|
296
|
+
else
|
|
297
|
+
super
|
|
298
|
+
end
|
|
299
|
+
end
|
|
300
|
+
|
|
301
|
+
def respond_to_missing?(name, include_private = false)
|
|
302
|
+
(!@owner.nil? && @owner.respond_to?(name, include_private)) || super
|
|
303
|
+
end
|
|
304
|
+
|
|
305
|
+
private
|
|
306
|
+
|
|
307
|
+
def collect_modules(&block)
|
|
308
|
+
previous = @collector
|
|
309
|
+
@collector = []
|
|
310
|
+
returned = instance_eval(&block)
|
|
311
|
+
collected = @collector.dup
|
|
312
|
+
if collected.empty? && !returned.nil?
|
|
313
|
+
collected << returned
|
|
314
|
+
end
|
|
315
|
+
collected
|
|
316
|
+
ensure
|
|
317
|
+
@collector = previous
|
|
318
|
+
end
|
|
319
|
+
|
|
320
|
+
def push(module_obj)
|
|
321
|
+
@collector << module_obj unless @collector.nil?
|
|
322
|
+
module_obj
|
|
323
|
+
end
|
|
324
|
+
|
|
325
|
+
def __dsl_modules_from(existing, &block)
|
|
326
|
+
out = existing.dup
|
|
327
|
+
out.concat(collect_modules(&block)) if block_given?
|
|
328
|
+
out.map { |entry| __dsl_normalize_module_entry(entry) }
|
|
329
|
+
end
|
|
330
|
+
|
|
331
|
+
def __dsl_normalize_module_entry(entry)
|
|
332
|
+
return entry if entry.is_a?(MLX::NN::Module)
|
|
333
|
+
|
|
334
|
+
if entry.is_a?(Class)
|
|
335
|
+
return entry.new if entry <= MLX::NN::Module
|
|
336
|
+
|
|
337
|
+
raise TypeError, "builder entries must be MLX::NN::Module instances, MLX::NN::Module classes, or callables"
|
|
338
|
+
end
|
|
339
|
+
|
|
340
|
+
return MLX::DSL::Callable.new(entry) if entry.respond_to?(:call)
|
|
341
|
+
|
|
342
|
+
raise TypeError, "builder entries must be MLX::NN::Module instances, MLX::NN::Module classes, or callables"
|
|
343
|
+
end
|
|
344
|
+
|
|
345
|
+
def __dsl_reject_layer_constructor_args!(args, kwargs, entry_type)
|
|
346
|
+
return if args.empty? && kwargs.empty?
|
|
347
|
+
|
|
348
|
+
raise ArgumentError, "layer entry #{entry_type} does not accept constructor arguments"
|
|
349
|
+
end
|
|
350
|
+
|
|
351
|
+
def __dsl_collect_repeated_entries(count, &block)
|
|
352
|
+
raise ArgumentError, "repeat requires a block" unless block_given?
|
|
353
|
+
|
|
354
|
+
repeats = count.to_i
|
|
355
|
+
raise ArgumentError, "repeat count must be non-negative" if repeats.negative?
|
|
356
|
+
|
|
357
|
+
out = []
|
|
358
|
+
repeats.times do |index|
|
|
359
|
+
out.concat(
|
|
360
|
+
collect_modules do
|
|
361
|
+
__dsl_call_repeat_block(block, index)
|
|
362
|
+
end
|
|
363
|
+
)
|
|
364
|
+
end
|
|
365
|
+
out
|
|
366
|
+
end
|
|
367
|
+
|
|
368
|
+
def __dsl_call_repeat_block(block, index)
|
|
369
|
+
return instance_eval(&block) if block.arity.zero?
|
|
370
|
+
|
|
371
|
+
block.call(index)
|
|
372
|
+
end
|
|
373
|
+
|
|
374
|
+
def __dsl_build_class_stack_layers(count, layer_class, args, kwargs)
|
|
375
|
+
repeats = count.to_i
|
|
376
|
+
raise ArgumentError, "stack count must be non-negative" if repeats.negative?
|
|
377
|
+
unless layer_class.is_a?(Class) && layer_class <= MLX::NN::Module
|
|
378
|
+
raise TypeError, "stack layer class must inherit from MLX::NN::Module"
|
|
379
|
+
end
|
|
380
|
+
|
|
381
|
+
Array.new(repeats) { layer_class.new(*args, **kwargs) }
|
|
382
|
+
end
|
|
383
|
+
end
|
|
384
|
+
end
|
|
385
|
+
end
|