mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
@@ -0,0 +1,173 @@
1
+ require_relative "falcon_h1"
2
+
3
+ module MlxLm
4
+ module Models
5
+ module Plamo2
6
+ class ModelArgs < FalconH1::ModelArgs
7
+ field :model_type, default: "plamo2"
8
+ field :rope_theta, default: 10_000.0
9
+ field :tie_word_embeddings, default: true
10
+ field :hidden_size_per_head, default: nil
11
+ field :full_attention_idx, default: nil
12
+ field :mamba_d_state, default: nil
13
+ field :mamba_num_heads, default: nil
14
+ field :mamba_step, default: 2
15
+ field :mamba_chunk_size, default: nil
16
+ field :mamba_enabled, default: true
17
+
18
+ def initialize(**kwargs)
19
+ super
20
+ @head_dim = @hidden_size_per_head if kwargs.key?(:hidden_size_per_head) && !kwargs.key?(:head_dim) && !@hidden_size_per_head.nil?
21
+ @num_attention_heads ||= @mamba_num_heads
22
+ @num_key_value_heads ||= @num_attention_heads
23
+ @mamba_d_conv ||= 4
24
+ @attention_window_size ||= @max_position_embeddings
25
+ @block_types ||= _to_block_types
26
+ end
27
+
28
+ def to_falcon_h1_dict
29
+ hidden_size = @hidden_size
30
+ attention_heads = @num_attention_heads
31
+ inferred_head_dim = if !@head_dim.nil?
32
+ @head_dim
33
+ elsif !@hidden_size_per_head.nil?
34
+ @hidden_size_per_head
35
+ elsif !hidden_size.nil? && attention_heads.to_i > 0
36
+ hidden_size / attention_heads
37
+ else
38
+ 64
39
+ end
40
+
41
+ {
42
+ "model_type" => @model_type,
43
+ "attention_bias" => @attention_bias,
44
+ "head_dim" => inferred_head_dim,
45
+ "hidden_size" => hidden_size,
46
+ "intermediate_size" => @intermediate_size,
47
+ "max_position_embeddings" => @max_position_embeddings,
48
+ "mamba_d_conv" => @mamba_d_conv,
49
+ "num_attention_heads" => attention_heads,
50
+ "num_hidden_layers" => @num_hidden_layers,
51
+ "num_key_value_heads" => @num_key_value_heads,
52
+ "rms_norm_eps" => @rms_norm_eps,
53
+ "rope_theta" => @rope_theta,
54
+ "vocab_size" => @vocab_size,
55
+ "tie_word_embeddings" => @tie_word_embeddings,
56
+ "attention_window_size" => @attention_window_size,
57
+ "block_types" => @block_types,
58
+ }
59
+ end
60
+
61
+ private
62
+
63
+ def _to_block_types
64
+ return @block_types if @block_types.is_a?(Array) && !@block_types.empty?
65
+
66
+ count = @num_hidden_layers.to_i
67
+ return nil if count <= 0
68
+
69
+ if @full_attention_idx.is_a?(Array) && !@full_attention_idx.empty?
70
+ full_attention = @full_attention_idx.map(&:to_i)
71
+ return Array.new(count) { |i| full_attention.include?(i) ? "attention" : "recurrent" }
72
+ end
73
+
74
+ return Array.new(count, "attention") unless @mamba_enabled
75
+
76
+ step = @mamba_step.to_i
77
+ step = 2 if step <= 1
78
+ midpoint = step / 2
79
+
80
+ if count <= midpoint
81
+ return Array.new(count) { |i| i == count - 1 ? "attention" : "recurrent" }
82
+ end
83
+
84
+ Array.new(count) { |i| (i % step) == midpoint ? "attention" : "recurrent" }
85
+ end
86
+ end
87
+
88
+ class Model < MLX::NN::Module
89
+ def initialize(args)
90
+ super()
91
+ @args = args
92
+ self.model_type = args.model_type
93
+ self.wrapped_model = FalconH1::Model.new(
94
+ FalconH1::ModelArgs.from_dict(args.to_falcon_h1_dict)
95
+ )
96
+ end
97
+
98
+ def call(inputs, cache: nil)
99
+ wrapped_model.call(inputs, cache: cache)
100
+ end
101
+
102
+ def sanitize(weights)
103
+ normalized = weights.is_a?(Hash) ? weights.dup : weights.to_h
104
+ _split_gate_up_proj!(normalized)
105
+
106
+ remapped = {}
107
+ normalized.each do |key, value|
108
+ remapped[_remap_weight_key(key)] = value
109
+ end
110
+
111
+ wrapped_model.sanitize(remapped)
112
+ end
113
+
114
+ def layers
115
+ wrapped_model.layers
116
+ end
117
+
118
+ def make_cache
119
+ return nil unless wrapped_model.respond_to?(:make_cache)
120
+
121
+ wrapped_model.make_cache
122
+ end
123
+
124
+ private
125
+
126
+ def _split_gate_up_proj!(weights)
127
+ mx = MLX::Core
128
+ pattern = /\A(model\.layers(?:\.layers)?\.\d+\.mlp)\.gate_up_proj\.(weight|bias|scales|biases)\z/
129
+
130
+ weights.keys.each do |key|
131
+ match = pattern.match(key)
132
+ next unless match
133
+
134
+ prefix = match[1]
135
+ param = match[2]
136
+ gate_up = weights.delete(key)
137
+ mid = gate_up.shape[0] / 2
138
+ next if mid <= 0
139
+
140
+ gate_proj, up_proj = mx.split(gate_up, [mid], 0)
141
+ weights["#{prefix}.gate_proj.#{param}"] = gate_proj
142
+ weights["#{prefix}.up_proj.#{param}"] = up_proj
143
+ end
144
+ end
145
+
146
+ def _remap_weight_key(key)
147
+ mapped = key.dup
148
+ mapped = mapped.gsub("model.layers.layers.", "model.layers.")
149
+ mapped = mapped.gsub("model.norm.", "model.final_layernorm.")
150
+
151
+ mapped = mapped.gsub(/\.layers\.(\d+)\.pre_mixer_norm\./) { ".layers.#{$1}.input_layernorm." }
152
+ mapped = mapped.gsub(/\.layers\.(\d+)\.pre_mlp_norm\./) { ".layers.#{$1}.pre_ff_layernorm." }
153
+
154
+ mapped = mapped.gsub(".mixer.conv1d.", ".mamba.conv1d.")
155
+ mapped = mapped.gsub(".mixer.in_proj.", ".mamba.in_proj.")
156
+ mapped = mapped.gsub(".mixer.out_proj.", ".mamba.out_proj.")
157
+ mapped = mapped.gsub(".mixer.qkv_proj.", ".self_attn.q_proj.")
158
+ mapped = mapped.gsub(".mixer.q_proj.", ".self_attn.q_proj.")
159
+ mapped = mapped.gsub(".mixer.k_proj.", ".self_attn.k_proj.")
160
+ mapped = mapped.gsub(".mixer.v_proj.", ".self_attn.v_proj.")
161
+ mapped = mapped.gsub(".mixer.o_proj.", ".self_attn.o_proj.")
162
+ mapped = mapped.gsub(".mlp.gate_up_proj.", ".feed_forward.gate_proj.")
163
+ mapped = mapped.gsub(".mlp.gate_proj.", ".feed_forward.gate_proj.")
164
+ mapped = mapped.gsub(".mlp.up_proj.", ".feed_forward.up_proj.")
165
+ mapped = mapped.gsub(".mlp.down_proj.", ".feed_forward.down_proj.")
166
+ mapped
167
+ end
168
+ end
169
+
170
+ Models.register("plamo2", Model, ModelArgs)
171
+ end
172
+ end
173
+ end
@@ -0,0 +1,175 @@
1
+ module MlxLm
2
+ module Models
3
+ module Qwen
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "qwen"
6
+ field :hidden_size, default: 2048
7
+ field :num_attention_heads, default: 16
8
+ field :num_hidden_layers, default: 24
9
+ field :kv_channels, default: 128
10
+ field :max_position_embeddings, default: 8192
11
+ field :layer_norm_epsilon, default: 1e-6
12
+ field :intermediate_size, default: 11008
13
+ field :no_bias, default: true
14
+ field :vocab_size, default: 151936
15
+ field :num_key_value_heads, default: nil
16
+
17
+ def initialize(**kwargs)
18
+ super
19
+ @num_key_value_heads ||= @num_attention_heads
20
+ end
21
+ end
22
+
23
+ class Attention < MLX::NN::Module
24
+ def initialize(args)
25
+ super()
26
+
27
+ hidden_size = args.hidden_size
28
+ @num_attention_heads = args.num_attention_heads
29
+ hidden_size_per_attention_head = hidden_size / @num_attention_heads
30
+
31
+ self.rotary_emb = MLX::NN::RoPE.new(
32
+ hidden_size_per_attention_head,
33
+ traditional: false
34
+ )
35
+
36
+ @proj_size = args.kv_channels * @num_attention_heads
37
+
38
+ self.c_attn = MLX::NN::Linear.new(hidden_size, @proj_size * 3, bias: true)
39
+ self.c_proj = MLX::NN::Linear.new(hidden_size, @proj_size, bias: !args.no_bias)
40
+
41
+ @head_dim = args.kv_channels
42
+ @scale = hidden_size_per_attention_head**(-0.5)
43
+ end
44
+
45
+ def call(x, mask: nil, cache: nil)
46
+ mx = MLX::Core
47
+
48
+ qkv = c_attn.call(x)
49
+ q, k, v = mx.split(qkv, [@proj_size, 2 * @proj_size], -1)
50
+
51
+ b, l, _ = q.shape
52
+
53
+ queries = q.reshape([b, l, @num_attention_heads, @head_dim]).transpose([0, 2, 1, 3])
54
+ keys = k.reshape([b, l, @num_attention_heads, @head_dim]).transpose([0, 2, 1, 3])
55
+ values = v.reshape([b, l, @num_attention_heads, @head_dim]).transpose([0, 2, 1, 3])
56
+
57
+ if cache
58
+ queries = rotary_emb.call(queries, offset: cache.offset)
59
+ keys = rotary_emb.call(keys, offset: cache.offset)
60
+ keys, values = cache.update_and_fetch(keys, values)
61
+ else
62
+ queries = rotary_emb.call(queries)
63
+ keys = rotary_emb.call(keys)
64
+ end
65
+
66
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
67
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @proj_size])
68
+
69
+ c_proj.call(output)
70
+ end
71
+ end
72
+
73
+ class MLP < MLX::NN::Module
74
+ def initialize(args)
75
+ super()
76
+
77
+ self.w1 = MLX::NN::Linear.new(
78
+ args.hidden_size,
79
+ args.intermediate_size / 2,
80
+ bias: !args.no_bias
81
+ )
82
+ self.w2 = MLX::NN::Linear.new(
83
+ args.hidden_size,
84
+ args.intermediate_size / 2,
85
+ bias: !args.no_bias
86
+ )
87
+ self.c_proj = MLX::NN::Linear.new(
88
+ args.intermediate_size / 2,
89
+ args.hidden_size,
90
+ bias: !args.no_bias
91
+ )
92
+ end
93
+
94
+ def call(x)
95
+ a1 = w1.call(x)
96
+ a2 = w2.call(x)
97
+ c_proj.call(Activations.swiglu(a2, a1))
98
+ end
99
+ end
100
+
101
+ class TransformerBlock < MLX::NN::Module
102
+ def initialize(args)
103
+ super()
104
+
105
+ self.ln_1 = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.layer_norm_epsilon)
106
+ self.attn = Attention.new(args)
107
+ self.ln_2 = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.layer_norm_epsilon)
108
+ self.mlp = MLP.new(args)
109
+ end
110
+
111
+ def call(x, mask: nil, cache: nil)
112
+ residual = x
113
+ x = ln_1.call(x)
114
+ x = attn.call(x, mask: mask, cache: cache)
115
+ residual = x + residual
116
+ x = ln_2.call(residual)
117
+ x = mlp.call(x)
118
+ x + residual
119
+ end
120
+ end
121
+
122
+ class QwenModel < MLX::NN::Module
123
+ def initialize(args)
124
+ super()
125
+ self.wte = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
126
+ self.h = Array.new(args.num_hidden_layers) { TransformerBlock.new(args) }
127
+ self.ln_f = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.layer_norm_epsilon)
128
+ end
129
+
130
+ def call(inputs, cache: nil)
131
+ x = wte.call(inputs)
132
+ layer_cache = cache || [nil] * h.length
133
+
134
+ mask = nil
135
+ mask = "causal" if x.shape[1] > 1
136
+
137
+ h.each_with_index do |layer, i|
138
+ x = layer.call(x, mask: mask, cache: layer_cache[i])
139
+ end
140
+
141
+ ln_f.call(x)
142
+ end
143
+ end
144
+
145
+ class Model < MLX::NN::Module
146
+ def initialize(config)
147
+ super()
148
+ @args = config
149
+ self.model_type = config.model_type
150
+ self.transformer = QwenModel.new(config)
151
+ self.lm_head = MLX::NN::Linear.new(
152
+ config.hidden_size,
153
+ config.vocab_size,
154
+ bias: !config.no_bias
155
+ )
156
+ end
157
+
158
+ def call(x, cache: nil)
159
+ y = transformer.call(x, cache: cache)
160
+ lm_head.call(y)
161
+ end
162
+
163
+ def sanitize(weights)
164
+ weights.reject { |k, _| k.include?("rotary_emb.inv_freq") }
165
+ end
166
+
167
+ def layers
168
+ transformer.h
169
+ end
170
+ end
171
+
172
+ Models.register("qwen", Model, ModelArgs)
173
+ end
174
+ end
175
+ end
@@ -0,0 +1,162 @@
1
+ module MlxLm
2
+ module Models
3
+ module Qwen2
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "qwen2"
6
+ field :hidden_size, default: 4096
7
+ field :num_hidden_layers, default: 32
8
+ field :num_attention_heads, default: 32
9
+ field :num_key_value_heads, default: 32
10
+ field :intermediate_size, default: 22016
11
+ field :vocab_size, default: 151936
12
+ field :rms_norm_eps, default: 1e-6
13
+ field :rope_theta, default: 1000000.0
14
+ field :rope_traditional, default: false
15
+ field :rope_scaling, default: nil
16
+ field :tie_word_embeddings, default: true
17
+ field :head_dim, default: nil
18
+ field :max_position_embeddings, default: 32768
19
+
20
+ def initialize(**kwargs)
21
+ super
22
+ @num_key_value_heads ||= @num_attention_heads
23
+ @head_dim ||= @hidden_size / @num_attention_heads
24
+ end
25
+ end
26
+
27
+ class Attention < MLX::NN::Module
28
+ def initialize(args)
29
+ super()
30
+ dim = args.hidden_size
31
+ @n_heads = args.num_attention_heads
32
+ @n_kv_heads = args.num_key_value_heads
33
+ @head_dim = args.head_dim
34
+ @scale = @head_dim**(-0.5)
35
+
36
+ # Qwen2: Q, K, V have bias; O does not
37
+ self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: true)
38
+ self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: true)
39
+ self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: true)
40
+ self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: false)
41
+
42
+ self.rope = MLX::NN::RoPE.new(
43
+ @head_dim,
44
+ traditional: args.rope_traditional,
45
+ base: args.rope_theta
46
+ )
47
+ end
48
+
49
+ def call(x, mask: nil, cache: nil)
50
+ mx = MLX::Core
51
+ b, l, _d = x.shape
52
+
53
+ queries = q_proj.call(x).reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
54
+ keys = k_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
55
+ values = v_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
56
+
57
+ if cache
58
+ queries = rope.call(queries, offset: cache.offset)
59
+ keys = rope.call(keys, offset: cache.offset)
60
+ keys, values = cache.update_and_fetch(keys, values)
61
+ else
62
+ queries = rope.call(queries)
63
+ keys = rope.call(keys)
64
+ end
65
+
66
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
67
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
68
+ o_proj.call(output)
69
+ end
70
+ end
71
+
72
+ class MLP < MLX::NN::Module
73
+ def initialize(args)
74
+ super()
75
+ dim = args.hidden_size
76
+ hidden_dim = args.intermediate_size
77
+
78
+ self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false)
79
+ self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: false)
80
+ self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false)
81
+ end
82
+
83
+ def call(x)
84
+ down_proj.call(MLX::NN.silu(gate_proj.call(x)) * up_proj.call(x))
85
+ end
86
+ end
87
+
88
+ class TransformerBlock < MLX::NN::Module
89
+ def initialize(args)
90
+ super()
91
+ self.self_attn = Attention.new(args)
92
+ self.mlp = MLP.new(args)
93
+ self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
94
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
95
+ end
96
+
97
+ def call(x, mask: nil, cache: nil)
98
+ r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
99
+ h = x + r
100
+ r = mlp.call(post_attention_layernorm.call(h))
101
+ h + r
102
+ end
103
+ end
104
+
105
+ class Qwen2Model < MLX::NN::Module
106
+ def initialize(args)
107
+ super()
108
+ @args = args
109
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
110
+ self.layers = Array.new(args.num_hidden_layers) { TransformerBlock.new(args) }
111
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
112
+ end
113
+
114
+ def call(inputs, cache: nil)
115
+ h = embed_tokens.call(inputs)
116
+ layer_cache = cache || [nil] * layers.length
117
+
118
+ mask = nil
119
+ mask = "causal" if h.shape[1] > 1
120
+
121
+ layers.each_with_index do |layer, i|
122
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
123
+ end
124
+
125
+ norm.call(h)
126
+ end
127
+ end
128
+
129
+ class Model < MLX::NN::Module
130
+ def initialize(args)
131
+ super()
132
+ @args = args
133
+ self.model = Qwen2Model.new(args)
134
+ unless args.tie_word_embeddings
135
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
136
+ end
137
+ end
138
+
139
+ def call(inputs, cache: nil)
140
+ out = model.call(inputs, cache: cache)
141
+ if @args.tie_word_embeddings
142
+ model.embed_tokens.as_linear(out)
143
+ else
144
+ lm_head.call(out)
145
+ end
146
+ end
147
+
148
+ def sanitize(weights)
149
+ result = weights.reject { |k, _| k.include?("self_attn.rotary_emb.inv_freq") }
150
+ result.delete("lm_head.weight") if @args.tie_word_embeddings
151
+ result
152
+ end
153
+
154
+ def layers
155
+ model.layers
156
+ end
157
+ end
158
+
159
+ Models.register("qwen2", Model, ModelArgs)
160
+ end
161
+ end
162
+ end
@@ -0,0 +1,189 @@
1
+ require_relative "activations"
2
+ require_relative "qwen2"
3
+ require_relative "switch_layers"
4
+
5
+ module MlxLm
6
+ module Models
7
+ module Qwen2Moe
8
+ class ModelArgs < Qwen2::ModelArgs
9
+ field :model_type, default: "qwen2_moe"
10
+ field :num_key_value_heads, default: nil
11
+ field :num_experts_per_tok
12
+ field :num_experts
13
+ field :moe_intermediate_size
14
+ field :shared_expert_intermediate_size
15
+ field :tie_word_embeddings, default: false
16
+
17
+ def initialize(**kwargs)
18
+ super
19
+ validate_rope_scaling!
20
+ end
21
+
22
+ private
23
+
24
+ def validate_rope_scaling!
25
+ return unless @rope_scaling
26
+
27
+ required_keys = %w[factor type]
28
+ unless required_keys.all? { |key| _rope_scaling_has_key?(key) }
29
+ raise ArgumentError, "rope_scaling must contain keys #{required_keys}"
30
+ end
31
+
32
+ return if _rope_scaling_value("type") == "linear"
33
+
34
+ raise ArgumentError, "rope_scaling 'type' currently only supports 'linear'"
35
+ end
36
+
37
+ def _rope_scaling_has_key?(key)
38
+ @rope_scaling.key?(key) || @rope_scaling.key?(key.to_sym)
39
+ end
40
+
41
+ def _rope_scaling_value(key)
42
+ return @rope_scaling[key] if @rope_scaling.key?(key)
43
+
44
+ @rope_scaling[key.to_sym]
45
+ end
46
+ end
47
+
48
+ class SharedExpertMLP < MLX::NN::Module
49
+ def initialize(dim, hidden_dim)
50
+ super()
51
+ self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false)
52
+ self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: false)
53
+ self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false)
54
+ end
55
+
56
+ def call(x)
57
+ down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x)))
58
+ end
59
+ end
60
+
61
+ class SparseMoeBlock < MLX::NN::Module
62
+ def initialize(args)
63
+ super()
64
+ dim = args.hidden_size
65
+ intermediate_size = args.moe_intermediate_size
66
+ shared_expert_intermediate_size = args.shared_expert_intermediate_size
67
+
68
+ @num_experts = args.num_experts
69
+ @top_k = args.num_experts_per_tok
70
+
71
+ self.gate = MLX::NN::Linear.new(dim, @num_experts, bias: false)
72
+ self.switch_mlp = SwitchLayers::SwitchGLU.new(dim, intermediate_size, @num_experts)
73
+
74
+ self.shared_expert = SharedExpertMLP.new(dim, shared_expert_intermediate_size)
75
+ self.shared_expert_gate = MLX::NN::Linear.new(dim, 1, bias: false)
76
+ end
77
+
78
+ def call(x)
79
+ mx = MLX::Core
80
+
81
+ gates = gate.call(x)
82
+ gates = mx.softmax(gates.astype(mx.float32), -1).astype(gates.dtype)
83
+
84
+ k = [@top_k, @num_experts].min
85
+ inds = mx.stop_gradient(mx.argpartition(gates * -1.0, k - 1, -1))
86
+ take_ids = mx.array((0...k).to_a, dtype: mx.int32)
87
+ inds = mx.take(inds, take_ids, -1)
88
+ scores = mx.take_along_axis(gates, inds, -1)
89
+
90
+ y = switch_mlp.call(x, inds)
91
+ y = mx.sum(y * mx.expand_dims(scores, -1), -2)
92
+
93
+ shared_expert_output = shared_expert.call(x)
94
+ shared_expert_output = mx.sigmoid(shared_expert_gate.call(x)) * shared_expert_output
95
+
96
+ y + shared_expert_output
97
+ end
98
+ end
99
+
100
+ class DecoderLayer < MLX::NN::Module
101
+ def initialize(args)
102
+ super()
103
+ self.self_attn = Qwen2::Attention.new(args)
104
+ self.mlp = SparseMoeBlock.new(args)
105
+ self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
106
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
107
+ end
108
+
109
+ def call(x, mask: nil, cache: nil)
110
+ r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
111
+ h = x + r
112
+ r = mlp.call(post_attention_layernorm.call(h))
113
+ h + r
114
+ end
115
+ end
116
+
117
+ class Qwen2MoeModel < MLX::NN::Module
118
+ def initialize(args)
119
+ super()
120
+ @args = args
121
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
122
+ self.layers = Array.new(args.num_hidden_layers) { DecoderLayer.new(args) }
123
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
124
+ end
125
+
126
+ def call(inputs, cache: nil)
127
+ h = embed_tokens.call(inputs)
128
+ layer_cache = cache || [nil] * layers.length
129
+
130
+ mask = nil
131
+ mask = "causal" if h.shape[1] > 1
132
+
133
+ layers.each_with_index do |layer, layer_idx|
134
+ h = layer.call(h, mask: mask, cache: layer_cache[layer_idx])
135
+ end
136
+
137
+ norm.call(h)
138
+ end
139
+ end
140
+
141
+ class Model < MLX::NN::Module
142
+ def initialize(args)
143
+ super()
144
+ @args = args
145
+ self.model_type = args.model_type
146
+ self.model = Qwen2MoeModel.new(args)
147
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
148
+ end
149
+
150
+ def call(inputs, cache: nil)
151
+ lm_head.call(model.call(inputs, cache: cache))
152
+ end
153
+
154
+ def sanitize(weights)
155
+ return weights unless weights.key?("model.layers.0.mlp.experts.0.up_proj.weight")
156
+
157
+ mx = MLX::Core
158
+ result = weights.dup
159
+
160
+ @args.num_hidden_layers.times do |layer_idx|
161
+ prefix = "model.layers.#{layer_idx}"
162
+ %w[up_proj down_proj gate_proj].each do |projection|
163
+ %w[weight scales biases].each do |param|
164
+ first_key = "#{prefix}.mlp.experts.0.#{projection}.#{param}"
165
+ next unless result.key?(first_key)
166
+
167
+ expert_keys = (0...@args.num_experts).map do |expert_idx|
168
+ "#{prefix}.mlp.experts.#{expert_idx}.#{projection}.#{param}"
169
+ end
170
+ next unless expert_keys.all? { |key| result.key?(key) }
171
+
172
+ stacked = expert_keys.map { |key| result.delete(key) }
173
+ result["#{prefix}.mlp.switch_mlp.#{projection}.#{param}"] = mx.stack(stacked)
174
+ end
175
+ end
176
+ end
177
+
178
+ result
179
+ end
180
+
181
+ def layers
182
+ model.layers
183
+ end
184
+ end
185
+
186
+ Models.register("qwen2_moe", Model, ModelArgs)
187
+ end
188
+ end
189
+ end