mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
@@ -0,0 +1,239 @@
1
+ module MlxLm
2
+ module Models
3
+ module DeepSeek
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "deepseek"
6
+ field :hidden_size, default: 4096
7
+ field :num_hidden_layers, default: 30
8
+ field :num_attention_heads, default: 32
9
+ field :num_key_value_heads, default: 32
10
+ field :intermediate_size, default: 11008
11
+ field :moe_intermediate_size, default: 1407
12
+ field :vocab_size, default: 102400
13
+ field :rms_norm_eps, default: 1e-6
14
+ field :rope_theta, default: 10000.0
15
+ field :rope_scaling, default: nil
16
+ field :attention_bias, default: false
17
+ field :n_shared_experts, default: nil
18
+ field :n_routed_experts, default: nil
19
+ field :num_experts_per_tok, default: nil
20
+ field :moe_layer_freq, default: 1
21
+ field :first_k_dense_replace, default: 0
22
+ field :max_position_embeddings, default: 2048
23
+
24
+ def initialize(**kwargs)
25
+ super
26
+ @num_key_value_heads ||= @num_attention_heads
27
+ end
28
+ end
29
+
30
+ class Attention < MLX::NN::Module
31
+ def initialize(args)
32
+ super()
33
+ dim = args.hidden_size
34
+ @n_heads = args.num_attention_heads
35
+ @n_kv_heads = args.num_key_value_heads
36
+ @head_dim = dim / @n_heads
37
+ @scale = @head_dim**(-0.5)
38
+
39
+ bias = args.attention_bias
40
+ self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: bias)
41
+ self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: bias)
42
+ self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: bias)
43
+ self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: bias)
44
+
45
+ rope_scale = 1.0
46
+ if args.rope_scaling && args.rope_scaling["type"] == "linear"
47
+ rope_scale = 1.0 / args.rope_scaling["factor"]
48
+ end
49
+
50
+ self.rope = MLX::NN::RoPE.new(
51
+ @head_dim,
52
+ traditional: false,
53
+ base: args.rope_theta,
54
+ scale: rope_scale
55
+ )
56
+ end
57
+
58
+ def call(x, mask: nil, cache: nil)
59
+ mx = MLX::Core
60
+ b, l, _d = x.shape
61
+
62
+ queries = q_proj.call(x).reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
63
+ keys = k_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
64
+ values = v_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
65
+
66
+ if cache
67
+ queries = rope.call(queries, offset: cache.offset)
68
+ keys = rope.call(keys, offset: cache.offset)
69
+ keys, values = cache.update_and_fetch(keys, values)
70
+ else
71
+ queries = rope.call(queries)
72
+ keys = rope.call(keys)
73
+ end
74
+
75
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
76
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
77
+ o_proj.call(output)
78
+ end
79
+ end
80
+
81
+ class DeepseekMLP < MLX::NN::Module
82
+ def initialize(dim, hidden_dim)
83
+ super()
84
+ self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false)
85
+ self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: false)
86
+ self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false)
87
+ end
88
+
89
+ def call(x)
90
+ down_proj.call(MLX::NN.silu(gate_proj.call(x)) * up_proj.call(x))
91
+ end
92
+ end
93
+
94
+ class MoEGate < MLX::NN::Module
95
+ def initialize(args)
96
+ super()
97
+ @top_k = args.num_experts_per_tok
98
+ self.weight = MLX::Core.zeros([args.n_routed_experts, args.hidden_size])
99
+ end
100
+
101
+ def call(x)
102
+ mx = MLX::Core
103
+ gates = mx.matmul(x, mx.transpose(weight))
104
+ scores = mx.softmax(gates.astype(mx.float32), -1).astype(gates.dtype)
105
+ k = @top_k
106
+ inds = mx.stop_gradient(mx.argpartition(scores * -1.0, k - 1, -1))
107
+ take_ids = mx.array((0...k).to_a, dtype: mx.int32)
108
+ inds = mx.take(inds, take_ids, -1)
109
+ scores = mx.take_along_axis(scores, inds, -1)
110
+ [inds, scores]
111
+ end
112
+ end
113
+
114
+ class DeepseekMoE < MLX::NN::Module
115
+ def initialize(args)
116
+ super()
117
+ @n_shared_experts = args.n_shared_experts
118
+ dim = args.hidden_size
119
+ moe_dim = args.moe_intermediate_size
120
+
121
+ self.switch_mlp = SwitchLayers::SwitchGLU.new(dim, moe_dim, args.n_routed_experts)
122
+ self.gate = MoEGate.new(args)
123
+
124
+ if args.n_shared_experts && args.n_shared_experts > 0
125
+ shared_dim = moe_dim * args.n_shared_experts
126
+ self.shared_experts = DeepseekMLP.new(dim, shared_dim)
127
+ end
128
+ end
129
+
130
+ def call(x)
131
+ mx = MLX::Core
132
+ inds, scores = gate.call(x)
133
+ y = switch_mlp.call(x, inds)
134
+ y = mx.sum(y * mx.expand_dims(scores, -1), -2)
135
+
136
+ if @n_shared_experts && @n_shared_experts > 0
137
+ y = y + shared_experts.call(x)
138
+ end
139
+
140
+ y
141
+ end
142
+ end
143
+
144
+ class DecoderLayer < MLX::NN::Module
145
+ def initialize(args, layer_idx)
146
+ super()
147
+ self.self_attn = Attention.new(args)
148
+ self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
149
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
150
+
151
+ # Determine if this layer uses MoE or dense
152
+ use_moe = args.n_routed_experts &&
153
+ layer_idx >= args.first_k_dense_replace &&
154
+ layer_idx % args.moe_layer_freq == 0
155
+
156
+ if use_moe
157
+ self.mlp = DeepseekMoE.new(args)
158
+ else
159
+ self.mlp = DeepseekMLP.new(args.hidden_size, args.intermediate_size)
160
+ end
161
+ end
162
+
163
+ def call(x, mask: nil, cache: nil)
164
+ r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
165
+ h = x + r
166
+ r = mlp.call(post_attention_layernorm.call(h))
167
+ h + r
168
+ end
169
+ end
170
+
171
+ class DeepseekModel < MLX::NN::Module
172
+ def initialize(args)
173
+ super()
174
+ @args = args
175
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
176
+ self.layers = Array.new(args.num_hidden_layers) { |i| DecoderLayer.new(args, i) }
177
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
178
+ end
179
+
180
+ def call(inputs, cache: nil)
181
+ h = embed_tokens.call(inputs)
182
+ layer_cache = cache || [nil] * layers.length
183
+
184
+ mask = nil
185
+ mask = "causal" if h.shape[1] > 1
186
+
187
+ layers.each_with_index do |layer, i|
188
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
189
+ end
190
+
191
+ norm.call(h)
192
+ end
193
+ end
194
+
195
+ class Model < MLX::NN::Module
196
+ def initialize(args)
197
+ super()
198
+ @args = args
199
+ self.model = DeepseekModel.new(args)
200
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
201
+ end
202
+
203
+ def call(inputs, cache: nil)
204
+ out = model.call(inputs, cache: cache)
205
+ lm_head.call(out)
206
+ end
207
+
208
+ def sanitize(weights)
209
+ mx = MLX::Core
210
+ result = weights.reject { |k, _| k.include?("self_attn.rotary_emb.inv_freq") }
211
+
212
+ # Convert per-expert weights to stacked SwitchGLU format
213
+ @args.num_hidden_layers.times do |l|
214
+ prefix = "model.layers.#{l}"
215
+ ["gate_proj", "down_proj", "up_proj"].each do |m|
216
+ ["weight", "scales", "biases"].each do |k|
217
+ key0 = "#{prefix}.mlp.experts.0.#{m}.#{k}"
218
+ if result.key?(key0)
219
+ to_join = (0...@args.n_routed_experts).map { |e|
220
+ result.delete("#{prefix}.mlp.experts.#{e}.#{m}.#{k}")
221
+ }
222
+ result["#{prefix}.mlp.switch_mlp.#{m}.#{k}"] = mx.stack(to_join)
223
+ end
224
+ end
225
+ end
226
+ end
227
+
228
+ result
229
+ end
230
+
231
+ def layers
232
+ model.layers
233
+ end
234
+ end
235
+
236
+ Models.register("deepseek", Model, ModelArgs)
237
+ end
238
+ end
239
+ end
@@ -0,0 +1,108 @@
1
+ require_relative "deepseek"
2
+
3
+ module MlxLm
4
+ module Models
5
+ module DeepseekV2
6
+ class ModelArgs < BaseModelArgs
7
+ field :model_type, default: "deepseek_v2"
8
+ field :vocab_size, default: 102_400
9
+ field :hidden_size, default: 4096
10
+ field :intermediate_size, default: 11_008
11
+ field :moe_intermediate_size, default: 1407
12
+ field :num_hidden_layers, default: 30
13
+ field :num_attention_heads, default: 32
14
+ field :num_key_value_heads, default: 32
15
+ field :n_shared_experts, default: nil
16
+ field :n_routed_experts, default: nil
17
+ field :routed_scaling_factor, default: 1.0
18
+ field :kv_lora_rank, default: 512
19
+ field :q_lora_rank, default: 1536
20
+ field :qk_rope_head_dim, default: 64
21
+ field :v_head_dim, default: 128
22
+ field :qk_nope_head_dim, default: 128
23
+ field :topk_method, default: "gready"
24
+ field :n_group, default: nil
25
+ field :topk_group, default: nil
26
+ field :num_experts_per_tok, default: nil
27
+ field :moe_layer_freq, default: 1
28
+ field :first_k_dense_replace, default: 0
29
+ field :max_position_embeddings, default: 2048
30
+ field :rms_norm_eps, default: 1e-6
31
+ field :rope_theta, default: 10_000.0
32
+ field :rope_scaling, default: nil
33
+ field :attention_bias, default: false
34
+
35
+ def initialize(**kwargs)
36
+ super
37
+ @num_key_value_heads ||= @num_attention_heads
38
+ end
39
+ end
40
+
41
+ class Model < DeepSeek::Model
42
+ def initialize(args)
43
+ super(DeepSeek::ModelArgs.from_dict(_to_deepseek_config(args)))
44
+ self.model_type = args.model_type
45
+ end
46
+
47
+ def sanitize(weights)
48
+ _stack_expert_weights(weights.dup)
49
+ end
50
+
51
+ private
52
+
53
+ def _to_deepseek_config(args)
54
+ {
55
+ "model_type" => args.model_type,
56
+ "vocab_size" => args.vocab_size,
57
+ "hidden_size" => args.hidden_size,
58
+ "intermediate_size" => args.intermediate_size,
59
+ "moe_intermediate_size" => args.moe_intermediate_size,
60
+ "num_hidden_layers" => args.num_hidden_layers,
61
+ "num_attention_heads" => args.num_attention_heads,
62
+ "num_key_value_heads" => args.num_key_value_heads,
63
+ "n_shared_experts" => args.n_shared_experts,
64
+ "n_routed_experts" => args.n_routed_experts,
65
+ "num_experts_per_tok" => args.num_experts_per_tok,
66
+ "moe_layer_freq" => args.moe_layer_freq,
67
+ "first_k_dense_replace" => args.first_k_dense_replace,
68
+ "max_position_embeddings" => args.max_position_embeddings,
69
+ "rms_norm_eps" => args.rms_norm_eps,
70
+ "rope_theta" => args.rope_theta,
71
+ "rope_scaling" => args.rope_scaling,
72
+ "attention_bias" => args.attention_bias,
73
+ }
74
+ end
75
+
76
+ def _stack_expert_weights(weights)
77
+ num_experts = @args.n_routed_experts.to_i
78
+ return weights if num_experts <= 0
79
+
80
+ mx = MLX::Core
81
+ projections = %w[gate_proj down_proj up_proj].freeze
82
+ params = %w[weight scales biases].freeze
83
+
84
+ @args.num_hidden_layers.times do |layer_idx|
85
+ prefix = "model.layers.#{layer_idx}.mlp"
86
+ projections.each do |projection|
87
+ params.each do |param|
88
+ expert_keys = (0...num_experts).map do |expert_idx|
89
+ "#{prefix}.experts.#{expert_idx}.#{projection}.#{param}"
90
+ end
91
+ next unless expert_keys.all? { |key| weights.key?(key) }
92
+
93
+ stacked = expert_keys.map { |key| weights.delete(key) }
94
+ weights["#{prefix}.switch_mlp.#{projection}.#{param}"] = mx.stack(stacked)
95
+ end
96
+ end
97
+ end
98
+
99
+ weights
100
+ end
101
+ end
102
+
103
+ Models.register("deepseek_v2", Model, ModelArgs)
104
+ end
105
+
106
+ DeepSeekV2 = DeepseekV2 unless const_defined?(:DeepSeekV2)
107
+ end
108
+ end
@@ -0,0 +1,34 @@
1
+ require_relative "deepseek_v2"
2
+
3
+ module MlxLm
4
+ module Models
5
+ module DeepseekV3
6
+ class ModelArgs < DeepseekV2::ModelArgs
7
+ field :model_type, default: "deepseek_v3"
8
+ field :topk_method, default: "noaux_tc"
9
+ field :scoring_func, default: "sigmoid"
10
+ field :norm_topk_prob, default: true
11
+ field :n_group, default: 1
12
+ field :topk_group, default: 1
13
+ field :num_experts_per_tok, default: 1
14
+ end
15
+
16
+ class Model < DeepseekV2::Model
17
+ def sanitize(weights)
18
+ super(weights).reject do |key, _|
19
+ key_name = key.to_s
20
+ key_name.start_with?("model.layers.61") || key_name.include?("rotary_emb.inv_freq")
21
+ end
22
+ end
23
+
24
+ def cast_predicate
25
+ ->(key) { !key.to_s.include?("e_score_correction_bias") }
26
+ end
27
+ end
28
+
29
+ Models.register("deepseek_v3", Model, ModelArgs)
30
+ end
31
+
32
+ DeepSeekV3 = DeepseekV3 unless const_defined?(:DeepSeekV3)
33
+ end
34
+ end
@@ -0,0 +1,45 @@
1
+ require_relative "deepseek"
2
+
3
+ module MlxLm
4
+ module Models
5
+ module DeepseekV32
6
+ class ModelArgs < DeepSeek::ModelArgs
7
+ field :model_type, default: "deepseek_v32"
8
+ field :index_head_dim, default: 128
9
+ field :index_n_heads, default: 64
10
+ field :index_topk, default: 2048
11
+ field :routed_scaling_factor, default: 1.0
12
+ field :kv_lora_rank, default: 512
13
+ field :q_lora_rank, default: 1536
14
+ field :qk_rope_head_dim, default: 64
15
+ field :v_head_dim, default: 128
16
+ field :qk_nope_head_dim, default: 128
17
+ field :topk_method, default: "noaux_tc"
18
+ field :scoring_func, default: "sigmoid"
19
+ field :norm_topk_prob, default: true
20
+ field :n_group, default: 1
21
+ field :topk_group, default: 1
22
+ end
23
+
24
+ class Model < DeepSeek::Model
25
+ def sanitize(weights)
26
+ sanitized = super(weights)
27
+ drop_mtp_layer_weights(sanitized)
28
+ end
29
+
30
+ private
31
+
32
+ def drop_mtp_layer_weights(weights)
33
+ cutoff = @args.num_hidden_layers.to_i
34
+
35
+ weights.reject do |key, _|
36
+ match = key.match(/\Amodel\.layers\.(\d+)(?:\.|\z)/)
37
+ match && match[1].to_i >= cutoff
38
+ end
39
+ end
40
+ end
41
+
42
+ Models.register("deepseek_v32", Model, ModelArgs)
43
+ end
44
+ end
45
+ end