mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
@@ -0,0 +1,154 @@
1
+ module MlxLm
2
+ module Models
3
+ module GPTBigCode
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "gpt_bigcode"
6
+ field :n_embd
7
+ field :n_layer
8
+ field :n_inner
9
+ field :n_head
10
+ field :n_positions
11
+ field :layer_norm_epsilon
12
+ field :vocab_size
13
+ field :num_key_value_heads, default: nil
14
+ field :multi_query, default: true
15
+ field :attention_bias, default: true
16
+ field :mlp_bias, default: true
17
+ field :tie_word_embeddings, default: true
18
+
19
+ def initialize(**kwargs)
20
+ super
21
+ @num_key_value_heads ||= @multi_query ? 1 : @n_head
22
+ end
23
+ end
24
+
25
+ class Attention < MLX::NN::Module
26
+ def initialize(args)
27
+ super()
28
+
29
+ @dim = args.n_embd
30
+ @n_heads = args.n_head
31
+ @n_kv_heads = args.multi_query ? 1 : args.n_head
32
+ @head_dim = @dim / @n_heads
33
+ @kv_dim = @n_kv_heads * @head_dim
34
+ @scale = @head_dim**(-0.5)
35
+
36
+ bias = args.attention_bias
37
+ self.c_attn = MLX::NN::Linear.new(@dim, @dim + 2 * @kv_dim, bias: bias)
38
+ self.c_proj = MLX::NN::Linear.new(@dim, @dim, bias: bias)
39
+ end
40
+
41
+ def call(x, mask: nil, cache: nil)
42
+ mx = MLX::Core
43
+ b, l, _d = x.shape
44
+
45
+ qkv = c_attn.call(x)
46
+ queries, keys, values = mx.split(qkv, [@dim, @dim + @kv_dim], -1)
47
+
48
+ queries = queries.reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
49
+ keys = keys.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
50
+ values = values.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
51
+
52
+ if cache
53
+ keys, values = cache.update_and_fetch(keys, values)
54
+ end
55
+
56
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
57
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @dim])
58
+ c_proj.call(output)
59
+ end
60
+ end
61
+
62
+ class MLP < MLX::NN::Module
63
+ def initialize(args)
64
+ super()
65
+
66
+ dim = args.n_embd
67
+ hidden_dim = args.n_inner
68
+ bias = args.mlp_bias
69
+ self.c_fc = MLX::NN::Linear.new(dim, hidden_dim, bias: bias)
70
+ self.c_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: bias)
71
+ end
72
+
73
+ def call(x)
74
+ c_proj.call(MLX::NN.gelu(c_fc.call(x)))
75
+ end
76
+ end
77
+
78
+ class TransformerBlock < MLX::NN::Module
79
+ def initialize(args)
80
+ super()
81
+ self.attn = Attention.new(args)
82
+ self.mlp = MLP.new(args)
83
+ self.ln_1 = MLX::NN::LayerNorm.new(args.n_embd, eps: args.layer_norm_epsilon)
84
+ self.ln_2 = MLX::NN::LayerNorm.new(args.n_embd, eps: args.layer_norm_epsilon)
85
+ end
86
+
87
+ def call(x, mask: nil, cache: nil)
88
+ r = attn.call(ln_1.call(x), mask: mask, cache: cache)
89
+ h = x + r
90
+ r = mlp.call(ln_2.call(h))
91
+ h + r
92
+ end
93
+ end
94
+
95
+ class GPTBigCodeModel < MLX::NN::Module
96
+ def initialize(args)
97
+ super()
98
+ self.wte = MLX::NN::Embedding.new(args.vocab_size, args.n_embd)
99
+ self.wpe = MLX::NN::Embedding.new(args.n_positions, args.n_embd)
100
+ self.h = Array.new(args.n_layer) { TransformerBlock.new(args) }
101
+ self.ln_f = MLX::NN::LayerNorm.new(args.n_embd, eps: args.layer_norm_epsilon)
102
+ end
103
+
104
+ def call(inputs, cache: nil)
105
+ mx = MLX::Core
106
+ _b, l = inputs.shape
107
+
108
+ hidden_states = wte.call(inputs)
109
+ layer_cache = cache || [nil] * h.length
110
+ offset = layer_cache[0] ? layer_cache[0].offset : 0
111
+ position_ids = mx.arange(offset, offset + l, 1, mx.int32)
112
+
113
+ mask = nil
114
+ mask = "causal" if hidden_states.shape[1] > 1
115
+
116
+ hidden_states = hidden_states + wpe.call(position_ids)
117
+
118
+ h.each_with_index do |layer, i|
119
+ hidden_states = layer.call(hidden_states, mask: mask, cache: layer_cache[i])
120
+ end
121
+
122
+ ln_f.call(hidden_states)
123
+ end
124
+ end
125
+
126
+ class Model < MLX::NN::Module
127
+ def initialize(args)
128
+ super()
129
+ @args = args
130
+ self.model_type = args.model_type
131
+ self.transformer = GPTBigCodeModel.new(args)
132
+ unless args.tie_word_embeddings
133
+ self.lm_head = MLX::NN::Linear.new(args.n_embd, args.vocab_size, bias: false)
134
+ end
135
+ end
136
+
137
+ def call(inputs, cache: nil)
138
+ out = transformer.call(inputs, cache: cache)
139
+ if @args.tie_word_embeddings
140
+ transformer.wte.as_linear(out)
141
+ else
142
+ lm_head.call(out)
143
+ end
144
+ end
145
+
146
+ def layers
147
+ transformer.h
148
+ end
149
+ end
150
+
151
+ Models.register("gpt_bigcode", Model, ModelArgs)
152
+ end
153
+ end
154
+ end
@@ -0,0 +1,178 @@
1
+ module MlxLm
2
+ module Models
3
+ module GPTNeoX
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "gpt_neox"
6
+ field :hidden_size, default: 2560
7
+ field :num_hidden_layers, default: 32
8
+ field :num_attention_heads, default: 32
9
+ field :num_key_value_heads, default: nil
10
+ field :vocab_size, default: 50432
11
+ field :layer_norm_eps, default: 1e-5
12
+ field :rotary_emb_base, default: 10000
13
+ field :rotary_pct, default: 0.25
14
+ field :use_parallel_residual, default: true
15
+ field :max_position_embeddings, default: 2048
16
+ field :intermediate_size, default: nil
17
+
18
+ def initialize(**kwargs)
19
+ super
20
+ @num_key_value_heads ||= @num_attention_heads
21
+ @intermediate_size ||= 4 * @hidden_size
22
+ end
23
+ end
24
+
25
+ class Attention < MLX::NN::Module
26
+ def initialize(args)
27
+ super()
28
+ dim = args.hidden_size
29
+ @n_heads = args.num_attention_heads
30
+ @n_kv_heads = args.num_key_value_heads
31
+ @head_dim = dim / @n_heads
32
+ @scale = @head_dim**(-0.5)
33
+
34
+ # Partial rotary: only apply RoPE to rotary_pct fraction of head_dim
35
+ rope_dim = (args.rotary_pct * @head_dim).to_i
36
+
37
+ # Combined QKV projection
38
+ total_qkv = (@n_heads + 2 * @n_kv_heads) * @head_dim
39
+ self.query_key_value = MLX::NN::Linear.new(dim, total_qkv, bias: true)
40
+ self.dense = MLX::NN::Linear.new(dim, dim, bias: true)
41
+
42
+ self.rope = MLX::NN::RoPE.new(
43
+ rope_dim,
44
+ traditional: false,
45
+ base: args.rotary_emb_base
46
+ )
47
+ end
48
+
49
+ def call(x, mask: nil, cache: nil)
50
+ mx = MLX::Core
51
+ b, l, _d = x.shape
52
+
53
+ qkv = query_key_value.call(x)
54
+
55
+ # Split into Q, K, V
56
+ q_size = @n_heads * @head_dim
57
+ kv_size = @n_kv_heads * @head_dim
58
+ queries, keys, values = mx.split(qkv, [q_size, q_size + kv_size], 2)
59
+
60
+ queries = queries.reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
61
+ keys = keys.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
62
+ values = values.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
63
+
64
+ if cache
65
+ queries = rope.call(queries, offset: cache.offset)
66
+ keys = rope.call(keys, offset: cache.offset)
67
+ keys, values = cache.update_and_fetch(keys, values)
68
+ else
69
+ queries = rope.call(queries)
70
+ keys = rope.call(keys)
71
+ end
72
+
73
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
74
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
75
+ dense.call(output)
76
+ end
77
+ end
78
+
79
+ class MLP < MLX::NN::Module
80
+ def initialize(dim, hidden_dim)
81
+ super()
82
+ self.dense_h_to_4h = MLX::NN::Linear.new(dim, hidden_dim, bias: true)
83
+ self.dense_4h_to_h = MLX::NN::Linear.new(hidden_dim, dim, bias: true)
84
+ end
85
+
86
+ def call(x)
87
+ dense_4h_to_h.call(MLX::NN.gelu_approx(dense_h_to_4h.call(x)))
88
+ end
89
+ end
90
+
91
+ class TransformerBlock < MLX::NN::Module
92
+ def initialize(args)
93
+ super()
94
+ self.self_attn = Attention.new(args)
95
+ self.mlp = MLP.new(args.hidden_size, args.intermediate_size)
96
+ self.input_layernorm = MLX::NN::LayerNorm.new(args.hidden_size, eps: args.layer_norm_eps)
97
+ @use_parallel_residual = args.use_parallel_residual
98
+ unless @use_parallel_residual
99
+ self.post_attention_layernorm = MLX::NN::LayerNorm.new(args.hidden_size, eps: args.layer_norm_eps)
100
+ end
101
+ end
102
+
103
+ def call(x, mask: nil, cache: nil)
104
+ h = input_layernorm.call(x)
105
+ r = self_attn.call(h, mask: mask, cache: cache)
106
+
107
+ if @use_parallel_residual
108
+ x + r + mlp.call(h)
109
+ else
110
+ h = x + r
111
+ r = mlp.call(post_attention_layernorm.call(h))
112
+ h + r
113
+ end
114
+ end
115
+ end
116
+
117
+ class GPTNeoXModel < MLX::NN::Module
118
+ def initialize(args)
119
+ super()
120
+ @args = args
121
+ self.embed_in = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
122
+ self.h = Array.new(args.num_hidden_layers) { TransformerBlock.new(args) }
123
+ self.final_layer_norm = MLX::NN::LayerNorm.new(args.hidden_size, eps: args.layer_norm_eps)
124
+ end
125
+
126
+ def call(inputs, cache: nil)
127
+ hidden = embed_in.call(inputs)
128
+ layer_cache = cache || [nil] * h.length
129
+
130
+ mask = nil
131
+ mask = "causal" if hidden.shape[1] > 1
132
+
133
+ h.each_with_index do |layer, i|
134
+ hidden = layer.call(hidden, mask: mask, cache: layer_cache[i])
135
+ end
136
+
137
+ final_layer_norm.call(hidden)
138
+ end
139
+ end
140
+
141
+ class Model < MLX::NN::Module
142
+ def initialize(args)
143
+ super()
144
+ @args = args
145
+ self.model = GPTNeoXModel.new(args)
146
+ self.embed_out = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
147
+ end
148
+
149
+ def call(inputs, cache: nil)
150
+ out = model.call(inputs, cache: cache)
151
+ embed_out.call(out)
152
+ end
153
+
154
+ def sanitize(weights)
155
+ result = {}
156
+ weights.each do |k, v|
157
+ next if k.include?(".attention.bias") || k.include?(".attention.masked_bias")
158
+ next if k.include?(".attention.rotary_emb.inv_freq")
159
+
160
+ # Remap weight keys
161
+ key = k.dup
162
+ key.gsub!(".gpt_neox.layers.", ".h.")
163
+ key.gsub!(".gpt_neox.", ".")
164
+ key = "model.#{key}" unless key.start_with?("model.")
165
+ result[key] = v
166
+ end
167
+ result
168
+ end
169
+
170
+ def layers
171
+ model.h
172
+ end
173
+ end
174
+
175
+ Models.register("gpt_neox", Model, ModelArgs)
176
+ end
177
+ end
178
+ end
@@ -0,0 +1,319 @@
1
+ require_relative "cache"
2
+ require_relative "rope_utils"
3
+ require_relative "switch_layers"
4
+
5
+ module MlxLm
6
+ module Models
7
+ module GptOss
8
+ class ModelArgs < BaseModelArgs
9
+ field :model_type, default: "gpt_oss"
10
+ field :num_hidden_layers, default: 36
11
+ field :num_local_experts, default: 128
12
+ field :num_experts_per_tok, default: 4
13
+ field :vocab_size, default: 201_088
14
+ field :rms_norm_eps, default: 1e-5
15
+ field :hidden_size, default: 2880
16
+ field :intermediate_size, default: 2880
17
+ field :head_dim, default: 64
18
+ field :num_attention_heads, default: 64
19
+ field :num_key_value_heads, default: 8
20
+ field :sliding_window, default: 128
21
+ field :rope_theta, default: 150_000
22
+ field :rope_scaling, default: nil
23
+ field :layer_types, default: nil
24
+
25
+ def initialize(**kwargs)
26
+ super
27
+ @layer_types ||= Array.new(@num_hidden_layers) do |i|
28
+ i.even? ? "sliding_attention" : "full_attention"
29
+ end
30
+ end
31
+ end
32
+
33
+ class AttentionBlock < MLX::NN::Module
34
+ def initialize(config)
35
+ super()
36
+ @head_dim = config.head_dim
37
+ @num_attention_heads = config.num_attention_heads
38
+ @num_key_value_heads = config.num_key_value_heads
39
+ @sm_scale = 1.0 / Math.sqrt(@head_dim)
40
+
41
+ self.q_proj = MLX::NN::Linear.new(
42
+ config.hidden_size,
43
+ @num_attention_heads * @head_dim,
44
+ bias: true
45
+ )
46
+ self.k_proj = MLX::NN::Linear.new(
47
+ config.hidden_size,
48
+ @num_key_value_heads * @head_dim,
49
+ bias: true
50
+ )
51
+ self.v_proj = MLX::NN::Linear.new(
52
+ config.hidden_size,
53
+ @num_key_value_heads * @head_dim,
54
+ bias: true
55
+ )
56
+ self.o_proj = MLX::NN::Linear.new(
57
+ @num_attention_heads * @head_dim,
58
+ config.hidden_size,
59
+ bias: true
60
+ )
61
+
62
+ self.rope = MlxLm::Models.initialize_rope(
63
+ @head_dim,
64
+ config.rope_theta,
65
+ false,
66
+ config.rope_scaling
67
+ )
68
+ end
69
+
70
+ def call(x, mask:, cache: nil)
71
+ mx = MLX::Core
72
+ b, l, _d = x.shape
73
+
74
+ q = q_proj.call(x).reshape([b, l, @num_attention_heads, @head_dim]).transpose([0, 2, 1, 3])
75
+ k = k_proj.call(x).reshape([b, l, @num_key_value_heads, @head_dim]).transpose([0, 2, 1, 3])
76
+ v = v_proj.call(x).reshape([b, l, @num_key_value_heads, @head_dim]).transpose([0, 2, 1, 3])
77
+
78
+ if cache
79
+ q = rope.call(q, offset: cache.offset)
80
+ k = rope.call(k, offset: cache.offset)
81
+ k, v = cache.update_and_fetch(k, v)
82
+ else
83
+ q = rope.call(q)
84
+ k = rope.call(k)
85
+ end
86
+
87
+ out = mx.scaled_dot_product_attention(q, k, v, @sm_scale, mask)
88
+ out = out.transpose([0, 2, 1, 3]).reshape([b, l, @num_attention_heads * @head_dim])
89
+ o_proj.call(out)
90
+ end
91
+ end
92
+
93
+ class MLPBlock < MLX::NN::Module
94
+ def initialize(config)
95
+ super()
96
+ @num_local_experts = config.num_local_experts
97
+ @num_experts_per_tok = config.num_experts_per_tok
98
+
99
+ self.experts = SwitchLayers::SwitchGLU.new(
100
+ config.hidden_size,
101
+ config.intermediate_size,
102
+ @num_local_experts,
103
+ bias: true
104
+ )
105
+ self.router = MLX::NN::Linear.new(
106
+ config.hidden_size,
107
+ @num_local_experts,
108
+ bias: true
109
+ )
110
+ end
111
+
112
+ def call(x)
113
+ mx = MLX::Core
114
+
115
+ gates = router.call(x)
116
+ k = [@num_experts_per_tok, @num_local_experts].min
117
+ inds = mx.stop_gradient(mx.argpartition(gates * -1.0, k - 1, -1))
118
+ take_ids = mx.array((0...k).to_a, dtype: mx.int32)
119
+ inds = mx.take(inds, take_ids, -1)
120
+ expert_weights = mx.take_along_axis(gates, inds, -1)
121
+ expert_weights = mx.softmax(expert_weights.astype(mx.float32), -1).astype(expert_weights.dtype)
122
+
123
+ x = experts.call(x, inds)
124
+ x = x * mx.expand_dims(expert_weights, -1)
125
+ mx.sum(x, -2)
126
+ end
127
+ end
128
+
129
+ class TransformerBlock < MLX::NN::Module
130
+ def initialize(config)
131
+ super()
132
+ self.self_attn = AttentionBlock.new(config)
133
+ self.mlp = MLPBlock.new(config)
134
+ self.input_layernorm = MLX::NN::RMSNorm.new(config.hidden_size, eps: config.rms_norm_eps)
135
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(config.hidden_size, eps: config.rms_norm_eps)
136
+ end
137
+
138
+ def call(x, mask:, cache: nil)
139
+ h = x + self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
140
+ h + mlp.call(post_attention_layernorm.call(h))
141
+ end
142
+ end
143
+
144
+ class GptOssMoeModel < MLX::NN::Module
145
+ attr_reader :layer_types
146
+
147
+ def initialize(args)
148
+ super()
149
+ @window_size = args.sliding_window
150
+ @layer_types = args.layer_types
151
+
152
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
153
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
154
+ self.layers = Array.new(args.num_hidden_layers) { TransformerBlock.new(args) }
155
+
156
+ @swa_idx = @layer_types.index("sliding_attention") || 0
157
+ @ga_idx = @layer_types.index("full_attention") || 0
158
+ end
159
+
160
+ def call(inputs, cache: nil, input_embeddings: nil)
161
+ x = input_embeddings || embed_tokens.call(inputs)
162
+ layer_cache = cache || [nil] * layers.length
163
+
164
+ full_mask = _create_attention_mask(x, layer_cache[@ga_idx])
165
+ swa_mask = _create_attention_mask(
166
+ x,
167
+ layer_cache[@swa_idx],
168
+ window_size: @window_size
169
+ )
170
+
171
+ layers.each_with_index do |layer, i|
172
+ layer_type = @layer_types[i]
173
+ mask = layer_type == "full_attention" ? full_mask : swa_mask
174
+ x = layer.call(x, mask: mask, cache: layer_cache[i])
175
+ end
176
+
177
+ norm.call(x)
178
+ end
179
+
180
+ private
181
+
182
+ def _create_attention_mask(h, cache = nil, window_size: nil)
183
+ n = h.shape[1]
184
+ if cache && cache.respond_to?(:make_mask)
185
+ return cache.make_mask(n, window_size: window_size)
186
+ end
187
+
188
+ if window_size
189
+ offset = 0
190
+ if cache
191
+ offset = cache.offset
192
+ if cache.instance_variable_defined?(:@max_size)
193
+ max_size = cache.instance_variable_get(:@max_size)
194
+ offset = [max_size - 1, offset].min if max_size && max_size > 0
195
+ end
196
+ end
197
+ return _create_causal_mask(n, offset: offset, window_size: window_size) if offset + n > window_size
198
+ end
199
+
200
+ return nil if n == 1
201
+
202
+ "causal"
203
+ end
204
+
205
+ def _create_causal_mask(n, offset: 0, window_size: nil)
206
+ mx = MLX::Core
207
+ rinds = mx.arange(0, offset + n, 1, mx.int32).reshape([1, offset + n])
208
+ linds = mx.arange(offset, offset + n, 1, mx.int32).reshape([n, 1])
209
+
210
+ mask = mx.greater_equal(linds, rinds)
211
+ if window_size
212
+ mask = mx.logical_and(mask, mx.less(linds, mx.add(rinds, window_size)))
213
+ end
214
+ mask
215
+ end
216
+ end
217
+
218
+ class Model < MLX::NN::Module
219
+ attr_reader :args
220
+
221
+ def initialize(args)
222
+ super()
223
+ @args = args
224
+ self.model_type = args.model_type
225
+ self.model = GptOssMoeModel.new(args)
226
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
227
+ end
228
+
229
+ def call(inputs, cache: nil, input_embeddings: nil)
230
+ lm_head.call(model.call(inputs, cache: cache, input_embeddings: input_embeddings))
231
+ end
232
+
233
+ def sanitize(weights)
234
+ return weights if weights.keys.any? { |key| key.include?("gate_proj.weight") }
235
+
236
+ result = {}
237
+ weights.each do |key, value|
238
+ if key.include?("gate_up_proj") && !key.include?("bias")
239
+ normalized_key, normalized_value = _normalize_moe_weight_param(key, value)
240
+ split_axis = normalized_value.shape.length - 2
241
+ result[normalized_key.sub("gate_up_proj", "gate_proj")] = _take_every_other(
242
+ normalized_value,
243
+ start: 0,
244
+ axis: split_axis
245
+ )
246
+ result[normalized_key.sub("gate_up_proj", "up_proj")] = _take_every_other(
247
+ normalized_value,
248
+ start: 1,
249
+ axis: split_axis
250
+ )
251
+ elsif key.include?("down_proj") && !key.include?("bias")
252
+ normalized_key, normalized_value = _normalize_moe_weight_param(key, value)
253
+ result[normalized_key] = normalized_value
254
+ elsif key.include?("gate_up_proj_bias")
255
+ split_axis = value.shape.length - 1
256
+ result[key.sub("gate_up_proj_bias", "gate_proj.bias")] = _take_every_other(
257
+ value,
258
+ start: 0,
259
+ axis: split_axis
260
+ )
261
+ result[key.sub("gate_up_proj_bias", "up_proj.bias")] = _take_every_other(
262
+ value,
263
+ start: 1,
264
+ axis: split_axis
265
+ )
266
+ elsif key.include?("down_proj_bias")
267
+ result[key.sub("down_proj_bias", "down_proj.bias")] = value
268
+ else
269
+ result[key] = value
270
+ end
271
+ end
272
+
273
+ result
274
+ end
275
+
276
+ def layers
277
+ model.layers
278
+ end
279
+
280
+ def make_cache
281
+ model.layer_types.map do |layer_type|
282
+ if layer_type == "full_attention"
283
+ MlxLm::KVCache.new
284
+ else
285
+ MlxLm::RotatingKVCache.new(max_size: @args.sliding_window)
286
+ end
287
+ end
288
+ end
289
+
290
+ private
291
+
292
+ def _normalize_moe_weight_param(key, value)
293
+ mx = MLX::Core
294
+ normalized_key = key
295
+ normalized_value = value
296
+
297
+ if key.include?("_blocks")
298
+ normalized_value = mx.flatten(value.view(mx.uint32), -2, -1)
299
+ normalized_key = normalized_key.sub("_blocks", ".weight")
300
+ end
301
+ if key.include?("_scales")
302
+ normalized_key = normalized_key.sub("_scales", ".scales")
303
+ end
304
+
305
+ [normalized_key, normalized_value]
306
+ end
307
+
308
+ def _take_every_other(value, start:, axis:)
309
+ mx = MLX::Core
310
+ indices = (start...value.shape[axis]).step(2).to_a
311
+ take_ids = mx.array(indices, dtype: mx.int32)
312
+ mx.take(value, take_ids, axis)
313
+ end
314
+ end
315
+
316
+ Models.register("gpt_oss", Model, ModelArgs)
317
+ end
318
+ end
319
+ end