fount-vlm-nell-02 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
- fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
- fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
- fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
- mlx_vlm/__init__.py +16 -0
- mlx_vlm/__main__.py +24 -0
- mlx_vlm/chat.py +234 -0
- mlx_vlm/chat_ui.py +508 -0
- mlx_vlm/convert.py +284 -0
- mlx_vlm/deprecation.py +52 -0
- mlx_vlm/evals/__init__.py +0 -0
- mlx_vlm/evals/math_vista.py +565 -0
- mlx_vlm/evals/mmmu.py +528 -0
- mlx_vlm/evals/mmstar.py +343 -0
- mlx_vlm/evals/ocrbench.py +453 -0
- mlx_vlm/evals/utils.py +37 -0
- mlx_vlm/generate.py +1457 -0
- mlx_vlm/lora.py +207 -0
- mlx_vlm/models/__init__.py +0 -0
- mlx_vlm/models/aya_vision/__init__.py +2 -0
- mlx_vlm/models/aya_vision/aya_vision.py +188 -0
- mlx_vlm/models/aya_vision/config.py +52 -0
- mlx_vlm/models/aya_vision/language.py +202 -0
- mlx_vlm/models/aya_vision/vision.py +340 -0
- mlx_vlm/models/base.py +356 -0
- mlx_vlm/models/cache.py +238 -0
- mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
- mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
- mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
- mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
- mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
- mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
- mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
- mlx_vlm/models/deepseekocr/__init__.py +2 -0
- mlx_vlm/models/deepseekocr/config.py +173 -0
- mlx_vlm/models/deepseekocr/conversation.py +264 -0
- mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
- mlx_vlm/models/deepseekocr/language.py +547 -0
- mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
- mlx_vlm/models/deepseekocr/sam.py +489 -0
- mlx_vlm/models/deepseekocr/vision.py +263 -0
- mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
- mlx_vlm/models/deepseekocr_2/config.py +216 -0
- mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
- mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
- mlx_vlm/models/deepseekocr_2/vision.py +439 -0
- mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
- mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
- mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
- mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
- mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
- mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
- mlx_vlm/models/fastvlm/__init__.py +2 -0
- mlx_vlm/models/fastvlm/config.py +79 -0
- mlx_vlm/models/fastvlm/fastvlm.py +198 -0
- mlx_vlm/models/fastvlm/language.py +49 -0
- mlx_vlm/models/fastvlm/vision.py +692 -0
- mlx_vlm/models/florence2/__init__.py +2 -0
- mlx_vlm/models/florence2/config.py +84 -0
- mlx_vlm/models/florence2/florence2.py +383 -0
- mlx_vlm/models/florence2/language.py +452 -0
- mlx_vlm/models/florence2/processing_florence2.py +30 -0
- mlx_vlm/models/florence2/vision.py +552 -0
- mlx_vlm/models/gemma3/__init__.py +2 -0
- mlx_vlm/models/gemma3/config.py +52 -0
- mlx_vlm/models/gemma3/gemma3.py +194 -0
- mlx_vlm/models/gemma3/language.py +293 -0
- mlx_vlm/models/gemma3/vision.py +215 -0
- mlx_vlm/models/gemma3n/__init__.py +2 -0
- mlx_vlm/models/gemma3n/audio.py +1038 -0
- mlx_vlm/models/gemma3n/config.py +130 -0
- mlx_vlm/models/gemma3n/gemma3n.py +322 -0
- mlx_vlm/models/gemma3n/language.py +631 -0
- mlx_vlm/models/gemma3n/vision.py +994 -0
- mlx_vlm/models/glm4v/__init__.py +3 -0
- mlx_vlm/models/glm4v/config.py +79 -0
- mlx_vlm/models/glm4v/glm4v.py +188 -0
- mlx_vlm/models/glm4v/language.py +574 -0
- mlx_vlm/models/glm4v/processing.py +220 -0
- mlx_vlm/models/glm4v/vision.py +406 -0
- mlx_vlm/models/glm4v_moe/__init__.py +3 -0
- mlx_vlm/models/glm4v_moe/config.py +81 -0
- mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
- mlx_vlm/models/glm4v_moe/language.py +674 -0
- mlx_vlm/models/glm4v_moe/processing.py +229 -0
- mlx_vlm/models/glm4v_moe/vision.py +405 -0
- mlx_vlm/models/glm_ocr/__init__.py +3 -0
- mlx_vlm/models/glm_ocr/config.py +93 -0
- mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
- mlx_vlm/models/glm_ocr/language.py +585 -0
- mlx_vlm/models/glm_ocr/processing.py +208 -0
- mlx_vlm/models/glm_ocr/vision.py +342 -0
- mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
- mlx_vlm/models/hunyuan_vl/config.py +136 -0
- mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
- mlx_vlm/models/hunyuan_vl/language.py +509 -0
- mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
- mlx_vlm/models/hunyuan_vl/vision.py +322 -0
- mlx_vlm/models/idefics2/__init__.py +2 -0
- mlx_vlm/models/idefics2/config.py +65 -0
- mlx_vlm/models/idefics2/idefics2.py +321 -0
- mlx_vlm/models/idefics2/language.py +161 -0
- mlx_vlm/models/idefics2/vision.py +244 -0
- mlx_vlm/models/idefics3/__init__.py +4 -0
- mlx_vlm/models/idefics3/config.py +54 -0
- mlx_vlm/models/idefics3/idefics3.py +221 -0
- mlx_vlm/models/idefics3/language.py +157 -0
- mlx_vlm/models/idefics3/vision.py +265 -0
- mlx_vlm/models/internvl_chat/__init__.py +3 -0
- mlx_vlm/models/internvl_chat/config.py +89 -0
- mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
- mlx_vlm/models/internvl_chat/language.py +187 -0
- mlx_vlm/models/internvl_chat/processor.py +395 -0
- mlx_vlm/models/internvl_chat/vision.py +265 -0
- mlx_vlm/models/interpolate.py +183 -0
- mlx_vlm/models/jina_vlm/__init__.py +3 -0
- mlx_vlm/models/jina_vlm/config.py +142 -0
- mlx_vlm/models/jina_vlm/image_processor.py +430 -0
- mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
- mlx_vlm/models/jina_vlm/language.py +272 -0
- mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
- mlx_vlm/models/jina_vlm/vision.py +202 -0
- mlx_vlm/models/kernels.py +447 -0
- mlx_vlm/models/kimi_vl/__init__.py +4 -0
- mlx_vlm/models/kimi_vl/config.py +84 -0
- mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
- mlx_vlm/models/kimi_vl/language.py +460 -0
- mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
- mlx_vlm/models/kimi_vl/vision.py +485 -0
- mlx_vlm/models/lfm2_vl/__init__.py +2 -0
- mlx_vlm/models/lfm2_vl/config.py +94 -0
- mlx_vlm/models/lfm2_vl/language.py +49 -0
- mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
- mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
- mlx_vlm/models/lfm2_vl/vision.py +223 -0
- mlx_vlm/models/llama4/__init__.py +2 -0
- mlx_vlm/models/llama4/config.py +83 -0
- mlx_vlm/models/llama4/language.py +334 -0
- mlx_vlm/models/llama4/llama4.py +146 -0
- mlx_vlm/models/llama4/vision.py +526 -0
- mlx_vlm/models/llava/__init__.py +2 -0
- mlx_vlm/models/llava/config.py +61 -0
- mlx_vlm/models/llava/language.py +200 -0
- mlx_vlm/models/llava/llava.py +132 -0
- mlx_vlm/models/llava/vision.py +233 -0
- mlx_vlm/models/llava_bunny/__init__.py +2 -0
- mlx_vlm/models/llava_bunny/config.py +85 -0
- mlx_vlm/models/llava_bunny/language.py +194 -0
- mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
- mlx_vlm/models/llava_bunny/vision.py +278 -0
- mlx_vlm/models/llava_next/__init__.py +2 -0
- mlx_vlm/models/llava_next/config.py +60 -0
- mlx_vlm/models/llava_next/language.py +192 -0
- mlx_vlm/models/llava_next/llava_next.py +138 -0
- mlx_vlm/models/llava_next/vision.py +217 -0
- mlx_vlm/models/mistral3/__init__.py +2 -0
- mlx_vlm/models/mistral3/config.py +59 -0
- mlx_vlm/models/mistral3/language.py +269 -0
- mlx_vlm/models/mistral3/mistral3.py +383 -0
- mlx_vlm/models/mllama/__init__.py +4 -0
- mlx_vlm/models/mllama/config.py +74 -0
- mlx_vlm/models/mllama/language.py +377 -0
- mlx_vlm/models/mllama/mllama.py +210 -0
- mlx_vlm/models/mllama/vision.py +458 -0
- mlx_vlm/models/molmo/__init__.py +5 -0
- mlx_vlm/models/molmo/config.py +93 -0
- mlx_vlm/models/molmo/language.py +208 -0
- mlx_vlm/models/molmo/molmo.py +108 -0
- mlx_vlm/models/molmo/processing_molmo.py +763 -0
- mlx_vlm/models/molmo/vision.py +408 -0
- mlx_vlm/models/molmo2/__init__.py +6 -0
- mlx_vlm/models/molmo2/config.py +137 -0
- mlx_vlm/models/molmo2/language.py +206 -0
- mlx_vlm/models/molmo2/molmo2.py +330 -0
- mlx_vlm/models/molmo2/processing.py +773 -0
- mlx_vlm/models/molmo2/vision.py +286 -0
- mlx_vlm/models/moondream2/__init__.py +11 -0
- mlx_vlm/models/moondream2/config.py +92 -0
- mlx_vlm/models/moondream2/image_crops.py +269 -0
- mlx_vlm/models/moondream2/language.py +267 -0
- mlx_vlm/models/moondream2/moondream2.py +522 -0
- mlx_vlm/models/moondream2/processing_moondream.py +144 -0
- mlx_vlm/models/moondream2/vision.py +200 -0
- mlx_vlm/models/multi_modality/__init__.py +4 -0
- mlx_vlm/models/multi_modality/config.py +108 -0
- mlx_vlm/models/multi_modality/language.py +191 -0
- mlx_vlm/models/multi_modality/multi_modality.py +338 -0
- mlx_vlm/models/multi_modality/sam.py +543 -0
- mlx_vlm/models/multi_modality/vision.py +450 -0
- mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
- mlx_vlm/models/paddleocr_vl/config.py +93 -0
- mlx_vlm/models/paddleocr_vl/language.py +522 -0
- mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
- mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
- mlx_vlm/models/paddleocr_vl/vision.py +358 -0
- mlx_vlm/models/paligemma/__init__.py +4 -0
- mlx_vlm/models/paligemma/config.py +50 -0
- mlx_vlm/models/paligemma/language.py +253 -0
- mlx_vlm/models/paligemma/paligemma.py +140 -0
- mlx_vlm/models/paligemma/vision.py +218 -0
- mlx_vlm/models/phi3_v/__init__.py +5 -0
- mlx_vlm/models/phi3_v/config.py +55 -0
- mlx_vlm/models/phi3_v/language.py +2 -0
- mlx_vlm/models/phi3_v/phi3_v.py +239 -0
- mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
- mlx_vlm/models/phi3_v/vision.py +294 -0
- mlx_vlm/models/pixtral/__init__.py +4 -0
- mlx_vlm/models/pixtral/config.py +69 -0
- mlx_vlm/models/pixtral/language.py +195 -0
- mlx_vlm/models/pixtral/pixtral.py +208 -0
- mlx_vlm/models/pixtral/vision.py +293 -0
- mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_5_vl/config.py +90 -0
- mlx_vlm/models/qwen2_5_vl/language.py +541 -0
- mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
- mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
- mlx_vlm/models/qwen2_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_vl/config.py +86 -0
- mlx_vlm/models/qwen2_vl/language.py +539 -0
- mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
- mlx_vlm/models/qwen2_vl/vision.py +308 -0
- mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
- mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
- mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
- mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
- mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
- mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
- mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
- mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
- mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
- mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
- mlx_vlm/models/qwen3_vl/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl/config.py +103 -0
- mlx_vlm/models/qwen3_vl/language.py +596 -0
- mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
- mlx_vlm/models/qwen3_vl/vision.py +441 -0
- mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
- mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
- mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
- mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
- mlx_vlm/models/smolvlm/__init__.py +4 -0
- mlx_vlm/models/smolvlm/config.py +59 -0
- mlx_vlm/models/smolvlm/smolvlm.py +60 -0
- mlx_vlm/prompt_utils.py +565 -0
- mlx_vlm/sample_utils.py +39 -0
- mlx_vlm/server.py +1107 -0
- mlx_vlm/smolvlm_video_generate.py +109 -0
- mlx_vlm/tokenizer_utils.py +371 -0
- mlx_vlm/trainer/__init__.py +9 -0
- mlx_vlm/trainer/lora.py +70 -0
- mlx_vlm/trainer/trainer.py +299 -0
- mlx_vlm/trainer/utils.py +160 -0
- mlx_vlm/utils.py +1339 -0
- mlx_vlm/version.py +1 -0
- mlx_vlm/video_generate.py +611 -0
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
import mlx.core as mx
|
|
2
|
+
import numpy as np
|
|
3
|
+
|
|
4
|
+
from ..idefics3 import Model as Idefics3Model
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Model(Idefics3Model):
|
|
8
|
+
def _prepare_inputs_for_multimodal(self, image_features, inputs_embeds, input_ids):
|
|
9
|
+
# Assumes bs == 1
|
|
10
|
+
|
|
11
|
+
B, T, D_text = inputs_embeds.shape
|
|
12
|
+
N, S, D_img = image_features.shape
|
|
13
|
+
|
|
14
|
+
image_offset = 0
|
|
15
|
+
cur_embeds = inputs_embeds[0]
|
|
16
|
+
|
|
17
|
+
# Find positions of <image> tokens in the text
|
|
18
|
+
image_token_index = self.config.image_token_index
|
|
19
|
+
image_positions = np.where(input_ids == image_token_index)[1].tolist()
|
|
20
|
+
num_image_tokens = len(image_positions)
|
|
21
|
+
|
|
22
|
+
# If no <image> => text-only
|
|
23
|
+
if num_image_tokens == 0:
|
|
24
|
+
empty_slice = image_features[0][:0, :] # shape (0, D)
|
|
25
|
+
return mx.concatenate([cur_embeds, empty_slice], axis=0)
|
|
26
|
+
|
|
27
|
+
# Typically, if each image is S embeddings, we expect the total # of <image> tokens
|
|
28
|
+
# in this sample to be multiple of S => each group of S tokens = 1 image
|
|
29
|
+
if num_image_tokens % S != 0:
|
|
30
|
+
raise ValueError(
|
|
31
|
+
f"Input has {num_image_tokens} <image> tokens, not a multiple of S={S}. "
|
|
32
|
+
"Cannot map them to blocks of shape (S, D)."
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
chunks = [image_positions[i : i + S] for i in range(0, num_image_tokens, S)]
|
|
36
|
+
|
|
37
|
+
segments = []
|
|
38
|
+
text_start = 0
|
|
39
|
+
|
|
40
|
+
# For each chunk (each chunk => 1 image)
|
|
41
|
+
for chunk in chunks:
|
|
42
|
+
cur_block = image_features[image_offset]
|
|
43
|
+
image_offset += 1
|
|
44
|
+
|
|
45
|
+
# We'll iterate over the S positions in ascending order
|
|
46
|
+
for i_s, pos in enumerate(chunk):
|
|
47
|
+
if pos > text_start:
|
|
48
|
+
segments.append(cur_embeds[text_start:pos])
|
|
49
|
+
# Then add one row from cur_block => shape (1, D)
|
|
50
|
+
row_of_block = cur_block[i_s : i_s + 1, :]
|
|
51
|
+
segments.append(row_of_block)
|
|
52
|
+
text_start = pos + 1
|
|
53
|
+
|
|
54
|
+
# leftover text after the final <image> token
|
|
55
|
+
if text_start < T:
|
|
56
|
+
segments.append(cur_embeds[text_start:])
|
|
57
|
+
|
|
58
|
+
# cat them into a single (T_b, D) tensor
|
|
59
|
+
merged_sample = mx.concatenate(segments, axis=0)
|
|
60
|
+
return mx.expand_dims(merged_sample, axis=0)
|
mlx_vlm/prompt_utils.py
ADDED
|
@@ -0,0 +1,565 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from functools import partial
|
|
3
|
+
from typing import Any, Dict, List, Union
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class MessageFormat(Enum):
|
|
9
|
+
"""Enum for different message format types."""
|
|
10
|
+
|
|
11
|
+
LIST_WITH_IMAGE = "list_with_image"
|
|
12
|
+
LIST_WITH_IMAGE_FIRST = "list_with_image_first"
|
|
13
|
+
LIST_WITH_IMAGE_URL_FIRST = "list_with_image_url_first"
|
|
14
|
+
LIST_WITH_IMAGE_TYPE = "list_with_image_type"
|
|
15
|
+
LIST_WITH_IMAGE_TYPE_TEXT = "list_with_image_type_text"
|
|
16
|
+
LIST_WITH_IMAGE_TYPE_TEXT_IMAGE_LAST = "list_with_image_type_text_image_last"
|
|
17
|
+
IMAGE_TOKEN = "image_token"
|
|
18
|
+
IMAGE_TOKEN_PIPE = "image_token_pipe"
|
|
19
|
+
START_IMAGE_TOKEN = "start_image_token"
|
|
20
|
+
IMAGE_TOKEN_NEWLINE = "image_token_newline"
|
|
21
|
+
NUMBERED_IMAGE_TOKENS = "numbered_image_tokens"
|
|
22
|
+
PROMPT_ONLY = "prompt_only"
|
|
23
|
+
PROMPT_WITH_IMAGE_TOKEN = "prompt_with_image_token"
|
|
24
|
+
PROMPT_WITH_START_IMAGE_TOKEN = "prompt_with_start_image_token"
|
|
25
|
+
VIDEO_WITH_TEXT = "video_with_text"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
# Model configuration mapping
|
|
29
|
+
MODEL_CONFIG = {
|
|
30
|
+
# List with image format models
|
|
31
|
+
"jina_vlm": MessageFormat.IMAGE_TOKEN_PIPE,
|
|
32
|
+
"jvlm": MessageFormat.IMAGE_TOKEN_PIPE,
|
|
33
|
+
"idefics2": MessageFormat.LIST_WITH_IMAGE,
|
|
34
|
+
"idefics3": MessageFormat.LIST_WITH_IMAGE_FIRST,
|
|
35
|
+
"lfm2-vl": MessageFormat.LIST_WITH_IMAGE_FIRST,
|
|
36
|
+
"lfm2_vl": MessageFormat.LIST_WITH_IMAGE_FIRST,
|
|
37
|
+
"aya_vision": MessageFormat.LIST_WITH_IMAGE,
|
|
38
|
+
"cohere2_vision": MessageFormat.LIST_WITH_IMAGE,
|
|
39
|
+
"paddleocr_vl": MessageFormat.LIST_WITH_IMAGE_FIRST,
|
|
40
|
+
"qwen2_vl": MessageFormat.LIST_WITH_IMAGE,
|
|
41
|
+
"qwen2_5_vl": MessageFormat.LIST_WITH_IMAGE_FIRST,
|
|
42
|
+
"qwen3_vl": MessageFormat.LIST_WITH_IMAGE_FIRST,
|
|
43
|
+
"qwen3_vl_moe": MessageFormat.LIST_WITH_IMAGE_FIRST,
|
|
44
|
+
"mistral3": MessageFormat.LIST_WITH_IMAGE_FIRST,
|
|
45
|
+
"glm4v": MessageFormat.LIST_WITH_IMAGE_FIRST,
|
|
46
|
+
"glm4v_moe": MessageFormat.LIST_WITH_IMAGE_FIRST,
|
|
47
|
+
"glm_ocr": MessageFormat.LIST_WITH_IMAGE_FIRST,
|
|
48
|
+
"ernie4_5_moe_vl": MessageFormat.LIST_WITH_IMAGE_URL_FIRST,
|
|
49
|
+
"internvl_chat": MessageFormat.LIST_WITH_IMAGE_TYPE,
|
|
50
|
+
"kimi_vl": MessageFormat.LIST_WITH_IMAGE,
|
|
51
|
+
"gemma3": MessageFormat.START_IMAGE_TOKEN,
|
|
52
|
+
"gemma3n": MessageFormat.LIST_WITH_IMAGE_TYPE_TEXT_IMAGE_LAST,
|
|
53
|
+
"llama4": MessageFormat.LIST_WITH_IMAGE,
|
|
54
|
+
"smolvlm": MessageFormat.LIST_WITH_IMAGE_FIRST,
|
|
55
|
+
"llava": MessageFormat.LIST_WITH_IMAGE,
|
|
56
|
+
"llava_next": MessageFormat.LIST_WITH_IMAGE,
|
|
57
|
+
"mllama": MessageFormat.LIST_WITH_IMAGE,
|
|
58
|
+
"pixtral": MessageFormat.LIST_WITH_IMAGE_TYPE_TEXT,
|
|
59
|
+
"molmo2": MessageFormat.LIST_WITH_IMAGE_FIRST,
|
|
60
|
+
# Token-based models
|
|
61
|
+
"llava-qwen2": MessageFormat.IMAGE_TOKEN_NEWLINE,
|
|
62
|
+
"llava_qwen2": MessageFormat.IMAGE_TOKEN_NEWLINE, # fastvlm
|
|
63
|
+
"bunny-llama": MessageFormat.IMAGE_TOKEN_NEWLINE,
|
|
64
|
+
"phi3_v": MessageFormat.NUMBERED_IMAGE_TOKENS,
|
|
65
|
+
"multi_modality": MessageFormat.IMAGE_TOKEN,
|
|
66
|
+
"deepseek_vl_v2": MessageFormat.IMAGE_TOKEN_NEWLINE,
|
|
67
|
+
"deepseekocr_2": MessageFormat.IMAGE_TOKEN_NEWLINE,
|
|
68
|
+
"deepseekocr": MessageFormat.IMAGE_TOKEN_NEWLINE,
|
|
69
|
+
"hunyuan_vl": MessageFormat.LIST_WITH_IMAGE_FIRST,
|
|
70
|
+
# Prompt-only models
|
|
71
|
+
"florence2": MessageFormat.PROMPT_ONLY,
|
|
72
|
+
"molmo": MessageFormat.PROMPT_ONLY,
|
|
73
|
+
"paligemma": MessageFormat.PROMPT_WITH_IMAGE_TOKEN,
|
|
74
|
+
"moondream1": MessageFormat.PROMPT_WITH_IMAGE_TOKEN,
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
# Models that don't support multi-image
|
|
78
|
+
SINGLE_IMAGE_ONLY_MODELS = {
|
|
79
|
+
"llava_next",
|
|
80
|
+
"llava-qwen2",
|
|
81
|
+
"bunny-llama",
|
|
82
|
+
"paligemma",
|
|
83
|
+
"multi_modality",
|
|
84
|
+
"mllama",
|
|
85
|
+
"moondream1",
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def extract_text_from_content(content: Any) -> str:
|
|
90
|
+
"""
|
|
91
|
+
Extract text from multimodal content.
|
|
92
|
+
|
|
93
|
+
When using OpenAI-compatible multimodal API, content can be a list like:
|
|
94
|
+
[
|
|
95
|
+
{"type": "text", "text": "Describe this image"},
|
|
96
|
+
{"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}
|
|
97
|
+
]
|
|
98
|
+
|
|
99
|
+
This function extracts only the text parts, preventing base64 image data
|
|
100
|
+
from being tokenized as text (which would cause token explosion).
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
content: Either a string or a list of content items
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
A string containing only the text content
|
|
107
|
+
"""
|
|
108
|
+
if isinstance(content, str):
|
|
109
|
+
return content
|
|
110
|
+
|
|
111
|
+
if isinstance(content, list):
|
|
112
|
+
text_parts = []
|
|
113
|
+
for item in content:
|
|
114
|
+
if isinstance(item, dict):
|
|
115
|
+
item_type = item.get("type", "")
|
|
116
|
+
# Extract text from text-type items
|
|
117
|
+
if item_type in ("text", "input_text"):
|
|
118
|
+
text = item.get("text", "") or item.get("content", "")
|
|
119
|
+
if text:
|
|
120
|
+
text_parts.append(text)
|
|
121
|
+
# Skip image_url, input_image, input_audio - these are handled separately
|
|
122
|
+
return " ".join(text_parts).strip() if text_parts else ""
|
|
123
|
+
|
|
124
|
+
# Fallback: convert to string (shouldn't happen in normal usage)
|
|
125
|
+
return str(content) if content else ""
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class MessageBuilder:
|
|
129
|
+
"""Builder for creating messages in various formats."""
|
|
130
|
+
|
|
131
|
+
@staticmethod
|
|
132
|
+
def text_message(text: str) -> Dict[str, str]:
|
|
133
|
+
"""Create a simple text message."""
|
|
134
|
+
return {"type": "text", "text": text, "content": text}
|
|
135
|
+
|
|
136
|
+
@staticmethod
|
|
137
|
+
def content_message(content: str) -> Dict[str, str]:
|
|
138
|
+
"""Create a content-type text message."""
|
|
139
|
+
return {"type": "text", "text": content, "content": content}
|
|
140
|
+
|
|
141
|
+
@staticmethod
|
|
142
|
+
def image_message() -> Dict[str, str]:
|
|
143
|
+
"""Create an image message."""
|
|
144
|
+
return {"type": "image"}
|
|
145
|
+
|
|
146
|
+
@staticmethod
|
|
147
|
+
def image_url_message() -> Dict[str, str]:
|
|
148
|
+
"""Create an image_url message (for models like ERNIE that expect this format)."""
|
|
149
|
+
return {"type": "image_url"}
|
|
150
|
+
|
|
151
|
+
@staticmethod
|
|
152
|
+
def audio_message() -> Dict[str, str]:
|
|
153
|
+
"""Create an audio message."""
|
|
154
|
+
return {"type": "audio"}
|
|
155
|
+
|
|
156
|
+
@staticmethod
|
|
157
|
+
def video_message(
|
|
158
|
+
video_path: str, max_pixels: int = 224 * 224, fps: int = 1
|
|
159
|
+
) -> Dict[str, Any]:
|
|
160
|
+
"""Create a video message."""
|
|
161
|
+
return {
|
|
162
|
+
"type": "video",
|
|
163
|
+
"video": video_path,
|
|
164
|
+
"max_pixels": max_pixels,
|
|
165
|
+
"fps": fps,
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class MessageFormatter:
|
|
170
|
+
"""Handles formatting messages for different model types."""
|
|
171
|
+
|
|
172
|
+
def __init__(self, model_name: str):
|
|
173
|
+
self.model_name = model_name.lower()
|
|
174
|
+
self.format_type = MODEL_CONFIG.get(self.model_name)
|
|
175
|
+
if not self.format_type:
|
|
176
|
+
raise ValueError(f"Unsupported model: {model_name}")
|
|
177
|
+
|
|
178
|
+
def format_message(
|
|
179
|
+
self,
|
|
180
|
+
prompt: str,
|
|
181
|
+
role: str = "user",
|
|
182
|
+
skip_image_token: bool = False,
|
|
183
|
+
skip_audio_token: bool = False,
|
|
184
|
+
num_images: int = 1,
|
|
185
|
+
num_audios: int = 1,
|
|
186
|
+
**kwargs,
|
|
187
|
+
) -> Union[str, Dict[str, Any]]:
|
|
188
|
+
"""Format a message based on the model type."""
|
|
189
|
+
|
|
190
|
+
# Check multi-image support
|
|
191
|
+
if num_images > 1 and self.model_name in SINGLE_IMAGE_ONLY_MODELS:
|
|
192
|
+
raise ValueError(
|
|
193
|
+
f"Model {self.model_name} does not support multi-image chat. "
|
|
194
|
+
f"Please only use 1 image."
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
# Handle video format for specific models
|
|
198
|
+
if self.model_name in [
|
|
199
|
+
"qwen2_vl",
|
|
200
|
+
"qwen2_5_vl",
|
|
201
|
+
"qwen3_vl",
|
|
202
|
+
"qwen3_vl_moe",
|
|
203
|
+
] and kwargs.get("video"):
|
|
204
|
+
return self._format_video_message(prompt, kwargs)
|
|
205
|
+
|
|
206
|
+
# Route to appropriate formatter
|
|
207
|
+
formatter_map = {
|
|
208
|
+
MessageFormat.LIST_WITH_IMAGE: self._format_list_with_image,
|
|
209
|
+
MessageFormat.LIST_WITH_IMAGE_FIRST: partial(
|
|
210
|
+
self._format_list_with_image, image_first=True
|
|
211
|
+
),
|
|
212
|
+
MessageFormat.LIST_WITH_IMAGE_URL_FIRST: partial(
|
|
213
|
+
self._format_list_with_image, image_first=True, use_image_url=True
|
|
214
|
+
),
|
|
215
|
+
MessageFormat.LIST_WITH_IMAGE_TYPE: self._format_list_with_image_type,
|
|
216
|
+
MessageFormat.LIST_WITH_IMAGE_TYPE_TEXT: partial(
|
|
217
|
+
self._format_list_with_image_type, message_type="text"
|
|
218
|
+
),
|
|
219
|
+
MessageFormat.LIST_WITH_IMAGE_TYPE_TEXT_IMAGE_LAST: partial(
|
|
220
|
+
self._format_list_with_image_type,
|
|
221
|
+
message_type="text",
|
|
222
|
+
image_first=False,
|
|
223
|
+
),
|
|
224
|
+
MessageFormat.IMAGE_TOKEN: partial(
|
|
225
|
+
self._format_with_token, token="<image>"
|
|
226
|
+
),
|
|
227
|
+
MessageFormat.IMAGE_TOKEN_PIPE: partial(
|
|
228
|
+
self._format_with_token, token="<|image|>"
|
|
229
|
+
),
|
|
230
|
+
MessageFormat.START_IMAGE_TOKEN: partial(
|
|
231
|
+
self._format_with_token, token="<start_of_image>", image_first=False
|
|
232
|
+
),
|
|
233
|
+
MessageFormat.IMAGE_TOKEN_NEWLINE: partial(
|
|
234
|
+
self._format_with_token, token="<image>\n"
|
|
235
|
+
),
|
|
236
|
+
MessageFormat.NUMBERED_IMAGE_TOKENS: self._format_numbered_tokens,
|
|
237
|
+
MessageFormat.PROMPT_ONLY: lambda *args, **kw: prompt,
|
|
238
|
+
MessageFormat.PROMPT_WITH_IMAGE_TOKEN: lambda *args, **kw: "<image>"
|
|
239
|
+
* num_images
|
|
240
|
+
+ prompt,
|
|
241
|
+
MessageFormat.PROMPT_WITH_START_IMAGE_TOKEN: lambda *args, **kw: prompt
|
|
242
|
+
+ "<start_of_image>" * num_images,
|
|
243
|
+
MessageFormat.VIDEO_WITH_TEXT: self._format_video_message,
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
formatter = formatter_map.get(self.format_type)
|
|
247
|
+
return formatter(
|
|
248
|
+
prompt,
|
|
249
|
+
role,
|
|
250
|
+
skip_image_token,
|
|
251
|
+
skip_audio_token,
|
|
252
|
+
num_images,
|
|
253
|
+
num_audios,
|
|
254
|
+
**kwargs,
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
def _format_list_with_image(
|
|
258
|
+
self,
|
|
259
|
+
prompt: str,
|
|
260
|
+
role: str,
|
|
261
|
+
skip_image_token: bool,
|
|
262
|
+
skip_audio_token: bool,
|
|
263
|
+
num_images: int,
|
|
264
|
+
num_audios: int,
|
|
265
|
+
image_first: bool = False,
|
|
266
|
+
use_image_url: bool = False,
|
|
267
|
+
**kwargs,
|
|
268
|
+
) -> Dict[str, Any]:
|
|
269
|
+
"""Format as a list with image tokens."""
|
|
270
|
+
content = [MessageBuilder.text_message(prompt)]
|
|
271
|
+
|
|
272
|
+
if role == "user" and not skip_image_token and num_images > 0:
|
|
273
|
+
image_builder = (
|
|
274
|
+
MessageBuilder.image_url_message
|
|
275
|
+
if use_image_url
|
|
276
|
+
else MessageBuilder.image_message
|
|
277
|
+
)
|
|
278
|
+
image_tokens = [image_builder()] * num_images
|
|
279
|
+
content = image_tokens + content if image_first else content + image_tokens
|
|
280
|
+
|
|
281
|
+
return {"role": role, "content": content}
|
|
282
|
+
|
|
283
|
+
def _format_list_with_image_type(
|
|
284
|
+
self,
|
|
285
|
+
prompt: str,
|
|
286
|
+
role: str,
|
|
287
|
+
skip_image_token: bool,
|
|
288
|
+
skip_audio_token: bool,
|
|
289
|
+
num_images: int,
|
|
290
|
+
num_audios: int,
|
|
291
|
+
message_type: str = "content",
|
|
292
|
+
image_first: bool = True,
|
|
293
|
+
**kwargs,
|
|
294
|
+
) -> Dict[str, Any]:
|
|
295
|
+
"""Format as a list with typed messages."""
|
|
296
|
+
msg_func = (
|
|
297
|
+
MessageBuilder.content_message
|
|
298
|
+
if message_type == "content"
|
|
299
|
+
else MessageBuilder.text_message
|
|
300
|
+
)
|
|
301
|
+
message = {"role": role, "content": [msg_func(prompt)]}
|
|
302
|
+
|
|
303
|
+
if role == "user":
|
|
304
|
+
if not skip_image_token and num_images > 0:
|
|
305
|
+
message["content"] = (
|
|
306
|
+
[MessageBuilder.image_message()] * num_images + message["content"]
|
|
307
|
+
if image_first
|
|
308
|
+
else message["content"]
|
|
309
|
+
+ [MessageBuilder.image_message()] * num_images
|
|
310
|
+
)
|
|
311
|
+
if not skip_audio_token and num_audios > 0:
|
|
312
|
+
message["content"] = (
|
|
313
|
+
message["content"] + [MessageBuilder.audio_message()] * num_audios
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
if role == "assistant":
|
|
317
|
+
message["content"] = message["content"][0].get(
|
|
318
|
+
"content", message["content"][0].get("text")
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
return message
|
|
322
|
+
|
|
323
|
+
def _format_with_token(
|
|
324
|
+
self,
|
|
325
|
+
prompt: str,
|
|
326
|
+
role: str,
|
|
327
|
+
skip_image_token: bool,
|
|
328
|
+
skip_audio_token: bool,
|
|
329
|
+
num_images: int,
|
|
330
|
+
num_audios: int,
|
|
331
|
+
token: str,
|
|
332
|
+
image_first: bool = True,
|
|
333
|
+
**kwargs,
|
|
334
|
+
) -> Dict[str, Any]:
|
|
335
|
+
"""Format with image tokens in the text."""
|
|
336
|
+
content = prompt
|
|
337
|
+
|
|
338
|
+
if role == "user" and not skip_image_token and num_images > 0:
|
|
339
|
+
prefix = token * num_images
|
|
340
|
+
content = f"{prefix}{content}" if image_first else f"{content}{prefix}"
|
|
341
|
+
|
|
342
|
+
return {"role": role, "content": content}
|
|
343
|
+
|
|
344
|
+
def _format_numbered_tokens(
|
|
345
|
+
self,
|
|
346
|
+
prompt: str,
|
|
347
|
+
role: str,
|
|
348
|
+
skip_image_token: bool,
|
|
349
|
+
skip_audio_token: bool,
|
|
350
|
+
num_images: int,
|
|
351
|
+
num_audios: int,
|
|
352
|
+
**kwargs,
|
|
353
|
+
) -> Dict[str, Any]:
|
|
354
|
+
"""Format with numbered image tokens."""
|
|
355
|
+
content = prompt
|
|
356
|
+
|
|
357
|
+
if role == "user" and not skip_image_token and num_images > 0:
|
|
358
|
+
# phi3_v uses single token regardless of num_images
|
|
359
|
+
prefix = (
|
|
360
|
+
"<|image_1|>"
|
|
361
|
+
if self.model_name == "phi3_v"
|
|
362
|
+
else " ".join([f"<|image_{i+1}|>" for i in range(num_images)])
|
|
363
|
+
)
|
|
364
|
+
content = f"{prefix}{content}"
|
|
365
|
+
|
|
366
|
+
return {"role": role, "content": content}
|
|
367
|
+
|
|
368
|
+
def _format_video_message(
|
|
369
|
+
self,
|
|
370
|
+
prompt: str,
|
|
371
|
+
role: str = "user",
|
|
372
|
+
skip_image_token: bool = False,
|
|
373
|
+
skip_audio_token: bool = False,
|
|
374
|
+
num_images: int = 0,
|
|
375
|
+
num_audios: int = 0,
|
|
376
|
+
**kwargs,
|
|
377
|
+
) -> Dict[str, Any]:
|
|
378
|
+
"""Format a video message with text."""
|
|
379
|
+
return {
|
|
380
|
+
"role": role,
|
|
381
|
+
"content": [
|
|
382
|
+
MessageBuilder.video_message(
|
|
383
|
+
kwargs["video"],
|
|
384
|
+
kwargs.get("max_pixels", 224 * 224),
|
|
385
|
+
kwargs.get("fps", 1),
|
|
386
|
+
),
|
|
387
|
+
MessageBuilder.text_message(prompt),
|
|
388
|
+
],
|
|
389
|
+
}
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
def get_message_json(
|
|
393
|
+
model_name: str,
|
|
394
|
+
prompt: str,
|
|
395
|
+
role: str = "user",
|
|
396
|
+
skip_image_token: bool = False,
|
|
397
|
+
skip_audio_token: bool = False,
|
|
398
|
+
num_images: int = 0,
|
|
399
|
+
num_audios: int = 0,
|
|
400
|
+
**kwargs,
|
|
401
|
+
) -> Union[str, Dict[str, Any]]:
|
|
402
|
+
"""
|
|
403
|
+
Get the appropriate JSON message based on the specified model.
|
|
404
|
+
|
|
405
|
+
Args:
|
|
406
|
+
model_name: The model for which to generate the message
|
|
407
|
+
prompt: The text prompt to be included in the message
|
|
408
|
+
role: The role of the message (default: "user")
|
|
409
|
+
skip_image_token: Whether to skip adding image tokens
|
|
410
|
+
skip_audio_token: Whether to skip adding audio tokens
|
|
411
|
+
num_images: Number of image tokens to add
|
|
412
|
+
num_audios: Number of audio tokens to add
|
|
413
|
+
**kwargs: Additional arguments (e.g., video path, max_pixels, fps)
|
|
414
|
+
|
|
415
|
+
Returns:
|
|
416
|
+
A dictionary or string representing the message for the specified model
|
|
417
|
+
"""
|
|
418
|
+
formatter = MessageFormatter(model_name)
|
|
419
|
+
|
|
420
|
+
return formatter.format_message(
|
|
421
|
+
prompt,
|
|
422
|
+
role,
|
|
423
|
+
skip_image_token,
|
|
424
|
+
skip_audio_token,
|
|
425
|
+
num_images,
|
|
426
|
+
num_audios,
|
|
427
|
+
**kwargs,
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
def get_chat_template(
|
|
432
|
+
processor,
|
|
433
|
+
messages: List[Dict[str, Any]],
|
|
434
|
+
add_generation_prompt: bool,
|
|
435
|
+
tokenize: bool = False,
|
|
436
|
+
**kwargs,
|
|
437
|
+
) -> Any:
|
|
438
|
+
"""Apply chat template using processor's tokenizer."""
|
|
439
|
+
try:
|
|
440
|
+
processor = (
|
|
441
|
+
processor
|
|
442
|
+
if processor.__dict__.get("chat_template")
|
|
443
|
+
else processor.tokenizer
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
return processor.apply_chat_template(
|
|
447
|
+
messages,
|
|
448
|
+
tokenize=tokenize,
|
|
449
|
+
add_generation_prompt=add_generation_prompt,
|
|
450
|
+
**kwargs,
|
|
451
|
+
)
|
|
452
|
+
except AttributeError:
|
|
453
|
+
raise ValueError(
|
|
454
|
+
"Error: processor does not have 'chat_template' or 'tokenizer' attribute."
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
def apply_chat_template(
|
|
459
|
+
processor,
|
|
460
|
+
config: Union[Dict[str, Any], Any],
|
|
461
|
+
prompt: Union[str, Dict[str, Any], List[Any]],
|
|
462
|
+
add_generation_prompt: bool = True,
|
|
463
|
+
return_messages: bool = False,
|
|
464
|
+
num_images: int = 0,
|
|
465
|
+
num_audios: int = 0,
|
|
466
|
+
**kwargs,
|
|
467
|
+
) -> Union[List[Dict[str, Any]], str, Any]:
|
|
468
|
+
"""
|
|
469
|
+
Apply chat template to prompts.
|
|
470
|
+
|
|
471
|
+
Args:
|
|
472
|
+
processor: The processor with chat template functionality
|
|
473
|
+
config: Model configuration
|
|
474
|
+
prompt: Single prompt string, dict, or list of prompts
|
|
475
|
+
add_generation_prompt: Whether to add generation prompt
|
|
476
|
+
return_messages: Whether to return messages list instead of template
|
|
477
|
+
num_images: Number of images in the input
|
|
478
|
+
num_audios: Number of audio files in the input
|
|
479
|
+
**kwargs: Additional arguments for message formatting
|
|
480
|
+
|
|
481
|
+
Returns:
|
|
482
|
+
Formatted messages or chat template
|
|
483
|
+
"""
|
|
484
|
+
config = config if isinstance(config, dict) else config.__dict__
|
|
485
|
+
model_type = config["model_type"]
|
|
486
|
+
|
|
487
|
+
# Build messages from prompts
|
|
488
|
+
messages = []
|
|
489
|
+
|
|
490
|
+
if isinstance(prompt, str):
|
|
491
|
+
# Single string prompt
|
|
492
|
+
messages.append(
|
|
493
|
+
get_message_json(
|
|
494
|
+
model_type,
|
|
495
|
+
prompt,
|
|
496
|
+
num_images=num_images,
|
|
497
|
+
num_audios=num_audios,
|
|
498
|
+
**kwargs,
|
|
499
|
+
)
|
|
500
|
+
)
|
|
501
|
+
elif isinstance(prompt, dict):
|
|
502
|
+
# Single dict prompt
|
|
503
|
+
content = extract_text_from_content(prompt["content"])
|
|
504
|
+
messages.append(
|
|
505
|
+
get_message_json(
|
|
506
|
+
model_type,
|
|
507
|
+
content,
|
|
508
|
+
prompt["role"],
|
|
509
|
+
num_images=num_images,
|
|
510
|
+
num_audios=num_audios,
|
|
511
|
+
**kwargs,
|
|
512
|
+
)
|
|
513
|
+
)
|
|
514
|
+
elif isinstance(prompt, list):
|
|
515
|
+
# List of prompts
|
|
516
|
+
for i, p in enumerate(prompt):
|
|
517
|
+
if isinstance(p, str):
|
|
518
|
+
is_first = i == 0
|
|
519
|
+
messages.append(
|
|
520
|
+
get_message_json(
|
|
521
|
+
model_type,
|
|
522
|
+
p,
|
|
523
|
+
skip_image_token=not is_first,
|
|
524
|
+
skip_audio_token=not is_first,
|
|
525
|
+
num_images=num_images,
|
|
526
|
+
num_audios=num_audios,
|
|
527
|
+
**kwargs,
|
|
528
|
+
)
|
|
529
|
+
)
|
|
530
|
+
elif isinstance(p, dict) or isinstance(p, BaseModel):
|
|
531
|
+
role = "user"
|
|
532
|
+
content = ""
|
|
533
|
+
if isinstance(p, dict):
|
|
534
|
+
role = p.get("role", "user")
|
|
535
|
+
content = p.get("content")
|
|
536
|
+
else:
|
|
537
|
+
role = p.role
|
|
538
|
+
content = p.content
|
|
539
|
+
# Handle multimodal content: extract only text, skip image/audio URLs
|
|
540
|
+
# This prevents base64 image data from being tokenized as text
|
|
541
|
+
content = extract_text_from_content(content)
|
|
542
|
+
is_first = i == 0 or (i == 1 and role not in ["system", "assistant"])
|
|
543
|
+
messages.append(
|
|
544
|
+
get_message_json(
|
|
545
|
+
model_type,
|
|
546
|
+
content,
|
|
547
|
+
role,
|
|
548
|
+
skip_image_token=not is_first
|
|
549
|
+
or role in ["system", "assistant"],
|
|
550
|
+
skip_audio_token=not is_first
|
|
551
|
+
or role in ["system", "assistant"],
|
|
552
|
+
num_images=num_images,
|
|
553
|
+
num_audios=num_audios,
|
|
554
|
+
**kwargs,
|
|
555
|
+
)
|
|
556
|
+
)
|
|
557
|
+
|
|
558
|
+
if return_messages:
|
|
559
|
+
return messages
|
|
560
|
+
|
|
561
|
+
# Some models only need the last message
|
|
562
|
+
if model_type in ["paligemma", "molmo", "florence2", "moondream1"]:
|
|
563
|
+
return messages[-1]
|
|
564
|
+
|
|
565
|
+
return get_chat_template(processor, messages, add_generation_prompt)
|
mlx_vlm/sample_utils.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
import mlx.core as mx
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array:
|
|
5
|
+
"""
|
|
6
|
+
Apply top-p (nucleus) sampling to logits.
|
|
7
|
+
|
|
8
|
+
Args:
|
|
9
|
+
logits: The logits from the model's output.
|
|
10
|
+
top_p: The cumulative probability threshold for top-p filtering.
|
|
11
|
+
temperature: Temperature parameter for softmax distribution reshaping.
|
|
12
|
+
Returns:
|
|
13
|
+
token selected based on the top-p criterion.
|
|
14
|
+
"""
|
|
15
|
+
if (
|
|
16
|
+
logits.dtype == mx.bfloat16
|
|
17
|
+
): # workaround for unable to load kernel contiguous_scan_inclusive_sum_bfloat16_bfloat16
|
|
18
|
+
logits = logits.astype(mx.float32)
|
|
19
|
+
|
|
20
|
+
# referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460
|
|
21
|
+
probs = mx.softmax(logits / temperature, axis=-1)
|
|
22
|
+
|
|
23
|
+
# sort probs in ascending order
|
|
24
|
+
sorted_indices = mx.argsort(probs, axis=-1)
|
|
25
|
+
sorted_probs = probs[..., sorted_indices.squeeze(0)]
|
|
26
|
+
|
|
27
|
+
cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
|
|
28
|
+
|
|
29
|
+
# select tokens with cumulative probs below threshold
|
|
30
|
+
top_probs = mx.where(
|
|
31
|
+
cumulative_probs > 1 - top_p,
|
|
32
|
+
sorted_probs,
|
|
33
|
+
mx.zeros_like(sorted_probs),
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
sorted_token = mx.random.categorical(mx.log(top_probs))
|
|
37
|
+
token = sorted_indices.squeeze(0)[sorted_token]
|
|
38
|
+
|
|
39
|
+
return token
|