mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
@@ -0,0 +1,343 @@
1
+ require_relative "activations"
2
+ require_relative "pipeline"
3
+ require_relative "switch_layers"
4
+
5
+ module MlxLm
6
+ module Models
7
+ module Glm4Moe
8
+ class ModelArgs < BaseModelArgs
9
+ field :model_type, default: "glm4_moe"
10
+ field :vocab_size
11
+ field :hidden_size
12
+ field :intermediate_size
13
+ field :max_position_embeddings
14
+ field :moe_intermediate_size
15
+ field :norm_topk_prob
16
+ field :num_attention_heads
17
+ field :n_group
18
+ field :head_dim, default: nil
19
+ field :topk_group
20
+ field :n_shared_experts
21
+ field :n_routed_experts
22
+ field :routed_scaling_factor
23
+ field :num_experts_per_tok
24
+ field :first_k_dense_replace
25
+ field :num_hidden_layers
26
+ field :num_key_value_heads, default: nil
27
+ field :rms_norm_eps
28
+ field :rope_theta
29
+ field :rope_scaling, default: nil
30
+ field :use_qk_norm
31
+ field :tie_word_embeddings
32
+ field :attention_bias
33
+ field :partial_rotary_factor
34
+ field :scoring_func, default: "sigmoid"
35
+ field :topk_method, default: "noaux_tc"
36
+
37
+ def initialize(**kwargs)
38
+ super
39
+ @num_key_value_heads ||= @num_attention_heads
40
+ @head_dim ||= @hidden_size / @num_attention_heads
41
+ end
42
+ end
43
+
44
+ class Attention < MLX::NN::Module
45
+ def initialize(args)
46
+ super()
47
+
48
+ dim = args.hidden_size
49
+ @n_heads = args.num_attention_heads
50
+ @n_kv_heads = args.num_key_value_heads
51
+ @head_dim = args.head_dim
52
+ @scale = @head_dim**(-0.5)
53
+
54
+ self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: args.attention_bias)
55
+ self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: args.attention_bias)
56
+ self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: args.attention_bias)
57
+ self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: false)
58
+
59
+ @use_qk_norm = args.use_qk_norm
60
+ if @use_qk_norm
61
+ self.q_norm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps)
62
+ self.k_norm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps)
63
+ end
64
+
65
+ rope_dims = [(@head_dim * args.partial_rotary_factor.to_f).to_i, 1].max
66
+ self.rope = MLX::NN::RoPE.new(
67
+ rope_dims,
68
+ traditional: false,
69
+ base: args.rope_theta
70
+ )
71
+ end
72
+
73
+ def call(x, mask: nil, cache: nil)
74
+ mx = MLX::Core
75
+ b, l, _d = x.shape
76
+
77
+ queries = q_proj.call(x).reshape([b, l, @n_heads, @head_dim])
78
+ keys = k_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim])
79
+ values = v_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim])
80
+
81
+ if @use_qk_norm
82
+ queries = q_norm.call(queries)
83
+ keys = k_norm.call(keys)
84
+ end
85
+
86
+ queries = queries.transpose([0, 2, 1, 3])
87
+ keys = keys.transpose([0, 2, 1, 3])
88
+ values = values.transpose([0, 2, 1, 3])
89
+
90
+ if cache
91
+ queries = rope.call(queries, offset: cache.offset)
92
+ keys = rope.call(keys, offset: cache.offset)
93
+ keys, values = cache.update_and_fetch(keys, values)
94
+ else
95
+ queries = rope.call(queries)
96
+ keys = rope.call(keys)
97
+ end
98
+
99
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
100
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
101
+ o_proj.call(output)
102
+ end
103
+ end
104
+
105
+ class MLP < MLX::NN::Module
106
+ def initialize(config, hidden_size: nil, intermediate_size: nil)
107
+ super()
108
+ hidden_size ||= config.hidden_size
109
+ intermediate_size ||= config.intermediate_size
110
+
111
+ self.gate_proj = MLX::NN::Linear.new(hidden_size, intermediate_size, bias: false)
112
+ self.up_proj = MLX::NN::Linear.new(hidden_size, intermediate_size, bias: false)
113
+ self.down_proj = MLX::NN::Linear.new(intermediate_size, hidden_size, bias: false)
114
+ end
115
+
116
+ def call(x)
117
+ down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x)))
118
+ end
119
+ end
120
+
121
+ module_function
122
+
123
+ def group_expert_select(
124
+ gates,
125
+ e_score_correction_bias,
126
+ top_k,
127
+ n_group,
128
+ topk_group,
129
+ routed_scaling_factor,
130
+ norm_topk_prob
131
+ )
132
+ mx = MLX::Core
133
+
134
+ scores = mx.sigmoid(gates.astype(mx.float32))
135
+ orig_scores = scores
136
+ scores = scores + e_score_correction_bias
137
+
138
+ if n_group.to_i > 1
139
+ experts_per_group = scores.shape[-1] / n_group
140
+ scores = mx.unflatten(scores, -1, [n_group, experts_per_group])
141
+ group_scores = mx.topk(scores, 2, -1)
142
+ group_scores = mx.expand_dims(mx.sum(group_scores, -1), -1)
143
+
144
+ drop_count = n_group - topk_group.to_i
145
+ if drop_count > 0
146
+ group_idx = mx.argpartition(group_scores, drop_count - 1, -2)
147
+ take_ids = mx.array((0...drop_count).to_a, dtype: mx.int32)
148
+ group_idx = mx.take(group_idx, take_ids, -2)
149
+ scores = mx.put_along_axis(
150
+ scores,
151
+ mx.stop_gradient(group_idx),
152
+ mx.array(0.0),
153
+ -2
154
+ )
155
+ end
156
+
157
+ scores = mx.flatten(scores, -2, -1)
158
+ end
159
+
160
+ k = [[top_k.to_i, 1].max, scores.shape[-1]].min
161
+ inds = mx.argpartition(scores * -1.0, k - 1, -1)
162
+ take_ids = mx.array((0...k).to_a, dtype: mx.int32)
163
+ inds = mx.take(inds, take_ids, -1)
164
+
165
+ selected_scores = mx.take_along_axis(orig_scores, inds, -1)
166
+ if k > 1 && norm_topk_prob
167
+ denominator = mx.expand_dims(mx.sum(selected_scores, -1), -1)
168
+ selected_scores = selected_scores / (denominator + 1e-20)
169
+ end
170
+
171
+ selected_scores = selected_scores * routed_scaling_factor.to_f
172
+ [inds, selected_scores]
173
+ end
174
+
175
+ class MoEGate < MLX::NN::Module
176
+ def initialize(config)
177
+ super()
178
+ @top_k = config.num_experts_per_tok
179
+ @norm_topk_prob = config.norm_topk_prob
180
+ @n_routed_experts = config.n_routed_experts
181
+ @routed_scaling_factor = config.routed_scaling_factor
182
+ @n_group = config.n_group
183
+ @topk_group = config.topk_group
184
+
185
+ raise ArgumentError, "Unsupported topk method: #{config.topk_method}" unless config.topk_method == "noaux_tc"
186
+
187
+ mx = MLX::Core
188
+ self.weight = mx.zeros([@n_routed_experts, config.hidden_size])
189
+ self.e_score_correction_bias = mx.zeros([@n_routed_experts])
190
+ end
191
+
192
+ def call(x)
193
+ mx = MLX::Core
194
+ gates = mx.matmul(x, mx.transpose(weight))
195
+ Glm4Moe.group_expert_select(
196
+ gates,
197
+ e_score_correction_bias,
198
+ @top_k,
199
+ @n_group,
200
+ @topk_group,
201
+ @routed_scaling_factor,
202
+ @norm_topk_prob
203
+ )
204
+ end
205
+ end
206
+
207
+ class MoE < MLX::NN::Module
208
+ def initialize(config)
209
+ super()
210
+ @config = config
211
+
212
+ self.switch_mlp = SwitchLayers::SwitchGLU.new(
213
+ config.hidden_size,
214
+ config.moe_intermediate_size,
215
+ config.n_routed_experts
216
+ )
217
+
218
+ self.gate = MoEGate.new(config)
219
+ unless config.n_shared_experts.nil?
220
+ shared_intermediate = config.moe_intermediate_size * config.n_shared_experts
221
+ self.shared_experts = MLP.new(config, intermediate_size: shared_intermediate)
222
+ end
223
+ end
224
+
225
+ def call(x)
226
+ mx = MLX::Core
227
+ inds, scores = gate.call(x)
228
+ y = switch_mlp.call(x, inds)
229
+ y = mx.sum(y * mx.expand_dims(scores, -1), -2).astype(y.dtype)
230
+ y = y + shared_experts.call(x) unless @config.n_shared_experts.nil?
231
+ y
232
+ end
233
+ end
234
+
235
+ class DecoderLayer < MLX::NN::Module
236
+ def initialize(config, layer_idx)
237
+ super()
238
+ self.self_attn = Attention.new(config)
239
+ self.mlp = if !config.n_routed_experts.nil? && layer_idx >= config.first_k_dense_replace
240
+ MoE.new(config)
241
+ else
242
+ MLP.new(config)
243
+ end
244
+
245
+ self.input_layernorm = MLX::NN::RMSNorm.new(config.hidden_size, eps: config.rms_norm_eps)
246
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(config.hidden_size, eps: config.rms_norm_eps)
247
+ end
248
+
249
+ def call(x, mask: nil, cache: nil)
250
+ r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
251
+ h = x + r
252
+ r = mlp.call(post_attention_layernorm.call(h))
253
+ h + r
254
+ end
255
+ end
256
+
257
+ class LanguageModel < MLX::NN::Module
258
+ include PipelineMixin
259
+
260
+ def initialize(config)
261
+ super()
262
+ self.embed_tokens = MLX::NN::Embedding.new(config.vocab_size, config.hidden_size)
263
+ self.layers = Array.new(config.num_hidden_layers) { |idx| DecoderLayer.new(config, idx) }
264
+ self.norm = MLX::NN::RMSNorm.new(config.hidden_size, eps: config.rms_norm_eps)
265
+ end
266
+
267
+ def call(x, cache: nil)
268
+ h = embed_tokens.call(x)
269
+ active_layers = pipeline_layers
270
+ layer_cache = cache || [nil] * active_layers.length
271
+ mask = _create_attention_mask(h, layer_cache[0])
272
+
273
+ active_layers.each_with_index do |layer, idx|
274
+ h = layer.call(h, mask: mask, cache: layer_cache[idx])
275
+ end
276
+
277
+ norm.call(h)
278
+ end
279
+
280
+ private
281
+
282
+ def _create_attention_mask(h, cache = nil)
283
+ n = h.shape[1]
284
+ return cache.make_mask(n) if cache && cache.respond_to?(:make_mask)
285
+ return nil if n == 1
286
+
287
+ "causal"
288
+ end
289
+ end
290
+
291
+ class Model < MLX::NN::Module
292
+ def initialize(config)
293
+ super()
294
+ @args = config
295
+ self.model_type = config.model_type
296
+ self.model = LanguageModel.new(config)
297
+ self.lm_head = MLX::NN::Linear.new(config.hidden_size, config.vocab_size, bias: false)
298
+ end
299
+
300
+ def call(inputs, cache: nil)
301
+ out = model.call(inputs, cache: cache)
302
+ lm_head.call(out)
303
+ end
304
+
305
+ def sanitize(weights)
306
+ mx = MLX::Core
307
+ result = weights.dup
308
+ mpt_layer = @args.num_hidden_layers.to_i
309
+
310
+ @args.num_hidden_layers.to_i.times do |layer_idx|
311
+ prefix = "model.layers.#{layer_idx}.mlp"
312
+ %w[gate_proj down_proj up_proj].each do |proj_name|
313
+ %w[weight scales biases].each do |param_name|
314
+ first_key = "#{prefix}.experts.0.#{proj_name}.#{param_name}"
315
+ next unless result.key?(first_key)
316
+
317
+ expert_keys = (0...@args.n_routed_experts.to_i).map do |expert_idx|
318
+ "#{prefix}.experts.#{expert_idx}.#{proj_name}.#{param_name}"
319
+ end
320
+ next unless expert_keys.all? { |key| result.key?(key) }
321
+
322
+ stacked = expert_keys.map { |key| result.delete(key) }
323
+ result["#{prefix}.switch_mlp.#{proj_name}.#{param_name}"] = mx.stack(stacked)
324
+ end
325
+ end
326
+ end
327
+
328
+ result.reject { |key, _| key.start_with?("model.layers.#{mpt_layer}") }
329
+ end
330
+
331
+ def layers
332
+ model.pipeline_layers
333
+ end
334
+
335
+ def cast_predicate
336
+ lambda { |key| !key.include?("e_score_correction_bias") }
337
+ end
338
+ end
339
+
340
+ Models.register("glm4_moe", Model, ModelArgs)
341
+ end
342
+ end
343
+ end
@@ -0,0 +1,131 @@
1
+ require_relative "glm4_moe"
2
+
3
+ module MlxLm
4
+ module Models
5
+ module Glm4MoeLite
6
+ class ModelArgs < BaseModelArgs
7
+ field :model_type, default: "glm4_moe_lite"
8
+ field :vocab_size, default: 154_880
9
+ field :hidden_size, default: 2048
10
+ field :intermediate_size, default: 10_240
11
+ field :moe_intermediate_size, default: 1536
12
+ field :num_hidden_layers, default: 47
13
+ field :num_attention_heads, default: 20
14
+ field :num_key_value_heads, default: 20
15
+ field :n_shared_experts, default: 1
16
+ field :n_routed_experts, default: 64
17
+ field :routed_scaling_factor, default: 1.8
18
+ field :kv_lora_rank, default: 512
19
+ field :q_lora_rank, default: 768
20
+ field :qk_rope_head_dim, default: 64
21
+ field :qk_nope_head_dim, default: 192
22
+ field :v_head_dim, default: 256
23
+ field :topk_method, default: "noaux_tc"
24
+ field :scoring_func, default: "sigmoid"
25
+ field :norm_topk_prob, default: true
26
+ field :n_group, default: 1
27
+ field :topk_group, default: 1
28
+ field :num_experts_per_tok, default: 4
29
+ field :moe_layer_freq, default: 1
30
+ field :first_k_dense_replace, default: 1
31
+ field :max_position_embeddings, default: 202_752
32
+ field :rms_norm_eps, default: 1e-5
33
+ field :rope_theta, default: 1_000_000.0
34
+ field :rope_scaling, default: nil
35
+ field :attention_bias, default: false
36
+ field :attention_dropout, default: 0.0
37
+ field :partial_rotary_factor, default: 1.0
38
+ field :tie_word_embeddings, default: false
39
+ field :num_nextn_predict_layers, default: 1
40
+ field :quantization, default: nil
41
+
42
+ def initialize(**kwargs)
43
+ super
44
+ @num_key_value_heads ||= @num_attention_heads
45
+ end
46
+ end
47
+
48
+ class Model < MLX::NN::Module
49
+ def initialize(args)
50
+ super()
51
+ @args = args
52
+ self.model_type = args.model_type
53
+ self.language_model = Glm4Moe::Model.new(
54
+ Glm4Moe::ModelArgs.from_dict(_to_glm4_moe_config(args))
55
+ )
56
+ end
57
+
58
+ def call(inputs, cache: nil)
59
+ language_model.call(inputs, cache: cache)
60
+ end
61
+
62
+ def sanitize(weights)
63
+ remapped = {}
64
+ weights.each do |key, value|
65
+ remapped[_remap_weight_key(key)] = value
66
+ end
67
+ language_model.sanitize(remapped)
68
+ end
69
+
70
+ def layers
71
+ language_model.layers
72
+ end
73
+
74
+ def make_cache
75
+ return language_model.make_cache if language_model.respond_to?(:make_cache)
76
+
77
+ nil
78
+ end
79
+
80
+ private
81
+
82
+ def _to_glm4_moe_config(args)
83
+ inferred_head_dim = args.qk_nope_head_dim.to_i + args.qk_rope_head_dim.to_i
84
+ inferred_head_dim = args.hidden_size / args.num_attention_heads if inferred_head_dim <= 0
85
+
86
+ {
87
+ "model_type" => args.model_type,
88
+ "vocab_size" => args.vocab_size,
89
+ "hidden_size" => args.hidden_size,
90
+ "intermediate_size" => args.intermediate_size,
91
+ "max_position_embeddings" => args.max_position_embeddings,
92
+ "moe_intermediate_size" => args.moe_intermediate_size,
93
+ "norm_topk_prob" => args.norm_topk_prob,
94
+ "num_attention_heads" => args.num_attention_heads,
95
+ "n_group" => args.n_group,
96
+ "head_dim" => inferred_head_dim,
97
+ "topk_group" => args.topk_group,
98
+ "n_shared_experts" => args.n_shared_experts,
99
+ "n_routed_experts" => args.n_routed_experts,
100
+ "routed_scaling_factor" => args.routed_scaling_factor,
101
+ "num_experts_per_tok" => args.num_experts_per_tok,
102
+ "first_k_dense_replace" => args.first_k_dense_replace,
103
+ "num_hidden_layers" => args.num_hidden_layers,
104
+ "num_key_value_heads" => args.num_key_value_heads,
105
+ "rms_norm_eps" => args.rms_norm_eps,
106
+ "rope_theta" => args.rope_theta,
107
+ "rope_scaling" => args.rope_scaling,
108
+ "use_qk_norm" => false,
109
+ "tie_word_embeddings" => args.tie_word_embeddings,
110
+ "attention_bias" => args.attention_bias,
111
+ "partial_rotary_factor" => args.partial_rotary_factor,
112
+ "scoring_func" => args.scoring_func,
113
+ "topk_method" => args.topk_method,
114
+ }
115
+ end
116
+
117
+ def _remap_weight_key(key)
118
+ mapped = key.dup
119
+ mapped = mapped.gsub(".self_attn.embed_q.", ".self_attn.q_proj.")
120
+ mapped = mapped.gsub(".self_attn.unembed_out.", ".self_attn.v_proj.")
121
+ mapped = mapped.gsub(".self_attn.kv_a_proj_with_mqa.", ".self_attn.k_proj.")
122
+ mapped = mapped.gsub(".self_attn.q_a_proj.", ".self_attn.q_proj.")
123
+ mapped = mapped.gsub(".self_attn.q_b_proj.", ".self_attn.q_proj.")
124
+ mapped
125
+ end
126
+ end
127
+
128
+ Models.register("glm4_moe_lite", Model, ModelArgs)
129
+ end
130
+ end
131
+ end
@@ -0,0 +1,26 @@
1
+ require_relative "deepseek_v32"
2
+
3
+ module MlxLm
4
+ module Models
5
+ module GlmMoeDsa
6
+ class ModelArgs < DeepseekV32::ModelArgs
7
+ field :model_type, default: "glm_moe_dsa"
8
+ field :rope_parameters, default: nil
9
+
10
+ def initialize(**kwargs)
11
+ super
12
+ return unless @rope_parameters.respond_to?(:[])
13
+
14
+ @rope_scaling = @rope_parameters
15
+ rope_theta = @rope_parameters["rope_theta"] || @rope_parameters[:rope_theta]
16
+ @rope_theta = rope_theta unless rope_theta.nil?
17
+ end
18
+ end
19
+
20
+ class Model < DeepseekV32::Model
21
+ end
22
+
23
+ Models.register("glm_moe_dsa", Model, ModelArgs)
24
+ end
25
+ end
26
+ end
@@ -0,0 +1,166 @@
1
+ module MlxLm
2
+ module Models
3
+ module GPT2
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "gpt2"
6
+ field :n_ctx
7
+ field :n_embd
8
+ field :n_head
9
+ field :n_layer
10
+ field :n_positions
11
+ field :layer_norm_epsilon
12
+ field :vocab_size
13
+ field :num_key_value_heads, default: nil
14
+
15
+ def initialize(**kwargs)
16
+ super
17
+ @num_key_value_heads ||= @n_head
18
+ end
19
+ end
20
+
21
+ class Attention < MLX::NN::Module
22
+ def initialize(args)
23
+ super()
24
+ unless (args.n_embd % args.n_head).zero?
25
+ raise ArgumentError, "n_embd must be divisible by n_head"
26
+ end
27
+
28
+ @n_embd = args.n_embd
29
+ @n_head = args.n_head
30
+ @head_dim = @n_embd / @n_head
31
+ @scale = @head_dim**(-0.5)
32
+
33
+ self.c_attn = MLX::NN::Linear.new(@n_embd, 3 * @n_embd, bias: true)
34
+ self.c_proj = MLX::NN::Linear.new(@n_embd, @n_embd, bias: true)
35
+ end
36
+
37
+ def call(x, mask: nil, cache: nil)
38
+ mx = MLX::Core
39
+ b, l, _d = x.shape
40
+
41
+ qkv = c_attn.call(x)
42
+ queries, keys, values = mx.split(qkv, 3, 2)
43
+
44
+ queries = queries.reshape([b, l, @n_head, @head_dim]).transpose([0, 2, 1, 3])
45
+ keys = keys.reshape([b, l, @n_head, @head_dim]).transpose([0, 2, 1, 3])
46
+ values = values.reshape([b, l, @n_head, @head_dim]).transpose([0, 2, 1, 3])
47
+
48
+ if cache
49
+ keys, values = cache.update_and_fetch(keys, values)
50
+ end
51
+
52
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
53
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_embd])
54
+ c_proj.call(output)
55
+ end
56
+ end
57
+
58
+ class MLP < MLX::NN::Module
59
+ def initialize(args)
60
+ super()
61
+ self.c_fc = MLX::NN::Linear.new(args.n_embd, 4 * args.n_embd, bias: true)
62
+ self.c_proj = MLX::NN::Linear.new(4 * args.n_embd, args.n_embd, bias: true)
63
+ end
64
+
65
+ def call(x)
66
+ c_proj.call(MLX::NN.gelu_approx(c_fc.call(x)))
67
+ end
68
+ end
69
+
70
+ class TransformerBlock < MLX::NN::Module
71
+ def initialize(args)
72
+ super()
73
+ self.attn = Attention.new(args)
74
+ self.mlp = MLP.new(args)
75
+ self.ln_1 = MLX::NN::LayerNorm.new(args.n_embd, eps: args.layer_norm_epsilon)
76
+ self.ln_2 = MLX::NN::LayerNorm.new(args.n_embd, eps: args.layer_norm_epsilon)
77
+ end
78
+
79
+ def call(x, mask: nil, cache: nil)
80
+ r = attn.call(ln_1.call(x), mask: mask, cache: cache)
81
+ h = x + r
82
+ r = mlp.call(ln_2.call(h))
83
+ h + r
84
+ end
85
+ end
86
+
87
+ class GPT2Model < MLX::NN::Module
88
+ def initialize(args)
89
+ super()
90
+ self.wte = MLX::NN::Embedding.new(args.vocab_size, args.n_embd)
91
+ self.wpe = MLX::NN::Embedding.new(args.n_positions, args.n_embd)
92
+ self.h = Array.new(args.n_layer) { TransformerBlock.new(args) }
93
+ self.ln_f = MLX::NN::LayerNorm.new(args.n_embd, eps: args.layer_norm_epsilon)
94
+ end
95
+
96
+ def call(inputs, cache: nil)
97
+ mx = MLX::Core
98
+ _b, l = inputs.shape
99
+
100
+ hidden_states = wte.call(inputs)
101
+ layer_cache = cache || [nil] * h.length
102
+ offset = layer_cache[0] ? layer_cache[0].offset : 0
103
+ position_ids = mx.add(mx.arange(0, l, 1, mx.int32), offset)
104
+ hidden_states = hidden_states + wpe.call(position_ids)
105
+
106
+ mask = _create_attention_mask(hidden_states, layer_cache[0])
107
+ h.each_with_index do |layer, i|
108
+ hidden_states = layer.call(hidden_states, mask: mask, cache: layer_cache[i])
109
+ end
110
+ ln_f.call(hidden_states)
111
+ end
112
+
113
+ private
114
+
115
+ def _create_attention_mask(h, cache)
116
+ return cache.make_mask(h.shape[1]) if cache && cache.respond_to?(:make_mask)
117
+ return nil if h.shape[1] == 1
118
+
119
+ "causal"
120
+ end
121
+ end
122
+
123
+ class Model < MLX::NN::Module
124
+ attr_reader :args
125
+
126
+ def initialize(args)
127
+ super()
128
+ @args = args
129
+ self.model_type = args.model_type
130
+ self.model = GPT2Model.new(args)
131
+ end
132
+
133
+ def call(inputs, cache: nil)
134
+ out = model.call(inputs, cache: cache)
135
+ model.wte.as_linear(out)
136
+ end
137
+
138
+ def sanitize(weights)
139
+ result = {}
140
+ weights.each do |k, v|
141
+ next if k.match?(/\Ah\.\d+\.attn\.bias\z/)
142
+
143
+ value = if k.match?(/\Ah\.\d+\.(attn\.c_attn|attn\.c_proj|mlp\.c_fc|mlp\.c_proj)\.weight\z/)
144
+ v.transpose([1, 0])
145
+ else
146
+ v
147
+ end
148
+
149
+ if k.start_with?("model.")
150
+ result[k] = value
151
+ else
152
+ result["model.#{k}"] = value
153
+ end
154
+ end
155
+ result
156
+ end
157
+
158
+ def layers
159
+ model.h
160
+ end
161
+ end
162
+
163
+ Models.register("gpt2", Model, ModelArgs)
164
+ end
165
+ end
166
+ end