nexaai 1.0.19rc5__cp310-cp310-macosx_14_0_universal2.whl → 1.0.19rc7__cp310-cp310-macosx_14_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nexaai might be problematic. Click here for more details.
- nexaai/_stub.cpython-310-darwin.so +0 -0
- nexaai/_version.py +1 -1
- nexaai/binds/libnexa_bridge.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libggml-base.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libggml-cpu.so +0 -0
- nexaai/binds/nexa_llama_cpp/libggml-metal.so +0 -0
- nexaai/binds/nexa_llama_cpp/libggml.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libllama.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libmtmd.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexa_mlx/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexa_mlx/py-lib/asr/__init__.py +12 -0
- nexaai/binds/nexa_mlx/py-lib/asr/interface.py +122 -0
- nexaai/binds/nexa_mlx/py-lib/common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/common/utils.py +25 -0
- nexaai/binds/nexa_mlx/py-lib/cv/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/cv/generate.py +195 -0
- nexaai/binds/nexa_mlx/py-lib/cv/interface.py +151 -0
- nexaai/binds/nexa_mlx/py-lib/cv/main.py +81 -0
- nexaai/binds/nexa_mlx/py-lib/cv/modeling/pp_ocr_v4.py +1736 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/generate.py +333 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/interface.py +617 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/main.py +173 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/modeling/nexa_jina_v2.py +399 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/generate_sd.py +244 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/interface.py +82 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/main.py +281 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/__init__.py +306 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/clip.py +116 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/config.py +65 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/model_io.py +386 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/sampler.py +105 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/tokenizer.py +100 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/unet.py +460 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/vae.py +274 -0
- nexaai/binds/nexa_mlx/py-lib/llm/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/llm/generate.py +149 -0
- nexaai/binds/nexa_mlx/py-lib/llm/interface.py +764 -0
- nexaai/binds/nexa_mlx/py-lib/llm/main.py +68 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/generate.py +174 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/interface.py +287 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/main.py +127 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/modeling/nexa_jina_rerank.py +330 -0
- nexaai/binds/nexa_mlx/py-lib/sd/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/sd/interface.py +362 -0
- nexaai/binds/nexa_mlx/py-lib/sd/main.py +286 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/__init__.py +306 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/clip.py +116 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/config.py +65 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/model_io.py +385 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/sampler.py +105 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/tokenizer.py +100 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/unet.py +460 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/vae.py +274 -0
- nexaai/binds/nexa_mlx/py-lib/tts/__init__.py +12 -0
- nexaai/binds/nexa_mlx/py-lib/tts/interface.py +276 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/__init__.py +3 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/generate.py +572 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl.py +294 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl_moe.py +276 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/interface.py +504 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/main.py +320 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/convert.py +68 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/interpolate.py +186 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/language.py +233 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/vision.py +503 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/base.py +202 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/cache.py +230 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/florence2.py +366 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/language.py +488 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/vision.py +591 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/gemma3.py +213 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/language.py +315 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/vision.py +238 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/__init__.py +2 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/audio.py +1038 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/config.py +139 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/language.py +629 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/vision.py +1022 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/__init__.py +9 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/idefics2.py +294 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/language.py +191 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/vision.py +267 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/idefics3.py +175 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/language.py +192 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/vision.py +233 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/__init__.py +9 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/language.py +220 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/processor.py +393 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/vision.py +293 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kernels.py +307 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/language.py +509 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/vision.py +522 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/language.py +386 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/llama4.py +138 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/vision.py +560 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/language.py +240 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/llava.py +153 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/vision.py +259 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/__init__.py +9 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/language.py +236 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/vision.py +303 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/language.py +230 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/llava_next.py +160 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/vision.py +243 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/mistral3.py +283 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/language.py +416 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/mllama.py +172 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/vision.py +499 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/language.py +243 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/molmo.py +133 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/vision.py +465 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/__init__.py +10 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/language.py +230 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/sam.py +557 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/vision.py +526 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/language.py +282 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/paligemma.py +160 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/vision.py +242 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/language.py +21 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/su_rope.py +71 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/vision.py +324 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/language.py +229 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/pixtral.py +161 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/vision.py +320 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/config.py +104 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/language.py +490 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/vision.py +312 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/processor.py +476 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/qwen3vl.py +1223 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1309 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_5_vl.py +209 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_vl.py +215 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/prompt_utils.py +474 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/sample_utils.py +39 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/tokenizer_utils.py +344 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/__init__.py +9 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/lora.py +70 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/trainer.py +296 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/utils.py +160 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/utils.py +928 -0
- nexaai/binds/nexa_nexaml/libggml-base.dylib +0 -0
- nexaai/binds/nexa_nexaml/libggml-cpu.so +0 -0
- nexaai/binds/nexa_nexaml/libggml-metal.so +0 -0
- nexaai/binds/nexa_nexaml/libggml.dylib +0 -0
- nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +276 -0
- nexaai/mlx_backend/vlm/interface.py +21 -4
- nexaai/mlx_backend/vlm/main.py +6 -2
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1309 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
- {nexaai-1.0.19rc5.dist-info → nexaai-1.0.19rc7.dist-info}/METADATA +1 -1
- {nexaai-1.0.19rc5.dist-info → nexaai-1.0.19rc7.dist-info}/RECORD +221 -21
- {nexaai-1.0.19rc5.dist-info → nexaai-1.0.19rc7.dist-info}/WHEEL +0 -0
- {nexaai-1.0.19rc5.dist-info → nexaai-1.0.19rc7.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,366 @@
|
|
|
1
|
+
import glob
|
|
2
|
+
import inspect
|
|
3
|
+
import json
|
|
4
|
+
import math
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import List, Optional, Tuple, Union
|
|
8
|
+
|
|
9
|
+
import mlx.core as mx
|
|
10
|
+
import mlx.nn as nn
|
|
11
|
+
from huggingface_hub import snapshot_download
|
|
12
|
+
from mlx.utils import tree_map
|
|
13
|
+
|
|
14
|
+
from .language import LanguageModel, TextConfig
|
|
15
|
+
from .vision import VisionConfig, VisionModel
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class ModelConfig:
|
|
20
|
+
"""Configuration class for Florence2."""
|
|
21
|
+
|
|
22
|
+
vision_config: VisionConfig
|
|
23
|
+
text_config: TextConfig
|
|
24
|
+
model_type: str = "florence2"
|
|
25
|
+
vocab_size: int = 50265
|
|
26
|
+
max_position_embeddings: int = 1024
|
|
27
|
+
pad_token_id: int = 1
|
|
28
|
+
bos_token_id: int = 0
|
|
29
|
+
eos_token_id: int = 2
|
|
30
|
+
image_token_index: int = 0
|
|
31
|
+
image_feature_source: List[str] = field(
|
|
32
|
+
default_factory=lambda: ["temporal_avg_pool", "spatial_avg_pool"]
|
|
33
|
+
)
|
|
34
|
+
visual_temporal_embedding: Optional[dict] = field(
|
|
35
|
+
default_factory=lambda: {"type": "COSINE", "max_temporal_embeddings": 100}
|
|
36
|
+
)
|
|
37
|
+
image_pos_embed: Optional[dict] = field(
|
|
38
|
+
default_factory=lambda: {"type": "learned_abs_2d", "max_pos_embeddings": 50}
|
|
39
|
+
)
|
|
40
|
+
eos_token_id: Optional[List[int]] = None
|
|
41
|
+
|
|
42
|
+
@classmethod
|
|
43
|
+
def from_dict(cls, params):
|
|
44
|
+
return cls(
|
|
45
|
+
**{
|
|
46
|
+
k: v
|
|
47
|
+
for k, v in params.items()
|
|
48
|
+
if k in inspect.signature(cls).parameters
|
|
49
|
+
}
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def shift_tokens_right(
|
|
54
|
+
input_ids: mx.array, pad_token_id: int, decoder_start_token_id: int
|
|
55
|
+
) -> mx.array:
|
|
56
|
+
"""Shift input tokens right, adding decoder start token at beginning."""
|
|
57
|
+
shifted = mx.roll(input_ids, 1, axis=-1)
|
|
58
|
+
shifted = tree_map(lambda x: x.at[:, 0].set(decoder_start_token_id), shifted)
|
|
59
|
+
shifted = mx.where(shifted == -100, pad_token_id, shifted)
|
|
60
|
+
return shifted
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class LearnedPositionEmbedding2D(nn.Module):
|
|
64
|
+
"""2D learned position embeddings."""
|
|
65
|
+
|
|
66
|
+
def __init__(self, embedding_dim: int = 256, num_pos: int = 50):
|
|
67
|
+
super().__init__()
|
|
68
|
+
self.row_embeddings = nn.Embedding(num_pos, embedding_dim // 2)
|
|
69
|
+
self.column_embeddings = nn.Embedding(
|
|
70
|
+
num_pos, embedding_dim - (embedding_dim // 2)
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
def __call__(self, x):
|
|
74
|
+
batch_size, height, width, channels = x.shape
|
|
75
|
+
width_pos = mx.arange(width)
|
|
76
|
+
height_pos = mx.arange(height)
|
|
77
|
+
|
|
78
|
+
x_emb = self.column_embeddings(width_pos)
|
|
79
|
+
y_emb = self.row_embeddings(height_pos)
|
|
80
|
+
|
|
81
|
+
pos = mx.concatenate(
|
|
82
|
+
[
|
|
83
|
+
mx.broadcast_to(x_emb[None, :, :], (height, width, x_emb.shape[-1])),
|
|
84
|
+
mx.broadcast_to(y_emb[:, None, :], (height, width, y_emb.shape[-1])),
|
|
85
|
+
],
|
|
86
|
+
axis=-1,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
return mx.broadcast_to(pos[None, ...], (batch_size, height, width, channels))
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class PositionalEmbeddingCosine1D(nn.Module):
|
|
93
|
+
"""
|
|
94
|
+
MLX implementation of 1D cosine positional embeddings.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
embed_dim: The dimension of the embeddings
|
|
98
|
+
max_seq_len: The maximum length to precompute the positional encodings
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
def __init__(self, embed_dim: int = 512, max_seq_len: int = 1024) -> None:
|
|
102
|
+
super().__init__()
|
|
103
|
+
self.embed_dim = embed_dim
|
|
104
|
+
self.max_seq_len = max_seq_len
|
|
105
|
+
|
|
106
|
+
# Generate position indices and dimension indices
|
|
107
|
+
position = mx.arange(max_seq_len)
|
|
108
|
+
dim_pos = mx.arange(0, embed_dim // 2) # Half the dimensions for sin/cos pairs
|
|
109
|
+
|
|
110
|
+
# Calculate frequency bands
|
|
111
|
+
factor = math.log(10000)
|
|
112
|
+
denominator = mx.exp(-factor * dim_pos / embed_dim)
|
|
113
|
+
|
|
114
|
+
# Create position-frequency product matrix [max_seq_len, embed_dim//2]
|
|
115
|
+
frequencies = mx.reshape(position, (-1, 1)) * denominator
|
|
116
|
+
|
|
117
|
+
# Calculate sin and cos values [max_seq_len, embed_dim//2]
|
|
118
|
+
sin_values = mx.sin(frequencies)
|
|
119
|
+
cos_values = mx.cos(frequencies)
|
|
120
|
+
|
|
121
|
+
# Interleave sin and cos values to create final embeddings
|
|
122
|
+
pos_idx_to_embed = mx.zeros((max_seq_len, embed_dim))
|
|
123
|
+
pos_idx_to_embed = mx.concatenate(
|
|
124
|
+
[mx.expand_dims(sin_values, -1), mx.expand_dims(cos_values, -1)], axis=-1
|
|
125
|
+
).reshape(max_seq_len, embed_dim)
|
|
126
|
+
|
|
127
|
+
# Store the positional embeddings
|
|
128
|
+
self.pos_idx_to_embed = pos_idx_to_embed
|
|
129
|
+
|
|
130
|
+
def __call__(self, seq_embeds: mx.array) -> mx.array:
|
|
131
|
+
"""
|
|
132
|
+
Apply positional embeddings to the input sequence.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
seq_embeds: Input sequence embeddings with shape:
|
|
136
|
+
- [T, D] where T is sequence length and D is embedding dimension
|
|
137
|
+
- [B, T, D] where B is batch size
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
Positional embeddings matching input shape
|
|
141
|
+
"""
|
|
142
|
+
shape_len = len(seq_embeds.shape)
|
|
143
|
+
assert 2 <= shape_len <= 3, "Input must be 2D or 3D tensor"
|
|
144
|
+
|
|
145
|
+
len_seq = seq_embeds.shape[-2]
|
|
146
|
+
assert (
|
|
147
|
+
len_seq <= self.max_seq_len
|
|
148
|
+
), f"Sequence length {len_seq} exceeds maximum length {self.max_seq_len}"
|
|
149
|
+
|
|
150
|
+
# Get relevant portion of pre-computed embeddings
|
|
151
|
+
pos_embeds = self.pos_idx_to_embed[:len_seq]
|
|
152
|
+
|
|
153
|
+
# Add batch dimension if input is 3D
|
|
154
|
+
if shape_len == 3:
|
|
155
|
+
pos_embeds = mx.expand_dims(pos_embeds, 0)
|
|
156
|
+
|
|
157
|
+
return pos_embeds
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class Model(nn.Module):
|
|
161
|
+
"""Florence-2 model for conditional generation."""
|
|
162
|
+
|
|
163
|
+
def __init__(self, config: ModelConfig):
|
|
164
|
+
super().__init__()
|
|
165
|
+
self.config = config
|
|
166
|
+
|
|
167
|
+
# Initialize vision model
|
|
168
|
+
self.vision_tower = VisionModel(config.vision_config)
|
|
169
|
+
|
|
170
|
+
# Initialize language model
|
|
171
|
+
self.language_model = LanguageModel(config.text_config)
|
|
172
|
+
|
|
173
|
+
# Image projection layers
|
|
174
|
+
image_dim = config.vision_config.dim_embed[-1]
|
|
175
|
+
text_dim = config.text_config.d_model
|
|
176
|
+
self.image_projection = mx.zeros((image_dim, text_dim))
|
|
177
|
+
|
|
178
|
+
self.image_proj_norm = nn.LayerNorm(text_dim)
|
|
179
|
+
|
|
180
|
+
# Position embeddings
|
|
181
|
+
if config.image_pos_embed["type"] == "learned_abs_2d":
|
|
182
|
+
self.image_pos_embed = LearnedPositionEmbedding2D(
|
|
183
|
+
embedding_dim=image_dim,
|
|
184
|
+
num_pos=config.image_pos_embed["max_pos_embeddings"],
|
|
185
|
+
)
|
|
186
|
+
else:
|
|
187
|
+
raise NotImplementedError(
|
|
188
|
+
f"Position embedding type {config.image_pos_embed['type']} not supported"
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
# Temporal embeddings
|
|
192
|
+
if config.visual_temporal_embedding["type"] == "COSINE":
|
|
193
|
+
self.visual_temporal_embed = PositionalEmbeddingCosine1D(
|
|
194
|
+
embed_dim=image_dim,
|
|
195
|
+
max_seq_len=config.visual_temporal_embedding["max_temporal_embeddings"],
|
|
196
|
+
)
|
|
197
|
+
else:
|
|
198
|
+
raise NotImplementedError(
|
|
199
|
+
f"Temporal embedding type {config.visual_temporal_embedding['type']} not supported"
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
self.image_feature_source = config.image_feature_source
|
|
203
|
+
|
|
204
|
+
def _encode_image(self, pixel_values, extract_features=True):
|
|
205
|
+
"""Encode image using vision model and add position embeddings."""
|
|
206
|
+
T = 1 # Single frame for now
|
|
207
|
+
|
|
208
|
+
# Get vision features
|
|
209
|
+
if extract_features:
|
|
210
|
+
batch_size, C, H, W = pixel_values.shape
|
|
211
|
+
x = self.vision_tower(pixel_values)
|
|
212
|
+
else:
|
|
213
|
+
x = pixel_values
|
|
214
|
+
batch_size = pixel_values.shape[0]
|
|
215
|
+
|
|
216
|
+
# Assuming this is part of a class method, keeping the same structure
|
|
217
|
+
if self.image_pos_embed is not None:
|
|
218
|
+
# Reshape to (batch_size * T, -1, feature_dim)
|
|
219
|
+
x = mx.reshape(x, (batch_size * T, -1, x.shape[-1]))
|
|
220
|
+
num_tokens = x.shape[-2]
|
|
221
|
+
h, w = int(num_tokens**0.5), int(num_tokens**0.5)
|
|
222
|
+
assert h * w == num_tokens, "only support square feature maps for now"
|
|
223
|
+
# Reshape to (batch_size * T, h, w, feature_dim)
|
|
224
|
+
x = mx.reshape(x, (batch_size * T, h, w, x.shape[-1]))
|
|
225
|
+
pos_embed = self.image_pos_embed(x)
|
|
226
|
+
x = x + pos_embed
|
|
227
|
+
# Reshape to (batch_size, T * h * w, feature_dim)
|
|
228
|
+
x = mx.reshape(x, (batch_size, T * h * w, x.shape[-1]))
|
|
229
|
+
|
|
230
|
+
if self.visual_temporal_embed is not None:
|
|
231
|
+
# Reshape for temporal embedding
|
|
232
|
+
x_temp = mx.reshape(x, (batch_size, T, -1, x.shape[-1]))
|
|
233
|
+
temporal_input = x_temp[:, :, 0]
|
|
234
|
+
visual_temporal_embed = self.visual_temporal_embed(temporal_input)
|
|
235
|
+
# Expand dims for broadcasting
|
|
236
|
+
visual_temporal_embed = mx.expand_dims(visual_temporal_embed, axis=2)
|
|
237
|
+
x = mx.reshape(x, (batch_size, T, -1, x.shape[-1])) + visual_temporal_embed
|
|
238
|
+
|
|
239
|
+
x_feat_dict = {}
|
|
240
|
+
|
|
241
|
+
# Spatial average pooling
|
|
242
|
+
x_spatial = mx.reshape(x, (batch_size, T, -1, x.shape[-1]))
|
|
243
|
+
spatial_avg_pool_x = mx.mean(x_spatial, axis=2)
|
|
244
|
+
x_feat_dict["spatial_avg_pool"] = spatial_avg_pool_x
|
|
245
|
+
|
|
246
|
+
# Temporal average pooling
|
|
247
|
+
x_temporal = mx.reshape(x, (batch_size, T, -1, x.shape[-1]))
|
|
248
|
+
temporal_avg_pool_x = mx.mean(x_temporal, axis=1)
|
|
249
|
+
x_feat_dict["temporal_avg_pool"] = temporal_avg_pool_x
|
|
250
|
+
|
|
251
|
+
# Last frame features
|
|
252
|
+
x_last = mx.reshape(x, (batch_size, T, -1, x.shape[-1]))
|
|
253
|
+
x = x_last[:, -1]
|
|
254
|
+
x_feat_dict["last_frame"] = x
|
|
255
|
+
|
|
256
|
+
# Gather features based on source configuration
|
|
257
|
+
new_x = []
|
|
258
|
+
for _image_feature_source in self.image_feature_source:
|
|
259
|
+
if _image_feature_source not in x_feat_dict:
|
|
260
|
+
raise ValueError(
|
|
261
|
+
f"invalid image feature source: {_image_feature_source}"
|
|
262
|
+
)
|
|
263
|
+
new_x.append(x_feat_dict[_image_feature_source])
|
|
264
|
+
|
|
265
|
+
# Concatenate features
|
|
266
|
+
x = mx.concatenate(new_x, axis=1)
|
|
267
|
+
|
|
268
|
+
# Final projection and normalization
|
|
269
|
+
x = x @ self.image_projection
|
|
270
|
+
x = self.image_proj_norm(x)
|
|
271
|
+
|
|
272
|
+
return x
|
|
273
|
+
|
|
274
|
+
def _merge_input_ids_with_image_features(self, image_features, inputs_embeds=None):
|
|
275
|
+
batch_size, image_token_length, _ = image_features.shape
|
|
276
|
+
image_attention_mask = mx.ones((batch_size, image_token_length))
|
|
277
|
+
|
|
278
|
+
if inputs_embeds is None:
|
|
279
|
+
return image_features, image_attention_mask
|
|
280
|
+
|
|
281
|
+
task_prefix_embeds = inputs_embeds
|
|
282
|
+
task_prefix_attention_mask = mx.ones((batch_size, task_prefix_embeds.shape[1]))
|
|
283
|
+
|
|
284
|
+
if len(task_prefix_attention_mask.shape) == 3:
|
|
285
|
+
task_prefix_attention_mask = task_prefix_attention_mask[:, 0]
|
|
286
|
+
|
|
287
|
+
# Concatenate image features and task prefix embeddings
|
|
288
|
+
inputs_embeds = mx.concatenate([image_features, task_prefix_embeds], axis=1)
|
|
289
|
+
attention_mask = mx.concatenate(
|
|
290
|
+
[image_attention_mask, task_prefix_attention_mask], axis=1
|
|
291
|
+
)
|
|
292
|
+
return inputs_embeds, attention_mask
|
|
293
|
+
|
|
294
|
+
@property
|
|
295
|
+
def layers(self):
|
|
296
|
+
return self.language_model.model.layers
|
|
297
|
+
|
|
298
|
+
def __call__(
|
|
299
|
+
self,
|
|
300
|
+
input_ids=None,
|
|
301
|
+
pixel_values=None,
|
|
302
|
+
cache=None,
|
|
303
|
+
decoder_input_ids=None,
|
|
304
|
+
decoder_attention_mask=None,
|
|
305
|
+
labels=None,
|
|
306
|
+
**kwargs,
|
|
307
|
+
):
|
|
308
|
+
"""Forward pass."""
|
|
309
|
+
attention_mask = None
|
|
310
|
+
decoder_inputs_embeds = None
|
|
311
|
+
|
|
312
|
+
# Process image if provided
|
|
313
|
+
if pixel_values is not None:
|
|
314
|
+
image_features = self._encode_image(pixel_values)
|
|
315
|
+
|
|
316
|
+
# Get input embeddings if needed
|
|
317
|
+
inputs_embeds = None
|
|
318
|
+
if input_ids is not None:
|
|
319
|
+
inputs_embeds = self.language_model.model.shared(input_ids)
|
|
320
|
+
|
|
321
|
+
# Merge image features with text embeddings
|
|
322
|
+
inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(
|
|
323
|
+
image_features, inputs_embeds
|
|
324
|
+
)
|
|
325
|
+
else:
|
|
326
|
+
inputs_embeds = None
|
|
327
|
+
attention_mask = None
|
|
328
|
+
|
|
329
|
+
# Handle decoder input IDs
|
|
330
|
+
if labels is not None and decoder_input_ids is None:
|
|
331
|
+
decoder_input_ids = shift_tokens_right(
|
|
332
|
+
labels, self.config.pad_token_id, self.config.bos_token_id
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
if decoder_input_ids is None and decoder_inputs_embeds is None:
|
|
336
|
+
decoder_start_token_id = getattr(
|
|
337
|
+
self.config, "decoder_start_token_id", 0
|
|
338
|
+
) # 2 is common for many models
|
|
339
|
+
if decoder_start_token_id is None:
|
|
340
|
+
decoder_start_token_id = 0
|
|
341
|
+
|
|
342
|
+
decoder_input_ids = mx.array([decoder_start_token_id])[None, :]
|
|
343
|
+
decoder_inputs_embeds = self.language_model.model.shared(decoder_input_ids)
|
|
344
|
+
decoder_input_ids = None
|
|
345
|
+
|
|
346
|
+
# Forward through language model
|
|
347
|
+
outputs = self.language_model(
|
|
348
|
+
input_ids=input_ids,
|
|
349
|
+
inputs_embeds=inputs_embeds,
|
|
350
|
+
attention_mask=attention_mask,
|
|
351
|
+
decoder_input_ids=decoder_input_ids,
|
|
352
|
+
decoder_inputs_embeds=decoder_inputs_embeds,
|
|
353
|
+
decoder_attention_mask=decoder_attention_mask,
|
|
354
|
+
cache=cache,
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
return outputs
|
|
358
|
+
|
|
359
|
+
@staticmethod
|
|
360
|
+
def sanitize(weights):
|
|
361
|
+
sanitized_weights = {}
|
|
362
|
+
for k, v in weights.items():
|
|
363
|
+
if "final_logits_bias" in k:
|
|
364
|
+
continue
|
|
365
|
+
sanitized_weights[k] = v
|
|
366
|
+
return sanitized_weights
|