mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
@@ -0,0 +1,292 @@
1
+ require_relative "activations"
2
+ require_relative "rope_utils"
3
+ require_relative "switch_layers"
4
+
5
+ module MlxLm
6
+ module Models
7
+ module Dots1
8
+ class ModelArgs < BaseModelArgs
9
+ field :model_type, default: "dots1"
10
+ field :hidden_size
11
+ field :num_hidden_layers
12
+ field :intermediate_size
13
+ field :num_attention_heads
14
+ field :rms_norm_eps
15
+ field :vocab_size
16
+ field :max_position_embeddings, default: nil
17
+ field :num_key_value_heads
18
+ field :first_k_dense_replace
19
+ field :moe_intermediate_size
20
+ field :n_routed_experts
21
+ field :n_shared_experts
22
+ field :norm_topk_prob
23
+ field :num_experts_per_tok
24
+ field :rope_theta
25
+ field :routed_scaling_factor
26
+ field :head_dim, default: nil
27
+ field :scoring_func, default: "noaux_tc"
28
+ field :n_group, default: 1
29
+ field :topk_group, default: 1
30
+ field :attention_bias, default: false
31
+ field :mlp_bias, default: false
32
+ field :rope_scaling, default: nil
33
+ field :tie_word_embeddings, default: false
34
+
35
+ def initialize(**kwargs)
36
+ super
37
+ @num_key_value_heads ||= @num_attention_heads
38
+ @head_dim ||= @hidden_size / @num_attention_heads
39
+ @n_group ||= 1
40
+ @topk_group ||= 1
41
+ end
42
+ end
43
+
44
+ class Dots1Attention < MLX::NN::Module
45
+ def initialize(args)
46
+ super()
47
+
48
+ dim = args.hidden_size
49
+ @n_heads = args.num_attention_heads
50
+ @n_kv_heads = args.num_key_value_heads
51
+ @head_dim = args.head_dim
52
+ @scale = @head_dim**(-0.5)
53
+
54
+ self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: args.attention_bias)
55
+ self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: args.attention_bias)
56
+ self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: args.attention_bias)
57
+ self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: args.attention_bias)
58
+
59
+ self.q_norm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps)
60
+ self.k_norm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps)
61
+ self.rope = MlxLm::Models.initialize_rope(
62
+ @head_dim,
63
+ args.rope_theta,
64
+ false,
65
+ args.rope_scaling,
66
+ max_position_embeddings: args.max_position_embeddings
67
+ )
68
+ end
69
+
70
+ def call(x, mask: nil, cache: nil)
71
+ mx = MLX::Core
72
+ b, l, _d = x.shape
73
+
74
+ queries = q_proj.call(x)
75
+ keys = k_proj.call(x)
76
+ values = v_proj.call(x)
77
+
78
+ queries = q_norm.call(queries.reshape([b, l, @n_heads, @head_dim])).transpose([0, 2, 1, 3])
79
+ keys = k_norm.call(keys.reshape([b, l, @n_kv_heads, @head_dim])).transpose([0, 2, 1, 3])
80
+ values = values.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
81
+
82
+ if cache
83
+ queries = rope.call(queries, offset: cache.offset)
84
+ keys = rope.call(keys, offset: cache.offset)
85
+ keys, values = cache.update_and_fetch(keys, values)
86
+ else
87
+ queries = rope.call(queries)
88
+ keys = rope.call(keys)
89
+ end
90
+
91
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
92
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
93
+ o_proj.call(output)
94
+ end
95
+ end
96
+
97
+ class Dots1TopkRouter < MLX::NN::Module
98
+ def initialize(args)
99
+ super()
100
+ mx = MLX::Core
101
+ @top_k = args.num_experts_per_tok
102
+ @norm_topk_prob = args.norm_topk_prob
103
+ @n_routed_experts = args.n_routed_experts
104
+ @routed_scaling_factor = args.routed_scaling_factor
105
+ @n_group = args.n_group
106
+ @topk_group = args.topk_group
107
+ self.weight = mx.zeros([@n_routed_experts, args.hidden_size]).astype(mx.float32)
108
+ self.e_score_correction_bias = mx.zeros([@n_routed_experts]).astype(mx.float32)
109
+ end
110
+
111
+ def call(x)
112
+ mx = MLX::Core
113
+
114
+ gates = mx.matmul(x, mx.transpose(weight))
115
+ scores = mx.sigmoid(gates.astype(mx.float32))
116
+ scores = scores + e_score_correction_bias.reshape([1, 1, @n_routed_experts])
117
+
118
+ k = [[@top_k.to_i, 1].max, @n_routed_experts].min
119
+ inds = mx.stop_gradient(mx.argpartition(scores * -1.0, k - 1, -1))
120
+ take_ids = mx.array((0...k).to_a, dtype: mx.int32)
121
+ inds = mx.take(inds, take_ids, -1)
122
+
123
+ selected_scores = mx.take_along_axis(mx.sigmoid(gates.astype(mx.float32)), inds, -1)
124
+ if k > 1 && @norm_topk_prob
125
+ denom = mx.expand_dims(mx.sum(selected_scores, -1), -1)
126
+ selected_scores = selected_scores / denom
127
+ end
128
+ selected_scores = selected_scores * @routed_scaling_factor.to_f
129
+
130
+ [inds, selected_scores.astype(gates.dtype)]
131
+ end
132
+ end
133
+
134
+ class Dots1MLP < MLX::NN::Module
135
+ def initialize(args, hidden_size: nil, intermediate_size: nil)
136
+ super()
137
+ @hidden_size = hidden_size || args.hidden_size
138
+ @intermediate_size = intermediate_size || args.intermediate_size
139
+
140
+ self.gate_proj = MLX::NN::Linear.new(@hidden_size, @intermediate_size, bias: args.mlp_bias)
141
+ self.up_proj = MLX::NN::Linear.new(@hidden_size, @intermediate_size, bias: args.mlp_bias)
142
+ self.down_proj = MLX::NN::Linear.new(@intermediate_size, @hidden_size, bias: args.mlp_bias)
143
+ end
144
+
145
+ def call(x)
146
+ down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x)))
147
+ end
148
+ end
149
+
150
+ class Dots1MoE < MLX::NN::Module
151
+ def initialize(args)
152
+ super()
153
+ @n_shared_experts = args.n_shared_experts
154
+ self.experts = SwitchLayers::SwitchGLU.new(
155
+ args.hidden_size,
156
+ args.moe_intermediate_size,
157
+ args.n_routed_experts,
158
+ bias: args.mlp_bias
159
+ )
160
+ self.gate = Dots1TopkRouter.new(args)
161
+
162
+ if @n_shared_experts && @n_shared_experts > 0
163
+ self.shared_experts = Dots1MLP.new(
164
+ args,
165
+ intermediate_size: args.moe_intermediate_size * @n_shared_experts
166
+ )
167
+ end
168
+ end
169
+
170
+ def call(x)
171
+ mx = MLX::Core
172
+ inds, scores = gate.call(x)
173
+ y = experts.call(x, inds)
174
+ y = mx.sum(y * mx.expand_dims(scores.astype(y.dtype), -1), -2)
175
+
176
+ y = y + shared_experts.call(x) if @n_shared_experts && @n_shared_experts > 0
177
+ y
178
+ end
179
+ end
180
+
181
+ class Dots1DecoderLayer < MLX::NN::Module
182
+ def initialize(args, layer_idx)
183
+ super()
184
+ self.self_attn = Dots1Attention.new(args)
185
+ if layer_idx >= args.first_k_dense_replace
186
+ self.mlp = Dots1MoE.new(args)
187
+ else
188
+ self.mlp = Dots1MLP.new(args)
189
+ end
190
+ self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
191
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
192
+ end
193
+
194
+ def call(x, mask: nil, cache: nil)
195
+ r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
196
+ h = x + r
197
+ r = mlp.call(post_attention_layernorm.call(h))
198
+ h + r
199
+ end
200
+ end
201
+
202
+ class Dots1Model < MLX::NN::Module
203
+ def initialize(args)
204
+ super()
205
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
206
+ self.layers = Array.new(args.num_hidden_layers) { |layer_idx| Dots1DecoderLayer.new(args, layer_idx) }
207
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
208
+ end
209
+
210
+ def call(inputs, cache: nil)
211
+ h = embed_tokens.call(inputs)
212
+ layer_cache = cache || [nil] * layers.length
213
+ mask = _create_attention_mask(h, layer_cache[0])
214
+
215
+ layers.each_with_index do |layer, layer_idx|
216
+ h = layer.call(h, mask: mask, cache: layer_cache[layer_idx])
217
+ end
218
+
219
+ norm.call(h)
220
+ end
221
+
222
+ private
223
+
224
+ def _create_attention_mask(h, cache)
225
+ n = h.shape[1]
226
+ return cache.make_mask(n) if cache && cache.respond_to?(:make_mask)
227
+ return nil if n == 1
228
+
229
+ "causal"
230
+ end
231
+ end
232
+
233
+ class Model < MLX::NN::Module
234
+ def initialize(args)
235
+ super()
236
+ @args = args
237
+ self.model_type = args.model_type
238
+ self.model = Dots1Model.new(args)
239
+ unless args.tie_word_embeddings
240
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
241
+ end
242
+ end
243
+
244
+ def call(inputs, cache: nil)
245
+ out = model.call(inputs, cache: cache)
246
+ if @args.tie_word_embeddings
247
+ model.embed_tokens.as_linear(out)
248
+ else
249
+ lm_head.call(out)
250
+ end
251
+ end
252
+
253
+ def sanitize(weights)
254
+ result = weights.dup
255
+ result.delete("lm_head.weight") if @args.tie_word_embeddings
256
+
257
+ experts_count = @args.n_routed_experts.to_i
258
+ if experts_count > 0
259
+ mx = MLX::Core
260
+ @args.num_hidden_layers.times do |layer_idx|
261
+ next if layer_idx < @args.first_k_dense_replace
262
+
263
+ prefix = "model.layers.#{layer_idx}.mlp"
264
+ %w[gate_proj down_proj up_proj].each do |projection|
265
+ %w[weight scales biases].each do |param|
266
+ first_key = "#{prefix}.experts.0.#{projection}.#{param}"
267
+ next unless result.key?(first_key)
268
+
269
+ expert_keys = (0...experts_count).map do |expert_idx|
270
+ "#{prefix}.experts.#{expert_idx}.#{projection}.#{param}"
271
+ end
272
+ next unless expert_keys.all? { |key| result.key?(key) }
273
+
274
+ stacked = expert_keys.map { |key| result.delete(key) }
275
+ result["#{prefix}.experts.#{projection}.#{param}"] = mx.stack(stacked)
276
+ end
277
+ end
278
+ end
279
+ end
280
+
281
+ result.reject { |k, _| k.include?("rotary_emb.inv_freq") }
282
+ end
283
+
284
+ def layers
285
+ model.layers
286
+ end
287
+ end
288
+
289
+ Models.register("dots1", Model, ModelArgs)
290
+ end
291
+ end
292
+ end
@@ -0,0 +1,165 @@
1
+ module MlxLm
2
+ module Models
3
+ module Ernie45
4
+ class ModelArgs < BaseModelArgs
5
+ field :hidden_size
6
+ field :intermediate_size
7
+ field :model_type, default: "ernie4_5"
8
+ field :max_position_embeddings
9
+ field :num_attention_heads
10
+ field :num_key_value_heads
11
+ field :head_dim, default: nil
12
+ field :num_hidden_layers
13
+ field :rms_norm_eps
14
+ field :vocab_size
15
+ field :rope_theta
16
+ field :use_bias, default: false
17
+ field :tie_word_embeddings, default: false
18
+
19
+ def initialize(**kwargs)
20
+ super
21
+ @num_key_value_heads ||= @num_attention_heads
22
+ @head_dim ||= @hidden_size / @num_attention_heads
23
+ end
24
+ end
25
+
26
+ class Attention < MLX::NN::Module
27
+ def initialize(args)
28
+ super()
29
+
30
+ dim = args.hidden_size
31
+ @n_heads = args.num_attention_heads
32
+ @n_kv_heads = args.num_key_value_heads
33
+ @head_dim = args.head_dim
34
+ @scale = @head_dim**(-0.5)
35
+
36
+ self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: args.use_bias)
37
+ self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: args.use_bias)
38
+ self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: args.use_bias)
39
+ self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: args.use_bias)
40
+
41
+ self.rope = MlxLm::Models.initialize_rope(
42
+ @head_dim,
43
+ args.rope_theta,
44
+ true,
45
+ nil,
46
+ max_position_embeddings: args.max_position_embeddings
47
+ )
48
+ end
49
+
50
+ def call(x, mask: nil, cache: nil)
51
+ mx = MLX::Core
52
+ b, l, _d = x.shape
53
+
54
+ queries = q_proj.call(x)
55
+ keys = k_proj.call(x)
56
+ values = v_proj.call(x)
57
+
58
+ queries = queries.reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
59
+ keys = keys.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
60
+ values = values.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
61
+
62
+ if cache
63
+ queries = rope.call(queries, offset: cache.offset)
64
+ keys = rope.call(keys, offset: cache.offset)
65
+ keys, values = cache.update_and_fetch(keys, values)
66
+ else
67
+ queries = rope.call(queries)
68
+ keys = rope.call(keys)
69
+ end
70
+
71
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
72
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
73
+ o_proj.call(output)
74
+ end
75
+ end
76
+
77
+ class MLP < MLX::NN::Module
78
+ def initialize(dim, hidden_dim, use_bias: false)
79
+ super()
80
+ self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: use_bias)
81
+ self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: use_bias)
82
+ self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: use_bias)
83
+ end
84
+
85
+ def call(x)
86
+ down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x)))
87
+ end
88
+ end
89
+
90
+ class DecoderLayer < MLX::NN::Module
91
+ def initialize(args)
92
+ super()
93
+ self.self_attn = Attention.new(args)
94
+ self.mlp = MLP.new(args.hidden_size, args.intermediate_size, use_bias: args.use_bias)
95
+ self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
96
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
97
+ end
98
+
99
+ def call(x, mask: nil, cache: nil)
100
+ r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
101
+ h = x + r
102
+ r = mlp.call(post_attention_layernorm.call(h))
103
+ h + r
104
+ end
105
+ end
106
+
107
+ class Ernie45Model < MLX::NN::Module
108
+ def initialize(args)
109
+ super()
110
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
111
+ self.layers = Array.new(args.num_hidden_layers) { DecoderLayer.new(args) }
112
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
113
+ end
114
+
115
+ def call(inputs, cache: nil)
116
+ h = embed_tokens.call(inputs)
117
+ layer_cache = cache || [nil] * layers.length
118
+ mask = _create_attention_mask(h, layer_cache[0])
119
+
120
+ layers.each_with_index do |layer, i|
121
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
122
+ end
123
+
124
+ norm.call(h)
125
+ end
126
+
127
+ private
128
+
129
+ def _create_attention_mask(h, cache)
130
+ return cache.make_mask(h.shape[1]) if cache && cache.respond_to?(:make_mask)
131
+ return nil if h.shape[1] == 1
132
+
133
+ "causal"
134
+ end
135
+ end
136
+
137
+ class Model < MLX::NN::Module
138
+ def initialize(args)
139
+ super()
140
+ @args = args
141
+ self.model_type = args.model_type
142
+ self.model = Ernie45Model.new(args)
143
+ unless args.tie_word_embeddings
144
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
145
+ end
146
+ end
147
+
148
+ def call(inputs, cache: nil)
149
+ out = model.call(inputs, cache: cache)
150
+ if @args.tie_word_embeddings
151
+ model.embed_tokens.as_linear(out)
152
+ else
153
+ lm_head.call(out)
154
+ end
155
+ end
156
+
157
+ def layers
158
+ model.layers
159
+ end
160
+ end
161
+
162
+ Models.register("ernie4_5", Model, ModelArgs)
163
+ end
164
+ end
165
+ end
@@ -0,0 +1,97 @@
1
+ require_relative "activations"
2
+ require_relative "rope_utils"
3
+ require_relative "ernie4_5"
4
+
5
+ module MlxLm
6
+ module Models
7
+ module Ernie45Moe
8
+ class ModelArgs < Ernie45::ModelArgs
9
+ field :model_type, default: "ernie4_5_moe"
10
+ field :moe_num_experts, default: 0
11
+ field :moe_layer_start_index, default: 0
12
+ field :moe_intermediate_size, default: 0
13
+ field :moe_capacity, default: []
14
+ field :moe_k, default: 1
15
+ field :moe_layer_interval, default: 1
16
+ field :moe_use_aux_free, default: false
17
+ field :moe_num_shared_experts, default: 0
18
+ field :moe_layer_end_index, default: nil
19
+ field :moe_gate_act, default: "softmax"
20
+
21
+ def initialize(**kwargs)
22
+ super
23
+ @moe_capacity = Array(@moe_capacity).dup
24
+ end
25
+ end
26
+
27
+ class Model < Ernie45::Model
28
+ REMOVE_PATTERNS = [
29
+ "mtp_block.",
30
+ "mtp_linear_proj.",
31
+ "mtp_hidden_norm.",
32
+ "mtp_emb_norm.",
33
+ "e_score_correction_bias",
34
+ ].freeze
35
+
36
+ EXPERT_PROJ_NAMES = %w[gate_proj down_proj up_proj].freeze
37
+
38
+ def sanitize(weights)
39
+ result = weights.reject do |key, _|
40
+ REMOVE_PATTERNS.any? { |pattern| key.include?(pattern) }
41
+ end
42
+
43
+ stack_expert_weights!(result)
44
+ end
45
+
46
+ private
47
+
48
+ def stack_expert_weights!(weights)
49
+ mx = MLX::Core
50
+ num_experts = @args.moe_num_experts.to_i
51
+ return weights if num_experts <= 0
52
+
53
+ @args.num_hidden_layers.times do |layer_idx|
54
+ prefix = "model.layers.#{layer_idx}.mlp"
55
+
56
+ EXPERT_PROJ_NAMES.each do |proj_name|
57
+ expert_weights = pop_complete_expert_weights(weights, prefix, proj_name, num_experts)
58
+ next unless expert_weights
59
+
60
+ weights["#{prefix}.switch_mlp.#{proj_name}.weight"] = mx.stack(expert_weights)
61
+ end
62
+ end
63
+
64
+ weights
65
+ end
66
+
67
+ def pop_complete_expert_weights(weights, prefix, proj_name, num_experts)
68
+ first_key = expert_weight_key(prefix, 0, proj_name)
69
+ return nil unless weights.key?(first_key)
70
+
71
+ popped = []
72
+ num_experts.times do |expert_idx|
73
+ key = expert_weight_key(prefix, expert_idx, proj_name)
74
+ unless weights.key?(key)
75
+ restore_popped_weights!(weights, prefix, proj_name, popped)
76
+ return nil
77
+ end
78
+ popped << weights.delete(key)
79
+ end
80
+ popped
81
+ end
82
+
83
+ def restore_popped_weights!(weights, prefix, proj_name, popped)
84
+ popped.each_with_index do |tensor, idx|
85
+ weights[expert_weight_key(prefix, idx, proj_name)] = tensor
86
+ end
87
+ end
88
+
89
+ def expert_weight_key(prefix, expert_idx, proj_name)
90
+ "#{prefix}.experts.#{expert_idx}.#{proj_name}.weight"
91
+ end
92
+ end
93
+
94
+ Models.register("ernie4_5_moe", Model, ModelArgs)
95
+ end
96
+ end
97
+ end