mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
@@ -0,0 +1,270 @@
1
+ require_relative "cache"
2
+ require_relative "rope_utils"
3
+
4
+ module MlxLm
5
+ module Models
6
+ module Gemma3Text
7
+ class ModelArgs < BaseModelArgs
8
+ field :model_type, default: "gemma3_text"
9
+ field :hidden_size, default: 1152
10
+ field :num_hidden_layers, default: 26
11
+ field :intermediate_size, default: 6912
12
+ field :num_attention_heads, default: 4
13
+ field :head_dim, default: 256
14
+ field :rms_norm_eps, default: 1.0e-6
15
+ field :vocab_size, default: 262144
16
+ field :num_key_value_heads, default: 1
17
+ field :rope_theta, default: 1_000_000.0
18
+ field :rope_local_base_freq, default: 10_000.0
19
+ field :query_pre_attn_scalar, default: 256.0
20
+ field :sliding_window, default: 512
21
+ field :sliding_window_pattern, default: 6
22
+ field :max_position_embeddings, default: 32768
23
+ field :rope_scaling, default: nil
24
+ end
25
+
26
+ class RMSNorm < MLX::NN::Module
27
+ def initialize(dims:, eps: 1e-6)
28
+ super()
29
+ self.weight = MLX::Core.ones([dims])
30
+ @eps = eps
31
+ end
32
+
33
+ def call(x)
34
+ mx = MLX::Core
35
+ x_sq = x * x
36
+ mean_sq = mx.mean(x_sq, -1, keepdims: true)
37
+ x * mx.rsqrt(mean_sq + @eps) * (1.0 + weight)
38
+ end
39
+ end
40
+
41
+ class Attention < MLX::NN::Module
42
+ def initialize(args, layer_idx)
43
+ super()
44
+ dim = args.hidden_size
45
+ @n_heads = args.num_attention_heads
46
+ @n_kv_heads = args.num_key_value_heads
47
+ @head_dim = args.head_dim
48
+ @scale = args.query_pre_attn_scalar**(-0.5)
49
+ pattern = [args.sliding_window_pattern.to_i, 1].max
50
+ @is_sliding = ((layer_idx + 1) % pattern) != 0
51
+
52
+ self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: false)
53
+ self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: false)
54
+ self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: false)
55
+ self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: false)
56
+
57
+ self.q_norm = RMSNorm.new(dims: @head_dim, eps: args.rms_norm_eps)
58
+ self.k_norm = RMSNorm.new(dims: @head_dim, eps: args.rms_norm_eps)
59
+
60
+ if @is_sliding
61
+ self.rope = MlxLm::Models.initialize_rope(
62
+ @head_dim,
63
+ args.rope_local_base_freq,
64
+ false
65
+ )
66
+ else
67
+ self.rope = MlxLm::Models.initialize_rope(
68
+ @head_dim,
69
+ args.rope_theta,
70
+ false,
71
+ args.rope_scaling,
72
+ max_position_embeddings: args.max_position_embeddings
73
+ )
74
+ end
75
+ end
76
+
77
+ def call(x, mask: nil, cache: nil)
78
+ mx = MLX::Core
79
+ b, l, _d = x.shape
80
+
81
+ queries = q_proj.call(x).reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
82
+ keys = k_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
83
+ values = v_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
84
+
85
+ queries = q_norm.call(queries)
86
+ keys = k_norm.call(keys)
87
+
88
+ if cache
89
+ queries = rope.call(queries, offset: cache.offset)
90
+ keys = rope.call(keys, offset: cache.offset)
91
+ keys, values = cache.update_and_fetch(keys, values)
92
+ else
93
+ queries = rope.call(queries)
94
+ keys = rope.call(keys)
95
+ end
96
+
97
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
98
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
99
+ o_proj.call(output)
100
+ end
101
+ end
102
+
103
+ class MLP < MLX::NN::Module
104
+ def initialize(dim, hidden_dim)
105
+ super()
106
+ self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false)
107
+ self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: false)
108
+ self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false)
109
+ end
110
+
111
+ def call(x)
112
+ down_proj.call(MLX::NN.gelu_approx(gate_proj.call(x)) * up_proj.call(x))
113
+ end
114
+ end
115
+
116
+ class TransformerBlock < MLX::NN::Module
117
+ def initialize(args, layer_idx)
118
+ super()
119
+ self.self_attn = Attention.new(args, layer_idx)
120
+ self.mlp = MLP.new(args.hidden_size, args.intermediate_size)
121
+ self.input_layernorm = RMSNorm.new(dims: args.hidden_size, eps: args.rms_norm_eps)
122
+ self.post_attention_layernorm = RMSNorm.new(dims: args.hidden_size, eps: args.rms_norm_eps)
123
+ self.pre_feedforward_layernorm = RMSNorm.new(dims: args.hidden_size, eps: args.rms_norm_eps)
124
+ self.post_feedforward_layernorm = RMSNorm.new(dims: args.hidden_size, eps: args.rms_norm_eps)
125
+ end
126
+
127
+ def call(x, mask: nil, cache: nil)
128
+ r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
129
+ h = clip_residual(x, post_attention_layernorm.call(r))
130
+ r = mlp.call(pre_feedforward_layernorm.call(h))
131
+ clip_residual(h, post_feedforward_layernorm.call(r))
132
+ end
133
+
134
+ private
135
+
136
+ def clip_residual(x, y)
137
+ mx = MLX::Core
138
+ return x + y unless x.dtype == mx.float16
139
+
140
+ bound = mx.finfo(mx.float16).max
141
+ mx.clip(
142
+ x.astype(mx.float32) + y.astype(mx.float32),
143
+ -bound,
144
+ bound
145
+ ).astype(mx.float16)
146
+ end
147
+ end
148
+
149
+ class Gemma3Model < MLX::NN::Module
150
+ attr_reader :sliding_window_pattern
151
+
152
+ def initialize(args)
153
+ super()
154
+ @args = args
155
+ @window_size = args.sliding_window
156
+ @sliding_window_pattern = [args.sliding_window_pattern.to_i, 1].max
157
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
158
+ self.layers = Array.new(args.num_hidden_layers) do |layer_idx|
159
+ TransformerBlock.new(args, layer_idx)
160
+ end
161
+ self.norm = RMSNorm.new(dims: args.hidden_size, eps: args.rms_norm_eps)
162
+ end
163
+
164
+ def call(inputs, cache: nil, input_embeddings: nil)
165
+ h = input_embeddings || embed_tokens.call(inputs)
166
+ h = h * Math.sqrt(@args.hidden_size)
167
+ layer_cache = cache || [nil] * layers.length
168
+
169
+ global_idx = sliding_window_pattern - 1
170
+ global_mask = _create_attention_mask(h, layer_cache[global_idx])
171
+ sliding_window_mask = if sliding_window_pattern > 1
172
+ _create_attention_mask(h, layer_cache[0], window_size: @window_size)
173
+ else
174
+ nil
175
+ end
176
+
177
+ layers.each_with_index do |layer, i|
178
+ is_global = (i % sliding_window_pattern) == (sliding_window_pattern - 1)
179
+ mask = is_global ? global_mask : sliding_window_mask
180
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
181
+ end
182
+
183
+ norm.call(h)
184
+ end
185
+
186
+ private
187
+
188
+ def _create_attention_mask(h, cache = nil, window_size: nil)
189
+ n = h.shape[1]
190
+ return cache.make_mask(n) if cache && cache.respond_to?(:make_mask)
191
+
192
+ if window_size
193
+ offset = cache ? cache.offset : 0
194
+ if cache && cache.instance_variable_defined?(:@max_size)
195
+ max_size = cache.instance_variable_get(:@max_size)
196
+ offset = [max_size - 1, offset].min if max_size && max_size > 0
197
+ end
198
+ return _create_causal_mask(n, offset: offset, window_size: window_size) if offset + n > window_size
199
+ end
200
+
201
+ return nil if n == 1
202
+
203
+ "causal"
204
+ end
205
+
206
+ def _create_causal_mask(n, offset: 0, window_size: nil)
207
+ mx = MLX::Core
208
+ rinds = mx.arange(0, offset + n, 1, mx.int32).reshape([1, offset + n])
209
+ linds = mx.arange(offset, offset + n, 1, mx.int32).reshape([n, 1])
210
+
211
+ mask = mx.greater_equal(linds, rinds)
212
+ if window_size
213
+ mask = mx.logical_and(mask, mx.less(linds, mx.add(rinds, window_size)))
214
+ end
215
+ mask
216
+ end
217
+ end
218
+
219
+ class Model < MLX::NN::Module
220
+ attr_reader :args
221
+
222
+ def initialize(args)
223
+ super()
224
+ @args = args
225
+ @tie_word_embeddings = false
226
+ self.model_type = args.model_type
227
+ self.model = Gemma3Model.new(args)
228
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
229
+ end
230
+
231
+ def call(inputs, cache: nil, input_embeddings: nil)
232
+ out = model.call(inputs, cache: cache, input_embeddings: input_embeddings)
233
+ if @tie_word_embeddings || lm_head.nil?
234
+ model.embed_tokens.as_linear(out)
235
+ else
236
+ lm_head.call(out)
237
+ end
238
+ end
239
+
240
+ def sanitize(weights)
241
+ sanitized = weights.reject { |k, _| k.include?("self_attn.rotary_emb.inv_freq") }
242
+ unless sanitized.key?("lm_head.weight")
243
+ @tie_word_embeddings = true
244
+ self.lm_head = nil
245
+ end
246
+ sanitized
247
+ end
248
+
249
+ def layers
250
+ model.layers
251
+ end
252
+
253
+ def make_cache
254
+ pattern = [@args.sliding_window_pattern.to_i, 1].max
255
+ max_size = @args.sliding_window || @args.max_position_embeddings || 1
256
+ Array.new(@args.num_hidden_layers) do |i|
257
+ is_global = (i % pattern) == (pattern - 1)
258
+ if is_global
259
+ MlxLm::KVCache.new
260
+ else
261
+ MlxLm::RotatingKVCache.new(max_size: max_size, keep: 0)
262
+ end
263
+ end
264
+ end
265
+ end
266
+
267
+ Models.register("gemma3_text", Model, ModelArgs)
268
+ end
269
+ end
270
+ end
@@ -0,0 +1,79 @@
1
+ require_relative "gemma2"
2
+
3
+ module MlxLm
4
+ module Models
5
+ module Gemma3n
6
+ class ModelArgs < BaseModelArgs
7
+ field :model_type, default: "gemma3n"
8
+ field :text_config, default: nil
9
+
10
+ def self.from_dict(params)
11
+ has_text_config = params.key?("text_config") || params.key?(:text_config)
12
+ return super if has_text_config
13
+
14
+ new(model_type: params["model_type"] || params[:model_type], text_config: params)
15
+ end
16
+
17
+ def initialize(**kwargs)
18
+ super
19
+ @text_config = (@text_config || {}).dup
20
+ end
21
+ end
22
+
23
+ class Model < MLX::NN::Module
24
+ MULTIMODAL_MODEL_PREFIXES = %w[
25
+ model.vision_tower
26
+ model.audio_tower
27
+ model.embed_audio
28
+ model.embed_vision
29
+ ].freeze
30
+
31
+ def initialize(args)
32
+ super()
33
+ @args = args
34
+ self.model_type = args.model_type
35
+ self.language_model = Gemma2::Model.new(Gemma2::ModelArgs.from_dict(_text_config_for_gemma2(args)))
36
+ end
37
+
38
+ def call(inputs, cache: nil, input_embeddings: nil)
39
+ supports_input_embeddings = language_model.method(:call).parameters.any? do |_, name|
40
+ name == :input_embeddings
41
+ end
42
+
43
+ if supports_input_embeddings
44
+ language_model.call(inputs, cache: cache, input_embeddings: input_embeddings)
45
+ else
46
+ language_model.call(inputs, cache: cache)
47
+ end
48
+ end
49
+
50
+ def sanitize(weights)
51
+ weights.reject do |key, _|
52
+ MULTIMODAL_MODEL_PREFIXES.any? { |prefix| key == prefix || key.start_with?("#{prefix}.") }
53
+ end
54
+ end
55
+
56
+ def layers
57
+ language_model.layers
58
+ end
59
+
60
+ def make_cache
61
+ return language_model.make_cache if language_model.respond_to?(:make_cache)
62
+
63
+ nil
64
+ end
65
+
66
+ private
67
+
68
+ def _text_config_for_gemma2(args)
69
+ config = {}
70
+ (args.text_config || {}).each { |key, value| config[key.to_s] = value }
71
+ config["model_type"] ||= args.model_type
72
+ config
73
+ end
74
+ end
75
+
76
+ Models.register("gemma3n", Model, ModelArgs)
77
+ end
78
+ end
79
+ end
@@ -0,0 +1,164 @@
1
+ module MlxLm
2
+ module Models
3
+ module GLM
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "glm"
6
+ field :hidden_size, default: 4096
7
+ field :num_hidden_layers, default: 28
8
+ field :intermediate_size, default: 13696
9
+ field :num_attention_heads, default: 32
10
+ field :rms_norm_eps, default: 1e-5
11
+ field :vocab_size, default: 151552
12
+ field :head_dim, default: nil
13
+ field :num_key_value_heads, default: nil
14
+ field :max_position_embeddings, default: nil
15
+ field :attention_bias, default: false
16
+ field :rope_theta, default: 10_000.0
17
+ field :tie_word_embeddings, default: true
18
+
19
+ def initialize(**kwargs)
20
+ super
21
+ @num_key_value_heads ||= @num_attention_heads
22
+ @head_dim ||= @hidden_size / @num_attention_heads
23
+ end
24
+ end
25
+
26
+ class Attention < MLX::NN::Module
27
+ def initialize(args)
28
+ super()
29
+ dim = args.hidden_size
30
+ @n_heads = args.num_attention_heads
31
+ @n_kv_heads = args.num_key_value_heads
32
+ @head_dim = args.head_dim
33
+ @scale = @head_dim**(-0.5)
34
+
35
+ bias = args.attention_bias
36
+ self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: bias)
37
+ self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: bias)
38
+ self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: bias)
39
+ self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: false)
40
+ self.rope = MLX::NN::RoPE.new(@head_dim, traditional: true, base: args.rope_theta)
41
+ end
42
+
43
+ def call(x, mask: nil, cache: nil)
44
+ mx = MLX::Core
45
+ b, l, _d = x.shape
46
+
47
+ queries = q_proj.call(x).reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
48
+ keys = k_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
49
+ values = v_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
50
+
51
+ if cache
52
+ queries = rope.call(queries, offset: cache.offset)
53
+ keys = rope.call(keys, offset: cache.offset)
54
+ keys, values = cache.update_and_fetch(keys, values)
55
+ else
56
+ queries = rope.call(queries)
57
+ keys = rope.call(keys)
58
+ end
59
+
60
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
61
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
62
+ o_proj.call(output)
63
+ end
64
+ end
65
+
66
+ class MLP < MLX::NN::Module
67
+ def initialize(args)
68
+ super()
69
+ self.gate_up_proj = MLX::NN::Linear.new(
70
+ args.hidden_size,
71
+ 2 * args.intermediate_size,
72
+ bias: false
73
+ )
74
+ self.down_proj = MLX::NN::Linear.new(
75
+ args.intermediate_size,
76
+ args.hidden_size,
77
+ bias: false
78
+ )
79
+ end
80
+
81
+ def call(x)
82
+ mx = MLX::Core
83
+ x = gate_up_proj.call(x)
84
+ split_dim = x.shape[-1] / 2
85
+ gate, up = mx.split(x, [split_dim], -1)
86
+ down_proj.call(Activations.swiglu(gate, up))
87
+ end
88
+ end
89
+
90
+ class TransformerBlock < MLX::NN::Module
91
+ def initialize(args)
92
+ super()
93
+ self.self_attn = Attention.new(args)
94
+ self.mlp = MLP.new(args)
95
+ self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
96
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
97
+ end
98
+
99
+ def call(x, mask: nil, cache: nil)
100
+ r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
101
+ h = x + r
102
+ r = mlp.call(post_attention_layernorm.call(h))
103
+ h + r
104
+ end
105
+ end
106
+
107
+ class GLMModel < MLX::NN::Module
108
+ def initialize(args)
109
+ super()
110
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
111
+ self.layers = Array.new(args.num_hidden_layers) { TransformerBlock.new(args) }
112
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
113
+ end
114
+
115
+ def call(inputs, cache: nil)
116
+ h = embed_tokens.call(inputs)
117
+ layer_cache = cache || [nil] * layers.length
118
+
119
+ mask = nil
120
+ mask = "causal" if h.shape[1] > 1
121
+
122
+ layers.each_with_index do |layer, i|
123
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
124
+ end
125
+
126
+ norm.call(h)
127
+ end
128
+ end
129
+
130
+ class Model < MLX::NN::Module
131
+ def initialize(args)
132
+ super()
133
+ @args = args
134
+ @model_type = args.model_type
135
+ self.model = GLMModel.new(args)
136
+ unless args.tie_word_embeddings
137
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
138
+ end
139
+ end
140
+
141
+ def call(inputs, cache: nil)
142
+ out = model.call(inputs, cache: cache)
143
+ if @args.tie_word_embeddings
144
+ model.embed_tokens.as_linear(out)
145
+ else
146
+ lm_head.call(out)
147
+ end
148
+ end
149
+
150
+ def sanitize(weights)
151
+ result = weights.reject { |k, _| k.include?("self_attn.rotary_emb.inv_freq") }
152
+ result.delete("lm_head.weight") if @args.tie_word_embeddings
153
+ result
154
+ end
155
+
156
+ def layers
157
+ model.layers
158
+ end
159
+ end
160
+
161
+ Models.register("glm", Model, ModelArgs)
162
+ end
163
+ end
164
+ end
@@ -0,0 +1,180 @@
1
+ module MlxLm
2
+ module Models
3
+ module GLM4
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "glm4"
6
+ field :hidden_size, default: 4096
7
+ field :num_hidden_layers, default: 40
8
+ field :intermediate_size, default: 13696
9
+ field :num_attention_heads, default: 32
10
+ field :attention_bias, default: false
11
+ field :head_dim, default: nil
12
+ field :rms_norm_eps, default: 1e-5
13
+ field :vocab_size, default: 151552
14
+ field :num_key_value_heads, default: nil
15
+ field :partial_rotary_factor, default: 0.5
16
+ field :rope_theta, default: 10_000.0
17
+ field :rope_traditional, default: true
18
+ field :max_position_embeddings, default: 32768
19
+
20
+ def initialize(**kwargs)
21
+ super
22
+ @num_key_value_heads ||= @num_attention_heads
23
+ @head_dim ||= @hidden_size / @num_attention_heads
24
+ end
25
+ end
26
+
27
+ class GLM4MLP < MLX::NN::Module
28
+ def initialize(args)
29
+ super()
30
+ self.gate_up_proj = MLX::NN::Linear.new(
31
+ args.hidden_size,
32
+ 2 * args.intermediate_size,
33
+ bias: false
34
+ )
35
+ self.down_proj = MLX::NN::Linear.new(
36
+ args.intermediate_size,
37
+ args.hidden_size,
38
+ bias: false
39
+ )
40
+ end
41
+
42
+ def call(x)
43
+ mx = MLX::Core
44
+ x = gate_up_proj.call(x)
45
+ split_dim = x.shape[-1] / 2
46
+ gate, up_states = mx.split(x, [split_dim], -1)
47
+ down_proj.call(Activations.swiglu(gate, up_states))
48
+ end
49
+ end
50
+
51
+ class GLM4Attention < MLX::NN::Module
52
+ def initialize(args)
53
+ super()
54
+ dim = args.hidden_size
55
+ @head_dim = args.head_dim
56
+ @n_heads = args.num_attention_heads
57
+ @n_kv_heads = args.num_key_value_heads
58
+ @scale = @head_dim**(-0.5)
59
+
60
+ self.q_proj = MLX::NN::Linear.new(
61
+ dim,
62
+ args.num_attention_heads * @head_dim,
63
+ bias: args.attention_bias
64
+ )
65
+ self.k_proj = MLX::NN::Linear.new(
66
+ dim,
67
+ args.num_key_value_heads * @head_dim,
68
+ bias: args.attention_bias
69
+ )
70
+ self.v_proj = MLX::NN::Linear.new(
71
+ dim,
72
+ args.num_key_value_heads * @head_dim,
73
+ bias: args.attention_bias
74
+ )
75
+ self.o_proj = MLX::NN::Linear.new(
76
+ args.num_attention_heads * @head_dim,
77
+ args.hidden_size,
78
+ bias: false
79
+ )
80
+
81
+ self.rope = MLX::NN::RoPE.new(
82
+ (args.partial_rotary_factor * @head_dim).to_i,
83
+ base: args.rope_theta,
84
+ traditional: args.rope_traditional
85
+ )
86
+ end
87
+
88
+ def call(x, mask: nil, cache: nil)
89
+ mx = MLX::Core
90
+ b, l, _d = x.shape
91
+
92
+ queries = q_proj.call(x).reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
93
+ keys = k_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
94
+ values = v_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
95
+
96
+ if cache
97
+ queries = rope.call(queries, offset: cache.offset)
98
+ keys = rope.call(keys, offset: cache.offset)
99
+ keys, values = cache.update_and_fetch(keys, values)
100
+ else
101
+ queries = rope.call(queries)
102
+ keys = rope.call(keys)
103
+ end
104
+
105
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
106
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
107
+ o_proj.call(output)
108
+ end
109
+ end
110
+
111
+ class GLM4DecoderLayer < MLX::NN::Module
112
+ def initialize(args)
113
+ super()
114
+ self.self_attn = GLM4Attention.new(args)
115
+ self.mlp = GLM4MLP.new(args)
116
+ self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
117
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
118
+ self.post_self_attn_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
119
+ self.post_mlp_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
120
+ end
121
+
122
+ def call(x, mask: nil, cache: nil)
123
+ x = x + post_self_attn_layernorm.call(
124
+ self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
125
+ )
126
+ residual = x
127
+ post_mlp_layernorm.call(mlp.call(post_attention_layernorm.call(x))) + residual
128
+ end
129
+ end
130
+
131
+ class GLM4Model < MLX::NN::Module
132
+ def initialize(args)
133
+ super()
134
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
135
+ self.layers = Array.new(args.num_hidden_layers) { GLM4DecoderLayer.new(args) }
136
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
137
+ end
138
+
139
+ def call(inputs, cache: nil)
140
+ h = embed_tokens.call(inputs)
141
+ layer_cache = cache || [nil] * layers.length
142
+
143
+ mask = nil
144
+ mask = "causal" if h.shape[1] > 1
145
+
146
+ layers.each_with_index do |layer, i|
147
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
148
+ end
149
+
150
+ norm.call(h)
151
+ end
152
+ end
153
+
154
+ class Model < MLX::NN::Module
155
+ def initialize(args)
156
+ super()
157
+ @args = args
158
+ self.model_type = args.model_type
159
+ self.model = GLM4Model.new(args)
160
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
161
+ end
162
+
163
+ def call(inputs, cache: nil)
164
+ out = model.call(inputs, cache: cache)
165
+ lm_head.call(out)
166
+ end
167
+
168
+ def sanitize(weights)
169
+ weights.reject { |k, _| k.include?("self_attn.rotary_emb.inv_freq") }
170
+ end
171
+
172
+ def layers
173
+ model.layers
174
+ end
175
+ end
176
+
177
+ Models.register("glm4", Model, ModelArgs)
178
+ end
179
+ end
180
+ end