mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
@@ -0,0 +1,167 @@
1
+ module MlxLm
2
+ module Models
3
+ module Nanochat
4
+ module_function
5
+
6
+ def rms_norm(x, eps: 1e-5)
7
+ mx = MLX::Core
8
+ variance = mx.mean(mx.square(x), -1, true)
9
+ mx.multiply(x, mx.rsqrt(mx.add(variance, eps)))
10
+ end
11
+
12
+ def softcap(logits, cap: 15.0)
13
+ mx = MLX::Core
14
+ mx.multiply(cap, mx.tanh(mx.divide(logits, cap)))
15
+ end
16
+
17
+ class ModelArgs < BaseModelArgs
18
+ field :model_type, default: "nanochat"
19
+ field :hidden_size, default: 1280
20
+ field :num_hidden_layers, default: 20
21
+ field :num_attention_heads, default: 10
22
+ field :num_key_value_heads, default: 10
23
+ field :vocab_size, default: 65_536
24
+ field :max_position_embeddings, default: 2048
25
+ field :intermediate_size, default: 5120
26
+ field :rope_theta, default: 10_000.0
27
+ end
28
+
29
+ class Attention < MLX::NN::Module
30
+ def initialize(args)
31
+ super()
32
+
33
+ @hidden_size = args.hidden_size
34
+ @num_heads = args.num_attention_heads
35
+ @num_kv_heads = args.num_key_value_heads
36
+ @head_dim = @hidden_size / @num_heads
37
+ @scale = @head_dim**(-0.5)
38
+ @rope_theta = args.rope_theta
39
+
40
+ self.c_q = MLX::NN::Linear.new(@hidden_size, @num_heads * @head_dim, bias: false)
41
+ self.c_k = MLX::NN::Linear.new(@hidden_size, @num_kv_heads * @head_dim, bias: false)
42
+ self.c_v = MLX::NN::Linear.new(@hidden_size, @num_kv_heads * @head_dim, bias: false)
43
+ self.c_proj = MLX::NN::Linear.new(@hidden_size, @hidden_size, bias: false)
44
+
45
+ mx = MLX::Core
46
+ exponent = mx.multiply(
47
+ mx.arange(0, @head_dim, 2, mx.float32),
48
+ Math.log(@rope_theta) / @head_dim.to_f
49
+ )
50
+ self._rope_freqs = mx.multiply(-1.0, mx.exp(exponent))
51
+ end
52
+
53
+ def call(x, mask: nil, cache: nil)
54
+ mx = MLX::Core
55
+ b, l, _d = x.shape
56
+
57
+ queries = c_q.call(x)
58
+ keys = c_k.call(x)
59
+ values = c_v.call(x)
60
+
61
+ queries = queries.reshape([b, l, @num_heads, @head_dim]).transpose([0, 2, 1, 3])
62
+ keys = keys.reshape([b, l, @num_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
63
+ values = values.reshape([b, l, @num_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
64
+
65
+ offset = cache ? cache.offset : 0
66
+ queries = _apply_rotary_emb(queries, offset: offset)
67
+ keys = _apply_rotary_emb(keys, offset: offset)
68
+
69
+ queries = Nanochat.rms_norm(queries)
70
+ keys = Nanochat.rms_norm(keys)
71
+
72
+ if cache
73
+ keys, values = cache.update_and_fetch(keys, values)
74
+ end
75
+
76
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
77
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @hidden_size])
78
+ c_proj.call(output)
79
+ end
80
+
81
+ private
82
+
83
+ def _apply_rotary_emb(x, offset:)
84
+ MLX::Core.rope(x, @head_dim, false, nil, 1.0, offset, _rope_freqs)
85
+ end
86
+ end
87
+
88
+ class MLP < MLX::NN::Module
89
+ def initialize(args)
90
+ super()
91
+ self.c_fc = MLX::NN::Linear.new(args.hidden_size, args.intermediate_size, bias: false)
92
+ self.c_proj = MLX::NN::Linear.new(args.intermediate_size, args.hidden_size, bias: false)
93
+ end
94
+
95
+ def call(x)
96
+ c_proj.call(MLX::NN.relu2(c_fc.call(x)))
97
+ end
98
+ end
99
+
100
+ class TransformerBlock < MLX::NN::Module
101
+ def initialize(args)
102
+ super()
103
+ self.attn = Attention.new(args)
104
+ self.mlp = MLP.new(args)
105
+ end
106
+
107
+ def call(x, mask: nil, cache: nil)
108
+ h = x + attn.call(Nanochat.rms_norm(x), mask: mask, cache: cache)
109
+ h + mlp.call(Nanochat.rms_norm(h))
110
+ end
111
+ end
112
+
113
+ class NanoChatModel < MLX::NN::Module
114
+ def initialize(args)
115
+ super()
116
+ self.wte = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
117
+ self.h = Array.new(args.num_hidden_layers) { TransformerBlock.new(args) }
118
+ end
119
+
120
+ def call(inputs, cache: nil)
121
+ hidden = wte.call(inputs)
122
+ hidden = Nanochat.rms_norm(hidden)
123
+
124
+ layer_cache = cache || [nil] * h.length
125
+ mask = _create_attention_mask(hidden, layer_cache[0])
126
+
127
+ h.each_with_index do |layer, i|
128
+ hidden = layer.call(hidden, mask: mask, cache: layer_cache[i])
129
+ end
130
+
131
+ Nanochat.rms_norm(hidden)
132
+ end
133
+
134
+ private
135
+
136
+ def _create_attention_mask(hidden, cache)
137
+ return cache.make_mask(hidden.shape[1]) if cache && cache.respond_to?(:make_mask)
138
+ return nil if hidden.shape[1] == 1
139
+
140
+ "causal"
141
+ end
142
+ end
143
+
144
+ class Model < MLX::NN::Module
145
+ def initialize(args)
146
+ super()
147
+ self.args = args
148
+ self.model_type = args.model_type
149
+ self.transformer = NanoChatModel.new(args)
150
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
151
+ end
152
+
153
+ def call(inputs, cache: nil)
154
+ out = transformer.call(inputs, cache: cache)
155
+ logits = lm_head.call(out)
156
+ Nanochat.softcap(logits)
157
+ end
158
+
159
+ def layers
160
+ transformer.h
161
+ end
162
+ end
163
+
164
+ Models.register("nanochat", Model, ModelArgs)
165
+ end
166
+ end
167
+ end
@@ -0,0 +1,202 @@
1
+ module MlxLm
2
+ module Models
3
+ module Nemotron
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "nemotron"
6
+ field :hidden_size
7
+ field :hidden_act
8
+ field :num_hidden_layers
9
+ field :intermediate_size
10
+ field :num_attention_heads
11
+ field :norm_eps
12
+ field :vocab_size
13
+ field :num_key_value_heads
14
+ field :head_dim, default: nil
15
+ field :max_position_embeddings, default: nil
16
+ field :attention_bias, default: false
17
+ field :mlp_bias, default: false
18
+ field :partial_rotary_factor, default: 0.5
19
+ field :rope_theta, default: 10_000.0
20
+ field :rope_traditional, default: false
21
+ field :rope_scaling, default: nil
22
+ field :tie_word_embeddings, default: false
23
+
24
+ def initialize(**kwargs)
25
+ super
26
+ @head_dim ||= @hidden_size / @num_attention_heads
27
+ validate_rope_scaling!
28
+ end
29
+
30
+ private
31
+
32
+ def rope_scaling_value(key)
33
+ return nil unless @rope_scaling
34
+
35
+ @rope_scaling[key] || @rope_scaling[key.to_s]
36
+ end
37
+
38
+ def validate_rope_scaling!
39
+ return unless @rope_scaling
40
+
41
+ raise ArgumentError, "rope_scaling must contain 'factor'" if rope_scaling_value(:factor).nil?
42
+
43
+ rope_type = rope_scaling_value(:type) || rope_scaling_value(:rope_type)
44
+ if rope_type.nil?
45
+ raise ArgumentError, "rope_scaling must contain either 'type' or 'rope_type'"
46
+ end
47
+ return if rope_type == "linear"
48
+
49
+ raise ArgumentError, "rope_scaling 'type' currently only supports 'linear'"
50
+ end
51
+ end
52
+
53
+ class NemotronLayerNorm1P < MLX::NN::LayerNorm
54
+ def call(x)
55
+ w = state.key?("weight") ? weight + 1.0 : nil
56
+ b = state.key?("bias") ? bias : nil
57
+ MLX::Core.layer_norm(x, w, b, @eps)
58
+ end
59
+ end
60
+
61
+ class Attention < MLX::NN::Module
62
+ def initialize(args)
63
+ super()
64
+
65
+ dim = args.hidden_size
66
+ @n_heads = args.num_attention_heads
67
+ @n_kv_heads = args.num_key_value_heads
68
+ @head_dim = args.head_dim
69
+ @partial_rotary_factor = args.partial_rotary_factor
70
+ @scale = @head_dim**(-0.5)
71
+
72
+ bias = args.attention_bias
73
+ self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: bias)
74
+ self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: bias)
75
+ self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: bias)
76
+ self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: bias)
77
+
78
+ rope_scale = 1.0
79
+ if args.rope_scaling
80
+ rope_type = args.rope_scaling[:type] || args.rope_scaling["type"] ||
81
+ args.rope_scaling[:rope_type] || args.rope_scaling["rope_type"]
82
+ if rope_type == "linear"
83
+ factor = args.rope_scaling[:factor] || args.rope_scaling["factor"]
84
+ rope_scale = 1.0 / factor.to_f
85
+ end
86
+ end
87
+
88
+ self.rope = MLX::NN::RoPE.new(
89
+ (@partial_rotary_factor * @head_dim).to_i,
90
+ base: args.rope_theta,
91
+ scale: rope_scale
92
+ )
93
+ end
94
+
95
+ def call(x, mask: nil, cache: nil)
96
+ mx = MLX::Core
97
+ b, l, _d = x.shape
98
+
99
+ queries = q_proj.call(x).reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
100
+ keys = k_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
101
+ values = v_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
102
+
103
+ if cache
104
+ queries = rope.call(queries, offset: cache.offset)
105
+ keys = rope.call(keys, offset: cache.offset)
106
+ keys, values = cache.update_and_fetch(keys, values)
107
+ else
108
+ queries = rope.call(queries)
109
+ keys = rope.call(keys)
110
+ end
111
+
112
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
113
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
114
+ o_proj.call(output)
115
+ end
116
+ end
117
+
118
+ class MLP < MLX::NN::Module
119
+ def initialize(args)
120
+ super()
121
+
122
+ dim = args.hidden_size
123
+ hidden_dim = args.intermediate_size
124
+ bias = args.mlp_bias
125
+ self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: bias)
126
+ self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: bias)
127
+ end
128
+
129
+ def call(x)
130
+ down_proj.call(MLX::NN.relu2(up_proj.call(x)))
131
+ end
132
+ end
133
+
134
+ class TransformerBlock < MLX::NN::Module
135
+ def initialize(args)
136
+ super()
137
+ self.self_attn = Attention.new(args)
138
+ self.mlp = MLP.new(args)
139
+ self.input_layernorm = NemotronLayerNorm1P.new(args.hidden_size, eps: args.norm_eps)
140
+ self.post_attention_layernorm = NemotronLayerNorm1P.new(args.hidden_size, eps: args.norm_eps)
141
+ end
142
+
143
+ def call(x, mask: nil, cache: nil)
144
+ r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
145
+ h = x + r
146
+ r = mlp.call(post_attention_layernorm.call(h))
147
+ h + r
148
+ end
149
+ end
150
+
151
+ class NemotronModel < MLX::NN::Module
152
+ def initialize(args)
153
+ super()
154
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
155
+ self.layers = Array.new(args.num_hidden_layers) { TransformerBlock.new(args) }
156
+ self.norm = NemotronLayerNorm1P.new(args.hidden_size, eps: args.norm_eps)
157
+ end
158
+
159
+ def call(inputs, cache: nil)
160
+ h = embed_tokens.call(inputs)
161
+ layer_cache = cache || [nil] * layers.length
162
+
163
+ mask = nil
164
+ mask = "causal" if h.shape[1] > 1
165
+
166
+ layers.each_with_index do |layer, i|
167
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
168
+ end
169
+
170
+ norm.call(h)
171
+ end
172
+ end
173
+
174
+ class Model < MLX::NN::Module
175
+ def initialize(args)
176
+ super()
177
+ @args = args
178
+ self.model_type = args.model_type
179
+ self.model = NemotronModel.new(args)
180
+ unless args.tie_word_embeddings
181
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
182
+ end
183
+ end
184
+
185
+ def call(inputs, cache: nil)
186
+ out = model.call(inputs, cache: cache)
187
+ if @args.tie_word_embeddings
188
+ model.embed_tokens.as_linear(out)
189
+ else
190
+ lm_head.call(out)
191
+ end
192
+ end
193
+
194
+ def layers
195
+ model.layers
196
+ end
197
+ end
198
+
199
+ Models.register("nemotron", Model, ModelArgs)
200
+ end
201
+ end
202
+ end
@@ -0,0 +1,212 @@
1
+ require_relative "falcon_h1"
2
+
3
+ module MlxLm
4
+ module Models
5
+ module NemotronH
6
+ class ModelArgs < FalconH1::ModelArgs
7
+ field :model_type, default: "nemotron_h"
8
+ field :tie_word_embeddings, default: false
9
+ field :mamba_num_heads, default: nil
10
+ field :mamba_head_dim, default: nil
11
+ field :mamba_proj_bias, default: nil
12
+ field :ssm_state_size, default: nil
13
+ field :conv_kernel, default: nil
14
+ field :n_groups, default: nil
15
+ field :mlp_bias, default: nil
16
+ field :layer_norm_epsilon, default: nil
17
+ field :use_bias, default: nil
18
+ field :use_conv_bias, default: nil
19
+ field :hybrid_override_pattern, default: nil
20
+ field :moe_intermediate_size, default: nil
21
+ field :moe_shared_expert_intermediate_size, default: nil
22
+ field :n_group, default: nil
23
+ field :n_routed_experts, default: nil
24
+ field :n_shared_experts, default: nil
25
+ field :topk_group, default: nil
26
+ field :num_experts_per_tok, default: nil
27
+ field :norm_topk_prob, default: nil
28
+ field :routed_scaling_factor, default: nil
29
+ field :time_step_limit, default: nil
30
+ field :time_step_min, default: nil
31
+ field :time_step_max, default: nil
32
+
33
+ def initialize(**kwargs)
34
+ super
35
+
36
+ @mamba_d_conv = @conv_kernel if kwargs.key?(:conv_kernel) && !kwargs.key?(:mamba_d_conv) && !@conv_kernel.nil?
37
+ @rms_norm_eps = @layer_norm_epsilon if kwargs.key?(:layer_norm_epsilon) && !kwargs.key?(:rms_norm_eps) && !@layer_norm_epsilon.nil?
38
+ @num_attention_heads ||= @mamba_num_heads
39
+ @head_dim ||= @mamba_head_dim
40
+
41
+ pattern = _hybrid_pattern_array
42
+ @hybrid_override_pattern = pattern unless pattern.nil?
43
+ @hybrid_override_pattern ||= _default_hybrid_pattern
44
+
45
+ if @num_hidden_layers.nil? && @hybrid_override_pattern.is_a?(Array) && !@hybrid_override_pattern.empty?
46
+ @num_hidden_layers = @hybrid_override_pattern.length
47
+ end
48
+
49
+ @num_key_value_heads ||= @num_attention_heads
50
+ @mamba_d_conv ||= 4
51
+ @block_types ||= _to_block_types(@hybrid_override_pattern)
52
+ end
53
+
54
+ def to_falcon_h1_dict
55
+ hidden_size = @hidden_size
56
+ attention_heads = @num_attention_heads
57
+ inferred_head_dim = if !@head_dim.nil?
58
+ @head_dim
59
+ elsif !@mamba_head_dim.nil?
60
+ @mamba_head_dim
61
+ elsif !hidden_size.nil? && attention_heads.to_i > 0
62
+ hidden_size / attention_heads
63
+ else
64
+ 64
65
+ end
66
+
67
+ {
68
+ "model_type" => @model_type,
69
+ "attention_bias" => @attention_bias,
70
+ "head_dim" => inferred_head_dim,
71
+ "hidden_size" => hidden_size,
72
+ "intermediate_size" => @intermediate_size || @moe_shared_expert_intermediate_size || hidden_size.to_i * 2,
73
+ "max_position_embeddings" => @max_position_embeddings,
74
+ "mamba_d_conv" => @mamba_d_conv,
75
+ "num_attention_heads" => attention_heads,
76
+ "num_hidden_layers" => @num_hidden_layers,
77
+ "num_key_value_heads" => @num_key_value_heads,
78
+ "rms_norm_eps" => @rms_norm_eps || @layer_norm_epsilon || 1e-5,
79
+ "rope_theta" => @rope_theta,
80
+ "vocab_size" => @vocab_size,
81
+ "tie_word_embeddings" => @tie_word_embeddings,
82
+ "attention_window_size" => @attention_window_size,
83
+ "block_types" => @block_types,
84
+ }
85
+ end
86
+
87
+ private
88
+
89
+ def _hybrid_pattern_array
90
+ return nil if @hybrid_override_pattern.nil?
91
+ return @hybrid_override_pattern if @hybrid_override_pattern.is_a?(Array)
92
+ return @hybrid_override_pattern.chars if @hybrid_override_pattern.is_a?(String)
93
+
94
+ nil
95
+ end
96
+
97
+ def _default_hybrid_pattern
98
+ count = @num_hidden_layers.to_i
99
+ return nil if count <= 0
100
+
101
+ Array.new(count) { |idx| idx.even? ? "*" : "M" }
102
+ end
103
+
104
+ def _to_block_types(pattern)
105
+ return @block_types if @block_types.is_a?(Array) && !@block_types.empty?
106
+ return nil unless pattern.is_a?(Array) && !pattern.empty?
107
+
108
+ pattern.map do |block_type|
109
+ case block_type.to_s
110
+ when "*"
111
+ "attention"
112
+ else
113
+ "recurrent"
114
+ end
115
+ end
116
+ end
117
+ end
118
+
119
+ class Model < MLX::NN::Module
120
+ def initialize(args)
121
+ super()
122
+ @args = args
123
+ self.model_type = args.model_type
124
+ self.wrapped_model = FalconH1::Model.new(
125
+ FalconH1::ModelArgs.from_dict(args.to_falcon_h1_dict)
126
+ )
127
+ end
128
+
129
+ def call(inputs, cache: nil)
130
+ wrapped_model.call(inputs, cache: cache)
131
+ end
132
+
133
+ def sanitize(weights)
134
+ normalized = weights.is_a?(Hash) ? weights.dup : weights.to_h
135
+ _stack_experts!(normalized)
136
+
137
+ remapped = {}
138
+ normalized.each do |key, value|
139
+ remapped[_remap_weight_key(key)] = value
140
+ end
141
+
142
+ wrapped_model.sanitize(remapped)
143
+ end
144
+
145
+ def layers
146
+ wrapped_model.layers
147
+ end
148
+
149
+ def make_cache
150
+ return nil unless wrapped_model.respond_to?(:make_cache)
151
+
152
+ wrapped_model.make_cache
153
+ end
154
+
155
+ private
156
+
157
+ def _stack_experts!(weights)
158
+ mx = MLX::Core
159
+ grouped = Hash.new { |h, k| h[k] = [] }
160
+ pattern = /\A(backbone\.layers\.\d+\.mixer|model\.layers(?:\.layers)?\.\d+\.mixer)\.experts\.(\d+)\.(up_proj|down_proj)\.(weight|bias|scales|biases)\z/
161
+
162
+ weights.keys.each do |key|
163
+ match = pattern.match(key)
164
+ next unless match
165
+
166
+ prefix = match[1]
167
+ expert_idx = match[2].to_i
168
+ projection = match[3]
169
+ param = match[4]
170
+ grouped[[prefix, projection, param]] << [expert_idx, key]
171
+ end
172
+
173
+ grouped.each do |(prefix, projection, param), entries|
174
+ next if entries.empty?
175
+
176
+ stacked = entries.sort_by(&:first).map { |_, key| weights.delete(key) }
177
+ target = projection == "up_proj" ? "fc1" : "fc2"
178
+ weights["#{prefix}.switch_mlp.#{target}.#{param}"] = mx.stack(stacked)
179
+ end
180
+ end
181
+
182
+ def _remap_weight_key(key)
183
+ mapped = key.dup
184
+ mapped = mapped.gsub("backbone.embeddings.", "model.embed_tokens.")
185
+ mapped = mapped.gsub("backbone.norm_f.", "model.final_layernorm.")
186
+ mapped = mapped.gsub("backbone.layers.", "model.layers.")
187
+ mapped = mapped.gsub("model.layers.layers.", "model.layers.")
188
+
189
+ mapped = mapped.gsub(/\.layers\.(\d+)\.norm\./) { ".layers.#{$1}.input_layernorm." }
190
+
191
+ mapped = mapped.gsub(".mixer.conv1d.", ".mamba.conv1d.")
192
+ mapped = mapped.gsub(".mixer.in_proj.", ".mamba.in_proj.")
193
+ mapped = mapped.gsub(".mixer.out_proj.", ".mamba.out_proj.")
194
+ mapped = mapped.gsub(".mixer.q_proj.", ".self_attn.q_proj.")
195
+ mapped = mapped.gsub(".mixer.k_proj.", ".self_attn.k_proj.")
196
+ mapped = mapped.gsub(".mixer.v_proj.", ".self_attn.v_proj.")
197
+ mapped = mapped.gsub(".mixer.o_proj.", ".self_attn.o_proj.")
198
+ mapped = mapped.gsub(".mixer.gate.", ".feed_forward.router.")
199
+ mapped = mapped.gsub(".mixer.switch_mlp.fc1.", ".feed_forward.switch_mlp.up_proj.")
200
+ mapped = mapped.gsub(".mixer.switch_mlp.fc2.", ".feed_forward.switch_mlp.down_proj.")
201
+ mapped = mapped.gsub(".mixer.shared_experts.up_proj.", ".feed_forward.up_proj.")
202
+ mapped = mapped.gsub(".mixer.shared_experts.down_proj.", ".feed_forward.down_proj.")
203
+ mapped = mapped.gsub(".mixer.up_proj.", ".feed_forward.up_proj.")
204
+ mapped = mapped.gsub(".mixer.down_proj.", ".feed_forward.down_proj.")
205
+ mapped
206
+ end
207
+ end
208
+
209
+ Models.register("nemotron_h", Model, ModelArgs)
210
+ end
211
+ end
212
+ end