fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,187 @@
1
+ from typing import Optional
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn as nn
5
+
6
+ from ..base import (
7
+ LanguageModelOutput,
8
+ create_attention_mask,
9
+ scaled_dot_product_attention,
10
+ )
11
+ from ..cache import KVCache
12
+ from .config import TextConfig
13
+
14
+
15
+ class Attention(nn.Module):
16
+ def __init__(self, args: TextConfig):
17
+ super().__init__()
18
+
19
+ dim = args.hidden_size
20
+ self.n_heads = n_heads = args.num_attention_heads
21
+ assert args.num_key_value_heads is not None
22
+ self.n_kv_heads = n_kv_heads = args.num_key_value_heads
23
+
24
+ # Allow overriding head_dim to support architectures where
25
+ # n_heads * head_dim != hidden_size.
26
+ self.head_dim = head_dim = getattr(args, "head_dim", None) or (
27
+ args.hidden_size // n_heads
28
+ )
29
+ self.scale = head_dim**-0.5
30
+
31
+ self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=True)
32
+ self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
33
+ self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=True)
34
+ self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)
35
+
36
+ self.rotary_emb = nn.RoPE(
37
+ head_dim,
38
+ base=args.rope_theta,
39
+ traditional=args.rope_traditional,
40
+ )
41
+
42
+ def __call__(
43
+ self,
44
+ x: mx.array,
45
+ mask: Optional[mx.array] = None,
46
+ cache: Optional[KVCache] = None,
47
+ ) -> mx.array:
48
+ B, L, D = x.shape
49
+
50
+ queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
51
+
52
+ # Prepare the queries, keys and values for the attention computation
53
+ queries = queries.reshape(B, L, self.n_heads, self.head_dim).transpose(
54
+ 0, 2, 1, 3
55
+ )
56
+ keys = keys.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(0, 2, 1, 3)
57
+ values = values.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(
58
+ 0, 2, 1, 3
59
+ )
60
+
61
+ offset = cache.offset if cache else 0
62
+
63
+ if mask is not None and isinstance(mask, mx.array):
64
+ mask = mask[..., : keys.shape[-2]]
65
+
66
+ queries = self.rotary_emb(queries, offset=offset)
67
+ keys = self.rotary_emb(keys, offset=offset)
68
+
69
+ if cache is not None:
70
+ keys, values = cache.update_and_fetch(keys, values)
71
+
72
+ output = scaled_dot_product_attention(
73
+ queries, keys, values, cache, scale=self.scale, mask=mask
74
+ )
75
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
76
+ return self.o_proj(output)
77
+
78
+
79
+ class MLP(nn.Module):
80
+ def __init__(self, dim, hidden_dim):
81
+ super().__init__()
82
+ self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
83
+ self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
84
+ self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
85
+
86
+ def __call__(self, x) -> mx.array:
87
+ return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
88
+
89
+
90
+ class Qwen2VLDecoderLayer(nn.Module):
91
+ def __init__(self, args: TextConfig):
92
+ super().__init__()
93
+ self.num_attention_heads = args.num_attention_heads
94
+ self.hidden_size = args.hidden_size
95
+ self.self_attn = Attention(args)
96
+ self.mlp = MLP(args.hidden_size, args.intermediate_size)
97
+ self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
98
+ self.post_attention_layernorm = nn.RMSNorm(
99
+ args.hidden_size, eps=args.rms_norm_eps
100
+ )
101
+ self.args = args
102
+
103
+ def __call__(
104
+ self,
105
+ x: mx.array,
106
+ mask: Optional[mx.array] = None,
107
+ cache: Optional[KVCache] = None,
108
+ ) -> mx.array:
109
+ r = self.self_attn(self.input_layernorm(x), mask, cache)
110
+ h = x + r
111
+ r = self.mlp(self.post_attention_layernorm(h))
112
+ out = h + r
113
+ return out
114
+
115
+
116
+ class Qwen2Model(nn.Module):
117
+ def __init__(self, args: TextConfig):
118
+ super().__init__()
119
+ self.args = args
120
+ self.vocab_size = args.vocab_size
121
+ self.num_hidden_layers = args.num_hidden_layers
122
+ assert self.vocab_size > 0
123
+ self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
124
+ self.layers = [
125
+ Qwen2VLDecoderLayer(args=args) for _ in range(args.num_hidden_layers)
126
+ ]
127
+ self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
128
+
129
+ def __call__(
130
+ self,
131
+ inputs: mx.array,
132
+ inputs_embeds: Optional[mx.array] = None,
133
+ mask: Optional[mx.array] = None,
134
+ cache=None,
135
+ ):
136
+ if inputs_embeds is None:
137
+ h = self.embed_tokens(inputs)
138
+ else:
139
+ h = inputs_embeds
140
+
141
+ if cache is None:
142
+ cache = [None] * len(self.layers)
143
+
144
+ if mask is None:
145
+ mask = create_attention_mask(h, cache)
146
+
147
+ for layer, c in zip(self.layers, cache):
148
+ h = layer(h, mask, c)
149
+
150
+ return self.norm(h)
151
+
152
+
153
+ class LanguageModel(nn.Module):
154
+ def __init__(self, args: TextConfig):
155
+ super().__init__()
156
+ self.args = args
157
+ self.model_type = args.model_type
158
+ self.model = Qwen2Model(args)
159
+
160
+ if not args.tie_word_embeddings:
161
+ self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
162
+
163
+ def __call__(
164
+ self,
165
+ inputs: mx.array,
166
+ inputs_embeds: Optional[mx.array] = None,
167
+ mask: Optional[mx.array] = None,
168
+ cache=None,
169
+ ):
170
+ out = self.model(inputs, cache=cache, inputs_embeds=inputs_embeds)
171
+ if self.args.tie_word_embeddings:
172
+ out = self.model.embed_tokens.as_linear(out)
173
+ else:
174
+ out = self.lm_head(out)
175
+ return LanguageModelOutput(logits=out)
176
+
177
+ @property
178
+ def layers(self):
179
+ return self.model.layers
180
+
181
+ @property
182
+ def head_dim(self):
183
+ return self.args.hidden_size // self.args.num_attention_heads
184
+
185
+ @property
186
+ def n_kv_heads(self):
187
+ return self.args.num_key_value_heads
@@ -0,0 +1,395 @@
1
+ from typing import List, Optional, Union
2
+
3
+ import mlx.core as mx
4
+ import numpy as np
5
+ from PIL import Image
6
+ from transformers import (
7
+ AutoImageProcessor,
8
+ AutoProcessor,
9
+ AutoTokenizer,
10
+ BatchFeature,
11
+ ProcessorMixin,
12
+ )
13
+ from transformers.image_processing_utils import BaseImageProcessor
14
+ from transformers.utils import logging
15
+
16
+ logger = logging.get_logger(__name__)
17
+
18
+ # Constants for image processing (from internvl_chat.py)
19
+
20
+ IMAGENET_MEAN = np.array([0.485, 0.456, 0.406])
21
+ IMAGENET_STD = np.array([0.229, 0.224, 0.225])
22
+ # chat_template = get_conv_template("internvl2_5")
23
+ chat_template = "{% for message in messages %}{{message['role'].capitalize() + ': '}}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '<image>\n' }}{% endfor %}{# Render all text next #}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ content['content'] }}{% endfor %}{{'\n'}}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:\n' }}{% endif %}"
24
+
25
+ IMG_START_TOKEN = "<img>"
26
+ IMG_END_TOKEN = "</img>"
27
+ IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>"
28
+
29
+
30
+ def build_transform(input_size):
31
+ """
32
+ Builds a transformation pipeline for images.
33
+
34
+ Args:
35
+ input_size (int): The target size for the image (height and width).
36
+
37
+ Returns:
38
+ function: A function that takes a PIL image and returns a normalized mx.array.
39
+ """
40
+ mean = mx.array(IMAGENET_MEAN)
41
+ std = mx.array(IMAGENET_STD)
42
+
43
+ def transform(img: Image.Image) -> mx.array:
44
+ # Ensure image is RGB
45
+ if img.mode != "RGB":
46
+ img = img.convert("RGB")
47
+
48
+ # Resize using PIL - BICUBIC interpolation is default in Pillow >= 9.1.0 for resize
49
+ # For older versions, you might need Pillow-SIMD or explicitly set
50
+ # resampling=Image.BICUBIC if available.
51
+ img = img.resize((input_size, input_size), resample=Image.Resampling.BICUBIC)
52
+
53
+ # Convert PIL image to NumPy array (H, W, C) and scale to [0, 1]
54
+ img_np = np.array(img).astype(np.float32) / 255.0
55
+
56
+ # Convert to MLX array and transpose to (C, H, W)
57
+ img_mx = mx.array(img_np).transpose(2, 0, 1)
58
+
59
+ # Normalize
60
+ img_mx = (img_mx - mean[:, None, None]) / std[:, None, None]
61
+
62
+ return img_mx
63
+
64
+ return transform
65
+
66
+
67
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
68
+ """Finds the closest aspect ratio from a list of targets."""
69
+ best_ratio_diff = float("inf")
70
+ best_ratio = (1, 1)
71
+ area = width * height
72
+ for ratio in target_ratios:
73
+ target_aspect_ratio = ratio[0] / ratio[1]
74
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
75
+ if ratio_diff < best_ratio_diff:
76
+ best_ratio_diff = ratio_diff
77
+ best_ratio = ratio
78
+ elif ratio_diff == best_ratio_diff:
79
+ # Prioritize ratios closer to the original image area if diffs are equal
80
+ target_area = image_size * image_size * ratio[0] * ratio[1]
81
+ if abs(area - target_area) < abs(
82
+ area - image_size * image_size * best_ratio[0] * best_ratio[1]
83
+ ):
84
+ best_ratio = ratio
85
+ return best_ratio
86
+
87
+
88
+ def dynamic_preprocess(
89
+ image: Image.Image, min_num=1, max_num=12, image_size=448, use_thumbnail=False
90
+ ):
91
+ """
92
+ Preprocesses the image by splitting it into blocks based on the closest aspect ratio.
93
+
94
+ Args:
95
+ image (PIL.Image.Image): Input image.
96
+ min_num (int): Minimum number of blocks.
97
+ max_num (int): Maximum number of blocks.
98
+ image_size (int): Target size for each block.
99
+ use_thumbnail (bool): Whether to include a thumbnail of the original image.
100
+
101
+ Returns:
102
+ list[PIL.Image.Image]: A list of processed image blocks (as PIL images).
103
+ """
104
+ orig_width, orig_height = image.size
105
+ if orig_width == 0 or orig_height == 0:
106
+ # Handle potential zero dimensions
107
+ return []
108
+
109
+ aspect_ratio = orig_width / orig_height
110
+
111
+ # Calculate the possible target aspect ratios
112
+ target_ratios = set(
113
+ (i, j)
114
+ for n in range(min_num, max_num + 1)
115
+ for i in range(1, n + 1)
116
+ for j in range(1, n + 1)
117
+ if min_num <= i * j <= max_num
118
+ )
119
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
120
+
121
+ # Find the closest target aspect ratio
122
+ target_aspect_ratio = find_closest_aspect_ratio(
123
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size
124
+ )
125
+
126
+ # Calculate the target dimensions for resizing
127
+ target_width = image_size * target_aspect_ratio[0]
128
+ target_height = image_size * target_aspect_ratio[1]
129
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
130
+
131
+ # Resize the image to fit the target block structure
132
+ # Using BICUBIC resampling
133
+ resized_img = image.resize(
134
+ (target_width, target_height), resample=Image.Resampling.BICUBIC
135
+ )
136
+
137
+ processed_images = []
138
+ # Crop the resized image into blocks
139
+ for i in range(blocks):
140
+ # Calculate crop box for the i-th block
141
+ row_idx = i // target_aspect_ratio[0]
142
+ col_idx = i % target_aspect_ratio[0]
143
+ left = col_idx * image_size
144
+ top = row_idx * image_size
145
+ right = (col_idx + 1) * image_size
146
+ bottom = (row_idx + 1) * image_size
147
+ box = (left, top, right, bottom)
148
+
149
+ # Crop and add the block
150
+ split_img = resized_img.crop(box)
151
+ processed_images.append(split_img)
152
+
153
+ assert (
154
+ len(processed_images) == blocks
155
+ ), f"Expected {blocks} blocks, but got {len(processed_images)}"
156
+
157
+ # Add a thumbnail if requested and if the image was split
158
+ if use_thumbnail and blocks > 1:
159
+ thumbnail_img = image.resize(
160
+ (image_size, image_size), resample=Image.Resampling.BICUBIC
161
+ )
162
+ processed_images.append(thumbnail_img)
163
+
164
+ return processed_images
165
+
166
+
167
+ class InternVLImageProcessor(BaseImageProcessor):
168
+ model_input_names = ["pixel_values"]
169
+
170
+ def __init__(
171
+ self,
172
+ do_resize: bool = True,
173
+ size: int = 448, # Default image size from dynamic_preprocess
174
+ resample=Image.Resampling.BICUBIC,
175
+ do_center_crop: bool = False, # Not used in original, but standard HF param
176
+ crop_size=None,
177
+ do_rescale: bool = True, # Original code scales by 1/255.0
178
+ rescale_factor: float = 1 / 255.0,
179
+ do_normalize: bool = True,
180
+ image_mean=IMAGENET_MEAN.tolist(),
181
+ image_std=IMAGENET_STD.tolist(),
182
+ do_dynamic_preprocess: bool = True,
183
+ dynamic_min_num: int = 1,
184
+ dynamic_max_num: int = 12,
185
+ dynamic_use_thumbnail: bool = True,
186
+ **kwargs,
187
+ ):
188
+ super().__init__(**kwargs)
189
+ self.do_resize = (
190
+ do_resize # Although dynamic_preprocess handles resizing internally
191
+ )
192
+ self.size = size
193
+ self.resample = resample
194
+ self.do_center_crop = do_center_crop
195
+ self.crop_size = crop_size
196
+ self.do_rescale = do_rescale
197
+ self.rescale_factor = rescale_factor
198
+ self.do_normalize = do_normalize
199
+ self.image_mean = image_mean
200
+ self.image_std = image_std
201
+ # Custom dynamic processing params
202
+ self.do_dynamic_preprocess = do_dynamic_preprocess
203
+ self.dynamic_min_num = dynamic_min_num
204
+ self.dynamic_max_num = dynamic_max_num
205
+ self.dynamic_use_thumbnail = dynamic_use_thumbnail
206
+
207
+ def preprocess(
208
+ self,
209
+ images: List[Image.Image],
210
+ do_dynamic_preprocess: Optional[bool] = None,
211
+ size: Optional[int] = None,
212
+ # ... other params matching __init__ ...
213
+ return_tensors: Optional[str] = None,
214
+ **kwargs,
215
+ ) -> List[mx.array]:
216
+
217
+ do_dynamic_preprocess = (
218
+ do_dynamic_preprocess
219
+ if do_dynamic_preprocess is not None
220
+ else self.do_dynamic_preprocess
221
+ )
222
+ size = size if size is not None else self.size
223
+ # ... handle other overrides ...
224
+
225
+ if not isinstance(images, list):
226
+ images = [images]
227
+
228
+ if not all(isinstance(image, Image.Image) for image in images):
229
+ raise ValueError("Input must be a list of PIL Images.")
230
+
231
+ processed_images_batch = []
232
+ for image in images:
233
+ # Apply dynamic preprocessing
234
+ if do_dynamic_preprocess:
235
+ processed_images = dynamic_preprocess(
236
+ image,
237
+ min_num=self.dynamic_min_num,
238
+ max_num=self.dynamic_max_num,
239
+ image_size=size,
240
+ use_thumbnail=self.dynamic_use_thumbnail,
241
+ )
242
+ else:
243
+ # Fallback or alternative simpler preprocessing if needed
244
+ # e.g., simple resize + normalize
245
+ processed_images = [image.resize((size, size), resample=self.resample)]
246
+
247
+ # Create transform function
248
+ transform = build_transform(input_size=size)
249
+
250
+ # Apply transform to each image block and collect arrays
251
+ pixel_values_list = [transform(img) for img in processed_images]
252
+
253
+ # Stack the arrays along a new dimension (batch dimension)
254
+ pixel_values = mx.stack(pixel_values_list, axis=0)
255
+
256
+ processed_images_batch.append(pixel_values)
257
+
258
+ # At this point, processed_images_batch contains a list of mx arrays,
259
+ # each array corresponding to an input image with stacked blocks.
260
+
261
+ data = {"pixel_values": mx.array(processed_images_batch)}
262
+ return BatchFeature(data=data, tensor_type=None)
263
+
264
+
265
+ class InternVLChatProcessor(ProcessorMixin):
266
+ attributes = ["image_processor", "tokenizer"]
267
+ image_processor_class = "InternVLImageProcessor"
268
+ tokenizer_class = (
269
+ "AutoTokenizer",
270
+ "Qwen2TokenizerFast",
271
+ ) # Specify possible classes
272
+
273
+ def __init__(
274
+ self,
275
+ image_processor=None,
276
+ tokenizer=None,
277
+ chat_template=chat_template,
278
+ **kwargs,
279
+ ):
280
+ if image_processor is None:
281
+ image_processor = InternVLImageProcessor(**kwargs)
282
+ if isinstance(tokenizer, str):
283
+ # Defaulting to the likely repo ID found earlier
284
+ tokenizer = AutoTokenizer.from_pretrained(
285
+ tokenizer, trust_remote_code=True, **kwargs
286
+ )
287
+
288
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
289
+
290
+ self.num_image_token = int((448 // 14) ** 2 * (0.5**2))
291
+
292
+ def __call__(
293
+ self,
294
+ text: Union[str, List[str]] = None,
295
+ images: List[Image.Image] = None,
296
+ padding: Union[bool, str] = True,
297
+ truncation: bool = True,
298
+ max_length: Optional[int] = None,
299
+ return_tensors: Optional[str] = "pt", # Default to PyTorch tensors
300
+ **kwargs,
301
+ ):
302
+ processed_inputs = {}
303
+ if text is not None:
304
+ if isinstance(text, str):
305
+ text = [text]
306
+
307
+ if len(text) == 1 and images is not None and len(images) > 1:
308
+ raise ValueError("Multi-image inference is not supported.")
309
+
310
+ if images is not None:
311
+ image_features = self.image_processor.preprocess(
312
+ images, return_tensors=return_tensors, **kwargs
313
+ )
314
+ processed_inputs.update(image_features) # Should contain 'pixel_values'
315
+
316
+ if text is not None:
317
+ queries = []
318
+
319
+ for idx in range(len(images)):
320
+ question = text[idx]
321
+
322
+ if images is not None and "<image>" not in question:
323
+ question = "<image>\n" + question
324
+
325
+ num_patches = image_features["pixel_values"][idx].shape[0]
326
+ image_tokens = (
327
+ IMG_START_TOKEN
328
+ + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches
329
+ + IMG_END_TOKEN
330
+ )
331
+ question = question.replace("<image>", image_tokens, 1)
332
+ queries.append(question)
333
+
334
+ self.tokenizer.padding_side = "left"
335
+ text_inputs = self.tokenizer(
336
+ queries,
337
+ padding=padding,
338
+ truncation=truncation,
339
+ max_length=max_length,
340
+ return_tensors=return_tensors,
341
+ **kwargs,
342
+ )
343
+ processed_inputs.update(text_inputs) # 'input_ids', 'attention_mask'
344
+
345
+ return processed_inputs
346
+
347
+ def batch_decode(self, *args, **kwargs):
348
+ """
349
+ This method forwards all its arguments to the tokenizer's batch_decode method.
350
+ """
351
+ return self.tokenizer.batch_decode(*args, **kwargs)
352
+
353
+ def decode(self, *args, **kwargs):
354
+ """
355
+ This method forwards all its arguments to the tokenizer's decode method.
356
+ """
357
+ return self.tokenizer.decode(*args, **kwargs)
358
+
359
+ def save_pretrained(self, save_directory, **kwargs):
360
+ pass
361
+
362
+ @staticmethod
363
+ def from_pretrained(pretrained_model_name_or_path, **kwargs):
364
+ tokenizer = AutoTokenizer.from_pretrained(
365
+ pretrained_model_name_or_path, **kwargs
366
+ )
367
+ image_processor = InternVLImageProcessor(**kwargs)
368
+ return InternVLChatProcessor(
369
+ image_processor=image_processor, tokenizer=tokenizer
370
+ )
371
+
372
+ # Need save_pretrained and from_pretrained
373
+ # save_pretrained should save both tokenizer and image_processor configs/files
374
+ # from_pretrained should load both
375
+
376
+ # Example:
377
+ # def save_pretrained(self, save_directory, **kwargs):
378
+ # self.tokenizer.save_pretrained(save_directory, **kwargs)
379
+ # self.image_processor.save_pretrained(save_directory, **kwargs)
380
+
381
+ # def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
382
+ # tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
383
+ # image_processor = InternVLImageProcessor.from_pretrained(pretrained_model_name_or_path, **kwargs)
384
+ # return cls(image_processor=image_processor, tokenizer=tokenizer)
385
+
386
+
387
+ # Registration
388
+ MODEL_TYPE = "internvl_chat" # Verify this from the model's config.json
389
+
390
+ AutoImageProcessor.register(
391
+ MODEL_TYPE, slow_image_processor_class=InternVLImageProcessor
392
+ )
393
+ AutoProcessor.register(MODEL_TYPE, InternVLChatProcessor)
394
+
395
+ logger.info(f"Registered custom processor classes for model type '{MODEL_TYPE}'.")