nexaai 1.0.29__cp310-cp310-macosx_14_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (580) hide show
  1. nexaai/__init__.py +99 -0
  2. nexaai/_stub.cpython-310-darwin.so +0 -0
  3. nexaai/_version.py +4 -0
  4. nexaai/asr.py +68 -0
  5. nexaai/asr_impl/__init__.py +0 -0
  6. nexaai/asr_impl/mlx_asr_impl.py +93 -0
  7. nexaai/asr_impl/pybind_asr_impl.py +127 -0
  8. nexaai/base.py +39 -0
  9. nexaai/binds/__init__.py +7 -0
  10. nexaai/binds/asr_bind.cpython-310-darwin.so +0 -0
  11. nexaai/binds/common_bind.cpython-310-darwin.so +0 -0
  12. nexaai/binds/cpu_gpu/libggml-base.dylib +0 -0
  13. nexaai/binds/cpu_gpu/libggml-cpu.so +0 -0
  14. nexaai/binds/cpu_gpu/libggml-metal.so +0 -0
  15. nexaai/binds/cpu_gpu/libggml.dylib +0 -0
  16. nexaai/binds/cpu_gpu/libmtmd.dylib +0 -0
  17. nexaai/binds/cpu_gpu/libnexa_cpu_gpu.dylib +0 -0
  18. nexaai/binds/cpu_gpu/libnexa_plugin.dylib +0 -0
  19. nexaai/binds/cv_bind.cpython-310-darwin.so +0 -0
  20. nexaai/binds/diarize_bind.cpython-310-darwin.so +0 -0
  21. nexaai/binds/embedder_bind.cpython-310-darwin.so +0 -0
  22. nexaai/binds/libnexa_bridge.dylib +0 -0
  23. nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
  24. nexaai/binds/metal/libnexa_plugin.dylib +0 -0
  25. nexaai/binds/metal/py-lib/ml.py +888 -0
  26. nexaai/binds/metal/py-lib/mlx_audio/__init__.py +0 -0
  27. nexaai/binds/metal/py-lib/mlx_audio/codec/__init__.py +1 -0
  28. nexaai/binds/metal/py-lib/mlx_audio/codec/models/__init__.py +5 -0
  29. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  30. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  31. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  32. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  33. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  34. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  35. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/__init__.py +1 -0
  36. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/base.py +228 -0
  37. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/dac.py +285 -0
  38. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  39. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  40. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  41. nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/__init__.py +1 -0
  42. nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/encodec.py +777 -0
  43. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/__init__.py +1 -0
  44. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/mimi.py +286 -0
  45. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  46. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  47. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  48. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  49. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  50. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  51. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/__init__.py +1 -0
  52. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model.py +260 -0
  53. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model_v2.py +383 -0
  54. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/utils.py +122 -0
  55. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/__init__.py +1 -0
  56. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/attention.py +97 -0
  57. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/layers.py +306 -0
  58. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/snac.py +154 -0
  59. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/vq.py +135 -0
  60. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/__init__.py +1 -0
  61. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/mel.py +33 -0
  62. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/vocos.py +359 -0
  63. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/__init__.py +0 -0
  64. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  65. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_descript.py +109 -0
  66. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_encodec.py +58 -0
  67. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_mimi.py +22 -0
  68. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_s3.py +25 -0
  69. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_snac.py +40 -0
  70. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_vocos.py +93 -0
  71. nexaai/binds/metal/py-lib/mlx_audio/server.py +525 -0
  72. nexaai/binds/metal/py-lib/mlx_audio/sts/__init__.py +0 -0
  73. nexaai/binds/metal/py-lib/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  74. nexaai/binds/metal/py-lib/mlx_audio/sts/voice_pipeline.py +327 -0
  75. nexaai/binds/metal/py-lib/mlx_audio/stt/__init__.py +0 -0
  76. nexaai/binds/metal/py-lib/mlx_audio/stt/generate.py +174 -0
  77. nexaai/binds/metal/py-lib/mlx_audio/stt/models/__init__.py +0 -0
  78. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  79. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  80. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/attention.py +187 -0
  81. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/audio.py +76 -0
  82. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  83. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  84. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  85. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  86. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  87. nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  88. nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  89. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/__init__.py +1 -0
  90. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/audio.py +82 -0
  91. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/decoding.py +742 -0
  92. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/timing.py +329 -0
  93. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  94. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/whisper.py +862 -0
  95. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/writers.py +268 -0
  96. nexaai/binds/metal/py-lib/mlx_audio/stt/tests/test_models.py +381 -0
  97. nexaai/binds/metal/py-lib/mlx_audio/stt/utils.py +195 -0
  98. nexaai/binds/metal/py-lib/mlx_audio/tts/__init__.py +1 -0
  99. nexaai/binds/metal/py-lib/mlx_audio/tts/audio_player.py +120 -0
  100. nexaai/binds/metal/py-lib/mlx_audio/tts/convert.py +71 -0
  101. nexaai/binds/metal/py-lib/mlx_audio/tts/generate.py +449 -0
  102. nexaai/binds/metal/py-lib/mlx_audio/tts/models/__init__.py +0 -0
  103. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/__init__.py +4 -0
  104. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/bark.py +528 -0
  105. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/isftnet.py +12 -0
  106. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/pipeline.py +442 -0
  107. nexaai/binds/metal/py-lib/mlx_audio/tts/models/base.py +84 -0
  108. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/__init__.py +1 -0
  109. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/audio.py +287 -0
  110. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/config.py +256 -0
  111. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/dia.py +592 -0
  112. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/layers.py +870 -0
  113. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/__init__.py +3 -0
  114. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/attention.py +180 -0
  115. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  116. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/conformer.py +247 -0
  117. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  118. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  119. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  120. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  121. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  122. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  123. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/indextts.py +412 -0
  124. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/mel.py +37 -0
  125. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/normalize.py +294 -0
  126. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  127. nexaai/binds/metal/py-lib/mlx_audio/tts/models/interpolate.py +108 -0
  128. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  129. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  130. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  131. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/modules.py +659 -0
  132. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  133. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/voice.py +113 -0
  134. nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/__init__.py +3 -0
  135. nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/llama.py +324 -0
  136. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/__init__.py +1 -0
  137. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  138. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  139. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/outetts.py +255 -0
  140. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  141. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/tokens.py +36 -0
  142. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/__init__.py +3 -0
  143. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/attention.py +195 -0
  144. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/sesame.py +633 -0
  145. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  146. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/__init__.py +1 -0
  147. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  148. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/bicodec.py +269 -0
  149. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  150. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  151. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  152. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  153. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  154. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  155. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  156. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  157. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  158. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  159. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  160. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  161. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  162. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  163. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  164. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/spark.py +382 -0
  165. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  166. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/file.py +221 -0
  167. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  168. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/__init__.py +0 -0
  169. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_base.py +66 -0
  170. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_convert.py +173 -0
  171. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_interpolate.py +88 -0
  172. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_models.py +974 -0
  173. nexaai/binds/metal/py-lib/mlx_audio/tts/utils.py +337 -0
  174. nexaai/binds/metal/py-lib/mlx_audio/utils.py +237 -0
  175. nexaai/binds/metal/py-lib/mlx_audio/version.py +1 -0
  176. nexaai/binds/metal/py-lib/profiling.py +239 -0
  177. nexaai/binds/nexaml/libfftw3.3.dylib +0 -0
  178. nexaai/binds/nexaml/libfftw3f.3.dylib +0 -0
  179. nexaai/binds/nexaml/libggml-base.dylib +0 -0
  180. nexaai/binds/nexaml/libggml-cpu.so +0 -0
  181. nexaai/binds/nexaml/libggml-metal.so +0 -0
  182. nexaai/binds/nexaml/libggml.dylib +0 -0
  183. nexaai/binds/nexaml/libmp3lame.0.dylib +0 -0
  184. nexaai/binds/nexaml/libmpg123.0.dylib +0 -0
  185. nexaai/binds/nexaml/libnexa-mm-process.dylib +0 -0
  186. nexaai/binds/nexaml/libnexa-sampling.dylib +0 -0
  187. nexaai/binds/nexaml/libnexa_plugin.dylib +0 -0
  188. nexaai/binds/nexaml/libnexaproc.dylib +0 -0
  189. nexaai/binds/nexaml/libomp.dylib +0 -0
  190. nexaai/binds/nexaml/libqwen3-vl.dylib +0 -0
  191. nexaai/binds/nexaml/libqwen3vl-vision.dylib +0 -0
  192. nexaai/binds/rerank_bind.cpython-310-darwin.so +0 -0
  193. nexaai/binds/vlm_bind.cpython-310-darwin.so +0 -0
  194. nexaai/common.py +106 -0
  195. nexaai/cv.py +95 -0
  196. nexaai/cv_impl/__init__.py +0 -0
  197. nexaai/cv_impl/mlx_cv_impl.py +91 -0
  198. nexaai/cv_impl/pybind_cv_impl.py +124 -0
  199. nexaai/diarize.py +80 -0
  200. nexaai/diarize_impl/__init__.py +1 -0
  201. nexaai/diarize_impl/pybind_diarize_impl.py +125 -0
  202. nexaai/embedder.py +73 -0
  203. nexaai/embedder_impl/__init__.py +0 -0
  204. nexaai/embedder_impl/mlx_embedder_impl.py +118 -0
  205. nexaai/embedder_impl/pybind_embedder_impl.py +96 -0
  206. nexaai/image_gen.py +141 -0
  207. nexaai/image_gen_impl/__init__.py +0 -0
  208. nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -0
  209. nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -0
  210. nexaai/llm.py +98 -0
  211. nexaai/llm_impl/__init__.py +0 -0
  212. nexaai/llm_impl/mlx_llm_impl.py +271 -0
  213. nexaai/llm_impl/pybind_llm_impl.py +238 -0
  214. nexaai/log.py +92 -0
  215. nexaai/mlx_backend/asr/__init__.py +12 -0
  216. nexaai/mlx_backend/asr/interface.py +122 -0
  217. nexaai/mlx_backend/common/__init__.py +0 -0
  218. nexaai/mlx_backend/common/utils.py +25 -0
  219. nexaai/mlx_backend/cv/__init__.py +0 -0
  220. nexaai/mlx_backend/cv/generate.py +195 -0
  221. nexaai/mlx_backend/cv/interface.py +162 -0
  222. nexaai/mlx_backend/cv/main.py +81 -0
  223. nexaai/mlx_backend/cv/modeling/pp_ocr_v4.py +1736 -0
  224. nexaai/mlx_backend/embedding/__init__.py +0 -0
  225. nexaai/mlx_backend/embedding/generate.py +333 -0
  226. nexaai/mlx_backend/embedding/interface.py +617 -0
  227. nexaai/mlx_backend/embedding/main.py +173 -0
  228. nexaai/mlx_backend/embedding/modeling/__init__.py +0 -0
  229. nexaai/mlx_backend/embedding/modeling/nexa_jina_v2.py +399 -0
  230. nexaai/mlx_backend/image_gen/__init__.py +1 -0
  231. nexaai/mlx_backend/image_gen/generate_sd.py +244 -0
  232. nexaai/mlx_backend/image_gen/interface.py +82 -0
  233. nexaai/mlx_backend/image_gen/main.py +281 -0
  234. nexaai/mlx_backend/image_gen/stable_diffusion/__init__.py +306 -0
  235. nexaai/mlx_backend/image_gen/stable_diffusion/clip.py +116 -0
  236. nexaai/mlx_backend/image_gen/stable_diffusion/config.py +65 -0
  237. nexaai/mlx_backend/image_gen/stable_diffusion/model_io.py +386 -0
  238. nexaai/mlx_backend/image_gen/stable_diffusion/sampler.py +105 -0
  239. nexaai/mlx_backend/image_gen/stable_diffusion/tokenizer.py +100 -0
  240. nexaai/mlx_backend/image_gen/stable_diffusion/unet.py +460 -0
  241. nexaai/mlx_backend/image_gen/stable_diffusion/vae.py +274 -0
  242. nexaai/mlx_backend/llm/__init__.py +0 -0
  243. nexaai/mlx_backend/llm/generate.py +149 -0
  244. nexaai/mlx_backend/llm/interface.py +764 -0
  245. nexaai/mlx_backend/llm/main.py +68 -0
  246. nexaai/mlx_backend/ml.py +888 -0
  247. nexaai/mlx_backend/mlx_audio/__init__.py +0 -0
  248. nexaai/mlx_backend/mlx_audio/codec/__init__.py +1 -0
  249. nexaai/mlx_backend/mlx_audio/codec/models/__init__.py +5 -0
  250. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  251. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  252. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  253. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  254. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  255. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  256. nexaai/mlx_backend/mlx_audio/codec/models/descript/__init__.py +1 -0
  257. nexaai/mlx_backend/mlx_audio/codec/models/descript/base.py +228 -0
  258. nexaai/mlx_backend/mlx_audio/codec/models/descript/dac.py +285 -0
  259. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  260. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  261. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  262. nexaai/mlx_backend/mlx_audio/codec/models/encodec/__init__.py +1 -0
  263. nexaai/mlx_backend/mlx_audio/codec/models/encodec/encodec.py +777 -0
  264. nexaai/mlx_backend/mlx_audio/codec/models/mimi/__init__.py +1 -0
  265. nexaai/mlx_backend/mlx_audio/codec/models/mimi/mimi.py +286 -0
  266. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  267. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  268. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  269. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  270. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  271. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  272. nexaai/mlx_backend/mlx_audio/codec/models/s3/__init__.py +1 -0
  273. nexaai/mlx_backend/mlx_audio/codec/models/s3/model.py +260 -0
  274. nexaai/mlx_backend/mlx_audio/codec/models/s3/model_v2.py +383 -0
  275. nexaai/mlx_backend/mlx_audio/codec/models/s3/utils.py +122 -0
  276. nexaai/mlx_backend/mlx_audio/codec/models/snac/__init__.py +1 -0
  277. nexaai/mlx_backend/mlx_audio/codec/models/snac/attention.py +97 -0
  278. nexaai/mlx_backend/mlx_audio/codec/models/snac/layers.py +306 -0
  279. nexaai/mlx_backend/mlx_audio/codec/models/snac/snac.py +154 -0
  280. nexaai/mlx_backend/mlx_audio/codec/models/snac/vq.py +135 -0
  281. nexaai/mlx_backend/mlx_audio/codec/models/vocos/__init__.py +1 -0
  282. nexaai/mlx_backend/mlx_audio/codec/models/vocos/mel.py +33 -0
  283. nexaai/mlx_backend/mlx_audio/codec/models/vocos/vocos.py +359 -0
  284. nexaai/mlx_backend/mlx_audio/codec/tests/__init__.py +0 -0
  285. nexaai/mlx_backend/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  286. nexaai/mlx_backend/mlx_audio/codec/tests/test_descript.py +109 -0
  287. nexaai/mlx_backend/mlx_audio/codec/tests/test_encodec.py +58 -0
  288. nexaai/mlx_backend/mlx_audio/codec/tests/test_mimi.py +22 -0
  289. nexaai/mlx_backend/mlx_audio/codec/tests/test_s3.py +25 -0
  290. nexaai/mlx_backend/mlx_audio/codec/tests/test_snac.py +40 -0
  291. nexaai/mlx_backend/mlx_audio/codec/tests/test_vocos.py +93 -0
  292. nexaai/mlx_backend/mlx_audio/server.py +525 -0
  293. nexaai/mlx_backend/mlx_audio/sts/__init__.py +0 -0
  294. nexaai/mlx_backend/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  295. nexaai/mlx_backend/mlx_audio/sts/voice_pipeline.py +327 -0
  296. nexaai/mlx_backend/mlx_audio/stt/__init__.py +0 -0
  297. nexaai/mlx_backend/mlx_audio/stt/generate.py +174 -0
  298. nexaai/mlx_backend/mlx_audio/stt/models/__init__.py +0 -0
  299. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  300. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  301. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/attention.py +187 -0
  302. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/audio.py +76 -0
  303. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  304. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  305. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  306. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  307. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  308. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  309. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  310. nexaai/mlx_backend/mlx_audio/stt/models/whisper/__init__.py +1 -0
  311. nexaai/mlx_backend/mlx_audio/stt/models/whisper/audio.py +82 -0
  312. nexaai/mlx_backend/mlx_audio/stt/models/whisper/decoding.py +742 -0
  313. nexaai/mlx_backend/mlx_audio/stt/models/whisper/timing.py +329 -0
  314. nexaai/mlx_backend/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  315. nexaai/mlx_backend/mlx_audio/stt/models/whisper/whisper.py +862 -0
  316. nexaai/mlx_backend/mlx_audio/stt/models/whisper/writers.py +268 -0
  317. nexaai/mlx_backend/mlx_audio/stt/tests/test_models.py +381 -0
  318. nexaai/mlx_backend/mlx_audio/stt/utils.py +195 -0
  319. nexaai/mlx_backend/mlx_audio/tts/__init__.py +1 -0
  320. nexaai/mlx_backend/mlx_audio/tts/audio_player.py +120 -0
  321. nexaai/mlx_backend/mlx_audio/tts/convert.py +71 -0
  322. nexaai/mlx_backend/mlx_audio/tts/generate.py +449 -0
  323. nexaai/mlx_backend/mlx_audio/tts/models/__init__.py +0 -0
  324. nexaai/mlx_backend/mlx_audio/tts/models/bark/__init__.py +4 -0
  325. nexaai/mlx_backend/mlx_audio/tts/models/bark/bark.py +528 -0
  326. nexaai/mlx_backend/mlx_audio/tts/models/bark/isftnet.py +12 -0
  327. nexaai/mlx_backend/mlx_audio/tts/models/bark/pipeline.py +442 -0
  328. nexaai/mlx_backend/mlx_audio/tts/models/base.py +84 -0
  329. nexaai/mlx_backend/mlx_audio/tts/models/dia/__init__.py +1 -0
  330. nexaai/mlx_backend/mlx_audio/tts/models/dia/audio.py +287 -0
  331. nexaai/mlx_backend/mlx_audio/tts/models/dia/config.py +256 -0
  332. nexaai/mlx_backend/mlx_audio/tts/models/dia/dia.py +592 -0
  333. nexaai/mlx_backend/mlx_audio/tts/models/dia/layers.py +870 -0
  334. nexaai/mlx_backend/mlx_audio/tts/models/indextts/__init__.py +3 -0
  335. nexaai/mlx_backend/mlx_audio/tts/models/indextts/attention.py +180 -0
  336. nexaai/mlx_backend/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  337. nexaai/mlx_backend/mlx_audio/tts/models/indextts/conformer.py +247 -0
  338. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  339. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  340. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  341. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  342. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  343. nexaai/mlx_backend/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  344. nexaai/mlx_backend/mlx_audio/tts/models/indextts/indextts.py +412 -0
  345. nexaai/mlx_backend/mlx_audio/tts/models/indextts/mel.py +37 -0
  346. nexaai/mlx_backend/mlx_audio/tts/models/indextts/normalize.py +294 -0
  347. nexaai/mlx_backend/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  348. nexaai/mlx_backend/mlx_audio/tts/models/interpolate.py +108 -0
  349. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  350. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  351. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  352. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/modules.py +659 -0
  353. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  354. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/voice.py +113 -0
  355. nexaai/mlx_backend/mlx_audio/tts/models/llama/__init__.py +3 -0
  356. nexaai/mlx_backend/mlx_audio/tts/models/llama/llama.py +324 -0
  357. nexaai/mlx_backend/mlx_audio/tts/models/outetts/__init__.py +1 -0
  358. nexaai/mlx_backend/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  359. nexaai/mlx_backend/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  360. nexaai/mlx_backend/mlx_audio/tts/models/outetts/default_speaker.json +461 -0
  361. nexaai/mlx_backend/mlx_audio/tts/models/outetts/outetts.py +255 -0
  362. nexaai/mlx_backend/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  363. nexaai/mlx_backend/mlx_audio/tts/models/outetts/tokens.py +36 -0
  364. nexaai/mlx_backend/mlx_audio/tts/models/sesame/__init__.py +3 -0
  365. nexaai/mlx_backend/mlx_audio/tts/models/sesame/attention.py +195 -0
  366. nexaai/mlx_backend/mlx_audio/tts/models/sesame/sesame.py +633 -0
  367. nexaai/mlx_backend/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  368. nexaai/mlx_backend/mlx_audio/tts/models/spark/__init__.py +1 -0
  369. nexaai/mlx_backend/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  370. nexaai/mlx_backend/mlx_audio/tts/models/spark/bicodec.py +269 -0
  371. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  372. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  373. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  374. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  375. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  376. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  377. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  378. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  379. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  380. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  381. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  382. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  383. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  384. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  385. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  386. nexaai/mlx_backend/mlx_audio/tts/models/spark/spark.py +382 -0
  387. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  388. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/file.py +221 -0
  389. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  390. nexaai/mlx_backend/mlx_audio/tts/tests/__init__.py +0 -0
  391. nexaai/mlx_backend/mlx_audio/tts/tests/test_base.py +66 -0
  392. nexaai/mlx_backend/mlx_audio/tts/tests/test_convert.py +173 -0
  393. nexaai/mlx_backend/mlx_audio/tts/tests/test_interpolate.py +88 -0
  394. nexaai/mlx_backend/mlx_audio/tts/tests/test_models.py +974 -0
  395. nexaai/mlx_backend/mlx_audio/tts/utils.py +337 -0
  396. nexaai/mlx_backend/mlx_audio/utils.py +237 -0
  397. nexaai/mlx_backend/mlx_audio/version.py +1 -0
  398. nexaai/mlx_backend/profiling.py +239 -0
  399. nexaai/mlx_backend/rerank/__init__.py +0 -0
  400. nexaai/mlx_backend/rerank/generate.py +174 -0
  401. nexaai/mlx_backend/rerank/interface.py +287 -0
  402. nexaai/mlx_backend/rerank/main.py +127 -0
  403. nexaai/mlx_backend/rerank/modeling/__init__.py +0 -0
  404. nexaai/mlx_backend/rerank/modeling/nexa_jina_rerank.py +330 -0
  405. nexaai/mlx_backend/sd/__init__.py +1 -0
  406. nexaai/mlx_backend/sd/interface.py +362 -0
  407. nexaai/mlx_backend/sd/main.py +286 -0
  408. nexaai/mlx_backend/sd/modeling/__init__.py +306 -0
  409. nexaai/mlx_backend/sd/modeling/clip.py +116 -0
  410. nexaai/mlx_backend/sd/modeling/config.py +65 -0
  411. nexaai/mlx_backend/sd/modeling/model_io.py +385 -0
  412. nexaai/mlx_backend/sd/modeling/sampler.py +105 -0
  413. nexaai/mlx_backend/sd/modeling/tokenizer.py +100 -0
  414. nexaai/mlx_backend/sd/modeling/unet.py +460 -0
  415. nexaai/mlx_backend/sd/modeling/vae.py +274 -0
  416. nexaai/mlx_backend/tts/__init__.py +12 -0
  417. nexaai/mlx_backend/tts/interface.py +276 -0
  418. nexaai/mlx_backend/vlm/__init__.py +3 -0
  419. nexaai/mlx_backend/vlm/generate.py +572 -0
  420. nexaai/mlx_backend/vlm/generate_qwen3_vl.py +374 -0
  421. nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +259 -0
  422. nexaai/mlx_backend/vlm/interface.py +559 -0
  423. nexaai/mlx_backend/vlm/main.py +365 -0
  424. nexaai/mlx_backend/vlm/modeling/__init__.py +0 -0
  425. nexaai/mlx_backend/vlm/modeling/convert.py +68 -0
  426. nexaai/mlx_backend/vlm/modeling/models/__init__.py +0 -0
  427. nexaai/mlx_backend/vlm/modeling/models/aya_vision/__init__.py +8 -0
  428. nexaai/mlx_backend/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
  429. nexaai/mlx_backend/vlm/modeling/models/aya_vision/interpolate.py +186 -0
  430. nexaai/mlx_backend/vlm/modeling/models/aya_vision/language.py +233 -0
  431. nexaai/mlx_backend/vlm/modeling/models/aya_vision/vision.py +503 -0
  432. nexaai/mlx_backend/vlm/modeling/models/base.py +202 -0
  433. nexaai/mlx_backend/vlm/modeling/models/cache.py +230 -0
  434. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
  435. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
  436. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
  437. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
  438. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
  439. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
  440. nexaai/mlx_backend/vlm/modeling/models/florence2/__init__.py +8 -0
  441. nexaai/mlx_backend/vlm/modeling/models/florence2/florence2.py +366 -0
  442. nexaai/mlx_backend/vlm/modeling/models/florence2/language.py +488 -0
  443. nexaai/mlx_backend/vlm/modeling/models/florence2/vision.py +591 -0
  444. nexaai/mlx_backend/vlm/modeling/models/gemma3/__init__.py +8 -0
  445. nexaai/mlx_backend/vlm/modeling/models/gemma3/gemma3.py +213 -0
  446. nexaai/mlx_backend/vlm/modeling/models/gemma3/language.py +315 -0
  447. nexaai/mlx_backend/vlm/modeling/models/gemma3/vision.py +238 -0
  448. nexaai/mlx_backend/vlm/modeling/models/gemma3n/__init__.py +2 -0
  449. nexaai/mlx_backend/vlm/modeling/models/gemma3n/audio.py +1038 -0
  450. nexaai/mlx_backend/vlm/modeling/models/gemma3n/config.py +139 -0
  451. nexaai/mlx_backend/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
  452. nexaai/mlx_backend/vlm/modeling/models/gemma3n/language.py +629 -0
  453. nexaai/mlx_backend/vlm/modeling/models/gemma3n/vision.py +1022 -0
  454. nexaai/mlx_backend/vlm/modeling/models/idefics2/__init__.py +9 -0
  455. nexaai/mlx_backend/vlm/modeling/models/idefics2/idefics2.py +294 -0
  456. nexaai/mlx_backend/vlm/modeling/models/idefics2/language.py +191 -0
  457. nexaai/mlx_backend/vlm/modeling/models/idefics2/vision.py +267 -0
  458. nexaai/mlx_backend/vlm/modeling/models/idefics3/__init__.py +8 -0
  459. nexaai/mlx_backend/vlm/modeling/models/idefics3/idefics3.py +175 -0
  460. nexaai/mlx_backend/vlm/modeling/models/idefics3/language.py +192 -0
  461. nexaai/mlx_backend/vlm/modeling/models/idefics3/vision.py +233 -0
  462. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/__init__.py +9 -0
  463. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
  464. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/language.py +220 -0
  465. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/processor.py +393 -0
  466. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/vision.py +293 -0
  467. nexaai/mlx_backend/vlm/modeling/models/kernels.py +307 -0
  468. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/__init__.py +8 -0
  469. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
  470. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/language.py +509 -0
  471. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/vision.py +522 -0
  472. nexaai/mlx_backend/vlm/modeling/models/llama4/__init__.py +8 -0
  473. nexaai/mlx_backend/vlm/modeling/models/llama4/language.py +386 -0
  474. nexaai/mlx_backend/vlm/modeling/models/llama4/llama4.py +138 -0
  475. nexaai/mlx_backend/vlm/modeling/models/llama4/vision.py +560 -0
  476. nexaai/mlx_backend/vlm/modeling/models/llava/__init__.py +8 -0
  477. nexaai/mlx_backend/vlm/modeling/models/llava/language.py +240 -0
  478. nexaai/mlx_backend/vlm/modeling/models/llava/llava.py +153 -0
  479. nexaai/mlx_backend/vlm/modeling/models/llava/vision.py +259 -0
  480. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/__init__.py +9 -0
  481. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/language.py +236 -0
  482. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
  483. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/vision.py +303 -0
  484. nexaai/mlx_backend/vlm/modeling/models/llava_next/__init__.py +8 -0
  485. nexaai/mlx_backend/vlm/modeling/models/llava_next/language.py +230 -0
  486. nexaai/mlx_backend/vlm/modeling/models/llava_next/llava_next.py +160 -0
  487. nexaai/mlx_backend/vlm/modeling/models/llava_next/vision.py +243 -0
  488. nexaai/mlx_backend/vlm/modeling/models/mistral3/__init__.py +8 -0
  489. nexaai/mlx_backend/vlm/modeling/models/mistral3/mistral3.py +283 -0
  490. nexaai/mlx_backend/vlm/modeling/models/mllama/__init__.py +8 -0
  491. nexaai/mlx_backend/vlm/modeling/models/mllama/language.py +416 -0
  492. nexaai/mlx_backend/vlm/modeling/models/mllama/mllama.py +172 -0
  493. nexaai/mlx_backend/vlm/modeling/models/mllama/vision.py +499 -0
  494. nexaai/mlx_backend/vlm/modeling/models/molmo/__init__.py +8 -0
  495. nexaai/mlx_backend/vlm/modeling/models/molmo/language.py +243 -0
  496. nexaai/mlx_backend/vlm/modeling/models/molmo/molmo.py +133 -0
  497. nexaai/mlx_backend/vlm/modeling/models/molmo/vision.py +465 -0
  498. nexaai/mlx_backend/vlm/modeling/models/multi_modality/__init__.py +10 -0
  499. nexaai/mlx_backend/vlm/modeling/models/multi_modality/language.py +230 -0
  500. nexaai/mlx_backend/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
  501. nexaai/mlx_backend/vlm/modeling/models/multi_modality/sam.py +557 -0
  502. nexaai/mlx_backend/vlm/modeling/models/multi_modality/vision.py +526 -0
  503. nexaai/mlx_backend/vlm/modeling/models/paligemma/__init__.py +8 -0
  504. nexaai/mlx_backend/vlm/modeling/models/paligemma/language.py +282 -0
  505. nexaai/mlx_backend/vlm/modeling/models/paligemma/paligemma.py +160 -0
  506. nexaai/mlx_backend/vlm/modeling/models/paligemma/vision.py +242 -0
  507. nexaai/mlx_backend/vlm/modeling/models/phi3_v/__init__.py +8 -0
  508. nexaai/mlx_backend/vlm/modeling/models/phi3_v/language.py +21 -0
  509. nexaai/mlx_backend/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
  510. nexaai/mlx_backend/vlm/modeling/models/phi3_v/su_rope.py +71 -0
  511. nexaai/mlx_backend/vlm/modeling/models/phi3_v/vision.py +324 -0
  512. nexaai/mlx_backend/vlm/modeling/models/pixtral/__init__.py +8 -0
  513. nexaai/mlx_backend/vlm/modeling/models/pixtral/language.py +229 -0
  514. nexaai/mlx_backend/vlm/modeling/models/pixtral/pixtral.py +161 -0
  515. nexaai/mlx_backend/vlm/modeling/models/pixtral/vision.py +320 -0
  516. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
  517. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
  518. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
  519. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
  520. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
  521. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
  522. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/config.py +104 -0
  523. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/language.py +490 -0
  524. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
  525. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/vision.py +312 -0
  526. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
  527. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
  528. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
  529. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
  530. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
  531. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
  532. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
  533. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/processor.py +476 -0
  534. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +1262 -0
  535. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
  536. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
  537. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
  538. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
  539. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
  540. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
  541. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
  542. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
  543. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1308 -0
  544. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
  545. nexaai/mlx_backend/vlm/modeling/models/smolvlm/__init__.py +8 -0
  546. nexaai/mlx_backend/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
  547. nexaai/mlx_backend/vlm/modeling/processing_qwen2_5_vl.py +209 -0
  548. nexaai/mlx_backend/vlm/modeling/processing_qwen2_vl.py +215 -0
  549. nexaai/mlx_backend/vlm/modeling/prompt_utils.py +474 -0
  550. nexaai/mlx_backend/vlm/modeling/sample_utils.py +39 -0
  551. nexaai/mlx_backend/vlm/modeling/tokenizer_utils.py +344 -0
  552. nexaai/mlx_backend/vlm/modeling/trainer/__init__.py +9 -0
  553. nexaai/mlx_backend/vlm/modeling/trainer/lora.py +70 -0
  554. nexaai/mlx_backend/vlm/modeling/trainer/trainer.py +296 -0
  555. nexaai/mlx_backend/vlm/modeling/trainer/utils.py +160 -0
  556. nexaai/mlx_backend/vlm/modeling/utils.py +928 -0
  557. nexaai/rerank.py +57 -0
  558. nexaai/rerank_impl/__init__.py +0 -0
  559. nexaai/rerank_impl/mlx_rerank_impl.py +94 -0
  560. nexaai/rerank_impl/pybind_rerank_impl.py +136 -0
  561. nexaai/runtime.py +68 -0
  562. nexaai/runtime_error.py +24 -0
  563. nexaai/tts.py +75 -0
  564. nexaai/tts_impl/__init__.py +0 -0
  565. nexaai/tts_impl/mlx_tts_impl.py +94 -0
  566. nexaai/tts_impl/pybind_tts_impl.py +43 -0
  567. nexaai/utils/decode.py +18 -0
  568. nexaai/utils/manifest_utils.py +531 -0
  569. nexaai/utils/model_manager.py +1745 -0
  570. nexaai/utils/model_types.py +49 -0
  571. nexaai/utils/progress_tracker.py +389 -0
  572. nexaai/utils/quantization_utils.py +245 -0
  573. nexaai/vlm.py +130 -0
  574. nexaai/vlm_impl/__init__.py +0 -0
  575. nexaai/vlm_impl/mlx_vlm_impl.py +259 -0
  576. nexaai/vlm_impl/pybind_vlm_impl.py +275 -0
  577. nexaai-1.0.29.dist-info/METADATA +35 -0
  578. nexaai-1.0.29.dist-info/RECORD +580 -0
  579. nexaai-1.0.29.dist-info/WHEEL +5 -0
  580. nexaai-1.0.29.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1262 @@
1
+ # Copyright © 2023-2024 Apple Inc.
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, List, Optional, Tuple, Union
5
+
6
+ import mlx.core as mx
7
+ import mlx.nn as nn
8
+ import math
9
+ import numpy as np
10
+
11
+ # Import from nested llm_common structure using relative imports
12
+ from .llm_common.base import (
13
+ BaseModelArgs,
14
+ create_attention_mask,
15
+ scaled_dot_product_attention,
16
+ )
17
+ from .llm_common.rope_utils import initialize_rope
18
+
19
+
20
+ @dataclass
21
+ class VisionConfig:
22
+ hidden_size: int = 1024
23
+ intermediate_size: int = 4096
24
+ num_heads: int = 16
25
+ num_hidden_layers: int = 24
26
+ patch_size: int = 16
27
+ temporal_patch_size: int = 2
28
+ in_channels: int = 3
29
+ hidden_act: str = "gelu"
30
+ spatial_merge_size: int = 2
31
+ out_hidden_size: int = 2560
32
+ num_position_embeddings: int = 2304
33
+ deepstack_visual_indexes: List[int] = None
34
+
35
+ def __post_init__(self):
36
+ if self.deepstack_visual_indexes is None:
37
+ self.deepstack_visual_indexes = [3, 7, 11]
38
+
39
+
40
+ @dataclass
41
+ class TextConfig(BaseModelArgs):
42
+ model_type: str = "qwen3vl"
43
+ hidden_size: int = 2560
44
+ num_hidden_layers: int = 36
45
+ intermediate_size: int = 9728
46
+ num_attention_heads: int = 32
47
+ num_key_value_heads: int = 8
48
+ rms_norm_eps: float = 1e-6
49
+ vocab_size: int = 151936
50
+ max_position_embeddings: int = 32768
51
+ rope_theta: float = 10000.0
52
+ head_dim: int = 128
53
+ tie_word_embeddings: bool = True
54
+ attention_bias: bool = False
55
+ attention_dropout: float = 0.0
56
+ rope_scaling: Optional[Dict[str, Union[float, str]]] = None
57
+
58
+ def __post_init__(self):
59
+ if self.rope_scaling is None:
60
+ # Use default RoPE for now since MRoPE is not implemented in rope_utils
61
+ self.rope_scaling = None
62
+
63
+
64
+ @dataclass
65
+ class ModelArgs(BaseModelArgs):
66
+ vision_config: VisionConfig = None
67
+ text_config: TextConfig = None
68
+ image_token_id: int = 151655
69
+ vision_start_token_id: int = 151652
70
+ vision_end_token_id: int = 151653
71
+
72
+ def __post_init__(self):
73
+ if self.vision_config is None:
74
+ self.vision_config = VisionConfig()
75
+ if self.text_config is None:
76
+ self.text_config = TextConfig()
77
+
78
+
79
+ def rotate_half(x):
80
+ x1 = x[..., : x.shape[-1] // 2]
81
+ x2 = x[..., x.shape[-1] // 2 :]
82
+ return mx.concatenate([-x2, x1], axis=-1)
83
+
84
+
85
+ def apply_rotary_pos_emb_vision(q, k, cos, sin):
86
+ cos = mx.expand_dims(cos, axis=-2)
87
+ sin = mx.expand_dims(sin, axis=-2)
88
+ q_embed = (q * cos) + (rotate_half(q) * sin)
89
+ k_embed = (k * cos) + (rotate_half(k) * sin)
90
+ return q_embed, k_embed
91
+
92
+
93
+ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
94
+ cos = mx.expand_dims(cos, axis=unsqueeze_dim)
95
+ sin = mx.expand_dims(sin, axis=unsqueeze_dim)
96
+ q_embed = (q * cos) + (rotate_half(q) * sin)
97
+ k_embed = (k * cos) + (rotate_half(k) * sin)
98
+ return q_embed, k_embed
99
+
100
+
101
+ class VisionMLP(nn.Module):
102
+ def __init__(self, config: VisionConfig):
103
+ super().__init__()
104
+ self.hidden_size = config.hidden_size
105
+ self.intermediate_size = config.intermediate_size
106
+ self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True)
107
+ self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True)
108
+
109
+ def __call__(self, hidden_state):
110
+ return self.linear_fc2(nn.gelu(self.linear_fc1(hidden_state)))
111
+
112
+
113
+ class VisionPatchEmbed(nn.Module):
114
+ def __init__(self, config: VisionConfig):
115
+ super().__init__()
116
+ self.patch_size = config.patch_size
117
+ self.temporal_patch_size = config.temporal_patch_size
118
+ self.in_channels = config.in_channels
119
+ self.embed_dim = config.hidden_size
120
+
121
+ kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
122
+ self.proj = nn.Conv3d(
123
+ self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True
124
+ )
125
+
126
+ def __call__(self, hidden_states: mx.array) -> mx.array:
127
+ target_dtype = self.proj.weight.dtype
128
+
129
+ # Reshape to 5D: [batch, channels, temporal, height, width] (PyTorch format)
130
+ # This matches the PyTorch ground truth exactly
131
+ hidden_states = hidden_states.reshape(
132
+ -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
133
+ )
134
+
135
+ # Convert to MLX format: [batch, temporal, height, width, channels]
136
+ hidden_states = hidden_states.transpose(0, 2, 3, 4, 1)
137
+
138
+ # Apply conv3d with target dtype and reshape to match PyTorch output
139
+ hidden_states = self.proj(hidden_states.astype(target_dtype)).reshape(-1, self.embed_dim)
140
+
141
+ return hidden_states
142
+
143
+
144
+ class VisionRotaryEmbedding(nn.Module):
145
+ def __init__(self, dim: int, theta: float = 10000.0):
146
+ super().__init__()
147
+ # Don't store inv_freq as a parameter since it causes loading issues
148
+ self.dim = dim
149
+ self.theta = theta
150
+
151
+ def __call__(self, seqlen: int) -> mx.array:
152
+ # Compute inv_freq on the fly
153
+ inv_freq = 1.0 / (self.theta ** (mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim))
154
+ seq = mx.arange(seqlen, dtype=inv_freq.dtype)
155
+ freqs = mx.outer(seq, inv_freq)
156
+ return freqs
157
+
158
+
159
+ class VisionPatchMerger(nn.Module):
160
+ def __init__(self, config: VisionConfig, use_postshuffle_norm=False):
161
+ super().__init__()
162
+ self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
163
+ self.use_postshuffle_norm = use_postshuffle_norm
164
+
165
+ norm_size = self.hidden_size if use_postshuffle_norm else config.hidden_size
166
+ self.norm = nn.LayerNorm(norm_size, eps=1e-6)
167
+ self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size)
168
+ self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size)
169
+
170
+ def __call__(self, x: mx.array) -> mx.array:
171
+ if self.use_postshuffle_norm:
172
+ x = self.norm(x.reshape(-1, self.hidden_size)).reshape(-1, self.hidden_size)
173
+ else:
174
+ x = self.norm(x).reshape(-1, self.hidden_size)
175
+
176
+ x = self.linear_fc2(nn.gelu(self.linear_fc1(x)))
177
+ return x
178
+
179
+
180
+ class VisionAttention(nn.Module):
181
+ def __init__(self, config: VisionConfig):
182
+ super().__init__()
183
+ self.dim = config.hidden_size
184
+ self.num_heads = config.num_heads
185
+ self.head_dim = self.dim // self.num_heads
186
+ self.scaling = self.head_dim**-0.5
187
+
188
+ self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True)
189
+ self.proj = nn.Linear(self.dim, self.dim)
190
+
191
+ def __call__(
192
+ self,
193
+ hidden_states: mx.array,
194
+ cu_seqlens: mx.array,
195
+ rotary_pos_emb: Optional[mx.array] = None,
196
+ position_embeddings: Optional[Tuple[mx.array, mx.array]] = None,
197
+ **kwargs,
198
+ ) -> mx.array:
199
+ seq_length = hidden_states.shape[0]
200
+ qkv = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1)
201
+ qkv = qkv.transpose(1, 0, 2, 3)
202
+ query_states, key_states, value_states = qkv[0], qkv[1], qkv[2]
203
+
204
+ cos, sin = position_embeddings
205
+ query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
206
+
207
+ query_states = query_states.transpose(1, 0, 2)
208
+ key_states = key_states.transpose(1, 0, 2)
209
+ value_states = value_states.transpose(1, 0, 2)
210
+
211
+ query_states = mx.expand_dims(query_states, axis=0)
212
+ key_states = mx.expand_dims(key_states, axis=0)
213
+ value_states = mx.expand_dims(value_states, axis=0)
214
+
215
+ lengths = cu_seqlens[1:] - cu_seqlens[:-1]
216
+
217
+ split_indices = []
218
+ cumsum = 0
219
+ for length in lengths[:-1]:
220
+ cumsum += int(length)
221
+ split_indices.append(cumsum)
222
+
223
+ if split_indices:
224
+ q_splits = mx.split(query_states, split_indices, axis=1)
225
+ k_splits = mx.split(key_states, split_indices, axis=1)
226
+ v_splits = mx.split(value_states, split_indices, axis=1)
227
+ else:
228
+ q_splits = [query_states]
229
+ k_splits = [key_states]
230
+ v_splits = [value_states]
231
+
232
+ attn_outputs = []
233
+ for q, k, v in zip(q_splits, k_splits, v_splits):
234
+ attn_out = scaled_dot_product_attention(
235
+ q, k, v, scale=self.scaling, mask=None, cache=None
236
+ )
237
+ attn_outputs.append(attn_out)
238
+
239
+ attn_output = mx.concatenate(attn_outputs, axis=1)
240
+
241
+ attn_output = attn_output[0].transpose(1, 0, 2)
242
+ attn_output = attn_output.reshape(seq_length, -1)
243
+ attn_output = self.proj(attn_output)
244
+
245
+ return attn_output
246
+
247
+
248
+ class VisionBlock(nn.Module):
249
+ def __init__(self, config: VisionConfig):
250
+ super().__init__()
251
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6)
252
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6)
253
+ self.attn = VisionAttention(config)
254
+ self.mlp = VisionMLP(config)
255
+
256
+ def __call__(
257
+ self,
258
+ hidden_states: mx.array,
259
+ cu_seqlens: mx.array,
260
+ position_embeddings: Tuple[mx.array, mx.array],
261
+ ) -> mx.array:
262
+ hidden_states = hidden_states + self.attn(
263
+ self.norm1(hidden_states),
264
+ cu_seqlens=cu_seqlens,
265
+ position_embeddings=position_embeddings,
266
+ )
267
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
268
+ return hidden_states
269
+
270
+
271
+ class VisionModel(nn.Module):
272
+ def __init__(self, config: VisionConfig):
273
+ super().__init__()
274
+ self.config = config
275
+ self.spatial_merge_size = config.spatial_merge_size
276
+ self.patch_size = config.patch_size
277
+
278
+ self.patch_embed = VisionPatchEmbed(config)
279
+ self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size)
280
+ self.num_grid_per_side = int(config.num_position_embeddings**0.5)
281
+
282
+ head_dim = config.hidden_size // config.num_heads
283
+ self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
284
+
285
+ self.blocks = [VisionBlock(config) for _ in range(config.num_hidden_layers)]
286
+ self.merger = VisionPatchMerger(config, use_postshuffle_norm=False)
287
+
288
+ self.deepstack_visual_indexes = config.deepstack_visual_indexes
289
+ self.deepstack_merger_list = [
290
+ VisionPatchMerger(config, use_postshuffle_norm=True)
291
+ for _ in range(len(config.deepstack_visual_indexes))
292
+ ]
293
+
294
+ def rot_pos_emb(self, grid_thw: mx.array) -> mx.array:
295
+ merge_size = self.spatial_merge_size
296
+
297
+ max_hw = int(grid_thw[:, 1:].max().item())
298
+ freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2)
299
+
300
+ pos_ids_parts = []
301
+
302
+ for i in range(grid_thw.shape[0]):
303
+ num_frames = int(grid_thw[i, 0].item())
304
+ height = int(grid_thw[i, 1].item())
305
+ width = int(grid_thw[i, 2].item())
306
+
307
+ merged_h, merged_w = height // merge_size, width // merge_size
308
+
309
+ block_rows = mx.arange(merged_h) # block row indices
310
+ block_cols = mx.arange(merged_w) # block col indices
311
+ intra_row = mx.arange(merge_size) # intra-block row offsets
312
+ intra_col = mx.arange(merge_size) # intra-block col offsets
313
+
314
+ # Compute full-resolution positions using broadcasting
315
+ row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None]
316
+ col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :]
317
+
318
+ row_idx = mx.broadcast_to(
319
+ row_idx, (merged_h, merged_w, merge_size, merge_size)
320
+ ).reshape(-1)
321
+ col_idx = mx.broadcast_to(
322
+ col_idx, (merged_h, merged_w, merge_size, merge_size)
323
+ ).reshape(-1)
324
+
325
+ coords = mx.stack([row_idx, col_idx], axis=-1)
326
+
327
+ if num_frames > 1:
328
+ coords = mx.tile(coords, (num_frames, 1))
329
+
330
+ pos_ids_parts.append(coords)
331
+
332
+ # Concatenate all coordinate parts
333
+ pos_ids = mx.concatenate(pos_ids_parts, axis=0)
334
+
335
+ embeddings = freq_table[pos_ids] # lookup rotary embeddings
336
+ embeddings = embeddings.reshape(embeddings.shape[0], -1)
337
+ return embeddings
338
+
339
+ def fast_pos_embed_interpolate(self, grid_thw: mx.array):
340
+ patch_pos_embeds = []
341
+
342
+ for i in range(grid_thw.shape[0]):
343
+ t = int(grid_thw[i, 0].item())
344
+ h = int(grid_thw[i, 1].item())
345
+ w = int(grid_thw[i, 2].item())
346
+
347
+ # Simple position embedding interpolation
348
+ h_idxs = mx.linspace(0, self.num_grid_per_side - 1, h)
349
+ w_idxs = mx.linspace(0, self.num_grid_per_side - 1, w)
350
+
351
+ h_idxs_floor = mx.floor(h_idxs).astype(mx.int32)
352
+ w_idxs_floor = mx.floor(w_idxs).astype(mx.int32)
353
+ h_idxs_ceil = mx.minimum(h_idxs_floor + 1, self.num_grid_per_side - 1)
354
+ w_idxs_ceil = mx.minimum(w_idxs_floor + 1, self.num_grid_per_side - 1)
355
+
356
+ dh = h_idxs - h_idxs_floor.astype(mx.float32)
357
+ dw = w_idxs - w_idxs_floor.astype(mx.float32)
358
+
359
+ base_h = h_idxs_floor * self.num_grid_per_side
360
+ base_h_ceil = h_idxs_ceil * self.num_grid_per_side
361
+
362
+ # Compute bilinear interpolation indices and weights
363
+ indices_tl = (base_h[:, None] + w_idxs_floor[None, :]).reshape(-1)
364
+ indices_tr = (base_h[:, None] + w_idxs_ceil[None, :]).reshape(-1)
365
+ indices_bl = (base_h_ceil[:, None] + w_idxs_floor[None, :]).reshape(-1)
366
+ indices_br = (base_h_ceil[:, None] + w_idxs_ceil[None, :]).reshape(-1)
367
+
368
+ weights_tl = ((1 - dh)[:, None] * (1 - dw)[None, :]).reshape(-1)
369
+ weights_tr = ((1 - dh)[:, None] * dw[None, :]).reshape(-1)
370
+ weights_bl = (dh[:, None] * (1 - dw)[None, :]).reshape(-1)
371
+ weights_br = (dh[:, None] * dw[None, :]).reshape(-1)
372
+
373
+ # Get embeddings and interpolate
374
+ pos_embed_tl = self.pos_embed(indices_tl) * weights_tl[:, None]
375
+ pos_embed_tr = self.pos_embed(indices_tr) * weights_tr[:, None]
376
+ pos_embed_bl = self.pos_embed(indices_bl) * weights_bl[:, None]
377
+ pos_embed_br = self.pos_embed(indices_br) * weights_br[:, None]
378
+
379
+ pos_embed = pos_embed_tl + pos_embed_tr + pos_embed_bl + pos_embed_br
380
+
381
+ # Repeat for temporal dimension and apply spatial merging
382
+ pos_embed = mx.tile(pos_embed, (t, 1))
383
+
384
+ # Apply spatial merging pattern
385
+ merge_size = self.config.spatial_merge_size
386
+ pos_embed = pos_embed.reshape(
387
+ t, h // merge_size, merge_size, w // merge_size, merge_size, -1
388
+ )
389
+ pos_embed = mx.transpose(pos_embed, (0, 1, 3, 2, 4, 5))
390
+ pos_embed = pos_embed.reshape(-1, pos_embed.shape[-1])
391
+
392
+ patch_pos_embeds.append(pos_embed)
393
+
394
+ return mx.concatenate(patch_pos_embeds, axis=0)
395
+
396
+ def __call__(
397
+ self, hidden_states: mx.array, grid_thw: mx.array
398
+ ) -> Tuple[mx.array, List[mx.array]]:
399
+ hidden_states = self.patch_embed(hidden_states)
400
+
401
+ pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
402
+ hidden_states = hidden_states + pos_embeds
403
+
404
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
405
+ seq_len = hidden_states.shape[0]
406
+
407
+ emb = mx.concatenate([rotary_pos_emb, rotary_pos_emb], axis=-1)
408
+ position_embeddings = (mx.cos(emb), mx.sin(emb))
409
+
410
+ # Create cumulative sequence lengths (following HuggingFace implementation)
411
+ # torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0])
412
+ seq_lens_per_image = grid_thw[:, 1] * grid_thw[:, 2] # h * w for each image
413
+ seq_lens = []
414
+ for i, (seq_len, repeats) in enumerate(zip(seq_lens_per_image, grid_thw[:, 0])):
415
+ seq_lens.extend([seq_len] * int(repeats))
416
+ seq_lens = mx.array(seq_lens)
417
+
418
+ # Then compute cumulative sum
419
+ cu_seqlens = mx.cumsum(seq_lens)
420
+ # Pad with 0 at the beginning
421
+ cu_seqlens = mx.concatenate([mx.array([0]), cu_seqlens])
422
+
423
+ deepstack_feature_lists = []
424
+ for layer_num, blk in enumerate(self.blocks):
425
+ hidden_states = blk(
426
+ hidden_states,
427
+ cu_seqlens=cu_seqlens,
428
+ position_embeddings=position_embeddings,
429
+ )
430
+ if layer_num in self.deepstack_visual_indexes:
431
+ idx = self.deepstack_visual_indexes.index(layer_num)
432
+ deepstack_feature = self.deepstack_merger_list[idx](hidden_states)
433
+ deepstack_feature_lists.append(deepstack_feature)
434
+
435
+ hidden_states = self.merger(hidden_states)
436
+ return hidden_states, deepstack_feature_lists
437
+
438
+
439
+ class TextRotaryEmbedding(nn.Module):
440
+ def __init__(self, config: TextConfig):
441
+ super().__init__()
442
+ self.config = config
443
+ self.max_seq_len_cached = config.max_position_embeddings
444
+ self.original_max_seq_len = config.max_position_embeddings
445
+
446
+ # MRoPE configuration
447
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
448
+ self.rope_type = config.rope_scaling.get("rope_type", "default")
449
+ self.mrope_section = config.rope_scaling.get("mrope_section", [24, 20, 20])
450
+ else:
451
+ self.rope_type = "default"
452
+ self.mrope_section = [24, 20, 20]
453
+
454
+ # Store parameters for computing inv_freq on the fly
455
+ self.head_dim = config.head_dim
456
+ self.theta = config.rope_theta
457
+
458
+ # Attention scaling (simplified - may need adjustment based on actual config)
459
+ self.attention_scaling = 1.0
460
+
461
+ def _get_inv_freq(self):
462
+ """Compute inverse frequencies on the fly"""
463
+ inv_freq = 1.0 / (
464
+ self.theta ** (mx.arange(0, self.head_dim, 2).astype(mx.float32) / self.head_dim)
465
+ )
466
+ # Expand for 3 dimensions (T, H, W)
467
+ return mx.broadcast_to(inv_freq[None, :], (3, len(inv_freq)))
468
+
469
+ def apply_interleaved_mrope(self, freqs, mrope_section):
470
+ """Apply interleaved MRoPE to 3D rotary embeddings.
471
+ Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
472
+ interleaved [THTHWHTHW...TT], preserving frequency continuity.
473
+ args:
474
+ x: (3, bs, seq_len, head_dim // 2)
475
+ mrope_section: (3,)
476
+ returns:
477
+ x_t: (bs, seq_len, head_dim // 2)
478
+ """
479
+ freqs_t = freqs[0] # just overwrite the first dimension T
480
+ for dim, offset in enumerate((1, 2), start=1): # H, W
481
+ length = mrope_section[dim] * 3
482
+ idx = slice(offset, length, 3)
483
+ freqs_t[..., idx] = freqs[dim, ..., idx]
484
+ return freqs_t
485
+
486
+ def __call__(self, x: mx.array, position_ids: mx.array) -> mx.array:
487
+ """
488
+ Args:
489
+ x: Input tensor for dtype reference
490
+ position_ids: Position indices, shape (3, batch_size, seq_len) for MRoPE
491
+
492
+ Returns:
493
+ cos, sin: Cosine and sine embeddings
494
+ """
495
+ # Handle 2D position_ids by expanding to 3D for MRoPE
496
+ if position_ids.ndim == 2:
497
+ position_ids = mx.broadcast_to(
498
+ position_ids[None, ...], (3, position_ids.shape[0], position_ids.shape[1])
499
+ )
500
+
501
+ batch_size, seq_len = position_ids.shape[1], position_ids.shape[2]
502
+
503
+ # Expand inverse frequencies: (3, 1, 1, dim//2) -> (3, batch_size, 1, dim//2)
504
+ inv_freq_expanded = mx.broadcast_to(
505
+ self._get_inv_freq()[:, None, None, :],
506
+ (3, batch_size, 1, self._get_inv_freq().shape[-1]),
507
+ )
508
+
509
+ # Expand position ids: (3, batch_size, seq_len) -> (3, batch_size, seq_len, 1)
510
+ position_ids_expanded = position_ids[..., None].astype(mx.float32)
511
+
512
+ # Compute frequencies: (3, batch_size, seq_len, dim//2)
513
+ freqs = inv_freq_expanded * position_ids_expanded
514
+
515
+ # Apply interleaved MRoPE
516
+ freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
517
+
518
+ # Create embeddings
519
+ emb = mx.concatenate([freqs, freqs], axis=-1) # (batch_size, seq_len, head_dim)
520
+ cos = mx.cos(emb) * self.attention_scaling
521
+ sin = mx.sin(emb) * self.attention_scaling
522
+
523
+ return cos.astype(x.dtype), sin.astype(x.dtype)
524
+
525
+
526
+ class TextAttention(nn.Module):
527
+ def __init__(self, config: TextConfig, layer_idx: int):
528
+ super().__init__()
529
+ self.config = config
530
+ self.layer_idx = layer_idx
531
+
532
+ dim = config.hidden_size
533
+ self.n_heads = config.num_attention_heads
534
+ self.n_kv_heads = config.num_key_value_heads
535
+ self.head_dim = config.head_dim
536
+ self.scale = self.head_dim**-0.5
537
+
538
+ self.q_proj = nn.Linear(dim, self.n_heads * self.head_dim, bias=config.attention_bias)
539
+ self.k_proj = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=config.attention_bias)
540
+ self.v_proj = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=config.attention_bias)
541
+ self.o_proj = nn.Linear(self.n_heads * self.head_dim, dim, bias=config.attention_bias)
542
+
543
+ self.q_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
544
+ self.k_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
545
+
546
+ # Initialize rope directly
547
+ self.rope = initialize_rope(
548
+ config.head_dim,
549
+ base=config.rope_theta,
550
+ traditional=False,
551
+ scaling_config=config.rope_scaling,
552
+ max_position_embeddings=config.max_position_embeddings,
553
+ )
554
+
555
+ def __call__(
556
+ self,
557
+ hidden_states: mx.array,
558
+ attention_mask: Optional[mx.array] = None,
559
+ cache: Optional[Any] = None,
560
+ cos: Optional[mx.array] = None,
561
+ sin: Optional[mx.array] = None,
562
+ rope_deltas: Optional[mx.array] = None,
563
+ ) -> Tuple[mx.array, Optional[mx.array]]:
564
+ B, L, D = hidden_states.shape
565
+
566
+ queries = self.q_proj(hidden_states).reshape(B, L, self.n_heads, -1)
567
+ keys = self.k_proj(hidden_states).reshape(B, L, self.n_kv_heads, -1)
568
+ values = self.v_proj(hidden_states).reshape(B, L, self.n_kv_heads, -1)
569
+
570
+ queries = self.q_norm(queries).transpose(0, 2, 1, 3)
571
+ keys = self.k_norm(keys).transpose(0, 2, 1, 3)
572
+ values = values.transpose(0, 2, 1, 3)
573
+
574
+ # Apply rope directly to queries and keys
575
+ if cos is not None and sin is not None:
576
+ queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin)
577
+ if cache is not None:
578
+ keys, values = cache.update_and_fetch(keys, values)
579
+ else:
580
+ if cache is not None:
581
+ # Handle different types of rope_deltas: scalar, array, or None
582
+ if rope_deltas is None:
583
+ offset_delta = 0
584
+ elif isinstance(rope_deltas, (int, float)):
585
+ # rope_deltas is a scalar
586
+ offset_delta = rope_deltas
587
+ elif hasattr(rope_deltas, 'size') and rope_deltas.size == 1:
588
+ # rope_deltas is an array with single element
589
+ offset_delta = rope_deltas.item()
590
+ elif hasattr(rope_deltas, 'shape') and rope_deltas.shape:
591
+ # rope_deltas is an array with multiple elements, take first
592
+ offset_delta = rope_deltas.reshape(-1)[0].item()
593
+ else:
594
+ offset_delta = 0
595
+
596
+ queries = self.rope(queries, offset=cache.offset + offset_delta)
597
+ keys = self.rope(keys, offset=cache.offset + offset_delta)
598
+ keys, values = cache.update_and_fetch(keys, values)
599
+ else:
600
+ queries = self.rope(queries)
601
+ keys = self.rope(keys)
602
+
603
+ output = scaled_dot_product_attention(
604
+ queries, keys, values, cache=cache, scale=self.scale, mask=attention_mask
605
+ )
606
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
607
+ return self.o_proj(output), None
608
+
609
+
610
+ class TextMLP(nn.Module):
611
+ def __init__(self, config: TextConfig):
612
+ super().__init__()
613
+ self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
614
+ self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
615
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
616
+
617
+ def __call__(self, x):
618
+ return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
619
+
620
+
621
+ class TextDecoderLayer(nn.Module):
622
+ def __init__(self, config: TextConfig, layer_idx: int):
623
+ super().__init__()
624
+ self.hidden_size = config.hidden_size
625
+ self.self_attn = TextAttention(config, layer_idx)
626
+ self.mlp = TextMLP(config)
627
+ self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
628
+ self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
629
+
630
+ def __call__(
631
+ self,
632
+ hidden_states: mx.array,
633
+ attention_mask: Optional[mx.array] = None,
634
+ cache: Optional[Any] = None,
635
+ cos: Optional[mx.array] = None,
636
+ sin: Optional[mx.array] = None,
637
+ rope_deltas: Optional[mx.array] = None,
638
+ ) -> mx.array:
639
+ residual = hidden_states
640
+ hidden_states = self.input_layernorm(hidden_states)
641
+
642
+ hidden_states, _ = self.self_attn(
643
+ hidden_states=hidden_states,
644
+ attention_mask=attention_mask,
645
+ cache=cache,
646
+ cos=cos,
647
+ sin=sin,
648
+ rope_deltas=rope_deltas,
649
+ )
650
+ hidden_states = residual + hidden_states
651
+ residual = hidden_states
652
+ hidden_states = self.post_attention_layernorm(hidden_states)
653
+ hidden_states = self.mlp(hidden_states)
654
+ hidden_states = residual + hidden_states
655
+ return hidden_states
656
+
657
+
658
+ class TextModel(nn.Module):
659
+ def __init__(self, config: TextConfig):
660
+ super().__init__()
661
+ self.config = config
662
+ self.vocab_size = config.vocab_size
663
+
664
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
665
+ self.layers = [
666
+ TextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)
667
+ ]
668
+ self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
669
+ self.rotary_emb = TextRotaryEmbedding(config)
670
+
671
+ def _deepstack_process(
672
+ self,
673
+ hidden_states: mx.array,
674
+ visual_pos_masks: mx.array,
675
+ deepstack_visual_embeds: mx.array,
676
+ ) -> mx.array:
677
+ if visual_pos_masks is None or deepstack_visual_embeds is None:
678
+ return hidden_states
679
+ B, L, D = hidden_states.shape
680
+ mask_flat = visual_pos_masks.astype(mx.int32).reshape(-1)
681
+ idx_flat = mx.cumsum(mask_flat, axis=0) - 1
682
+ N = deepstack_visual_embeds.shape[0]
683
+ idx_flat = mx.maximum(idx_flat, 0)
684
+ eq = (idx_flat[:, None] == mx.arange(N)[None, :]).astype(hidden_states.dtype)
685
+ add_flat = eq @ deepstack_visual_embeds.astype(hidden_states.dtype)
686
+ add_flat = add_flat * mask_flat[:, None].astype(hidden_states.dtype)
687
+ add = add_flat.reshape(B, L, D)
688
+ return hidden_states + add
689
+
690
+ def __call__(
691
+ self,
692
+ input_ids: Optional[mx.array] = None,
693
+ inputs_embeds: Optional[mx.array] = None,
694
+ attention_mask: Optional[mx.array] = None,
695
+ cache=None,
696
+ visual_pos_masks: Optional[mx.array] = None,
697
+ deepstack_visual_embeds: Optional[List[mx.array]] = None,
698
+ cos: Optional[mx.array] = None,
699
+ sin: Optional[mx.array] = None,
700
+ rope_deltas: Optional[mx.array] = None,
701
+ ):
702
+ if inputs_embeds is None:
703
+ inputs_embeds = self.embed_tokens(input_ids)
704
+
705
+ hidden_states = inputs_embeds
706
+
707
+ if attention_mask is None:
708
+ attention_mask = create_attention_mask(hidden_states, cache, return_array=True)
709
+
710
+ if cache is None:
711
+ cache = [None] * len(self.layers)
712
+
713
+ for layer_idx, (decoder_layer, c) in enumerate(zip(self.layers, cache)):
714
+ hidden_states = decoder_layer(
715
+ hidden_states,
716
+ attention_mask=attention_mask,
717
+ cache=c,
718
+ cos=cos,
719
+ sin=sin,
720
+ rope_deltas=rope_deltas,
721
+ )
722
+ if deepstack_visual_embeds is not None and layer_idx < len(deepstack_visual_embeds):
723
+ hidden_states = self._deepstack_process(
724
+ hidden_states, visual_pos_masks, deepstack_visual_embeds[layer_idx]
725
+ )
726
+ hidden_states = self.norm(hidden_states)
727
+ return hidden_states
728
+
729
+
730
+ # Standalone Vision Model
731
+ class VEGModel(nn.Module):
732
+ def __init__(self, vision_config: VisionConfig):
733
+ super().__init__()
734
+ self.config = vision_config
735
+ self.visual = VisionModel(vision_config)
736
+
737
+ def __call__(self, pixel_values: mx.array, image_grid_thw: mx.array):
738
+ return self.visual(pixel_values, image_grid_thw)
739
+
740
+ def sanitize(self, weights):
741
+ sanitized = {}
742
+ for k, v in weights.items():
743
+ if "visual." in k:
744
+ # Remove prefixes to match our model structure
745
+ clean_key = k.replace("model.visual.", "").replace("visual.", "")
746
+ sanitized[f"visual.{clean_key}"] = v
747
+ return sanitized
748
+
749
+
750
+ # Pure LLM Model (no vision components)
751
+ class LLMModel(nn.Module):
752
+ def __init__(self, text_config: TextConfig):
753
+ super().__init__()
754
+ self.args = text_config
755
+ self.config = text_config
756
+ self.language_model = TextModel(text_config)
757
+ if not text_config.tie_word_embeddings:
758
+ self.lm_head = nn.Linear(text_config.hidden_size, text_config.vocab_size, bias=False)
759
+
760
+ def get_rope_index(
761
+ self,
762
+ input_ids: Optional[mx.array] = None,
763
+ image_grid_thw: Optional[mx.array] = None,
764
+ attention_mask: Optional[mx.array] = None,
765
+ ) -> Tuple[mx.array, mx.array]:
766
+ """Simplified version for images only (no video support)."""
767
+
768
+ spatial_merge_size = 2
769
+ image_token_id = 151655
770
+ vision_start_token_id = 151652
771
+ mrope_position_deltas = []
772
+
773
+ if input_ids is not None and image_grid_thw is not None:
774
+ total_input_ids = input_ids
775
+ if attention_mask is None:
776
+ attention_mask = mx.ones_like(total_input_ids)
777
+
778
+ batch_size, seq_len = input_ids.shape
779
+ position_ids_list = []
780
+ image_index = 0
781
+
782
+ for i in range(batch_size):
783
+ input_ids_seq = total_input_ids[i]
784
+ mask_seq = attention_mask[i]
785
+
786
+ # Use mask to get valid length
787
+ valid_length = int(mx.sum(mask_seq).item())
788
+ input_ids_seq = input_ids_seq[:valid_length]
789
+
790
+ image_nums = 0
791
+ # Find vision start tokens by iterating through the sequence
792
+ vision_start_positions = []
793
+ for pos in range(input_ids_seq.shape[0]):
794
+ if input_ids_seq[pos].item() == vision_start_token_id:
795
+ vision_start_positions.append(pos)
796
+
797
+ if len(vision_start_positions) > 0:
798
+ for pos in vision_start_positions:
799
+ if pos + 1 < input_ids_seq.shape[0]:
800
+ if input_ids_seq[pos + 1].item() == image_token_id:
801
+ image_nums += 1
802
+
803
+ input_tokens = input_ids_seq.tolist()
804
+ llm_pos_ids_list = []
805
+ st = 0
806
+ remain_images = image_nums
807
+
808
+ for _ in range(image_nums):
809
+ ed_image = input_tokens.index(image_token_id, st)
810
+
811
+ t = image_grid_thw[image_index, 0].item()
812
+ h = image_grid_thw[image_index, 1].item()
813
+ w = image_grid_thw[image_index, 2].item()
814
+ image_index += 1
815
+ remain_images -= 1
816
+ ed = ed_image
817
+
818
+ llm_grid_t = int(t)
819
+ llm_grid_h = int(h) // spatial_merge_size
820
+ llm_grid_w = int(w) // spatial_merge_size
821
+ text_len = ed - st
822
+
823
+ st_idx = (
824
+ llm_pos_ids_list[-1].max().item() + 1 if len(llm_pos_ids_list) > 0 else 0
825
+ )
826
+ text_pos = mx.arange(text_len).reshape(1, -1)
827
+ text_pos = mx.broadcast_to(text_pos, (3, text_len)) + st_idx
828
+ llm_pos_ids_list.append(text_pos)
829
+
830
+ # t_index is always 0 because llm_grid_t is always 1 for images
831
+ t_index = mx.arange(llm_grid_t).reshape(-1, 1)
832
+ t_index = mx.broadcast_to(
833
+ t_index, (llm_grid_t, llm_grid_h * llm_grid_w)
834
+ ).reshape(-1)
835
+
836
+ h_index = mx.arange(llm_grid_h).reshape(1, -1, 1)
837
+ h_index = mx.broadcast_to(
838
+ h_index, (llm_grid_t, llm_grid_h, llm_grid_w)
839
+ ).reshape(-1)
840
+
841
+ w_index = mx.arange(llm_grid_w).reshape(1, 1, -1)
842
+ w_index = mx.broadcast_to(
843
+ w_index, (llm_grid_t, llm_grid_h, llm_grid_w)
844
+ ).reshape(-1)
845
+
846
+ vision_pos = mx.stack([t_index, h_index, w_index]) + text_len + st_idx
847
+ llm_pos_ids_list.append(vision_pos)
848
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
849
+
850
+ if st < len(input_tokens):
851
+ st_idx = (
852
+ llm_pos_ids_list[-1].max().item() + 1 if len(llm_pos_ids_list) > 0 else 0
853
+ )
854
+ text_len = len(input_tokens) - st
855
+ text_pos = mx.arange(text_len).reshape(1, -1)
856
+ text_pos = mx.broadcast_to(text_pos, (3, text_len)) + st_idx
857
+ llm_pos_ids_list.append(text_pos)
858
+
859
+ llm_positions = mx.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
860
+
861
+ # Create position_ids for this batch item, pad to seq_len
862
+ batch_position_ids = mx.ones((3, seq_len), dtype=input_ids.dtype)
863
+ valid_length = min(seq_len, llm_positions.shape[1])
864
+
865
+ # Create new arrays for each dimension
866
+ pos_dim0 = mx.concatenate(
867
+ [
868
+ llm_positions[0, :valid_length],
869
+ mx.ones(seq_len - valid_length, dtype=input_ids.dtype),
870
+ ]
871
+ )
872
+ pos_dim1 = mx.concatenate(
873
+ [
874
+ llm_positions[1, :valid_length],
875
+ mx.ones(seq_len - valid_length, dtype=input_ids.dtype),
876
+ ]
877
+ )
878
+ pos_dim2 = mx.concatenate(
879
+ [
880
+ llm_positions[2, :valid_length],
881
+ mx.ones(seq_len - valid_length, dtype=input_ids.dtype),
882
+ ]
883
+ )
884
+
885
+ batch_position_ids = mx.stack([pos_dim0, pos_dim1, pos_dim2])
886
+ position_ids_list.append(batch_position_ids)
887
+
888
+ mrope_position_deltas.append(
889
+ llm_positions.max().item() + 1 - len(total_input_ids[i])
890
+ )
891
+
892
+ # Stack all batch position_ids
893
+ position_ids = mx.stack(position_ids_list, axis=1) # Shape: (3, batch_size, seq_len)
894
+ mrope_position_deltas = mx.array(mrope_position_deltas).reshape(-1, 1)
895
+ return position_ids, mrope_position_deltas
896
+ else:
897
+ if attention_mask is not None:
898
+ position_ids = mx.cumsum(attention_mask.astype(mx.int32), axis=-1) - 1
899
+ position_ids = mx.where(attention_mask == 0, 1, position_ids)
900
+ position_ids = mx.expand_dims(position_ids, axis=0)
901
+ position_ids = mx.broadcast_to(
902
+ position_ids, (3, position_ids.shape[1], position_ids.shape[2])
903
+ )
904
+ max_position_ids = mx.max(
905
+ mx.max(position_ids, axis=0, keepdims=False), axis=-1, keepdims=True
906
+ )
907
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
908
+ else:
909
+ seq_len = input_ids.shape[1]
910
+ batch_size = input_ids.shape[0]
911
+ position_ids = mx.arange(seq_len).reshape(1, 1, -1)
912
+ position_ids = mx.broadcast_to(position_ids, (3, batch_size, seq_len))
913
+ mrope_position_deltas = mx.zeros((batch_size, 1), dtype=input_ids.dtype)
914
+
915
+ return position_ids, mrope_position_deltas
916
+
917
+ def __call__(
918
+ self,
919
+ inputs: mx.array = None,
920
+ mask: mx.array = None,
921
+ cache=None,
922
+ inputs_embeds: Optional[mx.array] = None,
923
+ visual_pos_masks: Optional[mx.array] = None,
924
+ deepstack_visual_embeds: Optional[List[mx.array]] = None,
925
+ cos: Optional[mx.array] = None,
926
+ sin: Optional[mx.array] = None,
927
+ rope_deltas: Optional[mx.array] = None,
928
+ ):
929
+ out = self.language_model(
930
+ input_ids=inputs,
931
+ inputs_embeds=inputs_embeds,
932
+ attention_mask=mask,
933
+ cache=cache,
934
+ visual_pos_masks=visual_pos_masks,
935
+ deepstack_visual_embeds=deepstack_visual_embeds,
936
+ cos=cos,
937
+ sin=sin,
938
+ rope_deltas=rope_deltas,
939
+ )
940
+ if self.args.tie_word_embeddings:
941
+ return self.language_model.embed_tokens.as_linear(out)
942
+ else:
943
+ return self.lm_head(out)
944
+
945
+ def sanitize(self, weights):
946
+ sanitized = {}
947
+ for k, v in weights.items():
948
+ if not ("visual." in k):
949
+ # Handle key mapping from combined model to LLM-only model
950
+ clean_key = k
951
+
952
+ # Remove model. prefix if present
953
+ if clean_key.startswith("model."):
954
+ clean_key = clean_key[6:] # Remove 'model.'
955
+
956
+ # Map language_ prefixed keys to language_model structure
957
+ if clean_key.startswith("language_"):
958
+ if clean_key.startswith("language_layers."):
959
+ clean_key = (
960
+ "language_model.layers." + clean_key[16:]
961
+ ) # Map to language_model.layers.
962
+ elif clean_key.startswith("language_embed_tokens."):
963
+ clean_key = (
964
+ "language_model.embed_tokens." + clean_key[22:]
965
+ ) # Map to language_model.embed_tokens.
966
+ elif clean_key.startswith("language_norm."):
967
+ clean_key = (
968
+ "language_model.norm." + clean_key[14:]
969
+ ) # Map to language_model.norm.
970
+
971
+ sanitized[clean_key] = v
972
+
973
+ # Handle tied embeddings - remove lm_head if using tied embeddings
974
+ if self.args.tie_word_embeddings:
975
+ sanitized.pop("lm_head.weight", None)
976
+
977
+ return sanitized
978
+
979
+ @property
980
+ def layers(self):
981
+ return self.language_model.layers
982
+
983
+
984
+ # Combined Model (for compatibility and utility functions)
985
+ class Qwen3VLModel(nn.Module):
986
+ def __init__(self, args: ModelArgs):
987
+ super().__init__()
988
+ self.args = args
989
+ self.config = args
990
+ self.visual = VisionModel(args.vision_config)
991
+ self.language_model = TextModel(args.text_config)
992
+
993
+ def sanitize(self, weights):
994
+ # Map weights to match the combined model structure
995
+ sanitized = {}
996
+ for k, v in weights.items():
997
+ # Remove 'model.' prefix if present to match our structure
998
+ clean_key = k.replace("model.", "") if k.startswith("model.") else k
999
+ sanitized[clean_key] = v
1000
+ return sanitized
1001
+
1002
+ def get_image_features(self, pixel_values: mx.array, image_grid_thw: Optional[mx.array] = None):
1003
+ image_embeds, deepstack_visual_embeds = self.visual(pixel_values, image_grid_thw)
1004
+ # Split based on grid dimensions
1005
+ if image_grid_thw is not None:
1006
+ split_sizes = (
1007
+ mx.prod(image_grid_thw, axis=-1) // (self.visual.spatial_merge_size**2)
1008
+ ).tolist()
1009
+ # Convert sizes to indices for mx.split (cumulative sum, excluding the last)
1010
+ split_indices = []
1011
+ cumsum = 0
1012
+ for size in split_sizes[:-1]: # Exclude last element
1013
+ cumsum += size
1014
+ split_indices.append(cumsum)
1015
+
1016
+ if split_indices: # Only split if we have indices
1017
+ image_embeds = mx.split(image_embeds, split_indices)
1018
+ else:
1019
+ image_embeds = [image_embeds] # Single image case
1020
+ return image_embeds, deepstack_visual_embeds
1021
+
1022
+ def __call__(
1023
+ self,
1024
+ input_ids: mx.array = None,
1025
+ attention_mask: Optional[mx.array] = None,
1026
+ inputs_embeds: Optional[mx.array] = None,
1027
+ pixel_values: Optional[mx.array] = None,
1028
+ image_grid_thw: Optional[mx.array] = None,
1029
+ cache=None,
1030
+ visual_pos_masks: Optional[mx.array] = None,
1031
+ deepstack_visual_embeds: Optional[List[mx.array]] = None,
1032
+ cos: Optional[mx.array] = None,
1033
+ sin: Optional[mx.array] = None,
1034
+ rope_deltas: Optional[mx.array] = None,
1035
+ ):
1036
+ if inputs_embeds is None:
1037
+ inputs_embeds = self.language_model.embed_tokens(input_ids)
1038
+
1039
+ # Process images
1040
+
1041
+ if pixel_values is not None:
1042
+ image_embeds, deepstack_visual_embeds = self.get_image_features(
1043
+ pixel_values, image_grid_thw
1044
+ )
1045
+
1046
+ # Create masks and embed visual features
1047
+ if isinstance(image_embeds, list):
1048
+ image_embeds = mx.concatenate(image_embeds, axis=0)
1049
+
1050
+ # Find image token positions and replace with visual embeddings
1051
+ image_mask = input_ids == self.args.image_token_id
1052
+ visual_pos_masks = image_mask
1053
+
1054
+ # Replace image tokens with visual embeddings
1055
+ inputs_embeds = inputs_embeds.at[image_mask].set(
1056
+ image_embeds.astype(inputs_embeds.dtype)
1057
+ )
1058
+
1059
+ outputs = self.language_model(
1060
+ inputs_embeds=inputs_embeds,
1061
+ attention_mask=attention_mask,
1062
+ cache=cache,
1063
+ visual_pos_masks=visual_pos_masks,
1064
+ deepstack_visual_embeds=deepstack_visual_embeds,
1065
+ cos=cos,
1066
+ sin=sin,
1067
+ rope_deltas=rope_deltas,
1068
+ )
1069
+
1070
+ return outputs
1071
+
1072
+
1073
+ def handle_multimodal_embeds(vision_model, llm_model, input_ids, pixel_values, image_grid_thw):
1074
+ """
1075
+ Handle the processing of multimodal embeddings including image features and position encoding.
1076
+
1077
+ This function processes vision and text inputs to create unified embeddings that can be fed
1078
+ into the language model. It handles:
1079
+ - Vision feature extraction from pixel values
1080
+ - Deepstack visual embedding collection
1081
+ - Image token replacement in text embeddings
1082
+ - Position encoding setup for MRoPE (Multi-dimensional RoPE)
1083
+
1084
+ Args:
1085
+ vision_model: The vision encoder model (VEGModel instance)
1086
+ llm_model: The language model (LLMModel instance)
1087
+ input_ids: Tokenized text input with image token placeholders [batch_size, seq_len]
1088
+ pixel_values: Preprocessed image pixel data [num_patches, feature_dim]
1089
+ image_grid_thw: Grid dimensions for each image [num_images, 3] (time, height, width)
1090
+
1091
+ Returns:
1092
+ tuple: (inputs_embeds, deepstack_visual_embeds, visual_pos_masks, cos, sin, rope_deltas)
1093
+ - inputs_embeds: Combined text and image embeddings [batch_size, seq_len, hidden_size]
1094
+ - deepstack_visual_embeds: Multi-layer visual features for deepstack processing
1095
+ - visual_pos_masks: Boolean mask indicating image token positions
1096
+ - cos: Cosine values for rotary position encoding
1097
+ - sin: Sine values for rotary position encoding
1098
+ - rope_deltas: Position offset deltas for rope computation
1099
+ """
1100
+ inputs_embeds = llm_model.language_model.embed_tokens(input_ids.squeeze(0))
1101
+ deepstack_visual_embeds = None
1102
+ visual_pos_masks = None
1103
+ cos = None
1104
+ sin = None
1105
+ rope_deltas = 0
1106
+
1107
+ if pixel_values is not None:
1108
+ if pixel_values.ndim == 4:
1109
+ pixel_values = mx.expand_dims(pixel_values, axis=2)
1110
+
1111
+ # Process each image individually to prevent feature mixing
1112
+ image_embeds_list = []
1113
+ all_deepstack_embeds = []
1114
+
1115
+ # Calculate cumulative indices for each image
1116
+ cumulative_patches = 0
1117
+
1118
+ for i in range(image_grid_thw.shape[0]):
1119
+ # Calculate number of patches for current image
1120
+ current_patches = int(image_grid_thw[i, 1] * image_grid_thw[i, 2])
1121
+ start_idx = cumulative_patches
1122
+ end_idx = cumulative_patches + current_patches
1123
+ cumulative_patches += current_patches
1124
+
1125
+ single_pixel_values = pixel_values[start_idx:end_idx]
1126
+ single_grid_thw = image_grid_thw[i : i + 1]
1127
+
1128
+ # Use vision model directly
1129
+ single_embeds, single_deepstack = vision_model(single_pixel_values, single_grid_thw)
1130
+
1131
+ # Split based on grid dimensions
1132
+ if single_grid_thw is not None:
1133
+ split_sizes = (
1134
+ mx.prod(single_grid_thw, axis=-1) // (vision_model.visual.spatial_merge_size**2)
1135
+ ).tolist()
1136
+ split_indices = []
1137
+ cumsum = 0
1138
+ for size in split_sizes[:-1]:
1139
+ cumsum += size
1140
+ split_indices.append(cumsum)
1141
+
1142
+ if split_indices:
1143
+ single_embeds = mx.split(single_embeds, split_indices)
1144
+ else:
1145
+ single_embeds = [single_embeds]
1146
+
1147
+ image_embeds_list.extend(single_embeds)
1148
+
1149
+ # Collect deepstack embeddings
1150
+ if i == 0:
1151
+ all_deepstack_embeds = single_deepstack
1152
+ else:
1153
+ # Concatenate deepstack embeddings from different images
1154
+ for j in range(len(all_deepstack_embeds)):
1155
+ all_deepstack_embeds[j] = mx.concatenate(
1156
+ [all_deepstack_embeds[j], single_deepstack[j]], axis=0
1157
+ )
1158
+
1159
+ deepstack_visual_embeds = all_deepstack_embeds
1160
+
1161
+ # Concatenate all image embeddings for processing
1162
+ image_embeds = mx.concatenate(image_embeds_list, axis=0)
1163
+
1164
+ # Find all image token positions
1165
+ image_token_id = 151655 # Default image token ID
1166
+ image_mask = input_ids.squeeze(0) == image_token_id
1167
+ image_mask_np = np.array(image_mask)
1168
+ image_token_positions = np.where(image_mask_np)[0]
1169
+
1170
+ # Verify we have the correct number of image tokens
1171
+ expected_total_tokens = sum(embed.shape[0] for embed in image_embeds_list)
1172
+ assert (
1173
+ len(image_token_positions) == expected_total_tokens
1174
+ ), f"Expected {expected_total_tokens} image tokens, got {len(image_token_positions)}"
1175
+
1176
+ # Replace image tokens with image embeddings
1177
+ seq_len = inputs_embeds.shape[0]
1178
+ result = inputs_embeds
1179
+
1180
+ # Replace image tokens with image embeddings sequentially
1181
+ embed_idx = 0
1182
+ for img_embed in image_embeds_list:
1183
+ for patch_idx in range(img_embed.shape[0]):
1184
+ token_pos = image_token_positions[embed_idx]
1185
+ pos_mask = mx.arange(seq_len) == token_pos
1186
+ result = mx.where(
1187
+ mx.expand_dims(pos_mask, axis=-1),
1188
+ mx.expand_dims(img_embed[patch_idx], axis=0).astype(inputs_embeds.dtype),
1189
+ result,
1190
+ )
1191
+ embed_idx += 1
1192
+
1193
+ inputs_embeds = result
1194
+ position_ids, rope_deltas = llm_model.get_rope_index(input_ids, image_grid_thw)
1195
+ cos, sin = llm_model.language_model.rotary_emb(inputs_embeds, position_ids)
1196
+ if inputs_embeds.ndim == 2:
1197
+ inputs_embeds = mx.expand_dims(inputs_embeds, axis=0)
1198
+
1199
+ if image_mask is not None:
1200
+ visual_pos_masks = image_mask
1201
+
1202
+ return inputs_embeds, deepstack_visual_embeds, visual_pos_masks, cos, sin, rope_deltas
1203
+
1204
+
1205
+ # Legacy Model wrapper (for backward compatibility)
1206
+ class Model(nn.Module):
1207
+ def __init__(self, args: ModelArgs):
1208
+ super().__init__()
1209
+ self.args = args
1210
+ self.model = Qwen3VLModel(args)
1211
+ if not args.text_config.tie_word_embeddings:
1212
+ self.lm_head = nn.Linear(
1213
+ args.text_config.hidden_size, args.text_config.vocab_size, bias=False
1214
+ )
1215
+
1216
+ def __call__(
1217
+ self,
1218
+ inputs: mx.array = None,
1219
+ mask: mx.array = None,
1220
+ cache=None,
1221
+ inputs_embeds: Optional[mx.array] = None,
1222
+ pixel_values: Optional[mx.array] = None,
1223
+ image_grid_thw: Optional[mx.array] = None,
1224
+ visual_pos_masks: Optional[mx.array] = None,
1225
+ deepstack_visual_embeds: Optional[List[mx.array]] = None,
1226
+ cos: Optional[mx.array] = None,
1227
+ sin: Optional[mx.array] = None,
1228
+ rope_deltas: Optional[mx.array] = None,
1229
+ ):
1230
+ out = self.model(
1231
+ input_ids=inputs,
1232
+ inputs_embeds=inputs_embeds,
1233
+ attention_mask=mask,
1234
+ cache=cache,
1235
+ pixel_values=pixel_values,
1236
+ image_grid_thw=image_grid_thw,
1237
+ visual_pos_masks=visual_pos_masks,
1238
+ deepstack_visual_embeds=deepstack_visual_embeds,
1239
+ cos=cos,
1240
+ sin=sin,
1241
+ rope_deltas=rope_deltas,
1242
+ )
1243
+ if self.args.text_config.tie_word_embeddings:
1244
+ return self.model.language_model.embed_tokens.as_linear(out)
1245
+ else:
1246
+ return self.lm_head(out)
1247
+
1248
+ def sanitize(self, weights):
1249
+ # Remove any unnecessary weights
1250
+ sanitized = {}
1251
+ for k, v in weights.items():
1252
+ sanitized[k] = v
1253
+
1254
+ # Handle tied embeddings - remove lm_head if using tied embeddings
1255
+ if self.args.text_config.tie_word_embeddings:
1256
+ sanitized.pop("lm_head.weight", None)
1257
+
1258
+ return sanitized
1259
+
1260
+ @property
1261
+ def layers(self):
1262
+ return self.model.language_model.layers