mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
@@ -0,0 +1,301 @@
1
+ require_relative "activations"
2
+ require_relative "cache"
3
+
4
+ module MlxLm
5
+ module Models
6
+ module Mamba
7
+ class ModelArgs < BaseModelArgs
8
+ field :model_type, default: "mamba"
9
+ field :vocab_size
10
+ field :hidden_size, default: nil
11
+ field :intermediate_size, default: nil
12
+ field :state_size, default: nil
13
+ field :num_hidden_layers, default: nil
14
+ field :conv_kernel, default: nil
15
+ field :use_bias, default: nil
16
+ field :use_conv_bias, default: nil
17
+ field :time_step_rank, default: "auto"
18
+ field :tie_word_embeddings, default: true
19
+ field :use_bcdt_rms, default: false
20
+ field :mixer_rms_eps, default: 1e-6
21
+
22
+ field :d_model, default: nil
23
+ field :d_inner, default: nil
24
+ field :d_state, default: nil
25
+ field :n_layer, default: nil
26
+ field :n_layers, default: nil
27
+ field :d_conv, default: nil
28
+ field :bias, default: nil
29
+ field :conv_bias, default: nil
30
+
31
+ def initialize(**kwargs)
32
+ super
33
+
34
+ @hidden_size ||= @d_model
35
+ @intermediate_size ||= @d_inner
36
+ @state_size ||= @d_state
37
+ @num_hidden_layers ||= @n_layer
38
+ @num_hidden_layers ||= @n_layers
39
+ @conv_kernel ||= @d_conv
40
+ @use_bias = @bias if @use_bias.nil?
41
+ @use_conv_bias = @conv_bias if @use_conv_bias.nil?
42
+
43
+ @time_step_rank = (@hidden_size.to_f / 16.0).ceil if @time_step_rank == "auto"
44
+ @use_bcdt_rms = true if @model_type == "falcon_mamba"
45
+
46
+ @hidden_size ||= 768
47
+ @intermediate_size ||= 1536
48
+ @state_size ||= 16
49
+ @num_hidden_layers ||= 24
50
+ @conv_kernel ||= 4
51
+ @use_bias = true if @use_bias.nil?
52
+ @use_conv_bias = true if @use_conv_bias.nil?
53
+ end
54
+ end
55
+
56
+ class MambaBlock < MLX::NN::Module
57
+ def initialize(args)
58
+ super()
59
+
60
+ @hidden_size = args.hidden_size
61
+ @ssm_state_size = args.state_size
62
+ @conv_kernel_size = args.conv_kernel
63
+ @intermediate_size = args.intermediate_size
64
+ @time_step_rank = args.time_step_rank.to_i
65
+ @use_conv_bias = args.use_conv_bias
66
+ @use_bcdt_rms = args.use_bcdt_rms
67
+ @mixer_rms_eps = args.mixer_rms_eps
68
+
69
+ self.in_proj = MLX::NN::Linear.new(
70
+ @hidden_size,
71
+ @intermediate_size * 2,
72
+ bias: args.use_bias
73
+ )
74
+
75
+ self.conv1d = MLX::NN::Conv1d.new(
76
+ @intermediate_size,
77
+ @intermediate_size,
78
+ @conv_kernel_size,
79
+ groups: @intermediate_size,
80
+ bias: @use_conv_bias,
81
+ padding: 0
82
+ )
83
+
84
+ self.x_proj = MLX::NN::Linear.new(
85
+ @intermediate_size,
86
+ @time_step_rank + 2 * @ssm_state_size,
87
+ bias: false
88
+ )
89
+ self.dt_proj = MLX::NN::Linear.new(@time_step_rank, @intermediate_size, bias: true)
90
+
91
+ mx = MLX::Core
92
+ a = mx.repeat(
93
+ mx.arange(1.0, @ssm_state_size + 1.0, 1.0).reshape([1, @ssm_state_size]),
94
+ @intermediate_size,
95
+ 0
96
+ )
97
+ self.a_log = mx.log(a)
98
+ self.d = mx.ones([@intermediate_size])
99
+
100
+ self.out_proj = MLX::NN::Linear.new(
101
+ @intermediate_size,
102
+ @hidden_size,
103
+ bias: args.use_bias
104
+ )
105
+ end
106
+
107
+ def call(x, cache)
108
+ if cache.nil?
109
+ conv_cache = nil
110
+ state_cache = nil
111
+ else
112
+ conv_cache = cache[0]
113
+ state_cache = cache[1]
114
+ end
115
+
116
+ output, new_conv_cache, new_state_cache = _process_sequence(x, conv_cache, state_cache)
117
+
118
+ if cache.is_a?(MlxLm::ArraysCache)
119
+ cache[0] = new_conv_cache
120
+ cache[1] = new_state_cache
121
+ end
122
+
123
+ output
124
+ end
125
+
126
+ def ssm_step(x, a, state = nil)
127
+ mx = MLX::Core
128
+
129
+ delta_bc = x_proj.call(x)
130
+ delta, b, c = mx.split(
131
+ delta_bc,
132
+ [@time_step_rank, @time_step_rank + @ssm_state_size],
133
+ -1
134
+ )
135
+
136
+ if @use_bcdt_rms
137
+ delta = _rms_norm(delta, eps: @mixer_rms_eps)
138
+ b = _rms_norm(b, eps: @mixer_rms_eps)
139
+ c = _rms_norm(c, eps: @mixer_rms_eps)
140
+ end
141
+
142
+ delta = MLX::NN.softplus(dt_proj.call(delta))
143
+ new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(b, 1)
144
+
145
+ unless state.nil?
146
+ new_state = new_state + state * mx.exp(mx.expand_dims(delta, -1) * a)
147
+ end
148
+
149
+ y = mx.squeeze(mx.matmul(new_state, mx.expand_dims(c, -1)), 2)
150
+ y = y + d * x
151
+
152
+ [y, new_state]
153
+ end
154
+
155
+ private
156
+
157
+ def _process_sequence(x, conv_cache, state_cache)
158
+ mx = MLX::Core
159
+
160
+ xz = in_proj.call(x)
161
+ x_part, z = mx.split(xz, 2, -1)
162
+
163
+ if conv_cache.nil?
164
+ x_full = mx.pad(
165
+ x_part,
166
+ [
167
+ [0, 0],
168
+ [@conv_kernel_size - 1, 0],
169
+ [0, 0],
170
+ ]
171
+ )
172
+ else
173
+ x_full = mx.concatenate([conv_cache, x_part], 1)
174
+ end
175
+
176
+ conv_out = conv1d.call(x_full)
177
+
178
+ n_keep = @conv_kernel_size - 1
179
+ new_conv_cache = if n_keep > 0
180
+ split_at = x_full.shape[1] - n_keep
181
+ mx.split(x_full, [split_at], 1)[1]
182
+ else
183
+ mx.zeros([x_full.shape[0], 0, x_full.shape[2]], x_full.dtype)
184
+ end
185
+
186
+ x_part = MLX::NN.silu(conv_out)
187
+ a = mx.multiply(-1.0, mx.exp(a_log))
188
+
189
+ current_state = state_cache
190
+ ys = []
191
+ x_part.shape[1].times do |t|
192
+ x_t = _slice_step(x_part, t)
193
+ y_t, current_state = ssm_step(x_t, a, current_state)
194
+ ys << y_t
195
+ end
196
+
197
+ y = mx.stack(ys, 1)
198
+ out = out_proj.call(Activations.swiglu(z, y))
199
+
200
+ [out, new_conv_cache, current_state]
201
+ end
202
+
203
+ def _slice_step(array, idx)
204
+ mx = MLX::Core
205
+ tail = idx.zero? ? array : mx.split(array, [idx], 1)[1]
206
+ mx.squeeze(mx.split(tail, [1], 1)[0], 1)
207
+ end
208
+
209
+ def _rms_norm(x, eps:)
210
+ mx = MLX::Core
211
+ variance = mx.mean(mx.square(x), -1, true)
212
+ x * mx.rsqrt(variance + eps)
213
+ end
214
+ end
215
+
216
+ class ResidualBlock < MLX::NN::Module
217
+ def initialize(args)
218
+ super()
219
+ self.mixer = MambaBlock.new(args)
220
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size)
221
+ end
222
+
223
+ def call(x, cache)
224
+ mixer.call(norm.call(x), cache) + x
225
+ end
226
+ end
227
+
228
+ class MambaModel < MLX::NN::Module
229
+ def initialize(args)
230
+ super()
231
+ self.embeddings = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
232
+ self.layers = Array.new(args.num_hidden_layers) { ResidualBlock.new(args) }
233
+ self.norm_f = MLX::NN::RMSNorm.new(args.hidden_size)
234
+ end
235
+
236
+ def call(x, cache)
237
+ hidden = embeddings.call(x)
238
+ layer_cache = cache || [nil] * layers.length
239
+
240
+ layers.each_with_index do |layer, i|
241
+ hidden = layer.call(hidden, layer_cache[i])
242
+ end
243
+
244
+ norm_f.call(hidden)
245
+ end
246
+ end
247
+
248
+ class Model < MLX::NN::Module
249
+ def initialize(args)
250
+ super()
251
+ self.args = args
252
+ self.model_type = args.model_type
253
+ self.backbone = MambaModel.new(args)
254
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) unless args.tie_word_embeddings
255
+ end
256
+
257
+ def call(inputs, cache: nil)
258
+ hidden = backbone.call(inputs, cache)
259
+
260
+ if args.tie_word_embeddings
261
+ backbone.embeddings.as_linear(hidden)
262
+ else
263
+ lm_head.call(hidden)
264
+ end
265
+ end
266
+
267
+ def make_cache
268
+ Array.new(layers.length) { MlxLm::ArraysCache.new(2) }
269
+ end
270
+
271
+ def layers
272
+ backbone.layers
273
+ end
274
+
275
+ def sanitize(weights)
276
+ sanitized = {}
277
+ weights.each do |name, param|
278
+ current = param
279
+ if name.include?("conv1d.weight") && _transpose_conv_weight?(param)
280
+ current = MLX::Core.swapaxes(param, 1, 2)
281
+ end
282
+ sanitized[name] = current
283
+ end
284
+ sanitized
285
+ end
286
+
287
+ private
288
+
289
+ def _transpose_conv_weight?(param)
290
+ return false unless param.respond_to?(:shape)
291
+ return false unless param.shape.is_a?(Array)
292
+ return false unless param.shape.length >= 3
293
+
294
+ param.shape[-1] != 1
295
+ end
296
+ end
297
+
298
+ Models.register("mamba", Model, ModelArgs)
299
+ end
300
+ end
301
+ end
@@ -0,0 +1,292 @@
1
+ require_relative "activations"
2
+ require_relative "cache"
3
+ require_relative "ssm"
4
+
5
+ module MlxLm
6
+ module Models
7
+ module Mamba2
8
+ class ModelArgs < BaseModelArgs
9
+ field :model_type, default: "mamba2"
10
+ field :num_heads
11
+ field :head_dim
12
+ field :vocab_size
13
+ field :hidden_size
14
+ field :intermediate_size, default: nil
15
+ field :state_size
16
+ field :num_hidden_layers
17
+ field :layer_norm_epsilon, default: 1e-6
18
+ field :conv_kernel
19
+ field :n_groups
20
+ field :use_bias, default: true
21
+ field :use_conv_bias, default: true
22
+ field :tie_word_embeddings, default: true
23
+ field :time_step_limit, default: [0.001, 100.0]
24
+ field :time_step_rank, default: "auto"
25
+ field :ssm_state_size, default: nil
26
+ field :max_position_embeddings, default: 2056
27
+
28
+ def initialize(**kwargs)
29
+ super
30
+
31
+ @time_step_rank = (@hidden_size.to_f / 16.0).ceil if @time_step_rank == "auto"
32
+ @ssm_state_size ||= @state_size
33
+ @intermediate_size ||= @num_heads * @head_dim
34
+ end
35
+ end
36
+
37
+ class MambaRMSNormGated < MLX::NN::Module
38
+ def initialize(hidden_size, eps: 1e-6)
39
+ super()
40
+ @eps = eps
41
+ self.weight = MLX::Core.ones([hidden_size])
42
+ end
43
+
44
+ def call(hidden_states, gate = nil)
45
+ hidden_states = Activations.swiglu(gate, hidden_states) unless gate.nil?
46
+ MLX::Core.rms_norm(hidden_states, weight, @eps)
47
+ end
48
+ end
49
+
50
+ class Mamba2Block < MLX::NN::Module
51
+ def initialize(args, layer_idx)
52
+ super()
53
+
54
+ _ = layer_idx
55
+ @num_heads = args.num_heads
56
+ @hidden_size = args.hidden_size
57
+ @ssm_state_size = args.ssm_state_size
58
+ @conv_kernel_size = args.conv_kernel
59
+ @intermediate_size = args.num_heads * args.head_dim
60
+ @n_groups = args.n_groups
61
+ @head_dim = args.head_dim
62
+ @time_step_limit = args.time_step_limit
63
+ @heads_per_group = @num_heads / @n_groups
64
+
65
+ @conv_dim = @intermediate_size + 2 * @n_groups * @ssm_state_size
66
+
67
+ self.conv1d = MLX::NN::Conv1d.new(
68
+ @conv_dim,
69
+ @conv_dim,
70
+ args.conv_kernel,
71
+ padding: 0,
72
+ groups: @conv_dim,
73
+ bias: args.use_conv_bias
74
+ )
75
+
76
+ projection_size = @intermediate_size + @conv_dim + @num_heads
77
+ self.in_proj = MLX::NN::Linear.new(@hidden_size, projection_size, bias: args.use_bias)
78
+
79
+ mx = MLX::Core
80
+ self.dt_bias = mx.ones([@num_heads])
81
+ self.a_log = mx.log(mx.arange(1, @num_heads + 1, 1, mx.float32))
82
+ self.d = mx.ones([@num_heads])
83
+
84
+ self.norm = MambaRMSNormGated.new(@intermediate_size, eps: args.layer_norm_epsilon)
85
+ self.out_proj = MLX::NN::Linear.new(@intermediate_size, @hidden_size, bias: args.use_bias)
86
+ end
87
+
88
+ def call(hidden_states, mask, cache = nil)
89
+ mx = MLX::Core
90
+
91
+ projected = in_proj.call(hidden_states)
92
+ gate, conv_input, dt = mx.split(
93
+ projected,
94
+ [@intermediate_size, @intermediate_size + @conv_dim],
95
+ -1
96
+ )
97
+
98
+ conv_output = _conv(conv_input, cache, mask)
99
+ ssm_hidden, b, c = mx.split(
100
+ conv_output,
101
+ [@intermediate_size, @intermediate_size + @n_groups * @ssm_state_size],
102
+ -1
103
+ )
104
+
105
+ y = _ssm(ssm_hidden, b, c, dt, cache, mask: mask)
106
+ cache.advance(y.shape[1]) if cache
107
+
108
+ out_proj.call(norm.call(y, gate))
109
+ end
110
+
111
+ private
112
+
113
+ def _conv(conv_input, cache, mask)
114
+ mx = MLX::Core
115
+
116
+ conv_input = mx.where(mx.expand_dims(mask, -1), conv_input, 0) unless mask.nil?
117
+
118
+ if cache
119
+ conv_state = if cache[0].nil?
120
+ mx.zeros(
121
+ [conv_input.shape[0], @conv_kernel_size - 1, @conv_dim],
122
+ conv_input.dtype
123
+ )
124
+ else
125
+ cache[0]
126
+ end
127
+
128
+ padded_input = mx.concatenate([conv_state, conv_input], 1)
129
+ n_keep = @conv_kernel_size - 1
130
+
131
+ if cache.lengths
132
+ t = padded_input.shape[1]
133
+ ends = mx.clip(cache.lengths, 0, t - n_keep)
134
+ positions = mx.expand_dims(
135
+ mx.expand_dims(ends, 1) + mx.arange(n_keep),
136
+ -1
137
+ )
138
+ cache[0] = mx.take_along_axis(padded_input, positions, 1)
139
+ else
140
+ if n_keep > 0
141
+ split_at = padded_input.shape[1] - n_keep
142
+ cache[0] = mx.split(padded_input, [split_at], 1)[1]
143
+ else
144
+ cache[0] = mx.zeros([padded_input.shape[0], 0, padded_input.shape[2]], padded_input.dtype)
145
+ end
146
+ end
147
+ else
148
+ padded_input = mx.pad(
149
+ conv_input,
150
+ [
151
+ [0, 0],
152
+ [@conv_kernel_size - 1, 0],
153
+ [0, 0],
154
+ ]
155
+ )
156
+ end
157
+
158
+ MLX::NN.silu(conv1d.call(padded_input))
159
+ end
160
+
161
+ def _ssm(hidden_states, b, c, dt, cache, mask:)
162
+ batch_size, seq_len, = hidden_states.shape
163
+ hidden_states = hidden_states.reshape(
164
+ [batch_size, seq_len, @num_heads, @head_dim]
165
+ )
166
+ b = b.reshape([batch_size, seq_len, @n_groups, @ssm_state_size])
167
+ c = c.reshape([batch_size, seq_len, @n_groups, @ssm_state_size])
168
+
169
+ if cache
170
+ state = cache[1]
171
+ lengths = cache.lengths
172
+ else
173
+ state = nil
174
+ lengths = nil
175
+ end
176
+
177
+ y, state = SSM.ssm_update(
178
+ hidden_states,
179
+ a_log,
180
+ b,
181
+ c,
182
+ d,
183
+ dt,
184
+ dt_bias,
185
+ state: state,
186
+ time_step_limit: @time_step_limit,
187
+ mask: mask,
188
+ lengths: lengths
189
+ )
190
+
191
+ cache[1] = state if cache
192
+ y.reshape([batch_size, seq_len, @intermediate_size])
193
+ end
194
+ end
195
+
196
+ class ResidualBlock < MLX::NN::Module
197
+ def initialize(args, layer_idx)
198
+ super()
199
+ self.mixer = Mamba2Block.new(args, layer_idx)
200
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size)
201
+ end
202
+
203
+ def call(x, mask, cache = nil)
204
+ mixer.call(norm.call(x), mask, cache) + x
205
+ end
206
+ end
207
+
208
+ class Mamba2Model < MLX::NN::Module
209
+ def initialize(args)
210
+ super()
211
+ @args = args
212
+ self.embeddings = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
213
+ self.layers = Array.new(args.num_hidden_layers) { |i| ResidualBlock.new(args, i) }
214
+ self.norm_f = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.layer_norm_epsilon)
215
+ end
216
+
217
+ def call(x, cache = nil)
218
+ hidden = embeddings.call(x)
219
+ layer_cache = cache || [nil] * layers.length
220
+
221
+ mask = _create_ssm_mask(hidden, layer_cache[0])
222
+ layers.each_with_index do |layer, i|
223
+ hidden = layer.call(hidden, mask, layer_cache[i])
224
+ end
225
+
226
+ norm_f.call(hidden)
227
+ end
228
+
229
+ private
230
+
231
+ def _create_ssm_mask(hidden, cache)
232
+ return cache.make_mask(hidden.shape[1]) if cache && cache.respond_to?(:make_mask)
233
+
234
+ nil
235
+ end
236
+ end
237
+
238
+ class Model < MLX::NN::Module
239
+ def initialize(args)
240
+ super()
241
+ self.args = args
242
+ self.model_type = args.model_type
243
+ self.backbone = Mamba2Model.new(args)
244
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) unless args.tie_word_embeddings
245
+ end
246
+
247
+ def call(inputs, cache: nil)
248
+ hidden = backbone.call(inputs, cache)
249
+
250
+ if args.tie_word_embeddings
251
+ backbone.embeddings.as_linear(hidden)
252
+ else
253
+ lm_head.call(hidden)
254
+ end
255
+ end
256
+
257
+ def make_cache(batch_size: 1)
258
+ _ = batch_size
259
+ Array.new(args.num_hidden_layers) { MlxLm::ArraysCache.new(2) }
260
+ end
261
+
262
+ def layers
263
+ backbone.layers
264
+ end
265
+
266
+ def sanitize(weights)
267
+ sanitized = {}
268
+ weights.each do |name, param|
269
+ current = param
270
+ if name.include?("conv1d.weight") && _transpose_conv_weight?(param)
271
+ current = MLX::Core.swapaxes(param, 1, 2)
272
+ end
273
+ sanitized[name] = current
274
+ end
275
+ sanitized
276
+ end
277
+
278
+ private
279
+
280
+ def _transpose_conv_weight?(param)
281
+ return false unless param.respond_to?(:shape)
282
+ return false unless param.shape.is_a?(Array)
283
+ return false unless param.shape.length >= 3
284
+
285
+ param.shape[-1] != 1
286
+ end
287
+ end
288
+
289
+ Models.register("mamba2", Model, ModelArgs)
290
+ end
291
+ end
292
+ end