fount-vlm-nell-02 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
- fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
- fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
- fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
- mlx_vlm/__init__.py +16 -0
- mlx_vlm/__main__.py +24 -0
- mlx_vlm/chat.py +234 -0
- mlx_vlm/chat_ui.py +508 -0
- mlx_vlm/convert.py +284 -0
- mlx_vlm/deprecation.py +52 -0
- mlx_vlm/evals/__init__.py +0 -0
- mlx_vlm/evals/math_vista.py +565 -0
- mlx_vlm/evals/mmmu.py +528 -0
- mlx_vlm/evals/mmstar.py +343 -0
- mlx_vlm/evals/ocrbench.py +453 -0
- mlx_vlm/evals/utils.py +37 -0
- mlx_vlm/generate.py +1457 -0
- mlx_vlm/lora.py +207 -0
- mlx_vlm/models/__init__.py +0 -0
- mlx_vlm/models/aya_vision/__init__.py +2 -0
- mlx_vlm/models/aya_vision/aya_vision.py +188 -0
- mlx_vlm/models/aya_vision/config.py +52 -0
- mlx_vlm/models/aya_vision/language.py +202 -0
- mlx_vlm/models/aya_vision/vision.py +340 -0
- mlx_vlm/models/base.py +356 -0
- mlx_vlm/models/cache.py +238 -0
- mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
- mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
- mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
- mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
- mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
- mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
- mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
- mlx_vlm/models/deepseekocr/__init__.py +2 -0
- mlx_vlm/models/deepseekocr/config.py +173 -0
- mlx_vlm/models/deepseekocr/conversation.py +264 -0
- mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
- mlx_vlm/models/deepseekocr/language.py +547 -0
- mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
- mlx_vlm/models/deepseekocr/sam.py +489 -0
- mlx_vlm/models/deepseekocr/vision.py +263 -0
- mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
- mlx_vlm/models/deepseekocr_2/config.py +216 -0
- mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
- mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
- mlx_vlm/models/deepseekocr_2/vision.py +439 -0
- mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
- mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
- mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
- mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
- mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
- mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
- mlx_vlm/models/fastvlm/__init__.py +2 -0
- mlx_vlm/models/fastvlm/config.py +79 -0
- mlx_vlm/models/fastvlm/fastvlm.py +198 -0
- mlx_vlm/models/fastvlm/language.py +49 -0
- mlx_vlm/models/fastvlm/vision.py +692 -0
- mlx_vlm/models/florence2/__init__.py +2 -0
- mlx_vlm/models/florence2/config.py +84 -0
- mlx_vlm/models/florence2/florence2.py +383 -0
- mlx_vlm/models/florence2/language.py +452 -0
- mlx_vlm/models/florence2/processing_florence2.py +30 -0
- mlx_vlm/models/florence2/vision.py +552 -0
- mlx_vlm/models/gemma3/__init__.py +2 -0
- mlx_vlm/models/gemma3/config.py +52 -0
- mlx_vlm/models/gemma3/gemma3.py +194 -0
- mlx_vlm/models/gemma3/language.py +293 -0
- mlx_vlm/models/gemma3/vision.py +215 -0
- mlx_vlm/models/gemma3n/__init__.py +2 -0
- mlx_vlm/models/gemma3n/audio.py +1038 -0
- mlx_vlm/models/gemma3n/config.py +130 -0
- mlx_vlm/models/gemma3n/gemma3n.py +322 -0
- mlx_vlm/models/gemma3n/language.py +631 -0
- mlx_vlm/models/gemma3n/vision.py +994 -0
- mlx_vlm/models/glm4v/__init__.py +3 -0
- mlx_vlm/models/glm4v/config.py +79 -0
- mlx_vlm/models/glm4v/glm4v.py +188 -0
- mlx_vlm/models/glm4v/language.py +574 -0
- mlx_vlm/models/glm4v/processing.py +220 -0
- mlx_vlm/models/glm4v/vision.py +406 -0
- mlx_vlm/models/glm4v_moe/__init__.py +3 -0
- mlx_vlm/models/glm4v_moe/config.py +81 -0
- mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
- mlx_vlm/models/glm4v_moe/language.py +674 -0
- mlx_vlm/models/glm4v_moe/processing.py +229 -0
- mlx_vlm/models/glm4v_moe/vision.py +405 -0
- mlx_vlm/models/glm_ocr/__init__.py +3 -0
- mlx_vlm/models/glm_ocr/config.py +93 -0
- mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
- mlx_vlm/models/glm_ocr/language.py +585 -0
- mlx_vlm/models/glm_ocr/processing.py +208 -0
- mlx_vlm/models/glm_ocr/vision.py +342 -0
- mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
- mlx_vlm/models/hunyuan_vl/config.py +136 -0
- mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
- mlx_vlm/models/hunyuan_vl/language.py +509 -0
- mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
- mlx_vlm/models/hunyuan_vl/vision.py +322 -0
- mlx_vlm/models/idefics2/__init__.py +2 -0
- mlx_vlm/models/idefics2/config.py +65 -0
- mlx_vlm/models/idefics2/idefics2.py +321 -0
- mlx_vlm/models/idefics2/language.py +161 -0
- mlx_vlm/models/idefics2/vision.py +244 -0
- mlx_vlm/models/idefics3/__init__.py +4 -0
- mlx_vlm/models/idefics3/config.py +54 -0
- mlx_vlm/models/idefics3/idefics3.py +221 -0
- mlx_vlm/models/idefics3/language.py +157 -0
- mlx_vlm/models/idefics3/vision.py +265 -0
- mlx_vlm/models/internvl_chat/__init__.py +3 -0
- mlx_vlm/models/internvl_chat/config.py +89 -0
- mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
- mlx_vlm/models/internvl_chat/language.py +187 -0
- mlx_vlm/models/internvl_chat/processor.py +395 -0
- mlx_vlm/models/internvl_chat/vision.py +265 -0
- mlx_vlm/models/interpolate.py +183 -0
- mlx_vlm/models/jina_vlm/__init__.py +3 -0
- mlx_vlm/models/jina_vlm/config.py +142 -0
- mlx_vlm/models/jina_vlm/image_processor.py +430 -0
- mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
- mlx_vlm/models/jina_vlm/language.py +272 -0
- mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
- mlx_vlm/models/jina_vlm/vision.py +202 -0
- mlx_vlm/models/kernels.py +447 -0
- mlx_vlm/models/kimi_vl/__init__.py +4 -0
- mlx_vlm/models/kimi_vl/config.py +84 -0
- mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
- mlx_vlm/models/kimi_vl/language.py +460 -0
- mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
- mlx_vlm/models/kimi_vl/vision.py +485 -0
- mlx_vlm/models/lfm2_vl/__init__.py +2 -0
- mlx_vlm/models/lfm2_vl/config.py +94 -0
- mlx_vlm/models/lfm2_vl/language.py +49 -0
- mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
- mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
- mlx_vlm/models/lfm2_vl/vision.py +223 -0
- mlx_vlm/models/llama4/__init__.py +2 -0
- mlx_vlm/models/llama4/config.py +83 -0
- mlx_vlm/models/llama4/language.py +334 -0
- mlx_vlm/models/llama4/llama4.py +146 -0
- mlx_vlm/models/llama4/vision.py +526 -0
- mlx_vlm/models/llava/__init__.py +2 -0
- mlx_vlm/models/llava/config.py +61 -0
- mlx_vlm/models/llava/language.py +200 -0
- mlx_vlm/models/llava/llava.py +132 -0
- mlx_vlm/models/llava/vision.py +233 -0
- mlx_vlm/models/llava_bunny/__init__.py +2 -0
- mlx_vlm/models/llava_bunny/config.py +85 -0
- mlx_vlm/models/llava_bunny/language.py +194 -0
- mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
- mlx_vlm/models/llava_bunny/vision.py +278 -0
- mlx_vlm/models/llava_next/__init__.py +2 -0
- mlx_vlm/models/llava_next/config.py +60 -0
- mlx_vlm/models/llava_next/language.py +192 -0
- mlx_vlm/models/llava_next/llava_next.py +138 -0
- mlx_vlm/models/llava_next/vision.py +217 -0
- mlx_vlm/models/mistral3/__init__.py +2 -0
- mlx_vlm/models/mistral3/config.py +59 -0
- mlx_vlm/models/mistral3/language.py +269 -0
- mlx_vlm/models/mistral3/mistral3.py +383 -0
- mlx_vlm/models/mllama/__init__.py +4 -0
- mlx_vlm/models/mllama/config.py +74 -0
- mlx_vlm/models/mllama/language.py +377 -0
- mlx_vlm/models/mllama/mllama.py +210 -0
- mlx_vlm/models/mllama/vision.py +458 -0
- mlx_vlm/models/molmo/__init__.py +5 -0
- mlx_vlm/models/molmo/config.py +93 -0
- mlx_vlm/models/molmo/language.py +208 -0
- mlx_vlm/models/molmo/molmo.py +108 -0
- mlx_vlm/models/molmo/processing_molmo.py +763 -0
- mlx_vlm/models/molmo/vision.py +408 -0
- mlx_vlm/models/molmo2/__init__.py +6 -0
- mlx_vlm/models/molmo2/config.py +137 -0
- mlx_vlm/models/molmo2/language.py +206 -0
- mlx_vlm/models/molmo2/molmo2.py +330 -0
- mlx_vlm/models/molmo2/processing.py +773 -0
- mlx_vlm/models/molmo2/vision.py +286 -0
- mlx_vlm/models/moondream2/__init__.py +11 -0
- mlx_vlm/models/moondream2/config.py +92 -0
- mlx_vlm/models/moondream2/image_crops.py +269 -0
- mlx_vlm/models/moondream2/language.py +267 -0
- mlx_vlm/models/moondream2/moondream2.py +522 -0
- mlx_vlm/models/moondream2/processing_moondream.py +144 -0
- mlx_vlm/models/moondream2/vision.py +200 -0
- mlx_vlm/models/multi_modality/__init__.py +4 -0
- mlx_vlm/models/multi_modality/config.py +108 -0
- mlx_vlm/models/multi_modality/language.py +191 -0
- mlx_vlm/models/multi_modality/multi_modality.py +338 -0
- mlx_vlm/models/multi_modality/sam.py +543 -0
- mlx_vlm/models/multi_modality/vision.py +450 -0
- mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
- mlx_vlm/models/paddleocr_vl/config.py +93 -0
- mlx_vlm/models/paddleocr_vl/language.py +522 -0
- mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
- mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
- mlx_vlm/models/paddleocr_vl/vision.py +358 -0
- mlx_vlm/models/paligemma/__init__.py +4 -0
- mlx_vlm/models/paligemma/config.py +50 -0
- mlx_vlm/models/paligemma/language.py +253 -0
- mlx_vlm/models/paligemma/paligemma.py +140 -0
- mlx_vlm/models/paligemma/vision.py +218 -0
- mlx_vlm/models/phi3_v/__init__.py +5 -0
- mlx_vlm/models/phi3_v/config.py +55 -0
- mlx_vlm/models/phi3_v/language.py +2 -0
- mlx_vlm/models/phi3_v/phi3_v.py +239 -0
- mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
- mlx_vlm/models/phi3_v/vision.py +294 -0
- mlx_vlm/models/pixtral/__init__.py +4 -0
- mlx_vlm/models/pixtral/config.py +69 -0
- mlx_vlm/models/pixtral/language.py +195 -0
- mlx_vlm/models/pixtral/pixtral.py +208 -0
- mlx_vlm/models/pixtral/vision.py +293 -0
- mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_5_vl/config.py +90 -0
- mlx_vlm/models/qwen2_5_vl/language.py +541 -0
- mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
- mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
- mlx_vlm/models/qwen2_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_vl/config.py +86 -0
- mlx_vlm/models/qwen2_vl/language.py +539 -0
- mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
- mlx_vlm/models/qwen2_vl/vision.py +308 -0
- mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
- mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
- mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
- mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
- mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
- mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
- mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
- mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
- mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
- mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
- mlx_vlm/models/qwen3_vl/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl/config.py +103 -0
- mlx_vlm/models/qwen3_vl/language.py +596 -0
- mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
- mlx_vlm/models/qwen3_vl/vision.py +441 -0
- mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
- mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
- mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
- mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
- mlx_vlm/models/smolvlm/__init__.py +4 -0
- mlx_vlm/models/smolvlm/config.py +59 -0
- mlx_vlm/models/smolvlm/smolvlm.py +60 -0
- mlx_vlm/prompt_utils.py +565 -0
- mlx_vlm/sample_utils.py +39 -0
- mlx_vlm/server.py +1107 -0
- mlx_vlm/smolvlm_video_generate.py +109 -0
- mlx_vlm/tokenizer_utils.py +371 -0
- mlx_vlm/trainer/__init__.py +9 -0
- mlx_vlm/trainer/lora.py +70 -0
- mlx_vlm/trainer/trainer.py +299 -0
- mlx_vlm/trainer/utils.py +160 -0
- mlx_vlm/utils.py +1339 -0
- mlx_vlm/version.py +1 -0
- mlx_vlm/video_generate.py +611 -0
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
from typing import List, Optional, Union
|
|
2
|
+
|
|
3
|
+
import mlx.core as mx
|
|
4
|
+
import mlx.nn as nn
|
|
5
|
+
|
|
6
|
+
from ..base import InputEmbeddingsFeatures
|
|
7
|
+
from .config import ModelConfig
|
|
8
|
+
from .language import LanguageModel
|
|
9
|
+
from .vision import Llama4MultiModalProjector, VisionModel
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Model(nn.Module):
|
|
13
|
+
def __init__(self, config: ModelConfig):
|
|
14
|
+
super().__init__()
|
|
15
|
+
self.config = config
|
|
16
|
+
self.vision_model = VisionModel(config.vision_config)
|
|
17
|
+
self.multi_modal_projector = Llama4MultiModalProjector(config)
|
|
18
|
+
self.language_model = LanguageModel(config.text_config)
|
|
19
|
+
self.vocab_size = config.text_config.vocab_size
|
|
20
|
+
|
|
21
|
+
def set_input_embeddings(self, value):
|
|
22
|
+
self.language_model.set_input_embeddings(value)
|
|
23
|
+
|
|
24
|
+
def get_output_embeddings(self):
|
|
25
|
+
return self.language_model.get_output_embeddings()
|
|
26
|
+
|
|
27
|
+
def set_output_embeddings(self, new_embeddings):
|
|
28
|
+
self.language_model.set_output_embeddings(new_embeddings)
|
|
29
|
+
|
|
30
|
+
def set_decoder(self, decoder):
|
|
31
|
+
self.language_model.set_decoder(decoder)
|
|
32
|
+
|
|
33
|
+
def get_decoder(self):
|
|
34
|
+
return self.language_model.get_decoder()
|
|
35
|
+
|
|
36
|
+
def get_image_features(
|
|
37
|
+
self,
|
|
38
|
+
pixel_values: mx.array,
|
|
39
|
+
vision_feature_layer: Union[int, List[int]],
|
|
40
|
+
vision_feature_select_strategy: str,
|
|
41
|
+
**kwargs,
|
|
42
|
+
):
|
|
43
|
+
if vision_feature_select_strategy not in ["default", "full"]:
|
|
44
|
+
raise ValueError(
|
|
45
|
+
f"Unexpected select feature strategy: {self.vision_feature_select_strategy}"
|
|
46
|
+
)
|
|
47
|
+
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
|
48
|
+
hidden_state = self.vision_model(
|
|
49
|
+
pixel_values, output_hidden_states=False, **kwargs
|
|
50
|
+
)
|
|
51
|
+
return hidden_state
|
|
52
|
+
|
|
53
|
+
def get_input_embeddings(
|
|
54
|
+
self,
|
|
55
|
+
input_ids: Optional[mx.array] = None,
|
|
56
|
+
pixel_values: Optional[mx.array] = None,
|
|
57
|
+
**kwargs,
|
|
58
|
+
):
|
|
59
|
+
if pixel_values is None:
|
|
60
|
+
return InputEmbeddingsFeatures(
|
|
61
|
+
inputs_embeds=self.language_model.model.embed_tokens(input_ids)
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# Get the input embeddings from the language model
|
|
65
|
+
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
|
66
|
+
|
|
67
|
+
image_features = self.get_image_features(
|
|
68
|
+
pixel_values=pixel_values,
|
|
69
|
+
vision_feature_layer=kwargs.get("vision_feature_layer", -1),
|
|
70
|
+
vision_feature_select_strategy=kwargs.get(
|
|
71
|
+
"vision_feature_select_strategy", "default"
|
|
72
|
+
),
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
vision_flat = image_features.reshape(-1, image_features.shape[-1])
|
|
76
|
+
projected_vision_flat = self.multi_modal_projector(vision_flat)
|
|
77
|
+
|
|
78
|
+
# Insert special image tokens in the input_ids
|
|
79
|
+
final_inputs_embeds = self._prepare_inputs_for_multimodal(
|
|
80
|
+
projected_vision_flat, inputs_embeds, input_ids
|
|
81
|
+
)
|
|
82
|
+
return InputEmbeddingsFeatures(inputs_embeds=final_inputs_embeds)
|
|
83
|
+
|
|
84
|
+
def _prepare_inputs_for_multimodal(self, image_features, inputs_embeds, input_ids):
|
|
85
|
+
image_token_index = self.config.image_token_index
|
|
86
|
+
|
|
87
|
+
# Find positions of <image> tokens
|
|
88
|
+
image_mask = input_ids == image_token_index
|
|
89
|
+
|
|
90
|
+
batch_size, seq_len = input_ids.shape
|
|
91
|
+
|
|
92
|
+
# Process each batch item
|
|
93
|
+
batch_outputs = []
|
|
94
|
+
feature_start_idx = 0
|
|
95
|
+
|
|
96
|
+
for batch_idx in range(batch_size):
|
|
97
|
+
batch_mask = image_mask[batch_idx]
|
|
98
|
+
num_positions = mx.sum(batch_mask).item()
|
|
99
|
+
|
|
100
|
+
if num_positions > 0:
|
|
101
|
+
batch_features = image_features[
|
|
102
|
+
feature_start_idx : feature_start_idx + num_positions
|
|
103
|
+
]
|
|
104
|
+
|
|
105
|
+
# Create indices for gathering
|
|
106
|
+
cumsum = mx.cumsum(batch_mask.astype(mx.int32))
|
|
107
|
+
feature_indices = mx.where(batch_mask, cumsum - 1, 0)
|
|
108
|
+
|
|
109
|
+
# Gather features
|
|
110
|
+
gathered_features = batch_features[feature_indices]
|
|
111
|
+
|
|
112
|
+
# Combine with original embeddings
|
|
113
|
+
batch_mask_expanded = mx.expand_dims(batch_mask, axis=-1)
|
|
114
|
+
batch_output = mx.where(
|
|
115
|
+
batch_mask_expanded, gathered_features, inputs_embeds[batch_idx]
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
feature_start_idx += num_positions
|
|
119
|
+
else:
|
|
120
|
+
batch_output = inputs_embeds[batch_idx]
|
|
121
|
+
|
|
122
|
+
batch_outputs.append(batch_output)
|
|
123
|
+
|
|
124
|
+
return mx.stack(batch_outputs, axis=0)
|
|
125
|
+
|
|
126
|
+
@property
|
|
127
|
+
def layers(self):
|
|
128
|
+
return self.language_model.model.layers
|
|
129
|
+
|
|
130
|
+
def __call__(
|
|
131
|
+
self,
|
|
132
|
+
input_ids: mx.array,
|
|
133
|
+
pixel_values: mx.array,
|
|
134
|
+
cache=None,
|
|
135
|
+
**kwargs,
|
|
136
|
+
):
|
|
137
|
+
|
|
138
|
+
input_embeddings_features = self.get_input_embeddings(
|
|
139
|
+
input_ids, pixel_values, **kwargs
|
|
140
|
+
)
|
|
141
|
+
logits = self.language_model(
|
|
142
|
+
inputs=input_ids,
|
|
143
|
+
inputs_embeds=input_embeddings_features.inputs_embeds,
|
|
144
|
+
cache=cache,
|
|
145
|
+
)
|
|
146
|
+
return logits
|
|
@@ -0,0 +1,526 @@
|
|
|
1
|
+
from typing import Optional, Tuple
|
|
2
|
+
|
|
3
|
+
import mlx.core as mx
|
|
4
|
+
import mlx.nn as nn
|
|
5
|
+
|
|
6
|
+
from ..base import pixel_shuffle
|
|
7
|
+
from .config import VisionConfig
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def check_array_shape(arr):
|
|
11
|
+
shape = arr.shape
|
|
12
|
+
|
|
13
|
+
# Check if the shape has 4 dimensions
|
|
14
|
+
if len(shape) != 4:
|
|
15
|
+
return False
|
|
16
|
+
|
|
17
|
+
out_channels, kH, KW, _ = shape
|
|
18
|
+
|
|
19
|
+
# Check if out_channels is the largest, and kH and KW are the same
|
|
20
|
+
if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
|
|
21
|
+
return True
|
|
22
|
+
else:
|
|
23
|
+
return False
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class Llama4MultiModalProjector(nn.Module):
|
|
27
|
+
def __init__(self, config):
|
|
28
|
+
super().__init__()
|
|
29
|
+
self.linear_1 = nn.Linear(
|
|
30
|
+
config.vision_config.vision_output_dim,
|
|
31
|
+
config.text_config.hidden_size,
|
|
32
|
+
bias=False,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
def __call__(self, image_features):
|
|
36
|
+
hidden_states = self.linear_1(image_features)
|
|
37
|
+
return hidden_states
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class Llama4VisionPixelShuffleMLP(nn.Module):
|
|
41
|
+
def __init__(self, config):
|
|
42
|
+
super().__init__()
|
|
43
|
+
self.pixel_shuffle_ratio = config.pixel_shuffle_ratio
|
|
44
|
+
self.inner_dim = int(
|
|
45
|
+
config.projector_input_dim // (self.pixel_shuffle_ratio**2)
|
|
46
|
+
)
|
|
47
|
+
self.output_dim = config.projector_output_dim
|
|
48
|
+
self.mlp = Llama4VisionMLP(config, bias=False, is_projector=True)
|
|
49
|
+
|
|
50
|
+
def __call__(self, encoded_patches: mx.array) -> mx.array:
|
|
51
|
+
encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio)
|
|
52
|
+
return self.mlp(encoded_patches)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
# TODO there is a different RoPE for vision encoder, defined as below
|
|
56
|
+
def reshape_for_broadcast(freqs_ci: mx.array, query: mx.array):
|
|
57
|
+
ndim = query.ndim
|
|
58
|
+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(query.shape)]
|
|
59
|
+
return freqs_ci.reshape(*shape)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def view_as_complex(x):
|
|
63
|
+
"""
|
|
64
|
+
Convert a tensor with shape (..., 2) to a complex tensor with shape (...).
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
x: A real tensor with last dimension of size 2.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
A complex tensor with size one less than the input.
|
|
71
|
+
"""
|
|
72
|
+
# Ensure the last dimension is size 2
|
|
73
|
+
assert x.shape[-1] == 2, f"Last dimension must be 2, got {x.shape[-1]}"
|
|
74
|
+
|
|
75
|
+
# Get real and imaginary parts
|
|
76
|
+
real, imag = x[..., 0], x[..., 1]
|
|
77
|
+
|
|
78
|
+
# Create complex tensor
|
|
79
|
+
return real + 1j * imag
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def view_as_real(x):
|
|
83
|
+
"""
|
|
84
|
+
Convert a complex tensor with shape (...) to a real tensor with shape (..., 2).
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
x: A complex tensor.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
A real tensor with an extra dimension of size 2.
|
|
91
|
+
"""
|
|
92
|
+
# Get real and imaginary parts
|
|
93
|
+
real = mx.real(x)
|
|
94
|
+
imag = mx.imag(x)
|
|
95
|
+
|
|
96
|
+
# Combine into a tensor with last dimension 2
|
|
97
|
+
return mx.stack([real, imag], axis=-1)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def vision_apply_rotary_emb(
|
|
101
|
+
query: mx.array,
|
|
102
|
+
key: mx.array,
|
|
103
|
+
freqs_ci: mx.array,
|
|
104
|
+
) -> Tuple[mx.array, mx.array]:
|
|
105
|
+
|
|
106
|
+
query_ = view_as_complex(query.astype(mx.float32).reshape(*query.shape[:-1], -1, 2))
|
|
107
|
+
key_ = view_as_complex(key.astype(mx.float32).reshape(*key.shape[:-1], -1, 2))
|
|
108
|
+
freqs_ci = reshape_for_broadcast(freqs_ci=freqs_ci, query=query_)
|
|
109
|
+
query_out = view_as_real(query_ * freqs_ci).flatten(3)
|
|
110
|
+
key_out = view_as_real(key_ * freqs_ci).flatten(3)
|
|
111
|
+
return query_out.astype(query.dtype), key_out.astype(key.dtype)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class Llama4VisionAttention(nn.Module):
|
|
115
|
+
def __init__(self, config: VisionConfig):
|
|
116
|
+
super().__init__()
|
|
117
|
+
self.config = config
|
|
118
|
+
self.embed_dim = config.hidden_size
|
|
119
|
+
self.num_heads = config.num_attention_heads
|
|
120
|
+
self.head_dim = config.hidden_size // config.num_attention_heads
|
|
121
|
+
self.num_key_value_groups = 1
|
|
122
|
+
self.scale = self.head_dim**-0.5
|
|
123
|
+
|
|
124
|
+
self.q_proj = nn.Linear(
|
|
125
|
+
self.embed_dim, self.num_heads * self.head_dim, bias=True
|
|
126
|
+
)
|
|
127
|
+
self.k_proj = nn.Linear(
|
|
128
|
+
self.embed_dim, self.num_heads * self.head_dim, bias=True
|
|
129
|
+
)
|
|
130
|
+
self.v_proj = nn.Linear(
|
|
131
|
+
self.embed_dim, self.num_heads * self.head_dim, bias=True
|
|
132
|
+
)
|
|
133
|
+
self.o_proj = nn.Linear(
|
|
134
|
+
self.num_heads * self.head_dim, self.embed_dim, bias=True
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
def __call__(
|
|
138
|
+
self,
|
|
139
|
+
hidden_states: mx.array,
|
|
140
|
+
freqs_ci: mx.array,
|
|
141
|
+
mask: Optional[mx.array] = None,
|
|
142
|
+
cache: Optional[mx.array] = None,
|
|
143
|
+
):
|
|
144
|
+
B, L, D = hidden_states.shape
|
|
145
|
+
|
|
146
|
+
query_states = self.q_proj(hidden_states).reshape(B, L, self.num_heads, -1)
|
|
147
|
+
key_states = self.k_proj(hidden_states).reshape(B, L, self.num_heads, -1)
|
|
148
|
+
value_states = self.v_proj(hidden_states).reshape(B, L, self.num_heads, -1)
|
|
149
|
+
|
|
150
|
+
query_states, key_states = vision_apply_rotary_emb(
|
|
151
|
+
query_states, key_states, freqs_ci=freqs_ci
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
query_states = query_states.transpose(0, 2, 1, 3)
|
|
155
|
+
key_states = key_states.transpose(0, 2, 1, 3)
|
|
156
|
+
value_states = value_states.transpose(0, 2, 1, 3)
|
|
157
|
+
|
|
158
|
+
attn_output = mx.fast.scaled_dot_product_attention(
|
|
159
|
+
query_states, key_states, value_states, scale=self.scale
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
attn_output = attn_output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
163
|
+
attn_output = self.o_proj(attn_output)
|
|
164
|
+
return attn_output
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
class Llama4VisionMLP(nn.Module):
|
|
168
|
+
def __init__(self, config, bias=True, is_projector=False):
|
|
169
|
+
super().__init__()
|
|
170
|
+
self.config = config
|
|
171
|
+
self.activation_fn = nn.GELU(approx="fast") # ACT2FN[config.hidden_act]
|
|
172
|
+
self.is_projector = is_projector
|
|
173
|
+
self.hidden_size = config.hidden_size
|
|
174
|
+
self.intermediate_size = config.intermediate_size
|
|
175
|
+
|
|
176
|
+
# Determine dimensions for first linear layer based on whether this is a projector
|
|
177
|
+
fc1_input_dim = self.intermediate_size if is_projector else self.hidden_size
|
|
178
|
+
fc1_output_dim = (
|
|
179
|
+
config.projector_input_dim if is_projector else self.intermediate_size
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
self.fc1 = nn.Linear(fc1_input_dim, fc1_output_dim, bias=bias)
|
|
183
|
+
|
|
184
|
+
# Determine dimensions for second linear layer
|
|
185
|
+
fc2_input_dim = (
|
|
186
|
+
config.projector_output_dim if is_projector else self.intermediate_size
|
|
187
|
+
)
|
|
188
|
+
fc2_output_dim = (
|
|
189
|
+
config.projector_output_dim if is_projector else self.hidden_size
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
self.fc2 = nn.Linear(fc2_input_dim, fc2_output_dim, bias=bias)
|
|
193
|
+
|
|
194
|
+
self.is_projector = is_projector
|
|
195
|
+
|
|
196
|
+
def __call__(self, hidden_states: mx.array) -> mx.array:
|
|
197
|
+
hidden_states = self.fc1(hidden_states)
|
|
198
|
+
hidden_states = self.activation_fn(hidden_states)
|
|
199
|
+
|
|
200
|
+
if self.is_projector:
|
|
201
|
+
return self.activation_fn(self.fc2(hidden_states))
|
|
202
|
+
|
|
203
|
+
return self.fc2(hidden_states)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class Llama4VisionEncoderLayer(nn.Module):
|
|
207
|
+
def __init__(self, config: VisionConfig):
|
|
208
|
+
super().__init__()
|
|
209
|
+
self.hidden_size = config.hidden_size
|
|
210
|
+
|
|
211
|
+
self.self_attn = Llama4VisionAttention(config)
|
|
212
|
+
self.mlp = Llama4VisionMLP(config)
|
|
213
|
+
|
|
214
|
+
self.input_layernorm = nn.LayerNorm(config.hidden_size)
|
|
215
|
+
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size)
|
|
216
|
+
|
|
217
|
+
def __call__(
|
|
218
|
+
self,
|
|
219
|
+
hidden_state: mx.array,
|
|
220
|
+
freqs_ci: mx.array,
|
|
221
|
+
mask: Optional[mx.array] = None,
|
|
222
|
+
):
|
|
223
|
+
# Self Attention
|
|
224
|
+
residual = hidden_state
|
|
225
|
+
|
|
226
|
+
hidden_state = self.input_layernorm(hidden_state)
|
|
227
|
+
|
|
228
|
+
hidden_state = self.self_attn(
|
|
229
|
+
hidden_state,
|
|
230
|
+
freqs_ci=freqs_ci,
|
|
231
|
+
mask=mask,
|
|
232
|
+
)
|
|
233
|
+
hidden_state = residual + hidden_state
|
|
234
|
+
|
|
235
|
+
# Feed forward
|
|
236
|
+
residual = hidden_state
|
|
237
|
+
hidden_state = self.post_attention_layernorm(hidden_state)
|
|
238
|
+
hidden_state = self.mlp(hidden_state)
|
|
239
|
+
hidden_state = residual + hidden_state
|
|
240
|
+
return hidden_state
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
class Llama4VisionEncoder(nn.Module):
|
|
244
|
+
"""
|
|
245
|
+
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
|
246
|
+
[`Llama4VisionEncoderLayer`].
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
config: VisionConfig
|
|
250
|
+
"""
|
|
251
|
+
|
|
252
|
+
def __init__(self, config: VisionConfig):
|
|
253
|
+
super().__init__()
|
|
254
|
+
self.config = config
|
|
255
|
+
self.layers = [
|
|
256
|
+
Llama4VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)
|
|
257
|
+
]
|
|
258
|
+
self.config = config
|
|
259
|
+
|
|
260
|
+
def __call__(
|
|
261
|
+
self,
|
|
262
|
+
hidden_states: mx.array,
|
|
263
|
+
freqs_ci: mx.array, # TODO move this to an attribute instead of keeping it around
|
|
264
|
+
mask: Optional[mx.array] = None,
|
|
265
|
+
):
|
|
266
|
+
|
|
267
|
+
for i, encoder_layer in enumerate(self.layers):
|
|
268
|
+
hidden_states = encoder_layer(
|
|
269
|
+
hidden_state=hidden_states,
|
|
270
|
+
mask=mask,
|
|
271
|
+
freqs_ci=freqs_ci,
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
return hidden_states
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
class Llama4UnfoldConvolution(nn.Module):
|
|
278
|
+
def __init__(self, config):
|
|
279
|
+
super().__init__()
|
|
280
|
+
kernel_size = config.patch_size
|
|
281
|
+
if isinstance(kernel_size, int):
|
|
282
|
+
kernel_size = (kernel_size, kernel_size)
|
|
283
|
+
self.kernel_size = kernel_size
|
|
284
|
+
self.stride = config.patch_size
|
|
285
|
+
self.linear = nn.Linear(
|
|
286
|
+
config.num_channels * kernel_size[0] * kernel_size[1],
|
|
287
|
+
config.hidden_size,
|
|
288
|
+
bias=False,
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
def _pair(self, x):
|
|
292
|
+
"""Convert input to a pair of values."""
|
|
293
|
+
if isinstance(x, (list, tuple)):
|
|
294
|
+
return tuple(x)
|
|
295
|
+
return (x, x)
|
|
296
|
+
|
|
297
|
+
def unfold(self, input_tensor):
|
|
298
|
+
"""
|
|
299
|
+
Extract sliding local blocks from a batched input tensor (MLX implementation).
|
|
300
|
+
|
|
301
|
+
This is equivalent to PyTorch's nn.functional.unfold or im2col operation.
|
|
302
|
+
|
|
303
|
+
Args:
|
|
304
|
+
input_tensor: Input tensor of shape (B, C, H, W)
|
|
305
|
+
|
|
306
|
+
Returns:
|
|
307
|
+
Unfolded tensor of shape (B, C*kernel_height*kernel_width, L)
|
|
308
|
+
where L is the number of blocks
|
|
309
|
+
"""
|
|
310
|
+
# Convert to pairs
|
|
311
|
+
kernel_size = self._pair(self.kernel_size)
|
|
312
|
+
stride = self._pair(self.stride)
|
|
313
|
+
padding = (0, 0) # No padding in the original code
|
|
314
|
+
dilation = (1, 1) # Default dilation
|
|
315
|
+
|
|
316
|
+
# Input shape
|
|
317
|
+
batch_size, channels, height, width = input_tensor.shape
|
|
318
|
+
|
|
319
|
+
# Calculate output dimensions
|
|
320
|
+
height_out = (
|
|
321
|
+
height + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1
|
|
322
|
+
) // stride[0] + 1
|
|
323
|
+
width_out = (
|
|
324
|
+
width + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1
|
|
325
|
+
) // stride[1] + 1
|
|
326
|
+
|
|
327
|
+
# Initialize output arrays
|
|
328
|
+
blocks = []
|
|
329
|
+
|
|
330
|
+
# Extract blocks
|
|
331
|
+
for i in range(0, height - kernel_size[0] * dilation[0] + 1, stride[0]):
|
|
332
|
+
for j in range(0, width - kernel_size[1] * dilation[1] + 1, stride[1]):
|
|
333
|
+
# Extract the block for all channels
|
|
334
|
+
block = []
|
|
335
|
+
for di in range(kernel_size[0]):
|
|
336
|
+
for dj in range(kernel_size[1]):
|
|
337
|
+
h_idx = i + di * dilation[0]
|
|
338
|
+
w_idx = j + dj * dilation[1]
|
|
339
|
+
# Get the block for all channels and add to our list
|
|
340
|
+
block.append(input_tensor[:, :, h_idx, w_idx])
|
|
341
|
+
|
|
342
|
+
# Stack the channel-blocks
|
|
343
|
+
block = mx.stack(block, axis=1) # Shape: (B, k*k, C)
|
|
344
|
+
block = mx.transpose(block, [0, 2, 1]) # Shape: (B, C, k*k)
|
|
345
|
+
blocks.append(block)
|
|
346
|
+
|
|
347
|
+
# Stack all blocks together
|
|
348
|
+
result = mx.stack(blocks, axis=-1) # Shape: (B, C, k*k, L)
|
|
349
|
+
|
|
350
|
+
# Reshape to match PyTorch's unfold output format: (B, C*k*k, L)
|
|
351
|
+
result = mx.reshape(
|
|
352
|
+
result,
|
|
353
|
+
(
|
|
354
|
+
batch_size,
|
|
355
|
+
channels * kernel_size[0] * kernel_size[1],
|
|
356
|
+
height_out * width_out,
|
|
357
|
+
),
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
return result
|
|
361
|
+
|
|
362
|
+
def __call__(self, hidden_states: mx.array) -> mx.array:
|
|
363
|
+
hidden_states = self.unfold(hidden_states)
|
|
364
|
+
hidden_states = hidden_states.swapaxes(1, 2)
|
|
365
|
+
hidden_states = self.linear(hidden_states)
|
|
366
|
+
return hidden_states
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
class Llama4VisionRotaryEmbedding:
|
|
370
|
+
def __init__(self, config):
|
|
371
|
+
super().__init__()
|
|
372
|
+
idx = config.image_size // config.patch_size
|
|
373
|
+
img_idx = mx.arange(idx**2, dtype=mx.int32).reshape(idx**2, 1)
|
|
374
|
+
img_idx = mx.concatenate([img_idx, img_idx[:1]], axis=0)
|
|
375
|
+
img_idx[-1, -1] = -2 # ID_CLS_TOKEN
|
|
376
|
+
frequencies_x = img_idx % idx # get the coordinates of the 2d matrix along x
|
|
377
|
+
frequencies_y = img_idx // idx # get the coordinates of the 2d matrix along y
|
|
378
|
+
freq_dim = config.hidden_size // config.num_attention_heads // 2
|
|
379
|
+
rope_freq = 1.0 / (
|
|
380
|
+
config.rope_theta
|
|
381
|
+
** (
|
|
382
|
+
mx.arange(0, freq_dim, 2, dtype=mx.float32)[: (freq_dim // 2)]
|
|
383
|
+
/ freq_dim
|
|
384
|
+
)
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
# Expand dimensions for frequencies_x and frequencies_y
|
|
388
|
+
freqs_x_expanded = (frequencies_x + 1)[..., None] * rope_freq[None, None, :]
|
|
389
|
+
freqs_y_expanded = (frequencies_y + 1)[..., None] * rope_freq[None, None, :]
|
|
390
|
+
|
|
391
|
+
def repeat_interleave(tensor, repeats, dim=-1):
|
|
392
|
+
# Get the shape
|
|
393
|
+
shape = list(tensor.shape)
|
|
394
|
+
|
|
395
|
+
# Reshape to add an extra dimension for repeating
|
|
396
|
+
tensor = mx.reshape(tensor, shape[:-1] + [shape[-1], 1])
|
|
397
|
+
|
|
398
|
+
# Repeat along the new dimension
|
|
399
|
+
tensor = mx.repeat(tensor, repeats, axis=-1)
|
|
400
|
+
|
|
401
|
+
# Reshape to flatten the last two dimensions
|
|
402
|
+
return mx.reshape(tensor, shape[:-1] + [shape[-1] * repeats])
|
|
403
|
+
|
|
404
|
+
# Apply interleaving
|
|
405
|
+
freqs_x = repeat_interleave(freqs_x_expanded, 2)
|
|
406
|
+
freqs_y = repeat_interleave(freqs_y_expanded, 2)
|
|
407
|
+
freqs = mx.concatenate([freqs_x, freqs_y], axis=-1).astype(mx.float32)[..., ::2]
|
|
408
|
+
# Replaced masked_fill with where
|
|
409
|
+
mask = img_idx.reshape(-1, 1, 1) < 0
|
|
410
|
+
freqs = mx.where(mask, mx.zeros_like(freqs), freqs)
|
|
411
|
+
freq_cis = mx.stack([mx.cos(freqs), mx.sin(freqs)], axis=-1)
|
|
412
|
+
freq_cis = view_as_complex(freq_cis)
|
|
413
|
+
self.freqs_ci = freq_cis # idx**2, idx**2, idx * 2
|
|
414
|
+
|
|
415
|
+
def __call__(self, hidden_states):
|
|
416
|
+
return self.freqs_ci
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
class VisionModel(nn.Module):
|
|
420
|
+
def __init__(self, config: VisionConfig):
|
|
421
|
+
super().__init__()
|
|
422
|
+
self.image_size = config.image_size
|
|
423
|
+
self.patch_size = config.patch_size
|
|
424
|
+
self.hidden_size = config.hidden_size
|
|
425
|
+
self.num_channels = config.num_channels
|
|
426
|
+
self.model_type = config.model_type
|
|
427
|
+
if self.model_type not in ["llama4", "llama4_vision_model"]:
|
|
428
|
+
raise ValueError(f"Model type {self.model_type} not supported")
|
|
429
|
+
|
|
430
|
+
self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
|
|
431
|
+
self.scale = config.hidden_size**-0.5
|
|
432
|
+
|
|
433
|
+
self.class_embedding = self.scale * mx.random.normal((self.hidden_size,))
|
|
434
|
+
self.positional_embedding_vlm = self.scale * mx.random.normal(
|
|
435
|
+
(self.num_patches, self.hidden_size)
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
self.patch_embedding = Llama4UnfoldConvolution(config)
|
|
439
|
+
|
|
440
|
+
self.rotary_embedding = Llama4VisionRotaryEmbedding(config)
|
|
441
|
+
|
|
442
|
+
# layer norms
|
|
443
|
+
self.layernorm_pre = nn.LayerNorm(self.hidden_size)
|
|
444
|
+
self.layernorm_post = nn.LayerNorm(self.hidden_size)
|
|
445
|
+
|
|
446
|
+
# encoders
|
|
447
|
+
self.model = Llama4VisionEncoder(config)
|
|
448
|
+
self.vision_adapter = Llama4VisionPixelShuffleMLP(config)
|
|
449
|
+
|
|
450
|
+
def get_input_embeddings(self):
|
|
451
|
+
"""
|
|
452
|
+
This function is used to fetch the first embedding layer to activate grads on inputs.
|
|
453
|
+
"""
|
|
454
|
+
return self.patch_embedding
|
|
455
|
+
|
|
456
|
+
def __call__(
|
|
457
|
+
self,
|
|
458
|
+
pixel_values: mx.array,
|
|
459
|
+
output_attentions: Optional[bool] = None,
|
|
460
|
+
output_hidden_states: Optional[bool] = None,
|
|
461
|
+
capture_activations: Optional[bool] = True,
|
|
462
|
+
):
|
|
463
|
+
|
|
464
|
+
batch_size_times_num_tiles, num_channels, height, width = pixel_values.shape
|
|
465
|
+
num_concurrent_media = 1
|
|
466
|
+
num_chunks = 1
|
|
467
|
+
|
|
468
|
+
hidden_state = self.patch_embedding(pixel_values)
|
|
469
|
+
|
|
470
|
+
_, num_patches, hidden_dim = hidden_state.shape
|
|
471
|
+
|
|
472
|
+
# Add cls token
|
|
473
|
+
hidden_state = hidden_state.reshape(
|
|
474
|
+
batch_size_times_num_tiles * num_concurrent_media * num_chunks,
|
|
475
|
+
num_patches,
|
|
476
|
+
hidden_dim,
|
|
477
|
+
)
|
|
478
|
+
|
|
479
|
+
class_embedding = mx.broadcast_to(
|
|
480
|
+
self.class_embedding, (hidden_state.shape[0], 1, hidden_state.shape[-1])
|
|
481
|
+
)
|
|
482
|
+
hidden_state = mx.concatenate([hidden_state, class_embedding], axis=1)
|
|
483
|
+
num_patches += 1
|
|
484
|
+
|
|
485
|
+
# Position embeddings
|
|
486
|
+
hidden_state = hidden_state.reshape(
|
|
487
|
+
batch_size_times_num_tiles * num_concurrent_media,
|
|
488
|
+
num_chunks,
|
|
489
|
+
num_patches,
|
|
490
|
+
hidden_dim,
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
positional_embedding = self.positional_embedding_vlm
|
|
494
|
+
hidden_state = hidden_state + positional_embedding
|
|
495
|
+
|
|
496
|
+
hidden_state = self.layernorm_pre(hidden_state)
|
|
497
|
+
|
|
498
|
+
hidden_state = hidden_state.reshape(batch_size_times_num_tiles, -1, hidden_dim)
|
|
499
|
+
freqs_ci = self.rotary_embedding(pixel_values)
|
|
500
|
+
|
|
501
|
+
hidden_state = self.model(
|
|
502
|
+
hidden_state,
|
|
503
|
+
mask=None,
|
|
504
|
+
freqs_ci=freqs_ci,
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
hidden_state = self.layernorm_post(hidden_state)
|
|
508
|
+
|
|
509
|
+
hidden_state = hidden_state[:, :-1, :]
|
|
510
|
+
|
|
511
|
+
# now, we use Llama4VisionPixelShuffle + mlp to project embeddings
|
|
512
|
+
final_hidden_state = self.vision_adapter(hidden_state)
|
|
513
|
+
|
|
514
|
+
# Return only the final state
|
|
515
|
+
return final_hidden_state
|
|
516
|
+
|
|
517
|
+
def sanitize(self, weights):
|
|
518
|
+
sanitized_weights = {}
|
|
519
|
+
for k, v in weights.items():
|
|
520
|
+
if "position_ids" in k:
|
|
521
|
+
# Remove unused position_ids
|
|
522
|
+
continue
|
|
523
|
+
else:
|
|
524
|
+
sanitized_weights[k] = v
|
|
525
|
+
|
|
526
|
+
return sanitized_weights
|