fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,430 @@
1
+ """Image processor for Jina VLM in MLX-VLM."""
2
+
3
+ import math
4
+ from typing import Dict, List, Optional, Tuple, Union
5
+
6
+ import mlx.core as mx
7
+ import numpy as np
8
+ from PIL import Image
9
+
10
+ # CLIP normalization constants
11
+ CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
12
+ CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
13
+
14
+ # Default special token IDs
15
+ DEFAULT_PATCH_TOKEN_ID = 151665 # <im_patch>
16
+ DEFAULT_START_TOKEN_ID = 151666 # <im_start>
17
+ DEFAULT_END_TOKEN_ID = 151667 # <im_end>
18
+ DEFAULT_COLUMN_TOKEN_ID = 151668 # <im_col>
19
+
20
+
21
+ def smart_resize(
22
+ height: int,
23
+ width: int,
24
+ factor: int = 28,
25
+ min_pixels: int = 56 * 56,
26
+ max_pixels: int = 14 * 14 * 4 * 1280,
27
+ ) -> Tuple[int, int]:
28
+ """Resize dimensions while maintaining aspect ratio and constraints."""
29
+ h_bar = round(height / factor) * factor
30
+ w_bar = round(width / factor) * factor
31
+
32
+ if h_bar * w_bar > max_pixels:
33
+ beta = math.sqrt((height * width) / max_pixels)
34
+ h_bar = max(factor, math.floor(height / beta / factor) * factor)
35
+ w_bar = max(factor, math.floor(width / beta / factor) * factor)
36
+ elif h_bar * w_bar < min_pixels:
37
+ beta = math.sqrt(min_pixels / (height * width))
38
+ h_bar = math.ceil(height * beta / factor) * factor
39
+ w_bar = math.ceil(width * beta / factor) * factor
40
+
41
+ return h_bar, w_bar
42
+
43
+
44
+ def patchify(array: np.ndarray, patch_size: int, batched: bool = False) -> np.ndarray:
45
+ """Reshape image(s) to patches."""
46
+ if len(array.shape) == 3 and not batched:
47
+ h, w, c = array.shape
48
+ h_patches = h // patch_size
49
+ w_patches = w // patch_size
50
+ array = array.reshape(h_patches, patch_size, w_patches, patch_size, c)
51
+ array = array.transpose(0, 2, 1, 3, 4)
52
+ return array.reshape(h_patches * w_patches, patch_size * patch_size * c)
53
+ elif len(array.shape) == 4 or (len(array.shape) == 3 and batched):
54
+ if len(array.shape) == 3:
55
+ bs, h, w = array.shape
56
+ c = 1
57
+ array = array[..., None]
58
+ else:
59
+ bs, h, w, c = array.shape
60
+ h_patches = h // patch_size
61
+ w_patches = w // patch_size
62
+ array = array.reshape(bs, h_patches, patch_size, w_patches, patch_size, c)
63
+ array = array.transpose(0, 1, 3, 2, 4, 5)
64
+ result = array.reshape(bs, h_patches * w_patches, patch_size * patch_size * c)
65
+ if c == 1:
66
+ result = result[..., 0] if result.shape[-1] == 1 else result.mean(axis=-1)
67
+ return result
68
+ else:
69
+ raise ValueError(f"Unsupported array shape: {array.shape}")
70
+
71
+
72
+ class ImageProcessor:
73
+ """Image processor for Jina VLM (standalone, not a BaseImageProcessor)."""
74
+
75
+ def __init__(
76
+ self,
77
+ config: Optional[dict] = None,
78
+ base_input_size: Tuple[int, int] = (378, 378),
79
+ patch_size: int = 14,
80
+ max_crops: int = 12,
81
+ min_pixels: int = 3136,
82
+ max_pixels: int = 1003520,
83
+ overlap_margins: Tuple[int, int] = (4, 4),
84
+ pooling_h: int = 2,
85
+ pooling_w: int = 2,
86
+ use_column_tokens: bool = True,
87
+ image_mean: Optional[List[float]] = None,
88
+ image_std: Optional[List[float]] = None,
89
+ patch_token_id: int = DEFAULT_PATCH_TOKEN_ID,
90
+ start_token_id: int = DEFAULT_START_TOKEN_ID,
91
+ end_token_id: int = DEFAULT_END_TOKEN_ID,
92
+ column_token_id: int = DEFAULT_COLUMN_TOKEN_ID,
93
+ ):
94
+ self.base_input_size = base_input_size
95
+ self.patch_size = patch_size
96
+ self.max_crops = max_crops
97
+ self.min_pixels = min_pixels
98
+ self.max_pixels = max_pixels
99
+ self.overlap_margins = overlap_margins
100
+ self.pooling_h = pooling_h
101
+ self.pooling_w = pooling_w
102
+ self.use_column_tokens = use_column_tokens
103
+ self.image_mean = image_mean or CLIP_MEAN
104
+ self.image_std = image_std or CLIP_STD
105
+
106
+ self.patch_token_id = patch_token_id
107
+ self.start_token_id = start_token_id
108
+ self.end_token_id = end_token_id
109
+ self.column_token_id = column_token_id
110
+
111
+ self.crop_patches = base_input_size[0] // patch_size
112
+ self.token_length_h = (self.crop_patches + pooling_h - 1) // pooling_h
113
+ self.token_length_w = (self.crop_patches + pooling_w - 1) // pooling_w
114
+ self.tokens_per_image = self.token_length_h * self.token_length_w
115
+
116
+ def normalize(self, x: np.ndarray) -> np.ndarray:
117
+ return (x - 0.5) * 2.0
118
+
119
+ def resize_image(
120
+ self,
121
+ image: np.ndarray,
122
+ size: Tuple[int, int],
123
+ ) -> Tuple[np.ndarray, np.ndarray]:
124
+ pil_image = Image.fromarray((image * 255).astype(np.uint8))
125
+ pil_image = pil_image.resize((size[1], size[0]), Image.BICUBIC)
126
+ resized = np.array(pil_image, dtype=np.float32) / 255.0
127
+ mask = np.ones((size[0], size[1]), dtype=np.bool_)
128
+ return resized, mask
129
+
130
+ def select_tiling(
131
+ self,
132
+ h: int,
133
+ w: int,
134
+ patch_size: int,
135
+ max_crops: int,
136
+ ) -> Tuple[int, int]:
137
+ tilings = []
138
+ for i in range(1, max_crops + 1):
139
+ for j in range(1, max_crops + 1):
140
+ if i * j <= max_crops:
141
+ tilings.append((i, j))
142
+
143
+ tilings.sort(key=lambda x: (x[0] * x[1], x[0]))
144
+ candidate_tilings = np.array(tilings, dtype=np.int32)
145
+ candidate_resolutions = candidate_tilings * patch_size
146
+
147
+ original_size = np.array([h, w], dtype=np.float32)
148
+ with np.errstate(divide="ignore"):
149
+ required_scale = candidate_resolutions.astype(np.float32) / original_size
150
+ required_scale = np.min(required_scale, axis=-1, keepdims=True)
151
+
152
+ if np.all(required_scale < 1):
153
+ ix = np.argmax(required_scale)
154
+ else:
155
+ required_scale = np.where(required_scale < 1.0, 1e9, required_scale)
156
+ ix = np.argmin(required_scale)
157
+
158
+ return tuple(candidate_tilings[ix])
159
+
160
+ def _get_patches_from_tiling(
161
+ self,
162
+ num_tiles: int,
163
+ pooling_size: int,
164
+ crop_patches: int,
165
+ crop_window_patches: int,
166
+ left_margin: int,
167
+ right_margin: int,
168
+ ) -> int:
169
+ if num_tiles > 1:
170
+ left_crop = (
171
+ (crop_window_patches + left_margin + pooling_size - 1)
172
+ // pooling_size
173
+ * pooling_size
174
+ )
175
+ middle_crop = (
176
+ (crop_window_patches + pooling_size - 1) // pooling_size * pooling_size
177
+ )
178
+ right_crop = (
179
+ (crop_window_patches + right_margin + pooling_size - 1)
180
+ // pooling_size
181
+ * pooling_size
182
+ )
183
+ return left_crop + (num_tiles - 2) * middle_crop + right_crop
184
+ else:
185
+ return (crop_patches + pooling_size - 1) // pooling_size * pooling_size
186
+
187
+ def build_image_input_idx(
188
+ self,
189
+ image_tokens: np.ndarray,
190
+ patch_order: Optional[np.ndarray],
191
+ ) -> np.ndarray:
192
+ image_input_idx = image_tokens == self.patch_token_id
193
+ image_input_idx = np.nonzero(image_input_idx)[0].astype(np.int32)
194
+
195
+ if patch_order is not None:
196
+ patch_order = np.reshape(patch_order, [-1])
197
+ valid = patch_order >= 0
198
+ n_valid_patches = valid.sum()
199
+
200
+ if len(image_input_idx) != n_valid_patches:
201
+ raise ValueError(
202
+ f"Mismatch: {len(image_input_idx)} patch tokens but {n_valid_patches} valid patches"
203
+ )
204
+
205
+ sorted_patch_ixs = np.zeros([image_input_idx.shape[0]], np.int32)
206
+ sorted_patch_ixs[patch_order[valid]] = np.arange(
207
+ n_valid_patches, dtype=np.int32
208
+ )
209
+ sorted_patch_ixs_ex = np.full(np.shape(patch_order), -1)
210
+ sorted_patch_ixs_ex[valid] = sorted_patch_ixs
211
+
212
+ valid_mask = (sorted_patch_ixs_ex >= 0).astype(np.int32)
213
+ image_input_idx = image_input_idx[sorted_patch_ixs_ex * valid_mask]
214
+ image_input_idx = image_input_idx * valid_mask - 10000 * (1 - valid_mask)
215
+
216
+ return np.reshape(image_input_idx, [-1, self.tokens_per_image])
217
+
218
+ def crop_image(
219
+ self, image: np.ndarray
220
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
221
+ left_margin, right_margin = self.overlap_margins
222
+ total_margin_pixels = self.patch_size * (right_margin + left_margin)
223
+ crop_patches = self.crop_patches
224
+ crop_window_patches = crop_patches - (right_margin + left_margin)
225
+ crop_window_size = crop_window_patches * self.patch_size
226
+
227
+ original_h, original_w = image.shape[:2]
228
+
229
+ tiling = self.select_tiling(
230
+ original_h - total_margin_pixels,
231
+ original_w - total_margin_pixels,
232
+ crop_window_size,
233
+ self.max_crops,
234
+ )
235
+
236
+ target_h = tiling[0] * crop_window_size + total_margin_pixels
237
+ target_w = tiling[1] * crop_window_size + total_margin_pixels
238
+ src, img_mask = self.resize_image(image, (target_h, target_w))
239
+ src = self.normalize(src)
240
+
241
+ patches_arr = []
242
+ mask_arr = []
243
+ patch_ordering_arr = []
244
+
245
+ crop_size = self.base_input_size[0]
246
+ on = 0
247
+
248
+ for i in range(tiling[0]):
249
+ y0 = i * crop_window_size
250
+ crop_y0 = 0 if i == 0 else left_margin // self.pooling_h
251
+ crop_h = crop_patches - (right_margin + left_margin)
252
+ if i == 0:
253
+ crop_h += left_margin
254
+ if i == (tiling[0] - 1):
255
+ crop_h += right_margin
256
+
257
+ for j in range(tiling[1]):
258
+ x0 = j * crop_window_size
259
+ crop_x0 = 0 if j == 0 else left_margin // self.pooling_w
260
+ crop_w = crop_patches - (right_margin + left_margin)
261
+ if j == 0:
262
+ crop_w += left_margin
263
+ if j == (tiling[1] - 1):
264
+ crop_w += right_margin
265
+
266
+ pooled_w = (crop_w + self.pooling_w - 1) // self.pooling_w
267
+ pooled_h = (crop_h + self.pooling_h - 1) // self.pooling_h
268
+ after_padding_width = self.token_length_w - pooled_w - crop_x0
269
+ after_padding_height = self.token_length_h - pooled_h - crop_y0
270
+
271
+ patch_ordering_arr.append(
272
+ np.pad(
273
+ np.reshape(
274
+ np.arange(on, on + pooled_h * pooled_w, dtype=np.int32),
275
+ (pooled_h, pooled_w),
276
+ ),
277
+ [
278
+ [crop_y0, after_padding_height],
279
+ [crop_x0, after_padding_width],
280
+ ],
281
+ constant_values=-1,
282
+ mode="constant",
283
+ )
284
+ )
285
+
286
+ crop = src[y0 : y0 + crop_size, x0 : x0 + crop_size]
287
+ if crop.shape[0] < crop_size or crop.shape[1] < crop_size:
288
+ padded = np.zeros((crop_size, crop_size, 3), dtype=np.float32)
289
+ padded[: crop.shape[0], : crop.shape[1]] = crop
290
+ crop = padded
291
+ patches_arr.append(crop)
292
+
293
+ crop_mask = img_mask[y0 : y0 + crop_size, x0 : x0 + crop_size]
294
+ if crop_mask.shape[0] < crop_size or crop_mask.shape[1] < crop_size:
295
+ padded_mask = np.zeros((crop_size, crop_size), dtype=np.bool_)
296
+ padded_mask[: crop_mask.shape[0], : crop_mask.shape[1]] = crop_mask
297
+ crop_mask = padded_mask
298
+ mask_arr.append(crop_mask)
299
+
300
+ on += pooled_h * pooled_w
301
+
302
+ patches = np.stack(patches_arr)
303
+ patch_ordering = np.stack(patch_ordering_arr)
304
+ img_masks = np.stack(mask_arr)
305
+
306
+ patches = patchify(patches, self.patch_size, batched=True)
307
+ img_masks = patchify(
308
+ img_masks.astype(np.float32), self.patch_size, batched=True
309
+ )
310
+ if img_masks.ndim == 3:
311
+ img_masks = img_masks.mean(axis=-1)
312
+
313
+ patch_ordering = np.reshape(patch_ordering, [-1])
314
+ valid = patch_ordering >= 0
315
+
316
+ patch_ordering_rh = np.reshape(
317
+ patch_ordering,
318
+ [tiling[0], tiling[1], self.token_length_h, self.token_length_w],
319
+ )
320
+ patch_ordering_rh = np.transpose(patch_ordering_rh, [0, 2, 1, 3])
321
+ patch_ordering_rh = np.reshape(patch_ordering_rh, [-1])
322
+ patch_ordering[valid] = patch_ordering_rh[patch_ordering_rh >= 0]
323
+
324
+ h = self._get_patches_from_tiling(
325
+ tiling[0],
326
+ self.pooling_h,
327
+ crop_patches,
328
+ crop_window_patches,
329
+ left_margin,
330
+ right_margin,
331
+ )
332
+ w = self._get_patches_from_tiling(
333
+ tiling[1],
334
+ self.pooling_w,
335
+ crop_patches,
336
+ crop_window_patches,
337
+ left_margin,
338
+ right_margin,
339
+ )
340
+
341
+ per_row = np.full((w // self.pooling_w,), self.patch_token_id, dtype=np.int32)
342
+ if self.use_column_tokens:
343
+ per_row = np.concatenate([per_row, [self.column_token_id]], 0)
344
+ joint = np.tile(per_row, [h // self.pooling_h])
345
+ joint = [[self.start_token_id], joint, [self.end_token_id]]
346
+
347
+ thumb, _ = self.resize_image(image, self.base_input_size)
348
+ thumb = self.normalize(thumb)
349
+ thumb_patches = patchify(thumb, self.patch_size, batched=False)
350
+ patches = np.concatenate([np.expand_dims(thumb_patches, 0), patches], 0)
351
+
352
+ patch_ordering = np.where(
353
+ patch_ordering >= 0, patch_ordering + self.tokens_per_image, -1
354
+ )
355
+ patch_ordering = np.concatenate(
356
+ [np.arange(0, self.tokens_per_image), patch_ordering], 0
357
+ )
358
+
359
+ per_row = np.full((self.token_length_w,), self.patch_token_id, dtype=np.int32)
360
+ if self.use_column_tokens:
361
+ per_row = np.concatenate([per_row, [self.column_token_id]], 0)
362
+ extra_tokens = np.tile(per_row, [self.token_length_h])
363
+ joint = [[self.start_token_id], extra_tokens, [self.end_token_id]] + joint
364
+
365
+ image_tokens = np.concatenate(joint, 0).astype(np.int32)
366
+
367
+ img_masks = np.pad(img_masks, [[1, 0], [0, 0]], constant_values=1.0)
368
+
369
+ return patches, image_tokens, patch_ordering, img_masks
370
+
371
+ def process_image(
372
+ self,
373
+ image: Union[Image.Image, np.ndarray],
374
+ ) -> Dict[str, np.ndarray]:
375
+ if isinstance(image, Image.Image):
376
+ image = image.convert("RGB")
377
+ image = np.array(image, dtype=np.float32) / 255.0
378
+ elif image.dtype == np.uint8:
379
+ image = image.astype(np.float32) / 255.0
380
+
381
+ h, w = image.shape[:2]
382
+ new_h, new_w = smart_resize(
383
+ h,
384
+ w,
385
+ factor=self.patch_size,
386
+ min_pixels=self.min_pixels,
387
+ max_pixels=self.max_pixels,
388
+ )
389
+ if (new_h, new_w) != (h, w):
390
+ image, _ = self.resize_image(image, (new_h, new_w))
391
+
392
+ patches, image_tokens, patch_ordering, masks = self.crop_image(image)
393
+
394
+ image_input_idx = self.build_image_input_idx(image_tokens, patch_ordering)
395
+
396
+ return {
397
+ "pixel_values": patches,
398
+ "image_tokens": image_tokens,
399
+ "image_input_idx": image_input_idx,
400
+ "image_masks": masks,
401
+ }
402
+
403
+ def preprocess(
404
+ self,
405
+ images: Union[Image.Image, List[Image.Image], str, List[str]],
406
+ **kwargs,
407
+ ) -> Dict[str, mx.array]:
408
+ if not isinstance(images, list):
409
+ images = [images]
410
+
411
+ loaded_images = []
412
+ for img in images:
413
+ if isinstance(img, str):
414
+ loaded_images.append(Image.open(img).convert("RGB"))
415
+ else:
416
+ loaded_images.append(img)
417
+
418
+ results = {
419
+ "pixel_values": [],
420
+ "image_tokens": [],
421
+ "image_input_idx": [],
422
+ "image_masks": [],
423
+ }
424
+
425
+ for image in loaded_images:
426
+ processed = self.process_image(image)
427
+ for key in results:
428
+ results[key].append(processed[key])
429
+
430
+ return results
@@ -0,0 +1,280 @@
1
+ """Main Jina VLM model for MLX."""
2
+
3
+ from typing import Optional
4
+
5
+ import mlx.core as mx
6
+ import mlx.nn as nn
7
+ from transformers import AutoProcessor
8
+
9
+ from ..base import InputEmbeddingsFeatures
10
+ from .config import ModelConfig, VisionConfig
11
+ from .language import LanguageModel
12
+ from .processing_jinavlm import JinaVLMProcessor
13
+ from .vision import VisionModel
14
+
15
+ AutoProcessor.register("jvlm", JinaVLMProcessor)
16
+
17
+
18
+ class CrossAttention(nn.Module):
19
+ """Cross-attention for pooling - matches weight naming: pooling.q, pooling.kv, pooling.out"""
20
+
21
+ def __init__(self, config: VisionConfig):
22
+ super().__init__()
23
+ input_size = config.hidden_size * len(config.vit_layers)
24
+ n_heads = config.num_attention_heads
25
+ head_dim = config.head_dim
26
+
27
+ self.num_heads = n_heads
28
+ self.head_dim = head_dim
29
+ self.scale = head_dim**-0.5
30
+
31
+ # Named to match weights: pooling.q, pooling.kv, pooling.out
32
+ self.q = nn.Linear(input_size, n_heads * head_dim, bias=True)
33
+ self.kv = nn.Linear(input_size, 2 * n_heads * head_dim, bias=True)
34
+ self.out = nn.Linear(n_heads * head_dim, config.hidden_size, bias=True)
35
+
36
+ def __call__(self, query: mx.array, key_value: mx.array) -> mx.array:
37
+ B, Lq, _ = query.shape
38
+ _, Lkv, _ = key_value.shape
39
+
40
+ q = self.q(query)
41
+ kv = self.kv(key_value)
42
+
43
+ q = q.reshape(B, Lq, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
44
+
45
+ # Split KV
46
+ kv = kv.reshape(B, Lkv, 2, self.num_heads, self.head_dim)
47
+ kv = kv.transpose(2, 0, 3, 1, 4) # (2, B, n_heads, Lkv, head_dim)
48
+ k, v = kv[0], kv[1]
49
+
50
+ attn = (q @ k.transpose(0, 1, 3, 2)) * self.scale
51
+ attn = mx.softmax(attn, axis=-1)
52
+ x = attn @ v
53
+
54
+ x = x.transpose(0, 2, 1, 3).reshape(B, Lq, -1)
55
+ x = self.out(x)
56
+ return x
57
+
58
+
59
+ class ConnectorMLP(nn.Module):
60
+ """MLP projector with SwiGLU - matches weight naming: projector.gate_up, projector.down"""
61
+
62
+ def __init__(self, config: VisionConfig):
63
+ super().__init__()
64
+ input_size = config.hidden_size
65
+ hidden_size = config.connector_hidden_size
66
+ output_size = config.output_size
67
+
68
+ # Named to match weights: projector.gate_up, projector.down
69
+ self.gate_up = nn.Linear(input_size, 2 * hidden_size, bias=False)
70
+ self.down = nn.Linear(hidden_size, output_size, bias=False)
71
+
72
+ def __call__(self, x: mx.array) -> mx.array:
73
+ gate_up = self.gate_up(x)
74
+ # Jina VLM convention: first half is value, second half is gate (activated)
75
+ up, gate = mx.split(gate_up, 2, axis=-1)
76
+ return self.down(nn.silu(gate) * up)
77
+
78
+
79
+ class VisionLanguageConnector(nn.Module):
80
+ """Vision-Language Connector - matches weight naming: vl_connector.pooling, vl_connector.projector"""
81
+
82
+ def __init__(self, config: VisionConfig):
83
+ super().__init__()
84
+ self.config = config
85
+
86
+ self.pooling_h = config.pooling_h
87
+ self.pooling_w = config.pooling_w
88
+
89
+ self.crop_patches = config.image_size // config.patch_size
90
+ self.token_length_h = (
91
+ self.crop_patches + config.pooling_h - 1
92
+ ) // config.pooling_h
93
+ self.token_length_w = (
94
+ self.crop_patches + config.pooling_w - 1
95
+ ) // config.pooling_w
96
+ self.tokens_per_image = self.token_length_h * self.token_length_w
97
+
98
+ input_size = config.hidden_size * len(config.vit_layers)
99
+ # Named to match weights: vl_connector.pad_embed
100
+ self.pad_embed = mx.zeros((2, input_size))
101
+
102
+ # Named to match weights: vl_connector.pooling
103
+ self.pooling = CrossAttention(config)
104
+
105
+ # Named to match weights: vl_connector.projector
106
+ self.projector = ConnectorMLP(config)
107
+
108
+ def __call__(
109
+ self, image_features: mx.array, image_masks: Optional[mx.array] = None
110
+ ) -> mx.array:
111
+ B, n_crops = image_features.shape[:2]
112
+ n_patch_h = n_patch_w = self.crop_patches
113
+
114
+ if image_masks is not None:
115
+ all_pad = (image_masks == 0).astype(mx.float32)
116
+ partial_pad = mx.logical_and(
117
+ image_masks < 1, mx.logical_not(image_masks == 0)
118
+ ).astype(mx.float32)
119
+
120
+ pad_embed_0 = self.pad_embed[0][None, None, None, :]
121
+ pad_embed_1 = self.pad_embed[1][None, None, None, :]
122
+
123
+ image_features = image_features + pad_embed_0 * all_pad[..., None]
124
+ image_features = image_features + pad_embed_1 * partial_pad[..., None]
125
+
126
+ image_features = image_features.reshape(B, n_crops, n_patch_h, n_patch_w, -1)
127
+
128
+ pad_h = n_patch_h % self.pooling_h
129
+ pad_w = n_patch_w % self.pooling_w
130
+ if pad_h != 0 or pad_w != 0:
131
+ pad_h = self.pooling_h - pad_h if pad_h != 0 else 0
132
+ pad_w = self.pooling_w - pad_w if pad_w != 0 else 0
133
+ image_features = mx.pad(
134
+ image_features, [(0, 0), (0, 0), (0, pad_h), (0, pad_w), (0, 0)]
135
+ )
136
+
137
+ _, _, H, W, C = image_features.shape
138
+ new_h, new_w = H // self.pooling_h, W // self.pooling_w
139
+
140
+ image_features = image_features.reshape(
141
+ B, n_crops, new_h, self.pooling_h, new_w, self.pooling_w, C
142
+ )
143
+ image_features = image_features.transpose(0, 1, 2, 4, 3, 5, 6)
144
+ image_features = image_features.reshape(
145
+ B * n_crops * new_h * new_w, self.pooling_h * self.pooling_w, C
146
+ )
147
+
148
+ query = image_features.mean(axis=1, keepdims=True)
149
+ pooled = self.pooling(query, image_features)
150
+
151
+ pooled = pooled.reshape(B, n_crops, new_h * new_w, -1)
152
+ output = self.projector(pooled)
153
+
154
+ return output
155
+
156
+
157
+ class Model(nn.Module):
158
+ """Jina Vision-Language Model - matches weight naming structure"""
159
+
160
+ def __init__(self, config: ModelConfig):
161
+ super().__init__()
162
+ self.config = config
163
+
164
+ # Named to match weights: vision_model
165
+ self.vision_model = VisionModel(config.vision_config)
166
+
167
+ # Named to match weights: vl_connector
168
+ self.vl_connector = VisionLanguageConnector(config.vision_config)
169
+
170
+ # Named to match weights: language_model
171
+ self.language_model = LanguageModel(config.text_config)
172
+
173
+ # lm_head is now inside language_model (weights will be mapped in sanitize)
174
+ self.language_model.lm_head = nn.Linear(
175
+ config.text_config.hidden_size, config.text_config.vocab_size, bias=False
176
+ )
177
+
178
+ @property
179
+ def layers(self):
180
+ return self.language_model.layers
181
+
182
+ def get_image_features(
183
+ self,
184
+ images: mx.array,
185
+ image_masks: Optional[mx.array] = None,
186
+ ) -> mx.array:
187
+ B, n_crops, n_patches, patch_dim = images.shape
188
+ dtype = self.vision_model.patch_embed.proj.weight.dtype
189
+
190
+ images_flat = images.reshape(B * n_crops, n_patches, patch_dim).astype(dtype)
191
+ valid_mask = ~mx.all(
192
+ images_flat.reshape(B * n_crops, -1) == -1, axis=-1, keepdims=True
193
+ )
194
+ valid_mask = valid_mask[:, :, None]
195
+
196
+ image_features = self.vision_model.get_features(images_flat)
197
+ image_features = image_features * valid_mask
198
+
199
+ n_output_patches = image_features.shape[1]
200
+ image_features = image_features.reshape(B, n_crops, n_output_patches, -1)
201
+ image_features = self.vl_connector(image_features, image_masks)
202
+
203
+ return image_features
204
+
205
+ def get_input_embeddings(
206
+ self,
207
+ input_ids: Optional[mx.array] = None,
208
+ pixel_values: Optional[mx.array] = None,
209
+ **kwargs,
210
+ ):
211
+ batch_size, seq_len = input_ids.shape
212
+
213
+ image_masks = kwargs.get("image_masks", None)
214
+ image_input_idx = kwargs.get("image_input_idx", None)
215
+
216
+ inputs_embeds = self.language_model.embedding(input_ids)
217
+
218
+ if pixel_values is not None and image_input_idx is not None:
219
+ if pixel_values.ndim == 3:
220
+ pixel_values = mx.expand_dims(pixel_values, 0)
221
+ image_masks = (
222
+ mx.expand_dims(image_masks, 0) if image_masks is not None else None
223
+ )
224
+ image_input_idx = (
225
+ mx.expand_dims(image_input_idx, 0)
226
+ if image_input_idx is not None
227
+ else None
228
+ )
229
+
230
+ image_features = self.get_image_features(pixel_values, image_masks)
231
+
232
+ num_image, num_patch = image_features.shape[1:3]
233
+
234
+ image_features = image_features.reshape(
235
+ batch_size, num_image * num_patch, -1
236
+ )
237
+ image_input_idx = image_input_idx.reshape(batch_size, num_image * num_patch)
238
+
239
+ for b in range(batch_size):
240
+ idx = image_input_idx[b]
241
+ features = image_features[b]
242
+
243
+ for i in range(idx.shape[0]):
244
+ pos = int(idx[i].item())
245
+ if pos >= 0 and pos < seq_len:
246
+ inputs_embeds = inputs_embeds.at[b, pos].add(features[i])
247
+
248
+ return InputEmbeddingsFeatures(inputs_embeds=inputs_embeds)
249
+
250
+ def __call__(
251
+ self,
252
+ input_ids: mx.array,
253
+ pixel_values: Optional[mx.array] = None,
254
+ mask: Optional[mx.array] = None,
255
+ cache=None,
256
+ **kwargs,
257
+ ) -> mx.array:
258
+
259
+ input_embeddings_features = self.get_input_embeddings(
260
+ input_ids, pixel_values, **kwargs
261
+ )
262
+ return self.language_model(
263
+ input_ids,
264
+ inputs_embeds=input_embeddings_features.inputs_embeds,
265
+ mask=mask,
266
+ cache=cache,
267
+ )
268
+
269
+ def sanitize(self, weights):
270
+ """Sanitize weight names for loading."""
271
+ new_weights = {}
272
+ for k, v in weights.items():
273
+ # Map lm_head to language_model.lm_head since language_model now has lm_head
274
+ if k.startswith("lm_head."):
275
+ new_k = "language_model." + k
276
+ else:
277
+ new_k = k
278
+ new_weights[new_k] = v
279
+
280
+ return new_weights