nexaai 1.0.19rc7__cp310-cp310-macosx_14_0_universal2.whl → 1.0.19rc9__cp310-cp310-macosx_14_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nexaai might be problematic. Click here for more details.
- nexaai/_stub.cpython-310-darwin.so +0 -0
- nexaai/_version.py +1 -1
- nexaai/binds/libnexa_bridge.dylib +0 -0
- nexaai/mlx_backend/vlm/generate_qwen3_vl.py +14 -31
- nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +15 -32
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +7 -23
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +8 -24
- {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc9.dist-info}/METADATA +1 -1
- {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc9.dist-info}/RECORD +11 -200
- nexaai/binds/nexa_mlx/py-lib/asr/__init__.py +0 -12
- nexaai/binds/nexa_mlx/py-lib/asr/interface.py +0 -122
- nexaai/binds/nexa_mlx/py-lib/common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/common/utils.py +0 -25
- nexaai/binds/nexa_mlx/py-lib/cv/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/cv/generate.py +0 -195
- nexaai/binds/nexa_mlx/py-lib/cv/interface.py +0 -151
- nexaai/binds/nexa_mlx/py-lib/cv/main.py +0 -81
- nexaai/binds/nexa_mlx/py-lib/cv/modeling/pp_ocr_v4.py +0 -1736
- nexaai/binds/nexa_mlx/py-lib/embedding/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/generate.py +0 -333
- nexaai/binds/nexa_mlx/py-lib/embedding/interface.py +0 -617
- nexaai/binds/nexa_mlx/py-lib/embedding/main.py +0 -173
- nexaai/binds/nexa_mlx/py-lib/embedding/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/modeling/nexa_jina_v2.py +0 -399
- nexaai/binds/nexa_mlx/py-lib/image_gen/__init__.py +0 -1
- nexaai/binds/nexa_mlx/py-lib/image_gen/generate_sd.py +0 -244
- nexaai/binds/nexa_mlx/py-lib/image_gen/interface.py +0 -82
- nexaai/binds/nexa_mlx/py-lib/image_gen/main.py +0 -281
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/__init__.py +0 -306
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/clip.py +0 -116
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/config.py +0 -65
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/model_io.py +0 -386
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/sampler.py +0 -105
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/tokenizer.py +0 -100
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/unet.py +0 -460
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/vae.py +0 -274
- nexaai/binds/nexa_mlx/py-lib/llm/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/llm/generate.py +0 -149
- nexaai/binds/nexa_mlx/py-lib/llm/interface.py +0 -764
- nexaai/binds/nexa_mlx/py-lib/llm/main.py +0 -68
- nexaai/binds/nexa_mlx/py-lib/rerank/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/generate.py +0 -174
- nexaai/binds/nexa_mlx/py-lib/rerank/interface.py +0 -287
- nexaai/binds/nexa_mlx/py-lib/rerank/main.py +0 -127
- nexaai/binds/nexa_mlx/py-lib/rerank/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/modeling/nexa_jina_rerank.py +0 -330
- nexaai/binds/nexa_mlx/py-lib/sd/__init__.py +0 -1
- nexaai/binds/nexa_mlx/py-lib/sd/interface.py +0 -362
- nexaai/binds/nexa_mlx/py-lib/sd/main.py +0 -286
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/__init__.py +0 -306
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/clip.py +0 -116
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/config.py +0 -65
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/model_io.py +0 -385
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/sampler.py +0 -105
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/tokenizer.py +0 -100
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/unet.py +0 -460
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/vae.py +0 -274
- nexaai/binds/nexa_mlx/py-lib/tts/__init__.py +0 -12
- nexaai/binds/nexa_mlx/py-lib/tts/interface.py +0 -276
- nexaai/binds/nexa_mlx/py-lib/vlm/__init__.py +0 -3
- nexaai/binds/nexa_mlx/py-lib/vlm/generate.py +0 -572
- nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl.py +0 -294
- nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl_moe.py +0 -276
- nexaai/binds/nexa_mlx/py-lib/vlm/interface.py +0 -504
- nexaai/binds/nexa_mlx/py-lib/vlm/main.py +0 -320
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/convert.py +0 -68
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/aya_vision.py +0 -193
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/interpolate.py +0 -186
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/language.py +0 -233
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/vision.py +0 -503
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/base.py +0 -202
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/cache.py +0 -230
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/__init__.py +0 -10
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/conversation.py +0 -264
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +0 -472
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/language.py +0 -591
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +0 -526
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/vision.py +0 -356
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/florence2.py +0 -366
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/language.py +0 -488
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/vision.py +0 -591
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/gemma3.py +0 -213
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/language.py +0 -315
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/vision.py +0 -238
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/__init__.py +0 -2
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/audio.py +0 -1038
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/config.py +0 -139
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/gemma3n.py +0 -322
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/language.py +0 -629
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/vision.py +0 -1022
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/__init__.py +0 -9
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/idefics2.py +0 -294
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/language.py +0 -191
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/vision.py +0 -267
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/idefics3.py +0 -175
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/language.py +0 -192
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/vision.py +0 -233
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/__init__.py +0 -9
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/internvl_chat.py +0 -140
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/language.py +0 -220
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/processor.py +0 -393
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/vision.py +0 -293
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kernels.py +0 -307
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/kimi_vl.py +0 -143
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/language.py +0 -509
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/vision.py +0 -522
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/language.py +0 -386
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/llama4.py +0 -138
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/vision.py +0 -560
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/language.py +0 -240
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/llava.py +0 -153
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/vision.py +0 -259
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/__init__.py +0 -9
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/language.py +0 -236
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/llava_bunny.py +0 -256
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/vision.py +0 -303
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/language.py +0 -230
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/llava_next.py +0 -160
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/vision.py +0 -243
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/mistral3.py +0 -283
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/language.py +0 -416
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/mllama.py +0 -172
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/vision.py +0 -499
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/language.py +0 -243
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/molmo.py +0 -133
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/vision.py +0 -465
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/__init__.py +0 -10
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/language.py +0 -230
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/multi_modality.py +0 -385
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/sam.py +0 -557
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/vision.py +0 -526
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/language.py +0 -282
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/paligemma.py +0 -160
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/vision.py +0 -242
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/language.py +0 -21
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/phi3_v.py +0 -243
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/su_rope.py +0 -71
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/vision.py +0 -324
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/language.py +0 -229
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/pixtral.py +0 -161
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/vision.py +0 -320
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/__init__.py +0 -2
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/config.py +0 -108
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/language.py +0 -490
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +0 -168
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/vision.py +0 -414
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/__init__.py +0 -2
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/config.py +0 -104
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/language.py +0 -490
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/qwen2_vl.py +0 -167
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/vision.py +0 -312
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/base.py +0 -117
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/cache.py +0 -531
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/generate.py +0 -701
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +0 -255
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +0 -303
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +0 -407
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/processor.py +0 -476
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/qwen3vl.py +0 -1223
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +0 -117
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +0 -531
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +0 -701
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +0 -255
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +0 -303
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +0 -407
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/processor.py +0 -476
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +0 -1309
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/switch_layers.py +0 -210
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/smolvlm.py +0 -62
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_5_vl.py +0 -209
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_vl.py +0 -215
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/prompt_utils.py +0 -474
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/sample_utils.py +0 -39
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/tokenizer_utils.py +0 -344
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/__init__.py +0 -9
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/lora.py +0 -70
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/trainer.py +0 -296
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/utils.py +0 -160
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/utils.py +0 -928
- {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc9.dist-info}/WHEEL +0 -0
- {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc9.dist-info}/top_level.txt +0 -0
|
@@ -1,503 +0,0 @@
|
|
|
1
|
-
import inspect
|
|
2
|
-
from dataclasses import dataclass
|
|
3
|
-
from typing import Optional
|
|
4
|
-
|
|
5
|
-
import mlx.core as mx
|
|
6
|
-
import mlx.nn as nn
|
|
7
|
-
import numpy as np
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
@dataclass
|
|
11
|
-
class VisionConfig:
|
|
12
|
-
model_type: str
|
|
13
|
-
hidden_size: int
|
|
14
|
-
num_attention_heads: int
|
|
15
|
-
patch_size: int
|
|
16
|
-
num_hidden_layers: int = 12
|
|
17
|
-
intermediate_size: int = 3072
|
|
18
|
-
image_size: int = 224
|
|
19
|
-
num_channels: int = 3
|
|
20
|
-
layer_norm_eps: float = 1e-6
|
|
21
|
-
|
|
22
|
-
@classmethod
|
|
23
|
-
def from_dict(cls, params):
|
|
24
|
-
return cls(
|
|
25
|
-
**{
|
|
26
|
-
k: v
|
|
27
|
-
for k, v in params.items()
|
|
28
|
-
if k in inspect.signature(cls).parameters
|
|
29
|
-
}
|
|
30
|
-
)
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
def check_array_shape(arr):
|
|
34
|
-
shape = arr.shape
|
|
35
|
-
|
|
36
|
-
# Check if the shape has 4 dimensions
|
|
37
|
-
if len(shape) != 4:
|
|
38
|
-
return False
|
|
39
|
-
|
|
40
|
-
out_channels, kH, KW, _ = shape
|
|
41
|
-
|
|
42
|
-
# Check if out_channels is the largest, and kH and KW are the same
|
|
43
|
-
if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
|
|
44
|
-
return True
|
|
45
|
-
else:
|
|
46
|
-
return False
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
class Attention(nn.Module):
|
|
50
|
-
def __init__(
|
|
51
|
-
self,
|
|
52
|
-
dims: int,
|
|
53
|
-
num_heads: int,
|
|
54
|
-
query_input_dims: Optional[int] = None,
|
|
55
|
-
key_input_dims: Optional[int] = None,
|
|
56
|
-
value_input_dims: Optional[int] = None,
|
|
57
|
-
value_dims: Optional[int] = None,
|
|
58
|
-
value_output_dims: Optional[int] = None,
|
|
59
|
-
bias: bool = True,
|
|
60
|
-
):
|
|
61
|
-
super().__init__()
|
|
62
|
-
|
|
63
|
-
if (dims % num_heads) != 0:
|
|
64
|
-
raise ValueError(
|
|
65
|
-
"The input feature dimensions should be divisible by the "
|
|
66
|
-
f"number of heads ({dims} % {num_heads}) != 0"
|
|
67
|
-
)
|
|
68
|
-
|
|
69
|
-
query_input_dims = query_input_dims or dims
|
|
70
|
-
key_input_dims = key_input_dims or dims
|
|
71
|
-
value_input_dims = value_input_dims or key_input_dims
|
|
72
|
-
value_dims = value_dims or dims
|
|
73
|
-
value_output_dims = value_output_dims or dims
|
|
74
|
-
|
|
75
|
-
self.num_heads = num_heads
|
|
76
|
-
head_dim = dims // num_heads
|
|
77
|
-
self.scale = head_dim**-0.5
|
|
78
|
-
|
|
79
|
-
self.q_proj = nn.Linear(query_input_dims, dims, bias=bias)
|
|
80
|
-
self.k_proj = nn.Linear(key_input_dims, dims, bias=bias)
|
|
81
|
-
self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias)
|
|
82
|
-
self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias)
|
|
83
|
-
|
|
84
|
-
def __call__(self, x, mask=None):
|
|
85
|
-
queries = self.q_proj(x)
|
|
86
|
-
keys = self.k_proj(x)
|
|
87
|
-
values = self.v_proj(x)
|
|
88
|
-
|
|
89
|
-
num_heads = self.num_heads
|
|
90
|
-
B, L, D = queries.shape
|
|
91
|
-
_, S, _ = keys.shape
|
|
92
|
-
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
|
93
|
-
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
|
94
|
-
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
|
95
|
-
|
|
96
|
-
output = mx.fast.scaled_dot_product_attention(
|
|
97
|
-
queries, keys, values, scale=self.scale, mask=mask
|
|
98
|
-
)
|
|
99
|
-
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
100
|
-
return self.out_proj(output)
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
class MLP(nn.Module):
|
|
104
|
-
def __init__(self, config: VisionConfig):
|
|
105
|
-
super().__init__()
|
|
106
|
-
self.activation_fn = nn.GELU(approx="precise")
|
|
107
|
-
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=True)
|
|
108
|
-
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=True)
|
|
109
|
-
|
|
110
|
-
def __call__(self, x: mx.array) -> mx.array:
|
|
111
|
-
x = self.fc1(x)
|
|
112
|
-
x = self.activation_fn(x)
|
|
113
|
-
x = self.fc2(x)
|
|
114
|
-
return x
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
class EncoderLayer(nn.Module):
|
|
118
|
-
def __init__(self, config: VisionConfig):
|
|
119
|
-
super().__init__()
|
|
120
|
-
self.embed_dim = config.hidden_size
|
|
121
|
-
self.self_attn = Attention(
|
|
122
|
-
config.hidden_size, config.num_attention_heads, bias=True
|
|
123
|
-
)
|
|
124
|
-
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
|
125
|
-
self.mlp = MLP(config)
|
|
126
|
-
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
|
127
|
-
|
|
128
|
-
def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
|
|
129
|
-
r = self.self_attn(self.layer_norm1(x), mask)
|
|
130
|
-
h = x + r
|
|
131
|
-
r = self.mlp(self.layer_norm2(h))
|
|
132
|
-
return h + r
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
class Encoder(nn.Module):
|
|
136
|
-
def __init__(self, config: VisionConfig):
|
|
137
|
-
super().__init__()
|
|
138
|
-
self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)]
|
|
139
|
-
|
|
140
|
-
def __call__(
|
|
141
|
-
self,
|
|
142
|
-
x: mx.array,
|
|
143
|
-
output_hidden_states: Optional[bool] = None,
|
|
144
|
-
mask: Optional[mx.array] = None,
|
|
145
|
-
) -> mx.array:
|
|
146
|
-
encoder_states = (x,) if output_hidden_states else None
|
|
147
|
-
h = x
|
|
148
|
-
for l in self.layers:
|
|
149
|
-
x = l(x, mask=mask)
|
|
150
|
-
if output_hidden_states:
|
|
151
|
-
encoder_states = encoder_states + (x,)
|
|
152
|
-
|
|
153
|
-
h = x
|
|
154
|
-
|
|
155
|
-
return (h, encoder_states)
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
def gaussian_blur_axis(image, sigma, axis):
|
|
159
|
-
"""
|
|
160
|
-
Applies a 1D Gaussian blur along the given axis.
|
|
161
|
-
This version works for arrays with any number of dimensions.
|
|
162
|
-
"""
|
|
163
|
-
radius = int(3 * sigma)
|
|
164
|
-
if radius < 1:
|
|
165
|
-
return image
|
|
166
|
-
x = mx.arange(-radius, radius + 1)
|
|
167
|
-
kernel = mx.exp(-(x**2) / (2 * sigma**2))
|
|
168
|
-
kernel = kernel / mx.sum(kernel)
|
|
169
|
-
|
|
170
|
-
# MLX doesn't have a direct apply_along_axis equivalent,
|
|
171
|
-
# so we'll implement the convolution differently based on the axis
|
|
172
|
-
|
|
173
|
-
# Helper function to apply 1D convolution along specific axis
|
|
174
|
-
def conv_1d(array, kernel, axis):
|
|
175
|
-
# Reshape kernel to broadcast along the right dimensions
|
|
176
|
-
kernel_shape = [1] * image.ndim
|
|
177
|
-
kernel_shape[axis] = len(kernel)
|
|
178
|
-
kernel_reshaped = kernel.reshape(kernel_shape)
|
|
179
|
-
|
|
180
|
-
# Pad the array
|
|
181
|
-
pad_width = [(0, 0)] * image.ndim
|
|
182
|
-
pad_width[axis] = (radius, radius)
|
|
183
|
-
padded = mx.pad(array, pad_width, mode="edge")
|
|
184
|
-
|
|
185
|
-
# Perform convolution via sliding window sum
|
|
186
|
-
result = mx.zeros_like(array)
|
|
187
|
-
slices = [slice(None)] * padded.ndim
|
|
188
|
-
|
|
189
|
-
for i in range(2 * radius + 1):
|
|
190
|
-
slices[axis] = slice(i, i + array.shape[axis])
|
|
191
|
-
result = result + padded[tuple(slices)] * kernel_reshaped
|
|
192
|
-
|
|
193
|
-
return result
|
|
194
|
-
|
|
195
|
-
return conv_1d(image, kernel, axis)
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
def bilinear_interpolate(image, new_height, new_width, align_corners=False):
|
|
199
|
-
"""
|
|
200
|
-
Performs bilinear interpolation on an array whose spatial dimensions are the first two.
|
|
201
|
-
It supports extra dimensions (e.g. channels or batch dimensions that have been moved to the trailing axes).
|
|
202
|
-
"""
|
|
203
|
-
# image is assumed to have shape (H, W, ...) where H and W are spatial dimensions.
|
|
204
|
-
H_in, W_in = image.shape[0], image.shape[1]
|
|
205
|
-
|
|
206
|
-
# Compute sampling positions in the input image.
|
|
207
|
-
if new_height == 1:
|
|
208
|
-
row_positions = mx.array([0.0])
|
|
209
|
-
else:
|
|
210
|
-
if align_corners:
|
|
211
|
-
row_positions = mx.linspace(0, H_in - 1, new_height)
|
|
212
|
-
else:
|
|
213
|
-
row_positions = (mx.arange(new_height) + 0.5) * H_in / new_height - 0.5
|
|
214
|
-
|
|
215
|
-
if new_width == 1:
|
|
216
|
-
col_positions = mx.array([0.0])
|
|
217
|
-
else:
|
|
218
|
-
if align_corners:
|
|
219
|
-
col_positions = mx.linspace(0, W_in - 1, new_width)
|
|
220
|
-
else:
|
|
221
|
-
col_positions = (mx.arange(new_width) + 0.5) * W_in / new_width - 0.5
|
|
222
|
-
|
|
223
|
-
# Compute floor and ceil indices.
|
|
224
|
-
row_floor = mx.floor(row_positions).astype(mx.int32)
|
|
225
|
-
col_floor = mx.floor(col_positions).astype(mx.int32)
|
|
226
|
-
row_ceil = row_floor + 1
|
|
227
|
-
col_ceil = col_floor + 1
|
|
228
|
-
|
|
229
|
-
row_floor = mx.clip(row_floor, 0, H_in - 1)
|
|
230
|
-
row_ceil = mx.clip(row_ceil, 0, H_in - 1)
|
|
231
|
-
col_floor = mx.clip(col_floor, 0, W_in - 1)
|
|
232
|
-
col_ceil = mx.clip(col_ceil, 0, W_in - 1)
|
|
233
|
-
|
|
234
|
-
row_weight = row_positions - row_floor # shape (new_height,)
|
|
235
|
-
col_weight = col_positions - col_floor # shape (new_width,)
|
|
236
|
-
|
|
237
|
-
# Use advanced indexing for gather operations
|
|
238
|
-
# Create meshgrid for coordinates
|
|
239
|
-
row_floor_grid, col_floor_grid = mx.meshgrid(row_floor, col_floor, indexing="ij")
|
|
240
|
-
row_ceil_grid, col_floor_grid = mx.meshgrid(row_ceil, col_floor, indexing="ij")
|
|
241
|
-
row_floor_grid, col_ceil_grid = mx.meshgrid(row_floor, col_ceil, indexing="ij")
|
|
242
|
-
row_ceil_grid, col_ceil_grid = mx.meshgrid(row_ceil, col_ceil, indexing="ij")
|
|
243
|
-
|
|
244
|
-
# Gather the four surrounding pixels using take_along_axis
|
|
245
|
-
# For higher dimensional arrays, we'll need to reshape and broadcast
|
|
246
|
-
extra_dims = image.ndim - 2
|
|
247
|
-
|
|
248
|
-
def gather_pixels(row_indices, col_indices):
|
|
249
|
-
# Flatten the spatial dimensions for gathering
|
|
250
|
-
flat_indices = row_indices * W_in + col_indices
|
|
251
|
-
flat_image = mx.reshape(image, (-1,) + image.shape[2:])
|
|
252
|
-
# Gather and reshape back
|
|
253
|
-
gathered = mx.take(flat_image, flat_indices.reshape(-1), axis=0)
|
|
254
|
-
return mx.reshape(gathered, (new_height, new_width) + image.shape[2:])
|
|
255
|
-
|
|
256
|
-
top_left = gather_pixels(row_floor_grid, col_floor_grid)
|
|
257
|
-
top_right = gather_pixels(row_floor_grid, col_ceil_grid)
|
|
258
|
-
bottom_left = gather_pixels(row_ceil_grid, col_floor_grid)
|
|
259
|
-
bottom_right = gather_pixels(row_ceil_grid, col_ceil_grid)
|
|
260
|
-
|
|
261
|
-
# Expand the weights to have shape (new_height, new_width, *[1]*extra_dims)
|
|
262
|
-
r_weight = row_weight.reshape(new_height, 1, *([1] * extra_dims))
|
|
263
|
-
c_weight = col_weight.reshape(1, new_width, *([1] * extra_dims))
|
|
264
|
-
|
|
265
|
-
# Perform bilinear interpolation.
|
|
266
|
-
result = (
|
|
267
|
-
(1 - r_weight) * (1 - c_weight) * top_left
|
|
268
|
-
+ (1 - r_weight) * c_weight * top_right
|
|
269
|
-
+ r_weight * (1 - c_weight) * bottom_left
|
|
270
|
-
+ r_weight * c_weight * bottom_right
|
|
271
|
-
)
|
|
272
|
-
return result
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
def resize_bilinear(image, new_size, align_corners=False, antialias=True):
|
|
276
|
-
"""
|
|
277
|
-
Resizes an image (or embedding tensor) to new_size=(new_height, new_width)
|
|
278
|
-
using bilinear interpolation with MLX.
|
|
279
|
-
|
|
280
|
-
Supports:
|
|
281
|
-
- 2D: (H, W)
|
|
282
|
-
- 3D: (H, W, C)
|
|
283
|
-
- 4D: (B, C, H, W) (assumed for typical image batches)
|
|
284
|
-
"""
|
|
285
|
-
new_height, new_width = new_size
|
|
286
|
-
|
|
287
|
-
# Convert numpy arrays to MLX arrays if needed
|
|
288
|
-
if isinstance(image, np.ndarray):
|
|
289
|
-
image = mx.array(image)
|
|
290
|
-
|
|
291
|
-
if image.ndim == 2 or image.ndim == 3:
|
|
292
|
-
# Assume spatial dims are the first two.
|
|
293
|
-
resized = image
|
|
294
|
-
H_in, W_in = image.shape[:2]
|
|
295
|
-
if antialias:
|
|
296
|
-
if new_height < H_in:
|
|
297
|
-
scale_y = new_height / H_in
|
|
298
|
-
sigma_y = (1 / scale_y - 1) / 2.0 # heuristic
|
|
299
|
-
if sigma_y > 0:
|
|
300
|
-
resized = gaussian_blur_axis(resized, sigma_y, axis=0)
|
|
301
|
-
if new_width < W_in:
|
|
302
|
-
scale_x = new_width / W_in
|
|
303
|
-
sigma_x = (1 / scale_x - 1) / 2.0
|
|
304
|
-
if sigma_x > 0:
|
|
305
|
-
resized = gaussian_blur_axis(resized, sigma_x, axis=1)
|
|
306
|
-
resized = bilinear_interpolate(
|
|
307
|
-
resized, new_height, new_width, align_corners=align_corners
|
|
308
|
-
)
|
|
309
|
-
return resized
|
|
310
|
-
|
|
311
|
-
elif image.ndim == 4:
|
|
312
|
-
# Assume shape is (B, C, H, W) (typical PyTorch/MLX format).
|
|
313
|
-
B, C, H_in, W_in = image.shape
|
|
314
|
-
# Permute to bring spatial dims to the front: (H, W, B, C)
|
|
315
|
-
image_perm = mx.transpose(image, (2, 3, 0, 1))
|
|
316
|
-
resized = image_perm
|
|
317
|
-
if antialias:
|
|
318
|
-
if new_height < H_in:
|
|
319
|
-
scale_y = new_height / H_in
|
|
320
|
-
sigma_y = (1 / scale_y - 1) / 2.0
|
|
321
|
-
if sigma_y > 0:
|
|
322
|
-
resized = gaussian_blur_axis(resized, sigma_y, axis=0)
|
|
323
|
-
if new_width < W_in:
|
|
324
|
-
scale_x = new_width / W_in
|
|
325
|
-
sigma_x = (1 / scale_x - 1) / 2.0
|
|
326
|
-
if sigma_x > 0:
|
|
327
|
-
resized = gaussian_blur_axis(resized, sigma_x, axis=1)
|
|
328
|
-
resized = bilinear_interpolate(
|
|
329
|
-
resized, new_height, new_width, align_corners=align_corners
|
|
330
|
-
)
|
|
331
|
-
# Permute back to (B, C, new_height, new_width)
|
|
332
|
-
resized = mx.transpose(resized, (2, 3, 0, 1))
|
|
333
|
-
return resized
|
|
334
|
-
|
|
335
|
-
else:
|
|
336
|
-
raise ValueError("Unsupported image dimensions.")
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
class VisionEmbeddings(nn.Module):
|
|
340
|
-
def __init__(self, config: VisionConfig):
|
|
341
|
-
super().__init__()
|
|
342
|
-
self.config = config
|
|
343
|
-
self.embed_dim = config.hidden_size
|
|
344
|
-
self.image_size = config.image_size
|
|
345
|
-
self.patch_size = config.patch_size
|
|
346
|
-
|
|
347
|
-
self.patch_embedding = nn.Conv2d(
|
|
348
|
-
config.num_channels,
|
|
349
|
-
config.hidden_size,
|
|
350
|
-
kernel_size=self.patch_size,
|
|
351
|
-
stride=self.patch_size,
|
|
352
|
-
)
|
|
353
|
-
|
|
354
|
-
self.num_patches = (self.image_size // self.patch_size) ** 2
|
|
355
|
-
self.num_positions = self.num_patches
|
|
356
|
-
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
|
357
|
-
|
|
358
|
-
@staticmethod
|
|
359
|
-
def resize_positional_embeddings(
|
|
360
|
-
positional_embeddings: mx.array,
|
|
361
|
-
spatial_shapes: mx.array,
|
|
362
|
-
max_length: int,
|
|
363
|
-
) -> mx.array:
|
|
364
|
-
"""
|
|
365
|
-
Resize positional embeddings to image-specific size and pad to a fixed size.
|
|
366
|
-
|
|
367
|
-
Args:
|
|
368
|
-
positional_embeddings (`torch.Tensor`):
|
|
369
|
-
Position embeddings of shape (height, width, embed_dim)
|
|
370
|
-
spatial_shapes (`torch.LongTensor`):
|
|
371
|
-
Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
|
|
372
|
-
max_length (`int`):
|
|
373
|
-
Maximum length of the positional embeddings to pad resized positional embeddings to
|
|
374
|
-
|
|
375
|
-
Returns:
|
|
376
|
-
`torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim)
|
|
377
|
-
"""
|
|
378
|
-
batch_size = spatial_shapes.shape[0]
|
|
379
|
-
embed_dim = positional_embeddings.shape[-1]
|
|
380
|
-
source_dtype = positional_embeddings.dtype
|
|
381
|
-
|
|
382
|
-
resulted_positional_embeddings = mx.zeros(
|
|
383
|
-
(batch_size, max_length, embed_dim)
|
|
384
|
-
).astype(source_dtype)
|
|
385
|
-
|
|
386
|
-
# (height, width, embed_dim) -> (1, embed_dim, height, width) for interpolation
|
|
387
|
-
positional_embeddings = positional_embeddings.transpose(2, 0, 1).reshape(
|
|
388
|
-
1, embed_dim, -1
|
|
389
|
-
)
|
|
390
|
-
|
|
391
|
-
# Upcast to float32 on CPU because antialias is not supported for bfloat16/float16 on CPU
|
|
392
|
-
if positional_embeddings.device.type == "cpu":
|
|
393
|
-
positional_embeddings = positional_embeddings.astype(mx.float32)
|
|
394
|
-
|
|
395
|
-
for i in range(batch_size):
|
|
396
|
-
# (1, dim, height, width) -> (1, dim, target_height, target_width)
|
|
397
|
-
height, width = spatial_shapes[i]
|
|
398
|
-
# Then upsample width dimension
|
|
399
|
-
resized_embeddings = resize_bilinear(
|
|
400
|
-
positional_embeddings,
|
|
401
|
-
(height, width),
|
|
402
|
-
align_corners=False,
|
|
403
|
-
antialias=True,
|
|
404
|
-
)
|
|
405
|
-
|
|
406
|
-
# (1, dim, target_height, target_width) -> (target_height * target_width, dim)
|
|
407
|
-
resized_embeddings = resized_embeddings.reshape(
|
|
408
|
-
embed_dim, height * width
|
|
409
|
-
).transpose(0, 1)
|
|
410
|
-
|
|
411
|
-
# Cast to original dtype
|
|
412
|
-
resized_embeddings = resized_embeddings.astype(source_dtype)
|
|
413
|
-
|
|
414
|
-
resulted_positional_embeddings[i, : height * width] = resized_embeddings
|
|
415
|
-
resulted_positional_embeddings[i, height * width :] = resized_embeddings[0]
|
|
416
|
-
|
|
417
|
-
return resulted_positional_embeddings
|
|
418
|
-
|
|
419
|
-
def __call__(
|
|
420
|
-
self, x: mx.array, spatial_shapes: Optional[mx.array] = None
|
|
421
|
-
) -> mx.array:
|
|
422
|
-
batch_size = x.shape[0]
|
|
423
|
-
patch_embeddings = self.patch_embedding(x)
|
|
424
|
-
patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2)
|
|
425
|
-
if spatial_shapes is None:
|
|
426
|
-
position_ids = mx.array(np.arange(self.num_positions)[None, :])
|
|
427
|
-
embeddings = patch_embeddings
|
|
428
|
-
embeddings += self.position_embedding(position_ids)
|
|
429
|
-
|
|
430
|
-
else:
|
|
431
|
-
# Get positional resized and padded positional embeddings
|
|
432
|
-
positional_embeddings = self.position_embedding.weight.reshape(
|
|
433
|
-
self.position_embedding_size, self.position_embedding_size, -1
|
|
434
|
-
)
|
|
435
|
-
|
|
436
|
-
resized_positional_embeddings = self.resize_positional_embeddings(
|
|
437
|
-
positional_embeddings, spatial_shapes, max_length=x.shape[1]
|
|
438
|
-
)
|
|
439
|
-
|
|
440
|
-
# Add positional embeddings to patch embeddings
|
|
441
|
-
embeddings = patch_embeds + resized_positional_embeddings
|
|
442
|
-
return embeddings
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
class SigLipVisionModel(nn.Module):
|
|
446
|
-
def __init__(self, config: VisionConfig):
|
|
447
|
-
super().__init__()
|
|
448
|
-
|
|
449
|
-
self.embeddings = VisionEmbeddings(config)
|
|
450
|
-
self.encoder = Encoder(config)
|
|
451
|
-
self.post_layernorm = nn.LayerNorm(config.hidden_size)
|
|
452
|
-
|
|
453
|
-
def __call__(
|
|
454
|
-
self,
|
|
455
|
-
x: mx.array,
|
|
456
|
-
spatial_shapes: mx.array,
|
|
457
|
-
output_hidden_states: Optional[bool] = None,
|
|
458
|
-
) -> mx.array:
|
|
459
|
-
x = self.embeddings(x, spatial_shapes)
|
|
460
|
-
x = x.astype(self.embeddings.patch_embedding.weight.dtype)
|
|
461
|
-
encoder_outputs = self.encoder(
|
|
462
|
-
x=x, output_hidden_states=output_hidden_states, mask=None
|
|
463
|
-
)
|
|
464
|
-
pooler_output = self.post_layernorm(encoder_outputs[0])
|
|
465
|
-
return pooler_output, x, encoder_outputs[-1]
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
class VisionModel(nn.Module):
|
|
469
|
-
def __init__(self, config: VisionConfig):
|
|
470
|
-
super().__init__()
|
|
471
|
-
self.model_type = config.model_type
|
|
472
|
-
if self.model_type not in ["siglip_vision_model"]:
|
|
473
|
-
raise ValueError(f"Unsupported model type: {self.model_type}")
|
|
474
|
-
|
|
475
|
-
self.vision_model = SigLipVisionModel(config)
|
|
476
|
-
|
|
477
|
-
def __call__(
|
|
478
|
-
self,
|
|
479
|
-
x: mx.array,
|
|
480
|
-
spatial_shapes: Optional[mx.array] = None,
|
|
481
|
-
output_hidden_states: Optional[bool] = None,
|
|
482
|
-
) -> mx.array:
|
|
483
|
-
return self.vision_model(x, spatial_shapes, output_hidden_states)
|
|
484
|
-
|
|
485
|
-
def sanitize(self, weights):
|
|
486
|
-
sanitized_weights = {}
|
|
487
|
-
for k, v in weights.items():
|
|
488
|
-
if "position_ids" in k:
|
|
489
|
-
# Remove unused position_ids
|
|
490
|
-
continue
|
|
491
|
-
elif "patch_embedding.weight" in k:
|
|
492
|
-
# PyTorch conv2d weight tensors have shape:
|
|
493
|
-
# [out_channels, in_channels, kH, KW]
|
|
494
|
-
# MLX conv2d expects the weight be of shape:
|
|
495
|
-
# [out_channels, kH, KW, in_channels]
|
|
496
|
-
if check_array_shape(v):
|
|
497
|
-
sanitized_weights[k] = v
|
|
498
|
-
else:
|
|
499
|
-
sanitized_weights[k] = v.transpose(0, 2, 3, 1)
|
|
500
|
-
else:
|
|
501
|
-
sanitized_weights[k] = v
|
|
502
|
-
|
|
503
|
-
return sanitized_weights
|
|
@@ -1,202 +0,0 @@
|
|
|
1
|
-
import math
|
|
2
|
-
from abc import ABC, abstractmethod
|
|
3
|
-
from dataclasses import dataclass
|
|
4
|
-
from typing import Any, Dict, List, Optional
|
|
5
|
-
|
|
6
|
-
import mlx.core as mx
|
|
7
|
-
import mlx.nn as nn
|
|
8
|
-
from mlx_lm.models.base import create_attention_mask, scaled_dot_product_attention
|
|
9
|
-
from mlx_lm.models.cache import RotatingKVCache
|
|
10
|
-
from PIL import Image
|
|
11
|
-
from transformers.image_processing_utils import BaseImageProcessor as ImageProcessor
|
|
12
|
-
from transformers.image_processing_utils import get_size_dict
|
|
13
|
-
from transformers.image_utils import ChannelDimension, PILImageResampling
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
@dataclass
|
|
17
|
-
class LanguageModelOutput:
|
|
18
|
-
logits: mx.array
|
|
19
|
-
cross_attention_states: Optional[List[mx.array]] = None
|
|
20
|
-
encoder_outputs: Optional[List[mx.array]] = None
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
def expand2square(pil_img, background_color):
|
|
24
|
-
width, height = pil_img.size
|
|
25
|
-
if width == height:
|
|
26
|
-
return pil_img
|
|
27
|
-
elif width > height:
|
|
28
|
-
result = Image.new(pil_img.mode, (width, width), background_color)
|
|
29
|
-
result.paste(pil_img, (0, (width - height) // 2))
|
|
30
|
-
return result
|
|
31
|
-
else:
|
|
32
|
-
result = Image.new(pil_img.mode, (height, height), background_color)
|
|
33
|
-
result.paste(pil_img, ((height - width) // 2, 0))
|
|
34
|
-
return result
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
def check_array_shape(arr):
|
|
38
|
-
shape = arr.shape
|
|
39
|
-
|
|
40
|
-
# Check if the shape has 4 dimensions
|
|
41
|
-
if len(shape) == 4:
|
|
42
|
-
out_channels, kH, KW, _ = shape
|
|
43
|
-
# Check if out_channels is the largest, and kH and KW are the same
|
|
44
|
-
if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
|
|
45
|
-
return True
|
|
46
|
-
else:
|
|
47
|
-
return False
|
|
48
|
-
# Check if the shape has 3 dimensions
|
|
49
|
-
elif len(shape) == 3:
|
|
50
|
-
_, kW, out_channels = shape
|
|
51
|
-
# Check if out_channels is the largest
|
|
52
|
-
if kW >= out_channels:
|
|
53
|
-
return True
|
|
54
|
-
else:
|
|
55
|
-
return False
|
|
56
|
-
else:
|
|
57
|
-
return False
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
class BaseImageProcessor(ImageProcessor):
|
|
61
|
-
def __init__(
|
|
62
|
-
self,
|
|
63
|
-
image_mean=(0.5, 0.5, 0.5),
|
|
64
|
-
image_std=(0.5, 0.5, 0.5),
|
|
65
|
-
size=(384, 384),
|
|
66
|
-
crop_size: Dict[str, int] = None,
|
|
67
|
-
resample=PILImageResampling.BICUBIC,
|
|
68
|
-
rescale_factor=1 / 255,
|
|
69
|
-
data_format=ChannelDimension.FIRST,
|
|
70
|
-
):
|
|
71
|
-
crop_size = (
|
|
72
|
-
crop_size if crop_size is not None else {"height": 384, "width": 384}
|
|
73
|
-
)
|
|
74
|
-
crop_size = get_size_dict(
|
|
75
|
-
crop_size, default_to_square=True, param_name="crop_size"
|
|
76
|
-
)
|
|
77
|
-
|
|
78
|
-
self.image_mean = image_mean
|
|
79
|
-
self.image_std = image_std
|
|
80
|
-
self.size = size
|
|
81
|
-
self.resample = resample
|
|
82
|
-
self.rescale_factor = rescale_factor
|
|
83
|
-
self.data_format = data_format
|
|
84
|
-
self.crop_size = crop_size
|
|
85
|
-
|
|
86
|
-
@abstractmethod
|
|
87
|
-
def preprocess(self, images):
|
|
88
|
-
pass
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
# Add this code to visualize the chunked attention mask
|
|
92
|
-
def visualize_attention_mask(mask):
|
|
93
|
-
"""Visualize attention mask with symbols for better readability."""
|
|
94
|
-
if mask is None:
|
|
95
|
-
print("No mask")
|
|
96
|
-
return
|
|
97
|
-
|
|
98
|
-
seq_len = mask.shape[0]
|
|
99
|
-
|
|
100
|
-
print(" ", end="")
|
|
101
|
-
for i in range(seq_len):
|
|
102
|
-
print(f"{i:2d} ", end="")
|
|
103
|
-
print()
|
|
104
|
-
|
|
105
|
-
for i in range(seq_len):
|
|
106
|
-
print(f"Token {i:2d}: ", end="")
|
|
107
|
-
for j in range(seq_len):
|
|
108
|
-
if mask[i, j]:
|
|
109
|
-
print(" ■ ", end="")
|
|
110
|
-
else:
|
|
111
|
-
print(" ⬚ ", end="")
|
|
112
|
-
print()
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
def check_activation_stats(name, tensor):
|
|
116
|
-
"""Helper function to check for anomalies and log stats."""
|
|
117
|
-
|
|
118
|
-
print(f"--- Activation Stats: {name} ---")
|
|
119
|
-
# Check for NaNs/Infs
|
|
120
|
-
has_nan = mx.isnan(tensor).any()
|
|
121
|
-
has_inf = mx.isinf(tensor).any()
|
|
122
|
-
if has_nan:
|
|
123
|
-
print(f"WARNING: Found NaN in {name}")
|
|
124
|
-
if has_inf:
|
|
125
|
-
print(f"WARNING: Found Inf in {name}")
|
|
126
|
-
|
|
127
|
-
# Calculate and print stats (ensure computation happens)
|
|
128
|
-
min_val = mx.min(tensor).item()
|
|
129
|
-
max_val = mx.max(tensor).item()
|
|
130
|
-
mean_val = mx.mean(tensor).item()
|
|
131
|
-
std_val = mx.std(tensor).item()
|
|
132
|
-
print(f" Shape: {tensor.shape}")
|
|
133
|
-
print(f" Min: {min_val:.4f}, Max: {max_val:.4f}")
|
|
134
|
-
print(f" Mean: {mean_val:.4f}, Std: {std_val:.4f}")
|
|
135
|
-
print("-" * (len(name) + 24))
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
def pixel_shuffle(input_tensor, shuffle_ratio):
|
|
139
|
-
# input_tensor: [batch_size, num_patches, channels]
|
|
140
|
-
batch_size, num_patches, channels = input_tensor.shape
|
|
141
|
-
patch_size = int(math.sqrt(num_patches))
|
|
142
|
-
|
|
143
|
-
input_tensor = input_tensor.reshape(batch_size, patch_size, patch_size, -1)
|
|
144
|
-
batch_size, height, width, channels = input_tensor.shape
|
|
145
|
-
|
|
146
|
-
reshaped_tensor = input_tensor.reshape(
|
|
147
|
-
batch_size, height, int(width * shuffle_ratio), int(channels / shuffle_ratio)
|
|
148
|
-
)
|
|
149
|
-
reshaped_tensor = reshaped_tensor.transpose(0, 2, 1, 3)
|
|
150
|
-
|
|
151
|
-
reshaped_tensor = reshaped_tensor.reshape(
|
|
152
|
-
batch_size,
|
|
153
|
-
int(height * shuffle_ratio),
|
|
154
|
-
int(width * shuffle_ratio),
|
|
155
|
-
int(channels / (shuffle_ratio**2)),
|
|
156
|
-
)
|
|
157
|
-
reshaped_tensor = reshaped_tensor.transpose(0, 2, 1, 3)
|
|
158
|
-
|
|
159
|
-
output_tensor = reshaped_tensor.reshape(batch_size, -1, reshaped_tensor.shape[-1])
|
|
160
|
-
return output_tensor
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
def interpolate(pos_embed, size, mode="cubic", align_corners=False):
|
|
164
|
-
"""
|
|
165
|
-
MLX implementation of PyTorch's F.interpolate with bicubic mode
|
|
166
|
-
|
|
167
|
-
Args:
|
|
168
|
-
pos_embed: MLX array with shape [B, C, H_src, W_src] or [C, H_src, W_src]
|
|
169
|
-
size: Tuple (H_dst, W_dst) - target size
|
|
170
|
-
align_corners: Boolean - whether to align corners
|
|
171
|
-
|
|
172
|
-
Returns:
|
|
173
|
-
Interpolated array with shape [B, C, H_dst, W_dst] or [C, H_dst, W_dst]
|
|
174
|
-
"""
|
|
175
|
-
# Handle different input shapes
|
|
176
|
-
input_dim = pos_embed.ndim
|
|
177
|
-
original_shape = pos_embed.shape
|
|
178
|
-
|
|
179
|
-
if input_dim == 3:
|
|
180
|
-
# [C, H, W] -> [1, C, H, W]
|
|
181
|
-
pos_embed = pos_embed.reshape(1, *original_shape)
|
|
182
|
-
|
|
183
|
-
# Get source dimensions
|
|
184
|
-
h_src, w_src = pos_embed.shape[-2:]
|
|
185
|
-
h_dst, w_dst = size
|
|
186
|
-
|
|
187
|
-
# Calculate scale factors
|
|
188
|
-
scale_h = h_dst / h_src
|
|
189
|
-
scale_w = w_dst / w_src
|
|
190
|
-
|
|
191
|
-
# Create upsampler
|
|
192
|
-
upsampler = nn.Upsample(
|
|
193
|
-
scale_factor=(scale_h, scale_w), mode=mode, align_corners=align_corners
|
|
194
|
-
)
|
|
195
|
-
|
|
196
|
-
# Apply upsampling
|
|
197
|
-
result = upsampler(pos_embed)
|
|
198
|
-
|
|
199
|
-
# Return in the original dimension format
|
|
200
|
-
if input_dim == 3:
|
|
201
|
-
return result.reshape(original_shape[0], *size)
|
|
202
|
-
return result
|