mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
@@ -0,0 +1,421 @@
1
+ require_relative "activations"
2
+ require_relative "cache"
3
+ require_relative "rope_utils"
4
+ require_relative "switch_layers"
5
+
6
+ module MlxLm
7
+ module Models
8
+ module Afmoe
9
+ class ModelArgs < BaseModelArgs
10
+ field :model_type
11
+ field :layer_types
12
+ field :vocab_size, default: 200_192
13
+ field :hidden_size, default: 2048
14
+ field :intermediate_size, default: 6144
15
+ field :moe_intermediate_size, default: 1024
16
+ field :num_hidden_layers, default: 32
17
+ field :num_attention_heads, default: 32
18
+ field :num_key_value_heads, default: 4
19
+ field :head_dim, default: 64
20
+ field :max_position_embeddings, default: 131_072
21
+ field :rms_norm_eps, default: 1e-5
22
+ field :rope_theta, default: 10_000.0
23
+ field :rope_scaling, default: nil
24
+ field :tie_word_embeddings, default: false
25
+ field :num_experts, default: 128
26
+ field :num_experts_per_tok, default: 8
27
+ field :num_shared_experts, default: 1
28
+ field :num_dense_layers, default: 2
29
+ field :route_norm, default: true
30
+ field :route_scale, default: 2.826
31
+ field :score_func, default: "sigmoid"
32
+ field :n_group, default: 1
33
+ field :topk_group, default: 1
34
+ field :sliding_window, default: 2048
35
+ field :mup_enabled, default: true
36
+
37
+ def initialize(**kwargs)
38
+ super
39
+ @num_key_value_heads ||= @num_attention_heads
40
+ @layer_types ||= Array.new(@num_hidden_layers) { "full_attention" }
41
+ end
42
+ end
43
+
44
+ class Attention < MLX::NN::Module
45
+ def initialize(args, is_local_attention: false)
46
+ super()
47
+ @hidden_size = args.hidden_size
48
+ @num_attention_heads = args.num_attention_heads
49
+ @num_key_value_heads = args.num_key_value_heads
50
+ @head_dim = args.head_dim
51
+ @is_local_attention = is_local_attention
52
+ @scale = @head_dim**(-0.5)
53
+
54
+ self.q_proj = MLX::NN::Linear.new(
55
+ @hidden_size,
56
+ @num_attention_heads * @head_dim,
57
+ bias: false
58
+ )
59
+ self.k_proj = MLX::NN::Linear.new(
60
+ @hidden_size,
61
+ @num_key_value_heads * @head_dim,
62
+ bias: false
63
+ )
64
+ self.v_proj = MLX::NN::Linear.new(
65
+ @hidden_size,
66
+ @num_key_value_heads * @head_dim,
67
+ bias: false
68
+ )
69
+ self.o_proj = MLX::NN::Linear.new(
70
+ @num_attention_heads * @head_dim,
71
+ @hidden_size,
72
+ bias: false
73
+ )
74
+
75
+ self.q_norm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps)
76
+ self.k_norm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps)
77
+ self.gate_proj = MLX::NN::Linear.new(
78
+ @hidden_size,
79
+ @num_attention_heads * @head_dim,
80
+ bias: false
81
+ )
82
+
83
+ if @is_local_attention
84
+ self.rope = MlxLm::Models.initialize_rope(
85
+ @head_dim,
86
+ args.rope_theta,
87
+ false,
88
+ args.rope_scaling,
89
+ max_position_embeddings: args.max_position_embeddings
90
+ )
91
+ end
92
+ end
93
+
94
+ def call(x, mask: nil, cache: nil)
95
+ mx = MLX::Core
96
+ b, l, _d = x.shape
97
+
98
+ queries = q_proj.call(x).reshape([b, l, @num_attention_heads, @head_dim]).transpose([0, 2, 1, 3])
99
+ keys = k_proj.call(x).reshape([b, l, @num_key_value_heads, @head_dim]).transpose([0, 2, 1, 3])
100
+ values = v_proj.call(x).reshape([b, l, @num_key_value_heads, @head_dim]).transpose([0, 2, 1, 3])
101
+
102
+ queries = q_norm.call(queries)
103
+ keys = k_norm.call(keys)
104
+
105
+ if @is_local_attention && respond_to?(:rope)
106
+ if cache
107
+ queries = rope.call(queries, offset: cache.offset)
108
+ keys = rope.call(keys, offset: cache.offset)
109
+ else
110
+ queries = rope.call(queries)
111
+ keys = rope.call(keys)
112
+ end
113
+ end
114
+
115
+ if cache
116
+ keys, values = cache.update_and_fetch(keys, values)
117
+ end
118
+
119
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
120
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @num_attention_heads * @head_dim])
121
+
122
+ gate = mx.sigmoid(gate_proj.call(x))
123
+ output = output * gate
124
+ o_proj.call(output)
125
+ end
126
+ end
127
+
128
+ class MLP < MLX::NN::Module
129
+ def initialize(args, intermediate_size: nil)
130
+ super()
131
+ dim = args.hidden_size
132
+ hidden_dim = intermediate_size || args.intermediate_size
133
+
134
+ self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false)
135
+ self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: false)
136
+ self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false)
137
+ end
138
+
139
+ def call(x)
140
+ down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x)))
141
+ end
142
+ end
143
+
144
+ class MoERouter < MLX::NN::Module
145
+ def initialize(args)
146
+ super()
147
+ self.gate = MLX::NN::Linear.new(args.hidden_size, args.num_experts, bias: false)
148
+ end
149
+
150
+ def call(x)
151
+ gate.call(x)
152
+ end
153
+ end
154
+
155
+ class AfmoeMoE < MLX::NN::Module
156
+ def initialize(args)
157
+ super()
158
+ @args = args
159
+ @num_experts = args.num_experts
160
+ @num_experts_per_tok = args.num_experts_per_tok
161
+ @route_norm = args.route_norm
162
+ @route_scale = args.route_scale
163
+ @score_func = args.score_func
164
+ @n_group = args.n_group
165
+ @topk_group = args.topk_group
166
+
167
+ self.router = MoERouter.new(args)
168
+ self.expert_bias = MLX::Core.zeros([args.num_experts])
169
+ self.experts = SwitchLayers::SwitchGLU.new(
170
+ args.hidden_size,
171
+ args.moe_intermediate_size,
172
+ args.num_experts
173
+ )
174
+
175
+ if args.num_shared_experts.to_i > 0
176
+ shared_intermediate_size = args.moe_intermediate_size * args.num_shared_experts
177
+ self.shared_experts = MLP.new(args, intermediate_size: shared_intermediate_size)
178
+ end
179
+ end
180
+
181
+ def call(x)
182
+ mx = MLX::Core
183
+
184
+ gates = router.call(x)
185
+ scores = if @score_func == "sigmoid"
186
+ mx.sigmoid(gates.astype(mx.float32))
187
+ else
188
+ mx.softmax(gates.astype(mx.float32), -1)
189
+ end
190
+
191
+ selection_scores = scores + expert_bias
192
+
193
+ if @n_group.to_i > 1
194
+ experts_per_group = selection_scores.shape[-1] / @n_group
195
+ selection_scores = mx.unflatten(selection_scores, -1, [@n_group, experts_per_group])
196
+ group_scores = mx.topk(selection_scores, 2, -1)
197
+ group_scores = mx.expand_dims(mx.sum(group_scores, -1), -1)
198
+
199
+ drop_count = @n_group - @topk_group.to_i
200
+ if drop_count > 0
201
+ group_idx = mx.argpartition(group_scores, drop_count - 1, -2)
202
+ take_ids = mx.array((0...drop_count).to_a, dtype: mx.int32)
203
+ group_idx = mx.take(group_idx, take_ids, -2)
204
+ selection_scores = mx.put_along_axis(
205
+ selection_scores,
206
+ mx.stop_gradient(group_idx),
207
+ mx.array(0.0),
208
+ -2
209
+ )
210
+ end
211
+
212
+ selection_scores = mx.flatten(selection_scores, -2, -1)
213
+ end
214
+
215
+ k = [@num_experts_per_tok.to_i, selection_scores.shape[-1]].min
216
+ inds = mx.argpartition(selection_scores * -1.0, k - 1, -1)
217
+ take_ids = mx.array((0...k).to_a, dtype: mx.int32)
218
+ inds = mx.take(inds, take_ids, -1)
219
+
220
+ selected_scores = mx.take_along_axis(scores, inds, -1)
221
+ if @route_norm && k > 1
222
+ denominator = mx.expand_dims(mx.sum(selected_scores, -1), -1)
223
+ selected_scores = selected_scores / denominator
224
+ end
225
+ selected_scores = selected_scores * @route_scale
226
+
227
+ y = experts.call(x, inds)
228
+ y = mx.sum(y * mx.expand_dims(selected_scores, -1), -2).astype(y.dtype)
229
+ y = y + shared_experts.call(x) if @args.num_shared_experts.to_i > 0
230
+ y
231
+ end
232
+ end
233
+
234
+ class DecoderLayer < MLX::NN::Module
235
+ attr_reader :use_sliding
236
+
237
+ def initialize(args, layer_idx, use_sliding: false)
238
+ super()
239
+ @use_sliding = use_sliding
240
+ self.self_attn = Attention.new(args, is_local_attention: @use_sliding)
241
+ self.mlp = if layer_idx < args.num_dense_layers
242
+ MLP.new(args)
243
+ else
244
+ AfmoeMoE.new(args)
245
+ end
246
+ self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
247
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
248
+ self.pre_mlp_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
249
+ self.post_mlp_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
250
+ end
251
+
252
+ def call(x, mask: nil, cache: nil)
253
+ r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
254
+ r = post_attention_layernorm.call(r)
255
+ h = x + r
256
+
257
+ r = mlp.call(pre_mlp_layernorm.call(h))
258
+ r = post_mlp_layernorm.call(r)
259
+ h + r
260
+ end
261
+ end
262
+
263
+ class AfmoeModel < MLX::NN::Module
264
+ attr_reader :layer_types, :sliding_window
265
+
266
+ def initialize(args)
267
+ super()
268
+ @hidden_size = args.hidden_size
269
+ @layer_types = args.layer_types
270
+ @sliding_window = args.sliding_window
271
+ @mup_enabled = args.mup_enabled
272
+
273
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
274
+ self.layers = @layer_types.each_with_index.map do |layer_type, idx|
275
+ DecoderLayer.new(args, idx, use_sliding: layer_type == "sliding_attention")
276
+ end
277
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
278
+
279
+ self.fa_idx = @layer_types.index("full_attention") || 0
280
+ self.swa_idx = @layer_types.index("sliding_attention")
281
+ end
282
+
283
+ def call(inputs, cache: nil)
284
+ h = embed_tokens.call(inputs)
285
+ h = h * Math.sqrt(@hidden_size) if @mup_enabled
286
+
287
+ layer_cache = cache || [nil] * layers.length
288
+ full_mask = _create_attention_mask(h, layer_cache[fa_idx])
289
+ sliding_mask = if swa_idx.nil?
290
+ nil
291
+ else
292
+ _create_attention_mask(h, layer_cache[swa_idx], window_size: @sliding_window)
293
+ end
294
+
295
+ layers.each_with_index do |layer, i|
296
+ mask = layer.use_sliding ? sliding_mask : full_mask
297
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
298
+ end
299
+ norm.call(h)
300
+ end
301
+
302
+ private
303
+
304
+ def _create_attention_mask(h, cache = nil, window_size: nil)
305
+ n = h.shape[1]
306
+ if cache && cache.respond_to?(:make_mask)
307
+ return cache.make_mask(n, window_size: window_size)
308
+ end
309
+
310
+ if window_size
311
+ offset = 0
312
+ if cache
313
+ offset = cache.offset if cache.respond_to?(:offset)
314
+ if cache.instance_variable_defined?(:@max_size)
315
+ max_size = cache.instance_variable_get(:@max_size)
316
+ offset = [max_size - 1, offset].min if max_size && max_size > 0
317
+ end
318
+ end
319
+ return _create_causal_mask(n, offset: offset, window_size: window_size) if offset + n > window_size
320
+ end
321
+
322
+ return nil if n == 1
323
+
324
+ "causal"
325
+ end
326
+
327
+ def _create_causal_mask(n, offset: 0, window_size: nil)
328
+ mx = MLX::Core
329
+ rinds = mx.arange(0, offset + n, 1, mx.int32).reshape([1, offset + n])
330
+ linds = mx.arange(offset, offset + n, 1, mx.int32).reshape([n, 1])
331
+
332
+ mask = mx.greater_equal(linds, rinds)
333
+ if window_size
334
+ mask = mx.logical_and(mask, mx.less(linds, mx.add(rinds, window_size)))
335
+ end
336
+ mask
337
+ end
338
+ end
339
+
340
+ class Model < MLX::NN::Module
341
+ def initialize(args)
342
+ super()
343
+ @args = args
344
+ self.model_type = args.model_type
345
+ self.model = AfmoeModel.new(args)
346
+ unless args.tie_word_embeddings
347
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
348
+ end
349
+ end
350
+
351
+ def call(inputs, cache: nil)
352
+ out = model.call(inputs, cache: cache)
353
+ if @args.tie_word_embeddings
354
+ model.embed_tokens.as_linear(out)
355
+ else
356
+ lm_head.call(out)
357
+ end
358
+ end
359
+
360
+ def sanitize(weights)
361
+ mx = MLX::Core
362
+ result = weights.reject { |key, _| key.to_s.include?("rotary_emb.inv_freq") }
363
+ result = result.dup
364
+ result.delete("lm_head.weight") if @args.tie_word_embeddings
365
+
366
+ @args.num_hidden_layers.times do |layer_idx|
367
+ next if layer_idx < @args.num_dense_layers.to_i
368
+
369
+ prefix = "model.layers.#{layer_idx}"
370
+ %w[up_proj down_proj gate_proj].each do |projection|
371
+ %w[weight scales biases].each do |param|
372
+ first_key = "#{prefix}.mlp.experts.0.#{projection}.#{param}"
373
+ next unless result.key?(first_key)
374
+
375
+ expert_keys = (0...@args.num_experts).map do |expert_idx|
376
+ "#{prefix}.mlp.experts.#{expert_idx}.#{projection}.#{param}"
377
+ end
378
+ next unless expert_keys.all? { |key| result.key?(key) }
379
+
380
+ stacked = expert_keys.map { |key| result.delete(key) }
381
+ result["#{prefix}.mlp.experts.#{projection}.#{param}"] = mx.stack(stacked)
382
+ end
383
+ end
384
+ end
385
+
386
+ result
387
+ end
388
+
389
+ def layers
390
+ model.layers
391
+ end
392
+
393
+ def make_cache
394
+ layers.map do |layer|
395
+ if layer.use_sliding
396
+ MlxLm::RotatingKVCache.new(max_size: model.sliding_window)
397
+ else
398
+ MlxLm::KVCache.new
399
+ end
400
+ end
401
+ end
402
+
403
+ def cast_predicate
404
+ lambda { |key| !key.to_s.include?("expert_bias") }
405
+ end
406
+
407
+ def quant_predicate
408
+ lambda do |path, _|
409
+ if path.to_s.include?("router.gate")
410
+ { group_size: 64, bits: 8 }
411
+ else
412
+ true
413
+ end
414
+ end
415
+ end
416
+ end
417
+
418
+ Models.register("afmoe", Model, ModelArgs)
419
+ end
420
+ end
421
+ end
@@ -0,0 +1,179 @@
1
+ module MlxLm
2
+ module Models
3
+ module Apertus
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type
6
+ field :hidden_size
7
+ field :num_hidden_layers
8
+ field :intermediate_size
9
+ field :mlp_bias
10
+ field :num_attention_heads
11
+ field :attention_bias
12
+ field :rms_norm_eps
13
+ field :vocab_size
14
+ field :num_key_value_heads
15
+ field :max_position_embeddings
16
+ field :rope_theta
17
+ field :post_norm
18
+ field :qk_norm
19
+ field :tie_word_embeddings
20
+ field :rope_traditional, default: false
21
+ field :rope_scaling, default: nil
22
+
23
+ def initialize(**kwargs)
24
+ super
25
+ @num_key_value_heads ||= @num_attention_heads
26
+ end
27
+ end
28
+
29
+ class ApertusMLP < MLX::NN::Module
30
+ def initialize(args)
31
+ super()
32
+ self.up_proj = MLX::NN::Linear.new(
33
+ args.hidden_size,
34
+ args.intermediate_size,
35
+ bias: args.mlp_bias
36
+ )
37
+ self.down_proj = MLX::NN::Linear.new(
38
+ args.intermediate_size,
39
+ args.hidden_size,
40
+ bias: args.mlp_bias
41
+ )
42
+ self.act_fn = Activations::XieLU.new
43
+ end
44
+
45
+ def call(x)
46
+ down_proj.call(act_fn.call(up_proj.call(x)))
47
+ end
48
+ end
49
+
50
+ class ApertusAttention < MLX::NN::Module
51
+ def initialize(args)
52
+ super()
53
+ dim = args.hidden_size
54
+ @num_attention_heads = args.num_attention_heads
55
+ @num_key_value_heads = args.num_key_value_heads
56
+ @head_dim = dim / @num_attention_heads
57
+ @scale = @head_dim**(-0.5)
58
+
59
+ self.q_proj = MLX::NN::Linear.new(dim, @num_attention_heads * @head_dim, bias: false)
60
+ self.k_proj = MLX::NN::Linear.new(dim, @num_key_value_heads * @head_dim, bias: false)
61
+ self.v_proj = MLX::NN::Linear.new(dim, @num_key_value_heads * @head_dim, bias: false)
62
+ self.o_proj = MLX::NN::Linear.new(@num_attention_heads * @head_dim, dim, bias: false)
63
+
64
+ self.q_norm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps)
65
+ self.k_norm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps)
66
+ self.rope = MlxLm::Models.initialize_rope(
67
+ @head_dim,
68
+ args.rope_theta,
69
+ args.rope_traditional,
70
+ args.rope_scaling,
71
+ max_position_embeddings: args.max_position_embeddings
72
+ )
73
+ end
74
+
75
+ def call(x, mask: nil, cache: nil)
76
+ mx = MLX::Core
77
+ b, l, _d = x.shape
78
+
79
+ queries = q_proj.call(x)
80
+ keys = k_proj.call(x)
81
+ values = v_proj.call(x)
82
+
83
+ queries = q_norm.call(queries.reshape([b, l, @num_attention_heads, @head_dim])).transpose([0, 2, 1, 3])
84
+ keys = k_norm.call(keys.reshape([b, l, @num_key_value_heads, @head_dim])).transpose([0, 2, 1, 3])
85
+ values = values.reshape([b, l, @num_key_value_heads, @head_dim]).transpose([0, 2, 1, 3])
86
+
87
+ if cache
88
+ queries = rope.call(queries, offset: cache.offset)
89
+ keys = rope.call(keys, offset: cache.offset)
90
+ keys, values = cache.update_and_fetch(keys, values)
91
+ else
92
+ queries = rope.call(queries)
93
+ keys = rope.call(keys)
94
+ end
95
+
96
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
97
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @num_attention_heads * @head_dim])
98
+ o_proj.call(output)
99
+ end
100
+ end
101
+
102
+ class ApertusDecoderLayer < MLX::NN::Module
103
+ def initialize(args)
104
+ super()
105
+ self.self_attn = ApertusAttention.new(args)
106
+ self.mlp = ApertusMLP.new(args)
107
+ self.attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
108
+ self.feedforward_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
109
+ end
110
+
111
+ def call(x, mask: nil, cache: nil)
112
+ h = x + self_attn.call(attention_layernorm.call(x), mask: mask, cache: cache)
113
+ h + mlp.call(feedforward_layernorm.call(h))
114
+ end
115
+ end
116
+
117
+ class ApertusModel < MLX::NN::Module
118
+ def initialize(args)
119
+ super()
120
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
121
+ self.layers = Array.new(args.num_hidden_layers) { ApertusDecoderLayer.new(args) }
122
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
123
+ end
124
+
125
+ def call(inputs, cache: nil)
126
+ h = embed_tokens.call(inputs)
127
+ layer_cache = cache || [nil] * layers.length
128
+ mask = _create_attention_mask(h, layer_cache[0])
129
+
130
+ layers.each_with_index do |layer, i|
131
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
132
+ end
133
+
134
+ norm.call(h)
135
+ end
136
+
137
+ private
138
+
139
+ def _create_attention_mask(h, cache)
140
+ return cache.make_mask(h.shape[1]) if cache && cache.respond_to?(:make_mask)
141
+ return nil if h.shape[1] == 1
142
+
143
+ "causal"
144
+ end
145
+ end
146
+
147
+ class Model < MLX::NN::Module
148
+ def initialize(args)
149
+ super()
150
+ self.args = args
151
+ self.model_type = args.model_type
152
+ self.model = ApertusModel.new(args)
153
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
154
+ end
155
+
156
+ def call(inputs, cache: nil)
157
+ out = model.call(inputs, cache: cache)
158
+ lm_head.call(out)
159
+ end
160
+
161
+ def sanitize(weights)
162
+ mx = MLX::Core
163
+ weights.each do |k, v|
164
+ if k.end_with?("alpha_p") || k.end_with?("alpha_n")
165
+ weights[k] = mx.squeeze(v)
166
+ end
167
+ end
168
+ weights
169
+ end
170
+
171
+ def layers
172
+ model.layers
173
+ end
174
+ end
175
+
176
+ Models.register("apertus", Model, ModelArgs)
177
+ end
178
+ end
179
+ end