mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
@@ -0,0 +1,283 @@
1
+ require_relative "activations"
2
+ require_relative "switch_layers"
3
+
4
+ module MlxLm
5
+ module Models
6
+ module Klear
7
+ class ModelArgs < BaseModelArgs
8
+ field :model_type, default: "Klear"
9
+ field :hidden_size
10
+ field :num_hidden_layers
11
+ field :intermediate_size
12
+ field :num_attention_heads
13
+ field :attention_bias
14
+ field :mlp_only_layers
15
+ field :num_experts
16
+ field :num_experts_per_tok
17
+ field :decoder_sparse_step
18
+ field :n_shared_experts
19
+ field :moe_intermediate_size
20
+ field :rms_norm_eps
21
+ field :vocab_size
22
+ field :num_key_value_heads
23
+ field :rope_theta
24
+ field :max_position_embeddings
25
+ field :norm_topk_prob
26
+
27
+ def initialize(**kwargs)
28
+ super
29
+ @mlp_only_layers ||= []
30
+ @num_key_value_heads ||= @num_attention_heads
31
+ end
32
+ end
33
+
34
+ class KlearAttention < MLX::NN::Module
35
+ def initialize(args)
36
+ super()
37
+ @num_attention_heads = args.num_attention_heads
38
+ @num_key_value_heads = args.num_key_value_heads
39
+ @head_dim = args.hidden_size / args.num_attention_heads
40
+ @scale = @head_dim**(-0.5)
41
+
42
+ self.q_proj = MLX::NN::Linear.new(
43
+ args.hidden_size,
44
+ @num_attention_heads * @head_dim,
45
+ bias: args.attention_bias
46
+ )
47
+ self.k_proj = MLX::NN::Linear.new(
48
+ args.hidden_size,
49
+ @num_key_value_heads * @head_dim,
50
+ bias: args.attention_bias
51
+ )
52
+ self.v_proj = MLX::NN::Linear.new(
53
+ args.hidden_size,
54
+ @num_key_value_heads * @head_dim,
55
+ bias: args.attention_bias
56
+ )
57
+ self.o_proj = MLX::NN::Linear.new(
58
+ @num_attention_heads * @head_dim,
59
+ args.hidden_size,
60
+ bias: args.attention_bias
61
+ )
62
+
63
+ self.q_norm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps)
64
+ self.k_norm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps)
65
+ self.rope = MLX::NN::RoPE.new(@head_dim, traditional: false, base: args.rope_theta)
66
+ end
67
+
68
+ def call(x, mask: nil, cache: nil)
69
+ mx = MLX::Core
70
+ b, l, _d = x.shape
71
+
72
+ queries = q_proj.call(x)
73
+ keys = k_proj.call(x)
74
+ values = v_proj.call(x)
75
+
76
+ queries = q_norm.call(queries.reshape([b, l, @num_attention_heads, @head_dim])).transpose([0, 2, 1, 3])
77
+ keys = k_norm.call(keys.reshape([b, l, @num_key_value_heads, @head_dim])).transpose([0, 2, 1, 3])
78
+ values = values.reshape([b, l, @num_key_value_heads, @head_dim]).transpose([0, 2, 1, 3])
79
+
80
+ if cache
81
+ queries = rope.call(queries, offset: cache.offset)
82
+ keys = rope.call(keys, offset: cache.offset)
83
+ keys, values = cache.update_and_fetch(keys, values)
84
+ else
85
+ queries = rope.call(queries)
86
+ keys = rope.call(keys)
87
+ end
88
+
89
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
90
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @num_attention_heads * @head_dim])
91
+ o_proj.call(output)
92
+ end
93
+ end
94
+
95
+ class KlearMLP < MLX::NN::Module
96
+ def initialize(dim, hidden_dim)
97
+ super()
98
+ self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false)
99
+ self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: false)
100
+ self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false)
101
+ end
102
+
103
+ def call(x)
104
+ down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x)))
105
+ end
106
+ end
107
+
108
+ class KlearSparseMoeBlock < MLX::NN::Module
109
+ def initialize(args)
110
+ super()
111
+ @norm_topk_prob = args.norm_topk_prob
112
+ @num_experts = args.num_experts
113
+ @top_k = [args.num_experts_per_tok.to_i, 1].max
114
+
115
+ self.gate = MLX::NN::Linear.new(args.hidden_size, @num_experts, bias: false)
116
+ self.experts = SwitchLayers::SwitchGLU.new(
117
+ args.hidden_size,
118
+ args.moe_intermediate_size,
119
+ @num_experts
120
+ )
121
+ self.shared_experts = KlearMLP.new(
122
+ args.hidden_size,
123
+ args.moe_intermediate_size * args.n_shared_experts
124
+ )
125
+ self.coefficient = MLX::NN::Linear.new(args.hidden_size, 2)
126
+
127
+ mx = MLX::Core
128
+ self.expert_bias = mx.zeros([@num_experts]).astype(mx.float32)
129
+ end
130
+
131
+ def call(x)
132
+ mx = MLX::Core
133
+
134
+ routing_weights = mx.sigmoid(gate.call(x).astype(mx.float32))
135
+ biased_weights = routing_weights + expert_bias.reshape([1, 1, @num_experts])
136
+
137
+ k = [@top_k, @num_experts].min
138
+ inds = mx.argpartition(biased_weights * -1.0, k - 1, -1)
139
+ take_ids = mx.array((0...k).to_a, dtype: mx.int32)
140
+ inds = mx.take(inds, take_ids, -1)
141
+
142
+ scores = mx.take_along_axis(routing_weights, inds, -1)
143
+ if @norm_topk_prob
144
+ denom = mx.expand_dims(mx.sum(scores, -1), -1)
145
+ scores = scores / denom
146
+ end
147
+
148
+ scores = scores.astype(x.dtype)
149
+ expert_out = experts.call(x, inds)
150
+ y_experts = mx.sum(expert_out * mx.expand_dims(scores, -1), -2)
151
+
152
+ coef = mx.softmax(coefficient.call(x).astype(mx.float32), -1).astype(x.dtype)
153
+ coef_expert, coef_shared = mx.split(coef, [1], -1)
154
+ shared = shared_experts.call(x)
155
+
156
+ y_experts * coef_expert + shared * coef_shared
157
+ end
158
+ end
159
+
160
+ class KlearDecoderLayer < MLX::NN::Module
161
+ def initialize(args, layer_idx:)
162
+ super()
163
+ self.self_attn = KlearAttention.new(args)
164
+ self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
165
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
166
+
167
+ if _use_sparse_moe_layer?(args, layer_idx)
168
+ self.mlp = KlearSparseMoeBlock.new(args)
169
+ else
170
+ self.mlp = KlearMLP.new(args.hidden_size, args.intermediate_size)
171
+ end
172
+ end
173
+
174
+ def call(x, mask: nil, cache: nil)
175
+ r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
176
+ h = x + r
177
+ r = mlp.call(post_attention_layernorm.call(h))
178
+ h + r
179
+ end
180
+
181
+ private
182
+
183
+ def _use_sparse_moe_layer?(args, layer_idx)
184
+ sparse_step = [args.decoder_sparse_step.to_i, 1].max
185
+ mlp_only_layers = args.mlp_only_layers || []
186
+
187
+ !mlp_only_layers.include?(layer_idx) &&
188
+ args.num_experts.to_i > 0 &&
189
+ ((layer_idx + 1) % sparse_step).zero?
190
+ end
191
+ end
192
+
193
+ class KlearModel < MLX::NN::Module
194
+ def initialize(args)
195
+ super()
196
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
197
+ self.layers = Array.new(args.num_hidden_layers) do |layer_idx|
198
+ KlearDecoderLayer.new(args, layer_idx: layer_idx)
199
+ end
200
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
201
+ end
202
+
203
+ def call(inputs, cache: nil)
204
+ h = embed_tokens.call(inputs)
205
+ layer_cache = cache || [nil] * layers.length
206
+ mask = _create_attention_mask(h, layer_cache[0])
207
+
208
+ layers.each_with_index do |layer, layer_idx|
209
+ h = layer.call(h, mask: mask, cache: layer_cache[layer_idx])
210
+ end
211
+
212
+ norm.call(h)
213
+ end
214
+
215
+ private
216
+
217
+ def _create_attention_mask(h, cache)
218
+ n = h.shape[1]
219
+ return cache.make_mask(n) if cache && cache.respond_to?(:make_mask)
220
+ return nil if n == 1
221
+
222
+ "causal"
223
+ end
224
+ end
225
+
226
+ class Model < MLX::NN::Module
227
+ def initialize(args)
228
+ super()
229
+ @args = args
230
+ self.model_type = args.model_type
231
+ self.model = KlearModel.new(args)
232
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
233
+ end
234
+
235
+ def call(inputs, cache: nil)
236
+ lm_head.call(model.call(inputs, cache: cache))
237
+ end
238
+
239
+ def sanitize(weights)
240
+ return weights unless weights.key?("model.layers.0.mlp.experts.0.gate_proj.weight")
241
+
242
+ mx = MLX::Core
243
+ result = weights.dup
244
+
245
+ @args.num_hidden_layers.times do |layer_idx|
246
+ prefix = "model.layers.#{layer_idx}.mlp.experts"
247
+ %w[gate_proj up_proj down_proj].each do |name|
248
+ expert_keys = (0...@args.num_experts).map do |expert_idx|
249
+ "#{prefix}.#{expert_idx}.#{name}.weight"
250
+ end
251
+ next unless expert_keys.all? { |key| result.key?(key) }
252
+
253
+ stacked = expert_keys.map { |key| result.delete(key) }
254
+ result["#{prefix}.#{name}.weight"] = mx.stack(stacked)
255
+ end
256
+ end
257
+
258
+ result
259
+ end
260
+
261
+ def layers
262
+ model.layers
263
+ end
264
+
265
+ def quant_predicate
266
+ lambda do |path, _module|
267
+ if path.to_s.end_with?("mlp.gate")
268
+ { "group_size" => 64, "bits" => 8 }
269
+ else
270
+ true
271
+ end
272
+ end
273
+ end
274
+
275
+ def cast_predicate
276
+ lambda { |key| !key.to_s.include?("expert_bias") }
277
+ end
278
+ end
279
+
280
+ Models.register("Klear", Model, ModelArgs)
281
+ end
282
+ end
283
+ end
@@ -0,0 +1,120 @@
1
+ require_relative "qwen3"
2
+
3
+ module MlxLm
4
+ module Models
5
+ module Lfm2
6
+ class ModelArgs < BaseModelArgs
7
+ field :model_type, default: "lfm2"
8
+ field :vocab_size, default: 32000
9
+ field :hidden_size, default: 4096
10
+ field :num_hidden_layers, default: 32
11
+ field :num_attention_heads, default: 32
12
+ field :num_key_value_heads, default: nil
13
+ field :max_position_embeddings, default: 2048
14
+ field :norm_eps, default: 1e-6
15
+ field :conv_bias, default: false
16
+ field :conv_L_cache, default: 4
17
+ field :block_dim, default: nil
18
+ field :block_ff_dim, default: nil
19
+ field :block_multiple_of, default: 256
20
+ field :block_ffn_dim_multiplier, default: nil
21
+ field :block_auto_adjust_ff_dim, default: false
22
+ field :rope_theta, default: 1_000_000.0
23
+ field :rope_parameters, default: nil
24
+ field :full_attn_idxs, default: nil
25
+ field :layer_types, default: nil
26
+ field :tie_word_embeddings, default: true
27
+
28
+ def initialize(**kwargs)
29
+ super
30
+ rope_theta_from_params = _rope_theta_from_parameters
31
+ @rope_theta = rope_theta_from_params unless rope_theta_from_params.nil?
32
+ @num_key_value_heads ||= @num_attention_heads
33
+ @block_dim ||= @hidden_size
34
+ @block_ff_dim ||= @block_dim * 4
35
+ @full_attn_idxs ||= _full_attn_idxs_from_layer_types
36
+ end
37
+
38
+ private
39
+
40
+ def _rope_theta_from_parameters
41
+ return nil unless @rope_parameters.is_a?(Hash)
42
+
43
+ @rope_parameters["rope_theta"] || @rope_parameters[:rope_theta]
44
+ end
45
+
46
+ def _full_attn_idxs_from_layer_types
47
+ return [] unless @layer_types.is_a?(Array)
48
+
49
+ @layer_types.each_with_index.filter_map do |layer_type, i|
50
+ i if layer_type.to_s == "full_attention"
51
+ end
52
+ end
53
+ end
54
+
55
+ class Model < MLX::NN::Module
56
+ def initialize(args)
57
+ super()
58
+ @args = args
59
+ self.model_type = args.model_type
60
+ self.language_model = Qwen3::Model.new(Qwen3::ModelArgs.from_dict(_qwen3_config(args)))
61
+ end
62
+
63
+ def call(inputs, cache: nil, input_embeddings: nil)
64
+ language_model.call(inputs, cache: cache, input_embeddings: input_embeddings)
65
+ end
66
+
67
+ def sanitize(weights)
68
+ sanitized = {}
69
+ weights.each do |name, param|
70
+ current = param
71
+ if name.include?("conv.weight") && _transpose_conv_weight?(param)
72
+ current = MLX::Core.swapaxes(param, 1, 2)
73
+ end
74
+ sanitized[name] = current
75
+ end
76
+ sanitized
77
+ end
78
+
79
+ def layers
80
+ language_model.layers
81
+ end
82
+
83
+ def make_cache
84
+ return language_model.make_cache if language_model.respond_to?(:make_cache)
85
+ return nil unless defined?(MlxLm::KVCache)
86
+
87
+ Array.new(layers.length) { MlxLm::KVCache.new }
88
+ end
89
+
90
+ private
91
+
92
+ def _transpose_conv_weight?(param)
93
+ return false unless param.respond_to?(:shape)
94
+ return false unless param.shape.is_a?(Array)
95
+ return false unless param.shape.length >= 3
96
+
97
+ param.shape[-1] > param.shape[1]
98
+ end
99
+
100
+ def _qwen3_config(args)
101
+ {
102
+ "model_type" => "qwen3",
103
+ "hidden_size" => args.hidden_size,
104
+ "num_hidden_layers" => args.num_hidden_layers,
105
+ "intermediate_size" => args.block_ff_dim,
106
+ "num_attention_heads" => args.num_attention_heads,
107
+ "num_key_value_heads" => args.num_key_value_heads,
108
+ "rms_norm_eps" => args.norm_eps,
109
+ "vocab_size" => args.vocab_size,
110
+ "rope_theta" => args.rope_theta,
111
+ "max_position_embeddings" => args.max_position_embeddings,
112
+ "tie_word_embeddings" => args.tie_word_embeddings,
113
+ }
114
+ end
115
+ end
116
+
117
+ Models.register("lfm2", Model, ModelArgs)
118
+ end
119
+ end
120
+ end