fount-vlm-nell-02 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
- fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
- fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
- fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
- mlx_vlm/__init__.py +16 -0
- mlx_vlm/__main__.py +24 -0
- mlx_vlm/chat.py +234 -0
- mlx_vlm/chat_ui.py +508 -0
- mlx_vlm/convert.py +284 -0
- mlx_vlm/deprecation.py +52 -0
- mlx_vlm/evals/__init__.py +0 -0
- mlx_vlm/evals/math_vista.py +565 -0
- mlx_vlm/evals/mmmu.py +528 -0
- mlx_vlm/evals/mmstar.py +343 -0
- mlx_vlm/evals/ocrbench.py +453 -0
- mlx_vlm/evals/utils.py +37 -0
- mlx_vlm/generate.py +1457 -0
- mlx_vlm/lora.py +207 -0
- mlx_vlm/models/__init__.py +0 -0
- mlx_vlm/models/aya_vision/__init__.py +2 -0
- mlx_vlm/models/aya_vision/aya_vision.py +188 -0
- mlx_vlm/models/aya_vision/config.py +52 -0
- mlx_vlm/models/aya_vision/language.py +202 -0
- mlx_vlm/models/aya_vision/vision.py +340 -0
- mlx_vlm/models/base.py +356 -0
- mlx_vlm/models/cache.py +238 -0
- mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
- mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
- mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
- mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
- mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
- mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
- mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
- mlx_vlm/models/deepseekocr/__init__.py +2 -0
- mlx_vlm/models/deepseekocr/config.py +173 -0
- mlx_vlm/models/deepseekocr/conversation.py +264 -0
- mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
- mlx_vlm/models/deepseekocr/language.py +547 -0
- mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
- mlx_vlm/models/deepseekocr/sam.py +489 -0
- mlx_vlm/models/deepseekocr/vision.py +263 -0
- mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
- mlx_vlm/models/deepseekocr_2/config.py +216 -0
- mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
- mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
- mlx_vlm/models/deepseekocr_2/vision.py +439 -0
- mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
- mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
- mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
- mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
- mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
- mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
- mlx_vlm/models/fastvlm/__init__.py +2 -0
- mlx_vlm/models/fastvlm/config.py +79 -0
- mlx_vlm/models/fastvlm/fastvlm.py +198 -0
- mlx_vlm/models/fastvlm/language.py +49 -0
- mlx_vlm/models/fastvlm/vision.py +692 -0
- mlx_vlm/models/florence2/__init__.py +2 -0
- mlx_vlm/models/florence2/config.py +84 -0
- mlx_vlm/models/florence2/florence2.py +383 -0
- mlx_vlm/models/florence2/language.py +452 -0
- mlx_vlm/models/florence2/processing_florence2.py +30 -0
- mlx_vlm/models/florence2/vision.py +552 -0
- mlx_vlm/models/gemma3/__init__.py +2 -0
- mlx_vlm/models/gemma3/config.py +52 -0
- mlx_vlm/models/gemma3/gemma3.py +194 -0
- mlx_vlm/models/gemma3/language.py +293 -0
- mlx_vlm/models/gemma3/vision.py +215 -0
- mlx_vlm/models/gemma3n/__init__.py +2 -0
- mlx_vlm/models/gemma3n/audio.py +1038 -0
- mlx_vlm/models/gemma3n/config.py +130 -0
- mlx_vlm/models/gemma3n/gemma3n.py +322 -0
- mlx_vlm/models/gemma3n/language.py +631 -0
- mlx_vlm/models/gemma3n/vision.py +994 -0
- mlx_vlm/models/glm4v/__init__.py +3 -0
- mlx_vlm/models/glm4v/config.py +79 -0
- mlx_vlm/models/glm4v/glm4v.py +188 -0
- mlx_vlm/models/glm4v/language.py +574 -0
- mlx_vlm/models/glm4v/processing.py +220 -0
- mlx_vlm/models/glm4v/vision.py +406 -0
- mlx_vlm/models/glm4v_moe/__init__.py +3 -0
- mlx_vlm/models/glm4v_moe/config.py +81 -0
- mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
- mlx_vlm/models/glm4v_moe/language.py +674 -0
- mlx_vlm/models/glm4v_moe/processing.py +229 -0
- mlx_vlm/models/glm4v_moe/vision.py +405 -0
- mlx_vlm/models/glm_ocr/__init__.py +3 -0
- mlx_vlm/models/glm_ocr/config.py +93 -0
- mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
- mlx_vlm/models/glm_ocr/language.py +585 -0
- mlx_vlm/models/glm_ocr/processing.py +208 -0
- mlx_vlm/models/glm_ocr/vision.py +342 -0
- mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
- mlx_vlm/models/hunyuan_vl/config.py +136 -0
- mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
- mlx_vlm/models/hunyuan_vl/language.py +509 -0
- mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
- mlx_vlm/models/hunyuan_vl/vision.py +322 -0
- mlx_vlm/models/idefics2/__init__.py +2 -0
- mlx_vlm/models/idefics2/config.py +65 -0
- mlx_vlm/models/idefics2/idefics2.py +321 -0
- mlx_vlm/models/idefics2/language.py +161 -0
- mlx_vlm/models/idefics2/vision.py +244 -0
- mlx_vlm/models/idefics3/__init__.py +4 -0
- mlx_vlm/models/idefics3/config.py +54 -0
- mlx_vlm/models/idefics3/idefics3.py +221 -0
- mlx_vlm/models/idefics3/language.py +157 -0
- mlx_vlm/models/idefics3/vision.py +265 -0
- mlx_vlm/models/internvl_chat/__init__.py +3 -0
- mlx_vlm/models/internvl_chat/config.py +89 -0
- mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
- mlx_vlm/models/internvl_chat/language.py +187 -0
- mlx_vlm/models/internvl_chat/processor.py +395 -0
- mlx_vlm/models/internvl_chat/vision.py +265 -0
- mlx_vlm/models/interpolate.py +183 -0
- mlx_vlm/models/jina_vlm/__init__.py +3 -0
- mlx_vlm/models/jina_vlm/config.py +142 -0
- mlx_vlm/models/jina_vlm/image_processor.py +430 -0
- mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
- mlx_vlm/models/jina_vlm/language.py +272 -0
- mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
- mlx_vlm/models/jina_vlm/vision.py +202 -0
- mlx_vlm/models/kernels.py +447 -0
- mlx_vlm/models/kimi_vl/__init__.py +4 -0
- mlx_vlm/models/kimi_vl/config.py +84 -0
- mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
- mlx_vlm/models/kimi_vl/language.py +460 -0
- mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
- mlx_vlm/models/kimi_vl/vision.py +485 -0
- mlx_vlm/models/lfm2_vl/__init__.py +2 -0
- mlx_vlm/models/lfm2_vl/config.py +94 -0
- mlx_vlm/models/lfm2_vl/language.py +49 -0
- mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
- mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
- mlx_vlm/models/lfm2_vl/vision.py +223 -0
- mlx_vlm/models/llama4/__init__.py +2 -0
- mlx_vlm/models/llama4/config.py +83 -0
- mlx_vlm/models/llama4/language.py +334 -0
- mlx_vlm/models/llama4/llama4.py +146 -0
- mlx_vlm/models/llama4/vision.py +526 -0
- mlx_vlm/models/llava/__init__.py +2 -0
- mlx_vlm/models/llava/config.py +61 -0
- mlx_vlm/models/llava/language.py +200 -0
- mlx_vlm/models/llava/llava.py +132 -0
- mlx_vlm/models/llava/vision.py +233 -0
- mlx_vlm/models/llava_bunny/__init__.py +2 -0
- mlx_vlm/models/llava_bunny/config.py +85 -0
- mlx_vlm/models/llava_bunny/language.py +194 -0
- mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
- mlx_vlm/models/llava_bunny/vision.py +278 -0
- mlx_vlm/models/llava_next/__init__.py +2 -0
- mlx_vlm/models/llava_next/config.py +60 -0
- mlx_vlm/models/llava_next/language.py +192 -0
- mlx_vlm/models/llava_next/llava_next.py +138 -0
- mlx_vlm/models/llava_next/vision.py +217 -0
- mlx_vlm/models/mistral3/__init__.py +2 -0
- mlx_vlm/models/mistral3/config.py +59 -0
- mlx_vlm/models/mistral3/language.py +269 -0
- mlx_vlm/models/mistral3/mistral3.py +383 -0
- mlx_vlm/models/mllama/__init__.py +4 -0
- mlx_vlm/models/mllama/config.py +74 -0
- mlx_vlm/models/mllama/language.py +377 -0
- mlx_vlm/models/mllama/mllama.py +210 -0
- mlx_vlm/models/mllama/vision.py +458 -0
- mlx_vlm/models/molmo/__init__.py +5 -0
- mlx_vlm/models/molmo/config.py +93 -0
- mlx_vlm/models/molmo/language.py +208 -0
- mlx_vlm/models/molmo/molmo.py +108 -0
- mlx_vlm/models/molmo/processing_molmo.py +763 -0
- mlx_vlm/models/molmo/vision.py +408 -0
- mlx_vlm/models/molmo2/__init__.py +6 -0
- mlx_vlm/models/molmo2/config.py +137 -0
- mlx_vlm/models/molmo2/language.py +206 -0
- mlx_vlm/models/molmo2/molmo2.py +330 -0
- mlx_vlm/models/molmo2/processing.py +773 -0
- mlx_vlm/models/molmo2/vision.py +286 -0
- mlx_vlm/models/moondream2/__init__.py +11 -0
- mlx_vlm/models/moondream2/config.py +92 -0
- mlx_vlm/models/moondream2/image_crops.py +269 -0
- mlx_vlm/models/moondream2/language.py +267 -0
- mlx_vlm/models/moondream2/moondream2.py +522 -0
- mlx_vlm/models/moondream2/processing_moondream.py +144 -0
- mlx_vlm/models/moondream2/vision.py +200 -0
- mlx_vlm/models/multi_modality/__init__.py +4 -0
- mlx_vlm/models/multi_modality/config.py +108 -0
- mlx_vlm/models/multi_modality/language.py +191 -0
- mlx_vlm/models/multi_modality/multi_modality.py +338 -0
- mlx_vlm/models/multi_modality/sam.py +543 -0
- mlx_vlm/models/multi_modality/vision.py +450 -0
- mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
- mlx_vlm/models/paddleocr_vl/config.py +93 -0
- mlx_vlm/models/paddleocr_vl/language.py +522 -0
- mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
- mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
- mlx_vlm/models/paddleocr_vl/vision.py +358 -0
- mlx_vlm/models/paligemma/__init__.py +4 -0
- mlx_vlm/models/paligemma/config.py +50 -0
- mlx_vlm/models/paligemma/language.py +253 -0
- mlx_vlm/models/paligemma/paligemma.py +140 -0
- mlx_vlm/models/paligemma/vision.py +218 -0
- mlx_vlm/models/phi3_v/__init__.py +5 -0
- mlx_vlm/models/phi3_v/config.py +55 -0
- mlx_vlm/models/phi3_v/language.py +2 -0
- mlx_vlm/models/phi3_v/phi3_v.py +239 -0
- mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
- mlx_vlm/models/phi3_v/vision.py +294 -0
- mlx_vlm/models/pixtral/__init__.py +4 -0
- mlx_vlm/models/pixtral/config.py +69 -0
- mlx_vlm/models/pixtral/language.py +195 -0
- mlx_vlm/models/pixtral/pixtral.py +208 -0
- mlx_vlm/models/pixtral/vision.py +293 -0
- mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_5_vl/config.py +90 -0
- mlx_vlm/models/qwen2_5_vl/language.py +541 -0
- mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
- mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
- mlx_vlm/models/qwen2_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_vl/config.py +86 -0
- mlx_vlm/models/qwen2_vl/language.py +539 -0
- mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
- mlx_vlm/models/qwen2_vl/vision.py +308 -0
- mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
- mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
- mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
- mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
- mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
- mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
- mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
- mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
- mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
- mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
- mlx_vlm/models/qwen3_vl/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl/config.py +103 -0
- mlx_vlm/models/qwen3_vl/language.py +596 -0
- mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
- mlx_vlm/models/qwen3_vl/vision.py +441 -0
- mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
- mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
- mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
- mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
- mlx_vlm/models/smolvlm/__init__.py +4 -0
- mlx_vlm/models/smolvlm/config.py +59 -0
- mlx_vlm/models/smolvlm/smolvlm.py +60 -0
- mlx_vlm/prompt_utils.py +565 -0
- mlx_vlm/sample_utils.py +39 -0
- mlx_vlm/server.py +1107 -0
- mlx_vlm/smolvlm_video_generate.py +109 -0
- mlx_vlm/tokenizer_utils.py +371 -0
- mlx_vlm/trainer/__init__.py +9 -0
- mlx_vlm/trainer/lora.py +70 -0
- mlx_vlm/trainer/trainer.py +299 -0
- mlx_vlm/trainer/utils.py +160 -0
- mlx_vlm/utils.py +1339 -0
- mlx_vlm/version.py +1 -0
- mlx_vlm/video_generate.py +611 -0
|
@@ -0,0 +1,383 @@
|
|
|
1
|
+
from typing import List, Optional, Tuple, Union
|
|
2
|
+
|
|
3
|
+
import mlx.core as mx
|
|
4
|
+
import mlx.nn as nn
|
|
5
|
+
|
|
6
|
+
from ..base import InputEmbeddingsFeatures
|
|
7
|
+
from ..pixtral import VisionModel
|
|
8
|
+
from .config import ModelConfig
|
|
9
|
+
from .language import LanguageModel
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _pair(x) -> Tuple[int, int]:
|
|
13
|
+
"""Convert input to a pair of values."""
|
|
14
|
+
if isinstance(x, (list, tuple)):
|
|
15
|
+
return tuple(x)
|
|
16
|
+
return (x, x)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def unfold(
|
|
20
|
+
input: mx.array,
|
|
21
|
+
kernel_size: Union[int, Tuple[int, int], List[int]],
|
|
22
|
+
dilation: Union[int, Tuple[int, int], List[int]] = 1,
|
|
23
|
+
padding: Union[int, Tuple[int, int], List[int]] = 0,
|
|
24
|
+
stride: Union[int, Tuple[int, int], List[int]] = 1,
|
|
25
|
+
) -> mx.array:
|
|
26
|
+
"""
|
|
27
|
+
Extract sliding local blocks from a batched input tensor (MLX implementation).
|
|
28
|
+
|
|
29
|
+
This is equivalent to PyTorch's nn.functional.unfold or im2col operation.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
input: Input tensor of shape (B, C, H, W)
|
|
33
|
+
kernel_size: Size of the sliding blocks
|
|
34
|
+
dilation: Controls the spacing between kernel elements
|
|
35
|
+
padding: Controls the amount of implicit padding
|
|
36
|
+
stride: Controls the stride between blocks
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
Unfolded tensor of shape (B, C*kernel_height*kernel_width, L)
|
|
40
|
+
where L is the number of blocks
|
|
41
|
+
"""
|
|
42
|
+
# Convert to pairs
|
|
43
|
+
kernel_size = _pair(kernel_size)
|
|
44
|
+
dilation = _pair(dilation)
|
|
45
|
+
padding = _pair(padding)
|
|
46
|
+
stride = _pair(stride)
|
|
47
|
+
|
|
48
|
+
# Input shape
|
|
49
|
+
batch_size, channels, height, width = input.shape
|
|
50
|
+
|
|
51
|
+
# Add padding if needed
|
|
52
|
+
if padding[0] > 0 or padding[1] > 0:
|
|
53
|
+
padding_shape = (
|
|
54
|
+
(0, 0),
|
|
55
|
+
(0, 0),
|
|
56
|
+
(padding[0], padding[0]),
|
|
57
|
+
(padding[1], padding[1]),
|
|
58
|
+
)
|
|
59
|
+
input = mx.pad(input, padding_shape)
|
|
60
|
+
|
|
61
|
+
# Calculate output dimensions
|
|
62
|
+
height_out = (
|
|
63
|
+
height + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1
|
|
64
|
+
) // stride[0] + 1
|
|
65
|
+
width_out = (
|
|
66
|
+
width + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1
|
|
67
|
+
) // stride[1] + 1
|
|
68
|
+
|
|
69
|
+
# Initialize output arrays
|
|
70
|
+
blocks = []
|
|
71
|
+
|
|
72
|
+
# Extract blocks
|
|
73
|
+
for i in range(
|
|
74
|
+
0, height + 2 * padding[0] - kernel_size[0] * dilation[0] + 1, stride[0]
|
|
75
|
+
):
|
|
76
|
+
for j in range(
|
|
77
|
+
0, width + 2 * padding[1] - kernel_size[1] * dilation[1] + 1, stride[1]
|
|
78
|
+
):
|
|
79
|
+
# Extract the block for all channels
|
|
80
|
+
block = []
|
|
81
|
+
for di in range(kernel_size[0]):
|
|
82
|
+
for dj in range(kernel_size[1]):
|
|
83
|
+
h_idx = i + di * dilation[0]
|
|
84
|
+
w_idx = j + dj * dilation[1]
|
|
85
|
+
# Get the block for all channels and add to our list
|
|
86
|
+
block.append(input[:, :, h_idx, w_idx])
|
|
87
|
+
|
|
88
|
+
# Stack the channel-blocks
|
|
89
|
+
block = mx.stack(block, axis=1) # Shape: (B, k*k, C)
|
|
90
|
+
block = mx.transpose(block, [0, 2, 1]) # Shape: (B, C, k*k)
|
|
91
|
+
blocks.append(block)
|
|
92
|
+
|
|
93
|
+
# Stack all blocks together
|
|
94
|
+
result = mx.stack(blocks, axis=-1) # Shape: (B, C, k*k, L)
|
|
95
|
+
|
|
96
|
+
# Reshape to match PyTorch's unfold output format: (B, C*k*k, L)
|
|
97
|
+
result = mx.reshape(
|
|
98
|
+
result,
|
|
99
|
+
(
|
|
100
|
+
batch_size,
|
|
101
|
+
channels * kernel_size[0] * kernel_size[1],
|
|
102
|
+
height_out * width_out,
|
|
103
|
+
),
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
return result
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class Mistral3PatchMerger(nn.Module):
|
|
110
|
+
"""
|
|
111
|
+
Learned merging of spatial_merge_size ** 2 patches
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
def __init__(self, config: ModelConfig):
|
|
115
|
+
super().__init__()
|
|
116
|
+
self.config = config
|
|
117
|
+
|
|
118
|
+
hidden_size = config.vision_config.hidden_size
|
|
119
|
+
self.spatial_merge_size = config.spatial_merge_size
|
|
120
|
+
self.patch_size = self.config.vision_config.patch_size
|
|
121
|
+
self.merging_layer = nn.Linear(
|
|
122
|
+
hidden_size * self.spatial_merge_size**2, hidden_size, bias=False
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
def __call__(self, image_features: mx.array, image_sizes: mx.array) -> mx.array:
|
|
126
|
+
|
|
127
|
+
image_sizes = [
|
|
128
|
+
(image_size[0] // self.patch_size, image_size[1] // self.patch_size)
|
|
129
|
+
for image_size in image_sizes
|
|
130
|
+
]
|
|
131
|
+
|
|
132
|
+
tokens_per_image = [h * w for h, w in image_sizes]
|
|
133
|
+
d = image_features.shape[-1]
|
|
134
|
+
image_features = image_features.astype(mx.bfloat16)
|
|
135
|
+
image_sizes = mx.array(image_sizes)
|
|
136
|
+
|
|
137
|
+
# Split the image features into chunks based on tokens_per_image
|
|
138
|
+
split_indices = []
|
|
139
|
+
current_index = 0
|
|
140
|
+
for tokens in tokens_per_image:
|
|
141
|
+
split_indices.append(current_index + tokens)
|
|
142
|
+
current_index += tokens
|
|
143
|
+
|
|
144
|
+
# Perform the split
|
|
145
|
+
chunks = mx.split(image_features, split_indices[:-1], axis=1)
|
|
146
|
+
|
|
147
|
+
permuted_tensor = []
|
|
148
|
+
for image_index, image_tokens in enumerate(chunks):
|
|
149
|
+
|
|
150
|
+
# Reshape image_tokens into a 2D grid
|
|
151
|
+
if image_tokens.shape[1] > 0:
|
|
152
|
+
h, w = image_sizes[image_index].tolist()
|
|
153
|
+
|
|
154
|
+
image_grid = image_tokens.reshape(h, w, d).transpose(2, 0, 1)[None, ...]
|
|
155
|
+
|
|
156
|
+
grid = unfold(
|
|
157
|
+
image_grid,
|
|
158
|
+
kernel_size=self.spatial_merge_size,
|
|
159
|
+
stride=self.spatial_merge_size,
|
|
160
|
+
)
|
|
161
|
+
grid = grid.reshape(d * self.spatial_merge_size**2, -1).T
|
|
162
|
+
permuted_tensor.append(grid)
|
|
163
|
+
|
|
164
|
+
image_features = mx.concatenate(permuted_tensor, axis=0)
|
|
165
|
+
image_features = self.merging_layer(image_features)
|
|
166
|
+
return image_features[None, ...]
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class Mistral3MultiModalProjector(nn.Module):
|
|
170
|
+
def __init__(self, config: ModelConfig):
|
|
171
|
+
super().__init__()
|
|
172
|
+
|
|
173
|
+
self.norm = nn.RMSNorm(config.vision_config.hidden_size)
|
|
174
|
+
self.patch_merger = Mistral3PatchMerger(config)
|
|
175
|
+
|
|
176
|
+
num_feature_layers = (
|
|
177
|
+
1
|
|
178
|
+
if isinstance(config.vision_feature_layer, int)
|
|
179
|
+
else len(config.vision_feature_layer)
|
|
180
|
+
)
|
|
181
|
+
self.linear_1 = nn.Linear(
|
|
182
|
+
config.vision_config.hidden_size * num_feature_layers,
|
|
183
|
+
config.text_config.hidden_size,
|
|
184
|
+
bias=config.multimodal_projector_bias,
|
|
185
|
+
)
|
|
186
|
+
self.gelu = nn.GELU()
|
|
187
|
+
self.linear_2 = nn.Linear(
|
|
188
|
+
config.text_config.hidden_size,
|
|
189
|
+
config.text_config.hidden_size,
|
|
190
|
+
bias=config.multimodal_projector_bias,
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
def __call__(self, x: mx.array, image_sizes: mx.array) -> mx.array:
|
|
194
|
+
x = self.norm(x)
|
|
195
|
+
|
|
196
|
+
x = self.patch_merger(x, image_sizes)
|
|
197
|
+
x = self.linear_1(x)
|
|
198
|
+
x = self.gelu(x)
|
|
199
|
+
x = self.linear_2(x)
|
|
200
|
+
return x
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
class Model(nn.Module):
|
|
204
|
+
def __init__(self, config: ModelConfig):
|
|
205
|
+
super().__init__()
|
|
206
|
+
self.config = config
|
|
207
|
+
|
|
208
|
+
self.multi_modal_projector = Mistral3MultiModalProjector(config)
|
|
209
|
+
self.vision_tower = VisionModel(config.vision_config)
|
|
210
|
+
self.language_model = LanguageModel(config.text_config)
|
|
211
|
+
self.vision_feature_layer = config.vision_feature_layer
|
|
212
|
+
|
|
213
|
+
def get_input_embeddings(
|
|
214
|
+
self,
|
|
215
|
+
input_ids: Optional[mx.array] = None,
|
|
216
|
+
pixel_values: Optional[mx.array] = None,
|
|
217
|
+
**kwargs,
|
|
218
|
+
):
|
|
219
|
+
image_sizes = kwargs.get("image_sizes", None)
|
|
220
|
+
|
|
221
|
+
if pixel_values is None:
|
|
222
|
+
return InputEmbeddingsFeatures(
|
|
223
|
+
inputs_embeds=self.language_model.model.embed_tokens(input_ids)
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
# Get the input embeddings from the language model
|
|
227
|
+
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
|
228
|
+
|
|
229
|
+
# Get the output hidden states from the vision model
|
|
230
|
+
if isinstance(pixel_values, list):
|
|
231
|
+
pixel_values = mx.concatenate(
|
|
232
|
+
[mx.array(pv)[None, ...] for pv in pixel_values], axis=0
|
|
233
|
+
)
|
|
234
|
+
if pixel_values.ndim == 3:
|
|
235
|
+
pixel_values = pixel_values[None, ...]
|
|
236
|
+
|
|
237
|
+
# Pass pixel_values as list of images, as each image is individually run through conv2d and position encoding
|
|
238
|
+
# Reference code from transformers: https://github.com/huggingface/transformers/blob/main/src/transformers/models/pixtral/modeling_pixtral.py#L479C9-L479C21
|
|
239
|
+
# and mistral_inference: https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/vision_encoder.py#L85
|
|
240
|
+
*_, hidden_states = self.vision_tower(
|
|
241
|
+
pixel_values.transpose(0, 2, 3, 1),
|
|
242
|
+
output_hidden_states=True,
|
|
243
|
+
)
|
|
244
|
+
# Select the hidden states from the desired layer
|
|
245
|
+
selected_image_feature = hidden_states[self.vision_feature_layer]
|
|
246
|
+
|
|
247
|
+
# Pass image features through the multi-modal projector
|
|
248
|
+
image_features = self.multi_modal_projector(selected_image_feature, image_sizes)
|
|
249
|
+
|
|
250
|
+
# Insert special image tokens in the input_ids
|
|
251
|
+
final_inputs_embeds = self.merge_input_ids_with_image_features(
|
|
252
|
+
self.config.image_token_index, image_features, inputs_embeds, input_ids
|
|
253
|
+
)
|
|
254
|
+
return InputEmbeddingsFeatures(inputs_embeds=final_inputs_embeds)
|
|
255
|
+
|
|
256
|
+
@staticmethod
|
|
257
|
+
def merge_input_ids_with_image_features(
|
|
258
|
+
image_token_index, image_features, inputs_embeds, input_ids
|
|
259
|
+
):
|
|
260
|
+
"""Merge image features into input embeddings at image token positions.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
image_token_index: Token ID for image placeholder
|
|
264
|
+
image_features: Vision features from the projector [1, num_features, hidden_dim]
|
|
265
|
+
inputs_embeds: Input embeddings [batch_size, seq_len, hidden_dim]
|
|
266
|
+
input_ids: Input token IDs [batch_size, seq_len]
|
|
267
|
+
|
|
268
|
+
Returns:
|
|
269
|
+
Updated input embeddings with image features inserted
|
|
270
|
+
"""
|
|
271
|
+
# Remove the extra batch dimension from image_features if present
|
|
272
|
+
if image_features.ndim == 3 and image_features.shape[0] == 1:
|
|
273
|
+
image_features = image_features.squeeze(0) # [num_features, hidden_dim]
|
|
274
|
+
|
|
275
|
+
# Positions of <image> tokens in input_ids
|
|
276
|
+
image_positions = input_ids == image_token_index
|
|
277
|
+
|
|
278
|
+
# Get dimensions
|
|
279
|
+
batch_size, seq_len = input_ids.shape
|
|
280
|
+
|
|
281
|
+
# Process each batch item
|
|
282
|
+
batch_outputs = []
|
|
283
|
+
feature_start_idx = 0
|
|
284
|
+
|
|
285
|
+
for batch_idx in range(batch_size):
|
|
286
|
+
# Get mask for this batch
|
|
287
|
+
image_mask = image_positions[batch_idx]
|
|
288
|
+
num_positions = mx.sum(image_mask).item()
|
|
289
|
+
|
|
290
|
+
if num_positions > 0:
|
|
291
|
+
# Extract features for this batch
|
|
292
|
+
batch_features = image_features[
|
|
293
|
+
feature_start_idx : feature_start_idx + num_positions
|
|
294
|
+
]
|
|
295
|
+
|
|
296
|
+
# Validate we have the right number of features
|
|
297
|
+
if batch_features.shape[0] != num_positions:
|
|
298
|
+
raise ValueError(
|
|
299
|
+
f"Number of image token positions ({num_positions}) does not match "
|
|
300
|
+
f"number of image features ({batch_features.shape[0]}) for batch {batch_idx}"
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
# Create indices for gathering
|
|
304
|
+
cumsum = mx.cumsum(image_mask.astype(mx.int32))
|
|
305
|
+
feature_indices = mx.where(image_mask, cumsum - 1, 0)
|
|
306
|
+
|
|
307
|
+
# Gather features
|
|
308
|
+
gathered_features = batch_features[feature_indices]
|
|
309
|
+
|
|
310
|
+
# Combine with original embeddings
|
|
311
|
+
image_mask_expanded = mx.expand_dims(image_mask, axis=-1)
|
|
312
|
+
batch_output = mx.where(
|
|
313
|
+
image_mask_expanded, gathered_features, inputs_embeds[batch_idx]
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
feature_start_idx += num_positions
|
|
317
|
+
else:
|
|
318
|
+
# No image tokens in this batch item
|
|
319
|
+
batch_output = inputs_embeds[batch_idx]
|
|
320
|
+
|
|
321
|
+
batch_outputs.append(batch_output)
|
|
322
|
+
|
|
323
|
+
# Stack all batch outputs
|
|
324
|
+
return mx.stack(batch_outputs, axis=0)
|
|
325
|
+
|
|
326
|
+
def __call__(
|
|
327
|
+
self,
|
|
328
|
+
input_ids: mx.array,
|
|
329
|
+
pixel_values: mx.array,
|
|
330
|
+
mask: mx.array,
|
|
331
|
+
cache=None,
|
|
332
|
+
**kwargs,
|
|
333
|
+
):
|
|
334
|
+
input_embeddings_features = self.get_input_embeddings(
|
|
335
|
+
input_ids, pixel_values, **kwargs
|
|
336
|
+
)
|
|
337
|
+
logits = self.language_model(
|
|
338
|
+
input_ids,
|
|
339
|
+
cache=cache,
|
|
340
|
+
inputs_embeds=input_embeddings_features.inputs_embeds,
|
|
341
|
+
)
|
|
342
|
+
return logits
|
|
343
|
+
|
|
344
|
+
def sanitize(self, weights):
|
|
345
|
+
def transform_key(key):
|
|
346
|
+
if "vision_tower" in key and "vision_model" not in key:
|
|
347
|
+
if "transformer" in key:
|
|
348
|
+
key = key.replace("vision_tower", "vision_tower.vision_model")
|
|
349
|
+
if "patch_conv" in key:
|
|
350
|
+
key = key.replace("vision_tower", "vision_tower.vision_model")
|
|
351
|
+
if "ln_pre" in key:
|
|
352
|
+
key = key.replace("vision_tower", "vision_tower.vision_model")
|
|
353
|
+
|
|
354
|
+
elif "vision_encoder" in key and "vision_tower" not in key:
|
|
355
|
+
if "transformer" in key:
|
|
356
|
+
key = key.replace(
|
|
357
|
+
"model.vision_encoder", "vision_tower.vision_model"
|
|
358
|
+
)
|
|
359
|
+
if "patch_conv" in key:
|
|
360
|
+
key = key.replace(
|
|
361
|
+
"model.vision_encoder", "vision_tower.vision_model"
|
|
362
|
+
)
|
|
363
|
+
if "ln_pre" in key:
|
|
364
|
+
key = key.replace(
|
|
365
|
+
"model.vision_encoder", "vision_tower.vision_model"
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
elif "model.language_model" in key and "language_model.model" not in key:
|
|
369
|
+
key = key.replace("model.language_model", "language_model.model")
|
|
370
|
+
|
|
371
|
+
elif "lm_head" in key and "language_model" not in key:
|
|
372
|
+
key = key.replace("lm_head", "language_model.lm_head")
|
|
373
|
+
|
|
374
|
+
elif "model.vision_projection" in key:
|
|
375
|
+
key = key.replace("model.vision_projection", "multi_modal_projector")
|
|
376
|
+
|
|
377
|
+
return key
|
|
378
|
+
|
|
379
|
+
return {transform_key(k): v for k, v in weights.items()}
|
|
380
|
+
|
|
381
|
+
@property
|
|
382
|
+
def layers(self):
|
|
383
|
+
return self.language_model.model.layers
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from typing import Dict, List, Optional, Tuple, Union
|
|
3
|
+
|
|
4
|
+
from ..base import BaseModelConfig
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class TextConfig(BaseModelConfig):
|
|
9
|
+
model_type: str = "mllama"
|
|
10
|
+
vocab_size: int = 32000
|
|
11
|
+
hidden_size: int = 4096
|
|
12
|
+
intermediate_size: int = 14336
|
|
13
|
+
num_hidden_layers: int = 40
|
|
14
|
+
num_attention_heads: int = 32
|
|
15
|
+
num_key_value_heads: int = 8
|
|
16
|
+
hidden_act: str = "silu"
|
|
17
|
+
max_position_embeddings: int = 131072
|
|
18
|
+
initializer_range: float = 0.02
|
|
19
|
+
rms_norm_eps: float = 1e-6
|
|
20
|
+
tie_word_embeddings: bool = False
|
|
21
|
+
rope_theta: float = 10000.0
|
|
22
|
+
rope_traditional: bool = False
|
|
23
|
+
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
|
24
|
+
cross_attention_layers: List[int] = field(
|
|
25
|
+
default_factory=lambda: [3, 8, 13, 18, 23, 28, 33, 38]
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
def __post_init__(self):
|
|
29
|
+
if self.num_key_value_heads is None:
|
|
30
|
+
self.num_key_value_heads = self.num_attention_heads
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class VisionConfig(BaseModelConfig):
|
|
35
|
+
image_size: int = 560
|
|
36
|
+
patch_size: int = 14
|
|
37
|
+
num_channels: int = 3
|
|
38
|
+
hidden_size: int = 1280
|
|
39
|
+
intermediate_size: int = 5120
|
|
40
|
+
num_hidden_layers: int = 32
|
|
41
|
+
num_attention_heads: int = 16
|
|
42
|
+
max_num_tiles: int = 4
|
|
43
|
+
max_aspect_ratio_id: int = 8
|
|
44
|
+
num_global_layers: int = 8
|
|
45
|
+
norm_eps: float = 1e-5
|
|
46
|
+
attention_dropout: float = 0.0
|
|
47
|
+
hidden_dropout: float = 0.0
|
|
48
|
+
vision_output_dim: int = 7680
|
|
49
|
+
intermediate_layers_indices: List[int] = field(
|
|
50
|
+
default_factory=lambda: [3, 7, 15, 23, 30]
|
|
51
|
+
)
|
|
52
|
+
supported_aspect_ratios: Tuple[List[int]] = (
|
|
53
|
+
[1, 1],
|
|
54
|
+
[1, 2],
|
|
55
|
+
[1, 3],
|
|
56
|
+
[1, 4],
|
|
57
|
+
[2, 1],
|
|
58
|
+
[2, 2],
|
|
59
|
+
[3, 1],
|
|
60
|
+
[4, 1],
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@dataclass
|
|
65
|
+
class ModelConfig(BaseModelConfig):
|
|
66
|
+
text_config: TextConfig
|
|
67
|
+
vision_config: VisionConfig
|
|
68
|
+
model_type: str
|
|
69
|
+
ignore_index: int = -100
|
|
70
|
+
image_token_index: int = 128256
|
|
71
|
+
vision_feature_select_strategy: str = "default"
|
|
72
|
+
vision_feature_layer: int = -2
|
|
73
|
+
vocab_size: int = 32000
|
|
74
|
+
eos_token_id: Optional[List[int]] = None
|