fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,161 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from ..base import (
7
+ LanguageModelOutput,
8
+ create_attention_mask,
9
+ scaled_dot_product_attention,
10
+ )
11
+ from ..cache import KVCache
12
+ from .config import TextConfig
13
+
14
+
15
+ class Attention(nn.Module):
16
+ def __init__(self, config: TextConfig):
17
+ super().__init__()
18
+
19
+ dim = config.hidden_size
20
+ self.n_heads = n_heads = config.num_attention_heads
21
+ self.n_kv_heads = n_kv_heads = config.num_key_value_heads
22
+
23
+ head_dim = config.hidden_size // n_heads
24
+ self.scale = head_dim**-0.5
25
+
26
+ self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
27
+ self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
28
+ self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
29
+ self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
30
+
31
+ self.rope = nn.RoPE(
32
+ head_dim,
33
+ traditional=config.rope_traditional,
34
+ base=config.rope_theta,
35
+ )
36
+
37
+ def __call__(
38
+ self,
39
+ x: mx.array,
40
+ mask: Optional[mx.array] = None,
41
+ cache: Optional[KVCache] = None,
42
+ ) -> mx.array:
43
+ B, L, D = x.shape
44
+
45
+ queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
46
+
47
+ # Prepare the queries, keys and values for the attention computation
48
+ queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
49
+ keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
50
+ values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
51
+
52
+ if cache is not None:
53
+ queries = self.rope(queries, offset=cache.offset)
54
+ keys = self.rope(keys, offset=cache.offset)
55
+ keys, values = cache.update_and_fetch(keys, values)
56
+ else:
57
+ queries = self.rope(queries)
58
+ keys = self.rope(keys)
59
+
60
+ output = scaled_dot_product_attention(
61
+ queries, keys, values, cache, scale=self.scale, mask=mask
62
+ )
63
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
64
+ return self.o_proj(output)
65
+
66
+
67
+ class MLP(nn.Module):
68
+ def __init__(self, dim, hidden_dim):
69
+ super().__init__()
70
+ self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
71
+ self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
72
+ self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
73
+
74
+ def __call__(self, x) -> mx.array:
75
+ return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
76
+
77
+
78
+ class TransformerBlock(nn.Module):
79
+ def __init__(self, config: TextConfig):
80
+ super().__init__()
81
+ self.num_attention_heads = config.num_attention_heads
82
+ self.hidden_size = config.hidden_size
83
+ self.self_attn = Attention(config)
84
+ self.mlp = MLP(config.hidden_size, config.intermediate_size)
85
+ self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
86
+ self.post_attention_layernorm = nn.RMSNorm(
87
+ config.hidden_size, eps=config.rms_norm_eps
88
+ )
89
+ self.config = config
90
+
91
+ def __call__(
92
+ self,
93
+ x: mx.array,
94
+ mask: Optional[mx.array] = None,
95
+ cache: Optional[KVCache] = None,
96
+ ) -> mx.array:
97
+ r = self.self_attn(self.input_layernorm(x), mask, cache)
98
+ h = x + r
99
+ r = self.mlp(self.post_attention_layernorm(h))
100
+ out = h + r
101
+ return out
102
+
103
+
104
+ class LanguageModel(nn.Module):
105
+ def __init__(self, config: TextConfig):
106
+ super().__init__()
107
+ self.config = config
108
+ self.model_type = config.model_type
109
+ self.vocab_size = config.vocab_size
110
+ self.num_hidden_layers = config.num_hidden_layers
111
+ assert self.vocab_size > 0
112
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
113
+ self.layers = [
114
+ TransformerBlock(config=config) for _ in range(config.num_hidden_layers)
115
+ ]
116
+ self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
117
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
118
+
119
+ def __call__(
120
+ self,
121
+ inputs: mx.array,
122
+ inputs_embeds: Optional[mx.array] = None,
123
+ mask: Optional[mx.array] = None,
124
+ cache=None,
125
+ **kwargs,
126
+ ):
127
+ # for passing merged input embeddings
128
+ if inputs_embeds is None:
129
+ h = self.embed_tokens(inputs)
130
+ else:
131
+ h = inputs_embeds
132
+
133
+ if cache is None:
134
+ cache = [None] * len(self.layers)
135
+
136
+ if mask is None:
137
+ mask = create_attention_mask(h, cache)
138
+
139
+ for layer, c in zip(self.layers, cache):
140
+ h = layer(h, mask, c)
141
+
142
+ logits = self.lm_head(self.norm(h))
143
+ return LanguageModelOutput(logits=logits)
144
+
145
+ def sanitize(self, weights):
146
+ # Remove unused precomputed rotary freqs
147
+ return {
148
+ k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
149
+ }
150
+
151
+ @property
152
+ def layers(self):
153
+ return self.model.layers
154
+
155
+ @property
156
+ def head_dim(self):
157
+ return self.config.hidden_size // self.config.num_attention_heads
158
+
159
+ @property
160
+ def n_kv_heads(self):
161
+ return self.config.num_key_value_heads
@@ -0,0 +1,244 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ import numpy as np
6
+
7
+ from .config import VisionConfig
8
+
9
+
10
+ def check_array_shape(arr):
11
+ shape = arr.shape
12
+
13
+ # Check if the shape has 4 dimensions
14
+ if len(shape) != 4:
15
+ return False
16
+
17
+ out_channels, kH, KW, _ = shape
18
+
19
+ # Check if out_channels is the largest, and kH and KW are the same
20
+ if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
21
+ return True
22
+ else:
23
+ return False
24
+
25
+
26
+ class Attention(nn.Module):
27
+ def __init__(
28
+ self,
29
+ dims: int,
30
+ num_heads: int,
31
+ query_input_dims: Optional[int] = None,
32
+ key_input_dims: Optional[int] = None,
33
+ value_input_dims: Optional[int] = None,
34
+ value_dims: Optional[int] = None,
35
+ value_output_dims: Optional[int] = None,
36
+ ):
37
+ super().__init__()
38
+
39
+ if (dims % num_heads) != 0:
40
+ raise ValueError(
41
+ "The input feature dimensions should be divisible by the "
42
+ f"number of heads ({dims} % {num_heads}) != 0"
43
+ )
44
+
45
+ query_input_dims = query_input_dims or dims
46
+ key_input_dims = key_input_dims or dims
47
+ value_input_dims = value_input_dims or key_input_dims
48
+ value_dims = value_dims or dims
49
+ value_output_dims = value_output_dims or dims
50
+
51
+ self.num_heads = num_heads
52
+ head_dim = dims // num_heads
53
+ self.scale = head_dim**-0.5
54
+
55
+ self.q_proj = nn.Linear(query_input_dims, dims, bias=True)
56
+ self.k_proj = nn.Linear(key_input_dims, dims, bias=True)
57
+ self.v_proj = nn.Linear(value_input_dims, value_dims, bias=True)
58
+ self.out_proj = nn.Linear(value_dims, value_output_dims, bias=True)
59
+
60
+ def __call__(self, x: mx.array, mask=None):
61
+ B, L, _ = x.shape
62
+ queries = self.q_proj(x)
63
+ keys = self.k_proj(x)
64
+ values = self.v_proj(x)
65
+
66
+ num_heads = self.num_heads
67
+
68
+ queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
69
+ keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
70
+ values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
71
+ if mask is not None:
72
+ mask = mask[:, :, mask.shape[-2] :, :]
73
+
74
+ output = mx.fast.scaled_dot_product_attention(
75
+ queries, keys, values, scale=self.scale, mask=mask
76
+ )
77
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
78
+ return self.out_proj(output)
79
+
80
+
81
+ class MLP(nn.Module):
82
+ def __init__(self, config: VisionConfig):
83
+ super().__init__()
84
+ self.activation_fn = nn.GELU(approx="fast")
85
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
86
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
87
+
88
+ def __call__(self, x: mx.array) -> mx.array:
89
+ x = self.activation_fn(self.fc1(x))
90
+ x = self.fc2(x)
91
+ return x
92
+
93
+
94
+ class EncoderLayer(nn.Module):
95
+ def __init__(self, config: VisionConfig):
96
+ super().__init__()
97
+ self.embed_dim = config.hidden_size
98
+ self.self_attn = Attention(config.hidden_size, config.num_attention_heads)
99
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
100
+ self.mlp = MLP(config)
101
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
102
+
103
+ def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
104
+ y = self.layer_norm1(x)
105
+ y = self.self_attn(y, mask)
106
+ x = x + y
107
+ y = self.layer_norm2(x)
108
+ y = self.mlp(y)
109
+ return x + y
110
+
111
+
112
+ class Encoder(nn.Module):
113
+ def __init__(self, config: VisionConfig):
114
+ super().__init__()
115
+ self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)]
116
+
117
+ def __call__(
118
+ self,
119
+ x: mx.array,
120
+ output_hidden_states: Optional[bool] = None,
121
+ mask: Optional[mx.array] = None,
122
+ ) -> mx.array:
123
+ encoder_states = (x,) if output_hidden_states else None
124
+ h = x
125
+ for l in self.layers:
126
+ x = l(x, mask=mask)
127
+ if output_hidden_states:
128
+ encoder_states = encoder_states + (x,)
129
+
130
+ h = x
131
+
132
+ return (h, encoder_states)
133
+
134
+
135
+ class VisionEmbeddings(nn.Module):
136
+ def __init__(self, config: VisionConfig):
137
+ super().__init__()
138
+ self.config = config
139
+ self.embed_dim = config.hidden_size
140
+ self.image_size = config.image_size
141
+ self.patch_size = config.patch_size
142
+
143
+ self.patch_embedding = nn.Conv2d(
144
+ in_channels=config.num_channels,
145
+ out_channels=self.embed_dim,
146
+ kernel_size=self.patch_size,
147
+ stride=self.patch_size,
148
+ )
149
+
150
+ self.num_patches = self.image_size // self.patch_size
151
+ self.num_positions = self.num_patches**2
152
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
153
+
154
+ def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
155
+ B, H, W, C = x.shape
156
+ patch_embeddings = self.patch_embedding(x)
157
+ patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2)
158
+ max_nb_patches_h, max_nb_patches_w = (
159
+ H // self.patch_size,
160
+ W // self.patch_size,
161
+ )
162
+ boundaries = np.linspace(
163
+ 1 / self.num_patches, 1.0, self.num_patches, endpoint=False
164
+ )
165
+ position_ids = np.zeros((B, max_nb_patches_h * max_nb_patches_w), dtype=int)
166
+
167
+ for batch_idx, p_attn_mask in enumerate(mask):
168
+ p_attn_mask = np.array(p_attn_mask)
169
+ nb_patches_h = p_attn_mask[:, 0].sum()
170
+ nb_patches_w = p_attn_mask[0, :].sum()
171
+
172
+ fractional_coords_h = np.linspace(0, 1, nb_patches_h, endpoint=False)
173
+ fractional_coords_w = np.linspace(0, 1, nb_patches_w, endpoint=False)
174
+
175
+ bucket_coords_h = (
176
+ np.digitize(fractional_coords_h, boundaries, right=True) - 1
177
+ )
178
+ bucket_coords_w = (
179
+ np.digitize(fractional_coords_w, boundaries, right=True) - 1
180
+ )
181
+
182
+ pos_ids = (
183
+ bucket_coords_h[:, None] * self.num_patches + bucket_coords_w
184
+ ).flatten()
185
+ position_ids[batch_idx][p_attn_mask.reshape(-1)] = pos_ids
186
+
187
+ embeddings = patch_embeddings
188
+ embeddings += self.position_embedding(mx.array(position_ids))
189
+ return embeddings
190
+
191
+
192
+ class VisionModel(nn.Module):
193
+ def __init__(self, config: VisionConfig):
194
+ super().__init__()
195
+ self.config = config
196
+ self.model_type = config.model_type
197
+ if self.model_type not in ["idefics2", "idefics2_vision"]:
198
+ raise ValueError(f"Unsupported model type: {self.model_type}")
199
+ self.embeddings = VisionEmbeddings(config)
200
+ self.encoder = Encoder(config)
201
+ self.post_layernorm = nn.LayerNorm(config.hidden_size)
202
+
203
+ def __call__(
204
+ self,
205
+ x: mx.array,
206
+ patch_attention_mask: Optional[mx.array] = None,
207
+ output_hidden_states: Optional[bool] = None,
208
+ ) -> mx.array:
209
+
210
+ B, L, D, C = x.shape
211
+ if patch_attention_mask is None:
212
+ patch_size = self.config.patch_size
213
+ patch_attention_mask = mx.ones(
214
+ (
215
+ B,
216
+ L // patch_size,
217
+ D // patch_size,
218
+ ),
219
+ dtype=mx.bool_,
220
+ )
221
+
222
+ x = self.embeddings(x, mask=patch_attention_mask)
223
+ encoder_outputs = self.encoder(x=x, output_hidden_states=output_hidden_states)
224
+
225
+ pooler_output = self.post_layernorm(encoder_outputs[0])
226
+
227
+ return pooler_output, x, encoder_outputs[-1]
228
+
229
+ def sanitize(self, weights):
230
+ sanitized_weights = {}
231
+ for k, v in weights.items():
232
+ if "patch_embedding.weight" in k:
233
+ # PyTorch conv2d weight tensors have shape:
234
+ # [out_channels, in_channels, kH, KW]
235
+ # MLX conv2d expects the weight be of shape:
236
+ # [out_channels, kH, KW, in_channels]
237
+ if check_array_shape(v):
238
+ sanitized_weights[k] = v
239
+ else:
240
+ sanitized_weights[k] = v.transpose(0, 2, 3, 1)
241
+ else:
242
+ sanitized_weights[k] = v
243
+
244
+ return sanitized_weights
@@ -0,0 +1,4 @@
1
+ from .config import ModelConfig, TextConfig, VisionConfig
2
+ from .idefics3 import Model
3
+ from .language import LanguageModel
4
+ from .vision import VisionModel
@@ -0,0 +1,54 @@
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional
3
+
4
+ from ..base import BaseModelConfig
5
+
6
+
7
+ @dataclass
8
+ class TextConfig(BaseModelConfig):
9
+ model_type: str
10
+ hidden_size: int
11
+ intermediate_size: int
12
+ num_attention_heads: int
13
+ rms_norm_eps: float
14
+ vocab_size: int
15
+ num_key_value_heads: int
16
+ rope_theta: float = 1000000.0
17
+ num_hidden_layers: int = 32
18
+ rope_traditional: bool = False
19
+ max_position_embeddings: int = 4096
20
+ tie_word_embeddings: bool = False
21
+
22
+ def __post_init__(self):
23
+ if self.num_key_value_heads is None:
24
+ self.num_key_value_heads = self.num_attention_heads
25
+
26
+
27
+ @dataclass
28
+ class VisionConfig(BaseModelConfig):
29
+ model_type: str
30
+ hidden_size: int
31
+ num_attention_heads: int
32
+ patch_size: int
33
+ num_hidden_layers: int = 12
34
+ intermediate_size: int = 3072
35
+ image_size: int = 224
36
+ num_channels: int = 3
37
+ layer_norm_eps: float = 1e-6
38
+
39
+
40
+ @dataclass
41
+ class ModelConfig(BaseModelConfig):
42
+ text_config: TextConfig
43
+ vision_config: VisionConfig
44
+ model_type: str
45
+ ignore_index: int = -100
46
+ vocab_size: int = 128259
47
+ scale_factor: int = 2
48
+ image_token_id: int = 49153
49
+ image_token_index: Optional[int] = None
50
+ eos_token_id: Optional[List[int]] = None
51
+
52
+ def __post_init__(self):
53
+ if self.image_token_index is None:
54
+ self.image_token_index = self.image_token_id
@@ -0,0 +1,221 @@
1
+ import re
2
+ from typing import Optional
3
+
4
+ import mlx.core as mx
5
+ import mlx.nn as nn
6
+ import numpy as np
7
+
8
+ from ..base import InputEmbeddingsFeatures
9
+ from .config import ModelConfig
10
+ from .language import LanguageModel
11
+ from .vision import VisionModel
12
+
13
+
14
+ def masked_scatter(
15
+ final_embedding: mx.array,
16
+ image_mask_expanded: mx.array,
17
+ scaled_image_features: mx.array,
18
+ ):
19
+ # Reshape the tensors to 1D
20
+ final_embedding_shape = final_embedding.shape
21
+ scaled_image_features_flattened = mx.flatten(scaled_image_features)
22
+ final_embedding_flattened = mx.flatten(final_embedding)
23
+ image_mask_expanded_flattened = mx.flatten(image_mask_expanded)
24
+
25
+ # Scatter the scaled image features into the special image token positions
26
+ image_positions = mx.array(np.where(image_mask_expanded_flattened)[0], mx.uint32)
27
+ final_embedding_flattened[image_positions] = scaled_image_features_flattened
28
+
29
+ # Reshape back to the original shape
30
+ final_embedding = mx.reshape(final_embedding_flattened, final_embedding_shape)
31
+
32
+ return final_embedding
33
+
34
+
35
+ class MLP(nn.Module):
36
+ def __init__(self, config: ModelConfig):
37
+ super().__init__()
38
+ input_size = config.vision_config.hidden_size * (config.scale_factor**2)
39
+ output_size = config.text_config.hidden_size
40
+ self.proj = nn.Linear(input_size, output_size, bias=False)
41
+
42
+ def __call__(self, x):
43
+ return self.proj(x)
44
+
45
+
46
+ class Idefics3Connector(nn.Module):
47
+ def __init__(self, config: ModelConfig):
48
+ super().__init__()
49
+ self.scale_factor = config.scale_factor
50
+ self.modality_projection = MLP(config)
51
+
52
+ def pixel_shuffle(self, x, scale_factor=2):
53
+ bsz, seq, embed_dim = x.shape
54
+ height = width = int(seq**0.5)
55
+ x = x.reshape(bsz, height, width, embed_dim)
56
+ x = x.reshape(bsz, height, int(width / scale_factor), embed_dim * scale_factor)
57
+ x = x.transpose(0, 2, 1, 3)
58
+ x = x.reshape(
59
+ bsz,
60
+ int(width / scale_factor),
61
+ int(height / scale_factor),
62
+ embed_dim * (scale_factor**2),
63
+ )
64
+ x = x.transpose(0, 2, 1, 3)
65
+ x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))
66
+ return x
67
+
68
+ def __call__(self, image_hidden_states):
69
+ image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor)
70
+ image_hidden_states = self.modality_projection(image_hidden_states)
71
+ return image_hidden_states
72
+
73
+
74
+ class Model(nn.Module):
75
+ def __init__(self, config: ModelConfig):
76
+ super().__init__()
77
+ self.model_type = config.model_type
78
+ self.config = config
79
+
80
+ self.vision_model = VisionModel(config.vision_config)
81
+ self.language_model = LanguageModel(config.text_config)
82
+ self.connector = Idefics3Connector(config)
83
+
84
+ def get_input_embeddings(
85
+ self,
86
+ input_ids: Optional[mx.array] = None,
87
+ pixel_values: Optional[mx.array] = None,
88
+ **kwargs,
89
+ ):
90
+ pixel_attention_mask = kwargs.get("pixel_attention_mask", None)
91
+
92
+ if pixel_values is None:
93
+ return InputEmbeddingsFeatures(
94
+ inputs_embeds=self.language_model.embed_tokens(input_ids)
95
+ )
96
+
97
+ inputs_embeds = self.language_model.embed_tokens(input_ids)
98
+
99
+ batch_size, num_images, num_channels, height, width = pixel_values.shape
100
+ pixel_values = pixel_values.reshape(
101
+ batch_size * num_images, num_channels, height, width
102
+ )
103
+
104
+ # Remove padding images - padding image are full 0.
105
+ nb_values_per_image = np.prod(pixel_values.shape[1:])
106
+ real_images_mask = (pixel_values == 0.0).sum(
107
+ axis=(-1, -2, -3)
108
+ ) != nb_values_per_image
109
+ real_images_inds = np.where(real_images_mask)[0].tolist()
110
+ pixel_values = pixel_values[real_images_inds, ...]
111
+
112
+ if pixel_attention_mask is None:
113
+ pixel_attention_mask = mx.ones(
114
+ (pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)),
115
+ dtype=mx.bool,
116
+ )
117
+ else:
118
+ # Remove padding images from the mask
119
+ pixel_attention_mask = pixel_attention_mask.reshape(
120
+ batch_size * num_images, height, width
121
+ )
122
+ pixel_attention_mask = pixel_attention_mask[real_images_inds]
123
+
124
+ patch_size = self.config.vision_config.patch_size
125
+ batch_size, height, width = pixel_attention_mask.shape
126
+
127
+ # Calculate number of patches
128
+ patches_h = height // patch_size
129
+ patches_w = width // patch_size
130
+
131
+ # Reshape to extract patches
132
+ reshaped = pixel_attention_mask[
133
+ :, : patches_h * patch_size, : patches_w * patch_size
134
+ ]
135
+ reshaped = reshaped.reshape(
136
+ batch_size, patches_h, patch_size, patches_w, patch_size
137
+ )
138
+ reshaped = reshaped.transpose(
139
+ 0, 1, 3, 2, 4
140
+ ) # (batch, patches_h, patches_w, patch_size, patch_size)
141
+
142
+ # Sum over patch dimensions and check if any pixels are active
143
+ patch_attention_mask = reshaped.sum(axis=(-1, -2)) > 0
144
+
145
+ pooler_output, *_ = self.vision_model(
146
+ pixel_values.transpose(0, 2, 3, 1),
147
+ patch_attention_mask=patch_attention_mask,
148
+ output_hidden_states=True,
149
+ )
150
+
151
+ image_features = pooler_output.astype(pixel_values.dtype)
152
+ image_features = self.connector(image_features)
153
+
154
+ final_inputs_embeds = self._prepare_inputs_for_multimodal(
155
+ image_features, inputs_embeds, input_ids
156
+ )
157
+ return InputEmbeddingsFeatures(inputs_embeds=final_inputs_embeds)
158
+
159
+ def _prepare_inputs_for_multimodal(self, image_features, inputs_embeds, input_ids):
160
+ special_image_mask = input_ids == self.config.image_token_index
161
+ n_image_tokens = special_image_mask.sum()
162
+ special_image_mask = special_image_mask[..., None]
163
+ special_image_mask = mx.broadcast_to(special_image_mask, inputs_embeds.shape)
164
+
165
+ n_image_features = image_features.shape[0]
166
+ n_image_mask_elements = special_image_mask.sum()
167
+ if n_image_mask_elements != image_features.size:
168
+ raise ValueError(
169
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
170
+ )
171
+
172
+ inputs_embeds = masked_scatter(
173
+ inputs_embeds, special_image_mask, image_features
174
+ )
175
+
176
+ return inputs_embeds
177
+
178
+ @property
179
+ def layers(self):
180
+ return self.language_model.layers
181
+
182
+ def __call__(
183
+ self,
184
+ input_ids: mx.array,
185
+ pixel_values: mx.array,
186
+ cache=None,
187
+ **kwargs,
188
+ ):
189
+ input_embeddings_features = self.get_input_embeddings(
190
+ input_ids, pixel_values, **kwargs
191
+ )
192
+ logits = self.language_model(
193
+ inputs=input_ids,
194
+ cache=cache,
195
+ inputs_embeds=input_embeddings_features.inputs_embeds,
196
+ )
197
+ return logits
198
+
199
+ def sanitize(self, weights):
200
+ weights = {
201
+ (
202
+ f"{k.split('.', 1)[1]}"
203
+ if re.match(r"^model\.", k)
204
+ else (f"language_model.{k}" if re.match(r"^lm_head\.", k) else k)
205
+ ): v
206
+ for k, v in weights.items()
207
+ }
208
+
209
+ weights = {
210
+ (
211
+ f"language_model.{k.split('.', 1)[1]}"
212
+ if re.match(
213
+ r"^text_model\.",
214
+ k,
215
+ )
216
+ else k
217
+ ): v
218
+ for k, v in weights.items()
219
+ }
220
+
221
+ return weights