mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
@@ -0,0 +1,174 @@
1
+ module MlxLm
2
+ module Models
3
+ module Mimo
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "mimo"
6
+ field :hidden_size
7
+ field :num_hidden_layers
8
+ field :intermediate_size
9
+ field :num_attention_heads
10
+ field :rms_norm_eps
11
+ field :vocab_size
12
+ field :num_key_value_heads, default: nil
13
+ field :max_position_embeddings, default: 32_768
14
+ field :rope_theta, default: 10_000.0
15
+ field :rope_traditional, default: false
16
+ field :rope_scaling, default: nil
17
+ field :tie_word_embeddings, default: false
18
+ field :num_nextn_predict_layers, default: 2
19
+
20
+ def initialize(**kwargs)
21
+ super
22
+ @num_key_value_heads ||= @num_attention_heads
23
+ end
24
+ end
25
+
26
+ class Attention < MLX::NN::Module
27
+ def initialize(args)
28
+ super()
29
+ dim = args.hidden_size
30
+ @n_heads = args.num_attention_heads
31
+ @n_kv_heads = args.num_key_value_heads
32
+ @head_dim = dim / @n_heads
33
+ @scale = @head_dim**(-0.5)
34
+
35
+ self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: true)
36
+ self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: true)
37
+ self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: true)
38
+ self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: false)
39
+ self.rope = MlxLm::Models.initialize_rope(
40
+ @head_dim,
41
+ args.rope_theta,
42
+ args.rope_traditional,
43
+ args.rope_scaling,
44
+ max_position_embeddings: args.max_position_embeddings
45
+ )
46
+ end
47
+
48
+ def call(x, mask: nil, cache: nil)
49
+ mx = MLX::Core
50
+ b, l, _d = x.shape
51
+
52
+ queries = q_proj.call(x)
53
+ keys = k_proj.call(x)
54
+ values = v_proj.call(x)
55
+
56
+ queries = queries.reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
57
+ keys = keys.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
58
+ values = values.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
59
+
60
+ if cache
61
+ queries = rope.call(queries, offset: cache.offset)
62
+ keys = rope.call(keys, offset: cache.offset)
63
+ keys, values = cache.update_and_fetch(keys, values)
64
+ else
65
+ queries = rope.call(queries)
66
+ keys = rope.call(keys)
67
+ end
68
+
69
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
70
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
71
+ o_proj.call(output)
72
+ end
73
+ end
74
+
75
+ class MLP < MLX::NN::Module
76
+ def initialize(dim, hidden_dim)
77
+ super()
78
+ self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false)
79
+ self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: false)
80
+ self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false)
81
+ end
82
+
83
+ def call(x)
84
+ down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x)))
85
+ end
86
+ end
87
+
88
+ class TransformerBlock < MLX::NN::Module
89
+ def initialize(args)
90
+ super()
91
+ self.self_attn = Attention.new(args)
92
+ self.mlp = MLP.new(args.hidden_size, args.intermediate_size)
93
+ self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
94
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
95
+ end
96
+
97
+ def call(x, mask: nil, cache: nil)
98
+ r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
99
+ h = x + r
100
+ r = mlp.call(post_attention_layernorm.call(h))
101
+ h + r
102
+ end
103
+ end
104
+
105
+ class MimoModel < MLX::NN::Module
106
+ def initialize(args)
107
+ super()
108
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
109
+ self.layers = Array.new(args.num_hidden_layers) { TransformerBlock.new(args) }
110
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
111
+ end
112
+
113
+ def call(inputs, cache: nil)
114
+ h = embed_tokens.call(inputs)
115
+ layer_cache = cache || [nil] * layers.length
116
+ mask = _create_attention_mask(h, layer_cache[0])
117
+
118
+ layers.each_with_index do |layer, i|
119
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
120
+ end
121
+
122
+ norm.call(h)
123
+ end
124
+
125
+ private
126
+
127
+ def _create_attention_mask(h, cache)
128
+ return cache.make_mask(h.shape[1]) if cache && cache.respond_to?(:make_mask)
129
+ return nil if h.shape[1] == 1
130
+
131
+ "causal"
132
+ end
133
+ end
134
+
135
+ class Model < MLX::NN::Module
136
+ def initialize(args)
137
+ super()
138
+ @args = args
139
+ self.model_type = args.model_type
140
+ self.model = MimoModel.new(args)
141
+ self.lm_head = nil
142
+ unless args.tie_word_embeddings
143
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
144
+ end
145
+ end
146
+
147
+ def call(inputs, cache: nil)
148
+ out = model.call(inputs, cache: cache)
149
+
150
+ if @args.tie_word_embeddings
151
+ model.embed_tokens.as_linear(out)
152
+ else
153
+ lm_head.call(out)
154
+ end
155
+ end
156
+
157
+ def sanitize(weights)
158
+ result = weights.dup
159
+ result.delete("lm_head.weight") if @args.tie_word_embeddings
160
+ result.reject do |k, _|
161
+ k.include?("self_attn.rotary_emb.inv_freq") ||
162
+ k.start_with?("model.mtp_layers.")
163
+ end
164
+ end
165
+
166
+ def layers
167
+ model.layers
168
+ end
169
+ end
170
+
171
+ Models.register("mimo", Model, ModelArgs)
172
+ end
173
+ end
174
+ end
@@ -0,0 +1,491 @@
1
+ require_relative "activations"
2
+ require_relative "cache"
3
+ require_relative "rope_utils"
4
+ require_relative "switch_layers"
5
+
6
+ module MlxLm
7
+ module Models
8
+ module MimoV2Flash
9
+ class ModelArgs < BaseModelArgs
10
+ field :model_type, default: "mimo_v2_flash"
11
+ field :num_experts_per_tok, default: 1
12
+ field :hybrid_layer_pattern, default: nil
13
+ field :moe_layer_freq, default: nil
14
+ field :add_swa_attention_sink_bias, default: false
15
+ field :add_full_attention_sink_bias, default: false
16
+ field :sliding_window_size, default: 4096
17
+ field :vocab_size
18
+ field :hidden_size
19
+ field :intermediate_size
20
+ field :moe_intermediate_size
21
+ field :num_hidden_layers
22
+ field :num_attention_heads
23
+ field :num_key_value_heads, default: nil
24
+ field :n_shared_experts, default: nil
25
+ field :n_routed_experts, default: nil
26
+ field :routed_scaling_factor, default: nil
27
+ field :topk_method, default: "noaux_tc"
28
+ field :scoring_func, default: "sigmoid"
29
+ field :norm_topk_prob, default: false
30
+ field :n_group, default: 1
31
+ field :topk_group, default: 1
32
+ field :max_position_embeddings, default: 32768
33
+ field :layernorm_epsilon, default: 1e-6
34
+ field :rope_theta, default: 10_000.0
35
+ field :swa_rope_theta, default: nil
36
+ field :swa_num_attention_heads, default: nil
37
+ field :swa_num_key_value_heads, default: nil
38
+ field :head_dim, default: nil
39
+ field :v_head_dim, default: nil
40
+ field :swa_head_dim, default: nil
41
+ field :swa_v_head_dim, default: nil
42
+ field :partial_rotary_factor, default: 1.0
43
+
44
+ def initialize(**kwargs)
45
+ super
46
+ @num_key_value_heads ||= @num_attention_heads
47
+ @swa_num_attention_heads ||= @num_attention_heads
48
+ @swa_num_key_value_heads ||= @num_key_value_heads
49
+
50
+ @head_dim ||= @hidden_size / @num_attention_heads
51
+ @v_head_dim ||= @head_dim
52
+ @swa_head_dim ||= @head_dim
53
+ @swa_v_head_dim ||= @swa_head_dim
54
+ @swa_rope_theta ||= @rope_theta
55
+
56
+ @n_routed_experts ||= 1
57
+ @routed_scaling_factor = 1.0 if @routed_scaling_factor.nil?
58
+ @hybrid_layer_pattern ||= Array.new(@num_hidden_layers, 0)
59
+ @moe_layer_freq ||= Array.new(@num_hidden_layers, 0)
60
+ @topk_group ||= @n_group
61
+ end
62
+ end
63
+
64
+ class Attention < MLX::NN::Module
65
+ def initialize(args, is_sliding_window)
66
+ super()
67
+
68
+ dim = args.hidden_size
69
+ @is_sliding_window = is_sliding_window
70
+ if @is_sliding_window
71
+ @n_heads = args.swa_num_attention_heads
72
+ @n_kv_heads = args.swa_num_key_value_heads
73
+ @has_sinks = args.add_swa_attention_sink_bias
74
+ @head_dim = args.swa_head_dim
75
+ @v_head_dim = args.swa_v_head_dim
76
+ rope_theta = args.swa_rope_theta
77
+ else
78
+ @n_heads = args.num_attention_heads
79
+ @n_kv_heads = args.num_key_value_heads
80
+ @has_sinks = args.add_full_attention_sink_bias
81
+ @head_dim = args.head_dim
82
+ @v_head_dim = args.v_head_dim
83
+ rope_theta = args.rope_theta
84
+ end
85
+
86
+ @scale = @head_dim**(-0.5)
87
+
88
+ self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: false)
89
+ self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: false)
90
+ self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @v_head_dim, bias: false)
91
+ self.o_proj = MLX::NN::Linear.new(@n_heads * @v_head_dim, dim, bias: false)
92
+ self.attention_sink_bias = if @has_sinks
93
+ MLX::Core.ones([@n_heads])
94
+ else
95
+ nil
96
+ end
97
+
98
+ rotary_dim = [(@head_dim * args.partial_rotary_factor.to_f).to_i, 1].max
99
+ self.rope = MLX::NN::RoPE.new(
100
+ rotary_dim,
101
+ traditional: false,
102
+ base: rope_theta
103
+ )
104
+ end
105
+
106
+ def call(x, mask: nil, cache: nil)
107
+ mx = MLX::Core
108
+ b, l, _d = x.shape
109
+
110
+ queries = q_proj.call(x)
111
+ keys = k_proj.call(x)
112
+ values = v_proj.call(x)
113
+
114
+ queries = queries.reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
115
+ keys = keys.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
116
+ values = values.reshape([b, l, @n_kv_heads, @v_head_dim]).transpose([0, 2, 1, 3])
117
+
118
+ if cache
119
+ queries = rope.call(queries, offset: cache.offset)
120
+ keys = rope.call(keys, offset: cache.offset)
121
+ keys, values = cache.update_and_fetch(keys, values)
122
+ else
123
+ queries = rope.call(queries)
124
+ keys = rope.call(keys)
125
+ end
126
+
127
+ output = _scaled_dot_product_attention(queries, keys, values, mask)
128
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @v_head_dim])
129
+ o_proj.call(output)
130
+ end
131
+
132
+ private
133
+
134
+ def _scaled_dot_product_attention(queries, keys, values, mask)
135
+ mx = MLX::Core
136
+
137
+ if attention_sink_bias
138
+ begin
139
+ return mx.scaled_dot_product_attention(
140
+ queries,
141
+ keys,
142
+ values,
143
+ @scale,
144
+ mask,
145
+ sinks: attention_sink_bias
146
+ )
147
+ rescue StandardError
148
+ # Fallback when sinks are unsupported by the local MLX runtime.
149
+ end
150
+ end
151
+
152
+ mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
153
+ end
154
+ end
155
+
156
+ class MLP < MLX::NN::Module
157
+ def initialize(config, hidden_size: nil, intermediate_size: nil)
158
+ super()
159
+ @hidden_size = hidden_size || config.hidden_size
160
+ @intermediate_size = intermediate_size || config.intermediate_size
161
+
162
+ self.gate_proj = MLX::NN::Linear.new(@hidden_size, @intermediate_size, bias: false)
163
+ self.up_proj = MLX::NN::Linear.new(@hidden_size, @intermediate_size, bias: false)
164
+ self.down_proj = MLX::NN::Linear.new(@intermediate_size, @hidden_size, bias: false)
165
+ end
166
+
167
+ def call(x)
168
+ down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x)))
169
+ end
170
+ end
171
+
172
+ module_function
173
+
174
+ def group_expert_select(
175
+ gates,
176
+ e_score_correction_bias,
177
+ top_k,
178
+ n_group,
179
+ topk_group,
180
+ routed_scaling_factor,
181
+ norm_topk_prob
182
+ )
183
+ mx = MLX::Core
184
+
185
+ scores = mx.sigmoid(gates.astype(mx.float32))
186
+ orig_scores = scores
187
+ scores = scores + e_score_correction_bias
188
+
189
+ if n_group.to_i > 1
190
+ experts_per_group = scores.shape[-1] / n_group
191
+ scores = mx.unflatten(scores, -1, [n_group, experts_per_group])
192
+ group_scores = mx.topk(scores, 2, -1)
193
+ group_scores = mx.expand_dims(mx.sum(group_scores, -1), -1)
194
+
195
+ drop_count = n_group - topk_group.to_i
196
+ if drop_count > 0
197
+ group_idx = mx.argpartition(group_scores, drop_count - 1, -2)
198
+ take_ids = mx.array((0...drop_count).to_a, dtype: mx.int32)
199
+ group_idx = mx.take(group_idx, take_ids, -2)
200
+ scores = mx.put_along_axis(
201
+ scores,
202
+ mx.stop_gradient(group_idx),
203
+ mx.array(0.0),
204
+ -2
205
+ )
206
+ end
207
+
208
+ scores = mx.flatten(scores, -2, -1)
209
+ end
210
+
211
+ k = [top_k.to_i, scores.shape[-1]].min
212
+ inds = mx.argpartition(scores * -1.0, k - 1, -1)
213
+ take_ids = mx.array((0...k).to_a, dtype: mx.int32)
214
+ inds = mx.take(inds, take_ids, -1)
215
+
216
+ selected_scores = mx.take_along_axis(orig_scores, inds, -1)
217
+ if k > 1 && norm_topk_prob
218
+ denominator = mx.expand_dims(mx.sum(selected_scores, -1), -1)
219
+ selected_scores = selected_scores / (denominator + 1e-20)
220
+ end
221
+
222
+ selected_scores = selected_scores * routed_scaling_factor.to_f
223
+ [inds, selected_scores]
224
+ end
225
+
226
+ class MoEGate < MLX::NN::Module
227
+ def initialize(config)
228
+ super()
229
+ @top_k = config.num_experts_per_tok
230
+ @norm_topk_prob = config.norm_topk_prob
231
+ @n_routed_experts = config.n_routed_experts
232
+ @routed_scaling_factor = config.routed_scaling_factor || 1.0
233
+ @n_group = config.n_group
234
+ @topk_group = config.topk_group
235
+
236
+ raise ArgumentError, "Unsupported topk method: #{config.topk_method}" unless config.topk_method == "noaux_tc"
237
+
238
+ mx = MLX::Core
239
+ self.weight = mx.zeros([@n_routed_experts, config.hidden_size])
240
+ self.e_score_correction_bias = mx.zeros([@n_routed_experts])
241
+ end
242
+
243
+ def call(x)
244
+ mx = MLX::Core
245
+ gates = mx.matmul(x, mx.transpose(weight))
246
+ MimoV2Flash.group_expert_select(
247
+ gates,
248
+ e_score_correction_bias,
249
+ @top_k,
250
+ @n_group,
251
+ @topk_group,
252
+ @routed_scaling_factor,
253
+ @norm_topk_prob
254
+ )
255
+ end
256
+ end
257
+
258
+ class MoE < MLX::NN::Module
259
+ def initialize(config)
260
+ super()
261
+ @config = config
262
+
263
+ self.switch_mlp = SwitchLayers::SwitchGLU.new(
264
+ config.hidden_size,
265
+ config.moe_intermediate_size,
266
+ config.n_routed_experts
267
+ )
268
+
269
+ self.gate = MoEGate.new(config)
270
+ if config.n_shared_experts
271
+ shared_intermediate = config.moe_intermediate_size * config.n_shared_experts
272
+ self.shared_experts = MLP.new(config, intermediate_size: shared_intermediate)
273
+ end
274
+ end
275
+
276
+ def call(x)
277
+ mx = MLX::Core
278
+ inds, scores = gate.call(x)
279
+ y = switch_mlp.call(x, inds)
280
+ y = mx.sum(y * mx.expand_dims(scores, -1), -2).astype(y.dtype)
281
+ y = y + shared_experts.call(x) if @config.n_shared_experts
282
+ y
283
+ end
284
+ end
285
+
286
+ class DecoderLayer < MLX::NN::Module
287
+ attr_reader :is_sliding_window
288
+
289
+ def initialize(config, is_moe, is_sliding_window)
290
+ super()
291
+ @is_sliding_window = is_sliding_window
292
+
293
+ self.self_attn = Attention.new(config, is_sliding_window)
294
+ self.mlp = is_moe ? MoE.new(config) : MLP.new(config)
295
+ self.input_layernorm = MLX::NN::RMSNorm.new(
296
+ config.hidden_size,
297
+ eps: config.layernorm_epsilon
298
+ )
299
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(
300
+ config.hidden_size,
301
+ eps: config.layernorm_epsilon
302
+ )
303
+ end
304
+
305
+ def call(x, mask: nil, cache: nil)
306
+ r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
307
+ h = x + r
308
+ r = mlp.call(post_attention_layernorm.call(h))
309
+ h + r
310
+ end
311
+ end
312
+
313
+ class LanguageModel < MLX::NN::Module
314
+ def initialize(config)
315
+ super()
316
+ @hybrid_layer_pattern = config.hybrid_layer_pattern
317
+ @sliding_window_size = config.sliding_window_size
318
+
319
+ self.embed_tokens = MLX::NN::Embedding.new(config.vocab_size, config.hidden_size)
320
+ self.layers = Array.new(config.num_hidden_layers) do |idx|
321
+ DecoderLayer.new(
322
+ config,
323
+ config.moe_layer_freq[idx] == 1,
324
+ config.hybrid_layer_pattern[idx] == 1
325
+ )
326
+ end
327
+ self.norm = MLX::NN::RMSNorm.new(config.hidden_size, eps: config.layernorm_epsilon)
328
+ self.swa_idx = @hybrid_layer_pattern.index(1) || 0
329
+ self.ga_idx = @hybrid_layer_pattern.index(0) || 0
330
+ end
331
+
332
+ def call(x, cache: nil)
333
+ h = embed_tokens.call(x)
334
+ layer_cache = cache || [nil] * layers.length
335
+
336
+ full_mask = _create_attention_mask(h, layer_cache[ga_idx])
337
+ swa_mask = _create_attention_mask(
338
+ h,
339
+ layer_cache[swa_idx],
340
+ window_size: @sliding_window_size
341
+ )
342
+
343
+ layers.each_with_index do |layer, i|
344
+ mask = layer.is_sliding_window ? swa_mask : full_mask
345
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
346
+ end
347
+
348
+ norm.call(h)
349
+ end
350
+
351
+ private
352
+
353
+ def _create_attention_mask(h, cache = nil, window_size: nil)
354
+ n = h.shape[1]
355
+ if cache && cache.respond_to?(:make_mask)
356
+ return cache.make_mask(n, window_size: window_size)
357
+ end
358
+
359
+ if window_size
360
+ offset = 0
361
+ if cache
362
+ offset = cache.offset if cache.respond_to?(:offset)
363
+ if cache.instance_variable_defined?(:@max_size)
364
+ max_size = cache.instance_variable_get(:@max_size)
365
+ offset = [max_size - 1, offset].min if max_size && max_size > 0
366
+ end
367
+ end
368
+ return _create_causal_mask(n, offset: offset, window_size: window_size) if offset + n > window_size
369
+ end
370
+
371
+ return nil if n == 1
372
+
373
+ "causal"
374
+ end
375
+
376
+ def _create_causal_mask(n, offset: 0, window_size: nil)
377
+ mx = MLX::Core
378
+ rinds = mx.arange(0, offset + n, 1, mx.int32).reshape([1, offset + n])
379
+ linds = mx.arange(offset, offset + n, 1, mx.int32).reshape([n, 1])
380
+
381
+ mask = mx.greater_equal(linds, rinds)
382
+ if window_size
383
+ mask = mx.logical_and(mask, mx.less(linds, mx.add(rinds, window_size)))
384
+ end
385
+ mask
386
+ end
387
+ end
388
+
389
+ class Model < MLX::NN::Module
390
+ def initialize(config)
391
+ super()
392
+ @args = config
393
+ self.model_type = config.model_type
394
+ self.model = LanguageModel.new(config)
395
+ self.lm_head = MLX::NN::Linear.new(config.hidden_size, config.vocab_size, bias: false)
396
+ end
397
+
398
+ def call(inputs, cache: nil)
399
+ out = model.call(inputs, cache: cache)
400
+ lm_head.call(out)
401
+ end
402
+
403
+ def sanitize(weights)
404
+ mx = MLX::Core
405
+ new_weights = {}
406
+
407
+ weights.each do |k, v|
408
+ if k.include?("weight_scale_inv")
409
+ wk = k.sub("_scale_inv", "")
410
+ if weights.key?(wk)
411
+ new_weights[wk] = _dequant(weights[wk], v)
412
+ end
413
+ elsif !new_weights.key?(k)
414
+ new_weights[k] = v
415
+ end
416
+ end
417
+
418
+ result = new_weights
419
+ @args.num_hidden_layers.times do |layer_idx|
420
+ prefix = "model.layers.#{layer_idx}"
421
+ %w[gate_proj down_proj up_proj].each do |proj|
422
+ %w[weight scales biases].each do |param|
423
+ first_key = "#{prefix}.mlp.experts.0.#{proj}.#{param}"
424
+ next unless result.key?(first_key)
425
+
426
+ expert_keys = (0...@args.n_routed_experts).map do |expert_idx|
427
+ "#{prefix}.mlp.experts.#{expert_idx}.#{proj}.#{param}"
428
+ end
429
+ next unless expert_keys.all? { |key| result.key?(key) }
430
+
431
+ stacked = expert_keys.map { |key| result.delete(key) }
432
+ result["#{prefix}.mlp.switch_mlp.#{proj}.#{param}"] = mx.stack(stacked)
433
+ end
434
+ end
435
+ end
436
+
437
+ result.reject { |k, _| k.start_with?("model.mtp") }
438
+ end
439
+
440
+ def layers
441
+ model.layers
442
+ end
443
+
444
+ def cast_predicate
445
+ lambda { |k| !k.include?("e_score_correction_bias") }
446
+ end
447
+
448
+ def make_cache
449
+ layers.map do |layer|
450
+ if layer.is_sliding_window
451
+ MlxLm::RotatingKVCache.new(max_size: @args.sliding_window_size)
452
+ else
453
+ MlxLm::KVCache.new
454
+ end
455
+ end
456
+ end
457
+
458
+ private
459
+
460
+ def _dequant(weight, scale_inv)
461
+ mx = MLX::Core
462
+ dtype = mx.bfloat16
463
+ block_size = 128
464
+
465
+ dequantized = mx.from_fp8(weight, dtype: dtype)
466
+ m, n = dequantized.shape
467
+ pad_bottom = block_size * scale_inv.shape[0] - m
468
+ pad_side = block_size * scale_inv.shape[1] - n
469
+
470
+ dequantized = mx.pad(dequantized, [[0, pad_bottom], [0, pad_side]])
471
+ dequantized = dequantized.reshape([
472
+ (m + pad_bottom) / block_size,
473
+ block_size,
474
+ (n + pad_side) / block_size,
475
+ block_size,
476
+ ])
477
+
478
+ scaled = dequantized * scale_inv.reshape([scale_inv.shape[0], 1, scale_inv.shape[1], 1])
479
+ scaled = scaled.reshape([m + pad_bottom, n + pad_side])
480
+ scaled = mx.split(scaled, [m], 0)[0]
481
+ scaled = mx.split(scaled, [n], 1)[0]
482
+ scaled.astype(dtype)
483
+ rescue StandardError
484
+ weight
485
+ end
486
+ end
487
+
488
+ Models.register("mimo_v2_flash", Model, ModelArgs)
489
+ end
490
+ end
491
+ end