fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,419 @@
1
+ from itertools import accumulate
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from .config import VisionConfig
7
+
8
+
9
+ def check_array_shape(arr):
10
+ shape = arr.shape
11
+
12
+ if len(shape) not in [4, 5]:
13
+ return False
14
+
15
+ B, out_channels, kH, KW, t = shape
16
+
17
+ if t == 3:
18
+ return True
19
+
20
+ if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
21
+ return True
22
+ else:
23
+ return False
24
+
25
+
26
+ def rotate_half(x):
27
+ x1 = x[..., : x.shape[-1] // 2]
28
+ x2 = x[..., x.shape[-1] // 2 :]
29
+ return mx.concatenate([-x2, x1], axis=-1)
30
+
31
+
32
+ def apply_rotary_pos_emb_vision(tensor, freqs) -> mx.array:
33
+ orig_dtype = tensor.dtype
34
+
35
+ cos = mx.cos(freqs)
36
+ sin = mx.sin(freqs)
37
+
38
+ cos = mx.expand_dims(cos, axis=1)
39
+ cos = mx.tile(cos, (1, 1, 2))
40
+ cos = mx.expand_dims(cos, axis=0)
41
+
42
+ sin = mx.expand_dims(sin, axis=1)
43
+ sin = mx.tile(sin, (1, 1, 2))
44
+ sin = mx.expand_dims(sin, axis=0)
45
+
46
+ output = (tensor * cos) + (rotate_half(tensor) * sin)
47
+ return output.astype(orig_dtype)
48
+
49
+
50
+ class VisionRotaryEmbedding(nn.Module):
51
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
52
+ super().__init__()
53
+ self.dim = dim
54
+ self.theta = theta
55
+
56
+ def __call__(self, seqlen: int) -> mx.array:
57
+ inv_freq = 1.0 / (
58
+ self.theta ** (mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim)
59
+ )
60
+ seq = mx.arange(seqlen, dtype=inv_freq.dtype)
61
+ freqs = mx.outer(seq, inv_freq)
62
+ return freqs
63
+
64
+
65
+ class PatchEmbed(nn.Module):
66
+ def __init__(
67
+ self,
68
+ patch_size: int = 14,
69
+ temporal_patch_size: int = 2,
70
+ in_channels: int = 3,
71
+ hidden_size: int = 1152,
72
+ ) -> None:
73
+ super().__init__()
74
+ self.patch_size = patch_size
75
+ self.temporal_patch_size = temporal_patch_size
76
+ self.in_channels = in_channels
77
+ self.hidden_size = hidden_size
78
+
79
+ kernel_size = [temporal_patch_size, patch_size, patch_size]
80
+ self.proj = nn.Conv3d(
81
+ in_channels,
82
+ hidden_size,
83
+ kernel_size=kernel_size,
84
+ stride=kernel_size,
85
+ bias=True,
86
+ )
87
+
88
+ def __call__(self, hidden_states: mx.array) -> mx.array:
89
+ hidden_states = hidden_states.reshape(
90
+ -1,
91
+ self.in_channels,
92
+ self.temporal_patch_size,
93
+ self.patch_size,
94
+ self.patch_size,
95
+ ).moveaxis(1, 4)
96
+
97
+ hidden_states = self.proj(hidden_states)
98
+ hidden_states = hidden_states.reshape(-1, self.hidden_size)
99
+ return hidden_states
100
+
101
+
102
+ class PatchMerger(nn.Module):
103
+ def __init__(self, config: VisionConfig, use_postshuffle_norm=False) -> None:
104
+ super().__init__()
105
+ self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
106
+ self.use_postshuffle_norm = use_postshuffle_norm
107
+ self.norm = nn.LayerNorm(
108
+ self.hidden_size if use_postshuffle_norm else config.hidden_size, eps=1e-6
109
+ )
110
+ self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size)
111
+ self.act_fn = nn.GELU()
112
+ self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size)
113
+
114
+ def __call__(self, x: mx.array) -> mx.array:
115
+ x = self.norm(
116
+ x.reshape(-1, self.hidden_size) if self.use_postshuffle_norm else x
117
+ ).reshape(-1, self.hidden_size)
118
+ x = self.linear_fc2(self.act_fn(self.linear_fc1(x)))
119
+ return x
120
+
121
+
122
+ class Attention(nn.Module):
123
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
124
+ super().__init__()
125
+ self.num_heads = num_heads
126
+ self.head_dim = head_dim = dim // num_heads
127
+ self.scale = head_dim**-0.5
128
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
129
+ self.proj = nn.Linear(dim, dim)
130
+
131
+ def __call__(
132
+ self, x: mx.array, cu_seqlens: mx.array, rotary_pos_emb: mx.array = None
133
+ ) -> mx.array:
134
+ seq_length = x.shape[0]
135
+ qkv = (
136
+ self.qkv(x).reshape(seq_length, 3, self.num_heads, -1).transpose(1, 0, 2, 3)
137
+ )
138
+ q, k, v = mx.split(qkv, 3)
139
+
140
+ q = apply_rotary_pos_emb_vision(mx.expand_dims(q, 0), rotary_pos_emb)[0]
141
+ k = apply_rotary_pos_emb_vision(mx.expand_dims(k, 0), rotary_pos_emb)[0]
142
+
143
+ q = q.transpose(0, 2, 1, 3)
144
+ k = k.transpose(0, 2, 1, 3)
145
+ v = v.transpose(0, 2, 1, 3)
146
+
147
+ splits = [
148
+ mx.split(tensor, cu_seqlens[1:-1].tolist(), axis=2) for tensor in (q, k, v)
149
+ ]
150
+
151
+ attn_outputs = []
152
+ for q, k, v in zip(*splits):
153
+ output = mx.fast.scaled_dot_product_attention(
154
+ q, k, v, scale=self.scale, mask=None
155
+ )
156
+ attn_outputs.append(output)
157
+
158
+ output = mx.concatenate(attn_outputs, axis=2)
159
+ output = output.transpose(0, 2, 1, 3).reshape(seq_length, -1)
160
+ return self.proj(output)
161
+
162
+
163
+ class MLP(nn.Module):
164
+ def __init__(self, dim, hidden_dim):
165
+ super().__init__()
166
+ self.linear_fc1 = nn.Linear(dim, hidden_dim, bias=True)
167
+ self.linear_fc2 = nn.Linear(hidden_dim, dim, bias=True)
168
+ self.act_fn = nn.GELU()
169
+
170
+ def __call__(self, x: mx.array) -> mx.array:
171
+ return self.linear_fc2(self.act_fn(self.linear_fc1(x)))
172
+
173
+
174
+ class Qwen3VLMoEVisionBlock(nn.Module):
175
+ def __init__(self, config: VisionConfig) -> None:
176
+ super().__init__()
177
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6)
178
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6)
179
+
180
+ self.attn = Attention(dim=config.hidden_size, num_heads=config.num_heads)
181
+ self.mlp = MLP(dim=config.hidden_size, hidden_dim=config.intermediate_size)
182
+
183
+ def __call__(self, hidden_states, cu_seqlens, rotary_pos_emb) -> mx.array:
184
+ hidden_states = hidden_states + self.attn(
185
+ self.norm1(hidden_states),
186
+ cu_seqlens=cu_seqlens,
187
+ rotary_pos_emb=rotary_pos_emb,
188
+ )
189
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
190
+ return hidden_states
191
+
192
+
193
+ class VisionModel(nn.Module):
194
+ def __init__(self, config: VisionConfig) -> None:
195
+ super().__init__()
196
+ self.config = config
197
+ self.model_type = config.model_type
198
+
199
+ if self.model_type not in ["qwen3_vl_moe", "qwen3_omni_moe_vision_encoder"]:
200
+ raise ValueError(f"Unsupported model type: {self.model_type}")
201
+
202
+ self.spatial_merge_size = config.spatial_merge_size
203
+
204
+ self.patch_embed = PatchEmbed(
205
+ patch_size=config.patch_size,
206
+ temporal_patch_size=config.temporal_patch_size,
207
+ in_channels=config.in_channels,
208
+ hidden_size=config.hidden_size,
209
+ )
210
+
211
+ head_dim = config.hidden_size // config.num_heads
212
+ self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
213
+
214
+ self.pos_embed = nn.Embedding(
215
+ config.num_position_embeddings, config.hidden_size
216
+ )
217
+ self.num_grid_per_side = int(config.num_position_embeddings**0.5)
218
+
219
+ self.blocks = [Qwen3VLMoEVisionBlock(config) for _ in range(config.depth)]
220
+ self.merger = PatchMerger(config=config, use_postshuffle_norm=False)
221
+
222
+ self.deepstack_visual_indexes = config.deepstack_visual_indexes
223
+ self.deepstack_merger_list = [
224
+ PatchMerger(
225
+ config=config,
226
+ use_postshuffle_norm=True,
227
+ )
228
+ for _ in range(len(config.deepstack_visual_indexes))
229
+ ]
230
+
231
+ def rot_pos_emb(self, grid_thw: mx.array) -> mx.array:
232
+ merge_size = self.spatial_merge_size
233
+
234
+ max_hw = int(mx.max(grid_thw[:, 1:]).item())
235
+ freq_table = self.rotary_pos_emb(max_hw)
236
+
237
+ pos_ids = []
238
+
239
+ for num_frames, height, width in grid_thw.tolist():
240
+ num_frames, height, width = int(num_frames), int(height), int(width)
241
+ merged_h, merged_w = height // merge_size, width // merge_size
242
+
243
+ block_rows = mx.arange(merged_h)
244
+ block_cols = mx.arange(merged_w)
245
+
246
+ intra_row = mx.arange(merge_size)
247
+ intra_col = mx.arange(merge_size)
248
+
249
+ row_idx = (
250
+ block_rows[:, None, None, None] * merge_size
251
+ + intra_row[None, None, :, None]
252
+ )
253
+ col_idx = (
254
+ block_cols[None, :, None, None] * merge_size
255
+ + intra_col[None, None, None, :]
256
+ )
257
+
258
+ row_idx = mx.broadcast_to(
259
+ row_idx, (merged_h, merged_w, merge_size, merge_size)
260
+ ).reshape(-1)
261
+ col_idx = mx.broadcast_to(
262
+ col_idx, (merged_h, merged_w, merge_size, merge_size)
263
+ ).reshape(-1)
264
+
265
+ coords = mx.stack([row_idx, col_idx], axis=-1)
266
+
267
+ if num_frames > 1:
268
+ coords = mx.tile(coords, (num_frames, 1))
269
+
270
+ pos_ids.append(coords)
271
+
272
+ pos_ids = mx.concatenate(pos_ids, axis=0)
273
+
274
+ h_embeddings = freq_table[pos_ids[:, 0]]
275
+ w_embeddings = freq_table[pos_ids[:, 1]]
276
+
277
+ embeddings = mx.concatenate([h_embeddings, w_embeddings], axis=-1)
278
+
279
+ return embeddings
280
+
281
+ def fast_pos_embed_interpolate(self, grid_thw):
282
+ grid_thw_list = grid_thw.tolist()
283
+ idx_list = [[] for _ in range(4)]
284
+ weight_list = [[] for _ in range(4)]
285
+
286
+ for t, h, w in grid_thw_list:
287
+ h = int(h)
288
+ w = int(w)
289
+ t = int(t)
290
+
291
+ h_idxs = mx.linspace(0, self.num_grid_per_side - 1, h)
292
+ w_idxs = mx.linspace(0, self.num_grid_per_side - 1, w)
293
+
294
+ h_idxs_floor = h_idxs.astype(mx.int32)
295
+ w_idxs_floor = w_idxs.astype(mx.int32)
296
+ h_idxs_ceil = mx.minimum(h_idxs_floor + 1, self.num_grid_per_side - 1)
297
+ w_idxs_ceil = mx.minimum(w_idxs_floor + 1, self.num_grid_per_side - 1)
298
+
299
+ dh = h_idxs - h_idxs_floor.astype(mx.float32)
300
+ dw = w_idxs - w_idxs_floor.astype(mx.float32)
301
+
302
+ base_h = h_idxs_floor * self.num_grid_per_side
303
+ base_h_ceil = h_idxs_ceil * self.num_grid_per_side
304
+
305
+ indices = [
306
+ (base_h[:, None] + w_idxs_floor[None, :]).flatten(),
307
+ (base_h[:, None] + w_idxs_ceil[None, :]).flatten(),
308
+ (base_h_ceil[:, None] + w_idxs_floor[None, :]).flatten(),
309
+ (base_h_ceil[:, None] + w_idxs_ceil[None, :]).flatten(),
310
+ ]
311
+
312
+ weights = [
313
+ ((1 - dh)[:, None] * (1 - dw)[None, :]).flatten(),
314
+ ((1 - dh)[:, None] * dw[None, :]).flatten(),
315
+ (dh[:, None] * (1 - dw)[None, :]).flatten(),
316
+ (dh[:, None] * dw[None, :]).flatten(),
317
+ ]
318
+
319
+ for i in range(4):
320
+ idx_list[i].extend(indices[i].tolist())
321
+ weight_list[i].extend(weights[i].tolist())
322
+
323
+ idx_tensor = mx.array(idx_list, dtype=mx.int32)
324
+ weight_tensor = mx.array(weight_list, dtype=self.pos_embed.weight.dtype)
325
+
326
+ pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None]
327
+ patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3]
328
+
329
+ split_sizes = [int(h * w) for t, h, w in grid_thw_list]
330
+ if len(split_sizes) > 1:
331
+ split_indices = list(accumulate(split_sizes[:-1]))
332
+ patch_pos_embeds_split = mx.split(patch_pos_embeds, split_indices, axis=0)
333
+ else:
334
+ patch_pos_embeds_split = [patch_pos_embeds]
335
+
336
+ patch_pos_embeds_permute = []
337
+ merge_size = self.config.spatial_merge_size
338
+
339
+ for pos_embed, (t, h, w) in zip(patch_pos_embeds_split, grid_thw_list):
340
+ t, h, w = int(t), int(h), int(w)
341
+ feature_dim = pos_embed.shape[-1]
342
+ pos_embed = mx.tile(pos_embed, (t, 1))
343
+ pos_embed = pos_embed.reshape(t, h, w, feature_dim)
344
+ pos_embed = (
345
+ pos_embed.reshape(
346
+ t,
347
+ h // merge_size,
348
+ merge_size,
349
+ w // merge_size,
350
+ merge_size,
351
+ feature_dim,
352
+ )
353
+ .transpose(0, 1, 3, 2, 4, 5)
354
+ .reshape(-1, feature_dim)
355
+ )
356
+ patch_pos_embeds_permute.append(pos_embed)
357
+
358
+ patch_pos_embeds = mx.concatenate(patch_pos_embeds_permute)
359
+ return patch_pos_embeds
360
+
361
+ def __call__(
362
+ self,
363
+ hidden_states: mx.array,
364
+ grid_thw: mx.array,
365
+ **kwargs,
366
+ ) -> mx.array:
367
+
368
+ hidden_states = self.patch_embed(hidden_states)
369
+ pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
370
+ hidden_states = hidden_states + pos_embeds
371
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
372
+
373
+ seq_len = hidden_states.shape[0]
374
+ hidden_states = hidden_states.reshape(seq_len, -1)
375
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
376
+
377
+ batch_size = grid_thw.shape[0]
378
+
379
+ cu_seqlens = []
380
+ for i in range(batch_size):
381
+ seq_len = grid_thw[i, 1] * grid_thw[i, 2]
382
+ cu_seqlens.append(mx.repeat(seq_len, grid_thw[i, 0]))
383
+
384
+ cu_seqlens = mx.concatenate(cu_seqlens)
385
+
386
+ cu_seqlens = mx.cumsum(cu_seqlens.astype(mx.int32), axis=0)
387
+ cu_seqlens = mx.pad(cu_seqlens, (1, 0), mode="constant", constant_values=0)
388
+
389
+ deepstack_feature_lists = []
390
+ for layer_num, blk in enumerate(self.blocks):
391
+ hidden_states = blk(
392
+ hidden_states,
393
+ cu_seqlens=cu_seqlens,
394
+ rotary_pos_emb=rotary_pos_emb,
395
+ )
396
+ if layer_num in self.deepstack_visual_indexes:
397
+ deepstack_feature = self.deepstack_merger_list[
398
+ self.deepstack_visual_indexes.index(layer_num)
399
+ ](hidden_states)
400
+ deepstack_feature_lists.append(deepstack_feature)
401
+
402
+ hidden_states = self.merger(hidden_states)
403
+
404
+ return hidden_states, deepstack_feature_lists
405
+
406
+ def sanitize(self, weights):
407
+ sanitized_weights = {}
408
+ for k, v in weights.items():
409
+ if "position_ids" in k:
410
+ continue
411
+ elif "patch_embed.proj.weight" in k:
412
+ if check_array_shape(v):
413
+ sanitized_weights[k] = v
414
+ else:
415
+ sanitized_weights[k] = v.transpose(0, 2, 3, 4, 1)
416
+ else:
417
+ sanitized_weights[k] = v
418
+
419
+ return sanitized_weights
@@ -0,0 +1,2 @@
1
+ from .config import ModelConfig, TextConfig, VisionConfig
2
+ from .qwen3_vl import LanguageModel, Model, VisionModel
@@ -0,0 +1,103 @@
1
+ import inspect
2
+ from dataclasses import dataclass, field
3
+ from typing import Dict, List, Optional, Union
4
+
5
+ from ..base import BaseModelConfig
6
+
7
+
8
+ @dataclass
9
+ class VisionConfig(BaseModelConfig):
10
+ model_type: str = "qwen3_vl"
11
+ depth: int = 32
12
+ hidden_size: int = 1280
13
+ intermediate_size: int = 3420
14
+ out_hidden_size: int = 1536
15
+ num_heads: int = 16
16
+ image_size: int = 384
17
+ patch_size: int = 14
18
+ vocab_size: int = 32000
19
+ mlp_ratio: float = 4.0
20
+ in_channels: int = 3
21
+ layer_norm_eps: float = 1e-6
22
+ spatial_patch_size: int = 14
23
+ spatial_merge_size: int = 2
24
+ tokens_per_second: int = 2
25
+ temporal_patch_size: int = 2
26
+ num_position_embeddings: int = 2304
27
+ window_size: int = 112
28
+ fullatt_block_indexes: list[int] = field(default_factory=lambda: [7, 15, 23, 31])
29
+ deepstack_visual_indexes: list[int] = field(default_factory=list)
30
+
31
+
32
+ @dataclass
33
+ class TextConfig(BaseModelConfig):
34
+ model_type: str
35
+ num_hidden_layers: int
36
+ hidden_size: int
37
+ intermediate_size: int
38
+ num_attention_heads: int
39
+ rms_norm_eps: float
40
+ vocab_size: int
41
+ num_key_value_heads: Optional[int]
42
+ head_dim: int
43
+ rope_theta: float
44
+ max_position_embeddings: int
45
+ norm_topk_prob: bool = True
46
+ rope_scaling: Optional[Dict[str, Union[float, str, bool, List[int]]]] = field(
47
+ default_factory=lambda: {"type": "default", "mrope_section": [24, 20, 20]}
48
+ )
49
+ tie_word_embeddings: bool = False
50
+ attention_bias: bool = False
51
+ hidden_act: str = "silu"
52
+
53
+ def __post_init__(self):
54
+ if self.num_key_value_heads is None:
55
+ self.num_key_value_heads = self.num_attention_heads
56
+
57
+ if self.rope_scaling:
58
+ # Normalize rope_scaling keys (accept both 'rope_type' and 'type')
59
+ if "type" not in self.rope_scaling and "rope_type" in self.rope_scaling:
60
+ self.rope_scaling["type"] = self.rope_scaling.pop("rope_type")
61
+
62
+ required_keys = {"mrope_section", "type"}
63
+ if not all(key in self.rope_scaling for key in required_keys):
64
+ raise ValueError(f"rope_scaling must contain keys {required_keys}")
65
+
66
+ if not self.rope_scaling["type"] in ["mrope", "default"]:
67
+ raise ValueError(f"rope_scaling type must be 'mrope' or 'default'")
68
+
69
+
70
+ @dataclass
71
+ class ModelConfig(BaseModelConfig):
72
+ text_config: TextConfig
73
+ vision_config: VisionConfig
74
+ model_type: str
75
+ ignore_index: int = -100
76
+ image_token_id: int = 151655
77
+ video_token_id: int = 151656
78
+ image_token_index: Optional[int] = None
79
+ video_token_index: Optional[int] = None
80
+ vision_start_token_id: int = 151652
81
+ vision_end_token_id: int = 151653
82
+ vision_token_id: int = 151654
83
+ vision_feature_select_strategy: str = "default"
84
+ vision_feature_layer: int = -2
85
+ vocab_size: int = 32000
86
+ eos_token_id: Optional[List[int]] = None
87
+
88
+ def __post_init__(self):
89
+ if self.image_token_index is None:
90
+ self.image_token_index = self.image_token_id
91
+ if self.video_token_index is None:
92
+ self.video_token_index = self.video_token_id
93
+
94
+ @classmethod
95
+ def from_dict(cls, params):
96
+
97
+ return cls(
98
+ **{
99
+ k: v
100
+ for k, v in params.items()
101
+ if k in inspect.signature(cls).parameters
102
+ }
103
+ )