mlx 0.30.7.2 → 0.30.7.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/ext/mlx/native.cpp +8 -2
- data/lib/mlx/distributed_utils/common.rb +1 -1
- data/lib/mlx/distributed_utils/config.rb +7 -4
- data/lib/mlx/distributed_utils/launch.rb +2 -0
- data/lib/mlx/dsl/attention.rb +132 -0
- data/lib/mlx/dsl/builder.rb +8 -0
- data/lib/mlx/dsl/config_schema.rb +133 -0
- data/lib/mlx/dsl/generate.rb +193 -0
- data/lib/mlx/dsl/kv_cache.rb +96 -0
- data/lib/mlx/dsl/masks.rb +32 -0
- data/lib/mlx/dsl/positions.rb +35 -0
- data/lib/mlx/dsl/run_stack.rb +68 -0
- data/lib/mlx/dsl/tensor.rb +126 -0
- data/lib/mlx/dsl/transformer_block.rb +113 -0
- data/lib/mlx/dsl/weight_map.rb +140 -0
- data/lib/mlx/dsl.rb +10 -0
- data/lib/mlx/version.rb +1 -1
- metadata +57 -5
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: 25d582e4816d69b27713a4027534b75cd00ca72557e69681daf07146d3e79ef2
|
|
4
|
+
data.tar.gz: c010252aa355370a531fa4f3b9bf8cc729876d2f7fb9ae8b8e0d6a1eb6cb57c4
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: 53e629e845342f173c04c7c6d9d976a29dd5492ae945239897d3168a586288ec958ba58753345317f301220ac5f4b91a22f97731ab799fcea5d59f3d19e48214
|
|
7
|
+
data.tar.gz: 5b04f2e63e3dcdb6a0282184a310600f4fb72e606b45e8e7a27a7b9461abef3a598afe2eb5525e66457da45b8a84a6c2f07c871c955e9cddd4308971496c7fd1
|
data/ext/mlx/native.cpp
CHANGED
|
@@ -6625,7 +6625,8 @@ static VALUE core_clear_cache(VALUE) {
|
|
|
6625
6625
|
|
|
6626
6626
|
static VALUE core_metal_is_available(VALUE) {
|
|
6627
6627
|
try {
|
|
6628
|
-
|
|
6628
|
+
const mx::Device gpu_device(mx::Device::gpu, 0);
|
|
6629
|
+
return mx::is_available(gpu_device) ? Qtrue : Qfalse;
|
|
6629
6630
|
} catch (const std::exception& error) {
|
|
6630
6631
|
raise_std_exception(error);
|
|
6631
6632
|
return Qnil;
|
|
@@ -6654,7 +6655,12 @@ static VALUE core_metal_stop_capture(VALUE) {
|
|
|
6654
6655
|
|
|
6655
6656
|
static VALUE core_metal_device_info(VALUE) {
|
|
6656
6657
|
try {
|
|
6657
|
-
const
|
|
6658
|
+
const mx::Device gpu_device(mx::Device::gpu, 0);
|
|
6659
|
+
if (!mx::is_available(gpu_device)) {
|
|
6660
|
+
rb_raise(rb_eRuntimeError, "[metal_device_info] Metal GPU device is not available");
|
|
6661
|
+
}
|
|
6662
|
+
|
|
6663
|
+
const auto& info = mx::device_info(gpu_device);
|
|
6658
6664
|
VALUE hash = rb_hash_new();
|
|
6659
6665
|
for (const auto& [key, value] : info) {
|
|
6660
6666
|
VALUE ruby_key = rb_utf8_str_new(key.c_str(), static_cast<long>(key.size()));
|
|
@@ -8,13 +8,14 @@ require "shellwords"
|
|
|
8
8
|
|
|
9
9
|
module MLX
|
|
10
10
|
module DistributedUtils
|
|
11
|
-
SSHInfo =
|
|
11
|
+
SSHInfo = Data.define(:can_ssh, :has_sudo) do
|
|
12
12
|
def to_bool
|
|
13
13
|
can_ssh
|
|
14
14
|
end
|
|
15
15
|
end
|
|
16
|
-
ThunderboltPort =
|
|
17
|
-
ThunderboltHost =
|
|
16
|
+
ThunderboltPort = Data.define(:iface, :uuid, :connected_to)
|
|
17
|
+
ThunderboltHost = Data.define(:name, :ports)
|
|
18
|
+
CommandResult = Data.define(:stdout, :stderr, :status)
|
|
18
19
|
|
|
19
20
|
class IPConfigurator
|
|
20
21
|
attr_reader :ips, :hosts, :tb_hosts
|
|
@@ -509,6 +510,8 @@ module MLX
|
|
|
509
510
|
end
|
|
510
511
|
|
|
511
512
|
def config_main(argv = ARGV, runner: nil)
|
|
513
|
+
Process.warmup if Process.respond_to?(:warmup)
|
|
514
|
+
|
|
512
515
|
opts = {
|
|
513
516
|
verbose: false,
|
|
514
517
|
hosts: "127.0.0.1",
|
|
@@ -577,7 +580,7 @@ module MLX
|
|
|
577
580
|
return runner.call(cmd) unless runner.nil?
|
|
578
581
|
|
|
579
582
|
stdout, stderr, status = Open3.capture3(*cmd)
|
|
580
|
-
|
|
583
|
+
CommandResult.new(stdout: stdout, stderr: stderr, status: status)
|
|
581
584
|
end
|
|
582
585
|
|
|
583
586
|
def stdout_for(result)
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module MLX
|
|
4
|
+
module DSL
|
|
5
|
+
class Attention < MLX::NN::Module
|
|
6
|
+
def initialize(
|
|
7
|
+
dims:,
|
|
8
|
+
num_heads:,
|
|
9
|
+
kv_heads: nil,
|
|
10
|
+
qkv_bias: false,
|
|
11
|
+
backend: :sdpa,
|
|
12
|
+
rope: nil,
|
|
13
|
+
cache: false
|
|
14
|
+
)
|
|
15
|
+
super()
|
|
16
|
+
|
|
17
|
+
@dims = Integer(dims)
|
|
18
|
+
@num_heads = Integer(num_heads)
|
|
19
|
+
@kv_heads = kv_heads.nil? ? @num_heads : Integer(kv_heads)
|
|
20
|
+
if (@dims % @num_heads) != 0
|
|
21
|
+
raise ArgumentError, "dims must be divisible by num_heads"
|
|
22
|
+
end
|
|
23
|
+
if (@num_heads % @kv_heads) != 0
|
|
24
|
+
raise ArgumentError, "num_heads must be divisible by kv_heads"
|
|
25
|
+
end
|
|
26
|
+
|
|
27
|
+
@head_dim = @dims / @num_heads
|
|
28
|
+
@kv_repeats = @num_heads / @kv_heads
|
|
29
|
+
@backend = backend.to_sym
|
|
30
|
+
@cache_enabled = !!cache
|
|
31
|
+
@scale = Math.sqrt(1.0 / @head_dim)
|
|
32
|
+
|
|
33
|
+
self.query_proj = MLX::NN::Linear.new(@dims, @num_heads * @head_dim, bias: qkv_bias)
|
|
34
|
+
self.key_proj = MLX::NN::Linear.new(@dims, @kv_heads * @head_dim, bias: qkv_bias)
|
|
35
|
+
self.value_proj = MLX::NN::Linear.new(@dims, @kv_heads * @head_dim, bias: qkv_bias)
|
|
36
|
+
self.out_proj = MLX::NN::Linear.new(@num_heads * @head_dim, @dims, bias: qkv_bias)
|
|
37
|
+
self.rope = __dsl_build_rope(rope)
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
def call(queries, keys = nil, values = nil, mask: nil, cache: nil)
|
|
41
|
+
keys ||= queries
|
|
42
|
+
values ||= keys
|
|
43
|
+
q_was_2d = queries.ndim == 2
|
|
44
|
+
|
|
45
|
+
queries = MLX::Core.expand_dims(queries, 0) if q_was_2d
|
|
46
|
+
keys = MLX::Core.expand_dims(keys, 0) if keys.ndim == 2
|
|
47
|
+
values = MLX::Core.expand_dims(values, 0) if values.ndim == 2
|
|
48
|
+
|
|
49
|
+
batch_size, q_len, = queries.shape
|
|
50
|
+
|
|
51
|
+
q = __dsl_pack_heads(query_proj.call(queries), @num_heads)
|
|
52
|
+
k = __dsl_pack_heads(key_proj.call(keys), @kv_heads)
|
|
53
|
+
v = __dsl_pack_heads(value_proj.call(values), @kv_heads)
|
|
54
|
+
|
|
55
|
+
offset = cache.nil? ? 0 : cache[0].shape[2]
|
|
56
|
+
if !rope.nil?
|
|
57
|
+
if offset.zero?
|
|
58
|
+
q = rope.call(q)
|
|
59
|
+
k = rope.call(k)
|
|
60
|
+
else
|
|
61
|
+
q = rope.call(q, offset: offset)
|
|
62
|
+
k = rope.call(k, offset: offset)
|
|
63
|
+
end
|
|
64
|
+
end
|
|
65
|
+
|
|
66
|
+
unless cache.nil?
|
|
67
|
+
key_cache, value_cache = cache
|
|
68
|
+
k = MLX::Core.concatenate([key_cache, k], 2)
|
|
69
|
+
v = MLX::Core.concatenate([value_cache, v], 2)
|
|
70
|
+
end
|
|
71
|
+
next_cache = [k, v]
|
|
72
|
+
|
|
73
|
+
k_for_attn = __dsl_repeat_kv(k)
|
|
74
|
+
v_for_attn = __dsl_repeat_kv(v)
|
|
75
|
+
out = __dsl_attention(q, k_for_attn, v_for_attn, mask)
|
|
76
|
+
out = MLX::Core.transpose(out, [0, 2, 1, 3])
|
|
77
|
+
out = MLX::Core.reshape(out, [batch_size, q_len, @num_heads * @head_dim])
|
|
78
|
+
out = out_proj.call(out)
|
|
79
|
+
out = MLX::Core.squeeze(out, 0) if q_was_2d
|
|
80
|
+
|
|
81
|
+
if @cache_enabled || !cache.nil?
|
|
82
|
+
[out, next_cache]
|
|
83
|
+
else
|
|
84
|
+
out
|
|
85
|
+
end
|
|
86
|
+
end
|
|
87
|
+
|
|
88
|
+
private
|
|
89
|
+
|
|
90
|
+
def __dsl_build_rope(config)
|
|
91
|
+
return nil if config.nil?
|
|
92
|
+
|
|
93
|
+
opts = config.transform_keys(&:to_sym)
|
|
94
|
+
rope_kwargs = {
|
|
95
|
+
traditional: opts.fetch(:traditional, false),
|
|
96
|
+
base: opts.fetch(:base, 10_000.0)
|
|
97
|
+
}
|
|
98
|
+
rope_kwargs[:scale] = opts[:scale] if opts.key?(:scale)
|
|
99
|
+
MLX::NN::RoPE.new(@head_dim, **rope_kwargs)
|
|
100
|
+
end
|
|
101
|
+
|
|
102
|
+
def __dsl_pack_heads(x, heads)
|
|
103
|
+
batch, length, = x.shape
|
|
104
|
+
x = MLX::Core.reshape(x, [batch, length, heads, @head_dim])
|
|
105
|
+
MLX::Core.transpose(x, [0, 2, 1, 3])
|
|
106
|
+
end
|
|
107
|
+
|
|
108
|
+
def __dsl_repeat_kv(x)
|
|
109
|
+
return x if @kv_repeats == 1
|
|
110
|
+
|
|
111
|
+
batch, _heads, length, dim = x.shape
|
|
112
|
+
expanded = MLX::Core.expand_dims(x, 2)
|
|
113
|
+
repeated = MLX::Core.concatenate(Array.new(@kv_repeats, expanded), 2)
|
|
114
|
+
MLX::Core.reshape(repeated, [batch, @num_heads, length, dim])
|
|
115
|
+
end
|
|
116
|
+
|
|
117
|
+
def __dsl_attention(q, k, v, mask)
|
|
118
|
+
if @backend == :sdpa && MLX::Core.respond_to?(:scaled_dot_product_attention)
|
|
119
|
+
return MLX::Core.scaled_dot_product_attention(q, k, v, @scale, mask)
|
|
120
|
+
end
|
|
121
|
+
|
|
122
|
+
scores = MLX::Core.matmul(
|
|
123
|
+
MLX::Core.multiply(q, @scale),
|
|
124
|
+
MLX::Core.transpose(k, [0, 1, 3, 2])
|
|
125
|
+
)
|
|
126
|
+
scores = MLX::Core.add(scores, mask.astype(scores.dtype)) unless mask.nil?
|
|
127
|
+
probs = MLX::Core.softmax(scores.astype(MLX::Core.float32), -1).astype(scores.dtype)
|
|
128
|
+
MLX::Core.matmul(probs, v)
|
|
129
|
+
end
|
|
130
|
+
end
|
|
131
|
+
end
|
|
132
|
+
end
|
data/lib/mlx/dsl/builder.rb
CHANGED
|
@@ -266,6 +266,14 @@ module MLX
|
|
|
266
266
|
push(MLX::NN::Transformer.new(*args, **kwargs))
|
|
267
267
|
end
|
|
268
268
|
|
|
269
|
+
def attention(*args, **kwargs)
|
|
270
|
+
push(MLX::DSL::Attention.new(*args, **kwargs))
|
|
271
|
+
end
|
|
272
|
+
|
|
273
|
+
def transformer_block(*args, **kwargs)
|
|
274
|
+
push(MLX::DSL::TransformerBlock.new(*args, **kwargs))
|
|
275
|
+
end
|
|
276
|
+
|
|
269
277
|
def rope(*args, **kwargs)
|
|
270
278
|
push(MLX::NN::RoPE.new(*args, **kwargs))
|
|
271
279
|
end
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module MLX
|
|
4
|
+
module DSL
|
|
5
|
+
module ConfigSchema
|
|
6
|
+
UNSET = Object.new.freeze
|
|
7
|
+
class DefaultContext
|
|
8
|
+
def initialize(values)
|
|
9
|
+
@values = values
|
|
10
|
+
end
|
|
11
|
+
|
|
12
|
+
def method_missing(name, *args, &block)
|
|
13
|
+
if args.empty? && block.nil? && @values.key?(name.to_sym)
|
|
14
|
+
return @values[name.to_sym]
|
|
15
|
+
end
|
|
16
|
+
|
|
17
|
+
super
|
|
18
|
+
end
|
|
19
|
+
|
|
20
|
+
def respond_to_missing?(name, include_private = false)
|
|
21
|
+
@values.key?(name.to_sym) || super
|
|
22
|
+
end
|
|
23
|
+
end
|
|
24
|
+
|
|
25
|
+
def self.included(base)
|
|
26
|
+
base.extend(ClassMethods)
|
|
27
|
+
end
|
|
28
|
+
|
|
29
|
+
module ClassMethods
|
|
30
|
+
def field(name, type = nil, required: false, default: UNSET, &validator)
|
|
31
|
+
key = name.to_sym
|
|
32
|
+
config_schema_fields[key] = {
|
|
33
|
+
type: type,
|
|
34
|
+
required: !!required,
|
|
35
|
+
default: default,
|
|
36
|
+
validator: validator
|
|
37
|
+
}
|
|
38
|
+
attr_accessor key unless method_defined?(key) && method_defined?(:"#{key}=")
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
def config_schema_fields
|
|
42
|
+
@config_schema_fields ||= {}
|
|
43
|
+
end
|
|
44
|
+
|
|
45
|
+
def inherited(subclass)
|
|
46
|
+
super
|
|
47
|
+
copied = config_schema_fields.each_with_object({}) do |(key, value), out|
|
|
48
|
+
out[key] = value.dup
|
|
49
|
+
end
|
|
50
|
+
subclass.instance_variable_set(:@config_schema_fields, copied)
|
|
51
|
+
end
|
|
52
|
+
|
|
53
|
+
def from_hash(raw)
|
|
54
|
+
source = (raw || {}).each_with_object({}) do |(key, value), out|
|
|
55
|
+
out[key.to_sym] = value
|
|
56
|
+
end
|
|
57
|
+
new(**source)
|
|
58
|
+
end
|
|
59
|
+
|
|
60
|
+
private
|
|
61
|
+
|
|
62
|
+
def __dsl_call_default(default, resolved)
|
|
63
|
+
context = DefaultContext.new(resolved)
|
|
64
|
+
return default unless default.respond_to?(:call)
|
|
65
|
+
return default.call(context) if default.is_a?(Proc) && default.arity == 1
|
|
66
|
+
return default.call if !default.is_a?(Proc)
|
|
67
|
+
return default.call if default.arity.zero?
|
|
68
|
+
|
|
69
|
+
default.call(context)
|
|
70
|
+
end
|
|
71
|
+
|
|
72
|
+
def __dsl_validate_field(name, value, spec)
|
|
73
|
+
type = spec.fetch(:type)
|
|
74
|
+
if !type.nil? && !value.nil?
|
|
75
|
+
allowed_types = type.is_a?(Array) ? type : [type]
|
|
76
|
+
unless allowed_types.any? { |klass| value.is_a?(klass) }
|
|
77
|
+
raise TypeError,
|
|
78
|
+
"config field #{name} must be #{allowed_types.map(&:to_s).join(' or ')}, got #{value.class}"
|
|
79
|
+
end
|
|
80
|
+
end
|
|
81
|
+
|
|
82
|
+
validator = spec.fetch(:validator)
|
|
83
|
+
unless validator.nil?
|
|
84
|
+
if validator.arity == 2
|
|
85
|
+
validator.call(value, name)
|
|
86
|
+
else
|
|
87
|
+
validator.call(value)
|
|
88
|
+
end
|
|
89
|
+
end
|
|
90
|
+
|
|
91
|
+
value
|
|
92
|
+
end
|
|
93
|
+
end
|
|
94
|
+
|
|
95
|
+
def initialize(**kwargs)
|
|
96
|
+
source = kwargs.each_with_object({}) do |(key, value), out|
|
|
97
|
+
out[key.to_sym] = value
|
|
98
|
+
end
|
|
99
|
+
resolved = {}
|
|
100
|
+
unknown = source.keys - self.class.config_schema_fields.keys
|
|
101
|
+
unless unknown.empty?
|
|
102
|
+
names = unknown.map(&:to_s).sort.join(", ")
|
|
103
|
+
raise ArgumentError, "unknown config field(s): #{names}"
|
|
104
|
+
end
|
|
105
|
+
|
|
106
|
+
self.class.config_schema_fields.each do |name, spec|
|
|
107
|
+
if source.key?(name)
|
|
108
|
+
value = source.fetch(name)
|
|
109
|
+
else
|
|
110
|
+
default = spec.fetch(:default)
|
|
111
|
+
if default.equal?(UNSET)
|
|
112
|
+
if spec.fetch(:required)
|
|
113
|
+
raise ArgumentError, "missing required config field: #{name}"
|
|
114
|
+
end
|
|
115
|
+
next
|
|
116
|
+
end
|
|
117
|
+
value = self.class.send(:__dsl_call_default, default, resolved)
|
|
118
|
+
end
|
|
119
|
+
|
|
120
|
+
value = self.class.send(:__dsl_validate_field, name, value, spec)
|
|
121
|
+
resolved[name] = value
|
|
122
|
+
public_send(:"#{name}=", value)
|
|
123
|
+
end
|
|
124
|
+
end
|
|
125
|
+
|
|
126
|
+
def to_h
|
|
127
|
+
self.class.config_schema_fields.keys.each_with_object({}) do |name, out|
|
|
128
|
+
out[name.to_s] = public_send(name)
|
|
129
|
+
end
|
|
130
|
+
end
|
|
131
|
+
end
|
|
132
|
+
end
|
|
133
|
+
end
|
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module MLX
|
|
4
|
+
module DSL
|
|
5
|
+
class Generate
|
|
6
|
+
def initialize(
|
|
7
|
+
model:,
|
|
8
|
+
tokenizer: nil,
|
|
9
|
+
eos_id: nil,
|
|
10
|
+
sampler: nil,
|
|
11
|
+
mode: :decoder_only,
|
|
12
|
+
decoder_start_id: nil
|
|
13
|
+
)
|
|
14
|
+
@model = model
|
|
15
|
+
@tokenizer = tokenizer
|
|
16
|
+
@eos_id = eos_id
|
|
17
|
+
@sampler = { strategy: :argmax }.merge((sampler || {}).transform_keys(&:to_sym))
|
|
18
|
+
@mode = mode.to_sym
|
|
19
|
+
@decoder_start_id = decoder_start_id
|
|
20
|
+
end
|
|
21
|
+
|
|
22
|
+
def each_token(prompt: nil, input_ids: nil, max_tokens: 128, **kwargs)
|
|
23
|
+
return enum_for(__method__, prompt: prompt, input_ids: input_ids, max_tokens: max_tokens, **kwargs) unless block_given?
|
|
24
|
+
|
|
25
|
+
case @mode
|
|
26
|
+
when :decoder_only
|
|
27
|
+
__dsl_each_decoder_only(prompt: prompt, input_ids: input_ids, max_tokens: max_tokens, **kwargs) do |id, chunk|
|
|
28
|
+
yield id, chunk
|
|
29
|
+
end
|
|
30
|
+
when :encoder_decoder
|
|
31
|
+
__dsl_each_encoder_decoder(prompt: prompt, input_ids: input_ids, max_tokens: max_tokens, **kwargs) do |id, chunk|
|
|
32
|
+
yield id, chunk
|
|
33
|
+
end
|
|
34
|
+
else
|
|
35
|
+
raise ArgumentError, "unsupported generation mode: #{@mode.inspect}"
|
|
36
|
+
end
|
|
37
|
+
|
|
38
|
+
self
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
private
|
|
42
|
+
|
|
43
|
+
def __dsl_each_decoder_only(prompt:, input_ids:, max_tokens:, **kwargs)
|
|
44
|
+
tokens = input_ids.nil? ? __dsl_encode(prompt) : input_ids
|
|
45
|
+
model_input = __dsl_input_array(tokens)
|
|
46
|
+
logits, cache = __dsl_decode_step(model_input, cache: nil, **kwargs)
|
|
47
|
+
|
|
48
|
+
max_tokens.to_i.times do
|
|
49
|
+
token = __dsl_sample(__dsl_last_logits(logits))
|
|
50
|
+
token_id = __dsl_token_id(token)
|
|
51
|
+
chunk = __dsl_decode_token(token_id)
|
|
52
|
+
yield token_id, chunk
|
|
53
|
+
break if !@eos_id.nil? && token_id == @eos_id
|
|
54
|
+
|
|
55
|
+
next_input = MLX::Core.array([[token_id]], MLX::Core.int32)
|
|
56
|
+
logits, cache = __dsl_decode_step(next_input, cache: cache, **kwargs)
|
|
57
|
+
end
|
|
58
|
+
end
|
|
59
|
+
|
|
60
|
+
def __dsl_each_encoder_decoder(prompt:, input_ids:, max_tokens:, **kwargs)
|
|
61
|
+
tokens = input_ids.nil? ? __dsl_encode(prompt) : input_ids
|
|
62
|
+
source = __dsl_input_array(tokens)
|
|
63
|
+
|
|
64
|
+
if @model.respond_to?(:encode) && @model.respond_to?(:decode)
|
|
65
|
+
memory = @model.encode(source)
|
|
66
|
+
start_id = __dsl_decoder_start_id
|
|
67
|
+
decoder_input = MLX::Core.array([[start_id]], MLX::Core.int32)
|
|
68
|
+
cache = nil
|
|
69
|
+
|
|
70
|
+
max_tokens.to_i.times do
|
|
71
|
+
decoded = @model.decode(decoder_input, memory, cache: cache, **kwargs)
|
|
72
|
+
logits, cache = __dsl_split_logits_and_cache(decoded, cache)
|
|
73
|
+
token = __dsl_sample(__dsl_last_logits(logits))
|
|
74
|
+
token_id = __dsl_token_id(token)
|
|
75
|
+
chunk = __dsl_decode_token(token_id)
|
|
76
|
+
yield token_id, chunk
|
|
77
|
+
break if !@eos_id.nil? && token_id == @eos_id
|
|
78
|
+
|
|
79
|
+
decoder_input = MLX::Core.array([[token_id]], MLX::Core.int32)
|
|
80
|
+
end
|
|
81
|
+
return
|
|
82
|
+
end
|
|
83
|
+
|
|
84
|
+
# Fallback path for model.call-style APIs.
|
|
85
|
+
__dsl_each_decoder_only(prompt: prompt, input_ids: tokens, max_tokens: max_tokens, **kwargs) do |id, chunk|
|
|
86
|
+
yield id, chunk
|
|
87
|
+
end
|
|
88
|
+
end
|
|
89
|
+
|
|
90
|
+
def __dsl_decode_step(input_ids, cache:, **kwargs)
|
|
91
|
+
output = @model.call(input_ids, cache: cache, **kwargs)
|
|
92
|
+
__dsl_split_logits_and_cache(output, cache)
|
|
93
|
+
end
|
|
94
|
+
|
|
95
|
+
def __dsl_split_logits_and_cache(output, fallback_cache)
|
|
96
|
+
if output.is_a?(Array) && output.length == 2
|
|
97
|
+
[output[0], output[1]]
|
|
98
|
+
else
|
|
99
|
+
[output, fallback_cache]
|
|
100
|
+
end
|
|
101
|
+
end
|
|
102
|
+
|
|
103
|
+
def __dsl_last_logits(logits)
|
|
104
|
+
return logits if logits.ndim == 2
|
|
105
|
+
return logits if logits.ndim == 1
|
|
106
|
+
|
|
107
|
+
index = MLX::Core.array([logits.shape[1] - 1], MLX::Core.int32)
|
|
108
|
+
MLX::Core.squeeze(MLX::Core.take(logits, index, 1), 1)
|
|
109
|
+
end
|
|
110
|
+
|
|
111
|
+
def __dsl_sample(logits)
|
|
112
|
+
strategy = @sampler.fetch(:strategy, :argmax).to_sym
|
|
113
|
+
temperature = @sampler.fetch(:temperature, 1.0).to_f
|
|
114
|
+
|
|
115
|
+
return MLX::Core.argmax(logits, -1) if strategy == :argmax || temperature.zero?
|
|
116
|
+
|
|
117
|
+
case strategy
|
|
118
|
+
when :top_k
|
|
119
|
+
__dsl_top_k_sample(logits, k: Integer(@sampler.fetch(:k, 40)), temperature: temperature)
|
|
120
|
+
when :temperature, :categorical
|
|
121
|
+
__dsl_temperature_sample(logits, temperature: temperature)
|
|
122
|
+
else
|
|
123
|
+
raise ArgumentError, "unsupported sampler strategy: #{strategy.inspect}"
|
|
124
|
+
end
|
|
125
|
+
end
|
|
126
|
+
|
|
127
|
+
def __dsl_temperature_sample(logits, temperature:)
|
|
128
|
+
scaled = if temperature == 1.0
|
|
129
|
+
logits
|
|
130
|
+
else
|
|
131
|
+
MLX::Core.multiply(logits, 1.0 / temperature)
|
|
132
|
+
end
|
|
133
|
+
MLX::Core.categorical(scaled)
|
|
134
|
+
end
|
|
135
|
+
|
|
136
|
+
def __dsl_top_k_sample(logits, k:, temperature:)
|
|
137
|
+
rows = logits.ndim == 1 ? [logits.to_a] : logits.to_a
|
|
138
|
+
masked = rows.map do |row|
|
|
139
|
+
pairs = row.each_with_index.sort_by { |(value, _index)| -value }
|
|
140
|
+
keep = pairs.first([k, row.length].min).map(&:last)
|
|
141
|
+
filtered = Array.new(row.length, -Float::INFINITY)
|
|
142
|
+
keep.each { |idx| filtered[idx] = row[idx] }
|
|
143
|
+
filtered
|
|
144
|
+
end
|
|
145
|
+
masked_logits = MLX::Core.array(masked, logits.dtype)
|
|
146
|
+
__dsl_temperature_sample(masked_logits, temperature: temperature)
|
|
147
|
+
end
|
|
148
|
+
|
|
149
|
+
def __dsl_encode(prompt)
|
|
150
|
+
raise ArgumentError, "prompt/input_ids required when tokenizer is unavailable" if @tokenizer.nil?
|
|
151
|
+
|
|
152
|
+
@tokenizer.encode(prompt.to_s)
|
|
153
|
+
end
|
|
154
|
+
|
|
155
|
+
def __dsl_input_array(tokens)
|
|
156
|
+
if tokens.is_a?(MLX::Core::Array)
|
|
157
|
+
return tokens if tokens.ndim > 1
|
|
158
|
+
|
|
159
|
+
return MLX::Core.expand_dims(tokens.astype(MLX::Core.int32), 0)
|
|
160
|
+
end
|
|
161
|
+
|
|
162
|
+
arr = tokens.to_a
|
|
163
|
+
nested = arr.empty? ? [[]] : (arr.first.is_a?(Array) ? arr : [arr])
|
|
164
|
+
MLX::Core.array(nested, MLX::Core.int32)
|
|
165
|
+
end
|
|
166
|
+
|
|
167
|
+
def __dsl_token_id(token)
|
|
168
|
+
return token.item.to_i if token.respond_to?(:item)
|
|
169
|
+
|
|
170
|
+
value = token.to_a
|
|
171
|
+
if value.is_a?(Array)
|
|
172
|
+
first = value.first
|
|
173
|
+
return first.first.to_i if first.is_a?(Array)
|
|
174
|
+
return first.to_i
|
|
175
|
+
end
|
|
176
|
+
value.to_i
|
|
177
|
+
end
|
|
178
|
+
|
|
179
|
+
def __dsl_decode_token(token_id)
|
|
180
|
+
return nil if @tokenizer.nil? || !@tokenizer.respond_to?(:decode)
|
|
181
|
+
|
|
182
|
+
@tokenizer.decode([token_id])
|
|
183
|
+
end
|
|
184
|
+
|
|
185
|
+
def __dsl_decoder_start_id
|
|
186
|
+
return @decoder_start_id unless @decoder_start_id.nil?
|
|
187
|
+
return @tokenizer.decoder_start_id if !@tokenizer.nil? && @tokenizer.respond_to?(:decoder_start_id)
|
|
188
|
+
|
|
189
|
+
raise ArgumentError, "decoder_start_id is required for encoder-decoder mode"
|
|
190
|
+
end
|
|
191
|
+
end
|
|
192
|
+
end
|
|
193
|
+
end
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module MLX
|
|
4
|
+
module DSL
|
|
5
|
+
class KVCache
|
|
6
|
+
attr_reader :num_layers
|
|
7
|
+
|
|
8
|
+
def initialize(num_layers:)
|
|
9
|
+
@num_layers = Integer(num_layers)
|
|
10
|
+
raise ArgumentError, "num_layers must be non-negative" if @num_layers.negative?
|
|
11
|
+
|
|
12
|
+
@layers = Array.new(@num_layers)
|
|
13
|
+
end
|
|
14
|
+
|
|
15
|
+
def layer(index)
|
|
16
|
+
@layers.fetch(__dsl_index(index))
|
|
17
|
+
end
|
|
18
|
+
|
|
19
|
+
def []=(index, value)
|
|
20
|
+
@layers[__dsl_index(index)] = value
|
|
21
|
+
end
|
|
22
|
+
|
|
23
|
+
def offset(layer:)
|
|
24
|
+
state = self.layer(layer)
|
|
25
|
+
return 0 if state.nil?
|
|
26
|
+
|
|
27
|
+
keys, = state
|
|
28
|
+
keys.shape[2]
|
|
29
|
+
end
|
|
30
|
+
|
|
31
|
+
def append(layer:, keys:, values:)
|
|
32
|
+
idx = __dsl_index(layer)
|
|
33
|
+
current = @layers[idx]
|
|
34
|
+
if current.nil?
|
|
35
|
+
@layers[idx] = [keys, values]
|
|
36
|
+
return @layers[idx]
|
|
37
|
+
end
|
|
38
|
+
|
|
39
|
+
key_cache, value_cache = current
|
|
40
|
+
next_keys = MLX::Core.concatenate([key_cache, keys], 2)
|
|
41
|
+
next_values = MLX::Core.concatenate([value_cache, values], 2)
|
|
42
|
+
@layers[idx] = [next_keys, next_values]
|
|
43
|
+
end
|
|
44
|
+
|
|
45
|
+
def truncate!(tokens:, layer: nil)
|
|
46
|
+
keep = Integer(tokens)
|
|
47
|
+
if layer.nil?
|
|
48
|
+
@layers.each_index { |idx| __dsl_truncate_layer!(idx, keep) }
|
|
49
|
+
else
|
|
50
|
+
__dsl_truncate_layer!(__dsl_index(layer), keep)
|
|
51
|
+
end
|
|
52
|
+
self
|
|
53
|
+
end
|
|
54
|
+
|
|
55
|
+
def reset!(layer: nil)
|
|
56
|
+
if layer.nil?
|
|
57
|
+
@layers.map! { nil }
|
|
58
|
+
else
|
|
59
|
+
@layers[__dsl_index(layer)] = nil
|
|
60
|
+
end
|
|
61
|
+
self
|
|
62
|
+
end
|
|
63
|
+
|
|
64
|
+
private
|
|
65
|
+
|
|
66
|
+
def __dsl_index(index)
|
|
67
|
+
idx = Integer(index)
|
|
68
|
+
if idx.negative? || idx >= @num_layers
|
|
69
|
+
raise IndexError, "layer index #{idx} out of range (0...#{@num_layers})"
|
|
70
|
+
end
|
|
71
|
+
|
|
72
|
+
idx
|
|
73
|
+
end
|
|
74
|
+
|
|
75
|
+
def __dsl_truncate_layer!(idx, keep)
|
|
76
|
+
state = @layers[idx]
|
|
77
|
+
return if state.nil?
|
|
78
|
+
|
|
79
|
+
if keep <= 0
|
|
80
|
+
@layers[idx] = nil
|
|
81
|
+
return
|
|
82
|
+
end
|
|
83
|
+
|
|
84
|
+
keys, values = state
|
|
85
|
+
total = keys.shape[2]
|
|
86
|
+
return if keep >= total
|
|
87
|
+
|
|
88
|
+
start = total - keep
|
|
89
|
+
indices = MLX::Core.arange(start, total, 1, MLX::Core.int32)
|
|
90
|
+
trimmed_keys = MLX::Core.take(keys, indices, 2)
|
|
91
|
+
trimmed_values = MLX::Core.take(values, indices, 2)
|
|
92
|
+
@layers[idx] = [trimmed_keys, trimmed_values]
|
|
93
|
+
end
|
|
94
|
+
end
|
|
95
|
+
end
|
|
96
|
+
end
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module MLX
|
|
4
|
+
module DSL
|
|
5
|
+
module Masks
|
|
6
|
+
module_function
|
|
7
|
+
|
|
8
|
+
def causal(length:, offset: 0, dtype: MLX::Core.float32)
|
|
9
|
+
length = Integer(length)
|
|
10
|
+
offset = Integer(offset)
|
|
11
|
+
raise ArgumentError, "length must be non-negative" if length.negative?
|
|
12
|
+
raise ArgumentError, "offset must be non-negative" if offset.negative?
|
|
13
|
+
|
|
14
|
+
rinds = MLX::Core.arange(0, offset + length, 1)
|
|
15
|
+
linds = if offset.zero?
|
|
16
|
+
rinds
|
|
17
|
+
else
|
|
18
|
+
MLX::Core.arange(offset, offset + length, 1)
|
|
19
|
+
end
|
|
20
|
+
lhs = MLX::Core.expand_dims(linds, 1)
|
|
21
|
+
rhs = MLX::Core.expand_dims(rinds, 0)
|
|
22
|
+
mask = MLX::Core.less(lhs, rhs).astype(dtype)
|
|
23
|
+
min_value = if MLX::Core.respond_to?(:finfo)
|
|
24
|
+
MLX::Core.finfo(dtype).min
|
|
25
|
+
else
|
|
26
|
+
-1e9
|
|
27
|
+
end
|
|
28
|
+
MLX::Core.multiply(mask, min_value)
|
|
29
|
+
end
|
|
30
|
+
end
|
|
31
|
+
end
|
|
32
|
+
end
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module MLX
|
|
4
|
+
module DSL
|
|
5
|
+
module Positions
|
|
6
|
+
module_function
|
|
7
|
+
|
|
8
|
+
def ids_like(input_ids, offset: 0, dtype: nil)
|
|
9
|
+
shape = input_ids.shape
|
|
10
|
+
seq_len = shape[-1]
|
|
11
|
+
dtype ||= input_ids.respond_to?(:dtype) ? input_ids.dtype : MLX::Core.int32
|
|
12
|
+
|
|
13
|
+
base = MLX::Core.arange(offset.to_i, offset.to_i + seq_len, 1, dtype)
|
|
14
|
+
return base if shape.length == 1
|
|
15
|
+
|
|
16
|
+
reshape_dims = Array.new(shape.length, 1)
|
|
17
|
+
reshape_dims[-1] = seq_len
|
|
18
|
+
expanded = MLX::Core.reshape(base, reshape_dims)
|
|
19
|
+
MLX::Core.broadcast_to(expanded, shape)
|
|
20
|
+
end
|
|
21
|
+
|
|
22
|
+
def offset_from_cache(cache, layer: 0)
|
|
23
|
+
return 0 if cache.nil?
|
|
24
|
+
return cache.offset(layer: layer) if cache.respond_to?(:offset)
|
|
25
|
+
|
|
26
|
+
if cache.respond_to?(:[]) && !cache[layer].nil?
|
|
27
|
+
keys, = cache[layer]
|
|
28
|
+
return keys.shape[2]
|
|
29
|
+
end
|
|
30
|
+
|
|
31
|
+
0
|
|
32
|
+
end
|
|
33
|
+
end
|
|
34
|
+
end
|
|
35
|
+
end
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module MLX
|
|
4
|
+
module DSL
|
|
5
|
+
def self.run_stack(layers, input, cache: nil, **kwargs)
|
|
6
|
+
modules = layers.to_a
|
|
7
|
+
if cache.is_a?(MLX::DSL::KVCache)
|
|
8
|
+
hidden = input
|
|
9
|
+
modules.each_with_index do |layer, index|
|
|
10
|
+
hidden, next_cache = __dsl_run_stack_layer(
|
|
11
|
+
layer,
|
|
12
|
+
hidden,
|
|
13
|
+
kwargs,
|
|
14
|
+
cache: cache.layer(index),
|
|
15
|
+
use_cache: true
|
|
16
|
+
)
|
|
17
|
+
cache[index] = next_cache
|
|
18
|
+
end
|
|
19
|
+
return [hidden, cache]
|
|
20
|
+
end
|
|
21
|
+
|
|
22
|
+
use_cache = !cache.nil?
|
|
23
|
+
cache_state = if use_cache
|
|
24
|
+
entries = cache.to_a
|
|
25
|
+
entries.length < modules.length ? entries + Array.new(modules.length - entries.length) : entries.dup
|
|
26
|
+
else
|
|
27
|
+
nil
|
|
28
|
+
end
|
|
29
|
+
|
|
30
|
+
hidden = input
|
|
31
|
+
modules.each_with_index do |layer, index|
|
|
32
|
+
layer_cache = use_cache ? cache_state[index] : nil
|
|
33
|
+
hidden, next_cache = __dsl_run_stack_layer(
|
|
34
|
+
layer,
|
|
35
|
+
hidden,
|
|
36
|
+
kwargs,
|
|
37
|
+
cache: layer_cache,
|
|
38
|
+
use_cache: use_cache
|
|
39
|
+
)
|
|
40
|
+
cache_state[index] = next_cache if use_cache
|
|
41
|
+
end
|
|
42
|
+
|
|
43
|
+
use_cache ? [hidden, cache_state] : hidden
|
|
44
|
+
end
|
|
45
|
+
|
|
46
|
+
def self.__dsl_run_stack_layer(layer, hidden, kwargs, cache:, use_cache:)
|
|
47
|
+
call_kwargs = kwargs.dup
|
|
48
|
+
call_kwargs[:cache] = cache if use_cache
|
|
49
|
+
result = layer.call(hidden, **call_kwargs)
|
|
50
|
+
|
|
51
|
+
if use_cache && result.is_a?(Array) && result.length == 2
|
|
52
|
+
[result[0], result[1]]
|
|
53
|
+
else
|
|
54
|
+
[result, cache]
|
|
55
|
+
end
|
|
56
|
+
rescue ArgumentError => e
|
|
57
|
+
if use_cache && e.message.include?("unknown keyword: :cache")
|
|
58
|
+
result = layer.call(hidden, **kwargs)
|
|
59
|
+
if result.is_a?(Array) && result.length == 2
|
|
60
|
+
return [result[0], result[1]]
|
|
61
|
+
end
|
|
62
|
+
return [result, cache]
|
|
63
|
+
end
|
|
64
|
+
raise
|
|
65
|
+
end
|
|
66
|
+
private_class_method :__dsl_run_stack_layer
|
|
67
|
+
end
|
|
68
|
+
end
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module MLX
|
|
4
|
+
module DSL
|
|
5
|
+
module Tensor
|
|
6
|
+
module_function
|
|
7
|
+
|
|
8
|
+
def scatter_rows(base:, row_indices:, values:, axis: nil)
|
|
9
|
+
axis = __dsl_default_scatter_axis(base, axis)
|
|
10
|
+
row_indices = __dsl_to_index_array(row_indices)
|
|
11
|
+
values = __dsl_to_array(values, dtype: base.dtype)
|
|
12
|
+
|
|
13
|
+
if base.ndim != values.ndim
|
|
14
|
+
raise ArgumentError, "base and values must have the same rank"
|
|
15
|
+
end
|
|
16
|
+
|
|
17
|
+
unless base.shape.each_with_index.all? { |dim, idx| idx == axis || dim == values.shape[idx] }
|
|
18
|
+
raise ArgumentError, "values shape must match base shape except along axis #{axis}"
|
|
19
|
+
end
|
|
20
|
+
|
|
21
|
+
indices = __dsl_expand_indices_for_axis(
|
|
22
|
+
row_indices: row_indices,
|
|
23
|
+
values_shape: values.shape,
|
|
24
|
+
axis: axis
|
|
25
|
+
)
|
|
26
|
+
MLX::Core.put_along_axis(base, indices, values, axis)
|
|
27
|
+
end
|
|
28
|
+
|
|
29
|
+
def where_labels(base:, labels:, mapping:, mode: :add_or_replace)
|
|
30
|
+
mode = mode.to_sym
|
|
31
|
+
unless [:add_or_replace, :replace].include?(mode)
|
|
32
|
+
raise ArgumentError, "mode must be :add_or_replace or :replace"
|
|
33
|
+
end
|
|
34
|
+
|
|
35
|
+
out = base
|
|
36
|
+
mapping.each do |label_value, mapped_value|
|
|
37
|
+
mask = MLX::Core.equal(labels, label_value)
|
|
38
|
+
mask = __dsl_expand_trailing_dims(mask, out.ndim)
|
|
39
|
+
mapped = __dsl_broadcast_mapping(mapped_value, base: out, labels_ndim: labels.ndim)
|
|
40
|
+
replacement = if mode == :replace
|
|
41
|
+
mapped
|
|
42
|
+
else
|
|
43
|
+
MLX::Core.add(out, mapped)
|
|
44
|
+
end
|
|
45
|
+
out = MLX::Core.where(mask, replacement, out)
|
|
46
|
+
end
|
|
47
|
+
out
|
|
48
|
+
end
|
|
49
|
+
|
|
50
|
+
def __dsl_default_scatter_axis(base, axis)
|
|
51
|
+
return axis.to_i unless axis.nil?
|
|
52
|
+
return 1 if base.ndim == 3
|
|
53
|
+
|
|
54
|
+
0
|
|
55
|
+
end
|
|
56
|
+
private_class_method :__dsl_default_scatter_axis
|
|
57
|
+
|
|
58
|
+
def __dsl_to_index_array(indices)
|
|
59
|
+
if indices.is_a?(MLX::Core::Array)
|
|
60
|
+
return indices.astype(MLX::Core.int32)
|
|
61
|
+
end
|
|
62
|
+
|
|
63
|
+
MLX::Core.array(indices, MLX::Core.int32)
|
|
64
|
+
end
|
|
65
|
+
private_class_method :__dsl_to_index_array
|
|
66
|
+
|
|
67
|
+
def __dsl_to_array(value, dtype:)
|
|
68
|
+
if value.is_a?(MLX::Core::Array)
|
|
69
|
+
return value.astype(dtype)
|
|
70
|
+
end
|
|
71
|
+
|
|
72
|
+
MLX::Core.array(value, dtype)
|
|
73
|
+
end
|
|
74
|
+
private_class_method :__dsl_to_array
|
|
75
|
+
|
|
76
|
+
def __dsl_expand_indices_for_axis(row_indices:, values_shape:, axis:)
|
|
77
|
+
if row_indices.ndim == 1
|
|
78
|
+
expected = values_shape[axis]
|
|
79
|
+
if row_indices.shape[0] != expected
|
|
80
|
+
raise ArgumentError, "row_indices length must match values shape at axis #{axis}"
|
|
81
|
+
end
|
|
82
|
+
|
|
83
|
+
reshape = Array.new(values_shape.length, 1)
|
|
84
|
+
reshape[axis] = expected
|
|
85
|
+
base = MLX::Core.reshape(row_indices, reshape)
|
|
86
|
+
return MLX::Core.broadcast_to(base, values_shape)
|
|
87
|
+
end
|
|
88
|
+
|
|
89
|
+
if row_indices.shape == values_shape
|
|
90
|
+
return row_indices
|
|
91
|
+
end
|
|
92
|
+
|
|
93
|
+
raise ArgumentError, "row_indices must be rank-1 or match values shape"
|
|
94
|
+
end
|
|
95
|
+
private_class_method :__dsl_expand_indices_for_axis
|
|
96
|
+
|
|
97
|
+
def __dsl_expand_trailing_dims(array, target_ndim)
|
|
98
|
+
out = array
|
|
99
|
+
while out.ndim < target_ndim
|
|
100
|
+
out = MLX::Core.expand_dims(out, out.ndim)
|
|
101
|
+
end
|
|
102
|
+
out
|
|
103
|
+
end
|
|
104
|
+
private_class_method :__dsl_expand_trailing_dims
|
|
105
|
+
|
|
106
|
+
def __dsl_broadcast_mapping(value, base:, labels_ndim:)
|
|
107
|
+
mapped = __dsl_to_array(value, dtype: base.dtype)
|
|
108
|
+
return mapped if mapped.shape == base.shape
|
|
109
|
+
|
|
110
|
+
if mapped.ndim == 1 && mapped.shape[0] == base.shape[-1]
|
|
111
|
+
reshape = Array.new(labels_ndim, 1) + [mapped.shape[0]]
|
|
112
|
+
mapped = MLX::Core.reshape(mapped, reshape)
|
|
113
|
+
return MLX::Core.broadcast_to(mapped, base.shape)
|
|
114
|
+
end
|
|
115
|
+
|
|
116
|
+
if mapped.ndim == labels_ndim && mapped.shape == base.shape[0...labels_ndim]
|
|
117
|
+
mapped = __dsl_expand_trailing_dims(mapped, base.ndim)
|
|
118
|
+
return MLX::Core.broadcast_to(mapped, base.shape)
|
|
119
|
+
end
|
|
120
|
+
|
|
121
|
+
MLX::Core.broadcast_to(mapped, base.shape)
|
|
122
|
+
end
|
|
123
|
+
private_class_method :__dsl_broadcast_mapping
|
|
124
|
+
end
|
|
125
|
+
end
|
|
126
|
+
end
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module MLX
|
|
4
|
+
module DSL
|
|
5
|
+
class FeedForward < MLX::NN::Module
|
|
6
|
+
def initialize(dims:, hidden_dims:, kind: :gelu, bias: false)
|
|
7
|
+
super()
|
|
8
|
+
@kind = kind.to_sym
|
|
9
|
+
@dims = Integer(dims)
|
|
10
|
+
@hidden_dims = Integer(hidden_dims)
|
|
11
|
+
|
|
12
|
+
case @kind
|
|
13
|
+
when :swiglu
|
|
14
|
+
self.gate_proj = MLX::NN::Linear.new(@dims, @hidden_dims, bias: bias)
|
|
15
|
+
self.up_proj = MLX::NN::Linear.new(@dims, @hidden_dims, bias: bias)
|
|
16
|
+
self.down_proj = MLX::NN::Linear.new(@hidden_dims, @dims, bias: bias)
|
|
17
|
+
else
|
|
18
|
+
self.in_proj = MLX::NN::Linear.new(@dims, @hidden_dims, bias: bias)
|
|
19
|
+
self.out_proj = MLX::NN::Linear.new(@hidden_dims, @dims, bias: bias)
|
|
20
|
+
end
|
|
21
|
+
end
|
|
22
|
+
|
|
23
|
+
def call(x)
|
|
24
|
+
case @kind
|
|
25
|
+
when :swiglu
|
|
26
|
+
gated = MLX::NN.silu(gate_proj.call(x))
|
|
27
|
+
down_proj.call(MLX::Core.multiply(gated, up_proj.call(x)))
|
|
28
|
+
when :relu
|
|
29
|
+
out_proj.call(MLX::NN.relu(in_proj.call(x)))
|
|
30
|
+
else
|
|
31
|
+
out_proj.call(MLX::NN.gelu(in_proj.call(x)))
|
|
32
|
+
end
|
|
33
|
+
end
|
|
34
|
+
end
|
|
35
|
+
|
|
36
|
+
class TransformerBlock < MLX::NN::Module
|
|
37
|
+
def initialize(
|
|
38
|
+
dims:,
|
|
39
|
+
num_heads:,
|
|
40
|
+
kv_heads: nil,
|
|
41
|
+
ffn_dims: nil,
|
|
42
|
+
norm: :rms,
|
|
43
|
+
norm_eps: 1e-5,
|
|
44
|
+
ffn: nil,
|
|
45
|
+
rope: nil,
|
|
46
|
+
qkv_bias: false,
|
|
47
|
+
backend: :sdpa,
|
|
48
|
+
cache: false
|
|
49
|
+
)
|
|
50
|
+
super()
|
|
51
|
+
|
|
52
|
+
ffn_config = (ffn || {}).transform_keys(&:to_sym)
|
|
53
|
+
ffn_kind = ffn_config.fetch(:kind, :gelu)
|
|
54
|
+
hidden_dims = ffn_dims || ffn_config.fetch(:hidden_dims, Integer(dims) * 4)
|
|
55
|
+
ffn_bias = ffn_config.fetch(:bias, false)
|
|
56
|
+
|
|
57
|
+
self.attention_norm = __dsl_build_norm(norm, Integer(dims), norm_eps)
|
|
58
|
+
self.attention = MLX::DSL::Attention.new(
|
|
59
|
+
dims: dims,
|
|
60
|
+
num_heads: num_heads,
|
|
61
|
+
kv_heads: kv_heads,
|
|
62
|
+
qkv_bias: qkv_bias,
|
|
63
|
+
backend: backend,
|
|
64
|
+
rope: rope,
|
|
65
|
+
cache: cache
|
|
66
|
+
)
|
|
67
|
+
self.ffn_norm = __dsl_build_norm(norm, Integer(dims), norm_eps)
|
|
68
|
+
self.feed_forward = MLX::DSL::FeedForward.new(
|
|
69
|
+
dims: dims,
|
|
70
|
+
hidden_dims: hidden_dims,
|
|
71
|
+
kind: ffn_kind,
|
|
72
|
+
bias: ffn_bias
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
@cache_enabled = !!cache
|
|
76
|
+
end
|
|
77
|
+
|
|
78
|
+
def call(x, mask: nil, cache: nil, **_kwargs)
|
|
79
|
+
attn_input = attention_norm.call(x)
|
|
80
|
+
attn_result = attention.call(attn_input, attn_input, attn_input, mask: mask, cache: cache)
|
|
81
|
+
if attn_result.is_a?(Array) && attn_result.length == 2
|
|
82
|
+
attn_out, next_cache = attn_result
|
|
83
|
+
else
|
|
84
|
+
attn_out = attn_result
|
|
85
|
+
next_cache = cache
|
|
86
|
+
end
|
|
87
|
+
|
|
88
|
+
hidden = MLX::Core.add(x, attn_out)
|
|
89
|
+
ffn_out = feed_forward.call(ffn_norm.call(hidden))
|
|
90
|
+
output = MLX::Core.add(hidden, ffn_out)
|
|
91
|
+
|
|
92
|
+
if @cache_enabled || !cache.nil?
|
|
93
|
+
[output, next_cache]
|
|
94
|
+
else
|
|
95
|
+
output
|
|
96
|
+
end
|
|
97
|
+
end
|
|
98
|
+
|
|
99
|
+
private
|
|
100
|
+
|
|
101
|
+
def __dsl_build_norm(kind, dims, eps)
|
|
102
|
+
case kind.to_sym
|
|
103
|
+
when :layer, :layer_norm
|
|
104
|
+
MLX::NN::LayerNorm.new(dims, eps: eps)
|
|
105
|
+
when :rms, :rms_norm
|
|
106
|
+
MLX::NN::RMSNorm.new(dims, eps: eps)
|
|
107
|
+
else
|
|
108
|
+
raise ArgumentError, "unsupported norm kind: #{kind.inspect}"
|
|
109
|
+
end
|
|
110
|
+
end
|
|
111
|
+
end
|
|
112
|
+
end
|
|
113
|
+
end
|
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module MLX
|
|
4
|
+
module DSL
|
|
5
|
+
def self.weight_map(&block)
|
|
6
|
+
mapper = WeightMap.new
|
|
7
|
+
mapper.instance_eval(&block) if block_given?
|
|
8
|
+
mapper
|
|
9
|
+
end
|
|
10
|
+
|
|
11
|
+
class WeightMap
|
|
12
|
+
def initialize
|
|
13
|
+
@rules = []
|
|
14
|
+
end
|
|
15
|
+
|
|
16
|
+
def strip_prefix(prefix)
|
|
17
|
+
@rules << [:strip_prefix, prefix.to_s]
|
|
18
|
+
self
|
|
19
|
+
end
|
|
20
|
+
|
|
21
|
+
def rename(from = nil, to = nil)
|
|
22
|
+
if from.is_a?(Hash)
|
|
23
|
+
from.each do |src, dst|
|
|
24
|
+
@rules << [:rename, src.to_s, dst.to_s]
|
|
25
|
+
end
|
|
26
|
+
return self
|
|
27
|
+
end
|
|
28
|
+
|
|
29
|
+
if from.nil? || to.nil?
|
|
30
|
+
raise ArgumentError, "rename requires either a mapping hash or from/to arguments"
|
|
31
|
+
end
|
|
32
|
+
|
|
33
|
+
@rules << [:rename, from.to_s, to.to_s]
|
|
34
|
+
self
|
|
35
|
+
end
|
|
36
|
+
|
|
37
|
+
def regex(pattern, replacement = nil, &block)
|
|
38
|
+
if pattern.nil?
|
|
39
|
+
raise ArgumentError, "regex requires a Regexp pattern"
|
|
40
|
+
end
|
|
41
|
+
if replacement.nil? && !block_given?
|
|
42
|
+
raise ArgumentError, "regex requires a replacement argument or block"
|
|
43
|
+
end
|
|
44
|
+
|
|
45
|
+
@rules << [:regex, pattern, replacement, block]
|
|
46
|
+
self
|
|
47
|
+
end
|
|
48
|
+
|
|
49
|
+
def split_qkv(source, into:, axis: 0)
|
|
50
|
+
names = into.to_a.map(&:to_s)
|
|
51
|
+
if names.empty? || names.length < 2
|
|
52
|
+
raise ArgumentError, "split_qkv :into must include at least two output names"
|
|
53
|
+
end
|
|
54
|
+
|
|
55
|
+
@rules << [:split, source.to_s, names, axis.to_i]
|
|
56
|
+
self
|
|
57
|
+
end
|
|
58
|
+
|
|
59
|
+
def transpose_if(rank:, order:)
|
|
60
|
+
@rules << [:transpose_if, rank.to_i, order.to_a]
|
|
61
|
+
self
|
|
62
|
+
end
|
|
63
|
+
|
|
64
|
+
def apply(weights)
|
|
65
|
+
entries = if weights.is_a?(Hash)
|
|
66
|
+
weights.to_a
|
|
67
|
+
elsif weights.respond_to?(:to_a)
|
|
68
|
+
weights.to_a
|
|
69
|
+
else
|
|
70
|
+
raise ArgumentError, "weights must be a Hash or array-like key/value collection"
|
|
71
|
+
end
|
|
72
|
+
|
|
73
|
+
out = {}
|
|
74
|
+
entries.each do |entry|
|
|
75
|
+
key, value = entry
|
|
76
|
+
__dsl_apply_rules_for_entry(key.to_s, value).each do |mapped_key, mapped_value|
|
|
77
|
+
out[mapped_key] = mapped_value
|
|
78
|
+
end
|
|
79
|
+
end
|
|
80
|
+
out
|
|
81
|
+
end
|
|
82
|
+
|
|
83
|
+
private
|
|
84
|
+
|
|
85
|
+
def __dsl_apply_rules_for_entry(key, value)
|
|
86
|
+
current_entries = [[key, value]]
|
|
87
|
+
@rules.each do |rule|
|
|
88
|
+
current_entries = current_entries.flat_map do |curr_key, curr_value|
|
|
89
|
+
__dsl_apply_rule(rule, curr_key, curr_value)
|
|
90
|
+
end
|
|
91
|
+
end
|
|
92
|
+
current_entries
|
|
93
|
+
end
|
|
94
|
+
|
|
95
|
+
def __dsl_apply_rule(rule, key, value)
|
|
96
|
+
kind = rule[0]
|
|
97
|
+
case kind
|
|
98
|
+
when :strip_prefix
|
|
99
|
+
prefix = rule[1]
|
|
100
|
+
if key.start_with?(prefix)
|
|
101
|
+
[[key[prefix.length..], value]]
|
|
102
|
+
else
|
|
103
|
+
[[key, value]]
|
|
104
|
+
end
|
|
105
|
+
when :rename
|
|
106
|
+
from = rule[1]
|
|
107
|
+
to = rule[2]
|
|
108
|
+
[[key.gsub(from, to), value]]
|
|
109
|
+
when :regex
|
|
110
|
+
pattern = rule[1]
|
|
111
|
+
replacement = rule[2]
|
|
112
|
+
block = rule[3]
|
|
113
|
+
if block.nil?
|
|
114
|
+
[[key.gsub(pattern, replacement.to_s), value]]
|
|
115
|
+
else
|
|
116
|
+
[[key.gsub(pattern, &block), value]]
|
|
117
|
+
end
|
|
118
|
+
when :split
|
|
119
|
+
source = rule[1]
|
|
120
|
+
targets = rule[2]
|
|
121
|
+
axis = rule[3]
|
|
122
|
+
return [[key, value]] unless key == source
|
|
123
|
+
|
|
124
|
+
parts = MLX::Core.split(value, targets.length, axis)
|
|
125
|
+
targets.each_with_index.map { |target, index| [target, parts[index]] }
|
|
126
|
+
when :transpose_if
|
|
127
|
+
target_rank = rule[1]
|
|
128
|
+
order = rule[2]
|
|
129
|
+
if value.respond_to?(:shape) && value.shape.length == target_rank
|
|
130
|
+
[[key, MLX::Core.transpose(value, order)]]
|
|
131
|
+
else
|
|
132
|
+
[[key, value]]
|
|
133
|
+
end
|
|
134
|
+
else
|
|
135
|
+
[[key, value]]
|
|
136
|
+
end
|
|
137
|
+
end
|
|
138
|
+
end
|
|
139
|
+
end
|
|
140
|
+
end
|
data/lib/mlx/dsl.rb
CHANGED
|
@@ -10,6 +10,16 @@ require_relative "dsl/data_pipeline"
|
|
|
10
10
|
require_relative "dsl/experiment"
|
|
11
11
|
require_relative "dsl/split_plan"
|
|
12
12
|
require_relative "dsl/builder"
|
|
13
|
+
require_relative "dsl/config_schema"
|
|
14
|
+
require_relative "dsl/weight_map"
|
|
15
|
+
require_relative "dsl/kv_cache"
|
|
16
|
+
require_relative "dsl/masks"
|
|
17
|
+
require_relative "dsl/positions"
|
|
18
|
+
require_relative "dsl/tensor"
|
|
19
|
+
require_relative "dsl/run_stack"
|
|
20
|
+
require_relative "dsl/attention"
|
|
21
|
+
require_relative "dsl/transformer_block"
|
|
22
|
+
require_relative "dsl/generate"
|
|
13
23
|
require_relative "dsl/train_step"
|
|
14
24
|
require_relative "dsl/model_mixin"
|
|
15
25
|
require_relative "dsl/model"
|
data/lib/mlx/version.rb
CHANGED
metadata
CHANGED
|
@@ -1,15 +1,57 @@
|
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
|
2
2
|
name: mlx
|
|
3
3
|
version: !ruby/object:Gem::Version
|
|
4
|
-
version: 0.30.7.
|
|
4
|
+
version: 0.30.7.3
|
|
5
5
|
platform: ruby
|
|
6
6
|
authors:
|
|
7
7
|
- MLX Contributors
|
|
8
8
|
- Aleksey Skryl
|
|
9
9
|
bindir: bin
|
|
10
10
|
cert_chain: []
|
|
11
|
-
date:
|
|
12
|
-
dependencies:
|
|
11
|
+
date: 1980-01-02 00:00:00.000000000 Z
|
|
12
|
+
dependencies:
|
|
13
|
+
- !ruby/object:Gem::Dependency
|
|
14
|
+
name: rake
|
|
15
|
+
requirement: !ruby/object:Gem::Requirement
|
|
16
|
+
requirements:
|
|
17
|
+
- - ">="
|
|
18
|
+
- !ruby/object:Gem::Version
|
|
19
|
+
version: '0'
|
|
20
|
+
type: :development
|
|
21
|
+
prerelease: false
|
|
22
|
+
version_requirements: !ruby/object:Gem::Requirement
|
|
23
|
+
requirements:
|
|
24
|
+
- - ">="
|
|
25
|
+
- !ruby/object:Gem::Version
|
|
26
|
+
version: '0'
|
|
27
|
+
- !ruby/object:Gem::Dependency
|
|
28
|
+
name: minitest
|
|
29
|
+
requirement: !ruby/object:Gem::Requirement
|
|
30
|
+
requirements:
|
|
31
|
+
- - ">="
|
|
32
|
+
- !ruby/object:Gem::Version
|
|
33
|
+
version: '0'
|
|
34
|
+
type: :development
|
|
35
|
+
prerelease: false
|
|
36
|
+
version_requirements: !ruby/object:Gem::Requirement
|
|
37
|
+
requirements:
|
|
38
|
+
- - ">="
|
|
39
|
+
- !ruby/object:Gem::Version
|
|
40
|
+
version: '0'
|
|
41
|
+
- !ruby/object:Gem::Dependency
|
|
42
|
+
name: benchmark
|
|
43
|
+
requirement: !ruby/object:Gem::Requirement
|
|
44
|
+
requirements:
|
|
45
|
+
- - ">="
|
|
46
|
+
- !ruby/object:Gem::Version
|
|
47
|
+
version: '0'
|
|
48
|
+
type: :development
|
|
49
|
+
prerelease: false
|
|
50
|
+
version_requirements: !ruby/object:Gem::Requirement
|
|
51
|
+
requirements:
|
|
52
|
+
- - ">="
|
|
53
|
+
- !ruby/object:Gem::Version
|
|
54
|
+
version: '0'
|
|
13
55
|
description: A Ruby wrapper for the native MLX machine learning runtime.
|
|
14
56
|
email:
|
|
15
57
|
- mlx@group.apple.com
|
|
@@ -27,15 +69,25 @@ files:
|
|
|
27
69
|
- lib/mlx/distributed_utils/config.rb
|
|
28
70
|
- lib/mlx/distributed_utils/launch.rb
|
|
29
71
|
- lib/mlx/dsl.rb
|
|
72
|
+
- lib/mlx/dsl/attention.rb
|
|
30
73
|
- lib/mlx/dsl/builder.rb
|
|
74
|
+
- lib/mlx/dsl/config_schema.rb
|
|
31
75
|
- lib/mlx/dsl/data_pipeline.rb
|
|
32
76
|
- lib/mlx/dsl/experiment.rb
|
|
77
|
+
- lib/mlx/dsl/generate.rb
|
|
33
78
|
- lib/mlx/dsl/graph_modules.rb
|
|
79
|
+
- lib/mlx/dsl/kv_cache.rb
|
|
80
|
+
- lib/mlx/dsl/masks.rb
|
|
34
81
|
- lib/mlx/dsl/model.rb
|
|
35
82
|
- lib/mlx/dsl/model_mixin.rb
|
|
83
|
+
- lib/mlx/dsl/positions.rb
|
|
84
|
+
- lib/mlx/dsl/run_stack.rb
|
|
36
85
|
- lib/mlx/dsl/split_plan.rb
|
|
86
|
+
- lib/mlx/dsl/tensor.rb
|
|
37
87
|
- lib/mlx/dsl/train_step.rb
|
|
38
88
|
- lib/mlx/dsl/trainer.rb
|
|
89
|
+
- lib/mlx/dsl/transformer_block.rb
|
|
90
|
+
- lib/mlx/dsl/weight_map.rb
|
|
39
91
|
- lib/mlx/extension.rb
|
|
40
92
|
- lib/mlx/nn.rb
|
|
41
93
|
- lib/mlx/nn/base.rb
|
|
@@ -640,14 +692,14 @@ required_ruby_version: !ruby/object:Gem::Requirement
|
|
|
640
692
|
requirements:
|
|
641
693
|
- - ">="
|
|
642
694
|
- !ruby/object:Gem::Version
|
|
643
|
-
version: '3.
|
|
695
|
+
version: '3.3'
|
|
644
696
|
required_rubygems_version: !ruby/object:Gem::Requirement
|
|
645
697
|
requirements:
|
|
646
698
|
- - ">="
|
|
647
699
|
- !ruby/object:Gem::Version
|
|
648
700
|
version: '0'
|
|
649
701
|
requirements: []
|
|
650
|
-
rubygems_version:
|
|
702
|
+
rubygems_version: 4.0.3
|
|
651
703
|
specification_version: 4
|
|
652
704
|
summary: Ruby bindings for the native MLX library
|
|
653
705
|
test_files: []
|