mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
@@ -0,0 +1,156 @@
1
+ module MlxLm
2
+ module Models
3
+ module Phi
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "phi"
6
+ field :max_position_embeddings, default: 2048
7
+ field :vocab_size, default: 51_200
8
+ field :hidden_size, default: 2560
9
+ field :num_attention_heads, default: 32
10
+ field :num_hidden_layers, default: 32
11
+ field :num_key_value_heads, default: nil
12
+ field :partial_rotary_factor, default: 0.4
13
+ field :intermediate_size, default: 10_240
14
+ field :layer_norm_eps, default: 1e-5
15
+ field :rope_theta, default: 10_000.0
16
+
17
+ def initialize(**kwargs)
18
+ super
19
+ @num_key_value_heads ||= @num_attention_heads
20
+ end
21
+ end
22
+
23
+ class PhiAttention < MLX::NN::Module
24
+ def initialize(args)
25
+ super()
26
+ @hidden_size = args.hidden_size
27
+ @num_heads = args.num_attention_heads
28
+ @head_dim = @hidden_size / @num_heads
29
+ @num_key_value_heads = args.num_key_value_heads
30
+ @scale = @head_dim**(-0.5)
31
+
32
+ if (@head_dim * @num_heads) != @hidden_size
33
+ raise ArgumentError,
34
+ "hidden_size must be divisible by num_heads (hidden_size=#{@hidden_size}, num_heads=#{@num_heads})"
35
+ end
36
+
37
+ self.q_proj = MLX::NN::Linear.new(@hidden_size, @num_heads * @head_dim, bias: true)
38
+ self.k_proj = MLX::NN::Linear.new(@hidden_size, @num_key_value_heads * @head_dim, bias: true)
39
+ self.v_proj = MLX::NN::Linear.new(@hidden_size, @num_key_value_heads * @head_dim, bias: true)
40
+ self.dense = MLX::NN::Linear.new(@num_heads * @head_dim, @hidden_size, bias: true)
41
+
42
+ self.rope = MLX::NN::RoPE.new(
43
+ (args.partial_rotary_factor * @head_dim).to_i,
44
+ traditional: false,
45
+ base: args.rope_theta
46
+ )
47
+ end
48
+
49
+ def call(x, mask: nil, cache: nil)
50
+ mx = MLX::Core
51
+ queries = q_proj.call(x)
52
+ keys = k_proj.call(x)
53
+ values = v_proj.call(x)
54
+
55
+ b, l, _d = queries.shape
56
+ queries = queries.reshape([b, l, @num_heads, @head_dim]).transpose([0, 2, 1, 3])
57
+ keys = keys.reshape([b, l, @num_key_value_heads, @head_dim]).transpose([0, 2, 1, 3])
58
+ values = values.reshape([b, l, @num_key_value_heads, @head_dim]).transpose([0, 2, 1, 3])
59
+
60
+ if cache
61
+ queries = rope.call(queries, offset: cache.offset)
62
+ keys = rope.call(keys, offset: cache.offset)
63
+ keys, values = cache.update_and_fetch(keys, values)
64
+ else
65
+ queries = rope.call(queries)
66
+ keys = rope.call(keys)
67
+ end
68
+
69
+ output = mx.scaled_dot_product_attention(
70
+ queries.astype(mx.float32),
71
+ keys,
72
+ values,
73
+ @scale,
74
+ mask
75
+ ).astype(values.dtype)
76
+
77
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @num_heads * @head_dim])
78
+ dense.call(output)
79
+ end
80
+ end
81
+
82
+ class PhiMLP < MLX::NN::Module
83
+ def initialize(args)
84
+ super()
85
+ self.fc1 = MLX::NN::Linear.new(args.hidden_size, args.intermediate_size, bias: true)
86
+ self.fc2 = MLX::NN::Linear.new(args.intermediate_size, args.hidden_size, bias: true)
87
+ end
88
+
89
+ def call(x)
90
+ fc2.call(MLX::NN.gelu_approx(fc1.call(x)))
91
+ end
92
+ end
93
+
94
+ class PhiDecoderLayer < MLX::NN::Module
95
+ def initialize(args)
96
+ super()
97
+ self.self_attn = PhiAttention.new(args)
98
+ self.input_layernorm = MLX::NN::LayerNorm.new(args.hidden_size, eps: args.layer_norm_eps)
99
+ self.mlp = PhiMLP.new(args)
100
+ end
101
+
102
+ def call(x, mask: nil, cache: nil)
103
+ h = input_layernorm.call(x)
104
+ attn_h = self_attn.call(h, mask: mask, cache: cache)
105
+ ff_h = mlp.call(h)
106
+ attn_h + ff_h + x
107
+ end
108
+ end
109
+
110
+ class PhiModel < MLX::NN::Module
111
+ def initialize(args)
112
+ super()
113
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
114
+ self.layers = Array.new(args.num_hidden_layers) { PhiDecoderLayer.new(args) }
115
+ self.final_layernorm = MLX::NN::LayerNorm.new(args.hidden_size, eps: args.layer_norm_eps)
116
+ end
117
+
118
+ def call(inputs, cache: nil)
119
+ h = embed_tokens.call(inputs)
120
+ layer_cache = cache || [nil] * layers.length
121
+
122
+ mask = nil
123
+ mask = "causal" if h.shape[1] > 1
124
+
125
+ layers.each_with_index do |layer, i|
126
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
127
+ end
128
+
129
+ final_layernorm.call(h)
130
+ end
131
+ end
132
+
133
+ class Model < MLX::NN::Module
134
+ def initialize(args)
135
+ super()
136
+ self.model = PhiModel.new(args)
137
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: true)
138
+ end
139
+
140
+ def call(inputs, cache: nil)
141
+ lm_head.call(model.call(inputs, cache: cache))
142
+ end
143
+
144
+ def sanitize(weights)
145
+ weights.reject { |k, _| k.include?("rotary_emb.inv_freq") }
146
+ end
147
+
148
+ def layers
149
+ model.layers
150
+ end
151
+ end
152
+
153
+ Models.register("phi", Model, ModelArgs)
154
+ end
155
+ end
156
+ end
@@ -0,0 +1,171 @@
1
+ module MlxLm
2
+ module Models
3
+ module Phi3
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "phi3"
6
+ field :hidden_size, default: 3072
7
+ field :num_hidden_layers, default: 32
8
+ field :num_attention_heads, default: 32
9
+ field :num_key_value_heads, default: nil
10
+ field :intermediate_size, default: 8192
11
+ field :vocab_size, default: 32064
12
+ field :rms_norm_eps, default: 1e-5
13
+ field :rope_theta, default: 10000.0
14
+ field :rope_traditional, default: false
15
+ field :rope_scaling, default: nil
16
+ field :tie_word_embeddings, default: false
17
+ field :head_dim, default: nil
18
+ field :max_position_embeddings, default: 131072
19
+
20
+ def initialize(**kwargs)
21
+ super
22
+ @num_key_value_heads ||= @num_attention_heads
23
+ @head_dim ||= @hidden_size / @num_attention_heads
24
+ end
25
+ end
26
+
27
+ # Phi3 uses combined QKV projection
28
+ class Attention < MLX::NN::Module
29
+ def initialize(args)
30
+ super()
31
+ dim = args.hidden_size
32
+ @n_heads = args.num_attention_heads
33
+ @n_kv_heads = args.num_key_value_heads
34
+ @head_dim = args.head_dim
35
+ @scale = @head_dim**(-0.5)
36
+
37
+ qkv_dim = (@n_heads + 2 * @n_kv_heads) * @head_dim
38
+ self.qkv_proj = MLX::NN::Linear.new(dim, qkv_dim, bias: false)
39
+ self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: false)
40
+
41
+ self.rope = MLX::NN::RoPE.new(
42
+ @head_dim,
43
+ traditional: args.rope_traditional,
44
+ base: args.rope_theta
45
+ )
46
+ end
47
+
48
+ def call(x, mask: nil, cache: nil)
49
+ mx = MLX::Core
50
+ b, l, _d = x.shape
51
+
52
+ qkv = qkv_proj.call(x)
53
+ q_size = @n_heads * @head_dim
54
+ k_size = @n_kv_heads * @head_dim
55
+
56
+ queries = mx.split(qkv, [q_size, q_size + k_size], -1)[0]
57
+ keys = mx.split(qkv, [q_size, q_size + k_size], -1)[1]
58
+ values = mx.split(qkv, [q_size + k_size], -1)[1]
59
+
60
+ queries = queries.reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
61
+ keys = keys.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
62
+ values = values.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
63
+
64
+ if cache
65
+ queries = rope.call(queries, offset: cache.offset)
66
+ keys = rope.call(keys, offset: cache.offset)
67
+ keys, values = cache.update_and_fetch(keys, values)
68
+ else
69
+ queries = rope.call(queries)
70
+ keys = rope.call(keys)
71
+ end
72
+
73
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
74
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
75
+ o_proj.call(output)
76
+ end
77
+ end
78
+
79
+ # Phi3 uses combined gate_up projection
80
+ class MLP < MLX::NN::Module
81
+ def initialize(args)
82
+ super()
83
+ dim = args.hidden_size
84
+ hidden_dim = args.intermediate_size
85
+
86
+ self.gate_up_proj = MLX::NN::Linear.new(dim, 2 * hidden_dim, bias: false)
87
+ self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: false)
88
+ end
89
+
90
+ def call(x)
91
+ mx = MLX::Core
92
+ x = gate_up_proj.call(x)
93
+ hidden_dim = x.shape[-1] / 2
94
+ parts = mx.split(x, [hidden_dim], -1)
95
+ gate = parts[0]
96
+ up = parts[1]
97
+ down_proj.call(MLX::NN.silu(gate) * up)
98
+ end
99
+ end
100
+
101
+ class TransformerBlock < MLX::NN::Module
102
+ def initialize(args)
103
+ super()
104
+ self.self_attn = Attention.new(args)
105
+ self.mlp = MLP.new(args)
106
+ self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
107
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
108
+ end
109
+
110
+ def call(x, mask: nil, cache: nil)
111
+ r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
112
+ h = x + r
113
+ r = mlp.call(post_attention_layernorm.call(h))
114
+ h + r
115
+ end
116
+ end
117
+
118
+ class Phi3Model < MLX::NN::Module
119
+ def initialize(args)
120
+ super()
121
+ @args = args
122
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
123
+ self.layers = Array.new(args.num_hidden_layers) { TransformerBlock.new(args) }
124
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
125
+ end
126
+
127
+ def call(inputs, cache: nil)
128
+ h = embed_tokens.call(inputs)
129
+ layer_cache = cache || [nil] * layers.length
130
+
131
+ mask = nil
132
+ mask = "causal" if h.shape[1] > 1
133
+
134
+ layers.each_with_index do |layer, i|
135
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
136
+ end
137
+
138
+ norm.call(h)
139
+ end
140
+ end
141
+
142
+ class Model < MLX::NN::Module
143
+ def initialize(args)
144
+ super()
145
+ @args = args
146
+ self.model = Phi3Model.new(args)
147
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) unless args.tie_word_embeddings
148
+ end
149
+
150
+ def call(inputs, cache: nil)
151
+ out = model.call(inputs, cache: cache)
152
+ if @args.tie_word_embeddings
153
+ model.embed_tokens.as_linear(out)
154
+ else
155
+ lm_head.call(out)
156
+ end
157
+ end
158
+
159
+ def sanitize(weights)
160
+ weights.reject { |k, _| k.include?("self_attn.rotary_emb.inv_freq") }
161
+ end
162
+
163
+ def layers
164
+ model.layers
165
+ end
166
+ end
167
+
168
+ Models.register("phi3", Model, ModelArgs)
169
+ end
170
+ end
171
+ end
@@ -0,0 +1,196 @@
1
+ module MlxLm
2
+ module Models
3
+ module Phi3small
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "phi3small"
6
+ field :hidden_size
7
+ field :dense_attention_every_n_layers
8
+ field :ff_intermediate_size
9
+ field :gegelu_limit
10
+ field :num_hidden_layers
11
+ field :num_attention_heads
12
+ field :layer_norm_epsilon
13
+ field :vocab_size
14
+ field :num_key_value_heads
15
+ field :mup_attn_multiplier, default: 1.0
16
+ field :mup_use_scaling, default: true
17
+ field :mup_embedding_multiplier, default: 10.0
18
+ field :mup_width_multiplier, default: 8.0
19
+ field :rope_embedding_base, default: 1_000_000.0
20
+ field :rope_position_scale, default: 1.0
21
+ field :blocksparse_block_size, default: 64
22
+ field :blocksparse_num_local_blocks, default: 16
23
+ field :blocksparse_vert_stride, default: 8
24
+
25
+ def initialize(**kwargs)
26
+ super
27
+ @num_key_value_heads ||= @num_attention_heads
28
+ end
29
+ end
30
+
31
+ class Attention < MLX::NN::Module
32
+ def initialize(args, layer_idx)
33
+ super()
34
+
35
+ dim = args.hidden_size
36
+ @n_heads = args.num_attention_heads
37
+ @n_kv_heads = args.num_key_value_heads
38
+ @n_q_per_kv = @n_heads / @n_kv_heads
39
+ @head_dim = dim / @n_heads
40
+
41
+ self.query_key_value = MLX::NN::Linear.new(
42
+ dim,
43
+ (@n_heads + 2 * @n_kv_heads) * @head_dim
44
+ )
45
+ self.dense = MLX::NN::Linear.new(dim, dim)
46
+
47
+ norm_factor = if args.mup_use_scaling
48
+ @head_dim / args.mup_attn_multiplier.to_f
49
+ else
50
+ Math.sqrt(@head_dim)
51
+ end
52
+ @scale = 1.0 / norm_factor
53
+
54
+ self.rope = MLX::NN::RoPE.new(
55
+ @head_dim,
56
+ traditional: false,
57
+ base: args.rope_embedding_base,
58
+ scale: args.rope_position_scale
59
+ )
60
+
61
+ @block_sparse = (layer_idx % args.dense_attention_every_n_layers).zero?
62
+ end
63
+
64
+ def call(x, mask: nil, cache: nil)
65
+ mx = MLX::Core
66
+ b, l, _d = x.shape
67
+
68
+ qkv = query_key_value.call(x)
69
+ q_size = @n_heads * @head_dim
70
+ k_size = @n_kv_heads * @head_dim
71
+
72
+ queries = mx.split(qkv, [q_size, q_size + k_size], -1)[0]
73
+ keys = mx.split(qkv, [q_size, q_size + k_size], -1)[1]
74
+ values = mx.split(qkv, [q_size + k_size], -1)[1]
75
+
76
+ queries = queries.reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
77
+ keys = keys.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
78
+ values = values.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
79
+
80
+ if cache
81
+ queries = rope.call(queries, offset: cache.offset)
82
+ keys = rope.call(keys, offset: cache.offset)
83
+ keys, values = cache.update_and_fetch(keys, values)
84
+ else
85
+ queries = rope.call(queries)
86
+ keys = rope.call(keys)
87
+ end
88
+
89
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
90
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
91
+ dense.call(output)
92
+ end
93
+ end
94
+
95
+ class MLP < MLX::NN::Module
96
+ def initialize(args)
97
+ super()
98
+ dim = args.hidden_size
99
+ @hidden_dim = args.ff_intermediate_size
100
+ self.up_proj = MLX::NN::Linear.new(dim, 2 * @hidden_dim)
101
+ self.down_proj = MLX::NN::Linear.new(@hidden_dim, dim)
102
+ end
103
+
104
+ def call(x)
105
+ mx = MLX::Core
106
+ x = up_proj.call(x)
107
+ a_gelu, a_linear = mx.split(x, [@hidden_dim], -1)
108
+ out_gelu = a_gelu * mx.sigmoid(1.702 * a_gelu)
109
+ down_proj.call(out_gelu * (a_linear + 1.0))
110
+ end
111
+ end
112
+
113
+ class TransformerBlock < MLX::NN::Module
114
+ def initialize(args, layer_idx)
115
+ super()
116
+ self.self_attn = Attention.new(args, layer_idx)
117
+ self.mlp = MLP.new(args)
118
+ self.input_layernorm = MLX::NN::LayerNorm.new(args.hidden_size, eps: args.layer_norm_epsilon)
119
+ self.post_attention_layernorm = MLX::NN::LayerNorm.new(args.hidden_size, eps: args.layer_norm_epsilon)
120
+ end
121
+
122
+ def call(x, mask: nil, cache: nil)
123
+ r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
124
+ h = x + r
125
+ r = mlp.call(post_attention_layernorm.call(h))
126
+ h + r
127
+ end
128
+ end
129
+
130
+ class Phi3Model < MLX::NN::Module
131
+ def initialize(args)
132
+ super()
133
+ @args = args
134
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
135
+ self.layers = Array.new(args.num_hidden_layers) { |layer_idx| TransformerBlock.new(args, layer_idx) }
136
+ self.final_layernorm = MLX::NN::LayerNorm.new(args.hidden_size, eps: args.layer_norm_epsilon)
137
+ end
138
+
139
+ def call(inputs, cache: nil)
140
+ h = embed_tokens.call(inputs)
141
+ h = h * @args.mup_embedding_multiplier if @args.mup_embedding_multiplier
142
+
143
+ layer_cache = cache || [nil] * layers.length
144
+ mask = _create_attention_mask(h, layer_cache[0])
145
+
146
+ layers.each_with_index do |layer, layer_idx|
147
+ h = layer.call(h, mask: mask, cache: layer_cache[layer_idx])
148
+ end
149
+
150
+ final_layernorm.call(h)
151
+ end
152
+
153
+ private
154
+
155
+ def _create_attention_mask(h, cache)
156
+ n = h.shape[1]
157
+ return cache.make_mask(n) if cache && cache.respond_to?(:make_mask)
158
+ return nil if n == 1
159
+
160
+ "causal"
161
+ end
162
+ end
163
+
164
+ class Model < MLX::NN::Module
165
+ def initialize(args)
166
+ super()
167
+ @args = args
168
+ self.model_type = args.model_type
169
+ self.model = Phi3Model.new(args)
170
+ end
171
+
172
+ def call(inputs, cache: nil)
173
+ out = model.call(inputs, cache: cache)
174
+ out = model.embed_tokens.as_linear(out)
175
+ out = out / @args.mup_width_multiplier if @args.mup_width_multiplier
176
+ out
177
+ end
178
+
179
+ def sanitize(weights)
180
+ weights.reject do |key, _|
181
+ key_name = key.to_s
182
+ key_name.include?("self_attn.rotary_emb.inv_freq") ||
183
+ key_name.include?("rotary_emb.inv_freq") ||
184
+ key_name.include?("position_embeddings.inv_freq")
185
+ end
186
+ end
187
+
188
+ def layers
189
+ model.layers
190
+ end
191
+ end
192
+
193
+ Models.register("phi3small", Model, ModelArgs)
194
+ end
195
+ end
196
+ end