fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,656 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ import numpy as np
6
+ from mlx_lm.models.switch_layers import SwitchGLU
7
+
8
+ from ..base import (
9
+ LanguageModelOutput,
10
+ create_attention_mask,
11
+ scaled_dot_product_attention,
12
+ )
13
+ from ..cache import KVCache
14
+ from .config import ModelConfig, TextConfig
15
+
16
+
17
+ class Qwen3VLMoERotaryEmbedding:
18
+ def __init__(
19
+ self, dim, max_position_embeddings=2048, base=10000, rope_scaling=None
20
+ ):
21
+ self.dim = dim
22
+ self.max_position_embeddings = max_position_embeddings
23
+ self.base = base
24
+
25
+ inv_freq = 1.0 / (
26
+ self.base ** (mx.arange(0, self.dim, 2).astype(mx.float32) / self.dim)
27
+ )
28
+ self.inv_freq = inv_freq
29
+
30
+ self.mrope_section = rope_scaling.get("mrope_section", [24, 20, 20])
31
+
32
+ def apply_interleaved_mrope(self, freqs, mrope_section):
33
+ """Apply interleaved MRoPE to 3D rotary embeddings.
34
+ Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
35
+ interleaved [THTHWHTHW...TT], preserving frequency continuity.
36
+ args:
37
+ x: (3, bs, seq_len, head_dim // 2)
38
+ mrope_section: (3,)
39
+ returns:
40
+ x_t: (bs, seq_len, head_dim // 2)
41
+ """
42
+ freqs_t = freqs[0] # just overwrite the first dimension T
43
+ for dim, offset in enumerate((1, 2), start=1): # H, W
44
+ length = mrope_section[dim] * 3
45
+ idx = slice(offset, length, 3)
46
+ freqs_t[..., idx] = freqs[dim, ..., idx]
47
+ return freqs_t
48
+
49
+ def __call__(self, x, position_ids):
50
+
51
+ # In contrast to other models, Qwen3VLMoe has different position ids for the grids
52
+ # So we expand the inv_freq to shape (3, ...)
53
+ if position_ids.ndim == 2:
54
+ position_ids = mx.broadcast_to(
55
+ position_ids[None, ...],
56
+ (3, position_ids.shape[0], position_ids.shape[1]),
57
+ )
58
+
59
+ inv_freq_expanded = mx.broadcast_to(
60
+ self.inv_freq[None, None, :, None].astype(mx.float32),
61
+ (3, position_ids.shape[1], self.inv_freq.shape[0], 1),
62
+ )
63
+ position_ids_expanded = position_ids[:, :, None, :].astype(
64
+ mx.float32
65
+ ) # shape (3, bs, 1, positions)
66
+
67
+ freqs = inv_freq_expanded @ position_ids_expanded
68
+ freqs = mx.swapaxes(freqs, 2, 3)
69
+ freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
70
+ emb = mx.concatenate([freqs, freqs], axis=-1)
71
+ cos = mx.cos(emb)
72
+ sin = mx.sin(emb)
73
+
74
+ return cos.astype(x.dtype), sin.astype(x.dtype)
75
+
76
+
77
+ def rotate_half(x):
78
+ """Rotates half the hidden dims of the input."""
79
+ x1 = x[..., : x.shape[-1] // 2]
80
+ x2 = x[..., x.shape[-1] // 2 :]
81
+ return mx.concatenate([-x2, x1], axis=-1)
82
+
83
+
84
+ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, unqueeze_dim=1):
85
+ """
86
+ Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors.
87
+ Args:
88
+ q (mx.array): The query tensor.
89
+ k (mx.array): The key tensor.
90
+ cos (mx.array): The cosine part of the rotary embedding.
91
+ sin (mx.array): The sine part of the rotary embedding.
92
+ unsqueeze_dim (int, optional): Dimension to unsqueeze. Defaults to 1.
93
+ Returns:
94
+ tuple(mx.array): The rotated query and key tensors.
95
+ """
96
+
97
+ cos = mx.expand_dims(cos, axis=unqueeze_dim)
98
+ sin = mx.expand_dims(sin, axis=unqueeze_dim)
99
+
100
+ # Apply rotary embedding
101
+ q_embed = (q * cos) + (rotate_half(q) * sin)
102
+ k_embed = (k * cos) + (rotate_half(k) * sin)
103
+
104
+ return q_embed, k_embed
105
+
106
+
107
+ class Attention(nn.Module):
108
+ def __init__(self, args: TextConfig):
109
+ super().__init__()
110
+
111
+ dim = args.hidden_size
112
+ self.n_heads = n_heads = args.num_attention_heads
113
+ assert args.num_key_value_heads is not None
114
+ self.n_kv_heads = n_kv_heads = args.num_key_value_heads
115
+
116
+ self.head_dim = head_dim = getattr(
117
+ args, "head_dim", args.hidden_size // args.num_attention_heads
118
+ )
119
+ self.scale = head_dim**-0.5
120
+
121
+ self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
122
+ self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
123
+ self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
124
+ self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
125
+
126
+ self.q_norm = nn.RMSNorm(dims=head_dim, eps=args.rms_norm_eps)
127
+ self.k_norm = nn.RMSNorm(dims=head_dim, eps=args.rms_norm_eps)
128
+
129
+ self.rope_scaling = args.rope_scaling
130
+
131
+ self.rotary_emb = Qwen3VLMoERotaryEmbedding(
132
+ head_dim,
133
+ max_position_embeddings=args.max_position_embeddings,
134
+ base=args.rope_theta,
135
+ rope_scaling=self.rope_scaling,
136
+ )
137
+
138
+ def __call__(
139
+ self,
140
+ x: mx.array,
141
+ mask: Optional[mx.array] = None,
142
+ cache: Optional[KVCache] = None,
143
+ position_ids: Optional[mx.array] = None,
144
+ ) -> mx.array:
145
+ B, L, D = x.shape
146
+
147
+ queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
148
+
149
+ # Prepare the queries, keys and values for the attention computation
150
+ queries = self.q_norm(
151
+ queries.reshape(B, L, self.n_heads, self.head_dim)
152
+ ).transpose(0, 2, 1, 3)
153
+ keys = self.k_norm(
154
+ keys.reshape(B, L, self.n_kv_heads, self.head_dim)
155
+ ).transpose(0, 2, 1, 3)
156
+ values = values.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(
157
+ 0, 2, 1, 3
158
+ )
159
+
160
+ kv_seq_len = keys.shape[-2]
161
+
162
+ if position_ids is None:
163
+ kv_seq_len += cache.offset + 1
164
+ position_ids = mx.arange(cache.offset, cache.offset + L)
165
+ position_ids = mx.expand_dims(position_ids, axis=0)
166
+ position_ids = mx.tile(position_ids, (3, 1, 1))
167
+ else:
168
+ kv_seq_len += cache.offset + 1 if cache is not None else 0
169
+
170
+ cos, sin = self.rotary_emb(values, position_ids)
171
+
172
+ if mask is not None and isinstance(mask, mx.array):
173
+ mask = mask[..., :kv_seq_len]
174
+
175
+ queries, keys = apply_multimodal_rotary_pos_emb(queries, keys, cos, sin)
176
+
177
+ if cache is not None:
178
+ keys, values = cache.update_and_fetch(keys, values)
179
+
180
+ output = scaled_dot_product_attention(
181
+ queries, keys, values, cache, scale=self.scale, mask=mask
182
+ )
183
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
184
+ return self.o_proj(output)
185
+
186
+
187
+ class MLP(nn.Module):
188
+ def __init__(self, dim, hidden_dim):
189
+ super().__init__()
190
+ self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
191
+ self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
192
+ self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
193
+
194
+ def __call__(self, x) -> mx.array:
195
+ return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
196
+
197
+
198
+ class Qwen3MoeSparseMoeBlock(nn.Module):
199
+ def __init__(self, args: TextConfig):
200
+ super().__init__()
201
+ dim = args.hidden_size
202
+ intermediate_size = args.moe_intermediate_size
203
+
204
+ self.num_experts = num_experts = args.num_experts
205
+ self.top_k = args.num_experts_per_tok
206
+ self.norm_topk_prob = args.norm_topk_prob
207
+
208
+ self.gate = nn.Linear(dim, num_experts, bias=False)
209
+ self.switch_mlp = SwitchGLU(dim, intermediate_size, num_experts)
210
+
211
+ def __call__(
212
+ self,
213
+ x: mx.array,
214
+ ):
215
+ gates = self.gate(x)
216
+ gates = mx.softmax(gates, axis=-1, precise=True)
217
+
218
+ k = self.top_k
219
+ inds = mx.argpartition(gates, kth=-k, axis=-1)[..., -k:]
220
+ scores = mx.take_along_axis(gates, inds, axis=-1)
221
+ if self.norm_topk_prob:
222
+ scores /= mx.sum(scores, axis=-1, keepdims=True)
223
+
224
+ y = self.switch_mlp(x, inds)
225
+ y = (y * scores[..., None]).sum(axis=-2)
226
+
227
+ return y
228
+
229
+
230
+ class Qwen3VLMoEDecoderLayer(nn.Module):
231
+ def __init__(self, args: TextConfig, layer_idx: int):
232
+ super().__init__()
233
+ self.hidden_size = args.hidden_size
234
+ self.self_attn = Attention(args)
235
+
236
+ self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
237
+ self.post_attention_layernorm = nn.RMSNorm(
238
+ args.hidden_size, eps=args.rms_norm_eps
239
+ )
240
+ self.args = args
241
+
242
+ if (layer_idx not in args.mlp_only_layers) and (
243
+ args.num_experts > 0 and (layer_idx + 1) % args.decoder_sparse_step == 0
244
+ ):
245
+ self.mlp = Qwen3MoeSparseMoeBlock(args)
246
+ else:
247
+ self.mlp = MLP(args.hidden_size, args.intermediate_size)
248
+
249
+ def __call__(
250
+ self,
251
+ x: mx.array,
252
+ mask: Optional[mx.array] = None,
253
+ cache: Optional[KVCache] = None,
254
+ position_ids: Optional[mx.array] = None,
255
+ ) -> mx.array:
256
+ r = self.self_attn(self.input_layernorm(x), mask, cache, position_ids)
257
+ h = x + r
258
+ r = self.mlp(self.post_attention_layernorm(h))
259
+ out = h + r
260
+ return out
261
+
262
+
263
+ class Qwen3VLMoEModel(nn.Module):
264
+ def __init__(self, args: TextConfig):
265
+ super().__init__()
266
+ self.args = args
267
+ self.vocab_size = args.vocab_size
268
+ self.num_hidden_layers = args.num_hidden_layers
269
+ assert self.vocab_size > 0
270
+ self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
271
+ self.layers = [
272
+ Qwen3VLMoEDecoderLayer(args=args, layer_idx=layer_idx)
273
+ for layer_idx in range(args.num_hidden_layers)
274
+ ]
275
+ self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
276
+
277
+ def __call__(
278
+ self,
279
+ inputs: mx.array,
280
+ inputs_embeds: Optional[mx.array] = None,
281
+ mask: Optional[mx.array] = None,
282
+ cache=None,
283
+ position_ids: Optional[mx.array] = None,
284
+ # args for deepstack
285
+ visual_pos_masks: Optional[mx.array] = None,
286
+ deepstack_visual_embeds: Optional[mx.array] = None,
287
+ ):
288
+ if inputs_embeds is None:
289
+ h = self.embed_tokens(inputs)
290
+ else:
291
+ h = inputs_embeds
292
+
293
+ if cache is None:
294
+ cache = [None] * len(self.layers)
295
+
296
+ if mask is None:
297
+ mask = create_attention_mask(h, cache)
298
+
299
+ for layer_idx, (layer, c) in enumerate(zip(self.layers, cache)):
300
+ h = layer(h, mask, c, position_ids)
301
+
302
+ # Add deepstack visual embeds
303
+ # add visual features to the hidden states of first several layers
304
+ if deepstack_visual_embeds is not None and layer_idx in range(
305
+ len(deepstack_visual_embeds)
306
+ ):
307
+ h = self._deepstack_process(
308
+ h,
309
+ visual_pos_masks,
310
+ deepstack_visual_embeds[layer_idx],
311
+ )
312
+
313
+ return self.norm(h)
314
+
315
+ def _deepstack_process(
316
+ self,
317
+ hidden_states: mx.array,
318
+ visual_pos_masks: mx.array,
319
+ visual_embeds: mx.array,
320
+ ):
321
+ batch_size = hidden_states.shape[0]
322
+
323
+ updated_batches = []
324
+ for b in range(batch_size):
325
+ batch_mask = visual_pos_masks[b]
326
+ batch_hidden = hidden_states[b]
327
+
328
+ batch_indices = mx.array(np.where(batch_mask)[0], dtype=mx.uint32)
329
+
330
+ if len(batch_indices) == 0:
331
+ updated_batches.append(batch_hidden)
332
+ continue
333
+
334
+ batch_result = mx.array(batch_hidden) # avoid modifying in-place
335
+ batch_result = batch_result.at[batch_indices].add(visual_embeds)
336
+
337
+ updated_batches.append(batch_hidden)
338
+
339
+ return mx.stack(updated_batches, axis=0)
340
+
341
+
342
+ class LanguageModel(nn.Module):
343
+ def __init__(self, args: TextConfig, config: ModelConfig = None):
344
+ super().__init__()
345
+ self.args = args
346
+ self.config = config
347
+ self.model_type = args.model_type
348
+ self.model = Qwen3VLMoEModel(args)
349
+ self._rope_deltas = None
350
+ self._position_ids = None
351
+
352
+ if not args.tie_word_embeddings:
353
+ self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
354
+
355
+ def get_rope_index(
356
+ self,
357
+ input_ids: mx.array,
358
+ image_grid_thw: Optional[mx.array] = None,
359
+ video_grid_thw: Optional[mx.array] = None,
360
+ attention_mask: Optional[mx.array] = None,
361
+ ):
362
+ # Calculate RoPE index for image/video tokens
363
+ batch_size, seq_length = input_ids.shape
364
+ position_ids = mx.arange(seq_length, dtype=mx.int32)
365
+ position_ids = mx.broadcast_to(position_ids[None, :], (batch_size, seq_length))
366
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
367
+ image_token_id = self.config.image_token_id
368
+ video_token_id = self.config.video_token_id
369
+ vision_start_token_id = self.config.vision_start_token_id
370
+ mrope_position_deltas = []
371
+ if input_ids is not None and (
372
+ image_grid_thw is not None or video_grid_thw is not None
373
+ ):
374
+ total_input_ids = input_ids
375
+ if attention_mask is None:
376
+ attention_mask = mx.ones_like(input_ids)
377
+ position_ids = mx.ones(
378
+ (3, input_ids.shape[0], input_ids.shape[1]), dtype=input_ids.dtype
379
+ )
380
+ image_index, video_index = 0, 0
381
+ for i, input_ids in enumerate(total_input_ids):
382
+ input_ids = mx.where(
383
+ attention_mask[i] == 1, input_ids, mx.zeros_like(input_ids)
384
+ )
385
+ image_nums, video_nums = 0, 0
386
+ vision_start_indices = mx.sum(
387
+ mx.where(
388
+ input_ids == vision_start_token_id,
389
+ mx.arange(input_ids.shape[0]),
390
+ mx.zeros_like(input_ids),
391
+ )
392
+ )
393
+ vision_tokens = input_ids[vision_start_indices + 1]
394
+ image_nums = (vision_tokens == image_token_id).sum().item()
395
+ video_nums = (vision_tokens == video_token_id).sum().item()
396
+ input_tokens = input_ids.tolist()
397
+ llm_pos_ids_list: list = []
398
+ st = 0
399
+ remain_images, remain_videos = image_nums, video_nums
400
+ for _ in range(image_nums + video_nums):
401
+ if image_token_id in input_tokens and remain_images > 0:
402
+ ed_image = input_tokens.index(image_token_id, st)
403
+ else:
404
+ ed_image = len(input_tokens) + 1
405
+ if video_token_id in input_tokens and remain_videos > 0:
406
+ ed_video = input_tokens.index(video_token_id, st)
407
+ else:
408
+ ed_video = len(input_tokens) + 1
409
+ if ed_image < ed_video:
410
+ t, h, w = (
411
+ image_grid_thw[image_index][0],
412
+ image_grid_thw[image_index][1],
413
+ image_grid_thw[image_index][2],
414
+ )
415
+ image_index += 1
416
+ remain_images -= 1
417
+ ed = ed_image
418
+ else:
419
+ t, h, w = (
420
+ video_grid_thw[video_index][0],
421
+ video_grid_thw[video_index][1],
422
+ video_grid_thw[video_index][2],
423
+ )
424
+ video_index += 1
425
+ remain_videos -= 1
426
+ ed = ed_video
427
+ llm_grid_t, llm_grid_h, llm_grid_w = (
428
+ t.item(),
429
+ h.item() // spatial_merge_size,
430
+ w.item() // spatial_merge_size,
431
+ )
432
+ text_len = ed - st
433
+ st_idx = (
434
+ llm_pos_ids_list[-1].max() + 1
435
+ if len(llm_pos_ids_list) > 0
436
+ else 0
437
+ )
438
+ index = mx.arange(text_len).reshape(1, text_len)
439
+ index = mx.broadcast_to(index, (3, text_len))
440
+ index = index + st_idx
441
+ llm_pos_ids_list.append(index)
442
+ t_index = mx.arange(llm_grid_t).reshape(
443
+ llm_grid_t, 1
444
+ ) # Equivalent to .view(-1, 1)
445
+ t_index = mx.broadcast_to(
446
+ t_index, (llm_grid_t, llm_grid_h * llm_grid_w)
447
+ ) # Equivalent to expand()
448
+ t_index = t_index.flatten() # Flattens to 1D
449
+
450
+ h_index = mx.arange(llm_grid_h).reshape(
451
+ 1, llm_grid_h, 1
452
+ ) # Equivalent to .view(1, -1)
453
+ h_index = mx.broadcast_to(
454
+ h_index, (llm_grid_t, llm_grid_h, llm_grid_w)
455
+ ) # Equivalent to expand()
456
+ h_index = h_index.flatten() # Flattens to 1D
457
+
458
+ w_index = mx.arange(llm_grid_w).reshape(
459
+ 1, 1, llm_grid_w
460
+ ) # Equivalent to .view(1, -1)
461
+ w_index = mx.broadcast_to(
462
+ w_index, (llm_grid_t, llm_grid_h, llm_grid_w)
463
+ ) # Equivalent to expand()
464
+ w_index = w_index.flatten() # Flattens to 1D
465
+
466
+ llm_pos_ids_list.append(
467
+ mx.stack([t_index, h_index, w_index]) + text_len + st_idx
468
+ )
469
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
470
+ if st < len(input_tokens):
471
+ st_idx = (
472
+ llm_pos_ids_list[-1].max() + 1
473
+ if len(llm_pos_ids_list) > 0
474
+ else 0
475
+ )
476
+ text_len = len(input_tokens) - st
477
+
478
+ t_index = mx.arange(text_len).reshape(
479
+ 1, text_len
480
+ ) # Equivalent to .view(-1, 1)
481
+ t_index = mx.broadcast_to(
482
+ t_index, (3, text_len)
483
+ ) # Equivalent to expand(3, -1)
484
+
485
+ llm_pos_ids_list.append(t_index + st_idx)
486
+
487
+ llm_positions = mx.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
488
+ mask = mx.array(attention_mask[i] == 1)
489
+ expanded_mask = mx.expand_dims(mask, axis=0)
490
+ expanded_mask = mx.broadcast_to(expanded_mask, (3, 1, mask.shape[0]))
491
+ expanded_positions = mx.expand_dims(llm_positions, axis=1)
492
+ new_positions = mx.where(
493
+ expanded_mask, expanded_positions, position_ids[:, i : i + 1, :]
494
+ )
495
+ updated_position_ids = mx.concatenate(
496
+ [
497
+ position_ids[:, :i, :],
498
+ new_positions,
499
+ position_ids[:, i + 1 :, :],
500
+ ],
501
+ axis=1,
502
+ )
503
+ position_ids = updated_position_ids
504
+ mrope_position_deltas.append(
505
+ llm_positions.max() + 1 - len(total_input_ids[i])
506
+ )
507
+ mrope_position_deltas = mx.array(mrope_position_deltas)[0]
508
+ return position_ids, mrope_position_deltas
509
+ else:
510
+ if attention_mask is not None:
511
+ position_ids = mx.cumsum(attention_mask.astype(mx.int64), axis=-1) - 1
512
+ position_ids = mx.where(
513
+ attention_mask == 0, mx.ones_like(position_ids), position_ids
514
+ )
515
+ position_ids = mx.expand_dims(position_ids[0], axis=0)
516
+ position_ids = mx.tile(position_ids, (3, 1, 1))
517
+ max_position_ids = position_ids.max(0, keepdims=False)[0].max(
518
+ -1, keepdims=True
519
+ )[0]
520
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
521
+ else:
522
+ position_ids = mx.arange(input_ids.shape[1]).reshape(1, -1)
523
+ position_ids = mx.broadcast_to(
524
+ position_ids, (3, input_ids.shape[0], input_ids.shape[1])
525
+ )
526
+ mrope_position_deltas = mx.zeros(
527
+ [input_ids.shape[0], 1],
528
+ dtype=input_ids.dtype,
529
+ )
530
+ return position_ids, mrope_position_deltas
531
+
532
+ def __call__(
533
+ self,
534
+ inputs: mx.array,
535
+ inputs_embeds: Optional[mx.array] = None,
536
+ mask: Optional[mx.array] = None,
537
+ cache=None,
538
+ # args for deepstack
539
+ visual_pos_masks: Optional[mx.array] = None,
540
+ deepstack_visual_embeds: Optional[mx.array] = None,
541
+ **kwargs,
542
+ ):
543
+ # Slicing visual_pos_masks when prefilling
544
+ n_to_process = kwargs.get("n_to_process", None)
545
+ if n_to_process is not None:
546
+ visual_pos_masks = visual_pos_masks[:, n_to_process:]
547
+
548
+ position_ids = kwargs.pop("position_ids", None)
549
+ pixel_values = kwargs.pop("pixel_values", None)
550
+ image_grid_thw = kwargs.pop("image_grid_thw", None)
551
+ video_grid_thw = kwargs.pop("video_grid_thw", None)
552
+ # reset rope_deltas when processing a new image/video
553
+ if pixel_values is not None:
554
+ self._rope_deltas = None
555
+
556
+ cache_offset = 0
557
+ if cache and cache[0] is not None:
558
+ offset = cache[0].offset
559
+ if isinstance(offset, int):
560
+ cache_offset = offset
561
+ elif isinstance(offset, mx.array):
562
+ cache_offset = (offset if offset.ndim == 0 else offset[0]).item()
563
+ else:
564
+ raise ValueError(f"Unexpected cache offset type: {type(offset)}")
565
+
566
+ # Check if mask shape matches input shape (for chunked prefill compatibility)
567
+ rope_mask = mask
568
+ if mask is not None and mask.shape[-1] != inputs.shape[-1]:
569
+ rope_mask = None
570
+
571
+ if position_ids is None and (rope_mask is None or rope_mask.ndim == 2):
572
+ # Calculate RoPE index once per generation in the pre-fill stage only
573
+ if (
574
+ (cache is not None and cache[0] is not None and (cache_offset == 0))
575
+ or self._rope_deltas is None
576
+ or cache is None
577
+ ):
578
+ if self._position_ids is not None:
579
+ seq_length = inputs.shape[1]
580
+ position_ids = self._position_ids[
581
+ :, :, cache_offset : cache_offset + seq_length
582
+ ]
583
+ else:
584
+ position_ids, rope_deltas = self.get_rope_index(
585
+ inputs, image_grid_thw, video_grid_thw, rope_mask
586
+ )
587
+ self._rope_deltas = rope_deltas
588
+ self._position_ids = position_ids
589
+ else:
590
+ # Use the prev pre-calculated rope-deltas to get the correct position ids
591
+ batch_size, seq_length = inputs.shape
592
+ delta = mx.array(
593
+ cache_offset + self._rope_deltas if cache is not None else 0
594
+ )
595
+ position_ids = mx.arange(seq_length).reshape(1, -1)
596
+ position_ids = mx.broadcast_to(position_ids, (batch_size, seq_length))
597
+
598
+ if cache_offset is not None:
599
+ if delta.ndim == 0:
600
+ delta = mx.expand_dims(delta, axis=0)
601
+
602
+ if delta.shape[0] < batch_size:
603
+ delta = mx.tile(delta, (batch_size, 1))
604
+ else:
605
+ # Slice delta to match batch
606
+ delta = delta[:batch_size]
607
+
608
+ position_ids = mx.add(position_ids, delta)[None, ...]
609
+ position_ids = mx.broadcast_to(
610
+ position_ids, (3, batch_size, seq_length)
611
+ )
612
+
613
+ out = self.model(
614
+ inputs,
615
+ cache=cache,
616
+ inputs_embeds=inputs_embeds,
617
+ position_ids=position_ids,
618
+ visual_pos_masks=visual_pos_masks,
619
+ deepstack_visual_embeds=deepstack_visual_embeds,
620
+ )
621
+ if self.args.tie_word_embeddings:
622
+ out = self.model.embed_tokens.as_linear(out)
623
+ else:
624
+ out = self.lm_head(out)
625
+ return LanguageModelOutput(logits=out)
626
+
627
+ def sanitize(self, weights):
628
+ for l in range(self.args.num_hidden_layers):
629
+ prefix = f"language_model.model.layers.{l}.mlp"
630
+ # Only sanitize MoE layer weights
631
+ if f"{prefix}.experts.up_proj" in weights:
632
+ for n in ["up_proj", "down_proj", "gate_proj"]:
633
+ to_join = weights.pop(f"{prefix}.experts.{n}")
634
+ weights[f"{prefix}.switch_mlp.{n}.weight"] = to_join
635
+ return weights
636
+
637
+ @property
638
+ def quant_predicate(self):
639
+ def predicate(path, _):
640
+ if path.endswith("mlp.gate"):
641
+ return {"group_size": 64, "bits": 8}
642
+ return True
643
+
644
+ return predicate
645
+
646
+ @property
647
+ def layers(self):
648
+ return self.model.layers
649
+
650
+ @property
651
+ def head_dim(self):
652
+ return self.args.hidden_size // self.args.num_attention_heads
653
+
654
+ @property
655
+ def n_kv_heads(self):
656
+ return self.args.num_key_value_heads