mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
@@ -0,0 +1,183 @@
1
+ module MlxLm
2
+ module Models
3
+ module Llama
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "llama"
6
+ field :hidden_size, default: 4096
7
+ field :num_hidden_layers, default: 32
8
+ field :num_attention_heads, default: 32
9
+ field :num_key_value_heads, default: nil
10
+ field :intermediate_size, default: 11008
11
+ field :vocab_size, default: 32000
12
+ field :rms_norm_eps, default: 1e-6
13
+ field :rope_theta, default: 10000.0
14
+ field :rope_traditional, default: false
15
+ field :rope_scaling, default: nil
16
+ field :tie_word_embeddings, default: true
17
+ field :attention_bias, default: false
18
+ field :mlp_bias, default: false
19
+ field :head_dim, default: nil
20
+ field :max_position_embeddings, default: 2048
21
+
22
+ def initialize(**kwargs)
23
+ super
24
+ @num_key_value_heads ||= @num_attention_heads
25
+ @head_dim ||= @hidden_size / @num_attention_heads
26
+ end
27
+ end
28
+
29
+ class Attention < MLX::NN::Module
30
+ def initialize(args)
31
+ super()
32
+ dim = args.hidden_size
33
+ @n_heads = args.num_attention_heads
34
+ @n_kv_heads = args.num_key_value_heads
35
+ @head_dim = args.head_dim
36
+ @scale = @head_dim**(-0.5)
37
+
38
+ bias = args.attention_bias
39
+ self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: bias)
40
+ self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: bias)
41
+ self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: bias)
42
+ self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: bias)
43
+
44
+ self.rope = MLX::NN::RoPE.new(
45
+ @head_dim,
46
+ traditional: args.rope_traditional,
47
+ base: args.rope_theta
48
+ )
49
+ end
50
+
51
+ def call(x, mask: nil, cache: nil)
52
+ mx = MLX::Core
53
+ b, l, _d = x.shape
54
+
55
+ queries = q_proj.call(x)
56
+ keys = k_proj.call(x)
57
+ values = v_proj.call(x)
58
+
59
+ queries = queries.reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
60
+ keys = keys.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
61
+ values = values.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
62
+
63
+ if cache
64
+ queries = rope.call(queries, offset: cache.offset)
65
+ keys = rope.call(keys, offset: cache.offset)
66
+ keys, values = cache.update_and_fetch(keys, values)
67
+ else
68
+ queries = rope.call(queries)
69
+ keys = rope.call(keys)
70
+ end
71
+
72
+ output = mx.scaled_dot_product_attention(
73
+ queries, keys, values, @scale, mask
74
+ )
75
+
76
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
77
+ o_proj.call(output)
78
+ end
79
+ end
80
+
81
+ class MLP < MLX::NN::Module
82
+ def initialize(args)
83
+ super()
84
+ dim = args.hidden_size
85
+ hidden_dim = args.intermediate_size
86
+ bias = args.mlp_bias
87
+
88
+ self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: bias)
89
+ self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: bias)
90
+ self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: bias)
91
+ end
92
+
93
+ def call(x)
94
+ mx = MLX::Core
95
+ down_proj.call(MLX::NN.silu(gate_proj.call(x)) * up_proj.call(x))
96
+ end
97
+ end
98
+
99
+ class TransformerBlock < MLX::NN::Module
100
+ def initialize(args)
101
+ super()
102
+ self.self_attn = Attention.new(args)
103
+ self.mlp = MLP.new(args)
104
+ self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
105
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
106
+ end
107
+
108
+ def call(x, mask: nil, cache: nil)
109
+ r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
110
+ h = x + r
111
+ r = mlp.call(post_attention_layernorm.call(h))
112
+ h + r
113
+ end
114
+ end
115
+
116
+ class LlamaModel < MLX::NN::Module
117
+ def initialize(args)
118
+ super()
119
+ @args = args
120
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
121
+ self.layers = Array.new(args.num_hidden_layers) { TransformerBlock.new(args) }
122
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
123
+ end
124
+
125
+ def call(inputs, cache: nil)
126
+ mx = MLX::Core
127
+ h = embed_tokens.call(inputs)
128
+ layer_cache = cache || [nil] * layers.length
129
+
130
+ mask = _create_attention_mask(h, layer_cache[0])
131
+
132
+ layers.each_with_index do |layer, i|
133
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
134
+ end
135
+
136
+ norm.call(h)
137
+ end
138
+
139
+ private
140
+
141
+ def _create_attention_mask(h, cache)
142
+ mx = MLX::Core
143
+ n = h.shape[1]
144
+ return nil if n == 1
145
+ "causal"
146
+ end
147
+ end
148
+
149
+ class Model < MLX::NN::Module
150
+ def initialize(args)
151
+ super()
152
+ @args = args
153
+ self.model = LlamaModel.new(args)
154
+ unless args.tie_word_embeddings
155
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
156
+ end
157
+ end
158
+
159
+ def call(inputs, cache: nil)
160
+ out = model.call(inputs, cache: cache)
161
+ if @args.tie_word_embeddings
162
+ model.embed_tokens.as_linear(out)
163
+ else
164
+ lm_head.call(out)
165
+ end
166
+ end
167
+
168
+ def sanitize(weights)
169
+ result = weights.reject { |k, _| k.include?("self_attn.rotary_emb.inv_freq") }
170
+ result.delete("lm_head.weight") if @args.tie_word_embeddings
171
+ result
172
+ end
173
+
174
+ def layers
175
+ model.layers
176
+ end
177
+ end
178
+
179
+ # Register in model registry
180
+ Models.register("llama", Model, ModelArgs)
181
+ end
182
+ end
183
+ end
@@ -0,0 +1,357 @@
1
+ require_relative "activations"
2
+ require_relative "cache"
3
+ require_relative "rope_utils"
4
+ require_relative "switch_layers"
5
+
6
+ module MlxLm
7
+ module Models
8
+ module Llama4
9
+ class TextArgs < BaseModelArgs
10
+ field :model_type, default: "llama4_text"
11
+ field :attention_bias, default: false
12
+ field :attention_chunk_size, default: 1024
13
+ field :head_dim, default: nil
14
+ field :hidden_size
15
+ field :interleave_moe_layer_step, default: 1
16
+ field :intermediate_size
17
+ field :intermediate_size_mlp, default: nil
18
+ field :max_position_embeddings, default: 4096
19
+ field :num_attention_heads
20
+ field :num_experts_per_tok, default: 1
21
+ field :num_hidden_layers
22
+ field :num_key_value_heads, default: nil
23
+ field :num_local_experts, default: 1
24
+ field :rms_norm_eps, default: 1e-5
25
+ field :rope_scaling, default: nil
26
+ field :rope_theta, default: 10_000.0
27
+ field :use_qk_norm, default: false
28
+ field :vocab_size
29
+ field :attn_temperature_tuning, default: 4
30
+ field :floor_scale, default: 8192
31
+ field :attn_scale, default: 0.1
32
+
33
+ def initialize(**kwargs)
34
+ super
35
+ @num_key_value_heads ||= @num_attention_heads
36
+ @head_dim ||= @hidden_size / @num_attention_heads
37
+ @intermediate_size_mlp ||= @intermediate_size
38
+ @attention_chunk_size = [@attention_chunk_size.to_i, 1].max
39
+ @interleave_moe_layer_step = [@interleave_moe_layer_step.to_i, 1].max
40
+ end
41
+ end
42
+
43
+ class ModelArgs < BaseModelArgs
44
+ field :model_type, default: "llama4"
45
+ field :text_config, default: nil
46
+
47
+ def self.from_dict(params)
48
+ has_text_config = params.key?("text_config") || params.key?(:text_config)
49
+ return super if has_text_config
50
+
51
+ new(model_type: params["model_type"] || params[:model_type], text_config: params)
52
+ end
53
+
54
+ def initialize(**kwargs)
55
+ super
56
+ @text_config = _to_text_args(@text_config || {})
57
+ end
58
+
59
+ private
60
+
61
+ def _to_text_args(config)
62
+ return config if config.is_a?(TextArgs)
63
+
64
+ normalized = {}
65
+ config.each { |key, value| normalized[key.to_s] = value }
66
+ TextArgs.from_dict(normalized)
67
+ end
68
+ end
69
+
70
+ class Attention < MLX::NN::Module
71
+ def initialize(args, layer_idx)
72
+ super()
73
+
74
+ dim = args.hidden_size
75
+ @n_heads = args.num_attention_heads
76
+ @n_kv_heads = args.num_key_value_heads
77
+ @head_dim = args.head_dim
78
+ @scale = @head_dim**(-0.5)
79
+ @use_rope = ((layer_idx + 1) % 4) != 0
80
+ @attn_temperature_tuning = args.attn_temperature_tuning
81
+ @floor_scale = args.floor_scale
82
+ @attn_scale = args.attn_scale
83
+ @use_qk_norm = args.use_qk_norm && @use_rope
84
+
85
+ self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: args.attention_bias)
86
+ self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: args.attention_bias)
87
+ self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: args.attention_bias)
88
+ self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: args.attention_bias)
89
+
90
+ if @use_rope
91
+ self.rope = MlxLm::Models.initialize_rope(
92
+ @head_dim,
93
+ args.rope_theta,
94
+ true,
95
+ args.rope_scaling,
96
+ max_position_embeddings: args.max_position_embeddings
97
+ )
98
+ end
99
+ end
100
+
101
+ def call(x, mask: nil, cache: nil)
102
+ mx = MLX::Core
103
+ b, l, _d = x.shape
104
+
105
+ queries = q_proj.call(x).reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
106
+ keys = k_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
107
+ values = v_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
108
+
109
+ offset = cache ? cache.offset : 0
110
+ if @use_rope
111
+ queries = rope.call(queries, offset: offset)
112
+ keys = rope.call(keys, offset: offset)
113
+ end
114
+
115
+ if @use_qk_norm
116
+ queries = mx.rms_norm(queries, nil, 1e-6)
117
+ keys = mx.rms_norm(keys, nil, 1e-6)
118
+ end
119
+
120
+ if @attn_temperature_tuning && !@use_rope
121
+ attn_scales = (mx.log(mx.floor(mx.arange(offset + 1, offset + l + 1) / @floor_scale) + 1.0) * @attn_scale) + 1.0
122
+ queries = (queries * attn_scales.reshape([l, 1])).astype(queries.dtype)
123
+ end
124
+
125
+ keys, values = cache.update_and_fetch(keys, values) if cache
126
+
127
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
128
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
129
+ o_proj.call(output)
130
+ end
131
+ end
132
+
133
+ class MLP < MLX::NN::Module
134
+ def initialize(args, intermediate_size = nil)
135
+ super()
136
+ dim = args.hidden_size
137
+ hidden_dim = intermediate_size || args.intermediate_size
138
+
139
+ self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false)
140
+ self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: false)
141
+ self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false)
142
+ end
143
+
144
+ def call(x)
145
+ down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x)))
146
+ end
147
+ end
148
+
149
+ class MoE < MLX::NN::Module
150
+ def initialize(args)
151
+ super()
152
+ @top_k = args.num_experts_per_tok
153
+ raise ArgumentError, "Only 1 expert per token supported" unless @top_k == 1
154
+
155
+ @num_experts = args.num_local_experts
156
+ self.experts = SwitchLayers::SwitchGLU.new(
157
+ args.hidden_size,
158
+ args.intermediate_size,
159
+ @num_experts
160
+ )
161
+ self.router = MLX::NN::Linear.new(args.hidden_size, @num_experts, bias: false)
162
+ self.shared_expert = MLP.new(args)
163
+ end
164
+
165
+ def call(x)
166
+ mx = MLX::Core
167
+ logits = router.call(x)
168
+
169
+ indices = mx.argpartition(logits * -1.0, @top_k - 1, -1)
170
+ take_ids = mx.array((0...@top_k).to_a, dtype: mx.int32)
171
+ indices = mx.take(indices, take_ids, -1)
172
+ scores = mx.take_along_axis(logits, indices, -1)
173
+ scores = mx.sigmoid(scores.astype(mx.float32)).astype(x.dtype)
174
+
175
+ out = mx.squeeze(experts.call(x * scores, indices), 2)
176
+ out + shared_expert.call(x)
177
+ end
178
+ end
179
+
180
+ class TransformerBlock < MLX::NN::Module
181
+ def initialize(args, layer_idx)
182
+ super()
183
+ self.self_attn = Attention.new(args, layer_idx)
184
+ is_moe_layer = (layer_idx % args.interleave_moe_layer_step) == (args.interleave_moe_layer_step - 1)
185
+ if is_moe_layer
186
+ self.feed_forward = MoE.new(args)
187
+ else
188
+ self.feed_forward = MLP.new(args, args.intermediate_size_mlp)
189
+ end
190
+ self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
191
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
192
+ end
193
+
194
+ def call(x, mask: nil, cache: nil)
195
+ r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
196
+ h = x + r
197
+ r = feed_forward.call(post_attention_layernorm.call(h))
198
+ h + r
199
+ end
200
+ end
201
+
202
+ class LlamaModel < MLX::NN::Module
203
+ def initialize(args)
204
+ super()
205
+ @attention_chunk_size = args.attention_chunk_size
206
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
207
+ self.layers = Array.new(args.num_hidden_layers) { |i| TransformerBlock.new(args, i) }
208
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
209
+ end
210
+
211
+ def call(inputs, cache: nil)
212
+ mx = MLX::Core
213
+ h = embed_tokens.call(inputs)
214
+ layer_cache = cache || Array.new(layers.length)
215
+
216
+ if cache
217
+ cache.each_with_index do |c, idx|
218
+ next unless ((idx + 1) % 4) != 0
219
+ next unless c && c.respond_to?(:maybe_trim_front)
220
+
221
+ c.maybe_trim_front
222
+ end
223
+ first_cache = cache[0]
224
+ start = first_cache&.respond_to?(:start_position) ? first_cache.start_position : 0
225
+ offset = first_cache&.respond_to?(:offset) ? first_cache.offset : 0
226
+ else
227
+ start = 0
228
+ offset = 0
229
+ end
230
+
231
+ finish = offset + h.shape[1]
232
+ linds = mx.arange(start, finish)
233
+ rinds = mx.arange(offset, finish).reshape([h.shape[1], 1])
234
+
235
+ block_pos = mx.abs(
236
+ mx.floor_divide(linds, @attention_chunk_size) -
237
+ mx.floor_divide(rinds, @attention_chunk_size)
238
+ )
239
+ token_pos = mx.less_equal(linds, rinds)
240
+ chunk_mask = mx.logical_and(mx.equal(block_pos, 0), token_pos)
241
+ global_mask = _create_attention_mask(h, layer_cache[3])
242
+
243
+ layers.each_with_index do |layer, idx|
244
+ use_chunked_attention = ((idx + 1) % 4) != 0
245
+ mask = use_chunked_attention ? chunk_mask : global_mask
246
+ h = layer.call(h, mask: mask, cache: layer_cache[idx])
247
+ end
248
+
249
+ norm.call(h)
250
+ end
251
+
252
+ private
253
+
254
+ def _create_attention_mask(h, cache = nil)
255
+ return cache.make_mask(h.shape[1]) if cache && cache.respond_to?(:make_mask)
256
+ return nil if h.shape[1] == 1
257
+
258
+ "causal"
259
+ end
260
+ end
261
+
262
+ class LanguageModel < MLX::NN::Module
263
+ def initialize(args)
264
+ super()
265
+ self.model = LlamaModel.new(args)
266
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
267
+ end
268
+
269
+ def call(inputs, cache: nil)
270
+ lm_head.call(model.call(inputs, cache: cache))
271
+ end
272
+ end
273
+
274
+ class Model < MLX::NN::Module
275
+ def initialize(args)
276
+ super()
277
+ @args = args
278
+ self.model_type = args.model_type
279
+ self.language_model = LanguageModel.new(args.text_config)
280
+ end
281
+
282
+ def call(inputs, cache: nil)
283
+ language_model.call(inputs, cache: cache)
284
+ end
285
+
286
+ def sanitize(weights)
287
+ mx = MLX::Core
288
+
289
+ sanitized = {}
290
+ weights.each do |key, value|
291
+ next if _multimodal_key?(key)
292
+
293
+ sanitized[key] = value
294
+ end
295
+
296
+ @args.text_config.num_hidden_layers.to_i.times do |layer_idx|
297
+ prefix = "language_model.model.layers.#{layer_idx}.feed_forward.experts"
298
+
299
+ gate_up = _pop_first(
300
+ sanitized,
301
+ ["#{prefix}.gate_up_proj", "#{prefix}.gate_up_proj.weight"]
302
+ )
303
+ if gate_up
304
+ split = gate_up.shape[-1] / 2
305
+ gate_proj, up_proj = mx.split(gate_up, [split], -1)
306
+ sanitized["#{prefix}.gate_proj.weight"] = mx.swapaxes(gate_proj, 1, 2)
307
+ sanitized["#{prefix}.up_proj.weight"] = mx.swapaxes(up_proj, 1, 2)
308
+ end
309
+
310
+ down_proj = _pop_first(
311
+ sanitized,
312
+ ["#{prefix}.down_proj", "#{prefix}.down_proj.weight"]
313
+ )
314
+ if down_proj
315
+ sanitized["#{prefix}.down_proj.weight"] = mx.swapaxes(down_proj, 1, 2)
316
+ end
317
+ end
318
+
319
+ sanitized
320
+ end
321
+
322
+ def layers
323
+ language_model.model.layers
324
+ end
325
+
326
+ def make_cache
327
+ chunk_size = [@args.text_config.attention_chunk_size.to_i, 1].max
328
+ Array.new(layers.length) do |i|
329
+ if ((i + 1) % 4) != 0
330
+ MlxLm::ChunkedKVCache.new(chunk_size)
331
+ else
332
+ MlxLm::KVCache.new
333
+ end
334
+ end
335
+ end
336
+
337
+ private
338
+
339
+ def _pop_first(weights, keys)
340
+ keys.each do |key|
341
+ return weights.delete(key) if weights.key?(key)
342
+ end
343
+ nil
344
+ end
345
+
346
+ def _multimodal_key?(key)
347
+ key_name = key.to_s
348
+ key_name.include?("vision_model") ||
349
+ key_name.include?("vision_tower") ||
350
+ key_name.include?("multi_modal_projector")
351
+ end
352
+ end
353
+
354
+ Models.register("llama4", Model, ModelArgs)
355
+ end
356
+ end
357
+ end