fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,358 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from ..interpolate import bilinear_interpolate
7
+ from .config import VisionConfig
8
+
9
+
10
+ def check_array_shape(arr):
11
+ out_channels, kH, KW, t = arr.shape
12
+
13
+ if t == 3:
14
+ return True
15
+
16
+ # Check if out_channels is the largest, and kH and KW are the same
17
+ if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
18
+ return True
19
+ else:
20
+ return False
21
+
22
+
23
+ def rotate_half(x):
24
+ """Rotates half the hidden dims of the input."""
25
+ x1 = x[..., : x.shape[-1] // 2]
26
+ x2 = x[..., x.shape[-1] // 2 :]
27
+ return mx.concatenate([-x2, x1], axis=-1)
28
+
29
+
30
+ def apply_rotary_pos_emb_vision(tensor, freqs) -> mx.array:
31
+ orig_dtype = tensor.dtype
32
+
33
+ cos = mx.cos(freqs)
34
+ sin = mx.sin(freqs)
35
+
36
+ cos = mx.expand_dims(cos, axis=1) # Equivalent to unsqueeze(1)
37
+ cos = mx.tile(cos, (1, 1, 2)) # Equivalent to repeat(1, 1, 2)
38
+ cos = mx.expand_dims(cos, axis=0) # Equivalent to [None, ...]
39
+
40
+ sin = mx.expand_dims(sin, axis=1) # Equivalent to unsqueeze(1)
41
+ sin = mx.tile(sin, (1, 1, 2)) # Equivalent to repeat(1, 1, 2)
42
+ sin = mx.expand_dims(sin, axis=0) # Equivalent to [None, ...]
43
+
44
+ output = (tensor * cos) + (rotate_half(tensor) * sin)
45
+ return output.astype(orig_dtype)
46
+
47
+
48
+ class VisionRotaryEmbedding(nn.Module):
49
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
50
+ super().__init__()
51
+ self.dim = dim
52
+ self.theta = theta
53
+
54
+ def __call__(self, seqlen: int) -> mx.array:
55
+ inv_freq = 1.0 / (
56
+ self.theta ** (mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim)
57
+ )
58
+ seq = mx.arange(seqlen.tolist(), dtype=inv_freq.dtype)
59
+ freqs = mx.outer(seq, inv_freq)
60
+ return freqs
61
+
62
+
63
+ class PaddleOCRVisionEmbeddings(nn.Module):
64
+ def __init__(
65
+ self,
66
+ patch_size: int = 14,
67
+ image_size: int = 384,
68
+ in_channels: int = 3,
69
+ embed_dim: int = 1152,
70
+ ) -> None:
71
+ super().__init__()
72
+ self.patch_size = patch_size
73
+ self.in_channels = in_channels
74
+ self.embed_dim = embed_dim
75
+
76
+ self.patch_embedding = nn.Conv2d(
77
+ in_channels=in_channels,
78
+ out_channels=embed_dim,
79
+ kernel_size=patch_size,
80
+ stride=patch_size,
81
+ )
82
+
83
+ num_patches = (image_size // patch_size) ** 2
84
+ self.position_embedding = nn.Embedding(num_patches, self.embed_dim)
85
+
86
+ def interpolate_pos_encoding(self, height: int, width: int) -> mx.array:
87
+ # Get the number of positions and embedding dimension
88
+ num_positions = self.position_embedding.weight.shape[0]
89
+
90
+ # Get all position embeddings (this will dequantize if quantized)
91
+ position_ids = mx.arange(num_positions)
92
+ patch_pos_embed = self.position_embedding(position_ids)
93
+ dim = patch_pos_embed.shape[-1]
94
+
95
+ # Reshape to 2D grid
96
+ sqrt_num_positions = int(num_positions**0.5)
97
+ patch_pos_embed = patch_pos_embed.reshape(
98
+ 1, sqrt_num_positions, sqrt_num_positions, dim
99
+ )
100
+
101
+ # Interpolate to target size
102
+ patch_pos_embed = bilinear_interpolate(
103
+ patch_pos_embed[0],
104
+ height,
105
+ width,
106
+ ).astype(patch_pos_embed.dtype)
107
+ patch_pos_embed = patch_pos_embed.reshape(-1, dim)
108
+ return patch_pos_embed
109
+
110
+ def __call__(self, hidden_states: mx.array, grid_thw: mx.array) -> mx.array:
111
+ batch_size, squence_len, channel, patch_size, patch_size = hidden_states.shape
112
+ target_dtype = self.patch_embedding.weight.dtype
113
+ hidden_states = hidden_states.reshape(
114
+ batch_size * squence_len, channel, patch_size, patch_size
115
+ )
116
+ # For MLX-Conv2d
117
+ hidden_states = hidden_states.transpose(0, 2, 3, 1)
118
+ patch_embeds = self.patch_embedding(hidden_states).astype(target_dtype)
119
+ patch_embeds = patch_embeds.transpose(0, 3, 1, 2)
120
+ embeddings = patch_embeds.flatten(-2).squeeze(-1)
121
+ embeddings = embeddings.reshape(batch_size, squence_len, -1)
122
+
123
+ start = 0
124
+ embeddings = embeddings.squeeze(0)
125
+ tmp_embeddings = []
126
+ for image_grid in grid_thw:
127
+ t, h, w = image_grid.tolist()
128
+ end = start + t * h * w
129
+ image_embeddings = embeddings[start:end, :]
130
+ position_embedding = self.interpolate_pos_encoding(h, w)
131
+ image_embeddings = image_embeddings + position_embedding
132
+ tmp_embeddings.append(image_embeddings)
133
+ start = end
134
+ embeddings = mx.concatenate(tmp_embeddings, axis=0)
135
+
136
+ return embeddings
137
+
138
+
139
+ class PaddleOCRProjector(nn.Module):
140
+ def __init__(self, dim, context_dim, spatial_merge_size) -> None:
141
+ super().__init__()
142
+
143
+ hidden_size = dim * (spatial_merge_size**2)
144
+ self.spatial_merge_size = spatial_merge_size
145
+ self.pre_norm = nn.LayerNorm(dim, eps=1e-6)
146
+ self.linear_1 = nn.Linear(hidden_size, hidden_size, bias=True)
147
+ self.act = nn.GELU()
148
+ self.linear_2 = nn.Linear(hidden_size, context_dim, bias=True)
149
+
150
+ def __call__(self, x: mx.array, grid_thw: mx.array) -> mx.array:
151
+ x_chunks = x.split(grid_thw.prod(axis=1).tolist(), axis=0)
152
+
153
+ processed_features = []
154
+ for x, image_grid in zip(x_chunks, grid_thw):
155
+ x = self.pre_norm(x)
156
+ t, h, w = image_grid.tolist()
157
+ d = x.shape[-1]
158
+ h_block = h // self.spatial_merge_size
159
+ w_block = w // self.spatial_merge_size
160
+
161
+ x = x.reshape(
162
+ t, h_block, self.spatial_merge_size, w_block, self.spatial_merge_size, d
163
+ )
164
+ x = x.transpose(0, 1, 3, 2, 4, 5)
165
+ x = x.reshape(
166
+ t * h_block * w_block,
167
+ self.spatial_merge_size * self.spatial_merge_size * d,
168
+ )
169
+
170
+ hidden_states = self.linear_1(x)
171
+ hidden_states = self.act(hidden_states)
172
+ hidden_states = self.linear_2(hidden_states)
173
+ processed_features.append(hidden_states)
174
+
175
+ return mx.concatenate(processed_features, axis=0)
176
+
177
+
178
+ class Attention(nn.Module):
179
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
180
+ super().__init__()
181
+ self.num_heads = num_heads
182
+ self.head_dim = head_dim = dim // num_heads
183
+ self.scale = head_dim**-0.5
184
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
185
+ self.out_proj = nn.Linear(dim, dim)
186
+
187
+ def __call__(
188
+ self, x: mx.array, cu_seqlens: mx.array, rotary_pos_emb: mx.array = None
189
+ ) -> mx.array:
190
+ seq_length = x.shape[0]
191
+ qkv = (
192
+ self.qkv(x).reshape(seq_length, 3, self.num_heads, -1).transpose(1, 0, 2, 3)
193
+ )
194
+ q, k, v = mx.split(qkv, 3)
195
+
196
+ q = apply_rotary_pos_emb_vision(mx.expand_dims(q, 0), rotary_pos_emb)[0]
197
+ k = apply_rotary_pos_emb_vision(mx.expand_dims(k, 0), rotary_pos_emb)[0]
198
+
199
+ attention_mask = mx.ones((1, seq_length, seq_length), dtype=x.dtype)
200
+
201
+ for i in range(1, len(cu_seqlens)):
202
+ start = int(cu_seqlens[i - 1])
203
+ end = int(cu_seqlens[i])
204
+ attention_mask[start:end, start:end] = 0
205
+
206
+ q = q.transpose(0, 2, 1, 3)
207
+ k = k.transpose(0, 2, 1, 3)
208
+ v = v.transpose(0, 2, 1, 3)
209
+
210
+ output = mx.fast.scaled_dot_product_attention(
211
+ q, k, v, scale=self.scale, mask=attention_mask
212
+ )
213
+ output = output.transpose(0, 2, 1, 3)
214
+ output = output.reshape(seq_length, -1)
215
+ return self.out_proj(output)
216
+
217
+
218
+ class MLP(nn.Module):
219
+ def __init__(self, dim, hidden_dim):
220
+ super().__init__()
221
+ self.activation_fn = nn.GELU(approx="precise")
222
+ self.fc1 = nn.Linear(dim, hidden_dim)
223
+ self.fc2 = nn.Linear(hidden_dim, dim)
224
+
225
+ def __call__(self, x: mx.array) -> mx.array:
226
+ x = self.activation_fn(self.fc1(x))
227
+ x = self.fc2(x)
228
+ return x
229
+
230
+
231
+ class PaddleOCRVisionEncoderLayer(nn.Module):
232
+ def __init__(self, config: VisionConfig) -> None:
233
+ super().__init__()
234
+ self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6)
235
+ self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6)
236
+
237
+ self.self_attn = Attention(
238
+ dim=config.hidden_size, num_heads=config.num_attention_heads
239
+ )
240
+ self.mlp = MLP(dim=config.hidden_size, hidden_dim=config.intermediate_size)
241
+
242
+ def __call__(self, hidden_states, cu_seqlens, rotary_pos_emb) -> mx.array:
243
+ hidden_states = hidden_states + self.self_attn(
244
+ self.layer_norm1(hidden_states),
245
+ cu_seqlens=cu_seqlens,
246
+ rotary_pos_emb=rotary_pos_emb,
247
+ )
248
+ hidden_states = hidden_states + self.mlp(self.layer_norm2(hidden_states))
249
+ return hidden_states
250
+
251
+
252
+ class VisionModel(nn.Module):
253
+ def __init__(self, config: VisionConfig) -> None:
254
+ super().__init__()
255
+ self.config = config
256
+ self.model_type = config.model_type
257
+ if self.model_type != "paddleocr_vl":
258
+ raise ValueError(f"Unsupported model type: {self.model_type}")
259
+
260
+ self.embeddings = PaddleOCRVisionEmbeddings(
261
+ patch_size=config.patch_size,
262
+ image_size=config.image_size,
263
+ in_channels=config.num_channels,
264
+ embed_dim=config.hidden_size,
265
+ )
266
+
267
+ head_dim = config.hidden_size // config.num_attention_heads
268
+ self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
269
+
270
+ self.layers = [
271
+ PaddleOCRVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)
272
+ ]
273
+ self.post_layernorm = nn.LayerNorm(
274
+ config.hidden_size, eps=config.layer_norm_eps
275
+ )
276
+ self.projector = PaddleOCRProjector(
277
+ dim=config.hidden_size,
278
+ context_dim=1024,
279
+ spatial_merge_size=config.spatial_merge_size,
280
+ )
281
+
282
+ def rot_pos_emb(self, grid_thw):
283
+ pos_ids = []
284
+
285
+ split_hids = []
286
+ split_wids = []
287
+ for t, h, w in grid_thw:
288
+ image_pids = mx.arange(int(t * h * w)) % (h * w)
289
+ sample_hids = image_pids // w
290
+ sample_wids = image_pids % w
291
+ split_hids.append(sample_hids)
292
+ split_wids.append(sample_wids)
293
+
294
+ height_position_ids = mx.concatenate(split_hids, axis=0)
295
+ width_position_ids = mx.concatenate(split_wids, axis=0)
296
+
297
+ pos_ids = mx.stack([height_position_ids, width_position_ids], axis=-1)
298
+ max_grid_size = mx.max(grid_thw[:, 1:])
299
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
300
+ rotary_pos_emb_full = rotary_pos_emb_full[pos_ids]
301
+
302
+ return rotary_pos_emb_full.reshape(pos_ids.shape[0], -1)
303
+
304
+ def __call__(
305
+ self,
306
+ hidden_states: mx.array,
307
+ grid_thw: mx.array,
308
+ output_hidden_states: Optional[bool] = None,
309
+ ) -> mx.array:
310
+ hidden_states = self.embeddings(hidden_states, grid_thw)
311
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
312
+
313
+ # Assuming grid_thw has shape (batch_size, 3)
314
+ batch_size = grid_thw.shape[0]
315
+
316
+ # Calculate cu_seqlens for each item in the batch
317
+ cu_seqlens = []
318
+ for i in range(batch_size):
319
+ seq_len = grid_thw[i, 1] * grid_thw[i, 2]
320
+ cu_seqlens.append(mx.repeat(seq_len, grid_thw[i, 0]))
321
+
322
+ # Concatenate the cu_seqlens for all items in the batch
323
+ cu_seqlens = mx.concatenate(cu_seqlens)
324
+
325
+ cu_seqlens = mx.cumsum(cu_seqlens.astype(mx.int32), axis=0)
326
+ cu_seqlens = mx.pad(cu_seqlens, (1, 0), mode="constant", constant_values=0)
327
+
328
+ encoder_states = (hidden_states,) if output_hidden_states else None
329
+ for layer in self.layers:
330
+ hidden_states = layer(
331
+ hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
332
+ )
333
+ if output_hidden_states:
334
+ encoder_states = encoder_states + (hidden_states,)
335
+
336
+ hidden_states = self.post_layernorm(hidden_states)
337
+ hidden_states = self.projector(hidden_states, grid_thw)
338
+ return hidden_states
339
+
340
+ def sanitize(self, weights):
341
+ sanitized_weights = {}
342
+ for k, v in weights.items():
343
+ if "position_ids" in k:
344
+ # Remove unused position_ids
345
+ continue
346
+ elif "patch_embedding.weight" in k:
347
+ # PyTorch conv2d weight tensors have shape:
348
+ # [out_channels, in_channels, kH, KW]
349
+ # MLX conv2d expects the weight be of shape:
350
+ # [out_channels, kH, KW, in_channels]
351
+ if check_array_shape(v):
352
+ sanitized_weights[k] = v
353
+ else:
354
+ sanitized_weights[k] = v.transpose(0, 2, 3, 1)
355
+ else:
356
+ sanitized_weights[k] = v
357
+
358
+ return sanitized_weights
@@ -0,0 +1,4 @@
1
+ from .config import ModelConfig, TextConfig, VisionConfig
2
+ from .language import LanguageModel
3
+ from .paligemma import Model
4
+ from .vision import VisionModel
@@ -0,0 +1,50 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import List, Optional
3
+
4
+ from ..base import BaseModelConfig
5
+
6
+
7
+ @dataclass
8
+ class ModelConfig(BaseModelConfig):
9
+ text_config: "TextConfig" = field(default_factory=lambda: TextConfig())
10
+ vision_config: "VisionConfig" = field(default_factory=lambda: VisionConfig())
11
+ model_type: str = "paligemma"
12
+ vocab_size: int = 257152
13
+ ignore_index: int = -100
14
+ image_token_index: int = 257152
15
+ hidden_size: int = 2048
16
+ pad_token_id: int = 0
17
+ eos_token_id: Optional[List[int]] = None
18
+
19
+
20
+ @dataclass
21
+ class TextConfig(BaseModelConfig):
22
+ model_type: str = "paligemma"
23
+ hidden_size: int = 2048
24
+ num_hidden_layers: int = 18
25
+ intermediate_size: int = 8192
26
+ num_attention_heads: int = 16
27
+ num_key_value_heads: int = 16
28
+ vocab_size: int = 256000
29
+ head_dim: int = 256
30
+ rms_norm_eps: float = 1e-6
31
+ rope_theta: float = 10000
32
+ rope_traditional: bool = False
33
+ attn_logit_softcapping: Optional[float] = None
34
+ final_logit_softcapping: Optional[float] = None
35
+ query_pre_attn_scalar: Optional[float] = None
36
+ max_position_embeddings: int = 4096
37
+
38
+
39
+ @dataclass
40
+ class VisionConfig(BaseModelConfig):
41
+ model_type: str = "siglip_vision_model"
42
+ num_hidden_layers: int = 27
43
+ hidden_size: int = 1152
44
+ intermediate_size: int = 4304
45
+ num_attention_heads: int = 16
46
+ patch_size: int = 14
47
+ projection_dim: int = 2048
48
+ image_size: int = 224
49
+ num_channels: int = 3
50
+ layer_norm_eps: float = 1e-6
@@ -0,0 +1,253 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from ..base import (
7
+ LanguageModelOutput,
8
+ create_attention_mask,
9
+ scaled_dot_product_attention,
10
+ )
11
+ from ..cache import KVCache
12
+ from .config import TextConfig
13
+
14
+
15
+ class RMSNorm(nn.Module):
16
+ def __init__(self, dims: int, eps: float = 1e-6):
17
+ super().__init__()
18
+ self.weight = mx.ones((dims,))
19
+ self.eps = eps
20
+
21
+ def __call__(self, x):
22
+ return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
23
+
24
+
25
+ class Attention(nn.Module):
26
+ def __init__(self, config: TextConfig):
27
+ super().__init__()
28
+
29
+ dim = config.hidden_size
30
+ self.n_heads = n_heads = config.num_attention_heads
31
+ self.n_kv_heads = n_kv_heads = config.num_key_value_heads
32
+ self.model_type = config.model_type
33
+ self.attn_logit_softcapping = config.attn_logit_softcapping
34
+ self.repeats = n_heads // n_kv_heads
35
+ self.head_dim = head_dim = (
36
+ config.hidden_size // n_heads
37
+ if self.model_type == "gemma"
38
+ else config.head_dim
39
+ )
40
+ self.scale = (
41
+ head_dim**-0.5
42
+ if self.model_type == "gemma"
43
+ else 1.0 / (config.query_pre_attn_scalar**0.5)
44
+ )
45
+
46
+ self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
47
+ self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
48
+ self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
49
+ self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
50
+
51
+ self.rope = nn.RoPE(
52
+ head_dim,
53
+ traditional=config.rope_traditional,
54
+ base=config.rope_theta,
55
+ )
56
+
57
+ def __call__(
58
+ self,
59
+ x: mx.array,
60
+ mask: Optional[mx.array] = None,
61
+ cache: Optional[KVCache] = None,
62
+ ) -> mx.array:
63
+ B, L, D = x.shape
64
+
65
+ queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
66
+
67
+ # Prepare the queries, keys and values for the attention computation
68
+ queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
69
+ keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
70
+ values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
71
+
72
+ if cache is not None:
73
+ queries = self.rope(queries, offset=cache.offset)
74
+ keys = self.rope(keys, offset=cache.offset)
75
+ keys, values = cache.update_and_fetch(keys, values)
76
+ else:
77
+ queries = self.rope(queries)
78
+ keys = self.rope(keys)
79
+
80
+ if self.model_type == "gemma":
81
+ output = scaled_dot_product_attention(
82
+ queries, keys, values, cache, scale=self.scale, mask=mask
83
+ )
84
+ else:
85
+ queries = queries * self.scale
86
+
87
+ if self.repeats > 1:
88
+ queries = queries.reshape(
89
+ B, self.n_kv_heads, self.repeats, L, self.head_dim
90
+ )
91
+ keys = mx.expand_dims(keys, 2)
92
+ values = mx.expand_dims(values, 2)
93
+
94
+ scores = queries @ keys.swapaxes(-1, -2)
95
+ scores = mx.tanh(scores / self.attn_logit_softcapping)
96
+ scores *= self.attn_logit_softcapping
97
+
98
+ if mask is not None and isinstance(mask, mx.array):
99
+ scores = scores + mask
100
+ scores = mx.softmax(scores, precise=True, axis=-1)
101
+ output = scores @ values
102
+ if self.repeats > 1:
103
+ output = output.reshape(B, self.n_heads, L, self.head_dim)
104
+
105
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
106
+ return self.o_proj(output)
107
+
108
+
109
+ class MLP(nn.Module):
110
+ def __init__(self, dim, hidden_dim, model_type):
111
+ super().__init__()
112
+ self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
113
+ self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
114
+ self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
115
+ self.gelu = nn.GELU() if model_type == "gemma" else nn.GELU(approx="precise")
116
+
117
+ def __call__(self, x) -> mx.array:
118
+ return self.down_proj(self.gelu(self.gate_proj(x)) * self.up_proj(x))
119
+
120
+
121
+ class TransformerBlock(nn.Module):
122
+ def __init__(self, config: TextConfig):
123
+ super().__init__()
124
+ self.model_type = config.model_type
125
+ self.num_attention_heads = config.num_attention_heads
126
+ self.hidden_size = config.hidden_size
127
+ self.self_attn = Attention(config)
128
+ self.mlp = MLP(config.hidden_size, config.intermediate_size, config.model_type)
129
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
130
+ self.post_attention_layernorm = RMSNorm(
131
+ config.hidden_size, eps=config.rms_norm_eps
132
+ )
133
+ self.config = config
134
+
135
+ if config.model_type == "gemma2":
136
+ self.pre_feedforward_layernorm = RMSNorm(
137
+ config.hidden_size, eps=config.rms_norm_eps
138
+ )
139
+ self.post_feedforward_layernorm = RMSNorm(
140
+ config.hidden_size, eps=config.rms_norm_eps
141
+ )
142
+
143
+ def __call__(
144
+ self,
145
+ x: mx.array,
146
+ mask: Optional[mx.array] = None,
147
+ cache: Optional[KVCache] = None,
148
+ ) -> mx.array:
149
+ # Self attention block
150
+ r = self.self_attn(self.input_layernorm(x), mask, cache)
151
+
152
+ if self.model_type == "gemma":
153
+ # Gemma: Post-attention residual connection then MLP
154
+ h = x + r
155
+ r = self.mlp(self.post_attention_layernorm(h))
156
+ out = h + r
157
+ else:
158
+ # Gemma2: Normalized residual connections with pre/post norms
159
+ h = x + self.post_attention_layernorm(r)
160
+ r = self.mlp(self.pre_feedforward_layernorm(h))
161
+ out = h + self.post_feedforward_layernorm(r)
162
+ return out
163
+
164
+
165
+ class GemmaModel(nn.Module):
166
+ def __init__(self, config: TextConfig):
167
+ super().__init__()
168
+ self.config = config
169
+ self.vocab_size = config.vocab_size
170
+ self.num_hidden_layers = config.num_hidden_layers
171
+ assert self.vocab_size > 0
172
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
173
+ self.layers = [
174
+ TransformerBlock(config=config) for _ in range(config.num_hidden_layers)
175
+ ]
176
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
177
+
178
+ def __call__(
179
+ self,
180
+ inputs: mx.array,
181
+ inputs_embeds: Optional[mx.array] = None,
182
+ mask: Optional[mx.array] = None,
183
+ cache=None,
184
+ ):
185
+ # for passing merged input embeddings
186
+ if inputs_embeds is None:
187
+ h = self.embed_tokens(inputs)
188
+ else:
189
+ h = inputs_embeds
190
+
191
+ h *= self.config.hidden_size**0.5
192
+
193
+ if cache is None:
194
+ cache = [None] * len(self.layers)
195
+
196
+ if mask is None or cache[0].offset > 0:
197
+ mask = create_attention_mask(h, cache, return_array=True)
198
+
199
+ for layer, c in zip(self.layers, cache):
200
+ h = layer(h, mask, c)
201
+
202
+ return self.norm(h)
203
+
204
+
205
+ class LanguageModel(nn.Module):
206
+ def __init__(self, config: TextConfig):
207
+ super().__init__()
208
+ self.config = config
209
+ self.final_logit_softcapping = config.final_logit_softcapping
210
+ self.model_type = config.model_type
211
+ self.model = GemmaModel(config)
212
+
213
+ if self.model_type not in ["gemma", "gemma2"]:
214
+ raise ValueError(
215
+ f"Model type {self.model_type} not supported. Currently only 'gemma' is supported"
216
+ )
217
+
218
+ def __call__(
219
+ self,
220
+ inputs: mx.array,
221
+ inputs_embeds: Optional[mx.array] = None,
222
+ mask: Optional[mx.array] = None,
223
+ cache=None,
224
+ **kwargs,
225
+ ):
226
+ out = self.model(inputs, mask=mask, cache=cache, inputs_embeds=inputs_embeds)
227
+ out = self.model.embed_tokens.as_linear(out)
228
+
229
+ if self.model_type == "gemma2":
230
+ out = mx.tanh(out / self.final_logit_softcapping)
231
+ out = out * self.final_logit_softcapping
232
+ return LanguageModelOutput(logits=out)
233
+
234
+ def sanitize(self, weights):
235
+ return {
236
+ k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
237
+ }
238
+
239
+ @property
240
+ def layers(self):
241
+ return self.model.layers
242
+
243
+ @property
244
+ def head_dim(self):
245
+ return (
246
+ self.config.hidden_size // self.config.num_attention_heads
247
+ if self.model_type == "gemma"
248
+ else self.config.head_dim
249
+ )
250
+
251
+ @property
252
+ def n_kv_heads(self):
253
+ return self.config.num_key_value_heads