fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,157 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from ..base import (
7
+ LanguageModelOutput,
8
+ create_attention_mask,
9
+ scaled_dot_product_attention,
10
+ )
11
+ from ..cache import KVCache
12
+ from .config import TextConfig
13
+
14
+
15
+ class Attention(nn.Module):
16
+ def __init__(self, config: TextConfig):
17
+ super().__init__()
18
+
19
+ dim = config.hidden_size
20
+ self.n_heads = n_heads = config.num_attention_heads
21
+ head_dim = config.hidden_size // n_heads
22
+ self.n_kv_heads = n_kv_heads = config.num_key_value_heads
23
+
24
+ self.scale = head_dim**-0.5
25
+
26
+ self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
27
+ self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
28
+ self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
29
+ self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
30
+
31
+ self.rope = nn.RoPE(
32
+ head_dim,
33
+ traditional=config.rope_traditional,
34
+ base=config.rope_theta,
35
+ )
36
+
37
+ def __call__(
38
+ self,
39
+ x: mx.array,
40
+ mask: Optional[mx.array] = None,
41
+ cache: Optional[KVCache] = None,
42
+ ) -> mx.array:
43
+ B, L, D = x.shape
44
+
45
+ queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
46
+
47
+ # Prepare the queries, keys and values for the attention computation
48
+ queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
49
+ keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
50
+ values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
51
+
52
+ if cache is not None:
53
+ queries = self.rope(queries, offset=cache.offset)
54
+ keys = self.rope(keys, offset=cache.offset)
55
+ keys, values = cache.update_and_fetch(keys, values)
56
+ else:
57
+ queries = self.rope(queries)
58
+ keys = self.rope(keys)
59
+
60
+ output = scaled_dot_product_attention(
61
+ queries, keys, values, cache, scale=self.scale, mask=mask
62
+ )
63
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
64
+ return self.o_proj(output)
65
+
66
+
67
+ class MLP(nn.Module):
68
+ def __init__(self, dim, hidden_dim):
69
+ super().__init__()
70
+ self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
71
+ self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
72
+ self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
73
+
74
+ def __call__(self, x) -> mx.array:
75
+ return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
76
+
77
+
78
+ class TransformerBlock(nn.Module):
79
+ def __init__(self, config: TextConfig):
80
+ super().__init__()
81
+ self.num_attention_heads = config.num_attention_heads
82
+ self.hidden_size = config.hidden_size
83
+ self.self_attn = Attention(config)
84
+ self.mlp = MLP(config.hidden_size, config.intermediate_size)
85
+ self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
86
+ self.post_attention_layernorm = nn.RMSNorm(
87
+ config.hidden_size, eps=config.rms_norm_eps
88
+ )
89
+ self.config = config
90
+
91
+ def __call__(
92
+ self,
93
+ x: mx.array,
94
+ mask: Optional[mx.array] = None,
95
+ cache: Optional[KVCache] = None,
96
+ ) -> mx.array:
97
+ r = self.self_attn(self.input_layernorm(x), mask, cache)
98
+ h = x + r
99
+ r = self.mlp(self.post_attention_layernorm(h))
100
+ out = h + r
101
+ return out
102
+
103
+
104
+ class LanguageModel(nn.Module):
105
+ def __init__(self, config: TextConfig):
106
+ super().__init__()
107
+ self.config = config
108
+ self.model_type = config.model_type
109
+ self.vocab_size = config.vocab_size
110
+ self.num_hidden_layers = config.num_hidden_layers
111
+ assert self.vocab_size > 0
112
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
113
+ self.layers = [
114
+ TransformerBlock(config=config) for _ in range(config.num_hidden_layers)
115
+ ]
116
+ self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
117
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
118
+
119
+ def __call__(
120
+ self,
121
+ inputs: mx.array,
122
+ inputs_embeds: Optional[mx.array] = None,
123
+ mask: Optional[mx.array] = None,
124
+ cache=None,
125
+ **kwargs,
126
+ ):
127
+ # for passing merged input embeddings
128
+ if inputs_embeds is None:
129
+ h = self.embed_tokens(inputs)
130
+ else:
131
+ h = inputs_embeds.astype(self.norm.weight.dtype)
132
+
133
+ if cache is None:
134
+ cache = [None] * len(self.layers)
135
+
136
+ if mask is None:
137
+ mask = create_attention_mask(h, cache)
138
+
139
+ for layer, c in zip(self.layers, cache):
140
+ h = layer(h, mask, c)
141
+
142
+ logits = self.lm_head(self.norm(h))
143
+ return LanguageModelOutput(logits=logits)
144
+
145
+ def sanitize(self, weights):
146
+ # Remove unused precomputed rotary freqs
147
+ return {
148
+ k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
149
+ }
150
+
151
+ @property
152
+ def head_dim(self):
153
+ return self.config.hidden_size // self.config.num_attention_heads
154
+
155
+ @property
156
+ def n_kv_heads(self):
157
+ return self.config.num_key_value_heads
@@ -0,0 +1,265 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from .config import VisionConfig
7
+
8
+
9
+ def check_array_shape(arr):
10
+ shape = arr.shape
11
+
12
+ # Check if the shape has 4 dimensions
13
+ if len(shape) != 4:
14
+ return False
15
+
16
+ out_channels, kH, KW, _ = shape
17
+
18
+ # Check if out_channels is the largest, and kH and KW are the same
19
+ if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
20
+ return True
21
+ else:
22
+ return False
23
+
24
+
25
+ class Attention(nn.Module):
26
+ def __init__(
27
+ self,
28
+ dims: int,
29
+ num_heads: int,
30
+ query_input_dims: Optional[int] = None,
31
+ key_input_dims: Optional[int] = None,
32
+ value_input_dims: Optional[int] = None,
33
+ value_dims: Optional[int] = None,
34
+ value_output_dims: Optional[int] = None,
35
+ bias: bool = True,
36
+ ):
37
+ super().__init__()
38
+
39
+ if (dims % num_heads) != 0:
40
+ raise ValueError(
41
+ "The input feature dimensions should be divisible by the "
42
+ f"number of heads ({dims} % {num_heads}) != 0"
43
+ )
44
+
45
+ query_input_dims = query_input_dims or dims
46
+ key_input_dims = key_input_dims or dims
47
+ value_input_dims = value_input_dims or key_input_dims
48
+ value_dims = value_dims or dims
49
+ value_output_dims = value_output_dims or dims
50
+
51
+ self.num_heads = num_heads
52
+ head_dim = dims // num_heads
53
+ self.scale = head_dim**-0.5
54
+
55
+ self.q_proj = nn.Linear(query_input_dims, dims, bias=bias)
56
+ self.k_proj = nn.Linear(key_input_dims, dims, bias=bias)
57
+ self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias)
58
+ self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias)
59
+
60
+ def __call__(self, x, mask=None):
61
+ queries = self.q_proj(x)
62
+ keys = self.k_proj(x)
63
+ values = self.v_proj(x)
64
+
65
+ num_heads = self.num_heads
66
+ B, L, D = queries.shape
67
+ _, S, _ = keys.shape
68
+ queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
69
+ keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
70
+ values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
71
+
72
+ output = mx.fast.scaled_dot_product_attention(
73
+ queries, keys, values, scale=self.scale, mask=mask
74
+ )
75
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
76
+ return self.out_proj(output)
77
+
78
+
79
+ class MLP(nn.Module):
80
+ def __init__(self, config: VisionConfig):
81
+ super().__init__()
82
+ self.activation_fn = nn.GELU(approx="precise")
83
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=True)
84
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=True)
85
+
86
+ def __call__(self, x: mx.array) -> mx.array:
87
+ x = self.fc1(x)
88
+ x = self.activation_fn(x)
89
+ x = self.fc2(x)
90
+ return x
91
+
92
+
93
+ class EncoderLayer(nn.Module):
94
+ def __init__(self, config: VisionConfig):
95
+ super().__init__()
96
+ self.embed_dim = config.hidden_size
97
+ self.self_attn = Attention(
98
+ config.hidden_size, config.num_attention_heads, bias=True
99
+ )
100
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
101
+ self.mlp = MLP(config)
102
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
103
+
104
+ def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
105
+ r = self.self_attn(self.layer_norm1(x), mask)
106
+ h = x + r
107
+ r = self.mlp(self.layer_norm2(h))
108
+ return h + r
109
+
110
+
111
+ class Encoder(nn.Module):
112
+ def __init__(self, config: VisionConfig):
113
+ super().__init__()
114
+ self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)]
115
+
116
+ def __call__(
117
+ self,
118
+ x: mx.array,
119
+ output_hidden_states: Optional[bool] = None,
120
+ mask: Optional[mx.array] = None,
121
+ ) -> mx.array:
122
+ encoder_states = (x,) if output_hidden_states else None
123
+ h = x
124
+ for l in self.layers:
125
+ x = l(x, mask=mask)
126
+ if output_hidden_states:
127
+ encoder_states = encoder_states + (x,)
128
+
129
+ h = x
130
+
131
+ return (h, encoder_states)
132
+
133
+
134
+ class VisionEmbeddings(nn.Module):
135
+ def __init__(self, config: VisionConfig):
136
+ super().__init__()
137
+ self.config = config
138
+ self.embed_dim = config.hidden_size
139
+ self.image_size = config.image_size
140
+ self.patch_size = config.patch_size
141
+
142
+ self.patch_embedding = nn.Conv2d(
143
+ in_channels=config.num_channels,
144
+ out_channels=self.embed_dim,
145
+ kernel_size=self.patch_size,
146
+ stride=self.patch_size,
147
+ )
148
+
149
+ self.num_patches_per_side = self.image_size // self.patch_size
150
+ self.num_patches = self.num_patches_per_side**2
151
+ self.num_positions = self.num_patches
152
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
153
+
154
+ def __call__(self, x: mx.array, patch_attention_mask: mx.array = None) -> mx.array:
155
+ batch_size, max_im_h, max_im_w, _ = x.shape
156
+ patch_embeds = self.patch_embedding(x)
157
+ embeddings = mx.flatten(patch_embeds, start_axis=1, end_axis=2)
158
+
159
+ seq_len = embeddings.shape[1]
160
+
161
+ if patch_attention_mask is None:
162
+ position_ids = mx.tile(mx.arange(seq_len), (batch_size, 1))
163
+ else:
164
+ boundaries = mx.arange(
165
+ 1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side
166
+ )
167
+
168
+ # Flatten mask to match sequence length (handles both (B,H,W) and (B,H,W,1))
169
+ if patch_attention_mask.ndim == 4:
170
+ flat_mask = patch_attention_mask.squeeze(-1).reshape(batch_size, -1)[
171
+ :, :seq_len
172
+ ]
173
+ else:
174
+ flat_mask = patch_attention_mask.reshape(batch_size, -1)[:, :seq_len]
175
+
176
+ # Compute valid patches per image (channels-last indexing)
177
+ nb_patches_h = mx.maximum(patch_attention_mask[:, :, 0].sum(axis=1), 1)
178
+ nb_patches_w = mx.maximum(patch_attention_mask[:, 0, :].sum(axis=1), 1)
179
+
180
+ position_ids = mx.zeros((batch_size, seq_len), dtype=mx.int32)
181
+
182
+ for batch_idx in range(batch_size):
183
+ nb_h = int(nb_patches_h[batch_idx])
184
+ nb_w = int(nb_patches_w[batch_idx])
185
+
186
+ # Compute fractional coordinates
187
+ fractional_h = mx.arange(nb_h, dtype=mx.float32) / nb_h
188
+ fractional_w = mx.arange(nb_w, dtype=mx.float32) / nb_w
189
+ fractional_h = mx.clip(fractional_h, a_min=0.0, a_max=1.0 - 1e-6)
190
+ fractional_w = mx.clip(fractional_w, a_min=0.0, a_max=1.0 - 1e-6)
191
+
192
+ # Bucket into position IDs
193
+ bucket_h = mx.sum(fractional_h[:, None] >= boundaries[None, :], axis=1)
194
+ bucket_w = mx.sum(fractional_w[:, None] >= boundaries[None, :], axis=1)
195
+
196
+ # Create 2D grid: iterate over height, then width (row-major)
197
+ pos_ids = (
198
+ bucket_h[:, None] * self.num_patches_per_side + bucket_w[None, :]
199
+ ).reshape(-1)
200
+
201
+ valid_len = min(pos_ids.shape[0], seq_len)
202
+ position_ids[batch_idx, :valid_len] = pos_ids[:valid_len]
203
+
204
+ # Zero out position embeddings for padding
205
+ mask_expanded = flat_mask[:, :, None] # (batch, seq_len, 1)
206
+
207
+ pos_embeddings = self.position_embedding(position_ids)
208
+
209
+ # Apply mask to zero out padding position embeddings
210
+ if patch_attention_mask is not None:
211
+ pos_embeddings = pos_embeddings * mask_expanded
212
+
213
+ embeddings = embeddings + pos_embeddings
214
+ return embeddings
215
+
216
+
217
+ class VisionModel(nn.Module):
218
+ def __init__(self, config: VisionConfig):
219
+ super().__init__()
220
+ self.model_type = config.model_type
221
+ if self.model_type not in [
222
+ "siglip_vision_model",
223
+ "idefics3",
224
+ "idefics3_vision",
225
+ "smolvlm_vision",
226
+ ]:
227
+ raise ValueError(f"Unsupported model type: {self.model_type}")
228
+
229
+ self.embeddings = VisionEmbeddings(config)
230
+ self.encoder = Encoder(config)
231
+ self.post_layernorm = nn.LayerNorm(config.hidden_size)
232
+
233
+ def __call__(
234
+ self,
235
+ x: mx.array,
236
+ patch_attention_mask: Optional[mx.array] = None,
237
+ output_hidden_states: Optional[bool] = None,
238
+ ) -> mx.array:
239
+ x = self.embeddings(x, patch_attention_mask)
240
+ x = x.astype(self.embeddings.patch_embedding.weight.dtype)
241
+ encoder_outputs = self.encoder(
242
+ x=x, output_hidden_states=output_hidden_states, mask=None
243
+ )
244
+ pooler_output = self.post_layernorm(encoder_outputs[0])
245
+ return pooler_output, x, encoder_outputs[-1]
246
+
247
+ def sanitize(self, weights):
248
+ sanitized_weights = {}
249
+ for k, v in weights.items():
250
+ if "position_ids" in k:
251
+ # Remove unused position_ids
252
+ continue
253
+ elif "patch_embedding.weight" in k:
254
+ # PyTorch conv2d weight tensors have shape:
255
+ # [out_channels, in_channels, kH, KW]
256
+ # MLX conv2d expects the weight be of shape:
257
+ # [out_channels, kH, KW, in_channels]
258
+ if check_array_shape(v):
259
+ sanitized_weights[k] = v
260
+ else:
261
+ sanitized_weights[k] = v.transpose(0, 2, 3, 1)
262
+ else:
263
+ sanitized_weights[k] = v
264
+
265
+ return sanitized_weights
@@ -0,0 +1,3 @@
1
+ from .config import ModelConfig, TextConfig, VisionConfig
2
+ from .internvl_chat import LanguageModel, Model, VisionModel
3
+ from .processor import InternVLChatProcessor, InternVLImageProcessor
@@ -0,0 +1,89 @@
1
+ from dataclasses import dataclass
2
+ from typing import Dict, List, Optional, Union
3
+
4
+ from ..base import BaseModelConfig
5
+
6
+
7
+ @dataclass
8
+ class VisionConfig(BaseModelConfig):
9
+ model_type: str
10
+ hidden_size: int = 1024
11
+ num_attention_heads: int = 16
12
+ patch_size: int = 14
13
+ num_hidden_layers: int = 24
14
+ intermediate_size: int = 4096
15
+ image_size: int = 448
16
+ num_channels: int = 3
17
+ layer_norm_eps: float = 1e-6
18
+ drop_path_rate: float = 0.1
19
+ qkv_bias: bool = True
20
+ qk_normalization: bool = False
21
+ norm_type: str = "layer_norm"
22
+
23
+ @classmethod
24
+ def from_dict(cls, params):
25
+ def normalize_dim(v):
26
+ if isinstance(v, (list, tuple)):
27
+ if len(v) > 0:
28
+ try:
29
+ return int(v[0])
30
+ except Exception:
31
+ return v[0]
32
+ try:
33
+ return int(v)
34
+ except Exception:
35
+ return v
36
+
37
+ p = dict(params)
38
+ if "image_size" in p:
39
+ p["image_size"] = normalize_dim(p["image_size"])
40
+ if "patch_size" in p:
41
+ p["patch_size"] = normalize_dim(p["patch_size"])
42
+
43
+ import inspect as _inspect
44
+
45
+ return cls(
46
+ **{k: v for k, v in p.items() if k in _inspect.signature(cls).parameters}
47
+ )
48
+
49
+
50
+ @dataclass
51
+ class TextConfig(BaseModelConfig):
52
+ model_type: str
53
+ hidden_size: int
54
+ num_hidden_layers: int
55
+ intermediate_size: int
56
+ num_attention_heads: int
57
+ rms_norm_eps: float
58
+ vocab_size: int
59
+ max_window_layers: int
60
+ hidden_act: str
61
+ num_key_value_heads: Optional[int] = 8
62
+ head_dim: Optional[int] = None
63
+ max_position_embeddings: Optional[int] = 40960
64
+ rope_theta: float = 1000000.0
65
+ rope_traditional: bool = False
66
+ rope_scaling: Optional[Dict[str, Union[float, str]]] = None
67
+ tie_word_embeddings: bool = False
68
+ sliding_window: int = 32768
69
+ use_sliding_window: bool = False
70
+ use_cache: bool = True
71
+
72
+ def __post_init__(self):
73
+ if self.num_key_value_heads is None:
74
+ self.num_key_value_heads = self.num_attention_heads
75
+
76
+
77
+ @dataclass
78
+ class ModelConfig(BaseModelConfig):
79
+ text_config: TextConfig
80
+ vision_config: VisionConfig
81
+ model_type: str
82
+ ignore_index: int = -100
83
+ image_token_index: int = 151667
84
+ video_token_index: int = 151656
85
+ vision_feature_select_strategy: str = "default"
86
+ vision_feature_layer: int = -1
87
+ vocab_size: int = 32000
88
+ downsample_ratio: float = 0.5
89
+ eos_token_id: Optional[List[int]] = None
@@ -0,0 +1,115 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+ import numpy as np
6
+
7
+ from ..base import InputEmbeddingsFeatures, pixel_shuffle
8
+ from .config import ModelConfig
9
+ from .language import LanguageModel
10
+ from .vision import VisionModel
11
+
12
+
13
+ class Model(nn.Module):
14
+ def __init__(self, config: ModelConfig):
15
+ super().__init__()
16
+ self.config = config
17
+ self.vision_model = VisionModel(config.vision_config)
18
+ self.language_model = LanguageModel(config.text_config)
19
+
20
+ self.downsample_ratio = config.downsample_ratio
21
+
22
+ vit_hidden_size = self.config.vision_config.hidden_size
23
+ llm_hidden_size = self.config.text_config.hidden_size
24
+
25
+ self.mlp1 = [
26
+ nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
27
+ nn.Linear(
28
+ vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size
29
+ ),
30
+ nn.GELU(),
31
+ nn.Linear(llm_hidden_size, llm_hidden_size),
32
+ ]
33
+
34
+ def get_input_embeddings(
35
+ self,
36
+ input_ids: Optional[mx.array] = None,
37
+ pixel_values: Optional[mx.array] = None,
38
+ **kwargs,
39
+ ):
40
+
41
+ if pixel_values is None:
42
+ return InputEmbeddingsFeatures(
43
+ inputs_embeds=self.language_model.model.embed_tokens(input_ids)
44
+ )
45
+
46
+ dtype = self.vision_model.embeddings.patch_embedding.weight.dtype
47
+ pixel_values = pixel_values.astype(dtype)
48
+
49
+ # TODO: Remove this after transformers implementation is merged
50
+ if pixel_values.ndim == 5:
51
+ pixel_values = pixel_values[0]
52
+
53
+ # Get the input embeddings from the language model
54
+ inputs_embeds = self.language_model.model.embed_tokens(input_ids)
55
+
56
+ # Get the ouptut hidden states from the vision model
57
+ hidden_states, _, _ = self.vision_model(
58
+ pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True
59
+ )
60
+
61
+ # Extract vision embeddings, removing the class token (first token)
62
+ hidden_states = hidden_states[:, 1:, :]
63
+
64
+ # Apply pixel shuffle with downsampling
65
+ hidden_states = pixel_shuffle(
66
+ hidden_states, shuffle_ratio=self.downsample_ratio
67
+ )
68
+
69
+ # Apply MLP transformation
70
+ for layer in self.mlp1:
71
+ hidden_states = layer(hidden_states)
72
+
73
+ # Insert special image tokens in the input_ids
74
+ final_inputs_embeds = self._merge_input_ids_with_image_features(
75
+ hidden_states, inputs_embeds, input_ids
76
+ )
77
+ return InputEmbeddingsFeatures(inputs_embeds=final_inputs_embeds)
78
+
79
+ def _merge_input_ids_with_image_features(
80
+ self, image_features, inputs_embeds, input_ids
81
+ ):
82
+ B, N, C = inputs_embeds.shape
83
+ image_token_index = self.config.image_token_index
84
+ video_token_index = self.config.video_token_index
85
+
86
+ # Positions of <image> tokens in input_ids, assuming batch size is 1
87
+ image_positions = input_ids == image_token_index
88
+ if mx.sum(image_positions) == 0:
89
+ image_positions = input_ids == video_token_index
90
+
91
+ image_indices = np.where(image_positions)[1].tolist()
92
+
93
+ image_features = image_features.reshape(-1, image_features.shape[-1])
94
+
95
+ inputs_embeds[:, image_indices, :] = image_features
96
+
97
+ return inputs_embeds.reshape(B, N, C)
98
+
99
+ @property
100
+ def layers(self):
101
+ return self.language_model.model.layers
102
+
103
+ def __call__(
104
+ self,
105
+ input_ids: mx.array,
106
+ pixel_values: mx.array,
107
+ mask: mx.array,
108
+ cache=None,
109
+ **kwargs,
110
+ ):
111
+ input_embeddings_features = self.get_input_embeddings(input_ids, pixel_values)
112
+ logits = self.language_model(
113
+ None, cache=cache, inputs_embeds=input_embeddings_features.inputs_embeds
114
+ )
115
+ return logits