nexaai 1.0.18rc1__cp310-cp310-macosx_14_0_universal2.whl → 1.0.19__cp310-cp310-macosx_14_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nexaai might be problematic. Click here for more details.

Files changed (215) hide show
  1. nexaai/_stub.cpython-310-darwin.so +0 -0
  2. nexaai/_version.py +1 -1
  3. nexaai/asr.py +2 -1
  4. nexaai/binds/{nexa_llama_cpp → cpu_gpu}/libggml-base.dylib +0 -0
  5. nexaai/binds/{nexa_llama_cpp → cpu_gpu}/libmtmd.dylib +0 -0
  6. nexaai/binds/{nexa_llama_cpp/libllama.dylib → cpu_gpu/libnexa_cpu_gpu.dylib} +0 -0
  7. nexaai/binds/{nexa_llama_cpp → cpu_gpu}/libnexa_plugin.dylib +0 -0
  8. nexaai/binds/libnexa_bridge.dylib +0 -0
  9. nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
  10. nexaai/binds/{nexa_mlx → metal}/libnexa_plugin.dylib +0 -0
  11. nexaai/binds/{nexa_nexaml → nexaml}/libggml-base.dylib +0 -0
  12. nexaai/binds/{nexa_nexaml → nexaml}/libnexa-mm-process.dylib +0 -0
  13. nexaai/binds/{nexa_nexaml → nexaml}/libnexa-sampling.dylib +0 -0
  14. nexaai/binds/nexaml/libnexa_plugin.dylib +0 -0
  15. nexaai/binds/nexaml/libnexaproc.dylib +0 -0
  16. nexaai/binds/{nexa_nexaml → nexaml}/libomp.dylib +0 -0
  17. nexaai/binds/nexaml/libqwen3-vl.dylib +0 -0
  18. nexaai/binds/nexaml/libqwen3vl-vision.dylib +0 -0
  19. nexaai/cv.py +2 -1
  20. nexaai/embedder.py +1 -1
  21. nexaai/image_gen.py +2 -1
  22. nexaai/llm.py +5 -3
  23. nexaai/llm_impl/mlx_llm_impl.py +2 -0
  24. nexaai/llm_impl/pybind_llm_impl.py +2 -0
  25. nexaai/mlx_backend/vlm/generate_qwen3_vl.py +176 -96
  26. nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +259 -0
  27. nexaai/mlx_backend/vlm/interface.py +99 -30
  28. nexaai/mlx_backend/vlm/main.py +58 -9
  29. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +338 -299
  30. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
  31. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
  32. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
  33. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
  34. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
  35. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
  36. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
  37. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
  38. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1308 -0
  39. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
  40. nexaai/rerank.py +2 -1
  41. nexaai/tts.py +2 -1
  42. nexaai/utils/manifest_utils.py +222 -15
  43. nexaai/utils/model_manager.py +120 -14
  44. nexaai/utils/model_types.py +2 -0
  45. nexaai/vlm.py +2 -1
  46. {nexaai-1.0.18rc1.dist-info → nexaai-1.0.19.dist-info}/METADATA +1 -2
  47. {nexaai-1.0.18rc1.dist-info → nexaai-1.0.19.dist-info}/RECORD +211 -200
  48. nexaai/binds/nexa_nexaml/libnexa_plugin.dylib +0 -0
  49. nexaai/binds/nexa_nexaml/libnexaproc.dylib +0 -0
  50. nexaai/binds/nexa_nexaml/libqwen3-vl.dylib +0 -0
  51. nexaai/binds/nexa_nexaml/libqwen3vl-vision.dylib +0 -0
  52. /nexaai/binds/{nexa_llama_cpp → cpu_gpu}/libggml-cpu.so +0 -0
  53. /nexaai/binds/{nexa_llama_cpp → cpu_gpu}/libggml-metal.so +0 -0
  54. /nexaai/binds/{nexa_llama_cpp → cpu_gpu}/libggml.dylib +0 -0
  55. /nexaai/binds/{nexa_mlx → metal}/py-lib/ml.py +0 -0
  56. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/__init__.py +0 -0
  57. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/__init__.py +0 -0
  58. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/__init__.py +0 -0
  59. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/bigvgan/__init__.py +0 -0
  60. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/bigvgan/activation.py +0 -0
  61. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/bigvgan/amp.py +0 -0
  62. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/bigvgan/bigvgan.py +0 -0
  63. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/bigvgan/conv.py +0 -0
  64. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/bigvgan/resample.py +0 -0
  65. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/descript/__init__.py +0 -0
  66. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/descript/base.py +0 -0
  67. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/descript/dac.py +0 -0
  68. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/descript/nn/__init__.py +0 -0
  69. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/descript/nn/layers.py +0 -0
  70. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/descript/nn/quantize.py +0 -0
  71. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/encodec/__init__.py +0 -0
  72. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/encodec/encodec.py +0 -0
  73. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/mimi/__init__.py +0 -0
  74. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/mimi/mimi.py +0 -0
  75. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/mimi/modules/__init__.py +0 -0
  76. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/mimi/modules/conv.py +0 -0
  77. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/mimi/modules/kv_cache.py +0 -0
  78. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/mimi/modules/quantization.py +0 -0
  79. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/mimi/modules/seanet.py +0 -0
  80. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/mimi/modules/transformer.py +0 -0
  81. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/s3/__init__.py +0 -0
  82. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/s3/model.py +0 -0
  83. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/s3/model_v2.py +0 -0
  84. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/s3/utils.py +0 -0
  85. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/snac/__init__.py +0 -0
  86. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/snac/attention.py +0 -0
  87. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/snac/layers.py +0 -0
  88. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/snac/snac.py +0 -0
  89. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/snac/vq.py +0 -0
  90. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/vocos/__init__.py +0 -0
  91. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/vocos/mel.py +0 -0
  92. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/models/vocos/vocos.py +0 -0
  93. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/tests/__init__.py +0 -0
  94. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/tests/test_bigvgan.py +0 -0
  95. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/tests/test_descript.py +0 -0
  96. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/tests/test_encodec.py +0 -0
  97. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/tests/test_mimi.py +0 -0
  98. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/tests/test_s3.py +0 -0
  99. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/tests/test_snac.py +0 -0
  100. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/codec/tests/test_vocos.py +0 -0
  101. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/server.py +0 -0
  102. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/sts/__init__.py +0 -0
  103. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/sts/tests/test_voice_pipeline.py +0 -0
  104. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/sts/voice_pipeline.py +0 -0
  105. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/__init__.py +0 -0
  106. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/generate.py +0 -0
  107. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/models/__init__.py +0 -0
  108. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/models/parakeet/__init__.py +0 -0
  109. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/models/parakeet/alignment.py +0 -0
  110. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/models/parakeet/attention.py +0 -0
  111. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/models/parakeet/audio.py +0 -0
  112. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/models/parakeet/conformer.py +0 -0
  113. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/models/parakeet/ctc.py +0 -0
  114. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/models/parakeet/parakeet.py +0 -0
  115. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/models/parakeet/rnnt.py +0 -0
  116. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/models/parakeet/tokenizer.py +0 -0
  117. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/models/wav2vec/feature_extractor.py +0 -0
  118. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/models/wav2vec/wav2vec.py +0 -0
  119. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/models/whisper/__init__.py +0 -0
  120. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/models/whisper/audio.py +0 -0
  121. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/models/whisper/decoding.py +0 -0
  122. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/models/whisper/timing.py +0 -0
  123. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/models/whisper/tokenizer.py +0 -0
  124. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/models/whisper/whisper.py +0 -0
  125. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/models/whisper/writers.py +0 -0
  126. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/tests/test_models.py +0 -0
  127. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/stt/utils.py +0 -0
  128. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/__init__.py +0 -0
  129. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/audio_player.py +0 -0
  130. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/convert.py +0 -0
  131. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/generate.py +0 -0
  132. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/__init__.py +0 -0
  133. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/bark/__init__.py +0 -0
  134. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/bark/bark.py +0 -0
  135. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/bark/isftnet.py +0 -0
  136. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/bark/pipeline.py +0 -0
  137. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/base.py +0 -0
  138. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/dia/__init__.py +0 -0
  139. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/dia/audio.py +0 -0
  140. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/dia/config.py +0 -0
  141. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/dia/dia.py +0 -0
  142. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/dia/layers.py +0 -0
  143. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/indextts/__init__.py +0 -0
  144. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/indextts/attention.py +0 -0
  145. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/indextts/bigvgan.py +0 -0
  146. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/indextts/conformer.py +0 -0
  147. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  148. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +0 -0
  149. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +0 -0
  150. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +0 -0
  151. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +0 -0
  152. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/indextts/gpt2.py +0 -0
  153. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/indextts/indextts.py +0 -0
  154. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/indextts/mel.py +0 -0
  155. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/indextts/normalize.py +0 -0
  156. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/indextts/perceiver.py +0 -0
  157. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/interpolate.py +0 -0
  158. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/kokoro/__init__.py +0 -0
  159. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/kokoro/istftnet.py +0 -0
  160. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/kokoro/kokoro.py +0 -0
  161. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/kokoro/modules.py +0 -0
  162. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/kokoro/pipeline.py +0 -0
  163. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/kokoro/voice.py +0 -0
  164. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/llama/__init__.py +0 -0
  165. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/llama/llama.py +0 -0
  166. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/outetts/__init__.py +0 -0
  167. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/outetts/audio_processor.py +0 -0
  168. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/outetts/dac_interface.py +0 -0
  169. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/outetts/outetts.py +0 -0
  170. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/outetts/prompt_processor.py +0 -0
  171. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/outetts/tokens.py +0 -0
  172. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/sesame/__init__.py +0 -0
  173. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/sesame/attention.py +0 -0
  174. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/sesame/sesame.py +0 -0
  175. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/sesame/watermarking.py +0 -0
  176. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/__init__.py +0 -0
  177. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/audio_tokenizer.py +0 -0
  178. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/bicodec.py +0 -0
  179. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  180. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  181. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/modules/blocks/sampler.py +0 -0
  182. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  183. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +0 -0
  184. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +0 -0
  185. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +0 -0
  186. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +0 -0
  187. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/modules/residual.py +0 -0
  188. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/modules/residual_fsq.py +0 -0
  189. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/modules/speaker/__init__.py +0 -0
  190. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +0 -0
  191. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +0 -0
  192. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +0 -0
  193. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +0 -0
  194. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/spark.py +0 -0
  195. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/utils/audio.py +0 -0
  196. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/utils/file.py +0 -0
  197. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/models/spark/utils/token_parser.py +0 -0
  198. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/tests/__init__.py +0 -0
  199. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/tests/test_base.py +0 -0
  200. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/tests/test_convert.py +0 -0
  201. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/tests/test_interpolate.py +0 -0
  202. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/tests/test_models.py +0 -0
  203. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/tts/utils.py +0 -0
  204. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/utils.py +0 -0
  205. /nexaai/binds/{nexa_mlx → metal}/py-lib/mlx_audio/version.py +0 -0
  206. /nexaai/binds/{nexa_mlx → metal}/py-lib/profiling.py +0 -0
  207. /nexaai/binds/{nexa_nexaml → nexaml}/libfftw3.3.dylib +0 -0
  208. /nexaai/binds/{nexa_nexaml → nexaml}/libfftw3f.3.dylib +0 -0
  209. /nexaai/binds/{nexa_nexaml → nexaml}/libggml-cpu.so +0 -0
  210. /nexaai/binds/{nexa_nexaml → nexaml}/libggml-metal.so +0 -0
  211. /nexaai/binds/{nexa_nexaml → nexaml}/libggml.dylib +0 -0
  212. /nexaai/binds/{nexa_nexaml → nexaml}/libmp3lame.0.dylib +0 -0
  213. /nexaai/binds/{nexa_nexaml → nexaml}/libmpg123.0.dylib +0 -0
  214. {nexaai-1.0.18rc1.dist-info → nexaai-1.0.19.dist-info}/WHEEL +0 -0
  215. {nexaai-1.0.18rc1.dist-info → nexaai-1.0.19.dist-info}/top_level.txt +0 -0
@@ -8,29 +8,13 @@ import mlx.nn as nn
8
8
  import math
9
9
  import numpy as np
10
10
 
11
- import os
12
- import sys
13
-
14
- curr_dir = os.path.dirname(os.path.abspath(__file__))
15
- llm_common_dir = os.path.join(curr_dir, "..", "..")
16
- sys.path.append(llm_common_dir)
17
-
18
- # Try relative imports first, fallback to sys.path approach for Nuitka compatibility
19
- try:
20
- from .llm_common.base import (
21
- BaseModelArgs,
22
- create_attention_mask,
23
- scaled_dot_product_attention,
24
- )
25
- from .llm_common.rope_utils import initialize_rope
26
- except ImportError:
27
- # Fallback for Nuitka compiled environment
28
- from llm_common.base import (
29
- BaseModelArgs,
30
- create_attention_mask,
31
- scaled_dot_product_attention,
32
- )
33
- from llm_common.rope_utils import initialize_rope
11
+ # Import from nested llm_common structure using relative imports
12
+ from .llm_common.base import (
13
+ BaseModelArgs,
14
+ create_attention_mask,
15
+ scaled_dot_product_attention,
16
+ )
17
+ from .llm_common.rope_utils import initialize_rope
34
18
 
35
19
 
36
20
  @dataclass
@@ -136,28 +120,24 @@ class VisionPatchEmbed(nn.Module):
136
120
 
137
121
  kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
138
122
  self.proj = nn.Conv3d(
139
- self.in_channels,
140
- self.embed_dim,
141
- kernel_size=kernel_size,
142
- stride=kernel_size,
143
- bias=True
123
+ self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True
144
124
  )
145
125
 
146
126
  def __call__(self, hidden_states: mx.array) -> mx.array:
147
127
  target_dtype = self.proj.weight.dtype
148
-
128
+
149
129
  # Reshape to 5D: [batch, channels, temporal, height, width] (PyTorch format)
150
130
  # This matches the PyTorch ground truth exactly
151
131
  hidden_states = hidden_states.reshape(
152
132
  -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
153
133
  )
154
-
134
+
155
135
  # Convert to MLX format: [batch, temporal, height, width, channels]
156
136
  hidden_states = hidden_states.transpose(0, 2, 3, 4, 1)
157
-
137
+
158
138
  # Apply conv3d with target dtype and reshape to match PyTorch output
159
139
  hidden_states = self.proj(hidden_states.astype(target_dtype)).reshape(-1, self.embed_dim)
160
-
140
+
161
141
  return hidden_states
162
142
 
163
143
 
@@ -179,20 +159,20 @@ class VisionRotaryEmbedding(nn.Module):
179
159
  class VisionPatchMerger(nn.Module):
180
160
  def __init__(self, config: VisionConfig, use_postshuffle_norm=False):
181
161
  super().__init__()
182
- self.hidden_size = config.hidden_size * (config.spatial_merge_size ** 2)
162
+ self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
183
163
  self.use_postshuffle_norm = use_postshuffle_norm
184
-
164
+
185
165
  norm_size = self.hidden_size if use_postshuffle_norm else config.hidden_size
186
- self.ln_q = nn.LayerNorm(norm_size, eps=1e-6)
166
+ self.norm = nn.LayerNorm(norm_size, eps=1e-6)
187
167
  self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size)
188
168
  self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size)
189
169
 
190
170
  def __call__(self, x: mx.array) -> mx.array:
191
171
  if self.use_postshuffle_norm:
192
- x = self.ln_q(x.reshape(-1, self.hidden_size)).reshape(-1, self.hidden_size)
172
+ x = self.norm(x.reshape(-1, self.hidden_size)).reshape(-1, self.hidden_size)
193
173
  else:
194
- x = self.ln_q(x).reshape(-1, self.hidden_size)
195
-
174
+ x = self.norm(x).reshape(-1, self.hidden_size)
175
+
196
176
  x = self.linear_fc2(nn.gelu(self.linear_fc1(x)))
197
177
  return x
198
178
 
@@ -203,8 +183,8 @@ class VisionAttention(nn.Module):
203
183
  self.dim = config.hidden_size
204
184
  self.num_heads = config.num_heads
205
185
  self.head_dim = self.dim // self.num_heads
206
- self.scaling = self.head_dim ** -0.5
207
-
186
+ self.scaling = self.head_dim**-0.5
187
+
208
188
  self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True)
209
189
  self.proj = nn.Linear(self.dim, self.dim)
210
190
 
@@ -220,51 +200,48 @@ class VisionAttention(nn.Module):
220
200
  qkv = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1)
221
201
  qkv = qkv.transpose(1, 0, 2, 3)
222
202
  query_states, key_states, value_states = qkv[0], qkv[1], qkv[2]
223
-
203
+
224
204
  cos, sin = position_embeddings
225
- query_states, key_states = apply_rotary_pos_emb_vision(
226
- query_states, key_states, cos, sin
227
- )
205
+ query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
228
206
 
229
207
  query_states = query_states.transpose(1, 0, 2)
230
208
  key_states = key_states.transpose(1, 0, 2)
231
209
  value_states = value_states.transpose(1, 0, 2)
232
-
210
+
233
211
  query_states = mx.expand_dims(query_states, axis=0)
234
212
  key_states = mx.expand_dims(key_states, axis=0)
235
213
  value_states = mx.expand_dims(value_states, axis=0)
236
-
214
+
237
215
  lengths = cu_seqlens[1:] - cu_seqlens[:-1]
238
-
216
+
239
217
  split_indices = []
240
218
  cumsum = 0
241
219
  for length in lengths[:-1]:
242
220
  cumsum += int(length)
243
221
  split_indices.append(cumsum)
244
-
222
+
245
223
  if split_indices:
246
224
  q_splits = mx.split(query_states, split_indices, axis=1)
247
225
  k_splits = mx.split(key_states, split_indices, axis=1)
248
226
  v_splits = mx.split(value_states, split_indices, axis=1)
249
227
  else:
250
228
  q_splits = [query_states]
251
- k_splits = [key_states]
229
+ k_splits = [key_states]
252
230
  v_splits = [value_states]
253
-
231
+
254
232
  attn_outputs = []
255
233
  for q, k, v in zip(q_splits, k_splits, v_splits):
256
234
  attn_out = scaled_dot_product_attention(
257
- q, k, v,
258
- scale=self.scaling, mask=None, cache=None
235
+ q, k, v, scale=self.scaling, mask=None, cache=None
259
236
  )
260
237
  attn_outputs.append(attn_out)
261
-
238
+
262
239
  attn_output = mx.concatenate(attn_outputs, axis=1)
263
-
240
+
264
241
  attn_output = attn_output[0].transpose(1, 0, 2)
265
242
  attn_output = attn_output.reshape(seq_length, -1)
266
243
  attn_output = self.proj(attn_output)
267
-
244
+
268
245
  return attn_output
269
246
 
270
247
 
@@ -300,7 +277,7 @@ class VisionModel(nn.Module):
300
277
 
301
278
  self.patch_embed = VisionPatchEmbed(config)
302
279
  self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size)
303
- self.num_grid_per_side = int(config.num_position_embeddings ** 0.5)
280
+ self.num_grid_per_side = int(config.num_position_embeddings**0.5)
304
281
 
305
282
  head_dim = config.hidden_size // config.num_heads
306
283
  self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
@@ -326,7 +303,7 @@ class VisionModel(nn.Module):
326
303
  num_frames = int(grid_thw[i, 0].item())
327
304
  height = int(grid_thw[i, 1].item())
328
305
  width = int(grid_thw[i, 2].item())
329
-
306
+
330
307
  merged_h, merged_w = height // merge_size, width // merge_size
331
308
 
332
309
  block_rows = mx.arange(merged_h) # block row indices
@@ -338,8 +315,12 @@ class VisionModel(nn.Module):
338
315
  row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None]
339
316
  col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :]
340
317
 
341
- row_idx = mx.broadcast_to(row_idx, (merged_h, merged_w, merge_size, merge_size)).reshape(-1)
342
- col_idx = mx.broadcast_to(col_idx, (merged_h, merged_w, merge_size, merge_size)).reshape(-1)
318
+ row_idx = mx.broadcast_to(
319
+ row_idx, (merged_h, merged_w, merge_size, merge_size)
320
+ ).reshape(-1)
321
+ col_idx = mx.broadcast_to(
322
+ col_idx, (merged_h, merged_w, merge_size, merge_size)
323
+ ).reshape(-1)
343
324
 
344
325
  coords = mx.stack([row_idx, col_idx], axis=-1)
345
326
 
@@ -350,19 +331,19 @@ class VisionModel(nn.Module):
350
331
 
351
332
  # Concatenate all coordinate parts
352
333
  pos_ids = mx.concatenate(pos_ids_parts, axis=0)
353
-
334
+
354
335
  embeddings = freq_table[pos_ids] # lookup rotary embeddings
355
336
  embeddings = embeddings.reshape(embeddings.shape[0], -1)
356
337
  return embeddings
357
338
 
358
339
  def fast_pos_embed_interpolate(self, grid_thw: mx.array):
359
340
  patch_pos_embeds = []
360
-
341
+
361
342
  for i in range(grid_thw.shape[0]):
362
343
  t = int(grid_thw[i, 0].item())
363
344
  h = int(grid_thw[i, 1].item())
364
345
  w = int(grid_thw[i, 2].item())
365
-
346
+
366
347
  # Simple position embedding interpolation
367
348
  h_idxs = mx.linspace(0, self.num_grid_per_side - 1, h)
368
349
  w_idxs = mx.linspace(0, self.num_grid_per_side - 1, w)
@@ -399,37 +380,41 @@ class VisionModel(nn.Module):
399
380
 
400
381
  # Repeat for temporal dimension and apply spatial merging
401
382
  pos_embed = mx.tile(pos_embed, (t, 1))
402
-
383
+
403
384
  # Apply spatial merging pattern
404
385
  merge_size = self.config.spatial_merge_size
405
- pos_embed = pos_embed.reshape(t, h // merge_size, merge_size, w // merge_size, merge_size, -1)
386
+ pos_embed = pos_embed.reshape(
387
+ t, h // merge_size, merge_size, w // merge_size, merge_size, -1
388
+ )
406
389
  pos_embed = mx.transpose(pos_embed, (0, 1, 3, 2, 4, 5))
407
390
  pos_embed = pos_embed.reshape(-1, pos_embed.shape[-1])
408
-
391
+
409
392
  patch_pos_embeds.append(pos_embed)
410
-
393
+
411
394
  return mx.concatenate(patch_pos_embeds, axis=0)
412
395
 
413
- def __call__(self, hidden_states: mx.array, grid_thw: mx.array) -> Tuple[mx.array, List[mx.array]]:
396
+ def __call__(
397
+ self, hidden_states: mx.array, grid_thw: mx.array
398
+ ) -> Tuple[mx.array, List[mx.array]]:
414
399
  hidden_states = self.patch_embed(hidden_states)
415
-
400
+
416
401
  pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
417
402
  hidden_states = hidden_states + pos_embeds
418
403
 
419
404
  rotary_pos_emb = self.rot_pos_emb(grid_thw)
420
405
  seq_len = hidden_states.shape[0]
421
-
406
+
422
407
  emb = mx.concatenate([rotary_pos_emb, rotary_pos_emb], axis=-1)
423
408
  position_embeddings = (mx.cos(emb), mx.sin(emb))
424
409
 
425
- # Create cumulative sequence lengths (following HuggingFace implementation)
410
+ # Create cumulative sequence lengths (following HuggingFace implementation)
426
411
  # torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0])
427
412
  seq_lens_per_image = grid_thw[:, 1] * grid_thw[:, 2] # h * w for each image
428
413
  seq_lens = []
429
414
  for i, (seq_len, repeats) in enumerate(zip(seq_lens_per_image, grid_thw[:, 0])):
430
415
  seq_lens.extend([seq_len] * int(repeats))
431
416
  seq_lens = mx.array(seq_lens)
432
-
417
+
433
418
  # Then compute cumulative sum
434
419
  cu_seqlens = mx.cumsum(seq_lens)
435
420
  # Pad with 0 at the beginning
@@ -457,7 +442,7 @@ class TextRotaryEmbedding(nn.Module):
457
442
  self.config = config
458
443
  self.max_seq_len_cached = config.max_position_embeddings
459
444
  self.original_max_seq_len = config.max_position_embeddings
460
-
445
+
461
446
  # MRoPE configuration
462
447
  if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
463
448
  self.rope_type = config.rope_scaling.get("rope_type", "default")
@@ -465,17 +450,19 @@ class TextRotaryEmbedding(nn.Module):
465
450
  else:
466
451
  self.rope_type = "default"
467
452
  self.mrope_section = [24, 20, 20]
468
-
453
+
469
454
  # Store parameters for computing inv_freq on the fly
470
455
  self.head_dim = config.head_dim
471
456
  self.theta = config.rope_theta
472
-
457
+
473
458
  # Attention scaling (simplified - may need adjustment based on actual config)
474
459
  self.attention_scaling = 1.0
475
460
 
476
461
  def _get_inv_freq(self):
477
462
  """Compute inverse frequencies on the fly"""
478
- inv_freq = 1.0 / (self.theta ** (mx.arange(0, self.head_dim, 2).astype(mx.float32) / self.head_dim))
463
+ inv_freq = 1.0 / (
464
+ self.theta ** (mx.arange(0, self.head_dim, 2).astype(mx.float32) / self.head_dim)
465
+ )
479
466
  # Expand for 3 dimensions (T, H, W)
480
467
  return mx.broadcast_to(inv_freq[None, :], (3, len(inv_freq)))
481
468
 
@@ -501,36 +488,38 @@ class TextRotaryEmbedding(nn.Module):
501
488
  Args:
502
489
  x: Input tensor for dtype reference
503
490
  position_ids: Position indices, shape (3, batch_size, seq_len) for MRoPE
504
-
491
+
505
492
  Returns:
506
493
  cos, sin: Cosine and sine embeddings
507
494
  """
508
495
  # Handle 2D position_ids by expanding to 3D for MRoPE
509
496
  if position_ids.ndim == 2:
510
- position_ids = mx.broadcast_to(position_ids[None, ...], (3, position_ids.shape[0], position_ids.shape[1]))
511
-
497
+ position_ids = mx.broadcast_to(
498
+ position_ids[None, ...], (3, position_ids.shape[0], position_ids.shape[1])
499
+ )
500
+
512
501
  batch_size, seq_len = position_ids.shape[1], position_ids.shape[2]
513
-
502
+
514
503
  # Expand inverse frequencies: (3, 1, 1, dim//2) -> (3, batch_size, 1, dim//2)
515
504
  inv_freq_expanded = mx.broadcast_to(
516
- self._get_inv_freq()[:, None, None, :],
517
- (3, batch_size, 1, self._get_inv_freq().shape[-1])
505
+ self._get_inv_freq()[:, None, None, :],
506
+ (3, batch_size, 1, self._get_inv_freq().shape[-1]),
518
507
  )
519
-
508
+
520
509
  # Expand position ids: (3, batch_size, seq_len) -> (3, batch_size, seq_len, 1)
521
510
  position_ids_expanded = position_ids[..., None].astype(mx.float32)
522
-
511
+
523
512
  # Compute frequencies: (3, batch_size, seq_len, dim//2)
524
513
  freqs = inv_freq_expanded * position_ids_expanded
525
-
514
+
526
515
  # Apply interleaved MRoPE
527
516
  freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
528
-
517
+
529
518
  # Create embeddings
530
519
  emb = mx.concatenate([freqs, freqs], axis=-1) # (batch_size, seq_len, head_dim)
531
520
  cos = mx.cos(emb) * self.attention_scaling
532
521
  sin = mx.sin(emb) * self.attention_scaling
533
-
522
+
534
523
  return cos.astype(x.dtype), sin.astype(x.dtype)
535
524
 
536
525
 
@@ -539,12 +528,12 @@ class TextAttention(nn.Module):
539
528
  super().__init__()
540
529
  self.config = config
541
530
  self.layer_idx = layer_idx
542
-
531
+
543
532
  dim = config.hidden_size
544
533
  self.n_heads = config.num_attention_heads
545
534
  self.n_kv_heads = config.num_key_value_heads
546
535
  self.head_dim = config.head_dim
547
- self.scale = self.head_dim ** -0.5
536
+ self.scale = self.head_dim**-0.5
548
537
 
549
538
  self.q_proj = nn.Linear(dim, self.n_heads * self.head_dim, bias=config.attention_bias)
550
539
  self.k_proj = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=config.attention_bias)
@@ -553,7 +542,7 @@ class TextAttention(nn.Module):
553
542
 
554
543
  self.q_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
555
544
  self.k_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
556
-
545
+
557
546
  # Initialize rope directly
558
547
  self.rope = initialize_rope(
559
548
  config.head_dim,
@@ -589,8 +578,23 @@ class TextAttention(nn.Module):
589
578
  keys, values = cache.update_and_fetch(keys, values)
590
579
  else:
591
580
  if cache is not None:
592
- queries = self.rope(queries, offset=cache.offset+rope_deltas)
593
- keys = self.rope(keys, offset=cache.offset+rope_deltas)
581
+ # Handle different types of rope_deltas: scalar, array, or None
582
+ if rope_deltas is None:
583
+ offset_delta = 0
584
+ elif isinstance(rope_deltas, (int, float)):
585
+ # rope_deltas is a scalar
586
+ offset_delta = rope_deltas
587
+ elif hasattr(rope_deltas, 'size') and rope_deltas.size == 1:
588
+ # rope_deltas is an array with single element
589
+ offset_delta = rope_deltas.item()
590
+ elif hasattr(rope_deltas, 'shape') and rope_deltas.shape:
591
+ # rope_deltas is an array with multiple elements, take first
592
+ offset_delta = rope_deltas.reshape(-1)[0].item()
593
+ else:
594
+ offset_delta = 0
595
+
596
+ queries = self.rope(queries, offset=cache.offset + offset_delta)
597
+ keys = self.rope(keys, offset=cache.offset + offset_delta)
594
598
  keys, values = cache.update_and_fetch(keys, values)
595
599
  else:
596
600
  queries = self.rope(queries)
@@ -634,7 +638,7 @@ class TextDecoderLayer(nn.Module):
634
638
  ) -> mx.array:
635
639
  residual = hidden_states
636
640
  hidden_states = self.input_layernorm(hidden_states)
637
-
641
+
638
642
  hidden_states, _ = self.self_attn(
639
643
  hidden_states=hidden_states,
640
644
  attention_mask=attention_mask,
@@ -656,11 +660,10 @@ class TextModel(nn.Module):
656
660
  super().__init__()
657
661
  self.config = config
658
662
  self.vocab_size = config.vocab_size
659
-
663
+
660
664
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
661
665
  self.layers = [
662
- TextDecoderLayer(config, layer_idx)
663
- for layer_idx in range(config.num_hidden_layers)
666
+ TextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)
664
667
  ]
665
668
  self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
666
669
  self.rotary_emb = TextRotaryEmbedding(config)
@@ -717,7 +720,9 @@ class TextModel(nn.Module):
717
720
  rope_deltas=rope_deltas,
718
721
  )
719
722
  if deepstack_visual_embeds is not None and layer_idx < len(deepstack_visual_embeds):
720
- hidden_states = self._deepstack_process(hidden_states, visual_pos_masks, deepstack_visual_embeds[layer_idx])
723
+ hidden_states = self._deepstack_process(
724
+ hidden_states, visual_pos_masks, deepstack_visual_embeds[layer_idx]
725
+ )
721
726
  hidden_states = self.norm(hidden_states)
722
727
  return hidden_states
723
728
 
@@ -728,17 +733,17 @@ class VEGModel(nn.Module):
728
733
  super().__init__()
729
734
  self.config = vision_config
730
735
  self.visual = VisionModel(vision_config)
731
-
736
+
732
737
  def __call__(self, pixel_values: mx.array, image_grid_thw: mx.array):
733
738
  return self.visual(pixel_values, image_grid_thw)
734
-
739
+
735
740
  def sanitize(self, weights):
736
741
  sanitized = {}
737
742
  for k, v in weights.items():
738
- if 'visual.' in k:
743
+ if "visual." in k:
739
744
  # Remove prefixes to match our model structure
740
- clean_key = k.replace('model.visual.', '').replace('visual.', '')
741
- sanitized[f'visual.{clean_key}'] = v
745
+ clean_key = k.replace("model.visual.", "").replace("visual.", "")
746
+ sanitized[f"visual.{clean_key}"] = v
742
747
  return sanitized
743
748
 
744
749
 
@@ -751,140 +756,164 @@ class LLMModel(nn.Module):
751
756
  self.language_model = TextModel(text_config)
752
757
  if not text_config.tie_word_embeddings:
753
758
  self.lm_head = nn.Linear(text_config.hidden_size, text_config.vocab_size, bias=False)
754
-
759
+
755
760
  def get_rope_index(
756
- self,
757
- input_ids: Optional[mx.array] = None,
758
- image_grid_thw: Optional[mx.array] = None,
759
- attention_mask: Optional[mx.array] = None,
760
- ) -> Tuple[mx.array, mx.array]:
761
- """Simplified version for images only (no video support)."""
762
-
763
- spatial_merge_size = 2
764
- image_token_id = 151655
765
- vision_start_token_id = 151652
766
- mrope_position_deltas = []
767
-
768
- if input_ids is not None and image_grid_thw is not None:
769
- total_input_ids = input_ids
770
- if attention_mask is None:
771
- attention_mask = mx.ones_like(total_input_ids)
772
-
773
- batch_size, seq_len = input_ids.shape
774
- position_ids_list = []
775
- image_index = 0
776
-
777
- for i in range(batch_size):
778
- input_ids_seq = total_input_ids[i]
779
- mask_seq = attention_mask[i]
780
-
781
- # Use mask to get valid length
782
- valid_length = int(mx.sum(mask_seq).item())
783
- input_ids_seq = input_ids_seq[:valid_length]
784
-
785
- image_nums = 0
786
- # Find vision start tokens by iterating through the sequence
787
- vision_start_positions = []
788
- for pos in range(input_ids_seq.shape[0]):
789
- if input_ids_seq[pos].item() == vision_start_token_id:
790
- vision_start_positions.append(pos)
791
-
792
- if len(vision_start_positions) > 0:
793
- for pos in vision_start_positions:
794
- if pos + 1 < input_ids_seq.shape[0]:
795
- if input_ids_seq[pos + 1].item() == image_token_id:
796
- image_nums += 1
797
-
798
- input_tokens = input_ids_seq.tolist()
799
- llm_pos_ids_list = []
800
- st = 0
801
- remain_images = image_nums
802
-
803
- for _ in range(image_nums):
804
- ed_image = input_tokens.index(image_token_id, st)
805
-
806
- t = image_grid_thw[image_index, 0].item()
807
- h = image_grid_thw[image_index, 1].item()
808
- w = image_grid_thw[image_index, 2].item()
809
- image_index += 1
810
- remain_images -= 1
811
- ed = ed_image
812
-
813
- llm_grid_t = int(t)
814
- llm_grid_h = int(h) // spatial_merge_size
815
- llm_grid_w = int(w) // spatial_merge_size
816
- text_len = ed - st
817
-
818
- st_idx = llm_pos_ids_list[-1].max().item() + 1 if len(llm_pos_ids_list) > 0 else 0
819
- text_pos = mx.arange(text_len).reshape(1, -1)
820
- text_pos = mx.broadcast_to(text_pos, (3, text_len)) + st_idx
821
- llm_pos_ids_list.append(text_pos)
822
-
823
- # t_index is always 0 because llm_grid_t is always 1 for images
824
- t_index = mx.arange(llm_grid_t).reshape(-1, 1)
825
- t_index = mx.broadcast_to(t_index, (llm_grid_t, llm_grid_h * llm_grid_w)).reshape(-1)
826
-
827
- h_index = mx.arange(llm_grid_h).reshape(1, -1, 1)
828
- h_index = mx.broadcast_to(h_index, (llm_grid_t, llm_grid_h, llm_grid_w)).reshape(-1)
829
-
830
- w_index = mx.arange(llm_grid_w).reshape(1, 1, -1)
831
- w_index = mx.broadcast_to(w_index, (llm_grid_t, llm_grid_h, llm_grid_w)).reshape(-1)
832
-
833
- vision_pos = mx.stack([t_index, h_index, w_index]) + text_len + st_idx
834
- llm_pos_ids_list.append(vision_pos)
835
- st = ed + llm_grid_t * llm_grid_h * llm_grid_w
836
-
837
- if st < len(input_tokens):
838
- st_idx = llm_pos_ids_list[-1].max().item() + 1 if len(llm_pos_ids_list) > 0 else 0
839
- text_len = len(input_tokens) - st
840
- text_pos = mx.arange(text_len).reshape(1, -1)
841
- text_pos = mx.broadcast_to(text_pos, (3, text_len)) + st_idx
842
- llm_pos_ids_list.append(text_pos)
843
-
844
- llm_positions = mx.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
845
-
846
- # Create position_ids for this batch item, pad to seq_len
847
- batch_position_ids = mx.ones((3, seq_len), dtype=input_ids.dtype)
848
- valid_length = min(seq_len, llm_positions.shape[1])
849
-
850
- # Create new arrays for each dimension
851
- pos_dim0 = mx.concatenate([llm_positions[0, :valid_length],
852
- mx.ones(seq_len - valid_length, dtype=input_ids.dtype)])
853
- pos_dim1 = mx.concatenate([llm_positions[1, :valid_length],
854
- mx.ones(seq_len - valid_length, dtype=input_ids.dtype)])
855
- pos_dim2 = mx.concatenate([llm_positions[2, :valid_length],
856
- mx.ones(seq_len - valid_length, dtype=input_ids.dtype)])
857
-
858
- batch_position_ids = mx.stack([pos_dim0, pos_dim1, pos_dim2])
859
- position_ids_list.append(batch_position_ids)
860
-
861
- mrope_position_deltas.append(llm_positions.max().item() + 1 - len(total_input_ids[i]))
862
-
863
- # Stack all batch position_ids
864
- position_ids = mx.stack(position_ids_list, axis=1) # Shape: (3, batch_size, seq_len)
865
- # Ensure rope deltas are 1D: (batch,)
866
- mrope_position_deltas = mx.array(mrope_position_deltas).reshape(-1)
867
- return position_ids, mrope_position_deltas
761
+ self,
762
+ input_ids: Optional[mx.array] = None,
763
+ image_grid_thw: Optional[mx.array] = None,
764
+ attention_mask: Optional[mx.array] = None,
765
+ ) -> Tuple[mx.array, mx.array]:
766
+ """Simplified version for images only (no video support)."""
767
+
768
+ spatial_merge_size = 2
769
+ image_token_id = 151655
770
+ vision_start_token_id = 151652
771
+ mrope_position_deltas = []
772
+
773
+ if input_ids is not None and image_grid_thw is not None:
774
+ total_input_ids = input_ids
775
+ if attention_mask is None:
776
+ attention_mask = mx.ones_like(total_input_ids)
777
+
778
+ batch_size, seq_len = input_ids.shape
779
+ position_ids_list = []
780
+ image_index = 0
781
+
782
+ for i in range(batch_size):
783
+ input_ids_seq = total_input_ids[i]
784
+ mask_seq = attention_mask[i]
785
+
786
+ # Use mask to get valid length
787
+ valid_length = int(mx.sum(mask_seq).item())
788
+ input_ids_seq = input_ids_seq[:valid_length]
789
+
790
+ image_nums = 0
791
+ # Find vision start tokens by iterating through the sequence
792
+ vision_start_positions = []
793
+ for pos in range(input_ids_seq.shape[0]):
794
+ if input_ids_seq[pos].item() == vision_start_token_id:
795
+ vision_start_positions.append(pos)
796
+
797
+ if len(vision_start_positions) > 0:
798
+ for pos in vision_start_positions:
799
+ if pos + 1 < input_ids_seq.shape[0]:
800
+ if input_ids_seq[pos + 1].item() == image_token_id:
801
+ image_nums += 1
802
+
803
+ input_tokens = input_ids_seq.tolist()
804
+ llm_pos_ids_list = []
805
+ st = 0
806
+ remain_images = image_nums
807
+
808
+ for _ in range(image_nums):
809
+ ed_image = input_tokens.index(image_token_id, st)
810
+
811
+ t = image_grid_thw[image_index, 0].item()
812
+ h = image_grid_thw[image_index, 1].item()
813
+ w = image_grid_thw[image_index, 2].item()
814
+ image_index += 1
815
+ remain_images -= 1
816
+ ed = ed_image
817
+
818
+ llm_grid_t = int(t)
819
+ llm_grid_h = int(h) // spatial_merge_size
820
+ llm_grid_w = int(w) // spatial_merge_size
821
+ text_len = ed - st
822
+
823
+ st_idx = (
824
+ llm_pos_ids_list[-1].max().item() + 1 if len(llm_pos_ids_list) > 0 else 0
825
+ )
826
+ text_pos = mx.arange(text_len).reshape(1, -1)
827
+ text_pos = mx.broadcast_to(text_pos, (3, text_len)) + st_idx
828
+ llm_pos_ids_list.append(text_pos)
829
+
830
+ # t_index is always 0 because llm_grid_t is always 1 for images
831
+ t_index = mx.arange(llm_grid_t).reshape(-1, 1)
832
+ t_index = mx.broadcast_to(
833
+ t_index, (llm_grid_t, llm_grid_h * llm_grid_w)
834
+ ).reshape(-1)
835
+
836
+ h_index = mx.arange(llm_grid_h).reshape(1, -1, 1)
837
+ h_index = mx.broadcast_to(
838
+ h_index, (llm_grid_t, llm_grid_h, llm_grid_w)
839
+ ).reshape(-1)
840
+
841
+ w_index = mx.arange(llm_grid_w).reshape(1, 1, -1)
842
+ w_index = mx.broadcast_to(
843
+ w_index, (llm_grid_t, llm_grid_h, llm_grid_w)
844
+ ).reshape(-1)
845
+
846
+ vision_pos = mx.stack([t_index, h_index, w_index]) + text_len + st_idx
847
+ llm_pos_ids_list.append(vision_pos)
848
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
849
+
850
+ if st < len(input_tokens):
851
+ st_idx = (
852
+ llm_pos_ids_list[-1].max().item() + 1 if len(llm_pos_ids_list) > 0 else 0
853
+ )
854
+ text_len = len(input_tokens) - st
855
+ text_pos = mx.arange(text_len).reshape(1, -1)
856
+ text_pos = mx.broadcast_to(text_pos, (3, text_len)) + st_idx
857
+ llm_pos_ids_list.append(text_pos)
858
+
859
+ llm_positions = mx.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
860
+
861
+ # Create position_ids for this batch item, pad to seq_len
862
+ batch_position_ids = mx.ones((3, seq_len), dtype=input_ids.dtype)
863
+ valid_length = min(seq_len, llm_positions.shape[1])
864
+
865
+ # Create new arrays for each dimension
866
+ pos_dim0 = mx.concatenate(
867
+ [
868
+ llm_positions[0, :valid_length],
869
+ mx.ones(seq_len - valid_length, dtype=input_ids.dtype),
870
+ ]
871
+ )
872
+ pos_dim1 = mx.concatenate(
873
+ [
874
+ llm_positions[1, :valid_length],
875
+ mx.ones(seq_len - valid_length, dtype=input_ids.dtype),
876
+ ]
877
+ )
878
+ pos_dim2 = mx.concatenate(
879
+ [
880
+ llm_positions[2, :valid_length],
881
+ mx.ones(seq_len - valid_length, dtype=input_ids.dtype),
882
+ ]
883
+ )
884
+
885
+ batch_position_ids = mx.stack([pos_dim0, pos_dim1, pos_dim2])
886
+ position_ids_list.append(batch_position_ids)
887
+
888
+ mrope_position_deltas.append(
889
+ llm_positions.max().item() + 1 - len(total_input_ids[i])
890
+ )
891
+
892
+ # Stack all batch position_ids
893
+ position_ids = mx.stack(position_ids_list, axis=1) # Shape: (3, batch_size, seq_len)
894
+ mrope_position_deltas = mx.array(mrope_position_deltas).reshape(-1, 1)
895
+ return position_ids, mrope_position_deltas
896
+ else:
897
+ if attention_mask is not None:
898
+ position_ids = mx.cumsum(attention_mask.astype(mx.int32), axis=-1) - 1
899
+ position_ids = mx.where(attention_mask == 0, 1, position_ids)
900
+ position_ids = mx.expand_dims(position_ids, axis=0)
901
+ position_ids = mx.broadcast_to(
902
+ position_ids, (3, position_ids.shape[1], position_ids.shape[2])
903
+ )
904
+ max_position_ids = mx.max(
905
+ mx.max(position_ids, axis=0, keepdims=False), axis=-1, keepdims=True
906
+ )
907
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
868
908
  else:
869
- if attention_mask is not None:
870
- position_ids = mx.cumsum(attention_mask.astype(mx.int32), axis=-1) - 1
871
- position_ids = mx.where(attention_mask == 0, 1, position_ids)
872
- position_ids = mx.expand_dims(position_ids, axis=0)
873
- position_ids = mx.broadcast_to(position_ids, (3, position_ids.shape[1], position_ids.shape[2]))
874
- # Compute max position per batch, ensure 1D shape (batch,)
875
- max_position_ids = mx.max(mx.max(position_ids, axis=0, keepdims=False), axis=-1, keepdims=False)
876
- mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
877
- mrope_position_deltas = mx.reshape(mrope_position_deltas, (-1,))
878
- else:
879
- seq_len = input_ids.shape[1]
880
- batch_size = input_ids.shape[0]
881
- position_ids = mx.arange(seq_len).reshape(1, 1, -1)
882
- position_ids = mx.broadcast_to(position_ids, (3, batch_size, seq_len))
883
- # 1D zeros for rope deltas
884
- mrope_position_deltas = mx.zeros((batch_size,), dtype=input_ids.dtype)
885
-
886
- return position_ids, mrope_position_deltas
887
-
909
+ seq_len = input_ids.shape[1]
910
+ batch_size = input_ids.shape[0]
911
+ position_ids = mx.arange(seq_len).reshape(1, 1, -1)
912
+ position_ids = mx.broadcast_to(position_ids, (3, batch_size, seq_len))
913
+ mrope_position_deltas = mx.zeros((batch_size, 1), dtype=input_ids.dtype)
914
+
915
+ return position_ids, mrope_position_deltas
916
+
888
917
  def __call__(
889
918
  self,
890
919
  inputs: mx.array = None,
@@ -912,35 +941,41 @@ class LLMModel(nn.Module):
912
941
  return self.language_model.embed_tokens.as_linear(out)
913
942
  else:
914
943
  return self.lm_head(out)
915
-
944
+
916
945
  def sanitize(self, weights):
917
946
  sanitized = {}
918
947
  for k, v in weights.items():
919
- if not ('visual.' in k):
948
+ if not ("visual." in k):
920
949
  # Handle key mapping from combined model to LLM-only model
921
950
  clean_key = k
922
-
951
+
923
952
  # Remove model. prefix if present
924
- if clean_key.startswith('model.'):
953
+ if clean_key.startswith("model."):
925
954
  clean_key = clean_key[6:] # Remove 'model.'
926
-
955
+
927
956
  # Map language_ prefixed keys to language_model structure
928
- if clean_key.startswith('language_'):
929
- if clean_key.startswith('language_layers.'):
930
- clean_key = 'language_model.layers.' + clean_key[16:] # Map to language_model.layers.
931
- elif clean_key.startswith('language_embed_tokens.'):
932
- clean_key = 'language_model.embed_tokens.' + clean_key[22:] # Map to language_model.embed_tokens.
933
- elif clean_key.startswith('language_norm.'):
934
- clean_key = 'language_model.norm.' + clean_key[14:] # Map to language_model.norm.
935
-
957
+ if clean_key.startswith("language_"):
958
+ if clean_key.startswith("language_layers."):
959
+ clean_key = (
960
+ "language_model.layers." + clean_key[16:]
961
+ ) # Map to language_model.layers.
962
+ elif clean_key.startswith("language_embed_tokens."):
963
+ clean_key = (
964
+ "language_model.embed_tokens." + clean_key[22:]
965
+ ) # Map to language_model.embed_tokens.
966
+ elif clean_key.startswith("language_norm."):
967
+ clean_key = (
968
+ "language_model.norm." + clean_key[14:]
969
+ ) # Map to language_model.norm.
970
+
936
971
  sanitized[clean_key] = v
937
-
972
+
938
973
  # Handle tied embeddings - remove lm_head if using tied embeddings
939
974
  if self.args.tie_word_embeddings:
940
975
  sanitized.pop("lm_head.weight", None)
941
-
976
+
942
977
  return sanitized
943
-
978
+
944
979
  @property
945
980
  def layers(self):
946
981
  return self.language_model.layers
@@ -954,39 +989,36 @@ class Qwen3VLModel(nn.Module):
954
989
  self.config = args
955
990
  self.visual = VisionModel(args.vision_config)
956
991
  self.language_model = TextModel(args.text_config)
957
-
992
+
958
993
  def sanitize(self, weights):
959
994
  # Map weights to match the combined model structure
960
995
  sanitized = {}
961
996
  for k, v in weights.items():
962
997
  # Remove 'model.' prefix if present to match our structure
963
- clean_key = k.replace('model.', '') if k.startswith('model.') else k
998
+ clean_key = k.replace("model.", "") if k.startswith("model.") else k
964
999
  sanitized[clean_key] = v
965
1000
  return sanitized
966
1001
 
967
- def get_image_features(
968
- self,
969
- pixel_values: mx.array,
970
- image_grid_thw: Optional[mx.array] = None
971
- ):
1002
+ def get_image_features(self, pixel_values: mx.array, image_grid_thw: Optional[mx.array] = None):
972
1003
  image_embeds, deepstack_visual_embeds = self.visual(pixel_values, image_grid_thw)
973
1004
  # Split based on grid dimensions
974
1005
  if image_grid_thw is not None:
975
- split_sizes = (mx.prod(image_grid_thw, axis=-1) // (self.visual.spatial_merge_size ** 2)).tolist()
1006
+ split_sizes = (
1007
+ mx.prod(image_grid_thw, axis=-1) // (self.visual.spatial_merge_size**2)
1008
+ ).tolist()
976
1009
  # Convert sizes to indices for mx.split (cumulative sum, excluding the last)
977
1010
  split_indices = []
978
1011
  cumsum = 0
979
1012
  for size in split_sizes[:-1]: # Exclude last element
980
1013
  cumsum += size
981
1014
  split_indices.append(cumsum)
982
-
1015
+
983
1016
  if split_indices: # Only split if we have indices
984
1017
  image_embeds = mx.split(image_embeds, split_indices)
985
1018
  else:
986
1019
  image_embeds = [image_embeds] # Single image case
987
1020
  return image_embeds, deepstack_visual_embeds
988
1021
 
989
-
990
1022
  def __call__(
991
1023
  self,
992
1024
  input_ids: mx.array = None,
@@ -1005,26 +1037,25 @@ class Qwen3VLModel(nn.Module):
1005
1037
  inputs_embeds = self.language_model.embed_tokens(input_ids)
1006
1038
 
1007
1039
  # Process images
1008
-
1040
+
1009
1041
  if pixel_values is not None:
1010
1042
  image_embeds, deepstack_visual_embeds = self.get_image_features(
1011
1043
  pixel_values, image_grid_thw
1012
1044
  )
1013
-
1045
+
1014
1046
  # Create masks and embed visual features
1015
1047
  if isinstance(image_embeds, list):
1016
1048
  image_embeds = mx.concatenate(image_embeds, axis=0)
1017
-
1049
+
1018
1050
  # Find image token positions and replace with visual embeddings
1019
- image_mask = (input_ids == self.args.image_token_id)
1051
+ image_mask = input_ids == self.args.image_token_id
1020
1052
  visual_pos_masks = image_mask
1021
-
1053
+
1022
1054
  # Replace image tokens with visual embeddings
1023
1055
  inputs_embeds = inputs_embeds.at[image_mask].set(
1024
1056
  image_embeds.astype(inputs_embeds.dtype)
1025
1057
  )
1026
1058
 
1027
-
1028
1059
  outputs = self.language_model(
1029
1060
  inputs_embeds=inputs_embeds,
1030
1061
  attention_mask=attention_mask,
@@ -1042,28 +1073,28 @@ class Qwen3VLModel(nn.Module):
1042
1073
  def handle_multimodal_embeds(vision_model, llm_model, input_ids, pixel_values, image_grid_thw):
1043
1074
  """
1044
1075
  Handle the processing of multimodal embeddings including image features and position encoding.
1045
-
1076
+
1046
1077
  This function processes vision and text inputs to create unified embeddings that can be fed
1047
1078
  into the language model. It handles:
1048
1079
  - Vision feature extraction from pixel values
1049
1080
  - Deepstack visual embedding collection
1050
1081
  - Image token replacement in text embeddings
1051
1082
  - Position encoding setup for MRoPE (Multi-dimensional RoPE)
1052
-
1083
+
1053
1084
  Args:
1054
1085
  vision_model: The vision encoder model (VEGModel instance)
1055
- llm_model: The language model (LLMModel instance)
1086
+ llm_model: The language model (LLMModel instance)
1056
1087
  input_ids: Tokenized text input with image token placeholders [batch_size, seq_len]
1057
1088
  pixel_values: Preprocessed image pixel data [num_patches, feature_dim]
1058
1089
  image_grid_thw: Grid dimensions for each image [num_images, 3] (time, height, width)
1059
-
1090
+
1060
1091
  Returns:
1061
1092
  tuple: (inputs_embeds, deepstack_visual_embeds, visual_pos_masks, cos, sin, rope_deltas)
1062
1093
  - inputs_embeds: Combined text and image embeddings [batch_size, seq_len, hidden_size]
1063
1094
  - deepstack_visual_embeds: Multi-layer visual features for deepstack processing
1064
1095
  - visual_pos_masks: Boolean mask indicating image token positions
1065
1096
  - cos: Cosine values for rotary position encoding
1066
- - sin: Sine values for rotary position encoding
1097
+ - sin: Sine values for rotary position encoding
1067
1098
  - rope_deltas: Position offset deltas for rope computation
1068
1099
  """
1069
1100
  inputs_embeds = llm_model.language_model.embed_tokens(input_ids.squeeze(0))
@@ -1072,74 +1103,80 @@ def handle_multimodal_embeds(vision_model, llm_model, input_ids, pixel_values, i
1072
1103
  cos = None
1073
1104
  sin = None
1074
1105
  rope_deltas = 0
1075
-
1106
+
1076
1107
  if pixel_values is not None:
1077
1108
  if pixel_values.ndim == 4:
1078
1109
  pixel_values = mx.expand_dims(pixel_values, axis=2)
1079
-
1110
+
1080
1111
  # Process each image individually to prevent feature mixing
1081
1112
  image_embeds_list = []
1082
1113
  all_deepstack_embeds = []
1083
-
1114
+
1084
1115
  # Calculate cumulative indices for each image
1085
1116
  cumulative_patches = 0
1086
-
1117
+
1087
1118
  for i in range(image_grid_thw.shape[0]):
1088
1119
  # Calculate number of patches for current image
1089
1120
  current_patches = int(image_grid_thw[i, 1] * image_grid_thw[i, 2])
1090
1121
  start_idx = cumulative_patches
1091
1122
  end_idx = cumulative_patches + current_patches
1092
1123
  cumulative_patches += current_patches
1093
-
1124
+
1094
1125
  single_pixel_values = pixel_values[start_idx:end_idx]
1095
- single_grid_thw = image_grid_thw[i:i+1]
1096
-
1126
+ single_grid_thw = image_grid_thw[i : i + 1]
1127
+
1097
1128
  # Use vision model directly
1098
1129
  single_embeds, single_deepstack = vision_model(single_pixel_values, single_grid_thw)
1099
-
1130
+
1100
1131
  # Split based on grid dimensions
1101
1132
  if single_grid_thw is not None:
1102
- split_sizes = (mx.prod(single_grid_thw, axis=-1) // (vision_model.visual.spatial_merge_size ** 2)).tolist()
1133
+ split_sizes = (
1134
+ mx.prod(single_grid_thw, axis=-1) // (vision_model.visual.spatial_merge_size**2)
1135
+ ).tolist()
1103
1136
  split_indices = []
1104
1137
  cumsum = 0
1105
1138
  for size in split_sizes[:-1]:
1106
1139
  cumsum += size
1107
1140
  split_indices.append(cumsum)
1108
-
1141
+
1109
1142
  if split_indices:
1110
1143
  single_embeds = mx.split(single_embeds, split_indices)
1111
1144
  else:
1112
1145
  single_embeds = [single_embeds]
1113
-
1146
+
1114
1147
  image_embeds_list.extend(single_embeds)
1115
-
1148
+
1116
1149
  # Collect deepstack embeddings
1117
1150
  if i == 0:
1118
1151
  all_deepstack_embeds = single_deepstack
1119
1152
  else:
1120
1153
  # Concatenate deepstack embeddings from different images
1121
1154
  for j in range(len(all_deepstack_embeds)):
1122
- all_deepstack_embeds[j] = mx.concatenate([all_deepstack_embeds[j], single_deepstack[j]], axis=0)
1123
-
1155
+ all_deepstack_embeds[j] = mx.concatenate(
1156
+ [all_deepstack_embeds[j], single_deepstack[j]], axis=0
1157
+ )
1158
+
1124
1159
  deepstack_visual_embeds = all_deepstack_embeds
1125
-
1160
+
1126
1161
  # Concatenate all image embeddings for processing
1127
1162
  image_embeds = mx.concatenate(image_embeds_list, axis=0)
1128
-
1163
+
1129
1164
  # Find all image token positions
1130
1165
  image_token_id = 151655 # Default image token ID
1131
- image_mask = (input_ids.squeeze(0) == image_token_id)
1166
+ image_mask = input_ids.squeeze(0) == image_token_id
1132
1167
  image_mask_np = np.array(image_mask)
1133
1168
  image_token_positions = np.where(image_mask_np)[0]
1134
-
1169
+
1135
1170
  # Verify we have the correct number of image tokens
1136
1171
  expected_total_tokens = sum(embed.shape[0] for embed in image_embeds_list)
1137
- assert len(image_token_positions) == expected_total_tokens, f"Expected {expected_total_tokens} image tokens, got {len(image_token_positions)}"
1138
-
1172
+ assert (
1173
+ len(image_token_positions) == expected_total_tokens
1174
+ ), f"Expected {expected_total_tokens} image tokens, got {len(image_token_positions)}"
1175
+
1139
1176
  # Replace image tokens with image embeddings
1140
1177
  seq_len = inputs_embeds.shape[0]
1141
1178
  result = inputs_embeds
1142
-
1179
+
1143
1180
  # Replace image tokens with image embeddings sequentially
1144
1181
  embed_idx = 0
1145
1182
  for img_embed in image_embeds_list:
@@ -1149,7 +1186,7 @@ def handle_multimodal_embeds(vision_model, llm_model, input_ids, pixel_values, i
1149
1186
  result = mx.where(
1150
1187
  mx.expand_dims(pos_mask, axis=-1),
1151
1188
  mx.expand_dims(img_embed[patch_idx], axis=0).astype(inputs_embeds.dtype),
1152
- result
1189
+ result,
1153
1190
  )
1154
1191
  embed_idx += 1
1155
1192
 
@@ -1158,10 +1195,10 @@ def handle_multimodal_embeds(vision_model, llm_model, input_ids, pixel_values, i
1158
1195
  cos, sin = llm_model.language_model.rotary_emb(inputs_embeds, position_ids)
1159
1196
  if inputs_embeds.ndim == 2:
1160
1197
  inputs_embeds = mx.expand_dims(inputs_embeds, axis=0)
1161
-
1198
+
1162
1199
  if image_mask is not None:
1163
1200
  visual_pos_masks = image_mask
1164
-
1201
+
1165
1202
  return inputs_embeds, deepstack_visual_embeds, visual_pos_masks, cos, sin, rope_deltas
1166
1203
 
1167
1204
 
@@ -1172,7 +1209,9 @@ class Model(nn.Module):
1172
1209
  self.args = args
1173
1210
  self.model = Qwen3VLModel(args)
1174
1211
  if not args.text_config.tie_word_embeddings:
1175
- self.lm_head = nn.Linear(args.text_config.hidden_size, args.text_config.vocab_size, bias=False)
1212
+ self.lm_head = nn.Linear(
1213
+ args.text_config.hidden_size, args.text_config.vocab_size, bias=False
1214
+ )
1176
1215
 
1177
1216
  def __call__(
1178
1217
  self,
@@ -1180,7 +1219,7 @@ class Model(nn.Module):
1180
1219
  mask: mx.array = None,
1181
1220
  cache=None,
1182
1221
  inputs_embeds: Optional[mx.array] = None,
1183
- pixel_values: Optional[mx.array] = None,
1222
+ pixel_values: Optional[mx.array] = None,
1184
1223
  image_grid_thw: Optional[mx.array] = None,
1185
1224
  visual_pos_masks: Optional[mx.array] = None,
1186
1225
  deepstack_visual_embeds: Optional[List[mx.array]] = None,
@@ -1211,13 +1250,13 @@ class Model(nn.Module):
1211
1250
  sanitized = {}
1212
1251
  for k, v in weights.items():
1213
1252
  sanitized[k] = v
1214
-
1253
+
1215
1254
  # Handle tied embeddings - remove lm_head if using tied embeddings
1216
1255
  if self.args.text_config.tie_word_embeddings:
1217
1256
  sanitized.pop("lm_head.weight", None)
1218
-
1257
+
1219
1258
  return sanitized
1220
1259
 
1221
1260
  @property
1222
1261
  def layers(self):
1223
- return self.model.language_model.layers
1262
+ return self.model.language_model.layers