mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
@@ -0,0 +1,169 @@
1
+ module MlxLm
2
+ module Models
3
+ module MiniCPM
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "minicpm"
6
+ field :hidden_size
7
+ field :dim_model_base
8
+ field :num_hidden_layers
9
+ field :intermediate_size
10
+ field :num_attention_heads
11
+ field :rms_norm_eps
12
+ field :vocab_size
13
+ field :num_key_value_heads
14
+ field :scale_depth
15
+ field :scale_emb
16
+ field :max_position_embeddings, default: nil
17
+ field :rope_theta, default: 1_000_000.0
18
+ field :rope_traditional, default: false
19
+ field :rope_scaling, default: nil
20
+ field :tie_word_embeddings, default: false
21
+
22
+ def initialize(**kwargs)
23
+ super
24
+ @num_key_value_heads ||= @num_attention_heads
25
+ end
26
+ end
27
+
28
+ class MLP < MLX::NN::Module
29
+ def initialize(args)
30
+ super()
31
+ self.gate_proj = MLX::NN::Linear.new(args.hidden_size, args.intermediate_size, bias: false)
32
+ self.up_proj = MLX::NN::Linear.new(args.hidden_size, args.intermediate_size, bias: false)
33
+ self.down_proj = MLX::NN::Linear.new(args.intermediate_size, args.hidden_size, bias: false)
34
+ end
35
+
36
+ def call(x)
37
+ down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x)))
38
+ end
39
+ end
40
+
41
+ class Attention < MLX::NN::Module
42
+ def initialize(args)
43
+ super()
44
+
45
+ dim = args.hidden_size
46
+ @n_heads = args.num_attention_heads
47
+ @n_kv_heads = args.num_key_value_heads
48
+ @head_dim = dim / @n_heads
49
+ @scale = @head_dim**(-0.5)
50
+
51
+ self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: false)
52
+ self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: false)
53
+ self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: false)
54
+ self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: false)
55
+
56
+ self.rope = MlxLm::Models.initialize_rope(
57
+ @head_dim,
58
+ args.rope_theta,
59
+ args.rope_traditional,
60
+ args.rope_scaling,
61
+ max_position_embeddings: args.max_position_embeddings
62
+ )
63
+ end
64
+
65
+ def call(x, mask: nil, cache: nil)
66
+ mx = MLX::Core
67
+ b, l, _d = x.shape
68
+
69
+ queries = q_proj.call(x)
70
+ keys = k_proj.call(x)
71
+ values = v_proj.call(x)
72
+
73
+ queries = queries.reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
74
+ keys = keys.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
75
+ values = values.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
76
+
77
+ if cache
78
+ queries = rope.call(queries, offset: cache.offset)
79
+ keys = rope.call(keys, offset: cache.offset)
80
+ keys, values = cache.update_and_fetch(keys, values)
81
+ else
82
+ queries = rope.call(queries)
83
+ keys = rope.call(keys)
84
+ end
85
+
86
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
87
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
88
+ o_proj.call(output)
89
+ end
90
+ end
91
+
92
+ class DecoderLayer < MLX::NN::Module
93
+ def initialize(args)
94
+ super()
95
+ self.self_attn = Attention.new(args)
96
+ self.mlp = MLP.new(args)
97
+ self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
98
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
99
+ @residual_scale = args.scale_depth / Math.sqrt(args.num_hidden_layers)
100
+ end
101
+
102
+ def call(x, mask: nil, cache: nil)
103
+ r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
104
+ h = x + r * @residual_scale
105
+ r = mlp.call(post_attention_layernorm.call(h))
106
+ h + r * @residual_scale
107
+ end
108
+ end
109
+
110
+ class MiniCPMModel < MLX::NN::Module
111
+ def initialize(args)
112
+ super()
113
+ @args = args
114
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
115
+ self.layers = Array.new(args.num_hidden_layers) { DecoderLayer.new(args) }
116
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
117
+ end
118
+
119
+ def call(inputs, cache: nil)
120
+ h = embed_tokens.call(inputs) * @args.scale_emb
121
+ layer_cache = cache || [nil] * layers.length
122
+
123
+ mask = nil
124
+ mask = "causal" if h.shape[1] > 1
125
+
126
+ layers.each_with_index do |layer, i|
127
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
128
+ end
129
+
130
+ norm.call(h)
131
+ end
132
+ end
133
+
134
+ class Model < MLX::NN::Module
135
+ def initialize(args)
136
+ super()
137
+ @args = args
138
+ self.model_type = args.model_type
139
+ self.model = MiniCPMModel.new(args)
140
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) unless args.tie_word_embeddings
141
+ end
142
+
143
+ def call(inputs, cache: nil)
144
+ mx = MLX::Core
145
+ out = model.call(inputs, cache: cache)
146
+
147
+ if @args.tie_word_embeddings
148
+ mx.matmul(out, model.embed_tokens.weight.T)
149
+ else
150
+ lm_head.call(out / (@args.hidden_size.to_f / @args.dim_model_base))
151
+ end
152
+ end
153
+
154
+ def sanitize(weights)
155
+ unless weights.key?("lm_head.weight")
156
+ weights["lm_head.weight"] = weights["model.embed_tokens.weight"]
157
+ end
158
+ weights
159
+ end
160
+
161
+ def layers
162
+ model.layers
163
+ end
164
+ end
165
+
166
+ Models.register("minicpm", Model, ModelArgs)
167
+ end
168
+ end
169
+ end
@@ -0,0 +1,237 @@
1
+ require_relative "activations"
2
+ require_relative "rope_utils"
3
+
4
+ module MlxLm
5
+ module Models
6
+ module MiniCPM3
7
+ class ModelArgs < BaseModelArgs
8
+ field :model_type, default: "minicpm3"
9
+ field :hidden_size
10
+ field :dim_model_base
11
+ field :num_hidden_layers
12
+ field :intermediate_size
13
+ field :num_attention_heads
14
+ field :rms_norm_eps
15
+ field :vocab_size
16
+ field :num_key_value_heads
17
+ field :q_lora_rank
18
+ field :qk_nope_head_dim
19
+ field :qk_rope_head_dim
20
+ field :kv_lora_rank
21
+ field :scale_depth
22
+ field :scale_emb
23
+ field :max_position_embeddings
24
+ field :attention_bias, default: false
25
+ field :rope_theta, default: 1_000_000.0
26
+ field :rope_traditional, default: false
27
+ field :rope_scaling, default: nil
28
+ field :tie_word_embeddings, default: false
29
+
30
+ def initialize(**kwargs)
31
+ super
32
+ @num_key_value_heads ||= @num_attention_heads
33
+ @rope_scaling ||= {}
34
+ end
35
+ end
36
+
37
+ class Attention < MLX::NN::Module
38
+ def initialize(args)
39
+ super()
40
+ @qk_rope_head_dim = args.qk_rope_head_dim
41
+ @qk_nope_head_dim = args.qk_nope_head_dim
42
+ @kv_lora_rank = args.kv_lora_rank
43
+ @num_heads = args.num_attention_heads
44
+ @hidden_size = args.hidden_size
45
+ @v_head_dim = @hidden_size / @num_heads
46
+ @q_head_dim = @qk_nope_head_dim + @qk_rope_head_dim
47
+ @kv_head_dim = @qk_nope_head_dim + @v_head_dim
48
+ @softmax_scale = @q_head_dim**(-0.5)
49
+
50
+ self.q_a_proj = MLX::NN::Linear.new(
51
+ @hidden_size,
52
+ args.q_lora_rank,
53
+ bias: args.attention_bias
54
+ )
55
+ self.q_a_layernorm = MLX::NN::RMSNorm.new(args.q_lora_rank, eps: args.rms_norm_eps)
56
+ self.q_b_proj = MLX::NN::Linear.new(
57
+ args.q_lora_rank,
58
+ @num_heads * @q_head_dim,
59
+ bias: false
60
+ )
61
+
62
+ self.kv_a_proj_with_mqa = MLX::NN::Linear.new(
63
+ @hidden_size,
64
+ @kv_lora_rank + @qk_rope_head_dim,
65
+ bias: args.attention_bias
66
+ )
67
+ self.kv_a_layernorm = MLX::NN::RMSNorm.new(@kv_lora_rank, eps: args.rms_norm_eps)
68
+ self.kv_b_proj = MLX::NN::Linear.new(
69
+ @kv_lora_rank,
70
+ @num_heads * @kv_head_dim,
71
+ bias: false
72
+ )
73
+
74
+ self.o_proj = MLX::NN::Linear.new(
75
+ @num_heads * @v_head_dim,
76
+ @hidden_size,
77
+ bias: args.attention_bias
78
+ )
79
+
80
+ self.rope = SuScaledRoPE.new(
81
+ @qk_rope_head_dim,
82
+ base: args.rope_theta,
83
+ max_position_embeddings: args.max_position_embeddings,
84
+ original_max_position_embeddings: scaling_value(args.rope_scaling, "original_max_position_embeddings", 4096),
85
+ short_factor: scaling_value(args.rope_scaling, "short_factor", 1.0),
86
+ long_factor: scaling_value(args.rope_scaling, "long_factor", 1.0)
87
+ )
88
+ end
89
+
90
+ def call(x, mask: nil, cache: nil)
91
+ mx = MLX::Core
92
+ b, l, _ = x.shape
93
+
94
+ q = q_b_proj.call(q_a_layernorm.call(q_a_proj.call(x)))
95
+ q = q.reshape([b, l, @num_heads, @q_head_dim]).transpose([0, 2, 1, 3])
96
+ q_nope, q_pe = mx.split(q, [@qk_nope_head_dim], -1)
97
+
98
+ compressed_kv = kv_a_proj_with_mqa.call(x)
99
+ compressed_kv, k_pe = mx.split(compressed_kv, [@kv_lora_rank], -1)
100
+ k_pe = k_pe.reshape([b, l, 1, @qk_rope_head_dim]).transpose([0, 2, 1, 3])
101
+
102
+ kv = kv_b_proj.call(kv_a_layernorm.call(compressed_kv))
103
+ kv = kv.reshape([b, l, @num_heads, @kv_head_dim]).transpose([0, 2, 1, 3])
104
+ k_nope, values = mx.split(kv, [@qk_nope_head_dim], -1)
105
+
106
+ if cache
107
+ q_pe = rope.call(q_pe, offset: cache.offset)
108
+ k_pe = rope.call(k_pe, offset: cache.offset)
109
+ else
110
+ q_pe = rope.call(q_pe)
111
+ k_pe = rope.call(k_pe)
112
+ end
113
+
114
+ k_pe_broadcasted = mx.broadcast_to(k_pe, [b, @num_heads, l, @qk_rope_head_dim])
115
+ queries = mx.concatenate([q_nope, q_pe], -1)
116
+ keys = mx.concatenate([k_nope, k_pe_broadcasted], -1)
117
+
118
+ if cache
119
+ keys, values = cache.update_and_fetch(keys, values)
120
+ end
121
+
122
+ output = mx.scaled_dot_product_attention(queries, keys, values, @softmax_scale, mask)
123
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @num_heads * @v_head_dim])
124
+ o_proj.call(output)
125
+ end
126
+
127
+ private
128
+
129
+ def scaling_value(config, key, default)
130
+ return default if config.nil?
131
+ return config[key] if config.key?(key)
132
+
133
+ config.fetch(key.to_sym, default)
134
+ end
135
+ end
136
+
137
+ class MLP < MLX::NN::Module
138
+ def initialize(args)
139
+ super()
140
+ self.gate_proj = MLX::NN::Linear.new(args.hidden_size, args.intermediate_size, bias: false)
141
+ self.up_proj = MLX::NN::Linear.new(args.hidden_size, args.intermediate_size, bias: false)
142
+ self.down_proj = MLX::NN::Linear.new(args.intermediate_size, args.hidden_size, bias: false)
143
+ end
144
+
145
+ def call(x)
146
+ down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x)))
147
+ end
148
+ end
149
+
150
+ class DecoderLayer < MLX::NN::Module
151
+ def initialize(args)
152
+ super()
153
+ self.self_attn = Attention.new(args)
154
+ self.mlp = MLP.new(args)
155
+ self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
156
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
157
+ @residual_scale = args.scale_depth / Math.sqrt(args.num_hidden_layers)
158
+ end
159
+
160
+ def call(x, mask: nil, cache: nil)
161
+ r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
162
+ h = x + r * @residual_scale
163
+ r = mlp.call(post_attention_layernorm.call(h))
164
+ h + r * @residual_scale
165
+ end
166
+ end
167
+
168
+ class MiniCPM3Model < MLX::NN::Module
169
+ def initialize(args)
170
+ super()
171
+ @args = args
172
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
173
+ self.layers = Array.new(args.num_hidden_layers) { DecoderLayer.new(args) }
174
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
175
+ end
176
+
177
+ def call(inputs, mask: nil, cache: nil)
178
+ h = embed_tokens.call(inputs) * @args.scale_emb
179
+ layer_cache = cache || [nil] * layers.length
180
+ local_mask = mask || _create_attention_mask(h, layer_cache[0])
181
+
182
+ layers.each_with_index do |layer, i|
183
+ h = layer.call(h, mask: local_mask, cache: layer_cache[i])
184
+ end
185
+
186
+ norm.call(h)
187
+ end
188
+
189
+ private
190
+
191
+ def _create_attention_mask(hidden, cache)
192
+ return cache.make_mask(hidden.shape[1]) if cache && cache.respond_to?(:make_mask)
193
+ return nil if hidden.shape[1] == 1
194
+
195
+ "causal"
196
+ end
197
+ end
198
+
199
+ class Model < MLX::NN::Module
200
+ def initialize(args)
201
+ super()
202
+ @args = args
203
+ self.model_type = args.model_type
204
+ self.model = MiniCPM3Model.new(args)
205
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) unless args.tie_word_embeddings
206
+ end
207
+
208
+ def call(inputs, mask: nil, cache: nil)
209
+ out = model.call(inputs, mask: mask, cache: cache)
210
+ if @args.tie_word_embeddings
211
+ model.embed_tokens.as_linear(out)
212
+ else
213
+ lm_head.call(out / (@args.hidden_size.to_f / @args.dim_model_base))
214
+ end
215
+ end
216
+
217
+ def sanitize(weights)
218
+ result = weights.reject { |k, _| k.include?("self_attn.rotary_emb.inv_freq") }
219
+
220
+ if @args.tie_word_embeddings
221
+ result.delete("lm_head.weight")
222
+ elsif !result.key?("lm_head.weight") && result.key?("model.embed_tokens.weight")
223
+ result["lm_head.weight"] = result["model.embed_tokens.weight"]
224
+ end
225
+
226
+ result
227
+ end
228
+
229
+ def layers
230
+ model.layers
231
+ end
232
+ end
233
+
234
+ Models.register("minicpm3", Model, ModelArgs)
235
+ end
236
+ end
237
+ end
@@ -0,0 +1,282 @@
1
+ require_relative "switch_layers"
2
+
3
+ module MlxLm
4
+ module Models
5
+ module Minimax
6
+ class ModelArgs < BaseModelArgs
7
+ field :model_type, default: "minimax"
8
+ field :hidden_size
9
+ field :intermediate_size
10
+ field :num_attention_heads
11
+ field :num_key_value_heads
12
+ field :max_position_embeddings
13
+ field :num_experts_per_tok
14
+ field :num_local_experts
15
+ field :shared_intermediate_size
16
+ field :num_hidden_layers
17
+ field :rms_norm_eps
18
+ field :rope_theta
19
+ field :rotary_dim
20
+ field :vocab_size
21
+ field :tie_word_embeddings, default: false
22
+ field :scoring_func, default: "sigmoid"
23
+ field :head_dim, default: nil
24
+ field :use_qk_norm, default: true
25
+
26
+ def initialize(**kwargs)
27
+ super
28
+ @num_key_value_heads ||= @num_attention_heads
29
+ @head_dim ||= @hidden_size / @num_attention_heads
30
+ @rotary_dim ||= @head_dim
31
+ end
32
+ end
33
+
34
+ class Attention < MLX::NN::Module
35
+ def initialize(args)
36
+ super()
37
+ dim = args.hidden_size
38
+ @num_attention_heads = args.num_attention_heads
39
+ @num_key_value_heads = args.num_key_value_heads
40
+ @head_dim = args.head_dim
41
+ @scale = @head_dim**(-0.5)
42
+ @use_qk_norm = args.use_qk_norm
43
+
44
+ self.q_proj = MLX::NN::Linear.new(dim, @num_attention_heads * @head_dim, bias: false)
45
+ self.k_proj = MLX::NN::Linear.new(dim, @num_key_value_heads * @head_dim, bias: false)
46
+ self.v_proj = MLX::NN::Linear.new(dim, @num_key_value_heads * @head_dim, bias: false)
47
+ self.o_proj = MLX::NN::Linear.new(@num_attention_heads * @head_dim, dim, bias: false)
48
+
49
+ if @use_qk_norm
50
+ self.q_norm = MLX::NN::RMSNorm.new(@head_dim * @num_attention_heads, eps: args.rms_norm_eps)
51
+ self.k_norm = MLX::NN::RMSNorm.new(@head_dim * @num_key_value_heads, eps: args.rms_norm_eps)
52
+ end
53
+
54
+ self.rope = MLX::NN::RoPE.new(args.rotary_dim, traditional: false, base: args.rope_theta)
55
+ end
56
+
57
+ def call(x, mask: nil, cache: nil)
58
+ mx = MLX::Core
59
+ b, l, _d = x.shape
60
+
61
+ queries = q_proj.call(x)
62
+ keys = k_proj.call(x)
63
+ values = v_proj.call(x)
64
+
65
+ if @use_qk_norm
66
+ queries = q_norm.call(queries)
67
+ keys = k_norm.call(keys)
68
+ end
69
+
70
+ queries = queries.reshape([b, l, @num_attention_heads, @head_dim]).transpose([0, 2, 1, 3])
71
+ keys = keys.reshape([b, l, @num_key_value_heads, @head_dim]).transpose([0, 2, 1, 3])
72
+ values = values.reshape([b, l, @num_key_value_heads, @head_dim]).transpose([0, 2, 1, 3])
73
+
74
+ if cache
75
+ queries = rope.call(queries, offset: cache.offset)
76
+ keys = rope.call(keys, offset: cache.offset)
77
+ keys, values = cache.update_and_fetch(keys, values)
78
+ else
79
+ queries = rope.call(queries)
80
+ keys = rope.call(keys)
81
+ end
82
+
83
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
84
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @num_attention_heads * @head_dim])
85
+ o_proj.call(output)
86
+ end
87
+ end
88
+
89
+ class SparseMoeBlock < MLX::NN::Module
90
+ def initialize(args)
91
+ super()
92
+ mx = MLX::Core
93
+ @num_experts_per_tok = args.num_experts_per_tok
94
+ @num_local_experts = args.num_local_experts
95
+
96
+ self.gate = MLX::NN::Linear.new(args.hidden_size, @num_local_experts, bias: false)
97
+ self.switch_mlp = SwitchLayers::SwitchGLU.new(args.hidden_size, args.intermediate_size, @num_local_experts)
98
+ self.e_score_correction_bias = mx.zeros([@num_local_experts])
99
+ end
100
+
101
+ def call(x)
102
+ mx = MLX::Core
103
+
104
+ gates = gate.call(x.astype(mx.float32))
105
+ orig_scores = mx.sigmoid(gates)
106
+ scores = orig_scores + e_score_correction_bias
107
+
108
+ k = [[@num_experts_per_tok.to_i, 1].max, @num_local_experts.to_i].min
109
+ inds = mx.argpartition(scores * -1.0, k - 1, -1)
110
+ take_ids = mx.array((0...k).to_a, dtype: mx.int32)
111
+ inds = mx.take(inds, take_ids, -1)
112
+
113
+ scores = mx.take_along_axis(orig_scores, inds, -1)
114
+ scores = scores / (mx.expand_dims(mx.sum(scores, -1), -1) + 1e-20)
115
+ scores = scores.astype(x.dtype)
116
+
117
+ y = switch_mlp.call(x, inds)
118
+ mx.sum(y * mx.expand_dims(scores, -1), -2)
119
+ end
120
+ end
121
+
122
+ class DecoderLayer < MLX::NN::Module
123
+ def initialize(args)
124
+ super()
125
+ self.self_attn = Attention.new(args)
126
+ self.block_sparse_moe = SparseMoeBlock.new(args)
127
+ self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
128
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
129
+ end
130
+
131
+ def call(x, mask: nil, cache: nil)
132
+ h = x + self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
133
+ h + block_sparse_moe.call(post_attention_layernorm.call(h))
134
+ end
135
+ end
136
+
137
+ class MiniMaxModel < MLX::NN::Module
138
+ def initialize(args)
139
+ super()
140
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
141
+ self.layers = Array.new(args.num_hidden_layers) { DecoderLayer.new(args) }
142
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
143
+ end
144
+
145
+ def call(inputs, mask: nil, cache: nil)
146
+ h = embed_tokens.call(inputs)
147
+ layer_cache = cache || [nil] * layers.length
148
+ local_mask = mask || _create_attention_mask(h, layer_cache[0])
149
+
150
+ layers.each_with_index do |layer, i|
151
+ h = layer.call(h, mask: local_mask, cache: layer_cache[i])
152
+ end
153
+
154
+ norm.call(h)
155
+ end
156
+
157
+ private
158
+
159
+ def _create_attention_mask(h, cache)
160
+ n = h.shape[1]
161
+ return cache.make_mask(n) if cache && cache.respond_to?(:make_mask)
162
+ return nil if n == 1
163
+
164
+ "causal"
165
+ end
166
+ end
167
+
168
+ class Model < MLX::NN::Module
169
+ def initialize(args)
170
+ super()
171
+ @args = args
172
+ self.model_type = args.model_type
173
+ self.model = MiniMaxModel.new(args)
174
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) unless args.tie_word_embeddings
175
+ end
176
+
177
+ def call(inputs, mask: nil, cache: nil)
178
+ out = model.call(inputs, mask: mask, cache: cache)
179
+ if @args.tie_word_embeddings
180
+ model.embed_tokens.as_linear(out)
181
+ else
182
+ lm_head.call(out)
183
+ end
184
+ end
185
+
186
+ def sanitize(weights)
187
+ mx = MLX::Core
188
+ dequantized = {}
189
+
190
+ weights.each do |key, value|
191
+ if key.include?("weight_scale_inv")
192
+ weight_key = key.sub("_scale_inv", "")
193
+ next unless weights.key?(weight_key)
194
+
195
+ dequantized[weight_key] = _dequant(weights[weight_key], value)
196
+ elsif !dequantized.key?(key)
197
+ dequantized[key] = value
198
+ end
199
+ end
200
+
201
+ result = dequantized
202
+ return result unless result.key?("model.layers.0.block_sparse_moe.experts.0.w1.weight")
203
+
204
+ mapping = {
205
+ "w1" => "gate_proj",
206
+ "w2" => "down_proj",
207
+ "w3" => "up_proj",
208
+ }
209
+ experts_count = @args.num_local_experts.to_i
210
+ return result if experts_count <= 0
211
+
212
+ @args.num_hidden_layers.times do |layer_idx|
213
+ prefix = "model.layers.#{layer_idx}"
214
+ mapping.each do |old_name, new_name|
215
+ first_key = "#{prefix}.block_sparse_moe.experts.0.#{old_name}.weight"
216
+ next unless result.key?(first_key)
217
+
218
+ expert_keys = (0...experts_count).map do |expert_idx|
219
+ "#{prefix}.block_sparse_moe.experts.#{expert_idx}.#{old_name}.weight"
220
+ end
221
+ next unless expert_keys.all? { |k| result.key?(k) }
222
+
223
+ stacked = expert_keys.map { |k| result.delete(k) }
224
+ result["#{prefix}.block_sparse_moe.switch_mlp.#{new_name}.weight"] = mx.stack(stacked)
225
+ end
226
+ end
227
+
228
+ result
229
+ end
230
+
231
+ def layers
232
+ model.layers
233
+ end
234
+
235
+ def cast_predicate
236
+ lambda { |key| !key.include?("e_score_correction_bias") }
237
+ end
238
+
239
+ def quant_predicate
240
+ lambda do |path, _|
241
+ if path.end_with?("block_sparse_moe.gate")
242
+ { group_size: 64, bits: 8 }
243
+ else
244
+ true
245
+ end
246
+ end
247
+ end
248
+
249
+ private
250
+
251
+ def _dequant(weight, scale_inv)
252
+ mx = MLX::Core
253
+ dtype = mx.bfloat16
254
+ block_size = 128
255
+
256
+ dequantized = mx.from_fp8(weight, dtype: dtype)
257
+ m, n = dequantized.shape
258
+ pad_bottom = block_size * scale_inv.shape[0] - m
259
+ pad_side = block_size * scale_inv.shape[1] - n
260
+
261
+ dequantized = mx.pad(dequantized, [[0, pad_bottom], [0, pad_side]])
262
+ dequantized = dequantized.reshape([
263
+ (m + pad_bottom) / block_size,
264
+ block_size,
265
+ (n + pad_side) / block_size,
266
+ block_size,
267
+ ])
268
+
269
+ scaled = dequantized * scale_inv.reshape([scale_inv.shape[0], 1, scale_inv.shape[1], 1])
270
+ scaled = scaled.reshape([m + pad_bottom, n + pad_side])
271
+ scaled = mx.split(scaled, [m], 0)[0]
272
+ scaled = mx.split(scaled, [n], 1)[0]
273
+ scaled.astype(dtype)
274
+ rescue StandardError
275
+ weight
276
+ end
277
+ end
278
+
279
+ Models.register("minimax", Model, ModelArgs)
280
+ end
281
+ end
282
+ end