nexaai 1.0.19rc6__cp310-cp310-macosx_14_0_universal2.whl → 1.0.19rc7__cp310-cp310-macosx_14_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nexaai might be problematic. Click here for more details.
- nexaai/_stub.cpython-310-darwin.so +0 -0
- nexaai/_version.py +1 -1
- nexaai/binds/libnexa_bridge.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libggml-base.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libggml-cpu.so +0 -0
- nexaai/binds/nexa_llama_cpp/libggml-metal.so +0 -0
- nexaai/binds/nexa_llama_cpp/libggml.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libllama.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libmtmd.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexa_mlx/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexa_mlx/py-lib/asr/__init__.py +12 -0
- nexaai/binds/nexa_mlx/py-lib/asr/interface.py +122 -0
- nexaai/binds/nexa_mlx/py-lib/common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/common/utils.py +25 -0
- nexaai/binds/nexa_mlx/py-lib/cv/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/cv/generate.py +195 -0
- nexaai/binds/nexa_mlx/py-lib/cv/interface.py +151 -0
- nexaai/binds/nexa_mlx/py-lib/cv/main.py +81 -0
- nexaai/binds/nexa_mlx/py-lib/cv/modeling/pp_ocr_v4.py +1736 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/generate.py +333 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/interface.py +617 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/main.py +173 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/modeling/nexa_jina_v2.py +399 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/generate_sd.py +244 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/interface.py +82 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/main.py +281 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/__init__.py +306 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/clip.py +116 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/config.py +65 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/model_io.py +386 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/sampler.py +105 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/tokenizer.py +100 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/unet.py +460 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/vae.py +274 -0
- nexaai/binds/nexa_mlx/py-lib/llm/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/llm/generate.py +149 -0
- nexaai/binds/nexa_mlx/py-lib/llm/interface.py +764 -0
- nexaai/binds/nexa_mlx/py-lib/llm/main.py +68 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/generate.py +174 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/interface.py +287 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/main.py +127 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/modeling/nexa_jina_rerank.py +330 -0
- nexaai/binds/nexa_mlx/py-lib/sd/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/sd/interface.py +362 -0
- nexaai/binds/nexa_mlx/py-lib/sd/main.py +286 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/__init__.py +306 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/clip.py +116 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/config.py +65 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/model_io.py +385 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/sampler.py +105 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/tokenizer.py +100 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/unet.py +460 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/vae.py +274 -0
- nexaai/binds/nexa_mlx/py-lib/tts/__init__.py +12 -0
- nexaai/binds/nexa_mlx/py-lib/tts/interface.py +276 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/__init__.py +3 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/generate.py +572 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl.py +294 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl_moe.py +276 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/interface.py +504 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/main.py +320 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/convert.py +68 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/interpolate.py +186 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/language.py +233 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/vision.py +503 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/base.py +202 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/cache.py +230 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/florence2.py +366 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/language.py +488 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/vision.py +591 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/gemma3.py +213 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/language.py +315 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/vision.py +238 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/__init__.py +2 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/audio.py +1038 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/config.py +139 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/language.py +629 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/vision.py +1022 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/__init__.py +9 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/idefics2.py +294 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/language.py +191 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/vision.py +267 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/idefics3.py +175 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/language.py +192 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/vision.py +233 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/__init__.py +9 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/language.py +220 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/processor.py +393 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/vision.py +293 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kernels.py +307 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/language.py +509 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/vision.py +522 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/language.py +386 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/llama4.py +138 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/vision.py +560 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/language.py +240 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/llava.py +153 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/vision.py +259 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/__init__.py +9 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/language.py +236 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/vision.py +303 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/language.py +230 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/llava_next.py +160 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/vision.py +243 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/mistral3.py +283 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/language.py +416 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/mllama.py +172 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/vision.py +499 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/language.py +243 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/molmo.py +133 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/vision.py +465 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/__init__.py +10 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/language.py +230 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/sam.py +557 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/vision.py +526 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/language.py +282 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/paligemma.py +160 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/vision.py +242 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/language.py +21 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/su_rope.py +71 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/vision.py +324 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/language.py +229 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/pixtral.py +161 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/vision.py +320 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/config.py +104 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/language.py +490 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/vision.py +312 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/processor.py +476 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/qwen3vl.py +1223 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1309 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_5_vl.py +209 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_vl.py +215 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/prompt_utils.py +474 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/sample_utils.py +39 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/tokenizer_utils.py +344 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/__init__.py +9 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/lora.py +70 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/trainer.py +296 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/utils.py +160 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/utils.py +928 -0
- nexaai/binds/nexa_nexaml/libggml-base.dylib +0 -0
- nexaai/binds/nexa_nexaml/libggml-cpu.so +0 -0
- nexaai/binds/nexa_nexaml/libggml-metal.so +0 -0
- nexaai/binds/nexa_nexaml/libggml.dylib +0 -0
- nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +276 -0
- nexaai/mlx_backend/vlm/interface.py +21 -4
- nexaai/mlx_backend/vlm/main.py +6 -2
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1309 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
- nexaai/utils/manifest_utils.py +222 -15
- nexaai/utils/model_manager.py +83 -7
- nexaai/utils/model_types.py +2 -0
- {nexaai-1.0.19rc6.dist-info → nexaai-1.0.19rc7.dist-info}/METADATA +1 -1
- {nexaai-1.0.19rc6.dist-info → nexaai-1.0.19rc7.dist-info}/RECORD +224 -24
- {nexaai-1.0.19rc6.dist-info → nexaai-1.0.19rc7.dist-info}/WHEEL +0 -0
- {nexaai-1.0.19rc6.dist-info → nexaai-1.0.19rc7.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,330 @@
|
|
|
1
|
+
# Copyright © 2023-2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Any, Dict, List, Optional, Union
|
|
6
|
+
|
|
7
|
+
import mlx.core as mx
|
|
8
|
+
import mlx.nn as nn
|
|
9
|
+
|
|
10
|
+
from mlx_lm.models.base import (
|
|
11
|
+
BaseModelArgs,
|
|
12
|
+
scaled_dot_product_attention,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass
|
|
17
|
+
class ModelArgs(BaseModelArgs):
|
|
18
|
+
model_type: str = "xlm_roberta"
|
|
19
|
+
vocab_size: int = 250002
|
|
20
|
+
hidden_size: int = 768
|
|
21
|
+
num_hidden_layers: int = 12
|
|
22
|
+
num_attention_heads: int = 12
|
|
23
|
+
intermediate_size: int = 3072
|
|
24
|
+
hidden_act: str = "gelu"
|
|
25
|
+
hidden_dropout_prob: float = 0.1
|
|
26
|
+
attention_probs_dropout_prob: float = 0.1
|
|
27
|
+
max_position_embeddings: int = 1026
|
|
28
|
+
type_vocab_size: int = 1
|
|
29
|
+
initializer_range: float = 0.02
|
|
30
|
+
layer_norm_eps: float = 1e-05
|
|
31
|
+
pad_token_id: int = 1
|
|
32
|
+
bos_token_id: int = 0
|
|
33
|
+
eos_token_id: int = 2
|
|
34
|
+
position_embedding_type: str = "absolute"
|
|
35
|
+
use_cache: bool = True
|
|
36
|
+
classifier_dropout: Optional[float] = None
|
|
37
|
+
num_labels: int = 1
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class XLMRobertaEmbeddings(nn.Module):
|
|
41
|
+
def __init__(self, config: ModelArgs):
|
|
42
|
+
super().__init__()
|
|
43
|
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
|
|
44
|
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
|
45
|
+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
|
46
|
+
|
|
47
|
+
def __call__(
|
|
48
|
+
self,
|
|
49
|
+
input_ids: Optional[mx.array] = None,
|
|
50
|
+
position_ids: Optional[mx.array] = None,
|
|
51
|
+
token_type_ids: Optional[mx.array] = None,
|
|
52
|
+
) -> mx.array:
|
|
53
|
+
if token_type_ids is None:
|
|
54
|
+
token_type_ids = mx.zeros_like(input_ids)
|
|
55
|
+
|
|
56
|
+
inputs_embeds = self.word_embeddings(input_ids)
|
|
57
|
+
position_embeddings = self.position_embeddings(position_ids)
|
|
58
|
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
|
59
|
+
|
|
60
|
+
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
|
|
61
|
+
return embeddings
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class SelfAttention(nn.Module):
|
|
65
|
+
def __init__(self, config: ModelArgs):
|
|
66
|
+
super().__init__()
|
|
67
|
+
self.num_attention_heads = config.num_attention_heads
|
|
68
|
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
|
69
|
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
|
70
|
+
|
|
71
|
+
def __call__(
|
|
72
|
+
self,
|
|
73
|
+
qkv: mx.array,
|
|
74
|
+
key_padding_mask: Optional[mx.array] = None,
|
|
75
|
+
) -> mx.array:
|
|
76
|
+
# qkv shape: [batch, seqlen, 3, num_heads, head_dim]
|
|
77
|
+
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
|
|
78
|
+
q, k, v = mx.split(qkv, 3, axis=2) # Each: [batch, seqlen, 1, num_heads, head_dim]
|
|
79
|
+
q = mx.squeeze(q, axis=2) # [batch, seqlen, num_heads, head_dim]
|
|
80
|
+
k = mx.squeeze(k, axis=2)
|
|
81
|
+
v = mx.squeeze(v, axis=2)
|
|
82
|
+
|
|
83
|
+
# Transpose for attention computation: [batch, num_heads, seqlen, head_dim]
|
|
84
|
+
q = mx.transpose(q, (0, 2, 1, 3))
|
|
85
|
+
k = mx.transpose(k, (0, 2, 1, 3))
|
|
86
|
+
v = mx.transpose(v, (0, 2, 1, 3))
|
|
87
|
+
|
|
88
|
+
scale = 1.0 / math.sqrt(self.attention_head_size)
|
|
89
|
+
|
|
90
|
+
mask = None
|
|
91
|
+
if key_padding_mask is not None:
|
|
92
|
+
# key_padding_mask: [batch, seqlen] where True means keep, False means mask
|
|
93
|
+
# Convert to attention mask: [batch, 1, 1, seqlen]
|
|
94
|
+
mask = mx.expand_dims(mx.expand_dims(key_padding_mask, axis=1), axis=1)
|
|
95
|
+
# Use the same dtype as the query tensor to match model dtype
|
|
96
|
+
target_dtype = q.dtype
|
|
97
|
+
mask = (1.0 - mask.astype(target_dtype)) * -10000.0
|
|
98
|
+
|
|
99
|
+
context = scaled_dot_product_attention(q, k, v, cache=None, scale=scale, mask=mask)
|
|
100
|
+
|
|
101
|
+
# Transpose back and reshape: [batch, seqlen, hidden_size]
|
|
102
|
+
context = mx.transpose(context, (0, 2, 1, 3))
|
|
103
|
+
new_context_shape = context.shape[:-2] + (self.all_head_size,)
|
|
104
|
+
context = mx.reshape(context, new_context_shape)
|
|
105
|
+
return context
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class MHA(nn.Module):
|
|
109
|
+
def __init__(self, config: ModelArgs):
|
|
110
|
+
super().__init__()
|
|
111
|
+
self.embed_dim = config.hidden_size
|
|
112
|
+
self.num_heads = config.num_attention_heads
|
|
113
|
+
self.head_dim = self.embed_dim // self.num_heads
|
|
114
|
+
|
|
115
|
+
# QKV projection
|
|
116
|
+
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads) # q + k + v
|
|
117
|
+
self.Wqkv = nn.Linear(self.embed_dim, qkv_dim, bias=True)
|
|
118
|
+
|
|
119
|
+
# Self attention
|
|
120
|
+
self.inner_attn = SelfAttention(config)
|
|
121
|
+
|
|
122
|
+
# Output projection
|
|
123
|
+
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
|
|
124
|
+
|
|
125
|
+
def __call__(
|
|
126
|
+
self,
|
|
127
|
+
x: mx.array,
|
|
128
|
+
key_padding_mask: Optional[mx.array] = None,
|
|
129
|
+
) -> tuple:
|
|
130
|
+
residual = x
|
|
131
|
+
qkv = self.Wqkv(x)
|
|
132
|
+
|
|
133
|
+
# Reshape to [batch, seqlen, 3, num_heads, head_dim]
|
|
134
|
+
batch, seqlen = qkv.shape[0], qkv.shape[1]
|
|
135
|
+
qkv = mx.reshape(qkv, (batch, seqlen, 3, self.num_heads, self.head_dim))
|
|
136
|
+
|
|
137
|
+
context = self.inner_attn(qkv, key_padding_mask=key_padding_mask)
|
|
138
|
+
out = self.out_proj(context)
|
|
139
|
+
|
|
140
|
+
return out, residual
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class Mlp(nn.Module):
|
|
144
|
+
def __init__(self, config: ModelArgs):
|
|
145
|
+
super().__init__()
|
|
146
|
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=True)
|
|
147
|
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size, bias=True)
|
|
148
|
+
|
|
149
|
+
def __call__(self, x: mx.array) -> tuple:
|
|
150
|
+
residual = x
|
|
151
|
+
y = self.fc1(x)
|
|
152
|
+
y = nn.gelu(y)
|
|
153
|
+
y = self.fc2(y)
|
|
154
|
+
return y, residual
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class Block(nn.Module):
|
|
158
|
+
def __init__(self, config: ModelArgs):
|
|
159
|
+
super().__init__()
|
|
160
|
+
self.mixer = MHA(config)
|
|
161
|
+
self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
162
|
+
self.mlp = Mlp(config)
|
|
163
|
+
self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
164
|
+
|
|
165
|
+
def __call__(
|
|
166
|
+
self,
|
|
167
|
+
hidden_states: mx.array,
|
|
168
|
+
mixer_kwargs: Optional[dict] = None,
|
|
169
|
+
) -> mx.array:
|
|
170
|
+
mixer_kwargs = mixer_kwargs or {}
|
|
171
|
+
|
|
172
|
+
# Attention block
|
|
173
|
+
mixer_out, residual = self.mixer(hidden_states, **mixer_kwargs)
|
|
174
|
+
hidden_states = self.norm1(mixer_out + residual)
|
|
175
|
+
|
|
176
|
+
# MLP block
|
|
177
|
+
mlp_out, residual = self.mlp(hidden_states)
|
|
178
|
+
hidden_states = self.norm2(mlp_out + residual)
|
|
179
|
+
|
|
180
|
+
return hidden_states
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
class XLMRobertaEncoder(nn.Module):
|
|
184
|
+
def __init__(self, config: ModelArgs):
|
|
185
|
+
super().__init__()
|
|
186
|
+
# Create layers list to match torch naming
|
|
187
|
+
self.layers = [Block(config) for _ in range(config.num_hidden_layers)]
|
|
188
|
+
|
|
189
|
+
def __call__(
|
|
190
|
+
self,
|
|
191
|
+
hidden_states: mx.array,
|
|
192
|
+
key_padding_mask: Optional[mx.array] = None,
|
|
193
|
+
) -> mx.array:
|
|
194
|
+
mixer_kwargs = None
|
|
195
|
+
if key_padding_mask is not None:
|
|
196
|
+
mixer_kwargs = {"key_padding_mask": key_padding_mask}
|
|
197
|
+
|
|
198
|
+
# Access layers from the list
|
|
199
|
+
for layer_module in self.layers:
|
|
200
|
+
hidden_states = layer_module(hidden_states, mixer_kwargs=mixer_kwargs)
|
|
201
|
+
|
|
202
|
+
return hidden_states
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class XLMRobertaModel(nn.Module):
|
|
206
|
+
def __init__(self, config: ModelArgs):
|
|
207
|
+
super().__init__()
|
|
208
|
+
self.config = config
|
|
209
|
+
self.embeddings = XLMRobertaEmbeddings(config)
|
|
210
|
+
self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
|
211
|
+
self.encoder = XLMRobertaEncoder(config)
|
|
212
|
+
|
|
213
|
+
def __call__(
|
|
214
|
+
self,
|
|
215
|
+
input_ids: mx.array,
|
|
216
|
+
attention_mask: Optional[mx.array] = None,
|
|
217
|
+
token_type_ids: Optional[mx.array] = None,
|
|
218
|
+
position_ids: Optional[mx.array] = None,
|
|
219
|
+
) -> mx.array:
|
|
220
|
+
hidden_states = self.embeddings(
|
|
221
|
+
input_ids=input_ids,
|
|
222
|
+
position_ids=position_ids,
|
|
223
|
+
token_type_ids=token_type_ids,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
hidden_states = self.emb_ln(hidden_states)
|
|
227
|
+
|
|
228
|
+
# Convert attention_mask for padding (True=keep, False=mask)
|
|
229
|
+
key_padding_mask = attention_mask
|
|
230
|
+
|
|
231
|
+
sequence_output = self.encoder(hidden_states, key_padding_mask=key_padding_mask)
|
|
232
|
+
|
|
233
|
+
return sequence_output
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
class XLMRobertaClassificationHead(nn.Module):
|
|
237
|
+
def __init__(self, config: ModelArgs):
|
|
238
|
+
super().__init__()
|
|
239
|
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
|
240
|
+
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
|
|
241
|
+
|
|
242
|
+
def __call__(self, features: mx.array) -> mx.array:
|
|
243
|
+
x = features[:, 0, :] # take first token (equivalent to [CLS])
|
|
244
|
+
x = self.dense(x)
|
|
245
|
+
x = mx.tanh(x)
|
|
246
|
+
x = self.out_proj(x)
|
|
247
|
+
return x
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
class XLMRobertaForSequenceClassification(nn.Module):
|
|
251
|
+
def __init__(self, config: ModelArgs):
|
|
252
|
+
super().__init__()
|
|
253
|
+
self.num_labels = config.num_labels
|
|
254
|
+
self.config = config
|
|
255
|
+
self.roberta = XLMRobertaModel(config)
|
|
256
|
+
self.classifier = XLMRobertaClassificationHead(config)
|
|
257
|
+
|
|
258
|
+
def __call__(
|
|
259
|
+
self,
|
|
260
|
+
input_ids: mx.array,
|
|
261
|
+
attention_mask: Optional[mx.array] = None,
|
|
262
|
+
token_type_ids: Optional[mx.array] = None,
|
|
263
|
+
position_ids: Optional[mx.array] = None,
|
|
264
|
+
) -> mx.array:
|
|
265
|
+
sequence_output = self.roberta(
|
|
266
|
+
input_ids=input_ids,
|
|
267
|
+
attention_mask=attention_mask,
|
|
268
|
+
token_type_ids=token_type_ids,
|
|
269
|
+
position_ids=position_ids,
|
|
270
|
+
)
|
|
271
|
+
logits = self.classifier(sequence_output)
|
|
272
|
+
return logits
|
|
273
|
+
|
|
274
|
+
def nexa_forward(
|
|
275
|
+
self,
|
|
276
|
+
input_ids: mx.array,
|
|
277
|
+
attention_mask: mx.array,
|
|
278
|
+
token_type_ids: mx.array,
|
|
279
|
+
position_ids: mx.array,
|
|
280
|
+
) -> mx.array:
|
|
281
|
+
return self(
|
|
282
|
+
input_ids=input_ids,
|
|
283
|
+
attention_mask=attention_mask,
|
|
284
|
+
token_type_ids=token_type_ids,
|
|
285
|
+
position_ids=position_ids,
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
class Model(nn.Module):
|
|
290
|
+
def __init__(self, args: ModelArgs):
|
|
291
|
+
super().__init__()
|
|
292
|
+
self.args = args
|
|
293
|
+
self.model_type = args.model_type
|
|
294
|
+
self.model = XLMRobertaForSequenceClassification(args)
|
|
295
|
+
|
|
296
|
+
def __call__(
|
|
297
|
+
self,
|
|
298
|
+
input_ids: mx.array,
|
|
299
|
+
attention_mask: Optional[mx.array] = None,
|
|
300
|
+
token_type_ids: Optional[mx.array] = None,
|
|
301
|
+
position_ids: Optional[mx.array] = None,
|
|
302
|
+
) -> mx.array:
|
|
303
|
+
return self.model(
|
|
304
|
+
input_ids=input_ids,
|
|
305
|
+
attention_mask=attention_mask,
|
|
306
|
+
token_type_ids=token_type_ids,
|
|
307
|
+
position_ids=position_ids,
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
def nexa_forward(
|
|
311
|
+
self,
|
|
312
|
+
input_ids: mx.array,
|
|
313
|
+
attention_mask: mx.array,
|
|
314
|
+
token_type_ids: mx.array,
|
|
315
|
+
position_ids: mx.array,
|
|
316
|
+
) -> mx.array:
|
|
317
|
+
return self.model.nexa_forward(
|
|
318
|
+
input_ids=input_ids,
|
|
319
|
+
attention_mask=attention_mask,
|
|
320
|
+
token_type_ids=token_type_ids,
|
|
321
|
+
position_ids=position_ids,
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
def sanitize(self, weights):
|
|
325
|
+
"""Remove parameters that don't exist in our model"""
|
|
326
|
+
return weights
|
|
327
|
+
|
|
328
|
+
@property
|
|
329
|
+
def layers(self):
|
|
330
|
+
return self.model.roberta.encoder.layers
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Stable Diffusion MLX interface package"""
|
|
@@ -0,0 +1,362 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
from typing import (
|
|
5
|
+
Any,
|
|
6
|
+
Callable,
|
|
7
|
+
List,
|
|
8
|
+
Optional,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
import mlx.core as mx
|
|
12
|
+
import numpy as np
|
|
13
|
+
from PIL import Image as PILImage
|
|
14
|
+
import mlx.nn as nn
|
|
15
|
+
import os
|
|
16
|
+
|
|
17
|
+
from .modeling import StableDiffusion, StableDiffusionXL
|
|
18
|
+
|
|
19
|
+
# --------------------------------------------------------------------------------------
|
|
20
|
+
# Core aliases & callback protocols
|
|
21
|
+
# --------------------------------------------------------------------------------------
|
|
22
|
+
|
|
23
|
+
Path = str
|
|
24
|
+
LogCallback = Callable[[str], None]
|
|
25
|
+
|
|
26
|
+
# --------------------------------------------------------------------------------------
|
|
27
|
+
# Core module functions
|
|
28
|
+
# --------------------------------------------------------------------------------------
|
|
29
|
+
|
|
30
|
+
def init() -> None:
|
|
31
|
+
"""Initialize the stable diffusion module"""
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
def deinit() -> None:
|
|
35
|
+
"""Deinitialize the stable diffusion module"""
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
def set_log(callback: LogCallback) -> None:
|
|
39
|
+
"""Set the logging callback"""
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
def log(message: str) -> None:
|
|
43
|
+
"""Log a message"""
|
|
44
|
+
print(message)
|
|
45
|
+
|
|
46
|
+
# --------------------------------------------------------------------------------------
|
|
47
|
+
# Basic data structures
|
|
48
|
+
# --------------------------------------------------------------------------------------
|
|
49
|
+
|
|
50
|
+
class Image:
|
|
51
|
+
def __init__(self, data: List[float], width: int, height: int, channels: int) -> None:
|
|
52
|
+
"""Initialize an image with pixel data"""
|
|
53
|
+
self.data = data
|
|
54
|
+
self.width = width
|
|
55
|
+
self.height = height
|
|
56
|
+
self.channels = channels
|
|
57
|
+
|
|
58
|
+
@classmethod
|
|
59
|
+
def from_numpy(cls, array: np.ndarray) -> 'Image':
|
|
60
|
+
"""Create Image from numpy array (H, W, C)"""
|
|
61
|
+
height, width, channels = array.shape
|
|
62
|
+
data = array.flatten().tolist()
|
|
63
|
+
return cls(data, width, height, channels)
|
|
64
|
+
|
|
65
|
+
@classmethod
|
|
66
|
+
def from_pil(cls, pil_image: PILImage.Image) -> 'Image':
|
|
67
|
+
"""Create Image from PIL Image"""
|
|
68
|
+
array = np.array(pil_image).astype(np.float32) / 255.0
|
|
69
|
+
return cls.from_numpy(array)
|
|
70
|
+
|
|
71
|
+
def to_numpy(self) -> np.ndarray:
|
|
72
|
+
"""Convert to numpy array (H, W, C)"""
|
|
73
|
+
return np.array(self.data).reshape(self.height, self.width, self.channels)
|
|
74
|
+
|
|
75
|
+
def to_pil(self) -> PILImage.Image:
|
|
76
|
+
"""Convert to PIL Image"""
|
|
77
|
+
array = (self.to_numpy() * 255).astype(np.uint8)
|
|
78
|
+
return PILImage.fromarray(array)
|
|
79
|
+
|
|
80
|
+
class ImageSamplerConfig:
|
|
81
|
+
def __init__(
|
|
82
|
+
self,
|
|
83
|
+
method: str = "ddim",
|
|
84
|
+
steps: int = 20,
|
|
85
|
+
guidance_scale: float = 7.5,
|
|
86
|
+
eta: float = 0.0,
|
|
87
|
+
seed: int = -1,
|
|
88
|
+
) -> None:
|
|
89
|
+
"""Initialize sampler configuration"""
|
|
90
|
+
self.method = method
|
|
91
|
+
self.steps = steps
|
|
92
|
+
self.guidance_scale = guidance_scale
|
|
93
|
+
self.eta = eta
|
|
94
|
+
self.seed = seed
|
|
95
|
+
|
|
96
|
+
class ImageGenerationConfig:
|
|
97
|
+
def __init__(
|
|
98
|
+
self,
|
|
99
|
+
prompts: str | List[str],
|
|
100
|
+
negative_prompts: str | List[str] | None = None,
|
|
101
|
+
height: int = 512,
|
|
102
|
+
width: int = 512,
|
|
103
|
+
sampler_config: Optional[ImageSamplerConfig] = None,
|
|
104
|
+
lora_id: int = -1, # Not used but kept for compatibility
|
|
105
|
+
init_image: Optional[Image] = None,
|
|
106
|
+
strength: float = 1.0,
|
|
107
|
+
n_images: int = 1,
|
|
108
|
+
n_rows: int = 1,
|
|
109
|
+
decoding_batch_size: int = 1,
|
|
110
|
+
) -> None:
|
|
111
|
+
"""Initialize image generation configuration"""
|
|
112
|
+
self.prompts = prompts
|
|
113
|
+
self.negative_prompts = negative_prompts or ""
|
|
114
|
+
self.height = height
|
|
115
|
+
self.width = width
|
|
116
|
+
self.sampler_config = sampler_config or ImageSamplerConfig()
|
|
117
|
+
self.lora_id = lora_id
|
|
118
|
+
self.init_image = init_image
|
|
119
|
+
self.strength = strength
|
|
120
|
+
self.n_images = n_images
|
|
121
|
+
self.n_rows = n_rows
|
|
122
|
+
self.decoding_batch_size = decoding_batch_size
|
|
123
|
+
|
|
124
|
+
# --------------------------------------------------------------------------------------
|
|
125
|
+
# Helper functions - following txt2img.py pattern
|
|
126
|
+
# --------------------------------------------------------------------------------------
|
|
127
|
+
|
|
128
|
+
def load_model(model_path: Path, float16: bool = True, quantize: bool = False) -> StableDiffusion:
|
|
129
|
+
"""Load a model from the given path - following txt2img.py pattern"""
|
|
130
|
+
|
|
131
|
+
# Check if it's a local path or HuggingFace repo
|
|
132
|
+
# If it contains path separators or exists as a file/directory, treat as local
|
|
133
|
+
is_local_path = ('/' in model_path or '\\' in model_path or os.path.exists(model_path))
|
|
134
|
+
|
|
135
|
+
if is_local_path:
|
|
136
|
+
# For local paths, determine model type from the path or model files
|
|
137
|
+
if "xl" in model_path.lower() or "turbo" in model_path.lower():
|
|
138
|
+
model = StableDiffusionXL(model_path, float16=float16)
|
|
139
|
+
else:
|
|
140
|
+
model = StableDiffusion(model_path, float16=float16)
|
|
141
|
+
else:
|
|
142
|
+
# For HuggingFace repo names, use the original logic
|
|
143
|
+
if "xl" in model_path.lower() or "turbo" in model_path.lower():
|
|
144
|
+
model = StableDiffusionXL(model_path, float16=float16)
|
|
145
|
+
else:
|
|
146
|
+
model = StableDiffusion(model_path, float16=float16)
|
|
147
|
+
|
|
148
|
+
# Apply quantization if requested - same as txt2img.py
|
|
149
|
+
if quantize:
|
|
150
|
+
if "xl" in model_path.lower() or "turbo" in model_path.lower():
|
|
151
|
+
nn.quantize(
|
|
152
|
+
model.text_encoder_1, class_predicate=lambda _, m: isinstance(m, nn.Linear)
|
|
153
|
+
)
|
|
154
|
+
nn.quantize(
|
|
155
|
+
model.text_encoder_2, class_predicate=lambda _, m: isinstance(m, nn.Linear)
|
|
156
|
+
)
|
|
157
|
+
else:
|
|
158
|
+
nn.quantize(
|
|
159
|
+
model.text_encoder, class_predicate=lambda _, m: isinstance(m, nn.Linear)
|
|
160
|
+
)
|
|
161
|
+
nn.quantize(model.unet, group_size=32, bits=8)
|
|
162
|
+
|
|
163
|
+
return model
|
|
164
|
+
|
|
165
|
+
def _prepare_image_for_sd(image: Image, target_width: int, target_height: int) -> mx.array:
|
|
166
|
+
"""Prepare image for stable diffusion processing - simplified"""
|
|
167
|
+
# Convert to PIL and resize
|
|
168
|
+
pil_img = image.to_pil()
|
|
169
|
+
pil_img = pil_img.resize((target_width, target_height), PILImage.LANCZOS)
|
|
170
|
+
|
|
171
|
+
# Convert to array and normalize to [0,1] range (following txt2img.py pattern)
|
|
172
|
+
img_array = np.array(pil_img).astype(np.float32)[:, :, :3] # Ensure RGB
|
|
173
|
+
img_tensor = mx.array(img_array / 255.0)
|
|
174
|
+
|
|
175
|
+
return img_tensor
|
|
176
|
+
|
|
177
|
+
# --------------------------------------------------------------------------------------
|
|
178
|
+
# Image generation
|
|
179
|
+
# --------------------------------------------------------------------------------------
|
|
180
|
+
|
|
181
|
+
class ImageGen:
|
|
182
|
+
def __init__(
|
|
183
|
+
self,
|
|
184
|
+
model_path: Path,
|
|
185
|
+
scheduler_config_path: Path = "", # Make optional
|
|
186
|
+
device: Optional[str] = None,
|
|
187
|
+
float16: bool = True,
|
|
188
|
+
quantize: bool = False,
|
|
189
|
+
) -> None:
|
|
190
|
+
"""Initialize the image generation model"""
|
|
191
|
+
self.model_path = model_path
|
|
192
|
+
self.scheduler_config_path = scheduler_config_path # Store for compatibility
|
|
193
|
+
self.float16 = float16
|
|
194
|
+
self.quantize = quantize
|
|
195
|
+
self.model = None
|
|
196
|
+
|
|
197
|
+
def destroy(self) -> None:
|
|
198
|
+
"""Clean up resources"""
|
|
199
|
+
self.model = None
|
|
200
|
+
|
|
201
|
+
def load_model(self, model_path: Path, extra_data: Any = None) -> bool:
|
|
202
|
+
"""Load the model from a file"""
|
|
203
|
+
try:
|
|
204
|
+
if os.path.isfile(model_path):
|
|
205
|
+
model_path = os.path.dirname(model_path)
|
|
206
|
+
|
|
207
|
+
self.model_path = model_path
|
|
208
|
+
self.model = load_model(model_path, self.float16, self.quantize)
|
|
209
|
+
self.model.ensure_models_are_loaded()
|
|
210
|
+
return True
|
|
211
|
+
except Exception as e:
|
|
212
|
+
log(f"Failed to load model: {e}")
|
|
213
|
+
return False
|
|
214
|
+
|
|
215
|
+
def close(self) -> None:
|
|
216
|
+
"""Close the model"""
|
|
217
|
+
self.destroy()
|
|
218
|
+
|
|
219
|
+
def set_scheduler(self, config: Any) -> None:
|
|
220
|
+
"""Set scheduler configuration (placeholder for compatibility)"""
|
|
221
|
+
log("Warning: set_scheduler not implemented")
|
|
222
|
+
pass
|
|
223
|
+
|
|
224
|
+
def set_sampler(self, config: ImageSamplerConfig) -> None:
|
|
225
|
+
"""Set sampler configuration (placeholder for compatibility)"""
|
|
226
|
+
log("Warning: set_sampler not implemented")
|
|
227
|
+
pass
|
|
228
|
+
|
|
229
|
+
def reset_sampler(self) -> None:
|
|
230
|
+
"""Reset sampler configuration (placeholder for compatibility)"""
|
|
231
|
+
log("Warning: reset_sampler not implemented")
|
|
232
|
+
pass
|
|
233
|
+
|
|
234
|
+
def set_lora(self, lora_id: int) -> None:
|
|
235
|
+
"""Set LoRA (placeholder for compatibility)"""
|
|
236
|
+
log("Warning: LoRA management not implemented")
|
|
237
|
+
pass
|
|
238
|
+
|
|
239
|
+
def add_lora(self, lora_path: Path) -> int:
|
|
240
|
+
"""Add LoRA (placeholder for compatibility)"""
|
|
241
|
+
log("Warning: LoRA management not implemented")
|
|
242
|
+
return -1
|
|
243
|
+
|
|
244
|
+
def remove_lora(self, lora_id: int) -> None:
|
|
245
|
+
"""Remove LoRA (placeholder for compatibility)"""
|
|
246
|
+
log("Warning: LoRA management not implemented")
|
|
247
|
+
pass
|
|
248
|
+
|
|
249
|
+
def list_loras(self) -> List[int]:
|
|
250
|
+
"""List LoRAs (placeholder for compatibility)"""
|
|
251
|
+
log("Warning: LoRA management not implemented")
|
|
252
|
+
return []
|
|
253
|
+
|
|
254
|
+
def txt2img(self, prompt: str, config: ImageGenerationConfig, clear_cache: bool = True) -> Image:
|
|
255
|
+
"""Generate an image from a text prompt - following txt2img.py pattern"""
|
|
256
|
+
if not self.model and not self.load_model(self.model_path):
|
|
257
|
+
raise RuntimeError("Model not loaded")
|
|
258
|
+
|
|
259
|
+
sampler_config = config.sampler_config
|
|
260
|
+
|
|
261
|
+
# Extract prompts
|
|
262
|
+
negative_prompt = ""
|
|
263
|
+
if config.negative_prompts:
|
|
264
|
+
negative_prompt = config.negative_prompts if isinstance(config.negative_prompts, str) else config.negative_prompts[0]
|
|
265
|
+
|
|
266
|
+
try:
|
|
267
|
+
# Generate latents - following txt2img.py approach
|
|
268
|
+
latents_generator = self.model.generate_latents(
|
|
269
|
+
prompt,
|
|
270
|
+
n_images=1,
|
|
271
|
+
num_steps=sampler_config.steps,
|
|
272
|
+
cfg_weight=sampler_config.guidance_scale,
|
|
273
|
+
negative_text=negative_prompt,
|
|
274
|
+
seed=sampler_config.seed if sampler_config.seed >= 0 else None
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
# Get final latents - following txt2img.py pattern
|
|
278
|
+
final_latents = None
|
|
279
|
+
for latents in latents_generator:
|
|
280
|
+
final_latents = latents
|
|
281
|
+
mx.eval(final_latents)
|
|
282
|
+
|
|
283
|
+
if final_latents is None:
|
|
284
|
+
raise RuntimeError("No latents generated")
|
|
285
|
+
|
|
286
|
+
# Decode to image - following txt2img.py pattern
|
|
287
|
+
decoded_image = self.model.decode(final_latents)
|
|
288
|
+
mx.eval(decoded_image)
|
|
289
|
+
|
|
290
|
+
# Convert to numpy array - following txt2img.py pattern
|
|
291
|
+
image_array = np.array(decoded_image.squeeze(0))
|
|
292
|
+
|
|
293
|
+
if clear_cache:
|
|
294
|
+
mx.clear_cache()
|
|
295
|
+
|
|
296
|
+
return Image.from_numpy(image_array)
|
|
297
|
+
|
|
298
|
+
except Exception as e:
|
|
299
|
+
log(f"Generation failed: {e}")
|
|
300
|
+
raise e
|
|
301
|
+
|
|
302
|
+
def img2img(self, init_image: Image, prompt: str, config: ImageGenerationConfig, clear_cache: bool = True) -> Image:
|
|
303
|
+
"""Generate an image from an initial image and a text prompt"""
|
|
304
|
+
if not self.model and not self.load_model(self.model_path):
|
|
305
|
+
raise RuntimeError("Model not loaded")
|
|
306
|
+
|
|
307
|
+
sampler_config = config.sampler_config
|
|
308
|
+
|
|
309
|
+
# Extract prompts
|
|
310
|
+
negative_prompt = ""
|
|
311
|
+
if config.negative_prompts:
|
|
312
|
+
negative_prompt = config.negative_prompts if isinstance(config.negative_prompts, str) else config.negative_prompts[0]
|
|
313
|
+
|
|
314
|
+
try:
|
|
315
|
+
# Prepare image for SD processing
|
|
316
|
+
img_tensor = _prepare_image_for_sd(init_image, config.width, config.height)
|
|
317
|
+
|
|
318
|
+
# Generate latents from image
|
|
319
|
+
latents_generator = self.model.generate_latents_from_image(
|
|
320
|
+
img_tensor,
|
|
321
|
+
prompt,
|
|
322
|
+
n_images=1,
|
|
323
|
+
strength=config.strength,
|
|
324
|
+
num_steps=sampler_config.steps,
|
|
325
|
+
cfg_weight=sampler_config.guidance_scale,
|
|
326
|
+
negative_text=negative_prompt,
|
|
327
|
+
seed=sampler_config.seed if sampler_config.seed >= 0 else None
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
# Get final latents
|
|
331
|
+
final_latents = None
|
|
332
|
+
for latents in latents_generator:
|
|
333
|
+
final_latents = latents
|
|
334
|
+
mx.eval(final_latents)
|
|
335
|
+
|
|
336
|
+
if final_latents is None:
|
|
337
|
+
raise RuntimeError("No latents generated")
|
|
338
|
+
|
|
339
|
+
# Decode to image
|
|
340
|
+
decoded_image = self.model.decode(final_latents)
|
|
341
|
+
mx.eval(decoded_image)
|
|
342
|
+
|
|
343
|
+
# Convert to numpy array
|
|
344
|
+
image_array = np.array(decoded_image.squeeze(0))
|
|
345
|
+
|
|
346
|
+
if clear_cache:
|
|
347
|
+
mx.clear_cache()
|
|
348
|
+
|
|
349
|
+
return Image.from_numpy(image_array)
|
|
350
|
+
|
|
351
|
+
except Exception as e:
|
|
352
|
+
log(f"Generation failed: {e}")
|
|
353
|
+
raise e
|
|
354
|
+
|
|
355
|
+
def generate(self, config: ImageGenerationConfig) -> Image:
|
|
356
|
+
"""Generate an image from configuration"""
|
|
357
|
+
if config.init_image:
|
|
358
|
+
prompt = config.prompts if isinstance(config.prompts, str) else config.prompts[0]
|
|
359
|
+
return self.img2img(config.init_image, prompt, config)
|
|
360
|
+
else:
|
|
361
|
+
prompt = config.prompts if isinstance(config.prompts, str) else config.prompts[0]
|
|
362
|
+
return self.txt2img(prompt, config)
|