fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,489 @@
1
+ from typing import Optional, Tuple, Type
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+
7
+ def get_abs_pos_sam(abs_pos, tgt_size):
8
+ """Interpolate absolute positional embeddings to target size."""
9
+ dtype = abs_pos.dtype
10
+ src_size = abs_pos.shape[1]
11
+
12
+ if src_size != tgt_size:
13
+ # Transpose to (B, C, H, W) for interpolation
14
+ old_pos_embed = abs_pos.transpose(0, 3, 1, 2)
15
+ old_pos_embed = old_pos_embed.astype(mx.float32)
16
+
17
+ # Bicubic interpolation
18
+ from ..kernels import bicubic_interpolate
19
+
20
+ new_pos_embed = bicubic_interpolate(
21
+ old_pos_embed, size=(tgt_size, tgt_size), antialias=True
22
+ ).astype(dtype)
23
+
24
+ # Transpose back to (B, H, W, C)
25
+ new_pos_embed = new_pos_embed.transpose(0, 2, 3, 1)
26
+ return new_pos_embed
27
+ else:
28
+ return abs_pos
29
+
30
+
31
+ class MLPBlock(nn.Module):
32
+ """MLP block with GELU activation."""
33
+
34
+ def __init__(
35
+ self,
36
+ embedding_dim: int,
37
+ mlp_dim: int,
38
+ act: Type[nn.Module] = nn.GELU,
39
+ ) -> None:
40
+ super().__init__()
41
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
42
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
43
+ self.act = act()
44
+
45
+ def __call__(self, x: mx.array) -> mx.array:
46
+ return self.lin2(self.act(self.lin1(x)))
47
+
48
+
49
+ class Attention(nn.Module):
50
+ """Multi-head Attention block with relative position embeddings."""
51
+
52
+ def __init__(
53
+ self,
54
+ dim: int,
55
+ num_heads: int = 8,
56
+ qkv_bias: bool = True,
57
+ use_rel_pos: bool = False,
58
+ input_size: Optional[Tuple[int, int]] = None,
59
+ ) -> None:
60
+ """
61
+ Args:
62
+ dim (int): Number of input channels.
63
+ num_heads (int): Number of attention heads.
64
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
65
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
66
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
67
+ positional parameter size.
68
+ """
69
+ super().__init__()
70
+ self.num_heads = num_heads
71
+ head_dim = dim // num_heads
72
+ self.scale = head_dim**-0.5
73
+
74
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
75
+ self.proj = nn.Linear(dim, dim)
76
+
77
+ self.use_rel_pos = use_rel_pos
78
+ if self.use_rel_pos:
79
+ assert (
80
+ input_size is not None
81
+ ), "Input size must be provided if using relative positional encoding."
82
+ # Initialize relative positional embeddings
83
+ self.rel_pos_h = mx.zeros((2 * input_size[0] - 1, head_dim))
84
+ self.rel_pos_w = mx.zeros((2 * input_size[1] - 1, head_dim))
85
+
86
+ def __call__(self, x: mx.array) -> mx.array:
87
+ B, H, W, _ = x.shape
88
+
89
+ # QKV projection and reshape
90
+ qkv = (
91
+ self.qkv(x)
92
+ .reshape(B, H * W, 3, self.num_heads, -1)
93
+ .transpose(2, 0, 3, 1, 4)
94
+ )
95
+
96
+ # Separate q, k, v
97
+ qkv_reshaped = qkv.reshape(3, B * self.num_heads, H * W, -1)
98
+ q, k, v = qkv_reshaped[0], qkv_reshaped[1], qkv_reshaped[2]
99
+
100
+ # Compute relative positional embeddings if needed
101
+ rel_h, rel_w = None, None
102
+ if self.use_rel_pos:
103
+ rel_h, rel_w = add_decomposed_rel_pos(
104
+ q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)
105
+ )
106
+
107
+ # Reshape for attention
108
+ q = q.reshape(B, self.num_heads, H * W, -1)
109
+ k = k.reshape(B, self.num_heads, H * W, -1)
110
+ v = v.reshape(B, self.num_heads, H * W, -1)
111
+
112
+ # Apply scaled dot product attention
113
+ if self.use_rel_pos:
114
+ rel_h = rel_h.reshape(
115
+ B, self.num_heads, rel_h.shape[1], rel_h.shape[2], rel_h.shape[3]
116
+ )
117
+ rel_w = rel_w.reshape(
118
+ B, self.num_heads, rel_w.shape[1], rel_w.shape[2], rel_w.shape[3]
119
+ )
120
+ attn_bias = (rel_h + rel_w).reshape(
121
+ B, self.num_heads, rel_h.shape[2], rel_h.shape[3] * rel_w.shape[4]
122
+ )
123
+ x = mx.fast.scaled_dot_product_attention(
124
+ q, k, v, scale=self.scale, mask=attn_bias
125
+ )
126
+ else:
127
+ x = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale)
128
+
129
+ # Reshape output
130
+ x = (
131
+ x.reshape(B, self.num_heads, H, W, -1)
132
+ .transpose(0, 2, 3, 1, 4)
133
+ .reshape(B, H, W, -1)
134
+ )
135
+
136
+ x = self.proj(x)
137
+ return x
138
+
139
+
140
+ class Block(nn.Module):
141
+ """Transformer blocks with support of window attention and residual propagation."""
142
+
143
+ def __init__(
144
+ self,
145
+ dim: int,
146
+ num_heads: int,
147
+ mlp_ratio: float = 4.0,
148
+ qkv_bias: bool = True,
149
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
150
+ act_layer: Type[nn.Module] = nn.GELU,
151
+ use_rel_pos: bool = False,
152
+ window_size: int = 0,
153
+ input_size: Optional[Tuple[int, int]] = None,
154
+ ) -> None:
155
+ """
156
+ Args:
157
+ dim (int): Number of input channels.
158
+ num_heads (int): Number of attention heads in each ViT block.
159
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
160
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
161
+ norm_layer (nn.Module): Normalization layer.
162
+ act_layer (nn.Module): Activation layer.
163
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
164
+ window_size (int): Window size for window attention blocks. If it equals 0, then
165
+ use global attention.
166
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
167
+ positional parameter size.
168
+ """
169
+ super().__init__()
170
+ self.norm1 = norm_layer(dim, eps=1e-6)
171
+ self.attn = Attention(
172
+ dim,
173
+ num_heads=num_heads,
174
+ qkv_bias=qkv_bias,
175
+ use_rel_pos=use_rel_pos,
176
+ input_size=input_size if window_size == 0 else (window_size, window_size),
177
+ )
178
+
179
+ self.norm2 = norm_layer(dim, eps=1e-6)
180
+ self.mlp = MLPBlock(
181
+ embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer
182
+ )
183
+
184
+ self.window_size = window_size
185
+
186
+ def __call__(self, x: mx.array) -> mx.array:
187
+ shortcut = x
188
+ x = self.norm1(x)
189
+
190
+ # Window partition
191
+ if self.window_size > 0:
192
+ H, W = x.shape[1], x.shape[2]
193
+ x, pad_hw = window_partition(x, self.window_size)
194
+
195
+ x = self.attn(x)
196
+
197
+ # Reverse window partition
198
+ if self.window_size > 0:
199
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
200
+
201
+ x = shortcut + x
202
+ x = x + self.mlp(self.norm2(x))
203
+
204
+ return x
205
+
206
+
207
+ class PatchEmbed(nn.Module):
208
+ """Image to Patch Embedding."""
209
+
210
+ def __init__(
211
+ self,
212
+ kernel_size: Tuple[int, int] = (16, 16),
213
+ stride: Tuple[int, int] = (16, 16),
214
+ in_chans: int = 3,
215
+ embed_dim: int = 768,
216
+ ) -> None:
217
+ """
218
+ Args:
219
+ kernel_size (Tuple): kernel size of the projection layer.
220
+ stride (Tuple): stride of the projection layer.
221
+ in_chans (int): Number of input image channels.
222
+ embed_dim (int): Patch embedding dimension.
223
+ """
224
+ super().__init__()
225
+ self.proj = nn.Conv2d(
226
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride
227
+ )
228
+
229
+ def __call__(self, x: mx.array) -> mx.array:
230
+ x = self.proj(x)
231
+ return x
232
+
233
+
234
+ class SAMEncoder(nn.Module):
235
+ """Vision Transformer encoder based on SAM architecture."""
236
+
237
+ def __init__(
238
+ self,
239
+ img_size: int = 1024,
240
+ patch_size: int = 16,
241
+ in_chans: int = 3,
242
+ embed_dim: int = 768,
243
+ depth: int = 12,
244
+ num_heads: int = 12,
245
+ mlp_ratio: float = 4.0,
246
+ out_chans: int = 256,
247
+ qkv_bias: bool = True,
248
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
249
+ act_layer: Type[nn.Module] = nn.GELU,
250
+ use_abs_pos: bool = True,
251
+ use_rel_pos: bool = True,
252
+ window_size: int = 14,
253
+ global_attn_indexes: Tuple[int, ...] = (2, 5, 8, 11),
254
+ final_out_chans: int = 1024,
255
+ ) -> None:
256
+ """
257
+ Args:
258
+ img_size (int): Input image size.
259
+ patch_size (int): Patch size.
260
+ in_chans (int): Number of input image channels.
261
+ embed_dim (int): Patch embedding dimension.
262
+ depth (int): Depth of ViT.
263
+ num_heads (int): Number of attention heads in each ViT block.
264
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
265
+ out_chans (int): Output channels for neck.
266
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
267
+ norm_layer (nn.Module): Normalization layer.
268
+ act_layer (nn.Module): Activation layer.
269
+ use_abs_pos (bool): If True, use absolute positional embeddings.
270
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
271
+ window_size (int): Window size for window attention blocks.
272
+ global_attn_indexes (tuple): Indexes for blocks using global attention.
273
+ final_out_chans (int): Final output channels after net_3 (1024 for OCR, 896 for OCR-2).
274
+ """
275
+ super().__init__()
276
+ self.img_size = img_size
277
+
278
+ self.patch_embed = PatchEmbed(
279
+ kernel_size=(patch_size, patch_size),
280
+ stride=(patch_size, patch_size),
281
+ in_chans=in_chans,
282
+ embed_dim=embed_dim,
283
+ )
284
+
285
+ self.use_abs_pos = use_abs_pos
286
+ if use_abs_pos:
287
+ # Initialize absolute positional embedding with pretrain image size
288
+ self.pos_embed = mx.zeros(
289
+ (1, img_size // patch_size, img_size // patch_size, embed_dim)
290
+ )
291
+
292
+ self.blocks = []
293
+ for i in range(depth):
294
+ block = Block(
295
+ dim=embed_dim,
296
+ num_heads=num_heads,
297
+ mlp_ratio=mlp_ratio,
298
+ qkv_bias=qkv_bias,
299
+ norm_layer=norm_layer,
300
+ act_layer=act_layer,
301
+ use_rel_pos=use_rel_pos,
302
+ window_size=window_size if i not in global_attn_indexes else 0,
303
+ input_size=(img_size // patch_size, img_size // patch_size),
304
+ )
305
+ self.blocks.append(block)
306
+
307
+ # Neck layers for output processing
308
+ self.neck = [
309
+ nn.Conv2d(embed_dim, out_chans, kernel_size=1, bias=False),
310
+ nn.LayerNorm(out_chans, eps=1e-6),
311
+ nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1, bias=False),
312
+ nn.LayerNorm(out_chans, eps=1e-6),
313
+ ]
314
+
315
+ # Additional downsampling layers
316
+ self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
317
+ self.net_3 = nn.Conv2d(
318
+ 512, final_out_chans, kernel_size=3, stride=2, padding=1, bias=False
319
+ )
320
+
321
+ def __call__(self, x: mx.array) -> mx.array:
322
+ # Patch embedding
323
+ x = self.patch_embed(x)
324
+
325
+ # Add positional embeddings
326
+ if self.use_abs_pos:
327
+ x = x + get_abs_pos_sam(self.pos_embed, x.shape[1])
328
+
329
+ # Apply transformer blocks
330
+ for blk in self.blocks:
331
+ x = blk(x)
332
+
333
+ # Apply neck layers
334
+ for n in self.neck:
335
+ x = n(x)
336
+
337
+ # Additional downsampling
338
+ x = self.net_2(x)
339
+ x = self.net_3(x)
340
+
341
+ return x
342
+
343
+
344
+ # Utility functions
345
+
346
+
347
+ def window_partition(x: mx.array, window_size: int) -> Tuple[mx.array, Tuple[int, int]]:
348
+ """
349
+ Partition into non-overlapping windows with padding if needed.
350
+
351
+ Args:
352
+ x (mx.array): input tokens with [B, H, W, C].
353
+ window_size (int): window size.
354
+
355
+ Returns:
356
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
357
+ (Hp, Wp): padded height and width before partition
358
+ """
359
+ B, H, W, C = x.shape
360
+
361
+ pad_h = (window_size - H % window_size) % window_size
362
+ pad_w = (window_size - W % window_size) % window_size
363
+
364
+ if pad_h > 0 or pad_w > 0:
365
+ x = mx.pad(x, [(0, 0), (0, pad_h), (0, pad_w), (0, 0)])
366
+
367
+ Hp, Wp = H + pad_h, W + pad_w
368
+
369
+ x = x.reshape(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
370
+ windows = x.transpose(0, 1, 3, 2, 4, 5).reshape(-1, window_size, window_size, C)
371
+
372
+ return windows, (Hp, Wp)
373
+
374
+
375
+ def window_unpartition(
376
+ windows: mx.array, # FIXED: Changed from np.ndarray to mx.array
377
+ window_size: int,
378
+ pad_hw: Tuple[int, int],
379
+ hw: Tuple[int, int],
380
+ ) -> mx.array: # FIXED: Changed return type from implicit to mx.array
381
+ """
382
+ Window unpartition into original sequences and removing padding.
383
+
384
+ Args:
385
+ windows (mx.array): input tokens with [B * num_windows, window_size, window_size, C].
386
+ window_size (int): window size.
387
+ pad_hw (Tuple): padded height and width (Hp, Wp).
388
+ hw (Tuple): original height and width (H, W) before padding.
389
+
390
+ Returns:
391
+ x: unpartitioned sequences with [B, H, W, C].
392
+ """
393
+ Hp, Wp = pad_hw
394
+ H, W = hw
395
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
396
+
397
+ x = windows.reshape(
398
+ B, Hp // window_size, Wp // window_size, window_size, window_size, -1
399
+ )
400
+ x = x.transpose(0, 1, 3, 2, 4, 5).reshape(B, Hp, Wp, -1)
401
+
402
+ if Hp > H or Wp > W:
403
+ x = x[:, :H, :W, :]
404
+
405
+ return x
406
+
407
+
408
+ def get_rel_pos(q_size: int, k_size: int, rel_pos: mx.array) -> mx.array:
409
+ """
410
+ Get relative positional embeddings according to the relative positions of
411
+ query and key sizes.
412
+
413
+ Args:
414
+ q_size (int): size of query q.
415
+ k_size (int): size of key k.
416
+ rel_pos (mx.array): relative position embeddings (L, C).
417
+
418
+ Returns:
419
+ Extracted positional embeddings according to relative positions.
420
+ """
421
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
422
+
423
+ # Interpolate rel pos if needed
424
+ if rel_pos.shape[0] != max_rel_dist:
425
+ dtype = rel_pos.dtype
426
+ rel_pos = rel_pos.astype(mx.float32)
427
+ rel_pos_resized = rel_pos.reshape(1, rel_pos.shape[0], -1).transpose(0, 2, 1)
428
+
429
+ # Linear interpolation
430
+ scale = rel_pos_resized.shape[2] / max_rel_dist
431
+ indices = mx.arange(max_rel_dist, dtype=mx.float32) * scale
432
+ idx_floor = mx.floor(indices).astype(mx.int32)
433
+ idx_ceil = mx.minimum(idx_floor + 1, rel_pos_resized.shape[2] - 1)
434
+ weight = indices - idx_floor.astype(mx.float32)
435
+
436
+ rel_pos_resized = (
437
+ mx.take(rel_pos_resized, idx_floor, axis=2) * (1 - weight)
438
+ + mx.take(rel_pos_resized, idx_ceil, axis=2) * weight
439
+ ).astype(dtype)
440
+
441
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).transpose(1, 0)
442
+ else:
443
+ rel_pos_resized = rel_pos
444
+
445
+ # Scale the coords with short length if shapes for q and k are different
446
+ q_coords = mx.arange(q_size, dtype=mx.float32)[:, None] * max(k_size / q_size, 1.0)
447
+ k_coords = mx.arange(k_size, dtype=mx.float32)[None, :] * max(q_size / k_size, 1.0)
448
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
449
+
450
+ return rel_pos_resized[relative_coords.astype(mx.int32)]
451
+
452
+
453
+ def add_decomposed_rel_pos(
454
+ q: mx.array, # FIXED: Changed from np.ndarray to mx.array
455
+ rel_pos_h: mx.array, # FIXED: Changed from np.ndarray to mx.array
456
+ rel_pos_w: mx.array, # FIXED: Changed from np.ndarray to mx.array
457
+ q_size: Tuple[int, int],
458
+ k_size: Tuple[int, int],
459
+ ) -> Tuple[mx.array, mx.array]: # FIXED: Added explicit return type
460
+ """
461
+ Calculate decomposed Relative Positional Embeddings.
462
+
463
+ Args:
464
+ q (mx.array): query q in the attention layer with shape (B, q_h * q_w, C).
465
+ rel_pos_h (mx.array): relative position embeddings (Lh, C) for height axis.
466
+ rel_pos_w (mx.array): relative position embeddings (Lw, C) for width axis.
467
+ q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
468
+ k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
469
+
470
+ Returns:
471
+ Tuple of (rel_h, rel_w): relative position biases for height and width.
472
+ """
473
+ q_h, q_w = q_size
474
+ k_h, k_w = k_size
475
+
476
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
477
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
478
+
479
+ B, _, dim = q.shape
480
+ r_q = q.reshape(B, q_h, q_w, dim)
481
+
482
+ rel_h = mx.einsum("bhwc,hkc->bhwk", r_q, Rh)
483
+ rel_w = mx.einsum("bhwc,wkc->bhwk", r_q, Rw)
484
+ rel_h = rel_h[..., None]
485
+ rel_w = rel_w[..., None, :]
486
+ rel_h = rel_h.reshape(B, q_h * q_w, k_h, 1)
487
+ rel_w = rel_w.reshape(B, q_h * q_w, 1, k_w)
488
+
489
+ return rel_h, rel_w