mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
@@ -0,0 +1,479 @@
1
+ require_relative "activations"
2
+ require_relative "cache"
3
+ require_relative "rope_utils"
4
+ require_relative "switch_layers"
5
+
6
+ module MlxLm
7
+ module Models
8
+ module Step3p5
9
+ def self.clamped_swiglu(x, gate, limit)
10
+ mx = MLX::Core
11
+ clipped_gate = mx.minimum(MLX::NN.silu(gate), limit)
12
+ clipped_x = mx.clip(x, -limit, limit)
13
+ clipped_gate * clipped_x
14
+ end
15
+
16
+ class ModelArgs < BaseModelArgs
17
+ field :model_type, default: "step3p5"
18
+ field :hidden_size
19
+ field :num_hidden_layers
20
+ field :vocab_size
21
+ field :num_attention_heads
22
+ field :num_attention_groups
23
+ field :head_dim
24
+ field :intermediate_size
25
+ field :rms_norm_eps, default: 1e-5
26
+ field :rope_theta, default: 10_000.0
27
+ field :rope_scaling, default: nil
28
+ field :max_position_embeddings, default: 262_144
29
+ field :sliding_window, default: 512
30
+ field :layer_types, default: nil
31
+ field :yarn_only_types, default: nil
32
+ field :partial_rotary_factors, default: nil
33
+ field :attention_other_setting, default: nil
34
+ field :use_head_wise_attn_gate, default: true
35
+ field :moe_num_experts, default: 288
36
+ field :moe_top_k, default: 8
37
+ field :moe_intermediate_size, default: 1280
38
+ field :share_expert_dim, default: 1280
39
+ field :moe_layers_enum, default: nil
40
+ field :moe_router_scaling_factor, default: 3.0
41
+ field :norm_expert_weight, default: true
42
+ field :swiglu_limits, default: nil
43
+ field :swiglu_limits_shared, default: nil
44
+ field :tie_word_embeddings, default: false
45
+ end
46
+
47
+ class ZeroCenteredRMSNorm < MLX::NN::Module
48
+ def initialize(dims, eps: 1e-5)
49
+ super()
50
+ self.weight = MLX::Core.ones([dims])
51
+ @eps = eps
52
+ end
53
+
54
+ def call(x)
55
+ mx = MLX::Core
56
+ mean_sq = mx.mean(x * x, -1, keepdims: true)
57
+ (x * mx.rsqrt(mean_sq + @eps)) * weight
58
+ end
59
+ end
60
+
61
+ class Step3p5MLP < MLX::NN::Module
62
+ def initialize(args, intermediate_size:, swiglu_limit: 0)
63
+ super()
64
+ @hidden_size = args.hidden_size
65
+ @intermediate_size = intermediate_size
66
+
67
+ self.gate_proj = MLX::NN::Linear.new(@hidden_size, @intermediate_size, bias: false)
68
+ self.up_proj = MLX::NN::Linear.new(@hidden_size, @intermediate_size, bias: false)
69
+ self.down_proj = MLX::NN::Linear.new(@intermediate_size, @hidden_size, bias: false)
70
+
71
+ @limit = swiglu_limit && swiglu_limit > 0 ? swiglu_limit : nil
72
+ end
73
+
74
+ def call(x)
75
+ if @limit
76
+ return down_proj.call(
77
+ Step3p5.clamped_swiglu(up_proj.call(x), gate_proj.call(x), @limit)
78
+ )
79
+ end
80
+
81
+ down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x)))
82
+ end
83
+ end
84
+
85
+ class Step3p5MoEGate < MLX::NN::Module
86
+ def initialize(args)
87
+ super()
88
+ @top_k = args.moe_top_k
89
+ @n_routed_experts = args.moe_num_experts
90
+ @routed_scaling_factor = args.moe_router_scaling_factor
91
+ @norm_topk_prob = args.norm_expert_weight
92
+
93
+ self.gate = MLX::NN::Linear.new(args.hidden_size, @n_routed_experts, bias: false)
94
+ self.router_bias = MLX::Core.zeros([@n_routed_experts])
95
+ end
96
+
97
+ def call(x)
98
+ _moe_gate_select(gate.call(x))
99
+ end
100
+
101
+ private
102
+
103
+ def _moe_gate_select(gates)
104
+ mx = MLX::Core
105
+ scores = mx.sigmoid(gates.astype(mx.float32))
106
+ corrected_scores = scores + router_bias
107
+
108
+ k = [[@top_k.to_i, 1].max, @n_routed_experts].min
109
+ topk_indices = mx.argpartition(corrected_scores * -1.0, k - 1, -1)
110
+ take_ids = mx.array((0...k).to_a, dtype: mx.int32)
111
+ topk_indices = mx.take(topk_indices, take_ids, -1)
112
+ topk_weights = mx.take_along_axis(scores, topk_indices, -1)
113
+
114
+ if @norm_topk_prob
115
+ topk_weights = topk_weights / (mx.expand_dims(mx.sum(topk_weights, -1), -1) + 1e-20)
116
+ end
117
+
118
+ [topk_indices, topk_weights * @routed_scaling_factor]
119
+ end
120
+ end
121
+
122
+ class Step3p5MoE < MLX::NN::Module
123
+ def initialize(args, layer_idx)
124
+ super()
125
+ swiglu_limit = _limit_at(args.swiglu_limits, layer_idx)
126
+ swiglu_limit_shared = _limit_at(args.swiglu_limits_shared, layer_idx)
127
+
128
+ self.gate = Step3p5MoEGate.new(args)
129
+ self.switch_mlp = SwitchLayers::SwitchGLU.new(
130
+ args.hidden_size,
131
+ args.moe_intermediate_size,
132
+ args.moe_num_experts
133
+ )
134
+ self.share_expert = Step3p5MLP.new(
135
+ args,
136
+ intermediate_size: args.share_expert_dim,
137
+ swiglu_limit: swiglu_limit_shared
138
+ )
139
+
140
+ @swiglu_limit = swiglu_limit
141
+ end
142
+
143
+ def call(x)
144
+ mx = MLX::Core
145
+ topk_indices, topk_weights = gate.call(x)
146
+
147
+ routed_output = switch_mlp.call(x, topk_indices)
148
+ routed_output = mx.sum(routed_output * mx.expand_dims(topk_weights, -1), -2).astype(routed_output.dtype)
149
+ routed_output + share_expert.call(x)
150
+ end
151
+
152
+ private
153
+
154
+ def _limit_at(values, idx)
155
+ arr = Array(values)
156
+ return 0 unless idx < arr.length
157
+
158
+ arr[idx] || 0
159
+ end
160
+ end
161
+
162
+ class Step3p5Attention < MLX::NN::Module
163
+ attr_reader :is_sliding
164
+
165
+ def initialize(args, layer_idx)
166
+ super()
167
+ dim = args.hidden_size
168
+ layer_types = Array(args.layer_types)
169
+
170
+ @is_sliding = if layer_types.empty?
171
+ layer_idx.even?
172
+ else
173
+ layer_types[layer_idx] == "sliding_attention"
174
+ end
175
+
176
+ if @is_sliding && args.attention_other_setting
177
+ settings = args.attention_other_setting
178
+ @num_heads = _cfg_value(settings, "num_attention_heads", args.num_attention_heads)
179
+ @num_kv_heads = _cfg_value(settings, "num_attention_groups", args.num_attention_groups)
180
+ else
181
+ @num_heads = args.num_attention_heads
182
+ @num_kv_heads = args.num_attention_groups
183
+ end
184
+
185
+ @head_dim = args.head_dim
186
+ @scale = @head_dim**(-0.5)
187
+
188
+ self.q_proj = MLX::NN::Linear.new(dim, @num_heads * @head_dim, bias: false)
189
+ self.k_proj = MLX::NN::Linear.new(dim, @num_kv_heads * @head_dim, bias: false)
190
+ self.v_proj = MLX::NN::Linear.new(dim, @num_kv_heads * @head_dim, bias: false)
191
+ self.o_proj = MLX::NN::Linear.new(@num_heads * @head_dim, dim, bias: false)
192
+
193
+ self.q_norm = ZeroCenteredRMSNorm.new(@head_dim, eps: args.rms_norm_eps)
194
+ self.k_norm = ZeroCenteredRMSNorm.new(@head_dim, eps: args.rms_norm_eps)
195
+
196
+ @use_head_wise_attn_gate = args.use_head_wise_attn_gate
197
+ self.g_proj = MLX::NN::Linear.new(dim, @num_heads, bias: false) if @use_head_wise_attn_gate
198
+
199
+ rope_theta = args.rope_theta
200
+ if rope_theta.is_a?(Array)
201
+ rope_theta = rope_theta[layer_idx] || rope_theta[0]
202
+ end
203
+
204
+ partial_rotary_factor = _partial_rotary_factor(args.partial_rotary_factors, layer_idx)
205
+ rope_dims = (@head_dim * partial_rotary_factor).to_i
206
+ rope_dims = 1 if rope_dims < 1
207
+
208
+ yarn_only_types = Array(args.yarn_only_types)
209
+ layer_type = layer_types.empty? ? "full_attention" : layer_types[layer_idx]
210
+ rope_scaling = if !yarn_only_types.empty? && !yarn_only_types.include?(layer_type)
211
+ nil
212
+ else
213
+ args.rope_scaling
214
+ end
215
+
216
+ self.rope = MlxLm::Models.initialize_rope(
217
+ rope_dims,
218
+ rope_theta,
219
+ false,
220
+ rope_scaling,
221
+ max_position_embeddings: args.max_position_embeddings
222
+ )
223
+ end
224
+
225
+ def call(x, mask: nil, cache: nil)
226
+ mx = MLX::Core
227
+ b, l, _ = x.shape
228
+
229
+ queries = q_proj.call(x)
230
+ keys = k_proj.call(x)
231
+ values = v_proj.call(x)
232
+
233
+ queries = q_norm.call(queries.reshape([b, l, @num_heads, @head_dim])).transpose([0, 2, 1, 3])
234
+ keys = k_norm.call(keys.reshape([b, l, @num_kv_heads, @head_dim])).transpose([0, 2, 1, 3])
235
+ values = values.reshape([b, l, @num_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
236
+
237
+ if cache
238
+ queries = rope.call(queries, offset: cache.offset)
239
+ keys = rope.call(keys, offset: cache.offset)
240
+ keys, values = cache.update_and_fetch(keys, values)
241
+ else
242
+ queries = rope.call(queries)
243
+ keys = rope.call(keys)
244
+ end
245
+
246
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
247
+ output = output.transpose([0, 2, 1, 3])
248
+
249
+ if @use_head_wise_attn_gate
250
+ output = output * mx.expand_dims(mx.sigmoid(g_proj.call(x)), -1)
251
+ end
252
+
253
+ o_proj.call(output.reshape([b, l, @num_heads * @head_dim]))
254
+ end
255
+
256
+ private
257
+
258
+ def _partial_rotary_factor(factors, idx)
259
+ arr = Array(factors)
260
+ return 1.0 unless idx < arr.length
261
+
262
+ arr[idx] || 1.0
263
+ end
264
+
265
+ def _cfg_value(hash, key, default = nil)
266
+ return hash[key] if hash.key?(key)
267
+
268
+ hash.fetch(key.to_sym, default)
269
+ end
270
+ end
271
+
272
+ class Step3p5DecoderLayer < MLX::NN::Module
273
+ attr_reader :is_sliding
274
+
275
+ def initialize(args, layer_idx)
276
+ super()
277
+ self.self_attn = Step3p5Attention.new(args, layer_idx)
278
+ @is_sliding = self_attn.is_sliding
279
+
280
+ moe_layers_idx = _build_moe_layers_idx(args)
281
+ is_moe_layer = moe_layers_idx[layer_idx]
282
+
283
+ if is_moe_layer
284
+ self.mlp = Step3p5MoE.new(args, layer_idx)
285
+ else
286
+ swiglu_limit = _limit_at(args.swiglu_limits_shared, layer_idx)
287
+ self.mlp = Step3p5MLP.new(
288
+ args,
289
+ intermediate_size: args.intermediate_size,
290
+ swiglu_limit: swiglu_limit
291
+ )
292
+ end
293
+
294
+ self.input_layernorm = ZeroCenteredRMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
295
+ self.post_attention_layernorm = ZeroCenteredRMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
296
+ end
297
+
298
+ def call(x, mask: nil, cache: nil)
299
+ r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
300
+ h = x + r
301
+ h + mlp.call(post_attention_layernorm.call(h))
302
+ end
303
+
304
+ private
305
+
306
+ def _build_moe_layers_idx(args)
307
+ mapping = {}
308
+ if args.moe_layers_enum
309
+ args.moe_layers_enum.split(",").each do |idx|
310
+ stripped = idx.strip
311
+ next if stripped.empty?
312
+
313
+ mapping[stripped.to_i] = true
314
+ end
315
+ else
316
+ (1...args.num_hidden_layers).each { |idx| mapping[idx] = true }
317
+ end
318
+ mapping
319
+ end
320
+
321
+ def _limit_at(values, idx)
322
+ arr = Array(values)
323
+ return 0 unless idx < arr.length
324
+
325
+ arr[idx] || 0
326
+ end
327
+ end
328
+
329
+ class Step3p5Model < MLX::NN::Module
330
+ def initialize(args)
331
+ super()
332
+ @args = args
333
+ @num_layers = args.num_hidden_layers
334
+
335
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
336
+ self.layers = Array.new(args.num_hidden_layers) { |layer_idx| Step3p5DecoderLayer.new(args, layer_idx) }
337
+ self.norm = ZeroCenteredRMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
338
+
339
+ @swa_idx = layers.index(&:is_sliding)
340
+ @full_idx = layers.index { |layer| !layer.is_sliding }
341
+ end
342
+
343
+ def call(inputs, cache: nil)
344
+ h = embed_tokens.call(inputs)
345
+ layer_cache = cache || [nil] * @num_layers
346
+
347
+ full_mask = @full_idx.nil? ? nil : _create_attention_mask(h, layer_cache[@full_idx])
348
+ swa_mask = if @swa_idx.nil?
349
+ nil
350
+ else
351
+ _create_attention_mask(h, layer_cache[@swa_idx], window_size: @args.sliding_window)
352
+ end
353
+
354
+ layers.each_with_index do |layer, i|
355
+ mask = layer.is_sliding ? swa_mask : full_mask
356
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
357
+ end
358
+
359
+ norm.call(h)
360
+ end
361
+
362
+ private
363
+
364
+ def _create_attention_mask(h, cache = nil, window_size: nil)
365
+ n = h.shape[1]
366
+ if cache && cache.respond_to?(:make_mask)
367
+ return cache.make_mask(n, window_size: window_size)
368
+ end
369
+
370
+ if window_size
371
+ offset = 0
372
+ if cache
373
+ offset = cache.offset
374
+ if cache.instance_variable_defined?(:@max_size)
375
+ max_size = cache.instance_variable_get(:@max_size)
376
+ offset = [max_size - 1, offset].min if max_size && max_size > 0
377
+ end
378
+ end
379
+ return _create_causal_mask(n, offset: offset, window_size: window_size) if offset + n > window_size
380
+ end
381
+ return nil if n == 1
382
+
383
+ "causal"
384
+ end
385
+
386
+ def _create_causal_mask(n, offset: 0, window_size: nil)
387
+ mx = MLX::Core
388
+ rinds = mx.arange(0, offset + n, 1, mx.int32).reshape([1, offset + n])
389
+ linds = mx.arange(offset, offset + n, 1, mx.int32).reshape([n, 1])
390
+
391
+ mask = mx.greater_equal(linds, rinds)
392
+ if window_size
393
+ mask = mx.logical_and(mask, mx.less(linds, mx.add(rinds, window_size)))
394
+ end
395
+ mask
396
+ end
397
+ end
398
+
399
+ class Model < MLX::NN::Module
400
+ attr_reader :args
401
+
402
+ def initialize(args)
403
+ super()
404
+ @args = args
405
+ self.model_type = args.model_type
406
+ self.model = Step3p5Model.new(args)
407
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
408
+ end
409
+
410
+ def call(inputs, cache: nil)
411
+ lm_head.call(model.call(inputs, cache: cache))
412
+ end
413
+
414
+ def layers
415
+ model.layers
416
+ end
417
+
418
+ def make_cache
419
+ Array.new(layers.length) { MlxLm::KVCache.new }
420
+ end
421
+
422
+ def sanitize(weights)
423
+ remappings = [
424
+ [".moe.gate_proj.", ".mlp.switch_mlp.gate_proj."],
425
+ [".moe.up_proj.", ".mlp.switch_mlp.up_proj."],
426
+ [".moe.down_proj.", ".mlp.switch_mlp.down_proj."],
427
+ [".moe.gate.", ".mlp.gate.gate."],
428
+ [".moe.router_bias", ".mlp.gate.router_bias"],
429
+ [".share_expert.", ".mlp.share_expert."],
430
+ ]
431
+
432
+ is_vanilla = weights.any? do |key, _|
433
+ remappings.any? { |src, dst| key.include?(src) && !key.include?(dst) }
434
+ end
435
+
436
+ sanitized = {}
437
+ weights.each do |key, value|
438
+ next if key.include?(".mtp")
439
+
440
+ if (match = key.match(/model\.layers\.(\d+)\./)) && match[1].to_i >= args.num_hidden_layers
441
+ next
442
+ end
443
+
444
+ mapped_key = key
445
+ remappings.each do |src, dst|
446
+ if mapped_key.include?(src) && !mapped_key.include?(dst)
447
+ mapped_key = mapped_key.gsub(src, dst)
448
+ break
449
+ end
450
+ end
451
+
452
+ mapped_value = value
453
+ if is_vanilla && mapped_key.end_with?(".weight") && mapped_key.include?("norm")
454
+ mapped_value = mapped_value + 1
455
+ end
456
+
457
+ sanitized[mapped_key] = mapped_value
458
+ end
459
+
460
+ sanitized
461
+ end
462
+
463
+ def cast_predicate
464
+ ->(key) { !key.include?("router_bias") }
465
+ end
466
+
467
+ def quant_predicate
468
+ lambda do |path, _|
469
+ return {group_size: 64, bits: 8} if path.include?("mlp.gate.gate")
470
+
471
+ true
472
+ end
473
+ end
474
+ end
475
+
476
+ Models.register("step3p5", Model, ModelArgs)
477
+ end
478
+ end
479
+ end
@@ -0,0 +1,221 @@
1
+ module MlxLm
2
+ module Models
3
+ module SwitchLayers
4
+ # Gather-sort helper: reorder tokens so same-expert tokens are contiguous.
5
+ # Returns [sorted_x, sorted_indices, inv_order].
6
+ def self.gather_sort(x, indices)
7
+ mx = MLX::Core
8
+ m = indices.shape[-1]
9
+ flat_indices = mx.flatten(indices)
10
+ order = mx.argsort(flat_indices)
11
+ inv_order = mx.argsort(order)
12
+ token_ids = mx.floor_divide(order, m)
13
+ sorted_x = mx.take(mx.flatten(x, 0, -3), token_ids, 0)
14
+ sorted_indices = mx.take(flat_indices, order)
15
+ [sorted_x, sorted_indices, inv_order]
16
+ end
17
+
18
+ # Scatter-unsort helper: restore original token order after sorted computation.
19
+ def self.scatter_unsort(x, inv_order, shape = nil)
20
+ mx = MLX::Core
21
+ x = mx.take(x, inv_order, 0)
22
+ x = mx.unflatten(x, 0, shape) if shape
23
+ x
24
+ end
25
+
26
+ # SwitchLinear: batched expert linear layer using gather_mm.
27
+ # Stacks all expert weights into a single [num_experts, output_dims, input_dims] tensor
28
+ # and dispatches via mx.gather_mm.
29
+ class SwitchLinear < MLX::NN::Module
30
+ def initialize(input_dims, output_dims, num_experts, bias: false)
31
+ super()
32
+ mx = MLX::Core
33
+ scale = Math.sqrt(1.0 / input_dims)
34
+ self.weight = mx.random_uniform(
35
+ [num_experts, output_dims, input_dims],
36
+ scale * -1.0, scale, mx.float32
37
+ )
38
+ self.bias = mx.zeros([num_experts, output_dims]) if bias
39
+ end
40
+
41
+ def call(x, indices, sorted_indices: false)
42
+ mx = MLX::Core
43
+ x = mx.gather_mm(
44
+ x,
45
+ mx.swapaxes(weight, -1, -2),
46
+ nil,
47
+ indices,
48
+ sorted_indices
49
+ )
50
+ if respond_to?(:bias)
51
+ x = x + mx.expand_dims(mx.take(bias, indices, 0), -2)
52
+ end
53
+ x
54
+ end
55
+
56
+ def to_quantized(group_size: nil, bits: nil, mode: "affine", quantize_input: false)
57
+ raise ArgumentError, "Quantized input is not supported." if quantize_input
58
+
59
+ QuantizedSwitchLinear.from_switch_linear(self, group_size, bits, mode: mode)
60
+ end
61
+ end
62
+
63
+ # Quantized version of SwitchLinear using gather_qmm.
64
+ class QuantizedSwitchLinear < MLX::NN::Module
65
+ attr_reader :group_size, :bits, :mode
66
+
67
+ def initialize(input_dims, output_dims, num_experts, bias: false, group_size: nil, bits: nil, mode: "affine")
68
+ super()
69
+
70
+ @group_size, @bits = MLX::NN.__send__(:defaults_for_mode, mode, group_size, bits)
71
+ @mode = mode
72
+
73
+ mx = MLX::Core
74
+ scale = Math.sqrt(1.0 / input_dims)
75
+ q_weight, q_scales, *q_biases = mx.quantize(
76
+ mx.random_uniform(
77
+ [num_experts, output_dims, input_dims],
78
+ scale * -1.0,
79
+ scale,
80
+ mx.float32
81
+ ),
82
+ @group_size,
83
+ @bits,
84
+ @mode
85
+ )
86
+ self.weight = q_weight
87
+ self.scales = q_scales
88
+ self.biases = q_biases.empty? ? nil : q_biases[0]
89
+ self.bias = mx.zeros([num_experts, output_dims]) if bias
90
+
91
+ freeze
92
+ end
93
+
94
+ def call(x, indices, sorted_indices: false)
95
+ mx = MLX::Core
96
+ q_biases = respond_to?(:biases) ? biases : nil
97
+ x = mx.gather_qmm(
98
+ x,
99
+ weight,
100
+ scales,
101
+ q_biases,
102
+ nil,
103
+ indices,
104
+ true,
105
+ @group_size,
106
+ @bits,
107
+ @mode,
108
+ sorted_indices
109
+ )
110
+ if respond_to?(:bias)
111
+ x = x + mx.expand_dims(mx.take(bias, indices, 0), -2)
112
+ end
113
+ x
114
+ end
115
+
116
+ def self.from_switch_linear(linear_layer, group_size = nil, bits = nil, mode: "affine")
117
+ num_experts, output_dims, input_dims = linear_layer.weight.shape
118
+ out = new(
119
+ input_dims,
120
+ output_dims,
121
+ num_experts,
122
+ bias: false,
123
+ group_size: group_size,
124
+ bits: bits,
125
+ mode: mode
126
+ )
127
+ q_weight, q_scales, *q_biases = MLX::Core.quantize(
128
+ linear_layer.weight,
129
+ out.group_size,
130
+ out.bits,
131
+ out.mode
132
+ )
133
+ out.weight = q_weight
134
+ out.scales = q_scales
135
+ out.biases = q_biases.empty? ? nil : q_biases[0]
136
+ out.bias = linear_layer.bias if linear_layer.state.key?("bias")
137
+ out
138
+ end
139
+ end
140
+
141
+ # SwitchGLU: batched expert MLP with SwiGLU activation using SwitchLinear.
142
+ # Replaces per-token expert routing loops with gather_mm for ONNX traceability.
143
+ class SwitchGLU < MLX::NN::Module
144
+ def initialize(input_dims, hidden_dims, num_experts, bias: false)
145
+ super()
146
+ self.gate_proj = SwitchLinear.new(input_dims, hidden_dims, num_experts, bias: bias)
147
+ self.up_proj = SwitchLinear.new(input_dims, hidden_dims, num_experts, bias: bias)
148
+ self.down_proj = SwitchLinear.new(hidden_dims, input_dims, num_experts, bias: bias)
149
+ end
150
+
151
+ def call(x, indices)
152
+ mx = MLX::Core
153
+ x = mx.expand_dims(x, [-2, -3])
154
+
155
+ # Sort optimization for many tokens
156
+ do_sort = indices.size >= 64
157
+ idx = indices
158
+ inv_order = nil
159
+
160
+ if do_sort
161
+ x, idx, inv_order = SwitchLayers.gather_sort(x, indices)
162
+ end
163
+
164
+ idx = mx.stop_gradient(idx) if training
165
+
166
+ x_up = up_proj.call(x, idx, sorted_indices: do_sort)
167
+ x_gate = gate_proj.call(x, idx, sorted_indices: do_sort)
168
+
169
+ # SwiGLU activation: silu(gate) * up
170
+ x = down_proj.call(
171
+ MLX::NN.silu(x_gate) * x_up,
172
+ idx,
173
+ sorted_indices: do_sort
174
+ )
175
+
176
+ if do_sort
177
+ x = SwitchLayers.scatter_unsort(x, inv_order, indices.shape)
178
+ end
179
+
180
+ mx.squeeze(x, -2)
181
+ end
182
+ end
183
+
184
+ # Batched expert MLP with configurable activation.
185
+ class SwitchMLP < MLX::NN::Module
186
+ def initialize(input_dims, hidden_dims, num_experts, activation: nil, bias: false)
187
+ super()
188
+ self.fc1 = SwitchLinear.new(input_dims, hidden_dims, num_experts, bias: bias)
189
+ self.fc2 = SwitchLinear.new(hidden_dims, input_dims, num_experts, bias: bias)
190
+ self.activation = activation || MLX::NN::GELU.new("precise")
191
+ end
192
+
193
+ def call(x, indices)
194
+ mx = MLX::Core
195
+ x = mx.expand_dims(x, [-2, -3])
196
+
197
+ # Sort optimization for many tokens
198
+ do_sort = indices.size >= 64
199
+ idx = indices
200
+ inv_order = nil
201
+
202
+ if do_sort
203
+ x, idx, inv_order = SwitchLayers.gather_sort(x, indices)
204
+ end
205
+
206
+ idx = mx.stop_gradient(idx) if training
207
+
208
+ x = fc1.call(x, idx, sorted_indices: do_sort)
209
+ x = activation.call(x)
210
+ x = fc2.call(x, idx, sorted_indices: do_sort)
211
+
212
+ if do_sort
213
+ x = SwitchLayers.scatter_unsort(x, inv_order, indices.shape)
214
+ end
215
+
216
+ mx.squeeze(x, -2)
217
+ end
218
+ end
219
+ end
220
+ end
221
+ end