fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,294 @@
1
+ from types import SimpleNamespace
2
+ from typing import Optional
3
+
4
+ import mlx.core as mx
5
+ import mlx.nn as nn
6
+ import numpy as np
7
+
8
+ from .config import VisionConfig
9
+
10
+
11
+ def check_array_shape(arr):
12
+ shape = arr.shape
13
+
14
+ # Check if the shape has 4 dimensions
15
+ if len(shape) != 4:
16
+ return False
17
+
18
+ out_channels, kH, KW, _ = shape
19
+
20
+ # Check if out_channels is the largest, and kH and KW are the same
21
+ if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
22
+ return True
23
+ else:
24
+ return False
25
+
26
+
27
+ class Attention(nn.Module):
28
+ def __init__(
29
+ self,
30
+ dims: int,
31
+ num_heads: int,
32
+ query_input_dims: Optional[int] = None,
33
+ key_input_dims: Optional[int] = None,
34
+ value_input_dims: Optional[int] = None,
35
+ value_dims: Optional[int] = None,
36
+ value_output_dims: Optional[int] = None,
37
+ bias: bool = False,
38
+ ):
39
+ super().__init__()
40
+
41
+ if (dims % num_heads) != 0:
42
+ raise ValueError(
43
+ "The input feature dimensions should be divisible by the "
44
+ f"number of heads ({dims} % {num_heads}) != 0"
45
+ )
46
+
47
+ query_input_dims = query_input_dims or dims
48
+ key_input_dims = key_input_dims or dims
49
+ value_input_dims = value_input_dims or key_input_dims
50
+ value_dims = value_dims or dims
51
+ value_output_dims = value_output_dims or dims
52
+
53
+ self.num_heads = num_heads = num_heads
54
+ head_dim = dims // num_heads
55
+ self.scale = head_dim**-0.5
56
+
57
+ self.q_proj = nn.Linear(query_input_dims, dims, bias=bias)
58
+ self.k_proj = nn.Linear(key_input_dims, dims, bias=bias)
59
+ self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias)
60
+ self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias)
61
+
62
+ def __call__(self, queries, keys, values, mask=None):
63
+ queries = self.q_proj(queries)
64
+ keys = self.k_proj(keys)
65
+ values = self.v_proj(values)
66
+
67
+ num_heads = self.num_heads
68
+ B, L, D = queries.shape
69
+ _, S, _ = keys.shape
70
+ queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
71
+ keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
72
+ values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
73
+
74
+ output = mx.fast.scaled_dot_product_attention(
75
+ queries, keys, values, scale=self.scale, mask=mask
76
+ )
77
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
78
+
79
+ return self.out_proj(output)
80
+
81
+
82
+ class MLP(nn.Module):
83
+ def __init__(self, config: VisionConfig):
84
+ super().__init__()
85
+ self.activation_fn = nn.GELU(approx="fast")
86
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
87
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
88
+
89
+ def __call__(self, x: mx.array) -> mx.array:
90
+ x = self.activation_fn(self.fc1(x))
91
+ x = self.fc2(x)
92
+ return x
93
+
94
+
95
+ class EncoderLayer(nn.Module):
96
+ def __init__(self, config: VisionConfig):
97
+ super().__init__()
98
+ self.embed_dim = config.hidden_size
99
+ self.self_attn = Attention(
100
+ config.hidden_size, config.num_attention_heads, bias=True
101
+ )
102
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
103
+ self.mlp = MLP(config)
104
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
105
+
106
+ def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
107
+ y = self.layer_norm1(x)
108
+ y = self.self_attn(y, y, y, mask)
109
+ x = x + y
110
+ y = self.layer_norm2(x)
111
+ y = self.mlp(y)
112
+ return x + y
113
+
114
+
115
+ class Encoder(nn.Module):
116
+ def __init__(self, config: VisionConfig):
117
+ super().__init__()
118
+ self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)]
119
+
120
+
121
+ class VisionEmbeddings(nn.Module):
122
+ def __init__(self, config: VisionConfig):
123
+ super().__init__()
124
+ self.config = config
125
+ self.embed_dim = config.hidden_size
126
+ self.image_size = config.image_size
127
+ self.patch_size = config.patch_size
128
+
129
+ self.class_embedding = mx.zeros((config.hidden_size,))
130
+
131
+ self.patch_embedding = nn.Conv2d(
132
+ in_channels=config.num_channels,
133
+ out_channels=self.embed_dim,
134
+ kernel_size=self.patch_size,
135
+ stride=self.patch_size,
136
+ bias=False,
137
+ )
138
+
139
+ self.num_patches = (self.image_size // self.patch_size) ** 2
140
+ self.num_positions = self.num_patches + 1
141
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
142
+
143
+ def __call__(self, x: mx.array) -> mx.array:
144
+ batch_size = x.shape[0]
145
+ patch_embeddings = self.patch_embedding(x)
146
+ patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2)
147
+ embed_dim = patch_embeddings.shape[-1]
148
+ cls_embeddings = mx.broadcast_to(
149
+ self.class_embedding, (batch_size, 1, embed_dim)
150
+ )
151
+ position_ids = mx.array(np.arange(self.num_positions)[None, :])
152
+
153
+ embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1)
154
+ embeddings += self.position_embedding(position_ids)
155
+ return embeddings
156
+
157
+
158
+ class ClipModel(nn.Module):
159
+ def __init__(self, config: VisionConfig):
160
+ super().__init__()
161
+ self.model_type = config.model_type
162
+ self.embeddings = VisionEmbeddings(config)
163
+ self.pre_layrnorm = nn.LayerNorm(config.hidden_size)
164
+ self.encoder = Encoder(config)
165
+ self.post_layernorm = nn.LayerNorm(config.hidden_size)
166
+
167
+ def __call__(
168
+ self,
169
+ x: mx.array,
170
+ output_hidden_states: Optional[bool] = None,
171
+ ) -> mx.array:
172
+ x = self.embeddings(x)
173
+ x = self.pre_layrnorm(x)
174
+
175
+ encoder_states = (x,) if output_hidden_states else None
176
+
177
+ for l in self.encoder.layers:
178
+ x = l(x, mask=None)
179
+ if output_hidden_states:
180
+ encoder_states = encoder_states + (x,)
181
+
182
+ pooler_output = self.post_layernorm(x[:, 0, :])
183
+ return pooler_output, x, encoder_states
184
+
185
+
186
+ class ClipVModel(nn.Module):
187
+ def __init__(self, config):
188
+ super().__init__()
189
+ self.model_type = config.model_type
190
+ self.vision_model = ClipModel(config)
191
+
192
+
193
+ class VisionModel(nn.Module):
194
+ CLIP_VIT_LARGE_PATCH14_336_CONFIG = SimpleNamespace(
195
+ model_type="phi3_v",
196
+ hidden_size=1024,
197
+ image_size=336,
198
+ intermediate_size=4096,
199
+ layer_norm_eps=1e-05,
200
+ num_attention_heads=16,
201
+ num_channels=3,
202
+ num_hidden_layers=24,
203
+ patch_size=14,
204
+ )
205
+
206
+ def __init__(self, config):
207
+ super().__init__()
208
+ self.model_type = config.model_type
209
+ self.img_processor = ClipVModel(self.CLIP_VIT_LARGE_PATCH14_336_CONFIG)
210
+ self.image_dim_out = image_dim_out = 1024
211
+ self.glb_GN = mx.zeros([1, 1, image_dim_out * 4])
212
+ self.sub_GN = mx.zeros([1, 1, 1, image_dim_out * 4])
213
+ self.img_projection = [
214
+ nn.Linear(image_dim_out * 4, config.hidden_size),
215
+ nn.GELU(),
216
+ nn.Linear(config.hidden_size, config.hidden_size),
217
+ ]
218
+
219
+ def __call__(
220
+ self,
221
+ img_embeds,
222
+ txt_embeds=None,
223
+ img_sizes=None,
224
+ positions=None,
225
+ output_hidden_states=None,
226
+ ):
227
+ if output_hidden_states:
228
+ return self.img_processor.vision_model(
229
+ img_embeds, output_hidden_states=output_hidden_states
230
+ )
231
+ img_embeds = mx.array(img_embeds)
232
+ img_sizes = mx.array(img_sizes)
233
+ B = img_embeds.shape[0]
234
+ img_sizes = (img_sizes // 336).tolist()
235
+ img_features = self.img_processor.vision_model(
236
+ img_embeds.reshape(-1, *img_embeds.shape[2:]).transpose(0, 2, 3, 1), True
237
+ )[-1][-2][:, 1:]
238
+ img_features = img_features.reshape(B, -1, *img_features.shape[1:])
239
+ C, H = self.image_dim_out, int(img_features.shape[2] ** 0.5)
240
+ output_imgs, output_len = [], []
241
+ for _bs in range(B):
242
+ h, w = img_sizes[_bs]
243
+ B_ = h * w
244
+
245
+ def _reshape_and_concatenate(img, shape, tile_shape):
246
+ return mx.concatenate(
247
+ [
248
+ img.reshape(shape)
249
+ .transpose(0, 1, 3, 2, 4, 5)
250
+ .reshape(tile_shape),
251
+ mx.tile(self.sub_GN, (1, tile_shape[1], 1, 1)),
252
+ ],
253
+ axis=2,
254
+ ).reshape(1, -1, 4 * C)
255
+
256
+ glb_img = _reshape_and_concatenate(
257
+ img_features[_bs, :1],
258
+ (1, H // 2, 2, H // 2, 2, C),
259
+ (1, H // 2, H // 2, 4 * C),
260
+ )
261
+ sub_img = _reshape_and_concatenate(
262
+ img_features[_bs, 1 : B_ + 1],
263
+ (B_, H // 2, 2, H // 2, 2, C),
264
+ (1, h * 12, w * 12, 4 * C),
265
+ )
266
+ x = mx.concatenate([sub_img, self.glb_GN, glb_img], axis=1)
267
+ for l in self.img_projection:
268
+ x = l(x)
269
+ output_imgs.append(np.array(x.astype(mx.float32)))
270
+ output_len.append(int((h * w + 1) * 144 + 1 + (h + 1) * 12))
271
+ idx = 0
272
+ txt_embeds = np.array(txt_embeds.astype(mx.float32))
273
+ for i, cnt in enumerate(output_len):
274
+ txt_embeds[
275
+ positions[idx][0], positions[idx][1] : positions[idx][1] + cnt
276
+ ] = output_imgs[i]
277
+ idx += cnt
278
+ txt_embeds = mx.array(txt_embeds)
279
+ return txt_embeds
280
+
281
+ def sanitize(self, weights):
282
+ sanitized_weights = {}
283
+ for k, v in weights.items():
284
+ if "position_ids" in k:
285
+ continue
286
+ elif "patch_embedding.weight" in k:
287
+ if check_array_shape(v):
288
+ sanitized_weights[k] = v
289
+ else:
290
+ sanitized_weights[k] = v.transpose(0, 2, 3, 1)
291
+ else:
292
+ sanitized_weights[k] = v
293
+
294
+ return sanitized_weights
@@ -0,0 +1,4 @@
1
+ from .config import ModelConfig, TextConfig, VisionConfig
2
+ from .language import LanguageModel
3
+ from .pixtral import Model
4
+ from .vision import VisionModel
@@ -0,0 +1,69 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import Dict, List, Optional, Union
3
+
4
+ from ..base import BaseModelConfig
5
+
6
+
7
+ @dataclass
8
+ class ModelConfig(BaseModelConfig):
9
+ text_config: "TextConfig" = field(default_factory=lambda: TextConfig())
10
+ vision_config: "VisionConfig" = field(default_factory=lambda: VisionConfig())
11
+ model_type: str = "pixtral"
12
+ ignore_index: int = -100
13
+ image_token_index: int = None
14
+ image_token_id: int = None
15
+ vision_feature_select_strategy: str = "full"
16
+ vision_feature_layer: int = -1
17
+ vocab_size: int = 32000
18
+ eos_token_id: Optional[List[int]] = None
19
+
20
+ def __post_init__(self):
21
+ if self.image_token_index is None:
22
+ self.image_token_index = self.image_token_id
23
+
24
+
25
+ @dataclass
26
+ class TextConfig(BaseModelConfig):
27
+ model_type: str = "mistral"
28
+ hidden_size: int = 5120
29
+ head_dim: int = 128
30
+ num_hidden_layers: int = 40
31
+ intermediate_size: int = 14336
32
+ num_attention_heads: int = 32
33
+ rms_norm_eps: float = 1e-06
34
+ vocab_size: int = 131072
35
+ num_key_value_heads: int = 8
36
+ rope_theta: float = 1000000000.0
37
+ rope_traditional: bool = False
38
+ rope_scaling: Optional[Dict[str, Union[float, str]]] = None
39
+ max_position_embeddings: int = 4096
40
+ use_qk_norm: bool = False
41
+
42
+ def __post_init__(self):
43
+ if self.num_key_value_heads is None:
44
+ self.num_key_value_heads = self.num_attention_heads
45
+
46
+ if self.rope_scaling:
47
+ required_keys = {"factor", "type"}
48
+ if not all(key in self.rope_scaling for key in required_keys):
49
+ raise ValueError(f"rope_scaling must contain keys {required_keys}")
50
+
51
+ if self.rope_scaling["type"] != "linear":
52
+ raise ValueError("rope_scaling 'type' currently only supports 'linear'")
53
+
54
+
55
+ @dataclass
56
+ class VisionConfig(BaseModelConfig):
57
+ model_type: str = "pixtral"
58
+ num_hidden_layers: int = 24
59
+ hidden_size: int = 1024
60
+ head_dim: int = 64
61
+ intermediate_size: int = 4096
62
+ num_attention_heads: int = 16
63
+ image_size: int = 336
64
+ patch_size: int = 14
65
+ projection_dim: int = 768
66
+ vocab_size: int = 32000
67
+ num_channels: int = 3
68
+ rms_norm_eps: float = 1e-5
69
+ rope_theta: float = 10000.0
@@ -0,0 +1,195 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from ..base import (
7
+ LanguageModelOutput,
8
+ create_attention_mask,
9
+ scaled_dot_product_attention,
10
+ )
11
+ from ..cache import KVCache
12
+ from .config import TextConfig
13
+
14
+
15
+ class Attention(nn.Module):
16
+ def __init__(self, config: TextConfig):
17
+ super().__init__()
18
+
19
+ dim = config.hidden_size
20
+ self.n_heads = n_heads = config.num_attention_heads
21
+ self.n_kv_heads = n_kv_heads = config.num_key_value_heads
22
+
23
+ head_dim = config.head_dim
24
+ self.scale = head_dim**-0.5
25
+
26
+ self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
27
+ self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
28
+ self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
29
+ self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
30
+
31
+ rope_scale = (
32
+ 1 / config.rope_scaling["factor"]
33
+ if config.rope_scaling is not None
34
+ and config.rope_scaling["type"] == "linear"
35
+ else 1
36
+ )
37
+ self.rope = nn.RoPE(
38
+ head_dim,
39
+ traditional=config.rope_traditional,
40
+ base=config.rope_theta,
41
+ scale=rope_scale,
42
+ )
43
+ self.use_qk_norm = config.use_qk_norm
44
+
45
+ if self.use_qk_norm:
46
+ self.q_norm = nn.RMSNorm(dims=head_dim, eps=config.rms_norm_eps)
47
+ self.k_norm = nn.RMSNorm(dims=head_dim, eps=config.rms_norm_eps)
48
+
49
+ def __call__(
50
+ self,
51
+ x: mx.array,
52
+ mask: Optional[mx.array] = None,
53
+ cache: Optional[KVCache] = None,
54
+ ) -> mx.array:
55
+ B, L, D = x.shape
56
+
57
+ queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
58
+
59
+ # Prepare the queries, keys and values for the attention computation
60
+ queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
61
+ keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
62
+ values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
63
+
64
+ if self.use_qk_norm:
65
+ queries = self.q_norm(queries)
66
+ keys = self.k_norm(keys)
67
+
68
+ if cache is not None:
69
+ queries = self.rope(queries, offset=cache.offset)
70
+ keys = self.rope(keys, offset=cache.offset)
71
+ keys, values = cache.update_and_fetch(keys, values)
72
+ else:
73
+ queries = self.rope(queries)
74
+ keys = self.rope(keys)
75
+
76
+ output = scaled_dot_product_attention(
77
+ queries, keys, values, cache, scale=self.scale, mask=mask
78
+ )
79
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
80
+ return self.o_proj(output)
81
+
82
+
83
+ class MLP(nn.Module):
84
+ def __init__(self, dim, hidden_dim):
85
+ super().__init__()
86
+ self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
87
+ self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
88
+ self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
89
+
90
+ def __call__(self, x) -> mx.array:
91
+ return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
92
+
93
+
94
+ class TransformerBlock(nn.Module):
95
+ def __init__(self, config: TextConfig):
96
+ super().__init__()
97
+ self.num_attention_heads = config.num_attention_heads
98
+ self.hidden_size = config.hidden_size
99
+ self.self_attn = Attention(config)
100
+ self.mlp = MLP(config.hidden_size, config.intermediate_size)
101
+ self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
102
+ self.post_attention_layernorm = nn.RMSNorm(
103
+ config.hidden_size, eps=config.rms_norm_eps
104
+ )
105
+ self.config = config
106
+
107
+ def __call__(
108
+ self,
109
+ x: mx.array,
110
+ mask: Optional[mx.array] = None,
111
+ cache: Optional[KVCache] = None,
112
+ ) -> mx.array:
113
+ r = self.self_attn(self.input_layernorm(x), mask, cache)
114
+ h = x + r
115
+ r = self.mlp(self.post_attention_layernorm(h))
116
+ out = h + r
117
+ return out
118
+
119
+
120
+ class Mistral(nn.Module):
121
+ def __init__(self, config: TextConfig):
122
+ super().__init__()
123
+ self.config = config
124
+ self.vocab_size = config.vocab_size
125
+ self.num_hidden_layers = config.num_hidden_layers
126
+ assert self.vocab_size > 0
127
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
128
+ self.layers = [
129
+ TransformerBlock(config=config) for _ in range(config.num_hidden_layers)
130
+ ]
131
+ self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
132
+
133
+ def __call__(
134
+ self,
135
+ inputs: mx.array,
136
+ inputs_embeds: Optional[mx.array] = None,
137
+ mask: Optional[mx.array] = None,
138
+ cache=None,
139
+ ):
140
+ # for passing merged input embeddings
141
+ if inputs_embeds is None:
142
+ h = self.embed_tokens(inputs)
143
+ else:
144
+ h = inputs_embeds
145
+
146
+ if cache is None:
147
+ cache = [None] * len(self.layers)
148
+
149
+ if mask is None:
150
+ mask = create_attention_mask(h, cache)
151
+
152
+ for layer, c in zip(self.layers, cache):
153
+ h = layer(h, mask, c)
154
+
155
+ return self.norm(h)
156
+
157
+
158
+ class LanguageModel(nn.Module):
159
+ def __init__(self, config: TextConfig):
160
+ super().__init__()
161
+ self.config = config
162
+ self.model_type = config.model_type
163
+ self.model = Mistral(config)
164
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
165
+
166
+ def __call__(
167
+ self,
168
+ inputs: mx.array,
169
+ inputs_embeds: Optional[mx.array] = None,
170
+ mask: Optional[mx.array] = None,
171
+ cache=None,
172
+ **kwargs,
173
+ ):
174
+ out = self.model(inputs, mask=mask, cache=cache, inputs_embeds=inputs_embeds)
175
+ logits = self.lm_head(out)
176
+ return LanguageModelOutput(logits=logits)
177
+
178
+ @staticmethod
179
+ def sanitize(weights):
180
+ # Remove unused precomputed rotary freqs
181
+ return {
182
+ k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
183
+ }
184
+
185
+ @property
186
+ def layers(self):
187
+ return self.model.layers
188
+
189
+ @property
190
+ def head_dim(self):
191
+ return self.config.head_dim
192
+
193
+ @property
194
+ def n_kv_heads(self):
195
+ return self.config.num_key_value_heads