fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,692 @@
1
+ from typing import List, Optional, Tuple
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from .config import VisionConfig
7
+
8
+
9
+ class NamedSequential(nn.Module):
10
+ def __init__(self):
11
+ super().__init__()
12
+ self._order = []
13
+
14
+ def add_module(self, name, module):
15
+ setattr(self, name, module)
16
+ self._order.append(name)
17
+
18
+ def __call__(self, x):
19
+ for name in self._order:
20
+ x = getattr(self, name)(x)
21
+ return x
22
+
23
+
24
+ class CallableModuleList(list):
25
+ def __call__(self, x: mx.array):
26
+ for item in self:
27
+ x = item(x)
28
+ return x
29
+
30
+
31
+ class MHSA(nn.Module):
32
+ """Multi-headed Self Attention module.
33
+
34
+ Source modified from:
35
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ dim: int,
41
+ head_dim: int = 32,
42
+ qkv_bias: bool = False,
43
+ attn_drop: float = 0.0,
44
+ proj_drop: float = 0.0,
45
+ ) -> None:
46
+ super().__init__()
47
+ assert dim % head_dim == 0, "dim should be divisible by head_dim"
48
+ self.head_dim = head_dim
49
+ self.num_heads = dim // head_dim
50
+ self.scale = head_dim**-0.5
51
+
52
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
53
+ self.attn_drop = nn.Dropout(attn_drop)
54
+ self.proj = nn.Linear(dim, dim)
55
+ self.proj_drop = nn.Dropout(proj_drop)
56
+
57
+ def __call__(self, x: mx.array) -> mx.array:
58
+ # Source: https://github.com/apple/ml-fastvlm/blob/592b4add3c1c8a518e77d95dc6248e76c1dd591f/llava/model/multimodal_encoder/mobileclip/mci.py#L661
59
+ x = x.transpose(0, 3, 1, 2)
60
+ B, C, H, W = x.shape
61
+ N = H * W
62
+ x = x.flatten(start_axis=2).transpose(0, 2, 1) # (B, N, C)
63
+ qkv = (
64
+ self.qkv(x)
65
+ .reshape(B, N, 3, self.num_heads, self.head_dim)
66
+ .transpose(2, 0, 3, 1, 4)
67
+ )
68
+ q, k, v = qkv
69
+
70
+ x = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=None)
71
+ x = x.transpose(0, 2, 1, 3).reshape(B, N, C)
72
+ x = self.proj(x)
73
+ x = self.proj_drop(x)
74
+
75
+ x = x.reshape(B, H, W, C)
76
+ return x
77
+
78
+
79
+ class ConvFFN(nn.Module):
80
+ """Convolutional FFN Module."""
81
+
82
+ def __init__(
83
+ self,
84
+ in_channels: int,
85
+ hidden_channels: Optional[int] = None,
86
+ out_channels: Optional[int] = None,
87
+ act_layer: nn.Module = nn.GELU,
88
+ ) -> None:
89
+ super().__init__()
90
+ out_channels = out_channels or in_channels
91
+ hidden_channels = hidden_channels or in_channels
92
+ self.conv = NamedSequential()
93
+ self.conv.add_module(
94
+ "conv",
95
+ nn.Conv2d(
96
+ in_channels=in_channels,
97
+ out_channels=out_channels,
98
+ kernel_size=7,
99
+ padding=3,
100
+ groups=in_channels,
101
+ bias=False,
102
+ ),
103
+ )
104
+ self.conv.add_module(
105
+ "bn",
106
+ nn.BatchNorm(num_features=out_channels),
107
+ )
108
+ self.fc1 = nn.Conv2d(in_channels, hidden_channels, kernel_size=1)
109
+ self.act = act_layer()
110
+ self.fc2 = nn.Conv2d(hidden_channels, out_channels, kernel_size=1)
111
+
112
+ def __call__(self, x: mx.array) -> mx.array:
113
+ x = self.conv(x)
114
+ x = self.fc1(x)
115
+ x = self.act(x)
116
+ x = self.fc2(x)
117
+ return x
118
+
119
+
120
+ class LayerNormChannel(nn.Module):
121
+ """
122
+ LayerNorm only for Channel Dimension.
123
+ Input: tensor in shape [B, H, W, C]
124
+ """
125
+
126
+ def __init__(self, num_features, eps=1e-05) -> None:
127
+ super().__init__()
128
+ self.weight = mx.ones(num_features)
129
+ self.bias = mx.zeros(num_features)
130
+ self.eps = eps
131
+
132
+ def __call__(self, x: mx.array) -> mx.array:
133
+ u = x.mean(-1, keepdims=True)
134
+ s = mx.power(x - u, 2).mean(-1, keepdims=True)
135
+ x = (x - u) / mx.sqrt(s + self.eps)
136
+ x = self.weight * x + self.bias
137
+ return x
138
+
139
+
140
+ class AttentionBlock(nn.Module):
141
+ """Implementation of metaformer block with MHSA as token mixer.
142
+
143
+ For more details on Metaformer structure, please refer to:
144
+ `MetaFormer Is Actually What You Need for Vision <https://arxiv.org/pdf/2111.11418.pdf>`_
145
+ """
146
+
147
+ def __init__(
148
+ self,
149
+ dim: int,
150
+ mlp_ratio: float = 4.0,
151
+ act_layer: nn.Module = nn.GELU,
152
+ norm_layer: nn.Module = nn.BatchNorm,
153
+ ):
154
+ super().__init__()
155
+
156
+ self.norm = norm_layer(num_features=dim)
157
+ self.token_mixer = MHSA(dim=dim)
158
+
159
+ assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format(
160
+ mlp_ratio
161
+ )
162
+ mlp_hidden_dim = int(dim * mlp_ratio)
163
+ self.convffn = ConvFFN(
164
+ in_channels=dim,
165
+ hidden_channels=mlp_hidden_dim,
166
+ act_layer=act_layer,
167
+ )
168
+
169
+ self.layer_scale_1 = mx.ones((1, 1, dim))
170
+ self.layer_scale_2 = mx.ones((1, 1, dim))
171
+
172
+ def __call__(self, x: mx.array) -> mx.array:
173
+ x = x + self.layer_scale_1 * self.token_mixer(self.norm(x))
174
+ x = x + self.layer_scale_2 * self.convffn(x)
175
+ return x
176
+
177
+
178
+ class RepCPE(nn.Module):
179
+ """Implementation of conditional positional encoding.
180
+
181
+ For more details refer to paper:
182
+ `Conditional Positional Encodings for Vision Transformers <https://arxiv.org/pdf/2102.10882.pdf>`_
183
+ """
184
+
185
+ def __init__(
186
+ self,
187
+ in_channels: int,
188
+ embed_dim: int = 768,
189
+ spatial_shape=(7, 7),
190
+ ) -> None:
191
+ super().__init__()
192
+ if isinstance(spatial_shape, int):
193
+ spatial_shape = tuple([spatial_shape] * 2)
194
+ assert isinstance(spatial_shape, Tuple), (
195
+ f'"spatial_shape" must by a sequence or int, '
196
+ f"get {type(spatial_shape)} instead."
197
+ )
198
+ assert len(spatial_shape) == 2, (
199
+ f'Length of "spatial_shape" should be 2, '
200
+ f"got {len(spatial_shape)} instead."
201
+ )
202
+
203
+ self.reparam_conv = nn.Conv2d(
204
+ in_channels=in_channels,
205
+ out_channels=embed_dim,
206
+ kernel_size=spatial_shape,
207
+ stride=1,
208
+ padding=int(spatial_shape[0] // 2),
209
+ groups=embed_dim,
210
+ bias=True,
211
+ )
212
+
213
+ def __call__(self, x: mx.array) -> mx.array:
214
+ return self.reparam_conv(x)
215
+
216
+
217
+ class ReparamLargeKernelConv(nn.Module):
218
+ """Building Block of RepLKNet
219
+
220
+ This class defines overparameterized large kernel conv block
221
+ introduced in `RepLKNet <https://arxiv.org/abs/2203.06717>`_
222
+
223
+ Reference: https://github.com/DingXiaoH/RepLKNet-pytorch
224
+ """
225
+
226
+ def __init__(
227
+ self,
228
+ in_channels: int,
229
+ out_channels: int,
230
+ kernel_size: int,
231
+ stride: int,
232
+ groups: int,
233
+ activation: nn.Module = nn.GELU(),
234
+ ) -> None:
235
+ super(ReparamLargeKernelConv, self).__init__()
236
+ self.activation = activation
237
+ self.lkb_reparam = nn.Conv2d(
238
+ in_channels=in_channels,
239
+ out_channels=out_channels,
240
+ kernel_size=kernel_size,
241
+ stride=stride,
242
+ padding=kernel_size // 2,
243
+ dilation=1,
244
+ groups=groups,
245
+ bias=True,
246
+ )
247
+
248
+ def __call__(self, x: mx.array) -> mx.array:
249
+ return self.activation(self.lkb_reparam(x))
250
+
251
+
252
+ class PatchEmbed(nn.Module):
253
+ """Convolutional patch embedding layer."""
254
+
255
+ def __init__(
256
+ self,
257
+ patch_size: int,
258
+ stride: int,
259
+ in_channels: int,
260
+ embed_dim: int,
261
+ ) -> None:
262
+ super().__init__()
263
+ self.proj = CallableModuleList()
264
+ self.proj.append(
265
+ ReparamLargeKernelConv(
266
+ in_channels=in_channels,
267
+ out_channels=embed_dim,
268
+ kernel_size=patch_size,
269
+ stride=stride,
270
+ groups=in_channels,
271
+ )
272
+ )
273
+ self.proj.append(
274
+ MobileOneBlock(
275
+ in_channels=embed_dim,
276
+ out_channels=embed_dim,
277
+ kernel_size=1,
278
+ stride=1,
279
+ padding=0,
280
+ groups=1,
281
+ use_se=False,
282
+ )
283
+ )
284
+
285
+ def __call__(self, x: mx.array) -> mx.array:
286
+ return self.proj(x)
287
+
288
+
289
+ class RepMixer(nn.Module):
290
+ """Reparameterizable token mixer.
291
+
292
+ For more details, please refer to Apple's paper:
293
+ `FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization <https://arxiv.org/pdf/2303.14189.pdf>`_
294
+ """
295
+
296
+ def __init__(
297
+ self,
298
+ dim,
299
+ kernel_size=3,
300
+ ):
301
+ super().__init__()
302
+ self.dim = dim
303
+ self.kernel_size = kernel_size
304
+
305
+ self.reparam_conv = nn.Conv2d(
306
+ in_channels=self.dim,
307
+ out_channels=self.dim,
308
+ kernel_size=self.kernel_size,
309
+ stride=1,
310
+ padding=self.kernel_size // 2,
311
+ groups=self.dim,
312
+ bias=True,
313
+ )
314
+
315
+ def __call__(self, x: mx.array) -> mx.array:
316
+ return self.reparam_conv(x)
317
+
318
+
319
+ class RepMixerBlock(nn.Module):
320
+ """Implementation of Metaformer block with RepMixer as token mixer.
321
+
322
+ For more details on Metaformer structure, please refer to:
323
+ `MetaFormer Is Actually What You Need for Vision <https://arxiv.org/pdf/2111.11418.pdf>`_
324
+ """
325
+
326
+ def __init__(
327
+ self,
328
+ dim: int,
329
+ kernel_size: int = 3,
330
+ mlp_ratio: float = 4.0,
331
+ act_layer: nn.Module = nn.GELU,
332
+ ):
333
+ super().__init__()
334
+
335
+ self.token_mixer = RepMixer(dim, kernel_size=kernel_size)
336
+
337
+ assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format(
338
+ mlp_ratio
339
+ )
340
+ mlp_hidden_dim = int(dim * mlp_ratio)
341
+ self.convffn = ConvFFN(
342
+ in_channels=dim,
343
+ hidden_channels=mlp_hidden_dim,
344
+ act_layer=act_layer,
345
+ )
346
+ self.layer_scale = mx.ones((1, 1, dim))
347
+
348
+ def __call__(self, x: mx.array) -> mx.array:
349
+ x = self.token_mixer(x)
350
+ x = x + self.layer_scale * self.convffn(x)
351
+ return x
352
+
353
+
354
+ def basic_blocks(
355
+ dim: int,
356
+ block_index: int,
357
+ num_blocks: List[int],
358
+ token_mixer_type: str,
359
+ kernel_size: int = 3,
360
+ mlp_ratio: float = 4.0,
361
+ act_layer: nn.Module = nn.GELU,
362
+ norm_layer: nn.Module = nn.BatchNorm,
363
+ ):
364
+ blocks = CallableModuleList()
365
+ for _ in range(num_blocks[block_index]):
366
+ if token_mixer_type == "repmixer":
367
+ blocks.append(
368
+ RepMixerBlock(
369
+ dim,
370
+ kernel_size=kernel_size,
371
+ mlp_ratio=mlp_ratio,
372
+ act_layer=act_layer,
373
+ )
374
+ )
375
+ elif token_mixer_type == "attention":
376
+ blocks.append(
377
+ AttentionBlock(
378
+ dim,
379
+ mlp_ratio=mlp_ratio,
380
+ act_layer=act_layer,
381
+ norm_layer=norm_layer,
382
+ )
383
+ )
384
+ else:
385
+ raise ValueError(
386
+ "Token mixer type: {} not supported".format(token_mixer_type)
387
+ )
388
+ return blocks
389
+
390
+
391
+ def build_fast_vit_network(config: VisionConfig):
392
+ network = []
393
+ for i in range(len(config.layers)):
394
+ spatial_shape = config.pos_embs_shapes[i]
395
+ if spatial_shape is not None:
396
+ position_embeddings = RepCPE(
397
+ in_channels=config.embed_dims[i],
398
+ embed_dim=config.embed_dims[i],
399
+ spatial_shape=spatial_shape,
400
+ )
401
+ network.append(position_embeddings)
402
+
403
+ stage = basic_blocks(
404
+ config.embed_dims[i],
405
+ i,
406
+ config.layers,
407
+ token_mixer_type=config.token_mixers[i],
408
+ kernel_size=config.repmixer_kernel_size,
409
+ mlp_ratio=config.mlp_ratios[i],
410
+ norm_layer=LayerNormChannel,
411
+ )
412
+ network.append(stage)
413
+
414
+ if i >= len(config.layers) - 1:
415
+ break
416
+
417
+ # Patch merging/downsampling between stages.
418
+ if config.downsamples[i] or config.embed_dims[i] != config.embed_dims[i + 1]:
419
+ network.append(
420
+ PatchEmbed(
421
+ patch_size=config.down_patch_size,
422
+ stride=config.down_stride,
423
+ in_channels=config.embed_dims[i],
424
+ embed_dim=config.embed_dims[i + 1],
425
+ )
426
+ )
427
+ return network
428
+
429
+
430
+ class SEBlock(nn.Module):
431
+ """Squeeze and Excite module.
432
+
433
+ MLX implementation of `Squeeze-and-Excitation Networks` -
434
+ https://arxiv.org/pdf/1709.01507.pdf
435
+ """
436
+
437
+ def __init__(self, in_channels: int, rd_ratio: float = 0.0625):
438
+ """Construct a Squeeze and Excite Module.
439
+
440
+ Args:
441
+ in_channels: Number of input channels.
442
+ rd_ratio: Input channel reduction ratio.
443
+ """
444
+ super().__init__()
445
+ self.reduce = nn.Conv2d(
446
+ in_channels=in_channels,
447
+ out_channels=int(in_channels * rd_ratio),
448
+ kernel_size=1,
449
+ stride=1,
450
+ bias=True,
451
+ )
452
+ self.expand = nn.Conv2d(
453
+ in_channels=int(in_channels * rd_ratio),
454
+ out_channels=in_channels,
455
+ kernel_size=1,
456
+ stride=1,
457
+ bias=True,
458
+ )
459
+
460
+ def __call__(self, inputs: mx.array) -> mx.array:
461
+ _, h, w, c = inputs.shape
462
+ x = nn.AvgPool2d(kernel_size=[h, w])(inputs)
463
+ x = self.reduce(x)
464
+ x = nn.layers.relu(x)
465
+ x = self.expand(x)
466
+ x = mx.sigmoid(x)
467
+ x = x.reshape(-1, 1, 1, c)
468
+ return inputs * x
469
+
470
+
471
+ class MobileOneBlock(nn.Module):
472
+ """MobileOne building block.
473
+
474
+ This implementation only uses the inference time CNN architecture and uses FastViTHD conventions.
475
+ """
476
+
477
+ def __init__(
478
+ self,
479
+ in_channels: int,
480
+ out_channels: int,
481
+ kernel_size: int,
482
+ stride: int = 1,
483
+ padding: int = 0,
484
+ dilation: int = 1,
485
+ groups: int = 1,
486
+ use_se: bool = False,
487
+ ):
488
+ super().__init__()
489
+ self.groups = groups
490
+ self.stride = stride
491
+ self.padding = padding
492
+ self.dilation = dilation
493
+ self.kernel_size = kernel_size
494
+ self.in_channels = in_channels
495
+ self.out_channels = out_channels
496
+
497
+ # Check if SE-ReLU is requested
498
+ if use_se:
499
+ self.se = SEBlock(out_channels)
500
+ else:
501
+ self.se = nn.Identity()
502
+
503
+ self.activation = nn.GELU()
504
+ self.reparam_conv = nn.Conv2d(
505
+ in_channels=in_channels,
506
+ out_channels=out_channels,
507
+ kernel_size=kernel_size,
508
+ stride=stride,
509
+ padding=padding,
510
+ dilation=dilation,
511
+ groups=groups,
512
+ bias=True,
513
+ )
514
+
515
+ def __call__(self, x: mx.array) -> mx.array:
516
+ return self.activation(self.se(self.reparam_conv(x)))
517
+
518
+
519
+ class ConvolutionalStem(nn.Module):
520
+ def __init__(self, config: VisionConfig):
521
+ super().__init__()
522
+ in_channels = 3
523
+ out_channels = config.embed_dims[0]
524
+ self.blocks = CallableModuleList(
525
+ [
526
+ MobileOneBlock(
527
+ in_channels=in_channels,
528
+ out_channels=out_channels,
529
+ kernel_size=3,
530
+ stride=2,
531
+ padding=1,
532
+ groups=1,
533
+ ),
534
+ MobileOneBlock(
535
+ in_channels=out_channels,
536
+ out_channels=out_channels,
537
+ kernel_size=3,
538
+ stride=2,
539
+ padding=1,
540
+ groups=out_channels,
541
+ ),
542
+ MobileOneBlock(
543
+ in_channels=out_channels,
544
+ out_channels=out_channels,
545
+ kernel_size=1,
546
+ stride=1,
547
+ padding=0,
548
+ groups=1,
549
+ ),
550
+ ]
551
+ )
552
+
553
+ def __call__(self, x: mx.array) -> mx.array:
554
+ return self.blocks(x)
555
+
556
+
557
+ class FastViTHDModel(nn.Module):
558
+ """
559
+ Based on https://github.com/apple/ml-fastvlm/blob/592b4add3c1c8a518e77d95dc6248e76c1dd591f/llava/model/multimodal_encoder/mobileclip/mci.py
560
+ Hardcoded, for now, for:
561
+ - FastViTHD variant
562
+ - Use inference_mode (i.e., modules contain the convolutional reparameterized versions of the architecture)
563
+ """
564
+
565
+ def __init__(self, config: VisionConfig):
566
+ super().__init__()
567
+ if config.pos_embs_shapes is None:
568
+ config.pos_embs_shapes = [None] * len(config.layers)
569
+ self.config = config
570
+
571
+ # We follow the nomenclature from mci.py
572
+ self.patch_embed = ConvolutionalStem(config)
573
+ self.network = build_fast_vit_network(config)
574
+ self.conv_exp = MobileOneBlock(
575
+ in_channels=config.embed_dims[-1],
576
+ out_channels=int(config.embed_dims[-1] * config.cls_ratio),
577
+ kernel_size=3,
578
+ stride=1,
579
+ padding=1,
580
+ groups=config.embed_dims[-1],
581
+ use_se=True,
582
+ )
583
+ self.head = nn.Linear(
584
+ int(config.embed_dims[-1] * config.cls_ratio), config.num_classes
585
+ )
586
+
587
+ def __call__(
588
+ self,
589
+ x: mx.array,
590
+ output_hidden_states: Optional[bool] = None,
591
+ ):
592
+ x = self.patch_embed(x)
593
+
594
+ encoder_states = (x,) if output_hidden_states else None
595
+ for layer in self.network:
596
+ x = layer(x)
597
+ if output_hidden_states:
598
+ encoder_states = encoder_states + (x,)
599
+
600
+ x = self.conv_exp(x)
601
+ cls_out = self.head(x)
602
+
603
+ return cls_out, x, encoder_states
604
+
605
+
606
+ class GlobalPool2D(nn.Module):
607
+ """This class implements global pooling with linear projection."""
608
+
609
+ def __init__(self, in_dim: int, out_dim: int) -> None:
610
+ super().__init__()
611
+ self.proj = mx.zeros((in_dim, out_dim))
612
+
613
+ def __call__(self, x: mx.array) -> mx.array:
614
+ assert (
615
+ x.ndim == 4
616
+ ), "Input should be 4-dimensional (Batch x in_dim x in_height x in_width). Got: {}".format(
617
+ x.shape
618
+ )
619
+
620
+ # [batch, in_height, in_width, in_dim] --> [batch, in_dim]
621
+ x = x.mean(axis=[1, 2])
622
+ # [batch, in_dim] x [in_dim, out_dim] --> [batch, out_dim]
623
+ x = x @ self.proj
624
+ return x
625
+
626
+
627
+ class VisionModel(nn.Module):
628
+ def __init__(self, config: VisionConfig):
629
+ super().__init__()
630
+
631
+ self.model_type = config.model_type
632
+ if self.model_type not in ["llava_qwen2", "fastvlm"]:
633
+ raise ValueError(f"Unsupported model type: {self.model_type}")
634
+
635
+ self.vision_model = FastViTHDModel(config)
636
+
637
+ # Replace projection head, same as in
638
+ # https://github.com/apple/ml-fastvlm/blob/592b4add3c1c8a518e77d95dc6248e76c1dd591f/llava/model/multimodal_encoder/mobileclip/__init__.py#L49
639
+ if config.projection_dim is not None:
640
+ in_dim = int(config.embed_dims[-1] * config.cls_ratio)
641
+ self.vision_model.head = GlobalPool2D(in_dim, config.projection_dim)
642
+
643
+ def __call__(
644
+ self, x: mx.array, output_hidden_states: Optional[bool] = None
645
+ ) -> mx.array:
646
+ return self.vision_model(x, output_hidden_states)
647
+
648
+ def sanitize(self, weights):
649
+ # Only transpose during conversion from transformers
650
+ W, C = weights[
651
+ "vision_tower.vision_model.patch_embed.blocks.1.reparam_conv.weight"
652
+ ].shape[-2:]
653
+ skip_transpose = W > C
654
+
655
+ def is_conv(k):
656
+ if skip_transpose:
657
+ return False
658
+ if ".reparam_conv.weight" in k:
659
+ return True
660
+ if ".conv.weight" in k:
661
+ return True
662
+ if ".fc1.weight" in k:
663
+ return True
664
+ if ".fc2.weight" in k:
665
+ return True
666
+ if ".lkb_reparam.weight" in k:
667
+ return True
668
+ if ".reduce.weight" in k:
669
+ return True
670
+ if ".expand.weight" in k:
671
+ return True
672
+ return False
673
+
674
+ sanitized_weights = {}
675
+ for k, v in weights.items():
676
+ if is_conv(k):
677
+ # PyTorch conv2d weight tensors have shape:
678
+ # [out_channels, in_channels, kH, KW]
679
+ # MLX conv2d expects the weight be of shape:
680
+ # [out_channels, kH, KW, in_channels]
681
+ if v.ndim == 4:
682
+ sanitized_weights[k] = v.transpose(0, 2, 3, 1)
683
+ else:
684
+ sanitized_weights[k] = v
685
+ elif "layer_scale" in k and not skip_transpose:
686
+ sanitized_weights[k] = v.transpose(1, 2, 0)
687
+ elif "num_batches_tracked" in k:
688
+ # I don't think we need this
689
+ continue
690
+ else:
691
+ sanitized_weights[k] = v
692
+ return sanitized_weights
@@ -0,0 +1,2 @@
1
+ from .config import ModelConfig, TextConfig, VisionConfig
2
+ from .florence2 import LanguageModel, Model, VisionModel