fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,322 @@
1
+ """DFNRope Vision Transformer for ERNIE 4.5 VL."""
2
+
3
+ from typing import Optional
4
+
5
+ import mlx.core as mx
6
+ import mlx.nn as nn
7
+ import numpy as np
8
+
9
+ from .config import VisionConfig
10
+
11
+
12
+ def rotate_half(x):
13
+ """Rotates half the hidden dims of the input."""
14
+ x1 = x[..., : x.shape[-1] // 2]
15
+ x2 = x[..., x.shape[-1] // 2 :]
16
+ return mx.concatenate([-x2, x1], axis=-1)
17
+
18
+
19
+ def apply_rotary_pos_emb_vision(tensor: mx.array, freqs: mx.array) -> mx.array:
20
+ """Applies Rotary Position Embedding to the input tensors.
21
+
22
+ Args:
23
+ tensor: The input tensor.
24
+ freqs: The frequencies used for the rotation.
25
+
26
+ Returns:
27
+ output: the tensor rotated using the Rotary Position Embedding.
28
+ """
29
+ orig_dtype = tensor.dtype
30
+ tensor = tensor.astype(mx.float32)
31
+ cos = mx.cos(freqs)
32
+ sin = mx.sin(freqs)
33
+ # cos [θ0,θ1,θ2......θd/2-1] -> cos_pos [θ0,θ0,θ1,θ1,θ2,θ2......θd/2-1,θd/2-1]
34
+ cos = mx.expand_dims(cos, axis=1)
35
+ cos = mx.tile(cos, (1, 1, 2))
36
+ cos = mx.expand_dims(cos, axis=0)
37
+
38
+ sin = mx.expand_dims(sin, axis=1)
39
+ sin = mx.tile(sin, (1, 1, 2))
40
+ sin = mx.expand_dims(sin, axis=0)
41
+
42
+ output = tensor * cos + rotate_half(tensor) * sin
43
+ return output.astype(orig_dtype)
44
+
45
+
46
+ class VisionRotaryEmbedding(nn.Module):
47
+ """Rotary position embedding for vision transformer."""
48
+
49
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
50
+ super().__init__()
51
+ self.dim = dim
52
+ self.theta = theta
53
+
54
+ def __call__(self, seqlen: int) -> mx.array:
55
+ inv_freq = 1.0 / (
56
+ self.theta ** (mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim)
57
+ )
58
+ if isinstance(seqlen, mx.array):
59
+ seqlen = seqlen.item()
60
+ seq = mx.arange(seqlen, dtype=inv_freq.dtype)
61
+ freqs = mx.outer(seq, inv_freq)
62
+ return freqs
63
+
64
+
65
+ class PatchEmbed(nn.Module):
66
+ """Linear patch embedding for DFNRope Vision Transformer."""
67
+
68
+ def __init__(
69
+ self,
70
+ patch_size: int = 14,
71
+ in_channels: int = 3,
72
+ embed_dim: int = 1152,
73
+ ) -> None:
74
+ super().__init__()
75
+ self.patch_size = patch_size
76
+ self.in_channels = in_channels
77
+ self.embed_dim = embed_dim
78
+ # Linear projection: in_channels * patch_size * patch_size -> embed_dim
79
+ self.proj = nn.Linear(
80
+ in_channels * patch_size * patch_size, embed_dim, bias=False
81
+ )
82
+
83
+ def __call__(self, hidden_states: mx.array) -> mx.array:
84
+ """
85
+ Args:
86
+ hidden_states: Input tensor of shape [num_patches, in_channels * patch_size * patch_size]
87
+ Returns:
88
+ Output tensor of shape [num_patches, embed_dim]
89
+ """
90
+ target_dtype = self.proj.weight.dtype
91
+ hidden_states = self.proj(hidden_states.astype(target_dtype))
92
+ return hidden_states
93
+
94
+
95
+ class VisionMLP(nn.Module):
96
+ """MLP for vision transformer block."""
97
+
98
+ def __init__(
99
+ self, dim: int, hidden_dim: int, hidden_act: str = "quick_gelu"
100
+ ) -> None:
101
+ super().__init__()
102
+ self.fc1 = nn.Linear(dim, hidden_dim)
103
+ self.fc2 = nn.Linear(hidden_dim, dim)
104
+ self.hidden_act = hidden_act
105
+
106
+ def __call__(self, x: mx.array) -> mx.array:
107
+ x = self.fc1(x)
108
+ if self.hidden_act == "quick_gelu":
109
+ x = x * mx.sigmoid(1.702 * x)
110
+ elif self.hidden_act == "gelu":
111
+ x = nn.gelu(x)
112
+ elif self.hidden_act == "silu":
113
+ x = nn.silu(x)
114
+ else:
115
+ x = nn.gelu(x)
116
+ return self.fc2(x)
117
+
118
+
119
+ class VisionAttention(nn.Module):
120
+ """Multi-head attention for vision transformer."""
121
+
122
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
123
+ super().__init__()
124
+ self.num_heads = num_heads
125
+ self.head_dim = dim // num_heads
126
+ self.scale = self.head_dim**-0.5
127
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
128
+ self.proj = nn.Linear(dim, dim)
129
+
130
+ def __call__(
131
+ self,
132
+ x: mx.array,
133
+ cu_seqlens: mx.array,
134
+ rotary_pos_emb: Optional[mx.array] = None,
135
+ ) -> mx.array:
136
+ """Forward function for vision attention."""
137
+ seq_length = x.shape[0]
138
+ qkv = (
139
+ self.qkv(x).reshape(seq_length, 3, self.num_heads, -1).transpose(1, 0, 2, 3)
140
+ )
141
+ q, k, v = mx.split(qkv, 3)
142
+
143
+ q = apply_rotary_pos_emb_vision(mx.expand_dims(q, 0), rotary_pos_emb)[0]
144
+ k = apply_rotary_pos_emb_vision(mx.expand_dims(k, 0), rotary_pos_emb)[0]
145
+
146
+ q = q.transpose(0, 2, 1, 3)
147
+ k = k.transpose(0, 2, 1, 3)
148
+ v = v.transpose(0, 2, 1, 3)
149
+
150
+ lengths = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
151
+ splits = [
152
+ mx.split(tensor, [lengths[0], sum(lengths[:2])], axis=2)
153
+ for tensor in (q, k, v)
154
+ ]
155
+
156
+ attn_outputs = []
157
+ for q, k, v in zip(*splits):
158
+ output = mx.fast.scaled_dot_product_attention(
159
+ q, k, v, scale=self.scale, mask=None
160
+ )
161
+ attn_outputs.append(output)
162
+
163
+ output = mx.concatenate(attn_outputs, axis=2)
164
+ output = output.transpose(0, 2, 1, 3).reshape(seq_length, -1)
165
+ return self.proj(output)
166
+
167
+
168
+ class DFNRopeVisionBlock(nn.Module):
169
+ """DFNRope Vision Transformer block."""
170
+
171
+ def __init__(self, config: VisionConfig) -> None:
172
+ super().__init__()
173
+ self.norm1 = nn.LayerNorm(config.embed_dim, eps=config.layer_norm_eps)
174
+ self.norm2 = nn.LayerNorm(config.embed_dim, eps=config.layer_norm_eps)
175
+
176
+ mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio)
177
+ self.attn = VisionAttention(config.embed_dim, num_heads=config.num_heads)
178
+ self.mlp = VisionMLP(
179
+ dim=config.embed_dim,
180
+ hidden_dim=mlp_hidden_dim,
181
+ hidden_act=config.hidden_act,
182
+ )
183
+
184
+ def __call__(
185
+ self, hidden_states: mx.array, cu_seqlens: mx.array, rotary_pos_emb: mx.array
186
+ ) -> mx.array:
187
+ hidden_states = hidden_states + self.attn(
188
+ self.norm1(hidden_states),
189
+ cu_seqlens=cu_seqlens,
190
+ rotary_pos_emb=rotary_pos_emb,
191
+ )
192
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
193
+ return hidden_states
194
+
195
+
196
+ class VisionModel(nn.Module):
197
+ """DFNRope Vision Transformer for ERNIE 4.5 VL."""
198
+
199
+ def __init__(self, config: VisionConfig) -> None:
200
+ super().__init__()
201
+ self.config = config
202
+ self.model_type = config.model_type
203
+ self.spatial_merge_size = config.spatial_merge_size
204
+
205
+ self.patch_embed = PatchEmbed(
206
+ patch_size=config.patch_size,
207
+ in_channels=config.in_channels,
208
+ embed_dim=config.embed_dim,
209
+ )
210
+
211
+ head_dim = config.embed_dim // config.num_heads
212
+ self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
213
+
214
+ self.blocks = [DFNRopeVisionBlock(config) for _ in range(config.depth)]
215
+ self.ln = nn.LayerNorm(config.embed_dim, eps=config.layer_norm_eps)
216
+
217
+ def rot_pos_emb(self, grid_thw: mx.array, num_pad: int = 0) -> mx.array:
218
+ """Compute rotary position embedding for vision.
219
+
220
+ Args:
221
+ grid_thw: Grid dimensions [batch, 3] containing (t, h, w)
222
+ num_pad: Number of padding tokens
223
+
224
+ Returns:
225
+ Rotary position embedding tensor
226
+ """
227
+ pos_ids = []
228
+ grid_hw_array = np.array(grid_thw.tolist(), dtype=np.int64)
229
+
230
+ for t, h, w in grid_hw_array:
231
+ hpos_ids = np.arange(h).reshape(-1, 1)
232
+ hpos_ids = np.tile(hpos_ids, (1, w))
233
+ hpos_ids = hpos_ids.reshape(
234
+ h // self.spatial_merge_size,
235
+ self.spatial_merge_size,
236
+ w // self.spatial_merge_size,
237
+ self.spatial_merge_size,
238
+ )
239
+ hpos_ids = np.transpose(hpos_ids, (0, 2, 1, 3))
240
+ hpos_ids = hpos_ids.flatten()
241
+
242
+ wpos_ids = np.arange(w).reshape(1, -1)
243
+ wpos_ids = np.tile(wpos_ids, (h, 1))
244
+ wpos_ids = wpos_ids.reshape(
245
+ h // self.spatial_merge_size,
246
+ self.spatial_merge_size,
247
+ w // self.spatial_merge_size,
248
+ self.spatial_merge_size,
249
+ )
250
+ wpos_ids = np.transpose(wpos_ids, (0, 2, 1, 3))
251
+ wpos_ids = wpos_ids.flatten()
252
+
253
+ stacked_ids = np.stack([hpos_ids, wpos_ids], axis=-1)
254
+ tiled_ids = np.tile(stacked_ids, (t, 1))
255
+ pos_ids.append(tiled_ids)
256
+
257
+ pos_ids = np.concatenate(pos_ids, axis=0)
258
+ if num_pad > 0:
259
+ pos_ids = np.concatenate(
260
+ [pos_ids, np.zeros((num_pad, 2), dtype=pos_ids.dtype)], axis=0
261
+ )
262
+
263
+ max_grid_size = int(np.max(grid_hw_array[:, 1:]))
264
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
265
+ pos_ids_mx = mx.array(pos_ids, dtype=mx.int32)
266
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids_mx].reshape(pos_ids.shape[0], -1)
267
+
268
+ return rotary_pos_emb
269
+
270
+ def __call__(
271
+ self,
272
+ hidden_states: mx.array,
273
+ grid_thw: mx.array,
274
+ output_hidden_states: Optional[bool] = None,
275
+ num_pad: int = 0,
276
+ ) -> mx.array:
277
+ """Forward pass through the vision model.
278
+
279
+ Args:
280
+ hidden_states: Input pixel values [num_patches, channels * patch_h * patch_w]
281
+ grid_thw: Grid dimensions [batch, 3]
282
+ output_hidden_states: Whether to output hidden states
283
+ num_pad: Number of padding tokens
284
+
285
+ Returns:
286
+ Vision features
287
+ """
288
+ hidden_states = self.patch_embed(hidden_states)
289
+ rotary_pos_emb = self.rot_pos_emb(grid_thw, num_pad=num_pad)
290
+
291
+ # Compute cumulative sequence lengths
292
+ cu_seqlens = mx.zeros(1, dtype=mx.int32)
293
+ for i in range(grid_thw.shape[0]):
294
+ t, h, w = grid_thw[i].tolist()
295
+ seq_len = t * h * w
296
+ cu_seqlens = mx.concatenate([cu_seqlens, cu_seqlens[-1:] + seq_len])
297
+
298
+ if num_pad > 0:
299
+ cu_seqlens = mx.concatenate([cu_seqlens, cu_seqlens[-1:] + num_pad])
300
+
301
+ encoder_states = (hidden_states,) if output_hidden_states else None
302
+
303
+ for blk in self.blocks:
304
+ hidden_states = blk(
305
+ hidden_states,
306
+ cu_seqlens=cu_seqlens,
307
+ rotary_pos_emb=rotary_pos_emb,
308
+ )
309
+ if output_hidden_states:
310
+ encoder_states = encoder_states + (hidden_states,)
311
+
312
+ hidden_states = self.ln(hidden_states)
313
+ return hidden_states
314
+
315
+ def sanitize(self, weights):
316
+ """Sanitize weights for loading."""
317
+ sanitized_weights = {}
318
+ for k, v in weights.items():
319
+ if "position_ids" in k:
320
+ continue
321
+ sanitized_weights[k] = v
322
+ return sanitized_weights
@@ -0,0 +1,2 @@
1
+ from .config import ModelConfig, TextConfig, VisionConfig
2
+ from .fastvlm import LanguageModel, Model, VisionModel
@@ -0,0 +1,79 @@
1
+ import inspect
2
+ from dataclasses import dataclass
3
+ from typing import Dict, Optional, Union
4
+
5
+ from ..base import BaseModelConfig
6
+
7
+
8
+ @dataclass
9
+ class TextConfig(BaseModelConfig):
10
+ model_type: str
11
+ hidden_size: int = 896
12
+ num_hidden_layers: int = 24
13
+ intermediate_size: int = 4864
14
+ num_attention_heads: int = 14
15
+ rms_norm_eps: float = 1e-06
16
+ vocab_size: int = 151936
17
+ num_key_value_heads: int = 2
18
+ max_position_embeddings: int = 32768
19
+ rope_theta: float = 1000000
20
+ rope_traditional: bool = False
21
+ rope_scaling: Optional[Dict[str, Union[float, str]]] = None
22
+ tie_word_embeddings: bool = True
23
+
24
+
25
+ @dataclass
26
+ class VisionConfig(BaseModelConfig):
27
+ model_type: str = "llava_qwen2" # fastvlm?
28
+ hidden_size: int = 1024
29
+ intermediate_size: int = 3072
30
+ image_size: int = 1024
31
+ patch_size: int = 64
32
+ projection_dim: int = 768
33
+ num_classes = 1000
34
+ down_patch_size = 7
35
+ down_stride = 2
36
+ layer_scale_init_value = 1e-5
37
+ cls_ratio = 2.0
38
+ # FastViTHD variant
39
+ layers = [2, 12, 24, 4, 2]
40
+ embed_dims = [96, 192, 384, 768, 1536]
41
+ mlp_ratios = [4, 4, 4, 4, 4]
42
+ downsamples = [True, True, True, True, True]
43
+ pos_embs_shapes = [None, None, None, (7, 7), (7, 7)]
44
+ token_mixers = ("repmixer", "repmixer", "repmixer", "attention", "attention")
45
+ repmixer_kernel_size = 3
46
+
47
+
48
+ @dataclass
49
+ class ModelConfig(BaseModelConfig):
50
+ text_config: TextConfig
51
+ vision_config: VisionConfig
52
+ model_type: str = "llava_qwen2" # fastvlm?
53
+ ignore_index: int = -100
54
+ image_token_index: int = -200
55
+ eos_token_id: int = 151645
56
+ mm_projector_type: str = "mlp2x_gelu"
57
+ mm_hidden_size: int = 3072
58
+ tokenizer_model_max_length: int = 8192
59
+ tokenizer_padding_side: str = "right"
60
+
61
+ @classmethod
62
+ def from_dict(cls, params):
63
+ if not params.get("text_config", {}):
64
+ # Copy text config parameters from root level
65
+ excluded_keys = {"vision_config"}
66
+ params["text_config"] = dict(
67
+ filter(lambda x: x[0] not in excluded_keys, params.items())
68
+ )
69
+
70
+ if not params.get("vision_config", {}):
71
+ params["vision_config"] = {}
72
+
73
+ return cls(
74
+ **{
75
+ k: v
76
+ for k, v in params.items()
77
+ if k in inspect.signature(cls).parameters
78
+ }
79
+ )
@@ -0,0 +1,198 @@
1
+ import re
2
+ from typing import Optional
3
+
4
+ import mlx.core as mx
5
+ import mlx.nn as nn
6
+ import numpy as np
7
+
8
+ from ..base import InputEmbeddingsFeatures
9
+ from .config import ModelConfig
10
+ from .language import LanguageModel
11
+ from .vision import CallableModuleList, VisionModel
12
+
13
+
14
+ def build_vision_projector(config):
15
+ hidden_size = config.text_config.hidden_size
16
+ projector_type = getattr(config, "mm_projector_type", "mlp2x_gelu")
17
+ if projector_type == "linear":
18
+ return nn.Linear(config.mm_hidden_size, hidden_size)
19
+
20
+ mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
21
+ if mlp_gelu_match:
22
+ mlp_depth = int(mlp_gelu_match.group(1))
23
+ modules = CallableModuleList()
24
+ modules.append(nn.Linear(config.mm_hidden_size, hidden_size))
25
+ for _ in range(1, mlp_depth):
26
+ modules.append(nn.GELU())
27
+ modules.append(nn.Linear(hidden_size, hidden_size))
28
+ return modules
29
+ raise ValueError(f"Unknown projector type: {projector_type}")
30
+
31
+
32
+ class Model(nn.Module):
33
+ def __init__(self, config: ModelConfig):
34
+ super().__init__()
35
+ self.config = config
36
+ self.vision_tower = VisionModel(config.vision_config)
37
+ self.language_model = LanguageModel(config.text_config)
38
+ self.mm_projector = build_vision_projector(config)
39
+
40
+ def get_input_embeddings(
41
+ self,
42
+ input_ids: Optional[mx.array] = None,
43
+ pixel_values: Optional[mx.array] = None,
44
+ mask: Optional[mx.array] = None,
45
+ **kwargs,
46
+ ):
47
+ if pixel_values is None:
48
+ return InputEmbeddingsFeatures(
49
+ inputs_embeds=self.language_model.model.embed_tokens(input_ids)
50
+ )
51
+
52
+ _, image_features, _ = self.vision_tower(pixel_values.transpose(0, 2, 3, 1))
53
+ B, H, W, C = image_features.shape
54
+ image_features = image_features.reshape(B, H * W, C)
55
+ image_features = self.mm_projector(image_features)
56
+
57
+ final_inputs_embeds = self.prepare_inputs_for_multimodal(
58
+ image_features, input_ids, mask
59
+ )
60
+ return InputEmbeddingsFeatures(inputs_embeds=final_inputs_embeds)
61
+
62
+ # Source: https://github.com/apple/ml-fastvlm/blob/592b4add3c1c8a518e77d95dc6248e76c1dd591f/llava/model/llava_arch.py#L146
63
+ def prepare_inputs_for_multimodal(self, image_features, input_ids, mask):
64
+ if mask is not None:
65
+ input_ids = [
66
+ cur_input_ids[
67
+ (start := mx.argmax(cur_mask).item()) : start
68
+ + cur_mask.sum().item()
69
+ ]
70
+ for cur_input_ids, cur_mask in zip(input_ids, mask)
71
+ ]
72
+
73
+ new_input_embeds = []
74
+ cur_image_idx = 0
75
+ for batch_idx, cur_input_ids in enumerate(input_ids):
76
+ num_images = (cur_input_ids == self.config.image_token_index).sum()
77
+ if num_images == 0:
78
+ cur_image_features = image_features[cur_image_idx]
79
+ cur_input_embeds_1 = self.language_model.model.embed_tokens(
80
+ cur_input_ids
81
+ )
82
+ cur_input_embeds = mx.concatenate(
83
+ [cur_input_embeds_1, cur_image_features[0:0]], dim=0
84
+ )
85
+ new_input_embeds.append(cur_input_embeds)
86
+ cur_image_idx += 1
87
+ continue
88
+
89
+ image_token_indices = (
90
+ [-1]
91
+ + np.where(np.array(cur_input_ids == self.config.image_token_index))[
92
+ 0
93
+ ].tolist()
94
+ + [cur_input_ids.shape[0]]
95
+ )
96
+ cur_input_ids_noim = []
97
+ for i in range(len(image_token_indices) - 1):
98
+ cur_input_ids_noim.append(
99
+ cur_input_ids[
100
+ image_token_indices[i] + 1 : image_token_indices[i + 1]
101
+ ]
102
+ )
103
+ split_sizes = image_token_indices[1:]
104
+ cur_input_embeds = self.language_model.model.embed_tokens(
105
+ mx.concatenate(cur_input_ids_noim)
106
+ )
107
+ cur_input_embeds_no_im = mx.split(cur_input_embeds, split_sizes)
108
+
109
+ cur_new_input_embeds = []
110
+ for i in range(num_images.item() + 1):
111
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
112
+ if i < num_images:
113
+ cur_image_features = image_features[cur_image_idx]
114
+ cur_image_idx += 1
115
+ cur_new_input_embeds.append(cur_image_features)
116
+ cur_new_input_embeds = mx.concatenate(cur_new_input_embeds)
117
+
118
+ new_input_embeds.append(cur_new_input_embeds)
119
+
120
+ if self.config.tokenizer_model_max_length is not None:
121
+ new_input_embeds = [
122
+ x[: self.config.tokenizer_model_max_length] for x in new_input_embeds
123
+ ]
124
+
125
+ max_len = max(x.shape[0] for x in new_input_embeds)
126
+ new_input_embeds_padded = []
127
+ for i, cur_new_embed in enumerate(new_input_embeds):
128
+ cur_len = cur_new_embed.shape[0]
129
+ padded = cur_new_embed
130
+ if max_len > cur_len:
131
+ if self.config.tokenizer_padding_side == "left":
132
+ padded = mx.concatenate(
133
+ (
134
+ mx.zeros(
135
+ (max_len - cur_len, cur_new_embed.shape[1]),
136
+ dtype=cur_new_embed.dtype,
137
+ ),
138
+ cur_new_embed,
139
+ ),
140
+ dim=0,
141
+ )
142
+ else:
143
+ padded = mx.concatenate(
144
+ (
145
+ cur_new_embed,
146
+ mx.zeros(
147
+ (max_len - cur_len, cur_new_embed.shape[1]),
148
+ dtype=cur_new_embed.dtype,
149
+ ),
150
+ ),
151
+ dim=0,
152
+ )
153
+ new_input_embeds_padded.append(padded)
154
+ new_input_embeds = mx.stack(new_input_embeds_padded)
155
+ return new_input_embeds
156
+
157
+ @property
158
+ def layers(self):
159
+ return self.language_model.model.layers
160
+
161
+ def __call__(
162
+ self,
163
+ input_ids: mx.array,
164
+ pixel_values: mx.array,
165
+ mask: mx.array,
166
+ cache=None,
167
+ **kwargs,
168
+ ):
169
+ input_embeddings_features = self.get_input_embeddings(
170
+ input_ids, pixel_values, mask
171
+ )
172
+ logits = self.language_model(
173
+ input_ids,
174
+ mask=mask,
175
+ cache=cache,
176
+ inputs_embeds=input_embeddings_features.inputs_embeds,
177
+ )
178
+ return logits
179
+
180
+ def sanitize(self, weights):
181
+ def transform_key(key):
182
+ if "vision_tower" in key:
183
+ if "model.vision_tower" in key:
184
+ key = key.replace(
185
+ "model.vision_tower.vision_tower.model",
186
+ "vision_tower.vision_model",
187
+ )
188
+ key = key.replace("patch_embed", "patch_embed.blocks")
189
+ return key
190
+ if "lm_head" in key:
191
+ return key
192
+ if "mm_projector" in key:
193
+ return key.replace("model.", "")
194
+ if "language_model" not in key:
195
+ return "language_model." + key
196
+ return key
197
+
198
+ return {transform_key(k): v for k, v in weights.items()}
@@ -0,0 +1,49 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ from mlx_lm.models.qwen2 import Qwen2Model
6
+
7
+ from ..base import LanguageModelOutput
8
+ from .config import TextConfig
9
+
10
+
11
+ class LanguageModel(nn.Module):
12
+ def __init__(self, config: TextConfig):
13
+ super().__init__()
14
+ self.config = config
15
+ self.model_type = config.model_type
16
+ self.model = Qwen2Model(config)
17
+ if not config.tie_word_embeddings:
18
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
19
+
20
+ # TODO: mask is going away in mlx-lm, see https://github.com/ml-explore/mlx-lm/pull/430
21
+ def __call__(
22
+ self,
23
+ inputs: mx.array,
24
+ mask: mx.array = None,
25
+ cache=None,
26
+ inputs_embeds: Optional[mx.array] = None,
27
+ ):
28
+ out = self.model(inputs, cache=cache, input_embeddings=inputs_embeds)
29
+ out = self.model.embed_tokens.as_linear(out)
30
+ return LanguageModelOutput(out)
31
+
32
+ def sanitize(self, weights):
33
+ if self.config.tie_word_embeddings:
34
+ weights.pop("lm_head.weight", None)
35
+ return {
36
+ k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
37
+ }
38
+
39
+ @property
40
+ def layers(self):
41
+ return self.model.layers
42
+
43
+ @property
44
+ def head_dim(self):
45
+ return self.args.hidden_size // self.args.num_attention_heads
46
+
47
+ @property
48
+ def n_kv_heads(self):
49
+ return self.args.num_key_value_heads