mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
@@ -0,0 +1,79 @@
1
+ require_relative "deepseek"
2
+
3
+ module MlxLm
4
+ module Models
5
+ module SolarOpen
6
+ class ModelArgs < BaseModelArgs
7
+ field :model_type, default: "solar_open"
8
+ field :vocab_size
9
+ field :hidden_size
10
+ field :intermediate_size
11
+ field :moe_intermediate_size
12
+ field :num_hidden_layers
13
+ field :num_attention_heads
14
+ field :num_key_value_heads
15
+ field :head_dim
16
+ field :n_shared_experts
17
+ field :n_routed_experts
18
+ field :routed_scaling_factor
19
+ field :num_experts_per_tok
20
+ field :first_k_dense_replace
21
+ field :norm_topk_prob
22
+ field :max_position_embeddings
23
+ field :rms_norm_eps
24
+ field :rope_theta
25
+ field :tie_word_embeddings
26
+ field :partial_rotary_factor
27
+ field :rope_scaling, default: nil
28
+ field :attention_bias, default: false
29
+ field :use_qk_norm, default: false
30
+ field :n_group, default: 1
31
+ field :topk_group, default: 1
32
+ field :scoring_func, default: "sigmoid"
33
+ field :topk_method, default: "noaux_tc"
34
+ end
35
+
36
+ class Model < DeepSeek::Model
37
+ def initialize(args)
38
+ super(DeepSeek::ModelArgs.from_dict(_to_deepseek_config(args)))
39
+ self.model_type = args.model_type
40
+ end
41
+
42
+ def sanitize(weights)
43
+ sanitized = super(weights)
44
+ mpt_prefix = "model.layers.#{@args.num_hidden_layers}"
45
+ sanitized.reject do |k, _|
46
+ k == mpt_prefix || k.start_with?("#{mpt_prefix}.")
47
+ end
48
+ end
49
+
50
+ private
51
+
52
+ def _to_deepseek_config(args)
53
+ {
54
+ "model_type" => args.model_type,
55
+ "vocab_size" => args.vocab_size,
56
+ "hidden_size" => args.hidden_size,
57
+ "intermediate_size" => args.intermediate_size,
58
+ "moe_intermediate_size" => args.moe_intermediate_size,
59
+ "num_hidden_layers" => args.num_hidden_layers,
60
+ "num_attention_heads" => args.num_attention_heads,
61
+ "num_key_value_heads" => args.num_key_value_heads,
62
+ "n_shared_experts" => args.n_shared_experts,
63
+ "n_routed_experts" => args.n_routed_experts,
64
+ "num_experts_per_tok" => args.num_experts_per_tok,
65
+ "first_k_dense_replace" => args.first_k_dense_replace,
66
+ "moe_layer_freq" => 1,
67
+ "max_position_embeddings" => args.max_position_embeddings,
68
+ "rms_norm_eps" => args.rms_norm_eps,
69
+ "rope_theta" => args.rope_theta,
70
+ "rope_scaling" => args.rope_scaling,
71
+ "attention_bias" => args.attention_bias,
72
+ }
73
+ end
74
+ end
75
+
76
+ Models.register("solar_open", Model, ModelArgs)
77
+ end
78
+ end
79
+ end
@@ -0,0 +1,162 @@
1
+ module MlxLm
2
+ module Models
3
+ module SSM
4
+ module_function
5
+
6
+ def compute_dt(dt, dt_bias, time_step_limit = [0.001, 100.0])
7
+ dt = MLX::NN.softplus(dt + dt_bias)
8
+ MLX::Core.clip(dt, time_step_limit[0], time_step_limit[1])
9
+ end
10
+
11
+ def segsum(x, mask: nil)
12
+ mx = MLX::Core
13
+ l = x.shape[-1]
14
+
15
+ unless mask.nil?
16
+ mask_e = mx.expand_dims(mask, 1)
17
+ x = x * mask_e
18
+ end
19
+
20
+ x = mx.repeat(mx.expand_dims(x, -1), l, -1)
21
+ x = mx.tril(x, -1)
22
+ x_segsum = mx.cumsum(x, -2)
23
+
24
+ unless mask.nil?
25
+ mask_e = mx.expand_dims(mask, 1)
26
+ valid = mx.multiply(mx.expand_dims(mask_e, -1), mx.expand_dims(mask_e, -2))
27
+ x_segsum = mx.where(valid, x_segsum, -Float::INFINITY)
28
+ end
29
+
30
+ x_segsum
31
+ end
32
+
33
+ # Baseline implementation for SSD-SSM using explicit recurrence.
34
+ def ssm_attn(
35
+ x,
36
+ a_log,
37
+ b,
38
+ c,
39
+ d,
40
+ dt,
41
+ dt_bias,
42
+ state: nil,
43
+ time_step_limit: [0.001, 100.0],
44
+ mask: nil,
45
+ lengths: nil,
46
+ step: 256
47
+ )
48
+ _ = step
49
+ raise NotImplementedError, "length-aware SSM path is not implemented yet" unless lengths.nil?
50
+
51
+ mx = MLX::Core
52
+ batch_size, seq_len, num_heads, head_dim = x.shape
53
+ _, _, num_groups, state_dim = b.shape
54
+
55
+ repeats = num_heads / num_groups
56
+ dt = compute_dt(dt, dt_bias, time_step_limit)
57
+ dt = mx.expand_dims(dt, 0) if dt.ndim == 2
58
+ a = mx.multiply(-1.0, mx.exp(a_log).astype(dt.dtype))
59
+
60
+ state ||= mx.zeros([batch_size, num_heads, head_dim, state_dim], x.dtype)
61
+
62
+ ys = []
63
+ seq_len.times do |t|
64
+ x_t = _slice_step(x, t)
65
+ dt_t = _slice_step(dt, t)
66
+ b_t = _slice_step(b, t)
67
+ c_t = _slice_step(c, t)
68
+
69
+ if repeats > 1
70
+ b_t = mx.repeat(b_t, repeats, 1)
71
+ c_t = mx.repeat(c_t, repeats, 1)
72
+ end
73
+
74
+ decay = mx.exp(dt_t * a.reshape([1, num_heads]))
75
+ prev_state = state
76
+ state = state * decay.reshape([batch_size, num_heads, 1, 1])
77
+
78
+ dB = dt_t.reshape([batch_size, num_heads, 1, 1]) * b_t.reshape([batch_size, num_heads, 1, state_dim])
79
+ state = state + x_t.reshape([batch_size, num_heads, head_dim, 1]) * dB
80
+
81
+ y_t = (state * c_t.reshape([batch_size, num_heads, 1, state_dim])).sum(-1)
82
+ y_t = y_t + x_t * d.reshape([1, num_heads, 1])
83
+
84
+ unless mask.nil?
85
+ m_t = _slice_step(mask, t)
86
+ m_t = m_t.reshape([batch_size, 1, 1])
87
+ state = mx.where(m_t, state, prev_state)
88
+ y_t = mx.where(m_t, y_t, mx.zeros(y_t.shape, y_t.dtype))
89
+ end
90
+
91
+ ys << y_t
92
+ end
93
+
94
+ [mx.stack(ys, 1), state]
95
+ end
96
+
97
+ def ssm_update_kernel(*_args, **_kwargs)
98
+ raise NotImplementedError,
99
+ "SSM metal kernel path is not implemented in mlx-ruby-lm yet"
100
+ end
101
+
102
+ def ssm_update(
103
+ hidden_states,
104
+ a_log,
105
+ b,
106
+ c,
107
+ d,
108
+ dt,
109
+ dt_bias,
110
+ state: nil,
111
+ time_step_limit: [0.001, 100.0],
112
+ mask: nil,
113
+ lengths: nil
114
+ )
115
+ mx = MLX::Core
116
+ seq_len = hidden_states.shape[1]
117
+
118
+ use_attn_path = seq_len > 1 ||
119
+ state.nil? ||
120
+ !mx.respond_to?(:metal_is_available) ||
121
+ !mx.metal_is_available ||
122
+ !mx.respond_to?(:default_device) ||
123
+ (mx.default_device.respond_to?(:type) && mx.default_device.type != :gpu)
124
+
125
+ if use_attn_path
126
+ return ssm_attn(
127
+ hidden_states,
128
+ a_log,
129
+ b,
130
+ c,
131
+ d,
132
+ dt,
133
+ dt_bias,
134
+ state: state,
135
+ time_step_limit: time_step_limit,
136
+ mask: mask,
137
+ lengths: lengths
138
+ )
139
+ end
140
+
141
+ ssm_update_kernel(
142
+ hidden_states,
143
+ a_log,
144
+ b,
145
+ c,
146
+ d,
147
+ dt,
148
+ dt_bias,
149
+ state,
150
+ time_step_limit
151
+ )
152
+ end
153
+
154
+ def _slice_step(array, idx)
155
+ mx = MLX::Core
156
+ tail = idx.zero? ? array : mx.split(array, [idx], 1)[1]
157
+ mx.squeeze(mx.split(tail, [1], 1)[0], 1)
158
+ end
159
+ private_class_method :_slice_step
160
+ end
161
+ end
162
+ end
@@ -0,0 +1,160 @@
1
+ module MlxLm
2
+ module Models
3
+ module StableLM
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "stablelm"
6
+ field :hidden_size, default: 2048
7
+ field :num_hidden_layers, default: 24
8
+ field :num_attention_heads, default: 32
9
+ field :num_key_value_heads, default: 32
10
+ field :intermediate_size, default: 5632
11
+ field :vocab_size, default: 50304
12
+ field :rope_theta, default: 10000.0
13
+ field :use_qkv_bias, default: false
14
+ field :partial_rotary_factor, default: 0.25
15
+ field :layer_norm_eps, default: 1e-5
16
+ field :use_parallel_residual, default: false
17
+ field :qk_layernorm, default: false
18
+
19
+ def initialize(**kwargs)
20
+ super
21
+ @num_key_value_heads ||= @num_attention_heads
22
+ end
23
+ end
24
+
25
+ class Attention < MLX::NN::Module
26
+ def initialize(args)
27
+ super()
28
+ dim = args.hidden_size
29
+ @n_heads = args.num_attention_heads
30
+ @n_kv_heads = args.num_key_value_heads
31
+ @head_dim = dim / @n_heads
32
+ @scale = @head_dim**(-0.5)
33
+
34
+ bias = args.use_qkv_bias
35
+ self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: bias)
36
+ self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: bias)
37
+ self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: bias)
38
+ self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: false)
39
+
40
+ # Partial rotary: only apply RoPE to a fraction of head_dim
41
+ rope_dim = (args.partial_rotary_factor * @head_dim).to_i
42
+ self.rope = MLX::NN::RoPE.new(
43
+ rope_dim,
44
+ traditional: false,
45
+ base: args.rope_theta
46
+ )
47
+ end
48
+
49
+ def call(x, mask: nil, cache: nil)
50
+ mx = MLX::Core
51
+ b, l, _d = x.shape
52
+
53
+ queries = q_proj.call(x).reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
54
+ keys = k_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
55
+ values = v_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
56
+
57
+ if cache
58
+ queries = rope.call(queries, offset: cache.offset)
59
+ keys = rope.call(keys, offset: cache.offset)
60
+ keys, values = cache.update_and_fetch(keys, values)
61
+ else
62
+ queries = rope.call(queries)
63
+ keys = rope.call(keys)
64
+ end
65
+
66
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
67
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
68
+ o_proj.call(output)
69
+ end
70
+ end
71
+
72
+ class MLP < MLX::NN::Module
73
+ def initialize(dim, hidden_dim)
74
+ super()
75
+ self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false)
76
+ self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: false)
77
+ self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false)
78
+ end
79
+
80
+ def call(x)
81
+ down_proj.call(MLX::NN.silu(gate_proj.call(x)) * up_proj.call(x))
82
+ end
83
+ end
84
+
85
+ class DecoderLayer < MLX::NN::Module
86
+ def initialize(args)
87
+ super()
88
+ self.self_attn = Attention.new(args)
89
+ self.mlp = MLP.new(args.hidden_size, args.intermediate_size)
90
+ self.input_layernorm = MLX::NN::LayerNorm.new(args.hidden_size, eps: args.layer_norm_eps)
91
+ @use_parallel_residual = args.use_parallel_residual
92
+ unless @use_parallel_residual
93
+ self.post_attention_layernorm = MLX::NN::LayerNorm.new(args.hidden_size, eps: args.layer_norm_eps)
94
+ end
95
+ end
96
+
97
+ def call(x, mask: nil, cache: nil)
98
+ h = input_layernorm.call(x)
99
+ r = self_attn.call(h, mask: mask, cache: cache)
100
+
101
+ if @use_parallel_residual
102
+ x + r + mlp.call(h)
103
+ else
104
+ h = x + r
105
+ r = mlp.call(post_attention_layernorm.call(h))
106
+ h + r
107
+ end
108
+ end
109
+ end
110
+
111
+ class StableLMModel < MLX::NN::Module
112
+ def initialize(args)
113
+ super()
114
+ @args = args
115
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
116
+ self.layers = Array.new(args.num_hidden_layers) { DecoderLayer.new(args) }
117
+ self.norm = MLX::NN::LayerNorm.new(args.hidden_size, eps: args.layer_norm_eps)
118
+ end
119
+
120
+ def call(inputs, cache: nil)
121
+ h = embed_tokens.call(inputs)
122
+ layer_cache = cache || [nil] * layers.length
123
+
124
+ mask = nil
125
+ mask = "causal" if h.shape[1] > 1
126
+
127
+ layers.each_with_index do |layer, i|
128
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
129
+ end
130
+
131
+ norm.call(h)
132
+ end
133
+ end
134
+
135
+ class Model < MLX::NN::Module
136
+ def initialize(args)
137
+ super()
138
+ @args = args
139
+ self.model = StableLMModel.new(args)
140
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
141
+ end
142
+
143
+ def call(inputs, cache: nil)
144
+ out = model.call(inputs, cache: cache)
145
+ lm_head.call(out)
146
+ end
147
+
148
+ def sanitize(weights)
149
+ weights.reject { |k, _| k.include?("self_attn.rotary_emb.inv_freq") }
150
+ end
151
+
152
+ def layers
153
+ model.layers
154
+ end
155
+ end
156
+
157
+ Models.register("stablelm", Model, ModelArgs)
158
+ end
159
+ end
160
+ end
@@ -0,0 +1,161 @@
1
+ module MlxLm
2
+ module Models
3
+ module Starcoder2
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "starcoder2"
6
+ field :hidden_size, default: 3072
7
+ field :num_hidden_layers, default: 30
8
+ field :num_attention_heads, default: 24
9
+ field :num_key_value_heads, default: 2
10
+ field :intermediate_size, default: 12288
11
+ field :vocab_size, default: 49152
12
+ field :norm_epsilon, default: 1e-5
13
+ field :rope_theta, default: 100000.0
14
+ field :rope_traditional, default: false
15
+ field :tie_word_embeddings, default: true
16
+ field :head_dim, default: nil
17
+
18
+ def initialize(**kwargs)
19
+ super
20
+ @num_key_value_heads ||= @num_attention_heads
21
+ @head_dim ||= @hidden_size / @num_attention_heads
22
+ end
23
+ end
24
+
25
+ class Attention < MLX::NN::Module
26
+ def initialize(args)
27
+ super()
28
+ dim = args.hidden_size
29
+ @n_heads = args.num_attention_heads
30
+ @n_kv_heads = args.num_key_value_heads
31
+ @head_dim = args.head_dim
32
+ @scale = @head_dim**(-0.5)
33
+
34
+ # StarCoder2: all projections have bias
35
+ self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: true)
36
+ self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: true)
37
+ self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: true)
38
+ self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: true)
39
+
40
+ self.rope = MLX::NN::RoPE.new(
41
+ @head_dim,
42
+ traditional: args.rope_traditional,
43
+ base: args.rope_theta
44
+ )
45
+ end
46
+
47
+ def call(x, mask: nil, cache: nil)
48
+ mx = MLX::Core
49
+ b, l, _d = x.shape
50
+
51
+ queries = q_proj.call(x).reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
52
+ keys = k_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
53
+ values = v_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
54
+
55
+ if cache
56
+ queries = rope.call(queries, offset: cache.offset)
57
+ keys = rope.call(keys, offset: cache.offset)
58
+ keys, values = cache.update_and_fetch(keys, values)
59
+ else
60
+ queries = rope.call(queries)
61
+ keys = rope.call(keys)
62
+ end
63
+
64
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
65
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
66
+ o_proj.call(output)
67
+ end
68
+ end
69
+
70
+ # StarCoder2: simple gelu MLP (no gating)
71
+ class MLP < MLX::NN::Module
72
+ def initialize(args)
73
+ super()
74
+ dim = args.hidden_size
75
+ hidden_dim = args.intermediate_size
76
+
77
+ self.c_fc = MLX::NN::Linear.new(dim, hidden_dim, bias: true)
78
+ self.c_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: true)
79
+ end
80
+
81
+ def call(x)
82
+ c_proj.call(MLX::NN.gelu(c_fc.call(x)))
83
+ end
84
+ end
85
+
86
+ class TransformerBlock < MLX::NN::Module
87
+ def initialize(args)
88
+ super()
89
+ self.self_attn = Attention.new(args)
90
+ self.mlp = MLP.new(args)
91
+ # StarCoder2 uses LayerNorm, not RMSNorm
92
+ self.input_layernorm = MLX::NN::LayerNorm.new(args.hidden_size, eps: args.norm_epsilon)
93
+ self.post_attention_layernorm = MLX::NN::LayerNorm.new(args.hidden_size, eps: args.norm_epsilon)
94
+ end
95
+
96
+ def call(x, mask: nil, cache: nil)
97
+ r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
98
+ h = x + r
99
+ r = mlp.call(post_attention_layernorm.call(h))
100
+ h + r
101
+ end
102
+ end
103
+
104
+ class Starcoder2Model < MLX::NN::Module
105
+ def initialize(args)
106
+ super()
107
+ @args = args
108
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
109
+ self.layers = Array.new(args.num_hidden_layers) { TransformerBlock.new(args) }
110
+ self.norm = MLX::NN::LayerNorm.new(args.hidden_size, eps: args.norm_epsilon)
111
+ end
112
+
113
+ def call(inputs, cache: nil)
114
+ h = embed_tokens.call(inputs)
115
+ layer_cache = cache || [nil] * layers.length
116
+
117
+ mask = nil
118
+ mask = "causal" if h.shape[1] > 1
119
+
120
+ layers.each_with_index do |layer, i|
121
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
122
+ end
123
+
124
+ norm.call(h)
125
+ end
126
+ end
127
+
128
+ class Model < MLX::NN::Module
129
+ def initialize(args)
130
+ super()
131
+ @args = args
132
+ self.model = Starcoder2Model.new(args)
133
+ unless args.tie_word_embeddings
134
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
135
+ end
136
+ end
137
+
138
+ def call(inputs, cache: nil)
139
+ out = model.call(inputs, cache: cache)
140
+ if @args.tie_word_embeddings
141
+ model.embed_tokens.as_linear(out)
142
+ else
143
+ lm_head.call(out)
144
+ end
145
+ end
146
+
147
+ def sanitize(weights)
148
+ result = weights.reject { |k, _| k.include?("self_attn.rotary_emb.inv_freq") }
149
+ result.delete("lm_head.weight") if @args.tie_word_embeddings
150
+ result
151
+ end
152
+
153
+ def layers
154
+ model.layers
155
+ end
156
+ end
157
+
158
+ Models.register("starcoder2", Model, ModelArgs)
159
+ end
160
+ end
161
+ end