nexaai 1.0.16rc13__cp310-cp310-macosx_15_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nexaai might be problematic. Click here for more details.
- nexaai/__init__.py +83 -0
- nexaai/_stub.cpython-310-darwin.so +0 -0
- nexaai/_version.py +4 -0
- nexaai/asr.py +64 -0
- nexaai/asr_impl/__init__.py +0 -0
- nexaai/asr_impl/mlx_asr_impl.py +92 -0
- nexaai/asr_impl/pybind_asr_impl.py +44 -0
- nexaai/base.py +39 -0
- nexaai/binds/__init__.py +4 -0
- nexaai/binds/common_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/embedder_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/libnexa_bridge.dylib +0 -0
- nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/nexa_llama_cpp/libggml-base.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libggml-cpu.so +0 -0
- nexaai/binds/nexa_llama_cpp/libggml-metal.so +0 -0
- nexaai/binds/nexa_llama_cpp/libggml.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libllama.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libmtmd.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexa_mlx/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexa_mlx/py-lib/ml.py +888 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/__init__.py +5 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/activation.py +51 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/amp.py +96 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/conv.py +114 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/resample.py +177 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/base.py +228 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/dac.py +285 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/nn/layers.py +129 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/encodec/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/encodec/encodec.py +777 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/mimi.py +286 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/model.py +260 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/model_v2.py +383 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/utils.py +122 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/attention.py +97 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/layers.py +306 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/snac.py +154 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/vq.py +135 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/vocos/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/vocos/mel.py +33 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/vocos/vocos.py +359 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_bigvgan.py +54 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_descript.py +109 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_encodec.py +58 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_mimi.py +22 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_s3.py +25 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_snac.py +40 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_vocos.py +93 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/server.py +525 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/sts/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/sts/voice_pipeline.py +327 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/generate.py +174 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/alignment.py +248 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/attention.py +187 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/audio.py +76 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/conformer.py +331 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/ctc.py +34 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/audio.py +82 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/decoding.py +742 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/timing.py +329 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/whisper.py +862 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/writers.py +268 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/tests/test_models.py +381 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/utils.py +195 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/audio_player.py +120 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/convert.py +71 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/generate.py +449 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/__init__.py +4 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/bark.py +528 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/isftnet.py +12 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/pipeline.py +442 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/base.py +84 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/audio.py +287 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/config.py +256 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/dia.py +592 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/layers.py +870 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/__init__.py +3 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/attention.py +180 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/conformer.py +247 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/gpt2.py +38 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/indextts.py +412 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/mel.py +37 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/normalize.py +294 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/perceiver.py +62 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/interpolate.py +108 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/__init__.py +4 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/modules.py +659 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/voice.py +113 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/llama/__init__.py +3 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/llama/llama.py +324 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/outetts.py +255 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/tokens.py +36 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/__init__.py +3 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/attention.py +195 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/sesame.py +633 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/watermarking.py +105 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/bicodec.py +269 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/residual.py +209 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/spark.py +382 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/utils/audio.py +220 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/utils/file.py +221 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_base.py +66 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_convert.py +173 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_interpolate.py +88 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_models.py +974 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/utils.py +337 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/utils.py +237 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/version.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/profiling.py +239 -0
- nexaai/binds/nexa_nexaml/libggml-base.dylib +0 -0
- nexaai/binds/nexa_nexaml/libggml-cpu.so +0 -0
- nexaai/binds/nexa_nexaml/libggml-metal.so +0 -0
- nexaai/binds/nexa_nexaml/libggml.dylib +0 -0
- nexaai/binds/nexa_nexaml/libnexa-mm-process.dylib +0 -0
- nexaai/binds/nexa_nexaml/libnexa-sampling.dylib +0 -0
- nexaai/binds/nexa_nexaml/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexa_nexaml/libnexaproc.dylib +0 -0
- nexaai/binds/nexa_nexaml/libqwen3-vl.dylib +0 -0
- nexaai/binds/nexa_nexaml/libqwen3vl-vision.dylib +0 -0
- nexaai/binds/vlm_bind.cpython-310-darwin.so +0 -0
- nexaai/common.py +104 -0
- nexaai/cv.py +92 -0
- nexaai/cv_impl/__init__.py +0 -0
- nexaai/cv_impl/mlx_cv_impl.py +89 -0
- nexaai/cv_impl/pybind_cv_impl.py +32 -0
- nexaai/embedder.py +72 -0
- nexaai/embedder_impl/__init__.py +0 -0
- nexaai/embedder_impl/mlx_embedder_impl.py +116 -0
- nexaai/embedder_impl/pybind_embedder_impl.py +95 -0
- nexaai/image_gen.py +140 -0
- nexaai/image_gen_impl/__init__.py +0 -0
- nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -0
- nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -0
- nexaai/llm.py +96 -0
- nexaai/llm_impl/__init__.py +0 -0
- nexaai/llm_impl/mlx_llm_impl.py +269 -0
- nexaai/llm_impl/pybind_llm_impl.py +218 -0
- nexaai/log.py +92 -0
- nexaai/mlx_backend/asr/__init__.py +12 -0
- nexaai/mlx_backend/asr/interface.py +122 -0
- nexaai/mlx_backend/common/__init__.py +0 -0
- nexaai/mlx_backend/common/utils.py +25 -0
- nexaai/mlx_backend/cv/__init__.py +0 -0
- nexaai/mlx_backend/cv/generate.py +195 -0
- nexaai/mlx_backend/cv/interface.py +151 -0
- nexaai/mlx_backend/cv/main.py +81 -0
- nexaai/mlx_backend/cv/modeling/pp_ocr_v4.py +1736 -0
- nexaai/mlx_backend/embedding/__init__.py +0 -0
- nexaai/mlx_backend/embedding/generate.py +333 -0
- nexaai/mlx_backend/embedding/interface.py +617 -0
- nexaai/mlx_backend/embedding/main.py +173 -0
- nexaai/mlx_backend/embedding/modeling/__init__.py +0 -0
- nexaai/mlx_backend/embedding/modeling/nexa_jina_v2.py +399 -0
- nexaai/mlx_backend/image_gen/__init__.py +1 -0
- nexaai/mlx_backend/image_gen/generate_sd.py +244 -0
- nexaai/mlx_backend/image_gen/interface.py +82 -0
- nexaai/mlx_backend/image_gen/main.py +281 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/__init__.py +306 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/clip.py +116 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/config.py +65 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/model_io.py +386 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/sampler.py +105 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/tokenizer.py +100 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/unet.py +460 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/vae.py +274 -0
- nexaai/mlx_backend/llm/__init__.py +0 -0
- nexaai/mlx_backend/llm/generate.py +149 -0
- nexaai/mlx_backend/llm/interface.py +764 -0
- nexaai/mlx_backend/llm/main.py +68 -0
- nexaai/mlx_backend/ml.py +888 -0
- nexaai/mlx_backend/mlx_audio/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/codec/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/__init__.py +5 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/activation.py +51 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/amp.py +96 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/conv.py +114 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/resample.py +177 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/base.py +228 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/dac.py +285 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/layers.py +129 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
- nexaai/mlx_backend/mlx_audio/codec/models/encodec/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/encodec/encodec.py +777 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/mimi.py +286 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/model.py +260 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/model_v2.py +383 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/utils.py +122 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/attention.py +97 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/layers.py +306 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/snac.py +154 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/vq.py +135 -0
- nexaai/mlx_backend/mlx_audio/codec/models/vocos/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/vocos/mel.py +33 -0
- nexaai/mlx_backend/mlx_audio/codec/models/vocos/vocos.py +359 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_bigvgan.py +54 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_descript.py +109 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_encodec.py +58 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_mimi.py +22 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_s3.py +25 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_snac.py +40 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_vocos.py +93 -0
- nexaai/mlx_backend/mlx_audio/server.py +525 -0
- nexaai/mlx_backend/mlx_audio/sts/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
- nexaai/mlx_backend/mlx_audio/sts/voice_pipeline.py +327 -0
- nexaai/mlx_backend/mlx_audio/stt/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/stt/generate.py +174 -0
- nexaai/mlx_backend/mlx_audio/stt/models/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/alignment.py +248 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/attention.py +187 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/audio.py +76 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/conformer.py +331 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/ctc.py +34 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
- nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
- nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/audio.py +82 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/decoding.py +742 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/timing.py +329 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/whisper.py +862 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/writers.py +268 -0
- nexaai/mlx_backend/mlx_audio/stt/tests/test_models.py +381 -0
- nexaai/mlx_backend/mlx_audio/stt/utils.py +195 -0
- nexaai/mlx_backend/mlx_audio/tts/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/audio_player.py +120 -0
- nexaai/mlx_backend/mlx_audio/tts/convert.py +71 -0
- nexaai/mlx_backend/mlx_audio/tts/generate.py +449 -0
- nexaai/mlx_backend/mlx_audio/tts/models/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/__init__.py +4 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/bark.py +528 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/isftnet.py +12 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/pipeline.py +442 -0
- nexaai/mlx_backend/mlx_audio/tts/models/base.py +84 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/audio.py +287 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/config.py +256 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/dia.py +592 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/layers.py +870 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/__init__.py +3 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/attention.py +180 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/conformer.py +247 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/gpt2.py +38 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/indextts.py +412 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/mel.py +37 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/normalize.py +294 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/perceiver.py +62 -0
- nexaai/mlx_backend/mlx_audio/tts/models/interpolate.py +108 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/__init__.py +4 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/modules.py +659 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/voice.py +113 -0
- nexaai/mlx_backend/mlx_audio/tts/models/llama/__init__.py +3 -0
- nexaai/mlx_backend/mlx_audio/tts/models/llama/llama.py +324 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/default_speaker.json +461 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/outetts.py +255 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/tokens.py +36 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/__init__.py +3 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/attention.py +195 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/sesame.py +633 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/watermarking.py +105 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/bicodec.py +269 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual.py +209 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/spark.py +382 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/audio.py +220 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/file.py +221 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_base.py +66 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_convert.py +173 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_interpolate.py +88 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_models.py +974 -0
- nexaai/mlx_backend/mlx_audio/tts/utils.py +337 -0
- nexaai/mlx_backend/mlx_audio/utils.py +237 -0
- nexaai/mlx_backend/mlx_audio/version.py +1 -0
- nexaai/mlx_backend/profiling.py +239 -0
- nexaai/mlx_backend/rerank/__init__.py +0 -0
- nexaai/mlx_backend/rerank/generate.py +174 -0
- nexaai/mlx_backend/rerank/interface.py +287 -0
- nexaai/mlx_backend/rerank/main.py +127 -0
- nexaai/mlx_backend/rerank/modeling/__init__.py +0 -0
- nexaai/mlx_backend/rerank/modeling/nexa_jina_rerank.py +330 -0
- nexaai/mlx_backend/sd/__init__.py +1 -0
- nexaai/mlx_backend/sd/interface.py +362 -0
- nexaai/mlx_backend/sd/main.py +286 -0
- nexaai/mlx_backend/sd/modeling/__init__.py +306 -0
- nexaai/mlx_backend/sd/modeling/clip.py +116 -0
- nexaai/mlx_backend/sd/modeling/config.py +65 -0
- nexaai/mlx_backend/sd/modeling/model_io.py +385 -0
- nexaai/mlx_backend/sd/modeling/sampler.py +105 -0
- nexaai/mlx_backend/sd/modeling/tokenizer.py +100 -0
- nexaai/mlx_backend/sd/modeling/unet.py +460 -0
- nexaai/mlx_backend/sd/modeling/vae.py +274 -0
- nexaai/mlx_backend/tts/__init__.py +12 -0
- nexaai/mlx_backend/tts/interface.py +276 -0
- nexaai/mlx_backend/vlm/__init__.py +3 -0
- nexaai/mlx_backend/vlm/generate.py +572 -0
- nexaai/mlx_backend/vlm/generate_qwen3_vl.py +261 -0
- nexaai/mlx_backend/vlm/interface.py +415 -0
- nexaai/mlx_backend/vlm/main.py +316 -0
- nexaai/mlx_backend/vlm/modeling/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/convert.py +68 -0
- nexaai/mlx_backend/vlm/modeling/models/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/interpolate.py +186 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/language.py +233 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/vision.py +503 -0
- nexaai/mlx_backend/vlm/modeling/models/base.py +202 -0
- nexaai/mlx_backend/vlm/modeling/models/cache.py +230 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/florence2.py +366 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/language.py +488 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/vision.py +591 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/gemma3.py +213 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/language.py +315 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/vision.py +238 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/__init__.py +2 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/audio.py +1038 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/config.py +139 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/language.py +629 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/vision.py +1022 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/idefics2.py +294 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/language.py +191 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/vision.py +267 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/idefics3.py +175 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/language.py +192 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/vision.py +233 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/language.py +220 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/processor.py +393 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/vision.py +293 -0
- nexaai/mlx_backend/vlm/modeling/models/kernels.py +307 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/language.py +509 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/vision.py +522 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/language.py +386 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/llama4.py +138 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/vision.py +560 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/language.py +240 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/llava.py +153 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/vision.py +259 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/language.py +236 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/vision.py +303 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/language.py +230 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/llava_next.py +160 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/vision.py +243 -0
- nexaai/mlx_backend/vlm/modeling/models/mistral3/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/mistral3/mistral3.py +283 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/language.py +416 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/mllama.py +172 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/vision.py +499 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/language.py +243 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/molmo.py +133 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/vision.py +465 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/__init__.py +10 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/language.py +230 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/sam.py +557 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/vision.py +526 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/language.py +282 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/paligemma.py +160 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/vision.py +242 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/language.py +21 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/su_rope.py +71 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/vision.py +324 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/language.py +229 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/pixtral.py +161 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/vision.py +320 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/config.py +104 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/language.py +490 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/vision.py +312 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/processor.py +476 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +1223 -0
- nexaai/mlx_backend/vlm/modeling/models/smolvlm/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
- nexaai/mlx_backend/vlm/modeling/processing_qwen2_5_vl.py +209 -0
- nexaai/mlx_backend/vlm/modeling/processing_qwen2_vl.py +215 -0
- nexaai/mlx_backend/vlm/modeling/prompt_utils.py +474 -0
- nexaai/mlx_backend/vlm/modeling/sample_utils.py +39 -0
- nexaai/mlx_backend/vlm/modeling/tokenizer_utils.py +344 -0
- nexaai/mlx_backend/vlm/modeling/trainer/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/trainer/lora.py +70 -0
- nexaai/mlx_backend/vlm/modeling/trainer/trainer.py +296 -0
- nexaai/mlx_backend/vlm/modeling/trainer/utils.py +160 -0
- nexaai/mlx_backend/vlm/modeling/utils.py +928 -0
- nexaai/rerank.py +55 -0
- nexaai/rerank_impl/__init__.py +0 -0
- nexaai/rerank_impl/mlx_rerank_impl.py +92 -0
- nexaai/rerank_impl/pybind_rerank_impl.py +43 -0
- nexaai/runtime.py +68 -0
- nexaai/tts.py +74 -0
- nexaai/tts_impl/__init__.py +0 -0
- nexaai/tts_impl/mlx_tts_impl.py +94 -0
- nexaai/tts_impl/pybind_tts_impl.py +43 -0
- nexaai/utils/avatar_fetcher.py +104 -0
- nexaai/utils/decode.py +18 -0
- nexaai/utils/manifest_utils.py +324 -0
- nexaai/utils/model_manager.py +1353 -0
- nexaai/utils/model_types.py +47 -0
- nexaai/utils/progress_tracker.py +385 -0
- nexaai/utils/quantization_utils.py +245 -0
- nexaai/vlm.py +128 -0
- nexaai/vlm_impl/__init__.py +0 -0
- nexaai/vlm_impl/mlx_vlm_impl.py +258 -0
- nexaai/vlm_impl/pybind_vlm_impl.py +230 -0
- nexaai-1.0.16rc13.dist-info/METADATA +32 -0
- nexaai-1.0.16rc13.dist-info/RECORD +557 -0
- nexaai-1.0.16rc13.dist-info/WHEEL +5 -0
- nexaai-1.0.16rc13.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,362 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
from typing import (
|
|
5
|
+
Any,
|
|
6
|
+
Callable,
|
|
7
|
+
List,
|
|
8
|
+
Optional,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
import mlx.core as mx
|
|
12
|
+
import numpy as np
|
|
13
|
+
from PIL import Image as PILImage
|
|
14
|
+
import mlx.nn as nn
|
|
15
|
+
import os
|
|
16
|
+
|
|
17
|
+
from .modeling import StableDiffusion, StableDiffusionXL
|
|
18
|
+
|
|
19
|
+
# --------------------------------------------------------------------------------------
|
|
20
|
+
# Core aliases & callback protocols
|
|
21
|
+
# --------------------------------------------------------------------------------------
|
|
22
|
+
|
|
23
|
+
Path = str
|
|
24
|
+
LogCallback = Callable[[str], None]
|
|
25
|
+
|
|
26
|
+
# --------------------------------------------------------------------------------------
|
|
27
|
+
# Core module functions
|
|
28
|
+
# --------------------------------------------------------------------------------------
|
|
29
|
+
|
|
30
|
+
def init() -> None:
|
|
31
|
+
"""Initialize the stable diffusion module"""
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
def deinit() -> None:
|
|
35
|
+
"""Deinitialize the stable diffusion module"""
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
def set_log(callback: LogCallback) -> None:
|
|
39
|
+
"""Set the logging callback"""
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
def log(message: str) -> None:
|
|
43
|
+
"""Log a message"""
|
|
44
|
+
print(message)
|
|
45
|
+
|
|
46
|
+
# --------------------------------------------------------------------------------------
|
|
47
|
+
# Basic data structures
|
|
48
|
+
# --------------------------------------------------------------------------------------
|
|
49
|
+
|
|
50
|
+
class Image:
|
|
51
|
+
def __init__(self, data: List[float], width: int, height: int, channels: int) -> None:
|
|
52
|
+
"""Initialize an image with pixel data"""
|
|
53
|
+
self.data = data
|
|
54
|
+
self.width = width
|
|
55
|
+
self.height = height
|
|
56
|
+
self.channels = channels
|
|
57
|
+
|
|
58
|
+
@classmethod
|
|
59
|
+
def from_numpy(cls, array: np.ndarray) -> 'Image':
|
|
60
|
+
"""Create Image from numpy array (H, W, C)"""
|
|
61
|
+
height, width, channels = array.shape
|
|
62
|
+
data = array.flatten().tolist()
|
|
63
|
+
return cls(data, width, height, channels)
|
|
64
|
+
|
|
65
|
+
@classmethod
|
|
66
|
+
def from_pil(cls, pil_image: PILImage.Image) -> 'Image':
|
|
67
|
+
"""Create Image from PIL Image"""
|
|
68
|
+
array = np.array(pil_image).astype(np.float32) / 255.0
|
|
69
|
+
return cls.from_numpy(array)
|
|
70
|
+
|
|
71
|
+
def to_numpy(self) -> np.ndarray:
|
|
72
|
+
"""Convert to numpy array (H, W, C)"""
|
|
73
|
+
return np.array(self.data).reshape(self.height, self.width, self.channels)
|
|
74
|
+
|
|
75
|
+
def to_pil(self) -> PILImage.Image:
|
|
76
|
+
"""Convert to PIL Image"""
|
|
77
|
+
array = (self.to_numpy() * 255).astype(np.uint8)
|
|
78
|
+
return PILImage.fromarray(array)
|
|
79
|
+
|
|
80
|
+
class ImageSamplerConfig:
|
|
81
|
+
def __init__(
|
|
82
|
+
self,
|
|
83
|
+
method: str = "ddim",
|
|
84
|
+
steps: int = 20,
|
|
85
|
+
guidance_scale: float = 7.5,
|
|
86
|
+
eta: float = 0.0,
|
|
87
|
+
seed: int = -1,
|
|
88
|
+
) -> None:
|
|
89
|
+
"""Initialize sampler configuration"""
|
|
90
|
+
self.method = method
|
|
91
|
+
self.steps = steps
|
|
92
|
+
self.guidance_scale = guidance_scale
|
|
93
|
+
self.eta = eta
|
|
94
|
+
self.seed = seed
|
|
95
|
+
|
|
96
|
+
class ImageGenerationConfig:
|
|
97
|
+
def __init__(
|
|
98
|
+
self,
|
|
99
|
+
prompts: str | List[str],
|
|
100
|
+
negative_prompts: str | List[str] | None = None,
|
|
101
|
+
height: int = 512,
|
|
102
|
+
width: int = 512,
|
|
103
|
+
sampler_config: Optional[ImageSamplerConfig] = None,
|
|
104
|
+
lora_id: int = -1, # Not used but kept for compatibility
|
|
105
|
+
init_image: Optional[Image] = None,
|
|
106
|
+
strength: float = 1.0,
|
|
107
|
+
n_images: int = 1,
|
|
108
|
+
n_rows: int = 1,
|
|
109
|
+
decoding_batch_size: int = 1,
|
|
110
|
+
) -> None:
|
|
111
|
+
"""Initialize image generation configuration"""
|
|
112
|
+
self.prompts = prompts
|
|
113
|
+
self.negative_prompts = negative_prompts or ""
|
|
114
|
+
self.height = height
|
|
115
|
+
self.width = width
|
|
116
|
+
self.sampler_config = sampler_config or ImageSamplerConfig()
|
|
117
|
+
self.lora_id = lora_id
|
|
118
|
+
self.init_image = init_image
|
|
119
|
+
self.strength = strength
|
|
120
|
+
self.n_images = n_images
|
|
121
|
+
self.n_rows = n_rows
|
|
122
|
+
self.decoding_batch_size = decoding_batch_size
|
|
123
|
+
|
|
124
|
+
# --------------------------------------------------------------------------------------
|
|
125
|
+
# Helper functions - following txt2img.py pattern
|
|
126
|
+
# --------------------------------------------------------------------------------------
|
|
127
|
+
|
|
128
|
+
def load_model(model_path: Path, float16: bool = True, quantize: bool = False) -> StableDiffusion:
|
|
129
|
+
"""Load a model from the given path - following txt2img.py pattern"""
|
|
130
|
+
|
|
131
|
+
# Check if it's a local path or HuggingFace repo
|
|
132
|
+
# If it contains path separators or exists as a file/directory, treat as local
|
|
133
|
+
is_local_path = ('/' in model_path or '\\' in model_path or os.path.exists(model_path))
|
|
134
|
+
|
|
135
|
+
if is_local_path:
|
|
136
|
+
# For local paths, determine model type from the path or model files
|
|
137
|
+
if "xl" in model_path.lower() or "turbo" in model_path.lower():
|
|
138
|
+
model = StableDiffusionXL(model_path, float16=float16)
|
|
139
|
+
else:
|
|
140
|
+
model = StableDiffusion(model_path, float16=float16)
|
|
141
|
+
else:
|
|
142
|
+
# For HuggingFace repo names, use the original logic
|
|
143
|
+
if "xl" in model_path.lower() or "turbo" in model_path.lower():
|
|
144
|
+
model = StableDiffusionXL(model_path, float16=float16)
|
|
145
|
+
else:
|
|
146
|
+
model = StableDiffusion(model_path, float16=float16)
|
|
147
|
+
|
|
148
|
+
# Apply quantization if requested - same as txt2img.py
|
|
149
|
+
if quantize:
|
|
150
|
+
if "xl" in model_path.lower() or "turbo" in model_path.lower():
|
|
151
|
+
nn.quantize(
|
|
152
|
+
model.text_encoder_1, class_predicate=lambda _, m: isinstance(m, nn.Linear)
|
|
153
|
+
)
|
|
154
|
+
nn.quantize(
|
|
155
|
+
model.text_encoder_2, class_predicate=lambda _, m: isinstance(m, nn.Linear)
|
|
156
|
+
)
|
|
157
|
+
else:
|
|
158
|
+
nn.quantize(
|
|
159
|
+
model.text_encoder, class_predicate=lambda _, m: isinstance(m, nn.Linear)
|
|
160
|
+
)
|
|
161
|
+
nn.quantize(model.unet, group_size=32, bits=8)
|
|
162
|
+
|
|
163
|
+
return model
|
|
164
|
+
|
|
165
|
+
def _prepare_image_for_sd(image: Image, target_width: int, target_height: int) -> mx.array:
|
|
166
|
+
"""Prepare image for stable diffusion processing - simplified"""
|
|
167
|
+
# Convert to PIL and resize
|
|
168
|
+
pil_img = image.to_pil()
|
|
169
|
+
pil_img = pil_img.resize((target_width, target_height), PILImage.LANCZOS)
|
|
170
|
+
|
|
171
|
+
# Convert to array and normalize to [0,1] range (following txt2img.py pattern)
|
|
172
|
+
img_array = np.array(pil_img).astype(np.float32)[:, :, :3] # Ensure RGB
|
|
173
|
+
img_tensor = mx.array(img_array / 255.0)
|
|
174
|
+
|
|
175
|
+
return img_tensor
|
|
176
|
+
|
|
177
|
+
# --------------------------------------------------------------------------------------
|
|
178
|
+
# Image generation
|
|
179
|
+
# --------------------------------------------------------------------------------------
|
|
180
|
+
|
|
181
|
+
class ImageGen:
|
|
182
|
+
def __init__(
|
|
183
|
+
self,
|
|
184
|
+
model_path: Path,
|
|
185
|
+
scheduler_config_path: Path = "", # Make optional
|
|
186
|
+
device: Optional[str] = None,
|
|
187
|
+
float16: bool = True,
|
|
188
|
+
quantize: bool = False,
|
|
189
|
+
) -> None:
|
|
190
|
+
"""Initialize the image generation model"""
|
|
191
|
+
self.model_path = model_path
|
|
192
|
+
self.scheduler_config_path = scheduler_config_path # Store for compatibility
|
|
193
|
+
self.float16 = float16
|
|
194
|
+
self.quantize = quantize
|
|
195
|
+
self.model = None
|
|
196
|
+
|
|
197
|
+
def destroy(self) -> None:
|
|
198
|
+
"""Clean up resources"""
|
|
199
|
+
self.model = None
|
|
200
|
+
|
|
201
|
+
def load_model(self, model_path: Path, extra_data: Any = None) -> bool:
|
|
202
|
+
"""Load the model from a file"""
|
|
203
|
+
try:
|
|
204
|
+
if os.path.isfile(model_path):
|
|
205
|
+
model_path = os.path.dirname(model_path)
|
|
206
|
+
|
|
207
|
+
self.model_path = model_path
|
|
208
|
+
self.model = load_model(model_path, self.float16, self.quantize)
|
|
209
|
+
self.model.ensure_models_are_loaded()
|
|
210
|
+
return True
|
|
211
|
+
except Exception as e:
|
|
212
|
+
log(f"Failed to load model: {e}")
|
|
213
|
+
return False
|
|
214
|
+
|
|
215
|
+
def close(self) -> None:
|
|
216
|
+
"""Close the model"""
|
|
217
|
+
self.destroy()
|
|
218
|
+
|
|
219
|
+
def set_scheduler(self, config: Any) -> None:
|
|
220
|
+
"""Set scheduler configuration (placeholder for compatibility)"""
|
|
221
|
+
log("Warning: set_scheduler not implemented")
|
|
222
|
+
pass
|
|
223
|
+
|
|
224
|
+
def set_sampler(self, config: ImageSamplerConfig) -> None:
|
|
225
|
+
"""Set sampler configuration (placeholder for compatibility)"""
|
|
226
|
+
log("Warning: set_sampler not implemented")
|
|
227
|
+
pass
|
|
228
|
+
|
|
229
|
+
def reset_sampler(self) -> None:
|
|
230
|
+
"""Reset sampler configuration (placeholder for compatibility)"""
|
|
231
|
+
log("Warning: reset_sampler not implemented")
|
|
232
|
+
pass
|
|
233
|
+
|
|
234
|
+
def set_lora(self, lora_id: int) -> None:
|
|
235
|
+
"""Set LoRA (placeholder for compatibility)"""
|
|
236
|
+
log("Warning: LoRA management not implemented")
|
|
237
|
+
pass
|
|
238
|
+
|
|
239
|
+
def add_lora(self, lora_path: Path) -> int:
|
|
240
|
+
"""Add LoRA (placeholder for compatibility)"""
|
|
241
|
+
log("Warning: LoRA management not implemented")
|
|
242
|
+
return -1
|
|
243
|
+
|
|
244
|
+
def remove_lora(self, lora_id: int) -> None:
|
|
245
|
+
"""Remove LoRA (placeholder for compatibility)"""
|
|
246
|
+
log("Warning: LoRA management not implemented")
|
|
247
|
+
pass
|
|
248
|
+
|
|
249
|
+
def list_loras(self) -> List[int]:
|
|
250
|
+
"""List LoRAs (placeholder for compatibility)"""
|
|
251
|
+
log("Warning: LoRA management not implemented")
|
|
252
|
+
return []
|
|
253
|
+
|
|
254
|
+
def txt2img(self, prompt: str, config: ImageGenerationConfig, clear_cache: bool = True) -> Image:
|
|
255
|
+
"""Generate an image from a text prompt - following txt2img.py pattern"""
|
|
256
|
+
if not self.model and not self.load_model(self.model_path):
|
|
257
|
+
raise RuntimeError("Model not loaded")
|
|
258
|
+
|
|
259
|
+
sampler_config = config.sampler_config
|
|
260
|
+
|
|
261
|
+
# Extract prompts
|
|
262
|
+
negative_prompt = ""
|
|
263
|
+
if config.negative_prompts:
|
|
264
|
+
negative_prompt = config.negative_prompts if isinstance(config.negative_prompts, str) else config.negative_prompts[0]
|
|
265
|
+
|
|
266
|
+
try:
|
|
267
|
+
# Generate latents - following txt2img.py approach
|
|
268
|
+
latents_generator = self.model.generate_latents(
|
|
269
|
+
prompt,
|
|
270
|
+
n_images=1,
|
|
271
|
+
num_steps=sampler_config.steps,
|
|
272
|
+
cfg_weight=sampler_config.guidance_scale,
|
|
273
|
+
negative_text=negative_prompt,
|
|
274
|
+
seed=sampler_config.seed if sampler_config.seed >= 0 else None
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
# Get final latents - following txt2img.py pattern
|
|
278
|
+
final_latents = None
|
|
279
|
+
for latents in latents_generator:
|
|
280
|
+
final_latents = latents
|
|
281
|
+
mx.eval(final_latents)
|
|
282
|
+
|
|
283
|
+
if final_latents is None:
|
|
284
|
+
raise RuntimeError("No latents generated")
|
|
285
|
+
|
|
286
|
+
# Decode to image - following txt2img.py pattern
|
|
287
|
+
decoded_image = self.model.decode(final_latents)
|
|
288
|
+
mx.eval(decoded_image)
|
|
289
|
+
|
|
290
|
+
# Convert to numpy array - following txt2img.py pattern
|
|
291
|
+
image_array = np.array(decoded_image.squeeze(0))
|
|
292
|
+
|
|
293
|
+
if clear_cache:
|
|
294
|
+
mx.clear_cache()
|
|
295
|
+
|
|
296
|
+
return Image.from_numpy(image_array)
|
|
297
|
+
|
|
298
|
+
except Exception as e:
|
|
299
|
+
log(f"Generation failed: {e}")
|
|
300
|
+
raise e
|
|
301
|
+
|
|
302
|
+
def img2img(self, init_image: Image, prompt: str, config: ImageGenerationConfig, clear_cache: bool = True) -> Image:
|
|
303
|
+
"""Generate an image from an initial image and a text prompt"""
|
|
304
|
+
if not self.model and not self.load_model(self.model_path):
|
|
305
|
+
raise RuntimeError("Model not loaded")
|
|
306
|
+
|
|
307
|
+
sampler_config = config.sampler_config
|
|
308
|
+
|
|
309
|
+
# Extract prompts
|
|
310
|
+
negative_prompt = ""
|
|
311
|
+
if config.negative_prompts:
|
|
312
|
+
negative_prompt = config.negative_prompts if isinstance(config.negative_prompts, str) else config.negative_prompts[0]
|
|
313
|
+
|
|
314
|
+
try:
|
|
315
|
+
# Prepare image for SD processing
|
|
316
|
+
img_tensor = _prepare_image_for_sd(init_image, config.width, config.height)
|
|
317
|
+
|
|
318
|
+
# Generate latents from image
|
|
319
|
+
latents_generator = self.model.generate_latents_from_image(
|
|
320
|
+
img_tensor,
|
|
321
|
+
prompt,
|
|
322
|
+
n_images=1,
|
|
323
|
+
strength=config.strength,
|
|
324
|
+
num_steps=sampler_config.steps,
|
|
325
|
+
cfg_weight=sampler_config.guidance_scale,
|
|
326
|
+
negative_text=negative_prompt,
|
|
327
|
+
seed=sampler_config.seed if sampler_config.seed >= 0 else None
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
# Get final latents
|
|
331
|
+
final_latents = None
|
|
332
|
+
for latents in latents_generator:
|
|
333
|
+
final_latents = latents
|
|
334
|
+
mx.eval(final_latents)
|
|
335
|
+
|
|
336
|
+
if final_latents is None:
|
|
337
|
+
raise RuntimeError("No latents generated")
|
|
338
|
+
|
|
339
|
+
# Decode to image
|
|
340
|
+
decoded_image = self.model.decode(final_latents)
|
|
341
|
+
mx.eval(decoded_image)
|
|
342
|
+
|
|
343
|
+
# Convert to numpy array
|
|
344
|
+
image_array = np.array(decoded_image.squeeze(0))
|
|
345
|
+
|
|
346
|
+
if clear_cache:
|
|
347
|
+
mx.clear_cache()
|
|
348
|
+
|
|
349
|
+
return Image.from_numpy(image_array)
|
|
350
|
+
|
|
351
|
+
except Exception as e:
|
|
352
|
+
log(f"Generation failed: {e}")
|
|
353
|
+
raise e
|
|
354
|
+
|
|
355
|
+
def generate(self, config: ImageGenerationConfig) -> Image:
|
|
356
|
+
"""Generate an image from configuration"""
|
|
357
|
+
if config.init_image:
|
|
358
|
+
prompt = config.prompts if isinstance(config.prompts, str) else config.prompts[0]
|
|
359
|
+
return self.img2img(config.init_image, prompt, config)
|
|
360
|
+
else:
|
|
361
|
+
prompt = config.prompts if isinstance(config.prompts, str) else config.prompts[0]
|
|
362
|
+
return self.txt2img(prompt, config)
|
|
@@ -0,0 +1,286 @@
|
|
|
1
|
+
from interface import ImageGen, ImageGenerationConfig, ImageSamplerConfig, Image
|
|
2
|
+
import numpy as np
|
|
3
|
+
from PIL import Image as PILImage
|
|
4
|
+
import mlx.core as mx
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def test_txt2image(
|
|
8
|
+
prompt="A photo of an astronaut riding a horse on Mars.",
|
|
9
|
+
model="sdxl",
|
|
10
|
+
local_model_path="",
|
|
11
|
+
n_images=1,
|
|
12
|
+
steps=None,
|
|
13
|
+
cfg=None,
|
|
14
|
+
negative_prompt="",
|
|
15
|
+
n_rows=1,
|
|
16
|
+
decoding_batch_size=1,
|
|
17
|
+
float16=True,
|
|
18
|
+
quantize=False,
|
|
19
|
+
preload_models=False,
|
|
20
|
+
output="out_txt2img.png",
|
|
21
|
+
seed=None,
|
|
22
|
+
verbose=False,
|
|
23
|
+
width=512,
|
|
24
|
+
height=512,
|
|
25
|
+
):
|
|
26
|
+
"""Generate images from text prompt using high-level interface"""
|
|
27
|
+
|
|
28
|
+
# Determine model path based on model type
|
|
29
|
+
if model == "sdxl":
|
|
30
|
+
model_path = local_model_path or "stabilityai/sdxl-turbo"
|
|
31
|
+
default_cfg = 0.0
|
|
32
|
+
default_steps = 2
|
|
33
|
+
else:
|
|
34
|
+
model_path = local_model_path or "stabilityai/stable-diffusion-2-1-base"
|
|
35
|
+
default_cfg = 7.5
|
|
36
|
+
default_steps = 50
|
|
37
|
+
|
|
38
|
+
# Use provided values or defaults
|
|
39
|
+
cfg = cfg or default_cfg
|
|
40
|
+
steps = steps or default_steps
|
|
41
|
+
|
|
42
|
+
# Create ImageGen instance with proper parameters
|
|
43
|
+
image_gen = ImageGen(model_path, "", device=None, float16=float16, quantize=quantize)
|
|
44
|
+
|
|
45
|
+
# Load the model
|
|
46
|
+
if not image_gen.load_model(model_path):
|
|
47
|
+
print(f"Failed to load model: {model_path}")
|
|
48
|
+
return None
|
|
49
|
+
|
|
50
|
+
# Create sampler configuration
|
|
51
|
+
sampler_config = ImageSamplerConfig(
|
|
52
|
+
method="ddim",
|
|
53
|
+
steps=steps,
|
|
54
|
+
guidance_scale=cfg,
|
|
55
|
+
seed=seed if seed is not None else -1,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# Create generation configuration with all parameters
|
|
59
|
+
gen_config = ImageGenerationConfig(
|
|
60
|
+
prompts=prompt,
|
|
61
|
+
negative_prompts=negative_prompt,
|
|
62
|
+
height=height,
|
|
63
|
+
width=width,
|
|
64
|
+
sampler_config=sampler_config,
|
|
65
|
+
n_images=n_images,
|
|
66
|
+
n_rows=n_rows,
|
|
67
|
+
decoding_batch_size=decoding_batch_size,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
if verbose:
|
|
71
|
+
print(f"Generating {n_images} image(s) with prompt: '{prompt}'")
|
|
72
|
+
print(f"Model: {model_path}, Steps: {steps}, CFG: {cfg}")
|
|
73
|
+
print(f"Float16: {float16}, Quantize: {quantize}")
|
|
74
|
+
|
|
75
|
+
# Generate image using txt2img
|
|
76
|
+
result_image = image_gen.txt2img(prompt, gen_config)
|
|
77
|
+
|
|
78
|
+
# Free memory by deleting model components (following main_duplicate.py pattern)
|
|
79
|
+
if image_gen.model:
|
|
80
|
+
if model == "sdxl":
|
|
81
|
+
if hasattr(image_gen.model, "text_encoder_1"):
|
|
82
|
+
del image_gen.model.text_encoder_1
|
|
83
|
+
if hasattr(image_gen.model, "text_encoder_2"):
|
|
84
|
+
del image_gen.model.text_encoder_2
|
|
85
|
+
else:
|
|
86
|
+
if hasattr(image_gen.model, "text_encoder"):
|
|
87
|
+
del image_gen.model.text_encoder
|
|
88
|
+
|
|
89
|
+
if hasattr(image_gen.model, "unet"):
|
|
90
|
+
del image_gen.model.unet
|
|
91
|
+
if hasattr(image_gen.model, "sampler"):
|
|
92
|
+
del image_gen.model.sampler
|
|
93
|
+
|
|
94
|
+
# Get peak memory usage
|
|
95
|
+
peak_mem_unet = mx.metal.get_peak_memory() / 1024**3
|
|
96
|
+
|
|
97
|
+
# Convert to PIL and save
|
|
98
|
+
image_np = result_image.to_numpy()
|
|
99
|
+
image_pil = PILImage.fromarray((image_np * 255).astype(np.uint8))
|
|
100
|
+
image_pil.save(output)
|
|
101
|
+
|
|
102
|
+
print(f"Text-to-image output saved to: {output}")
|
|
103
|
+
|
|
104
|
+
# Get final peak memory usage
|
|
105
|
+
peak_mem_overall = mx.metal.get_peak_memory() / 1024**3
|
|
106
|
+
|
|
107
|
+
# Report memory usage
|
|
108
|
+
if verbose:
|
|
109
|
+
print(f"Peak memory used for unet: {peak_mem_unet:.3f}GB")
|
|
110
|
+
print(f"Peak memory used overall: {peak_mem_overall:.3f}GB")
|
|
111
|
+
|
|
112
|
+
# Clean up
|
|
113
|
+
image_gen.close()
|
|
114
|
+
|
|
115
|
+
return output
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def test_image2image(
|
|
119
|
+
prompt="A lit fireplace",
|
|
120
|
+
model="sdxl",
|
|
121
|
+
strength=0.5,
|
|
122
|
+
local_model_path="",
|
|
123
|
+
n_images=1,
|
|
124
|
+
steps=None,
|
|
125
|
+
cfg=None,
|
|
126
|
+
negative_prompt="",
|
|
127
|
+
n_rows=1,
|
|
128
|
+
decoding_batch_size=1,
|
|
129
|
+
quantize=False,
|
|
130
|
+
float16=True,
|
|
131
|
+
preload_models=False,
|
|
132
|
+
init_image_path="out_txt2img.png",
|
|
133
|
+
output="out_img2img.png",
|
|
134
|
+
verbose=False,
|
|
135
|
+
seed=None,
|
|
136
|
+
width=256,
|
|
137
|
+
height=256,
|
|
138
|
+
):
|
|
139
|
+
"""Generate images from image and text prompt using high-level interface"""
|
|
140
|
+
|
|
141
|
+
# Determine model path based on model type
|
|
142
|
+
if model == "sdxl":
|
|
143
|
+
model_path = local_model_path or "stabilityai/sdxl-turbo"
|
|
144
|
+
default_cfg = 0.0
|
|
145
|
+
default_steps = 2
|
|
146
|
+
else:
|
|
147
|
+
model_path = local_model_path or "stabilityai/stable-diffusion-2-1-base"
|
|
148
|
+
default_cfg = 7.5
|
|
149
|
+
default_steps = 50
|
|
150
|
+
|
|
151
|
+
# Use provided values or defaults
|
|
152
|
+
cfg = cfg or default_cfg
|
|
153
|
+
steps = steps or default_steps
|
|
154
|
+
|
|
155
|
+
# Load and process input image
|
|
156
|
+
try:
|
|
157
|
+
pil_img = PILImage.open(init_image_path)
|
|
158
|
+
# Ensure RGB format
|
|
159
|
+
if pil_img.mode != "RGB":
|
|
160
|
+
pil_img = pil_img.convert("RGB")
|
|
161
|
+
|
|
162
|
+
# Convert to numpy array and then to our Image class
|
|
163
|
+
img_np = np.array(pil_img).astype(np.float32) / 255.0 # Normalize to [0,1]
|
|
164
|
+
init_image = Image.from_numpy(img_np)
|
|
165
|
+
|
|
166
|
+
except FileNotFoundError:
|
|
167
|
+
print(f"Error: Image file '{init_image_path}' not found.")
|
|
168
|
+
return None
|
|
169
|
+
except Exception as e:
|
|
170
|
+
print(f"Error loading image: {e}")
|
|
171
|
+
return None
|
|
172
|
+
|
|
173
|
+
# Create ImageGen instance
|
|
174
|
+
image_gen = ImageGen(model_path, "", device=None)
|
|
175
|
+
|
|
176
|
+
# Load the model
|
|
177
|
+
if not image_gen.load_model(model_path):
|
|
178
|
+
print(f"Failed to load model: {model_path}")
|
|
179
|
+
return None
|
|
180
|
+
|
|
181
|
+
# Create sampler configuration
|
|
182
|
+
sampler_config = ImageSamplerConfig(
|
|
183
|
+
method="ddim",
|
|
184
|
+
steps=steps,
|
|
185
|
+
guidance_scale=cfg,
|
|
186
|
+
seed=seed if seed is not None else -1,
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
# Create generation configuration
|
|
190
|
+
gen_config = ImageGenerationConfig(
|
|
191
|
+
prompts=prompt,
|
|
192
|
+
negative_prompts=negative_prompt,
|
|
193
|
+
height=height,
|
|
194
|
+
width=width,
|
|
195
|
+
sampler_config=sampler_config,
|
|
196
|
+
init_image=init_image,
|
|
197
|
+
strength=strength,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
if verbose:
|
|
201
|
+
print(f"Generating image with prompt: '{prompt}' and strength: {strength}")
|
|
202
|
+
print(f"Model: {model_path}, Steps: {steps}, CFG: {cfg}")
|
|
203
|
+
|
|
204
|
+
# Generate image using img2img
|
|
205
|
+
result_image = image_gen.img2img(init_image, prompt, gen_config)
|
|
206
|
+
|
|
207
|
+
# Free memory by deleting model components (following main_duplicate.py pattern)
|
|
208
|
+
if image_gen.model:
|
|
209
|
+
if model == "sdxl":
|
|
210
|
+
if hasattr(image_gen.model, "text_encoder_1"):
|
|
211
|
+
del image_gen.model.text_encoder_1
|
|
212
|
+
if hasattr(image_gen.model, "text_encoder_2"):
|
|
213
|
+
del image_gen.model.text_encoder_2
|
|
214
|
+
else:
|
|
215
|
+
if hasattr(image_gen.model, "text_encoder"):
|
|
216
|
+
del image_gen.model.text_encoder
|
|
217
|
+
|
|
218
|
+
if hasattr(image_gen.model, "unet"):
|
|
219
|
+
del image_gen.model.unet
|
|
220
|
+
if hasattr(image_gen.model, "sampler"):
|
|
221
|
+
del image_gen.model.sampler
|
|
222
|
+
|
|
223
|
+
# Get peak memory usage
|
|
224
|
+
peak_mem_unet = mx.metal.get_peak_memory() / 1024**3
|
|
225
|
+
|
|
226
|
+
# Convert to PIL and save
|
|
227
|
+
image_np = result_image.to_numpy()
|
|
228
|
+
image_pil = PILImage.fromarray((image_np * 255).astype(np.uint8))
|
|
229
|
+
image_pil.save(output)
|
|
230
|
+
|
|
231
|
+
print(f"Image-to-image output saved to: {output}")
|
|
232
|
+
|
|
233
|
+
# Get final peak memory usage
|
|
234
|
+
peak_mem_overall = mx.metal.get_peak_memory() / 1024**3
|
|
235
|
+
|
|
236
|
+
# Report memory usage
|
|
237
|
+
if verbose:
|
|
238
|
+
print(f"Peak memory used for unet: {peak_mem_unet:.3f}GB")
|
|
239
|
+
print(f"Peak memory used overall: {peak_mem_overall:.3f}GB")
|
|
240
|
+
|
|
241
|
+
# Clean up
|
|
242
|
+
image_gen.close()
|
|
243
|
+
|
|
244
|
+
return output
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
if __name__ == "__main__":
|
|
248
|
+
# Text-to-image parameters
|
|
249
|
+
txt2img_params = {
|
|
250
|
+
"prompt": "A photo of an astronaut riding a horse on Mars.",
|
|
251
|
+
"model": "sdxl",
|
|
252
|
+
"n_images": 1,
|
|
253
|
+
"n_rows": 1,
|
|
254
|
+
"output": "out_txt2img.png",
|
|
255
|
+
"verbose": True,
|
|
256
|
+
"width": 256,
|
|
257
|
+
"height": 256,
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
# Image-to-image parameters
|
|
261
|
+
img2img_params = {
|
|
262
|
+
"prompt": "A lit fireplace",
|
|
263
|
+
"model": "sdxl",
|
|
264
|
+
"strength": 0.5,
|
|
265
|
+
"n_images": 1,
|
|
266
|
+
"n_rows": 1,
|
|
267
|
+
"init_image_path": "out_txt2img.png",
|
|
268
|
+
"output": "out_img2img.png",
|
|
269
|
+
"verbose": True,
|
|
270
|
+
"width": 512,
|
|
271
|
+
"height": 512,
|
|
272
|
+
}
|
|
273
|
+
|
|
274
|
+
print("Running text-to-image generation...")
|
|
275
|
+
generated_image = test_txt2image(**txt2img_params)
|
|
276
|
+
|
|
277
|
+
if generated_image:
|
|
278
|
+
print(f"\nRunning image-to-image generation using: {generated_image}")
|
|
279
|
+
img2img_params["init_image_path"] = generated_image
|
|
280
|
+
test_image2image(**img2img_params)
|
|
281
|
+
|
|
282
|
+
print(f"\nPipeline complete!")
|
|
283
|
+
print(f"Text-to-image result: {txt2img_params['output']}")
|
|
284
|
+
print(f"Image-to-image result: {img2img_params['output']}")
|
|
285
|
+
else:
|
|
286
|
+
print("Failed to generate initial image, skipping img2img test")
|