nexaai 1.0.29__cp310-cp310-macosx_14_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nexaai/__init__.py +99 -0
- nexaai/_stub.cpython-310-darwin.so +0 -0
- nexaai/_version.py +4 -0
- nexaai/asr.py +68 -0
- nexaai/asr_impl/__init__.py +0 -0
- nexaai/asr_impl/mlx_asr_impl.py +93 -0
- nexaai/asr_impl/pybind_asr_impl.py +127 -0
- nexaai/base.py +39 -0
- nexaai/binds/__init__.py +7 -0
- nexaai/binds/asr_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/common_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/cpu_gpu/libggml-base.dylib +0 -0
- nexaai/binds/cpu_gpu/libggml-cpu.so +0 -0
- nexaai/binds/cpu_gpu/libggml-metal.so +0 -0
- nexaai/binds/cpu_gpu/libggml.dylib +0 -0
- nexaai/binds/cpu_gpu/libmtmd.dylib +0 -0
- nexaai/binds/cpu_gpu/libnexa_cpu_gpu.dylib +0 -0
- nexaai/binds/cpu_gpu/libnexa_plugin.dylib +0 -0
- nexaai/binds/cv_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/diarize_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/embedder_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/libnexa_bridge.dylib +0 -0
- nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/metal/libnexa_plugin.dylib +0 -0
- nexaai/binds/metal/py-lib/ml.py +888 -0
- nexaai/binds/metal/py-lib/mlx_audio/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/__init__.py +5 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/activation.py +51 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/amp.py +96 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/conv.py +114 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/resample.py +177 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/base.py +228 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/dac.py +285 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/layers.py +129 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/encodec.py +777 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/mimi.py +286 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model.py +260 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model_v2.py +383 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/utils.py +122 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/attention.py +97 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/layers.py +306 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/snac.py +154 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/vq.py +135 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/mel.py +33 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/vocos.py +359 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_bigvgan.py +54 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_descript.py +109 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_encodec.py +58 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_mimi.py +22 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_s3.py +25 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_snac.py +40 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_vocos.py +93 -0
- nexaai/binds/metal/py-lib/mlx_audio/server.py +525 -0
- nexaai/binds/metal/py-lib/mlx_audio/sts/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
- nexaai/binds/metal/py-lib/mlx_audio/sts/voice_pipeline.py +327 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/generate.py +174 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/alignment.py +248 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/attention.py +187 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/audio.py +76 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/conformer.py +331 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/ctc.py +34 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/audio.py +82 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/decoding.py +742 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/timing.py +329 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/whisper.py +862 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/writers.py +268 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/tests/test_models.py +381 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/utils.py +195 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/audio_player.py +120 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/convert.py +71 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/generate.py +449 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/__init__.py +4 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/bark.py +528 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/isftnet.py +12 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/pipeline.py +442 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/base.py +84 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/audio.py +287 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/config.py +256 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/dia.py +592 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/layers.py +870 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/__init__.py +3 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/attention.py +180 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/conformer.py +247 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/gpt2.py +38 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/indextts.py +412 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/mel.py +37 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/normalize.py +294 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/perceiver.py +62 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/interpolate.py +108 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/__init__.py +4 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/modules.py +659 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/voice.py +113 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/__init__.py +3 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/llama.py +324 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/outetts.py +255 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/tokens.py +36 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/__init__.py +3 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/attention.py +195 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/sesame.py +633 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/watermarking.py +105 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/bicodec.py +269 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual.py +209 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/spark.py +382 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/audio.py +220 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/file.py +221 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_base.py +66 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_convert.py +173 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_interpolate.py +88 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_models.py +974 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/utils.py +337 -0
- nexaai/binds/metal/py-lib/mlx_audio/utils.py +237 -0
- nexaai/binds/metal/py-lib/mlx_audio/version.py +1 -0
- nexaai/binds/metal/py-lib/profiling.py +239 -0
- nexaai/binds/nexaml/libfftw3.3.dylib +0 -0
- nexaai/binds/nexaml/libfftw3f.3.dylib +0 -0
- nexaai/binds/nexaml/libggml-base.dylib +0 -0
- nexaai/binds/nexaml/libggml-cpu.so +0 -0
- nexaai/binds/nexaml/libggml-metal.so +0 -0
- nexaai/binds/nexaml/libggml.dylib +0 -0
- nexaai/binds/nexaml/libmp3lame.0.dylib +0 -0
- nexaai/binds/nexaml/libmpg123.0.dylib +0 -0
- nexaai/binds/nexaml/libnexa-mm-process.dylib +0 -0
- nexaai/binds/nexaml/libnexa-sampling.dylib +0 -0
- nexaai/binds/nexaml/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexaml/libnexaproc.dylib +0 -0
- nexaai/binds/nexaml/libomp.dylib +0 -0
- nexaai/binds/nexaml/libqwen3-vl.dylib +0 -0
- nexaai/binds/nexaml/libqwen3vl-vision.dylib +0 -0
- nexaai/binds/rerank_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/vlm_bind.cpython-310-darwin.so +0 -0
- nexaai/common.py +106 -0
- nexaai/cv.py +95 -0
- nexaai/cv_impl/__init__.py +0 -0
- nexaai/cv_impl/mlx_cv_impl.py +91 -0
- nexaai/cv_impl/pybind_cv_impl.py +124 -0
- nexaai/diarize.py +80 -0
- nexaai/diarize_impl/__init__.py +1 -0
- nexaai/diarize_impl/pybind_diarize_impl.py +125 -0
- nexaai/embedder.py +73 -0
- nexaai/embedder_impl/__init__.py +0 -0
- nexaai/embedder_impl/mlx_embedder_impl.py +118 -0
- nexaai/embedder_impl/pybind_embedder_impl.py +96 -0
- nexaai/image_gen.py +141 -0
- nexaai/image_gen_impl/__init__.py +0 -0
- nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -0
- nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -0
- nexaai/llm.py +98 -0
- nexaai/llm_impl/__init__.py +0 -0
- nexaai/llm_impl/mlx_llm_impl.py +271 -0
- nexaai/llm_impl/pybind_llm_impl.py +238 -0
- nexaai/log.py +92 -0
- nexaai/mlx_backend/asr/__init__.py +12 -0
- nexaai/mlx_backend/asr/interface.py +122 -0
- nexaai/mlx_backend/common/__init__.py +0 -0
- nexaai/mlx_backend/common/utils.py +25 -0
- nexaai/mlx_backend/cv/__init__.py +0 -0
- nexaai/mlx_backend/cv/generate.py +195 -0
- nexaai/mlx_backend/cv/interface.py +162 -0
- nexaai/mlx_backend/cv/main.py +81 -0
- nexaai/mlx_backend/cv/modeling/pp_ocr_v4.py +1736 -0
- nexaai/mlx_backend/embedding/__init__.py +0 -0
- nexaai/mlx_backend/embedding/generate.py +333 -0
- nexaai/mlx_backend/embedding/interface.py +617 -0
- nexaai/mlx_backend/embedding/main.py +173 -0
- nexaai/mlx_backend/embedding/modeling/__init__.py +0 -0
- nexaai/mlx_backend/embedding/modeling/nexa_jina_v2.py +399 -0
- nexaai/mlx_backend/image_gen/__init__.py +1 -0
- nexaai/mlx_backend/image_gen/generate_sd.py +244 -0
- nexaai/mlx_backend/image_gen/interface.py +82 -0
- nexaai/mlx_backend/image_gen/main.py +281 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/__init__.py +306 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/clip.py +116 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/config.py +65 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/model_io.py +386 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/sampler.py +105 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/tokenizer.py +100 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/unet.py +460 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/vae.py +274 -0
- nexaai/mlx_backend/llm/__init__.py +0 -0
- nexaai/mlx_backend/llm/generate.py +149 -0
- nexaai/mlx_backend/llm/interface.py +764 -0
- nexaai/mlx_backend/llm/main.py +68 -0
- nexaai/mlx_backend/ml.py +888 -0
- nexaai/mlx_backend/mlx_audio/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/codec/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/__init__.py +5 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/activation.py +51 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/amp.py +96 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/conv.py +114 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/resample.py +177 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/base.py +228 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/dac.py +285 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/layers.py +129 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
- nexaai/mlx_backend/mlx_audio/codec/models/encodec/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/encodec/encodec.py +777 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/mimi.py +286 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/model.py +260 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/model_v2.py +383 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/utils.py +122 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/attention.py +97 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/layers.py +306 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/snac.py +154 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/vq.py +135 -0
- nexaai/mlx_backend/mlx_audio/codec/models/vocos/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/vocos/mel.py +33 -0
- nexaai/mlx_backend/mlx_audio/codec/models/vocos/vocos.py +359 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_bigvgan.py +54 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_descript.py +109 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_encodec.py +58 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_mimi.py +22 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_s3.py +25 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_snac.py +40 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_vocos.py +93 -0
- nexaai/mlx_backend/mlx_audio/server.py +525 -0
- nexaai/mlx_backend/mlx_audio/sts/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
- nexaai/mlx_backend/mlx_audio/sts/voice_pipeline.py +327 -0
- nexaai/mlx_backend/mlx_audio/stt/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/stt/generate.py +174 -0
- nexaai/mlx_backend/mlx_audio/stt/models/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/alignment.py +248 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/attention.py +187 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/audio.py +76 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/conformer.py +331 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/ctc.py +34 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
- nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
- nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/audio.py +82 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/decoding.py +742 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/timing.py +329 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/whisper.py +862 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/writers.py +268 -0
- nexaai/mlx_backend/mlx_audio/stt/tests/test_models.py +381 -0
- nexaai/mlx_backend/mlx_audio/stt/utils.py +195 -0
- nexaai/mlx_backend/mlx_audio/tts/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/audio_player.py +120 -0
- nexaai/mlx_backend/mlx_audio/tts/convert.py +71 -0
- nexaai/mlx_backend/mlx_audio/tts/generate.py +449 -0
- nexaai/mlx_backend/mlx_audio/tts/models/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/__init__.py +4 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/bark.py +528 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/isftnet.py +12 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/pipeline.py +442 -0
- nexaai/mlx_backend/mlx_audio/tts/models/base.py +84 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/audio.py +287 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/config.py +256 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/dia.py +592 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/layers.py +870 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/__init__.py +3 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/attention.py +180 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/conformer.py +247 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/gpt2.py +38 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/indextts.py +412 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/mel.py +37 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/normalize.py +294 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/perceiver.py +62 -0
- nexaai/mlx_backend/mlx_audio/tts/models/interpolate.py +108 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/__init__.py +4 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/modules.py +659 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/voice.py +113 -0
- nexaai/mlx_backend/mlx_audio/tts/models/llama/__init__.py +3 -0
- nexaai/mlx_backend/mlx_audio/tts/models/llama/llama.py +324 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/default_speaker.json +461 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/outetts.py +255 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/tokens.py +36 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/__init__.py +3 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/attention.py +195 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/sesame.py +633 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/watermarking.py +105 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/bicodec.py +269 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual.py +209 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/spark.py +382 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/audio.py +220 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/file.py +221 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_base.py +66 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_convert.py +173 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_interpolate.py +88 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_models.py +974 -0
- nexaai/mlx_backend/mlx_audio/tts/utils.py +337 -0
- nexaai/mlx_backend/mlx_audio/utils.py +237 -0
- nexaai/mlx_backend/mlx_audio/version.py +1 -0
- nexaai/mlx_backend/profiling.py +239 -0
- nexaai/mlx_backend/rerank/__init__.py +0 -0
- nexaai/mlx_backend/rerank/generate.py +174 -0
- nexaai/mlx_backend/rerank/interface.py +287 -0
- nexaai/mlx_backend/rerank/main.py +127 -0
- nexaai/mlx_backend/rerank/modeling/__init__.py +0 -0
- nexaai/mlx_backend/rerank/modeling/nexa_jina_rerank.py +330 -0
- nexaai/mlx_backend/sd/__init__.py +1 -0
- nexaai/mlx_backend/sd/interface.py +362 -0
- nexaai/mlx_backend/sd/main.py +286 -0
- nexaai/mlx_backend/sd/modeling/__init__.py +306 -0
- nexaai/mlx_backend/sd/modeling/clip.py +116 -0
- nexaai/mlx_backend/sd/modeling/config.py +65 -0
- nexaai/mlx_backend/sd/modeling/model_io.py +385 -0
- nexaai/mlx_backend/sd/modeling/sampler.py +105 -0
- nexaai/mlx_backend/sd/modeling/tokenizer.py +100 -0
- nexaai/mlx_backend/sd/modeling/unet.py +460 -0
- nexaai/mlx_backend/sd/modeling/vae.py +274 -0
- nexaai/mlx_backend/tts/__init__.py +12 -0
- nexaai/mlx_backend/tts/interface.py +276 -0
- nexaai/mlx_backend/vlm/__init__.py +3 -0
- nexaai/mlx_backend/vlm/generate.py +572 -0
- nexaai/mlx_backend/vlm/generate_qwen3_vl.py +374 -0
- nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +259 -0
- nexaai/mlx_backend/vlm/interface.py +559 -0
- nexaai/mlx_backend/vlm/main.py +365 -0
- nexaai/mlx_backend/vlm/modeling/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/convert.py +68 -0
- nexaai/mlx_backend/vlm/modeling/models/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/interpolate.py +186 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/language.py +233 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/vision.py +503 -0
- nexaai/mlx_backend/vlm/modeling/models/base.py +202 -0
- nexaai/mlx_backend/vlm/modeling/models/cache.py +230 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/florence2.py +366 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/language.py +488 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/vision.py +591 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/gemma3.py +213 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/language.py +315 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/vision.py +238 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/__init__.py +2 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/audio.py +1038 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/config.py +139 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/language.py +629 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/vision.py +1022 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/idefics2.py +294 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/language.py +191 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/vision.py +267 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/idefics3.py +175 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/language.py +192 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/vision.py +233 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/language.py +220 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/processor.py +393 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/vision.py +293 -0
- nexaai/mlx_backend/vlm/modeling/models/kernels.py +307 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/language.py +509 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/vision.py +522 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/language.py +386 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/llama4.py +138 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/vision.py +560 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/language.py +240 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/llava.py +153 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/vision.py +259 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/language.py +236 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/vision.py +303 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/language.py +230 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/llava_next.py +160 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/vision.py +243 -0
- nexaai/mlx_backend/vlm/modeling/models/mistral3/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/mistral3/mistral3.py +283 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/language.py +416 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/mllama.py +172 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/vision.py +499 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/language.py +243 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/molmo.py +133 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/vision.py +465 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/__init__.py +10 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/language.py +230 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/sam.py +557 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/vision.py +526 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/language.py +282 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/paligemma.py +160 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/vision.py +242 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/language.py +21 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/su_rope.py +71 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/vision.py +324 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/language.py +229 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/pixtral.py +161 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/vision.py +320 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/config.py +104 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/language.py +490 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/vision.py +312 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/processor.py +476 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +1262 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1308 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
- nexaai/mlx_backend/vlm/modeling/models/smolvlm/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
- nexaai/mlx_backend/vlm/modeling/processing_qwen2_5_vl.py +209 -0
- nexaai/mlx_backend/vlm/modeling/processing_qwen2_vl.py +215 -0
- nexaai/mlx_backend/vlm/modeling/prompt_utils.py +474 -0
- nexaai/mlx_backend/vlm/modeling/sample_utils.py +39 -0
- nexaai/mlx_backend/vlm/modeling/tokenizer_utils.py +344 -0
- nexaai/mlx_backend/vlm/modeling/trainer/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/trainer/lora.py +70 -0
- nexaai/mlx_backend/vlm/modeling/trainer/trainer.py +296 -0
- nexaai/mlx_backend/vlm/modeling/trainer/utils.py +160 -0
- nexaai/mlx_backend/vlm/modeling/utils.py +928 -0
- nexaai/rerank.py +57 -0
- nexaai/rerank_impl/__init__.py +0 -0
- nexaai/rerank_impl/mlx_rerank_impl.py +94 -0
- nexaai/rerank_impl/pybind_rerank_impl.py +136 -0
- nexaai/runtime.py +68 -0
- nexaai/runtime_error.py +24 -0
- nexaai/tts.py +75 -0
- nexaai/tts_impl/__init__.py +0 -0
- nexaai/tts_impl/mlx_tts_impl.py +94 -0
- nexaai/tts_impl/pybind_tts_impl.py +43 -0
- nexaai/utils/decode.py +18 -0
- nexaai/utils/manifest_utils.py +531 -0
- nexaai/utils/model_manager.py +1745 -0
- nexaai/utils/model_types.py +49 -0
- nexaai/utils/progress_tracker.py +389 -0
- nexaai/utils/quantization_utils.py +245 -0
- nexaai/vlm.py +130 -0
- nexaai/vlm_impl/__init__.py +0 -0
- nexaai/vlm_impl/mlx_vlm_impl.py +259 -0
- nexaai/vlm_impl/pybind_vlm_impl.py +275 -0
- nexaai-1.0.29.dist-info/METADATA +35 -0
- nexaai-1.0.29.dist-info/RECORD +580 -0
- nexaai-1.0.29.dist-info/WHEEL +5 -0
- nexaai-1.0.29.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1038 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import Optional, Tuple, Union
|
|
3
|
+
|
|
4
|
+
import mlx.core as mx
|
|
5
|
+
import mlx.nn as nn
|
|
6
|
+
|
|
7
|
+
from ..base import check_array_shape
|
|
8
|
+
from .config import AudioConfig, ModelConfig
|
|
9
|
+
from .language import Gemma3nRMSNorm
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def convert_torch_to_mlx_pad_width(padding, input_shape):
|
|
13
|
+
"""Convert PyTorch padding to MLX pad_width format"""
|
|
14
|
+
ndim = len(input_shape)
|
|
15
|
+
|
|
16
|
+
# Initialize with no padding for all dimensions
|
|
17
|
+
pad_width = [(0, 0)] * ndim
|
|
18
|
+
|
|
19
|
+
# Set padding only for the dimensions that exist in the input
|
|
20
|
+
# PyTorch p2d format: (left, right, top, bottom, front, back, ...)
|
|
21
|
+
# For 2D tensor with padding (12, 11, 0, 0):
|
|
22
|
+
# - Last dim gets (left=12, right=11)
|
|
23
|
+
# - Second to last dim gets (top=0, bottom=0)
|
|
24
|
+
|
|
25
|
+
if ndim >= 1 and len(padding) >= 2:
|
|
26
|
+
# Last dimension
|
|
27
|
+
pad_width[-1] = (padding[0], padding[1])
|
|
28
|
+
if ndim >= 2 and len(padding) >= 4:
|
|
29
|
+
# Second to last dimension
|
|
30
|
+
pad_width[-2] = (padding[2], padding[3])
|
|
31
|
+
if ndim >= 3 and len(padding) >= 6:
|
|
32
|
+
# Third to last dimension
|
|
33
|
+
pad_width[-3] = (padding[4], padding[5])
|
|
34
|
+
if ndim >= 4 and len(padding) >= 8:
|
|
35
|
+
# Fourth to last dimension
|
|
36
|
+
pad_width[-4] = (padding[6], padding[7])
|
|
37
|
+
|
|
38
|
+
return pad_width
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class Gemma3nAudioRelativePositionEmbedding(nn.Module):
|
|
42
|
+
|
|
43
|
+
def __init__(self, config: AudioConfig, *args, **kwargs):
|
|
44
|
+
super().__init__()
|
|
45
|
+
self.config = config
|
|
46
|
+
|
|
47
|
+
self.num_heads = self.config.conf_num_attention_heads
|
|
48
|
+
self.channels = self.config.hidden_size
|
|
49
|
+
self.head_dim = self.channels // self.num_heads
|
|
50
|
+
self.max_backward = (
|
|
51
|
+
self.config.conf_attention_context_left - 1
|
|
52
|
+
if self.config.conf_attention_context_left > 0
|
|
53
|
+
else 0
|
|
54
|
+
)
|
|
55
|
+
self.max_forward = self.config.conf_attention_context_right
|
|
56
|
+
|
|
57
|
+
self.pos_proj = nn.Linear(
|
|
58
|
+
self.channels, self.num_heads * self.head_dim, bias=False
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
min_timescale = 1.0
|
|
62
|
+
max_timescale = 1.0e4
|
|
63
|
+
num_timescales = self.channels // 2
|
|
64
|
+
log_timescale_increment = math.log(
|
|
65
|
+
float(max_timescale) / float(min_timescale)
|
|
66
|
+
) / max(num_timescales - 1, 1)
|
|
67
|
+
inv_timescales = min_timescale * mx.exp(
|
|
68
|
+
mx.arange(num_timescales) * -log_timescale_increment
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
self._inv_timescales = mx.array(inv_timescales)[None, None, ...]
|
|
72
|
+
|
|
73
|
+
def _get_timing_signal_1d_pos(self, position: mx.array, dtype) -> mx.array:
|
|
74
|
+
assert position.ndim == 2
|
|
75
|
+
position = mx.expand_dims(position.astype(mx.float32), axis=-1)
|
|
76
|
+
|
|
77
|
+
scaled_time = position * self._inv_timescales
|
|
78
|
+
timing_signal = mx.concatenate(
|
|
79
|
+
[mx.sin(scaled_time), mx.cos(scaled_time)], axis=-1
|
|
80
|
+
)
|
|
81
|
+
return timing_signal.astype(dtype)
|
|
82
|
+
|
|
83
|
+
def _relative_shift(
|
|
84
|
+
self,
|
|
85
|
+
term_bd_before_shift: mx.array,
|
|
86
|
+
batch_size: int,
|
|
87
|
+
num_heads: int,
|
|
88
|
+
num_query_blocks: int,
|
|
89
|
+
query_block_size: int,
|
|
90
|
+
key_context_size: int,
|
|
91
|
+
max_span_plus_1: int,
|
|
92
|
+
) -> mx.array:
|
|
93
|
+
pad_amount_last_dim = (key_context_size + 1) - max_span_plus_1
|
|
94
|
+
|
|
95
|
+
# We only pad the last dimension on the right.
|
|
96
|
+
padding_tuple = (0, pad_amount_last_dim)
|
|
97
|
+
|
|
98
|
+
term_bd_padded = mx.pad(
|
|
99
|
+
term_bd_before_shift,
|
|
100
|
+
convert_torch_to_mlx_pad_width(padding_tuple, term_bd_before_shift.shape),
|
|
101
|
+
)
|
|
102
|
+
# Shape after pad: [B, N, U, W, C+1]
|
|
103
|
+
# Reshape for slicing (emulating JAX's behavior)
|
|
104
|
+
# [B, N, U, W * (C+1)]
|
|
105
|
+
term_bd_reshaped = term_bd_padded.reshape(
|
|
106
|
+
(
|
|
107
|
+
batch_size,
|
|
108
|
+
num_heads,
|
|
109
|
+
num_query_blocks,
|
|
110
|
+
query_block_size * (key_context_size + 1),
|
|
111
|
+
)
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# Slice to effective [B, N, U, W * C]
|
|
115
|
+
term_bd_sliced = term_bd_reshaped[
|
|
116
|
+
:, :, :, : query_block_size * key_context_size
|
|
117
|
+
]
|
|
118
|
+
|
|
119
|
+
# Reshape back to [B, N, U, W, C]
|
|
120
|
+
term_bd_shifted = term_bd_sliced.reshape(
|
|
121
|
+
(
|
|
122
|
+
batch_size,
|
|
123
|
+
num_heads,
|
|
124
|
+
num_query_blocks,
|
|
125
|
+
query_block_size,
|
|
126
|
+
key_context_size,
|
|
127
|
+
)
|
|
128
|
+
)
|
|
129
|
+
return term_bd_shifted
|
|
130
|
+
|
|
131
|
+
def __call__(self, queries: mx.array, keys: mx.array) -> mx.array:
|
|
132
|
+
# queries: [B, U, W, N, H] (batch, num_query_blocks, query_block_size, num_heads, head_dim)
|
|
133
|
+
# keys: [B, U, C, N, H] (batch, num_query_blocks, key_context_size, num_heads, head_dim)
|
|
134
|
+
# C = W + L + R (key_context_size)
|
|
135
|
+
# F_span = L + R + 1 (max_span + 1)
|
|
136
|
+
|
|
137
|
+
batch_size, num_query_blocks, query_block_size, num_heads, head_dim = (
|
|
138
|
+
queries.shape
|
|
139
|
+
)
|
|
140
|
+
_, _, key_context_size, _, _ = keys.shape
|
|
141
|
+
|
|
142
|
+
# Relative positions for sinusoidal embeddings: [L, L-1, ..., -R]
|
|
143
|
+
# Length is L+R+1 = self.max_span + 1
|
|
144
|
+
pos_indices = mx.expand_dims(
|
|
145
|
+
mx.arange(self.max_backward, -self.max_forward - 1, -1), axis=0
|
|
146
|
+
) # Shape [1, F_span]
|
|
147
|
+
|
|
148
|
+
max_span_plus_1 = pos_indices.shape[1] # F_span
|
|
149
|
+
|
|
150
|
+
sin_emb_timing_signal = self._get_timing_signal_1d_pos(
|
|
151
|
+
pos_indices, dtype=queries.dtype
|
|
152
|
+
) # Shape [1, F_span, self.channels]
|
|
153
|
+
|
|
154
|
+
# Project sinusoidal embeddings: [1, F_span, self.channels] -> [1, F_span, N*H]
|
|
155
|
+
projected_sin_emb = self.pos_proj(sin_emb_timing_signal)
|
|
156
|
+
# Reshape to [1, F_span, N, H] then squeeze to [F_span, N, H]
|
|
157
|
+
sin_emb = projected_sin_emb.reshape(
|
|
158
|
+
1, max_span_plus_1, self.num_heads, self.head_dim
|
|
159
|
+
).squeeze(
|
|
160
|
+
0
|
|
161
|
+
) # Shape [F, N, H]
|
|
162
|
+
|
|
163
|
+
# term_ac: Query-Key content interaction
|
|
164
|
+
# queries: [B, U, W, N, H] -> transpose to [B, N, U, W, H] for matmul
|
|
165
|
+
# keys: [B, U, C, N, H] -> transpose to [B, N, U, H, C] for matmul
|
|
166
|
+
queries_p = queries.transpose(0, 3, 1, 2, 4) # [B, N, U, W, H]
|
|
167
|
+
keys_p_t = keys.transpose(0, 3, 1, 4, 2) # [B, N, U, H, C]
|
|
168
|
+
term_ac = mx.matmul(queries_p, keys_p_t) # [B, N, U, W, C]
|
|
169
|
+
|
|
170
|
+
# term_bd: Query-Position interaction
|
|
171
|
+
# Original einsum: term_bd_unshifed = mx.einsum('buwnh,fnh->bnuwf', queries, sin_emb)
|
|
172
|
+
# queries shape: [B, U, W, N, H]
|
|
173
|
+
# sin_emb shape: [F, N, H]
|
|
174
|
+
# Target output shape: [B, N, U, W, F]
|
|
175
|
+
|
|
176
|
+
# Transpose queries to [B, N, U, W, H] for easier broadcasting with sin_emb
|
|
177
|
+
q_transposed = queries.transpose(0, 3, 1, 2, 4)
|
|
178
|
+
|
|
179
|
+
# Permute sin_emb to [N, H, F] to prepare for matmul
|
|
180
|
+
# sin_emb original is [F, N, H]
|
|
181
|
+
s_transposed = sin_emb.transpose(1, 2, 0) # Shape: [N, H, F]
|
|
182
|
+
|
|
183
|
+
# Reshape queries for matmul: [B, N, U*W, H]
|
|
184
|
+
q_reshaped = q_transposed.reshape(
|
|
185
|
+
batch_size, num_heads, num_query_blocks * query_block_size, head_dim
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
# Perform matmul: [B, N, U*W, H] @ [N, H, F]
|
|
189
|
+
# s_permuted ([N, H, F]) will be broadcast to [B, N, H, F]
|
|
190
|
+
# Result: [B, N, U*W, F]
|
|
191
|
+
term_bd_unshifed_matmul = mx.matmul(q_reshaped, s_transposed)
|
|
192
|
+
|
|
193
|
+
# Reshape to target [B, N, U, W, F]
|
|
194
|
+
term_bd_unshifed = term_bd_unshifed_matmul.reshape(
|
|
195
|
+
batch_size,
|
|
196
|
+
num_heads,
|
|
197
|
+
num_query_blocks,
|
|
198
|
+
query_block_size,
|
|
199
|
+
max_span_plus_1,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# Apply relative shift to term_bd_unshifed
|
|
203
|
+
term_bd_shifted = self._relative_shift(
|
|
204
|
+
term_bd_unshifed,
|
|
205
|
+
batch_size,
|
|
206
|
+
num_heads,
|
|
207
|
+
num_query_blocks,
|
|
208
|
+
query_block_size,
|
|
209
|
+
key_context_size,
|
|
210
|
+
max_span_plus_1,
|
|
211
|
+
) # Shape [B, N, U, W, C]
|
|
212
|
+
|
|
213
|
+
return term_ac + term_bd_shifted
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
class Gemma3nAudioAttention(nn.Module):
|
|
217
|
+
def __init__(self, config: AudioConfig, *args, **kwargs):
|
|
218
|
+
super().__init__()
|
|
219
|
+
self.config = config
|
|
220
|
+
|
|
221
|
+
self.num_heads = self.config.conf_num_attention_heads
|
|
222
|
+
self.hidden_size = self.config.hidden_size
|
|
223
|
+
self.head_dim = self.hidden_size // self.num_heads
|
|
224
|
+
|
|
225
|
+
self.chunk_size = self.config.conf_attention_chunk_size
|
|
226
|
+
self.max_future_horizon = self.config.conf_attention_context_right
|
|
227
|
+
self.max_past_horizon = max(0, self.config.conf_attention_context_left - 1)
|
|
228
|
+
self.attention_invalid_logits_value = (
|
|
229
|
+
self.config.conf_attention_invalid_logits_value
|
|
230
|
+
)
|
|
231
|
+
self.attention_logits_soft_cap = self.config.conf_attention_logit_cap
|
|
232
|
+
self.context_size = (
|
|
233
|
+
self.chunk_size + self.max_past_horizon + self.max_future_horizon
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
self.relative_position_embedding = Gemma3nAudioRelativePositionEmbedding(config)
|
|
237
|
+
self.per_dim_scale = mx.zeros((self.head_dim,))
|
|
238
|
+
|
|
239
|
+
self.q_proj = nn.Linear(
|
|
240
|
+
self.hidden_size, self.num_heads * self.head_dim, bias=False
|
|
241
|
+
)
|
|
242
|
+
self.k_proj = nn.Linear(
|
|
243
|
+
self.hidden_size, self.num_heads * self.head_dim, bias=False
|
|
244
|
+
)
|
|
245
|
+
self.v_proj = nn.Linear(
|
|
246
|
+
self.hidden_size, self.num_heads * self.head_dim, bias=False
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
q_scale = self.head_dim**-0.5
|
|
250
|
+
# Fix: Implement softplus manually since nn.softplus doesn't exist in MLX
|
|
251
|
+
# softplus(x) = log(1 + exp(x))
|
|
252
|
+
r_softplus_0 = 1.0 / mx.log(2.0)
|
|
253
|
+
self._q_scale = q_scale * r_softplus_0
|
|
254
|
+
|
|
255
|
+
lower_causal_mask = mx.tril(
|
|
256
|
+
mx.ones((self.context_size, self.chunk_size), dtype=mx.bool_),
|
|
257
|
+
k=0,
|
|
258
|
+
).T
|
|
259
|
+
upper_causal_mask = mx.tril(
|
|
260
|
+
mx.ones((self.chunk_size, self.context_size), dtype=mx.bool_),
|
|
261
|
+
k=self.max_past_horizon + self.max_future_horizon,
|
|
262
|
+
)
|
|
263
|
+
local_causal_valid_mask = mx.ones(
|
|
264
|
+
(self.chunk_size, self.context_size), dtype=mx.bool_
|
|
265
|
+
)
|
|
266
|
+
local_causal_valid_mask = (
|
|
267
|
+
local_causal_valid_mask * lower_causal_mask * upper_causal_mask
|
|
268
|
+
)
|
|
269
|
+
self._local_causal_valid_mask = local_causal_valid_mask
|
|
270
|
+
|
|
271
|
+
self._softcap = mx.array(self.attention_logits_soft_cap, dtype=mx.float32)
|
|
272
|
+
|
|
273
|
+
def _pad_dim1(
|
|
274
|
+
self,
|
|
275
|
+
x: mx.array,
|
|
276
|
+
dim10_val: int,
|
|
277
|
+
dim11_val: int,
|
|
278
|
+
) -> mx.array:
|
|
279
|
+
padding_tuple = [0] * x.ndim * 2
|
|
280
|
+
dim_idx_from_end = x.ndim - 2
|
|
281
|
+
start_idx_for_dim = 2 * dim_idx_from_end
|
|
282
|
+
padding_tuple[start_idx_for_dim] = dim10_val
|
|
283
|
+
padding_tuple[start_idx_for_dim + 1] = dim11_val
|
|
284
|
+
|
|
285
|
+
return mx.pad(x, convert_torch_to_mlx_pad_width(tuple(padding_tuple), x.shape))
|
|
286
|
+
|
|
287
|
+
def _convert_to_block(
|
|
288
|
+
self, x: mx.array, padding_val: Union[bool, float] = 0.0
|
|
289
|
+
) -> mx.array:
|
|
290
|
+
shape = x.shape
|
|
291
|
+
b, t = shape[:2]
|
|
292
|
+
num_blocks = (t + self.chunk_size - 1) // self.chunk_size
|
|
293
|
+
|
|
294
|
+
if (padding_len := num_blocks * self.chunk_size - t) > 0:
|
|
295
|
+
x = self._pad_dim1(x, 0, padding_len)
|
|
296
|
+
|
|
297
|
+
permute_dims = (b, num_blocks, self.chunk_size) + shape[2:]
|
|
298
|
+
return x.reshape(permute_dims)
|
|
299
|
+
|
|
300
|
+
def unfold_mlx(self, x, dimension, size, step):
|
|
301
|
+
# Get the shape and determine the number of windows
|
|
302
|
+
shape = x.shape
|
|
303
|
+
dim_size = shape[dimension]
|
|
304
|
+
num_windows = (dim_size - size) // step + 1
|
|
305
|
+
|
|
306
|
+
# Create indices for each window
|
|
307
|
+
windows = []
|
|
308
|
+
for i in range(num_windows):
|
|
309
|
+
start_idx = i * step
|
|
310
|
+
end_idx = start_idx + size
|
|
311
|
+
|
|
312
|
+
# Create slice objects for all dimensions
|
|
313
|
+
slices = [slice(None)] * len(shape)
|
|
314
|
+
slices[dimension] = slice(start_idx, end_idx)
|
|
315
|
+
|
|
316
|
+
windows.append(x[tuple(slices)])
|
|
317
|
+
|
|
318
|
+
# Stack along a new dimension
|
|
319
|
+
return mx.stack(windows, axis=dimension + 1)
|
|
320
|
+
|
|
321
|
+
def _extract_block_context(self, x: mx.array) -> mx.array:
|
|
322
|
+
pad_left = self.max_past_horizon
|
|
323
|
+
|
|
324
|
+
pad_right = self.max_future_horizon + self.chunk_size - 1
|
|
325
|
+
x = self._pad_dim1(x, pad_left, pad_right)
|
|
326
|
+
|
|
327
|
+
frame_len = self.context_size
|
|
328
|
+
frame_step = self.chunk_size
|
|
329
|
+
# Create windows using sliding window approach for MLX
|
|
330
|
+
# x shape: (batch, time, ...)
|
|
331
|
+
batch_size = x.shape[0]
|
|
332
|
+
time_dim = x.shape[1]
|
|
333
|
+
other_dims = x.shape[2:]
|
|
334
|
+
|
|
335
|
+
x_unfolded = self.unfold_mlx(x, 1, frame_len, frame_step)
|
|
336
|
+
|
|
337
|
+
if x.ndim > 2 and x_unfolded.ndim > 3:
|
|
338
|
+
x_unfolded = x_unfolded.transpose(0, 2, 1, 3, 4)
|
|
339
|
+
|
|
340
|
+
return x_unfolded
|
|
341
|
+
|
|
342
|
+
def __call__(self, x: mx.array, mask: mx.array) -> mx.array:
|
|
343
|
+
query_states = self.q_proj(x).reshape(
|
|
344
|
+
*x.shape[:-1], self.num_heads, self.head_dim
|
|
345
|
+
)
|
|
346
|
+
key_states = self.k_proj(x).reshape(
|
|
347
|
+
*x.shape[:-1], self.num_heads, self.head_dim
|
|
348
|
+
)
|
|
349
|
+
value_states = self.v_proj(x).reshape(
|
|
350
|
+
*x.shape[:-1], self.num_heads, self.head_dim
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
per_dim_scale_sp = mx.logaddexp(self.per_dim_scale, 0.0)
|
|
354
|
+
|
|
355
|
+
broadcast_shape = (1, 1, 1, self.head_dim)
|
|
356
|
+
per_dim_scale_sp_broadcast = per_dim_scale_sp.reshape(broadcast_shape)
|
|
357
|
+
query_states = query_states * self._q_scale * per_dim_scale_sp_broadcast
|
|
358
|
+
|
|
359
|
+
batch_size, q_time = query_states.shape[:2]
|
|
360
|
+
|
|
361
|
+
query_blocks = self._convert_to_block(query_states)
|
|
362
|
+
key_blocks = self._extract_block_context(key_states)
|
|
363
|
+
value_blocks = self._extract_block_context(value_states)
|
|
364
|
+
num_query_blocks = query_blocks.shape[1]
|
|
365
|
+
|
|
366
|
+
# 1. Create a mask indicating originally valid positions.
|
|
367
|
+
original_valid_mask = ~mask # True for valid, False for padded
|
|
368
|
+
|
|
369
|
+
# 2. Extract blocks from this validity mask.
|
|
370
|
+
extracted_valid_mask_blocks = self._extract_block_context(
|
|
371
|
+
original_valid_mask
|
|
372
|
+
).transpose(0, 2, 1)
|
|
373
|
+
|
|
374
|
+
# If subframe_factor was used in _extract_block_context for a [B, T] input mask,
|
|
375
|
+
# the shape might be [B, U, C/SF, SF]. Reshape to [B, U, C].
|
|
376
|
+
# batch_size and num_query_blocks are known from query_blocks.
|
|
377
|
+
# self.context_size is C.
|
|
378
|
+
if (
|
|
379
|
+
extracted_valid_mask_blocks.ndim == 4
|
|
380
|
+
and extracted_valid_mask_blocks.shape[0] == batch_size
|
|
381
|
+
and extracted_valid_mask_blocks.shape[1] == num_query_blocks
|
|
382
|
+
and extracted_valid_mask_blocks.shape[2]
|
|
383
|
+
* extracted_valid_mask_blocks.shape[3]
|
|
384
|
+
== self.context_size
|
|
385
|
+
):
|
|
386
|
+
extracted_valid_mask_blocks = extracted_valid_mask_blocks.reshape(
|
|
387
|
+
batch_size, num_query_blocks, self.context_size
|
|
388
|
+
)
|
|
389
|
+
# After potential reshape, ensure it's [B, U, C] if it was from a [B,T] mask.
|
|
390
|
+
# This assertion might be too strict if _extract_block_context handles higher-rank inputs differently,
|
|
391
|
+
# but for the mask case, this should hold.
|
|
392
|
+
if extracted_valid_mask_blocks.shape != (
|
|
393
|
+
batch_size,
|
|
394
|
+
num_query_blocks,
|
|
395
|
+
self.context_size,
|
|
396
|
+
):
|
|
397
|
+
raise ValueError(
|
|
398
|
+
"Shape of extracted_valid_mask_blocks"
|
|
399
|
+
f" {extracted_valid_mask_blocks.shape} is not ({batch_size},"
|
|
400
|
+
f" {num_query_blocks}, {self.context_size}) after potential reshape."
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
# 3. Expand dimensions for broadcasting with logits and causal mask.
|
|
404
|
+
# Target shape for broadcasting with logits [B,N,U,W,C]
|
|
405
|
+
# extracted_valid_mask_blocks to [B, 1, U, 1, C]
|
|
406
|
+
condition_from_input_validity = mx.expand_dims(
|
|
407
|
+
extracted_valid_mask_blocks, axis=1
|
|
408
|
+
)
|
|
409
|
+
condition_from_input_validity = mx.expand_dims(
|
|
410
|
+
condition_from_input_validity, axis=-2
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
# self.local_causal_valid_mask is [W, C], True where allowed by local window.
|
|
414
|
+
# Expand to [1, 1, 1, W, C]
|
|
415
|
+
condition_from_causality = self._local_causal_valid_mask[None, None, None, ...]
|
|
416
|
+
|
|
417
|
+
# 4. Combine the two conditions.
|
|
418
|
+
# final_condition will be True where a key is *both* originally valid *and* causally accessible.
|
|
419
|
+
# Broadcasts to [B, 1, U, W, C]
|
|
420
|
+
final_condition_for_where = mx.logical_and(
|
|
421
|
+
condition_from_input_validity,
|
|
422
|
+
condition_from_causality, # Ensure same device
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
# Embed queries and keys
|
|
426
|
+
logits = self.relative_position_embedding(query_blocks, key_blocks)
|
|
427
|
+
|
|
428
|
+
# Apply attention logit softcap
|
|
429
|
+
# Ensure softcap is on the same device as logits
|
|
430
|
+
logits = logits / self._softcap
|
|
431
|
+
logits = nn.tanh(logits)
|
|
432
|
+
logits = logits * self._softcap
|
|
433
|
+
|
|
434
|
+
# Apply the combined mask.
|
|
435
|
+
# final_condition_for_where will broadcast with logits [B,N,U,W,C]
|
|
436
|
+
logits = mx.where(
|
|
437
|
+
final_condition_for_where, logits, self.attention_invalid_logits_value
|
|
438
|
+
)
|
|
439
|
+
probabilities = mx.softmax(logits.astype(mx.float32), axis=-1).astype(
|
|
440
|
+
value_blocks.dtype
|
|
441
|
+
)
|
|
442
|
+
|
|
443
|
+
# context_vectors is adapted from jax.numpy.einsum("BNuwc,BucNH->BuwNH", ...)
|
|
444
|
+
b_dim, n_dim, u_dim, w_dim, c_dim = probabilities.shape
|
|
445
|
+
h_dim = value_blocks.shape[-1]
|
|
446
|
+
prob_bun = probabilities.transpose(0, 2, 1, 3, 4).reshape(-1, w_dim, c_dim)
|
|
447
|
+
v_bun = value_blocks.transpose(0, 1, 3, 2, 4).reshape(-1, c_dim, h_dim)
|
|
448
|
+
result_bmm = mx.matmul(prob_bun, v_bun)
|
|
449
|
+
context_vectors = result_bmm.reshape(
|
|
450
|
+
b_dim, u_dim, n_dim, w_dim, h_dim
|
|
451
|
+
).transpose(0, 1, 3, 2, 4)
|
|
452
|
+
context_vectors = context_vectors.reshape(
|
|
453
|
+
(
|
|
454
|
+
batch_size,
|
|
455
|
+
num_query_blocks * self.chunk_size,
|
|
456
|
+
self.num_heads,
|
|
457
|
+
self.head_dim,
|
|
458
|
+
)
|
|
459
|
+
)
|
|
460
|
+
context_vectors = context_vectors[:, :q_time]
|
|
461
|
+
|
|
462
|
+
return context_vectors
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
class Gemma3nCumulativeGroupNorm(nn.Module):
|
|
466
|
+
"""Applies Group Normalization cumulatively over the time dimension.
|
|
467
|
+
|
|
468
|
+
This layer normalizes the input by calculating the mean and variance
|
|
469
|
+
cumulatively over the time dimension (dim 1). The statistics are computed
|
|
470
|
+
over all feature dimensions (specified by `feature_dims` and `num_channels`)
|
|
471
|
+
for elements marked as valid by the optional `mask`.
|
|
472
|
+
|
|
473
|
+
If a `mask` is provided (True for valid, False for invalid/padded),
|
|
474
|
+
invalid time steps do not contribute to the statistics calculation, and
|
|
475
|
+
their corresponding output values are zeroed out.
|
|
476
|
+
|
|
477
|
+
Scale and bias, if enabled, are applied per-channel (last dimension).
|
|
478
|
+
This behavior is similar to JAX's `GroupNormalization` with `num_groups=1`
|
|
479
|
+
and `cumulative=True`.
|
|
480
|
+
"""
|
|
481
|
+
|
|
482
|
+
def __init__(
|
|
483
|
+
self,
|
|
484
|
+
num_channels: int, # Number of channels (size of the last dimension)
|
|
485
|
+
feature_dims: Tuple[
|
|
486
|
+
int
|
|
487
|
+
], # Sizes of non-channel feature dimensions, e.g., (H, W) for input [B,T,H,W,C]
|
|
488
|
+
eps: float = 1e-3,
|
|
489
|
+
use_scale: bool = True,
|
|
490
|
+
use_bias: bool = False,
|
|
491
|
+
):
|
|
492
|
+
super().__init__()
|
|
493
|
+
self.num_channels = num_channels
|
|
494
|
+
self.feature_dims = tuple(feature_dims)
|
|
495
|
+
self.eps = eps
|
|
496
|
+
self.use_scale = use_scale
|
|
497
|
+
self.use_bias = use_bias
|
|
498
|
+
|
|
499
|
+
if self.use_scale:
|
|
500
|
+
# Scale parameter depends only on the channel dimension
|
|
501
|
+
self.weight = mx.ones(num_channels)
|
|
502
|
+
else:
|
|
503
|
+
self.weight = None
|
|
504
|
+
|
|
505
|
+
if self.use_bias:
|
|
506
|
+
# Bias parameter depends only on the channel dimension
|
|
507
|
+
self.bias = mx.zeros(num_channels)
|
|
508
|
+
else:
|
|
509
|
+
self.bias = None
|
|
510
|
+
|
|
511
|
+
# Axes for normalization: all dimensions except Batch (0) and Time (1).
|
|
512
|
+
# For input [B, T, *feature_dims, C], these are dims from 2 onwards.
|
|
513
|
+
self.reduction_axes = tuple(range(2, 2 + len(self.feature_dims) + 1))
|
|
514
|
+
|
|
515
|
+
def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array:
|
|
516
|
+
"""Applies cumulative group norm, optionally using a mask.
|
|
517
|
+
|
|
518
|
+
Args:
|
|
519
|
+
x: Input tensor, shape [B, T, *feature_dims, C].
|
|
520
|
+
mask: Optional boolean mask, shape [B, T]. True indicates a valid
|
|
521
|
+
(non-padded) time step. If None, all time steps are considered valid.
|
|
522
|
+
|
|
523
|
+
Returns:
|
|
524
|
+
Normalized tensor with the same shape as x.
|
|
525
|
+
"""
|
|
526
|
+
expected_input_suffix = self.feature_dims + (self.num_channels,)
|
|
527
|
+
if x.shape[2:] != expected_input_suffix:
|
|
528
|
+
raise ValueError(
|
|
529
|
+
f"Input tensor shape suffix {x.shape[2:]} does not match expected"
|
|
530
|
+
f" suffix (feature_dims + num_channels) {expected_input_suffix}"
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
if mask is not None:
|
|
534
|
+
if mask.shape != x.shape[:2]:
|
|
535
|
+
raise ValueError(
|
|
536
|
+
f"Mask shape {mask.shape} must match input Batch/Time dimensions {x.shape[:2]}"
|
|
537
|
+
)
|
|
538
|
+
if mask.dtype != mx.bool:
|
|
539
|
+
raise TypeError("Mask must be a boolean tensor.")
|
|
540
|
+
|
|
541
|
+
input_dtype = x.dtype
|
|
542
|
+
# Calculations are performed in float32 for numerical stability.
|
|
543
|
+
calc_dtype = mx.float32
|
|
544
|
+
x_calc = x.astype(calc_dtype)
|
|
545
|
+
|
|
546
|
+
# Prepare a broadcastable mask (`mask_calc`).
|
|
547
|
+
# If no mask is provided, treat all elements as valid
|
|
548
|
+
# (mask_calc is all ones).
|
|
549
|
+
# Otherwise, expand the [B, T] mask to [B, T, 1, ..., 1] for broadcasting.
|
|
550
|
+
if mask is not None:
|
|
551
|
+
mask_suffix_shape = (1,) * len(expected_input_suffix)
|
|
552
|
+
mask_calc = mask.reshape(mask.shape + mask_suffix_shape).astype(calc_dtype)
|
|
553
|
+
else:
|
|
554
|
+
mask_calc = mx.ones_like(x_calc).astype(calc_dtype)
|
|
555
|
+
|
|
556
|
+
# Mask the input for sum calculation: only valid elements contribute.
|
|
557
|
+
x_masked_for_sum = x_calc * mask_calc
|
|
558
|
+
|
|
559
|
+
# Cumulative Statistics Calculation
|
|
560
|
+
# 1. Sum of values over reduction axes at each time step.
|
|
561
|
+
sum_values_at_t = mx.sum(
|
|
562
|
+
x_masked_for_sum, axis=self.reduction_axes, keepdims=True
|
|
563
|
+
)
|
|
564
|
+
# 2. Cumulative sum of values over time.
|
|
565
|
+
cum_sum_values = mx.cumsum(sum_values_at_t, axis=1)
|
|
566
|
+
|
|
567
|
+
# 3. Count of valid elements in the normalization group at each time step.
|
|
568
|
+
# (A "group" here consists of all features at a given Batch, Time).
|
|
569
|
+
elements_in_group_at_t = mx.sum(
|
|
570
|
+
mask_calc, axis=self.reduction_axes, keepdims=True
|
|
571
|
+
)
|
|
572
|
+
# 4. Cumulative count of valid elements over time.
|
|
573
|
+
cum_count_elements = mx.cumsum(elements_in_group_at_t, axis=1)
|
|
574
|
+
# Avoid division by zero if all preceding elements were masked.
|
|
575
|
+
safe_cum_count_elements = mx.clip(cum_count_elements, 1, None)
|
|
576
|
+
|
|
577
|
+
# 5. Cumulative mean.
|
|
578
|
+
cum_mean = cum_sum_values / safe_cum_count_elements
|
|
579
|
+
|
|
580
|
+
# 6. Sum of squared differences from the cumulative mean.
|
|
581
|
+
# Only sum for valid elements: (x_calc - cum_mean)^2 * mask_calc.
|
|
582
|
+
# Using x_calc here for the difference, as cum_mean already accounts for masking.
|
|
583
|
+
squared_diff_from_mean = (x_calc - cum_mean) ** 2
|
|
584
|
+
sum_sq_diff_at_t = mx.sum(
|
|
585
|
+
squared_diff_from_mean * mask_calc,
|
|
586
|
+
axis=self.reduction_axes,
|
|
587
|
+
keepdims=True,
|
|
588
|
+
)
|
|
589
|
+
# 7. Cumulative sum of squared differences over time.
|
|
590
|
+
cum_sum_sq_diff = mx.cumsum(sum_sq_diff_at_t, axis=1)
|
|
591
|
+
|
|
592
|
+
# 8. Cumulative variance.
|
|
593
|
+
cum_variance = cum_sum_sq_diff / safe_cum_count_elements
|
|
594
|
+
|
|
595
|
+
# Normalize the input using the calculated cumulative statistics:
|
|
596
|
+
# (x - E[x]) / sqrt(Var[x] + eps)
|
|
597
|
+
normalized_x = (x_calc - cum_mean) * mx.rsqrt(cum_variance + self.eps)
|
|
598
|
+
|
|
599
|
+
# Apply affine transformation (scale and bias) if enabled.
|
|
600
|
+
# Scale and bias are applied per-channel (last dimension).
|
|
601
|
+
if self.use_scale and self.weight is not None:
|
|
602
|
+
scale = self.weight.astype(calc_dtype)
|
|
603
|
+
# Reshape for broadcasting: [C] -> [1, ..., 1, C]
|
|
604
|
+
scale_view_shape = [1] * (x.ndim - 1) + [self.num_channels]
|
|
605
|
+
normalized_x = normalized_x * scale.reshape(scale_view_shape)
|
|
606
|
+
|
|
607
|
+
if self.use_bias and self.bias is not None:
|
|
608
|
+
bias = self.bias.astype(calc_dtype)
|
|
609
|
+
bias_view_shape = [1] * (x.ndim - 1) + [self.num_channels]
|
|
610
|
+
normalized_x = normalized_x + bias.reshape(bias_view_shape)
|
|
611
|
+
|
|
612
|
+
# Zero out outputs for time steps that were originally masked (where mask_calc is 0).
|
|
613
|
+
# This ensures padded/invalid positions in the input result in zero output.
|
|
614
|
+
final_output = normalized_x * mask_calc
|
|
615
|
+
|
|
616
|
+
return final_output.astype(input_dtype)
|
|
617
|
+
|
|
618
|
+
|
|
619
|
+
class Gemma3nAudioSSCPConvBlock(nn.Module):
|
|
620
|
+
def __init__(
|
|
621
|
+
self,
|
|
622
|
+
idx: int,
|
|
623
|
+
input_freq_dim: int,
|
|
624
|
+
config: AudioConfig,
|
|
625
|
+
manual_padding: Tuple[int, int, int, int] = (0, 0, 0, 0),
|
|
626
|
+
*args,
|
|
627
|
+
**kwargs,
|
|
628
|
+
):
|
|
629
|
+
super().__init__()
|
|
630
|
+
self.config = config
|
|
631
|
+
self.manual_padding = manual_padding
|
|
632
|
+
|
|
633
|
+
# in_channels is 1 for the first block, or C_out from previous block's conv
|
|
634
|
+
in_channels = 1 if idx == 0 else self.config.sscp_conv_channel_size[idx - 1]
|
|
635
|
+
out_channels = self.config.sscp_conv_channel_size[idx]
|
|
636
|
+
kernel_h, kernel_w = self.config.sscp_conv_kernel_size[idx]
|
|
637
|
+
stride_h, stride_w = self.config.sscp_conv_stride_size[idx]
|
|
638
|
+
|
|
639
|
+
self.conv = nn.Conv2d(
|
|
640
|
+
in_channels=in_channels,
|
|
641
|
+
out_channels=out_channels,
|
|
642
|
+
kernel_size=(
|
|
643
|
+
kernel_h,
|
|
644
|
+
kernel_w,
|
|
645
|
+
), # Kernel (kH, kW) operates on (Time, Freq_dim)
|
|
646
|
+
stride=(stride_h, stride_w),
|
|
647
|
+
padding=(0, 0), # Manual padding is used
|
|
648
|
+
bias=False,
|
|
649
|
+
)
|
|
650
|
+
|
|
651
|
+
# Calculate output frequency dimension (f_out_conv) after this convolution.
|
|
652
|
+
# input_freq_dim is the unpadded width (feature dimension).
|
|
653
|
+
# self.manual_padding is (pad_F_left, pad_F_right, pad_T_top, pad_T_bottom)
|
|
654
|
+
f_in_padded = (
|
|
655
|
+
input_freq_dim
|
|
656
|
+
+ self.manual_padding[0] # pad_F_left
|
|
657
|
+
+ self.manual_padding[1] # pad_F_right
|
|
658
|
+
)
|
|
659
|
+
f_out_conv = (f_in_padded - kernel_w) // stride_w + 1
|
|
660
|
+
|
|
661
|
+
self.norm = Gemma3nCumulativeGroupNorm(
|
|
662
|
+
num_channels=out_channels, # Channels of the conv output
|
|
663
|
+
feature_dims=(f_out_conv,), # The frequency dimension size after conv
|
|
664
|
+
eps=self.config.sscp_conv_eps,
|
|
665
|
+
use_scale=True,
|
|
666
|
+
use_bias=False,
|
|
667
|
+
)
|
|
668
|
+
|
|
669
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
670
|
+
# Input audio_encodings is [B, C_in, T_in, F_in] (e.g., C_in=1)
|
|
671
|
+
# manual_padding is (pad_F_left, pad_F_right, pad_T_top, pad_T_bottom)
|
|
672
|
+
# F.pad applies to last two dims: F_in then T_in
|
|
673
|
+
|
|
674
|
+
audio_encodings_padded = mx.pad(
|
|
675
|
+
x, convert_torch_to_mlx_pad_width(self.manual_padding, x.shape)
|
|
676
|
+
)
|
|
677
|
+
|
|
678
|
+
# Expected padded shape for F_in, k_w=3, pad_F=(1,1) -> F_padded = F_in+2
|
|
679
|
+
# Expected padded shape for T_in, k_h=3, pad_T=(0,2) -> T_padded = T_in+2
|
|
680
|
+
audio_encodings_conv = self.conv(audio_encodings_padded.transpose(0, 2, 3, 1))
|
|
681
|
+
# Expected conv output shape: [B, C_out, T_out, F_out]
|
|
682
|
+
# Input to norm is [B, T_out, F_out, C_out]
|
|
683
|
+
x_normed = self.norm(audio_encodings_conv)
|
|
684
|
+
# Output of norm is [B, T_out, F_out, C_out], permute back to [B, C_out, T_out, F_out]
|
|
685
|
+
audio_encodings_normed = x_normed.transpose(0, 3, 1, 2)
|
|
686
|
+
return nn.relu(audio_encodings_normed)
|
|
687
|
+
|
|
688
|
+
|
|
689
|
+
class Gemma3nAudioSubSampleConvProjection(nn.Module):
|
|
690
|
+
|
|
691
|
+
def __init__(self, config: AudioConfig, *args, **kwargs):
|
|
692
|
+
super().__init__()
|
|
693
|
+
self.config = config
|
|
694
|
+
|
|
695
|
+
current_f_for_block_input = (
|
|
696
|
+
config.input_feat_size
|
|
697
|
+
) # Start with original feature dim
|
|
698
|
+
calculated_block_padding = []
|
|
699
|
+
calculated_f_out_dims = [] # Tracking frequency dimension output sizes
|
|
700
|
+
|
|
701
|
+
for i in range(2): # Assuming 2 conv layers as per sscp_conv_... arrays
|
|
702
|
+
kernel_h, kernel_w = config.sscp_conv_kernel_size[i]
|
|
703
|
+
stride_h, stride_w = config.sscp_conv_stride_size[i]
|
|
704
|
+
# Assuming dilation rate of 1 for frequency dimension as it's not in config
|
|
705
|
+
# effective_kernel_w = (kernel_w - 1) * dilation_w + 1 # Not needed if hardcoding freq padding
|
|
706
|
+
|
|
707
|
+
# Padding for Time (Height for Conv2d) - REVERSE_CAUSAL like
|
|
708
|
+
# JAX 'reverse_causal' padding is (0, kernel_size - 1)
|
|
709
|
+
pad_t_top = 0
|
|
710
|
+
pad_t_bottom = kernel_h - 1
|
|
711
|
+
|
|
712
|
+
# Frequency Padding (Width for Conv2d)
|
|
713
|
+
# Based on JAX effective padding (1,1) for F_in=10, K_w=3, S_w=2
|
|
714
|
+
# and the successful test configuration.
|
|
715
|
+
# If kernel/stride/input_freq for frequency changes, this might need re-evaluation
|
|
716
|
+
# to match generic JAX 'SAME' behavior if it differs.
|
|
717
|
+
pad_f_left = 1
|
|
718
|
+
pad_f_right = 1
|
|
719
|
+
|
|
720
|
+
manual_padding_tuple = (
|
|
721
|
+
pad_f_left,
|
|
722
|
+
pad_f_right,
|
|
723
|
+
pad_t_top,
|
|
724
|
+
pad_t_bottom,
|
|
725
|
+
)
|
|
726
|
+
calculated_block_padding.append(manual_padding_tuple)
|
|
727
|
+
|
|
728
|
+
# Calculate output frequency dimension after this convolution
|
|
729
|
+
# This uses the actual padding applied and kernel/stride.
|
|
730
|
+
f_in_padded = current_f_for_block_input + pad_f_left + pad_f_right
|
|
731
|
+
f_out_after_conv = (
|
|
732
|
+
f_in_padded - kernel_w
|
|
733
|
+
) // stride_w + 1 # Assuming dilation_w = 1
|
|
734
|
+
calculated_f_out_dims.append(f_out_after_conv)
|
|
735
|
+
current_f_for_block_input = f_out_after_conv
|
|
736
|
+
|
|
737
|
+
self.conv_0 = Gemma3nAudioSSCPConvBlock(
|
|
738
|
+
idx=0,
|
|
739
|
+
input_freq_dim=config.input_feat_size, # Pass original feature dim
|
|
740
|
+
config=config,
|
|
741
|
+
manual_padding=calculated_block_padding[0],
|
|
742
|
+
)
|
|
743
|
+
self.conv_1 = Gemma3nAudioSSCPConvBlock(
|
|
744
|
+
idx=1,
|
|
745
|
+
input_freq_dim=calculated_f_out_dims[0], # Output freq dim from conv_0
|
|
746
|
+
config=config,
|
|
747
|
+
manual_padding=calculated_block_padding[1],
|
|
748
|
+
)
|
|
749
|
+
final_c_out = config.sscp_conv_channel_size[-1]
|
|
750
|
+
final_f_out = calculated_f_out_dims[-1] # Final frequency dimension
|
|
751
|
+
self.input_proj_in_features = final_c_out * final_f_out
|
|
752
|
+
self.input_proj_linear = nn.Linear(
|
|
753
|
+
self.input_proj_in_features, self.config.hidden_size, bias=False
|
|
754
|
+
)
|
|
755
|
+
|
|
756
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
757
|
+
# audio_encodings is [B, T, F_in]
|
|
758
|
+
# Reshape to [B, 1, T, F_in] (Batch, Channels=1, Height=Time, Width=F_in)
|
|
759
|
+
audio_encodings_reshaped = mx.expand_dims(x, 1)
|
|
760
|
+
x = self.conv_0(audio_encodings_reshaped)
|
|
761
|
+
x = self.conv_1(x)
|
|
762
|
+
# x from conv_1 is [B, C_out_1, T_out_1, F_out_1]
|
|
763
|
+
b, c_out, t_out, f_out = x.shape
|
|
764
|
+
# Permute to [B, T_out_1, F_out_1, C_out_1] then flatten F_out_1 and C_out_1
|
|
765
|
+
x_transposed = x.transpose(0, 2, 3, 1)
|
|
766
|
+
output_flattened = x_transposed.reshape(b, t_out, f_out * c_out)
|
|
767
|
+
output = self.input_proj_linear(output_flattened)
|
|
768
|
+
return output
|
|
769
|
+
|
|
770
|
+
|
|
771
|
+
class Gemma3nAudioConformerAttention(nn.Module):
|
|
772
|
+
def __init__(self, config: AudioConfig, *args, **kwargs):
|
|
773
|
+
super().__init__()
|
|
774
|
+
self.config = config
|
|
775
|
+
|
|
776
|
+
head_dim = self.config.hidden_size // self.config.conf_num_attention_heads
|
|
777
|
+
self.post_in_shape = (self.config.conf_num_attention_heads, head_dim)
|
|
778
|
+
self.post_in_features = self.config.hidden_size
|
|
779
|
+
|
|
780
|
+
self._gradient_clipping = mx.array(self.config.gradient_clipping)
|
|
781
|
+
|
|
782
|
+
self.pre_attn_norm = Gemma3nRMSNorm(self.config.hidden_size)
|
|
783
|
+
self.attn = Gemma3nAudioAttention(config)
|
|
784
|
+
self.post = nn.Linear(
|
|
785
|
+
self.post_in_features, self.config.hidden_size, bias=False
|
|
786
|
+
)
|
|
787
|
+
self.post_norm = Gemma3nRMSNorm(self.config.hidden_size)
|
|
788
|
+
|
|
789
|
+
def __call__(self, x: mx.array, mask: mx.array) -> mx.array:
|
|
790
|
+
audio_encodings_input_to_attn = x
|
|
791
|
+
x = mx.clip(x, -self._gradient_clipping, self._gradient_clipping)
|
|
792
|
+
audio_encodings_norm = self.pre_attn_norm(x)
|
|
793
|
+
# Output of self.attn is [B, T, NumHeads, HeadDim]
|
|
794
|
+
audio_encodings_attn_out = self.attn(audio_encodings_norm, mask)
|
|
795
|
+
|
|
796
|
+
# Reshape from [B, T, NumHeads, HeadDim] to [B, T, NumHeads * HeadDim]
|
|
797
|
+
# NumHeads * HeadDim = hidden_size
|
|
798
|
+
b, t, num_heads, head_dim = audio_encodings_attn_out.shape
|
|
799
|
+
audio_encodings_reshaped = audio_encodings_attn_out.reshape(
|
|
800
|
+
b, t, num_heads * head_dim
|
|
801
|
+
)
|
|
802
|
+
|
|
803
|
+
x = self.post(audio_encodings_reshaped)
|
|
804
|
+
x = mx.clip(x, -self._gradient_clipping, self._gradient_clipping)
|
|
805
|
+
return audio_encodings_input_to_attn + self.post_norm(x)
|
|
806
|
+
|
|
807
|
+
|
|
808
|
+
class Gemma3nAudioConformerFeedForward(nn.Module):
|
|
809
|
+
def __init__(self, config: AudioConfig, *args, **kwargs):
|
|
810
|
+
super().__init__()
|
|
811
|
+
self.config = config
|
|
812
|
+
|
|
813
|
+
self._gradient_clipping = mx.array(self.config.gradient_clipping)
|
|
814
|
+
|
|
815
|
+
self.pre_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
|
|
816
|
+
self.ffw_layer_1 = nn.Linear(
|
|
817
|
+
self.config.hidden_size, self.config.hidden_size * 4, bias=False
|
|
818
|
+
)
|
|
819
|
+
self.ffw_layer_2 = nn.Linear(
|
|
820
|
+
self.config.hidden_size * 4, self.config.hidden_size, bias=False
|
|
821
|
+
)
|
|
822
|
+
self.post_layer_norm = Gemma3nRMSNorm(self.config.hidden_size)
|
|
823
|
+
self._post_layer_scale = mx.array(self.config.conf_residual_weight)
|
|
824
|
+
|
|
825
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
826
|
+
residual = x
|
|
827
|
+
x = mx.clip(x, -self._gradient_clipping, self._gradient_clipping)
|
|
828
|
+
x = self.pre_layer_norm(x)
|
|
829
|
+
x: mx.array = self.ffw_layer_1(x) # jax.numpy.einsum("...a,ab->...b")
|
|
830
|
+
x = nn.silu(x) # Add SiLU (Swish) activation
|
|
831
|
+
x: mx.array = self.ffw_layer_2(x) # jax.numpy.einsum("...a,ab->...b")
|
|
832
|
+
x = mx.clip(x, -self._gradient_clipping, self._gradient_clipping)
|
|
833
|
+
x = self.post_layer_norm(x)
|
|
834
|
+
return residual + (x * self._post_layer_scale)
|
|
835
|
+
|
|
836
|
+
|
|
837
|
+
class Gemma3nAudioConformerLightConv1d(nn.Module):
|
|
838
|
+
def __init__(self, config: AudioConfig, *args, **kwargs):
|
|
839
|
+
super().__init__()
|
|
840
|
+
self.config = config
|
|
841
|
+
|
|
842
|
+
self.pre_layer_norm = Gemma3nRMSNorm(
|
|
843
|
+
self.config.hidden_size, eps=self.config.rms_norm_eps
|
|
844
|
+
)
|
|
845
|
+
self.linear_start = nn.Linear(
|
|
846
|
+
self.config.hidden_size, self.config.hidden_size * 2, bias=False
|
|
847
|
+
)
|
|
848
|
+
self.depthwise_conv1d = nn.Conv1d(
|
|
849
|
+
in_channels=self.config.hidden_size,
|
|
850
|
+
out_channels=self.config.hidden_size,
|
|
851
|
+
kernel_size=self.config.conf_conv_kernel_size,
|
|
852
|
+
stride=1,
|
|
853
|
+
padding=0, # Manual causal padding
|
|
854
|
+
groups=self.config.hidden_size, # Depthwise
|
|
855
|
+
bias=False,
|
|
856
|
+
)
|
|
857
|
+
self._gradient_clipping = mx.array(self.config.gradient_clipping)
|
|
858
|
+
self.conv_norm = Gemma3nRMSNorm(
|
|
859
|
+
self.config.hidden_size, eps=self.config.rms_norm_eps
|
|
860
|
+
)
|
|
861
|
+
self.linear_end = nn.Linear(
|
|
862
|
+
self.config.hidden_size, self.config.hidden_size, bias=False
|
|
863
|
+
)
|
|
864
|
+
|
|
865
|
+
self.causal_padding = self.config.conf_conv_kernel_size - 1
|
|
866
|
+
|
|
867
|
+
def __call__(self, audio_encodings: mx.array) -> mx.array:
|
|
868
|
+
audio_encodings_residual = audio_encodings # Save for residual connection
|
|
869
|
+
|
|
870
|
+
audio_encodings = self.pre_layer_norm(audio_encodings)
|
|
871
|
+
audio_encodings = self.linear_start(audio_encodings)
|
|
872
|
+
audio_encodings = nn.glu(audio_encodings, axis=-1)
|
|
873
|
+
# Permute for Conv1d: [B, T, D] -> [B, D, T]
|
|
874
|
+
audio_encodings_transposed = audio_encodings.transpose(0, 2, 1)
|
|
875
|
+
# Apply manual causal padding
|
|
876
|
+
audio_encodings_transposed_padded = mx.pad(
|
|
877
|
+
audio_encodings_transposed,
|
|
878
|
+
convert_torch_to_mlx_pad_width(
|
|
879
|
+
(self.causal_padding, 0), audio_encodings_transposed.shape
|
|
880
|
+
),
|
|
881
|
+
)
|
|
882
|
+
audio_encodings = self.depthwise_conv1d(
|
|
883
|
+
audio_encodings_transposed_padded.transpose(0, 2, 1)
|
|
884
|
+
)
|
|
885
|
+
audio_encodings = mx.clip(
|
|
886
|
+
audio_encodings, -self._gradient_clipping, self._gradient_clipping
|
|
887
|
+
)
|
|
888
|
+
audio_encodings = self.conv_norm(audio_encodings)
|
|
889
|
+
audio_encodings = nn.silu(audio_encodings)
|
|
890
|
+
audio_encodings = self.linear_end(audio_encodings)
|
|
891
|
+
output = audio_encodings + audio_encodings_residual
|
|
892
|
+
return output
|
|
893
|
+
|
|
894
|
+
|
|
895
|
+
class Gemma3nAudioConformerBlock(nn.Module):
|
|
896
|
+
|
|
897
|
+
def __init__(self, config: AudioConfig, *args, **kwargs):
|
|
898
|
+
super().__init__()
|
|
899
|
+
self.config = config
|
|
900
|
+
|
|
901
|
+
self.ffw_layer_start = Gemma3nAudioConformerFeedForward(self.config)
|
|
902
|
+
self.attention = Gemma3nAudioConformerAttention(self.config)
|
|
903
|
+
self.lconv1d = Gemma3nAudioConformerLightConv1d(self.config)
|
|
904
|
+
self.ffw_layer_end = Gemma3nAudioConformerFeedForward(self.config)
|
|
905
|
+
self._gradient_clipping = mx.array(self.config.gradient_clipping)
|
|
906
|
+
self.norm = Gemma3nRMSNorm(self.config.hidden_size)
|
|
907
|
+
|
|
908
|
+
def __call__(self, audio_encodings: mx.array, audio_mel_mask: mx.array) -> mx.array:
|
|
909
|
+
audio_encodings = self.ffw_layer_start(audio_encodings)
|
|
910
|
+
audio_encodings = self.attention(audio_encodings, audio_mel_mask)
|
|
911
|
+
validity_mask_for_lconv = ~audio_mel_mask # True for valid
|
|
912
|
+
audio_encodings_for_lconv_input = audio_encodings * mx.expand_dims(
|
|
913
|
+
validity_mask_for_lconv, -1
|
|
914
|
+
).astype(audio_encodings.dtype)
|
|
915
|
+
audio_encodings = self.lconv1d(audio_encodings_for_lconv_input)
|
|
916
|
+
|
|
917
|
+
audio_encodings = self.ffw_layer_end(audio_encodings)
|
|
918
|
+
audio_encodings = mx.clip(
|
|
919
|
+
audio_encodings, -self._gradient_clipping, self._gradient_clipping
|
|
920
|
+
)
|
|
921
|
+
output = self.norm(audio_encodings)
|
|
922
|
+
return output
|
|
923
|
+
|
|
924
|
+
|
|
925
|
+
class AudioModel(nn.Module):
|
|
926
|
+
def __init__(self, config: AudioConfig, *args, **kwargs):
|
|
927
|
+
super().__init__()
|
|
928
|
+
self.config = config
|
|
929
|
+
|
|
930
|
+
self.subsample_conv_projection = Gemma3nAudioSubSampleConvProjection(config)
|
|
931
|
+
self.conformer = [
|
|
932
|
+
Gemma3nAudioConformerBlock(config)
|
|
933
|
+
for _ in range(config.conf_num_hidden_layers)
|
|
934
|
+
]
|
|
935
|
+
|
|
936
|
+
def __call__(
|
|
937
|
+
self, audio_mel: mx.array, audio_mel_mask: mx.array
|
|
938
|
+
) -> Tuple[mx.array, mx.array]:
|
|
939
|
+
audio_encodings = self.subsample_conv_projection(
|
|
940
|
+
audio_mel
|
|
941
|
+
) # audio_encodings: [B, T_sub, D]
|
|
942
|
+
|
|
943
|
+
# Subsample the input audio_mel_mask to match the time dimension of audio_encodings (T_sub)
|
|
944
|
+
t_sub = audio_encodings.shape[1]
|
|
945
|
+
|
|
946
|
+
time_stride_product = 1
|
|
947
|
+
for stride_pair_idx in range(len(self.config.sscp_conv_stride_size)):
|
|
948
|
+
time_stride_product *= self.config.sscp_conv_stride_size[stride_pair_idx][0]
|
|
949
|
+
|
|
950
|
+
# Create indices for gathering from the original mask.
|
|
951
|
+
# These indices map to original time steps corresponding to the start of each
|
|
952
|
+
# receptive field in the subsampled output.
|
|
953
|
+
indices = mx.arange(t_sub) * time_stride_product
|
|
954
|
+
indices = mx.clip(
|
|
955
|
+
indices, None, a_max=audio_mel_mask.shape[1] - 1
|
|
956
|
+
) # Ensure indices are valid
|
|
957
|
+
|
|
958
|
+
# Expand indices for batch compatibility if B > 1 and indices is 1D.
|
|
959
|
+
if audio_mel_mask.ndim > 1 and indices.ndim == 1:
|
|
960
|
+
indices = indices[None, :]
|
|
961
|
+
indices = mx.broadcast_to(
|
|
962
|
+
indices, (audio_mel_mask.shape[0], indices.shape[1])
|
|
963
|
+
) # [B, T_sub]
|
|
964
|
+
elif (
|
|
965
|
+
audio_mel_mask.ndim == indices.ndim
|
|
966
|
+
and audio_mel_mask.shape[0] == 1
|
|
967
|
+
and indices.shape[0] != 1
|
|
968
|
+
and t_sub == indices.shape[0]
|
|
969
|
+
):
|
|
970
|
+
# Handle case where B=1 but indices became [T_sub] instead of [1, T_sub]
|
|
971
|
+
indices = indices[None, :]
|
|
972
|
+
|
|
973
|
+
current_mask = mx.take_along_axis(audio_mel_mask, indices, axis=1) # [B, T_sub]
|
|
974
|
+
|
|
975
|
+
# Fallback: Ensure mask length matches feature length after gather.
|
|
976
|
+
if current_mask.shape[1] != t_sub:
|
|
977
|
+
print(
|
|
978
|
+
"Warning: Subsampled mask length %s mismatch with feature length %s after gather. Adjusting.",
|
|
979
|
+
current_mask.shape[1],
|
|
980
|
+
t_sub,
|
|
981
|
+
)
|
|
982
|
+
if current_mask.shape[1] > t_sub:
|
|
983
|
+
current_mask = current_mask[:, :t_sub]
|
|
984
|
+
else: # current_mask.shape[1] < t_sub
|
|
985
|
+
padding_needed = t_sub - current_mask.shape[1]
|
|
986
|
+
current_mask = mx.pad(
|
|
987
|
+
current_mask,
|
|
988
|
+
convert_torch_to_mlx_pad_width(
|
|
989
|
+
(0, padding_needed), current_mask.shape
|
|
990
|
+
),
|
|
991
|
+
)
|
|
992
|
+
|
|
993
|
+
for i, block in enumerate(self.conformer):
|
|
994
|
+
audio_encodings = block(
|
|
995
|
+
audio_encodings, current_mask
|
|
996
|
+
) # Pass the processed mask
|
|
997
|
+
|
|
998
|
+
if self.config.conf_reduction_factor > 1:
|
|
999
|
+
audio_encodings = audio_encodings[:, :: self.config.conf_reduction_factor]
|
|
1000
|
+
# Reduce the mask as well
|
|
1001
|
+
current_mask = current_mask[:, :: self.config.conf_reduction_factor]
|
|
1002
|
+
|
|
1003
|
+
# Final masking of audio_encodings based on the final current_mask
|
|
1004
|
+
# Ensure current_mask length matches the finally reduced audio_encodings length
|
|
1005
|
+
if current_mask.shape[1] != audio_encodings.shape[1]:
|
|
1006
|
+
target_len = audio_encodings.shape[1]
|
|
1007
|
+
mask_current_len = current_mask.shape[1]
|
|
1008
|
+
if target_len > mask_current_len:
|
|
1009
|
+
padding_needed = target_len - mask_current_len
|
|
1010
|
+
current_mask = mx.pad(
|
|
1011
|
+
current_mask,
|
|
1012
|
+
convert_torch_to_mlx_pad_width(
|
|
1013
|
+
(0, padding_needed), current_mask.shape
|
|
1014
|
+
),
|
|
1015
|
+
)
|
|
1016
|
+
elif mask_current_len > target_len: # mask is longer
|
|
1017
|
+
current_mask = current_mask[:, :target_len]
|
|
1018
|
+
|
|
1019
|
+
audio_encodings = mx.where(current_mask[..., None], 0.0, audio_encodings)
|
|
1020
|
+
return audio_encodings, current_mask
|
|
1021
|
+
|
|
1022
|
+
def sanitize(self, weights):
|
|
1023
|
+
sanitized_weights = {}
|
|
1024
|
+
for k, v in weights.items():
|
|
1025
|
+
if "conv.weight" in k:
|
|
1026
|
+
if check_array_shape(v):
|
|
1027
|
+
sanitized_weights[k] = v
|
|
1028
|
+
else:
|
|
1029
|
+
sanitized_weights[k] = v.transpose(0, 2, 3, 1)
|
|
1030
|
+
elif "conv1d.weight" in k:
|
|
1031
|
+
if check_array_shape(v):
|
|
1032
|
+
sanitized_weights[k] = v
|
|
1033
|
+
else:
|
|
1034
|
+
sanitized_weights[k] = v.transpose(0, 2, 1)
|
|
1035
|
+
else:
|
|
1036
|
+
sanitized_weights[k] = v
|
|
1037
|
+
|
|
1038
|
+
return sanitized_weights
|