mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
@@ -0,0 +1,169 @@
1
+ module MlxLm
2
+ module Models
3
+ module OLMo2
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "olmo2"
6
+ field :hidden_size, default: 4096
7
+ field :num_hidden_layers, default: 32
8
+ field :num_attention_heads, default: 32
9
+ field :num_key_value_heads, default: nil
10
+ field :intermediate_size, default: 11008
11
+ field :vocab_size, default: 50304
12
+ field :rms_norm_eps, default: 1e-6
13
+ field :rope_theta, default: 10000.0
14
+ field :rope_traditional, default: false
15
+ field :rope_scaling, default: nil
16
+ field :attention_bias, default: false
17
+ field :mlp_bias, default: false
18
+ field :tie_word_embeddings, default: true
19
+ field :head_dim, default: nil
20
+ field :max_position_embeddings, default: nil
21
+
22
+ def initialize(**kwargs)
23
+ super
24
+ @num_key_value_heads ||= @num_attention_heads
25
+ @head_dim ||= @hidden_size / @num_attention_heads
26
+ end
27
+ end
28
+
29
+ class Attention < MLX::NN::Module
30
+ def initialize(args)
31
+ super()
32
+ dim = args.hidden_size
33
+ @n_heads = args.num_attention_heads
34
+ @n_kv_heads = args.num_key_value_heads
35
+ @head_dim = args.head_dim
36
+ @scale = @head_dim**(-0.5)
37
+
38
+ bias = args.attention_bias
39
+ self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: bias)
40
+ self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: bias)
41
+ self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: bias)
42
+ self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: bias)
43
+
44
+ # OLMo2: Q and K normalization
45
+ self.q_norm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps)
46
+ self.k_norm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps)
47
+
48
+ self.rope = MLX::NN::RoPE.new(
49
+ @head_dim,
50
+ traditional: args.rope_traditional,
51
+ base: args.rope_theta
52
+ )
53
+ end
54
+
55
+ def call(x, mask: nil, cache: nil)
56
+ mx = MLX::Core
57
+ b, l, _d = x.shape
58
+
59
+ queries = q_proj.call(x).reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
60
+ keys = k_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
61
+ values = v_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
62
+
63
+ # Apply Q/K normalization
64
+ queries = q_norm.call(queries)
65
+ keys = k_norm.call(keys)
66
+
67
+ if cache
68
+ queries = rope.call(queries, offset: cache.offset)
69
+ keys = rope.call(keys, offset: cache.offset)
70
+ keys, values = cache.update_and_fetch(keys, values)
71
+ else
72
+ queries = rope.call(queries)
73
+ keys = rope.call(keys)
74
+ end
75
+
76
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
77
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
78
+ o_proj.call(output)
79
+ end
80
+ end
81
+
82
+ class MLP < MLX::NN::Module
83
+ def initialize(dim, hidden_dim, bias: false)
84
+ super()
85
+ self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: bias)
86
+ self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: bias)
87
+ self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: bias)
88
+ end
89
+
90
+ def call(x)
91
+ down_proj.call(MLX::NN.silu(gate_proj.call(x)) * up_proj.call(x))
92
+ end
93
+ end
94
+
95
+ class TransformerBlock < MLX::NN::Module
96
+ def initialize(args)
97
+ super()
98
+ self.self_attn = Attention.new(args)
99
+ self.mlp = MLP.new(args.hidden_size, args.intermediate_size, bias: args.mlp_bias)
100
+ self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
101
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
102
+ end
103
+
104
+ def call(x, mask: nil, cache: nil)
105
+ r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
106
+ h = x + r
107
+ r = mlp.call(post_attention_layernorm.call(h))
108
+ h + r
109
+ end
110
+ end
111
+
112
+ class OLMo2Model < MLX::NN::Module
113
+ def initialize(args)
114
+ super()
115
+ @args = args
116
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
117
+ self.layers = Array.new(args.num_hidden_layers) { TransformerBlock.new(args) }
118
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
119
+ end
120
+
121
+ def call(inputs, cache: nil)
122
+ h = embed_tokens.call(inputs)
123
+ layer_cache = cache || [nil] * layers.length
124
+
125
+ mask = nil
126
+ mask = "causal" if h.shape[1] > 1
127
+
128
+ layers.each_with_index do |layer, i|
129
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
130
+ end
131
+
132
+ norm.call(h)
133
+ end
134
+ end
135
+
136
+ class Model < MLX::NN::Module
137
+ def initialize(args)
138
+ super()
139
+ @args = args
140
+ self.model = OLMo2Model.new(args)
141
+ unless args.tie_word_embeddings
142
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
143
+ end
144
+ end
145
+
146
+ def call(inputs, cache: nil)
147
+ out = model.call(inputs, cache: cache)
148
+ if @args.tie_word_embeddings
149
+ model.embed_tokens.as_linear(out)
150
+ else
151
+ lm_head.call(out)
152
+ end
153
+ end
154
+
155
+ def sanitize(weights)
156
+ result = weights.reject { |k, _| k.include?("self_attn.rotary_emb.inv_freq") }
157
+ result.delete("lm_head.weight") if @args.tie_word_embeddings
158
+ result
159
+ end
160
+
161
+ def layers
162
+ model.layers
163
+ end
164
+ end
165
+
166
+ Models.register("olmo2", Model, ModelArgs)
167
+ end
168
+ end
169
+ end
@@ -0,0 +1,254 @@
1
+ module MlxLm
2
+ module Models
3
+ module OLMo3
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "olmo3"
6
+ field :hidden_size
7
+ field :num_hidden_layers
8
+ field :intermediate_size
9
+ field :num_attention_heads
10
+ field :rms_norm_eps
11
+ field :vocab_size
12
+ field :max_position_embeddings
13
+ field :sliding_window
14
+ field :rope_theta
15
+ field :attention_bias, default: false
16
+ field :layer_types, default: nil
17
+ field :num_key_value_heads, default: nil
18
+ field :head_dim, default: nil
19
+ field :rope_scaling, default: nil
20
+ field :tie_word_embeddings, default: false
21
+
22
+ def initialize(**kwargs)
23
+ super
24
+ @num_key_value_heads ||= @num_attention_heads
25
+ @head_dim ||= @hidden_size / @num_attention_heads
26
+ @layer_types ||= Array.new(@num_hidden_layers) do |i|
27
+ ((i + 1) % 4).zero? ? "full_attention" : "sliding_attention"
28
+ end
29
+ end
30
+ end
31
+
32
+ class Olmo3Attention < MLX::NN::Module
33
+ def initialize(args, layer_idx:)
34
+ super()
35
+ @num_attention_heads = args.num_attention_heads
36
+ @num_key_value_heads = args.num_key_value_heads
37
+ @head_dim = args.head_dim
38
+ @scale = @head_dim**(-0.5)
39
+
40
+ self.q_proj = MLX::NN::Linear.new(
41
+ args.hidden_size,
42
+ @num_attention_heads * @head_dim,
43
+ bias: args.attention_bias
44
+ )
45
+ self.k_proj = MLX::NN::Linear.new(
46
+ args.hidden_size,
47
+ @num_key_value_heads * @head_dim,
48
+ bias: args.attention_bias
49
+ )
50
+ self.v_proj = MLX::NN::Linear.new(
51
+ args.hidden_size,
52
+ @num_key_value_heads * @head_dim,
53
+ bias: args.attention_bias
54
+ )
55
+ self.o_proj = MLX::NN::Linear.new(
56
+ @num_attention_heads * @head_dim,
57
+ args.hidden_size,
58
+ bias: args.attention_bias
59
+ )
60
+
61
+ self.q_norm = MLX::NN::RMSNorm.new(
62
+ @num_attention_heads * @head_dim,
63
+ eps: args.rms_norm_eps
64
+ )
65
+ self.k_norm = MLX::NN::RMSNorm.new(
66
+ @num_key_value_heads * @head_dim,
67
+ eps: args.rms_norm_eps
68
+ )
69
+
70
+ if args.layer_types[layer_idx] != "full_attention"
71
+ self.rope = MLX::NN::RoPE.new(@head_dim, traditional: false, base: args.rope_theta)
72
+ else
73
+ self.rope = MlxLm::Models.initialize_rope(
74
+ @head_dim,
75
+ args.rope_theta,
76
+ false,
77
+ args.rope_scaling,
78
+ max_position_embeddings: args.max_position_embeddings
79
+ )
80
+ end
81
+ end
82
+
83
+ def call(x, mask: nil, cache: nil)
84
+ mx = MLX::Core
85
+ b, l, _d = x.shape
86
+
87
+ queries = q_norm.call(q_proj.call(x))
88
+ keys = k_norm.call(k_proj.call(x))
89
+ values = v_proj.call(x)
90
+
91
+ queries = queries.reshape([b, l, @num_attention_heads, @head_dim]).transpose([0, 2, 1, 3])
92
+ keys = keys.reshape([b, l, @num_key_value_heads, @head_dim]).transpose([0, 2, 1, 3])
93
+ values = values.reshape([b, l, @num_key_value_heads, @head_dim]).transpose([0, 2, 1, 3])
94
+
95
+ if cache
96
+ queries = rope.call(queries, offset: cache.offset)
97
+ keys = rope.call(keys, offset: cache.offset)
98
+ keys, values = cache.update_and_fetch(keys, values)
99
+ else
100
+ queries = rope.call(queries)
101
+ keys = rope.call(keys)
102
+ end
103
+
104
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
105
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @num_attention_heads * @head_dim])
106
+ o_proj.call(output)
107
+ end
108
+ end
109
+
110
+ class Olmo3MLP < MLX::NN::Module
111
+ def initialize(args)
112
+ super()
113
+ self.gate_proj = MLX::NN::Linear.new(args.hidden_size, args.intermediate_size, bias: false)
114
+ self.down_proj = MLX::NN::Linear.new(args.intermediate_size, args.hidden_size, bias: false)
115
+ self.up_proj = MLX::NN::Linear.new(args.hidden_size, args.intermediate_size, bias: false)
116
+ end
117
+
118
+ def call(x)
119
+ down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x)))
120
+ end
121
+ end
122
+
123
+ class Olmo3DecoderLayer < MLX::NN::Module
124
+ def initialize(args, layer_idx:)
125
+ super()
126
+ self.self_attn = Olmo3Attention.new(args, layer_idx: layer_idx)
127
+ self.mlp = Olmo3MLP.new(args)
128
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
129
+ self.post_feedforward_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
130
+ end
131
+
132
+ def call(x, mask: nil, cache: nil)
133
+ r = post_attention_layernorm.call(self_attn.call(x, mask: mask, cache: cache))
134
+ h = x + r
135
+ r = post_feedforward_layernorm.call(mlp.call(h))
136
+ h + r
137
+ end
138
+ end
139
+
140
+ class Olmo3Model < MLX::NN::Module
141
+ attr_reader :layer_types
142
+
143
+ def initialize(args)
144
+ super()
145
+ @sliding_window = args.sliding_window
146
+ @layer_types = args.layer_types
147
+
148
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
149
+ self.layers = Array.new(args.num_hidden_layers) do |i|
150
+ Olmo3DecoderLayer.new(args, layer_idx: i)
151
+ end
152
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
153
+
154
+ self.swa_idx = @layer_types.index("sliding_attention") || 0
155
+ self.ga_idx = @layer_types.index("full_attention") || 0
156
+ end
157
+
158
+ def call(inputs, cache: nil)
159
+ h = embed_tokens.call(inputs)
160
+ layer_cache = cache || [nil] * layers.length
161
+
162
+ full_mask = _create_attention_mask(h, layer_cache[ga_idx])
163
+ sliding_window_mask = _create_attention_mask(
164
+ h,
165
+ layer_cache[swa_idx],
166
+ window_size: @sliding_window
167
+ )
168
+
169
+ layers.each_with_index do |layer, i|
170
+ mask = @layer_types[i] == "full_attention" ? full_mask : sliding_window_mask
171
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
172
+ end
173
+
174
+ norm.call(h)
175
+ end
176
+
177
+ private
178
+
179
+ def _create_attention_mask(h, cache = nil, window_size: nil)
180
+ n = h.shape[1]
181
+ if cache && cache.respond_to?(:make_mask)
182
+ return cache.make_mask(n, window_size: window_size)
183
+ end
184
+
185
+ if window_size
186
+ offset = 0
187
+ if cache
188
+ offset = cache.offset
189
+ if cache.instance_variable_defined?(:@max_size)
190
+ max_size = cache.instance_variable_get(:@max_size)
191
+ offset = [max_size - 1, offset].min if max_size && max_size > 0
192
+ end
193
+ end
194
+ return _create_causal_mask(n, offset: offset, window_size: window_size) if offset + n > window_size
195
+ end
196
+ return nil if n == 1
197
+
198
+ "causal"
199
+ end
200
+
201
+ def _create_causal_mask(n, offset: 0, window_size: nil)
202
+ mx = MLX::Core
203
+ rinds = mx.arange(0, offset + n, 1, mx.int32).reshape([1, offset + n])
204
+ linds = mx.arange(offset, offset + n, 1, mx.int32).reshape([n, 1])
205
+
206
+ mask = mx.greater_equal(linds, rinds)
207
+ if window_size
208
+ mask = mx.logical_and(mask, mx.less(linds, mx.add(rinds, window_size)))
209
+ end
210
+ mask
211
+ end
212
+ end
213
+
214
+ class Model < MLX::NN::Module
215
+ attr_reader :args
216
+
217
+ def initialize(args)
218
+ super()
219
+ @args = args
220
+ self.model_type = args.model_type
221
+ self.model = Olmo3Model.new(args)
222
+ unless args.tie_word_embeddings
223
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
224
+ end
225
+ end
226
+
227
+ def call(inputs, cache: nil)
228
+ out = model.call(inputs, cache: cache)
229
+ if args.tie_word_embeddings
230
+ model.embed_tokens.as_linear(out)
231
+ else
232
+ lm_head.call(out)
233
+ end
234
+ end
235
+
236
+ def layers
237
+ model.layers
238
+ end
239
+
240
+ def make_cache
241
+ model.layer_types.map do |layer_type|
242
+ if layer_type == "full_attention"
243
+ KVCache.new
244
+ else
245
+ RotatingKVCache.new(max_size: args.sliding_window)
246
+ end
247
+ end
248
+ end
249
+ end
250
+
251
+ Models.register("olmo3", Model, ModelArgs)
252
+ end
253
+ end
254
+ end
@@ -0,0 +1,64 @@
1
+ require_relative "olmo2"
2
+
3
+ module MlxLm
4
+ module Models
5
+ module OLMoE
6
+ class ModelArgs < OLMo2::ModelArgs
7
+ field :model_type, default: "olmoe"
8
+ field :num_experts
9
+ field :num_experts_per_tok
10
+ field :norm_topk_prob, default: false
11
+ end
12
+
13
+ class Model < OLMo2::Model
14
+ def sanitize(weights)
15
+ result = super(weights)
16
+ rewrite_expert_weights(result)
17
+ end
18
+
19
+ private
20
+
21
+ def rewrite_expert_weights(weights)
22
+ return weights unless weights.key?("model.layers.0.mlp.experts.0.up_proj.weight")
23
+
24
+ mx = MLX::Core
25
+
26
+ layers.length.times do |layer_idx|
27
+ prefix = "model.layers.#{layer_idx}.mlp"
28
+ %w[up_proj down_proj gate_proj].each do |projection|
29
+ %w[weight scales biases].each do |param|
30
+ first_key = "#{prefix}.experts.0.#{projection}.#{param}"
31
+ next unless weights.key?(first_key)
32
+
33
+ expert_count = @args.num_experts || infer_expert_count(weights, prefix, projection, param)
34
+ next unless expert_count && expert_count.positive?
35
+
36
+ expert_keys = (0...expert_count).map do |expert_idx|
37
+ "#{prefix}.experts.#{expert_idx}.#{projection}.#{param}"
38
+ end
39
+ next unless expert_keys.all? { |key| weights.key?(key) }
40
+
41
+ weights["#{prefix}.switch_mlp.#{projection}.#{param}"] = mx.stack(expert_keys.map { |key| weights.delete(key) })
42
+ end
43
+ end
44
+ end
45
+
46
+ weights
47
+ end
48
+
49
+ def infer_expert_count(weights, prefix, projection, param)
50
+ pattern = /\A#{Regexp.escape(prefix)}\.experts\.(\d+)\.#{projection}\.#{param}\z/
51
+ indices = weights.keys.filter_map do |key|
52
+ match = pattern.match(key)
53
+ match[1].to_i if match
54
+ end
55
+ return 0 if indices.empty?
56
+
57
+ indices.max + 1
58
+ end
59
+ end
60
+
61
+ Models.register("olmoe", Model, ModelArgs)
62
+ end
63
+ end
64
+ end
@@ -0,0 +1,208 @@
1
+ module MlxLm
2
+ module Models
3
+ module OpenELM
4
+ module_function
5
+
6
+ def make_divisible(v, divisor = 8, min_value = nil)
7
+ min_value ||= divisor
8
+ rounded = ((v + (divisor.to_f / 2)).to_i / divisor) * divisor
9
+ new_v = [min_value, rounded].max
10
+ new_v += divisor if new_v < (0.9 * v)
11
+ new_v
12
+ end
13
+
14
+ class ModelArgs < BaseModelArgs
15
+ field :model_type, default: "openelm"
16
+ field :head_dim, default: 64
17
+ field :num_transformer_layers, default: 12
18
+ field :model_dim, default: 2048
19
+ field :vocab_size, default: 32_000
20
+ field :ffn_dim_divisor, default: 8
21
+ field :num_query_heads, default: [32]
22
+ field :num_kv_heads, default: []
23
+ field :ffn_multipliers, default: [1.0]
24
+ field :ffn_with_glu, default: true
25
+ field :normalize_qk_projections, default: true
26
+ field :share_input_output_layers, default: true
27
+ field :rms_norm_eps, default: 1e-6
28
+ field :rope_freq_constant, default: 10_000.0
29
+
30
+ def initialize(**kwargs)
31
+ super
32
+ @num_query_heads = normalize_schedule(@num_query_heads, @num_transformer_layers, 1, "num_query_heads").map(&:to_i)
33
+
34
+ if @num_kv_heads.nil? || Array(@num_kv_heads).empty?
35
+ @num_kv_heads = @num_query_heads.dup
36
+ else
37
+ @num_kv_heads = normalize_schedule(@num_kv_heads, @num_transformer_layers, @num_query_heads[0], "num_kv_heads").map(&:to_i)
38
+ end
39
+
40
+ @ffn_multipliers = normalize_schedule(@ffn_multipliers, @num_transformer_layers, 1.0, "ffn_multipliers").map(&:to_f)
41
+ end
42
+
43
+ private
44
+
45
+ def normalize_schedule(values, layers, fallback, field_name)
46
+ items = Array(values)
47
+ items = [fallback] if items.empty?
48
+ items = Array.new(layers, items[0]) if items.length == 1 && layers > 1
49
+
50
+ unless items.length == layers
51
+ raise ArgumentError, "#{field_name} must have #{layers} entries, got #{items.length}"
52
+ end
53
+
54
+ items
55
+ end
56
+ end
57
+
58
+ class Attention < MLX::NN::Module
59
+ def initialize(args, layer_id:)
60
+ super()
61
+ @head_dim = args.head_dim
62
+ @n_heads = args.num_query_heads[layer_id]
63
+ @n_kv_heads = args.num_kv_heads[layer_id]
64
+ @scale = @head_dim**(-0.5)
65
+ @normalize_qk_projections = args.normalize_qk_projections
66
+
67
+ op_size = (@n_heads + (2 * @n_kv_heads)) * @head_dim
68
+ self.qkv_proj = MLX::NN::Linear.new(args.model_dim, op_size, bias: false)
69
+ self.out_proj = MLX::NN::Linear.new(@n_heads * @head_dim, args.model_dim, bias: false)
70
+
71
+ if @normalize_qk_projections
72
+ self.q_norm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps)
73
+ self.k_norm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps)
74
+ end
75
+
76
+ self.rope = MLX::NN::RoPE.new(@head_dim, traditional: false, base: args.rope_freq_constant)
77
+ end
78
+
79
+ def call(x, mask: nil, cache: nil)
80
+ mx = MLX::Core
81
+ b, l, _d = x.shape
82
+
83
+ qkv = qkv_proj.call(x)
84
+ qkv = qkv.reshape([b, l, @n_heads + (2 * @n_kv_heads), @head_dim]).transpose([0, 2, 1, 3])
85
+ queries, keys, values = mx.split(qkv, [@n_heads, @n_heads + @n_kv_heads], 1)
86
+
87
+ if @normalize_qk_projections
88
+ queries = q_norm.call(queries)
89
+ keys = k_norm.call(keys)
90
+ end
91
+
92
+ if cache
93
+ queries = rope.call(queries, offset: cache.offset)
94
+ keys = rope.call(keys, offset: cache.offset)
95
+ keys, values = cache.update_and_fetch(keys, values)
96
+ else
97
+ queries = rope.call(queries)
98
+ keys = rope.call(keys)
99
+ end
100
+
101
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
102
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
103
+ out_proj.call(output)
104
+ end
105
+ end
106
+
107
+ class MLP < MLX::NN::Module
108
+ def initialize(args, layer_id:)
109
+ super()
110
+ @ffn_with_glu = args.ffn_with_glu
111
+ dim = args.model_dim
112
+ multiplier = args.ffn_multipliers[layer_id]
113
+ @intermediate_dim = OpenELM.make_divisible(multiplier * dim, args.ffn_dim_divisor).to_i
114
+
115
+ proj_1_dim = @ffn_with_glu ? (2 * @intermediate_dim) : @intermediate_dim
116
+ self.proj_1 = MLX::NN::Linear.new(dim, proj_1_dim, bias: false)
117
+ self.proj_2 = MLX::NN::Linear.new(@intermediate_dim, dim, bias: false)
118
+ end
119
+
120
+ def call(x)
121
+ x = proj_1.call(x)
122
+ x = if @ffn_with_glu
123
+ gate, value = MLX::Core.split(x, [@intermediate_dim], -1)
124
+ Activations.swiglu(gate, value)
125
+ else
126
+ MLX::NN.gelu_approx(x)
127
+ end
128
+ proj_2.call(x)
129
+ end
130
+ end
131
+
132
+ class TransformerBlock < MLX::NN::Module
133
+ def initialize(args, layer_id:)
134
+ super()
135
+ self.attn = Attention.new(args, layer_id: layer_id)
136
+ self.ffn = MLP.new(args, layer_id: layer_id)
137
+ self.attn_norm = MLX::NN::RMSNorm.new(args.model_dim, eps: args.rms_norm_eps)
138
+ self.ffn_norm = MLX::NN::RMSNorm.new(args.model_dim, eps: args.rms_norm_eps)
139
+ end
140
+
141
+ def call(x, mask: nil, cache: nil)
142
+ r = attn.call(attn_norm.call(x), mask: mask, cache: cache)
143
+ h = x + r
144
+ r = ffn.call(ffn_norm.call(h))
145
+ h + r
146
+ end
147
+ end
148
+
149
+ class OpenELMModel < MLX::NN::Module
150
+ def initialize(args)
151
+ super()
152
+ self.token_embeddings = MLX::NN::Embedding.new(args.vocab_size, args.model_dim)
153
+ self.layers = Array.new(args.num_transformer_layers) do |layer_id|
154
+ TransformerBlock.new(args, layer_id: layer_id)
155
+ end
156
+ self.norm = MLX::NN::RMSNorm.new(args.model_dim, eps: args.rms_norm_eps)
157
+ end
158
+
159
+ def call(inputs, cache: nil)
160
+ h = token_embeddings.call(inputs)
161
+ layer_cache = cache || [nil] * layers.length
162
+ mask = _create_attention_mask(h, layer_cache[0])
163
+
164
+ layers.each_with_index do |layer, i|
165
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
166
+ end
167
+ norm.call(h)
168
+ end
169
+
170
+ private
171
+
172
+ def _create_attention_mask(h, cache)
173
+ return cache.make_mask(h.shape[1]) if cache && cache.respond_to?(:make_mask)
174
+ return nil if h.shape[1] == 1
175
+
176
+ "causal"
177
+ end
178
+ end
179
+
180
+ class Model < MLX::NN::Module
181
+ def initialize(args)
182
+ super()
183
+ @args = args
184
+ self.model_type = args.model_type
185
+ self.transformer = OpenELMModel.new(args)
186
+ unless args.share_input_output_layers
187
+ self.lm_head = MLX::NN::Linear.new(args.model_dim, args.vocab_size, bias: false)
188
+ end
189
+ end
190
+
191
+ def call(inputs, cache: nil)
192
+ out = transformer.call(inputs, cache: cache)
193
+ if @args.share_input_output_layers
194
+ transformer.token_embeddings.as_linear(out)
195
+ else
196
+ lm_head.call(out)
197
+ end
198
+ end
199
+
200
+ def layers
201
+ transformer.layers
202
+ end
203
+ end
204
+
205
+ Models.register("openelm", Model, ModelArgs)
206
+ end
207
+ end
208
+ end