fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,383 @@
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from ..base import InputEmbeddingsFeatures
7
+ from ..pixtral import VisionModel
8
+ from .config import ModelConfig
9
+ from .language import LanguageModel
10
+
11
+
12
+ def _pair(x) -> Tuple[int, int]:
13
+ """Convert input to a pair of values."""
14
+ if isinstance(x, (list, tuple)):
15
+ return tuple(x)
16
+ return (x, x)
17
+
18
+
19
+ def unfold(
20
+ input: mx.array,
21
+ kernel_size: Union[int, Tuple[int, int], List[int]],
22
+ dilation: Union[int, Tuple[int, int], List[int]] = 1,
23
+ padding: Union[int, Tuple[int, int], List[int]] = 0,
24
+ stride: Union[int, Tuple[int, int], List[int]] = 1,
25
+ ) -> mx.array:
26
+ """
27
+ Extract sliding local blocks from a batched input tensor (MLX implementation).
28
+
29
+ This is equivalent to PyTorch's nn.functional.unfold or im2col operation.
30
+
31
+ Args:
32
+ input: Input tensor of shape (B, C, H, W)
33
+ kernel_size: Size of the sliding blocks
34
+ dilation: Controls the spacing between kernel elements
35
+ padding: Controls the amount of implicit padding
36
+ stride: Controls the stride between blocks
37
+
38
+ Returns:
39
+ Unfolded tensor of shape (B, C*kernel_height*kernel_width, L)
40
+ where L is the number of blocks
41
+ """
42
+ # Convert to pairs
43
+ kernel_size = _pair(kernel_size)
44
+ dilation = _pair(dilation)
45
+ padding = _pair(padding)
46
+ stride = _pair(stride)
47
+
48
+ # Input shape
49
+ batch_size, channels, height, width = input.shape
50
+
51
+ # Add padding if needed
52
+ if padding[0] > 0 or padding[1] > 0:
53
+ padding_shape = (
54
+ (0, 0),
55
+ (0, 0),
56
+ (padding[0], padding[0]),
57
+ (padding[1], padding[1]),
58
+ )
59
+ input = mx.pad(input, padding_shape)
60
+
61
+ # Calculate output dimensions
62
+ height_out = (
63
+ height + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1
64
+ ) // stride[0] + 1
65
+ width_out = (
66
+ width + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1
67
+ ) // stride[1] + 1
68
+
69
+ # Initialize output arrays
70
+ blocks = []
71
+
72
+ # Extract blocks
73
+ for i in range(
74
+ 0, height + 2 * padding[0] - kernel_size[0] * dilation[0] + 1, stride[0]
75
+ ):
76
+ for j in range(
77
+ 0, width + 2 * padding[1] - kernel_size[1] * dilation[1] + 1, stride[1]
78
+ ):
79
+ # Extract the block for all channels
80
+ block = []
81
+ for di in range(kernel_size[0]):
82
+ for dj in range(kernel_size[1]):
83
+ h_idx = i + di * dilation[0]
84
+ w_idx = j + dj * dilation[1]
85
+ # Get the block for all channels and add to our list
86
+ block.append(input[:, :, h_idx, w_idx])
87
+
88
+ # Stack the channel-blocks
89
+ block = mx.stack(block, axis=1) # Shape: (B, k*k, C)
90
+ block = mx.transpose(block, [0, 2, 1]) # Shape: (B, C, k*k)
91
+ blocks.append(block)
92
+
93
+ # Stack all blocks together
94
+ result = mx.stack(blocks, axis=-1) # Shape: (B, C, k*k, L)
95
+
96
+ # Reshape to match PyTorch's unfold output format: (B, C*k*k, L)
97
+ result = mx.reshape(
98
+ result,
99
+ (
100
+ batch_size,
101
+ channels * kernel_size[0] * kernel_size[1],
102
+ height_out * width_out,
103
+ ),
104
+ )
105
+
106
+ return result
107
+
108
+
109
+ class Mistral3PatchMerger(nn.Module):
110
+ """
111
+ Learned merging of spatial_merge_size ** 2 patches
112
+ """
113
+
114
+ def __init__(self, config: ModelConfig):
115
+ super().__init__()
116
+ self.config = config
117
+
118
+ hidden_size = config.vision_config.hidden_size
119
+ self.spatial_merge_size = config.spatial_merge_size
120
+ self.patch_size = self.config.vision_config.patch_size
121
+ self.merging_layer = nn.Linear(
122
+ hidden_size * self.spatial_merge_size**2, hidden_size, bias=False
123
+ )
124
+
125
+ def __call__(self, image_features: mx.array, image_sizes: mx.array) -> mx.array:
126
+
127
+ image_sizes = [
128
+ (image_size[0] // self.patch_size, image_size[1] // self.patch_size)
129
+ for image_size in image_sizes
130
+ ]
131
+
132
+ tokens_per_image = [h * w for h, w in image_sizes]
133
+ d = image_features.shape[-1]
134
+ image_features = image_features.astype(mx.bfloat16)
135
+ image_sizes = mx.array(image_sizes)
136
+
137
+ # Split the image features into chunks based on tokens_per_image
138
+ split_indices = []
139
+ current_index = 0
140
+ for tokens in tokens_per_image:
141
+ split_indices.append(current_index + tokens)
142
+ current_index += tokens
143
+
144
+ # Perform the split
145
+ chunks = mx.split(image_features, split_indices[:-1], axis=1)
146
+
147
+ permuted_tensor = []
148
+ for image_index, image_tokens in enumerate(chunks):
149
+
150
+ # Reshape image_tokens into a 2D grid
151
+ if image_tokens.shape[1] > 0:
152
+ h, w = image_sizes[image_index].tolist()
153
+
154
+ image_grid = image_tokens.reshape(h, w, d).transpose(2, 0, 1)[None, ...]
155
+
156
+ grid = unfold(
157
+ image_grid,
158
+ kernel_size=self.spatial_merge_size,
159
+ stride=self.spatial_merge_size,
160
+ )
161
+ grid = grid.reshape(d * self.spatial_merge_size**2, -1).T
162
+ permuted_tensor.append(grid)
163
+
164
+ image_features = mx.concatenate(permuted_tensor, axis=0)
165
+ image_features = self.merging_layer(image_features)
166
+ return image_features[None, ...]
167
+
168
+
169
+ class Mistral3MultiModalProjector(nn.Module):
170
+ def __init__(self, config: ModelConfig):
171
+ super().__init__()
172
+
173
+ self.norm = nn.RMSNorm(config.vision_config.hidden_size)
174
+ self.patch_merger = Mistral3PatchMerger(config)
175
+
176
+ num_feature_layers = (
177
+ 1
178
+ if isinstance(config.vision_feature_layer, int)
179
+ else len(config.vision_feature_layer)
180
+ )
181
+ self.linear_1 = nn.Linear(
182
+ config.vision_config.hidden_size * num_feature_layers,
183
+ config.text_config.hidden_size,
184
+ bias=config.multimodal_projector_bias,
185
+ )
186
+ self.gelu = nn.GELU()
187
+ self.linear_2 = nn.Linear(
188
+ config.text_config.hidden_size,
189
+ config.text_config.hidden_size,
190
+ bias=config.multimodal_projector_bias,
191
+ )
192
+
193
+ def __call__(self, x: mx.array, image_sizes: mx.array) -> mx.array:
194
+ x = self.norm(x)
195
+
196
+ x = self.patch_merger(x, image_sizes)
197
+ x = self.linear_1(x)
198
+ x = self.gelu(x)
199
+ x = self.linear_2(x)
200
+ return x
201
+
202
+
203
+ class Model(nn.Module):
204
+ def __init__(self, config: ModelConfig):
205
+ super().__init__()
206
+ self.config = config
207
+
208
+ self.multi_modal_projector = Mistral3MultiModalProjector(config)
209
+ self.vision_tower = VisionModel(config.vision_config)
210
+ self.language_model = LanguageModel(config.text_config)
211
+ self.vision_feature_layer = config.vision_feature_layer
212
+
213
+ def get_input_embeddings(
214
+ self,
215
+ input_ids: Optional[mx.array] = None,
216
+ pixel_values: Optional[mx.array] = None,
217
+ **kwargs,
218
+ ):
219
+ image_sizes = kwargs.get("image_sizes", None)
220
+
221
+ if pixel_values is None:
222
+ return InputEmbeddingsFeatures(
223
+ inputs_embeds=self.language_model.model.embed_tokens(input_ids)
224
+ )
225
+
226
+ # Get the input embeddings from the language model
227
+ inputs_embeds = self.language_model.model.embed_tokens(input_ids)
228
+
229
+ # Get the output hidden states from the vision model
230
+ if isinstance(pixel_values, list):
231
+ pixel_values = mx.concatenate(
232
+ [mx.array(pv)[None, ...] for pv in pixel_values], axis=0
233
+ )
234
+ if pixel_values.ndim == 3:
235
+ pixel_values = pixel_values[None, ...]
236
+
237
+ # Pass pixel_values as list of images, as each image is individually run through conv2d and position encoding
238
+ # Reference code from transformers: https://github.com/huggingface/transformers/blob/main/src/transformers/models/pixtral/modeling_pixtral.py#L479C9-L479C21
239
+ # and mistral_inference: https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/vision_encoder.py#L85
240
+ *_, hidden_states = self.vision_tower(
241
+ pixel_values.transpose(0, 2, 3, 1),
242
+ output_hidden_states=True,
243
+ )
244
+ # Select the hidden states from the desired layer
245
+ selected_image_feature = hidden_states[self.vision_feature_layer]
246
+
247
+ # Pass image features through the multi-modal projector
248
+ image_features = self.multi_modal_projector(selected_image_feature, image_sizes)
249
+
250
+ # Insert special image tokens in the input_ids
251
+ final_inputs_embeds = self.merge_input_ids_with_image_features(
252
+ self.config.image_token_index, image_features, inputs_embeds, input_ids
253
+ )
254
+ return InputEmbeddingsFeatures(inputs_embeds=final_inputs_embeds)
255
+
256
+ @staticmethod
257
+ def merge_input_ids_with_image_features(
258
+ image_token_index, image_features, inputs_embeds, input_ids
259
+ ):
260
+ """Merge image features into input embeddings at image token positions.
261
+
262
+ Args:
263
+ image_token_index: Token ID for image placeholder
264
+ image_features: Vision features from the projector [1, num_features, hidden_dim]
265
+ inputs_embeds: Input embeddings [batch_size, seq_len, hidden_dim]
266
+ input_ids: Input token IDs [batch_size, seq_len]
267
+
268
+ Returns:
269
+ Updated input embeddings with image features inserted
270
+ """
271
+ # Remove the extra batch dimension from image_features if present
272
+ if image_features.ndim == 3 and image_features.shape[0] == 1:
273
+ image_features = image_features.squeeze(0) # [num_features, hidden_dim]
274
+
275
+ # Positions of <image> tokens in input_ids
276
+ image_positions = input_ids == image_token_index
277
+
278
+ # Get dimensions
279
+ batch_size, seq_len = input_ids.shape
280
+
281
+ # Process each batch item
282
+ batch_outputs = []
283
+ feature_start_idx = 0
284
+
285
+ for batch_idx in range(batch_size):
286
+ # Get mask for this batch
287
+ image_mask = image_positions[batch_idx]
288
+ num_positions = mx.sum(image_mask).item()
289
+
290
+ if num_positions > 0:
291
+ # Extract features for this batch
292
+ batch_features = image_features[
293
+ feature_start_idx : feature_start_idx + num_positions
294
+ ]
295
+
296
+ # Validate we have the right number of features
297
+ if batch_features.shape[0] != num_positions:
298
+ raise ValueError(
299
+ f"Number of image token positions ({num_positions}) does not match "
300
+ f"number of image features ({batch_features.shape[0]}) for batch {batch_idx}"
301
+ )
302
+
303
+ # Create indices for gathering
304
+ cumsum = mx.cumsum(image_mask.astype(mx.int32))
305
+ feature_indices = mx.where(image_mask, cumsum - 1, 0)
306
+
307
+ # Gather features
308
+ gathered_features = batch_features[feature_indices]
309
+
310
+ # Combine with original embeddings
311
+ image_mask_expanded = mx.expand_dims(image_mask, axis=-1)
312
+ batch_output = mx.where(
313
+ image_mask_expanded, gathered_features, inputs_embeds[batch_idx]
314
+ )
315
+
316
+ feature_start_idx += num_positions
317
+ else:
318
+ # No image tokens in this batch item
319
+ batch_output = inputs_embeds[batch_idx]
320
+
321
+ batch_outputs.append(batch_output)
322
+
323
+ # Stack all batch outputs
324
+ return mx.stack(batch_outputs, axis=0)
325
+
326
+ def __call__(
327
+ self,
328
+ input_ids: mx.array,
329
+ pixel_values: mx.array,
330
+ mask: mx.array,
331
+ cache=None,
332
+ **kwargs,
333
+ ):
334
+ input_embeddings_features = self.get_input_embeddings(
335
+ input_ids, pixel_values, **kwargs
336
+ )
337
+ logits = self.language_model(
338
+ input_ids,
339
+ cache=cache,
340
+ inputs_embeds=input_embeddings_features.inputs_embeds,
341
+ )
342
+ return logits
343
+
344
+ def sanitize(self, weights):
345
+ def transform_key(key):
346
+ if "vision_tower" in key and "vision_model" not in key:
347
+ if "transformer" in key:
348
+ key = key.replace("vision_tower", "vision_tower.vision_model")
349
+ if "patch_conv" in key:
350
+ key = key.replace("vision_tower", "vision_tower.vision_model")
351
+ if "ln_pre" in key:
352
+ key = key.replace("vision_tower", "vision_tower.vision_model")
353
+
354
+ elif "vision_encoder" in key and "vision_tower" not in key:
355
+ if "transformer" in key:
356
+ key = key.replace(
357
+ "model.vision_encoder", "vision_tower.vision_model"
358
+ )
359
+ if "patch_conv" in key:
360
+ key = key.replace(
361
+ "model.vision_encoder", "vision_tower.vision_model"
362
+ )
363
+ if "ln_pre" in key:
364
+ key = key.replace(
365
+ "model.vision_encoder", "vision_tower.vision_model"
366
+ )
367
+
368
+ elif "model.language_model" in key and "language_model.model" not in key:
369
+ key = key.replace("model.language_model", "language_model.model")
370
+
371
+ elif "lm_head" in key and "language_model" not in key:
372
+ key = key.replace("lm_head", "language_model.lm_head")
373
+
374
+ elif "model.vision_projection" in key:
375
+ key = key.replace("model.vision_projection", "multi_modal_projector")
376
+
377
+ return key
378
+
379
+ return {transform_key(k): v for k, v in weights.items()}
380
+
381
+ @property
382
+ def layers(self):
383
+ return self.language_model.model.layers
@@ -0,0 +1,4 @@
1
+ from .config import ModelConfig, TextConfig, VisionConfig
2
+ from .language import LanguageModel
3
+ from .mllama import Model
4
+ from .vision import VisionModel
@@ -0,0 +1,74 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import Dict, List, Optional, Tuple, Union
3
+
4
+ from ..base import BaseModelConfig
5
+
6
+
7
+ @dataclass
8
+ class TextConfig(BaseModelConfig):
9
+ model_type: str = "mllama"
10
+ vocab_size: int = 32000
11
+ hidden_size: int = 4096
12
+ intermediate_size: int = 14336
13
+ num_hidden_layers: int = 40
14
+ num_attention_heads: int = 32
15
+ num_key_value_heads: int = 8
16
+ hidden_act: str = "silu"
17
+ max_position_embeddings: int = 131072
18
+ initializer_range: float = 0.02
19
+ rms_norm_eps: float = 1e-6
20
+ tie_word_embeddings: bool = False
21
+ rope_theta: float = 10000.0
22
+ rope_traditional: bool = False
23
+ rope_scaling: Optional[Dict[str, Union[float, str]]] = None
24
+ cross_attention_layers: List[int] = field(
25
+ default_factory=lambda: [3, 8, 13, 18, 23, 28, 33, 38]
26
+ )
27
+
28
+ def __post_init__(self):
29
+ if self.num_key_value_heads is None:
30
+ self.num_key_value_heads = self.num_attention_heads
31
+
32
+
33
+ @dataclass
34
+ class VisionConfig(BaseModelConfig):
35
+ image_size: int = 560
36
+ patch_size: int = 14
37
+ num_channels: int = 3
38
+ hidden_size: int = 1280
39
+ intermediate_size: int = 5120
40
+ num_hidden_layers: int = 32
41
+ num_attention_heads: int = 16
42
+ max_num_tiles: int = 4
43
+ max_aspect_ratio_id: int = 8
44
+ num_global_layers: int = 8
45
+ norm_eps: float = 1e-5
46
+ attention_dropout: float = 0.0
47
+ hidden_dropout: float = 0.0
48
+ vision_output_dim: int = 7680
49
+ intermediate_layers_indices: List[int] = field(
50
+ default_factory=lambda: [3, 7, 15, 23, 30]
51
+ )
52
+ supported_aspect_ratios: Tuple[List[int]] = (
53
+ [1, 1],
54
+ [1, 2],
55
+ [1, 3],
56
+ [1, 4],
57
+ [2, 1],
58
+ [2, 2],
59
+ [3, 1],
60
+ [4, 1],
61
+ )
62
+
63
+
64
+ @dataclass
65
+ class ModelConfig(BaseModelConfig):
66
+ text_config: TextConfig
67
+ vision_config: VisionConfig
68
+ model_type: str
69
+ ignore_index: int = -100
70
+ image_token_index: int = 128256
71
+ vision_feature_select_strategy: str = "default"
72
+ vision_feature_layer: int = -2
73
+ vocab_size: int = 32000
74
+ eos_token_id: Optional[List[int]] = None