nexaai 1.0.4rc10__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nexaai might be problematic. Click here for more details.
- nexaai/__init__.py +71 -0
- nexaai/_version.py +4 -0
- nexaai/asr.py +60 -0
- nexaai/asr_impl/__init__.py +0 -0
- nexaai/asr_impl/mlx_asr_impl.py +91 -0
- nexaai/asr_impl/pybind_asr_impl.py +43 -0
- nexaai/base.py +39 -0
- nexaai/binds/__init__.py +3 -0
- nexaai/binds/common_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/embedder_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/libnexa_bridge.dylib +0 -0
- nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/nexa_llama_cpp/libggml-base.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libggml-cpu.so +0 -0
- nexaai/binds/nexa_llama_cpp/libggml-metal.so +0 -0
- nexaai/binds/nexa_llama_cpp/libggml.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libllama.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libmtmd.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexa_mlx/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexa_mlx/py-lib/ml.py +842 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/__init__.py +5 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/activation.py +51 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/amp.py +96 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/conv.py +114 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/resample.py +177 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/base.py +228 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/dac.py +285 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/nn/layers.py +129 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/encodec/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/encodec/encodec.py +777 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/mimi.py +286 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/model.py +260 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/model_v2.py +383 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/utils.py +122 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/attention.py +97 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/layers.py +306 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/snac.py +154 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/vq.py +135 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/vocos/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/vocos/mel.py +33 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/vocos/vocos.py +359 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_bigvgan.py +54 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_descript.py +109 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_encodec.py +58 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_mimi.py +22 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_s3.py +25 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_snac.py +40 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_vocos.py +93 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/server.py +525 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/sts/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/sts/voice_pipeline.py +327 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/generate.py +174 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/alignment.py +248 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/attention.py +187 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/audio.py +76 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/conformer.py +331 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/ctc.py +34 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/audio.py +82 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/decoding.py +742 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/timing.py +329 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/whisper.py +862 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/writers.py +268 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/tests/test_models.py +381 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/utils.py +195 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/audio_player.py +120 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/convert.py +71 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/generate.py +449 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/__init__.py +4 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/bark.py +528 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/isftnet.py +12 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/pipeline.py +442 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/base.py +84 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/audio.py +287 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/config.py +256 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/dia.py +592 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/layers.py +870 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/__init__.py +3 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/attention.py +180 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/conformer.py +247 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/gpt2.py +38 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/indextts.py +412 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/mel.py +37 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/normalize.py +294 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/perceiver.py +62 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/interpolate.py +108 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/__init__.py +4 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/modules.py +659 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/voice.py +113 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/llama/__init__.py +3 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/llama/llama.py +324 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/outetts.py +255 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/tokens.py +36 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/__init__.py +3 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/attention.py +195 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/sesame.py +633 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/watermarking.py +105 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/bicodec.py +269 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/residual.py +209 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/spark.py +382 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/utils/audio.py +220 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/utils/file.py +221 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_base.py +66 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_convert.py +173 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_interpolate.py +88 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_models.py +974 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/utils.py +337 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/utils.py +237 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/version.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/profiling.py +239 -0
- nexaai/common.py +61 -0
- nexaai/cv.py +87 -0
- nexaai/cv_impl/__init__.py +0 -0
- nexaai/cv_impl/mlx_cv_impl.py +88 -0
- nexaai/cv_impl/pybind_cv_impl.py +31 -0
- nexaai/embedder.py +68 -0
- nexaai/embedder_impl/__init__.py +0 -0
- nexaai/embedder_impl/mlx_embedder_impl.py +114 -0
- nexaai/embedder_impl/pybind_embedder_impl.py +91 -0
- nexaai/image_gen.py +136 -0
- nexaai/image_gen_impl/__init__.py +0 -0
- nexaai/image_gen_impl/mlx_image_gen_impl.py +291 -0
- nexaai/image_gen_impl/pybind_image_gen_impl.py +84 -0
- nexaai/llm.py +89 -0
- nexaai/llm_impl/__init__.py +0 -0
- nexaai/llm_impl/mlx_llm_impl.py +249 -0
- nexaai/llm_impl/pybind_llm_impl.py +207 -0
- nexaai/mlx_backend/asr/__init__.py +12 -0
- nexaai/mlx_backend/asr/interface.py +122 -0
- nexaai/mlx_backend/common/__init__.py +0 -0
- nexaai/mlx_backend/common/utils.py +25 -0
- nexaai/mlx_backend/cv/__init__.py +0 -0
- nexaai/mlx_backend/cv/generate.py +195 -0
- nexaai/mlx_backend/cv/interface.py +151 -0
- nexaai/mlx_backend/cv/main.py +81 -0
- nexaai/mlx_backend/cv/modeling/pp_ocr_v4.py +1736 -0
- nexaai/mlx_backend/embedding/__init__.py +0 -0
- nexaai/mlx_backend/embedding/generate.py +130 -0
- nexaai/mlx_backend/embedding/interface.py +312 -0
- nexaai/mlx_backend/embedding/main.py +82 -0
- nexaai/mlx_backend/embedding/modeling/__init__.py +0 -0
- nexaai/mlx_backend/embedding/modeling/nexa_jina_v2.py +399 -0
- nexaai/mlx_backend/llm/__init__.py +0 -0
- nexaai/mlx_backend/llm/generate.py +149 -0
- nexaai/mlx_backend/llm/interface.py +764 -0
- nexaai/mlx_backend/llm/main.py +68 -0
- nexaai/mlx_backend/ml.py +842 -0
- nexaai/mlx_backend/mlx_audio/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/codec/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/__init__.py +5 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/activation.py +51 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/amp.py +96 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/conv.py +114 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/resample.py +177 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/base.py +228 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/dac.py +285 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/layers.py +129 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
- nexaai/mlx_backend/mlx_audio/codec/models/encodec/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/encodec/encodec.py +777 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/mimi.py +286 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/model.py +260 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/model_v2.py +383 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/utils.py +122 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/attention.py +97 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/layers.py +306 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/snac.py +154 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/vq.py +135 -0
- nexaai/mlx_backend/mlx_audio/codec/models/vocos/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/vocos/mel.py +33 -0
- nexaai/mlx_backend/mlx_audio/codec/models/vocos/vocos.py +359 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_bigvgan.py +54 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_descript.py +109 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_encodec.py +58 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_mimi.py +22 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_s3.py +25 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_snac.py +40 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_vocos.py +93 -0
- nexaai/mlx_backend/mlx_audio/server.py +525 -0
- nexaai/mlx_backend/mlx_audio/sts/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
- nexaai/mlx_backend/mlx_audio/sts/voice_pipeline.py +327 -0
- nexaai/mlx_backend/mlx_audio/stt/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/stt/generate.py +174 -0
- nexaai/mlx_backend/mlx_audio/stt/models/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/alignment.py +248 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/attention.py +187 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/audio.py +76 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/conformer.py +331 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/ctc.py +34 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
- nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
- nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/audio.py +82 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/decoding.py +742 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/timing.py +329 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/whisper.py +862 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/writers.py +268 -0
- nexaai/mlx_backend/mlx_audio/stt/tests/test_models.py +381 -0
- nexaai/mlx_backend/mlx_audio/stt/utils.py +195 -0
- nexaai/mlx_backend/mlx_audio/tts/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/audio_player.py +120 -0
- nexaai/mlx_backend/mlx_audio/tts/convert.py +71 -0
- nexaai/mlx_backend/mlx_audio/tts/generate.py +449 -0
- nexaai/mlx_backend/mlx_audio/tts/models/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/__init__.py +4 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/bark.py +528 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/isftnet.py +12 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/pipeline.py +442 -0
- nexaai/mlx_backend/mlx_audio/tts/models/base.py +84 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/audio.py +287 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/config.py +256 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/dia.py +592 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/layers.py +870 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/__init__.py +3 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/attention.py +180 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/conformer.py +247 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/gpt2.py +38 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/indextts.py +412 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/mel.py +37 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/normalize.py +294 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/perceiver.py +62 -0
- nexaai/mlx_backend/mlx_audio/tts/models/interpolate.py +108 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/__init__.py +4 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/modules.py +659 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/voice.py +113 -0
- nexaai/mlx_backend/mlx_audio/tts/models/llama/__init__.py +3 -0
- nexaai/mlx_backend/mlx_audio/tts/models/llama/llama.py +324 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/default_speaker.json +461 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/outetts.py +255 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/tokens.py +36 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/__init__.py +3 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/attention.py +195 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/sesame.py +633 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/watermarking.py +105 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/bicodec.py +269 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual.py +209 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/spark.py +382 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/audio.py +220 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/file.py +221 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_base.py +66 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_convert.py +173 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_interpolate.py +88 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_models.py +974 -0
- nexaai/mlx_backend/mlx_audio/tts/utils.py +337 -0
- nexaai/mlx_backend/mlx_audio/utils.py +237 -0
- nexaai/mlx_backend/mlx_audio/version.py +1 -0
- nexaai/mlx_backend/profiling.py +239 -0
- nexaai/mlx_backend/rerank/__init__.py +0 -0
- nexaai/mlx_backend/rerank/generate.py +174 -0
- nexaai/mlx_backend/rerank/interface.py +287 -0
- nexaai/mlx_backend/rerank/main.py +127 -0
- nexaai/mlx_backend/rerank/modeling/__init__.py +0 -0
- nexaai/mlx_backend/rerank/modeling/nexa_jina_rerank.py +330 -0
- nexaai/mlx_backend/sd/__init__.py +1 -0
- nexaai/mlx_backend/sd/interface.py +362 -0
- nexaai/mlx_backend/sd/main.py +286 -0
- nexaai/mlx_backend/sd/modeling/__init__.py +306 -0
- nexaai/mlx_backend/sd/modeling/clip.py +116 -0
- nexaai/mlx_backend/sd/modeling/config.py +65 -0
- nexaai/mlx_backend/sd/modeling/model_io.py +330 -0
- nexaai/mlx_backend/sd/modeling/sampler.py +105 -0
- nexaai/mlx_backend/sd/modeling/tokenizer.py +100 -0
- nexaai/mlx_backend/sd/modeling/unet.py +460 -0
- nexaai/mlx_backend/sd/modeling/vae.py +274 -0
- nexaai/mlx_backend/tts/__init__.py +12 -0
- nexaai/mlx_backend/tts/interface.py +276 -0
- nexaai/mlx_backend/vlm/__init__.py +3 -0
- nexaai/mlx_backend/vlm/generate.py +572 -0
- nexaai/mlx_backend/vlm/interface.py +406 -0
- nexaai/mlx_backend/vlm/main.py +157 -0
- nexaai/mlx_backend/vlm/modeling/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/convert.py +68 -0
- nexaai/mlx_backend/vlm/modeling/models/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/interpolate.py +186 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/language.py +233 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/vision.py +503 -0
- nexaai/mlx_backend/vlm/modeling/models/base.py +202 -0
- nexaai/mlx_backend/vlm/modeling/models/cache.py +230 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/florence2.py +366 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/language.py +488 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/vision.py +591 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/gemma3.py +213 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/language.py +315 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/vision.py +238 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/__init__.py +2 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/audio.py +1038 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/config.py +139 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/language.py +629 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/vision.py +1022 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/idefics2.py +294 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/language.py +191 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/vision.py +267 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/idefics3.py +175 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/language.py +192 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/vision.py +233 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/language.py +220 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/processor.py +393 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/vision.py +293 -0
- nexaai/mlx_backend/vlm/modeling/models/kernels.py +307 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/language.py +509 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/vision.py +522 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/language.py +386 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/llama4.py +138 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/vision.py +560 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/language.py +240 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/llava.py +153 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/vision.py +259 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/language.py +236 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/vision.py +303 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/language.py +230 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/llava_next.py +160 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/vision.py +243 -0
- nexaai/mlx_backend/vlm/modeling/models/mistral3/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/mistral3/mistral3.py +283 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/language.py +416 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/mllama.py +172 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/vision.py +499 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/language.py +243 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/molmo.py +133 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/vision.py +465 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/__init__.py +10 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/language.py +230 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/sam.py +557 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/vision.py +526 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/language.py +282 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/paligemma.py +160 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/vision.py +242 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/language.py +21 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/su_rope.py +71 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/vision.py +324 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/language.py +229 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/pixtral.py +161 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/vision.py +320 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/config.py +104 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/language.py +490 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/vision.py +312 -0
- nexaai/mlx_backend/vlm/modeling/models/smolvlm/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
- nexaai/mlx_backend/vlm/modeling/processing_qwen2_5_vl.py +209 -0
- nexaai/mlx_backend/vlm/modeling/processing_qwen2_vl.py +215 -0
- nexaai/mlx_backend/vlm/modeling/prompt_utils.py +474 -0
- nexaai/mlx_backend/vlm/modeling/sample_utils.py +39 -0
- nexaai/mlx_backend/vlm/modeling/tokenizer_utils.py +344 -0
- nexaai/mlx_backend/vlm/modeling/trainer/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/trainer/lora.py +70 -0
- nexaai/mlx_backend/vlm/modeling/trainer/trainer.py +296 -0
- nexaai/mlx_backend/vlm/modeling/trainer/utils.py +160 -0
- nexaai/mlx_backend/vlm/modeling/utils.py +928 -0
- nexaai/rerank.py +51 -0
- nexaai/rerank_impl/__init__.py +0 -0
- nexaai/rerank_impl/mlx_rerank_impl.py +91 -0
- nexaai/rerank_impl/pybind_rerank_impl.py +42 -0
- nexaai/runtime.py +64 -0
- nexaai/tts.py +70 -0
- nexaai/tts_impl/__init__.py +0 -0
- nexaai/tts_impl/mlx_tts_impl.py +93 -0
- nexaai/tts_impl/pybind_tts_impl.py +42 -0
- nexaai/utils/avatar_fetcher.py +104 -0
- nexaai/utils/decode.py +18 -0
- nexaai/utils/model_manager.py +1195 -0
- nexaai/utils/progress_tracker.py +372 -0
- nexaai/vlm.py +120 -0
- nexaai/vlm_impl/__init__.py +0 -0
- nexaai/vlm_impl/mlx_vlm_impl.py +205 -0
- nexaai/vlm_impl/pybind_vlm_impl.py +228 -0
- nexaai-1.0.4rc10.dist-info/METADATA +26 -0
- nexaai-1.0.4rc10.dist-info/RECORD +519 -0
- nexaai-1.0.4rc10.dist-info/WHEEL +5 -0
- nexaai-1.0.4rc10.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,337 @@
|
|
|
1
|
+
import glob
|
|
2
|
+
import importlib
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
import shutil
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import List, Optional, Tuple, Union
|
|
8
|
+
|
|
9
|
+
import mlx.core as mx
|
|
10
|
+
import mlx.nn as nn
|
|
11
|
+
from mlx.utils import tree_flatten
|
|
12
|
+
from mlx_lm.convert import mixed_quant_predicate_builder
|
|
13
|
+
from mlx_lm.utils import dequantize_model, quantize_model, save_config, save_model
|
|
14
|
+
|
|
15
|
+
MODEL_REMAPPING = {"outetts": "outetts", "spark": "spark", "sam": "sesame"}
|
|
16
|
+
MAX_FILE_SIZE_GB = 5
|
|
17
|
+
MODEL_CONVERSION_DTYPES = ["float16", "bfloat16", "float32"]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def get_model_path(path: str, revision: Optional[str] = None) -> Path:
|
|
21
|
+
"""
|
|
22
|
+
Ensures the model is available locally. Only works with local paths.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
path_or_hf_repo (str): The local path to the model.
|
|
26
|
+
revision (str, optional): Ignored for local paths, kept for compatibility.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
Path: The path to the model.
|
|
30
|
+
|
|
31
|
+
Raises:
|
|
32
|
+
FileNotFoundError: If the local path does not exist.
|
|
33
|
+
"""
|
|
34
|
+
model_path = Path(path)
|
|
35
|
+
|
|
36
|
+
if not model_path.exists():
|
|
37
|
+
raise FileNotFoundError(f"Model path '{path}' does not exist locally. Please ensure the model is available at the specified path.")
|
|
38
|
+
|
|
39
|
+
return model_path
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# Get a list of all available model types from the models directory
|
|
43
|
+
def get_available_models():
|
|
44
|
+
"""
|
|
45
|
+
Get a list of all available TTS model types by scanning the models directory.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
List[str]: A list of available model type names
|
|
49
|
+
"""
|
|
50
|
+
models_dir = Path(__file__).parent / "models"
|
|
51
|
+
available_models = []
|
|
52
|
+
|
|
53
|
+
if models_dir.exists() and models_dir.is_dir():
|
|
54
|
+
for item in models_dir.iterdir():
|
|
55
|
+
if item.is_dir() and not item.name.startswith("__"):
|
|
56
|
+
available_models.append(item.name)
|
|
57
|
+
|
|
58
|
+
return available_models
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def get_model_and_args(model_type: str, model_name: List[str]):
|
|
62
|
+
"""
|
|
63
|
+
Retrieve the model architecture module based on the model type and name.
|
|
64
|
+
|
|
65
|
+
This function attempts to find the appropriate model architecture by:
|
|
66
|
+
1. Checking if the model_type is directly in the MODEL_REMAPPING dictionary
|
|
67
|
+
2. Looking for partial matches in segments of the model_name
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
model_type (str): The type of model to load (e.g., "outetts").
|
|
71
|
+
model_name (List[str]): List of model name components that might contain
|
|
72
|
+
remapping information.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
Tuple[module, str]: A tuple containing:
|
|
76
|
+
- The imported architecture module
|
|
77
|
+
- The resolved model_type string after remapping
|
|
78
|
+
|
|
79
|
+
Raises:
|
|
80
|
+
ValueError: If the model type is not supported (module import fails).
|
|
81
|
+
"""
|
|
82
|
+
# Stage 1: Check if the model type is in the remapping
|
|
83
|
+
model_type = MODEL_REMAPPING.get(model_type, model_type)
|
|
84
|
+
|
|
85
|
+
# Stage 2: Check for partial matches in segments of the model name
|
|
86
|
+
models = get_available_models()
|
|
87
|
+
if model_name is not None:
|
|
88
|
+
for part in model_name:
|
|
89
|
+
# First check if the part matches an available model directory name
|
|
90
|
+
if part in models:
|
|
91
|
+
model_type = part
|
|
92
|
+
|
|
93
|
+
# Then check if the part is in our custom remapping dictionary
|
|
94
|
+
if part in MODEL_REMAPPING:
|
|
95
|
+
model_type = MODEL_REMAPPING[part]
|
|
96
|
+
break
|
|
97
|
+
|
|
98
|
+
try:
|
|
99
|
+
arch = importlib.import_module(f"mlx_audio.tts.models.{model_type}")
|
|
100
|
+
except ImportError:
|
|
101
|
+
msg = f"Model type {model_type} not supported."
|
|
102
|
+
logging.error(msg)
|
|
103
|
+
raise ValueError(msg)
|
|
104
|
+
|
|
105
|
+
return arch, model_type
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def load_config(model_path: Union[str, Path], **kwargs) -> dict:
|
|
109
|
+
"""Load model configuration from a local path.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
model_path: Local path to load config from
|
|
113
|
+
**kwargs: Additional keyword arguments (ignored for local loading)
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
dict: Model configuration
|
|
117
|
+
|
|
118
|
+
Raises:
|
|
119
|
+
FileNotFoundError: If config.json is not found at the path
|
|
120
|
+
"""
|
|
121
|
+
if isinstance(model_path, str):
|
|
122
|
+
model_path = get_model_path(model_path)
|
|
123
|
+
|
|
124
|
+
try:
|
|
125
|
+
with open(model_path / "config.json", encoding="utf-8") as f:
|
|
126
|
+
return json.load(f)
|
|
127
|
+
except FileNotFoundError as exc:
|
|
128
|
+
raise FileNotFoundError(f"Config not found at {model_path}") from exc
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def load_model(
|
|
132
|
+
model_path: Path, lazy: bool = False, strict: bool = True, **kwargs
|
|
133
|
+
) -> nn.Module:
|
|
134
|
+
"""
|
|
135
|
+
Load and initialize the model from a given path.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
model_path (Path): The path to load the model from.
|
|
139
|
+
lazy (bool): If False eval the model parameters to make sure they are
|
|
140
|
+
loaded in memory before returning, otherwise they will be loaded
|
|
141
|
+
when needed. Default: ``False``
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
nn.Module: The loaded and initialized model.
|
|
145
|
+
|
|
146
|
+
Raises:
|
|
147
|
+
FileNotFoundError: If the weight files (.safetensors) are not found.
|
|
148
|
+
ValueError: If the model class or args class are not found or cannot be instantiated.
|
|
149
|
+
"""
|
|
150
|
+
model_name = None
|
|
151
|
+
if isinstance(model_path, str):
|
|
152
|
+
model_name = model_path.lower().split("/")[-1].split("-")
|
|
153
|
+
model_path = get_model_path(model_path)
|
|
154
|
+
elif isinstance(model_path, Path):
|
|
155
|
+
model_name = model_path.name.lower().split("-")
|
|
156
|
+
else:
|
|
157
|
+
raise ValueError(f"Invalid model path type: {type(model_path)}")
|
|
158
|
+
|
|
159
|
+
config = load_config(model_path, **kwargs)
|
|
160
|
+
config["tokenizer_name"] = model_path
|
|
161
|
+
|
|
162
|
+
# Determine model_type from config or model_name
|
|
163
|
+
model_type = config.get("model_type", None)
|
|
164
|
+
if model_type is None:
|
|
165
|
+
model_type = model_name[0].lower() if model_name is not None else None
|
|
166
|
+
|
|
167
|
+
# TODO: remove this check once we cleaned other models.
|
|
168
|
+
if model_type != "kokoro":
|
|
169
|
+
raise ValueError(f"Model type {model_type} not supported. Only kokoro is supported for now.")
|
|
170
|
+
|
|
171
|
+
quantization = config.get("quantization", None)
|
|
172
|
+
|
|
173
|
+
weight_files = glob.glob(str(model_path / "*.safetensors"))
|
|
174
|
+
if not weight_files:
|
|
175
|
+
# Check in LLM directory if no safetensors found in the main directory
|
|
176
|
+
# For Spark model
|
|
177
|
+
weight_files = glob.glob(str(model_path / "LLM" / "*.safetensors"))
|
|
178
|
+
|
|
179
|
+
if not weight_files:
|
|
180
|
+
logging.error(f"No safetensors found in {model_path}")
|
|
181
|
+
message = f"""
|
|
182
|
+
No safetensors found in {model_path}
|
|
183
|
+
Please ensure that the model directory contains the required .safetensors weight files.
|
|
184
|
+
The model directory should contain:
|
|
185
|
+
- config.json (model configuration)
|
|
186
|
+
- *.safetensors (model weights)
|
|
187
|
+
- Any other required model files
|
|
188
|
+
|
|
189
|
+
If you have a PyTorch model, you may need to convert it to safetensors format first.
|
|
190
|
+
"""
|
|
191
|
+
raise FileNotFoundError(message)
|
|
192
|
+
|
|
193
|
+
weights = {}
|
|
194
|
+
for wf in weight_files:
|
|
195
|
+
weights.update(mx.load(wf))
|
|
196
|
+
|
|
197
|
+
model_class, model_type = get_model_and_args(
|
|
198
|
+
model_type=model_type, model_name=model_name
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
# Get model config from model class if it exists, otherwise use the config
|
|
202
|
+
model_config = (
|
|
203
|
+
model_class.ModelConfig.from_dict(config)
|
|
204
|
+
if hasattr(model_class, "ModelConfig")
|
|
205
|
+
else config
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
if model_config is not None and hasattr(model_config, "model_path"):
|
|
209
|
+
# For Spark model
|
|
210
|
+
model_config.model_path = model_path
|
|
211
|
+
|
|
212
|
+
model = model_class.Model(model_config)
|
|
213
|
+
quantization = config.get("quantization", None)
|
|
214
|
+
if quantization is None:
|
|
215
|
+
weights = model.sanitize(weights)
|
|
216
|
+
|
|
217
|
+
if (quantization := config.get("quantization", None)) is not None:
|
|
218
|
+
|
|
219
|
+
def get_class_predicate(p, m):
|
|
220
|
+
# Handle custom per layer quantizations
|
|
221
|
+
if p in config["quantization"]:
|
|
222
|
+
return config["quantization"][p]
|
|
223
|
+
if not hasattr(m, "to_quantized"):
|
|
224
|
+
return False
|
|
225
|
+
# Skip layers not divisible by 64
|
|
226
|
+
if hasattr(m, "weight") and m.weight.size % 64 != 0:
|
|
227
|
+
return False
|
|
228
|
+
# Handle legacy models which may not have everything quantized
|
|
229
|
+
return f"{p}.scales" in weights
|
|
230
|
+
|
|
231
|
+
nn.quantize(
|
|
232
|
+
model,
|
|
233
|
+
group_size=quantization["group_size"],
|
|
234
|
+
bits=quantization["bits"],
|
|
235
|
+
class_predicate=get_class_predicate,
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
model.load_weights(list(weights.items()), strict=strict)
|
|
239
|
+
|
|
240
|
+
if not lazy:
|
|
241
|
+
mx.eval(model.parameters())
|
|
242
|
+
|
|
243
|
+
model.eval()
|
|
244
|
+
return model
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def convert(
|
|
248
|
+
hf_path: str,
|
|
249
|
+
mlx_path: str = "mlx_model",
|
|
250
|
+
quantize: bool = False,
|
|
251
|
+
q_group_size: int = 64,
|
|
252
|
+
q_bits: int = 4,
|
|
253
|
+
dtype: str = None,
|
|
254
|
+
revision: Optional[str] = None,
|
|
255
|
+
dequantize: bool = False,
|
|
256
|
+
trust_remote_code: bool = True,
|
|
257
|
+
quant_predicate: Optional[str] = None,
|
|
258
|
+
):
|
|
259
|
+
print("[INFO] Loading")
|
|
260
|
+
model_path = get_model_path(hf_path, revision=revision)
|
|
261
|
+
model = load_model(model_path, lazy=True, trust_remote_code=trust_remote_code)
|
|
262
|
+
config = load_config(model_path, trust_remote_code=trust_remote_code)
|
|
263
|
+
|
|
264
|
+
if isinstance(quant_predicate, str):
|
|
265
|
+
quant_predicate = mixed_quant_predicate_builder(quant_predicate, model)
|
|
266
|
+
|
|
267
|
+
# Get model-specific quantization predicate if available
|
|
268
|
+
model_quant_predicate = getattr(
|
|
269
|
+
model, "model_quant_predicate", lambda p, m, config: True
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
# Define base quantization requirements
|
|
273
|
+
def base_quant_requirements(p, m, config):
|
|
274
|
+
return (
|
|
275
|
+
hasattr(m, "weight")
|
|
276
|
+
and m.weight.shape[-1] % 64 == 0 # Skip layers not divisible by 64
|
|
277
|
+
and hasattr(m, "to_quantized")
|
|
278
|
+
and model_quant_predicate(p, m, config)
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
# Combine with user-provided predicate if available
|
|
282
|
+
if quant_predicate is None:
|
|
283
|
+
quant_predicate = base_quant_requirements
|
|
284
|
+
else:
|
|
285
|
+
original_predicate = quant_predicate
|
|
286
|
+
quant_predicate = lambda p, m, config: (
|
|
287
|
+
base_quant_requirements(p, m, config) and original_predicate(p, m, config)
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
weights = dict(tree_flatten(model.parameters()))
|
|
291
|
+
|
|
292
|
+
if dtype is None:
|
|
293
|
+
dtype = config.get("torch_dtype", None)
|
|
294
|
+
if dtype in MODEL_CONVERSION_DTYPES:
|
|
295
|
+
print("[INFO] Using dtype:", dtype)
|
|
296
|
+
dtype = getattr(mx, dtype)
|
|
297
|
+
weights = {k: v.astype(dtype) for k, v in weights.items()}
|
|
298
|
+
|
|
299
|
+
if quantize and dequantize:
|
|
300
|
+
raise ValueError("Choose either quantize or dequantize, not both.")
|
|
301
|
+
|
|
302
|
+
if quantize:
|
|
303
|
+
print("[INFO] Quantizing")
|
|
304
|
+
model.load_weights(list(weights.items()))
|
|
305
|
+
weights, config = quantize_model(
|
|
306
|
+
model, config, q_group_size, q_bits, quant_predicate=quant_predicate
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
if dequantize:
|
|
310
|
+
print("[INFO] Dequantizing")
|
|
311
|
+
model = dequantize_model(model)
|
|
312
|
+
weights = dict(tree_flatten(model.parameters()))
|
|
313
|
+
|
|
314
|
+
if isinstance(mlx_path, str):
|
|
315
|
+
mlx_path = Path(mlx_path)
|
|
316
|
+
|
|
317
|
+
# Ensure the destination directory for MLX model exists before copying files
|
|
318
|
+
mlx_path.mkdir(parents=True, exist_ok=True)
|
|
319
|
+
|
|
320
|
+
# Copy Python and JSON files from the model path to the MLX path
|
|
321
|
+
for pattern in ["*.py", "*.json", "*.wav", "*.pt", "*.safetensors", "*.yaml"]:
|
|
322
|
+
files = glob.glob(str(model_path / pattern))
|
|
323
|
+
for file in files:
|
|
324
|
+
shutil.copy(file, mlx_path)
|
|
325
|
+
|
|
326
|
+
# Check files in subdirectories up to two levels deep
|
|
327
|
+
subdir_files = glob.glob(str(model_path / "**" / pattern), recursive=True)
|
|
328
|
+
for file in subdir_files:
|
|
329
|
+
rel_path = Path(file).relative_to(model_path)
|
|
330
|
+
# Create subdirectories if they don't exist
|
|
331
|
+
dest_dir = mlx_path / rel_path.parent
|
|
332
|
+
dest_dir.mkdir(parents=True, exist_ok=True)
|
|
333
|
+
shutil.copy(file, dest_dir)
|
|
334
|
+
|
|
335
|
+
save_model(mlx_path, model, donate_model=True)
|
|
336
|
+
|
|
337
|
+
save_config(config, config_path=mlx_path / "config.json")
|
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from functools import lru_cache
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import mlx.core as mx
|
|
6
|
+
|
|
7
|
+
# Common window functions
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@lru_cache(maxsize=None)
|
|
11
|
+
def hanning(size):
|
|
12
|
+
return mx.array(
|
|
13
|
+
[0.5 * (1 - math.cos(2 * math.pi * n / (size - 1))) for n in range(size)]
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@lru_cache(maxsize=None)
|
|
18
|
+
def hamming(size):
|
|
19
|
+
return mx.array(
|
|
20
|
+
[0.54 - 0.46 * math.cos(2 * math.pi * n / (size - 1)) for n in range(size)]
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@lru_cache(maxsize=None)
|
|
25
|
+
def blackman(size):
|
|
26
|
+
return mx.array(
|
|
27
|
+
[
|
|
28
|
+
0.42
|
|
29
|
+
- 0.5 * math.cos(2 * math.pi * n / (size - 1))
|
|
30
|
+
+ 0.08 * math.cos(4 * math.pi * n / (size - 1))
|
|
31
|
+
for n in range(size)
|
|
32
|
+
]
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@lru_cache(maxsize=None)
|
|
37
|
+
def bartlett(size):
|
|
38
|
+
return mx.array([1 - 2 * abs(n - (size - 1) / 2) / (size - 1) for n in range(size)])
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
STR_TO_WINDOW_FN = {
|
|
42
|
+
"hann": hanning,
|
|
43
|
+
"hanning": hanning,
|
|
44
|
+
"hamming": hamming,
|
|
45
|
+
"blackman": blackman,
|
|
46
|
+
"bartlett": bartlett,
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
# STFT and ISTFT
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def stft(
|
|
53
|
+
x,
|
|
54
|
+
n_fft=800,
|
|
55
|
+
hop_length=None,
|
|
56
|
+
win_length=None,
|
|
57
|
+
window: mx.array | str = "hann",
|
|
58
|
+
center=True,
|
|
59
|
+
pad_mode="reflect",
|
|
60
|
+
):
|
|
61
|
+
if hop_length is None:
|
|
62
|
+
hop_length = n_fft // 4
|
|
63
|
+
if win_length is None:
|
|
64
|
+
win_length = n_fft
|
|
65
|
+
|
|
66
|
+
if isinstance(window, str):
|
|
67
|
+
window_fn = STR_TO_WINDOW_FN.get(window.lower())
|
|
68
|
+
if window_fn is None:
|
|
69
|
+
raise ValueError(f"Unknown window function: {window}")
|
|
70
|
+
w = window_fn(win_length)
|
|
71
|
+
else:
|
|
72
|
+
w = window
|
|
73
|
+
|
|
74
|
+
if w.shape[0] < n_fft:
|
|
75
|
+
pad_size = n_fft - w.shape[0]
|
|
76
|
+
w = mx.concatenate([w, mx.zeros((pad_size,))], axis=0)
|
|
77
|
+
|
|
78
|
+
def _pad(x, padding, pad_mode="reflect"):
|
|
79
|
+
if pad_mode == "constant":
|
|
80
|
+
return mx.pad(x, [(padding, padding)])
|
|
81
|
+
elif pad_mode == "reflect":
|
|
82
|
+
prefix = x[1 : padding + 1][::-1]
|
|
83
|
+
suffix = x[-(padding + 1) : -1][::-1]
|
|
84
|
+
return mx.concatenate([prefix, x, suffix])
|
|
85
|
+
else:
|
|
86
|
+
raise ValueError(f"Invalid pad_mode {pad_mode}")
|
|
87
|
+
|
|
88
|
+
if center:
|
|
89
|
+
x = _pad(x, n_fft // 2, pad_mode)
|
|
90
|
+
|
|
91
|
+
num_frames = 1 + (x.shape[0] - n_fft) // hop_length
|
|
92
|
+
if num_frames <= 0:
|
|
93
|
+
raise ValueError(
|
|
94
|
+
f"Input is too short (length={x.shape[0]}) for n_fft={n_fft} with "
|
|
95
|
+
f"hop_length={hop_length} and center={center}."
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
shape = (num_frames, n_fft)
|
|
99
|
+
strides = (hop_length, 1)
|
|
100
|
+
frames = mx.as_strided(x, shape=shape, strides=strides)
|
|
101
|
+
return mx.fft.rfft(frames * w)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def istft(
|
|
105
|
+
x,
|
|
106
|
+
hop_length=None,
|
|
107
|
+
win_length=None,
|
|
108
|
+
window="hann",
|
|
109
|
+
center=True,
|
|
110
|
+
length=None,
|
|
111
|
+
):
|
|
112
|
+
if win_length is None:
|
|
113
|
+
win_length = (x.shape[1] - 1) * 2
|
|
114
|
+
if hop_length is None:
|
|
115
|
+
hop_length = win_length // 4
|
|
116
|
+
|
|
117
|
+
if isinstance(window, str):
|
|
118
|
+
window_fn = STR_TO_WINDOW_FN.get(window.lower())
|
|
119
|
+
if window_fn is None:
|
|
120
|
+
raise ValueError(f"Unknown window function: {window}")
|
|
121
|
+
w = window_fn(win_length + 1)[:-1]
|
|
122
|
+
else:
|
|
123
|
+
w = window
|
|
124
|
+
|
|
125
|
+
if w.shape[0] < win_length:
|
|
126
|
+
w = mx.concatenate([w, mx.zeros((win_length - w.shape[0],))], axis=0)
|
|
127
|
+
|
|
128
|
+
num_frames = x.shape[1]
|
|
129
|
+
t = (num_frames - 1) * hop_length + win_length
|
|
130
|
+
|
|
131
|
+
reconstructed = mx.zeros(t)
|
|
132
|
+
window_sum = mx.zeros(t)
|
|
133
|
+
|
|
134
|
+
# inverse FFT of each frame
|
|
135
|
+
frames_time = mx.fft.irfft(x, axis=0).transpose(1, 0)
|
|
136
|
+
|
|
137
|
+
# get the position in the time-domain signal to add the frame
|
|
138
|
+
frame_offsets = mx.arange(num_frames) * hop_length
|
|
139
|
+
indices = frame_offsets[:, None] + mx.arange(win_length)
|
|
140
|
+
indices_flat = indices.flatten()
|
|
141
|
+
|
|
142
|
+
updates_reconstructed = (frames_time * w).flatten()
|
|
143
|
+
updates_window = mx.tile(w, (num_frames,)).flatten()
|
|
144
|
+
|
|
145
|
+
# overlap-add the inverse transformed frame, scaled by the window
|
|
146
|
+
reconstructed = reconstructed.at[indices_flat].add(updates_reconstructed)
|
|
147
|
+
window_sum = window_sum.at[indices_flat].add(updates_window)
|
|
148
|
+
|
|
149
|
+
# normalize by the sum of the window values
|
|
150
|
+
reconstructed = mx.where(window_sum != 0, reconstructed / window_sum, reconstructed)
|
|
151
|
+
|
|
152
|
+
if center and length is None:
|
|
153
|
+
reconstructed = reconstructed[win_length // 2 : -win_length // 2]
|
|
154
|
+
|
|
155
|
+
if length is not None:
|
|
156
|
+
reconstructed = reconstructed[:length]
|
|
157
|
+
|
|
158
|
+
return reconstructed
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
# Mel filterbank
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
@lru_cache(maxsize=None)
|
|
165
|
+
def mel_filters(
|
|
166
|
+
sample_rate: int,
|
|
167
|
+
n_fft: int,
|
|
168
|
+
n_mels: int,
|
|
169
|
+
f_min: float = 0,
|
|
170
|
+
f_max: Optional[float] = None,
|
|
171
|
+
norm: Optional[str] = None,
|
|
172
|
+
mel_scale: str = "htk",
|
|
173
|
+
) -> mx.array:
|
|
174
|
+
def hz_to_mel(freq, mel_scale="htk"):
|
|
175
|
+
if mel_scale == "htk":
|
|
176
|
+
return 2595.0 * math.log10(1.0 + freq / 700.0)
|
|
177
|
+
|
|
178
|
+
# slaney scale
|
|
179
|
+
f_min, f_sp = 0.0, 200.0 / 3
|
|
180
|
+
mels = (freq - f_min) / f_sp
|
|
181
|
+
min_log_hz = 1000.0
|
|
182
|
+
min_log_mel = (min_log_hz - f_min) / f_sp
|
|
183
|
+
logstep = math.log(6.4) / 27.0
|
|
184
|
+
if freq >= min_log_hz:
|
|
185
|
+
mels = min_log_mel + math.log(freq / min_log_hz) / logstep
|
|
186
|
+
return mels
|
|
187
|
+
|
|
188
|
+
def mel_to_hz(mels, mel_scale="htk"):
|
|
189
|
+
if mel_scale == "htk":
|
|
190
|
+
return 700.0 * (10.0 ** (mels / 2595.0) - 1.0)
|
|
191
|
+
|
|
192
|
+
# slaney scale
|
|
193
|
+
f_min, f_sp = 0.0, 200.0 / 3
|
|
194
|
+
freqs = f_min + f_sp * mels
|
|
195
|
+
min_log_hz = 1000.0
|
|
196
|
+
min_log_mel = (min_log_hz - f_min) / f_sp
|
|
197
|
+
logstep = math.log(6.4) / 27.0
|
|
198
|
+
freqs = mx.where(
|
|
199
|
+
mels >= min_log_mel,
|
|
200
|
+
min_log_hz * mx.exp(logstep * (mels - min_log_mel)),
|
|
201
|
+
freqs,
|
|
202
|
+
)
|
|
203
|
+
return freqs
|
|
204
|
+
|
|
205
|
+
f_max = f_max or sample_rate / 2
|
|
206
|
+
|
|
207
|
+
# generate frequency points
|
|
208
|
+
|
|
209
|
+
n_freqs = n_fft // 2 + 1
|
|
210
|
+
all_freqs = mx.linspace(0, sample_rate // 2, n_freqs)
|
|
211
|
+
|
|
212
|
+
# convert frequencies to mel and back to hz
|
|
213
|
+
|
|
214
|
+
m_min = hz_to_mel(f_min, mel_scale)
|
|
215
|
+
m_max = hz_to_mel(f_max, mel_scale)
|
|
216
|
+
m_pts = mx.linspace(m_min, m_max, n_mels + 2)
|
|
217
|
+
f_pts = mel_to_hz(m_pts, mel_scale)
|
|
218
|
+
|
|
219
|
+
# compute slopes for filterbank
|
|
220
|
+
|
|
221
|
+
f_diff = f_pts[1:] - f_pts[:-1]
|
|
222
|
+
slopes = mx.expand_dims(f_pts, 0) - mx.expand_dims(all_freqs, 1)
|
|
223
|
+
|
|
224
|
+
# calculate overlapping triangular filters
|
|
225
|
+
|
|
226
|
+
down_slopes = (-slopes[:, :-2]) / f_diff[:-1]
|
|
227
|
+
up_slopes = slopes[:, 2:] / f_diff[1:]
|
|
228
|
+
filterbank = mx.maximum(
|
|
229
|
+
mx.zeros_like(down_slopes), mx.minimum(down_slopes, up_slopes)
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
if norm == "slaney":
|
|
233
|
+
enorm = 2.0 / (f_pts[2 : n_mels + 2] - f_pts[:n_mels])
|
|
234
|
+
filterbank *= mx.expand_dims(enorm, 0)
|
|
235
|
+
|
|
236
|
+
filterbank = filterbank.moveaxis(0, 1)
|
|
237
|
+
return filterbank
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.2.3"
|