mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
@@ -0,0 +1,192 @@
1
+ module MlxLm
2
+ module Models
3
+ module Telechat3
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "telechat3"
6
+ field :hidden_size, default: 4096
7
+ field :intermediate_size, default: 14336
8
+ field :max_position_embeddings, default: 32768
9
+ field :num_attention_heads, default: 32
10
+ field :num_hidden_layers, default: 32
11
+ field :num_key_value_heads, default: nil
12
+ field :rms_norm_eps, default: 1e-6
13
+ field :vocab_size, default: 151936
14
+ field :rope_theta, default: 10_000.0
15
+ field :mlp_bias, default: false
16
+ field :attention_bias, default: false
17
+ field :head_dim, default: nil
18
+ field :rope_scaling, default: nil
19
+ field :tie_word_embeddings, default: false
20
+
21
+ def initialize(**kwargs)
22
+ super
23
+ @num_key_value_heads ||= @num_attention_heads
24
+ @head_dim ||= @hidden_size / @num_attention_heads
25
+ end
26
+ end
27
+
28
+ class Telechat3Attention < MLX::NN::Module
29
+ def initialize(args)
30
+ super()
31
+ dim = args.hidden_size
32
+ @num_attention_heads = args.num_attention_heads
33
+ @num_key_value_heads = args.num_key_value_heads
34
+ @head_dim = args.head_dim
35
+ @scale = @head_dim**(-0.5)
36
+
37
+ self.q_proj = MLX::NN::Linear.new(
38
+ dim,
39
+ args.num_attention_heads * @head_dim,
40
+ bias: args.attention_bias
41
+ )
42
+ self.k_proj = MLX::NN::Linear.new(
43
+ dim,
44
+ args.num_key_value_heads * @head_dim,
45
+ bias: args.attention_bias
46
+ )
47
+ self.v_proj = MLX::NN::Linear.new(
48
+ dim,
49
+ args.num_key_value_heads * @head_dim,
50
+ bias: args.attention_bias
51
+ )
52
+ self.o_proj = MLX::NN::Linear.new(
53
+ args.num_attention_heads * @head_dim,
54
+ dim,
55
+ bias: args.attention_bias
56
+ )
57
+
58
+ self.rope = MlxLm::Models.initialize_rope(
59
+ @head_dim,
60
+ args.rope_theta,
61
+ false,
62
+ args.rope_scaling,
63
+ max_position_embeddings: args.max_position_embeddings
64
+ )
65
+ end
66
+
67
+ def call(x, mask: nil, cache: nil)
68
+ mx = MLX::Core
69
+ b, l, _d = x.shape
70
+
71
+ queries = q_proj.call(x)
72
+ keys = k_proj.call(x)
73
+ values = v_proj.call(x)
74
+
75
+ queries = queries.reshape([b, l, @num_attention_heads, @head_dim]).transpose([0, 2, 1, 3])
76
+ keys = keys.reshape([b, l, @num_key_value_heads, @head_dim]).transpose([0, 2, 1, 3])
77
+ values = values.reshape([b, l, @num_key_value_heads, @head_dim]).transpose([0, 2, 1, 3])
78
+
79
+ if cache
80
+ queries = rope.call(queries, offset: cache.offset)
81
+ keys = rope.call(keys, offset: cache.offset)
82
+ keys, values = cache.update_and_fetch(keys, values)
83
+ else
84
+ queries = rope.call(queries)
85
+ keys = rope.call(keys)
86
+ end
87
+
88
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
89
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @num_attention_heads * @head_dim])
90
+ o_proj.call(output)
91
+ end
92
+ end
93
+
94
+ class Telechat3MLP < MLX::NN::Module
95
+ def initialize(args)
96
+ super()
97
+ self.gate_proj = MLX::NN::Linear.new(
98
+ args.hidden_size,
99
+ args.intermediate_size,
100
+ bias: args.mlp_bias
101
+ )
102
+ self.down_proj = MLX::NN::Linear.new(
103
+ args.intermediate_size,
104
+ args.hidden_size,
105
+ bias: args.mlp_bias
106
+ )
107
+ self.up_proj = MLX::NN::Linear.new(
108
+ args.hidden_size,
109
+ args.intermediate_size,
110
+ bias: args.mlp_bias
111
+ )
112
+ end
113
+
114
+ def call(x)
115
+ down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x)))
116
+ end
117
+ end
118
+
119
+ class Telechat3DecoderLayer < MLX::NN::Module
120
+ def initialize(args)
121
+ super()
122
+ self.self_attn = Telechat3Attention.new(args)
123
+ self.mlp = Telechat3MLP.new(args)
124
+ self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
125
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
126
+ end
127
+
128
+ def call(x, mask: nil, cache: nil)
129
+ r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
130
+ h = x + r
131
+ h + mlp.call(post_attention_layernorm.call(h))
132
+ end
133
+ end
134
+
135
+ class Telechat3Model < MLX::NN::Module
136
+ def initialize(args)
137
+ super()
138
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
139
+ self.layers = Array.new(args.num_hidden_layers) { Telechat3DecoderLayer.new(args) }
140
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
141
+ end
142
+
143
+ def call(inputs, cache: nil, input_embeddings: nil)
144
+ h = input_embeddings || embed_tokens.call(inputs)
145
+ layer_cache = cache || [nil] * layers.length
146
+
147
+ mask = nil
148
+ mask = "causal" if h.shape[1] > 1
149
+
150
+ layers.each_with_index do |layer, i|
151
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
152
+ end
153
+
154
+ norm.call(h)
155
+ end
156
+ end
157
+
158
+ class Model < MLX::NN::Module
159
+ def initialize(args)
160
+ super()
161
+ @args = args
162
+ self.model_type = args.model_type
163
+ self.model = Telechat3Model.new(args)
164
+ unless args.tie_word_embeddings
165
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
166
+ end
167
+ end
168
+
169
+ def call(inputs, cache: nil, input_embeddings: nil)
170
+ out = model.call(inputs, cache: cache, input_embeddings: input_embeddings)
171
+ if @args.tie_word_embeddings
172
+ model.embed_tokens.as_linear(out)
173
+ else
174
+ lm_head.call(out)
175
+ end
176
+ end
177
+
178
+ def sanitize(weights)
179
+ result = weights.reject { |k, _| k.include?("self_attn.rotary_emb.inv_freq") }
180
+ result.delete("lm_head.weight") if @args.tie_word_embeddings
181
+ result
182
+ end
183
+
184
+ def layers
185
+ model.layers
186
+ end
187
+ end
188
+
189
+ Models.register("telechat3", Model, ModelArgs)
190
+ end
191
+ end
192
+ end
@@ -0,0 +1,230 @@
1
+ module MlxLm
2
+ module Models
3
+ module YoutuLLM
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "youtu_llm"
6
+ field :vocab_size, default: 128_256
7
+ field :hidden_size, default: 2048
8
+ field :intermediate_size, default: 6144
9
+ field :num_hidden_layers, default: 32
10
+ field :num_attention_heads, default: 16
11
+ field :num_key_value_heads, default: 16
12
+ field :kv_lora_rank, default: 512
13
+ field :q_lora_rank, default: 1536
14
+ field :qk_rope_head_dim, default: 64
15
+ field :v_head_dim, default: 128
16
+ field :qk_nope_head_dim, default: 128
17
+ field :max_position_embeddings, default: 131_072
18
+ field :rms_norm_eps, default: 1e-6
19
+ field :rope_theta, default: 1_600_000.0
20
+ field :rope_traditional, default: true
21
+ field :rope_scaling, default: nil
22
+ field :attention_bias, default: false
23
+ field :mlp_bias, default: false
24
+ field :tie_word_embeddings, default: true
25
+ end
26
+
27
+ class YoutuLLMAttention < MLX::NN::Module
28
+ def initialize(config)
29
+ super()
30
+ @hidden_size = config.hidden_size
31
+ @num_heads = config.num_attention_heads
32
+ @q_lora_rank = config.q_lora_rank
33
+ @qk_rope_head_dim = config.qk_rope_head_dim
34
+ @kv_lora_rank = config.kv_lora_rank
35
+ @v_head_dim = config.v_head_dim
36
+ @qk_nope_head_dim = config.qk_nope_head_dim
37
+ @q_head_dim = @qk_nope_head_dim + @qk_rope_head_dim
38
+ @kv_head_dim = @qk_nope_head_dim + @v_head_dim
39
+ @scale = @q_head_dim**(-0.5)
40
+
41
+ if @q_lora_rank.nil?
42
+ self.q_proj = MLX::NN::Linear.new(
43
+ @hidden_size,
44
+ @num_heads * @q_head_dim,
45
+ bias: false
46
+ )
47
+ else
48
+ self.q_a_proj = MLX::NN::Linear.new(
49
+ @hidden_size,
50
+ @q_lora_rank,
51
+ bias: config.attention_bias
52
+ )
53
+ self.q_a_layernorm = MLX::NN::RMSNorm.new(@q_lora_rank, eps: config.rms_norm_eps)
54
+ self.q_b_proj = MLX::NN::Linear.new(@q_lora_rank, @num_heads * @q_head_dim, bias: false)
55
+ end
56
+
57
+ self.kv_a_proj_with_mqa = MLX::NN::Linear.new(
58
+ @hidden_size,
59
+ @kv_lora_rank + @qk_rope_head_dim,
60
+ bias: config.attention_bias
61
+ )
62
+ self.kv_a_layernorm = MLX::NN::RMSNorm.new(@kv_lora_rank, eps: config.rms_norm_eps)
63
+ self.kv_b_proj = MLX::NN::Linear.new(
64
+ @kv_lora_rank,
65
+ @num_heads * (@q_head_dim - @qk_rope_head_dim + @v_head_dim),
66
+ bias: false
67
+ )
68
+
69
+ self.o_proj = MLX::NN::Linear.new(
70
+ @num_heads * @v_head_dim,
71
+ @hidden_size,
72
+ bias: config.attention_bias
73
+ )
74
+
75
+ self.rope = MlxLm::Models.initialize_rope(
76
+ @qk_rope_head_dim,
77
+ config.rope_theta,
78
+ config.rope_traditional,
79
+ config.rope_scaling,
80
+ max_position_embeddings: config.max_position_embeddings
81
+ )
82
+ end
83
+
84
+ def call(x, mask: nil, cache: nil)
85
+ mx = MLX::Core
86
+ b, l, _d = x.shape
87
+
88
+ q = if @q_lora_rank.nil?
89
+ q_proj.call(x)
90
+ else
91
+ q_b_proj.call(q_a_layernorm.call(q_a_proj.call(x)))
92
+ end
93
+
94
+ q = q.reshape([b, l, @num_heads, @q_head_dim]).transpose([0, 2, 1, 3])
95
+ q_nope, q_pe = mx.split(q, [@qk_nope_head_dim], -1)
96
+
97
+ compressed_kv = kv_a_proj_with_mqa.call(x)
98
+ compressed_kv, k_pe = mx.split(compressed_kv, [@kv_lora_rank], -1)
99
+ k_pe = k_pe.reshape([b, l, 1, @qk_rope_head_dim]).transpose([0, 2, 1, 3])
100
+
101
+ kv = kv_b_proj.call(kv_a_layernorm.call(compressed_kv))
102
+ kv = kv.reshape([b, l, @num_heads, @kv_head_dim]).transpose([0, 2, 1, 3])
103
+ k_nope, values = mx.split(kv, [@qk_nope_head_dim], -1)
104
+
105
+ if cache
106
+ q_pe = rope.call(q_pe, offset: cache.offset)
107
+ k_pe = rope.call(k_pe, offset: cache.offset)
108
+ k_pe = mx.repeat(k_pe, @num_heads, 1)
109
+ keys, values = cache.update_and_fetch(mx.concatenate([k_nope, k_pe], -1), values)
110
+ else
111
+ q_pe = rope.call(q_pe)
112
+ k_pe = rope.call(k_pe)
113
+ k_pe = mx.repeat(k_pe, @num_heads, 1)
114
+ keys = mx.concatenate([k_nope, k_pe], -1)
115
+ end
116
+
117
+ queries = mx.concatenate([q_nope, q_pe], -1)
118
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
119
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @num_heads * @v_head_dim])
120
+ o_proj.call(output)
121
+ end
122
+ end
123
+
124
+ class YoutuLLMMLP < MLX::NN::Module
125
+ def initialize(config)
126
+ super()
127
+ self.gate_proj = MLX::NN::Linear.new(
128
+ config.hidden_size,
129
+ config.intermediate_size,
130
+ bias: config.mlp_bias
131
+ )
132
+ self.up_proj = MLX::NN::Linear.new(
133
+ config.hidden_size,
134
+ config.intermediate_size,
135
+ bias: config.mlp_bias
136
+ )
137
+ self.down_proj = MLX::NN::Linear.new(
138
+ config.intermediate_size,
139
+ config.hidden_size,
140
+ bias: config.mlp_bias
141
+ )
142
+ end
143
+
144
+ def call(x)
145
+ down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x)))
146
+ end
147
+ end
148
+
149
+ class YoutuLLMDecoderLayer < MLX::NN::Module
150
+ def initialize(config)
151
+ super()
152
+ self.self_attn = YoutuLLMAttention.new(config)
153
+ self.mlp = YoutuLLMMLP.new(config)
154
+ self.input_layernorm = MLX::NN::RMSNorm.new(config.hidden_size, eps: config.rms_norm_eps)
155
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(config.hidden_size, eps: config.rms_norm_eps)
156
+ end
157
+
158
+ def call(x, mask: nil, cache: nil)
159
+ r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
160
+ h = x + r
161
+ r = mlp.call(post_attention_layernorm.call(h))
162
+ h + r
163
+ end
164
+ end
165
+
166
+ class YoutuLLMModel < MLX::NN::Module
167
+ def initialize(config)
168
+ super()
169
+ self.embed_tokens = MLX::NN::Embedding.new(config.vocab_size, config.hidden_size)
170
+ self.layers = Array.new(config.num_hidden_layers) { YoutuLLMDecoderLayer.new(config) }
171
+ self.norm = MLX::NN::RMSNorm.new(config.hidden_size, eps: config.rms_norm_eps)
172
+ end
173
+
174
+ def call(inputs, cache: nil)
175
+ h = embed_tokens.call(inputs)
176
+ layer_cache = cache || [nil] * layers.length
177
+ mask = _create_attention_mask(h, layer_cache[0])
178
+
179
+ layers.each_with_index do |layer, i|
180
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
181
+ end
182
+
183
+ norm.call(h)
184
+ end
185
+
186
+ private
187
+
188
+ def _create_attention_mask(h, cache)
189
+ return cache.make_mask(h.shape[1]) if cache && cache.respond_to?(:make_mask)
190
+ return nil if h.shape[1] == 1
191
+
192
+ "causal"
193
+ end
194
+ end
195
+
196
+ class Model < MLX::NN::Module
197
+ def initialize(config)
198
+ super()
199
+ @config = config
200
+ self.model_type = config.model_type
201
+ self.model = YoutuLLMModel.new(config)
202
+ unless config.tie_word_embeddings
203
+ self.lm_head = MLX::NN::Linear.new(config.hidden_size, config.vocab_size, bias: false)
204
+ end
205
+ end
206
+
207
+ def call(inputs, cache: nil)
208
+ out = model.call(inputs, cache: cache)
209
+ if @config.tie_word_embeddings
210
+ model.embed_tokens.as_linear(out)
211
+ else
212
+ lm_head.call(out)
213
+ end
214
+ end
215
+
216
+ def sanitize(weights)
217
+ result = weights.dup
218
+ result.delete("lm_head.weight") if @config.tie_word_embeddings
219
+ result
220
+ end
221
+
222
+ def layers
223
+ model.layers
224
+ end
225
+ end
226
+
227
+ Models.register("youtu_llm", Model, ModelArgs)
228
+ end
229
+ end
230
+ end
@@ -0,0 +1,33 @@
1
+ module MlxLm
2
+ module Models
3
+ # Model registry: maps architecture name to [Model, ModelArgs] classes.
4
+ # Additional architectures register themselves here.
5
+ REGISTRY = {}
6
+
7
+ # Remapping for architectures that share implementation
8
+ REMAPPING = {
9
+ "mistral" => "llama",
10
+ "falcon_mamba" => "mamba",
11
+ }.freeze
12
+
13
+ module_function
14
+
15
+ def register(name, model_class, args_class)
16
+ REGISTRY[name] = [model_class, args_class]
17
+ end
18
+
19
+ def get_classes(config)
20
+ model_type = config["model_type"]
21
+ raise ArgumentError, "config.json missing 'model_type' field" unless model_type
22
+
23
+ # Apply remapping
24
+ canonical = REMAPPING.fetch(model_type, model_type)
25
+
26
+ unless REGISTRY.key?(canonical)
27
+ raise ArgumentError, "Model architecture '#{model_type}' (canonical: '#{canonical}') not found in registry. Available: #{REGISTRY.keys.join(', ')}"
28
+ end
29
+
30
+ REGISTRY[canonical]
31
+ end
32
+ end
33
+ end
@@ -0,0 +1,48 @@
1
+ module MlxLm
2
+ module Perplexity
3
+ module_function
4
+
5
+ # Compute perplexity of a model on a token sequence.
6
+ # Returns exp(average negative log-likelihood per token).
7
+ def compute(model, tokens, batch_size: nil)
8
+ ll = log_likelihood(model, tokens, batch_size: batch_size)
9
+ num_tokens = tokens.size - 1
10
+ avg_nll = -ll / num_tokens.to_f
11
+ Math.exp(avg_nll)
12
+ end
13
+
14
+ # Compute total log-likelihood of a token sequence.
15
+ # Sum of log P(token_i | token_0..token_{i-1}) for i in 1..n.
16
+ def log_likelihood(model, tokens, batch_size: nil)
17
+ mx = MLX::Core
18
+
19
+ token_arr = tokens.is_a?(MLX::Core::Array) ? tokens : mx.array(tokens, dtype: mx.int32)
20
+ total_tokens = token_arr.size
21
+
22
+ # Process all at once for small sequences
23
+ input = token_arr.reshape([1, total_tokens])
24
+ logits = model.call(input)
25
+ mx.eval(logits)
26
+
27
+ # Compute log probabilities
28
+ # logits shape: [1, total_tokens, vocab_size]
29
+ # We want P(token[i+1] | token[0..i])
30
+ vocab_size = logits.shape[-1]
31
+ logits_2d = logits.reshape([total_tokens, vocab_size])
32
+
33
+ # Log softmax
34
+ log_probs = logits_2d - mx.logsumexp(logits_2d, -1, true)
35
+
36
+ # Gather log probs for actual next tokens
37
+ # For position i, the model predicts token i+1
38
+ total_ll = 0.0
39
+ (0...(total_tokens - 1)).each do |i|
40
+ target_token = token_arr[i + 1].item
41
+ lp = log_probs[i][target_token].item
42
+ total_ll += lp
43
+ end
44
+
45
+ total_ll
46
+ end
47
+ end
48
+ end
@@ -0,0 +1,131 @@
1
+ module MlxLm
2
+ module Quantize
3
+ module_function
4
+
5
+ # Quantize a model's linear and embedding layers.
6
+ #
7
+ # @param model [nn::Module] The model to quantize
8
+ # @param group_size [Integer] Group size for quantization (default: 64)
9
+ # @param bits [Integer] Number of bits (default: 4)
10
+ # @param weights [Hash] Optional weight dict; if provided, only quantize layers
11
+ # that have corresponding .scales keys
12
+ # @return [Hash] The quantization config
13
+ def quantize_model(model, group_size: 64, bits: 4, weights: nil)
14
+ if weights
15
+ # Auto-detect: only quantize layers that have .scales in weights
16
+ class_predicate = ->(path, mod) {
17
+ if mod.respond_to?(:to_quantized)
18
+ weights.key?("#{path}.scales")
19
+ else
20
+ false
21
+ end
22
+ }
23
+ else
24
+ # Quantize all quantizable layers
25
+ class_predicate = ->(_path, mod) {
26
+ mod.respond_to?(:to_quantized)
27
+ }
28
+ end
29
+
30
+ MLX::NN.quantize(model, group_size: group_size, bits: bits, class_predicate: class_predicate)
31
+
32
+ { "group_size" => group_size, "bits" => bits }
33
+ end
34
+
35
+ # Dequantize a model (convert QuantizedLinear back to Linear, etc.)
36
+ #
37
+ # @param model [nn::Module] The quantized model to dequantize
38
+ # @return [nn::Module] The dequantized model
39
+ def dequantize_model(model)
40
+ mx = MLX::Core
41
+ de_quantize_layers(model)
42
+ model
43
+ end
44
+
45
+ # Compute bits per weight for a model
46
+ def bits_per_weight(model)
47
+ total_bits = 0
48
+ total_params = 0
49
+
50
+ model.named_modules.each do |name, mod|
51
+ case mod
52
+ when MLX::NN::QuantizedLinear
53
+ # Quantized: bits per element = quantized bits
54
+ weight = mod.instance_variable_get(:@weight)
55
+ if weight
56
+ num_params = weight.shape.reduce(:*)
57
+ total_bits += num_params * 4 # approximate
58
+ total_params += num_params
59
+ end
60
+ when MLX::NN::Linear
61
+ weight = mod.instance_variable_get(:@weight)
62
+ if weight
63
+ num_params = weight.shape.reduce(:*)
64
+ total_bits += num_params * 32 # float32
65
+ total_params += num_params
66
+ end
67
+ end
68
+ end
69
+
70
+ total_params > 0 ? total_bits.to_f / total_params : 0.0
71
+ end
72
+
73
+ private
74
+
75
+ def self.de_quantize_layers(mod)
76
+ mod.instance_variables.each do |ivar|
77
+ val = mod.instance_variable_get(ivar)
78
+ case val
79
+ when MLX::NN::QuantizedLinear
80
+ # Convert back to Linear
81
+ dequantized = linear_from_quantized(val)
82
+ mod.instance_variable_set(ivar, dequantized)
83
+ when MLX::NN::QuantizedEmbedding
84
+ dequantized = embedding_from_quantized(val)
85
+ mod.instance_variable_set(ivar, dequantized)
86
+ when MLX::NN::Module
87
+ de_quantize_layers(val)
88
+ when ::Array
89
+ val.each { |item| de_quantize_layers(item) if item.is_a?(MLX::NN::Module) }
90
+ end
91
+ end
92
+ end
93
+
94
+ def self.linear_from_quantized(qlinear)
95
+ mx = MLX::Core
96
+ weight = qlinear.instance_variable_get(:@weight)
97
+ scales = qlinear.instance_variable_get(:@scales)
98
+ biases = qlinear.instance_variable_get(:@biases)
99
+ bias = qlinear.instance_variable_get(:@bias)
100
+ group_size = qlinear.instance_variable_get(:@group_size) || 64
101
+ bits = qlinear.instance_variable_get(:@bits) || 4
102
+
103
+ # Dequantize weight
104
+ dequantized = mx.dequantize(weight, scales, biases, group_size, bits)
105
+ out_features = dequantized.shape[0]
106
+ in_features = dequantized.shape[1]
107
+
108
+ linear = MLX::NN::Linear.new(in_features, out_features, bias: !bias.nil?)
109
+ linear.instance_variable_set(:@weight, dequantized)
110
+ linear.instance_variable_set(:@bias, bias) if bias
111
+ linear
112
+ end
113
+
114
+ def self.embedding_from_quantized(qembed)
115
+ mx = MLX::Core
116
+ weight = qembed.instance_variable_get(:@weight)
117
+ scales = qembed.instance_variable_get(:@scales)
118
+ biases = qembed.instance_variable_get(:@biases)
119
+ group_size = qembed.instance_variable_get(:@group_size) || 64
120
+ bits = qembed.instance_variable_get(:@bits) || 4
121
+
122
+ dequantized = mx.dequantize(weight, scales, biases, group_size, bits)
123
+ num_embeddings = dequantized.shape[0]
124
+ dims = dequantized.shape[1]
125
+
126
+ embedding = MLX::NN::Embedding.new(num_embeddings, dims)
127
+ embedding.instance_variable_set(:@weight, dequantized)
128
+ embedding
129
+ end
130
+ end
131
+ end