mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
@@ -0,0 +1,92 @@
1
+ require_relative "qwen3_moe"
2
+
3
+ module MlxLm
4
+ module Models
5
+ module Qwen3VLMoe
6
+ class ModelArgs < BaseModelArgs
7
+ field :model_type, default: "qwen3_vl_moe"
8
+ field :text_config, default: nil
9
+
10
+ def self.from_dict(params)
11
+ has_text_config = params.key?("text_config") || params.key?(:text_config)
12
+ return super if has_text_config
13
+
14
+ new(model_type: params["model_type"] || params[:model_type], text_config: params)
15
+ end
16
+ end
17
+
18
+ class Model < MLX::NN::Module
19
+ def initialize(args)
20
+ super()
21
+ @args = args
22
+ self.model_type = args.model_type
23
+ self.language_model = Qwen3Moe::Model.new(Qwen3Moe::ModelArgs.from_dict(args.text_config))
24
+ end
25
+
26
+ def call(inputs, cache: nil, input_embeddings: nil)
27
+ language_model.call(inputs, cache: cache, input_embeddings: input_embeddings)
28
+ end
29
+
30
+ def sanitize(weights)
31
+ nested = MLX::Utils.tree_unflatten(weights.to_a)
32
+ nested.delete("visual") if nested.is_a?(Hash)
33
+
34
+ language_model_tree = {}
35
+ if nested.is_a?(Hash)
36
+ language_model_node = nested["language_model"]
37
+ if language_model_node.is_a?(Hash)
38
+ language_model_tree["model"] = language_model_node["model"] if language_model_node.key?("model")
39
+ language_model_tree["lm_head"] = language_model_node["lm_head"] if language_model_node.key?("lm_head")
40
+ end
41
+ end
42
+
43
+ flattened = MLX::Utils.tree_flatten({ "language_model" => language_model_tree }, destination: {})
44
+ sanitized = flattened.is_a?(Hash) ? flattened : {}
45
+ rewrite_moe_expert_weights(sanitized)
46
+ sanitized
47
+ end
48
+
49
+ def layers
50
+ language_model.model.layers
51
+ end
52
+
53
+ private
54
+
55
+ def rewrite_moe_expert_weights(weights)
56
+ mx = MLX::Core
57
+
58
+ layers.length.times do |layer_idx|
59
+ prefix = "language_model.model.layers.#{layer_idx}.mlp"
60
+ gate_up_key = _first_existing_key(
61
+ weights,
62
+ ["#{prefix}.experts.gate_up_proj", "#{prefix}.experts.gate_up_proj.weight"]
63
+ )
64
+ down_proj_key = _first_existing_key(
65
+ weights,
66
+ ["#{prefix}.experts.down_proj", "#{prefix}.experts.down_proj.weight"]
67
+ )
68
+
69
+ next unless gate_up_key && down_proj_key
70
+
71
+ gate_up = weights.delete(gate_up_key)
72
+ down_proj = weights.delete(down_proj_key)
73
+ mid = gate_up.shape[-1] / 2
74
+ gate_proj, up_proj = mx.split(gate_up, [mid], -1)
75
+
76
+ weights["#{prefix}.switch_mlp.gate_proj.weight"] = mx.swapaxes(gate_proj, -2, -1)
77
+ weights["#{prefix}.switch_mlp.up_proj.weight"] = mx.swapaxes(up_proj, -2, -1)
78
+ weights["#{prefix}.switch_mlp.down_proj.weight"] = mx.swapaxes(down_proj, -2, -1)
79
+ end
80
+
81
+ weights
82
+ end
83
+
84
+ def _first_existing_key(weights, candidates)
85
+ candidates.find { |key| weights.key?(key) }
86
+ end
87
+ end
88
+
89
+ Models.register("qwen3_vl_moe", Model, ModelArgs)
90
+ end
91
+ end
92
+ end
@@ -0,0 +1,444 @@
1
+ require_relative "cache"
2
+
3
+ module MlxLm
4
+ module Models
5
+ module RecurrentGemma
6
+ class ModelArgs < BaseModelArgs
7
+ field :model_type, default: "recurrent_gemma"
8
+ field :attention_bias
9
+ field :conv1d_width
10
+ field :hidden_size
11
+ field :intermediate_size
12
+ field :logits_soft_cap
13
+ field :num_attention_heads
14
+ field :num_hidden_layers
15
+ field :num_key_value_heads
16
+ field :rms_norm_eps
17
+ field :rope_theta
18
+ field :attention_window_size
19
+ field :vocab_size
20
+ field :embeddings_scale_by_sqrt_dim, default: true
21
+ field :block_types, default: nil
22
+ field :_block_types, default: nil
23
+
24
+ def initialize(**kwargs)
25
+ super
26
+ @block_types ||= @_block_types
27
+ @block_types ||= ["recurrent", "attention"]
28
+ end
29
+ end
30
+
31
+ class RMSNorm < MLX::NN::Module
32
+ def initialize(dims, eps: 1e-5)
33
+ super()
34
+ self.weight = MLX::Core.ones([dims])
35
+ @eps = eps
36
+ end
37
+
38
+ def call(x)
39
+ mx = MLX::Core
40
+ mean_sq = mx.mean(x * x, -1, keepdims: true)
41
+ norm = x * mx.rsqrt(mean_sq + @eps)
42
+ norm * (weight + 1.0)
43
+ end
44
+ end
45
+
46
+ class RGLRU < MLX::NN::Module
47
+ def initialize(width:, num_heads:)
48
+ super()
49
+ @width = width
50
+ @num_heads = num_heads
51
+ @head_dim = @width / @num_heads
52
+
53
+ mx = MLX::Core
54
+ self.recurrent_param = mx.zeros([@width])
55
+ self.input_gate_weight = mx.zeros([@num_heads, @head_dim, @head_dim])
56
+ self.input_gate_bias = mx.zeros([@num_heads, @head_dim])
57
+ self.recurrent_gate_weight = mx.zeros([@num_heads, @head_dim, @head_dim])
58
+ self.recurrent_gate_bias = mx.zeros([@num_heads, @head_dim])
59
+ end
60
+
61
+ def call(x, cache: nil)
62
+ mx = MLX::Core
63
+ b, l, _ = x.shape
64
+
65
+ gate_x = _apply_block_linear(x, input_gate_weight, input_gate_bias, batch: b, seq: l)
66
+ gate_a = _apply_block_linear(x, recurrent_gate_weight, recurrent_gate_bias, batch: b, seq: l)
67
+
68
+ log_a = -8.0 * gate_a * MLX::NN.softplus(recurrent_param)
69
+ a = mx.exp(log_a)
70
+ a_square = mx.exp(2.0 * log_a)
71
+
72
+ gated_x = x * gate_x
73
+ multiplier = mx.sqrt(1.0 - a_square)
74
+ if cache.nil?
75
+ first = mx.ones([b, 1, @width], multiplier.dtype)
76
+ if l == 1
77
+ multiplier = first
78
+ else
79
+ rest = mx.split(multiplier, [1], 1)[1]
80
+ multiplier = mx.concatenate([first, rest], 1)
81
+ end
82
+ end
83
+
84
+ normalized_x = gated_x * multiplier.astype(x.dtype)
85
+ _rnn_scan(normalized_x, a, cache)
86
+ end
87
+
88
+ private
89
+
90
+ def _apply_block_linear(h, w, b, batch:, seq:)
91
+ mx = MLX::Core
92
+ h = h.reshape([batch, seq, @num_heads, @head_dim]).transpose([0, 2, 1, 3])
93
+ h = mx.matmul(h, w).transpose([0, 2, 1, 3]) + b
94
+ mx.sigmoid(h.reshape([batch, seq, @width]))
95
+ end
96
+
97
+ def _rnn_scan(x, a, h0)
98
+ mx = MLX::Core
99
+ b, l, d = x.shape
100
+
101
+ if l == 1
102
+ if h0.nil?
103
+ return x, _slice_step(x, 0)
104
+ end
105
+
106
+ y = a * mx.expand_dims(h0, 1) + x
107
+ return y, _slice_step(y, 0)
108
+ end
109
+
110
+ h_t = h0 || mx.zeros([b, d], x.dtype)
111
+ ys = []
112
+ l.times do |t|
113
+ h_t = _slice_step(a, t) * h_t + _slice_step(x, t)
114
+ ys << h_t
115
+ end
116
+ [mx.stack(ys, 1), h_t]
117
+ end
118
+
119
+ def _slice_step(array, idx)
120
+ mx = MLX::Core
121
+ idx_arr = mx.array([idx], dtype: mx.int32)
122
+ mx.squeeze(mx.take(array, idx_arr, 1), 1)
123
+ end
124
+ end
125
+
126
+ class RecurrentBlock < MLX::NN::Module
127
+ def initialize(width:, num_heads:, lru_width: nil, conv1d_temporal_width: 4)
128
+ super()
129
+ @width = width
130
+ @num_heads = num_heads
131
+ @lru_width = lru_width || width
132
+ @conv1d_temporal_width = conv1d_temporal_width
133
+
134
+ self.linear_y = MLX::NN::Linear.new(width, @lru_width)
135
+ self.linear_x = MLX::NN::Linear.new(width, @lru_width)
136
+ self.linear_out = MLX::NN::Linear.new(@lru_width, width)
137
+ self.conv_1d = MLX::NN::Conv1d.new(
138
+ @lru_width,
139
+ @lru_width,
140
+ @conv1d_temporal_width,
141
+ groups: @lru_width,
142
+ bias: true,
143
+ padding: 0
144
+ )
145
+ self.rg_lru = RGLRU.new(width: @lru_width, num_heads: @num_heads)
146
+ end
147
+
148
+ def call(x, cache: nil, mask: nil)
149
+ _ = mask
150
+ mx = MLX::Core
151
+
152
+ y = MLX::NN.gelu_approx(linear_y.call(x))
153
+ x = linear_x.call(x)
154
+
155
+ conv_cache = _read_cache(cache, 0)
156
+ rnn_cache = _read_cache(cache, 1)
157
+
158
+ x = if conv_cache
159
+ mx.concatenate([conv_cache, x], 1)
160
+ else
161
+ mx.pad(x, [[0, 0], [@conv1d_temporal_width - 1, 0], [0, 0]])
162
+ end
163
+
164
+ conv_input = x
165
+ x = conv_1d.call(x)
166
+ _write_cache(cache, 0, _tail_cache(conv_input))
167
+
168
+ x, last_h = rg_lru.call(x, cache: rnn_cache)
169
+ _write_cache(cache, 1, last_h)
170
+
171
+ linear_out.call(x * y)
172
+ end
173
+
174
+ private
175
+
176
+ def _tail_cache(full_x)
177
+ mx = MLX::Core
178
+ n_keep = @conv1d_temporal_width - 1
179
+ return mx.zeros([full_x.shape[0], 0, full_x.shape[2]], full_x.dtype) if n_keep <= 0
180
+
181
+ split_at = full_x.shape[1] - n_keep
182
+ mx.split(full_x, [split_at], 1)[1]
183
+ end
184
+
185
+ def _read_cache(cache, idx)
186
+ if cache.is_a?(MlxLm::ArraysCache) || cache.is_a?(Array)
187
+ cache[idx]
188
+ else
189
+ nil
190
+ end
191
+ end
192
+
193
+ def _write_cache(cache, idx, value)
194
+ return unless cache.is_a?(MlxLm::ArraysCache) || cache.is_a?(Array)
195
+
196
+ cache[idx] = value
197
+ end
198
+ end
199
+
200
+ class LocalAttentionBlock < MLX::NN::Module
201
+ def initialize(width:, num_heads:, window_size:)
202
+ super()
203
+ @width = width
204
+ @num_heads = num_heads
205
+ @window_size = window_size
206
+ @scale = (width / num_heads)**(-0.5)
207
+ @head_dim = @width / @num_heads
208
+
209
+ self.q_proj = MLX::NN::Linear.new(@width, @width, bias: false)
210
+ self.k_proj = MLX::NN::Linear.new(@width, @head_dim, bias: false)
211
+ self.v_proj = MLX::NN::Linear.new(@width, @head_dim, bias: false)
212
+ self.o_proj = MLX::NN::Linear.new(@width, @width, bias: true)
213
+ self.rope = MLX::NN::RoPE.new(@head_dim / 2, traditional: false)
214
+ end
215
+
216
+ def call(x, cache: nil, mask: nil)
217
+ mx = MLX::Core
218
+ b, l, _ = x.shape
219
+
220
+ queries = q_proj.call(x)
221
+ keys = k_proj.call(x)
222
+ values = v_proj.call(x)
223
+
224
+ queries = queries.reshape([b, l, @num_heads, @head_dim]).transpose([0, 2, 1, 3])
225
+ keys = keys.reshape([b, l, 1, @head_dim]).transpose([0, 2, 1, 3])
226
+ values = values.reshape([b, l, 1, @head_dim]).transpose([0, 2, 1, 3])
227
+
228
+ if cache
229
+ queries = rope.call(queries, offset: cache.offset)
230
+ keys = rope.call(keys, offset: cache.offset)
231
+ keys, values = cache.update_and_fetch(keys, values)
232
+ else
233
+ queries = rope.call(queries)
234
+ keys = rope.call(keys)
235
+ end
236
+
237
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
238
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @width])
239
+ o_proj.call(output)
240
+ end
241
+ end
242
+
243
+ class MLPBlock < MLX::NN::Module
244
+ def initialize(width:, expanded_width:)
245
+ super()
246
+ hidden = expanded_width / 2
247
+ self.up_proj = MLX::NN::Linear.new(width, hidden)
248
+ self.gate_proj = MLX::NN::Linear.new(width, hidden)
249
+ self.down_proj = MLX::NN::Linear.new(hidden, width)
250
+ end
251
+
252
+ def call(x)
253
+ down_proj.call(MLX::NN.gelu_approx(gate_proj.call(x)) * up_proj.call(x))
254
+ end
255
+ end
256
+
257
+ class ResidualBlock < MLX::NN::Module
258
+ attr_reader :temporal_block_type
259
+
260
+ def initialize(
261
+ width:,
262
+ mlp_expanded_width:,
263
+ num_heads:,
264
+ attention_window_size:,
265
+ temporal_block_type:,
266
+ lru_width: nil,
267
+ conv1d_temporal_width: 4
268
+ )
269
+ super()
270
+ @temporal_block_type = temporal_block_type
271
+
272
+ self.temporal_pre_norm = RMSNorm.new(width)
273
+ self.temporal_block = if temporal_block_type == "recurrent"
274
+ RecurrentBlock.new(
275
+ width: width,
276
+ num_heads: num_heads,
277
+ lru_width: lru_width,
278
+ conv1d_temporal_width: conv1d_temporal_width
279
+ )
280
+ else
281
+ LocalAttentionBlock.new(
282
+ width: width,
283
+ num_heads: num_heads,
284
+ window_size: attention_window_size
285
+ )
286
+ end
287
+
288
+ self.channel_pre_norm = RMSNorm.new(width)
289
+ self.mlp_block = MLPBlock.new(width: width, expanded_width: mlp_expanded_width)
290
+ end
291
+
292
+ def call(x, cache: nil, mask: nil)
293
+ raw_x = x
294
+ x = temporal_block.call(temporal_pre_norm.call(raw_x), cache: cache, mask: mask)
295
+ residual = x + raw_x
296
+ x = mlp_block.call(channel_pre_norm.call(residual))
297
+ x + residual
298
+ end
299
+ end
300
+
301
+ class Griffin < MLX::NN::Module
302
+ attr_reader :window_size, :swa_idx
303
+
304
+ def initialize(config)
305
+ super()
306
+ @config = config
307
+ @scale_by_sqrt_dim = config.embeddings_scale_by_sqrt_dim
308
+
309
+ block_types = Array(config.block_types)
310
+ block_types = ["recurrent"] if block_types.empty?
311
+
312
+ self.embed_tokens = MLX::NN::Embedding.new(config.vocab_size, config.hidden_size)
313
+ self.layers = Array.new(config.num_hidden_layers) do |i|
314
+ ResidualBlock.new(
315
+ width: config.hidden_size,
316
+ mlp_expanded_width: config.intermediate_size,
317
+ num_heads: config.num_attention_heads,
318
+ attention_window_size: config.attention_window_size,
319
+ temporal_block_type: block_types[i % block_types.length],
320
+ lru_width: nil,
321
+ conv1d_temporal_width: config.conv1d_width
322
+ )
323
+ end
324
+ self.final_norm = RMSNorm.new(config.hidden_size, eps: config.rms_norm_eps)
325
+
326
+ @window_size = config.attention_window_size
327
+ @swa_idx = block_types.index("attention") || 0
328
+ end
329
+
330
+ def call(tokens, cache: nil)
331
+ x = embed_tokens.call(tokens)
332
+ x = x * Math.sqrt(x.shape[-1]) if @scale_by_sqrt_dim
333
+
334
+ layer_cache = cache || [nil] * layers.length
335
+ mask = _create_attention_mask(x, layer_cache[@swa_idx], window_size: @window_size)
336
+
337
+ layers.each_with_index do |block, i|
338
+ x = block.call(x, mask: mask, cache: layer_cache[i])
339
+ end
340
+
341
+ final_norm.call(x)
342
+ end
343
+
344
+ private
345
+
346
+ def _create_attention_mask(h, cache = nil, window_size: nil)
347
+ n = h.shape[1]
348
+ if cache && cache.respond_to?(:make_mask)
349
+ return cache.make_mask(n, window_size: window_size)
350
+ end
351
+
352
+ if window_size
353
+ offset = 0
354
+ if cache
355
+ offset = cache.offset
356
+ if cache.instance_variable_defined?(:@max_size)
357
+ max_size = cache.instance_variable_get(:@max_size)
358
+ offset = [max_size - 1, offset].min if max_size && max_size > 0
359
+ end
360
+ end
361
+ return _create_causal_mask(n, offset: offset, window_size: window_size) if offset + n > window_size
362
+ end
363
+ return nil if n == 1
364
+
365
+ "causal"
366
+ end
367
+
368
+ def _create_causal_mask(n, offset: 0, window_size: nil)
369
+ mx = MLX::Core
370
+ rinds = mx.arange(0, offset + n, 1, mx.int32).reshape([1, offset + n])
371
+ linds = mx.arange(offset, offset + n, 1, mx.int32).reshape([n, 1])
372
+
373
+ mask = mx.greater_equal(linds, rinds)
374
+ if window_size
375
+ mask = mx.logical_and(mask, mx.less(linds, mx.add(rinds, window_size)))
376
+ end
377
+ mask
378
+ end
379
+ end
380
+
381
+ class Model < MLX::NN::Module
382
+ attr_reader :args
383
+
384
+ def initialize(config)
385
+ super()
386
+ @args = config
387
+ @tie_word_embeddings = false
388
+ self.model_type = config.model_type
389
+ self.model = Griffin.new(config)
390
+ self.lm_head = MLX::NN::Linear.new(config.hidden_size, config.vocab_size, bias: false)
391
+ end
392
+
393
+ def call(tokens, cache: nil)
394
+ mx = MLX::Core
395
+ logits = model.call(tokens, cache: cache)
396
+ logits = if @tie_word_embeddings || lm_head.nil?
397
+ model.embed_tokens.as_linear(logits)
398
+ else
399
+ lm_head.call(logits)
400
+ end
401
+
402
+ c = args.logits_soft_cap
403
+ logits = mx.tanh(logits / c) * c if c && c != 0
404
+ logits
405
+ end
406
+
407
+ def layers
408
+ model.layers
409
+ end
410
+
411
+ def sanitize(weights)
412
+ mx = MLX::Core
413
+ sanitized = {}
414
+ weights.each do |key, value|
415
+ current = value
416
+ if key.include?("conv_1d.weight") && value.shape[-1] != 1
417
+ current = mx.swapaxes(value, 1, 2)
418
+ end
419
+ sanitized[key] = current
420
+ end
421
+
422
+ unless sanitized.key?("lm_head.weight")
423
+ @tie_word_embeddings = true
424
+ self.lm_head = nil
425
+ end
426
+
427
+ sanitized
428
+ end
429
+
430
+ def make_cache
431
+ layers.map do |layer|
432
+ if layer.temporal_block_type == "recurrent"
433
+ MlxLm::ArraysCache.new(2)
434
+ else
435
+ MlxLm::RotatingKVCache.new(max_size: args.attention_window_size)
436
+ end
437
+ end
438
+ end
439
+ end
440
+
441
+ Models.register("recurrent_gemma", Model, ModelArgs)
442
+ end
443
+ end
444
+ end