nexaai 1.0.29__cp310-cp310-macosx_14_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nexaai/__init__.py +99 -0
- nexaai/_stub.cpython-310-darwin.so +0 -0
- nexaai/_version.py +4 -0
- nexaai/asr.py +68 -0
- nexaai/asr_impl/__init__.py +0 -0
- nexaai/asr_impl/mlx_asr_impl.py +93 -0
- nexaai/asr_impl/pybind_asr_impl.py +127 -0
- nexaai/base.py +39 -0
- nexaai/binds/__init__.py +7 -0
- nexaai/binds/asr_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/common_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/cpu_gpu/libggml-base.dylib +0 -0
- nexaai/binds/cpu_gpu/libggml-cpu.so +0 -0
- nexaai/binds/cpu_gpu/libggml-metal.so +0 -0
- nexaai/binds/cpu_gpu/libggml.dylib +0 -0
- nexaai/binds/cpu_gpu/libmtmd.dylib +0 -0
- nexaai/binds/cpu_gpu/libnexa_cpu_gpu.dylib +0 -0
- nexaai/binds/cpu_gpu/libnexa_plugin.dylib +0 -0
- nexaai/binds/cv_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/diarize_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/embedder_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/libnexa_bridge.dylib +0 -0
- nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/metal/libnexa_plugin.dylib +0 -0
- nexaai/binds/metal/py-lib/ml.py +888 -0
- nexaai/binds/metal/py-lib/mlx_audio/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/__init__.py +5 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/activation.py +51 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/amp.py +96 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/conv.py +114 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/resample.py +177 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/base.py +228 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/dac.py +285 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/layers.py +129 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/encodec.py +777 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/mimi.py +286 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model.py +260 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model_v2.py +383 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/utils.py +122 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/attention.py +97 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/layers.py +306 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/snac.py +154 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/vq.py +135 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/mel.py +33 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/vocos.py +359 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_bigvgan.py +54 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_descript.py +109 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_encodec.py +58 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_mimi.py +22 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_s3.py +25 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_snac.py +40 -0
- nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_vocos.py +93 -0
- nexaai/binds/metal/py-lib/mlx_audio/server.py +525 -0
- nexaai/binds/metal/py-lib/mlx_audio/sts/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
- nexaai/binds/metal/py-lib/mlx_audio/sts/voice_pipeline.py +327 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/generate.py +174 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/alignment.py +248 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/attention.py +187 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/audio.py +76 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/conformer.py +331 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/ctc.py +34 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/audio.py +82 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/decoding.py +742 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/timing.py +329 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/whisper.py +862 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/writers.py +268 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/tests/test_models.py +381 -0
- nexaai/binds/metal/py-lib/mlx_audio/stt/utils.py +195 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/audio_player.py +120 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/convert.py +71 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/generate.py +449 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/__init__.py +4 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/bark.py +528 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/isftnet.py +12 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/pipeline.py +442 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/base.py +84 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/audio.py +287 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/config.py +256 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/dia.py +592 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/layers.py +870 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/__init__.py +3 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/attention.py +180 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/conformer.py +247 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/gpt2.py +38 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/indextts.py +412 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/mel.py +37 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/normalize.py +294 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/perceiver.py +62 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/interpolate.py +108 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/__init__.py +4 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/modules.py +659 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/voice.py +113 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/__init__.py +3 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/llama.py +324 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/outetts.py +255 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/tokens.py +36 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/__init__.py +3 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/attention.py +195 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/sesame.py +633 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/watermarking.py +105 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/bicodec.py +269 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual.py +209 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/spark.py +382 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/audio.py +220 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/file.py +221 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/__init__.py +0 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_base.py +66 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_convert.py +173 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_interpolate.py +88 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_models.py +974 -0
- nexaai/binds/metal/py-lib/mlx_audio/tts/utils.py +337 -0
- nexaai/binds/metal/py-lib/mlx_audio/utils.py +237 -0
- nexaai/binds/metal/py-lib/mlx_audio/version.py +1 -0
- nexaai/binds/metal/py-lib/profiling.py +239 -0
- nexaai/binds/nexaml/libfftw3.3.dylib +0 -0
- nexaai/binds/nexaml/libfftw3f.3.dylib +0 -0
- nexaai/binds/nexaml/libggml-base.dylib +0 -0
- nexaai/binds/nexaml/libggml-cpu.so +0 -0
- nexaai/binds/nexaml/libggml-metal.so +0 -0
- nexaai/binds/nexaml/libggml.dylib +0 -0
- nexaai/binds/nexaml/libmp3lame.0.dylib +0 -0
- nexaai/binds/nexaml/libmpg123.0.dylib +0 -0
- nexaai/binds/nexaml/libnexa-mm-process.dylib +0 -0
- nexaai/binds/nexaml/libnexa-sampling.dylib +0 -0
- nexaai/binds/nexaml/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexaml/libnexaproc.dylib +0 -0
- nexaai/binds/nexaml/libomp.dylib +0 -0
- nexaai/binds/nexaml/libqwen3-vl.dylib +0 -0
- nexaai/binds/nexaml/libqwen3vl-vision.dylib +0 -0
- nexaai/binds/rerank_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/vlm_bind.cpython-310-darwin.so +0 -0
- nexaai/common.py +106 -0
- nexaai/cv.py +95 -0
- nexaai/cv_impl/__init__.py +0 -0
- nexaai/cv_impl/mlx_cv_impl.py +91 -0
- nexaai/cv_impl/pybind_cv_impl.py +124 -0
- nexaai/diarize.py +80 -0
- nexaai/diarize_impl/__init__.py +1 -0
- nexaai/diarize_impl/pybind_diarize_impl.py +125 -0
- nexaai/embedder.py +73 -0
- nexaai/embedder_impl/__init__.py +0 -0
- nexaai/embedder_impl/mlx_embedder_impl.py +118 -0
- nexaai/embedder_impl/pybind_embedder_impl.py +96 -0
- nexaai/image_gen.py +141 -0
- nexaai/image_gen_impl/__init__.py +0 -0
- nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -0
- nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -0
- nexaai/llm.py +98 -0
- nexaai/llm_impl/__init__.py +0 -0
- nexaai/llm_impl/mlx_llm_impl.py +271 -0
- nexaai/llm_impl/pybind_llm_impl.py +238 -0
- nexaai/log.py +92 -0
- nexaai/mlx_backend/asr/__init__.py +12 -0
- nexaai/mlx_backend/asr/interface.py +122 -0
- nexaai/mlx_backend/common/__init__.py +0 -0
- nexaai/mlx_backend/common/utils.py +25 -0
- nexaai/mlx_backend/cv/__init__.py +0 -0
- nexaai/mlx_backend/cv/generate.py +195 -0
- nexaai/mlx_backend/cv/interface.py +162 -0
- nexaai/mlx_backend/cv/main.py +81 -0
- nexaai/mlx_backend/cv/modeling/pp_ocr_v4.py +1736 -0
- nexaai/mlx_backend/embedding/__init__.py +0 -0
- nexaai/mlx_backend/embedding/generate.py +333 -0
- nexaai/mlx_backend/embedding/interface.py +617 -0
- nexaai/mlx_backend/embedding/main.py +173 -0
- nexaai/mlx_backend/embedding/modeling/__init__.py +0 -0
- nexaai/mlx_backend/embedding/modeling/nexa_jina_v2.py +399 -0
- nexaai/mlx_backend/image_gen/__init__.py +1 -0
- nexaai/mlx_backend/image_gen/generate_sd.py +244 -0
- nexaai/mlx_backend/image_gen/interface.py +82 -0
- nexaai/mlx_backend/image_gen/main.py +281 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/__init__.py +306 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/clip.py +116 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/config.py +65 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/model_io.py +386 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/sampler.py +105 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/tokenizer.py +100 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/unet.py +460 -0
- nexaai/mlx_backend/image_gen/stable_diffusion/vae.py +274 -0
- nexaai/mlx_backend/llm/__init__.py +0 -0
- nexaai/mlx_backend/llm/generate.py +149 -0
- nexaai/mlx_backend/llm/interface.py +764 -0
- nexaai/mlx_backend/llm/main.py +68 -0
- nexaai/mlx_backend/ml.py +888 -0
- nexaai/mlx_backend/mlx_audio/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/codec/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/__init__.py +5 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/activation.py +51 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/amp.py +96 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/conv.py +114 -0
- nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/resample.py +177 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/base.py +228 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/dac.py +285 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/layers.py +129 -0
- nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
- nexaai/mlx_backend/mlx_audio/codec/models/encodec/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/encodec/encodec.py +777 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/mimi.py +286 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
- nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/model.py +260 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/model_v2.py +383 -0
- nexaai/mlx_backend/mlx_audio/codec/models/s3/utils.py +122 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/attention.py +97 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/layers.py +306 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/snac.py +154 -0
- nexaai/mlx_backend/mlx_audio/codec/models/snac/vq.py +135 -0
- nexaai/mlx_backend/mlx_audio/codec/models/vocos/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/codec/models/vocos/mel.py +33 -0
- nexaai/mlx_backend/mlx_audio/codec/models/vocos/vocos.py +359 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_bigvgan.py +54 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_descript.py +109 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_encodec.py +58 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_mimi.py +22 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_s3.py +25 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_snac.py +40 -0
- nexaai/mlx_backend/mlx_audio/codec/tests/test_vocos.py +93 -0
- nexaai/mlx_backend/mlx_audio/server.py +525 -0
- nexaai/mlx_backend/mlx_audio/sts/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
- nexaai/mlx_backend/mlx_audio/sts/voice_pipeline.py +327 -0
- nexaai/mlx_backend/mlx_audio/stt/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/stt/generate.py +174 -0
- nexaai/mlx_backend/mlx_audio/stt/models/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/alignment.py +248 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/attention.py +187 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/audio.py +76 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/conformer.py +331 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/ctc.py +34 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
- nexaai/mlx_backend/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
- nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
- nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/audio.py +82 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/decoding.py +742 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/timing.py +329 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/whisper.py +862 -0
- nexaai/mlx_backend/mlx_audio/stt/models/whisper/writers.py +268 -0
- nexaai/mlx_backend/mlx_audio/stt/tests/test_models.py +381 -0
- nexaai/mlx_backend/mlx_audio/stt/utils.py +195 -0
- nexaai/mlx_backend/mlx_audio/tts/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/audio_player.py +120 -0
- nexaai/mlx_backend/mlx_audio/tts/convert.py +71 -0
- nexaai/mlx_backend/mlx_audio/tts/generate.py +449 -0
- nexaai/mlx_backend/mlx_audio/tts/models/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/__init__.py +4 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/bark.py +528 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/isftnet.py +12 -0
- nexaai/mlx_backend/mlx_audio/tts/models/bark/pipeline.py +442 -0
- nexaai/mlx_backend/mlx_audio/tts/models/base.py +84 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/audio.py +287 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/config.py +256 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/dia.py +592 -0
- nexaai/mlx_backend/mlx_audio/tts/models/dia/layers.py +870 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/__init__.py +3 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/attention.py +180 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/conformer.py +247 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/gpt2.py +38 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/indextts.py +412 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/mel.py +37 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/normalize.py +294 -0
- nexaai/mlx_backend/mlx_audio/tts/models/indextts/perceiver.py +62 -0
- nexaai/mlx_backend/mlx_audio/tts/models/interpolate.py +108 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/__init__.py +4 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/modules.py +659 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
- nexaai/mlx_backend/mlx_audio/tts/models/kokoro/voice.py +113 -0
- nexaai/mlx_backend/mlx_audio/tts/models/llama/__init__.py +3 -0
- nexaai/mlx_backend/mlx_audio/tts/models/llama/llama.py +324 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/default_speaker.json +461 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/outetts.py +255 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
- nexaai/mlx_backend/mlx_audio/tts/models/outetts/tokens.py +36 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/__init__.py +3 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/attention.py +195 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/sesame.py +633 -0
- nexaai/mlx_backend/mlx_audio/tts/models/sesame/watermarking.py +105 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/bicodec.py +269 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual.py +209 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/spark.py +382 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/audio.py +220 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/file.py +221 -0
- nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/__init__.py +0 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_base.py +66 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_convert.py +173 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_interpolate.py +88 -0
- nexaai/mlx_backend/mlx_audio/tts/tests/test_models.py +974 -0
- nexaai/mlx_backend/mlx_audio/tts/utils.py +337 -0
- nexaai/mlx_backend/mlx_audio/utils.py +237 -0
- nexaai/mlx_backend/mlx_audio/version.py +1 -0
- nexaai/mlx_backend/profiling.py +239 -0
- nexaai/mlx_backend/rerank/__init__.py +0 -0
- nexaai/mlx_backend/rerank/generate.py +174 -0
- nexaai/mlx_backend/rerank/interface.py +287 -0
- nexaai/mlx_backend/rerank/main.py +127 -0
- nexaai/mlx_backend/rerank/modeling/__init__.py +0 -0
- nexaai/mlx_backend/rerank/modeling/nexa_jina_rerank.py +330 -0
- nexaai/mlx_backend/sd/__init__.py +1 -0
- nexaai/mlx_backend/sd/interface.py +362 -0
- nexaai/mlx_backend/sd/main.py +286 -0
- nexaai/mlx_backend/sd/modeling/__init__.py +306 -0
- nexaai/mlx_backend/sd/modeling/clip.py +116 -0
- nexaai/mlx_backend/sd/modeling/config.py +65 -0
- nexaai/mlx_backend/sd/modeling/model_io.py +385 -0
- nexaai/mlx_backend/sd/modeling/sampler.py +105 -0
- nexaai/mlx_backend/sd/modeling/tokenizer.py +100 -0
- nexaai/mlx_backend/sd/modeling/unet.py +460 -0
- nexaai/mlx_backend/sd/modeling/vae.py +274 -0
- nexaai/mlx_backend/tts/__init__.py +12 -0
- nexaai/mlx_backend/tts/interface.py +276 -0
- nexaai/mlx_backend/vlm/__init__.py +3 -0
- nexaai/mlx_backend/vlm/generate.py +572 -0
- nexaai/mlx_backend/vlm/generate_qwen3_vl.py +374 -0
- nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +259 -0
- nexaai/mlx_backend/vlm/interface.py +559 -0
- nexaai/mlx_backend/vlm/main.py +365 -0
- nexaai/mlx_backend/vlm/modeling/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/convert.py +68 -0
- nexaai/mlx_backend/vlm/modeling/models/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/interpolate.py +186 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/language.py +233 -0
- nexaai/mlx_backend/vlm/modeling/models/aya_vision/vision.py +503 -0
- nexaai/mlx_backend/vlm/modeling/models/base.py +202 -0
- nexaai/mlx_backend/vlm/modeling/models/cache.py +230 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
- nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/florence2.py +366 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/language.py +488 -0
- nexaai/mlx_backend/vlm/modeling/models/florence2/vision.py +591 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/gemma3.py +213 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/language.py +315 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3/vision.py +238 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/__init__.py +2 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/audio.py +1038 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/config.py +139 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/language.py +629 -0
- nexaai/mlx_backend/vlm/modeling/models/gemma3n/vision.py +1022 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/idefics2.py +294 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/language.py +191 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics2/vision.py +267 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/idefics3.py +175 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/language.py +192 -0
- nexaai/mlx_backend/vlm/modeling/models/idefics3/vision.py +233 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/language.py +220 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/processor.py +393 -0
- nexaai/mlx_backend/vlm/modeling/models/internvl_chat/vision.py +293 -0
- nexaai/mlx_backend/vlm/modeling/models/kernels.py +307 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/language.py +509 -0
- nexaai/mlx_backend/vlm/modeling/models/kimi_vl/vision.py +522 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/language.py +386 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/llama4.py +138 -0
- nexaai/mlx_backend/vlm/modeling/models/llama4/vision.py +560 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/language.py +240 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/llava.py +153 -0
- nexaai/mlx_backend/vlm/modeling/models/llava/vision.py +259 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/language.py +236 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_bunny/vision.py +303 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/language.py +230 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/llava_next.py +160 -0
- nexaai/mlx_backend/vlm/modeling/models/llava_next/vision.py +243 -0
- nexaai/mlx_backend/vlm/modeling/models/mistral3/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/mistral3/mistral3.py +283 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/language.py +416 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/mllama.py +172 -0
- nexaai/mlx_backend/vlm/modeling/models/mllama/vision.py +499 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/language.py +243 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/molmo.py +133 -0
- nexaai/mlx_backend/vlm/modeling/models/molmo/vision.py +465 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/__init__.py +10 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/language.py +230 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/sam.py +557 -0
- nexaai/mlx_backend/vlm/modeling/models/multi_modality/vision.py +526 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/language.py +282 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/paligemma.py +160 -0
- nexaai/mlx_backend/vlm/modeling/models/paligemma/vision.py +242 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/language.py +21 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/su_rope.py +71 -0
- nexaai/mlx_backend/vlm/modeling/models/phi3_v/vision.py +324 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/language.py +229 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/pixtral.py +161 -0
- nexaai/mlx_backend/vlm/modeling/models/pixtral/vision.py +320 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/config.py +104 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/language.py +490 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/vision.py +312 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/processor.py +476 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +1262 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1308 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
- nexaai/mlx_backend/vlm/modeling/models/smolvlm/__init__.py +8 -0
- nexaai/mlx_backend/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
- nexaai/mlx_backend/vlm/modeling/processing_qwen2_5_vl.py +209 -0
- nexaai/mlx_backend/vlm/modeling/processing_qwen2_vl.py +215 -0
- nexaai/mlx_backend/vlm/modeling/prompt_utils.py +474 -0
- nexaai/mlx_backend/vlm/modeling/sample_utils.py +39 -0
- nexaai/mlx_backend/vlm/modeling/tokenizer_utils.py +344 -0
- nexaai/mlx_backend/vlm/modeling/trainer/__init__.py +9 -0
- nexaai/mlx_backend/vlm/modeling/trainer/lora.py +70 -0
- nexaai/mlx_backend/vlm/modeling/trainer/trainer.py +296 -0
- nexaai/mlx_backend/vlm/modeling/trainer/utils.py +160 -0
- nexaai/mlx_backend/vlm/modeling/utils.py +928 -0
- nexaai/rerank.py +57 -0
- nexaai/rerank_impl/__init__.py +0 -0
- nexaai/rerank_impl/mlx_rerank_impl.py +94 -0
- nexaai/rerank_impl/pybind_rerank_impl.py +136 -0
- nexaai/runtime.py +68 -0
- nexaai/runtime_error.py +24 -0
- nexaai/tts.py +75 -0
- nexaai/tts_impl/__init__.py +0 -0
- nexaai/tts_impl/mlx_tts_impl.py +94 -0
- nexaai/tts_impl/pybind_tts_impl.py +43 -0
- nexaai/utils/decode.py +18 -0
- nexaai/utils/manifest_utils.py +531 -0
- nexaai/utils/model_manager.py +1745 -0
- nexaai/utils/model_types.py +49 -0
- nexaai/utils/progress_tracker.py +389 -0
- nexaai/utils/quantization_utils.py +245 -0
- nexaai/vlm.py +130 -0
- nexaai/vlm_impl/__init__.py +0 -0
- nexaai/vlm_impl/mlx_vlm_impl.py +259 -0
- nexaai/vlm_impl/pybind_vlm_impl.py +275 -0
- nexaai-1.0.29.dist-info/METADATA +35 -0
- nexaai-1.0.29.dist-info/RECORD +580 -0
- nexaai-1.0.29.dist-info/WHEEL +5 -0
- nexaai-1.0.29.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1262 @@
|
|
|
1
|
+
# Copyright © 2023-2024 Apple Inc.
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
5
|
+
|
|
6
|
+
import mlx.core as mx
|
|
7
|
+
import mlx.nn as nn
|
|
8
|
+
import math
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
# Import from nested llm_common structure using relative imports
|
|
12
|
+
from .llm_common.base import (
|
|
13
|
+
BaseModelArgs,
|
|
14
|
+
create_attention_mask,
|
|
15
|
+
scaled_dot_product_attention,
|
|
16
|
+
)
|
|
17
|
+
from .llm_common.rope_utils import initialize_rope
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class VisionConfig:
|
|
22
|
+
hidden_size: int = 1024
|
|
23
|
+
intermediate_size: int = 4096
|
|
24
|
+
num_heads: int = 16
|
|
25
|
+
num_hidden_layers: int = 24
|
|
26
|
+
patch_size: int = 16
|
|
27
|
+
temporal_patch_size: int = 2
|
|
28
|
+
in_channels: int = 3
|
|
29
|
+
hidden_act: str = "gelu"
|
|
30
|
+
spatial_merge_size: int = 2
|
|
31
|
+
out_hidden_size: int = 2560
|
|
32
|
+
num_position_embeddings: int = 2304
|
|
33
|
+
deepstack_visual_indexes: List[int] = None
|
|
34
|
+
|
|
35
|
+
def __post_init__(self):
|
|
36
|
+
if self.deepstack_visual_indexes is None:
|
|
37
|
+
self.deepstack_visual_indexes = [3, 7, 11]
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class TextConfig(BaseModelArgs):
|
|
42
|
+
model_type: str = "qwen3vl"
|
|
43
|
+
hidden_size: int = 2560
|
|
44
|
+
num_hidden_layers: int = 36
|
|
45
|
+
intermediate_size: int = 9728
|
|
46
|
+
num_attention_heads: int = 32
|
|
47
|
+
num_key_value_heads: int = 8
|
|
48
|
+
rms_norm_eps: float = 1e-6
|
|
49
|
+
vocab_size: int = 151936
|
|
50
|
+
max_position_embeddings: int = 32768
|
|
51
|
+
rope_theta: float = 10000.0
|
|
52
|
+
head_dim: int = 128
|
|
53
|
+
tie_word_embeddings: bool = True
|
|
54
|
+
attention_bias: bool = False
|
|
55
|
+
attention_dropout: float = 0.0
|
|
56
|
+
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
|
|
57
|
+
|
|
58
|
+
def __post_init__(self):
|
|
59
|
+
if self.rope_scaling is None:
|
|
60
|
+
# Use default RoPE for now since MRoPE is not implemented in rope_utils
|
|
61
|
+
self.rope_scaling = None
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@dataclass
|
|
65
|
+
class ModelArgs(BaseModelArgs):
|
|
66
|
+
vision_config: VisionConfig = None
|
|
67
|
+
text_config: TextConfig = None
|
|
68
|
+
image_token_id: int = 151655
|
|
69
|
+
vision_start_token_id: int = 151652
|
|
70
|
+
vision_end_token_id: int = 151653
|
|
71
|
+
|
|
72
|
+
def __post_init__(self):
|
|
73
|
+
if self.vision_config is None:
|
|
74
|
+
self.vision_config = VisionConfig()
|
|
75
|
+
if self.text_config is None:
|
|
76
|
+
self.text_config = TextConfig()
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def rotate_half(x):
|
|
80
|
+
x1 = x[..., : x.shape[-1] // 2]
|
|
81
|
+
x2 = x[..., x.shape[-1] // 2 :]
|
|
82
|
+
return mx.concatenate([-x2, x1], axis=-1)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def apply_rotary_pos_emb_vision(q, k, cos, sin):
|
|
86
|
+
cos = mx.expand_dims(cos, axis=-2)
|
|
87
|
+
sin = mx.expand_dims(sin, axis=-2)
|
|
88
|
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
89
|
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
90
|
+
return q_embed, k_embed
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
|
|
94
|
+
cos = mx.expand_dims(cos, axis=unsqueeze_dim)
|
|
95
|
+
sin = mx.expand_dims(sin, axis=unsqueeze_dim)
|
|
96
|
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
97
|
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
98
|
+
return q_embed, k_embed
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class VisionMLP(nn.Module):
|
|
102
|
+
def __init__(self, config: VisionConfig):
|
|
103
|
+
super().__init__()
|
|
104
|
+
self.hidden_size = config.hidden_size
|
|
105
|
+
self.intermediate_size = config.intermediate_size
|
|
106
|
+
self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True)
|
|
107
|
+
self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True)
|
|
108
|
+
|
|
109
|
+
def __call__(self, hidden_state):
|
|
110
|
+
return self.linear_fc2(nn.gelu(self.linear_fc1(hidden_state)))
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class VisionPatchEmbed(nn.Module):
|
|
114
|
+
def __init__(self, config: VisionConfig):
|
|
115
|
+
super().__init__()
|
|
116
|
+
self.patch_size = config.patch_size
|
|
117
|
+
self.temporal_patch_size = config.temporal_patch_size
|
|
118
|
+
self.in_channels = config.in_channels
|
|
119
|
+
self.embed_dim = config.hidden_size
|
|
120
|
+
|
|
121
|
+
kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
|
|
122
|
+
self.proj = nn.Conv3d(
|
|
123
|
+
self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
def __call__(self, hidden_states: mx.array) -> mx.array:
|
|
127
|
+
target_dtype = self.proj.weight.dtype
|
|
128
|
+
|
|
129
|
+
# Reshape to 5D: [batch, channels, temporal, height, width] (PyTorch format)
|
|
130
|
+
# This matches the PyTorch ground truth exactly
|
|
131
|
+
hidden_states = hidden_states.reshape(
|
|
132
|
+
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# Convert to MLX format: [batch, temporal, height, width, channels]
|
|
136
|
+
hidden_states = hidden_states.transpose(0, 2, 3, 4, 1)
|
|
137
|
+
|
|
138
|
+
# Apply conv3d with target dtype and reshape to match PyTorch output
|
|
139
|
+
hidden_states = self.proj(hidden_states.astype(target_dtype)).reshape(-1, self.embed_dim)
|
|
140
|
+
|
|
141
|
+
return hidden_states
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class VisionRotaryEmbedding(nn.Module):
|
|
145
|
+
def __init__(self, dim: int, theta: float = 10000.0):
|
|
146
|
+
super().__init__()
|
|
147
|
+
# Don't store inv_freq as a parameter since it causes loading issues
|
|
148
|
+
self.dim = dim
|
|
149
|
+
self.theta = theta
|
|
150
|
+
|
|
151
|
+
def __call__(self, seqlen: int) -> mx.array:
|
|
152
|
+
# Compute inv_freq on the fly
|
|
153
|
+
inv_freq = 1.0 / (self.theta ** (mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim))
|
|
154
|
+
seq = mx.arange(seqlen, dtype=inv_freq.dtype)
|
|
155
|
+
freqs = mx.outer(seq, inv_freq)
|
|
156
|
+
return freqs
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class VisionPatchMerger(nn.Module):
|
|
160
|
+
def __init__(self, config: VisionConfig, use_postshuffle_norm=False):
|
|
161
|
+
super().__init__()
|
|
162
|
+
self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
|
|
163
|
+
self.use_postshuffle_norm = use_postshuffle_norm
|
|
164
|
+
|
|
165
|
+
norm_size = self.hidden_size if use_postshuffle_norm else config.hidden_size
|
|
166
|
+
self.norm = nn.LayerNorm(norm_size, eps=1e-6)
|
|
167
|
+
self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size)
|
|
168
|
+
self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size)
|
|
169
|
+
|
|
170
|
+
def __call__(self, x: mx.array) -> mx.array:
|
|
171
|
+
if self.use_postshuffle_norm:
|
|
172
|
+
x = self.norm(x.reshape(-1, self.hidden_size)).reshape(-1, self.hidden_size)
|
|
173
|
+
else:
|
|
174
|
+
x = self.norm(x).reshape(-1, self.hidden_size)
|
|
175
|
+
|
|
176
|
+
x = self.linear_fc2(nn.gelu(self.linear_fc1(x)))
|
|
177
|
+
return x
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class VisionAttention(nn.Module):
|
|
181
|
+
def __init__(self, config: VisionConfig):
|
|
182
|
+
super().__init__()
|
|
183
|
+
self.dim = config.hidden_size
|
|
184
|
+
self.num_heads = config.num_heads
|
|
185
|
+
self.head_dim = self.dim // self.num_heads
|
|
186
|
+
self.scaling = self.head_dim**-0.5
|
|
187
|
+
|
|
188
|
+
self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True)
|
|
189
|
+
self.proj = nn.Linear(self.dim, self.dim)
|
|
190
|
+
|
|
191
|
+
def __call__(
|
|
192
|
+
self,
|
|
193
|
+
hidden_states: mx.array,
|
|
194
|
+
cu_seqlens: mx.array,
|
|
195
|
+
rotary_pos_emb: Optional[mx.array] = None,
|
|
196
|
+
position_embeddings: Optional[Tuple[mx.array, mx.array]] = None,
|
|
197
|
+
**kwargs,
|
|
198
|
+
) -> mx.array:
|
|
199
|
+
seq_length = hidden_states.shape[0]
|
|
200
|
+
qkv = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1)
|
|
201
|
+
qkv = qkv.transpose(1, 0, 2, 3)
|
|
202
|
+
query_states, key_states, value_states = qkv[0], qkv[1], qkv[2]
|
|
203
|
+
|
|
204
|
+
cos, sin = position_embeddings
|
|
205
|
+
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
|
|
206
|
+
|
|
207
|
+
query_states = query_states.transpose(1, 0, 2)
|
|
208
|
+
key_states = key_states.transpose(1, 0, 2)
|
|
209
|
+
value_states = value_states.transpose(1, 0, 2)
|
|
210
|
+
|
|
211
|
+
query_states = mx.expand_dims(query_states, axis=0)
|
|
212
|
+
key_states = mx.expand_dims(key_states, axis=0)
|
|
213
|
+
value_states = mx.expand_dims(value_states, axis=0)
|
|
214
|
+
|
|
215
|
+
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
|
216
|
+
|
|
217
|
+
split_indices = []
|
|
218
|
+
cumsum = 0
|
|
219
|
+
for length in lengths[:-1]:
|
|
220
|
+
cumsum += int(length)
|
|
221
|
+
split_indices.append(cumsum)
|
|
222
|
+
|
|
223
|
+
if split_indices:
|
|
224
|
+
q_splits = mx.split(query_states, split_indices, axis=1)
|
|
225
|
+
k_splits = mx.split(key_states, split_indices, axis=1)
|
|
226
|
+
v_splits = mx.split(value_states, split_indices, axis=1)
|
|
227
|
+
else:
|
|
228
|
+
q_splits = [query_states]
|
|
229
|
+
k_splits = [key_states]
|
|
230
|
+
v_splits = [value_states]
|
|
231
|
+
|
|
232
|
+
attn_outputs = []
|
|
233
|
+
for q, k, v in zip(q_splits, k_splits, v_splits):
|
|
234
|
+
attn_out = scaled_dot_product_attention(
|
|
235
|
+
q, k, v, scale=self.scaling, mask=None, cache=None
|
|
236
|
+
)
|
|
237
|
+
attn_outputs.append(attn_out)
|
|
238
|
+
|
|
239
|
+
attn_output = mx.concatenate(attn_outputs, axis=1)
|
|
240
|
+
|
|
241
|
+
attn_output = attn_output[0].transpose(1, 0, 2)
|
|
242
|
+
attn_output = attn_output.reshape(seq_length, -1)
|
|
243
|
+
attn_output = self.proj(attn_output)
|
|
244
|
+
|
|
245
|
+
return attn_output
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
class VisionBlock(nn.Module):
|
|
249
|
+
def __init__(self, config: VisionConfig):
|
|
250
|
+
super().__init__()
|
|
251
|
+
self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6)
|
|
252
|
+
self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6)
|
|
253
|
+
self.attn = VisionAttention(config)
|
|
254
|
+
self.mlp = VisionMLP(config)
|
|
255
|
+
|
|
256
|
+
def __call__(
|
|
257
|
+
self,
|
|
258
|
+
hidden_states: mx.array,
|
|
259
|
+
cu_seqlens: mx.array,
|
|
260
|
+
position_embeddings: Tuple[mx.array, mx.array],
|
|
261
|
+
) -> mx.array:
|
|
262
|
+
hidden_states = hidden_states + self.attn(
|
|
263
|
+
self.norm1(hidden_states),
|
|
264
|
+
cu_seqlens=cu_seqlens,
|
|
265
|
+
position_embeddings=position_embeddings,
|
|
266
|
+
)
|
|
267
|
+
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
|
268
|
+
return hidden_states
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
class VisionModel(nn.Module):
|
|
272
|
+
def __init__(self, config: VisionConfig):
|
|
273
|
+
super().__init__()
|
|
274
|
+
self.config = config
|
|
275
|
+
self.spatial_merge_size = config.spatial_merge_size
|
|
276
|
+
self.patch_size = config.patch_size
|
|
277
|
+
|
|
278
|
+
self.patch_embed = VisionPatchEmbed(config)
|
|
279
|
+
self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size)
|
|
280
|
+
self.num_grid_per_side = int(config.num_position_embeddings**0.5)
|
|
281
|
+
|
|
282
|
+
head_dim = config.hidden_size // config.num_heads
|
|
283
|
+
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
|
284
|
+
|
|
285
|
+
self.blocks = [VisionBlock(config) for _ in range(config.num_hidden_layers)]
|
|
286
|
+
self.merger = VisionPatchMerger(config, use_postshuffle_norm=False)
|
|
287
|
+
|
|
288
|
+
self.deepstack_visual_indexes = config.deepstack_visual_indexes
|
|
289
|
+
self.deepstack_merger_list = [
|
|
290
|
+
VisionPatchMerger(config, use_postshuffle_norm=True)
|
|
291
|
+
for _ in range(len(config.deepstack_visual_indexes))
|
|
292
|
+
]
|
|
293
|
+
|
|
294
|
+
def rot_pos_emb(self, grid_thw: mx.array) -> mx.array:
|
|
295
|
+
merge_size = self.spatial_merge_size
|
|
296
|
+
|
|
297
|
+
max_hw = int(grid_thw[:, 1:].max().item())
|
|
298
|
+
freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2)
|
|
299
|
+
|
|
300
|
+
pos_ids_parts = []
|
|
301
|
+
|
|
302
|
+
for i in range(grid_thw.shape[0]):
|
|
303
|
+
num_frames = int(grid_thw[i, 0].item())
|
|
304
|
+
height = int(grid_thw[i, 1].item())
|
|
305
|
+
width = int(grid_thw[i, 2].item())
|
|
306
|
+
|
|
307
|
+
merged_h, merged_w = height // merge_size, width // merge_size
|
|
308
|
+
|
|
309
|
+
block_rows = mx.arange(merged_h) # block row indices
|
|
310
|
+
block_cols = mx.arange(merged_w) # block col indices
|
|
311
|
+
intra_row = mx.arange(merge_size) # intra-block row offsets
|
|
312
|
+
intra_col = mx.arange(merge_size) # intra-block col offsets
|
|
313
|
+
|
|
314
|
+
# Compute full-resolution positions using broadcasting
|
|
315
|
+
row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None]
|
|
316
|
+
col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :]
|
|
317
|
+
|
|
318
|
+
row_idx = mx.broadcast_to(
|
|
319
|
+
row_idx, (merged_h, merged_w, merge_size, merge_size)
|
|
320
|
+
).reshape(-1)
|
|
321
|
+
col_idx = mx.broadcast_to(
|
|
322
|
+
col_idx, (merged_h, merged_w, merge_size, merge_size)
|
|
323
|
+
).reshape(-1)
|
|
324
|
+
|
|
325
|
+
coords = mx.stack([row_idx, col_idx], axis=-1)
|
|
326
|
+
|
|
327
|
+
if num_frames > 1:
|
|
328
|
+
coords = mx.tile(coords, (num_frames, 1))
|
|
329
|
+
|
|
330
|
+
pos_ids_parts.append(coords)
|
|
331
|
+
|
|
332
|
+
# Concatenate all coordinate parts
|
|
333
|
+
pos_ids = mx.concatenate(pos_ids_parts, axis=0)
|
|
334
|
+
|
|
335
|
+
embeddings = freq_table[pos_ids] # lookup rotary embeddings
|
|
336
|
+
embeddings = embeddings.reshape(embeddings.shape[0], -1)
|
|
337
|
+
return embeddings
|
|
338
|
+
|
|
339
|
+
def fast_pos_embed_interpolate(self, grid_thw: mx.array):
|
|
340
|
+
patch_pos_embeds = []
|
|
341
|
+
|
|
342
|
+
for i in range(grid_thw.shape[0]):
|
|
343
|
+
t = int(grid_thw[i, 0].item())
|
|
344
|
+
h = int(grid_thw[i, 1].item())
|
|
345
|
+
w = int(grid_thw[i, 2].item())
|
|
346
|
+
|
|
347
|
+
# Simple position embedding interpolation
|
|
348
|
+
h_idxs = mx.linspace(0, self.num_grid_per_side - 1, h)
|
|
349
|
+
w_idxs = mx.linspace(0, self.num_grid_per_side - 1, w)
|
|
350
|
+
|
|
351
|
+
h_idxs_floor = mx.floor(h_idxs).astype(mx.int32)
|
|
352
|
+
w_idxs_floor = mx.floor(w_idxs).astype(mx.int32)
|
|
353
|
+
h_idxs_ceil = mx.minimum(h_idxs_floor + 1, self.num_grid_per_side - 1)
|
|
354
|
+
w_idxs_ceil = mx.minimum(w_idxs_floor + 1, self.num_grid_per_side - 1)
|
|
355
|
+
|
|
356
|
+
dh = h_idxs - h_idxs_floor.astype(mx.float32)
|
|
357
|
+
dw = w_idxs - w_idxs_floor.astype(mx.float32)
|
|
358
|
+
|
|
359
|
+
base_h = h_idxs_floor * self.num_grid_per_side
|
|
360
|
+
base_h_ceil = h_idxs_ceil * self.num_grid_per_side
|
|
361
|
+
|
|
362
|
+
# Compute bilinear interpolation indices and weights
|
|
363
|
+
indices_tl = (base_h[:, None] + w_idxs_floor[None, :]).reshape(-1)
|
|
364
|
+
indices_tr = (base_h[:, None] + w_idxs_ceil[None, :]).reshape(-1)
|
|
365
|
+
indices_bl = (base_h_ceil[:, None] + w_idxs_floor[None, :]).reshape(-1)
|
|
366
|
+
indices_br = (base_h_ceil[:, None] + w_idxs_ceil[None, :]).reshape(-1)
|
|
367
|
+
|
|
368
|
+
weights_tl = ((1 - dh)[:, None] * (1 - dw)[None, :]).reshape(-1)
|
|
369
|
+
weights_tr = ((1 - dh)[:, None] * dw[None, :]).reshape(-1)
|
|
370
|
+
weights_bl = (dh[:, None] * (1 - dw)[None, :]).reshape(-1)
|
|
371
|
+
weights_br = (dh[:, None] * dw[None, :]).reshape(-1)
|
|
372
|
+
|
|
373
|
+
# Get embeddings and interpolate
|
|
374
|
+
pos_embed_tl = self.pos_embed(indices_tl) * weights_tl[:, None]
|
|
375
|
+
pos_embed_tr = self.pos_embed(indices_tr) * weights_tr[:, None]
|
|
376
|
+
pos_embed_bl = self.pos_embed(indices_bl) * weights_bl[:, None]
|
|
377
|
+
pos_embed_br = self.pos_embed(indices_br) * weights_br[:, None]
|
|
378
|
+
|
|
379
|
+
pos_embed = pos_embed_tl + pos_embed_tr + pos_embed_bl + pos_embed_br
|
|
380
|
+
|
|
381
|
+
# Repeat for temporal dimension and apply spatial merging
|
|
382
|
+
pos_embed = mx.tile(pos_embed, (t, 1))
|
|
383
|
+
|
|
384
|
+
# Apply spatial merging pattern
|
|
385
|
+
merge_size = self.config.spatial_merge_size
|
|
386
|
+
pos_embed = pos_embed.reshape(
|
|
387
|
+
t, h // merge_size, merge_size, w // merge_size, merge_size, -1
|
|
388
|
+
)
|
|
389
|
+
pos_embed = mx.transpose(pos_embed, (0, 1, 3, 2, 4, 5))
|
|
390
|
+
pos_embed = pos_embed.reshape(-1, pos_embed.shape[-1])
|
|
391
|
+
|
|
392
|
+
patch_pos_embeds.append(pos_embed)
|
|
393
|
+
|
|
394
|
+
return mx.concatenate(patch_pos_embeds, axis=0)
|
|
395
|
+
|
|
396
|
+
def __call__(
|
|
397
|
+
self, hidden_states: mx.array, grid_thw: mx.array
|
|
398
|
+
) -> Tuple[mx.array, List[mx.array]]:
|
|
399
|
+
hidden_states = self.patch_embed(hidden_states)
|
|
400
|
+
|
|
401
|
+
pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
|
|
402
|
+
hidden_states = hidden_states + pos_embeds
|
|
403
|
+
|
|
404
|
+
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
|
405
|
+
seq_len = hidden_states.shape[0]
|
|
406
|
+
|
|
407
|
+
emb = mx.concatenate([rotary_pos_emb, rotary_pos_emb], axis=-1)
|
|
408
|
+
position_embeddings = (mx.cos(emb), mx.sin(emb))
|
|
409
|
+
|
|
410
|
+
# Create cumulative sequence lengths (following HuggingFace implementation)
|
|
411
|
+
# torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0])
|
|
412
|
+
seq_lens_per_image = grid_thw[:, 1] * grid_thw[:, 2] # h * w for each image
|
|
413
|
+
seq_lens = []
|
|
414
|
+
for i, (seq_len, repeats) in enumerate(zip(seq_lens_per_image, grid_thw[:, 0])):
|
|
415
|
+
seq_lens.extend([seq_len] * int(repeats))
|
|
416
|
+
seq_lens = mx.array(seq_lens)
|
|
417
|
+
|
|
418
|
+
# Then compute cumulative sum
|
|
419
|
+
cu_seqlens = mx.cumsum(seq_lens)
|
|
420
|
+
# Pad with 0 at the beginning
|
|
421
|
+
cu_seqlens = mx.concatenate([mx.array([0]), cu_seqlens])
|
|
422
|
+
|
|
423
|
+
deepstack_feature_lists = []
|
|
424
|
+
for layer_num, blk in enumerate(self.blocks):
|
|
425
|
+
hidden_states = blk(
|
|
426
|
+
hidden_states,
|
|
427
|
+
cu_seqlens=cu_seqlens,
|
|
428
|
+
position_embeddings=position_embeddings,
|
|
429
|
+
)
|
|
430
|
+
if layer_num in self.deepstack_visual_indexes:
|
|
431
|
+
idx = self.deepstack_visual_indexes.index(layer_num)
|
|
432
|
+
deepstack_feature = self.deepstack_merger_list[idx](hidden_states)
|
|
433
|
+
deepstack_feature_lists.append(deepstack_feature)
|
|
434
|
+
|
|
435
|
+
hidden_states = self.merger(hidden_states)
|
|
436
|
+
return hidden_states, deepstack_feature_lists
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
class TextRotaryEmbedding(nn.Module):
|
|
440
|
+
def __init__(self, config: TextConfig):
|
|
441
|
+
super().__init__()
|
|
442
|
+
self.config = config
|
|
443
|
+
self.max_seq_len_cached = config.max_position_embeddings
|
|
444
|
+
self.original_max_seq_len = config.max_position_embeddings
|
|
445
|
+
|
|
446
|
+
# MRoPE configuration
|
|
447
|
+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
|
448
|
+
self.rope_type = config.rope_scaling.get("rope_type", "default")
|
|
449
|
+
self.mrope_section = config.rope_scaling.get("mrope_section", [24, 20, 20])
|
|
450
|
+
else:
|
|
451
|
+
self.rope_type = "default"
|
|
452
|
+
self.mrope_section = [24, 20, 20]
|
|
453
|
+
|
|
454
|
+
# Store parameters for computing inv_freq on the fly
|
|
455
|
+
self.head_dim = config.head_dim
|
|
456
|
+
self.theta = config.rope_theta
|
|
457
|
+
|
|
458
|
+
# Attention scaling (simplified - may need adjustment based on actual config)
|
|
459
|
+
self.attention_scaling = 1.0
|
|
460
|
+
|
|
461
|
+
def _get_inv_freq(self):
|
|
462
|
+
"""Compute inverse frequencies on the fly"""
|
|
463
|
+
inv_freq = 1.0 / (
|
|
464
|
+
self.theta ** (mx.arange(0, self.head_dim, 2).astype(mx.float32) / self.head_dim)
|
|
465
|
+
)
|
|
466
|
+
# Expand for 3 dimensions (T, H, W)
|
|
467
|
+
return mx.broadcast_to(inv_freq[None, :], (3, len(inv_freq)))
|
|
468
|
+
|
|
469
|
+
def apply_interleaved_mrope(self, freqs, mrope_section):
|
|
470
|
+
"""Apply interleaved MRoPE to 3D rotary embeddings.
|
|
471
|
+
Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
|
|
472
|
+
interleaved [THTHWHTHW...TT], preserving frequency continuity.
|
|
473
|
+
args:
|
|
474
|
+
x: (3, bs, seq_len, head_dim // 2)
|
|
475
|
+
mrope_section: (3,)
|
|
476
|
+
returns:
|
|
477
|
+
x_t: (bs, seq_len, head_dim // 2)
|
|
478
|
+
"""
|
|
479
|
+
freqs_t = freqs[0] # just overwrite the first dimension T
|
|
480
|
+
for dim, offset in enumerate((1, 2), start=1): # H, W
|
|
481
|
+
length = mrope_section[dim] * 3
|
|
482
|
+
idx = slice(offset, length, 3)
|
|
483
|
+
freqs_t[..., idx] = freqs[dim, ..., idx]
|
|
484
|
+
return freqs_t
|
|
485
|
+
|
|
486
|
+
def __call__(self, x: mx.array, position_ids: mx.array) -> mx.array:
|
|
487
|
+
"""
|
|
488
|
+
Args:
|
|
489
|
+
x: Input tensor for dtype reference
|
|
490
|
+
position_ids: Position indices, shape (3, batch_size, seq_len) for MRoPE
|
|
491
|
+
|
|
492
|
+
Returns:
|
|
493
|
+
cos, sin: Cosine and sine embeddings
|
|
494
|
+
"""
|
|
495
|
+
# Handle 2D position_ids by expanding to 3D for MRoPE
|
|
496
|
+
if position_ids.ndim == 2:
|
|
497
|
+
position_ids = mx.broadcast_to(
|
|
498
|
+
position_ids[None, ...], (3, position_ids.shape[0], position_ids.shape[1])
|
|
499
|
+
)
|
|
500
|
+
|
|
501
|
+
batch_size, seq_len = position_ids.shape[1], position_ids.shape[2]
|
|
502
|
+
|
|
503
|
+
# Expand inverse frequencies: (3, 1, 1, dim//2) -> (3, batch_size, 1, dim//2)
|
|
504
|
+
inv_freq_expanded = mx.broadcast_to(
|
|
505
|
+
self._get_inv_freq()[:, None, None, :],
|
|
506
|
+
(3, batch_size, 1, self._get_inv_freq().shape[-1]),
|
|
507
|
+
)
|
|
508
|
+
|
|
509
|
+
# Expand position ids: (3, batch_size, seq_len) -> (3, batch_size, seq_len, 1)
|
|
510
|
+
position_ids_expanded = position_ids[..., None].astype(mx.float32)
|
|
511
|
+
|
|
512
|
+
# Compute frequencies: (3, batch_size, seq_len, dim//2)
|
|
513
|
+
freqs = inv_freq_expanded * position_ids_expanded
|
|
514
|
+
|
|
515
|
+
# Apply interleaved MRoPE
|
|
516
|
+
freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
|
|
517
|
+
|
|
518
|
+
# Create embeddings
|
|
519
|
+
emb = mx.concatenate([freqs, freqs], axis=-1) # (batch_size, seq_len, head_dim)
|
|
520
|
+
cos = mx.cos(emb) * self.attention_scaling
|
|
521
|
+
sin = mx.sin(emb) * self.attention_scaling
|
|
522
|
+
|
|
523
|
+
return cos.astype(x.dtype), sin.astype(x.dtype)
|
|
524
|
+
|
|
525
|
+
|
|
526
|
+
class TextAttention(nn.Module):
|
|
527
|
+
def __init__(self, config: TextConfig, layer_idx: int):
|
|
528
|
+
super().__init__()
|
|
529
|
+
self.config = config
|
|
530
|
+
self.layer_idx = layer_idx
|
|
531
|
+
|
|
532
|
+
dim = config.hidden_size
|
|
533
|
+
self.n_heads = config.num_attention_heads
|
|
534
|
+
self.n_kv_heads = config.num_key_value_heads
|
|
535
|
+
self.head_dim = config.head_dim
|
|
536
|
+
self.scale = self.head_dim**-0.5
|
|
537
|
+
|
|
538
|
+
self.q_proj = nn.Linear(dim, self.n_heads * self.head_dim, bias=config.attention_bias)
|
|
539
|
+
self.k_proj = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=config.attention_bias)
|
|
540
|
+
self.v_proj = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=config.attention_bias)
|
|
541
|
+
self.o_proj = nn.Linear(self.n_heads * self.head_dim, dim, bias=config.attention_bias)
|
|
542
|
+
|
|
543
|
+
self.q_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
|
544
|
+
self.k_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
|
545
|
+
|
|
546
|
+
# Initialize rope directly
|
|
547
|
+
self.rope = initialize_rope(
|
|
548
|
+
config.head_dim,
|
|
549
|
+
base=config.rope_theta,
|
|
550
|
+
traditional=False,
|
|
551
|
+
scaling_config=config.rope_scaling,
|
|
552
|
+
max_position_embeddings=config.max_position_embeddings,
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
def __call__(
|
|
556
|
+
self,
|
|
557
|
+
hidden_states: mx.array,
|
|
558
|
+
attention_mask: Optional[mx.array] = None,
|
|
559
|
+
cache: Optional[Any] = None,
|
|
560
|
+
cos: Optional[mx.array] = None,
|
|
561
|
+
sin: Optional[mx.array] = None,
|
|
562
|
+
rope_deltas: Optional[mx.array] = None,
|
|
563
|
+
) -> Tuple[mx.array, Optional[mx.array]]:
|
|
564
|
+
B, L, D = hidden_states.shape
|
|
565
|
+
|
|
566
|
+
queries = self.q_proj(hidden_states).reshape(B, L, self.n_heads, -1)
|
|
567
|
+
keys = self.k_proj(hidden_states).reshape(B, L, self.n_kv_heads, -1)
|
|
568
|
+
values = self.v_proj(hidden_states).reshape(B, L, self.n_kv_heads, -1)
|
|
569
|
+
|
|
570
|
+
queries = self.q_norm(queries).transpose(0, 2, 1, 3)
|
|
571
|
+
keys = self.k_norm(keys).transpose(0, 2, 1, 3)
|
|
572
|
+
values = values.transpose(0, 2, 1, 3)
|
|
573
|
+
|
|
574
|
+
# Apply rope directly to queries and keys
|
|
575
|
+
if cos is not None and sin is not None:
|
|
576
|
+
queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin)
|
|
577
|
+
if cache is not None:
|
|
578
|
+
keys, values = cache.update_and_fetch(keys, values)
|
|
579
|
+
else:
|
|
580
|
+
if cache is not None:
|
|
581
|
+
# Handle different types of rope_deltas: scalar, array, or None
|
|
582
|
+
if rope_deltas is None:
|
|
583
|
+
offset_delta = 0
|
|
584
|
+
elif isinstance(rope_deltas, (int, float)):
|
|
585
|
+
# rope_deltas is a scalar
|
|
586
|
+
offset_delta = rope_deltas
|
|
587
|
+
elif hasattr(rope_deltas, 'size') and rope_deltas.size == 1:
|
|
588
|
+
# rope_deltas is an array with single element
|
|
589
|
+
offset_delta = rope_deltas.item()
|
|
590
|
+
elif hasattr(rope_deltas, 'shape') and rope_deltas.shape:
|
|
591
|
+
# rope_deltas is an array with multiple elements, take first
|
|
592
|
+
offset_delta = rope_deltas.reshape(-1)[0].item()
|
|
593
|
+
else:
|
|
594
|
+
offset_delta = 0
|
|
595
|
+
|
|
596
|
+
queries = self.rope(queries, offset=cache.offset + offset_delta)
|
|
597
|
+
keys = self.rope(keys, offset=cache.offset + offset_delta)
|
|
598
|
+
keys, values = cache.update_and_fetch(keys, values)
|
|
599
|
+
else:
|
|
600
|
+
queries = self.rope(queries)
|
|
601
|
+
keys = self.rope(keys)
|
|
602
|
+
|
|
603
|
+
output = scaled_dot_product_attention(
|
|
604
|
+
queries, keys, values, cache=cache, scale=self.scale, mask=attention_mask
|
|
605
|
+
)
|
|
606
|
+
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
607
|
+
return self.o_proj(output), None
|
|
608
|
+
|
|
609
|
+
|
|
610
|
+
class TextMLP(nn.Module):
|
|
611
|
+
def __init__(self, config: TextConfig):
|
|
612
|
+
super().__init__()
|
|
613
|
+
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
|
614
|
+
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
|
615
|
+
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
|
616
|
+
|
|
617
|
+
def __call__(self, x):
|
|
618
|
+
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
|
619
|
+
|
|
620
|
+
|
|
621
|
+
class TextDecoderLayer(nn.Module):
|
|
622
|
+
def __init__(self, config: TextConfig, layer_idx: int):
|
|
623
|
+
super().__init__()
|
|
624
|
+
self.hidden_size = config.hidden_size
|
|
625
|
+
self.self_attn = TextAttention(config, layer_idx)
|
|
626
|
+
self.mlp = TextMLP(config)
|
|
627
|
+
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
628
|
+
self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
629
|
+
|
|
630
|
+
def __call__(
|
|
631
|
+
self,
|
|
632
|
+
hidden_states: mx.array,
|
|
633
|
+
attention_mask: Optional[mx.array] = None,
|
|
634
|
+
cache: Optional[Any] = None,
|
|
635
|
+
cos: Optional[mx.array] = None,
|
|
636
|
+
sin: Optional[mx.array] = None,
|
|
637
|
+
rope_deltas: Optional[mx.array] = None,
|
|
638
|
+
) -> mx.array:
|
|
639
|
+
residual = hidden_states
|
|
640
|
+
hidden_states = self.input_layernorm(hidden_states)
|
|
641
|
+
|
|
642
|
+
hidden_states, _ = self.self_attn(
|
|
643
|
+
hidden_states=hidden_states,
|
|
644
|
+
attention_mask=attention_mask,
|
|
645
|
+
cache=cache,
|
|
646
|
+
cos=cos,
|
|
647
|
+
sin=sin,
|
|
648
|
+
rope_deltas=rope_deltas,
|
|
649
|
+
)
|
|
650
|
+
hidden_states = residual + hidden_states
|
|
651
|
+
residual = hidden_states
|
|
652
|
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
|
653
|
+
hidden_states = self.mlp(hidden_states)
|
|
654
|
+
hidden_states = residual + hidden_states
|
|
655
|
+
return hidden_states
|
|
656
|
+
|
|
657
|
+
|
|
658
|
+
class TextModel(nn.Module):
|
|
659
|
+
def __init__(self, config: TextConfig):
|
|
660
|
+
super().__init__()
|
|
661
|
+
self.config = config
|
|
662
|
+
self.vocab_size = config.vocab_size
|
|
663
|
+
|
|
664
|
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
|
665
|
+
self.layers = [
|
|
666
|
+
TextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)
|
|
667
|
+
]
|
|
668
|
+
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
669
|
+
self.rotary_emb = TextRotaryEmbedding(config)
|
|
670
|
+
|
|
671
|
+
def _deepstack_process(
|
|
672
|
+
self,
|
|
673
|
+
hidden_states: mx.array,
|
|
674
|
+
visual_pos_masks: mx.array,
|
|
675
|
+
deepstack_visual_embeds: mx.array,
|
|
676
|
+
) -> mx.array:
|
|
677
|
+
if visual_pos_masks is None or deepstack_visual_embeds is None:
|
|
678
|
+
return hidden_states
|
|
679
|
+
B, L, D = hidden_states.shape
|
|
680
|
+
mask_flat = visual_pos_masks.astype(mx.int32).reshape(-1)
|
|
681
|
+
idx_flat = mx.cumsum(mask_flat, axis=0) - 1
|
|
682
|
+
N = deepstack_visual_embeds.shape[0]
|
|
683
|
+
idx_flat = mx.maximum(idx_flat, 0)
|
|
684
|
+
eq = (idx_flat[:, None] == mx.arange(N)[None, :]).astype(hidden_states.dtype)
|
|
685
|
+
add_flat = eq @ deepstack_visual_embeds.astype(hidden_states.dtype)
|
|
686
|
+
add_flat = add_flat * mask_flat[:, None].astype(hidden_states.dtype)
|
|
687
|
+
add = add_flat.reshape(B, L, D)
|
|
688
|
+
return hidden_states + add
|
|
689
|
+
|
|
690
|
+
def __call__(
|
|
691
|
+
self,
|
|
692
|
+
input_ids: Optional[mx.array] = None,
|
|
693
|
+
inputs_embeds: Optional[mx.array] = None,
|
|
694
|
+
attention_mask: Optional[mx.array] = None,
|
|
695
|
+
cache=None,
|
|
696
|
+
visual_pos_masks: Optional[mx.array] = None,
|
|
697
|
+
deepstack_visual_embeds: Optional[List[mx.array]] = None,
|
|
698
|
+
cos: Optional[mx.array] = None,
|
|
699
|
+
sin: Optional[mx.array] = None,
|
|
700
|
+
rope_deltas: Optional[mx.array] = None,
|
|
701
|
+
):
|
|
702
|
+
if inputs_embeds is None:
|
|
703
|
+
inputs_embeds = self.embed_tokens(input_ids)
|
|
704
|
+
|
|
705
|
+
hidden_states = inputs_embeds
|
|
706
|
+
|
|
707
|
+
if attention_mask is None:
|
|
708
|
+
attention_mask = create_attention_mask(hidden_states, cache, return_array=True)
|
|
709
|
+
|
|
710
|
+
if cache is None:
|
|
711
|
+
cache = [None] * len(self.layers)
|
|
712
|
+
|
|
713
|
+
for layer_idx, (decoder_layer, c) in enumerate(zip(self.layers, cache)):
|
|
714
|
+
hidden_states = decoder_layer(
|
|
715
|
+
hidden_states,
|
|
716
|
+
attention_mask=attention_mask,
|
|
717
|
+
cache=c,
|
|
718
|
+
cos=cos,
|
|
719
|
+
sin=sin,
|
|
720
|
+
rope_deltas=rope_deltas,
|
|
721
|
+
)
|
|
722
|
+
if deepstack_visual_embeds is not None and layer_idx < len(deepstack_visual_embeds):
|
|
723
|
+
hidden_states = self._deepstack_process(
|
|
724
|
+
hidden_states, visual_pos_masks, deepstack_visual_embeds[layer_idx]
|
|
725
|
+
)
|
|
726
|
+
hidden_states = self.norm(hidden_states)
|
|
727
|
+
return hidden_states
|
|
728
|
+
|
|
729
|
+
|
|
730
|
+
# Standalone Vision Model
|
|
731
|
+
class VEGModel(nn.Module):
|
|
732
|
+
def __init__(self, vision_config: VisionConfig):
|
|
733
|
+
super().__init__()
|
|
734
|
+
self.config = vision_config
|
|
735
|
+
self.visual = VisionModel(vision_config)
|
|
736
|
+
|
|
737
|
+
def __call__(self, pixel_values: mx.array, image_grid_thw: mx.array):
|
|
738
|
+
return self.visual(pixel_values, image_grid_thw)
|
|
739
|
+
|
|
740
|
+
def sanitize(self, weights):
|
|
741
|
+
sanitized = {}
|
|
742
|
+
for k, v in weights.items():
|
|
743
|
+
if "visual." in k:
|
|
744
|
+
# Remove prefixes to match our model structure
|
|
745
|
+
clean_key = k.replace("model.visual.", "").replace("visual.", "")
|
|
746
|
+
sanitized[f"visual.{clean_key}"] = v
|
|
747
|
+
return sanitized
|
|
748
|
+
|
|
749
|
+
|
|
750
|
+
# Pure LLM Model (no vision components)
|
|
751
|
+
class LLMModel(nn.Module):
|
|
752
|
+
def __init__(self, text_config: TextConfig):
|
|
753
|
+
super().__init__()
|
|
754
|
+
self.args = text_config
|
|
755
|
+
self.config = text_config
|
|
756
|
+
self.language_model = TextModel(text_config)
|
|
757
|
+
if not text_config.tie_word_embeddings:
|
|
758
|
+
self.lm_head = nn.Linear(text_config.hidden_size, text_config.vocab_size, bias=False)
|
|
759
|
+
|
|
760
|
+
def get_rope_index(
|
|
761
|
+
self,
|
|
762
|
+
input_ids: Optional[mx.array] = None,
|
|
763
|
+
image_grid_thw: Optional[mx.array] = None,
|
|
764
|
+
attention_mask: Optional[mx.array] = None,
|
|
765
|
+
) -> Tuple[mx.array, mx.array]:
|
|
766
|
+
"""Simplified version for images only (no video support)."""
|
|
767
|
+
|
|
768
|
+
spatial_merge_size = 2
|
|
769
|
+
image_token_id = 151655
|
|
770
|
+
vision_start_token_id = 151652
|
|
771
|
+
mrope_position_deltas = []
|
|
772
|
+
|
|
773
|
+
if input_ids is not None and image_grid_thw is not None:
|
|
774
|
+
total_input_ids = input_ids
|
|
775
|
+
if attention_mask is None:
|
|
776
|
+
attention_mask = mx.ones_like(total_input_ids)
|
|
777
|
+
|
|
778
|
+
batch_size, seq_len = input_ids.shape
|
|
779
|
+
position_ids_list = []
|
|
780
|
+
image_index = 0
|
|
781
|
+
|
|
782
|
+
for i in range(batch_size):
|
|
783
|
+
input_ids_seq = total_input_ids[i]
|
|
784
|
+
mask_seq = attention_mask[i]
|
|
785
|
+
|
|
786
|
+
# Use mask to get valid length
|
|
787
|
+
valid_length = int(mx.sum(mask_seq).item())
|
|
788
|
+
input_ids_seq = input_ids_seq[:valid_length]
|
|
789
|
+
|
|
790
|
+
image_nums = 0
|
|
791
|
+
# Find vision start tokens by iterating through the sequence
|
|
792
|
+
vision_start_positions = []
|
|
793
|
+
for pos in range(input_ids_seq.shape[0]):
|
|
794
|
+
if input_ids_seq[pos].item() == vision_start_token_id:
|
|
795
|
+
vision_start_positions.append(pos)
|
|
796
|
+
|
|
797
|
+
if len(vision_start_positions) > 0:
|
|
798
|
+
for pos in vision_start_positions:
|
|
799
|
+
if pos + 1 < input_ids_seq.shape[0]:
|
|
800
|
+
if input_ids_seq[pos + 1].item() == image_token_id:
|
|
801
|
+
image_nums += 1
|
|
802
|
+
|
|
803
|
+
input_tokens = input_ids_seq.tolist()
|
|
804
|
+
llm_pos_ids_list = []
|
|
805
|
+
st = 0
|
|
806
|
+
remain_images = image_nums
|
|
807
|
+
|
|
808
|
+
for _ in range(image_nums):
|
|
809
|
+
ed_image = input_tokens.index(image_token_id, st)
|
|
810
|
+
|
|
811
|
+
t = image_grid_thw[image_index, 0].item()
|
|
812
|
+
h = image_grid_thw[image_index, 1].item()
|
|
813
|
+
w = image_grid_thw[image_index, 2].item()
|
|
814
|
+
image_index += 1
|
|
815
|
+
remain_images -= 1
|
|
816
|
+
ed = ed_image
|
|
817
|
+
|
|
818
|
+
llm_grid_t = int(t)
|
|
819
|
+
llm_grid_h = int(h) // spatial_merge_size
|
|
820
|
+
llm_grid_w = int(w) // spatial_merge_size
|
|
821
|
+
text_len = ed - st
|
|
822
|
+
|
|
823
|
+
st_idx = (
|
|
824
|
+
llm_pos_ids_list[-1].max().item() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
825
|
+
)
|
|
826
|
+
text_pos = mx.arange(text_len).reshape(1, -1)
|
|
827
|
+
text_pos = mx.broadcast_to(text_pos, (3, text_len)) + st_idx
|
|
828
|
+
llm_pos_ids_list.append(text_pos)
|
|
829
|
+
|
|
830
|
+
# t_index is always 0 because llm_grid_t is always 1 for images
|
|
831
|
+
t_index = mx.arange(llm_grid_t).reshape(-1, 1)
|
|
832
|
+
t_index = mx.broadcast_to(
|
|
833
|
+
t_index, (llm_grid_t, llm_grid_h * llm_grid_w)
|
|
834
|
+
).reshape(-1)
|
|
835
|
+
|
|
836
|
+
h_index = mx.arange(llm_grid_h).reshape(1, -1, 1)
|
|
837
|
+
h_index = mx.broadcast_to(
|
|
838
|
+
h_index, (llm_grid_t, llm_grid_h, llm_grid_w)
|
|
839
|
+
).reshape(-1)
|
|
840
|
+
|
|
841
|
+
w_index = mx.arange(llm_grid_w).reshape(1, 1, -1)
|
|
842
|
+
w_index = mx.broadcast_to(
|
|
843
|
+
w_index, (llm_grid_t, llm_grid_h, llm_grid_w)
|
|
844
|
+
).reshape(-1)
|
|
845
|
+
|
|
846
|
+
vision_pos = mx.stack([t_index, h_index, w_index]) + text_len + st_idx
|
|
847
|
+
llm_pos_ids_list.append(vision_pos)
|
|
848
|
+
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
|
849
|
+
|
|
850
|
+
if st < len(input_tokens):
|
|
851
|
+
st_idx = (
|
|
852
|
+
llm_pos_ids_list[-1].max().item() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
853
|
+
)
|
|
854
|
+
text_len = len(input_tokens) - st
|
|
855
|
+
text_pos = mx.arange(text_len).reshape(1, -1)
|
|
856
|
+
text_pos = mx.broadcast_to(text_pos, (3, text_len)) + st_idx
|
|
857
|
+
llm_pos_ids_list.append(text_pos)
|
|
858
|
+
|
|
859
|
+
llm_positions = mx.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
|
|
860
|
+
|
|
861
|
+
# Create position_ids for this batch item, pad to seq_len
|
|
862
|
+
batch_position_ids = mx.ones((3, seq_len), dtype=input_ids.dtype)
|
|
863
|
+
valid_length = min(seq_len, llm_positions.shape[1])
|
|
864
|
+
|
|
865
|
+
# Create new arrays for each dimension
|
|
866
|
+
pos_dim0 = mx.concatenate(
|
|
867
|
+
[
|
|
868
|
+
llm_positions[0, :valid_length],
|
|
869
|
+
mx.ones(seq_len - valid_length, dtype=input_ids.dtype),
|
|
870
|
+
]
|
|
871
|
+
)
|
|
872
|
+
pos_dim1 = mx.concatenate(
|
|
873
|
+
[
|
|
874
|
+
llm_positions[1, :valid_length],
|
|
875
|
+
mx.ones(seq_len - valid_length, dtype=input_ids.dtype),
|
|
876
|
+
]
|
|
877
|
+
)
|
|
878
|
+
pos_dim2 = mx.concatenate(
|
|
879
|
+
[
|
|
880
|
+
llm_positions[2, :valid_length],
|
|
881
|
+
mx.ones(seq_len - valid_length, dtype=input_ids.dtype),
|
|
882
|
+
]
|
|
883
|
+
)
|
|
884
|
+
|
|
885
|
+
batch_position_ids = mx.stack([pos_dim0, pos_dim1, pos_dim2])
|
|
886
|
+
position_ids_list.append(batch_position_ids)
|
|
887
|
+
|
|
888
|
+
mrope_position_deltas.append(
|
|
889
|
+
llm_positions.max().item() + 1 - len(total_input_ids[i])
|
|
890
|
+
)
|
|
891
|
+
|
|
892
|
+
# Stack all batch position_ids
|
|
893
|
+
position_ids = mx.stack(position_ids_list, axis=1) # Shape: (3, batch_size, seq_len)
|
|
894
|
+
mrope_position_deltas = mx.array(mrope_position_deltas).reshape(-1, 1)
|
|
895
|
+
return position_ids, mrope_position_deltas
|
|
896
|
+
else:
|
|
897
|
+
if attention_mask is not None:
|
|
898
|
+
position_ids = mx.cumsum(attention_mask.astype(mx.int32), axis=-1) - 1
|
|
899
|
+
position_ids = mx.where(attention_mask == 0, 1, position_ids)
|
|
900
|
+
position_ids = mx.expand_dims(position_ids, axis=0)
|
|
901
|
+
position_ids = mx.broadcast_to(
|
|
902
|
+
position_ids, (3, position_ids.shape[1], position_ids.shape[2])
|
|
903
|
+
)
|
|
904
|
+
max_position_ids = mx.max(
|
|
905
|
+
mx.max(position_ids, axis=0, keepdims=False), axis=-1, keepdims=True
|
|
906
|
+
)
|
|
907
|
+
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
|
|
908
|
+
else:
|
|
909
|
+
seq_len = input_ids.shape[1]
|
|
910
|
+
batch_size = input_ids.shape[0]
|
|
911
|
+
position_ids = mx.arange(seq_len).reshape(1, 1, -1)
|
|
912
|
+
position_ids = mx.broadcast_to(position_ids, (3, batch_size, seq_len))
|
|
913
|
+
mrope_position_deltas = mx.zeros((batch_size, 1), dtype=input_ids.dtype)
|
|
914
|
+
|
|
915
|
+
return position_ids, mrope_position_deltas
|
|
916
|
+
|
|
917
|
+
def __call__(
|
|
918
|
+
self,
|
|
919
|
+
inputs: mx.array = None,
|
|
920
|
+
mask: mx.array = None,
|
|
921
|
+
cache=None,
|
|
922
|
+
inputs_embeds: Optional[mx.array] = None,
|
|
923
|
+
visual_pos_masks: Optional[mx.array] = None,
|
|
924
|
+
deepstack_visual_embeds: Optional[List[mx.array]] = None,
|
|
925
|
+
cos: Optional[mx.array] = None,
|
|
926
|
+
sin: Optional[mx.array] = None,
|
|
927
|
+
rope_deltas: Optional[mx.array] = None,
|
|
928
|
+
):
|
|
929
|
+
out = self.language_model(
|
|
930
|
+
input_ids=inputs,
|
|
931
|
+
inputs_embeds=inputs_embeds,
|
|
932
|
+
attention_mask=mask,
|
|
933
|
+
cache=cache,
|
|
934
|
+
visual_pos_masks=visual_pos_masks,
|
|
935
|
+
deepstack_visual_embeds=deepstack_visual_embeds,
|
|
936
|
+
cos=cos,
|
|
937
|
+
sin=sin,
|
|
938
|
+
rope_deltas=rope_deltas,
|
|
939
|
+
)
|
|
940
|
+
if self.args.tie_word_embeddings:
|
|
941
|
+
return self.language_model.embed_tokens.as_linear(out)
|
|
942
|
+
else:
|
|
943
|
+
return self.lm_head(out)
|
|
944
|
+
|
|
945
|
+
def sanitize(self, weights):
|
|
946
|
+
sanitized = {}
|
|
947
|
+
for k, v in weights.items():
|
|
948
|
+
if not ("visual." in k):
|
|
949
|
+
# Handle key mapping from combined model to LLM-only model
|
|
950
|
+
clean_key = k
|
|
951
|
+
|
|
952
|
+
# Remove model. prefix if present
|
|
953
|
+
if clean_key.startswith("model."):
|
|
954
|
+
clean_key = clean_key[6:] # Remove 'model.'
|
|
955
|
+
|
|
956
|
+
# Map language_ prefixed keys to language_model structure
|
|
957
|
+
if clean_key.startswith("language_"):
|
|
958
|
+
if clean_key.startswith("language_layers."):
|
|
959
|
+
clean_key = (
|
|
960
|
+
"language_model.layers." + clean_key[16:]
|
|
961
|
+
) # Map to language_model.layers.
|
|
962
|
+
elif clean_key.startswith("language_embed_tokens."):
|
|
963
|
+
clean_key = (
|
|
964
|
+
"language_model.embed_tokens." + clean_key[22:]
|
|
965
|
+
) # Map to language_model.embed_tokens.
|
|
966
|
+
elif clean_key.startswith("language_norm."):
|
|
967
|
+
clean_key = (
|
|
968
|
+
"language_model.norm." + clean_key[14:]
|
|
969
|
+
) # Map to language_model.norm.
|
|
970
|
+
|
|
971
|
+
sanitized[clean_key] = v
|
|
972
|
+
|
|
973
|
+
# Handle tied embeddings - remove lm_head if using tied embeddings
|
|
974
|
+
if self.args.tie_word_embeddings:
|
|
975
|
+
sanitized.pop("lm_head.weight", None)
|
|
976
|
+
|
|
977
|
+
return sanitized
|
|
978
|
+
|
|
979
|
+
@property
|
|
980
|
+
def layers(self):
|
|
981
|
+
return self.language_model.layers
|
|
982
|
+
|
|
983
|
+
|
|
984
|
+
# Combined Model (for compatibility and utility functions)
|
|
985
|
+
class Qwen3VLModel(nn.Module):
|
|
986
|
+
def __init__(self, args: ModelArgs):
|
|
987
|
+
super().__init__()
|
|
988
|
+
self.args = args
|
|
989
|
+
self.config = args
|
|
990
|
+
self.visual = VisionModel(args.vision_config)
|
|
991
|
+
self.language_model = TextModel(args.text_config)
|
|
992
|
+
|
|
993
|
+
def sanitize(self, weights):
|
|
994
|
+
# Map weights to match the combined model structure
|
|
995
|
+
sanitized = {}
|
|
996
|
+
for k, v in weights.items():
|
|
997
|
+
# Remove 'model.' prefix if present to match our structure
|
|
998
|
+
clean_key = k.replace("model.", "") if k.startswith("model.") else k
|
|
999
|
+
sanitized[clean_key] = v
|
|
1000
|
+
return sanitized
|
|
1001
|
+
|
|
1002
|
+
def get_image_features(self, pixel_values: mx.array, image_grid_thw: Optional[mx.array] = None):
|
|
1003
|
+
image_embeds, deepstack_visual_embeds = self.visual(pixel_values, image_grid_thw)
|
|
1004
|
+
# Split based on grid dimensions
|
|
1005
|
+
if image_grid_thw is not None:
|
|
1006
|
+
split_sizes = (
|
|
1007
|
+
mx.prod(image_grid_thw, axis=-1) // (self.visual.spatial_merge_size**2)
|
|
1008
|
+
).tolist()
|
|
1009
|
+
# Convert sizes to indices for mx.split (cumulative sum, excluding the last)
|
|
1010
|
+
split_indices = []
|
|
1011
|
+
cumsum = 0
|
|
1012
|
+
for size in split_sizes[:-1]: # Exclude last element
|
|
1013
|
+
cumsum += size
|
|
1014
|
+
split_indices.append(cumsum)
|
|
1015
|
+
|
|
1016
|
+
if split_indices: # Only split if we have indices
|
|
1017
|
+
image_embeds = mx.split(image_embeds, split_indices)
|
|
1018
|
+
else:
|
|
1019
|
+
image_embeds = [image_embeds] # Single image case
|
|
1020
|
+
return image_embeds, deepstack_visual_embeds
|
|
1021
|
+
|
|
1022
|
+
def __call__(
|
|
1023
|
+
self,
|
|
1024
|
+
input_ids: mx.array = None,
|
|
1025
|
+
attention_mask: Optional[mx.array] = None,
|
|
1026
|
+
inputs_embeds: Optional[mx.array] = None,
|
|
1027
|
+
pixel_values: Optional[mx.array] = None,
|
|
1028
|
+
image_grid_thw: Optional[mx.array] = None,
|
|
1029
|
+
cache=None,
|
|
1030
|
+
visual_pos_masks: Optional[mx.array] = None,
|
|
1031
|
+
deepstack_visual_embeds: Optional[List[mx.array]] = None,
|
|
1032
|
+
cos: Optional[mx.array] = None,
|
|
1033
|
+
sin: Optional[mx.array] = None,
|
|
1034
|
+
rope_deltas: Optional[mx.array] = None,
|
|
1035
|
+
):
|
|
1036
|
+
if inputs_embeds is None:
|
|
1037
|
+
inputs_embeds = self.language_model.embed_tokens(input_ids)
|
|
1038
|
+
|
|
1039
|
+
# Process images
|
|
1040
|
+
|
|
1041
|
+
if pixel_values is not None:
|
|
1042
|
+
image_embeds, deepstack_visual_embeds = self.get_image_features(
|
|
1043
|
+
pixel_values, image_grid_thw
|
|
1044
|
+
)
|
|
1045
|
+
|
|
1046
|
+
# Create masks and embed visual features
|
|
1047
|
+
if isinstance(image_embeds, list):
|
|
1048
|
+
image_embeds = mx.concatenate(image_embeds, axis=0)
|
|
1049
|
+
|
|
1050
|
+
# Find image token positions and replace with visual embeddings
|
|
1051
|
+
image_mask = input_ids == self.args.image_token_id
|
|
1052
|
+
visual_pos_masks = image_mask
|
|
1053
|
+
|
|
1054
|
+
# Replace image tokens with visual embeddings
|
|
1055
|
+
inputs_embeds = inputs_embeds.at[image_mask].set(
|
|
1056
|
+
image_embeds.astype(inputs_embeds.dtype)
|
|
1057
|
+
)
|
|
1058
|
+
|
|
1059
|
+
outputs = self.language_model(
|
|
1060
|
+
inputs_embeds=inputs_embeds,
|
|
1061
|
+
attention_mask=attention_mask,
|
|
1062
|
+
cache=cache,
|
|
1063
|
+
visual_pos_masks=visual_pos_masks,
|
|
1064
|
+
deepstack_visual_embeds=deepstack_visual_embeds,
|
|
1065
|
+
cos=cos,
|
|
1066
|
+
sin=sin,
|
|
1067
|
+
rope_deltas=rope_deltas,
|
|
1068
|
+
)
|
|
1069
|
+
|
|
1070
|
+
return outputs
|
|
1071
|
+
|
|
1072
|
+
|
|
1073
|
+
def handle_multimodal_embeds(vision_model, llm_model, input_ids, pixel_values, image_grid_thw):
|
|
1074
|
+
"""
|
|
1075
|
+
Handle the processing of multimodal embeddings including image features and position encoding.
|
|
1076
|
+
|
|
1077
|
+
This function processes vision and text inputs to create unified embeddings that can be fed
|
|
1078
|
+
into the language model. It handles:
|
|
1079
|
+
- Vision feature extraction from pixel values
|
|
1080
|
+
- Deepstack visual embedding collection
|
|
1081
|
+
- Image token replacement in text embeddings
|
|
1082
|
+
- Position encoding setup for MRoPE (Multi-dimensional RoPE)
|
|
1083
|
+
|
|
1084
|
+
Args:
|
|
1085
|
+
vision_model: The vision encoder model (VEGModel instance)
|
|
1086
|
+
llm_model: The language model (LLMModel instance)
|
|
1087
|
+
input_ids: Tokenized text input with image token placeholders [batch_size, seq_len]
|
|
1088
|
+
pixel_values: Preprocessed image pixel data [num_patches, feature_dim]
|
|
1089
|
+
image_grid_thw: Grid dimensions for each image [num_images, 3] (time, height, width)
|
|
1090
|
+
|
|
1091
|
+
Returns:
|
|
1092
|
+
tuple: (inputs_embeds, deepstack_visual_embeds, visual_pos_masks, cos, sin, rope_deltas)
|
|
1093
|
+
- inputs_embeds: Combined text and image embeddings [batch_size, seq_len, hidden_size]
|
|
1094
|
+
- deepstack_visual_embeds: Multi-layer visual features for deepstack processing
|
|
1095
|
+
- visual_pos_masks: Boolean mask indicating image token positions
|
|
1096
|
+
- cos: Cosine values for rotary position encoding
|
|
1097
|
+
- sin: Sine values for rotary position encoding
|
|
1098
|
+
- rope_deltas: Position offset deltas for rope computation
|
|
1099
|
+
"""
|
|
1100
|
+
inputs_embeds = llm_model.language_model.embed_tokens(input_ids.squeeze(0))
|
|
1101
|
+
deepstack_visual_embeds = None
|
|
1102
|
+
visual_pos_masks = None
|
|
1103
|
+
cos = None
|
|
1104
|
+
sin = None
|
|
1105
|
+
rope_deltas = 0
|
|
1106
|
+
|
|
1107
|
+
if pixel_values is not None:
|
|
1108
|
+
if pixel_values.ndim == 4:
|
|
1109
|
+
pixel_values = mx.expand_dims(pixel_values, axis=2)
|
|
1110
|
+
|
|
1111
|
+
# Process each image individually to prevent feature mixing
|
|
1112
|
+
image_embeds_list = []
|
|
1113
|
+
all_deepstack_embeds = []
|
|
1114
|
+
|
|
1115
|
+
# Calculate cumulative indices for each image
|
|
1116
|
+
cumulative_patches = 0
|
|
1117
|
+
|
|
1118
|
+
for i in range(image_grid_thw.shape[0]):
|
|
1119
|
+
# Calculate number of patches for current image
|
|
1120
|
+
current_patches = int(image_grid_thw[i, 1] * image_grid_thw[i, 2])
|
|
1121
|
+
start_idx = cumulative_patches
|
|
1122
|
+
end_idx = cumulative_patches + current_patches
|
|
1123
|
+
cumulative_patches += current_patches
|
|
1124
|
+
|
|
1125
|
+
single_pixel_values = pixel_values[start_idx:end_idx]
|
|
1126
|
+
single_grid_thw = image_grid_thw[i : i + 1]
|
|
1127
|
+
|
|
1128
|
+
# Use vision model directly
|
|
1129
|
+
single_embeds, single_deepstack = vision_model(single_pixel_values, single_grid_thw)
|
|
1130
|
+
|
|
1131
|
+
# Split based on grid dimensions
|
|
1132
|
+
if single_grid_thw is not None:
|
|
1133
|
+
split_sizes = (
|
|
1134
|
+
mx.prod(single_grid_thw, axis=-1) // (vision_model.visual.spatial_merge_size**2)
|
|
1135
|
+
).tolist()
|
|
1136
|
+
split_indices = []
|
|
1137
|
+
cumsum = 0
|
|
1138
|
+
for size in split_sizes[:-1]:
|
|
1139
|
+
cumsum += size
|
|
1140
|
+
split_indices.append(cumsum)
|
|
1141
|
+
|
|
1142
|
+
if split_indices:
|
|
1143
|
+
single_embeds = mx.split(single_embeds, split_indices)
|
|
1144
|
+
else:
|
|
1145
|
+
single_embeds = [single_embeds]
|
|
1146
|
+
|
|
1147
|
+
image_embeds_list.extend(single_embeds)
|
|
1148
|
+
|
|
1149
|
+
# Collect deepstack embeddings
|
|
1150
|
+
if i == 0:
|
|
1151
|
+
all_deepstack_embeds = single_deepstack
|
|
1152
|
+
else:
|
|
1153
|
+
# Concatenate deepstack embeddings from different images
|
|
1154
|
+
for j in range(len(all_deepstack_embeds)):
|
|
1155
|
+
all_deepstack_embeds[j] = mx.concatenate(
|
|
1156
|
+
[all_deepstack_embeds[j], single_deepstack[j]], axis=0
|
|
1157
|
+
)
|
|
1158
|
+
|
|
1159
|
+
deepstack_visual_embeds = all_deepstack_embeds
|
|
1160
|
+
|
|
1161
|
+
# Concatenate all image embeddings for processing
|
|
1162
|
+
image_embeds = mx.concatenate(image_embeds_list, axis=0)
|
|
1163
|
+
|
|
1164
|
+
# Find all image token positions
|
|
1165
|
+
image_token_id = 151655 # Default image token ID
|
|
1166
|
+
image_mask = input_ids.squeeze(0) == image_token_id
|
|
1167
|
+
image_mask_np = np.array(image_mask)
|
|
1168
|
+
image_token_positions = np.where(image_mask_np)[0]
|
|
1169
|
+
|
|
1170
|
+
# Verify we have the correct number of image tokens
|
|
1171
|
+
expected_total_tokens = sum(embed.shape[0] for embed in image_embeds_list)
|
|
1172
|
+
assert (
|
|
1173
|
+
len(image_token_positions) == expected_total_tokens
|
|
1174
|
+
), f"Expected {expected_total_tokens} image tokens, got {len(image_token_positions)}"
|
|
1175
|
+
|
|
1176
|
+
# Replace image tokens with image embeddings
|
|
1177
|
+
seq_len = inputs_embeds.shape[0]
|
|
1178
|
+
result = inputs_embeds
|
|
1179
|
+
|
|
1180
|
+
# Replace image tokens with image embeddings sequentially
|
|
1181
|
+
embed_idx = 0
|
|
1182
|
+
for img_embed in image_embeds_list:
|
|
1183
|
+
for patch_idx in range(img_embed.shape[0]):
|
|
1184
|
+
token_pos = image_token_positions[embed_idx]
|
|
1185
|
+
pos_mask = mx.arange(seq_len) == token_pos
|
|
1186
|
+
result = mx.where(
|
|
1187
|
+
mx.expand_dims(pos_mask, axis=-1),
|
|
1188
|
+
mx.expand_dims(img_embed[patch_idx], axis=0).astype(inputs_embeds.dtype),
|
|
1189
|
+
result,
|
|
1190
|
+
)
|
|
1191
|
+
embed_idx += 1
|
|
1192
|
+
|
|
1193
|
+
inputs_embeds = result
|
|
1194
|
+
position_ids, rope_deltas = llm_model.get_rope_index(input_ids, image_grid_thw)
|
|
1195
|
+
cos, sin = llm_model.language_model.rotary_emb(inputs_embeds, position_ids)
|
|
1196
|
+
if inputs_embeds.ndim == 2:
|
|
1197
|
+
inputs_embeds = mx.expand_dims(inputs_embeds, axis=0)
|
|
1198
|
+
|
|
1199
|
+
if image_mask is not None:
|
|
1200
|
+
visual_pos_masks = image_mask
|
|
1201
|
+
|
|
1202
|
+
return inputs_embeds, deepstack_visual_embeds, visual_pos_masks, cos, sin, rope_deltas
|
|
1203
|
+
|
|
1204
|
+
|
|
1205
|
+
# Legacy Model wrapper (for backward compatibility)
|
|
1206
|
+
class Model(nn.Module):
|
|
1207
|
+
def __init__(self, args: ModelArgs):
|
|
1208
|
+
super().__init__()
|
|
1209
|
+
self.args = args
|
|
1210
|
+
self.model = Qwen3VLModel(args)
|
|
1211
|
+
if not args.text_config.tie_word_embeddings:
|
|
1212
|
+
self.lm_head = nn.Linear(
|
|
1213
|
+
args.text_config.hidden_size, args.text_config.vocab_size, bias=False
|
|
1214
|
+
)
|
|
1215
|
+
|
|
1216
|
+
def __call__(
|
|
1217
|
+
self,
|
|
1218
|
+
inputs: mx.array = None,
|
|
1219
|
+
mask: mx.array = None,
|
|
1220
|
+
cache=None,
|
|
1221
|
+
inputs_embeds: Optional[mx.array] = None,
|
|
1222
|
+
pixel_values: Optional[mx.array] = None,
|
|
1223
|
+
image_grid_thw: Optional[mx.array] = None,
|
|
1224
|
+
visual_pos_masks: Optional[mx.array] = None,
|
|
1225
|
+
deepstack_visual_embeds: Optional[List[mx.array]] = None,
|
|
1226
|
+
cos: Optional[mx.array] = None,
|
|
1227
|
+
sin: Optional[mx.array] = None,
|
|
1228
|
+
rope_deltas: Optional[mx.array] = None,
|
|
1229
|
+
):
|
|
1230
|
+
out = self.model(
|
|
1231
|
+
input_ids=inputs,
|
|
1232
|
+
inputs_embeds=inputs_embeds,
|
|
1233
|
+
attention_mask=mask,
|
|
1234
|
+
cache=cache,
|
|
1235
|
+
pixel_values=pixel_values,
|
|
1236
|
+
image_grid_thw=image_grid_thw,
|
|
1237
|
+
visual_pos_masks=visual_pos_masks,
|
|
1238
|
+
deepstack_visual_embeds=deepstack_visual_embeds,
|
|
1239
|
+
cos=cos,
|
|
1240
|
+
sin=sin,
|
|
1241
|
+
rope_deltas=rope_deltas,
|
|
1242
|
+
)
|
|
1243
|
+
if self.args.text_config.tie_word_embeddings:
|
|
1244
|
+
return self.model.language_model.embed_tokens.as_linear(out)
|
|
1245
|
+
else:
|
|
1246
|
+
return self.lm_head(out)
|
|
1247
|
+
|
|
1248
|
+
def sanitize(self, weights):
|
|
1249
|
+
# Remove any unnecessary weights
|
|
1250
|
+
sanitized = {}
|
|
1251
|
+
for k, v in weights.items():
|
|
1252
|
+
sanitized[k] = v
|
|
1253
|
+
|
|
1254
|
+
# Handle tied embeddings - remove lm_head if using tied embeddings
|
|
1255
|
+
if self.args.text_config.tie_word_embeddings:
|
|
1256
|
+
sanitized.pop("lm_head.weight", None)
|
|
1257
|
+
|
|
1258
|
+
return sanitized
|
|
1259
|
+
|
|
1260
|
+
@property
|
|
1261
|
+
def layers(self):
|
|
1262
|
+
return self.model.language_model.layers
|