nexaai 1.0.29__cp310-cp310-macosx_14_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nexaai/__init__.py +99 -0
- nexaai/_stub.cpython-310-darwin.so +0 -0
- nexaai/_version.py +4 -0
- nexaai/asr.py +68 -0
- nexaai/asr_impl/__init__.py +0 -0
- nexaai/asr_impl/mlx_asr_impl.py +93 -0
- nexaai/asr_impl/pybind_asr_impl.py +127 -0
- nexaai/base.py +39 -0
- nexaai/binds/__init__.py +7 -0
- nexaai/binds/asr_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/common_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/cpu_gpu/libggml-base.dylib +0 -0
- nexaai/binds/cpu_gpu/libggml-cpu.so +0 -0
- nexaai/binds/cpu_gpu/libggml-metal.so +0 -0
- nexaai/binds/cpu_gpu/libggml.dylib +0 -0
- nexaai/binds/cpu_gpu/libmtmd.dylib +0 -0
- nexaai/binds/cpu_gpu/libnexa_cpu_gpu.dylib +0 -0
- nexaai/binds/cpu_gpu/libnexa_plugin.dylib +0 -0
- nexaai/binds/cv_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/diarize_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/embedder_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/libnexa_bridge.dylib +0 -0
- nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/metal/libnexa_plugin.dylib +0 -0
- nexaai/binds/metal/py-lib/ml.py +888 -0
- nexaai/binds/metal/py-lib/mlx_audio/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/__init__.py +5 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/activation.py +51 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/amp.py +96 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/conv.py +114 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/resample.py +177 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/base.py +228 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/dac.py +285 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/layers.py +129 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/encodec.py +777 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/mimi.py +286 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model.py +260 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model_v2.py +383 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/utils.py +122 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/attention.py +97 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/layers.py +306 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/snac.py +154 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/vq.py +135 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/mel.py +33 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/vocos.py +359 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_bigvgan.py +54 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_descript.py +109 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_encodec.py +58 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_mimi.py +22 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_s3.py +25 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_snac.py +40 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_vocos.py +93 -0
- nexaai/binds/metal/py-lib/mlx_audio/server.py +525 -0
- nexaai/binds/metal/py-lib/mlx_audio/sts/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
- nexaai/binds/metal/py-lib/mlx_audio/sts/voice_pipeline.py +327 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/generate.py +174 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/alignment.py +248 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/attention.py +187 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/audio.py +76 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/conformer.py +331 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/ctc.py +34 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/audio.py +82 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/decoding.py +742 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/timing.py +329 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/whisper.py +862 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/writers.py +268 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/tests/test_models.py +381 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/utils.py +195 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/audio_player.py +120 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/convert.py +71 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/generate.py +449 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/__init__.py +4 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/bark.py +528 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/isftnet.py +12 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/pipeline.py +442 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/base.py +84 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/audio.py +287 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/config.py +256 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/dia.py +592 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/layers.py +870 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/__init__.py +3 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/attention.py +180 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/conformer.py +247 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/gpt2.py +38 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/indextts.py +412 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/mel.py +37 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/normalize.py +294 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/perceiver.py +62 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/interpolate.py +108 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/__init__.py +4 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/modules.py +659 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/voice.py +113 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/__init__.py +3 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/llama.py +324 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/outetts.py +255 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/tokens.py +36 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/__init__.py +3 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/attention.py +195 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/sesame.py +633 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/watermarking.py +105 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/bicodec.py +269 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual.py +209 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/spark.py +382 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/audio.py +220 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/file.py +221 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_base.py +66 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_convert.py +173 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_interpolate.py +88 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_models.py +974 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/utils.py +337 -0
- nexaai/binds/metal/py-lib/mlx_audio/utils.py +237 -0
- nexaai/binds/metal/py-lib/mlx_audio/version.py +1 -0
- nexaai/binds/metal/py-lib/profiling.py +239 -0
- nexaai/binds/nexaml/libfftw3.3.dylib +0 -0
- nexaai/binds/nexaml/libfftw3f.3.dylib +0 -0
- nexaai/binds/nexaml/libggml-base.dylib +0 -0
- nexaai/binds/nexaml/libggml-cpu.so +0 -0
- nexaai/binds/nexaml/libggml-metal.so +0 -0
- nexaai/binds/nexaml/libggml.dylib +0 -0
- nexaai/binds/nexaml/libmp3lame.0.dylib +0 -0
- nexaai/binds/nexaml/libmpg123.0.dylib +0 -0
- nexaai/binds/nexaml/libnexa-mm-process.dylib +0 -0
- nexaai/binds/nexaml/libnexa-sampling.dylib +0 -0
- nexaai/binds/nexaml/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexaml/libnexaproc.dylib +0 -0
- nexaai/binds/nexaml/libomp.dylib +0 -0
- nexaai/binds/nexaml/libqwen3-vl.dylib +0 -0
- nexaai/binds/nexaml/libqwen3vl-vision.dylib +0 -0
- nexaai/binds/rerank_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/vlm_bind.cpython-310-darwin.so +0 -0
- nexaai/common.py +106 -0
- nexaai/cv.py +95 -0
- nexaai/cv_impl/__init__.py +0 -0
- nexaai/cv_impl/mlx_cv_impl.py +91 -0
- nexaai/cv_impl/pybind_cv_impl.py +124 -0
- nexaai/diarize.py +80 -0
- nexaai/diarize_impl/__init__.py +1 -0
- nexaai/diarize_impl/pybind_diarize_impl.py +125 -0
- nexaai/embedder.py +73 -0
- nexaai/embedder_impl/__init__.py +0 -0
- nexaai/embedder_impl/mlx_embedder_impl.py +118 -0
- nexaai/embedder_impl/pybind_embedder_impl.py +96 -0
- nexaai/image_gen.py +141 -0
- nexaai/image_gen_impl/__init__.py +0 -0
- nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -0
- nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -0
- nexaai/llm.py +98 -0
- nexaai/llm_impl/__init__.py +0 -0
- nexaai/llm_impl/mlx_llm_impl.py +271 -0
- nexaai/llm_impl/pybind_llm_impl.py +238 -0
- nexaai/log.py +92 -0
- nexaai/mlx_backend/asr/__init__.py +12 -0
- nexaai/mlx_backend/asr/interface.py +122 -0
- nexaai/mlx_backend/common/__init__.py +0 -0
- nexaai/mlx_backend/common/utils.py +25 -0
- nexaai/mlx_backend/cv/__init__.py +0 -0
- nexaai/mlx_backend/cv/generate.py +195 -0
- nexaai/mlx_backend/cv/interface.py +162 -0
- nexaai/mlx_backend/cv/main.py +81 -0
- nexaai/mlx_backend/cv/modeling/pp_ocr_v4.py +1736 -0
- nexaai/mlx_backend/embedding/__init__.py +0 -0
- nexaai/mlx_backend/embedding/generate.py +333 -0
- nexaai/mlx_backend/embedding/interface.py +617 -0
- nexaai/mlx_backend/embedding/main.py +173 -0
- nexaai/mlx_backend/embedding/modeling/__init__.py +0 -0
- nexaai/mlx_backend/embedding/modeling/nexa_jina_v2.py +399 -0
- nexaai/mlx_backend/image_gen/__init__.py +1 -0
- nexaai/mlx_backend/image_gen/generate_sd.py +244 -0
- nexaai/mlx_backend/image_gen/interface.py +82 -0
- nexaai/mlx_backend/image_gen/main.py +281 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/__init__.py +306 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/clip.py +116 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/config.py +65 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/model_io.py +386 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/sampler.py +105 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/tokenizer.py +100 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/unet.py +460 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/vae.py +274 -0
- nexaai/mlx_backend/llm/__init__.py +0 -0
- nexaai/mlx_backend/llm/generate.py +149 -0
- nexaai/mlx_backend/llm/interface.py +764 -0
- nexaai/mlx_backend/llm/main.py +68 -0
- nexaai/mlx_backend/ml.py +888 -0
- nexaai/mlx_backend/mlx_audio/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/codec/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/__init__.py +5 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/activation.py +51 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/amp.py +96 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/conv.py +114 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/resample.py +177 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/base.py +228 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/dac.py +285 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/layers.py +129 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
- nexaai/mlx_backend/mlx_audio/codec/models/encodec/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/encodec/encodec.py +777 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/mimi.py +286 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/model.py +260 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/model_v2.py +383 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/utils.py +122 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/attention.py +97 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/layers.py +306 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/snac.py +154 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/vq.py +135 -0
- nexaai/mlx_backend/mlx_audio/codec/models/vocos/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/vocos/mel.py +33 -0
- nexaai/mlx_backend/mlx_audio/codec/models/vocos/vocos.py +359 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_bigvgan.py +54 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_descript.py +109 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_encodec.py +58 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_mimi.py +22 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_s3.py +25 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_snac.py +40 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_vocos.py +93 -0
- nexaai/mlx_backend/mlx_audio/server.py +525 -0
- nexaai/mlx_backend/mlx_audio/sts/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
- nexaai/mlx_backend/mlx_audio/sts/voice_pipeline.py +327 -0
- nexaai/mlx_backend/mlx_audio/stt/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/stt/generate.py +174 -0
- nexaai/mlx_backend/mlx_audio/stt/models/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/alignment.py +248 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/attention.py +187 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/audio.py +76 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/conformer.py +331 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/ctc.py +34 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
- nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
- nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/audio.py +82 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/decoding.py +742 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/timing.py +329 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/whisper.py +862 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/writers.py +268 -0
- nexaai/mlx_backend/mlx_audio/stt/tests/test_models.py +381 -0
- nexaai/mlx_backend/mlx_audio/stt/utils.py +195 -0
- nexaai/mlx_backend/mlx_audio/tts/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/audio_player.py +120 -0
- nexaai/mlx_backend/mlx_audio/tts/convert.py +71 -0
- nexaai/mlx_backend/mlx_audio/tts/generate.py +449 -0
- nexaai/mlx_backend/mlx_audio/tts/models/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/__init__.py +4 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/bark.py +528 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/isftnet.py +12 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/pipeline.py +442 -0
- nexaai/mlx_backend/mlx_audio/tts/models/base.py +84 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/audio.py +287 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/config.py +256 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/dia.py +592 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/layers.py +870 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/__init__.py +3 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/attention.py +180 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/conformer.py +247 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/gpt2.py +38 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/indextts.py +412 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/mel.py +37 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/normalize.py +294 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/perceiver.py +62 -0
- nexaai/mlx_backend/mlx_audio/tts/models/interpolate.py +108 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/__init__.py +4 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/modules.py +659 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/voice.py +113 -0
- nexaai/mlx_backend/mlx_audio/tts/models/llama/__init__.py +3 -0
- nexaai/mlx_backend/mlx_audio/tts/models/llama/llama.py +324 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/default_speaker.json +461 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/outetts.py +255 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/tokens.py +36 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/__init__.py +3 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/attention.py +195 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/sesame.py +633 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/watermarking.py +105 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/bicodec.py +269 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual.py +209 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/spark.py +382 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/audio.py +220 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/file.py +221 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_base.py +66 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_convert.py +173 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_interpolate.py +88 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_models.py +974 -0
- nexaai/mlx_backend/mlx_audio/tts/utils.py +337 -0
- nexaai/mlx_backend/mlx_audio/utils.py +237 -0
- nexaai/mlx_backend/mlx_audio/version.py +1 -0
- nexaai/mlx_backend/profiling.py +239 -0
- nexaai/mlx_backend/rerank/__init__.py +0 -0
- nexaai/mlx_backend/rerank/generate.py +174 -0
- nexaai/mlx_backend/rerank/interface.py +287 -0
- nexaai/mlx_backend/rerank/main.py +127 -0
- nexaai/mlx_backend/rerank/modeling/__init__.py +0 -0
- nexaai/mlx_backend/rerank/modeling/nexa_jina_rerank.py +330 -0
- nexaai/mlx_backend/sd/__init__.py +1 -0
- nexaai/mlx_backend/sd/interface.py +362 -0
- nexaai/mlx_backend/sd/main.py +286 -0
- nexaai/mlx_backend/sd/modeling/__init__.py +306 -0
- nexaai/mlx_backend/sd/modeling/clip.py +116 -0
- nexaai/mlx_backend/sd/modeling/config.py +65 -0
- nexaai/mlx_backend/sd/modeling/model_io.py +385 -0
- nexaai/mlx_backend/sd/modeling/sampler.py +105 -0
- nexaai/mlx_backend/sd/modeling/tokenizer.py +100 -0
- nexaai/mlx_backend/sd/modeling/unet.py +460 -0
- nexaai/mlx_backend/sd/modeling/vae.py +274 -0
- nexaai/mlx_backend/tts/__init__.py +12 -0
- nexaai/mlx_backend/tts/interface.py +276 -0
- nexaai/mlx_backend/vlm/__init__.py +3 -0
- nexaai/mlx_backend/vlm/generate.py +572 -0
- nexaai/mlx_backend/vlm/generate_qwen3_vl.py +374 -0
- nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +259 -0
- nexaai/mlx_backend/vlm/interface.py +559 -0
- nexaai/mlx_backend/vlm/main.py +365 -0
- nexaai/mlx_backend/vlm/modeling/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/convert.py +68 -0
- nexaai/mlx_backend/vlm/modeling/models/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/interpolate.py +186 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/language.py +233 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/vision.py +503 -0
- nexaai/mlx_backend/vlm/modeling/models/base.py +202 -0
- nexaai/mlx_backend/vlm/modeling/models/cache.py +230 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/florence2.py +366 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/language.py +488 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/vision.py +591 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/gemma3.py +213 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/language.py +315 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/vision.py +238 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/__init__.py +2 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/audio.py +1038 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/config.py +139 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/language.py +629 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/vision.py +1022 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/idefics2.py +294 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/language.py +191 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/vision.py +267 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/idefics3.py +175 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/language.py +192 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/vision.py +233 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/language.py +220 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/processor.py +393 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/vision.py +293 -0
- nexaai/mlx_backend/vlm/modeling/models/kernels.py +307 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/language.py +509 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/vision.py +522 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/language.py +386 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/llama4.py +138 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/vision.py +560 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/language.py +240 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/llava.py +153 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/vision.py +259 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/language.py +236 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/vision.py +303 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/language.py +230 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/llava_next.py +160 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/vision.py +243 -0
- nexaai/mlx_backend/vlm/modeling/models/mistral3/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/mistral3/mistral3.py +283 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/language.py +416 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/mllama.py +172 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/vision.py +499 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/language.py +243 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/molmo.py +133 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/vision.py +465 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/__init__.py +10 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/language.py +230 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/sam.py +557 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/vision.py +526 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/language.py +282 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/paligemma.py +160 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/vision.py +242 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/language.py +21 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/su_rope.py +71 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/vision.py +324 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/language.py +229 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/pixtral.py +161 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/vision.py +320 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/config.py +104 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/language.py +490 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/vision.py +312 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/processor.py +476 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +1262 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1308 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
- nexaai/mlx_backend/vlm/modeling/models/smolvlm/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
- nexaai/mlx_backend/vlm/modeling/processing_qwen2_5_vl.py +209 -0
- nexaai/mlx_backend/vlm/modeling/processing_qwen2_vl.py +215 -0
- nexaai/mlx_backend/vlm/modeling/prompt_utils.py +474 -0
- nexaai/mlx_backend/vlm/modeling/sample_utils.py +39 -0
- nexaai/mlx_backend/vlm/modeling/tokenizer_utils.py +344 -0
- nexaai/mlx_backend/vlm/modeling/trainer/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/trainer/lora.py +70 -0
- nexaai/mlx_backend/vlm/modeling/trainer/trainer.py +296 -0
- nexaai/mlx_backend/vlm/modeling/trainer/utils.py +160 -0
- nexaai/mlx_backend/vlm/modeling/utils.py +928 -0
- nexaai/rerank.py +57 -0
- nexaai/rerank_impl/__init__.py +0 -0
- nexaai/rerank_impl/mlx_rerank_impl.py +94 -0
- nexaai/rerank_impl/pybind_rerank_impl.py +136 -0
- nexaai/runtime.py +68 -0
- nexaai/runtime_error.py +24 -0
- nexaai/tts.py +75 -0
- nexaai/tts_impl/__init__.py +0 -0
- nexaai/tts_impl/mlx_tts_impl.py +94 -0
- nexaai/tts_impl/pybind_tts_impl.py +43 -0
- nexaai/utils/decode.py +18 -0
- nexaai/utils/manifest_utils.py +531 -0
- nexaai/utils/model_manager.py +1745 -0
- nexaai/utils/model_types.py +49 -0
- nexaai/utils/progress_tracker.py +389 -0
- nexaai/utils/quantization_utils.py +245 -0
- nexaai/vlm.py +130 -0
- nexaai/vlm_impl/__init__.py +0 -0
- nexaai/vlm_impl/mlx_vlm_impl.py +259 -0
- nexaai/vlm_impl/pybind_vlm_impl.py +275 -0
- nexaai-1.0.29.dist-info/METADATA +35 -0
- nexaai-1.0.29.dist-info/RECORD +580 -0
- nexaai-1.0.29.dist-info/WHEEL +5 -0
- nexaai-1.0.29.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,287 @@
|
|
|
1
|
+
import typing as tp
|
|
2
|
+
|
|
3
|
+
import mlx.core as mx
|
|
4
|
+
|
|
5
|
+
from .config import DataConfig
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def build_delay_indices(
|
|
9
|
+
B: int, T: int, C: int, delay_pattern: tp.List[int]
|
|
10
|
+
) -> tp.Tuple[mx.array, mx.array]:
|
|
11
|
+
"""
|
|
12
|
+
Precompute (t_idx_BxTxC, indices_BTCx3) so that out[t, c] = in[t - delay[c], c].
|
|
13
|
+
Negative t_idx => BOS; t_idx >= T => PAD.
|
|
14
|
+
"""
|
|
15
|
+
delay_arr = mx.array(delay_pattern, dtype=mx.int32)
|
|
16
|
+
|
|
17
|
+
t_idx_BxT = mx.broadcast_to(
|
|
18
|
+
mx.arange(T, dtype=mx.int32)[None, :],
|
|
19
|
+
[B, T],
|
|
20
|
+
)
|
|
21
|
+
t_idx_BxTx1 = mx.expand_dims(t_idx_BxT, -1)
|
|
22
|
+
t_idx_BxTxC = t_idx_BxTx1 - mx.reshape(delay_arr, (1, 1, C))
|
|
23
|
+
|
|
24
|
+
b_idx_BxTxC = mx.broadcast_to(
|
|
25
|
+
mx.reshape(mx.arange(B, dtype=mx.int32), (B, 1, 1)),
|
|
26
|
+
[B, T, C],
|
|
27
|
+
)
|
|
28
|
+
c_idx_BxTxC = mx.broadcast_to(
|
|
29
|
+
mx.reshape(mx.arange(C, dtype=mx.int32), (1, 1, C)),
|
|
30
|
+
[B, T, C],
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
# We must clamp time indices to [0..T-1] so gather_nd equivalent won't fail
|
|
34
|
+
t_clamped_BxTxC = mx.clip(t_idx_BxTxC, 0, T - 1)
|
|
35
|
+
|
|
36
|
+
indices_BTCx3 = mx.stack(
|
|
37
|
+
[
|
|
38
|
+
mx.reshape(b_idx_BxTxC, (-1,)),
|
|
39
|
+
mx.reshape(t_clamped_BxTxC, (-1,)),
|
|
40
|
+
mx.reshape(c_idx_BxTxC, (-1,)),
|
|
41
|
+
],
|
|
42
|
+
axis=1,
|
|
43
|
+
).astype(mx.int32)
|
|
44
|
+
|
|
45
|
+
return t_idx_BxTxC, indices_BTCx3
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def apply_audio_delay(
|
|
49
|
+
audio_BxTxC: mx.array,
|
|
50
|
+
pad_value: int,
|
|
51
|
+
bos_value: int,
|
|
52
|
+
precomp: tp.Tuple[mx.array, mx.array],
|
|
53
|
+
) -> mx.array:
|
|
54
|
+
"""
|
|
55
|
+
Applies the delay pattern to batched audio tokens using precomputed indices,
|
|
56
|
+
inserting BOS where t_idx < 0 and PAD where t_idx >= T.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
audio_BxTxC: [B, T, C] int16 audio tokens (or int32/float)
|
|
60
|
+
pad_value: the padding token
|
|
61
|
+
bos_value: the BOS token
|
|
62
|
+
precomp: (t_idx_BxTxC, indices_BTCx3) from build_delay_indices
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
result_BxTxC: [B, T, C] delayed audio tokens
|
|
66
|
+
"""
|
|
67
|
+
t_idx_BxTxC, indices_BTCx3 = precomp
|
|
68
|
+
|
|
69
|
+
def gather_nd(array, indices):
|
|
70
|
+
gathered = []
|
|
71
|
+
for idx in range(indices.shape[0]):
|
|
72
|
+
b, t, c = indices[idx, 0], indices[idx, 1], indices[idx, 2]
|
|
73
|
+
gathered.append(array[b, t, c])
|
|
74
|
+
return mx.array(gathered)
|
|
75
|
+
|
|
76
|
+
# Apply gather
|
|
77
|
+
gathered_flat = gather_nd(audio_BxTxC, indices_BTCx3)
|
|
78
|
+
gathered_BxTxC = mx.reshape(gathered_flat, audio_BxTxC.shape)
|
|
79
|
+
|
|
80
|
+
# Create masks
|
|
81
|
+
mask_bos = t_idx_BxTxC < 0 # => place bos_value
|
|
82
|
+
mask_pad = t_idx_BxTxC >= audio_BxTxC.shape[1] # => place pad_value
|
|
83
|
+
|
|
84
|
+
# Create scalar values
|
|
85
|
+
bos_tensor = mx.full(1, bos_value, dtype=audio_BxTxC.dtype)
|
|
86
|
+
pad_tensor = mx.full(1, pad_value, dtype=audio_BxTxC.dtype)
|
|
87
|
+
|
|
88
|
+
# Apply masks (if mask_bos, BOS; else if mask_pad, PAD; else original gather)
|
|
89
|
+
result_BxTxC = mx.where(
|
|
90
|
+
mask_bos, bos_tensor, mx.where(mask_pad, pad_tensor, gathered_BxTxC)
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
return result_BxTxC
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def audio_to_codebook(
|
|
97
|
+
model,
|
|
98
|
+
input_values,
|
|
99
|
+
data_config: DataConfig,
|
|
100
|
+
padding_mask=None,
|
|
101
|
+
sample_rate=44100,
|
|
102
|
+
):
|
|
103
|
+
"""
|
|
104
|
+
Encodes the input audio waveform into discrete codes.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
model: The model to use for encoding.
|
|
108
|
+
input_values (`mx.array` of shape `(batch_size, channels, sequence_length)`):
|
|
109
|
+
Float values of the input audio waveform.
|
|
110
|
+
padding_mask (`mx.array` of shape `(batch_size, channels, sequence_length)`):
|
|
111
|
+
Padding mask used to pad the `input_values`.
|
|
112
|
+
sample_rate (`int`, *optional*) :
|
|
113
|
+
Signal sampling_rate
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
A list of frames containing the discrete encoded codes for the input audio waveform, along with rescaling
|
|
117
|
+
factors for each chunk when `normalize` is True. Each frames is a tuple `(codebook, scale)`, with
|
|
118
|
+
`codebook` of shape `[batch_size, num_codebooks, frames]`.
|
|
119
|
+
Scale is not used here.
|
|
120
|
+
"""
|
|
121
|
+
audio_data = model.preprocess(input_values, sample_rate)
|
|
122
|
+
|
|
123
|
+
if padding_mask is None:
|
|
124
|
+
padding_mask = mx.ones_like(input_values).astype(mx.bool_)
|
|
125
|
+
|
|
126
|
+
_, encoded_frame, _, _, _ = model.encode(audio_data, n_quantizers=None) # 1, C, T
|
|
127
|
+
seq_length = encoded_frame.shape[2]
|
|
128
|
+
|
|
129
|
+
t_idx_BxTxC, indices_BTCx3 = build_delay_indices(
|
|
130
|
+
B=1,
|
|
131
|
+
T=seq_length,
|
|
132
|
+
C=data_config.channels,
|
|
133
|
+
delay_pattern=data_config.delay_pattern,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
encoded_frame = apply_audio_delay(
|
|
137
|
+
audio_BxTxC=mx.transpose(encoded_frame, (0, 2, 1)), # 1, T, C
|
|
138
|
+
pad_value=data_config.audio_pad_value,
|
|
139
|
+
bos_value=data_config.audio_bos_value,
|
|
140
|
+
precomp=(t_idx_BxTxC, indices_BTCx3),
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
return encoded_frame
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def build_revert_indices(
|
|
147
|
+
B: int, T: int, C: int, delay_pattern: tp.List[int]
|
|
148
|
+
) -> tp.Tuple[mx.array, mx.array]:
|
|
149
|
+
"""
|
|
150
|
+
Precompute indices for the revert operation using MLX.
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
A tuple (t_idx_BxTxC, indices_BTCx3) where:
|
|
154
|
+
- t_idx_BxTxC is a tensor of shape [B, T, C] computed as time indices plus the delay.
|
|
155
|
+
- indices_BTCx3 is a tensor of shape [B*T*C, 3] used for gathering, computed from:
|
|
156
|
+
batch indices, clamped time indices, and channel indices.
|
|
157
|
+
"""
|
|
158
|
+
delay_arr = mx.array(delay_pattern, dtype=mx.int32)
|
|
159
|
+
|
|
160
|
+
t_idx_BT1 = mx.broadcast_to(mx.expand_dims(mx.arange(T), 0), [B, T])
|
|
161
|
+
t_idx_BT1 = mx.expand_dims(t_idx_BT1, -1)
|
|
162
|
+
|
|
163
|
+
t_idx_BxTxC = mx.minimum(
|
|
164
|
+
t_idx_BT1 + mx.reshape(delay_arr, (1, 1, C)),
|
|
165
|
+
mx.array(T - 1, dtype=mx.int32),
|
|
166
|
+
)
|
|
167
|
+
b_idx_BxTxC = mx.broadcast_to(mx.reshape(mx.arange(B), (B, 1, 1)), [B, T, C])
|
|
168
|
+
c_idx_BxTxC = mx.broadcast_to(mx.reshape(mx.arange(C), (1, 1, C)), [B, T, C])
|
|
169
|
+
|
|
170
|
+
indices_BTCx3 = mx.stack(
|
|
171
|
+
[
|
|
172
|
+
mx.reshape(b_idx_BxTxC, (-1,)),
|
|
173
|
+
mx.reshape(t_idx_BxTxC, (-1,)),
|
|
174
|
+
mx.reshape(c_idx_BxTxC, (-1,)),
|
|
175
|
+
],
|
|
176
|
+
axis=1,
|
|
177
|
+
).astype(mx.int32)
|
|
178
|
+
|
|
179
|
+
return t_idx_BxTxC, indices_BTCx3
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def revert_audio_delay(
|
|
183
|
+
audio_BxTxC: mx.array,
|
|
184
|
+
pad_value: int,
|
|
185
|
+
precomp: tp.Tuple[mx.array, mx.array],
|
|
186
|
+
T: int,
|
|
187
|
+
) -> mx.array:
|
|
188
|
+
"""
|
|
189
|
+
Reverts a delay pattern from batched audio tokens using precomputed indices (MLX version).
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
audio_BxTxC: Input delayed audio tensor
|
|
193
|
+
pad_value: Padding value for out-of-bounds indices
|
|
194
|
+
precomp: Precomputed revert indices tuple containing:
|
|
195
|
+
- t_idx_BxTxC: Time offset indices tensor
|
|
196
|
+
- indices_BTCx3: Gather indices tensor for original audio
|
|
197
|
+
T: Original sequence length before padding
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
Reverted audio tensor with same shape as input
|
|
201
|
+
"""
|
|
202
|
+
t_idx_BxTxC, indices_BTCx3 = precomp
|
|
203
|
+
|
|
204
|
+
def gather_nd(array, indices):
|
|
205
|
+
gathered = []
|
|
206
|
+
for idx in range(indices.shape[0]):
|
|
207
|
+
b, t, c = indices[idx, 0], indices[idx, 1], indices[idx, 2]
|
|
208
|
+
gathered.append(array[b, t, c])
|
|
209
|
+
return mx.array(gathered)
|
|
210
|
+
|
|
211
|
+
gathered_flat = gather_nd(audio_BxTxC, indices_BTCx3)
|
|
212
|
+
gathered_BxTxC = mx.reshape(gathered_flat, audio_BxTxC.shape)
|
|
213
|
+
pad_tensor = mx.full(1, pad_value, dtype=audio_BxTxC.dtype)
|
|
214
|
+
T_tensor = mx.array(T)
|
|
215
|
+
|
|
216
|
+
result_BxTxC = mx.where(t_idx_BxTxC >= T_tensor, pad_tensor, gathered_BxTxC)
|
|
217
|
+
|
|
218
|
+
return result_BxTxC
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def decode(
|
|
222
|
+
model,
|
|
223
|
+
audio_codes,
|
|
224
|
+
):
|
|
225
|
+
"""
|
|
226
|
+
Decodes the given frames into an output audio waveform
|
|
227
|
+
"""
|
|
228
|
+
|
|
229
|
+
if len(audio_codes) != 1:
|
|
230
|
+
raise ValueError(f"Expected one frame, got {len(audio_codes)}")
|
|
231
|
+
|
|
232
|
+
try:
|
|
233
|
+
audio_values = model.quantizer.from_codes(audio_codes)
|
|
234
|
+
audio_values = model.decode(audio_values[0])
|
|
235
|
+
|
|
236
|
+
return audio_values
|
|
237
|
+
except Exception as e:
|
|
238
|
+
print(f"Error in decode method: {str(e)}")
|
|
239
|
+
raise
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def codebook_to_audio(
|
|
243
|
+
generated_codes: mx.array, model, delay_pattern, B=1, T=2600, C=9
|
|
244
|
+
):
|
|
245
|
+
"""Process a single codebook file to generate audio"""
|
|
246
|
+
# Remove BOS token
|
|
247
|
+
generated_codes = generated_codes[:, 1:]
|
|
248
|
+
|
|
249
|
+
if generated_codes.shape[1] > T:
|
|
250
|
+
generated_codes = generated_codes[:, :T]
|
|
251
|
+
|
|
252
|
+
seq_length = generated_codes.shape[1]
|
|
253
|
+
|
|
254
|
+
# Build revert indices
|
|
255
|
+
t_idx_BxTxC, indices_BTCx3 = build_revert_indices(
|
|
256
|
+
B=B, T=seq_length, C=C, delay_pattern=delay_pattern
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
# Transpose and add batch dimension
|
|
260
|
+
audio_BxTxC = mx.expand_dims(mx.transpose(generated_codes, (1, 0)), 0)
|
|
261
|
+
reverted_codebook = revert_audio_delay(
|
|
262
|
+
audio_BxTxC=audio_BxTxC,
|
|
263
|
+
pad_value=0,
|
|
264
|
+
precomp=(t_idx_BxTxC, indices_BTCx3),
|
|
265
|
+
T=seq_length,
|
|
266
|
+
)
|
|
267
|
+
reverted_codebook = reverted_codebook[:, :-30, :]
|
|
268
|
+
|
|
269
|
+
codebook = mx.transpose(reverted_codebook, (0, 2, 1))
|
|
270
|
+
|
|
271
|
+
min_valid_index = 0
|
|
272
|
+
max_valid_index = 1023
|
|
273
|
+
invalid_mask = (codebook < min_valid_index) | (codebook > max_valid_index)
|
|
274
|
+
|
|
275
|
+
num_invalid = mx.sum(invalid_mask).item()
|
|
276
|
+
if num_invalid > 0:
|
|
277
|
+
print(
|
|
278
|
+
f"Warning: Clamping {num_invalid} indices outside range [{min_valid_index}, {max_valid_index}] to 0."
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
# Set invalid values to 0
|
|
282
|
+
zeros = mx.zeros_like(codebook)
|
|
283
|
+
codebook = mx.where(invalid_mask, zeros, codebook)
|
|
284
|
+
|
|
285
|
+
audio_array = decode(model, codebook)
|
|
286
|
+
|
|
287
|
+
return audio_array
|
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
"""Configuration management module for the Dia model.
|
|
2
|
+
|
|
3
|
+
This module provides comprehensive configuration management for the Dia model,
|
|
4
|
+
utilizing dataclasses for validation. It defines configurations for data processing,
|
|
5
|
+
model architecture (encoder and decoder), and training settings.
|
|
6
|
+
|
|
7
|
+
Key components:
|
|
8
|
+
- DataConfig: Parameters for data loading and preprocessing.
|
|
9
|
+
- EncoderConfig: Architecture details for the encoder module.
|
|
10
|
+
- DecoderConfig: Architecture details for the decoder module.
|
|
11
|
+
- ModelConfig: Combined model architecture settings.
|
|
12
|
+
- TrainingConfig: Training hyperparameters and settings.
|
|
13
|
+
- DiaConfig: Master configuration combining all components.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
import json
|
|
17
|
+
import os
|
|
18
|
+
from dataclasses import dataclass, field
|
|
19
|
+
from typing import List, Optional, Tuple
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass(frozen=True)
|
|
23
|
+
class DataConfig:
|
|
24
|
+
"""Configuration for data loading and preprocessing.
|
|
25
|
+
|
|
26
|
+
Attributes:
|
|
27
|
+
text_length: Maximum length of text sequences (must be multiple of 128).
|
|
28
|
+
audio_length: Maximum length of audio sequences (must be multiple of 128).
|
|
29
|
+
channels: Number of audio channels.
|
|
30
|
+
text_pad_value: Value used for padding text sequences.
|
|
31
|
+
audio_eos_value: Value representing the end of audio sequences.
|
|
32
|
+
audio_bos_value: Value representing the beginning of audio sequences.
|
|
33
|
+
audio_pad_value: Value used for padding audio sequences.
|
|
34
|
+
delay_pattern: List of delay values for each audio channel.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
text_length: int
|
|
38
|
+
audio_length: int
|
|
39
|
+
channels: int = 9
|
|
40
|
+
text_pad_value: int = 0
|
|
41
|
+
audio_eos_value: int = 1024
|
|
42
|
+
audio_pad_value: int = 1025
|
|
43
|
+
audio_bos_value: int = 1026
|
|
44
|
+
delay_pattern: List[int] = field(
|
|
45
|
+
default_factory=lambda: [0, 8, 9, 10, 11, 12, 13, 14, 15]
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
def __post_init__(self):
|
|
49
|
+
# Ensure text_length and audio_length are multiples of 128
|
|
50
|
+
object.__setattr__(self, "text_length", (self.text_length + 127) // 128 * 128)
|
|
51
|
+
object.__setattr__(self, "audio_length", (self.audio_length + 127) // 128 * 128)
|
|
52
|
+
|
|
53
|
+
def __hash__(self) -> int:
|
|
54
|
+
"""Generate a hash based on all fields of the config."""
|
|
55
|
+
return hash(
|
|
56
|
+
(
|
|
57
|
+
self.text_length,
|
|
58
|
+
self.audio_length,
|
|
59
|
+
self.channels,
|
|
60
|
+
self.text_pad_value,
|
|
61
|
+
self.audio_pad_value,
|
|
62
|
+
self.audio_bos_value,
|
|
63
|
+
self.audio_eos_value,
|
|
64
|
+
tuple(self.delay_pattern),
|
|
65
|
+
)
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@dataclass(frozen=True)
|
|
70
|
+
class EncoderConfig:
|
|
71
|
+
"""Configuration for the encoder component of the Dia model.
|
|
72
|
+
|
|
73
|
+
Attributes:
|
|
74
|
+
n_layer: Number of transformer layers.
|
|
75
|
+
n_embd: Embedding dimension.
|
|
76
|
+
n_hidden: Hidden dimension size in the MLP layers.
|
|
77
|
+
n_head: Number of attention heads.
|
|
78
|
+
head_dim: Dimension per attention head.
|
|
79
|
+
mlp_activations: List of activation functions for the MLP layers.
|
|
80
|
+
use_pre_norm: Whether to use pre-normalization (LayerNorm before attention/MLP).
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
n_layer: int
|
|
84
|
+
n_embd: int
|
|
85
|
+
n_hidden: int
|
|
86
|
+
n_head: int
|
|
87
|
+
head_dim: int
|
|
88
|
+
mlp_activations: List[str] = field(default_factory=lambda: ["silu", "linear"])
|
|
89
|
+
use_pre_norm: bool = False
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@dataclass(frozen=True)
|
|
93
|
+
class DecoderConfig:
|
|
94
|
+
"""Configuration for the decoder component of the Dia model.
|
|
95
|
+
|
|
96
|
+
Attributes:
|
|
97
|
+
n_layer: Number of transformer layers.
|
|
98
|
+
n_embd: Embedding dimension.
|
|
99
|
+
n_hidden: Hidden dimension size in the MLP layers.
|
|
100
|
+
gqa_query_heads: Number of query heads for grouped-query self-attention.
|
|
101
|
+
kv_heads: Number of key/value heads for grouped-query self-attention.
|
|
102
|
+
gqa_head_dim: Dimension per query head for grouped-query self-attention.
|
|
103
|
+
cross_query_heads: Number of query heads for cross-attention.
|
|
104
|
+
cross_head_dim: Dimension per cross-attention head.
|
|
105
|
+
mlp_activations: List of activation functions for the MLP layers.
|
|
106
|
+
use_pre_norm: Whether to use pre-normalization.
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
n_layer: int
|
|
110
|
+
n_embd: int
|
|
111
|
+
n_hidden: int
|
|
112
|
+
gqa_query_heads: int
|
|
113
|
+
kv_heads: int
|
|
114
|
+
gqa_head_dim: int
|
|
115
|
+
cross_query_heads: int
|
|
116
|
+
cross_head_dim: int
|
|
117
|
+
mlp_activations: List[str] = field(default_factory=lambda: ["silu", "linear"])
|
|
118
|
+
use_pre_norm: bool = False
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
@dataclass(frozen=True)
|
|
122
|
+
class ModelConfig:
|
|
123
|
+
"""Main configuration container for the Dia model architecture.
|
|
124
|
+
|
|
125
|
+
Attributes:
|
|
126
|
+
encoder: Configuration for the encoder component.
|
|
127
|
+
decoder: Configuration for the decoder component.
|
|
128
|
+
src_vocab_size: Size of the source (text) vocabulary.
|
|
129
|
+
tgt_vocab_size: Size of the target (audio code) vocabulary.
|
|
130
|
+
dropout: Dropout probability applied within the model.
|
|
131
|
+
normalization_layer_epsilon: Epsilon value for normalization layers (e.g., LayerNorm).
|
|
132
|
+
weight_dtype: Data type for model weights (e.g., "float32", "bfloat16").
|
|
133
|
+
rope_min_timescale: Minimum timescale for Rotary Positional Embeddings (RoPE).
|
|
134
|
+
rope_max_timescale: Maximum timescale for Rotary Positional Embeddings (RoPE).
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
encoder: EncoderConfig
|
|
138
|
+
decoder: DecoderConfig
|
|
139
|
+
src_vocab_size: int = 128
|
|
140
|
+
tgt_vocab_size: int = 1028
|
|
141
|
+
dropout: float = 0.0
|
|
142
|
+
normalization_layer_epsilon: float = 1.0e-5
|
|
143
|
+
weight_dtype: str = "float32"
|
|
144
|
+
rope_min_timescale: int = 1
|
|
145
|
+
rope_max_timescale: int = 10_000
|
|
146
|
+
sample_rate: int = 44100
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
@dataclass(frozen=True)
|
|
150
|
+
class TrainingConfig:
|
|
151
|
+
"""Training process configuration and hyperparameters.
|
|
152
|
+
|
|
153
|
+
Note: This configuration currently only includes precision settings.
|
|
154
|
+
Other training parameters (like batch size, learning rate, optimizer settings)
|
|
155
|
+
are assumed to be handled externally.
|
|
156
|
+
|
|
157
|
+
Attributes:
|
|
158
|
+
dtype: Data type for activations during training (e.g., "bfloat16", "float32").
|
|
159
|
+
logits_dot_in_fp32: Whether to compute the final logits dot product in fp32 for stability.
|
|
160
|
+
"""
|
|
161
|
+
|
|
162
|
+
dtype: str = "bfloat16"
|
|
163
|
+
logits_dot_in_fp32: bool = False
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
@dataclass(frozen=True)
|
|
167
|
+
class DiaConfig:
|
|
168
|
+
"""Master configuration for the Dia model.
|
|
169
|
+
|
|
170
|
+
Combines all sub-configurations into a single validated object.
|
|
171
|
+
|
|
172
|
+
Attributes:
|
|
173
|
+
version: Configuration version string.
|
|
174
|
+
model: Model architecture configuration.
|
|
175
|
+
training: Training process configuration (precision settings).
|
|
176
|
+
data: Data loading and processing configuration.
|
|
177
|
+
"""
|
|
178
|
+
|
|
179
|
+
model: ModelConfig
|
|
180
|
+
training: TrainingConfig
|
|
181
|
+
data: DataConfig
|
|
182
|
+
version: str = "1.0"
|
|
183
|
+
|
|
184
|
+
def save(self, path: str) -> None:
|
|
185
|
+
"""Save the current configuration instance to a JSON file.
|
|
186
|
+
|
|
187
|
+
Ensures the parent directory exists and the file has a .json extension.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
path: The target file path to save the configuration.
|
|
191
|
+
|
|
192
|
+
Raises:
|
|
193
|
+
ValueError: If the path is not a file with a .json extension.
|
|
194
|
+
"""
|
|
195
|
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
196
|
+
config_dict = {
|
|
197
|
+
"version": self.version,
|
|
198
|
+
"model": {
|
|
199
|
+
"encoder": vars(self.model.encoder),
|
|
200
|
+
"decoder": vars(self.model.decoder),
|
|
201
|
+
"src_vocab_size": self.model.src_vocab_size,
|
|
202
|
+
"tgt_vocab_size": self.model.tgt_vocab_size,
|
|
203
|
+
"dropout": self.model.dropout,
|
|
204
|
+
"normalization_layer_epsilon": self.model.normalization_layer_epsilon,
|
|
205
|
+
"weight_dtype": self.model.weight_dtype,
|
|
206
|
+
"rope_min_timescale": self.model.rope_min_timescale,
|
|
207
|
+
"rope_max_timescale": self.model.rope_max_timescale,
|
|
208
|
+
"sample_rate": self.model.sample_rate,
|
|
209
|
+
},
|
|
210
|
+
"training": vars(self.training),
|
|
211
|
+
"data": vars(self.data),
|
|
212
|
+
}
|
|
213
|
+
with open(path, "w") as f:
|
|
214
|
+
json.dump(config_dict, f, indent=2)
|
|
215
|
+
|
|
216
|
+
@classmethod
|
|
217
|
+
def load_dict(cls, config: dict) -> Optional["DiaConfig"]:
|
|
218
|
+
try:
|
|
219
|
+
model_config = ModelConfig(
|
|
220
|
+
encoder=EncoderConfig(**config["model"]["encoder"]),
|
|
221
|
+
decoder=DecoderConfig(**config["model"]["decoder"]),
|
|
222
|
+
**{
|
|
223
|
+
k: v
|
|
224
|
+
for k, v in config["model"].items()
|
|
225
|
+
if k not in ["encoder", "decoder"]
|
|
226
|
+
},
|
|
227
|
+
)
|
|
228
|
+
return cls(
|
|
229
|
+
version=config.get("version", "1.0"),
|
|
230
|
+
model=model_config,
|
|
231
|
+
training=TrainingConfig(**config["training"]),
|
|
232
|
+
data=DataConfig(**config["data"]),
|
|
233
|
+
)
|
|
234
|
+
except (KeyError, TypeError):
|
|
235
|
+
return None
|
|
236
|
+
|
|
237
|
+
@classmethod
|
|
238
|
+
def load(cls, path: str) -> Optional["DiaConfig"]:
|
|
239
|
+
"""Load and validate a Dia configuration from a JSON file.
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
path: The path to the configuration file.
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
A validated DiaConfig instance if the file exists and is valid,
|
|
246
|
+
otherwise None if the file is not found.
|
|
247
|
+
|
|
248
|
+
Raises:
|
|
249
|
+
ValueError: If the JSON content fails validation against the DiaConfig schema.
|
|
250
|
+
"""
|
|
251
|
+
try:
|
|
252
|
+
with open(path, "r") as f:
|
|
253
|
+
config = json.load(f)
|
|
254
|
+
return cls.load_dict(config)
|
|
255
|
+
except FileNotFoundError:
|
|
256
|
+
return None
|