fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,485 @@
1
+ from typing import List, Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from ..kernels import bicubic_interpolate
7
+ from .config import VisionConfig
8
+
9
+
10
+ def check_array_shape(arr):
11
+ shape = arr.shape
12
+
13
+ # Check if the shape has 4 dimensions
14
+ if len(shape) != 4:
15
+ return False
16
+
17
+ out_channels, kH, KW, _ = shape
18
+
19
+ # Check if out_channels is the largest, and kH and KW are the same
20
+ if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
21
+ return True
22
+ else:
23
+ return False
24
+
25
+
26
+ def rotate_half(x):
27
+ """Rotates half the hidden dims of the input."""
28
+ x1 = x[..., : x.shape[-1] // 2]
29
+ x2 = x[..., x.shape[-1] // 2 :]
30
+ return mx.concatenate([-x2, x1], axis=-1)
31
+
32
+
33
+ def apply_rotary_pos_emb_vision(tensor, freqs) -> mx.array:
34
+ orig_dtype = tensor.dtype
35
+
36
+ cos = mx.cos(freqs)
37
+ sin = mx.sin(freqs)
38
+
39
+ cos = mx.expand_dims(cos, axis=1) # Equivalent to unsqueeze(1)
40
+ cos = mx.tile(cos, (1, 1, 2)) # Equivalent to repeat(1, 1, 2)
41
+ cos = mx.expand_dims(cos, axis=0) # Equivalent to [None, ...]
42
+
43
+ sin = mx.expand_dims(sin, axis=1) # Equivalent to unsqueeze(1)
44
+ sin = mx.tile(sin, (1, 1, 2)) # Equivalent to repeat(1, 1, 2)
45
+ sin = mx.expand_dims(sin, axis=0) # Equivalent to [None, ...]
46
+
47
+ output = (tensor * cos) + (rotate_half(tensor) * sin)
48
+ return output.astype(orig_dtype)
49
+
50
+
51
+ class VisionRotaryEmbedding(nn.Module):
52
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
53
+ super().__init__()
54
+ self.dim = dim
55
+ self.theta = theta
56
+
57
+ def __call__(self, seqlen: int) -> mx.array:
58
+ inv_freq = 1.0 / (
59
+ self.theta ** (mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim)
60
+ )
61
+ seq = mx.arange(seqlen.tolist(), dtype=inv_freq.dtype)
62
+ freqs = mx.outer(seq, inv_freq)
63
+ return freqs
64
+
65
+
66
+ class Learnable2DInterpPosEmb(nn.Module):
67
+ def __init__(
68
+ self, height: int, width: int, dim: int, interpolation_mode: str = "bicubic"
69
+ ) -> None:
70
+ super().__init__()
71
+ self.height = height
72
+ self.width = width
73
+ self.interpolation_mode = interpolation_mode
74
+ self.weight = mx.ones((height, width, dim))
75
+
76
+ def __call__(self, x: mx.array, grid_hws: mx.array) -> mx.array:
77
+ pos_embs = []
78
+ for shape in grid_hws.tolist():
79
+ if shape == self.weight.shape[:-1]:
80
+ pos_embs.append(self.weight.flatten(end_axis=1))
81
+ else:
82
+ result = (
83
+ bicubic_interpolate(
84
+ mx.expand_dims(self.weight.transpose(2, 0, 1), axis=0),
85
+ size=shape,
86
+ )
87
+ .squeeze(0)
88
+ .transpose(1, 2, 0)
89
+ .flatten(end_axis=1)
90
+ )
91
+
92
+ pos_embs.append(result)
93
+
94
+ out = x + mx.concatenate(pos_embs).astype(x.dtype)
95
+ return out
96
+
97
+
98
+ class PatchEmbed(nn.Module):
99
+ def __init__(
100
+ self,
101
+ patch_size: int = 14,
102
+ num_channels: int = 3,
103
+ embed_dim: int = 1152,
104
+ init_pos_emb_height: int = 64,
105
+ ) -> None:
106
+ super().__init__()
107
+ self.patch_size = patch_size
108
+ self.num_channels = num_channels
109
+ self.embed_dim = embed_dim
110
+ self.init_pos_emb_height = init_pos_emb_height
111
+
112
+ self.proj = nn.Conv2d(
113
+ num_channels,
114
+ embed_dim,
115
+ kernel_size=patch_size,
116
+ stride=patch_size,
117
+ bias=True,
118
+ )
119
+ self.pos_emb = Learnable2DInterpPosEmb(
120
+ height=init_pos_emb_height, width=init_pos_emb_height, dim=embed_dim
121
+ )
122
+
123
+ def __call__(self, hidden_states: mx.array, grid_thw: mx.array) -> mx.array:
124
+ hidden_states = self.proj(hidden_states).swapaxes(1, 3)
125
+ hidden_states = hidden_states.reshape(hidden_states.shape[0], -1)
126
+ hidden_states = self.pos_emb(hidden_states, grid_thw)
127
+ return hidden_states
128
+
129
+
130
+ def _apply_rope_input_validation(x, freqs_cis):
131
+ assert x.ndim == freqs_cis.ndim + 1, (x.shape, freqs_cis.shape)
132
+ assert x.shape[:-2] == freqs_cis.shape[:-1], (x.shape, freqs_cis.shape)
133
+ assert x.shape[-1] == 2 * freqs_cis.shape[-1], (x.shape, freqs_cis.shape)
134
+ assert freqs_cis.dtype == mx.complex64, freqs_cis.dtype
135
+
136
+
137
+ def view_as_complex(x):
138
+ """
139
+ Convert a tensor with shape (..., 2) to a complex tensor with shape (...).
140
+ """
141
+ # Get real and imaginary parts
142
+ real, imag = x[..., 0], x[..., 1]
143
+ # Create complex tensor
144
+ return real + 1j * imag
145
+
146
+
147
+ def view_as_real(x):
148
+ """
149
+ Convert a complex tensor with shape (...) to a real tensor with shape (..., 2).
150
+ """
151
+ # Get real and imaginary parts
152
+ real = mx.real(x)
153
+ imag = mx.imag(x)
154
+ # Combine into a tensor with last dimension 2
155
+ return mx.stack([real, imag], axis=-1)
156
+
157
+
158
+ def apply_rope(
159
+ q: mx.array, k: mx.array, freqs_cis: mx.array
160
+ ) -> tuple[mx.array, mx.array]:
161
+ """
162
+ Args: (The leading dimensions of all inputs should be the same)
163
+ q: query, array of shape (..., num_heads, head_dim)
164
+ k: key, array of shape (..., num_heads, head_dim)
165
+ freqs_cis: array of shape (..., head_dim/2), dtype=mx.complex64. It contains the precomputed cis(freqs) for each position in the 2D grid.
166
+ Returns:
167
+ xq_out, xk_out: arrays of shape (..., num_heads, head_dim)
168
+ """
169
+ _apply_rope_input_validation(q, freqs_cis)
170
+ _apply_rope_input_validation(k, freqs_cis)
171
+
172
+ freqs_cis = mx.expand_dims(freqs_cis, axis=-2) # ..., 1, head_dim/2
173
+ # ..., num_heads, head_dim/2
174
+ q_ = view_as_complex(q.astype(mx.float32).reshape(*q.shape[:-1], -1, 2))
175
+ k_ = view_as_complex(k.astype(mx.float32).reshape(*k.shape[:-1], -1, 2))
176
+ q_out = view_as_real(q_ * freqs_cis).flatten(-2) # ..., num_heads, head_dim
177
+ k_out = view_as_real(k_ * freqs_cis).flatten(-2) # ..., num_heads, head_dim
178
+ return q_out.astype(q.dtype), k_out.astype(k.dtype)
179
+
180
+
181
+ class Attention(nn.Module):
182
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
183
+ super().__init__()
184
+ self.num_heads = num_heads
185
+ self.head_dim = head_dim = dim // num_heads
186
+ self.scale = head_dim**-0.5
187
+ self.wqkv = nn.Linear(dim, dim * 3, bias=True)
188
+ self.wo = nn.Linear(dim, dim, bias=True)
189
+
190
+ def __call__(
191
+ self, x: mx.array, cu_seqlens: mx.array, rotary_pos_emb: mx.array = None
192
+ ) -> mx.array:
193
+ seq_length = x.shape[0]
194
+ qkv = self.wqkv(x)
195
+
196
+ qkv_shape = qkv.shape[:-1] + (
197
+ 3,
198
+ self.num_heads,
199
+ self.head_dim,
200
+ )
201
+ # xqkv: (batch_size, seqlen, 3, nheads, headdim)
202
+ qkv = qkv.reshape(*qkv_shape)
203
+
204
+ q, k, v = mx.split(qkv, 3, axis=1)
205
+ q = q.squeeze(1)
206
+ k = k.squeeze(1)
207
+ v = v.squeeze(1)
208
+
209
+ q, k = apply_rope(q, k, rotary_pos_emb)
210
+
211
+ attention_mask = mx.zeros((1, seq_length, seq_length), dtype=x.dtype)
212
+
213
+ # Create attention mask for each sequence in the batch
214
+ for i in range(1, len(cu_seqlens)):
215
+ start = int(cu_seqlens[i - 1])
216
+ end = int(cu_seqlens[i])
217
+ attention_mask[..., start:end, start:end] = 1
218
+
219
+ q = q.transpose(1, 0, 2)
220
+ k = k.transpose(1, 0, 2)
221
+ v = v.transpose(1, 0, 2)
222
+
223
+ attn_weight = q @ k.swapaxes(-2, -1) / mx.sqrt(q.shape[-1])
224
+ attn_weight += attention_mask
225
+ attn_weight = mx.softmax(attn_weight, axis=-1).astype(q.dtype)
226
+
227
+ attn_output = attn_weight @ v
228
+ attn_output = attn_output.transpose(1, 0, 2)
229
+ attn_output = attn_output.reshape(seq_length, -1)
230
+ return self.wo(attn_output)
231
+
232
+
233
+ class MLP(nn.Module):
234
+ def __init__(self, dim, hidden_dim):
235
+ super().__init__()
236
+ self.activation_fn = nn.GELU()
237
+ self.fc0 = nn.Linear(dim, hidden_dim)
238
+ self.fc1 = nn.Linear(hidden_dim, dim)
239
+
240
+ def __call__(self, x: mx.array) -> mx.array:
241
+ x = self.activation_fn(self.fc0(x))
242
+ x = self.fc1(x)
243
+ return x
244
+
245
+
246
+ class Qwen2VLVisionBlock(nn.Module):
247
+ def __init__(self, config: VisionConfig) -> None:
248
+ super().__init__()
249
+ self.norm0 = nn.LayerNorm(config.embed_dim, eps=1e-6)
250
+ self.norm1 = nn.LayerNorm(config.embed_dim, eps=1e-6)
251
+
252
+ self.attn = Attention(dim=config.embed_dim, num_heads=config.num_heads)
253
+ self.mlp = MLP(dim=config.embed_dim, hidden_dim=config.intermediate_size)
254
+
255
+ def __call__(self, hidden_states, cu_seqlens, rotary_pos_emb) -> mx.array:
256
+ hidden_states = hidden_states + self.attn(
257
+ self.norm0(hidden_states),
258
+ cu_seqlens=cu_seqlens,
259
+ rotary_pos_emb=rotary_pos_emb,
260
+ )
261
+ hidden_states = hidden_states + self.mlp(self.norm1(hidden_states))
262
+ return hidden_states
263
+
264
+
265
+ class Rope2DPosEmb(nn.Module):
266
+ """2D rotary position embedding with multi-resolution support.
267
+
268
+ This class is intended to be used in the following way:
269
+ 1. Before training, create an instance of Rope2DPosEmb. This instance will hold the precomputed cis.
270
+ 2. Before each forward pass, call `get_freqs_cis_by_*` to get the `freqs_cis` tensor for this iteration.
271
+ 3. During the forward pass, pass the `freqs_cis` tensor to each attention layer, and call `apply` just before each attention operation.
272
+ The rope is shared across all attention layers and all heads.
273
+
274
+ Refs:
275
+ - RoFormer: https://arxiv.org/abs/2104.09864
276
+ - VisionLLaMA: https://arxiv.org/abs/2403.00522
277
+ - https://github.com/Meituan-AutoML/VisionLLaMA/blob/main/dit/models.py
278
+
279
+ Args:
280
+ dim (int): usually the multi-head attention dimension, should be divisible by 4 (TODO: relax this constraint if needed)
281
+ max_height (int): the maximum height of the 2D grid
282
+ max_width (int): the maximum width of the 2D grid
283
+ theta_base (float): the base of the theta
284
+ """
285
+
286
+ def __init__(self, dim: int, max_height: int, max_width: int, theta_base=10000):
287
+ super().__init__()
288
+ self.dim = dim
289
+ assert self.dim % 4 == 0, "dim must be divisible by 4"
290
+ self.max_height = max_height
291
+ self.max_width = max_width
292
+ self.theta_base = theta_base
293
+
294
+ self._freqs_cis = None
295
+
296
+ def extra_repr(self):
297
+ return f"dim={self.dim}, max_height={self.max_height}, max_width={self.max_width}, theta_base={self.theta_base}"
298
+
299
+ def _precompute_freqs_cis(self) -> mx.array:
300
+ """Calculate the cis(freqs) for each position in the 2D grid.
301
+
302
+ Return: complex array of shape (max_height, max_width, dim//2) and value:
303
+ height axis: ret[h, w, 2*i] = cis(h * theta_base**(-4*i/dim))
304
+ weight axis: ret[h, w, 2*i+1] = cis(w * theta_base**(-4*i/dim)) with (i in [0, dim//4))
305
+ note: `cis` is a mathematical notation defined by cis x = cos x + i sin x,
306
+ """
307
+ N = self.max_height * self.max_width
308
+ flat_pos = mx.arange(0, N, dtype=mx.float32)
309
+ x_pos = flat_pos % self.max_width
310
+ y_pos = flat_pos // self.max_width
311
+ dim_range = mx.arange(0, self.dim, 4)[: (self.dim // 4)].astype(
312
+ mx.float32
313
+ ) # C/4
314
+ freqs = 1.0 / (self.theta_base ** (dim_range / self.dim))
315
+ x_freqs = mx.outer(x_pos, freqs) # N, C/4
316
+ y_freqs = mx.outer(y_pos, freqs) # N, C/4
317
+
318
+ # Create complex numbers using cos and sin
319
+ x_cos = mx.cos(x_freqs)
320
+ x_sin = mx.sin(x_freqs)
321
+ y_cos = mx.cos(y_freqs)
322
+ y_sin = mx.sin(y_freqs)
323
+
324
+ # Create complex numbers
325
+ x_cis = x_cos + 1j * x_sin # N, C/4
326
+ y_cis = y_cos + 1j * y_sin # N, C/4
327
+
328
+ # N, C/4, 2
329
+ freqs_cis = mx.stack([x_cis, y_cis], axis=-1)
330
+
331
+ # max_height, max_width, C/2
332
+ freqs_cis = freqs_cis.reshape(self.max_height, self.max_width, -1)
333
+ return freqs_cis
334
+
335
+ def get_freqs_cis(self, grid_hws: mx.array) -> mx.array:
336
+ """
337
+ Args:
338
+ grid_hws (mx.array): grid height and width
339
+
340
+ Returns:
341
+ freqs_cis: array of shape (sum(t * height * width), dim//2)
342
+ """
343
+ if self._freqs_cis is None:
344
+ self._freqs_cis = self._precompute_freqs_cis()
345
+
346
+ shapes = grid_hws.tolist()
347
+ assert all(
348
+ 1 <= h <= self.max_height and 1 <= w <= self.max_width for h, w in shapes
349
+ ), (
350
+ shapes,
351
+ self.max_height,
352
+ self.max_width,
353
+ )
354
+
355
+ freqs_cis_list = []
356
+ for h, w in shapes:
357
+ # Get the slice of precomputed frequencies for this shape
358
+ shape_freqs = self._freqs_cis[:h, :w]
359
+ # Reshape to flatten the spatial dimensions
360
+ shape_freqs = shape_freqs.reshape(-1, self.dim // 2)
361
+ freqs_cis_list.append(shape_freqs)
362
+
363
+ freqs_cis = mx.concatenate(freqs_cis_list, axis=0)
364
+ return freqs_cis
365
+
366
+
367
+ def patch_merger(
368
+ x: mx.array,
369
+ grid_hws: mx.array,
370
+ merge_kernel_size: list[int, int] = (2, 2),
371
+ ) -> List[mx.array]:
372
+ d_model = x.shape[-1]
373
+
374
+ outputs = []
375
+ pre_sum = 0
376
+ for x_shape in grid_hws.tolist():
377
+ height, width = x_shape[0], x_shape[1]
378
+ # Get the current sequence
379
+ seq = x[pre_sum : pre_sum + height * width]
380
+ # Reshape along self.merge_kernel_size and concat to the last dimension
381
+ kernel_height, kernel_width = merge_kernel_size
382
+ new_height, new_width = height // kernel_height, width // kernel_width
383
+ reshaped_seq = seq.reshape(
384
+ new_height, kernel_height, new_width, kernel_width, d_model
385
+ )
386
+ reshaped_seq = mx.transpose(reshaped_seq, (0, 2, 1, 3, 4))
387
+ padded_seq = reshaped_seq.reshape(
388
+ new_height * new_width, kernel_height * kernel_width, -1
389
+ )
390
+ outputs.append(padded_seq)
391
+ pre_sum += height * width
392
+
393
+ return outputs
394
+
395
+
396
+ class VisionModel(nn.Module):
397
+
398
+ def __init__(self, config: VisionConfig) -> None:
399
+ super().__init__()
400
+ self.config = config
401
+ self.model_type = config.model_type
402
+ if self.model_type not in ["qwen2_vl", "moonvit"]:
403
+ raise ValueError(f"Unsupported model type: {self.model_type}")
404
+ self.spatial_merge_size = config.spatial_merge_size
405
+ self.merge_kernel_size = config.merge_kernel_size
406
+
407
+ self.patch_embed = PatchEmbed(
408
+ patch_size=config.patch_size,
409
+ num_channels=config.num_channels,
410
+ embed_dim=config.embed_dim,
411
+ init_pos_emb_height=config.init_pos_emb_height,
412
+ )
413
+
414
+ head_dim = config.embed_dim // config.num_heads
415
+ self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
416
+
417
+ self.blocks = [Qwen2VLVisionBlock(config) for _ in range(config.depth)]
418
+ self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=1e-6)
419
+ self.rope_pos_emb = Rope2DPosEmb(head_dim, 512, 512)
420
+
421
+ def __call__(
422
+ self,
423
+ hidden_states: mx.array,
424
+ grid_thw: mx.array,
425
+ output_hidden_states: Optional[bool] = None,
426
+ ) -> mx.array:
427
+
428
+ hidden_states = self.patch_embed(hidden_states, grid_thw)
429
+ rotary_pos_emb = self.rope_pos_emb.get_freqs_cis(grid_thw)
430
+
431
+ # Assuming grid_thw has shape (batch_size, 3)
432
+ batch_size = grid_thw.shape[0]
433
+
434
+ # Calculate cu_seqlens for each item in the batch
435
+ lengths = mx.concatenate(
436
+ (
437
+ mx.zeros((1,), dtype=grid_thw.dtype),
438
+ grid_thw[:, 0] * grid_thw[:, 1],
439
+ )
440
+ )
441
+ cu_seqlens = mx.cumsum(lengths.astype(mx.int32), axis=0)
442
+
443
+ encoder_states = (hidden_states,) if output_hidden_states else None
444
+
445
+ for blk in self.blocks:
446
+ hidden_states = blk(
447
+ hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
448
+ )
449
+ if output_hidden_states:
450
+ encoder_states = encoder_states + (hidden_states,)
451
+
452
+ hidden_states = self.final_layernorm(hidden_states)
453
+
454
+ hidden_states = patch_merger(
455
+ hidden_states, grid_thw, merge_kernel_size=self.merge_kernel_size
456
+ )
457
+
458
+ return hidden_states
459
+
460
+ def sanitize(self, weights):
461
+ sanitized_weights = {}
462
+ for k, v in weights.items():
463
+ if "position_ids" in k:
464
+ # Remove unused position_ids
465
+ continue
466
+ elif "patch_embed.proj.weight" in k:
467
+ # PyTorch conv2d weight tensors have shape:
468
+ # [out_channels, in_channels, kH, KW]
469
+ # MLX conv2d expects the weight be of shape:
470
+ # [out_channels, kH, KW, in_channels]
471
+ if check_array_shape(v):
472
+ sanitized_weights[k] = v
473
+ else:
474
+ sanitized_weights[k] = v.transpose(0, 2, 3, 1)
475
+
476
+ elif "vision_tower.blocks" in k:
477
+ if "attn" not in k and ("wqkv" in k or "wo" in k):
478
+ new_key = k.replace("wqkv", "attn.wqkv").replace("wo", "attn.wo")
479
+ sanitized_weights[new_key] = v
480
+ else:
481
+ sanitized_weights[k] = v
482
+ else:
483
+ sanitized_weights[k] = v
484
+
485
+ return sanitized_weights
@@ -0,0 +1,2 @@
1
+ from .config import ModelConfig, TextConfig, VisionConfig
2
+ from .lfm2_vl import LanguageModel, Model, VisionModel
@@ -0,0 +1,94 @@
1
+ from dataclasses import dataclass
2
+ from typing import List, Tuple
3
+
4
+ from ..base import BaseModelConfig
5
+
6
+
7
+ @dataclass
8
+ class TextConfig(BaseModelConfig):
9
+ model_type: str = "lfm2"
10
+ hidden_size: int = 1024
11
+ num_hidden_layers: int = 16
12
+ intermediate_size: int = 6656
13
+ num_attention_heads: int = 16
14
+ num_key_value_heads: int = 8
15
+ max_position_embeddings: int = 128000
16
+ rope_theta: float = 1000000.0
17
+ vocab_size: int = 65536
18
+ eos_token_id: int = 7
19
+ initializer_range: float = 0.02
20
+ norm_eps: float = 1e-05
21
+ use_cache: bool = True
22
+ use_pos_enc: bool = True
23
+ block_auto_adjust_ff_dim: bool = True
24
+ block_dim: int = 1024
25
+ block_ff_dim: int = 6656
26
+ block_ffn_dim_multiplier: float = 1.0
27
+ block_mlp_init_scale: float = 1.0
28
+ block_multiple_of: int = 256
29
+ block_norm_eps: float = 1e-05
30
+ block_out_init_scale: float = 1.0
31
+ block_use_swiglu: bool = True
32
+ block_use_xavier_init: bool = True
33
+ conv_L_cache: int = 3
34
+ conv_bias: bool = False
35
+ conv_dim: int = 1024
36
+ conv_dim_out: int = 1024
37
+ conv_use_xavier_init: bool = True
38
+ layer_types: List[str] = None
39
+ num_heads: int = 16
40
+ full_attn_idxs: List[int] = None
41
+
42
+ def __post_init__(self):
43
+
44
+ if self.full_attn_idxs is None:
45
+ self.full_attn_idxs = [
46
+ i
47
+ for i, layer_type in enumerate(self.layer_types)
48
+ if layer_type == "full_attention"
49
+ ]
50
+
51
+
52
+ @dataclass
53
+ class VisionConfig(BaseModelConfig):
54
+ model_type: str = "lfm2_vl"
55
+ hidden_size: int = 768
56
+ intermediate_size: int = 3072
57
+ num_hidden_layers: int = 12
58
+ num_attention_heads: int = 12
59
+ num_channels: int = 3
60
+ image_size: int = 224
61
+ patch_size: int = 16
62
+ num_patches: int = 256
63
+ attention_dropout: float = 0.0
64
+ layer_norm_eps: float = 1e-06
65
+ hidden_act: str = "gelu_pytorch_tanh"
66
+ vision_use_head: bool = False
67
+ num_positions: int = None
68
+ spatial_shapes: List[Tuple[int, int]] = None
69
+
70
+
71
+ @dataclass
72
+ class ModelConfig(BaseModelConfig):
73
+ text_config: TextConfig
74
+ vision_config: VisionConfig
75
+ model_type: str = "lfm2-vl"
76
+ do_image_splitting: bool = True
77
+ downsample_factor: int = 2
78
+ encoder_patch_size: int = 16
79
+ image_token_index: int = 396
80
+ max_image_tokens: int = 256
81
+ max_num_patches: int = 1024
82
+ max_pixels_tolerance: float = 2.0
83
+ max_tiles: int = 10
84
+ min_image_tokens: int = 64
85
+ min_tiles: int = 2
86
+ tile_size: int = 512
87
+ use_image_special_tokens: bool = True
88
+ use_thumbnail: bool = False
89
+ vision_feature_layer: int = -1
90
+ projector_bias: bool = True
91
+ projector_hidden_act: str = "gelu"
92
+ projector_hidden_size: int = 2560
93
+ eos_token_id: int = 7
94
+ projector_use_layernorm: bool = True
@@ -0,0 +1,49 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ from mlx_lm.models.cache import ArraysCache, KVCache
6
+ from mlx_lm.models.lfm2 import Lfm2Model
7
+
8
+ from ..base import LanguageModelOutput
9
+ from .config import TextConfig
10
+
11
+
12
+ class LanguageModel(nn.Module):
13
+ def __init__(self, config: TextConfig):
14
+ super().__init__()
15
+ self.config = config
16
+ self.model_type = config.model_type
17
+ self.model = Lfm2Model(config)
18
+
19
+ def __call__(
20
+ self,
21
+ inputs: mx.array,
22
+ mask: mx.array = None,
23
+ cache=None,
24
+ inputs_embeds: Optional[mx.array] = None,
25
+ **kwargs,
26
+ ):
27
+ out = self.model(inputs, cache, inputs_embeds)
28
+ out = self.model.embed_tokens.as_linear(out)
29
+ return LanguageModelOutput(out)
30
+
31
+ def sanitize(self, weights):
32
+ sanitized_weights = {}
33
+ for name, param in weights.items():
34
+ if "conv.weight" in name:
35
+ if param.shape[-1] > param.shape[1]:
36
+ param = param.transpose(0, 2, 1)
37
+
38
+ sanitized_weights[name] = param
39
+ return sanitized_weights
40
+
41
+ @property
42
+ def layers(self):
43
+ return self.model.layers
44
+
45
+ def make_cache(self):
46
+ return [
47
+ KVCache() if l.is_attention_layer else ArraysCache(size=1)
48
+ for l in self.layers
49
+ ]