fount-vlm-nell-02 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (258) hide show
  1. fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
  2. fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
  3. fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
  4. fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
  5. fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
  6. fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
  7. mlx_vlm/__init__.py +16 -0
  8. mlx_vlm/__main__.py +24 -0
  9. mlx_vlm/chat.py +234 -0
  10. mlx_vlm/chat_ui.py +508 -0
  11. mlx_vlm/convert.py +284 -0
  12. mlx_vlm/deprecation.py +52 -0
  13. mlx_vlm/evals/__init__.py +0 -0
  14. mlx_vlm/evals/math_vista.py +565 -0
  15. mlx_vlm/evals/mmmu.py +528 -0
  16. mlx_vlm/evals/mmstar.py +343 -0
  17. mlx_vlm/evals/ocrbench.py +453 -0
  18. mlx_vlm/evals/utils.py +37 -0
  19. mlx_vlm/generate.py +1457 -0
  20. mlx_vlm/lora.py +207 -0
  21. mlx_vlm/models/__init__.py +0 -0
  22. mlx_vlm/models/aya_vision/__init__.py +2 -0
  23. mlx_vlm/models/aya_vision/aya_vision.py +188 -0
  24. mlx_vlm/models/aya_vision/config.py +52 -0
  25. mlx_vlm/models/aya_vision/language.py +202 -0
  26. mlx_vlm/models/aya_vision/vision.py +340 -0
  27. mlx_vlm/models/base.py +356 -0
  28. mlx_vlm/models/cache.py +238 -0
  29. mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
  30. mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
  31. mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
  32. mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
  33. mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
  34. mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
  35. mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
  36. mlx_vlm/models/deepseekocr/__init__.py +2 -0
  37. mlx_vlm/models/deepseekocr/config.py +173 -0
  38. mlx_vlm/models/deepseekocr/conversation.py +264 -0
  39. mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
  40. mlx_vlm/models/deepseekocr/language.py +547 -0
  41. mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
  42. mlx_vlm/models/deepseekocr/sam.py +489 -0
  43. mlx_vlm/models/deepseekocr/vision.py +263 -0
  44. mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
  45. mlx_vlm/models/deepseekocr_2/config.py +216 -0
  46. mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
  47. mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
  48. mlx_vlm/models/deepseekocr_2/vision.py +439 -0
  49. mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
  50. mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
  51. mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
  52. mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
  53. mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
  54. mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
  55. mlx_vlm/models/fastvlm/__init__.py +2 -0
  56. mlx_vlm/models/fastvlm/config.py +79 -0
  57. mlx_vlm/models/fastvlm/fastvlm.py +198 -0
  58. mlx_vlm/models/fastvlm/language.py +49 -0
  59. mlx_vlm/models/fastvlm/vision.py +692 -0
  60. mlx_vlm/models/florence2/__init__.py +2 -0
  61. mlx_vlm/models/florence2/config.py +84 -0
  62. mlx_vlm/models/florence2/florence2.py +383 -0
  63. mlx_vlm/models/florence2/language.py +452 -0
  64. mlx_vlm/models/florence2/processing_florence2.py +30 -0
  65. mlx_vlm/models/florence2/vision.py +552 -0
  66. mlx_vlm/models/gemma3/__init__.py +2 -0
  67. mlx_vlm/models/gemma3/config.py +52 -0
  68. mlx_vlm/models/gemma3/gemma3.py +194 -0
  69. mlx_vlm/models/gemma3/language.py +293 -0
  70. mlx_vlm/models/gemma3/vision.py +215 -0
  71. mlx_vlm/models/gemma3n/__init__.py +2 -0
  72. mlx_vlm/models/gemma3n/audio.py +1038 -0
  73. mlx_vlm/models/gemma3n/config.py +130 -0
  74. mlx_vlm/models/gemma3n/gemma3n.py +322 -0
  75. mlx_vlm/models/gemma3n/language.py +631 -0
  76. mlx_vlm/models/gemma3n/vision.py +994 -0
  77. mlx_vlm/models/glm4v/__init__.py +3 -0
  78. mlx_vlm/models/glm4v/config.py +79 -0
  79. mlx_vlm/models/glm4v/glm4v.py +188 -0
  80. mlx_vlm/models/glm4v/language.py +574 -0
  81. mlx_vlm/models/glm4v/processing.py +220 -0
  82. mlx_vlm/models/glm4v/vision.py +406 -0
  83. mlx_vlm/models/glm4v_moe/__init__.py +3 -0
  84. mlx_vlm/models/glm4v_moe/config.py +81 -0
  85. mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
  86. mlx_vlm/models/glm4v_moe/language.py +674 -0
  87. mlx_vlm/models/glm4v_moe/processing.py +229 -0
  88. mlx_vlm/models/glm4v_moe/vision.py +405 -0
  89. mlx_vlm/models/glm_ocr/__init__.py +3 -0
  90. mlx_vlm/models/glm_ocr/config.py +93 -0
  91. mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
  92. mlx_vlm/models/glm_ocr/language.py +585 -0
  93. mlx_vlm/models/glm_ocr/processing.py +208 -0
  94. mlx_vlm/models/glm_ocr/vision.py +342 -0
  95. mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
  96. mlx_vlm/models/hunyuan_vl/config.py +136 -0
  97. mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
  98. mlx_vlm/models/hunyuan_vl/language.py +509 -0
  99. mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
  100. mlx_vlm/models/hunyuan_vl/vision.py +322 -0
  101. mlx_vlm/models/idefics2/__init__.py +2 -0
  102. mlx_vlm/models/idefics2/config.py +65 -0
  103. mlx_vlm/models/idefics2/idefics2.py +321 -0
  104. mlx_vlm/models/idefics2/language.py +161 -0
  105. mlx_vlm/models/idefics2/vision.py +244 -0
  106. mlx_vlm/models/idefics3/__init__.py +4 -0
  107. mlx_vlm/models/idefics3/config.py +54 -0
  108. mlx_vlm/models/idefics3/idefics3.py +221 -0
  109. mlx_vlm/models/idefics3/language.py +157 -0
  110. mlx_vlm/models/idefics3/vision.py +265 -0
  111. mlx_vlm/models/internvl_chat/__init__.py +3 -0
  112. mlx_vlm/models/internvl_chat/config.py +89 -0
  113. mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
  114. mlx_vlm/models/internvl_chat/language.py +187 -0
  115. mlx_vlm/models/internvl_chat/processor.py +395 -0
  116. mlx_vlm/models/internvl_chat/vision.py +265 -0
  117. mlx_vlm/models/interpolate.py +183 -0
  118. mlx_vlm/models/jina_vlm/__init__.py +3 -0
  119. mlx_vlm/models/jina_vlm/config.py +142 -0
  120. mlx_vlm/models/jina_vlm/image_processor.py +430 -0
  121. mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
  122. mlx_vlm/models/jina_vlm/language.py +272 -0
  123. mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
  124. mlx_vlm/models/jina_vlm/vision.py +202 -0
  125. mlx_vlm/models/kernels.py +447 -0
  126. mlx_vlm/models/kimi_vl/__init__.py +4 -0
  127. mlx_vlm/models/kimi_vl/config.py +84 -0
  128. mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
  129. mlx_vlm/models/kimi_vl/language.py +460 -0
  130. mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
  131. mlx_vlm/models/kimi_vl/vision.py +485 -0
  132. mlx_vlm/models/lfm2_vl/__init__.py +2 -0
  133. mlx_vlm/models/lfm2_vl/config.py +94 -0
  134. mlx_vlm/models/lfm2_vl/language.py +49 -0
  135. mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
  136. mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
  137. mlx_vlm/models/lfm2_vl/vision.py +223 -0
  138. mlx_vlm/models/llama4/__init__.py +2 -0
  139. mlx_vlm/models/llama4/config.py +83 -0
  140. mlx_vlm/models/llama4/language.py +334 -0
  141. mlx_vlm/models/llama4/llama4.py +146 -0
  142. mlx_vlm/models/llama4/vision.py +526 -0
  143. mlx_vlm/models/llava/__init__.py +2 -0
  144. mlx_vlm/models/llava/config.py +61 -0
  145. mlx_vlm/models/llava/language.py +200 -0
  146. mlx_vlm/models/llava/llava.py +132 -0
  147. mlx_vlm/models/llava/vision.py +233 -0
  148. mlx_vlm/models/llava_bunny/__init__.py +2 -0
  149. mlx_vlm/models/llava_bunny/config.py +85 -0
  150. mlx_vlm/models/llava_bunny/language.py +194 -0
  151. mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
  152. mlx_vlm/models/llava_bunny/vision.py +278 -0
  153. mlx_vlm/models/llava_next/__init__.py +2 -0
  154. mlx_vlm/models/llava_next/config.py +60 -0
  155. mlx_vlm/models/llava_next/language.py +192 -0
  156. mlx_vlm/models/llava_next/llava_next.py +138 -0
  157. mlx_vlm/models/llava_next/vision.py +217 -0
  158. mlx_vlm/models/mistral3/__init__.py +2 -0
  159. mlx_vlm/models/mistral3/config.py +59 -0
  160. mlx_vlm/models/mistral3/language.py +269 -0
  161. mlx_vlm/models/mistral3/mistral3.py +383 -0
  162. mlx_vlm/models/mllama/__init__.py +4 -0
  163. mlx_vlm/models/mllama/config.py +74 -0
  164. mlx_vlm/models/mllama/language.py +377 -0
  165. mlx_vlm/models/mllama/mllama.py +210 -0
  166. mlx_vlm/models/mllama/vision.py +458 -0
  167. mlx_vlm/models/molmo/__init__.py +5 -0
  168. mlx_vlm/models/molmo/config.py +93 -0
  169. mlx_vlm/models/molmo/language.py +208 -0
  170. mlx_vlm/models/molmo/molmo.py +108 -0
  171. mlx_vlm/models/molmo/processing_molmo.py +763 -0
  172. mlx_vlm/models/molmo/vision.py +408 -0
  173. mlx_vlm/models/molmo2/__init__.py +6 -0
  174. mlx_vlm/models/molmo2/config.py +137 -0
  175. mlx_vlm/models/molmo2/language.py +206 -0
  176. mlx_vlm/models/molmo2/molmo2.py +330 -0
  177. mlx_vlm/models/molmo2/processing.py +773 -0
  178. mlx_vlm/models/molmo2/vision.py +286 -0
  179. mlx_vlm/models/moondream2/__init__.py +11 -0
  180. mlx_vlm/models/moondream2/config.py +92 -0
  181. mlx_vlm/models/moondream2/image_crops.py +269 -0
  182. mlx_vlm/models/moondream2/language.py +267 -0
  183. mlx_vlm/models/moondream2/moondream2.py +522 -0
  184. mlx_vlm/models/moondream2/processing_moondream.py +144 -0
  185. mlx_vlm/models/moondream2/vision.py +200 -0
  186. mlx_vlm/models/multi_modality/__init__.py +4 -0
  187. mlx_vlm/models/multi_modality/config.py +108 -0
  188. mlx_vlm/models/multi_modality/language.py +191 -0
  189. mlx_vlm/models/multi_modality/multi_modality.py +338 -0
  190. mlx_vlm/models/multi_modality/sam.py +543 -0
  191. mlx_vlm/models/multi_modality/vision.py +450 -0
  192. mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
  193. mlx_vlm/models/paddleocr_vl/config.py +93 -0
  194. mlx_vlm/models/paddleocr_vl/language.py +522 -0
  195. mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
  196. mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
  197. mlx_vlm/models/paddleocr_vl/vision.py +358 -0
  198. mlx_vlm/models/paligemma/__init__.py +4 -0
  199. mlx_vlm/models/paligemma/config.py +50 -0
  200. mlx_vlm/models/paligemma/language.py +253 -0
  201. mlx_vlm/models/paligemma/paligemma.py +140 -0
  202. mlx_vlm/models/paligemma/vision.py +218 -0
  203. mlx_vlm/models/phi3_v/__init__.py +5 -0
  204. mlx_vlm/models/phi3_v/config.py +55 -0
  205. mlx_vlm/models/phi3_v/language.py +2 -0
  206. mlx_vlm/models/phi3_v/phi3_v.py +239 -0
  207. mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
  208. mlx_vlm/models/phi3_v/vision.py +294 -0
  209. mlx_vlm/models/pixtral/__init__.py +4 -0
  210. mlx_vlm/models/pixtral/config.py +69 -0
  211. mlx_vlm/models/pixtral/language.py +195 -0
  212. mlx_vlm/models/pixtral/pixtral.py +208 -0
  213. mlx_vlm/models/pixtral/vision.py +293 -0
  214. mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
  215. mlx_vlm/models/qwen2_5_vl/config.py +90 -0
  216. mlx_vlm/models/qwen2_5_vl/language.py +541 -0
  217. mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
  218. mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
  219. mlx_vlm/models/qwen2_vl/__init__.py +2 -0
  220. mlx_vlm/models/qwen2_vl/config.py +86 -0
  221. mlx_vlm/models/qwen2_vl/language.py +539 -0
  222. mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
  223. mlx_vlm/models/qwen2_vl/vision.py +308 -0
  224. mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
  225. mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
  226. mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
  227. mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
  228. mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
  229. mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
  230. mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
  231. mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
  232. mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
  233. mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
  234. mlx_vlm/models/qwen3_vl/__init__.py +2 -0
  235. mlx_vlm/models/qwen3_vl/config.py +103 -0
  236. mlx_vlm/models/qwen3_vl/language.py +596 -0
  237. mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
  238. mlx_vlm/models/qwen3_vl/vision.py +441 -0
  239. mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
  240. mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
  241. mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
  242. mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
  243. mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
  244. mlx_vlm/models/smolvlm/__init__.py +4 -0
  245. mlx_vlm/models/smolvlm/config.py +59 -0
  246. mlx_vlm/models/smolvlm/smolvlm.py +60 -0
  247. mlx_vlm/prompt_utils.py +565 -0
  248. mlx_vlm/sample_utils.py +39 -0
  249. mlx_vlm/server.py +1107 -0
  250. mlx_vlm/smolvlm_video_generate.py +109 -0
  251. mlx_vlm/tokenizer_utils.py +371 -0
  252. mlx_vlm/trainer/__init__.py +9 -0
  253. mlx_vlm/trainer/lora.py +70 -0
  254. mlx_vlm/trainer/trainer.py +299 -0
  255. mlx_vlm/trainer/utils.py +160 -0
  256. mlx_vlm/utils.py +1339 -0
  257. mlx_vlm/version.py +1 -0
  258. mlx_vlm/video_generate.py +611 -0
mlx_vlm/version.py ADDED
@@ -0,0 +1 @@
1
+ __version__ = "0.3.11"
@@ -0,0 +1,611 @@
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import base64
5
+ import logging
6
+ import math
7
+ import os
8
+ import time
9
+ from io import BytesIO
10
+ from typing import List
11
+
12
+ import cv2
13
+ import mlx.core as mx
14
+ import numpy as np
15
+ import requests
16
+ from PIL import Image
17
+
18
+ from .generate import generate
19
+ from .utils import load, load_image, process_inputs_with_fallback
20
+
21
+ # This is a beta version of the video generation script.
22
+ # It is not fully tested and may not work as expected.
23
+
24
+ logger = logging.getLogger(__name__)
25
+ logger.setLevel(logging.INFO)
26
+ logger.addHandler(logging.StreamHandler())
27
+
28
+ logger.info(
29
+ "This is a beta version of the video understanding. It may not work as expected."
30
+ )
31
+
32
+ IMAGE_FACTOR = 28
33
+ MIN_PIXELS = 4 * 28 * 28
34
+ MAX_PIXELS = 16384 * 28 * 28
35
+ MAX_RATIO = 200
36
+
37
+ VIDEO_MIN_PIXELS = 128 * 28 * 28
38
+ VIDEO_MAX_PIXELS = 768 * 28 * 28
39
+ FRAME_FACTOR = 2
40
+ FPS = 2.0
41
+ FPS_MIN_FRAMES = 4
42
+ FPS_MAX_FRAMES = 768
43
+
44
+ # Set the maximum number of video token inputs.
45
+ VIDEO_TOTAL_PIXELS = int(
46
+ float(os.environ.get("VIDEO_MAX_PIXELS", 128000 * 28 * 28 * 0.9))
47
+ )
48
+
49
+
50
+ def round_by_factor(number: int, factor: int) -> int:
51
+ """Returns the closest integer to 'number' that is divisible by 'factor'."""
52
+ return round(number / factor) * factor
53
+
54
+
55
+ def ceil_by_factor(number: int, factor: int) -> int:
56
+ """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
57
+ return math.ceil(number / factor) * factor
58
+
59
+
60
+ def floor_by_factor(number: int, factor: int) -> int:
61
+ """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
62
+ return math.floor(number / factor) * factor
63
+
64
+
65
+ def smart_resize(
66
+ height: int,
67
+ width: int,
68
+ factor: int = IMAGE_FACTOR,
69
+ min_pixels: int = MIN_PIXELS,
70
+ max_pixels: int = MAX_PIXELS,
71
+ ) -> tuple[int, int]:
72
+ """
73
+ Rescales the image so that the following conditions are met:
74
+
75
+ 1. Both dimensions (height and width) are divisible by 'factor'.
76
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
77
+ 3. The aspect ratio of the image is maintained as closely as possible.
78
+ """
79
+ if max(height, width) / min(height, width) > MAX_RATIO:
80
+ raise ValueError(
81
+ f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
82
+ )
83
+ h_bar = max(factor, round_by_factor(height, factor))
84
+ w_bar = max(factor, round_by_factor(width, factor))
85
+ if h_bar * w_bar > max_pixels:
86
+ beta = math.sqrt((height * width) / max_pixels)
87
+ h_bar = floor_by_factor(height / beta, factor)
88
+ w_bar = floor_by_factor(width / beta, factor)
89
+ elif h_bar * w_bar < min_pixels:
90
+ beta = math.sqrt(min_pixels / (height * width))
91
+ h_bar = ceil_by_factor(height * beta, factor)
92
+ w_bar = ceil_by_factor(width * beta, factor)
93
+ return h_bar, w_bar
94
+
95
+
96
+ def to_rgb(pil_image: Image.Image) -> Image.Image:
97
+ if pil_image.mode == "RGBA":
98
+ white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
99
+ white_background.paste(
100
+ pil_image, mask=pil_image.split()[3]
101
+ ) # Use alpha channel as mask
102
+ return white_background
103
+ else:
104
+ return pil_image.convert("RGB")
105
+
106
+
107
+ def fetch_image(
108
+ ele: dict[str, str | Image.Image], size_factor: int = IMAGE_FACTOR
109
+ ) -> Image.Image:
110
+ if "image" in ele:
111
+ image = ele["image"]
112
+ else:
113
+ image = ele["image_url"]
114
+ image_obj = None
115
+ if isinstance(image, Image.Image):
116
+ image_obj = image
117
+ elif image.startswith("http://") or image.startswith("https://"):
118
+ response = requests.get(image, stream=True)
119
+ image_obj = Image.open(BytesIO(response.content))
120
+ elif image.startswith("file://"):
121
+ image_obj = Image.open(image[7:])
122
+ elif image.startswith("data:image"):
123
+ if "base64," in image:
124
+ _, base64_data = image.split("base64,", 1)
125
+ data = base64.b64decode(base64_data)
126
+ image_obj = Image.open(BytesIO(data))
127
+ else:
128
+ image_obj = Image.open(image)
129
+ if image_obj is None:
130
+ raise ValueError(
131
+ f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}"
132
+ )
133
+ image = to_rgb(image_obj)
134
+ ## resize
135
+ if "resized_height" in ele and "resized_width" in ele:
136
+ resized_height, resized_width = smart_resize(
137
+ ele["resized_height"],
138
+ ele["resized_width"],
139
+ factor=size_factor,
140
+ )
141
+ else:
142
+ width, height = image.size
143
+ min_pixels = ele.get("min_pixels", MIN_PIXELS)
144
+ max_pixels = ele.get("max_pixels", MAX_PIXELS)
145
+ resized_height, resized_width = smart_resize(
146
+ height,
147
+ width,
148
+ factor=size_factor,
149
+ min_pixels=min_pixels,
150
+ max_pixels=max_pixels,
151
+ )
152
+ image = image.resize((resized_width, resized_height))
153
+ return image
154
+
155
+
156
+ def smart_nframes(
157
+ ele: dict,
158
+ total_frames: int,
159
+ video_fps: int | float,
160
+ ) -> int:
161
+ """Calculate the number of frames for the video to be used as model inputs.
162
+
163
+ Either a fixed 'nframes' is provided in ele or 'fps' is used to calculate how many frames to sample.
164
+ """
165
+ assert not (
166
+ "fps" in ele and "nframes" in ele
167
+ ), "Only accept either `fps` or `nframes`"
168
+ if "nframes" in ele:
169
+ nframes = round_by_factor(ele["nframes"], FRAME_FACTOR)
170
+ else:
171
+ fps = ele.get("fps", FPS)
172
+ min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR)
173
+ max_frames = floor_by_factor(
174
+ ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR
175
+ )
176
+ nframes = total_frames / video_fps * fps
177
+ if nframes > total_frames:
178
+ logger.warning(
179
+ f"smart_nframes: nframes[{nframes}] > total_frames[{total_frames}]"
180
+ )
181
+ nframes = min(min(max(nframes, min_frames), max_frames), total_frames)
182
+ nframes = floor_by_factor(nframes, FRAME_FACTOR)
183
+ if not (FRAME_FACTOR <= nframes and nframes <= total_frames):
184
+ raise ValueError(
185
+ f"nframes should be in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}."
186
+ )
187
+ return nframes
188
+
189
+
190
+ def load_video(
191
+ ele: dict,
192
+ ) -> (np.ndarray, float):
193
+ """
194
+ Read video using cv2.VideoCapture.
195
+
196
+ The video is read as a NumPy array with shape (T, C, H, W) where T is the number of frames,
197
+ C is the number of channels, and H, W are the frame dimensions.
198
+ """
199
+ video_path = ele["video"]
200
+ if video_path.startswith("file://"):
201
+ video_path = video_path[7:]
202
+ cap = cv2.VideoCapture(video_path)
203
+ if not cap.isOpened():
204
+ raise ValueError(f"Cannot open video: {video_path}")
205
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
206
+ video_fps = cap.get(cv2.CAP_PROP_FPS) or 1.0 # default to 1.0 if fps returns 0
207
+ st = time.time()
208
+ logger.info(
209
+ f"numpy reader: video_path={video_path}, total_frames={total_frames}, video_fps={video_fps}, time={time.time()-st:.3f}s"
210
+ )
211
+ nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
212
+ indices = np.linspace(0, total_frames - 1, nframes).round().astype(int)
213
+ frames = []
214
+ for idx in indices:
215
+ cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
216
+ ret, frame = cap.read()
217
+ if not ret:
218
+ break
219
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
220
+ frames.append(frame)
221
+ cap.release()
222
+ if not frames:
223
+ raise ValueError("No frames read from the video.")
224
+ # Stack frames into a numpy array: (T, H, W, C)
225
+ video_np = np.stack(frames, axis=0)
226
+ # Rearrange to (T, C, H, W)
227
+ video_np = np.transpose(video_np, (0, 3, 1, 2))
228
+ sample_fps = nframes / max(total_frames, 1e-6) * video_fps
229
+ return video_np, sample_fps
230
+
231
+
232
+ def fetch_video(
233
+ ele: dict, image_factor: int = IMAGE_FACTOR, return_video_sample_fps: bool = False
234
+ ) -> np.ndarray | list[Image.Image]:
235
+ if isinstance(ele["video"], str):
236
+ video, sample_fps = load_video(ele)
237
+ nframes, _, height, width = video.shape
238
+ min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
239
+ total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
240
+ max_pixels = max(
241
+ min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR),
242
+ int(min_pixels * 1.05),
243
+ )
244
+ max_pixels_supposed = ele.get("max_pixels", max_pixels)
245
+ if max_pixels_supposed > max_pixels:
246
+ logger.warning(
247
+ f"The given max_pixels[{max_pixels_supposed}] exceeds limit[{max_pixels}]."
248
+ )
249
+ max_pixels = min(max_pixels_supposed, max_pixels)
250
+ if "resized_height" in ele and "resized_width" in ele:
251
+ resized_height, resized_width = smart_resize(
252
+ ele["resized_height"],
253
+ ele["resized_width"],
254
+ factor=image_factor,
255
+ )
256
+ else:
257
+ resized_height, resized_width = smart_resize(
258
+ height,
259
+ width,
260
+ factor=image_factor,
261
+ min_pixels=min_pixels,
262
+ max_pixels=max_pixels,
263
+ )
264
+ # Resize each frame using OpenCV (similar to torchvision.transforms.functional.resize with BICUBIC)
265
+ resized_frames = []
266
+ # video is (T, C, H, W) so we need to process each frame
267
+ for frame in video:
268
+ # Rearrange from (C, H, W) to (H, W, C)
269
+ frame_np = np.transpose(frame, (1, 2, 0))
270
+ # cv2.resize expects size as (width, height)
271
+ resized = cv2.resize(
272
+ frame_np, (resized_width, resized_height), interpolation=cv2.INTER_CUBIC
273
+ )
274
+ # Convert back to (C, H, W)
275
+ resized = np.transpose(resized, (2, 0, 1))
276
+ resized_frames.append(resized)
277
+ video = np.stack(resized_frames, axis=0).astype(np.float32)
278
+ if return_video_sample_fps:
279
+ return video, sample_fps
280
+ return video
281
+ else:
282
+ # Assume video is provided as a list/tuple of image objects.
283
+ process_info = ele.copy()
284
+ process_info.pop("type", None)
285
+ process_info.pop("video", None)
286
+ images = [
287
+ fetch_image(
288
+ {"image": video_element, **process_info}, size_factor=image_factor
289
+ )
290
+ for video_element in ele["video"]
291
+ ]
292
+ nframes = ceil_by_factor(len(images), FRAME_FACTOR)
293
+ if len(images) < nframes:
294
+ images.extend([images[-1]] * (nframes - len(images)))
295
+ if return_video_sample_fps:
296
+ return images, process_info.pop("fps", 2.0)
297
+ return images
298
+
299
+
300
+ def extract_vision_info(conversations: list[dict] | list[list[dict]]) -> list[dict]:
301
+ vision_infos = []
302
+ if isinstance(conversations[0], dict):
303
+ conversations = [conversations]
304
+ for conversation in conversations:
305
+ for message in conversation:
306
+ if isinstance(message["content"], list):
307
+ for ele in message["content"]:
308
+ if (
309
+ "image" in ele
310
+ or "image_url" in ele
311
+ or "video" in ele
312
+ or ele["type"] in ("image", "image_url", "video")
313
+ ):
314
+ vision_infos.append(ele)
315
+ return vision_infos
316
+
317
+
318
+ def process_vision_info(
319
+ conversations: list[dict] | list[list[dict]],
320
+ return_video_kwargs: bool = False,
321
+ ) -> tuple[
322
+ list[Image.Image] | None, list[np.ndarray | list[Image.Image]] | None, dict | None
323
+ ]:
324
+ vision_infos = extract_vision_info(conversations)
325
+ ## Read images or videos
326
+ image_inputs = []
327
+ video_inputs = []
328
+ video_sample_fps_list = []
329
+ for vision_info in vision_infos:
330
+ if "image" in vision_info or "image_url" in vision_info:
331
+ image_inputs.append(fetch_image(vision_info))
332
+ elif "video" in vision_info:
333
+ video_input, video_sample_fps = fetch_video(
334
+ vision_info, return_video_sample_fps=True
335
+ )
336
+ video_sample_fps_list.append(video_sample_fps)
337
+ video_inputs.append(video_input)
338
+ else:
339
+ raise ValueError("Content must include image, image_url, or video.")
340
+ if len(image_inputs) == 0:
341
+ image_inputs = None
342
+ if len(video_inputs) == 0:
343
+ video_inputs = None
344
+ if return_video_kwargs:
345
+ return image_inputs, video_inputs, {"fps": video_sample_fps_list}
346
+ return image_inputs, video_inputs
347
+
348
+
349
+ class VideoFrameExtractor:
350
+ def __init__(self, max_frames: int = 50):
351
+ self.max_frames = max_frames
352
+
353
+ def resize_and_center_crop(
354
+ self, image: Image.Image, target_size: int
355
+ ) -> Image.Image:
356
+ # Get current dimensions
357
+ width, height = image.size
358
+
359
+ # Calculate new dimensions keeping aspect ratio
360
+ if width < height:
361
+ new_width = target_size
362
+ new_height = int(height * (target_size / width))
363
+ else:
364
+ new_height = target_size
365
+ new_width = int(width * (target_size / height))
366
+
367
+ # Resize
368
+ image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
369
+
370
+ # Center crop
371
+ left = (new_width - target_size) // 2
372
+ top = (new_height - target_size) // 2
373
+ right = left + target_size
374
+ bottom = top + target_size
375
+
376
+ return image.crop((left, top, right, bottom))
377
+
378
+ def extract_frames(self, video_path: str) -> List[Image.Image]:
379
+ cap = cv2.VideoCapture(video_path)
380
+ if not cap.isOpened():
381
+ raise ValueError(f"Could not open video: {video_path}")
382
+
383
+ # Get video properties
384
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
385
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
386
+
387
+ # Calculate frame indices to extract (1fps)
388
+ frame_indices = list(range(0, total_frames, fps))
389
+
390
+ # If we have more frames than max_frames, sample evenly
391
+ if len(frame_indices) > self.max_frames:
392
+ indices = np.linspace(0, len(frame_indices) - 1, self.max_frames, dtype=int)
393
+ frame_indices = [frame_indices[i] for i in indices]
394
+
395
+ frames = []
396
+ for frame_idx in frame_indices:
397
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
398
+ ret, frame = cap.read()
399
+ if ret:
400
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
401
+ pil_image = Image.fromarray(frame)
402
+ pil_image = self.resize_and_center_crop(pil_image, 384)
403
+ frames.append(pil_image)
404
+
405
+ cap.release()
406
+ return frames
407
+
408
+
409
+ def is_video_model(model):
410
+ return hasattr(model.config, "video_token_id") or hasattr(
411
+ model.config, "video_token_index"
412
+ )
413
+
414
+
415
+ def is_video_file(video_path: List[str]) -> bool:
416
+ video_extensions = [".mp4", ".avi", ".mov"]
417
+ for path in video_path:
418
+ if not any(path.endswith(ext) for ext in video_extensions):
419
+ return False
420
+ return True
421
+
422
+
423
+ def main():
424
+ parser = argparse.ArgumentParser(description="Video Description CLI")
425
+ parser.add_argument(
426
+ "--video", type=str, nargs="+", required=True, help="Path to the video file"
427
+ )
428
+ parser.add_argument(
429
+ "--max-pixels",
430
+ type=int,
431
+ nargs=2,
432
+ default=224 * 224,
433
+ help="Maximum number of pixels",
434
+ )
435
+ parser.add_argument(
436
+ "--max-frames", type=int, default=None, help="Maximum number of frames"
437
+ )
438
+ parser.add_argument("--fps", type=float, default=1.0, help="Frames per second")
439
+ parser.add_argument(
440
+ "--prompt", default="Describe this video.", help="Text prompt for the model"
441
+ )
442
+ parser.add_argument(
443
+ "--temperature", type=float, default=0.7, help="Temperature for generation"
444
+ )
445
+ parser.add_argument(
446
+ "--max-tokens",
447
+ type=int,
448
+ default=100,
449
+ help="Maximum number of tokens to generate",
450
+ )
451
+ parser.add_argument(
452
+ "--model",
453
+ default="mlx-community/Qwen2.5-VL-7B-Instruct-4bit",
454
+ help="Select the model to use",
455
+ )
456
+ parser.add_argument("--verbose", action="store_false", help="Print verbose output")
457
+
458
+ args = parser.parse_args()
459
+
460
+ print(f"\033[32mLoading model:\033[0m {args.model}")
461
+ model, processor = load(args.model)
462
+
463
+ # Validate the model
464
+ if not is_video_model(model):
465
+ logger.warning(
466
+ "Warning: The model selected doesn't natively support video inputs. Performance may be degraded."
467
+ )
468
+
469
+ if isinstance(args.max_pixels, tuple) or isinstance(args.max_pixels, list):
470
+ max_pixels = args.max_pixels[0] * args.max_pixels[1]
471
+ else:
472
+ max_pixels = args.max_pixels
473
+
474
+ kwargs = {}
475
+ if is_video_model(model):
476
+
477
+ # Check if video is image or video
478
+ if is_video_file(args.video):
479
+ messages = [
480
+ {
481
+ "role": "user",
482
+ "content": [
483
+ {
484
+ "type": "video",
485
+ "video": args.video[0],
486
+ "max_pixels": max_pixels,
487
+ "fps": args.fps,
488
+ },
489
+ {"type": "text", "text": args.prompt},
490
+ ],
491
+ }
492
+ ]
493
+ else:
494
+ messages = [
495
+ {
496
+ "role": "user",
497
+ "content": [
498
+ *[{"type": "image", "image": image} for image in args.video],
499
+ {"type": "text", "text": args.prompt},
500
+ ],
501
+ }
502
+ ]
503
+
504
+ text = processor.apply_chat_template(
505
+ messages, tokenize=False, add_generation_prompt=True
506
+ )
507
+ image_inputs, video_inputs, fps = process_vision_info(messages, True)
508
+
509
+ if args.max_frames is not None:
510
+ video_inputs = video_inputs[: args.max_frames]
511
+ inputs = processor(
512
+ text=[text],
513
+ images=image_inputs,
514
+ videos=video_inputs,
515
+ padding=True,
516
+ return_tensors="pt",
517
+ )
518
+
519
+ input_ids = mx.array(inputs["input_ids"])
520
+ pixel_values = inputs.get(
521
+ "pixel_values_videos", inputs.get("pixel_values", None)
522
+ )
523
+ if pixel_values is None:
524
+ raise ValueError("Please provide a valid video or image input.")
525
+ pixel_values = mx.array(pixel_values)
526
+
527
+ mask = mx.array(inputs["attention_mask"])
528
+ if inputs.get("video_grid_thw", None) is not None:
529
+ kwargs["video_grid_thw"] = mx.array(inputs["video_grid_thw"])
530
+ if inputs.get("image_grid_thw", None) is not None:
531
+ kwargs["image_grid_thw"] = mx.array(inputs["image_grid_thw"])
532
+
533
+ else:
534
+ if is_video_file(args.video):
535
+ if len(args.video) > 1:
536
+ raise ValueError("Only one video is supported for video models.")
537
+ else:
538
+ frame_extractor = VideoFrameExtractor(args.max_frames)
539
+ frames = frame_extractor.extract_frames(args.video[0])
540
+ else:
541
+ frames = [load_image(image) for image in args.video]
542
+
543
+ # Create prompt with frames
544
+ image_tokens = [{"type": "image"} for _ in range(len(frames))]
545
+ messages = [
546
+ {
547
+ "role": "user",
548
+ "content": [
549
+ {"type": "text", "text": "Answer briefly."},
550
+ *image_tokens,
551
+ {"type": "text", "text": args.prompt},
552
+ ],
553
+ }
554
+ ]
555
+
556
+ text = processor.apply_chat_template(
557
+ messages, tokenize=False, add_generation_prompt=True
558
+ )
559
+
560
+ # Configure processor for video frames
561
+ processor.image_processor.size = (
562
+ args.max_pixels
563
+ if isinstance(args.max_pixels, tuple)
564
+ else (args.max_pixels, args.max_pixels)
565
+ )
566
+ if hasattr(processor.image_processor, "do_resize"):
567
+ processor.image_processor.do_resize = False
568
+ if hasattr(processor.image_processor, "do_image_splitting"):
569
+ processor.image_processor.do_image_splitting = False
570
+
571
+ # Process inputs
572
+ inputs = process_inputs_with_fallback(
573
+ processor,
574
+ images=[img for img in frames],
575
+ prompts=text,
576
+ )
577
+
578
+ input_ids = mx.array(inputs["input_ids"])
579
+ pixel_values = mx.array(inputs["pixel_values"])
580
+ mask = mx.array(inputs["attention_mask"])
581
+ for key, value in inputs.items():
582
+ if key not in [
583
+ "input_ids",
584
+ "pixel_values",
585
+ "attention_mask",
586
+ ] and not isinstance(value, (str, list)):
587
+ kwargs[key] = mx.array(value)
588
+
589
+ logger.info("\033[32mGenerating response...\033[0m")
590
+
591
+ kwargs["video"] = args.video
592
+ kwargs["input_ids"] = input_ids
593
+ kwargs["pixel_values"] = pixel_values
594
+ kwargs["mask"] = mask
595
+ kwargs["temperature"] = args.temperature
596
+ kwargs["max_tokens"] = args.max_tokens
597
+
598
+ response = generate(
599
+ model,
600
+ processor,
601
+ prompt=text,
602
+ verbose=args.verbose,
603
+ **kwargs,
604
+ )
605
+
606
+ if not args.verbose:
607
+ print(response)
608
+
609
+
610
+ if __name__ == "__main__":
611
+ main()