mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
@@ -0,0 +1,378 @@
1
+ require_relative "activations"
2
+ require_relative "switch_layers"
3
+
4
+ module MlxLm
5
+ module Models
6
+ module Hunyuan
7
+ module_function
8
+
9
+ def int_or_list(value, idx)
10
+ return value[idx] if value.is_a?(Array)
11
+
12
+ value
13
+ end
14
+
15
+ class ModelArgs < BaseModelArgs
16
+ field :model_type, default: "hunyuan"
17
+ field :vocab_size
18
+ field :hidden_size
19
+ field :num_hidden_layers
20
+ field :intermediate_size
21
+ field :num_attention_heads
22
+ field :num_key_value_heads, default: nil
23
+ field :attention_bias
24
+ field :moe_topk
25
+ field :num_experts
26
+ field :num_shared_expert
27
+ field :use_mixed_mlp_moe
28
+ field :use_qk_norm
29
+ field :rms_norm_eps
30
+ field :rope_theta
31
+ field :use_cla
32
+ field :cla_share_factor, default: 2
33
+ field :moe_intermediate_size, default: nil
34
+ field :rope_scaling, default: nil
35
+ field :tie_word_embeddings, default: false
36
+
37
+ def initialize(**kwargs)
38
+ super
39
+ @num_key_value_heads ||= @num_attention_heads
40
+ _validate_rope_scaling!
41
+ end
42
+
43
+ private
44
+
45
+ def _validate_rope_scaling!
46
+ return if @rope_scaling.nil?
47
+
48
+ required_keys = %w[factor type]
49
+ return if required_keys.all? { |key| _rope_scaling_has_key?(key) }
50
+
51
+ raise ArgumentError, "rope_scaling must contain keys #{required_keys}"
52
+ end
53
+
54
+ def _rope_scaling_has_key?(key)
55
+ @rope_scaling.key?(key) || @rope_scaling.key?(key.to_sym)
56
+ end
57
+ end
58
+
59
+ class DynamicNTKAlphaRoPE < MLX::NN::Module
60
+ def initialize(dims, base: 10_000.0, scaling_alpha: 1.0)
61
+ super()
62
+ mx = MLX::Core
63
+
64
+ @dims = dims
65
+ adjusted_base = base * (scaling_alpha**(dims.to_f / (dims - 2)))
66
+ self._freqs = mx.power(
67
+ adjusted_base,
68
+ mx.divide(mx.arange(0, dims, 2, mx.float32), dims.to_f)
69
+ )
70
+ end
71
+
72
+ def call(x, offset: 0)
73
+ MLX::Core.rope(x, @dims, false, nil, 1.0, offset, _freqs)
74
+ end
75
+ end
76
+
77
+ class Attention < MLX::NN::Module
78
+ def initialize(kv_proj, args)
79
+ super()
80
+ dim = args.hidden_size
81
+
82
+ @kv_proj = kv_proj
83
+ @n_heads = args.num_attention_heads
84
+ @n_kv_heads = args.num_key_value_heads
85
+ @head_dim = dim / @n_heads
86
+ @scale = @head_dim**(-0.5)
87
+ @use_qk_norm = args.use_qk_norm
88
+
89
+ self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: args.attention_bias)
90
+ if kv_proj
91
+ self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: args.attention_bias)
92
+ self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: args.attention_bias)
93
+ end
94
+ self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: args.attention_bias)
95
+
96
+ if @use_qk_norm
97
+ self.query_layernorm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps)
98
+ self.key_layernorm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps)
99
+ end
100
+
101
+ scaling_alpha = _config_value(args.rope_scaling, "alpha", 1.0)
102
+ self.rope = DynamicNTKAlphaRoPE.new(
103
+ @head_dim,
104
+ base: args.rope_theta,
105
+ scaling_alpha: scaling_alpha
106
+ )
107
+ end
108
+
109
+ def call(x, mask: nil, cache: nil, kv_states: nil)
110
+ mx = MLX::Core
111
+ b, l, _d = x.shape
112
+
113
+ queries = q_proj.call(x)
114
+ if kv_states
115
+ keys, values = kv_states
116
+ else
117
+ raise ArgumentError, "kv_states required when kv_proj is disabled" unless @kv_proj
118
+
119
+ keys = k_proj.call(x)
120
+ values = v_proj.call(x)
121
+ kv_states = [keys, values]
122
+ end
123
+
124
+ queries = queries.reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
125
+ keys = keys.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
126
+ values = values.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
127
+
128
+ offset = cache ? cache.offset : 0
129
+ queries = rope.call(queries, offset: offset)
130
+ keys = rope.call(keys, offset: offset)
131
+
132
+ if @use_qk_norm
133
+ queries = query_layernorm.call(queries)
134
+ keys = key_layernorm.call(keys)
135
+ end
136
+
137
+ keys, values = cache.update_and_fetch(keys, values) if cache
138
+
139
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
140
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
141
+ [o_proj.call(output), kv_states]
142
+ end
143
+
144
+ private
145
+
146
+ def _config_value(config, key, default = nil)
147
+ return default if config.nil?
148
+ return config[key] if config.key?(key)
149
+
150
+ config.fetch(key.to_sym, default)
151
+ end
152
+ end
153
+
154
+ class MLP < MLX::NN::Module
155
+ def initialize(dim, hidden_dim)
156
+ super()
157
+ self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false)
158
+ self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: false)
159
+ self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false)
160
+ end
161
+
162
+ def call(x)
163
+ down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x)))
164
+ end
165
+ end
166
+
167
+ class Gate < MLX::NN::Module
168
+ def initialize(dim, num_experts)
169
+ super()
170
+ self.wg = MLX::NN::Linear.new(dim, num_experts, bias: false)
171
+ end
172
+
173
+ def call(x)
174
+ wg.call(x)
175
+ end
176
+ end
177
+
178
+ class MoeBlock < MLX::NN::Module
179
+ def initialize(args, layer_idx: 0)
180
+ super()
181
+ dim = args.hidden_size
182
+ intermediate_size = args.intermediate_size
183
+
184
+ @use_shared_mlp = args.use_mixed_mlp_moe
185
+ if @use_shared_mlp
186
+ num_shared = Hunyuan.int_or_list(args.num_shared_expert, layer_idx).to_i
187
+ self.shared_mlp = MLP.new(dim, (intermediate_size * num_shared).to_i)
188
+ end
189
+
190
+ @num_experts = args.num_experts
191
+ @top_k = Hunyuan.int_or_list(args.moe_topk, layer_idx).to_i
192
+ self.gate = Gate.new(dim, @num_experts)
193
+
194
+ expert_intermediate_size = args.moe_intermediate_size.nil? ?
195
+ intermediate_size :
196
+ Hunyuan.int_or_list(args.moe_intermediate_size, layer_idx)
197
+
198
+ self.switch_mlp = SwitchLayers::SwitchGLU.new(
199
+ dim,
200
+ expert_intermediate_size,
201
+ @num_experts
202
+ )
203
+ end
204
+
205
+ def call(x)
206
+ mx = MLX::Core
207
+
208
+ gates = gate.call(x)
209
+ gates = mx.softmax(gates.astype(mx.float32), -1).astype(gates.dtype)
210
+
211
+ k = [[@top_k, 1].max, @num_experts].min
212
+ inds = mx.stop_gradient(mx.argpartition(gates * -1.0, k - 1, -1))
213
+ take_ids = mx.array((0...k).to_a, dtype: mx.int32)
214
+ inds = mx.take(inds, take_ids, -1)
215
+ scores = mx.take_along_axis(gates, inds, -1)
216
+
217
+ y = switch_mlp.call(x, inds)
218
+ y = mx.sum(y * mx.expand_dims(scores.astype(mx.float32), -1), -2).astype(y.dtype)
219
+
220
+ y = y + shared_mlp.call(x) if @use_shared_mlp
221
+ y
222
+ end
223
+ end
224
+
225
+ class DecoderLayer < MLX::NN::Module
226
+ def initialize(args, kv_proj:, layer_idx:)
227
+ super()
228
+ self.self_attn = Attention.new(kv_proj, args)
229
+ if args.num_experts.to_i == 1
230
+ self.mlp = MLP.new(args.hidden_size, args.intermediate_size)
231
+ else
232
+ self.mlp = MoeBlock.new(args, layer_idx: layer_idx)
233
+ end
234
+ self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
235
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
236
+ end
237
+
238
+ def call(x, mask: nil, cache: nil, shared_kv_states: nil)
239
+ r, shared_kv_states = self_attn.call(
240
+ input_layernorm.call(x),
241
+ mask: mask,
242
+ cache: cache,
243
+ kv_states: shared_kv_states
244
+ )
245
+ h = x + r
246
+ r = mlp.call(post_attention_layernorm.call(h))
247
+ [h + r, shared_kv_states]
248
+ end
249
+ end
250
+
251
+ class HunYuanModel < MLX::NN::Module
252
+ def initialize(args)
253
+ super()
254
+ @args = args
255
+
256
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
257
+ self.layers = Array.new(args.num_hidden_layers) do |i|
258
+ kv_proj = (!args.use_cla) || (i % args.cla_share_factor).zero?
259
+ DecoderLayer.new(args, kv_proj: kv_proj, layer_idx: i)
260
+ end
261
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
262
+ end
263
+
264
+ def call(inputs, cache: nil)
265
+ h = embed_tokens.call(inputs)
266
+ layer_cache = cache || [nil] * layers.length
267
+ mask = _create_attention_mask(h, layer_cache[0])
268
+
269
+ shared_kv_states = nil
270
+ layers.each_with_index do |layer, i|
271
+ if (!@args.use_cla) || (i % @args.cla_share_factor).zero?
272
+ shared_kv_states = nil
273
+ end
274
+ h, shared_kv_states = layer.call(
275
+ h,
276
+ mask: mask,
277
+ cache: layer_cache[i],
278
+ shared_kv_states: shared_kv_states
279
+ )
280
+ end
281
+
282
+ norm.call(h)
283
+ end
284
+
285
+ private
286
+
287
+ def _create_attention_mask(hidden, cache)
288
+ return cache.make_mask(hidden.shape[1]) if cache && cache.respond_to?(:make_mask)
289
+ return nil if hidden.shape[1] == 1
290
+
291
+ "causal"
292
+ end
293
+ end
294
+
295
+ class Model < MLX::NN::Module
296
+ def initialize(args)
297
+ super()
298
+ @args = args
299
+ self.model_type = args.model_type
300
+ self.model = HunYuanModel.new(args)
301
+ unless args.tie_word_embeddings
302
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
303
+ end
304
+ end
305
+
306
+ def call(inputs, cache: nil)
307
+ out = model.call(inputs, cache: cache)
308
+ if @args.tie_word_embeddings
309
+ model.embed_tokens.as_linear(out)
310
+ else
311
+ lm_head.call(out)
312
+ end
313
+ end
314
+
315
+ def sanitize(weights)
316
+ mx = MLX::Core
317
+ result = weights.dup
318
+
319
+ if result.key?("model.layers.0.mlp.gate_and_up_proj.weight")
320
+ new_weights = {}
321
+ d = @args.hidden_size
322
+ n_kv_heads = @args.num_key_value_heads
323
+ n_kv_groups = @args.num_attention_heads / n_kv_heads
324
+ head_dim = d / @args.num_attention_heads
325
+
326
+ result.each do |key, value|
327
+ if key.include?("qkv_proj")
328
+ reshaped = value.reshape([n_kv_heads, n_kv_groups + 2, head_dim, -1])
329
+ qkv_splits = mx.split(reshaped, [n_kv_groups, n_kv_groups + 1], 1)
330
+ %w[q_proj k_proj v_proj].each_with_index do |proj, idx|
331
+ new_weights[key.sub("qkv_proj", proj)] = mx.flatten(qkv_splits[idx], 0, 2)
332
+ end
333
+ elsif key.include?("gate_and_up_proj")
334
+ split_idx = value.shape[0] / 2
335
+ up_proj, gate_proj = mx.split(value, [split_idx], 0)
336
+ new_weights[key.sub("gate_and_up_proj", "up_proj")] = up_proj
337
+ new_weights[key.sub("gate_and_up_proj", "gate_proj")] = gate_proj
338
+ else
339
+ new_weights[key] = value
340
+ end
341
+ end
342
+
343
+ result = new_weights
344
+ end
345
+
346
+ if result.key?("model.layers.0.mlp.experts.0.up_proj.weight")
347
+ @args.num_hidden_layers.times do |layer_idx|
348
+ prefix = "model.layers.#{layer_idx}"
349
+ %w[up_proj down_proj gate_proj].each do |projection|
350
+ %w[weight scales biases].each do |param|
351
+ first_key = "#{prefix}.mlp.experts.0.#{projection}.#{param}"
352
+ next unless result.key?(first_key)
353
+
354
+ expert_keys = (0...@args.num_experts).map do |expert_idx|
355
+ "#{prefix}.mlp.experts.#{expert_idx}.#{projection}.#{param}"
356
+ end
357
+ next unless expert_keys.all? { |k| result.key?(k) }
358
+
359
+ stacked = expert_keys.map { |k| result.delete(k) }
360
+ result["#{prefix}.mlp.switch_mlp.#{projection}.#{param}"] = mx.stack(stacked)
361
+ end
362
+ end
363
+ end
364
+ end
365
+
366
+ result.delete("lm_head.weight") if @args.tie_word_embeddings
367
+ result
368
+ end
369
+
370
+ def layers
371
+ model.layers
372
+ end
373
+ end
374
+
375
+ Models.register("hunyuan", Model, ModelArgs)
376
+ end
377
+ end
378
+ end
@@ -0,0 +1,235 @@
1
+ require_relative "activations"
2
+
3
+ module MlxLm
4
+ module Models
5
+ module HunyuanV1Dense
6
+ class ModelArgs < BaseModelArgs
7
+ field :model_type, default: "hunyuan_v1_dense"
8
+ field :vocab_size, default: 151_936
9
+ field :hidden_size, default: 4096
10
+ field :num_hidden_layers, default: 40
11
+ field :intermediate_size, default: 12_288
12
+ field :num_attention_heads, default: 32
13
+ field :num_key_value_heads, default: 8
14
+ field :rms_norm_eps, default: 1e-6
15
+ field :rope_theta, default: 10_000.0
16
+ field :max_position_embeddings, default: 32_768
17
+ field :attention_bias, default: false
18
+ field :use_qk_norm, default: true
19
+ field :rope_scaling, default: nil
20
+ field :tie_word_embeddings, default: false
21
+ field :head_dim, default: nil
22
+
23
+ def initialize(**kwargs)
24
+ super
25
+ @num_key_value_heads ||= @num_attention_heads
26
+ @head_dim ||= @hidden_size / @num_attention_heads
27
+ _validate_rope_scaling!
28
+ end
29
+
30
+ private
31
+
32
+ def _validate_rope_scaling!
33
+ return if @rope_scaling.nil?
34
+
35
+ required_keys = %w[alpha factor type]
36
+ missing = required_keys.reject { |key| _config_has_key?(key) }
37
+ return if missing.empty?
38
+
39
+ raise ArgumentError, "rope_scaling must contain keys #{required_keys}"
40
+ end
41
+
42
+ def _config_has_key?(key)
43
+ @rope_scaling.key?(key) || @rope_scaling.key?(key.to_sym)
44
+ end
45
+ end
46
+
47
+ class DynamicNTKAlphaRoPE < MLX::NN::Module
48
+ def initialize(dims, base: 10_000.0, scaling_alpha: 1.0)
49
+ super()
50
+ mx = MLX::Core
51
+
52
+ @dims = dims
53
+ adjusted_base = base * (scaling_alpha**(dims.to_f / (dims - 2)))
54
+ self._freqs = mx.power(
55
+ adjusted_base,
56
+ mx.divide(mx.arange(0, dims, 2, mx.float32), dims.to_f)
57
+ )
58
+ end
59
+
60
+ def call(x, offset: 0)
61
+ MLX::Core.rope(x, @dims, false, nil, 1.0, offset, _freqs)
62
+ end
63
+ end
64
+
65
+ class Attention < MLX::NN::Module
66
+ def initialize(args)
67
+ super()
68
+ dim = args.hidden_size
69
+
70
+ @n_heads = args.num_attention_heads
71
+ @n_kv_heads = args.num_key_value_heads
72
+ @head_dim = args.head_dim
73
+ @scale = @head_dim**(-0.5)
74
+ @use_qk_norm = args.use_qk_norm
75
+
76
+ self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: args.attention_bias)
77
+ self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: args.attention_bias)
78
+ self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: args.attention_bias)
79
+ self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: args.attention_bias)
80
+
81
+ if @use_qk_norm
82
+ self.query_layernorm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps)
83
+ self.key_layernorm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps)
84
+ end
85
+
86
+ scaling_alpha = _config_value(args.rope_scaling, "alpha", 1.0)
87
+ self.rope = DynamicNTKAlphaRoPE.new(
88
+ @head_dim,
89
+ base: args.rope_theta,
90
+ scaling_alpha: scaling_alpha
91
+ )
92
+ end
93
+
94
+ def call(x, mask: nil, cache: nil)
95
+ mx = MLX::Core
96
+ b, l, _d = x.shape
97
+
98
+ queries = q_proj.call(x)
99
+ keys = k_proj.call(x)
100
+ values = v_proj.call(x)
101
+
102
+ queries = queries.reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
103
+ keys = keys.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
104
+ values = values.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
105
+
106
+ if cache
107
+ queries = rope.call(queries, offset: cache.offset)
108
+ keys = rope.call(keys, offset: cache.offset)
109
+ else
110
+ queries = rope.call(queries)
111
+ keys = rope.call(keys)
112
+ end
113
+
114
+ if @use_qk_norm
115
+ queries = query_layernorm.call(queries)
116
+ keys = key_layernorm.call(keys)
117
+ end
118
+
119
+ if cache
120
+ keys, values = cache.update_and_fetch(keys, values)
121
+ end
122
+
123
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
124
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
125
+ o_proj.call(output)
126
+ end
127
+
128
+ private
129
+
130
+ def _config_value(config, key, default = nil)
131
+ return default if config.nil?
132
+ return config[key] if config.key?(key)
133
+
134
+ config.fetch(key.to_sym, default)
135
+ end
136
+ end
137
+
138
+ class MLP < MLX::NN::Module
139
+ def initialize(args)
140
+ super()
141
+ dim = args.hidden_size
142
+ hidden_dim = args.intermediate_size
143
+
144
+ self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false)
145
+ self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: false)
146
+ self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false)
147
+ end
148
+
149
+ def call(x)
150
+ down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x)))
151
+ end
152
+ end
153
+
154
+ class TransformerBlock < MLX::NN::Module
155
+ def initialize(args)
156
+ super()
157
+ self.self_attn = Attention.new(args)
158
+ self.mlp = MLP.new(args)
159
+ self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
160
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
161
+ end
162
+
163
+ def call(x, mask: nil, cache: nil)
164
+ r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
165
+ h = x + r
166
+ r = mlp.call(post_attention_layernorm.call(h))
167
+ h + r
168
+ end
169
+ end
170
+
171
+ class HunyuanV1DenseModel < MLX::NN::Module
172
+ def initialize(args)
173
+ super()
174
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
175
+ self.layers = Array.new(args.num_hidden_layers) { TransformerBlock.new(args) }
176
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
177
+ end
178
+
179
+ def call(inputs, cache: nil)
180
+ h = embed_tokens.call(inputs)
181
+ layer_cache = cache || [nil] * layers.length
182
+ mask = _create_attention_mask(h, layer_cache[0])
183
+
184
+ layers.each_with_index do |layer, layer_idx|
185
+ h = layer.call(h, mask: mask, cache: layer_cache[layer_idx])
186
+ end
187
+
188
+ norm.call(h)
189
+ end
190
+
191
+ private
192
+
193
+ def _create_attention_mask(hidden, cache)
194
+ return cache.make_mask(hidden.shape[1]) if cache && cache.respond_to?(:make_mask)
195
+ return nil if hidden.shape[1] == 1
196
+
197
+ "causal"
198
+ end
199
+ end
200
+
201
+ class Model < MLX::NN::Module
202
+ def initialize(args)
203
+ super()
204
+ @args = args
205
+ self.model_type = args.model_type
206
+ self.model = HunyuanV1DenseModel.new(args)
207
+ unless args.tie_word_embeddings
208
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
209
+ end
210
+ end
211
+
212
+ def call(inputs, cache: nil)
213
+ out = model.call(inputs, cache: cache)
214
+ if @args.tie_word_embeddings
215
+ model.embed_tokens.as_linear(out)
216
+ else
217
+ lm_head.call(out)
218
+ end
219
+ end
220
+
221
+ def sanitize(weights)
222
+ result = weights.dup
223
+ result.delete("lm_head.weight") if @args.tie_word_embeddings
224
+ result
225
+ end
226
+
227
+ def layers
228
+ model.layers
229
+ end
230
+ end
231
+
232
+ Models.register("hunyuan_v1_dense", Model, ModelArgs)
233
+ end
234
+ end
235
+ end