mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
@@ -0,0 +1,304 @@
1
+ require_relative "activations"
2
+ require_relative "cache"
3
+ require_relative "pipeline"
4
+ require_relative "rope_utils"
5
+
6
+ module MlxLm
7
+ module Models
8
+ module Ministral3
9
+ def self.llama4_attn_scale(size, offset, beta, max_position_embeddings)
10
+ mx = MLX::Core
11
+ positions = mx.arange(size) + offset
12
+ scale = 1.0 + beta.to_f * mx.log(1.0 + mx.floor(positions / max_position_embeddings.to_f))
13
+ scale.reshape([size, 1])
14
+ end
15
+
16
+ class ModelArgs < BaseModelArgs
17
+ field :model_type, default: "ministral3"
18
+ field :hidden_size
19
+ field :num_hidden_layers
20
+ field :intermediate_size
21
+ field :num_attention_heads
22
+ field :rms_norm_eps
23
+ field :vocab_size
24
+ field :head_dim, default: nil
25
+ field :max_position_embeddings, default: nil
26
+ field :num_key_value_heads, default: nil
27
+ field :rope_parameters, default: nil
28
+ field :tie_word_embeddings, default: true
29
+ field :layer_types, default: nil
30
+ field :sliding_window, default: nil
31
+
32
+ def initialize(**kwargs)
33
+ super
34
+ @num_key_value_heads ||= @num_attention_heads
35
+ @head_dim ||= @hidden_size / @num_attention_heads
36
+ @rope_parameters = _stringify_keys(@rope_parameters || {})
37
+ @rope_parameters["rope_theta"] = 10_000.0 unless @rope_parameters.key?("rope_theta")
38
+ @layer_types ||= Array.new(@num_hidden_layers) { "full_attention" }
39
+ end
40
+
41
+ def rope_parameter(key, default = nil)
42
+ return default unless @rope_parameters.is_a?(Hash)
43
+ return @rope_parameters[key.to_s] if @rope_parameters.key?(key.to_s)
44
+ return @rope_parameters[key.to_sym] if @rope_parameters.key?(key.to_sym)
45
+
46
+ default
47
+ end
48
+
49
+ private
50
+
51
+ def _stringify_keys(hash)
52
+ hash.each_with_object({}) do |(key, value), out|
53
+ out[key.to_s] = value
54
+ end
55
+ end
56
+ end
57
+
58
+ class Attention < MLX::NN::Module
59
+ def initialize(args)
60
+ super()
61
+
62
+ dim = args.hidden_size
63
+ @n_heads = args.num_attention_heads
64
+ @n_kv_heads = args.num_key_value_heads
65
+ @head_dim = args.head_dim
66
+ @scale = @head_dim**(-0.5)
67
+
68
+ self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: false)
69
+ self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: false)
70
+ self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: false)
71
+ self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: false)
72
+
73
+ self.rope = MlxLm::Models.initialize_rope(
74
+ @head_dim,
75
+ args.rope_parameter("rope_theta", 10_000.0),
76
+ false,
77
+ args.rope_parameters,
78
+ max_position_embeddings: args.max_position_embeddings
79
+ )
80
+ end
81
+
82
+ def call(x, attn_scale:, mask: nil, cache: nil)
83
+ mx = MLX::Core
84
+ b, l, _d = x.shape
85
+
86
+ queries = q_proj.call(x).reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
87
+ keys = k_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
88
+ values = v_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
89
+
90
+ if cache
91
+ queries = rope.call(queries, offset: cache.offset)
92
+ keys = rope.call(keys, offset: cache.offset)
93
+ keys, values = cache.update_and_fetch(keys, values)
94
+ else
95
+ queries = rope.call(queries)
96
+ keys = rope.call(keys)
97
+ end
98
+
99
+ queries = queries * attn_scale
100
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
101
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
102
+ o_proj.call(output)
103
+ end
104
+ end
105
+
106
+ class MLP < MLX::NN::Module
107
+ def initialize(args)
108
+ super()
109
+
110
+ dim = args.hidden_size
111
+ hidden_dim = args.intermediate_size
112
+ self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false)
113
+ self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: false)
114
+ self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false)
115
+ end
116
+
117
+ def call(x)
118
+ down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x)))
119
+ end
120
+ end
121
+
122
+ class TransformerBlock < MLX::NN::Module
123
+ attr_reader :use_sliding
124
+
125
+ def initialize(args, use_sliding: false)
126
+ super()
127
+ @use_sliding = use_sliding
128
+ self.self_attn = Attention.new(args)
129
+ self.mlp = MLP.new(args)
130
+ self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
131
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
132
+ end
133
+
134
+ def call(x, attn_scale:, mask: nil, cache: nil)
135
+ r = self_attn.call(input_layernorm.call(x), attn_scale: attn_scale, mask: mask, cache: cache)
136
+ h = x + r
137
+ r = mlp.call(post_attention_layernorm.call(h))
138
+ h + r
139
+ end
140
+ end
141
+
142
+ class LanguageModel < MLX::NN::Module
143
+ include PipelineMixin
144
+ attr_reader :sliding_window
145
+
146
+ def initialize(args)
147
+ super()
148
+ @args = args
149
+ @sliding_window = args.sliding_window
150
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
151
+ self.layers = args.layer_types.map do |layer_type|
152
+ TransformerBlock.new(args, use_sliding: layer_type == "sliding_attention")
153
+ end
154
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
155
+ end
156
+
157
+ def call(inputs, cache: nil, input_embeddings: nil)
158
+ h = input_embeddings || embed_tokens.call(inputs)
159
+ active_layers = pipeline_layers
160
+ layer_cache = cache || Array.new(active_layers.length)
161
+
162
+ first_cache = layer_cache.find { |entry| !entry.nil? }
163
+ offset = first_cache ? first_cache.offset : 0
164
+
165
+ fa_idx = nil
166
+ swa_idx = nil
167
+ active_layers.each_with_index do |layer, i|
168
+ if layer.use_sliding
169
+ swa_idx ||= i
170
+ else
171
+ fa_idx ||= i
172
+ end
173
+ break if fa_idx && swa_idx
174
+ end
175
+
176
+ fa_mask = fa_idx.nil? ? nil : _create_attention_mask(h, layer_cache[fa_idx])
177
+ swa_mask = if swa_idx.nil?
178
+ nil
179
+ else
180
+ _create_attention_mask(h, layer_cache[swa_idx], window_size: @sliding_window)
181
+ end
182
+
183
+ beta = @args.rope_parameter("llama_4_scaling_beta", 0.0).to_f
184
+ max_pos = @args.rope_parameter(
185
+ "original_max_position_embeddings",
186
+ @args.max_position_embeddings || h.shape[1]
187
+ ).to_i
188
+ max_pos = 1 if max_pos <= 0
189
+
190
+ attn_scale = MlxLm::Models::Ministral3.llama4_attn_scale(
191
+ inputs.shape[1],
192
+ offset,
193
+ beta,
194
+ max_pos
195
+ ).astype(h.dtype)
196
+
197
+ active_layers.each_with_index do |layer, idx|
198
+ mask = layer.use_sliding ? swa_mask : fa_mask
199
+ h = layer.call(h, attn_scale: attn_scale, mask: mask, cache: layer_cache[idx])
200
+ end
201
+
202
+ norm.call(h)
203
+ end
204
+
205
+ private
206
+
207
+ def _create_attention_mask(h, cache = nil, window_size: nil)
208
+ n = h.shape[1]
209
+ if cache && cache.respond_to?(:make_mask)
210
+ return cache.make_mask(n, window_size: window_size)
211
+ end
212
+
213
+ if window_size
214
+ offset = cache ? cache.offset : 0
215
+ if cache && cache.instance_variable_defined?(:@max_size)
216
+ max_size = cache.instance_variable_get(:@max_size)
217
+ offset = [max_size - 1, offset].min if max_size && max_size > 0
218
+ end
219
+
220
+ return _create_causal_mask(n, offset: offset, window_size: window_size) if offset + n > window_size
221
+ end
222
+
223
+ return nil if n == 1
224
+
225
+ "causal"
226
+ end
227
+
228
+ def _create_causal_mask(n, offset: 0, window_size: nil)
229
+ mx = MLX::Core
230
+ rinds = mx.arange(0, offset + n, 1, mx.int32).reshape([1, offset + n])
231
+ linds = mx.arange(offset, offset + n, 1, mx.int32).reshape([n, 1])
232
+
233
+ mask = mx.greater_equal(linds, rinds)
234
+ if window_size
235
+ mask = mx.logical_and(mask, mx.less(linds, mx.add(rinds, window_size)))
236
+ end
237
+ mask
238
+ end
239
+ end
240
+
241
+ class Model < MLX::NN::Module
242
+ def initialize(args)
243
+ super()
244
+ @args = args
245
+ self.model_type = args.model_type
246
+ self.model = LanguageModel.new(args)
247
+ unless args.tie_word_embeddings
248
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
249
+ end
250
+ end
251
+
252
+ def call(inputs, cache: nil, input_embeddings: nil)
253
+ out = model.call(inputs, cache: cache, input_embeddings: input_embeddings)
254
+ if @args.tie_word_embeddings
255
+ model.embed_tokens.as_linear(out)
256
+ else
257
+ lm_head.call(out)
258
+ end
259
+ end
260
+
261
+ def sanitize(weights)
262
+ sanitized = weights.reject do |key, _|
263
+ key_name = key.to_s
264
+ key_name.include?("self_attn.rotary_emb.inv_freq") || key_name.include?("self_attn.rope.inv_freq")
265
+ end
266
+ sanitized.delete("lm_head.weight") if @args.tie_word_embeddings
267
+
268
+ new_weights = {}
269
+ sanitized.each do |key, value|
270
+ key_name = key.to_s
271
+ if key_name.include?("weight_scale_inv")
272
+ wk = key_name.sub("_scale_inv", "")
273
+ next unless sanitized.key?(wk)
274
+
275
+ new_weights[wk] = sanitized[wk] * value
276
+ elsif key_name.include?("activation_scale")
277
+ next
278
+ elsif !new_weights.key?(key)
279
+ new_weights[key] = value
280
+ end
281
+ end
282
+ new_weights
283
+ end
284
+
285
+ def layers
286
+ model.pipeline_layers
287
+ end
288
+
289
+ def make_cache
290
+ max_size = @args.sliding_window || @args.max_position_embeddings || 1
291
+ layers.map do |layer|
292
+ if layer.use_sliding
293
+ MlxLm::RotatingKVCache.new(max_size: max_size)
294
+ else
295
+ MlxLm::KVCache.new
296
+ end
297
+ end
298
+ end
299
+ end
300
+
301
+ Models.register("ministral3", Model, ModelArgs)
302
+ end
303
+ end
304
+ end
@@ -0,0 +1,84 @@
1
+ require_relative "llama"
2
+
3
+ module MlxLm
4
+ module Models
5
+ module Mistral3
6
+ class ModelArgs < BaseModelArgs
7
+ field :model_type, default: "mistral3"
8
+ field :text_config, default: nil
9
+
10
+ def initialize(**kwargs)
11
+ super
12
+ @text_config = (@text_config || {}).dup
13
+ @text_config["tie_word_embeddings"] = false unless @text_config.key?("tie_word_embeddings")
14
+ end
15
+ end
16
+
17
+ class Model < MLX::NN::Module
18
+ def initialize(args)
19
+ super()
20
+ @args = args
21
+ self.model_type = args.model_type
22
+
23
+ text_config = args.text_config || {}
24
+ text_model_type = text_config["model_type"]
25
+
26
+ if text_model_type == "ministral3" && Models::REGISTRY.key?("ministral3")
27
+ model_class, args_class = Models.get_classes(text_config)
28
+ self.language_model = model_class.new(args_class.from_dict(text_config))
29
+ else
30
+ self.language_model = Llama::Model.new(Llama::ModelArgs.from_dict(text_config))
31
+ end
32
+ end
33
+
34
+ def call(inputs, cache: nil, input_embeddings: nil)
35
+ supports_input_embeddings = language_model.method(:call).parameters.any? do |_, name|
36
+ name == :input_embeddings
37
+ end
38
+
39
+ if supports_input_embeddings
40
+ language_model.call(inputs, cache: cache, input_embeddings: input_embeddings)
41
+ else
42
+ language_model.call(inputs, cache: cache)
43
+ end
44
+ end
45
+
46
+ def sanitize(weights)
47
+ result = {}
48
+ language_weights = {}
49
+
50
+ weights.each do |k, v|
51
+ next if k == "vision_tower" || k.start_with?("vision_tower.")
52
+ next if k == "multi_modal_projector" || k.start_with?("multi_modal_projector.")
53
+
54
+ if k.start_with?("language_model.")
55
+ language_weights[k.delete_prefix("language_model.")] = v
56
+ else
57
+ result[k] = v
58
+ end
59
+ end
60
+
61
+ sanitized_language = if language_model.respond_to?(:sanitize)
62
+ language_model.sanitize(language_weights)
63
+ else
64
+ language_weights
65
+ end
66
+
67
+ sanitized_language.each do |k, v|
68
+ result["language_model.#{k}"] = v
69
+ end
70
+
71
+ result
72
+ end
73
+
74
+ def layers
75
+ return language_model.model.layers if language_model.respond_to?(:model) && language_model.model.respond_to?(:layers)
76
+
77
+ language_model.layers
78
+ end
79
+ end
80
+
81
+ Models.register("mistral3", Model, ModelArgs)
82
+ end
83
+ end
84
+ end
@@ -0,0 +1,192 @@
1
+ module MlxLm
2
+ module Models
3
+ module Mixtral
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "mixtral"
6
+ field :hidden_size, default: 4096
7
+ field :num_hidden_layers, default: 32
8
+ field :num_attention_heads, default: 32
9
+ field :num_key_value_heads, default: 8
10
+ field :intermediate_size, default: 14336
11
+ field :vocab_size, default: 32000
12
+ field :rms_norm_eps, default: 1e-5
13
+ field :rope_theta, default: 1e6
14
+ field :rope_traditional, default: false
15
+ field :rope_scaling, default: nil
16
+ field :num_local_experts, default: 8
17
+ field :num_experts_per_tok, default: 2
18
+ field :tie_word_embeddings, default: false
19
+
20
+ def initialize(**kwargs)
21
+ super
22
+ @num_key_value_heads ||= @num_attention_heads
23
+ end
24
+ end
25
+
26
+ class Attention < MLX::NN::Module
27
+ def initialize(args)
28
+ super()
29
+ dim = args.hidden_size
30
+ @n_heads = args.num_attention_heads
31
+ @n_kv_heads = args.num_key_value_heads
32
+ @head_dim = dim / @n_heads
33
+ @scale = @head_dim**(-0.5)
34
+
35
+ self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: false)
36
+ self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: false)
37
+ self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: false)
38
+ self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: false)
39
+
40
+ self.rope = MLX::NN::RoPE.new(
41
+ @head_dim,
42
+ traditional: args.rope_traditional,
43
+ base: args.rope_theta
44
+ )
45
+ end
46
+
47
+ def call(x, mask: nil, cache: nil)
48
+ mx = MLX::Core
49
+ b, l, _d = x.shape
50
+
51
+ queries = q_proj.call(x).reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
52
+ keys = k_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
53
+ values = v_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
54
+
55
+ if cache
56
+ queries = rope.call(queries, offset: cache.offset)
57
+ keys = rope.call(keys, offset: cache.offset)
58
+ keys, values = cache.update_and_fetch(keys, values)
59
+ else
60
+ queries = rope.call(queries)
61
+ keys = rope.call(keys)
62
+ end
63
+
64
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
65
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
66
+ o_proj.call(output)
67
+ end
68
+ end
69
+
70
+ class SparseMoeBlock < MLX::NN::Module
71
+ def initialize(args)
72
+ super()
73
+ @num_experts = args.num_local_experts
74
+ @num_experts_per_tok = args.num_experts_per_tok
75
+ dim = args.hidden_size
76
+ hidden_dim = args.intermediate_size
77
+
78
+ self.gate = MLX::NN::Linear.new(dim, @num_experts, bias: false)
79
+ self.switch_mlp = SwitchLayers::SwitchGLU.new(dim, hidden_dim, @num_experts)
80
+ end
81
+
82
+ def call(x)
83
+ mx = MLX::Core
84
+ k = @num_experts_per_tok
85
+
86
+ gates = gate.call(x)
87
+ inds = mx.stop_gradient(mx.argpartition(gates * -1.0, k - 1, -1))
88
+ take_ids = mx.array((0...k).to_a, dtype: mx.int32)
89
+ inds = mx.take(inds, take_ids, -1)
90
+
91
+ scores = mx.take_along_axis(gates, inds, -1)
92
+ scores = mx.softmax(scores.astype(mx.float32), -1).astype(gates.dtype)
93
+
94
+ y = switch_mlp.call(x, inds)
95
+ y = mx.sum(y * mx.expand_dims(scores, -1), -2)
96
+ y
97
+ end
98
+ end
99
+
100
+ class MixtralDecoderLayer < MLX::NN::Module
101
+ def initialize(args)
102
+ super()
103
+ self.self_attn = Attention.new(args)
104
+ self.block_sparse_moe = SparseMoeBlock.new(args)
105
+ self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
106
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
107
+ end
108
+
109
+ def call(x, mask: nil, cache: nil)
110
+ r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
111
+ h = x + r
112
+ r = block_sparse_moe.call(post_attention_layernorm.call(h))
113
+ h + r
114
+ end
115
+ end
116
+
117
+ class MixtralModel < MLX::NN::Module
118
+ def initialize(args)
119
+ super()
120
+ @args = args
121
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
122
+ self.layers = Array.new(args.num_hidden_layers) { MixtralDecoderLayer.new(args) }
123
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
124
+ end
125
+
126
+ def call(inputs, cache: nil)
127
+ h = embed_tokens.call(inputs)
128
+ layer_cache = cache || [nil] * layers.length
129
+
130
+ mask = nil
131
+ mask = "causal" if h.shape[1] > 1
132
+
133
+ layers.each_with_index do |layer, i|
134
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
135
+ end
136
+
137
+ norm.call(h)
138
+ end
139
+ end
140
+
141
+ class Model < MLX::NN::Module
142
+ def initialize(args)
143
+ super()
144
+ @args = args
145
+ self.model = MixtralModel.new(args)
146
+ unless args.tie_word_embeddings
147
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
148
+ end
149
+ end
150
+
151
+ def call(inputs, cache: nil)
152
+ out = model.call(inputs, cache: cache)
153
+ if @args.tie_word_embeddings
154
+ model.embed_tokens.as_linear(out)
155
+ else
156
+ lm_head.call(out)
157
+ end
158
+ end
159
+
160
+ def sanitize(weights)
161
+ mx = MLX::Core
162
+ result = weights.reject { |k, _| k.include?("self_attn.rotary_emb.inv_freq") }
163
+ result.delete("lm_head.weight") if @args.tie_word_embeddings
164
+
165
+ # Convert per-expert weights to stacked SwitchGLU format
166
+ @args.num_hidden_layers.times do |l|
167
+ prefix = "model.layers.#{l}"
168
+ [["w1", "gate_proj"], ["w2", "down_proj"], ["w3", "up_proj"]].each do |n, m|
169
+ ["weight", "scales", "biases"].each do |k|
170
+ key0 = "#{prefix}.block_sparse_moe.experts.0.#{n}.#{k}"
171
+ if result.key?(key0)
172
+ to_join = (0...@args.num_local_experts).map { |e|
173
+ result.delete("#{prefix}.block_sparse_moe.experts.#{e}.#{n}.#{k}")
174
+ }
175
+ result["#{prefix}.block_sparse_moe.switch_mlp.#{m}.#{k}"] = mx.stack(to_join)
176
+ end
177
+ end
178
+ end
179
+ end
180
+
181
+ result
182
+ end
183
+
184
+ def layers
185
+ model.layers
186
+ end
187
+ end
188
+
189
+ Models.register("mixtral", Model, ModelArgs)
190
+ end
191
+ end
192
+ end
@@ -0,0 +1,75 @@
1
+ module MlxLm
2
+ module Models
3
+ module MLA
4
+ class MultiLinear < MLX::NN::Module
5
+ def initialize(input_dims, output_dims, num_heads)
6
+ super()
7
+ scale = Math.sqrt(1.0 / input_dims)
8
+ self.weight = MLX::Core.uniform([num_heads, output_dims, input_dims], -scale, scale)
9
+ end
10
+
11
+ def call(x, transpose: true)
12
+ if transpose
13
+ MLX::Core.matmul(x, MLX::Core.swapaxes(weight, -1, -2))
14
+ else
15
+ MLX::Core.matmul(x, weight)
16
+ end
17
+ end
18
+
19
+ def to_quantized(group_size: nil, bits: nil, mode: "affine", quantize_input: false)
20
+ raise ArgumentError, "Quantized input is not supported." if quantize_input
21
+
22
+ QuantizedMultiLinear.from_multi_linear(self, group_size, bits, mode: mode)
23
+ end
24
+ end
25
+
26
+ class QuantizedMultiLinear < MLX::NN::Module
27
+ attr_reader :group_size, :bits, :mode
28
+
29
+ def initialize(input_dims, output_dims, num_heads, group_size = nil, bits = nil, mode: "affine")
30
+ super()
31
+
32
+ @group_size, @bits = MLX::NN.__send__(:defaults_for_mode, mode, group_size, bits)
33
+ @mode = mode
34
+
35
+ scale = Math.sqrt(1.0 / input_dims)
36
+ weight = MLX::Core.uniform([num_heads, output_dims, input_dims], -scale, scale)
37
+ q_weight, q_scales, *q_biases = MLX::Core.quantize(weight, @group_size, @bits, @mode)
38
+ self.weight = q_weight
39
+ self.scales = q_scales
40
+ self.biases = q_biases.empty? ? nil : q_biases[0]
41
+
42
+ freeze
43
+ end
44
+
45
+ def call(x, transpose: true)
46
+ MLX::Core.quantized_matmul(
47
+ x,
48
+ weight,
49
+ scales,
50
+ biases,
51
+ transpose,
52
+ @group_size,
53
+ @bits,
54
+ @mode
55
+ )
56
+ end
57
+
58
+ def self.from_multi_linear(multi_linear_layer, group_size = nil, bits = nil, mode: "affine")
59
+ num_heads, output_dims, input_dims = multi_linear_layer.weight.shape
60
+ out = new(input_dims, output_dims, num_heads, group_size, bits, mode: mode)
61
+ q_weight, q_scales, *q_biases = MLX::Core.quantize(
62
+ multi_linear_layer.weight,
63
+ out.group_size,
64
+ out.bits,
65
+ out.mode
66
+ )
67
+ out.weight = q_weight
68
+ out.scales = q_scales
69
+ out.biases = q_biases.empty? ? nil : q_biases[0]
70
+ out
71
+ end
72
+ end
73
+ end
74
+ end
75
+ end