nexaai 1.0.19rc7__cp310-cp310-macosx_14_0_universal2.whl → 1.0.19rc9__cp310-cp310-macosx_14_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nexaai might be problematic. Click here for more details.
- nexaai/_stub.cpython-310-darwin.so +0 -0
- nexaai/_version.py +1 -1
- nexaai/binds/libnexa_bridge.dylib +0 -0
- nexaai/mlx_backend/vlm/generate_qwen3_vl.py +14 -31
- nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +15 -32
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +7 -23
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +8 -24
- {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc9.dist-info}/METADATA +1 -1
- {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc9.dist-info}/RECORD +11 -200
- nexaai/binds/nexa_mlx/py-lib/asr/__init__.py +0 -12
- nexaai/binds/nexa_mlx/py-lib/asr/interface.py +0 -122
- nexaai/binds/nexa_mlx/py-lib/common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/common/utils.py +0 -25
- nexaai/binds/nexa_mlx/py-lib/cv/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/cv/generate.py +0 -195
- nexaai/binds/nexa_mlx/py-lib/cv/interface.py +0 -151
- nexaai/binds/nexa_mlx/py-lib/cv/main.py +0 -81
- nexaai/binds/nexa_mlx/py-lib/cv/modeling/pp_ocr_v4.py +0 -1736
- nexaai/binds/nexa_mlx/py-lib/embedding/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/generate.py +0 -333
- nexaai/binds/nexa_mlx/py-lib/embedding/interface.py +0 -617
- nexaai/binds/nexa_mlx/py-lib/embedding/main.py +0 -173
- nexaai/binds/nexa_mlx/py-lib/embedding/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/modeling/nexa_jina_v2.py +0 -399
- nexaai/binds/nexa_mlx/py-lib/image_gen/__init__.py +0 -1
- nexaai/binds/nexa_mlx/py-lib/image_gen/generate_sd.py +0 -244
- nexaai/binds/nexa_mlx/py-lib/image_gen/interface.py +0 -82
- nexaai/binds/nexa_mlx/py-lib/image_gen/main.py +0 -281
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/__init__.py +0 -306
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/clip.py +0 -116
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/config.py +0 -65
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/model_io.py +0 -386
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/sampler.py +0 -105
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/tokenizer.py +0 -100
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/unet.py +0 -460
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/vae.py +0 -274
- nexaai/binds/nexa_mlx/py-lib/llm/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/llm/generate.py +0 -149
- nexaai/binds/nexa_mlx/py-lib/llm/interface.py +0 -764
- nexaai/binds/nexa_mlx/py-lib/llm/main.py +0 -68
- nexaai/binds/nexa_mlx/py-lib/rerank/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/generate.py +0 -174
- nexaai/binds/nexa_mlx/py-lib/rerank/interface.py +0 -287
- nexaai/binds/nexa_mlx/py-lib/rerank/main.py +0 -127
- nexaai/binds/nexa_mlx/py-lib/rerank/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/modeling/nexa_jina_rerank.py +0 -330
- nexaai/binds/nexa_mlx/py-lib/sd/__init__.py +0 -1
- nexaai/binds/nexa_mlx/py-lib/sd/interface.py +0 -362
- nexaai/binds/nexa_mlx/py-lib/sd/main.py +0 -286
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/__init__.py +0 -306
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/clip.py +0 -116
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/config.py +0 -65
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/model_io.py +0 -385
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/sampler.py +0 -105
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/tokenizer.py +0 -100
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/unet.py +0 -460
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/vae.py +0 -274
- nexaai/binds/nexa_mlx/py-lib/tts/__init__.py +0 -12
- nexaai/binds/nexa_mlx/py-lib/tts/interface.py +0 -276
- nexaai/binds/nexa_mlx/py-lib/vlm/__init__.py +0 -3
- nexaai/binds/nexa_mlx/py-lib/vlm/generate.py +0 -572
- nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl.py +0 -294
- nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl_moe.py +0 -276
- nexaai/binds/nexa_mlx/py-lib/vlm/interface.py +0 -504
- nexaai/binds/nexa_mlx/py-lib/vlm/main.py +0 -320
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/convert.py +0 -68
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/aya_vision.py +0 -193
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/interpolate.py +0 -186
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/language.py +0 -233
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/vision.py +0 -503
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/base.py +0 -202
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/cache.py +0 -230
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/__init__.py +0 -10
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/conversation.py +0 -264
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +0 -472
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/language.py +0 -591
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +0 -526
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/vision.py +0 -356
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/florence2.py +0 -366
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/language.py +0 -488
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/vision.py +0 -591
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/gemma3.py +0 -213
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/language.py +0 -315
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/vision.py +0 -238
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/__init__.py +0 -2
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/audio.py +0 -1038
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/config.py +0 -139
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/gemma3n.py +0 -322
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/language.py +0 -629
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/vision.py +0 -1022
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/__init__.py +0 -9
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/idefics2.py +0 -294
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/language.py +0 -191
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/vision.py +0 -267
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/idefics3.py +0 -175
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/language.py +0 -192
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/vision.py +0 -233
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/__init__.py +0 -9
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/internvl_chat.py +0 -140
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/language.py +0 -220
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/processor.py +0 -393
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/vision.py +0 -293
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kernels.py +0 -307
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/kimi_vl.py +0 -143
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/language.py +0 -509
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/vision.py +0 -522
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/language.py +0 -386
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/llama4.py +0 -138
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/vision.py +0 -560
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/language.py +0 -240
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/llava.py +0 -153
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/vision.py +0 -259
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/__init__.py +0 -9
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/language.py +0 -236
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/llava_bunny.py +0 -256
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/vision.py +0 -303
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/language.py +0 -230
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/llava_next.py +0 -160
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/vision.py +0 -243
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/mistral3.py +0 -283
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/language.py +0 -416
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/mllama.py +0 -172
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/vision.py +0 -499
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/language.py +0 -243
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/molmo.py +0 -133
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/vision.py +0 -465
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/__init__.py +0 -10
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/language.py +0 -230
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/multi_modality.py +0 -385
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/sam.py +0 -557
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/vision.py +0 -526
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/language.py +0 -282
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/paligemma.py +0 -160
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/vision.py +0 -242
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/language.py +0 -21
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/phi3_v.py +0 -243
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/su_rope.py +0 -71
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/vision.py +0 -324
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/language.py +0 -229
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/pixtral.py +0 -161
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/vision.py +0 -320
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/__init__.py +0 -2
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/config.py +0 -108
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/language.py +0 -490
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +0 -168
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/vision.py +0 -414
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/__init__.py +0 -2
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/config.py +0 -104
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/language.py +0 -490
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/qwen2_vl.py +0 -167
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/vision.py +0 -312
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/base.py +0 -117
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/cache.py +0 -531
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/generate.py +0 -701
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +0 -255
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +0 -303
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +0 -407
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/processor.py +0 -476
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/qwen3vl.py +0 -1223
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +0 -117
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +0 -531
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +0 -701
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +0 -255
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +0 -303
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +0 -407
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/processor.py +0 -476
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +0 -1309
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/switch_layers.py +0 -210
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/smolvlm.py +0 -62
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_5_vl.py +0 -209
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_vl.py +0 -215
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/prompt_utils.py +0 -474
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/sample_utils.py +0 -39
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/tokenizer_utils.py +0 -344
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/__init__.py +0 -9
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/lora.py +0 -70
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/trainer.py +0 -296
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/utils.py +0 -160
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/utils.py +0 -928
- {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc9.dist-info}/WHEEL +0 -0
- {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc9.dist-info}/top_level.txt +0 -0
|
@@ -1,522 +0,0 @@
|
|
|
1
|
-
import inspect
|
|
2
|
-
from dataclasses import dataclass
|
|
3
|
-
from typing import List, Optional
|
|
4
|
-
|
|
5
|
-
import mlx.core as mx
|
|
6
|
-
import mlx.nn as nn
|
|
7
|
-
|
|
8
|
-
from ..kernels import bicubic_interpolate
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
@dataclass
|
|
12
|
-
class VisionConfig:
|
|
13
|
-
model_type: str = "moonvit"
|
|
14
|
-
depth: int = 27
|
|
15
|
-
embed_dim: int = 1152
|
|
16
|
-
hidden_size: int = 1152
|
|
17
|
-
num_heads: int = 16
|
|
18
|
-
image_size: int = 384
|
|
19
|
-
patch_size: int = 14
|
|
20
|
-
vocab_size: int = 32000
|
|
21
|
-
mlp_ratio: float = 4.0
|
|
22
|
-
num_channels: int = 3
|
|
23
|
-
layer_norm_eps: float = 1e-6
|
|
24
|
-
intermediate_size: int = 4304
|
|
25
|
-
init_pos_emb_height: int = 64
|
|
26
|
-
init_pos_emb_width: int = 64
|
|
27
|
-
spatial_patch_size: int = 14
|
|
28
|
-
spatial_merge_size: int = 2
|
|
29
|
-
temporal_patch_size: int = 2
|
|
30
|
-
merge_kernel_size: list[int, int] = None
|
|
31
|
-
|
|
32
|
-
def __post_init__(self):
|
|
33
|
-
if self.merge_kernel_size is None:
|
|
34
|
-
self.merge_kernel_size = (self.spatial_merge_size, self.spatial_merge_size)
|
|
35
|
-
|
|
36
|
-
@classmethod
|
|
37
|
-
def from_dict(cls, params):
|
|
38
|
-
return cls(
|
|
39
|
-
**{
|
|
40
|
-
k: v
|
|
41
|
-
for k, v in params.items()
|
|
42
|
-
if k in inspect.signature(cls).parameters
|
|
43
|
-
}
|
|
44
|
-
)
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
def check_array_shape(arr):
|
|
48
|
-
shape = arr.shape
|
|
49
|
-
|
|
50
|
-
# Check if the shape has 4 dimensions
|
|
51
|
-
if len(shape) != 4:
|
|
52
|
-
return False
|
|
53
|
-
|
|
54
|
-
out_channels, kH, KW, _ = shape
|
|
55
|
-
|
|
56
|
-
# Check if out_channels is the largest, and kH and KW are the same
|
|
57
|
-
if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
|
|
58
|
-
return True
|
|
59
|
-
else:
|
|
60
|
-
return False
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
def rotate_half(x):
|
|
64
|
-
"""Rotates half the hidden dims of the input."""
|
|
65
|
-
x1 = x[..., : x.shape[-1] // 2]
|
|
66
|
-
x2 = x[..., x.shape[-1] // 2 :]
|
|
67
|
-
return mx.concatenate([-x2, x1], axis=-1)
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
def apply_rotary_pos_emb_vision(tensor, freqs) -> mx.array:
|
|
71
|
-
orig_dtype = tensor.dtype
|
|
72
|
-
|
|
73
|
-
cos = mx.cos(freqs)
|
|
74
|
-
sin = mx.sin(freqs)
|
|
75
|
-
|
|
76
|
-
cos = mx.expand_dims(cos, axis=1) # Equivalent to unsqueeze(1)
|
|
77
|
-
cos = mx.tile(cos, (1, 1, 2)) # Equivalent to repeat(1, 1, 2)
|
|
78
|
-
cos = mx.expand_dims(cos, axis=0) # Equivalent to [None, ...]
|
|
79
|
-
|
|
80
|
-
sin = mx.expand_dims(sin, axis=1) # Equivalent to unsqueeze(1)
|
|
81
|
-
sin = mx.tile(sin, (1, 1, 2)) # Equivalent to repeat(1, 1, 2)
|
|
82
|
-
sin = mx.expand_dims(sin, axis=0) # Equivalent to [None, ...]
|
|
83
|
-
|
|
84
|
-
output = (tensor * cos) + (rotate_half(tensor) * sin)
|
|
85
|
-
return output.astype(orig_dtype)
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
class VisionRotaryEmbedding(nn.Module):
|
|
89
|
-
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
|
90
|
-
super().__init__()
|
|
91
|
-
self.dim = dim
|
|
92
|
-
self.theta = theta
|
|
93
|
-
|
|
94
|
-
def __call__(self, seqlen: int) -> mx.array:
|
|
95
|
-
inv_freq = 1.0 / (
|
|
96
|
-
self.theta ** (mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim)
|
|
97
|
-
)
|
|
98
|
-
seq = mx.arange(seqlen.tolist(), dtype=inv_freq.dtype)
|
|
99
|
-
freqs = mx.outer(seq, inv_freq)
|
|
100
|
-
return freqs
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
class Learnable2DInterpPosEmb(nn.Module):
|
|
104
|
-
def __init__(
|
|
105
|
-
self, height: int, width: int, dim: int, interpolation_mode: str = "bicubic"
|
|
106
|
-
) -> None:
|
|
107
|
-
super().__init__()
|
|
108
|
-
self.height = height
|
|
109
|
-
self.width = width
|
|
110
|
-
self.interpolation_mode = interpolation_mode
|
|
111
|
-
self.weight = mx.ones((height, width, dim))
|
|
112
|
-
|
|
113
|
-
def __call__(self, x: mx.array, grid_hws: mx.array) -> mx.array:
|
|
114
|
-
pos_embs = []
|
|
115
|
-
for shape in grid_hws.tolist():
|
|
116
|
-
if shape == self.weight.shape[:-1]:
|
|
117
|
-
pos_embs.append(self.weight.flatten(end_axis=1))
|
|
118
|
-
else:
|
|
119
|
-
result = (
|
|
120
|
-
bicubic_interpolate(
|
|
121
|
-
mx.expand_dims(self.weight.transpose(2, 0, 1), axis=0),
|
|
122
|
-
size=shape,
|
|
123
|
-
)
|
|
124
|
-
.squeeze(0)
|
|
125
|
-
.transpose(1, 2, 0)
|
|
126
|
-
.flatten(end_axis=1)
|
|
127
|
-
)
|
|
128
|
-
|
|
129
|
-
pos_embs.append(result)
|
|
130
|
-
|
|
131
|
-
out = x + mx.concatenate(pos_embs).astype(x.dtype)
|
|
132
|
-
return out
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
class PatchEmbed(nn.Module):
|
|
136
|
-
def __init__(
|
|
137
|
-
self,
|
|
138
|
-
patch_size: int = 14,
|
|
139
|
-
num_channels: int = 3,
|
|
140
|
-
embed_dim: int = 1152,
|
|
141
|
-
init_pos_emb_height: int = 64,
|
|
142
|
-
) -> None:
|
|
143
|
-
super().__init__()
|
|
144
|
-
self.patch_size = patch_size
|
|
145
|
-
self.num_channels = num_channels
|
|
146
|
-
self.embed_dim = embed_dim
|
|
147
|
-
self.init_pos_emb_height = init_pos_emb_height
|
|
148
|
-
|
|
149
|
-
self.proj = nn.Conv2d(
|
|
150
|
-
num_channels,
|
|
151
|
-
embed_dim,
|
|
152
|
-
kernel_size=patch_size,
|
|
153
|
-
stride=patch_size,
|
|
154
|
-
bias=True,
|
|
155
|
-
)
|
|
156
|
-
self.pos_emb = Learnable2DInterpPosEmb(
|
|
157
|
-
height=init_pos_emb_height, width=init_pos_emb_height, dim=embed_dim
|
|
158
|
-
)
|
|
159
|
-
|
|
160
|
-
def __call__(self, hidden_states: mx.array, grid_thw: mx.array) -> mx.array:
|
|
161
|
-
hidden_states = self.proj(hidden_states).swapaxes(1, 3)
|
|
162
|
-
hidden_states = hidden_states.reshape(hidden_states.shape[0], -1)
|
|
163
|
-
hidden_states = self.pos_emb(hidden_states, grid_thw)
|
|
164
|
-
return hidden_states
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
def _apply_rope_input_validation(x, freqs_cis):
|
|
168
|
-
assert x.ndim == freqs_cis.ndim + 1, (x.shape, freqs_cis.shape)
|
|
169
|
-
assert x.shape[:-2] == freqs_cis.shape[:-1], (x.shape, freqs_cis.shape)
|
|
170
|
-
assert x.shape[-1] == 2 * freqs_cis.shape[-1], (x.shape, freqs_cis.shape)
|
|
171
|
-
assert freqs_cis.dtype == mx.complex64, freqs_cis.dtype
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
def view_as_complex(x):
|
|
175
|
-
"""
|
|
176
|
-
Convert a tensor with shape (..., 2) to a complex tensor with shape (...).
|
|
177
|
-
"""
|
|
178
|
-
# Get real and imaginary parts
|
|
179
|
-
real, imag = x[..., 0], x[..., 1]
|
|
180
|
-
# Create complex tensor
|
|
181
|
-
return real + 1j * imag
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
def view_as_real(x):
|
|
185
|
-
"""
|
|
186
|
-
Convert a complex tensor with shape (...) to a real tensor with shape (..., 2).
|
|
187
|
-
"""
|
|
188
|
-
# Get real and imaginary parts
|
|
189
|
-
real = mx.real(x)
|
|
190
|
-
imag = mx.imag(x)
|
|
191
|
-
# Combine into a tensor with last dimension 2
|
|
192
|
-
return mx.stack([real, imag], axis=-1)
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
def apply_rope(
|
|
196
|
-
q: mx.array, k: mx.array, freqs_cis: mx.array
|
|
197
|
-
) -> tuple[mx.array, mx.array]:
|
|
198
|
-
"""
|
|
199
|
-
Args: (The leading dimensions of all inputs should be the same)
|
|
200
|
-
q: query, array of shape (..., num_heads, head_dim)
|
|
201
|
-
k: key, array of shape (..., num_heads, head_dim)
|
|
202
|
-
freqs_cis: array of shape (..., head_dim/2), dtype=mx.complex64. It contains the precomputed cis(freqs) for each position in the 2D grid.
|
|
203
|
-
Returns:
|
|
204
|
-
xq_out, xk_out: arrays of shape (..., num_heads, head_dim)
|
|
205
|
-
"""
|
|
206
|
-
_apply_rope_input_validation(q, freqs_cis)
|
|
207
|
-
_apply_rope_input_validation(k, freqs_cis)
|
|
208
|
-
|
|
209
|
-
freqs_cis = mx.expand_dims(freqs_cis, axis=-2) # ..., 1, head_dim/2
|
|
210
|
-
# ..., num_heads, head_dim/2
|
|
211
|
-
q_ = view_as_complex(q.astype(mx.float32).reshape(*q.shape[:-1], -1, 2))
|
|
212
|
-
k_ = view_as_complex(k.astype(mx.float32).reshape(*k.shape[:-1], -1, 2))
|
|
213
|
-
q_out = view_as_real(q_ * freqs_cis).flatten(-2) # ..., num_heads, head_dim
|
|
214
|
-
k_out = view_as_real(k_ * freqs_cis).flatten(-2) # ..., num_heads, head_dim
|
|
215
|
-
return q_out.astype(q.dtype), k_out.astype(k.dtype)
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
class Attention(nn.Module):
|
|
219
|
-
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
|
220
|
-
super().__init__()
|
|
221
|
-
self.num_heads = num_heads
|
|
222
|
-
self.head_dim = head_dim = dim // num_heads
|
|
223
|
-
self.scale = head_dim**-0.5
|
|
224
|
-
self.wqkv = nn.Linear(dim, dim * 3, bias=True)
|
|
225
|
-
self.wo = nn.Linear(dim, dim, bias=True)
|
|
226
|
-
|
|
227
|
-
def __call__(
|
|
228
|
-
self, x: mx.array, cu_seqlens: mx.array, rotary_pos_emb: mx.array = None
|
|
229
|
-
) -> mx.array:
|
|
230
|
-
seq_length = x.shape[0]
|
|
231
|
-
qkv = self.wqkv(x)
|
|
232
|
-
|
|
233
|
-
qkv_shape = qkv.shape[:-1] + (
|
|
234
|
-
3,
|
|
235
|
-
self.num_heads,
|
|
236
|
-
self.head_dim,
|
|
237
|
-
)
|
|
238
|
-
# xqkv: (batch_size, seqlen, 3, nheads, headdim)
|
|
239
|
-
qkv = qkv.reshape(*qkv_shape)
|
|
240
|
-
|
|
241
|
-
q, k, v = mx.split(qkv, 3, axis=1)
|
|
242
|
-
q = q.squeeze(1)
|
|
243
|
-
k = k.squeeze(1)
|
|
244
|
-
v = v.squeeze(1)
|
|
245
|
-
|
|
246
|
-
q, k = apply_rope(q, k, rotary_pos_emb)
|
|
247
|
-
|
|
248
|
-
attention_mask = mx.zeros((1, seq_length, seq_length), dtype=x.dtype)
|
|
249
|
-
|
|
250
|
-
# Create attention mask for each sequence in the batch
|
|
251
|
-
for i in range(1, len(cu_seqlens)):
|
|
252
|
-
start = int(cu_seqlens[i - 1])
|
|
253
|
-
end = int(cu_seqlens[i])
|
|
254
|
-
attention_mask[..., start:end, start:end] = 1
|
|
255
|
-
|
|
256
|
-
q = q.transpose(1, 0, 2)
|
|
257
|
-
k = k.transpose(1, 0, 2)
|
|
258
|
-
v = v.transpose(1, 0, 2)
|
|
259
|
-
|
|
260
|
-
attn_weight = q @ k.swapaxes(-2, -1) / mx.sqrt(q.shape[-1])
|
|
261
|
-
attn_weight += attention_mask
|
|
262
|
-
attn_weight = mx.softmax(attn_weight, axis=-1).astype(q.dtype)
|
|
263
|
-
|
|
264
|
-
attn_output = attn_weight @ v
|
|
265
|
-
attn_output = attn_output.transpose(1, 0, 2)
|
|
266
|
-
attn_output = attn_output.reshape(seq_length, -1)
|
|
267
|
-
return self.wo(attn_output)
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
class MLP(nn.Module):
|
|
271
|
-
def __init__(self, dim, hidden_dim):
|
|
272
|
-
super().__init__()
|
|
273
|
-
self.activation_fn = nn.GELU()
|
|
274
|
-
self.fc0 = nn.Linear(dim, hidden_dim)
|
|
275
|
-
self.fc1 = nn.Linear(hidden_dim, dim)
|
|
276
|
-
|
|
277
|
-
def __call__(self, x: mx.array) -> mx.array:
|
|
278
|
-
x = self.activation_fn(self.fc0(x))
|
|
279
|
-
x = self.fc1(x)
|
|
280
|
-
return x
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
class Qwen2VLVisionBlock(nn.Module):
|
|
284
|
-
def __init__(self, config: VisionConfig) -> None:
|
|
285
|
-
super().__init__()
|
|
286
|
-
self.norm0 = nn.LayerNorm(config.embed_dim, eps=1e-6)
|
|
287
|
-
self.norm1 = nn.LayerNorm(config.embed_dim, eps=1e-6)
|
|
288
|
-
|
|
289
|
-
self.attn = Attention(dim=config.embed_dim, num_heads=config.num_heads)
|
|
290
|
-
self.mlp = MLP(dim=config.embed_dim, hidden_dim=config.intermediate_size)
|
|
291
|
-
|
|
292
|
-
def __call__(self, hidden_states, cu_seqlens, rotary_pos_emb) -> mx.array:
|
|
293
|
-
hidden_states = hidden_states + self.attn(
|
|
294
|
-
self.norm0(hidden_states),
|
|
295
|
-
cu_seqlens=cu_seqlens,
|
|
296
|
-
rotary_pos_emb=rotary_pos_emb,
|
|
297
|
-
)
|
|
298
|
-
hidden_states = hidden_states + self.mlp(self.norm1(hidden_states))
|
|
299
|
-
return hidden_states
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
class Rope2DPosEmb(nn.Module):
|
|
303
|
-
"""2D rotary position embedding with multi-resolution support.
|
|
304
|
-
|
|
305
|
-
This class is intended to be used in the following way:
|
|
306
|
-
1. Before training, create an instance of Rope2DPosEmb. This instance will hold the precomputed cis.
|
|
307
|
-
2. Before each forward pass, call `get_freqs_cis_by_*` to get the `freqs_cis` tensor for this iteration.
|
|
308
|
-
3. During the forward pass, pass the `freqs_cis` tensor to each attention layer, and call `apply` just before each attention operation.
|
|
309
|
-
The rope is shared across all attention layers and all heads.
|
|
310
|
-
|
|
311
|
-
Refs:
|
|
312
|
-
- RoFormer: https://arxiv.org/abs/2104.09864
|
|
313
|
-
- VisionLLaMA: https://arxiv.org/abs/2403.00522
|
|
314
|
-
- https://github.com/Meituan-AutoML/VisionLLaMA/blob/main/dit/models.py
|
|
315
|
-
|
|
316
|
-
Args:
|
|
317
|
-
dim (int): usually the multi-head attention dimension, should be divisible by 4 (TODO: relax this constraint if needed)
|
|
318
|
-
max_height (int): the maximum height of the 2D grid
|
|
319
|
-
max_width (int): the maximum width of the 2D grid
|
|
320
|
-
theta_base (float): the base of the theta
|
|
321
|
-
"""
|
|
322
|
-
|
|
323
|
-
def __init__(self, dim: int, max_height: int, max_width: int, theta_base=10000):
|
|
324
|
-
super().__init__()
|
|
325
|
-
self.dim = dim
|
|
326
|
-
assert self.dim % 4 == 0, "dim must be divisible by 4"
|
|
327
|
-
self.max_height = max_height
|
|
328
|
-
self.max_width = max_width
|
|
329
|
-
self.theta_base = theta_base
|
|
330
|
-
|
|
331
|
-
self._freqs_cis = None
|
|
332
|
-
|
|
333
|
-
def extra_repr(self):
|
|
334
|
-
return f"dim={self.dim}, max_height={self.max_height}, max_width={self.max_width}, theta_base={self.theta_base}"
|
|
335
|
-
|
|
336
|
-
def _precompute_freqs_cis(self) -> mx.array:
|
|
337
|
-
"""Calculate the cis(freqs) for each position in the 2D grid.
|
|
338
|
-
|
|
339
|
-
Return: complex array of shape (max_height, max_width, dim//2) and value:
|
|
340
|
-
height axis: ret[h, w, 2*i] = cis(h * theta_base**(-4*i/dim))
|
|
341
|
-
weight axis: ret[h, w, 2*i+1] = cis(w * theta_base**(-4*i/dim)) with (i in [0, dim//4))
|
|
342
|
-
note: `cis` is a mathematical notation defined by cis x = cos x + i sin x,
|
|
343
|
-
"""
|
|
344
|
-
N = self.max_height * self.max_width
|
|
345
|
-
flat_pos = mx.arange(0, N, dtype=mx.float32)
|
|
346
|
-
x_pos = flat_pos % self.max_width
|
|
347
|
-
y_pos = flat_pos // self.max_width
|
|
348
|
-
dim_range = mx.arange(0, self.dim, 4)[: (self.dim // 4)].astype(
|
|
349
|
-
mx.float32
|
|
350
|
-
) # C/4
|
|
351
|
-
freqs = 1.0 / (self.theta_base ** (dim_range / self.dim))
|
|
352
|
-
x_freqs = mx.outer(x_pos, freqs) # N, C/4
|
|
353
|
-
y_freqs = mx.outer(y_pos, freqs) # N, C/4
|
|
354
|
-
|
|
355
|
-
# Create complex numbers using cos and sin
|
|
356
|
-
x_cos = mx.cos(x_freqs)
|
|
357
|
-
x_sin = mx.sin(x_freqs)
|
|
358
|
-
y_cos = mx.cos(y_freqs)
|
|
359
|
-
y_sin = mx.sin(y_freqs)
|
|
360
|
-
|
|
361
|
-
# Create complex numbers
|
|
362
|
-
x_cis = x_cos + 1j * x_sin # N, C/4
|
|
363
|
-
y_cis = y_cos + 1j * y_sin # N, C/4
|
|
364
|
-
|
|
365
|
-
# N, C/4, 2
|
|
366
|
-
freqs_cis = mx.stack([x_cis, y_cis], axis=-1)
|
|
367
|
-
|
|
368
|
-
# max_height, max_width, C/2
|
|
369
|
-
freqs_cis = freqs_cis.reshape(self.max_height, self.max_width, -1)
|
|
370
|
-
return freqs_cis
|
|
371
|
-
|
|
372
|
-
def get_freqs_cis(self, grid_hws: mx.array) -> mx.array:
|
|
373
|
-
"""
|
|
374
|
-
Args:
|
|
375
|
-
grid_hws (mx.array): grid height and width
|
|
376
|
-
|
|
377
|
-
Returns:
|
|
378
|
-
freqs_cis: array of shape (sum(t * height * width), dim//2)
|
|
379
|
-
"""
|
|
380
|
-
if self._freqs_cis is None:
|
|
381
|
-
self._freqs_cis = self._precompute_freqs_cis()
|
|
382
|
-
|
|
383
|
-
shapes = grid_hws.tolist()
|
|
384
|
-
assert all(
|
|
385
|
-
1 <= h <= self.max_height and 1 <= w <= self.max_width for h, w in shapes
|
|
386
|
-
), (
|
|
387
|
-
shapes,
|
|
388
|
-
self.max_height,
|
|
389
|
-
self.max_width,
|
|
390
|
-
)
|
|
391
|
-
|
|
392
|
-
freqs_cis_list = []
|
|
393
|
-
for h, w in shapes:
|
|
394
|
-
# Get the slice of precomputed frequencies for this shape
|
|
395
|
-
shape_freqs = self._freqs_cis[:h, :w]
|
|
396
|
-
# Reshape to flatten the spatial dimensions
|
|
397
|
-
shape_freqs = shape_freqs.reshape(-1, self.dim // 2)
|
|
398
|
-
freqs_cis_list.append(shape_freqs)
|
|
399
|
-
|
|
400
|
-
freqs_cis = mx.concatenate(freqs_cis_list, axis=0)
|
|
401
|
-
return freqs_cis
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
def patch_merger(
|
|
405
|
-
x: mx.array,
|
|
406
|
-
grid_hws: mx.array,
|
|
407
|
-
merge_kernel_size: list[int, int] = (2, 2),
|
|
408
|
-
) -> List[mx.array]:
|
|
409
|
-
d_model = x.shape[-1]
|
|
410
|
-
|
|
411
|
-
outputs = []
|
|
412
|
-
pre_sum = 0
|
|
413
|
-
for x_shape in grid_hws.tolist():
|
|
414
|
-
height, width = x_shape[0], x_shape[1]
|
|
415
|
-
# Get the current sequence
|
|
416
|
-
seq = x[pre_sum : pre_sum + height * width]
|
|
417
|
-
# Reshape along self.merge_kernel_size and concat to the last dimension
|
|
418
|
-
kernel_height, kernel_width = merge_kernel_size
|
|
419
|
-
new_height, new_width = height // kernel_height, width // kernel_width
|
|
420
|
-
reshaped_seq = seq.reshape(
|
|
421
|
-
new_height, kernel_height, new_width, kernel_width, d_model
|
|
422
|
-
)
|
|
423
|
-
reshaped_seq = mx.transpose(reshaped_seq, (0, 2, 1, 3, 4))
|
|
424
|
-
padded_seq = reshaped_seq.reshape(
|
|
425
|
-
new_height * new_width, kernel_height * kernel_width, -1
|
|
426
|
-
)
|
|
427
|
-
outputs.append(padded_seq)
|
|
428
|
-
pre_sum += height * width
|
|
429
|
-
|
|
430
|
-
return outputs
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
class VisionModel(nn.Module):
|
|
434
|
-
|
|
435
|
-
def __init__(self, config: VisionConfig) -> None:
|
|
436
|
-
super().__init__()
|
|
437
|
-
self.config = config
|
|
438
|
-
self.model_type = config.model_type
|
|
439
|
-
if self.model_type not in ["qwen2_vl", "moonvit"]:
|
|
440
|
-
raise ValueError(f"Unsupported model type: {self.model_type}")
|
|
441
|
-
self.spatial_merge_size = config.spatial_merge_size
|
|
442
|
-
self.merge_kernel_size = config.merge_kernel_size
|
|
443
|
-
|
|
444
|
-
self.patch_embed = PatchEmbed(
|
|
445
|
-
patch_size=config.patch_size,
|
|
446
|
-
num_channels=config.num_channels,
|
|
447
|
-
embed_dim=config.embed_dim,
|
|
448
|
-
init_pos_emb_height=config.init_pos_emb_height,
|
|
449
|
-
)
|
|
450
|
-
|
|
451
|
-
head_dim = config.embed_dim // config.num_heads
|
|
452
|
-
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
|
453
|
-
|
|
454
|
-
self.blocks = [Qwen2VLVisionBlock(config) for _ in range(config.depth)]
|
|
455
|
-
self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=1e-6)
|
|
456
|
-
self.rope_pos_emb = Rope2DPosEmb(head_dim, 512, 512)
|
|
457
|
-
|
|
458
|
-
def __call__(
|
|
459
|
-
self,
|
|
460
|
-
hidden_states: mx.array,
|
|
461
|
-
grid_thw: mx.array,
|
|
462
|
-
output_hidden_states: Optional[bool] = None,
|
|
463
|
-
) -> mx.array:
|
|
464
|
-
|
|
465
|
-
hidden_states = self.patch_embed(hidden_states, grid_thw)
|
|
466
|
-
rotary_pos_emb = self.rope_pos_emb.get_freqs_cis(grid_thw)
|
|
467
|
-
|
|
468
|
-
# Assuming grid_thw has shape (batch_size, 3)
|
|
469
|
-
batch_size = grid_thw.shape[0]
|
|
470
|
-
|
|
471
|
-
# Calculate cu_seqlens for each item in the batch
|
|
472
|
-
lengths = mx.concatenate(
|
|
473
|
-
(
|
|
474
|
-
mx.zeros((1,), dtype=grid_thw.dtype),
|
|
475
|
-
grid_thw[:, 0] * grid_thw[:, 1],
|
|
476
|
-
)
|
|
477
|
-
)
|
|
478
|
-
cu_seqlens = mx.cumsum(lengths.astype(mx.int32), axis=0)
|
|
479
|
-
|
|
480
|
-
encoder_states = (hidden_states,) if output_hidden_states else None
|
|
481
|
-
|
|
482
|
-
for blk in self.blocks:
|
|
483
|
-
hidden_states = blk(
|
|
484
|
-
hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
|
|
485
|
-
)
|
|
486
|
-
if output_hidden_states:
|
|
487
|
-
encoder_states = encoder_states + (hidden_states,)
|
|
488
|
-
|
|
489
|
-
hidden_states = self.final_layernorm(hidden_states)
|
|
490
|
-
|
|
491
|
-
hidden_states = patch_merger(
|
|
492
|
-
hidden_states, grid_thw, merge_kernel_size=self.merge_kernel_size
|
|
493
|
-
)
|
|
494
|
-
|
|
495
|
-
return hidden_states
|
|
496
|
-
|
|
497
|
-
def sanitize(self, weights):
|
|
498
|
-
sanitized_weights = {}
|
|
499
|
-
for k, v in weights.items():
|
|
500
|
-
if "position_ids" in k:
|
|
501
|
-
# Remove unused position_ids
|
|
502
|
-
continue
|
|
503
|
-
elif "patch_embed.proj.weight" in k:
|
|
504
|
-
# PyTorch conv2d weight tensors have shape:
|
|
505
|
-
# [out_channels, in_channels, kH, KW]
|
|
506
|
-
# MLX conv2d expects the weight be of shape:
|
|
507
|
-
# [out_channels, kH, KW, in_channels]
|
|
508
|
-
if check_array_shape(v):
|
|
509
|
-
sanitized_weights[k] = v
|
|
510
|
-
else:
|
|
511
|
-
sanitized_weights[k] = v.transpose(0, 2, 3, 1)
|
|
512
|
-
|
|
513
|
-
elif "vision_tower.blocks" in k:
|
|
514
|
-
if "attn" not in k and ("wqkv" in k or "wo" in k):
|
|
515
|
-
new_key = k.replace("wqkv", "attn.wqkv").replace("wo", "attn.wo")
|
|
516
|
-
sanitized_weights[new_key] = v
|
|
517
|
-
else:
|
|
518
|
-
sanitized_weights[k] = v
|
|
519
|
-
else:
|
|
520
|
-
sanitized_weights[k] = v
|
|
521
|
-
|
|
522
|
-
return sanitized_weights
|