fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,322 @@
1
+ import mlx.core as mx
2
+ import mlx.nn as nn
3
+
4
+ from ..base import chunked_attention
5
+ from .config import VisionConfig
6
+
7
+
8
+ class VisionMLP(nn.Module):
9
+ def __init__(self, config: VisionConfig):
10
+ super().__init__()
11
+ self.hidden_size = config.hidden_size
12
+ self.intermediate_size = config.intermediate_size
13
+ self.dense_h_to_4h = nn.Linear(
14
+ self.hidden_size, self.intermediate_size, bias=True
15
+ )
16
+ self.dense_4h_to_h = nn.Linear(
17
+ self.intermediate_size, self.hidden_size, bias=True
18
+ )
19
+
20
+ def __call__(self, x: mx.array) -> mx.array:
21
+ x = self.dense_h_to_4h(x)
22
+ x = nn.gelu(x)
23
+ x = self.dense_4h_to_h(x)
24
+ return x
25
+
26
+
27
+ class VisionAttention(nn.Module):
28
+
29
+ def __init__(self, config: VisionConfig):
30
+ super().__init__()
31
+ self.hidden_size = config.hidden_size
32
+ self.num_heads = config.num_attention_heads
33
+ self.head_dim = self.hidden_size // self.num_heads
34
+ self.scale = self.head_dim**-0.5
35
+
36
+ self.q_proj = nn.Linear(
37
+ config.hidden_size, self.num_heads * self.head_dim, bias=True
38
+ )
39
+ self.k_proj = nn.Linear(
40
+ config.hidden_size, self.num_heads * self.head_dim, bias=True
41
+ )
42
+ self.v_proj = nn.Linear(
43
+ config.hidden_size, self.num_heads * self.head_dim, bias=True
44
+ )
45
+ self.o_proj = nn.Linear(
46
+ self.num_heads * self.head_dim, config.hidden_size, bias=True
47
+ )
48
+
49
+ def __call__(self, x: mx.array, chunk_size: int = 1024) -> mx.array:
50
+ B, L, _ = x.shape
51
+
52
+ queries = self.q_proj(x)
53
+ keys = self.k_proj(x)
54
+ values = self.v_proj(x)
55
+
56
+ # Reshape to (B, n_heads, L, head_dim)
57
+ queries = queries.reshape(B, L, self.num_heads, self.head_dim).transpose(
58
+ 0, 2, 1, 3
59
+ )
60
+ keys = keys.reshape(B, L, self.num_heads, self.head_dim).transpose(0, 2, 1, 3)
61
+ values = values.reshape(B, L, self.num_heads, self.head_dim).transpose(
62
+ 0, 2, 1, 3
63
+ )
64
+
65
+ output = chunked_attention(
66
+ queries,
67
+ keys,
68
+ values,
69
+ scale=self.scale,
70
+ chunk_size=chunk_size,
71
+ )
72
+
73
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
74
+ return self.o_proj(output)
75
+
76
+
77
+ class VisionBlock(nn.Module):
78
+
79
+ def __init__(self, config: VisionConfig):
80
+ super().__init__()
81
+ self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
82
+ self.post_attention_layernorm = nn.LayerNorm(
83
+ config.hidden_size, eps=config.rms_norm_eps
84
+ )
85
+ self.self_attn = VisionAttention(config)
86
+ self.mlp = VisionMLP(config)
87
+
88
+ def __call__(self, x: mx.array) -> mx.array:
89
+ # Self-attention with residual
90
+ h = x + self.self_attn(self.input_layernorm(x))
91
+ # MLP with residual
92
+ out = h + self.mlp(self.post_attention_layernorm(h))
93
+ return out
94
+
95
+
96
+ class PatchEmbed(nn.Module):
97
+
98
+ def __init__(self, config: VisionConfig):
99
+ super().__init__()
100
+ self.config = config
101
+ self.embed_dim = config.hidden_size
102
+ self.patch_size = config.patch_size
103
+ self.num_channels = config.num_channels
104
+ self.spatial_merge_size = config.spatial_merge_size
105
+ self.interpolate_mode = config.interpolate_mode
106
+
107
+ self.patch_embedding = nn.Conv2d(
108
+ in_channels=config.num_channels,
109
+ out_channels=self.embed_dim,
110
+ kernel_size=self.patch_size,
111
+ stride=self.patch_size,
112
+ bias=True,
113
+ )
114
+
115
+ self.max_num_patches = (config.max_image_size // self.patch_size) ** 2
116
+ self.num_positions = self.max_num_patches + 1
117
+ self.position_edge = int(self.num_positions**0.5)
118
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
119
+
120
+ def __call__(self, pixel_values: mx.array, grid_thw: list) -> mx.array:
121
+ num_patches = pixel_values.shape[0]
122
+ # Reshape: (num_patches, C*P*P) -> (num_patches, C, P, P) -> (num_patches, P, P, C) for MLX conv
123
+ pixel_values = pixel_values.reshape(
124
+ num_patches, self.num_channels, self.patch_size, self.patch_size
125
+ )
126
+ pixel_values = pixel_values.transpose(0, 2, 3, 1) # NCHW -> NHWC for MLX
127
+
128
+ # Apply patch embedding
129
+ patch_embeds = self.patch_embedding(pixel_values) # (N, 1, 1, embed_dim)
130
+ patch_embeds = patch_embeds.reshape(1, num_patches, self.embed_dim)
131
+
132
+ # Get position embeddings and interpolate for each grid
133
+ pos_embed_weights = self.position_embedding.weight[1:, :] # Skip cls token
134
+ base_pos_embed = pos_embed_weights.reshape(
135
+ 1, self.position_edge, self.position_edge, self.embed_dim
136
+ )
137
+
138
+ patch_pos_embed_list = []
139
+ for grid in grid_thw:
140
+ t, h, w = grid
141
+ h_float = float(h) + 0.1
142
+ w_float = float(w) + 0.1
143
+
144
+ target_h = int(h)
145
+ target_w = int(w)
146
+
147
+ # Simple bilinear interpolation
148
+ pos_embed = self._interpolate_pos_embed(base_pos_embed, target_h, target_w)
149
+ pos_embed = pos_embed.reshape(1, -1, self.embed_dim)
150
+ patch_pos_embed_list.append(pos_embed)
151
+
152
+ patch_pos_embed = mx.concatenate(patch_pos_embed_list, axis=1)
153
+ embeddings = patch_embeds + patch_pos_embed
154
+
155
+ return embeddings
156
+
157
+ def _interpolate_pos_embed(
158
+ self, pos_embed: mx.array, target_h: int, target_w: int
159
+ ) -> mx.array:
160
+ dtype = pos_embed.dtype
161
+ src_h, src_w = pos_embed.shape[1], pos_embed.shape[2]
162
+
163
+ if src_h == target_h and src_w == target_w:
164
+ return pos_embed
165
+
166
+ # Create coordinate grids
167
+ h_scale = src_h / (target_h + 0.1)
168
+ w_scale = src_w / (target_w + 0.1)
169
+ h_coords = (mx.arange(target_h) + 0.5) * h_scale - 0.5
170
+ w_coords = (mx.arange(target_w) + 0.5) * w_scale - 0.5
171
+
172
+ i0 = h_coords.astype(mx.int32)
173
+ j0 = w_coords.astype(mx.int32)
174
+ i1 = mx.minimum(i0 + 1, src_h - 1)
175
+ j1 = mx.minimum(j0 + 1, src_w - 1)
176
+
177
+ di = (h_coords - i0.astype(mx.float32))[:, None, None]
178
+ dj = (w_coords - j0.astype(mx.float32))[None, :, None]
179
+
180
+ # Gather corners and interpolate
181
+ p00 = pos_embed[0, i0][:, j0]
182
+ p01 = pos_embed[0, i0][:, j1]
183
+ p10 = pos_embed[0, i1][:, j0]
184
+ p11 = pos_embed[0, i1][:, j1]
185
+
186
+ result = (
187
+ (1 - di) * (1 - dj) * p00
188
+ + (1 - di) * dj * p01
189
+ + di * (1 - dj) * p10
190
+ + di * dj * p11
191
+ )
192
+
193
+ return result[None].astype(dtype)
194
+
195
+
196
+ class PatchMerger(nn.Module):
197
+ def __init__(
198
+ self,
199
+ config: VisionConfig,
200
+ ):
201
+ super().__init__()
202
+ self.config = config
203
+ self.spatial_merge_size = config.spatial_merge_size
204
+ self.hidden_size = config.hidden_size
205
+ self.out_hidden_size = config.out_hidden_size
206
+
207
+ merge_hidden = config.hidden_size * 2 # 2304
208
+ final_hidden = config.hidden_size * 4 # 4608
209
+
210
+ self.before_rms = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
211
+ self.after_rms = nn.RMSNorm(config.out_hidden_size, eps=config.rms_norm_eps)
212
+
213
+ self.proj = [
214
+ nn.Conv2d(
215
+ config.hidden_size,
216
+ merge_hidden,
217
+ kernel_size=config.spatial_merge_size,
218
+ stride=config.spatial_merge_size,
219
+ bias=True,
220
+ ),
221
+ nn.GELU(),
222
+ nn.Conv2d(merge_hidden, final_hidden, kernel_size=1, bias=True),
223
+ ]
224
+
225
+ self.mlp = nn.Linear(final_hidden, config.out_hidden_size, bias=True)
226
+
227
+ self.image_newline = mx.zeros((final_hidden,))
228
+ self.image_begin = mx.zeros((config.out_hidden_size,))
229
+ self.image_end = mx.zeros((config.out_hidden_size,))
230
+ self.image_sep = mx.zeros((config.out_hidden_size,))
231
+
232
+ def __call__(self, hidden_states: mx.array, grid_h: int, grid_w: int) -> mx.array:
233
+
234
+ B = hidden_states.shape[0]
235
+ final_hidden = self.config.hidden_size * 4 # 4608
236
+
237
+ x = self.before_rms(hidden_states)
238
+
239
+ x = x.reshape(B, grid_h, grid_w, self.hidden_size)
240
+
241
+ for layer in self.proj:
242
+ x = layer(x)
243
+
244
+ merged_h = grid_h // self.spatial_merge_size
245
+ merged_w = grid_w // self.spatial_merge_size
246
+
247
+ x = x.reshape(B, merged_h, merged_w, final_hidden)
248
+
249
+ newlines = mx.broadcast_to(
250
+ self.image_newline[None, None, None, :], (B, merged_h, 1, final_hidden)
251
+ )
252
+
253
+ x = mx.concatenate(
254
+ [x, newlines], axis=2
255
+ ) # (B, merged_h, merged_w+1, final_hidden)
256
+ x = x.reshape(B, merged_h * (merged_w + 1), final_hidden)
257
+
258
+ x = self.mlp(x)
259
+
260
+ begin = mx.broadcast_to(
261
+ self.image_begin[None, None, :], (B, 1, self.out_hidden_size)
262
+ )
263
+ end = mx.broadcast_to(
264
+ self.image_end[None, None, :], (B, 1, self.out_hidden_size)
265
+ )
266
+
267
+ x = mx.concatenate([begin, x, end], axis=1)
268
+
269
+ x = self.after_rms(x)
270
+
271
+ return x
272
+
273
+
274
+ class VisionModel(nn.Module):
275
+ def __init__(self, config: VisionConfig):
276
+ super().__init__()
277
+ self.config = config
278
+ self.model_type = config.model_type
279
+ if self.model_type != "hunyuan_vl":
280
+ raise ValueError(f"Unsupported model type: {self.model_type}")
281
+ self.embeddings = PatchEmbed(config)
282
+ self.layers = [VisionBlock(config) for _ in range(config.num_hidden_layers)]
283
+ self.perceive = PatchMerger(
284
+ config=config,
285
+ )
286
+
287
+ def __call__(
288
+ self,
289
+ pixel_values: mx.array,
290
+ grid_thw: list,
291
+ ) -> mx.array:
292
+ """
293
+ Args:
294
+ pixel_values: Flattened pixel values of shape (total_patches, C*P*P)
295
+ grid_thw: List of [t, h, w] for each image
296
+
297
+ Returns:
298
+ Image features of shape (1, total_tokens, text_hidden_size)
299
+ """
300
+ hidden_states = self.embeddings(pixel_values, grid_thw)
301
+
302
+ for layer in self.layers:
303
+ hidden_states = layer(hidden_states)
304
+
305
+ # Calculate cumulative sequence lengths
306
+ cu_seqlens = [0]
307
+ for t, h, w in grid_thw:
308
+ cu_seqlens.append(int(h * w))
309
+ cu_seqlens = mx.cumsum(mx.array(cu_seqlens, dtype=mx.int32))
310
+
311
+ # Split and process each image
312
+ processed_items = []
313
+ for i, grid in enumerate(grid_thw):
314
+ t, h, w = grid
315
+ start_idx = int(cu_seqlens[i])
316
+ end_idx = int(cu_seqlens[i + 1])
317
+ item = hidden_states[:, start_idx:end_idx, :]
318
+ processed = self.perceive(item, int(h), int(w))
319
+ processed_items.append(processed)
320
+
321
+ hidden_states = mx.concatenate(processed_items, axis=1)
322
+ return hidden_states
@@ -0,0 +1,2 @@
1
+ from .config import ModelConfig, PerceiverConfig, TextConfig, VisionConfig
2
+ from .idefics2 import LanguageModel, Model, VisionModel
@@ -0,0 +1,65 @@
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional
3
+
4
+ from ..base import BaseModelConfig
5
+
6
+
7
+ @dataclass
8
+ class VisionConfig(BaseModelConfig):
9
+ model_type: str
10
+ hidden_size: int = 4096
11
+ intermediate_size: int = 14336
12
+ num_hidden_layers: int = 32
13
+ num_attention_heads: int = 32
14
+ num_key_value_heads: int = 8
15
+ num_channels: int = 3
16
+ image_size: int = 224
17
+ patch_size: int = 32
18
+ layer_norm_eps: float = 1e-6
19
+
20
+
21
+ @dataclass
22
+ class TextConfig(BaseModelConfig):
23
+ model_type: str
24
+ hidden_size: int = 4096
25
+ intermediate_size: int = 14336
26
+ num_hidden_layers: int = 32
27
+ num_attention_heads: int = 32
28
+ num_key_value_heads: int = 8
29
+ rms_norm_eps: float = 1e-5
30
+ vocab_size: int = 32003
31
+ rope_theta: float = 1000000.0
32
+ rope_traditional: bool = False
33
+ max_position_embeddings: int = 32768
34
+ tie_word_embeddings: bool = False
35
+
36
+ def __post_init__(self):
37
+ if self.num_key_value_heads is None:
38
+ self.num_key_value_heads = self.num_attention_heads
39
+
40
+
41
+ @dataclass
42
+ class PerceiverConfig(BaseModelConfig):
43
+ model_type: str
44
+ num_key_value_heads: int = 4
45
+ resampler_depth: int = 3
46
+ resampler_head_dim: int = 96
47
+ resampler_n_heads: int = 16
48
+ resampler_n_latents: int = 64
49
+
50
+
51
+ @dataclass
52
+ class ModelConfig(BaseModelConfig):
53
+ text_config: TextConfig
54
+ vision_config: VisionConfig
55
+ perceiver_config: PerceiverConfig
56
+ model_type: str
57
+ ignore_index: int = -100
58
+ image_token_id: int = 32001
59
+ vocab_size: int = 151936
60
+ image_token_index: Optional[int] = None
61
+ eos_token_id: Optional[List[int]] = None
62
+
63
+ def __post_init__(self):
64
+ if self.image_token_index is None:
65
+ self.image_token_index = self.image_token_id
@@ -0,0 +1,321 @@
1
+ import re
2
+ from typing import Optional, Tuple
3
+
4
+ import mlx.core as mx
5
+ import mlx.nn as nn
6
+ import numpy as np
7
+
8
+ from ..base import InputEmbeddingsFeatures
9
+ from .config import ModelConfig
10
+ from .language import LanguageModel
11
+ from .vision import VisionModel
12
+
13
+
14
+ def masked_scatter(
15
+ final_embedding: mx.array,
16
+ image_mask_expanded: mx.array,
17
+ scaled_image_features: mx.array,
18
+ ):
19
+ # Reshape the tensors to 1D
20
+ final_embedding_shape = final_embedding.shape
21
+ scaled_image_features_flattened = mx.flatten(scaled_image_features)
22
+ final_embedding_flattened = mx.flatten(final_embedding)
23
+ image_mask_expanded_flattened = mx.flatten(image_mask_expanded)
24
+
25
+ # Scatter the scaled image features into the special image token positions
26
+ image_positions = mx.array(np.where(image_mask_expanded_flattened)[0], mx.uint32)
27
+ final_embedding_flattened[image_positions] = scaled_image_features_flattened
28
+
29
+ # Reshape back to the original shape
30
+ final_embedding = mx.reshape(final_embedding_flattened, final_embedding_shape)
31
+
32
+ return final_embedding
33
+
34
+
35
+ class Idefics2PerceiverAttention(nn.Module):
36
+ def __init__(self, config: ModelConfig):
37
+ super().__init__()
38
+
39
+ dim = config.text_config.hidden_size
40
+ self.n_heads = n_heads = config.perceiver_config.resampler_n_heads
41
+ self.n_kv_heads = n_kv_heads = config.perceiver_config.num_key_value_heads
42
+
43
+ head_dim = config.perceiver_config.resampler_head_dim
44
+ self.scale = head_dim**-0.5
45
+
46
+ self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
47
+ self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
48
+ self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
49
+ self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
50
+
51
+ def __call__(
52
+ self,
53
+ x: mx.array,
54
+ kv: mx.array,
55
+ mask: Optional[mx.array] = None,
56
+ cache: Optional[Tuple[mx.array, mx.array]] = None,
57
+ ) -> mx.array:
58
+ B, L, D = x.shape
59
+ kv_seq_len = L + kv.shape[1]
60
+ hidden_states = mx.concatenate([kv, x], axis=-2)
61
+
62
+ queries = self.q_proj(x)
63
+ keys = self.k_proj(hidden_states)
64
+ values = self.v_proj(hidden_states)
65
+
66
+ # Prepare the queries, keys and values for the attention computation
67
+ queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
68
+ keys = keys.reshape(B, kv_seq_len, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
69
+ values = values.reshape(B, kv_seq_len, self.n_kv_heads, -1).transpose(
70
+ 0, 2, 1, 3
71
+ )
72
+
73
+ if cache is not None:
74
+ key_cache, value_cache = cache
75
+ keys = mx.concatenate([key_cache, keys], axis=2)
76
+ values = mx.concatenate([value_cache, values], axis=2)
77
+
78
+ output = mx.fast.scaled_dot_product_attention(
79
+ queries, keys, values, scale=self.scale
80
+ )
81
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
82
+ return self.o_proj(output)
83
+
84
+
85
+ class Idefics2PerceiverLayer(nn.Module):
86
+ def __init__(self, config: ModelConfig):
87
+ super().__init__()
88
+ self.hidden_size = config.text_config.hidden_size
89
+ self.n_latents = config.perceiver_config.resampler_n_latents
90
+ self.depth = config.perceiver_config.resampler_depth
91
+ self.rms_norm_eps = config.text_config.rms_norm_eps
92
+
93
+ self.input_latents_norm = nn.RMSNorm(self.hidden_size, eps=self.rms_norm_eps)
94
+ self.input_context_norm = nn.RMSNorm(self.hidden_size, eps=self.rms_norm_eps)
95
+ self.self_attn = Idefics2PerceiverAttention(config)
96
+ self.post_attention_layernorm = nn.RMSNorm(
97
+ self.hidden_size, eps=self.rms_norm_eps
98
+ )
99
+ self.mlp = MLP(self.hidden_size, self.hidden_size * 4, self.hidden_size)
100
+
101
+ def __call__(
102
+ self,
103
+ x: mx.array,
104
+ hidden_states: mx.array,
105
+ mask: Optional[mx.array] = None,
106
+ ) -> mx.array:
107
+ latents = self.input_latents_norm(x)
108
+ context = self.input_context_norm(hidden_states)
109
+
110
+ latents = self.self_attn(latents, context, mask=mask)
111
+
112
+ latents = x + latents
113
+ r = latents
114
+
115
+ latents = self.post_attention_layernorm(latents)
116
+ latents = self.mlp(latents)
117
+ latents = r + latents
118
+ return latents
119
+
120
+
121
+ class Idefics2PerceiverResampler(nn.Module):
122
+ def __init__(self, config: ModelConfig):
123
+ super().__init__()
124
+ self.hidden_size = config.text_config.hidden_size
125
+ self.n_latents = config.perceiver_config.resampler_n_latents
126
+
127
+ self.latents = mx.ones((self.n_latents, self.hidden_size))
128
+ self.layers = [
129
+ Idefics2PerceiverLayer(config)
130
+ for _ in range(config.perceiver_config.resampler_depth)
131
+ ]
132
+ self.norm = nn.RMSNorm(self.hidden_size, eps=config.text_config.rms_norm_eps)
133
+
134
+ def __call__(self, x: mx.array, mask: Optional[mx.array] = None):
135
+
136
+ h = mx.expand_dims(self.latents, axis=0)
137
+ h = mx.repeat(h, x.shape[0], axis=0)
138
+
139
+ for layer in self.layers:
140
+ h = layer(h, x, mask=mask)
141
+
142
+ return self.norm(h)
143
+
144
+
145
+ class MLP(nn.Module):
146
+ def __init__(self, dim, hidden_dim, output_size):
147
+ super().__init__()
148
+ self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
149
+ self.down_proj = nn.Linear(hidden_dim, output_size, bias=False)
150
+ self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
151
+
152
+ def __call__(self, x) -> mx.array:
153
+ return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
154
+
155
+
156
+ class Idefics2Connector(nn.Module):
157
+ def __init__(self, config: ModelConfig):
158
+ super().__init__()
159
+ self.modality_projection = MLP(
160
+ config.vision_config.hidden_size,
161
+ config.text_config.intermediate_size,
162
+ config.text_config.hidden_size,
163
+ )
164
+
165
+ self.perceiver_resampler = Idefics2PerceiverResampler(config)
166
+
167
+ def __call__(self, x: mx.array, mask=None) -> mx.array:
168
+ x = self.modality_projection(x)
169
+ x = self.perceiver_resampler(x, mask=mask)
170
+ return x
171
+
172
+
173
+ class Model(nn.Module):
174
+ def __init__(self, config: ModelConfig):
175
+ super().__init__()
176
+ self.model_type = config.model_type
177
+ self.config = config
178
+
179
+ self.vision_model = VisionModel(config.vision_config)
180
+ self.language_model = LanguageModel(config.text_config)
181
+ self.connector = Idefics2Connector(config)
182
+
183
+ def get_input_embeddings(
184
+ self,
185
+ input_ids: Optional[mx.array] = None,
186
+ pixel_values: Optional[mx.array] = None,
187
+ **kwargs,
188
+ ):
189
+ pixel_attention_mask = kwargs.get("pixel_attention_mask", None)
190
+
191
+ if pixel_values is None:
192
+ return InputEmbeddingsFeatures(
193
+ inputs_embeds=self.language_model.embed_tokens(input_ids)
194
+ )
195
+
196
+ inputs_embeds = self.language_model.embed_tokens(input_ids)
197
+
198
+ batch_size, num_images, num_channels, height, width = pixel_values.shape
199
+ pixel_values = pixel_values.reshape(
200
+ batch_size * num_images, num_channels, height, width
201
+ )
202
+
203
+ # Remove padding images - padding image are full 0.
204
+ nb_values_per_image = np.prod(pixel_values.shape[1:])
205
+ real_images_mask = (pixel_values == 0.0).sum(
206
+ axis=(-1, -2, -3)
207
+ ) != nb_values_per_image
208
+ real_images_inds = np.where(real_images_mask)[0].tolist()
209
+ pixel_values = pixel_values[real_images_inds, ...]
210
+
211
+ if pixel_attention_mask is None:
212
+ pixel_attention_mask = mx.ones(
213
+ (pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)),
214
+ dtype=mx.bool,
215
+ )
216
+ else:
217
+ # Remove padding images from the mask
218
+ pixel_attention_mask = pixel_attention_mask.reshape(
219
+ batch_size * num_images, height, width
220
+ )
221
+ pixel_attention_mask = pixel_attention_mask[real_images_inds]
222
+
223
+ patch_size = self.config.vision_config.patch_size
224
+ batch_size, height, width = pixel_attention_mask.shape
225
+
226
+ # Calculate number of patches
227
+ patches_h = height // patch_size
228
+ patches_w = width // patch_size
229
+
230
+ # Reshape to extract patches
231
+ reshaped = pixel_attention_mask[
232
+ :, : patches_h * patch_size, : patches_w * patch_size
233
+ ]
234
+ reshaped = reshaped.reshape(
235
+ batch_size, patches_h, patch_size, patches_w, patch_size
236
+ )
237
+ reshaped = reshaped.transpose(
238
+ 0, 1, 3, 2, 4
239
+ ) # (batch, patches_h, patches_w, patch_size, patch_size)
240
+
241
+ # Sum over patch dimensions and check if any pixels are active
242
+ patch_attention_mask = reshaped.sum(axis=(-1, -2)) > 0
243
+
244
+ pooler_output, *_ = self.vision_model(
245
+ pixel_values.transpose(0, 2, 3, 1),
246
+ patch_attention_mask=patch_attention_mask,
247
+ output_hidden_states=True,
248
+ )
249
+
250
+ image_features = pooler_output.astype(pixel_values.dtype)
251
+ image_features = self.connector(image_features)
252
+
253
+ final_inputs_embeds = self._prepare_inputs_for_multimodal(
254
+ image_features, inputs_embeds, input_ids
255
+ )
256
+ return InputEmbeddingsFeatures(inputs_embeds=final_inputs_embeds)
257
+
258
+ def _prepare_inputs_for_multimodal(self, image_features, inputs_embeds, input_ids):
259
+ special_image_mask = input_ids == self.config.image_token_index
260
+ n_image_tokens = special_image_mask.sum()
261
+ special_image_mask = special_image_mask[..., None]
262
+ special_image_mask = mx.broadcast_to(special_image_mask, inputs_embeds.shape)
263
+
264
+ n_image_features = image_features.shape[0]
265
+ n_image_mask_elements = special_image_mask.sum()
266
+ if n_image_mask_elements != image_features.size:
267
+ raise ValueError(
268
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
269
+ )
270
+
271
+ inputs_embeds = masked_scatter(
272
+ inputs_embeds, special_image_mask, image_features
273
+ )
274
+
275
+ return inputs_embeds
276
+
277
+ @property
278
+ def layers(self):
279
+ return self.language_model.model.layers
280
+
281
+ def __call__(
282
+ self,
283
+ input_ids: mx.array,
284
+ pixel_values: mx.array,
285
+ mask: mx.array,
286
+ cache=None,
287
+ **kwargs,
288
+ ):
289
+ input_embeddings_features = self.get_input_embeddings(
290
+ input_ids, pixel_values, **kwargs
291
+ )
292
+ logits = self.language_model(
293
+ inputs=input_ids,
294
+ cache=cache,
295
+ inputs_embeds=input_embeddings_features.inputs_embeds,
296
+ )
297
+ return logits
298
+
299
+ def sanitize(self, weights):
300
+ weights = {
301
+ (
302
+ f"{k.split('.', 1)[1]}"
303
+ if re.match(r"^model\.", k)
304
+ else (f"language_model.{k}" if re.match(r"^lm_head\.", k) else k)
305
+ ): v
306
+ for k, v in weights.items()
307
+ }
308
+
309
+ weights = {
310
+ (
311
+ f"language_model.{k.split('.', 1)[1]}"
312
+ if re.match(
313
+ r"^text_model\.",
314
+ k,
315
+ )
316
+ else k
317
+ ): v
318
+ for k, v in weights.items()
319
+ }
320
+
321
+ return weights