nexaai 1.0.18rc1__cp310-cp310-macosx_14_0_universal2.whl → 1.0.19__cp310-cp310-macosx_14_0_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nexaai might be problematic. Click here for more details.
- nexaai/_stub.cpython-310-darwin.so +0 -0
- nexaai/_version.py +1 -1
- nexaai/asr.py +2 -1
- nexaai/binds/{nexa_llama_cpp → cpu_gpu}/libggml-base.dylib +0 -0
- nexaai/binds/{nexa_llama_cpp → cpu_gpu}/libmtmd.dylib +0 -0
- nexaai/binds/{nexa_llama_cpp/libllama.dylib → cpu_gpu/libnexa_cpu_gpu.dylib} +0 -0
- nexaai/binds/{nexa_llama_cpp → cpu_gpu}/libnexa_plugin.dylib +0 -0
- nexaai/binds/libnexa_bridge.dylib +0 -0
- nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
- nexaai/binds/{nexa_mlx → metal}/libnexa_plugin.dylib +0 -0
- nexaai/binds/{nexa_nexaml → nexaml}/libggml-base.dylib +0 -0
- nexaai/binds/{nexa_nexaml → nexaml}/libnexa-mm-process.dylib +0 -0
- nexaai/binds/{nexa_nexaml → nexaml}/libnexa-sampling.dylib +0 -0
- nexaai/binds/nexaml/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexaml/libnexaproc.dylib +0 -0
- nexaai/binds/{nexa_nexaml → nexaml}/libomp.dylib +0 -0
- nexaai/binds/nexaml/libqwen3-vl.dylib +0 -0
- nexaai/binds/nexaml/libqwen3vl-vision.dylib +0 -0
- nexaai/cv.py +2 -1
- nexaai/embedder.py +1 -1
- nexaai/image_gen.py +2 -1
- nexaai/llm.py +5 -3
- nexaai/llm_impl/mlx_llm_impl.py +2 -0
- nexaai/llm_impl/pybind_llm_impl.py +2 -0
- nexaai/mlx_backend/vlm/generate_qwen3_vl.py +176 -96
- nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +259 -0
- nexaai/mlx_backend/vlm/interface.py +99 -30
- nexaai/mlx_backend/vlm/main.py +58 -9
- nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +338 -299
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1308 -0
- nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
- nexaai/rerank.py +2 -1
- nexaai/tts.py +2 -1
- nexaai/utils/manifest_utils.py +222 -15
- nexaai/utils/model_manager.py +120 -14
- nexaai/utils/model_types.py +2 -0
- nexaai/vlm.py +2 -1
- {nexaai-1.0.18rc1.dist-info → nexaai-1.0.19.dist-info}/METADATA +1 -2
- {nexaai-1.0.18rc1.dist-info → nexaai-1.0.19.dist-info}/RECORD +211 -200
- nexaai/binds/nexa_nexaml/libnexa_plugin.dylib +0 -0
- nexaai/binds/nexa_nexaml/libnexaproc.dylib +0 -0
- nexaai/binds/nexa_nexaml/libqwen3-vl.dylib +0 -0
- nexaai/binds/nexa_nexaml/libqwen3vl-vision.dylib +0 -0
- /nexaai/binds/{nexa_llama_cpp → cpu_gpu}/libggml-cpu.so +0 -0
- /nexaai/binds/{nexa_llama_cpp → cpu_gpu}/libggml-metal.so +0 -0
- /nexaai/binds/{nexa_llama_cpp → cpu_gpu}/libggml.dylib +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/ml.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/__init__.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/__init__.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/__init__.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/bigvgan/__init__.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/bigvgan/activation.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/bigvgan/amp.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/bigvgan/bigvgan.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/bigvgan/conv.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/bigvgan/resample.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/descript/__init__.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/descript/base.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/descript/dac.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/descript/nn/__init__.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/descript/nn/layers.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/descript/nn/quantize.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/encodec/__init__.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/encodec/encodec.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/mimi/__init__.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/mimi/mimi.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/mimi/modules/__init__.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/mimi/modules/conv.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/mimi/modules/kv_cache.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/mimi/modules/quantization.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/mimi/modules/seanet.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/mimi/modules/transformer.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/s3/__init__.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/s3/model.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/s3/model_v2.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/s3/utils.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/snac/__init__.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/snac/attention.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/snac/layers.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/snac/snac.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/snac/vq.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/vocos/__init__.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/vocos/mel.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/vocos/vocos.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/tests/__init__.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/tests/test_bigvgan.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/tests/test_descript.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/tests/test_encodec.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/tests/test_mimi.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/tests/test_s3.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/tests/test_snac.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/tests/test_vocos.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/server.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/sts/__init__.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/sts/tests/test_voice_pipeline.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/sts/voice_pipeline.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/__init__.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/generate.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/models/__init__.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/models/parakeet/__init__.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/models/parakeet/alignment.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/models/parakeet/attention.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/models/parakeet/audio.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/models/parakeet/conformer.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/models/parakeet/ctc.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/models/parakeet/parakeet.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/models/parakeet/rnnt.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/models/parakeet/tokenizer.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/models/wav2vec/feature_extractor.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/models/wav2vec/wav2vec.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/models/whisper/__init__.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/models/whisper/audio.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/models/whisper/decoding.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/models/whisper/timing.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/models/whisper/tokenizer.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/models/whisper/whisper.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/models/whisper/writers.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/tests/test_models.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/utils.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/__init__.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/audio_player.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/convert.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/generate.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/__init__.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/bark/__init__.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/bark/bark.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/bark/isftnet.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/bark/pipeline.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/base.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/dia/__init__.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/dia/audio.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/dia/config.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/dia/dia.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/dia/layers.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/indextts/__init__.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/indextts/attention.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/indextts/bigvgan.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/indextts/conformer.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/indextts/gpt2.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/indextts/indextts.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/indextts/mel.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/indextts/normalize.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/indextts/perceiver.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/interpolate.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/kokoro/__init__.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/kokoro/istftnet.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/kokoro/kokoro.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/kokoro/modules.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/kokoro/pipeline.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/kokoro/voice.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/llama/__init__.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/llama/llama.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/outetts/__init__.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/outetts/audio_processor.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/outetts/dac_interface.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/outetts/outetts.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/outetts/prompt_processor.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/outetts/tokens.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/sesame/__init__.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/sesame/attention.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/sesame/sesame.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/sesame/watermarking.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/__init__.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/audio_tokenizer.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/bicodec.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/modules/blocks/sampler.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/modules/residual.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/modules/residual_fsq.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/modules/speaker/__init__.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/spark.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/utils/audio.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/utils/file.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/utils/token_parser.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/tests/__init__.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/tests/test_base.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/tests/test_convert.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/tests/test_interpolate.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/tests/test_models.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/utils.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/utils.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/version.py +0 -0
- /nexaai/binds/{nexa_mlx → metal}/py-lib/profiling.py +0 -0
- /nexaai/binds/{nexa_nexaml → nexaml}/libfftw3.3.dylib +0 -0
- /nexaai/binds/{nexa_nexaml → nexaml}/libfftw3f.3.dylib +0 -0
- /nexaai/binds/{nexa_nexaml → nexaml}/libggml-cpu.so +0 -0
- /nexaai/binds/{nexa_nexaml → nexaml}/libggml-metal.so +0 -0
- /nexaai/binds/{nexa_nexaml → nexaml}/libggml.dylib +0 -0
- /nexaai/binds/{nexa_nexaml → nexaml}/libmp3lame.0.dylib +0 -0
- /nexaai/binds/{nexa_nexaml → nexaml}/libmpg123.0.dylib +0 -0
- {nexaai-1.0.18rc1.dist-info → nexaai-1.0.19.dist-info}/WHEEL +0 -0
- {nexaai-1.0.18rc1.dist-info → nexaai-1.0.19.dist-info}/top_level.txt +0 -0
|
@@ -8,29 +8,13 @@ import mlx.nn as nn
|
|
|
8
8
|
import math
|
|
9
9
|
import numpy as np
|
|
10
10
|
|
|
11
|
-
|
|
12
|
-
import
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
# Try relative imports first, fallback to sys.path approach for Nuitka compatibility
|
|
19
|
-
try:
|
|
20
|
-
from .llm_common.base import (
|
|
21
|
-
BaseModelArgs,
|
|
22
|
-
create_attention_mask,
|
|
23
|
-
scaled_dot_product_attention,
|
|
24
|
-
)
|
|
25
|
-
from .llm_common.rope_utils import initialize_rope
|
|
26
|
-
except ImportError:
|
|
27
|
-
# Fallback for Nuitka compiled environment
|
|
28
|
-
from llm_common.base import (
|
|
29
|
-
BaseModelArgs,
|
|
30
|
-
create_attention_mask,
|
|
31
|
-
scaled_dot_product_attention,
|
|
32
|
-
)
|
|
33
|
-
from llm_common.rope_utils import initialize_rope
|
|
11
|
+
# Import from nested llm_common structure using relative imports
|
|
12
|
+
from .llm_common.base import (
|
|
13
|
+
BaseModelArgs,
|
|
14
|
+
create_attention_mask,
|
|
15
|
+
scaled_dot_product_attention,
|
|
16
|
+
)
|
|
17
|
+
from .llm_common.rope_utils import initialize_rope
|
|
34
18
|
|
|
35
19
|
|
|
36
20
|
@dataclass
|
|
@@ -136,28 +120,24 @@ class VisionPatchEmbed(nn.Module):
|
|
|
136
120
|
|
|
137
121
|
kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
|
|
138
122
|
self.proj = nn.Conv3d(
|
|
139
|
-
self.in_channels,
|
|
140
|
-
self.embed_dim,
|
|
141
|
-
kernel_size=kernel_size,
|
|
142
|
-
stride=kernel_size,
|
|
143
|
-
bias=True
|
|
123
|
+
self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True
|
|
144
124
|
)
|
|
145
125
|
|
|
146
126
|
def __call__(self, hidden_states: mx.array) -> mx.array:
|
|
147
127
|
target_dtype = self.proj.weight.dtype
|
|
148
|
-
|
|
128
|
+
|
|
149
129
|
# Reshape to 5D: [batch, channels, temporal, height, width] (PyTorch format)
|
|
150
130
|
# This matches the PyTorch ground truth exactly
|
|
151
131
|
hidden_states = hidden_states.reshape(
|
|
152
132
|
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
|
|
153
133
|
)
|
|
154
|
-
|
|
134
|
+
|
|
155
135
|
# Convert to MLX format: [batch, temporal, height, width, channels]
|
|
156
136
|
hidden_states = hidden_states.transpose(0, 2, 3, 4, 1)
|
|
157
|
-
|
|
137
|
+
|
|
158
138
|
# Apply conv3d with target dtype and reshape to match PyTorch output
|
|
159
139
|
hidden_states = self.proj(hidden_states.astype(target_dtype)).reshape(-1, self.embed_dim)
|
|
160
|
-
|
|
140
|
+
|
|
161
141
|
return hidden_states
|
|
162
142
|
|
|
163
143
|
|
|
@@ -179,20 +159,20 @@ class VisionRotaryEmbedding(nn.Module):
|
|
|
179
159
|
class VisionPatchMerger(nn.Module):
|
|
180
160
|
def __init__(self, config: VisionConfig, use_postshuffle_norm=False):
|
|
181
161
|
super().__init__()
|
|
182
|
-
self.hidden_size = config.hidden_size * (config.spatial_merge_size
|
|
162
|
+
self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
|
|
183
163
|
self.use_postshuffle_norm = use_postshuffle_norm
|
|
184
|
-
|
|
164
|
+
|
|
185
165
|
norm_size = self.hidden_size if use_postshuffle_norm else config.hidden_size
|
|
186
|
-
self.
|
|
166
|
+
self.norm = nn.LayerNorm(norm_size, eps=1e-6)
|
|
187
167
|
self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size)
|
|
188
168
|
self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size)
|
|
189
169
|
|
|
190
170
|
def __call__(self, x: mx.array) -> mx.array:
|
|
191
171
|
if self.use_postshuffle_norm:
|
|
192
|
-
x = self.
|
|
172
|
+
x = self.norm(x.reshape(-1, self.hidden_size)).reshape(-1, self.hidden_size)
|
|
193
173
|
else:
|
|
194
|
-
x = self.
|
|
195
|
-
|
|
174
|
+
x = self.norm(x).reshape(-1, self.hidden_size)
|
|
175
|
+
|
|
196
176
|
x = self.linear_fc2(nn.gelu(self.linear_fc1(x)))
|
|
197
177
|
return x
|
|
198
178
|
|
|
@@ -203,8 +183,8 @@ class VisionAttention(nn.Module):
|
|
|
203
183
|
self.dim = config.hidden_size
|
|
204
184
|
self.num_heads = config.num_heads
|
|
205
185
|
self.head_dim = self.dim // self.num_heads
|
|
206
|
-
self.scaling = self.head_dim
|
|
207
|
-
|
|
186
|
+
self.scaling = self.head_dim**-0.5
|
|
187
|
+
|
|
208
188
|
self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True)
|
|
209
189
|
self.proj = nn.Linear(self.dim, self.dim)
|
|
210
190
|
|
|
@@ -220,51 +200,48 @@ class VisionAttention(nn.Module):
|
|
|
220
200
|
qkv = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1)
|
|
221
201
|
qkv = qkv.transpose(1, 0, 2, 3)
|
|
222
202
|
query_states, key_states, value_states = qkv[0], qkv[1], qkv[2]
|
|
223
|
-
|
|
203
|
+
|
|
224
204
|
cos, sin = position_embeddings
|
|
225
|
-
query_states, key_states = apply_rotary_pos_emb_vision(
|
|
226
|
-
query_states, key_states, cos, sin
|
|
227
|
-
)
|
|
205
|
+
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
|
|
228
206
|
|
|
229
207
|
query_states = query_states.transpose(1, 0, 2)
|
|
230
208
|
key_states = key_states.transpose(1, 0, 2)
|
|
231
209
|
value_states = value_states.transpose(1, 0, 2)
|
|
232
|
-
|
|
210
|
+
|
|
233
211
|
query_states = mx.expand_dims(query_states, axis=0)
|
|
234
212
|
key_states = mx.expand_dims(key_states, axis=0)
|
|
235
213
|
value_states = mx.expand_dims(value_states, axis=0)
|
|
236
|
-
|
|
214
|
+
|
|
237
215
|
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
|
|
238
|
-
|
|
216
|
+
|
|
239
217
|
split_indices = []
|
|
240
218
|
cumsum = 0
|
|
241
219
|
for length in lengths[:-1]:
|
|
242
220
|
cumsum += int(length)
|
|
243
221
|
split_indices.append(cumsum)
|
|
244
|
-
|
|
222
|
+
|
|
245
223
|
if split_indices:
|
|
246
224
|
q_splits = mx.split(query_states, split_indices, axis=1)
|
|
247
225
|
k_splits = mx.split(key_states, split_indices, axis=1)
|
|
248
226
|
v_splits = mx.split(value_states, split_indices, axis=1)
|
|
249
227
|
else:
|
|
250
228
|
q_splits = [query_states]
|
|
251
|
-
k_splits = [key_states]
|
|
229
|
+
k_splits = [key_states]
|
|
252
230
|
v_splits = [value_states]
|
|
253
|
-
|
|
231
|
+
|
|
254
232
|
attn_outputs = []
|
|
255
233
|
for q, k, v in zip(q_splits, k_splits, v_splits):
|
|
256
234
|
attn_out = scaled_dot_product_attention(
|
|
257
|
-
q, k, v,
|
|
258
|
-
scale=self.scaling, mask=None, cache=None
|
|
235
|
+
q, k, v, scale=self.scaling, mask=None, cache=None
|
|
259
236
|
)
|
|
260
237
|
attn_outputs.append(attn_out)
|
|
261
|
-
|
|
238
|
+
|
|
262
239
|
attn_output = mx.concatenate(attn_outputs, axis=1)
|
|
263
|
-
|
|
240
|
+
|
|
264
241
|
attn_output = attn_output[0].transpose(1, 0, 2)
|
|
265
242
|
attn_output = attn_output.reshape(seq_length, -1)
|
|
266
243
|
attn_output = self.proj(attn_output)
|
|
267
|
-
|
|
244
|
+
|
|
268
245
|
return attn_output
|
|
269
246
|
|
|
270
247
|
|
|
@@ -300,7 +277,7 @@ class VisionModel(nn.Module):
|
|
|
300
277
|
|
|
301
278
|
self.patch_embed = VisionPatchEmbed(config)
|
|
302
279
|
self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size)
|
|
303
|
-
self.num_grid_per_side = int(config.num_position_embeddings
|
|
280
|
+
self.num_grid_per_side = int(config.num_position_embeddings**0.5)
|
|
304
281
|
|
|
305
282
|
head_dim = config.hidden_size // config.num_heads
|
|
306
283
|
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
|
@@ -326,7 +303,7 @@ class VisionModel(nn.Module):
|
|
|
326
303
|
num_frames = int(grid_thw[i, 0].item())
|
|
327
304
|
height = int(grid_thw[i, 1].item())
|
|
328
305
|
width = int(grid_thw[i, 2].item())
|
|
329
|
-
|
|
306
|
+
|
|
330
307
|
merged_h, merged_w = height // merge_size, width // merge_size
|
|
331
308
|
|
|
332
309
|
block_rows = mx.arange(merged_h) # block row indices
|
|
@@ -338,8 +315,12 @@ class VisionModel(nn.Module):
|
|
|
338
315
|
row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None]
|
|
339
316
|
col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :]
|
|
340
317
|
|
|
341
|
-
row_idx = mx.broadcast_to(
|
|
342
|
-
|
|
318
|
+
row_idx = mx.broadcast_to(
|
|
319
|
+
row_idx, (merged_h, merged_w, merge_size, merge_size)
|
|
320
|
+
).reshape(-1)
|
|
321
|
+
col_idx = mx.broadcast_to(
|
|
322
|
+
col_idx, (merged_h, merged_w, merge_size, merge_size)
|
|
323
|
+
).reshape(-1)
|
|
343
324
|
|
|
344
325
|
coords = mx.stack([row_idx, col_idx], axis=-1)
|
|
345
326
|
|
|
@@ -350,19 +331,19 @@ class VisionModel(nn.Module):
|
|
|
350
331
|
|
|
351
332
|
# Concatenate all coordinate parts
|
|
352
333
|
pos_ids = mx.concatenate(pos_ids_parts, axis=0)
|
|
353
|
-
|
|
334
|
+
|
|
354
335
|
embeddings = freq_table[pos_ids] # lookup rotary embeddings
|
|
355
336
|
embeddings = embeddings.reshape(embeddings.shape[0], -1)
|
|
356
337
|
return embeddings
|
|
357
338
|
|
|
358
339
|
def fast_pos_embed_interpolate(self, grid_thw: mx.array):
|
|
359
340
|
patch_pos_embeds = []
|
|
360
|
-
|
|
341
|
+
|
|
361
342
|
for i in range(grid_thw.shape[0]):
|
|
362
343
|
t = int(grid_thw[i, 0].item())
|
|
363
344
|
h = int(grid_thw[i, 1].item())
|
|
364
345
|
w = int(grid_thw[i, 2].item())
|
|
365
|
-
|
|
346
|
+
|
|
366
347
|
# Simple position embedding interpolation
|
|
367
348
|
h_idxs = mx.linspace(0, self.num_grid_per_side - 1, h)
|
|
368
349
|
w_idxs = mx.linspace(0, self.num_grid_per_side - 1, w)
|
|
@@ -399,37 +380,41 @@ class VisionModel(nn.Module):
|
|
|
399
380
|
|
|
400
381
|
# Repeat for temporal dimension and apply spatial merging
|
|
401
382
|
pos_embed = mx.tile(pos_embed, (t, 1))
|
|
402
|
-
|
|
383
|
+
|
|
403
384
|
# Apply spatial merging pattern
|
|
404
385
|
merge_size = self.config.spatial_merge_size
|
|
405
|
-
pos_embed = pos_embed.reshape(
|
|
386
|
+
pos_embed = pos_embed.reshape(
|
|
387
|
+
t, h // merge_size, merge_size, w // merge_size, merge_size, -1
|
|
388
|
+
)
|
|
406
389
|
pos_embed = mx.transpose(pos_embed, (0, 1, 3, 2, 4, 5))
|
|
407
390
|
pos_embed = pos_embed.reshape(-1, pos_embed.shape[-1])
|
|
408
|
-
|
|
391
|
+
|
|
409
392
|
patch_pos_embeds.append(pos_embed)
|
|
410
|
-
|
|
393
|
+
|
|
411
394
|
return mx.concatenate(patch_pos_embeds, axis=0)
|
|
412
395
|
|
|
413
|
-
def __call__(
|
|
396
|
+
def __call__(
|
|
397
|
+
self, hidden_states: mx.array, grid_thw: mx.array
|
|
398
|
+
) -> Tuple[mx.array, List[mx.array]]:
|
|
414
399
|
hidden_states = self.patch_embed(hidden_states)
|
|
415
|
-
|
|
400
|
+
|
|
416
401
|
pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
|
|
417
402
|
hidden_states = hidden_states + pos_embeds
|
|
418
403
|
|
|
419
404
|
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
|
420
405
|
seq_len = hidden_states.shape[0]
|
|
421
|
-
|
|
406
|
+
|
|
422
407
|
emb = mx.concatenate([rotary_pos_emb, rotary_pos_emb], axis=-1)
|
|
423
408
|
position_embeddings = (mx.cos(emb), mx.sin(emb))
|
|
424
409
|
|
|
425
|
-
|
|
410
|
+
# Create cumulative sequence lengths (following HuggingFace implementation)
|
|
426
411
|
# torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0])
|
|
427
412
|
seq_lens_per_image = grid_thw[:, 1] * grid_thw[:, 2] # h * w for each image
|
|
428
413
|
seq_lens = []
|
|
429
414
|
for i, (seq_len, repeats) in enumerate(zip(seq_lens_per_image, grid_thw[:, 0])):
|
|
430
415
|
seq_lens.extend([seq_len] * int(repeats))
|
|
431
416
|
seq_lens = mx.array(seq_lens)
|
|
432
|
-
|
|
417
|
+
|
|
433
418
|
# Then compute cumulative sum
|
|
434
419
|
cu_seqlens = mx.cumsum(seq_lens)
|
|
435
420
|
# Pad with 0 at the beginning
|
|
@@ -457,7 +442,7 @@ class TextRotaryEmbedding(nn.Module):
|
|
|
457
442
|
self.config = config
|
|
458
443
|
self.max_seq_len_cached = config.max_position_embeddings
|
|
459
444
|
self.original_max_seq_len = config.max_position_embeddings
|
|
460
|
-
|
|
445
|
+
|
|
461
446
|
# MRoPE configuration
|
|
462
447
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
|
463
448
|
self.rope_type = config.rope_scaling.get("rope_type", "default")
|
|
@@ -465,17 +450,19 @@ class TextRotaryEmbedding(nn.Module):
|
|
|
465
450
|
else:
|
|
466
451
|
self.rope_type = "default"
|
|
467
452
|
self.mrope_section = [24, 20, 20]
|
|
468
|
-
|
|
453
|
+
|
|
469
454
|
# Store parameters for computing inv_freq on the fly
|
|
470
455
|
self.head_dim = config.head_dim
|
|
471
456
|
self.theta = config.rope_theta
|
|
472
|
-
|
|
457
|
+
|
|
473
458
|
# Attention scaling (simplified - may need adjustment based on actual config)
|
|
474
459
|
self.attention_scaling = 1.0
|
|
475
460
|
|
|
476
461
|
def _get_inv_freq(self):
|
|
477
462
|
"""Compute inverse frequencies on the fly"""
|
|
478
|
-
inv_freq = 1.0 / (
|
|
463
|
+
inv_freq = 1.0 / (
|
|
464
|
+
self.theta ** (mx.arange(0, self.head_dim, 2).astype(mx.float32) / self.head_dim)
|
|
465
|
+
)
|
|
479
466
|
# Expand for 3 dimensions (T, H, W)
|
|
480
467
|
return mx.broadcast_to(inv_freq[None, :], (3, len(inv_freq)))
|
|
481
468
|
|
|
@@ -501,36 +488,38 @@ class TextRotaryEmbedding(nn.Module):
|
|
|
501
488
|
Args:
|
|
502
489
|
x: Input tensor for dtype reference
|
|
503
490
|
position_ids: Position indices, shape (3, batch_size, seq_len) for MRoPE
|
|
504
|
-
|
|
491
|
+
|
|
505
492
|
Returns:
|
|
506
493
|
cos, sin: Cosine and sine embeddings
|
|
507
494
|
"""
|
|
508
495
|
# Handle 2D position_ids by expanding to 3D for MRoPE
|
|
509
496
|
if position_ids.ndim == 2:
|
|
510
|
-
position_ids = mx.broadcast_to(
|
|
511
|
-
|
|
497
|
+
position_ids = mx.broadcast_to(
|
|
498
|
+
position_ids[None, ...], (3, position_ids.shape[0], position_ids.shape[1])
|
|
499
|
+
)
|
|
500
|
+
|
|
512
501
|
batch_size, seq_len = position_ids.shape[1], position_ids.shape[2]
|
|
513
|
-
|
|
502
|
+
|
|
514
503
|
# Expand inverse frequencies: (3, 1, 1, dim//2) -> (3, batch_size, 1, dim//2)
|
|
515
504
|
inv_freq_expanded = mx.broadcast_to(
|
|
516
|
-
self._get_inv_freq()[:, None, None, :],
|
|
517
|
-
(3, batch_size, 1, self._get_inv_freq().shape[-1])
|
|
505
|
+
self._get_inv_freq()[:, None, None, :],
|
|
506
|
+
(3, batch_size, 1, self._get_inv_freq().shape[-1]),
|
|
518
507
|
)
|
|
519
|
-
|
|
508
|
+
|
|
520
509
|
# Expand position ids: (3, batch_size, seq_len) -> (3, batch_size, seq_len, 1)
|
|
521
510
|
position_ids_expanded = position_ids[..., None].astype(mx.float32)
|
|
522
|
-
|
|
511
|
+
|
|
523
512
|
# Compute frequencies: (3, batch_size, seq_len, dim//2)
|
|
524
513
|
freqs = inv_freq_expanded * position_ids_expanded
|
|
525
|
-
|
|
514
|
+
|
|
526
515
|
# Apply interleaved MRoPE
|
|
527
516
|
freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
|
|
528
|
-
|
|
517
|
+
|
|
529
518
|
# Create embeddings
|
|
530
519
|
emb = mx.concatenate([freqs, freqs], axis=-1) # (batch_size, seq_len, head_dim)
|
|
531
520
|
cos = mx.cos(emb) * self.attention_scaling
|
|
532
521
|
sin = mx.sin(emb) * self.attention_scaling
|
|
533
|
-
|
|
522
|
+
|
|
534
523
|
return cos.astype(x.dtype), sin.astype(x.dtype)
|
|
535
524
|
|
|
536
525
|
|
|
@@ -539,12 +528,12 @@ class TextAttention(nn.Module):
|
|
|
539
528
|
super().__init__()
|
|
540
529
|
self.config = config
|
|
541
530
|
self.layer_idx = layer_idx
|
|
542
|
-
|
|
531
|
+
|
|
543
532
|
dim = config.hidden_size
|
|
544
533
|
self.n_heads = config.num_attention_heads
|
|
545
534
|
self.n_kv_heads = config.num_key_value_heads
|
|
546
535
|
self.head_dim = config.head_dim
|
|
547
|
-
self.scale = self.head_dim
|
|
536
|
+
self.scale = self.head_dim**-0.5
|
|
548
537
|
|
|
549
538
|
self.q_proj = nn.Linear(dim, self.n_heads * self.head_dim, bias=config.attention_bias)
|
|
550
539
|
self.k_proj = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=config.attention_bias)
|
|
@@ -553,7 +542,7 @@ class TextAttention(nn.Module):
|
|
|
553
542
|
|
|
554
543
|
self.q_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
|
555
544
|
self.k_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
|
556
|
-
|
|
545
|
+
|
|
557
546
|
# Initialize rope directly
|
|
558
547
|
self.rope = initialize_rope(
|
|
559
548
|
config.head_dim,
|
|
@@ -589,8 +578,23 @@ class TextAttention(nn.Module):
|
|
|
589
578
|
keys, values = cache.update_and_fetch(keys, values)
|
|
590
579
|
else:
|
|
591
580
|
if cache is not None:
|
|
592
|
-
|
|
593
|
-
|
|
581
|
+
# Handle different types of rope_deltas: scalar, array, or None
|
|
582
|
+
if rope_deltas is None:
|
|
583
|
+
offset_delta = 0
|
|
584
|
+
elif isinstance(rope_deltas, (int, float)):
|
|
585
|
+
# rope_deltas is a scalar
|
|
586
|
+
offset_delta = rope_deltas
|
|
587
|
+
elif hasattr(rope_deltas, 'size') and rope_deltas.size == 1:
|
|
588
|
+
# rope_deltas is an array with single element
|
|
589
|
+
offset_delta = rope_deltas.item()
|
|
590
|
+
elif hasattr(rope_deltas, 'shape') and rope_deltas.shape:
|
|
591
|
+
# rope_deltas is an array with multiple elements, take first
|
|
592
|
+
offset_delta = rope_deltas.reshape(-1)[0].item()
|
|
593
|
+
else:
|
|
594
|
+
offset_delta = 0
|
|
595
|
+
|
|
596
|
+
queries = self.rope(queries, offset=cache.offset + offset_delta)
|
|
597
|
+
keys = self.rope(keys, offset=cache.offset + offset_delta)
|
|
594
598
|
keys, values = cache.update_and_fetch(keys, values)
|
|
595
599
|
else:
|
|
596
600
|
queries = self.rope(queries)
|
|
@@ -634,7 +638,7 @@ class TextDecoderLayer(nn.Module):
|
|
|
634
638
|
) -> mx.array:
|
|
635
639
|
residual = hidden_states
|
|
636
640
|
hidden_states = self.input_layernorm(hidden_states)
|
|
637
|
-
|
|
641
|
+
|
|
638
642
|
hidden_states, _ = self.self_attn(
|
|
639
643
|
hidden_states=hidden_states,
|
|
640
644
|
attention_mask=attention_mask,
|
|
@@ -656,11 +660,10 @@ class TextModel(nn.Module):
|
|
|
656
660
|
super().__init__()
|
|
657
661
|
self.config = config
|
|
658
662
|
self.vocab_size = config.vocab_size
|
|
659
|
-
|
|
663
|
+
|
|
660
664
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
|
661
665
|
self.layers = [
|
|
662
|
-
TextDecoderLayer(config, layer_idx)
|
|
663
|
-
for layer_idx in range(config.num_hidden_layers)
|
|
666
|
+
TextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)
|
|
664
667
|
]
|
|
665
668
|
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
666
669
|
self.rotary_emb = TextRotaryEmbedding(config)
|
|
@@ -717,7 +720,9 @@ class TextModel(nn.Module):
|
|
|
717
720
|
rope_deltas=rope_deltas,
|
|
718
721
|
)
|
|
719
722
|
if deepstack_visual_embeds is not None and layer_idx < len(deepstack_visual_embeds):
|
|
720
|
-
hidden_states = self._deepstack_process(
|
|
723
|
+
hidden_states = self._deepstack_process(
|
|
724
|
+
hidden_states, visual_pos_masks, deepstack_visual_embeds[layer_idx]
|
|
725
|
+
)
|
|
721
726
|
hidden_states = self.norm(hidden_states)
|
|
722
727
|
return hidden_states
|
|
723
728
|
|
|
@@ -728,17 +733,17 @@ class VEGModel(nn.Module):
|
|
|
728
733
|
super().__init__()
|
|
729
734
|
self.config = vision_config
|
|
730
735
|
self.visual = VisionModel(vision_config)
|
|
731
|
-
|
|
736
|
+
|
|
732
737
|
def __call__(self, pixel_values: mx.array, image_grid_thw: mx.array):
|
|
733
738
|
return self.visual(pixel_values, image_grid_thw)
|
|
734
|
-
|
|
739
|
+
|
|
735
740
|
def sanitize(self, weights):
|
|
736
741
|
sanitized = {}
|
|
737
742
|
for k, v in weights.items():
|
|
738
|
-
if
|
|
743
|
+
if "visual." in k:
|
|
739
744
|
# Remove prefixes to match our model structure
|
|
740
|
-
clean_key = k.replace(
|
|
741
|
-
sanitized[f
|
|
745
|
+
clean_key = k.replace("model.visual.", "").replace("visual.", "")
|
|
746
|
+
sanitized[f"visual.{clean_key}"] = v
|
|
742
747
|
return sanitized
|
|
743
748
|
|
|
744
749
|
|
|
@@ -751,140 +756,164 @@ class LLMModel(nn.Module):
|
|
|
751
756
|
self.language_model = TextModel(text_config)
|
|
752
757
|
if not text_config.tie_word_embeddings:
|
|
753
758
|
self.lm_head = nn.Linear(text_config.hidden_size, text_config.vocab_size, bias=False)
|
|
754
|
-
|
|
759
|
+
|
|
755
760
|
def get_rope_index(
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
761
|
+
self,
|
|
762
|
+
input_ids: Optional[mx.array] = None,
|
|
763
|
+
image_grid_thw: Optional[mx.array] = None,
|
|
764
|
+
attention_mask: Optional[mx.array] = None,
|
|
765
|
+
) -> Tuple[mx.array, mx.array]:
|
|
766
|
+
"""Simplified version for images only (no video support)."""
|
|
767
|
+
|
|
768
|
+
spatial_merge_size = 2
|
|
769
|
+
image_token_id = 151655
|
|
770
|
+
vision_start_token_id = 151652
|
|
771
|
+
mrope_position_deltas = []
|
|
772
|
+
|
|
773
|
+
if input_ids is not None and image_grid_thw is not None:
|
|
774
|
+
total_input_ids = input_ids
|
|
775
|
+
if attention_mask is None:
|
|
776
|
+
attention_mask = mx.ones_like(total_input_ids)
|
|
777
|
+
|
|
778
|
+
batch_size, seq_len = input_ids.shape
|
|
779
|
+
position_ids_list = []
|
|
780
|
+
image_index = 0
|
|
781
|
+
|
|
782
|
+
for i in range(batch_size):
|
|
783
|
+
input_ids_seq = total_input_ids[i]
|
|
784
|
+
mask_seq = attention_mask[i]
|
|
785
|
+
|
|
786
|
+
# Use mask to get valid length
|
|
787
|
+
valid_length = int(mx.sum(mask_seq).item())
|
|
788
|
+
input_ids_seq = input_ids_seq[:valid_length]
|
|
789
|
+
|
|
790
|
+
image_nums = 0
|
|
791
|
+
# Find vision start tokens by iterating through the sequence
|
|
792
|
+
vision_start_positions = []
|
|
793
|
+
for pos in range(input_ids_seq.shape[0]):
|
|
794
|
+
if input_ids_seq[pos].item() == vision_start_token_id:
|
|
795
|
+
vision_start_positions.append(pos)
|
|
796
|
+
|
|
797
|
+
if len(vision_start_positions) > 0:
|
|
798
|
+
for pos in vision_start_positions:
|
|
799
|
+
if pos + 1 < input_ids_seq.shape[0]:
|
|
800
|
+
if input_ids_seq[pos + 1].item() == image_token_id:
|
|
801
|
+
image_nums += 1
|
|
802
|
+
|
|
803
|
+
input_tokens = input_ids_seq.tolist()
|
|
804
|
+
llm_pos_ids_list = []
|
|
805
|
+
st = 0
|
|
806
|
+
remain_images = image_nums
|
|
807
|
+
|
|
808
|
+
for _ in range(image_nums):
|
|
809
|
+
ed_image = input_tokens.index(image_token_id, st)
|
|
810
|
+
|
|
811
|
+
t = image_grid_thw[image_index, 0].item()
|
|
812
|
+
h = image_grid_thw[image_index, 1].item()
|
|
813
|
+
w = image_grid_thw[image_index, 2].item()
|
|
814
|
+
image_index += 1
|
|
815
|
+
remain_images -= 1
|
|
816
|
+
ed = ed_image
|
|
817
|
+
|
|
818
|
+
llm_grid_t = int(t)
|
|
819
|
+
llm_grid_h = int(h) // spatial_merge_size
|
|
820
|
+
llm_grid_w = int(w) // spatial_merge_size
|
|
821
|
+
text_len = ed - st
|
|
822
|
+
|
|
823
|
+
st_idx = (
|
|
824
|
+
llm_pos_ids_list[-1].max().item() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
825
|
+
)
|
|
826
|
+
text_pos = mx.arange(text_len).reshape(1, -1)
|
|
827
|
+
text_pos = mx.broadcast_to(text_pos, (3, text_len)) + st_idx
|
|
828
|
+
llm_pos_ids_list.append(text_pos)
|
|
829
|
+
|
|
830
|
+
# t_index is always 0 because llm_grid_t is always 1 for images
|
|
831
|
+
t_index = mx.arange(llm_grid_t).reshape(-1, 1)
|
|
832
|
+
t_index = mx.broadcast_to(
|
|
833
|
+
t_index, (llm_grid_t, llm_grid_h * llm_grid_w)
|
|
834
|
+
).reshape(-1)
|
|
835
|
+
|
|
836
|
+
h_index = mx.arange(llm_grid_h).reshape(1, -1, 1)
|
|
837
|
+
h_index = mx.broadcast_to(
|
|
838
|
+
h_index, (llm_grid_t, llm_grid_h, llm_grid_w)
|
|
839
|
+
).reshape(-1)
|
|
840
|
+
|
|
841
|
+
w_index = mx.arange(llm_grid_w).reshape(1, 1, -1)
|
|
842
|
+
w_index = mx.broadcast_to(
|
|
843
|
+
w_index, (llm_grid_t, llm_grid_h, llm_grid_w)
|
|
844
|
+
).reshape(-1)
|
|
845
|
+
|
|
846
|
+
vision_pos = mx.stack([t_index, h_index, w_index]) + text_len + st_idx
|
|
847
|
+
llm_pos_ids_list.append(vision_pos)
|
|
848
|
+
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
|
849
|
+
|
|
850
|
+
if st < len(input_tokens):
|
|
851
|
+
st_idx = (
|
|
852
|
+
llm_pos_ids_list[-1].max().item() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
853
|
+
)
|
|
854
|
+
text_len = len(input_tokens) - st
|
|
855
|
+
text_pos = mx.arange(text_len).reshape(1, -1)
|
|
856
|
+
text_pos = mx.broadcast_to(text_pos, (3, text_len)) + st_idx
|
|
857
|
+
llm_pos_ids_list.append(text_pos)
|
|
858
|
+
|
|
859
|
+
llm_positions = mx.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
|
|
860
|
+
|
|
861
|
+
# Create position_ids for this batch item, pad to seq_len
|
|
862
|
+
batch_position_ids = mx.ones((3, seq_len), dtype=input_ids.dtype)
|
|
863
|
+
valid_length = min(seq_len, llm_positions.shape[1])
|
|
864
|
+
|
|
865
|
+
# Create new arrays for each dimension
|
|
866
|
+
pos_dim0 = mx.concatenate(
|
|
867
|
+
[
|
|
868
|
+
llm_positions[0, :valid_length],
|
|
869
|
+
mx.ones(seq_len - valid_length, dtype=input_ids.dtype),
|
|
870
|
+
]
|
|
871
|
+
)
|
|
872
|
+
pos_dim1 = mx.concatenate(
|
|
873
|
+
[
|
|
874
|
+
llm_positions[1, :valid_length],
|
|
875
|
+
mx.ones(seq_len - valid_length, dtype=input_ids.dtype),
|
|
876
|
+
]
|
|
877
|
+
)
|
|
878
|
+
pos_dim2 = mx.concatenate(
|
|
879
|
+
[
|
|
880
|
+
llm_positions[2, :valid_length],
|
|
881
|
+
mx.ones(seq_len - valid_length, dtype=input_ids.dtype),
|
|
882
|
+
]
|
|
883
|
+
)
|
|
884
|
+
|
|
885
|
+
batch_position_ids = mx.stack([pos_dim0, pos_dim1, pos_dim2])
|
|
886
|
+
position_ids_list.append(batch_position_ids)
|
|
887
|
+
|
|
888
|
+
mrope_position_deltas.append(
|
|
889
|
+
llm_positions.max().item() + 1 - len(total_input_ids[i])
|
|
890
|
+
)
|
|
891
|
+
|
|
892
|
+
# Stack all batch position_ids
|
|
893
|
+
position_ids = mx.stack(position_ids_list, axis=1) # Shape: (3, batch_size, seq_len)
|
|
894
|
+
mrope_position_deltas = mx.array(mrope_position_deltas).reshape(-1, 1)
|
|
895
|
+
return position_ids, mrope_position_deltas
|
|
896
|
+
else:
|
|
897
|
+
if attention_mask is not None:
|
|
898
|
+
position_ids = mx.cumsum(attention_mask.astype(mx.int32), axis=-1) - 1
|
|
899
|
+
position_ids = mx.where(attention_mask == 0, 1, position_ids)
|
|
900
|
+
position_ids = mx.expand_dims(position_ids, axis=0)
|
|
901
|
+
position_ids = mx.broadcast_to(
|
|
902
|
+
position_ids, (3, position_ids.shape[1], position_ids.shape[2])
|
|
903
|
+
)
|
|
904
|
+
max_position_ids = mx.max(
|
|
905
|
+
mx.max(position_ids, axis=0, keepdims=False), axis=-1, keepdims=True
|
|
906
|
+
)
|
|
907
|
+
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
|
|
868
908
|
else:
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
mrope_position_deltas = mx.reshape(mrope_position_deltas, (-1,))
|
|
878
|
-
else:
|
|
879
|
-
seq_len = input_ids.shape[1]
|
|
880
|
-
batch_size = input_ids.shape[0]
|
|
881
|
-
position_ids = mx.arange(seq_len).reshape(1, 1, -1)
|
|
882
|
-
position_ids = mx.broadcast_to(position_ids, (3, batch_size, seq_len))
|
|
883
|
-
# 1D zeros for rope deltas
|
|
884
|
-
mrope_position_deltas = mx.zeros((batch_size,), dtype=input_ids.dtype)
|
|
885
|
-
|
|
886
|
-
return position_ids, mrope_position_deltas
|
|
887
|
-
|
|
909
|
+
seq_len = input_ids.shape[1]
|
|
910
|
+
batch_size = input_ids.shape[0]
|
|
911
|
+
position_ids = mx.arange(seq_len).reshape(1, 1, -1)
|
|
912
|
+
position_ids = mx.broadcast_to(position_ids, (3, batch_size, seq_len))
|
|
913
|
+
mrope_position_deltas = mx.zeros((batch_size, 1), dtype=input_ids.dtype)
|
|
914
|
+
|
|
915
|
+
return position_ids, mrope_position_deltas
|
|
916
|
+
|
|
888
917
|
def __call__(
|
|
889
918
|
self,
|
|
890
919
|
inputs: mx.array = None,
|
|
@@ -912,35 +941,41 @@ class LLMModel(nn.Module):
|
|
|
912
941
|
return self.language_model.embed_tokens.as_linear(out)
|
|
913
942
|
else:
|
|
914
943
|
return self.lm_head(out)
|
|
915
|
-
|
|
944
|
+
|
|
916
945
|
def sanitize(self, weights):
|
|
917
946
|
sanitized = {}
|
|
918
947
|
for k, v in weights.items():
|
|
919
|
-
if not (
|
|
948
|
+
if not ("visual." in k):
|
|
920
949
|
# Handle key mapping from combined model to LLM-only model
|
|
921
950
|
clean_key = k
|
|
922
|
-
|
|
951
|
+
|
|
923
952
|
# Remove model. prefix if present
|
|
924
|
-
if clean_key.startswith(
|
|
953
|
+
if clean_key.startswith("model."):
|
|
925
954
|
clean_key = clean_key[6:] # Remove 'model.'
|
|
926
|
-
|
|
955
|
+
|
|
927
956
|
# Map language_ prefixed keys to language_model structure
|
|
928
|
-
if clean_key.startswith(
|
|
929
|
-
if clean_key.startswith(
|
|
930
|
-
clean_key =
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
elif clean_key.startswith(
|
|
934
|
-
clean_key =
|
|
935
|
-
|
|
957
|
+
if clean_key.startswith("language_"):
|
|
958
|
+
if clean_key.startswith("language_layers."):
|
|
959
|
+
clean_key = (
|
|
960
|
+
"language_model.layers." + clean_key[16:]
|
|
961
|
+
) # Map to language_model.layers.
|
|
962
|
+
elif clean_key.startswith("language_embed_tokens."):
|
|
963
|
+
clean_key = (
|
|
964
|
+
"language_model.embed_tokens." + clean_key[22:]
|
|
965
|
+
) # Map to language_model.embed_tokens.
|
|
966
|
+
elif clean_key.startswith("language_norm."):
|
|
967
|
+
clean_key = (
|
|
968
|
+
"language_model.norm." + clean_key[14:]
|
|
969
|
+
) # Map to language_model.norm.
|
|
970
|
+
|
|
936
971
|
sanitized[clean_key] = v
|
|
937
|
-
|
|
972
|
+
|
|
938
973
|
# Handle tied embeddings - remove lm_head if using tied embeddings
|
|
939
974
|
if self.args.tie_word_embeddings:
|
|
940
975
|
sanitized.pop("lm_head.weight", None)
|
|
941
|
-
|
|
976
|
+
|
|
942
977
|
return sanitized
|
|
943
|
-
|
|
978
|
+
|
|
944
979
|
@property
|
|
945
980
|
def layers(self):
|
|
946
981
|
return self.language_model.layers
|
|
@@ -954,39 +989,36 @@ class Qwen3VLModel(nn.Module):
|
|
|
954
989
|
self.config = args
|
|
955
990
|
self.visual = VisionModel(args.vision_config)
|
|
956
991
|
self.language_model = TextModel(args.text_config)
|
|
957
|
-
|
|
992
|
+
|
|
958
993
|
def sanitize(self, weights):
|
|
959
994
|
# Map weights to match the combined model structure
|
|
960
995
|
sanitized = {}
|
|
961
996
|
for k, v in weights.items():
|
|
962
997
|
# Remove 'model.' prefix if present to match our structure
|
|
963
|
-
clean_key = k.replace(
|
|
998
|
+
clean_key = k.replace("model.", "") if k.startswith("model.") else k
|
|
964
999
|
sanitized[clean_key] = v
|
|
965
1000
|
return sanitized
|
|
966
1001
|
|
|
967
|
-
def get_image_features(
|
|
968
|
-
self,
|
|
969
|
-
pixel_values: mx.array,
|
|
970
|
-
image_grid_thw: Optional[mx.array] = None
|
|
971
|
-
):
|
|
1002
|
+
def get_image_features(self, pixel_values: mx.array, image_grid_thw: Optional[mx.array] = None):
|
|
972
1003
|
image_embeds, deepstack_visual_embeds = self.visual(pixel_values, image_grid_thw)
|
|
973
1004
|
# Split based on grid dimensions
|
|
974
1005
|
if image_grid_thw is not None:
|
|
975
|
-
split_sizes = (
|
|
1006
|
+
split_sizes = (
|
|
1007
|
+
mx.prod(image_grid_thw, axis=-1) // (self.visual.spatial_merge_size**2)
|
|
1008
|
+
).tolist()
|
|
976
1009
|
# Convert sizes to indices for mx.split (cumulative sum, excluding the last)
|
|
977
1010
|
split_indices = []
|
|
978
1011
|
cumsum = 0
|
|
979
1012
|
for size in split_sizes[:-1]: # Exclude last element
|
|
980
1013
|
cumsum += size
|
|
981
1014
|
split_indices.append(cumsum)
|
|
982
|
-
|
|
1015
|
+
|
|
983
1016
|
if split_indices: # Only split if we have indices
|
|
984
1017
|
image_embeds = mx.split(image_embeds, split_indices)
|
|
985
1018
|
else:
|
|
986
1019
|
image_embeds = [image_embeds] # Single image case
|
|
987
1020
|
return image_embeds, deepstack_visual_embeds
|
|
988
1021
|
|
|
989
|
-
|
|
990
1022
|
def __call__(
|
|
991
1023
|
self,
|
|
992
1024
|
input_ids: mx.array = None,
|
|
@@ -1005,26 +1037,25 @@ class Qwen3VLModel(nn.Module):
|
|
|
1005
1037
|
inputs_embeds = self.language_model.embed_tokens(input_ids)
|
|
1006
1038
|
|
|
1007
1039
|
# Process images
|
|
1008
|
-
|
|
1040
|
+
|
|
1009
1041
|
if pixel_values is not None:
|
|
1010
1042
|
image_embeds, deepstack_visual_embeds = self.get_image_features(
|
|
1011
1043
|
pixel_values, image_grid_thw
|
|
1012
1044
|
)
|
|
1013
|
-
|
|
1045
|
+
|
|
1014
1046
|
# Create masks and embed visual features
|
|
1015
1047
|
if isinstance(image_embeds, list):
|
|
1016
1048
|
image_embeds = mx.concatenate(image_embeds, axis=0)
|
|
1017
|
-
|
|
1049
|
+
|
|
1018
1050
|
# Find image token positions and replace with visual embeddings
|
|
1019
|
-
image_mask =
|
|
1051
|
+
image_mask = input_ids == self.args.image_token_id
|
|
1020
1052
|
visual_pos_masks = image_mask
|
|
1021
|
-
|
|
1053
|
+
|
|
1022
1054
|
# Replace image tokens with visual embeddings
|
|
1023
1055
|
inputs_embeds = inputs_embeds.at[image_mask].set(
|
|
1024
1056
|
image_embeds.astype(inputs_embeds.dtype)
|
|
1025
1057
|
)
|
|
1026
1058
|
|
|
1027
|
-
|
|
1028
1059
|
outputs = self.language_model(
|
|
1029
1060
|
inputs_embeds=inputs_embeds,
|
|
1030
1061
|
attention_mask=attention_mask,
|
|
@@ -1042,28 +1073,28 @@ class Qwen3VLModel(nn.Module):
|
|
|
1042
1073
|
def handle_multimodal_embeds(vision_model, llm_model, input_ids, pixel_values, image_grid_thw):
|
|
1043
1074
|
"""
|
|
1044
1075
|
Handle the processing of multimodal embeddings including image features and position encoding.
|
|
1045
|
-
|
|
1076
|
+
|
|
1046
1077
|
This function processes vision and text inputs to create unified embeddings that can be fed
|
|
1047
1078
|
into the language model. It handles:
|
|
1048
1079
|
- Vision feature extraction from pixel values
|
|
1049
1080
|
- Deepstack visual embedding collection
|
|
1050
1081
|
- Image token replacement in text embeddings
|
|
1051
1082
|
- Position encoding setup for MRoPE (Multi-dimensional RoPE)
|
|
1052
|
-
|
|
1083
|
+
|
|
1053
1084
|
Args:
|
|
1054
1085
|
vision_model: The vision encoder model (VEGModel instance)
|
|
1055
|
-
llm_model: The language model (LLMModel instance)
|
|
1086
|
+
llm_model: The language model (LLMModel instance)
|
|
1056
1087
|
input_ids: Tokenized text input with image token placeholders [batch_size, seq_len]
|
|
1057
1088
|
pixel_values: Preprocessed image pixel data [num_patches, feature_dim]
|
|
1058
1089
|
image_grid_thw: Grid dimensions for each image [num_images, 3] (time, height, width)
|
|
1059
|
-
|
|
1090
|
+
|
|
1060
1091
|
Returns:
|
|
1061
1092
|
tuple: (inputs_embeds, deepstack_visual_embeds, visual_pos_masks, cos, sin, rope_deltas)
|
|
1062
1093
|
- inputs_embeds: Combined text and image embeddings [batch_size, seq_len, hidden_size]
|
|
1063
1094
|
- deepstack_visual_embeds: Multi-layer visual features for deepstack processing
|
|
1064
1095
|
- visual_pos_masks: Boolean mask indicating image token positions
|
|
1065
1096
|
- cos: Cosine values for rotary position encoding
|
|
1066
|
-
- sin: Sine values for rotary position encoding
|
|
1097
|
+
- sin: Sine values for rotary position encoding
|
|
1067
1098
|
- rope_deltas: Position offset deltas for rope computation
|
|
1068
1099
|
"""
|
|
1069
1100
|
inputs_embeds = llm_model.language_model.embed_tokens(input_ids.squeeze(0))
|
|
@@ -1072,74 +1103,80 @@ def handle_multimodal_embeds(vision_model, llm_model, input_ids, pixel_values, i
|
|
|
1072
1103
|
cos = None
|
|
1073
1104
|
sin = None
|
|
1074
1105
|
rope_deltas = 0
|
|
1075
|
-
|
|
1106
|
+
|
|
1076
1107
|
if pixel_values is not None:
|
|
1077
1108
|
if pixel_values.ndim == 4:
|
|
1078
1109
|
pixel_values = mx.expand_dims(pixel_values, axis=2)
|
|
1079
|
-
|
|
1110
|
+
|
|
1080
1111
|
# Process each image individually to prevent feature mixing
|
|
1081
1112
|
image_embeds_list = []
|
|
1082
1113
|
all_deepstack_embeds = []
|
|
1083
|
-
|
|
1114
|
+
|
|
1084
1115
|
# Calculate cumulative indices for each image
|
|
1085
1116
|
cumulative_patches = 0
|
|
1086
|
-
|
|
1117
|
+
|
|
1087
1118
|
for i in range(image_grid_thw.shape[0]):
|
|
1088
1119
|
# Calculate number of patches for current image
|
|
1089
1120
|
current_patches = int(image_grid_thw[i, 1] * image_grid_thw[i, 2])
|
|
1090
1121
|
start_idx = cumulative_patches
|
|
1091
1122
|
end_idx = cumulative_patches + current_patches
|
|
1092
1123
|
cumulative_patches += current_patches
|
|
1093
|
-
|
|
1124
|
+
|
|
1094
1125
|
single_pixel_values = pixel_values[start_idx:end_idx]
|
|
1095
|
-
single_grid_thw = image_grid_thw[i:i+1]
|
|
1096
|
-
|
|
1126
|
+
single_grid_thw = image_grid_thw[i : i + 1]
|
|
1127
|
+
|
|
1097
1128
|
# Use vision model directly
|
|
1098
1129
|
single_embeds, single_deepstack = vision_model(single_pixel_values, single_grid_thw)
|
|
1099
|
-
|
|
1130
|
+
|
|
1100
1131
|
# Split based on grid dimensions
|
|
1101
1132
|
if single_grid_thw is not None:
|
|
1102
|
-
split_sizes = (
|
|
1133
|
+
split_sizes = (
|
|
1134
|
+
mx.prod(single_grid_thw, axis=-1) // (vision_model.visual.spatial_merge_size**2)
|
|
1135
|
+
).tolist()
|
|
1103
1136
|
split_indices = []
|
|
1104
1137
|
cumsum = 0
|
|
1105
1138
|
for size in split_sizes[:-1]:
|
|
1106
1139
|
cumsum += size
|
|
1107
1140
|
split_indices.append(cumsum)
|
|
1108
|
-
|
|
1141
|
+
|
|
1109
1142
|
if split_indices:
|
|
1110
1143
|
single_embeds = mx.split(single_embeds, split_indices)
|
|
1111
1144
|
else:
|
|
1112
1145
|
single_embeds = [single_embeds]
|
|
1113
|
-
|
|
1146
|
+
|
|
1114
1147
|
image_embeds_list.extend(single_embeds)
|
|
1115
|
-
|
|
1148
|
+
|
|
1116
1149
|
# Collect deepstack embeddings
|
|
1117
1150
|
if i == 0:
|
|
1118
1151
|
all_deepstack_embeds = single_deepstack
|
|
1119
1152
|
else:
|
|
1120
1153
|
# Concatenate deepstack embeddings from different images
|
|
1121
1154
|
for j in range(len(all_deepstack_embeds)):
|
|
1122
|
-
all_deepstack_embeds[j] = mx.concatenate(
|
|
1123
|
-
|
|
1155
|
+
all_deepstack_embeds[j] = mx.concatenate(
|
|
1156
|
+
[all_deepstack_embeds[j], single_deepstack[j]], axis=0
|
|
1157
|
+
)
|
|
1158
|
+
|
|
1124
1159
|
deepstack_visual_embeds = all_deepstack_embeds
|
|
1125
|
-
|
|
1160
|
+
|
|
1126
1161
|
# Concatenate all image embeddings for processing
|
|
1127
1162
|
image_embeds = mx.concatenate(image_embeds_list, axis=0)
|
|
1128
|
-
|
|
1163
|
+
|
|
1129
1164
|
# Find all image token positions
|
|
1130
1165
|
image_token_id = 151655 # Default image token ID
|
|
1131
|
-
image_mask =
|
|
1166
|
+
image_mask = input_ids.squeeze(0) == image_token_id
|
|
1132
1167
|
image_mask_np = np.array(image_mask)
|
|
1133
1168
|
image_token_positions = np.where(image_mask_np)[0]
|
|
1134
|
-
|
|
1169
|
+
|
|
1135
1170
|
# Verify we have the correct number of image tokens
|
|
1136
1171
|
expected_total_tokens = sum(embed.shape[0] for embed in image_embeds_list)
|
|
1137
|
-
assert
|
|
1138
|
-
|
|
1172
|
+
assert (
|
|
1173
|
+
len(image_token_positions) == expected_total_tokens
|
|
1174
|
+
), f"Expected {expected_total_tokens} image tokens, got {len(image_token_positions)}"
|
|
1175
|
+
|
|
1139
1176
|
# Replace image tokens with image embeddings
|
|
1140
1177
|
seq_len = inputs_embeds.shape[0]
|
|
1141
1178
|
result = inputs_embeds
|
|
1142
|
-
|
|
1179
|
+
|
|
1143
1180
|
# Replace image tokens with image embeddings sequentially
|
|
1144
1181
|
embed_idx = 0
|
|
1145
1182
|
for img_embed in image_embeds_list:
|
|
@@ -1149,7 +1186,7 @@ def handle_multimodal_embeds(vision_model, llm_model, input_ids, pixel_values, i
|
|
|
1149
1186
|
result = mx.where(
|
|
1150
1187
|
mx.expand_dims(pos_mask, axis=-1),
|
|
1151
1188
|
mx.expand_dims(img_embed[patch_idx], axis=0).astype(inputs_embeds.dtype),
|
|
1152
|
-
result
|
|
1189
|
+
result,
|
|
1153
1190
|
)
|
|
1154
1191
|
embed_idx += 1
|
|
1155
1192
|
|
|
@@ -1158,10 +1195,10 @@ def handle_multimodal_embeds(vision_model, llm_model, input_ids, pixel_values, i
|
|
|
1158
1195
|
cos, sin = llm_model.language_model.rotary_emb(inputs_embeds, position_ids)
|
|
1159
1196
|
if inputs_embeds.ndim == 2:
|
|
1160
1197
|
inputs_embeds = mx.expand_dims(inputs_embeds, axis=0)
|
|
1161
|
-
|
|
1198
|
+
|
|
1162
1199
|
if image_mask is not None:
|
|
1163
1200
|
visual_pos_masks = image_mask
|
|
1164
|
-
|
|
1201
|
+
|
|
1165
1202
|
return inputs_embeds, deepstack_visual_embeds, visual_pos_masks, cos, sin, rope_deltas
|
|
1166
1203
|
|
|
1167
1204
|
|
|
@@ -1172,7 +1209,9 @@ class Model(nn.Module):
|
|
|
1172
1209
|
self.args = args
|
|
1173
1210
|
self.model = Qwen3VLModel(args)
|
|
1174
1211
|
if not args.text_config.tie_word_embeddings:
|
|
1175
|
-
self.lm_head = nn.Linear(
|
|
1212
|
+
self.lm_head = nn.Linear(
|
|
1213
|
+
args.text_config.hidden_size, args.text_config.vocab_size, bias=False
|
|
1214
|
+
)
|
|
1176
1215
|
|
|
1177
1216
|
def __call__(
|
|
1178
1217
|
self,
|
|
@@ -1180,7 +1219,7 @@ class Model(nn.Module):
|
|
|
1180
1219
|
mask: mx.array = None,
|
|
1181
1220
|
cache=None,
|
|
1182
1221
|
inputs_embeds: Optional[mx.array] = None,
|
|
1183
|
-
pixel_values: Optional[mx.array] = None,
|
|
1222
|
+
pixel_values: Optional[mx.array] = None,
|
|
1184
1223
|
image_grid_thw: Optional[mx.array] = None,
|
|
1185
1224
|
visual_pos_masks: Optional[mx.array] = None,
|
|
1186
1225
|
deepstack_visual_embeds: Optional[List[mx.array]] = None,
|
|
@@ -1211,13 +1250,13 @@ class Model(nn.Module):
|
|
|
1211
1250
|
sanitized = {}
|
|
1212
1251
|
for k, v in weights.items():
|
|
1213
1252
|
sanitized[k] = v
|
|
1214
|
-
|
|
1253
|
+
|
|
1215
1254
|
# Handle tied embeddings - remove lm_head if using tied embeddings
|
|
1216
1255
|
if self.args.text_config.tie_word_embeddings:
|
|
1217
1256
|
sanitized.pop("lm_head.weight", None)
|
|
1218
|
-
|
|
1257
|
+
|
|
1219
1258
|
return sanitized
|
|
1220
1259
|
|
|
1221
1260
|
@property
|
|
1222
1261
|
def layers(self):
|
|
1223
|
-
return self.model.language_model.layers
|
|
1262
|
+
return self.model.language_model.layers
|