fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,265 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ import numpy as np
6
+
7
+ from ..base import interpolate
8
+ from .config import VisionConfig
9
+
10
+
11
+ def check_array_shape(arr):
12
+ shape = arr.shape
13
+
14
+ # Check if the shape has 4 dimensions
15
+ if len(shape) != 4:
16
+ return False
17
+
18
+ out_channels, kH, KW, _ = shape
19
+
20
+ # Check if out_channels is the largest, and kH and KW are the same
21
+ if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
22
+ return True
23
+ else:
24
+ return False
25
+
26
+
27
+ class Attention(nn.Module):
28
+ def __init__(self, config: VisionConfig):
29
+ super().__init__()
30
+
31
+ if (config.hidden_size % config.num_attention_heads) != 0:
32
+ raise ValueError(
33
+ "The input feature dimensions should be divisible by the "
34
+ f"number of heads ({config.hidden_size} % {config.num_attention_heads}) != 0"
35
+ )
36
+
37
+ self.dims = dims = config.hidden_size
38
+
39
+ self.num_heads = config.num_attention_heads
40
+ head_dim = config.hidden_size // config.num_attention_heads
41
+ self.scale = head_dim**-0.5
42
+ self.qkv_bias = config.qkv_bias
43
+
44
+ self.qkv = nn.Linear(dims, 3 * dims, bias=config.qkv_bias)
45
+ self.proj = nn.Linear(dims, dims)
46
+
47
+ self.qk_normalization = config.qk_normalization
48
+
49
+ if self.qk_normalization:
50
+ self.q_norm = nn.RMSNorm(dims, eps=config.layer_norm_eps)
51
+ self.k_norm = nn.RMSNorm(dims, eps=config.layer_norm_eps)
52
+
53
+ def __call__(self, x, mask=None):
54
+ B, L, C = x.shape
55
+ qkv = self.qkv(x).reshape(B, L, 3, self.num_heads, C // self.num_heads)
56
+ qkv = qkv.transpose(2, 0, 3, 1, 4)
57
+ queries, keys, values = (
58
+ qkv[0],
59
+ qkv[1],
60
+ qkv[2],
61
+ ) # Each has shape (B, groups, N, C//groups)
62
+
63
+ if self.qk_normalization:
64
+ B_, H_, N_, D_ = queries.shape
65
+ queries = (
66
+ self.q_norm(queries.transpose(0, 2, 1, 3).flatten(-2, -1))
67
+ .reshape(B_, N_, H_, D_)
68
+ .transpose(0, 2, 1, 3)
69
+ )
70
+ keys = (
71
+ self.k_norm(keys.transpose(0, 2, 1, 3).flatten(-2, -1))
72
+ .reshape(B_, N_, H_, D_)
73
+ .transpose(0, 2, 1, 3)
74
+ )
75
+
76
+ output = mx.fast.scaled_dot_product_attention(
77
+ queries, keys, values, scale=self.scale, mask=mask
78
+ )
79
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
80
+ return self.proj(output)
81
+
82
+
83
+ class MLP(nn.Module):
84
+ def __init__(self, config: VisionConfig):
85
+ super().__init__()
86
+ self.activation_fn = nn.GELU(approx="precise")
87
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
88
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
89
+
90
+ def __call__(self, x: mx.array) -> mx.array:
91
+ x = self.fc1(x)
92
+ x = self.activation_fn(x)
93
+ x = self.fc2(x)
94
+ return x
95
+
96
+
97
+ class EncoderLayer(nn.Module):
98
+ def __init__(self, config: VisionConfig, drop_path_rate: float = 0.0):
99
+ super().__init__()
100
+ self.embed_dim = config.hidden_size
101
+ self.intermediate_size = config.intermediate_size
102
+ self.norm_type = getattr(config, "norm_type", "layer_norm")
103
+
104
+ self.attn = Attention(config)
105
+ self.mlp = MLP(config)
106
+
107
+ if self.norm_type == "layer_norm":
108
+ self.norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
109
+ self.norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
110
+ elif self.norm_type == "rms_norm":
111
+ self.norm1 = nn.RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
112
+ self.norm2 = nn.RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
113
+ else:
114
+ raise ValueError(f"Unsupported normalization type: {self.norm_type}")
115
+
116
+ self.ls1 = mx.ones((self.embed_dim,))
117
+ self.ls2 = mx.ones((self.embed_dim,))
118
+
119
+ self.drop_path1 = (
120
+ nn.Dropout(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
121
+ )
122
+ self.drop_path2 = (
123
+ nn.Dropout(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
124
+ )
125
+
126
+ def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
127
+ dtype = x.dtype
128
+ x = x + self.drop_path1(self.attn(self.norm1(x).astype(dtype)) * self.ls1)
129
+
130
+ x = x + self.drop_path2(self.mlp(self.norm2(x).astype(dtype)) * self.ls2)
131
+
132
+ return x.astype(dtype)
133
+
134
+
135
+ class Encoder(nn.Module):
136
+ def __init__(self, config: VisionConfig):
137
+ super().__init__()
138
+ dpr = [
139
+ mx.array(x)
140
+ for x in np.linspace(0, config.drop_path_rate, config.num_hidden_layers)
141
+ ]
142
+ self.layers = [
143
+ EncoderLayer(config, dpr[i]) for i in range(config.num_hidden_layers)
144
+ ]
145
+
146
+ def __call__(
147
+ self,
148
+ x: mx.array,
149
+ output_hidden_states: Optional[bool] = None,
150
+ mask: Optional[mx.array] = None,
151
+ ) -> mx.array:
152
+ encoder_states = (x,) if output_hidden_states else None
153
+ h = x
154
+ for l in self.layers:
155
+ x = l(x, mask=mask)
156
+ if output_hidden_states:
157
+ encoder_states = encoder_states + (x,)
158
+
159
+ h = x
160
+
161
+ return (h, encoder_states)
162
+
163
+
164
+ class VisionEmbeddings(nn.Module):
165
+ def __init__(self, config: VisionConfig):
166
+ super().__init__()
167
+ self.config = config
168
+ self.embed_dim = config.hidden_size
169
+ self.image_size = config.image_size
170
+ self.patch_size = config.patch_size
171
+
172
+ self.class_embedding = mx.random.normal((1, 1, self.embed_dim))
173
+
174
+ self.patch_embedding = nn.Conv2d(
175
+ in_channels=3,
176
+ out_channels=self.embed_dim,
177
+ kernel_size=self.patch_size,
178
+ stride=self.patch_size,
179
+ )
180
+
181
+ self.num_patches = (self.image_size // self.patch_size) ** 2
182
+ self.num_positions = self.num_patches + 1
183
+
184
+ self.position_embedding = mx.random.normal(
185
+ (1, self.num_positions, self.embed_dim)
186
+ )
187
+
188
+ def _get_pos_embed(self, pos_embed, H, W):
189
+ target_dtype = pos_embed.dtype
190
+ pos_embed = pos_embed.reshape(
191
+ 1,
192
+ self.image_size // self.patch_size,
193
+ self.image_size // self.patch_size,
194
+ -1,
195
+ ).transpose(0, 3, 1, 2)
196
+ pos_embed = interpolate(pos_embed, (H, W))
197
+ pos_embed = (
198
+ pos_embed.reshape(1, -1, H * W).transpose(0, 2, 1).astype(target_dtype)
199
+ )
200
+ return pos_embed
201
+
202
+ def __call__(self, x: mx.array) -> mx.array:
203
+ target_dtype = self.patch_embedding.weight.dtype
204
+ patch_embeds = self.patch_embedding(x).transpose(
205
+ 0, 3, 1, 2
206
+ ) # shape = [*, channel, width, height]
207
+ batch_size, _, height, width = patch_embeds.shape
208
+ patch_embeds = mx.flatten(patch_embeds, start_axis=2).transpose(0, 2, 1)
209
+ class_embeds = mx.broadcast_to(
210
+ self.class_embedding, (batch_size, 1, self.embed_dim)
211
+ ).astype(target_dtype)
212
+ embeddings = mx.concatenate([class_embeds, patch_embeds], axis=1)
213
+ position_embedding = mx.concatenate(
214
+ [
215
+ self.position_embedding[:, :1, :],
216
+ self._get_pos_embed(self.position_embedding[:, 1:, :], height, width),
217
+ ],
218
+ axis=1,
219
+ )
220
+ embeddings = embeddings + position_embedding.astype(target_dtype)
221
+
222
+ return embeddings
223
+
224
+
225
+ class VisionModel(nn.Module):
226
+ def __init__(self, config: VisionConfig):
227
+ super().__init__()
228
+ self.model_type = config.model_type
229
+ if self.model_type not in ["siglip_vision_model", "intern_vit_6b"]:
230
+ raise ValueError(f"Unsupported model type: {self.model_type}")
231
+
232
+ self.embeddings = VisionEmbeddings(config)
233
+ self.encoder = Encoder(config)
234
+
235
+ def __call__(
236
+ self,
237
+ x: mx.array,
238
+ output_hidden_states: Optional[bool] = None,
239
+ ) -> mx.array:
240
+ x = self.embeddings(x)
241
+ last_hidden_state, encoder_outputs = self.encoder(
242
+ x=x, output_hidden_states=output_hidden_states, mask=None
243
+ )
244
+ pooler_output = last_hidden_state[:, 0, :]
245
+ return last_hidden_state, pooler_output, encoder_outputs[1:]
246
+
247
+ def sanitize(self, weights):
248
+ sanitized_weights = {}
249
+ for k, v in weights.items():
250
+ if "position_ids" in k:
251
+ # Remove unused position_ids
252
+ continue
253
+ elif "patch_embedding.weight" in k:
254
+ # PyTorch conv2d weight tensors have shape:
255
+ # [out_channels, in_channels, kH, KW]
256
+ # MLX conv2d expects the weight be of shape:
257
+ # [out_channels, kH, KW, in_channels]
258
+ if check_array_shape(v):
259
+ sanitized_weights[k] = v
260
+ else:
261
+ sanitized_weights[k] = v.transpose(0, 2, 3, 1)
262
+ else:
263
+ sanitized_weights[k] = v
264
+
265
+ return sanitized_weights
@@ -0,0 +1,183 @@
1
+ import mlx.core as mx
2
+ import numpy as np
3
+
4
+
5
+ def gaussian_blur_axis(image, sigma, axis):
6
+ """
7
+ Applies a 1D Gaussian blur along the given axis.
8
+ This version works for arrays with any number of dimensions.
9
+ """
10
+ radius = int(3 * sigma)
11
+ if radius < 1:
12
+ return image
13
+ x = mx.arange(-radius, radius + 1)
14
+ kernel = mx.exp(-(x**2) / (2 * sigma**2))
15
+ kernel = kernel / mx.sum(kernel)
16
+
17
+ # MLX doesn't have a direct apply_along_axis equivalent,
18
+ # so we'll implement the convolution differently based on the axis
19
+
20
+ # Helper function to apply 1D convolution along specific axis
21
+ def conv_1d(array, kernel, axis):
22
+ # Reshape kernel to broadcast along the right dimensions
23
+ kernel_shape = [1] * image.ndim
24
+ kernel_shape[axis] = len(kernel)
25
+ kernel_reshaped = kernel.reshape(kernel_shape)
26
+
27
+ # Pad the array
28
+ pad_width = [(0, 0)] * image.ndim
29
+ pad_width[axis] = (radius, radius)
30
+ padded = mx.pad(array, pad_width, mode="edge")
31
+
32
+ # Perform convolution via sliding window sum
33
+ result = mx.zeros_like(array)
34
+ slices = [slice(None)] * padded.ndim
35
+
36
+ for i in range(2 * radius + 1):
37
+ slices[axis] = slice(i, i + array.shape[axis])
38
+ result = result + padded[tuple(slices)] * kernel_reshaped
39
+
40
+ return result
41
+
42
+ return conv_1d(image, kernel, axis)
43
+
44
+
45
+ def bilinear_interpolate(image, new_height, new_width, align_corners=False):
46
+ """
47
+ Performs bilinear interpolation on an array whose spatial dimensions are the first two.
48
+ It supports extra dimensions (e.g. channels or batch dimensions that have been moved to the trailing axes).
49
+ """
50
+ # image is assumed to have shape (H, W, ...) where H and W are spatial dimensions.
51
+ H_in, W_in = image.shape[0], image.shape[1]
52
+
53
+ # Compute sampling positions in the input image.
54
+ if new_height == 1:
55
+ row_positions = mx.array([0.0])
56
+ else:
57
+ if align_corners:
58
+ row_positions = mx.linspace(0, H_in - 1, new_height)
59
+ else:
60
+ row_positions = (mx.arange(new_height) + 0.5) * H_in / new_height - 0.5
61
+
62
+ if new_width == 1:
63
+ col_positions = mx.array([0.0])
64
+ else:
65
+ if align_corners:
66
+ col_positions = mx.linspace(0, W_in - 1, new_width)
67
+ else:
68
+ col_positions = (mx.arange(new_width) + 0.5) * W_in / new_width - 0.5
69
+
70
+ # Compute floor and ceil indices.
71
+ row_floor = mx.floor(row_positions).astype(mx.int32)
72
+ col_floor = mx.floor(col_positions).astype(mx.int32)
73
+ row_ceil = row_floor + 1
74
+ col_ceil = col_floor + 1
75
+
76
+ row_floor = mx.clip(row_floor, 0, H_in - 1)
77
+ row_ceil = mx.clip(row_ceil, 0, H_in - 1)
78
+ col_floor = mx.clip(col_floor, 0, W_in - 1)
79
+ col_ceil = mx.clip(col_ceil, 0, W_in - 1)
80
+
81
+ row_weight = row_positions - row_floor # shape (new_height,)
82
+ col_weight = col_positions - col_floor # shape (new_width,)
83
+
84
+ # Use advanced indexing for gather operations
85
+ # Create meshgrid for coordinates
86
+ row_floor_grid, col_floor_grid = mx.meshgrid(row_floor, col_floor, indexing="ij")
87
+ row_ceil_grid, col_floor_grid = mx.meshgrid(row_ceil, col_floor, indexing="ij")
88
+ row_floor_grid, col_ceil_grid = mx.meshgrid(row_floor, col_ceil, indexing="ij")
89
+ row_ceil_grid, col_ceil_grid = mx.meshgrid(row_ceil, col_ceil, indexing="ij")
90
+
91
+ # Gather the four surrounding pixels using take_along_axis
92
+ # For higher dimensional arrays, we'll need to reshape and broadcast
93
+ extra_dims = image.ndim - 2
94
+
95
+ def gather_pixels(row_indices, col_indices):
96
+ # Flatten the spatial dimensions for gathering
97
+ flat_indices = row_indices * W_in + col_indices
98
+ flat_image = mx.reshape(image, (-1,) + image.shape[2:])
99
+ # Gather and reshape back
100
+ gathered = mx.take(flat_image, flat_indices.reshape(-1), axis=0)
101
+ return mx.reshape(gathered, (new_height, new_width) + image.shape[2:])
102
+
103
+ top_left = gather_pixels(row_floor_grid, col_floor_grid)
104
+ top_right = gather_pixels(row_floor_grid, col_ceil_grid)
105
+ bottom_left = gather_pixels(row_ceil_grid, col_floor_grid)
106
+ bottom_right = gather_pixels(row_ceil_grid, col_ceil_grid)
107
+
108
+ # Expand the weights to have shape (new_height, new_width, *[1]*extra_dims)
109
+ r_weight = row_weight.reshape(new_height, 1, *([1] * extra_dims))
110
+ c_weight = col_weight.reshape(1, new_width, *([1] * extra_dims))
111
+
112
+ # Perform bilinear interpolation.
113
+ result = (
114
+ (1 - r_weight) * (1 - c_weight) * top_left
115
+ + (1 - r_weight) * c_weight * top_right
116
+ + r_weight * (1 - c_weight) * bottom_left
117
+ + r_weight * c_weight * bottom_right
118
+ )
119
+ return result
120
+
121
+
122
+ def resize_bilinear(image, new_size, align_corners=False, antialias=True):
123
+ """
124
+ Resizes an image (or embedding tensor) to new_size=(new_height, new_width)
125
+ using bilinear interpolation with MLX.
126
+
127
+ Supports:
128
+ - 2D: (H, W)
129
+ - 3D: (H, W, C)
130
+ - 4D: (B, C, H, W) (assumed for typical image batches)
131
+ """
132
+ new_height, new_width = new_size
133
+
134
+ # Convert numpy arrays to MLX arrays if needed
135
+ if isinstance(image, np.ndarray):
136
+ image = mx.array(image)
137
+
138
+ if image.ndim == 2 or image.ndim == 3:
139
+ # Assume spatial dims are the first two.
140
+ resized = image
141
+ H_in, W_in = image.shape[:2]
142
+ if antialias:
143
+ if new_height < H_in:
144
+ scale_y = new_height / H_in
145
+ sigma_y = (1 / scale_y - 1) / 2.0 # heuristic
146
+ if sigma_y > 0:
147
+ resized = gaussian_blur_axis(resized, sigma_y, axis=0)
148
+ if new_width < W_in:
149
+ scale_x = new_width / W_in
150
+ sigma_x = (1 / scale_x - 1) / 2.0
151
+ if sigma_x > 0:
152
+ resized = gaussian_blur_axis(resized, sigma_x, axis=1)
153
+ resized = bilinear_interpolate(
154
+ resized, new_height, new_width, align_corners=align_corners
155
+ )
156
+ return resized
157
+
158
+ elif image.ndim == 4:
159
+ # Assume shape is (B, C, H, W) (typical PyTorch/MLX format).
160
+ B, C, H_in, W_in = image.shape
161
+ # Permute to bring spatial dims to the front: (H, W, B, C)
162
+ image_perm = mx.transpose(image, (2, 3, 0, 1))
163
+ resized = image_perm
164
+ if antialias:
165
+ if new_height < H_in:
166
+ scale_y = new_height / H_in
167
+ sigma_y = (1 / scale_y - 1) / 2.0
168
+ if sigma_y > 0:
169
+ resized = gaussian_blur_axis(resized, sigma_y, axis=0)
170
+ if new_width < W_in:
171
+ scale_x = new_width / W_in
172
+ sigma_x = (1 / scale_x - 1) / 2.0
173
+ if sigma_x > 0:
174
+ resized = gaussian_blur_axis(resized, sigma_x, axis=1)
175
+ resized = bilinear_interpolate(
176
+ resized, new_height, new_width, align_corners=align_corners
177
+ )
178
+ # Permute back to (B, C, new_height, new_width)
179
+ resized = mx.transpose(resized, (2, 3, 0, 1))
180
+ return resized
181
+
182
+ else:
183
+ raise ValueError("Unsupported image dimensions.")
@@ -0,0 +1,3 @@
1
+ from .config import ModelConfig, TextConfig, VisionConfig
2
+ from .image_processor import ImageProcessor
3
+ from .jina_vlm import JinaVLMProcessor, LanguageModel, Model, VisionModel
@@ -0,0 +1,142 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import Tuple
3
+
4
+ from ..base import BaseModelConfig
5
+
6
+
7
+ @dataclass
8
+ class VisionConfig(BaseModelConfig):
9
+ """Vision encoder configuration for Jina VLM."""
10
+
11
+ model_type: str = "jina_vlm"
12
+ hidden_size: int = 1152
13
+ num_hidden_layers: int = 27
14
+ num_attention_heads: int = 16
15
+ head_dim: int = 72
16
+ patch_size: int = 14
17
+ image_size: int = 378
18
+ num_channels: int = 3
19
+ intermediate_size: int = 4304
20
+ layer_norm_eps: float = 1e-6
21
+ use_bias: bool = True
22
+ use_cls_token: bool = False
23
+ post_layer_norm: bool = True
24
+ activation: str = "gelu_pytorch_tanh"
25
+ vit_layers: Tuple[int, ...] = (-4, -10)
26
+ output_size: int = 2048
27
+ # Connector config
28
+ pooling_h: int = 2
29
+ pooling_w: int = 2
30
+ connector_hidden_size: int = 6144
31
+
32
+
33
+ @dataclass
34
+ class TextConfig(BaseModelConfig):
35
+ """Text decoder configuration for Jina VLM."""
36
+
37
+ model_type: str = "jina_vlm"
38
+ hidden_size: int = 2048
39
+ num_hidden_layers: int = 28
40
+ num_attention_heads: int = 16
41
+ num_key_value_heads: int = 8
42
+ head_dim: int = 128
43
+ vocab_size: int = 151936
44
+ additional_vocab_size: int = 128
45
+ intermediate_size: int = 6144
46
+ rms_norm_eps: float = 1e-6
47
+ rope_theta: float = 1000000.0
48
+ max_position_embeddings: int = 40960
49
+ use_qk_norm: bool = True
50
+ tie_word_embeddings: bool = False
51
+
52
+
53
+ @dataclass
54
+ class ModelConfig(BaseModelConfig):
55
+ """Full Jina VLM configuration."""
56
+
57
+ text_config: TextConfig = field(default_factory=TextConfig)
58
+ vision_config: VisionConfig = field(default_factory=VisionConfig)
59
+ model_type: str = "jina_vlm"
60
+ vocab_size: int = 151936
61
+ bos_token_id: int = 151643
62
+ eos_token_id: int = 151643
63
+ pad_token_id: int = 151643
64
+ image_token_index: int = 151940 # <|image|>
65
+ image_token_id: int = 151940 # <|image|>
66
+ image_start_token_id: int = 151666 # <im_start>
67
+ image_end_token_id: int = 151667 # <im_end>
68
+ image_patch_token_id: int = 151665 # <im_patch>
69
+ image_column_token_id: int = 151668 # <im_col>
70
+ ignore_index: int = -100
71
+ tie_word_embeddings: bool = False
72
+
73
+ @classmethod
74
+ def from_dict(cls, params):
75
+ # Parse vision config
76
+ vision_cfg = params.get("vision_config", {})
77
+ vision_block = vision_cfg.get("block_config", {})
78
+ vision_attn = vision_block.get("attn_config", {})
79
+ vision_ffn = vision_block.get("ffn_config", {})
80
+ vl_connector = vision_cfg.get("vl_connector_config", {})
81
+ connector_mlp = vl_connector.get("mlp_projector_config", {})
82
+
83
+ vision_config = VisionConfig(
84
+ hidden_size=vision_cfg.get("hidden_size", 1152),
85
+ num_hidden_layers=vision_cfg.get("n_layers", 27),
86
+ num_attention_heads=vision_attn.get("n_heads", 16),
87
+ head_dim=vision_attn.get("head_dim", 72),
88
+ patch_size=vision_cfg.get("patch_size", 14),
89
+ image_size=(
90
+ vision_cfg.get("input_size", [378, 378])[0]
91
+ if isinstance(vision_cfg.get("input_size"), list)
92
+ else 378
93
+ ),
94
+ num_channels=vision_cfg.get("n_channels", 3),
95
+ intermediate_size=vision_ffn.get("size", 4304),
96
+ use_bias=vision_attn.get("q_bias", True),
97
+ use_cls_token=vision_cfg.get("use_cls_token", False),
98
+ post_layer_norm=vision_cfg.get("post_lnorm", True),
99
+ activation=vision_ffn.get("activation_type", "gelu_pytorch_tanh"),
100
+ vit_layers=tuple(vision_cfg.get("vit_layers", [-4, -10])),
101
+ output_size=vision_cfg.get("output_size", 2048),
102
+ pooling_h=vl_connector.get("pooling_h", 2),
103
+ pooling_w=vl_connector.get("pooling_w", 2),
104
+ connector_hidden_size=connector_mlp.get("size", 6144),
105
+ )
106
+
107
+ # Parse text config
108
+ text_cfg = params.get("text_config", {})
109
+ text_block = text_cfg.get("block_config", {})
110
+ text_attn = text_block.get("attn_config", {})
111
+ text_ffn = text_block.get("ffn_config", {})
112
+ text_lnorm = text_block.get("lnorm_config", {})
113
+
114
+ text_config = TextConfig(
115
+ hidden_size=text_cfg.get("hidden_size", 2048),
116
+ num_hidden_layers=text_cfg.get(
117
+ "n_layers", text_cfg.get("num_hidden_layers", 28)
118
+ ),
119
+ num_attention_heads=text_attn.get("n_heads", 16),
120
+ num_key_value_heads=text_attn.get("n_kv_heads", 8),
121
+ head_dim=text_attn.get("head_dim", 128),
122
+ vocab_size=text_cfg.get("vocab_size", 151936),
123
+ additional_vocab_size=text_cfg.get("additional_vocab_size", 128),
124
+ intermediate_size=text_ffn.get("size", 6144),
125
+ rms_norm_eps=text_lnorm.get("eps", 1e-6),
126
+ rope_theta=text_cfg.get("rope_theta", 1000000.0),
127
+ max_position_embeddings=text_cfg.get("max_sequence_length", 40960),
128
+ use_qk_norm=text_attn.get("q_lnorm", True),
129
+ tie_word_embeddings=text_cfg.get("tie_word_embeddings", False),
130
+ )
131
+
132
+ return cls(
133
+ text_config=text_config,
134
+ vision_config=vision_config,
135
+ model_type=params.get("model_type", "jina_vlm"),
136
+ vocab_size=params.get("vocab_size", text_config.vocab_size),
137
+ bos_token_id=params.get("bos_token_id", 151643),
138
+ eos_token_id=params.get("eos_token_id", 151643),
139
+ pad_token_id=params.get("pad_token_id", 151643),
140
+ image_token_index=params.get("image_token_index", 151940),
141
+ tie_word_embeddings=params.get("tie_word_embeddings", False),
142
+ )