mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
@@ -0,0 +1,421 @@
1
+ require_relative "activations"
2
+ require_relative "cache"
3
+ require_relative "rope_utils"
4
+ require_relative "switch_layers"
5
+
6
+ module MlxLm
7
+ module Models
8
+ module Lfm2Moe
9
+ class ModelArgs < BaseModelArgs
10
+ field :model_type, default: "lfm2_moe"
11
+ field :vocab_size
12
+ field :hidden_size
13
+ field :intermediate_size
14
+ field :moe_intermediate_size
15
+ field :num_hidden_layers
16
+ field :num_experts
17
+ field :num_experts_per_tok
18
+ field :norm_topk_prob
19
+ field :num_attention_heads
20
+ field :num_key_value_heads, default: nil
21
+ field :max_position_embeddings
22
+ field :use_expert_bias
23
+ field :num_dense_layers
24
+ field :norm_eps
25
+ field :conv_bias
26
+ field :conv_L_cache
27
+ field :rope_theta, default: 1_000_000.0
28
+ field :rope_parameters, default: nil
29
+ field :full_attn_idxs, default: nil
30
+ field :layer_types, default: nil
31
+
32
+ def initialize(**kwargs)
33
+ super
34
+ rope_theta_from_params = _rope_theta_from_parameters
35
+ @rope_theta = rope_theta_from_params unless rope_theta_from_params.nil?
36
+ @num_key_value_heads ||= @num_attention_heads
37
+ @full_attn_idxs ||= _full_attn_idxs_from_layer_types
38
+ end
39
+
40
+ private
41
+
42
+ def _rope_theta_from_parameters
43
+ return nil unless @rope_parameters.is_a?(Hash)
44
+
45
+ @rope_parameters["rope_theta"] || @rope_parameters[:rope_theta]
46
+ end
47
+
48
+ def _full_attn_idxs_from_layer_types
49
+ return [] unless @layer_types.is_a?(Array)
50
+
51
+ @layer_types.each_with_index.filter_map do |layer_type, i|
52
+ i if layer_type.to_s == "full_attention"
53
+ end
54
+ end
55
+ end
56
+
57
+ class Attention < MLX::NN::Module
58
+ def initialize(args)
59
+ super()
60
+
61
+ dim = args.hidden_size
62
+ @n_heads = args.num_attention_heads
63
+ @n_kv_heads = args.num_key_value_heads
64
+ @head_dim = args.hidden_size / @n_heads
65
+ @scale = @head_dim**(-0.5)
66
+
67
+ self.q_layernorm = MLX::NN::RMSNorm.new(@head_dim, eps: args.norm_eps)
68
+ self.k_layernorm = MLX::NN::RMSNorm.new(@head_dim, eps: args.norm_eps)
69
+
70
+ self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: false)
71
+ self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: false)
72
+ self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: false)
73
+ self.out_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: false)
74
+
75
+ self.rope = MlxLm::Models.initialize_rope(
76
+ @head_dim,
77
+ args.rope_theta,
78
+ false,
79
+ nil,
80
+ max_position_embeddings: args.max_position_embeddings
81
+ )
82
+ end
83
+
84
+ def call(x, mask: nil, cache: nil)
85
+ mx = MLX::Core
86
+ b, l, _d = x.shape
87
+
88
+ queries = q_proj.call(x)
89
+ keys = k_proj.call(x)
90
+ values = v_proj.call(x)
91
+
92
+ queries = q_layernorm.call(queries.reshape([b, l, @n_heads, @head_dim])).transpose([0, 2, 1, 3])
93
+ keys = k_layernorm.call(keys.reshape([b, l, @n_kv_heads, @head_dim])).transpose([0, 2, 1, 3])
94
+ values = values.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
95
+
96
+ if cache
97
+ queries = rope.call(queries, offset: cache.offset)
98
+ keys = rope.call(keys, offset: cache.offset)
99
+ keys, values = cache.update_and_fetch(keys, values)
100
+ else
101
+ queries = rope.call(queries)
102
+ keys = rope.call(keys)
103
+ end
104
+
105
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
106
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
107
+ out_proj.call(output)
108
+ end
109
+ end
110
+
111
+ class ShortConv < MLX::NN::Module
112
+ def initialize(args, layer_idx)
113
+ super()
114
+ _ = layer_idx
115
+ @args = args
116
+ @l_cache = args.conv_L_cache
117
+ @hidden_size = args.hidden_size
118
+
119
+ self.conv = MLX::NN::Conv1d.new(
120
+ args.hidden_size,
121
+ args.hidden_size,
122
+ @l_cache,
123
+ padding: 0,
124
+ groups: args.hidden_size,
125
+ bias: args.conv_bias
126
+ )
127
+ self.in_proj = MLX::NN::Linear.new(args.hidden_size, 3 * args.hidden_size, bias: args.conv_bias)
128
+ self.out_proj = MLX::NN::Linear.new(args.hidden_size, args.hidden_size, bias: args.conv_bias)
129
+ end
130
+
131
+ def call(x, mask: nil, cache: nil)
132
+ mx = MLX::Core
133
+
134
+ projected = in_proj.call(x)
135
+ b_gate, c_gate, x_gate = mx.split(projected, [@hidden_size, 2 * @hidden_size], -1)
136
+ bx = b_gate * x_gate
137
+ bx = mx.where(mask.reshape([mask.shape[0], mask.shape[1], 1]), bx, 0) unless mask.nil?
138
+
139
+ if cache
140
+ state = if cache[0].nil?
141
+ mx.zeros([bx.shape[0], @l_cache - 1, @hidden_size], dtype: bx.dtype)
142
+ else
143
+ cache[0]
144
+ end
145
+
146
+ bx = mx.concatenate([state, bx], 1)
147
+ n_keep = @l_cache - 1
148
+ t = x_gate.shape[1]
149
+
150
+ if cache.lengths
151
+ ends = mx.clip(cache.lengths, 0, t)
152
+ positions = mx.expand_dims(
153
+ mx.expand_dims(ends, 1) + mx.arange(n_keep),
154
+ -1
155
+ )
156
+ cache[0] = mx.take_along_axis(bx, positions, 1)
157
+ else
158
+ if n_keep > 0
159
+ split_at = bx.shape[1] - n_keep
160
+ cache[0] = mx.split(bx, [split_at], 1)[1]
161
+ else
162
+ cache[0] = mx.zeros([bx.shape[0], 0, bx.shape[2]], dtype: bx.dtype)
163
+ end
164
+ end
165
+
166
+ cache.advance(t)
167
+ else
168
+ bx = mx.pad(
169
+ bx,
170
+ [
171
+ [0, 0],
172
+ [@l_cache - 1, 0],
173
+ [0, 0],
174
+ ]
175
+ )
176
+ end
177
+
178
+ conv_out = conv.call(bx)
179
+ out_proj.call(c_gate * conv_out)
180
+ end
181
+ end
182
+
183
+ class MLP < MLX::NN::Module
184
+ def initialize(config, intermediate_size: nil)
185
+ super()
186
+ @hidden_size = config.hidden_size
187
+ @intermediate_size = intermediate_size || config.intermediate_size
188
+ self.gate_proj = MLX::NN::Linear.new(@hidden_size, @intermediate_size, bias: false)
189
+ self.up_proj = MLX::NN::Linear.new(@hidden_size, @intermediate_size, bias: false)
190
+ self.down_proj = MLX::NN::Linear.new(@intermediate_size, @hidden_size, bias: false)
191
+ end
192
+
193
+ def call(x)
194
+ down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x)))
195
+ end
196
+ end
197
+
198
+ class SparseMoeBlock < MLX::NN::Module
199
+ def initialize(args)
200
+ super()
201
+ dim = args.hidden_size
202
+ intermediate_size = args.moe_intermediate_size
203
+
204
+ @num_experts = args.num_experts
205
+ @top_k = args.num_experts_per_tok
206
+ @norm_topk_prob = args.norm_topk_prob
207
+ @use_expert_bias = args.use_expert_bias
208
+
209
+ self.gate = MLX::NN::Linear.new(dim, @num_experts, bias: false)
210
+ self.switch_mlp = SwitchLayers::SwitchGLU.new(dim, intermediate_size, @num_experts)
211
+ self.expert_bias = MLX::Core.zeros([@num_experts]) if @use_expert_bias
212
+ end
213
+
214
+ def call(x)
215
+ mx = MLX::Core
216
+
217
+ gates = gate.call(x).astype(mx.float32)
218
+ gates = mx.softmax(gates, -1)
219
+ gates = gates + expert_bias if @use_expert_bias
220
+
221
+ k = [[@top_k.to_i, 1].max, @num_experts].min
222
+ inds = mx.argpartition(gates, -k, -1)
223
+ take_ids = mx.array((@num_experts - k...@num_experts).to_a, dtype: mx.int32)
224
+ inds = mx.take(inds, take_ids, -1)
225
+
226
+ scores = mx.take_along_axis(gates, inds, -1)
227
+ if @norm_topk_prob
228
+ scores = scores / (mx.expand_dims(mx.sum(scores, -1), -1) + 1e-20)
229
+ end
230
+ scores = scores.astype(x.dtype)
231
+
232
+ y = switch_mlp.call(x, inds)
233
+ mx.sum(y * mx.expand_dims(scores, -1), -2)
234
+ end
235
+ end
236
+
237
+ class DecoderLayer < MLX::NN::Module
238
+ attr_reader :is_attention_layer
239
+
240
+ def initialize(args, layer_idx)
241
+ super()
242
+ @is_attention_layer = args.full_attn_idxs.include?(layer_idx)
243
+
244
+ if @is_attention_layer
245
+ self.self_attn = Attention.new(args)
246
+ else
247
+ self.conv = ShortConv.new(args, layer_idx)
248
+ end
249
+
250
+ self.feed_forward = if layer_idx < args.num_dense_layers
251
+ MLP.new(args, intermediate_size: args.intermediate_size)
252
+ else
253
+ SparseMoeBlock.new(args)
254
+ end
255
+
256
+ self.operator_norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.norm_eps)
257
+ self.ffn_norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.norm_eps)
258
+ end
259
+
260
+ def call(x, mask: nil, cache: nil)
261
+ r = if @is_attention_layer
262
+ self_attn.call(operator_norm.call(x), mask: mask, cache: cache)
263
+ else
264
+ conv.call(operator_norm.call(x), mask: mask, cache: cache)
265
+ end
266
+
267
+ h = x + r
268
+ h + feed_forward.call(ffn_norm.call(h))
269
+ end
270
+ end
271
+
272
+ class Lfm2MoeModel < MLX::NN::Module
273
+ def initialize(args)
274
+ super()
275
+ @args = args
276
+
277
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
278
+ self.layers = Array.new(args.num_hidden_layers) { |i| DecoderLayer.new(args, i) }
279
+ self.embedding_norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.norm_eps)
280
+
281
+ self.fa_idx = args.full_attn_idxs[0] || 0
282
+ self.conv_idx = 0
283
+ args.num_hidden_layers.times do |i|
284
+ if args.full_attn_idxs.include?(i)
285
+ self.conv_idx += 1
286
+ else
287
+ break
288
+ end
289
+ end
290
+ self.conv_idx = [conv_idx, args.num_hidden_layers - 1].min
291
+ end
292
+
293
+ def call(inputs, cache: nil, input_embeddings: nil)
294
+ h = input_embeddings || embed_tokens.call(inputs)
295
+ layer_cache = cache || [nil] * layers.length
296
+
297
+ attn_mask = _create_attention_mask(h, layer_cache[fa_idx])
298
+ conv_mask = _create_ssm_mask(h, layer_cache[conv_idx])
299
+
300
+ layers.each_with_index do |layer, i|
301
+ mask = layer.is_attention_layer ? attn_mask : conv_mask
302
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
303
+ end
304
+
305
+ embedding_norm.call(h)
306
+ end
307
+
308
+ private
309
+
310
+ def _create_attention_mask(h, cache = nil)
311
+ n = h.shape[1]
312
+ return cache.make_mask(n) if cache && cache.respond_to?(:make_mask)
313
+ return nil if n == 1
314
+
315
+ "causal"
316
+ end
317
+
318
+ def _create_ssm_mask(h, cache = nil)
319
+ return cache.make_mask(h.shape[1]) if cache && cache.respond_to?(:make_mask)
320
+
321
+ nil
322
+ end
323
+ end
324
+
325
+ class Model < MLX::NN::Module
326
+ def initialize(args)
327
+ super()
328
+ @args = args
329
+ self.model_type = args.model_type
330
+ self.model = Lfm2MoeModel.new(args)
331
+ end
332
+
333
+ def call(inputs, cache: nil, input_embeddings: nil)
334
+ out = model.call(inputs, cache: cache, input_embeddings: input_embeddings)
335
+ model.embed_tokens.as_linear(out)
336
+ end
337
+
338
+ def sanitize(weights)
339
+ mx = MLX::Core
340
+ sanitized = {}
341
+
342
+ weights.each do |name, param|
343
+ current = param
344
+ if name.include?("conv.weight") && _transpose_conv_weight?(param)
345
+ current = mx.swapaxes(param, 1, 2)
346
+ end
347
+
348
+ key = name
349
+ {
350
+ "w1.weight" => "gate_proj.weight",
351
+ "w2.weight" => "down_proj.weight",
352
+ "w3.weight" => "up_proj.weight",
353
+ }.each do |old_name, new_name|
354
+ key = key.gsub(old_name, new_name) if key.include?(old_name)
355
+ end
356
+
357
+ sanitized[key] = current
358
+ end
359
+
360
+ @args.num_hidden_layers.times do |layer_idx|
361
+ prefix = "model.layers.#{layer_idx}"
362
+ %w[gate_proj down_proj up_proj].each do |projection|
363
+ first_key = "#{prefix}.feed_forward.experts.0.#{projection}.weight"
364
+ next unless sanitized.key?(first_key)
365
+
366
+ expert_keys = (0...@args.num_experts).map do |expert_idx|
367
+ "#{prefix}.feed_forward.experts.#{expert_idx}.#{projection}.weight"
368
+ end
369
+ next unless expert_keys.all? { |k| sanitized.key?(k) }
370
+
371
+ stacked = expert_keys.map { |k| sanitized.delete(k) }
372
+ sanitized["#{prefix}.feed_forward.switch_mlp.#{projection}.weight"] = mx.stack(stacked)
373
+ end
374
+ end
375
+
376
+ sanitized
377
+ end
378
+
379
+ def layers
380
+ model.layers
381
+ end
382
+
383
+ def make_cache
384
+ layers.map do |layer|
385
+ if layer.is_attention_layer
386
+ MlxLm::KVCache.new
387
+ else
388
+ MlxLm::ArraysCache.new(1)
389
+ end
390
+ end
391
+ end
392
+
393
+ def quant_predicate
394
+ lambda do |path, _|
395
+ if path.end_with?("feed_forward.gate")
396
+ { group_size: 64, bits: 8 }
397
+ else
398
+ true
399
+ end
400
+ end
401
+ end
402
+
403
+ def cast_predicate
404
+ lambda { |k| !k.include?("expert_bias") }
405
+ end
406
+
407
+ private
408
+
409
+ def _transpose_conv_weight?(param)
410
+ return false unless param.respond_to?(:shape)
411
+ return false unless param.shape.is_a?(Array)
412
+ return false unless param.shape.length >= 3
413
+
414
+ param.shape[-1] > param.shape[1]
415
+ end
416
+ end
417
+
418
+ Models.register("lfm2_moe", Model, ModelArgs)
419
+ end
420
+ end
421
+ end
@@ -0,0 +1,67 @@
1
+ require_relative "lfm2"
2
+
3
+ module MlxLm
4
+ module Models
5
+ module Lfm2VL
6
+ class ModelArgs < BaseModelArgs
7
+ field :model_type, default: "lfm2-vl"
8
+ field :text_config, default: nil
9
+
10
+ def self.from_dict(params)
11
+ has_text_config = params.key?("text_config") || params.key?(:text_config)
12
+ return super if has_text_config
13
+
14
+ new(model_type: params["model_type"] || params[:model_type], text_config: params)
15
+ end
16
+
17
+ def initialize(**kwargs)
18
+ super
19
+ @text_config = _stringify_keys(@text_config || {})
20
+ @text_config["tie_word_embeddings"] = false
21
+ end
22
+
23
+ private
24
+
25
+ def _stringify_keys(hash)
26
+ hash.each_with_object({}) do |(key, value), out|
27
+ out[key.to_s] = value
28
+ end
29
+ end
30
+ end
31
+
32
+ class Model < MLX::NN::Module
33
+ def initialize(args)
34
+ super()
35
+ @args = args
36
+ self.model_type = args.model_type
37
+ self.language_model = Lfm2::Model.new(Lfm2::ModelArgs.from_dict(args.text_config))
38
+ end
39
+
40
+ def call(inputs, cache: nil, input_embeddings: nil)
41
+ language_model.call(inputs, cache: cache, input_embeddings: input_embeddings)
42
+ end
43
+
44
+ def sanitize(weights)
45
+ nested = MLX::Utils.tree_unflatten(weights.to_a)
46
+ if nested.is_a?(Hash)
47
+ nested.delete("vision_tower")
48
+ nested.delete("multi_modal_projector")
49
+ end
50
+ MLX::Utils.tree_flatten(nested, destination: {})
51
+ end
52
+
53
+ def layers
54
+ language_model.layers
55
+ end
56
+
57
+ def make_cache
58
+ return language_model.make_cache if language_model.respond_to?(:make_cache)
59
+
60
+ nil
61
+ end
62
+ end
63
+
64
+ Models.register("lfm2-vl", Model, ModelArgs)
65
+ end
66
+ end
67
+ end
@@ -0,0 +1,148 @@
1
+ module MlxLm
2
+ module Models
3
+ module Lille130m
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "lille-130m"
6
+ field :block_size
7
+ field :layer_norm_eps
8
+ field :n_embd
9
+ field :n_head
10
+ field :n_kv_heads
11
+ field :n_layer
12
+ field :rope_theta
13
+ field :vocab_size
14
+ field :tie_word_embeddings, default: true
15
+ end
16
+
17
+ class Lille130mAttention < MLX::NN::Module
18
+ def initialize(args)
19
+ super()
20
+ @n_head = args.n_head
21
+ @n_kv_heads = args.n_kv_heads
22
+ @head_dim = args.n_embd / @n_head
23
+ @scale = @head_dim**(-0.5)
24
+
25
+ self.qkv_proj = MLX::NN::Linear.new(
26
+ args.n_embd,
27
+ (@n_head + (2 * @n_kv_heads)) * @head_dim,
28
+ bias: false
29
+ )
30
+ self.out_proj = MLX::NN::Linear.new(@n_head * @head_dim, args.n_embd, bias: false)
31
+ self.norm = MLX::NN::RMSNorm.new(args.n_embd, eps: args.layer_norm_eps)
32
+ self.rope = MLX::NN::RoPE.new(@head_dim, traditional: true, base: args.rope_theta)
33
+ end
34
+
35
+ def call(x, mask: nil, cache: nil)
36
+ mx = MLX::Core
37
+ b, l, _d = x.shape
38
+
39
+ qkv = qkv_proj.call(norm.call(x))
40
+ q_size = @n_head * @head_dim
41
+ kv_size = @n_kv_heads * @head_dim
42
+ queries, keys, values = mx.split(qkv, [q_size, q_size + kv_size], 2)
43
+
44
+ queries = queries.reshape([b, l, @n_head, @head_dim]).transpose([0, 2, 1, 3])
45
+ keys = keys.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
46
+ values = values.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
47
+
48
+ if cache
49
+ queries = rope.call(queries, offset: cache.offset)
50
+ keys = rope.call(keys, offset: cache.offset)
51
+ keys, values = cache.update_and_fetch(keys, values)
52
+ else
53
+ queries = rope.call(queries)
54
+ keys = rope.call(keys)
55
+ end
56
+
57
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
58
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_head * @head_dim])
59
+ out_proj.call(output)
60
+ end
61
+ end
62
+
63
+ class Lille130mMLP < MLX::NN::Module
64
+ def initialize(args)
65
+ super()
66
+ hidden_dim = 256 * ((8 * args.n_embd / 3) / 256.0).round
67
+ hidden_dim = 256 if hidden_dim.zero?
68
+
69
+ self.norm = MLX::NN::RMSNorm.new(args.n_embd, eps: args.layer_norm_eps)
70
+ self.gate_proj = MLX::NN::Linear.new(args.n_embd, hidden_dim, bias: false)
71
+ self.up_proj = MLX::NN::Linear.new(args.n_embd, hidden_dim, bias: false)
72
+ self.down_proj = MLX::NN::Linear.new(hidden_dim, args.n_embd, bias: false)
73
+ end
74
+
75
+ def call(x)
76
+ h = norm.call(x)
77
+ down_proj.call(Activations.swiglu(gate_proj.call(h), up_proj.call(h)))
78
+ end
79
+ end
80
+
81
+ class Lille130Block < MLX::NN::Module
82
+ def initialize(args)
83
+ super()
84
+ self.attention = Lille130mAttention.new(args)
85
+ self.feed_forward = Lille130mMLP.new(args)
86
+ end
87
+
88
+ def call(x, mask: nil, cache: nil)
89
+ h = x + attention.call(x, mask: mask, cache: cache)
90
+ h + feed_forward.call(h)
91
+ end
92
+ end
93
+
94
+ class Lille130 < MLX::NN::Module
95
+ def initialize(args)
96
+ super()
97
+ self.tok_embeddings = MLX::NN::Embedding.new(args.vocab_size, args.n_embd)
98
+ self.layers = Array.new(args.n_layer) { Lille130Block.new(args) }
99
+ self.norm = MLX::NN::RMSNorm.new(args.n_embd, eps: args.layer_norm_eps)
100
+ end
101
+
102
+ def call(inputs, cache: nil)
103
+ h = tok_embeddings.call(inputs)
104
+ layer_cache = cache || [nil] * layers.length
105
+ mask = _create_attention_mask(h, layer_cache[0])
106
+
107
+ layers.each_with_index do |layer, i|
108
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
109
+ end
110
+
111
+ tok_embeddings.as_linear(norm.call(h))
112
+ end
113
+
114
+ private
115
+
116
+ def _create_attention_mask(h, cache)
117
+ return cache.make_mask(h.shape[1]) if cache && cache.respond_to?(:make_mask)
118
+ return nil if h.shape[1] == 1
119
+
120
+ "causal"
121
+ end
122
+ end
123
+
124
+ class Model < MLX::NN::Module
125
+ def initialize(args)
126
+ super()
127
+ self.args = args
128
+ self.model_type = args.model_type
129
+ self.transformer = Lille130.new(args)
130
+ end
131
+
132
+ def call(inputs, cache: nil)
133
+ transformer.call(inputs, cache: cache)
134
+ end
135
+
136
+ def layers
137
+ transformer.layers
138
+ end
139
+
140
+ def sanitize(weights)
141
+ weights.reject { |k, _| k.include?("rotary_emb") }
142
+ end
143
+ end
144
+
145
+ Models.register("lille-130m", Model, ModelArgs)
146
+ end
147
+ end
148
+ end