fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,458 @@
1
+ from typing import List, Optional, Tuple
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from .config import VisionConfig
7
+
8
+
9
+ def check_array_shape(arr):
10
+ shape = arr.shape
11
+
12
+ # Check if the shape has 4 dimensions
13
+ if len(shape) != 4:
14
+ return False
15
+
16
+ out_channels, kH, KW, _ = shape
17
+
18
+ # Check if out_channels is the largest, and kH and KW are the same
19
+ if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
20
+ return True
21
+ else:
22
+ return False
23
+
24
+
25
+ class MllamaVisionAttention(nn.Module):
26
+ def __init__(self, config: VisionConfig):
27
+ super().__init__()
28
+ self.embed_dim = config.hidden_size
29
+ self.num_heads = config.num_attention_heads
30
+ self.head_dim = config.hidden_size // config.num_attention_heads
31
+ self.scale = self.head_dim**-0.5
32
+
33
+ self.q_proj = nn.Linear(
34
+ self.embed_dim, self.num_heads * self.head_dim, bias=False
35
+ )
36
+ self.k_proj = nn.Linear(
37
+ self.embed_dim, self.num_heads * self.head_dim, bias=False
38
+ )
39
+ self.v_proj = nn.Linear(
40
+ self.embed_dim, self.num_heads * self.head_dim, bias=False
41
+ )
42
+ self.o_proj = nn.Linear(
43
+ self.num_heads * self.head_dim, self.embed_dim, bias=False
44
+ )
45
+
46
+ def __call__(
47
+ self,
48
+ hidden_state: mx.array,
49
+ attention_mask: Optional[mx.array] = None,
50
+ ) -> mx.array:
51
+ query = self.q_proj(hidden_state)
52
+ key = self.k_proj(hidden_state)
53
+ value = self.v_proj(hidden_state)
54
+
55
+ batch_size, q_seq_len, _ = query.shape
56
+ _, kv_seq_len, _ = key.shape
57
+
58
+ query = query.reshape(
59
+ batch_size, q_seq_len, self.num_heads, self.head_dim
60
+ ).transpose(0, 2, 1, 3)
61
+ key = key.reshape(
62
+ batch_size, kv_seq_len, self.num_heads, self.head_dim
63
+ ).transpose(0, 2, 1, 3)
64
+ value = value.reshape(
65
+ batch_size, kv_seq_len, self.num_heads, self.head_dim
66
+ ).transpose(0, 2, 1, 3)
67
+
68
+ if attention_mask is not None:
69
+ attention_mask = attention_mask[:, :, : key.shape[-2], :]
70
+
71
+ attn_output = mx.fast.scaled_dot_product_attention(
72
+ query, key, value, scale=self.scale, mask=attention_mask
73
+ )
74
+
75
+ attn_output = attn_output.transpose(0, 2, 1, 3)
76
+ attn_output = attn_output.reshape(batch_size, q_seq_len, -1)
77
+
78
+ return self.o_proj(attn_output)
79
+
80
+
81
+ class MllamaVisionMLP(nn.Module):
82
+ def __init__(self, config: VisionConfig):
83
+ super().__init__()
84
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=True)
85
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=True)
86
+ self.gelu = nn.GELU()
87
+
88
+ def __call__(self, hidden_states: mx.array) -> mx.array:
89
+ hidden_states = self.fc1(hidden_states)
90
+ hidden_states = self.gelu(hidden_states)
91
+ hidden_states = self.fc2(hidden_states)
92
+ return hidden_states
93
+
94
+
95
+ class MllamaVisionEncoderLayer(nn.Module):
96
+ def __init__(self, config: VisionConfig, is_gated: bool = False):
97
+ super().__init__()
98
+ self.hidden_size = config.hidden_size
99
+ self.num_attention_heads = config.num_attention_heads
100
+ self.is_gated = is_gated
101
+
102
+ self.self_attn = MllamaVisionAttention(config)
103
+ self.mlp = MllamaVisionMLP(config)
104
+
105
+ self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps)
106
+ self.post_attention_layernorm = nn.LayerNorm(
107
+ self.hidden_size, eps=config.norm_eps
108
+ )
109
+
110
+ if is_gated:
111
+ self.gate_attn = mx.zeros(1)
112
+ self.gate_ffn = mx.zeros(1)
113
+
114
+ def __call__(
115
+ self,
116
+ hidden_state: mx.array,
117
+ attention_mask: Optional[mx.array] = None,
118
+ ) -> mx.array:
119
+ # Self Attention
120
+ residual = hidden_state
121
+ hidden_state = self.input_layernorm(hidden_state)
122
+ hidden_state = self.self_attn(hidden_state, attention_mask=attention_mask)
123
+ if self.is_gated:
124
+ hidden_state = mx.tanh(self.gate_attn) * hidden_state
125
+ hidden_state = residual + hidden_state
126
+
127
+ # Feed forward
128
+ residual = hidden_state
129
+ hidden_state = self.post_attention_layernorm(hidden_state)
130
+ hidden_state = self.mlp(hidden_state)
131
+ if self.is_gated:
132
+ hidden_state = mx.tanh(self.gate_ffn) * hidden_state
133
+ hidden_state = residual + hidden_state
134
+
135
+ return hidden_state
136
+
137
+
138
+ class MllamaVisionEncoder(nn.Module):
139
+ def __init__(self, config: VisionConfig, num_layers=32, is_gated=False):
140
+ super().__init__()
141
+ self.layers = [
142
+ MllamaVisionEncoderLayer(config, is_gated) for _ in range(num_layers)
143
+ ]
144
+
145
+ def __call__(
146
+ self,
147
+ hidden_states: mx.array,
148
+ attention_mask: Optional[mx.array] = None,
149
+ ) -> Tuple[mx.array, List[mx.array]]:
150
+ encoder_states = ()
151
+ for layer in self.layers:
152
+ hidden_states = layer(hidden_states, attention_mask=attention_mask)
153
+ encoder_states = encoder_states + (hidden_states,)
154
+ return hidden_states, encoder_states
155
+
156
+
157
+ class MllamaPrecomputedAspectRatioEmbedding(nn.Module):
158
+ def __init__(self, config: VisionConfig, is_gated: bool = True):
159
+ super().__init__()
160
+ self.max_num_tiles = config.max_num_tiles
161
+ self.hidden_size = config.hidden_size
162
+ self.max_aspect_ratio_id = config.max_aspect_ratio_id
163
+ self.is_gated = is_gated
164
+
165
+ self.embedding = nn.Embedding(
166
+ self.max_aspect_ratio_id + 1, self.max_num_tiles * self.hidden_size
167
+ )
168
+ if is_gated:
169
+ self.gate = mx.zeros(1)
170
+
171
+ def __call__(self, hidden_state: mx.array, aspect_ratio_ids: mx.array) -> mx.array:
172
+ embeddings = self.embedding(aspect_ratio_ids)
173
+ embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size)
174
+
175
+ if self.is_gated:
176
+ embeddings = embeddings * mx.tanh(self.gate)
177
+
178
+ hidden_state = hidden_state + embeddings
179
+ return hidden_state
180
+
181
+
182
+ class MllamaPrecomputedPositionEmbedding(nn.Module):
183
+ def __init__(self, config: VisionConfig):
184
+ super().__init__()
185
+ self.max_num_tiles = config.max_num_tiles
186
+ self.max_aspect_ratio_id = config.max_aspect_ratio_id
187
+ self.num_patches = (config.image_size // config.patch_size) ** 2 + 1
188
+ self.hidden_size = config.hidden_size
189
+ self.scale = config.hidden_size**-0.5
190
+
191
+ self.gate = mx.zeros(1)
192
+
193
+ # position embedding
194
+ self.embedding = (
195
+ mx.random.normal((self.num_patches, self.hidden_size)) * self.scale
196
+ )
197
+
198
+ # tile position embedding
199
+ self.tile_embedding = nn.Embedding(
200
+ self.max_aspect_ratio_id + 1,
201
+ self.max_num_tiles * self.num_patches * self.hidden_size,
202
+ )
203
+
204
+ def __call__(self, hidden_state: mx.array, aspect_ratio_ids: mx.array) -> mx.array:
205
+ # position embeddings
206
+ gated_position_embedding = (1 - mx.tanh(self.gate)) * self.embedding
207
+ hidden_state = hidden_state + gated_position_embedding.reshape(
208
+ 1, 1, self.num_patches, self.hidden_size
209
+ )
210
+
211
+ # precomputed tile position embeddings
212
+ tile_position_embedding = self.tile_embedding(aspect_ratio_ids)
213
+ batch_size = hidden_state.shape[0]
214
+ tile_position_embedding = tile_position_embedding.reshape(
215
+ batch_size, self.max_num_tiles, self.num_patches, self.hidden_size
216
+ )
217
+ gated_tile_position_embedding = mx.tanh(self.gate) * tile_position_embedding
218
+ hidden_state = hidden_state + gated_tile_position_embedding
219
+
220
+ return hidden_state
221
+
222
+
223
+ class VisionModel(nn.Module):
224
+ def __init__(self, config: VisionConfig):
225
+ super().__init__()
226
+ self.image_size = config.image_size
227
+ self.patch_size = config.patch_size
228
+ self.max_num_tiles = config.max_num_tiles
229
+ self.hidden_size = config.hidden_size
230
+ self.num_channels = config.num_channels
231
+ self.intermediate_layers_indices = config.intermediate_layers_indices
232
+
233
+ self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
234
+ self.scale = config.hidden_size**-0.5
235
+
236
+ self.patch_embedding = nn.Conv2d(
237
+ in_channels=config.num_channels,
238
+ out_channels=self.hidden_size,
239
+ kernel_size=self.patch_size,
240
+ stride=self.patch_size,
241
+ bias=False,
242
+ )
243
+
244
+ self.class_embedding = mx.random.normal((self.hidden_size,)) * self.scale
245
+ self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(config)
246
+
247
+ self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(
248
+ config, is_gated=True
249
+ )
250
+ self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(
251
+ config, is_gated=True
252
+ )
253
+
254
+ # layer norms
255
+ self.layernorm_pre = nn.LayerNorm(self.hidden_size, eps=config.norm_eps)
256
+ self.layernorm_post = nn.LayerNorm(self.hidden_size, eps=config.norm_eps)
257
+
258
+ # encoders
259
+ self.transformer = MllamaVisionEncoder(
260
+ config, config.num_hidden_layers, is_gated=False
261
+ )
262
+ self.global_transformer = MllamaVisionEncoder(
263
+ config, config.num_global_layers, is_gated=True
264
+ )
265
+
266
+ def __call__(
267
+ self,
268
+ pixel_values: mx.array,
269
+ aspect_ratio_ids: mx.array,
270
+ aspect_ratio_mask: mx.array,
271
+ ) -> mx.array:
272
+ batch_size, num_concurrent_media, num_tiles, num_channels, height, width = (
273
+ pixel_values.shape
274
+ )
275
+ aspect_ratio_ids = aspect_ratio_ids.reshape(
276
+ batch_size * num_concurrent_media, -1
277
+ )
278
+
279
+ pixel_values = pixel_values.reshape(
280
+ batch_size * num_concurrent_media * num_tiles, num_channels, height, width
281
+ )
282
+ # Patch embedding
283
+ patch_embeds = self.patch_embedding(pixel_values.moveaxis(1, 3)).moveaxis(3, 1)
284
+
285
+ hidden_state = patch_embeds.reshape(
286
+ patch_embeds.shape[0], patch_embeds.shape[1], -1
287
+ ).transpose(0, 2, 1)
288
+
289
+ # Tile embeddings
290
+ _, num_patches, dim = hidden_state.shape
291
+ hidden_state = hidden_state.reshape(
292
+ batch_size * num_concurrent_media, num_tiles, -1, dim
293
+ )
294
+ hidden_state = self.pre_tile_positional_embedding(
295
+ hidden_state, aspect_ratio_ids
296
+ )
297
+
298
+ # Add cls token
299
+ hidden_state = hidden_state.reshape(
300
+ batch_size * num_concurrent_media * num_tiles, num_patches, dim
301
+ )
302
+ class_embedding = mx.broadcast_to(
303
+ self.class_embedding,
304
+ (batch_size * num_concurrent_media * num_tiles, 1, dim),
305
+ )
306
+ hidden_state = mx.concatenate([class_embedding, hidden_state], axis=1)
307
+ num_patches += 1
308
+
309
+ # Position embeddings
310
+ hidden_state = hidden_state.reshape(
311
+ batch_size * num_concurrent_media, num_tiles, num_patches, dim
312
+ )
313
+ hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids)
314
+
315
+ hidden_state = self.layernorm_pre(hidden_state)
316
+
317
+ # Compute the number of tokens to pad
318
+ num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8
319
+
320
+ # Pad the tensor
321
+ padding = [(0, 0), (0, 0), (0, num_padding_patches), (0, 0)]
322
+ hidden_state = mx.pad(hidden_state, padding)
323
+ slice_index = -num_padding_patches if num_padding_patches > 0 else None
324
+
325
+ # Prepare attention mask
326
+ attention_mask = aspect_ratio_mask.reshape(
327
+ batch_size * num_concurrent_media, -1
328
+ )
329
+ attention_mask = _prepare_aspect_ratio_attention_mask(
330
+ aspect_ratio_mask=attention_mask,
331
+ num_patches=self.num_patches,
332
+ target_length=hidden_state.shape[2],
333
+ )
334
+
335
+ # Apply encoder
336
+ hidden_state = hidden_state.reshape(
337
+ batch_size * num_concurrent_media, -1, self.hidden_size
338
+ )
339
+ output = self.transformer(hidden_state, attention_mask=attention_mask)
340
+
341
+ hidden_state = output[0]
342
+
343
+ hidden_state = self.layernorm_post(hidden_state)
344
+
345
+ # Apply global encoder
346
+ hidden_state = hidden_state.reshape(
347
+ batch_size * num_concurrent_media,
348
+ num_tiles,
349
+ num_patches + num_padding_patches,
350
+ self.hidden_size,
351
+ )
352
+ hidden_state = self.post_tile_positional_embedding(
353
+ hidden_state, aspect_ratio_ids
354
+ )
355
+ hidden_state = hidden_state.reshape(
356
+ batch_size * num_concurrent_media,
357
+ num_tiles * (num_patches + num_padding_patches),
358
+ self.hidden_size,
359
+ )
360
+ global_output = self.global_transformer(
361
+ hidden_state, attention_mask=attention_mask
362
+ )
363
+
364
+ hidden_state = global_output[0]
365
+
366
+ hidden_state = hidden_state.reshape(
367
+ batch_size * num_concurrent_media,
368
+ num_tiles,
369
+ num_patches + num_padding_patches,
370
+ dim,
371
+ )
372
+
373
+ hidden_state = hidden_state[:, :, :slice_index]
374
+ hidden_state = hidden_state.reshape(
375
+ batch_size, num_concurrent_media, num_tiles, num_patches, dim
376
+ )
377
+
378
+ # Collect intermediate layer outputs from encoder output
379
+ all_intermediate_hidden_states = output[1]
380
+ intermediate_hidden_states = mx.stack(all_intermediate_hidden_states, axis=-1)
381
+ intermediate_hidden_states = intermediate_hidden_states[
382
+ ..., self.intermediate_layers_indices
383
+ ]
384
+
385
+ # Remove padding from intermediate hidden states
386
+ intermediate_hidden_states = intermediate_hidden_states.reshape(
387
+ batch_size * num_concurrent_media,
388
+ num_tiles,
389
+ num_patches + num_padding_patches,
390
+ -1,
391
+ )
392
+ intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index]
393
+ intermediate_hidden_states = intermediate_hidden_states.reshape(
394
+ batch_size, num_concurrent_media, num_tiles, num_patches, -1
395
+ )
396
+
397
+ # Concatenate final hidden state and intermediate hidden states
398
+ hidden_state = mx.concatenate(
399
+ [hidden_state, intermediate_hidden_states], axis=-1
400
+ )
401
+
402
+ return hidden_state
403
+
404
+ @staticmethod
405
+ def sanitize(weights):
406
+ sanitized_weights = {}
407
+ for k, v in weights.items():
408
+ if "position_ids" in k:
409
+ # Remove unused position_ids
410
+ continue
411
+ elif "patch_embedding.weight" in k:
412
+ # PyTorch conv2d weight tensors have shape:
413
+ # [out_channels, in_channels, kH, KW]
414
+ # MLX conv2d expects the weight be of shape:
415
+ # [out_channels, kH, KW, in_channels]
416
+ if check_array_shape(v):
417
+ sanitized_weights[k] = v
418
+ else:
419
+ sanitized_weights[k] = v.transpose(0, 2, 3, 1)
420
+ else:
421
+ sanitized_weights[k] = v
422
+
423
+ return sanitized_weights
424
+
425
+
426
+ def _prepare_aspect_ratio_attention_mask(
427
+ aspect_ratio_mask: mx.array,
428
+ num_patches: int,
429
+ target_length: int,
430
+ ) -> mx.array:
431
+ dtype = mx.float32
432
+ aspect_ratio_mask = aspect_ratio_mask.astype(dtype)
433
+
434
+ # Expand aspect ratio mask to target_length
435
+ batch_size, max_num_tiles = aspect_ratio_mask.shape
436
+ attention_mask = aspect_ratio_mask.reshape(batch_size, max_num_tiles, 1, 1).astype(
437
+ dtype
438
+ )
439
+ attention_mask = mx.tile(attention_mask, (1, 1, target_length, 1))
440
+
441
+ # Mask padding patches
442
+ pad_patches = target_length - num_patches
443
+ attention_mask[:, :, -pad_patches:] = 0
444
+
445
+ # Invert the mask (0 -> 1, 1 -> 0)
446
+ attention_mask = 1 - attention_mask
447
+
448
+ # Reshape to 2D and create 4D attention mask
449
+ # (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length)
450
+ attention_mask = attention_mask.reshape(
451
+ batch_size, max_num_tiles * target_length, 1
452
+ )
453
+
454
+ min_value = -1e9
455
+ attention_mask = attention_mask @ attention_mask.transpose(0, 2, 1) * min_value
456
+ attention_mask = attention_mask[:, None, :, :]
457
+
458
+ return attention_mask
@@ -0,0 +1,5 @@
1
+ from .config import ModelConfig, TextConfig, VisionConfig
2
+ from .language import LanguageModel
3
+ from .molmo import Model
4
+ from .processing_molmo import MolmoImageProcessor, MolmoProcessor
5
+ from .vision import VisionModel
@@ -0,0 +1,93 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import List, Optional, Tuple
3
+
4
+ from ..base import BaseModelConfig
5
+
6
+
7
+ @dataclass
8
+ class ModelConfig(BaseModelConfig):
9
+ text_config: "TextConfig" = field(default_factory=lambda: TextConfig())
10
+ vision_config: "VisionConfig" = field(default_factory=lambda: VisionConfig())
11
+ model_type: str = "molmo"
12
+ image_feature_dropout: float = 0.0
13
+ image_pooling_h: int = 2
14
+ image_pooling_w: int = 2
15
+ image_pooling_2d: str = "attention"
16
+ image_projector: str = "mlp"
17
+ eos_token_id: Optional[List[int]] = None
18
+
19
+
20
+ @dataclass
21
+ class TextConfig(BaseModelConfig):
22
+ model_type: str = "molmo"
23
+ max_position_embeddings: int = 4096
24
+ d_model: int = 3584
25
+ n_heads: int = 28
26
+ n_kv_heads: int = 4
27
+ n_layers: int = 28
28
+ mlp_ratio: int = 4
29
+ max_sequence_length: int = 1024
30
+ act_output_multiplier: int = 0.5
31
+ mlp_hidden_size: int = 37888
32
+ vocab_size: int = 152064
33
+ embedding_size: Optional[int] = 152064
34
+ additional_vocab_size: Optional[int] = None
35
+ attention_dropout: float = 0.1
36
+ residual_dropout: float = 0.1
37
+ embedding_dropout: float = 0.1
38
+ layer_norm_eps: float = 1e-5
39
+ initializer_range: float = 0.02
40
+ pad_token_id: int = -1
41
+ rope: bool = True
42
+ rope_theta: float = 1000000.0
43
+ weight_tying: bool = False
44
+ rope_full_precision: bool = True
45
+ rope_impl: str = "interleave"
46
+ additional_vocab_size: Optional[int] = 128
47
+
48
+
49
+ @dataclass
50
+ class VisionConfig(BaseModelConfig):
51
+ model_type: str = "molmo"
52
+ num_channels: int = 3
53
+ image_default_input_size: Tuple[int, int] = (336, 336)
54
+ image_patch_size: int = 14
55
+ image_pos_patch_size: int = 14
56
+ hidden_size: int = 18944
57
+ image_emb_dim: int = 1024
58
+ image_num_heads: int = 16
59
+ image_num_key_value_heads: int = 16
60
+ image_num_layers: int = 23
61
+ image_head_dim: int = 64
62
+ image_mlp_dim: int = 4096
63
+ image_mlp_activations: str = "gelu"
64
+ image_dropout_rate: float = 0.0
65
+ image_num_pos: int = 577
66
+ image_norm_eps: float = 1e-5
67
+ attention_dropout: float = 0.0
68
+ residual_dropout: float = 0.0
69
+ initializer_range: float = 0.02
70
+ d_model: int = 3584
71
+ image_pooling_h: int = 2
72
+ image_pooling_w: int = 2
73
+ vit_layers: Optional[List[int]] = field(default_factory=lambda: [-2, -9])
74
+ image_pooling_2d: str = "attention-meanq"
75
+ image_padding_embed: str = "pad_and_partial_pad"
76
+ intermediate_size: Optional[int] = None
77
+
78
+ def __post_init__(self):
79
+ if self.intermediate_size is None:
80
+ self.intermediate_size = self.image_patch_size * self.image_patch_size * 3
81
+
82
+ @property
83
+ def image_num_patch(self):
84
+ h, w = self.image_default_input_size
85
+ return h // self.image_patch_size, w // self.image_patch_size
86
+
87
+ @property
88
+ def llm_patches_per_crop(self):
89
+ h, w = self.image_num_patch
90
+ # Round up in case we need to pad the image features for pooling
91
+ h = (h + self.image_pooling_h - 1) // self.image_pooling_h
92
+ w = (w + self.image_pooling_w - 1) // self.image_pooling_w
93
+ return h, w