fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,450 @@
1
+ import copy
2
+ from functools import partial
3
+ from math import sqrt
4
+ from typing import Dict, Optional, Union
5
+
6
+ import cv2
7
+ import mlx.core as mx
8
+ import mlx.nn as nn
9
+ import numpy as np
10
+
11
+ from .config import MLPConfig, VisionConfig
12
+ from .sam import SAMEncoder
13
+
14
+
15
+ def check_array_shape(arr):
16
+ shape = arr.shape
17
+
18
+ # Check if the shape has 4 dimensions
19
+ if len(shape) != 4:
20
+ return False
21
+
22
+ out_channels, kH, KW, _ = shape
23
+
24
+ # Check if out_channels is the largest, and kH and KW are the same
25
+ if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
26
+ return True
27
+ else:
28
+ return False
29
+
30
+
31
+ class AttentionPoolLatent(nn.Module):
32
+ """Attention pooling w/ latent query"""
33
+
34
+ def __init__(
35
+ self,
36
+ in_features: int,
37
+ out_features: int = None,
38
+ embed_dim: int = None,
39
+ num_heads: int = 8,
40
+ mlp_ratio: float = 4.0,
41
+ qkv_bias: bool = True,
42
+ qk_norm: bool = False,
43
+ latent_len: int = 1,
44
+ latent_dim: int = None,
45
+ pos_embed: str = "",
46
+ pool_type: str = "token",
47
+ norm_layer: Optional[nn.Module] = None,
48
+ drop: float = 0.0,
49
+ ):
50
+ super().__init__()
51
+
52
+ embed_dim = embed_dim or in_features
53
+ out_features = out_features or in_features
54
+ assert embed_dim % num_heads == 0
55
+ self.num_heads = num_heads
56
+ self.head_dim = embed_dim // num_heads
57
+ self.scale = self.head_dim**-0.5
58
+ self.pool = pool_type
59
+
60
+ self.latent_dim = latent_dim or embed_dim
61
+ self.latent_len = latent_len
62
+ self.latent = mx.zeros((self.latent_len, embed_dim))[None, :]
63
+
64
+ self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
65
+ self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias)
66
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
67
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
68
+ self.proj = nn.Linear(embed_dim, embed_dim)
69
+ self.proj_drop = nn.Dropout(drop)
70
+
71
+ if pos_embed == "abs":
72
+ spatial_len = self.feat_size
73
+ self.pos_embed = mx.zeros((spatial_len, in_features))
74
+ else:
75
+ self.pos_embed = None
76
+
77
+ self.norm = nn.LayerNorm(out_features)
78
+ config = MLPConfig(
79
+ hidden_size=embed_dim, intermediate_size=int(embed_dim * mlp_ratio)
80
+ )
81
+ self.mlp = MLP(config)
82
+
83
+ def __call__(self, x: mx.array):
84
+ B, N, C = x.shape
85
+
86
+ if self.pos_embed is not None:
87
+ x = x + self.pos_embed.unsqueeze(0).to(x.dtype)
88
+
89
+ q_latent = mx.array(self.latent)
90
+
91
+ q = (
92
+ self.q(q_latent)
93
+ .reshape(B, self.latent_len, self.num_heads, self.head_dim)
94
+ .transpose(0, 2, 1, 3)
95
+ )
96
+
97
+ kv = (
98
+ self.kv(x)
99
+ .reshape(B, N, 2, self.num_heads, self.head_dim)
100
+ .transpose(2, 0, 3, 1, 4)
101
+ )
102
+ k, v = mx.split(kv, 2, axis=0)
103
+
104
+ q, k = self.q_norm(q), self.k_norm(k)
105
+
106
+ x = mx.fast.scaled_dot_product_attention(
107
+ q, k[0], v[0], scale=(1.0 / sqrt(q.shape[-1])), mask=None
108
+ )
109
+
110
+ x = x.transpose(0, 2, 1, 3).reshape(B, self.latent_len, C)
111
+ x = self.proj(x)
112
+ x = self.proj_drop(x)
113
+
114
+ x = x + self.mlp(self.norm(x))
115
+
116
+ # optional pool if latent seq_len > 1 and pooled output is desired
117
+ if self.pool == "token":
118
+ x = x[:, 0]
119
+ elif self.pool == "avg":
120
+ x = x.mean(1)
121
+ return x
122
+
123
+
124
+ class Attention(nn.Module):
125
+ def __init__(
126
+ self,
127
+ dims: int,
128
+ num_heads: int,
129
+ qkv_bias: bool = False,
130
+ ):
131
+ super().__init__()
132
+
133
+ if (dims % num_heads) != 0:
134
+ raise ValueError(
135
+ "The input feature dimensions should be divisible by the "
136
+ f"number of heads ({dims} % {num_heads}) != 0"
137
+ )
138
+
139
+ self.num_heads = num_heads = num_heads
140
+ head_dim = dims // num_heads
141
+ self.scale = head_dim**-0.5
142
+
143
+ self.qkv = nn.Linear(dims, dims * 3, bias=qkv_bias)
144
+ self.proj = nn.Linear(dims, dims, bias=True)
145
+
146
+ def __call__(self, x, mask=None):
147
+ qkv = self.qkv(x)
148
+ queries, keys, values = mx.split(qkv, 3, axis=-1)
149
+
150
+ num_heads = self.num_heads
151
+ B, L, D = queries.shape
152
+ _, S, _ = keys.shape
153
+ queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
154
+ keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
155
+ values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
156
+
157
+ output = mx.fast.scaled_dot_product_attention(
158
+ queries, keys, values, scale=self.scale, mask=mask
159
+ )
160
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
161
+
162
+ return self.proj(output)
163
+
164
+
165
+ class FastGELUActivation(nn.Module):
166
+ """
167
+ Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
168
+ """
169
+
170
+ def __call__(self, input: mx.array) -> mx.array:
171
+ return (
172
+ 0.5
173
+ * input
174
+ * (1.0 + mx.tanh(np.sqrt(2 / np.pi) * (input + 0.044715 * (input**3))))
175
+ ).astype(input.dtype)
176
+
177
+
178
+ class MLP(nn.Module):
179
+ def __init__(self, config: Union[VisionConfig, Dict], bias: bool = True):
180
+ super().__init__()
181
+ self.activation_fn = FastGELUActivation()
182
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=bias)
183
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=bias)
184
+
185
+ def __call__(self, x: mx.array) -> mx.array:
186
+ x = self.activation_fn(self.fc1(x))
187
+ x = self.fc2(x)
188
+ return x
189
+
190
+
191
+ class EncoderLayer(nn.Module):
192
+ def __init__(self, config: VisionConfig):
193
+ super().__init__()
194
+ self.embed_dim = config.hidden_size
195
+ self.attn = Attention(
196
+ config.hidden_size, config.num_attention_heads, qkv_bias=True
197
+ )
198
+ self.norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
199
+ self.mlp = MLP(config)
200
+ self.norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
201
+
202
+ def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
203
+ y = self.norm1(x)
204
+ y = self.attn(y, mask)
205
+ x = x + y
206
+ y = self.norm2(x)
207
+ y = self.mlp(y)
208
+ return x + y
209
+
210
+
211
+ class VisionEmbeddings(nn.Module):
212
+ def __init__(self, config: VisionConfig, norm_layer: bool = False):
213
+ super().__init__()
214
+ self.config = config
215
+ self.embed_dim = config.hidden_size
216
+ self.image_size = config.image_size
217
+ self.patch_size = config.patch_size
218
+
219
+ self.proj = nn.Conv2d(
220
+ in_channels=config.num_channels,
221
+ out_channels=self.embed_dim,
222
+ kernel_size=self.patch_size,
223
+ stride=self.patch_size,
224
+ )
225
+
226
+ self.num_patches = (self.image_size // self.patch_size) ** 2
227
+ self.num_positions = self.num_patches
228
+
229
+ self.norm = (
230
+ nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
231
+ if norm_layer
232
+ else nn.Identity()
233
+ )
234
+
235
+ def __call__(self, x: mx.array) -> mx.array:
236
+ patch_embeddings = self.proj(x)
237
+ patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2)
238
+ return self.norm(patch_embeddings)
239
+
240
+
241
+ class SigLipVisionModel(nn.Module):
242
+ def __init__(
243
+ self,
244
+ config: VisionConfig,
245
+ ignore_head: bool,
246
+ pre_norm: bool = False,
247
+ no_embed_class: bool = True,
248
+ ):
249
+ super().__init__()
250
+ self.num_prefix_tokens = 1
251
+ self.no_embed_class = False
252
+ self.dynamic_img_size = False
253
+ self.ignore_head = ignore_head
254
+ self.cls_token = None
255
+ self.reg_token = None
256
+ self.patch_embed = VisionEmbeddings(config)
257
+ self.norm_pre = nn.LayerNorm(config.hidden_size) if pre_norm else nn.Identity()
258
+ self.blocks = [EncoderLayer(config) for _ in range(config.num_hidden_layers)]
259
+ self.norm = nn.LayerNorm(config.hidden_size)
260
+ num_patches = self.patch_embed.num_patches
261
+ embed_len = (
262
+ num_patches if no_embed_class else num_patches + self.num_prefix_tokens
263
+ )
264
+ self.pos_embed = mx.random.normal((embed_len, config.hidden_size))[None, :]
265
+
266
+ norm_layer = partial(nn.LayerNorm, eps=1e-5)
267
+ self.attn_pool = AttentionPoolLatent(
268
+ config.hidden_size,
269
+ num_heads=config.num_attention_heads,
270
+ norm_layer=norm_layer,
271
+ )
272
+
273
+ def __call__(
274
+ self,
275
+ x: mx.array,
276
+ output_hidden_states: Optional[bool] = None,
277
+ ) -> mx.array:
278
+ x = self.patch_embed(x)
279
+ x += self.pos_embed
280
+ x = self.norm_pre(x)
281
+
282
+ encoder_states = (x,) if output_hidden_states else None
283
+ for l in self.blocks:
284
+ x = l(x, mask=None)
285
+ if output_hidden_states:
286
+ encoder_states = encoder_states + (x,)
287
+
288
+ pooler_output = self.norm(x)
289
+
290
+ if not self.ignore_head:
291
+ pooler_output = self.attn_pool(pooler_output)
292
+ return pooler_output, x, encoder_states
293
+
294
+
295
+ class HybridVisionModel(nn.Module):
296
+ def __init__(self, config: VisionConfig, resolution: str, ignore_head: bool = True):
297
+ super().__init__()
298
+
299
+ self.model_type = config.model_type
300
+ self.resolution = resolution
301
+ if self.model_type != "vision":
302
+ raise ValueError(f"Unsupported model type: {self.model_type}")
303
+
304
+ if resolution == "high":
305
+ self.vision_tower = SAMEncoder()
306
+ else:
307
+ self.vision_tower = SigLipVisionModel(config, ignore_head)
308
+
309
+ def __call__(self, x: mx.array) -> mx.array:
310
+ if self.resolution == "high":
311
+ return self.vision_tower(x)
312
+ else:
313
+ return self.vision_tower(x)[0]
314
+
315
+
316
+ def resize_image(image, size, antialias=True):
317
+ """
318
+ Resize an image with OpenCV.
319
+
320
+ Args:
321
+ image (numpy.ndarray): The input image array. Supports H × W or H × W × C.
322
+ If you pass in a batch (N × H × W × C) just slice the
323
+ element you want, e.g. image[0].
324
+ size (tuple): Target size as (width, height) — exactly the same order that
325
+ cv2.resize expects.
326
+ antialias (bool):
327
+ * True → high‑quality interpolation (bicubic for upscaling, area for downscaling)
328
+ * False → nearest‑neighbor (fast, blocky)
329
+
330
+ Returns:
331
+ numpy.ndarray: The resized image array.
332
+ """
333
+ img = np.ascontiguousarray(np.asarray(image))
334
+ if img.ndim == 4 and img.shape[0] == 1: # squeeze stray batch dim
335
+ img = img[0]
336
+ h0, w0 = img.shape[:2]
337
+
338
+ # --- work out dsize vs fx/fy ---------------------------------------------
339
+ dsize = None
340
+ fx = fy = 0.0
341
+ if isinstance(size, (int, float)): # uniform scale
342
+ fx = fy = float(size)
343
+ elif isinstance(size, (tuple, list)) and len(size) == 2:
344
+ a, b = size
345
+ # Heuristic: treat "small" floats as scale factors
346
+ if all(isinstance(x, (int, float)) and x < 10 for x in (a, b)):
347
+ fx, fy = float(a), float(b) # scale factors
348
+ else:
349
+ dsize = (int(a), int(b)) # absolute pixels
350
+ else:
351
+ raise ValueError("target must be scalar or a 2‑tuple")
352
+
353
+ # Guard against zeros after int‑casting
354
+ if dsize:
355
+ if dsize[0] <= 0 or dsize[1] <= 0:
356
+ raise ValueError(f"dsize became {dsize}")
357
+ else:
358
+ if fx <= 0 or fy <= 0:
359
+ raise ValueError(f"fx,fy became {(fx, fy)}")
360
+
361
+ # --- choose interpolation -------------------------------------------------
362
+ if antialias:
363
+ # Use Lanczos interpolation for potentially better detail preservation
364
+ interp = cv2.INTER_LANCZOS4
365
+ else:
366
+ interp = cv2.INTER_NEAREST
367
+
368
+ # --- call OpenCV ----------------------------------------------------------
369
+ return mx.array(cv2.resize(img, dsize=dsize, fx=fx, fy=fy, interpolation=interp))
370
+
371
+
372
+ class VisionModel(nn.Module):
373
+ def __init__(self, config: VisionConfig, ignore_head: bool = True):
374
+ super().__init__()
375
+
376
+ self.model_type = config.model_type
377
+ self.config = config
378
+ if self.model_type != "vision":
379
+ raise ValueError(f"Unsupported model type: {self.model_type}")
380
+
381
+ if config.cls == "HybridVisionTower":
382
+ self.high_layer_norm = nn.LayerNorm(
383
+ config.params["high_res_cfg"]["output_dim"]
384
+ )
385
+ self.low_layer_norm = nn.LayerNorm(
386
+ config.params["low_res_cfg"]["output_dim"]
387
+ )
388
+
389
+ high_res_cfg = copy.deepcopy(config)
390
+ high_res_cfg.image_size = config.params["high_res_cfg"]["image_size"]
391
+ self.vision_tower_high = HybridVisionModel(
392
+ high_res_cfg, "high", ignore_head
393
+ )
394
+
395
+ low_res_cfg = copy.deepcopy(config)
396
+ low_res_cfg.image_size = config.params["low_res_cfg"]["image_size"]
397
+
398
+ self.vision_tower_low = HybridVisionModel(low_res_cfg, "low", ignore_head)
399
+ self.low_res_size = config.params["low_res_cfg"]["image_size"]
400
+ self.resize = lambda image: resize_image(
401
+ image, (self.low_res_size, self.low_res_size), antialias=True
402
+ )
403
+
404
+ else:
405
+ self.vision_tower = SigLipVisionModel(config, ignore_head)
406
+
407
+ def __call__(
408
+ self, x: mx.array, output_hidden_states: Optional[bool] = None
409
+ ) -> mx.array:
410
+ if self.config.cls == "HybridVisionTower":
411
+ high_images = x
412
+ low_images = mx.array(self.resize(np.array(x)))[None, :]
413
+
414
+ high_res = self.vision_tower_high(high_images)
415
+ low_res = self.vision_tower_low(low_images)
416
+
417
+ return (high_res, low_res)
418
+ else:
419
+ return self.vision_tower(x, output_hidden_states)
420
+
421
+ def sanitize(self, weights):
422
+ sanitized_weights = {}
423
+ weight_keys = {
424
+ "neck.0.weight",
425
+ "neck.2.weight",
426
+ "neck_hd.0.weight",
427
+ "neck_hd.2.weight",
428
+ "downsamples.0.weight",
429
+ "downsamples.1.weight",
430
+ "patch_embed.proj.weight",
431
+ }
432
+ for k, v in weights.items():
433
+ if "position_ids" in k:
434
+ # Remove unused position_ids
435
+ continue
436
+
437
+ elif ".".join(k.split(".")[-3:]) in weight_keys:
438
+ # PyTorch conv2d weight tensors have shape:
439
+ # [out_channels, in_channels, kH, KW]
440
+ # MLX conv2d expects the weight be of shape:
441
+ # [out_channels, kH, KW, in_channels]
442
+ if check_array_shape(v):
443
+ sanitized_weights[k] = v
444
+ else:
445
+ sanitized_weights[k] = v.transpose(0, 2, 3, 1)
446
+
447
+ else:
448
+ sanitized_weights[k] = v
449
+
450
+ return sanitized_weights
@@ -0,0 +1,3 @@
1
+ from .config import ModelConfig, TextConfig, VisionConfig
2
+ from .paddleocr_vl import LanguageModel, Model, VisionModel
3
+ from .processing_paddleocr_vl import PaddleOCRVLProcessor
@@ -0,0 +1,93 @@
1
+ import inspect
2
+ from dataclasses import dataclass, field
3
+ from typing import Dict, Optional, Union
4
+
5
+ from ..base import BaseModelConfig
6
+
7
+
8
+ @dataclass
9
+ class VisionConfig(BaseModelConfig):
10
+ model_type: str = "paddleocr_vl"
11
+ hidden_size: int = 1152
12
+ intermediate_size: int = 4304
13
+ num_hidden_layers: int = 27
14
+ num_attention_heads: int = 16
15
+ num_channels: int = 3
16
+ image_size: int = 384
17
+ patch_size: int = 14
18
+ hidden_act: str = "gelu_pytorch_tanh"
19
+ layer_norm_eps: float = 1e-6
20
+ attention_dropout: float = 0.0
21
+ spatial_merge_size: int = 2
22
+
23
+
24
+ @dataclass
25
+ class TextConfig(BaseModelConfig):
26
+ model_type: str = "paddleocr_vl"
27
+ hidden_size: int = 1024
28
+ num_hidden_layers: int = 18
29
+ intermediate_size: int = 3072
30
+ num_attention_heads: int = 16
31
+ rms_norm_eps: float = 1e-05
32
+ vocab_size: int = 103424
33
+ num_key_value_heads: Optional[int] = 2
34
+ max_position_embeddings: Optional[int] = 131072
35
+ rope_theta: float = 500000.0
36
+ rope_traditional: bool = False
37
+ rope_scaling: Optional[Dict[str, Union[float, str]]] = None
38
+ use_cache: bool = True
39
+ hidden_act: str = ("silu",)
40
+ pad_token_id: int = (0,)
41
+ bos_token_id: int = (1,)
42
+ eos_token_id: int = (2,)
43
+ use_bias: bool = (False,)
44
+ head_dim: int = (128,)
45
+ rope_parameters: Dict = None
46
+ rope_scaling: Dict = field(
47
+ default_factory=lambda: {
48
+ "rope_type": "default",
49
+ "type": "default",
50
+ "mrope_section": [16, 24, 24],
51
+ }
52
+ )
53
+
54
+ def __post_init__(self):
55
+ if self.num_key_value_heads is None:
56
+ self.num_key_value_heads = self.num_attention_heads
57
+
58
+ if self.rope_scaling:
59
+ required_keys = {"mrope_section", "type"}
60
+ if not all(key in self.rope_scaling for key in required_keys):
61
+ raise ValueError(f"rope_scaling must contain keys {required_keys}")
62
+
63
+ if not self.rope_scaling["type"] in ["mrope", "default"]:
64
+ raise ValueError(f"rope_scaling type must be 'mrope' or 'default'")
65
+
66
+
67
+ @dataclass
68
+ class ModelConfig(BaseModelConfig):
69
+ text_config: TextConfig
70
+ vision_config: VisionConfig
71
+ model_type: str = "paddleocr_vl"
72
+ ignore_index: int = -100
73
+ image_token_id: int = 100295
74
+ video_token_id: int = 100296
75
+ vision_start_token_id: int = 101305
76
+ vision_end_token_id: int = (101306,)
77
+ eos_token_id: int = (2,)
78
+
79
+ @classmethod
80
+ def from_dict(cls, params):
81
+ # Copy text config parameters from root level
82
+ excluded_keys = {"vision_config"}
83
+ params["text_config"] = dict(
84
+ filter(lambda x: x[0] not in excluded_keys, params.items())
85
+ )
86
+
87
+ return cls(
88
+ **{
89
+ k: v
90
+ for k, v in params.items()
91
+ if k in inspect.signature(cls).parameters
92
+ }
93
+ )