nexaai 1.0.29__cp310-cp310-macosx_14_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nexaai/__init__.py +99 -0
- nexaai/_stub.cpython-310-darwin.so +0 -0
- nexaai/_version.py +4 -0
- nexaai/asr.py +68 -0
- nexaai/asr_impl/__init__.py +0 -0
- nexaai/asr_impl/mlx_asr_impl.py +93 -0
- nexaai/asr_impl/pybind_asr_impl.py +127 -0
- nexaai/base.py +39 -0
- nexaai/binds/__init__.py +7 -0
- nexaai/binds/asr_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/common_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/cpu_gpu/libggml-base.dylib +0 -0
- nexaai/binds/cpu_gpu/libggml-cpu.so +0 -0
- nexaai/binds/cpu_gpu/libggml-metal.so +0 -0
- nexaai/binds/cpu_gpu/libggml.dylib +0 -0
- nexaai/binds/cpu_gpu/libmtmd.dylib +0 -0
- nexaai/binds/cpu_gpu/libnexa_cpu_gpu.dylib +0 -0
- nexaai/binds/cpu_gpu/libnexa_plugin.dylib +0 -0
- nexaai/binds/cv_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/diarize_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/embedder_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/libnexa_bridge.dylib +0 -0
- nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/metal/libnexa_plugin.dylib +0 -0
- nexaai/binds/metal/py-lib/ml.py +888 -0
- nexaai/binds/metal/py-lib/mlx_audio/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/__init__.py +5 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/activation.py +51 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/amp.py +96 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/conv.py +114 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/resample.py +177 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/base.py +228 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/dac.py +285 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/layers.py +129 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/encodec.py +777 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/mimi.py +286 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model.py +260 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model_v2.py +383 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/utils.py +122 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/attention.py +97 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/layers.py +306 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/snac.py +154 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/vq.py +135 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/mel.py +33 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/vocos.py +359 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_bigvgan.py +54 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_descript.py +109 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_encodec.py +58 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_mimi.py +22 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_s3.py +25 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_snac.py +40 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_vocos.py +93 -0
- nexaai/binds/metal/py-lib/mlx_audio/server.py +525 -0
- nexaai/binds/metal/py-lib/mlx_audio/sts/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
- nexaai/binds/metal/py-lib/mlx_audio/sts/voice_pipeline.py +327 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/generate.py +174 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/alignment.py +248 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/attention.py +187 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/audio.py +76 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/conformer.py +331 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/ctc.py +34 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/audio.py +82 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/decoding.py +742 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/timing.py +329 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/whisper.py +862 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/writers.py +268 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/tests/test_models.py +381 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/utils.py +195 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/audio_player.py +120 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/convert.py +71 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/generate.py +449 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/__init__.py +4 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/bark.py +528 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/isftnet.py +12 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/pipeline.py +442 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/base.py +84 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/audio.py +287 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/config.py +256 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/dia.py +592 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/layers.py +870 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/__init__.py +3 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/attention.py +180 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/conformer.py +247 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/gpt2.py +38 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/indextts.py +412 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/mel.py +37 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/normalize.py +294 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/perceiver.py +62 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/interpolate.py +108 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/__init__.py +4 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/modules.py +659 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/voice.py +113 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/__init__.py +3 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/llama.py +324 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/outetts.py +255 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/tokens.py +36 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/__init__.py +3 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/attention.py +195 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/sesame.py +633 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/watermarking.py +105 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/bicodec.py +269 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual.py +209 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/spark.py +382 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/audio.py +220 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/file.py +221 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_base.py +66 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_convert.py +173 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_interpolate.py +88 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_models.py +974 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/utils.py +337 -0
- nexaai/binds/metal/py-lib/mlx_audio/utils.py +237 -0
- nexaai/binds/metal/py-lib/mlx_audio/version.py +1 -0
- nexaai/binds/metal/py-lib/profiling.py +239 -0
- nexaai/binds/nexaml/libfftw3.3.dylib +0 -0
- nexaai/binds/nexaml/libfftw3f.3.dylib +0 -0
- nexaai/binds/nexaml/libggml-base.dylib +0 -0
- nexaai/binds/nexaml/libggml-cpu.so +0 -0
- nexaai/binds/nexaml/libggml-metal.so +0 -0
- nexaai/binds/nexaml/libggml.dylib +0 -0
- nexaai/binds/nexaml/libmp3lame.0.dylib +0 -0
- nexaai/binds/nexaml/libmpg123.0.dylib +0 -0
- nexaai/binds/nexaml/libnexa-mm-process.dylib +0 -0
- nexaai/binds/nexaml/libnexa-sampling.dylib +0 -0
- nexaai/binds/nexaml/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexaml/libnexaproc.dylib +0 -0
- nexaai/binds/nexaml/libomp.dylib +0 -0
- nexaai/binds/nexaml/libqwen3-vl.dylib +0 -0
- nexaai/binds/nexaml/libqwen3vl-vision.dylib +0 -0
- nexaai/binds/rerank_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/vlm_bind.cpython-310-darwin.so +0 -0
- nexaai/common.py +106 -0
- nexaai/cv.py +95 -0
- nexaai/cv_impl/__init__.py +0 -0
- nexaai/cv_impl/mlx_cv_impl.py +91 -0
- nexaai/cv_impl/pybind_cv_impl.py +124 -0
- nexaai/diarize.py +80 -0
- nexaai/diarize_impl/__init__.py +1 -0
- nexaai/diarize_impl/pybind_diarize_impl.py +125 -0
- nexaai/embedder.py +73 -0
- nexaai/embedder_impl/__init__.py +0 -0
- nexaai/embedder_impl/mlx_embedder_impl.py +118 -0
- nexaai/embedder_impl/pybind_embedder_impl.py +96 -0
- nexaai/image_gen.py +141 -0
- nexaai/image_gen_impl/__init__.py +0 -0
- nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -0
- nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -0
- nexaai/llm.py +98 -0
- nexaai/llm_impl/__init__.py +0 -0
- nexaai/llm_impl/mlx_llm_impl.py +271 -0
- nexaai/llm_impl/pybind_llm_impl.py +238 -0
- nexaai/log.py +92 -0
- nexaai/mlx_backend/asr/__init__.py +12 -0
- nexaai/mlx_backend/asr/interface.py +122 -0
- nexaai/mlx_backend/common/__init__.py +0 -0
- nexaai/mlx_backend/common/utils.py +25 -0
- nexaai/mlx_backend/cv/__init__.py +0 -0
- nexaai/mlx_backend/cv/generate.py +195 -0
- nexaai/mlx_backend/cv/interface.py +162 -0
- nexaai/mlx_backend/cv/main.py +81 -0
- nexaai/mlx_backend/cv/modeling/pp_ocr_v4.py +1736 -0
- nexaai/mlx_backend/embedding/__init__.py +0 -0
- nexaai/mlx_backend/embedding/generate.py +333 -0
- nexaai/mlx_backend/embedding/interface.py +617 -0
- nexaai/mlx_backend/embedding/main.py +173 -0
- nexaai/mlx_backend/embedding/modeling/__init__.py +0 -0
- nexaai/mlx_backend/embedding/modeling/nexa_jina_v2.py +399 -0
- nexaai/mlx_backend/image_gen/__init__.py +1 -0
- nexaai/mlx_backend/image_gen/generate_sd.py +244 -0
- nexaai/mlx_backend/image_gen/interface.py +82 -0
- nexaai/mlx_backend/image_gen/main.py +281 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/__init__.py +306 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/clip.py +116 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/config.py +65 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/model_io.py +386 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/sampler.py +105 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/tokenizer.py +100 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/unet.py +460 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/vae.py +274 -0
- nexaai/mlx_backend/llm/__init__.py +0 -0
- nexaai/mlx_backend/llm/generate.py +149 -0
- nexaai/mlx_backend/llm/interface.py +764 -0
- nexaai/mlx_backend/llm/main.py +68 -0
- nexaai/mlx_backend/ml.py +888 -0
- nexaai/mlx_backend/mlx_audio/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/codec/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/__init__.py +5 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/activation.py +51 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/amp.py +96 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/conv.py +114 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/resample.py +177 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/base.py +228 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/dac.py +285 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/layers.py +129 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
- nexaai/mlx_backend/mlx_audio/codec/models/encodec/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/encodec/encodec.py +777 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/mimi.py +286 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/model.py +260 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/model_v2.py +383 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/utils.py +122 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/attention.py +97 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/layers.py +306 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/snac.py +154 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/vq.py +135 -0
- nexaai/mlx_backend/mlx_audio/codec/models/vocos/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/vocos/mel.py +33 -0
- nexaai/mlx_backend/mlx_audio/codec/models/vocos/vocos.py +359 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_bigvgan.py +54 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_descript.py +109 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_encodec.py +58 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_mimi.py +22 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_s3.py +25 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_snac.py +40 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_vocos.py +93 -0
- nexaai/mlx_backend/mlx_audio/server.py +525 -0
- nexaai/mlx_backend/mlx_audio/sts/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
- nexaai/mlx_backend/mlx_audio/sts/voice_pipeline.py +327 -0
- nexaai/mlx_backend/mlx_audio/stt/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/stt/generate.py +174 -0
- nexaai/mlx_backend/mlx_audio/stt/models/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/alignment.py +248 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/attention.py +187 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/audio.py +76 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/conformer.py +331 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/ctc.py +34 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
- nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
- nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/audio.py +82 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/decoding.py +742 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/timing.py +329 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/whisper.py +862 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/writers.py +268 -0
- nexaai/mlx_backend/mlx_audio/stt/tests/test_models.py +381 -0
- nexaai/mlx_backend/mlx_audio/stt/utils.py +195 -0
- nexaai/mlx_backend/mlx_audio/tts/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/audio_player.py +120 -0
- nexaai/mlx_backend/mlx_audio/tts/convert.py +71 -0
- nexaai/mlx_backend/mlx_audio/tts/generate.py +449 -0
- nexaai/mlx_backend/mlx_audio/tts/models/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/__init__.py +4 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/bark.py +528 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/isftnet.py +12 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/pipeline.py +442 -0
- nexaai/mlx_backend/mlx_audio/tts/models/base.py +84 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/audio.py +287 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/config.py +256 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/dia.py +592 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/layers.py +870 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/__init__.py +3 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/attention.py +180 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/conformer.py +247 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/gpt2.py +38 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/indextts.py +412 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/mel.py +37 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/normalize.py +294 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/perceiver.py +62 -0
- nexaai/mlx_backend/mlx_audio/tts/models/interpolate.py +108 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/__init__.py +4 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/modules.py +659 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/voice.py +113 -0
- nexaai/mlx_backend/mlx_audio/tts/models/llama/__init__.py +3 -0
- nexaai/mlx_backend/mlx_audio/tts/models/llama/llama.py +324 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/default_speaker.json +461 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/outetts.py +255 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/tokens.py +36 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/__init__.py +3 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/attention.py +195 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/sesame.py +633 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/watermarking.py +105 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/bicodec.py +269 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual.py +209 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/spark.py +382 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/audio.py +220 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/file.py +221 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_base.py +66 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_convert.py +173 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_interpolate.py +88 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_models.py +974 -0
- nexaai/mlx_backend/mlx_audio/tts/utils.py +337 -0
- nexaai/mlx_backend/mlx_audio/utils.py +237 -0
- nexaai/mlx_backend/mlx_audio/version.py +1 -0
- nexaai/mlx_backend/profiling.py +239 -0
- nexaai/mlx_backend/rerank/__init__.py +0 -0
- nexaai/mlx_backend/rerank/generate.py +174 -0
- nexaai/mlx_backend/rerank/interface.py +287 -0
- nexaai/mlx_backend/rerank/main.py +127 -0
- nexaai/mlx_backend/rerank/modeling/__init__.py +0 -0
- nexaai/mlx_backend/rerank/modeling/nexa_jina_rerank.py +330 -0
- nexaai/mlx_backend/sd/__init__.py +1 -0
- nexaai/mlx_backend/sd/interface.py +362 -0
- nexaai/mlx_backend/sd/main.py +286 -0
- nexaai/mlx_backend/sd/modeling/__init__.py +306 -0
- nexaai/mlx_backend/sd/modeling/clip.py +116 -0
- nexaai/mlx_backend/sd/modeling/config.py +65 -0
- nexaai/mlx_backend/sd/modeling/model_io.py +385 -0
- nexaai/mlx_backend/sd/modeling/sampler.py +105 -0
- nexaai/mlx_backend/sd/modeling/tokenizer.py +100 -0
- nexaai/mlx_backend/sd/modeling/unet.py +460 -0
- nexaai/mlx_backend/sd/modeling/vae.py +274 -0
- nexaai/mlx_backend/tts/__init__.py +12 -0
- nexaai/mlx_backend/tts/interface.py +276 -0
- nexaai/mlx_backend/vlm/__init__.py +3 -0
- nexaai/mlx_backend/vlm/generate.py +572 -0
- nexaai/mlx_backend/vlm/generate_qwen3_vl.py +374 -0
- nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +259 -0
- nexaai/mlx_backend/vlm/interface.py +559 -0
- nexaai/mlx_backend/vlm/main.py +365 -0
- nexaai/mlx_backend/vlm/modeling/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/convert.py +68 -0
- nexaai/mlx_backend/vlm/modeling/models/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/interpolate.py +186 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/language.py +233 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/vision.py +503 -0
- nexaai/mlx_backend/vlm/modeling/models/base.py +202 -0
- nexaai/mlx_backend/vlm/modeling/models/cache.py +230 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/florence2.py +366 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/language.py +488 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/vision.py +591 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/gemma3.py +213 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/language.py +315 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/vision.py +238 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/__init__.py +2 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/audio.py +1038 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/config.py +139 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/language.py +629 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/vision.py +1022 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/idefics2.py +294 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/language.py +191 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/vision.py +267 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/idefics3.py +175 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/language.py +192 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/vision.py +233 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/language.py +220 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/processor.py +393 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/vision.py +293 -0
- nexaai/mlx_backend/vlm/modeling/models/kernels.py +307 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/language.py +509 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/vision.py +522 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/language.py +386 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/llama4.py +138 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/vision.py +560 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/language.py +240 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/llava.py +153 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/vision.py +259 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/language.py +236 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/vision.py +303 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/language.py +230 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/llava_next.py +160 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/vision.py +243 -0
- nexaai/mlx_backend/vlm/modeling/models/mistral3/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/mistral3/mistral3.py +283 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/language.py +416 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/mllama.py +172 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/vision.py +499 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/language.py +243 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/molmo.py +133 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/vision.py +465 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/__init__.py +10 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/language.py +230 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/sam.py +557 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/vision.py +526 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/language.py +282 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/paligemma.py +160 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/vision.py +242 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/language.py +21 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/su_rope.py +71 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/vision.py +324 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/language.py +229 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/pixtral.py +161 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/vision.py +320 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/config.py +104 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/language.py +490 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/vision.py +312 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/processor.py +476 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +1262 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1308 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
- nexaai/mlx_backend/vlm/modeling/models/smolvlm/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
- nexaai/mlx_backend/vlm/modeling/processing_qwen2_5_vl.py +209 -0
- nexaai/mlx_backend/vlm/modeling/processing_qwen2_vl.py +215 -0
- nexaai/mlx_backend/vlm/modeling/prompt_utils.py +474 -0
- nexaai/mlx_backend/vlm/modeling/sample_utils.py +39 -0
- nexaai/mlx_backend/vlm/modeling/tokenizer_utils.py +344 -0
- nexaai/mlx_backend/vlm/modeling/trainer/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/trainer/lora.py +70 -0
- nexaai/mlx_backend/vlm/modeling/trainer/trainer.py +296 -0
- nexaai/mlx_backend/vlm/modeling/trainer/utils.py +160 -0
- nexaai/mlx_backend/vlm/modeling/utils.py +928 -0
- nexaai/rerank.py +57 -0
- nexaai/rerank_impl/__init__.py +0 -0
- nexaai/rerank_impl/mlx_rerank_impl.py +94 -0
- nexaai/rerank_impl/pybind_rerank_impl.py +136 -0
- nexaai/runtime.py +68 -0
- nexaai/runtime_error.py +24 -0
- nexaai/tts.py +75 -0
- nexaai/tts_impl/__init__.py +0 -0
- nexaai/tts_impl/mlx_tts_impl.py +94 -0
- nexaai/tts_impl/pybind_tts_impl.py +43 -0
- nexaai/utils/decode.py +18 -0
- nexaai/utils/manifest_utils.py +531 -0
- nexaai/utils/model_manager.py +1745 -0
- nexaai/utils/model_types.py +49 -0
- nexaai/utils/progress_tracker.py +389 -0
- nexaai/utils/quantization_utils.py +245 -0
- nexaai/vlm.py +130 -0
- nexaai/vlm_impl/__init__.py +0 -0
- nexaai/vlm_impl/mlx_vlm_impl.py +259 -0
- nexaai/vlm_impl/pybind_vlm_impl.py +275 -0
- nexaai-1.0.29.dist-info/METADATA +35 -0
- nexaai-1.0.29.dist-info/RECORD +580 -0
- nexaai-1.0.29.dist-info/WHEEL +5 -0
- nexaai-1.0.29.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,742 @@
|
|
|
1
|
+
# Copyright © 2023 Apple Inc.
|
|
2
|
+
|
|
3
|
+
import zlib
|
|
4
|
+
from dataclasses import dataclass, field, replace
|
|
5
|
+
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
|
6
|
+
|
|
7
|
+
import mlx.core as mx
|
|
8
|
+
import numpy as np
|
|
9
|
+
from mlx.utils import tree_map
|
|
10
|
+
|
|
11
|
+
from .audio import CHUNK_LENGTH
|
|
12
|
+
from .tokenizer import Tokenizer, get_tokenizer
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def compression_ratio(text) -> float:
|
|
16
|
+
text_bytes = text.encode("utf-8")
|
|
17
|
+
return len(text_bytes) / len(zlib.compress(text_bytes))
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def detect_language(
|
|
21
|
+
model: "Whisper", mel: mx.array, tokenizer: Tokenizer = None
|
|
22
|
+
) -> Tuple[mx.array, List[dict]]:
|
|
23
|
+
"""
|
|
24
|
+
Detect the spoken language in the audio, and return them as list of strings, along with the ids
|
|
25
|
+
of the most probable language tokens and the probability distribution over all language tokens.
|
|
26
|
+
This is performed outside the main decode loop in order to not interfere with kv-caching.
|
|
27
|
+
|
|
28
|
+
Returns
|
|
29
|
+
-------
|
|
30
|
+
language_tokens : mx.array, shape = (n_audio,)
|
|
31
|
+
ids of the most probable language tokens, which appears after the startoftranscript token.
|
|
32
|
+
language_probs : List[Dict[str, float]], length = n_audio
|
|
33
|
+
list of dictionaries containing the probability distribution over all languages.
|
|
34
|
+
"""
|
|
35
|
+
if tokenizer is None:
|
|
36
|
+
tokenizer = get_tokenizer(
|
|
37
|
+
model.is_multilingual, num_languages=model.num_languages
|
|
38
|
+
)
|
|
39
|
+
if (
|
|
40
|
+
tokenizer.language is None
|
|
41
|
+
or tokenizer.language_token not in tokenizer.sot_sequence
|
|
42
|
+
):
|
|
43
|
+
raise ValueError(
|
|
44
|
+
"This model doesn't have language tokens so it can't perform lang id"
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
single = mel.ndim == 2
|
|
48
|
+
if single:
|
|
49
|
+
mel = mel[None]
|
|
50
|
+
|
|
51
|
+
# skip encoder forward pass if already-encoded audio features were given
|
|
52
|
+
if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
|
|
53
|
+
mel = model.encoder(mel)
|
|
54
|
+
|
|
55
|
+
# forward pass using a single token, startoftranscript
|
|
56
|
+
n_audio = mel.shape[0]
|
|
57
|
+
x = mx.array([[tokenizer.sot]] * n_audio) # [n_audio, 1]
|
|
58
|
+
logits = model.logits(x, mel)[:, 0]
|
|
59
|
+
|
|
60
|
+
# collect detected languages; suppress all non-language tokens
|
|
61
|
+
mask = mx.full(logits.shape[-1], -mx.inf, dtype=mx.float32)
|
|
62
|
+
mask[list(tokenizer.all_language_tokens)] = 0.0
|
|
63
|
+
logits += mask
|
|
64
|
+
language_tokens = mx.argmax(logits, axis=-1)
|
|
65
|
+
language_token_probs = mx.softmax(logits, axis=-1)
|
|
66
|
+
language_token_probs = np.array(language_token_probs)
|
|
67
|
+
language_probs = [
|
|
68
|
+
{
|
|
69
|
+
c: language_token_probs[i, j].item()
|
|
70
|
+
for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
|
|
71
|
+
}
|
|
72
|
+
for i in range(n_audio)
|
|
73
|
+
]
|
|
74
|
+
|
|
75
|
+
if single:
|
|
76
|
+
language_tokens = language_tokens[0]
|
|
77
|
+
language_probs = language_probs[0]
|
|
78
|
+
|
|
79
|
+
return language_tokens, language_probs
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@dataclass(frozen=True)
|
|
83
|
+
class DecodingOptions:
|
|
84
|
+
# whether to perform X->X "transcribe" or X->English "translate"
|
|
85
|
+
task: str = "transcribe"
|
|
86
|
+
|
|
87
|
+
# language that the audio is in; uses detected language if None
|
|
88
|
+
language: Optional[str] = None
|
|
89
|
+
|
|
90
|
+
# sampling-related options
|
|
91
|
+
temperature: float = 0.0
|
|
92
|
+
sample_len: Optional[int] = None # maximum number of tokens to sample
|
|
93
|
+
best_of: Optional[int] = None # number of independent sample trajectories, if t > 0
|
|
94
|
+
beam_size: Optional[int] = None # number of beams in beam search, if t == 0
|
|
95
|
+
patience: Optional[float] = None # patience in beam search (arxiv:2204.05424)
|
|
96
|
+
|
|
97
|
+
# "alpha" in Google NMT, or None for length norm, when ranking generations
|
|
98
|
+
# to select which to return among the beams or best-of-N samples
|
|
99
|
+
length_penalty: Optional[float] = None
|
|
100
|
+
|
|
101
|
+
# text or tokens to feed as the prompt or the prefix; for more info:
|
|
102
|
+
# https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
|
|
103
|
+
prompt: Optional[Union[str, List[int]]] = None # for the previous context
|
|
104
|
+
prefix: Optional[Union[str, List[int]]] = None # to prefix the current context
|
|
105
|
+
|
|
106
|
+
# list of tokens ids (or comma-separated token ids) to suppress
|
|
107
|
+
# "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
|
|
108
|
+
suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
|
|
109
|
+
suppress_blank: bool = True # this will suppress blank outputs
|
|
110
|
+
|
|
111
|
+
# timestamp sampling options
|
|
112
|
+
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
|
|
113
|
+
max_initial_timestamp: Optional[float] = 1.0
|
|
114
|
+
|
|
115
|
+
# implementation details
|
|
116
|
+
fp16: bool = True # use fp16 for most of the calculation
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
@dataclass(frozen=True)
|
|
120
|
+
class DecodingResult:
|
|
121
|
+
audio_features: mx.array
|
|
122
|
+
language: str
|
|
123
|
+
language_probs: Optional[Dict[str, float]] = None
|
|
124
|
+
tokens: List[int] = field(default_factory=list)
|
|
125
|
+
text: str = ""
|
|
126
|
+
avg_logprob: float = np.nan
|
|
127
|
+
no_speech_prob: float = np.nan
|
|
128
|
+
temperature: float = np.nan
|
|
129
|
+
compression_ratio: float = np.nan
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class Inference:
|
|
133
|
+
def __init__(self, model: "Whisper"):
|
|
134
|
+
self.model: "Whisper" = model
|
|
135
|
+
self.kv_cache = None
|
|
136
|
+
|
|
137
|
+
def logits(self, tokens: mx.array, audio_features: mx.array) -> mx.array:
|
|
138
|
+
"""Perform a forward pass on the decoder and return per-token logits"""
|
|
139
|
+
logits, self.kv_cache, _ = self.model.decoder(
|
|
140
|
+
tokens, audio_features, kv_cache=self.kv_cache
|
|
141
|
+
)
|
|
142
|
+
return logits.astype(mx.float32)
|
|
143
|
+
|
|
144
|
+
def rearrange_kv_cache(self, source_indices):
|
|
145
|
+
"""Update the key-value cache according to the updated beams"""
|
|
146
|
+
# update the key/value cache to contain the selected sequences
|
|
147
|
+
if source_indices != list(range(len(source_indices))):
|
|
148
|
+
self.kv_cache = tree_map(lambda x: x[source_indices], self.kv_cache)
|
|
149
|
+
|
|
150
|
+
def reset(self):
|
|
151
|
+
self.kv_cache = None
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class SequenceRanker:
|
|
155
|
+
def rank(
|
|
156
|
+
self, tokens: List[List[mx.array]], sum_logprobs: List[List[float]]
|
|
157
|
+
) -> List[int]:
|
|
158
|
+
"""
|
|
159
|
+
Given a list of groups of samples and their cumulative log probabilities,
|
|
160
|
+
return the indices of the samples in each group to select as the final result
|
|
161
|
+
"""
|
|
162
|
+
raise NotImplementedError
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class MaximumLikelihoodRanker(SequenceRanker):
|
|
166
|
+
"""
|
|
167
|
+
Select the sample with the highest log probabilities, penalized using either
|
|
168
|
+
a simple length normalization or Google NMT paper's length penalty
|
|
169
|
+
"""
|
|
170
|
+
|
|
171
|
+
def __init__(self, length_penalty: Optional[float]):
|
|
172
|
+
self.length_penalty = length_penalty
|
|
173
|
+
|
|
174
|
+
def rank(self, tokens: List[List[List[int]]], sum_logprobs: List[List[float]]):
|
|
175
|
+
def scores(logprobs, lengths):
|
|
176
|
+
result = []
|
|
177
|
+
for logprob, length in zip(logprobs, lengths):
|
|
178
|
+
if self.length_penalty is None:
|
|
179
|
+
penalty = length
|
|
180
|
+
else:
|
|
181
|
+
# from the Google NMT paper
|
|
182
|
+
penalty = ((5 + length) / 6) ** self.length_penalty
|
|
183
|
+
result.append(logprob / penalty)
|
|
184
|
+
return result
|
|
185
|
+
|
|
186
|
+
# get the sequence with the highest score
|
|
187
|
+
lengths = [[len(t) for t in s] for s in tokens]
|
|
188
|
+
return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class TokenDecoder:
|
|
192
|
+
def reset(self):
|
|
193
|
+
"""Initialize any stateful variables for decoding a new sequence"""
|
|
194
|
+
|
|
195
|
+
def update(
|
|
196
|
+
self, tokens: mx.array, logits: mx.array, sum_logprobs: mx.array
|
|
197
|
+
) -> Tuple[mx.array, bool, mx.array]:
|
|
198
|
+
"""Specify how to select the next token, based on the current trace and logits
|
|
199
|
+
|
|
200
|
+
Parameters
|
|
201
|
+
----------
|
|
202
|
+
tokens : mx.array, shape = (n_batch, current_sequence_length)
|
|
203
|
+
all tokens in the context so far, including the prefix and sot_sequence tokens
|
|
204
|
+
|
|
205
|
+
logits : mx.array, shape = (n_batch, vocab_size)
|
|
206
|
+
per-token logits of the probability distribution at the current step
|
|
207
|
+
|
|
208
|
+
sum_logprobs : mx.array, shape = (n_batch)
|
|
209
|
+
cumulative log probabilities for each sequence
|
|
210
|
+
|
|
211
|
+
Returns
|
|
212
|
+
-------
|
|
213
|
+
tokens : mx.array, shape = (n_batch, current_sequence_length + 1)
|
|
214
|
+
the tokens, appended with the selected next token
|
|
215
|
+
|
|
216
|
+
completed : bool
|
|
217
|
+
True if all sequences has reached the end of text
|
|
218
|
+
|
|
219
|
+
sum_logprobs: mx.array, shape = (n_batch)
|
|
220
|
+
updated cumulative log probabilities for each sequence
|
|
221
|
+
|
|
222
|
+
"""
|
|
223
|
+
raise NotImplementedError
|
|
224
|
+
|
|
225
|
+
def finalize(
|
|
226
|
+
self, tokens: mx.array, sum_logprobs: mx.array
|
|
227
|
+
) -> Tuple[Sequence[Sequence[mx.array]], List[List[float]]]:
|
|
228
|
+
"""Finalize search and return the final candidate sequences
|
|
229
|
+
|
|
230
|
+
Parameters
|
|
231
|
+
----------
|
|
232
|
+
tokens : mx.array, shape = (n_audio, n_group, current_sequence_length)
|
|
233
|
+
all tokens in the context so far, including the prefix and sot_sequence
|
|
234
|
+
|
|
235
|
+
sum_logprobs : mx.array, shape = (n_audio, n_group)
|
|
236
|
+
cumulative log probabilities for each sequence
|
|
237
|
+
|
|
238
|
+
Returns
|
|
239
|
+
-------
|
|
240
|
+
tokens : Sequence[Sequence[mx.array]], length = n_audio
|
|
241
|
+
sequence of mx.arrays containing candidate token sequences, for each audio input
|
|
242
|
+
|
|
243
|
+
sum_logprobs : List[List[float]], length = n_audio
|
|
244
|
+
sequence of cumulative log probabilities corresponding to the above
|
|
245
|
+
|
|
246
|
+
"""
|
|
247
|
+
raise NotImplementedError
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
@mx.compile
|
|
251
|
+
def categorical(logits, temp):
|
|
252
|
+
return mx.random.categorical(logits / temp)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
class GreedyDecoder(TokenDecoder):
|
|
256
|
+
def __init__(self, temperature: float, eot: int):
|
|
257
|
+
self.temperature = temperature
|
|
258
|
+
self.eot = eot
|
|
259
|
+
|
|
260
|
+
def update(
|
|
261
|
+
self, tokens: mx.array, logits: mx.array, sum_logprobs: mx.array
|
|
262
|
+
) -> Tuple[mx.array, bool, mx.array]:
|
|
263
|
+
if self.temperature == 0:
|
|
264
|
+
next_tokens = logits.argmax(axis=-1)
|
|
265
|
+
else:
|
|
266
|
+
next_tokens = categorical(logits, self.temperature)
|
|
267
|
+
|
|
268
|
+
logprobs = logits - mx.logsumexp(logits, axis=-1)
|
|
269
|
+
|
|
270
|
+
current_logprobs = logprobs[mx.arange(logprobs.shape[0]), next_tokens]
|
|
271
|
+
sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
|
|
272
|
+
|
|
273
|
+
eot_mask = tokens[:, -1] == self.eot
|
|
274
|
+
next_tokens = next_tokens * (1 - eot_mask) + self.eot * eot_mask
|
|
275
|
+
tokens = mx.concatenate([tokens, next_tokens[:, None]], axis=-1)
|
|
276
|
+
|
|
277
|
+
completed = mx.all(tokens[:, -1] == self.eot)
|
|
278
|
+
return tokens, completed, sum_logprobs
|
|
279
|
+
|
|
280
|
+
def finalize(self, tokens: mx.array, sum_logprobs: mx.array):
|
|
281
|
+
# make sure each sequence has at least one EOT token at the end
|
|
282
|
+
tokens = mx.pad(tokens, [(0, 0), (0, 0), (0, 1)], constant_values=self.eot)
|
|
283
|
+
return tokens, sum_logprobs
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
class LogitFilter:
|
|
287
|
+
def apply(self, logits: mx.array, tokens: mx.array) -> mx.array:
|
|
288
|
+
"""Apply any filtering or masking to logits
|
|
289
|
+
|
|
290
|
+
Parameters
|
|
291
|
+
----------
|
|
292
|
+
logits : mx.array, shape = (n_batch, vocab_size)
|
|
293
|
+
per-token logits of the probability distribution at the current step
|
|
294
|
+
|
|
295
|
+
tokens : mx.array, shape = (n_batch, current_sequence_length)
|
|
296
|
+
all tokens in the context so far, including the prefix and sot_sequence tokens
|
|
297
|
+
|
|
298
|
+
"""
|
|
299
|
+
raise NotImplementedError
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
class SuppressBlank(LogitFilter):
|
|
303
|
+
def __init__(self, tokenizer: Tokenizer, sample_begin: int, n_vocab: int):
|
|
304
|
+
self.sample_begin = sample_begin
|
|
305
|
+
mask = np.zeros(n_vocab, np.float32)
|
|
306
|
+
mask[tokenizer.encode(" ") + [tokenizer.eot]] = -np.inf
|
|
307
|
+
self.mask = mx.array(mask)
|
|
308
|
+
|
|
309
|
+
def apply(self, logits: mx.array, tokens: mx.array) -> mx.array:
|
|
310
|
+
if tokens.shape[1] == self.sample_begin:
|
|
311
|
+
return logits + self.mask
|
|
312
|
+
return logits
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
class SuppressTokens(LogitFilter):
|
|
316
|
+
def __init__(self, suppress_tokens: Sequence[int], n_vocab: int):
|
|
317
|
+
mask = np.zeros(n_vocab, np.float32)
|
|
318
|
+
mask[list(suppress_tokens)] = -np.inf
|
|
319
|
+
self.mask = mx.array(mask)
|
|
320
|
+
|
|
321
|
+
def apply(self, logits: mx.array, tokens: mx.array) -> mx.array:
|
|
322
|
+
return logits + self.mask
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
class ApplyTimestampRules(LogitFilter):
|
|
326
|
+
def __init__(
|
|
327
|
+
self,
|
|
328
|
+
tokenizer: Tokenizer,
|
|
329
|
+
sample_begin: int,
|
|
330
|
+
max_initial_timestamp_index: Optional[int],
|
|
331
|
+
):
|
|
332
|
+
self.tokenizer = tokenizer
|
|
333
|
+
self.sample_begin = sample_begin
|
|
334
|
+
self.max_initial_timestamp_index = max_initial_timestamp_index
|
|
335
|
+
|
|
336
|
+
def apply(self, logits: mx.array, tokens: mx.array) -> mx.array:
|
|
337
|
+
mask = np.zeros(logits.shape, np.float32)
|
|
338
|
+
# suppress <|notimestamps|> which is handled by without_timestamps
|
|
339
|
+
if self.tokenizer.no_timestamps is not None:
|
|
340
|
+
mask[:, self.tokenizer.no_timestamps] = -np.inf
|
|
341
|
+
|
|
342
|
+
## timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
|
343
|
+
tokens = tokens.tolist()
|
|
344
|
+
for k in range(len(tokens)):
|
|
345
|
+
seq = tokens[k][self.sample_begin :]
|
|
346
|
+
last_was_timestamp = (
|
|
347
|
+
len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
|
|
348
|
+
)
|
|
349
|
+
penultimate_was_timestamp = (
|
|
350
|
+
len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
if last_was_timestamp:
|
|
354
|
+
if penultimate_was_timestamp: # has to be non-timestamp
|
|
355
|
+
mask[k, self.tokenizer.timestamp_begin :] = -np.inf
|
|
356
|
+
else: # cannot be normal text tokens
|
|
357
|
+
mask[k, : self.tokenizer.eot] = -np.inf
|
|
358
|
+
|
|
359
|
+
timestamps = [
|
|
360
|
+
i for i, v in enumerate(seq) if v > self.tokenizer.timestamp_begin
|
|
361
|
+
]
|
|
362
|
+
if len(timestamps) > 0:
|
|
363
|
+
# timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
|
|
364
|
+
# also force each segment to have a nonzero length, to prevent infinite looping
|
|
365
|
+
last_timestamp = timestamps[-1]
|
|
366
|
+
if not last_timestamp or penultimate_was_timestamp:
|
|
367
|
+
last_timestamp += 1
|
|
368
|
+
mask[k, self.tokenizer.timestamp_begin : last_timestamp] = -np.inf
|
|
369
|
+
|
|
370
|
+
if len(tokens[0]) == self.sample_begin:
|
|
371
|
+
# suppress generating non-timestamp tokens at the beginning
|
|
372
|
+
mask[:, : self.tokenizer.timestamp_begin] = -np.inf
|
|
373
|
+
|
|
374
|
+
# apply the `max_initial_timestamp` option
|
|
375
|
+
if self.max_initial_timestamp_index is not None:
|
|
376
|
+
last_allowed = (
|
|
377
|
+
self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
|
|
378
|
+
)
|
|
379
|
+
mask[:, last_allowed + 1 :] = -np.inf
|
|
380
|
+
|
|
381
|
+
# if sum of probability over timestamps is above any other token, sample timestamp
|
|
382
|
+
mask = mx.array(mask)
|
|
383
|
+
logprobs = logits - mx.logsumexp(logits, axis=-1)
|
|
384
|
+
timestamp_logprob = logprobs[:, self.tokenizer.timestamp_begin :].logsumexp(
|
|
385
|
+
axis=-1, keepdims=True
|
|
386
|
+
)
|
|
387
|
+
max_text_token_logprob = logprobs[:, : self.tokenizer.timestamp_begin].max(
|
|
388
|
+
axis=-1, keepdims=True
|
|
389
|
+
)
|
|
390
|
+
mask[:, : self.tokenizer.timestamp_begin] = mx.where(
|
|
391
|
+
timestamp_logprob > max_text_token_logprob,
|
|
392
|
+
-mx.inf,
|
|
393
|
+
mask[:, : self.tokenizer.timestamp_begin],
|
|
394
|
+
)
|
|
395
|
+
return logits + mask
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
class DecodingTask:
|
|
399
|
+
inference: Inference
|
|
400
|
+
sequence_ranker: SequenceRanker
|
|
401
|
+
decoder: TokenDecoder
|
|
402
|
+
logit_filters: List[LogitFilter]
|
|
403
|
+
|
|
404
|
+
def __init__(self, model: "Whisper", options: DecodingOptions):
|
|
405
|
+
self.model = model
|
|
406
|
+
|
|
407
|
+
language = options.language or "en"
|
|
408
|
+
tokenizer = get_tokenizer(
|
|
409
|
+
model.is_multilingual,
|
|
410
|
+
num_languages=model.num_languages,
|
|
411
|
+
language=language,
|
|
412
|
+
task=options.task,
|
|
413
|
+
)
|
|
414
|
+
self.tokenizer: Tokenizer = tokenizer
|
|
415
|
+
self.options: DecodingOptions = self._verify_options(options)
|
|
416
|
+
|
|
417
|
+
self.n_group: int = options.beam_size or options.best_of or 1
|
|
418
|
+
self.n_ctx: int = model.dims.n_text_ctx
|
|
419
|
+
self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
|
|
420
|
+
|
|
421
|
+
self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
|
|
422
|
+
if self.options.without_timestamps:
|
|
423
|
+
self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
|
|
424
|
+
|
|
425
|
+
self.initial_tokens: Tuple[int] = self._get_initial_tokens()
|
|
426
|
+
self.sample_begin: int = len(self.initial_tokens)
|
|
427
|
+
self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
|
|
428
|
+
|
|
429
|
+
# inference: implements the forward pass through the decoder, including kv caching
|
|
430
|
+
self.inference = Inference(model)
|
|
431
|
+
|
|
432
|
+
# sequence ranker: implements how to rank a group of sampled sequences
|
|
433
|
+
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
|
|
434
|
+
|
|
435
|
+
# decoder: implements how to select the next tokens, given the autoregressive distribution
|
|
436
|
+
if options.beam_size is not None:
|
|
437
|
+
raise NotImplementedError("Beam search decoder is not yet implemented")
|
|
438
|
+
else:
|
|
439
|
+
self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
|
|
440
|
+
|
|
441
|
+
# logit filters: applies various rules to suppress or penalize certain tokens
|
|
442
|
+
self.logit_filters = []
|
|
443
|
+
if self.options.suppress_blank:
|
|
444
|
+
self.logit_filters.append(
|
|
445
|
+
SuppressBlank(self.tokenizer, self.sample_begin, model.dims.n_vocab)
|
|
446
|
+
)
|
|
447
|
+
if self.options.suppress_tokens:
|
|
448
|
+
self.logit_filters.append(
|
|
449
|
+
SuppressTokens(self._get_suppress_tokens(), model.dims.n_vocab)
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
if not options.without_timestamps:
|
|
453
|
+
precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
|
|
454
|
+
max_initial_timestamp_index = None
|
|
455
|
+
if options.max_initial_timestamp:
|
|
456
|
+
max_initial_timestamp_index = round(
|
|
457
|
+
self.options.max_initial_timestamp / precision
|
|
458
|
+
)
|
|
459
|
+
self.logit_filters.append(
|
|
460
|
+
ApplyTimestampRules(
|
|
461
|
+
tokenizer, self.sample_begin, max_initial_timestamp_index
|
|
462
|
+
)
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
|
|
466
|
+
if options.beam_size is not None and options.best_of is not None:
|
|
467
|
+
raise ValueError("beam_size and best_of can't be given together")
|
|
468
|
+
if options.temperature == 0:
|
|
469
|
+
if options.best_of is not None:
|
|
470
|
+
raise ValueError("best_of with greedy sampling (T=0) is not compatible")
|
|
471
|
+
if options.patience is not None and options.beam_size is None:
|
|
472
|
+
raise ValueError("patience requires beam_size to be given")
|
|
473
|
+
if options.length_penalty is not None and not (
|
|
474
|
+
0 <= options.length_penalty <= 1
|
|
475
|
+
):
|
|
476
|
+
raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
|
|
477
|
+
|
|
478
|
+
return options
|
|
479
|
+
|
|
480
|
+
def _get_initial_tokens(self) -> Tuple[int]:
|
|
481
|
+
tokens = list(self.sot_sequence)
|
|
482
|
+
|
|
483
|
+
if prefix := self.options.prefix:
|
|
484
|
+
prefix_tokens = (
|
|
485
|
+
self.tokenizer.encode(" " + prefix.strip())
|
|
486
|
+
if isinstance(prefix, str)
|
|
487
|
+
else prefix
|
|
488
|
+
)
|
|
489
|
+
if self.sample_len is not None:
|
|
490
|
+
max_prefix_len = self.n_ctx // 2 - self.sample_len
|
|
491
|
+
prefix_tokens = prefix_tokens[-max_prefix_len:]
|
|
492
|
+
tokens = tokens + prefix_tokens
|
|
493
|
+
|
|
494
|
+
if prompt := self.options.prompt:
|
|
495
|
+
prompt_tokens = (
|
|
496
|
+
self.tokenizer.encode(" " + prompt.strip())
|
|
497
|
+
if isinstance(prompt, str)
|
|
498
|
+
else prompt
|
|
499
|
+
)
|
|
500
|
+
tokens = (
|
|
501
|
+
[self.tokenizer.sot_prev]
|
|
502
|
+
+ prompt_tokens[-(self.n_ctx // 2 - 1) :]
|
|
503
|
+
+ tokens
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
return tuple(tokens)
|
|
507
|
+
|
|
508
|
+
def _get_suppress_tokens(self) -> Tuple[int]:
|
|
509
|
+
suppress_tokens = self.options.suppress_tokens
|
|
510
|
+
|
|
511
|
+
if isinstance(suppress_tokens, str):
|
|
512
|
+
suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
|
|
513
|
+
|
|
514
|
+
if -1 in suppress_tokens:
|
|
515
|
+
suppress_tokens = [t for t in suppress_tokens if t >= 0]
|
|
516
|
+
suppress_tokens.extend(self.tokenizer.non_speech_tokens)
|
|
517
|
+
elif suppress_tokens is None or len(suppress_tokens) == 0:
|
|
518
|
+
suppress_tokens = [] # interpret empty string as an empty list
|
|
519
|
+
else:
|
|
520
|
+
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
|
|
521
|
+
|
|
522
|
+
suppress_tokens.extend(
|
|
523
|
+
[
|
|
524
|
+
self.tokenizer.transcribe,
|
|
525
|
+
self.tokenizer.translate,
|
|
526
|
+
self.tokenizer.sot,
|
|
527
|
+
self.tokenizer.sot_prev,
|
|
528
|
+
self.tokenizer.sot_lm,
|
|
529
|
+
]
|
|
530
|
+
)
|
|
531
|
+
if self.tokenizer.no_speech is not None:
|
|
532
|
+
# no-speech probability is collected separately
|
|
533
|
+
suppress_tokens.append(self.tokenizer.no_speech)
|
|
534
|
+
|
|
535
|
+
return tuple(sorted(set(suppress_tokens)))
|
|
536
|
+
|
|
537
|
+
def _get_audio_features(self, mel: mx.array):
|
|
538
|
+
if self.options.fp16:
|
|
539
|
+
mel = mel.astype(mx.float16)
|
|
540
|
+
|
|
541
|
+
if mel.shape[-2:] == (
|
|
542
|
+
self.model.dims.n_audio_ctx,
|
|
543
|
+
self.model.dims.n_audio_state,
|
|
544
|
+
):
|
|
545
|
+
# encoded audio features are given; skip audio encoding
|
|
546
|
+
audio_features = mel
|
|
547
|
+
else:
|
|
548
|
+
audio_features = self.model.encoder(mel)
|
|
549
|
+
|
|
550
|
+
if audio_features.dtype != (mx.float16 if self.options.fp16 else mx.float32):
|
|
551
|
+
raise TypeError(
|
|
552
|
+
f"audio_features has an incorrect dtype: {audio_features.dtype}"
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
return audio_features
|
|
556
|
+
|
|
557
|
+
def _detect_language(self, audio_features: mx.array, tokens: np.array):
|
|
558
|
+
languages = [self.options.language] * audio_features.shape[0]
|
|
559
|
+
lang_probs = None
|
|
560
|
+
|
|
561
|
+
if self.options.language is None or self.options.task == "lang_id":
|
|
562
|
+
lang_tokens, lang_probs = self.model.detect_language(
|
|
563
|
+
audio_features, self.tokenizer
|
|
564
|
+
)
|
|
565
|
+
languages = [max(probs, key=probs.get) for probs in lang_probs]
|
|
566
|
+
if self.options.language is None:
|
|
567
|
+
# write language tokens
|
|
568
|
+
tokens[:, self.sot_index + 1] = np.array(lang_tokens)
|
|
569
|
+
|
|
570
|
+
return languages, lang_probs
|
|
571
|
+
|
|
572
|
+
def _main_loop(self, audio_features: mx.array, tokens: mx.array):
|
|
573
|
+
n_batch = tokens.shape[0]
|
|
574
|
+
sum_logprobs = mx.zeros(n_batch)
|
|
575
|
+
|
|
576
|
+
def _step(inputs, audio_features, tokens, sum_logprobs):
|
|
577
|
+
pre_logits = self.inference.logits(inputs, audio_features)
|
|
578
|
+
|
|
579
|
+
# consider the logits at the last token only
|
|
580
|
+
logits = pre_logits[:, -1]
|
|
581
|
+
|
|
582
|
+
# apply the logit filters, e.g. for suppressing or applying penalty to
|
|
583
|
+
for logit_filter in self.logit_filters:
|
|
584
|
+
logits = logit_filter.apply(logits, tokens)
|
|
585
|
+
|
|
586
|
+
# expand the tokens tensor with the selected next tokens
|
|
587
|
+
tokens, completed, sum_logprobs = self.decoder.update(
|
|
588
|
+
tokens, logits, sum_logprobs
|
|
589
|
+
)
|
|
590
|
+
return tokens, completed, sum_logprobs, pre_logits
|
|
591
|
+
|
|
592
|
+
tokens, completed, sum_logprobs, pre_logits = _step(
|
|
593
|
+
tokens, audio_features, tokens, sum_logprobs
|
|
594
|
+
)
|
|
595
|
+
if self.tokenizer.no_speech is not None: # compute no_speech_probs
|
|
596
|
+
probs_at_sot = mx.softmax(pre_logits[:, self.sot_index], axis=-1)
|
|
597
|
+
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech]
|
|
598
|
+
else:
|
|
599
|
+
no_speech_probs = mx.full(n_batch, mx.nan)
|
|
600
|
+
mx.async_eval(completed, tokens, sum_logprobs, no_speech_probs)
|
|
601
|
+
|
|
602
|
+
for i in range(1, self.sample_len):
|
|
603
|
+
inputs = tokens[:, -1:]
|
|
604
|
+
if tokens.shape[-1] > self.n_ctx:
|
|
605
|
+
break
|
|
606
|
+
next_tokens, next_completed, next_sum_logprobs, _ = _step(
|
|
607
|
+
inputs, audio_features, tokens, sum_logprobs
|
|
608
|
+
)
|
|
609
|
+
mx.async_eval(next_completed, next_tokens, next_sum_logprobs)
|
|
610
|
+
if completed:
|
|
611
|
+
break
|
|
612
|
+
tokens = next_tokens
|
|
613
|
+
completed = next_completed
|
|
614
|
+
sum_logprobs = next_sum_logprobs
|
|
615
|
+
|
|
616
|
+
return tokens, sum_logprobs, no_speech_probs
|
|
617
|
+
|
|
618
|
+
def run(self, mel: mx.array) -> List[DecodingResult]:
|
|
619
|
+
self.inference.reset()
|
|
620
|
+
self.decoder.reset()
|
|
621
|
+
tokenizer: Tokenizer = self.tokenizer
|
|
622
|
+
n_audio: int = mel.shape[0]
|
|
623
|
+
|
|
624
|
+
audio_features: mx.array = self._get_audio_features(mel) # encoder forward pass
|
|
625
|
+
tokens: mx.array = mx.array(self.initial_tokens)
|
|
626
|
+
tokens = mx.broadcast_to(tokens, (n_audio, len(self.initial_tokens)))
|
|
627
|
+
|
|
628
|
+
# detect language if requested, overwriting the language token
|
|
629
|
+
languages, language_probs = self._detect_language(audio_features, tokens)
|
|
630
|
+
if self.options.task == "lang_id":
|
|
631
|
+
return [
|
|
632
|
+
DecodingResult(
|
|
633
|
+
audio_features=features, language=language, language_probs=probs
|
|
634
|
+
)
|
|
635
|
+
for features, language, probs in zip(
|
|
636
|
+
audio_features, languages, language_probs
|
|
637
|
+
)
|
|
638
|
+
]
|
|
639
|
+
|
|
640
|
+
# repeat tokens by the group size, for beam search or best-of-n sampling
|
|
641
|
+
if self.n_group > 1:
|
|
642
|
+
tokens = tokens[:, None, :]
|
|
643
|
+
tokens = mx.broadcast_to(
|
|
644
|
+
tokens, [n_audio, self.n_group, len(self.initial_tokens)]
|
|
645
|
+
)
|
|
646
|
+
tokens = tokens.reshape(
|
|
647
|
+
tokens, (n_audio * self.n_group, len(self.initial_tokens))
|
|
648
|
+
)
|
|
649
|
+
|
|
650
|
+
# call the main sampling loop
|
|
651
|
+
tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
|
|
652
|
+
|
|
653
|
+
# reshape the tensors to have (n_audio, n_group) as the first two dimensions
|
|
654
|
+
audio_features = audio_features[:: self.n_group]
|
|
655
|
+
no_speech_probs = no_speech_probs[:: self.n_group]
|
|
656
|
+
assert audio_features.shape[0] == len(no_speech_probs) == n_audio
|
|
657
|
+
|
|
658
|
+
tokens = tokens.reshape(n_audio, self.n_group, -1)
|
|
659
|
+
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
|
|
660
|
+
|
|
661
|
+
# get the final candidates for each group, and slice between the first sampled token and EOT
|
|
662
|
+
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
|
|
663
|
+
tokens = tokens[..., self.sample_begin :]
|
|
664
|
+
|
|
665
|
+
# eval and convert to list
|
|
666
|
+
mx.eval(tokens, sum_logprobs, no_speech_probs)
|
|
667
|
+
tokens = tokens.tolist()
|
|
668
|
+
sum_logprobs = sum_logprobs.tolist()
|
|
669
|
+
no_speech_probs = no_speech_probs.tolist()
|
|
670
|
+
tokens = [[t[: t.index(tokenizer.eot)] for t in s] for s in tokens]
|
|
671
|
+
|
|
672
|
+
# select the top-ranked sample in each group
|
|
673
|
+
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
|
|
674
|
+
tokens: List[List[int]] = [t[i] for i, t in zip(selected, tokens)]
|
|
675
|
+
texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
|
|
676
|
+
|
|
677
|
+
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
|
|
678
|
+
avg_logprobs: List[float] = [
|
|
679
|
+
lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
|
|
680
|
+
]
|
|
681
|
+
|
|
682
|
+
fields = (
|
|
683
|
+
texts,
|
|
684
|
+
languages,
|
|
685
|
+
tokens,
|
|
686
|
+
audio_features,
|
|
687
|
+
avg_logprobs,
|
|
688
|
+
no_speech_probs,
|
|
689
|
+
)
|
|
690
|
+
if len(set(map(len, fields))) != 1:
|
|
691
|
+
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
|
|
692
|
+
|
|
693
|
+
return [
|
|
694
|
+
DecodingResult(
|
|
695
|
+
audio_features=features,
|
|
696
|
+
language=language,
|
|
697
|
+
tokens=tokens,
|
|
698
|
+
text=text,
|
|
699
|
+
avg_logprob=avg_logprob,
|
|
700
|
+
no_speech_prob=no_speech_prob,
|
|
701
|
+
temperature=self.options.temperature,
|
|
702
|
+
compression_ratio=compression_ratio(text),
|
|
703
|
+
)
|
|
704
|
+
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
|
|
705
|
+
*fields
|
|
706
|
+
)
|
|
707
|
+
]
|
|
708
|
+
|
|
709
|
+
|
|
710
|
+
def decode(
|
|
711
|
+
model: "Whisper",
|
|
712
|
+
mel: mx.array,
|
|
713
|
+
options: DecodingOptions = DecodingOptions(),
|
|
714
|
+
**kwargs,
|
|
715
|
+
) -> Union[DecodingResult, List[DecodingResult]]:
|
|
716
|
+
"""
|
|
717
|
+
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
|
|
718
|
+
|
|
719
|
+
Parameters
|
|
720
|
+
----------
|
|
721
|
+
model: Whisper
|
|
722
|
+
the Whisper model instance
|
|
723
|
+
|
|
724
|
+
mel: mx.array, shape = (80, 3000) or (*, 80, 3000)
|
|
725
|
+
An array containing the Mel spectrogram(s)
|
|
726
|
+
|
|
727
|
+
options: DecodingOptions
|
|
728
|
+
A dataclass that contains all necessary options for decoding 30-second segments
|
|
729
|
+
|
|
730
|
+
Returns
|
|
731
|
+
-------
|
|
732
|
+
result: Union[DecodingResult, List[DecodingResult]]
|
|
733
|
+
The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
|
|
734
|
+
"""
|
|
735
|
+
if single := mel.ndim == 2:
|
|
736
|
+
mel = mel[None]
|
|
737
|
+
|
|
738
|
+
if kwargs:
|
|
739
|
+
options = replace(options, **kwargs)
|
|
740
|
+
|
|
741
|
+
result = DecodingTask(model, options).run(mel)
|
|
742
|
+
return result[0] if single else result
|