fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
@@ -0,0 +1,60 @@
1
+ import mlx.core as mx
2
+ import numpy as np
3
+
4
+ from ..idefics3 import Model as Idefics3Model
5
+
6
+
7
+ class Model(Idefics3Model):
8
+ def _prepare_inputs_for_multimodal(self, image_features, inputs_embeds, input_ids):
9
+ # Assumes bs == 1
10
+
11
+ B, T, D_text = inputs_embeds.shape
12
+ N, S, D_img = image_features.shape
13
+
14
+ image_offset = 0
15
+ cur_embeds = inputs_embeds[0]
16
+
17
+ # Find positions of <image> tokens in the text
18
+ image_token_index = self.config.image_token_index
19
+ image_positions = np.where(input_ids == image_token_index)[1].tolist()
20
+ num_image_tokens = len(image_positions)
21
+
22
+ # If no <image> => text-only
23
+ if num_image_tokens == 0:
24
+ empty_slice = image_features[0][:0, :] # shape (0, D)
25
+ return mx.concatenate([cur_embeds, empty_slice], axis=0)
26
+
27
+ # Typically, if each image is S embeddings, we expect the total # of <image> tokens
28
+ # in this sample to be multiple of S => each group of S tokens = 1 image
29
+ if num_image_tokens % S != 0:
30
+ raise ValueError(
31
+ f"Input has {num_image_tokens} <image> tokens, not a multiple of S={S}. "
32
+ "Cannot map them to blocks of shape (S, D)."
33
+ )
34
+
35
+ chunks = [image_positions[i : i + S] for i in range(0, num_image_tokens, S)]
36
+
37
+ segments = []
38
+ text_start = 0
39
+
40
+ # For each chunk (each chunk => 1 image)
41
+ for chunk in chunks:
42
+ cur_block = image_features[image_offset]
43
+ image_offset += 1
44
+
45
+ # We'll iterate over the S positions in ascending order
46
+ for i_s, pos in enumerate(chunk):
47
+ if pos > text_start:
48
+ segments.append(cur_embeds[text_start:pos])
49
+ # Then add one row from cur_block => shape (1, D)
50
+ row_of_block = cur_block[i_s : i_s + 1, :]
51
+ segments.append(row_of_block)
52
+ text_start = pos + 1
53
+
54
+ # leftover text after the final <image> token
55
+ if text_start < T:
56
+ segments.append(cur_embeds[text_start:])
57
+
58
+ # cat them into a single (T_b, D) tensor
59
+ merged_sample = mx.concatenate(segments, axis=0)
60
+ return mx.expand_dims(merged_sample, axis=0)
@@ -0,0 +1,565 @@
1
+ from enum import Enum
2
+ from functools import partial
3
+ from typing import Any, Dict, List, Union
4
+
5
+ from pydantic import BaseModel
6
+
7
+
8
+ class MessageFormat(Enum):
9
+ """Enum for different message format types."""
10
+
11
+ LIST_WITH_IMAGE = "list_with_image"
12
+ LIST_WITH_IMAGE_FIRST = "list_with_image_first"
13
+ LIST_WITH_IMAGE_URL_FIRST = "list_with_image_url_first"
14
+ LIST_WITH_IMAGE_TYPE = "list_with_image_type"
15
+ LIST_WITH_IMAGE_TYPE_TEXT = "list_with_image_type_text"
16
+ LIST_WITH_IMAGE_TYPE_TEXT_IMAGE_LAST = "list_with_image_type_text_image_last"
17
+ IMAGE_TOKEN = "image_token"
18
+ IMAGE_TOKEN_PIPE = "image_token_pipe"
19
+ START_IMAGE_TOKEN = "start_image_token"
20
+ IMAGE_TOKEN_NEWLINE = "image_token_newline"
21
+ NUMBERED_IMAGE_TOKENS = "numbered_image_tokens"
22
+ PROMPT_ONLY = "prompt_only"
23
+ PROMPT_WITH_IMAGE_TOKEN = "prompt_with_image_token"
24
+ PROMPT_WITH_START_IMAGE_TOKEN = "prompt_with_start_image_token"
25
+ VIDEO_WITH_TEXT = "video_with_text"
26
+
27
+
28
+ # Model configuration mapping
29
+ MODEL_CONFIG = {
30
+ # List with image format models
31
+ "jina_vlm": MessageFormat.IMAGE_TOKEN_PIPE,
32
+ "jvlm": MessageFormat.IMAGE_TOKEN_PIPE,
33
+ "idefics2": MessageFormat.LIST_WITH_IMAGE,
34
+ "idefics3": MessageFormat.LIST_WITH_IMAGE_FIRST,
35
+ "lfm2-vl": MessageFormat.LIST_WITH_IMAGE_FIRST,
36
+ "lfm2_vl": MessageFormat.LIST_WITH_IMAGE_FIRST,
37
+ "aya_vision": MessageFormat.LIST_WITH_IMAGE,
38
+ "cohere2_vision": MessageFormat.LIST_WITH_IMAGE,
39
+ "paddleocr_vl": MessageFormat.LIST_WITH_IMAGE_FIRST,
40
+ "qwen2_vl": MessageFormat.LIST_WITH_IMAGE,
41
+ "qwen2_5_vl": MessageFormat.LIST_WITH_IMAGE_FIRST,
42
+ "qwen3_vl": MessageFormat.LIST_WITH_IMAGE_FIRST,
43
+ "qwen3_vl_moe": MessageFormat.LIST_WITH_IMAGE_FIRST,
44
+ "mistral3": MessageFormat.LIST_WITH_IMAGE_FIRST,
45
+ "glm4v": MessageFormat.LIST_WITH_IMAGE_FIRST,
46
+ "glm4v_moe": MessageFormat.LIST_WITH_IMAGE_FIRST,
47
+ "glm_ocr": MessageFormat.LIST_WITH_IMAGE_FIRST,
48
+ "ernie4_5_moe_vl": MessageFormat.LIST_WITH_IMAGE_URL_FIRST,
49
+ "internvl_chat": MessageFormat.LIST_WITH_IMAGE_TYPE,
50
+ "kimi_vl": MessageFormat.LIST_WITH_IMAGE,
51
+ "gemma3": MessageFormat.START_IMAGE_TOKEN,
52
+ "gemma3n": MessageFormat.LIST_WITH_IMAGE_TYPE_TEXT_IMAGE_LAST,
53
+ "llama4": MessageFormat.LIST_WITH_IMAGE,
54
+ "smolvlm": MessageFormat.LIST_WITH_IMAGE_FIRST,
55
+ "llava": MessageFormat.LIST_WITH_IMAGE,
56
+ "llava_next": MessageFormat.LIST_WITH_IMAGE,
57
+ "mllama": MessageFormat.LIST_WITH_IMAGE,
58
+ "pixtral": MessageFormat.LIST_WITH_IMAGE_TYPE_TEXT,
59
+ "molmo2": MessageFormat.LIST_WITH_IMAGE_FIRST,
60
+ # Token-based models
61
+ "llava-qwen2": MessageFormat.IMAGE_TOKEN_NEWLINE,
62
+ "llava_qwen2": MessageFormat.IMAGE_TOKEN_NEWLINE, # fastvlm
63
+ "bunny-llama": MessageFormat.IMAGE_TOKEN_NEWLINE,
64
+ "phi3_v": MessageFormat.NUMBERED_IMAGE_TOKENS,
65
+ "multi_modality": MessageFormat.IMAGE_TOKEN,
66
+ "deepseek_vl_v2": MessageFormat.IMAGE_TOKEN_NEWLINE,
67
+ "deepseekocr_2": MessageFormat.IMAGE_TOKEN_NEWLINE,
68
+ "deepseekocr": MessageFormat.IMAGE_TOKEN_NEWLINE,
69
+ "hunyuan_vl": MessageFormat.LIST_WITH_IMAGE_FIRST,
70
+ # Prompt-only models
71
+ "florence2": MessageFormat.PROMPT_ONLY,
72
+ "molmo": MessageFormat.PROMPT_ONLY,
73
+ "paligemma": MessageFormat.PROMPT_WITH_IMAGE_TOKEN,
74
+ "moondream1": MessageFormat.PROMPT_WITH_IMAGE_TOKEN,
75
+ }
76
+
77
+ # Models that don't support multi-image
78
+ SINGLE_IMAGE_ONLY_MODELS = {
79
+ "llava_next",
80
+ "llava-qwen2",
81
+ "bunny-llama",
82
+ "paligemma",
83
+ "multi_modality",
84
+ "mllama",
85
+ "moondream1",
86
+ }
87
+
88
+
89
+ def extract_text_from_content(content: Any) -> str:
90
+ """
91
+ Extract text from multimodal content.
92
+
93
+ When using OpenAI-compatible multimodal API, content can be a list like:
94
+ [
95
+ {"type": "text", "text": "Describe this image"},
96
+ {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}
97
+ ]
98
+
99
+ This function extracts only the text parts, preventing base64 image data
100
+ from being tokenized as text (which would cause token explosion).
101
+
102
+ Args:
103
+ content: Either a string or a list of content items
104
+
105
+ Returns:
106
+ A string containing only the text content
107
+ """
108
+ if isinstance(content, str):
109
+ return content
110
+
111
+ if isinstance(content, list):
112
+ text_parts = []
113
+ for item in content:
114
+ if isinstance(item, dict):
115
+ item_type = item.get("type", "")
116
+ # Extract text from text-type items
117
+ if item_type in ("text", "input_text"):
118
+ text = item.get("text", "") or item.get("content", "")
119
+ if text:
120
+ text_parts.append(text)
121
+ # Skip image_url, input_image, input_audio - these are handled separately
122
+ return " ".join(text_parts).strip() if text_parts else ""
123
+
124
+ # Fallback: convert to string (shouldn't happen in normal usage)
125
+ return str(content) if content else ""
126
+
127
+
128
+ class MessageBuilder:
129
+ """Builder for creating messages in various formats."""
130
+
131
+ @staticmethod
132
+ def text_message(text: str) -> Dict[str, str]:
133
+ """Create a simple text message."""
134
+ return {"type": "text", "text": text, "content": text}
135
+
136
+ @staticmethod
137
+ def content_message(content: str) -> Dict[str, str]:
138
+ """Create a content-type text message."""
139
+ return {"type": "text", "text": content, "content": content}
140
+
141
+ @staticmethod
142
+ def image_message() -> Dict[str, str]:
143
+ """Create an image message."""
144
+ return {"type": "image"}
145
+
146
+ @staticmethod
147
+ def image_url_message() -> Dict[str, str]:
148
+ """Create an image_url message (for models like ERNIE that expect this format)."""
149
+ return {"type": "image_url"}
150
+
151
+ @staticmethod
152
+ def audio_message() -> Dict[str, str]:
153
+ """Create an audio message."""
154
+ return {"type": "audio"}
155
+
156
+ @staticmethod
157
+ def video_message(
158
+ video_path: str, max_pixels: int = 224 * 224, fps: int = 1
159
+ ) -> Dict[str, Any]:
160
+ """Create a video message."""
161
+ return {
162
+ "type": "video",
163
+ "video": video_path,
164
+ "max_pixels": max_pixels,
165
+ "fps": fps,
166
+ }
167
+
168
+
169
+ class MessageFormatter:
170
+ """Handles formatting messages for different model types."""
171
+
172
+ def __init__(self, model_name: str):
173
+ self.model_name = model_name.lower()
174
+ self.format_type = MODEL_CONFIG.get(self.model_name)
175
+ if not self.format_type:
176
+ raise ValueError(f"Unsupported model: {model_name}")
177
+
178
+ def format_message(
179
+ self,
180
+ prompt: str,
181
+ role: str = "user",
182
+ skip_image_token: bool = False,
183
+ skip_audio_token: bool = False,
184
+ num_images: int = 1,
185
+ num_audios: int = 1,
186
+ **kwargs,
187
+ ) -> Union[str, Dict[str, Any]]:
188
+ """Format a message based on the model type."""
189
+
190
+ # Check multi-image support
191
+ if num_images > 1 and self.model_name in SINGLE_IMAGE_ONLY_MODELS:
192
+ raise ValueError(
193
+ f"Model {self.model_name} does not support multi-image chat. "
194
+ f"Please only use 1 image."
195
+ )
196
+
197
+ # Handle video format for specific models
198
+ if self.model_name in [
199
+ "qwen2_vl",
200
+ "qwen2_5_vl",
201
+ "qwen3_vl",
202
+ "qwen3_vl_moe",
203
+ ] and kwargs.get("video"):
204
+ return self._format_video_message(prompt, kwargs)
205
+
206
+ # Route to appropriate formatter
207
+ formatter_map = {
208
+ MessageFormat.LIST_WITH_IMAGE: self._format_list_with_image,
209
+ MessageFormat.LIST_WITH_IMAGE_FIRST: partial(
210
+ self._format_list_with_image, image_first=True
211
+ ),
212
+ MessageFormat.LIST_WITH_IMAGE_URL_FIRST: partial(
213
+ self._format_list_with_image, image_first=True, use_image_url=True
214
+ ),
215
+ MessageFormat.LIST_WITH_IMAGE_TYPE: self._format_list_with_image_type,
216
+ MessageFormat.LIST_WITH_IMAGE_TYPE_TEXT: partial(
217
+ self._format_list_with_image_type, message_type="text"
218
+ ),
219
+ MessageFormat.LIST_WITH_IMAGE_TYPE_TEXT_IMAGE_LAST: partial(
220
+ self._format_list_with_image_type,
221
+ message_type="text",
222
+ image_first=False,
223
+ ),
224
+ MessageFormat.IMAGE_TOKEN: partial(
225
+ self._format_with_token, token="<image>"
226
+ ),
227
+ MessageFormat.IMAGE_TOKEN_PIPE: partial(
228
+ self._format_with_token, token="<|image|>"
229
+ ),
230
+ MessageFormat.START_IMAGE_TOKEN: partial(
231
+ self._format_with_token, token="<start_of_image>", image_first=False
232
+ ),
233
+ MessageFormat.IMAGE_TOKEN_NEWLINE: partial(
234
+ self._format_with_token, token="<image>\n"
235
+ ),
236
+ MessageFormat.NUMBERED_IMAGE_TOKENS: self._format_numbered_tokens,
237
+ MessageFormat.PROMPT_ONLY: lambda *args, **kw: prompt,
238
+ MessageFormat.PROMPT_WITH_IMAGE_TOKEN: lambda *args, **kw: "<image>"
239
+ * num_images
240
+ + prompt,
241
+ MessageFormat.PROMPT_WITH_START_IMAGE_TOKEN: lambda *args, **kw: prompt
242
+ + "<start_of_image>" * num_images,
243
+ MessageFormat.VIDEO_WITH_TEXT: self._format_video_message,
244
+ }
245
+
246
+ formatter = formatter_map.get(self.format_type)
247
+ return formatter(
248
+ prompt,
249
+ role,
250
+ skip_image_token,
251
+ skip_audio_token,
252
+ num_images,
253
+ num_audios,
254
+ **kwargs,
255
+ )
256
+
257
+ def _format_list_with_image(
258
+ self,
259
+ prompt: str,
260
+ role: str,
261
+ skip_image_token: bool,
262
+ skip_audio_token: bool,
263
+ num_images: int,
264
+ num_audios: int,
265
+ image_first: bool = False,
266
+ use_image_url: bool = False,
267
+ **kwargs,
268
+ ) -> Dict[str, Any]:
269
+ """Format as a list with image tokens."""
270
+ content = [MessageBuilder.text_message(prompt)]
271
+
272
+ if role == "user" and not skip_image_token and num_images > 0:
273
+ image_builder = (
274
+ MessageBuilder.image_url_message
275
+ if use_image_url
276
+ else MessageBuilder.image_message
277
+ )
278
+ image_tokens = [image_builder()] * num_images
279
+ content = image_tokens + content if image_first else content + image_tokens
280
+
281
+ return {"role": role, "content": content}
282
+
283
+ def _format_list_with_image_type(
284
+ self,
285
+ prompt: str,
286
+ role: str,
287
+ skip_image_token: bool,
288
+ skip_audio_token: bool,
289
+ num_images: int,
290
+ num_audios: int,
291
+ message_type: str = "content",
292
+ image_first: bool = True,
293
+ **kwargs,
294
+ ) -> Dict[str, Any]:
295
+ """Format as a list with typed messages."""
296
+ msg_func = (
297
+ MessageBuilder.content_message
298
+ if message_type == "content"
299
+ else MessageBuilder.text_message
300
+ )
301
+ message = {"role": role, "content": [msg_func(prompt)]}
302
+
303
+ if role == "user":
304
+ if not skip_image_token and num_images > 0:
305
+ message["content"] = (
306
+ [MessageBuilder.image_message()] * num_images + message["content"]
307
+ if image_first
308
+ else message["content"]
309
+ + [MessageBuilder.image_message()] * num_images
310
+ )
311
+ if not skip_audio_token and num_audios > 0:
312
+ message["content"] = (
313
+ message["content"] + [MessageBuilder.audio_message()] * num_audios
314
+ )
315
+
316
+ if role == "assistant":
317
+ message["content"] = message["content"][0].get(
318
+ "content", message["content"][0].get("text")
319
+ )
320
+
321
+ return message
322
+
323
+ def _format_with_token(
324
+ self,
325
+ prompt: str,
326
+ role: str,
327
+ skip_image_token: bool,
328
+ skip_audio_token: bool,
329
+ num_images: int,
330
+ num_audios: int,
331
+ token: str,
332
+ image_first: bool = True,
333
+ **kwargs,
334
+ ) -> Dict[str, Any]:
335
+ """Format with image tokens in the text."""
336
+ content = prompt
337
+
338
+ if role == "user" and not skip_image_token and num_images > 0:
339
+ prefix = token * num_images
340
+ content = f"{prefix}{content}" if image_first else f"{content}{prefix}"
341
+
342
+ return {"role": role, "content": content}
343
+
344
+ def _format_numbered_tokens(
345
+ self,
346
+ prompt: str,
347
+ role: str,
348
+ skip_image_token: bool,
349
+ skip_audio_token: bool,
350
+ num_images: int,
351
+ num_audios: int,
352
+ **kwargs,
353
+ ) -> Dict[str, Any]:
354
+ """Format with numbered image tokens."""
355
+ content = prompt
356
+
357
+ if role == "user" and not skip_image_token and num_images > 0:
358
+ # phi3_v uses single token regardless of num_images
359
+ prefix = (
360
+ "<|image_1|>"
361
+ if self.model_name == "phi3_v"
362
+ else " ".join([f"<|image_{i+1}|>" for i in range(num_images)])
363
+ )
364
+ content = f"{prefix}{content}"
365
+
366
+ return {"role": role, "content": content}
367
+
368
+ def _format_video_message(
369
+ self,
370
+ prompt: str,
371
+ role: str = "user",
372
+ skip_image_token: bool = False,
373
+ skip_audio_token: bool = False,
374
+ num_images: int = 0,
375
+ num_audios: int = 0,
376
+ **kwargs,
377
+ ) -> Dict[str, Any]:
378
+ """Format a video message with text."""
379
+ return {
380
+ "role": role,
381
+ "content": [
382
+ MessageBuilder.video_message(
383
+ kwargs["video"],
384
+ kwargs.get("max_pixels", 224 * 224),
385
+ kwargs.get("fps", 1),
386
+ ),
387
+ MessageBuilder.text_message(prompt),
388
+ ],
389
+ }
390
+
391
+
392
+ def get_message_json(
393
+ model_name: str,
394
+ prompt: str,
395
+ role: str = "user",
396
+ skip_image_token: bool = False,
397
+ skip_audio_token: bool = False,
398
+ num_images: int = 0,
399
+ num_audios: int = 0,
400
+ **kwargs,
401
+ ) -> Union[str, Dict[str, Any]]:
402
+ """
403
+ Get the appropriate JSON message based on the specified model.
404
+
405
+ Args:
406
+ model_name: The model for which to generate the message
407
+ prompt: The text prompt to be included in the message
408
+ role: The role of the message (default: "user")
409
+ skip_image_token: Whether to skip adding image tokens
410
+ skip_audio_token: Whether to skip adding audio tokens
411
+ num_images: Number of image tokens to add
412
+ num_audios: Number of audio tokens to add
413
+ **kwargs: Additional arguments (e.g., video path, max_pixels, fps)
414
+
415
+ Returns:
416
+ A dictionary or string representing the message for the specified model
417
+ """
418
+ formatter = MessageFormatter(model_name)
419
+
420
+ return formatter.format_message(
421
+ prompt,
422
+ role,
423
+ skip_image_token,
424
+ skip_audio_token,
425
+ num_images,
426
+ num_audios,
427
+ **kwargs,
428
+ )
429
+
430
+
431
+ def get_chat_template(
432
+ processor,
433
+ messages: List[Dict[str, Any]],
434
+ add_generation_prompt: bool,
435
+ tokenize: bool = False,
436
+ **kwargs,
437
+ ) -> Any:
438
+ """Apply chat template using processor's tokenizer."""
439
+ try:
440
+ processor = (
441
+ processor
442
+ if processor.__dict__.get("chat_template")
443
+ else processor.tokenizer
444
+ )
445
+
446
+ return processor.apply_chat_template(
447
+ messages,
448
+ tokenize=tokenize,
449
+ add_generation_prompt=add_generation_prompt,
450
+ **kwargs,
451
+ )
452
+ except AttributeError:
453
+ raise ValueError(
454
+ "Error: processor does not have 'chat_template' or 'tokenizer' attribute."
455
+ )
456
+
457
+
458
+ def apply_chat_template(
459
+ processor,
460
+ config: Union[Dict[str, Any], Any],
461
+ prompt: Union[str, Dict[str, Any], List[Any]],
462
+ add_generation_prompt: bool = True,
463
+ return_messages: bool = False,
464
+ num_images: int = 0,
465
+ num_audios: int = 0,
466
+ **kwargs,
467
+ ) -> Union[List[Dict[str, Any]], str, Any]:
468
+ """
469
+ Apply chat template to prompts.
470
+
471
+ Args:
472
+ processor: The processor with chat template functionality
473
+ config: Model configuration
474
+ prompt: Single prompt string, dict, or list of prompts
475
+ add_generation_prompt: Whether to add generation prompt
476
+ return_messages: Whether to return messages list instead of template
477
+ num_images: Number of images in the input
478
+ num_audios: Number of audio files in the input
479
+ **kwargs: Additional arguments for message formatting
480
+
481
+ Returns:
482
+ Formatted messages or chat template
483
+ """
484
+ config = config if isinstance(config, dict) else config.__dict__
485
+ model_type = config["model_type"]
486
+
487
+ # Build messages from prompts
488
+ messages = []
489
+
490
+ if isinstance(prompt, str):
491
+ # Single string prompt
492
+ messages.append(
493
+ get_message_json(
494
+ model_type,
495
+ prompt,
496
+ num_images=num_images,
497
+ num_audios=num_audios,
498
+ **kwargs,
499
+ )
500
+ )
501
+ elif isinstance(prompt, dict):
502
+ # Single dict prompt
503
+ content = extract_text_from_content(prompt["content"])
504
+ messages.append(
505
+ get_message_json(
506
+ model_type,
507
+ content,
508
+ prompt["role"],
509
+ num_images=num_images,
510
+ num_audios=num_audios,
511
+ **kwargs,
512
+ )
513
+ )
514
+ elif isinstance(prompt, list):
515
+ # List of prompts
516
+ for i, p in enumerate(prompt):
517
+ if isinstance(p, str):
518
+ is_first = i == 0
519
+ messages.append(
520
+ get_message_json(
521
+ model_type,
522
+ p,
523
+ skip_image_token=not is_first,
524
+ skip_audio_token=not is_first,
525
+ num_images=num_images,
526
+ num_audios=num_audios,
527
+ **kwargs,
528
+ )
529
+ )
530
+ elif isinstance(p, dict) or isinstance(p, BaseModel):
531
+ role = "user"
532
+ content = ""
533
+ if isinstance(p, dict):
534
+ role = p.get("role", "user")
535
+ content = p.get("content")
536
+ else:
537
+ role = p.role
538
+ content = p.content
539
+ # Handle multimodal content: extract only text, skip image/audio URLs
540
+ # This prevents base64 image data from being tokenized as text
541
+ content = extract_text_from_content(content)
542
+ is_first = i == 0 or (i == 1 and role not in ["system", "assistant"])
543
+ messages.append(
544
+ get_message_json(
545
+ model_type,
546
+ content,
547
+ role,
548
+ skip_image_token=not is_first
549
+ or role in ["system", "assistant"],
550
+ skip_audio_token=not is_first
551
+ or role in ["system", "assistant"],
552
+ num_images=num_images,
553
+ num_audios=num_audios,
554
+ **kwargs,
555
+ )
556
+ )
557
+
558
+ if return_messages:
559
+ return messages
560
+
561
+ # Some models only need the last message
562
+ if model_type in ["paligemma", "molmo", "florence2", "moondream1"]:
563
+ return messages[-1]
564
+
565
+ return get_chat_template(processor, messages, add_generation_prompt)
@@ -0,0 +1,39 @@
1
+ import mlx.core as mx
2
+
3
+
4
+ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array:
5
+ """
6
+ Apply top-p (nucleus) sampling to logits.
7
+
8
+ Args:
9
+ logits: The logits from the model's output.
10
+ top_p: The cumulative probability threshold for top-p filtering.
11
+ temperature: Temperature parameter for softmax distribution reshaping.
12
+ Returns:
13
+ token selected based on the top-p criterion.
14
+ """
15
+ if (
16
+ logits.dtype == mx.bfloat16
17
+ ): # workaround for unable to load kernel contiguous_scan_inclusive_sum_bfloat16_bfloat16
18
+ logits = logits.astype(mx.float32)
19
+
20
+ # referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460
21
+ probs = mx.softmax(logits / temperature, axis=-1)
22
+
23
+ # sort probs in ascending order
24
+ sorted_indices = mx.argsort(probs, axis=-1)
25
+ sorted_probs = probs[..., sorted_indices.squeeze(0)]
26
+
27
+ cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
28
+
29
+ # select tokens with cumulative probs below threshold
30
+ top_probs = mx.where(
31
+ cumulative_probs > 1 - top_p,
32
+ sorted_probs,
33
+ mx.zeros_like(sorted_probs),
34
+ )
35
+
36
+ sorted_token = mx.random.categorical(mx.log(top_probs))
37
+ token = sorted_indices.squeeze(0)[sorted_token]
38
+
39
+ return token