mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
@@ -0,0 +1,195 @@
1
+ module MlxLm
2
+ module Models
3
+ module Llama4Text
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "llama4_text"
6
+ field :hidden_size
7
+ field :num_attention_heads
8
+ field :num_hidden_layers
9
+ field :vocab_size
10
+ field :intermediate_size, default: nil
11
+ field :intermediate_size_mlp, default: nil
12
+ field :num_key_value_heads, default: nil
13
+ field :rms_norm_eps, default: 1e-5
14
+ field :rope_theta, default: 10_000.0
15
+ field :head_dim, default: nil
16
+ field :tie_word_embeddings, default: true
17
+ field :no_rope_layers, default: nil
18
+ field :use_qk_norm, default: false
19
+
20
+ def initialize(**kwargs)
21
+ super
22
+ @num_key_value_heads ||= @num_attention_heads
23
+ @head_dim ||= @hidden_size / @num_attention_heads
24
+ @intermediate_size_mlp ||= @intermediate_size
25
+
26
+ if @no_rope_layers.nil?
27
+ @no_rope_layers = Array.new(@num_hidden_layers, 1)
28
+ elsif @no_rope_layers.length != @num_hidden_layers
29
+ raise ArgumentError, "`no_rope_layers` length mismatch"
30
+ end
31
+ end
32
+ end
33
+
34
+ class Attention < MLX::NN::Module
35
+ def initialize(args, use_rope)
36
+ super()
37
+ @n_heads = args.num_attention_heads
38
+ @n_kv_heads = args.num_key_value_heads
39
+ @head_dim = args.head_dim
40
+ @scale = @head_dim**(-0.5)
41
+ @use_rope = !!use_rope
42
+ @use_qk_norm = !!args.use_qk_norm
43
+ @rms_norm_eps = args.rms_norm_eps
44
+
45
+ self.q_proj = MLX::NN::Linear.new(args.hidden_size, @n_heads * @head_dim, bias: false)
46
+ self.k_proj = MLX::NN::Linear.new(args.hidden_size, @n_kv_heads * @head_dim, bias: false)
47
+ self.v_proj = MLX::NN::Linear.new(args.hidden_size, @n_kv_heads * @head_dim, bias: false)
48
+ self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, args.hidden_size, bias: false)
49
+
50
+ if @use_rope
51
+ self.rope = MLX::NN::RoPE.new(@head_dim, traditional: true, base: args.rope_theta)
52
+ end
53
+ end
54
+
55
+ def call(x, mask: nil, cache: nil)
56
+ mx = MLX::Core
57
+ b, l, _d = x.shape
58
+
59
+ queries = q_proj.call(x)
60
+ keys = k_proj.call(x)
61
+ values = v_proj.call(x)
62
+
63
+ queries = queries.reshape([b, l, @n_heads, @head_dim])
64
+ keys = keys.reshape([b, l, @n_kv_heads, @head_dim])
65
+
66
+ if @use_qk_norm
67
+ queries = mx.rms_norm(queries, nil, @rms_norm_eps)
68
+ keys = mx.rms_norm(keys, nil, @rms_norm_eps)
69
+ end
70
+
71
+ queries = queries.transpose([0, 2, 1, 3])
72
+ keys = keys.transpose([0, 2, 1, 3])
73
+ values = values.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
74
+
75
+ if @use_rope
76
+ if cache
77
+ queries = rope.call(queries, offset: cache.offset)
78
+ keys = rope.call(keys, offset: cache.offset)
79
+ else
80
+ queries = rope.call(queries)
81
+ keys = rope.call(keys)
82
+ end
83
+ end
84
+
85
+ keys, values = cache.update_and_fetch(keys, values) if cache
86
+
87
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
88
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
89
+ o_proj.call(output)
90
+ end
91
+ end
92
+
93
+ class MLP < MLX::NN::Module
94
+ def initialize(dim, intermediate_size)
95
+ super()
96
+ self.gate_proj = MLX::NN::Linear.new(dim, intermediate_size, bias: false)
97
+ self.up_proj = MLX::NN::Linear.new(dim, intermediate_size, bias: false)
98
+ self.down_proj = MLX::NN::Linear.new(intermediate_size, dim, bias: false)
99
+ end
100
+
101
+ def call(x)
102
+ down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x)))
103
+ end
104
+ end
105
+
106
+ class TransformerBlock < MLX::NN::Module
107
+ def initialize(args, use_rope)
108
+ super()
109
+ self.self_attn = Attention.new(args, use_rope)
110
+ self.feed_forward = MLP.new(args.hidden_size, args.intermediate_size_mlp)
111
+ self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
112
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
113
+ end
114
+
115
+ def call(x, mask: nil, cache: nil)
116
+ r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
117
+ h = x + r
118
+ r = feed_forward.call(post_attention_layernorm.call(h))
119
+ h + r
120
+ end
121
+ end
122
+
123
+ class LanguageModel < MLX::NN::Module
124
+ def initialize(args)
125
+ super()
126
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
127
+ self.layers = Array.new(args.num_hidden_layers) do |i|
128
+ TransformerBlock.new(args, args.no_rope_layers[i])
129
+ end
130
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
131
+ end
132
+
133
+ def call(inputs, cache: nil)
134
+ h = embed_tokens.call(inputs)
135
+ layer_cache = cache || [nil] * layers.length
136
+ mask = _create_attention_mask(h, layer_cache[0])
137
+
138
+ layers.each_with_index do |layer, i|
139
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
140
+ end
141
+
142
+ norm.call(h)
143
+ end
144
+
145
+ private
146
+
147
+ def _create_attention_mask(h, cache)
148
+ return cache.make_mask(h.shape[1]) if cache && cache.respond_to?(:make_mask)
149
+ return nil if h.shape[1] == 1
150
+
151
+ "causal"
152
+ end
153
+ end
154
+
155
+ class Model < MLX::NN::Module
156
+ def initialize(args)
157
+ super()
158
+ @args = args
159
+ self.model_type = args.model_type
160
+ self.model = LanguageModel.new(args)
161
+ self.output = nil
162
+ unless args.tie_word_embeddings
163
+ self.output = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
164
+ end
165
+ end
166
+
167
+ def call(inputs, cache: nil)
168
+ h = model.call(inputs, cache: cache)
169
+ if @args.tie_word_embeddings
170
+ model.embed_tokens.as_linear(h)
171
+ else
172
+ output.call(h)
173
+ end
174
+ end
175
+
176
+ def sanitize(weights)
177
+ sanitized = weights.reject do |k, _|
178
+ k.include?("self_attn.rotary_emb.inv_freq") || k.include?("self_attn.rope.inv_freq")
179
+ end
180
+ if @args.tie_word_embeddings
181
+ sanitized.delete("output.weight")
182
+ sanitized.delete("lm_head.weight")
183
+ end
184
+ sanitized
185
+ end
186
+
187
+ def layers
188
+ model.layers
189
+ end
190
+ end
191
+
192
+ Models.register("llama4_text", Model, ModelArgs)
193
+ end
194
+ end
195
+ end
@@ -0,0 +1,153 @@
1
+ require_relative "glm4_moe_lite"
2
+
3
+ module MlxLm
4
+ module Models
5
+ module LongcatFlash
6
+ class ModelArgs < Glm4MoeLite::ModelArgs
7
+ field :model_type, default: "longcat_flash"
8
+ field :hidden_dim, default: nil
9
+ field :ffn_hidden_size, default: nil
10
+ field :num_layers, default: nil
11
+ field :num_heads, default: nil
12
+ field :num_kv_heads, default: nil
13
+ field :num_experts, default: nil
14
+ field :num_local_experts, default: nil
15
+ field :num_shared_experts, default: nil
16
+ field :top_k, default: nil
17
+ field :score_function, default: nil
18
+
19
+ def self.from_dict(params)
20
+ normalized = params.each_with_object({}) do |(key, value), out|
21
+ out[key.to_s] = value
22
+ end
23
+
24
+ {
25
+ "hidden_dim" => "hidden_size",
26
+ "ffn_hidden_size" => "intermediate_size",
27
+ "num_layers" => "num_hidden_layers",
28
+ "num_heads" => "num_attention_heads",
29
+ "num_kv_heads" => "num_key_value_heads",
30
+ "num_local_experts" => "n_routed_experts",
31
+ "num_experts" => "n_routed_experts",
32
+ "num_shared_experts" => "n_shared_experts",
33
+ "top_k" => "num_experts_per_tok",
34
+ "score_function" => "scoring_func",
35
+ }.each do |source_key, target_key|
36
+ next unless normalized.key?(source_key)
37
+
38
+ normalized[target_key] = normalized[source_key] unless normalized.key?(target_key)
39
+ end
40
+
41
+ normalized["model_type"] ||= "longcat_flash"
42
+ super(normalized)
43
+ end
44
+
45
+ def initialize(**kwargs)
46
+ super
47
+ @hidden_size = @hidden_dim if kwargs.key?(:hidden_dim) && !kwargs.key?(:hidden_size) && !@hidden_dim.nil?
48
+ @intermediate_size = @ffn_hidden_size if kwargs.key?(:ffn_hidden_size) && !kwargs.key?(:intermediate_size) && !@ffn_hidden_size.nil?
49
+ @num_hidden_layers = @num_layers if kwargs.key?(:num_layers) && !kwargs.key?(:num_hidden_layers) && !@num_layers.nil?
50
+ @num_attention_heads = @num_heads if kwargs.key?(:num_heads) && !kwargs.key?(:num_attention_heads) && !@num_heads.nil?
51
+ @num_key_value_heads = @num_kv_heads if kwargs.key?(:num_kv_heads) && !kwargs.key?(:num_key_value_heads) && !@num_kv_heads.nil?
52
+ @n_routed_experts = @num_local_experts if kwargs.key?(:num_local_experts) && !kwargs.key?(:n_routed_experts) && !@num_local_experts.nil?
53
+ @n_routed_experts = @num_experts if kwargs.key?(:num_experts) && !kwargs.key?(:n_routed_experts) && !kwargs.key?(:num_local_experts) && !@num_experts.nil?
54
+ @n_shared_experts = @num_shared_experts if kwargs.key?(:num_shared_experts) && !kwargs.key?(:n_shared_experts) && !@num_shared_experts.nil?
55
+ @num_experts_per_tok = @top_k if kwargs.key?(:top_k) && !kwargs.key?(:num_experts_per_tok) && !@top_k.nil?
56
+ @scoring_func = @score_function if kwargs.key?(:score_function) && !kwargs.key?(:scoring_func) && !@score_function.nil?
57
+ @num_key_value_heads ||= @num_attention_heads
58
+ end
59
+
60
+ def to_glm4_moe_lite_dict
61
+ {
62
+ "model_type" => @model_type,
63
+ "vocab_size" => @vocab_size,
64
+ "hidden_size" => @hidden_size,
65
+ "intermediate_size" => @intermediate_size,
66
+ "moe_intermediate_size" => @moe_intermediate_size,
67
+ "num_hidden_layers" => @num_hidden_layers,
68
+ "num_attention_heads" => @num_attention_heads,
69
+ "num_key_value_heads" => @num_key_value_heads,
70
+ "n_shared_experts" => @n_shared_experts,
71
+ "n_routed_experts" => @n_routed_experts,
72
+ "routed_scaling_factor" => @routed_scaling_factor,
73
+ "kv_lora_rank" => @kv_lora_rank,
74
+ "q_lora_rank" => @q_lora_rank,
75
+ "qk_rope_head_dim" => @qk_rope_head_dim,
76
+ "qk_nope_head_dim" => @qk_nope_head_dim,
77
+ "v_head_dim" => @v_head_dim,
78
+ "topk_method" => @topk_method,
79
+ "scoring_func" => @scoring_func,
80
+ "norm_topk_prob" => @norm_topk_prob,
81
+ "n_group" => @n_group,
82
+ "topk_group" => @topk_group,
83
+ "num_experts_per_tok" => @num_experts_per_tok,
84
+ "moe_layer_freq" => @moe_layer_freq,
85
+ "first_k_dense_replace" => @first_k_dense_replace,
86
+ "max_position_embeddings" => @max_position_embeddings,
87
+ "rms_norm_eps" => @rms_norm_eps,
88
+ "rope_theta" => @rope_theta,
89
+ "rope_scaling" => @rope_scaling,
90
+ "attention_bias" => @attention_bias,
91
+ "attention_dropout" => @attention_dropout,
92
+ "partial_rotary_factor" => @partial_rotary_factor,
93
+ "tie_word_embeddings" => @tie_word_embeddings,
94
+ "num_nextn_predict_layers" => @num_nextn_predict_layers,
95
+ "quantization" => @quantization,
96
+ }
97
+ end
98
+ end
99
+
100
+ class Model < MLX::NN::Module
101
+ def initialize(args)
102
+ super()
103
+ @args = args
104
+ self.model_type = args.model_type
105
+ self.wrapped_model = Glm4MoeLite::Model.new(
106
+ Glm4MoeLite::ModelArgs.from_dict(args.to_glm4_moe_lite_dict)
107
+ )
108
+ end
109
+
110
+ def call(inputs, cache: nil)
111
+ wrapped_model.call(inputs, cache: cache)
112
+ end
113
+
114
+ def sanitize(weights)
115
+ remapped = {}
116
+ flat_weights = weights.is_a?(Hash) ? weights : weights.to_h
117
+ flat_weights.each do |key, value|
118
+ remapped[_remap_weight_key(key)] = value
119
+ end
120
+ wrapped_model.sanitize(remapped)
121
+ end
122
+
123
+ def layers
124
+ wrapped_model.layers
125
+ end
126
+
127
+ def make_cache
128
+ return wrapped_model.make_cache if wrapped_model.respond_to?(:make_cache)
129
+
130
+ nil
131
+ end
132
+
133
+ def cast_predicate
134
+ return wrapped_model.cast_predicate if wrapped_model.respond_to?(:cast_predicate)
135
+
136
+ lambda { |_key| true }
137
+ end
138
+
139
+ private
140
+
141
+ def _remap_weight_key(key)
142
+ mapped = key.dup
143
+ mapped = mapped.gsub(".attention.", ".self_attn.")
144
+ mapped = mapped.gsub(".block_sparse_moe.", ".mlp.")
145
+ mapped = mapped.gsub(".mlp.router.", ".mlp.gate.")
146
+ mapped
147
+ end
148
+ end
149
+
150
+ Models.register("longcat_flash", Model, ModelArgs)
151
+ end
152
+ end
153
+ end
@@ -0,0 +1,137 @@
1
+ require_relative "longcat_flash"
2
+
3
+ module MlxLm
4
+ module Models
5
+ module LongcatFlashNgram
6
+ class ModelArgs < LongcatFlash::ModelArgs
7
+ field :model_type, default: "longcat_flash_ngram"
8
+ field :attention_method, default: nil
9
+ field :zero_expert_type, default: "identity"
10
+ field :moe_topk, default: nil
11
+ field :expert_ffn_hidden_size, default: nil
12
+ field :zero_expert_num, default: nil
13
+ field :num_layers, default: nil
14
+ field :ngram_vocab_size_ratio, default: 78
15
+ field :emb_neighbor_num, default: 4
16
+ field :emb_split_num, default: 4
17
+ field :mla_scale_q_lora, default: nil
18
+ field :mla_scale_kv_lora, default: nil
19
+ field :router_bias, default: false
20
+
21
+ def self.from_dict(params)
22
+ normalized = params.each_with_object({}) do |(key, value), out|
23
+ out[key.to_s] = value
24
+ end
25
+
26
+ {
27
+ "num_layers" => "num_hidden_layers",
28
+ "moe_topk" => "num_experts_per_tok",
29
+ "expert_ffn_hidden_size" => "moe_intermediate_size",
30
+ }.each do |source_key, target_key|
31
+ next unless normalized.key?(source_key)
32
+
33
+ normalized[target_key] = normalized[source_key] unless normalized.key?(target_key)
34
+ end
35
+
36
+ if normalized.key?("n_routed_experts") && normalized.key?("zero_expert_num") && !normalized.key?("num_local_experts")
37
+ normalized["num_local_experts"] = normalized["n_routed_experts"].to_i + normalized["zero_expert_num"].to_i
38
+ end
39
+
40
+ if normalized.key?("num_attention_heads") && !normalized.key?("num_key_value_heads") && !normalized.key?("num_kv_heads")
41
+ normalized["num_key_value_heads"] = normalized["num_attention_heads"]
42
+ end
43
+
44
+ normalized["model_type"] ||= "longcat_flash_ngram"
45
+ super(normalized)
46
+ end
47
+
48
+ def initialize(**kwargs)
49
+ super
50
+ @num_hidden_layers = @num_layers if kwargs.key?(:num_layers) && !kwargs.key?(:num_hidden_layers) && !@num_layers.nil?
51
+ @num_experts_per_tok = @moe_topk if kwargs.key?(:moe_topk) && !kwargs.key?(:num_experts_per_tok) && !@moe_topk.nil?
52
+ @moe_intermediate_size = @expert_ffn_hidden_size if kwargs.key?(:expert_ffn_hidden_size) && !kwargs.key?(:moe_intermediate_size) && !@expert_ffn_hidden_size.nil?
53
+
54
+ if kwargs.key?(:zero_expert_num) && !@zero_expert_num.nil? && !kwargs.key?(:num_local_experts) && !kwargs.key?(:n_routed_experts) && !@n_routed_experts.nil?
55
+ @n_routed_experts = @n_routed_experts.to_i + @zero_expert_num.to_i
56
+ end
57
+
58
+ if kwargs.key?(:num_attention_heads) && !kwargs.key?(:num_key_value_heads) && !kwargs.key?(:num_kv_heads)
59
+ @num_key_value_heads = @num_attention_heads
60
+ end
61
+
62
+ @num_key_value_heads ||= @num_attention_heads
63
+ end
64
+
65
+ def to_longcat_flash_dict
66
+ routed_experts = @n_routed_experts
67
+ if !@zero_expert_num.nil? && !routed_experts.nil?
68
+ routed_experts = routed_experts.to_i + @zero_expert_num.to_i
69
+ end
70
+
71
+ dict = to_glm4_moe_lite_dict
72
+ dict["model_type"] = @model_type
73
+ dict["n_routed_experts"] = routed_experts unless routed_experts.nil?
74
+ dict
75
+ end
76
+ end
77
+
78
+ class Model < MLX::NN::Module
79
+ def initialize(args)
80
+ super()
81
+ @args = args
82
+ self.model_type = args.model_type
83
+ self.wrapped_model = LongcatFlash::Model.new(
84
+ LongcatFlash::ModelArgs.from_dict(args.to_longcat_flash_dict)
85
+ )
86
+ end
87
+
88
+ def call(inputs, cache: nil)
89
+ wrapped_model.call(inputs, cache: cache)
90
+ end
91
+
92
+ def sanitize(weights)
93
+ remapped = {}
94
+ flat_weights = weights.is_a?(Hash) ? weights : weights.to_h
95
+ flat_weights.each do |key, value|
96
+ remapped[_to_longcat_flash_key(key)] = value
97
+ end
98
+
99
+ sanitized = wrapped_model.sanitize(remapped)
100
+ restored = {}
101
+ sanitized.each do |key, value|
102
+ restored[_from_longcat_flash_key(key)] = value
103
+ end
104
+ restored
105
+ end
106
+
107
+ def layers
108
+ wrapped_model.layers
109
+ end
110
+
111
+ def make_cache
112
+ return wrapped_model.make_cache if wrapped_model.respond_to?(:make_cache)
113
+
114
+ nil
115
+ end
116
+
117
+ def cast_predicate
118
+ return wrapped_model.cast_predicate if wrapped_model.respond_to?(:cast_predicate)
119
+
120
+ lambda { |_key| true }
121
+ end
122
+
123
+ private
124
+
125
+ def _to_longcat_flash_key(key)
126
+ key.to_s.gsub("model.ngram_embeddings.word_embeddings.", "model.embed_tokens.")
127
+ end
128
+
129
+ def _from_longcat_flash_key(key)
130
+ key.to_s.gsub("model.embed_tokens.", "model.ngram_embeddings.word_embeddings.")
131
+ end
132
+ end
133
+
134
+ Models.register("longcat_flash_ngram", Model, ModelArgs)
135
+ end
136
+ end
137
+ end