mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
@@ -0,0 +1,206 @@
1
+ require_relative "rope_utils"
2
+ require_relative "switch_layers"
3
+
4
+ module MlxLm
5
+ module Models
6
+ module PhiMoe
7
+ class ModelArgs < BaseModelArgs
8
+ field :model_type, default: "phimoe"
9
+ field :vocab_size, default: 32064
10
+ field :hidden_size, default: 4096
11
+ field :intermediate_size, default: 6400
12
+ field :num_hidden_layers, default: 32
13
+ field :num_attention_heads, default: 32
14
+ field :num_key_value_heads, default: 8
15
+ field :max_position_embeddings, default: 131072
16
+ field :original_max_position_embeddings, default: 4096
17
+ field :rms_norm_eps, default: 1e-6
18
+ field :rope_scaling, default: nil
19
+ field :num_local_experts, default: 16
20
+ field :num_experts_per_tok, default: 2
21
+ field :rope_theta, default: 10_000.0
22
+ end
23
+
24
+ class Attention < MLX::NN::Module
25
+ def initialize(args)
26
+ super()
27
+
28
+ dim = args.hidden_size
29
+ @n_heads = args.num_attention_heads
30
+ @n_kv_heads = args.num_key_value_heads
31
+ @head_dim = dim / @n_heads
32
+ @scale = @head_dim**(-0.5)
33
+
34
+ self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: true)
35
+ self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: true)
36
+ self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: true)
37
+ self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: true)
38
+
39
+ scaling = args.rope_scaling || {}
40
+ self.rope = SuScaledRoPE.new(
41
+ @head_dim,
42
+ base: args.rope_theta,
43
+ max_position_embeddings: args.max_position_embeddings,
44
+ original_max_position_embeddings: args.original_max_position_embeddings,
45
+ short_factor: _config_value(scaling, "short_factor", 1.0),
46
+ long_factor: _config_value(scaling, "long_factor", 1.0),
47
+ short_mscale: _config_value(scaling, "short_mscale"),
48
+ long_mscale: _config_value(scaling, "long_mscale")
49
+ )
50
+ end
51
+
52
+ def call(x, mask: nil, cache: nil)
53
+ mx = MLX::Core
54
+ b, l, _d = x.shape
55
+
56
+ queries = q_proj.call(x).reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
57
+ keys = k_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
58
+ values = v_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
59
+
60
+ if cache
61
+ queries = rope.call(queries, offset: cache.offset)
62
+ keys = rope.call(keys, offset: cache.offset)
63
+ keys, values = cache.update_and_fetch(keys, values)
64
+ else
65
+ queries = rope.call(queries)
66
+ keys = rope.call(keys)
67
+ end
68
+
69
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
70
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
71
+ o_proj.call(output)
72
+ end
73
+
74
+ private
75
+
76
+ def _config_value(config, key, default = nil)
77
+ return default if config.nil?
78
+ return config[key] if config.key?(key)
79
+
80
+ config.fetch(key.to_sym, default)
81
+ end
82
+ end
83
+
84
+ class PhiMoESparseMoeBlock < MLX::NN::Module
85
+ def initialize(args)
86
+ super()
87
+
88
+ @hidden_dim = args.hidden_size
89
+ @ffn_dim = args.intermediate_size
90
+ @num_experts = args.num_local_experts
91
+ @top_k = args.num_experts_per_tok
92
+
93
+ self.gate = MLX::NN::Linear.new(@hidden_dim, @num_experts, bias: false)
94
+ self.switch_mlp = SwitchLayers::SwitchGLU.new(@hidden_dim, @ffn_dim, @num_experts)
95
+ end
96
+
97
+ def call(x)
98
+ mx = MLX::Core
99
+
100
+ k = [@top_k, @num_experts].min
101
+ gates = gate.call(x)
102
+ inds = mx.stop_gradient(mx.argpartition(gates * -1.0, k - 1, -1))
103
+ take_ids = mx.array((0...k).to_a, dtype: mx.int32)
104
+ inds = mx.take(inds, take_ids, -1)
105
+ scores = mx.take_along_axis(gates, inds, -1)
106
+ scores = mx.softmax(scores.astype(mx.float32), -1).astype(gates.dtype)
107
+
108
+ y = switch_mlp.call(x, inds)
109
+ mx.sum(y * mx.expand_dims(scores, -1), -2)
110
+ end
111
+ end
112
+
113
+ class DecoderLayer < MLX::NN::Module
114
+ def initialize(args)
115
+ super()
116
+ self.self_attn = Attention.new(args)
117
+ self.block_sparse_moe = PhiMoESparseMoeBlock.new(args)
118
+ self.input_layernorm = MLX::NN::LayerNorm.new(args.hidden_size, eps: args.rms_norm_eps)
119
+ self.post_attention_layernorm = MLX::NN::LayerNorm.new(args.hidden_size, eps: args.rms_norm_eps)
120
+ end
121
+
122
+ def call(x, mask: nil, cache: nil)
123
+ residual = x
124
+ hidden_states = input_layernorm.call(x)
125
+ hidden_states = self_attn.call(hidden_states, mask: mask, cache: cache)
126
+ hidden_states = residual + hidden_states
127
+
128
+ residual = hidden_states
129
+ hidden_states = post_attention_layernorm.call(hidden_states)
130
+ hidden_states = block_sparse_moe.call(hidden_states)
131
+ residual + hidden_states
132
+ end
133
+ end
134
+
135
+ class PhiMoEModel < MLX::NN::Module
136
+ def initialize(args)
137
+ super()
138
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
139
+ self.layers = Array.new(args.num_hidden_layers) { DecoderLayer.new(args) }
140
+ self.norm = MLX::NN::LayerNorm.new(args.hidden_size, eps: args.rms_norm_eps)
141
+ end
142
+
143
+ def call(inputs, cache: nil)
144
+ h = embed_tokens.call(inputs)
145
+ layer_cache = cache || [nil] * layers.length
146
+
147
+ mask = nil
148
+ mask = "causal" if h.shape[1] > 1
149
+
150
+ layers.each_with_index do |layer, layer_idx|
151
+ h = layer.call(h, mask: mask, cache: layer_cache[layer_idx])
152
+ end
153
+
154
+ norm.call(h)
155
+ end
156
+ end
157
+
158
+ class Model < MLX::NN::Module
159
+ def initialize(args)
160
+ super()
161
+ @args = args
162
+ self.model_type = args.model_type
163
+ self.model = PhiMoEModel.new(args)
164
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: true)
165
+ end
166
+
167
+ def call(inputs, cache: nil)
168
+ lm_head.call(model.call(inputs, cache: cache))
169
+ end
170
+
171
+ def sanitize(weights)
172
+ return weights unless weights.key?("model.layers.0.block_sparse_moe.experts.0.w1.weight")
173
+
174
+ mx = MLX::Core
175
+ result = weights.dup
176
+
177
+ @args.num_hidden_layers.times do |layer_idx|
178
+ prefix = "model.layers.#{layer_idx}"
179
+ [["w1", "gate_proj"], ["w2", "down_proj"], ["w3", "up_proj"]].each do |source, target|
180
+ %w[weight scales biases].each do |param|
181
+ first_key = "#{prefix}.block_sparse_moe.experts.0.#{source}.#{param}"
182
+ next unless result.key?(first_key)
183
+
184
+ expert_keys = (0...@args.num_local_experts).map do |expert_idx|
185
+ "#{prefix}.block_sparse_moe.experts.#{expert_idx}.#{source}.#{param}"
186
+ end
187
+ next unless expert_keys.all? { |key| result.key?(key) }
188
+
189
+ stacked = expert_keys.map { |key| result.delete(key) }
190
+ result["#{prefix}.block_sparse_moe.switch_mlp.#{target}.#{param}"] = mx.stack(stacked)
191
+ end
192
+ end
193
+ end
194
+
195
+ result
196
+ end
197
+
198
+ def layers
199
+ model.layers
200
+ end
201
+ end
202
+
203
+ Models.register("phimoe", Model, ModelArgs)
204
+ end
205
+ end
206
+ end
@@ -0,0 +1,208 @@
1
+ require_relative "switch_layers"
2
+
3
+ module MlxLm
4
+ module Models
5
+ module Phixtral
6
+ class ModelArgs < BaseModelArgs
7
+ field :model_type, default: "phixtral"
8
+ field :num_vocab, default: 51_200
9
+ field :model_dim, default: 2_560
10
+ field :num_heads, default: 32
11
+ field :num_layers, default: 32
12
+ field :rotary_dim, default: 32
13
+ field :num_experts_per_tok, default: 2
14
+ field :num_local_experts, default: 4
15
+ end
16
+
17
+ class RoPEAttention < MLX::NN::Module
18
+ def initialize(dims, num_heads, rotary_dim)
19
+ super()
20
+ @num_heads = num_heads
21
+ @head_dim = dims / num_heads
22
+ @scale = @head_dim**(-0.5)
23
+
24
+ self.rope = MLX::NN::RoPE.new(rotary_dim, traditional: false)
25
+ self.wqkv = MLX::NN::Linear.new(dims, 3 * dims)
26
+ self.out_proj = MLX::NN::Linear.new(dims, dims)
27
+ end
28
+
29
+ def call(x, mask: nil, cache: nil)
30
+ mx = MLX::Core
31
+ b, l, d = x.shape
32
+
33
+ qkv = wqkv.call(x)
34
+ queries, keys, values = mx.split(qkv, [d, 2 * d], -1)
35
+
36
+ queries = queries.reshape([b, l, @num_heads, @head_dim]).transpose([0, 2, 1, 3])
37
+ keys = keys.reshape([b, l, @num_heads, @head_dim]).transpose([0, 2, 1, 3])
38
+ values = values.reshape([b, l, @num_heads, @head_dim]).transpose([0, 2, 1, 3])
39
+
40
+ if cache
41
+ queries = rope.call(queries, offset: cache.offset)
42
+ keys = rope.call(keys, offset: cache.offset)
43
+ keys, values = cache.update_and_fetch(keys, values)
44
+ else
45
+ queries = rope.call(queries)
46
+ keys = rope.call(keys)
47
+ end
48
+
49
+ queries = queries.astype(mx.float32)
50
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask).astype(values.dtype)
51
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, d])
52
+ out_proj.call(output)
53
+ end
54
+ end
55
+
56
+ class MOE < MLX::NN::Module
57
+ def initialize(args, dim, hidden_dim)
58
+ super()
59
+ @num_experts = args.num_local_experts
60
+ @num_experts_per_tok = args.num_experts_per_tok
61
+
62
+ self.switch_mlp = SwitchLayers::SwitchMLP.new(
63
+ dim,
64
+ hidden_dim,
65
+ @num_experts,
66
+ bias: true
67
+ )
68
+ self.gate = MLX::NN::Linear.new(args.model_dim, @num_experts, bias: false)
69
+ end
70
+
71
+ def call(x)
72
+ mx = MLX::Core
73
+ k = @num_experts_per_tok
74
+
75
+ gates = gate.call(x)
76
+ inds = mx.stop_gradient(mx.argpartition(gates * -1.0, k - 1, -1))
77
+ take_ids = mx.array((0...k).to_a, dtype: mx.int32)
78
+ inds = mx.take(inds, take_ids, -1)
79
+
80
+ scores = mx.take_along_axis(gates, inds, -1)
81
+ scores = mx.softmax(scores.astype(mx.float32), -1).astype(gates.dtype)
82
+
83
+ y = switch_mlp.call(x, inds)
84
+ mx.sum(y * mx.expand_dims(scores, -1), -2)
85
+ end
86
+ end
87
+
88
+ class ParallelBlock < MLX::NN::Module
89
+ def initialize(config)
90
+ super()
91
+ dims = config.model_dim
92
+ mlp_dims = dims * 4
93
+
94
+ self.mixer = RoPEAttention.new(dims, config.num_heads, config.rotary_dim)
95
+ self.ln = MLX::NN::LayerNorm.new(dims)
96
+ self.moe = MOE.new(config, dims, mlp_dims)
97
+ end
98
+
99
+ def call(x, mask: nil, cache: nil)
100
+ h = ln.call(x)
101
+ attn_h = mixer.call(h, mask: mask, cache: cache)
102
+ ff_h = moe.call(h)
103
+ attn_h + ff_h + x
104
+ end
105
+ end
106
+
107
+ class Embd < MLX::NN::Module
108
+ def initialize(config)
109
+ super()
110
+ self.wte = MLX::NN::Embedding.new(config.num_vocab, config.model_dim)
111
+ end
112
+
113
+ def call(x)
114
+ wte.call(x)
115
+ end
116
+ end
117
+
118
+ class TransformerDecoder < MLX::NN::Module
119
+ def initialize(config)
120
+ super()
121
+ self.embd = Embd.new(config)
122
+ self.h = Array.new(config.num_layers) { ParallelBlock.new(config) }
123
+ end
124
+
125
+ def call(x, mask: nil, cache: nil)
126
+ hidden = embd.call(x)
127
+ layer_cache = cache || [nil] * h.length
128
+
129
+ h.each_with_index do |layer, i|
130
+ hidden = layer.call(hidden, mask: mask, cache: layer_cache[i])
131
+ end
132
+
133
+ hidden
134
+ end
135
+ end
136
+
137
+ class OutputHead < MLX::NN::Module
138
+ def initialize(config)
139
+ super()
140
+ self.ln = MLX::NN::LayerNorm.new(config.model_dim)
141
+ self.linear = MLX::NN::Linear.new(config.model_dim, config.num_vocab)
142
+ end
143
+
144
+ def call(inputs)
145
+ linear.call(ln.call(inputs))
146
+ end
147
+ end
148
+
149
+ class Model < MLX::NN::Module
150
+ def initialize(config)
151
+ super()
152
+ @args = config
153
+
154
+ self.model_type = config.model_type
155
+ self.transformer = TransformerDecoder.new(config)
156
+ self.lm_head = OutputHead.new(config)
157
+ end
158
+
159
+ def call(x, mask: nil, cache: nil)
160
+ local_mask = mask || _create_attention_mask(x, cache)
161
+ y = transformer.call(x, mask: local_mask, cache: cache)
162
+ lm_head.call(y)
163
+ end
164
+
165
+ def sanitize(weights)
166
+ first_key = "transformer.h.0.moe.mlp.0.fc1.weight"
167
+ return weights unless weights.key?(first_key)
168
+
169
+ mx = MLX::Core
170
+ result = weights.dup
171
+
172
+ @args.num_layers.times do |layer_idx|
173
+ prefix = "transformer.h.#{layer_idx}"
174
+ %w[fc1 fc2].each do |proj|
175
+ %w[weight scales biases bias].each do |suffix|
176
+ expert_keys = (0...@args.num_local_experts).map do |expert_idx|
177
+ "#{prefix}.moe.mlp.#{expert_idx}.#{proj}.#{suffix}"
178
+ end
179
+ next unless expert_keys.all? { |k| result.key?(k) }
180
+
181
+ stacked = expert_keys.map { |k| result.delete(k) }
182
+ result["#{prefix}.moe.switch_mlp.#{proj}.#{suffix}"] = mx.stack(stacked)
183
+ end
184
+ end
185
+ end
186
+
187
+ result
188
+ end
189
+
190
+ def layers
191
+ transformer.h
192
+ end
193
+
194
+ private
195
+
196
+ def _create_attention_mask(tokens, cache)
197
+ first_cache = cache.is_a?(Array) ? cache[0] : cache
198
+ return first_cache.make_mask(tokens.shape[1]) if first_cache && first_cache.respond_to?(:make_mask)
199
+ return nil if tokens.shape[1] == 1
200
+
201
+ "causal"
202
+ end
203
+ end
204
+
205
+ Models.register("phixtral", Model, ModelArgs)
206
+ end
207
+ end
208
+ end
@@ -0,0 +1,37 @@
1
+ module MlxLm
2
+ module Models
3
+ module PipelineMixin
4
+ attr_accessor :pipeline_rank, :pipeline_size, :start_idx, :end_idx
5
+
6
+ def initialize(*args, **kwargs)
7
+ super(*args, **kwargs)
8
+ @pipeline_rank = 0
9
+ @pipeline_size = 1
10
+ @start_idx = 0
11
+ @end_idx = nil
12
+ end
13
+
14
+ def pipeline_layers
15
+ layers[@start_idx...@end_idx]
16
+ end
17
+
18
+ def pipeline(group)
19
+ # Split layers in reverse so rank=0 gets the last layers and
20
+ # rank=pipeline_size-1 gets the first.
21
+ @pipeline_rank = group.rank
22
+ @pipeline_size = group.size
23
+ layers_per_rank = layers.length / @pipeline_size
24
+ extra = layers.length - (layers_per_rank * @pipeline_size)
25
+ layers_per_rank += 1 if @pipeline_rank < extra
26
+
27
+ @start_idx = (@pipeline_size - @pipeline_rank - 1) * layers_per_rank
28
+ @end_idx = @start_idx + layers_per_rank
29
+
30
+ self.layers = layers[0...@end_idx]
31
+ # Keep layer numbering stable for checkpoint loading.
32
+ self.layers[0...@start_idx] = Array.new(@start_idx, nil)
33
+ self
34
+ end
35
+ end
36
+ end
37
+ end
@@ -0,0 +1,47 @@
1
+ module MlxLm
2
+ module Models
3
+ module Pixtral
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "pixtral"
6
+ field :text_config
7
+
8
+ def initialize(**kwargs)
9
+ super
10
+ @text_config ||= {}
11
+ @text_config["tie_word_embeddings"] = false
12
+ unless @text_config.key?("num_attention_heads") || @text_config.key?(:num_attention_heads)
13
+ @text_config["num_attention_heads"] = 32
14
+ end
15
+ end
16
+ end
17
+
18
+ class Model < MLX::NN::Module
19
+ def initialize(args)
20
+ super()
21
+ @args = args
22
+ self.model_type = args.model_type
23
+ self.language_model = Llama::Model.new(Llama::ModelArgs.from_dict(args.text_config))
24
+ end
25
+
26
+ def call(inputs, cache: nil, input_embeddings: nil)
27
+ language_model.call(inputs, cache: cache)
28
+ end
29
+
30
+ def sanitize(weights)
31
+ weights.reject do |key, _|
32
+ key == "vision_tower" ||
33
+ key.start_with?("vision_tower.") ||
34
+ key == "multi_modal_projector" ||
35
+ key.start_with?("multi_modal_projector.")
36
+ end
37
+ end
38
+
39
+ def layers
40
+ language_model.model.layers
41
+ end
42
+ end
43
+
44
+ Models.register("pixtral", Model, ModelArgs)
45
+ end
46
+ end
47
+ end
@@ -0,0 +1,169 @@
1
+ module MlxLm
2
+ module Models
3
+ module Plamo
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "plamo"
6
+ field :hidden_size
7
+ field :num_hidden_layers
8
+ field :intermediate_size
9
+ field :num_attention_heads
10
+ field :rms_norm_eps
11
+ field :vocab_size
12
+ field :n_shared_head, default: 8
13
+ field :rope_theta, default: 10_000.0
14
+ field :rope_traditional, default: false
15
+ end
16
+
17
+ class Attention < MLX::NN::Module
18
+ def initialize(config)
19
+ super()
20
+ @config = config
21
+ @hidden_size = config.hidden_size
22
+ @q_num_heads = config.num_attention_heads
23
+ @head_dim = @hidden_size / @q_num_heads
24
+ @qk_dim = @head_dim
25
+ @v_dim = @head_dim
26
+ @k_num_heads = (@q_num_heads.to_f / config.n_shared_head).ceil
27
+ @v_num_heads = @k_num_heads
28
+ @scale = @head_dim**(-0.5)
29
+
30
+ self.q_proj = MLX::NN::Linear.new(@hidden_size, @q_num_heads * @qk_dim, bias: false)
31
+ self.k_proj = MLX::NN::Linear.new(@hidden_size, @k_num_heads * @qk_dim, bias: false)
32
+ self.v_proj = MLX::NN::Linear.new(@hidden_size, @v_num_heads * @v_dim, bias: false)
33
+ self.o_proj = MLX::NN::Linear.new(@q_num_heads * @v_dim, @hidden_size, bias: false)
34
+ self.rotary_emb = MLX::NN::RoPE.new(
35
+ @head_dim,
36
+ traditional: config.rope_traditional,
37
+ base: config.rope_theta,
38
+ scale: 1.0
39
+ )
40
+ end
41
+
42
+ def call(hidden_states, attention_mask: nil, cache: nil)
43
+ mx = MLX::Core
44
+ bsz, q_len, _d = hidden_states.shape
45
+
46
+ queries = q_proj.call(hidden_states)
47
+ keys = k_proj.call(hidden_states)
48
+ values = v_proj.call(hidden_states)
49
+
50
+ queries = queries.reshape([bsz, q_len, @q_num_heads, @qk_dim]).transpose([0, 2, 1, 3])
51
+ keys = keys.reshape([bsz, q_len, @k_num_heads, @qk_dim]).transpose([0, 2, 1, 3])
52
+ values = values.reshape([bsz, q_len, @v_num_heads, @v_dim]).transpose([0, 2, 1, 3])
53
+
54
+ if cache
55
+ queries = rotary_emb.call(queries, offset: cache.offset)
56
+ keys = rotary_emb.call(keys, offset: cache.offset)
57
+ keys, values = cache.update_and_fetch(keys, values)
58
+ else
59
+ queries = rotary_emb.call(queries)
60
+ keys = rotary_emb.call(keys)
61
+ end
62
+
63
+ keys = mx.tile(keys, [1, @config.n_shared_head, 1, 1])
64
+ values = mx.tile(values, [1, @config.n_shared_head, 1, 1])
65
+
66
+ output = mx.scaled_dot_product_attention(
67
+ queries,
68
+ keys,
69
+ values,
70
+ @scale,
71
+ attention_mask
72
+ )
73
+ output = output.transpose([0, 2, 1, 3]).reshape([bsz, q_len, @q_num_heads * @v_dim])
74
+ o_proj.call(output)
75
+ end
76
+ end
77
+
78
+ class MLP < MLX::NN::Module
79
+ def initialize(config)
80
+ super()
81
+ self.gate_proj = MLX::NN::Linear.new(config.hidden_size, config.intermediate_size, bias: false)
82
+ self.up_proj = MLX::NN::Linear.new(config.hidden_size, config.intermediate_size, bias: false)
83
+ self.down_proj = MLX::NN::Linear.new(config.intermediate_size, config.hidden_size, bias: false)
84
+ end
85
+
86
+ def call(x)
87
+ down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x)))
88
+ end
89
+ end
90
+
91
+ class DecoderLayer < MLX::NN::Module
92
+ def initialize(config)
93
+ super()
94
+ self.self_attn = Attention.new(config)
95
+ self.mlp = MLP.new(config)
96
+ self.norm = MLX::NN::RMSNorm.new(config.hidden_size, eps: config.rms_norm_eps)
97
+ end
98
+
99
+ def call(hidden_states, attention_mask: nil, cache: nil)
100
+ residual = hidden_states
101
+ hidden_states = norm.call(hidden_states)
102
+
103
+ hidden_states_sa = self_attn.call(
104
+ hidden_states,
105
+ attention_mask: attention_mask,
106
+ cache: cache
107
+ )
108
+ hidden_states_mlp = mlp.call(hidden_states)
109
+
110
+ residual + hidden_states_sa + hidden_states_mlp
111
+ end
112
+ end
113
+
114
+ class PlamoModel < MLX::NN::Module
115
+ def initialize(config)
116
+ super()
117
+ self.embed_tokens = MLX::NN::Embedding.new(config.vocab_size, config.hidden_size)
118
+ self.layers = Array.new(config.num_hidden_layers) { DecoderLayer.new(config) }
119
+ self.norm = MLX::NN::RMSNorm.new(config.hidden_size, eps: config.rms_norm_eps)
120
+ end
121
+
122
+ def call(inputs, cache: nil)
123
+ h = embed_tokens.call(inputs)
124
+ layer_cache = cache || [nil] * layers.length
125
+ mask = _create_attention_mask(h, layer_cache[0])
126
+
127
+ layers.each_with_index do |layer, i|
128
+ h = layer.call(h, attention_mask: mask, cache: layer_cache[i])
129
+ end
130
+
131
+ norm.call(h)
132
+ end
133
+
134
+ private
135
+
136
+ def _create_attention_mask(h, cache)
137
+ return cache.make_mask(h.shape[1]) if cache && cache.respond_to?(:make_mask)
138
+ return nil if h.shape[1] == 1
139
+
140
+ "causal"
141
+ end
142
+ end
143
+
144
+ class Model < MLX::NN::Module
145
+ def initialize(args)
146
+ super()
147
+ self.model_type = args.model_type
148
+ self.model = PlamoModel.new(args)
149
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
150
+ end
151
+
152
+ def call(inputs, cache: nil)
153
+ out = model.call(inputs, cache: cache)
154
+ lm_head.call(out)
155
+ end
156
+
157
+ def sanitize(weights)
158
+ weights.reject { |k, _| k.include?("self_attn.rotary_emb.inv_freq") }
159
+ end
160
+
161
+ def layers
162
+ model.layers
163
+ end
164
+ end
165
+
166
+ Models.register("plamo", Model, ModelArgs)
167
+ end
168
+ end
169
+ end