mlx-ruby-lm 0.30.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. checksums.yaml +7 -0
  2. data/LICENSE.txt +21 -0
  3. data/README.md +83 -0
  4. data/exe/mlx_lm +7 -0
  5. data/lib/mlx_lm/benchmark.rb +67 -0
  6. data/lib/mlx_lm/chat_template.rb +41 -0
  7. data/lib/mlx_lm/cli.rb +113 -0
  8. data/lib/mlx_lm/config.rb +30 -0
  9. data/lib/mlx_lm/convert_utils.rb +51 -0
  10. data/lib/mlx_lm/generate.rb +204 -0
  11. data/lib/mlx_lm/load_utils.rb +87 -0
  12. data/lib/mlx_lm/model_args.rb +54 -0
  13. data/lib/mlx_lm/models/activations.rb +46 -0
  14. data/lib/mlx_lm/models/afm7.rb +131 -0
  15. data/lib/mlx_lm/models/afmoe.rb +421 -0
  16. data/lib/mlx_lm/models/apertus.rb +179 -0
  17. data/lib/mlx_lm/models/baichuan_m1.rb +306 -0
  18. data/lib/mlx_lm/models/bailing_moe.rb +399 -0
  19. data/lib/mlx_lm/models/bailing_moe_linear.rb +91 -0
  20. data/lib/mlx_lm/models/bitlinear_layers.rb +108 -0
  21. data/lib/mlx_lm/models/bitnet.rb +176 -0
  22. data/lib/mlx_lm/models/cache.rb +792 -0
  23. data/lib/mlx_lm/models/cohere.rb +150 -0
  24. data/lib/mlx_lm/models/cohere2.rb +224 -0
  25. data/lib/mlx_lm/models/dbrx.rb +286 -0
  26. data/lib/mlx_lm/models/deepseek.rb +239 -0
  27. data/lib/mlx_lm/models/deepseek_v2.rb +108 -0
  28. data/lib/mlx_lm/models/deepseek_v3.rb +34 -0
  29. data/lib/mlx_lm/models/deepseek_v32.rb +45 -0
  30. data/lib/mlx_lm/models/dots1.rb +292 -0
  31. data/lib/mlx_lm/models/ernie4_5.rb +165 -0
  32. data/lib/mlx_lm/models/ernie4_5_moe.rb +97 -0
  33. data/lib/mlx_lm/models/exaone.rb +169 -0
  34. data/lib/mlx_lm/models/exaone4.rb +233 -0
  35. data/lib/mlx_lm/models/exaone_moe.rb +421 -0
  36. data/lib/mlx_lm/models/falcon_h1.rb +102 -0
  37. data/lib/mlx_lm/models/gated_delta.rb +136 -0
  38. data/lib/mlx_lm/models/gemma.rb +159 -0
  39. data/lib/mlx_lm/models/gemma2.rb +198 -0
  40. data/lib/mlx_lm/models/gemma3.rb +85 -0
  41. data/lib/mlx_lm/models/gemma3_text.rb +270 -0
  42. data/lib/mlx_lm/models/gemma3n.rb +79 -0
  43. data/lib/mlx_lm/models/glm.rb +164 -0
  44. data/lib/mlx_lm/models/glm4.rb +180 -0
  45. data/lib/mlx_lm/models/glm4_moe.rb +343 -0
  46. data/lib/mlx_lm/models/glm4_moe_lite.rb +131 -0
  47. data/lib/mlx_lm/models/glm_moe_dsa.rb +26 -0
  48. data/lib/mlx_lm/models/gpt2.rb +166 -0
  49. data/lib/mlx_lm/models/gpt_bigcode.rb +154 -0
  50. data/lib/mlx_lm/models/gpt_neox.rb +178 -0
  51. data/lib/mlx_lm/models/gpt_oss.rb +319 -0
  52. data/lib/mlx_lm/models/granite.rb +170 -0
  53. data/lib/mlx_lm/models/granitemoe.rb +58 -0
  54. data/lib/mlx_lm/models/granitemoehybrid.rb +178 -0
  55. data/lib/mlx_lm/models/helium.rb +158 -0
  56. data/lib/mlx_lm/models/hunyuan.rb +378 -0
  57. data/lib/mlx_lm/models/hunyuan_v1_dense.rb +235 -0
  58. data/lib/mlx_lm/models/internlm2.rb +160 -0
  59. data/lib/mlx_lm/models/internlm3.rb +237 -0
  60. data/lib/mlx_lm/models/iquestloopcoder.rb +261 -0
  61. data/lib/mlx_lm/models/jamba.rb +158 -0
  62. data/lib/mlx_lm/models/kimi_k25.rb +98 -0
  63. data/lib/mlx_lm/models/kimi_linear.rb +124 -0
  64. data/lib/mlx_lm/models/kimi_vl.rb +93 -0
  65. data/lib/mlx_lm/models/klear.rb +283 -0
  66. data/lib/mlx_lm/models/lfm2.rb +120 -0
  67. data/lib/mlx_lm/models/lfm2_moe.rb +421 -0
  68. data/lib/mlx_lm/models/lfm2_vl.rb +67 -0
  69. data/lib/mlx_lm/models/lille_130m.rb +148 -0
  70. data/lib/mlx_lm/models/llama.rb +183 -0
  71. data/lib/mlx_lm/models/llama4.rb +357 -0
  72. data/lib/mlx_lm/models/llama4_text.rb +195 -0
  73. data/lib/mlx_lm/models/longcat_flash.rb +153 -0
  74. data/lib/mlx_lm/models/longcat_flash_ngram.rb +137 -0
  75. data/lib/mlx_lm/models/mamba.rb +301 -0
  76. data/lib/mlx_lm/models/mamba2.rb +292 -0
  77. data/lib/mlx_lm/models/mimo.rb +174 -0
  78. data/lib/mlx_lm/models/mimo_v2_flash.rb +491 -0
  79. data/lib/mlx_lm/models/minicpm.rb +169 -0
  80. data/lib/mlx_lm/models/minicpm3.rb +237 -0
  81. data/lib/mlx_lm/models/minimax.rb +282 -0
  82. data/lib/mlx_lm/models/ministral3.rb +304 -0
  83. data/lib/mlx_lm/models/mistral3.rb +84 -0
  84. data/lib/mlx_lm/models/mixtral.rb +192 -0
  85. data/lib/mlx_lm/models/mla.rb +75 -0
  86. data/lib/mlx_lm/models/nanochat.rb +167 -0
  87. data/lib/mlx_lm/models/nemotron.rb +202 -0
  88. data/lib/mlx_lm/models/nemotron_h.rb +212 -0
  89. data/lib/mlx_lm/models/nemotron_nas.rb +404 -0
  90. data/lib/mlx_lm/models/olmo.rb +165 -0
  91. data/lib/mlx_lm/models/olmo2.rb +169 -0
  92. data/lib/mlx_lm/models/olmo3.rb +254 -0
  93. data/lib/mlx_lm/models/olmoe.rb +64 -0
  94. data/lib/mlx_lm/models/openelm.rb +208 -0
  95. data/lib/mlx_lm/models/phi.rb +156 -0
  96. data/lib/mlx_lm/models/phi3.rb +171 -0
  97. data/lib/mlx_lm/models/phi3small.rb +196 -0
  98. data/lib/mlx_lm/models/phimoe.rb +206 -0
  99. data/lib/mlx_lm/models/phixtral.rb +208 -0
  100. data/lib/mlx_lm/models/pipeline.rb +37 -0
  101. data/lib/mlx_lm/models/pixtral.rb +47 -0
  102. data/lib/mlx_lm/models/plamo.rb +169 -0
  103. data/lib/mlx_lm/models/plamo2.rb +173 -0
  104. data/lib/mlx_lm/models/qwen.rb +175 -0
  105. data/lib/mlx_lm/models/qwen2.rb +162 -0
  106. data/lib/mlx_lm/models/qwen2_moe.rb +189 -0
  107. data/lib/mlx_lm/models/qwen2_vl.rb +48 -0
  108. data/lib/mlx_lm/models/qwen3.rb +167 -0
  109. data/lib/mlx_lm/models/qwen3_5.rb +69 -0
  110. data/lib/mlx_lm/models/qwen3_5_moe.rb +54 -0
  111. data/lib/mlx_lm/models/qwen3_moe.rb +166 -0
  112. data/lib/mlx_lm/models/qwen3_next.rb +147 -0
  113. data/lib/mlx_lm/models/qwen3_vl.rb +48 -0
  114. data/lib/mlx_lm/models/qwen3_vl_moe.rb +92 -0
  115. data/lib/mlx_lm/models/recurrent_gemma.rb +444 -0
  116. data/lib/mlx_lm/models/rope_utils.rb +316 -0
  117. data/lib/mlx_lm/models/rwkv7.rb +101 -0
  118. data/lib/mlx_lm/models/seed_oss.rb +167 -0
  119. data/lib/mlx_lm/models/smollm3.rb +89 -0
  120. data/lib/mlx_lm/models/solar_open.rb +79 -0
  121. data/lib/mlx_lm/models/ssm.rb +162 -0
  122. data/lib/mlx_lm/models/stablelm.rb +160 -0
  123. data/lib/mlx_lm/models/starcoder2.rb +161 -0
  124. data/lib/mlx_lm/models/step3p5.rb +479 -0
  125. data/lib/mlx_lm/models/switch_layers.rb +221 -0
  126. data/lib/mlx_lm/models/telechat3.rb +192 -0
  127. data/lib/mlx_lm/models/youtu_llm.rb +230 -0
  128. data/lib/mlx_lm/models.rb +33 -0
  129. data/lib/mlx_lm/perplexity.rb +48 -0
  130. data/lib/mlx_lm/quantize.rb +131 -0
  131. data/lib/mlx_lm/sample_utils.rb +159 -0
  132. data/lib/mlx_lm/server.rb +190 -0
  133. data/lib/mlx_lm/tokenizer_utils.rb +158 -0
  134. data/lib/mlx_lm/tuner/lora.rb +165 -0
  135. data/lib/mlx_lm/version.rb +3 -0
  136. data/lib/mlx_lm/weight_utils.rb +170 -0
  137. data/lib/mlx_lm.rb +135 -0
  138. metadata +272 -0
@@ -0,0 +1,159 @@
1
+ module MlxLm
2
+ module Models
3
+ module Gemma
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "gemma"
6
+ field :hidden_size, default: 3072
7
+ field :num_hidden_layers, default: 28
8
+ field :num_attention_heads, default: 16
9
+ field :num_key_value_heads, default: 16
10
+ field :intermediate_size, default: 24576
11
+ field :vocab_size, default: 256000
12
+ field :rms_norm_eps, default: 1e-6
13
+ field :rope_theta, default: 10000.0
14
+ field :rope_traditional, default: false
15
+ field :head_dim, default: 256
16
+ field :tie_word_embeddings, default: true
17
+
18
+ def initialize(**kwargs)
19
+ super
20
+ @num_key_value_heads ||= @num_attention_heads
21
+ end
22
+ end
23
+
24
+ class Attention < MLX::NN::Module
25
+ def initialize(args)
26
+ super()
27
+ dim = args.hidden_size
28
+ @n_heads = args.num_attention_heads
29
+ @n_kv_heads = args.num_key_value_heads
30
+ @head_dim = args.head_dim
31
+ @scale = @head_dim**(-0.5)
32
+
33
+ self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: false)
34
+ self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: false)
35
+ self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: false)
36
+ self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: false)
37
+
38
+ self.rope = MLX::NN::RoPE.new(
39
+ @head_dim,
40
+ traditional: args.rope_traditional,
41
+ base: args.rope_theta
42
+ )
43
+ end
44
+
45
+ def call(x, mask: nil, cache: nil)
46
+ mx = MLX::Core
47
+ b, l, _d = x.shape
48
+
49
+ queries = q_proj.call(x).reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
50
+ keys = k_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
51
+ values = v_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
52
+
53
+ if cache
54
+ queries = rope.call(queries, offset: cache.offset)
55
+ keys = rope.call(keys, offset: cache.offset)
56
+ keys, values = cache.update_and_fetch(keys, values)
57
+ else
58
+ queries = rope.call(queries)
59
+ keys = rope.call(keys)
60
+ end
61
+
62
+ output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask)
63
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
64
+ o_proj.call(output)
65
+ end
66
+ end
67
+
68
+ class MLP < MLX::NN::Module
69
+ def initialize(args)
70
+ super()
71
+ dim = args.hidden_size
72
+ hidden_dim = args.intermediate_size
73
+
74
+ self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false)
75
+ self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: false)
76
+ self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false)
77
+ end
78
+
79
+ def call(x)
80
+ down_proj.call(MLX::NN.gelu(gate_proj.call(x)) * up_proj.call(x))
81
+ end
82
+ end
83
+
84
+ class TransformerBlock < MLX::NN::Module
85
+ def initialize(args)
86
+ super()
87
+ self.self_attn = Attention.new(args)
88
+ self.mlp = MLP.new(args)
89
+ self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
90
+ self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
91
+ end
92
+
93
+ def call(x, mask: nil, cache: nil)
94
+ r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
95
+ h = x + r
96
+ r = mlp.call(post_attention_layernorm.call(h))
97
+ h + r
98
+ end
99
+ end
100
+
101
+ class GemmaModel < MLX::NN::Module
102
+ def initialize(args)
103
+ super()
104
+ @args = args
105
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
106
+ self.layers = Array.new(args.num_hidden_layers) { TransformerBlock.new(args) }
107
+ self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
108
+ end
109
+
110
+ def call(inputs, cache: nil)
111
+ mx = MLX::Core
112
+ h = embed_tokens.call(inputs)
113
+ # Gemma scales embeddings by sqrt(hidden_size)
114
+ h = h * Math.sqrt(@args.hidden_size)
115
+ layer_cache = cache || [nil] * layers.length
116
+
117
+ mask = nil
118
+ mask = "causal" if h.shape[1] > 1
119
+
120
+ layers.each_with_index do |layer, i|
121
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
122
+ end
123
+
124
+ norm.call(h)
125
+ end
126
+ end
127
+
128
+ class Model < MLX::NN::Module
129
+ def initialize(args)
130
+ super()
131
+ @args = args
132
+ self.model = GemmaModel.new(args)
133
+ unless args.tie_word_embeddings
134
+ self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false)
135
+ end
136
+ end
137
+
138
+ def call(inputs, cache: nil)
139
+ out = model.call(inputs, cache: cache)
140
+ if @args.tie_word_embeddings
141
+ model.embed_tokens.as_linear(out)
142
+ else
143
+ lm_head.call(out)
144
+ end
145
+ end
146
+
147
+ def sanitize(weights)
148
+ weights.reject { |k, _| k.include?("self_attn.rotary_emb.inv_freq") }
149
+ end
150
+
151
+ def layers
152
+ model.layers
153
+ end
154
+ end
155
+
156
+ Models.register("gemma", Model, ModelArgs)
157
+ end
158
+ end
159
+ end
@@ -0,0 +1,198 @@
1
+ module MlxLm
2
+ module Models
3
+ module Gemma2
4
+ class ModelArgs < BaseModelArgs
5
+ field :model_type, default: "gemma2"
6
+ field :hidden_size, default: 3072
7
+ field :num_hidden_layers, default: 28
8
+ field :num_attention_heads, default: 16
9
+ field :num_key_value_heads, default: 16
10
+ field :intermediate_size, default: 24576
11
+ field :vocab_size, default: 256000
12
+ field :head_dim, default: 256
13
+ field :rms_norm_eps, default: 1e-6
14
+ field :rope_theta, default: 10000.0
15
+ field :rope_traditional, default: false
16
+ field :attn_logit_softcapping, default: 50.0
17
+ field :final_logit_softcapping, default: 30.0
18
+ field :query_pre_attn_scalar, default: 144.0
19
+
20
+ def initialize(**kwargs)
21
+ super
22
+ @num_key_value_heads ||= @num_attention_heads
23
+ end
24
+ end
25
+
26
+ # Gemma2 custom RMSNorm: uses (1 + weight) instead of weight
27
+ class Gemma2RMSNorm < MLX::NN::Module
28
+ def initialize(dims, eps: 1e-6)
29
+ super()
30
+ self.weight = MLX::Core.ones([dims])
31
+ @eps = eps
32
+ end
33
+
34
+ def call(x)
35
+ mx = MLX::Core
36
+ # RMS normalization: x / sqrt(mean(x^2) + eps) * (1 + weight)
37
+ x_sq = x * x
38
+ mean_sq = mx.mean(x_sq, -1, keepdims: true)
39
+ norm = x * mx.rsqrt(mean_sq + @eps)
40
+ norm * (weight + 1.0)
41
+ end
42
+ end
43
+
44
+ class Attention < MLX::NN::Module
45
+ def initialize(args)
46
+ super()
47
+ dim = args.hidden_size
48
+ @n_heads = args.num_attention_heads
49
+ @n_kv_heads = args.num_key_value_heads
50
+ @head_dim = args.head_dim
51
+ @scale = 1.0 / (args.query_pre_attn_scalar**0.5)
52
+ @attn_logit_softcapping = args.attn_logit_softcapping
53
+
54
+ self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: false)
55
+ self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: false)
56
+ self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: false)
57
+ self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: false)
58
+
59
+ self.rope = MLX::NN::RoPE.new(
60
+ @head_dim,
61
+ traditional: args.rope_traditional,
62
+ base: args.rope_theta
63
+ )
64
+ end
65
+
66
+ def call(x, mask: nil, cache: nil)
67
+ mx = MLX::Core
68
+ b, l, _d = x.shape
69
+
70
+ queries = q_proj.call(x).reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3])
71
+ keys = k_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
72
+ values = v_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3])
73
+
74
+ if cache
75
+ queries = rope.call(queries, offset: cache.offset)
76
+ keys = rope.call(keys, offset: cache.offset)
77
+ keys, values = cache.update_and_fetch(keys, values)
78
+ else
79
+ queries = rope.call(queries)
80
+ keys = rope.call(keys)
81
+ end
82
+
83
+ # Custom attention with softcapping
84
+ queries = queries * @scale
85
+
86
+ # Manual attention computation for softcapping
87
+ scores = mx.matmul(queries, mx.transpose(keys, [0, 1, 3, 2]))
88
+
89
+ # Apply attention logit softcapping
90
+ scores = mx.tanh(scores / @attn_logit_softcapping) * @attn_logit_softcapping
91
+
92
+ # Apply causal mask
93
+ if mask == "causal"
94
+ n = scores.shape[-1]
95
+ causal_mask = mx.triu(mx.full([n, n], -Float::INFINITY), 1)
96
+ scores = scores + causal_mask
97
+ end
98
+
99
+ scores = mx.softmax(scores, -1)
100
+ output = mx.matmul(scores, values)
101
+
102
+ output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim])
103
+ o_proj.call(output)
104
+ end
105
+ end
106
+
107
+ class MLP < MLX::NN::Module
108
+ def initialize(dim, hidden_dim)
109
+ super()
110
+ self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false)
111
+ self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: false)
112
+ self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false)
113
+ end
114
+
115
+ def call(x)
116
+ # Gemma2 uses gelu_approx instead of silu
117
+ down_proj.call(MLX::NN.gelu_approx(gate_proj.call(x)) * up_proj.call(x))
118
+ end
119
+ end
120
+
121
+ class TransformerBlock < MLX::NN::Module
122
+ def initialize(args)
123
+ super()
124
+ self.self_attn = Attention.new(args)
125
+ self.mlp = MLP.new(args.hidden_size, args.intermediate_size)
126
+ # Gemma2 has 4 norms per block
127
+ self.input_layernorm = Gemma2RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
128
+ self.post_attention_layernorm = Gemma2RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
129
+ self.pre_feedforward_layernorm = Gemma2RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
130
+ self.post_feedforward_layernorm = Gemma2RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
131
+ end
132
+
133
+ def call(x, mask: nil, cache: nil)
134
+ r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache)
135
+ h = x + post_attention_layernorm.call(r)
136
+ r = mlp.call(pre_feedforward_layernorm.call(h))
137
+ h + post_feedforward_layernorm.call(r)
138
+ end
139
+ end
140
+
141
+ class Gemma2Model < MLX::NN::Module
142
+ def initialize(args)
143
+ super()
144
+ @args = args
145
+ self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size)
146
+ self.layers = Array.new(args.num_hidden_layers) { TransformerBlock.new(args) }
147
+ self.norm = Gemma2RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps)
148
+ end
149
+
150
+ def call(inputs, cache: nil)
151
+ mx = MLX::Core
152
+ h = embed_tokens.call(inputs)
153
+ # Gemma2 scales embeddings by sqrt(hidden_size)
154
+ h = h * Math.sqrt(@args.hidden_size)
155
+ layer_cache = cache || [nil] * layers.length
156
+
157
+ mask = nil
158
+ mask = "causal" if h.shape[1] > 1
159
+
160
+ layers.each_with_index do |layer, i|
161
+ h = layer.call(h, mask: mask, cache: layer_cache[i])
162
+ end
163
+
164
+ norm.call(h)
165
+ end
166
+ end
167
+
168
+ class Model < MLX::NN::Module
169
+ def initialize(args)
170
+ super()
171
+ @args = args
172
+ @final_logit_softcapping = args.final_logit_softcapping
173
+ self.model = Gemma2Model.new(args)
174
+ end
175
+
176
+ def call(inputs, cache: nil)
177
+ mx = MLX::Core
178
+ out = model.call(inputs, cache: cache)
179
+ # Tied embeddings
180
+ out = model.embed_tokens.as_linear(out)
181
+ # Final logit softcapping
182
+ out = mx.tanh(out / @final_logit_softcapping) * @final_logit_softcapping
183
+ out
184
+ end
185
+
186
+ def sanitize(weights)
187
+ weights.reject { |k, _| k.include?("self_attn.rotary_emb.inv_freq") }
188
+ end
189
+
190
+ def layers
191
+ model.layers
192
+ end
193
+ end
194
+
195
+ Models.register("gemma2", Model, ModelArgs)
196
+ end
197
+ end
198
+ end
@@ -0,0 +1,85 @@
1
+ require_relative "gemma3_text"
2
+
3
+ module MlxLm
4
+ module Models
5
+ module Gemma3
6
+ class ModelArgs < BaseModelArgs
7
+ field :model_type, default: "gemma3"
8
+ field :text_config, default: nil
9
+ field :vocab_size, default: 262208
10
+
11
+ def self.from_dict(params)
12
+ has_text_config = params.key?("text_config") || params.key?(:text_config)
13
+ return super if has_text_config
14
+
15
+ model_type = params["model_type"] || params[:model_type] || "gemma3"
16
+ vocab_size = params["vocab_size"] || params[:vocab_size] || 262208
17
+ new(model_type: model_type, text_config: params, vocab_size: vocab_size)
18
+ end
19
+
20
+ def initialize(**kwargs)
21
+ super
22
+ @text_config = _stringify_keys(@text_config || {})
23
+ @text_config["vocab_size"] = @vocab_size
24
+ @text_config["num_attention_heads"] ||= 8
25
+ @text_config["num_key_value_heads"] ||= 4
26
+ @text_config["model_type"] ||= "gemma3_text"
27
+ end
28
+
29
+ private
30
+
31
+ def _stringify_keys(hash)
32
+ hash.each_with_object({}) do |(key, value), out|
33
+ out[key.to_s] = value
34
+ end
35
+ end
36
+ end
37
+
38
+ class Model < MLX::NN::Module
39
+ def initialize(args)
40
+ super()
41
+ @args = args
42
+ self.model_type = args.model_type
43
+ self.language_model = Gemma3Text::Model.new(
44
+ Gemma3Text::ModelArgs.from_dict(args.text_config)
45
+ )
46
+ end
47
+
48
+ def call(inputs, cache: nil, input_embeddings: nil)
49
+ language_model.call(
50
+ inputs,
51
+ cache: cache,
52
+ input_embeddings: input_embeddings
53
+ )
54
+ end
55
+
56
+ def sanitize(weights)
57
+ flat_weights = weights.is_a?(Hash) ? weights : weights.to_h
58
+ nested = MLX::Utils.tree_unflatten(flat_weights.to_a)
59
+
60
+ if nested.is_a?(Hash)
61
+ nested.delete("vision_tower")
62
+ nested.delete("multi_modal_projector")
63
+
64
+ language_tree = nested["language_model"] || {}
65
+ language_weights = MLX::Utils.tree_flatten(language_tree, destination: {})
66
+ sanitized_language = language_model.sanitize(language_weights)
67
+ nested["language_model"] = MLX::Utils.tree_unflatten(sanitized_language.to_a)
68
+ end
69
+
70
+ MLX::Utils.tree_flatten(nested, destination: {})
71
+ end
72
+
73
+ def layers
74
+ language_model.layers
75
+ end
76
+
77
+ def make_cache
78
+ language_model.make_cache
79
+ end
80
+ end
81
+
82
+ Models.register("gemma3", Model, ModelArgs)
83
+ end
84
+ end
85
+ end