fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,622 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ import numpy as np
6
+ from mlx_lm.models.switch_layers import SwitchGLU
7
+
8
+ from ..base import (
9
+ LanguageModelOutput,
10
+ create_attention_mask,
11
+ scaled_dot_product_attention,
12
+ )
13
+ from ..cache import KVCache
14
+ from .config import TextConfig, ThinkerConfig
15
+
16
+
17
+ class Qwen3OmniMoeThinkerTextRotaryEmbedding:
18
+ def __init__(
19
+ self, dim, max_position_embeddings=2048, base=10000, rope_scaling=None
20
+ ):
21
+ self.dim = dim
22
+ self.max_position_embeddings = max_position_embeddings
23
+ self.base = base
24
+
25
+ inv_freq = 1.0 / (
26
+ self.base ** (mx.arange(0, self.dim, 2).astype(mx.float32) / self.dim)
27
+ )
28
+ self.inv_freq = inv_freq
29
+ self.attention_scaling = 1.0
30
+
31
+ rope_scaling = rope_scaling or {}
32
+ self.mrope_section = rope_scaling.get("mrope_section", [24, 20, 20])
33
+
34
+ def apply_interleaved_mrope(self, freqs, mrope_section):
35
+ D = freqs.shape[-1]
36
+ indices = mx.arange(D)
37
+
38
+ freqs_t = freqs[0]
39
+
40
+ limit1 = mrope_section[1] * 3
41
+ mask1 = (indices % 3 == 1) & (indices < limit1)
42
+ freqs_t = mx.where(mask1, freqs[1], freqs_t)
43
+
44
+ limit2 = mrope_section[2] * 3
45
+ mask2 = (indices % 3 == 2) & (indices < limit2)
46
+ freqs_t = mx.where(mask2, freqs[2], freqs_t)
47
+
48
+ return freqs_t
49
+
50
+ def __call__(self, x, position_ids):
51
+
52
+ if position_ids.ndim == 2:
53
+ position_ids = mx.broadcast_to(
54
+ position_ids[None, ...],
55
+ (3, position_ids.shape[0], position_ids.shape[1]),
56
+ )
57
+
58
+ inv_freq_expanded = mx.broadcast_to(
59
+ self.inv_freq[None, None, :, None].astype(mx.float32),
60
+ (3, position_ids.shape[1], self.inv_freq.shape[0], 1),
61
+ )
62
+ position_ids_expanded = position_ids[:, :, None, :].astype(mx.float32)
63
+
64
+ freqs = inv_freq_expanded @ position_ids_expanded
65
+ freqs = mx.swapaxes(freqs, 2, 3)
66
+ freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
67
+ emb = mx.concatenate([freqs, freqs], axis=-1)
68
+ cos = mx.cos(emb) * self.attention_scaling
69
+ sin = mx.sin(emb) * self.attention_scaling
70
+
71
+ return cos.astype(x.dtype), sin.astype(x.dtype)
72
+
73
+
74
+ def rotate_half(x):
75
+ x1 = x[..., : x.shape[-1] // 2]
76
+ x2 = x[..., x.shape[-1] // 2 :]
77
+ return mx.concatenate([-x2, x1], axis=-1)
78
+
79
+
80
+ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, unqueeze_dim=1):
81
+ cos = mx.expand_dims(cos, axis=unqueeze_dim)
82
+ sin = mx.expand_dims(sin, axis=unqueeze_dim)
83
+
84
+ q_embed = (q * cos) + (rotate_half(q) * sin)
85
+ k_embed = (k * cos) + (rotate_half(k) * sin)
86
+
87
+ return q_embed, k_embed
88
+
89
+
90
+ class Attention(nn.Module):
91
+ def __init__(self, args: TextConfig):
92
+ super().__init__()
93
+
94
+ dim = args.hidden_size
95
+ self.n_heads = n_heads = args.num_attention_heads
96
+ assert args.num_key_value_heads is not None
97
+ self.n_kv_heads = n_kv_heads = args.num_key_value_heads
98
+
99
+ self.head_dim = head_dim = getattr(
100
+ args, "head_dim", args.hidden_size // args.num_attention_heads
101
+ )
102
+ self.scale = head_dim**-0.5
103
+
104
+ self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
105
+ self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
106
+ self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
107
+ self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
108
+
109
+ self.q_norm = nn.RMSNorm(dims=head_dim, eps=args.rms_norm_eps)
110
+ self.k_norm = nn.RMSNorm(dims=head_dim, eps=args.rms_norm_eps)
111
+
112
+ self.rope_scaling = args.rope_scaling
113
+
114
+ self.rotary_emb = Qwen3OmniMoeThinkerTextRotaryEmbedding(
115
+ head_dim,
116
+ max_position_embeddings=args.max_position_embeddings,
117
+ base=args.rope_theta,
118
+ rope_scaling=self.rope_scaling,
119
+ )
120
+
121
+ def __call__(
122
+ self,
123
+ x: mx.array,
124
+ mask: Optional[mx.array] = None,
125
+ cache: Optional[KVCache] = None,
126
+ position_ids: Optional[mx.array] = None,
127
+ ) -> mx.array:
128
+ B, L, D = x.shape
129
+
130
+ queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
131
+
132
+ queries = self.q_norm(
133
+ queries.reshape(B, L, self.n_heads, self.head_dim)
134
+ ).transpose(0, 2, 1, 3)
135
+ keys = self.k_norm(
136
+ keys.reshape(B, L, self.n_kv_heads, self.head_dim)
137
+ ).transpose(0, 2, 1, 3)
138
+ values = values.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(
139
+ 0, 2, 1, 3
140
+ )
141
+
142
+ kv_seq_len = keys.shape[-2]
143
+
144
+ if position_ids is None:
145
+ kv_seq_len += cache.offset + 1
146
+ position_ids = mx.arange(cache.offset, cache.offset + L)
147
+ position_ids = mx.expand_dims(position_ids, axis=0)
148
+ position_ids = mx.tile(position_ids, (3, 1, 1))
149
+ else:
150
+ kv_seq_len += cache.offset + 1 if cache is not None else 0
151
+
152
+ cos, sin = self.rotary_emb(values, position_ids)
153
+
154
+ if mask is not None and isinstance(mask, mx.array):
155
+ mask = mask[..., :kv_seq_len]
156
+
157
+ queries, keys = apply_multimodal_rotary_pos_emb(queries, keys, cos, sin)
158
+
159
+ if cache is not None:
160
+ keys, values = cache.update_and_fetch(keys, values)
161
+
162
+ output = scaled_dot_product_attention(
163
+ queries, keys, values, cache, scale=self.scale, mask=mask
164
+ )
165
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
166
+ return self.o_proj(output)
167
+
168
+
169
+ class MLP(nn.Module):
170
+ def __init__(self, dim, hidden_dim):
171
+ super().__init__()
172
+ self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
173
+ self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
174
+ self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
175
+
176
+ def __call__(self, x) -> mx.array:
177
+ return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
178
+
179
+
180
+ class Qwen3OmniMoeThinkerTextSparseMoeBlock(nn.Module):
181
+ def __init__(self, args: TextConfig):
182
+ super().__init__()
183
+ self.num_experts = args.num_experts
184
+ self.top_k = args.num_experts_per_tok
185
+ self.norm_topk_prob = args.norm_topk_prob
186
+
187
+ self.gate = nn.Linear(args.hidden_size, args.num_experts, bias=False)
188
+ self.switch_mlp = SwitchGLU(
189
+ args.hidden_size, args.moe_intermediate_size, args.num_experts
190
+ )
191
+
192
+ def __call__(
193
+ self,
194
+ x: mx.array,
195
+ ) -> mx.array:
196
+ gates = self.gate(x)
197
+ gates = mx.softmax(gates, axis=-1, precise=True)
198
+
199
+ k = self.top_k
200
+ inds = mx.argpartition(gates, kth=-k, axis=-1)[..., -k:]
201
+ scores = mx.take_along_axis(gates, inds, axis=-1)
202
+ if self.norm_topk_prob:
203
+ scores /= mx.sum(scores, axis=-1, keepdims=True)
204
+
205
+ y = self.switch_mlp(x, inds)
206
+ y = (y * scores[..., None]).sum(axis=-2)
207
+
208
+ return y
209
+
210
+
211
+ class Qwen3OmniMoEThinkerTextDecoderLayer(nn.Module):
212
+ def __init__(self, args: TextConfig, layer_idx: int):
213
+ super().__init__()
214
+ self.hidden_size = args.hidden_size
215
+ self.self_attn = Attention(args)
216
+
217
+ self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
218
+ self.post_attention_layernorm = nn.RMSNorm(
219
+ args.hidden_size, eps=args.rms_norm_eps
220
+ )
221
+ self.args = args
222
+
223
+ if (layer_idx not in args.mlp_only_layers) and (
224
+ args.num_experts > 0 and (layer_idx + 1) % args.decoder_sparse_step == 0
225
+ ):
226
+ self.mlp = Qwen3OmniMoeThinkerTextSparseMoeBlock(args)
227
+ else:
228
+ self.mlp = MLP(args.hidden_size, args.intermediate_size)
229
+
230
+ def __call__(
231
+ self,
232
+ x: mx.array,
233
+ mask: Optional[mx.array] = None,
234
+ cache: Optional[KVCache] = None,
235
+ position_ids: Optional[mx.array] = None,
236
+ ) -> mx.array:
237
+ r = self.self_attn(self.input_layernorm(x), mask, cache, position_ids)
238
+ h = x + r
239
+ r = self.mlp(self.post_attention_layernorm(h))
240
+ out = h + r
241
+ return out
242
+
243
+
244
+ class Qwen3VLMoEModel(nn.Module):
245
+ def __init__(self, args: TextConfig):
246
+ super().__init__()
247
+ self.args = args
248
+ self.vocab_size = args.vocab_size
249
+ self.num_hidden_layers = args.num_hidden_layers
250
+ assert self.vocab_size > 0
251
+ self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
252
+ self.layers = [
253
+ Qwen3OmniMoEThinkerTextDecoderLayer(args=args, layer_idx=layer_idx)
254
+ for layer_idx in range(args.num_hidden_layers)
255
+ ]
256
+ self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
257
+
258
+ def __call__(
259
+ self,
260
+ inputs: mx.array,
261
+ inputs_embeds: Optional[mx.array] = None,
262
+ mask: Optional[mx.array] = None,
263
+ cache=None,
264
+ position_ids: Optional[mx.array] = None,
265
+ visual_pos_masks: Optional[mx.array] = None,
266
+ deepstack_visual_embeds: Optional[mx.array] = None,
267
+ output_hidden_states: bool = False,
268
+ ):
269
+ if inputs_embeds is None:
270
+ h = self.embed_tokens(inputs)
271
+ else:
272
+ h = inputs_embeds
273
+
274
+ if cache is None:
275
+ cache = [None] * len(self.layers)
276
+
277
+ if mask is None:
278
+ mask = create_attention_mask(h, cache)
279
+
280
+ all_hidden_states = [] if output_hidden_states else None
281
+
282
+ for layer_idx, (layer, c) in enumerate(zip(self.layers, cache)):
283
+ if output_hidden_states:
284
+ all_hidden_states.append(h)
285
+ h = layer(h, mask, c, position_ids)
286
+
287
+ if deepstack_visual_embeds is not None and layer_idx in range(
288
+ len(deepstack_visual_embeds)
289
+ ):
290
+ h = self._deepstack_process(
291
+ h,
292
+ visual_pos_masks,
293
+ deepstack_visual_embeds[layer_idx],
294
+ )
295
+
296
+ if layer_idx % 4 == 0:
297
+ mx.eval(h)
298
+
299
+ if output_hidden_states:
300
+ all_hidden_states.append(h)
301
+
302
+ return (
303
+ (self.norm(h), all_hidden_states) if output_hidden_states else self.norm(h)
304
+ )
305
+
306
+ def _deepstack_process(
307
+ self,
308
+ hidden_states: mx.array,
309
+ visual_pos_masks: mx.array,
310
+ visual_embeds: mx.array,
311
+ ):
312
+ if visual_pos_masks.ndim == 3:
313
+ visual_pos_masks = visual_pos_masks[..., 0]
314
+ visual_embeds = visual_embeds.astype(hidden_states.dtype)
315
+ visual_indices = np.where(visual_pos_masks)[0].tolist()
316
+ local_this = hidden_states[:, visual_indices, :] + visual_embeds
317
+ hidden_states[:, visual_indices, :] = local_this
318
+ return hidden_states
319
+
320
+
321
+ class LanguageModel(nn.Module):
322
+ def __init__(self, args: TextConfig, config: ThinkerConfig = None):
323
+ super().__init__()
324
+ self.args = args
325
+ self.config = config
326
+ self.model_type = args.model_type
327
+ self.model = Qwen3VLMoEModel(args)
328
+ self._rope_deltas = None
329
+
330
+ if not args.tie_word_embeddings:
331
+ self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
332
+
333
+ def get_rope_index(
334
+ self,
335
+ input_ids: mx.array,
336
+ image_grid_thw: Optional[mx.array] = None,
337
+ video_grid_thw: Optional[mx.array] = None,
338
+ attention_mask: Optional[mx.array] = None,
339
+ ):
340
+ batch_size, seq_length = input_ids.shape
341
+ position_ids = mx.arange(seq_length, dtype=mx.int32)
342
+ position_ids = mx.broadcast_to(position_ids[None, :], (batch_size, seq_length))
343
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
344
+ image_token_id = self.config.image_token_id
345
+ video_token_id = self.config.video_token_id
346
+ vision_start_token_id = self.config.vision_start_token_id
347
+ mrope_position_deltas = []
348
+ if input_ids is not None and (
349
+ image_grid_thw is not None or video_grid_thw is not None
350
+ ):
351
+ total_input_ids = input_ids
352
+ if attention_mask is None:
353
+ attention_mask = mx.ones_like(input_ids)
354
+ position_ids = mx.ones(
355
+ (3, input_ids.shape[0], input_ids.shape[1]), dtype=input_ids.dtype
356
+ )
357
+ image_index, video_index = 0, 0
358
+ for i, input_ids in enumerate(total_input_ids):
359
+ input_ids = mx.where(
360
+ attention_mask[i] == 1, input_ids, mx.zeros_like(input_ids)
361
+ )
362
+ image_nums, video_nums = 0, 0
363
+ vision_start_indices = mx.sum(
364
+ mx.where(
365
+ input_ids == vision_start_token_id,
366
+ mx.arange(input_ids.shape[0]),
367
+ mx.zeros_like(input_ids),
368
+ )
369
+ )
370
+ vision_tokens = input_ids[vision_start_indices + 1]
371
+ image_nums = (vision_tokens == image_token_id).sum().item()
372
+ video_nums = (vision_tokens == video_token_id).sum().item()
373
+ input_tokens = input_ids.tolist()
374
+ llm_pos_ids_list: list = []
375
+ st = 0
376
+ remain_images, remain_videos = image_nums, video_nums
377
+ for _ in range(image_nums + video_nums):
378
+ if image_token_id in input_tokens and remain_images > 0:
379
+ ed_image = input_tokens.index(image_token_id, st)
380
+ else:
381
+ ed_image = len(input_tokens) + 1
382
+ if video_token_id in input_tokens and remain_videos > 0:
383
+ ed_video = input_tokens.index(video_token_id, st)
384
+ else:
385
+ ed_video = len(input_tokens) + 1
386
+ if ed_image < ed_video:
387
+ t, h, w = (
388
+ image_grid_thw[image_index][0],
389
+ image_grid_thw[image_index][1],
390
+ image_grid_thw[image_index][2],
391
+ )
392
+ image_index += 1
393
+ remain_images -= 1
394
+ ed = ed_image
395
+ else:
396
+ t, h, w = (
397
+ video_grid_thw[video_index][0],
398
+ video_grid_thw[video_index][1],
399
+ video_grid_thw[video_index][2],
400
+ )
401
+ video_index += 1
402
+ remain_videos -= 1
403
+ ed = ed_video
404
+ llm_grid_t, llm_grid_h, llm_grid_w = (
405
+ t.item(),
406
+ h.item() // spatial_merge_size,
407
+ w.item() // spatial_merge_size,
408
+ )
409
+ text_len = ed - st
410
+ st_idx = (
411
+ llm_pos_ids_list[-1].max() + 1
412
+ if len(llm_pos_ids_list) > 0
413
+ else 0
414
+ )
415
+ index = mx.arange(text_len).reshape(1, text_len)
416
+ index = mx.broadcast_to(index, (3, text_len))
417
+ index = index + st_idx
418
+ llm_pos_ids_list.append(index)
419
+ t_index = mx.arange(llm_grid_t).reshape(llm_grid_t, 1)
420
+ t_index = mx.broadcast_to(
421
+ t_index, (llm_grid_t, llm_grid_h * llm_grid_w)
422
+ )
423
+ t_index = t_index.flatten()
424
+
425
+ h_index = mx.arange(llm_grid_h).reshape(1, llm_grid_h, 1)
426
+ h_index = mx.broadcast_to(
427
+ h_index, (llm_grid_t, llm_grid_h, llm_grid_w)
428
+ )
429
+ h_index = h_index.flatten()
430
+
431
+ w_index = mx.arange(llm_grid_w).reshape(1, 1, llm_grid_w)
432
+ w_index = mx.broadcast_to(
433
+ w_index, (llm_grid_t, llm_grid_h, llm_grid_w)
434
+ )
435
+ w_index = w_index.flatten()
436
+
437
+ llm_pos_ids_list.append(
438
+ mx.stack([t_index, h_index, w_index]) + text_len + st_idx
439
+ )
440
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
441
+ if st < len(input_tokens):
442
+ st_idx = (
443
+ llm_pos_ids_list[-1].max() + 1
444
+ if len(llm_pos_ids_list) > 0
445
+ else 0
446
+ )
447
+ text_len = len(input_tokens) - st
448
+
449
+ t_index = mx.arange(text_len).reshape(1, text_len)
450
+ t_index = mx.broadcast_to(t_index, (3, text_len))
451
+
452
+ llm_pos_ids_list.append(t_index + st_idx)
453
+
454
+ llm_positions = mx.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
455
+ mask = mx.array(attention_mask[i] == 1)
456
+ expanded_mask = mx.expand_dims(mask, axis=0)
457
+ expanded_mask = mx.broadcast_to(expanded_mask, (3, 1, mask.shape[0]))
458
+ expanded_positions = mx.expand_dims(llm_positions, axis=1)
459
+ new_positions = mx.where(
460
+ expanded_mask, expanded_positions, position_ids[:, i : i + 1, :]
461
+ )
462
+ updated_position_ids = mx.concatenate(
463
+ [
464
+ position_ids[:, :i, :],
465
+ new_positions,
466
+ position_ids[:, i + 1 :, :],
467
+ ],
468
+ axis=1,
469
+ )
470
+ position_ids = updated_position_ids
471
+ mrope_position_deltas.append(
472
+ llm_positions.max() + 1 - len(total_input_ids[i])
473
+ )
474
+ mrope_position_deltas = mx.array(mrope_position_deltas)[0]
475
+ return position_ids, mrope_position_deltas
476
+ else:
477
+ if attention_mask is not None:
478
+ position_ids = mx.cumsum(attention_mask.astype(mx.int64), axis=-1) - 1
479
+ position_ids = mx.where(
480
+ attention_mask == 0, mx.ones_like(position_ids), position_ids
481
+ )
482
+ position_ids = mx.expand_dims(position_ids[0], axis=0)
483
+ position_ids = mx.tile(position_ids, (3, 1, 1))
484
+ max_position_ids = position_ids.max(0, keepdims=False)[0].max(
485
+ -1, keepdims=True
486
+ )[0]
487
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
488
+ else:
489
+ position_ids = mx.arange(input_ids.shape[1]).reshape(1, -1)
490
+ position_ids = mx.broadcast_to(
491
+ position_ids, (3, input_ids.shape[0], input_ids.shape[1])
492
+ )
493
+ mrope_position_deltas = mx.zeros(
494
+ [input_ids.shape[0], 1],
495
+ dtype=input_ids.dtype,
496
+ )
497
+ return position_ids, mrope_position_deltas
498
+
499
+ def __call__(
500
+ self,
501
+ inputs: mx.array,
502
+ inputs_embeds: Optional[mx.array] = None,
503
+ mask: Optional[mx.array] = None,
504
+ cache=None,
505
+ visual_pos_masks: Optional[mx.array] = None,
506
+ deepstack_visual_embeds: Optional[mx.array] = None,
507
+ **kwargs,
508
+ ):
509
+
510
+ position_ids = kwargs.pop("position_ids", None)
511
+ pixel_values = kwargs.pop("pixel_values", None)
512
+ image_grid_thw = kwargs.pop("image_grid_thw", None)
513
+ video_grid_thw = kwargs.pop("video_grid_thw", None)
514
+ if pixel_values is not None:
515
+ self._rope_deltas = None
516
+
517
+ cache_offset = 0
518
+ if cache and cache[0] is not None:
519
+ offset = cache[0].offset
520
+ if isinstance(offset, int):
521
+ cache_offset = offset
522
+ elif isinstance(offset, mx.array):
523
+ cache_offset = (offset if offset.ndim == 0 else offset[0]).item()
524
+ else:
525
+ raise ValueError(f"Unexpected cache offset type: {type(offset)}")
526
+
527
+ if position_ids is None and (mask is None or mask.ndim == 2):
528
+ if (
529
+ (cache is not None and cache[0] is not None and (cache_offset == 0))
530
+ or self._rope_deltas is None
531
+ or cache is None
532
+ ):
533
+ position_ids, rope_deltas = self.get_rope_index(
534
+ inputs, image_grid_thw, video_grid_thw, mask
535
+ )
536
+ self._rope_deltas = rope_deltas
537
+ else:
538
+ batch_size, seq_length = inputs.shape
539
+ delta = mx.array(
540
+ cache_offset + self._rope_deltas if cache is not None else 0
541
+ )
542
+ position_ids = mx.arange(seq_length).reshape(1, -1)
543
+ position_ids = mx.broadcast_to(position_ids, (batch_size, seq_length))
544
+
545
+ if cache_offset is not None:
546
+ if delta.ndim == 0:
547
+ delta = mx.expand_dims(delta, axis=0)
548
+
549
+ if delta.shape[0] < batch_size:
550
+ delta = mx.tile(delta, (batch_size, 1))
551
+ else:
552
+ delta = delta[:batch_size]
553
+
554
+ position_ids = mx.add(position_ids, delta)[None, ...]
555
+ position_ids = mx.broadcast_to(
556
+ position_ids, (3, batch_size, seq_length)
557
+ )
558
+
559
+ visual_pos_masks = kwargs.get("visual_pos_masks", None)
560
+ deepstack_visual_embeds = kwargs.get("deepstack_visual_embeds", None)
561
+ output_hidden_states = kwargs.pop("output_hidden_states", False)
562
+
563
+ out = self.model(
564
+ inputs,
565
+ cache=cache,
566
+ inputs_embeds=inputs_embeds,
567
+ position_ids=position_ids,
568
+ visual_pos_masks=visual_pos_masks,
569
+ deepstack_visual_embeds=deepstack_visual_embeds,
570
+ output_hidden_states=output_hidden_states,
571
+ )
572
+
573
+ if output_hidden_states:
574
+ hidden_states, all_hidden_states = out
575
+ out = hidden_states
576
+
577
+ if self.args.tie_word_embeddings:
578
+ logits = self.model.embed_tokens.as_linear(out)
579
+ else:
580
+ logits = self.lm_head(out)
581
+
582
+ return LanguageModelOutput(
583
+ logits=logits,
584
+ hidden_states=all_hidden_states if output_hidden_states else None,
585
+ )
586
+
587
+ def sanitize(self, weights):
588
+ for l in range(self.args.num_hidden_layers):
589
+ prefix = f"thinker.language_model.model.layers.{l}.mlp"
590
+ for n in ["gate_proj", "down_proj", "up_proj"]:
591
+ experts_weights = []
592
+ for e in range(self.args.num_experts):
593
+ key = f"{prefix}.experts.{e}.{n}.weight"
594
+ if key in weights:
595
+ experts_weights.append(weights.pop(key))
596
+
597
+ if experts_weights:
598
+ weights[f"{prefix}.switch_mlp.{n}.weight"] = mx.stack(
599
+ experts_weights, axis=0
600
+ )
601
+ return weights
602
+
603
+ @property
604
+ def quant_predicate(self):
605
+ def predicate(path, _):
606
+ if path.endswith("mlp.gate"):
607
+ return {"group_size": 64, "bits": 8}
608
+ return True
609
+
610
+ return predicate
611
+
612
+ @property
613
+ def layers(self):
614
+ return self.model.layers
615
+
616
+ @property
617
+ def head_dim(self):
618
+ return self.args.hidden_size // self.args.num_attention_heads
619
+
620
+ @property
621
+ def n_kv_heads(self):
622
+ return self.args.num_key_value_heads
@@ -0,0 +1,69 @@
1
+ import mlx.core as mx
2
+ import numpy as np
3
+
4
+ from mlx_vlm.utils import load_audio
5
+
6
+
7
+ def process_multimodal_info(conversation, use_audio_in_video=False):
8
+ audios = []
9
+ images = []
10
+ videos = []
11
+ for msg in conversation:
12
+ if "content" in msg:
13
+ if isinstance(msg["content"], str):
14
+ continue
15
+ for part in msg["content"]:
16
+ if part["type"] == "audio":
17
+ audios.append(part["audio"])
18
+ elif part["type"] == "image":
19
+ images.append(part["image"])
20
+ elif part["type"] == "video":
21
+ videos.append(part["video"])
22
+ return audios, images, videos
23
+
24
+
25
+ def prepare_omni_inputs(
26
+ processor,
27
+ conversation,
28
+ use_audio_in_video=False,
29
+ ):
30
+ audios, images, videos = process_multimodal_info(conversation, use_audio_in_video)
31
+
32
+ text = processor.apply_chat_template(
33
+ conversation, add_generation_prompt=True, tokenize=False
34
+ )
35
+
36
+ loaded_audios = []
37
+ if audios:
38
+ sr = processor.feature_extractor.sampling_rate
39
+ for audio_path in audios:
40
+ loaded_audios.append(load_audio(audio_path, sr=sr))
41
+
42
+ inputs = processor(
43
+ text=[text],
44
+ audio=loaded_audios if loaded_audios else None,
45
+ images=images if images else None,
46
+ videos=videos if videos else None,
47
+ return_tensors="pt",
48
+ padding=True,
49
+ use_audio_in_video=use_audio_in_video,
50
+ )
51
+
52
+ model_inputs = {}
53
+ for k, v in inputs.items():
54
+ if hasattr(v, "numpy"):
55
+ model_inputs[k] = mx.array(v.numpy())
56
+ elif isinstance(v, np.ndarray):
57
+ model_inputs[k] = mx.array(v)
58
+ else:
59
+ model_inputs[k] = v
60
+
61
+ if (
62
+ "feature_attention_mask" in model_inputs
63
+ and "audio_feature_lengths" not in model_inputs
64
+ ):
65
+ model_inputs["audio_feature_lengths"] = (
66
+ model_inputs["feature_attention_mask"].sum(axis=1).astype(mx.int32)
67
+ )
68
+
69
+ return model_inputs, text