nexaai 1.0.29__cp310-cp310-macosx_14_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nexaai/__init__.py +99 -0
- nexaai/_stub.cpython-310-darwin.so +0 -0
- nexaai/_version.py +4 -0
- nexaai/asr.py +68 -0
- nexaai/asr_impl/__init__.py +0 -0
- nexaai/asr_impl/mlx_asr_impl.py +93 -0
- nexaai/asr_impl/pybind_asr_impl.py +127 -0
- nexaai/base.py +39 -0
- nexaai/binds/__init__.py +7 -0
- nexaai/binds/asr_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/common_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/cpu_gpu/libggml-base.dylib +0 -0
- nexaai/binds/cpu_gpu/libggml-cpu.so +0 -0
- nexaai/binds/cpu_gpu/libggml-metal.so +0 -0
- nexaai/binds/cpu_gpu/libggml.dylib +0 -0
- nexaai/binds/cpu_gpu/libmtmd.dylib +0 -0
- nexaai/binds/cpu_gpu/libnexa_cpu_gpu.dylib +0 -0
- nexaai/binds/cpu_gpu/libnexa_plugin.dylib +0 -0
- nexaai/binds/cv_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/diarize_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/embedder_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/libnexa_bridge.dylib +0 -0
- nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/metal/libnexa_plugin.dylib +0 -0
- nexaai/binds/metal/py-lib/ml.py +888 -0
- nexaai/binds/metal/py-lib/mlx_audio/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/__init__.py +5 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/activation.py +51 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/amp.py +96 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/conv.py +114 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/resample.py +177 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/base.py +228 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/dac.py +285 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/layers.py +129 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/encodec.py +777 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/mimi.py +286 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model.py +260 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model_v2.py +383 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/utils.py +122 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/attention.py +97 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/layers.py +306 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/snac.py +154 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/vq.py +135 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/mel.py +33 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/vocos.py +359 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_bigvgan.py +54 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_descript.py +109 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_encodec.py +58 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_mimi.py +22 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_s3.py +25 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_snac.py +40 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_vocos.py +93 -0
- nexaai/binds/metal/py-lib/mlx_audio/server.py +525 -0
- nexaai/binds/metal/py-lib/mlx_audio/sts/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
- nexaai/binds/metal/py-lib/mlx_audio/sts/voice_pipeline.py +327 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/generate.py +174 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/alignment.py +248 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/attention.py +187 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/audio.py +76 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/conformer.py +331 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/ctc.py +34 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/audio.py +82 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/decoding.py +742 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/timing.py +329 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/whisper.py +862 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/writers.py +268 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/tests/test_models.py +381 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/utils.py +195 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/audio_player.py +120 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/convert.py +71 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/generate.py +449 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/__init__.py +4 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/bark.py +528 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/isftnet.py +12 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/pipeline.py +442 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/base.py +84 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/audio.py +287 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/config.py +256 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/dia.py +592 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/layers.py +870 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/__init__.py +3 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/attention.py +180 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/conformer.py +247 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/gpt2.py +38 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/indextts.py +412 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/mel.py +37 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/normalize.py +294 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/perceiver.py +62 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/interpolate.py +108 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/__init__.py +4 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/modules.py +659 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/voice.py +113 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/__init__.py +3 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/llama.py +324 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/outetts.py +255 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/tokens.py +36 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/__init__.py +3 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/attention.py +195 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/sesame.py +633 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/watermarking.py +105 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/bicodec.py +269 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual.py +209 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/spark.py +382 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/audio.py +220 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/file.py +221 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_base.py +66 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_convert.py +173 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_interpolate.py +88 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_models.py +974 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/utils.py +337 -0
- nexaai/binds/metal/py-lib/mlx_audio/utils.py +237 -0
- nexaai/binds/metal/py-lib/mlx_audio/version.py +1 -0
- nexaai/binds/metal/py-lib/profiling.py +239 -0
- nexaai/binds/nexaml/libfftw3.3.dylib +0 -0
- nexaai/binds/nexaml/libfftw3f.3.dylib +0 -0
- nexaai/binds/nexaml/libggml-base.dylib +0 -0
- nexaai/binds/nexaml/libggml-cpu.so +0 -0
- nexaai/binds/nexaml/libggml-metal.so +0 -0
- nexaai/binds/nexaml/libggml.dylib +0 -0
- nexaai/binds/nexaml/libmp3lame.0.dylib +0 -0
- nexaai/binds/nexaml/libmpg123.0.dylib +0 -0
- nexaai/binds/nexaml/libnexa-mm-process.dylib +0 -0
- nexaai/binds/nexaml/libnexa-sampling.dylib +0 -0
- nexaai/binds/nexaml/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexaml/libnexaproc.dylib +0 -0
- nexaai/binds/nexaml/libomp.dylib +0 -0
- nexaai/binds/nexaml/libqwen3-vl.dylib +0 -0
- nexaai/binds/nexaml/libqwen3vl-vision.dylib +0 -0
- nexaai/binds/rerank_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/vlm_bind.cpython-310-darwin.so +0 -0
- nexaai/common.py +106 -0
- nexaai/cv.py +95 -0
- nexaai/cv_impl/__init__.py +0 -0
- nexaai/cv_impl/mlx_cv_impl.py +91 -0
- nexaai/cv_impl/pybind_cv_impl.py +124 -0
- nexaai/diarize.py +80 -0
- nexaai/diarize_impl/__init__.py +1 -0
- nexaai/diarize_impl/pybind_diarize_impl.py +125 -0
- nexaai/embedder.py +73 -0
- nexaai/embedder_impl/__init__.py +0 -0
- nexaai/embedder_impl/mlx_embedder_impl.py +118 -0
- nexaai/embedder_impl/pybind_embedder_impl.py +96 -0
- nexaai/image_gen.py +141 -0
- nexaai/image_gen_impl/__init__.py +0 -0
- nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -0
- nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -0
- nexaai/llm.py +98 -0
- nexaai/llm_impl/__init__.py +0 -0
- nexaai/llm_impl/mlx_llm_impl.py +271 -0
- nexaai/llm_impl/pybind_llm_impl.py +238 -0
- nexaai/log.py +92 -0
- nexaai/mlx_backend/asr/__init__.py +12 -0
- nexaai/mlx_backend/asr/interface.py +122 -0
- nexaai/mlx_backend/common/__init__.py +0 -0
- nexaai/mlx_backend/common/utils.py +25 -0
- nexaai/mlx_backend/cv/__init__.py +0 -0
- nexaai/mlx_backend/cv/generate.py +195 -0
- nexaai/mlx_backend/cv/interface.py +162 -0
- nexaai/mlx_backend/cv/main.py +81 -0
- nexaai/mlx_backend/cv/modeling/pp_ocr_v4.py +1736 -0
- nexaai/mlx_backend/embedding/__init__.py +0 -0
- nexaai/mlx_backend/embedding/generate.py +333 -0
- nexaai/mlx_backend/embedding/interface.py +617 -0
- nexaai/mlx_backend/embedding/main.py +173 -0
- nexaai/mlx_backend/embedding/modeling/__init__.py +0 -0
- nexaai/mlx_backend/embedding/modeling/nexa_jina_v2.py +399 -0
- nexaai/mlx_backend/image_gen/__init__.py +1 -0
- nexaai/mlx_backend/image_gen/generate_sd.py +244 -0
- nexaai/mlx_backend/image_gen/interface.py +82 -0
- nexaai/mlx_backend/image_gen/main.py +281 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/__init__.py +306 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/clip.py +116 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/config.py +65 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/model_io.py +386 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/sampler.py +105 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/tokenizer.py +100 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/unet.py +460 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/vae.py +274 -0
- nexaai/mlx_backend/llm/__init__.py +0 -0
- nexaai/mlx_backend/llm/generate.py +149 -0
- nexaai/mlx_backend/llm/interface.py +764 -0
- nexaai/mlx_backend/llm/main.py +68 -0
- nexaai/mlx_backend/ml.py +888 -0
- nexaai/mlx_backend/mlx_audio/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/codec/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/__init__.py +5 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/activation.py +51 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/amp.py +96 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/conv.py +114 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/resample.py +177 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/base.py +228 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/dac.py +285 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/layers.py +129 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
- nexaai/mlx_backend/mlx_audio/codec/models/encodec/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/encodec/encodec.py +777 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/mimi.py +286 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/model.py +260 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/model_v2.py +383 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/utils.py +122 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/attention.py +97 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/layers.py +306 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/snac.py +154 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/vq.py +135 -0
- nexaai/mlx_backend/mlx_audio/codec/models/vocos/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/vocos/mel.py +33 -0
- nexaai/mlx_backend/mlx_audio/codec/models/vocos/vocos.py +359 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_bigvgan.py +54 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_descript.py +109 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_encodec.py +58 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_mimi.py +22 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_s3.py +25 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_snac.py +40 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_vocos.py +93 -0
- nexaai/mlx_backend/mlx_audio/server.py +525 -0
- nexaai/mlx_backend/mlx_audio/sts/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
- nexaai/mlx_backend/mlx_audio/sts/voice_pipeline.py +327 -0
- nexaai/mlx_backend/mlx_audio/stt/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/stt/generate.py +174 -0
- nexaai/mlx_backend/mlx_audio/stt/models/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/alignment.py +248 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/attention.py +187 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/audio.py +76 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/conformer.py +331 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/ctc.py +34 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
- nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
- nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/audio.py +82 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/decoding.py +742 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/timing.py +329 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/whisper.py +862 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/writers.py +268 -0
- nexaai/mlx_backend/mlx_audio/stt/tests/test_models.py +381 -0
- nexaai/mlx_backend/mlx_audio/stt/utils.py +195 -0
- nexaai/mlx_backend/mlx_audio/tts/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/audio_player.py +120 -0
- nexaai/mlx_backend/mlx_audio/tts/convert.py +71 -0
- nexaai/mlx_backend/mlx_audio/tts/generate.py +449 -0
- nexaai/mlx_backend/mlx_audio/tts/models/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/__init__.py +4 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/bark.py +528 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/isftnet.py +12 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/pipeline.py +442 -0
- nexaai/mlx_backend/mlx_audio/tts/models/base.py +84 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/audio.py +287 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/config.py +256 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/dia.py +592 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/layers.py +870 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/__init__.py +3 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/attention.py +180 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/conformer.py +247 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/gpt2.py +38 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/indextts.py +412 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/mel.py +37 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/normalize.py +294 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/perceiver.py +62 -0
- nexaai/mlx_backend/mlx_audio/tts/models/interpolate.py +108 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/__init__.py +4 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/modules.py +659 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/voice.py +113 -0
- nexaai/mlx_backend/mlx_audio/tts/models/llama/__init__.py +3 -0
- nexaai/mlx_backend/mlx_audio/tts/models/llama/llama.py +324 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/default_speaker.json +461 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/outetts.py +255 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/tokens.py +36 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/__init__.py +3 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/attention.py +195 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/sesame.py +633 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/watermarking.py +105 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/bicodec.py +269 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual.py +209 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/spark.py +382 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/audio.py +220 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/file.py +221 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_base.py +66 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_convert.py +173 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_interpolate.py +88 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_models.py +974 -0
- nexaai/mlx_backend/mlx_audio/tts/utils.py +337 -0
- nexaai/mlx_backend/mlx_audio/utils.py +237 -0
- nexaai/mlx_backend/mlx_audio/version.py +1 -0
- nexaai/mlx_backend/profiling.py +239 -0
- nexaai/mlx_backend/rerank/__init__.py +0 -0
- nexaai/mlx_backend/rerank/generate.py +174 -0
- nexaai/mlx_backend/rerank/interface.py +287 -0
- nexaai/mlx_backend/rerank/main.py +127 -0
- nexaai/mlx_backend/rerank/modeling/__init__.py +0 -0
- nexaai/mlx_backend/rerank/modeling/nexa_jina_rerank.py +330 -0
- nexaai/mlx_backend/sd/__init__.py +1 -0
- nexaai/mlx_backend/sd/interface.py +362 -0
- nexaai/mlx_backend/sd/main.py +286 -0
- nexaai/mlx_backend/sd/modeling/__init__.py +306 -0
- nexaai/mlx_backend/sd/modeling/clip.py +116 -0
- nexaai/mlx_backend/sd/modeling/config.py +65 -0
- nexaai/mlx_backend/sd/modeling/model_io.py +385 -0
- nexaai/mlx_backend/sd/modeling/sampler.py +105 -0
- nexaai/mlx_backend/sd/modeling/tokenizer.py +100 -0
- nexaai/mlx_backend/sd/modeling/unet.py +460 -0
- nexaai/mlx_backend/sd/modeling/vae.py +274 -0
- nexaai/mlx_backend/tts/__init__.py +12 -0
- nexaai/mlx_backend/tts/interface.py +276 -0
- nexaai/mlx_backend/vlm/__init__.py +3 -0
- nexaai/mlx_backend/vlm/generate.py +572 -0
- nexaai/mlx_backend/vlm/generate_qwen3_vl.py +374 -0
- nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +259 -0
- nexaai/mlx_backend/vlm/interface.py +559 -0
- nexaai/mlx_backend/vlm/main.py +365 -0
- nexaai/mlx_backend/vlm/modeling/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/convert.py +68 -0
- nexaai/mlx_backend/vlm/modeling/models/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/interpolate.py +186 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/language.py +233 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/vision.py +503 -0
- nexaai/mlx_backend/vlm/modeling/models/base.py +202 -0
- nexaai/mlx_backend/vlm/modeling/models/cache.py +230 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/florence2.py +366 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/language.py +488 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/vision.py +591 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/gemma3.py +213 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/language.py +315 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/vision.py +238 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/__init__.py +2 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/audio.py +1038 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/config.py +139 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/language.py +629 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/vision.py +1022 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/idefics2.py +294 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/language.py +191 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/vision.py +267 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/idefics3.py +175 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/language.py +192 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/vision.py +233 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/language.py +220 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/processor.py +393 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/vision.py +293 -0
- nexaai/mlx_backend/vlm/modeling/models/kernels.py +307 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/language.py +509 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/vision.py +522 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/language.py +386 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/llama4.py +138 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/vision.py +560 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/language.py +240 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/llava.py +153 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/vision.py +259 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/language.py +236 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/vision.py +303 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/language.py +230 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/llava_next.py +160 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/vision.py +243 -0
- nexaai/mlx_backend/vlm/modeling/models/mistral3/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/mistral3/mistral3.py +283 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/language.py +416 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/mllama.py +172 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/vision.py +499 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/language.py +243 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/molmo.py +133 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/vision.py +465 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/__init__.py +10 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/language.py +230 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/sam.py +557 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/vision.py +526 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/language.py +282 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/paligemma.py +160 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/vision.py +242 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/language.py +21 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/su_rope.py +71 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/vision.py +324 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/language.py +229 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/pixtral.py +161 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/vision.py +320 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/config.py +104 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/language.py +490 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/vision.py +312 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/processor.py +476 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +1262 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1308 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
- nexaai/mlx_backend/vlm/modeling/models/smolvlm/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
- nexaai/mlx_backend/vlm/modeling/processing_qwen2_5_vl.py +209 -0
- nexaai/mlx_backend/vlm/modeling/processing_qwen2_vl.py +215 -0
- nexaai/mlx_backend/vlm/modeling/prompt_utils.py +474 -0
- nexaai/mlx_backend/vlm/modeling/sample_utils.py +39 -0
- nexaai/mlx_backend/vlm/modeling/tokenizer_utils.py +344 -0
- nexaai/mlx_backend/vlm/modeling/trainer/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/trainer/lora.py +70 -0
- nexaai/mlx_backend/vlm/modeling/trainer/trainer.py +296 -0
- nexaai/mlx_backend/vlm/modeling/trainer/utils.py +160 -0
- nexaai/mlx_backend/vlm/modeling/utils.py +928 -0
- nexaai/rerank.py +57 -0
- nexaai/rerank_impl/__init__.py +0 -0
- nexaai/rerank_impl/mlx_rerank_impl.py +94 -0
- nexaai/rerank_impl/pybind_rerank_impl.py +136 -0
- nexaai/runtime.py +68 -0
- nexaai/runtime_error.py +24 -0
- nexaai/tts.py +75 -0
- nexaai/tts_impl/__init__.py +0 -0
- nexaai/tts_impl/mlx_tts_impl.py +94 -0
- nexaai/tts_impl/pybind_tts_impl.py +43 -0
- nexaai/utils/decode.py +18 -0
- nexaai/utils/manifest_utils.py +531 -0
- nexaai/utils/model_manager.py +1745 -0
- nexaai/utils/model_types.py +49 -0
- nexaai/utils/progress_tracker.py +389 -0
- nexaai/utils/quantization_utils.py +245 -0
- nexaai/vlm.py +130 -0
- nexaai/vlm_impl/__init__.py +0 -0
- nexaai/vlm_impl/mlx_vlm_impl.py +259 -0
- nexaai/vlm_impl/pybind_vlm_impl.py +275 -0
- nexaai-1.0.29.dist-info/METADATA +35 -0
- nexaai-1.0.29.dist-info/RECORD +580 -0
- nexaai-1.0.29.dist-info/WHEEL +5 -0
- nexaai-1.0.29.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,476 @@
|
|
|
1
|
+
from typing import Any, Dict, List, Optional, Union
|
|
2
|
+
import mlx.core as mx
|
|
3
|
+
import numpy as np
|
|
4
|
+
from PIL import Image
|
|
5
|
+
import io
|
|
6
|
+
import base64
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Qwen3VLProcessor:
|
|
10
|
+
def __init__(self, tokenizer=None, image_processor=None):
|
|
11
|
+
self.tokenizer = tokenizer
|
|
12
|
+
self.image_processor = image_processor
|
|
13
|
+
|
|
14
|
+
# Vision tokens (following the official implementation)
|
|
15
|
+
self.image_token = "<|image_pad|>"
|
|
16
|
+
self.vision_start_token = "<|vision_start|>"
|
|
17
|
+
self.vision_end_token = "<|vision_end|>"
|
|
18
|
+
|
|
19
|
+
# Token IDs (will be set properly if tokenizer is provided)
|
|
20
|
+
if tokenizer:
|
|
21
|
+
self.image_token_id = getattr(tokenizer, 'image_token_id',
|
|
22
|
+
tokenizer.convert_tokens_to_ids(self.image_token))
|
|
23
|
+
self.vision_start_token_id = getattr(tokenizer, 'vision_start_token_id',
|
|
24
|
+
tokenizer.convert_tokens_to_ids(self.vision_start_token))
|
|
25
|
+
self.vision_end_token_id = getattr(tokenizer, 'vision_end_token_id',
|
|
26
|
+
tokenizer.convert_tokens_to_ids(self.vision_end_token))
|
|
27
|
+
else:
|
|
28
|
+
# Fallback IDs for when no tokenizer is provided
|
|
29
|
+
self.image_token_id = 151655
|
|
30
|
+
self.vision_start_token_id = 151652
|
|
31
|
+
self.vision_end_token_id = 151653
|
|
32
|
+
|
|
33
|
+
# Image processing parameters (following Qwen3VL defaults)
|
|
34
|
+
self.min_pixels = 4096
|
|
35
|
+
self.max_pixels = 16777216
|
|
36
|
+
self.patch_size = 16
|
|
37
|
+
self.merge_size = 2
|
|
38
|
+
self.temporal_patch_size = 2
|
|
39
|
+
|
|
40
|
+
# Add the missing image_mean and image_std
|
|
41
|
+
self.image_mean = [0.5, 0.5, 0.5]
|
|
42
|
+
self.image_std = [0.5, 0.5, 0.5]
|
|
43
|
+
|
|
44
|
+
def _extract_patches(self, image_array: np.ndarray) -> np.ndarray:
|
|
45
|
+
"""
|
|
46
|
+
Extract patches from image array to create proper tensor for Conv3d.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
image_array: Shape (C, H, W)
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
patches: Flattened tensor that can be reshaped to
|
|
53
|
+
(num_patches, C, temporal_patch_size, patch_size, patch_size)
|
|
54
|
+
"""
|
|
55
|
+
C, H, W = image_array.shape
|
|
56
|
+
|
|
57
|
+
# Calculate number of patches
|
|
58
|
+
patch_h = H // self.patch_size
|
|
59
|
+
patch_w = W // self.patch_size
|
|
60
|
+
|
|
61
|
+
# Extract spatial patches
|
|
62
|
+
# Reshape to (C, patch_h, patch_size, patch_w, patch_size)
|
|
63
|
+
patches = image_array.reshape(
|
|
64
|
+
C, patch_h, self.patch_size, patch_w, self.patch_size
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# Rearrange to (patch_h, patch_w, C, patch_size, patch_size)
|
|
68
|
+
patches = patches.transpose(1, 3, 0, 2, 4)
|
|
69
|
+
|
|
70
|
+
# Reshape to (patch_h * patch_w, C, patch_size, patch_size)
|
|
71
|
+
num_patches = patch_h * patch_w
|
|
72
|
+
patches = patches.reshape(num_patches, C, self.patch_size, self.patch_size)
|
|
73
|
+
|
|
74
|
+
# Add temporal dimension by duplicating the patches
|
|
75
|
+
# Shape: (num_patches, C, temporal_patch_size, patch_size, patch_size)
|
|
76
|
+
patches = np.tile(patches[:, :, None, :, :], (1, 1, self.temporal_patch_size, 1, 1))
|
|
77
|
+
|
|
78
|
+
return patches
|
|
79
|
+
|
|
80
|
+
def _process_single_image(self, image: Union[str, Image.Image, np.ndarray]) -> Dict[str, Any]:
|
|
81
|
+
"""Process a single image and return processed data."""
|
|
82
|
+
if isinstance(image, str):
|
|
83
|
+
if image.startswith('data:image'):
|
|
84
|
+
image_data = base64.b64decode(image.split(',')[1])
|
|
85
|
+
image = Image.open(io.BytesIO(image_data))
|
|
86
|
+
else:
|
|
87
|
+
image = Image.open(image)
|
|
88
|
+
elif isinstance(image, np.ndarray):
|
|
89
|
+
image = Image.fromarray(image)
|
|
90
|
+
|
|
91
|
+
if image.mode != 'RGB':
|
|
92
|
+
image = image.convert('RGB')
|
|
93
|
+
|
|
94
|
+
# Resize image based on pixel constraints
|
|
95
|
+
width, height = image.size
|
|
96
|
+
pixels = width * height
|
|
97
|
+
|
|
98
|
+
if pixels < self.min_pixels:
|
|
99
|
+
scale = (self.min_pixels / pixels) ** 0.5
|
|
100
|
+
width = int(width * scale)
|
|
101
|
+
height = int(height * scale)
|
|
102
|
+
elif pixels > self.max_pixels:
|
|
103
|
+
scale = (self.max_pixels / pixels) ** 0.5
|
|
104
|
+
width = int(width * scale)
|
|
105
|
+
height = int(height * scale)
|
|
106
|
+
|
|
107
|
+
# Ensure dimensions are multiples of patch_size AND work with merge_size
|
|
108
|
+
# Use fraction-based rounding to match PyTorch behavior
|
|
109
|
+
import math
|
|
110
|
+
|
|
111
|
+
width_frac = (width / self.patch_size) % 1
|
|
112
|
+
height_frac = (height / self.patch_size) % 1
|
|
113
|
+
|
|
114
|
+
# Round up if fraction >= 0.3, otherwise round down
|
|
115
|
+
# This matches the observed PyTorch processor behavior
|
|
116
|
+
if width_frac >= 0.3:
|
|
117
|
+
width = math.ceil(width / self.patch_size) * self.patch_size
|
|
118
|
+
else:
|
|
119
|
+
width = (width // self.patch_size) * self.patch_size
|
|
120
|
+
|
|
121
|
+
if height_frac >= 0.3:
|
|
122
|
+
height = math.ceil(height / self.patch_size) * self.patch_size
|
|
123
|
+
else:
|
|
124
|
+
height = (height // self.patch_size) * self.patch_size
|
|
125
|
+
|
|
126
|
+
# CRITICAL: Ensure patch dimensions are even for 2x2 merging
|
|
127
|
+
# If either dimension is odd, add one more patch to make it even
|
|
128
|
+
h_patches = height // self.patch_size
|
|
129
|
+
w_patches = width // self.patch_size
|
|
130
|
+
|
|
131
|
+
if h_patches % 2 == 1:
|
|
132
|
+
height += self.patch_size # Add one more patch row
|
|
133
|
+
|
|
134
|
+
if w_patches % 2 == 1:
|
|
135
|
+
width += self.patch_size # Add one more patch column
|
|
136
|
+
|
|
137
|
+
if width == 0 or height == 0:
|
|
138
|
+
width = height = self.patch_size
|
|
139
|
+
|
|
140
|
+
image = image.resize((width, height), Image.Resampling.LANCZOS)
|
|
141
|
+
|
|
142
|
+
# Convert to array and normalize
|
|
143
|
+
image_array = np.array(image).astype(np.float32) / 255.0
|
|
144
|
+
|
|
145
|
+
# Qwen3VL normalization
|
|
146
|
+
mean = np.array(self.image_mean)
|
|
147
|
+
std = np.array(self.image_std)
|
|
148
|
+
image_array = (image_array - mean) / std
|
|
149
|
+
|
|
150
|
+
# Convert HWC to CHW
|
|
151
|
+
image_array = np.transpose(image_array, (2, 0, 1))
|
|
152
|
+
|
|
153
|
+
# Calculate grid dimensions
|
|
154
|
+
h_patches = height // self.patch_size
|
|
155
|
+
w_patches = width // self.patch_size
|
|
156
|
+
|
|
157
|
+
# Extract patches using the exact same method as PyTorch Conv3d unfold
|
|
158
|
+
C, H, W = image_array.shape
|
|
159
|
+
|
|
160
|
+
# Reshape to extract patches: (C, H//patch_size, patch_size, W//patch_size, patch_size)
|
|
161
|
+
patches = image_array.reshape(C, h_patches, self.patch_size, w_patches, self.patch_size)
|
|
162
|
+
|
|
163
|
+
# Rearrange to group patches: (h_patches, w_patches, C, patch_size, patch_size)
|
|
164
|
+
patches = patches.transpose(1, 3, 0, 2, 4)
|
|
165
|
+
|
|
166
|
+
# Flatten spatial patches: (h_patches * w_patches, C, patch_size, patch_size)
|
|
167
|
+
patches = patches.reshape(-1, C, self.patch_size, self.patch_size)
|
|
168
|
+
|
|
169
|
+
# Add temporal dimension: (num_patches, C, T, patch_size, patch_size)
|
|
170
|
+
patches_with_temporal = np.tile(patches[:, :, None, :, :], (1, 1, self.temporal_patch_size, 1, 1))
|
|
171
|
+
|
|
172
|
+
# Flatten each patch in the order: C, T, H, W to match PyTorch Conv3d
|
|
173
|
+
pixel_values = patches_with_temporal.reshape(patches_with_temporal.shape[0], -1)
|
|
174
|
+
|
|
175
|
+
# Apply spatial merging reordering to match PyTorch processor
|
|
176
|
+
# Group patches into merge_size x merge_size blocks and reorder
|
|
177
|
+
pixel_values = pixel_values.reshape(h_patches // self.merge_size, self.merge_size,
|
|
178
|
+
w_patches // self.merge_size, self.merge_size, -1)
|
|
179
|
+
# Rearrange to (h_blocks, w_blocks, merge_size*merge_size, feature_dim)
|
|
180
|
+
pixel_values = pixel_values.transpose(0, 2, 1, 3, 4)
|
|
181
|
+
pixel_values = pixel_values.reshape(h_patches // self.merge_size,
|
|
182
|
+
w_patches // self.merge_size,
|
|
183
|
+
self.merge_size * self.merge_size, -1)
|
|
184
|
+
# Flatten to (total_merged_patches, feature_dim)
|
|
185
|
+
pixel_values = pixel_values.reshape(-1, pixel_values.shape[-1])
|
|
186
|
+
|
|
187
|
+
return {
|
|
188
|
+
'pixel_values': pixel_values, # Shape: (num_patches, 1536)
|
|
189
|
+
'grid_thw': [1, h_patches, w_patches] # T=1 for images
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
def _insert_image_tokens(self, text: str, image_grid_thw: List[List[int]]) -> str:
|
|
193
|
+
"""Insert the correct number of image tokens based on grid dimensions."""
|
|
194
|
+
if not image_grid_thw:
|
|
195
|
+
return text
|
|
196
|
+
|
|
197
|
+
merge_length = self.merge_size ** 2
|
|
198
|
+
index = 0
|
|
199
|
+
|
|
200
|
+
while self.image_token in text and index < len(image_grid_thw):
|
|
201
|
+
# Calculate number of tokens needed for this image
|
|
202
|
+
t, h, w = image_grid_thw[index]
|
|
203
|
+
num_image_tokens = (t * h * w) // merge_length
|
|
204
|
+
|
|
205
|
+
# Replace one image token with the calculated number of tokens
|
|
206
|
+
text = text.replace(self.image_token, self.image_token * num_image_tokens, 1)
|
|
207
|
+
index += 1
|
|
208
|
+
|
|
209
|
+
return text
|
|
210
|
+
|
|
211
|
+
def __call__(
|
|
212
|
+
self,
|
|
213
|
+
text: Union[str, List[str]] = None,
|
|
214
|
+
images: Union[Image.Image, List[Image.Image], str, List[str], np.ndarray, List[np.ndarray]] = None,
|
|
215
|
+
return_tensors: str = "mlx",
|
|
216
|
+
**kwargs
|
|
217
|
+
) -> Dict[str, mx.array]:
|
|
218
|
+
"""
|
|
219
|
+
Process text and images for Qwen3VL model.
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
Dict containing:
|
|
223
|
+
- input_ids: Tokenized text with proper image tokens
|
|
224
|
+
- pixel_values: Processed image patches (if images provided)
|
|
225
|
+
- image_grid_thw: Grid dimensions for images (if images provided)
|
|
226
|
+
"""
|
|
227
|
+
result = {}
|
|
228
|
+
|
|
229
|
+
# Process images first
|
|
230
|
+
grid_thw_list = None
|
|
231
|
+
if images is not None:
|
|
232
|
+
if not isinstance(images, list):
|
|
233
|
+
images = [images]
|
|
234
|
+
|
|
235
|
+
# Check if images list is not empty
|
|
236
|
+
if len(images) > 0:
|
|
237
|
+
if self.image_processor is not None:
|
|
238
|
+
image_inputs = self.image_processor(images=images, return_tensors="np")
|
|
239
|
+
result["pixel_values"] = mx.array(image_inputs["pixel_values"])
|
|
240
|
+
result["image_grid_thw"] = mx.array(image_inputs["image_grid_thw"])
|
|
241
|
+
grid_thw_list = image_inputs["image_grid_thw"].tolist()
|
|
242
|
+
else:
|
|
243
|
+
processed_patches = []
|
|
244
|
+
grid_thw_list = []
|
|
245
|
+
for image in images:
|
|
246
|
+
processed = self._process_single_image(image)
|
|
247
|
+
processed_patches.append(processed["pixel_values"])
|
|
248
|
+
grid_thw_list.append(processed["grid_thw"])
|
|
249
|
+
all_patches = np.concatenate(processed_patches, axis=0)
|
|
250
|
+
result["pixel_values"] = mx.array(all_patches)
|
|
251
|
+
result["image_grid_thw"] = mx.array(np.array(grid_thw_list))
|
|
252
|
+
|
|
253
|
+
# Process text
|
|
254
|
+
if text is not None:
|
|
255
|
+
if not isinstance(text, list):
|
|
256
|
+
text = [text]
|
|
257
|
+
text = text.copy()
|
|
258
|
+
if grid_thw_list is not None:
|
|
259
|
+
for i in range(len(text)):
|
|
260
|
+
text[i] = self._insert_image_tokens(text[i], grid_thw_list)
|
|
261
|
+
if self.tokenizer:
|
|
262
|
+
text_inputs = self.tokenizer(text, return_tensors="np", **kwargs)
|
|
263
|
+
result["input_ids"] = mx.array(text_inputs["input_ids"])
|
|
264
|
+
if "attention_mask" in text_inputs:
|
|
265
|
+
result["attention_mask"] = mx.array(text_inputs["attention_mask"])
|
|
266
|
+
else:
|
|
267
|
+
all_tokens = []
|
|
268
|
+
for t in text:
|
|
269
|
+
tokens = [hash(word) % 50000 for word in t.split()]
|
|
270
|
+
all_tokens.append(tokens)
|
|
271
|
+
max_len = max(len(tokens) for tokens in all_tokens)
|
|
272
|
+
padded_tokens = []
|
|
273
|
+
for tokens in all_tokens:
|
|
274
|
+
padded = tokens + [0] * (max_len - len(tokens))
|
|
275
|
+
padded_tokens.append(padded)
|
|
276
|
+
result["input_ids"] = mx.array(np.array(padded_tokens))
|
|
277
|
+
|
|
278
|
+
return result
|
|
279
|
+
|
|
280
|
+
def _extract_images_and_text_from_messages(self, messages: List[Dict]) -> tuple:
|
|
281
|
+
"""Extract images and text from message format."""
|
|
282
|
+
images = []
|
|
283
|
+
text_parts = []
|
|
284
|
+
|
|
285
|
+
for message in messages:
|
|
286
|
+
role = message.get("role", "user")
|
|
287
|
+
content = message.get("content", [])
|
|
288
|
+
|
|
289
|
+
if isinstance(content, str):
|
|
290
|
+
# Simple text content
|
|
291
|
+
text_parts.append({"role": role, "content": content})
|
|
292
|
+
elif isinstance(content, list):
|
|
293
|
+
# Multi-modal content
|
|
294
|
+
message_text_parts = []
|
|
295
|
+
for item in content:
|
|
296
|
+
if item.get("type") == "image":
|
|
297
|
+
images.append(item.get("image"))
|
|
298
|
+
message_text_parts.append("<|vision_start|><|image_pad|><|vision_end|>")
|
|
299
|
+
elif item.get("type") == "text":
|
|
300
|
+
message_text_parts.append(item.get("text", ""))
|
|
301
|
+
|
|
302
|
+
combined_text = "".join(message_text_parts)
|
|
303
|
+
text_parts.append({"role": role, "content": combined_text})
|
|
304
|
+
|
|
305
|
+
return images, text_parts
|
|
306
|
+
|
|
307
|
+
def apply_chat_template(
|
|
308
|
+
self,
|
|
309
|
+
messages: List[Dict],
|
|
310
|
+
add_generation_prompt: bool = True,
|
|
311
|
+
tokenize: bool = False,
|
|
312
|
+
**kwargs
|
|
313
|
+
) -> str:
|
|
314
|
+
"""Apply chat template to messages."""
|
|
315
|
+
# Handle multi-modal messages
|
|
316
|
+
if any(isinstance(msg.get("content"), list) for msg in messages):
|
|
317
|
+
_, text_messages = self._extract_images_and_text_from_messages(messages)
|
|
318
|
+
messages = text_messages
|
|
319
|
+
|
|
320
|
+
if not self.tokenizer:
|
|
321
|
+
# Fallback chat template
|
|
322
|
+
formatted_messages = []
|
|
323
|
+
for msg in messages:
|
|
324
|
+
role = msg.get("role", "user")
|
|
325
|
+
content = msg.get("content", "")
|
|
326
|
+
formatted_messages.append(f"<|im_start|>{role}\n{content}<|im_end|>")
|
|
327
|
+
|
|
328
|
+
result = "\n".join(formatted_messages)
|
|
329
|
+
if add_generation_prompt:
|
|
330
|
+
result += "\n<|im_start|>assistant\n"
|
|
331
|
+
return result
|
|
332
|
+
|
|
333
|
+
# Use tokenizer and manually remove system message to match ground truth
|
|
334
|
+
result = self.tokenizer.apply_chat_template(
|
|
335
|
+
messages,
|
|
336
|
+
add_generation_prompt=add_generation_prompt,
|
|
337
|
+
tokenize=tokenize,
|
|
338
|
+
**kwargs
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
# Remove system message to match ground truth format
|
|
342
|
+
system_prefix = '<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n'
|
|
343
|
+
if result.startswith(system_prefix):
|
|
344
|
+
result = result[len(system_prefix):]
|
|
345
|
+
|
|
346
|
+
return result
|
|
347
|
+
|
|
348
|
+
def messages_to_text(
|
|
349
|
+
self,
|
|
350
|
+
messages: List[Dict],
|
|
351
|
+
add_generation_prompt: bool = True,
|
|
352
|
+
**kwargs
|
|
353
|
+
) -> tuple:
|
|
354
|
+
"""
|
|
355
|
+
Step 1: Convert multi-modal messages to text format.
|
|
356
|
+
|
|
357
|
+
Args:
|
|
358
|
+
messages: List of message dicts with role and content
|
|
359
|
+
add_generation_prompt: Whether to add generation prompt
|
|
360
|
+
**kwargs: Additional arguments
|
|
361
|
+
|
|
362
|
+
Returns:
|
|
363
|
+
Tuple of (text, images) where text is the formatted string and images is list of image objects
|
|
364
|
+
"""
|
|
365
|
+
# Extract images and text from messages
|
|
366
|
+
images, text_messages = self._extract_images_and_text_from_messages(messages)
|
|
367
|
+
|
|
368
|
+
# Apply chat template
|
|
369
|
+
text = self.apply_chat_template(
|
|
370
|
+
text_messages,
|
|
371
|
+
add_generation_prompt=add_generation_prompt,
|
|
372
|
+
tokenize=False,
|
|
373
|
+
**kwargs
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
# Load images from URLs if needed
|
|
377
|
+
processed_images = []
|
|
378
|
+
for img in images:
|
|
379
|
+
if isinstance(img, str) and (img.startswith('http://') or img.startswith('https://')):
|
|
380
|
+
# Load image from URL
|
|
381
|
+
import requests
|
|
382
|
+
from io import BytesIO
|
|
383
|
+
try:
|
|
384
|
+
response = requests.get(img, stream=True, timeout=10)
|
|
385
|
+
img = Image.open(BytesIO(response.content))
|
|
386
|
+
except Exception as e:
|
|
387
|
+
raise ValueError(f"Failed to load image from URL {img}: {e}")
|
|
388
|
+
processed_images.append(img)
|
|
389
|
+
|
|
390
|
+
return text, processed_images
|
|
391
|
+
|
|
392
|
+
def text_to_input_ids(
|
|
393
|
+
self,
|
|
394
|
+
text: str,
|
|
395
|
+
images: List = None,
|
|
396
|
+
return_tensors: str = "mlx",
|
|
397
|
+
**kwargs
|
|
398
|
+
) -> Dict[str, Any]:
|
|
399
|
+
"""
|
|
400
|
+
Step 2: Process text and images into input_ids and pixel_values.
|
|
401
|
+
|
|
402
|
+
Args:
|
|
403
|
+
text: Formatted text string (from messages_to_text)
|
|
404
|
+
images: List of image objects
|
|
405
|
+
return_tensors: Format of returned tensors
|
|
406
|
+
**kwargs: Additional arguments
|
|
407
|
+
|
|
408
|
+
Returns:
|
|
409
|
+
Dict with input_ids, pixel_values, image_grid_thw
|
|
410
|
+
"""
|
|
411
|
+
return self(
|
|
412
|
+
text=[text],
|
|
413
|
+
images=images,
|
|
414
|
+
return_tensors=return_tensors,
|
|
415
|
+
**kwargs
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
def process_messages(
|
|
419
|
+
self,
|
|
420
|
+
messages: List[Dict],
|
|
421
|
+
add_generation_prompt: bool = True,
|
|
422
|
+
return_tensors: str = "mlx",
|
|
423
|
+
**kwargs
|
|
424
|
+
) -> Dict[str, Any]:
|
|
425
|
+
"""
|
|
426
|
+
Process multi-modal messages end-to-end (combines messages_to_text + text_to_input_ids).
|
|
427
|
+
|
|
428
|
+
Args:
|
|
429
|
+
messages: List of message dicts with role and content
|
|
430
|
+
add_generation_prompt: Whether to add generation prompt
|
|
431
|
+
return_tensors: Format of returned tensors
|
|
432
|
+
**kwargs: Additional arguments
|
|
433
|
+
|
|
434
|
+
Returns:
|
|
435
|
+
Dict with input_ids, pixel_values, image_grid_thw
|
|
436
|
+
"""
|
|
437
|
+
# Step 1: Convert messages to text
|
|
438
|
+
text, processed_images = self.messages_to_text(
|
|
439
|
+
messages,
|
|
440
|
+
add_generation_prompt=add_generation_prompt,
|
|
441
|
+
**kwargs
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
# Step 2: Convert text to input_ids
|
|
445
|
+
return self.text_to_input_ids(
|
|
446
|
+
text,
|
|
447
|
+
images=processed_images,
|
|
448
|
+
return_tensors=return_tensors,
|
|
449
|
+
**kwargs
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
def post_process_image_text_to_text(
|
|
453
|
+
self,
|
|
454
|
+
generated_outputs,
|
|
455
|
+
skip_special_tokens: bool = True,
|
|
456
|
+
**kwargs
|
|
457
|
+
) -> List[str]:
|
|
458
|
+
"""Decode generated token IDs back to text."""
|
|
459
|
+
if self.tokenizer:
|
|
460
|
+
if hasattr(generated_outputs, 'tolist'):
|
|
461
|
+
generated_outputs = generated_outputs.tolist()
|
|
462
|
+
|
|
463
|
+
return self.tokenizer.batch_decode(
|
|
464
|
+
generated_outputs,
|
|
465
|
+
skip_special_tokens=skip_special_tokens,
|
|
466
|
+
**kwargs
|
|
467
|
+
)
|
|
468
|
+
else:
|
|
469
|
+
# Fallback decoding
|
|
470
|
+
return ["[Decoded text - tokenizer not available]"] * len(generated_outputs)
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
# Convenience function
|
|
474
|
+
def create_qwen3vl_processor(tokenizer=None, image_processor=None):
|
|
475
|
+
"""Create a Qwen3VL processor instance."""
|
|
476
|
+
return Qwen3VLProcessor(tokenizer=tokenizer, image_processor=image_processor)
|