fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,340 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ import numpy as np
6
+
7
+ from ..interpolate import resize_bilinear
8
+ from .config import VisionConfig
9
+
10
+
11
+ def check_array_shape(arr):
12
+ shape = arr.shape
13
+
14
+ # Check if the shape has 4 dimensions
15
+ if len(shape) != 4:
16
+ return False
17
+
18
+ out_channels, kH, KW, _ = shape
19
+
20
+ # Check if out_channels is the largest, and kH and KW are the same
21
+ if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
22
+ return True
23
+ else:
24
+ return False
25
+
26
+
27
+ class Attention(nn.Module):
28
+ def __init__(
29
+ self,
30
+ dims: int,
31
+ num_heads: int,
32
+ query_input_dims: Optional[int] = None,
33
+ key_input_dims: Optional[int] = None,
34
+ value_input_dims: Optional[int] = None,
35
+ value_dims: Optional[int] = None,
36
+ value_output_dims: Optional[int] = None,
37
+ bias: bool = True,
38
+ ):
39
+ super().__init__()
40
+
41
+ if (dims % num_heads) != 0:
42
+ raise ValueError(
43
+ "The input feature dimensions should be divisible by the "
44
+ f"number of heads ({dims} % {num_heads}) != 0"
45
+ )
46
+
47
+ query_input_dims = query_input_dims or dims
48
+ key_input_dims = key_input_dims or dims
49
+ value_input_dims = value_input_dims or key_input_dims
50
+ value_dims = value_dims or dims
51
+ value_output_dims = value_output_dims or dims
52
+
53
+ self.num_heads = num_heads
54
+ head_dim = dims // num_heads
55
+ self.scale = head_dim**-0.5
56
+
57
+ self.q_proj = nn.Linear(query_input_dims, dims, bias=bias)
58
+ self.k_proj = nn.Linear(key_input_dims, dims, bias=bias)
59
+ self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias)
60
+ self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias)
61
+
62
+ def __call__(self, x, mask=None):
63
+ queries = self.q_proj(x)
64
+ keys = self.k_proj(x)
65
+ values = self.v_proj(x)
66
+
67
+ num_heads = self.num_heads
68
+ B, L, D = queries.shape
69
+ _, S, _ = keys.shape
70
+ queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
71
+ keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
72
+ values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
73
+
74
+ output = mx.fast.scaled_dot_product_attention(
75
+ queries, keys, values, scale=self.scale, mask=mask
76
+ )
77
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
78
+ return self.out_proj(output)
79
+
80
+
81
+ class MLP(nn.Module):
82
+ def __init__(self, config: VisionConfig):
83
+ super().__init__()
84
+ self.activation_fn = nn.GELU(approx="precise")
85
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=True)
86
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=True)
87
+
88
+ def __call__(self, x: mx.array) -> mx.array:
89
+ x = self.fc1(x)
90
+ x = self.activation_fn(x)
91
+ x = self.fc2(x)
92
+ return x
93
+
94
+
95
+ class EncoderLayer(nn.Module):
96
+ def __init__(self, config: VisionConfig):
97
+ super().__init__()
98
+ self.embed_dim = config.hidden_size
99
+ self.self_attn = Attention(
100
+ config.hidden_size, config.num_attention_heads, bias=True
101
+ )
102
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
103
+ self.mlp = MLP(config)
104
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
105
+
106
+ def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
107
+ r = self.self_attn(self.layer_norm1(x), mask)
108
+ h = x + r
109
+ r = self.mlp(self.layer_norm2(h))
110
+ return h + r
111
+
112
+
113
+ class Encoder(nn.Module):
114
+ def __init__(self, config: VisionConfig):
115
+ super().__init__()
116
+ self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)]
117
+
118
+ def __call__(
119
+ self,
120
+ x: mx.array,
121
+ output_hidden_states: Optional[bool] = None,
122
+ mask: Optional[mx.array] = None,
123
+ ) -> mx.array:
124
+ encoder_states = (x,) if output_hidden_states else None
125
+ h = x
126
+ for l in self.layers:
127
+ x = l(x, mask=mask)
128
+ if output_hidden_states:
129
+ encoder_states = encoder_states + (x,)
130
+
131
+ h = x
132
+
133
+ return (h, encoder_states)
134
+
135
+
136
+ def gaussian_blur_axis(image, sigma, axis):
137
+ """
138
+ Applies a 1D Gaussian blur along the given axis.
139
+ This version works for arrays with any number of dimensions.
140
+ """
141
+ radius = int(3 * sigma)
142
+ if radius < 1:
143
+ return image
144
+ x = mx.arange(-radius, radius + 1)
145
+ kernel = mx.exp(-(x**2) / (2 * sigma**2))
146
+ kernel = kernel / mx.sum(kernel)
147
+
148
+ # MLX doesn't have a direct apply_along_axis equivalent,
149
+ # so we'll implement the convolution differently based on the axis
150
+
151
+ # Helper function to apply 1D convolution along specific axis
152
+ def conv_1d(array, kernel, axis):
153
+ # Reshape kernel to broadcast along the right dimensions
154
+ kernel_shape = [1] * image.ndim
155
+ kernel_shape[axis] = len(kernel)
156
+ kernel_reshaped = kernel.reshape(kernel_shape)
157
+
158
+ # Pad the array
159
+ pad_width = [(0, 0)] * image.ndim
160
+ pad_width[axis] = (radius, radius)
161
+ padded = mx.pad(array, pad_width, mode="edge")
162
+
163
+ # Perform convolution via sliding window sum
164
+ result = mx.zeros_like(array)
165
+ slices = [slice(None)] * padded.ndim
166
+
167
+ for i in range(2 * radius + 1):
168
+ slices[axis] = slice(i, i + array.shape[axis])
169
+ result = result + padded[tuple(slices)] * kernel_reshaped
170
+
171
+ return result
172
+
173
+ return conv_1d(image, kernel, axis)
174
+
175
+
176
+ class VisionEmbeddings(nn.Module):
177
+ def __init__(self, config: VisionConfig):
178
+ super().__init__()
179
+ self.config = config
180
+ self.embed_dim = config.hidden_size
181
+ self.image_size = config.image_size
182
+ self.patch_size = config.patch_size
183
+
184
+ self.patch_embedding = nn.Conv2d(
185
+ config.num_channels,
186
+ config.hidden_size,
187
+ kernel_size=self.patch_size,
188
+ stride=self.patch_size,
189
+ )
190
+
191
+ self.num_patches = (self.image_size // self.patch_size) ** 2
192
+ self.num_positions = self.num_patches
193
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
194
+
195
+ @staticmethod
196
+ def resize_positional_embeddings(
197
+ positional_embeddings: mx.array,
198
+ spatial_shapes: mx.array,
199
+ max_length: int,
200
+ ) -> mx.array:
201
+ """
202
+ Resize positional embeddings to image-specific size and pad to a fixed size.
203
+
204
+ Args:
205
+ positional_embeddings (`torch.Tensor`):
206
+ Position embeddings of shape (height, width, embed_dim)
207
+ spatial_shapes (`torch.LongTensor`):
208
+ Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
209
+ max_length (`int`):
210
+ Maximum length of the positional embeddings to pad resized positional embeddings to
211
+
212
+ Returns:
213
+ `torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim)
214
+ """
215
+ batch_size = spatial_shapes.shape[0]
216
+ embed_dim = positional_embeddings.shape[-1]
217
+ source_dtype = positional_embeddings.dtype
218
+
219
+ resulted_positional_embeddings = mx.zeros(
220
+ (batch_size, max_length, embed_dim)
221
+ ).astype(source_dtype)
222
+
223
+ # (height, width, embed_dim) -> (1, embed_dim, height, width) for interpolation
224
+ positional_embeddings = positional_embeddings.transpose(2, 0, 1).reshape(
225
+ 1, embed_dim, -1
226
+ )
227
+
228
+ # Upcast to float32 on CPU because antialias is not supported for bfloat16/float16 on CPU
229
+ if positional_embeddings.device.type == "cpu":
230
+ positional_embeddings = positional_embeddings.astype(mx.float32)
231
+
232
+ for i in range(batch_size):
233
+ # (1, dim, height, width) -> (1, dim, target_height, target_width)
234
+ height, width = spatial_shapes[i]
235
+ # Then upsample width dimension
236
+ resized_embeddings = resize_bilinear(
237
+ positional_embeddings,
238
+ (height, width),
239
+ align_corners=False,
240
+ antialias=True,
241
+ )
242
+
243
+ # (1, dim, target_height, target_width) -> (target_height * target_width, dim)
244
+ resized_embeddings = resized_embeddings.reshape(
245
+ embed_dim, height * width
246
+ ).transpose(0, 1)
247
+
248
+ # Cast to original dtype
249
+ resized_embeddings = resized_embeddings.astype(source_dtype)
250
+
251
+ resulted_positional_embeddings[i, : height * width] = resized_embeddings
252
+ resulted_positional_embeddings[i, height * width :] = resized_embeddings[0]
253
+
254
+ return resulted_positional_embeddings
255
+
256
+ def __call__(
257
+ self, x: mx.array, spatial_shapes: Optional[mx.array] = None
258
+ ) -> mx.array:
259
+ batch_size = x.shape[0]
260
+ patch_embeddings = self.patch_embedding(x)
261
+ patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2)
262
+ if spatial_shapes is None:
263
+ position_ids = mx.array(np.arange(self.num_positions)[None, :])
264
+ embeddings = patch_embeddings
265
+ embeddings += self.position_embedding(position_ids)
266
+
267
+ else:
268
+ # Get positional resized and padded positional embeddings
269
+ positional_embeddings = self.position_embedding.weight.reshape(
270
+ self.position_embedding_size, self.position_embedding_size, -1
271
+ )
272
+
273
+ resized_positional_embeddings = self.resize_positional_embeddings(
274
+ positional_embeddings, spatial_shapes, max_length=x.shape[1]
275
+ )
276
+
277
+ # Add positional embeddings to patch embeddings
278
+ embeddings = patch_embeds + resized_positional_embeddings
279
+ return embeddings
280
+
281
+
282
+ class SigLipVisionModel(nn.Module):
283
+ def __init__(self, config: VisionConfig):
284
+ super().__init__()
285
+
286
+ self.embeddings = VisionEmbeddings(config)
287
+ self.encoder = Encoder(config)
288
+ self.post_layernorm = nn.LayerNorm(config.hidden_size)
289
+
290
+ def __call__(
291
+ self,
292
+ x: mx.array,
293
+ spatial_shapes: mx.array,
294
+ output_hidden_states: Optional[bool] = None,
295
+ ) -> mx.array:
296
+ x = self.embeddings(x, spatial_shapes)
297
+ x = x.astype(self.embeddings.patch_embedding.weight.dtype)
298
+ encoder_outputs = self.encoder(
299
+ x=x, output_hidden_states=output_hidden_states, mask=None
300
+ )
301
+ pooler_output = self.post_layernorm(encoder_outputs[0])
302
+ return pooler_output, x, encoder_outputs[-1]
303
+
304
+
305
+ class VisionModel(nn.Module):
306
+ def __init__(self, config: VisionConfig):
307
+ super().__init__()
308
+ self.model_type = config.model_type
309
+ if self.model_type not in ["siglip_vision_model"]:
310
+ raise ValueError(f"Unsupported model type: {self.model_type}")
311
+
312
+ self.vision_model = SigLipVisionModel(config)
313
+
314
+ def __call__(
315
+ self,
316
+ x: mx.array,
317
+ spatial_shapes: Optional[mx.array] = None,
318
+ output_hidden_states: Optional[bool] = None,
319
+ ) -> mx.array:
320
+ return self.vision_model(x, spatial_shapes, output_hidden_states)
321
+
322
+ def sanitize(self, weights):
323
+ sanitized_weights = {}
324
+ for k, v in weights.items():
325
+ if "position_ids" in k:
326
+ # Remove unused position_ids
327
+ continue
328
+ elif "patch_embedding.weight" in k:
329
+ # PyTorch conv2d weight tensors have shape:
330
+ # [out_channels, in_channels, kH, KW]
331
+ # MLX conv2d expects the weight be of shape:
332
+ # [out_channels, kH, KW, in_channels]
333
+ if check_array_shape(v):
334
+ sanitized_weights[k] = v
335
+ else:
336
+ sanitized_weights[k] = v.transpose(0, 2, 3, 1)
337
+ else:
338
+ sanitized_weights[k] = v
339
+
340
+ return sanitized_weights
mlx_vlm/models/base.py ADDED
@@ -0,0 +1,356 @@
1
+ import inspect
2
+ import math
3
+ from abc import abstractmethod
4
+ from dataclasses import dataclass
5
+ from typing import Dict, List, Optional
6
+
7
+ import mlx.core as mx
8
+ import mlx.nn as nn
9
+ from mlx_lm.models.base import create_attention_mask, scaled_dot_product_attention
10
+ from PIL import Image
11
+
12
+
13
+ @dataclass
14
+ class LanguageModelOutput:
15
+ logits: mx.array
16
+ hidden_states: Optional[List[mx.array]] = None
17
+ cross_attention_states: Optional[List[mx.array]] = None
18
+ encoder_outputs: Optional[List[mx.array]] = None
19
+
20
+
21
+ @dataclass
22
+ class InputEmbeddingsFeatures:
23
+ inputs_embeds: mx.array
24
+ attention_mask_4d: Optional[mx.array] = None
25
+ visual_pos_masks: Optional[mx.array] = None
26
+ deepstack_visual_embeds: Optional[mx.array] = None
27
+ per_layer_inputs: Optional[mx.array] = None
28
+ cross_attention_states: Optional[mx.array] = None
29
+ cross_attention_mask: Optional[mx.array] = None
30
+ full_text_row_masked_out_mask: Optional[mx.array] = None
31
+ decoder_inputs_embeds: Optional[mx.array] = None
32
+ attention_mask: Optional[mx.array] = None # For encoder-decoder models
33
+
34
+ def to_dict(self):
35
+ return {
36
+ "inputs_embeds": self.inputs_embeds,
37
+ "attention_mask_4d": self.attention_mask_4d,
38
+ "visual_pos_masks": self.visual_pos_masks,
39
+ "deepstack_visual_embeds": self.deepstack_visual_embeds,
40
+ "per_layer_inputs": self.per_layer_inputs,
41
+ "cross_attention_states": self.cross_attention_states,
42
+ "cross_attention_mask": self.cross_attention_mask,
43
+ "full_text_row_masked_out_mask": self.full_text_row_masked_out_mask,
44
+ "decoder_inputs_embeds": self.decoder_inputs_embeds,
45
+ "attention_mask": self.attention_mask,
46
+ }
47
+
48
+
49
+ @dataclass
50
+ class BaseModelConfig:
51
+ @classmethod
52
+ def from_dict(cls, params):
53
+ return cls(
54
+ **{
55
+ k: v
56
+ for k, v in params.items()
57
+ if k in inspect.signature(cls).parameters
58
+ }
59
+ )
60
+
61
+ def to_dict(self):
62
+ return {k: v for k, v in self.__dict__.items() if v is not None}
63
+
64
+
65
+ class BaseImageProcessor:
66
+ """
67
+ Base image processor class. Subclasses should implement preprocess().
68
+ Transformers imports are deferred to __init__ for faster module loading.
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ image_mean=(0.5, 0.5, 0.5),
74
+ image_std=(0.5, 0.5, 0.5),
75
+ size=(384, 384),
76
+ crop_size: Dict[str, int] = None,
77
+ resample=None,
78
+ rescale_factor=1 / 255,
79
+ data_format=None,
80
+ ):
81
+ from transformers.image_processing_utils import get_size_dict
82
+ from transformers.image_utils import ChannelDimension, PILImageResampling
83
+
84
+ if resample is None:
85
+ resample = PILImageResampling.BICUBIC
86
+ if data_format is None:
87
+ data_format = ChannelDimension.FIRST
88
+
89
+ crop_size = (
90
+ crop_size if crop_size is not None else {"height": 384, "width": 384}
91
+ )
92
+ crop_size = get_size_dict(
93
+ crop_size, default_to_square=True, param_name="crop_size"
94
+ )
95
+
96
+ self.image_mean = image_mean
97
+ self.image_std = image_std
98
+ self.size = size
99
+ self.resample = resample
100
+ self.rescale_factor = rescale_factor
101
+ self.data_format = data_format
102
+ self.crop_size = crop_size
103
+
104
+ def rescale(
105
+ self,
106
+ image,
107
+ scale: float,
108
+ input_data_format: str = "channels_first",
109
+ ):
110
+ """Rescale an image by a scale factor."""
111
+ return image * scale
112
+
113
+ def normalize(
114
+ self,
115
+ image,
116
+ mean,
117
+ std,
118
+ input_data_format: str = "channels_first",
119
+ ):
120
+ """Normalize an image with mean and std."""
121
+ import numpy as np
122
+
123
+ mean = np.array(mean, dtype=image.dtype)
124
+ std = np.array(std, dtype=image.dtype)
125
+
126
+ if input_data_format == "channels_first":
127
+ # Image shape: [C, H, W]
128
+ mean = mean[:, None, None]
129
+ std = std[:, None, None]
130
+ else:
131
+ # Image shape: [H, W, C]
132
+ pass # mean and std are already in correct shape
133
+
134
+ return (image - mean) / std
135
+
136
+ @abstractmethod
137
+ def preprocess(self, images):
138
+ pass
139
+
140
+
141
+ def expand2square(pil_img, background_color):
142
+ width, height = pil_img.size
143
+ if width == height:
144
+ return pil_img
145
+ elif width > height:
146
+ result = Image.new(pil_img.mode, (width, width), background_color)
147
+ result.paste(pil_img, (0, (width - height) // 2))
148
+ return result
149
+ else:
150
+ result = Image.new(pil_img.mode, (height, height), background_color)
151
+ result.paste(pil_img, ((height - width) // 2, 0))
152
+ return result
153
+
154
+
155
+ def check_array_shape(arr):
156
+ shape = arr.shape
157
+
158
+ # Check if the shape has 4 dimensions
159
+ if len(shape) == 4:
160
+ out_channels, kH, KW, _ = shape
161
+ # Check if out_channels is the largest, and kH and KW are the same
162
+ if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
163
+ return True
164
+ else:
165
+ return False
166
+ # Check if the shape has 3 dimensions
167
+ elif len(shape) == 3:
168
+ _, kW, out_channels = shape
169
+ # Check if out_channels is the largest
170
+ if kW >= out_channels:
171
+ return True
172
+ else:
173
+ return False
174
+ else:
175
+ return False
176
+
177
+
178
+ def check_activation_stats(name, tensor):
179
+ """Helper function to check for anomalies and log stats."""
180
+
181
+ print(f"--- Activation Stats: {name} ---")
182
+ # Check for NaNs/Infs
183
+ has_nan = mx.isnan(tensor).any()
184
+ has_inf = mx.isinf(tensor).any()
185
+ if has_nan:
186
+ print(f"WARNING: Found NaN in {name}")
187
+ if has_inf:
188
+ print(f"WARNING: Found Inf in {name}")
189
+
190
+ # Calculate and print stats (ensure computation happens)
191
+ min_val = mx.min(tensor).item()
192
+ max_val = mx.max(tensor).item()
193
+ mean_val = mx.mean(tensor).item()
194
+ std_val = mx.std(tensor).item()
195
+ print(f" Shape: {tensor.shape}")
196
+ print(f" Min: {min_val:.4f}, Max: {max_val:.4f}")
197
+ print(f" Mean: {mean_val:.4f}, Std: {std_val:.4f}")
198
+ print("-" * (len(name) + 24))
199
+
200
+
201
+ def pixel_shuffle(input_tensor, shuffle_ratio):
202
+ # input_tensor: [batch_size, num_patches, channels]
203
+ batch_size, num_patches, channels = input_tensor.shape
204
+ patch_size = int(math.sqrt(num_patches))
205
+
206
+ input_tensor = input_tensor.reshape(batch_size, patch_size, patch_size, -1)
207
+ batch_size, height, width, channels = input_tensor.shape
208
+
209
+ reshaped_tensor = input_tensor.reshape(
210
+ batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio)
211
+ )
212
+ reshaped_tensor = reshaped_tensor.transpose(0, 2, 1, 3)
213
+
214
+ reshaped_tensor = reshaped_tensor.reshape(
215
+ batch_size,
216
+ int(height * shuffle_ratio),
217
+ int(width * shuffle_ratio),
218
+ int(channels / (shuffle_ratio**2)),
219
+ )
220
+ reshaped_tensor = reshaped_tensor.transpose(0, 2, 1, 3)
221
+
222
+ output_tensor = reshaped_tensor.reshape(batch_size, -1, reshaped_tensor.shape[-1])
223
+ return output_tensor
224
+
225
+
226
+ def interpolate(pos_embed, size, mode="cubic", align_corners=False):
227
+ """
228
+ MLX implementation of PyTorch's F.interpolate with bicubic mode
229
+
230
+ Args:
231
+ pos_embed: MLX array with shape [B, C, H_src, W_src] or [C, H_src, W_src]
232
+ size: Tuple (H_dst, W_dst) - target size
233
+ align_corners: Boolean - whether to align corners
234
+
235
+ Returns:
236
+ Interpolated array with shape [B, C, H_dst, W_dst] or [C, H_dst, W_dst]
237
+ """
238
+ # Handle different input shapes
239
+ input_dim = pos_embed.ndim
240
+ original_shape = pos_embed.shape
241
+
242
+ if input_dim == 3:
243
+ # [C, H, W] -> [1, C, H, W]
244
+ pos_embed = pos_embed.reshape(1, *original_shape)
245
+
246
+ # Get source dimensions
247
+ h_src, w_src = pos_embed.shape[-2:]
248
+ h_dst, w_dst = size
249
+
250
+ # Calculate scale factors
251
+ scale_h = h_dst / h_src
252
+ scale_w = w_dst / w_src
253
+
254
+ # Create upsampler
255
+ upsampler = nn.Upsample(
256
+ scale_factor=(scale_h, scale_w), mode=mode, align_corners=align_corners
257
+ )
258
+
259
+ # Apply upsampling
260
+ result = upsampler(pos_embed)
261
+
262
+ # Return in the original dimension format
263
+ if input_dim == 3:
264
+ return result.reshape(original_shape[0], *size)
265
+ return result
266
+
267
+
268
+ @mx.compile
269
+ def chunked_attention(
270
+ queries: mx.array,
271
+ keys: mx.array,
272
+ values: mx.array,
273
+ scale: float,
274
+ chunk_size: int,
275
+ ) -> mx.array:
276
+
277
+ L = queries.shape[2]
278
+
279
+ outputs = []
280
+ for i in range(0, L, chunk_size):
281
+ end_idx = min(i + chunk_size, L)
282
+ q_chunk = queries[:, :, i:end_idx, :] # (B, n_heads, chunk, head_dim)
283
+
284
+ chunk_output = mx.fast.scaled_dot_product_attention(
285
+ q_chunk, keys, values, scale=scale
286
+ )
287
+
288
+ outputs.append(chunk_output)
289
+
290
+ return mx.concatenate(outputs, axis=2) # (B, n_heads, L, head_dim)
291
+
292
+
293
+ def install_auto_processor_patch(target_model_types, processor_cls):
294
+ """
295
+ Install a composable patch on transformers.AutoProcessor.from_pretrained
296
+
297
+ Args:
298
+ target_model_types (Union[str, List[str]]): Model types to intercept.
299
+ processor_cls (type): Processor class exposing `from_pretrained`.
300
+
301
+ Returns:
302
+ The previous `AutoProcessor.from_pretrained` for reference.
303
+ """
304
+ from transformers import AutoProcessor as _HF_AutoProcessor
305
+
306
+ if isinstance(target_model_types, str):
307
+ target_model_types = [target_model_types]
308
+ target_model_types = {t.lower() for t in target_model_types}
309
+
310
+ previous_from_pretrained = _HF_AutoProcessor.from_pretrained
311
+
312
+ @classmethod
313
+ def _patched_auto_processor_from_pretrained(
314
+ cls, pretrained_model_name_or_path, **kwargs
315
+ ):
316
+ import json as _json
317
+ from pathlib import Path
318
+
319
+ try:
320
+ model_path = Path(pretrained_model_name_or_path)
321
+ is_local = model_path.exists() and model_path.is_dir()
322
+
323
+ cfg = {}
324
+ if is_local:
325
+ config_path = model_path / "config.json"
326
+ if config_path.exists():
327
+ with open(config_path, "r", encoding="utf-8") as f:
328
+ cfg = _json.load(f)
329
+ else:
330
+ try:
331
+ from huggingface_hub import hf_hub_download
332
+
333
+ cfg_path = hf_hub_download(
334
+ pretrained_model_name_or_path, "config.json"
335
+ )
336
+ with open(cfg_path, "r", encoding="utf-8") as f:
337
+ cfg = _json.load(f)
338
+ except Exception:
339
+ cfg = {}
340
+
341
+ model_type = str(cfg.get("model_type", "")).lower()
342
+ if model_type in target_model_types:
343
+ return processor_cls.from_pretrained(
344
+ pretrained_model_name_or_path, **kwargs
345
+ )
346
+ except Exception:
347
+ # On any failure, fall back to previous behavior
348
+ pass
349
+
350
+ # Chain to the prior from_pretrained (which may already be patched)
351
+ return previous_from_pretrained.__func__(
352
+ cls, pretrained_model_name_or_path, **kwargs
353
+ )
354
+
355
+ _HF_AutoProcessor.from_pretrained = _patched_auto_processor_from_pretrained
356
+ return previous_from_pretrained