nexaai 1.0.19rc6__cp310-cp310-macosx_14_0_universal2.whl → 1.0.19rc7__cp310-cp310-macosx_14_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nexaai might be problematic. Click here for more details.
- nexaai/_stub.cpython-310-darwin.so +0 -0
- nexaai/_version.py +1 -1
- nexaai/binds/libnexa_bridge.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libggml-base.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libggml-cpu.so +0 -0
- nexaai/binds/nexa_llama_cpp/libggml-metal.so +0 -0
- nexaai/binds/nexa_llama_cpp/libggml.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libllama.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libmtmd.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexa_mlx/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexa_mlx/py-lib/asr/__init__.py +12 -0
- nexaai/binds/nexa_mlx/py-lib/asr/interface.py +122 -0
- nexaai/binds/nexa_mlx/py-lib/common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/common/utils.py +25 -0
- nexaai/binds/nexa_mlx/py-lib/cv/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/cv/generate.py +195 -0
- nexaai/binds/nexa_mlx/py-lib/cv/interface.py +151 -0
- nexaai/binds/nexa_mlx/py-lib/cv/main.py +81 -0
- nexaai/binds/nexa_mlx/py-lib/cv/modeling/pp_ocr_v4.py +1736 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/generate.py +333 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/interface.py +617 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/main.py +173 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/modeling/nexa_jina_v2.py +399 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/generate_sd.py +244 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/interface.py +82 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/main.py +281 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/__init__.py +306 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/clip.py +116 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/config.py +65 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/model_io.py +386 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/sampler.py +105 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/tokenizer.py +100 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/unet.py +460 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/vae.py +274 -0
- nexaai/binds/nexa_mlx/py-lib/llm/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/llm/generate.py +149 -0
- nexaai/binds/nexa_mlx/py-lib/llm/interface.py +764 -0
- nexaai/binds/nexa_mlx/py-lib/llm/main.py +68 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/generate.py +174 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/interface.py +287 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/main.py +127 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/modeling/nexa_jina_rerank.py +330 -0
- nexaai/binds/nexa_mlx/py-lib/sd/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/sd/interface.py +362 -0
- nexaai/binds/nexa_mlx/py-lib/sd/main.py +286 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/__init__.py +306 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/clip.py +116 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/config.py +65 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/model_io.py +385 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/sampler.py +105 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/tokenizer.py +100 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/unet.py +460 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/vae.py +274 -0
- nexaai/binds/nexa_mlx/py-lib/tts/__init__.py +12 -0
- nexaai/binds/nexa_mlx/py-lib/tts/interface.py +276 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/__init__.py +3 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/generate.py +572 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl.py +294 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl_moe.py +276 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/interface.py +504 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/main.py +320 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/convert.py +68 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/interpolate.py +186 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/language.py +233 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/vision.py +503 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/base.py +202 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/cache.py +230 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/florence2.py +366 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/language.py +488 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/vision.py +591 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/gemma3.py +213 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/language.py +315 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/vision.py +238 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/__init__.py +2 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/audio.py +1038 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/config.py +139 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/language.py +629 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/vision.py +1022 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/__init__.py +9 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/idefics2.py +294 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/language.py +191 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/vision.py +267 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/idefics3.py +175 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/language.py +192 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/vision.py +233 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/__init__.py +9 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/language.py +220 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/processor.py +393 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/vision.py +293 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kernels.py +307 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/language.py +509 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/vision.py +522 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/language.py +386 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/llama4.py +138 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/vision.py +560 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/language.py +240 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/llava.py +153 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/vision.py +259 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/__init__.py +9 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/language.py +236 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/vision.py +303 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/language.py +230 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/llava_next.py +160 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/vision.py +243 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/mistral3.py +283 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/language.py +416 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/mllama.py +172 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/vision.py +499 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/language.py +243 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/molmo.py +133 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/vision.py +465 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/__init__.py +10 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/language.py +230 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/sam.py +557 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/vision.py +526 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/language.py +282 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/paligemma.py +160 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/vision.py +242 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/language.py +21 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/su_rope.py +71 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/vision.py +324 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/language.py +229 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/pixtral.py +161 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/vision.py +320 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/config.py +104 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/language.py +490 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/vision.py +312 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/processor.py +476 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/qwen3vl.py +1223 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1309 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_5_vl.py +209 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_vl.py +215 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/prompt_utils.py +474 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/sample_utils.py +39 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/tokenizer_utils.py +344 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/__init__.py +9 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/lora.py +70 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/trainer.py +296 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/utils.py +160 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/utils.py +928 -0
- nexaai/binds/nexa_nexaml/libggml-base.dylib +0 -0
- nexaai/binds/nexa_nexaml/libggml-cpu.so +0 -0
- nexaai/binds/nexa_nexaml/libggml-metal.so +0 -0
- nexaai/binds/nexa_nexaml/libggml.dylib +0 -0
- nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +276 -0
- nexaai/mlx_backend/vlm/interface.py +21 -4
- nexaai/mlx_backend/vlm/main.py +6 -2
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1309 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
- nexaai/utils/manifest_utils.py +222 -15
- nexaai/utils/model_manager.py +83 -7
- nexaai/utils/model_types.py +2 -0
- {nexaai-1.0.19rc6.dist-info → nexaai-1.0.19rc7.dist-info}/METADATA +1 -1
- {nexaai-1.0.19rc6.dist-info → nexaai-1.0.19rc7.dist-info}/RECORD +224 -24
- {nexaai-1.0.19rc6.dist-info → nexaai-1.0.19rc7.dist-info}/WHEEL +0 -0
- {nexaai-1.0.19rc6.dist-info → nexaai-1.0.19rc7.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,572 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import codecs
|
|
3
|
+
import contextlib
|
|
4
|
+
import functools
|
|
5
|
+
import os
|
|
6
|
+
import re
|
|
7
|
+
import time
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
|
|
10
|
+
|
|
11
|
+
import mlx.core as mx
|
|
12
|
+
import mlx.nn as nn
|
|
13
|
+
from mlx_lm.generate import maybe_quantize_kv_cache
|
|
14
|
+
from transformers import PreTrainedTokenizer
|
|
15
|
+
|
|
16
|
+
from .modeling.models import cache
|
|
17
|
+
from .modeling.prompt_utils import apply_chat_template
|
|
18
|
+
from .modeling.sample_utils import top_p_sampling
|
|
19
|
+
from .modeling.utils import (
|
|
20
|
+
StoppingCriteria,
|
|
21
|
+
apply_repetition_penalty,
|
|
22
|
+
load,
|
|
23
|
+
prepare_inputs,
|
|
24
|
+
tree_reduce,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
DEFAULT_MODEL_PATH = "mlx-community/gemma-3-4b-it-8bit"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def parse_media_from_input(user_input):
|
|
31
|
+
"""Parse quoted media files from user input and return prompt and media paths"""
|
|
32
|
+
# Find all quoted strings (both single and double quotes)
|
|
33
|
+
quoted_pattern = r'["\']([^"\']*)["\']'
|
|
34
|
+
quoted_matches = re.findall(quoted_pattern, user_input)
|
|
35
|
+
|
|
36
|
+
# Remove quoted strings from the input to get the actual prompt
|
|
37
|
+
prompt = re.sub(quoted_pattern, '', user_input).strip()
|
|
38
|
+
|
|
39
|
+
# Separate image and audio files based on extensions
|
|
40
|
+
image_extensions = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff', '.webp'}
|
|
41
|
+
audio_extensions = {'.mp3', '.wav', '.flac', '.aac', '.ogg', '.m4a'}
|
|
42
|
+
|
|
43
|
+
image_paths = []
|
|
44
|
+
audio_paths = []
|
|
45
|
+
|
|
46
|
+
for quoted_file in quoted_matches:
|
|
47
|
+
if quoted_file: # Skip empty quotes
|
|
48
|
+
# Expand user path if it starts with ~
|
|
49
|
+
if quoted_file.startswith('~'):
|
|
50
|
+
quoted_file = os.path.expanduser(quoted_file)
|
|
51
|
+
|
|
52
|
+
# Check if file exists
|
|
53
|
+
if not os.path.exists(quoted_file):
|
|
54
|
+
print(f"Warning: File '{quoted_file}' not found")
|
|
55
|
+
continue
|
|
56
|
+
|
|
57
|
+
file_ext = os.path.splitext(quoted_file.lower())[1]
|
|
58
|
+
if file_ext in image_extensions:
|
|
59
|
+
image_paths.append(quoted_file)
|
|
60
|
+
elif file_ext in audio_extensions:
|
|
61
|
+
audio_paths.append(quoted_file)
|
|
62
|
+
|
|
63
|
+
return prompt, image_paths if image_paths else None, audio_paths if audio_paths else None
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def parse_arguments():
|
|
67
|
+
parser = argparse.ArgumentParser(
|
|
68
|
+
description="Generate text from an image using a model."
|
|
69
|
+
)
|
|
70
|
+
parser.add_argument(
|
|
71
|
+
"--model",
|
|
72
|
+
type=str,
|
|
73
|
+
default=DEFAULT_MODEL_PATH,
|
|
74
|
+
help="The path to the local model directory or Hugging Face repo.",
|
|
75
|
+
)
|
|
76
|
+
return parser.parse_args()
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
# A stream on the default device just for generation
|
|
80
|
+
generation_stream = mx.new_stream(mx.default_device())
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@contextlib.contextmanager
|
|
84
|
+
def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None):
|
|
85
|
+
"""
|
|
86
|
+
A context manager to temporarily change the wired limit.
|
|
87
|
+
|
|
88
|
+
Note, the wired limit should not be changed during an async eval. If an
|
|
89
|
+
async eval could be running pass in the streams to synchronize with prior
|
|
90
|
+
to exiting the context manager.
|
|
91
|
+
"""
|
|
92
|
+
model_bytes = tree_reduce(
|
|
93
|
+
lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0
|
|
94
|
+
)
|
|
95
|
+
max_rec_size = mx.metal.device_info()["max_recommended_working_set_size"]
|
|
96
|
+
if model_bytes > 0.9 * max_rec_size:
|
|
97
|
+
model_mb = model_bytes // 2**20
|
|
98
|
+
max_rec_mb = max_rec_size // 2**20
|
|
99
|
+
print(
|
|
100
|
+
f"[WARNING] Generating with a model that requires {model_mb} MB "
|
|
101
|
+
f"which is close to the maximum recommended size of {max_rec_mb} "
|
|
102
|
+
"MB. This can be slow. See the documentation for possible work-arounds: "
|
|
103
|
+
"https://github.com/ml-explore/mlx-lm/tree/main#large-models"
|
|
104
|
+
)
|
|
105
|
+
old_limit = mx.set_wired_limit(max_rec_size)
|
|
106
|
+
try:
|
|
107
|
+
yield None
|
|
108
|
+
finally:
|
|
109
|
+
if streams is not None:
|
|
110
|
+
for s in streams:
|
|
111
|
+
mx.synchronize(s)
|
|
112
|
+
else:
|
|
113
|
+
mx.synchronize()
|
|
114
|
+
mx.set_wired_limit(old_limit)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
@dataclass
|
|
118
|
+
class GenerationResult:
|
|
119
|
+
text: str
|
|
120
|
+
token: Optional[int]
|
|
121
|
+
logprobs: Optional[List[float]]
|
|
122
|
+
prompt_tokens: int
|
|
123
|
+
generation_tokens: int
|
|
124
|
+
prompt_tps: float
|
|
125
|
+
generation_tps: float
|
|
126
|
+
peak_memory: float
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def generate_step(
|
|
130
|
+
input_ids: mx.array,
|
|
131
|
+
model: nn.Module,
|
|
132
|
+
pixel_values,
|
|
133
|
+
mask,
|
|
134
|
+
*,
|
|
135
|
+
max_tokens: int = 256,
|
|
136
|
+
temperature: float = 0.0,
|
|
137
|
+
repetition_penalty: Optional[float] = None,
|
|
138
|
+
repetition_context_size: Optional[int] = 20,
|
|
139
|
+
top_p: float = 1.0,
|
|
140
|
+
logit_bias: Optional[Dict[int, float]] = None,
|
|
141
|
+
prompt_cache: Optional[List[Any]] = None,
|
|
142
|
+
max_kv_size: Optional[int] = None,
|
|
143
|
+
kv_bits: Optional[int] = None,
|
|
144
|
+
kv_group_size: int = 64,
|
|
145
|
+
quantized_kv_start: int = 0,
|
|
146
|
+
**kwargs,
|
|
147
|
+
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
|
148
|
+
"""
|
|
149
|
+
A generator producing token ids based on the given prompt from the model.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
prompt (mx.array): The input prompt.
|
|
153
|
+
model (nn.Module): The model to use for generation.
|
|
154
|
+
temperature (float): The temperature for sampling, if 0 the argmax is used.
|
|
155
|
+
Default: ``0``.
|
|
156
|
+
repetition_penalty (float, optional): The penalty factor for repeating
|
|
157
|
+
tokens.
|
|
158
|
+
repetition_context_size (int, optional): The number of tokens to
|
|
159
|
+
consider for repetition penalty. Default: ``20``.
|
|
160
|
+
top_p (float, optional): Nulceus sampling, higher means model considers
|
|
161
|
+
more less likely words.
|
|
162
|
+
logit_bias (dictionary, optional): Additive logit bias.
|
|
163
|
+
|
|
164
|
+
Yields:
|
|
165
|
+
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing
|
|
166
|
+
one token and a vector of log probabilities.
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
quantize_cache_fn = functools.partial(
|
|
170
|
+
maybe_quantize_kv_cache,
|
|
171
|
+
quantized_kv_start=quantized_kv_start,
|
|
172
|
+
kv_group_size=kv_group_size,
|
|
173
|
+
kv_bits=kv_bits,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
def sample(logits: mx.array) -> Tuple[mx.array, float]:
|
|
177
|
+
if logit_bias:
|
|
178
|
+
indices = mx.array(list(logit_bias.keys()))
|
|
179
|
+
values = mx.array(list(logit_bias.values()))
|
|
180
|
+
logits[:, indices] += values
|
|
181
|
+
logprobs = logits - mx.logsumexp(logits)
|
|
182
|
+
|
|
183
|
+
if temperature == 0:
|
|
184
|
+
token = mx.argmax(logits, axis=-1)
|
|
185
|
+
else:
|
|
186
|
+
if top_p > 0 and top_p < 1.0:
|
|
187
|
+
token = top_p_sampling(logits, top_p, temperature)
|
|
188
|
+
else:
|
|
189
|
+
token = mx.random.categorical(logits * (1 / temperature))
|
|
190
|
+
|
|
191
|
+
return token, logprobs
|
|
192
|
+
|
|
193
|
+
if repetition_penalty and (
|
|
194
|
+
repetition_penalty < 0 or not isinstance(repetition_penalty, float)
|
|
195
|
+
):
|
|
196
|
+
raise ValueError(
|
|
197
|
+
f"repetition_penalty must be a non-negative float, got {repetition_penalty}"
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
y = input_ids
|
|
201
|
+
# Create the KV cache for generation
|
|
202
|
+
if prompt_cache is None:
|
|
203
|
+
prompt_cache = cache.make_prompt_cache(
|
|
204
|
+
model.language_model,
|
|
205
|
+
max_kv_size=max_kv_size,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
repetition_context = input_ids.reshape(-1).tolist()
|
|
209
|
+
|
|
210
|
+
if repetition_context_size:
|
|
211
|
+
repetition_context = repetition_context[-repetition_context_size:]
|
|
212
|
+
|
|
213
|
+
def _step(y, **kwargs):
|
|
214
|
+
with mx.stream(generation_stream):
|
|
215
|
+
nonlocal repetition_context
|
|
216
|
+
if "decoder_input_ids" in kwargs:
|
|
217
|
+
outputs = model.language_model(
|
|
218
|
+
cache=prompt_cache,
|
|
219
|
+
**kwargs,
|
|
220
|
+
)
|
|
221
|
+
else:
|
|
222
|
+
outputs = model.language_model(
|
|
223
|
+
y[None],
|
|
224
|
+
cache=prompt_cache,
|
|
225
|
+
**kwargs,
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
logits = outputs.logits[:, -1, :]
|
|
229
|
+
|
|
230
|
+
if repetition_penalty:
|
|
231
|
+
logits = apply_repetition_penalty(
|
|
232
|
+
logits, repetition_context, repetition_penalty
|
|
233
|
+
)
|
|
234
|
+
y, logprobs = sample(logits)
|
|
235
|
+
repetition_context.append(y.item())
|
|
236
|
+
else:
|
|
237
|
+
y, logprobs = sample(logits)
|
|
238
|
+
|
|
239
|
+
if repetition_context_size:
|
|
240
|
+
if len(repetition_context) > repetition_context_size:
|
|
241
|
+
repetition_context = repetition_context[-repetition_context_size:]
|
|
242
|
+
|
|
243
|
+
quantize_cache_fn(prompt_cache)
|
|
244
|
+
return y, logprobs.squeeze(0)
|
|
245
|
+
|
|
246
|
+
outputs = model(input_ids, pixel_values, cache=prompt_cache, mask=mask, **kwargs)
|
|
247
|
+
|
|
248
|
+
logits = outputs.logits[:, -1, :]
|
|
249
|
+
quantize_cache_fn(prompt_cache)
|
|
250
|
+
y, logprobs = sample(logits)
|
|
251
|
+
mx.async_eval(y)
|
|
252
|
+
|
|
253
|
+
if outputs.cross_attention_states is not None:
|
|
254
|
+
kwargs = {
|
|
255
|
+
k: v
|
|
256
|
+
for k, v in zip(
|
|
257
|
+
["cross_attention_states"], [outputs.cross_attention_states]
|
|
258
|
+
)
|
|
259
|
+
}
|
|
260
|
+
elif outputs.encoder_outputs is not None:
|
|
261
|
+
kwargs = {
|
|
262
|
+
"decoder_input_ids": y[None],
|
|
263
|
+
"encoder_outputs": outputs.encoder_outputs,
|
|
264
|
+
}
|
|
265
|
+
else:
|
|
266
|
+
kwargs = {}
|
|
267
|
+
|
|
268
|
+
n = 0
|
|
269
|
+
while True:
|
|
270
|
+
if n != max_tokens:
|
|
271
|
+
next_y, next_logprobs = _step(y, **kwargs)
|
|
272
|
+
mx.async_eval(next_y)
|
|
273
|
+
if "decoder_input_ids" in kwargs:
|
|
274
|
+
kwargs["decoder_input_ids"] = next_y[None]
|
|
275
|
+
yield y.item(), logprobs
|
|
276
|
+
y, logprobs = next_y, next_logprobs
|
|
277
|
+
if n == max_tokens:
|
|
278
|
+
break
|
|
279
|
+
|
|
280
|
+
n += 1
|
|
281
|
+
|
|
282
|
+
# Periodically clear cache to prevent memory accumulation
|
|
283
|
+
if n % 256 == 0: # Clear cache every 256 tokens
|
|
284
|
+
mx.clear_cache()
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def stream_generate(
|
|
288
|
+
model: nn.Module,
|
|
289
|
+
processor: PreTrainedTokenizer,
|
|
290
|
+
prompt: str,
|
|
291
|
+
image: Union[str, List[str]] = None,
|
|
292
|
+
audio: Union[str, List[str]] = None,
|
|
293
|
+
**kwargs,
|
|
294
|
+
) -> Union[str, Generator[str, None, None]]:
|
|
295
|
+
"""
|
|
296
|
+
A generator producing text based on the given prompt from the model.
|
|
297
|
+
|
|
298
|
+
Args:
|
|
299
|
+
prompt (mx.array): The input prompt.
|
|
300
|
+
model (nn.Module): The model to use for generation.
|
|
301
|
+
max_tokens (int): The ma
|
|
302
|
+
kwargs: The remaining options get passed to :func:`generate_step`.
|
|
303
|
+
See :func:`generate_step` for more details.
|
|
304
|
+
|
|
305
|
+
Yields:
|
|
306
|
+
Generator[Tuple[mx.array, mx.array]]: A generator producing text.
|
|
307
|
+
"""
|
|
308
|
+
tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
|
|
309
|
+
|
|
310
|
+
# Skip special tokens
|
|
311
|
+
skip_special_tokens = kwargs.pop("skip_special_tokens", False)
|
|
312
|
+
skip_special_token_ids = (
|
|
313
|
+
set(tokenizer.all_special_ids)
|
|
314
|
+
if skip_special_tokens and hasattr(tokenizer, "all_special_ids")
|
|
315
|
+
else []
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
add_special_tokens = (
|
|
319
|
+
not hasattr(processor, "chat_template")
|
|
320
|
+
if model.config.model_type in ["gemma3", "gemma3n"]
|
|
321
|
+
else True
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
resize_shape = kwargs.pop("resize_shape", None)
|
|
325
|
+
image_token_index = getattr(model.config, "image_token_index", None)
|
|
326
|
+
|
|
327
|
+
if kwargs.get("input_ids", None) is not None:
|
|
328
|
+
input_ids = kwargs.pop("input_ids")
|
|
329
|
+
pixel_values = kwargs.pop("pixel_values", None)
|
|
330
|
+
mask = kwargs.pop("mask", None)
|
|
331
|
+
else:
|
|
332
|
+
inputs = prepare_inputs(
|
|
333
|
+
processor,
|
|
334
|
+
images=image,
|
|
335
|
+
audio=audio,
|
|
336
|
+
prompts=prompt,
|
|
337
|
+
image_token_index=image_token_index,
|
|
338
|
+
resize_shape=resize_shape,
|
|
339
|
+
add_special_tokens=add_special_tokens,
|
|
340
|
+
)
|
|
341
|
+
input_ids = inputs.get("input_ids", None)
|
|
342
|
+
pixel_values = inputs.get("pixel_values", None)
|
|
343
|
+
mask = inputs.get("attention_mask", None)
|
|
344
|
+
data_kwargs = {
|
|
345
|
+
k: v
|
|
346
|
+
for k, v in inputs.items()
|
|
347
|
+
if k not in ["input_ids", "pixel_values", "attention_mask"]
|
|
348
|
+
}
|
|
349
|
+
kwargs.update(data_kwargs)
|
|
350
|
+
|
|
351
|
+
with wired_limit(model, [generation_stream]):
|
|
352
|
+
detokenizer = processor.detokenizer
|
|
353
|
+
detokenizer.reset()
|
|
354
|
+
tic = time.perf_counter()
|
|
355
|
+
for n, (token, logprobs) in enumerate(
|
|
356
|
+
generate_step(input_ids, model, pixel_values, mask, **kwargs)
|
|
357
|
+
):
|
|
358
|
+
if n == 0:
|
|
359
|
+
prompt_time = time.perf_counter() - tic
|
|
360
|
+
prompt_tps = input_ids.size / prompt_time
|
|
361
|
+
tic = time.perf_counter()
|
|
362
|
+
|
|
363
|
+
# Stop generation if the token is in the eos_token_ids
|
|
364
|
+
if tokenizer.stopping_criteria(token):
|
|
365
|
+
break
|
|
366
|
+
|
|
367
|
+
detokenizer.add_token(token, skip_special_token_ids=skip_special_token_ids)
|
|
368
|
+
|
|
369
|
+
# Yield the last segment if streaming
|
|
370
|
+
yield GenerationResult(
|
|
371
|
+
text=detokenizer.last_segment,
|
|
372
|
+
token=token,
|
|
373
|
+
logprobs=logprobs,
|
|
374
|
+
prompt_tokens=input_ids.size,
|
|
375
|
+
generation_tokens=n + 1,
|
|
376
|
+
prompt_tps=prompt_tps,
|
|
377
|
+
generation_tps=(n + 1) / (time.perf_counter() - tic),
|
|
378
|
+
peak_memory=mx.get_peak_memory() / 1e9,
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
detokenizer.finalize()
|
|
382
|
+
yield GenerationResult(
|
|
383
|
+
text=detokenizer.last_segment,
|
|
384
|
+
token=token,
|
|
385
|
+
logprobs=logprobs,
|
|
386
|
+
prompt_tokens=input_ids.size,
|
|
387
|
+
generation_tokens=n + 1,
|
|
388
|
+
prompt_tps=prompt_tps,
|
|
389
|
+
generation_tps=(n + 1) / (time.perf_counter() - tic),
|
|
390
|
+
peak_memory=mx.get_peak_memory() / 1e9,
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
# Cleanup after generation
|
|
394
|
+
mx.clear_cache()
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
def generate(
|
|
398
|
+
model: nn.Module,
|
|
399
|
+
processor: PreTrainedTokenizer,
|
|
400
|
+
prompt: str,
|
|
401
|
+
image: Union[str, List[str]] = None,
|
|
402
|
+
audio: Union[str, List[str]] = None,
|
|
403
|
+
verbose: bool = False,
|
|
404
|
+
**kwargs,
|
|
405
|
+
) -> str:
|
|
406
|
+
"""
|
|
407
|
+
Generate text from the model.
|
|
408
|
+
|
|
409
|
+
Args:
|
|
410
|
+
model (nn.Module): The language model.
|
|
411
|
+
tokenizer (PreTrainedTokenizer): The tokenizer.
|
|
412
|
+
prompt (str): The string prompt.
|
|
413
|
+
temperature (float): The temperature for sampling (default 0).
|
|
414
|
+
max_tokens (int): The maximum number of tokens (default 100).
|
|
415
|
+
verbose (bool): If ``True``, print tokens and timing information
|
|
416
|
+
(default ``False``).
|
|
417
|
+
formatter (Optional[Callable]): A function which takes a token and a
|
|
418
|
+
probability and displays it.
|
|
419
|
+
repetition_penalty (float, optional): The penalty factor for repeating tokens.
|
|
420
|
+
repetition_context_size (int, optional): The number of tokens to consider for repetition penalty.
|
|
421
|
+
"""
|
|
422
|
+
|
|
423
|
+
if verbose:
|
|
424
|
+
print("=" * 10)
|
|
425
|
+
files = []
|
|
426
|
+
if image is not None:
|
|
427
|
+
files.extend(image)
|
|
428
|
+
if audio is not None:
|
|
429
|
+
files.extend(audio)
|
|
430
|
+
if kwargs.get("video") is not None:
|
|
431
|
+
files.extend(kwargs.get("video"))
|
|
432
|
+
|
|
433
|
+
print(f"Files: {files}", "\n")
|
|
434
|
+
|
|
435
|
+
print("Prompt:", prompt)
|
|
436
|
+
|
|
437
|
+
text = ""
|
|
438
|
+
last_response = None
|
|
439
|
+
|
|
440
|
+
eos_tokens = kwargs.get("eos_tokens", None)
|
|
441
|
+
stopping_criteria = kwargs.get("stopping_criteria", None)
|
|
442
|
+
|
|
443
|
+
# Get the tokenizer
|
|
444
|
+
tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
|
|
445
|
+
|
|
446
|
+
# Add custom EOS tokens to the stopping criteria
|
|
447
|
+
if eos_tokens is not None:
|
|
448
|
+
tokenizer.stopping_criteria.add_eos_token_ids(eos_tokens)
|
|
449
|
+
|
|
450
|
+
# Use custom stopping criteria
|
|
451
|
+
elif stopping_criteria is not None:
|
|
452
|
+
if isinstance(stopping_criteria, StoppingCriteria) or callable(
|
|
453
|
+
stopping_criteria
|
|
454
|
+
):
|
|
455
|
+
tokenizer.stopping_criteria = stopping_criteria
|
|
456
|
+
else:
|
|
457
|
+
raise ValueError(
|
|
458
|
+
"stopping_criteria must be an instance of StoppingCriteria or a callable"
|
|
459
|
+
)
|
|
460
|
+
else:
|
|
461
|
+
tokenizer.stopping_criteria.reset(model.config.eos_token_id)
|
|
462
|
+
|
|
463
|
+
for response in stream_generate(model, processor, prompt, image, audio, **kwargs):
|
|
464
|
+
if verbose:
|
|
465
|
+
print(response.text, end="", flush=True)
|
|
466
|
+
text += response.text
|
|
467
|
+
last_response = response
|
|
468
|
+
|
|
469
|
+
if verbose:
|
|
470
|
+
print("\n" + "=" * 10)
|
|
471
|
+
if len(text) == 0:
|
|
472
|
+
print("No text generated for this prompt")
|
|
473
|
+
return
|
|
474
|
+
print(
|
|
475
|
+
f"Prompt: {last_response.prompt_tokens} tokens, "
|
|
476
|
+
f"{last_response.prompt_tps:.3f} tokens-per-sec"
|
|
477
|
+
)
|
|
478
|
+
print(
|
|
479
|
+
f"Generation: {last_response.generation_tokens} tokens, "
|
|
480
|
+
f"{last_response.generation_tps:.3f} tokens-per-sec"
|
|
481
|
+
)
|
|
482
|
+
print(f"Peak memory: {last_response.peak_memory:.3f} GB")
|
|
483
|
+
|
|
484
|
+
usage_stats = {
|
|
485
|
+
"input_tokens": last_response.prompt_tokens,
|
|
486
|
+
"output_tokens": last_response.generation_tokens,
|
|
487
|
+
"total_tokens": last_response.prompt_tokens + last_response.generation_tokens,
|
|
488
|
+
"prompt_tps": last_response.prompt_tps,
|
|
489
|
+
"generation_tps": last_response.generation_tps,
|
|
490
|
+
"peak_memory": last_response.peak_memory,
|
|
491
|
+
}
|
|
492
|
+
|
|
493
|
+
return text, usage_stats
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
def main():
|
|
497
|
+
args = parse_arguments()
|
|
498
|
+
|
|
499
|
+
# Load model and processor
|
|
500
|
+
model, processor = load(args.model, None)
|
|
501
|
+
config = model.config
|
|
502
|
+
|
|
503
|
+
# Initialize chat history
|
|
504
|
+
chat = []
|
|
505
|
+
|
|
506
|
+
print("Multi-round conversation started. Type 'exit' or 'quit' to stop.")
|
|
507
|
+
print("You can include image/audio files in quotes, e.g.: 'what does this image mean \"/path/to/image.jpg\"'")
|
|
508
|
+
print("=" * 50)
|
|
509
|
+
|
|
510
|
+
# Main chat loop
|
|
511
|
+
while True:
|
|
512
|
+
try:
|
|
513
|
+
user_input = input("User: ").strip()
|
|
514
|
+
|
|
515
|
+
# Exit conditions
|
|
516
|
+
if user_input.lower() in ['exit', 'quit', '']:
|
|
517
|
+
break
|
|
518
|
+
|
|
519
|
+
# Parse media files from user input
|
|
520
|
+
prompt_text, image_paths, audio_paths = parse_media_from_input(user_input)
|
|
521
|
+
|
|
522
|
+
# If no text prompt after parsing, use the original input
|
|
523
|
+
if not prompt_text.strip():
|
|
524
|
+
prompt_text = user_input
|
|
525
|
+
image_paths = None
|
|
526
|
+
audio_paths = None
|
|
527
|
+
|
|
528
|
+
# Add user message to chat history
|
|
529
|
+
chat.append({"role": "user", "content": prompt_text})
|
|
530
|
+
|
|
531
|
+
# Calculate number of images for chat template
|
|
532
|
+
num_images = len(image_paths) if image_paths else 0
|
|
533
|
+
num_audios = len(audio_paths) if audio_paths else 0
|
|
534
|
+
|
|
535
|
+
# Apply chat template
|
|
536
|
+
formatted_prompt = apply_chat_template(
|
|
537
|
+
processor, config, chat, num_images=num_images, num_audios=num_audios
|
|
538
|
+
)
|
|
539
|
+
|
|
540
|
+
# Generate response
|
|
541
|
+
response = ""
|
|
542
|
+
print("Assistant: ", end="", flush=True)
|
|
543
|
+
|
|
544
|
+
for chunk in stream_generate(
|
|
545
|
+
model,
|
|
546
|
+
processor,
|
|
547
|
+
formatted_prompt,
|
|
548
|
+
image_paths,
|
|
549
|
+
audio_paths,
|
|
550
|
+
max_tokens=100,
|
|
551
|
+
temperature=0.7,
|
|
552
|
+
top_p=0.9,
|
|
553
|
+
verbose=True,
|
|
554
|
+
):
|
|
555
|
+
response += chunk.text
|
|
556
|
+
print(chunk.text, end="", flush=True)
|
|
557
|
+
|
|
558
|
+
print() # New line after response
|
|
559
|
+
|
|
560
|
+
# Add assistant response to chat history
|
|
561
|
+
chat.append({"role": "assistant", "content": response})
|
|
562
|
+
|
|
563
|
+
except KeyboardInterrupt:
|
|
564
|
+
print("\nConversation interrupted by user.")
|
|
565
|
+
break
|
|
566
|
+
except Exception as e:
|
|
567
|
+
print(f"Error: {e}")
|
|
568
|
+
continue
|
|
569
|
+
|
|
570
|
+
|
|
571
|
+
if __name__ == "__main__":
|
|
572
|
+
main()
|