nexaai 1.0.19rc7__cp310-cp310-macosx_14_0_universal2.whl → 1.0.19rc9__cp310-cp310-macosx_14_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nexaai might be problematic. Click here for more details.
- nexaai/_stub.cpython-310-darwin.so +0 -0
- nexaai/_version.py +1 -1
- nexaai/binds/libnexa_bridge.dylib +0 -0
- nexaai/mlx_backend/vlm/generate_qwen3_vl.py +14 -31
- nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +15 -32
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +7 -23
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +8 -24
- {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc9.dist-info}/METADATA +1 -1
- {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc9.dist-info}/RECORD +11 -200
- nexaai/binds/nexa_mlx/py-lib/asr/__init__.py +0 -12
- nexaai/binds/nexa_mlx/py-lib/asr/interface.py +0 -122
- nexaai/binds/nexa_mlx/py-lib/common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/common/utils.py +0 -25
- nexaai/binds/nexa_mlx/py-lib/cv/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/cv/generate.py +0 -195
- nexaai/binds/nexa_mlx/py-lib/cv/interface.py +0 -151
- nexaai/binds/nexa_mlx/py-lib/cv/main.py +0 -81
- nexaai/binds/nexa_mlx/py-lib/cv/modeling/pp_ocr_v4.py +0 -1736
- nexaai/binds/nexa_mlx/py-lib/embedding/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/generate.py +0 -333
- nexaai/binds/nexa_mlx/py-lib/embedding/interface.py +0 -617
- nexaai/binds/nexa_mlx/py-lib/embedding/main.py +0 -173
- nexaai/binds/nexa_mlx/py-lib/embedding/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/modeling/nexa_jina_v2.py +0 -399
- nexaai/binds/nexa_mlx/py-lib/image_gen/__init__.py +0 -1
- nexaai/binds/nexa_mlx/py-lib/image_gen/generate_sd.py +0 -244
- nexaai/binds/nexa_mlx/py-lib/image_gen/interface.py +0 -82
- nexaai/binds/nexa_mlx/py-lib/image_gen/main.py +0 -281
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/__init__.py +0 -306
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/clip.py +0 -116
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/config.py +0 -65
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/model_io.py +0 -386
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/sampler.py +0 -105
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/tokenizer.py +0 -100
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/unet.py +0 -460
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/vae.py +0 -274
- nexaai/binds/nexa_mlx/py-lib/llm/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/llm/generate.py +0 -149
- nexaai/binds/nexa_mlx/py-lib/llm/interface.py +0 -764
- nexaai/binds/nexa_mlx/py-lib/llm/main.py +0 -68
- nexaai/binds/nexa_mlx/py-lib/rerank/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/generate.py +0 -174
- nexaai/binds/nexa_mlx/py-lib/rerank/interface.py +0 -287
- nexaai/binds/nexa_mlx/py-lib/rerank/main.py +0 -127
- nexaai/binds/nexa_mlx/py-lib/rerank/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/modeling/nexa_jina_rerank.py +0 -330
- nexaai/binds/nexa_mlx/py-lib/sd/__init__.py +0 -1
- nexaai/binds/nexa_mlx/py-lib/sd/interface.py +0 -362
- nexaai/binds/nexa_mlx/py-lib/sd/main.py +0 -286
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/__init__.py +0 -306
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/clip.py +0 -116
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/config.py +0 -65
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/model_io.py +0 -385
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/sampler.py +0 -105
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/tokenizer.py +0 -100
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/unet.py +0 -460
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/vae.py +0 -274
- nexaai/binds/nexa_mlx/py-lib/tts/__init__.py +0 -12
- nexaai/binds/nexa_mlx/py-lib/tts/interface.py +0 -276
- nexaai/binds/nexa_mlx/py-lib/vlm/__init__.py +0 -3
- nexaai/binds/nexa_mlx/py-lib/vlm/generate.py +0 -572
- nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl.py +0 -294
- nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl_moe.py +0 -276
- nexaai/binds/nexa_mlx/py-lib/vlm/interface.py +0 -504
- nexaai/binds/nexa_mlx/py-lib/vlm/main.py +0 -320
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/convert.py +0 -68
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/aya_vision.py +0 -193
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/interpolate.py +0 -186
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/language.py +0 -233
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/vision.py +0 -503
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/base.py +0 -202
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/cache.py +0 -230
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/__init__.py +0 -10
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/conversation.py +0 -264
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +0 -472
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/language.py +0 -591
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +0 -526
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/vision.py +0 -356
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/florence2.py +0 -366
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/language.py +0 -488
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/vision.py +0 -591
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/gemma3.py +0 -213
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/language.py +0 -315
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/vision.py +0 -238
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/__init__.py +0 -2
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/audio.py +0 -1038
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/config.py +0 -139
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/gemma3n.py +0 -322
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/language.py +0 -629
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/vision.py +0 -1022
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/__init__.py +0 -9
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/idefics2.py +0 -294
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/language.py +0 -191
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/vision.py +0 -267
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/idefics3.py +0 -175
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/language.py +0 -192
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/vision.py +0 -233
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/__init__.py +0 -9
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/internvl_chat.py +0 -140
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/language.py +0 -220
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/processor.py +0 -393
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/vision.py +0 -293
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kernels.py +0 -307
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/kimi_vl.py +0 -143
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/language.py +0 -509
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/vision.py +0 -522
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/language.py +0 -386
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/llama4.py +0 -138
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/vision.py +0 -560
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/language.py +0 -240
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/llava.py +0 -153
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/vision.py +0 -259
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/__init__.py +0 -9
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/language.py +0 -236
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/llava_bunny.py +0 -256
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/vision.py +0 -303
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/language.py +0 -230
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/llava_next.py +0 -160
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/vision.py +0 -243
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/mistral3.py +0 -283
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/language.py +0 -416
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/mllama.py +0 -172
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/vision.py +0 -499
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/language.py +0 -243
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/molmo.py +0 -133
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/vision.py +0 -465
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/__init__.py +0 -10
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/language.py +0 -230
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/multi_modality.py +0 -385
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/sam.py +0 -557
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/vision.py +0 -526
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/language.py +0 -282
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/paligemma.py +0 -160
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/vision.py +0 -242
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/language.py +0 -21
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/phi3_v.py +0 -243
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/su_rope.py +0 -71
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/vision.py +0 -324
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/language.py +0 -229
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/pixtral.py +0 -161
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/vision.py +0 -320
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/__init__.py +0 -2
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/config.py +0 -108
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/language.py +0 -490
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +0 -168
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/vision.py +0 -414
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/__init__.py +0 -2
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/config.py +0 -104
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/language.py +0 -490
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/qwen2_vl.py +0 -167
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/vision.py +0 -312
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/base.py +0 -117
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/cache.py +0 -531
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/generate.py +0 -701
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +0 -255
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +0 -303
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +0 -407
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/processor.py +0 -476
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/qwen3vl.py +0 -1223
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +0 -117
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +0 -531
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +0 -701
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +0 -255
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +0 -303
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +0 -407
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/processor.py +0 -476
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +0 -1309
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/switch_layers.py +0 -210
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/__init__.py +0 -8
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/smolvlm.py +0 -62
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_5_vl.py +0 -209
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_vl.py +0 -215
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/prompt_utils.py +0 -474
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/sample_utils.py +0 -39
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/tokenizer_utils.py +0 -344
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/__init__.py +0 -9
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/lora.py +0 -70
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/trainer.py +0 -296
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/utils.py +0 -160
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/utils.py +0 -928
- {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc9.dist-info}/WHEEL +0 -0
- {nexaai-1.0.19rc7.dist-info → nexaai-1.0.19rc9.dist-info}/top_level.txt +0 -0
|
@@ -1,243 +0,0 @@
|
|
|
1
|
-
import inspect
|
|
2
|
-
import math
|
|
3
|
-
from dataclasses import dataclass
|
|
4
|
-
from typing import Optional
|
|
5
|
-
|
|
6
|
-
import mlx.core as mx
|
|
7
|
-
import mlx.nn as nn
|
|
8
|
-
import numpy as np
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
@dataclass
|
|
12
|
-
class VisionConfig:
|
|
13
|
-
model_type: str
|
|
14
|
-
num_hidden_layers: int = 24
|
|
15
|
-
hidden_size: int = 1024
|
|
16
|
-
intermediate_size: int = 4096
|
|
17
|
-
num_attention_heads: int = 16
|
|
18
|
-
image_size: int = 336
|
|
19
|
-
patch_size: int = 14
|
|
20
|
-
projection_dim: int = 768
|
|
21
|
-
vocab_size: int = 32000
|
|
22
|
-
num_channels: int = 3
|
|
23
|
-
layer_norm_eps: float = 1e-5
|
|
24
|
-
|
|
25
|
-
@classmethod
|
|
26
|
-
def from_dict(cls, params):
|
|
27
|
-
return cls(
|
|
28
|
-
**{
|
|
29
|
-
k: v
|
|
30
|
-
for k, v in params.items()
|
|
31
|
-
if k in inspect.signature(cls).parameters
|
|
32
|
-
}
|
|
33
|
-
)
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
def check_array_shape(arr):
|
|
37
|
-
shape = arr.shape
|
|
38
|
-
|
|
39
|
-
# Check if the shape has 4 dimensions
|
|
40
|
-
if len(shape) != 4:
|
|
41
|
-
return False
|
|
42
|
-
|
|
43
|
-
out_channels, kH, KW, _ = shape
|
|
44
|
-
|
|
45
|
-
# Check if out_channels is the largest, and kH and KW are the same
|
|
46
|
-
if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
|
|
47
|
-
return True
|
|
48
|
-
else:
|
|
49
|
-
return False
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
class Attention(nn.Module):
|
|
53
|
-
def __init__(
|
|
54
|
-
self,
|
|
55
|
-
dims: int,
|
|
56
|
-
num_heads: int,
|
|
57
|
-
query_input_dims: Optional[int] = None,
|
|
58
|
-
key_input_dims: Optional[int] = None,
|
|
59
|
-
value_input_dims: Optional[int] = None,
|
|
60
|
-
value_dims: Optional[int] = None,
|
|
61
|
-
value_output_dims: Optional[int] = None,
|
|
62
|
-
bias: bool = False,
|
|
63
|
-
):
|
|
64
|
-
super().__init__()
|
|
65
|
-
|
|
66
|
-
if (dims % num_heads) != 0:
|
|
67
|
-
raise ValueError(
|
|
68
|
-
"The input feature dimensions should be divisible by the "
|
|
69
|
-
f"number of heads ({dims} % {num_heads}) != 0"
|
|
70
|
-
)
|
|
71
|
-
|
|
72
|
-
query_input_dims = query_input_dims or dims
|
|
73
|
-
key_input_dims = key_input_dims or dims
|
|
74
|
-
value_input_dims = value_input_dims or key_input_dims
|
|
75
|
-
value_dims = value_dims or dims
|
|
76
|
-
value_output_dims = value_output_dims or dims
|
|
77
|
-
|
|
78
|
-
self.num_heads = num_heads = num_heads
|
|
79
|
-
head_dim = dims // num_heads
|
|
80
|
-
self.scale = head_dim**-0.5
|
|
81
|
-
|
|
82
|
-
self.q_proj = nn.Linear(query_input_dims, dims, bias=bias)
|
|
83
|
-
self.k_proj = nn.Linear(key_input_dims, dims, bias=bias)
|
|
84
|
-
self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias)
|
|
85
|
-
self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias)
|
|
86
|
-
|
|
87
|
-
def __call__(self, queries, keys, values, mask=None):
|
|
88
|
-
queries = self.q_proj(queries)
|
|
89
|
-
keys = self.k_proj(keys)
|
|
90
|
-
values = self.v_proj(values)
|
|
91
|
-
|
|
92
|
-
num_heads = self.num_heads
|
|
93
|
-
B, L, D = queries.shape
|
|
94
|
-
_, S, _ = keys.shape
|
|
95
|
-
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
|
96
|
-
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
|
97
|
-
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
|
98
|
-
|
|
99
|
-
output = mx.fast.scaled_dot_product_attention(
|
|
100
|
-
queries, keys, values, scale=self.scale, mask=mask
|
|
101
|
-
)
|
|
102
|
-
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
103
|
-
|
|
104
|
-
return self.out_proj(output)
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
class MLP(nn.Module):
|
|
108
|
-
def __init__(self, config: VisionConfig):
|
|
109
|
-
super().__init__()
|
|
110
|
-
self.activation_fn = nn.GELU(approx="fast")
|
|
111
|
-
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
|
112
|
-
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
|
113
|
-
|
|
114
|
-
def __call__(self, x: mx.array) -> mx.array:
|
|
115
|
-
x = self.activation_fn(self.fc1(x))
|
|
116
|
-
x = self.fc2(x)
|
|
117
|
-
return x
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
class EncoderLayer(nn.Module):
|
|
121
|
-
def __init__(self, config: VisionConfig):
|
|
122
|
-
super().__init__()
|
|
123
|
-
self.embed_dim = config.hidden_size
|
|
124
|
-
self.self_attn = Attention(
|
|
125
|
-
config.hidden_size, config.num_attention_heads, bias=True
|
|
126
|
-
)
|
|
127
|
-
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
|
128
|
-
self.mlp = MLP(config)
|
|
129
|
-
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
|
130
|
-
|
|
131
|
-
def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
|
|
132
|
-
y = self.layer_norm1(x)
|
|
133
|
-
y = self.self_attn(y, y, y, mask)
|
|
134
|
-
x = x + y
|
|
135
|
-
y = self.layer_norm2(x)
|
|
136
|
-
y = self.mlp(y)
|
|
137
|
-
return x + y
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
class Encoder(nn.Module):
|
|
141
|
-
def __init__(self, config: VisionConfig):
|
|
142
|
-
super().__init__()
|
|
143
|
-
self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)]
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
class VisionEmbeddings(nn.Module):
|
|
147
|
-
def __init__(self, config: VisionConfig):
|
|
148
|
-
super().__init__()
|
|
149
|
-
self.config = config
|
|
150
|
-
self.embed_dim = config.hidden_size
|
|
151
|
-
self.image_size = config.image_size
|
|
152
|
-
self.patch_size = config.patch_size
|
|
153
|
-
|
|
154
|
-
self.class_embedding = mx.zeros((config.hidden_size,))
|
|
155
|
-
|
|
156
|
-
self.patch_embedding = nn.Conv2d(
|
|
157
|
-
in_channels=config.num_channels,
|
|
158
|
-
out_channels=self.embed_dim,
|
|
159
|
-
kernel_size=self.patch_size,
|
|
160
|
-
stride=self.patch_size,
|
|
161
|
-
bias=False,
|
|
162
|
-
)
|
|
163
|
-
|
|
164
|
-
self.num_patches = (self.image_size // self.patch_size) ** 2
|
|
165
|
-
self.num_positions = self.num_patches + 1
|
|
166
|
-
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
|
167
|
-
|
|
168
|
-
def __call__(self, x: mx.array) -> mx.array:
|
|
169
|
-
batch_size = x.shape[0]
|
|
170
|
-
patch_embeddings = self.patch_embedding(x)
|
|
171
|
-
patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2)
|
|
172
|
-
embed_dim = patch_embeddings.shape[-1]
|
|
173
|
-
cls_embeddings = mx.broadcast_to(
|
|
174
|
-
self.class_embedding, (batch_size, 1, embed_dim)
|
|
175
|
-
)
|
|
176
|
-
position_ids = mx.array(np.arange(self.num_positions)[None, :])
|
|
177
|
-
|
|
178
|
-
embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1)
|
|
179
|
-
embeddings += self.position_embedding(position_ids)
|
|
180
|
-
return embeddings
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
class ClipVisionModel(nn.Module):
|
|
184
|
-
def __init__(self, config: VisionConfig):
|
|
185
|
-
super().__init__()
|
|
186
|
-
self.embeddings = VisionEmbeddings(config)
|
|
187
|
-
self.pre_layrnorm = nn.LayerNorm(config.hidden_size)
|
|
188
|
-
self.encoder = Encoder(config)
|
|
189
|
-
self.post_layernorm = nn.LayerNorm(config.hidden_size)
|
|
190
|
-
|
|
191
|
-
def __call__(
|
|
192
|
-
self,
|
|
193
|
-
x: mx.array,
|
|
194
|
-
output_hidden_states: Optional[bool] = None,
|
|
195
|
-
) -> mx.array:
|
|
196
|
-
x = self.embeddings(x)
|
|
197
|
-
x = self.pre_layrnorm(x)
|
|
198
|
-
|
|
199
|
-
encoder_states = (x,) if output_hidden_states else None
|
|
200
|
-
|
|
201
|
-
for l in self.encoder.layers:
|
|
202
|
-
x = l(x, mask=None)
|
|
203
|
-
if output_hidden_states:
|
|
204
|
-
encoder_states = encoder_states + (x,)
|
|
205
|
-
|
|
206
|
-
pooler_output = self.post_layernorm(x[:, 0, :])
|
|
207
|
-
return pooler_output, x, encoder_states
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
class VisionModel(nn.Module):
|
|
211
|
-
def __init__(self, config: VisionConfig):
|
|
212
|
-
super().__init__()
|
|
213
|
-
|
|
214
|
-
self.model_type = config.model_type
|
|
215
|
-
if self.model_type != "clip_vision_model":
|
|
216
|
-
raise ValueError(f"Unsupported model type: {self.model_type}")
|
|
217
|
-
|
|
218
|
-
self.vision_model = ClipVisionModel(config)
|
|
219
|
-
|
|
220
|
-
def __call__(
|
|
221
|
-
self, x: mx.array, output_hidden_states: Optional[bool] = None
|
|
222
|
-
) -> mx.array:
|
|
223
|
-
return self.vision_model(x, output_hidden_states)
|
|
224
|
-
|
|
225
|
-
def sanitize(self, weights):
|
|
226
|
-
sanitized_weights = {}
|
|
227
|
-
for k, v in weights.items():
|
|
228
|
-
if "position_ids" in k:
|
|
229
|
-
# Remove unused position_ids
|
|
230
|
-
continue
|
|
231
|
-
elif "patch_embedding.weight" in k:
|
|
232
|
-
# PyTorch conv2d weight tensors have shape:
|
|
233
|
-
# [out_channels, in_channels, kH, KW]
|
|
234
|
-
# MLX conv2d expects the weight be of shape:
|
|
235
|
-
# [out_channels, kH, KW, in_channels]
|
|
236
|
-
if check_array_shape(v):
|
|
237
|
-
sanitized_weights[k] = v
|
|
238
|
-
else:
|
|
239
|
-
sanitized_weights[k] = v.transpose(0, 2, 3, 1)
|
|
240
|
-
else:
|
|
241
|
-
sanitized_weights[k] = v
|
|
242
|
-
|
|
243
|
-
return sanitized_weights
|
|
@@ -1,283 +0,0 @@
|
|
|
1
|
-
import glob
|
|
2
|
-
import inspect
|
|
3
|
-
import json
|
|
4
|
-
from dataclasses import dataclass
|
|
5
|
-
from pathlib import Path
|
|
6
|
-
from typing import List, Optional, Tuple, Union
|
|
7
|
-
|
|
8
|
-
import mlx.core as mx
|
|
9
|
-
import mlx.nn as nn
|
|
10
|
-
import numpy as np
|
|
11
|
-
|
|
12
|
-
from ..pixtral import LanguageModel
|
|
13
|
-
from ..pixtral import Model as PixtralModel
|
|
14
|
-
from ..pixtral import TextConfig, VisionConfig, VisionModel
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
@dataclass
|
|
18
|
-
class ModelConfig:
|
|
19
|
-
text_config: TextConfig
|
|
20
|
-
vision_config: VisionConfig
|
|
21
|
-
model_type: str
|
|
22
|
-
ignore_index: int = -100
|
|
23
|
-
image_token_index: int = 10
|
|
24
|
-
vision_feature_select_strategy: str = "full"
|
|
25
|
-
vision_feature_layer: int = -1
|
|
26
|
-
vocab_size: int = 32000
|
|
27
|
-
spatial_merge_size: int = 2
|
|
28
|
-
multimodal_projector_bias: bool = False
|
|
29
|
-
eos_token_id: Optional[List[int]] = None
|
|
30
|
-
|
|
31
|
-
@classmethod
|
|
32
|
-
def from_dict(cls, params):
|
|
33
|
-
return cls(
|
|
34
|
-
**{
|
|
35
|
-
k: v
|
|
36
|
-
for k, v in params.items()
|
|
37
|
-
if k in inspect.signature(cls).parameters
|
|
38
|
-
}
|
|
39
|
-
)
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
def _pair(x) -> Tuple[int, int]:
|
|
43
|
-
"""Convert input to a pair of values."""
|
|
44
|
-
if isinstance(x, (list, tuple)):
|
|
45
|
-
return tuple(x)
|
|
46
|
-
return (x, x)
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
def unfold(
|
|
50
|
-
input: mx.array,
|
|
51
|
-
kernel_size: Union[int, Tuple[int, int], List[int]],
|
|
52
|
-
dilation: Union[int, Tuple[int, int], List[int]] = 1,
|
|
53
|
-
padding: Union[int, Tuple[int, int], List[int]] = 0,
|
|
54
|
-
stride: Union[int, Tuple[int, int], List[int]] = 1,
|
|
55
|
-
) -> mx.array:
|
|
56
|
-
"""
|
|
57
|
-
Extract sliding local blocks from a batched input tensor (MLX implementation).
|
|
58
|
-
|
|
59
|
-
This is equivalent to PyTorch's nn.functional.unfold or im2col operation.
|
|
60
|
-
|
|
61
|
-
Args:
|
|
62
|
-
input: Input tensor of shape (B, C, H, W)
|
|
63
|
-
kernel_size: Size of the sliding blocks
|
|
64
|
-
dilation: Controls the spacing between kernel elements
|
|
65
|
-
padding: Controls the amount of implicit padding
|
|
66
|
-
stride: Controls the stride between blocks
|
|
67
|
-
|
|
68
|
-
Returns:
|
|
69
|
-
Unfolded tensor of shape (B, C*kernel_height*kernel_width, L)
|
|
70
|
-
where L is the number of blocks
|
|
71
|
-
"""
|
|
72
|
-
# Convert to pairs
|
|
73
|
-
kernel_size = _pair(kernel_size)
|
|
74
|
-
dilation = _pair(dilation)
|
|
75
|
-
padding = _pair(padding)
|
|
76
|
-
stride = _pair(stride)
|
|
77
|
-
|
|
78
|
-
# Input shape
|
|
79
|
-
batch_size, channels, height, width = input.shape
|
|
80
|
-
|
|
81
|
-
# Add padding if needed
|
|
82
|
-
if padding[0] > 0 or padding[1] > 0:
|
|
83
|
-
padding_shape = (
|
|
84
|
-
(0, 0),
|
|
85
|
-
(0, 0),
|
|
86
|
-
(padding[0], padding[0]),
|
|
87
|
-
(padding[1], padding[1]),
|
|
88
|
-
)
|
|
89
|
-
input = mx.pad(input, padding_shape)
|
|
90
|
-
|
|
91
|
-
# Calculate output dimensions
|
|
92
|
-
height_out = (
|
|
93
|
-
height + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1
|
|
94
|
-
) // stride[0] + 1
|
|
95
|
-
width_out = (
|
|
96
|
-
width + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1
|
|
97
|
-
) // stride[1] + 1
|
|
98
|
-
|
|
99
|
-
# Initialize output arrays
|
|
100
|
-
blocks = []
|
|
101
|
-
|
|
102
|
-
# Extract blocks
|
|
103
|
-
for i in range(
|
|
104
|
-
0, height + 2 * padding[0] - kernel_size[0] * dilation[0] + 1, stride[0]
|
|
105
|
-
):
|
|
106
|
-
for j in range(
|
|
107
|
-
0, width + 2 * padding[1] - kernel_size[1] * dilation[1] + 1, stride[1]
|
|
108
|
-
):
|
|
109
|
-
# Extract the block for all channels
|
|
110
|
-
block = []
|
|
111
|
-
for di in range(kernel_size[0]):
|
|
112
|
-
for dj in range(kernel_size[1]):
|
|
113
|
-
h_idx = i + di * dilation[0]
|
|
114
|
-
w_idx = j + dj * dilation[1]
|
|
115
|
-
# Get the block for all channels and add to our list
|
|
116
|
-
block.append(input[:, :, h_idx, w_idx])
|
|
117
|
-
|
|
118
|
-
# Stack the channel-blocks
|
|
119
|
-
block = mx.stack(block, axis=1) # Shape: (B, k*k, C)
|
|
120
|
-
block = mx.transpose(block, [0, 2, 1]) # Shape: (B, C, k*k)
|
|
121
|
-
blocks.append(block)
|
|
122
|
-
|
|
123
|
-
# Stack all blocks together
|
|
124
|
-
result = mx.stack(blocks, axis=-1) # Shape: (B, C, k*k, L)
|
|
125
|
-
|
|
126
|
-
# Reshape to match PyTorch's unfold output format: (B, C*k*k, L)
|
|
127
|
-
result = mx.reshape(
|
|
128
|
-
result,
|
|
129
|
-
(
|
|
130
|
-
batch_size,
|
|
131
|
-
channels * kernel_size[0] * kernel_size[1],
|
|
132
|
-
height_out * width_out,
|
|
133
|
-
),
|
|
134
|
-
)
|
|
135
|
-
|
|
136
|
-
return result
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
class Mistral3PatchMerger(nn.Module):
|
|
140
|
-
"""
|
|
141
|
-
Learned merging of spatial_merge_size ** 2 patches
|
|
142
|
-
"""
|
|
143
|
-
|
|
144
|
-
def __init__(self, config: ModelConfig):
|
|
145
|
-
super().__init__()
|
|
146
|
-
self.config = config
|
|
147
|
-
|
|
148
|
-
hidden_size = config.vision_config.hidden_size
|
|
149
|
-
self.spatial_merge_size = config.spatial_merge_size
|
|
150
|
-
self.patch_size = self.config.vision_config.patch_size
|
|
151
|
-
self.merging_layer = nn.Linear(
|
|
152
|
-
hidden_size * self.spatial_merge_size**2, hidden_size, bias=False
|
|
153
|
-
)
|
|
154
|
-
|
|
155
|
-
def __call__(self, image_features: mx.array, image_sizes: mx.array) -> mx.array:
|
|
156
|
-
|
|
157
|
-
image_sizes = [
|
|
158
|
-
(image_size[0] // self.patch_size, image_size[1] // self.patch_size)
|
|
159
|
-
for image_size in image_sizes
|
|
160
|
-
]
|
|
161
|
-
|
|
162
|
-
tokens_per_image = [h * w for h, w in image_sizes]
|
|
163
|
-
d = image_features.shape[-1]
|
|
164
|
-
image_features = image_features.astype(mx.bfloat16)
|
|
165
|
-
image_sizes = mx.array(image_sizes)
|
|
166
|
-
|
|
167
|
-
# Split the image features into chunks based on tokens_per_image
|
|
168
|
-
split_indices = []
|
|
169
|
-
current_index = 0
|
|
170
|
-
for tokens in tokens_per_image:
|
|
171
|
-
split_indices.append(current_index + tokens)
|
|
172
|
-
current_index += tokens
|
|
173
|
-
|
|
174
|
-
# Perform the split
|
|
175
|
-
chunks = mx.split(image_features, split_indices[:-1], axis=1)
|
|
176
|
-
|
|
177
|
-
permuted_tensor = []
|
|
178
|
-
for image_index, image_tokens in enumerate(chunks):
|
|
179
|
-
|
|
180
|
-
# Reshape image_tokens into a 2D grid
|
|
181
|
-
if image_tokens.shape[1] > 0:
|
|
182
|
-
h, w = image_sizes[image_index].tolist()
|
|
183
|
-
|
|
184
|
-
image_grid = image_tokens.reshape(h, w, d).transpose(2, 0, 1)[None, ...]
|
|
185
|
-
|
|
186
|
-
grid = unfold(
|
|
187
|
-
image_grid,
|
|
188
|
-
kernel_size=self.spatial_merge_size,
|
|
189
|
-
stride=self.spatial_merge_size,
|
|
190
|
-
)
|
|
191
|
-
grid = grid.reshape(d * self.spatial_merge_size**2, -1).T
|
|
192
|
-
permuted_tensor.append(grid)
|
|
193
|
-
|
|
194
|
-
image_features = mx.concatenate(permuted_tensor, axis=0)
|
|
195
|
-
image_features = self.merging_layer(image_features)
|
|
196
|
-
return image_features[None, ...]
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
class Mistral3MultiModalProjector(nn.Module):
|
|
200
|
-
def __init__(self, config: ModelConfig):
|
|
201
|
-
super().__init__()
|
|
202
|
-
|
|
203
|
-
self.norm = nn.RMSNorm(config.vision_config.hidden_size)
|
|
204
|
-
self.patch_merger = Mistral3PatchMerger(config)
|
|
205
|
-
|
|
206
|
-
num_feature_layers = (
|
|
207
|
-
1
|
|
208
|
-
if isinstance(config.vision_feature_layer, int)
|
|
209
|
-
else len(config.vision_feature_layer)
|
|
210
|
-
)
|
|
211
|
-
self.linear_1 = nn.Linear(
|
|
212
|
-
config.vision_config.hidden_size * num_feature_layers,
|
|
213
|
-
config.text_config.hidden_size,
|
|
214
|
-
bias=config.multimodal_projector_bias,
|
|
215
|
-
)
|
|
216
|
-
self.gelu = nn.GELU()
|
|
217
|
-
self.linear_2 = nn.Linear(
|
|
218
|
-
config.text_config.hidden_size,
|
|
219
|
-
config.text_config.hidden_size,
|
|
220
|
-
bias=config.multimodal_projector_bias,
|
|
221
|
-
)
|
|
222
|
-
|
|
223
|
-
def __call__(self, x: mx.array, image_sizes: mx.array) -> mx.array:
|
|
224
|
-
x = self.norm(x)
|
|
225
|
-
|
|
226
|
-
x = self.patch_merger(x, image_sizes)
|
|
227
|
-
x = self.linear_1(x)
|
|
228
|
-
x = self.gelu(x)
|
|
229
|
-
x = self.linear_2(x)
|
|
230
|
-
return x
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
class Model(PixtralModel):
|
|
234
|
-
def __init__(self, config: ModelConfig):
|
|
235
|
-
super().__init__(config)
|
|
236
|
-
self.config = config
|
|
237
|
-
|
|
238
|
-
self.multi_modal_projector = Mistral3MultiModalProjector(config)
|
|
239
|
-
|
|
240
|
-
def get_input_embeddings(
|
|
241
|
-
self,
|
|
242
|
-
input_ids: Optional[mx.array] = None,
|
|
243
|
-
pixel_values: Optional[mx.array] = None,
|
|
244
|
-
**kwargs,
|
|
245
|
-
):
|
|
246
|
-
image_sizes = kwargs.get("image_sizes", None)
|
|
247
|
-
|
|
248
|
-
if pixel_values is None:
|
|
249
|
-
return self.language_model.model.embed_tokens(input_ids)
|
|
250
|
-
|
|
251
|
-
# Get the input embeddings from the language model
|
|
252
|
-
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
|
253
|
-
|
|
254
|
-
# Get the output hidden states from the vision model
|
|
255
|
-
if isinstance(pixel_values, list):
|
|
256
|
-
pixel_values = mx.concatenate(
|
|
257
|
-
[mx.array(pv)[None, ...] for pv in pixel_values], axis=0
|
|
258
|
-
)
|
|
259
|
-
if pixel_values.ndim == 3:
|
|
260
|
-
pixel_values = pixel_values[None, ...]
|
|
261
|
-
|
|
262
|
-
# Pass pixel_values as list of images, as each image is individually run through conv2d and position encoding
|
|
263
|
-
# Reference code from transformers: https://github.com/huggingface/transformers/blob/main/src/transformers/models/pixtral/modeling_pixtral.py#L479C9-L479C21
|
|
264
|
-
# and mistral_inference: https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/vision_encoder.py#L85
|
|
265
|
-
*_, hidden_states = self.vision_tower(
|
|
266
|
-
pixel_values.transpose(0, 2, 3, 1),
|
|
267
|
-
output_hidden_states=True,
|
|
268
|
-
)
|
|
269
|
-
# Select the hidden states from the desired layer
|
|
270
|
-
selected_image_feature = hidden_states[self.vision_feature_layer]
|
|
271
|
-
|
|
272
|
-
# Pass image features through the multi-modal projector
|
|
273
|
-
image_features = self.multi_modal_projector(selected_image_feature, image_sizes)
|
|
274
|
-
|
|
275
|
-
# Insert special image tokens in the input_ids
|
|
276
|
-
final_inputs_embeds = self.merge_input_ids_with_image_features(
|
|
277
|
-
self.config.image_token_index, image_features, inputs_embeds, input_ids
|
|
278
|
-
)
|
|
279
|
-
return final_inputs_embeds
|
|
280
|
-
|
|
281
|
-
@property
|
|
282
|
-
def layers(self):
|
|
283
|
-
return self.language_model.model.layers
|