fount-vlm-nell-02 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
- fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
- fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
- fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
- mlx_vlm/__init__.py +16 -0
- mlx_vlm/__main__.py +24 -0
- mlx_vlm/chat.py +234 -0
- mlx_vlm/chat_ui.py +508 -0
- mlx_vlm/convert.py +284 -0
- mlx_vlm/deprecation.py +52 -0
- mlx_vlm/evals/__init__.py +0 -0
- mlx_vlm/evals/math_vista.py +565 -0
- mlx_vlm/evals/mmmu.py +528 -0
- mlx_vlm/evals/mmstar.py +343 -0
- mlx_vlm/evals/ocrbench.py +453 -0
- mlx_vlm/evals/utils.py +37 -0
- mlx_vlm/generate.py +1457 -0
- mlx_vlm/lora.py +207 -0
- mlx_vlm/models/__init__.py +0 -0
- mlx_vlm/models/aya_vision/__init__.py +2 -0
- mlx_vlm/models/aya_vision/aya_vision.py +188 -0
- mlx_vlm/models/aya_vision/config.py +52 -0
- mlx_vlm/models/aya_vision/language.py +202 -0
- mlx_vlm/models/aya_vision/vision.py +340 -0
- mlx_vlm/models/base.py +356 -0
- mlx_vlm/models/cache.py +238 -0
- mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
- mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
- mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
- mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
- mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
- mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
- mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
- mlx_vlm/models/deepseekocr/__init__.py +2 -0
- mlx_vlm/models/deepseekocr/config.py +173 -0
- mlx_vlm/models/deepseekocr/conversation.py +264 -0
- mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
- mlx_vlm/models/deepseekocr/language.py +547 -0
- mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
- mlx_vlm/models/deepseekocr/sam.py +489 -0
- mlx_vlm/models/deepseekocr/vision.py +263 -0
- mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
- mlx_vlm/models/deepseekocr_2/config.py +216 -0
- mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
- mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
- mlx_vlm/models/deepseekocr_2/vision.py +439 -0
- mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
- mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
- mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
- mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
- mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
- mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
- mlx_vlm/models/fastvlm/__init__.py +2 -0
- mlx_vlm/models/fastvlm/config.py +79 -0
- mlx_vlm/models/fastvlm/fastvlm.py +198 -0
- mlx_vlm/models/fastvlm/language.py +49 -0
- mlx_vlm/models/fastvlm/vision.py +692 -0
- mlx_vlm/models/florence2/__init__.py +2 -0
- mlx_vlm/models/florence2/config.py +84 -0
- mlx_vlm/models/florence2/florence2.py +383 -0
- mlx_vlm/models/florence2/language.py +452 -0
- mlx_vlm/models/florence2/processing_florence2.py +30 -0
- mlx_vlm/models/florence2/vision.py +552 -0
- mlx_vlm/models/gemma3/__init__.py +2 -0
- mlx_vlm/models/gemma3/config.py +52 -0
- mlx_vlm/models/gemma3/gemma3.py +194 -0
- mlx_vlm/models/gemma3/language.py +293 -0
- mlx_vlm/models/gemma3/vision.py +215 -0
- mlx_vlm/models/gemma3n/__init__.py +2 -0
- mlx_vlm/models/gemma3n/audio.py +1038 -0
- mlx_vlm/models/gemma3n/config.py +130 -0
- mlx_vlm/models/gemma3n/gemma3n.py +322 -0
- mlx_vlm/models/gemma3n/language.py +631 -0
- mlx_vlm/models/gemma3n/vision.py +994 -0
- mlx_vlm/models/glm4v/__init__.py +3 -0
- mlx_vlm/models/glm4v/config.py +79 -0
- mlx_vlm/models/glm4v/glm4v.py +188 -0
- mlx_vlm/models/glm4v/language.py +574 -0
- mlx_vlm/models/glm4v/processing.py +220 -0
- mlx_vlm/models/glm4v/vision.py +406 -0
- mlx_vlm/models/glm4v_moe/__init__.py +3 -0
- mlx_vlm/models/glm4v_moe/config.py +81 -0
- mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
- mlx_vlm/models/glm4v_moe/language.py +674 -0
- mlx_vlm/models/glm4v_moe/processing.py +229 -0
- mlx_vlm/models/glm4v_moe/vision.py +405 -0
- mlx_vlm/models/glm_ocr/__init__.py +3 -0
- mlx_vlm/models/glm_ocr/config.py +93 -0
- mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
- mlx_vlm/models/glm_ocr/language.py +585 -0
- mlx_vlm/models/glm_ocr/processing.py +208 -0
- mlx_vlm/models/glm_ocr/vision.py +342 -0
- mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
- mlx_vlm/models/hunyuan_vl/config.py +136 -0
- mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
- mlx_vlm/models/hunyuan_vl/language.py +509 -0
- mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
- mlx_vlm/models/hunyuan_vl/vision.py +322 -0
- mlx_vlm/models/idefics2/__init__.py +2 -0
- mlx_vlm/models/idefics2/config.py +65 -0
- mlx_vlm/models/idefics2/idefics2.py +321 -0
- mlx_vlm/models/idefics2/language.py +161 -0
- mlx_vlm/models/idefics2/vision.py +244 -0
- mlx_vlm/models/idefics3/__init__.py +4 -0
- mlx_vlm/models/idefics3/config.py +54 -0
- mlx_vlm/models/idefics3/idefics3.py +221 -0
- mlx_vlm/models/idefics3/language.py +157 -0
- mlx_vlm/models/idefics3/vision.py +265 -0
- mlx_vlm/models/internvl_chat/__init__.py +3 -0
- mlx_vlm/models/internvl_chat/config.py +89 -0
- mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
- mlx_vlm/models/internvl_chat/language.py +187 -0
- mlx_vlm/models/internvl_chat/processor.py +395 -0
- mlx_vlm/models/internvl_chat/vision.py +265 -0
- mlx_vlm/models/interpolate.py +183 -0
- mlx_vlm/models/jina_vlm/__init__.py +3 -0
- mlx_vlm/models/jina_vlm/config.py +142 -0
- mlx_vlm/models/jina_vlm/image_processor.py +430 -0
- mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
- mlx_vlm/models/jina_vlm/language.py +272 -0
- mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
- mlx_vlm/models/jina_vlm/vision.py +202 -0
- mlx_vlm/models/kernels.py +447 -0
- mlx_vlm/models/kimi_vl/__init__.py +4 -0
- mlx_vlm/models/kimi_vl/config.py +84 -0
- mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
- mlx_vlm/models/kimi_vl/language.py +460 -0
- mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
- mlx_vlm/models/kimi_vl/vision.py +485 -0
- mlx_vlm/models/lfm2_vl/__init__.py +2 -0
- mlx_vlm/models/lfm2_vl/config.py +94 -0
- mlx_vlm/models/lfm2_vl/language.py +49 -0
- mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
- mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
- mlx_vlm/models/lfm2_vl/vision.py +223 -0
- mlx_vlm/models/llama4/__init__.py +2 -0
- mlx_vlm/models/llama4/config.py +83 -0
- mlx_vlm/models/llama4/language.py +334 -0
- mlx_vlm/models/llama4/llama4.py +146 -0
- mlx_vlm/models/llama4/vision.py +526 -0
- mlx_vlm/models/llava/__init__.py +2 -0
- mlx_vlm/models/llava/config.py +61 -0
- mlx_vlm/models/llava/language.py +200 -0
- mlx_vlm/models/llava/llava.py +132 -0
- mlx_vlm/models/llava/vision.py +233 -0
- mlx_vlm/models/llava_bunny/__init__.py +2 -0
- mlx_vlm/models/llava_bunny/config.py +85 -0
- mlx_vlm/models/llava_bunny/language.py +194 -0
- mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
- mlx_vlm/models/llava_bunny/vision.py +278 -0
- mlx_vlm/models/llava_next/__init__.py +2 -0
- mlx_vlm/models/llava_next/config.py +60 -0
- mlx_vlm/models/llava_next/language.py +192 -0
- mlx_vlm/models/llava_next/llava_next.py +138 -0
- mlx_vlm/models/llava_next/vision.py +217 -0
- mlx_vlm/models/mistral3/__init__.py +2 -0
- mlx_vlm/models/mistral3/config.py +59 -0
- mlx_vlm/models/mistral3/language.py +269 -0
- mlx_vlm/models/mistral3/mistral3.py +383 -0
- mlx_vlm/models/mllama/__init__.py +4 -0
- mlx_vlm/models/mllama/config.py +74 -0
- mlx_vlm/models/mllama/language.py +377 -0
- mlx_vlm/models/mllama/mllama.py +210 -0
- mlx_vlm/models/mllama/vision.py +458 -0
- mlx_vlm/models/molmo/__init__.py +5 -0
- mlx_vlm/models/molmo/config.py +93 -0
- mlx_vlm/models/molmo/language.py +208 -0
- mlx_vlm/models/molmo/molmo.py +108 -0
- mlx_vlm/models/molmo/processing_molmo.py +763 -0
- mlx_vlm/models/molmo/vision.py +408 -0
- mlx_vlm/models/molmo2/__init__.py +6 -0
- mlx_vlm/models/molmo2/config.py +137 -0
- mlx_vlm/models/molmo2/language.py +206 -0
- mlx_vlm/models/molmo2/molmo2.py +330 -0
- mlx_vlm/models/molmo2/processing.py +773 -0
- mlx_vlm/models/molmo2/vision.py +286 -0
- mlx_vlm/models/moondream2/__init__.py +11 -0
- mlx_vlm/models/moondream2/config.py +92 -0
- mlx_vlm/models/moondream2/image_crops.py +269 -0
- mlx_vlm/models/moondream2/language.py +267 -0
- mlx_vlm/models/moondream2/moondream2.py +522 -0
- mlx_vlm/models/moondream2/processing_moondream.py +144 -0
- mlx_vlm/models/moondream2/vision.py +200 -0
- mlx_vlm/models/multi_modality/__init__.py +4 -0
- mlx_vlm/models/multi_modality/config.py +108 -0
- mlx_vlm/models/multi_modality/language.py +191 -0
- mlx_vlm/models/multi_modality/multi_modality.py +338 -0
- mlx_vlm/models/multi_modality/sam.py +543 -0
- mlx_vlm/models/multi_modality/vision.py +450 -0
- mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
- mlx_vlm/models/paddleocr_vl/config.py +93 -0
- mlx_vlm/models/paddleocr_vl/language.py +522 -0
- mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
- mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
- mlx_vlm/models/paddleocr_vl/vision.py +358 -0
- mlx_vlm/models/paligemma/__init__.py +4 -0
- mlx_vlm/models/paligemma/config.py +50 -0
- mlx_vlm/models/paligemma/language.py +253 -0
- mlx_vlm/models/paligemma/paligemma.py +140 -0
- mlx_vlm/models/paligemma/vision.py +218 -0
- mlx_vlm/models/phi3_v/__init__.py +5 -0
- mlx_vlm/models/phi3_v/config.py +55 -0
- mlx_vlm/models/phi3_v/language.py +2 -0
- mlx_vlm/models/phi3_v/phi3_v.py +239 -0
- mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
- mlx_vlm/models/phi3_v/vision.py +294 -0
- mlx_vlm/models/pixtral/__init__.py +4 -0
- mlx_vlm/models/pixtral/config.py +69 -0
- mlx_vlm/models/pixtral/language.py +195 -0
- mlx_vlm/models/pixtral/pixtral.py +208 -0
- mlx_vlm/models/pixtral/vision.py +293 -0
- mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_5_vl/config.py +90 -0
- mlx_vlm/models/qwen2_5_vl/language.py +541 -0
- mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
- mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
- mlx_vlm/models/qwen2_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_vl/config.py +86 -0
- mlx_vlm/models/qwen2_vl/language.py +539 -0
- mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
- mlx_vlm/models/qwen2_vl/vision.py +308 -0
- mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
- mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
- mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
- mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
- mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
- mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
- mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
- mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
- mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
- mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
- mlx_vlm/models/qwen3_vl/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl/config.py +103 -0
- mlx_vlm/models/qwen3_vl/language.py +596 -0
- mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
- mlx_vlm/models/qwen3_vl/vision.py +441 -0
- mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
- mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
- mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
- mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
- mlx_vlm/models/smolvlm/__init__.py +4 -0
- mlx_vlm/models/smolvlm/config.py +59 -0
- mlx_vlm/models/smolvlm/smolvlm.py +60 -0
- mlx_vlm/prompt_utils.py +565 -0
- mlx_vlm/sample_utils.py +39 -0
- mlx_vlm/server.py +1107 -0
- mlx_vlm/smolvlm_video_generate.py +109 -0
- mlx_vlm/tokenizer_utils.py +371 -0
- mlx_vlm/trainer/__init__.py +9 -0
- mlx_vlm/trainer/lora.py +70 -0
- mlx_vlm/trainer/trainer.py +299 -0
- mlx_vlm/trainer/utils.py +160 -0
- mlx_vlm/utils.py +1339 -0
- mlx_vlm/version.py +1 -0
- mlx_vlm/video_generate.py +611 -0
mlx_vlm/version.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.3.11"
|
|
@@ -0,0 +1,611 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
import base64
|
|
5
|
+
import logging
|
|
6
|
+
import math
|
|
7
|
+
import os
|
|
8
|
+
import time
|
|
9
|
+
from io import BytesIO
|
|
10
|
+
from typing import List
|
|
11
|
+
|
|
12
|
+
import cv2
|
|
13
|
+
import mlx.core as mx
|
|
14
|
+
import numpy as np
|
|
15
|
+
import requests
|
|
16
|
+
from PIL import Image
|
|
17
|
+
|
|
18
|
+
from .generate import generate
|
|
19
|
+
from .utils import load, load_image, process_inputs_with_fallback
|
|
20
|
+
|
|
21
|
+
# This is a beta version of the video generation script.
|
|
22
|
+
# It is not fully tested and may not work as expected.
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
logger.setLevel(logging.INFO)
|
|
26
|
+
logger.addHandler(logging.StreamHandler())
|
|
27
|
+
|
|
28
|
+
logger.info(
|
|
29
|
+
"This is a beta version of the video understanding. It may not work as expected."
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
IMAGE_FACTOR = 28
|
|
33
|
+
MIN_PIXELS = 4 * 28 * 28
|
|
34
|
+
MAX_PIXELS = 16384 * 28 * 28
|
|
35
|
+
MAX_RATIO = 200
|
|
36
|
+
|
|
37
|
+
VIDEO_MIN_PIXELS = 128 * 28 * 28
|
|
38
|
+
VIDEO_MAX_PIXELS = 768 * 28 * 28
|
|
39
|
+
FRAME_FACTOR = 2
|
|
40
|
+
FPS = 2.0
|
|
41
|
+
FPS_MIN_FRAMES = 4
|
|
42
|
+
FPS_MAX_FRAMES = 768
|
|
43
|
+
|
|
44
|
+
# Set the maximum number of video token inputs.
|
|
45
|
+
VIDEO_TOTAL_PIXELS = int(
|
|
46
|
+
float(os.environ.get("VIDEO_MAX_PIXELS", 128000 * 28 * 28 * 0.9))
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def round_by_factor(number: int, factor: int) -> int:
|
|
51
|
+
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
|
52
|
+
return round(number / factor) * factor
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def ceil_by_factor(number: int, factor: int) -> int:
|
|
56
|
+
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
|
57
|
+
return math.ceil(number / factor) * factor
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def floor_by_factor(number: int, factor: int) -> int:
|
|
61
|
+
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
|
62
|
+
return math.floor(number / factor) * factor
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def smart_resize(
|
|
66
|
+
height: int,
|
|
67
|
+
width: int,
|
|
68
|
+
factor: int = IMAGE_FACTOR,
|
|
69
|
+
min_pixels: int = MIN_PIXELS,
|
|
70
|
+
max_pixels: int = MAX_PIXELS,
|
|
71
|
+
) -> tuple[int, int]:
|
|
72
|
+
"""
|
|
73
|
+
Rescales the image so that the following conditions are met:
|
|
74
|
+
|
|
75
|
+
1. Both dimensions (height and width) are divisible by 'factor'.
|
|
76
|
+
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
|
77
|
+
3. The aspect ratio of the image is maintained as closely as possible.
|
|
78
|
+
"""
|
|
79
|
+
if max(height, width) / min(height, width) > MAX_RATIO:
|
|
80
|
+
raise ValueError(
|
|
81
|
+
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
|
|
82
|
+
)
|
|
83
|
+
h_bar = max(factor, round_by_factor(height, factor))
|
|
84
|
+
w_bar = max(factor, round_by_factor(width, factor))
|
|
85
|
+
if h_bar * w_bar > max_pixels:
|
|
86
|
+
beta = math.sqrt((height * width) / max_pixels)
|
|
87
|
+
h_bar = floor_by_factor(height / beta, factor)
|
|
88
|
+
w_bar = floor_by_factor(width / beta, factor)
|
|
89
|
+
elif h_bar * w_bar < min_pixels:
|
|
90
|
+
beta = math.sqrt(min_pixels / (height * width))
|
|
91
|
+
h_bar = ceil_by_factor(height * beta, factor)
|
|
92
|
+
w_bar = ceil_by_factor(width * beta, factor)
|
|
93
|
+
return h_bar, w_bar
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def to_rgb(pil_image: Image.Image) -> Image.Image:
|
|
97
|
+
if pil_image.mode == "RGBA":
|
|
98
|
+
white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
|
|
99
|
+
white_background.paste(
|
|
100
|
+
pil_image, mask=pil_image.split()[3]
|
|
101
|
+
) # Use alpha channel as mask
|
|
102
|
+
return white_background
|
|
103
|
+
else:
|
|
104
|
+
return pil_image.convert("RGB")
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def fetch_image(
|
|
108
|
+
ele: dict[str, str | Image.Image], size_factor: int = IMAGE_FACTOR
|
|
109
|
+
) -> Image.Image:
|
|
110
|
+
if "image" in ele:
|
|
111
|
+
image = ele["image"]
|
|
112
|
+
else:
|
|
113
|
+
image = ele["image_url"]
|
|
114
|
+
image_obj = None
|
|
115
|
+
if isinstance(image, Image.Image):
|
|
116
|
+
image_obj = image
|
|
117
|
+
elif image.startswith("http://") or image.startswith("https://"):
|
|
118
|
+
response = requests.get(image, stream=True)
|
|
119
|
+
image_obj = Image.open(BytesIO(response.content))
|
|
120
|
+
elif image.startswith("file://"):
|
|
121
|
+
image_obj = Image.open(image[7:])
|
|
122
|
+
elif image.startswith("data:image"):
|
|
123
|
+
if "base64," in image:
|
|
124
|
+
_, base64_data = image.split("base64,", 1)
|
|
125
|
+
data = base64.b64decode(base64_data)
|
|
126
|
+
image_obj = Image.open(BytesIO(data))
|
|
127
|
+
else:
|
|
128
|
+
image_obj = Image.open(image)
|
|
129
|
+
if image_obj is None:
|
|
130
|
+
raise ValueError(
|
|
131
|
+
f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}"
|
|
132
|
+
)
|
|
133
|
+
image = to_rgb(image_obj)
|
|
134
|
+
## resize
|
|
135
|
+
if "resized_height" in ele and "resized_width" in ele:
|
|
136
|
+
resized_height, resized_width = smart_resize(
|
|
137
|
+
ele["resized_height"],
|
|
138
|
+
ele["resized_width"],
|
|
139
|
+
factor=size_factor,
|
|
140
|
+
)
|
|
141
|
+
else:
|
|
142
|
+
width, height = image.size
|
|
143
|
+
min_pixels = ele.get("min_pixels", MIN_PIXELS)
|
|
144
|
+
max_pixels = ele.get("max_pixels", MAX_PIXELS)
|
|
145
|
+
resized_height, resized_width = smart_resize(
|
|
146
|
+
height,
|
|
147
|
+
width,
|
|
148
|
+
factor=size_factor,
|
|
149
|
+
min_pixels=min_pixels,
|
|
150
|
+
max_pixels=max_pixels,
|
|
151
|
+
)
|
|
152
|
+
image = image.resize((resized_width, resized_height))
|
|
153
|
+
return image
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def smart_nframes(
|
|
157
|
+
ele: dict,
|
|
158
|
+
total_frames: int,
|
|
159
|
+
video_fps: int | float,
|
|
160
|
+
) -> int:
|
|
161
|
+
"""Calculate the number of frames for the video to be used as model inputs.
|
|
162
|
+
|
|
163
|
+
Either a fixed 'nframes' is provided in ele or 'fps' is used to calculate how many frames to sample.
|
|
164
|
+
"""
|
|
165
|
+
assert not (
|
|
166
|
+
"fps" in ele and "nframes" in ele
|
|
167
|
+
), "Only accept either `fps` or `nframes`"
|
|
168
|
+
if "nframes" in ele:
|
|
169
|
+
nframes = round_by_factor(ele["nframes"], FRAME_FACTOR)
|
|
170
|
+
else:
|
|
171
|
+
fps = ele.get("fps", FPS)
|
|
172
|
+
min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR)
|
|
173
|
+
max_frames = floor_by_factor(
|
|
174
|
+
ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR
|
|
175
|
+
)
|
|
176
|
+
nframes = total_frames / video_fps * fps
|
|
177
|
+
if nframes > total_frames:
|
|
178
|
+
logger.warning(
|
|
179
|
+
f"smart_nframes: nframes[{nframes}] > total_frames[{total_frames}]"
|
|
180
|
+
)
|
|
181
|
+
nframes = min(min(max(nframes, min_frames), max_frames), total_frames)
|
|
182
|
+
nframes = floor_by_factor(nframes, FRAME_FACTOR)
|
|
183
|
+
if not (FRAME_FACTOR <= nframes and nframes <= total_frames):
|
|
184
|
+
raise ValueError(
|
|
185
|
+
f"nframes should be in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}."
|
|
186
|
+
)
|
|
187
|
+
return nframes
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def load_video(
|
|
191
|
+
ele: dict,
|
|
192
|
+
) -> (np.ndarray, float):
|
|
193
|
+
"""
|
|
194
|
+
Read video using cv2.VideoCapture.
|
|
195
|
+
|
|
196
|
+
The video is read as a NumPy array with shape (T, C, H, W) where T is the number of frames,
|
|
197
|
+
C is the number of channels, and H, W are the frame dimensions.
|
|
198
|
+
"""
|
|
199
|
+
video_path = ele["video"]
|
|
200
|
+
if video_path.startswith("file://"):
|
|
201
|
+
video_path = video_path[7:]
|
|
202
|
+
cap = cv2.VideoCapture(video_path)
|
|
203
|
+
if not cap.isOpened():
|
|
204
|
+
raise ValueError(f"Cannot open video: {video_path}")
|
|
205
|
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
206
|
+
video_fps = cap.get(cv2.CAP_PROP_FPS) or 1.0 # default to 1.0 if fps returns 0
|
|
207
|
+
st = time.time()
|
|
208
|
+
logger.info(
|
|
209
|
+
f"numpy reader: video_path={video_path}, total_frames={total_frames}, video_fps={video_fps}, time={time.time()-st:.3f}s"
|
|
210
|
+
)
|
|
211
|
+
nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
|
|
212
|
+
indices = np.linspace(0, total_frames - 1, nframes).round().astype(int)
|
|
213
|
+
frames = []
|
|
214
|
+
for idx in indices:
|
|
215
|
+
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
|
|
216
|
+
ret, frame = cap.read()
|
|
217
|
+
if not ret:
|
|
218
|
+
break
|
|
219
|
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
220
|
+
frames.append(frame)
|
|
221
|
+
cap.release()
|
|
222
|
+
if not frames:
|
|
223
|
+
raise ValueError("No frames read from the video.")
|
|
224
|
+
# Stack frames into a numpy array: (T, H, W, C)
|
|
225
|
+
video_np = np.stack(frames, axis=0)
|
|
226
|
+
# Rearrange to (T, C, H, W)
|
|
227
|
+
video_np = np.transpose(video_np, (0, 3, 1, 2))
|
|
228
|
+
sample_fps = nframes / max(total_frames, 1e-6) * video_fps
|
|
229
|
+
return video_np, sample_fps
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def fetch_video(
|
|
233
|
+
ele: dict, image_factor: int = IMAGE_FACTOR, return_video_sample_fps: bool = False
|
|
234
|
+
) -> np.ndarray | list[Image.Image]:
|
|
235
|
+
if isinstance(ele["video"], str):
|
|
236
|
+
video, sample_fps = load_video(ele)
|
|
237
|
+
nframes, _, height, width = video.shape
|
|
238
|
+
min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
|
|
239
|
+
total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
|
|
240
|
+
max_pixels = max(
|
|
241
|
+
min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR),
|
|
242
|
+
int(min_pixels * 1.05),
|
|
243
|
+
)
|
|
244
|
+
max_pixels_supposed = ele.get("max_pixels", max_pixels)
|
|
245
|
+
if max_pixels_supposed > max_pixels:
|
|
246
|
+
logger.warning(
|
|
247
|
+
f"The given max_pixels[{max_pixels_supposed}] exceeds limit[{max_pixels}]."
|
|
248
|
+
)
|
|
249
|
+
max_pixels = min(max_pixels_supposed, max_pixels)
|
|
250
|
+
if "resized_height" in ele and "resized_width" in ele:
|
|
251
|
+
resized_height, resized_width = smart_resize(
|
|
252
|
+
ele["resized_height"],
|
|
253
|
+
ele["resized_width"],
|
|
254
|
+
factor=image_factor,
|
|
255
|
+
)
|
|
256
|
+
else:
|
|
257
|
+
resized_height, resized_width = smart_resize(
|
|
258
|
+
height,
|
|
259
|
+
width,
|
|
260
|
+
factor=image_factor,
|
|
261
|
+
min_pixels=min_pixels,
|
|
262
|
+
max_pixels=max_pixels,
|
|
263
|
+
)
|
|
264
|
+
# Resize each frame using OpenCV (similar to torchvision.transforms.functional.resize with BICUBIC)
|
|
265
|
+
resized_frames = []
|
|
266
|
+
# video is (T, C, H, W) so we need to process each frame
|
|
267
|
+
for frame in video:
|
|
268
|
+
# Rearrange from (C, H, W) to (H, W, C)
|
|
269
|
+
frame_np = np.transpose(frame, (1, 2, 0))
|
|
270
|
+
# cv2.resize expects size as (width, height)
|
|
271
|
+
resized = cv2.resize(
|
|
272
|
+
frame_np, (resized_width, resized_height), interpolation=cv2.INTER_CUBIC
|
|
273
|
+
)
|
|
274
|
+
# Convert back to (C, H, W)
|
|
275
|
+
resized = np.transpose(resized, (2, 0, 1))
|
|
276
|
+
resized_frames.append(resized)
|
|
277
|
+
video = np.stack(resized_frames, axis=0).astype(np.float32)
|
|
278
|
+
if return_video_sample_fps:
|
|
279
|
+
return video, sample_fps
|
|
280
|
+
return video
|
|
281
|
+
else:
|
|
282
|
+
# Assume video is provided as a list/tuple of image objects.
|
|
283
|
+
process_info = ele.copy()
|
|
284
|
+
process_info.pop("type", None)
|
|
285
|
+
process_info.pop("video", None)
|
|
286
|
+
images = [
|
|
287
|
+
fetch_image(
|
|
288
|
+
{"image": video_element, **process_info}, size_factor=image_factor
|
|
289
|
+
)
|
|
290
|
+
for video_element in ele["video"]
|
|
291
|
+
]
|
|
292
|
+
nframes = ceil_by_factor(len(images), FRAME_FACTOR)
|
|
293
|
+
if len(images) < nframes:
|
|
294
|
+
images.extend([images[-1]] * (nframes - len(images)))
|
|
295
|
+
if return_video_sample_fps:
|
|
296
|
+
return images, process_info.pop("fps", 2.0)
|
|
297
|
+
return images
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def extract_vision_info(conversations: list[dict] | list[list[dict]]) -> list[dict]:
|
|
301
|
+
vision_infos = []
|
|
302
|
+
if isinstance(conversations[0], dict):
|
|
303
|
+
conversations = [conversations]
|
|
304
|
+
for conversation in conversations:
|
|
305
|
+
for message in conversation:
|
|
306
|
+
if isinstance(message["content"], list):
|
|
307
|
+
for ele in message["content"]:
|
|
308
|
+
if (
|
|
309
|
+
"image" in ele
|
|
310
|
+
or "image_url" in ele
|
|
311
|
+
or "video" in ele
|
|
312
|
+
or ele["type"] in ("image", "image_url", "video")
|
|
313
|
+
):
|
|
314
|
+
vision_infos.append(ele)
|
|
315
|
+
return vision_infos
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def process_vision_info(
|
|
319
|
+
conversations: list[dict] | list[list[dict]],
|
|
320
|
+
return_video_kwargs: bool = False,
|
|
321
|
+
) -> tuple[
|
|
322
|
+
list[Image.Image] | None, list[np.ndarray | list[Image.Image]] | None, dict | None
|
|
323
|
+
]:
|
|
324
|
+
vision_infos = extract_vision_info(conversations)
|
|
325
|
+
## Read images or videos
|
|
326
|
+
image_inputs = []
|
|
327
|
+
video_inputs = []
|
|
328
|
+
video_sample_fps_list = []
|
|
329
|
+
for vision_info in vision_infos:
|
|
330
|
+
if "image" in vision_info or "image_url" in vision_info:
|
|
331
|
+
image_inputs.append(fetch_image(vision_info))
|
|
332
|
+
elif "video" in vision_info:
|
|
333
|
+
video_input, video_sample_fps = fetch_video(
|
|
334
|
+
vision_info, return_video_sample_fps=True
|
|
335
|
+
)
|
|
336
|
+
video_sample_fps_list.append(video_sample_fps)
|
|
337
|
+
video_inputs.append(video_input)
|
|
338
|
+
else:
|
|
339
|
+
raise ValueError("Content must include image, image_url, or video.")
|
|
340
|
+
if len(image_inputs) == 0:
|
|
341
|
+
image_inputs = None
|
|
342
|
+
if len(video_inputs) == 0:
|
|
343
|
+
video_inputs = None
|
|
344
|
+
if return_video_kwargs:
|
|
345
|
+
return image_inputs, video_inputs, {"fps": video_sample_fps_list}
|
|
346
|
+
return image_inputs, video_inputs
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
class VideoFrameExtractor:
|
|
350
|
+
def __init__(self, max_frames: int = 50):
|
|
351
|
+
self.max_frames = max_frames
|
|
352
|
+
|
|
353
|
+
def resize_and_center_crop(
|
|
354
|
+
self, image: Image.Image, target_size: int
|
|
355
|
+
) -> Image.Image:
|
|
356
|
+
# Get current dimensions
|
|
357
|
+
width, height = image.size
|
|
358
|
+
|
|
359
|
+
# Calculate new dimensions keeping aspect ratio
|
|
360
|
+
if width < height:
|
|
361
|
+
new_width = target_size
|
|
362
|
+
new_height = int(height * (target_size / width))
|
|
363
|
+
else:
|
|
364
|
+
new_height = target_size
|
|
365
|
+
new_width = int(width * (target_size / height))
|
|
366
|
+
|
|
367
|
+
# Resize
|
|
368
|
+
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
|
369
|
+
|
|
370
|
+
# Center crop
|
|
371
|
+
left = (new_width - target_size) // 2
|
|
372
|
+
top = (new_height - target_size) // 2
|
|
373
|
+
right = left + target_size
|
|
374
|
+
bottom = top + target_size
|
|
375
|
+
|
|
376
|
+
return image.crop((left, top, right, bottom))
|
|
377
|
+
|
|
378
|
+
def extract_frames(self, video_path: str) -> List[Image.Image]:
|
|
379
|
+
cap = cv2.VideoCapture(video_path)
|
|
380
|
+
if not cap.isOpened():
|
|
381
|
+
raise ValueError(f"Could not open video: {video_path}")
|
|
382
|
+
|
|
383
|
+
# Get video properties
|
|
384
|
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
385
|
+
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
|
386
|
+
|
|
387
|
+
# Calculate frame indices to extract (1fps)
|
|
388
|
+
frame_indices = list(range(0, total_frames, fps))
|
|
389
|
+
|
|
390
|
+
# If we have more frames than max_frames, sample evenly
|
|
391
|
+
if len(frame_indices) > self.max_frames:
|
|
392
|
+
indices = np.linspace(0, len(frame_indices) - 1, self.max_frames, dtype=int)
|
|
393
|
+
frame_indices = [frame_indices[i] for i in indices]
|
|
394
|
+
|
|
395
|
+
frames = []
|
|
396
|
+
for frame_idx in frame_indices:
|
|
397
|
+
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
|
|
398
|
+
ret, frame = cap.read()
|
|
399
|
+
if ret:
|
|
400
|
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
|
401
|
+
pil_image = Image.fromarray(frame)
|
|
402
|
+
pil_image = self.resize_and_center_crop(pil_image, 384)
|
|
403
|
+
frames.append(pil_image)
|
|
404
|
+
|
|
405
|
+
cap.release()
|
|
406
|
+
return frames
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
def is_video_model(model):
|
|
410
|
+
return hasattr(model.config, "video_token_id") or hasattr(
|
|
411
|
+
model.config, "video_token_index"
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
def is_video_file(video_path: List[str]) -> bool:
|
|
416
|
+
video_extensions = [".mp4", ".avi", ".mov"]
|
|
417
|
+
for path in video_path:
|
|
418
|
+
if not any(path.endswith(ext) for ext in video_extensions):
|
|
419
|
+
return False
|
|
420
|
+
return True
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
def main():
|
|
424
|
+
parser = argparse.ArgumentParser(description="Video Description CLI")
|
|
425
|
+
parser.add_argument(
|
|
426
|
+
"--video", type=str, nargs="+", required=True, help="Path to the video file"
|
|
427
|
+
)
|
|
428
|
+
parser.add_argument(
|
|
429
|
+
"--max-pixels",
|
|
430
|
+
type=int,
|
|
431
|
+
nargs=2,
|
|
432
|
+
default=224 * 224,
|
|
433
|
+
help="Maximum number of pixels",
|
|
434
|
+
)
|
|
435
|
+
parser.add_argument(
|
|
436
|
+
"--max-frames", type=int, default=None, help="Maximum number of frames"
|
|
437
|
+
)
|
|
438
|
+
parser.add_argument("--fps", type=float, default=1.0, help="Frames per second")
|
|
439
|
+
parser.add_argument(
|
|
440
|
+
"--prompt", default="Describe this video.", help="Text prompt for the model"
|
|
441
|
+
)
|
|
442
|
+
parser.add_argument(
|
|
443
|
+
"--temperature", type=float, default=0.7, help="Temperature for generation"
|
|
444
|
+
)
|
|
445
|
+
parser.add_argument(
|
|
446
|
+
"--max-tokens",
|
|
447
|
+
type=int,
|
|
448
|
+
default=100,
|
|
449
|
+
help="Maximum number of tokens to generate",
|
|
450
|
+
)
|
|
451
|
+
parser.add_argument(
|
|
452
|
+
"--model",
|
|
453
|
+
default="mlx-community/Qwen2.5-VL-7B-Instruct-4bit",
|
|
454
|
+
help="Select the model to use",
|
|
455
|
+
)
|
|
456
|
+
parser.add_argument("--verbose", action="store_false", help="Print verbose output")
|
|
457
|
+
|
|
458
|
+
args = parser.parse_args()
|
|
459
|
+
|
|
460
|
+
print(f"\033[32mLoading model:\033[0m {args.model}")
|
|
461
|
+
model, processor = load(args.model)
|
|
462
|
+
|
|
463
|
+
# Validate the model
|
|
464
|
+
if not is_video_model(model):
|
|
465
|
+
logger.warning(
|
|
466
|
+
"Warning: The model selected doesn't natively support video inputs. Performance may be degraded."
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
if isinstance(args.max_pixels, tuple) or isinstance(args.max_pixels, list):
|
|
470
|
+
max_pixels = args.max_pixels[0] * args.max_pixels[1]
|
|
471
|
+
else:
|
|
472
|
+
max_pixels = args.max_pixels
|
|
473
|
+
|
|
474
|
+
kwargs = {}
|
|
475
|
+
if is_video_model(model):
|
|
476
|
+
|
|
477
|
+
# Check if video is image or video
|
|
478
|
+
if is_video_file(args.video):
|
|
479
|
+
messages = [
|
|
480
|
+
{
|
|
481
|
+
"role": "user",
|
|
482
|
+
"content": [
|
|
483
|
+
{
|
|
484
|
+
"type": "video",
|
|
485
|
+
"video": args.video[0],
|
|
486
|
+
"max_pixels": max_pixels,
|
|
487
|
+
"fps": args.fps,
|
|
488
|
+
},
|
|
489
|
+
{"type": "text", "text": args.prompt},
|
|
490
|
+
],
|
|
491
|
+
}
|
|
492
|
+
]
|
|
493
|
+
else:
|
|
494
|
+
messages = [
|
|
495
|
+
{
|
|
496
|
+
"role": "user",
|
|
497
|
+
"content": [
|
|
498
|
+
*[{"type": "image", "image": image} for image in args.video],
|
|
499
|
+
{"type": "text", "text": args.prompt},
|
|
500
|
+
],
|
|
501
|
+
}
|
|
502
|
+
]
|
|
503
|
+
|
|
504
|
+
text = processor.apply_chat_template(
|
|
505
|
+
messages, tokenize=False, add_generation_prompt=True
|
|
506
|
+
)
|
|
507
|
+
image_inputs, video_inputs, fps = process_vision_info(messages, True)
|
|
508
|
+
|
|
509
|
+
if args.max_frames is not None:
|
|
510
|
+
video_inputs = video_inputs[: args.max_frames]
|
|
511
|
+
inputs = processor(
|
|
512
|
+
text=[text],
|
|
513
|
+
images=image_inputs,
|
|
514
|
+
videos=video_inputs,
|
|
515
|
+
padding=True,
|
|
516
|
+
return_tensors="pt",
|
|
517
|
+
)
|
|
518
|
+
|
|
519
|
+
input_ids = mx.array(inputs["input_ids"])
|
|
520
|
+
pixel_values = inputs.get(
|
|
521
|
+
"pixel_values_videos", inputs.get("pixel_values", None)
|
|
522
|
+
)
|
|
523
|
+
if pixel_values is None:
|
|
524
|
+
raise ValueError("Please provide a valid video or image input.")
|
|
525
|
+
pixel_values = mx.array(pixel_values)
|
|
526
|
+
|
|
527
|
+
mask = mx.array(inputs["attention_mask"])
|
|
528
|
+
if inputs.get("video_grid_thw", None) is not None:
|
|
529
|
+
kwargs["video_grid_thw"] = mx.array(inputs["video_grid_thw"])
|
|
530
|
+
if inputs.get("image_grid_thw", None) is not None:
|
|
531
|
+
kwargs["image_grid_thw"] = mx.array(inputs["image_grid_thw"])
|
|
532
|
+
|
|
533
|
+
else:
|
|
534
|
+
if is_video_file(args.video):
|
|
535
|
+
if len(args.video) > 1:
|
|
536
|
+
raise ValueError("Only one video is supported for video models.")
|
|
537
|
+
else:
|
|
538
|
+
frame_extractor = VideoFrameExtractor(args.max_frames)
|
|
539
|
+
frames = frame_extractor.extract_frames(args.video[0])
|
|
540
|
+
else:
|
|
541
|
+
frames = [load_image(image) for image in args.video]
|
|
542
|
+
|
|
543
|
+
# Create prompt with frames
|
|
544
|
+
image_tokens = [{"type": "image"} for _ in range(len(frames))]
|
|
545
|
+
messages = [
|
|
546
|
+
{
|
|
547
|
+
"role": "user",
|
|
548
|
+
"content": [
|
|
549
|
+
{"type": "text", "text": "Answer briefly."},
|
|
550
|
+
*image_tokens,
|
|
551
|
+
{"type": "text", "text": args.prompt},
|
|
552
|
+
],
|
|
553
|
+
}
|
|
554
|
+
]
|
|
555
|
+
|
|
556
|
+
text = processor.apply_chat_template(
|
|
557
|
+
messages, tokenize=False, add_generation_prompt=True
|
|
558
|
+
)
|
|
559
|
+
|
|
560
|
+
# Configure processor for video frames
|
|
561
|
+
processor.image_processor.size = (
|
|
562
|
+
args.max_pixels
|
|
563
|
+
if isinstance(args.max_pixels, tuple)
|
|
564
|
+
else (args.max_pixels, args.max_pixels)
|
|
565
|
+
)
|
|
566
|
+
if hasattr(processor.image_processor, "do_resize"):
|
|
567
|
+
processor.image_processor.do_resize = False
|
|
568
|
+
if hasattr(processor.image_processor, "do_image_splitting"):
|
|
569
|
+
processor.image_processor.do_image_splitting = False
|
|
570
|
+
|
|
571
|
+
# Process inputs
|
|
572
|
+
inputs = process_inputs_with_fallback(
|
|
573
|
+
processor,
|
|
574
|
+
images=[img for img in frames],
|
|
575
|
+
prompts=text,
|
|
576
|
+
)
|
|
577
|
+
|
|
578
|
+
input_ids = mx.array(inputs["input_ids"])
|
|
579
|
+
pixel_values = mx.array(inputs["pixel_values"])
|
|
580
|
+
mask = mx.array(inputs["attention_mask"])
|
|
581
|
+
for key, value in inputs.items():
|
|
582
|
+
if key not in [
|
|
583
|
+
"input_ids",
|
|
584
|
+
"pixel_values",
|
|
585
|
+
"attention_mask",
|
|
586
|
+
] and not isinstance(value, (str, list)):
|
|
587
|
+
kwargs[key] = mx.array(value)
|
|
588
|
+
|
|
589
|
+
logger.info("\033[32mGenerating response...\033[0m")
|
|
590
|
+
|
|
591
|
+
kwargs["video"] = args.video
|
|
592
|
+
kwargs["input_ids"] = input_ids
|
|
593
|
+
kwargs["pixel_values"] = pixel_values
|
|
594
|
+
kwargs["mask"] = mask
|
|
595
|
+
kwargs["temperature"] = args.temperature
|
|
596
|
+
kwargs["max_tokens"] = args.max_tokens
|
|
597
|
+
|
|
598
|
+
response = generate(
|
|
599
|
+
model,
|
|
600
|
+
processor,
|
|
601
|
+
prompt=text,
|
|
602
|
+
verbose=args.verbose,
|
|
603
|
+
**kwargs,
|
|
604
|
+
)
|
|
605
|
+
|
|
606
|
+
if not args.verbose:
|
|
607
|
+
print(response)
|
|
608
|
+
|
|
609
|
+
|
|
610
|
+
if __name__ == "__main__":
|
|
611
|
+
main()
|