nexaai 1.0.19rc5__cp310-cp310-macosx_14_0_universal2.whl → 1.0.19rc7__cp310-cp310-macosx_14_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nexaai might be problematic. Click here for more details.
- nexaai/_stub.cpython-310-darwin.so +0 -0
- nexaai/_version.py +1 -1
- nexaai/binds/libnexa_bridge.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libggml-base.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libggml-cpu.so +0 -0
- nexaai/binds/nexa_llama_cpp/libggml-metal.so +0 -0
- nexaai/binds/nexa_llama_cpp/libggml.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libllama.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libmtmd.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexa_mlx/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexa_mlx/py-lib/asr/__init__.py +12 -0
- nexaai/binds/nexa_mlx/py-lib/asr/interface.py +122 -0
- nexaai/binds/nexa_mlx/py-lib/common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/common/utils.py +25 -0
- nexaai/binds/nexa_mlx/py-lib/cv/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/cv/generate.py +195 -0
- nexaai/binds/nexa_mlx/py-lib/cv/interface.py +151 -0
- nexaai/binds/nexa_mlx/py-lib/cv/main.py +81 -0
- nexaai/binds/nexa_mlx/py-lib/cv/modeling/pp_ocr_v4.py +1736 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/generate.py +333 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/interface.py +617 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/main.py +173 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/modeling/nexa_jina_v2.py +399 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/generate_sd.py +244 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/interface.py +82 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/main.py +281 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/__init__.py +306 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/clip.py +116 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/config.py +65 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/model_io.py +386 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/sampler.py +105 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/tokenizer.py +100 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/unet.py +460 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/vae.py +274 -0
- nexaai/binds/nexa_mlx/py-lib/llm/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/llm/generate.py +149 -0
- nexaai/binds/nexa_mlx/py-lib/llm/interface.py +764 -0
- nexaai/binds/nexa_mlx/py-lib/llm/main.py +68 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/generate.py +174 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/interface.py +287 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/main.py +127 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/modeling/nexa_jina_rerank.py +330 -0
- nexaai/binds/nexa_mlx/py-lib/sd/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/sd/interface.py +362 -0
- nexaai/binds/nexa_mlx/py-lib/sd/main.py +286 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/__init__.py +306 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/clip.py +116 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/config.py +65 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/model_io.py +385 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/sampler.py +105 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/tokenizer.py +100 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/unet.py +460 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/vae.py +274 -0
- nexaai/binds/nexa_mlx/py-lib/tts/__init__.py +12 -0
- nexaai/binds/nexa_mlx/py-lib/tts/interface.py +276 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/__init__.py +3 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/generate.py +572 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl.py +294 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl_moe.py +276 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/interface.py +504 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/main.py +320 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/convert.py +68 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/interpolate.py +186 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/language.py +233 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/vision.py +503 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/base.py +202 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/cache.py +230 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/florence2.py +366 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/language.py +488 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/vision.py +591 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/gemma3.py +213 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/language.py +315 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/vision.py +238 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/__init__.py +2 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/audio.py +1038 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/config.py +139 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/language.py +629 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/vision.py +1022 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/__init__.py +9 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/idefics2.py +294 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/language.py +191 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/vision.py +267 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/idefics3.py +175 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/language.py +192 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/vision.py +233 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/__init__.py +9 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/language.py +220 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/processor.py +393 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/vision.py +293 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kernels.py +307 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/language.py +509 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/vision.py +522 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/language.py +386 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/llama4.py +138 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/vision.py +560 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/language.py +240 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/llava.py +153 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/vision.py +259 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/__init__.py +9 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/language.py +236 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/vision.py +303 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/language.py +230 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/llava_next.py +160 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/vision.py +243 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/mistral3.py +283 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/language.py +416 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/mllama.py +172 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/vision.py +499 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/language.py +243 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/molmo.py +133 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/vision.py +465 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/__init__.py +10 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/language.py +230 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/sam.py +557 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/vision.py +526 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/language.py +282 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/paligemma.py +160 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/vision.py +242 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/language.py +21 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/su_rope.py +71 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/vision.py +324 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/language.py +229 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/pixtral.py +161 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/vision.py +320 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/config.py +104 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/language.py +490 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/vision.py +312 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/processor.py +476 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/qwen3vl.py +1223 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1309 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_5_vl.py +209 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_vl.py +215 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/prompt_utils.py +474 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/sample_utils.py +39 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/tokenizer_utils.py +344 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/__init__.py +9 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/lora.py +70 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/trainer.py +296 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/utils.py +160 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/utils.py +928 -0
- nexaai/binds/nexa_nexaml/libggml-base.dylib +0 -0
- nexaai/binds/nexa_nexaml/libggml-cpu.so +0 -0
- nexaai/binds/nexa_nexaml/libggml-metal.so +0 -0
- nexaai/binds/nexa_nexaml/libggml.dylib +0 -0
- nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +276 -0
- nexaai/mlx_backend/vlm/interface.py +21 -4
- nexaai/mlx_backend/vlm/main.py +6 -2
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1309 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
- {nexaai-1.0.19rc5.dist-info → nexaai-1.0.19rc7.dist-info}/METADATA +1 -1
- {nexaai-1.0.19rc5.dist-info → nexaai-1.0.19rc7.dist-info}/RECORD +221 -21
- {nexaai-1.0.19rc5.dist-info → nexaai-1.0.19rc7.dist-info}/WHEEL +0 -0
- {nexaai-1.0.19rc5.dist-info → nexaai-1.0.19rc7.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,560 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import math
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Optional, Tuple
|
|
5
|
+
|
|
6
|
+
import mlx.core as mx
|
|
7
|
+
import mlx.nn as nn
|
|
8
|
+
|
|
9
|
+
from ..base import pixel_shuffle
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class VisionConfig:
|
|
14
|
+
model_type: str
|
|
15
|
+
hidden_size: int
|
|
16
|
+
image_size: int
|
|
17
|
+
initializer_range: float
|
|
18
|
+
intermediate_size: int
|
|
19
|
+
norm_eps: float
|
|
20
|
+
num_attention_heads: int
|
|
21
|
+
num_channels: int
|
|
22
|
+
num_hidden_layers: int
|
|
23
|
+
patch_size: int
|
|
24
|
+
pixel_shuffle_ratio: float
|
|
25
|
+
projector_dropout: float
|
|
26
|
+
projector_input_dim: int
|
|
27
|
+
projector_output_dim: int
|
|
28
|
+
rope_theta: float
|
|
29
|
+
vision_feature_layer: int
|
|
30
|
+
vision_feature_select_strategy: str
|
|
31
|
+
vision_output_dim: int
|
|
32
|
+
|
|
33
|
+
@classmethod
|
|
34
|
+
def from_dict(cls, params):
|
|
35
|
+
return cls(
|
|
36
|
+
**{
|
|
37
|
+
k: v
|
|
38
|
+
for k, v in params.items()
|
|
39
|
+
if k in inspect.signature(cls).parameters
|
|
40
|
+
}
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def check_array_shape(arr):
|
|
45
|
+
shape = arr.shape
|
|
46
|
+
|
|
47
|
+
# Check if the shape has 4 dimensions
|
|
48
|
+
if len(shape) != 4:
|
|
49
|
+
return False
|
|
50
|
+
|
|
51
|
+
out_channels, kH, KW, _ = shape
|
|
52
|
+
|
|
53
|
+
# Check if out_channels is the largest, and kH and KW are the same
|
|
54
|
+
if (out_channels >= kH) and (out_channels >= KW) and (kH == KW):
|
|
55
|
+
return True
|
|
56
|
+
else:
|
|
57
|
+
return False
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class Llama4MultiModalProjector(nn.Module):
|
|
61
|
+
def __init__(self, config):
|
|
62
|
+
super().__init__()
|
|
63
|
+
self.linear_1 = nn.Linear(
|
|
64
|
+
config.vision_config.vision_output_dim,
|
|
65
|
+
config.text_config.hidden_size,
|
|
66
|
+
bias=False,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
def __call__(self, image_features):
|
|
70
|
+
hidden_states = self.linear_1(image_features)
|
|
71
|
+
return hidden_states
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class Llama4VisionPixelShuffleMLP(nn.Module):
|
|
75
|
+
def __init__(self, config):
|
|
76
|
+
super().__init__()
|
|
77
|
+
self.pixel_shuffle_ratio = config.pixel_shuffle_ratio
|
|
78
|
+
self.inner_dim = int(
|
|
79
|
+
config.projector_input_dim // (self.pixel_shuffle_ratio**2)
|
|
80
|
+
)
|
|
81
|
+
self.output_dim = config.projector_output_dim
|
|
82
|
+
self.mlp = Llama4VisionMLP(config, bias=False, is_projector=True)
|
|
83
|
+
|
|
84
|
+
def __call__(self, encoded_patches: mx.array) -> mx.array:
|
|
85
|
+
encoded_patches = pixel_shuffle(encoded_patches, self.pixel_shuffle_ratio)
|
|
86
|
+
return self.mlp(encoded_patches)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
# TODO there is a different RoPE for vision encoder, defined as below
|
|
90
|
+
def reshape_for_broadcast(freqs_ci: mx.array, query: mx.array):
|
|
91
|
+
ndim = query.ndim
|
|
92
|
+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(query.shape)]
|
|
93
|
+
return freqs_ci.reshape(*shape)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def view_as_complex(x):
|
|
97
|
+
"""
|
|
98
|
+
Convert a tensor with shape (..., 2) to a complex tensor with shape (...).
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
x: A real tensor with last dimension of size 2.
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
A complex tensor with size one less than the input.
|
|
105
|
+
"""
|
|
106
|
+
# Ensure the last dimension is size 2
|
|
107
|
+
assert x.shape[-1] == 2, f"Last dimension must be 2, got {x.shape[-1]}"
|
|
108
|
+
|
|
109
|
+
# Get real and imaginary parts
|
|
110
|
+
real, imag = x[..., 0], x[..., 1]
|
|
111
|
+
|
|
112
|
+
# Create complex tensor
|
|
113
|
+
return real + 1j * imag
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def view_as_real(x):
|
|
117
|
+
"""
|
|
118
|
+
Convert a complex tensor with shape (...) to a real tensor with shape (..., 2).
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
x: A complex tensor.
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
A real tensor with an extra dimension of size 2.
|
|
125
|
+
"""
|
|
126
|
+
# Get real and imaginary parts
|
|
127
|
+
real = mx.real(x)
|
|
128
|
+
imag = mx.imag(x)
|
|
129
|
+
|
|
130
|
+
# Combine into a tensor with last dimension 2
|
|
131
|
+
return mx.stack([real, imag], axis=-1)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def vision_apply_rotary_emb(
|
|
135
|
+
query: mx.array,
|
|
136
|
+
key: mx.array,
|
|
137
|
+
freqs_ci: mx.array,
|
|
138
|
+
) -> Tuple[mx.array, mx.array]:
|
|
139
|
+
|
|
140
|
+
query_ = view_as_complex(query.astype(mx.float32).reshape(*query.shape[:-1], -1, 2))
|
|
141
|
+
key_ = view_as_complex(key.astype(mx.float32).reshape(*key.shape[:-1], -1, 2))
|
|
142
|
+
freqs_ci = reshape_for_broadcast(freqs_ci=freqs_ci, query=query_)
|
|
143
|
+
query_out = view_as_real(query_ * freqs_ci).flatten(3)
|
|
144
|
+
key_out = view_as_real(key_ * freqs_ci).flatten(3)
|
|
145
|
+
return query_out.astype(query.dtype), key_out.astype(key.dtype)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class Llama4VisionAttention(nn.Module):
|
|
149
|
+
def __init__(self, config: VisionConfig):
|
|
150
|
+
super().__init__()
|
|
151
|
+
self.config = config
|
|
152
|
+
self.embed_dim = config.hidden_size
|
|
153
|
+
self.num_heads = config.num_attention_heads
|
|
154
|
+
self.head_dim = config.hidden_size // config.num_attention_heads
|
|
155
|
+
self.num_key_value_groups = 1
|
|
156
|
+
self.scale = self.head_dim**-0.5
|
|
157
|
+
|
|
158
|
+
self.q_proj = nn.Linear(
|
|
159
|
+
self.embed_dim, self.num_heads * self.head_dim, bias=True
|
|
160
|
+
)
|
|
161
|
+
self.k_proj = nn.Linear(
|
|
162
|
+
self.embed_dim, self.num_heads * self.head_dim, bias=True
|
|
163
|
+
)
|
|
164
|
+
self.v_proj = nn.Linear(
|
|
165
|
+
self.embed_dim, self.num_heads * self.head_dim, bias=True
|
|
166
|
+
)
|
|
167
|
+
self.o_proj = nn.Linear(
|
|
168
|
+
self.num_heads * self.head_dim, self.embed_dim, bias=True
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
def __call__(
|
|
172
|
+
self,
|
|
173
|
+
hidden_states: mx.array,
|
|
174
|
+
freqs_ci: mx.array,
|
|
175
|
+
mask: Optional[mx.array] = None,
|
|
176
|
+
cache: Optional[mx.array] = None,
|
|
177
|
+
):
|
|
178
|
+
B, L, D = hidden_states.shape
|
|
179
|
+
|
|
180
|
+
query_states = self.q_proj(hidden_states).reshape(B, L, self.num_heads, -1)
|
|
181
|
+
key_states = self.k_proj(hidden_states).reshape(B, L, self.num_heads, -1)
|
|
182
|
+
value_states = self.v_proj(hidden_states).reshape(B, L, self.num_heads, -1)
|
|
183
|
+
|
|
184
|
+
query_states, key_states = vision_apply_rotary_emb(
|
|
185
|
+
query_states, key_states, freqs_ci=freqs_ci
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
query_states = query_states.transpose(0, 2, 1, 3)
|
|
189
|
+
key_states = key_states.transpose(0, 2, 1, 3)
|
|
190
|
+
value_states = value_states.transpose(0, 2, 1, 3)
|
|
191
|
+
|
|
192
|
+
attn_output = mx.fast.scaled_dot_product_attention(
|
|
193
|
+
query_states, key_states, value_states, scale=self.scale
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
attn_output = attn_output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
197
|
+
attn_output = self.o_proj(attn_output)
|
|
198
|
+
return attn_output
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
class Llama4VisionMLP(nn.Module):
|
|
202
|
+
def __init__(self, config, bias=True, is_projector=False):
|
|
203
|
+
super().__init__()
|
|
204
|
+
self.config = config
|
|
205
|
+
self.activation_fn = nn.GELU(approx="fast") # ACT2FN[config.hidden_act]
|
|
206
|
+
self.is_projector = is_projector
|
|
207
|
+
self.hidden_size = config.hidden_size
|
|
208
|
+
self.intermediate_size = config.intermediate_size
|
|
209
|
+
|
|
210
|
+
# Determine dimensions for first linear layer based on whether this is a projector
|
|
211
|
+
fc1_input_dim = self.intermediate_size if is_projector else self.hidden_size
|
|
212
|
+
fc1_output_dim = (
|
|
213
|
+
config.projector_input_dim if is_projector else self.intermediate_size
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
self.fc1 = nn.Linear(fc1_input_dim, fc1_output_dim, bias=bias)
|
|
217
|
+
|
|
218
|
+
# Determine dimensions for second linear layer
|
|
219
|
+
fc2_input_dim = (
|
|
220
|
+
config.projector_output_dim if is_projector else self.intermediate_size
|
|
221
|
+
)
|
|
222
|
+
fc2_output_dim = (
|
|
223
|
+
config.projector_output_dim if is_projector else self.hidden_size
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
self.fc2 = nn.Linear(fc2_input_dim, fc2_output_dim, bias=bias)
|
|
227
|
+
|
|
228
|
+
self.is_projector = is_projector
|
|
229
|
+
|
|
230
|
+
def __call__(self, hidden_states: mx.array) -> mx.array:
|
|
231
|
+
hidden_states = self.fc1(hidden_states)
|
|
232
|
+
hidden_states = self.activation_fn(hidden_states)
|
|
233
|
+
|
|
234
|
+
if self.is_projector:
|
|
235
|
+
return self.activation_fn(self.fc2(hidden_states))
|
|
236
|
+
|
|
237
|
+
return self.fc2(hidden_states)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
class Llama4VisionEncoderLayer(nn.Module):
|
|
241
|
+
def __init__(self, config: VisionConfig):
|
|
242
|
+
super().__init__()
|
|
243
|
+
self.hidden_size = config.hidden_size
|
|
244
|
+
|
|
245
|
+
self.self_attn = Llama4VisionAttention(config)
|
|
246
|
+
self.mlp = Llama4VisionMLP(config)
|
|
247
|
+
|
|
248
|
+
self.input_layernorm = nn.LayerNorm(config.hidden_size)
|
|
249
|
+
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size)
|
|
250
|
+
|
|
251
|
+
def __call__(
|
|
252
|
+
self,
|
|
253
|
+
hidden_state: mx.array,
|
|
254
|
+
freqs_ci: mx.array,
|
|
255
|
+
mask: Optional[mx.array] = None,
|
|
256
|
+
):
|
|
257
|
+
# Self Attention
|
|
258
|
+
residual = hidden_state
|
|
259
|
+
|
|
260
|
+
hidden_state = self.input_layernorm(hidden_state)
|
|
261
|
+
|
|
262
|
+
hidden_state = self.self_attn(
|
|
263
|
+
hidden_state,
|
|
264
|
+
freqs_ci=freqs_ci,
|
|
265
|
+
mask=mask,
|
|
266
|
+
)
|
|
267
|
+
hidden_state = residual + hidden_state
|
|
268
|
+
|
|
269
|
+
# Feed forward
|
|
270
|
+
residual = hidden_state
|
|
271
|
+
hidden_state = self.post_attention_layernorm(hidden_state)
|
|
272
|
+
hidden_state = self.mlp(hidden_state)
|
|
273
|
+
hidden_state = residual + hidden_state
|
|
274
|
+
return hidden_state
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
class Llama4VisionEncoder(nn.Module):
|
|
278
|
+
"""
|
|
279
|
+
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
|
280
|
+
[`Llama4VisionEncoderLayer`].
|
|
281
|
+
|
|
282
|
+
Args:
|
|
283
|
+
config: VisionConfig
|
|
284
|
+
"""
|
|
285
|
+
|
|
286
|
+
def __init__(self, config: VisionConfig):
|
|
287
|
+
super().__init__()
|
|
288
|
+
self.config = config
|
|
289
|
+
self.layers = [
|
|
290
|
+
Llama4VisionEncoderLayer(config) for _ in range(config.num_hidden_layers)
|
|
291
|
+
]
|
|
292
|
+
self.config = config
|
|
293
|
+
|
|
294
|
+
def __call__(
|
|
295
|
+
self,
|
|
296
|
+
hidden_states: mx.array,
|
|
297
|
+
freqs_ci: mx.array, # TODO move this to an attribute instead of keeping it around
|
|
298
|
+
mask: Optional[mx.array] = None,
|
|
299
|
+
):
|
|
300
|
+
|
|
301
|
+
for i, encoder_layer in enumerate(self.layers):
|
|
302
|
+
hidden_states = encoder_layer(
|
|
303
|
+
hidden_state=hidden_states,
|
|
304
|
+
mask=mask,
|
|
305
|
+
freqs_ci=freqs_ci,
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
return hidden_states
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
class Llama4UnfoldConvolution(nn.Module):
|
|
312
|
+
def __init__(self, config):
|
|
313
|
+
super().__init__()
|
|
314
|
+
kernel_size = config.patch_size
|
|
315
|
+
if isinstance(kernel_size, int):
|
|
316
|
+
kernel_size = (kernel_size, kernel_size)
|
|
317
|
+
self.kernel_size = kernel_size
|
|
318
|
+
self.stride = config.patch_size
|
|
319
|
+
self.linear = nn.Linear(
|
|
320
|
+
config.num_channels * kernel_size[0] * kernel_size[1],
|
|
321
|
+
config.hidden_size,
|
|
322
|
+
bias=False,
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
def _pair(self, x):
|
|
326
|
+
"""Convert input to a pair of values."""
|
|
327
|
+
if isinstance(x, (list, tuple)):
|
|
328
|
+
return tuple(x)
|
|
329
|
+
return (x, x)
|
|
330
|
+
|
|
331
|
+
def unfold(self, input_tensor):
|
|
332
|
+
"""
|
|
333
|
+
Extract sliding local blocks from a batched input tensor (MLX implementation).
|
|
334
|
+
|
|
335
|
+
This is equivalent to PyTorch's nn.functional.unfold or im2col operation.
|
|
336
|
+
|
|
337
|
+
Args:
|
|
338
|
+
input_tensor: Input tensor of shape (B, C, H, W)
|
|
339
|
+
|
|
340
|
+
Returns:
|
|
341
|
+
Unfolded tensor of shape (B, C*kernel_height*kernel_width, L)
|
|
342
|
+
where L is the number of blocks
|
|
343
|
+
"""
|
|
344
|
+
# Convert to pairs
|
|
345
|
+
kernel_size = self._pair(self.kernel_size)
|
|
346
|
+
stride = self._pair(self.stride)
|
|
347
|
+
padding = (0, 0) # No padding in the original code
|
|
348
|
+
dilation = (1, 1) # Default dilation
|
|
349
|
+
|
|
350
|
+
# Input shape
|
|
351
|
+
batch_size, channels, height, width = input_tensor.shape
|
|
352
|
+
|
|
353
|
+
# Calculate output dimensions
|
|
354
|
+
height_out = (
|
|
355
|
+
height + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1
|
|
356
|
+
) // stride[0] + 1
|
|
357
|
+
width_out = (
|
|
358
|
+
width + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1
|
|
359
|
+
) // stride[1] + 1
|
|
360
|
+
|
|
361
|
+
# Initialize output arrays
|
|
362
|
+
blocks = []
|
|
363
|
+
|
|
364
|
+
# Extract blocks
|
|
365
|
+
for i in range(0, height - kernel_size[0] * dilation[0] + 1, stride[0]):
|
|
366
|
+
for j in range(0, width - kernel_size[1] * dilation[1] + 1, stride[1]):
|
|
367
|
+
# Extract the block for all channels
|
|
368
|
+
block = []
|
|
369
|
+
for di in range(kernel_size[0]):
|
|
370
|
+
for dj in range(kernel_size[1]):
|
|
371
|
+
h_idx = i + di * dilation[0]
|
|
372
|
+
w_idx = j + dj * dilation[1]
|
|
373
|
+
# Get the block for all channels and add to our list
|
|
374
|
+
block.append(input_tensor[:, :, h_idx, w_idx])
|
|
375
|
+
|
|
376
|
+
# Stack the channel-blocks
|
|
377
|
+
block = mx.stack(block, axis=1) # Shape: (B, k*k, C)
|
|
378
|
+
block = mx.transpose(block, [0, 2, 1]) # Shape: (B, C, k*k)
|
|
379
|
+
blocks.append(block)
|
|
380
|
+
|
|
381
|
+
# Stack all blocks together
|
|
382
|
+
result = mx.stack(blocks, axis=-1) # Shape: (B, C, k*k, L)
|
|
383
|
+
|
|
384
|
+
# Reshape to match PyTorch's unfold output format: (B, C*k*k, L)
|
|
385
|
+
result = mx.reshape(
|
|
386
|
+
result,
|
|
387
|
+
(
|
|
388
|
+
batch_size,
|
|
389
|
+
channels * kernel_size[0] * kernel_size[1],
|
|
390
|
+
height_out * width_out,
|
|
391
|
+
),
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
return result
|
|
395
|
+
|
|
396
|
+
def __call__(self, hidden_states: mx.array) -> mx.array:
|
|
397
|
+
hidden_states = self.unfold(hidden_states)
|
|
398
|
+
hidden_states = hidden_states.swapaxes(1, 2)
|
|
399
|
+
hidden_states = self.linear(hidden_states)
|
|
400
|
+
return hidden_states
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
class Llama4VisionRotaryEmbedding:
|
|
404
|
+
def __init__(self, config):
|
|
405
|
+
super().__init__()
|
|
406
|
+
idx = config.image_size // config.patch_size
|
|
407
|
+
img_idx = mx.arange(idx**2, dtype=mx.int32).reshape(idx**2, 1)
|
|
408
|
+
img_idx = mx.concatenate([img_idx, img_idx[:1]], axis=0)
|
|
409
|
+
img_idx[-1, -1] = -2 # ID_CLS_TOKEN
|
|
410
|
+
frequencies_x = img_idx % idx # get the coordinates of the 2d matrix along x
|
|
411
|
+
frequencies_y = img_idx // idx # get the coordinates of the 2d matrix along y
|
|
412
|
+
freq_dim = config.hidden_size // config.num_attention_heads // 2
|
|
413
|
+
rope_freq = 1.0 / (
|
|
414
|
+
config.rope_theta
|
|
415
|
+
** (
|
|
416
|
+
mx.arange(0, freq_dim, 2, dtype=mx.float32)[: (freq_dim // 2)]
|
|
417
|
+
/ freq_dim
|
|
418
|
+
)
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
# Expand dimensions for frequencies_x and frequencies_y
|
|
422
|
+
freqs_x_expanded = (frequencies_x + 1)[..., None] * rope_freq[None, None, :]
|
|
423
|
+
freqs_y_expanded = (frequencies_y + 1)[..., None] * rope_freq[None, None, :]
|
|
424
|
+
|
|
425
|
+
def repeat_interleave(tensor, repeats, dim=-1):
|
|
426
|
+
# Get the shape
|
|
427
|
+
shape = list(tensor.shape)
|
|
428
|
+
|
|
429
|
+
# Reshape to add an extra dimension for repeating
|
|
430
|
+
tensor = mx.reshape(tensor, shape[:-1] + [shape[-1], 1])
|
|
431
|
+
|
|
432
|
+
# Repeat along the new dimension
|
|
433
|
+
tensor = mx.repeat(tensor, repeats, axis=-1)
|
|
434
|
+
|
|
435
|
+
# Reshape to flatten the last two dimensions
|
|
436
|
+
return mx.reshape(tensor, shape[:-1] + [shape[-1] * repeats])
|
|
437
|
+
|
|
438
|
+
# Apply interleaving
|
|
439
|
+
freqs_x = repeat_interleave(freqs_x_expanded, 2)
|
|
440
|
+
freqs_y = repeat_interleave(freqs_y_expanded, 2)
|
|
441
|
+
freqs = mx.concatenate([freqs_x, freqs_y], axis=-1).astype(mx.float32)[..., ::2]
|
|
442
|
+
# Replaced masked_fill with where
|
|
443
|
+
mask = img_idx.reshape(-1, 1, 1) < 0
|
|
444
|
+
freqs = mx.where(mask, mx.zeros_like(freqs), freqs)
|
|
445
|
+
freq_cis = mx.stack([mx.cos(freqs), mx.sin(freqs)], axis=-1)
|
|
446
|
+
freq_cis = view_as_complex(freq_cis)
|
|
447
|
+
self.freqs_ci = freq_cis # idx**2, idx**2, idx * 2
|
|
448
|
+
|
|
449
|
+
def __call__(self, hidden_states):
|
|
450
|
+
return self.freqs_ci
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
class VisionModel(nn.Module):
|
|
454
|
+
def __init__(self, config: VisionConfig):
|
|
455
|
+
super().__init__()
|
|
456
|
+
self.image_size = config.image_size
|
|
457
|
+
self.patch_size = config.patch_size
|
|
458
|
+
self.hidden_size = config.hidden_size
|
|
459
|
+
self.num_channels = config.num_channels
|
|
460
|
+
self.model_type = config.model_type
|
|
461
|
+
if self.model_type not in ["llama4", "llama4_vision_model"]:
|
|
462
|
+
raise ValueError(f"Model type {self.model_type} not supported")
|
|
463
|
+
|
|
464
|
+
self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
|
|
465
|
+
self.scale = config.hidden_size**-0.5
|
|
466
|
+
|
|
467
|
+
self.class_embedding = self.scale * mx.random.normal((self.hidden_size,))
|
|
468
|
+
self.positional_embedding_vlm = self.scale * mx.random.normal(
|
|
469
|
+
(self.num_patches, self.hidden_size)
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
self.patch_embedding = Llama4UnfoldConvolution(config)
|
|
473
|
+
|
|
474
|
+
self.rotary_embedding = Llama4VisionRotaryEmbedding(config)
|
|
475
|
+
|
|
476
|
+
# layer norms
|
|
477
|
+
self.layernorm_pre = nn.LayerNorm(self.hidden_size)
|
|
478
|
+
self.layernorm_post = nn.LayerNorm(self.hidden_size)
|
|
479
|
+
|
|
480
|
+
# encoders
|
|
481
|
+
self.model = Llama4VisionEncoder(config)
|
|
482
|
+
self.vision_adapter = Llama4VisionPixelShuffleMLP(config)
|
|
483
|
+
|
|
484
|
+
def get_input_embeddings(self):
|
|
485
|
+
"""
|
|
486
|
+
This function is used to fetch the first embedding layer to activate grads on inputs.
|
|
487
|
+
"""
|
|
488
|
+
return self.patch_embedding
|
|
489
|
+
|
|
490
|
+
def __call__(
|
|
491
|
+
self,
|
|
492
|
+
pixel_values: mx.array,
|
|
493
|
+
output_attentions: Optional[bool] = None,
|
|
494
|
+
output_hidden_states: Optional[bool] = None,
|
|
495
|
+
capture_activations: Optional[bool] = True,
|
|
496
|
+
):
|
|
497
|
+
|
|
498
|
+
batch_size_times_num_tiles, num_channels, height, width = pixel_values.shape
|
|
499
|
+
num_concurrent_media = 1
|
|
500
|
+
num_chunks = 1
|
|
501
|
+
|
|
502
|
+
hidden_state = self.patch_embedding(pixel_values)
|
|
503
|
+
|
|
504
|
+
_, num_patches, hidden_dim = hidden_state.shape
|
|
505
|
+
|
|
506
|
+
# Add cls token
|
|
507
|
+
hidden_state = hidden_state.reshape(
|
|
508
|
+
batch_size_times_num_tiles * num_concurrent_media * num_chunks,
|
|
509
|
+
num_patches,
|
|
510
|
+
hidden_dim,
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
class_embedding = mx.broadcast_to(
|
|
514
|
+
self.class_embedding, (hidden_state.shape[0], 1, hidden_state.shape[-1])
|
|
515
|
+
)
|
|
516
|
+
hidden_state = mx.concatenate([hidden_state, class_embedding], axis=1)
|
|
517
|
+
num_patches += 1
|
|
518
|
+
|
|
519
|
+
# Position embeddings
|
|
520
|
+
hidden_state = hidden_state.reshape(
|
|
521
|
+
batch_size_times_num_tiles * num_concurrent_media,
|
|
522
|
+
num_chunks,
|
|
523
|
+
num_patches,
|
|
524
|
+
hidden_dim,
|
|
525
|
+
)
|
|
526
|
+
|
|
527
|
+
positional_embedding = self.positional_embedding_vlm
|
|
528
|
+
hidden_state = hidden_state + positional_embedding
|
|
529
|
+
|
|
530
|
+
hidden_state = self.layernorm_pre(hidden_state)
|
|
531
|
+
|
|
532
|
+
hidden_state = hidden_state.reshape(batch_size_times_num_tiles, -1, hidden_dim)
|
|
533
|
+
freqs_ci = self.rotary_embedding(pixel_values)
|
|
534
|
+
|
|
535
|
+
hidden_state = self.model(
|
|
536
|
+
hidden_state,
|
|
537
|
+
mask=None,
|
|
538
|
+
freqs_ci=freqs_ci,
|
|
539
|
+
)
|
|
540
|
+
|
|
541
|
+
hidden_state = self.layernorm_post(hidden_state)
|
|
542
|
+
|
|
543
|
+
hidden_state = hidden_state[:, :-1, :]
|
|
544
|
+
|
|
545
|
+
# now, we use Llama4VisionPixelShuffle + mlp to project embeddings
|
|
546
|
+
final_hidden_state = self.vision_adapter(hidden_state)
|
|
547
|
+
|
|
548
|
+
# Return only the final state
|
|
549
|
+
return final_hidden_state
|
|
550
|
+
|
|
551
|
+
def sanitize(self, weights):
|
|
552
|
+
sanitized_weights = {}
|
|
553
|
+
for k, v in weights.items():
|
|
554
|
+
if "position_ids" in k:
|
|
555
|
+
# Remove unused position_ids
|
|
556
|
+
continue
|
|
557
|
+
else:
|
|
558
|
+
sanitized_weights[k] = v
|
|
559
|
+
|
|
560
|
+
return sanitized_weights
|