mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
@@ -0,0 +1,170 @@
1
+ module MlxLm
2
+ module Models
3
+ module Granite
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "granite"
6
+ field :hidden_size
7
+ field :num_hidden_layers
8
+ field :intermediate_size
9
+ field :num_attention_heads
10
+ field :rms_norm_eps
11
+ field :vocab_size
12
+ field :logits_scaling
13
+ field :attention_multiplier
14
+ field :embedding_multiplier
15
+ field :residual_multiplier
16
+ field :max_position_embeddings
17
+ field :num_key_value_heads
18
+ field :attention_bias
19
+ field :mlp_bias
20
+ field :rope_theta
21
+ field :rope_scaling, default: nil
22
+ field :tie_word_embeddings, default: true
23
+
24
+ def initialize(**kwargs)
25
+ super
26
+ @num_key_value_heads ||= @num_attention_heads
27
+ end
28
+ end
29
+
30
+ class Attention < MLX::NN::Module
31
+ def initialize(args)
32
+ super()
33
+
34
+ dim = args.hidden_size
35
+ @n_heads = args.num_attention_heads
36
+ @n_kv_heads = args.num_key_value_heads
37
+ @head_dim = dim / @n_heads
38
+ @scale = args.attention_multiplier
39
+
40
+ bias = args.attention_bias
41
+ self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: bias)
42
+ self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: bias)
43
+ self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: bias)
44
+ self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: bias)
45
+
46
+ self.rope = MlxLm::Models.initialize_rope(
47
+ @head_dim,
48
+ args.rope_theta,
49
+ false,
50
+ args.rope_scaling,
51
+ max_position_embeddings: args.max_position_embeddings
52
+ )
53
+ end
54
+
55
+ def call(x, mask: nil, cache: nil)
56
+ mx = MLX::Core
57
+ b, l, _d = x.shape
58
+
59
+ queries = q_proj.call(x)
60
+ keys = k_proj.call(x)
61
+ values = v_proj.call(x)
62
+
63
+ queries = queries.reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
64
+ keys = keys.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
65
+ values = values.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
66
+
67
+ if cache
68
+ queries = rope.call(queries, offset: cache.offset)
69
+ keys = rope.call(keys, offset: cache.offset)
70
+ keys, values = cache.update_and_fetch(keys, values)
71
+ else
72
+ queries = rope.call(queries)
73
+ keys = rope.call(keys)
74
+ end
75
+
76
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
77
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
78
+ o_proj.call(output)
79
+ end
80
+ end
81
+
82
+ class MLP < MLX::NN::Module
83
+ def initialize(args)
84
+ super()
85
+
86
+ dim = args.hidden_size
87
+ hidden_dim = args.intermediate_size
88
+ bias = args.mlp_bias
89
+
90
+ self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: bias)
91
+ self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: bias)
92
+ self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: bias)
93
+ end
94
+
95
+ def call(x)
96
+ down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x)))
97
+ end
98
+ end
99
+
100
+ class TransformerBlock < MLX::NN::Module
101
+ def initialize(args)
102
+ super()
103
+ self.self_attn = Attention.new(args)
104
+ self.mlp = MLP.new(args)
105
+ self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
106
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
107
+ @residual_multiplier = args.residual_multiplier
108
+ end
109
+
110
+ def call(x, mask: nil, cache: nil)
111
+ r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
112
+ h = x + r * @residual_multiplier
113
+ r = mlp.call(post_attention_layernorm.call(h))
114
+ h + r * @residual_multiplier
115
+ end
116
+ end
117
+
118
+ class GraniteModel < MLX::NN::Module
119
+ def initialize(args)
120
+ super()
121
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
122
+ self.layers = Array.new(args.num_hidden_layers) { TransformerBlock.new(args) }
123
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
124
+ @embedding_multiplier = args.embedding_multiplier
125
+ end
126
+
127
+ def call(inputs, cache: nil)
128
+ h = embed_tokens.call(inputs) * @embedding_multiplier
129
+ layer_cache = cache || [nil] * layers.length
130
+
131
+ mask = nil
132
+ mask = "causal" if h.shape[1] > 1
133
+
134
+ layers.each_with_index do |layer, i|
135
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
136
+ end
137
+
138
+ norm.call(h)
139
+ end
140
+ end
141
+
142
+ class Model < MLX::NN::Module
143
+ def initialize(args)
144
+ super()
145
+ @args = args
146
+ self.model_type = args.model_type
147
+ self.model = GraniteModel.new(args)
148
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) unless args.tie_word_embeddings
149
+ @logits_scaling = args.logits_scaling
150
+ end
151
+
152
+ def call(inputs, cache: nil)
153
+ out = model.call(inputs, cache: cache)
154
+ out = if @args.tie_word_embeddings
155
+ model.embed_tokens.as_linear(out)
156
+ else
157
+ lm_head.call(out)
158
+ end
159
+ out / @logits_scaling
160
+ end
161
+
162
+ def layers
163
+ model.layers
164
+ end
165
+ end
166
+
167
+ Models.register("granite", Model, ModelArgs)
168
+ end
169
+ end
170
+ end
@@ -0,0 +1,58 @@
1
+ require_relative "granite"
2
+
3
+ module MlxLm
4
+ module Models
5
+ module GraniteMoe
6
+ class ModelArgs < Granite::ModelArgs
7
+ field :model_type, default: "granitemoe"
8
+ field :num_local_experts
9
+ field :num_experts_per_tok
10
+ end
11
+
12
+ class Model < Granite::Model
13
+ def sanitize(weights)
14
+ result = weights.dup
15
+ rewrite_legacy_moe_weights(result)
16
+ result.delete("lm_head.weight") if @args.tie_word_embeddings
17
+ result
18
+ end
19
+
20
+ private
21
+
22
+ def rewrite_legacy_moe_weights(weights)
23
+ mx = MLX::Core
24
+
25
+ layers.length.times do |layer_idx|
26
+ prefix = "model.layers.#{layer_idx}.block_sparse_moe"
27
+ input_key = _first_existing_key(
28
+ weights,
29
+ ["#{prefix}.input_linear.weight", "#{prefix}.input_linear"]
30
+ )
31
+ output_key = _first_existing_key(
32
+ weights,
33
+ ["#{prefix}.output_linear.weight", "#{prefix}.output_linear"]
34
+ )
35
+ next unless input_key && output_key
36
+
37
+ input_linear = weights.delete(input_key)
38
+ output_linear = weights.delete(output_key)
39
+ mid = input_linear.shape[1] / 2
40
+ gate_proj, up_proj = mx.split(input_linear, [mid], 1)
41
+
42
+ weights["#{prefix}.switch_mlp.gate_proj.weight"] = gate_proj
43
+ weights["#{prefix}.switch_mlp.up_proj.weight"] = up_proj
44
+ weights["#{prefix}.switch_mlp.down_proj.weight"] = output_linear
45
+ end
46
+
47
+ weights
48
+ end
49
+
50
+ def _first_existing_key(weights, candidates)
51
+ candidates.find { |key| weights.key?(key) }
52
+ end
53
+ end
54
+
55
+ Models.register("granitemoe", Model, ModelArgs)
56
+ end
57
+ end
58
+ end
@@ -0,0 +1,178 @@
1
+ require_relative "falcon_h1"
2
+
3
+ module MlxLm
4
+ module Models
5
+ module GraniteMoeHybrid
6
+ class ModelArgs < FalconH1::ModelArgs
7
+ field :model_type, default: "granitemoehybrid"
8
+ field :embedding_multiplier, default: 1.0
9
+ field :attention_multiplier, default: 1.0
10
+ field :logits_scaling, default: 1.0
11
+ field :residual_multiplier, default: 1.0
12
+ field :num_local_experts, default: nil
13
+ field :num_experts_per_tok, default: nil
14
+ field :shared_intermediate_size, default: nil
15
+ field :mamba_n_heads, default: nil
16
+ field :mamba_d_head, default: nil
17
+ field :mamba_proj_bias, default: false
18
+ field :mamba_d_state, default: nil
19
+ field :mamba_n_groups, default: nil
20
+ field :mamba_conv_bias, default: false
21
+ field :layer_types, default: nil
22
+ field :position_embedding_type, default: "rope"
23
+ field :time_step_limit, default: [0.001, 100.0]
24
+ field :mlp_bias, default: false
25
+
26
+ def initialize(**kwargs)
27
+ super
28
+ @num_hidden_layers ||= Array(@layer_types).length
29
+ @num_attention_heads ||= @mamba_n_heads
30
+ @num_key_value_heads ||= @num_attention_heads
31
+ @head_dim ||= @mamba_d_head
32
+ @mamba_d_conv ||= 4
33
+ @layer_types ||= _default_layer_types
34
+ @block_types ||= _to_block_types
35
+ end
36
+
37
+ def to_falcon_h1_dict
38
+ hidden_size = @hidden_size
39
+ attention_heads = @num_attention_heads
40
+ inferred_head_dim = if !@head_dim.nil?
41
+ @head_dim
42
+ elsif !@mamba_d_head.nil?
43
+ @mamba_d_head
44
+ elsif !hidden_size.nil? && attention_heads.to_i > 0
45
+ hidden_size / attention_heads
46
+ else
47
+ 64
48
+ end
49
+
50
+ {
51
+ "model_type" => @model_type,
52
+ "attention_bias" => @attention_bias,
53
+ "head_dim" => inferred_head_dim,
54
+ "hidden_size" => hidden_size,
55
+ "intermediate_size" => @intermediate_size || @shared_intermediate_size || hidden_size.to_i * 2,
56
+ "max_position_embeddings" => @max_position_embeddings,
57
+ "mamba_d_conv" => @mamba_d_conv,
58
+ "num_attention_heads" => attention_heads,
59
+ "num_hidden_layers" => @num_hidden_layers,
60
+ "num_key_value_heads" => @num_key_value_heads,
61
+ "rms_norm_eps" => @rms_norm_eps,
62
+ "rope_theta" => @rope_theta,
63
+ "vocab_size" => @vocab_size,
64
+ "tie_word_embeddings" => @tie_word_embeddings,
65
+ "attention_window_size" => @attention_window_size,
66
+ "block_types" => @block_types,
67
+ }
68
+ end
69
+
70
+ private
71
+
72
+ def _default_layer_types
73
+ count = @num_hidden_layers.to_i
74
+ return nil if count <= 0
75
+
76
+ Array.new(count) { |idx| idx.even? ? "mamba" : "attention" }
77
+ end
78
+
79
+ def _to_block_types
80
+ return @block_types if @block_types.is_a?(Array) && !@block_types.empty?
81
+ return nil unless @layer_types.is_a?(Array) && !@layer_types.empty?
82
+
83
+ @layer_types.map { |layer_type| layer_type.to_s == "mamba" ? "recurrent" : "attention" }
84
+ end
85
+ end
86
+
87
+ class Model < MLX::NN::Module
88
+ def initialize(args)
89
+ super()
90
+ @args = args
91
+ self.model_type = args.model_type
92
+ self.wrapped_model = FalconH1::Model.new(
93
+ FalconH1::ModelArgs.from_dict(args.to_falcon_h1_dict)
94
+ )
95
+ end
96
+
97
+ def call(inputs, cache: nil)
98
+ wrapped_model.call(inputs, cache: cache)
99
+ end
100
+
101
+ def sanitize(weights)
102
+ normalized = weights.dup
103
+ _rewrite_block_sparse_moe!(normalized)
104
+ _rewrite_shared_mlp!(normalized)
105
+ normalized.delete("lm_head.weight") if @args.tie_word_embeddings
106
+
107
+ remapped = {}
108
+ normalized.each do |key, value|
109
+ remapped[_remap_weight_key(key)] = value
110
+ end
111
+ wrapped_model.sanitize(remapped)
112
+ end
113
+
114
+ def layers
115
+ wrapped_model.layers
116
+ end
117
+
118
+ def make_cache
119
+ return nil unless wrapped_model.respond_to?(:make_cache)
120
+
121
+ wrapped_model.make_cache
122
+ end
123
+
124
+ private
125
+
126
+ def _rewrite_block_sparse_moe!(weights)
127
+ mx = MLX::Core
128
+
129
+ @args.num_hidden_layers.to_i.times do |layer_idx|
130
+ prefix = "model.layers.#{layer_idx}.block_sparse_moe"
131
+ input_key = "#{prefix}.input_linear.weight"
132
+ output_key = "#{prefix}.output_linear.weight"
133
+ next unless weights.key?(input_key) && weights.key?(output_key)
134
+
135
+ input_linear = weights.delete(input_key)
136
+ output_linear = weights.delete(output_key)
137
+ mid = input_linear.shape[1] / 2
138
+ gate_proj, up_proj = mx.split(input_linear, [mid], 1)
139
+
140
+ weights["#{prefix}.switch_mlp.gate_proj.weight"] = gate_proj
141
+ weights["#{prefix}.switch_mlp.up_proj.weight"] = up_proj
142
+ weights["#{prefix}.switch_mlp.down_proj.weight"] = output_linear
143
+ end
144
+ end
145
+
146
+ def _rewrite_shared_mlp!(weights)
147
+ mx = MLX::Core
148
+
149
+ @args.num_hidden_layers.to_i.times do |layer_idx|
150
+ prefix = "model.layers.#{layer_idx}.shared_mlp"
151
+ input_key = "#{prefix}.input_linear.weight"
152
+ output_key = "#{prefix}.output_linear.weight"
153
+ next unless weights.key?(input_key) && weights.key?(output_key)
154
+
155
+ input_linear = weights.delete(input_key)
156
+ mid = input_linear.shape[0] / 2
157
+ gate_proj, up_proj = mx.split(input_linear, [mid], 0)
158
+
159
+ weights["model.layers.#{layer_idx}.mlp.gate_proj.weight"] = gate_proj
160
+ weights["model.layers.#{layer_idx}.mlp.up_proj.weight"] = up_proj
161
+ weights["model.layers.#{layer_idx}.mlp.down_proj.weight"] = weights.delete(output_key)
162
+ end
163
+ end
164
+
165
+ def _remap_weight_key(key)
166
+ mapped = key.dup
167
+ mapped = mapped.gsub(".block_sparse_moe.", ".feed_forward.")
168
+ mapped = mapped.gsub(".shared_mlp.", ".feed_forward.")
169
+ mapped = mapped.gsub(".post_attention_layernorm.", ".pre_ff_layernorm.")
170
+ mapped = mapped.gsub("model.norm.", "model.final_layernorm.")
171
+ mapped
172
+ end
173
+ end
174
+
175
+ Models.register("granitemoehybrid", Model, ModelArgs)
176
+ end
177
+ end
178
+ end
@@ -0,0 +1,158 @@
1
+ module MlxLm
2
+ module Models
3
+ module Helium
4
+ class ModelArgs < BaseModelArgs
5
+ field :hidden_size, default: 256
6
+ field :num_hidden_layers, default: 24
7
+ field :intermediate_size, default: 1024
8
+ field :num_attention_heads, default: 4
9
+ field :num_key_value_heads, default: nil
10
+ field :rms_norm_eps, default: 1e-5
11
+ field :vocab_size, default: 32_000
12
+ field :attention_bias, default: false
13
+ field :head_dim, default: nil
14
+ field :max_position_embeddings, default: 2048
15
+ field :mlp_bias, default: false
16
+ field :model_type, default: "helium"
17
+ field :rope_theta, default: 10_000.0
18
+ field :tie_word_embeddings, default: false
19
+
20
+ def initialize(**kwargs)
21
+ super
22
+ @num_key_value_heads ||= @num_attention_heads
23
+ @head_dim ||= @hidden_size / @num_attention_heads
24
+ end
25
+ end
26
+
27
+ class Attention < MLX::NN::Module
28
+ def initialize(args)
29
+ super()
30
+
31
+ dim = args.hidden_size
32
+ @n_heads = args.num_attention_heads
33
+ @n_kv_heads = args.num_key_value_heads
34
+ @head_dim = args.hidden_size / @n_heads
35
+ @scale = @head_dim**(-0.5)
36
+
37
+ self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: args.attention_bias)
38
+ self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: args.attention_bias)
39
+ self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: args.attention_bias)
40
+ self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: false)
41
+ self.rope = MLX::NN::RoPE.new(@head_dim, traditional: true, base: args.rope_theta)
42
+ end
43
+
44
+ def call(x, mask: nil, cache: nil)
45
+ mx = MLX::Core
46
+ b, l, _d = x.shape
47
+
48
+ queries = q_proj.call(x).reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
49
+ keys = k_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
50
+ values = v_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
51
+
52
+ if cache
53
+ queries = rope.call(queries, offset: cache.offset)
54
+ keys = rope.call(keys, offset: cache.offset)
55
+ keys, values = cache.update_and_fetch(keys, values)
56
+ else
57
+ queries = rope.call(queries)
58
+ keys = rope.call(keys)
59
+ end
60
+
61
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
62
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
63
+ o_proj.call(output)
64
+ end
65
+ end
66
+
67
+ class MLP < MLX::NN::Module
68
+ def initialize(args)
69
+ super()
70
+ self.gate_proj = MLX::NN::Linear.new(
71
+ args.hidden_size,
72
+ args.intermediate_size,
73
+ bias: args.mlp_bias
74
+ )
75
+ self.up_proj = MLX::NN::Linear.new(
76
+ args.hidden_size,
77
+ args.intermediate_size,
78
+ bias: args.mlp_bias
79
+ )
80
+ self.down_proj = MLX::NN::Linear.new(
81
+ args.intermediate_size,
82
+ args.hidden_size,
83
+ bias: args.mlp_bias
84
+ )
85
+ end
86
+
87
+ def call(x)
88
+ down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x)))
89
+ end
90
+ end
91
+
92
+ class DecoderLayer < MLX::NN::Module
93
+ def initialize(args)
94
+ super()
95
+ self.self_attn = Attention.new(args)
96
+ self.mlp = MLP.new(args)
97
+ self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
98
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
99
+ end
100
+
101
+ def call(x, mask: nil, cache: nil)
102
+ r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
103
+ h = x + r
104
+ r = mlp.call(post_attention_layernorm.call(h))
105
+ h + r
106
+ end
107
+ end
108
+
109
+ class HeliumModel < MLX::NN::Module
110
+ def initialize(args)
111
+ super()
112
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
113
+ self.layers = Array.new(args.num_hidden_layers) { DecoderLayer.new(args) }
114
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
115
+ end
116
+
117
+ def call(inputs, cache: nil)
118
+ h = embed_tokens.call(inputs)
119
+ layer_cache = cache || [nil] * layers.length
120
+
121
+ mask = nil
122
+ mask = "causal" if h.shape[1] > 1
123
+
124
+ layers.each_with_index do |layer, i|
125
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
126
+ end
127
+
128
+ norm.call(h)
129
+ end
130
+ end
131
+
132
+ class Model < MLX::NN::Module
133
+ def initialize(args)
134
+ super()
135
+ @args = args
136
+ @model_type = args.model_type
137
+ self.model = HeliumModel.new(args)
138
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
139
+ end
140
+
141
+ def call(inputs, cache: nil)
142
+ out = model.call(inputs, cache: cache)
143
+ if @args.tie_word_embeddings
144
+ model.embed_tokens.as_linear(out)
145
+ else
146
+ lm_head.call(out)
147
+ end
148
+ end
149
+
150
+ def layers
151
+ model.layers
152
+ end
153
+ end
154
+
155
+ Models.register("helium", Model, ModelArgs)
156
+ end
157
+ end
158
+ end