fount-vlm-nell-02 0.3.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fount_vlm_nell_02-0.3.11.dist-info/METADATA +418 -0
- fount_vlm_nell_02-0.3.11.dist-info/RECORD +258 -0
- fount_vlm_nell_02-0.3.11.dist-info/WHEEL +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/entry_points.txt +5 -0
- fount_vlm_nell_02-0.3.11.dist-info/licenses/LICENSE +21 -0
- fount_vlm_nell_02-0.3.11.dist-info/top_level.txt +1 -0
- mlx_vlm/__init__.py +16 -0
- mlx_vlm/__main__.py +24 -0
- mlx_vlm/chat.py +234 -0
- mlx_vlm/chat_ui.py +508 -0
- mlx_vlm/convert.py +284 -0
- mlx_vlm/deprecation.py +52 -0
- mlx_vlm/evals/__init__.py +0 -0
- mlx_vlm/evals/math_vista.py +565 -0
- mlx_vlm/evals/mmmu.py +528 -0
- mlx_vlm/evals/mmstar.py +343 -0
- mlx_vlm/evals/ocrbench.py +453 -0
- mlx_vlm/evals/utils.py +37 -0
- mlx_vlm/generate.py +1457 -0
- mlx_vlm/lora.py +207 -0
- mlx_vlm/models/__init__.py +0 -0
- mlx_vlm/models/aya_vision/__init__.py +2 -0
- mlx_vlm/models/aya_vision/aya_vision.py +188 -0
- mlx_vlm/models/aya_vision/config.py +52 -0
- mlx_vlm/models/aya_vision/language.py +202 -0
- mlx_vlm/models/aya_vision/vision.py +340 -0
- mlx_vlm/models/base.py +356 -0
- mlx_vlm/models/cache.py +238 -0
- mlx_vlm/models/deepseek_vl_v2/__init__.py +2 -0
- mlx_vlm/models/deepseek_vl_v2/config.py +159 -0
- mlx_vlm/models/deepseek_vl_v2/conversation.py +264 -0
- mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +418 -0
- mlx_vlm/models/deepseek_vl_v2/language.py +539 -0
- mlx_vlm/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +536 -0
- mlx_vlm/models/deepseek_vl_v2/vision.py +322 -0
- mlx_vlm/models/deepseekocr/__init__.py +2 -0
- mlx_vlm/models/deepseekocr/config.py +173 -0
- mlx_vlm/models/deepseekocr/conversation.py +264 -0
- mlx_vlm/models/deepseekocr/deepseekocr.py +371 -0
- mlx_vlm/models/deepseekocr/language.py +547 -0
- mlx_vlm/models/deepseekocr/processing_deepseekocr.py +655 -0
- mlx_vlm/models/deepseekocr/sam.py +489 -0
- mlx_vlm/models/deepseekocr/vision.py +263 -0
- mlx_vlm/models/deepseekocr_2/__init__.py +12 -0
- mlx_vlm/models/deepseekocr_2/config.py +216 -0
- mlx_vlm/models/deepseekocr_2/deepseekocr_2.py +297 -0
- mlx_vlm/models/deepseekocr_2/processing_deepseekocr.py +624 -0
- mlx_vlm/models/deepseekocr_2/vision.py +439 -0
- mlx_vlm/models/ernie4_5_moe_vl/__init__.py +5 -0
- mlx_vlm/models/ernie4_5_moe_vl/config.py +139 -0
- mlx_vlm/models/ernie4_5_moe_vl/ernie4_5_moe_vl.py +337 -0
- mlx_vlm/models/ernie4_5_moe_vl/language.py +770 -0
- mlx_vlm/models/ernie4_5_moe_vl/processor.py +686 -0
- mlx_vlm/models/ernie4_5_moe_vl/vision.py +322 -0
- mlx_vlm/models/fastvlm/__init__.py +2 -0
- mlx_vlm/models/fastvlm/config.py +79 -0
- mlx_vlm/models/fastvlm/fastvlm.py +198 -0
- mlx_vlm/models/fastvlm/language.py +49 -0
- mlx_vlm/models/fastvlm/vision.py +692 -0
- mlx_vlm/models/florence2/__init__.py +2 -0
- mlx_vlm/models/florence2/config.py +84 -0
- mlx_vlm/models/florence2/florence2.py +383 -0
- mlx_vlm/models/florence2/language.py +452 -0
- mlx_vlm/models/florence2/processing_florence2.py +30 -0
- mlx_vlm/models/florence2/vision.py +552 -0
- mlx_vlm/models/gemma3/__init__.py +2 -0
- mlx_vlm/models/gemma3/config.py +52 -0
- mlx_vlm/models/gemma3/gemma3.py +194 -0
- mlx_vlm/models/gemma3/language.py +293 -0
- mlx_vlm/models/gemma3/vision.py +215 -0
- mlx_vlm/models/gemma3n/__init__.py +2 -0
- mlx_vlm/models/gemma3n/audio.py +1038 -0
- mlx_vlm/models/gemma3n/config.py +130 -0
- mlx_vlm/models/gemma3n/gemma3n.py +322 -0
- mlx_vlm/models/gemma3n/language.py +631 -0
- mlx_vlm/models/gemma3n/vision.py +994 -0
- mlx_vlm/models/glm4v/__init__.py +3 -0
- mlx_vlm/models/glm4v/config.py +79 -0
- mlx_vlm/models/glm4v/glm4v.py +188 -0
- mlx_vlm/models/glm4v/language.py +574 -0
- mlx_vlm/models/glm4v/processing.py +220 -0
- mlx_vlm/models/glm4v/vision.py +406 -0
- mlx_vlm/models/glm4v_moe/__init__.py +3 -0
- mlx_vlm/models/glm4v_moe/config.py +81 -0
- mlx_vlm/models/glm4v_moe/glm4v_moe.py +176 -0
- mlx_vlm/models/glm4v_moe/language.py +674 -0
- mlx_vlm/models/glm4v_moe/processing.py +229 -0
- mlx_vlm/models/glm4v_moe/vision.py +405 -0
- mlx_vlm/models/glm_ocr/__init__.py +3 -0
- mlx_vlm/models/glm_ocr/config.py +93 -0
- mlx_vlm/models/glm_ocr/glm_ocr.py +180 -0
- mlx_vlm/models/glm_ocr/language.py +585 -0
- mlx_vlm/models/glm_ocr/processing.py +208 -0
- mlx_vlm/models/glm_ocr/vision.py +342 -0
- mlx_vlm/models/hunyuan_vl/__init__.py +7 -0
- mlx_vlm/models/hunyuan_vl/config.py +136 -0
- mlx_vlm/models/hunyuan_vl/hunyuan_vl.py +181 -0
- mlx_vlm/models/hunyuan_vl/language.py +509 -0
- mlx_vlm/models/hunyuan_vl/processing_hunyuan_vl.py +607 -0
- mlx_vlm/models/hunyuan_vl/vision.py +322 -0
- mlx_vlm/models/idefics2/__init__.py +2 -0
- mlx_vlm/models/idefics2/config.py +65 -0
- mlx_vlm/models/idefics2/idefics2.py +321 -0
- mlx_vlm/models/idefics2/language.py +161 -0
- mlx_vlm/models/idefics2/vision.py +244 -0
- mlx_vlm/models/idefics3/__init__.py +4 -0
- mlx_vlm/models/idefics3/config.py +54 -0
- mlx_vlm/models/idefics3/idefics3.py +221 -0
- mlx_vlm/models/idefics3/language.py +157 -0
- mlx_vlm/models/idefics3/vision.py +265 -0
- mlx_vlm/models/internvl_chat/__init__.py +3 -0
- mlx_vlm/models/internvl_chat/config.py +89 -0
- mlx_vlm/models/internvl_chat/internvl_chat.py +115 -0
- mlx_vlm/models/internvl_chat/language.py +187 -0
- mlx_vlm/models/internvl_chat/processor.py +395 -0
- mlx_vlm/models/internvl_chat/vision.py +265 -0
- mlx_vlm/models/interpolate.py +183 -0
- mlx_vlm/models/jina_vlm/__init__.py +3 -0
- mlx_vlm/models/jina_vlm/config.py +142 -0
- mlx_vlm/models/jina_vlm/image_processor.py +430 -0
- mlx_vlm/models/jina_vlm/jina_vlm.py +280 -0
- mlx_vlm/models/jina_vlm/language.py +272 -0
- mlx_vlm/models/jina_vlm/processing_jinavlm.py +266 -0
- mlx_vlm/models/jina_vlm/vision.py +202 -0
- mlx_vlm/models/kernels.py +447 -0
- mlx_vlm/models/kimi_vl/__init__.py +4 -0
- mlx_vlm/models/kimi_vl/config.py +84 -0
- mlx_vlm/models/kimi_vl/kimi_vl.py +127 -0
- mlx_vlm/models/kimi_vl/language.py +460 -0
- mlx_vlm/models/kimi_vl/processing_kimi_vl.py +560 -0
- mlx_vlm/models/kimi_vl/vision.py +485 -0
- mlx_vlm/models/lfm2_vl/__init__.py +2 -0
- mlx_vlm/models/lfm2_vl/config.py +94 -0
- mlx_vlm/models/lfm2_vl/language.py +49 -0
- mlx_vlm/models/lfm2_vl/lfm2_vl.py +223 -0
- mlx_vlm/models/lfm2_vl/processing_lfm2_vl.py +320 -0
- mlx_vlm/models/lfm2_vl/vision.py +223 -0
- mlx_vlm/models/llama4/__init__.py +2 -0
- mlx_vlm/models/llama4/config.py +83 -0
- mlx_vlm/models/llama4/language.py +334 -0
- mlx_vlm/models/llama4/llama4.py +146 -0
- mlx_vlm/models/llama4/vision.py +526 -0
- mlx_vlm/models/llava/__init__.py +2 -0
- mlx_vlm/models/llava/config.py +61 -0
- mlx_vlm/models/llava/language.py +200 -0
- mlx_vlm/models/llava/llava.py +132 -0
- mlx_vlm/models/llava/vision.py +233 -0
- mlx_vlm/models/llava_bunny/__init__.py +2 -0
- mlx_vlm/models/llava_bunny/config.py +85 -0
- mlx_vlm/models/llava_bunny/language.py +194 -0
- mlx_vlm/models/llava_bunny/llava_bunny.py +217 -0
- mlx_vlm/models/llava_bunny/vision.py +278 -0
- mlx_vlm/models/llava_next/__init__.py +2 -0
- mlx_vlm/models/llava_next/config.py +60 -0
- mlx_vlm/models/llava_next/language.py +192 -0
- mlx_vlm/models/llava_next/llava_next.py +138 -0
- mlx_vlm/models/llava_next/vision.py +217 -0
- mlx_vlm/models/mistral3/__init__.py +2 -0
- mlx_vlm/models/mistral3/config.py +59 -0
- mlx_vlm/models/mistral3/language.py +269 -0
- mlx_vlm/models/mistral3/mistral3.py +383 -0
- mlx_vlm/models/mllama/__init__.py +4 -0
- mlx_vlm/models/mllama/config.py +74 -0
- mlx_vlm/models/mllama/language.py +377 -0
- mlx_vlm/models/mllama/mllama.py +210 -0
- mlx_vlm/models/mllama/vision.py +458 -0
- mlx_vlm/models/molmo/__init__.py +5 -0
- mlx_vlm/models/molmo/config.py +93 -0
- mlx_vlm/models/molmo/language.py +208 -0
- mlx_vlm/models/molmo/molmo.py +108 -0
- mlx_vlm/models/molmo/processing_molmo.py +763 -0
- mlx_vlm/models/molmo/vision.py +408 -0
- mlx_vlm/models/molmo2/__init__.py +6 -0
- mlx_vlm/models/molmo2/config.py +137 -0
- mlx_vlm/models/molmo2/language.py +206 -0
- mlx_vlm/models/molmo2/molmo2.py +330 -0
- mlx_vlm/models/molmo2/processing.py +773 -0
- mlx_vlm/models/molmo2/vision.py +286 -0
- mlx_vlm/models/moondream2/__init__.py +11 -0
- mlx_vlm/models/moondream2/config.py +92 -0
- mlx_vlm/models/moondream2/image_crops.py +269 -0
- mlx_vlm/models/moondream2/language.py +267 -0
- mlx_vlm/models/moondream2/moondream2.py +522 -0
- mlx_vlm/models/moondream2/processing_moondream.py +144 -0
- mlx_vlm/models/moondream2/vision.py +200 -0
- mlx_vlm/models/multi_modality/__init__.py +4 -0
- mlx_vlm/models/multi_modality/config.py +108 -0
- mlx_vlm/models/multi_modality/language.py +191 -0
- mlx_vlm/models/multi_modality/multi_modality.py +338 -0
- mlx_vlm/models/multi_modality/sam.py +543 -0
- mlx_vlm/models/multi_modality/vision.py +450 -0
- mlx_vlm/models/paddleocr_vl/__init__.py +3 -0
- mlx_vlm/models/paddleocr_vl/config.py +93 -0
- mlx_vlm/models/paddleocr_vl/language.py +522 -0
- mlx_vlm/models/paddleocr_vl/paddleocr_vl.py +207 -0
- mlx_vlm/models/paddleocr_vl/processing_paddleocr_vl.py +425 -0
- mlx_vlm/models/paddleocr_vl/vision.py +358 -0
- mlx_vlm/models/paligemma/__init__.py +4 -0
- mlx_vlm/models/paligemma/config.py +50 -0
- mlx_vlm/models/paligemma/language.py +253 -0
- mlx_vlm/models/paligemma/paligemma.py +140 -0
- mlx_vlm/models/paligemma/vision.py +218 -0
- mlx_vlm/models/phi3_v/__init__.py +5 -0
- mlx_vlm/models/phi3_v/config.py +55 -0
- mlx_vlm/models/phi3_v/language.py +2 -0
- mlx_vlm/models/phi3_v/phi3_v.py +239 -0
- mlx_vlm/models/phi3_v/processing_phi3_v.py +704 -0
- mlx_vlm/models/phi3_v/vision.py +294 -0
- mlx_vlm/models/pixtral/__init__.py +4 -0
- mlx_vlm/models/pixtral/config.py +69 -0
- mlx_vlm/models/pixtral/language.py +195 -0
- mlx_vlm/models/pixtral/pixtral.py +208 -0
- mlx_vlm/models/pixtral/vision.py +293 -0
- mlx_vlm/models/qwen2_5_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_5_vl/config.py +90 -0
- mlx_vlm/models/qwen2_5_vl/language.py +541 -0
- mlx_vlm/models/qwen2_5_vl/qwen2_5_vl.py +184 -0
- mlx_vlm/models/qwen2_5_vl/vision.py +414 -0
- mlx_vlm/models/qwen2_vl/__init__.py +2 -0
- mlx_vlm/models/qwen2_vl/config.py +86 -0
- mlx_vlm/models/qwen2_vl/language.py +539 -0
- mlx_vlm/models/qwen2_vl/qwen2_vl.py +180 -0
- mlx_vlm/models/qwen2_vl/vision.py +308 -0
- mlx_vlm/models/qwen3_omni_moe/__init__.py +29 -0
- mlx_vlm/models/qwen3_omni_moe/audio.py +317 -0
- mlx_vlm/models/qwen3_omni_moe/code2wav.py +542 -0
- mlx_vlm/models/qwen3_omni_moe/config.py +264 -0
- mlx_vlm/models/qwen3_omni_moe/language.py +622 -0
- mlx_vlm/models/qwen3_omni_moe/omni_utils.py +69 -0
- mlx_vlm/models/qwen3_omni_moe/qwen3_omni_moe.py +706 -0
- mlx_vlm/models/qwen3_omni_moe/talker.py +873 -0
- mlx_vlm/models/qwen3_omni_moe/thinker.py +366 -0
- mlx_vlm/models/qwen3_omni_moe/vision.py +419 -0
- mlx_vlm/models/qwen3_vl/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl/config.py +103 -0
- mlx_vlm/models/qwen3_vl/language.py +596 -0
- mlx_vlm/models/qwen3_vl/qwen3_vl.py +166 -0
- mlx_vlm/models/qwen3_vl/vision.py +441 -0
- mlx_vlm/models/qwen3_vl_moe/__init__.py +2 -0
- mlx_vlm/models/qwen3_vl_moe/config.py +108 -0
- mlx_vlm/models/qwen3_vl_moe/language.py +656 -0
- mlx_vlm/models/qwen3_vl_moe/qwen3_vl_moe.py +184 -0
- mlx_vlm/models/qwen3_vl_moe/vision.py +442 -0
- mlx_vlm/models/smolvlm/__init__.py +4 -0
- mlx_vlm/models/smolvlm/config.py +59 -0
- mlx_vlm/models/smolvlm/smolvlm.py +60 -0
- mlx_vlm/prompt_utils.py +565 -0
- mlx_vlm/sample_utils.py +39 -0
- mlx_vlm/server.py +1107 -0
- mlx_vlm/smolvlm_video_generate.py +109 -0
- mlx_vlm/tokenizer_utils.py +371 -0
- mlx_vlm/trainer/__init__.py +9 -0
- mlx_vlm/trainer/lora.py +70 -0
- mlx_vlm/trainer/trainer.py +299 -0
- mlx_vlm/trainer/utils.py +160 -0
- mlx_vlm/utils.py +1339 -0
- mlx_vlm/version.py +1 -0
- mlx_vlm/video_generate.py +611 -0
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import mlx.core as mx
|
|
4
|
+
import mlx.nn as nn
|
|
5
|
+
|
|
6
|
+
from ..base import InputEmbeddingsFeatures
|
|
7
|
+
from .config import ModelConfig
|
|
8
|
+
from .language import LanguageModel
|
|
9
|
+
from .vision import VisionModel
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Model(nn.Module):
|
|
13
|
+
def __init__(self, config: ModelConfig):
|
|
14
|
+
super().__init__()
|
|
15
|
+
self.config = config
|
|
16
|
+
self.vision_tower = VisionModel(config.vision_config)
|
|
17
|
+
self.language_model = LanguageModel(config.text_config, config)
|
|
18
|
+
|
|
19
|
+
def get_input_embeddings(
|
|
20
|
+
self,
|
|
21
|
+
input_ids: Optional[mx.array] = None,
|
|
22
|
+
pixel_values: Optional[mx.array] = None,
|
|
23
|
+
**kwargs,
|
|
24
|
+
):
|
|
25
|
+
image_grid_thw = kwargs.get("image_grid_thw", None)
|
|
26
|
+
video_grid_thw = kwargs.get("video_grid_thw", None)
|
|
27
|
+
mask = kwargs.get("mask", None)
|
|
28
|
+
grid_thw = image_grid_thw if image_grid_thw is not None else video_grid_thw
|
|
29
|
+
if pixel_values is None:
|
|
30
|
+
# Reset position state for text-only generation
|
|
31
|
+
self.language_model._position_ids = None
|
|
32
|
+
self.language_model._rope_deltas = None
|
|
33
|
+
return InputEmbeddingsFeatures(
|
|
34
|
+
inputs_embeds=self.language_model.model.embed_tokens(input_ids)
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
dtype = self.vision_tower.patch_embed.proj.weight.dtype
|
|
38
|
+
pixel_values = pixel_values.astype(dtype)
|
|
39
|
+
|
|
40
|
+
# Get the input embeddings from the language model
|
|
41
|
+
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
|
42
|
+
|
|
43
|
+
# Get the ouptut hidden states from the vision model
|
|
44
|
+
hidden_states = self.vision_tower(
|
|
45
|
+
pixel_values, grid_thw, output_hidden_states=False
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
# Insert special image tokens in the input_ids
|
|
49
|
+
final_inputs_embeds = self.merge_input_ids_with_image_features(
|
|
50
|
+
self.config.image_token_id,
|
|
51
|
+
self.config.video_token_id,
|
|
52
|
+
hidden_states,
|
|
53
|
+
inputs_embeds,
|
|
54
|
+
input_ids,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
# Pre-calculate position_ids for chunked prefill
|
|
58
|
+
if image_grid_thw is not None or video_grid_thw is not None:
|
|
59
|
+
position_ids, rope_deltas = self.language_model.get_rope_index(
|
|
60
|
+
input_ids, image_grid_thw, video_grid_thw, mask
|
|
61
|
+
)
|
|
62
|
+
self.language_model._position_ids = position_ids
|
|
63
|
+
self.language_model._rope_deltas = rope_deltas
|
|
64
|
+
|
|
65
|
+
return InputEmbeddingsFeatures(inputs_embeds=final_inputs_embeds)
|
|
66
|
+
|
|
67
|
+
@staticmethod
|
|
68
|
+
def merge_input_ids_with_image_features(
|
|
69
|
+
image_token_id,
|
|
70
|
+
video_token_id,
|
|
71
|
+
image_features,
|
|
72
|
+
inputs_embeds,
|
|
73
|
+
input_ids,
|
|
74
|
+
):
|
|
75
|
+
"""Merge image features into input embeddings at image token positions.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
image_token_id: The token ID for image placeholders
|
|
79
|
+
video_token_id: The token ID for video placeholders (fallback)
|
|
80
|
+
image_features: Vision features from the vision tower [num_features, hidden_dim]
|
|
81
|
+
inputs_embeds: Input embeddings [batch_size, seq_len, hidden_dim]
|
|
82
|
+
input_ids: Input token IDs [batch_size, seq_len]
|
|
83
|
+
grid_thw: Grid dimensions for each image (optional, not used in simple case)
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
Updated input embeddings with image features inserted
|
|
87
|
+
"""
|
|
88
|
+
# Find positions of image tokens
|
|
89
|
+
image_positions = input_ids == image_token_id
|
|
90
|
+
if mx.sum(image_positions) == 0:
|
|
91
|
+
image_positions = input_ids == video_token_id
|
|
92
|
+
|
|
93
|
+
# Get dimensions
|
|
94
|
+
batch_size, seq_len = input_ids.shape
|
|
95
|
+
|
|
96
|
+
# Process each batch item
|
|
97
|
+
batch_outputs = []
|
|
98
|
+
feature_start_idx = 0
|
|
99
|
+
|
|
100
|
+
for batch_idx in range(batch_size):
|
|
101
|
+
# Get mask for this batch
|
|
102
|
+
image_mask = image_positions[batch_idx]
|
|
103
|
+
num_positions = mx.sum(image_mask).item()
|
|
104
|
+
|
|
105
|
+
if num_positions > 0:
|
|
106
|
+
# Extract features for this batch
|
|
107
|
+
batch_features = image_features[
|
|
108
|
+
feature_start_idx : feature_start_idx + num_positions
|
|
109
|
+
]
|
|
110
|
+
|
|
111
|
+
# Validate we have the right number of features
|
|
112
|
+
if batch_features.shape[0] != num_positions:
|
|
113
|
+
raise ValueError(
|
|
114
|
+
f"Number of image token positions ({num_positions}) does not match "
|
|
115
|
+
f"number of image features ({batch_features.shape[0]}) for batch {batch_idx}"
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
# Create indices for gathering
|
|
119
|
+
cumsum = mx.cumsum(image_mask.astype(mx.int32))
|
|
120
|
+
feature_indices = mx.where(image_mask, cumsum - 1, 0)
|
|
121
|
+
|
|
122
|
+
# Gather features
|
|
123
|
+
gathered_features = batch_features[feature_indices]
|
|
124
|
+
|
|
125
|
+
# Combine with original embeddings
|
|
126
|
+
image_mask_expanded = mx.expand_dims(image_mask, axis=-1)
|
|
127
|
+
batch_output = mx.where(
|
|
128
|
+
image_mask_expanded, gathered_features, inputs_embeds[batch_idx]
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
feature_start_idx += num_positions
|
|
132
|
+
else:
|
|
133
|
+
# No image tokens in this batch item
|
|
134
|
+
batch_output = inputs_embeds[batch_idx]
|
|
135
|
+
|
|
136
|
+
batch_outputs.append(batch_output)
|
|
137
|
+
|
|
138
|
+
# Stack all batch outputs
|
|
139
|
+
return mx.stack(batch_outputs, axis=0)
|
|
140
|
+
|
|
141
|
+
@property
|
|
142
|
+
def layers(self):
|
|
143
|
+
return self.language_model.model.layers
|
|
144
|
+
|
|
145
|
+
def __call__(
|
|
146
|
+
self,
|
|
147
|
+
input_ids: mx.array,
|
|
148
|
+
pixel_values: Optional[mx.array] = None,
|
|
149
|
+
mask: Optional[mx.array] = None,
|
|
150
|
+
cache=None,
|
|
151
|
+
**kwargs,
|
|
152
|
+
):
|
|
153
|
+
|
|
154
|
+
input_embeddings_features = self.get_input_embeddings(
|
|
155
|
+
input_ids, pixel_values, **kwargs
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
kwargs = {
|
|
159
|
+
"pixel_values": pixel_values,
|
|
160
|
+
**kwargs,
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
logits = self.language_model(
|
|
164
|
+
input_ids,
|
|
165
|
+
input_embeddings_features.inputs_embeds,
|
|
166
|
+
mask=mask,
|
|
167
|
+
cache=cache,
|
|
168
|
+
**kwargs,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
return logits
|
|
172
|
+
|
|
173
|
+
def sanitize(self, weights):
|
|
174
|
+
def transform_key(key):
|
|
175
|
+
if "vision_tower" not in key:
|
|
176
|
+
key = key.replace("visual", "vision_tower")
|
|
177
|
+
if "language_model" not in key:
|
|
178
|
+
if "model" in key:
|
|
179
|
+
key = key.replace("model", "language_model.model")
|
|
180
|
+
elif "lm_head" in key:
|
|
181
|
+
key = key.replace("lm_head", "language_model.lm_head")
|
|
182
|
+
return key
|
|
183
|
+
|
|
184
|
+
return {transform_key(k): v for k, v in weights.items()}
|
|
@@ -0,0 +1,414 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
import mlx.core as mx
|
|
4
|
+
import mlx.nn as nn
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
from .config import VisionConfig
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def check_array_shape(arr):
|
|
11
|
+
shape = arr.shape
|
|
12
|
+
|
|
13
|
+
# Check if the shape has 4 dimensions
|
|
14
|
+
if len(shape) not in [4, 5]:
|
|
15
|
+
return False
|
|
16
|
+
|
|
17
|
+
B, out_channels, kH, KW, t = shape
|
|
18
|
+
|
|
19
|
+
if t == 3:
|
|
20
|
+
return True
|
|
21
|
+
|
|
22
|
+
# Check if out_channels is the largest, and kH and KW are the same
|
|
23
|
+
if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
|
|
24
|
+
return True
|
|
25
|
+
else:
|
|
26
|
+
return False
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def rotate_half(x):
|
|
30
|
+
"""Rotates half the hidden dims of the input."""
|
|
31
|
+
x1 = x[..., : x.shape[-1] // 2]
|
|
32
|
+
x2 = x[..., x.shape[-1] // 2 :]
|
|
33
|
+
return mx.concatenate([-x2, x1], axis=-1)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def apply_rotary_pos_emb_vision(tensor, freqs) -> mx.array:
|
|
37
|
+
orig_dtype = tensor.dtype
|
|
38
|
+
|
|
39
|
+
cos = mx.cos(freqs)
|
|
40
|
+
sin = mx.sin(freqs)
|
|
41
|
+
|
|
42
|
+
cos = mx.expand_dims(cos, axis=1) # Equivalent to unsqueeze(1)
|
|
43
|
+
cos = mx.tile(cos, (1, 1, 2)) # Equivalent to repeat(1, 1, 2)
|
|
44
|
+
cos = mx.expand_dims(cos, axis=0) # Equivalent to [None, ...]
|
|
45
|
+
|
|
46
|
+
sin = mx.expand_dims(sin, axis=1) # Equivalent to unsqueeze(1)
|
|
47
|
+
sin = mx.tile(sin, (1, 1, 2)) # Equivalent to repeat(1, 1, 2)
|
|
48
|
+
sin = mx.expand_dims(sin, axis=0) # Equivalent to [None, ...]
|
|
49
|
+
|
|
50
|
+
output = (tensor * cos) + (rotate_half(tensor) * sin)
|
|
51
|
+
return output.astype(orig_dtype)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class VisionRotaryEmbedding(nn.Module):
|
|
55
|
+
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
|
56
|
+
super().__init__()
|
|
57
|
+
self.dim = dim
|
|
58
|
+
self.theta = theta
|
|
59
|
+
|
|
60
|
+
def __call__(self, seqlen: int) -> mx.array:
|
|
61
|
+
inv_freq = 1.0 / (
|
|
62
|
+
self.theta ** (mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim)
|
|
63
|
+
)
|
|
64
|
+
seq = mx.arange(seqlen.item(), dtype=inv_freq.dtype)
|
|
65
|
+
freqs = mx.outer(seq, inv_freq)
|
|
66
|
+
return freqs
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class PatchEmbed(nn.Module):
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
patch_size: int = 14,
|
|
73
|
+
temporal_patch_size: int = 2,
|
|
74
|
+
in_channels: int = 3,
|
|
75
|
+
hidden_size: int = 1152,
|
|
76
|
+
) -> None:
|
|
77
|
+
super().__init__()
|
|
78
|
+
self.patch_size = patch_size
|
|
79
|
+
self.temporal_patch_size = temporal_patch_size
|
|
80
|
+
self.in_channels = in_channels
|
|
81
|
+
self.hidden_size = hidden_size
|
|
82
|
+
|
|
83
|
+
kernel_size = [temporal_patch_size, patch_size, patch_size]
|
|
84
|
+
self.proj = nn.Conv3d(
|
|
85
|
+
in_channels,
|
|
86
|
+
hidden_size,
|
|
87
|
+
kernel_size=kernel_size,
|
|
88
|
+
stride=kernel_size,
|
|
89
|
+
bias=False,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
def __call__(self, hidden_states: mx.array) -> mx.array:
|
|
93
|
+
hidden_states = hidden_states.reshape(
|
|
94
|
+
-1,
|
|
95
|
+
self.in_channels,
|
|
96
|
+
self.temporal_patch_size,
|
|
97
|
+
self.patch_size,
|
|
98
|
+
self.patch_size,
|
|
99
|
+
).moveaxis(1, 4)
|
|
100
|
+
|
|
101
|
+
hidden_states = self.proj(hidden_states)
|
|
102
|
+
hidden_states = hidden_states.reshape(-1, self.hidden_size)
|
|
103
|
+
return hidden_states
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class PatchMerger(nn.Module):
|
|
107
|
+
def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
|
|
108
|
+
super().__init__()
|
|
109
|
+
self.hidden_size = context_dim * (spatial_merge_size**2)
|
|
110
|
+
self.ln_q = nn.RMSNorm(context_dim, eps=1e-6)
|
|
111
|
+
self.mlp = [
|
|
112
|
+
nn.Linear(self.hidden_size, self.hidden_size),
|
|
113
|
+
nn.GELU(),
|
|
114
|
+
nn.Linear(self.hidden_size, dim),
|
|
115
|
+
]
|
|
116
|
+
|
|
117
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
118
|
+
x = self.ln_q(x).reshape(-1, self.hidden_size)
|
|
119
|
+
for layer in self.mlp:
|
|
120
|
+
x = layer(x)
|
|
121
|
+
return x
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class Attention(nn.Module):
|
|
125
|
+
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
|
126
|
+
super().__init__()
|
|
127
|
+
self.num_heads = num_heads
|
|
128
|
+
self.head_dim = head_dim = dim // num_heads
|
|
129
|
+
self.scale = head_dim**-0.5
|
|
130
|
+
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
|
131
|
+
self.proj = nn.Linear(dim, dim)
|
|
132
|
+
|
|
133
|
+
def __call__(
|
|
134
|
+
self, x: mx.array, cu_seqlens: mx.array, rotary_pos_emb: mx.array = None
|
|
135
|
+
) -> mx.array:
|
|
136
|
+
seq_length = x.shape[0]
|
|
137
|
+
qkv = (
|
|
138
|
+
self.qkv(x).reshape(seq_length, 3, self.num_heads, -1).transpose(1, 0, 2, 3)
|
|
139
|
+
)
|
|
140
|
+
q, k, v = mx.split(qkv, 3)
|
|
141
|
+
|
|
142
|
+
q = apply_rotary_pos_emb_vision(mx.expand_dims(q, 0), rotary_pos_emb)[0]
|
|
143
|
+
k = apply_rotary_pos_emb_vision(mx.expand_dims(k, 0), rotary_pos_emb)[0]
|
|
144
|
+
attention_mask = mx.full(
|
|
145
|
+
(1, seq_length, seq_length), mx.finfo(q.dtype).min, dtype=q.dtype
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
for i in range(1, len(cu_seqlens)):
|
|
149
|
+
start = int(cu_seqlens[i - 1])
|
|
150
|
+
end = int(cu_seqlens[i])
|
|
151
|
+
attention_mask[..., start:end, start:end] = 0
|
|
152
|
+
|
|
153
|
+
q = q.transpose(0, 2, 1, 3)
|
|
154
|
+
k = k.transpose(0, 2, 1, 3)
|
|
155
|
+
v = v.transpose(0, 2, 1, 3)
|
|
156
|
+
|
|
157
|
+
output = mx.fast.scaled_dot_product_attention(
|
|
158
|
+
q, k, v, scale=self.scale, mask=attention_mask
|
|
159
|
+
)
|
|
160
|
+
output = output.transpose(0, 2, 1, 3)
|
|
161
|
+
output = output.reshape(seq_length, -1)
|
|
162
|
+
return self.proj(output)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class MLP(nn.Module):
|
|
166
|
+
def __init__(self, dim, hidden_dim):
|
|
167
|
+
super().__init__()
|
|
168
|
+
self.gate_proj = nn.Linear(dim, hidden_dim)
|
|
169
|
+
self.up_proj = nn.Linear(dim, hidden_dim)
|
|
170
|
+
self.down_proj = nn.Linear(hidden_dim, dim)
|
|
171
|
+
|
|
172
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
173
|
+
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class Qwen2VLVisionBlock(nn.Module):
|
|
177
|
+
def __init__(self, config: VisionConfig) -> None:
|
|
178
|
+
super().__init__()
|
|
179
|
+
self.norm1 = nn.RMSNorm(config.hidden_size, eps=1e-6)
|
|
180
|
+
self.norm2 = nn.RMSNorm(config.hidden_size, eps=1e-6)
|
|
181
|
+
|
|
182
|
+
self.attn = Attention(dim=config.hidden_size, num_heads=config.num_heads)
|
|
183
|
+
self.mlp = MLP(dim=config.hidden_size, hidden_dim=config.intermediate_size)
|
|
184
|
+
|
|
185
|
+
def __call__(self, hidden_states, cu_seqlens, rotary_pos_emb) -> mx.array:
|
|
186
|
+
hidden_states = hidden_states + self.attn(
|
|
187
|
+
self.norm1(hidden_states),
|
|
188
|
+
cu_seqlens=cu_seqlens,
|
|
189
|
+
rotary_pos_emb=rotary_pos_emb,
|
|
190
|
+
)
|
|
191
|
+
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
|
192
|
+
return hidden_states
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
class VisionModel(nn.Module):
|
|
196
|
+
|
|
197
|
+
def __init__(self, config: VisionConfig) -> None:
|
|
198
|
+
super().__init__()
|
|
199
|
+
self.config = config
|
|
200
|
+
self.model_type = config.model_type
|
|
201
|
+
if self.model_type != "qwen2_5_vl":
|
|
202
|
+
raise ValueError(f"Unsupported model type: {self.model_type}")
|
|
203
|
+
self.spatial_merge_size = config.spatial_merge_size
|
|
204
|
+
|
|
205
|
+
self.patch_embed = PatchEmbed(
|
|
206
|
+
patch_size=config.patch_size,
|
|
207
|
+
temporal_patch_size=config.temporal_patch_size,
|
|
208
|
+
in_channels=config.in_channels,
|
|
209
|
+
hidden_size=config.hidden_size,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
self.window_size = config.window_size
|
|
213
|
+
self.patch_size = config.patch_size
|
|
214
|
+
self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
|
|
215
|
+
self.fullatt_block_indexes = config.fullatt_block_indexes
|
|
216
|
+
head_dim = config.hidden_size // config.num_heads
|
|
217
|
+
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
|
218
|
+
|
|
219
|
+
self.blocks = [Qwen2VLVisionBlock(config) for _ in range(config.depth)]
|
|
220
|
+
self.merger = PatchMerger(
|
|
221
|
+
dim=config.out_hidden_size, context_dim=config.hidden_size
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
def rot_pos_emb(self, grid_thw):
|
|
225
|
+
pos_ids = []
|
|
226
|
+
|
|
227
|
+
for t, h, w in grid_thw.tolist():
|
|
228
|
+
hpos_ids = mx.expand_dims(mx.arange(h), 1)
|
|
229
|
+
hpos_ids = mx.repeat(hpos_ids, w, axis=1)
|
|
230
|
+
hpos_ids = hpos_ids.reshape(
|
|
231
|
+
h // self.spatial_merge_size,
|
|
232
|
+
self.spatial_merge_size,
|
|
233
|
+
w // self.spatial_merge_size,
|
|
234
|
+
self.spatial_merge_size,
|
|
235
|
+
)
|
|
236
|
+
hpos_ids = mx.transpose(hpos_ids, (0, 2, 1, 3))
|
|
237
|
+
hpos_ids = hpos_ids.flatten()
|
|
238
|
+
|
|
239
|
+
wpos_ids = mx.expand_dims(mx.arange(w), 0)
|
|
240
|
+
wpos_ids = mx.repeat(wpos_ids, h, axis=0)
|
|
241
|
+
wpos_ids = wpos_ids.reshape(
|
|
242
|
+
h // self.spatial_merge_size,
|
|
243
|
+
self.spatial_merge_size,
|
|
244
|
+
w // self.spatial_merge_size,
|
|
245
|
+
self.spatial_merge_size,
|
|
246
|
+
)
|
|
247
|
+
wpos_ids = mx.transpose(wpos_ids, (0, 2, 1, 3))
|
|
248
|
+
wpos_ids = wpos_ids.flatten()
|
|
249
|
+
|
|
250
|
+
stacked_pos_ids = mx.stack([hpos_ids, wpos_ids], axis=-1)
|
|
251
|
+
pos_ids.append(mx.tile(stacked_pos_ids, (t, 1)))
|
|
252
|
+
|
|
253
|
+
pos_ids = mx.concatenate(pos_ids, axis=0)
|
|
254
|
+
max_grid_size = mx.max(grid_thw[:, 1:])
|
|
255
|
+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
|
256
|
+
rotary_pos_emb = rotary_pos_emb_full[pos_ids]
|
|
257
|
+
|
|
258
|
+
return rotary_pos_emb.reshape(pos_ids.shape[0], -1)
|
|
259
|
+
|
|
260
|
+
def get_window_index(self, grid_thw):
|
|
261
|
+
window_index = []
|
|
262
|
+
cu_window_seqlens = [0]
|
|
263
|
+
window_index_id = 0
|
|
264
|
+
vit_merger_window_size = (
|
|
265
|
+
self.window_size // self.spatial_merge_size // self.patch_size
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
for grid_t, grid_h, grid_w in grid_thw.tolist():
|
|
269
|
+
llm_grid_h = grid_h // self.spatial_merge_size
|
|
270
|
+
llm_grid_w = grid_w // self.spatial_merge_size
|
|
271
|
+
|
|
272
|
+
index = mx.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
|
|
273
|
+
grid_t, llm_grid_h, llm_grid_w
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
|
|
277
|
+
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
|
|
278
|
+
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
|
|
279
|
+
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
|
|
280
|
+
|
|
281
|
+
# Replace F.pad with np.pad
|
|
282
|
+
index_padded = mx.pad(
|
|
283
|
+
index,
|
|
284
|
+
((0, 0), (0, pad_h), (0, pad_w)),
|
|
285
|
+
mode="constant",
|
|
286
|
+
constant_values=-100,
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
index_padded = index_padded.reshape(
|
|
290
|
+
grid_t,
|
|
291
|
+
num_windows_h,
|
|
292
|
+
vit_merger_window_size,
|
|
293
|
+
num_windows_w,
|
|
294
|
+
vit_merger_window_size,
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
# Replace permute with np.transpose
|
|
298
|
+
index_padded = mx.transpose(index_padded, (0, 1, 3, 2, 4)).reshape(
|
|
299
|
+
grid_t,
|
|
300
|
+
num_windows_h * num_windows_w,
|
|
301
|
+
vit_merger_window_size,
|
|
302
|
+
vit_merger_window_size,
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
# Replace torch operations with numpy
|
|
306
|
+
seqlens = mx.sum(index_padded != -100, axis=(2, 3)).reshape(-1)
|
|
307
|
+
index_padded = index_padded.reshape(-1)
|
|
308
|
+
index = np.where(index_padded != -100)[
|
|
309
|
+
0
|
|
310
|
+
].tolist() # [i for i, x in enumerate(index_padded) if x != -100]
|
|
311
|
+
index_new = index_padded[index]
|
|
312
|
+
|
|
313
|
+
window_index.append(index_new + window_index_id)
|
|
314
|
+
cu_seqlens_tmp = (
|
|
315
|
+
mx.cumsum(seqlens, axis=0) * self.spatial_merge_unit
|
|
316
|
+
+ cu_window_seqlens[-1]
|
|
317
|
+
)
|
|
318
|
+
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
|
|
319
|
+
window_index_id += int(grid_t * llm_grid_h * llm_grid_w)
|
|
320
|
+
|
|
321
|
+
# Replace torch.cat with np.concatenate
|
|
322
|
+
window_index = mx.concatenate(window_index, axis=0)
|
|
323
|
+
cu_window_seqlens = mx.array(cu_window_seqlens)
|
|
324
|
+
|
|
325
|
+
return window_index, cu_window_seqlens
|
|
326
|
+
|
|
327
|
+
def __call__(
|
|
328
|
+
self,
|
|
329
|
+
hidden_states: mx.array,
|
|
330
|
+
grid_thw: mx.array,
|
|
331
|
+
output_hidden_states: Optional[bool] = None,
|
|
332
|
+
) -> mx.array:
|
|
333
|
+
|
|
334
|
+
hidden_states = self.patch_embed(hidden_states)
|
|
335
|
+
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
|
336
|
+
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
|
|
337
|
+
|
|
338
|
+
# Get indices of first occurrence of each unique value
|
|
339
|
+
seen = set()
|
|
340
|
+
idx = []
|
|
341
|
+
for i, x in enumerate(cu_window_seqlens):
|
|
342
|
+
if x not in seen:
|
|
343
|
+
seen.add(x)
|
|
344
|
+
idx.append(i)
|
|
345
|
+
|
|
346
|
+
idx = mx.array(idx, dtype=mx.int32)
|
|
347
|
+
cu_window_seqlens = cu_window_seqlens[idx]
|
|
348
|
+
|
|
349
|
+
seq_len, _ = hidden_states.shape
|
|
350
|
+
hidden_states = hidden_states.reshape(
|
|
351
|
+
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
|
|
352
|
+
)
|
|
353
|
+
hidden_states = hidden_states[window_index, :, :]
|
|
354
|
+
hidden_states = hidden_states.reshape(seq_len, -1)
|
|
355
|
+
rotary_pos_emb = rotary_pos_emb.reshape(
|
|
356
|
+
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
|
|
357
|
+
)
|
|
358
|
+
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
|
|
359
|
+
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
|
360
|
+
|
|
361
|
+
# Assuming grid_thw has shape (batch_size, 3)
|
|
362
|
+
batch_size = grid_thw.shape[0]
|
|
363
|
+
|
|
364
|
+
# Calculate cu_seqlens for each item in the batch
|
|
365
|
+
cu_seqlens = []
|
|
366
|
+
for i in range(batch_size):
|
|
367
|
+
seq_len = grid_thw[i, 1] * grid_thw[i, 2]
|
|
368
|
+
cu_seqlens.append(mx.repeat(seq_len, grid_thw[i, 0]))
|
|
369
|
+
|
|
370
|
+
# Concatenate the cu_seqlens for all items in the batch
|
|
371
|
+
cu_seqlens = mx.concatenate(cu_seqlens)
|
|
372
|
+
|
|
373
|
+
cu_seqlens = mx.cumsum(cu_seqlens.astype(mx.int32), axis=0)
|
|
374
|
+
cu_seqlens = mx.pad(cu_seqlens, (1, 0), mode="constant", constant_values=0)
|
|
375
|
+
|
|
376
|
+
encoder_states = (hidden_states,) if output_hidden_states else None
|
|
377
|
+
|
|
378
|
+
for layer_num, blk in enumerate(self.blocks):
|
|
379
|
+
if layer_num in self.fullatt_block_indexes:
|
|
380
|
+
cu_seqlens_now = cu_seqlens
|
|
381
|
+
else:
|
|
382
|
+
cu_seqlens_now = cu_window_seqlens
|
|
383
|
+
|
|
384
|
+
hidden_states = blk(
|
|
385
|
+
hidden_states, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
if output_hidden_states:
|
|
389
|
+
encoder_states = encoder_states + (hidden_states,)
|
|
390
|
+
|
|
391
|
+
hidden_states = self.merger(hidden_states)
|
|
392
|
+
reverse_indices = mx.argsort(window_index, axis=0)
|
|
393
|
+
hidden_states = hidden_states[reverse_indices, :]
|
|
394
|
+
return hidden_states
|
|
395
|
+
|
|
396
|
+
def sanitize(self, weights):
|
|
397
|
+
sanitized_weights = {}
|
|
398
|
+
for k, v in weights.items():
|
|
399
|
+
if "position_ids" in k:
|
|
400
|
+
# Remove unused position_ids
|
|
401
|
+
continue
|
|
402
|
+
elif "patch_embed.proj.weight" in k:
|
|
403
|
+
# PyTorch conv2d weight tensors have shape:
|
|
404
|
+
# [out_channels, in_channels, kH, KW]
|
|
405
|
+
# MLX conv2d expects the weight be of shape:
|
|
406
|
+
# [out_channels, kH, KW, in_channels]
|
|
407
|
+
if check_array_shape(v):
|
|
408
|
+
sanitized_weights[k] = v
|
|
409
|
+
else:
|
|
410
|
+
sanitized_weights[k] = v.transpose(0, 2, 3, 4, 1)
|
|
411
|
+
else:
|
|
412
|
+
sanitized_weights[k] = v
|
|
413
|
+
|
|
414
|
+
return sanitized_weights
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Dict, List, Optional, Union
|
|
4
|
+
|
|
5
|
+
from ..base import BaseModelConfig
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class VisionConfig(BaseModelConfig):
|
|
10
|
+
model_type: str = "qwen2_vl"
|
|
11
|
+
depth: int = 32
|
|
12
|
+
embed_dim: int = 1280
|
|
13
|
+
hidden_size: int = 1536
|
|
14
|
+
num_heads: int = 16
|
|
15
|
+
image_size: int = 384
|
|
16
|
+
patch_size: int = 14
|
|
17
|
+
vocab_size: int = 32000
|
|
18
|
+
mlp_ratio: float = 4.0
|
|
19
|
+
in_channels: int = 3
|
|
20
|
+
layer_norm_eps: float = 1e-6
|
|
21
|
+
spatial_patch_size: int = 14
|
|
22
|
+
spatial_merge_size: int = 2
|
|
23
|
+
temporal_patch_size: int = 2
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class TextConfig(BaseModelConfig):
|
|
28
|
+
model_type: str
|
|
29
|
+
hidden_size: int
|
|
30
|
+
num_hidden_layers: int
|
|
31
|
+
intermediate_size: int
|
|
32
|
+
num_attention_heads: int
|
|
33
|
+
rms_norm_eps: float
|
|
34
|
+
vocab_size: int
|
|
35
|
+
num_key_value_heads: Optional[int] = 8
|
|
36
|
+
max_position_embeddings: Optional[int] = 40960
|
|
37
|
+
rope_theta: float = 1000000.0
|
|
38
|
+
rope_traditional: bool = False
|
|
39
|
+
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
|
40
|
+
tie_word_embeddings: bool = False
|
|
41
|
+
sliding_window: int = 32768
|
|
42
|
+
use_sliding_window: bool = False
|
|
43
|
+
use_cache: bool = True
|
|
44
|
+
|
|
45
|
+
def __post_init__(self):
|
|
46
|
+
if self.num_key_value_heads is None:
|
|
47
|
+
self.num_key_value_heads = self.num_attention_heads
|
|
48
|
+
|
|
49
|
+
if self.rope_scaling:
|
|
50
|
+
required_keys = {"mrope_section", "type"}
|
|
51
|
+
if not all(key in self.rope_scaling for key in required_keys):
|
|
52
|
+
raise ValueError(f"rope_scaling must contain keys {required_keys}")
|
|
53
|
+
|
|
54
|
+
if not self.rope_scaling["type"] in ["mrope", "default"]:
|
|
55
|
+
raise ValueError(f"rope_scaling type must be 'mrope' or 'default'")
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@dataclass
|
|
59
|
+
class ModelConfig(BaseModelConfig):
|
|
60
|
+
text_config: TextConfig
|
|
61
|
+
vision_config: VisionConfig
|
|
62
|
+
model_type: str
|
|
63
|
+
ignore_index: int = -100
|
|
64
|
+
image_token_id: int = 151655
|
|
65
|
+
video_token_id: int = 151656
|
|
66
|
+
vision_start_token_id: int = 151652
|
|
67
|
+
vision_feature_select_strategy: str = "default"
|
|
68
|
+
vision_feature_layer: int = -2
|
|
69
|
+
vocab_size: int = 32000
|
|
70
|
+
eos_token_id: Optional[List[int]] = None
|
|
71
|
+
|
|
72
|
+
@classmethod
|
|
73
|
+
def from_dict(cls, params):
|
|
74
|
+
# Copy text config parameters from root level
|
|
75
|
+
excluded_keys = {"vision_config"}
|
|
76
|
+
params["text_config"] = dict(
|
|
77
|
+
filter(lambda x: x[0] not in excluded_keys, params.items())
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
return cls(
|
|
81
|
+
**{
|
|
82
|
+
k: v
|
|
83
|
+
for k, v in params.items()
|
|
84
|
+
if k in inspect.signature(cls).parameters
|
|
85
|
+
}
|
|
86
|
+
)
|