mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
@@ -0,0 +1,160 @@
1
+ module MlxLm
2
+ module Models
3
+ module InternLM2
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "internlm2"
6
+ field :hidden_size, default: 4096
7
+ field :num_hidden_layers, default: 32
8
+ field :num_attention_heads, default: 32
9
+ field :num_key_value_heads, default: nil
10
+ field :intermediate_size, default: 11008
11
+ field :vocab_size, default: 103168
12
+ field :rms_norm_eps, default: 1e-6
13
+ field :rope_theta, default: 10000.0
14
+ field :rope_traditional, default: false
15
+ field :rope_scaling, default: nil
16
+ field :bias, default: true
17
+ field :tie_word_embeddings, default: false
18
+ field :max_position_embeddings, default: 32768
19
+
20
+ def initialize(**kwargs)
21
+ super
22
+ @num_key_value_heads ||= @num_attention_heads
23
+ end
24
+ end
25
+
26
+ class Attention < MLX::NN::Module
27
+ def initialize(args)
28
+ super()
29
+ dim = args.hidden_size
30
+ @n_heads = args.num_attention_heads
31
+ @n_kv_heads = args.num_key_value_heads
32
+ @head_dim = dim / @n_heads
33
+ @scale = @head_dim**(-0.5)
34
+
35
+ # Combined QKV projection
36
+ total_qkv = (@n_heads + 2 * @n_kv_heads) * @head_dim
37
+ self.wqkv = MLX::NN::Linear.new(dim, total_qkv, bias: args.bias)
38
+ self.wo = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: args.bias)
39
+
40
+ self.rope = MLX::NN::RoPE.new(
41
+ @head_dim,
42
+ traditional: args.rope_traditional,
43
+ base: args.rope_theta
44
+ )
45
+ end
46
+
47
+ def call(x, mask: nil, cache: nil)
48
+ mx = MLX::Core
49
+ b, l, _d = x.shape
50
+
51
+ qkv = wqkv.call(x)
52
+ q_size = @n_heads * @head_dim
53
+ kv_size = @n_kv_heads * @head_dim
54
+ queries, keys, values = mx.split(qkv, [q_size, q_size + kv_size], 2)
55
+
56
+ queries = queries.reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
57
+ keys = keys.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
58
+ values = values.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
59
+
60
+ if cache
61
+ queries = rope.call(queries, offset: cache.offset)
62
+ keys = rope.call(keys, offset: cache.offset)
63
+ keys, values = cache.update_and_fetch(keys, values)
64
+ else
65
+ queries = rope.call(queries)
66
+ keys = rope.call(keys)
67
+ end
68
+
69
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
70
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
71
+ wo.call(output)
72
+ end
73
+ end
74
+
75
+ class MLP < MLX::NN::Module
76
+ def initialize(dim, hidden_dim)
77
+ super()
78
+ self.w1 = MLX::NN::Linear.new(dim, hidden_dim, bias: false)
79
+ self.w2 = MLX::NN::Linear.new(hidden_dim, dim, bias: false)
80
+ self.w3 = MLX::NN::Linear.new(dim, hidden_dim, bias: false)
81
+ end
82
+
83
+ def call(x)
84
+ w2.call(MLX::NN.silu(w1.call(x)) * w3.call(x))
85
+ end
86
+ end
87
+
88
+ class TransformerBlock < MLX::NN::Module
89
+ def initialize(args)
90
+ super()
91
+ self.attention = Attention.new(args)
92
+ self.feed_forward = MLP.new(args.hidden_size, args.intermediate_size)
93
+ self.attention_norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
94
+ self.ffn_norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
95
+ end
96
+
97
+ def call(x, mask: nil, cache: nil)
98
+ r = attention.call(attention_norm.call(x), mask: mask, cache: cache)
99
+ h = x + r
100
+ r = feed_forward.call(ffn_norm.call(h))
101
+ h + r
102
+ end
103
+ end
104
+
105
+ class InternLM2Model < MLX::NN::Module
106
+ def initialize(args)
107
+ super()
108
+ @args = args
109
+ self.tok_embeddings = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
110
+ self.layers = Array.new(args.num_hidden_layers) { TransformerBlock.new(args) }
111
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
112
+ end
113
+
114
+ def call(inputs, cache: nil)
115
+ h = tok_embeddings.call(inputs)
116
+ layer_cache = cache || [nil] * layers.length
117
+
118
+ mask = nil
119
+ mask = "causal" if h.shape[1] > 1
120
+
121
+ layers.each_with_index do |layer, i|
122
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
123
+ end
124
+
125
+ norm.call(h)
126
+ end
127
+ end
128
+
129
+ class Model < MLX::NN::Module
130
+ def initialize(args)
131
+ super()
132
+ @args = args
133
+ self.model = InternLM2Model.new(args)
134
+ unless args.tie_word_embeddings
135
+ self.output = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
136
+ end
137
+ end
138
+
139
+ def call(inputs, cache: nil)
140
+ out = model.call(inputs, cache: cache)
141
+ if @args.tie_word_embeddings
142
+ model.tok_embeddings.as_linear(out)
143
+ else
144
+ output.call(out)
145
+ end
146
+ end
147
+
148
+ def sanitize(weights)
149
+ weights.reject { |k, _| k.include?("attention.rope.inv_freq") }
150
+ end
151
+
152
+ def layers
153
+ model.layers
154
+ end
155
+ end
156
+
157
+ Models.register("internlm2", Model, ModelArgs)
158
+ end
159
+ end
160
+ end
@@ -0,0 +1,237 @@
1
+ module MlxLm
2
+ module Models
3
+ module InternLM3
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "internlm3"
6
+ field :hidden_size, default: 4096
7
+ field :num_hidden_layers, default: 32
8
+ field :intermediate_size, default: 11008
9
+ field :num_attention_heads, default: 32
10
+ field :rms_norm_eps, default: 1e-6
11
+ field :vocab_size, default: 103168
12
+ field :bias, default: false
13
+ field :qkv_bias, default: false
14
+ field :max_position_embeddings, default: 32768
15
+ field :num_key_value_heads, default: nil
16
+ field :rope_theta, default: 10_000.0
17
+ field :rope_traditional, default: false
18
+ field :rope_scaling, default: nil
19
+ field :tie_word_embeddings, default: false
20
+
21
+ def initialize(**kwargs)
22
+ super
23
+ @num_key_value_heads ||= @num_attention_heads
24
+
25
+ return if @rope_scaling.nil?
26
+
27
+ required_keys = %w[factor rope_type]
28
+ missing = required_keys.reject { |k| _config_has_key?(k) }
29
+ unless missing.empty?
30
+ raise ArgumentError, "rope_scaling must contain keys #{required_keys}"
31
+ end
32
+
33
+ rope_type = _config_value("rope_type")
34
+ unless %w[linear dynamic].include?(rope_type)
35
+ raise ArgumentError, "rope_scaling 'rope_type' only supports 'linear' or 'dynamic'"
36
+ end
37
+ end
38
+
39
+ private
40
+
41
+ def _config_has_key?(key)
42
+ return false unless @rope_scaling.respond_to?(:key?)
43
+
44
+ @rope_scaling.key?(key) || @rope_scaling.key?(key.to_sym)
45
+ end
46
+
47
+ def _config_value(key, default = nil)
48
+ return default unless _config_has_key?(key)
49
+
50
+ if @rope_scaling.key?(key)
51
+ @rope_scaling[key]
52
+ else
53
+ @rope_scaling[key.to_sym]
54
+ end
55
+ end
56
+ end
57
+
58
+ class DynamicNTKScalingRoPE < MLX::NN::Module
59
+ def initialize(
60
+ dims,
61
+ max_position_embeddings: 2048,
62
+ traditional: false,
63
+ base: 10_000.0,
64
+ scale: 1.0
65
+ )
66
+ super()
67
+ @max_position_embeddings = max_position_embeddings
68
+ @original_base = base
69
+ @dims = dims
70
+ @traditional = traditional
71
+ @scale = scale
72
+ end
73
+
74
+ def call(x, offset: 0)
75
+ seq_len = x.shape[-2] + offset
76
+ if seq_len > @max_position_embeddings
77
+ scaled_ctx = (@scale * seq_len.to_f / @max_position_embeddings) - (@scale - 1.0)
78
+ base = @original_base * (scaled_ctx**(@dims.to_f / (@dims - 2)))
79
+ else
80
+ base = @original_base
81
+ end
82
+
83
+ MLX::Core.rope(x, @dims, @traditional, base, @scale, offset)
84
+ end
85
+ end
86
+
87
+ class Attention < MLX::NN::Module
88
+ def initialize(args)
89
+ super()
90
+ dim = args.hidden_size
91
+ qkv_bias = args.qkv_bias
92
+ @n_heads = args.num_attention_heads
93
+ @n_kv_heads = args.num_key_value_heads
94
+ @head_dim = args.hidden_size / @n_heads
95
+ @scale = @head_dim**(-0.5)
96
+
97
+ self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: qkv_bias)
98
+ self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: qkv_bias)
99
+ self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: qkv_bias)
100
+ self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: qkv_bias)
101
+
102
+ rope_scale = if args.rope_scaling && _config_value(args.rope_scaling, "rope_type") == "linear"
103
+ 1.0 / _config_value(args.rope_scaling, "factor").to_f
104
+ else
105
+ 2.0
106
+ end
107
+
108
+ self.rope = DynamicNTKScalingRoPE.new(
109
+ @head_dim,
110
+ max_position_embeddings: args.max_position_embeddings,
111
+ traditional: args.rope_traditional,
112
+ base: args.rope_theta,
113
+ scale: rope_scale
114
+ )
115
+ end
116
+
117
+ def call(x, mask: nil, cache: nil)
118
+ mx = MLX::Core
119
+ b, l, _d = x.shape
120
+
121
+ queries = q_proj.call(x)
122
+ keys = k_proj.call(x)
123
+ values = v_proj.call(x)
124
+
125
+ queries = queries.reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
126
+ keys = keys.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
127
+ values = values.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
128
+
129
+ if cache
130
+ queries = rope.call(queries, offset: cache.offset)
131
+ keys = rope.call(keys, offset: cache.offset)
132
+ keys, values = cache.update_and_fetch(keys, values)
133
+ else
134
+ queries = rope.call(queries)
135
+ keys = rope.call(keys)
136
+ end
137
+
138
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
139
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
140
+ o_proj.call(output)
141
+ end
142
+
143
+ private
144
+
145
+ def _config_value(config, key, default = nil)
146
+ return default if config.nil? || !config.respond_to?(:key?)
147
+ return config[key] if config.key?(key)
148
+
149
+ config.fetch(key.to_sym, default)
150
+ end
151
+ end
152
+
153
+ class MLP < MLX::NN::Module
154
+ def initialize(dim, hidden_dim, bias)
155
+ super()
156
+ self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: bias)
157
+ self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: bias)
158
+ self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: bias)
159
+ end
160
+
161
+ def call(x)
162
+ down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x)))
163
+ end
164
+ end
165
+
166
+ class TransformerBlock < MLX::NN::Module
167
+ def initialize(args)
168
+ super()
169
+ self.self_attn = Attention.new(args)
170
+ self.mlp = MLP.new(args.hidden_size, args.intermediate_size, args.bias)
171
+ self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
172
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
173
+ end
174
+
175
+ def call(x, mask: nil, cache: nil)
176
+ r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
177
+ h = x + r
178
+ r = mlp.call(post_attention_layernorm.call(h))
179
+ h + r
180
+ end
181
+ end
182
+
183
+ class InternLM3Model < MLX::NN::Module
184
+ def initialize(args)
185
+ super()
186
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
187
+ self.layers = Array.new(args.num_hidden_layers) { TransformerBlock.new(args) }
188
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
189
+ end
190
+
191
+ def call(inputs, cache: nil)
192
+ h = embed_tokens.call(inputs)
193
+ layer_cache = cache || [nil] * layers.length
194
+
195
+ mask = nil
196
+ mask = "causal" if h.shape[1] > 1
197
+
198
+ layers.each_with_index do |layer, i|
199
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
200
+ end
201
+
202
+ norm.call(h)
203
+ end
204
+ end
205
+
206
+ class Model < MLX::NN::Module
207
+ def initialize(args)
208
+ super()
209
+ @args = args
210
+ self.model = InternLM3Model.new(args)
211
+ unless args.tie_word_embeddings
212
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
213
+ end
214
+ end
215
+
216
+ def call(inputs, cache: nil)
217
+ out = model.call(inputs, cache: cache)
218
+ if @args.tie_word_embeddings
219
+ model.embed_tokens.as_linear(out)
220
+ else
221
+ lm_head.call(out)
222
+ end
223
+ end
224
+
225
+ def sanitize(weights)
226
+ weights.reject { |k, _| k.include?("attention.rope.inv_freq") }
227
+ end
228
+
229
+ def layers
230
+ model.layers
231
+ end
232
+ end
233
+
234
+ Models.register("internlm3", Model, ModelArgs)
235
+ end
236
+ end
237
+ end
@@ -0,0 +1,261 @@
1
+ require_relative "cache"
2
+ require_relative "rope_utils"
3
+
4
+ module MlxLm
5
+ module Models
6
+ module Iquestloopcoder
7
+ class ModelArgs < BaseModelArgs
8
+ field :model_type, default: "iquestloopcoder"
9
+ field :hidden_size
10
+ field :num_hidden_layers
11
+ field :intermediate_size
12
+ field :num_attention_heads
13
+ field :rms_norm_eps
14
+ field :vocab_size
15
+ field :head_dim
16
+ field :num_key_value_heads
17
+ field :max_position_embeddings, default: 131_072
18
+ field :attention_bias, default: false
19
+ field :mlp_bias, default: false
20
+ field :rope_theta, default: 500_000.0
21
+ field :rope_scaling, default: nil
22
+ field :tie_word_embeddings, default: false
23
+ field :loop_num, default: 2
24
+ field :loop_window_size, default: 64
25
+
26
+ def initialize(**kwargs)
27
+ super
28
+ @num_key_value_heads ||= @num_attention_heads
29
+ @head_dim ||= @hidden_size / @num_attention_heads
30
+ end
31
+ end
32
+
33
+ class LoopGateProjection < MLX::NN::Module
34
+ def initialize(num_heads, head_dim)
35
+ super()
36
+ @num_heads = num_heads
37
+ @head_dim = head_dim
38
+
39
+ mx = MLX::Core
40
+ self.weight = mx.zeros([num_heads, head_dim])
41
+ self.bias = mx.zeros([num_heads])
42
+ end
43
+
44
+ def call(query)
45
+ mx = MLX::Core
46
+ projection = weight.reshape([@num_heads, @head_dim, 1])
47
+ gate_logits = mx.matmul(query, projection)
48
+ gate_logits = gate_logits + bias.reshape([1, @num_heads, 1, 1])
49
+ mx.sigmoid(gate_logits)
50
+ end
51
+ end
52
+
53
+ class Attention < MLX::NN::Module
54
+ def initialize(args)
55
+ super()
56
+ dim = args.hidden_size
57
+ @n_heads = args.num_attention_heads
58
+ @n_kv_heads = args.num_key_value_heads
59
+ @head_dim = args.head_dim
60
+ @scale = @head_dim**(-0.5)
61
+
62
+ self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: args.attention_bias)
63
+ self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: args.attention_bias)
64
+ self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: args.attention_bias)
65
+ self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: args.attention_bias)
66
+
67
+ self.rope = MlxLm::Models.initialize_rope(
68
+ @head_dim,
69
+ args.rope_theta,
70
+ false,
71
+ args.rope_scaling,
72
+ max_position_embeddings: args.max_position_embeddings
73
+ )
74
+ end
75
+
76
+ def get_qkv(x, offset: 0)
77
+ b, l, _d = x.shape
78
+
79
+ queries = q_proj.call(x).reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
80
+ keys = k_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
81
+ values = v_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
82
+
83
+ queries = rope.call(queries, offset: offset)
84
+ keys = rope.call(keys, offset: offset)
85
+
86
+ [queries, keys, values]
87
+ end
88
+
89
+ def attention(queries, keys, values, mask: nil, cache: nil)
90
+ _cache = cache
91
+ MLX::Core.scaled_dot_product_attention(queries, keys, values, @scale, mask)
92
+ end
93
+ end
94
+
95
+ class MLP < MLX::NN::Module
96
+ def initialize(args)
97
+ super()
98
+ dim = args.hidden_size
99
+ hidden_dim = args.intermediate_size
100
+ self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: args.mlp_bias)
101
+ self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: args.mlp_bias)
102
+ self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: args.mlp_bias)
103
+ end
104
+
105
+ def call(x)
106
+ down_proj.call(MLX::NN.silu(gate_proj.call(x)) * up_proj.call(x))
107
+ end
108
+ end
109
+
110
+ class TransformerBlock < MLX::NN::Module
111
+ def initialize(args)
112
+ super()
113
+ self.self_attn = Attention.new(args)
114
+ self.mlp = MLP.new(args)
115
+ self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
116
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
117
+ end
118
+ end
119
+
120
+ class IQuestLoopCoderModel < MLX::NN::Module
121
+ def initialize(args)
122
+ super()
123
+ @args = args
124
+ unless args.loop_num == 2
125
+ raise ArgumentError, "Only loop_num=2 is supported, got #{args.loop_num}"
126
+ end
127
+
128
+ self.vocab_size = args.vocab_size
129
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
130
+ self.layers = Array.new(args.num_hidden_layers) { TransformerBlock.new(args) }
131
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
132
+ self.gate_projections = Array.new(args.num_hidden_layers) do
133
+ LoopGateProjection.new(args.num_attention_heads, args.head_dim)
134
+ end
135
+ self.loop_num = args.loop_num
136
+ self.loop_window_size = args.loop_window_size
137
+ end
138
+
139
+ def call(inputs, cache: nil)
140
+ mx = MLX::Core
141
+ b, l = inputs.shape[0], inputs.shape[1]
142
+
143
+ h = embed_tokens.call(inputs)
144
+ layer_count = layers.length
145
+ layer_cache = cache || [nil] * (2 * layer_count)
146
+
147
+ mask = _create_attention_mask(h, layer_cache[0])
148
+ window_mask = _create_attention_mask(h, layer_cache[layer_count], window_size: loop_window_size)
149
+
150
+ loop1_kv = []
151
+ layers.each_with_index do |layer, idx|
152
+ c = layer_cache[idx]
153
+ h_norm = layer.input_layernorm.call(h)
154
+ offset = c ? c.offset : 0
155
+ q1, k1, v1 = layer.self_attn.get_qkv(h_norm, offset: offset)
156
+
157
+ if c
158
+ k1, v1 = c.update_and_fetch(k1, v1)
159
+ end
160
+ loop1_kv << [k1, v1]
161
+
162
+ out = layer.self_attn.attention(q1, k1, v1, mask: mask, cache: c)
163
+ r = layer.self_attn.o_proj.call(out.transpose([0, 2, 1, 3]).reshape([b, l, @args.hidden_size]))
164
+ h = h + r
165
+ r = layer.mlp.call(layer.post_attention_layernorm.call(h))
166
+ h = h + r
167
+ end
168
+
169
+ layers.each_with_index do |layer, idx|
170
+ gate_proj = gate_projections[idx]
171
+ c = layer_cache[layer_count + idx]
172
+ k1, v1 = loop1_kv[idx]
173
+
174
+ h_norm = layer.input_layernorm.call(h)
175
+ offset = c ? c.offset : 0
176
+ q2, k2, v2 = layer.self_attn.get_qkv(h_norm, offset: offset)
177
+
178
+ gate = gate_proj.call(q2)
179
+ attn_global = layer.self_attn.attention(q2, k1, v1, mask: mask, cache: c)
180
+
181
+ if c
182
+ k2, v2 = c.update_and_fetch(k2, v2)
183
+ end
184
+
185
+ attn_local = layer.self_attn.attention(q2, k2, v2, mask: window_mask, cache: c)
186
+ mixed = _mix_attention(gate, attn_global, attn_local)
187
+
188
+ r = layer.self_attn.o_proj.call(mixed.transpose([0, 2, 1, 3]).reshape([b, l, @args.hidden_size]))
189
+ h = h + r
190
+ r = layer.mlp.call(layer.post_attention_layernorm.call(h))
191
+ h = h + r
192
+ end
193
+
194
+ norm.call(h)
195
+ end
196
+
197
+ private
198
+
199
+ def _create_attention_mask(h, cache = nil, window_size: nil)
200
+ n = h.shape[1]
201
+ return cache.make_mask(n) if cache && cache.respond_to?(:make_mask)
202
+ return nil if n == 1
203
+ return _create_causal_mask(n, window_size: window_size) if window_size && n > window_size
204
+
205
+ "causal"
206
+ end
207
+
208
+ def _create_causal_mask(n, offset: 0, window_size: nil)
209
+ mx = MLX::Core
210
+ rinds = mx.arange(0, offset + n, 1, mx.int32).reshape([1, offset + n])
211
+ linds = mx.arange(offset, offset + n, 1, mx.int32).reshape([n, 1])
212
+
213
+ mask = mx.greater_equal(linds, rinds)
214
+ if window_size
215
+ mask = mx.logical_and(mask, mx.less(linds, mx.add(rinds, window_size)))
216
+ end
217
+ mask
218
+ end
219
+
220
+ def _mix_attention(gate, attn_global, attn_local)
221
+ gate = gate.astype(attn_global.dtype)
222
+ (gate * attn_global) + ((1.0 - gate) * attn_local)
223
+ end
224
+ end
225
+
226
+ class Model < MLX::NN::Module
227
+ def initialize(args)
228
+ super()
229
+ @args = args
230
+ self.model_type = args.model_type
231
+ self.model = IQuestLoopCoderModel.new(args)
232
+ unless args.tie_word_embeddings
233
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
234
+ end
235
+ end
236
+
237
+ def call(inputs, cache: nil)
238
+ out = model.call(inputs, cache: cache)
239
+ if @args.tie_word_embeddings
240
+ model.embed_tokens.as_linear(out)
241
+ else
242
+ lm_head.call(out)
243
+ end
244
+ end
245
+
246
+ def layers
247
+ model.layers
248
+ end
249
+
250
+ def make_cache
251
+ Array.new(layers.length) { MlxLm::KVCache.new } +
252
+ Array.new(layers.length) { MlxLm::RotatingKVCache.new(max_size: @args.loop_window_size) }
253
+ end
254
+ end
255
+
256
+ Models.register("iquestloopcoder", Model, ModelArgs)
257
+ end
258
+
259
+ IQuestLoopCoder = Iquestloopcoder unless const_defined?(:IQuestLoopCoder)
260
+ end
261
+ end