fount-vlm-nell-02 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
- fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
- fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
- fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
- mlx_vlm/__init__.py +16 -0
- mlx_vlm/__main__.py +24 -0
- mlx_vlm/chat.py +234 -0
- mlx_vlm/chat_ui.py +508 -0
- mlx_vlm/convert.py +284 -0
- mlx_vlm/deprecation.py +52 -0
- mlx_vlm/evals/__init__.py +0 -0
- mlx_vlm/evals/math_vista.py +565 -0
- mlx_vlm/evals/mmmu.py +528 -0
- mlx_vlm/evals/mmstar.py +343 -0
- mlx_vlm/evals/ocrbench.py +453 -0
- mlx_vlm/evals/utils.py +37 -0
- mlx_vlm/generate.py +1457 -0
- mlx_vlm/lora.py +207 -0
- mlx_vlm/models/__init__.py +0 -0
- mlx_vlm/models/aya_vision/__init__.py +2 -0
- mlx_vlm/models/aya_vision/aya_vision.py +188 -0
- mlx_vlm/models/aya_vision/config.py +52 -0
- mlx_vlm/models/aya_vision/language.py +202 -0
- mlx_vlm/models/aya_vision/vision.py +340 -0
- mlx_vlm/models/base.py +356 -0
- mlx_vlm/models/cache.py +238 -0
- mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
- mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
- mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
- mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
- mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
- mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
- mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
- mlx_vlm/models/deepseekocr/__init__.py +2 -0
- mlx_vlm/models/deepseekocr/config.py +173 -0
- mlx_vlm/models/deepseekocr/conversation.py +264 -0
- mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
- mlx_vlm/models/deepseekocr/language.py +547 -0
- mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
- mlx_vlm/models/deepseekocr/sam.py +489 -0
- mlx_vlm/models/deepseekocr/vision.py +263 -0
- mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
- mlx_vlm/models/deepseekocr_2/config.py +216 -0
- mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
- mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
- mlx_vlm/models/deepseekocr_2/vision.py +439 -0
- mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
- mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
- mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
- mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
- mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
- mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
- mlx_vlm/models/fastvlm/__init__.py +2 -0
- mlx_vlm/models/fastvlm/config.py +79 -0
- mlx_vlm/models/fastvlm/fastvlm.py +198 -0
- mlx_vlm/models/fastvlm/language.py +49 -0
- mlx_vlm/models/fastvlm/vision.py +692 -0
- mlx_vlm/models/florence2/__init__.py +2 -0
- mlx_vlm/models/florence2/config.py +84 -0
- mlx_vlm/models/florence2/florence2.py +383 -0
- mlx_vlm/models/florence2/language.py +452 -0
- mlx_vlm/models/florence2/processing_florence2.py +30 -0
- mlx_vlm/models/florence2/vision.py +552 -0
- mlx_vlm/models/gemma3/__init__.py +2 -0
- mlx_vlm/models/gemma3/config.py +52 -0
- mlx_vlm/models/gemma3/gemma3.py +194 -0
- mlx_vlm/models/gemma3/language.py +293 -0
- mlx_vlm/models/gemma3/vision.py +215 -0
- mlx_vlm/models/gemma3n/__init__.py +2 -0
- mlx_vlm/models/gemma3n/audio.py +1038 -0
- mlx_vlm/models/gemma3n/config.py +130 -0
- mlx_vlm/models/gemma3n/gemma3n.py +322 -0
- mlx_vlm/models/gemma3n/language.py +631 -0
- mlx_vlm/models/gemma3n/vision.py +994 -0
- mlx_vlm/models/glm4v/__init__.py +3 -0
- mlx_vlm/models/glm4v/config.py +79 -0
- mlx_vlm/models/glm4v/glm4v.py +188 -0
- mlx_vlm/models/glm4v/language.py +574 -0
- mlx_vlm/models/glm4v/processing.py +220 -0
- mlx_vlm/models/glm4v/vision.py +406 -0
- mlx_vlm/models/glm4v_moe/__init__.py +3 -0
- mlx_vlm/models/glm4v_moe/config.py +81 -0
- mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
- mlx_vlm/models/glm4v_moe/language.py +674 -0
- mlx_vlm/models/glm4v_moe/processing.py +229 -0
- mlx_vlm/models/glm4v_moe/vision.py +405 -0
- mlx_vlm/models/glm_ocr/__init__.py +3 -0
- mlx_vlm/models/glm_ocr/config.py +93 -0
- mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
- mlx_vlm/models/glm_ocr/language.py +585 -0
- mlx_vlm/models/glm_ocr/processing.py +208 -0
- mlx_vlm/models/glm_ocr/vision.py +342 -0
- mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
- mlx_vlm/models/hunyuan_vl/config.py +136 -0
- mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
- mlx_vlm/models/hunyuan_vl/language.py +509 -0
- mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
- mlx_vlm/models/hunyuan_vl/vision.py +322 -0
- mlx_vlm/models/idefics2/__init__.py +2 -0
- mlx_vlm/models/idefics2/config.py +65 -0
- mlx_vlm/models/idefics2/idefics2.py +321 -0
- mlx_vlm/models/idefics2/language.py +161 -0
- mlx_vlm/models/idefics2/vision.py +244 -0
- mlx_vlm/models/idefics3/__init__.py +4 -0
- mlx_vlm/models/idefics3/config.py +54 -0
- mlx_vlm/models/idefics3/idefics3.py +221 -0
- mlx_vlm/models/idefics3/language.py +157 -0
- mlx_vlm/models/idefics3/vision.py +265 -0
- mlx_vlm/models/internvl_chat/__init__.py +3 -0
- mlx_vlm/models/internvl_chat/config.py +89 -0
- mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
- mlx_vlm/models/internvl_chat/language.py +187 -0
- mlx_vlm/models/internvl_chat/processor.py +395 -0
- mlx_vlm/models/internvl_chat/vision.py +265 -0
- mlx_vlm/models/interpolate.py +183 -0
- mlx_vlm/models/jina_vlm/__init__.py +3 -0
- mlx_vlm/models/jina_vlm/config.py +142 -0
- mlx_vlm/models/jina_vlm/image_processor.py +430 -0
- mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
- mlx_vlm/models/jina_vlm/language.py +272 -0
- mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
- mlx_vlm/models/jina_vlm/vision.py +202 -0
- mlx_vlm/models/kernels.py +447 -0
- mlx_vlm/models/kimi_vl/__init__.py +4 -0
- mlx_vlm/models/kimi_vl/config.py +84 -0
- mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
- mlx_vlm/models/kimi_vl/language.py +460 -0
- mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
- mlx_vlm/models/kimi_vl/vision.py +485 -0
- mlx_vlm/models/lfm2_vl/__init__.py +2 -0
- mlx_vlm/models/lfm2_vl/config.py +94 -0
- mlx_vlm/models/lfm2_vl/language.py +49 -0
- mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
- mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
- mlx_vlm/models/lfm2_vl/vision.py +223 -0
- mlx_vlm/models/llama4/__init__.py +2 -0
- mlx_vlm/models/llama4/config.py +83 -0
- mlx_vlm/models/llama4/language.py +334 -0
- mlx_vlm/models/llama4/llama4.py +146 -0
- mlx_vlm/models/llama4/vision.py +526 -0
- mlx_vlm/models/llava/__init__.py +2 -0
- mlx_vlm/models/llava/config.py +61 -0
- mlx_vlm/models/llava/language.py +200 -0
- mlx_vlm/models/llava/llava.py +132 -0
- mlx_vlm/models/llava/vision.py +233 -0
- mlx_vlm/models/llava_bunny/__init__.py +2 -0
- mlx_vlm/models/llava_bunny/config.py +85 -0
- mlx_vlm/models/llava_bunny/language.py +194 -0
- mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
- mlx_vlm/models/llava_bunny/vision.py +278 -0
- mlx_vlm/models/llava_next/__init__.py +2 -0
- mlx_vlm/models/llava_next/config.py +60 -0
- mlx_vlm/models/llava_next/language.py +192 -0
- mlx_vlm/models/llava_next/llava_next.py +138 -0
- mlx_vlm/models/llava_next/vision.py +217 -0
- mlx_vlm/models/mistral3/__init__.py +2 -0
- mlx_vlm/models/mistral3/config.py +59 -0
- mlx_vlm/models/mistral3/language.py +269 -0
- mlx_vlm/models/mistral3/mistral3.py +383 -0
- mlx_vlm/models/mllama/__init__.py +4 -0
- mlx_vlm/models/mllama/config.py +74 -0
- mlx_vlm/models/mllama/language.py +377 -0
- mlx_vlm/models/mllama/mllama.py +210 -0
- mlx_vlm/models/mllama/vision.py +458 -0
- mlx_vlm/models/molmo/__init__.py +5 -0
- mlx_vlm/models/molmo/config.py +93 -0
- mlx_vlm/models/molmo/language.py +208 -0
- mlx_vlm/models/molmo/molmo.py +108 -0
- mlx_vlm/models/molmo/processing_molmo.py +763 -0
- mlx_vlm/models/molmo/vision.py +408 -0
- mlx_vlm/models/molmo2/__init__.py +6 -0
- mlx_vlm/models/molmo2/config.py +137 -0
- mlx_vlm/models/molmo2/language.py +206 -0
- mlx_vlm/models/molmo2/molmo2.py +330 -0
- mlx_vlm/models/molmo2/processing.py +773 -0
- mlx_vlm/models/molmo2/vision.py +286 -0
- mlx_vlm/models/moondream2/__init__.py +11 -0
- mlx_vlm/models/moondream2/config.py +92 -0
- mlx_vlm/models/moondream2/image_crops.py +269 -0
- mlx_vlm/models/moondream2/language.py +267 -0
- mlx_vlm/models/moondream2/moondream2.py +522 -0
- mlx_vlm/models/moondream2/processing_moondream.py +144 -0
- mlx_vlm/models/moondream2/vision.py +200 -0
- mlx_vlm/models/multi_modality/__init__.py +4 -0
- mlx_vlm/models/multi_modality/config.py +108 -0
- mlx_vlm/models/multi_modality/language.py +191 -0
- mlx_vlm/models/multi_modality/multi_modality.py +338 -0
- mlx_vlm/models/multi_modality/sam.py +543 -0
- mlx_vlm/models/multi_modality/vision.py +450 -0
- mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
- mlx_vlm/models/paddleocr_vl/config.py +93 -0
- mlx_vlm/models/paddleocr_vl/language.py +522 -0
- mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
- mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
- mlx_vlm/models/paddleocr_vl/vision.py +358 -0
- mlx_vlm/models/paligemma/__init__.py +4 -0
- mlx_vlm/models/paligemma/config.py +50 -0
- mlx_vlm/models/paligemma/language.py +253 -0
- mlx_vlm/models/paligemma/paligemma.py +140 -0
- mlx_vlm/models/paligemma/vision.py +218 -0
- mlx_vlm/models/phi3_v/__init__.py +5 -0
- mlx_vlm/models/phi3_v/config.py +55 -0
- mlx_vlm/models/phi3_v/language.py +2 -0
- mlx_vlm/models/phi3_v/phi3_v.py +239 -0
- mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
- mlx_vlm/models/phi3_v/vision.py +294 -0
- mlx_vlm/models/pixtral/__init__.py +4 -0
- mlx_vlm/models/pixtral/config.py +69 -0
- mlx_vlm/models/pixtral/language.py +195 -0
- mlx_vlm/models/pixtral/pixtral.py +208 -0
- mlx_vlm/models/pixtral/vision.py +293 -0
- mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_5_vl/config.py +90 -0
- mlx_vlm/models/qwen2_5_vl/language.py +541 -0
- mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
- mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
- mlx_vlm/models/qwen2_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_vl/config.py +86 -0
- mlx_vlm/models/qwen2_vl/language.py +539 -0
- mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
- mlx_vlm/models/qwen2_vl/vision.py +308 -0
- mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
- mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
- mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
- mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
- mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
- mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
- mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
- mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
- mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
- mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
- mlx_vlm/models/qwen3_vl/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl/config.py +103 -0
- mlx_vlm/models/qwen3_vl/language.py +596 -0
- mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
- mlx_vlm/models/qwen3_vl/vision.py +441 -0
- mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
- mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
- mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
- mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
- mlx_vlm/models/smolvlm/__init__.py +4 -0
- mlx_vlm/models/smolvlm/config.py +59 -0
- mlx_vlm/models/smolvlm/smolvlm.py +60 -0
- mlx_vlm/prompt_utils.py +565 -0
- mlx_vlm/sample_utils.py +39 -0
- mlx_vlm/server.py +1107 -0
- mlx_vlm/smolvlm_video_generate.py +109 -0
- mlx_vlm/tokenizer_utils.py +371 -0
- mlx_vlm/trainer/__init__.py +9 -0
- mlx_vlm/trainer/lora.py +70 -0
- mlx_vlm/trainer/trainer.py +299 -0
- mlx_vlm/trainer/utils.py +160 -0
- mlx_vlm/utils.py +1339 -0
- mlx_vlm/version.py +1 -0
- mlx_vlm/video_generate.py +611 -0
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
"""Processor for GLM-4V-MoE model.
|
|
2
|
+
|
|
3
|
+
Handles image/video token expansion based on grid dimensions and merge size.
|
|
4
|
+
Based on the HuggingFace transformers GLM-4.6V processor implementation.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import List, Union
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
from transformers.feature_extraction_utils import BatchFeature
|
|
11
|
+
from transformers.processing_utils import ProcessorMixin
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Glm46VMoEProcessor(ProcessorMixin):
|
|
15
|
+
"""
|
|
16
|
+
Processor for GLM-4V-MoE that wraps an image processor and tokenizer.
|
|
17
|
+
|
|
18
|
+
Handles:
|
|
19
|
+
- Image preprocessing via image_processor
|
|
20
|
+
- Token replacement for image/video placeholders based on grid dimensions
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
attributes = ["image_processor", "tokenizer"]
|
|
24
|
+
valid_kwargs = ["chat_template"]
|
|
25
|
+
image_processor_class = "AutoImageProcessor"
|
|
26
|
+
tokenizer_class = "AutoTokenizer"
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
image_processor=None,
|
|
31
|
+
tokenizer=None,
|
|
32
|
+
chat_template=None,
|
|
33
|
+
**kwargs,
|
|
34
|
+
):
|
|
35
|
+
self.tokenizer = tokenizer
|
|
36
|
+
self.image_processor = image_processor
|
|
37
|
+
|
|
38
|
+
# Get image/video tokens from tokenizer or use defaults
|
|
39
|
+
self.image_token = "<|image|>"
|
|
40
|
+
self.video_token = "<|video|>"
|
|
41
|
+
|
|
42
|
+
if tokenizer is not None:
|
|
43
|
+
self.image_token = getattr(tokenizer, "image_token", "<|image|>")
|
|
44
|
+
self.video_token = getattr(tokenizer, "video_token", "<|video|>")
|
|
45
|
+
|
|
46
|
+
# Get token IDs
|
|
47
|
+
self.image_token_id = getattr(tokenizer, "image_token_id", None)
|
|
48
|
+
if self.image_token_id is None:
|
|
49
|
+
self.image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
|
|
50
|
+
|
|
51
|
+
self.video_token_id = getattr(tokenizer, "video_token_id", None)
|
|
52
|
+
if self.video_token_id is None:
|
|
53
|
+
self.video_token_id = tokenizer.convert_tokens_to_ids(self.video_token)
|
|
54
|
+
else:
|
|
55
|
+
self.image_token_id = None
|
|
56
|
+
self.video_token_id = None
|
|
57
|
+
|
|
58
|
+
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
|
59
|
+
|
|
60
|
+
def __call__(
|
|
61
|
+
self,
|
|
62
|
+
images=None,
|
|
63
|
+
text: Union[str, List[str]] = None,
|
|
64
|
+
videos=None,
|
|
65
|
+
**kwargs,
|
|
66
|
+
) -> BatchFeature:
|
|
67
|
+
"""
|
|
68
|
+
Process images/videos and text for the model.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
images: Single image or list of images (PIL.Image, np.ndarray, etc.)
|
|
72
|
+
text: Single text or list of texts
|
|
73
|
+
videos: Video inputs (optional)
|
|
74
|
+
**kwargs: Additional arguments passed to image_processor and tokenizer
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
BatchFeature with:
|
|
78
|
+
- input_ids: Token IDs with image/video placeholders expanded
|
|
79
|
+
- attention_mask: Attention mask
|
|
80
|
+
- pixel_values: Processed image/video patches
|
|
81
|
+
- image_grid_thw: Grid dimensions for each image
|
|
82
|
+
- video_grid_thw: Grid dimensions for each video (if videos provided)
|
|
83
|
+
"""
|
|
84
|
+
image_inputs = {}
|
|
85
|
+
video_inputs = {}
|
|
86
|
+
image_grid_thw = None
|
|
87
|
+
video_grid_thw = None
|
|
88
|
+
|
|
89
|
+
# Pop tokenizer-specific kwargs that shouldn't go to image processor
|
|
90
|
+
padding = kwargs.pop("padding", False)
|
|
91
|
+
return_token_type_ids = kwargs.pop("return_token_type_ids", False)
|
|
92
|
+
return_tensors = kwargs.pop("return_tensors", None)
|
|
93
|
+
|
|
94
|
+
# Process images
|
|
95
|
+
if images is not None and self.image_processor is not None:
|
|
96
|
+
image_inputs = self.image_processor(images=images)
|
|
97
|
+
image_grid_thw = image_inputs.get("image_grid_thw")
|
|
98
|
+
|
|
99
|
+
# Process videos
|
|
100
|
+
if videos is not None:
|
|
101
|
+
if hasattr(self, "video_processor") and self.video_processor is not None:
|
|
102
|
+
video_inputs = self.video_processor(videos=videos, **kwargs)
|
|
103
|
+
video_grid_thw = video_inputs.get("video_grid_thw")
|
|
104
|
+
|
|
105
|
+
# Handle text input
|
|
106
|
+
if text is None:
|
|
107
|
+
text = [""]
|
|
108
|
+
elif not isinstance(text, list):
|
|
109
|
+
text = [text]
|
|
110
|
+
|
|
111
|
+
# Make a copy to avoid modifying original
|
|
112
|
+
text = [t for t in text]
|
|
113
|
+
|
|
114
|
+
# Get merge_size from image_processor
|
|
115
|
+
merge_size = getattr(self.image_processor, "merge_size", 2)
|
|
116
|
+
if hasattr(self.image_processor, "spatial_merge_size"):
|
|
117
|
+
merge_size = self.image_processor.spatial_merge_size
|
|
118
|
+
merge_length = merge_size**2
|
|
119
|
+
|
|
120
|
+
# Expand image tokens based on grid dimensions
|
|
121
|
+
if image_grid_thw is not None:
|
|
122
|
+
index = 0
|
|
123
|
+
for i in range(len(text)):
|
|
124
|
+
while self.image_token in text[i]:
|
|
125
|
+
# Calculate number of image tokens: prod(grid_thw) / merge_size^2
|
|
126
|
+
grid = image_grid_thw[index]
|
|
127
|
+
if hasattr(grid, "tolist"):
|
|
128
|
+
grid = grid.tolist()
|
|
129
|
+
num_image_tokens = int(np.prod(grid) // merge_length)
|
|
130
|
+
|
|
131
|
+
# Replace single image token with correct number of placeholder tokens
|
|
132
|
+
text[i] = text[i].replace(
|
|
133
|
+
self.image_token,
|
|
134
|
+
"<|placeholder|>" * num_image_tokens,
|
|
135
|
+
1,
|
|
136
|
+
)
|
|
137
|
+
index += 1
|
|
138
|
+
# Replace placeholders back to image tokens
|
|
139
|
+
text[i] = text[i].replace("<|placeholder|>", self.image_token)
|
|
140
|
+
|
|
141
|
+
# Expand video tokens based on grid dimensions
|
|
142
|
+
if video_grid_thw is not None:
|
|
143
|
+
video_index = 0
|
|
144
|
+
for i in range(len(text)):
|
|
145
|
+
while self.video_token in text[i]:
|
|
146
|
+
grid = video_grid_thw[video_index]
|
|
147
|
+
if hasattr(grid, "tolist"):
|
|
148
|
+
grid = grid.tolist()
|
|
149
|
+
|
|
150
|
+
num_frames = grid[0]
|
|
151
|
+
# Calculate tokens per frame
|
|
152
|
+
num_tokens_per_frame = int(
|
|
153
|
+
np.prod(grid) // merge_length // num_frames
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
# Build video structure with frame tokens
|
|
157
|
+
video_structure = ""
|
|
158
|
+
for frame_idx in range(num_frames):
|
|
159
|
+
# Add image tokens for this frame
|
|
160
|
+
frame_structure = self.image_token * num_tokens_per_frame
|
|
161
|
+
video_structure += frame_structure
|
|
162
|
+
|
|
163
|
+
text[i] = text[i].replace(self.video_token, video_structure, 1)
|
|
164
|
+
video_index += 1
|
|
165
|
+
|
|
166
|
+
# Tokenize text
|
|
167
|
+
text_inputs = self.tokenizer(
|
|
168
|
+
text,
|
|
169
|
+
padding=padding,
|
|
170
|
+
return_token_type_ids=return_token_type_ids,
|
|
171
|
+
**kwargs,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
return BatchFeature(
|
|
175
|
+
data={**text_inputs, **image_inputs, **video_inputs},
|
|
176
|
+
tensor_type=return_tensors,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
def batch_decode(self, *args, **kwargs):
|
|
180
|
+
"""Decode token IDs to text."""
|
|
181
|
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
|
182
|
+
|
|
183
|
+
def decode(self, *args, **kwargs):
|
|
184
|
+
"""Decode token IDs to text."""
|
|
185
|
+
return self.tokenizer.decode(*args, **kwargs)
|
|
186
|
+
|
|
187
|
+
def apply_chat_template(self, *args, **kwargs):
|
|
188
|
+
"""Apply chat template using the tokenizer."""
|
|
189
|
+
return self.tokenizer.apply_chat_template(*args, **kwargs)
|
|
190
|
+
|
|
191
|
+
@property
|
|
192
|
+
def model_input_names(self):
|
|
193
|
+
"""Return combined input names from tokenizer and image processor."""
|
|
194
|
+
tokenizer_input_names = (
|
|
195
|
+
self.tokenizer.model_input_names if self.tokenizer else []
|
|
196
|
+
)
|
|
197
|
+
image_processor_input_names = (
|
|
198
|
+
self.image_processor.model_input_names
|
|
199
|
+
if hasattr(self.image_processor, "model_input_names")
|
|
200
|
+
else []
|
|
201
|
+
)
|
|
202
|
+
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
|
203
|
+
|
|
204
|
+
@classmethod
|
|
205
|
+
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
|
206
|
+
"""Load processor from pretrained model path."""
|
|
207
|
+
from transformers import AutoTokenizer, Glm4vImageProcessor
|
|
208
|
+
|
|
209
|
+
trust_remote_code = kwargs.pop("trust_remote_code", True)
|
|
210
|
+
|
|
211
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
212
|
+
pretrained_model_name_or_path,
|
|
213
|
+
trust_remote_code=trust_remote_code,
|
|
214
|
+
**kwargs,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
image_processor = Glm4vImageProcessor.from_pretrained(
|
|
218
|
+
pretrained_model_name_or_path,
|
|
219
|
+
trust_remote_code=trust_remote_code,
|
|
220
|
+
**kwargs,
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
return cls(image_processor=image_processor, tokenizer=tokenizer, **kwargs)
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
__all__ = ["Glm46VMoEProcessor"]
|
|
227
|
+
|
|
228
|
+
# Alias for backwards compatibility
|
|
229
|
+
Glm4VMoEProcessor = Glm46VMoEProcessor
|
|
@@ -0,0 +1,405 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import mlx.core as mx
|
|
4
|
+
import mlx.nn as nn
|
|
5
|
+
|
|
6
|
+
from ..kernels import grid_sample
|
|
7
|
+
from .config import VisionConfig
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def check_array_shape(arr):
|
|
11
|
+
shape = arr.shape
|
|
12
|
+
|
|
13
|
+
# Check if the shape has 4 or 5 dimensions
|
|
14
|
+
if len(shape) == 4:
|
|
15
|
+
out_channels, kH, KW, _ = shape
|
|
16
|
+
# Check if out_channels is the largest, and kH and KW are the same
|
|
17
|
+
return (out_channels >= kH) and (out_channels >= KW) and (kH == KW)
|
|
18
|
+
elif len(shape) == 5:
|
|
19
|
+
B, out_channels, kH, KW, t = shape
|
|
20
|
+
# Special case for temporal dimension
|
|
21
|
+
if t == 3:
|
|
22
|
+
return True
|
|
23
|
+
# Check if out_channels is the largest, and kH and KW are the same
|
|
24
|
+
return (out_channels >= kH) and (out_channels >= KW) and (kH == KW)
|
|
25
|
+
else:
|
|
26
|
+
return False
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def rotate_half(x):
|
|
30
|
+
"""Rotates half the hidden dims of the input."""
|
|
31
|
+
x1 = x[..., : x.shape[-1] // 2]
|
|
32
|
+
x2 = x[..., x.shape[-1] // 2 :]
|
|
33
|
+
return mx.concatenate([-x2, x1], axis=-1)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def apply_rotary_pos_emb_vision(tensor, freqs) -> mx.array:
|
|
37
|
+
orig_dtype = tensor.dtype
|
|
38
|
+
|
|
39
|
+
cos = mx.cos(freqs)
|
|
40
|
+
sin = mx.sin(freqs)
|
|
41
|
+
|
|
42
|
+
cos = mx.expand_dims(cos, axis=1) # Equivalent to unsqueeze(1)
|
|
43
|
+
cos = mx.tile(cos, (1, 1, 2)) # Equivalent to repeat(1, 1, 2)
|
|
44
|
+
cos = mx.expand_dims(cos, axis=0) # Equivalent to [None, ...]
|
|
45
|
+
|
|
46
|
+
sin = mx.expand_dims(sin, axis=1) # Equivalent to unsqueeze(1)
|
|
47
|
+
sin = mx.tile(sin, (1, 1, 2)) # Equivalent to repeat(1, 1, 2)
|
|
48
|
+
sin = mx.expand_dims(sin, axis=0) # Equivalent to [None, ...]
|
|
49
|
+
|
|
50
|
+
output = (tensor * cos) + (rotate_half(tensor) * sin)
|
|
51
|
+
return output.astype(orig_dtype)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class Glm4vMoeVisionRotaryEmbedding(nn.Module):
|
|
55
|
+
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
|
56
|
+
super().__init__()
|
|
57
|
+
self.dim = dim
|
|
58
|
+
self.theta = theta
|
|
59
|
+
|
|
60
|
+
def __call__(self, seqlen: int) -> mx.array:
|
|
61
|
+
inv_freq = 1.0 / (
|
|
62
|
+
self.theta ** (mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim)
|
|
63
|
+
)
|
|
64
|
+
seq = mx.arange(seqlen.item(), dtype=inv_freq.dtype)
|
|
65
|
+
freqs = mx.outer(seq, inv_freq)
|
|
66
|
+
return freqs
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class Glm4vVisionEmbeddings(nn.Module):
|
|
70
|
+
def __init__(self, config: VisionConfig):
|
|
71
|
+
super().__init__()
|
|
72
|
+
self.config = config
|
|
73
|
+
self.embed_dim = config.hidden_size
|
|
74
|
+
self.image_size = config.image_size
|
|
75
|
+
self.patch_size = config.patch_size
|
|
76
|
+
|
|
77
|
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
|
78
|
+
self.num_positions = self.num_patches
|
|
79
|
+
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
|
80
|
+
|
|
81
|
+
def __call__(self, embeddings, lengths, image_shapes, h_coords, w_coords):
|
|
82
|
+
|
|
83
|
+
# Get position embedding parameters
|
|
84
|
+
pos_embed_weight = self.position_embedding.weight
|
|
85
|
+
hidden_size = pos_embed_weight.shape[1]
|
|
86
|
+
total_seq = h_coords.shape[0]
|
|
87
|
+
|
|
88
|
+
# Handle empty sequence case
|
|
89
|
+
if total_seq == 0:
|
|
90
|
+
adapted_pos_embed = mx.empty(0, hidden_size, dtype=pos_embed_weight.dtype)
|
|
91
|
+
else:
|
|
92
|
+
# Convert inputs to tensors if needed
|
|
93
|
+
if isinstance(lengths, list):
|
|
94
|
+
lengths = mx.array(lengths, dtype=mx.int32)
|
|
95
|
+
if not isinstance(image_shapes, mx.array):
|
|
96
|
+
image_shapes = mx.array(image_shapes, dtype=mx.int32)
|
|
97
|
+
|
|
98
|
+
# Prepare 2D position embedding
|
|
99
|
+
orig_size_sq = pos_embed_weight.shape[0]
|
|
100
|
+
orig_size = int(orig_size_sq**0.5)
|
|
101
|
+
pos_embed_2d = (
|
|
102
|
+
pos_embed_weight.reshape(orig_size, orig_size, hidden_size)
|
|
103
|
+
.transpose(2, 0, 1)[None, ...]
|
|
104
|
+
.astype(mx.float32)
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
# Calculate target dimensions for each patch
|
|
108
|
+
target_h = mx.concatenate(
|
|
109
|
+
[mx.repeat(image_shapes[i, 1], lengths[i]) for i in range(len(lengths))]
|
|
110
|
+
).astype(mx.float32)
|
|
111
|
+
target_w = mx.concatenate(
|
|
112
|
+
[mx.repeat(image_shapes[i, 2], lengths[i]) for i in range(len(lengths))]
|
|
113
|
+
).astype(mx.float32)
|
|
114
|
+
|
|
115
|
+
# Normalize coordinates to [-1, 1] range for grid_sample
|
|
116
|
+
h_coords = h_coords.astype(mx.float32)
|
|
117
|
+
w_coords = w_coords.astype(mx.float32)
|
|
118
|
+
norm_w = ((w_coords + 0.5) / target_w) * 2 - 1
|
|
119
|
+
norm_h = ((h_coords + 0.5) / target_h) * 2 - 1
|
|
120
|
+
|
|
121
|
+
# Create sampling grid
|
|
122
|
+
grid = mx.stack((norm_w, norm_h), axis=-1)[None, :, None, ...]
|
|
123
|
+
|
|
124
|
+
# Perform bicubic interpolation
|
|
125
|
+
interpolated_embed_fp32 = grid_sample(
|
|
126
|
+
pos_embed_2d.transpose(0, 2, 3, 1),
|
|
127
|
+
grid,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# Reshape and convert back to original dtype
|
|
131
|
+
adapted_pos_embed_fp32 = interpolated_embed_fp32.squeeze(0).squeeze(1)
|
|
132
|
+
adapted_pos_embed = adapted_pos_embed_fp32.astype(pos_embed_weight.dtype)
|
|
133
|
+
|
|
134
|
+
# Add adapted position encoding to embeddings
|
|
135
|
+
embeddings = embeddings + adapted_pos_embed
|
|
136
|
+
return embeddings
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class Glm4vMoeVisionPatchEmbed(nn.Module):
|
|
140
|
+
def __init__(self, config: VisionConfig) -> None:
|
|
141
|
+
super().__init__()
|
|
142
|
+
self.config = config
|
|
143
|
+
self.patch_size = config.patch_size
|
|
144
|
+
self.temporal_patch_size = config.temporal_patch_size
|
|
145
|
+
self.in_channels = config.in_channels
|
|
146
|
+
self.embed_dim = config.hidden_size
|
|
147
|
+
|
|
148
|
+
kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
|
|
149
|
+
self.proj = nn.Conv3d(
|
|
150
|
+
self.in_channels,
|
|
151
|
+
self.embed_dim,
|
|
152
|
+
kernel_size=kernel_size,
|
|
153
|
+
stride=kernel_size,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
def __call__(self, hidden_states: mx.array) -> mx.array:
|
|
157
|
+
hidden_states = hidden_states.reshape(
|
|
158
|
+
-1,
|
|
159
|
+
self.in_channels,
|
|
160
|
+
self.temporal_patch_size,
|
|
161
|
+
self.patch_size,
|
|
162
|
+
self.patch_size,
|
|
163
|
+
).moveaxis(1, 4)
|
|
164
|
+
|
|
165
|
+
hidden_states = self.proj(hidden_states)
|
|
166
|
+
hidden_states = hidden_states.reshape(-1, self.embed_dim)
|
|
167
|
+
return hidden_states
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
class Glm4vMoeVisionPatchMerger(nn.Module):
|
|
171
|
+
def __init__(self, dim: int, context_dim: int, bias: bool = False) -> None:
|
|
172
|
+
super().__init__()
|
|
173
|
+
|
|
174
|
+
self.proj = nn.Linear(dim, dim, bias=bias)
|
|
175
|
+
self.post_projection_norm = nn.LayerNorm(dim)
|
|
176
|
+
self.gate_proj = nn.Linear(dim, context_dim, bias=bias)
|
|
177
|
+
self.up_proj = nn.Linear(dim, context_dim, bias=bias)
|
|
178
|
+
self.down_proj = nn.Linear(context_dim, dim, bias=bias)
|
|
179
|
+
|
|
180
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
181
|
+
x = self.proj(x)
|
|
182
|
+
x = nn.gelu(self.post_projection_norm(x))
|
|
183
|
+
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
class Glm4vMoeVisionAttention(nn.Module):
|
|
187
|
+
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
|
188
|
+
super().__init__()
|
|
189
|
+
self.num_heads = num_heads
|
|
190
|
+
self.head_dim = head_dim = dim // num_heads
|
|
191
|
+
self.scale = head_dim**-0.5
|
|
192
|
+
self.qkv = nn.Linear(dim, dim * 3, bias=False)
|
|
193
|
+
self.proj = nn.Linear(dim, dim, bias=False)
|
|
194
|
+
|
|
195
|
+
def __call__(
|
|
196
|
+
self, x: mx.array, cu_seqlens: mx.array, rotary_pos_emb: mx.array = None
|
|
197
|
+
) -> mx.array:
|
|
198
|
+
seq_length = x.shape[0]
|
|
199
|
+
qkv = (
|
|
200
|
+
self.qkv(x).reshape(seq_length, 3, self.num_heads, -1).transpose(1, 0, 2, 3)
|
|
201
|
+
)
|
|
202
|
+
q, k, v = mx.split(qkv, 3)
|
|
203
|
+
|
|
204
|
+
q = apply_rotary_pos_emb_vision(mx.expand_dims(q, 0), rotary_pos_emb)[0]
|
|
205
|
+
k = apply_rotary_pos_emb_vision(mx.expand_dims(k, 0), rotary_pos_emb)[0]
|
|
206
|
+
|
|
207
|
+
attention_mask = mx.full(
|
|
208
|
+
(1, seq_length, seq_length), mx.finfo(q.dtype).min, dtype=q.dtype
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
for i in range(1, len(cu_seqlens)):
|
|
212
|
+
start = int(cu_seqlens[i - 1])
|
|
213
|
+
end = int(cu_seqlens[i])
|
|
214
|
+
attention_mask[..., start:end, start:end] = 0
|
|
215
|
+
|
|
216
|
+
q = q.transpose(0, 2, 1, 3)
|
|
217
|
+
k = k.transpose(0, 2, 1, 3)
|
|
218
|
+
v = v.transpose(0, 2, 1, 3)
|
|
219
|
+
|
|
220
|
+
output = mx.fast.scaled_dot_product_attention(
|
|
221
|
+
q, k, v, scale=self.scale, mask=attention_mask
|
|
222
|
+
)
|
|
223
|
+
output = output.transpose(0, 2, 1, 3)
|
|
224
|
+
output = output.reshape(seq_length, -1)
|
|
225
|
+
return self.proj(output)
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
class Glm4vMoeVisionMLP(nn.Module):
|
|
229
|
+
def __init__(self, dim, hidden_dim):
|
|
230
|
+
super().__init__()
|
|
231
|
+
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
|
|
232
|
+
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
|
|
233
|
+
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
|
|
234
|
+
|
|
235
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
236
|
+
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
class Glm4vMoeVisionBlock(nn.Module):
|
|
240
|
+
def __init__(self, config: VisionConfig) -> None:
|
|
241
|
+
super().__init__()
|
|
242
|
+
self.norm1 = nn.RMSNorm(config.hidden_size, eps=1e-6)
|
|
243
|
+
self.norm2 = nn.RMSNorm(config.hidden_size, eps=1e-6)
|
|
244
|
+
|
|
245
|
+
self.attn = Glm4vMoeVisionAttention(
|
|
246
|
+
dim=config.hidden_size, num_heads=config.num_heads
|
|
247
|
+
)
|
|
248
|
+
self.mlp = Glm4vMoeVisionMLP(
|
|
249
|
+
dim=config.hidden_size, hidden_dim=config.out_hidden_size
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
def __call__(self, hidden_states, cu_seqlens, rotary_pos_emb) -> mx.array:
|
|
253
|
+
hidden_states = hidden_states + self.attn(
|
|
254
|
+
self.norm1(hidden_states),
|
|
255
|
+
cu_seqlens=cu_seqlens,
|
|
256
|
+
rotary_pos_emb=rotary_pos_emb,
|
|
257
|
+
)
|
|
258
|
+
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
|
259
|
+
return hidden_states
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
class VisionModel(nn.Module):
|
|
263
|
+
|
|
264
|
+
def __init__(self, config: VisionConfig) -> None:
|
|
265
|
+
super().__init__()
|
|
266
|
+
self.config = config
|
|
267
|
+
self.model_type = config.model_type
|
|
268
|
+
if self.model_type not in ["glm4v_moe", "glm4v_moe_vision"]:
|
|
269
|
+
raise ValueError(f"Unsupported model type: {self.model_type}")
|
|
270
|
+
self.spatial_merge_size = config.spatial_merge_size
|
|
271
|
+
|
|
272
|
+
self.embeddings = Glm4vVisionEmbeddings(config)
|
|
273
|
+
self.patch_embed = Glm4vMoeVisionPatchEmbed(
|
|
274
|
+
config=config,
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
self.window_size = config.window_size
|
|
278
|
+
self.patch_size = config.patch_size
|
|
279
|
+
self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
|
|
280
|
+
|
|
281
|
+
head_dim = config.hidden_size // config.num_heads
|
|
282
|
+
self.rotary_pos_emb = Glm4vMoeVisionRotaryEmbedding(head_dim // 2)
|
|
283
|
+
|
|
284
|
+
self.blocks = [Glm4vMoeVisionBlock(config) for _ in range(config.depth)]
|
|
285
|
+
self.merger = Glm4vMoeVisionPatchMerger(
|
|
286
|
+
dim=config.out_hidden_size, context_dim=config.intermediate_size
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
self.post_conv_layernorm = nn.RMSNorm(
|
|
290
|
+
config.hidden_size, eps=config.rms_norm_eps
|
|
291
|
+
)
|
|
292
|
+
self.downsample = nn.Conv2d(
|
|
293
|
+
in_channels=config.hidden_size,
|
|
294
|
+
out_channels=config.out_hidden_size,
|
|
295
|
+
kernel_size=config.spatial_merge_size,
|
|
296
|
+
stride=config.spatial_merge_size,
|
|
297
|
+
)
|
|
298
|
+
self.post_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
299
|
+
|
|
300
|
+
def rot_pos_emb(self, grid_thw):
|
|
301
|
+
pos_ids = []
|
|
302
|
+
|
|
303
|
+
for t, h, w in grid_thw.tolist():
|
|
304
|
+
hpos_ids = mx.expand_dims(mx.arange(h), 1)
|
|
305
|
+
hpos_ids = mx.repeat(hpos_ids, w, axis=1)
|
|
306
|
+
hpos_ids = hpos_ids.reshape(
|
|
307
|
+
h // self.spatial_merge_size,
|
|
308
|
+
self.spatial_merge_size,
|
|
309
|
+
w // self.spatial_merge_size,
|
|
310
|
+
self.spatial_merge_size,
|
|
311
|
+
)
|
|
312
|
+
hpos_ids = mx.transpose(hpos_ids, (0, 2, 1, 3))
|
|
313
|
+
hpos_ids = hpos_ids.flatten()
|
|
314
|
+
|
|
315
|
+
wpos_ids = mx.expand_dims(mx.arange(w), 0)
|
|
316
|
+
wpos_ids = mx.repeat(wpos_ids, h, axis=0)
|
|
317
|
+
wpos_ids = wpos_ids.reshape(
|
|
318
|
+
h // self.spatial_merge_size,
|
|
319
|
+
self.spatial_merge_size,
|
|
320
|
+
w // self.spatial_merge_size,
|
|
321
|
+
self.spatial_merge_size,
|
|
322
|
+
)
|
|
323
|
+
wpos_ids = mx.transpose(wpos_ids, (0, 2, 1, 3))
|
|
324
|
+
wpos_ids = wpos_ids.flatten()
|
|
325
|
+
|
|
326
|
+
stacked_pos_ids = mx.stack([hpos_ids, wpos_ids], axis=-1)
|
|
327
|
+
pos_ids.append(mx.tile(stacked_pos_ids, (t, 1)))
|
|
328
|
+
|
|
329
|
+
pos_ids = mx.concatenate(pos_ids, axis=0)
|
|
330
|
+
max_grid_size = mx.max(grid_thw[:, 1:])
|
|
331
|
+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
|
332
|
+
rotary_pos_emb = rotary_pos_emb_full[pos_ids]
|
|
333
|
+
|
|
334
|
+
return rotary_pos_emb.reshape(pos_ids.shape[0], -1), pos_ids
|
|
335
|
+
|
|
336
|
+
def __call__(
|
|
337
|
+
self,
|
|
338
|
+
hidden_states: mx.array,
|
|
339
|
+
grid_thw: mx.array,
|
|
340
|
+
output_hidden_states: Optional[bool] = None,
|
|
341
|
+
) -> mx.array:
|
|
342
|
+
|
|
343
|
+
hidden_states = self.patch_embed(hidden_states)
|
|
344
|
+
hidden_states = self.post_conv_layernorm(hidden_states)
|
|
345
|
+
rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw)
|
|
346
|
+
|
|
347
|
+
seq_lens = grid_thw[:, 1] * grid_thw[:, 2]
|
|
348
|
+
repeats = grid_thw[:, 0]
|
|
349
|
+
repeated_values = []
|
|
350
|
+
for i, (seq_len, repeat_count) in enumerate(
|
|
351
|
+
zip(seq_lens.tolist(), repeats.tolist())
|
|
352
|
+
):
|
|
353
|
+
repeated_values.extend([seq_len] * repeat_count)
|
|
354
|
+
|
|
355
|
+
cu_seqlens = mx.array(repeated_values).cumsum(axis=0)
|
|
356
|
+
cu_seqlens = mx.pad(cu_seqlens, (1, 0), constant_values=0)
|
|
357
|
+
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
|
358
|
+
hidden_states = self.embeddings(
|
|
359
|
+
hidden_states, seqlens, grid_thw, image_type_ids[:, 0], image_type_ids[:, 1]
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
for blk in self.blocks:
|
|
363
|
+
hidden_states = blk(
|
|
364
|
+
hidden_states,
|
|
365
|
+
cu_seqlens=cu_seqlens,
|
|
366
|
+
rotary_pos_emb=rotary_pos_emb,
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
hidden_states = self.post_layernorm(hidden_states)
|
|
370
|
+
|
|
371
|
+
hidden_states = hidden_states.reshape(
|
|
372
|
+
-1,
|
|
373
|
+
self.spatial_merge_size,
|
|
374
|
+
self.spatial_merge_size,
|
|
375
|
+
hidden_states.shape[-1],
|
|
376
|
+
)
|
|
377
|
+
hidden_states = self.downsample(hidden_states).reshape(
|
|
378
|
+
-1, self.config.out_hidden_size
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
hidden_states = self.merger(hidden_states)
|
|
382
|
+
return hidden_states
|
|
383
|
+
|
|
384
|
+
def sanitize(self, weights):
|
|
385
|
+
sanitized_weights = {}
|
|
386
|
+
for k, v in weights.items():
|
|
387
|
+
if "position_ids" in k:
|
|
388
|
+
# Remove unused position_ids
|
|
389
|
+
continue
|
|
390
|
+
elif "patch_embed.proj.weight" in k or "downsample.weight" in k:
|
|
391
|
+
# PyTorch conv2d weight tensors have shape:
|
|
392
|
+
# [out_channels, in_channels, kH, KW]
|
|
393
|
+
# MLX conv2d expects the weight be of shape:
|
|
394
|
+
# [out_channels, kH, KW, in_channels]
|
|
395
|
+
if check_array_shape(v):
|
|
396
|
+
sanitized_weights[k] = v
|
|
397
|
+
else:
|
|
398
|
+
if v.ndim == 5:
|
|
399
|
+
sanitized_weights[k] = v.transpose(0, 2, 3, 4, 1)
|
|
400
|
+
if v.ndim == 4:
|
|
401
|
+
sanitized_weights[k] = v.transpose(0, 2, 3, 1)
|
|
402
|
+
else:
|
|
403
|
+
sanitized_weights[k] = v
|
|
404
|
+
|
|
405
|
+
return sanitized_weights
|