nexaai 1.0.19rc6__cp310-cp310-macosx_14_0_universal2.whl → 1.0.19rc7__cp310-cp310-macosx_14_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nexaai might be problematic. Click here for more details.
- nexaai/_stub.cpython-310-darwin.so +0 -0
- nexaai/_version.py +1 -1
- nexaai/binds/libnexa_bridge.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libggml-base.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libggml-cpu.so +0 -0
- nexaai/binds/nexa_llama_cpp/libggml-metal.so +0 -0
- nexaai/binds/nexa_llama_cpp/libggml.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libllama.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libmtmd.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexa_mlx/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexa_mlx/py-lib/asr/__init__.py +12 -0
- nexaai/binds/nexa_mlx/py-lib/asr/interface.py +122 -0
- nexaai/binds/nexa_mlx/py-lib/common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/common/utils.py +25 -0
- nexaai/binds/nexa_mlx/py-lib/cv/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/cv/generate.py +195 -0
- nexaai/binds/nexa_mlx/py-lib/cv/interface.py +151 -0
- nexaai/binds/nexa_mlx/py-lib/cv/main.py +81 -0
- nexaai/binds/nexa_mlx/py-lib/cv/modeling/pp_ocr_v4.py +1736 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/generate.py +333 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/interface.py +617 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/main.py +173 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/modeling/nexa_jina_v2.py +399 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/generate_sd.py +244 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/interface.py +82 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/main.py +281 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/__init__.py +306 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/clip.py +116 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/config.py +65 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/model_io.py +386 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/sampler.py +105 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/tokenizer.py +100 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/unet.py +460 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/vae.py +274 -0
- nexaai/binds/nexa_mlx/py-lib/llm/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/llm/generate.py +149 -0
- nexaai/binds/nexa_mlx/py-lib/llm/interface.py +764 -0
- nexaai/binds/nexa_mlx/py-lib/llm/main.py +68 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/generate.py +174 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/interface.py +287 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/main.py +127 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/modeling/nexa_jina_rerank.py +330 -0
- nexaai/binds/nexa_mlx/py-lib/sd/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/sd/interface.py +362 -0
- nexaai/binds/nexa_mlx/py-lib/sd/main.py +286 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/__init__.py +306 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/clip.py +116 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/config.py +65 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/model_io.py +385 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/sampler.py +105 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/tokenizer.py +100 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/unet.py +460 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/vae.py +274 -0
- nexaai/binds/nexa_mlx/py-lib/tts/__init__.py +12 -0
- nexaai/binds/nexa_mlx/py-lib/tts/interface.py +276 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/__init__.py +3 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/generate.py +572 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl.py +294 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl_moe.py +276 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/interface.py +504 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/main.py +320 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/convert.py +68 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/interpolate.py +186 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/language.py +233 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/vision.py +503 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/base.py +202 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/cache.py +230 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/florence2.py +366 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/language.py +488 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/vision.py +591 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/gemma3.py +213 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/language.py +315 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/vision.py +238 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/__init__.py +2 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/audio.py +1038 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/config.py +139 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/language.py +629 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/vision.py +1022 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/__init__.py +9 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/idefics2.py +294 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/language.py +191 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/vision.py +267 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/idefics3.py +175 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/language.py +192 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/vision.py +233 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/__init__.py +9 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/language.py +220 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/processor.py +393 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/vision.py +293 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kernels.py +307 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/language.py +509 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/vision.py +522 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/language.py +386 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/llama4.py +138 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/vision.py +560 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/language.py +240 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/llava.py +153 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/vision.py +259 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/__init__.py +9 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/language.py +236 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/vision.py +303 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/language.py +230 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/llava_next.py +160 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/vision.py +243 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/mistral3.py +283 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/language.py +416 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/mllama.py +172 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/vision.py +499 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/language.py +243 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/molmo.py +133 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/vision.py +465 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/__init__.py +10 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/language.py +230 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/sam.py +557 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/vision.py +526 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/language.py +282 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/paligemma.py +160 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/vision.py +242 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/language.py +21 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/su_rope.py +71 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/vision.py +324 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/language.py +229 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/pixtral.py +161 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/vision.py +320 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/config.py +104 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/language.py +490 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/vision.py +312 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/processor.py +476 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/qwen3vl.py +1223 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1309 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_5_vl.py +209 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_vl.py +215 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/prompt_utils.py +474 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/sample_utils.py +39 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/tokenizer_utils.py +344 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/__init__.py +9 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/lora.py +70 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/trainer.py +296 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/utils.py +160 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/utils.py +928 -0
- nexaai/binds/nexa_nexaml/libggml-base.dylib +0 -0
- nexaai/binds/nexa_nexaml/libggml-cpu.so +0 -0
- nexaai/binds/nexa_nexaml/libggml-metal.so +0 -0
- nexaai/binds/nexa_nexaml/libggml.dylib +0 -0
- nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +276 -0
- nexaai/mlx_backend/vlm/interface.py +21 -4
- nexaai/mlx_backend/vlm/main.py +6 -2
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1309 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
- nexaai/utils/manifest_utils.py +222 -15
- nexaai/utils/model_manager.py +83 -7
- nexaai/utils/model_types.py +2 -0
- {nexaai-1.0.19rc6.dist-info → nexaai-1.0.19rc7.dist-info}/METADATA +1 -1
- {nexaai-1.0.19rc6.dist-info → nexaai-1.0.19rc7.dist-info}/RECORD +224 -24
- {nexaai-1.0.19rc6.dist-info → nexaai-1.0.19rc7.dist-info}/WHEEL +0 -0
- {nexaai-1.0.19rc6.dist-info → nexaai-1.0.19rc7.dist-info}/top_level.txt +0 -0
|
File without changes
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Any, Optional
|
|
4
|
+
|
|
5
|
+
import mlx.core as mx
|
|
6
|
+
from mlx.utils import tree_map
|
|
7
|
+
|
|
8
|
+
from .cache import QuantizedKVCache
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class BaseModelArgs:
|
|
13
|
+
@classmethod
|
|
14
|
+
def from_dict(cls, params):
|
|
15
|
+
return cls(**{k: v for k, v in params.items() if k in inspect.signature(cls).parameters})
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def create_causal_mask(
|
|
19
|
+
N: int,
|
|
20
|
+
offset: int = 0,
|
|
21
|
+
window_size: Optional[int] = None,
|
|
22
|
+
lengths: Optional[mx.array] = None,
|
|
23
|
+
):
|
|
24
|
+
rinds = mx.arange(offset + N)
|
|
25
|
+
linds = mx.arange(offset, offset + N) if offset else rinds
|
|
26
|
+
linds = linds[:, None]
|
|
27
|
+
rinds = rinds[None]
|
|
28
|
+
mask = linds >= rinds
|
|
29
|
+
if window_size is not None:
|
|
30
|
+
mask = mask & (linds <= rinds + window_size)
|
|
31
|
+
if lengths is not None:
|
|
32
|
+
lengths = lengths[:, None, None, None]
|
|
33
|
+
mask = mask & (rinds < lengths)
|
|
34
|
+
return mask
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def create_attention_mask(h: mx.array, cache: Optional[Any] = None, return_array: bool = False):
|
|
38
|
+
T = h.shape[1]
|
|
39
|
+
if T > 1:
|
|
40
|
+
offset = 0
|
|
41
|
+
window_size = None
|
|
42
|
+
if cache is not None and cache[0] is not None:
|
|
43
|
+
c = cache[0]
|
|
44
|
+
offset = c.offset
|
|
45
|
+
if hasattr(c, "max_size"):
|
|
46
|
+
window_size = c.max_size
|
|
47
|
+
offset = min(window_size, offset)
|
|
48
|
+
return_array = return_array or offset + T > window_size
|
|
49
|
+
if return_array:
|
|
50
|
+
return create_causal_mask(T, offset, window_size=window_size)
|
|
51
|
+
else:
|
|
52
|
+
return "causal"
|
|
53
|
+
else:
|
|
54
|
+
mask = None
|
|
55
|
+
return mask
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def quantized_scaled_dot_product_attention(
|
|
59
|
+
queries: mx.array,
|
|
60
|
+
q_keys: tuple[mx.array, mx.array, mx.array],
|
|
61
|
+
q_values: tuple[mx.array, mx.array, mx.array],
|
|
62
|
+
scale: float,
|
|
63
|
+
mask: Optional[mx.array],
|
|
64
|
+
group_size: int = 64,
|
|
65
|
+
bits: int = 8,
|
|
66
|
+
) -> mx.array:
|
|
67
|
+
B, n_q_heads, L, D = queries.shape
|
|
68
|
+
n_kv_heads = q_keys[0].shape[-3]
|
|
69
|
+
n_repeats = n_q_heads // n_kv_heads
|
|
70
|
+
|
|
71
|
+
queries *= scale
|
|
72
|
+
|
|
73
|
+
if n_repeats > 1:
|
|
74
|
+
queries = mx.reshape(queries, (B, n_kv_heads, n_repeats, L, D))
|
|
75
|
+
q_keys = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_keys)
|
|
76
|
+
q_values = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_values)
|
|
77
|
+
|
|
78
|
+
scores = mx.quantized_matmul(queries, *q_keys, transpose=True, group_size=group_size, bits=bits)
|
|
79
|
+
if mask is not None:
|
|
80
|
+
if isinstance(mask, str):
|
|
81
|
+
qL, kL = scores.shape[-2:]
|
|
82
|
+
q_indices = mx.arange(kL - qL, kL)
|
|
83
|
+
k_indices = mx.arange(kL)
|
|
84
|
+
mask = q_indices[:, None] >= k_indices[None]
|
|
85
|
+
if mask.dtype == mx.bool_:
|
|
86
|
+
scores = mx.where(mask, scores, mx.finfo(scores.dtype).min)
|
|
87
|
+
else:
|
|
88
|
+
scores += mask
|
|
89
|
+
scores = mx.softmax(scores, axis=-1, precise=True)
|
|
90
|
+
out = mx.quantized_matmul(scores, *q_values, transpose=False, group_size=group_size, bits=bits)
|
|
91
|
+
|
|
92
|
+
if n_repeats > 1:
|
|
93
|
+
out = mx.reshape(out, (B, n_q_heads, L, D))
|
|
94
|
+
|
|
95
|
+
return out
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def scaled_dot_product_attention(
|
|
99
|
+
queries,
|
|
100
|
+
keys,
|
|
101
|
+
values,
|
|
102
|
+
cache,
|
|
103
|
+
scale: float,
|
|
104
|
+
mask: Optional[mx.array],
|
|
105
|
+
) -> mx.array:
|
|
106
|
+
if isinstance(cache, QuantizedKVCache):
|
|
107
|
+
return quantized_scaled_dot_product_attention(
|
|
108
|
+
queries,
|
|
109
|
+
keys,
|
|
110
|
+
values,
|
|
111
|
+
scale=scale,
|
|
112
|
+
mask=mask,
|
|
113
|
+
group_size=cache.group_size,
|
|
114
|
+
bits=cache.bits,
|
|
115
|
+
)
|
|
116
|
+
else:
|
|
117
|
+
return mx.fast.scaled_dot_product_attention(queries, keys, values, scale=scale, mask=mask)
|
|
@@ -0,0 +1,531 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Optional
|
|
2
|
+
|
|
3
|
+
import mlx.core as mx
|
|
4
|
+
import mlx.nn as nn
|
|
5
|
+
from mlx.utils import tree_flatten, tree_map, tree_unflatten
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def make_prompt_cache(
|
|
9
|
+
model: nn.Module,
|
|
10
|
+
max_kv_size: Optional[int] = None,
|
|
11
|
+
) -> List[Any]:
|
|
12
|
+
"""
|
|
13
|
+
Construct the model's cache for use when cgeneration.
|
|
14
|
+
|
|
15
|
+
This function will defer the cache construction to the model if it has a
|
|
16
|
+
``make_cache`` method, otherwise it will make a default KV cache.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
model (nn.Module): The language model.
|
|
20
|
+
max_kv_size (Optional[int]): If provided and the model does not have a
|
|
21
|
+
``make_cache`` method, a ``RotatingKVCache`` is used with a maximum
|
|
22
|
+
size of ``max_kv_size``
|
|
23
|
+
"""
|
|
24
|
+
if hasattr(model, "make_cache"):
|
|
25
|
+
return model.make_cache()
|
|
26
|
+
|
|
27
|
+
num_layers = len(model.layers)
|
|
28
|
+
if max_kv_size is not None:
|
|
29
|
+
return [RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers)]
|
|
30
|
+
else:
|
|
31
|
+
return [KVCache() for _ in range(num_layers)]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}):
|
|
35
|
+
"""
|
|
36
|
+
Save a pre-computed prompt cache to a file.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
file_name (str): The ``.safetensors`` file name.
|
|
40
|
+
cache (List[Any]): The model state.
|
|
41
|
+
metadata (Dict[str, str]): Optional metadata to save along with model
|
|
42
|
+
state.
|
|
43
|
+
"""
|
|
44
|
+
cache_data = [c.state for c in cache]
|
|
45
|
+
cache_info = [c.meta_state for c in cache]
|
|
46
|
+
cache_data = dict(tree_flatten(cache_data))
|
|
47
|
+
cache_classes = [type(c).__name__ for c in cache]
|
|
48
|
+
cache_metadata = [cache_info, metadata, cache_classes]
|
|
49
|
+
cache_metadata = dict(tree_flatten(cache_metadata))
|
|
50
|
+
mx.save_safetensors(file_name, cache_data, cache_metadata)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def load_prompt_cache(file_name, return_metadata=False):
|
|
54
|
+
"""
|
|
55
|
+
Load a prompt cache from a file.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
file_name (str): The ``.safetensors`` file name.
|
|
59
|
+
return_metadata (bool): Whether or not to return metadata.
|
|
60
|
+
Default: ``False``.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
List[Any] or Tuple[List[Any], Dict[str, str]]: The prompt cache and
|
|
64
|
+
the metadata if requested.
|
|
65
|
+
"""
|
|
66
|
+
arrays, cache_metadata = mx.load(file_name, return_metadata=True)
|
|
67
|
+
arrays = tree_unflatten(list(arrays.items()))
|
|
68
|
+
cache_metadata = tree_unflatten(list(cache_metadata.items()))
|
|
69
|
+
info, metadata, classes = cache_metadata
|
|
70
|
+
cache = [globals()[c]() for c in classes]
|
|
71
|
+
for c, state, meta_state in zip(cache, arrays, info):
|
|
72
|
+
c.state = state
|
|
73
|
+
c.meta_state = meta_state
|
|
74
|
+
if return_metadata:
|
|
75
|
+
return cache, metadata
|
|
76
|
+
return cache
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def can_trim_prompt_cache(cache: List[Any]) -> bool:
|
|
80
|
+
"""
|
|
81
|
+
Check if model's cache can be trimmed.
|
|
82
|
+
"""
|
|
83
|
+
return all(c.is_trimmable() for c in cache)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
|
|
87
|
+
"""
|
|
88
|
+
Trim the model's cache by the given number of tokens.
|
|
89
|
+
|
|
90
|
+
This function will trim the cache if possible (in-place) and return the
|
|
91
|
+
number of tokens that were trimmed.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
cache (List[Any]): The model's cache.
|
|
95
|
+
num_tokens (int): The number of tokens to trim.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
(int): The number of tokens that were trimmed.
|
|
99
|
+
"""
|
|
100
|
+
if not can_trim_prompt_cache(cache) or len(cache) == 0:
|
|
101
|
+
return 0
|
|
102
|
+
return [c.trim(num_tokens) for c in cache][0]
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class _BaseCache:
|
|
106
|
+
@property
|
|
107
|
+
def state(self):
|
|
108
|
+
return []
|
|
109
|
+
|
|
110
|
+
@state.setter
|
|
111
|
+
def state(self, v):
|
|
112
|
+
if v is not None and v:
|
|
113
|
+
raise ValueError("This cache has no state but a state was set.")
|
|
114
|
+
|
|
115
|
+
@property
|
|
116
|
+
def meta_state(self):
|
|
117
|
+
return ""
|
|
118
|
+
|
|
119
|
+
@meta_state.setter
|
|
120
|
+
def meta_state(self, v):
|
|
121
|
+
if v is not None and v:
|
|
122
|
+
raise ValueError("This cache has no meta_state but a meta_state was set.")
|
|
123
|
+
|
|
124
|
+
def is_trimmable(self):
|
|
125
|
+
return False
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class ConcatenateKVCache(_BaseCache):
|
|
129
|
+
"""ConcatenateKVCache the simplest KV cache implementation.
|
|
130
|
+
|
|
131
|
+
Can be used as a mock KV cache or when large blocks are being processed at
|
|
132
|
+
a time in which case KVCache isn't necessarily faster. Consider using the
|
|
133
|
+
KVCache with a larger step size before using this cache.
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
def __init__(self):
|
|
137
|
+
self.keys = None
|
|
138
|
+
self.values = None
|
|
139
|
+
self.offset = 0
|
|
140
|
+
|
|
141
|
+
def update_and_fetch(self, keys, values):
|
|
142
|
+
if self.keys is None:
|
|
143
|
+
self.keys = keys
|
|
144
|
+
self.values = values
|
|
145
|
+
else:
|
|
146
|
+
self.keys = mx.concatenate([self.keys, keys], axis=-2)
|
|
147
|
+
self.values = mx.concatenate([self.values, values], axis=-2)
|
|
148
|
+
self.offset = self.keys.shape[-2]
|
|
149
|
+
|
|
150
|
+
return self.keys, self.values
|
|
151
|
+
|
|
152
|
+
@property
|
|
153
|
+
def state(self):
|
|
154
|
+
return self.keys, self.values
|
|
155
|
+
|
|
156
|
+
@state.setter
|
|
157
|
+
def state(self, v):
|
|
158
|
+
self.keys, self.values = v
|
|
159
|
+
self.offset = self.keys.shape[-2]
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
class QuantizedKVCache(_BaseCache):
|
|
163
|
+
def __init__(self, group_size: int = 64, bits: int = 8):
|
|
164
|
+
self.keys = None
|
|
165
|
+
self.values = None
|
|
166
|
+
self.offset = 0
|
|
167
|
+
self.step = 256
|
|
168
|
+
self.group_size = group_size
|
|
169
|
+
self.bits = bits
|
|
170
|
+
|
|
171
|
+
def update_and_fetch(self, keys, values):
|
|
172
|
+
B, n_kv_heads, num_steps, k_head_dim = keys.shape
|
|
173
|
+
v_head_dim = values.shape[-1]
|
|
174
|
+
prev = self.offset
|
|
175
|
+
|
|
176
|
+
if self.keys is None or (prev + num_steps) > self.keys[0].shape[-2]:
|
|
177
|
+
el_per_int = 8 * mx.uint32.size // self.bits
|
|
178
|
+
new_steps = (self.step + num_steps - 1) // self.step * self.step
|
|
179
|
+
shape = (B, n_kv_heads, new_steps)
|
|
180
|
+
|
|
181
|
+
def init_quant(dim):
|
|
182
|
+
return (
|
|
183
|
+
mx.zeros((*shape, dim // el_per_int), dtype=mx.uint32),
|
|
184
|
+
mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype),
|
|
185
|
+
mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype),
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
def expand_quant(x):
|
|
189
|
+
new_x = mx.zeros((*shape, x.shape[-1]), dtype=x.dtype)
|
|
190
|
+
return mx.concatenate([x, new_x], axis=-2)
|
|
191
|
+
|
|
192
|
+
if self.keys is not None:
|
|
193
|
+
if prev % self.step != 0:
|
|
194
|
+
self.keys, self.values = tree_map(
|
|
195
|
+
lambda x: x[..., :prev, :], (self.keys, self.values)
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
self.keys, self.values = tree_map(expand_quant, (self.keys, self.values))
|
|
199
|
+
else:
|
|
200
|
+
self.keys, self.values = init_quant(k_head_dim), init_quant(v_head_dim)
|
|
201
|
+
|
|
202
|
+
self.offset += num_steps
|
|
203
|
+
|
|
204
|
+
keys = mx.quantize(keys, group_size=self.group_size, bits=self.bits)
|
|
205
|
+
values = mx.quantize(values, group_size=self.group_size, bits=self.bits)
|
|
206
|
+
for i in range(len(self.keys)):
|
|
207
|
+
self.keys[i][..., prev : self.offset, :] = keys[i]
|
|
208
|
+
self.values[i][..., prev : self.offset, :] = values[i]
|
|
209
|
+
|
|
210
|
+
return tree_map(lambda x: x[..., : self.offset, :], (self.keys, self.values))
|
|
211
|
+
|
|
212
|
+
@property
|
|
213
|
+
def state(self):
|
|
214
|
+
if self.offset == self.keys[0].shape[2]:
|
|
215
|
+
return self.keys, self.values
|
|
216
|
+
else:
|
|
217
|
+
return tree_map(lambda x: x[..., : self.offset, :], (self.keys, self.values))
|
|
218
|
+
|
|
219
|
+
@state.setter
|
|
220
|
+
def state(self, v):
|
|
221
|
+
self.keys, self.values = v
|
|
222
|
+
|
|
223
|
+
@property
|
|
224
|
+
def meta_state(self):
|
|
225
|
+
return tuple(map(str, (self.step, self.offset, self.group_size, self.bits)))
|
|
226
|
+
|
|
227
|
+
@meta_state.setter
|
|
228
|
+
def meta_state(self, v):
|
|
229
|
+
self.step, self.offset, self.group_size, self.bits = map(int, v)
|
|
230
|
+
|
|
231
|
+
def is_trimmable(self):
|
|
232
|
+
return True
|
|
233
|
+
|
|
234
|
+
def trim(self, n):
|
|
235
|
+
n = min(self.offset, n)
|
|
236
|
+
self.offset -= n
|
|
237
|
+
return n
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
class KVCache(_BaseCache):
|
|
241
|
+
def __init__(self):
|
|
242
|
+
self.keys = None
|
|
243
|
+
self.values = None
|
|
244
|
+
self.offset = 0
|
|
245
|
+
self.step = 256
|
|
246
|
+
|
|
247
|
+
def update_and_fetch(self, keys, values):
|
|
248
|
+
prev = self.offset
|
|
249
|
+
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
|
|
250
|
+
B, n_kv_heads, _, k_head_dim = keys.shape
|
|
251
|
+
v_head_dim = values.shape[3]
|
|
252
|
+
n_steps = (self.step + keys.shape[2] - 1) // self.step
|
|
253
|
+
k_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim)
|
|
254
|
+
v_shape = (B, n_kv_heads, n_steps * self.step, v_head_dim)
|
|
255
|
+
new_k = mx.zeros(k_shape, keys.dtype)
|
|
256
|
+
new_v = mx.zeros(v_shape, values.dtype)
|
|
257
|
+
if self.keys is not None:
|
|
258
|
+
if prev % self.step != 0:
|
|
259
|
+
self.keys = self.keys[..., :prev, :]
|
|
260
|
+
self.values = self.values[..., :prev, :]
|
|
261
|
+
self.keys = mx.concatenate([self.keys, new_k], axis=2)
|
|
262
|
+
self.values = mx.concatenate([self.values, new_v], axis=2)
|
|
263
|
+
else:
|
|
264
|
+
self.keys, self.values = new_k, new_v
|
|
265
|
+
|
|
266
|
+
self.offset += keys.shape[2]
|
|
267
|
+
self.keys[..., prev : self.offset, :] = keys
|
|
268
|
+
self.values[..., prev : self.offset, :] = values
|
|
269
|
+
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
|
270
|
+
|
|
271
|
+
@property
|
|
272
|
+
def state(self):
|
|
273
|
+
if self.offset == self.keys.shape[2]:
|
|
274
|
+
return self.keys, self.values
|
|
275
|
+
else:
|
|
276
|
+
return (
|
|
277
|
+
self.keys[..., : self.offset, :],
|
|
278
|
+
self.values[..., : self.offset, :],
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
@state.setter
|
|
282
|
+
def state(self, v):
|
|
283
|
+
self.keys, self.values = v
|
|
284
|
+
self.offset = self.keys.shape[2]
|
|
285
|
+
|
|
286
|
+
def is_trimmable(self):
|
|
287
|
+
return True
|
|
288
|
+
|
|
289
|
+
def trim(self, n):
|
|
290
|
+
n = min(self.offset, n)
|
|
291
|
+
self.offset -= n
|
|
292
|
+
return n
|
|
293
|
+
|
|
294
|
+
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
|
|
295
|
+
quant_cache = QuantizedKVCache(group_size=group_size, bits=bits)
|
|
296
|
+
quant_cache.offset = self.offset
|
|
297
|
+
if self.keys is not None:
|
|
298
|
+
quant_cache.keys = mx.quantize(self.keys, group_size=group_size, bits=bits)
|
|
299
|
+
quant_cache.values = mx.quantize(self.values, group_size=group_size, bits=bits)
|
|
300
|
+
return quant_cache
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
class RotatingKVCache(_BaseCache):
|
|
304
|
+
|
|
305
|
+
def __init__(self, max_size=None, keep=0, step=256):
|
|
306
|
+
self.keep = keep
|
|
307
|
+
self.keys = None
|
|
308
|
+
self.values = None
|
|
309
|
+
self.offset = 0
|
|
310
|
+
self.max_size = max_size
|
|
311
|
+
self.step = step
|
|
312
|
+
self._idx = 0
|
|
313
|
+
|
|
314
|
+
def _trim(self, trim_size, v, append=None):
|
|
315
|
+
to_cat = []
|
|
316
|
+
if trim_size > 0:
|
|
317
|
+
to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]]
|
|
318
|
+
else:
|
|
319
|
+
to_cat = [v]
|
|
320
|
+
if append is not None:
|
|
321
|
+
to_cat.append(append)
|
|
322
|
+
return mx.concatenate(to_cat, axis=2)
|
|
323
|
+
|
|
324
|
+
def _temporal_order(self, v):
|
|
325
|
+
"""
|
|
326
|
+
Rearrange the cache into temporal order, slicing off the end if unused.
|
|
327
|
+
"""
|
|
328
|
+
if self._idx == v.shape[2]:
|
|
329
|
+
return v
|
|
330
|
+
elif self._idx < self.offset:
|
|
331
|
+
return mx.concatenate(
|
|
332
|
+
[
|
|
333
|
+
v[..., : self.keep, :],
|
|
334
|
+
v[..., self._idx :, :],
|
|
335
|
+
v[..., self.keep : self._idx, :],
|
|
336
|
+
],
|
|
337
|
+
axis=2,
|
|
338
|
+
)
|
|
339
|
+
else:
|
|
340
|
+
return v[..., : self._idx, :]
|
|
341
|
+
|
|
342
|
+
def _update_concat(self, keys, values):
|
|
343
|
+
if self.keys is None:
|
|
344
|
+
self.keys = keys
|
|
345
|
+
self.values = values
|
|
346
|
+
else:
|
|
347
|
+
# Put the keys/values in temporal order to
|
|
348
|
+
# preserve context
|
|
349
|
+
self.keys = self._temporal_order(self.keys)
|
|
350
|
+
self.values = self._temporal_order(self.values)
|
|
351
|
+
|
|
352
|
+
# The largest size is self.max_size + S to ensure
|
|
353
|
+
# every token gets at least self.max_size context
|
|
354
|
+
trim_size = self._idx - self.max_size
|
|
355
|
+
self.keys = self._trim(trim_size, self.keys, keys)
|
|
356
|
+
self.values = self._trim(trim_size, self.values, values)
|
|
357
|
+
self.offset += keys.shape[2]
|
|
358
|
+
self._idx = self.keys.shape[2]
|
|
359
|
+
return self.keys, self.values
|
|
360
|
+
|
|
361
|
+
def _update_in_place(self, keys, values):
|
|
362
|
+
# May not have hit the max size yet, so potentially
|
|
363
|
+
# keep growing the cache
|
|
364
|
+
B, n_kv_heads, S, k_head_dim = keys.shape
|
|
365
|
+
prev = self.offset
|
|
366
|
+
if self.keys is None or (prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size):
|
|
367
|
+
v_head_dim = values.shape[3]
|
|
368
|
+
new_size = min(self.step, self.max_size - prev)
|
|
369
|
+
k_shape = (B, n_kv_heads, new_size, k_head_dim)
|
|
370
|
+
v_shape = (B, n_kv_heads, new_size, v_head_dim)
|
|
371
|
+
new_k = mx.zeros(k_shape, keys.dtype)
|
|
372
|
+
new_v = mx.zeros(v_shape, values.dtype)
|
|
373
|
+
if self.keys is not None:
|
|
374
|
+
self.keys = mx.concatenate([self.keys, new_k], axis=2)
|
|
375
|
+
self.values = mx.concatenate([self.values, new_v], axis=2)
|
|
376
|
+
else:
|
|
377
|
+
self.keys, self.values = new_k, new_v
|
|
378
|
+
self._idx = prev
|
|
379
|
+
|
|
380
|
+
# Trim if needed
|
|
381
|
+
trim_size = self.keys.shape[2] - self.max_size
|
|
382
|
+
if trim_size > 0:
|
|
383
|
+
self.keys = self._trim(trim_size, self.keys)
|
|
384
|
+
self.values = self._trim(trim_size, self.values)
|
|
385
|
+
self._idx = self.max_size
|
|
386
|
+
|
|
387
|
+
# Rotate
|
|
388
|
+
if self._idx == self.max_size:
|
|
389
|
+
self._idx = self.keep
|
|
390
|
+
|
|
391
|
+
# Assign
|
|
392
|
+
self.keys[..., self._idx : self._idx + S, :] = keys
|
|
393
|
+
self.values[..., self._idx : self._idx + S, :] = values
|
|
394
|
+
self.offset += S
|
|
395
|
+
self._idx += S
|
|
396
|
+
|
|
397
|
+
# If the buffer is not full, slice off the end
|
|
398
|
+
if self.offset < self.max_size:
|
|
399
|
+
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
|
400
|
+
return self.keys, self.values
|
|
401
|
+
|
|
402
|
+
def update_and_fetch(self, keys, values):
|
|
403
|
+
if keys.shape[2] == 1:
|
|
404
|
+
return self._update_in_place(keys, values)
|
|
405
|
+
return self._update_concat(keys, values)
|
|
406
|
+
|
|
407
|
+
@property
|
|
408
|
+
def state(self):
|
|
409
|
+
if self.offset < self.keys.shape[2]:
|
|
410
|
+
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
|
411
|
+
else:
|
|
412
|
+
return self.keys, self.values
|
|
413
|
+
|
|
414
|
+
@state.setter
|
|
415
|
+
def state(self, v):
|
|
416
|
+
self.keys, self.values = v
|
|
417
|
+
|
|
418
|
+
@property
|
|
419
|
+
def meta_state(self):
|
|
420
|
+
return tuple(map(str, (self.keep, self.max_size, self.step, self.offset, self._idx)))
|
|
421
|
+
|
|
422
|
+
@meta_state.setter
|
|
423
|
+
def meta_state(self, v):
|
|
424
|
+
self.keep, self.max_size, self.step, self.offset, self._idx = map(
|
|
425
|
+
int,
|
|
426
|
+
v,
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
def is_trimmable(self):
|
|
430
|
+
return self.offset < self.max_size
|
|
431
|
+
|
|
432
|
+
def trim(self, n):
|
|
433
|
+
n = min(self.offset, n)
|
|
434
|
+
self.offset -= n
|
|
435
|
+
self._idx -= n
|
|
436
|
+
return n
|
|
437
|
+
|
|
438
|
+
def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache:
|
|
439
|
+
raise NotImplementedError("RotatingKVCache Quantization NYI")
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
class MambaCache(_BaseCache):
|
|
443
|
+
def __init__(self):
|
|
444
|
+
self.cache = [None, None]
|
|
445
|
+
|
|
446
|
+
def __setitem__(self, idx, value):
|
|
447
|
+
self.cache[idx] = value
|
|
448
|
+
|
|
449
|
+
def __getitem__(self, idx):
|
|
450
|
+
return self.cache[idx]
|
|
451
|
+
|
|
452
|
+
@property
|
|
453
|
+
def state(self):
|
|
454
|
+
return self.cache
|
|
455
|
+
|
|
456
|
+
@state.setter
|
|
457
|
+
def state(self, v):
|
|
458
|
+
self.cache = v
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
class ChunkedKVCache(KVCache):
|
|
462
|
+
def __init__(self, chunk_size=None):
|
|
463
|
+
super().__init__()
|
|
464
|
+
self.chunk_size = chunk_size
|
|
465
|
+
self.start_position = 0
|
|
466
|
+
|
|
467
|
+
def maybe_trim_front(self):
|
|
468
|
+
# Maintain the cache below the chunk size
|
|
469
|
+
if self.keys is not None and self.keys.shape[2] >= self.chunk_size:
|
|
470
|
+
self.start_position += self.keys.shape[2] - self.chunk_size
|
|
471
|
+
self.keys = self.keys[..., -self.chunk_size :, :]
|
|
472
|
+
self.values = self.values[..., -self.chunk_size :, :]
|
|
473
|
+
|
|
474
|
+
def update_and_fetch(self, keys, values):
|
|
475
|
+
prev = self.offset - self.start_position
|
|
476
|
+
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
|
|
477
|
+
B, n_kv_heads, _, k_head_dim = keys.shape
|
|
478
|
+
v_head_dim = values.shape[3]
|
|
479
|
+
n_steps = (self.step + keys.shape[2] - 1) // self.step
|
|
480
|
+
k_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim)
|
|
481
|
+
v_shape = (B, n_kv_heads, n_steps * self.step, v_head_dim)
|
|
482
|
+
new_k = mx.zeros(k_shape, keys.dtype)
|
|
483
|
+
new_v = mx.zeros(v_shape, values.dtype)
|
|
484
|
+
if self.keys is not None:
|
|
485
|
+
if prev % self.step != 0:
|
|
486
|
+
self.keys = self.keys[..., :prev, :]
|
|
487
|
+
self.values = self.values[..., :prev, :]
|
|
488
|
+
self.keys = mx.concatenate([self.keys, new_k], axis=2)
|
|
489
|
+
self.values = mx.concatenate([self.values, new_v], axis=2)
|
|
490
|
+
else:
|
|
491
|
+
self.keys, self.values = new_k, new_v
|
|
492
|
+
|
|
493
|
+
self.offset += keys.shape[2]
|
|
494
|
+
end = self.offset - self.start_position
|
|
495
|
+
self.keys[..., prev:end, :] = keys
|
|
496
|
+
self.values[..., prev:end, :] = values
|
|
497
|
+
return self.keys[..., :end, :], self.values[..., :end, :]
|
|
498
|
+
|
|
499
|
+
def trim(self, n):
|
|
500
|
+
n = min(self.offset - self.start_position, n)
|
|
501
|
+
self.offset -= n
|
|
502
|
+
return n
|
|
503
|
+
|
|
504
|
+
@property
|
|
505
|
+
def meta_state(self):
|
|
506
|
+
return tuple(map(str, (self.chunk_size, self.start_position)))
|
|
507
|
+
|
|
508
|
+
@meta_state.setter
|
|
509
|
+
def meta_state(self, v):
|
|
510
|
+
self.chunk_size, self.start_position = map(int, v)
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
class CacheList(KVCache):
|
|
514
|
+
def __init__(self, *caches):
|
|
515
|
+
self.caches = caches
|
|
516
|
+
|
|
517
|
+
def __getitem__(self, idx):
|
|
518
|
+
return self.caches[idx]
|
|
519
|
+
|
|
520
|
+
@property
|
|
521
|
+
def state(self):
|
|
522
|
+
return [s for c in self.caches for s in c.state]
|
|
523
|
+
|
|
524
|
+
@state.setter
|
|
525
|
+
def state(self, v):
|
|
526
|
+
state_lens = [len(c.state) for c in self.caches]
|
|
527
|
+
start = 0
|
|
528
|
+
for c in self.caches:
|
|
529
|
+
l = len(c.state)
|
|
530
|
+
c.state = v[start : start + l]
|
|
531
|
+
start += l
|