fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,206 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from ..base import (
7
+ LanguageModelOutput,
8
+ create_attention_mask,
9
+ scaled_dot_product_attention,
10
+ )
11
+ from ..cache import KVCache
12
+ from .config import ModelConfig, TextConfig
13
+
14
+
15
+ class Molmo2Embedding(nn.Module):
16
+ def __init__(
17
+ self,
18
+ num_embeddings: int,
19
+ num_new_embeddings: int,
20
+ features: int,
21
+ ):
22
+ super().__init__()
23
+ self.embedding = mx.zeros((num_embeddings, features))
24
+ self.new_embedding = mx.zeros((num_new_embeddings, features))
25
+
26
+ def __call__(self, x: mx.array) -> mx.array:
27
+ return mx.concatenate([self.embedding, self.new_embedding], axis=0)[x]
28
+
29
+
30
+ class LanguageModelMLP(nn.Module):
31
+ def __init__(self, input_dim: int, intermediate_size: int):
32
+ super().__init__()
33
+ self.ff_proj = nn.Linear(input_dim, intermediate_size * 2, bias=False)
34
+ self.ff_out = nn.Linear(intermediate_size, input_dim, bias=False)
35
+
36
+ def __call__(self, x: mx.array) -> mx.array:
37
+ x = self.ff_proj(x)
38
+ x, gate = mx.split(x, 2, axis=-1)
39
+ x = nn.silu(gate) * x
40
+ return self.ff_out(x)
41
+
42
+
43
+ class Molmo2Attention(nn.Module):
44
+ def __init__(self, config: TextConfig):
45
+ super().__init__()
46
+ self.config = config
47
+ self.num_heads = config.num_attention_heads
48
+ self.num_key_value_heads = config.num_key_value_heads
49
+ self.head_dim = config.head_dim
50
+ self.scale = self.head_dim**-0.5
51
+
52
+ self.fused_dims = (
53
+ config.num_attention_heads * config.head_dim,
54
+ config.head_dim * config.num_key_value_heads,
55
+ config.head_dim * config.num_key_value_heads,
56
+ )
57
+
58
+ self.att_proj = nn.Linear(
59
+ config.hidden_size,
60
+ sum(self.fused_dims),
61
+ bias=config.qkv_bias,
62
+ )
63
+
64
+ self.q_norm = nn.RMSNorm(dims=config.head_dim, eps=config.layer_norm_eps)
65
+ self.k_norm = nn.RMSNorm(dims=config.head_dim, eps=config.layer_norm_eps)
66
+
67
+ self.attn_out = nn.Linear(
68
+ config.head_dim * config.num_attention_heads,
69
+ config.hidden_size,
70
+ bias=False,
71
+ )
72
+
73
+ self.rotary_emb = nn.RoPE(self.head_dim, base=config.rope_theta)
74
+
75
+ def __call__(
76
+ self,
77
+ hidden_states: mx.array,
78
+ mask: Optional[mx.array] = None,
79
+ cache: Optional[KVCache] = None,
80
+ ) -> mx.array:
81
+ batch_size, seq_len, _ = hidden_states.shape
82
+
83
+ qkv = self.att_proj(hidden_states)
84
+ q, k, v = mx.split(
85
+ qkv,
86
+ [self.fused_dims[0], self.fused_dims[0] + self.fused_dims[1]],
87
+ axis=-1,
88
+ )
89
+
90
+ q = self.q_norm(q.reshape(batch_size, seq_len, self.num_heads, self.head_dim))
91
+ k = self.k_norm(
92
+ k.reshape(batch_size, seq_len, self.num_key_value_heads, self.head_dim)
93
+ )
94
+ v = v.reshape(batch_size, seq_len, self.num_key_value_heads, self.head_dim)
95
+
96
+ q = q.transpose(0, 2, 1, 3)
97
+ k = k.transpose(0, 2, 1, 3)
98
+ v = v.transpose(0, 2, 1, 3)
99
+
100
+ if cache is not None:
101
+ q = self.rotary_emb(q, offset=cache.offset)
102
+ k = self.rotary_emb(k, offset=cache.offset)
103
+ k, v = cache.update_and_fetch(k, v)
104
+ else:
105
+ q = self.rotary_emb(q)
106
+ k = self.rotary_emb(k)
107
+
108
+ att = scaled_dot_product_attention(q, k, v, cache, scale=self.scale, mask=mask)
109
+ att = att.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, -1)
110
+ return self.attn_out(att)
111
+
112
+
113
+ class Molmo2DecoderLayer(nn.Module):
114
+ def __init__(self, config: TextConfig):
115
+ super().__init__()
116
+ self.self_attn = Molmo2Attention(config)
117
+ self.attn_norm = nn.RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
118
+ self.ff_norm = nn.RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
119
+ self.mlp = LanguageModelMLP(config.hidden_size, config.intermediate_size)
120
+
121
+ def __call__(
122
+ self,
123
+ hidden_states: mx.array,
124
+ mask: Optional[mx.array] = None,
125
+ cache: Optional[KVCache] = None,
126
+ ) -> mx.array:
127
+ residual = hidden_states
128
+ hidden_states = self.attn_norm(hidden_states)
129
+ hidden_states = residual + self.self_attn(hidden_states, mask, cache)
130
+
131
+ residual = hidden_states
132
+ hidden_states = self.ff_norm(hidden_states)
133
+ hidden_states = residual + self.mlp(hidden_states)
134
+ return hidden_states
135
+
136
+
137
+ class Molmo2Transformer(nn.Module):
138
+ def __init__(self, config: TextConfig):
139
+ super().__init__()
140
+ self.config = config
141
+
142
+ self.wte = Molmo2Embedding(
143
+ config.vocab_size, config.additional_vocab_size, config.hidden_size
144
+ )
145
+ self.blocks = [
146
+ Molmo2DecoderLayer(config) for _ in range(config.num_hidden_layers)
147
+ ]
148
+ self.ln_f = nn.RMSNorm(config.hidden_size, eps=config.layer_norm_eps)
149
+
150
+ self.emb_drop = nn.Dropout(config.embedding_dropout)
151
+
152
+ def __call__(
153
+ self,
154
+ input_ids: mx.array,
155
+ inputs_embeds: Optional[mx.array] = None,
156
+ mask: Optional[mx.array] = None,
157
+ cache: Optional[list[KVCache]] = None,
158
+ ) -> mx.array:
159
+ if inputs_embeds is None:
160
+ hidden_states = self.wte(input_ids)
161
+ else:
162
+ hidden_states = inputs_embeds
163
+
164
+ if cache is None:
165
+ cache = [None] * len(self.blocks)
166
+
167
+ if mask is None:
168
+ mask = create_attention_mask(hidden_states, cache)
169
+
170
+ hidden_states = self.emb_drop(hidden_states)
171
+
172
+ for block, c in zip(self.blocks, cache):
173
+ hidden_states = block(hidden_states, mask, c)
174
+
175
+ return self.ln_f(hidden_states)
176
+
177
+
178
+ class LanguageModel(nn.Module):
179
+ def __init__(self, args: TextConfig, config: ModelConfig = None):
180
+ super().__init__()
181
+ self.args = args
182
+ self.config = config
183
+ self.model_type = args.model_type
184
+ self.model = Molmo2Transformer(args)
185
+ self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
186
+
187
+ def __call__(
188
+ self,
189
+ inputs: mx.array,
190
+ inputs_embeds: Optional[mx.array] = None,
191
+ mask: Optional[mx.array] = None,
192
+ cache: Optional[list[KVCache]] = None,
193
+ **kwargs,
194
+ ) -> LanguageModelOutput:
195
+ hidden_states = self.model(inputs, inputs_embeds, mask, cache)
196
+ logits = self.lm_head(hidden_states)
197
+ return LanguageModelOutput(logits=logits)
198
+
199
+ @staticmethod
200
+ def sanitize(weights):
201
+ # Remove unused precomputed rotary freqs if present.
202
+ return {k: v for k, v in weights.items() if "rotary_emb.inv_freq" not in k}
203
+
204
+ @property
205
+ def layers(self):
206
+ return self.model.blocks
@@ -0,0 +1,330 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ import numpy as np
6
+
7
+ from ..base import InputEmbeddingsFeatures, LanguageModelOutput
8
+ from .config import ModelConfig
9
+ from .language import LanguageModel
10
+ from .vision import VisionModel
11
+
12
+
13
+ class Model(nn.Module):
14
+ def __init__(self, config: ModelConfig):
15
+ super().__init__()
16
+ self.config = config
17
+ self.language_model = LanguageModel(config.text_config, config)
18
+ self.vision_tower = VisionModel(config.vision_config)
19
+
20
+ @property
21
+ def layers(self):
22
+ return self.language_model.layers
23
+
24
+ def build_batched_images(
25
+ self,
26
+ input_ids: mx.array,
27
+ pixel_values: mx.array,
28
+ image_token_pooling: mx.array,
29
+ image_grids: mx.array,
30
+ image_num_crops: mx.array,
31
+ ) -> tuple[mx.array, mx.array]:
32
+ raw_counts = (input_ids == self.config.image_end_token_id).sum(axis=1)
33
+ counts = raw_counts // 2
34
+ batch_size = counts.shape[0]
35
+
36
+ num_images = int(counts.sum().item())
37
+
38
+ if image_grids.shape[0] != num_images:
39
+ raise ValueError(
40
+ f"Expected {num_images} image grids, got {image_grids.shape[0]}"
41
+ )
42
+ if image_num_crops.shape[0] != num_images:
43
+ raise ValueError(
44
+ f"Expected {num_images} image crop counts, got {image_num_crops.shape[0]}"
45
+ )
46
+
47
+ num_pooled_patches_per_image = (
48
+ (image_grids[:, :2].prod(axis=1) + image_grids[:, 2:].prod(axis=1))
49
+ .astype(image_num_crops.dtype)
50
+ .reshape(-1)
51
+ )
52
+
53
+ n_crops, n_patches, pixels_per_patch = pixel_values.shape
54
+
55
+ example_ids_for_image = mx.array(
56
+ np.repeat(
57
+ np.arange(batch_size), np.array(counts).astype(np.int32).tolist()
58
+ ),
59
+ dtype=mx.int32,
60
+ )
61
+
62
+ crops_per_example = mx.zeros((batch_size,), dtype=image_num_crops.dtype)
63
+ pooled_per_example = mx.zeros(
64
+ (batch_size,), dtype=num_pooled_patches_per_image.dtype
65
+ )
66
+ for image_idx in range(num_images):
67
+ ex = int(example_ids_for_image[image_idx].item())
68
+ crops_per_example[ex] = crops_per_example[ex] + image_num_crops[image_idx]
69
+ pooled_per_example[ex] = (
70
+ pooled_per_example[ex] + num_pooled_patches_per_image[image_idx]
71
+ )
72
+
73
+ total_crops = int(crops_per_example.sum().item())
74
+ if total_crops != n_crops:
75
+ raise ValueError(f"Expected {total_crops} crops, got {n_crops}")
76
+
77
+ total_pooled = int(pooled_per_example.sum().item())
78
+ if total_pooled != image_token_pooling.shape[0]:
79
+ raise ValueError(
80
+ f"Expected {total_pooled} pooled patches, got {image_token_pooling.shape[0]}"
81
+ )
82
+
83
+ max_crops = int(crops_per_example.max().item())
84
+ images = mx.full(
85
+ (batch_size, max_crops, n_patches, pixels_per_patch),
86
+ vals=-1,
87
+ dtype=pixel_values.dtype,
88
+ )
89
+
90
+ offset_crop = 0
91
+ for i in range(batch_size):
92
+ num = int(crops_per_example[i].item())
93
+ images[i, :num] = pixel_values[offset_crop : offset_crop + num]
94
+ offset_crop += num
95
+
96
+ max_pooled = int(pooled_per_example.max().item())
97
+ token_dim = image_token_pooling.shape[1]
98
+ new_token_pooling = mx.full(
99
+ (batch_size, max_pooled, token_dim),
100
+ vals=-1,
101
+ dtype=image_token_pooling.dtype,
102
+ )
103
+
104
+ patches_per_image = image_num_crops * n_patches
105
+ counts_list = counts.tolist()
106
+ image_idx = 0
107
+ pooled_offset = 0
108
+ patch_offset = 0
109
+ for ex, c in enumerate(counts_list):
110
+ num_pooled = int(pooled_per_example[ex].item())
111
+ cur = mx.array(
112
+ image_token_pooling[pooled_offset : pooled_offset + num_pooled]
113
+ )
114
+
115
+ per_img_patches = patches_per_image[image_idx : image_idx + c]
116
+ index_offsets = [0] + np.cumsum(per_img_patches.tolist()).tolist()[:-1]
117
+ per_img_pooled = num_pooled_patches_per_image[
118
+ image_idx : image_idx + c
119
+ ].tolist()
120
+
121
+ offset = 0
122
+ for j in range(c):
123
+ n = int(per_img_pooled[j])
124
+ idx_off = int(index_offsets[j])
125
+ cur_slice = cur[offset : offset + n]
126
+ cur[offset : offset + n] = mx.where(
127
+ cur_slice >= 0,
128
+ cur_slice + idx_off,
129
+ cur_slice,
130
+ )
131
+ offset += n
132
+
133
+ new_token_pooling[ex, :num_pooled] = cur
134
+ pooled_offset += num_pooled
135
+ image_idx += c
136
+ patch_offset += num_pooled
137
+
138
+ return images, new_token_pooling
139
+
140
+ def build_batched_videos(
141
+ self,
142
+ input_ids: mx.array,
143
+ pixel_values_videos: mx.array,
144
+ video_token_pooling: mx.array,
145
+ video_grids: mx.array,
146
+ ) -> tuple[mx.array, mx.array]:
147
+ end_token_id = (
148
+ self.config.frame_end_token_id
149
+ if self.config.use_frame_special_tokens
150
+ else self.config.image_end_token_id
151
+ )
152
+ counts = mx.any(input_ids == end_token_id, axis=1).astype(mx.int32)
153
+ batch_size = counts.shape[0]
154
+ num_videos = int(counts.sum().item())
155
+
156
+ if video_grids.shape[0] != num_videos:
157
+ raise ValueError(
158
+ f"Expected {num_videos} videos, got {video_grids.shape[0]}"
159
+ )
160
+
161
+ num_pooled_patches_per_video = (video_grids[:, 1] * video_grids[:, 2]).astype(
162
+ video_token_pooling.dtype
163
+ )
164
+
165
+ n_frames, n_patches, pixels_per_patch = pixel_values_videos.shape
166
+
167
+ frames_per_example = mx.zeros((batch_size,), dtype=mx.int32)
168
+ pooled_per_example = mx.zeros((batch_size,), dtype=video_token_pooling.dtype)
169
+
170
+ video_index = 0
171
+ for i in range(batch_size):
172
+ if counts[i].item() == 1:
173
+ frames_per_example[i] = int(video_grids[video_index][0].item())
174
+ pooled_per_example[i] = num_pooled_patches_per_video[video_index]
175
+ video_index += 1
176
+
177
+ max_frames = int(frames_per_example.max().item()) if num_videos else 0
178
+ videos = mx.full(
179
+ (batch_size, max_frames, n_patches, pixels_per_patch),
180
+ vals=-1,
181
+ dtype=pixel_values_videos.dtype,
182
+ )
183
+
184
+ offset = 0
185
+ for i in range(batch_size):
186
+ num = int(frames_per_example[i].item())
187
+ if num > 0:
188
+ videos[i, :num] = pixel_values_videos[offset : offset + num]
189
+ offset += num
190
+
191
+ max_pooled = int(pooled_per_example.max().item()) if num_videos else 0
192
+ token_dim = video_token_pooling.shape[1]
193
+ new_token_pooling = mx.full(
194
+ (batch_size, max_pooled, token_dim),
195
+ vals=-1,
196
+ dtype=video_token_pooling.dtype,
197
+ )
198
+
199
+ pooled_offset = 0
200
+ for i in range(batch_size):
201
+ num = int(pooled_per_example[i].item())
202
+ if num > 0:
203
+ new_token_pooling[i, :num] = video_token_pooling[
204
+ pooled_offset : pooled_offset + num
205
+ ]
206
+ pooled_offset += num
207
+
208
+ if offset != n_frames:
209
+ raise ValueError(f"Expected {n_frames} frames, got {offset}")
210
+ if pooled_offset != video_token_pooling.shape[0]:
211
+ raise ValueError(
212
+ f"Expected {video_token_pooling.shape[0]} pooled tokens, got {pooled_offset}"
213
+ )
214
+
215
+ return videos, new_token_pooling
216
+
217
+ def merge_visual_inputs(
218
+ self,
219
+ *,
220
+ input_ids: mx.array,
221
+ pixel_values: Optional[mx.array] = None,
222
+ image_token_pooling: Optional[mx.array] = None,
223
+ image_grids: Optional[mx.array] = None,
224
+ image_num_crops: Optional[mx.array] = None,
225
+ video_token_pooling: Optional[mx.array] = None,
226
+ video_grids: Optional[mx.array] = None,
227
+ ) -> tuple[Optional[mx.array], Optional[mx.array]]:
228
+ if pixel_values is None:
229
+ return None, None
230
+
231
+ if video_token_pooling is not None or video_grids is not None:
232
+ if video_token_pooling is None or video_grids is None:
233
+ raise ValueError(
234
+ "video_token_pooling and video_grids are required for videos"
235
+ )
236
+ return self.build_batched_videos(
237
+ input_ids=input_ids,
238
+ pixel_values_videos=pixel_values,
239
+ video_token_pooling=video_token_pooling,
240
+ video_grids=video_grids,
241
+ )
242
+
243
+ if (
244
+ image_token_pooling is None
245
+ or image_grids is None
246
+ or image_num_crops is None
247
+ ):
248
+ raise ValueError(
249
+ "image_token_pooling, image_grids, and image_num_crops are required for images"
250
+ )
251
+
252
+ return self.build_batched_images(
253
+ input_ids=input_ids,
254
+ pixel_values=pixel_values,
255
+ image_token_pooling=image_token_pooling,
256
+ image_grids=image_grids,
257
+ image_num_crops=image_num_crops,
258
+ )
259
+
260
+ def get_input_embeddings(
261
+ self,
262
+ input_ids: mx.array,
263
+ pixel_values: Optional[mx.array] = None,
264
+ **kwargs,
265
+ ) -> mx.array:
266
+ input_ids = input_ids * (input_ids != -1).astype(input_ids.dtype)
267
+ x = self.language_model.model.wte(input_ids)
268
+
269
+ if pixel_values is not None:
270
+
271
+ pixel_values, token_pooling = self.merge_visual_inputs(
272
+ input_ids=input_ids,
273
+ pixel_values=pixel_values,
274
+ image_token_pooling=kwargs.get("image_token_pooling", None),
275
+ image_grids=kwargs.get("image_grids", None),
276
+ image_num_crops=kwargs.get("image_num_crops", None),
277
+ video_token_pooling=kwargs.get("video_token_pooling", None),
278
+ video_grids=kwargs.get("video_grids", None),
279
+ )
280
+
281
+ dtype = self.vision_tower.image_vit.patch_embedding.weight.dtype
282
+ pixel_values = pixel_values.astype(dtype)
283
+ image_features = self.vision_tower(pixel_values, token_pooling)
284
+ is_image_patch = mx.reshape(input_ids, (-1,)) == self.config.image_patch_id
285
+ if int(is_image_patch.sum().item()) != image_features.shape[0]:
286
+ raise ValueError(
287
+ f"Expected {int(is_image_patch.sum().item())} image features, got {image_features.shape[0]}"
288
+ )
289
+ flat_x = mx.reshape(x, (-1, x.shape[-1]))
290
+ positions = mx.array(np.where(np.array(is_image_patch))[0], dtype=mx.uint32)
291
+ flat_x[positions] = flat_x[positions] + image_features
292
+ x = flat_x.reshape(x.shape)
293
+
294
+ return InputEmbeddingsFeatures(inputs_embeds=x)
295
+
296
+ def __call__(
297
+ self,
298
+ input_ids: mx.array,
299
+ pixel_values: Optional[mx.array] = None,
300
+ mask: Optional[mx.array] = None,
301
+ cache=None,
302
+ **kwargs,
303
+ ) -> LanguageModelOutput:
304
+ if input_ids.ndim == 1:
305
+ input_ids = input_ids[None, :]
306
+
307
+ input_embeddings_features = self.get_input_embeddings(
308
+ input_ids=input_ids, pixel_values=pixel_values, **kwargs
309
+ )
310
+
311
+ return self.language_model(
312
+ input_ids,
313
+ inputs_embeds=input_embeddings_features.inputs_embeds,
314
+ mask=mask,
315
+ cache=cache,
316
+ )
317
+
318
+ def sanitize(self, weights):
319
+ def transform_key(key: str) -> str:
320
+ if key.startswith("model.transformer."):
321
+ key = key.replace("model.transformer.", "language_model.model.", 1)
322
+ if key.startswith("model.vision_backbone."):
323
+ key = key.replace("model.vision_backbone.", "vision_tower.", 1)
324
+ if key.startswith("lm_head."):
325
+ key = key.replace("lm_head.", "language_model.lm_head.", 1)
326
+ # Vision transformer uses list not named submodule
327
+ key = key.replace(".transformer.resblocks.", ".transformer.")
328
+ return key
329
+
330
+ return {transform_key(k): v for k, v in weights.items()}