mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
@@ -0,0 +1,421 @@
1
+ require_relative "activations"
2
+ require_relative "cache"
3
+ require_relative "rope_utils"
4
+ require_relative "switch_layers"
5
+
6
+ module MlxLm
7
+ module Models
8
+ module ExaoneMoe
9
+ class ModelArgs < BaseModelArgs
10
+ field :model_type, default: "exaone_moe"
11
+ field :vocab_size
12
+ field :hidden_size
13
+ field :intermediate_size
14
+ field :moe_intermediate_size
15
+ field :num_hidden_layers
16
+ field :num_attention_heads
17
+ field :num_key_value_heads, default: nil
18
+ field :head_dim, default: nil
19
+ field :num_experts
20
+ field :num_experts_per_tok
21
+ field :num_shared_experts
22
+ field :rms_norm_eps
23
+ field :max_position_embeddings
24
+ field :sliding_window
25
+ field :layer_types, default: nil
26
+ field :is_moe_layer, default: nil
27
+ field :n_group, default: 1
28
+ field :topk_group, default: 1
29
+ field :routed_scaling_factor, default: 2.5
30
+ field :norm_topk_prob, default: true
31
+ field :scoring_func, default: "sigmoid"
32
+ field :topk_method, default: "noaux_tc"
33
+ field :rope_theta, default: 1_000_000.0
34
+ field :rope_scaling, default: nil
35
+ field :rope_parameters, default: nil
36
+ field :tie_word_embeddings, default: false
37
+
38
+ def initialize(**kwargs)
39
+ super
40
+ @num_key_value_heads ||= @num_attention_heads
41
+ @head_dim ||= @hidden_size / @num_attention_heads
42
+ @layer_types ||= Array.new(@num_hidden_layers) { "full_attention" }
43
+ @is_moe_layer ||= Array.new(@num_hidden_layers, false)
44
+
45
+ return unless @rope_parameters.respond_to?(:[])
46
+
47
+ rope_theta = @rope_parameters["rope_theta"] || @rope_parameters[:rope_theta]
48
+ @rope_theta = rope_theta unless rope_theta.nil?
49
+ end
50
+ end
51
+
52
+ module_function
53
+
54
+ def group_expert_select(
55
+ gates,
56
+ e_score_correction_bias,
57
+ top_k,
58
+ n_group,
59
+ topk_group,
60
+ routed_scaling_factor,
61
+ norm_topk_prob
62
+ )
63
+ mx = MLX::Core
64
+
65
+ scores = mx.sigmoid(gates.astype(mx.float32))
66
+ orig_scores = scores
67
+ scores = scores + e_score_correction_bias
68
+
69
+ if n_group.to_i > 1
70
+ experts_per_group = scores.shape[-1] / n_group
71
+ scores = mx.unflatten(scores, -1, [n_group, experts_per_group])
72
+ group_scores = mx.topk(scores, 2, -1)
73
+ group_scores = mx.expand_dims(mx.sum(group_scores, -1), -1)
74
+
75
+ drop_count = n_group - topk_group.to_i
76
+ if drop_count > 0
77
+ group_idx = mx.argpartition(group_scores, drop_count - 1, -2)
78
+ take_ids = mx.array((0...drop_count).to_a, dtype: mx.int32)
79
+ group_idx = mx.take(group_idx, take_ids, -2)
80
+ scores = mx.put_along_axis(
81
+ scores,
82
+ mx.stop_gradient(group_idx),
83
+ mx.array(0.0),
84
+ -2
85
+ )
86
+ end
87
+
88
+ scores = mx.flatten(scores, -2, -1)
89
+ end
90
+
91
+ k = [[top_k.to_i, 1].max, scores.shape[-1]].min
92
+ inds = mx.argpartition(scores * -1.0, k - 1, -1)
93
+ take_ids = mx.array((0...k).to_a, dtype: mx.int32)
94
+ inds = mx.take(inds, take_ids, -1)
95
+
96
+ selected_scores = mx.take_along_axis(orig_scores, inds, -1)
97
+ if k > 1 && norm_topk_prob
98
+ denominator = mx.expand_dims(mx.sum(selected_scores, -1), -1)
99
+ selected_scores = selected_scores / (denominator + 1e-20)
100
+ end
101
+
102
+ selected_scores = selected_scores * routed_scaling_factor.to_f
103
+ [inds, selected_scores]
104
+ end
105
+
106
+ class MoEGate < MLX::NN::Module
107
+ def initialize(args)
108
+ super()
109
+ @top_k = args.num_experts_per_tok
110
+ @norm_topk_prob = args.norm_topk_prob
111
+ @n_routed_experts = args.num_experts
112
+ @routed_scaling_factor = args.routed_scaling_factor
113
+ @n_group = args.n_group
114
+ @topk_group = args.topk_group
115
+
116
+ raise ArgumentError, "Unsupported topk method: #{args.topk_method}" unless args.topk_method == "noaux_tc"
117
+
118
+ mx = MLX::Core
119
+ self.weight = mx.zeros([@n_routed_experts, args.hidden_size])
120
+ self.e_score_correction_bias = mx.zeros([@n_routed_experts])
121
+ end
122
+
123
+ def call(x)
124
+ mx = MLX::Core
125
+ gates = mx.matmul(x, mx.transpose(weight))
126
+ ExaoneMoe.group_expert_select(
127
+ gates,
128
+ e_score_correction_bias,
129
+ @top_k,
130
+ @n_group,
131
+ @topk_group,
132
+ @routed_scaling_factor,
133
+ @norm_topk_prob
134
+ )
135
+ end
136
+ end
137
+
138
+ class MLP < MLX::NN::Module
139
+ def initialize(args, intermediate_size: nil)
140
+ super()
141
+ hidden_size = args.hidden_size
142
+ intermediate_size ||= args.intermediate_size
143
+
144
+ self.gate_proj = MLX::NN::Linear.new(hidden_size, intermediate_size, bias: false)
145
+ self.up_proj = MLX::NN::Linear.new(hidden_size, intermediate_size, bias: false)
146
+ self.down_proj = MLX::NN::Linear.new(intermediate_size, hidden_size, bias: false)
147
+ end
148
+
149
+ def call(x)
150
+ down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x)))
151
+ end
152
+ end
153
+
154
+ class MoE < MLX::NN::Module
155
+ def initialize(args)
156
+ super()
157
+ @num_shared_experts = args.num_shared_experts
158
+
159
+ self.switch_mlp = SwitchLayers::SwitchGLU.new(
160
+ args.hidden_size,
161
+ args.moe_intermediate_size,
162
+ args.num_experts
163
+ )
164
+ self.gate = MoEGate.new(args)
165
+
166
+ if !@num_shared_experts.nil? && @num_shared_experts > 0
167
+ shared_intermediate = args.moe_intermediate_size * @num_shared_experts
168
+ self.shared_experts = MLP.new(args, intermediate_size: shared_intermediate)
169
+ end
170
+ end
171
+
172
+ def call(x)
173
+ mx = MLX::Core
174
+ inds, scores = gate.call(x)
175
+ y = switch_mlp.call(x, inds)
176
+ y = mx.sum(y * mx.expand_dims(scores, -1), -2).astype(y.dtype)
177
+ y = y + shared_experts.call(x) if respond_to?(:shared_experts)
178
+ y
179
+ end
180
+ end
181
+
182
+ class Attention < MLX::NN::Module
183
+ attr_reader :is_sliding_window
184
+
185
+ def initialize(args, layer_idx)
186
+ super()
187
+
188
+ @hidden_size = args.hidden_size
189
+ @n_heads = args.num_attention_heads
190
+ @n_kv_heads = args.num_key_value_heads
191
+ @head_dim = args.head_dim
192
+ @scale = @head_dim**(-0.5)
193
+
194
+ self.q_proj = MLX::NN::Linear.new(@hidden_size, @n_heads * @head_dim, bias: false)
195
+ self.k_proj = MLX::NN::Linear.new(@hidden_size, @n_kv_heads * @head_dim, bias: false)
196
+ self.v_proj = MLX::NN::Linear.new(@hidden_size, @n_kv_heads * @head_dim, bias: false)
197
+ self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, @hidden_size, bias: false)
198
+
199
+ self.q_norm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps)
200
+ self.k_norm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps)
201
+
202
+ @is_sliding_window = args.layer_types[layer_idx] == "sliding_attention"
203
+ apply_rope_all_layers = !args.layer_types.include?("sliding_attention")
204
+ @use_rope = @is_sliding_window || apply_rope_all_layers
205
+
206
+ if @use_rope
207
+ self.rope = MlxLm::Models.initialize_rope(
208
+ @head_dim,
209
+ args.rope_theta,
210
+ false,
211
+ args.rope_scaling,
212
+ max_position_embeddings: args.max_position_embeddings
213
+ )
214
+ end
215
+ end
216
+
217
+ def call(x, mask: nil, cache: nil)
218
+ mx = MLX::Core
219
+ b, l, _d = x.shape
220
+
221
+ queries = q_proj.call(x)
222
+ keys = k_proj.call(x)
223
+ values = v_proj.call(x)
224
+
225
+ queries = q_norm.call(queries.reshape([b, l, @n_heads, @head_dim])).transpose([0, 2, 1, 3])
226
+ keys = k_norm.call(keys.reshape([b, l, @n_kv_heads, @head_dim])).transpose([0, 2, 1, 3])
227
+ values = values.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
228
+
229
+ if cache
230
+ if @use_rope
231
+ queries = rope.call(queries, offset: cache.offset)
232
+ keys = rope.call(keys, offset: cache.offset)
233
+ end
234
+ keys, values = cache.update_and_fetch(keys, values)
235
+ elsif @use_rope
236
+ queries = rope.call(queries)
237
+ keys = rope.call(keys)
238
+ end
239
+
240
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
241
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
242
+ o_proj.call(output)
243
+ end
244
+ end
245
+
246
+ class DecoderLayer < MLX::NN::Module
247
+ attr_reader :is_sliding_window
248
+
249
+ def initialize(args, layer_idx)
250
+ super()
251
+
252
+ self.self_attn = Attention.new(args, layer_idx)
253
+ self.mlp = args.is_moe_layer[layer_idx] ? MoE.new(args) : MLP.new(args)
254
+ @is_sliding_window = self_attn.is_sliding_window
255
+
256
+ self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
257
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
258
+ end
259
+
260
+ def call(x, mask: nil, cache: nil)
261
+ r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
262
+ h = x + r
263
+ r = mlp.call(post_attention_layernorm.call(h))
264
+ h + r
265
+ end
266
+ end
267
+
268
+ class ExaoneMoeModel < MLX::NN::Module
269
+ def initialize(args)
270
+ super()
271
+ @window_size = args.sliding_window
272
+
273
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
274
+ self.layers = Array.new(args.num_hidden_layers) { |idx| DecoderLayer.new(args, idx) }
275
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
276
+
277
+ self.swa_idx = nil
278
+ self.ga_idx = nil
279
+ layers.each_with_index do |layer, idx|
280
+ self.swa_idx = idx if swa_idx.nil? && layer.is_sliding_window
281
+ self.ga_idx = idx if ga_idx.nil? && !layer.is_sliding_window
282
+ break unless swa_idx.nil? || ga_idx.nil?
283
+ end
284
+ end
285
+
286
+ def call(inputs, cache: nil)
287
+ h = embed_tokens.call(inputs)
288
+ layer_cache = cache || [nil] * layers.length
289
+
290
+ global_cache = ga_idx.nil? ? layer_cache[0] : layer_cache[ga_idx]
291
+ swa_cache = swa_idx.nil? ? layer_cache[0] : layer_cache[swa_idx]
292
+
293
+ global_mask = _create_attention_mask(h, global_cache)
294
+ swa_mask = _create_attention_mask(h, swa_cache, window_size: @window_size)
295
+
296
+ layers.each_with_index do |layer, idx|
297
+ mask = layer.is_sliding_window ? swa_mask : global_mask
298
+ h = layer.call(h, mask: mask, cache: layer_cache[idx])
299
+ end
300
+
301
+ norm.call(h)
302
+ end
303
+
304
+ private
305
+
306
+ def _create_attention_mask(h, cache = nil, window_size: nil)
307
+ n = h.shape[1]
308
+ if cache && cache.respond_to?(:make_mask)
309
+ return cache.make_mask(n, window_size: window_size)
310
+ end
311
+
312
+ if window_size
313
+ offset = 0
314
+ if cache
315
+ offset = cache.offset if cache.respond_to?(:offset)
316
+ if cache.instance_variable_defined?(:@max_size)
317
+ max_size = cache.instance_variable_get(:@max_size)
318
+ offset = [max_size - 1, offset].min if max_size && max_size > 0
319
+ end
320
+ end
321
+ return _create_causal_mask(n, offset: offset, window_size: window_size) if offset + n > window_size
322
+ end
323
+
324
+ return nil if n == 1
325
+
326
+ "causal"
327
+ end
328
+
329
+ def _create_causal_mask(n, offset: 0, window_size: nil)
330
+ mx = MLX::Core
331
+ rinds = mx.arange(0, offset + n, 1, mx.int32).reshape([1, offset + n])
332
+ linds = mx.arange(offset, offset + n, 1, mx.int32).reshape([n, 1])
333
+
334
+ mask = mx.greater_equal(linds, rinds)
335
+ if window_size
336
+ mask = mx.logical_and(mask, mx.less(linds, mx.add(rinds, window_size)))
337
+ end
338
+ mask
339
+ end
340
+ end
341
+
342
+ class Model < MLX::NN::Module
343
+ def initialize(args)
344
+ super()
345
+ @args = args
346
+ self.model_type = args.model_type
347
+ self.model = ExaoneMoeModel.new(args)
348
+
349
+ unless args.tie_word_embeddings
350
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
351
+ end
352
+ end
353
+
354
+ def call(inputs, cache: nil)
355
+ out = model.call(inputs, cache: cache)
356
+ if @args.tie_word_embeddings
357
+ model.embed_tokens.as_linear(out)
358
+ else
359
+ lm_head.call(out)
360
+ end
361
+ end
362
+
363
+ def sanitize(weights)
364
+ mx = MLX::Core
365
+ result = weights.reject { |k, _| k.start_with?("mtp.") }
366
+ num_experts = @args.num_experts.to_i
367
+
368
+ @args.num_hidden_layers.to_i.times do |layer_idx|
369
+ next unless @args.is_moe_layer[layer_idx]
370
+
371
+ prefix = "model.layers.#{layer_idx}.mlp"
372
+ bias_key = "#{prefix}.e_score_correction_bias"
373
+ if result.key?(bias_key)
374
+ result["#{prefix}.gate.e_score_correction_bias"] = result.delete(bias_key)
375
+ end
376
+
377
+ %w[gate_proj down_proj up_proj].each do |proj_name|
378
+ %w[weight scales biases].each do |param_name|
379
+ first_key = "#{prefix}.experts.0.#{proj_name}.#{param_name}"
380
+ last_key = "#{prefix}.experts.#{num_experts - 1}.#{proj_name}.#{param_name}"
381
+ next unless result.key?(first_key) && result.key?(last_key)
382
+
383
+ expert_keys = (0...num_experts).map do |expert_idx|
384
+ "#{prefix}.experts.#{expert_idx}.#{proj_name}.#{param_name}"
385
+ end
386
+ next unless expert_keys.all? { |key| result.key?(key) }
387
+
388
+ stacked = expert_keys.map { |key| result.delete(key) }
389
+ result["#{prefix}.switch_mlp.#{proj_name}.#{param_name}"] = mx.stack(stacked)
390
+ end
391
+ end
392
+ end
393
+
394
+ result.delete("lm_head.weight") if @args.tie_word_embeddings
395
+ result
396
+ end
397
+
398
+ def layers
399
+ model.layers
400
+ end
401
+
402
+ def cast_predicate
403
+ lambda { |key| !key.include?("e_score_correction_bias") }
404
+ end
405
+
406
+ def make_cache
407
+ max_window = @args.sliding_window || @args.max_position_embeddings || 1
408
+ layers.map do |layer|
409
+ if layer.is_sliding_window
410
+ RotatingKVCache.new(max_size: max_window, keep: 0)
411
+ else
412
+ KVCache.new
413
+ end
414
+ end
415
+ end
416
+ end
417
+
418
+ Models.register("exaone_moe", Model, ModelArgs)
419
+ end
420
+ end
421
+ end
@@ -0,0 +1,102 @@
1
+ require_relative "recurrent_gemma"
2
+
3
+ module MlxLm
4
+ module Models
5
+ module FalconH1
6
+ class ModelArgs < BaseModelArgs
7
+ field :model_type, default: "falcon_h1"
8
+ field :attention_bias, default: false
9
+ field :head_dim, default: 64
10
+ field :hidden_size, default: 1024
11
+ field :intermediate_size, default: 2048
12
+ field :max_position_embeddings, default: 131_072
13
+ field :mamba_d_conv, default: 4
14
+ field :num_attention_heads, default: 8
15
+ field :num_hidden_layers, default: 36
16
+ field :num_key_value_heads, default: 2
17
+ field :rms_norm_eps, default: 1e-5
18
+ field :rope_theta, default: 100_000_000_000.0
19
+ field :vocab_size, default: 32_784
20
+ field :tie_word_embeddings, default: true
21
+ field :logits_soft_cap, default: nil
22
+ field :attention_window_size, default: nil
23
+ field :block_types, default: nil
24
+ end
25
+
26
+ class Model < MLX::NN::Module
27
+ def initialize(args)
28
+ super()
29
+ @args = args
30
+ self.model_type = args.model_type
31
+ self.language_model = RecurrentGemma::Model.new(
32
+ RecurrentGemma::ModelArgs.from_dict(_to_recurrent_gemma_config(args))
33
+ )
34
+
35
+ if args.tie_word_embeddings
36
+ language_model.instance_variable_set(:@tie_word_embeddings, true)
37
+ language_model.lm_head = nil if language_model.respond_to?(:lm_head=)
38
+ end
39
+ end
40
+
41
+ def call(inputs, cache: nil)
42
+ language_model.call(inputs, cache: cache)
43
+ end
44
+
45
+ def sanitize(weights)
46
+ remapped = {}
47
+ weights.each do |key, value|
48
+ remapped[_remap_weight_key(key)] = value
49
+ end
50
+ language_model.sanitize(remapped)
51
+ end
52
+
53
+ def layers
54
+ language_model.layers
55
+ end
56
+
57
+ def make_cache
58
+ return language_model.make_cache if language_model.respond_to?(:make_cache)
59
+
60
+ nil
61
+ end
62
+
63
+ private
64
+
65
+ def _to_recurrent_gemma_config(args)
66
+ {
67
+ "model_type" => args.model_type,
68
+ "attention_bias" => args.attention_bias,
69
+ "conv1d_width" => args.mamba_d_conv || 4,
70
+ "hidden_size" => args.hidden_size,
71
+ "intermediate_size" => args.intermediate_size,
72
+ "logits_soft_cap" => args.logits_soft_cap,
73
+ "num_attention_heads" => args.num_attention_heads,
74
+ "num_hidden_layers" => args.num_hidden_layers,
75
+ "num_key_value_heads" => args.num_key_value_heads || args.num_attention_heads,
76
+ "rms_norm_eps" => args.rms_norm_eps,
77
+ "rope_theta" => args.rope_theta,
78
+ "attention_window_size" => args.attention_window_size || [args.max_position_embeddings.to_i, 128].min,
79
+ "vocab_size" => args.vocab_size,
80
+ "embeddings_scale_by_sqrt_dim" => false,
81
+ "block_types" => args.block_types || ["recurrent", "attention"],
82
+ }
83
+ end
84
+
85
+ def _remap_weight_key(key)
86
+ mapped = key.dup
87
+ mapped = mapped.gsub(".mamba.conv1d.", ".temporal_block.conv_1d.")
88
+ mapped = mapped.gsub(".mamba.out_proj.", ".temporal_block.linear_out.")
89
+ mapped = mapped.gsub(".mamba.in_proj.", ".temporal_block.linear_x.")
90
+ mapped = mapped.gsub(".self_attn.", ".temporal_block.")
91
+ mapped = mapped.gsub(".feed_forward.", ".mlp_block.")
92
+ mapped = mapped.gsub(".input_layernorm.", ".temporal_pre_norm.")
93
+ mapped = mapped.gsub(".pre_ff_layernorm.", ".channel_pre_norm.")
94
+ mapped = mapped.gsub("model.final_layernorm.", "model.final_norm.")
95
+ mapped
96
+ end
97
+ end
98
+
99
+ Models.register("falcon_h1", Model, ModelArgs)
100
+ end
101
+ end
102
+ end
@@ -0,0 +1,136 @@
1
+ module MlxLm
2
+ module Models
3
+ module GatedDelta
4
+ module_function
5
+
6
+ def compute_g(a_log, a, dt_bias)
7
+ mx = MLX::Core
8
+ decay = mx.exp(a_log.astype(mx.float32)) * MLX::NN.softplus(a + dt_bias)
9
+ mx.exp(
10
+ mx.multiply(-1.0, decay)
11
+ ).astype(a.dtype)
12
+ end
13
+
14
+ def gated_delta_kernel(q, k, v, g, beta, state, mask = nil)
15
+ # TODO: Add a Metal custom-kernel specialization for prefill throughput parity.
16
+ gated_delta_ops(q, k, v, g, beta, state, mask)
17
+ end
18
+
19
+ def gated_delta_ops(q, k, v, g, beta, state = nil, mask = nil)
20
+ mx = MLX::Core
21
+ bsz, steps, hk, dk = q.shape
22
+ v_shape = v.shape
23
+ hv = v_shape[-2]
24
+ dv = v_shape[-1]
25
+
26
+ state ||= mx.zeros([bsz, hv, dv, dk], q.dtype)
27
+
28
+ repeat_factor = hv / hk
29
+ if repeat_factor > 1
30
+ q = mx.repeat(q, repeat_factor, -2)
31
+ k = mx.repeat(k, repeat_factor, -2)
32
+ end
33
+
34
+ q_steps = mx.split(q, steps, 1).map { |x| mx.squeeze(x, 1) }
35
+ k_steps = mx.split(k, steps, 1).map { |x| mx.squeeze(x, 1) }
36
+ v_steps = mx.split(v, steps, 1).map { |x| mx.squeeze(x, 1) }
37
+ g_steps = mx.split(g, steps, 1).map { |x| mx.squeeze(x, 1) }
38
+ beta_steps = mx.split(beta, steps, 1).map { |x| mx.squeeze(x, 1) }
39
+ mask_steps =
40
+ if mask.nil?
41
+ nil
42
+ elsif mask.ndim == 1
43
+ [mask]
44
+ else
45
+ mx.split(mask, steps, 1).map { |x| mx.squeeze(x, 1) }
46
+ end
47
+
48
+ ys = []
49
+ steps.times do |t|
50
+ y, state = _gated_delta_step_ops(
51
+ q_steps[t],
52
+ k_steps[t],
53
+ v_steps[t],
54
+ g_steps[t],
55
+ beta_steps[t],
56
+ state,
57
+ mask_steps&.[](t)
58
+ )
59
+ ys << y
60
+ end
61
+
62
+ [mx.stack(ys, 1), state]
63
+ end
64
+
65
+ def gated_delta_update(
66
+ q,
67
+ k,
68
+ v,
69
+ a,
70
+ b,
71
+ a_log,
72
+ dt_bias,
73
+ state = nil,
74
+ mask = nil,
75
+ use_kernel: true
76
+ )
77
+ mx = MLX::Core
78
+ beta = mx.sigmoid(b)
79
+ g = compute_g(a_log, a, dt_bias)
80
+
81
+ if state.nil?
82
+ bsz, = q.shape
83
+ dk = q.shape[-1]
84
+ hv = v.shape[-2]
85
+ dv = v.shape[-1]
86
+ state = mx.zeros([bsz, hv, dv, dk], q.dtype)
87
+ end
88
+
89
+ if use_kernel && metal_kernel_available?
90
+ gated_delta_kernel(q, k, v, g, beta, state, mask)
91
+ else
92
+ gated_delta_ops(q, k, v, g, beta, state, mask)
93
+ end
94
+ end
95
+
96
+ def _gated_delta_step_ops(q, k, v, g, beta, state, mask = nil)
97
+ mx = MLX::Core
98
+ old_state = state
99
+
100
+ decay = case g.ndim
101
+ when 2
102
+ mx.expand_dims(g, [2, 3])
103
+ when 3
104
+ mx.expand_dims(g, 2)
105
+ else
106
+ raise ArgumentError, "Unsupported gating shape #{g.shape.inspect}"
107
+ end
108
+
109
+ state = state * decay
110
+ k_expanded = mx.expand_dims(k, 2)
111
+ kv_mem = (state * k_expanded).sum(-1)
112
+ delta = (v - kv_mem) * mx.expand_dims(beta, -1)
113
+ state = state + k_expanded * mx.expand_dims(delta, -1)
114
+ y = (state * mx.expand_dims(q, 2)).sum(-1)
115
+
116
+ unless mask.nil?
117
+ mask_shape = [mask.shape[0]] + [1] * (state.ndim - 1)
118
+ state = mx.where(mask.reshape(mask_shape), state, old_state)
119
+ end
120
+
121
+ [y, state]
122
+ end
123
+ private_class_method :_gated_delta_step_ops
124
+
125
+ def metal_kernel_available?
126
+ mx = MLX::Core
127
+ return false unless mx.respond_to?(:metal_is_available) && mx.metal_is_available
128
+ return false unless mx.respond_to?(:default_device)
129
+
130
+ device = mx.default_device
131
+ device.respond_to?(:type) && device.type == :gpu
132
+ end
133
+ private_class_method :metal_kernel_available?
134
+ end
135
+ end
136
+ end