fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,286 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ import numpy as np
6
+
7
+ from ..base import interpolate
8
+ from .config import AdapterConfig, VisionConfig, VitConfig
9
+
10
+
11
+ def _gelu_from_name(name: str) -> nn.Module:
12
+ if name == "gelu_pytorch_tanh":
13
+ return nn.GELU(approx="fast")
14
+ return nn.GELU(approx="fast")
15
+
16
+
17
+ class ViTMLP(nn.Module):
18
+ def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str):
19
+ super().__init__()
20
+ self.w1 = nn.Linear(hidden_size, intermediate_size, bias=True)
21
+ self.w2 = nn.Linear(intermediate_size, hidden_size, bias=True)
22
+ self.act = _gelu_from_name(hidden_act)
23
+
24
+ def __call__(self, x: mx.array) -> mx.array:
25
+ return self.w2(self.act(self.w1(x)))
26
+
27
+
28
+ class ViTMultiHeadDotProductAttention(nn.Module):
29
+ def __init__(
30
+ self,
31
+ *,
32
+ hidden_size: int,
33
+ num_heads: int,
34
+ num_key_value_heads: int,
35
+ head_dim: int,
36
+ input_dim: Optional[int] = None,
37
+ use_bias: bool = True,
38
+ float32_attention: bool = True,
39
+ ):
40
+ super().__init__()
41
+ self.hidden_size = hidden_size
42
+ self.num_heads = num_heads
43
+ self.num_key_value_heads = num_key_value_heads
44
+ self.head_dim = head_dim
45
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
46
+ self.scale = head_dim**-0.5
47
+ self.float32_attention = float32_attention
48
+
49
+ input_dim = input_dim or hidden_size
50
+ self.wq = nn.Linear(input_dim, self.num_heads * self.head_dim, bias=use_bias)
51
+ self.wk = nn.Linear(
52
+ input_dim, self.num_key_value_heads * self.head_dim, bias=use_bias
53
+ )
54
+ self.wv = nn.Linear(
55
+ input_dim, self.num_key_value_heads * self.head_dim, bias=use_bias
56
+ )
57
+ self.wo = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True)
58
+
59
+ def __call__(
60
+ self,
61
+ inputs_q: mx.array,
62
+ inputs_kv: Optional[mx.array] = None,
63
+ attn_mask: Optional[mx.array] = None,
64
+ ) -> mx.array:
65
+ if inputs_kv is None:
66
+ inputs_k = inputs_q
67
+ inputs_v = inputs_q
68
+ else:
69
+ inputs_k = inputs_kv
70
+ inputs_v = inputs_kv
71
+
72
+ xq = self.wq(inputs_q)
73
+ xk = self.wk(inputs_k)
74
+ xv = self.wv(inputs_v)
75
+
76
+ bsz, q_len, _ = xq.shape
77
+ _, kv_len, _ = xk.shape
78
+
79
+ xq = xq.reshape(bsz, q_len, self.num_heads, self.head_dim)
80
+ xk = xk.reshape(bsz, kv_len, self.num_key_value_heads, self.head_dim)
81
+ xv = xv.reshape(bsz, kv_len, self.num_key_value_heads, self.head_dim)
82
+
83
+ if self.num_heads != self.num_key_value_heads:
84
+ xk = mx.repeat(xk, self.num_key_value_groups, axis=2)
85
+ xv = mx.repeat(xv, self.num_key_value_groups, axis=2)
86
+
87
+ q = xq.transpose(0, 2, 1, 3)
88
+ k = xk.transpose(0, 2, 1, 3)
89
+ v = xv.transpose(0, 2, 1, 3)
90
+
91
+ dtype = q.dtype
92
+ if self.float32_attention:
93
+ q = q.astype(mx.float32)
94
+ k = k.astype(mx.float32)
95
+ v = v.astype(mx.float32)
96
+
97
+ scores = mx.matmul(q, k.transpose(0, 1, 3, 2)) * self.scale
98
+ if attn_mask is not None:
99
+ scores = mx.where(
100
+ attn_mask,
101
+ scores,
102
+ mx.full(scores.shape, vals=-1e9, dtype=scores.dtype),
103
+ )
104
+
105
+ weights = mx.softmax(scores, axis=-1)
106
+ out = mx.matmul(weights, v).astype(dtype)
107
+ out = out.transpose(0, 2, 1, 3).reshape(bsz, q_len, -1)
108
+ return self.wo(out)
109
+
110
+
111
+ class Molmo2VisionBlock(nn.Module):
112
+ def __init__(self, config: VitConfig):
113
+ super().__init__()
114
+ self.attention = ViTMultiHeadDotProductAttention(
115
+ hidden_size=config.hidden_size,
116
+ num_heads=config.num_attention_heads,
117
+ num_key_value_heads=config.num_key_value_heads,
118
+ head_dim=config.head_dim,
119
+ float32_attention=config.float32_attention,
120
+ input_dim=config.hidden_size,
121
+ )
122
+ self.feed_forward = ViTMLP(
123
+ config.hidden_size, config.intermediate_size, config.hidden_act
124
+ )
125
+ self.attention_norm = nn.LayerNorm(
126
+ config.hidden_size, eps=config.layer_norm_eps
127
+ )
128
+ self.ffn_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
129
+
130
+ def __call__(self, x: mx.array) -> mx.array:
131
+ x = x + self.attention(self.attention_norm(x))
132
+ x = x + self.feed_forward(self.ffn_norm(x))
133
+ return x
134
+
135
+
136
+ class Molmo2VisionTransformer(nn.Module):
137
+ def __init__(self, config: VitConfig):
138
+ super().__init__()
139
+ self.config = config
140
+ self.num_prefix_tokens = 0
141
+
142
+ self.positional_embedding = mx.zeros((config.image_num_pos, config.hidden_size))
143
+ patch_dim = config.image_patch_size * config.image_patch_size * 3
144
+ self.patch_embedding = nn.Linear(patch_dim, config.hidden_size, bias=True)
145
+ self.transformer = [
146
+ Molmo2VisionBlock(config) for _ in range(config.num_hidden_layers)
147
+ ]
148
+
149
+ def add_pos_emb(self, x: mx.array, patch_num: Tuple[int, int]) -> mx.array:
150
+ pos_emb = self.positional_embedding
151
+ pos_emb_size = int(pos_emb.shape[0] ** 0.5)
152
+ pos_emb = mx.reshape(pos_emb, (pos_emb_size, pos_emb_size, pos_emb.shape[1]))
153
+
154
+ patch_h, patch_w = patch_num
155
+ if pos_emb.shape[0] != patch_h or pos_emb.shape[1] != patch_w:
156
+ pos_emb = mx.transpose(pos_emb[None, ...], (0, 3, 1, 2))
157
+ pos_emb = interpolate(
158
+ pos_emb, (patch_h, patch_w), mode="cubic", align_corners=False
159
+ )
160
+ pos_emb = mx.transpose(pos_emb, (0, 2, 3, 1))[0]
161
+
162
+ pos_emb = mx.reshape(pos_emb, (-1, pos_emb.shape[-1]))
163
+ return x + pos_emb[None, :, :].astype(x.dtype)
164
+
165
+ def __call__(
166
+ self,
167
+ x: mx.array,
168
+ patch_num: Optional[Tuple[int, int]] = None,
169
+ ):
170
+ if patch_num is None:
171
+ patch_num = self.config.image_num_patch
172
+
173
+ x = self.patch_embedding(x)
174
+ x = self.add_pos_emb(x, patch_num)
175
+
176
+ hidden_states = []
177
+ for block in self.transformer:
178
+ x = block(x)
179
+ hidden_states.append(x)
180
+ return hidden_states
181
+
182
+
183
+ class ImageProjectorMLP(nn.Module):
184
+ def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
185
+ super().__init__()
186
+ self.w1 = nn.Linear(input_dim, hidden_dim, bias=False)
187
+ self.w2 = nn.Linear(hidden_dim, output_dim, bias=False)
188
+ self.w3 = nn.Linear(input_dim, hidden_dim, bias=False)
189
+
190
+ def __call__(self, x: mx.array) -> mx.array:
191
+ return self.w2(nn.silu(self.w1(x)) * self.w3(x))
192
+
193
+
194
+ class VisionModel(nn.Module):
195
+ def __init__(self, config: VisionConfig):
196
+ super().__init__()
197
+ self.config = config
198
+ self.model_type = "molmo2"
199
+ self.vit_config: VitConfig = config.vit_config
200
+ self.adapter_config: AdapterConfig = config.adapter_config
201
+
202
+ self.image_vit = Molmo2VisionTransformer(self.vit_config)
203
+
204
+ self.vit_layers = []
205
+ for layer in self.adapter_config.vit_layers:
206
+ self.vit_layers.append(
207
+ layer if layer >= 0 else layer + self.vit_config.num_hidden_layers
208
+ )
209
+
210
+ pool_dim = self.vit_config.hidden_size * len(self.vit_layers)
211
+
212
+ self.image_pooling_2d = ViTMultiHeadDotProductAttention(
213
+ hidden_size=self.adapter_config.hidden_size,
214
+ num_heads=self.adapter_config.num_attention_heads,
215
+ num_key_value_heads=self.adapter_config.num_key_value_heads,
216
+ head_dim=self.adapter_config.head_dim,
217
+ input_dim=pool_dim,
218
+ float32_attention=self.adapter_config.float32_attention,
219
+ )
220
+
221
+ self.image_projector = ImageProjectorMLP(
222
+ self.adapter_config.hidden_size,
223
+ self.adapter_config.intermediate_size,
224
+ self.adapter_config.text_hidden_size,
225
+ )
226
+
227
+ def encode_image(self, images: mx.array) -> mx.array:
228
+ batch_size, num_crops, num_patch, patch_dim = images.shape
229
+ images = images.reshape(batch_size * num_crops, num_patch, patch_dim)
230
+ hidden_states = self.image_vit(images)
231
+
232
+ features = [hidden_states[layer] for layer in self.vit_layers]
233
+ image_features = mx.concatenate(features, axis=-1)
234
+ image_features = image_features.reshape(batch_size, num_crops, num_patch, -1)
235
+ return image_features
236
+
237
+ def __call__(
238
+ self,
239
+ images: mx.array,
240
+ pooled_patches_idx: mx.array,
241
+ ) -> mx.array:
242
+ batch_size, num_crops = images.shape[:2]
243
+
244
+ image_features = self.encode_image(images)
245
+ dim = image_features.shape[-1]
246
+
247
+ valid = pooled_patches_idx >= 0
248
+ valid_token = mx.any(valid, axis=-1)
249
+
250
+ flat_features = image_features.reshape(batch_size, -1, dim)
251
+ idx = mx.clip(pooled_patches_idx, 0, None)
252
+ batch_idx = mx.arange(batch_size)[:, None, None]
253
+ batch_idx = mx.broadcast_to(batch_idx, idx.shape)
254
+
255
+ gathered = flat_features[mx.reshape(batch_idx, (-1,)), mx.reshape(idx, (-1,))]
256
+ to_pool = gathered.reshape(
257
+ pooled_patches_idx.shape[0],
258
+ pooled_patches_idx.shape[1],
259
+ pooled_patches_idx.shape[2],
260
+ dim,
261
+ )
262
+
263
+ to_pool = to_pool * valid[..., None].astype(to_pool.dtype)
264
+ to_pool = to_pool.reshape(-1, pooled_patches_idx.shape[-1], dim)
265
+
266
+ if self.adapter_config.pooling_attention_mask:
267
+ attn_mask = valid.reshape(-1, 1, 1, valid.shape[-1])
268
+ denom = valid.reshape(-1, to_pool.shape[-2]).astype(mx.float32).sum(axis=-1)
269
+ denom = mx.where(denom == 0, mx.ones_like(denom), denom)
270
+ query = to_pool.sum(axis=-2, keepdims=True) / denom[:, None, None].astype(
271
+ to_pool.dtype
272
+ )
273
+ else:
274
+ attn_mask = None
275
+ query = mx.mean(to_pool, axis=-2, keepdims=True)
276
+
277
+ pooled = self.image_pooling_2d(query, to_pool, attn_mask=attn_mask)
278
+ pooled = pooled.reshape(batch_size, -1, pooled.shape[-1])
279
+ pooled = self.image_projector(pooled)
280
+
281
+ pooled = pooled.reshape(-1, pooled.shape[-1])
282
+
283
+ # MLX doesn't support boolean indexing, so convert to integer indices
284
+ valid_flat = np.array(valid_token).flatten()
285
+ valid_indices = np.where(valid_flat)[0]
286
+ return pooled[mx.array(valid_indices)]
@@ -0,0 +1,11 @@
1
+ from .config import ModelConfig, TextConfig, VisionConfig
2
+ from .image_crops import (
3
+ adaptive_avg_pool2d,
4
+ overlap_crop_image,
5
+ reconstruct_from_crops,
6
+ select_tiling,
7
+ )
8
+ from .moondream2 import ImageProcessor, Model
9
+ from .vision import VisionModel
10
+ from .language import LanguageModel
11
+ from . import processing_moondream # Registers the AutoProcessor patch
@@ -0,0 +1,92 @@
1
+ import inspect
2
+ from dataclasses import dataclass
3
+ from typing import Optional
4
+
5
+ from ..base import BaseModelConfig
6
+
7
+
8
+ @dataclass
9
+ class TextConfig(BaseModelConfig):
10
+ model_type: str = "phi"
11
+ hidden_size: int = 2048
12
+ num_hidden_layers: int = 24
13
+ intermediate_size: int = 8192
14
+ num_attention_heads: int = 32
15
+ num_key_value_heads: int = 32
16
+ vocab_size: int = 51200
17
+ max_position_embeddings: int = 2048
18
+ rope_theta: float = 10000.0
19
+ layer_norm_eps: float = 1e-5
20
+ # Moondream uses partial RoPE - only first 32 dims (head_dim // 2)
21
+ partial_rotary_factor: float = 0.5
22
+ # Prefix attention length: BOS (1) + image patches (729) = 730
23
+ prefix_attn_len: int = 730
24
+
25
+
26
+ @dataclass
27
+ class VisionConfig(BaseModelConfig):
28
+ model_type: str = "moondream_vision"
29
+ hidden_size: int = 1152 # enc_dim
30
+ num_hidden_layers: int = 27 # enc_n_layers
31
+ intermediate_size: int = 4304 # enc_ff_dim
32
+ num_attention_heads: int = 16 # enc_n_heads
33
+ image_size: int = 378 # crop_size
34
+ patch_size: int = 14 # enc_patch_size
35
+ num_channels: int = 3 # in_channels
36
+ layer_norm_eps: float = 1e-5
37
+ # Multi-crop settings (for future full implementation)
38
+ max_crops: int = 12
39
+ overlap_margin: int = 4
40
+
41
+
42
+ @dataclass
43
+ class ModelConfig(BaseModelConfig):
44
+ text_config: TextConfig = None
45
+ vision_config: VisionConfig = None
46
+ model_type: str = "moondream1"
47
+ # Projection MLP inner dimension
48
+ proj_inner_dim: int = 8192
49
+ # Image features are prepended after BOS token
50
+ image_token_index: int = -200
51
+ vocab_size: int = 51200
52
+ # Prefix attention length: BOS (1) + image patches (729) = 730
53
+ prefix_attn_len: int = 730
54
+ # Token IDs (EOS and BOS are the same for moondream)
55
+ eos_token_id: int = 0
56
+ bos_token_id: int = 0
57
+
58
+ def __post_init__(self):
59
+ if self.text_config is None:
60
+ self.text_config = TextConfig()
61
+ if self.vision_config is None:
62
+ self.vision_config = VisionConfig()
63
+
64
+ @classmethod
65
+ def from_dict(cls, params):
66
+ # Extract nested configs
67
+ text_config_dict = params.get("text_config", {})
68
+ vision_config_dict = params.get("vision_config", {})
69
+
70
+ # If text_config is empty, try to get from root level
71
+ if not text_config_dict:
72
+ text_config_dict = {
73
+ k: v
74
+ for k, v in params.items()
75
+ if k in inspect.signature(TextConfig).parameters
76
+ }
77
+
78
+ # Create nested config objects
79
+ text_config = TextConfig.from_dict(text_config_dict)
80
+ vision_config = VisionConfig.from_dict(vision_config_dict)
81
+
82
+ # Build the main config
83
+ return cls(
84
+ text_config=text_config,
85
+ vision_config=vision_config,
86
+ **{
87
+ k: v
88
+ for k, v in params.items()
89
+ if k in inspect.signature(cls).parameters
90
+ and k not in ("text_config", "vision_config")
91
+ },
92
+ )
@@ -0,0 +1,269 @@
1
+ """
2
+ Multi-crop image processing utilities for Moondream2.
3
+
4
+ Reference implementation: moondream2/image_crops.py
5
+ """
6
+
7
+ import math
8
+ from typing import Tuple
9
+
10
+ import mlx.core as mx
11
+ import numpy as np
12
+ from PIL import Image
13
+
14
+
15
+ def select_tiling(
16
+ height: int, width: int, crop_size: int, max_crops: int
17
+ ) -> Tuple[int, int]:
18
+ """
19
+ Determine the optimal number of tiles to cover an image with overlapping crops.
20
+
21
+ Ported from HF reference: moondream2/image_crops.py:17-50
22
+ """
23
+ if height <= crop_size or width <= crop_size:
24
+ return (1, 1)
25
+
26
+ # Minimum required tiles in each dimension
27
+ min_h = math.ceil(height / crop_size)
28
+ min_w = math.ceil(width / crop_size)
29
+
30
+ # If minimum required tiles exceed max_crops, return proportional distribution
31
+ if min_h * min_w > max_crops:
32
+ ratio = math.sqrt(max_crops / (min_h * min_w))
33
+ return (max(1, math.floor(min_h * ratio)), max(1, math.floor(min_w * ratio)))
34
+
35
+ # Perfect aspect-ratio tiles that satisfy max_crops
36
+ h_tiles = math.floor(math.sqrt(max_crops * height / width))
37
+ w_tiles = math.floor(math.sqrt(max_crops * width / height))
38
+
39
+ # Ensure we meet minimum tile requirements
40
+ h_tiles = max(h_tiles, min_h)
41
+ w_tiles = max(w_tiles, min_w)
42
+
43
+ # If we exceeded max_crops, scale down the larger dimension
44
+ if h_tiles * w_tiles > max_crops:
45
+ if w_tiles > h_tiles:
46
+ w_tiles = math.floor(max_crops / h_tiles)
47
+ else:
48
+ h_tiles = math.floor(max_crops / w_tiles)
49
+
50
+ return (max(1, h_tiles), max(1, w_tiles))
51
+
52
+
53
+ def overlap_crop_image(
54
+ image: np.ndarray,
55
+ max_crops: int = 12,
56
+ overlap_margin: int = 4,
57
+ base_size: Tuple[int, int] = (378, 378),
58
+ patch_size: int = 14,
59
+ ) -> Tuple[np.ndarray, Tuple[int, int]]:
60
+ """
61
+ Create overlapping crops from an image for multi-scale processing.
62
+
63
+ Ported from HF reference: moondream2/image_crops.py:58-167
64
+
65
+ Args:
66
+ image: Input image as numpy array [H, W, C] in range [0, 255]
67
+ max_crops: Maximum number of local crops allowed (default 12)
68
+ overlap_margin: Number of patches to overlap between adjacent crops (default 4)
69
+ base_size: Size of each crop (default (378, 378))
70
+ patch_size: Size of each patch for the vision encoder (default 14)
71
+
72
+ Returns:
73
+ crops: numpy array [n_crops, H, W, C] - crops[0] is global, rest are local
74
+ tiling: (h_tiles, w_tiles) tuple describing the local crop layout
75
+ """
76
+ original_h, original_w = image.shape[:2]
77
+
78
+ # Convert margin from patch units to pixels
79
+ margin_pixels = patch_size * overlap_margin
80
+ total_margin_pixels = margin_pixels * 2 # Both sides
81
+
82
+ # Calculate crop parameters
83
+ crop_patches = base_size[0] // patch_size # patches per crop dimension
84
+ crop_window_patches = crop_patches - (2 * overlap_margin) # usable patches
85
+ crop_window_size = crop_window_patches * patch_size # usable size in pixels
86
+
87
+ # Determine tiling using margin-adjusted dimensions and effective crop size
88
+ tiling = select_tiling(
89
+ original_h - total_margin_pixels,
90
+ original_w - total_margin_pixels,
91
+ crop_window_size,
92
+ max_crops,
93
+ )
94
+
95
+ # Pre-allocate crops
96
+ n_crops = tiling[0] * tiling[1] + 1 # +1 for global crop
97
+ crops = np.zeros(
98
+ (n_crops, base_size[0], base_size[1], image.shape[2]), dtype=np.uint8
99
+ )
100
+
101
+ # Resize image to fit tiling
102
+ target_size = (
103
+ tiling[0] * crop_window_size + total_margin_pixels,
104
+ tiling[1] * crop_window_size + total_margin_pixels,
105
+ )
106
+
107
+ pil_image = Image.fromarray(image.astype(np.uint8))
108
+
109
+ # Resize for local crops
110
+ resized = pil_image.resize(
111
+ (int(target_size[1]), int(target_size[0])),
112
+ resample=Image.Resampling.LANCZOS,
113
+ )
114
+ image = np.asarray(resized)
115
+
116
+ # Create global crop
117
+ global_crop = pil_image.resize(
118
+ (int(base_size[1]), int(base_size[0])),
119
+ resample=Image.Resampling.LANCZOS,
120
+ )
121
+ crops[0] = np.asarray(global_crop)
122
+
123
+ # Extract local crops
124
+ for i in range(tiling[0]):
125
+ for j in range(tiling[1]):
126
+ y0 = i * crop_window_size
127
+ x0 = j * crop_window_size
128
+
129
+ y_end = min(y0 + base_size[0], image.shape[0])
130
+ x_end = min(x0 + base_size[1], image.shape[1])
131
+
132
+ crop_region = image[y0:y_end, x0:x_end]
133
+ crops[
134
+ 1 + i * tiling[1] + j, : crop_region.shape[0], : crop_region.shape[1]
135
+ ] = crop_region
136
+
137
+ return crops, tiling
138
+
139
+
140
+ def reconstruct_from_crops(
141
+ local_features: mx.array,
142
+ tiling: Tuple[int, int],
143
+ overlap_margin: int = 4,
144
+ ) -> mx.array:
145
+ """
146
+ Reconstruct a unified feature map from local crop features.
147
+
148
+ This function stitches together the features from local crops,
149
+ handling the overlap regions by trimming interior margins.
150
+
151
+ Args:
152
+ local_features: [n_local, 27, 27, 1152] features from local crops
153
+ (27x27 patches per crop, each with 1152-dim features)
154
+ tiling: (h_tiles, w_tiles) describing the crop layout
155
+ overlap_margin: Number of patches that overlap between adjacent crops (default 4)
156
+
157
+ Returns:
158
+ Reconstructed feature map [H, W, 1152] where:
159
+ H = h_tiles * (27 - 2*overlap_margin) + 2*overlap_margin
160
+ W = w_tiles * (27 - 2*overlap_margin) + 2*overlap_margin
161
+ """
162
+ h_tiles, w_tiles = tiling
163
+ n_local = h_tiles * w_tiles
164
+ patches_per_side = 27 # 378 / 14 = 27 patches per crop side
165
+ hidden_size = local_features.shape[-1] # 1152
166
+
167
+ # Effective patches per crop after removing interior overlaps
168
+ effective_patches = patches_per_side - 2 * overlap_margin # 27 - 8 = 19
169
+
170
+ # Output feature map size
171
+ out_h = h_tiles * effective_patches + 2 * overlap_margin
172
+ out_w = w_tiles * effective_patches + 2 * overlap_margin
173
+
174
+ # Initialize output
175
+ # Use numpy for easier slicing, convert to mx at the end
176
+ local_np = np.array(local_features)
177
+ output = np.zeros((out_h, out_w, hidden_size), dtype=local_np.dtype)
178
+
179
+ crop_idx = 0
180
+ for i in range(h_tiles):
181
+ for j in range(w_tiles):
182
+ crop_features = local_np[crop_idx] # [27, 27, 1152]
183
+
184
+ # Determine which margins to keep based on position
185
+ top_margin = overlap_margin if i == 0 else 0
186
+ bottom_margin = overlap_margin if i == h_tiles - 1 else 0
187
+ left_margin = overlap_margin if j == 0 else 0
188
+ right_margin = overlap_margin if j == w_tiles - 1 else 0
189
+
190
+ # Trim interior margins
191
+ start_y = 0 if i == 0 else overlap_margin
192
+ end_y = patches_per_side if i == h_tiles - 1 else patches_per_side - overlap_margin
193
+ start_x = 0 if j == 0 else overlap_margin
194
+ end_x = patches_per_side if j == w_tiles - 1 else patches_per_side - overlap_margin
195
+
196
+ trimmed = crop_features[start_y:end_y, start_x:end_x]
197
+
198
+ # Calculate output position
199
+ out_y = 0 if i == 0 else (patches_per_side - overlap_margin) + (i - 1) * effective_patches
200
+ out_x = 0 if j == 0 else (patches_per_side - overlap_margin) + (j - 1) * effective_patches
201
+
202
+ out_h_slice = end_y - start_y
203
+ out_w_slice = end_x - start_x
204
+
205
+ output[out_y : out_y + out_h_slice, out_x : out_x + out_w_slice] = trimmed
206
+
207
+ crop_idx += 1
208
+
209
+ return mx.array(output)
210
+
211
+
212
+ def adaptive_avg_pool2d(
213
+ x: mx.array,
214
+ output_size: Tuple[int, int],
215
+ ) -> mx.array:
216
+ """
217
+ Adaptive average pooling that pools input to a fixed output size.
218
+
219
+ Args:
220
+ x: Input tensor [H, W, C] or [C, H, W]
221
+ output_size: Target (H_out, W_out)
222
+
223
+ Returns:
224
+ Pooled tensor with spatial dimensions matching output_size
225
+ """
226
+ # Assume input is [H, W, C] (channel last)
227
+ H, W, C = x.shape
228
+ out_h, out_w = output_size
229
+
230
+ if H == out_h and W == out_w:
231
+ return x
232
+
233
+ # Calculate kernel and stride sizes for adaptive pooling
234
+ # Kernel size = ceil(input_size / output_size)
235
+ # Stride = floor(input_size / output_size)
236
+ kernel_h = (H + out_h - 1) // out_h
237
+ kernel_w = (W + out_w - 1) // out_w
238
+ stride_h = H // out_h
239
+ stride_w = W // out_w
240
+
241
+ # Pad if necessary to ensure we can cover the output size
242
+ pad_h = max(0, (out_h - 1) * stride_h + kernel_h - H)
243
+ pad_w = max(0, (out_w - 1) * stride_w + kernel_w - W)
244
+
245
+ if pad_h > 0 or pad_w > 0:
246
+ # Pad with zeros
247
+ x = mx.pad(x, [(0, pad_h), (0, pad_w), (0, 0)])
248
+
249
+ # Perform pooling using a simple averaging approach
250
+ # Convert to [1, H, W, C] for batch processing
251
+ x = x[None, :, :, :] # [1, H, W, C]
252
+
253
+ # Use reshape and mean for pooling
254
+ result = np.zeros((out_h, out_w, C), dtype=np.float32)
255
+ x_np = np.array(x[0]) # [H, W, C]
256
+
257
+ for i in range(out_h):
258
+ for j in range(out_w):
259
+ # Calculate the input region for this output pixel
260
+ h_start = i * stride_h
261
+ h_end = min(h_start + kernel_h, x_np.shape[0])
262
+ w_start = j * stride_w
263
+ w_end = min(w_start + kernel_w, x_np.shape[1])
264
+
265
+ # Average pool
266
+ region = x_np[h_start:h_end, w_start:w_end, :]
267
+ result[i, j, :] = region.mean(axis=(0, 1))
268
+
269
+ return mx.array(result)