fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,322 @@
1
+ from functools import partial
2
+ from math import sqrt
3
+ from typing import Dict, Optional, Union
4
+
5
+ import mlx.core as mx
6
+ import mlx.nn as nn
7
+
8
+ from .config import MLPConfig, VisionConfig
9
+
10
+
11
+ def check_array_shape(arr):
12
+ shape = arr.shape
13
+
14
+ # Check if the shape has 4 dimensions
15
+ if len(shape) != 4:
16
+ return False
17
+
18
+ out_channels, kH, KW, _ = shape
19
+
20
+ # Check if out_channels is the largest, and kH and KW are the same
21
+ if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
22
+ return True
23
+ else:
24
+ return False
25
+
26
+
27
+ class AttentionPoolLatent(nn.Module):
28
+ """Attention pooling w/ latent query"""
29
+
30
+ def __init__(
31
+ self,
32
+ in_features: int,
33
+ out_features: int = None,
34
+ embed_dim: int = None,
35
+ num_heads: int = 8,
36
+ mlp_ratio: float = 4.0,
37
+ qkv_bias: bool = True,
38
+ qk_norm: bool = False,
39
+ latent_len: int = 1,
40
+ latent_dim: int = None,
41
+ pos_embed: str = "",
42
+ pool_type: str = "token",
43
+ norm_layer: Optional[nn.Module] = None,
44
+ drop: float = 0.0,
45
+ ):
46
+ super().__init__()
47
+
48
+ embed_dim = embed_dim or in_features
49
+ out_features = out_features or in_features
50
+ assert embed_dim % num_heads == 0
51
+ self.num_heads = num_heads
52
+ self.head_dim = embed_dim // num_heads
53
+ self.scale = self.head_dim**-0.5
54
+ self.pool = pool_type
55
+
56
+ self.latent_dim = latent_dim or embed_dim
57
+ self.latent_len = latent_len
58
+ self.latent = mx.zeros((self.latent_len, embed_dim))[None, :]
59
+
60
+ self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
61
+ self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias)
62
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
63
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
64
+ self.proj = nn.Linear(embed_dim, embed_dim)
65
+ self.proj_drop = nn.Dropout(drop)
66
+
67
+ if pos_embed == "abs":
68
+ spatial_len = self.feat_size
69
+ self.pos_embed = mx.zeros((spatial_len, in_features))
70
+ else:
71
+ self.pos_embed = None
72
+
73
+ self.norm = nn.LayerNorm(out_features)
74
+ config = MLPConfig(
75
+ width=embed_dim, intermediate_size=int(embed_dim * mlp_ratio)
76
+ )
77
+ self.mlp = MLP(config)
78
+
79
+ def __call__(self, x: mx.array):
80
+ B, N, C = x.shape
81
+
82
+ if self.pos_embed is not None:
83
+ x = x + self.pos_embed.unsqueeze(0).to(x.dtype)
84
+
85
+ q_latent = mx.array(self.latent)
86
+
87
+ q = (
88
+ self.q(q_latent)
89
+ .reshape(B, self.latent_len, self.num_heads, self.head_dim)
90
+ .transpose(0, 2, 1, 3)
91
+ )
92
+
93
+ kv = (
94
+ self.kv(x)
95
+ .reshape(B, N, 2, self.num_heads, self.head_dim)
96
+ .transpose(2, 0, 3, 1, 4)
97
+ )
98
+ k, v = mx.split(kv, 2, axis=0)
99
+
100
+ q, k = self.q_norm(q), self.k_norm(k)
101
+
102
+ x = mx.fast.scaled_dot_product_attention(
103
+ q, k[0], v[0], scale=(1.0 / sqrt(q.shape[-1])), mask=None
104
+ )
105
+
106
+ x = x.transpose(0, 2, 1, 3).reshape(B, self.latent_len, C)
107
+ x = self.proj(x)
108
+ x = self.proj_drop(x)
109
+
110
+ x = x + self.mlp(self.norm(x))
111
+
112
+ # optional pool if latent seq_len > 1 and pooled output is desired
113
+ if self.pool == "token":
114
+ x = x[:, 0]
115
+ elif self.pool == "avg":
116
+ x = x.mean(1)
117
+ return x
118
+
119
+
120
+ class Attention(nn.Module):
121
+ def __init__(
122
+ self,
123
+ dims: int,
124
+ num_heads: int,
125
+ qkv_bias: bool = True,
126
+ ):
127
+ super().__init__()
128
+
129
+ if (dims % num_heads) != 0:
130
+ raise ValueError(
131
+ "The input feature dimensions should be divisible by the "
132
+ f"number of heads ({dims} % {num_heads}) != 0"
133
+ )
134
+
135
+ self.num_heads = num_heads = num_heads
136
+ head_dim = dims // num_heads
137
+ self.scale = head_dim**-0.5
138
+
139
+ self.qkv = nn.Linear(dims, dims * 3, bias=qkv_bias)
140
+ self.proj = nn.Linear(dims, dims, bias=True)
141
+
142
+ def __call__(self, x, mask=None):
143
+ qkv = self.qkv(x)
144
+ queries, keys, values = mx.split(qkv, 3, axis=-1)
145
+
146
+ num_heads = self.num_heads
147
+ B, L, D = queries.shape
148
+ _, S, _ = keys.shape
149
+ queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
150
+ keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
151
+ values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
152
+
153
+ output = mx.fast.scaled_dot_product_attention(
154
+ queries, keys, values, scale=self.scale, mask=mask
155
+ )
156
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
157
+
158
+ return self.proj(output)
159
+
160
+
161
+ class MLP(nn.Module):
162
+ def __init__(self, config: Union[VisionConfig, Dict], bias: bool = True):
163
+ super().__init__()
164
+ self.activation_fn = nn.GELU(approx="precise")
165
+ self.fc1 = nn.Linear(config.width, config.intermediate_size, bias=bias)
166
+ self.fc2 = nn.Linear(config.intermediate_size, config.width, bias=bias)
167
+
168
+ def __call__(self, x: mx.array) -> mx.array:
169
+ x = self.activation_fn(self.fc1(x))
170
+ x = self.fc2(x)
171
+ return x
172
+
173
+
174
+ class EncoderLayer(nn.Module):
175
+ def __init__(self, config: VisionConfig):
176
+ super().__init__()
177
+ self.embed_dim = config.width
178
+ self.attn = Attention(config.width, config.num_attention_heads, qkv_bias=True)
179
+ self.norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
180
+ self.mlp = MLP(config)
181
+ self.norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
182
+
183
+ def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
184
+ y = self.norm1(x)
185
+ y = self.attn(y, mask)
186
+ x = x + y
187
+ y = self.norm2(x)
188
+ y = self.mlp(y)
189
+ return x + y
190
+
191
+
192
+ class VisionEmbeddings(nn.Module):
193
+ def __init__(self, config: VisionConfig, norm_layer: bool = False):
194
+ super().__init__()
195
+ self.config = config
196
+ self.embed_dim = config.width
197
+ self.image_size = config.image_size
198
+ self.patch_size = config.patch_size
199
+
200
+ self.proj = nn.Conv2d(
201
+ in_channels=config.num_channels,
202
+ out_channels=self.embed_dim,
203
+ kernel_size=self.patch_size,
204
+ stride=self.patch_size,
205
+ )
206
+
207
+ self.num_patches = (self.image_size // self.patch_size) ** 2
208
+ self.num_positions = self.num_patches
209
+
210
+ self.norm = (
211
+ nn.LayerNorm(config.width, eps=config.layer_norm_eps)
212
+ if norm_layer
213
+ else nn.Identity()
214
+ )
215
+
216
+ def __call__(self, x: mx.array) -> mx.array:
217
+ patch_embeddings = self.proj(x)
218
+ patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2)
219
+ return self.norm(patch_embeddings)
220
+
221
+
222
+ class SigLipVisionModel(nn.Module):
223
+ def __init__(
224
+ self,
225
+ config: VisionConfig,
226
+ ignore_head: bool,
227
+ pre_norm: bool = False,
228
+ no_embed_class: bool = True,
229
+ ):
230
+ super().__init__()
231
+ self.num_prefix_tokens = 1
232
+ self.no_embed_class = False
233
+ self.dynamic_img_size = False
234
+ self.ignore_head = ignore_head
235
+ self.cls_token = None
236
+ self.reg_token = None
237
+ self.patch_embed = VisionEmbeddings(config)
238
+ self.norm_pre = nn.LayerNorm(config.width) if pre_norm else nn.Identity()
239
+ self.blocks = [EncoderLayer(config) for _ in range(config.layers)]
240
+ self.norm = nn.LayerNorm(config.width)
241
+ num_patches = self.patch_embed.num_patches
242
+ embed_len = (
243
+ num_patches if no_embed_class else num_patches + self.num_prefix_tokens
244
+ )
245
+ self.pos_embed = mx.random.normal((embed_len, config.width))[None, :]
246
+
247
+ norm_layer = partial(nn.LayerNorm, eps=1e-5)
248
+ self.attn_pool = AttentionPoolLatent(
249
+ config.width,
250
+ num_heads=config.num_attention_heads,
251
+ norm_layer=norm_layer,
252
+ mlp_ratio=config.mlp_ratio,
253
+ )
254
+
255
+ def __call__(
256
+ self,
257
+ x: mx.array,
258
+ output_hidden_states: Optional[bool] = None,
259
+ ) -> mx.array:
260
+ x = self.patch_embed(x)
261
+ x += self.pos_embed
262
+ x = self.norm_pre(x)
263
+
264
+ encoder_states = (x,) if output_hidden_states else None
265
+ for l in self.blocks:
266
+ x = l(x, mask=None)
267
+ if output_hidden_states:
268
+ encoder_states = encoder_states + (x,)
269
+
270
+ pooler_output = self.norm(x)
271
+
272
+ if not self.ignore_head:
273
+ pooler_output = self.attn_pool(pooler_output)
274
+ return pooler_output, x, encoder_states
275
+
276
+
277
+ class VisionModel(nn.Module):
278
+ def __init__(self, config: VisionConfig, ignore_head: bool = True):
279
+ super().__init__()
280
+
281
+ self.model_type = config.model_type
282
+ self.config = config
283
+ if self.model_type != "vision":
284
+ raise ValueError(f"Unsupported model type: {self.model_type}")
285
+
286
+ self.vision_tower = SigLipVisionModel(config, ignore_head)
287
+
288
+ def __call__(
289
+ self, x: mx.array, output_hidden_states: Optional[bool] = None
290
+ ) -> mx.array:
291
+ return self.vision_tower(x, output_hidden_states)
292
+
293
+ def sanitize(self, weights):
294
+ sanitized_weights = {}
295
+ weight_keys = {
296
+ "neck.0.weight",
297
+ "neck.2.weight",
298
+ "neck_hd.0.weight",
299
+ "neck_hd.2.weight",
300
+ "downsamples.0.weight",
301
+ "downsamples.1.weight",
302
+ "patch_embed.proj.weight",
303
+ }
304
+ for k, v in weights.items():
305
+ if "position_ids" in k:
306
+ # Remove unused position_ids
307
+ continue
308
+
309
+ elif ".".join(k.split(".")[-3:]) in weight_keys:
310
+ # PyTorch conv2d weight tensors have shape:
311
+ # [out_channels, in_channels, kH, KW]
312
+ # MLX conv2d expects the weight be of shape:
313
+ # [out_channels, kH, KW, in_channels]
314
+ if check_array_shape(v):
315
+ sanitized_weights[k] = v
316
+ else:
317
+ sanitized_weights[k] = v.transpose(0, 2, 3, 1)
318
+
319
+ else:
320
+ sanitized_weights[k] = v
321
+
322
+ return sanitized_weights
@@ -0,0 +1,2 @@
1
+ from .config import MLPConfig, ModelConfig, ProjectorConfig, TextConfig, VisionConfig
2
+ from .deepseekocr import DeepseekOCRProcessor, LanguageModel, Model, VisionModel
@@ -0,0 +1,173 @@
1
+ import inspect
2
+ from dataclasses import dataclass
3
+ from typing import Dict, List, Optional, Tuple, Union
4
+
5
+ from ..base import BaseModelConfig
6
+
7
+
8
+ @dataclass
9
+ class TextConfig(BaseModelConfig):
10
+ model_type: str = "deepseek_v2"
11
+ vocab_size: int = 102400
12
+ hidden_size: int = 1280
13
+ intermediate_size: int = 6848
14
+ moe_intermediate_size: int = 896
15
+ num_hidden_layers: int = 30
16
+ num_attention_heads: int = 32
17
+ num_key_value_heads: int = 32
18
+ n_shared_experts: Optional[int] = 2
19
+ n_routed_experts: Optional[int] = 64
20
+ routed_scaling_factor: float = 1.0
21
+ kv_lora_rank: int = 512
22
+ q_lora_rank: int = 1536
23
+ qk_rope_head_dim: int = 0
24
+ v_head_dim: int = 128
25
+ qk_nope_head_dim: int = 0
26
+ topk_method: str = "greedy"
27
+ n_group: Optional[int] = 1
28
+ topk_group: Optional[int] = 1
29
+ num_experts_per_tok: Optional[int] = 6
30
+ moe_layer_freq: int = 1
31
+ first_k_dense_replace: int = 0
32
+ max_position_embeddings: int = 2048
33
+ rms_norm_eps: float = 1e-6
34
+ rope_theta: float = 10000.0
35
+ rope_traditional: bool = False
36
+ rope_scaling: Dict = None
37
+ attention_bias: bool = False
38
+ scoring_func: str = "softmax"
39
+ attn_type: str = "DeepseekV2Attention"
40
+
41
+ def __post_init__(self):
42
+ if self.qk_nope_head_dim == 0:
43
+ self.attn_type = "LlamaAttention"
44
+
45
+ if self.num_key_value_heads is None:
46
+ self.num_key_value_heads = self.num_attention_heads
47
+
48
+
49
+ @dataclass
50
+ class VisionConfig(BaseModelConfig):
51
+ model_type: str
52
+ layers: int = 24
53
+ width: int = 1152
54
+ hidden_size: int = 1024
55
+ intermediate_size: int = 4096
56
+ num_attention_heads: int = 16
57
+ image_size: int = 224
58
+ patch_size: int = 14
59
+ num_channels: int = 3
60
+ layer_norm_eps: float = 1e-6
61
+ mlp_ratio: float = 3.7362
62
+ cls: str = None
63
+ params: dict = None
64
+
65
+
66
+ @dataclass
67
+ class MLPConfig(BaseModelConfig):
68
+ hidden_size: int
69
+ intermediate_size: int
70
+ hidden_act: str = "gelu"
71
+
72
+
73
+ @dataclass
74
+ class ProjectorConfig(BaseModelConfig):
75
+ projector_type: str = "linear"
76
+ input_dim: int = 2048
77
+ n_embed: int = 1280
78
+ depth: int = 2
79
+ mlp_ratio: int = 1
80
+ downsample_ratio: int = 2
81
+ token_pooling: bool = False
82
+
83
+
84
+ @dataclass
85
+ class SAMViTConfig(BaseModelConfig):
86
+ image_size: Union[Tuple[int, int], int] = 1024
87
+ width: int = 768
88
+ layers: int = 12
89
+ heads: int = 12
90
+ patch_size: int = 16
91
+ window_size: int = 14
92
+ prompt_embed_dim: int = 256
93
+ global_attn_indexes: Union[List[int], Tuple[int]] = (2, 5, 8, 11)
94
+ downsample_channels: Union[List[int], Tuple[int]] = (512, 1024)
95
+
96
+
97
+ @dataclass
98
+ class ModelConfig(BaseModelConfig):
99
+ text_config: TextConfig
100
+ vision_config: VisionConfig
101
+ projector_config: ProjectorConfig
102
+ model_type: str
103
+ ignore_index: int = -100
104
+ image_token_index: int = 128815
105
+ vision_feature_select_strategy: str = "default"
106
+ select_layer: int = -1
107
+ pad_id: int = 100001
108
+ num_image_tokens: int = 576
109
+ vocab_size: int = 32000
110
+ tile_tag: str = "2D"
111
+ global_view_pos: str = "head"
112
+ eos_token_id: Optional[List[int]] = None
113
+ quantization: Optional[Dict] = None
114
+
115
+ @classmethod
116
+ def from_dict(cls, params):
117
+ if "language_config" in params:
118
+ params["text_config"] = params["language_config"]
119
+ del params["language_config"]
120
+
121
+ return cls(
122
+ text_config=TextConfig.from_dict(params["text_config"]),
123
+ vision_config=VisionConfig.from_dict(params["vision_config"]),
124
+ projector_config=ProjectorConfig.from_dict(params["projector_config"]),
125
+ **{
126
+ k: v
127
+ for k, v in params.items()
128
+ if k in inspect.signature(cls).parameters
129
+ and k not in ["text_config", "vision_config", "projector_config"]
130
+ },
131
+ )
132
+
133
+
134
+ @dataclass
135
+ class Conversation:
136
+ """A class that represents a conversation."""
137
+
138
+ system: str
139
+ roles: List[str]
140
+ messages: List[List[str]]
141
+ offset: int
142
+ sep_style: int
143
+ sep: str
144
+ sep2: str
145
+ version: str = "Unknown"
146
+
147
+
148
+ @dataclass
149
+ class VLChatProcessorOutput:
150
+ """
151
+ Output of the VL chat processor.
152
+ """
153
+
154
+ sft_format: str
155
+ input_ids: List[int]
156
+ pixel_values: List
157
+ num_image_tokens: List[int]
158
+ image_grid_thw: List[List[int]]
159
+ image_sizes: Optional[List[List[int]]] = None
160
+ videos: Optional[List] = None
161
+ aspect_ratio_ids: Optional[List[int]] = None
162
+ aspect_ratio_mask: Optional[List[List[int]]] = None
163
+ cross_attention_mask: Optional[List[List[List[int]]]] = None
164
+ attention_mask: Optional[List[int]] = None
165
+ labels: Optional[List[int]] = None
166
+
167
+
168
+ @dataclass
169
+ class BatchCollateOutput:
170
+ input_ids: List
171
+ labels: List
172
+ attention_mask: List
173
+ pixel_values: List