fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,408 @@
1
+ from typing import List, Optional, Tuple
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from .config import VisionConfig
7
+
8
+
9
+ class MLP(nn.Module):
10
+ def __init__(self, config: VisionConfig, input_dim: int):
11
+ super().__init__()
12
+ self.config = config
13
+ self.hidden_size = config.hidden_size
14
+ self.w1 = nn.Linear(
15
+ input_dim,
16
+ self.hidden_size,
17
+ bias=False,
18
+ )
19
+ self.w2 = nn.Linear(
20
+ self.hidden_size,
21
+ config.d_model,
22
+ bias=False,
23
+ )
24
+ self.w3 = nn.Linear(
25
+ input_dim,
26
+ self.hidden_size,
27
+ bias=False,
28
+ )
29
+
30
+ def __call__(self, x: mx.array) -> mx.array:
31
+ x = self.w2(nn.silu(self.w1(x)) * self.w3(x))
32
+ return x
33
+
34
+
35
+ class ViTMLP(nn.Module):
36
+ def __init__(self, config: VisionConfig):
37
+ super().__init__()
38
+ self.config = config
39
+ self.w1 = nn.Linear(config.image_emb_dim, config.image_mlp_dim, bias=True)
40
+ self.w2 = nn.Linear(config.image_mlp_dim, config.image_emb_dim, bias=True)
41
+ self.act = nn.GELU(approx="fast")
42
+
43
+ def __call__(self, x: mx.array) -> mx.array:
44
+ x = self.w1(x)
45
+ x = self.act(x)
46
+ x = self.w2(x)
47
+ return x
48
+
49
+
50
+ class MultiHeadDotProductAttention(nn.Module):
51
+ def __init__(self, config: VisionConfig, is_vit_layer: Optional[bool] = True):
52
+ super().__init__()
53
+ self.config = config
54
+ self.embed_dim = config.image_emb_dim
55
+ self.num_heads = config.image_num_heads
56
+ self.head_dim = config.image_head_dim
57
+ self.num_key_value_heads = config.image_num_key_value_heads
58
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
59
+ self.scale = self.head_dim**-0.5
60
+ self.is_vit_layer = is_vit_layer
61
+
62
+ n_layers = (
63
+ 1 if (is_vit_layer or config.vit_layers is None) else len(config.vit_layers)
64
+ )
65
+
66
+ self.wq = nn.Linear(
67
+ n_layers * self.embed_dim, self.num_heads * self.head_dim, bias=True
68
+ )
69
+ self.wk = nn.Linear(
70
+ n_layers * self.embed_dim,
71
+ self.num_key_value_heads * self.head_dim,
72
+ bias=True,
73
+ )
74
+ self.wv = nn.Linear(
75
+ n_layers * self.embed_dim,
76
+ self.num_key_value_heads * self.head_dim,
77
+ bias=True,
78
+ )
79
+ self.wo = nn.Linear(self.num_heads * self.head_dim, self.embed_dim, bias=True)
80
+
81
+ def _split_heads(self, hidden_states, num_heads) -> mx.array:
82
+ return hidden_states.reshape(
83
+ hidden_states.shape[:2] + (num_heads, self.head_dim)
84
+ )
85
+
86
+ def _merge_heads(self, hidden_states) -> mx.array:
87
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
88
+
89
+ def __call__(self, x: mx.array, kv: mx.array = None) -> mx.array:
90
+ batch_size, seq_len, _ = x.shape
91
+
92
+ if kv is None:
93
+ k = x
94
+ v = x
95
+ else:
96
+ k = kv
97
+ v = kv
98
+ q = self._split_heads(self.wq(x), self.num_heads).transpose(0, 2, 1, 3)
99
+
100
+ k = self._split_heads(self.wk(k), self.num_key_value_heads).transpose(
101
+ 0, 2, 1, 3
102
+ )
103
+ v = self._split_heads(self.wv(v), self.num_key_value_heads).transpose(
104
+ 0, 2, 1, 3
105
+ )
106
+
107
+ attn = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale)
108
+ out = attn.transpose(0, 2, 1, 3)
109
+ out = self._merge_heads(out)
110
+ out = self.wo(out)
111
+ return out
112
+
113
+
114
+ class ResidualAttentionBlock(nn.Module):
115
+ def __init__(self, config: VisionConfig):
116
+ super().__init__()
117
+ self.config = config
118
+ self.attention = MultiHeadDotProductAttention(config)
119
+ self.feed_forward = ViTMLP(config)
120
+ self.attention_norm = nn.LayerNorm(
121
+ config.image_emb_dim, eps=config.image_norm_eps
122
+ )
123
+ self.ffn_norm = nn.LayerNorm(config.image_emb_dim, eps=config.image_norm_eps)
124
+
125
+ def __call__(self, x: mx.array) -> mx.array:
126
+ x = x + self.attention(self.attention_norm(x))
127
+ x = x + self.feed_forward(self.ffn_norm(x))
128
+ return x
129
+
130
+
131
+ class ResidualAttentionBlocks(nn.Module):
132
+ def __init__(self, config: VisionConfig):
133
+ super().__init__()
134
+ self.resblocks = [
135
+ ResidualAttentionBlock(config) for _ in range(config.image_num_layers)
136
+ ]
137
+
138
+ def __call__(self, x: mx.array) -> mx.array:
139
+ h = []
140
+ for block in self.resblocks:
141
+ x = block(x)
142
+ h.append(x)
143
+ return h
144
+
145
+
146
+ def _expand_token(token, batch_size: int):
147
+ return mx.broadcast_to(
148
+ mx.reshape(token, (1, 1, -1)), (batch_size, 1, token.shape[-1])
149
+ )
150
+
151
+
152
+ def pad_to_multiple(x, target_size, pad_mode="edge", pad_value=0):
153
+ """
154
+ Pad the last dimension of input tensor to match target size.
155
+
156
+ Args:
157
+ x: Input tensor with shape [..., D]
158
+ target_size: Desired size for the last dimension
159
+ pad_mode: Padding mode ('constant', 'reflect', etc.)
160
+ pad_value: Value to use for constant padding
161
+
162
+ Returns:
163
+ Padded tensor with shape [..., target_size]
164
+ """
165
+ current_size = x.shape[-1]
166
+
167
+ # Return early if no padding needed
168
+ if current_size == target_size:
169
+ return x
170
+
171
+ # Ensure target size is larger
172
+ if current_size > target_size:
173
+ raise ValueError(
174
+ f"Current size {current_size} is larger than target size {target_size}"
175
+ )
176
+
177
+ # Calculate padding needed
178
+ pad_size = target_size - current_size
179
+
180
+ # Create padding configuration
181
+ # No padding for batch and channel dimensions (0,0), only pad the last dim
182
+ pad_config = [(0, 0)] * (len(x.shape) - 1) + [(0, pad_size)]
183
+
184
+ return mx.pad(x, pad_width=pad_config, mode=pad_mode, constant_values=pad_value)
185
+
186
+
187
+ class VisionTransformer(nn.Module):
188
+ def __init__(self, config: VisionConfig):
189
+ super().__init__()
190
+ self.config = config
191
+ self.class_embedding = mx.zeros((config.image_emb_dim,))
192
+ self.positional_embedding = mx.zeros(
193
+ (config.image_num_pos, config.image_emb_dim)
194
+ )
195
+ self.patch_embedding = nn.Linear(
196
+ config.intermediate_size,
197
+ config.image_emb_dim,
198
+ bias=False,
199
+ )
200
+ self.pre_ln = nn.LayerNorm(config.image_emb_dim, eps=config.image_norm_eps)
201
+ self.transformer = ResidualAttentionBlocks(config)
202
+
203
+ def add_pos_emb(self, x: mx.array, patch_num: int) -> mx.array:
204
+ cls_emb = self.positional_embedding[0:1]
205
+ pos_emb = self.positional_embedding[1:]
206
+
207
+ # Reshape into 2D grid
208
+ pos_emb_size = int(pos_emb.shape[0] ** 0.5)
209
+ pos_emb = mx.reshape(pos_emb, (pos_emb_size, pos_emb_size, pos_emb.shape[1]))
210
+
211
+ (patch_num_0, patch_num_1) = patch_num
212
+
213
+ if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1:
214
+ # Reshape for upsampling (add batch and channel dims)
215
+ pos_emb = mx.expand_dims(pos_emb, 0)
216
+ pos_emb = mx.transpose(pos_emb, (0, 3, 1, 2))
217
+
218
+ # Create and apply upsampler
219
+ upsampler = nn.Upsample(
220
+ scale_factor=(
221
+ patch_num_0 / pos_emb.shape[2],
222
+ patch_num_1 / pos_emb.shape[3],
223
+ ),
224
+ mode="linear", # MLX doesn't have bicubic, using linear as closest alternative
225
+ align_corners=False,
226
+ )
227
+ pos_emb = upsampler(pos_emb)
228
+
229
+ # Restore original dimensions
230
+ pos_emb = mx.transpose(pos_emb, (0, 2, 3, 1))
231
+ pos_emb = mx.squeeze(pos_emb, 0)
232
+
233
+ pos_emb = mx.reshape(pos_emb, (-1, pos_emb.shape[-1]))
234
+
235
+ # Expand cls_emb and pos_emb
236
+ expanded_cls = cls_emb[None, :, :]
237
+ expanded_pos = pos_emb[None, :, :]
238
+
239
+ # Concatenate and add to x
240
+ pos_embedding = mx.concatenate([expanded_cls, expanded_pos], axis=1)
241
+ x = x + pos_embedding
242
+ return x
243
+
244
+ def __call__(self, x: mx.array, patch_num: int = None) -> List[mx.array]:
245
+ """
246
+ : param x: (batch_size, num_patch, n_pixels)
247
+ """
248
+ if patch_num is None:
249
+ patch_num = self.config.image_num_patch
250
+ B, N, D = x.shape
251
+
252
+ # (Optional) Due to quantization, pad around the image to match intermediate_size
253
+ x = pad_to_multiple(x, self.config.intermediate_size)
254
+
255
+ x = self.patch_embedding(x)
256
+
257
+ # class embeddings and positional embeddings
258
+ expanded_class_emb = _expand_token(self.class_embedding, x.shape[0])
259
+ expanded_class_emb = expanded_class_emb
260
+
261
+ x = mx.concatenate([expanded_class_emb, x], axis=1)
262
+ x = self.add_pos_emb(x, patch_num)
263
+
264
+ x = self.pre_ln(x)
265
+
266
+ hidden_states = self.transformer(x)
267
+ return hidden_states
268
+
269
+
270
+ class VisionModel(nn.Module):
271
+ def __init__(self, config):
272
+ super().__init__()
273
+ self.config = config
274
+ self.model_type = config.model_type
275
+ if self.model_type != "molmo":
276
+ raise ValueError(
277
+ f"Model type {self.model_type} not supported. Currently only 'molmo' is supported"
278
+ )
279
+ self.image_vit = VisionTransformer(config)
280
+ self.num_prefix_tokens = 1
281
+
282
+ self.image_pooling_2d = MultiHeadDotProductAttention(config, is_vit_layer=False)
283
+ self.image_projector = MLP(config, config.image_emb_dim)
284
+ self.pad_embed = mx.zeros((2, config.image_emb_dim * 2))
285
+
286
+ def encode_image(self, images: mx.array) -> mx.array:
287
+ """
288
+ : param images: (batch_size, num_crops, num_patch, n_pixels)
289
+ """
290
+ cfg = self.config
291
+ B, T, N, D = images.shape
292
+
293
+ # Check for -1 values across dimensions 1 and 2
294
+ reshaped_images = mx.reshape(images, (B * T, N, D))
295
+ mask = ~mx.all(reshaped_images == -1, axis=(1, 2), keepdims=True)
296
+
297
+ # Output all hidden states
298
+ images = reshaped_images
299
+ image_features = self.image_vit(images)
300
+
301
+ if cfg.vit_layers is not None:
302
+ features = []
303
+ for layer in cfg.vit_layers:
304
+ features.append(image_features[layer])
305
+ image_features = mx.concatenate(features, axis=-1)
306
+ else:
307
+ image_features = image_features[-1]
308
+
309
+ cls_embed = None
310
+ if self.num_prefix_tokens > 0:
311
+ cls_embed = image_features[:, 0]
312
+ image_features = image_features[:, 1:]
313
+
314
+ image_features = image_features * mask
315
+ image_features = mx.reshape(image_features, (B, T, N, -1))
316
+
317
+ cls_embed = mx.reshape(cls_embed, (B, T, -1)) if cls_embed is not None else None
318
+
319
+ return image_features, cls_embed
320
+
321
+ def __call__(
322
+ self, images: mx.array, image_masks: mx.array
323
+ ) -> Tuple[mx.array, Optional[mx.array]]:
324
+ cfg = self.config
325
+
326
+ batch_size, num_image = images.shape[:2]
327
+ image_features, cls_embed = self.encode_image(images)
328
+
329
+ if cfg.image_padding_embed:
330
+ assert image_masks is not None
331
+ if cfg.image_padding_embed == "pad_embed":
332
+ all_pad = image_masks == 0
333
+ pad_embed = mx.reshape(self.pad_embed, (1, 1, 1, -1))
334
+ image_features = image_features + pad_embed * mx.expand_dims(
335
+ all_pad, -1
336
+ )
337
+ elif cfg.image_padding_embed == "regress":
338
+ pad_embed = mx.reshape(self.pad_embed, (1, 1, 1, -1))
339
+ image_features = image_features + pad_embed * mx.expand_dims(
340
+ mx.maximum(image_masks, mx.zeros_like(image_masks)), -1
341
+ )
342
+ elif cfg.image_padding_embed == "pad_and_partial_pad":
343
+ pad_embed = mx.reshape(self.pad_embed, (2, 1, 1, 1, -1))
344
+ all_pad = image_masks == 0
345
+ partial_pad = mx.logical_and(image_masks < 1, mx.logical_not(all_pad))
346
+ partial_pad = partial_pad
347
+ all_pad = all_pad
348
+ image_features = image_features + pad_embed[0] * mx.expand_dims(
349
+ all_pad, -1
350
+ )
351
+ image_features = image_features + pad_embed[1] * mx.expand_dims(
352
+ partial_pad, -1
353
+ )
354
+ else:
355
+ raise ValueError(cfg.image_padding_embed)
356
+
357
+ image_features = mx.reshape(
358
+ image_features, (batch_size, num_image) + cfg.image_num_patch + (-1,)
359
+ )
360
+
361
+ if cfg.image_num_patch[0] % cfg.image_pooling_h == 1:
362
+ # Pad so we can still pool 2x2 patches
363
+ image_features = mx.pad(
364
+ image_features, [(0, 0), (0, 0), (0, 1), (0, 1), (0, 0)]
365
+ )
366
+
367
+ # image pooling
368
+ # MLX equivalent of einops rearrange
369
+ h_blocks = image_features.shape[2] // cfg.image_pooling_h
370
+ w_blocks = image_features.shape[3] // cfg.image_pooling_w
371
+ image_features = mx.reshape(
372
+ mx.transpose(
373
+ mx.reshape(
374
+ image_features,
375
+ (
376
+ batch_size,
377
+ num_image,
378
+ h_blocks,
379
+ cfg.image_pooling_h,
380
+ w_blocks,
381
+ cfg.image_pooling_w,
382
+ -1,
383
+ ),
384
+ ),
385
+ (0, 1, 2, 4, 3, 5, 6),
386
+ ),
387
+ (
388
+ batch_size * num_image * h_blocks * w_blocks,
389
+ cfg.image_pooling_h * cfg.image_pooling_w,
390
+ -1,
391
+ ),
392
+ )
393
+
394
+ if cfg.image_pooling_2d == "attention-meanq":
395
+ query = mx.mean(image_features, axis=-2, keepdims=True)
396
+ image_features = self.image_pooling_2d(query, image_features)
397
+ elif cfg.image_pooling_2d not in {"none", "stack"}:
398
+ image_features = self.image_pooling_2d(
399
+ image_features[:, :1, :], image_features
400
+ )
401
+
402
+ h, w = cfg.llm_patches_per_crop
403
+ image_features = mx.reshape(image_features, (batch_size, num_image, h * w, -1))
404
+
405
+ # # MLP layer to map the feature
406
+ image_features = self.image_projector(image_features)
407
+
408
+ return image_features, cls_embed
@@ -0,0 +1,6 @@
1
+ from .config import AdapterConfig, ModelConfig, TextConfig, VisionConfig, VitConfig
2
+ from .language import LanguageModel
3
+ from .molmo2 import Model
4
+ from .processing import Molmo2ImageProcessor as ImageProcessor
5
+ from .processing import Molmo2Processor as Processor
6
+ from .vision import VisionModel
@@ -0,0 +1,137 @@
1
+ import inspect
2
+ from dataclasses import dataclass, field
3
+ from typing import List, Optional
4
+
5
+ from ..base import BaseModelConfig
6
+
7
+
8
+ @dataclass
9
+ class VitConfig(BaseModelConfig):
10
+ model_type: str = "molmo2"
11
+ hidden_size: int = 1152
12
+ intermediate_size: int = 4304
13
+ num_hidden_layers: int = 25 # Note: HF config says 27 but weights only have 25
14
+ num_attention_heads: int = 16
15
+ num_key_value_heads: int = 16
16
+ head_dim: int = 72
17
+ image_patch_size: int = 14
18
+ image_num_pos: int = 729
19
+ image_default_input_size: List[int] = field(default_factory=lambda: [378, 378])
20
+ hidden_act: str = "gelu_pytorch_tanh"
21
+ layer_norm_eps: float = 1e-6
22
+ attention_dropout: float = 0.0
23
+ residual_dropout: float = 0.0
24
+ float32_attention: bool = True
25
+ attn_implementation: str = "sdpa"
26
+
27
+ @classmethod
28
+ def from_dict(cls, params):
29
+ # Workaround: HuggingFace config says 27 layers but weights only have 25
30
+ # Override to use 25 layers
31
+ if params.get("num_hidden_layers", 25) > 25:
32
+ params = dict(params) # Don't modify original
33
+ params["num_hidden_layers"] = 25
34
+ return super().from_dict(params)
35
+
36
+ @property
37
+ def image_num_patch(self):
38
+ h, w = self.image_default_input_size
39
+ return h // self.image_patch_size, w // self.image_patch_size
40
+
41
+
42
+ @dataclass
43
+ class AdapterConfig(BaseModelConfig):
44
+ model_type: str = "molmo2"
45
+ hidden_size: int = 1152
46
+ intermediate_size: int = 9728
47
+ text_hidden_size: int = 2560
48
+ num_attention_heads: int = 16
49
+ num_key_value_heads: int = 16
50
+ head_dim: int = 72
51
+ hidden_act: str = "silu"
52
+ vit_layers: List[int] = field(default_factory=lambda: [-3, -9])
53
+ image_feature_dropout: float = 0.0
54
+ pooling_attention_mask: bool = True
55
+ attention_dropout: float = 0.0
56
+ residual_dropout: float = 0.0
57
+ float32_attention: bool = True
58
+ attn_implementation: str = "sdpa"
59
+
60
+
61
+ @dataclass
62
+ class VisionConfig(BaseModelConfig):
63
+ vit_config: VitConfig = field(default_factory=VitConfig)
64
+ adapter_config: AdapterConfig = field(default_factory=AdapterConfig)
65
+
66
+ @classmethod
67
+ def from_dict(cls, params):
68
+ vit_cfg = params.get("vit_config", {})
69
+ adapter_cfg = params.get("adapter_config", {})
70
+ return cls(
71
+ vit_config=VitConfig.from_dict(vit_cfg),
72
+ adapter_config=AdapterConfig.from_dict(adapter_cfg),
73
+ )
74
+
75
+
76
+ @dataclass
77
+ class TextConfig(BaseModelConfig):
78
+ model_type: str = "molmo2"
79
+ hidden_size: int = 2560
80
+ intermediate_size: int = 9728
81
+ num_hidden_layers: int = 36
82
+ num_attention_heads: int = 32
83
+ num_key_value_heads: int = 8
84
+ head_dim: int = 128
85
+ vocab_size: int = 151936
86
+ additional_vocab_size: int = 128
87
+ hidden_act: str = "silu"
88
+ layer_norm_eps: float = 1e-6
89
+ attention_dropout: float = 0.0
90
+ residual_dropout: float = 0.0
91
+ embedding_dropout: float = 0.0
92
+ max_position_embeddings: int = 36864
93
+ rope_theta: float = 5000000.0
94
+ rope_scaling: Optional[dict] = None
95
+ use_qk_norm: bool = True
96
+ qk_norm_type: str = "qwen3"
97
+ qkv_bias: bool = False
98
+ use_cache: bool = True
99
+ norm_after: bool = False
100
+
101
+
102
+ @dataclass
103
+ class ModelConfig(BaseModelConfig):
104
+ text_config: TextConfig = field(default_factory=TextConfig)
105
+ vision_config: VisionConfig = field(default_factory=VisionConfig)
106
+ model_type: str = "molmo2"
107
+
108
+ image_start_token_id: int = 151936
109
+ low_res_image_start_token_id: int = 151940
110
+ image_end_token_id: int = 151937
111
+ image_low_res_id: int = 151942
112
+ image_patch_id: int = 151938
113
+ image_col_id: int = 151939
114
+ frame_start_token_id: int = 151943
115
+ frame_end_token_id: int = 151944
116
+ use_frame_special_tokens: bool = False
117
+
118
+ tie_word_embeddings: bool = False
119
+ initializer_range: float = 0.02
120
+ eos_token_id: Optional[List[int]] = None
121
+
122
+ @classmethod
123
+ def from_dict(cls, params):
124
+ # Normalize how the repo loads configs: always provide `vision_config`.
125
+ if not params.get("vision_config"):
126
+ params["vision_config"] = {
127
+ "vit_config": params.get("vit_config", {}),
128
+ "adapter_config": params.get("adapter_config", {}),
129
+ }
130
+
131
+ return cls(
132
+ **{
133
+ k: v
134
+ for k, v in params.items()
135
+ if k in inspect.signature(cls).parameters
136
+ }
137
+ )