nexaai 1.0.4rc10__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nexaai might be problematic. Click here for more details.
- nexaai/__init__.py +71 -0
- nexaai/_version.py +4 -0
- nexaai/asr.py +60 -0
- nexaai/asr_impl/__init__.py +0 -0
- nexaai/asr_impl/mlx_asr_impl.py +91 -0
- nexaai/asr_impl/pybind_asr_impl.py +43 -0
- nexaai/base.py +39 -0
- nexaai/binds/__init__.py +3 -0
- nexaai/binds/common_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/embedder_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/libnexa_bridge.dylib +0 -0
- nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/nexa_llama_cpp/libggml-base.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libggml-cpu.so +0 -0
- nexaai/binds/nexa_llama_cpp/libggml-metal.so +0 -0
- nexaai/binds/nexa_llama_cpp/libggml.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libllama.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libmtmd.dylib +0 -0
- nexaai/binds/nexa_llama_cpp/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexa_mlx/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexa_mlx/py-lib/ml.py +842 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/__init__.py +5 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/activation.py +51 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/amp.py +96 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/conv.py +114 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/resample.py +177 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/base.py +228 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/dac.py +285 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/nn/layers.py +129 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/encodec/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/encodec/encodec.py +777 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/mimi.py +286 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/model.py +260 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/model_v2.py +383 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/utils.py +122 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/attention.py +97 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/layers.py +306 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/snac.py +154 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/vq.py +135 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/vocos/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/vocos/mel.py +33 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/vocos/vocos.py +359 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_bigvgan.py +54 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_descript.py +109 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_encodec.py +58 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_mimi.py +22 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_s3.py +25 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_snac.py +40 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_vocos.py +93 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/server.py +525 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/sts/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/sts/voice_pipeline.py +327 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/generate.py +174 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/alignment.py +248 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/attention.py +187 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/audio.py +76 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/conformer.py +331 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/ctc.py +34 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/audio.py +82 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/decoding.py +742 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/timing.py +329 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/whisper.py +862 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/writers.py +268 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/tests/test_models.py +381 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/utils.py +195 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/audio_player.py +120 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/convert.py +71 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/generate.py +449 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/__init__.py +4 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/bark.py +528 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/isftnet.py +12 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/pipeline.py +442 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/base.py +84 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/audio.py +287 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/config.py +256 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/dia.py +592 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/layers.py +870 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/__init__.py +3 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/attention.py +180 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/conformer.py +247 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/gpt2.py +38 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/indextts.py +412 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/mel.py +37 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/normalize.py +294 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/perceiver.py +62 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/interpolate.py +108 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/__init__.py +4 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/modules.py +659 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/voice.py +113 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/llama/__init__.py +3 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/llama/llama.py +324 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/outetts.py +255 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/tokens.py +36 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/__init__.py +3 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/attention.py +195 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/sesame.py +633 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/watermarking.py +105 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/bicodec.py +269 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/residual.py +209 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/spark.py +382 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/utils/audio.py +220 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/utils/file.py +221 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/__init__.py +0 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_base.py +66 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_convert.py +173 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_interpolate.py +88 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_models.py +974 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/utils.py +337 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/utils.py +237 -0
- nexaai/binds/nexa_mlx/py-lib/mlx_audio/version.py +1 -0
- nexaai/binds/nexa_mlx/py-lib/profiling.py +239 -0
- nexaai/common.py +61 -0
- nexaai/cv.py +87 -0
- nexaai/cv_impl/__init__.py +0 -0
- nexaai/cv_impl/mlx_cv_impl.py +88 -0
- nexaai/cv_impl/pybind_cv_impl.py +31 -0
- nexaai/embedder.py +68 -0
- nexaai/embedder_impl/__init__.py +0 -0
- nexaai/embedder_impl/mlx_embedder_impl.py +114 -0
- nexaai/embedder_impl/pybind_embedder_impl.py +91 -0
- nexaai/image_gen.py +136 -0
- nexaai/image_gen_impl/__init__.py +0 -0
- nexaai/image_gen_impl/mlx_image_gen_impl.py +291 -0
- nexaai/image_gen_impl/pybind_image_gen_impl.py +84 -0
- nexaai/llm.py +89 -0
- nexaai/llm_impl/__init__.py +0 -0
- nexaai/llm_impl/mlx_llm_impl.py +249 -0
- nexaai/llm_impl/pybind_llm_impl.py +207 -0
- nexaai/mlx_backend/asr/__init__.py +12 -0
- nexaai/mlx_backend/asr/interface.py +122 -0
- nexaai/mlx_backend/common/__init__.py +0 -0
- nexaai/mlx_backend/common/utils.py +25 -0
- nexaai/mlx_backend/cv/__init__.py +0 -0
- nexaai/mlx_backend/cv/generate.py +195 -0
- nexaai/mlx_backend/cv/interface.py +151 -0
- nexaai/mlx_backend/cv/main.py +81 -0
- nexaai/mlx_backend/cv/modeling/pp_ocr_v4.py +1736 -0
- nexaai/mlx_backend/embedding/__init__.py +0 -0
- nexaai/mlx_backend/embedding/generate.py +130 -0
- nexaai/mlx_backend/embedding/interface.py +312 -0
- nexaai/mlx_backend/embedding/main.py +82 -0
- nexaai/mlx_backend/embedding/modeling/__init__.py +0 -0
- nexaai/mlx_backend/embedding/modeling/nexa_jina_v2.py +399 -0
- nexaai/mlx_backend/llm/__init__.py +0 -0
- nexaai/mlx_backend/llm/generate.py +149 -0
- nexaai/mlx_backend/llm/interface.py +764 -0
- nexaai/mlx_backend/llm/main.py +68 -0
- nexaai/mlx_backend/ml.py +842 -0
- nexaai/mlx_backend/mlx_audio/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/codec/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/__init__.py +5 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/activation.py +51 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/amp.py +96 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/conv.py +114 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/resample.py +177 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/base.py +228 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/dac.py +285 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/layers.py +129 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
- nexaai/mlx_backend/mlx_audio/codec/models/encodec/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/encodec/encodec.py +777 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/mimi.py +286 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/model.py +260 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/model_v2.py +383 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/utils.py +122 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/attention.py +97 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/layers.py +306 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/snac.py +154 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/vq.py +135 -0
- nexaai/mlx_backend/mlx_audio/codec/models/vocos/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/vocos/mel.py +33 -0
- nexaai/mlx_backend/mlx_audio/codec/models/vocos/vocos.py +359 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_bigvgan.py +54 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_descript.py +109 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_encodec.py +58 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_mimi.py +22 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_s3.py +25 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_snac.py +40 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_vocos.py +93 -0
- nexaai/mlx_backend/mlx_audio/server.py +525 -0
- nexaai/mlx_backend/mlx_audio/sts/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
- nexaai/mlx_backend/mlx_audio/sts/voice_pipeline.py +327 -0
- nexaai/mlx_backend/mlx_audio/stt/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/stt/generate.py +174 -0
- nexaai/mlx_backend/mlx_audio/stt/models/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/alignment.py +248 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/attention.py +187 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/audio.py +76 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/conformer.py +331 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/ctc.py +34 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
- nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
- nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/audio.py +82 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/decoding.py +742 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/timing.py +329 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/whisper.py +862 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/writers.py +268 -0
- nexaai/mlx_backend/mlx_audio/stt/tests/test_models.py +381 -0
- nexaai/mlx_backend/mlx_audio/stt/utils.py +195 -0
- nexaai/mlx_backend/mlx_audio/tts/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/audio_player.py +120 -0
- nexaai/mlx_backend/mlx_audio/tts/convert.py +71 -0
- nexaai/mlx_backend/mlx_audio/tts/generate.py +449 -0
- nexaai/mlx_backend/mlx_audio/tts/models/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/__init__.py +4 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/bark.py +528 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/isftnet.py +12 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/pipeline.py +442 -0
- nexaai/mlx_backend/mlx_audio/tts/models/base.py +84 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/audio.py +287 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/config.py +256 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/dia.py +592 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/layers.py +870 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/__init__.py +3 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/attention.py +180 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/conformer.py +247 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/gpt2.py +38 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/indextts.py +412 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/mel.py +37 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/normalize.py +294 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/perceiver.py +62 -0
- nexaai/mlx_backend/mlx_audio/tts/models/interpolate.py +108 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/__init__.py +4 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/modules.py +659 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/voice.py +113 -0
- nexaai/mlx_backend/mlx_audio/tts/models/llama/__init__.py +3 -0
- nexaai/mlx_backend/mlx_audio/tts/models/llama/llama.py +324 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/default_speaker.json +461 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/outetts.py +255 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/tokens.py +36 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/__init__.py +3 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/attention.py +195 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/sesame.py +633 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/watermarking.py +105 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/bicodec.py +269 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual.py +209 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/spark.py +382 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/audio.py +220 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/file.py +221 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_base.py +66 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_convert.py +173 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_interpolate.py +88 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_models.py +974 -0
- nexaai/mlx_backend/mlx_audio/tts/utils.py +337 -0
- nexaai/mlx_backend/mlx_audio/utils.py +237 -0
- nexaai/mlx_backend/mlx_audio/version.py +1 -0
- nexaai/mlx_backend/profiling.py +239 -0
- nexaai/mlx_backend/rerank/__init__.py +0 -0
- nexaai/mlx_backend/rerank/generate.py +174 -0
- nexaai/mlx_backend/rerank/interface.py +287 -0
- nexaai/mlx_backend/rerank/main.py +127 -0
- nexaai/mlx_backend/rerank/modeling/__init__.py +0 -0
- nexaai/mlx_backend/rerank/modeling/nexa_jina_rerank.py +330 -0
- nexaai/mlx_backend/sd/__init__.py +1 -0
- nexaai/mlx_backend/sd/interface.py +362 -0
- nexaai/mlx_backend/sd/main.py +286 -0
- nexaai/mlx_backend/sd/modeling/__init__.py +306 -0
- nexaai/mlx_backend/sd/modeling/clip.py +116 -0
- nexaai/mlx_backend/sd/modeling/config.py +65 -0
- nexaai/mlx_backend/sd/modeling/model_io.py +330 -0
- nexaai/mlx_backend/sd/modeling/sampler.py +105 -0
- nexaai/mlx_backend/sd/modeling/tokenizer.py +100 -0
- nexaai/mlx_backend/sd/modeling/unet.py +460 -0
- nexaai/mlx_backend/sd/modeling/vae.py +274 -0
- nexaai/mlx_backend/tts/__init__.py +12 -0
- nexaai/mlx_backend/tts/interface.py +276 -0
- nexaai/mlx_backend/vlm/__init__.py +3 -0
- nexaai/mlx_backend/vlm/generate.py +572 -0
- nexaai/mlx_backend/vlm/interface.py +406 -0
- nexaai/mlx_backend/vlm/main.py +157 -0
- nexaai/mlx_backend/vlm/modeling/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/convert.py +68 -0
- nexaai/mlx_backend/vlm/modeling/models/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/interpolate.py +186 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/language.py +233 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/vision.py +503 -0
- nexaai/mlx_backend/vlm/modeling/models/base.py +202 -0
- nexaai/mlx_backend/vlm/modeling/models/cache.py +230 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/florence2.py +366 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/language.py +488 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/vision.py +591 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/gemma3.py +213 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/language.py +315 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/vision.py +238 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/__init__.py +2 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/audio.py +1038 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/config.py +139 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/language.py +629 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/vision.py +1022 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/idefics2.py +294 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/language.py +191 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/vision.py +267 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/idefics3.py +175 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/language.py +192 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/vision.py +233 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/language.py +220 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/processor.py +393 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/vision.py +293 -0
- nexaai/mlx_backend/vlm/modeling/models/kernels.py +307 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/language.py +509 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/vision.py +522 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/language.py +386 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/llama4.py +138 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/vision.py +560 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/language.py +240 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/llava.py +153 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/vision.py +259 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/language.py +236 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/vision.py +303 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/language.py +230 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/llava_next.py +160 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/vision.py +243 -0
- nexaai/mlx_backend/vlm/modeling/models/mistral3/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/mistral3/mistral3.py +283 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/language.py +416 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/mllama.py +172 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/vision.py +499 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/language.py +243 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/molmo.py +133 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/vision.py +465 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/__init__.py +10 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/language.py +230 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/sam.py +557 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/vision.py +526 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/language.py +282 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/paligemma.py +160 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/vision.py +242 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/language.py +21 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/su_rope.py +71 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/vision.py +324 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/language.py +229 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/pixtral.py +161 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/vision.py +320 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/config.py +104 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/language.py +490 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/vision.py +312 -0
- nexaai/mlx_backend/vlm/modeling/models/smolvlm/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
- nexaai/mlx_backend/vlm/modeling/processing_qwen2_5_vl.py +209 -0
- nexaai/mlx_backend/vlm/modeling/processing_qwen2_vl.py +215 -0
- nexaai/mlx_backend/vlm/modeling/prompt_utils.py +474 -0
- nexaai/mlx_backend/vlm/modeling/sample_utils.py +39 -0
- nexaai/mlx_backend/vlm/modeling/tokenizer_utils.py +344 -0
- nexaai/mlx_backend/vlm/modeling/trainer/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/trainer/lora.py +70 -0
- nexaai/mlx_backend/vlm/modeling/trainer/trainer.py +296 -0
- nexaai/mlx_backend/vlm/modeling/trainer/utils.py +160 -0
- nexaai/mlx_backend/vlm/modeling/utils.py +928 -0
- nexaai/rerank.py +51 -0
- nexaai/rerank_impl/__init__.py +0 -0
- nexaai/rerank_impl/mlx_rerank_impl.py +91 -0
- nexaai/rerank_impl/pybind_rerank_impl.py +42 -0
- nexaai/runtime.py +64 -0
- nexaai/tts.py +70 -0
- nexaai/tts_impl/__init__.py +0 -0
- nexaai/tts_impl/mlx_tts_impl.py +93 -0
- nexaai/tts_impl/pybind_tts_impl.py +42 -0
- nexaai/utils/avatar_fetcher.py +104 -0
- nexaai/utils/decode.py +18 -0
- nexaai/utils/model_manager.py +1195 -0
- nexaai/utils/progress_tracker.py +372 -0
- nexaai/vlm.py +120 -0
- nexaai/vlm_impl/__init__.py +0 -0
- nexaai/vlm_impl/mlx_vlm_impl.py +205 -0
- nexaai/vlm_impl/pybind_vlm_impl.py +228 -0
- nexaai-1.0.4rc10.dist-info/METADATA +26 -0
- nexaai-1.0.4rc10.dist-info/RECORD +519 -0
- nexaai-1.0.4rc10.dist-info/WHEEL +5 -0
- nexaai-1.0.4rc10.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,406 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
import time
|
|
6
|
+
from typing import Any, List, Optional, Sequence, Tuple, Union
|
|
7
|
+
import mlx.core as mx
|
|
8
|
+
import codecs
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
|
|
11
|
+
# Import configs and callback types from ml.py for API alignment
|
|
12
|
+
from ml import (
|
|
13
|
+
VLM as BaseVLM,
|
|
14
|
+
SamplerConfig,
|
|
15
|
+
GenerationConfig,
|
|
16
|
+
ChatMessage,
|
|
17
|
+
EmbeddingConfig,
|
|
18
|
+
TokenCallback,
|
|
19
|
+
Path,
|
|
20
|
+
Tool, # Add Path alias for type hints
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
# Import profiling module
|
|
24
|
+
from profiling import ProfilingMixin, ProfilingData, StopReason
|
|
25
|
+
|
|
26
|
+
# Import from the actual mlx_vlm structure
|
|
27
|
+
from .generate import generate, stream_generate, load
|
|
28
|
+
from .modeling.prompt_utils import apply_chat_template
|
|
29
|
+
|
|
30
|
+
# --------------------------------------------------------------------------------------
|
|
31
|
+
# Updated GenerationResult to match the new structure
|
|
32
|
+
# --------------------------------------------------------------------------------------
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class GenerationResult:
|
|
36
|
+
text: str = ""
|
|
37
|
+
token: Optional[int] = None
|
|
38
|
+
logprobs: Optional[List[float]] = None
|
|
39
|
+
prompt_tokens: int = 0
|
|
40
|
+
generation_tokens: int = 0
|
|
41
|
+
total_tokens: int = 0
|
|
42
|
+
prompt_tps: float = 0.0
|
|
43
|
+
generation_tps: float = 0.0
|
|
44
|
+
peak_memory: float = 0.0
|
|
45
|
+
# --------------------------------------------------------------------------------------
|
|
46
|
+
# VLM (Vision-Language Model)
|
|
47
|
+
# --------------------------------------------------------------------------------------
|
|
48
|
+
|
|
49
|
+
class VLM(ProfilingMixin):
|
|
50
|
+
"""
|
|
51
|
+
Vision-Language Models for mlx-vlm
|
|
52
|
+
API aligned with ml.py VLM abstract base class.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
model_path: Path,
|
|
58
|
+
mmproj_path: Path,
|
|
59
|
+
context_length: int,
|
|
60
|
+
device: Optional[str] = None,
|
|
61
|
+
) -> None:
|
|
62
|
+
# Initialize profiling mixin
|
|
63
|
+
ProfilingMixin.__init__(self)
|
|
64
|
+
|
|
65
|
+
# Check if model_path is a file, if so use its parent directory
|
|
66
|
+
if os.path.isfile(model_path):
|
|
67
|
+
model_path = os.path.dirname(model_path)
|
|
68
|
+
|
|
69
|
+
self.model_path = model_path
|
|
70
|
+
self.mmproj_path = mmproj_path
|
|
71
|
+
self.context_length = context_length
|
|
72
|
+
self.device = device
|
|
73
|
+
|
|
74
|
+
self.model, self.processor = load(str(model_path))
|
|
75
|
+
|
|
76
|
+
# Init deafutl sampler config with defualt.
|
|
77
|
+
self.sampler_config = SamplerConfig()
|
|
78
|
+
|
|
79
|
+
def destroy(self) -> None:
|
|
80
|
+
"""Destroy the model and free resources."""
|
|
81
|
+
self.model = None
|
|
82
|
+
self.processor = None
|
|
83
|
+
|
|
84
|
+
def reset(self) -> None:
|
|
85
|
+
"""Reset the model state."""
|
|
86
|
+
self._reset_cache()
|
|
87
|
+
|
|
88
|
+
def _reset_cache(self) -> None:
|
|
89
|
+
"""Reset the KV cache."""
|
|
90
|
+
# If the model has a cache, reset it
|
|
91
|
+
if hasattr(self.model, "cache"):
|
|
92
|
+
self.model.cache = None
|
|
93
|
+
|
|
94
|
+
# Tokenization
|
|
95
|
+
def encode(self, text: str) -> List[int]:
|
|
96
|
+
"""Encode text to token IDs."""
|
|
97
|
+
return self.processor.encode(text)
|
|
98
|
+
|
|
99
|
+
def decode(self, token_ids: Sequence[int]) -> str:
|
|
100
|
+
"""Decode token IDs to text."""
|
|
101
|
+
return self.processor.decode(token_ids)
|
|
102
|
+
|
|
103
|
+
# Sampler
|
|
104
|
+
def set_sampler(self, config: SamplerConfig) -> None:
|
|
105
|
+
"""Set sampler configuration."""
|
|
106
|
+
self.sampler_config = config
|
|
107
|
+
|
|
108
|
+
def reset_sampler(self) -> None:
|
|
109
|
+
"""Reset sampler to default configuration."""
|
|
110
|
+
self.sampler_config = None
|
|
111
|
+
|
|
112
|
+
# Generation
|
|
113
|
+
def generate(
|
|
114
|
+
self,
|
|
115
|
+
prompt: str,
|
|
116
|
+
config: Optional[GenerationConfig] = None,
|
|
117
|
+
) -> GenerationResult:
|
|
118
|
+
"""Generate text from prompt."""
|
|
119
|
+
# Start profiling
|
|
120
|
+
self._start_profiling()
|
|
121
|
+
|
|
122
|
+
gen_kwargs = {}
|
|
123
|
+
if config is not None:
|
|
124
|
+
gen_kwargs = config.__dict__.copy()
|
|
125
|
+
# Remove image_paths and audio_paths from config as they'll be handled separately
|
|
126
|
+
gen_kwargs.pop('image_paths', None)
|
|
127
|
+
gen_kwargs.pop('audio_paths', None)
|
|
128
|
+
if self.sampler_config is not None:
|
|
129
|
+
gen_kwargs.update(self.sampler_config.__dict__)
|
|
130
|
+
|
|
131
|
+
# Get image and audio paths from config
|
|
132
|
+
image_paths = config.image_paths if config else None
|
|
133
|
+
audio_paths = config.audio_paths if config else None
|
|
134
|
+
|
|
135
|
+
# Convert paths to strings for generate function
|
|
136
|
+
image_list = [str(path) for path in image_paths] if image_paths else None
|
|
137
|
+
audio_list = [str(path) for path in audio_paths] if audio_paths else None
|
|
138
|
+
|
|
139
|
+
# End prompt processing, start decode
|
|
140
|
+
self._prompt_end()
|
|
141
|
+
self._decode_start()
|
|
142
|
+
|
|
143
|
+
try:
|
|
144
|
+
# Start timing for generation
|
|
145
|
+
generation_start_time = time.perf_counter()
|
|
146
|
+
|
|
147
|
+
text, stats = generate(
|
|
148
|
+
self.model,
|
|
149
|
+
self.processor,
|
|
150
|
+
prompt,
|
|
151
|
+
image=image_list,
|
|
152
|
+
audio=audio_list,
|
|
153
|
+
**gen_kwargs,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
# End timing for generation
|
|
157
|
+
generation_end_time = time.perf_counter()
|
|
158
|
+
|
|
159
|
+
# Calculate average time per token and estimate TTFT
|
|
160
|
+
generated_tokens = stats.get("output_tokens", 0)
|
|
161
|
+
if generated_tokens > 0:
|
|
162
|
+
total_generation_time = generation_end_time - generation_start_time
|
|
163
|
+
avg_time_per_token = total_generation_time / generated_tokens
|
|
164
|
+
# TTFT = prompt processing time + first token generation time
|
|
165
|
+
# This provides a more accurate estimate than the previous approximation
|
|
166
|
+
estimated_ttft = (self._profiling_context.prompt_end_time - self._profiling_context.prompt_start_time) + avg_time_per_token
|
|
167
|
+
# Update the profiling context with estimated TTFT
|
|
168
|
+
self._profiling_context.first_token_time = self._profiling_context.prompt_start_time + estimated_ttft
|
|
169
|
+
self._profiling_context.ttft_recorded = True
|
|
170
|
+
else:
|
|
171
|
+
# If no tokens generated, use total generation time as TTFT
|
|
172
|
+
self._record_ttft()
|
|
173
|
+
|
|
174
|
+
# Update profiling data
|
|
175
|
+
prompt_tokens = stats.get("input_tokens", 0)
|
|
176
|
+
self._update_prompt_tokens(prompt_tokens)
|
|
177
|
+
self._update_generated_tokens(generated_tokens)
|
|
178
|
+
self._set_stop_reason(StopReason.ML_STOP_REASON_COMPLETED)
|
|
179
|
+
self._decode_end()
|
|
180
|
+
self._end_profiling()
|
|
181
|
+
|
|
182
|
+
return GenerationResult(
|
|
183
|
+
text=text,
|
|
184
|
+
prompt_tokens=prompt_tokens,
|
|
185
|
+
generation_tokens=generated_tokens,
|
|
186
|
+
total_tokens=stats.get("total_tokens", 0),
|
|
187
|
+
prompt_tps=stats.get("prompt_tps", 0.0),
|
|
188
|
+
generation_tps=stats.get("generation_tps", 0.0),
|
|
189
|
+
peak_memory=stats.get("peak_memory", 0.0),
|
|
190
|
+
)
|
|
191
|
+
except Exception as e:
|
|
192
|
+
self._set_stop_reason(StopReason.ML_STOP_REASON_UNKNOWN)
|
|
193
|
+
self._decode_end()
|
|
194
|
+
self._end_profiling()
|
|
195
|
+
raise RuntimeError(f"Generation error: {str(e)}")
|
|
196
|
+
|
|
197
|
+
def generate_stream(
|
|
198
|
+
self,
|
|
199
|
+
prompt: str,
|
|
200
|
+
config: Optional[GenerationConfig],
|
|
201
|
+
on_token: Optional[TokenCallback],
|
|
202
|
+
) -> GenerationResult:
|
|
203
|
+
"""Generate text with streaming callback. Unified method for both text and multimodal generation."""
|
|
204
|
+
# Start profiling
|
|
205
|
+
self._start_profiling()
|
|
206
|
+
|
|
207
|
+
gen_kwargs = {}
|
|
208
|
+
if config is not None:
|
|
209
|
+
gen_kwargs = config.__dict__.copy()
|
|
210
|
+
# Remove image_paths and audio_paths from config as they'll be handled separately
|
|
211
|
+
gen_kwargs.pop('image_paths', None)
|
|
212
|
+
gen_kwargs.pop('audio_paths', None)
|
|
213
|
+
if self.sampler_config is not None:
|
|
214
|
+
gen_kwargs.update(self.sampler_config.__dict__)
|
|
215
|
+
|
|
216
|
+
# Get image and audio paths from config
|
|
217
|
+
image_paths = config.image_paths if config else None
|
|
218
|
+
audio_paths = config.audio_paths if config else None
|
|
219
|
+
|
|
220
|
+
# Convert paths to strings for stream_generate function
|
|
221
|
+
image_list = [str(path) for path in image_paths] if image_paths else None
|
|
222
|
+
audio_list = [str(path) for path in audio_paths] if audio_paths else None
|
|
223
|
+
|
|
224
|
+
# End prompt processing, start decode
|
|
225
|
+
self._prompt_end()
|
|
226
|
+
self._decode_start()
|
|
227
|
+
|
|
228
|
+
text = ""
|
|
229
|
+
last_result = None
|
|
230
|
+
first_token = True
|
|
231
|
+
|
|
232
|
+
try:
|
|
233
|
+
for result in stream_generate(
|
|
234
|
+
self.model,
|
|
235
|
+
self.processor,
|
|
236
|
+
prompt,
|
|
237
|
+
image=image_list,
|
|
238
|
+
audio=audio_list,
|
|
239
|
+
**gen_kwargs,
|
|
240
|
+
):
|
|
241
|
+
# Record TTFT on first token
|
|
242
|
+
if first_token:
|
|
243
|
+
self._record_ttft()
|
|
244
|
+
first_token = False
|
|
245
|
+
|
|
246
|
+
# Call the token callback if provided
|
|
247
|
+
if on_token is not None:
|
|
248
|
+
if not on_token(result.text):
|
|
249
|
+
self._set_stop_reason(StopReason.ML_STOP_REASON_USER)
|
|
250
|
+
break
|
|
251
|
+
text += result.text
|
|
252
|
+
last_result = result
|
|
253
|
+
|
|
254
|
+
# Set stop reason if not user stop
|
|
255
|
+
if self._profiling_context.stop_reason != StopReason.ML_STOP_REASON_USER:
|
|
256
|
+
self._set_stop_reason(StopReason.ML_STOP_REASON_EOS)
|
|
257
|
+
|
|
258
|
+
# Update profiling data
|
|
259
|
+
if last_result:
|
|
260
|
+
self._update_prompt_tokens(last_result.prompt_tokens)
|
|
261
|
+
self._update_generated_tokens(last_result.generation_tokens)
|
|
262
|
+
|
|
263
|
+
self._decode_end()
|
|
264
|
+
self._end_profiling()
|
|
265
|
+
|
|
266
|
+
return GenerationResult(
|
|
267
|
+
text=text,
|
|
268
|
+
token=last_result.token if last_result else None,
|
|
269
|
+
logprobs=last_result.logprobs if last_result else None,
|
|
270
|
+
prompt_tokens=last_result.prompt_tokens if last_result else 0,
|
|
271
|
+
generation_tokens=last_result.generation_tokens if last_result else 0,
|
|
272
|
+
total_tokens=(last_result.prompt_tokens + last_result.generation_tokens) if last_result else 0,
|
|
273
|
+
prompt_tps=last_result.prompt_tps if last_result else 0.0,
|
|
274
|
+
generation_tps=last_result.generation_tps if last_result else 0.0,
|
|
275
|
+
peak_memory=last_result.peak_memory if last_result else 0.0,
|
|
276
|
+
)
|
|
277
|
+
except Exception as e:
|
|
278
|
+
self._set_stop_reason(StopReason.ML_STOP_REASON_UNKNOWN)
|
|
279
|
+
self._decode_end()
|
|
280
|
+
self._end_profiling()
|
|
281
|
+
raise RuntimeError(f"Streaming generation error: {str(e)}")
|
|
282
|
+
|
|
283
|
+
# Legacy multimodal methods - kept for backward compatibility but delegate to unified method
|
|
284
|
+
def generate_multimodal(
|
|
285
|
+
self,
|
|
286
|
+
prompt: str,
|
|
287
|
+
image_paths: Optional[Sequence[Path]] = None,
|
|
288
|
+
audio_paths: Optional[Sequence[Path]] = None,
|
|
289
|
+
config: Optional[GenerationConfig] = None,
|
|
290
|
+
) -> str:
|
|
291
|
+
"""Generate text from prompt with multiple images and audio."""
|
|
292
|
+
# Create config with media paths if not provided
|
|
293
|
+
if config is None:
|
|
294
|
+
config = GenerationConfig()
|
|
295
|
+
|
|
296
|
+
# Update config with provided paths
|
|
297
|
+
if image_paths is not None:
|
|
298
|
+
config.image_paths = image_paths
|
|
299
|
+
if audio_paths is not None:
|
|
300
|
+
config.audio_paths = audio_paths
|
|
301
|
+
|
|
302
|
+
# Delegate to unified generate method and extract text
|
|
303
|
+
result = self.generate(prompt, config)
|
|
304
|
+
return result.text
|
|
305
|
+
|
|
306
|
+
def generate_stream_multimodal(
|
|
307
|
+
self,
|
|
308
|
+
prompt: str,
|
|
309
|
+
image_paths: Optional[Sequence[Path]] = None,
|
|
310
|
+
audio_paths: Optional[Sequence[Path]] = None,
|
|
311
|
+
config: Optional[GenerationConfig] = None,
|
|
312
|
+
on_token: Optional[TokenCallback] = None,
|
|
313
|
+
) -> str:
|
|
314
|
+
"""Generate text from prompt with multiple images and audio using streaming callback."""
|
|
315
|
+
# Create config with media paths if not provided
|
|
316
|
+
if config is None:
|
|
317
|
+
config = GenerationConfig()
|
|
318
|
+
|
|
319
|
+
# Update config with provided paths
|
|
320
|
+
if image_paths is not None:
|
|
321
|
+
config.image_paths = image_paths
|
|
322
|
+
if audio_paths is not None:
|
|
323
|
+
config.audio_paths = audio_paths
|
|
324
|
+
|
|
325
|
+
# Delegate to unified generate_stream method and extract text
|
|
326
|
+
result = self.generate_stream(prompt, config, on_token)
|
|
327
|
+
return result.text
|
|
328
|
+
|
|
329
|
+
def get_chat_template(self, template_name: str) -> str:
|
|
330
|
+
"""Get chat template by name."""
|
|
331
|
+
# This is a stub; actual implementation depends on processor internals
|
|
332
|
+
if hasattr(self.processor, "get_chat_template"):
|
|
333
|
+
return self.processor.get_chat_template(template_name)
|
|
334
|
+
return ""
|
|
335
|
+
|
|
336
|
+
def apply_chat_template(self, messages: Sequence[ChatMessage], tools: Optional[str] = None, enable_thinking: bool = True) -> str:
|
|
337
|
+
"""Apply chat template to messages with optional tools support."""
|
|
338
|
+
if hasattr(self.processor, "apply_chat_template"):
|
|
339
|
+
# Convert ChatMessage objects to dictionaries for the processor
|
|
340
|
+
messages_dict = [{"role": msg.role, "content": msg.content} for msg in messages]
|
|
341
|
+
|
|
342
|
+
parsed_tools = None
|
|
343
|
+
if tools is not None and tools.strip():
|
|
344
|
+
parsed_tools = json.loads(tools)
|
|
345
|
+
|
|
346
|
+
result = apply_chat_template(self.processor, self.model.config, messages_dict, add_generation_prompt=True, enable_thinking=enable_thinking, tools=parsed_tools)
|
|
347
|
+
return result
|
|
348
|
+
# Fallback: join messages
|
|
349
|
+
return "\n".join([f"{m.role}: {m.content}" for m in messages])
|
|
350
|
+
|
|
351
|
+
def apply_chat_template_with_media(self, messages: Sequence[ChatMessage], num_images: int = 0, num_audios: int = 0, tools: Optional[str] = None, enable_thinking: bool = True) -> str:
|
|
352
|
+
"""Apply chat template to messages with proper image/audio token insertion and optional tools support."""
|
|
353
|
+
# Convert ChatMessage objects to dictionaries for the processor
|
|
354
|
+
messages_dict = [{"role": msg.role, "content": msg.content} for msg in messages]
|
|
355
|
+
|
|
356
|
+
parsed_tools = None
|
|
357
|
+
if tools is not None and tools.strip():
|
|
358
|
+
parsed_tools = json.loads(tools)
|
|
359
|
+
|
|
360
|
+
# Use the same logic as generate.py
|
|
361
|
+
return apply_chat_template(
|
|
362
|
+
self.processor,
|
|
363
|
+
self.model.config,
|
|
364
|
+
messages_dict,
|
|
365
|
+
num_images=num_images,
|
|
366
|
+
num_audios=num_audios,
|
|
367
|
+
enable_thinking=enable_thinking,
|
|
368
|
+
tools=parsed_tools
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
# Embeddings
|
|
372
|
+
def embed(
|
|
373
|
+
self,
|
|
374
|
+
texts: Sequence[str],
|
|
375
|
+
config: Optional[EmbeddingConfig] = None,
|
|
376
|
+
) -> List[List[float]]:
|
|
377
|
+
"""Generate embeddings for texts with profiling."""
|
|
378
|
+
# Start profiling
|
|
379
|
+
self._start_profiling()
|
|
380
|
+
|
|
381
|
+
try:
|
|
382
|
+
# If processor/model supports embeddings, use it; otherwise, stub
|
|
383
|
+
if hasattr(self.model, "embed"):
|
|
384
|
+
embed_kwargs = config.__dict__ if config else {}
|
|
385
|
+
|
|
386
|
+
# End prompt processing, start decode
|
|
387
|
+
self._prompt_end()
|
|
388
|
+
self._decode_start()
|
|
389
|
+
|
|
390
|
+
result = self.model.embed(texts, **embed_kwargs)
|
|
391
|
+
|
|
392
|
+
# End timing and finalize profiling data
|
|
393
|
+
self._update_generated_tokens(0) # No generation in embedding
|
|
394
|
+
self._set_stop_reason(StopReason.ML_STOP_REASON_COMPLETED)
|
|
395
|
+
self._decode_end()
|
|
396
|
+
self._end_profiling()
|
|
397
|
+
|
|
398
|
+
return result
|
|
399
|
+
else:
|
|
400
|
+
raise NotImplementedError("Embedding not supported for this model.")
|
|
401
|
+
|
|
402
|
+
except Exception as e:
|
|
403
|
+
self._set_stop_reason(StopReason.ML_STOP_REASON_UNKNOWN)
|
|
404
|
+
self._decode_end()
|
|
405
|
+
self._end_profiling()
|
|
406
|
+
raise RuntimeError(f"Error generating embeddings: {str(e)}")
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
from .interface import VLM
|
|
2
|
+
from ml import GenerationConfig, SamplerConfig, ChatMessage
|
|
3
|
+
import re
|
|
4
|
+
import os
|
|
5
|
+
import codecs
|
|
6
|
+
|
|
7
|
+
def parse_media_from_input(user_input):
|
|
8
|
+
"""Parse quoted media files from user input and return prompt and media paths"""
|
|
9
|
+
# Find all quoted strings (both single and double quotes)
|
|
10
|
+
quoted_pattern = r'["\']([^"\']*)["\']'
|
|
11
|
+
quoted_matches = re.findall(quoted_pattern, user_input)
|
|
12
|
+
|
|
13
|
+
# Remove quoted strings from the input to get the actual prompt
|
|
14
|
+
prompt = re.sub(quoted_pattern, '', user_input).strip()
|
|
15
|
+
|
|
16
|
+
# Separate image and audio files based on extensions
|
|
17
|
+
image_extensions = {'.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff', '.webp'}
|
|
18
|
+
audio_extensions = {'.mp3', '.wav', '.flac', '.aac', '.ogg', '.m4a'}
|
|
19
|
+
|
|
20
|
+
image_paths = []
|
|
21
|
+
audio_paths = []
|
|
22
|
+
|
|
23
|
+
for quoted_file in quoted_matches:
|
|
24
|
+
if quoted_file: # Skip empty quotes
|
|
25
|
+
# Expand user path if it starts with ~
|
|
26
|
+
if quoted_file.startswith('~'):
|
|
27
|
+
quoted_file = os.path.expanduser(quoted_file)
|
|
28
|
+
|
|
29
|
+
# Check if file exists
|
|
30
|
+
if not os.path.exists(quoted_file):
|
|
31
|
+
print(f"Warning: File '{quoted_file}' not found")
|
|
32
|
+
continue
|
|
33
|
+
|
|
34
|
+
file_ext = os.path.splitext(quoted_file.lower())[1]
|
|
35
|
+
if file_ext in image_extensions:
|
|
36
|
+
image_paths.append(quoted_file)
|
|
37
|
+
elif file_ext in audio_extensions:
|
|
38
|
+
audio_paths.append(quoted_file)
|
|
39
|
+
|
|
40
|
+
return prompt, image_paths if image_paths else None, audio_paths if audio_paths else None
|
|
41
|
+
|
|
42
|
+
def test_vlm_generate_stream(model_path):
|
|
43
|
+
# Specify the checkpoint
|
|
44
|
+
context_length = 2048
|
|
45
|
+
|
|
46
|
+
# Load the corresponding model and VLM instance
|
|
47
|
+
vlm = VLM(
|
|
48
|
+
model_path=model_path,
|
|
49
|
+
mmproj_path=None, # Not needed for this model
|
|
50
|
+
context_length=context_length,
|
|
51
|
+
device=None
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
# Configure sampler
|
|
55
|
+
sampler_config = SamplerConfig(
|
|
56
|
+
temperature=0.7,
|
|
57
|
+
top_p=0.9
|
|
58
|
+
)
|
|
59
|
+
vlm.set_sampler(sampler_config)
|
|
60
|
+
|
|
61
|
+
# Chat history using ChatMessage objects (following ml.py API)
|
|
62
|
+
chat = []
|
|
63
|
+
|
|
64
|
+
print("Multi-round VLM conversation started. Type 'quit' or 'exit' to end.")
|
|
65
|
+
print("Include images/audios in quotes, e.g.: 'describe \"image1.jpg\" \"image2.png\"'")
|
|
66
|
+
print("You can also use single quotes: 'describe '/path/to/image.jpg''")
|
|
67
|
+
print("=" * 50)
|
|
68
|
+
|
|
69
|
+
def on_token(text_chunk, user_data):
|
|
70
|
+
"""Token callback for streaming"""
|
|
71
|
+
print(text_chunk, end="", flush=True)
|
|
72
|
+
if user_data is not None:
|
|
73
|
+
user_data["response"] += text_chunk
|
|
74
|
+
return True
|
|
75
|
+
|
|
76
|
+
while True:
|
|
77
|
+
# Get user input
|
|
78
|
+
user_input = input("\nUser: ").strip()
|
|
79
|
+
|
|
80
|
+
# Check for exit commands
|
|
81
|
+
if user_input.lower() in ["quit", "exit", "q"]:
|
|
82
|
+
print("Goodbye!")
|
|
83
|
+
break
|
|
84
|
+
|
|
85
|
+
if not user_input:
|
|
86
|
+
continue
|
|
87
|
+
|
|
88
|
+
# Parse media files and prompt from user input
|
|
89
|
+
prompt_text, image_paths, audio_paths = parse_media_from_input(user_input)
|
|
90
|
+
|
|
91
|
+
print(f"image_paths: {image_paths}")
|
|
92
|
+
print(f"audio_paths: {audio_paths}")
|
|
93
|
+
|
|
94
|
+
# If no text prompt after parsing, use the original input
|
|
95
|
+
if not prompt_text.strip():
|
|
96
|
+
prompt_text = user_input
|
|
97
|
+
image_paths = None
|
|
98
|
+
audio_paths = None
|
|
99
|
+
|
|
100
|
+
# Add user message to chat history using ChatMessage (following ml.py API)
|
|
101
|
+
chat.append(ChatMessage(role="user", content=prompt_text))
|
|
102
|
+
|
|
103
|
+
# Calculate number of images and audios for chat template
|
|
104
|
+
num_images = len(image_paths) if image_paths else 0
|
|
105
|
+
num_audios = len(audio_paths) if audio_paths else 0
|
|
106
|
+
|
|
107
|
+
# Apply chat template with image/audio token insertion
|
|
108
|
+
try:
|
|
109
|
+
formatted_prompt = vlm.apply_chat_template_with_media(chat, num_images=num_images, num_audios=num_audios)
|
|
110
|
+
except (NotImplementedError, AttributeError):
|
|
111
|
+
# Fallback to manual formatting if chat template is not implemented
|
|
112
|
+
formatted_prompt = ""
|
|
113
|
+
for msg in chat:
|
|
114
|
+
formatted_prompt += f"{msg.role}: {msg.content}\n"
|
|
115
|
+
formatted_prompt += "Assistant: "
|
|
116
|
+
|
|
117
|
+
# Generation config with media paths
|
|
118
|
+
generation_config = GenerationConfig(
|
|
119
|
+
max_tokens=512,
|
|
120
|
+
sampler_config=sampler_config,
|
|
121
|
+
image_paths=image_paths,
|
|
122
|
+
audio_paths=audio_paths
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
# Generate response
|
|
126
|
+
print("Assistant: ", end="", flush=True)
|
|
127
|
+
|
|
128
|
+
try:
|
|
129
|
+
# Use streaming generation with callback - single method handles all cases
|
|
130
|
+
user_data = {"response": ""}
|
|
131
|
+
|
|
132
|
+
# Always use the unified generate_stream method
|
|
133
|
+
response = vlm.generate_stream(
|
|
134
|
+
prompt=formatted_prompt,
|
|
135
|
+
config=generation_config,
|
|
136
|
+
on_token=on_token,
|
|
137
|
+
user_data=user_data
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
print() # New line after streaming
|
|
141
|
+
|
|
142
|
+
# Add assistant response to chat history using ChatMessage
|
|
143
|
+
chat.append(ChatMessage(role="assistant", content=user_data["response"]))
|
|
144
|
+
|
|
145
|
+
except Exception as e:
|
|
146
|
+
print(f"Error generating response: {e}")
|
|
147
|
+
print()
|
|
148
|
+
|
|
149
|
+
# Clean up
|
|
150
|
+
vlm.destroy()
|
|
151
|
+
|
|
152
|
+
if __name__ == "__main__":
|
|
153
|
+
import argparse
|
|
154
|
+
parser = argparse.ArgumentParser()
|
|
155
|
+
parser.add_argument("--model_path", type=str, default="mlx-community/gemma-3-4b-it-8bit")
|
|
156
|
+
args = parser.parse_args()
|
|
157
|
+
test_vlm_generate_stream(args.model_path)
|
|
File without changes
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
# Copyright © 2023-2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
|
|
5
|
+
from .utils import MODEL_CONVERSION_DTYPES, convert
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def configure_parser() -> argparse.ArgumentParser:
|
|
9
|
+
"""
|
|
10
|
+
Configures and returns the argument parser for the script.
|
|
11
|
+
|
|
12
|
+
Returns:
|
|
13
|
+
argparse.ArgumentParser: Configured argument parser.
|
|
14
|
+
"""
|
|
15
|
+
parser = argparse.ArgumentParser(
|
|
16
|
+
description="Convert Hugging Face model to MLX format"
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
parser.add_argument("--hf-path", type=str, help="Path to the Hugging Face model.")
|
|
20
|
+
parser.add_argument(
|
|
21
|
+
"--mlx-path", type=str, default="mlx_model", help="Path to save the MLX model."
|
|
22
|
+
)
|
|
23
|
+
parser.add_argument(
|
|
24
|
+
"-q", "--quantize", help="Generate a quantized model.", action="store_true"
|
|
25
|
+
)
|
|
26
|
+
parser.add_argument(
|
|
27
|
+
"--q-group-size", help="Group size for quantization.", type=int, default=64
|
|
28
|
+
)
|
|
29
|
+
parser.add_argument(
|
|
30
|
+
"--q-bits", help="Bits per weight for quantization.", type=int, default=4
|
|
31
|
+
)
|
|
32
|
+
parser.add_argument(
|
|
33
|
+
"--dtype",
|
|
34
|
+
help="Type to save the parameter. Defaults to config.json's `torch_dtype` or the current model weights dtype",
|
|
35
|
+
type=str,
|
|
36
|
+
choices=MODEL_CONVERSION_DTYPES,
|
|
37
|
+
default=None,
|
|
38
|
+
)
|
|
39
|
+
parser.add_argument(
|
|
40
|
+
"--upload-repo",
|
|
41
|
+
help="The Hugging Face repo to upload the model to.",
|
|
42
|
+
type=str,
|
|
43
|
+
default=None,
|
|
44
|
+
)
|
|
45
|
+
parser.add_argument(
|
|
46
|
+
"-d",
|
|
47
|
+
"--dequantize",
|
|
48
|
+
help="Dequantize a quantized model.",
|
|
49
|
+
action="store_true",
|
|
50
|
+
default=False,
|
|
51
|
+
)
|
|
52
|
+
parser.add_argument(
|
|
53
|
+
"--skip-vision",
|
|
54
|
+
help="Skip vision module quantization.",
|
|
55
|
+
action="store_true",
|
|
56
|
+
default=False,
|
|
57
|
+
)
|
|
58
|
+
return parser
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def main():
|
|
62
|
+
parser = configure_parser()
|
|
63
|
+
args = parser.parse_args()
|
|
64
|
+
convert(**vars(args))
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
if __name__ == "__main__":
|
|
68
|
+
main()
|
|
File without changes
|