nexaai 1.0.19rc5__cp310-cp310-macosx_14_0_universal2.whl → 1.0.19rc7__cp310-cp310-macosx_14_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nexaai might be problematic. Click here for more details.
- nexaai/_stub.cpython-310-darwin.so +0 -0
- nexaai/_version.py +1 -1
- nexaai/binds/libnexa_bridge.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libggml-base.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libggml-cpu.so +0 -0
- nexaai/binds/nexa_llama_cpp/libggml-metal.so +0 -0
- nexaai/binds/nexa_llama_cpp/libggml.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libllama.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libmtmd.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexa_mlx/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexa_mlx/py-lib/asr/__init__.py +12 -0
- nexaai/binds/nexa_mlx/py-lib/asr/interface.py +122 -0
- nexaai/binds/nexa_mlx/py-lib/common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/common/utils.py +25 -0
- nexaai/binds/nexa_mlx/py-lib/cv/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/cv/generate.py +195 -0
- nexaai/binds/nexa_mlx/py-lib/cv/interface.py +151 -0
- nexaai/binds/nexa_mlx/py-lib/cv/main.py +81 -0
- nexaai/binds/nexa_mlx/py-lib/cv/modeling/pp_ocr_v4.py +1736 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/generate.py +333 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/interface.py +617 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/main.py +173 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/embedding/modeling/nexa_jina_v2.py +399 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/generate_sd.py +244 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/interface.py +82 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/main.py +281 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/__init__.py +306 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/clip.py +116 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/config.py +65 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/model_io.py +386 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/sampler.py +105 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/tokenizer.py +100 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/unet.py +460 -0
- nexaai/binds/nexa_mlx/py-lib/image_gen/stable_diffusion/vae.py +274 -0
- nexaai/binds/nexa_mlx/py-lib/llm/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/llm/generate.py +149 -0
- nexaai/binds/nexa_mlx/py-lib/llm/interface.py +764 -0
- nexaai/binds/nexa_mlx/py-lib/llm/main.py +68 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/generate.py +174 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/interface.py +287 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/main.py +127 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/rerank/modeling/nexa_jina_rerank.py +330 -0
- nexaai/binds/nexa_mlx/py-lib/sd/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/sd/interface.py +362 -0
- nexaai/binds/nexa_mlx/py-lib/sd/main.py +286 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/__init__.py +306 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/clip.py +116 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/config.py +65 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/model_io.py +385 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/sampler.py +105 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/tokenizer.py +100 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/unet.py +460 -0
- nexaai/binds/nexa_mlx/py-lib/sd/modeling/vae.py +274 -0
- nexaai/binds/nexa_mlx/py-lib/tts/__init__.py +12 -0
- nexaai/binds/nexa_mlx/py-lib/tts/interface.py +276 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/__init__.py +3 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/generate.py +572 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl.py +294 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/generate_qwen3_vl_moe.py +276 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/interface.py +504 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/main.py +320 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/convert.py +68 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/interpolate.py +186 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/language.py +233 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/aya_vision/vision.py +503 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/base.py +202 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/cache.py +230 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/florence2.py +366 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/language.py +488 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/florence2/vision.py +591 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/gemma3.py +213 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/language.py +315 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3/vision.py +238 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/__init__.py +2 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/audio.py +1038 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/config.py +139 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/language.py +629 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/gemma3n/vision.py +1022 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/__init__.py +9 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/idefics2.py +294 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/language.py +191 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics2/vision.py +267 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/idefics3.py +175 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/language.py +192 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/idefics3/vision.py +233 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/__init__.py +9 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/language.py +220 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/processor.py +393 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/internvl_chat/vision.py +293 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kernels.py +307 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/language.py +509 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/kimi_vl/vision.py +522 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/language.py +386 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/llama4.py +138 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llama4/vision.py +560 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/language.py +240 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/llava.py +153 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava/vision.py +259 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/__init__.py +9 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/language.py +236 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_bunny/vision.py +303 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/language.py +230 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/llava_next.py +160 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/llava_next/vision.py +243 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mistral3/mistral3.py +283 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/language.py +416 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/mllama.py +172 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/mllama/vision.py +499 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/language.py +243 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/molmo.py +133 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/molmo/vision.py +465 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/__init__.py +10 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/language.py +230 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/sam.py +557 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/multi_modality/vision.py +526 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/language.py +282 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/paligemma.py +160 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/paligemma/vision.py +242 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/language.py +21 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/su_rope.py +71 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/phi3_v/vision.py +324 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/language.py +229 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/pixtral.py +161 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/pixtral/vision.py +320 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/config.py +104 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/language.py +490 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen2_vl/vision.py +312 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/processor.py +476 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3_vl/qwen3vl.py +1223 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1309 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/__init__.py +8 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_5_vl.py +209 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/processing_qwen2_vl.py +215 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/prompt_utils.py +474 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/sample_utils.py +39 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/tokenizer_utils.py +344 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/__init__.py +9 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/lora.py +70 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/trainer.py +296 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/trainer/utils.py +160 -0
- nexaai/binds/nexa_mlx/py-lib/vlm/modeling/utils.py +928 -0
- nexaai/binds/nexa_nexaml/libggml-base.dylib +0 -0
- nexaai/binds/nexa_nexaml/libggml-cpu.so +0 -0
- nexaai/binds/nexa_nexaml/libggml-metal.so +0 -0
- nexaai/binds/nexa_nexaml/libggml.dylib +0 -0
- nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +276 -0
- nexaai/mlx_backend/vlm/interface.py +21 -4
- nexaai/mlx_backend/vlm/main.py +6 -2
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1309 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
- {nexaai-1.0.19rc5.dist-info → nexaai-1.0.19rc7.dist-info}/METADATA +1 -1
- {nexaai-1.0.19rc5.dist-info → nexaai-1.0.19rc7.dist-info}/RECORD +221 -21
- {nexaai-1.0.19rc5.dist-info → nexaai-1.0.19rc7.dist-info}/WHEEL +0 -0
- {nexaai-1.0.19rc5.dist-info → nexaai-1.0.19rc7.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import (
|
|
4
|
+
List,
|
|
5
|
+
Optional,
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
import mlx.core as mx
|
|
9
|
+
import numpy as np
|
|
10
|
+
from PIL import Image as PILImage
|
|
11
|
+
import mlx.nn as nn
|
|
12
|
+
import os
|
|
13
|
+
|
|
14
|
+
from .stable_diffusion import StableDiffusion, StableDiffusionXL
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Image:
|
|
18
|
+
def __init__(self, data: List[float], width: int, height: int, channels: int) -> None:
|
|
19
|
+
"""Initialize an image with pixel data"""
|
|
20
|
+
self.data = data
|
|
21
|
+
self.width = width
|
|
22
|
+
self.height = height
|
|
23
|
+
self.channels = channels
|
|
24
|
+
|
|
25
|
+
@classmethod
|
|
26
|
+
def from_numpy(cls, array: np.ndarray) -> 'Image':
|
|
27
|
+
"""Create Image from numpy array (H, W, C)"""
|
|
28
|
+
height, width, channels = array.shape
|
|
29
|
+
data = array.flatten().tolist()
|
|
30
|
+
return cls(data, width, height, channels)
|
|
31
|
+
|
|
32
|
+
@classmethod
|
|
33
|
+
def from_pil(cls, pil_image: PILImage.Image) -> 'Image':
|
|
34
|
+
"""Create Image from PIL Image"""
|
|
35
|
+
array = np.array(pil_image).astype(np.float32) / 255.0
|
|
36
|
+
return cls.from_numpy(array)
|
|
37
|
+
|
|
38
|
+
def to_numpy(self) -> np.ndarray:
|
|
39
|
+
"""Convert to numpy array (H, W, C)"""
|
|
40
|
+
return np.array(self.data).reshape(self.height, self.width, self.channels)
|
|
41
|
+
|
|
42
|
+
def to_pil(self) -> PILImage.Image:
|
|
43
|
+
"""Convert to PIL Image"""
|
|
44
|
+
array = (self.to_numpy() * 255).astype(np.uint8)
|
|
45
|
+
return PILImage.fromarray(array)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class ImageSamplerConfig:
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
method: str = "ddim",
|
|
52
|
+
steps: int = 4, # SDXL Turbo typically uses fewer steps
|
|
53
|
+
guidance_scale: float = 0.0, # SDXL Turbo works well with no guidance
|
|
54
|
+
eta: float = 0.0,
|
|
55
|
+
seed: int = -1,
|
|
56
|
+
) -> None:
|
|
57
|
+
"""Initialize sampler configuration optimized for SDXL Turbo"""
|
|
58
|
+
self.method = method
|
|
59
|
+
self.steps = steps
|
|
60
|
+
self.guidance_scale = guidance_scale
|
|
61
|
+
self.eta = eta
|
|
62
|
+
self.seed = seed
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class ImageGenerationConfig:
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
prompts: str | List[str],
|
|
69
|
+
negative_prompts: str | List[str] | None = None,
|
|
70
|
+
height: int = 512,
|
|
71
|
+
width: int = 512,
|
|
72
|
+
sampler_config: Optional[ImageSamplerConfig] = None,
|
|
73
|
+
lora_id: int = -1, # Not used but kept for compatibility
|
|
74
|
+
init_image: Optional[Image] = None,
|
|
75
|
+
strength: float = 1.0,
|
|
76
|
+
n_images: int = 1,
|
|
77
|
+
n_rows: int = 1,
|
|
78
|
+
decoding_batch_size: int = 1,
|
|
79
|
+
) -> None:
|
|
80
|
+
"""Initialize image generation configuration"""
|
|
81
|
+
self.prompts = prompts
|
|
82
|
+
self.negative_prompts = negative_prompts or ""
|
|
83
|
+
self.height = height
|
|
84
|
+
self.width = width
|
|
85
|
+
self.sampler_config = sampler_config or ImageSamplerConfig()
|
|
86
|
+
self.lora_id = lora_id
|
|
87
|
+
self.init_image = init_image
|
|
88
|
+
self.strength = strength
|
|
89
|
+
self.n_images = n_images
|
|
90
|
+
self.n_rows = n_rows
|
|
91
|
+
self.decoding_batch_size = decoding_batch_size
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class ImageGen:
|
|
95
|
+
def __init__(
|
|
96
|
+
self,
|
|
97
|
+
model_path: str,
|
|
98
|
+
scheduler_config_path: Optional[str] = None,
|
|
99
|
+
device: Optional[str] = None,
|
|
100
|
+
float16: bool = True,
|
|
101
|
+
quantize: bool = False,
|
|
102
|
+
) -> None:
|
|
103
|
+
"""Initialize the image generation model for SDXL Turbo"""
|
|
104
|
+
self.model_path = model_path
|
|
105
|
+
self.scheduler_config_path = scheduler_config_path
|
|
106
|
+
self.float16 = float16
|
|
107
|
+
self.quantize = quantize
|
|
108
|
+
self.model = None
|
|
109
|
+
|
|
110
|
+
@staticmethod
|
|
111
|
+
def load_model(model_path: str, float16: bool = True, quantize: bool = False) -> StableDiffusion:
|
|
112
|
+
"""Load a model from the given path - following txt2img.py pattern"""
|
|
113
|
+
|
|
114
|
+
# Check if it's a local path or HuggingFace repo
|
|
115
|
+
# If it contains path separators or exists as a file/directory, treat as local
|
|
116
|
+
is_local_path = (
|
|
117
|
+
'/' in model_path or '\\' in model_path or os.path.exists(model_path))
|
|
118
|
+
|
|
119
|
+
if is_local_path:
|
|
120
|
+
# For local paths, determine model type from the path or model files
|
|
121
|
+
if "xl" in model_path.lower() or "turbo" in model_path.lower():
|
|
122
|
+
model = StableDiffusionXL(model_path, float16=float16)
|
|
123
|
+
else:
|
|
124
|
+
model = StableDiffusion(model_path, float16=float16)
|
|
125
|
+
else:
|
|
126
|
+
# For HuggingFace repo names, use the original logic
|
|
127
|
+
if "xl" in model_path.lower() or "turbo" in model_path.lower():
|
|
128
|
+
model = StableDiffusionXL(model_path, float16=float16)
|
|
129
|
+
else:
|
|
130
|
+
model = StableDiffusion(model_path, float16=float16)
|
|
131
|
+
|
|
132
|
+
# Apply quantization if requested - same as txt2img.py
|
|
133
|
+
if quantize:
|
|
134
|
+
if "xl" in model_path.lower() or "turbo" in model_path.lower():
|
|
135
|
+
nn.quantize(
|
|
136
|
+
model.text_encoder_1, class_predicate=lambda _, m: isinstance(
|
|
137
|
+
m, nn.Linear)
|
|
138
|
+
)
|
|
139
|
+
nn.quantize(
|
|
140
|
+
model.text_encoder_2, class_predicate=lambda _, m: isinstance(
|
|
141
|
+
m, nn.Linear)
|
|
142
|
+
)
|
|
143
|
+
else:
|
|
144
|
+
nn.quantize(
|
|
145
|
+
model.text_encoder, class_predicate=lambda _, m: isinstance(
|
|
146
|
+
m, nn.Linear)
|
|
147
|
+
)
|
|
148
|
+
nn.quantize(model.unet, group_size=32, bits=8)
|
|
149
|
+
return model
|
|
150
|
+
|
|
151
|
+
def txt2img(self, prompt: str, config: ImageGenerationConfig, clear_cache: bool = True) -> Image:
|
|
152
|
+
"""Generate an image from a text prompt - following txt2img.py pattern"""
|
|
153
|
+
if not self.model:
|
|
154
|
+
self.model = self.load_model(self.model_path)
|
|
155
|
+
if not self.model:
|
|
156
|
+
raise RuntimeError("Model not loaded")
|
|
157
|
+
|
|
158
|
+
sampler_config = config.sampler_config
|
|
159
|
+
|
|
160
|
+
negative_prompt = ""
|
|
161
|
+
if config.negative_prompts:
|
|
162
|
+
negative_prompt = config.negative_prompts if isinstance(
|
|
163
|
+
config.negative_prompts, str) else config.negative_prompts[0]
|
|
164
|
+
|
|
165
|
+
# Generate latents - following txt2img.py approach
|
|
166
|
+
latents_generator = self.model.generate_latents(
|
|
167
|
+
prompt,
|
|
168
|
+
n_images=1,
|
|
169
|
+
num_steps=sampler_config.steps,
|
|
170
|
+
cfg_weight=sampler_config.guidance_scale,
|
|
171
|
+
negative_text=negative_prompt,
|
|
172
|
+
seed=sampler_config.seed if sampler_config.seed >= 0 else None
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
# Get final latents - following txt2img.py pattern
|
|
176
|
+
final_latents = None
|
|
177
|
+
for latents in latents_generator:
|
|
178
|
+
final_latents = latents
|
|
179
|
+
mx.eval(final_latents)
|
|
180
|
+
|
|
181
|
+
if final_latents is None:
|
|
182
|
+
raise RuntimeError("No latents generated")
|
|
183
|
+
|
|
184
|
+
# Decode to image - following txt2img.py pattern
|
|
185
|
+
decoded_image = self.model.decode(final_latents)
|
|
186
|
+
mx.eval(decoded_image)
|
|
187
|
+
|
|
188
|
+
# Convert to numpy array
|
|
189
|
+
image_array = np.array(decoded_image.squeeze(0))
|
|
190
|
+
|
|
191
|
+
if clear_cache:
|
|
192
|
+
mx.clear_cache()
|
|
193
|
+
|
|
194
|
+
return Image.from_numpy(image_array)
|
|
195
|
+
|
|
196
|
+
def img2img(self, init_image: Image, prompt: str, config: ImageGenerationConfig, clear_cache: bool = True) -> Image:
|
|
197
|
+
"""Generate an image from an initial image and a text prompt using SDXL Turbo"""
|
|
198
|
+
if not self.model:
|
|
199
|
+
self.model = self.load_model(self.model_path)
|
|
200
|
+
if not self.model:
|
|
201
|
+
raise RuntimeError("Model not loaded")
|
|
202
|
+
|
|
203
|
+
sampler_config = config.sampler_config
|
|
204
|
+
|
|
205
|
+
negative_prompt = ""
|
|
206
|
+
if config.negative_prompts:
|
|
207
|
+
negative_prompt = config.negative_prompts if isinstance(
|
|
208
|
+
config.negative_prompts, str) else config.negative_prompts[0]
|
|
209
|
+
|
|
210
|
+
img_tensor = _prepare_image_for_sd(
|
|
211
|
+
init_image, config.width, config.height)
|
|
212
|
+
|
|
213
|
+
# Generate latents from image
|
|
214
|
+
latents_generator = self.model.generate_latents_from_image(
|
|
215
|
+
img_tensor,
|
|
216
|
+
prompt,
|
|
217
|
+
n_images=1,
|
|
218
|
+
strength=config.strength,
|
|
219
|
+
num_steps=sampler_config.steps,
|
|
220
|
+
cfg_weight=sampler_config.guidance_scale,
|
|
221
|
+
negative_text=negative_prompt,
|
|
222
|
+
seed=sampler_config.seed if sampler_config.seed >= 0 else None
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
# Get final latents
|
|
226
|
+
final_latents = None
|
|
227
|
+
for latents in latents_generator:
|
|
228
|
+
final_latents = latents
|
|
229
|
+
mx.eval(final_latents)
|
|
230
|
+
|
|
231
|
+
if final_latents is None:
|
|
232
|
+
raise RuntimeError("No latents generated")
|
|
233
|
+
|
|
234
|
+
# Decode to image
|
|
235
|
+
decoded_image = self.model.decode(final_latents)
|
|
236
|
+
mx.eval(decoded_image)
|
|
237
|
+
|
|
238
|
+
# Convert to numpy array
|
|
239
|
+
image_array = np.array(decoded_image.squeeze(0))
|
|
240
|
+
|
|
241
|
+
if clear_cache:
|
|
242
|
+
mx.clear_cache()
|
|
243
|
+
|
|
244
|
+
return Image.from_numpy(image_array)
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import os
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
from ml import ImageGenCreateInput, ImageGenerationConfig, ImageGenImg2ImgInput, ImageGenTxt2ImgInput, ImageGenOutput
|
|
6
|
+
from profiling import ProfilingMixin, StopReason
|
|
7
|
+
|
|
8
|
+
from .generate_sd import ImageGen as SDImageGen, Image, ImageGenerationConfig as SDImageGenerationConfig, ImageSamplerConfig
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ImageGen(ProfilingMixin):
|
|
12
|
+
sd_gen: Optional[SDImageGen] = None
|
|
13
|
+
|
|
14
|
+
def __init__(self, input: ImageGenCreateInput):
|
|
15
|
+
"""Initialize the image generation model"""
|
|
16
|
+
self.sd_gen = SDImageGen(model_path=input.model_path)
|
|
17
|
+
|
|
18
|
+
def destroy(self) -> None:
|
|
19
|
+
"""Clean up resources"""
|
|
20
|
+
self.sd_gen = None
|
|
21
|
+
|
|
22
|
+
def txt2img(self, input: ImageGenTxt2ImgInput) -> ImageGenOutput:
|
|
23
|
+
"""Generate an image from a text prompt - public interface"""
|
|
24
|
+
height = input.config.height
|
|
25
|
+
width = input.config.width
|
|
26
|
+
assert height % 16 == 0, f"Height must be divisible by 16 ({height}/16={height/16})"
|
|
27
|
+
assert width % 16 == 0, f"Width must be divisible by 16 ({width}/16={width/16})"
|
|
28
|
+
|
|
29
|
+
internal_config = SDImageGenerationConfig(
|
|
30
|
+
prompts=input.prompt,
|
|
31
|
+
negative_prompts=input.config.negative_prompts,
|
|
32
|
+
height=height,
|
|
33
|
+
width=width,
|
|
34
|
+
sampler_config=ImageSamplerConfig(
|
|
35
|
+
steps=input.config.sampler_config.steps,
|
|
36
|
+
guidance_scale=input.config.sampler_config.guidance_scale,
|
|
37
|
+
seed=input.config.sampler_config.seed
|
|
38
|
+
),
|
|
39
|
+
strength=input.config.strength
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
result_image = self.sd_gen.txt2img(input.prompt, internal_config)
|
|
43
|
+
|
|
44
|
+
parent_dir = os.path.dirname(input.output_path)
|
|
45
|
+
if parent_dir:
|
|
46
|
+
os.makedirs(parent_dir, exist_ok=True)
|
|
47
|
+
result_image.to_pil().save(input.output_path)
|
|
48
|
+
|
|
49
|
+
return ImageGenOutput(output_image_path=input.output_path)
|
|
50
|
+
|
|
51
|
+
def img2img(self, input: ImageGenImg2ImgInput) -> ImageGenOutput:
|
|
52
|
+
"""Generate an image from an initial image and a text prompt - public interface"""
|
|
53
|
+
height = input.config.height
|
|
54
|
+
width = input.config.width
|
|
55
|
+
assert height % 16 == 0, f"Height must be divisible by 16 ({height}/16={height/16})"
|
|
56
|
+
assert width % 16 == 0, f"Width must be divisible by 16 ({width}/16={width/16})"
|
|
57
|
+
|
|
58
|
+
init_image = Image.from_pil(input.init_image_path)
|
|
59
|
+
|
|
60
|
+
internal_config = SDImageGenerationConfig(
|
|
61
|
+
prompts=input.prompt,
|
|
62
|
+
negative_prompts=input.config.negative_prompts,
|
|
63
|
+
height=height,
|
|
64
|
+
width=width,
|
|
65
|
+
sampler_config=ImageSamplerConfig(
|
|
66
|
+
steps=input.config.sampler_config.steps,
|
|
67
|
+
guidance_scale=input.config.sampler_config.guidance_scale,
|
|
68
|
+
seed=input.config.sampler_config.seed
|
|
69
|
+
),
|
|
70
|
+
init_image=init_image,
|
|
71
|
+
strength=input.config.strength
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
result_image = self.sd_gen.img2img(
|
|
75
|
+
init_image, input.prompt, internal_config)
|
|
76
|
+
|
|
77
|
+
parent_dir = os.path.dirname(input.output_path)
|
|
78
|
+
if parent_dir:
|
|
79
|
+
os.makedirs(parent_dir, exist_ok=True)
|
|
80
|
+
result_image.to_pil().save(input.output_path)
|
|
81
|
+
|
|
82
|
+
return ImageGenOutput(output_image_path=input.output_path)
|
|
@@ -0,0 +1,281 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
"""
|
|
3
|
+
Command line interface for text-to-image generation using MLX backend.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import argparse
|
|
7
|
+
import sys
|
|
8
|
+
import os
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Optional
|
|
11
|
+
|
|
12
|
+
# Add the parent directory to the path to import the interface
|
|
13
|
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
14
|
+
|
|
15
|
+
from interface import ImageGen, ImageSamplerConfig
|
|
16
|
+
from ml import (
|
|
17
|
+
ImageGenCreateInput,
|
|
18
|
+
ImageGenTxt2ImgInput,
|
|
19
|
+
ImageGenerationConfig,
|
|
20
|
+
ImageSamplerConfig as MLImageSamplerConfig,
|
|
21
|
+
SchedulerConfig,
|
|
22
|
+
ModelConfig
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def create_default_config() -> ImageGenerationConfig:
|
|
27
|
+
"""Create a default image generation configuration."""
|
|
28
|
+
sampler_config = MLImageSamplerConfig(
|
|
29
|
+
method="ddim",
|
|
30
|
+
steps=4, # SDXL Turbo optimized
|
|
31
|
+
guidance_scale=0.0, # SDXL Turbo works well with no guidance
|
|
32
|
+
eta=0.0,
|
|
33
|
+
seed=-1
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
scheduler_config = SchedulerConfig(
|
|
37
|
+
type="ddim",
|
|
38
|
+
num_train_timesteps=1000,
|
|
39
|
+
steps_offset=0,
|
|
40
|
+
beta_start=0.00085,
|
|
41
|
+
beta_end=0.012,
|
|
42
|
+
beta_schedule="scaled_linear",
|
|
43
|
+
prediction_type="epsilon",
|
|
44
|
+
timestep_type="discrete",
|
|
45
|
+
timestep_spacing="linspace",
|
|
46
|
+
interpolation_type="linear"
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
return ImageGenerationConfig(
|
|
50
|
+
prompts=[""], # Will be set by user input
|
|
51
|
+
sampler_config=sampler_config,
|
|
52
|
+
scheduler_config=scheduler_config,
|
|
53
|
+
strength=1.0,
|
|
54
|
+
negative_prompts=None,
|
|
55
|
+
height=512,
|
|
56
|
+
width=512
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def parse_arguments():
|
|
61
|
+
"""Parse command line arguments."""
|
|
62
|
+
parser = argparse.ArgumentParser(
|
|
63
|
+
description="Generate images from text prompts using MLX backend",
|
|
64
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
65
|
+
epilog="""
|
|
66
|
+
Examples:
|
|
67
|
+
python main.py "a beautiful sunset over mountains"
|
|
68
|
+
python main.py "a cat sitting on a chair" --output output.png --width 1024 --height 1024
|
|
69
|
+
python main.py "a futuristic city" --model-path ./models/sdxl-turbo --steps 8 --seed 42
|
|
70
|
+
"""
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# Required arguments
|
|
74
|
+
parser.add_argument(
|
|
75
|
+
"prompt",
|
|
76
|
+
help="Text prompt for image generation"
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
# Optional arguments
|
|
80
|
+
parser.add_argument(
|
|
81
|
+
"--output", "-o",
|
|
82
|
+
type=str,
|
|
83
|
+
help="Output image path (default: generated_image.png)"
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
parser.add_argument(
|
|
87
|
+
"--model-path", "-m",
|
|
88
|
+
type=str,
|
|
89
|
+
default="stabilityai/sdxl-turbo",
|
|
90
|
+
help="Path to the model or HuggingFace model name (default: stabilityai/sdxl-turbo)"
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
parser.add_argument(
|
|
94
|
+
"--width", "-w",
|
|
95
|
+
type=int,
|
|
96
|
+
default=512,
|
|
97
|
+
help="Image width (must be divisible by 16, default: 512)"
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
parser.add_argument(
|
|
101
|
+
"--height", "-h",
|
|
102
|
+
type=int,
|
|
103
|
+
default=512,
|
|
104
|
+
help="Image height (must be divisible by 16, default: 512)"
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
parser.add_argument(
|
|
108
|
+
"--steps", "-s",
|
|
109
|
+
type=int,
|
|
110
|
+
default=4,
|
|
111
|
+
help="Number of denoising steps (default: 4 for SDXL Turbo)"
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
parser.add_argument(
|
|
115
|
+
"--guidance-scale", "-g",
|
|
116
|
+
type=float,
|
|
117
|
+
default=0.0,
|
|
118
|
+
help="Guidance scale (default: 0.0 for SDXL Turbo)"
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
parser.add_argument(
|
|
122
|
+
"--seed",
|
|
123
|
+
type=int,
|
|
124
|
+
default=-1,
|
|
125
|
+
help="Random seed (-1 for random, default: -1)"
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
parser.add_argument(
|
|
129
|
+
"--negative-prompt", "-n",
|
|
130
|
+
type=str,
|
|
131
|
+
help="Negative prompt to avoid certain elements"
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
parser.add_argument(
|
|
135
|
+
"--device-id",
|
|
136
|
+
type=str,
|
|
137
|
+
help="Device ID to use (default: auto-detect)"
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
parser.add_argument(
|
|
141
|
+
"--verbose", "-v",
|
|
142
|
+
action="store_true",
|
|
143
|
+
help="Enable verbose logging"
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
return parser.parse_args()
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def validate_arguments(args):
|
|
150
|
+
"""Validate command line arguments."""
|
|
151
|
+
# Check dimensions are divisible by 16
|
|
152
|
+
if args.width % 16 != 0:
|
|
153
|
+
raise ValueError(f"Width must be divisible by 16, got {args.width}")
|
|
154
|
+
if args.height % 16 != 0:
|
|
155
|
+
raise ValueError(f"Height must be divisible by 16, got {args.height}")
|
|
156
|
+
|
|
157
|
+
# Check steps is positive
|
|
158
|
+
if args.steps <= 0:
|
|
159
|
+
raise ValueError(f"Steps must be positive, got {args.steps}")
|
|
160
|
+
|
|
161
|
+
# Check guidance scale is non-negative
|
|
162
|
+
if args.guidance_scale < 0:
|
|
163
|
+
raise ValueError(f"Guidance scale must be non-negative, got {args.guidance_scale}")
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def main():
|
|
167
|
+
"""Main function for command line interface."""
|
|
168
|
+
try:
|
|
169
|
+
# Parse arguments
|
|
170
|
+
args = parse_arguments()
|
|
171
|
+
|
|
172
|
+
# Validate arguments
|
|
173
|
+
validate_arguments(args)
|
|
174
|
+
|
|
175
|
+
# Set up output path
|
|
176
|
+
if args.output:
|
|
177
|
+
output_path = Path(args.output)
|
|
178
|
+
else:
|
|
179
|
+
output_path = Path("generated_image.png")
|
|
180
|
+
|
|
181
|
+
# Ensure output directory exists
|
|
182
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
183
|
+
|
|
184
|
+
if args.verbose:
|
|
185
|
+
print(f"Initializing image generation...")
|
|
186
|
+
print(f"Model: {args.model_path}")
|
|
187
|
+
print(f"Prompt: {args.prompt}")
|
|
188
|
+
print(f"Output: {output_path}")
|
|
189
|
+
print(f"Dimensions: {args.width}x{args.height}")
|
|
190
|
+
print(f"Steps: {args.steps}")
|
|
191
|
+
print(f"Guidance scale: {args.guidance_scale}")
|
|
192
|
+
print(f"Seed: {args.seed}")
|
|
193
|
+
if args.negative_prompt:
|
|
194
|
+
print(f"Negative prompt: {args.negative_prompt}")
|
|
195
|
+
|
|
196
|
+
# Create model configuration
|
|
197
|
+
model_config = ModelConfig(
|
|
198
|
+
name="sdxl-turbo",
|
|
199
|
+
version="1.0",
|
|
200
|
+
description="SDXL Turbo model for fast image generation"
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
# Create image generator
|
|
204
|
+
create_input = ImageGenCreateInput(
|
|
205
|
+
model_name="sdxl-turbo",
|
|
206
|
+
model_path=args.model_path,
|
|
207
|
+
config=model_config,
|
|
208
|
+
scheduler_config_path="", # Not used for SDXL Turbo
|
|
209
|
+
plugin_id="mlx",
|
|
210
|
+
device_id=args.device_id
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
image_gen = ImageGen(create_input)
|
|
214
|
+
|
|
215
|
+
# Create generation configuration
|
|
216
|
+
sampler_config = MLImageSamplerConfig(
|
|
217
|
+
method="ddim",
|
|
218
|
+
steps=args.steps,
|
|
219
|
+
guidance_scale=args.guidance_scale,
|
|
220
|
+
eta=0.0,
|
|
221
|
+
seed=args.seed
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
scheduler_config = SchedulerConfig(
|
|
225
|
+
type="ddim",
|
|
226
|
+
num_train_timesteps=1000,
|
|
227
|
+
steps_offset=0,
|
|
228
|
+
beta_start=0.00085,
|
|
229
|
+
beta_end=0.012,
|
|
230
|
+
beta_schedule="scaled_linear",
|
|
231
|
+
prediction_type="epsilon",
|
|
232
|
+
timestep_type="discrete",
|
|
233
|
+
timestep_spacing="linspace",
|
|
234
|
+
interpolation_type="linear"
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
generation_config = ImageGenerationConfig(
|
|
238
|
+
prompts=[args.prompt],
|
|
239
|
+
sampler_config=sampler_config,
|
|
240
|
+
scheduler_config=scheduler_config,
|
|
241
|
+
strength=1.0,
|
|
242
|
+
negative_prompts=[args.negative_prompt] if args.negative_prompt else None,
|
|
243
|
+
height=args.height,
|
|
244
|
+
width=args.width
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
# Create text-to-image input
|
|
248
|
+
txt2img_input = ImageGenTxt2ImgInput(
|
|
249
|
+
prompt=args.prompt,
|
|
250
|
+
config=generation_config,
|
|
251
|
+
output_path=str(output_path)
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
if args.verbose:
|
|
255
|
+
print("Generating image...")
|
|
256
|
+
|
|
257
|
+
# Generate image
|
|
258
|
+
result = image_gen.txt2img(txt2img_input)
|
|
259
|
+
|
|
260
|
+
if args.verbose:
|
|
261
|
+
print(f"Image generated successfully!")
|
|
262
|
+
print(f"Saved to: {result.output_image_path}")
|
|
263
|
+
else:
|
|
264
|
+
print(f"Image saved to: {result.output_image_path}")
|
|
265
|
+
|
|
266
|
+
# Clean up
|
|
267
|
+
image_gen.close()
|
|
268
|
+
|
|
269
|
+
except KeyboardInterrupt:
|
|
270
|
+
print("\nGeneration cancelled by user.")
|
|
271
|
+
sys.exit(1)
|
|
272
|
+
except Exception as e:
|
|
273
|
+
print(f"Error: {e}", file=sys.stderr)
|
|
274
|
+
if args.verbose:
|
|
275
|
+
import traceback
|
|
276
|
+
traceback.print_exc()
|
|
277
|
+
sys.exit(1)
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
if __name__ == "__main__":
|
|
281
|
+
main()
|