fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,377 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from ..base import (
7
+ LanguageModelOutput,
8
+ create_attention_mask,
9
+ scaled_dot_product_attention,
10
+ )
11
+ from ..cache import KVCache
12
+ from .config import TextConfig
13
+
14
+
15
+ class MllamaTextCrossAttention(nn.Module):
16
+ def __init__(self, config: TextConfig, layer_idx: Optional[int] = None):
17
+ super().__init__()
18
+ self.config = config
19
+ self.hidden_size = config.hidden_size
20
+ self.num_heads = config.num_attention_heads
21
+ self.head_dim = self.hidden_size // self.num_heads
22
+ self.num_key_value_heads = config.num_key_value_heads
23
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
24
+ self.layer_idx = layer_idx
25
+ self.scale = self.head_dim**-0.5
26
+ self.q_proj = nn.Linear(
27
+ self.hidden_size, self.num_heads * self.head_dim, bias=False
28
+ )
29
+ self.k_proj = nn.Linear(
30
+ self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
31
+ )
32
+ self.v_proj = nn.Linear(
33
+ self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
34
+ )
35
+ self.o_proj = nn.Linear(
36
+ self.num_heads * self.head_dim, self.hidden_size, bias=False
37
+ )
38
+
39
+ self.q_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
40
+ self.k_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
41
+
42
+ def __call__(
43
+ self,
44
+ hidden_states: mx.array,
45
+ cross_attention_states: Optional[mx.array] = None,
46
+ attention_mask: Optional[mx.array] = None,
47
+ cache: Optional[KVCache] = None,
48
+ ) -> mx.array:
49
+
50
+ bsz, q_len, _ = hidden_states.shape
51
+ query = (
52
+ self.q_proj(hidden_states)
53
+ .reshape(bsz, q_len, self.num_heads, self.head_dim)
54
+ .transpose(0, 2, 1, 3)
55
+ )
56
+ query_states = self.q_norm(query)
57
+
58
+ if cross_attention_states is not None:
59
+ key_states = (
60
+ self.k_proj(cross_attention_states)
61
+ .reshape(bsz, -1, self.num_key_value_heads, self.head_dim)
62
+ .transpose(0, 2, 1, 3)
63
+ )
64
+ value_states = (
65
+ self.v_proj(cross_attention_states)
66
+ .reshape(bsz, -1, self.num_key_value_heads, self.head_dim)
67
+ .transpose(0, 2, 1, 3)
68
+ )
69
+ key_states = self.k_norm(key_states)
70
+ elif cache is not None and cache.offset > 0:
71
+ key_states, value_states = cache.fetch()
72
+ else:
73
+ key_states, value_states = mx.split(query, 2, axis=1)
74
+ key_states = self.k_norm(key_states)
75
+
76
+ attn_output = scaled_dot_product_attention(
77
+ query_states,
78
+ key_states,
79
+ value_states,
80
+ cache,
81
+ scale=self.scale,
82
+ mask=attention_mask, # add a dim for batch processing
83
+ )
84
+ attn_output = attn_output.transpose(0, 2, 1, 3).reshape(
85
+ bsz, q_len, self.hidden_size
86
+ )
87
+ return self.o_proj(attn_output)
88
+
89
+
90
+ class MllamaTextSelfAttention(nn.Module):
91
+ def __init__(self, config: TextConfig, layer_idx: int):
92
+ super().__init__()
93
+ self.config = config
94
+ self.hidden_size = config.hidden_size
95
+ self.num_heads = config.num_attention_heads
96
+ self.head_dim = self.hidden_size // self.num_heads
97
+ self.num_key_value_heads = config.num_key_value_heads
98
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
99
+ self.scale = self.head_dim**-0.5
100
+ self.layer_idx = layer_idx
101
+
102
+ self.q_proj = nn.Linear(
103
+ self.hidden_size, self.num_heads * self.head_dim, bias=False
104
+ )
105
+ self.k_proj = nn.Linear(
106
+ self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
107
+ )
108
+ self.v_proj = nn.Linear(
109
+ self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
110
+ )
111
+ self.o_proj = nn.Linear(
112
+ self.num_heads * self.head_dim, self.hidden_size, bias=False
113
+ )
114
+
115
+ self.rope = nn.RoPE(
116
+ self.head_dim,
117
+ traditional=config.rope_traditional,
118
+ base=config.rope_theta,
119
+ scale=1,
120
+ )
121
+
122
+ def __call__(
123
+ self,
124
+ x: mx.array,
125
+ mask: Optional[mx.array] = None,
126
+ cache: Optional[KVCache] = None,
127
+ ) -> mx.array:
128
+ bsz, q_len, _ = x.shape
129
+ query_states = (
130
+ self.q_proj(x).reshape(bsz, q_len, self.num_heads, -1).transpose(0, 2, 1, 3)
131
+ )
132
+ key_states = (
133
+ self.k_proj(x)
134
+ .reshape(bsz, q_len, self.num_key_value_heads, -1)
135
+ .transpose(0, 2, 1, 3)
136
+ )
137
+ value_states = (
138
+ self.v_proj(x)
139
+ .reshape(bsz, q_len, self.num_key_value_heads, -1)
140
+ .transpose(0, 2, 1, 3)
141
+ )
142
+
143
+ if cache is not None:
144
+ query_states = self.rope(query_states, offset=cache.offset)
145
+ key_states = self.rope(key_states, offset=cache.offset)
146
+ key_states, value_states = cache.update_and_fetch(key_states, value_states)
147
+ else:
148
+ query_states = self.rope(query_states)
149
+ key_states = self.rope(key_states)
150
+
151
+ attn_output = scaled_dot_product_attention(
152
+ query_states, key_states, value_states, cache, scale=self.scale, mask=mask
153
+ )
154
+ attn_output = attn_output.transpose(0, 2, 1, 3).reshape(
155
+ bsz, q_len, self.hidden_size
156
+ )
157
+ return self.o_proj(attn_output)
158
+
159
+
160
+ class MllamaTextMLP(nn.Module):
161
+ def __init__(self, config: TextConfig):
162
+ super().__init__()
163
+ self.gate_proj = nn.Linear(
164
+ config.hidden_size, config.intermediate_size, bias=False
165
+ )
166
+ self.up_proj = nn.Linear(
167
+ config.hidden_size, config.intermediate_size, bias=False
168
+ )
169
+ self.down_proj = nn.Linear(
170
+ config.intermediate_size, config.hidden_size, bias=False
171
+ )
172
+ self.act_fn = lambda x: x * mx.sigmoid(x)
173
+
174
+ def __call__(self, x):
175
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
176
+
177
+
178
+ class MllamaSelfAttentionDecoderLayer(nn.Module):
179
+ def __init__(self, config: TextConfig, layer_idx: int):
180
+ super().__init__()
181
+ self.hidden_size = config.hidden_size
182
+ self.self_attn = MllamaTextSelfAttention(config, layer_idx=layer_idx)
183
+ self.mlp = MllamaTextMLP(config)
184
+ self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
185
+ self.post_attention_layernorm = nn.RMSNorm(
186
+ config.hidden_size, eps=config.rms_norm_eps
187
+ )
188
+
189
+ def __call__(
190
+ self,
191
+ hidden_states: mx.array,
192
+ mask: Optional[mx.array] = None,
193
+ cache: Optional[KVCache] = None,
194
+ ) -> mx.array:
195
+ residual = hidden_states
196
+ hidden_states = self.input_layernorm(hidden_states)
197
+ hidden_states = self.self_attn(
198
+ x=hidden_states,
199
+ mask=mask,
200
+ cache=cache,
201
+ )
202
+ hidden_states = residual + hidden_states
203
+
204
+ residual = hidden_states
205
+ hidden_states = self.post_attention_layernorm(hidden_states)
206
+ hidden_states = self.mlp(hidden_states)
207
+ hidden_states = residual + hidden_states
208
+
209
+ return hidden_states
210
+
211
+
212
+ class MllamaCrossAttentionDecoderLayer(nn.Module):
213
+ def __init__(self, config: TextConfig, layer_idx: int):
214
+ super().__init__()
215
+ self.hidden_size = config.hidden_size
216
+ self.cross_attn = MllamaTextCrossAttention(config, layer_idx=layer_idx)
217
+ self.mlp = MllamaTextMLP(config)
218
+ self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
219
+ self.post_attention_layernorm = nn.RMSNorm(
220
+ config.hidden_size, eps=config.rms_norm_eps
221
+ )
222
+ self.cross_attn_attn_gate = mx.zeros(1)
223
+ self.cross_attn_mlp_gate = mx.zeros(1)
224
+
225
+ def __call__(
226
+ self,
227
+ hidden_states: mx.array,
228
+ cross_attention_states: mx.array,
229
+ attention_mask: Optional[mx.array] = None,
230
+ full_text_row_masked_out_mask: Optional[mx.array] = None,
231
+ cache: Optional[KVCache] = None,
232
+ ) -> mx.array:
233
+ residual = hidden_states
234
+ hidden_states = self.input_layernorm(hidden_states)
235
+ hidden_states = self.cross_attn(
236
+ hidden_states=hidden_states,
237
+ cross_attention_states=cross_attention_states,
238
+ attention_mask=attention_mask,
239
+ cache=cache,
240
+ )
241
+ hidden_states = residual + mx.tanh(self.cross_attn_attn_gate) * hidden_states
242
+
243
+ residual = hidden_states
244
+ hidden_states = self.post_attention_layernorm(hidden_states)
245
+ hidden_states = self.mlp(hidden_states)
246
+ if full_text_row_masked_out_mask is not None:
247
+ hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states
248
+ hidden_states = residual + mx.tanh(self.cross_attn_mlp_gate) * hidden_states
249
+
250
+ return hidden_states
251
+
252
+
253
+ class MllamaTextModel(nn.Module):
254
+ def __init__(self, config: TextConfig):
255
+ super().__init__()
256
+ self.config = config
257
+ self.vocab_size = config.vocab_size
258
+ self.hidden_size = config.hidden_size
259
+
260
+ self.embed_tokens = nn.Embedding(config.vocab_size + 8, config.hidden_size)
261
+ self.layers = [
262
+ (
263
+ MllamaCrossAttentionDecoderLayer(config, layer_idx)
264
+ if layer_idx in config.cross_attention_layers
265
+ else MllamaSelfAttentionDecoderLayer(config, layer_idx)
266
+ )
267
+ for layer_idx in range(config.num_hidden_layers)
268
+ ]
269
+ self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
270
+
271
+ def __call__(
272
+ self,
273
+ input_ids: Optional[mx.array] = None,
274
+ mask: Optional[mx.array] = None,
275
+ position_ids: Optional[mx.array] = None,
276
+ cross_attention_states: Optional[mx.array] = None,
277
+ cross_attention_mask: Optional[mx.array] = None,
278
+ full_text_row_masked_out_mask: Optional[mx.array] = None,
279
+ inputs_embeds: Optional[mx.array] = None,
280
+ cache: Optional[KVCache] = None,
281
+ ) -> mx.array:
282
+ # Prioritize inputs_embeds if provided
283
+ if inputs_embeds is not None:
284
+ batch_size, seq_length, _ = inputs_embeds.shape
285
+ elif input_ids is not None:
286
+ batch_size, seq_length = input_ids.shape
287
+ inputs_embeds = self.embed_tokens(input_ids)
288
+ else:
289
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
290
+
291
+ if position_ids is None:
292
+ position_ids = mx.expand_dims(mx.arange(seq_length), 0)
293
+ position_ids = mx.repeat(position_ids, batch_size, axis=0)
294
+
295
+ hidden_states = inputs_embeds
296
+
297
+ if cache is None:
298
+ cache = [None] * len(self.layers)
299
+
300
+ if mask is None:
301
+ mask = create_attention_mask(hidden_states, cache)
302
+
303
+ for idx, (decoder_layer, c) in enumerate(zip(self.layers, cache)):
304
+ if idx in self.config.cross_attention_layers:
305
+ layer_outputs = decoder_layer(
306
+ hidden_states,
307
+ cross_attention_states=cross_attention_states,
308
+ attention_mask=cross_attention_mask,
309
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask,
310
+ cache=c,
311
+ )
312
+ else:
313
+ layer_outputs = decoder_layer(
314
+ hidden_states,
315
+ mask=mask,
316
+ cache=c,
317
+ )
318
+ hidden_states = layer_outputs
319
+
320
+ hidden_states = self.norm(hidden_states)
321
+
322
+ return hidden_states
323
+
324
+
325
+ class LanguageModel(nn.Module):
326
+ def __init__(self, config: TextConfig):
327
+ super().__init__()
328
+ self.config = config
329
+ self.model = MllamaTextModel(config)
330
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
331
+
332
+ def __call__(
333
+ self,
334
+ inputs: Optional[mx.array] = None,
335
+ inputs_embeds: Optional[mx.array] = None,
336
+ mask: Optional[mx.array] = None,
337
+ cache: Optional[KVCache] = None,
338
+ cross_attention_states: Optional[mx.array] = None,
339
+ cross_attention_mask: Optional[mx.array] = None,
340
+ full_text_row_masked_out_mask: Optional[mx.array] = None,
341
+ **kwargs,
342
+ ) -> Tuple[mx.array, Optional[mx.array]]:
343
+
344
+ hidden_states = self.model(
345
+ input_ids=inputs,
346
+ mask=mask,
347
+ cross_attention_states=cross_attention_states,
348
+ cross_attention_mask=cross_attention_mask,
349
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask,
350
+ inputs_embeds=inputs_embeds,
351
+ cache=cache,
352
+ )
353
+
354
+ logits = self.lm_head(hidden_states)
355
+
356
+ return LanguageModelOutput(
357
+ logits=logits, cross_attention_states=cross_attention_states
358
+ )
359
+
360
+ @staticmethod
361
+ def sanitize(weights):
362
+ # Remove unused precomputed rotary freqs
363
+ return {
364
+ k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
365
+ }
366
+
367
+ @property
368
+ def layers(self):
369
+ return self.model.layers
370
+
371
+ @property
372
+ def head_dim(self):
373
+ return self.config.hidden_size // self.config.num_attention_heads
374
+
375
+ @property
376
+ def n_kv_heads(self):
377
+ return self.config.num_key_value_heads
@@ -0,0 +1,210 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from ..base import InputEmbeddingsFeatures
7
+ from ..cache import KVCache
8
+ from .config import ModelConfig
9
+ from .language import LanguageModel
10
+ from .vision import VisionModel
11
+
12
+
13
+ class Model(nn.Module):
14
+ def __init__(self, config: ModelConfig):
15
+ super().__init__()
16
+ self.config = config
17
+ self.vision_tower = VisionModel(config.vision_config)
18
+ self.language_model = LanguageModel(config.text_config)
19
+ self.multi_modal_projector = nn.Linear(
20
+ config.vision_config.vision_output_dim,
21
+ config.text_config.hidden_size,
22
+ bias=True,
23
+ )
24
+
25
+ @property
26
+ def layers(self):
27
+ return self.language_model.model.layers
28
+
29
+ def get_input_embeddings(
30
+ self,
31
+ input_ids: Optional[mx.array] = None,
32
+ pixel_values: Optional[mx.array] = None,
33
+ **kwargs,
34
+ ):
35
+ aspect_ratio_ids = kwargs.get("aspect_ratio_ids", None)
36
+ aspect_ratio_mask = kwargs.get("aspect_ratio_mask", None)
37
+ cross_attention_mask = kwargs.get("cross_attention_mask", None)
38
+
39
+ # Get text embeddings
40
+ inputs_embeds = self.language_model.model.embed_tokens(input_ids)
41
+
42
+ cross_attention_states = None
43
+ full_text_row_masked_out_mask = None
44
+
45
+ # Process vision input if provided
46
+ if pixel_values is not None:
47
+ if aspect_ratio_ids is None:
48
+ raise ValueError(
49
+ "`aspect_ratio_ids` must be provided if `pixel_values` is provided"
50
+ )
51
+
52
+ vision_outputs = self.vision_tower(
53
+ pixel_values=pixel_values,
54
+ aspect_ratio_ids=aspect_ratio_ids,
55
+ aspect_ratio_mask=aspect_ratio_mask,
56
+ )
57
+ cross_attention_states = vision_outputs[0]
58
+
59
+ cross_attention_states = self.multi_modal_projector(
60
+ cross_attention_states
61
+ ).reshape(
62
+ -1,
63
+ cross_attention_states.shape[-2],
64
+ self.config.text_config.hidden_size,
65
+ )
66
+
67
+ # Prepare cross attention mask
68
+ if cross_attention_mask is not None:
69
+ cross_attention_mask, full_text_row_masked_out_mask = (
70
+ self._prepare_cross_attention_mask(
71
+ cross_attention_mask,
72
+ num_vision_tokens=(
73
+ self.config.vision_config.image_size
74
+ // self.config.vision_config.patch_size
75
+ )
76
+ ** 2
77
+ + 1,
78
+ )
79
+ )
80
+
81
+ cache_position = mx.arange(input_ids.shape[1], dtype=mx.int32)
82
+ cross_attention_mask = cross_attention_mask[:, :, cache_position]
83
+ full_text_row_masked_out_mask = full_text_row_masked_out_mask[
84
+ :, :, cache_position
85
+ ]
86
+
87
+ return InputEmbeddingsFeatures(
88
+ inputs_embeds=inputs_embeds,
89
+ cross_attention_states=cross_attention_states,
90
+ cross_attention_mask=cross_attention_mask,
91
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask,
92
+ )
93
+
94
+ def __call__(
95
+ self,
96
+ input_ids: mx.array,
97
+ pixel_values: Optional[mx.array] = None,
98
+ mask: Optional[mx.array] = None,
99
+ cache: Optional[KVCache] = None,
100
+ **kwargs,
101
+ ) -> Tuple[mx.array, Optional[mx.array]]:
102
+
103
+ aspect_ratio_ids = kwargs.pop("aspect_ratio_ids", None)
104
+ aspect_ratio_mask = kwargs.pop("aspect_ratio_mask", None)
105
+ cross_attention_mask = kwargs.pop("cross_attention_mask", None)
106
+
107
+ inputs_embeds = None
108
+
109
+ # Process vision input if provided
110
+ if pixel_values is not None:
111
+ if aspect_ratio_ids is None:
112
+ raise ValueError(
113
+ "`aspect_ratio_ids` must be provided if `pixel_values` is provided"
114
+ )
115
+
116
+ vision_outputs = self.vision_tower(
117
+ pixel_values=pixel_values,
118
+ aspect_ratio_ids=aspect_ratio_ids,
119
+ aspect_ratio_mask=aspect_ratio_mask,
120
+ )
121
+ cross_attention_states = vision_outputs[0]
122
+
123
+ cross_attention_states = self.multi_modal_projector(
124
+ cross_attention_states
125
+ ).reshape(
126
+ -1,
127
+ cross_attention_states.shape[-2],
128
+ self.config.text_config.hidden_size,
129
+ )
130
+
131
+ else:
132
+ cross_attention_states = None
133
+
134
+ # Prepare cross attention mask
135
+ if cross_attention_mask is not None:
136
+ cross_attention_mask, full_text_row_masked_out_mask = (
137
+ self._prepare_cross_attention_mask(
138
+ cross_attention_mask,
139
+ num_vision_tokens=(
140
+ self.config.vision_config.image_size
141
+ // self.config.vision_config.patch_size
142
+ )
143
+ ** 2
144
+ + 1,
145
+ )
146
+ )
147
+ else:
148
+ full_text_row_masked_out_mask = None
149
+
150
+ if cross_attention_mask is not None:
151
+ cache_position = mx.arange(input_ids.shape[1], dtype=mx.int32)
152
+ cross_attention_mask = cross_attention_mask[:, :, cache_position]
153
+ full_text_row_masked_out_mask = full_text_row_masked_out_mask[
154
+ :, :, cache_position
155
+ ]
156
+
157
+ # Process language input
158
+ outputs = self.language_model(
159
+ inputs=input_ids,
160
+ mask=mask,
161
+ cross_attention_states=cross_attention_states,
162
+ cross_attention_mask=cross_attention_mask,
163
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask,
164
+ inputs_embeds=inputs_embeds,
165
+ cache=cache,
166
+ )
167
+
168
+ return outputs
169
+
170
+ def _prepare_cross_attention_mask(
171
+ self,
172
+ cross_attention_mask: mx.array,
173
+ num_vision_tokens: int,
174
+ ) -> Tuple[mx.array, mx.array]:
175
+ batch_size, text_total_length, *_ = cross_attention_mask.shape
176
+ cross_attention_mask = mx.repeat(
177
+ cross_attention_mask, num_vision_tokens, axis=3
178
+ )
179
+ cross_attention_mask = cross_attention_mask.reshape(
180
+ batch_size, text_total_length, -1
181
+ )
182
+ cross_attention_mask = mx.expand_dims(cross_attention_mask, 1)
183
+
184
+ # Invert the mask
185
+ inverted_cross_attn_mask = 1.0 - cross_attention_mask
186
+ fill_array = mx.array(-1e9)
187
+ fill_array = mx.broadcast_to(fill_array, inverted_cross_attn_mask.shape)
188
+ cross_attention_mask = mx.where(
189
+ inverted_cross_attn_mask,
190
+ fill_array,
191
+ cross_attention_mask,
192
+ )
193
+
194
+ # Apply full-row bias
195
+ full_text_row_masked_out_mask = mx.any(
196
+ cross_attention_mask != -1e9,
197
+ axis=-1,
198
+ keepdims=True,
199
+ )
200
+ cross_attention_mask *= full_text_row_masked_out_mask
201
+
202
+ return cross_attention_mask, full_text_row_masked_out_mask
203
+
204
+ def sanitize(self, weights):
205
+ def transform_key(key):
206
+ if "vision_tower" not in key:
207
+ key = key.replace("vision_model", "vision_tower")
208
+ return key
209
+
210
+ return {transform_key(k): v for k, v in weights.items()}