mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
@@ -0,0 +1,48 @@
1
+ module MlxLm
2
+ module Models
3
+ module Qwen2VL
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "qwen2_vl"
6
+ field :text_config
7
+
8
+ def self.from_dict(params)
9
+ has_text_config = params.key?("text_config") || params.key?(:text_config)
10
+ return super if has_text_config
11
+
12
+ new(model_type: params["model_type"] || params[:model_type], text_config: params)
13
+ end
14
+ end
15
+
16
+ class Model < MLX::NN::Module
17
+ def initialize(args)
18
+ super()
19
+ @args = args
20
+ self.model_type = args.model_type
21
+ self.language_model = Qwen2::Model.new(Qwen2::ModelArgs.from_dict(args.text_config))
22
+ end
23
+
24
+ def call(inputs, cache: nil, input_embeddings: nil)
25
+ language_model.call(inputs, cache: cache)
26
+ end
27
+
28
+ def sanitize(weights)
29
+ sanitized = {}
30
+ weights.each do |key, value|
31
+ next if key == "visual" || key.start_with?("visual.")
32
+ next if key == "vision_tower" || key.start_with?("vision_tower.")
33
+
34
+ mapped_key = key.start_with?("language_model.") ? key : "language_model.#{key}"
35
+ sanitized[mapped_key] = value
36
+ end
37
+ sanitized
38
+ end
39
+
40
+ def layers
41
+ language_model.model.layers
42
+ end
43
+ end
44
+
45
+ Models.register("qwen2_vl", Model, ModelArgs)
46
+ end
47
+ end
48
+ end
@@ -0,0 +1,167 @@
1
+ module MlxLm
2
+ module Models
3
+ module Qwen3
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "qwen3"
6
+ field :hidden_size, default: 2048
7
+ field :num_hidden_layers, default: 24
8
+ field :intermediate_size, default: 11008
9
+ field :num_attention_heads, default: 16
10
+ field :rms_norm_eps, default: 1e-6
11
+ field :vocab_size, default: 151936
12
+ field :num_key_value_heads, default: nil
13
+ field :max_position_embeddings, default: 32768
14
+ field :rope_theta, default: 1_000_000.0
15
+ field :head_dim, default: nil
16
+ field :tie_word_embeddings, default: true
17
+ field :rope_scaling, default: nil
18
+
19
+ def initialize(**kwargs)
20
+ super
21
+ @num_key_value_heads ||= @num_attention_heads
22
+ @head_dim ||= @hidden_size / @num_attention_heads
23
+ end
24
+ end
25
+
26
+ class Attention < MLX::NN::Module
27
+ def initialize(args)
28
+ super()
29
+
30
+ dim = args.hidden_size
31
+ @n_heads = args.num_attention_heads
32
+ @n_kv_heads = args.num_key_value_heads
33
+ @head_dim = args.head_dim
34
+ @scale = @head_dim**(-0.5)
35
+
36
+ self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: false)
37
+ self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: false)
38
+ self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: false)
39
+ self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: false)
40
+
41
+ self.q_norm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps)
42
+ self.k_norm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps)
43
+ self.rope = MlxLm::Models.initialize_rope(
44
+ @head_dim,
45
+ args.rope_theta,
46
+ false,
47
+ args.rope_scaling,
48
+ max_position_embeddings: args.max_position_embeddings
49
+ )
50
+ end
51
+
52
+ def call(x, mask: nil, cache: nil)
53
+ mx = MLX::Core
54
+ b, l, _d = x.shape
55
+
56
+ queries = q_proj.call(x)
57
+ keys = k_proj.call(x)
58
+ values = v_proj.call(x)
59
+
60
+ queries = q_norm.call(queries.reshape([b, l, @n_heads, @head_dim])).transpose([0, 2, 1, 3])
61
+ keys = k_norm.call(keys.reshape([b, l, @n_kv_heads, @head_dim])).transpose([0, 2, 1, 3])
62
+ values = values.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
63
+
64
+ if cache
65
+ queries = rope.call(queries, offset: cache.offset)
66
+ keys = rope.call(keys, offset: cache.offset)
67
+ keys, values = cache.update_and_fetch(keys, values)
68
+ else
69
+ queries = rope.call(queries)
70
+ keys = rope.call(keys)
71
+ end
72
+
73
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
74
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
75
+ o_proj.call(output)
76
+ end
77
+ end
78
+
79
+ class MLP < MLX::NN::Module
80
+ def initialize(dim, hidden_dim)
81
+ super()
82
+ self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false)
83
+ self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: false)
84
+ self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false)
85
+ end
86
+
87
+ def call(x)
88
+ down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x)))
89
+ end
90
+ end
91
+
92
+ class TransformerBlock < MLX::NN::Module
93
+ def initialize(args)
94
+ super()
95
+ self.self_attn = Attention.new(args)
96
+ self.mlp = MLP.new(args.hidden_size, args.intermediate_size)
97
+ self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
98
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
99
+ end
100
+
101
+ def call(x, mask: nil, cache: nil)
102
+ r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
103
+ h = x + r
104
+ r = mlp.call(post_attention_layernorm.call(h))
105
+ h + r
106
+ end
107
+ end
108
+
109
+ class Qwen3Model < MLX::NN::Module
110
+ def initialize(args)
111
+ super()
112
+ @args = args
113
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
114
+ self.layers = Array.new(args.num_hidden_layers) { TransformerBlock.new(args) }
115
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
116
+ end
117
+
118
+ def call(inputs, cache: nil, input_embeddings: nil)
119
+ h = input_embeddings || embed_tokens.call(inputs)
120
+ layer_cache = cache || [nil] * layers.length
121
+
122
+ mask = nil
123
+ mask = "causal" if h.shape[1] > 1
124
+
125
+ layers.each_with_index do |layer, i|
126
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
127
+ end
128
+
129
+ norm.call(h)
130
+ end
131
+ end
132
+
133
+ class Model < MLX::NN::Module
134
+ def initialize(args)
135
+ super()
136
+ @args = args
137
+ self.model_type = args.model_type
138
+ self.model = Qwen3Model.new(args)
139
+ unless args.tie_word_embeddings
140
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
141
+ end
142
+ end
143
+
144
+ def call(inputs, cache: nil, input_embeddings: nil)
145
+ out = model.call(inputs, cache: cache, input_embeddings: input_embeddings)
146
+ if @args.tie_word_embeddings
147
+ model.embed_tokens.as_linear(out)
148
+ else
149
+ lm_head.call(out)
150
+ end
151
+ end
152
+
153
+ def sanitize(weights)
154
+ result = weights.dup
155
+ result.delete("lm_head.weight") if @args.tie_word_embeddings
156
+ result
157
+ end
158
+
159
+ def layers
160
+ model.layers
161
+ end
162
+ end
163
+
164
+ Models.register("qwen3", Model, ModelArgs)
165
+ end
166
+ end
167
+ end
@@ -0,0 +1,69 @@
1
+ module MlxLm
2
+ module Models
3
+ module Qwen35
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "qwen3_5"
6
+ field :text_config, default: nil
7
+
8
+ def self.from_dict(params)
9
+ has_text_config = params.key?("text_config") || params.key?(:text_config)
10
+ return super if has_text_config
11
+
12
+ new(model_type: params["model_type"] || params[:model_type], text_config: params)
13
+ end
14
+ end
15
+
16
+ class Model < MLX::NN::Module
17
+ def initialize(args)
18
+ super()
19
+ @args = args
20
+ self.model_type = args.model_type
21
+ self.language_model = Qwen3::Model.new(Qwen3::ModelArgs.from_dict(_text_config_for_qwen3(args)))
22
+ end
23
+
24
+ def call(inputs, cache: nil, input_embeddings: nil)
25
+ language_model.call(inputs, cache: cache, input_embeddings: input_embeddings)
26
+ end
27
+
28
+ def sanitize(weights)
29
+ language_model.sanitize(remap_language_model_weights(weights))
30
+ end
31
+
32
+ def layers
33
+ language_model.layers
34
+ end
35
+
36
+ protected
37
+
38
+ def remap_language_model_weights(weights)
39
+ remapped = {}
40
+ weights.each do |key, value|
41
+ next if key.start_with?("model.visual")
42
+
43
+ mapped_key = if key.start_with?("model.language_model")
44
+ key.sub("model.language_model", "language_model.model")
45
+ elsif key.start_with?("language_model.")
46
+ key
47
+ else
48
+ "language_model.#{key}"
49
+ end
50
+ remapped[mapped_key] = value
51
+ end
52
+ remapped
53
+ end
54
+
55
+ private
56
+
57
+ def _text_config_for_qwen3(args)
58
+ config = {}
59
+ (args.text_config || {}).each { |key, value| config[key.to_s] = value }
60
+ config["model_type"] ||= args.model_type
61
+ config["tie_word_embeddings"] = false unless config.key?("tie_word_embeddings")
62
+ config
63
+ end
64
+ end
65
+
66
+ Models.register("qwen3_5", Model, ModelArgs)
67
+ end
68
+ end
69
+ end
@@ -0,0 +1,54 @@
1
+ module MlxLm
2
+ module Models
3
+ module Qwen35Moe
4
+ class ModelArgs < Qwen35::ModelArgs
5
+ field :model_type, default: "qwen3_5_moe"
6
+ end
7
+
8
+ class Model < Qwen35::Model
9
+ def sanitize(weights)
10
+ remapped = remap_language_model_weights(weights)
11
+ rewrite_moe_expert_weights(remapped)
12
+ language_model.sanitize(remapped)
13
+ end
14
+
15
+ private
16
+
17
+ def rewrite_moe_expert_weights(weights)
18
+ mx = MLX::Core
19
+
20
+ layers.length.times do |layer_idx|
21
+ prefix = "language_model.model.layers.#{layer_idx}.mlp"
22
+ gate_up_key = _first_existing_key(
23
+ weights,
24
+ ["#{prefix}.experts.gate_up_proj", "#{prefix}.experts.gate_up_proj.weight"]
25
+ )
26
+ down_proj_key = _first_existing_key(
27
+ weights,
28
+ ["#{prefix}.experts.down_proj", "#{prefix}.experts.down_proj.weight"]
29
+ )
30
+
31
+ next unless gate_up_key && down_proj_key
32
+
33
+ gate_up = weights.delete(gate_up_key)
34
+ down_proj = weights.delete(down_proj_key)
35
+ mid = gate_up.shape[-2] / 2
36
+ gate_proj, up_proj = mx.split(gate_up, [mid], -2)
37
+
38
+ weights["#{prefix}.switch_mlp.gate_proj.weight"] = gate_proj
39
+ weights["#{prefix}.switch_mlp.up_proj.weight"] = up_proj
40
+ weights["#{prefix}.switch_mlp.down_proj.weight"] = down_proj
41
+ end
42
+
43
+ weights
44
+ end
45
+
46
+ def _first_existing_key(weights, candidates)
47
+ candidates.find { |key| weights.key?(key) }
48
+ end
49
+ end
50
+
51
+ Models.register("qwen3_5_moe", Model, ModelArgs)
52
+ end
53
+ end
54
+ end
@@ -0,0 +1,166 @@
1
+ require_relative "qwen3"
2
+ require_relative "switch_layers"
3
+
4
+ module MlxLm
5
+ module Models
6
+ module Qwen3Moe
7
+ class ModelArgs < Qwen3::ModelArgs
8
+ field :model_type, default: "qwen3_moe"
9
+ field :num_experts, default: 128
10
+ field :num_experts_per_tok, default: 8
11
+ field :decoder_sparse_step, default: 1
12
+ field :mlp_only_layers, default: []
13
+ field :moe_intermediate_size, default: 1408
14
+ field :norm_topk_prob, default: false
15
+
16
+ def initialize(**kwargs)
17
+ super
18
+ @mlp_only_layers ||= []
19
+ end
20
+ end
21
+
22
+ class SparseMoeBlock < MLX::NN::Module
23
+ def initialize(args)
24
+ super()
25
+ @top_k = [args.num_experts_per_tok.to_i, 1].max
26
+ @num_experts = args.num_experts
27
+ @norm_topk_prob = args.norm_topk_prob
28
+
29
+ dim = args.hidden_size
30
+ hidden_dim = args.moe_intermediate_size
31
+
32
+ self.gate = MLX::NN::Linear.new(dim, @num_experts, bias: false)
33
+ self.switch_mlp = SwitchLayers::SwitchGLU.new(dim, hidden_dim, @num_experts)
34
+ end
35
+
36
+ def call(x)
37
+ mx = MLX::Core
38
+
39
+ gates = gate.call(x)
40
+ gates = mx.softmax(gates.astype(mx.float32), -1).astype(gates.dtype)
41
+
42
+ k = [@top_k, @num_experts].min
43
+ inds = mx.stop_gradient(mx.argpartition(gates * -1.0, k - 1, -1))
44
+ take_ids = mx.array((0...k).to_a, dtype: mx.int32)
45
+ inds = mx.take(inds, take_ids, -1)
46
+ scores = mx.take_along_axis(gates, inds, -1)
47
+
48
+ if @norm_topk_prob
49
+ denom = mx.expand_dims(mx.sum(scores, -1), -1)
50
+ scores = scores / denom
51
+ end
52
+
53
+ y = switch_mlp.call(x, inds)
54
+ mx.sum(y * mx.expand_dims(scores, -1), -2)
55
+ end
56
+ end
57
+
58
+ class DecoderLayer < MLX::NN::Module
59
+ def initialize(args, layer_idx)
60
+ super()
61
+ self.self_attn = Qwen3::Attention.new(args)
62
+ self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
63
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
64
+
65
+ if _use_sparse_moe_layer?(args, layer_idx)
66
+ self.mlp = SparseMoeBlock.new(args)
67
+ else
68
+ self.mlp = Qwen3::MLP.new(args.hidden_size, args.intermediate_size)
69
+ end
70
+ end
71
+
72
+ def call(x, mask: nil, cache: nil)
73
+ r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
74
+ h = x + r
75
+ r = mlp.call(post_attention_layernorm.call(h))
76
+ h + r
77
+ end
78
+
79
+ private
80
+
81
+ def _use_sparse_moe_layer?(args, layer_idx)
82
+ sparse_step = [args.decoder_sparse_step.to_i, 1].max
83
+ mlp_only_layers = args.mlp_only_layers || []
84
+
85
+ !mlp_only_layers.include?(layer_idx) &&
86
+ args.num_experts.to_i > 0 &&
87
+ ((layer_idx + 1) % sparse_step).zero?
88
+ end
89
+ end
90
+
91
+ class Qwen3MoeModel < MLX::NN::Module
92
+ def initialize(args)
93
+ super()
94
+ @args = args
95
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
96
+ self.layers = Array.new(args.num_hidden_layers) { |layer_idx| DecoderLayer.new(args, layer_idx) }
97
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
98
+ end
99
+
100
+ def call(inputs, cache: nil, input_embeddings: nil)
101
+ h = input_embeddings || embed_tokens.call(inputs)
102
+ layer_cache = cache || [nil] * layers.length
103
+
104
+ mask = nil
105
+ mask = "causal" if h.shape[1] > 1
106
+
107
+ layers.each_with_index do |layer, layer_idx|
108
+ h = layer.call(h, mask: mask, cache: layer_cache[layer_idx])
109
+ end
110
+
111
+ norm.call(h)
112
+ end
113
+ end
114
+
115
+ class Model < MLX::NN::Module
116
+ def initialize(args)
117
+ super()
118
+ @args = args
119
+ self.model_type = args.model_type
120
+ self.model = Qwen3MoeModel.new(args)
121
+ unless args.tie_word_embeddings
122
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
123
+ end
124
+ end
125
+
126
+ def call(inputs, cache: nil, input_embeddings: nil)
127
+ out = model.call(inputs, cache: cache, input_embeddings: input_embeddings)
128
+ if @args.tie_word_embeddings
129
+ model.embed_tokens.as_linear(out)
130
+ else
131
+ lm_head.call(out)
132
+ end
133
+ end
134
+
135
+ def sanitize(weights)
136
+ mx = MLX::Core
137
+
138
+ result = weights.dup
139
+ result.delete("lm_head.weight") if @args.tie_word_embeddings
140
+ return result unless result.key?("model.layers.0.mlp.experts.0.up_proj.weight")
141
+
142
+ @args.num_hidden_layers.times do |layer_idx|
143
+ prefix = "model.layers.#{layer_idx}.mlp"
144
+ %w[up_proj down_proj gate_proj].each do |projection|
145
+ expert_keys = (0...@args.num_experts).map do |expert_idx|
146
+ "#{prefix}.experts.#{expert_idx}.#{projection}.weight"
147
+ end
148
+ next unless expert_keys.all? { |key| result.key?(key) }
149
+
150
+ stacked = expert_keys.map { |key| result.delete(key) }
151
+ result["#{prefix}.switch_mlp.#{projection}.weight"] = mx.stack(stacked)
152
+ end
153
+ end
154
+
155
+ result
156
+ end
157
+
158
+ def layers
159
+ model.layers
160
+ end
161
+ end
162
+
163
+ Models.register("qwen3_moe", Model, ModelArgs)
164
+ end
165
+ end
166
+ end
@@ -0,0 +1,147 @@
1
+ require_relative "kimi_linear"
2
+
3
+ module MlxLm
4
+ module Models
5
+ module Qwen3Next
6
+ class ModelArgs < KimiLinear::ModelArgs
7
+ field :model_type, default: "qwen3_next"
8
+ field :linear_num_value_heads, default: nil
9
+ field :linear_num_key_heads, default: nil
10
+ field :linear_key_head_dim, default: nil
11
+ field :linear_value_head_dim, default: nil
12
+ field :linear_conv_kernel_dim, default: nil
13
+ field :decoder_sparse_step, default: nil
14
+ field :shared_expert_intermediate_size, default: nil
15
+ field :mlp_only_layers, default: []
16
+ field :full_attention_interval, default: 4
17
+ field :head_dim, default: nil
18
+ field :attention_bias, default: false
19
+ field :num_shared_experts, default: 1
20
+ field :norm_topk_prob, default: false
21
+ field :first_k_dense_replace, default: 0
22
+
23
+ def self.from_dict(params)
24
+ normalized = params.each_with_object({}) do |(key, value), out|
25
+ out[key.to_s] = value
26
+ end
27
+
28
+ {
29
+ "shared_expert_intermediate_size" => "moe_shared_expert_intermediate_size",
30
+ }.each do |source_key, target_key|
31
+ next unless normalized.key?(source_key)
32
+
33
+ normalized[target_key] = normalized[source_key] unless normalized.key?(target_key)
34
+ end
35
+
36
+ if normalized.key?("attention_bias")
37
+ normalized["use_bias"] = normalized["attention_bias"] unless normalized.key?("use_bias")
38
+ normalized["use_qkv_bias"] = normalized["attention_bias"] unless normalized.key?("use_qkv_bias")
39
+ end
40
+
41
+ if normalized.key?("linear_num_key_heads") && !normalized.key?("num_key_value_heads")
42
+ normalized["num_key_value_heads"] = normalized["linear_num_key_heads"]
43
+ end
44
+
45
+ if normalized.key?("mlp_only_layers") && !normalized.key?("first_k_dense_replace")
46
+ normalized["first_k_dense_replace"] = _dense_prefix_length(normalized["mlp_only_layers"])
47
+ end
48
+
49
+ normalized["num_shared_experts"] = 1 unless normalized.key?("num_shared_experts")
50
+ normalized["norm_topk_prob"] = false unless normalized.key?("norm_topk_prob")
51
+ normalized["first_k_dense_replace"] = 0 unless normalized.key?("first_k_dense_replace")
52
+ normalized["model_type"] ||= "qwen3_next"
53
+ super(normalized)
54
+ end
55
+
56
+ def initialize(**kwargs)
57
+ super
58
+ @moe_shared_expert_intermediate_size = @shared_expert_intermediate_size if kwargs.key?(:shared_expert_intermediate_size) && !kwargs.key?(:moe_shared_expert_intermediate_size) && !@shared_expert_intermediate_size.nil?
59
+
60
+ if kwargs.key?(:attention_bias) && !@attention_bias.nil?
61
+ @use_bias = @attention_bias unless kwargs.key?(:use_bias)
62
+ @use_qkv_bias = @attention_bias unless kwargs.key?(:use_qkv_bias)
63
+ end
64
+
65
+ if kwargs.key?(:mlp_only_layers) && !kwargs.key?(:first_k_dense_replace)
66
+ @first_k_dense_replace = self.class._dense_prefix_length(@mlp_only_layers)
67
+ end
68
+
69
+ @num_shared_experts = 1 if @num_shared_experts.nil?
70
+ @norm_topk_prob = false if @norm_topk_prob.nil?
71
+ @first_k_dense_replace = 0 if @first_k_dense_replace.nil?
72
+ @num_key_value_heads ||= @num_attention_heads
73
+ end
74
+
75
+ def to_kimi_linear_dict
76
+ dict = to_bailing_moe_linear_dict
77
+ dict["model_type"] = @model_type
78
+ dict["num_shared_experts"] = @num_shared_experts || 1
79
+ dict["norm_topk_prob"] = @norm_topk_prob.nil? ? false : @norm_topk_prob
80
+ dict["first_k_dense_replace"] = @first_k_dense_replace || 0
81
+ dict["use_bias"] = @use_bias
82
+ dict["use_qkv_bias"] = @use_qkv_bias
83
+ dict["moe_shared_expert_intermediate_size"] = @moe_shared_expert_intermediate_size unless @moe_shared_expert_intermediate_size.nil?
84
+ dict
85
+ end
86
+
87
+ def self._dense_prefix_length(mlp_only_layers)
88
+ layers = Array(mlp_only_layers).map(&:to_i)
89
+ count = 0
90
+ count += 1 while layers.include?(count)
91
+ count
92
+ end
93
+ end
94
+
95
+ class Model < MLX::NN::Module
96
+ def initialize(args)
97
+ super()
98
+ @args = args
99
+ self.model_type = args.model_type
100
+ self.wrapped_model = KimiLinear::Model.new(
101
+ KimiLinear::ModelArgs.from_dict(args.to_kimi_linear_dict)
102
+ )
103
+ end
104
+
105
+ def call(inputs, cache: nil)
106
+ wrapped_model.call(inputs, cache: cache)
107
+ end
108
+
109
+ def sanitize(weights)
110
+ remapped = {}
111
+ flat_weights = weights.is_a?(Hash) ? weights : weights.to_h
112
+ flat_weights.each do |key, value|
113
+ mapped = key.to_s.gsub(".mlp.shared_expert.", ".mlp.shared_experts.")
114
+ next if mapped.include?(".mtp.")
115
+
116
+ remapped[mapped] = value
117
+ end
118
+ wrapped_model.sanitize(remapped)
119
+ end
120
+
121
+ def layers
122
+ wrapped_model.layers
123
+ end
124
+
125
+ def make_cache
126
+ return wrapped_model.make_cache if wrapped_model.respond_to?(:make_cache)
127
+
128
+ nil
129
+ end
130
+
131
+ def cast_predicate
132
+ return wrapped_model.cast_predicate if wrapped_model.respond_to?(:cast_predicate)
133
+
134
+ lambda { |_key| true }
135
+ end
136
+
137
+ def quant_predicate
138
+ return wrapped_model.quant_predicate if wrapped_model.respond_to?(:quant_predicate)
139
+
140
+ lambda { |_key, _value| true }
141
+ end
142
+ end
143
+
144
+ Models.register("qwen3_next", Model, ModelArgs)
145
+ end
146
+ end
147
+ end
@@ -0,0 +1,48 @@
1
+ module MlxLm
2
+ module Models
3
+ module Qwen3VL
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "qwen3_vl"
6
+ field :text_config, default: nil
7
+
8
+ def self.from_dict(params)
9
+ return super if params.key?("text_config")
10
+
11
+ new(model_type: params["model_type"], text_config: params)
12
+ end
13
+ end
14
+
15
+ class Model < MLX::NN::Module
16
+ def initialize(args)
17
+ super()
18
+ @args = args
19
+ self.model_type = args.model_type
20
+ self.language_model = Qwen3::Model.new(Qwen3::ModelArgs.from_dict(args.text_config))
21
+ end
22
+
23
+ def call(inputs, cache: nil, input_embeddings: nil)
24
+ language_model.call(inputs, cache: cache, input_embeddings: input_embeddings)
25
+ end
26
+
27
+ def sanitize(weights)
28
+ nested = MLX::Utils.tree_unflatten(weights.to_a)
29
+ nested.delete("vision_tower") if nested.is_a?(Hash)
30
+
31
+ flattened = MLX::Utils.tree_flatten(nested, destination: {})
32
+ sanitized = {}
33
+ flattened.each do |key, value|
34
+ sanitized_key = key.start_with?("language_model.") ? key : "language_model.#{key}"
35
+ sanitized[sanitized_key] = value
36
+ end
37
+ sanitized
38
+ end
39
+
40
+ def layers
41
+ language_model.layers
42
+ end
43
+ end
44
+
45
+ Models.register("qwen3_vl", Model, ModelArgs)
46
+ end
47
+ end
48
+ end