mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
@@ -0,0 +1,306 @@
1
+ module MlxLm
2
+ module Models
3
+ module BaichuanM1
4
+ class ModelArgs < BaseModelArgs
5
+ field :vocab_size
6
+ field :hidden_size
7
+ field :intermediate_size
8
+ field :num_hidden_layers
9
+ field :num_attention_heads
10
+ field :num_key_value_heads
11
+ field :rope_theta
12
+ field :sliding_window
13
+ field :sliding_window_layers
14
+ field :conv_window
15
+ field :rms_norm_eps
16
+ field :model_type, default: "baichuan_m1"
17
+ field :num_swa_attention_heads, default: nil
18
+ field :num_swa_key_value_heads, default: nil
19
+ field :tie_word_embeddings, default: false
20
+ end
21
+
22
+ class Attention < MLX::NN::Module
23
+ def initialize(config, layer_idx: nil)
24
+ super()
25
+
26
+ raise ArgumentError, "Layer index must be provided to Attention module." if layer_idx.nil?
27
+
28
+ swa_layers = config.sliding_window_layers || []
29
+ @is_swa = swa_layers.include?(layer_idx)
30
+
31
+ @num_heads = if @is_swa && config.num_swa_attention_heads
32
+ config.num_swa_attention_heads
33
+ else
34
+ config.num_attention_heads
35
+ end
36
+
37
+ @num_kv_heads = if @is_swa && config.num_swa_key_value_heads
38
+ config.num_swa_key_value_heads
39
+ else
40
+ config.num_key_value_heads
41
+ end
42
+
43
+ @hidden_size = config.hidden_size
44
+ @head_dim = @hidden_size / @num_heads
45
+
46
+ unless (@head_dim * @num_heads) == @hidden_size
47
+ raise ArgumentError, "hidden_size must be divisible by num_heads"
48
+ end
49
+
50
+ @scale = @head_dim**(-0.5)
51
+
52
+ self.w_pack = MLX::NN::Linear.new(
53
+ config.hidden_size,
54
+ @hidden_size + 2 * @num_kv_heads * @head_dim,
55
+ bias: false
56
+ )
57
+ self.o_proj = MLX::NN::Linear.new(
58
+ @num_heads * @head_dim,
59
+ config.hidden_size,
60
+ bias: false
61
+ )
62
+
63
+ self.rope = MLX::NN::RoPE.new(@head_dim, traditional: false, base: config.rope_theta)
64
+
65
+ @conv_window = config.conv_window
66
+ raise ArgumentError, "conv_window must be 2" unless @conv_window == 2
67
+
68
+ mx = MLX::Core
69
+ self.conv_k = mx.zeros([1, 1, @num_kv_heads, 1, @conv_window])
70
+ self.conv_v = mx.zeros([1, 1, @num_kv_heads, 1, @conv_window])
71
+ end
72
+
73
+ def call(x, mask: nil, cache: nil)
74
+ mx = MLX::Core
75
+ b, l, d = x.shape
76
+
77
+ proj = w_pack.call(x)
78
+ q, k, v = mx.split(proj, [d, d + @num_kv_heads * @head_dim], -1)
79
+
80
+ q = q.reshape([b, l, @num_heads, @head_dim]).transpose([0, 2, 1, 3])
81
+ k = k.reshape([b, l, @num_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
82
+ v = v.reshape([b, l, @num_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
83
+
84
+ layer_cache = cache || [nil, nil]
85
+ conv_cache = layer_cache[0]
86
+ kv_cache = layer_cache[1]
87
+
88
+ if conv_cache
89
+ offset = kv_cache.offset
90
+ last_k = conv_cache[0]
91
+ last_v = conv_cache[1]
92
+ else
93
+ offset = 0
94
+ last_k = nil
95
+ last_v = nil
96
+ end
97
+
98
+ k_init = k
99
+ v_init = v
100
+
101
+ k = _custom_convolution(k, conv_k, state: last_k)
102
+ v = _custom_convolution(v, conv_v, state: last_v)
103
+ q = rope.call(q, offset: offset)
104
+ k = rope.call(k, offset: offset)
105
+
106
+ if conv_cache
107
+ k, v = kv_cache.update_and_fetch(k, v)
108
+ if l > 0
109
+ conv_cache[0] = mx.split(k_init, [l - 1], 2)[1]
110
+ conv_cache[1] = mx.split(v_init, [l - 1], 2)[1]
111
+ end
112
+ end
113
+
114
+ out = mx.scaled_dot_product_attention(q, k, v, @scale, mask)
115
+ out = out.transpose([0, 2, 1, 3]).reshape([b, l, @num_heads * @head_dim])
116
+ o_proj.call(out)
117
+ end
118
+
119
+ private
120
+
121
+ def _custom_convolution(u, weights, state: nil)
122
+ mx = MLX::Core
123
+ b, h, l, d = u.shape
124
+
125
+ weights = weights.reshape([1, h, @conv_window, 1, 1])
126
+ w0 = mx.take(weights, 0, 2)
127
+ w1 = mx.take(weights, 1, 2)
128
+
129
+ state ||= mx.zeros([b, h, 1, d], u.dtype)
130
+ if l > 1
131
+ u_prev = mx.concatenate([state, mx.split(u, [l - 1], 2)[0]], 2)
132
+ else
133
+ u_prev = state
134
+ end
135
+
136
+ mx.add(mx.multiply(u_prev, w0), mx.multiply(u, w1))
137
+ end
138
+ end
139
+
140
+ class MLP < MLX::NN::Module
141
+ def initialize(config)
142
+ super()
143
+ self.gate_proj = MLX::NN::Linear.new(
144
+ config.hidden_size,
145
+ config.intermediate_size,
146
+ bias: false
147
+ )
148
+ self.up_proj = MLX::NN::Linear.new(
149
+ config.hidden_size,
150
+ config.intermediate_size,
151
+ bias: false
152
+ )
153
+ self.down_proj = MLX::NN::Linear.new(
154
+ config.intermediate_size,
155
+ config.hidden_size,
156
+ bias: false
157
+ )
158
+ end
159
+
160
+ def call(x)
161
+ down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x)))
162
+ end
163
+ end
164
+
165
+ class DecoderLayer < MLX::NN::Module
166
+ def initialize(config, layer_idx)
167
+ super()
168
+ self.self_attn = Attention.new(config, layer_idx: layer_idx)
169
+ self.mlp = MLP.new(config)
170
+ self.input_layernorm = MLX::NN::RMSNorm.new(config.hidden_size, eps: config.rms_norm_eps)
171
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(config.hidden_size, eps: config.rms_norm_eps)
172
+ end
173
+
174
+ def call(x, mask: nil, cache: nil)
175
+ r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
176
+ h = x + r
177
+ r = mlp.call(post_attention_layernorm.call(h))
178
+ h + r
179
+ end
180
+ end
181
+
182
+ class BaichuanModel < MLX::NN::Module
183
+ def initialize(config)
184
+ super()
185
+ @config = config
186
+ @sliding_window = config.sliding_window
187
+ @swa_layers = config.sliding_window_layers || []
188
+
189
+ self.embed_tokens = MLX::NN::Embedding.new(config.vocab_size, config.hidden_size)
190
+ self.layers = Array.new(config.num_hidden_layers) { |i| DecoderLayer.new(config, i) }
191
+ self.norm = MLX::NN::RMSNorm.new(config.hidden_size, eps: config.rms_norm_eps)
192
+
193
+ self.first_swa_idx = @swa_layers.empty? ? nil : @swa_layers[0]
194
+ self.first_global_idx = nil
195
+ config.num_hidden_layers.times do |i|
196
+ next if @swa_layers.include?(i)
197
+
198
+ self.first_global_idx = i
199
+ break
200
+ end
201
+ end
202
+
203
+ def call(inputs, cache: nil)
204
+ x = embed_tokens.call(inputs)
205
+ layer_cache = cache || Array.new(layers.length) { [nil, nil] }
206
+
207
+ c_global = first_global_idx.nil? ? nil : layer_cache[first_global_idx][1]
208
+ c_swa = first_swa_idx.nil? ? nil : layer_cache[first_swa_idx][1]
209
+
210
+ global_mask = _create_attention_mask(x, c_global)
211
+ swa_mask = _create_attention_mask(x, c_swa, window_size: @sliding_window)
212
+
213
+ layers.each_with_index do |layer, i|
214
+ mask = @swa_layers.include?(i) ? swa_mask : global_mask
215
+ x = layer.call(x, mask: mask, cache: layer_cache[i])
216
+ end
217
+
218
+ norm.call(x)
219
+ end
220
+
221
+ private
222
+
223
+ def _create_attention_mask(x, cache = nil, window_size: nil)
224
+ n = x.shape[1]
225
+ return cache.make_mask(n, window_size: window_size) if cache && cache.respond_to?(:make_mask)
226
+ return nil if n == 1
227
+ return _create_causal_mask(n, window_size: window_size) if window_size && n > window_size
228
+
229
+ "causal"
230
+ end
231
+
232
+ def _create_causal_mask(n, offset: 0, window_size: nil)
233
+ mx = MLX::Core
234
+ rinds = mx.arange(0, offset + n, 1, mx.int32).reshape([1, offset + n])
235
+ linds = mx.arange(offset, offset + n, 1, mx.int32).reshape([n, 1])
236
+
237
+ mask = mx.greater_equal(linds, rinds)
238
+ if window_size
239
+ mask = mx.logical_and(mask, mx.less(linds, mx.add(rinds, window_size)))
240
+ end
241
+ mask
242
+ end
243
+ end
244
+
245
+ class Model < MLX::NN::Module
246
+ def initialize(config)
247
+ super()
248
+ @config = config
249
+ self.model_type = config.model_type
250
+ self.model = BaichuanModel.new(config)
251
+ @tie_word_embeddings = config.tie_word_embeddings
252
+ unless @tie_word_embeddings
253
+ self.lm_head = MLX::NN::Linear.new(config.hidden_size, config.vocab_size, bias: false)
254
+ end
255
+ end
256
+
257
+ def make_cache
258
+ caches = []
259
+ swa_layers = @config.sliding_window_layers || []
260
+ @config.num_hidden_layers.times do |i|
261
+ is_swa = swa_layers.include?(i)
262
+ conv_cache = MlxLm::ArraysCache.new(2)
263
+ kv_cache = if is_swa
264
+ MlxLm::RotatingKVCache.new(max_size: @config.sliding_window)
265
+ else
266
+ MlxLm::KVCache.new
267
+ end
268
+ caches << MlxLm::CacheList.new(conv_cache, kv_cache)
269
+ end
270
+ caches
271
+ end
272
+
273
+ def sanitize(weights)
274
+ mx = MLX::Core
275
+ is_quantized = weights.key?("lm_head.scales")
276
+
277
+ if !is_quantized && weights.key?("lm_head.weight")
278
+ w = weights["lm_head.weight"]
279
+ dtype = w.dtype
280
+ w = w.astype(mx.float32)
281
+ norm = mx.norm(w, nil, -1, true)
282
+ w = (w / (norm + 1e-7)).astype(dtype)
283
+ weights["lm_head.weight"] = w
284
+ end
285
+
286
+ weights
287
+ end
288
+
289
+ def call(inputs, cache: nil)
290
+ out = model.call(inputs, cache: cache)
291
+ if @tie_word_embeddings
292
+ model.embed_tokens.as_linear(out)
293
+ else
294
+ lm_head.call(out)
295
+ end
296
+ end
297
+
298
+ def layers
299
+ model.layers
300
+ end
301
+ end
302
+
303
+ Models.register("baichuan_m1", Model, ModelArgs)
304
+ end
305
+ end
306
+ end
@@ -0,0 +1,399 @@
1
+ require_relative "activations"
2
+ require_relative "rope_utils"
3
+ require_relative "switch_layers"
4
+
5
+ module MlxLm
6
+ module Models
7
+ module BailingMoe
8
+ class ModelArgs < BaseModelArgs
9
+ field :model_type
10
+ field :hidden_size
11
+ field :intermediate_size
12
+ field :max_position_embeddings
13
+ field :moe_intermediate_size
14
+ field :num_experts
15
+ field :num_shared_experts
16
+ field :norm_topk_prob
17
+ field :num_attention_heads
18
+ field :num_experts_per_tok
19
+ field :num_hidden_layers
20
+ field :num_key_value_heads
21
+ field :rms_norm_eps
22
+ field :rope_theta
23
+ field :vocab_size
24
+ field :first_k_dense_replace
25
+ field :rope_scaling, default: nil
26
+ field :use_bias, default: false
27
+ field :use_qkv_bias, default: false
28
+ field :norm_head, default: false
29
+ field :norm_softmax, default: false
30
+ field :use_qk_norm, default: false
31
+ field :tie_word_embeddings, default: false
32
+ field :partial_rotary_factor, default: 1.0
33
+ field :rotary_dim, default: nil
34
+ field :moe_router_enable_expert_bias, default: false
35
+ field :moe_router_enable_routed_scaling, default: true
36
+ field :routed_scaling_factor, default: 1.0
37
+ field :score_function, default: "softmax"
38
+ field :n_group, default: 1
39
+ field :topk_group, default: 4
40
+ field :moe_shared_expert_intermediate_size, default: nil
41
+ field :moe_router_enable_shared_expert, default: true
42
+
43
+ def initialize(**kwargs)
44
+ super
45
+ @num_key_value_heads ||= @num_attention_heads
46
+ end
47
+ end
48
+
49
+ module_function
50
+
51
+ def aggregate_expert_outputs(expert_outputs, scores)
52
+ mx = MLX::Core
53
+ mx.sum(expert_outputs * mx.expand_dims(scores, -1), -2).astype(expert_outputs.dtype)
54
+ end
55
+
56
+ def group_expert_select(
57
+ gates,
58
+ e_score_correction_bias,
59
+ top_k,
60
+ n_group,
61
+ topk_group,
62
+ routed_scaling_factor,
63
+ norm_topk_prob,
64
+ score_function
65
+ )
66
+ mx = MLX::Core
67
+ in_type = gates.dtype
68
+
69
+ scores = if score_function == "sigmoid"
70
+ mx.sigmoid(gates.astype(mx.float32))
71
+ else
72
+ mx.softmax(gates.astype(mx.float32), -1)
73
+ end
74
+ orig_scores = scores
75
+ scores = scores + e_score_correction_bias if e_score_correction_bias
76
+
77
+ if n_group.to_i > 1
78
+ experts_per_group = scores.shape[-1] / n_group
79
+ scores = mx.unflatten(scores, -1, [n_group, experts_per_group])
80
+ group_scores = mx.topk(scores, 2, -1)
81
+ group_scores = mx.expand_dims(mx.sum(group_scores, -1), -1)
82
+
83
+ drop_count = n_group - topk_group.to_i
84
+ if drop_count > 0
85
+ group_idx = mx.argpartition(group_scores, drop_count - 1, -2)
86
+ take_ids = mx.array((0...drop_count).to_a, dtype: mx.int32)
87
+ group_idx = mx.take(group_idx, take_ids, -2)
88
+ scores = mx.put_along_axis(
89
+ scores,
90
+ mx.stop_gradient(group_idx),
91
+ mx.array(0.0),
92
+ -2
93
+ )
94
+ end
95
+
96
+ scores = mx.flatten(scores, -2, -1)
97
+ end
98
+
99
+ k = [top_k.to_i, scores.shape[-1]].min
100
+ inds = mx.argpartition(scores * -1.0, k - 1, -1)
101
+ take_ids = mx.array((0...k).to_a, dtype: mx.int32)
102
+ inds = mx.take(inds, take_ids, -1)
103
+
104
+ selected_scores = mx.take_along_axis(orig_scores, inds, -1)
105
+ if k > 1 && norm_topk_prob
106
+ denominator = mx.expand_dims(mx.sum(selected_scores, -1), -1) + 1e-20
107
+ selected_scores = selected_scores / denominator
108
+ end
109
+
110
+ selected_scores = selected_scores * routed_scaling_factor.to_f
111
+ [inds, selected_scores.astype(in_type)]
112
+ end
113
+
114
+ class BailingMoeMLP < MLX::NN::Module
115
+ def initialize(args, intermediate_size: nil)
116
+ super()
117
+ hidden_dim = intermediate_size || args.intermediate_size
118
+
119
+ self.gate_proj = MLX::NN::Linear.new(args.hidden_size, hidden_dim, bias: args.use_bias)
120
+ self.down_proj = MLX::NN::Linear.new(hidden_dim, args.hidden_size, bias: args.use_bias)
121
+ self.up_proj = MLX::NN::Linear.new(args.hidden_size, hidden_dim, bias: args.use_bias)
122
+ end
123
+
124
+ def call(x)
125
+ down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x)))
126
+ end
127
+ end
128
+
129
+ class BailingMoeAttention < MLX::NN::Module
130
+ def initialize(args)
131
+ super()
132
+ @use_qk_norm = args.use_qk_norm
133
+ @num_attention_heads = args.num_attention_heads
134
+ @num_key_value_heads = args.num_key_value_heads
135
+ @head_dim = args.hidden_size / @num_attention_heads
136
+ @scale = @head_dim**(-0.5)
137
+
138
+ self.query_key_value = MLX::NN::Linear.new(
139
+ args.hidden_size,
140
+ (@num_attention_heads + 2 * @num_key_value_heads) * @head_dim,
141
+ bias: args.use_qkv_bias
142
+ )
143
+ self.dense = MLX::NN::Linear.new(
144
+ @num_attention_heads * @head_dim,
145
+ args.hidden_size,
146
+ bias: args.use_bias
147
+ )
148
+
149
+ if @use_qk_norm
150
+ self.key_layernorm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps)
151
+ self.query_layernorm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps)
152
+ end
153
+
154
+ rope_dim = args.rotary_dim || (@head_dim * args.partial_rotary_factor.to_f).to_i
155
+ rope_dim = [rope_dim, 1].max
156
+ self.rope = MlxLm::Models.initialize_rope(
157
+ rope_dim,
158
+ args.rope_theta,
159
+ false,
160
+ args.rope_scaling,
161
+ max_position_embeddings: args.max_position_embeddings
162
+ )
163
+ end
164
+
165
+ def call(x, mask: nil, cache: nil)
166
+ mx = MLX::Core
167
+ b, l, _d = x.shape
168
+
169
+ qkv = query_key_value.call(x)
170
+
171
+ q_size = @num_attention_heads * @head_dim
172
+ kv_size = @num_key_value_heads * @head_dim
173
+ q, k, v = mx.split(qkv, [q_size, q_size + kv_size], -1)
174
+
175
+ queries = q.reshape([b, l, @num_attention_heads, @head_dim]).transpose([0, 2, 1, 3])
176
+ keys = k.reshape([b, l, @num_key_value_heads, @head_dim]).transpose([0, 2, 1, 3])
177
+ values = v.reshape([b, l, @num_key_value_heads, @head_dim]).transpose([0, 2, 1, 3])
178
+
179
+ if @use_qk_norm
180
+ queries = query_layernorm.call(queries)
181
+ keys = key_layernorm.call(keys)
182
+ end
183
+
184
+ if cache
185
+ queries = rope.call(queries, offset: cache.offset)
186
+ keys = rope.call(keys, offset: cache.offset)
187
+ keys, values = cache.update_and_fetch(keys, values)
188
+ else
189
+ queries = rope.call(queries)
190
+ keys = rope.call(keys)
191
+ end
192
+
193
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
194
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @num_attention_heads * @head_dim])
195
+ dense.call(output)
196
+ end
197
+ end
198
+
199
+ class BailingMoeGate < MLX::NN::Module
200
+ def initialize(args)
201
+ super()
202
+ @norm_topk_prob = args.norm_topk_prob
203
+ @top_k = args.num_experts_per_tok
204
+ @n_group = args.n_group
205
+ @topk_group = args.topk_group
206
+ @routed_scaling_factor = args.routed_scaling_factor
207
+ @score_function = args.score_function
208
+
209
+ self.gate_proj = MLX::NN::Linear.new(args.hidden_size, args.num_experts, bias: false)
210
+ self.expert_bias = if args.moe_router_enable_expert_bias
211
+ MLX::Core.zeros([args.num_experts])
212
+ else
213
+ nil
214
+ end
215
+ end
216
+
217
+ def call(x)
218
+ BailingMoe.group_expert_select(
219
+ gate_proj.call(x),
220
+ expert_bias,
221
+ @top_k,
222
+ @n_group,
223
+ @topk_group,
224
+ @routed_scaling_factor,
225
+ @norm_topk_prob,
226
+ @score_function
227
+ )
228
+ end
229
+ end
230
+
231
+ class BailingMoeSparseMoeBlock < MLX::NN::Module
232
+ def initialize(args)
233
+ super()
234
+ self.switch_mlp = SwitchLayers::SwitchGLU.new(
235
+ args.hidden_size,
236
+ args.moe_intermediate_size,
237
+ args.num_experts,
238
+ bias: args.use_bias
239
+ )
240
+ self.gate = BailingMoeGate.new(args)
241
+
242
+ shared_dim = args.moe_shared_expert_intermediate_size || args.moe_intermediate_size
243
+ self.shared_experts = if args.num_shared_experts.to_i > 0 && args.moe_router_enable_shared_expert
244
+ BailingMoeMLP.new(
245
+ args,
246
+ intermediate_size: shared_dim * args.num_shared_experts
247
+ )
248
+ end
249
+ end
250
+
251
+ def call(x)
252
+ topk_idx, topk_weight = gate.call(x)
253
+ out = switch_mlp.call(x, topk_idx)
254
+ out = BailingMoe.aggregate_expert_outputs(out, topk_weight)
255
+ out = out + shared_experts.call(x) if respond_to?(:shared_experts)
256
+ out
257
+ end
258
+ end
259
+
260
+ class BailingMoeDecoderLayer < MLX::NN::Module
261
+ def initialize(args, layer_idx:)
262
+ super()
263
+ self.attention = BailingMoeAttention.new(args)
264
+ self.mlp = if !args.num_experts.nil? && layer_idx >= args.first_k_dense_replace
265
+ BailingMoeSparseMoeBlock.new(args)
266
+ else
267
+ BailingMoeMLP.new(args)
268
+ end
269
+ self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
270
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
271
+ end
272
+
273
+ def call(x, mask: nil, cache: nil)
274
+ r = attention.call(input_layernorm.call(x), mask: mask, cache: cache)
275
+ h = x + r
276
+ r = mlp.call(post_attention_layernorm.call(h))
277
+ h + r
278
+ end
279
+ end
280
+
281
+ class BailingMoeModel < MLX::NN::Module
282
+ def initialize(args)
283
+ super()
284
+ self.word_embeddings = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
285
+ self.layers = Array.new(args.num_hidden_layers) do |layer_idx|
286
+ BailingMoeDecoderLayer.new(args, layer_idx: layer_idx)
287
+ end
288
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
289
+ end
290
+
291
+ def call(inputs, cache: nil)
292
+ h = word_embeddings.call(inputs)
293
+ layer_cache = cache || [nil] * layers.length
294
+ mask = _create_attention_mask(h, layer_cache[0])
295
+
296
+ layers.each_with_index do |layer, layer_idx|
297
+ h = layer.call(h, mask: mask, cache: layer_cache[layer_idx])
298
+ end
299
+ norm.call(h)
300
+ end
301
+
302
+ private
303
+
304
+ def _create_attention_mask(hidden, cache)
305
+ return cache.make_mask(hidden.shape[1]) if cache && cache.respond_to?(:make_mask)
306
+ return nil if hidden.shape[1] == 1
307
+
308
+ "causal"
309
+ end
310
+ end
311
+
312
+ class Model < MLX::NN::Module
313
+ def initialize(args)
314
+ super()
315
+ @args = args
316
+ @norm_head = args.norm_head
317
+ self.model_type = args.model_type
318
+ self.model = BailingMoeModel.new(args)
319
+ unless args.tie_word_embeddings
320
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
321
+ end
322
+ end
323
+
324
+ def call(inputs, cache: nil)
325
+ out = model.call(inputs, cache: cache)
326
+ if @args.tie_word_embeddings
327
+ model.word_embeddings.as_linear(out)
328
+ else
329
+ lm_head.call(out)
330
+ end
331
+ end
332
+
333
+ def sanitize(weights)
334
+ mx = MLX::Core
335
+ result = weights.dup
336
+
337
+ result.delete("lm_head.weight") if @args.tie_word_embeddings
338
+
339
+ if @norm_head && result.key?("lm_head.weight")
340
+ w = result["lm_head.weight"]
341
+ dtype = w.dtype
342
+ w_fp32 = w.astype(mx.float32)
343
+ weight_norm = mx.sqrt(mx.sum(mx.square(w_fp32), 0, true)) + 1e-7
344
+ result["lm_head.weight"] = (w_fp32 / weight_norm).astype(dtype)
345
+ end
346
+
347
+ @args.num_hidden_layers.times do |layer_idx|
348
+ next if layer_idx < @args.first_k_dense_replace.to_i
349
+
350
+ prefix = "model.layers.#{layer_idx}"
351
+ %w[gate_proj down_proj up_proj].each do |projection|
352
+ %w[weight scales biases].each do |param|
353
+ first_key = "#{prefix}.mlp.experts.0.#{projection}.#{param}"
354
+ next unless result.key?(first_key)
355
+
356
+ expert_keys = (0...@args.num_experts).map do |expert_idx|
357
+ "#{prefix}.mlp.experts.#{expert_idx}.#{projection}.#{param}"
358
+ end
359
+ next unless expert_keys.all? { |key| result.key?(key) }
360
+
361
+ stacked = expert_keys.map { |key| result.delete(key) }
362
+ result["#{prefix}.mlp.switch_mlp.#{projection}.#{param}"] = mx.stack(stacked)
363
+ end
364
+ end
365
+
366
+ if result.key?("#{prefix}.mlp.gate.weight")
367
+ result["#{prefix}.mlp.gate.gate_proj.weight"] = result.delete("#{prefix}.mlp.gate.weight")
368
+ end
369
+ if result.key?("#{prefix}.mlp.gate.bias")
370
+ result["#{prefix}.mlp.gate.gate_proj.bias"] = result.delete("#{prefix}.mlp.gate.bias")
371
+ end
372
+ end
373
+
374
+ result
375
+ end
376
+
377
+ def quant_predicate
378
+ lambda do |path, _|
379
+ if path.to_s.end_with?("mlp.gate.gate_proj")
380
+ { group_size: 64, bits: 8 }
381
+ else
382
+ true
383
+ end
384
+ end
385
+ end
386
+
387
+ def cast_predicate
388
+ lambda { |key| !key.to_s.include?("expert_bias") }
389
+ end
390
+
391
+ def layers
392
+ model.layers
393
+ end
394
+ end
395
+
396
+ Models.register("bailing_moe", Model, ModelArgs)
397
+ end
398
+ end
399
+ end