mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
@@ -0,0 +1,150 @@
1
+ module MlxLm
2
+ module Models
3
+ module Cohere
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "cohere"
6
+ field :hidden_size, default: 8192
7
+ field :num_hidden_layers, default: 40
8
+ field :num_attention_heads, default: 64
9
+ field :num_key_value_heads, default: 64
10
+ field :intermediate_size, default: 22528
11
+ field :vocab_size, default: 256000
12
+ field :rope_theta, default: 8000000.0
13
+ field :layer_norm_eps, default: 1e-5
14
+ field :logit_scale, default: 0.0625
15
+ field :attention_bias, default: false
16
+ field :layer_norm_bias, default: false
17
+ field :use_qk_norm, default: false
18
+
19
+ def initialize(**kwargs)
20
+ super
21
+ @num_key_value_heads ||= @num_attention_heads
22
+ end
23
+ end
24
+
25
+ class Attention < MLX::NN::Module
26
+ def initialize(args)
27
+ super()
28
+ dim = args.hidden_size
29
+ @n_heads = args.num_attention_heads
30
+ @n_kv_heads = args.num_key_value_heads
31
+ @head_dim = dim / @n_heads
32
+ @scale = @head_dim**(-0.5)
33
+
34
+ bias = args.attention_bias
35
+ self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: bias)
36
+ self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: bias)
37
+ self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: bias)
38
+ self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: bias)
39
+
40
+ self.rope = MLX::NN::RoPE.new(@head_dim, traditional: true, base: args.rope_theta)
41
+ end
42
+
43
+ def call(x, mask: nil, cache: nil)
44
+ mx = MLX::Core
45
+ b, l, _d = x.shape
46
+
47
+ queries = q_proj.call(x).reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
48
+ keys = k_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
49
+ values = v_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
50
+
51
+ if cache
52
+ queries = rope.call(queries, offset: cache.offset)
53
+ keys = rope.call(keys, offset: cache.offset)
54
+ keys, values = cache.update_and_fetch(keys, values)
55
+ else
56
+ queries = rope.call(queries)
57
+ keys = rope.call(keys)
58
+ end
59
+
60
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
61
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
62
+ o_proj.call(output)
63
+ end
64
+ end
65
+
66
+ class MLP < MLX::NN::Module
67
+ def initialize(dim, hidden_dim)
68
+ super()
69
+ self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false)
70
+ self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: false)
71
+ self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false)
72
+ end
73
+
74
+ def call(x)
75
+ down_proj.call(MLX::NN.silu(gate_proj.call(x)) * up_proj.call(x))
76
+ end
77
+ end
78
+
79
+ class TransformerBlock < MLX::NN::Module
80
+ def initialize(args)
81
+ super()
82
+ self.self_attn = Attention.new(args)
83
+ self.mlp = MLP.new(args.hidden_size, args.intermediate_size)
84
+ self.input_layernorm = MLX::NN::LayerNorm.new(
85
+ args.hidden_size, eps: args.layer_norm_eps
86
+ )
87
+ end
88
+
89
+ def call(x, mask: nil, cache: nil)
90
+ # Cohere uses parallel residuals: attn + mlp + x
91
+ h = input_layernorm.call(x)
92
+ attn_h = self_attn.call(h, mask: mask, cache: cache)
93
+ ff_h = mlp.call(h)
94
+ attn_h + ff_h + x
95
+ end
96
+ end
97
+
98
+ class CohereModel < MLX::NN::Module
99
+ def initialize(args)
100
+ super()
101
+ @args = args
102
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
103
+ self.layers = Array.new(args.num_hidden_layers) { TransformerBlock.new(args) }
104
+ self.norm = MLX::NN::LayerNorm.new(
105
+ args.hidden_size, eps: args.layer_norm_eps
106
+ )
107
+ end
108
+
109
+ def call(inputs, cache: nil)
110
+ h = embed_tokens.call(inputs)
111
+ layer_cache = cache || [nil] * layers.length
112
+
113
+ mask = nil
114
+ mask = "causal" if h.shape[1] > 1
115
+
116
+ layers.each_with_index do |layer, i|
117
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
118
+ end
119
+
120
+ norm.call(h)
121
+ end
122
+ end
123
+
124
+ class Model < MLX::NN::Module
125
+ def initialize(args)
126
+ super()
127
+ @args = args
128
+ self.model = CohereModel.new(args)
129
+ end
130
+
131
+ def call(inputs, cache: nil)
132
+ out = model.call(inputs, cache: cache)
133
+ # Tied embeddings + logit scaling
134
+ out = model.embed_tokens.as_linear(out)
135
+ out * @args.logit_scale
136
+ end
137
+
138
+ def sanitize(weights)
139
+ weights.reject { |k, _| k.include?("self_attn.rotary_emb.inv_freq") }
140
+ end
141
+
142
+ def layers
143
+ model.layers
144
+ end
145
+ end
146
+
147
+ Models.register("cohere", Model, ModelArgs)
148
+ end
149
+ end
150
+ end
@@ -0,0 +1,224 @@
1
+ module MlxLm
2
+ module Models
3
+ module Cohere2
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "cohere2"
6
+ field :hidden_size, default: 4096
7
+ field :head_dim, default: 128
8
+ field :num_hidden_layers, default: 32
9
+ field :intermediate_size, default: 14336
10
+ field :num_attention_heads, default: 32
11
+ field :num_key_value_heads, default: 8
12
+ field :rope_theta, default: 50_000.0
13
+ field :vocab_size, default: 256000
14
+ field :layer_norm_eps, default: 1e-5
15
+ field :logit_scale, default: 0.0625
16
+ field :attention_bias, default: false
17
+ field :layer_norm_bias, default: false
18
+ field :sliding_window, default: 4096
19
+ field :sliding_window_pattern, default: 4
20
+
21
+ def initialize(**kwargs)
22
+ super
23
+ @num_key_value_heads ||= @num_attention_heads
24
+ end
25
+ end
26
+
27
+ class Attention < MLX::NN::Module
28
+ def initialize(args, layer_idx)
29
+ super()
30
+ dim = args.hidden_size
31
+ @n_heads = args.num_attention_heads
32
+ @n_kv_heads = args.num_key_value_heads
33
+ @head_dim = args.head_dim
34
+ if (@head_dim * @n_heads) != dim
35
+ raise ArgumentError,
36
+ "hidden_size must equal num_attention_heads * head_dim (got #{dim} and #{@n_heads} * #{@head_dim})"
37
+ end
38
+ @scale = @head_dim**(-0.5)
39
+
40
+ bias = args.attention_bias
41
+ self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: bias)
42
+ self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: bias)
43
+ self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: bias)
44
+ self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: bias)
45
+
46
+ self.rope = MLX::NN::RoPE.new(@head_dim, traditional: true, base: args.rope_theta)
47
+ @use_sliding_window = ((layer_idx + 1) % args.sliding_window_pattern) != 0
48
+ end
49
+
50
+ def call(x, mask: nil, cache: nil)
51
+ mx = MLX::Core
52
+ b, l, _d = x.shape
53
+
54
+ queries = q_proj.call(x).reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
55
+ keys = k_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
56
+ values = v_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
57
+
58
+ if @use_sliding_window
59
+ if cache
60
+ queries = rope.call(queries, offset: cache.offset)
61
+ keys = rope.call(keys, offset: cache.offset)
62
+ else
63
+ queries = rope.call(queries)
64
+ keys = rope.call(keys)
65
+ end
66
+ end
67
+
68
+ keys, values = cache.update_and_fetch(keys, values) if cache
69
+
70
+ sdpa_type = queries.dtype == mx.float16 ? mx.float32 : queries.dtype
71
+ output = mx.scaled_dot_product_attention(
72
+ queries.astype(sdpa_type),
73
+ keys,
74
+ values,
75
+ @scale,
76
+ mask
77
+ ).astype(queries.dtype)
78
+
79
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
80
+ o_proj.call(output)
81
+ end
82
+ end
83
+
84
+ class MLP < MLX::NN::Module
85
+ def initialize(dim, hidden_dim)
86
+ super()
87
+ self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false)
88
+ self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false)
89
+ self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: false)
90
+ end
91
+
92
+ def call(x)
93
+ down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x)))
94
+ end
95
+ end
96
+
97
+ class TransformerBlock < MLX::NN::Module
98
+ def initialize(args, layer_idx)
99
+ super()
100
+ self.self_attn = Attention.new(args, layer_idx)
101
+ self.mlp = MLP.new(args.hidden_size, args.intermediate_size)
102
+ self.input_layernorm = MLX::NN::LayerNorm.new(
103
+ args.hidden_size,
104
+ eps: args.layer_norm_eps,
105
+ bias: args.layer_norm_bias
106
+ )
107
+ end
108
+
109
+ def call(x, mask: nil, cache: nil)
110
+ h = input_layernorm.call(x)
111
+ attn_h = self_attn.call(h, mask: mask, cache: cache)
112
+ ff_h = mlp.call(h)
113
+ attn_h + ff_h + x
114
+ end
115
+ end
116
+
117
+ class Cohere2Model < MLX::NN::Module
118
+ def initialize(args)
119
+ super()
120
+ @args = args
121
+ @window_size = args.sliding_window
122
+
123
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
124
+ self.layers = Array.new(args.num_hidden_layers) { |i| TransformerBlock.new(args, i) }
125
+ self.norm = MLX::NN::LayerNorm.new(
126
+ args.hidden_size,
127
+ eps: args.layer_norm_eps,
128
+ bias: args.layer_norm_bias
129
+ )
130
+ end
131
+
132
+ def call(inputs, cache: nil)
133
+ h = embed_tokens.call(inputs)
134
+ layer_cache = cache || [nil] * layers.length
135
+
136
+ pattern = @args.sliding_window_pattern
137
+ full_mask = _create_attention_mask(h, layer_cache[pattern - 1])
138
+ swa_mask = _create_attention_mask(h, layer_cache[0], window_size: @window_size)
139
+
140
+ layers.each_with_index do |layer, i|
141
+ is_global = (i % pattern) == (pattern - 1)
142
+ mask = is_global ? full_mask : swa_mask
143
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
144
+ end
145
+
146
+ norm.call(h)
147
+ end
148
+
149
+ private
150
+
151
+ def _create_attention_mask(h, cache, window_size: nil)
152
+ n = h.shape[1]
153
+ offset = cache ? cache.offset : 0
154
+
155
+ if window_size
156
+ if cache || n > window_size
157
+ return _create_causal_mask(n, offset, window_size)
158
+ end
159
+ return nil if n == 1
160
+
161
+ return "causal"
162
+ end
163
+
164
+ return nil if n == 1
165
+
166
+ "causal"
167
+ end
168
+
169
+ def _create_causal_mask(n, offset, window_size = nil)
170
+ mx = MLX::Core
171
+ rinds = mx.arange(offset + n)
172
+ linds = offset.zero? ? rinds : mx.arange(offset, offset + n)
173
+
174
+ linds = mx.expand_dims(linds, 1)
175
+ rinds = mx.expand_dims(rinds, 0)
176
+ mask = mx.greater_equal(linds, rinds)
177
+
178
+ if window_size
179
+ mask = mx.logical_and(mask, mx.less(linds, mx.add(rinds, window_size)))
180
+ end
181
+
182
+ mask
183
+ end
184
+ end
185
+
186
+ class Model < MLX::NN::Module
187
+ def initialize(args)
188
+ super()
189
+ @args = args
190
+ self.model = Cohere2Model.new(args)
191
+ end
192
+
193
+ def call(inputs, cache: nil)
194
+ out = model.call(inputs, cache: cache)
195
+ out = model.embed_tokens.as_linear(out)
196
+ out * @args.logit_scale
197
+ end
198
+
199
+ def make_cache
200
+ caches = []
201
+ @args.num_hidden_layers.times do |i|
202
+ is_global = (i % @args.sliding_window_pattern) == (@args.sliding_window_pattern - 1)
203
+ if is_global
204
+ caches << MlxLm::KVCache.new
205
+ else
206
+ caches << MlxLm::RotatingKVCache.new(max_size: @args.sliding_window, keep: 0)
207
+ end
208
+ end
209
+ caches
210
+ end
211
+
212
+ def sanitize(weights)
213
+ weights.reject { |k, _| k.include?("self_attn.rotary_emb.inv_freq") }
214
+ end
215
+
216
+ def layers
217
+ model.layers
218
+ end
219
+ end
220
+
221
+ Models.register("cohere2", Model, ModelArgs)
222
+ end
223
+ end
224
+ end
@@ -0,0 +1,286 @@
1
+ require_relative "activations"
2
+
3
+ module MlxLm
4
+ module Models
5
+ module Dbrx
6
+ class ModelArgs < BaseModelArgs
7
+ field :model_type, default: "dbrx"
8
+ field :vocab_size, default: 32_000
9
+ field :d_model, default: 6144
10
+ field :ffn_config, default: {}
11
+ field :attn_config, default: {}
12
+ field :n_layers, default: 40
13
+ field :n_heads, default: 48
14
+
15
+ def initialize(**kwargs)
16
+ super
17
+ @ffn_config ||= {}
18
+ @attn_config ||= {}
19
+ end
20
+ end
21
+
22
+ class Attention < MLX::NN::Module
23
+ def initialize(args)
24
+ super()
25
+
26
+ @num_heads = args.n_heads
27
+ @d_model = args.d_model
28
+ @head_dim = @d_model / args.n_heads
29
+ @num_key_value_heads = _attn_value(args.attn_config, "kv_n_heads", args.n_heads).to_i
30
+ @clip_qkv = _attn_value(args.attn_config, "clip_qkv", 8.0).to_f
31
+ @rope_theta = _attn_value(args.attn_config, "rope_theta", 10_000.0).to_f
32
+ @scale = @head_dim**(-0.5)
33
+
34
+ self.wqkv = MLX::NN::Linear.new(
35
+ args.d_model,
36
+ (@num_key_value_heads * 2 + @num_heads) * @head_dim,
37
+ bias: false
38
+ )
39
+ self.out_proj = MLX::NN::Linear.new(args.d_model, args.d_model, bias: false)
40
+ self.rope = MLX::NN::RoPE.new(@head_dim, traditional: false, base: @rope_theta)
41
+ end
42
+
43
+ def call(x, mask: nil, cache: nil)
44
+ mx = MLX::Core
45
+ b, l, _d = x.shape
46
+
47
+ qkv = wqkv.call(x)
48
+ qkv = mx.clip(qkv, -@clip_qkv, @clip_qkv)
49
+
50
+ splits = [@d_model, @d_model + @head_dim * @num_key_value_heads]
51
+ queries, keys, values = mx.split(qkv, splits, -1)
52
+
53
+ queries = queries.reshape([b, l, @num_heads, @head_dim]).transpose([0, 2, 1, 3])
54
+ keys = keys.reshape([b, l, @num_key_value_heads, @head_dim]).transpose([0, 2, 1, 3])
55
+ values = values.reshape([b, l, @num_key_value_heads, @head_dim]).transpose([0, 2, 1, 3])
56
+
57
+ if cache
58
+ queries = rope.call(queries, offset: cache.offset)
59
+ keys = rope.call(keys, offset: cache.offset)
60
+ keys, values = cache.update_and_fetch(keys, values)
61
+ else
62
+ queries = rope.call(queries)
63
+ keys = rope.call(keys)
64
+ end
65
+
66
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
67
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @d_model])
68
+ out_proj.call(output)
69
+ end
70
+
71
+ private
72
+
73
+ def _attn_value(config, key, default = nil)
74
+ return default if config.nil?
75
+ return config[key] if config.key?(key)
76
+
77
+ config.fetch(key.to_sym, default)
78
+ end
79
+ end
80
+
81
+ class NormAttnNorm < MLX::NN::Module
82
+ def initialize(args)
83
+ super()
84
+ self.norm_1 = MLX::NN::LayerNorm.new(args.d_model, bias: false)
85
+ self.norm_2 = MLX::NN::LayerNorm.new(args.d_model, bias: false)
86
+ self.attn = Attention.new(args)
87
+ end
88
+
89
+ def call(x, mask: nil, cache: nil)
90
+ h = attn.call(norm_1.call(x), mask: mask, cache: cache)
91
+ residual = x + h
92
+ [residual, norm_2.call(residual)]
93
+ end
94
+ end
95
+
96
+ class MLP < MLX::NN::Module
97
+ def initialize(d_model, ffn_dim)
98
+ super()
99
+ self.v1 = MLX::NN::Linear.new(d_model, ffn_dim, bias: false)
100
+ self.w1 = MLX::NN::Linear.new(d_model, ffn_dim, bias: false)
101
+ self.w2 = MLX::NN::Linear.new(ffn_dim, d_model, bias: false)
102
+ end
103
+
104
+ def call(x)
105
+ w2.call(Activations.swiglu(w1.call(x), v1.call(x)))
106
+ end
107
+ end
108
+
109
+ class Router < MLX::NN::Module
110
+ def initialize(d_model, num_experts)
111
+ super()
112
+ self.layer = MLX::NN::Linear.new(d_model, num_experts, bias: false)
113
+ end
114
+
115
+ def call(x)
116
+ layer.call(x)
117
+ end
118
+ end
119
+
120
+ class SparseMoeBlock < MLX::NN::Module
121
+ def initialize(args)
122
+ super()
123
+ @d_model = args.d_model
124
+ @ffn_dim = _ffn_value(args.ffn_config, "ffn_hidden_size", args.d_model * 4).to_i
125
+ @num_experts = _ffn_value(args.ffn_config, "moe_num_experts", 1).to_i
126
+ @num_experts_per_tok = _ffn_value(args.ffn_config, "moe_top_k", 1).to_i
127
+
128
+ self.router = Router.new(@d_model, @num_experts)
129
+ self.experts = Array.new(@num_experts) { MLP.new(@d_model, @ffn_dim) }
130
+ end
131
+
132
+ def call(x)
133
+ mx = MLX::Core
134
+
135
+ top_k = [[@num_experts_per_tok, 1].max, @num_experts].min
136
+ orig_shape = x.shape
137
+ token_count = orig_shape[0...-1].reduce(1, :*)
138
+ flat_x = x.reshape([token_count, orig_shape[-1]])
139
+
140
+ gates = router.call(flat_x)
141
+ gates = mx.softmax(gates.astype(mx.float32), -1)
142
+
143
+ inds = mx.stop_gradient(mx.argpartition(gates * -1.0, top_k - 1, -1))
144
+ take_ids = mx.array((0...top_k).to_a, dtype: mx.int32)
145
+ inds = mx.take(inds, take_ids, -1)
146
+ scores = mx.take_along_axis(gates, inds, -1)
147
+ scores = scores / mx.expand_dims(mx.sum(scores, -1), -1)
148
+ scores = scores.astype(flat_x.dtype)
149
+
150
+ expert_ids = inds.to_a
151
+ expert_scores = scores.to_a
152
+
153
+ outputs = Array.new(flat_x.shape[0]) do |token_idx|
154
+ token_ids = mx.array([token_idx], dtype: mx.int32)
155
+ token_state = mx.squeeze(mx.take(flat_x, token_ids, 0), 0)
156
+
157
+ token_out = nil
158
+ expert_ids[token_idx].each_with_index do |expert_idx, score_idx|
159
+ expert_out = experts[expert_idx.to_i].call(token_state)
160
+ weighted = expert_out * expert_scores[token_idx][score_idx].to_f
161
+ token_out = token_out.nil? ? weighted : (token_out + weighted)
162
+ end
163
+
164
+ token_out
165
+ end
166
+
167
+ mx.stack(outputs, 0).reshape(orig_shape)
168
+ end
169
+
170
+ private
171
+
172
+ def _ffn_value(config, key, default = nil)
173
+ return default if config.nil?
174
+ return config[key] if config.key?(key)
175
+
176
+ config.fetch(key.to_sym, default)
177
+ end
178
+ end
179
+
180
+ class DecoderLayer < MLX::NN::Module
181
+ def initialize(args)
182
+ super()
183
+ self.ffn = SparseMoeBlock.new(args)
184
+ self.norm_attn_norm = NormAttnNorm.new(args)
185
+ end
186
+
187
+ def call(x, mask: nil, cache: nil)
188
+ residual, hidden = norm_attn_norm.call(x, mask: mask, cache: cache)
189
+ ffn.call(hidden) + residual
190
+ end
191
+ end
192
+
193
+ class DbrxModel < MLX::NN::Module
194
+ def initialize(args)
195
+ super()
196
+ self.wte = MLX::NN::Embedding.new(args.vocab_size, args.d_model)
197
+ self.blocks = Array.new(args.n_layers) { DecoderLayer.new(args) }
198
+ self.norm_f = MLX::NN::LayerNorm.new(args.d_model, bias: false)
199
+ end
200
+
201
+ def call(inputs, cache: nil)
202
+ h = wte.call(inputs)
203
+ layer_cache = cache || [nil] * blocks.length
204
+ mask = _create_attention_mask(h, layer_cache[0])
205
+
206
+ blocks.each_with_index do |layer, layer_idx|
207
+ h = layer.call(h, mask: mask, cache: layer_cache[layer_idx])
208
+ end
209
+
210
+ norm_f.call(h)
211
+ end
212
+
213
+ private
214
+
215
+ def _create_attention_mask(hidden, cache)
216
+ return cache.make_mask(hidden.shape[1]) if cache && cache.respond_to?(:make_mask)
217
+ return nil if hidden.shape[1] == 1
218
+
219
+ "causal"
220
+ end
221
+ end
222
+
223
+ class Model < MLX::NN::Module
224
+ def initialize(args)
225
+ super()
226
+ @args = args
227
+ self.model_type = args.model_type
228
+ self.transformer = DbrxModel.new(args)
229
+ self.lm_head = MLX::NN::Linear.new(args.d_model, args.vocab_size, bias: false)
230
+ end
231
+
232
+ def call(inputs, cache: nil)
233
+ out = transformer.call(inputs, cache: cache)
234
+ lm_head.call(out)
235
+ end
236
+
237
+ def layers
238
+ transformer.blocks
239
+ end
240
+
241
+ def sanitize(weights)
242
+ mx = MLX::Core
243
+ num_experts = _ffn_value(@args.ffn_config, "moe_num_experts", 0).to_i
244
+ return weights if num_experts <= 0
245
+
246
+ pattern = "experts.mlp"
247
+ sanitized = {}
248
+
249
+ weights.each do |key, value|
250
+ unless key.include?(pattern)
251
+ sanitized[key] = value
252
+ next
253
+ end
254
+
255
+ split_weights = mx.split(value, num_experts, 0)
256
+ split_weights.each_with_index do |slice, expert_idx|
257
+ expert_key = _expert_weight_key(key, expert_idx)
258
+ if key.end_with?("w2") || key.end_with?("w2.weight")
259
+ slice = slice.transpose([1, 0])
260
+ end
261
+ sanitized[expert_key] = slice
262
+ end
263
+ end
264
+
265
+ sanitized
266
+ end
267
+
268
+ private
269
+
270
+ def _expert_weight_key(key, expert_idx)
271
+ base = key.end_with?(".weight") ? key.sub(/\.weight\z/, "") : key
272
+ "#{base.sub('.mlp', ".#{expert_idx}")}.weight"
273
+ end
274
+
275
+ def _ffn_value(config, key, default = nil)
276
+ return default if config.nil?
277
+ return config[key] if config.key?(key)
278
+
279
+ config.fetch(key.to_sym, default)
280
+ end
281
+ end
282
+
283
+ Models.register("dbrx", Model, ModelArgs)
284
+ end
285
+ end
286
+ end