nexaai 1.0.19rc6__cp310-cp310-macosx_14_0_universal2.whl → 1.0.19rc7__cp310-cp310-macosx_14_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nexaai might be problematic. Click here for more details.
- nexaai/_stub.cpython-310-darwin.so +0 -0
- nexaai/_version.py +1 -1
- nexaai/binds/libnexa_bridge.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libggml-base.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libggml-cpu.so +0 -0
- nexaai/binds/nexa_llama_cpp/libggml-metal.so +0 -0
- nexaai/binds/nexa_llama_cpp/libggml.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libllama.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libmtmd.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexa_mlx/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexa_mlx/py-lib/asr/__init__.py +12 -0
- nexaai/binds/nexa_mlx/py-lib/asr/interface.py +122 -0
- nexaai/binds/nexa_mlx/py-lib/common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/common/utils.py +25 -0
- nexaai/binds/nexa_mlx/py-lib/cv/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/cv/generate.py +195 -0
- nexaai/binds/nexa_mlx/py-lib/cv/interface.py +151 -0
- nexaai/binds/nexa_mlx/py-lib/cv/main.py +81 -0
- nexaai/binds/nexa_mlx/py-lib/cv/modeling/pp_ocr_v4.py +1736 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/generate.py +333 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/interface.py +617 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/main.py +173 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/modeling/nexa_jina_v2.py +399 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/generate_sd.py +244 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/interface.py +82 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/main.py +281 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/__init__.py +306 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/clip.py +116 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/config.py +65 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/model_io.py +386 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/sampler.py +105 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/tokenizer.py +100 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/unet.py +460 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/vae.py +274 -0
- nexaai/binds/nexa_mlx/py-lib/llm/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/llm/generate.py +149 -0
- nexaai/binds/nexa_mlx/py-lib/llm/interface.py +764 -0
- nexaai/binds/nexa_mlx/py-lib/llm/main.py +68 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/generate.py +174 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/interface.py +287 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/main.py +127 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/modeling/nexa_jina_rerank.py +330 -0
- nexaai/binds/nexa_mlx/py-lib/sd/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/sd/interface.py +362 -0
- nexaai/binds/nexa_mlx/py-lib/sd/main.py +286 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/__init__.py +306 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/clip.py +116 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/config.py +65 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/model_io.py +385 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/sampler.py +105 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/tokenizer.py +100 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/unet.py +460 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/vae.py +274 -0
- nexaai/binds/nexa_mlx/py-lib/tts/__init__.py +12 -0
- nexaai/binds/nexa_mlx/py-lib/tts/interface.py +276 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/__init__.py +3 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/generate.py +572 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl.py +294 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl_moe.py +276 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/interface.py +504 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/main.py +320 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/convert.py +68 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/interpolate.py +186 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/language.py +233 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/vision.py +503 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/base.py +202 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/cache.py +230 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/florence2.py +366 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/language.py +488 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/vision.py +591 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/gemma3.py +213 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/language.py +315 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/vision.py +238 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/__init__.py +2 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/audio.py +1038 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/config.py +139 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/language.py +629 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/vision.py +1022 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/__init__.py +9 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/idefics2.py +294 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/language.py +191 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/vision.py +267 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/idefics3.py +175 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/language.py +192 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/vision.py +233 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/__init__.py +9 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/language.py +220 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/processor.py +393 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/vision.py +293 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kernels.py +307 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/language.py +509 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/vision.py +522 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/language.py +386 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/llama4.py +138 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/vision.py +560 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/language.py +240 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/llava.py +153 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/vision.py +259 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/__init__.py +9 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/language.py +236 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/vision.py +303 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/language.py +230 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/llava_next.py +160 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/vision.py +243 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/mistral3.py +283 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/language.py +416 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/mllama.py +172 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/vision.py +499 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/language.py +243 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/molmo.py +133 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/vision.py +465 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/__init__.py +10 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/language.py +230 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/sam.py +557 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/vision.py +526 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/language.py +282 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/paligemma.py +160 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/vision.py +242 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/language.py +21 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/su_rope.py +71 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/vision.py +324 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/language.py +229 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/pixtral.py +161 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/vision.py +320 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/config.py +104 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/language.py +490 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/vision.py +312 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/processor.py +476 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/qwen3vl.py +1223 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1309 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_5_vl.py +209 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_vl.py +215 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/prompt_utils.py +474 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/sample_utils.py +39 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/tokenizer_utils.py +344 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/__init__.py +9 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/lora.py +70 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/trainer.py +296 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/utils.py +160 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/utils.py +928 -0
- nexaai/binds/nexa_nexaml/libggml-base.dylib +0 -0
- nexaai/binds/nexa_nexaml/libggml-cpu.so +0 -0
- nexaai/binds/nexa_nexaml/libggml-metal.so +0 -0
- nexaai/binds/nexa_nexaml/libggml.dylib +0 -0
- nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +276 -0
- nexaai/mlx_backend/vlm/interface.py +21 -4
- nexaai/mlx_backend/vlm/main.py +6 -2
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1309 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
- nexaai/utils/manifest_utils.py +222 -15
- nexaai/utils/model_manager.py +83 -7
- nexaai/utils/model_types.py +2 -0
- {nexaai-1.0.19rc6.dist-info → nexaai-1.0.19rc7.dist-info}/METADATA +1 -1
- {nexaai-1.0.19rc6.dist-info → nexaai-1.0.19rc7.dist-info}/RECORD +224 -24
- {nexaai-1.0.19rc6.dist-info → nexaai-1.0.19rc7.dist-info}/WHEEL +0 -0
- {nexaai-1.0.19rc6.dist-info → nexaai-1.0.19rc7.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,393 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import List, Optional, Tuple, Union
|
|
4
|
+
|
|
5
|
+
import mlx.core as mx
|
|
6
|
+
import numpy as np
|
|
7
|
+
from PIL import Image
|
|
8
|
+
from transformers import (
|
|
9
|
+
AutoImageProcessor,
|
|
10
|
+
AutoProcessor,
|
|
11
|
+
AutoTokenizer,
|
|
12
|
+
BatchFeature,
|
|
13
|
+
PreTrainedTokenizerBase,
|
|
14
|
+
ProcessorMixin,
|
|
15
|
+
)
|
|
16
|
+
from transformers.image_utils import ImageFeatureExtractionMixin
|
|
17
|
+
from transformers.utils import logging
|
|
18
|
+
|
|
19
|
+
logger = logging.get_logger(__name__)
|
|
20
|
+
|
|
21
|
+
# Constants for image processing (from internvl_chat.py)
|
|
22
|
+
|
|
23
|
+
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406])
|
|
24
|
+
IMAGENET_STD = np.array([0.229, 0.224, 0.225])
|
|
25
|
+
# chat_template = get_conv_template("internvl2_5")
|
|
26
|
+
chat_template = "{% for message in messages %}{{message['role'].capitalize() + ': '}}{# Render all images first #}{% for content in message['content'] | selectattr('type', 'equalto', 'image') %}{{ '<image>\n' }}{% endfor %}{# Render all text next #}{% for content in message['content'] | selectattr('type', 'equalto', 'text') %}{{ content['content'] }}{% endfor %}{{'\n'}}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:\n' }}{% endif %}"
|
|
27
|
+
|
|
28
|
+
IMG_START_TOKEN = "<img>"
|
|
29
|
+
IMG_END_TOKEN = "</img>"
|
|
30
|
+
IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>"
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def build_transform(input_size):
|
|
34
|
+
"""
|
|
35
|
+
Builds a transformation pipeline for images.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
input_size (int): The target size for the image (height and width).
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
function: A function that takes a PIL image and returns a normalized mx.array.
|
|
42
|
+
"""
|
|
43
|
+
mean = mx.array(IMAGENET_MEAN)
|
|
44
|
+
std = mx.array(IMAGENET_STD)
|
|
45
|
+
|
|
46
|
+
def transform(img: Image.Image) -> mx.array:
|
|
47
|
+
# Ensure image is RGB
|
|
48
|
+
if img.mode != "RGB":
|
|
49
|
+
img = img.convert("RGB")
|
|
50
|
+
|
|
51
|
+
# Resize using PIL - BICUBIC interpolation is default in Pillow >= 9.1.0 for resize
|
|
52
|
+
# For older versions, you might need Pillow-SIMD or explicitly set
|
|
53
|
+
# resampling=Image.BICUBIC if available.
|
|
54
|
+
img = img.resize((input_size, input_size), resample=Image.Resampling.BICUBIC)
|
|
55
|
+
|
|
56
|
+
# Convert PIL image to NumPy array (H, W, C) and scale to [0, 1]
|
|
57
|
+
img_np = np.array(img).astype(np.float32) / 255.0
|
|
58
|
+
|
|
59
|
+
# Convert to MLX array and transpose to (C, H, W)
|
|
60
|
+
img_mx = mx.array(img_np).transpose(2, 0, 1)
|
|
61
|
+
|
|
62
|
+
# Normalize
|
|
63
|
+
img_mx = (img_mx - mean[:, None, None]) / std[:, None, None]
|
|
64
|
+
|
|
65
|
+
return img_mx
|
|
66
|
+
|
|
67
|
+
return transform
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
|
|
71
|
+
"""Finds the closest aspect ratio from a list of targets."""
|
|
72
|
+
best_ratio_diff = float("inf")
|
|
73
|
+
best_ratio = (1, 1)
|
|
74
|
+
area = width * height
|
|
75
|
+
for ratio in target_ratios:
|
|
76
|
+
target_aspect_ratio = ratio[0] / ratio[1]
|
|
77
|
+
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
|
78
|
+
if ratio_diff < best_ratio_diff:
|
|
79
|
+
best_ratio_diff = ratio_diff
|
|
80
|
+
best_ratio = ratio
|
|
81
|
+
elif ratio_diff == best_ratio_diff:
|
|
82
|
+
# Prioritize ratios closer to the original image area if diffs are equal
|
|
83
|
+
target_area = image_size * image_size * ratio[0] * ratio[1]
|
|
84
|
+
if abs(area - target_area) < abs(
|
|
85
|
+
area - image_size * image_size * best_ratio[0] * best_ratio[1]
|
|
86
|
+
):
|
|
87
|
+
best_ratio = ratio
|
|
88
|
+
return best_ratio
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def dynamic_preprocess(
|
|
92
|
+
image: Image.Image, min_num=1, max_num=12, image_size=448, use_thumbnail=False
|
|
93
|
+
):
|
|
94
|
+
"""
|
|
95
|
+
Preprocesses the image by splitting it into blocks based on the closest aspect ratio.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
image (PIL.Image.Image): Input image.
|
|
99
|
+
min_num (int): Minimum number of blocks.
|
|
100
|
+
max_num (int): Maximum number of blocks.
|
|
101
|
+
image_size (int): Target size for each block.
|
|
102
|
+
use_thumbnail (bool): Whether to include a thumbnail of the original image.
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
list[PIL.Image.Image]: A list of processed image blocks (as PIL images).
|
|
106
|
+
"""
|
|
107
|
+
orig_width, orig_height = image.size
|
|
108
|
+
if orig_width == 0 or orig_height == 0:
|
|
109
|
+
# Handle potential zero dimensions
|
|
110
|
+
return []
|
|
111
|
+
aspect_ratio = orig_width / orig_height
|
|
112
|
+
|
|
113
|
+
# Calculate the possible target aspect ratios
|
|
114
|
+
target_ratios = set(
|
|
115
|
+
(i, j)
|
|
116
|
+
for n in range(min_num, max_num + 1)
|
|
117
|
+
for i in range(1, n + 1)
|
|
118
|
+
for j in range(1, n + 1)
|
|
119
|
+
if min_num <= i * j <= max_num
|
|
120
|
+
)
|
|
121
|
+
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
|
122
|
+
|
|
123
|
+
# Find the closest target aspect ratio
|
|
124
|
+
target_aspect_ratio = find_closest_aspect_ratio(
|
|
125
|
+
aspect_ratio, target_ratios, orig_width, orig_height, image_size
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# Calculate the target dimensions for resizing
|
|
129
|
+
target_width = image_size * target_aspect_ratio[0]
|
|
130
|
+
target_height = image_size * target_aspect_ratio[1]
|
|
131
|
+
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
|
132
|
+
|
|
133
|
+
# Resize the image to fit the target block structure
|
|
134
|
+
# Using BICUBIC resampling
|
|
135
|
+
resized_img = image.resize(
|
|
136
|
+
(target_width, target_height), resample=Image.Resampling.BICUBIC
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
processed_images = []
|
|
140
|
+
# Crop the resized image into blocks
|
|
141
|
+
for i in range(blocks):
|
|
142
|
+
# Calculate crop box for the i-th block
|
|
143
|
+
row_idx = i // target_aspect_ratio[0]
|
|
144
|
+
col_idx = i % target_aspect_ratio[0]
|
|
145
|
+
left = col_idx * image_size
|
|
146
|
+
top = row_idx * image_size
|
|
147
|
+
right = (col_idx + 1) * image_size
|
|
148
|
+
bottom = (row_idx + 1) * image_size
|
|
149
|
+
box = (left, top, right, bottom)
|
|
150
|
+
|
|
151
|
+
# Crop and add the block
|
|
152
|
+
split_img = resized_img.crop(box)
|
|
153
|
+
processed_images.append(split_img)
|
|
154
|
+
|
|
155
|
+
assert (
|
|
156
|
+
len(processed_images) == blocks
|
|
157
|
+
), f"Expected {blocks} blocks, but got {len(processed_images)}"
|
|
158
|
+
|
|
159
|
+
# Add a thumbnail if requested and if the image was split
|
|
160
|
+
if use_thumbnail and blocks > 1:
|
|
161
|
+
thumbnail_img = image.resize(
|
|
162
|
+
(image_size, image_size), resample=Image.Resampling.BICUBIC
|
|
163
|
+
)
|
|
164
|
+
processed_images.append(thumbnail_img)
|
|
165
|
+
|
|
166
|
+
return processed_images
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class InternVLImageProcessor(ImageFeatureExtractionMixin):
|
|
170
|
+
model_input_names = ["pixel_values"]
|
|
171
|
+
|
|
172
|
+
def __init__(
|
|
173
|
+
self,
|
|
174
|
+
do_resize: bool = True,
|
|
175
|
+
size: int = 448, # Default image size from dynamic_preprocess
|
|
176
|
+
resample=Image.Resampling.BICUBIC,
|
|
177
|
+
do_center_crop: bool = False, # Not used in original, but standard HF param
|
|
178
|
+
crop_size=None,
|
|
179
|
+
do_rescale: bool = True, # Original code scales by 1/255.0
|
|
180
|
+
rescale_factor: float = 1 / 255.0,
|
|
181
|
+
do_normalize: bool = True,
|
|
182
|
+
image_mean=IMAGENET_MEAN.tolist(),
|
|
183
|
+
image_std=IMAGENET_STD.tolist(),
|
|
184
|
+
do_dynamic_preprocess: bool = True,
|
|
185
|
+
dynamic_min_num: int = 1,
|
|
186
|
+
dynamic_max_num: int = 12,
|
|
187
|
+
dynamic_use_thumbnail: bool = True,
|
|
188
|
+
**kwargs,
|
|
189
|
+
):
|
|
190
|
+
super().__init__()
|
|
191
|
+
self.do_resize = (
|
|
192
|
+
do_resize # Although dynamic_preprocess handles resizing internally
|
|
193
|
+
)
|
|
194
|
+
self.size = size
|
|
195
|
+
self.resample = resample
|
|
196
|
+
self.do_center_crop = do_center_crop
|
|
197
|
+
self.crop_size = crop_size
|
|
198
|
+
self.do_rescale = do_rescale
|
|
199
|
+
self.rescale_factor = rescale_factor
|
|
200
|
+
self.do_normalize = do_normalize
|
|
201
|
+
self.image_mean = image_mean
|
|
202
|
+
self.image_std = image_std
|
|
203
|
+
# Custom dynamic processing params
|
|
204
|
+
self.do_dynamic_preprocess = do_dynamic_preprocess
|
|
205
|
+
self.dynamic_min_num = dynamic_min_num
|
|
206
|
+
self.dynamic_max_num = dynamic_max_num
|
|
207
|
+
self.dynamic_use_thumbnail = dynamic_use_thumbnail
|
|
208
|
+
|
|
209
|
+
def preprocess(
|
|
210
|
+
self,
|
|
211
|
+
images: List[Image.Image],
|
|
212
|
+
do_dynamic_preprocess: Optional[bool] = None,
|
|
213
|
+
size: Optional[int] = None,
|
|
214
|
+
# ... other params matching __init__ ...
|
|
215
|
+
return_tensors: Optional[str] = None,
|
|
216
|
+
**kwargs,
|
|
217
|
+
) -> List[mx.array]:
|
|
218
|
+
|
|
219
|
+
do_dynamic_preprocess = (
|
|
220
|
+
do_dynamic_preprocess
|
|
221
|
+
if do_dynamic_preprocess is not None
|
|
222
|
+
else self.do_dynamic_preprocess
|
|
223
|
+
)
|
|
224
|
+
size = size if size is not None else self.size
|
|
225
|
+
# ... handle other overrides ...
|
|
226
|
+
|
|
227
|
+
if not isinstance(images, list):
|
|
228
|
+
images = [images]
|
|
229
|
+
|
|
230
|
+
if not all(isinstance(image, Image.Image) for image in images):
|
|
231
|
+
raise ValueError("Input must be a list of PIL Images.")
|
|
232
|
+
|
|
233
|
+
processed_images_batch = []
|
|
234
|
+
for image in images:
|
|
235
|
+
# Apply dynamic preprocessing
|
|
236
|
+
if do_dynamic_preprocess:
|
|
237
|
+
processed_images = dynamic_preprocess(
|
|
238
|
+
image,
|
|
239
|
+
min_num=self.dynamic_min_num,
|
|
240
|
+
max_num=self.dynamic_max_num,
|
|
241
|
+
image_size=size,
|
|
242
|
+
use_thumbnail=self.dynamic_use_thumbnail,
|
|
243
|
+
)
|
|
244
|
+
else:
|
|
245
|
+
# Fallback or alternative simpler preprocessing if needed
|
|
246
|
+
# e.g., simple resize + normalize
|
|
247
|
+
processed_images = [image.resize((size, size), resample=self.resample)]
|
|
248
|
+
|
|
249
|
+
# Create transform function
|
|
250
|
+
transform = build_transform(input_size=size)
|
|
251
|
+
|
|
252
|
+
# Apply transform to each image block and collect arrays
|
|
253
|
+
pixel_values_list = [transform(img) for img in processed_images]
|
|
254
|
+
|
|
255
|
+
# Stack the arrays along a new dimension (batch dimension)
|
|
256
|
+
pixel_values = mx.stack(pixel_values_list, axis=0)
|
|
257
|
+
|
|
258
|
+
processed_images_batch.append(pixel_values)
|
|
259
|
+
|
|
260
|
+
# At this point, processed_images_batch contains a list of mx arrays,
|
|
261
|
+
# each array corresponding to an input image with stacked blocks.
|
|
262
|
+
|
|
263
|
+
data = {"pixel_values": mx.array(processed_images_batch)}
|
|
264
|
+
return BatchFeature(data=data, tensor_type=None)
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
class InternVLChatProcessor(ProcessorMixin):
|
|
268
|
+
attributes = ["image_processor", "tokenizer"]
|
|
269
|
+
image_processor_class = "InternVLImageProcessor"
|
|
270
|
+
tokenizer_class = (
|
|
271
|
+
"AutoTokenizer",
|
|
272
|
+
"Qwen2TokenizerFast",
|
|
273
|
+
) # Specify possible classes
|
|
274
|
+
|
|
275
|
+
def __init__(
|
|
276
|
+
self,
|
|
277
|
+
image_processor=None,
|
|
278
|
+
tokenizer=None,
|
|
279
|
+
chat_template=chat_template,
|
|
280
|
+
**kwargs,
|
|
281
|
+
):
|
|
282
|
+
if image_processor is None:
|
|
283
|
+
image_processor = InternVLImageProcessor(**kwargs)
|
|
284
|
+
if isinstance(tokenizer, str):
|
|
285
|
+
# Defaulting to the likely repo ID found earlier
|
|
286
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
287
|
+
tokenizer, trust_remote_code=True, **kwargs
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
|
291
|
+
|
|
292
|
+
self.num_image_token = int((448 // 14) ** 2 * (0.5**2))
|
|
293
|
+
|
|
294
|
+
def __call__(
|
|
295
|
+
self,
|
|
296
|
+
text: Union[str, List[str]] = None,
|
|
297
|
+
images: List[Image.Image] = None,
|
|
298
|
+
padding: Union[bool, str] = True,
|
|
299
|
+
truncation: bool = True,
|
|
300
|
+
max_length: Optional[int] = None,
|
|
301
|
+
return_tensors: Optional[str] = "pt", # Default to PyTorch tensors
|
|
302
|
+
**kwargs,
|
|
303
|
+
):
|
|
304
|
+
processed_inputs = {}
|
|
305
|
+
if images is not None:
|
|
306
|
+
image_features = self.image_processor.preprocess(
|
|
307
|
+
images, return_tensors=return_tensors, **kwargs
|
|
308
|
+
)
|
|
309
|
+
processed_inputs.update(image_features) # Should contain 'pixel_values'
|
|
310
|
+
|
|
311
|
+
if text is not None:
|
|
312
|
+
queries = []
|
|
313
|
+
|
|
314
|
+
if isinstance(text, str):
|
|
315
|
+
text = [text]
|
|
316
|
+
|
|
317
|
+
for idx in range(len(images)):
|
|
318
|
+
question = text[idx]
|
|
319
|
+
|
|
320
|
+
if images is not None and "<image>" not in question:
|
|
321
|
+
question = "<image>\n" + question
|
|
322
|
+
|
|
323
|
+
num_patches = image_features["pixel_values"][idx].shape[0]
|
|
324
|
+
image_tokens = (
|
|
325
|
+
IMG_START_TOKEN
|
|
326
|
+
+ IMG_CONTEXT_TOKEN * self.num_image_token * num_patches
|
|
327
|
+
+ IMG_END_TOKEN
|
|
328
|
+
)
|
|
329
|
+
question = question.replace("<image>", image_tokens, 1)
|
|
330
|
+
queries.append(question)
|
|
331
|
+
|
|
332
|
+
self.tokenizer.padding_side = "left"
|
|
333
|
+
text_inputs = self.tokenizer(
|
|
334
|
+
queries,
|
|
335
|
+
padding=padding,
|
|
336
|
+
truncation=truncation,
|
|
337
|
+
max_length=max_length,
|
|
338
|
+
return_tensors=return_tensors,
|
|
339
|
+
**kwargs,
|
|
340
|
+
)
|
|
341
|
+
processed_inputs.update(text_inputs) # 'input_ids', 'attention_mask'
|
|
342
|
+
|
|
343
|
+
return processed_inputs
|
|
344
|
+
|
|
345
|
+
def batch_decode(self, *args, **kwargs):
|
|
346
|
+
"""
|
|
347
|
+
This method forwards all its arguments to the tokenizer's batch_decode method.
|
|
348
|
+
"""
|
|
349
|
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
|
350
|
+
|
|
351
|
+
def decode(self, *args, **kwargs):
|
|
352
|
+
"""
|
|
353
|
+
This method forwards all its arguments to the tokenizer's decode method.
|
|
354
|
+
"""
|
|
355
|
+
return self.tokenizer.decode(*args, **kwargs)
|
|
356
|
+
|
|
357
|
+
def save_pretrained(self, save_directory, **kwargs):
|
|
358
|
+
pass
|
|
359
|
+
|
|
360
|
+
@staticmethod
|
|
361
|
+
def from_pretrained(pretrained_model_name_or_path, **kwargs):
|
|
362
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
363
|
+
pretrained_model_name_or_path, **kwargs
|
|
364
|
+
)
|
|
365
|
+
image_processor = InternVLImageProcessor(**kwargs)
|
|
366
|
+
return InternVLChatProcessor(
|
|
367
|
+
image_processor=image_processor, tokenizer=tokenizer
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
# Need save_pretrained and from_pretrained
|
|
371
|
+
# save_pretrained should save both tokenizer and image_processor configs/files
|
|
372
|
+
# from_pretrained should load both
|
|
373
|
+
|
|
374
|
+
# Example:
|
|
375
|
+
# def save_pretrained(self, save_directory, **kwargs):
|
|
376
|
+
# self.tokenizer.save_pretrained(save_directory, **kwargs)
|
|
377
|
+
# self.image_processor.save_pretrained(save_directory, **kwargs)
|
|
378
|
+
|
|
379
|
+
# def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
|
380
|
+
# tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
|
381
|
+
# image_processor = InternVLImageProcessor.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
|
382
|
+
# return cls(image_processor=image_processor, tokenizer=tokenizer)
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
# Registration
|
|
386
|
+
MODEL_TYPE = "internvl_chat" # Verify this from the model's config.json
|
|
387
|
+
|
|
388
|
+
AutoImageProcessor.register(
|
|
389
|
+
MODEL_TYPE, slow_image_processor_class=InternVLImageProcessor
|
|
390
|
+
)
|
|
391
|
+
AutoProcessor.register(MODEL_TYPE, InternVLChatProcessor)
|
|
392
|
+
|
|
393
|
+
logger.info(f"Registered custom processor classes for model type '{MODEL_TYPE}'.")
|
|
@@ -0,0 +1,293 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import mlx.core as mx
|
|
6
|
+
import mlx.nn as nn
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from ..base import interpolate
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class VisionConfig:
|
|
14
|
+
model_type: str
|
|
15
|
+
hidden_size: int = 1024
|
|
16
|
+
num_attention_heads: int = 16
|
|
17
|
+
patch_size: int = 14
|
|
18
|
+
num_hidden_layers: int = 24
|
|
19
|
+
intermediate_size: int = 4096
|
|
20
|
+
image_size: int = 448
|
|
21
|
+
num_channels: int = 3
|
|
22
|
+
layer_norm_eps: float = 1e-6
|
|
23
|
+
drop_path_rate: float = 0.1
|
|
24
|
+
qkv_bias: bool = True
|
|
25
|
+
qk_normalization: bool = False
|
|
26
|
+
norm_type: str = "layer_norm"
|
|
27
|
+
|
|
28
|
+
@classmethod
|
|
29
|
+
def from_dict(cls, params):
|
|
30
|
+
return cls(
|
|
31
|
+
**{
|
|
32
|
+
k: v
|
|
33
|
+
for k, v in params.items()
|
|
34
|
+
if k in inspect.signature(cls).parameters
|
|
35
|
+
}
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def check_array_shape(arr):
|
|
40
|
+
shape = arr.shape
|
|
41
|
+
|
|
42
|
+
# Check if the shape has 4 dimensions
|
|
43
|
+
if len(shape) != 4:
|
|
44
|
+
return False
|
|
45
|
+
|
|
46
|
+
out_channels, kH, KW, _ = shape
|
|
47
|
+
|
|
48
|
+
# Check if out_channels is the largest, and kH and KW are the same
|
|
49
|
+
if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
|
|
50
|
+
return True
|
|
51
|
+
else:
|
|
52
|
+
return False
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class Attention(nn.Module):
|
|
56
|
+
def __init__(self, config: VisionConfig):
|
|
57
|
+
super().__init__()
|
|
58
|
+
|
|
59
|
+
if (config.hidden_size % config.num_attention_heads) != 0:
|
|
60
|
+
raise ValueError(
|
|
61
|
+
"The input feature dimensions should be divisible by the "
|
|
62
|
+
f"number of heads ({config.hidden_size} % {config.num_attention_heads}) != 0"
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
self.dims = dims = config.hidden_size
|
|
66
|
+
|
|
67
|
+
self.num_heads = config.num_attention_heads
|
|
68
|
+
head_dim = config.hidden_size // config.num_attention_heads
|
|
69
|
+
self.scale = head_dim**-0.5
|
|
70
|
+
self.qkv_bias = config.qkv_bias
|
|
71
|
+
|
|
72
|
+
self.qkv = nn.Linear(dims, 3 * dims, bias=config.qkv_bias)
|
|
73
|
+
self.proj = nn.Linear(dims, dims)
|
|
74
|
+
|
|
75
|
+
self.qk_normalization = config.qk_normalization
|
|
76
|
+
|
|
77
|
+
if self.qk_normalization:
|
|
78
|
+
self.q_norm = nn.RMSNorm(dims, eps=config.layer_norm_eps)
|
|
79
|
+
self.k_norm = nn.RMSNorm(dims, eps=config.layer_norm_eps)
|
|
80
|
+
|
|
81
|
+
def __call__(self, x, mask=None):
|
|
82
|
+
B, L, C = x.shape
|
|
83
|
+
qkv = self.qkv(x).reshape(B, L, 3, self.num_heads, C // self.num_heads)
|
|
84
|
+
qkv = qkv.transpose(2, 0, 3, 1, 4)
|
|
85
|
+
queries, keys, values = (
|
|
86
|
+
qkv[0],
|
|
87
|
+
qkv[1],
|
|
88
|
+
qkv[2],
|
|
89
|
+
) # Each has shape (B, groups, N, C//groups)
|
|
90
|
+
|
|
91
|
+
if self.qk_normalization:
|
|
92
|
+
B_, H_, N_, D_ = queries.shape
|
|
93
|
+
queries = (
|
|
94
|
+
self.q_norm(queries.transpose(0, 2, 1, 3).flatten(-2, -1))
|
|
95
|
+
.reshape(B_, N_, H_, D_)
|
|
96
|
+
.transpose(0, 2, 1, 3)
|
|
97
|
+
)
|
|
98
|
+
keys = (
|
|
99
|
+
self.k_norm(keys.transpose(0, 2, 1, 3).flatten(-2, -1))
|
|
100
|
+
.reshape(B_, N_, H_, D_)
|
|
101
|
+
.transpose(0, 2, 1, 3)
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
output = mx.fast.scaled_dot_product_attention(
|
|
105
|
+
queries, keys, values, scale=self.scale, mask=mask
|
|
106
|
+
)
|
|
107
|
+
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
108
|
+
return self.proj(output)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
class MLP(nn.Module):
|
|
112
|
+
def __init__(self, config: VisionConfig):
|
|
113
|
+
super().__init__()
|
|
114
|
+
self.activation_fn = nn.GELU(approx="precise")
|
|
115
|
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
|
116
|
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
|
117
|
+
|
|
118
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
119
|
+
x = self.fc1(x)
|
|
120
|
+
x = self.activation_fn(x)
|
|
121
|
+
x = self.fc2(x)
|
|
122
|
+
return x
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class EncoderLayer(nn.Module):
|
|
126
|
+
def __init__(self, config: VisionConfig, drop_path_rate: float = 0.0):
|
|
127
|
+
super().__init__()
|
|
128
|
+
self.embed_dim = config.hidden_size
|
|
129
|
+
self.intermediate_size = config.intermediate_size
|
|
130
|
+
self.norm_type = getattr(config, "norm_type", "layer_norm")
|
|
131
|
+
|
|
132
|
+
self.attn = Attention(config)
|
|
133
|
+
self.mlp = MLP(config)
|
|
134
|
+
|
|
135
|
+
if self.norm_type == "layer_norm":
|
|
136
|
+
self.norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
|
137
|
+
self.norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
|
138
|
+
elif self.norm_type == "rms_norm":
|
|
139
|
+
self.norm1 = nn.RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
|
140
|
+
self.norm2 = nn.RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
|
141
|
+
else:
|
|
142
|
+
raise ValueError(f"Unsupported normalization type: {self.norm_type}")
|
|
143
|
+
|
|
144
|
+
self.ls1 = mx.ones((self.embed_dim,))
|
|
145
|
+
self.ls2 = mx.ones((self.embed_dim,))
|
|
146
|
+
|
|
147
|
+
self.drop_path1 = (
|
|
148
|
+
nn.Dropout(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
|
|
149
|
+
)
|
|
150
|
+
self.drop_path2 = (
|
|
151
|
+
nn.Dropout(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
|
|
155
|
+
dtype = x.dtype
|
|
156
|
+
x = x + self.drop_path1(self.attn(self.norm1(x).astype(dtype)) * self.ls1)
|
|
157
|
+
|
|
158
|
+
x = x + self.drop_path2(self.mlp(self.norm2(x).astype(dtype)) * self.ls2)
|
|
159
|
+
|
|
160
|
+
return x.astype(dtype)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class Encoder(nn.Module):
|
|
164
|
+
def __init__(self, config: VisionConfig):
|
|
165
|
+
super().__init__()
|
|
166
|
+
dpr = [
|
|
167
|
+
mx.array(x)
|
|
168
|
+
for x in np.linspace(0, config.drop_path_rate, config.num_hidden_layers)
|
|
169
|
+
]
|
|
170
|
+
self.layers = [
|
|
171
|
+
EncoderLayer(config, dpr[i]) for i in range(config.num_hidden_layers)
|
|
172
|
+
]
|
|
173
|
+
|
|
174
|
+
def __call__(
|
|
175
|
+
self,
|
|
176
|
+
x: mx.array,
|
|
177
|
+
output_hidden_states: Optional[bool] = None,
|
|
178
|
+
mask: Optional[mx.array] = None,
|
|
179
|
+
) -> mx.array:
|
|
180
|
+
encoder_states = (x,) if output_hidden_states else None
|
|
181
|
+
h = x
|
|
182
|
+
for l in self.layers:
|
|
183
|
+
x = l(x, mask=mask)
|
|
184
|
+
if output_hidden_states:
|
|
185
|
+
encoder_states = encoder_states + (x,)
|
|
186
|
+
|
|
187
|
+
h = x
|
|
188
|
+
|
|
189
|
+
return (h, encoder_states)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
class VisionEmbeddings(nn.Module):
|
|
193
|
+
def __init__(self, config: VisionConfig):
|
|
194
|
+
super().__init__()
|
|
195
|
+
self.config = config
|
|
196
|
+
self.embed_dim = config.hidden_size
|
|
197
|
+
self.image_size = config.image_size
|
|
198
|
+
self.patch_size = config.patch_size
|
|
199
|
+
|
|
200
|
+
self.class_embedding = mx.random.normal((1, 1, self.embed_dim))
|
|
201
|
+
|
|
202
|
+
self.patch_embedding = nn.Conv2d(
|
|
203
|
+
in_channels=3,
|
|
204
|
+
out_channels=self.embed_dim,
|
|
205
|
+
kernel_size=self.patch_size,
|
|
206
|
+
stride=self.patch_size,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
|
210
|
+
self.num_positions = self.num_patches + 1
|
|
211
|
+
|
|
212
|
+
self.position_embedding = mx.random.normal(
|
|
213
|
+
(1, self.num_positions, self.embed_dim)
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
def _get_pos_embed(self, pos_embed, H, W):
|
|
217
|
+
target_dtype = pos_embed.dtype
|
|
218
|
+
pos_embed = pos_embed.reshape(
|
|
219
|
+
1,
|
|
220
|
+
self.image_size // self.patch_size,
|
|
221
|
+
self.image_size // self.patch_size,
|
|
222
|
+
-1,
|
|
223
|
+
).transpose(0, 3, 1, 2)
|
|
224
|
+
pos_embed = interpolate(pos_embed, (H, W))
|
|
225
|
+
pos_embed = (
|
|
226
|
+
pos_embed.reshape(1, -1, H * W).transpose(0, 2, 1).astype(target_dtype)
|
|
227
|
+
)
|
|
228
|
+
return pos_embed
|
|
229
|
+
|
|
230
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
231
|
+
target_dtype = self.patch_embedding.weight.dtype
|
|
232
|
+
patch_embeds = self.patch_embedding(x).transpose(
|
|
233
|
+
0, 3, 1, 2
|
|
234
|
+
) # shape = [*, channel, width, height]
|
|
235
|
+
batch_size, _, height, width = patch_embeds.shape
|
|
236
|
+
patch_embeds = mx.flatten(patch_embeds, start_axis=2).transpose(0, 2, 1)
|
|
237
|
+
class_embeds = mx.broadcast_to(
|
|
238
|
+
self.class_embedding, (batch_size, 1, self.embed_dim)
|
|
239
|
+
).astype(target_dtype)
|
|
240
|
+
embeddings = mx.concatenate([class_embeds, patch_embeds], axis=1)
|
|
241
|
+
position_embedding = mx.concatenate(
|
|
242
|
+
[
|
|
243
|
+
self.position_embedding[:, :1, :],
|
|
244
|
+
self._get_pos_embed(self.position_embedding[:, 1:, :], height, width),
|
|
245
|
+
],
|
|
246
|
+
axis=1,
|
|
247
|
+
)
|
|
248
|
+
embeddings = embeddings + position_embedding.astype(target_dtype)
|
|
249
|
+
|
|
250
|
+
return embeddings
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
class VisionModel(nn.Module):
|
|
254
|
+
def __init__(self, config: VisionConfig):
|
|
255
|
+
super().__init__()
|
|
256
|
+
self.model_type = config.model_type
|
|
257
|
+
if self.model_type not in ["siglip_vision_model", "intern_vit_6b"]:
|
|
258
|
+
raise ValueError(f"Unsupported model type: {self.model_type}")
|
|
259
|
+
|
|
260
|
+
self.embeddings = VisionEmbeddings(config)
|
|
261
|
+
self.encoder = Encoder(config)
|
|
262
|
+
|
|
263
|
+
def __call__(
|
|
264
|
+
self,
|
|
265
|
+
x: mx.array,
|
|
266
|
+
output_hidden_states: Optional[bool] = None,
|
|
267
|
+
) -> mx.array:
|
|
268
|
+
x = self.embeddings(x)
|
|
269
|
+
last_hidden_state, encoder_outputs = self.encoder(
|
|
270
|
+
x=x, output_hidden_states=output_hidden_states, mask=None
|
|
271
|
+
)
|
|
272
|
+
pooler_output = last_hidden_state[:, 0, :]
|
|
273
|
+
return last_hidden_state, pooler_output, encoder_outputs[1:]
|
|
274
|
+
|
|
275
|
+
def sanitize(self, weights):
|
|
276
|
+
sanitized_weights = {}
|
|
277
|
+
for k, v in weights.items():
|
|
278
|
+
if "position_ids" in k:
|
|
279
|
+
# Remove unused position_ids
|
|
280
|
+
continue
|
|
281
|
+
elif "patch_embedding.weight" in k:
|
|
282
|
+
# PyTorch conv2d weight tensors have shape:
|
|
283
|
+
# [out_channels, in_channels, kH, KW]
|
|
284
|
+
# MLX conv2d expects the weight be of shape:
|
|
285
|
+
# [out_channels, kH, KW, in_channels]
|
|
286
|
+
if check_array_shape(v):
|
|
287
|
+
sanitized_weights[k] = v
|
|
288
|
+
else:
|
|
289
|
+
sanitized_weights[k] = v.transpose(0, 2, 3, 1)
|
|
290
|
+
else:
|
|
291
|
+
sanitized_weights[k] = v
|
|
292
|
+
|
|
293
|
+
return sanitized_weights
|