mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
@@ -0,0 +1,316 @@
1
+ module MlxLm
2
+ module Models
3
+ class SuScaledRoPE < MLX::NN::Module
4
+ def initialize(
5
+ dims,
6
+ base: 10_000.0,
7
+ max_position_embeddings: 131_072,
8
+ original_max_position_embeddings: 4096,
9
+ short_factor: 1.0,
10
+ long_factor: 1.0,
11
+ short_mscale: nil,
12
+ long_mscale: nil
13
+ )
14
+ super()
15
+ mx = MLX::Core
16
+ @dim = dims
17
+ @original_max_position_embeddings = original_max_position_embeddings
18
+
19
+ freqs = mx.power(
20
+ base.to_f,
21
+ mx.divide(mx.arange(0, dims, 2, mx.float32), dims.to_f)
22
+ )
23
+ self._freqs = mx.multiply(mx.array(long_factor, dtype: mx.float32), freqs)
24
+
25
+ factor = max_position_embeddings.to_f / original_max_position_embeddings
26
+ self._scale = long_mscale || if factor <= 1.0
27
+ 1.0
28
+ else
29
+ Math.sqrt(1 + Math.log(factor) / Math.log(original_max_position_embeddings))
30
+ end
31
+ end
32
+
33
+ def call(x, offset: 0)
34
+ mx = MLX::Core
35
+ x = scale_rotary_part(x, _scale)
36
+ mx.rope(x, @dim, false, nil, 1.0, offset, _freqs)
37
+ end
38
+
39
+ private
40
+
41
+ def scale_rotary_part(x, scale)
42
+ return x if scale == 1.0
43
+
44
+ mx = MLX::Core
45
+ rotary, rest = mx.split(x, [@dim], -1)
46
+ mx.concatenate([mx.multiply(rotary, scale), rest], -1)
47
+ end
48
+ end
49
+
50
+ class Llama3RoPE < MLX::NN::Module
51
+ def initialize(
52
+ dims:,
53
+ max_position_embeddings: 2048,
54
+ traditional: false,
55
+ base: 10_000,
56
+ scaling_config: nil
57
+ )
58
+ super()
59
+ mx = MLX::Core
60
+
61
+ @dims = dims
62
+ @max_position_embeddings = max_position_embeddings
63
+ @traditional = traditional
64
+
65
+ factor = config_value(scaling_config, "factor")
66
+ low_freq_factor = config_value(scaling_config, "low_freq_factor", 1.0)
67
+ high_freq_factor = config_value(scaling_config, "high_freq_factor", 4.0)
68
+ old_context_len = config_value(
69
+ scaling_config,
70
+ "original_max_position_embeddings",
71
+ 8192
72
+ )
73
+
74
+ low_freq_wavelen = old_context_len.to_f / low_freq_factor
75
+ high_freq_wavelen = old_context_len.to_f / high_freq_factor
76
+
77
+ freqs = mx.power(
78
+ base.to_f,
79
+ mx.divide(mx.arange(0, dims, 2), dims.to_f)
80
+ )
81
+ wavelens = mx.multiply(2.0 * Math::PI, freqs)
82
+
83
+ freqs = mx.where(
84
+ mx.greater(wavelens, low_freq_wavelen),
85
+ mx.multiply(freqs, factor),
86
+ freqs
87
+ )
88
+
89
+ is_medium_freq = mx.logical_and(
90
+ mx.greater(wavelens, high_freq_wavelen),
91
+ mx.less(wavelens, low_freq_wavelen)
92
+ )
93
+
94
+ smooth_factors = mx.divide(
95
+ mx.subtract(mx.divide(old_context_len.to_f, wavelens), low_freq_factor),
96
+ (high_freq_factor - low_freq_factor).to_f
97
+ )
98
+
99
+ smooth_freqs = mx.divide(
100
+ freqs,
101
+ mx.add(
102
+ mx.divide(mx.subtract(1.0, smooth_factors), factor.to_f),
103
+ smooth_factors
104
+ )
105
+ )
106
+
107
+ self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs)
108
+ end
109
+
110
+ def extra_repr
111
+ "#{@dims}, traditional=#{@traditional}, max_position_embeddings=#{@max_position_embeddings}"
112
+ end
113
+
114
+ def call(x, offset: 0)
115
+ MLX::Core.rope(x, @dims, @traditional, nil, 1.0, offset, _freqs)
116
+ end
117
+
118
+ private
119
+
120
+ def config_value(config, key, default = nil)
121
+ return default if config.nil?
122
+ return config[key] if config.key?(key)
123
+
124
+ config.fetch(key.to_sym, default)
125
+ end
126
+ end
127
+
128
+ class YarnRoPE < MLX::NN::Module
129
+ def initialize(
130
+ dims,
131
+ traditional: false,
132
+ max_position_embeddings: 2048,
133
+ base: 10_000,
134
+ scaling_factor: 1.0,
135
+ original_max_position_embeddings: 4096,
136
+ beta_fast: 32,
137
+ beta_slow: 1,
138
+ mscale: 1,
139
+ mscale_all_dim: 0
140
+ )
141
+ super()
142
+ mx = MLX::Core
143
+
144
+ self.mscale = yarn_get_mscale(scaling_factor, mscale) /
145
+ yarn_get_mscale(scaling_factor, mscale_all_dim)
146
+
147
+ freq_extra = mx.power(
148
+ base.to_f,
149
+ mx.divide(mx.arange(0, dims, 2, mx.float32), dims.to_f)
150
+ )
151
+ freq_inter = mx.multiply(scaling_factor.to_f, freq_extra)
152
+
153
+ low, high = yarn_find_correction_range(
154
+ dims,
155
+ base,
156
+ original_max_position_embeddings,
157
+ beta_fast,
158
+ beta_slow
159
+ )
160
+
161
+ freq_mask = mx.subtract(1.0, yarn_linear_ramp_mask(low, high, dims / 2))
162
+ self._freqs = mx.divide(
163
+ mx.multiply(freq_inter, freq_extra),
164
+ mx.add(
165
+ mx.multiply(freq_inter, freq_mask),
166
+ mx.multiply(freq_extra, mx.subtract(1.0, freq_mask))
167
+ )
168
+ )
169
+
170
+ @dims = dims
171
+ @traditional = traditional
172
+ end
173
+
174
+ def call(x, offset: 0)
175
+ mx = MLX::Core
176
+ x = scale_rotary_part(x, mscale) unless mscale == 1.0
177
+
178
+ mx.rope(x, @dims, @traditional, nil, 1.0, offset, _freqs)
179
+ end
180
+
181
+ private
182
+
183
+ def scale_rotary_part(x, scale)
184
+ mx = MLX::Core
185
+ rotary, rest = mx.split(x, [@dims], -1)
186
+ mx.concatenate([mx.multiply(rotary, scale), rest], -1)
187
+ end
188
+
189
+ def yarn_find_correction_dim(dims, base, original_max_position_embeddings, num_rotations)
190
+ dims * Math.log(original_max_position_embeddings.to_f / (num_rotations * 2 * Math::PI)) /
191
+ (2 * Math.log(base))
192
+ end
193
+
194
+ def yarn_find_correction_range(dims, base, original_max_position_embeddings, beta_fast, beta_slow)
195
+ low = yarn_find_correction_dim(dims, base, original_max_position_embeddings, beta_fast).floor
196
+ high = yarn_find_correction_dim(dims, base, original_max_position_embeddings, beta_slow).ceil
197
+ [
198
+ [low, 0].max,
199
+ [high, dims - 1].min,
200
+ ]
201
+ end
202
+
203
+ def yarn_get_mscale(scale = 1, mscale = 1)
204
+ return 1.0 if scale <= 1
205
+
206
+ 0.1 * mscale * Math.log(scale) + 1.0
207
+ end
208
+
209
+ def yarn_linear_ramp_mask(min_val, max_val, dim)
210
+ mx = MLX::Core
211
+
212
+ max_val += 0.001 if min_val == max_val
213
+
214
+ linear = mx.divide(
215
+ mx.subtract(mx.arange(0, dim, 1, mx.float32), min_val),
216
+ max_val - min_val
217
+ )
218
+ mx.clip(linear, 0.0, 1.0)
219
+ end
220
+ end
221
+
222
+ module_function
223
+
224
+ def initialize_rope(
225
+ dims,
226
+ base,
227
+ traditional,
228
+ scaling_config = nil,
229
+ max_position_embeddings: nil
230
+ )
231
+ rope_type = if scaling_config
232
+ config_value(scaling_config, "type") ||
233
+ config_value(scaling_config, "rope_type", "default")
234
+ else
235
+ "default"
236
+ end
237
+
238
+ case rope_type
239
+ when "default", "linear"
240
+ scale = rope_type == "linear" ? 1.0 / config_value(scaling_config, "factor") : 1.0
241
+ MLX::NN::RoPE.new(dims, traditional: traditional, base: base, scale: scale)
242
+ when "llama3"
243
+ Llama3RoPE.new(
244
+ dims: dims,
245
+ max_position_embeddings: max_position_embeddings,
246
+ traditional: traditional,
247
+ base: base,
248
+ scaling_config: scaling_config
249
+ )
250
+ when "yarn", "deepseek_yarn", "telechat3-yarn"
251
+ rope_kwargs = {}
252
+ %w[
253
+ original_max_position_embeddings
254
+ beta_fast
255
+ beta_slow
256
+ mscale
257
+ mscale_all_dim
258
+ ].each do |key|
259
+ value = config_value(scaling_config, key)
260
+ rope_kwargs[key.to_sym] = value unless value.nil?
261
+ end
262
+
263
+ YarnRoPE.new(
264
+ dims,
265
+ max_position_embeddings: max_position_embeddings,
266
+ traditional: traditional,
267
+ scaling_factor: config_value(scaling_config, "factor"),
268
+ base: base,
269
+ **rope_kwargs
270
+ )
271
+ when "longrope"
272
+ SuScaledRoPE.new(
273
+ dims,
274
+ base: base,
275
+ max_position_embeddings: max_position_embeddings,
276
+ original_max_position_embeddings: config_value(
277
+ scaling_config,
278
+ "original_max_position_embeddings"
279
+ ),
280
+ short_factor: config_value(scaling_config, "short_factor"),
281
+ long_factor: config_value(scaling_config, "long_factor")
282
+ )
283
+ when "mrope"
284
+ mrope_section = config_value(scaling_config, "mrope_section", [])
285
+ unless mrope_section.length == 3
286
+ raise ArgumentError,
287
+ "MRoPE currently only supports 3 sections, got #{mrope_section.length}."
288
+ end
289
+
290
+ MLX::NN::RoPE.new(dims, traditional: traditional, base: base)
291
+ else
292
+ raise ArgumentError, "Unsupported RoPE type #{rope_type}"
293
+ end
294
+ end
295
+
296
+ def config_value(config, key, default = nil)
297
+ return default if config.nil?
298
+ return config[key] if config.key?(key)
299
+
300
+ config.fetch(key.to_sym, default)
301
+ end
302
+ private_class_method :config_value
303
+
304
+ module RoPEUtils
305
+ SuScaledRoPE = MlxLm::Models::SuScaledRoPE
306
+ Llama3RoPE = MlxLm::Models::Llama3RoPE
307
+ YarnRoPE = MlxLm::Models::YarnRoPE
308
+
309
+ module_function
310
+
311
+ def initialize_rope(*args, **kwargs)
312
+ MlxLm::Models.initialize_rope(*args, **kwargs)
313
+ end
314
+ end
315
+ end
316
+ end
@@ -0,0 +1,101 @@
1
+ require_relative "recurrent_gemma"
2
+
3
+ module MlxLm
4
+ module Models
5
+ module Rwkv7
6
+ class ModelArgs < BaseModelArgs
7
+ field :model_type, default: "rwkv7"
8
+ field :vocab_size
9
+ field :hidden_size
10
+ field :intermediate_size
11
+ field :norm_eps, default: 1e-5
12
+ field :head_dim
13
+ field :num_hidden_layers
14
+ field :a_low_rank_dim, default: nil
15
+ field :v_low_rank_dim, default: nil
16
+ field :gate_low_rank_dim, default: nil
17
+ field :decay_low_rank_dim, default: nil
18
+ field :tie_word_embeddings, default: false
19
+ field :rope_theta, default: 10_000.0
20
+ field :attention_window_size, default: 128
21
+ field :block_types, default: nil
22
+ field :num_attention_heads, default: nil
23
+ field :num_key_value_heads, default: nil
24
+
25
+ def initialize(**kwargs)
26
+ super
27
+ if @num_attention_heads.nil? && !@hidden_size.nil? && !@head_dim.nil? && @head_dim.to_i > 0
28
+ @num_attention_heads = @hidden_size / @head_dim
29
+ end
30
+ @num_attention_heads ||= 1
31
+ @num_key_value_heads ||= @num_attention_heads
32
+ @block_types ||= Array.new(@num_hidden_layers.to_i, "recurrent")
33
+ end
34
+
35
+ def to_recurrent_gemma_dict
36
+ {
37
+ "model_type" => @model_type,
38
+ "attention_bias" => false,
39
+ "conv1d_width" => 3,
40
+ "hidden_size" => @hidden_size,
41
+ "intermediate_size" => @intermediate_size,
42
+ "logits_soft_cap" => nil,
43
+ "num_attention_heads" => @num_attention_heads,
44
+ "num_hidden_layers" => @num_hidden_layers,
45
+ "num_key_value_heads" => @num_key_value_heads,
46
+ "rms_norm_eps" => @norm_eps,
47
+ "rope_theta" => @rope_theta,
48
+ "attention_window_size" => @attention_window_size,
49
+ "vocab_size" => @vocab_size,
50
+ "embeddings_scale_by_sqrt_dim" => false,
51
+ "block_types" => @block_types,
52
+ }
53
+ end
54
+ end
55
+
56
+ class Model < MLX::NN::Module
57
+ def initialize(args)
58
+ super()
59
+ @args = args
60
+ self.model_type = args.model_type
61
+ self.wrapped_model = RecurrentGemma::Model.new(
62
+ RecurrentGemma::ModelArgs.from_dict(args.to_recurrent_gemma_dict)
63
+ )
64
+ end
65
+
66
+ def call(inputs, cache: nil)
67
+ wrapped_model.call(inputs, cache: cache)
68
+ end
69
+
70
+ def sanitize(weights)
71
+ remapped = {}
72
+ weights.each do |key, value|
73
+ remapped[_remap_weight_key(key)] = value
74
+ end
75
+ wrapped_model.sanitize(remapped)
76
+ end
77
+
78
+ def layers
79
+ wrapped_model.layers
80
+ end
81
+
82
+ def make_cache
83
+ return wrapped_model.make_cache if wrapped_model.respond_to?(:make_cache)
84
+
85
+ nil
86
+ end
87
+
88
+ private
89
+
90
+ def _remap_weight_key(key)
91
+ mapped = key.dup
92
+ mapped = mapped.gsub(/\Ablocks\./, "model.layers.")
93
+ mapped = mapped.gsub(".time_mix.", ".temporal_block.")
94
+ mapped
95
+ end
96
+ end
97
+
98
+ Models.register("rwkv7", Model, ModelArgs)
99
+ end
100
+ end
101
+ end
@@ -0,0 +1,167 @@
1
+ module MlxLm
2
+ module Models
3
+ module SeedOSS
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "seed_oss"
6
+ field :hidden_size, default: 4096
7
+ field :num_hidden_layers, default: 32
8
+ field :intermediate_size, default: 11008
9
+ field :num_attention_heads, default: 32
10
+ field :rms_norm_eps, default: 1e-6
11
+ field :vocab_size, default: 151936
12
+ field :num_key_value_heads, default: nil
13
+ field :head_dim, default: nil
14
+ field :max_position_embeddings, default: nil
15
+ field :attention_bias, default: false
16
+ field :attention_out_bias, default: false
17
+ field :mlp_bias, default: false
18
+ field :rope_theta, default: 10000.0
19
+ field :rope_traditional, default: false
20
+ field :rope_scaling, default: nil
21
+ field :tie_word_embeddings, default: true
22
+
23
+ def initialize(**kwargs)
24
+ super
25
+ @num_key_value_heads ||= @num_attention_heads
26
+ @head_dim ||= @hidden_size / @num_attention_heads
27
+ end
28
+ end
29
+
30
+ class Attention < MLX::NN::Module
31
+ def initialize(args)
32
+ super()
33
+
34
+ dim = args.hidden_size
35
+ @n_heads = args.num_attention_heads
36
+ @n_kv_heads = args.num_key_value_heads
37
+ @head_dim = args.head_dim
38
+ @scale = @head_dim**(-0.5)
39
+
40
+ input_bias = args.attention_bias
41
+ output_bias = args.attention_out_bias
42
+
43
+ self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: input_bias)
44
+ self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: input_bias)
45
+ self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: input_bias)
46
+ self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: output_bias)
47
+
48
+ self.rope = MlxLm::Models.initialize_rope(
49
+ @head_dim,
50
+ args.rope_theta,
51
+ args.rope_traditional,
52
+ args.rope_scaling,
53
+ max_position_embeddings: args.max_position_embeddings
54
+ )
55
+ end
56
+
57
+ def call(x, mask: nil, cache: nil)
58
+ mx = MLX::Core
59
+ b, l, _d = x.shape
60
+
61
+ queries = q_proj.call(x).reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
62
+ keys = k_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
63
+ values = v_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
64
+
65
+ if cache
66
+ queries = rope.call(queries, offset: cache.offset)
67
+ keys = rope.call(keys, offset: cache.offset)
68
+ keys, values = cache.update_and_fetch(keys, values)
69
+ else
70
+ queries = rope.call(queries)
71
+ keys = rope.call(keys)
72
+ end
73
+
74
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
75
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
76
+ o_proj.call(output)
77
+ end
78
+ end
79
+
80
+ class MLP < MLX::NN::Module
81
+ def initialize(dim, hidden_dim, bias: false)
82
+ super()
83
+ self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: bias)
84
+ self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: bias)
85
+ self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: bias)
86
+ end
87
+
88
+ def call(x)
89
+ down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x)))
90
+ end
91
+ end
92
+
93
+ class TransformerBlock < MLX::NN::Module
94
+ def initialize(args)
95
+ super()
96
+ self.self_attn = Attention.new(args)
97
+ self.mlp = MLP.new(args.hidden_size, args.intermediate_size, bias: args.mlp_bias)
98
+ self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
99
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
100
+ end
101
+
102
+ def call(x, mask: nil, cache: nil)
103
+ r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
104
+ h = x + r
105
+ r = mlp.call(post_attention_layernorm.call(h))
106
+ h + r
107
+ end
108
+ end
109
+
110
+ class SeedModel < MLX::NN::Module
111
+ def initialize(args)
112
+ super()
113
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
114
+ self.layers = Array.new(args.num_hidden_layers) { TransformerBlock.new(args) }
115
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
116
+ end
117
+
118
+ def call(inputs, cache: nil)
119
+ h = embed_tokens.call(inputs)
120
+ layer_cache = cache || [nil] * layers.length
121
+
122
+ mask = nil
123
+ mask = "causal" if h.shape[1] > 1
124
+
125
+ layers.each_with_index do |layer, i|
126
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
127
+ end
128
+
129
+ norm.call(h)
130
+ end
131
+ end
132
+
133
+ class Model < MLX::NN::Module
134
+ def initialize(args)
135
+ super()
136
+ @args = args
137
+ self.model_type = args.model_type
138
+ self.model = SeedModel.new(args)
139
+ self.tie_word_embeddings = args.tie_word_embeddings
140
+ unless args.tie_word_embeddings
141
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
142
+ end
143
+ end
144
+
145
+ def call(inputs, cache: nil)
146
+ h = model.call(inputs, cache: cache)
147
+ if tie_word_embeddings
148
+ model.embed_tokens.as_linear(h)
149
+ else
150
+ lm_head.call(h)
151
+ end
152
+ end
153
+
154
+ def sanitize(weights)
155
+ weights.delete("lm_head.weight") if tie_word_embeddings
156
+ weights
157
+ end
158
+
159
+ def layers
160
+ model.layers
161
+ end
162
+ end
163
+
164
+ Models.register("seed_oss", Model, ModelArgs)
165
+ end
166
+ end
167
+ end
@@ -0,0 +1,89 @@
1
+ module MlxLm
2
+ module Models
3
+ module SmolLM3
4
+ class ModelArgs < Llama::ModelArgs
5
+ field :model_type, default: "smollm3"
6
+ field :no_rope_layer_interval, default: 4
7
+ field :no_rope_layers, default: nil
8
+
9
+ def initialize(**kwargs)
10
+ super
11
+
12
+ if @no_rope_layers.nil?
13
+ @no_rope_layers = Array.new(@num_hidden_layers) do |i|
14
+ ((i + 1) % @no_rope_layer_interval).zero? ? 0 : 1
15
+ end
16
+ elsif @no_rope_layers.length != @num_hidden_layers
17
+ raise ArgumentError, "`no_rope_layers` length mismatch"
18
+ end
19
+ end
20
+ end
21
+
22
+ class NoPE < MLX::NN::Module
23
+ def call(x, offset: 0)
24
+ x
25
+ end
26
+ end
27
+
28
+ class Model < MLX::NN::Module
29
+ def initialize(args)
30
+ super()
31
+ @args = args
32
+ self.model_type = args.model_type
33
+ self.model = Llama::LlamaModel.new(args)
34
+ unless args.tie_word_embeddings
35
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
36
+ end
37
+
38
+ args.no_rope_layers.each_with_index do |use_rope, idx|
39
+ next if use_rope && use_rope != 0
40
+
41
+ model.layers[idx].self_attn.rope = NoPE.new
42
+ end
43
+ end
44
+
45
+ def call(inputs, cache: nil, input_embeddings: nil)
46
+ out = if input_embeddings.nil?
47
+ model.call(inputs, cache: cache)
48
+ else
49
+ _call_with_input_embeddings(input_embeddings, cache)
50
+ end
51
+
52
+ if @args.tie_word_embeddings
53
+ model.embed_tokens.as_linear(out)
54
+ else
55
+ lm_head.call(out)
56
+ end
57
+ end
58
+
59
+ def layers
60
+ model.layers
61
+ end
62
+
63
+ def sanitize(weights)
64
+ result = weights.reject { |k, _| k.include?("self_attn.rotary_emb.inv_freq") }
65
+ result.delete("lm_head.weight") if @args.tie_word_embeddings
66
+ result
67
+ end
68
+
69
+ private
70
+
71
+ def _call_with_input_embeddings(input_embeddings, cache)
72
+ h = input_embeddings
73
+ layer_cache = cache || [nil] * model.layers.length
74
+
75
+ mask = nil
76
+ mask = "causal" if h.shape[1] > 1
77
+
78
+ model.layers.each_with_index do |layer, i|
79
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
80
+ end
81
+
82
+ model.norm.call(h)
83
+ end
84
+ end
85
+
86
+ Models.register("smollm3", Model, ModelArgs)
87
+ end
88
+ end
89
+ end