fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,631 @@
1
+ import math
2
+ from functools import partial
3
+ from typing import Any, Optional
4
+
5
+ import mlx.core as mx
6
+ import mlx.nn as nn
7
+
8
+ from ..base import (
9
+ LanguageModelOutput,
10
+ create_attention_mask,
11
+ scaled_dot_product_attention,
12
+ )
13
+ from ..cache import KVCache, RotatingKVCache
14
+ from .config import TextConfig
15
+
16
+
17
+ class Gemma3nRMSNorm(nn.Module):
18
+ def __init__(
19
+ self,
20
+ dim: int,
21
+ eps: float = 1e-6,
22
+ scale_shift: float = 0.0,
23
+ with_scale: bool = True,
24
+ ):
25
+ super().__init__()
26
+ self.eps = eps
27
+ self.scale_shift = scale_shift
28
+ self.with_scale = with_scale
29
+
30
+ if self.with_scale:
31
+ # Make weight a proper parameter
32
+ self.weight = mx.ones(dim)
33
+ else:
34
+ self.weight = None
35
+
36
+ def _norm(self, x):
37
+ # Match PyTorch's normalization exactly
38
+ return x * mx.rsqrt(x.square().mean(axis=-1, keepdims=True) + self.eps)
39
+
40
+ def __call__(self, x: mx.array) -> mx.array:
41
+ # Match PyTorch implementation
42
+ output = self._norm(x.astype(mx.float32))
43
+
44
+ if self.with_scale:
45
+ output = output * (self.weight + self.scale_shift)
46
+
47
+ return output.astype(x.dtype)
48
+
49
+
50
+ class RMSNoScale(nn.Module):
51
+ def __init__(self, eps: float = 1e-5):
52
+ super().__init__()
53
+ self.eps = eps
54
+
55
+ def __call__(self, x):
56
+ return mx.fast.rms_norm(x, None, self.eps)
57
+
58
+
59
+ class Gemma3nLaurelBlock(nn.Module):
60
+ """Learned Augmented Residual Layer"""
61
+
62
+ def __init__(self, config: TextConfig):
63
+ super().__init__()
64
+ self.config = config
65
+
66
+ self.linear_left = nn.Linear(
67
+ self.config.hidden_size, self.config.laurel_rank, bias=False
68
+ )
69
+ self.linear_right = nn.Linear(
70
+ self.config.laurel_rank, self.config.hidden_size, bias=False
71
+ )
72
+ self.post_laurel_norm = nn.RMSNorm(
73
+ dims=self.config.hidden_size,
74
+ eps=self.config.rms_norm_eps,
75
+ )
76
+
77
+ def __call__(self, x: mx.array) -> mx.array:
78
+ laurel_x = self.linear_left(x)
79
+ laurel_x = self.linear_right(laurel_x)
80
+ normed_laurel_x = self.post_laurel_norm(laurel_x)
81
+ return x + normed_laurel_x
82
+
83
+
84
+ class Gemma3nAttention(nn.Module):
85
+ def __init__(self, config: TextConfig, layer_idx: int, is_kv_shared_layer: bool):
86
+ super().__init__()
87
+ self.is_sliding = config.layer_types[layer_idx] == "sliding_attention"
88
+
89
+ dim = config.hidden_size
90
+ self.n_heads = n_heads = config.num_attention_heads
91
+ self.n_kv_heads = n_kv_heads = config.num_key_value_heads
92
+ self.repeats = n_heads // n_kv_heads
93
+ self.head_dim = head_dim = config.head_dim
94
+ self.layer_idx = layer_idx
95
+
96
+ self.scale = 1.0
97
+
98
+ self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
99
+ self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
100
+ self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
101
+ self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
102
+
103
+ self.q_norm = nn.RMSNorm(dims=config.head_dim, eps=config.rms_norm_eps)
104
+ self.k_norm = nn.RMSNorm(dims=config.head_dim, eps=config.rms_norm_eps)
105
+ self.v_norm = RMSNoScale(eps=config.rms_norm_eps)
106
+
107
+ self.is_kv_shared_layer = is_kv_shared_layer
108
+
109
+ self.rope = nn.RoPE(
110
+ head_dim,
111
+ traditional=False,
112
+ base=(
113
+ config.rope_local_base_freq if self.is_sliding else config.rope_theta
114
+ ),
115
+ )
116
+
117
+ def __call__(
118
+ self,
119
+ x: mx.array,
120
+ mask: Optional[mx.array] = None,
121
+ cache: Optional[Any] = None,
122
+ ) -> mx.array:
123
+ B, L, _ = x.shape
124
+
125
+ queries = self.q_proj(x)
126
+ queries = queries.reshape(B, L, -1, self.head_dim)
127
+ queries = self.q_norm(queries)
128
+
129
+ offset = 0
130
+ if self.is_kv_shared_layer and cache is not None:
131
+ # For shared layers, retrieve KV from the designated cache layer
132
+ keys, values = cache.state
133
+ offset = cache.offset
134
+
135
+ else:
136
+
137
+ if cache is not None:
138
+ offset = cache.offset
139
+
140
+ keys = self.k_proj(x).reshape(B, L, -1, self.head_dim)
141
+ keys = self.k_norm(keys)
142
+ keys = keys.transpose(0, 2, 1, 3)
143
+ keys = self.rope(keys, offset=offset)
144
+
145
+ values = self.v_proj(x).reshape(B, L, -1, self.head_dim)
146
+ values = self.v_norm(values)
147
+ values = values.transpose(0, 2, 1, 3)
148
+
149
+ if cache is not None:
150
+ keys, values = cache.update_and_fetch(keys, values)
151
+
152
+ queries = queries.transpose(0, 2, 1, 3)
153
+ queries = self.rope(queries, offset=offset)
154
+
155
+ if isinstance(mask, mx.array) and mask.shape[-1] != keys.shape[-2]:
156
+ mask = mask[:, : keys.shape[-2]]
157
+
158
+ output = scaled_dot_product_attention(
159
+ queries, keys, values, cache=cache, scale=self.scale, mask=mask
160
+ )
161
+
162
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
163
+
164
+ return self.o_proj(output)
165
+
166
+
167
+ @partial(mx.compile, shapeless=True)
168
+ def gelu_topk(inputs, std_multiplier):
169
+ inputs_mean = mx.mean(inputs, axis=-1, keepdims=True)
170
+ inputs_std = mx.std(inputs, axis=-1, keepdims=True)
171
+ cutoff_x = inputs_mean + inputs_std * std_multiplier.astype(inputs_std.dtype)
172
+ return nn.gelu_approx(mx.maximum(0, inputs - cutoff_x))
173
+
174
+
175
+ class MLP(nn.Module):
176
+ def __init__(self, config: TextConfig, layer_idx: int = 0):
177
+ super().__init__()
178
+ self.config = config
179
+ self.hidden_size = config.hidden_size
180
+ self.intermediate_size = config.intermediate_size
181
+ self.gate_proj = nn.Linear(
182
+ self.hidden_size, self.intermediate_size[0], bias=False
183
+ )
184
+ self.up_proj = nn.Linear(
185
+ self.hidden_size, self.intermediate_size[0], bias=False
186
+ )
187
+ self.down_proj = nn.Linear(
188
+ self.intermediate_size[0], self.hidden_size, bias=False
189
+ )
190
+ if config.activation_sparsity_pattern is not None:
191
+ self.activation_sparsity = config.activation_sparsity_pattern[layer_idx]
192
+ else:
193
+ self.activation_sparsity = 0.0
194
+ if self.activation_sparsity > 0:
195
+ self._std_multiplier = math.sqrt(2.0) * mx.erfinv(
196
+ 2 * self.activation_sparsity - 1
197
+ )
198
+
199
+ def __call__(self, x: mx.array):
200
+ gate_proj = self.gate_proj(x)
201
+ if self.activation_sparsity > 0.0:
202
+ activations = gelu_topk(gate_proj, self._std_multiplier)
203
+ else:
204
+ activations = nn.gelu_approx(gate_proj)
205
+ up_proj = self.up_proj(x)
206
+ down_proj = self.down_proj(activations * up_proj)
207
+ return down_proj
208
+
209
+
210
+ class Gemma3nAltUp(nn.Module):
211
+ """Alternating Updates (AltUp)"""
212
+
213
+ def __init__(self, config: TextConfig):
214
+ super().__init__()
215
+ self.config = config
216
+
217
+ self.correct_output_scale = mx.zeros((self.config.hidden_size,))
218
+ self.correction_coefs = nn.Linear(
219
+ self.config.altup_num_inputs, self.config.altup_num_inputs, bias=False
220
+ )
221
+ self.prediction_coefs = nn.Linear(
222
+ self.config.altup_num_inputs, self.config.altup_num_inputs**2, bias=False
223
+ )
224
+ self.modality_router = nn.Linear(
225
+ self.config.hidden_size, self.config.altup_num_inputs, bias=False
226
+ )
227
+ self.router_norm = nn.RMSNorm(
228
+ dims=self.config.hidden_size,
229
+ eps=self.config.rms_norm_eps,
230
+ )
231
+
232
+ def compute_router_modalities(self, x: mx.array) -> mx.array:
233
+ router_inputs = self.router_norm(x) * (self.config.hidden_size**-1.0)
234
+ routed = self.modality_router(router_inputs).astype(mx.float32)
235
+ return mx.tanh(routed)
236
+
237
+ def predict(self, x: mx.array) -> mx.array:
238
+ modalities = self.compute_router_modalities(x[self.config.altup_active_idx])
239
+
240
+ self.prediction_coefs.weight = self.prediction_coefs.weight.astype(mx.float32)
241
+
242
+ if self.config.altup_coef_clip is not None:
243
+ self.prediction_coefs.weight = mx.clip(
244
+ self.prediction_coefs.weight,
245
+ -self.config.altup_coef_clip,
246
+ self.config.altup_coef_clip,
247
+ )
248
+
249
+ all_coefs = (
250
+ self.prediction_coefs(modalities)
251
+ .reshape(
252
+ *modalities.shape[:-1],
253
+ self.config.altup_num_inputs,
254
+ self.config.altup_num_inputs,
255
+ )
256
+ .transpose(0, 1, 3, 2)
257
+ )
258
+
259
+ x_up = x.astype(mx.float32)
260
+ x_permuted = x_up.transpose(1, 2, 3, 0)
261
+ predictions = mx.matmul(x_permuted, all_coefs)
262
+ predictions = predictions.transpose(3, 0, 1, 2)
263
+ predictions += x_up
264
+ return predictions.astype(x.dtype)
265
+
266
+ def correct(self, predictions: mx.array, activated: mx.array):
267
+ modalities = self.compute_router_modalities(activated)
268
+
269
+ self.correction_coefs.weight = self.correction_coefs.weight.astype(mx.float32)
270
+
271
+ if self.config.altup_coef_clip is not None:
272
+ self.correction_coefs.weight = mx.clip(
273
+ self.correction_coefs.weight,
274
+ -self.config.altup_coef_clip,
275
+ self.config.altup_coef_clip,
276
+ )
277
+
278
+ all_coefs = self.correction_coefs(modalities) + 1.0
279
+
280
+ active_x = predictions[self.config.altup_active_idx]
281
+ innovation = activated - active_x
282
+
283
+ all_coefs = all_coefs.transpose(2, 1, 0)
284
+ corrected = innovation[None] * all_coefs[:, None]
285
+ corrected += predictions
286
+
287
+ return corrected.astype(activated.dtype)
288
+
289
+
290
+ class Gemma3nDecoderLayer(nn.Module):
291
+ def __init__(self, config: TextConfig, layer_idx: int, is_kv_shared_layer: bool):
292
+ super().__init__()
293
+ self.config = config
294
+ self.hidden_size = config.hidden_size
295
+ self.layer_idx = layer_idx
296
+ self.self_attn = Gemma3nAttention(config, layer_idx, is_kv_shared_layer)
297
+ self.mlp = MLP(config, layer_idx=layer_idx)
298
+ self.input_layernorm = nn.RMSNorm(
299
+ self.hidden_size,
300
+ eps=config.rms_norm_eps,
301
+ )
302
+
303
+ self.post_attention_layernorm = nn.RMSNorm(
304
+ self.hidden_size,
305
+ eps=config.rms_norm_eps,
306
+ )
307
+ self.pre_feedforward_layernorm = nn.RMSNorm(
308
+ self.hidden_size,
309
+ eps=config.rms_norm_eps,
310
+ )
311
+ self.post_feedforward_layernorm = nn.RMSNorm(
312
+ self.hidden_size,
313
+ eps=config.rms_norm_eps,
314
+ )
315
+ self.is_sliding = self.self_attn.is_sliding
316
+ self.sliding_window = config.sliding_window
317
+
318
+ self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
319
+
320
+ self.altup = Gemma3nAltUp(config)
321
+ self.laurel = Gemma3nLaurelBlock(config)
322
+ self.per_layer_input_gate = nn.Linear(
323
+ self.hidden_size, self.hidden_size_per_layer_input, bias=False
324
+ )
325
+ self.per_layer_projection = nn.Linear(
326
+ self.hidden_size_per_layer_input, self.hidden_size, bias=False
327
+ )
328
+ self.post_per_layer_input_norm = nn.RMSNorm(
329
+ self.hidden_size,
330
+ eps=config.rms_norm_eps,
331
+ )
332
+
333
+ def __call__(
334
+ self,
335
+ x: mx.array,
336
+ mask: Optional[mx.array] = None,
337
+ cache: Optional[Any] = None,
338
+ per_layer_input: Optional[mx.array] = None,
339
+ ):
340
+ predictions = self.altup.predict(x)
341
+ active_prediction = predictions[self.config.altup_active_idx]
342
+
343
+ active_prediction_normed = self.input_layernorm(active_prediction)
344
+ laurel_output = self.laurel(active_prediction_normed)
345
+
346
+ attn = self.self_attn(
347
+ active_prediction_normed,
348
+ mask,
349
+ cache,
350
+ )
351
+
352
+ attn = self.post_attention_layernorm(attn)
353
+
354
+ attn_gated = active_prediction + attn
355
+ attn_laurel = (attn_gated + laurel_output) * (2.0**-0.5)
356
+
357
+ attn_norm = self.pre_feedforward_layernorm(attn_laurel)
358
+ attn_ffw = self.mlp(attn_norm)
359
+ attn_ffw_norm = self.post_feedforward_layernorm(attn_ffw)
360
+ attn_ffw_laurel_gated = attn_laurel + attn_ffw_norm
361
+
362
+ corrected_predictions = self.altup.correct(predictions, attn_ffw_laurel_gated)
363
+
364
+ first_prediction = corrected_predictions[self.config.altup_active_idx]
365
+ if self.config.altup_correct_scale:
366
+ first_prediction = first_prediction * self.altup.correct_output_scale
367
+
368
+ first_prediction = self.per_layer_input_gate(first_prediction)
369
+ first_prediction = nn.gelu_approx(first_prediction)
370
+
371
+ first_prediction = mx.multiply(first_prediction, per_layer_input)
372
+
373
+ first_prediction = self.per_layer_projection(first_prediction)
374
+ first_prediction = self.post_per_layer_input_norm(first_prediction)
375
+
376
+ corrected_predictions[1:] = corrected_predictions[1:] + first_prediction
377
+
378
+ return corrected_predictions
379
+
380
+
381
+ class Gemma3Model(nn.Module):
382
+ def __init__(self, config: TextConfig):
383
+ super().__init__()
384
+ self.config = config
385
+ self.hidden_size = config.hidden_size
386
+ self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
387
+ self.vocab_size = config.vocab_size
388
+ self.vocab_size_per_layer_input = config.vocab_size_per_layer_input
389
+ self.num_hidden_layers = config.num_hidden_layers
390
+ self.first_kv_shared_layer_idx = (
391
+ config.num_hidden_layers - config.num_kv_shared_layers
392
+ )
393
+
394
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
395
+ self.layers = [
396
+ Gemma3nDecoderLayer(
397
+ config=config,
398
+ layer_idx=layer_idx,
399
+ is_kv_shared_layer=layer_idx >= self.first_kv_shared_layer_idx,
400
+ )
401
+ for layer_idx in range(config.num_hidden_layers)
402
+ ]
403
+
404
+ self.embed_tokens_per_layer = nn.Embedding(
405
+ config.vocab_size_per_layer_input,
406
+ config.num_hidden_layers * config.hidden_size_per_layer_input,
407
+ )
408
+
409
+ self.per_layer_model_projection = nn.Linear(
410
+ config.hidden_size,
411
+ config.num_hidden_layers * config.hidden_size_per_layer_input,
412
+ bias=False,
413
+ )
414
+
415
+ self.per_layer_projection_norm = nn.RMSNorm(
416
+ dims=config.hidden_size_per_layer_input,
417
+ eps=config.rms_norm_eps,
418
+ )
419
+
420
+ self.altup_projections = [
421
+ nn.Linear(config.hidden_size, config.hidden_size, bias=False)
422
+ for _ in range(1, self.config.altup_num_inputs)
423
+ ]
424
+
425
+ self.altup_unembed_projections = [
426
+ nn.Linear(config.hidden_size, config.hidden_size, bias=False)
427
+ for _ in range(1, self.config.altup_num_inputs)
428
+ ]
429
+
430
+ self.norm = nn.RMSNorm(
431
+ config.hidden_size,
432
+ eps=config.rms_norm_eps,
433
+ )
434
+
435
+ self.first_sliding_idx = self.config.layer_types.index("sliding_attention")
436
+ self.first_full_idx = self.config.layer_types.index("full_attention")
437
+
438
+ concrete_layers = self.config.layer_types[: self.first_kv_shared_layer_idx]
439
+ shared_full_idx = (
440
+ len(concrete_layers) - 1 - concrete_layers[::-1].index("full_attention")
441
+ )
442
+ shared_sliding_idx = (
443
+ len(concrete_layers) - 1 - concrete_layers[::-1].index("sliding_attention")
444
+ )
445
+
446
+ self.layer_idx_to_cache_idx = []
447
+ for i, layer_type in enumerate(self.config.layer_types):
448
+ if i < self.first_kv_shared_layer_idx:
449
+ self.layer_idx_to_cache_idx.append(i)
450
+ else:
451
+ if layer_type == "full_attention":
452
+ self.layer_idx_to_cache_idx.append(shared_full_idx)
453
+ elif layer_type == "sliding_attention":
454
+ self.layer_idx_to_cache_idx.append(shared_sliding_idx)
455
+ else:
456
+ raise NotImplementedError(f"Unknown layer type: {layer_type}")
457
+
458
+ def __call__(
459
+ self,
460
+ inputs: mx.array = None,
461
+ inputs_embeds: mx.array = None,
462
+ mask: mx.array = None,
463
+ cache=None,
464
+ **kwargs,
465
+ ):
466
+ per_layer_inputs = kwargs.pop("per_layer_inputs", None)
467
+ n_to_process = kwargs.pop("n_to_process", None)
468
+ if per_layer_inputs is not None and n_to_process is not None:
469
+ per_layer_inputs = per_layer_inputs[:, :n_to_process]
470
+
471
+ if inputs_embeds is None:
472
+ h = self.embed_tokens(inputs) * (self.hidden_size**0.5)
473
+ else:
474
+ h = inputs_embeds
475
+
476
+ if per_layer_inputs is None and inputs is not None:
477
+ per_layer_inputs = self.get_per_layer_inputs(inputs)
478
+
479
+ per_layer_inputs = self.project_per_layer_inputs(h, per_layer_inputs)
480
+
481
+ if cache is None:
482
+ cache = [None] * len(self.layers)
483
+
484
+ if mask is None:
485
+ full_mask = create_attention_mask(
486
+ h,
487
+ cache[self.first_full_idx :],
488
+ )
489
+ sliding_window_mask = create_attention_mask(
490
+ h,
491
+ cache[self.first_sliding_idx :],
492
+ )
493
+ h0 = h
494
+
495
+ # Expand hidden_states to support per-layer inputs
496
+ target_magnitude = mx.mean(h0**2, axis=-1, keepdims=True) ** 0.5
497
+
498
+ h_list = [h0]
499
+ h_list.extend([proj(h0) for proj in self.altup_projections])
500
+ h = mx.stack(h_list, axis=0)
501
+ mags = mx.mean(h[1:] ** 2, axis=-1, keepdims=True) ** 0.5
502
+ h[1:] = h[1:] * (target_magnitude / mx.maximum(mags, mx.finfo(h0.dtype).min))
503
+
504
+ for i, layer in enumerate(self.layers):
505
+ per_layer_input = per_layer_inputs[:, :, i, :]
506
+
507
+ is_global = self.config.layer_types[i] == "full_attention"
508
+
509
+ local_mask = mask
510
+ if mask is None and is_global:
511
+ local_mask = full_mask
512
+ elif mask is None:
513
+ local_mask = sliding_window_mask
514
+
515
+ h = layer(
516
+ h,
517
+ local_mask,
518
+ cache[self.layer_idx_to_cache_idx[i]],
519
+ per_layer_input,
520
+ )
521
+
522
+ # Per-layer inputs to single output
523
+ target_magnitude = mx.mean(h[0] ** 2, axis=-1, keepdims=True) ** 0.5
524
+ for i, proj in enumerate(self.altup_unembed_projections):
525
+ h[i + 1] = proj(h[i + 1])
526
+ mags = mx.mean(h[1:] ** 2, axis=-1, keepdims=True) ** 0.5
527
+ h[1:] = h[1:] * (target_magnitude / mx.maximum(mags, mx.finfo(h0.dtype).min))
528
+
529
+ h = mx.mean(h, axis=0)
530
+
531
+ return self.norm(h)
532
+
533
+ def get_per_layer_inputs(self, input_ids: mx.array) -> mx.array:
534
+ per_layer_inputs_mask = input_ids < self.vocab_size_per_layer_input
535
+ tokens = mx.where(per_layer_inputs_mask, input_ids, mx.zeros_like(input_ids))
536
+ result = self.embed_tokens_per_layer(tokens) * (
537
+ self.hidden_size_per_layer_input**0.5
538
+ )
539
+ return result.reshape(
540
+ *input_ids.shape,
541
+ self.num_hidden_layers,
542
+ self.hidden_size_per_layer_input,
543
+ )
544
+
545
+ def project_per_layer_inputs(
546
+ self,
547
+ inputs_embeds: mx.array,
548
+ per_layer_inputs: mx.array,
549
+ ) -> mx.array:
550
+ per_layer_projection = self.per_layer_model_projection(inputs_embeds) * (
551
+ self.hidden_size**-0.5
552
+ )
553
+ per_layer_projection = per_layer_projection.reshape(
554
+ *inputs_embeds.shape[:-1],
555
+ self.config.num_hidden_layers,
556
+ self.config.hidden_size_per_layer_input,
557
+ )
558
+ per_layer_projection = self.per_layer_projection_norm(per_layer_projection)
559
+ return (per_layer_projection + per_layer_inputs) * (2.0**-0.5)
560
+
561
+
562
+ @partial(mx.compile, shapeless=True)
563
+ def logit_softcap(softcap, x):
564
+ out = mx.tanh(x / softcap)
565
+ out = out * softcap
566
+ return out
567
+
568
+
569
+ class LanguageModel(nn.Module):
570
+ def __init__(self, config: TextConfig):
571
+ super().__init__()
572
+ self.config = config
573
+ self.model_type = config.model_type
574
+ self.model = Gemma3Model(config)
575
+ self.final_logit_softcapping = config.final_logit_softcapping
576
+
577
+ def __call__(
578
+ self,
579
+ inputs: mx.array = None,
580
+ inputs_embeds: Optional[mx.array] = None,
581
+ mask: Optional[mx.array] = None,
582
+ cache=None,
583
+ **kwargs,
584
+ ):
585
+ out = self.model(
586
+ inputs, inputs_embeds=inputs_embeds, mask=mask, cache=cache, **kwargs
587
+ )
588
+ out = self.model.embed_tokens.as_linear(out)
589
+ if self.final_logit_softcapping is not None:
590
+ out = logit_softcap(self.final_logit_softcapping, out)
591
+ return LanguageModelOutput(logits=out)
592
+
593
+ def sanitize(self, weights):
594
+ sanitized_weights = {}
595
+
596
+ for k, v in weights.items():
597
+ if "language_model.model" not in k and "language_model.lm_head" not in k:
598
+ new_key = k.replace("language_model", "language_model.model")
599
+ sanitized_weights[new_key] = v
600
+ elif "self_attn.rotary_emb.inv_freq" in k:
601
+ continue
602
+ else:
603
+ sanitized_weights[k] = v
604
+ return sanitized_weights
605
+
606
+ @property
607
+ def layers(self):
608
+ return self.model.layers
609
+
610
+ @property
611
+ def head_dim(self):
612
+ return self.config.head_dim
613
+
614
+ @property
615
+ def n_kv_heads(self):
616
+ return self.config.num_key_value_heads
617
+
618
+ def make_cache(self):
619
+ caches = []
620
+ for layer_type in self.config.layer_types[
621
+ : self.model.first_kv_shared_layer_idx
622
+ ]:
623
+ if layer_type == "full_attention":
624
+ caches.append(KVCache())
625
+ elif layer_type == "sliding_attention":
626
+ caches.append(
627
+ RotatingKVCache(max_size=self.config.sliding_window, keep=0)
628
+ )
629
+ else:
630
+ raise NotImplementedError(f"Unknown layer type: {layer_type}")
631
+ return caches