fount-vlm-nell-02 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
- fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
- fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
- fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
- mlx_vlm/__init__.py +16 -0
- mlx_vlm/__main__.py +24 -0
- mlx_vlm/chat.py +234 -0
- mlx_vlm/chat_ui.py +508 -0
- mlx_vlm/convert.py +284 -0
- mlx_vlm/deprecation.py +52 -0
- mlx_vlm/evals/__init__.py +0 -0
- mlx_vlm/evals/math_vista.py +565 -0
- mlx_vlm/evals/mmmu.py +528 -0
- mlx_vlm/evals/mmstar.py +343 -0
- mlx_vlm/evals/ocrbench.py +453 -0
- mlx_vlm/evals/utils.py +37 -0
- mlx_vlm/generate.py +1457 -0
- mlx_vlm/lora.py +207 -0
- mlx_vlm/models/__init__.py +0 -0
- mlx_vlm/models/aya_vision/__init__.py +2 -0
- mlx_vlm/models/aya_vision/aya_vision.py +188 -0
- mlx_vlm/models/aya_vision/config.py +52 -0
- mlx_vlm/models/aya_vision/language.py +202 -0
- mlx_vlm/models/aya_vision/vision.py +340 -0
- mlx_vlm/models/base.py +356 -0
- mlx_vlm/models/cache.py +238 -0
- mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
- mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
- mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
- mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
- mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
- mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
- mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
- mlx_vlm/models/deepseekocr/__init__.py +2 -0
- mlx_vlm/models/deepseekocr/config.py +173 -0
- mlx_vlm/models/deepseekocr/conversation.py +264 -0
- mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
- mlx_vlm/models/deepseekocr/language.py +547 -0
- mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
- mlx_vlm/models/deepseekocr/sam.py +489 -0
- mlx_vlm/models/deepseekocr/vision.py +263 -0
- mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
- mlx_vlm/models/deepseekocr_2/config.py +216 -0
- mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
- mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
- mlx_vlm/models/deepseekocr_2/vision.py +439 -0
- mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
- mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
- mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
- mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
- mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
- mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
- mlx_vlm/models/fastvlm/__init__.py +2 -0
- mlx_vlm/models/fastvlm/config.py +79 -0
- mlx_vlm/models/fastvlm/fastvlm.py +198 -0
- mlx_vlm/models/fastvlm/language.py +49 -0
- mlx_vlm/models/fastvlm/vision.py +692 -0
- mlx_vlm/models/florence2/__init__.py +2 -0
- mlx_vlm/models/florence2/config.py +84 -0
- mlx_vlm/models/florence2/florence2.py +383 -0
- mlx_vlm/models/florence2/language.py +452 -0
- mlx_vlm/models/florence2/processing_florence2.py +30 -0
- mlx_vlm/models/florence2/vision.py +552 -0
- mlx_vlm/models/gemma3/__init__.py +2 -0
- mlx_vlm/models/gemma3/config.py +52 -0
- mlx_vlm/models/gemma3/gemma3.py +194 -0
- mlx_vlm/models/gemma3/language.py +293 -0
- mlx_vlm/models/gemma3/vision.py +215 -0
- mlx_vlm/models/gemma3n/__init__.py +2 -0
- mlx_vlm/models/gemma3n/audio.py +1038 -0
- mlx_vlm/models/gemma3n/config.py +130 -0
- mlx_vlm/models/gemma3n/gemma3n.py +322 -0
- mlx_vlm/models/gemma3n/language.py +631 -0
- mlx_vlm/models/gemma3n/vision.py +994 -0
- mlx_vlm/models/glm4v/__init__.py +3 -0
- mlx_vlm/models/glm4v/config.py +79 -0
- mlx_vlm/models/glm4v/glm4v.py +188 -0
- mlx_vlm/models/glm4v/language.py +574 -0
- mlx_vlm/models/glm4v/processing.py +220 -0
- mlx_vlm/models/glm4v/vision.py +406 -0
- mlx_vlm/models/glm4v_moe/__init__.py +3 -0
- mlx_vlm/models/glm4v_moe/config.py +81 -0
- mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
- mlx_vlm/models/glm4v_moe/language.py +674 -0
- mlx_vlm/models/glm4v_moe/processing.py +229 -0
- mlx_vlm/models/glm4v_moe/vision.py +405 -0
- mlx_vlm/models/glm_ocr/__init__.py +3 -0
- mlx_vlm/models/glm_ocr/config.py +93 -0
- mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
- mlx_vlm/models/glm_ocr/language.py +585 -0
- mlx_vlm/models/glm_ocr/processing.py +208 -0
- mlx_vlm/models/glm_ocr/vision.py +342 -0
- mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
- mlx_vlm/models/hunyuan_vl/config.py +136 -0
- mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
- mlx_vlm/models/hunyuan_vl/language.py +509 -0
- mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
- mlx_vlm/models/hunyuan_vl/vision.py +322 -0
- mlx_vlm/models/idefics2/__init__.py +2 -0
- mlx_vlm/models/idefics2/config.py +65 -0
- mlx_vlm/models/idefics2/idefics2.py +321 -0
- mlx_vlm/models/idefics2/language.py +161 -0
- mlx_vlm/models/idefics2/vision.py +244 -0
- mlx_vlm/models/idefics3/__init__.py +4 -0
- mlx_vlm/models/idefics3/config.py +54 -0
- mlx_vlm/models/idefics3/idefics3.py +221 -0
- mlx_vlm/models/idefics3/language.py +157 -0
- mlx_vlm/models/idefics3/vision.py +265 -0
- mlx_vlm/models/internvl_chat/__init__.py +3 -0
- mlx_vlm/models/internvl_chat/config.py +89 -0
- mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
- mlx_vlm/models/internvl_chat/language.py +187 -0
- mlx_vlm/models/internvl_chat/processor.py +395 -0
- mlx_vlm/models/internvl_chat/vision.py +265 -0
- mlx_vlm/models/interpolate.py +183 -0
- mlx_vlm/models/jina_vlm/__init__.py +3 -0
- mlx_vlm/models/jina_vlm/config.py +142 -0
- mlx_vlm/models/jina_vlm/image_processor.py +430 -0
- mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
- mlx_vlm/models/jina_vlm/language.py +272 -0
- mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
- mlx_vlm/models/jina_vlm/vision.py +202 -0
- mlx_vlm/models/kernels.py +447 -0
- mlx_vlm/models/kimi_vl/__init__.py +4 -0
- mlx_vlm/models/kimi_vl/config.py +84 -0
- mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
- mlx_vlm/models/kimi_vl/language.py +460 -0
- mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
- mlx_vlm/models/kimi_vl/vision.py +485 -0
- mlx_vlm/models/lfm2_vl/__init__.py +2 -0
- mlx_vlm/models/lfm2_vl/config.py +94 -0
- mlx_vlm/models/lfm2_vl/language.py +49 -0
- mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
- mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
- mlx_vlm/models/lfm2_vl/vision.py +223 -0
- mlx_vlm/models/llama4/__init__.py +2 -0
- mlx_vlm/models/llama4/config.py +83 -0
- mlx_vlm/models/llama4/language.py +334 -0
- mlx_vlm/models/llama4/llama4.py +146 -0
- mlx_vlm/models/llama4/vision.py +526 -0
- mlx_vlm/models/llava/__init__.py +2 -0
- mlx_vlm/models/llava/config.py +61 -0
- mlx_vlm/models/llava/language.py +200 -0
- mlx_vlm/models/llava/llava.py +132 -0
- mlx_vlm/models/llava/vision.py +233 -0
- mlx_vlm/models/llava_bunny/__init__.py +2 -0
- mlx_vlm/models/llava_bunny/config.py +85 -0
- mlx_vlm/models/llava_bunny/language.py +194 -0
- mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
- mlx_vlm/models/llava_bunny/vision.py +278 -0
- mlx_vlm/models/llava_next/__init__.py +2 -0
- mlx_vlm/models/llava_next/config.py +60 -0
- mlx_vlm/models/llava_next/language.py +192 -0
- mlx_vlm/models/llava_next/llava_next.py +138 -0
- mlx_vlm/models/llava_next/vision.py +217 -0
- mlx_vlm/models/mistral3/__init__.py +2 -0
- mlx_vlm/models/mistral3/config.py +59 -0
- mlx_vlm/models/mistral3/language.py +269 -0
- mlx_vlm/models/mistral3/mistral3.py +383 -0
- mlx_vlm/models/mllama/__init__.py +4 -0
- mlx_vlm/models/mllama/config.py +74 -0
- mlx_vlm/models/mllama/language.py +377 -0
- mlx_vlm/models/mllama/mllama.py +210 -0
- mlx_vlm/models/mllama/vision.py +458 -0
- mlx_vlm/models/molmo/__init__.py +5 -0
- mlx_vlm/models/molmo/config.py +93 -0
- mlx_vlm/models/molmo/language.py +208 -0
- mlx_vlm/models/molmo/molmo.py +108 -0
- mlx_vlm/models/molmo/processing_molmo.py +763 -0
- mlx_vlm/models/molmo/vision.py +408 -0
- mlx_vlm/models/molmo2/__init__.py +6 -0
- mlx_vlm/models/molmo2/config.py +137 -0
- mlx_vlm/models/molmo2/language.py +206 -0
- mlx_vlm/models/molmo2/molmo2.py +330 -0
- mlx_vlm/models/molmo2/processing.py +773 -0
- mlx_vlm/models/molmo2/vision.py +286 -0
- mlx_vlm/models/moondream2/__init__.py +11 -0
- mlx_vlm/models/moondream2/config.py +92 -0
- mlx_vlm/models/moondream2/image_crops.py +269 -0
- mlx_vlm/models/moondream2/language.py +267 -0
- mlx_vlm/models/moondream2/moondream2.py +522 -0
- mlx_vlm/models/moondream2/processing_moondream.py +144 -0
- mlx_vlm/models/moondream2/vision.py +200 -0
- mlx_vlm/models/multi_modality/__init__.py +4 -0
- mlx_vlm/models/multi_modality/config.py +108 -0
- mlx_vlm/models/multi_modality/language.py +191 -0
- mlx_vlm/models/multi_modality/multi_modality.py +338 -0
- mlx_vlm/models/multi_modality/sam.py +543 -0
- mlx_vlm/models/multi_modality/vision.py +450 -0
- mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
- mlx_vlm/models/paddleocr_vl/config.py +93 -0
- mlx_vlm/models/paddleocr_vl/language.py +522 -0
- mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
- mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
- mlx_vlm/models/paddleocr_vl/vision.py +358 -0
- mlx_vlm/models/paligemma/__init__.py +4 -0
- mlx_vlm/models/paligemma/config.py +50 -0
- mlx_vlm/models/paligemma/language.py +253 -0
- mlx_vlm/models/paligemma/paligemma.py +140 -0
- mlx_vlm/models/paligemma/vision.py +218 -0
- mlx_vlm/models/phi3_v/__init__.py +5 -0
- mlx_vlm/models/phi3_v/config.py +55 -0
- mlx_vlm/models/phi3_v/language.py +2 -0
- mlx_vlm/models/phi3_v/phi3_v.py +239 -0
- mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
- mlx_vlm/models/phi3_v/vision.py +294 -0
- mlx_vlm/models/pixtral/__init__.py +4 -0
- mlx_vlm/models/pixtral/config.py +69 -0
- mlx_vlm/models/pixtral/language.py +195 -0
- mlx_vlm/models/pixtral/pixtral.py +208 -0
- mlx_vlm/models/pixtral/vision.py +293 -0
- mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_5_vl/config.py +90 -0
- mlx_vlm/models/qwen2_5_vl/language.py +541 -0
- mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
- mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
- mlx_vlm/models/qwen2_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_vl/config.py +86 -0
- mlx_vlm/models/qwen2_vl/language.py +539 -0
- mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
- mlx_vlm/models/qwen2_vl/vision.py +308 -0
- mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
- mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
- mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
- mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
- mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
- mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
- mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
- mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
- mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
- mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
- mlx_vlm/models/qwen3_vl/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl/config.py +103 -0
- mlx_vlm/models/qwen3_vl/language.py +596 -0
- mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
- mlx_vlm/models/qwen3_vl/vision.py +441 -0
- mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
- mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
- mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
- mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
- mlx_vlm/models/smolvlm/__init__.py +4 -0
- mlx_vlm/models/smolvlm/config.py +59 -0
- mlx_vlm/models/smolvlm/smolvlm.py +60 -0
- mlx_vlm/prompt_utils.py +565 -0
- mlx_vlm/sample_utils.py +39 -0
- mlx_vlm/server.py +1107 -0
- mlx_vlm/smolvlm_video_generate.py +109 -0
- mlx_vlm/tokenizer_utils.py +371 -0
- mlx_vlm/trainer/__init__.py +9 -0
- mlx_vlm/trainer/lora.py +70 -0
- mlx_vlm/trainer/trainer.py +299 -0
- mlx_vlm/trainer/utils.py +160 -0
- mlx_vlm/utils.py +1339 -0
- mlx_vlm/version.py +1 -0
- mlx_vlm/video_generate.py +611 -0
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import mlx.core as mx
|
|
4
|
+
import mlx.nn as nn
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from ..base import InputEmbeddingsFeatures
|
|
8
|
+
from . import processing_lfm2_vl # noqa: F401
|
|
9
|
+
from .config import ModelConfig
|
|
10
|
+
from .language import LanguageModel
|
|
11
|
+
from .vision import VisionModel
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Lfm2VlMultiModalProjector(nn.Module):
|
|
15
|
+
def __init__(self, config: ModelConfig):
|
|
16
|
+
super().__init__()
|
|
17
|
+
in_channels = config.vision_config.hidden_size * (config.downsample_factor**2)
|
|
18
|
+
if config.projector_use_layernorm:
|
|
19
|
+
self.layer_norm = nn.LayerNorm(in_channels)
|
|
20
|
+
else:
|
|
21
|
+
self.layer_norm = nn.Identity()
|
|
22
|
+
self.linear_1 = nn.Linear(
|
|
23
|
+
in_channels,
|
|
24
|
+
config.projector_hidden_size,
|
|
25
|
+
bias=config.projector_bias,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
self.linear_2 = nn.Linear(
|
|
29
|
+
config.projector_hidden_size,
|
|
30
|
+
config.text_config.hidden_size,
|
|
31
|
+
bias=config.projector_bias,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
def __call__(self, x):
|
|
35
|
+
x = self.linear_1(self.layer_norm(x))
|
|
36
|
+
x = self.linear_2(nn.gelu(x))
|
|
37
|
+
return x
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class PixelUnshuffleBlock(nn.Module):
|
|
41
|
+
def __init__(self, factor: int):
|
|
42
|
+
super().__init__()
|
|
43
|
+
self.factor = factor
|
|
44
|
+
|
|
45
|
+
def __call__(self, x):
|
|
46
|
+
n, w, h, c = x.shape
|
|
47
|
+
if w % self.factor != 0:
|
|
48
|
+
x = mx.concatenate(
|
|
49
|
+
[
|
|
50
|
+
x,
|
|
51
|
+
mx.zeros((n, self.factor - (w % self.factor), h, c), dtype=x.dtype),
|
|
52
|
+
],
|
|
53
|
+
axis=1,
|
|
54
|
+
)
|
|
55
|
+
n, w, h, c = x.shape
|
|
56
|
+
|
|
57
|
+
if h % self.factor != 0:
|
|
58
|
+
x = mx.concatenate(
|
|
59
|
+
[
|
|
60
|
+
x,
|
|
61
|
+
mx.zeros((n, w, self.factor - (h % self.factor), c), dtype=x.dtype),
|
|
62
|
+
],
|
|
63
|
+
axis=2,
|
|
64
|
+
)
|
|
65
|
+
n, w, h, c = x.shape
|
|
66
|
+
x = x.reshape(n, w, int(h / self.factor), int(c * self.factor))
|
|
67
|
+
x = x.transpose(0, 2, 1, 3)
|
|
68
|
+
x = x.reshape(
|
|
69
|
+
n, int(h / self.factor), int(w / self.factor), int(c * self.factor**2)
|
|
70
|
+
)
|
|
71
|
+
x = x.transpose(0, 2, 1, 3)
|
|
72
|
+
return x
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def masked_scatter(
|
|
76
|
+
final_embedding: mx.array,
|
|
77
|
+
image_mask_expanded: mx.array,
|
|
78
|
+
scaled_image_features: mx.array,
|
|
79
|
+
):
|
|
80
|
+
# Reshape the tensors to 1D
|
|
81
|
+
final_embedding_shape = final_embedding.shape
|
|
82
|
+
scaled_image_features_flattened = mx.flatten(scaled_image_features)
|
|
83
|
+
final_embedding_flattened = mx.flatten(final_embedding)
|
|
84
|
+
image_mask_expanded_flattened = mx.flatten(image_mask_expanded)
|
|
85
|
+
|
|
86
|
+
# Scatter the scaled image features into the special image token positions
|
|
87
|
+
image_positions = mx.array(np.where(image_mask_expanded_flattened)[0], mx.uint32)
|
|
88
|
+
final_embedding_flattened[image_positions] = scaled_image_features_flattened
|
|
89
|
+
|
|
90
|
+
# Reshape back to the original shape
|
|
91
|
+
final_embedding = mx.reshape(final_embedding_flattened, final_embedding_shape)
|
|
92
|
+
|
|
93
|
+
return final_embedding
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class Model(nn.Module):
|
|
97
|
+
def __init__(self, config: ModelConfig):
|
|
98
|
+
super().__init__()
|
|
99
|
+
self.model_type = config.model_type
|
|
100
|
+
self.config = config
|
|
101
|
+
self.vision_tower = VisionModel(config.vision_config)
|
|
102
|
+
|
|
103
|
+
if config.vision_feature_layer != -1:
|
|
104
|
+
self.vision_tower.encoder.layers = self.vision_tower.encoder.layers[
|
|
105
|
+
: config.vision_feature_layer + 1
|
|
106
|
+
]
|
|
107
|
+
if config.downsample_factor > 1:
|
|
108
|
+
self.pixel_unshuffle = PixelUnshuffleBlock(config.downsample_factor)
|
|
109
|
+
else:
|
|
110
|
+
self.pixel_unshuffle = nn.Identity()
|
|
111
|
+
|
|
112
|
+
self.multi_modal_projector = Lfm2VlMultiModalProjector(config)
|
|
113
|
+
self.language_model = LanguageModel(config.text_config)
|
|
114
|
+
|
|
115
|
+
def get_input_embeddings(
|
|
116
|
+
self,
|
|
117
|
+
input_ids: Optional[mx.array] = None,
|
|
118
|
+
pixel_values: Optional[mx.array] = None,
|
|
119
|
+
**kwargs,
|
|
120
|
+
):
|
|
121
|
+
spatial_shapes = kwargs.get("spatial_shapes", None)
|
|
122
|
+
pixel_attention_mask = kwargs.get("pixel_attention_mask", None)
|
|
123
|
+
|
|
124
|
+
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
|
125
|
+
|
|
126
|
+
if pixel_values is None:
|
|
127
|
+
return InputEmbeddingsFeatures(inputs_embeds=inputs_embeds)
|
|
128
|
+
|
|
129
|
+
# Get the ouptut hidden states from the vision model
|
|
130
|
+
*_, hidden_states = self.vision_tower(
|
|
131
|
+
pixel_values, output_hidden_states=True, spatial_shapes=spatial_shapes
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
img_feature_lengths = pixel_attention_mask.sum(axis=1).tolist()
|
|
135
|
+
image_features = []
|
|
136
|
+
|
|
137
|
+
for img_idx in range(hidden_states.shape[0]):
|
|
138
|
+
feature = hidden_states[img_idx]
|
|
139
|
+
|
|
140
|
+
feature = feature[: img_feature_lengths[img_idx], :][None, ...]
|
|
141
|
+
|
|
142
|
+
feature_org_h, feature_org_w = spatial_shapes[img_idx]
|
|
143
|
+
feature = feature.reshape(1, feature_org_h, feature_org_w, -1)
|
|
144
|
+
feature = self.pixel_unshuffle(feature)
|
|
145
|
+
|
|
146
|
+
img_embedding = self.multi_modal_projector(feature)
|
|
147
|
+
|
|
148
|
+
img_embedding = img_embedding.reshape(-1, img_embedding.shape[-1])
|
|
149
|
+
image_features.append(img_embedding)
|
|
150
|
+
|
|
151
|
+
image_features = mx.concatenate(image_features, axis=0)
|
|
152
|
+
|
|
153
|
+
final_inputs_embeds = self.merge_input_ids_with_image_features(
|
|
154
|
+
image_features, inputs_embeds, input_ids, self.config.image_token_index
|
|
155
|
+
)
|
|
156
|
+
return InputEmbeddingsFeatures(inputs_embeds=final_inputs_embeds)
|
|
157
|
+
|
|
158
|
+
@staticmethod
|
|
159
|
+
def merge_input_ids_with_image_features(
|
|
160
|
+
image_features, inputs_embeds, input_ids, image_token_index
|
|
161
|
+
):
|
|
162
|
+
special_image_mask = input_ids == image_token_index
|
|
163
|
+
n_image_tokens = special_image_mask.sum()
|
|
164
|
+
special_image_mask = special_image_mask[..., None]
|
|
165
|
+
special_image_mask = mx.broadcast_to(special_image_mask, inputs_embeds.shape)
|
|
166
|
+
|
|
167
|
+
n_image_features = image_features.shape[0]
|
|
168
|
+
n_image_mask_elements = special_image_mask.sum()
|
|
169
|
+
if n_image_mask_elements != image_features.size:
|
|
170
|
+
raise ValueError(
|
|
171
|
+
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
inputs_embeds = masked_scatter(
|
|
175
|
+
inputs_embeds, special_image_mask, image_features
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
return inputs_embeds
|
|
179
|
+
|
|
180
|
+
@property
|
|
181
|
+
def layers(self):
|
|
182
|
+
return self.language_model.model.layers
|
|
183
|
+
|
|
184
|
+
def __call__(
|
|
185
|
+
self,
|
|
186
|
+
input_ids: mx.array,
|
|
187
|
+
pixel_values: mx.array,
|
|
188
|
+
mask: mx.array,
|
|
189
|
+
cache=None,
|
|
190
|
+
**kwargs,
|
|
191
|
+
):
|
|
192
|
+
spatial_shapes = kwargs.get("spatial_shapes", None)
|
|
193
|
+
pixel_attention_mask = kwargs.get("pixel_attention_mask", None)
|
|
194
|
+
input_embeddings_features = self.get_input_embeddings(
|
|
195
|
+
input_ids, pixel_values, spatial_shapes, pixel_attention_mask
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
logits = self.language_model(
|
|
199
|
+
input_ids, mask=None, cache=cache, inputs_embeds=input_embeddings_features
|
|
200
|
+
)
|
|
201
|
+
return logits
|
|
202
|
+
|
|
203
|
+
def sanitize(self, weights):
|
|
204
|
+
def transform_key(key):
|
|
205
|
+
if "vision_tower" in key:
|
|
206
|
+
key = (
|
|
207
|
+
key.replace("model.", "")
|
|
208
|
+
.replace("vision_encoder", "encoder")
|
|
209
|
+
.replace("vision_embeddings", "embeddings")
|
|
210
|
+
.replace("vision_post_layernorm", "post_layernorm")
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
if "language_model" in key:
|
|
214
|
+
key = key.replace("model.language_model", "language_model.model")
|
|
215
|
+
|
|
216
|
+
if "multi_modal_projector" in key:
|
|
217
|
+
key = key.replace(
|
|
218
|
+
"model.multi_modal_projector", "multi_modal_projector"
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
return key
|
|
222
|
+
|
|
223
|
+
return {transform_key(k): v for k, v in weights.items()}
|
|
@@ -0,0 +1,320 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Compatibility patch for Lfm2VlProcessor.
|
|
3
|
+
|
|
4
|
+
The Lfm2VlProcessorKwargs has a default `return_row_col_info: True` in images_kwargs,
|
|
5
|
+
but this parameter is only supported by the FAST image processor (Lfm2VlImageProcessorFast).
|
|
6
|
+
When using the slow image processor (Siglip2ImageProcessor), this causes a validation error.
|
|
7
|
+
|
|
8
|
+
This patch:
|
|
9
|
+
1. Removes the unsupported `return_row_col_info` parameter from the defaults
|
|
10
|
+
2. Enables `do_resize: True` to ensure images are properly resized for patch processing
|
|
11
|
+
3. Patches the `__call__` method to handle the slow image processor case, computing
|
|
12
|
+
`image_rows`, `image_cols`, `image_sizes` when missing and providing sensible
|
|
13
|
+
defaults for tile-related parameters
|
|
14
|
+
4. Patches the `__init__` to add missing attributes to the slow image processor
|
|
15
|
+
5. Forces the use of the slow image processor to avoid PyTorch tensor requirements
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
import math
|
|
19
|
+
|
|
20
|
+
import numpy as np
|
|
21
|
+
from transformers.models.lfm2_vl.processing_lfm2_vl import (
|
|
22
|
+
Lfm2VlProcessor,
|
|
23
|
+
Lfm2VlProcessorKwargs,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
# Try to import the slow image processor to force its use
|
|
27
|
+
try:
|
|
28
|
+
from transformers.models.siglip2.image_processing_siglip2 import (
|
|
29
|
+
Siglip2ImageProcessor,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
_SLOW_PROCESSOR_AVAILABLE = True
|
|
33
|
+
except ImportError:
|
|
34
|
+
_SLOW_PROCESSOR_AVAILABLE = False
|
|
35
|
+
|
|
36
|
+
# Remove return_row_col_info from the defaults since the slow image processor
|
|
37
|
+
# (Siglip2ImageProcessor) doesn't support it - only the fast version does.
|
|
38
|
+
# Also enable do_resize to ensure images are properly resized to be divisible by patch_size.
|
|
39
|
+
if hasattr(Lfm2VlProcessorKwargs, "_defaults"):
|
|
40
|
+
if "images_kwargs" in Lfm2VlProcessorKwargs._defaults:
|
|
41
|
+
Lfm2VlProcessorKwargs._defaults["images_kwargs"].pop(
|
|
42
|
+
"return_row_col_info", None
|
|
43
|
+
)
|
|
44
|
+
# Enable resizing for the slow image processor (model config has do_resize: False
|
|
45
|
+
# which is intended for the fast processor that handles resizing differently)
|
|
46
|
+
Lfm2VlProcessorKwargs._defaults["images_kwargs"]["do_resize"] = True
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
# Store the original __init__ method
|
|
50
|
+
_original_init = Lfm2VlProcessor.__init__
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _patched_init(self, image_processor, tokenizer, chat_template=None, **kwargs):
|
|
54
|
+
"""Patched __init__ that adds missing attributes to the slow image processor."""
|
|
55
|
+
# Check if we got the fast image processor and need to replace it with the slow one
|
|
56
|
+
# The fast processor requires PyTorch tensors which we don't have
|
|
57
|
+
processor_class_name = type(image_processor).__name__
|
|
58
|
+
if "Fast" in processor_class_name and _SLOW_PROCESSOR_AVAILABLE:
|
|
59
|
+
# Replace with slow processor using the same config
|
|
60
|
+
if hasattr(image_processor, "to_dict"):
|
|
61
|
+
# Use the config dict to create the slow processor
|
|
62
|
+
slow_processor = Siglip2ImageProcessor(**image_processor.to_dict())
|
|
63
|
+
else:
|
|
64
|
+
# Fallback to copying attributes
|
|
65
|
+
slow_processor = Siglip2ImageProcessor(
|
|
66
|
+
**{
|
|
67
|
+
k: v
|
|
68
|
+
for k, v in image_processor.__dict__.items()
|
|
69
|
+
if not k.startswith("_") and k not in ["name_or_path"]
|
|
70
|
+
}
|
|
71
|
+
)
|
|
72
|
+
image_processor = slow_processor
|
|
73
|
+
|
|
74
|
+
# Call original __init__
|
|
75
|
+
_original_init(
|
|
76
|
+
self, image_processor, tokenizer, chat_template=chat_template, **kwargs
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
# Add missing attributes for the slow image processor (Siglip2ImageProcessor)
|
|
80
|
+
# These are needed by expand_text_with_placeholders and _get_image_num_tokens
|
|
81
|
+
if not hasattr(self.image_processor, "tile_size"):
|
|
82
|
+
self.image_processor.tile_size = 512
|
|
83
|
+
if not hasattr(self.image_processor, "max_image_tokens"):
|
|
84
|
+
self.image_processor.max_image_tokens = 256
|
|
85
|
+
if not hasattr(self.image_processor, "min_image_tokens"):
|
|
86
|
+
self.image_processor.min_image_tokens = 64
|
|
87
|
+
if not hasattr(self.image_processor, "downsample_factor"):
|
|
88
|
+
self.image_processor.downsample_factor = 2
|
|
89
|
+
if not hasattr(self.image_processor, "encoder_patch_size"):
|
|
90
|
+
self.image_processor.encoder_patch_size = 16
|
|
91
|
+
if not hasattr(self.image_processor, "do_image_splitting"):
|
|
92
|
+
self.image_processor.do_image_splitting = (
|
|
93
|
+
False # Disable tiling for slow processor
|
|
94
|
+
)
|
|
95
|
+
if not hasattr(self.image_processor, "use_thumbnail"):
|
|
96
|
+
self.image_processor.use_thumbnail = False
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
# Apply the __init__ patch
|
|
100
|
+
Lfm2VlProcessor.__init__ = _patched_init
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def _compute_image_grid_info(pixel_values, patch_size: int = 16):
|
|
104
|
+
"""
|
|
105
|
+
Compute image_rows, image_cols, and image_sizes from pixel_values.
|
|
106
|
+
|
|
107
|
+
When using the slow image processor (Siglip2ImageProcessor), these values
|
|
108
|
+
are not returned. This function computes them from the pixel_values tensor.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
pixel_values: Array of shape (batch, num_patches, patch_dim)
|
|
112
|
+
patch_size: The patch size used for image processing
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
image_rows: List of rows per image
|
|
116
|
+
image_cols: List of cols per image
|
|
117
|
+
image_sizes: List of total patches per image
|
|
118
|
+
"""
|
|
119
|
+
# pixel_values shape: (batch, num_patches, patch_dim)
|
|
120
|
+
# For Siglip2, each image is processed independently and has its own num_patches
|
|
121
|
+
if hasattr(pixel_values, "shape"):
|
|
122
|
+
batch_size = pixel_values.shape[0]
|
|
123
|
+
num_patches = pixel_values.shape[1]
|
|
124
|
+
|
|
125
|
+
# Estimate rows/cols from num_patches (assuming roughly square)
|
|
126
|
+
# The actual image was resized to fit max_num_patches while maintaining aspect ratio
|
|
127
|
+
side_length = int(math.sqrt(num_patches))
|
|
128
|
+
|
|
129
|
+
# Return as nested lists (one list per batch, one value per image in batch)
|
|
130
|
+
image_rows = [[side_length] for _ in range(batch_size)]
|
|
131
|
+
image_cols = [[side_length] for _ in range(batch_size)]
|
|
132
|
+
image_sizes = [[num_patches] for _ in range(batch_size)]
|
|
133
|
+
|
|
134
|
+
return image_rows, image_cols, image_sizes
|
|
135
|
+
|
|
136
|
+
return [[1]], [[1]], [[1]]
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
# Store the original __call__ method
|
|
140
|
+
_original_call = Lfm2VlProcessor.__call__
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _ensure_slow_processor(processor_instance):
|
|
144
|
+
"""
|
|
145
|
+
Ensure we're using the slow image processor, not the fast one.
|
|
146
|
+
The fast processor only supports PyTorch tensors which we can't use without PyTorch.
|
|
147
|
+
"""
|
|
148
|
+
image_processor = processor_instance.image_processor
|
|
149
|
+
processor_class_name = type(image_processor).__name__
|
|
150
|
+
|
|
151
|
+
if "Fast" in processor_class_name and _SLOW_PROCESSOR_AVAILABLE:
|
|
152
|
+
# Need to replace with slow processor
|
|
153
|
+
# Get the config from the fast processor
|
|
154
|
+
config = (
|
|
155
|
+
image_processor.to_dict() if hasattr(image_processor, "to_dict") else {}
|
|
156
|
+
)
|
|
157
|
+
# Remove keys that might cause issues
|
|
158
|
+
config.pop("image_processor_type", None)
|
|
159
|
+
config.pop("auto_map", None)
|
|
160
|
+
config.pop("_processor_class", None)
|
|
161
|
+
|
|
162
|
+
# Create slow processor with the same config
|
|
163
|
+
slow_processor = Siglip2ImageProcessor(**config)
|
|
164
|
+
processor_instance.image_processor = slow_processor
|
|
165
|
+
|
|
166
|
+
# Re-add missing attributes
|
|
167
|
+
if not hasattr(processor_instance.image_processor, "tile_size"):
|
|
168
|
+
processor_instance.image_processor.tile_size = 512
|
|
169
|
+
if not hasattr(processor_instance.image_processor, "downsample_factor"):
|
|
170
|
+
processor_instance.image_processor.downsample_factor = 2
|
|
171
|
+
if not hasattr(processor_instance.image_processor, "do_image_splitting"):
|
|
172
|
+
processor_instance.image_processor.do_image_splitting = False
|
|
173
|
+
if not hasattr(processor_instance.image_processor, "use_thumbnail"):
|
|
174
|
+
processor_instance.image_processor.use_thumbnail = False
|
|
175
|
+
|
|
176
|
+
return processor_instance.image_processor
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def _patched_call(self, images=None, text=None, **kwargs):
|
|
180
|
+
"""
|
|
181
|
+
Patched __call__ that handles the slow image processor case.
|
|
182
|
+
|
|
183
|
+
The slow Siglip2ImageProcessor doesn't return image_rows, image_cols, image_sizes
|
|
184
|
+
which are required by expand_text_with_placeholders. This patch intercepts the call
|
|
185
|
+
and computes these values when they're missing.
|
|
186
|
+
"""
|
|
187
|
+
from transformers.feature_extraction_utils import BatchFeature
|
|
188
|
+
from transformers.image_utils import make_nested_list_of_images
|
|
189
|
+
|
|
190
|
+
# Ensure we're using the slow processor (fast requires PyTorch tensors)
|
|
191
|
+
if images is not None:
|
|
192
|
+
_ensure_slow_processor(self)
|
|
193
|
+
|
|
194
|
+
if images is None and text is not None:
|
|
195
|
+
# Text-only case
|
|
196
|
+
output_kwargs = self._merge_kwargs(
|
|
197
|
+
Lfm2VlProcessorKwargs,
|
|
198
|
+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
|
199
|
+
**kwargs,
|
|
200
|
+
)
|
|
201
|
+
output_kwargs["text_kwargs"].pop("use_image_special_tokens", None)
|
|
202
|
+
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
|
203
|
+
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
|
204
|
+
return BatchFeature(text_inputs, tensor_type=return_tensors)
|
|
205
|
+
|
|
206
|
+
if text is None and images is None:
|
|
207
|
+
raise ValueError("You must provide one of `text` or `images`.")
|
|
208
|
+
|
|
209
|
+
if images is not None and text is None:
|
|
210
|
+
raise ValueError(
|
|
211
|
+
"You must provide `text` when `images` is provided. Minimal text consists of a single image token."
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
# Merge kwargs to get the final settings
|
|
215
|
+
output_kwargs = self._merge_kwargs(
|
|
216
|
+
Lfm2VlProcessorKwargs,
|
|
217
|
+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
|
218
|
+
**kwargs,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
if isinstance(text, str):
|
|
222
|
+
text = [text]
|
|
223
|
+
elif text is not None and not isinstance(text, list):
|
|
224
|
+
raise TypeError(
|
|
225
|
+
"Invalid input text. Please provide a string, or a list of strings"
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
n_images_in_text = [sample.count(self.image_token) for sample in text]
|
|
229
|
+
|
|
230
|
+
inputs = {}
|
|
231
|
+
use_image_special_tokens = output_kwargs["text_kwargs"].pop(
|
|
232
|
+
"use_image_special_tokens", True
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
# Process images
|
|
236
|
+
images = self.image_processor.fetch_images(images)
|
|
237
|
+
batched_images = make_nested_list_of_images(images)
|
|
238
|
+
|
|
239
|
+
# Override return_tensors for image processing to avoid PyTorch dependency
|
|
240
|
+
images_kwargs = output_kwargs["images_kwargs"].copy()
|
|
241
|
+
images_kwargs["return_tensors"] = "np" # Use numpy instead of pt
|
|
242
|
+
|
|
243
|
+
vision_inputs = self.image_processor(batched_images, **images_kwargs)
|
|
244
|
+
|
|
245
|
+
n_images_in_images = [len(sublist) for sublist in batched_images]
|
|
246
|
+
if n_images_in_images != n_images_in_text:
|
|
247
|
+
raise ValueError(
|
|
248
|
+
f"The number of images in the text {n_images_in_text} and images {n_images_in_images} should be the same."
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
# Check if image_rows/cols/sizes are present (fast processor case)
|
|
252
|
+
if "image_rows" in vision_inputs:
|
|
253
|
+
image_rows = vision_inputs.pop("image_rows")
|
|
254
|
+
image_cols = vision_inputs.pop("image_cols")
|
|
255
|
+
image_sizes = vision_inputs.pop("image_sizes")
|
|
256
|
+
else:
|
|
257
|
+
# Slow processor case - compute from spatial_shapes or pixel_attention_mask
|
|
258
|
+
# The spatial_shapes gives the actual (height, width) in patches for each image
|
|
259
|
+
spatial_shapes = vision_inputs.get("spatial_shapes")
|
|
260
|
+
if spatial_shapes is not None:
|
|
261
|
+
# spatial_shapes is array of shape (batch, 2) with [height, width] in patches
|
|
262
|
+
image_rows = [[int(ss[0])] for ss in spatial_shapes]
|
|
263
|
+
image_cols = [[int(ss[1])] for ss in spatial_shapes]
|
|
264
|
+
image_sizes = [[int(ss[0] * ss[1])] for ss in spatial_shapes]
|
|
265
|
+
else:
|
|
266
|
+
# Fallback to computing from pixel_values
|
|
267
|
+
pixel_values = vision_inputs.get("pixel_values")
|
|
268
|
+
patch_size = getattr(self.image_processor, "patch_size", 16)
|
|
269
|
+
image_rows, image_cols, image_sizes = _compute_image_grid_info(
|
|
270
|
+
pixel_values, patch_size
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
# For slow processor, use simplified text expansion
|
|
274
|
+
# (no tiling support, just add image tokens)
|
|
275
|
+
# Account for downsample_factor: the vision tower reduces patches by factor^2
|
|
276
|
+
downsample_factor = getattr(self.image_processor, "downsample_factor", 2)
|
|
277
|
+
|
|
278
|
+
expanded_text = []
|
|
279
|
+
for sample_text, sample_images, rows, cols, sizes in zip(
|
|
280
|
+
text, batched_images, image_rows, image_cols, image_sizes
|
|
281
|
+
):
|
|
282
|
+
split_sample = sample_text.split(self.image_token)
|
|
283
|
+
result = ""
|
|
284
|
+
for i, _ in enumerate(sample_images):
|
|
285
|
+
result += split_sample[i]
|
|
286
|
+
if use_image_special_tokens:
|
|
287
|
+
result += self.image_start_token
|
|
288
|
+
# Add image tokens based on the number of patches AFTER downsampling
|
|
289
|
+
# The vision tower downsamples by factor^2, so divide by that
|
|
290
|
+
num_patches = sizes[i] if i < len(sizes) else sizes[0]
|
|
291
|
+
num_image_tokens = num_patches // (downsample_factor**2)
|
|
292
|
+
result += self.image_token * num_image_tokens
|
|
293
|
+
if use_image_special_tokens:
|
|
294
|
+
result += self.image_end_token
|
|
295
|
+
# Add any remaining text after the last image
|
|
296
|
+
if len(split_sample) > len(sample_images):
|
|
297
|
+
result += split_sample[-1]
|
|
298
|
+
expanded_text.append(result)
|
|
299
|
+
|
|
300
|
+
inputs.update(vision_inputs)
|
|
301
|
+
|
|
302
|
+
return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
|
|
303
|
+
|
|
304
|
+
text_inputs = self.tokenizer(expanded_text, **output_kwargs["text_kwargs"])
|
|
305
|
+
inputs.update(text_inputs)
|
|
306
|
+
|
|
307
|
+
# Convert lists to numpy arrays for proper handling by mlx_vlm
|
|
308
|
+
# The tokenizer returns lists but mlx_vlm expects numpy arrays
|
|
309
|
+
if isinstance(inputs.get("input_ids"), list):
|
|
310
|
+
inputs["input_ids"] = np.array(inputs["input_ids"])
|
|
311
|
+
if isinstance(inputs.get("attention_mask"), list):
|
|
312
|
+
inputs["attention_mask"] = np.array(inputs["attention_mask"])
|
|
313
|
+
|
|
314
|
+
return BatchFeature(
|
|
315
|
+
inputs, tensor_type=None
|
|
316
|
+
) # Don't convert, let mlx_vlm handle it
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
# Apply the patch
|
|
320
|
+
Lfm2VlProcessor.__call__ = _patched_call
|