nexaai 1.0.29__cp310-cp310-macosx_14_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nexaai/__init__.py +99 -0
- nexaai/_stub.cpython-310-darwin.so +0 -0
- nexaai/_version.py +4 -0
- nexaai/asr.py +68 -0
- nexaai/asr_impl/__init__.py +0 -0
- nexaai/asr_impl/mlx_asr_impl.py +93 -0
- nexaai/asr_impl/pybind_asr_impl.py +127 -0
- nexaai/base.py +39 -0
- nexaai/binds/__init__.py +7 -0
- nexaai/binds/asr_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/common_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/cpu_gpu/libggml-base.dylib +0 -0
- nexaai/binds/cpu_gpu/libggml-cpu.so +0 -0
- nexaai/binds/cpu_gpu/libggml-metal.so +0 -0
- nexaai/binds/cpu_gpu/libggml.dylib +0 -0
- nexaai/binds/cpu_gpu/libmtmd.dylib +0 -0
- nexaai/binds/cpu_gpu/libnexa_cpu_gpu.dylib +0 -0
- nexaai/binds/cpu_gpu/libnexa_plugin.dylib +0 -0
- nexaai/binds/cv_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/diarize_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/embedder_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/libnexa_bridge.dylib +0 -0
- nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/metal/libnexa_plugin.dylib +0 -0
- nexaai/binds/metal/py-lib/ml.py +888 -0
- nexaai/binds/metal/py-lib/mlx_audio/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/__init__.py +5 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/activation.py +51 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/amp.py +96 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/conv.py +114 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/resample.py +177 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/base.py +228 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/dac.py +285 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/layers.py +129 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/encodec.py +777 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/mimi.py +286 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model.py +260 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model_v2.py +383 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/utils.py +122 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/attention.py +97 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/layers.py +306 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/snac.py +154 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/vq.py +135 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/mel.py +33 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/vocos.py +359 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_bigvgan.py +54 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_descript.py +109 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_encodec.py +58 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_mimi.py +22 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_s3.py +25 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_snac.py +40 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_vocos.py +93 -0
- nexaai/binds/metal/py-lib/mlx_audio/server.py +525 -0
- nexaai/binds/metal/py-lib/mlx_audio/sts/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
- nexaai/binds/metal/py-lib/mlx_audio/sts/voice_pipeline.py +327 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/generate.py +174 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/alignment.py +248 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/attention.py +187 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/audio.py +76 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/conformer.py +331 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/ctc.py +34 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/audio.py +82 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/decoding.py +742 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/timing.py +329 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/whisper.py +862 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/writers.py +268 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/tests/test_models.py +381 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/utils.py +195 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/audio_player.py +120 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/convert.py +71 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/generate.py +449 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/__init__.py +4 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/bark.py +528 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/isftnet.py +12 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/pipeline.py +442 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/base.py +84 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/audio.py +287 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/config.py +256 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/dia.py +592 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/layers.py +870 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/__init__.py +3 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/attention.py +180 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/conformer.py +247 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/gpt2.py +38 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/indextts.py +412 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/mel.py +37 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/normalize.py +294 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/perceiver.py +62 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/interpolate.py +108 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/__init__.py +4 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/modules.py +659 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/voice.py +113 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/__init__.py +3 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/llama.py +324 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/outetts.py +255 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/tokens.py +36 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/__init__.py +3 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/attention.py +195 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/sesame.py +633 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/watermarking.py +105 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/bicodec.py +269 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual.py +209 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/spark.py +382 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/audio.py +220 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/file.py +221 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_base.py +66 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_convert.py +173 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_interpolate.py +88 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_models.py +974 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/utils.py +337 -0
- nexaai/binds/metal/py-lib/mlx_audio/utils.py +237 -0
- nexaai/binds/metal/py-lib/mlx_audio/version.py +1 -0
- nexaai/binds/metal/py-lib/profiling.py +239 -0
- nexaai/binds/nexaml/libfftw3.3.dylib +0 -0
- nexaai/binds/nexaml/libfftw3f.3.dylib +0 -0
- nexaai/binds/nexaml/libggml-base.dylib +0 -0
- nexaai/binds/nexaml/libggml-cpu.so +0 -0
- nexaai/binds/nexaml/libggml-metal.so +0 -0
- nexaai/binds/nexaml/libggml.dylib +0 -0
- nexaai/binds/nexaml/libmp3lame.0.dylib +0 -0
- nexaai/binds/nexaml/libmpg123.0.dylib +0 -0
- nexaai/binds/nexaml/libnexa-mm-process.dylib +0 -0
- nexaai/binds/nexaml/libnexa-sampling.dylib +0 -0
- nexaai/binds/nexaml/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexaml/libnexaproc.dylib +0 -0
- nexaai/binds/nexaml/libomp.dylib +0 -0
- nexaai/binds/nexaml/libqwen3-vl.dylib +0 -0
- nexaai/binds/nexaml/libqwen3vl-vision.dylib +0 -0
- nexaai/binds/rerank_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/vlm_bind.cpython-310-darwin.so +0 -0
- nexaai/common.py +106 -0
- nexaai/cv.py +95 -0
- nexaai/cv_impl/__init__.py +0 -0
- nexaai/cv_impl/mlx_cv_impl.py +91 -0
- nexaai/cv_impl/pybind_cv_impl.py +124 -0
- nexaai/diarize.py +80 -0
- nexaai/diarize_impl/__init__.py +1 -0
- nexaai/diarize_impl/pybind_diarize_impl.py +125 -0
- nexaai/embedder.py +73 -0
- nexaai/embedder_impl/__init__.py +0 -0
- nexaai/embedder_impl/mlx_embedder_impl.py +118 -0
- nexaai/embedder_impl/pybind_embedder_impl.py +96 -0
- nexaai/image_gen.py +141 -0
- nexaai/image_gen_impl/__init__.py +0 -0
- nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -0
- nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -0
- nexaai/llm.py +98 -0
- nexaai/llm_impl/__init__.py +0 -0
- nexaai/llm_impl/mlx_llm_impl.py +271 -0
- nexaai/llm_impl/pybind_llm_impl.py +238 -0
- nexaai/log.py +92 -0
- nexaai/mlx_backend/asr/__init__.py +12 -0
- nexaai/mlx_backend/asr/interface.py +122 -0
- nexaai/mlx_backend/common/__init__.py +0 -0
- nexaai/mlx_backend/common/utils.py +25 -0
- nexaai/mlx_backend/cv/__init__.py +0 -0
- nexaai/mlx_backend/cv/generate.py +195 -0
- nexaai/mlx_backend/cv/interface.py +162 -0
- nexaai/mlx_backend/cv/main.py +81 -0
- nexaai/mlx_backend/cv/modeling/pp_ocr_v4.py +1736 -0
- nexaai/mlx_backend/embedding/__init__.py +0 -0
- nexaai/mlx_backend/embedding/generate.py +333 -0
- nexaai/mlx_backend/embedding/interface.py +617 -0
- nexaai/mlx_backend/embedding/main.py +173 -0
- nexaai/mlx_backend/embedding/modeling/__init__.py +0 -0
- nexaai/mlx_backend/embedding/modeling/nexa_jina_v2.py +399 -0
- nexaai/mlx_backend/image_gen/__init__.py +1 -0
- nexaai/mlx_backend/image_gen/generate_sd.py +244 -0
- nexaai/mlx_backend/image_gen/interface.py +82 -0
- nexaai/mlx_backend/image_gen/main.py +281 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/__init__.py +306 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/clip.py +116 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/config.py +65 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/model_io.py +386 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/sampler.py +105 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/tokenizer.py +100 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/unet.py +460 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/vae.py +274 -0
- nexaai/mlx_backend/llm/__init__.py +0 -0
- nexaai/mlx_backend/llm/generate.py +149 -0
- nexaai/mlx_backend/llm/interface.py +764 -0
- nexaai/mlx_backend/llm/main.py +68 -0
- nexaai/mlx_backend/ml.py +888 -0
- nexaai/mlx_backend/mlx_audio/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/codec/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/__init__.py +5 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/activation.py +51 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/amp.py +96 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/conv.py +114 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/resample.py +177 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/base.py +228 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/dac.py +285 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/layers.py +129 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
- nexaai/mlx_backend/mlx_audio/codec/models/encodec/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/encodec/encodec.py +777 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/mimi.py +286 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/model.py +260 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/model_v2.py +383 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/utils.py +122 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/attention.py +97 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/layers.py +306 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/snac.py +154 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/vq.py +135 -0
- nexaai/mlx_backend/mlx_audio/codec/models/vocos/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/vocos/mel.py +33 -0
- nexaai/mlx_backend/mlx_audio/codec/models/vocos/vocos.py +359 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_bigvgan.py +54 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_descript.py +109 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_encodec.py +58 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_mimi.py +22 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_s3.py +25 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_snac.py +40 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_vocos.py +93 -0
- nexaai/mlx_backend/mlx_audio/server.py +525 -0
- nexaai/mlx_backend/mlx_audio/sts/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
- nexaai/mlx_backend/mlx_audio/sts/voice_pipeline.py +327 -0
- nexaai/mlx_backend/mlx_audio/stt/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/stt/generate.py +174 -0
- nexaai/mlx_backend/mlx_audio/stt/models/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/alignment.py +248 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/attention.py +187 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/audio.py +76 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/conformer.py +331 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/ctc.py +34 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
- nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
- nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/audio.py +82 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/decoding.py +742 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/timing.py +329 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/whisper.py +862 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/writers.py +268 -0
- nexaai/mlx_backend/mlx_audio/stt/tests/test_models.py +381 -0
- nexaai/mlx_backend/mlx_audio/stt/utils.py +195 -0
- nexaai/mlx_backend/mlx_audio/tts/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/audio_player.py +120 -0
- nexaai/mlx_backend/mlx_audio/tts/convert.py +71 -0
- nexaai/mlx_backend/mlx_audio/tts/generate.py +449 -0
- nexaai/mlx_backend/mlx_audio/tts/models/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/__init__.py +4 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/bark.py +528 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/isftnet.py +12 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/pipeline.py +442 -0
- nexaai/mlx_backend/mlx_audio/tts/models/base.py +84 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/audio.py +287 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/config.py +256 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/dia.py +592 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/layers.py +870 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/__init__.py +3 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/attention.py +180 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/conformer.py +247 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/gpt2.py +38 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/indextts.py +412 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/mel.py +37 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/normalize.py +294 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/perceiver.py +62 -0
- nexaai/mlx_backend/mlx_audio/tts/models/interpolate.py +108 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/__init__.py +4 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/modules.py +659 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/voice.py +113 -0
- nexaai/mlx_backend/mlx_audio/tts/models/llama/__init__.py +3 -0
- nexaai/mlx_backend/mlx_audio/tts/models/llama/llama.py +324 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/default_speaker.json +461 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/outetts.py +255 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/tokens.py +36 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/__init__.py +3 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/attention.py +195 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/sesame.py +633 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/watermarking.py +105 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/bicodec.py +269 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual.py +209 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/spark.py +382 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/audio.py +220 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/file.py +221 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_base.py +66 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_convert.py +173 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_interpolate.py +88 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_models.py +974 -0
- nexaai/mlx_backend/mlx_audio/tts/utils.py +337 -0
- nexaai/mlx_backend/mlx_audio/utils.py +237 -0
- nexaai/mlx_backend/mlx_audio/version.py +1 -0
- nexaai/mlx_backend/profiling.py +239 -0
- nexaai/mlx_backend/rerank/__init__.py +0 -0
- nexaai/mlx_backend/rerank/generate.py +174 -0
- nexaai/mlx_backend/rerank/interface.py +287 -0
- nexaai/mlx_backend/rerank/main.py +127 -0
- nexaai/mlx_backend/rerank/modeling/__init__.py +0 -0
- nexaai/mlx_backend/rerank/modeling/nexa_jina_rerank.py +330 -0
- nexaai/mlx_backend/sd/__init__.py +1 -0
- nexaai/mlx_backend/sd/interface.py +362 -0
- nexaai/mlx_backend/sd/main.py +286 -0
- nexaai/mlx_backend/sd/modeling/__init__.py +306 -0
- nexaai/mlx_backend/sd/modeling/clip.py +116 -0
- nexaai/mlx_backend/sd/modeling/config.py +65 -0
- nexaai/mlx_backend/sd/modeling/model_io.py +385 -0
- nexaai/mlx_backend/sd/modeling/sampler.py +105 -0
- nexaai/mlx_backend/sd/modeling/tokenizer.py +100 -0
- nexaai/mlx_backend/sd/modeling/unet.py +460 -0
- nexaai/mlx_backend/sd/modeling/vae.py +274 -0
- nexaai/mlx_backend/tts/__init__.py +12 -0
- nexaai/mlx_backend/tts/interface.py +276 -0
- nexaai/mlx_backend/vlm/__init__.py +3 -0
- nexaai/mlx_backend/vlm/generate.py +572 -0
- nexaai/mlx_backend/vlm/generate_qwen3_vl.py +374 -0
- nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +259 -0
- nexaai/mlx_backend/vlm/interface.py +559 -0
- nexaai/mlx_backend/vlm/main.py +365 -0
- nexaai/mlx_backend/vlm/modeling/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/convert.py +68 -0
- nexaai/mlx_backend/vlm/modeling/models/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/interpolate.py +186 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/language.py +233 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/vision.py +503 -0
- nexaai/mlx_backend/vlm/modeling/models/base.py +202 -0
- nexaai/mlx_backend/vlm/modeling/models/cache.py +230 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/florence2.py +366 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/language.py +488 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/vision.py +591 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/gemma3.py +213 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/language.py +315 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/vision.py +238 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/__init__.py +2 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/audio.py +1038 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/config.py +139 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/language.py +629 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/vision.py +1022 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/idefics2.py +294 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/language.py +191 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/vision.py +267 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/idefics3.py +175 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/language.py +192 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/vision.py +233 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/language.py +220 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/processor.py +393 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/vision.py +293 -0
- nexaai/mlx_backend/vlm/modeling/models/kernels.py +307 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/language.py +509 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/vision.py +522 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/language.py +386 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/llama4.py +138 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/vision.py +560 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/language.py +240 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/llava.py +153 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/vision.py +259 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/language.py +236 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/vision.py +303 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/language.py +230 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/llava_next.py +160 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/vision.py +243 -0
- nexaai/mlx_backend/vlm/modeling/models/mistral3/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/mistral3/mistral3.py +283 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/language.py +416 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/mllama.py +172 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/vision.py +499 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/language.py +243 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/molmo.py +133 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/vision.py +465 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/__init__.py +10 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/language.py +230 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/sam.py +557 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/vision.py +526 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/language.py +282 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/paligemma.py +160 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/vision.py +242 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/language.py +21 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/su_rope.py +71 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/vision.py +324 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/language.py +229 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/pixtral.py +161 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/vision.py +320 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/config.py +104 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/language.py +490 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/vision.py +312 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/processor.py +476 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +1262 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1308 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
- nexaai/mlx_backend/vlm/modeling/models/smolvlm/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
- nexaai/mlx_backend/vlm/modeling/processing_qwen2_5_vl.py +209 -0
- nexaai/mlx_backend/vlm/modeling/processing_qwen2_vl.py +215 -0
- nexaai/mlx_backend/vlm/modeling/prompt_utils.py +474 -0
- nexaai/mlx_backend/vlm/modeling/sample_utils.py +39 -0
- nexaai/mlx_backend/vlm/modeling/tokenizer_utils.py +344 -0
- nexaai/mlx_backend/vlm/modeling/trainer/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/trainer/lora.py +70 -0
- nexaai/mlx_backend/vlm/modeling/trainer/trainer.py +296 -0
- nexaai/mlx_backend/vlm/modeling/trainer/utils.py +160 -0
- nexaai/mlx_backend/vlm/modeling/utils.py +928 -0
- nexaai/rerank.py +57 -0
- nexaai/rerank_impl/__init__.py +0 -0
- nexaai/rerank_impl/mlx_rerank_impl.py +94 -0
- nexaai/rerank_impl/pybind_rerank_impl.py +136 -0
- nexaai/runtime.py +68 -0
- nexaai/runtime_error.py +24 -0
- nexaai/tts.py +75 -0
- nexaai/tts_impl/__init__.py +0 -0
- nexaai/tts_impl/mlx_tts_impl.py +94 -0
- nexaai/tts_impl/pybind_tts_impl.py +43 -0
- nexaai/utils/decode.py +18 -0
- nexaai/utils/manifest_utils.py +531 -0
- nexaai/utils/model_manager.py +1745 -0
- nexaai/utils/model_types.py +49 -0
- nexaai/utils/progress_tracker.py +389 -0
- nexaai/utils/quantization_utils.py +245 -0
- nexaai/vlm.py +130 -0
- nexaai/vlm_impl/__init__.py +0 -0
- nexaai/vlm_impl/mlx_vlm_impl.py +259 -0
- nexaai/vlm_impl/pybind_vlm_impl.py +275 -0
- nexaai-1.0.29.dist-info/METADATA +35 -0
- nexaai-1.0.29.dist-info/RECORD +580 -0
- nexaai-1.0.29.dist-info/WHEEL +5 -0
- nexaai-1.0.29.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,979 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import List, Optional, Tuple, Union
|
|
3
|
+
|
|
4
|
+
import mlx.core as mx
|
|
5
|
+
import mlx.nn as nn
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from mlx_audio.utils import istft, stft
|
|
9
|
+
|
|
10
|
+
from ..base import check_array_shape
|
|
11
|
+
from ..interpolate import interpolate
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def get_padding(kernel_size: int, dilation: int = 1) -> int:
|
|
15
|
+
return int((kernel_size * dilation - dilation) / 2)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def compute_norm(
|
|
19
|
+
x: mx.array,
|
|
20
|
+
p: int,
|
|
21
|
+
dim: Optional[Union[int, List[int]]] = None,
|
|
22
|
+
keepdim: bool = False,
|
|
23
|
+
) -> mx.array:
|
|
24
|
+
"""
|
|
25
|
+
Compute the p-norm of a tensor along specified dimensions.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
x: Input array
|
|
29
|
+
p: Order of the norm (1 or 2)
|
|
30
|
+
dim: Dimension(s) along which to compute the norm
|
|
31
|
+
keepdim: Whether to keep the reduced dimensions
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
MLX array containing the computed norm
|
|
35
|
+
"""
|
|
36
|
+
if p not in [1, 2]:
|
|
37
|
+
raise ValueError("Only p-norms with p of 1 or 2 are supported")
|
|
38
|
+
|
|
39
|
+
# Handle dimension input
|
|
40
|
+
if dim is None:
|
|
41
|
+
dim = tuple(range(x.ndim))
|
|
42
|
+
elif isinstance(dim, int):
|
|
43
|
+
dim = (dim,)
|
|
44
|
+
|
|
45
|
+
if p == 1:
|
|
46
|
+
# L1 norm
|
|
47
|
+
return mx.sum(mx.abs(x), axis=dim, keepdims=keepdim)
|
|
48
|
+
else:
|
|
49
|
+
# L2 norm
|
|
50
|
+
return mx.sqrt(mx.sum(x * x, axis=dim, keepdims=keepdim))
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def weight_norm(
|
|
54
|
+
weight_v: mx.array, weight_g: mx.array, dim: Optional[int] = None
|
|
55
|
+
) -> mx.array:
|
|
56
|
+
"""
|
|
57
|
+
Applies weight normalization to the input tensor.
|
|
58
|
+
|
|
59
|
+
Weight normalization reparameterizes weight vectors in a neural network
|
|
60
|
+
as a magnitude scalar times a direction vector: w = g * v/||v||
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
weight_v: Weight direction tensor (v)
|
|
64
|
+
weight_g: Weight magnitude tensor (g)
|
|
65
|
+
dim: Dimension along which to normalize. If None, normalize over all dims
|
|
66
|
+
except dim=-1
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
Normalized weight tensor
|
|
70
|
+
"""
|
|
71
|
+
rank = len(weight_v.shape)
|
|
72
|
+
|
|
73
|
+
if dim is not None:
|
|
74
|
+
# Adjust negative dim
|
|
75
|
+
if dim < -1:
|
|
76
|
+
dim += rank
|
|
77
|
+
|
|
78
|
+
# Create list of axes to normalize over
|
|
79
|
+
axes = list(range(rank))
|
|
80
|
+
if dim != -1:
|
|
81
|
+
axes.remove(dim)
|
|
82
|
+
else:
|
|
83
|
+
# Default behavior: normalize over all dimensions
|
|
84
|
+
axes = list(range(rank))
|
|
85
|
+
|
|
86
|
+
# Compute L2 norm of v along specified axes
|
|
87
|
+
norm_v = compute_norm(weight_v, p=2, dim=axes, keepdim=True)
|
|
88
|
+
|
|
89
|
+
# Normalize and scale by g: w = g * (v / ||v||)
|
|
90
|
+
normalized_weight = weight_v / (
|
|
91
|
+
norm_v + 1e-7
|
|
92
|
+
) # Add epsilon for numerical stability
|
|
93
|
+
return normalized_weight * weight_g
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class ConvWeighted(nn.Module):
|
|
97
|
+
"""Conv1d with weight normalization"""
|
|
98
|
+
|
|
99
|
+
def __init__(
|
|
100
|
+
self,
|
|
101
|
+
in_channels: int,
|
|
102
|
+
out_channels: int,
|
|
103
|
+
kernel_size: int,
|
|
104
|
+
stride: int = 1,
|
|
105
|
+
padding: int = 1,
|
|
106
|
+
dilation: int = 1,
|
|
107
|
+
groups: int = 1,
|
|
108
|
+
bias: bool = True,
|
|
109
|
+
encode: bool = False,
|
|
110
|
+
):
|
|
111
|
+
super().__init__()
|
|
112
|
+
|
|
113
|
+
self.stride = stride
|
|
114
|
+
self.padding = padding
|
|
115
|
+
self.dilation = dilation
|
|
116
|
+
self.groups = groups
|
|
117
|
+
|
|
118
|
+
# Initialize weight magnitude (g) and direction (v) vectors
|
|
119
|
+
self.weight_g = mx.ones(
|
|
120
|
+
(out_channels, 1, 1)
|
|
121
|
+
) # Scalar magnitude per output channel
|
|
122
|
+
self.weight_v = mx.ones(
|
|
123
|
+
(out_channels, kernel_size, in_channels)
|
|
124
|
+
) # Direction vectors
|
|
125
|
+
|
|
126
|
+
self.bias = mx.zeros(in_channels if encode else out_channels) if bias else None
|
|
127
|
+
|
|
128
|
+
def __call__(self, x, conv):
|
|
129
|
+
|
|
130
|
+
weight = weight_norm(self.weight_v, self.weight_g, dim=0)
|
|
131
|
+
|
|
132
|
+
if self.bias is not None:
|
|
133
|
+
bias = self.bias.reshape(1, 1, -1)
|
|
134
|
+
else:
|
|
135
|
+
bias = None
|
|
136
|
+
|
|
137
|
+
def apply_conv(x, weight_to_use):
|
|
138
|
+
if self.bias is not None:
|
|
139
|
+
return (
|
|
140
|
+
conv(
|
|
141
|
+
x,
|
|
142
|
+
weight_to_use,
|
|
143
|
+
stride=self.stride,
|
|
144
|
+
padding=self.padding,
|
|
145
|
+
dilation=self.dilation,
|
|
146
|
+
groups=self.groups,
|
|
147
|
+
)
|
|
148
|
+
+ bias
|
|
149
|
+
)
|
|
150
|
+
return conv(
|
|
151
|
+
x,
|
|
152
|
+
weight_to_use,
|
|
153
|
+
stride=self.stride,
|
|
154
|
+
padding=self.padding,
|
|
155
|
+
dilation=self.dilation,
|
|
156
|
+
groups=self.groups,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
try:
|
|
160
|
+
# Check if channels last match or if groups > 1 for ConvTransposed1d
|
|
161
|
+
if x.shape[-1] == weight.shape[-1] or self.groups > 1:
|
|
162
|
+
# Input is channels first, use weight as-is
|
|
163
|
+
return apply_conv(x, weight)
|
|
164
|
+
else:
|
|
165
|
+
# Input is channels last, need to transpose weight
|
|
166
|
+
return apply_conv(x, weight.T)
|
|
167
|
+
except Exception as e:
|
|
168
|
+
print(f"Error: {e}")
|
|
169
|
+
print(f"x.shape: {x.shape}, weight.shape: {weight.shape}")
|
|
170
|
+
raise e
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class _InstanceNorm(nn.Module):
|
|
174
|
+
def __init__(
|
|
175
|
+
self,
|
|
176
|
+
num_features: int,
|
|
177
|
+
eps: float = 1e-5,
|
|
178
|
+
momentum: float = 0.1,
|
|
179
|
+
affine: bool = False,
|
|
180
|
+
track_running_stats: bool = False,
|
|
181
|
+
) -> None:
|
|
182
|
+
super().__init__()
|
|
183
|
+
self.num_features = num_features
|
|
184
|
+
self.eps = eps
|
|
185
|
+
self.momentum = momentum
|
|
186
|
+
self.affine = affine
|
|
187
|
+
self.track_running_stats = track_running_stats
|
|
188
|
+
|
|
189
|
+
# Initialize parameters
|
|
190
|
+
if self.affine:
|
|
191
|
+
self.weight = mx.ones((num_features,))
|
|
192
|
+
self.bias = mx.zeros((num_features,))
|
|
193
|
+
else:
|
|
194
|
+
self.weight = None
|
|
195
|
+
self.bias = None
|
|
196
|
+
|
|
197
|
+
if self.track_running_stats:
|
|
198
|
+
self.running_mean = mx.zeros((num_features,))
|
|
199
|
+
self.running_var = mx.ones((num_features,))
|
|
200
|
+
else:
|
|
201
|
+
self.running_mean = None
|
|
202
|
+
self.running_var = None
|
|
203
|
+
|
|
204
|
+
def _check_input_dim(self, input):
|
|
205
|
+
raise NotImplementedError
|
|
206
|
+
|
|
207
|
+
def _get_no_batch_dim(self):
|
|
208
|
+
raise NotImplementedError
|
|
209
|
+
|
|
210
|
+
def _handle_no_batch_input(self, input):
|
|
211
|
+
# Add batch dimension, apply norm, then remove batch dimension
|
|
212
|
+
expanded = mx.expand_dims(input, axis=0)
|
|
213
|
+
result = self._apply_instance_norm(expanded)
|
|
214
|
+
return mx.squeeze(result, axis=0)
|
|
215
|
+
|
|
216
|
+
def _apply_instance_norm(self, input):
|
|
217
|
+
# MLX doesn't have a direct instance_norm function like PyTorch
|
|
218
|
+
# So we need to implement it manually
|
|
219
|
+
|
|
220
|
+
# Get dimensions
|
|
221
|
+
dims = list(range(input.ndim))
|
|
222
|
+
feature_dim = dims[-self._get_no_batch_dim()]
|
|
223
|
+
|
|
224
|
+
# Compute statistics along all dims except batch and feature dims
|
|
225
|
+
reduce_dims = [d for d in dims if d != 0 and d != feature_dim]
|
|
226
|
+
|
|
227
|
+
if self.training or not self.track_running_stats:
|
|
228
|
+
# Compute mean and variance for normalization
|
|
229
|
+
mean = mx.mean(input, axis=reduce_dims, keepdims=True)
|
|
230
|
+
var = mx.var(input, axis=reduce_dims, keepdims=True)
|
|
231
|
+
|
|
232
|
+
# Update running stats if tracking
|
|
233
|
+
if self.track_running_stats and self.training:
|
|
234
|
+
# Compute overall mean and variance (across batch too)
|
|
235
|
+
overall_mean = mx.mean(mean, axis=0)
|
|
236
|
+
overall_var = mx.mean(var, axis=0)
|
|
237
|
+
|
|
238
|
+
# Update running statistics
|
|
239
|
+
self.running_mean = (
|
|
240
|
+
1 - self.momentum
|
|
241
|
+
) * self.running_mean + self.momentum * overall_mean
|
|
242
|
+
self.running_var = (
|
|
243
|
+
1 - self.momentum
|
|
244
|
+
) * self.running_var + self.momentum * overall_var
|
|
245
|
+
else:
|
|
246
|
+
# Use running statistics
|
|
247
|
+
mean_shape = [1] * input.ndim
|
|
248
|
+
mean_shape[feature_dim] = self.num_features
|
|
249
|
+
var_shape = mean_shape.copy()
|
|
250
|
+
|
|
251
|
+
mean = mx.reshape(self.running_mean, mean_shape)
|
|
252
|
+
var = mx.reshape(self.running_var, var_shape)
|
|
253
|
+
|
|
254
|
+
# Normalize
|
|
255
|
+
x_norm = (input - mean) / mx.sqrt(var + self.eps)
|
|
256
|
+
|
|
257
|
+
# Apply affine transform if needed
|
|
258
|
+
if self.affine:
|
|
259
|
+
weight_shape = [1] * input.ndim
|
|
260
|
+
weight_shape[feature_dim] = self.num_features
|
|
261
|
+
bias_shape = weight_shape.copy()
|
|
262
|
+
|
|
263
|
+
weight = mx.reshape(self.weight, weight_shape)
|
|
264
|
+
bias = mx.reshape(self.bias, bias_shape)
|
|
265
|
+
|
|
266
|
+
return x_norm * weight + bias
|
|
267
|
+
else:
|
|
268
|
+
return x_norm
|
|
269
|
+
|
|
270
|
+
def __call__(self, input):
|
|
271
|
+
self._check_input_dim(input)
|
|
272
|
+
|
|
273
|
+
feature_dim = input.ndim - self._get_no_batch_dim()
|
|
274
|
+
if input.shape[feature_dim] != self.num_features:
|
|
275
|
+
if self.affine:
|
|
276
|
+
raise ValueError(
|
|
277
|
+
f"expected input's size at dim={feature_dim} to match num_features"
|
|
278
|
+
f" ({self.num_features}), but got: {input.shape[feature_dim]}."
|
|
279
|
+
)
|
|
280
|
+
else:
|
|
281
|
+
print(
|
|
282
|
+
f"input's size at dim={feature_dim} does not match num_features. "
|
|
283
|
+
"You can silence this warning by not passing in num_features, "
|
|
284
|
+
"which is not used because affine=False"
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
if input.ndim == self._get_no_batch_dim():
|
|
288
|
+
return self._handle_no_batch_input(input)
|
|
289
|
+
|
|
290
|
+
return self._apply_instance_norm(input)
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
class InstanceNorm1d(_InstanceNorm):
|
|
294
|
+
"""Applies Instance Normalization over a 2D (unbatched) or 3D (batched) input.
|
|
295
|
+
|
|
296
|
+
This implementation follows the algorithm described in the paper
|
|
297
|
+
"Instance Normalization: The Missing Ingredient for Fast Stylization".
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
num_features: Number of features or channels (C) of the input
|
|
301
|
+
eps: A value added to the denominator for numerical stability. Default: 1e-5
|
|
302
|
+
momentum: The value used for the running_mean and running_var computation. Default: 0.1
|
|
303
|
+
affine: When True, this module has learnable affine parameters. Default: False
|
|
304
|
+
track_running_stats: When True, this module tracks running statistics. Default: False
|
|
305
|
+
|
|
306
|
+
Shape:
|
|
307
|
+
- Input: (N, C, L) or (C, L)
|
|
308
|
+
- Output: Same shape as input
|
|
309
|
+
|
|
310
|
+
Examples:
|
|
311
|
+
>>> # Without Learnable Parameters
|
|
312
|
+
>>> m = nn.InstanceNorm1d(100)
|
|
313
|
+
>>> # With Learnable Parameters
|
|
314
|
+
>>> m = nn.InstanceNorm1d(100, affine=True)
|
|
315
|
+
>>> input = mx.random.normal((20, 100, 40))
|
|
316
|
+
>>> output = m(input)
|
|
317
|
+
"""
|
|
318
|
+
|
|
319
|
+
def _get_no_batch_dim(self):
|
|
320
|
+
return 2
|
|
321
|
+
|
|
322
|
+
def _check_input_dim(self, input):
|
|
323
|
+
if input.ndim not in (2, 3):
|
|
324
|
+
raise ValueError(f"expected 2D or 3D input (got {input.ndim}D input)")
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
class AdaIN1d(nn.Module):
|
|
328
|
+
def __init__(self, style_dim: int, num_features: int):
|
|
329
|
+
super().__init__()
|
|
330
|
+
self.norm = InstanceNorm1d(num_features, affine=False)
|
|
331
|
+
self.fc = nn.Linear(style_dim, num_features * 2)
|
|
332
|
+
|
|
333
|
+
def __call__(self, x: mx.array, s: mx.array) -> mx.array:
|
|
334
|
+
h = self.fc(s)
|
|
335
|
+
h = mx.expand_dims(h, axis=2) # Equivalent to view(..., 1)
|
|
336
|
+
gamma, beta = mx.split(h, 2, axis=1)
|
|
337
|
+
x = (1 + gamma) * self.norm(x) + beta
|
|
338
|
+
return x
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
class AdaINResBlock1(nn.Module):
|
|
342
|
+
def __init__(
|
|
343
|
+
self,
|
|
344
|
+
channels: int,
|
|
345
|
+
kernel_size: int = 3,
|
|
346
|
+
dilation: Tuple[int, int, int] = (1, 3, 5),
|
|
347
|
+
style_dim: int = 64,
|
|
348
|
+
):
|
|
349
|
+
super(AdaINResBlock1, self).__init__()
|
|
350
|
+
self.convs1 = [
|
|
351
|
+
ConvWeighted(
|
|
352
|
+
channels,
|
|
353
|
+
channels,
|
|
354
|
+
kernel_size,
|
|
355
|
+
stride=1,
|
|
356
|
+
padding=get_padding(kernel_size, dilation[i]),
|
|
357
|
+
dilation=dilation[i],
|
|
358
|
+
)
|
|
359
|
+
for i in range(3)
|
|
360
|
+
]
|
|
361
|
+
self.convs2 = [
|
|
362
|
+
ConvWeighted(
|
|
363
|
+
channels,
|
|
364
|
+
channels,
|
|
365
|
+
kernel_size,
|
|
366
|
+
stride=1,
|
|
367
|
+
padding=get_padding(kernel_size, 1),
|
|
368
|
+
dilation=1,
|
|
369
|
+
)
|
|
370
|
+
for _ in range(3)
|
|
371
|
+
]
|
|
372
|
+
self.adain1 = [AdaIN1d(style_dim, channels) for _ in range(3)]
|
|
373
|
+
self.adain2 = [AdaIN1d(style_dim, channels) for _ in range(3)]
|
|
374
|
+
self.alpha1 = [mx.ones((1, channels, 1)) for _ in range(len(self.convs1))]
|
|
375
|
+
self.alpha2 = [mx.ones((1, channels, 1)) for _ in range(len(self.convs2))]
|
|
376
|
+
|
|
377
|
+
def __call__(self, x: mx.array, s: mx.array) -> mx.array:
|
|
378
|
+
for c1, c2, n1, n2, a1, a2 in zip(
|
|
379
|
+
self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2
|
|
380
|
+
):
|
|
381
|
+
xt = n1(x, s)
|
|
382
|
+
xt = xt + (1 / a1) * (mx.sin(a1 * xt) ** 2) # Snake1D
|
|
383
|
+
|
|
384
|
+
xt = xt.swapaxes(2, 1)
|
|
385
|
+
xt = c1(xt, mx.conv1d)
|
|
386
|
+
xt = xt.swapaxes(2, 1)
|
|
387
|
+
|
|
388
|
+
xt = n2(xt, s)
|
|
389
|
+
xt = xt + (1 / a2) * (mx.sin(a2 * xt) ** 2) # Snake1D
|
|
390
|
+
|
|
391
|
+
xt = xt.swapaxes(2, 1)
|
|
392
|
+
xt = c2(xt, mx.conv1d)
|
|
393
|
+
xt = xt.swapaxes(2, 1)
|
|
394
|
+
|
|
395
|
+
x = xt + x
|
|
396
|
+
return x
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
def mlx_angle(z, deg=False):
|
|
400
|
+
z = mx.array(z)
|
|
401
|
+
|
|
402
|
+
if z.dtype == mx.complex64:
|
|
403
|
+
zimag = mx.imag(z)
|
|
404
|
+
zreal = mx.real(z)
|
|
405
|
+
else:
|
|
406
|
+
zimag = mx.zeros_like(z)
|
|
407
|
+
zreal = z
|
|
408
|
+
|
|
409
|
+
a = mx.arctan2(zimag, zreal)
|
|
410
|
+
|
|
411
|
+
if deg:
|
|
412
|
+
a = a * (180.0 / math.pi)
|
|
413
|
+
|
|
414
|
+
return a
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
def mlx_unwrap(p, discont=None, axis=-1, period=2 * math.pi):
|
|
418
|
+
if discont is None:
|
|
419
|
+
discont = period / 2
|
|
420
|
+
|
|
421
|
+
discont = max(discont, period / 2)
|
|
422
|
+
|
|
423
|
+
slice_indices = [slice(None)] * p.ndim
|
|
424
|
+
|
|
425
|
+
slice_indices[axis] = slice(1, None)
|
|
426
|
+
after_slice = tuple(slice_indices)
|
|
427
|
+
|
|
428
|
+
slice_indices[axis] = slice(None, -1)
|
|
429
|
+
before_slice = tuple(slice_indices)
|
|
430
|
+
|
|
431
|
+
dd = p[after_slice] - p[before_slice]
|
|
432
|
+
|
|
433
|
+
interval_high = period / 2
|
|
434
|
+
interval_low = -interval_high
|
|
435
|
+
|
|
436
|
+
ddmod = dd - period * mx.floor((dd - interval_low) / period)
|
|
437
|
+
ddmod = mx.where(
|
|
438
|
+
(mx.abs(dd - interval_high) < 1e-10) & (dd > 0), interval_high, ddmod
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
ph_correct = ddmod - dd
|
|
442
|
+
ph_correct = mx.where(mx.abs(dd) < discont, 0, ph_correct)
|
|
443
|
+
|
|
444
|
+
padding_shape = list(ph_correct.shape)
|
|
445
|
+
padding_shape[axis] = 1
|
|
446
|
+
zero_padding = mx.zeros(padding_shape)
|
|
447
|
+
padded_corrections = mx.concatenate([zero_padding, ph_correct], axis=axis)
|
|
448
|
+
cumulative_corrections = mx.cumsum(padded_corrections, axis=axis)
|
|
449
|
+
|
|
450
|
+
return p + cumulative_corrections
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
class MLXSTFT:
|
|
454
|
+
def __init__(
|
|
455
|
+
self, filter_length=800, hop_length=200, win_length=800, window="hann"
|
|
456
|
+
):
|
|
457
|
+
self.filter_length = filter_length
|
|
458
|
+
self.hop_length = hop_length
|
|
459
|
+
self.win_length = win_length
|
|
460
|
+
|
|
461
|
+
self.window = window
|
|
462
|
+
|
|
463
|
+
def transform(self, input_data):
|
|
464
|
+
# Ensure 2D
|
|
465
|
+
if input_data.ndim == 1:
|
|
466
|
+
input_data = input_data[None, :]
|
|
467
|
+
|
|
468
|
+
magnitudes = []
|
|
469
|
+
phases = []
|
|
470
|
+
|
|
471
|
+
for batch_idx in range(input_data.shape[0]):
|
|
472
|
+
# Compute STFT
|
|
473
|
+
x_stft = stft(
|
|
474
|
+
input_data[batch_idx],
|
|
475
|
+
n_fft=self.filter_length,
|
|
476
|
+
hop_length=self.hop_length,
|
|
477
|
+
win_length=self.win_length,
|
|
478
|
+
window=self.window,
|
|
479
|
+
center=True,
|
|
480
|
+
pad_mode="reflect",
|
|
481
|
+
).transpose(1, 0)
|
|
482
|
+
|
|
483
|
+
# Get magnitude
|
|
484
|
+
magnitude = mx.abs(x_stft)
|
|
485
|
+
|
|
486
|
+
# Get phase
|
|
487
|
+
phase = mlx_angle(x_stft)
|
|
488
|
+
|
|
489
|
+
magnitudes.append(magnitude)
|
|
490
|
+
phases.append(phase)
|
|
491
|
+
|
|
492
|
+
magnitudes = mx.stack(magnitudes, axis=0)
|
|
493
|
+
phases = mx.stack(phases, axis=0)
|
|
494
|
+
|
|
495
|
+
return magnitudes, phases
|
|
496
|
+
|
|
497
|
+
def inverse(self, magnitude, phase):
|
|
498
|
+
reconstructed = []
|
|
499
|
+
|
|
500
|
+
for batch_idx in range(magnitude.shape[0]):
|
|
501
|
+
# Unwrap phases for reconstruction
|
|
502
|
+
phase_cont = mlx_unwrap(phase[batch_idx], axis=1)
|
|
503
|
+
|
|
504
|
+
# Combine magnitude and phase
|
|
505
|
+
real_part = magnitude[batch_idx] * mx.cos(phase_cont)
|
|
506
|
+
imag_part = magnitude[batch_idx] * mx.sin(phase_cont)
|
|
507
|
+
x_stft = real_part + 1j * imag_part
|
|
508
|
+
|
|
509
|
+
# Inverse STFT
|
|
510
|
+
audio = istft(
|
|
511
|
+
x_stft,
|
|
512
|
+
hop_length=self.hop_length,
|
|
513
|
+
win_length=self.win_length,
|
|
514
|
+
window=self.window,
|
|
515
|
+
center=True,
|
|
516
|
+
length=None,
|
|
517
|
+
)
|
|
518
|
+
|
|
519
|
+
reconstructed.append(audio)
|
|
520
|
+
|
|
521
|
+
reconstructed = mx.stack(reconstructed, axis=0)[:, None, :]
|
|
522
|
+
|
|
523
|
+
return reconstructed
|
|
524
|
+
|
|
525
|
+
def __call__(self, input_data: mx.array) -> mx.array:
|
|
526
|
+
self.magnitude, self.phase = self.transform(input_data)
|
|
527
|
+
reconstruction = self.inverse(self.magnitude, self.phase)
|
|
528
|
+
return mx.expand_dims(reconstruction, axis=-2)
|
|
529
|
+
|
|
530
|
+
|
|
531
|
+
class SineGen:
|
|
532
|
+
def __init__(
|
|
533
|
+
self,
|
|
534
|
+
samp_rate: int,
|
|
535
|
+
upsample_scale: int,
|
|
536
|
+
harmonic_num: int = 0,
|
|
537
|
+
sine_amp: float = 0.1,
|
|
538
|
+
noise_std: float = 0.003,
|
|
539
|
+
voiced_threshold: float = 0,
|
|
540
|
+
flag_for_pulse: bool = False,
|
|
541
|
+
):
|
|
542
|
+
super().__init__()
|
|
543
|
+
self.sine_amp = sine_amp
|
|
544
|
+
self.noise_std = noise_std
|
|
545
|
+
self.harmonic_num = harmonic_num
|
|
546
|
+
self.dim = self.harmonic_num + 1
|
|
547
|
+
self.sampling_rate = samp_rate
|
|
548
|
+
self.voiced_threshold = voiced_threshold
|
|
549
|
+
self.flag_for_pulse = flag_for_pulse
|
|
550
|
+
self.upsample_scale = upsample_scale
|
|
551
|
+
|
|
552
|
+
def _f02uv(self, f0: mx.array) -> mx.array:
|
|
553
|
+
return mx.array(f0 > self.voiced_threshold, dtype=mx.float32)
|
|
554
|
+
|
|
555
|
+
def _f02sine(self, f0_values: mx.array) -> mx.array:
|
|
556
|
+
"""f0_values: (batchsize, length, dim)
|
|
557
|
+
where dim indicates fundamental tone and overtones
|
|
558
|
+
"""
|
|
559
|
+
# convert to F0 in rad. The interger part n can be ignored
|
|
560
|
+
# because 2 * np.pi * n doesn't affect phase
|
|
561
|
+
rad_values = (f0_values / self.sampling_rate) % 1
|
|
562
|
+
# initial phase noise (no noise for fundamental component)
|
|
563
|
+
rand_ini = mx.random.normal((f0_values.shape[0], f0_values.shape[2]))
|
|
564
|
+
rand_ini[:, 0] = 0
|
|
565
|
+
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
|
566
|
+
# instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
|
|
567
|
+
if not self.flag_for_pulse:
|
|
568
|
+
rad_values = interpolate(
|
|
569
|
+
rad_values.transpose(0, 2, 1),
|
|
570
|
+
scale_factor=1 / self.upsample_scale,
|
|
571
|
+
mode="linear",
|
|
572
|
+
).transpose(0, 2, 1)
|
|
573
|
+
phase = mx.cumsum(rad_values, axis=1) * 2 * mx.pi
|
|
574
|
+
phase = interpolate(
|
|
575
|
+
phase.transpose(0, 2, 1) * self.upsample_scale,
|
|
576
|
+
scale_factor=self.upsample_scale,
|
|
577
|
+
mode="linear",
|
|
578
|
+
).transpose(0, 2, 1)
|
|
579
|
+
sines = mx.sin(phase)
|
|
580
|
+
else:
|
|
581
|
+
# If necessary, make sure that the first time step of every
|
|
582
|
+
# voiced segments is sin(pi) or cos(0)
|
|
583
|
+
# This is used for pulse-train generation
|
|
584
|
+
# identify the last time step in unvoiced segments
|
|
585
|
+
uv = self._f02uv(f0_values)
|
|
586
|
+
uv_1 = mx.roll(uv, shifts=-1, axis=1)
|
|
587
|
+
uv_1[:, -1, :] = 1
|
|
588
|
+
u_loc = (uv < 1) * (uv_1 > 0)
|
|
589
|
+
# get the instantanouse phase
|
|
590
|
+
tmp_cumsum = mx.cumsum(rad_values, axis=1)
|
|
591
|
+
# different batch needs to be processed differently
|
|
592
|
+
for idx in range(f0_values.shape[0]):
|
|
593
|
+
temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
|
|
594
|
+
temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
|
|
595
|
+
# stores the accumulation of i.phase within
|
|
596
|
+
# each voiced segments
|
|
597
|
+
tmp_cumsum[idx, :, :] = 0
|
|
598
|
+
tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
|
|
599
|
+
# rad_values - tmp_cumsum: remove the accumulation of i.phase
|
|
600
|
+
# within the previous voiced segment.
|
|
601
|
+
i_phase = mx.cumsum(rad_values - tmp_cumsum, axis=1)
|
|
602
|
+
# get the sines
|
|
603
|
+
sines = mx.cos(i_phase * 2 * mx.pi)
|
|
604
|
+
return sines
|
|
605
|
+
|
|
606
|
+
def __call__(self, f0: mx.array) -> Tuple[mx.array, mx.array, mx.array]:
|
|
607
|
+
f0_buf = mx.zeros((f0.shape[0], f0.shape[1], self.dim))
|
|
608
|
+
|
|
609
|
+
# Fundamental component
|
|
610
|
+
fn = f0 * mx.arange(1, self.harmonic_num + 2)[None, None, :]
|
|
611
|
+
|
|
612
|
+
# Generate sine waveforms
|
|
613
|
+
sine_waves = self._f02sine(fn) * self.sine_amp
|
|
614
|
+
|
|
615
|
+
# Generate UV signal
|
|
616
|
+
uv = self._f02uv(f0)
|
|
617
|
+
|
|
618
|
+
# Generate noise
|
|
619
|
+
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
|
620
|
+
noise = noise_amp * mx.random.normal(sine_waves.shape)
|
|
621
|
+
|
|
622
|
+
sine_waves = sine_waves * uv + noise
|
|
623
|
+
return sine_waves, uv, noise
|
|
624
|
+
|
|
625
|
+
|
|
626
|
+
class SourceModuleHnNSF(nn.Module):
|
|
627
|
+
"""SourceModule for hn-nsf
|
|
628
|
+
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
|
629
|
+
add_noise_std=0.003, voiced_threshod=0)
|
|
630
|
+
sampling_rate: sampling_rate in Hz
|
|
631
|
+
harmonic_num: number of harmonic above F0 (default: 0)
|
|
632
|
+
sine_amp: amplitude of sine source signal (default: 0.1)
|
|
633
|
+
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
|
634
|
+
note that amplitude of noise in unvoiced is decided
|
|
635
|
+
by sine_amp
|
|
636
|
+
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
|
637
|
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
|
638
|
+
F0_sampled (batchsize, length, 1)
|
|
639
|
+
Sine_source (batchsize, length, 1)
|
|
640
|
+
noise_source (batchsize, length 1)
|
|
641
|
+
uv (batchsize, length, 1)
|
|
642
|
+
"""
|
|
643
|
+
|
|
644
|
+
def __init__(
|
|
645
|
+
self,
|
|
646
|
+
sampling_rate,
|
|
647
|
+
upsample_scale,
|
|
648
|
+
harmonic_num=0,
|
|
649
|
+
sine_amp=0.1,
|
|
650
|
+
add_noise_std=0.003,
|
|
651
|
+
voiced_threshod=0,
|
|
652
|
+
):
|
|
653
|
+
super(SourceModuleHnNSF, self).__init__()
|
|
654
|
+
self.sine_amp = sine_amp
|
|
655
|
+
self.noise_std = add_noise_std
|
|
656
|
+
# to produce sine waveforms
|
|
657
|
+
self.l_sin_gen = SineGen(
|
|
658
|
+
sampling_rate,
|
|
659
|
+
upsample_scale,
|
|
660
|
+
harmonic_num,
|
|
661
|
+
sine_amp,
|
|
662
|
+
add_noise_std,
|
|
663
|
+
voiced_threshod,
|
|
664
|
+
)
|
|
665
|
+
# to merge source harmonics into a single excitation
|
|
666
|
+
self.l_linear = nn.Linear(harmonic_num + 1, 1)
|
|
667
|
+
|
|
668
|
+
def __call__(self, x):
|
|
669
|
+
"""
|
|
670
|
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
|
671
|
+
F0_sampled (batchsize, length, 1)
|
|
672
|
+
Sine_source (batchsize, length, 1)
|
|
673
|
+
noise_source (batchsize, length 1)
|
|
674
|
+
"""
|
|
675
|
+
# source for harmonic branch
|
|
676
|
+
sine_wavs, uv, _ = self.l_sin_gen(x)
|
|
677
|
+
sine_merge = mx.tanh(self.l_linear(sine_wavs))
|
|
678
|
+
# source for noise branch, in the same shape as uv
|
|
679
|
+
noise = mx.random.normal(uv.shape) * self.sine_amp / 3
|
|
680
|
+
return sine_merge, noise, uv
|
|
681
|
+
|
|
682
|
+
|
|
683
|
+
class ReflectionPad1d(nn.Module):
|
|
684
|
+
def __init__(self, padding):
|
|
685
|
+
super().__init__()
|
|
686
|
+
self.padding = padding
|
|
687
|
+
|
|
688
|
+
def __call__(self, x):
|
|
689
|
+
return mx.pad(x, ((0, 0), (0, 0), (self.padding[0], self.padding[1])))
|
|
690
|
+
|
|
691
|
+
|
|
692
|
+
def leaky_relu(x, negative_slope=0.01):
|
|
693
|
+
return mx.where(x > 0, x, x * negative_slope)
|
|
694
|
+
|
|
695
|
+
|
|
696
|
+
class Generator(nn.Module):
|
|
697
|
+
def __init__(
|
|
698
|
+
self,
|
|
699
|
+
style_dim,
|
|
700
|
+
resblock_kernel_sizes,
|
|
701
|
+
upsample_rates,
|
|
702
|
+
upsample_initial_channel,
|
|
703
|
+
resblock_dilation_sizes,
|
|
704
|
+
upsample_kernel_sizes,
|
|
705
|
+
gen_istft_n_fft,
|
|
706
|
+
gen_istft_hop_size,
|
|
707
|
+
):
|
|
708
|
+
super(Generator, self).__init__()
|
|
709
|
+
self.num_kernels = len(resblock_kernel_sizes)
|
|
710
|
+
self.num_upsamples = len(upsample_rates)
|
|
711
|
+
upsample_rates = mx.array(upsample_rates)
|
|
712
|
+
self.m_source = SourceModuleHnNSF(
|
|
713
|
+
sampling_rate=24000,
|
|
714
|
+
upsample_scale=mx.prod(upsample_rates) * gen_istft_hop_size,
|
|
715
|
+
harmonic_num=8,
|
|
716
|
+
voiced_threshod=10,
|
|
717
|
+
)
|
|
718
|
+
self.f0_upsamp = nn.Upsample(
|
|
719
|
+
scale_factor=mx.prod(upsample_rates) * gen_istft_hop_size
|
|
720
|
+
)
|
|
721
|
+
self.noise_convs = []
|
|
722
|
+
self.noise_res = []
|
|
723
|
+
self.ups = []
|
|
724
|
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
|
725
|
+
self.ups.append(
|
|
726
|
+
ConvWeighted(
|
|
727
|
+
upsample_initial_channel // (2 ** (i + 1)),
|
|
728
|
+
upsample_initial_channel // (2**i),
|
|
729
|
+
int(k),
|
|
730
|
+
int(u),
|
|
731
|
+
padding=int((k - u) // 2),
|
|
732
|
+
encode=True,
|
|
733
|
+
)
|
|
734
|
+
)
|
|
735
|
+
self.resblocks = []
|
|
736
|
+
for i in range(len(self.ups)):
|
|
737
|
+
ch = upsample_initial_channel // (2 ** (i + 1))
|
|
738
|
+
for j, (k, d) in enumerate(
|
|
739
|
+
zip(resblock_kernel_sizes, resblock_dilation_sizes)
|
|
740
|
+
):
|
|
741
|
+
self.resblocks.append(AdaINResBlock1(ch, k, d, style_dim))
|
|
742
|
+
c_cur = upsample_initial_channel // (2 ** (i + 1))
|
|
743
|
+
if i + 1 < len(upsample_rates):
|
|
744
|
+
stride_f0 = int(mx.prod(upsample_rates[i + 1 :]))
|
|
745
|
+
self.noise_convs.append(
|
|
746
|
+
nn.Conv1d(
|
|
747
|
+
gen_istft_n_fft + 2,
|
|
748
|
+
c_cur,
|
|
749
|
+
kernel_size=stride_f0 * 2,
|
|
750
|
+
stride=stride_f0,
|
|
751
|
+
padding=(stride_f0 + 1) // 2,
|
|
752
|
+
)
|
|
753
|
+
)
|
|
754
|
+
self.noise_res.append(AdaINResBlock1(c_cur, 7, [1, 3, 5], style_dim))
|
|
755
|
+
else:
|
|
756
|
+
self.noise_convs.append(
|
|
757
|
+
nn.Conv1d(gen_istft_n_fft + 2, c_cur, kernel_size=1)
|
|
758
|
+
)
|
|
759
|
+
self.noise_res.append(AdaINResBlock1(c_cur, 11, [1, 3, 5], style_dim))
|
|
760
|
+
self.post_n_fft = gen_istft_n_fft
|
|
761
|
+
self.conv_post = ConvWeighted(ch, self.post_n_fft + 2, 7, 1, padding=3)
|
|
762
|
+
self.reflection_pad = ReflectionPad1d((1, 0))
|
|
763
|
+
self.stft = MLXSTFT(
|
|
764
|
+
filter_length=gen_istft_n_fft,
|
|
765
|
+
hop_length=gen_istft_hop_size,
|
|
766
|
+
win_length=gen_istft_n_fft,
|
|
767
|
+
)
|
|
768
|
+
|
|
769
|
+
def __call__(self, x, s, f0):
|
|
770
|
+
f0 = self.f0_upsamp(f0[:, None].transpose(0, 2, 1)) # bs,n,t
|
|
771
|
+
har_source, noi_source, uv = self.m_source(f0)
|
|
772
|
+
har_source = mx.squeeze(har_source.transpose(0, 2, 1), axis=1)
|
|
773
|
+
har_spec, har_phase = self.stft.transform(har_source)
|
|
774
|
+
har = mx.concatenate([har_spec, har_phase], axis=1)
|
|
775
|
+
har = har.swapaxes(2, 1)
|
|
776
|
+
for i in range(self.num_upsamples):
|
|
777
|
+
x = leaky_relu(x, negative_slope=0.1)
|
|
778
|
+
x_source = self.noise_convs[i](har)
|
|
779
|
+
x_source = x_source.swapaxes(2, 1)
|
|
780
|
+
x_source = self.noise_res[i](x_source, s)
|
|
781
|
+
|
|
782
|
+
x = x.swapaxes(2, 1)
|
|
783
|
+
x = self.ups[i](x, mx.conv_transpose1d)
|
|
784
|
+
x = x.swapaxes(2, 1)
|
|
785
|
+
|
|
786
|
+
if i == self.num_upsamples - 1:
|
|
787
|
+
x = self.reflection_pad(x)
|
|
788
|
+
x = x + x_source
|
|
789
|
+
|
|
790
|
+
xs = None
|
|
791
|
+
for j in range(self.num_kernels):
|
|
792
|
+
if xs is None:
|
|
793
|
+
xs = self.resblocks[i * self.num_kernels + j](x, s)
|
|
794
|
+
else:
|
|
795
|
+
xs += self.resblocks[i * self.num_kernels + j](x, s)
|
|
796
|
+
x = xs / self.num_kernels
|
|
797
|
+
|
|
798
|
+
x = leaky_relu(x, negative_slope=0.01)
|
|
799
|
+
|
|
800
|
+
x = x.swapaxes(2, 1)
|
|
801
|
+
x = self.conv_post(x, mx.conv1d)
|
|
802
|
+
x = x.swapaxes(2, 1)
|
|
803
|
+
|
|
804
|
+
spec = mx.exp(x[:, : self.post_n_fft // 2 + 1, :])
|
|
805
|
+
phase = mx.sin(x[:, self.post_n_fft // 2 + 1 :, :])
|
|
806
|
+
result = self.stft.inverse(spec, phase)
|
|
807
|
+
return result
|
|
808
|
+
|
|
809
|
+
|
|
810
|
+
class UpSample1d(nn.Module):
|
|
811
|
+
def __init__(self, layer_type):
|
|
812
|
+
super().__init__()
|
|
813
|
+
self.layer_type = layer_type
|
|
814
|
+
self.interpolate = nn.Upsample(
|
|
815
|
+
scale_factor=2, mode="nearest", align_corners=True
|
|
816
|
+
)
|
|
817
|
+
|
|
818
|
+
def __call__(self, x):
|
|
819
|
+
if self.layer_type == "none":
|
|
820
|
+
return x
|
|
821
|
+
else:
|
|
822
|
+
return self.interpolate(x)
|
|
823
|
+
|
|
824
|
+
|
|
825
|
+
class AdainResBlk1d(nn.Module):
|
|
826
|
+
def __init__(
|
|
827
|
+
self,
|
|
828
|
+
dim_in,
|
|
829
|
+
dim_out,
|
|
830
|
+
style_dim=64,
|
|
831
|
+
actv=nn.LeakyReLU(0.2),
|
|
832
|
+
upsample="none",
|
|
833
|
+
dropout_p=0.0,
|
|
834
|
+
bias=False,
|
|
835
|
+
conv_type=None,
|
|
836
|
+
):
|
|
837
|
+
super().__init__()
|
|
838
|
+
self.actv = actv
|
|
839
|
+
self.dim_in = dim_in
|
|
840
|
+
self.conv_type = conv_type
|
|
841
|
+
self.upsample_type = upsample
|
|
842
|
+
self.upsample = UpSample1d(upsample)
|
|
843
|
+
self.learned_sc = dim_in != dim_out
|
|
844
|
+
self._build_weights(dim_in, dim_out, style_dim)
|
|
845
|
+
self.dropout = nn.Dropout(dropout_p)
|
|
846
|
+
if upsample == "none":
|
|
847
|
+
self.pool = nn.Identity()
|
|
848
|
+
else:
|
|
849
|
+
self.pool = ConvWeighted(
|
|
850
|
+
1, dim_in, kernel_size=3, stride=2, padding=1, groups=dim_in
|
|
851
|
+
)
|
|
852
|
+
|
|
853
|
+
def _build_weights(self, dim_in, dim_out, style_dim):
|
|
854
|
+
self.conv1 = ConvWeighted(dim_in, dim_out, kernel_size=3, stride=1, padding=1)
|
|
855
|
+
self.conv2 = ConvWeighted(dim_out, dim_out, kernel_size=3, stride=1, padding=1)
|
|
856
|
+
self.norm1 = AdaIN1d(style_dim, dim_in)
|
|
857
|
+
self.norm2 = AdaIN1d(style_dim, dim_out)
|
|
858
|
+
if self.learned_sc:
|
|
859
|
+
self.conv1x1 = ConvWeighted(
|
|
860
|
+
dim_in, dim_out, kernel_size=1, stride=1, padding=0, bias=False
|
|
861
|
+
)
|
|
862
|
+
|
|
863
|
+
def _shortcut(self, x):
|
|
864
|
+
x = x.swapaxes(2, 1)
|
|
865
|
+
x = self.upsample(x)
|
|
866
|
+
x = x.swapaxes(2, 1)
|
|
867
|
+
|
|
868
|
+
if self.learned_sc:
|
|
869
|
+
x = x.swapaxes(2, 1)
|
|
870
|
+
x = self.conv1x1(x, mx.conv1d)
|
|
871
|
+
x = x.swapaxes(2, 1)
|
|
872
|
+
return x
|
|
873
|
+
|
|
874
|
+
def _residual(self, x, s):
|
|
875
|
+
x = self.norm1(x, s)
|
|
876
|
+
x = self.actv(x)
|
|
877
|
+
|
|
878
|
+
# Manually implement grouped ConvTranspose1d since MLX doesn't support groups
|
|
879
|
+
x = x.swapaxes(2, 1)
|
|
880
|
+
x = self.pool(x, mx.conv_transpose1d) if self.upsample_type != "none" else x
|
|
881
|
+
x = mx.pad(x, ((0, 0), (1, 0), (0, 0))) if self.upsample_type != "none" else x
|
|
882
|
+
x = x.swapaxes(2, 1)
|
|
883
|
+
|
|
884
|
+
x = x.swapaxes(2, 1)
|
|
885
|
+
x = self.conv1(self.dropout(x), mx.conv1d)
|
|
886
|
+
x = x.swapaxes(2, 1)
|
|
887
|
+
|
|
888
|
+
x = self.norm2(x, s)
|
|
889
|
+
x = self.actv(x)
|
|
890
|
+
|
|
891
|
+
x = x.swapaxes(2, 1)
|
|
892
|
+
x = self.conv2(x, mx.conv1d)
|
|
893
|
+
x = x.swapaxes(2, 1)
|
|
894
|
+
return x
|
|
895
|
+
|
|
896
|
+
def __call__(self, x, s):
|
|
897
|
+
out = self._residual(x, s)
|
|
898
|
+
out = (out + self._shortcut(x)) / mx.sqrt(2)
|
|
899
|
+
return out
|
|
900
|
+
|
|
901
|
+
|
|
902
|
+
class Decoder(nn.Module):
|
|
903
|
+
def __init__(
|
|
904
|
+
self,
|
|
905
|
+
dim_in,
|
|
906
|
+
style_dim,
|
|
907
|
+
dim_out,
|
|
908
|
+
resblock_kernel_sizes,
|
|
909
|
+
upsample_rates,
|
|
910
|
+
upsample_initial_channel,
|
|
911
|
+
resblock_dilation_sizes,
|
|
912
|
+
upsample_kernel_sizes,
|
|
913
|
+
gen_istft_n_fft,
|
|
914
|
+
gen_istft_hop_size,
|
|
915
|
+
):
|
|
916
|
+
super().__init__()
|
|
917
|
+
self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim, conv_type=mx.conv1d)
|
|
918
|
+
self.decode = []
|
|
919
|
+
self.decode.append(
|
|
920
|
+
AdainResBlk1d(1024 + 2 + 64, 1024, style_dim, conv_type=mx.conv1d)
|
|
921
|
+
)
|
|
922
|
+
self.decode.append(
|
|
923
|
+
AdainResBlk1d(1024 + 2 + 64, 1024, style_dim, conv_type=mx.conv1d)
|
|
924
|
+
)
|
|
925
|
+
self.decode.append(
|
|
926
|
+
AdainResBlk1d(1024 + 2 + 64, 1024, style_dim, conv_type=mx.conv1d)
|
|
927
|
+
)
|
|
928
|
+
self.decode.append(
|
|
929
|
+
AdainResBlk1d(
|
|
930
|
+
1024 + 2 + 64, 512, style_dim, upsample=True, conv_type=mx.conv1d
|
|
931
|
+
)
|
|
932
|
+
)
|
|
933
|
+
self.F0_conv = ConvWeighted(1, 1, kernel_size=3, stride=2, padding=1, groups=1)
|
|
934
|
+
self.N_conv = ConvWeighted(1, 1, kernel_size=3, stride=2, padding=1, groups=1)
|
|
935
|
+
self.asr_res = [ConvWeighted(512, 64, kernel_size=1, padding=0)]
|
|
936
|
+
self.generator = Generator(
|
|
937
|
+
style_dim,
|
|
938
|
+
resblock_kernel_sizes,
|
|
939
|
+
upsample_rates,
|
|
940
|
+
upsample_initial_channel,
|
|
941
|
+
resblock_dilation_sizes,
|
|
942
|
+
upsample_kernel_sizes,
|
|
943
|
+
gen_istft_n_fft,
|
|
944
|
+
gen_istft_hop_size,
|
|
945
|
+
)
|
|
946
|
+
|
|
947
|
+
def __call__(self, asr, F0_curve, N, s):
|
|
948
|
+
s = mx.array(s)
|
|
949
|
+
F0 = self.F0_conv(F0_curve[:, None, :].swapaxes(2, 1), mx.conv1d).swapaxes(2, 1)
|
|
950
|
+
N = self.N_conv(N[:, None, :].swapaxes(2, 1), mx.conv1d).swapaxes(2, 1)
|
|
951
|
+
x = mx.concatenate([asr, F0, N], axis=1)
|
|
952
|
+
x = self.encode(x, s)
|
|
953
|
+
asr_res = self.asr_res[0](asr.swapaxes(2, 1), mx.conv1d).swapaxes(2, 1)
|
|
954
|
+
res = True
|
|
955
|
+
for block in self.decode: # Working in MLX
|
|
956
|
+
if res:
|
|
957
|
+
x = mx.concatenate([x, asr_res, F0, N], axis=1)
|
|
958
|
+
x = block(x, s)
|
|
959
|
+
# Check if this block has upsampling
|
|
960
|
+
if hasattr(block, "upsample_type") and block.upsample_type != "none":
|
|
961
|
+
res = False
|
|
962
|
+
x = self.generator(x, s, F0_curve) # Working in MLX
|
|
963
|
+
return x
|
|
964
|
+
|
|
965
|
+
def sanitize(self, key, weights):
|
|
966
|
+
sanitized_weights = None
|
|
967
|
+
if "noise_convs" in key and key.endswith(".weight"):
|
|
968
|
+
sanitized_weights = weights.transpose(0, 2, 1)
|
|
969
|
+
|
|
970
|
+
elif "weight_v" in key:
|
|
971
|
+
if check_array_shape(weights):
|
|
972
|
+
sanitized_weights = weights
|
|
973
|
+
else:
|
|
974
|
+
sanitized_weights = weights.transpose(0, 2, 1)
|
|
975
|
+
|
|
976
|
+
else:
|
|
977
|
+
sanitized_weights = weights
|
|
978
|
+
|
|
979
|
+
return sanitized_weights
|