nexaai 1.0.29__cp310-cp310-macosx_14_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nexaai/__init__.py +99 -0
- nexaai/_stub.cpython-310-darwin.so +0 -0
- nexaai/_version.py +4 -0
- nexaai/asr.py +68 -0
- nexaai/asr_impl/__init__.py +0 -0
- nexaai/asr_impl/mlx_asr_impl.py +93 -0
- nexaai/asr_impl/pybind_asr_impl.py +127 -0
- nexaai/base.py +39 -0
- nexaai/binds/__init__.py +7 -0
- nexaai/binds/asr_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/common_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/cpu_gpu/libggml-base.dylib +0 -0
- nexaai/binds/cpu_gpu/libggml-cpu.so +0 -0
- nexaai/binds/cpu_gpu/libggml-metal.so +0 -0
- nexaai/binds/cpu_gpu/libggml.dylib +0 -0
- nexaai/binds/cpu_gpu/libmtmd.dylib +0 -0
- nexaai/binds/cpu_gpu/libnexa_cpu_gpu.dylib +0 -0
- nexaai/binds/cpu_gpu/libnexa_plugin.dylib +0 -0
- nexaai/binds/cv_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/diarize_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/embedder_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/libnexa_bridge.dylib +0 -0
- nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/metal/libnexa_plugin.dylib +0 -0
- nexaai/binds/metal/py-lib/ml.py +888 -0
- nexaai/binds/metal/py-lib/mlx_audio/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/__init__.py +5 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/activation.py +51 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/amp.py +96 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/conv.py +114 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/resample.py +177 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/base.py +228 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/dac.py +285 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/layers.py +129 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/encodec.py +777 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/mimi.py +286 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model.py +260 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model_v2.py +383 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/utils.py +122 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/attention.py +97 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/layers.py +306 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/snac.py +154 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/vq.py +135 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/mel.py +33 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/vocos.py +359 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_bigvgan.py +54 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_descript.py +109 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_encodec.py +58 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_mimi.py +22 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_s3.py +25 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_snac.py +40 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_vocos.py +93 -0
- nexaai/binds/metal/py-lib/mlx_audio/server.py +525 -0
- nexaai/binds/metal/py-lib/mlx_audio/sts/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
- nexaai/binds/metal/py-lib/mlx_audio/sts/voice_pipeline.py +327 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/generate.py +174 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/alignment.py +248 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/attention.py +187 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/audio.py +76 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/conformer.py +331 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/ctc.py +34 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/audio.py +82 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/decoding.py +742 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/timing.py +329 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/whisper.py +862 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/writers.py +268 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/tests/test_models.py +381 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/utils.py +195 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/audio_player.py +120 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/convert.py +71 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/generate.py +449 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/__init__.py +4 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/bark.py +528 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/isftnet.py +12 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/pipeline.py +442 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/base.py +84 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/audio.py +287 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/config.py +256 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/dia.py +592 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/layers.py +870 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/__init__.py +3 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/attention.py +180 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/conformer.py +247 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/gpt2.py +38 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/indextts.py +412 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/mel.py +37 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/normalize.py +294 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/perceiver.py +62 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/interpolate.py +108 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/__init__.py +4 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/modules.py +659 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/voice.py +113 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/__init__.py +3 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/llama.py +324 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/outetts.py +255 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/tokens.py +36 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/__init__.py +3 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/attention.py +195 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/sesame.py +633 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/watermarking.py +105 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/bicodec.py +269 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual.py +209 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/spark.py +382 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/audio.py +220 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/file.py +221 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_base.py +66 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_convert.py +173 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_interpolate.py +88 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_models.py +974 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/utils.py +337 -0
- nexaai/binds/metal/py-lib/mlx_audio/utils.py +237 -0
- nexaai/binds/metal/py-lib/mlx_audio/version.py +1 -0
- nexaai/binds/metal/py-lib/profiling.py +239 -0
- nexaai/binds/nexaml/libfftw3.3.dylib +0 -0
- nexaai/binds/nexaml/libfftw3f.3.dylib +0 -0
- nexaai/binds/nexaml/libggml-base.dylib +0 -0
- nexaai/binds/nexaml/libggml-cpu.so +0 -0
- nexaai/binds/nexaml/libggml-metal.so +0 -0
- nexaai/binds/nexaml/libggml.dylib +0 -0
- nexaai/binds/nexaml/libmp3lame.0.dylib +0 -0
- nexaai/binds/nexaml/libmpg123.0.dylib +0 -0
- nexaai/binds/nexaml/libnexa-mm-process.dylib +0 -0
- nexaai/binds/nexaml/libnexa-sampling.dylib +0 -0
- nexaai/binds/nexaml/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexaml/libnexaproc.dylib +0 -0
- nexaai/binds/nexaml/libomp.dylib +0 -0
- nexaai/binds/nexaml/libqwen3-vl.dylib +0 -0
- nexaai/binds/nexaml/libqwen3vl-vision.dylib +0 -0
- nexaai/binds/rerank_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/vlm_bind.cpython-310-darwin.so +0 -0
- nexaai/common.py +106 -0
- nexaai/cv.py +95 -0
- nexaai/cv_impl/__init__.py +0 -0
- nexaai/cv_impl/mlx_cv_impl.py +91 -0
- nexaai/cv_impl/pybind_cv_impl.py +124 -0
- nexaai/diarize.py +80 -0
- nexaai/diarize_impl/__init__.py +1 -0
- nexaai/diarize_impl/pybind_diarize_impl.py +125 -0
- nexaai/embedder.py +73 -0
- nexaai/embedder_impl/__init__.py +0 -0
- nexaai/embedder_impl/mlx_embedder_impl.py +118 -0
- nexaai/embedder_impl/pybind_embedder_impl.py +96 -0
- nexaai/image_gen.py +141 -0
- nexaai/image_gen_impl/__init__.py +0 -0
- nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -0
- nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -0
- nexaai/llm.py +98 -0
- nexaai/llm_impl/__init__.py +0 -0
- nexaai/llm_impl/mlx_llm_impl.py +271 -0
- nexaai/llm_impl/pybind_llm_impl.py +238 -0
- nexaai/log.py +92 -0
- nexaai/mlx_backend/asr/__init__.py +12 -0
- nexaai/mlx_backend/asr/interface.py +122 -0
- nexaai/mlx_backend/common/__init__.py +0 -0
- nexaai/mlx_backend/common/utils.py +25 -0
- nexaai/mlx_backend/cv/__init__.py +0 -0
- nexaai/mlx_backend/cv/generate.py +195 -0
- nexaai/mlx_backend/cv/interface.py +162 -0
- nexaai/mlx_backend/cv/main.py +81 -0
- nexaai/mlx_backend/cv/modeling/pp_ocr_v4.py +1736 -0
- nexaai/mlx_backend/embedding/__init__.py +0 -0
- nexaai/mlx_backend/embedding/generate.py +333 -0
- nexaai/mlx_backend/embedding/interface.py +617 -0
- nexaai/mlx_backend/embedding/main.py +173 -0
- nexaai/mlx_backend/embedding/modeling/__init__.py +0 -0
- nexaai/mlx_backend/embedding/modeling/nexa_jina_v2.py +399 -0
- nexaai/mlx_backend/image_gen/__init__.py +1 -0
- nexaai/mlx_backend/image_gen/generate_sd.py +244 -0
- nexaai/mlx_backend/image_gen/interface.py +82 -0
- nexaai/mlx_backend/image_gen/main.py +281 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/__init__.py +306 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/clip.py +116 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/config.py +65 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/model_io.py +386 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/sampler.py +105 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/tokenizer.py +100 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/unet.py +460 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/vae.py +274 -0
- nexaai/mlx_backend/llm/__init__.py +0 -0
- nexaai/mlx_backend/llm/generate.py +149 -0
- nexaai/mlx_backend/llm/interface.py +764 -0
- nexaai/mlx_backend/llm/main.py +68 -0
- nexaai/mlx_backend/ml.py +888 -0
- nexaai/mlx_backend/mlx_audio/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/codec/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/__init__.py +5 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/activation.py +51 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/amp.py +96 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/conv.py +114 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/resample.py +177 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/base.py +228 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/dac.py +285 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/layers.py +129 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
- nexaai/mlx_backend/mlx_audio/codec/models/encodec/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/encodec/encodec.py +777 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/mimi.py +286 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/model.py +260 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/model_v2.py +383 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/utils.py +122 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/attention.py +97 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/layers.py +306 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/snac.py +154 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/vq.py +135 -0
- nexaai/mlx_backend/mlx_audio/codec/models/vocos/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/vocos/mel.py +33 -0
- nexaai/mlx_backend/mlx_audio/codec/models/vocos/vocos.py +359 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_bigvgan.py +54 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_descript.py +109 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_encodec.py +58 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_mimi.py +22 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_s3.py +25 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_snac.py +40 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_vocos.py +93 -0
- nexaai/mlx_backend/mlx_audio/server.py +525 -0
- nexaai/mlx_backend/mlx_audio/sts/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
- nexaai/mlx_backend/mlx_audio/sts/voice_pipeline.py +327 -0
- nexaai/mlx_backend/mlx_audio/stt/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/stt/generate.py +174 -0
- nexaai/mlx_backend/mlx_audio/stt/models/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/alignment.py +248 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/attention.py +187 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/audio.py +76 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/conformer.py +331 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/ctc.py +34 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
- nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
- nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/audio.py +82 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/decoding.py +742 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/timing.py +329 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/whisper.py +862 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/writers.py +268 -0
- nexaai/mlx_backend/mlx_audio/stt/tests/test_models.py +381 -0
- nexaai/mlx_backend/mlx_audio/stt/utils.py +195 -0
- nexaai/mlx_backend/mlx_audio/tts/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/audio_player.py +120 -0
- nexaai/mlx_backend/mlx_audio/tts/convert.py +71 -0
- nexaai/mlx_backend/mlx_audio/tts/generate.py +449 -0
- nexaai/mlx_backend/mlx_audio/tts/models/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/__init__.py +4 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/bark.py +528 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/isftnet.py +12 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/pipeline.py +442 -0
- nexaai/mlx_backend/mlx_audio/tts/models/base.py +84 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/audio.py +287 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/config.py +256 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/dia.py +592 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/layers.py +870 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/__init__.py +3 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/attention.py +180 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/conformer.py +247 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/gpt2.py +38 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/indextts.py +412 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/mel.py +37 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/normalize.py +294 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/perceiver.py +62 -0
- nexaai/mlx_backend/mlx_audio/tts/models/interpolate.py +108 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/__init__.py +4 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/modules.py +659 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/voice.py +113 -0
- nexaai/mlx_backend/mlx_audio/tts/models/llama/__init__.py +3 -0
- nexaai/mlx_backend/mlx_audio/tts/models/llama/llama.py +324 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/default_speaker.json +461 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/outetts.py +255 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/tokens.py +36 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/__init__.py +3 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/attention.py +195 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/sesame.py +633 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/watermarking.py +105 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/bicodec.py +269 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual.py +209 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/spark.py +382 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/audio.py +220 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/file.py +221 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_base.py +66 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_convert.py +173 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_interpolate.py +88 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_models.py +974 -0
- nexaai/mlx_backend/mlx_audio/tts/utils.py +337 -0
- nexaai/mlx_backend/mlx_audio/utils.py +237 -0
- nexaai/mlx_backend/mlx_audio/version.py +1 -0
- nexaai/mlx_backend/profiling.py +239 -0
- nexaai/mlx_backend/rerank/__init__.py +0 -0
- nexaai/mlx_backend/rerank/generate.py +174 -0
- nexaai/mlx_backend/rerank/interface.py +287 -0
- nexaai/mlx_backend/rerank/main.py +127 -0
- nexaai/mlx_backend/rerank/modeling/__init__.py +0 -0
- nexaai/mlx_backend/rerank/modeling/nexa_jina_rerank.py +330 -0
- nexaai/mlx_backend/sd/__init__.py +1 -0
- nexaai/mlx_backend/sd/interface.py +362 -0
- nexaai/mlx_backend/sd/main.py +286 -0
- nexaai/mlx_backend/sd/modeling/__init__.py +306 -0
- nexaai/mlx_backend/sd/modeling/clip.py +116 -0
- nexaai/mlx_backend/sd/modeling/config.py +65 -0
- nexaai/mlx_backend/sd/modeling/model_io.py +385 -0
- nexaai/mlx_backend/sd/modeling/sampler.py +105 -0
- nexaai/mlx_backend/sd/modeling/tokenizer.py +100 -0
- nexaai/mlx_backend/sd/modeling/unet.py +460 -0
- nexaai/mlx_backend/sd/modeling/vae.py +274 -0
- nexaai/mlx_backend/tts/__init__.py +12 -0
- nexaai/mlx_backend/tts/interface.py +276 -0
- nexaai/mlx_backend/vlm/__init__.py +3 -0
- nexaai/mlx_backend/vlm/generate.py +572 -0
- nexaai/mlx_backend/vlm/generate_qwen3_vl.py +374 -0
- nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +259 -0
- nexaai/mlx_backend/vlm/interface.py +559 -0
- nexaai/mlx_backend/vlm/main.py +365 -0
- nexaai/mlx_backend/vlm/modeling/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/convert.py +68 -0
- nexaai/mlx_backend/vlm/modeling/models/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/interpolate.py +186 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/language.py +233 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/vision.py +503 -0
- nexaai/mlx_backend/vlm/modeling/models/base.py +202 -0
- nexaai/mlx_backend/vlm/modeling/models/cache.py +230 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/florence2.py +366 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/language.py +488 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/vision.py +591 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/gemma3.py +213 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/language.py +315 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/vision.py +238 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/__init__.py +2 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/audio.py +1038 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/config.py +139 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/language.py +629 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/vision.py +1022 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/idefics2.py +294 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/language.py +191 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/vision.py +267 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/idefics3.py +175 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/language.py +192 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/vision.py +233 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/language.py +220 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/processor.py +393 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/vision.py +293 -0
- nexaai/mlx_backend/vlm/modeling/models/kernels.py +307 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/language.py +509 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/vision.py +522 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/language.py +386 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/llama4.py +138 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/vision.py +560 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/language.py +240 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/llava.py +153 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/vision.py +259 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/language.py +236 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/vision.py +303 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/language.py +230 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/llava_next.py +160 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/vision.py +243 -0
- nexaai/mlx_backend/vlm/modeling/models/mistral3/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/mistral3/mistral3.py +283 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/language.py +416 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/mllama.py +172 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/vision.py +499 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/language.py +243 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/molmo.py +133 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/vision.py +465 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/__init__.py +10 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/language.py +230 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/sam.py +557 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/vision.py +526 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/language.py +282 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/paligemma.py +160 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/vision.py +242 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/language.py +21 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/su_rope.py +71 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/vision.py +324 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/language.py +229 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/pixtral.py +161 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/vision.py +320 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/config.py +104 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/language.py +490 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/vision.py +312 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/processor.py +476 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +1262 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1308 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
- nexaai/mlx_backend/vlm/modeling/models/smolvlm/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
- nexaai/mlx_backend/vlm/modeling/processing_qwen2_5_vl.py +209 -0
- nexaai/mlx_backend/vlm/modeling/processing_qwen2_vl.py +215 -0
- nexaai/mlx_backend/vlm/modeling/prompt_utils.py +474 -0
- nexaai/mlx_backend/vlm/modeling/sample_utils.py +39 -0
- nexaai/mlx_backend/vlm/modeling/tokenizer_utils.py +344 -0
- nexaai/mlx_backend/vlm/modeling/trainer/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/trainer/lora.py +70 -0
- nexaai/mlx_backend/vlm/modeling/trainer/trainer.py +296 -0
- nexaai/mlx_backend/vlm/modeling/trainer/utils.py +160 -0
- nexaai/mlx_backend/vlm/modeling/utils.py +928 -0
- nexaai/rerank.py +57 -0
- nexaai/rerank_impl/__init__.py +0 -0
- nexaai/rerank_impl/mlx_rerank_impl.py +94 -0
- nexaai/rerank_impl/pybind_rerank_impl.py +136 -0
- nexaai/runtime.py +68 -0
- nexaai/runtime_error.py +24 -0
- nexaai/tts.py +75 -0
- nexaai/tts_impl/__init__.py +0 -0
- nexaai/tts_impl/mlx_tts_impl.py +94 -0
- nexaai/tts_impl/pybind_tts_impl.py +43 -0
- nexaai/utils/decode.py +18 -0
- nexaai/utils/manifest_utils.py +531 -0
- nexaai/utils/model_manager.py +1745 -0
- nexaai/utils/model_types.py +49 -0
- nexaai/utils/progress_tracker.py +389 -0
- nexaai/utils/quantization_utils.py +245 -0
- nexaai/vlm.py +130 -0
- nexaai/vlm_impl/__init__.py +0 -0
- nexaai/vlm_impl/mlx_vlm_impl.py +259 -0
- nexaai/vlm_impl/pybind_vlm_impl.py +275 -0
- nexaai-1.0.29.dist-info/METADATA +35 -0
- nexaai-1.0.29.dist-info/RECORD +580 -0
- nexaai-1.0.29.dist-info/WHEEL +5 -0
- nexaai-1.0.29.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,239 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from typing import Any, Optional
|
|
6
|
+
from enum import IntEnum
|
|
7
|
+
|
|
8
|
+
# --------------------------------------------------------------------------------------
|
|
9
|
+
# Stop reason constants matching profile.h
|
|
10
|
+
# --------------------------------------------------------------------------------------
|
|
11
|
+
|
|
12
|
+
class StopReason(IntEnum):
|
|
13
|
+
"""Stop reason constants matching profile.h"""
|
|
14
|
+
ML_STOP_REASON_UNKNOWN = 0
|
|
15
|
+
ML_STOP_REASON_EOS = 1
|
|
16
|
+
ML_STOP_REASON_LENGTH = 2
|
|
17
|
+
ML_STOP_REASON_USER = 3
|
|
18
|
+
ML_STOP_REASON_STOP_SEQUENCE = 4
|
|
19
|
+
ML_STOP_REASON_COMPLETED = 5
|
|
20
|
+
|
|
21
|
+
# --------------------------------------------------------------------------------------
|
|
22
|
+
# Profiling data structure
|
|
23
|
+
# --------------------------------------------------------------------------------------
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class ProfilingData:
|
|
27
|
+
"""Profiling data for performance metrics."""
|
|
28
|
+
ttft_us: int = 0 # Time to first token (us)
|
|
29
|
+
total_time_us: int = 0 # Total generation time (us)
|
|
30
|
+
prompt_time_us: int = 0 # Prompt processing time (us)
|
|
31
|
+
decode_time_us: int = 0 # Token generation time (us)
|
|
32
|
+
tokens_per_second: float = 0.0 # Decoding speed (tokens/sec)
|
|
33
|
+
total_tokens: int = 0 # Total tokens generated
|
|
34
|
+
prompt_tokens: int = 0 # Number of prompt tokens
|
|
35
|
+
generated_tokens: int = 0 # Number of generated tokens
|
|
36
|
+
stop_reason: int = StopReason.ML_STOP_REASON_UNKNOWN # Stop reason (numeric)
|
|
37
|
+
|
|
38
|
+
def reset(self):
|
|
39
|
+
"""Reset all profiling data."""
|
|
40
|
+
self.ttft_us = 0
|
|
41
|
+
self.total_time_us = 0
|
|
42
|
+
self.prompt_time_us = 0
|
|
43
|
+
self.decode_time_us = 0
|
|
44
|
+
self.tokens_per_second = 0.0
|
|
45
|
+
self.total_tokens = 0
|
|
46
|
+
self.prompt_tokens = 0
|
|
47
|
+
self.generated_tokens = 0
|
|
48
|
+
self.stop_reason = StopReason.ML_STOP_REASON_UNKNOWN
|
|
49
|
+
|
|
50
|
+
# --------------------------------------------------------------------------------------
|
|
51
|
+
# Profiling context (similar to ml_ProfilingContext in profile.h)
|
|
52
|
+
# --------------------------------------------------------------------------------------
|
|
53
|
+
|
|
54
|
+
@dataclass
|
|
55
|
+
class ProfilingContext:
|
|
56
|
+
"""Profiling context for tracking timing and state."""
|
|
57
|
+
start_time: Optional[float] = None
|
|
58
|
+
prompt_start_time: Optional[float] = None
|
|
59
|
+
prompt_end_time: Optional[float] = None
|
|
60
|
+
decode_start_time: Optional[float] = None
|
|
61
|
+
decode_end_time: Optional[float] = None
|
|
62
|
+
first_token_time: Optional[float] = None
|
|
63
|
+
end_time: Optional[float] = None
|
|
64
|
+
|
|
65
|
+
ttft_recorded: bool = False
|
|
66
|
+
stop_reason: int = StopReason.ML_STOP_REASON_UNKNOWN
|
|
67
|
+
prompt_tokens: int = 0
|
|
68
|
+
generated_tokens: int = 0
|
|
69
|
+
|
|
70
|
+
def reset(self):
|
|
71
|
+
"""Reset profiling context."""
|
|
72
|
+
self.start_time = None
|
|
73
|
+
self.prompt_start_time = None
|
|
74
|
+
self.prompt_end_time = None
|
|
75
|
+
self.decode_start_time = None
|
|
76
|
+
self.decode_end_time = None
|
|
77
|
+
self.first_token_time = None
|
|
78
|
+
self.end_time = None
|
|
79
|
+
self.ttft_recorded = False
|
|
80
|
+
self.stop_reason = StopReason.ML_STOP_REASON_UNKNOWN
|
|
81
|
+
self.prompt_tokens = 0
|
|
82
|
+
self.generated_tokens = 0
|
|
83
|
+
|
|
84
|
+
# --------------------------------------------------------------------------------------
|
|
85
|
+
# Profiling functions (similar to profile.h functions)
|
|
86
|
+
# --------------------------------------------------------------------------------------
|
|
87
|
+
|
|
88
|
+
def profiling_reset(ctx: ProfilingContext) -> None:
|
|
89
|
+
"""Reset profiling context (ml_profiling_reset)."""
|
|
90
|
+
ctx.reset()
|
|
91
|
+
|
|
92
|
+
def profiling_start(ctx: ProfilingContext) -> None:
|
|
93
|
+
"""Start profiling (ml_profiling_start)."""
|
|
94
|
+
ctx.start_time = time.perf_counter()
|
|
95
|
+
ctx.prompt_start_time = ctx.start_time
|
|
96
|
+
|
|
97
|
+
def profiling_prompt_start(ctx: ProfilingContext) -> None:
|
|
98
|
+
"""Start prompt processing timing (ml_profiling_prompt_start)."""
|
|
99
|
+
ctx.prompt_start_time = time.perf_counter()
|
|
100
|
+
|
|
101
|
+
def profiling_prompt_end(ctx: ProfilingContext) -> None:
|
|
102
|
+
"""End prompt processing timing (ml_profiling_prompt_end)."""
|
|
103
|
+
ctx.prompt_end_time = time.perf_counter()
|
|
104
|
+
|
|
105
|
+
def profiling_decode_start(ctx: ProfilingContext) -> None:
|
|
106
|
+
"""Start decode timing (ml_profiling_decode_start)."""
|
|
107
|
+
ctx.decode_start_time = time.perf_counter()
|
|
108
|
+
|
|
109
|
+
def profiling_decode_end(ctx: ProfilingContext) -> None:
|
|
110
|
+
"""End decode timing (ml_profiling_decode_end)."""
|
|
111
|
+
ctx.decode_end_time = time.perf_counter()
|
|
112
|
+
|
|
113
|
+
def profiling_record_ttft(ctx: ProfilingContext) -> None:
|
|
114
|
+
"""Record time to first token (ml_profiling_record_ttft)."""
|
|
115
|
+
if not ctx.ttft_recorded and ctx.start_time is not None:
|
|
116
|
+
ctx.first_token_time = time.perf_counter()
|
|
117
|
+
ctx.ttft_recorded = True
|
|
118
|
+
|
|
119
|
+
def profiling_update_prompt_tokens(ctx: ProfilingContext, prompt_tokens: int) -> None:
|
|
120
|
+
"""Update prompt token count (ml_profiling_update_prompt_tokens)."""
|
|
121
|
+
ctx.prompt_tokens = prompt_tokens
|
|
122
|
+
|
|
123
|
+
def profiling_update_generated_tokens(ctx: ProfilingContext, generated_tokens: int) -> None:
|
|
124
|
+
"""Update generated token count (ml_profiling_update_generated_tokens)."""
|
|
125
|
+
ctx.generated_tokens = generated_tokens
|
|
126
|
+
|
|
127
|
+
def profiling_stop_reason(ctx: ProfilingContext, stop_reason: int) -> None:
|
|
128
|
+
"""Set stop reason (ml_profiling_stop_reason)."""
|
|
129
|
+
ctx.stop_reason = stop_reason
|
|
130
|
+
|
|
131
|
+
def profiling_end(ctx: ProfilingContext) -> None:
|
|
132
|
+
"""End profiling (ml_profiling_end)."""
|
|
133
|
+
ctx.end_time = time.perf_counter()
|
|
134
|
+
|
|
135
|
+
def profiling_gen_data(ctx: ProfilingContext) -> ProfilingData:
|
|
136
|
+
"""Generate profiling data from context (ml_profiling_gen_data)."""
|
|
137
|
+
data = ProfilingData()
|
|
138
|
+
|
|
139
|
+
if ctx.start_time is None or ctx.end_time is None:
|
|
140
|
+
return data
|
|
141
|
+
|
|
142
|
+
# Calculate total time
|
|
143
|
+
data.total_time_us = int((ctx.end_time - ctx.start_time) * 1_000_000)
|
|
144
|
+
|
|
145
|
+
# Calculate prompt time
|
|
146
|
+
if ctx.prompt_start_time is not None and ctx.prompt_end_time is not None:
|
|
147
|
+
data.prompt_time_us = int((ctx.prompt_end_time - ctx.prompt_start_time) * 1_000_000)
|
|
148
|
+
|
|
149
|
+
# Calculate decode time
|
|
150
|
+
if ctx.decode_start_time is not None and ctx.decode_end_time is not None:
|
|
151
|
+
data.decode_time_us = int((ctx.decode_end_time - ctx.decode_start_time) * 1_000_000)
|
|
152
|
+
|
|
153
|
+
# Calculate TTFT
|
|
154
|
+
if ctx.first_token_time is not None and ctx.start_time is not None:
|
|
155
|
+
data.ttft_us = int((ctx.first_token_time - ctx.start_time) * 1_000_000)
|
|
156
|
+
|
|
157
|
+
# Set token counts
|
|
158
|
+
data.prompt_tokens = ctx.prompt_tokens
|
|
159
|
+
data.generated_tokens = ctx.generated_tokens
|
|
160
|
+
data.total_tokens = ctx.prompt_tokens + ctx.generated_tokens
|
|
161
|
+
|
|
162
|
+
# Calculate tokens per second
|
|
163
|
+
if data.decode_time_us > 0:
|
|
164
|
+
data.tokens_per_second = (data.generated_tokens * 1_000_000.0) / data.decode_time_us
|
|
165
|
+
|
|
166
|
+
# Set stop reason
|
|
167
|
+
data.stop_reason = ctx.stop_reason
|
|
168
|
+
|
|
169
|
+
return data
|
|
170
|
+
|
|
171
|
+
def stop_reason_to_string(reason: int) -> str:
|
|
172
|
+
"""Convert stop reason to string (stop_reason_to_string)."""
|
|
173
|
+
try:
|
|
174
|
+
return StopReason(reason).name
|
|
175
|
+
except ValueError:
|
|
176
|
+
return f"UNKNOWN({reason})"
|
|
177
|
+
|
|
178
|
+
# --------------------------------------------------------------------------------------
|
|
179
|
+
# Profiling mixin for model classes
|
|
180
|
+
# --------------------------------------------------------------------------------------
|
|
181
|
+
|
|
182
|
+
class ProfilingMixin:
|
|
183
|
+
"""Mixin class to add profiling capabilities to model classes."""
|
|
184
|
+
|
|
185
|
+
def __init__(self):
|
|
186
|
+
"""Initialize profiling mixin."""
|
|
187
|
+
self._profiling_context = ProfilingContext()
|
|
188
|
+
self._profiling_data = ProfilingData()
|
|
189
|
+
|
|
190
|
+
def _start_profiling(self) -> None:
|
|
191
|
+
"""Start profiling for an operation."""
|
|
192
|
+
profiling_reset(self._profiling_context)
|
|
193
|
+
profiling_start(self._profiling_context)
|
|
194
|
+
|
|
195
|
+
def _prompt_start(self) -> None:
|
|
196
|
+
"""Start prompt processing timing."""
|
|
197
|
+
profiling_prompt_start(self._profiling_context)
|
|
198
|
+
|
|
199
|
+
def _prompt_end(self) -> None:
|
|
200
|
+
"""End prompt processing timing."""
|
|
201
|
+
profiling_prompt_end(self._profiling_context)
|
|
202
|
+
|
|
203
|
+
def _decode_start(self) -> None:
|
|
204
|
+
"""Start decode timing."""
|
|
205
|
+
profiling_decode_start(self._profiling_context)
|
|
206
|
+
|
|
207
|
+
def _decode_end(self) -> None:
|
|
208
|
+
"""End decode timing."""
|
|
209
|
+
profiling_decode_end(self._profiling_context)
|
|
210
|
+
|
|
211
|
+
def _record_ttft(self) -> None:
|
|
212
|
+
"""Record time to first token."""
|
|
213
|
+
profiling_record_ttft(self._profiling_context)
|
|
214
|
+
|
|
215
|
+
def _update_prompt_tokens(self, prompt_tokens: int) -> None:
|
|
216
|
+
"""Update prompt token count."""
|
|
217
|
+
profiling_update_prompt_tokens(self._profiling_context, prompt_tokens)
|
|
218
|
+
|
|
219
|
+
def _update_generated_tokens(self, generated_tokens: int) -> None:
|
|
220
|
+
"""Update generated token count."""
|
|
221
|
+
profiling_update_generated_tokens(self._profiling_context, generated_tokens)
|
|
222
|
+
|
|
223
|
+
def _set_stop_reason(self, stop_reason: int) -> None:
|
|
224
|
+
"""Set stop reason."""
|
|
225
|
+
profiling_stop_reason(self._profiling_context, stop_reason)
|
|
226
|
+
|
|
227
|
+
def _end_profiling(self) -> ProfilingData:
|
|
228
|
+
"""End profiling and return data."""
|
|
229
|
+
profiling_end(self._profiling_context)
|
|
230
|
+
self._profiling_data = profiling_gen_data(self._profiling_context)
|
|
231
|
+
return self._profiling_data
|
|
232
|
+
|
|
233
|
+
def get_profiling_data(self) -> ProfilingData:
|
|
234
|
+
"""Get profiling data for the last operation."""
|
|
235
|
+
return self._profiling_data
|
|
236
|
+
|
|
237
|
+
def reset_profiling(self) -> None:
|
|
238
|
+
"""Reset profiling data."""
|
|
239
|
+
self._profiling_data.reset()
|
|
File without changes
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
# Copyright © Nexa AI
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import sys
|
|
16
|
+
import os
|
|
17
|
+
import mlx.core as mx
|
|
18
|
+
import mlx.nn as nn
|
|
19
|
+
import numpy as np
|
|
20
|
+
import time
|
|
21
|
+
|
|
22
|
+
from transformers import AutoTokenizer
|
|
23
|
+
from huggingface_hub import snapshot_download
|
|
24
|
+
from .modeling.nexa_jina_rerank import Model, ModelArgs
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
|
|
28
|
+
"""Create position ids from input ids, accounting for padding tokens"""
|
|
29
|
+
mask = (input_ids != padding_idx).astype(mx.int32)
|
|
30
|
+
incremental_indices = (mx.cumsum(mask, axis=1) + past_key_values_length) * mask
|
|
31
|
+
return incremental_indices.astype(mx.int32) + padding_idx
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def prepare_inputs(query, documents, tokenizer, max_length=1024):
|
|
35
|
+
"""Prepare inputs for the model - match torch exactly"""
|
|
36
|
+
sentence_pairs = [[query, doc] for doc in documents]
|
|
37
|
+
inputs = tokenizer(
|
|
38
|
+
sentence_pairs,
|
|
39
|
+
padding="max_length",
|
|
40
|
+
truncation=True,
|
|
41
|
+
return_tensors="np",
|
|
42
|
+
max_length=max_length,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
input_ids = mx.array(inputs["input_ids"]).astype(mx.int32)
|
|
46
|
+
seqlen = input_ids.shape[1]
|
|
47
|
+
attention_mask = mx.array(inputs["attention_mask"]).astype(mx.float32)
|
|
48
|
+
|
|
49
|
+
# Create token_type_ids as 1D tensor like torch, then broadcast for each batch item
|
|
50
|
+
token_type_ids_1d = mx.zeros(seqlen, dtype=mx.int32)
|
|
51
|
+
batch_size = input_ids.shape[0]
|
|
52
|
+
token_type_ids = mx.broadcast_to(
|
|
53
|
+
mx.expand_dims(token_type_ids_1d, axis=0), (batch_size, seqlen)
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
# Create position ids for each sequence in the batch
|
|
57
|
+
position_ids = create_position_ids_from_input_ids(input_ids, padding_idx=1)
|
|
58
|
+
|
|
59
|
+
return input_ids, attention_mask, token_type_ids, position_ids
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def load_model(model_id):
|
|
63
|
+
"""Initialize and load the Jina V2 rerank model."""
|
|
64
|
+
curr_dir = os.path.dirname(os.path.abspath(__file__))
|
|
65
|
+
model_dir = f"{curr_dir}/modelfiles/nexaml_jina_v2_rerank_mlx"
|
|
66
|
+
|
|
67
|
+
# Download model if not exists
|
|
68
|
+
if not os.path.exists(model_dir):
|
|
69
|
+
print(f"Downloading model {model_id}...")
|
|
70
|
+
|
|
71
|
+
os.makedirs(model_dir, exist_ok=True)
|
|
72
|
+
|
|
73
|
+
try:
|
|
74
|
+
snapshot_download(
|
|
75
|
+
repo_id=model_id,
|
|
76
|
+
allow_patterns=["*.safetensors", "config.json", "tokenizer*"],
|
|
77
|
+
local_dir=model_dir,
|
|
78
|
+
local_dir_use_symlinks=False
|
|
79
|
+
)
|
|
80
|
+
print("Model download completed!")
|
|
81
|
+
except Exception as e:
|
|
82
|
+
print(f"Failed to download model: {e}")
|
|
83
|
+
print("Try: huggingface-cli login (if authentication required)")
|
|
84
|
+
raise
|
|
85
|
+
|
|
86
|
+
# Create model config
|
|
87
|
+
config = ModelArgs()
|
|
88
|
+
model = Model(config)
|
|
89
|
+
|
|
90
|
+
# Load weights
|
|
91
|
+
weight_file = os.path.join(model_dir, "model.safetensors")
|
|
92
|
+
if not os.path.exists(weight_file):
|
|
93
|
+
# Try alternative naming patterns
|
|
94
|
+
safetensors_files = [f for f in os.listdir(model_dir) if f.endswith('.safetensors')]
|
|
95
|
+
if safetensors_files:
|
|
96
|
+
weight_file = os.path.join(model_dir, safetensors_files[0])
|
|
97
|
+
else:
|
|
98
|
+
raise FileNotFoundError(f"No .safetensors file found in {model_dir}")
|
|
99
|
+
|
|
100
|
+
print(f"Loading weights from: {weight_file}")
|
|
101
|
+
model.load_weights(weight_file, strict=True)
|
|
102
|
+
model.eval()
|
|
103
|
+
|
|
104
|
+
return model, model_dir
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def load_tokenizer(model_path):
|
|
108
|
+
"""Load and configure the tokenizer."""
|
|
109
|
+
return AutoTokenizer.from_pretrained(model_path)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def rerank_documents(model, tokenizer, query, documents, max_length=1024):
|
|
113
|
+
"""Rerank documents based on query relevance."""
|
|
114
|
+
# Prepare inputs
|
|
115
|
+
input_ids, attention_mask, token_type_ids, position_ids = prepare_inputs(
|
|
116
|
+
query, documents, tokenizer, max_length
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
# Run inference
|
|
120
|
+
start_time = time.time()
|
|
121
|
+
scores = model.nexa_forward(input_ids, attention_mask, token_type_ids, position_ids)
|
|
122
|
+
scores = mx.squeeze(scores, axis=-1)
|
|
123
|
+
end_time = time.time()
|
|
124
|
+
|
|
125
|
+
# Apply sigmoid to get probabilities
|
|
126
|
+
scores_sigmoid = mx.sigmoid(scores)
|
|
127
|
+
|
|
128
|
+
inference_time = (end_time - start_time) * 1000 # Convert to ms
|
|
129
|
+
|
|
130
|
+
return scores, scores_sigmoid, inference_time
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def main(model_id):
|
|
134
|
+
"""Main function to handle reranking demonstration."""
|
|
135
|
+
|
|
136
|
+
# Load model and tokenizer
|
|
137
|
+
model, model_path = load_model(model_id)
|
|
138
|
+
tokenizer = load_tokenizer(model_path)
|
|
139
|
+
|
|
140
|
+
# Example query and documents
|
|
141
|
+
query = "What are the health benefits of green tea?"
|
|
142
|
+
documents = [
|
|
143
|
+
"Green tea is rich in antioxidants and may improve brain function.",
|
|
144
|
+
"Coffee contains caffeine and can boost energy levels.",
|
|
145
|
+
"Das Trinken von grünem Tee kann das Risiko für Herzkrankheiten senken.",
|
|
146
|
+
"Black tea is another popular beverage with its own health benefits.",
|
|
147
|
+
]
|
|
148
|
+
|
|
149
|
+
# Perform reranking
|
|
150
|
+
scores, scores_sigmoid, inference_time = rerank_documents(
|
|
151
|
+
model, tokenizer, query, documents
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# Display results
|
|
155
|
+
print("=" * 70)
|
|
156
|
+
print("Reranking Results:")
|
|
157
|
+
print("=" * 70)
|
|
158
|
+
print(f"Query: {query}")
|
|
159
|
+
print()
|
|
160
|
+
|
|
161
|
+
for i, (doc, score, prob) in enumerate(zip(documents, scores.tolist(), scores_sigmoid.tolist())):
|
|
162
|
+
print(f"Document {i+1}:")
|
|
163
|
+
print(f" Text: {doc}")
|
|
164
|
+
print(f" Score: {score:.4f}")
|
|
165
|
+
print(f" Probability: {prob:.4f}")
|
|
166
|
+
print()
|
|
167
|
+
|
|
168
|
+
print(f"Inference time: {inference_time:.1f}ms")
|
|
169
|
+
print(f"Throughput: {len(documents)/inference_time*1000:.1f} docs/s")
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
if __name__ == "__main__":
|
|
173
|
+
model_id = "nexaml/jina-v2-rerank-mlx"
|
|
174
|
+
main(model_id)
|
|
@@ -0,0 +1,287 @@
|
|
|
1
|
+
# Copyright © Nexa AI
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.s
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
import os
|
|
18
|
+
import json
|
|
19
|
+
import mlx.core as mx
|
|
20
|
+
import mlx.nn as nn
|
|
21
|
+
import numpy as np
|
|
22
|
+
import time
|
|
23
|
+
from pathlib import Path
|
|
24
|
+
from typing import Any, List, Optional, Sequence
|
|
25
|
+
from dataclasses import dataclass
|
|
26
|
+
from abc import ABC, abstractmethod
|
|
27
|
+
|
|
28
|
+
# Import necessary modules
|
|
29
|
+
from transformers import AutoTokenizer
|
|
30
|
+
|
|
31
|
+
# Import from ml.py for API alignment (assuming similar structure)
|
|
32
|
+
try:
|
|
33
|
+
from ml import (
|
|
34
|
+
Reranker as BaseReranker,
|
|
35
|
+
Path as PathType,
|
|
36
|
+
)
|
|
37
|
+
except ImportError:
|
|
38
|
+
# Fallback to local definitions if ml.py not available
|
|
39
|
+
PathType = Path
|
|
40
|
+
BaseReranker = ABC
|
|
41
|
+
|
|
42
|
+
# Import profiling module
|
|
43
|
+
from profiling import ProfilingMixin, ProfilingData, StopReason
|
|
44
|
+
|
|
45
|
+
# Import the model implementation
|
|
46
|
+
from .modeling.nexa_jina_rerank import Model, ModelArgs
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclass
|
|
50
|
+
class RerankConfig:
|
|
51
|
+
"""Configuration for reranking."""
|
|
52
|
+
batch_size: int = 1
|
|
53
|
+
normalize: bool = True
|
|
54
|
+
normalize_method: str = "softmax" # "softmax" | "min-max" | "none"
|
|
55
|
+
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
batch_size: int = 1,
|
|
59
|
+
normalize: bool = True,
|
|
60
|
+
normalize_method: str = "softmax",
|
|
61
|
+
) -> None:
|
|
62
|
+
self.batch_size = batch_size
|
|
63
|
+
self.normalize = normalize
|
|
64
|
+
self.normalize_method = normalize_method
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class Reranker(BaseReranker, ProfilingMixin):
|
|
68
|
+
"""
|
|
69
|
+
Reranker interface for MLX reranking models.
|
|
70
|
+
API aligned with ml.py Reranker abstract base class.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
model_path: PathType,
|
|
76
|
+
tokenizer_path: PathType,
|
|
77
|
+
device: Optional[str] = None,
|
|
78
|
+
) -> None:
|
|
79
|
+
"""Initialize the Reranker model."""
|
|
80
|
+
# Initialize profiling mixin
|
|
81
|
+
ProfilingMixin.__init__(self)
|
|
82
|
+
|
|
83
|
+
# Store paths
|
|
84
|
+
if (os.path.isfile(model_path)):
|
|
85
|
+
model_path = os.path.dirname(model_path)
|
|
86
|
+
|
|
87
|
+
# Call parent constructor if inheriting from ml.py
|
|
88
|
+
if hasattr(super(), '__init__'):
|
|
89
|
+
super().__init__(model_path, tokenizer_path, device)
|
|
90
|
+
|
|
91
|
+
# Store paths and device
|
|
92
|
+
self.model_path = model_path
|
|
93
|
+
self.tokenizer_path = tokenizer_path
|
|
94
|
+
self.device = device if device is not None else "cpu"
|
|
95
|
+
|
|
96
|
+
# Initialize model and tokenizer as None
|
|
97
|
+
self.model = None
|
|
98
|
+
self.tokenizer = None
|
|
99
|
+
self.config = None
|
|
100
|
+
|
|
101
|
+
def destroy(self) -> None:
|
|
102
|
+
"""Destroy the model and free resources."""
|
|
103
|
+
self.model = None
|
|
104
|
+
self.tokenizer = None
|
|
105
|
+
self.config = None
|
|
106
|
+
|
|
107
|
+
def load_model(self, model_path: PathType, extra_data: Any = None) -> bool:
|
|
108
|
+
"""Load model from path."""
|
|
109
|
+
try:
|
|
110
|
+
# Use the provided model_path or fall back to instance path
|
|
111
|
+
if model_path:
|
|
112
|
+
# Apply same file-to-directory conversion as in __init__
|
|
113
|
+
if os.path.isfile(model_path):
|
|
114
|
+
model_path = os.path.dirname(model_path)
|
|
115
|
+
self.model_path = model_path
|
|
116
|
+
|
|
117
|
+
# Load the model using internal implementation
|
|
118
|
+
self.model = self._load_jina_model(self.model_path)
|
|
119
|
+
self.tokenizer = self._load_tokenizer()
|
|
120
|
+
|
|
121
|
+
return True
|
|
122
|
+
except Exception as e:
|
|
123
|
+
print(f"Failed to load model: {e}")
|
|
124
|
+
return False
|
|
125
|
+
|
|
126
|
+
def close(self) -> None:
|
|
127
|
+
"""Close the model."""
|
|
128
|
+
self.destroy()
|
|
129
|
+
|
|
130
|
+
def rerank(
|
|
131
|
+
self,
|
|
132
|
+
query: str,
|
|
133
|
+
documents: Sequence[str],
|
|
134
|
+
config: Optional[RerankConfig] = None,
|
|
135
|
+
clear_cache: bool = True,
|
|
136
|
+
) -> mx.array:
|
|
137
|
+
"""Rerank documents given a query."""
|
|
138
|
+
if self.model is None or self.tokenizer is None:
|
|
139
|
+
raise RuntimeError("Model not loaded. Call load_model() first.")
|
|
140
|
+
|
|
141
|
+
if config is None:
|
|
142
|
+
config = RerankConfig()
|
|
143
|
+
|
|
144
|
+
# Start profiling
|
|
145
|
+
self._start_profiling()
|
|
146
|
+
self._prompt_start()
|
|
147
|
+
|
|
148
|
+
all_scores = []
|
|
149
|
+
|
|
150
|
+
# Process documents in batches
|
|
151
|
+
batch_size = config.batch_size
|
|
152
|
+
for i in range(0, len(documents), batch_size):
|
|
153
|
+
batch_docs = documents[i:i + batch_size]
|
|
154
|
+
batch_scores = self._rerank_batch(query, batch_docs, config)
|
|
155
|
+
all_scores.append(batch_scores)
|
|
156
|
+
|
|
157
|
+
if clear_cache:
|
|
158
|
+
mx.clear_cache()
|
|
159
|
+
|
|
160
|
+
# End prompt processing, start decode
|
|
161
|
+
self._prompt_end()
|
|
162
|
+
self._decode_start()
|
|
163
|
+
|
|
164
|
+
# Concatenate all batch scores into a single array
|
|
165
|
+
res = mx.concatenate(all_scores, axis=0) if len(all_scores) > 1 else all_scores[0]
|
|
166
|
+
|
|
167
|
+
# End decode and profiling
|
|
168
|
+
self._decode_end()
|
|
169
|
+
self._set_stop_reason(StopReason.ML_STOP_REASON_COMPLETED)
|
|
170
|
+
self._end_profiling()
|
|
171
|
+
|
|
172
|
+
return res
|
|
173
|
+
|
|
174
|
+
def _load_jina_model(self, model_dir: str) -> Model:
|
|
175
|
+
"""Initialize and load the Jina V2 rerank model."""
|
|
176
|
+
|
|
177
|
+
# Validate that model path exists
|
|
178
|
+
if not os.path.exists(model_dir):
|
|
179
|
+
raise ValueError(f"Model path does not exist: {model_dir}")
|
|
180
|
+
|
|
181
|
+
# Store model directory for tokenizer loading
|
|
182
|
+
self._model_dir = model_dir
|
|
183
|
+
|
|
184
|
+
# Create model config
|
|
185
|
+
config = ModelArgs()
|
|
186
|
+
model = Model(config)
|
|
187
|
+
|
|
188
|
+
# Load weights
|
|
189
|
+
weight_file = os.path.join(model_dir, "model.safetensors")
|
|
190
|
+
if not os.path.exists(weight_file):
|
|
191
|
+
# Try alternative naming patterns
|
|
192
|
+
safetensors_files = [f for f in os.listdir(model_dir) if f.endswith('.safetensors')]
|
|
193
|
+
if safetensors_files:
|
|
194
|
+
weight_file = os.path.join(model_dir, safetensors_files[0])
|
|
195
|
+
else:
|
|
196
|
+
raise FileNotFoundError(f"No .safetensors file found in {model_dir}")
|
|
197
|
+
|
|
198
|
+
model.load_weights(weight_file, strict=True)
|
|
199
|
+
model.eval()
|
|
200
|
+
|
|
201
|
+
return model
|
|
202
|
+
|
|
203
|
+
def _load_tokenizer(self) -> AutoTokenizer:
|
|
204
|
+
"""Load and configure the tokenizer."""
|
|
205
|
+
return AutoTokenizer.from_pretrained(self._model_dir)
|
|
206
|
+
|
|
207
|
+
def _rerank_batch(self, query: str, documents: List[str], config: RerankConfig) -> mx.array:
|
|
208
|
+
"""Rerank a batch of documents and return their scores."""
|
|
209
|
+
# Prepare inputs
|
|
210
|
+
input_ids, attention_mask, token_type_ids, position_ids = self._prepare_inputs(
|
|
211
|
+
query, documents, self.tokenizer, max_length=1024
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
# Run inference
|
|
215
|
+
scores = self.model.nexa_forward(input_ids, attention_mask, token_type_ids, position_ids)
|
|
216
|
+
scores = mx.squeeze(scores, axis=-1)
|
|
217
|
+
|
|
218
|
+
# Apply normalization if requested
|
|
219
|
+
if config.normalize:
|
|
220
|
+
scores = self._normalize_scores(scores, config.normalize_method)
|
|
221
|
+
|
|
222
|
+
return scores
|
|
223
|
+
|
|
224
|
+
def _create_position_ids_from_input_ids(self, input_ids, padding_idx, past_key_values_length=0):
|
|
225
|
+
"""Create position ids from input ids, accounting for padding tokens"""
|
|
226
|
+
mask = (input_ids != padding_idx).astype(mx.int32)
|
|
227
|
+
incremental_indices = (mx.cumsum(mask, axis=1) + past_key_values_length) * mask
|
|
228
|
+
return incremental_indices.astype(mx.int32) + padding_idx
|
|
229
|
+
|
|
230
|
+
def _prepare_inputs(self, query, documents, tokenizer, max_length=1024):
|
|
231
|
+
"""Prepare inputs for the model - match torch exactly"""
|
|
232
|
+
sentence_pairs = [[query, doc] for doc in documents]
|
|
233
|
+
inputs = tokenizer(
|
|
234
|
+
sentence_pairs,
|
|
235
|
+
padding="max_length",
|
|
236
|
+
truncation=True,
|
|
237
|
+
return_tensors="np",
|
|
238
|
+
max_length=max_length,
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
input_ids = mx.array(inputs["input_ids"]).astype(mx.int32)
|
|
242
|
+
seqlen = input_ids.shape[1]
|
|
243
|
+
attention_mask = mx.array(inputs["attention_mask"]).astype(mx.float32)
|
|
244
|
+
|
|
245
|
+
# Create token_type_ids as 1D tensor like torch, then broadcast for each batch item
|
|
246
|
+
token_type_ids_1d = mx.zeros(seqlen, dtype=mx.int32)
|
|
247
|
+
batch_size = input_ids.shape[0]
|
|
248
|
+
token_type_ids = mx.broadcast_to(
|
|
249
|
+
mx.expand_dims(token_type_ids_1d, axis=0), (batch_size, seqlen)
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
# Create position ids for each sequence in the batch
|
|
253
|
+
position_ids = self._create_position_ids_from_input_ids(input_ids, padding_idx=1)
|
|
254
|
+
|
|
255
|
+
return input_ids, attention_mask, token_type_ids, position_ids
|
|
256
|
+
|
|
257
|
+
def _normalize_scores(self, scores: mx.array, method: str) -> mx.array:
|
|
258
|
+
"""Normalize scores using specified method."""
|
|
259
|
+
if method == "none":
|
|
260
|
+
return scores
|
|
261
|
+
elif method == "softmax":
|
|
262
|
+
# For 1D arrays, use axis=0; for higher dims, use axis=-1
|
|
263
|
+
if len(scores.shape) == 1:
|
|
264
|
+
return mx.softmax(scores, axis=0)
|
|
265
|
+
else:
|
|
266
|
+
return mx.softmax(scores, axis=-1)
|
|
267
|
+
elif method == "min-max":
|
|
268
|
+
min_val = mx.min(scores)
|
|
269
|
+
max_val = mx.max(scores)
|
|
270
|
+
if max_val > min_val:
|
|
271
|
+
return (scores - min_val) / (max_val - min_val)
|
|
272
|
+
return scores
|
|
273
|
+
else:
|
|
274
|
+
return scores
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
# Factory function for creating reranker instances
|
|
278
|
+
def create_reranker(
|
|
279
|
+
model_path: PathType,
|
|
280
|
+
tokenizer_path: Optional[PathType] = None,
|
|
281
|
+
device: Optional[str] = None,
|
|
282
|
+
) -> Reranker:
|
|
283
|
+
"""Create and return a Reranker instance."""
|
|
284
|
+
if tokenizer_path is None:
|
|
285
|
+
tokenizer_path = model_path
|
|
286
|
+
|
|
287
|
+
return Reranker(model_path, tokenizer_path, device)
|