mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
@@ -0,0 +1,404 @@
1
+ require_relative "cache"
2
+ require_relative "rope_utils"
3
+
4
+ module MlxLm
5
+ module Models
6
+ module NemotronNas
7
+ module_function
8
+
9
+ def find_multiple(n, k)
10
+ remainder = n % k
11
+ remainder.zero? ? n : (n + k - remainder)
12
+ end
13
+
14
+ def ffn_mult_to_intermediate_size(ffn_mult, hidden_size)
15
+ intermediate_size = (2 * ffn_mult.to_f * hidden_size / 3).to_i
16
+ find_multiple(intermediate_size, 256)
17
+ end
18
+
19
+ class AttentionConfig
20
+ attr_reader :no_op, :replace_with_linear, :sparsify, :n_heads_in_group, :window_length,
21
+ :num_sink_tokens, :use_prefill_window_in_sink_attention, :unshifted_sink
22
+
23
+ def initialize(
24
+ no_op: false,
25
+ replace_with_linear: false,
26
+ sparsify: nil,
27
+ n_heads_in_group: nil,
28
+ window_length: nil,
29
+ num_sink_tokens: nil,
30
+ use_prefill_window_in_sink_attention: false,
31
+ unshifted_sink: false
32
+ )
33
+ @no_op = no_op
34
+ @replace_with_linear = replace_with_linear
35
+ @sparsify = sparsify
36
+ @n_heads_in_group = n_heads_in_group
37
+ @window_length = window_length
38
+ @num_sink_tokens = num_sink_tokens
39
+ @use_prefill_window_in_sink_attention = use_prefill_window_in_sink_attention
40
+ @unshifted_sink = unshifted_sink
41
+
42
+ if @no_op || @replace_with_linear
43
+ @n_heads_in_group = nil
44
+ @window_length = nil
45
+ @num_sink_tokens = nil
46
+ else
47
+ raise ArgumentError, "n_heads_in_group must be specified for active attention blocks" if @n_heads_in_group.nil?
48
+ raise ArgumentError, "n_heads_in_group must be positive, got #{@n_heads_in_group}" if @n_heads_in_group.to_i <= 0
49
+ end
50
+ end
51
+
52
+ def self.from_dict(data)
53
+ hash = _symbolize_keys(data || {})
54
+ new(**hash)
55
+ end
56
+
57
+ def self._symbolize_keys(hash)
58
+ hash.each_with_object({}) { |(k, v), out| out[k.to_sym] = v }
59
+ end
60
+ private_class_method :_symbolize_keys
61
+ end
62
+
63
+ class FFNConfig
64
+ attr_reader :no_op, :replace_with_linear, :sparsify, :ffn_mult
65
+
66
+ def initialize(
67
+ no_op: false,
68
+ replace_with_linear: false,
69
+ sparsify: nil,
70
+ ffn_mult: nil
71
+ )
72
+ @no_op = no_op
73
+ @replace_with_linear = replace_with_linear
74
+ @sparsify = sparsify
75
+ @ffn_mult = ffn_mult
76
+
77
+ if @no_op || @replace_with_linear
78
+ @ffn_mult = nil
79
+ else
80
+ raise ArgumentError, "ffn_mult must be specified for active FFN blocks" if @ffn_mult.nil?
81
+ @ffn_mult = @ffn_mult.to_f.round(6)
82
+ end
83
+ end
84
+
85
+ def self.from_dict(data)
86
+ hash = _symbolize_keys(data || {})
87
+ new(**hash)
88
+ end
89
+
90
+ def self._symbolize_keys(hash)
91
+ hash.each_with_object({}) { |(k, v), out| out[k.to_sym] = v }
92
+ end
93
+ private_class_method :_symbolize_keys
94
+ end
95
+
96
+ class BlockConfig
97
+ attr_reader :attention, :ffn
98
+
99
+ def initialize(attention:, ffn:)
100
+ @attention = attention
101
+ @ffn = ffn
102
+ end
103
+
104
+ def self.from_dict(data)
105
+ hash = data || {}
106
+ attention_data = hash["attention"] || hash[:attention] || {}
107
+ ffn_data = hash["ffn"] || hash[:ffn] || {}
108
+ new(
109
+ attention: AttentionConfig.from_dict(attention_data),
110
+ ffn: FFNConfig.from_dict(ffn_data)
111
+ )
112
+ end
113
+ end
114
+
115
+ class ModelArgs < BaseModelArgs
116
+ field :model_type, default: "nemotron-nas"
117
+ field :hidden_size, default: 8192
118
+ field :num_hidden_layers, default: 80
119
+ field :num_attention_heads, default: 64
120
+ field :rms_norm_eps, default: 1e-5
121
+ field :vocab_size, default: 128_256
122
+ field :block_configs, default: []
123
+ field :hidden_act, default: "silu"
124
+ field :attention_bias, default: false
125
+ field :mlp_bias, default: false
126
+ field :rope_theta, default: 500_000.0
127
+ field :rope_scaling, default: nil
128
+ field :max_position_embeddings, default: 131_072
129
+ field :tie_word_embeddings, default: false
130
+
131
+ def initialize(**kwargs)
132
+ super
133
+ @block_configs = Array(@block_configs).map do |config|
134
+ config.is_a?(BlockConfig) ? config : BlockConfig.from_dict(config)
135
+ end
136
+
137
+ if @block_configs.length != @num_hidden_layers
138
+ raise ArgumentError,
139
+ "Number of block_configs (#{@block_configs.length}) must match num_hidden_layers (#{@num_hidden_layers})"
140
+ end
141
+
142
+ validate_rope_scaling!
143
+ validate_block_configs!
144
+ end
145
+
146
+ private
147
+
148
+ def validate_rope_scaling!
149
+ return unless @rope_scaling
150
+
151
+ factor = rope_scaling_value(:factor)
152
+ raise ArgumentError, "rope_scaling must contain 'factor'" if factor.nil?
153
+
154
+ rope_type = rope_scaling_value(:rope_type) || rope_scaling_value(:type)
155
+ raise ArgumentError, "rope_scaling must contain 'rope_type'" if rope_type.nil?
156
+
157
+ normalized = @rope_scaling.dup
158
+ normalized["rope_type"] = rope_type
159
+ normalized[:rope_type] = rope_type
160
+ @rope_scaling = normalized
161
+ end
162
+
163
+ def rope_scaling_value(key)
164
+ return nil unless @rope_scaling
165
+ return @rope_scaling[key] if @rope_scaling.key?(key)
166
+
167
+ @rope_scaling[key.to_s]
168
+ end
169
+
170
+ def validate_block_configs!
171
+ @block_configs.each_with_index do |block_config, i|
172
+ attention = block_config.attention
173
+ next if attention.no_op || attention.replace_with_linear
174
+
175
+ heads_in_group = attention.n_heads_in_group.to_i
176
+ if (@num_attention_heads % heads_in_group) != 0
177
+ raise ArgumentError,
178
+ "Layer #{i}: num_attention_heads (#{@num_attention_heads}) must be divisible by n_heads_in_group (#{attention.n_heads_in_group})"
179
+ end
180
+ end
181
+ end
182
+ end
183
+
184
+ class Attention < MLX::NN::Module
185
+ def initialize(args, attention_config)
186
+ super()
187
+
188
+ dim = args.hidden_size
189
+ @n_heads = args.num_attention_heads
190
+ @n_kv_heads = @n_heads / attention_config.n_heads_in_group
191
+ @head_dim = args.hidden_size / @n_heads
192
+ raise ArgumentError, "hidden_size (#{dim}) must be divisible by num_attention_heads (#{@n_heads})" if (@head_dim * @n_heads) != dim
193
+
194
+ @scale = @head_dim**(-0.5)
195
+ self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: args.attention_bias)
196
+ self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: args.attention_bias)
197
+ self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: args.attention_bias)
198
+ self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: args.attention_bias)
199
+ self.rope = MlxLm::Models.initialize_rope(
200
+ @head_dim,
201
+ args.rope_theta,
202
+ false,
203
+ args.rope_scaling,
204
+ max_position_embeddings: args.max_position_embeddings
205
+ )
206
+ end
207
+
208
+ def call(x, mask: nil, cache: nil)
209
+ mx = MLX::Core
210
+ b, l, _d = x.shape
211
+
212
+ queries = q_proj.call(x).reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
213
+ keys = k_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
214
+ values = v_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
215
+
216
+ if cache
217
+ queries = rope.call(queries, offset: cache.offset)
218
+ keys = rope.call(keys, offset: cache.offset)
219
+ keys, values = cache.update_and_fetch(keys, values)
220
+ else
221
+ queries = rope.call(queries)
222
+ keys = rope.call(keys)
223
+ end
224
+
225
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
226
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
227
+ o_proj.call(output)
228
+ end
229
+ end
230
+
231
+ class MLP < MLX::NN::Module
232
+ def initialize(args, ffn_config)
233
+ super()
234
+ hidden_dim = NemotronNas.ffn_mult_to_intermediate_size(ffn_config.ffn_mult, args.hidden_size)
235
+ @act_fn = args.hidden_act
236
+
237
+ supported = %w[silu relu gelu gelu_new gelu_fast]
238
+ unless supported.include?(@act_fn)
239
+ raise ArgumentError, "Unknown activation function: #{@act_fn}"
240
+ end
241
+
242
+ self.gate_proj = MLX::NN::Linear.new(args.hidden_size, hidden_dim, bias: args.mlp_bias)
243
+ self.down_proj = MLX::NN::Linear.new(hidden_dim, args.hidden_size, bias: args.mlp_bias)
244
+ self.up_proj = MLX::NN::Linear.new(args.hidden_size, hidden_dim, bias: args.mlp_bias)
245
+ end
246
+
247
+ def call(x)
248
+ gate = _activate(gate_proj.call(x))
249
+ down_proj.call(gate * up_proj.call(x))
250
+ end
251
+
252
+ private
253
+
254
+ def _activate(x)
255
+ case @act_fn
256
+ when "silu"
257
+ MLX::NN.silu(x)
258
+ when "relu"
259
+ MLX::NN.relu(x)
260
+ when "gelu"
261
+ MLX::NN.gelu(x)
262
+ when "gelu_new", "gelu_fast"
263
+ MLX::NN.gelu_approx(x)
264
+ else
265
+ x
266
+ end
267
+ end
268
+ end
269
+
270
+ class LinearSubblockReplacement < MLX::NN::Module
271
+ def initialize(hidden_size, bias)
272
+ super()
273
+ self.linear = MLX::NN::Linear.new(hidden_size, hidden_size, bias: bias)
274
+ end
275
+
276
+ def call(x, mask: nil, cache: nil)
277
+ _ = mask
278
+ _ = cache
279
+ linear.call(x)
280
+ end
281
+ end
282
+
283
+ class TransformerBlock < MLX::NN::Module
284
+ def initialize(args, layer_idx)
285
+ super()
286
+ block_config = args.block_configs[layer_idx]
287
+ @attention_config = block_config.attention
288
+ @ffn_config = block_config.ffn
289
+
290
+ self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) unless @attention_config.no_op
291
+ self.self_attn = if @attention_config.no_op
292
+ nil
293
+ elsif @attention_config.replace_with_linear
294
+ LinearSubblockReplacement.new(args.hidden_size, args.attention_bias)
295
+ else
296
+ Attention.new(args, @attention_config)
297
+ end
298
+
299
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) unless @ffn_config.no_op
300
+ self.mlp = if @ffn_config.no_op
301
+ nil
302
+ elsif @ffn_config.replace_with_linear
303
+ LinearSubblockReplacement.new(args.hidden_size, args.mlp_bias)
304
+ else
305
+ MLP.new(args, @ffn_config)
306
+ end
307
+ end
308
+
309
+ def call(x, mask: nil, cache: nil)
310
+ if self_attn
311
+ residual = x
312
+ h = input_layernorm.call(x)
313
+ x = residual + self_attn.call(h, mask: mask, cache: cache)
314
+ end
315
+
316
+ if mlp
317
+ residual = x
318
+ h = post_attention_layernorm.call(x)
319
+ x = residual + mlp.call(h)
320
+ end
321
+
322
+ x
323
+ end
324
+ end
325
+
326
+ class NemotronNASModel < MLX::NN::Module
327
+ attr_reader :num_attn_layers
328
+
329
+ def initialize(args)
330
+ super()
331
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
332
+ self.layers = Array.new(args.num_hidden_layers) { |layer_idx| TransformerBlock.new(args, layer_idx) }
333
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
334
+ @num_attn_layers = layers.count { |layer| !layer.self_attn.nil? }
335
+ end
336
+
337
+ def call(inputs, cache: nil)
338
+ h = embed_tokens.call(inputs)
339
+ layer_cache = cache || [nil] * @num_attn_layers
340
+ mask = _create_attention_mask(h, layer_cache[0])
341
+
342
+ cache_idx = 0
343
+ layers.each do |layer|
344
+ layer_state = if layer.self_attn
345
+ state = layer_cache[cache_idx]
346
+ cache_idx += 1
347
+ state
348
+ end
349
+ h = layer.call(h, mask: mask, cache: layer_state)
350
+ end
351
+
352
+ norm.call(h)
353
+ end
354
+
355
+ private
356
+
357
+ def _create_attention_mask(h, cache)
358
+ n = h.shape[1]
359
+ return cache.make_mask(n) if cache && cache.respond_to?(:make_mask)
360
+ return nil if n == 1
361
+
362
+ "causal"
363
+ end
364
+ end
365
+
366
+ class Model < MLX::NN::Module
367
+ def initialize(args)
368
+ super()
369
+ @args = args
370
+ self.model_type = args.model_type
371
+ self.model = NemotronNASModel.new(args)
372
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) unless args.tie_word_embeddings
373
+ end
374
+
375
+ def call(inputs, cache: nil)
376
+ out = model.call(inputs, cache: cache)
377
+ if @args.tie_word_embeddings
378
+ model.embed_tokens.as_linear(out)
379
+ else
380
+ lm_head.call(out)
381
+ end
382
+ end
383
+
384
+ def sanitize(weights)
385
+ result = weights.reject { |k, _| k.include?("self_attn.rotary_emb.inv_freq") }
386
+ result.delete("lm_head.weight") if @args.tie_word_embeddings
387
+ result
388
+ end
389
+
390
+ def layers
391
+ model.layers
392
+ end
393
+
394
+ def make_cache
395
+ layers.filter_map do |layer|
396
+ MlxLm::KVCache.new if layer.self_attn
397
+ end
398
+ end
399
+ end
400
+
401
+ Models.register("nemotron-nas", Model, ModelArgs)
402
+ end
403
+ end
404
+ end
@@ -0,0 +1,165 @@
1
+ module MlxLm
2
+ module Models
3
+ module OLMo
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "olmo"
6
+ field :d_model, default: nil
7
+ field :n_layers, default: nil
8
+ field :mlp_hidden_size, default: nil
9
+ field :n_heads, default: nil
10
+ field :vocab_size, default: 50304
11
+ field :embedding_size, default: nil
12
+ field :rope_theta, default: 10000.0
13
+ field :rope_traditional, default: false
14
+ field :mlp_ratio, default: 4
15
+ field :weight_tying, default: false
16
+
17
+ # Compatibility aliases used in some generic tests/config builders.
18
+ field :hidden_size, default: nil
19
+ field :num_hidden_layers, default: nil
20
+ field :intermediate_size, default: nil
21
+ field :num_attention_heads, default: nil
22
+ field :tie_word_embeddings, default: nil
23
+
24
+ def initialize(**kwargs)
25
+ super
26
+ @d_model = @hidden_size if @hidden_size
27
+ @n_layers = @num_hidden_layers if @num_hidden_layers
28
+ @n_heads = @num_attention_heads if @num_attention_heads
29
+ @mlp_hidden_size = @intermediate_size if @intermediate_size
30
+ @weight_tying = @tie_word_embeddings unless @tie_word_embeddings.nil?
31
+
32
+ @d_model ||= 4096
33
+ @n_layers ||= 32
34
+ @n_heads ||= 32
35
+ @embedding_size ||= @vocab_size
36
+ @mlp_hidden_size ||= @mlp_ratio * @d_model
37
+ end
38
+ end
39
+
40
+ class TransformerBlock < MLX::NN::Module
41
+ def initialize(args)
42
+ super()
43
+ dim = args.d_model
44
+ @n_heads = args.n_heads
45
+ @head_dim = dim / @n_heads
46
+ @scale = @head_dim**(-0.5)
47
+ @ff_hidden_size = args.mlp_hidden_size
48
+
49
+ self.ff_proj = MLX::NN::Linear.new(dim, @ff_hidden_size, bias: false)
50
+ self.ff_out = MLX::NN::Linear.new(@ff_hidden_size / 2, dim, bias: false)
51
+
52
+ self.att_norm = MLX::NN::LayerNorm.new(dim, affine: false)
53
+ self.ff_norm = MLX::NN::LayerNorm.new(dim, affine: false)
54
+
55
+ self.att_proj = MLX::NN::Linear.new(dim, 3 * dim, bias: false)
56
+ self.attn_out = MLX::NN::Linear.new(dim, dim, bias: false)
57
+
58
+ self.rope = MLX::NN::RoPE.new(
59
+ @head_dim,
60
+ traditional: args.rope_traditional,
61
+ base: args.rope_theta
62
+ )
63
+ end
64
+
65
+ def attend(x, mask: nil, cache: nil)
66
+ mx = MLX::Core
67
+ b, l, d = x.shape
68
+
69
+ qkv = att_proj.call(x)
70
+ queries, keys, values = mx.split(qkv, [d, 2 * d], 2)
71
+
72
+ queries = queries.reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
73
+ keys = keys.reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
74
+ values = values.reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
75
+
76
+ if cache
77
+ queries = rope.call(queries, offset: cache.offset)
78
+ keys = rope.call(keys, offset: cache.offset)
79
+ keys, values = cache.update_and_fetch(keys, values)
80
+ else
81
+ queries = rope.call(queries)
82
+ keys = rope.call(keys)
83
+ end
84
+
85
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
86
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, d])
87
+ attn_out.call(output)
88
+ end
89
+
90
+ def call(x, mask: nil, cache: nil)
91
+ mx = MLX::Core
92
+
93
+ r = attend(att_norm.call(x), mask: mask, cache: cache)
94
+ h = x + r
95
+
96
+ ff_hidden = ff_proj.call(ff_norm.call(h))
97
+ x1, x2 = mx.split(ff_hidden, [@ff_hidden_size / 2], 2)
98
+ h + ff_out.call(Activations.swiglu(x2, x1))
99
+ end
100
+ end
101
+
102
+ class Transformer < MLX::NN::Module
103
+ def initialize(args)
104
+ super()
105
+ @weight_tying = args.weight_tying
106
+
107
+ self.wte = MLX::NN::Embedding.new(args.embedding_size, args.d_model)
108
+ self.blocks = Array.new(args.n_layers) { TransformerBlock.new(args) }
109
+ self.ff_out = MLX::NN::Linear.new(args.d_model, args.embedding_size, bias: false) unless @weight_tying
110
+ self.norm = MLX::NN::LayerNorm.new(args.d_model, affine: false)
111
+ end
112
+
113
+ def call(inputs, cache: nil)
114
+ h = wte.call(inputs)
115
+ layer_cache = cache || [nil] * blocks.length
116
+
117
+ mask = nil
118
+ mask = "causal" if h.shape[1] > 1
119
+
120
+ blocks.each_with_index do |block, i|
121
+ h = block.call(h, mask: mask, cache: layer_cache[i])
122
+ end
123
+
124
+ h = norm.call(h)
125
+
126
+ if @weight_tying
127
+ wte.as_linear(h)
128
+ else
129
+ ff_out.call(h)
130
+ end
131
+ end
132
+ end
133
+
134
+ class OlmoModel < MLX::NN::Module
135
+ def initialize(args)
136
+ super()
137
+ self.transformer = Transformer.new(args)
138
+ end
139
+
140
+ def call(inputs, cache: nil)
141
+ transformer.call(inputs, cache: cache)
142
+ end
143
+ end
144
+
145
+ class Model < MLX::NN::Module
146
+ def initialize(args)
147
+ super()
148
+ self.model_type = args.model_type
149
+ self.model = OlmoModel.new(args)
150
+ self.args = args
151
+ end
152
+
153
+ def call(inputs, cache: nil)
154
+ model.call(inputs, cache: cache)
155
+ end
156
+
157
+ def layers
158
+ model.transformer.blocks
159
+ end
160
+ end
161
+
162
+ Models.register("olmo", Model, ModelArgs)
163
+ end
164
+ end
165
+ end