nexaai 1.0.29__cp310-cp310-macosx_14_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nexaai/__init__.py +99 -0
- nexaai/_stub.cpython-310-darwin.so +0 -0
- nexaai/_version.py +4 -0
- nexaai/asr.py +68 -0
- nexaai/asr_impl/__init__.py +0 -0
- nexaai/asr_impl/mlx_asr_impl.py +93 -0
- nexaai/asr_impl/pybind_asr_impl.py +127 -0
- nexaai/base.py +39 -0
- nexaai/binds/__init__.py +7 -0
- nexaai/binds/asr_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/common_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/cpu_gpu/libggml-base.dylib +0 -0
- nexaai/binds/cpu_gpu/libggml-cpu.so +0 -0
- nexaai/binds/cpu_gpu/libggml-metal.so +0 -0
- nexaai/binds/cpu_gpu/libggml.dylib +0 -0
- nexaai/binds/cpu_gpu/libmtmd.dylib +0 -0
- nexaai/binds/cpu_gpu/libnexa_cpu_gpu.dylib +0 -0
- nexaai/binds/cpu_gpu/libnexa_plugin.dylib +0 -0
- nexaai/binds/cv_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/diarize_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/embedder_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/libnexa_bridge.dylib +0 -0
- nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/metal/libnexa_plugin.dylib +0 -0
- nexaai/binds/metal/py-lib/ml.py +888 -0
- nexaai/binds/metal/py-lib/mlx_audio/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/__init__.py +5 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/activation.py +51 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/amp.py +96 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/conv.py +114 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/resample.py +177 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/base.py +228 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/dac.py +285 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/layers.py +129 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/encodec.py +777 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/mimi.py +286 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model.py +260 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model_v2.py +383 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/utils.py +122 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/attention.py +97 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/layers.py +306 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/snac.py +154 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/vq.py +135 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/mel.py +33 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/vocos.py +359 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_bigvgan.py +54 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_descript.py +109 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_encodec.py +58 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_mimi.py +22 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_s3.py +25 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_snac.py +40 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_vocos.py +93 -0
- nexaai/binds/metal/py-lib/mlx_audio/server.py +525 -0
- nexaai/binds/metal/py-lib/mlx_audio/sts/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
- nexaai/binds/metal/py-lib/mlx_audio/sts/voice_pipeline.py +327 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/generate.py +174 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/alignment.py +248 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/attention.py +187 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/audio.py +76 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/conformer.py +331 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/ctc.py +34 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/audio.py +82 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/decoding.py +742 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/timing.py +329 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/whisper.py +862 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/writers.py +268 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/tests/test_models.py +381 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/utils.py +195 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/audio_player.py +120 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/convert.py +71 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/generate.py +449 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/__init__.py +4 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/bark.py +528 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/isftnet.py +12 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/pipeline.py +442 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/base.py +84 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/audio.py +287 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/config.py +256 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/dia.py +592 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/layers.py +870 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/__init__.py +3 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/attention.py +180 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/conformer.py +247 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/gpt2.py +38 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/indextts.py +412 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/mel.py +37 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/normalize.py +294 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/perceiver.py +62 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/interpolate.py +108 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/__init__.py +4 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/modules.py +659 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/voice.py +113 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/__init__.py +3 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/llama.py +324 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/outetts.py +255 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/tokens.py +36 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/__init__.py +3 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/attention.py +195 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/sesame.py +633 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/watermarking.py +105 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/bicodec.py +269 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual.py +209 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/spark.py +382 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/audio.py +220 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/file.py +221 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_base.py +66 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_convert.py +173 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_interpolate.py +88 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_models.py +974 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/utils.py +337 -0
- nexaai/binds/metal/py-lib/mlx_audio/utils.py +237 -0
- nexaai/binds/metal/py-lib/mlx_audio/version.py +1 -0
- nexaai/binds/metal/py-lib/profiling.py +239 -0
- nexaai/binds/nexaml/libfftw3.3.dylib +0 -0
- nexaai/binds/nexaml/libfftw3f.3.dylib +0 -0
- nexaai/binds/nexaml/libggml-base.dylib +0 -0
- nexaai/binds/nexaml/libggml-cpu.so +0 -0
- nexaai/binds/nexaml/libggml-metal.so +0 -0
- nexaai/binds/nexaml/libggml.dylib +0 -0
- nexaai/binds/nexaml/libmp3lame.0.dylib +0 -0
- nexaai/binds/nexaml/libmpg123.0.dylib +0 -0
- nexaai/binds/nexaml/libnexa-mm-process.dylib +0 -0
- nexaai/binds/nexaml/libnexa-sampling.dylib +0 -0
- nexaai/binds/nexaml/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexaml/libnexaproc.dylib +0 -0
- nexaai/binds/nexaml/libomp.dylib +0 -0
- nexaai/binds/nexaml/libqwen3-vl.dylib +0 -0
- nexaai/binds/nexaml/libqwen3vl-vision.dylib +0 -0
- nexaai/binds/rerank_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/vlm_bind.cpython-310-darwin.so +0 -0
- nexaai/common.py +106 -0
- nexaai/cv.py +95 -0
- nexaai/cv_impl/__init__.py +0 -0
- nexaai/cv_impl/mlx_cv_impl.py +91 -0
- nexaai/cv_impl/pybind_cv_impl.py +124 -0
- nexaai/diarize.py +80 -0
- nexaai/diarize_impl/__init__.py +1 -0
- nexaai/diarize_impl/pybind_diarize_impl.py +125 -0
- nexaai/embedder.py +73 -0
- nexaai/embedder_impl/__init__.py +0 -0
- nexaai/embedder_impl/mlx_embedder_impl.py +118 -0
- nexaai/embedder_impl/pybind_embedder_impl.py +96 -0
- nexaai/image_gen.py +141 -0
- nexaai/image_gen_impl/__init__.py +0 -0
- nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -0
- nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -0
- nexaai/llm.py +98 -0
- nexaai/llm_impl/__init__.py +0 -0
- nexaai/llm_impl/mlx_llm_impl.py +271 -0
- nexaai/llm_impl/pybind_llm_impl.py +238 -0
- nexaai/log.py +92 -0
- nexaai/mlx_backend/asr/__init__.py +12 -0
- nexaai/mlx_backend/asr/interface.py +122 -0
- nexaai/mlx_backend/common/__init__.py +0 -0
- nexaai/mlx_backend/common/utils.py +25 -0
- nexaai/mlx_backend/cv/__init__.py +0 -0
- nexaai/mlx_backend/cv/generate.py +195 -0
- nexaai/mlx_backend/cv/interface.py +162 -0
- nexaai/mlx_backend/cv/main.py +81 -0
- nexaai/mlx_backend/cv/modeling/pp_ocr_v4.py +1736 -0
- nexaai/mlx_backend/embedding/__init__.py +0 -0
- nexaai/mlx_backend/embedding/generate.py +333 -0
- nexaai/mlx_backend/embedding/interface.py +617 -0
- nexaai/mlx_backend/embedding/main.py +173 -0
- nexaai/mlx_backend/embedding/modeling/__init__.py +0 -0
- nexaai/mlx_backend/embedding/modeling/nexa_jina_v2.py +399 -0
- nexaai/mlx_backend/image_gen/__init__.py +1 -0
- nexaai/mlx_backend/image_gen/generate_sd.py +244 -0
- nexaai/mlx_backend/image_gen/interface.py +82 -0
- nexaai/mlx_backend/image_gen/main.py +281 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/__init__.py +306 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/clip.py +116 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/config.py +65 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/model_io.py +386 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/sampler.py +105 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/tokenizer.py +100 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/unet.py +460 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/vae.py +274 -0
- nexaai/mlx_backend/llm/__init__.py +0 -0
- nexaai/mlx_backend/llm/generate.py +149 -0
- nexaai/mlx_backend/llm/interface.py +764 -0
- nexaai/mlx_backend/llm/main.py +68 -0
- nexaai/mlx_backend/ml.py +888 -0
- nexaai/mlx_backend/mlx_audio/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/codec/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/__init__.py +5 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/activation.py +51 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/amp.py +96 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/conv.py +114 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/resample.py +177 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/base.py +228 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/dac.py +285 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/layers.py +129 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
- nexaai/mlx_backend/mlx_audio/codec/models/encodec/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/encodec/encodec.py +777 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/mimi.py +286 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/model.py +260 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/model_v2.py +383 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/utils.py +122 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/attention.py +97 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/layers.py +306 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/snac.py +154 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/vq.py +135 -0
- nexaai/mlx_backend/mlx_audio/codec/models/vocos/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/vocos/mel.py +33 -0
- nexaai/mlx_backend/mlx_audio/codec/models/vocos/vocos.py +359 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_bigvgan.py +54 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_descript.py +109 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_encodec.py +58 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_mimi.py +22 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_s3.py +25 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_snac.py +40 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_vocos.py +93 -0
- nexaai/mlx_backend/mlx_audio/server.py +525 -0
- nexaai/mlx_backend/mlx_audio/sts/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
- nexaai/mlx_backend/mlx_audio/sts/voice_pipeline.py +327 -0
- nexaai/mlx_backend/mlx_audio/stt/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/stt/generate.py +174 -0
- nexaai/mlx_backend/mlx_audio/stt/models/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/alignment.py +248 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/attention.py +187 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/audio.py +76 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/conformer.py +331 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/ctc.py +34 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
- nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
- nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/audio.py +82 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/decoding.py +742 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/timing.py +329 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/whisper.py +862 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/writers.py +268 -0
- nexaai/mlx_backend/mlx_audio/stt/tests/test_models.py +381 -0
- nexaai/mlx_backend/mlx_audio/stt/utils.py +195 -0
- nexaai/mlx_backend/mlx_audio/tts/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/audio_player.py +120 -0
- nexaai/mlx_backend/mlx_audio/tts/convert.py +71 -0
- nexaai/mlx_backend/mlx_audio/tts/generate.py +449 -0
- nexaai/mlx_backend/mlx_audio/tts/models/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/__init__.py +4 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/bark.py +528 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/isftnet.py +12 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/pipeline.py +442 -0
- nexaai/mlx_backend/mlx_audio/tts/models/base.py +84 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/audio.py +287 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/config.py +256 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/dia.py +592 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/layers.py +870 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/__init__.py +3 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/attention.py +180 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/conformer.py +247 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/gpt2.py +38 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/indextts.py +412 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/mel.py +37 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/normalize.py +294 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/perceiver.py +62 -0
- nexaai/mlx_backend/mlx_audio/tts/models/interpolate.py +108 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/__init__.py +4 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/modules.py +659 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/voice.py +113 -0
- nexaai/mlx_backend/mlx_audio/tts/models/llama/__init__.py +3 -0
- nexaai/mlx_backend/mlx_audio/tts/models/llama/llama.py +324 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/default_speaker.json +461 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/outetts.py +255 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/tokens.py +36 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/__init__.py +3 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/attention.py +195 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/sesame.py +633 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/watermarking.py +105 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/bicodec.py +269 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual.py +209 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/spark.py +382 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/audio.py +220 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/file.py +221 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_base.py +66 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_convert.py +173 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_interpolate.py +88 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_models.py +974 -0
- nexaai/mlx_backend/mlx_audio/tts/utils.py +337 -0
- nexaai/mlx_backend/mlx_audio/utils.py +237 -0
- nexaai/mlx_backend/mlx_audio/version.py +1 -0
- nexaai/mlx_backend/profiling.py +239 -0
- nexaai/mlx_backend/rerank/__init__.py +0 -0
- nexaai/mlx_backend/rerank/generate.py +174 -0
- nexaai/mlx_backend/rerank/interface.py +287 -0
- nexaai/mlx_backend/rerank/main.py +127 -0
- nexaai/mlx_backend/rerank/modeling/__init__.py +0 -0
- nexaai/mlx_backend/rerank/modeling/nexa_jina_rerank.py +330 -0
- nexaai/mlx_backend/sd/__init__.py +1 -0
- nexaai/mlx_backend/sd/interface.py +362 -0
- nexaai/mlx_backend/sd/main.py +286 -0
- nexaai/mlx_backend/sd/modeling/__init__.py +306 -0
- nexaai/mlx_backend/sd/modeling/clip.py +116 -0
- nexaai/mlx_backend/sd/modeling/config.py +65 -0
- nexaai/mlx_backend/sd/modeling/model_io.py +385 -0
- nexaai/mlx_backend/sd/modeling/sampler.py +105 -0
- nexaai/mlx_backend/sd/modeling/tokenizer.py +100 -0
- nexaai/mlx_backend/sd/modeling/unet.py +460 -0
- nexaai/mlx_backend/sd/modeling/vae.py +274 -0
- nexaai/mlx_backend/tts/__init__.py +12 -0
- nexaai/mlx_backend/tts/interface.py +276 -0
- nexaai/mlx_backend/vlm/__init__.py +3 -0
- nexaai/mlx_backend/vlm/generate.py +572 -0
- nexaai/mlx_backend/vlm/generate_qwen3_vl.py +374 -0
- nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +259 -0
- nexaai/mlx_backend/vlm/interface.py +559 -0
- nexaai/mlx_backend/vlm/main.py +365 -0
- nexaai/mlx_backend/vlm/modeling/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/convert.py +68 -0
- nexaai/mlx_backend/vlm/modeling/models/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/interpolate.py +186 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/language.py +233 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/vision.py +503 -0
- nexaai/mlx_backend/vlm/modeling/models/base.py +202 -0
- nexaai/mlx_backend/vlm/modeling/models/cache.py +230 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/florence2.py +366 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/language.py +488 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/vision.py +591 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/gemma3.py +213 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/language.py +315 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/vision.py +238 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/__init__.py +2 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/audio.py +1038 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/config.py +139 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/language.py +629 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/vision.py +1022 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/idefics2.py +294 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/language.py +191 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/vision.py +267 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/idefics3.py +175 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/language.py +192 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/vision.py +233 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/language.py +220 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/processor.py +393 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/vision.py +293 -0
- nexaai/mlx_backend/vlm/modeling/models/kernels.py +307 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/language.py +509 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/vision.py +522 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/language.py +386 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/llama4.py +138 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/vision.py +560 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/language.py +240 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/llava.py +153 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/vision.py +259 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/language.py +236 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/vision.py +303 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/language.py +230 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/llava_next.py +160 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/vision.py +243 -0
- nexaai/mlx_backend/vlm/modeling/models/mistral3/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/mistral3/mistral3.py +283 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/language.py +416 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/mllama.py +172 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/vision.py +499 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/language.py +243 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/molmo.py +133 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/vision.py +465 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/__init__.py +10 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/language.py +230 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/sam.py +557 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/vision.py +526 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/language.py +282 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/paligemma.py +160 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/vision.py +242 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/language.py +21 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/su_rope.py +71 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/vision.py +324 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/language.py +229 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/pixtral.py +161 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/vision.py +320 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/config.py +104 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/language.py +490 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/vision.py +312 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/processor.py +476 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +1262 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1308 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
- nexaai/mlx_backend/vlm/modeling/models/smolvlm/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
- nexaai/mlx_backend/vlm/modeling/processing_qwen2_5_vl.py +209 -0
- nexaai/mlx_backend/vlm/modeling/processing_qwen2_vl.py +215 -0
- nexaai/mlx_backend/vlm/modeling/prompt_utils.py +474 -0
- nexaai/mlx_backend/vlm/modeling/sample_utils.py +39 -0
- nexaai/mlx_backend/vlm/modeling/tokenizer_utils.py +344 -0
- nexaai/mlx_backend/vlm/modeling/trainer/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/trainer/lora.py +70 -0
- nexaai/mlx_backend/vlm/modeling/trainer/trainer.py +296 -0
- nexaai/mlx_backend/vlm/modeling/trainer/utils.py +160 -0
- nexaai/mlx_backend/vlm/modeling/utils.py +928 -0
- nexaai/rerank.py +57 -0
- nexaai/rerank_impl/__init__.py +0 -0
- nexaai/rerank_impl/mlx_rerank_impl.py +94 -0
- nexaai/rerank_impl/pybind_rerank_impl.py +136 -0
- nexaai/runtime.py +68 -0
- nexaai/runtime_error.py +24 -0
- nexaai/tts.py +75 -0
- nexaai/tts_impl/__init__.py +0 -0
- nexaai/tts_impl/mlx_tts_impl.py +94 -0
- nexaai/tts_impl/pybind_tts_impl.py +43 -0
- nexaai/utils/decode.py +18 -0
- nexaai/utils/manifest_utils.py +531 -0
- nexaai/utils/model_manager.py +1745 -0
- nexaai/utils/model_types.py +49 -0
- nexaai/utils/progress_tracker.py +389 -0
- nexaai/utils/quantization_utils.py +245 -0
- nexaai/vlm.py +130 -0
- nexaai/vlm_impl/__init__.py +0 -0
- nexaai/vlm_impl/mlx_vlm_impl.py +259 -0
- nexaai/vlm_impl/pybind_vlm_impl.py +275 -0
- nexaai-1.0.29.dist-info/METADATA +35 -0
- nexaai-1.0.29.dist-info/RECORD +580 -0
- nexaai-1.0.29.dist-info/WHEEL +5 -0
- nexaai-1.0.29.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,764 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
from typing import Any, List, Optional, Sequence, Tuple, Union
|
|
6
|
+
import mlx.core as mx
|
|
7
|
+
import os
|
|
8
|
+
import time
|
|
9
|
+
|
|
10
|
+
# Import necessary modules from mlx_lm
|
|
11
|
+
from mlx_lm import generate, stream_generate, load
|
|
12
|
+
from mlx_lm.sample_utils import make_sampler, make_logits_processors
|
|
13
|
+
from mlx_lm.models.cache import make_prompt_cache, save_prompt_cache, load_prompt_cache
|
|
14
|
+
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
|
15
|
+
from mlx_lm.generate import generate_step
|
|
16
|
+
from mlx_lm.tuner.utils import load_adapters
|
|
17
|
+
import mlx.core as mx
|
|
18
|
+
|
|
19
|
+
# Import configs and callback types from ml.py for API alignment
|
|
20
|
+
from ml import (
|
|
21
|
+
LLM as BaseLLM,
|
|
22
|
+
ModelConfig,
|
|
23
|
+
SamplerConfig,
|
|
24
|
+
GenerationConfig,
|
|
25
|
+
ChatMessage,
|
|
26
|
+
EmbeddingConfig,
|
|
27
|
+
TokenCallback,
|
|
28
|
+
Path,
|
|
29
|
+
Tool
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
# Import profiling module
|
|
33
|
+
from profiling import ProfilingMixin, ProfilingData, StopReason
|
|
34
|
+
|
|
35
|
+
class LLM(BaseLLM, ProfilingMixin):
|
|
36
|
+
"""
|
|
37
|
+
LLM interface for mlx-lm.
|
|
38
|
+
API aligned with ml.py LLM abstract base class.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
model_path: Path,
|
|
44
|
+
tokenizer_path: Path,
|
|
45
|
+
config: ModelConfig,
|
|
46
|
+
device: Optional[str] = None,
|
|
47
|
+
) -> None:
|
|
48
|
+
"""
|
|
49
|
+
Initialize the LLM model.
|
|
50
|
+
"""
|
|
51
|
+
# Initialize profiling mixin
|
|
52
|
+
ProfilingMixin.__init__(self)
|
|
53
|
+
|
|
54
|
+
# Check if model_path is a file, if so use its parent directory, since MLX requires loading from a directory
|
|
55
|
+
if os.path.isfile(model_path):
|
|
56
|
+
model_path = os.path.dirname(model_path)
|
|
57
|
+
|
|
58
|
+
# Call parent constructor
|
|
59
|
+
super().__init__(model_path, tokenizer_path, config, device)
|
|
60
|
+
|
|
61
|
+
# For MLX, we ignore ModelConfig parameters as requested
|
|
62
|
+
# Store the basic parameters
|
|
63
|
+
self.model_path = model_path
|
|
64
|
+
self.tokenizer_path = tokenizer_path
|
|
65
|
+
self.config = config # Store but ignore the values
|
|
66
|
+
self.device = device if device is not None else "cpu"
|
|
67
|
+
|
|
68
|
+
# Simulate C handle (would be pointer in C, here just store info)
|
|
69
|
+
self.handle = {
|
|
70
|
+
"model_path": model_path,
|
|
71
|
+
"tokenizer_path": tokenizer_path,
|
|
72
|
+
"device": self.device,
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
# Load model and tokenizer using mlx-lm
|
|
76
|
+
self.model, self.tokenizer = load(model_path)
|
|
77
|
+
self.sampler_config = SamplerConfig()
|
|
78
|
+
self.default_generation_config = GenerationConfig()
|
|
79
|
+
self.kv_cache = None
|
|
80
|
+
# Initialize cache and global tracking (similar to reset logic)
|
|
81
|
+
self._reset_cache()
|
|
82
|
+
self.token_generator = None
|
|
83
|
+
self.loras = {}
|
|
84
|
+
self.current_lora_id = -1
|
|
85
|
+
self._next_lora_id = 0
|
|
86
|
+
# Track whether KV cache has been used for generation
|
|
87
|
+
self.kv_cache_used = False
|
|
88
|
+
# Track total tokens processed (prompts + responses) for prompt cache functionality
|
|
89
|
+
self.global_n_past = 0
|
|
90
|
+
|
|
91
|
+
def destroy(self) -> None:
|
|
92
|
+
"""Destroy LLM instance and free associated resources (ml_llm_destroy)."""
|
|
93
|
+
self.model = None
|
|
94
|
+
self.tokenizer = None
|
|
95
|
+
self.kv_cache = None
|
|
96
|
+
self.token_generator = None
|
|
97
|
+
self.sampler_config = SamplerConfig()
|
|
98
|
+
self.default_generation_config = GenerationConfig()
|
|
99
|
+
self.loras.clear()
|
|
100
|
+
self.current_lora_id = -1
|
|
101
|
+
self._next_lora_id = 0
|
|
102
|
+
self.kv_cache_used = False
|
|
103
|
+
self.global_n_past = 0
|
|
104
|
+
self.reset_profiling()
|
|
105
|
+
|
|
106
|
+
def reset(self) -> None:
|
|
107
|
+
"""Reset LLM internal state (ml_llm_reset)."""
|
|
108
|
+
mx.clear_cache()
|
|
109
|
+
self._reset_cache()
|
|
110
|
+
self.reset_profiling()
|
|
111
|
+
|
|
112
|
+
def _reset_cache(self) -> None:
|
|
113
|
+
"""Reset the KV cache."""
|
|
114
|
+
if self.model is not None:
|
|
115
|
+
# For MLX, let mlx-lm handle cache size automatically since we ignore ModelConfig
|
|
116
|
+
# Use n_ctx if provided and > 0, otherwise let mlx-lm decide
|
|
117
|
+
max_kv_size = self.config.n_ctx if self.config.n_ctx > 0 else None
|
|
118
|
+
if max_kv_size:
|
|
119
|
+
self.kv_cache = make_prompt_cache(self.model, max_kv_size=max_kv_size)
|
|
120
|
+
else:
|
|
121
|
+
self.kv_cache = make_prompt_cache(self.model)
|
|
122
|
+
self.token_generator = None # Reset generator for new conversation
|
|
123
|
+
self.kv_cache_used = False # Reset cache usage flag
|
|
124
|
+
self.global_n_past = 0 # Reset prompt cache tracking
|
|
125
|
+
|
|
126
|
+
# Tokenization methods
|
|
127
|
+
def encode(self, text: str) -> List[int]:
|
|
128
|
+
"""Encode UTF-8 text to token IDs (ml_llm_encode)."""
|
|
129
|
+
if not isinstance(self.tokenizer, TokenizerWrapper):
|
|
130
|
+
wrapper = TokenizerWrapper(self.tokenizer)
|
|
131
|
+
return wrapper.encode(text, add_special_tokens=True)
|
|
132
|
+
return self.tokenizer.encode(text, add_special_tokens=True)
|
|
133
|
+
|
|
134
|
+
def decode(self, token_ids: Sequence[int]) -> str:
|
|
135
|
+
"""Decode token IDs to UTF-8 text (ml_llm_decode)."""
|
|
136
|
+
if not isinstance(self.tokenizer, TokenizerWrapper):
|
|
137
|
+
wrapper = TokenizerWrapper(self.tokenizer)
|
|
138
|
+
return wrapper.decode(list(token_ids))
|
|
139
|
+
return self.tokenizer.decode(list(token_ids))
|
|
140
|
+
|
|
141
|
+
# KV-cache methods
|
|
142
|
+
def save_kv_cache(self, path: Path) -> bool:
|
|
143
|
+
"""Save KV cache to file. Returns True on success, False on error."""
|
|
144
|
+
try:
|
|
145
|
+
if self.kv_cache is not None:
|
|
146
|
+
if not path.endswith('.safetensors'):
|
|
147
|
+
path = path + '.safetensors'
|
|
148
|
+
save_prompt_cache(path, self.kv_cache)
|
|
149
|
+
return True
|
|
150
|
+
return False
|
|
151
|
+
except Exception as e:
|
|
152
|
+
print(f"Error saving KV cache: {e}")
|
|
153
|
+
return False
|
|
154
|
+
|
|
155
|
+
def load_kv_cache(self, path: Path) -> bool:
|
|
156
|
+
"""Load KV cache from file. Returns True on success, False on error."""
|
|
157
|
+
try:
|
|
158
|
+
if not path.endswith('.safetensors'):
|
|
159
|
+
path = path + '.safetensors'
|
|
160
|
+
self.kv_cache = load_prompt_cache(path)
|
|
161
|
+
return True
|
|
162
|
+
except Exception as e:
|
|
163
|
+
print(f"Error loading KV cache: {e}")
|
|
164
|
+
return False
|
|
165
|
+
|
|
166
|
+
# LoRA methods
|
|
167
|
+
#
|
|
168
|
+
# LoRA (Low-Rank Adaptation) support for fine-tuned model variants.
|
|
169
|
+
# This implementation supports dynamic switching between different LoRA adapters
|
|
170
|
+
# by reloading the model with the appropriate adapter weights.
|
|
171
|
+
#
|
|
172
|
+
# Usage:
|
|
173
|
+
# 1. Add LoRA adapter: lora_id = model.add_lora("/path/to/adapter")
|
|
174
|
+
# 2. Activate LoRA: model.set_lora(lora_id)
|
|
175
|
+
# 3. Switch back to base model: model.set_lora(-1)
|
|
176
|
+
# 4. Or combine steps 1-2: lora_id = model.load_and_activate_lora("/path/to/adapter")
|
|
177
|
+
|
|
178
|
+
def set_lora(self, lora_id: int) -> None:
|
|
179
|
+
"""Set active LoRA adapter by ID (ml_llm_set_lora)."""
|
|
180
|
+
if lora_id == -1:
|
|
181
|
+
if self.current_lora_id != -1:
|
|
182
|
+
self._switch_to_base_model()
|
|
183
|
+
return
|
|
184
|
+
if lora_id not in self.loras:
|
|
185
|
+
raise ValueError(f"LoRA adapter with ID {lora_id} not found")
|
|
186
|
+
if self.current_lora_id != lora_id:
|
|
187
|
+
self._switch_to_lora(lora_id)
|
|
188
|
+
|
|
189
|
+
def add_lora(self, lora_path: Path) -> int:
|
|
190
|
+
"""Add LoRA adapter from file (ml_llm_add_lora). Returns LoRA ID on success, negative on error."""
|
|
191
|
+
if not lora_path or not os.path.exists(lora_path):
|
|
192
|
+
return -1
|
|
193
|
+
if not self._validate_lora_adapter(lora_path):
|
|
194
|
+
return -2
|
|
195
|
+
for lora_id, (path, _) in self.loras.items():
|
|
196
|
+
if os.path.abspath(path) == os.path.abspath(lora_path):
|
|
197
|
+
return lora_id
|
|
198
|
+
lora_id = self._next_lora_id
|
|
199
|
+
self._next_lora_id += 1
|
|
200
|
+
try:
|
|
201
|
+
adapters = load_adapters(lora_path)
|
|
202
|
+
self.loras[lora_id] = (lora_path, adapters)
|
|
203
|
+
return lora_id
|
|
204
|
+
except Exception:
|
|
205
|
+
return -99
|
|
206
|
+
|
|
207
|
+
def _validate_lora_adapter(self, lora_path: Path) -> bool:
|
|
208
|
+
"""Validate that a path contains a valid LoRA adapter."""
|
|
209
|
+
if not os.path.isdir(lora_path):
|
|
210
|
+
return False
|
|
211
|
+
|
|
212
|
+
# Check for required LoRA files
|
|
213
|
+
required_files = ["adapter_config.json"]
|
|
214
|
+
optional_files = [
|
|
215
|
+
"adapters.safetensors",
|
|
216
|
+
"adapter_model.safetensors",
|
|
217
|
+
"pytorch_model.bin", # PyTorch format
|
|
218
|
+
"adapter_model.bin", # Alternative PyTorch format
|
|
219
|
+
]
|
|
220
|
+
|
|
221
|
+
# At least adapter_config.json should exist
|
|
222
|
+
config_exists = any(os.path.exists(os.path.join(lora_path, f)) for f in required_files)
|
|
223
|
+
if not config_exists:
|
|
224
|
+
return False
|
|
225
|
+
|
|
226
|
+
# At least one weight file should exist
|
|
227
|
+
weights_exist = any(os.path.exists(os.path.join(lora_path, f)) for f in optional_files)
|
|
228
|
+
|
|
229
|
+
return weights_exist
|
|
230
|
+
|
|
231
|
+
def remove_lora(self, lora_id: int) -> None:
|
|
232
|
+
"""Remove LoRA adapter by ID (ml_llm_remove_lora)."""
|
|
233
|
+
if lora_id not in self.loras:
|
|
234
|
+
return
|
|
235
|
+
if self.current_lora_id == lora_id:
|
|
236
|
+
self._switch_to_base_model()
|
|
237
|
+
self.loras.pop(lora_id, None)
|
|
238
|
+
|
|
239
|
+
def list_loras(self) -> List[int]:
|
|
240
|
+
"""List all loaded LoRA adapter IDs (ml_llm_list_loras)."""
|
|
241
|
+
return list(self.loras.keys())
|
|
242
|
+
|
|
243
|
+
def _switch_to_base_model(self) -> None:
|
|
244
|
+
"""Switch to the base model (no LoRA)."""
|
|
245
|
+
try:
|
|
246
|
+
# Reload the base model
|
|
247
|
+
self.model, self.tokenizer = load(self.model_path)
|
|
248
|
+
self.current_lora_id = -1
|
|
249
|
+
self._reset_cache() # Reset cache when switching models
|
|
250
|
+
except Exception as e:
|
|
251
|
+
raise RuntimeError(f"Failed to switch to base model: {str(e)}")
|
|
252
|
+
|
|
253
|
+
def _switch_to_lora(self, lora_id: int) -> None:
|
|
254
|
+
"""Switch to a specific LoRA adapter."""
|
|
255
|
+
if lora_id not in self.loras:
|
|
256
|
+
raise ValueError(f"LoRA adapter with ID {lora_id} not found")
|
|
257
|
+
|
|
258
|
+
try:
|
|
259
|
+
lora_path, adapters = self.loras[lora_id]
|
|
260
|
+
|
|
261
|
+
# Load model with LoRA adapter
|
|
262
|
+
self.model, self.tokenizer = load(self.model_path, adapter_path=lora_path)
|
|
263
|
+
self.current_lora_id = lora_id
|
|
264
|
+
self._reset_cache() # Reset cache when switching models
|
|
265
|
+
except Exception as e:
|
|
266
|
+
raise RuntimeError(f"Failed to switch to LoRA adapter {lora_id} (path: {lora_path}): {str(e)}")
|
|
267
|
+
|
|
268
|
+
def get_current_lora_id(self) -> int:
|
|
269
|
+
"""Get the currently active LoRA adapter ID."""
|
|
270
|
+
return self.current_lora_id
|
|
271
|
+
|
|
272
|
+
def get_lora_info(self, lora_id: int) -> dict:
|
|
273
|
+
"""Get information about a specific LoRA adapter."""
|
|
274
|
+
if lora_id not in self.loras:
|
|
275
|
+
raise ValueError(f"LoRA adapter with ID {lora_id} not found")
|
|
276
|
+
|
|
277
|
+
lora_path, adapters = self.loras[lora_id]
|
|
278
|
+
return {
|
|
279
|
+
"id": lora_id,
|
|
280
|
+
"path": lora_path,
|
|
281
|
+
"is_active": lora_id == self.current_lora_id,
|
|
282
|
+
"config": getattr(adapters, "config", None) if hasattr(adapters, "config") else None
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
def load_and_activate_lora(self, lora_path: Path) -> int:
|
|
286
|
+
"""Load a LoRA adapter and immediately activate it."""
|
|
287
|
+
lora_id = self.add_lora(lora_path)
|
|
288
|
+
self.set_lora(lora_id)
|
|
289
|
+
return lora_id
|
|
290
|
+
|
|
291
|
+
# Sampler methods
|
|
292
|
+
def set_sampler(self, config: SamplerConfig) -> None:
|
|
293
|
+
"""Configure text generation sampling parameters (ml_llm_set_sampler)."""
|
|
294
|
+
self.sampler_config = config
|
|
295
|
+
|
|
296
|
+
def reset_sampler(self) -> None:
|
|
297
|
+
"""Reset sampling parameters to defaults (ml_llm_reset_sampler)."""
|
|
298
|
+
self.sampler_config = SamplerConfig()
|
|
299
|
+
|
|
300
|
+
# Generation config methods
|
|
301
|
+
def set_generation_config(self, config: GenerationConfig) -> None:
|
|
302
|
+
"""Set default generation configuration for token-level generation."""
|
|
303
|
+
self.default_generation_config = config
|
|
304
|
+
|
|
305
|
+
def _make_mlx_sampler_from_config(self, sampler_config: SamplerConfig):
|
|
306
|
+
"""Create mlx-lm sampler from specific config."""
|
|
307
|
+
# Set seed if specified
|
|
308
|
+
if sampler_config.seed != -1:
|
|
309
|
+
mx.random.seed(sampler_config.seed)
|
|
310
|
+
|
|
311
|
+
return make_sampler(
|
|
312
|
+
temp=sampler_config.temperature,
|
|
313
|
+
top_p=sampler_config.top_p,
|
|
314
|
+
top_k=sampler_config.top_k,
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
def _make_logits_processors_from_config(self, sampler_config: SamplerConfig):
|
|
318
|
+
"""Create logits processors from specific config."""
|
|
319
|
+
# Only use repetition penalty which is natively supported by mlx-lm
|
|
320
|
+
if sampler_config.repetition_penalty != 1.0:
|
|
321
|
+
return make_logits_processors(
|
|
322
|
+
repetition_penalty=sampler_config.repetition_penalty,
|
|
323
|
+
)
|
|
324
|
+
return None
|
|
325
|
+
|
|
326
|
+
def _make_mlx_sampler(self):
|
|
327
|
+
"""Create mlx-lm sampler from class config."""
|
|
328
|
+
return self._make_mlx_sampler_from_config(self.sampler_config)
|
|
329
|
+
|
|
330
|
+
def _make_logits_processors(self):
|
|
331
|
+
"""Create logits processors from class config."""
|
|
332
|
+
return self._make_logits_processors_from_config(self.sampler_config)
|
|
333
|
+
|
|
334
|
+
def generate_stream(
|
|
335
|
+
self,
|
|
336
|
+
prompt: str,
|
|
337
|
+
config: Optional[GenerationConfig],
|
|
338
|
+
on_token: TokenCallback,
|
|
339
|
+
user_data: Any = None,
|
|
340
|
+
) -> str:
|
|
341
|
+
"""
|
|
342
|
+
Generate text with streaming callback and profiling.
|
|
343
|
+
|
|
344
|
+
The prompt should be the incremental part after applying chat template.
|
|
345
|
+
apply_chat_template now returns only the incremental prompt based on global_n_past:
|
|
346
|
+
- First round (global_n_past = 0): Last user message + last system message (if exists)
|
|
347
|
+
- Subsequent rounds (global_n_past > 0): Only last user message
|
|
348
|
+
|
|
349
|
+
Prompt Cache Behavior:
|
|
350
|
+
- Tracks global_n_past to know how many tokens (prompts + responses) have been processed
|
|
351
|
+
- Passes incremental token arrays directly to stream_generate as prompt cache already contains the past history
|
|
352
|
+
- KV cache retains the conversation context until reset() is called
|
|
353
|
+
"""
|
|
354
|
+
# Start profiling
|
|
355
|
+
self._start_profiling()
|
|
356
|
+
|
|
357
|
+
if config is None:
|
|
358
|
+
config = GenerationConfig()
|
|
359
|
+
|
|
360
|
+
# Use sampler config from GenerationConfig if provided, otherwise use class config
|
|
361
|
+
effective_sampler_config = config.sampler_config if config.sampler_config else self.sampler_config
|
|
362
|
+
|
|
363
|
+
# Create sampler from effective config
|
|
364
|
+
sampler = self._make_mlx_sampler_from_config(effective_sampler_config)
|
|
365
|
+
logits_processors = self._make_logits_processors_from_config(effective_sampler_config)
|
|
366
|
+
|
|
367
|
+
is_first_round = self.global_n_past <= 0
|
|
368
|
+
|
|
369
|
+
# Encode prompt to get tokens
|
|
370
|
+
incremental_tokens = self.encode(prompt)
|
|
371
|
+
cached_tokens = 0
|
|
372
|
+
|
|
373
|
+
# Only offset prefix kv-cache at first round
|
|
374
|
+
# if is_first_round:
|
|
375
|
+
|
|
376
|
+
# # Handle KV cache prefix offset if available
|
|
377
|
+
# if self.kv_cache is not None and len(self.kv_cache) > 0:
|
|
378
|
+
# # Get the offset from the first cache layer
|
|
379
|
+
# if hasattr(self.kv_cache[0], 'offset'):
|
|
380
|
+
# cached_tokens = self.kv_cache[0].offset - 1
|
|
381
|
+
|
|
382
|
+
# # Process only the non-cached tokens
|
|
383
|
+
# incremental_tokens = incremental_tokens[cached_tokens:] if cached_tokens > 0 else incremental_tokens
|
|
384
|
+
|
|
385
|
+
# if len(incremental_tokens) == 0:
|
|
386
|
+
# raise ValueError("No tokens to process, KV cache is too long.")
|
|
387
|
+
|
|
388
|
+
# Since apply_chat_template now returns incremental prompts, we can use the prompt directly
|
|
389
|
+
# The prompt is already the incremental part based on global_n_past
|
|
390
|
+
incremental_length = len(incremental_tokens)
|
|
391
|
+
|
|
392
|
+
# Record prompt tokens for profiling (use incremental length for this call)
|
|
393
|
+
self._update_prompt_tokens(incremental_length)
|
|
394
|
+
|
|
395
|
+
generated_tokens = 0
|
|
396
|
+
full_text = ""
|
|
397
|
+
last_response = None
|
|
398
|
+
first_token = True
|
|
399
|
+
|
|
400
|
+
try:
|
|
401
|
+
# End prompt processing, start decode
|
|
402
|
+
self._prompt_end()
|
|
403
|
+
self._decode_start()
|
|
404
|
+
|
|
405
|
+
for response in stream_generate(
|
|
406
|
+
model=self.model,
|
|
407
|
+
tokenizer=self.tokenizer,
|
|
408
|
+
prompt=incremental_tokens,
|
|
409
|
+
max_tokens=config.max_tokens,
|
|
410
|
+
sampler=sampler,
|
|
411
|
+
logits_processors=logits_processors if logits_processors else None,
|
|
412
|
+
prompt_cache=self.kv_cache,
|
|
413
|
+
):
|
|
414
|
+
# Record TTFT on first token
|
|
415
|
+
if first_token:
|
|
416
|
+
self._record_ttft()
|
|
417
|
+
first_token = False
|
|
418
|
+
|
|
419
|
+
token_text = response.text
|
|
420
|
+
generated_tokens += 1
|
|
421
|
+
|
|
422
|
+
# Call the token callback - if it returns False, stop generation
|
|
423
|
+
if not on_token(token_text, user_data):
|
|
424
|
+
self._set_stop_reason(StopReason.ML_STOP_REASON_USER)
|
|
425
|
+
break
|
|
426
|
+
full_text += token_text
|
|
427
|
+
last_response = response
|
|
428
|
+
|
|
429
|
+
# Set stop reason based on how generation ended
|
|
430
|
+
if generated_tokens >= config.max_tokens:
|
|
431
|
+
self._set_stop_reason(StopReason.ML_STOP_REASON_LENGTH)
|
|
432
|
+
elif self._profiling_context.stop_reason != StopReason.ML_STOP_REASON_USER: # Don't override user stop
|
|
433
|
+
# Check if the last response indicates EOS stop
|
|
434
|
+
if last_response:
|
|
435
|
+
if hasattr(last_response, 'finish_reason') and last_response.finish_reason == "stop":
|
|
436
|
+
self._set_stop_reason(StopReason.ML_STOP_REASON_EOS)
|
|
437
|
+
else:
|
|
438
|
+
self._set_stop_reason(StopReason.ML_STOP_REASON_COMPLETED)
|
|
439
|
+
else:
|
|
440
|
+
# Fallback: generation loop ended naturally, likely due to EOS
|
|
441
|
+
self._set_stop_reason(StopReason.ML_STOP_REASON_EOS)
|
|
442
|
+
|
|
443
|
+
# Update global_n_past to reflect the new tokens processed (incremental prompt + response)
|
|
444
|
+
# Use the response metadata to get accurate token counts
|
|
445
|
+
self.global_n_past += cached_tokens + incremental_length + last_response.generation_tokens
|
|
446
|
+
|
|
447
|
+
# Mark cache as used after successful generation
|
|
448
|
+
self.kv_cache_used = True
|
|
449
|
+
|
|
450
|
+
# Update generated tokens and end profiling
|
|
451
|
+
self._update_generated_tokens(generated_tokens)
|
|
452
|
+
self._decode_end()
|
|
453
|
+
self._end_profiling()
|
|
454
|
+
|
|
455
|
+
return full_text
|
|
456
|
+
except Exception as e:
|
|
457
|
+
import traceback
|
|
458
|
+
self._set_stop_reason(StopReason.ML_STOP_REASON_UNKNOWN)
|
|
459
|
+
self._decode_end()
|
|
460
|
+
self._end_profiling()
|
|
461
|
+
return f"Streaming generation error: {str(e)}\n{traceback.format_exc()}"
|
|
462
|
+
|
|
463
|
+
# Chat template methods
|
|
464
|
+
def get_chat_template(self, template_name: str) -> str:
|
|
465
|
+
"""Get chat template by name."""
|
|
466
|
+
# The header expects a template_name argument, but mlx-lm only supports one template.
|
|
467
|
+
# We'll ignore the argument for now.
|
|
468
|
+
return self.tokenizer.chat_template
|
|
469
|
+
|
|
470
|
+
def apply_chat_template(self, messages: Sequence[ChatMessage], tools: Optional[str] = None, enable_thinking: bool = True, add_generation_prompt: bool = True) -> str:
|
|
471
|
+
"""
|
|
472
|
+
Apply chat template to messages with incremental prompt support and optional tools.
|
|
473
|
+
|
|
474
|
+
This method now returns only the incremental prompt based on global_n_past:
|
|
475
|
+
- When global_n_past = 0 (first conversation): Last user message + last system message (if exists)
|
|
476
|
+
- When global_n_past > 0 (subsequent rounds): Only last user message
|
|
477
|
+
"""
|
|
478
|
+
# TODO: this is temporary solution to account for the no-thinking requirement of GPT-OSS. In the long term we need to revisit the API design of apply_chat_template.
|
|
479
|
+
try:
|
|
480
|
+
# Check global_n_past > 0 to determine if this is the first round of conversation
|
|
481
|
+
is_first_round = self.global_n_past <= 0
|
|
482
|
+
|
|
483
|
+
# Find last user message and last system message
|
|
484
|
+
last_user_msg = None
|
|
485
|
+
last_system_msg = None
|
|
486
|
+
|
|
487
|
+
for msg in messages:
|
|
488
|
+
if msg.role == "user":
|
|
489
|
+
last_user_msg = msg
|
|
490
|
+
elif msg.role == "system":
|
|
491
|
+
last_system_msg = msg
|
|
492
|
+
|
|
493
|
+
# Build incremental message list based on conversation round
|
|
494
|
+
if is_first_round:
|
|
495
|
+
# First round: include system message (if exists) + last user message
|
|
496
|
+
incremental_messages = []
|
|
497
|
+
if last_system_msg:
|
|
498
|
+
incremental_messages.append({
|
|
499
|
+
"role": last_system_msg.role,
|
|
500
|
+
"content": last_system_msg.content
|
|
501
|
+
})
|
|
502
|
+
|
|
503
|
+
if last_user_msg:
|
|
504
|
+
incremental_messages.append({
|
|
505
|
+
"role": last_user_msg.role,
|
|
506
|
+
"content": last_user_msg.content
|
|
507
|
+
})
|
|
508
|
+
else:
|
|
509
|
+
raise ValueError("No user message found for first conversation round")
|
|
510
|
+
|
|
511
|
+
else:
|
|
512
|
+
# Subsequent rounds: only last user message
|
|
513
|
+
if last_user_msg:
|
|
514
|
+
incremental_messages = [{
|
|
515
|
+
"role": last_user_msg.role,
|
|
516
|
+
"content": last_user_msg.content
|
|
517
|
+
}]
|
|
518
|
+
else:
|
|
519
|
+
raise ValueError("No user message found for subsequent conversation round")
|
|
520
|
+
|
|
521
|
+
parsed_tools = None
|
|
522
|
+
if tools is not None:
|
|
523
|
+
parsed_tools = json.loads(tools)
|
|
524
|
+
|
|
525
|
+
return self.tokenizer.apply_chat_template(
|
|
526
|
+
incremental_messages,
|
|
527
|
+
tokenize=False,
|
|
528
|
+
enable_thinking=enable_thinking,
|
|
529
|
+
add_generation_prompt=add_generation_prompt,
|
|
530
|
+
tools=parsed_tools
|
|
531
|
+
)
|
|
532
|
+
except Exception as e:
|
|
533
|
+
import traceback
|
|
534
|
+
raise RuntimeError(f"Error applying chat template: {str(e)}\n{traceback.format_exc()}")
|
|
535
|
+
|
|
536
|
+
# Embeddings - using the model's embedding layer directly
|
|
537
|
+
def embed(
|
|
538
|
+
self,
|
|
539
|
+
texts: Sequence[str],
|
|
540
|
+
config: Optional[EmbeddingConfig] = None,
|
|
541
|
+
) -> List[List[float]]:
|
|
542
|
+
"""Generate embeddings for texts with profiling."""
|
|
543
|
+
# Start profiling
|
|
544
|
+
self._start_profiling()
|
|
545
|
+
|
|
546
|
+
# Calculate total tokens for all texts
|
|
547
|
+
total_tokens = sum(len(self.encode(text)) for text in texts)
|
|
548
|
+
self._update_prompt_tokens(total_tokens)
|
|
549
|
+
|
|
550
|
+
# End prompt processing, start decode
|
|
551
|
+
self._prompt_end()
|
|
552
|
+
self._decode_start()
|
|
553
|
+
|
|
554
|
+
try:
|
|
555
|
+
embeddings = []
|
|
556
|
+
|
|
557
|
+
for text in texts:
|
|
558
|
+
# Tokenize the text
|
|
559
|
+
tokens = self.encode(text)
|
|
560
|
+
|
|
561
|
+
# Convert to mlx array
|
|
562
|
+
token_array = mx.array(tokens)
|
|
563
|
+
|
|
564
|
+
# Get embeddings directly from the model's embedding layer
|
|
565
|
+
embedding_tensor = self.model.model.embed_tokens(token_array)
|
|
566
|
+
|
|
567
|
+
# Average pool across sequence dimension to get a single embedding per text
|
|
568
|
+
# Shape: [seq_len, hidden_size] -> [hidden_size]
|
|
569
|
+
pooled_embedding = mx.mean(embedding_tensor, axis=0)
|
|
570
|
+
|
|
571
|
+
# Convert to Python list of floats
|
|
572
|
+
embedding_list = pooled_embedding.tolist()
|
|
573
|
+
embeddings.append(embedding_list)
|
|
574
|
+
|
|
575
|
+
# End timing and finalize profiling data
|
|
576
|
+
self._update_generated_tokens(0) # No generation in embedding
|
|
577
|
+
self._set_stop_reason(StopReason.ML_STOP_REASON_COMPLETED)
|
|
578
|
+
self._decode_end()
|
|
579
|
+
self._end_profiling()
|
|
580
|
+
|
|
581
|
+
return embeddings
|
|
582
|
+
|
|
583
|
+
except Exception as e:
|
|
584
|
+
self._set_stop_reason(StopReason.ML_STOP_REASON_UNKNOWN)
|
|
585
|
+
self._decode_end()
|
|
586
|
+
self._end_profiling()
|
|
587
|
+
raise RuntimeError(f"Error generating embeddings: {str(e)}")
|
|
588
|
+
|
|
589
|
+
# =============================================================================
|
|
590
|
+
# Test functions
|
|
591
|
+
# =============================================================================
|
|
592
|
+
# Add test functions at the bottom before the main conversation test
|
|
593
|
+
def test_kv_cache_save_load():
|
|
594
|
+
"""Test KV cache save and load functionality"""
|
|
595
|
+
print("Testing KV cache save and load...")
|
|
596
|
+
|
|
597
|
+
# Initialize model
|
|
598
|
+
model_path = "mlx-community/Qwen3-1.7B-4bit-DWQ"
|
|
599
|
+
config = ModelConfig()
|
|
600
|
+
config.n_ctx = 512
|
|
601
|
+
|
|
602
|
+
llm = LLM(model_path, model_path, config)
|
|
603
|
+
|
|
604
|
+
def stream_callback(token, user_data):
|
|
605
|
+
print(token, end="", flush=True)
|
|
606
|
+
return True
|
|
607
|
+
|
|
608
|
+
# Test prompt
|
|
609
|
+
test_prompt = "🥳 🎂 Once upon a time"
|
|
610
|
+
|
|
611
|
+
# Test save
|
|
612
|
+
print("Testing KV cache save...")
|
|
613
|
+
gen_config = GenerationConfig()
|
|
614
|
+
gen_config.max_tokens = 20 # Generate enough tokens to populate cache
|
|
615
|
+
|
|
616
|
+
print("Generating text to populate cache:")
|
|
617
|
+
response = llm.generate_stream(test_prompt, gen_config, stream_callback)
|
|
618
|
+
print(f"\nGenerated: {response}")
|
|
619
|
+
|
|
620
|
+
cache_path = "./test_kvcache_save.safetensors"
|
|
621
|
+
save_result = llm.save_kv_cache(cache_path)
|
|
622
|
+
print(f"Save result: {save_result}")
|
|
623
|
+
assert save_result == True, "KV cache save should succeed"
|
|
624
|
+
|
|
625
|
+
# Reset cache
|
|
626
|
+
llm.reset()
|
|
627
|
+
|
|
628
|
+
# Test load
|
|
629
|
+
print("Testing KV cache load...")
|
|
630
|
+
cache_path = "./test_kvcache_load.safetensors"
|
|
631
|
+
|
|
632
|
+
# First generate and save
|
|
633
|
+
response = llm.generate_stream(test_prompt, gen_config, stream_callback)
|
|
634
|
+
save_result = llm.save_kv_cache(cache_path)
|
|
635
|
+
assert save_result == True, "KV cache save should succeed"
|
|
636
|
+
|
|
637
|
+
# Reset and load
|
|
638
|
+
llm.reset()
|
|
639
|
+
load_result = llm.load_kv_cache(cache_path)
|
|
640
|
+
print(f"Load result: {load_result}")
|
|
641
|
+
assert load_result == True, "KV cache load should succeed"
|
|
642
|
+
|
|
643
|
+
print("KV cache save/load tests passed!")
|
|
644
|
+
|
|
645
|
+
def test_tokenization():
|
|
646
|
+
"""Test encode and decode functionality"""
|
|
647
|
+
print("Testing tokenization...")
|
|
648
|
+
|
|
649
|
+
model_path = "mlx-community/Qwen3-1.7B-4bit-DWQ"
|
|
650
|
+
config = ModelConfig()
|
|
651
|
+
|
|
652
|
+
llm = LLM(model_path, model_path, config)
|
|
653
|
+
|
|
654
|
+
test_text = "🥳 🎂 Once upon a time"
|
|
655
|
+
|
|
656
|
+
# Test encode
|
|
657
|
+
token_ids = llm.encode(test_text)
|
|
658
|
+
print(f"Encoded '{test_text}' to {len(token_ids)} tokens")
|
|
659
|
+
assert len(token_ids) > 0, "Encoding should produce tokens"
|
|
660
|
+
|
|
661
|
+
# Test decode
|
|
662
|
+
decoded_text = llm.decode(token_ids)
|
|
663
|
+
print(f"Decoded back to: '{decoded_text}'")
|
|
664
|
+
assert len(decoded_text) > 0, "Decoding should produce text"
|
|
665
|
+
|
|
666
|
+
print("Tokenization tests passed!")
|
|
667
|
+
|
|
668
|
+
def test_generation():
|
|
669
|
+
"""Test basic text generation"""
|
|
670
|
+
print("Testing generation...")
|
|
671
|
+
|
|
672
|
+
model_path = "mlx-community/Qwen3-1.7B-4bit-DWQ"
|
|
673
|
+
config = ModelConfig()
|
|
674
|
+
|
|
675
|
+
llm = LLM(model_path, model_path, config)
|
|
676
|
+
|
|
677
|
+
def stream_callback(token, user_data):
|
|
678
|
+
print(token, end="", flush=True)
|
|
679
|
+
return True
|
|
680
|
+
|
|
681
|
+
test_prompt = "🥳 🎂 Once upon a time"
|
|
682
|
+
gen_config = GenerationConfig()
|
|
683
|
+
gen_config.max_tokens = 10
|
|
684
|
+
|
|
685
|
+
print("Generating text:")
|
|
686
|
+
response = llm.generate_stream(test_prompt, gen_config, stream_callback)
|
|
687
|
+
print(f"\nGenerated response length: {len(response)}")
|
|
688
|
+
assert len(response) > 0, "Generation should produce text"
|
|
689
|
+
|
|
690
|
+
print("Generation test passed!")
|
|
691
|
+
|
|
692
|
+
def run_tests():
|
|
693
|
+
"""Run all test cases"""
|
|
694
|
+
try:
|
|
695
|
+
test_tokenization()
|
|
696
|
+
print()
|
|
697
|
+
test_generation()
|
|
698
|
+
print()
|
|
699
|
+
test_kv_cache_save_load()
|
|
700
|
+
print()
|
|
701
|
+
print("All tests passed! ✅")
|
|
702
|
+
except Exception as e:
|
|
703
|
+
print(f"Test failed: {e}")
|
|
704
|
+
import traceback
|
|
705
|
+
traceback.print_exc()
|
|
706
|
+
|
|
707
|
+
# For testing
|
|
708
|
+
if __name__ == "__main__":
|
|
709
|
+
import sys
|
|
710
|
+
|
|
711
|
+
# Check if running tests
|
|
712
|
+
if len(sys.argv) > 1 and sys.argv[1] == "test":
|
|
713
|
+
run_tests()
|
|
714
|
+
sys.exit(0)
|
|
715
|
+
|
|
716
|
+
def on_token(token_text, user_data):
|
|
717
|
+
"""Token callback that prints each token as it's generated"""
|
|
718
|
+
print(token_text, end="", flush=True)
|
|
719
|
+
return True # Continue generation
|
|
720
|
+
|
|
721
|
+
# Multi-round conversation test case
|
|
722
|
+
model_path = "mlx-community/Qwen3-1.7B-4bit-DWQ"
|
|
723
|
+
tokenizer_path = "mlx-community/Qwen3-1.7B-4bit-DWQ"
|
|
724
|
+
config = ModelConfig()
|
|
725
|
+
|
|
726
|
+
llm = LLM(model_path, tokenizer_path, config)
|
|
727
|
+
|
|
728
|
+
# Run tests
|
|
729
|
+
print("================================================")
|
|
730
|
+
print("Running tests")
|
|
731
|
+
run_tests()
|
|
732
|
+
print("================================================")
|
|
733
|
+
|
|
734
|
+
# Multi-round conversation test case
|
|
735
|
+
chat = []
|
|
736
|
+
print("Multi-round conversation test. Type 'exit' to quit.")
|
|
737
|
+
|
|
738
|
+
while True:
|
|
739
|
+
try:
|
|
740
|
+
user_input = input("User: ").strip()
|
|
741
|
+
|
|
742
|
+
# Exit conditions
|
|
743
|
+
if user_input.lower() in ['exit', 'quit', '']:
|
|
744
|
+
break
|
|
745
|
+
|
|
746
|
+
# Add user message to chat history
|
|
747
|
+
chat.append(ChatMessage(role="user", content=user_input))
|
|
748
|
+
|
|
749
|
+
# Apply chat template to get full conversation history as formatted prompt
|
|
750
|
+
formatted_prompt = llm.apply_chat_template(chat)
|
|
751
|
+
# Generate response using streaming with on_token callback
|
|
752
|
+
print("Assistant: ", end="", flush=True) # Following generate.py pattern
|
|
753
|
+
response = llm.generate_stream(formatted_prompt, None, on_token)
|
|
754
|
+
|
|
755
|
+
# Add assistant response to chat history for next round
|
|
756
|
+
chat.append(ChatMessage(role="assistant", content=response))
|
|
757
|
+
print() # New line after response
|
|
758
|
+
|
|
759
|
+
except KeyboardInterrupt:
|
|
760
|
+
print("\nConversation interrupted by user.")
|
|
761
|
+
break
|
|
762
|
+
except Exception as e:
|
|
763
|
+
print(f"Error: {e}")
|
|
764
|
+
continue
|