fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,439 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from .config import Qwen2EncoderConfig, VisionConfig
7
+
8
+
9
+ def check_array_shape(arr):
10
+ shape = arr.shape
11
+
12
+ # Check if the shape has 4 dimensions
13
+ if len(shape) != 4:
14
+ return False
15
+
16
+ out_channels, kH, KW, _ = shape
17
+
18
+ # Check if out_channels is the largest, and kH and KW are the same
19
+ if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
20
+ return True
21
+ else:
22
+ return False
23
+
24
+
25
+ class Qwen2RMSNorm(nn.Module):
26
+ """RMSNorm for Qwen2 encoder."""
27
+
28
+ def __init__(self, hidden_size: int, eps: float = 1e-6):
29
+ super().__init__()
30
+ self.weight = mx.ones((hidden_size,))
31
+ self.variance_epsilon = eps
32
+
33
+ def __call__(self, hidden_states: mx.array) -> mx.array:
34
+ input_dtype = hidden_states.dtype
35
+ hidden_states = hidden_states.astype(mx.float32)
36
+ variance = mx.mean(hidden_states**2, axis=-1, keepdims=True)
37
+ hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon)
38
+ return self.weight * hidden_states.astype(input_dtype)
39
+
40
+
41
+ class Qwen2RotaryEmbedding(nn.Module):
42
+ """Rotary position embeddings for Qwen2."""
43
+
44
+ def __init__(
45
+ self, dim: int, max_position_embeddings: int = 2048, base: float = 1000000.0
46
+ ):
47
+ super().__init__()
48
+ self.dim = dim
49
+ self.max_position_embeddings = max_position_embeddings
50
+ self.base = base
51
+ # Note: inv_freq is computed on-the-fly, not stored as a parameter
52
+
53
+ def __call__(
54
+ self, x: mx.array, position_ids: mx.array
55
+ ) -> Tuple[mx.array, mx.array]:
56
+ # Compute inv_freq on the fly (not stored as parameter)
57
+ inv_freq = 1.0 / (
58
+ self.base ** (mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim)
59
+ )
60
+
61
+ # position_ids: [batch_size, seq_len]
62
+ # inv_freq: [head_dim // 2]
63
+ # We want freqs of shape [batch_size, seq_len, head_dim // 2]
64
+
65
+ # Outer product: position_ids[:, :, None] * inv_freq[None, None, :]
66
+ position_ids_float = position_ids[:, :, None].astype(mx.float32) # [B, S, 1]
67
+ inv_freq_expanded = inv_freq[None, None, :] # [1, 1, D//2]
68
+ freqs = position_ids_float * inv_freq_expanded # [B, S, D//2]
69
+
70
+ emb = mx.concatenate([freqs, freqs], axis=-1) # [B, S, D]
71
+ cos = mx.cos(emb)
72
+ sin = mx.sin(emb)
73
+ return cos.astype(x.dtype), sin.astype(x.dtype)
74
+
75
+
76
+ def rotate_half(x: mx.array) -> mx.array:
77
+ """Rotates half the hidden dims of the input."""
78
+ x1 = x[..., : x.shape[-1] // 2]
79
+ x2 = x[..., x.shape[-1] // 2 :]
80
+ return mx.concatenate([-x2, x1], axis=-1)
81
+
82
+
83
+ def apply_rotary_pos_emb(
84
+ q: mx.array, k: mx.array, cos: mx.array, sin: mx.array
85
+ ) -> Tuple[mx.array, mx.array]:
86
+ """Apply rotary position embedding to query and key tensors."""
87
+ cos = cos[:, None, :, :]
88
+ sin = sin[:, None, :, :]
89
+ q_embed = (q * cos) + (rotate_half(q) * sin)
90
+ k_embed = (k * cos) + (rotate_half(k) * sin)
91
+ return q_embed, k_embed
92
+
93
+
94
+ class Qwen2MLP(nn.Module):
95
+ """MLP for Qwen2 encoder."""
96
+
97
+ def __init__(self, config: Qwen2EncoderConfig):
98
+ super().__init__()
99
+ self.hidden_size = config.dim
100
+ self.intermediate_size = config.intermediate_size
101
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
102
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
103
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
104
+
105
+ def __call__(self, x: mx.array) -> mx.array:
106
+ return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
107
+
108
+
109
+ class Qwen2Attention(nn.Module):
110
+ """Multi-head attention for Qwen2 encoder with GQA support."""
111
+
112
+ def __init__(self, config: Qwen2EncoderConfig, layer_idx: int = 0):
113
+ super().__init__()
114
+ self.config = config
115
+ self.layer_idx = layer_idx
116
+
117
+ self.hidden_size = config.dim
118
+ self.num_heads = config.heads
119
+ self.head_dim = self.hidden_size // self.num_heads
120
+ self.num_key_value_heads = config.kv_heads
121
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
122
+
123
+ self.q_proj = nn.Linear(
124
+ self.hidden_size, self.num_heads * self.head_dim, bias=True
125
+ )
126
+ self.k_proj = nn.Linear(
127
+ self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
128
+ )
129
+ self.v_proj = nn.Linear(
130
+ self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
131
+ )
132
+ self.o_proj = nn.Linear(
133
+ self.num_heads * self.head_dim, self.hidden_size, bias=False
134
+ )
135
+
136
+ self.rotary_emb = Qwen2RotaryEmbedding(
137
+ self.head_dim,
138
+ max_position_embeddings=2048,
139
+ base=config.rope_theta,
140
+ )
141
+ self.scale = self.head_dim**-0.5
142
+
143
+ def __call__(
144
+ self,
145
+ hidden_states: mx.array,
146
+ attention_mask: Optional[mx.array] = None,
147
+ position_ids: Optional[mx.array] = None,
148
+ ) -> mx.array:
149
+ bsz, q_len, _ = hidden_states.shape
150
+
151
+ query_states = self.q_proj(hidden_states)
152
+ key_states = self.k_proj(hidden_states)
153
+ value_states = self.v_proj(hidden_states)
154
+
155
+ query_states = query_states.reshape(
156
+ bsz, q_len, self.num_heads, self.head_dim
157
+ ).transpose(0, 2, 1, 3)
158
+ key_states = key_states.reshape(
159
+ bsz, q_len, self.num_key_value_heads, self.head_dim
160
+ ).transpose(0, 2, 1, 3)
161
+ value_states = value_states.reshape(
162
+ bsz, q_len, self.num_key_value_heads, self.head_dim
163
+ ).transpose(0, 2, 1, 3)
164
+
165
+ cos, sin = self.rotary_emb(value_states, position_ids)
166
+ query_states, key_states = apply_rotary_pos_emb(
167
+ query_states, key_states, cos, sin
168
+ )
169
+
170
+ # Repeat KV heads for GQA
171
+ if self.num_key_value_groups > 1:
172
+ key_states = mx.repeat(key_states, self.num_key_value_groups, axis=1)
173
+ value_states = mx.repeat(value_states, self.num_key_value_groups, axis=1)
174
+
175
+ attn_output = mx.fast.scaled_dot_product_attention(
176
+ query_states,
177
+ key_states,
178
+ value_states,
179
+ scale=self.scale,
180
+ mask=attention_mask,
181
+ )
182
+
183
+ attn_output = attn_output.transpose(0, 2, 1, 3).reshape(bsz, q_len, -1)
184
+ attn_output = self.o_proj(attn_output)
185
+
186
+ return attn_output
187
+
188
+
189
+ class Qwen2DecoderLayer(nn.Module):
190
+ """Transformer layer for Qwen2 encoder."""
191
+
192
+ def __init__(self, config: Qwen2EncoderConfig, layer_idx: int = 0):
193
+ super().__init__()
194
+ self.hidden_size = config.dim
195
+ self.self_attn = Qwen2Attention(config, layer_idx)
196
+ self.mlp = Qwen2MLP(config)
197
+ self.input_layernorm = Qwen2RMSNorm(config.dim, eps=config.rms_norm_eps)
198
+ self.post_attention_layernorm = Qwen2RMSNorm(
199
+ config.dim, eps=config.rms_norm_eps
200
+ )
201
+
202
+ def __call__(
203
+ self,
204
+ hidden_states: mx.array,
205
+ attention_mask: Optional[mx.array] = None,
206
+ position_ids: Optional[mx.array] = None,
207
+ ) -> mx.array:
208
+ residual = hidden_states
209
+ hidden_states = self.input_layernorm(hidden_states)
210
+ hidden_states = self.self_attn(
211
+ hidden_states=hidden_states,
212
+ attention_mask=attention_mask,
213
+ position_ids=position_ids,
214
+ )
215
+ hidden_states = residual + hidden_states
216
+
217
+ residual = hidden_states
218
+ hidden_states = self.post_attention_layernorm(hidden_states)
219
+ hidden_states = self.mlp(hidden_states)
220
+ hidden_states = residual + hidden_states
221
+
222
+ return hidden_states
223
+
224
+
225
+ class Qwen2Decoder2Encoder(nn.Module):
226
+ """Qwen2-based decoder used as encoder for vision features.
227
+
228
+ Takes SAM features and processes them through Qwen2 transformer layers
229
+ using learnable queries to produce fixed-size output.
230
+ """
231
+
232
+ def __init__(self, config: Qwen2EncoderConfig):
233
+ super().__init__()
234
+ self.config = config
235
+
236
+ # Learnable queries for cross-attention
237
+ # query_1024: (256, dim) - for 1024x1024 images (SAM outputs 16x16=256 features)
238
+ # query_768: (144, dim) - for 768x768 images (SAM outputs 12x12=144 features)
239
+ # Initialized with zeros, will be loaded from weights
240
+ self.query_1024 = mx.zeros((256, config.dim))
241
+ self.query_768 = mx.zeros((144, config.dim))
242
+
243
+ # Transformer layers
244
+ self.layers = [
245
+ Qwen2DecoderLayer(config, layer_idx=i) for i in range(config.layers)
246
+ ]
247
+
248
+ # Final layer norm
249
+ self.norm = Qwen2RMSNorm(config.dim, eps=config.rms_norm_eps)
250
+
251
+ def __call__(self, sam_features: mx.array) -> mx.array:
252
+ """Process SAM features through Qwen2 encoder.
253
+
254
+ Args:
255
+ sam_features: SAM encoder output of shape (B, H, W, C) where C=896
256
+
257
+ Returns:
258
+ Encoded features of shape (B, 256, C)
259
+ """
260
+ batch_size = sam_features.shape[0]
261
+
262
+ # Flatten spatial dimensions: (B, H, W, C) -> (B, H*W, C)
263
+ sam_features_flat = sam_features.reshape(batch_size, -1, self.config.dim)
264
+ num_image_tokens = sam_features_flat.shape[1]
265
+
266
+ # Select appropriate query based on number of image tokens
267
+ # 256 tokens -> use query_1024 (for 1024x1024 images, SAM outputs 16x16)
268
+ # 144 tokens -> use query_768 (for 768x768 images, SAM outputs 12x12)
269
+ if num_image_tokens == 256:
270
+ query_embed = self.query_1024
271
+ num_queries = 256
272
+ elif num_image_tokens == 144:
273
+ query_embed = self.query_768
274
+ num_queries = 144
275
+ else:
276
+ # Default to query_1024 for unexpected sizes
277
+ query_embed = self.query_1024
278
+ num_queries = 256
279
+
280
+ # Expand queries for batch: (num_queries, C) -> (B, num_queries, C)
281
+ queries = mx.broadcast_to(
282
+ query_embed[None, :, :], (batch_size, num_queries, self.config.dim)
283
+ )
284
+
285
+ # Concatenate: image tokens + query tokens
286
+ # Shape: (B, num_image_tokens + num_queries, C)
287
+ hidden_states = mx.concatenate([sam_features_flat, queries], axis=1)
288
+ seq_len = hidden_states.shape[1]
289
+
290
+ # Create mixed attention mask:
291
+ # - Image tokens can attend to all image tokens (bidirectional)
292
+ # - Image tokens CANNOT attend to query tokens (blocked)
293
+ # - Query tokens can attend to all image tokens
294
+ # - Query tokens use causal attention within queries (can attend to self + previous)
295
+ # Shape: (1, 1, seq_len, seq_len) - will be broadcast across batch and heads
296
+ mask_dtype = hidden_states.dtype
297
+
298
+ # Start with all positions blocked (large negative value)
299
+ mask = mx.full((seq_len, seq_len), -1e9, dtype=mx.float32)
300
+
301
+ # 1. Image tokens can attend to all image tokens (bidirectional)
302
+ # mask[0:num_image_tokens, 0:num_image_tokens] = 0
303
+ image_to_image = mx.zeros(
304
+ (num_image_tokens, num_image_tokens), dtype=mx.float32
305
+ )
306
+ mask = mx.concatenate(
307
+ [
308
+ mx.concatenate(
309
+ [image_to_image, mask[:num_image_tokens, num_image_tokens:]], axis=1
310
+ ),
311
+ mask[num_image_tokens:, :],
312
+ ],
313
+ axis=0,
314
+ )
315
+
316
+ # 2. Query tokens can attend to all image tokens
317
+ # mask[num_image_tokens:, 0:num_image_tokens] = 0
318
+ query_to_image = mx.zeros((num_queries, num_image_tokens), dtype=mx.float32)
319
+ mask = mx.concatenate(
320
+ [
321
+ mask[:num_image_tokens, :],
322
+ mx.concatenate(
323
+ [query_to_image, mask[num_image_tokens:, num_image_tokens:]], axis=1
324
+ ),
325
+ ],
326
+ axis=0,
327
+ )
328
+
329
+ # 3. Query tokens use causal attention (can attend to self + previous queries)
330
+ # Create lower triangular mask for query-query region
331
+ query_causal = mx.tril(mx.zeros((num_queries, num_queries), dtype=mx.float32))
332
+ query_causal = query_causal + mx.triu(
333
+ mx.full((num_queries, num_queries), -1e9, dtype=mx.float32), k=1
334
+ )
335
+
336
+ # Update query-query region in mask
337
+ mask = mx.concatenate(
338
+ [
339
+ mask[:, :num_image_tokens],
340
+ mx.concatenate(
341
+ [mask[:num_image_tokens, num_image_tokens:], query_causal], axis=0
342
+ ),
343
+ ],
344
+ axis=1,
345
+ )
346
+
347
+ # Cast to input dtype and reshape for attention: (1, 1, seq_len, seq_len)
348
+ attention_mask = mask.astype(mask_dtype)[None, None, :, :]
349
+
350
+ # Create position IDs
351
+ position_ids = mx.broadcast_to(
352
+ mx.arange(seq_len)[None, :], (batch_size, seq_len)
353
+ )
354
+
355
+ # Process through transformer layers
356
+ for layer in self.layers:
357
+ hidden_states = layer(
358
+ hidden_states,
359
+ attention_mask=attention_mask,
360
+ position_ids=position_ids,
361
+ )
362
+
363
+ # Apply final layer norm
364
+ hidden_states = self.norm(hidden_states)
365
+
366
+ # Return only the query tokens (last num_queries tokens)
367
+ return hidden_states[:, -num_queries:, :]
368
+
369
+
370
+ class VisionModel(nn.Module):
371
+ """Vision model for DeepSeek-OCR-2 using Qwen2 encoder."""
372
+
373
+ def __init__(self, config: VisionConfig):
374
+ super().__init__()
375
+ self.model_type = config.model_type
376
+ self.config = config
377
+
378
+ if self.model_type != "vision":
379
+ raise ValueError(f"Unsupported model type: {self.model_type}")
380
+
381
+ # Get Qwen2 config from params
382
+ qwen2_params = config.params.get("qwen2", {})
383
+ qwen2_config = Qwen2EncoderConfig(
384
+ dim=qwen2_params.get("dim", 896),
385
+ layers=qwen2_params.get("layers", 24),
386
+ heads=qwen2_params.get("heads", 14),
387
+ kv_heads=qwen2_params.get("kv_heads", 2),
388
+ intermediate_size=qwen2_params.get("intermediate_size", 4864),
389
+ rms_norm_eps=qwen2_params.get("rms_norm_eps", 1e-6),
390
+ rope_theta=qwen2_params.get("rope_theta", 1000000.0),
391
+ )
392
+
393
+ self.qwen2_encoder = Qwen2Decoder2Encoder(qwen2_config)
394
+
395
+ def __call__(self, x: mx.array, sam_features: mx.array) -> mx.array:
396
+ """Process vision input through Qwen2 encoder.
397
+
398
+ Args:
399
+ x: Original image tensor (not used, kept for API compatibility)
400
+ sam_features: SAM encoder output of shape (B, H, W, C)
401
+
402
+ Returns:
403
+ Encoded features of shape (B, 256, C)
404
+ """
405
+ return self.qwen2_encoder(sam_features)
406
+
407
+ def sanitize(self, weights):
408
+ sanitized_weights = {}
409
+ weight_keys = {
410
+ "neck.0.weight",
411
+ "neck.2.weight",
412
+ "neck_hd.0.weight",
413
+ "neck_hd.2.weight",
414
+ "sam_model.net_2.weight",
415
+ "sam_model.net_3.weight",
416
+ "downsamples.0.weight",
417
+ "downsamples.1.weight",
418
+ "patch_embed.proj.weight",
419
+ "embeddings.patch_embedding.weight",
420
+ }
421
+ for k, v in weights.items():
422
+ if "position_ids" in k:
423
+ # Remove unused position_ids
424
+ continue
425
+
426
+ elif ".".join(k.split(".")[-3:]) in weight_keys:
427
+ # PyTorch conv2d weight tensors have shape:
428
+ # [out_channels, in_channels, kH, KW]
429
+ # MLX conv2d expects the weight be of shape:
430
+ # [out_channels, kH, KW, in_channels]
431
+ if check_array_shape(v):
432
+ sanitized_weights[k] = v
433
+ else:
434
+ sanitized_weights[k] = v.transpose(0, 2, 3, 1)
435
+
436
+ else:
437
+ sanitized_weights[k] = v
438
+
439
+ return sanitized_weights
@@ -0,0 +1,5 @@
1
+ from .config import ModelConfig, TextConfig, VisionConfig
2
+ from .ernie4_5_moe_vl import Model, VariableResolutionResamplerModel
3
+ from .language import LanguageModel
4
+ from .processor import Ernie4_5_VLProcessor, Ernie4_5_VLTokenizer, ImageProcessor
5
+ from .vision import VisionModel
@@ -0,0 +1,139 @@
1
+ import inspect
2
+ from dataclasses import dataclass, field
3
+ from typing import Dict, List, Optional, Union
4
+
5
+ from ..base import BaseModelConfig
6
+
7
+
8
+ @dataclass
9
+ class VisionConfig(BaseModelConfig):
10
+ """DFNRopeVisionTransformer configuration."""
11
+
12
+ model_type: str = "DFNRope_vision_transformer"
13
+ depth: int = 32
14
+ embed_dim: int = 1280
15
+ hidden_size: int = 3584 # This should match embed_dim for DFNRope
16
+ hidden_act: str = "quick_gelu"
17
+ mlp_ratio: float = 4.0
18
+ num_heads: int = 16
19
+ in_channels: int = 3
20
+ patch_size: int = 14
21
+ spatial_merge_size: int = 2
22
+ layer_norm_eps: float = 1e-6
23
+
24
+ def __post_init__(self):
25
+ # hidden_size should equal embed_dim for this architecture
26
+ if self.hidden_size != self.embed_dim:
27
+ self.hidden_size = self.embed_dim
28
+
29
+
30
+ @dataclass
31
+ class TextConfig(BaseModelConfig):
32
+ hidden_size: int = 3584
33
+ intermediate_size: int = 18944
34
+ model_type: str = "ernie"
35
+ max_position_embeddings: int = 131072
36
+ num_attention_heads: int = 28
37
+ num_key_value_heads: int = 4
38
+ num_hidden_layers: int = 56
39
+ rms_norm_eps: float = 1e-6
40
+ vocab_size: int = 151936
41
+ rope_theta: float = 1000000.0
42
+ use_bias: bool = False
43
+ tie_word_embeddings: bool = False
44
+ compression_ratio: float = 1.0
45
+ # MoE config
46
+ moe_num_experts: Union[int, List[int]] = 128
47
+ moe_layer_start_index: Union[int, List[int]] = 3
48
+ moe_layer_end_index: Optional[Union[int, List[int]]] = 53
49
+ moe_intermediate_size: Union[int, List[int]] = 1408
50
+ moe_capacity: List[float] = field(default_factory=lambda: [1.2, 2.0, 2.0])
51
+ moe_k: int = 2
52
+ moe_layer_interval: int = 1
53
+ moe_use_aux_free: bool = True
54
+ moe_num_shared_experts: int = 0
55
+ moe_gate_act: str = "softmax"
56
+ moe_norm_gate_logits: bool = True
57
+ head_dim: Optional[int] = None
58
+ # 3D RoPE config
59
+ rope_3d: bool = True
60
+ freq_allocation: int = 20
61
+ mrope_section: List[int] = field(default_factory=lambda: [22, 22, 20])
62
+ rope_scaling: Optional[Dict[str, Union[str, List[int]]]] = None
63
+ rope_parameters: Optional[Dict[str, Union[str, float, List[int]]]] = None
64
+ moe_norm_min: float = 1e-12
65
+
66
+ def __post_init__(self):
67
+ if self.num_key_value_heads is None:
68
+ self.num_key_value_heads = self.num_attention_heads
69
+ if self.head_dim is None:
70
+ self.head_dim = self.hidden_size // self.num_attention_heads
71
+ # Normalize rope_scaling keys
72
+ if self.rope_scaling:
73
+ if "type" not in self.rope_scaling and "rope_type" in self.rope_scaling:
74
+ self.rope_scaling["type"] = self.rope_scaling.pop("rope_type")
75
+ # Extract mrope_section from rope_scaling if present
76
+ if "mrope_section" in self.rope_scaling:
77
+ self.mrope_section = list(self.rope_scaling["mrope_section"])
78
+ # Also check rope_parameters (HuggingFace format)
79
+ if self.rope_parameters:
80
+ if "mrope_section" in self.rope_parameters:
81
+ self.mrope_section = list(self.rope_parameters["mrope_section"])
82
+
83
+
84
+ @dataclass
85
+ class ModelConfig(BaseModelConfig):
86
+ text_config: TextConfig = None
87
+ vision_config: VisionConfig = None
88
+ model_type: str = "ernie4_5_moe_vl"
89
+ ignore_index: int = -100
90
+ # Token IDs (defaults will be overridden by from_dict / __post_init__)
91
+ im_patch_id: int = 100295
92
+ image_token_id: int = 100295
93
+ image_start_token_id: int = 101304
94
+ image_end_token_id: int = 101305
95
+ video_token_id: int = 100295
96
+ video_start_token_id: int = 101306
97
+ video_end_token_id: int = 101307
98
+ vision_start_token_id: int = 101304
99
+ vision_end_token_id: int = 101305
100
+ vision_token_id: int = 100295
101
+ vocab_size: int = 103424
102
+ eos_token_id: Optional[List[int]] = None
103
+ # Vision-language integration
104
+ pixel_hidden_size: int = 1280
105
+ hidden_size: int = 2560
106
+ # Resampler config
107
+ spatial_conv_size: int = 2
108
+ temporal_conv_size: int = 2
109
+ use_temporal_conv: bool = True
110
+ # 3D RoPE config
111
+ rope_3d: bool = True
112
+ freq_allocation: int = 20
113
+
114
+ def __post_init__(self):
115
+ # Derive image_token_id from im_patch_id if not explicitly set differently
116
+ if self.image_token_id != self.im_patch_id:
117
+ self.image_token_id = self.im_patch_id
118
+ # vision_start/end should match image_start/end
119
+ if self.vision_start_token_id != self.image_start_token_id:
120
+ self.vision_start_token_id = self.image_start_token_id
121
+ if self.vision_end_token_id != self.image_end_token_id:
122
+ self.vision_end_token_id = self.image_end_token_id
123
+
124
+ @classmethod
125
+ def from_dict(cls, params):
126
+ # Copy text config parameters from root level (like qwen2_vl does)
127
+ # This ensures update_module_configs works correctly
128
+ excluded_keys = {"vision_config"}
129
+ params["text_config"] = dict(
130
+ filter(lambda x: x[0] not in excluded_keys, params.items())
131
+ )
132
+
133
+ return cls(
134
+ **{
135
+ k: v
136
+ for k, v in params.items()
137
+ if k in inspect.signature(cls).parameters
138
+ }
139
+ )