fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,263 @@
1
+ import math
2
+ from typing import Dict, Optional, Union
3
+
4
+ import mlx.core as mx
5
+ import mlx.nn as nn
6
+ import numpy as np
7
+
8
+ from ..base import interpolate
9
+ from .config import VisionConfig
10
+
11
+
12
+ def check_array_shape(arr):
13
+ shape = arr.shape
14
+
15
+ # Check if the shape has 4 dimensions
16
+ if len(shape) != 4:
17
+ return False
18
+
19
+ out_channels, kH, KW, _ = shape
20
+
21
+ # Check if out_channels is the largest, and kH and KW are the same
22
+ if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
23
+ return True
24
+ else:
25
+ return False
26
+
27
+
28
+ class Attention(nn.Module):
29
+ def __init__(
30
+ self,
31
+ dims: int,
32
+ num_heads: int,
33
+ qkv_bias: bool = True,
34
+ ):
35
+ super().__init__()
36
+
37
+ if (dims % num_heads) != 0:
38
+ raise ValueError(
39
+ "The input feature dimensions should be divisible by the "
40
+ f"number of heads ({dims} % {num_heads}) != 0"
41
+ )
42
+
43
+ self.num_heads = num_heads = num_heads
44
+ head_dim = dims // num_heads
45
+ self.scale = head_dim**-0.5
46
+
47
+ self.qkv_proj = nn.Linear(dims, dims * 3, bias=qkv_bias)
48
+ self.out_proj = nn.Linear(dims, dims, bias=True)
49
+
50
+ def __call__(self, x, mask=None):
51
+ qkv = self.qkv_proj(x)
52
+ queries, keys, values = mx.split(qkv, 3, axis=-1)
53
+
54
+ num_heads = self.num_heads
55
+ B, L, D = queries.shape
56
+ _, S, _ = keys.shape
57
+ queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
58
+ keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
59
+ values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
60
+
61
+ output = mx.fast.scaled_dot_product_attention(
62
+ queries, keys, values, scale=self.scale, mask=mask
63
+ )
64
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
65
+
66
+ return self.out_proj(output)
67
+
68
+
69
+ class MLP(nn.Module):
70
+ def __init__(self, config: Union[VisionConfig, Dict], bias: bool = True):
71
+ super().__init__()
72
+ self.activation_fn = nn.GELU()
73
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=bias)
74
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=bias)
75
+
76
+ def __call__(self, x: mx.array) -> mx.array:
77
+ x = self.activation_fn(self.fc1(x))
78
+ x = self.fc2(x)
79
+ return x
80
+
81
+
82
+ class EncoderLayer(nn.Module):
83
+ def __init__(self, config: VisionConfig):
84
+ super().__init__()
85
+ self.embed_dim = config.hidden_size
86
+ self.self_attn = Attention(
87
+ config.hidden_size, config.num_attention_heads, qkv_bias=True
88
+ )
89
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
90
+ self.mlp = MLP(config)
91
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
92
+
93
+ def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
94
+ y = self.layer_norm1(x)
95
+ y = self.self_attn(y, mask)
96
+ x = x + y
97
+ y = self.layer_norm2(x)
98
+ y = self.mlp(y)
99
+ return x + y
100
+
101
+
102
+ class VisionEmbeddings(nn.Module):
103
+ def __init__(self, config: VisionConfig):
104
+ super().__init__()
105
+ self.config = config
106
+ self.embed_dim = config.hidden_size
107
+ self.image_size = 224
108
+ self.patch_size = config.patch_size
109
+
110
+ self.class_embedding = mx.random.normal((self.embed_dim,))
111
+
112
+ self.patch_embedding = nn.Conv2d(
113
+ in_channels=config.num_channels,
114
+ out_channels=self.embed_dim,
115
+ kernel_size=self.patch_size,
116
+ stride=self.patch_size,
117
+ bias=False,
118
+ )
119
+
120
+ self.num_patches = (self.image_size // self.patch_size) ** 2
121
+ self.num_positions = self.num_patches + 1
122
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
123
+
124
+ def _get_abs_pos(self, abs_pos, tgt_size):
125
+ """
126
+ Resize absolute positional embeddings
127
+
128
+ Args:
129
+ abs_pos: Tensor of shape (L, C) - absolute position embeddings
130
+ tgt_size: int - target size M
131
+
132
+ Returns:
133
+ Tensor of shape (M, C) - resized position embeddings
134
+ """
135
+ dim = abs_pos.shape[-1]
136
+ abs_pos_new = mx.squeeze(abs_pos, axis=0)
137
+ cls_token, old_pos_embed = abs_pos_new[:1], abs_pos_new[1:]
138
+ src_size = int(math.sqrt(abs_pos_new.shape[0] - 1))
139
+ tgt_size_2d = int(math.sqrt(tgt_size))
140
+ dtype = abs_pos.dtype
141
+
142
+ if src_size != tgt_size_2d:
143
+ # Reshape to (1, src_size, src_size, dim) then transpose to (1, dim, src_size, src_size)
144
+ old_pos_embed = mx.reshape(old_pos_embed, (1, src_size, src_size, dim))
145
+ old_pos_embed = mx.transpose(old_pos_embed, (0, 3, 1, 2))
146
+ old_pos_embed = old_pos_embed.astype(mx.float32)
147
+
148
+ new_pos_embed = interpolate(old_pos_embed, (tgt_size_2d, tgt_size_2d))
149
+
150
+ new_pos_embed = new_pos_embed.astype(dtype)
151
+ new_pos_embed = mx.transpose(new_pos_embed, (0, 2, 3, 1))
152
+ new_pos_embed = mx.reshape(new_pos_embed, (tgt_size_2d * tgt_size_2d, dim))
153
+ vision_pos_embed = mx.concatenate([cls_token, new_pos_embed], axis=0)
154
+ vision_pos_embed = mx.reshape(
155
+ vision_pos_embed, (1, tgt_size_2d * tgt_size_2d + 1, dim)
156
+ )
157
+ return vision_pos_embed
158
+ else:
159
+ return abs_pos
160
+
161
+ def __call__(
162
+ self, x: mx.array, patch_embeds: Optional[mx.array] = None
163
+ ) -> mx.array:
164
+ batch_size, height, width, _ = x.shape
165
+ target_dtype = self.position_embedding.weight.dtype
166
+
167
+ if patch_embeds is not None:
168
+ patch_embeddings = patch_embeds
169
+ else:
170
+ patch_embeddings = self.patch_embedding(x)
171
+
172
+ # Flatten patch embeddings properly
173
+ patch_embeds = mx.flatten(patch_embeddings, start_axis=1, end_axis=2)
174
+
175
+ # Broadcast class embedding
176
+ class_embeds = mx.broadcast_to(
177
+ self.class_embedding, (batch_size, 1, self.embed_dim)
178
+ ).astype(target_dtype)
179
+
180
+ # Concatenate class and patch embeddings
181
+ embeddings = mx.concatenate([class_embeds, patch_embeds], axis=1)
182
+
183
+ # Create position IDs
184
+ position_ids = mx.array(np.arange(self.num_positions)[None, :])
185
+
186
+ # Add positional embeddings
187
+ embeddings = embeddings + self._get_abs_pos(
188
+ self.position_embedding(position_ids), embeddings.shape[1]
189
+ ).astype(target_dtype)
190
+
191
+ return embeddings
192
+
193
+
194
+ class NoTPTransformer(nn.Module):
195
+ def __init__(self, config: VisionConfig):
196
+ super().__init__()
197
+ self.num_layers = config.layers
198
+ self.layers = [EncoderLayer(config) for _ in range(config.layers)]
199
+
200
+ def __call__(
201
+ self,
202
+ x: mx.array,
203
+ ) -> mx.array:
204
+ for l in self.layers:
205
+ x = l(x, mask=None)
206
+ return x
207
+
208
+
209
+ class VisionModel(nn.Module):
210
+ def __init__(self, config: VisionConfig):
211
+ super().__init__()
212
+
213
+ self.model_type = config.model_type
214
+ self.config = config
215
+ if self.model_type != "vision":
216
+ raise ValueError(f"Unsupported model type: {self.model_type}")
217
+
218
+ self.embeddings = VisionEmbeddings(config)
219
+ self.pre_layrnorm = nn.LayerNorm(config.hidden_size)
220
+ self.transformer = NoTPTransformer(config)
221
+
222
+ def __call__(
223
+ self,
224
+ x: mx.array,
225
+ patch_embeds: mx.array = None,
226
+ ) -> mx.array:
227
+ x = self.embeddings(x, patch_embeds)
228
+ x = self.pre_layrnorm(x)
229
+ return self.transformer(x)
230
+
231
+ def sanitize(self, weights):
232
+ sanitized_weights = {}
233
+ weight_keys = {
234
+ "neck.0.weight",
235
+ "neck.2.weight",
236
+ "neck_hd.0.weight",
237
+ "neck_hd.2.weight",
238
+ "sam_model.net_2.weight",
239
+ "sam_model.net_3.weight",
240
+ "downsamples.0.weight",
241
+ "downsamples.1.weight",
242
+ "patch_embed.proj.weight",
243
+ "embeddings.patch_embedding.weight",
244
+ }
245
+ for k, v in weights.items():
246
+ if "position_ids" in k:
247
+ # Remove unused position_ids
248
+ continue
249
+
250
+ elif ".".join(k.split(".")[-3:]) in weight_keys:
251
+ # PyTorch conv2d weight tensors have shape:
252
+ # [out_channels, in_channels, kH, KW]
253
+ # MLX conv2d expects the weight be of shape:
254
+ # [out_channels, kH, KW, in_channels]
255
+ if check_array_shape(v):
256
+ sanitized_weights[k] = v
257
+ else:
258
+ sanitized_weights[k] = v.transpose(0, 2, 3, 1)
259
+
260
+ else:
261
+ sanitized_weights[k] = v
262
+
263
+ return sanitized_weights
@@ -0,0 +1,12 @@
1
+ # Import shared LanguageModel from deepseekocr
2
+ from ..deepseekocr.language import LanguageModel
3
+ from .config import (
4
+ MLPConfig,
5
+ ModelConfig,
6
+ ProjectorConfig,
7
+ Qwen2EncoderConfig,
8
+ TextConfig,
9
+ VisionConfig,
10
+ )
11
+ from .deepseekocr_2 import DeepseekOCR2Processor, Model
12
+ from .vision import Qwen2Decoder2Encoder, VisionModel
@@ -0,0 +1,216 @@
1
+ import inspect
2
+ from dataclasses import dataclass
3
+ from typing import Dict, List, Optional, Tuple, Union
4
+
5
+ from ..base import BaseModelConfig
6
+
7
+
8
+ @dataclass
9
+ class TextConfig(BaseModelConfig):
10
+ model_type: str = "deepseek_v2"
11
+ vocab_size: int = 102400
12
+ hidden_size: int = 1280
13
+ intermediate_size: int = 6848
14
+ moe_intermediate_size: int = 896
15
+ num_hidden_layers: int = 30
16
+ num_attention_heads: int = 32
17
+ num_key_value_heads: int = 32
18
+ n_shared_experts: Optional[int] = 2
19
+ n_routed_experts: Optional[int] = 64
20
+ routed_scaling_factor: float = 1.0
21
+ kv_lora_rank: int = 512
22
+ q_lora_rank: int = 1536
23
+ qk_rope_head_dim: int = 0
24
+ v_head_dim: int = 128
25
+ qk_nope_head_dim: int = 0
26
+ topk_method: str = "greedy"
27
+ n_group: Optional[int] = 1
28
+ topk_group: Optional[int] = 1
29
+ num_experts_per_tok: Optional[int] = 6
30
+ moe_layer_freq: int = 1
31
+ first_k_dense_replace: int = 0
32
+ max_position_embeddings: int = 2048
33
+ rms_norm_eps: float = 1e-6
34
+ rope_theta: float = 10000.0
35
+ rope_traditional: bool = False
36
+ rope_scaling: Dict = None
37
+ attention_bias: bool = False
38
+ scoring_func: str = "softmax"
39
+ attn_type: str = "DeepseekV2Attention"
40
+
41
+ def __post_init__(self):
42
+ if self.qk_nope_head_dim == 0:
43
+ self.attn_type = "LlamaAttention"
44
+
45
+ if self.num_key_value_heads is None:
46
+ self.num_key_value_heads = self.num_attention_heads
47
+
48
+
49
+ @dataclass
50
+ class Qwen2EncoderConfig(BaseModelConfig):
51
+ """Configuration for the Qwen2 decoder-as-encoder in the vision model."""
52
+
53
+ dim: int = 896
54
+ layers: int = 24
55
+ heads: int = 14
56
+ kv_heads: int = 2
57
+ intermediate_size: int = 4864
58
+ rms_norm_eps: float = 1e-6
59
+ rope_theta: float = 1000000.0
60
+
61
+
62
+ @dataclass
63
+ class VisionConfig(BaseModelConfig):
64
+ model_type: str
65
+ layers: int = 24
66
+ width: int = 1152
67
+ hidden_size: int = 896
68
+ intermediate_size: int = 4096
69
+ num_attention_heads: int = 16
70
+ image_size: int = 1024
71
+ patch_size: int = 14
72
+ num_channels: int = 3
73
+ layer_norm_eps: float = 1e-6
74
+ mlp_ratio: float = 3.7362
75
+ cls: str = None
76
+ params: dict = None
77
+
78
+ @classmethod
79
+ def from_dict(cls, params):
80
+ # Parse width configuration for SAM and Qwen2
81
+ width = params.get("width", {})
82
+ qwen2_config = width.get("qwen2-0-5b", {})
83
+ sam_config = width.get("sam_vit_b", {})
84
+
85
+ # Build qwen2 params for VisionModel
86
+ qwen2_params = {
87
+ "dim": qwen2_config.get("dim", 896),
88
+ "layers": 24, # Default for Qwen2 encoder
89
+ "heads": 14,
90
+ "kv_heads": 2,
91
+ "intermediate_size": 4864,
92
+ "rms_norm_eps": 1e-6,
93
+ "rope_theta": 1000000.0,
94
+ }
95
+
96
+ # Update params to include qwen2 config
97
+ if params.get("params") is None:
98
+ params["params"] = {}
99
+ params["params"]["qwen2"] = qwen2_params
100
+ params["params"]["sam"] = sam_config
101
+
102
+ # Set hidden_size from qwen2 dim
103
+ if "hidden_size" not in params:
104
+ params["hidden_size"] = qwen2_config.get("dim", 896)
105
+
106
+ return super().from_dict(params)
107
+
108
+
109
+ @dataclass
110
+ class MLPConfig(BaseModelConfig):
111
+ hidden_size: int
112
+ intermediate_size: int
113
+ hidden_act: str = "gelu"
114
+
115
+
116
+ @dataclass
117
+ class ProjectorConfig(BaseModelConfig):
118
+ projector_type: str = "linear"
119
+ input_dim: int = 2048
120
+ n_embed: int = 1280
121
+ depth: int = 2
122
+ mlp_ratio: int = 1
123
+ downsample_ratio: int = 2
124
+ token_pooling: bool = False
125
+
126
+
127
+ @dataclass
128
+ class SAMViTConfig(BaseModelConfig):
129
+ image_size: Union[Tuple[int, int], int] = 1024
130
+ width: int = 768
131
+ layers: int = 12
132
+ heads: int = 12
133
+ patch_size: int = 16
134
+ window_size: int = 14
135
+ prompt_embed_dim: int = 256
136
+ global_attn_indexes: Union[List[int], Tuple[int]] = (2, 5, 8, 11)
137
+ downsample_channels: Union[List[int], Tuple[int]] = (512, 1024)
138
+
139
+
140
+ @dataclass
141
+ class ModelConfig(BaseModelConfig):
142
+ text_config: TextConfig
143
+ vision_config: VisionConfig
144
+ projector_config: ProjectorConfig
145
+ model_type: str
146
+ ignore_index: int = -100
147
+ image_token_index: int = 128815
148
+ vision_feature_select_strategy: str = "default"
149
+ select_layer: int = -1
150
+ pad_id: int = 100001
151
+ num_image_tokens: int = 576
152
+ vocab_size: int = 32000
153
+ tile_tag: str = "2D"
154
+ global_view_pos: str = "head"
155
+ eos_token_id: Optional[List[int]] = None
156
+ quantization: Optional[Dict] = None
157
+
158
+ @classmethod
159
+ def from_dict(cls, params):
160
+ if "language_config" in params:
161
+ params["text_config"] = params["language_config"]
162
+ del params["language_config"]
163
+
164
+ return cls(
165
+ text_config=TextConfig.from_dict(params["text_config"]),
166
+ vision_config=VisionConfig.from_dict(params["vision_config"]),
167
+ projector_config=ProjectorConfig.from_dict(params["projector_config"]),
168
+ **{
169
+ k: v
170
+ for k, v in params.items()
171
+ if k in inspect.signature(cls).parameters
172
+ and k not in ["text_config", "vision_config", "projector_config"]
173
+ },
174
+ )
175
+
176
+
177
+ @dataclass
178
+ class Conversation:
179
+ """A class that represents a conversation."""
180
+
181
+ system: str
182
+ roles: List[str]
183
+ messages: List[List[str]]
184
+ offset: int
185
+ sep_style: int
186
+ sep: str
187
+ sep2: str
188
+ version: str = "Unknown"
189
+
190
+
191
+ @dataclass
192
+ class VLChatProcessorOutput:
193
+ """
194
+ Output of the VL chat processor.
195
+ """
196
+
197
+ sft_format: str
198
+ input_ids: List[int]
199
+ pixel_values: List
200
+ num_image_tokens: List[int]
201
+ image_grid_thw: List[List[int]]
202
+ image_sizes: Optional[List[List[int]]] = None
203
+ videos: Optional[List] = None
204
+ aspect_ratio_ids: Optional[List[int]] = None
205
+ aspect_ratio_mask: Optional[List[List[int]]] = None
206
+ cross_attention_mask: Optional[List[List[List[int]]]] = None
207
+ attention_mask: Optional[List[int]] = None
208
+ labels: Optional[List[int]] = None
209
+
210
+
211
+ @dataclass
212
+ class BatchCollateOutput:
213
+ input_ids: List
214
+ labels: List
215
+ attention_mask: List
216
+ pixel_values: List