nexaai 1.0.29__cp310-cp310-macosx_14_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (580) hide show
  1. nexaai/__init__.py +99 -0
  2. nexaai/_stub.cpython-310-darwin.so +0 -0
  3. nexaai/_version.py +4 -0
  4. nexaai/asr.py +68 -0
  5. nexaai/asr_impl/__init__.py +0 -0
  6. nexaai/asr_impl/mlx_asr_impl.py +93 -0
  7. nexaai/asr_impl/pybind_asr_impl.py +127 -0
  8. nexaai/base.py +39 -0
  9. nexaai/binds/__init__.py +7 -0
  10. nexaai/binds/asr_bind.cpython-310-darwin.so +0 -0
  11. nexaai/binds/common_bind.cpython-310-darwin.so +0 -0
  12. nexaai/binds/cpu_gpu/libggml-base.dylib +0 -0
  13. nexaai/binds/cpu_gpu/libggml-cpu.so +0 -0
  14. nexaai/binds/cpu_gpu/libggml-metal.so +0 -0
  15. nexaai/binds/cpu_gpu/libggml.dylib +0 -0
  16. nexaai/binds/cpu_gpu/libmtmd.dylib +0 -0
  17. nexaai/binds/cpu_gpu/libnexa_cpu_gpu.dylib +0 -0
  18. nexaai/binds/cpu_gpu/libnexa_plugin.dylib +0 -0
  19. nexaai/binds/cv_bind.cpython-310-darwin.so +0 -0
  20. nexaai/binds/diarize_bind.cpython-310-darwin.so +0 -0
  21. nexaai/binds/embedder_bind.cpython-310-darwin.so +0 -0
  22. nexaai/binds/libnexa_bridge.dylib +0 -0
  23. nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
  24. nexaai/binds/metal/libnexa_plugin.dylib +0 -0
  25. nexaai/binds/metal/py-lib/ml.py +888 -0
  26. nexaai/binds/metal/py-lib/mlx_audio/__init__.py +0 -0
  27. nexaai/binds/metal/py-lib/mlx_audio/codec/__init__.py +1 -0
  28. nexaai/binds/metal/py-lib/mlx_audio/codec/models/__init__.py +5 -0
  29. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  30. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  31. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  32. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  33. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  34. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  35. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/__init__.py +1 -0
  36. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/base.py +228 -0
  37. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/dac.py +285 -0
  38. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  39. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  40. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  41. nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/__init__.py +1 -0
  42. nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/encodec.py +777 -0
  43. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/__init__.py +1 -0
  44. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/mimi.py +286 -0
  45. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  46. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  47. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  48. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  49. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  50. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  51. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/__init__.py +1 -0
  52. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model.py +260 -0
  53. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model_v2.py +383 -0
  54. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/utils.py +122 -0
  55. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/__init__.py +1 -0
  56. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/attention.py +97 -0
  57. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/layers.py +306 -0
  58. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/snac.py +154 -0
  59. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/vq.py +135 -0
  60. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/__init__.py +1 -0
  61. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/mel.py +33 -0
  62. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/vocos.py +359 -0
  63. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/__init__.py +0 -0
  64. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  65. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_descript.py +109 -0
  66. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_encodec.py +58 -0
  67. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_mimi.py +22 -0
  68. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_s3.py +25 -0
  69. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_snac.py +40 -0
  70. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_vocos.py +93 -0
  71. nexaai/binds/metal/py-lib/mlx_audio/server.py +525 -0
  72. nexaai/binds/metal/py-lib/mlx_audio/sts/__init__.py +0 -0
  73. nexaai/binds/metal/py-lib/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  74. nexaai/binds/metal/py-lib/mlx_audio/sts/voice_pipeline.py +327 -0
  75. nexaai/binds/metal/py-lib/mlx_audio/stt/__init__.py +0 -0
  76. nexaai/binds/metal/py-lib/mlx_audio/stt/generate.py +174 -0
  77. nexaai/binds/metal/py-lib/mlx_audio/stt/models/__init__.py +0 -0
  78. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  79. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  80. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/attention.py +187 -0
  81. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/audio.py +76 -0
  82. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  83. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  84. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  85. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  86. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  87. nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  88. nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  89. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/__init__.py +1 -0
  90. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/audio.py +82 -0
  91. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/decoding.py +742 -0
  92. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/timing.py +329 -0
  93. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  94. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/whisper.py +862 -0
  95. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/writers.py +268 -0
  96. nexaai/binds/metal/py-lib/mlx_audio/stt/tests/test_models.py +381 -0
  97. nexaai/binds/metal/py-lib/mlx_audio/stt/utils.py +195 -0
  98. nexaai/binds/metal/py-lib/mlx_audio/tts/__init__.py +1 -0
  99. nexaai/binds/metal/py-lib/mlx_audio/tts/audio_player.py +120 -0
  100. nexaai/binds/metal/py-lib/mlx_audio/tts/convert.py +71 -0
  101. nexaai/binds/metal/py-lib/mlx_audio/tts/generate.py +449 -0
  102. nexaai/binds/metal/py-lib/mlx_audio/tts/models/__init__.py +0 -0
  103. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/__init__.py +4 -0
  104. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/bark.py +528 -0
  105. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/isftnet.py +12 -0
  106. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/pipeline.py +442 -0
  107. nexaai/binds/metal/py-lib/mlx_audio/tts/models/base.py +84 -0
  108. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/__init__.py +1 -0
  109. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/audio.py +287 -0
  110. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/config.py +256 -0
  111. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/dia.py +592 -0
  112. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/layers.py +870 -0
  113. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/__init__.py +3 -0
  114. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/attention.py +180 -0
  115. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  116. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/conformer.py +247 -0
  117. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  118. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  119. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  120. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  121. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  122. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  123. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/indextts.py +412 -0
  124. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/mel.py +37 -0
  125. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/normalize.py +294 -0
  126. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  127. nexaai/binds/metal/py-lib/mlx_audio/tts/models/interpolate.py +108 -0
  128. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  129. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  130. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  131. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/modules.py +659 -0
  132. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  133. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/voice.py +113 -0
  134. nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/__init__.py +3 -0
  135. nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/llama.py +324 -0
  136. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/__init__.py +1 -0
  137. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  138. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  139. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/outetts.py +255 -0
  140. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  141. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/tokens.py +36 -0
  142. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/__init__.py +3 -0
  143. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/attention.py +195 -0
  144. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/sesame.py +633 -0
  145. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  146. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/__init__.py +1 -0
  147. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  148. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/bicodec.py +269 -0
  149. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  150. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  151. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  152. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  153. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  154. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  155. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  156. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  157. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  158. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  159. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  160. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  161. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  162. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  163. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  164. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/spark.py +382 -0
  165. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  166. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/file.py +221 -0
  167. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  168. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/__init__.py +0 -0
  169. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_base.py +66 -0
  170. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_convert.py +173 -0
  171. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_interpolate.py +88 -0
  172. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_models.py +974 -0
  173. nexaai/binds/metal/py-lib/mlx_audio/tts/utils.py +337 -0
  174. nexaai/binds/metal/py-lib/mlx_audio/utils.py +237 -0
  175. nexaai/binds/metal/py-lib/mlx_audio/version.py +1 -0
  176. nexaai/binds/metal/py-lib/profiling.py +239 -0
  177. nexaai/binds/nexaml/libfftw3.3.dylib +0 -0
  178. nexaai/binds/nexaml/libfftw3f.3.dylib +0 -0
  179. nexaai/binds/nexaml/libggml-base.dylib +0 -0
  180. nexaai/binds/nexaml/libggml-cpu.so +0 -0
  181. nexaai/binds/nexaml/libggml-metal.so +0 -0
  182. nexaai/binds/nexaml/libggml.dylib +0 -0
  183. nexaai/binds/nexaml/libmp3lame.0.dylib +0 -0
  184. nexaai/binds/nexaml/libmpg123.0.dylib +0 -0
  185. nexaai/binds/nexaml/libnexa-mm-process.dylib +0 -0
  186. nexaai/binds/nexaml/libnexa-sampling.dylib +0 -0
  187. nexaai/binds/nexaml/libnexa_plugin.dylib +0 -0
  188. nexaai/binds/nexaml/libnexaproc.dylib +0 -0
  189. nexaai/binds/nexaml/libomp.dylib +0 -0
  190. nexaai/binds/nexaml/libqwen3-vl.dylib +0 -0
  191. nexaai/binds/nexaml/libqwen3vl-vision.dylib +0 -0
  192. nexaai/binds/rerank_bind.cpython-310-darwin.so +0 -0
  193. nexaai/binds/vlm_bind.cpython-310-darwin.so +0 -0
  194. nexaai/common.py +106 -0
  195. nexaai/cv.py +95 -0
  196. nexaai/cv_impl/__init__.py +0 -0
  197. nexaai/cv_impl/mlx_cv_impl.py +91 -0
  198. nexaai/cv_impl/pybind_cv_impl.py +124 -0
  199. nexaai/diarize.py +80 -0
  200. nexaai/diarize_impl/__init__.py +1 -0
  201. nexaai/diarize_impl/pybind_diarize_impl.py +125 -0
  202. nexaai/embedder.py +73 -0
  203. nexaai/embedder_impl/__init__.py +0 -0
  204. nexaai/embedder_impl/mlx_embedder_impl.py +118 -0
  205. nexaai/embedder_impl/pybind_embedder_impl.py +96 -0
  206. nexaai/image_gen.py +141 -0
  207. nexaai/image_gen_impl/__init__.py +0 -0
  208. nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -0
  209. nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -0
  210. nexaai/llm.py +98 -0
  211. nexaai/llm_impl/__init__.py +0 -0
  212. nexaai/llm_impl/mlx_llm_impl.py +271 -0
  213. nexaai/llm_impl/pybind_llm_impl.py +238 -0
  214. nexaai/log.py +92 -0
  215. nexaai/mlx_backend/asr/__init__.py +12 -0
  216. nexaai/mlx_backend/asr/interface.py +122 -0
  217. nexaai/mlx_backend/common/__init__.py +0 -0
  218. nexaai/mlx_backend/common/utils.py +25 -0
  219. nexaai/mlx_backend/cv/__init__.py +0 -0
  220. nexaai/mlx_backend/cv/generate.py +195 -0
  221. nexaai/mlx_backend/cv/interface.py +162 -0
  222. nexaai/mlx_backend/cv/main.py +81 -0
  223. nexaai/mlx_backend/cv/modeling/pp_ocr_v4.py +1736 -0
  224. nexaai/mlx_backend/embedding/__init__.py +0 -0
  225. nexaai/mlx_backend/embedding/generate.py +333 -0
  226. nexaai/mlx_backend/embedding/interface.py +617 -0
  227. nexaai/mlx_backend/embedding/main.py +173 -0
  228. nexaai/mlx_backend/embedding/modeling/__init__.py +0 -0
  229. nexaai/mlx_backend/embedding/modeling/nexa_jina_v2.py +399 -0
  230. nexaai/mlx_backend/image_gen/__init__.py +1 -0
  231. nexaai/mlx_backend/image_gen/generate_sd.py +244 -0
  232. nexaai/mlx_backend/image_gen/interface.py +82 -0
  233. nexaai/mlx_backend/image_gen/main.py +281 -0
  234. nexaai/mlx_backend/image_gen/stable_diffusion/__init__.py +306 -0
  235. nexaai/mlx_backend/image_gen/stable_diffusion/clip.py +116 -0
  236. nexaai/mlx_backend/image_gen/stable_diffusion/config.py +65 -0
  237. nexaai/mlx_backend/image_gen/stable_diffusion/model_io.py +386 -0
  238. nexaai/mlx_backend/image_gen/stable_diffusion/sampler.py +105 -0
  239. nexaai/mlx_backend/image_gen/stable_diffusion/tokenizer.py +100 -0
  240. nexaai/mlx_backend/image_gen/stable_diffusion/unet.py +460 -0
  241. nexaai/mlx_backend/image_gen/stable_diffusion/vae.py +274 -0
  242. nexaai/mlx_backend/llm/__init__.py +0 -0
  243. nexaai/mlx_backend/llm/generate.py +149 -0
  244. nexaai/mlx_backend/llm/interface.py +764 -0
  245. nexaai/mlx_backend/llm/main.py +68 -0
  246. nexaai/mlx_backend/ml.py +888 -0
  247. nexaai/mlx_backend/mlx_audio/__init__.py +0 -0
  248. nexaai/mlx_backend/mlx_audio/codec/__init__.py +1 -0
  249. nexaai/mlx_backend/mlx_audio/codec/models/__init__.py +5 -0
  250. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  251. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  252. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  253. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  254. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  255. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  256. nexaai/mlx_backend/mlx_audio/codec/models/descript/__init__.py +1 -0
  257. nexaai/mlx_backend/mlx_audio/codec/models/descript/base.py +228 -0
  258. nexaai/mlx_backend/mlx_audio/codec/models/descript/dac.py +285 -0
  259. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  260. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  261. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  262. nexaai/mlx_backend/mlx_audio/codec/models/encodec/__init__.py +1 -0
  263. nexaai/mlx_backend/mlx_audio/codec/models/encodec/encodec.py +777 -0
  264. nexaai/mlx_backend/mlx_audio/codec/models/mimi/__init__.py +1 -0
  265. nexaai/mlx_backend/mlx_audio/codec/models/mimi/mimi.py +286 -0
  266. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  267. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  268. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  269. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  270. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  271. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  272. nexaai/mlx_backend/mlx_audio/codec/models/s3/__init__.py +1 -0
  273. nexaai/mlx_backend/mlx_audio/codec/models/s3/model.py +260 -0
  274. nexaai/mlx_backend/mlx_audio/codec/models/s3/model_v2.py +383 -0
  275. nexaai/mlx_backend/mlx_audio/codec/models/s3/utils.py +122 -0
  276. nexaai/mlx_backend/mlx_audio/codec/models/snac/__init__.py +1 -0
  277. nexaai/mlx_backend/mlx_audio/codec/models/snac/attention.py +97 -0
  278. nexaai/mlx_backend/mlx_audio/codec/models/snac/layers.py +306 -0
  279. nexaai/mlx_backend/mlx_audio/codec/models/snac/snac.py +154 -0
  280. nexaai/mlx_backend/mlx_audio/codec/models/snac/vq.py +135 -0
  281. nexaai/mlx_backend/mlx_audio/codec/models/vocos/__init__.py +1 -0
  282. nexaai/mlx_backend/mlx_audio/codec/models/vocos/mel.py +33 -0
  283. nexaai/mlx_backend/mlx_audio/codec/models/vocos/vocos.py +359 -0
  284. nexaai/mlx_backend/mlx_audio/codec/tests/__init__.py +0 -0
  285. nexaai/mlx_backend/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  286. nexaai/mlx_backend/mlx_audio/codec/tests/test_descript.py +109 -0
  287. nexaai/mlx_backend/mlx_audio/codec/tests/test_encodec.py +58 -0
  288. nexaai/mlx_backend/mlx_audio/codec/tests/test_mimi.py +22 -0
  289. nexaai/mlx_backend/mlx_audio/codec/tests/test_s3.py +25 -0
  290. nexaai/mlx_backend/mlx_audio/codec/tests/test_snac.py +40 -0
  291. nexaai/mlx_backend/mlx_audio/codec/tests/test_vocos.py +93 -0
  292. nexaai/mlx_backend/mlx_audio/server.py +525 -0
  293. nexaai/mlx_backend/mlx_audio/sts/__init__.py +0 -0
  294. nexaai/mlx_backend/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  295. nexaai/mlx_backend/mlx_audio/sts/voice_pipeline.py +327 -0
  296. nexaai/mlx_backend/mlx_audio/stt/__init__.py +0 -0
  297. nexaai/mlx_backend/mlx_audio/stt/generate.py +174 -0
  298. nexaai/mlx_backend/mlx_audio/stt/models/__init__.py +0 -0
  299. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  300. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  301. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/attention.py +187 -0
  302. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/audio.py +76 -0
  303. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  304. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  305. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  306. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  307. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  308. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  309. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  310. nexaai/mlx_backend/mlx_audio/stt/models/whisper/__init__.py +1 -0
  311. nexaai/mlx_backend/mlx_audio/stt/models/whisper/audio.py +82 -0
  312. nexaai/mlx_backend/mlx_audio/stt/models/whisper/decoding.py +742 -0
  313. nexaai/mlx_backend/mlx_audio/stt/models/whisper/timing.py +329 -0
  314. nexaai/mlx_backend/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  315. nexaai/mlx_backend/mlx_audio/stt/models/whisper/whisper.py +862 -0
  316. nexaai/mlx_backend/mlx_audio/stt/models/whisper/writers.py +268 -0
  317. nexaai/mlx_backend/mlx_audio/stt/tests/test_models.py +381 -0
  318. nexaai/mlx_backend/mlx_audio/stt/utils.py +195 -0
  319. nexaai/mlx_backend/mlx_audio/tts/__init__.py +1 -0
  320. nexaai/mlx_backend/mlx_audio/tts/audio_player.py +120 -0
  321. nexaai/mlx_backend/mlx_audio/tts/convert.py +71 -0
  322. nexaai/mlx_backend/mlx_audio/tts/generate.py +449 -0
  323. nexaai/mlx_backend/mlx_audio/tts/models/__init__.py +0 -0
  324. nexaai/mlx_backend/mlx_audio/tts/models/bark/__init__.py +4 -0
  325. nexaai/mlx_backend/mlx_audio/tts/models/bark/bark.py +528 -0
  326. nexaai/mlx_backend/mlx_audio/tts/models/bark/isftnet.py +12 -0
  327. nexaai/mlx_backend/mlx_audio/tts/models/bark/pipeline.py +442 -0
  328. nexaai/mlx_backend/mlx_audio/tts/models/base.py +84 -0
  329. nexaai/mlx_backend/mlx_audio/tts/models/dia/__init__.py +1 -0
  330. nexaai/mlx_backend/mlx_audio/tts/models/dia/audio.py +287 -0
  331. nexaai/mlx_backend/mlx_audio/tts/models/dia/config.py +256 -0
  332. nexaai/mlx_backend/mlx_audio/tts/models/dia/dia.py +592 -0
  333. nexaai/mlx_backend/mlx_audio/tts/models/dia/layers.py +870 -0
  334. nexaai/mlx_backend/mlx_audio/tts/models/indextts/__init__.py +3 -0
  335. nexaai/mlx_backend/mlx_audio/tts/models/indextts/attention.py +180 -0
  336. nexaai/mlx_backend/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  337. nexaai/mlx_backend/mlx_audio/tts/models/indextts/conformer.py +247 -0
  338. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  339. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  340. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  341. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  342. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  343. nexaai/mlx_backend/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  344. nexaai/mlx_backend/mlx_audio/tts/models/indextts/indextts.py +412 -0
  345. nexaai/mlx_backend/mlx_audio/tts/models/indextts/mel.py +37 -0
  346. nexaai/mlx_backend/mlx_audio/tts/models/indextts/normalize.py +294 -0
  347. nexaai/mlx_backend/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  348. nexaai/mlx_backend/mlx_audio/tts/models/interpolate.py +108 -0
  349. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  350. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  351. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  352. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/modules.py +659 -0
  353. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  354. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/voice.py +113 -0
  355. nexaai/mlx_backend/mlx_audio/tts/models/llama/__init__.py +3 -0
  356. nexaai/mlx_backend/mlx_audio/tts/models/llama/llama.py +324 -0
  357. nexaai/mlx_backend/mlx_audio/tts/models/outetts/__init__.py +1 -0
  358. nexaai/mlx_backend/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  359. nexaai/mlx_backend/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  360. nexaai/mlx_backend/mlx_audio/tts/models/outetts/default_speaker.json +461 -0
  361. nexaai/mlx_backend/mlx_audio/tts/models/outetts/outetts.py +255 -0
  362. nexaai/mlx_backend/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  363. nexaai/mlx_backend/mlx_audio/tts/models/outetts/tokens.py +36 -0
  364. nexaai/mlx_backend/mlx_audio/tts/models/sesame/__init__.py +3 -0
  365. nexaai/mlx_backend/mlx_audio/tts/models/sesame/attention.py +195 -0
  366. nexaai/mlx_backend/mlx_audio/tts/models/sesame/sesame.py +633 -0
  367. nexaai/mlx_backend/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  368. nexaai/mlx_backend/mlx_audio/tts/models/spark/__init__.py +1 -0
  369. nexaai/mlx_backend/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  370. nexaai/mlx_backend/mlx_audio/tts/models/spark/bicodec.py +269 -0
  371. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  372. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  373. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  374. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  375. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  376. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  377. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  378. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  379. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  380. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  381. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  382. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  383. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  384. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  385. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  386. nexaai/mlx_backend/mlx_audio/tts/models/spark/spark.py +382 -0
  387. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  388. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/file.py +221 -0
  389. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  390. nexaai/mlx_backend/mlx_audio/tts/tests/__init__.py +0 -0
  391. nexaai/mlx_backend/mlx_audio/tts/tests/test_base.py +66 -0
  392. nexaai/mlx_backend/mlx_audio/tts/tests/test_convert.py +173 -0
  393. nexaai/mlx_backend/mlx_audio/tts/tests/test_interpolate.py +88 -0
  394. nexaai/mlx_backend/mlx_audio/tts/tests/test_models.py +974 -0
  395. nexaai/mlx_backend/mlx_audio/tts/utils.py +337 -0
  396. nexaai/mlx_backend/mlx_audio/utils.py +237 -0
  397. nexaai/mlx_backend/mlx_audio/version.py +1 -0
  398. nexaai/mlx_backend/profiling.py +239 -0
  399. nexaai/mlx_backend/rerank/__init__.py +0 -0
  400. nexaai/mlx_backend/rerank/generate.py +174 -0
  401. nexaai/mlx_backend/rerank/interface.py +287 -0
  402. nexaai/mlx_backend/rerank/main.py +127 -0
  403. nexaai/mlx_backend/rerank/modeling/__init__.py +0 -0
  404. nexaai/mlx_backend/rerank/modeling/nexa_jina_rerank.py +330 -0
  405. nexaai/mlx_backend/sd/__init__.py +1 -0
  406. nexaai/mlx_backend/sd/interface.py +362 -0
  407. nexaai/mlx_backend/sd/main.py +286 -0
  408. nexaai/mlx_backend/sd/modeling/__init__.py +306 -0
  409. nexaai/mlx_backend/sd/modeling/clip.py +116 -0
  410. nexaai/mlx_backend/sd/modeling/config.py +65 -0
  411. nexaai/mlx_backend/sd/modeling/model_io.py +385 -0
  412. nexaai/mlx_backend/sd/modeling/sampler.py +105 -0
  413. nexaai/mlx_backend/sd/modeling/tokenizer.py +100 -0
  414. nexaai/mlx_backend/sd/modeling/unet.py +460 -0
  415. nexaai/mlx_backend/sd/modeling/vae.py +274 -0
  416. nexaai/mlx_backend/tts/__init__.py +12 -0
  417. nexaai/mlx_backend/tts/interface.py +276 -0
  418. nexaai/mlx_backend/vlm/__init__.py +3 -0
  419. nexaai/mlx_backend/vlm/generate.py +572 -0
  420. nexaai/mlx_backend/vlm/generate_qwen3_vl.py +374 -0
  421. nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +259 -0
  422. nexaai/mlx_backend/vlm/interface.py +559 -0
  423. nexaai/mlx_backend/vlm/main.py +365 -0
  424. nexaai/mlx_backend/vlm/modeling/__init__.py +0 -0
  425. nexaai/mlx_backend/vlm/modeling/convert.py +68 -0
  426. nexaai/mlx_backend/vlm/modeling/models/__init__.py +0 -0
  427. nexaai/mlx_backend/vlm/modeling/models/aya_vision/__init__.py +8 -0
  428. nexaai/mlx_backend/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
  429. nexaai/mlx_backend/vlm/modeling/models/aya_vision/interpolate.py +186 -0
  430. nexaai/mlx_backend/vlm/modeling/models/aya_vision/language.py +233 -0
  431. nexaai/mlx_backend/vlm/modeling/models/aya_vision/vision.py +503 -0
  432. nexaai/mlx_backend/vlm/modeling/models/base.py +202 -0
  433. nexaai/mlx_backend/vlm/modeling/models/cache.py +230 -0
  434. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
  435. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
  436. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
  437. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
  438. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
  439. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
  440. nexaai/mlx_backend/vlm/modeling/models/florence2/__init__.py +8 -0
  441. nexaai/mlx_backend/vlm/modeling/models/florence2/florence2.py +366 -0
  442. nexaai/mlx_backend/vlm/modeling/models/florence2/language.py +488 -0
  443. nexaai/mlx_backend/vlm/modeling/models/florence2/vision.py +591 -0
  444. nexaai/mlx_backend/vlm/modeling/models/gemma3/__init__.py +8 -0
  445. nexaai/mlx_backend/vlm/modeling/models/gemma3/gemma3.py +213 -0
  446. nexaai/mlx_backend/vlm/modeling/models/gemma3/language.py +315 -0
  447. nexaai/mlx_backend/vlm/modeling/models/gemma3/vision.py +238 -0
  448. nexaai/mlx_backend/vlm/modeling/models/gemma3n/__init__.py +2 -0
  449. nexaai/mlx_backend/vlm/modeling/models/gemma3n/audio.py +1038 -0
  450. nexaai/mlx_backend/vlm/modeling/models/gemma3n/config.py +139 -0
  451. nexaai/mlx_backend/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
  452. nexaai/mlx_backend/vlm/modeling/models/gemma3n/language.py +629 -0
  453. nexaai/mlx_backend/vlm/modeling/models/gemma3n/vision.py +1022 -0
  454. nexaai/mlx_backend/vlm/modeling/models/idefics2/__init__.py +9 -0
  455. nexaai/mlx_backend/vlm/modeling/models/idefics2/idefics2.py +294 -0
  456. nexaai/mlx_backend/vlm/modeling/models/idefics2/language.py +191 -0
  457. nexaai/mlx_backend/vlm/modeling/models/idefics2/vision.py +267 -0
  458. nexaai/mlx_backend/vlm/modeling/models/idefics3/__init__.py +8 -0
  459. nexaai/mlx_backend/vlm/modeling/models/idefics3/idefics3.py +175 -0
  460. nexaai/mlx_backend/vlm/modeling/models/idefics3/language.py +192 -0
  461. nexaai/mlx_backend/vlm/modeling/models/idefics3/vision.py +233 -0
  462. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/__init__.py +9 -0
  463. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
  464. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/language.py +220 -0
  465. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/processor.py +393 -0
  466. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/vision.py +293 -0
  467. nexaai/mlx_backend/vlm/modeling/models/kernels.py +307 -0
  468. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/__init__.py +8 -0
  469. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
  470. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/language.py +509 -0
  471. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/vision.py +522 -0
  472. nexaai/mlx_backend/vlm/modeling/models/llama4/__init__.py +8 -0
  473. nexaai/mlx_backend/vlm/modeling/models/llama4/language.py +386 -0
  474. nexaai/mlx_backend/vlm/modeling/models/llama4/llama4.py +138 -0
  475. nexaai/mlx_backend/vlm/modeling/models/llama4/vision.py +560 -0
  476. nexaai/mlx_backend/vlm/modeling/models/llava/__init__.py +8 -0
  477. nexaai/mlx_backend/vlm/modeling/models/llava/language.py +240 -0
  478. nexaai/mlx_backend/vlm/modeling/models/llava/llava.py +153 -0
  479. nexaai/mlx_backend/vlm/modeling/models/llava/vision.py +259 -0
  480. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/__init__.py +9 -0
  481. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/language.py +236 -0
  482. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
  483. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/vision.py +303 -0
  484. nexaai/mlx_backend/vlm/modeling/models/llava_next/__init__.py +8 -0
  485. nexaai/mlx_backend/vlm/modeling/models/llava_next/language.py +230 -0
  486. nexaai/mlx_backend/vlm/modeling/models/llava_next/llava_next.py +160 -0
  487. nexaai/mlx_backend/vlm/modeling/models/llava_next/vision.py +243 -0
  488. nexaai/mlx_backend/vlm/modeling/models/mistral3/__init__.py +8 -0
  489. nexaai/mlx_backend/vlm/modeling/models/mistral3/mistral3.py +283 -0
  490. nexaai/mlx_backend/vlm/modeling/models/mllama/__init__.py +8 -0
  491. nexaai/mlx_backend/vlm/modeling/models/mllama/language.py +416 -0
  492. nexaai/mlx_backend/vlm/modeling/models/mllama/mllama.py +172 -0
  493. nexaai/mlx_backend/vlm/modeling/models/mllama/vision.py +499 -0
  494. nexaai/mlx_backend/vlm/modeling/models/molmo/__init__.py +8 -0
  495. nexaai/mlx_backend/vlm/modeling/models/molmo/language.py +243 -0
  496. nexaai/mlx_backend/vlm/modeling/models/molmo/molmo.py +133 -0
  497. nexaai/mlx_backend/vlm/modeling/models/molmo/vision.py +465 -0
  498. nexaai/mlx_backend/vlm/modeling/models/multi_modality/__init__.py +10 -0
  499. nexaai/mlx_backend/vlm/modeling/models/multi_modality/language.py +230 -0
  500. nexaai/mlx_backend/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
  501. nexaai/mlx_backend/vlm/modeling/models/multi_modality/sam.py +557 -0
  502. nexaai/mlx_backend/vlm/modeling/models/multi_modality/vision.py +526 -0
  503. nexaai/mlx_backend/vlm/modeling/models/paligemma/__init__.py +8 -0
  504. nexaai/mlx_backend/vlm/modeling/models/paligemma/language.py +282 -0
  505. nexaai/mlx_backend/vlm/modeling/models/paligemma/paligemma.py +160 -0
  506. nexaai/mlx_backend/vlm/modeling/models/paligemma/vision.py +242 -0
  507. nexaai/mlx_backend/vlm/modeling/models/phi3_v/__init__.py +8 -0
  508. nexaai/mlx_backend/vlm/modeling/models/phi3_v/language.py +21 -0
  509. nexaai/mlx_backend/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
  510. nexaai/mlx_backend/vlm/modeling/models/phi3_v/su_rope.py +71 -0
  511. nexaai/mlx_backend/vlm/modeling/models/phi3_v/vision.py +324 -0
  512. nexaai/mlx_backend/vlm/modeling/models/pixtral/__init__.py +8 -0
  513. nexaai/mlx_backend/vlm/modeling/models/pixtral/language.py +229 -0
  514. nexaai/mlx_backend/vlm/modeling/models/pixtral/pixtral.py +161 -0
  515. nexaai/mlx_backend/vlm/modeling/models/pixtral/vision.py +320 -0
  516. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
  517. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
  518. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
  519. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
  520. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
  521. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
  522. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/config.py +104 -0
  523. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/language.py +490 -0
  524. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
  525. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/vision.py +312 -0
  526. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
  527. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
  528. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
  529. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
  530. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
  531. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
  532. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
  533. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/processor.py +476 -0
  534. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +1262 -0
  535. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
  536. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
  537. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
  538. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
  539. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
  540. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
  541. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
  542. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
  543. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1308 -0
  544. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
  545. nexaai/mlx_backend/vlm/modeling/models/smolvlm/__init__.py +8 -0
  546. nexaai/mlx_backend/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
  547. nexaai/mlx_backend/vlm/modeling/processing_qwen2_5_vl.py +209 -0
  548. nexaai/mlx_backend/vlm/modeling/processing_qwen2_vl.py +215 -0
  549. nexaai/mlx_backend/vlm/modeling/prompt_utils.py +474 -0
  550. nexaai/mlx_backend/vlm/modeling/sample_utils.py +39 -0
  551. nexaai/mlx_backend/vlm/modeling/tokenizer_utils.py +344 -0
  552. nexaai/mlx_backend/vlm/modeling/trainer/__init__.py +9 -0
  553. nexaai/mlx_backend/vlm/modeling/trainer/lora.py +70 -0
  554. nexaai/mlx_backend/vlm/modeling/trainer/trainer.py +296 -0
  555. nexaai/mlx_backend/vlm/modeling/trainer/utils.py +160 -0
  556. nexaai/mlx_backend/vlm/modeling/utils.py +928 -0
  557. nexaai/rerank.py +57 -0
  558. nexaai/rerank_impl/__init__.py +0 -0
  559. nexaai/rerank_impl/mlx_rerank_impl.py +94 -0
  560. nexaai/rerank_impl/pybind_rerank_impl.py +136 -0
  561. nexaai/runtime.py +68 -0
  562. nexaai/runtime_error.py +24 -0
  563. nexaai/tts.py +75 -0
  564. nexaai/tts_impl/__init__.py +0 -0
  565. nexaai/tts_impl/mlx_tts_impl.py +94 -0
  566. nexaai/tts_impl/pybind_tts_impl.py +43 -0
  567. nexaai/utils/decode.py +18 -0
  568. nexaai/utils/manifest_utils.py +531 -0
  569. nexaai/utils/model_manager.py +1745 -0
  570. nexaai/utils/model_types.py +49 -0
  571. nexaai/utils/progress_tracker.py +389 -0
  572. nexaai/utils/quantization_utils.py +245 -0
  573. nexaai/vlm.py +130 -0
  574. nexaai/vlm_impl/__init__.py +0 -0
  575. nexaai/vlm_impl/mlx_vlm_impl.py +259 -0
  576. nexaai/vlm_impl/pybind_vlm_impl.py +275 -0
  577. nexaai-1.0.29.dist-info/METADATA +35 -0
  578. nexaai-1.0.29.dist-info/RECORD +580 -0
  579. nexaai-1.0.29.dist-info/WHEEL +5 -0
  580. nexaai-1.0.29.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1308 @@
1
+ # Copyright © 2023-2024 Apple Inc.
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, List, Optional, Tuple, Union
5
+
6
+ import mlx.core as mx
7
+ import mlx.nn as nn
8
+ import math
9
+ import numpy as np
10
+
11
+ # Import from nested llm_common structure using relative imports
12
+ from .llm_common.base import (
13
+ BaseModelArgs,
14
+ create_attention_mask,
15
+ scaled_dot_product_attention,
16
+ )
17
+ from .llm_common.rope_utils import initialize_rope
18
+ from .switch_layers import SwitchGLU
19
+
20
+
21
+ @dataclass
22
+ class VisionConfig:
23
+ hidden_size: int = 1152
24
+ intermediate_size: int = 4304
25
+ num_heads: int = 16
26
+ num_hidden_layers: int = 27
27
+ patch_size: int = 16
28
+ temporal_patch_size: int = 2
29
+ in_channels: int = 3
30
+ hidden_act: str = "gelu_pytorch_tanh"
31
+ spatial_merge_size: int = 2
32
+ out_hidden_size: int = 2048
33
+ num_position_embeddings: int = 2304
34
+ deepstack_visual_indexes: List[int] = None
35
+
36
+ def __post_init__(self):
37
+ if self.deepstack_visual_indexes is None:
38
+ self.deepstack_visual_indexes = [8, 16, 24]
39
+
40
+
41
+ @dataclass
42
+ class TextConfig(BaseModelArgs):
43
+ model_type: str = "qwen3_vl_moe_text"
44
+ hidden_size: int = 2048
45
+ num_hidden_layers: int = 48
46
+ intermediate_size: int = 6144
47
+ num_attention_heads: int = 32
48
+ num_key_value_heads: int = 4
49
+ rms_norm_eps: float = 1e-6
50
+ vocab_size: int = 152064
51
+ max_position_embeddings: int = 128000
52
+ rope_theta: float = 1000000.0
53
+ head_dim: int = 128
54
+ tie_word_embeddings: bool = False
55
+ attention_bias: bool = False
56
+ attention_dropout: float = 0.0
57
+ rope_scaling: Optional[Dict[str, Union[float, str]]] = None
58
+ # MoE specific parameters
59
+ num_experts: int = 128
60
+ num_experts_per_tok: int = 8
61
+ moe_intermediate_size: int = 768
62
+ shared_expert_intermediate_size: int = 0
63
+ norm_topk_prob: bool = True
64
+ decoder_sparse_step: int = 1
65
+ max_window_layers: int = 48
66
+ sliding_window: int = 32768
67
+ mlp_only_layers: List[int] = None
68
+ use_qk_norm: bool = True
69
+ layer_types: List[str] = None
70
+
71
+ def __post_init__(self):
72
+ if self.rope_scaling is None:
73
+ self.rope_scaling = {
74
+ "mrope_interleaved": True,
75
+ "mrope_section": [24, 20, 20],
76
+ "rope_type": "default"
77
+ }
78
+ if self.mlp_only_layers is None:
79
+ self.mlp_only_layers = []
80
+ if self.layer_types is None:
81
+ # This would need to be populated based on the actual model architecture
82
+ self.layer_types = []
83
+
84
+
85
+ @dataclass
86
+ class ModelArgs(BaseModelArgs):
87
+ vision_config: VisionConfig = None
88
+ text_config: TextConfig = None
89
+ image_token_id: int = 151655
90
+ vision_start_token_id: int = 151652
91
+ vision_end_token_id: int = 151653
92
+
93
+ def __post_init__(self):
94
+ if self.vision_config is None:
95
+ self.vision_config = VisionConfig()
96
+ if self.text_config is None:
97
+ self.text_config = TextConfig()
98
+
99
+
100
+ def rotate_half(x):
101
+ x1 = x[..., : x.shape[-1] // 2]
102
+ x2 = x[..., x.shape[-1] // 2 :]
103
+ return mx.concatenate([-x2, x1], axis=-1)
104
+
105
+
106
+ def apply_rotary_pos_emb_vision(q, k, cos, sin):
107
+ cos = mx.expand_dims(cos, axis=-2)
108
+ sin = mx.expand_dims(sin, axis=-2)
109
+ q_embed = (q * cos) + (rotate_half(q) * sin)
110
+ k_embed = (k * cos) + (rotate_half(k) * sin)
111
+ return q_embed, k_embed
112
+
113
+
114
+ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
115
+ cos = mx.expand_dims(cos, axis=unsqueeze_dim)
116
+ sin = mx.expand_dims(sin, axis=unsqueeze_dim)
117
+ q_embed = (q * cos) + (rotate_half(q) * sin)
118
+ k_embed = (k * cos) + (rotate_half(k) * sin)
119
+ return q_embed, k_embed
120
+
121
+
122
+ class VisionMLP(nn.Module):
123
+ def __init__(self, config: VisionConfig):
124
+ super().__init__()
125
+ self.hidden_size = config.hidden_size
126
+ self.intermediate_size = config.intermediate_size
127
+ self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True)
128
+ self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True)
129
+
130
+ def __call__(self, hidden_state):
131
+ return self.linear_fc2(nn.gelu(self.linear_fc1(hidden_state)))
132
+
133
+
134
+ class VisionPatchEmbed(nn.Module):
135
+ def __init__(self, config: VisionConfig):
136
+ super().__init__()
137
+ self.patch_size = config.patch_size
138
+ self.temporal_patch_size = config.temporal_patch_size
139
+ self.in_channels = config.in_channels
140
+ self.embed_dim = config.hidden_size
141
+
142
+ kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
143
+ self.proj = nn.Conv3d(
144
+ self.in_channels,
145
+ self.embed_dim,
146
+ kernel_size=kernel_size,
147
+ stride=kernel_size,
148
+ bias=True
149
+ )
150
+
151
+ def __call__(self, hidden_states: mx.array) -> mx.array:
152
+ target_dtype = self.proj.weight.dtype
153
+
154
+ # Reshape to 5D: [batch, channels, temporal, height, width] (PyTorch format)
155
+ # This matches the PyTorch ground truth exactly
156
+ hidden_states = hidden_states.reshape(
157
+ -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
158
+ )
159
+
160
+ # Convert to MLX format: [batch, temporal, height, width, channels]
161
+ hidden_states = hidden_states.transpose(0, 2, 3, 4, 1)
162
+
163
+ # Apply conv3d with target dtype and reshape to match PyTorch output
164
+ hidden_states = self.proj(hidden_states.astype(target_dtype)).reshape(-1, self.embed_dim)
165
+
166
+ return hidden_states
167
+
168
+
169
+ class VisionRotaryEmbedding(nn.Module):
170
+ def __init__(self, dim: int, theta: float = 10000.0):
171
+ super().__init__()
172
+ # Don't store inv_freq as a parameter since it causes loading issues
173
+ self.dim = dim
174
+ self.theta = theta
175
+
176
+ def __call__(self, seqlen: int) -> mx.array:
177
+ # Compute inv_freq on the fly
178
+ inv_freq = 1.0 / (self.theta ** (mx.arange(0, self.dim, 2, dtype=mx.float32) / self.dim))
179
+ seq = mx.arange(seqlen, dtype=inv_freq.dtype)
180
+ freqs = mx.outer(seq, inv_freq)
181
+ return freqs
182
+
183
+
184
+ class VisionPatchMerger(nn.Module):
185
+ def __init__(self, config: VisionConfig, use_postshuffle_norm=False):
186
+ super().__init__()
187
+ self.hidden_size = config.hidden_size * (config.spatial_merge_size ** 2)
188
+ self.use_postshuffle_norm = use_postshuffle_norm
189
+
190
+ norm_size = self.hidden_size if use_postshuffle_norm else config.hidden_size
191
+ self.ln_q = nn.LayerNorm(norm_size, eps=1e-6)
192
+ self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size)
193
+ self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size)
194
+
195
+ def __call__(self, x: mx.array) -> mx.array:
196
+ if self.use_postshuffle_norm:
197
+ x = self.ln_q(x.reshape(-1, self.hidden_size)).reshape(-1, self.hidden_size)
198
+ else:
199
+ x = self.ln_q(x).reshape(-1, self.hidden_size)
200
+
201
+ x = self.linear_fc2(nn.gelu(self.linear_fc1(x)))
202
+ return x
203
+
204
+
205
+ class VisionAttention(nn.Module):
206
+ def __init__(self, config: VisionConfig):
207
+ super().__init__()
208
+ self.dim = config.hidden_size
209
+ self.num_heads = config.num_heads
210
+ self.head_dim = self.dim // self.num_heads
211
+ self.scaling = self.head_dim ** -0.5
212
+
213
+ self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True)
214
+ self.proj = nn.Linear(self.dim, self.dim)
215
+
216
+ def __call__(
217
+ self,
218
+ hidden_states: mx.array,
219
+ cu_seqlens: mx.array,
220
+ rotary_pos_emb: Optional[mx.array] = None,
221
+ position_embeddings: Optional[Tuple[mx.array, mx.array]] = None,
222
+ **kwargs,
223
+ ) -> mx.array:
224
+ seq_length = hidden_states.shape[0]
225
+ qkv = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1)
226
+ qkv = qkv.transpose(1, 0, 2, 3)
227
+ query_states, key_states, value_states = qkv[0], qkv[1], qkv[2]
228
+
229
+ cos, sin = position_embeddings
230
+ query_states, key_states = apply_rotary_pos_emb_vision(
231
+ query_states, key_states, cos, sin
232
+ )
233
+
234
+ query_states = query_states.transpose(1, 0, 2)
235
+ key_states = key_states.transpose(1, 0, 2)
236
+ value_states = value_states.transpose(1, 0, 2)
237
+
238
+ query_states = mx.expand_dims(query_states, axis=0)
239
+ key_states = mx.expand_dims(key_states, axis=0)
240
+ value_states = mx.expand_dims(value_states, axis=0)
241
+
242
+ lengths = cu_seqlens[1:] - cu_seqlens[:-1]
243
+
244
+ split_indices = []
245
+ cumsum = 0
246
+ for length in lengths[:-1]:
247
+ cumsum += int(length)
248
+ split_indices.append(cumsum)
249
+
250
+ if split_indices:
251
+ q_splits = mx.split(query_states, split_indices, axis=1)
252
+ k_splits = mx.split(key_states, split_indices, axis=1)
253
+ v_splits = mx.split(value_states, split_indices, axis=1)
254
+ else:
255
+ q_splits = [query_states]
256
+ k_splits = [key_states]
257
+ v_splits = [value_states]
258
+
259
+ attn_outputs = []
260
+ for q, k, v in zip(q_splits, k_splits, v_splits):
261
+ attn_out = scaled_dot_product_attention(
262
+ q, k, v,
263
+ scale=self.scaling, mask=None, cache=None
264
+ )
265
+ attn_outputs.append(attn_out)
266
+
267
+ attn_output = mx.concatenate(attn_outputs, axis=1)
268
+
269
+ attn_output = attn_output[0].transpose(1, 0, 2)
270
+ attn_output = attn_output.reshape(seq_length, -1)
271
+ attn_output = self.proj(attn_output)
272
+
273
+ return attn_output
274
+
275
+
276
+ class VisionBlock(nn.Module):
277
+ def __init__(self, config: VisionConfig):
278
+ super().__init__()
279
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6)
280
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6)
281
+ self.attn = VisionAttention(config)
282
+ self.mlp = VisionMLP(config)
283
+
284
+ def __call__(
285
+ self,
286
+ hidden_states: mx.array,
287
+ cu_seqlens: mx.array,
288
+ position_embeddings: Tuple[mx.array, mx.array],
289
+ ) -> mx.array:
290
+ hidden_states = hidden_states + self.attn(
291
+ self.norm1(hidden_states),
292
+ cu_seqlens=cu_seqlens,
293
+ position_embeddings=position_embeddings,
294
+ )
295
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
296
+ return hidden_states
297
+
298
+
299
+ class VisionModel(nn.Module):
300
+ def __init__(self, config: VisionConfig):
301
+ super().__init__()
302
+ self.config = config
303
+ self.spatial_merge_size = config.spatial_merge_size
304
+ self.patch_size = config.patch_size
305
+
306
+ self.patch_embed = VisionPatchEmbed(config)
307
+ self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size)
308
+ self.num_grid_per_side = int(config.num_position_embeddings ** 0.5)
309
+
310
+ head_dim = config.hidden_size // config.num_heads
311
+ self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
312
+
313
+ self.blocks = [VisionBlock(config) for _ in range(config.num_hidden_layers)]
314
+ self.merger = VisionPatchMerger(config, use_postshuffle_norm=False)
315
+
316
+ self.deepstack_visual_indexes = config.deepstack_visual_indexes
317
+ self.deepstack_merger_list = [
318
+ VisionPatchMerger(config, use_postshuffle_norm=True)
319
+ for _ in range(len(config.deepstack_visual_indexes))
320
+ ]
321
+
322
+ def rot_pos_emb(self, grid_thw: mx.array) -> mx.array:
323
+ merge_size = self.spatial_merge_size
324
+
325
+ max_hw = int(grid_thw[:, 1:].max().item())
326
+ freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2)
327
+
328
+ pos_ids_parts = []
329
+
330
+ for i in range(grid_thw.shape[0]):
331
+ num_frames = int(grid_thw[i, 0].item())
332
+ height = int(grid_thw[i, 1].item())
333
+ width = int(grid_thw[i, 2].item())
334
+
335
+ merged_h, merged_w = height // merge_size, width // merge_size
336
+
337
+ block_rows = mx.arange(merged_h) # block row indices
338
+ block_cols = mx.arange(merged_w) # block col indices
339
+ intra_row = mx.arange(merge_size) # intra-block row offsets
340
+ intra_col = mx.arange(merge_size) # intra-block col offsets
341
+
342
+ # Compute full-resolution positions using broadcasting
343
+ row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None]
344
+ col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :]
345
+
346
+ row_idx = mx.broadcast_to(row_idx, (merged_h, merged_w, merge_size, merge_size)).reshape(-1)
347
+ col_idx = mx.broadcast_to(col_idx, (merged_h, merged_w, merge_size, merge_size)).reshape(-1)
348
+
349
+ coords = mx.stack([row_idx, col_idx], axis=-1)
350
+
351
+ if num_frames > 1:
352
+ coords = mx.tile(coords, (num_frames, 1))
353
+
354
+ pos_ids_parts.append(coords)
355
+
356
+ # Concatenate all coordinate parts
357
+ pos_ids = mx.concatenate(pos_ids_parts, axis=0)
358
+
359
+ embeddings = freq_table[pos_ids] # lookup rotary embeddings
360
+ embeddings = embeddings.reshape(embeddings.shape[0], -1)
361
+ return embeddings
362
+
363
+ def fast_pos_embed_interpolate(self, grid_thw: mx.array):
364
+ patch_pos_embeds = []
365
+
366
+ for i in range(grid_thw.shape[0]):
367
+ t = int(grid_thw[i, 0].item())
368
+ h = int(grid_thw[i, 1].item())
369
+ w = int(grid_thw[i, 2].item())
370
+
371
+ # Simple position embedding interpolation
372
+ h_idxs = mx.linspace(0, self.num_grid_per_side - 1, h)
373
+ w_idxs = mx.linspace(0, self.num_grid_per_side - 1, w)
374
+
375
+ h_idxs_floor = mx.floor(h_idxs).astype(mx.int32)
376
+ w_idxs_floor = mx.floor(w_idxs).astype(mx.int32)
377
+ h_idxs_ceil = mx.minimum(h_idxs_floor + 1, self.num_grid_per_side - 1)
378
+ w_idxs_ceil = mx.minimum(w_idxs_floor + 1, self.num_grid_per_side - 1)
379
+
380
+ dh = h_idxs - h_idxs_floor.astype(mx.float32)
381
+ dw = w_idxs - w_idxs_floor.astype(mx.float32)
382
+
383
+ base_h = h_idxs_floor * self.num_grid_per_side
384
+ base_h_ceil = h_idxs_ceil * self.num_grid_per_side
385
+
386
+ # Compute bilinear interpolation indices and weights
387
+ indices_tl = (base_h[:, None] + w_idxs_floor[None, :]).reshape(-1)
388
+ indices_tr = (base_h[:, None] + w_idxs_ceil[None, :]).reshape(-1)
389
+ indices_bl = (base_h_ceil[:, None] + w_idxs_floor[None, :]).reshape(-1)
390
+ indices_br = (base_h_ceil[:, None] + w_idxs_ceil[None, :]).reshape(-1)
391
+
392
+ weights_tl = ((1 - dh)[:, None] * (1 - dw)[None, :]).reshape(-1)
393
+ weights_tr = ((1 - dh)[:, None] * dw[None, :]).reshape(-1)
394
+ weights_bl = (dh[:, None] * (1 - dw)[None, :]).reshape(-1)
395
+ weights_br = (dh[:, None] * dw[None, :]).reshape(-1)
396
+
397
+ # Get embeddings and interpolate
398
+ pos_embed_tl = self.pos_embed(indices_tl) * weights_tl[:, None]
399
+ pos_embed_tr = self.pos_embed(indices_tr) * weights_tr[:, None]
400
+ pos_embed_bl = self.pos_embed(indices_bl) * weights_bl[:, None]
401
+ pos_embed_br = self.pos_embed(indices_br) * weights_br[:, None]
402
+
403
+ pos_embed = pos_embed_tl + pos_embed_tr + pos_embed_bl + pos_embed_br
404
+
405
+ # Repeat for temporal dimension and apply spatial merging
406
+ pos_embed = mx.tile(pos_embed, (t, 1))
407
+
408
+ # Apply spatial merging pattern
409
+ merge_size = self.config.spatial_merge_size
410
+ pos_embed = pos_embed.reshape(t, h // merge_size, merge_size, w // merge_size, merge_size, -1)
411
+ pos_embed = mx.transpose(pos_embed, (0, 1, 3, 2, 4, 5))
412
+ pos_embed = pos_embed.reshape(-1, pos_embed.shape[-1])
413
+
414
+ patch_pos_embeds.append(pos_embed)
415
+
416
+ return mx.concatenate(patch_pos_embeds, axis=0)
417
+
418
+ def __call__(self, hidden_states: mx.array, grid_thw: mx.array) -> Tuple[mx.array, List[mx.array]]:
419
+ hidden_states = self.patch_embed(hidden_states)
420
+
421
+ pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
422
+ hidden_states = hidden_states + pos_embeds
423
+
424
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
425
+ seq_len = hidden_states.shape[0]
426
+
427
+ emb = mx.concatenate([rotary_pos_emb, rotary_pos_emb], axis=-1)
428
+ position_embeddings = (mx.cos(emb), mx.sin(emb))
429
+
430
+ # Create cumulative sequence lengths (following HuggingFace implementation)
431
+ # torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0])
432
+ seq_lens_per_image = grid_thw[:, 1] * grid_thw[:, 2] # h * w for each image
433
+ seq_lens = []
434
+ for i, (seq_len, repeats) in enumerate(zip(seq_lens_per_image, grid_thw[:, 0])):
435
+ seq_lens.extend([seq_len] * int(repeats))
436
+ seq_lens = mx.array(seq_lens)
437
+
438
+ # Then compute cumulative sum
439
+ cu_seqlens = mx.cumsum(seq_lens)
440
+ # Pad with 0 at the beginning
441
+ cu_seqlens = mx.concatenate([mx.array([0]), cu_seqlens])
442
+
443
+ deepstack_feature_lists = []
444
+ for layer_num, blk in enumerate(self.blocks):
445
+ hidden_states = blk(
446
+ hidden_states,
447
+ cu_seqlens=cu_seqlens,
448
+ position_embeddings=position_embeddings,
449
+ )
450
+ if layer_num in self.deepstack_visual_indexes:
451
+ idx = self.deepstack_visual_indexes.index(layer_num)
452
+ deepstack_feature = self.deepstack_merger_list[idx](hidden_states)
453
+ deepstack_feature_lists.append(deepstack_feature)
454
+
455
+ hidden_states = self.merger(hidden_states)
456
+ return hidden_states, deepstack_feature_lists
457
+
458
+
459
+ class TextRotaryEmbedding(nn.Module):
460
+ def __init__(self, config: TextConfig):
461
+ super().__init__()
462
+ self.config = config
463
+ self.max_seq_len_cached = config.max_position_embeddings
464
+ self.original_max_seq_len = config.max_position_embeddings
465
+
466
+ # MRoPE configuration
467
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
468
+ self.rope_type = config.rope_scaling.get("rope_type", "default")
469
+ self.mrope_section = config.rope_scaling.get("mrope_section", [24, 20, 20])
470
+ else:
471
+ self.rope_type = "default"
472
+ self.mrope_section = [24, 20, 20]
473
+
474
+ # Store parameters for computing inv_freq on the fly
475
+ self.head_dim = config.head_dim
476
+ self.theta = config.rope_theta
477
+
478
+ # Attention scaling (simplified - may need adjustment based on actual config)
479
+ self.attention_scaling = 1.0
480
+
481
+ def _get_inv_freq(self):
482
+ """Compute inverse frequencies on the fly"""
483
+ inv_freq = 1.0 / (self.theta ** (mx.arange(0, self.head_dim, 2).astype(mx.float32) / self.head_dim))
484
+ # Expand for 3 dimensions (T, H, W)
485
+ return mx.broadcast_to(inv_freq[None, :], (3, len(inv_freq)))
486
+
487
+ def apply_interleaved_mrope(self, freqs, mrope_section):
488
+ """Apply interleaved MRoPE to 3D rotary embeddings.
489
+ Reorganizes frequency layout from chunked [TTT...HHH...WWW] to
490
+ interleaved [THTHWHTHW...TT], preserving frequency continuity.
491
+ args:
492
+ x: (3, bs, seq_len, head_dim // 2)
493
+ mrope_section: (3,)
494
+ returns:
495
+ x_t: (bs, seq_len, head_dim // 2)
496
+ """
497
+ freqs_t = freqs[0] # just overwrite the first dimension T
498
+ for dim, offset in enumerate((1, 2), start=1): # H, W
499
+ length = mrope_section[dim] * 3
500
+ idx = slice(offset, length, 3)
501
+ freqs_t[..., idx] = freqs[dim, ..., idx]
502
+ return freqs_t
503
+
504
+ def __call__(self, x: mx.array, position_ids: mx.array) -> mx.array:
505
+ """
506
+ Args:
507
+ x: Input tensor for dtype reference
508
+ position_ids: Position indices, shape (3, batch_size, seq_len) for MRoPE
509
+
510
+ Returns:
511
+ cos, sin: Cosine and sine embeddings
512
+ """
513
+ # Handle 2D position_ids by expanding to 3D for MRoPE
514
+ if position_ids.ndim == 2:
515
+ position_ids = mx.broadcast_to(position_ids[None, ...], (3, position_ids.shape[0], position_ids.shape[1]))
516
+
517
+ batch_size, seq_len = position_ids.shape[1], position_ids.shape[2]
518
+
519
+ # Expand inverse frequencies: (3, 1, 1, dim//2) -> (3, batch_size, 1, dim//2)
520
+ inv_freq_expanded = mx.broadcast_to(
521
+ self._get_inv_freq()[:, None, None, :],
522
+ (3, batch_size, 1, self._get_inv_freq().shape[-1])
523
+ )
524
+
525
+ # Expand position ids: (3, batch_size, seq_len) -> (3, batch_size, seq_len, 1)
526
+ position_ids_expanded = position_ids[..., None].astype(mx.float32)
527
+
528
+ # Compute frequencies: (3, batch_size, seq_len, dim//2)
529
+ freqs = inv_freq_expanded * position_ids_expanded
530
+
531
+ # Apply interleaved MRoPE
532
+ freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
533
+
534
+ # Create embeddings
535
+ emb = mx.concatenate([freqs, freqs], axis=-1) # (batch_size, seq_len, head_dim)
536
+ cos = mx.cos(emb) * self.attention_scaling
537
+ sin = mx.sin(emb) * self.attention_scaling
538
+
539
+ return cos.astype(x.dtype), sin.astype(x.dtype)
540
+
541
+
542
+ class TextAttention(nn.Module):
543
+ def __init__(self, config: TextConfig, layer_idx: int):
544
+ super().__init__()
545
+ self.config = config
546
+ self.layer_idx = layer_idx
547
+
548
+ dim = config.hidden_size
549
+ self.n_heads = config.num_attention_heads
550
+ self.n_kv_heads = config.num_key_value_heads
551
+ self.head_dim = config.head_dim
552
+ self.scale = self.head_dim ** -0.5
553
+
554
+ self.q_proj = nn.Linear(dim, self.n_heads * self.head_dim, bias=config.attention_bias)
555
+ self.k_proj = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=config.attention_bias)
556
+ self.v_proj = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=config.attention_bias)
557
+ self.o_proj = nn.Linear(self.n_heads * self.head_dim, dim, bias=config.attention_bias)
558
+
559
+ self.q_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
560
+ self.k_norm = nn.RMSNorm(self.head_dim, eps=config.rms_norm_eps)
561
+
562
+ # Initialize rope directly
563
+ self.rope = initialize_rope(
564
+ config.head_dim,
565
+ base=config.rope_theta,
566
+ traditional=False,
567
+ scaling_config=config.rope_scaling,
568
+ max_position_embeddings=config.max_position_embeddings,
569
+ )
570
+
571
+ def __call__(
572
+ self,
573
+ hidden_states: mx.array,
574
+ attention_mask: Optional[mx.array] = None,
575
+ cache: Optional[Any] = None,
576
+ cos: Optional[mx.array] = None,
577
+ sin: Optional[mx.array] = None,
578
+ rope_deltas: Optional[mx.array] = None,
579
+ ) -> Tuple[mx.array, Optional[mx.array]]:
580
+ B, L, D = hidden_states.shape
581
+
582
+ queries = self.q_proj(hidden_states).reshape(B, L, self.n_heads, -1)
583
+ keys = self.k_proj(hidden_states).reshape(B, L, self.n_kv_heads, -1)
584
+ values = self.v_proj(hidden_states).reshape(B, L, self.n_kv_heads, -1)
585
+
586
+ queries = self.q_norm(queries).transpose(0, 2, 1, 3)
587
+ keys = self.k_norm(keys).transpose(0, 2, 1, 3)
588
+ values = values.transpose(0, 2, 1, 3)
589
+
590
+ # Apply rope directly to queries and keys
591
+ if cos is not None and sin is not None:
592
+ queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin)
593
+ if cache is not None:
594
+ keys, values = cache.update_and_fetch(keys, values)
595
+ else:
596
+ if cache is not None:
597
+ # Handle different types of rope_deltas: scalar, array, or None
598
+ if rope_deltas is None:
599
+ offset_delta = 0
600
+ elif isinstance(rope_deltas, (int, float)):
601
+ # rope_deltas is a scalar
602
+ offset_delta = rope_deltas
603
+ elif hasattr(rope_deltas, 'size') and rope_deltas.size == 1:
604
+ # rope_deltas is an array with single element
605
+ offset_delta = rope_deltas.item()
606
+ elif hasattr(rope_deltas, 'shape') and rope_deltas.shape:
607
+ # rope_deltas is an array with multiple elements, take first
608
+ offset_delta = rope_deltas.reshape(-1)[0].item()
609
+ else:
610
+ offset_delta = 0
611
+
612
+ queries = self.rope(queries, offset=cache.offset+offset_delta)
613
+ keys = self.rope(keys, offset=cache.offset+offset_delta)
614
+ keys, values = cache.update_and_fetch(keys, values)
615
+ else:
616
+ queries = self.rope(queries)
617
+ keys = self.rope(keys)
618
+
619
+ output = scaled_dot_product_attention(
620
+ queries, keys, values, cache=cache, scale=self.scale, mask=attention_mask
621
+ )
622
+ output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
623
+ return self.o_proj(output), None
624
+
625
+
626
+ class TextMLP(nn.Module):
627
+ def __init__(self, config: TextConfig):
628
+ super().__init__()
629
+ self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
630
+ self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
631
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
632
+
633
+ def __call__(self, x):
634
+ return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
635
+
636
+
637
+ # Add this custom MoE implementation to replace SwitchGLU usage
638
+
639
+ class TextMoEExperts(nn.Module):
640
+ def __init__(self, config: TextConfig):
641
+ super().__init__()
642
+ # Use the optimized SwitchGLU implementation for efficient expert computation
643
+ self.switch_glu = SwitchGLU(
644
+ input_dims=config.hidden_size,
645
+ hidden_dims=config.moe_intermediate_size,
646
+ num_experts=config.num_experts,
647
+ activation=nn.SiLU(),
648
+ bias=False
649
+ )
650
+
651
+ def __call__(self, hidden_states: mx.array, routing_weights: mx.array, router_indices: mx.array) -> mx.array:
652
+ # Use the efficient SwitchGLU implementation
653
+ # SwitchGLU handles the expert routing internally and is highly optimized
654
+ expert_output = self.switch_glu(hidden_states, router_indices)
655
+
656
+ # Apply routing weights and sum over experts (top_k dimension)
657
+ weighted_output = expert_output * mx.expand_dims(routing_weights, -1)
658
+ final_output = mx.sum(weighted_output, axis=-2)
659
+
660
+ return final_output
661
+
662
+ class TextSparseMoeBlock(nn.Module):
663
+ def __init__(self, config: TextConfig):
664
+ super().__init__()
665
+ self.hidden_size = config.hidden_size
666
+ self.num_experts = config.num_experts
667
+ self.top_k = config.num_experts_per_tok
668
+ self.norm_topk_prob = config.norm_topk_prob
669
+
670
+ self.gate = nn.Linear(self.hidden_size, self.num_experts, bias=False)
671
+ self.experts = TextMoEExperts(config)
672
+
673
+ def __call__(self, x: mx.array) -> mx.array:
674
+ batch_size, sequence_length, hidden_dim = x.shape
675
+ x_flat = x.reshape(-1, hidden_dim)
676
+
677
+ # Router computation
678
+ router_logits = self.gate(x_flat)
679
+ routing_weights = mx.softmax(router_logits, axis=-1, precise=True)
680
+
681
+ # Top-k selection
682
+ router_indices = mx.argpartition(-routing_weights, kth=self.top_k - 1, axis=-1)[..., :self.top_k]
683
+ routing_weights = mx.take_along_axis(routing_weights, router_indices, axis=-1)
684
+
685
+ if self.norm_topk_prob:
686
+ routing_weights = routing_weights / mx.sum(routing_weights, axis=-1, keepdims=True)
687
+
688
+ # Expert computation
689
+ final_hidden_states = self.experts(x, routing_weights, router_indices)
690
+
691
+ return final_hidden_states
692
+
693
+
694
+ class TextDecoderLayer(nn.Module):
695
+ def __init__(self, config: TextConfig, layer_idx: int):
696
+ super().__init__()
697
+ self.hidden_size = config.hidden_size
698
+ self.self_attn = TextAttention(config, layer_idx)
699
+
700
+ # Determine if this layer should use MoE
701
+ use_moe = (
702
+ layer_idx not in config.mlp_only_layers and
703
+ config.num_experts > 0 and
704
+ (layer_idx + 1) % config.decoder_sparse_step == 0
705
+ )
706
+
707
+ if use_moe:
708
+ self.mlp = TextSparseMoeBlock(config)
709
+ else:
710
+ self.mlp = TextMLP(config)
711
+
712
+ self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
713
+ self.post_attention_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
714
+
715
+ def __call__(
716
+ self,
717
+ hidden_states: mx.array,
718
+ attention_mask: Optional[mx.array] = None,
719
+ cache: Optional[Any] = None,
720
+ cos: Optional[mx.array] = None,
721
+ sin: Optional[mx.array] = None,
722
+ rope_deltas: Optional[mx.array] = None,
723
+ ) -> mx.array:
724
+ residual = hidden_states
725
+ hidden_states = self.input_layernorm(hidden_states)
726
+
727
+ hidden_states, _ = self.self_attn(
728
+ hidden_states=hidden_states,
729
+ attention_mask=attention_mask,
730
+ cache=cache,
731
+ cos=cos,
732
+ sin=sin,
733
+ rope_deltas=rope_deltas,
734
+ )
735
+ hidden_states = residual + hidden_states
736
+ residual = hidden_states
737
+ hidden_states = self.post_attention_layernorm(hidden_states)
738
+ hidden_states = self.mlp(hidden_states)
739
+ hidden_states = residual + hidden_states
740
+ return hidden_states
741
+
742
+
743
+ class TextModel(nn.Module):
744
+ def __init__(self, config: TextConfig):
745
+ super().__init__()
746
+ self.config = config
747
+ self.vocab_size = config.vocab_size
748
+
749
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
750
+ self.layers = [
751
+ TextDecoderLayer(config, layer_idx)
752
+ for layer_idx in range(config.num_hidden_layers)
753
+ ]
754
+ self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
755
+ self.rotary_emb = TextRotaryEmbedding(config)
756
+
757
+ def _deepstack_process(
758
+ self,
759
+ hidden_states: mx.array,
760
+ visual_pos_masks: mx.array,
761
+ deepstack_visual_embeds: mx.array,
762
+ ) -> mx.array:
763
+ if visual_pos_masks is None or deepstack_visual_embeds is None:
764
+ return hidden_states
765
+ B, L, D = hidden_states.shape
766
+ mask_flat = visual_pos_masks.astype(mx.int32).reshape(-1)
767
+ idx_flat = mx.cumsum(mask_flat, axis=0) - 1
768
+ N = deepstack_visual_embeds.shape[0]
769
+ idx_flat = mx.maximum(idx_flat, 0)
770
+ eq = (idx_flat[:, None] == mx.arange(N)[None, :]).astype(hidden_states.dtype)
771
+ add_flat = eq @ deepstack_visual_embeds.astype(hidden_states.dtype)
772
+ add_flat = add_flat * mask_flat[:, None].astype(hidden_states.dtype)
773
+ add = add_flat.reshape(B, L, D)
774
+ return hidden_states + add
775
+
776
+ def __call__(
777
+ self,
778
+ input_ids: Optional[mx.array] = None,
779
+ inputs_embeds: Optional[mx.array] = None,
780
+ attention_mask: Optional[mx.array] = None,
781
+ cache=None,
782
+ visual_pos_masks: Optional[mx.array] = None,
783
+ deepstack_visual_embeds: Optional[List[mx.array]] = None,
784
+ cos: Optional[mx.array] = None,
785
+ sin: Optional[mx.array] = None,
786
+ rope_deltas: Optional[mx.array] = None,
787
+ ):
788
+ if inputs_embeds is None:
789
+ inputs_embeds = self.embed_tokens(input_ids)
790
+
791
+ hidden_states = inputs_embeds
792
+
793
+ if attention_mask is None:
794
+ attention_mask = create_attention_mask(hidden_states, cache, return_array=True)
795
+
796
+ if cache is None:
797
+ cache = [None] * len(self.layers)
798
+
799
+ for layer_idx, (decoder_layer, c) in enumerate(zip(self.layers, cache)):
800
+ hidden_states = decoder_layer(
801
+ hidden_states,
802
+ attention_mask=attention_mask,
803
+ cache=c,
804
+ cos=cos,
805
+ sin=sin,
806
+ rope_deltas=rope_deltas,
807
+ )
808
+ if deepstack_visual_embeds is not None and layer_idx < len(deepstack_visual_embeds):
809
+ hidden_states = self._deepstack_process(hidden_states, visual_pos_masks, deepstack_visual_embeds[layer_idx])
810
+ hidden_states = self.norm(hidden_states)
811
+ return hidden_states
812
+
813
+
814
+ # Standalone Vision Model
815
+ class VEGModel(nn.Module):
816
+ def __init__(self, vision_config: VisionConfig):
817
+ super().__init__()
818
+ self.config = vision_config
819
+ self.visual = VisionModel(vision_config)
820
+
821
+ def __call__(self, pixel_values: mx.array, image_grid_thw: mx.array):
822
+ return self.visual(pixel_values, image_grid_thw)
823
+
824
+ def sanitize(self, weights):
825
+ sanitized = {}
826
+ for k, v in weights.items():
827
+ if 'visual.' in k:
828
+ # Remove prefixes to match our model structure
829
+ clean_key = k.replace('model.visual.', '').replace('visual.', '')
830
+ sanitized[f'visual.{clean_key}'] = v
831
+ return sanitized
832
+
833
+
834
+ # Pure LLM Model (no vision components)
835
+ class LLMModel(nn.Module):
836
+ def __init__(self, text_config: TextConfig):
837
+ super().__init__()
838
+ self.args = text_config
839
+ self.config = text_config
840
+ self.language_model = TextModel(text_config)
841
+ if not text_config.tie_word_embeddings:
842
+ self.lm_head = nn.Linear(text_config.hidden_size, text_config.vocab_size, bias=False)
843
+
844
+ def get_rope_index(
845
+ self,
846
+ input_ids: Optional[mx.array] = None,
847
+ image_grid_thw: Optional[mx.array] = None,
848
+ attention_mask: Optional[mx.array] = None,
849
+ ) -> Tuple[mx.array, mx.array]:
850
+ """Simplified version for images only (no video support)."""
851
+
852
+ spatial_merge_size = 2
853
+ image_token_id = 151655
854
+ vision_start_token_id = 151652
855
+ mrope_position_deltas = []
856
+
857
+ if input_ids is not None and image_grid_thw is not None:
858
+ total_input_ids = input_ids
859
+ if attention_mask is None:
860
+ attention_mask = mx.ones_like(total_input_ids)
861
+
862
+ batch_size, seq_len = input_ids.shape
863
+ position_ids_list = []
864
+ image_index = 0
865
+
866
+ for i in range(batch_size):
867
+ input_ids_seq = total_input_ids[i]
868
+ mask_seq = attention_mask[i]
869
+
870
+ # Use mask to get valid length
871
+ valid_length = int(mx.sum(mask_seq).item())
872
+ input_ids_seq = input_ids_seq[:valid_length]
873
+
874
+ image_nums = 0
875
+ # Find vision start tokens by iterating through the sequence
876
+ vision_start_positions = []
877
+ for pos in range(input_ids_seq.shape[0]):
878
+ if input_ids_seq[pos].item() == vision_start_token_id:
879
+ vision_start_positions.append(pos)
880
+
881
+ if len(vision_start_positions) > 0:
882
+ for pos in vision_start_positions:
883
+ if pos + 1 < input_ids_seq.shape[0]:
884
+ if input_ids_seq[pos + 1].item() == image_token_id:
885
+ image_nums += 1
886
+
887
+ input_tokens = input_ids_seq.tolist()
888
+ llm_pos_ids_list = []
889
+ st = 0
890
+ remain_images = image_nums
891
+
892
+ for _ in range(image_nums):
893
+ ed_image = input_tokens.index(image_token_id, st)
894
+
895
+ t = image_grid_thw[image_index, 0].item()
896
+ h = image_grid_thw[image_index, 1].item()
897
+ w = image_grid_thw[image_index, 2].item()
898
+ image_index += 1
899
+ remain_images -= 1
900
+ ed = ed_image
901
+
902
+ llm_grid_t = int(t)
903
+ llm_grid_h = int(h) // spatial_merge_size
904
+ llm_grid_w = int(w) // spatial_merge_size
905
+ text_len = ed - st
906
+
907
+ st_idx = llm_pos_ids_list[-1].max().item() + 1 if len(llm_pos_ids_list) > 0 else 0
908
+ text_pos = mx.arange(text_len).reshape(1, -1)
909
+ text_pos = mx.broadcast_to(text_pos, (3, text_len)) + st_idx
910
+ llm_pos_ids_list.append(text_pos)
911
+
912
+ # t_index is always 0 because llm_grid_t is always 1 for images
913
+ t_index = mx.arange(llm_grid_t).reshape(-1, 1)
914
+ t_index = mx.broadcast_to(t_index, (llm_grid_t, llm_grid_h * llm_grid_w)).reshape(-1)
915
+
916
+ h_index = mx.arange(llm_grid_h).reshape(1, -1, 1)
917
+ h_index = mx.broadcast_to(h_index, (llm_grid_t, llm_grid_h, llm_grid_w)).reshape(-1)
918
+
919
+ w_index = mx.arange(llm_grid_w).reshape(1, 1, -1)
920
+ w_index = mx.broadcast_to(w_index, (llm_grid_t, llm_grid_h, llm_grid_w)).reshape(-1)
921
+
922
+ vision_pos = mx.stack([t_index, h_index, w_index]) + text_len + st_idx
923
+ llm_pos_ids_list.append(vision_pos)
924
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
925
+
926
+ if st < len(input_tokens):
927
+ st_idx = llm_pos_ids_list[-1].max().item() + 1 if len(llm_pos_ids_list) > 0 else 0
928
+ text_len = len(input_tokens) - st
929
+ text_pos = mx.arange(text_len).reshape(1, -1)
930
+ text_pos = mx.broadcast_to(text_pos, (3, text_len)) + st_idx
931
+ llm_pos_ids_list.append(text_pos)
932
+
933
+ llm_positions = mx.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
934
+
935
+ # Create position_ids for this batch item, pad to seq_len
936
+ batch_position_ids = mx.ones((3, seq_len), dtype=input_ids.dtype)
937
+ valid_length = min(seq_len, llm_positions.shape[1])
938
+
939
+ # Create new arrays for each dimension
940
+ pos_dim0 = mx.concatenate([llm_positions[0, :valid_length],
941
+ mx.ones(seq_len - valid_length, dtype=input_ids.dtype)])
942
+ pos_dim1 = mx.concatenate([llm_positions[1, :valid_length],
943
+ mx.ones(seq_len - valid_length, dtype=input_ids.dtype)])
944
+ pos_dim2 = mx.concatenate([llm_positions[2, :valid_length],
945
+ mx.ones(seq_len - valid_length, dtype=input_ids.dtype)])
946
+
947
+ batch_position_ids = mx.stack([pos_dim0, pos_dim1, pos_dim2])
948
+ position_ids_list.append(batch_position_ids)
949
+
950
+ mrope_position_deltas.append(llm_positions.max().item() + 1 - len(total_input_ids[i]))
951
+
952
+ # Stack all batch position_ids
953
+ position_ids = mx.stack(position_ids_list, axis=1) # Shape: (3, batch_size, seq_len)
954
+ mrope_position_deltas = mx.array(mrope_position_deltas).reshape(-1, 1)
955
+ return position_ids, mrope_position_deltas
956
+ else:
957
+ if attention_mask is not None:
958
+ position_ids = mx.cumsum(attention_mask.astype(mx.int32), axis=-1) - 1
959
+ position_ids = mx.where(attention_mask == 0, 1, position_ids)
960
+ position_ids = mx.expand_dims(position_ids, axis=0)
961
+ position_ids = mx.broadcast_to(position_ids, (3, position_ids.shape[1], position_ids.shape[2]))
962
+ max_position_ids = mx.max(mx.max(position_ids, axis=0, keepdims=False), axis=-1, keepdims=True)
963
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
964
+ else:
965
+ seq_len = input_ids.shape[1]
966
+ batch_size = input_ids.shape[0]
967
+ position_ids = mx.arange(seq_len).reshape(1, 1, -1)
968
+ position_ids = mx.broadcast_to(position_ids, (3, batch_size, seq_len))
969
+ mrope_position_deltas = mx.zeros((batch_size, 1), dtype=input_ids.dtype)
970
+
971
+ return position_ids, mrope_position_deltas
972
+
973
+ def __call__(
974
+ self,
975
+ inputs: mx.array = None,
976
+ mask: mx.array = None,
977
+ cache=None,
978
+ inputs_embeds: Optional[mx.array] = None,
979
+ visual_pos_masks: Optional[mx.array] = None,
980
+ deepstack_visual_embeds: Optional[List[mx.array]] = None,
981
+ cos: Optional[mx.array] = None,
982
+ sin: Optional[mx.array] = None,
983
+ rope_deltas: Optional[mx.array] = None,
984
+ ):
985
+ out = self.language_model(
986
+ input_ids=inputs,
987
+ inputs_embeds=inputs_embeds,
988
+ attention_mask=mask,
989
+ cache=cache,
990
+ visual_pos_masks=visual_pos_masks,
991
+ deepstack_visual_embeds=deepstack_visual_embeds,
992
+ cos=cos,
993
+ sin=sin,
994
+ rope_deltas=rope_deltas,
995
+ )
996
+ if self.args.tie_word_embeddings:
997
+ return self.language_model.embed_tokens.as_linear(out)
998
+ else:
999
+ return self.lm_head(out)
1000
+
1001
+ def sanitize(self, weights):
1002
+ sanitized = {}
1003
+ for k, v in weights.items():
1004
+ if not ('visual.' in k):
1005
+ # Handle key mapping from combined model to LLM-only model
1006
+ clean_key = k
1007
+
1008
+ # Remove model. prefix if present
1009
+ if clean_key.startswith('model.'):
1010
+ clean_key = clean_key[6:] # Remove 'model.'
1011
+
1012
+ # Map language_ prefixed keys to language_model structure
1013
+ if clean_key.startswith('language_'):
1014
+ if clean_key.startswith('language_layers.'):
1015
+ clean_key = 'language_model.layers.' + clean_key[16:] # Map to language_model.layers.
1016
+ elif clean_key.startswith('language_embed_tokens.'):
1017
+ clean_key = 'language_model.embed_tokens.' + clean_key[22:] # Map to language_model.embed_tokens.
1018
+ elif clean_key.startswith('language_norm.'):
1019
+ clean_key = 'language_model.norm.' + clean_key[14:] # Map to language_model.norm.
1020
+
1021
+ sanitized[clean_key] = v
1022
+
1023
+ # Handle tied embeddings - remove lm_head if using tied embeddings
1024
+ if self.args.tie_word_embeddings:
1025
+ sanitized.pop("lm_head.weight", None)
1026
+
1027
+ return sanitized
1028
+
1029
+ @property
1030
+ def layers(self):
1031
+ return self.language_model.layers
1032
+
1033
+
1034
+ # Combined Model (for compatibility and utility functions)
1035
+ class Qwen3VLModel(nn.Module):
1036
+ def __init__(self, args: ModelArgs):
1037
+ super().__init__()
1038
+ self.args = args
1039
+ self.config = args
1040
+ self.visual = VisionModel(args.vision_config)
1041
+ self.language_model = TextModel(args.text_config)
1042
+
1043
+ def sanitize(self, weights):
1044
+ # Map weights to match the combined model structure
1045
+ sanitized = {}
1046
+ for k, v in weights.items():
1047
+ # Remove 'model.' prefix if present to match our structure
1048
+ clean_key = k.replace('model.', '') if k.startswith('model.') else k
1049
+ sanitized[clean_key] = v
1050
+ return sanitized
1051
+
1052
+ def get_image_features(
1053
+ self,
1054
+ pixel_values: mx.array,
1055
+ image_grid_thw: Optional[mx.array] = None
1056
+ ):
1057
+ image_embeds, deepstack_visual_embeds = self.visual(pixel_values, image_grid_thw)
1058
+ # Split based on grid dimensions
1059
+ if image_grid_thw is not None:
1060
+ split_sizes = (mx.prod(image_grid_thw, axis=-1) // (self.visual.spatial_merge_size ** 2)).tolist()
1061
+ # Convert sizes to indices for mx.split (cumulative sum, excluding the last)
1062
+ split_indices = []
1063
+ cumsum = 0
1064
+ for size in split_sizes[:-1]: # Exclude last element
1065
+ cumsum += size
1066
+ split_indices.append(cumsum)
1067
+
1068
+ if split_indices: # Only split if we have indices
1069
+ image_embeds = mx.split(image_embeds, split_indices)
1070
+ else:
1071
+ image_embeds = [image_embeds] # Single image case
1072
+ return image_embeds, deepstack_visual_embeds
1073
+
1074
+
1075
+ def __call__(
1076
+ self,
1077
+ input_ids: mx.array = None,
1078
+ attention_mask: Optional[mx.array] = None,
1079
+ inputs_embeds: Optional[mx.array] = None,
1080
+ pixel_values: Optional[mx.array] = None,
1081
+ image_grid_thw: Optional[mx.array] = None,
1082
+ cache=None,
1083
+ visual_pos_masks: Optional[mx.array] = None,
1084
+ deepstack_visual_embeds: Optional[List[mx.array]] = None,
1085
+ cos: Optional[mx.array] = None,
1086
+ sin: Optional[mx.array] = None,
1087
+ rope_deltas: Optional[mx.array] = None,
1088
+ ):
1089
+ if inputs_embeds is None:
1090
+ inputs_embeds = self.language_model.embed_tokens(input_ids)
1091
+
1092
+ # Process images
1093
+
1094
+ if pixel_values is not None:
1095
+ image_embeds, deepstack_visual_embeds = self.get_image_features(
1096
+ pixel_values, image_grid_thw
1097
+ )
1098
+
1099
+ # Create masks and embed visual features
1100
+ if isinstance(image_embeds, list):
1101
+ image_embeds = mx.concatenate(image_embeds, axis=0)
1102
+
1103
+ # Find image token positions and replace with visual embeddings
1104
+ image_mask = (input_ids == self.args.image_token_id)
1105
+ visual_pos_masks = image_mask
1106
+
1107
+ # Replace image tokens with visual embeddings
1108
+ inputs_embeds = inputs_embeds.at[image_mask].set(
1109
+ image_embeds.astype(inputs_embeds.dtype)
1110
+ )
1111
+
1112
+
1113
+ outputs = self.language_model(
1114
+ inputs_embeds=inputs_embeds,
1115
+ attention_mask=attention_mask,
1116
+ cache=cache,
1117
+ visual_pos_masks=visual_pos_masks,
1118
+ deepstack_visual_embeds=deepstack_visual_embeds,
1119
+ cos=cos,
1120
+ sin=sin,
1121
+ rope_deltas=rope_deltas,
1122
+ )
1123
+
1124
+ return outputs
1125
+
1126
+
1127
+ def handle_multimodal_embeds(vision_model, llm_model, input_ids, pixel_values, image_grid_thw):
1128
+ """
1129
+ Handle the processing of multimodal embeddings including image features and position encoding.
1130
+
1131
+ This function processes vision and text inputs to create unified embeddings that can be fed
1132
+ into the language model. It handles:
1133
+ - Vision feature extraction from pixel values
1134
+ - Deepstack visual embedding collection
1135
+ - Image token replacement in text embeddings
1136
+ - Position encoding setup for MRoPE (Multi-dimensional RoPE)
1137
+
1138
+ Args:
1139
+ vision_model: The vision encoder model (VEGModel instance)
1140
+ llm_model: The language model (LLMModel instance)
1141
+ input_ids: Tokenized text input with image token placeholders [batch_size, seq_len]
1142
+ pixel_values: Preprocessed image pixel data [num_patches, feature_dim]
1143
+ image_grid_thw: Grid dimensions for each image [num_images, 3] (time, height, width)
1144
+
1145
+ Returns:
1146
+ tuple: (inputs_embeds, deepstack_visual_embeds, visual_pos_masks, cos, sin, rope_deltas)
1147
+ - inputs_embeds: Combined text and image embeddings [batch_size, seq_len, hidden_size]
1148
+ - deepstack_visual_embeds: Multi-layer visual features for deepstack processing
1149
+ - visual_pos_masks: Boolean mask indicating image token positions
1150
+ - cos: Cosine values for rotary position encoding
1151
+ - sin: Sine values for rotary position encoding
1152
+ - rope_deltas: Position offset deltas for rope computation
1153
+ """
1154
+ inputs_embeds = llm_model.language_model.embed_tokens(input_ids.squeeze(0))
1155
+ deepstack_visual_embeds = None
1156
+ visual_pos_masks = None
1157
+ cos = None
1158
+ sin = None
1159
+ rope_deltas = 0
1160
+
1161
+ if pixel_values is not None:
1162
+ if pixel_values.ndim == 4:
1163
+ pixel_values = mx.expand_dims(pixel_values, axis=2)
1164
+
1165
+ # Process each image individually to prevent feature mixing
1166
+ image_embeds_list = []
1167
+ all_deepstack_embeds = []
1168
+
1169
+ # Calculate cumulative indices for each image
1170
+ cumulative_patches = 0
1171
+
1172
+ for i in range(image_grid_thw.shape[0]):
1173
+ # Calculate number of patches for current image
1174
+ current_patches = int(image_grid_thw[i, 1] * image_grid_thw[i, 2])
1175
+ start_idx = cumulative_patches
1176
+ end_idx = cumulative_patches + current_patches
1177
+ cumulative_patches += current_patches
1178
+
1179
+ single_pixel_values = pixel_values[start_idx:end_idx]
1180
+ single_grid_thw = image_grid_thw[i:i+1]
1181
+
1182
+ # Use vision model directly
1183
+ single_embeds, single_deepstack = vision_model(single_pixel_values, single_grid_thw)
1184
+
1185
+ # Split based on grid dimensions
1186
+ if single_grid_thw is not None:
1187
+ split_sizes = (mx.prod(single_grid_thw, axis=-1) // (vision_model.visual.spatial_merge_size ** 2)).tolist()
1188
+ split_indices = []
1189
+ cumsum = 0
1190
+ for size in split_sizes[:-1]:
1191
+ cumsum += size
1192
+ split_indices.append(cumsum)
1193
+
1194
+ if split_indices:
1195
+ single_embeds = mx.split(single_embeds, split_indices)
1196
+ else:
1197
+ single_embeds = [single_embeds]
1198
+
1199
+ image_embeds_list.extend(single_embeds)
1200
+
1201
+ # Collect deepstack embeddings
1202
+ if i == 0:
1203
+ all_deepstack_embeds = single_deepstack
1204
+ else:
1205
+ # Concatenate deepstack embeddings from different images
1206
+ for j in range(len(all_deepstack_embeds)):
1207
+ all_deepstack_embeds[j] = mx.concatenate([all_deepstack_embeds[j], single_deepstack[j]], axis=0)
1208
+
1209
+ deepstack_visual_embeds = all_deepstack_embeds
1210
+
1211
+ # Concatenate all image embeddings for processing
1212
+ image_embeds = mx.concatenate(image_embeds_list, axis=0)
1213
+
1214
+ # Find all image token positions
1215
+ image_token_id = 151655 # Default image token ID
1216
+ image_mask = (input_ids.squeeze(0) == image_token_id)
1217
+ image_mask_np = np.array(image_mask)
1218
+ image_token_positions = np.where(image_mask_np)[0]
1219
+
1220
+ # Verify we have the correct number of image tokens
1221
+ expected_total_tokens = sum(embed.shape[0] for embed in image_embeds_list)
1222
+ assert len(image_token_positions) == expected_total_tokens, f"Expected {expected_total_tokens} image tokens, got {len(image_token_positions)}"
1223
+
1224
+ # Replace image tokens with image embeddings
1225
+ seq_len = inputs_embeds.shape[0]
1226
+ result = inputs_embeds
1227
+
1228
+ # Replace image tokens with image embeddings sequentially
1229
+ embed_idx = 0
1230
+ for img_embed in image_embeds_list:
1231
+ for patch_idx in range(img_embed.shape[0]):
1232
+ token_pos = image_token_positions[embed_idx]
1233
+ pos_mask = mx.arange(seq_len) == token_pos
1234
+ result = mx.where(
1235
+ mx.expand_dims(pos_mask, axis=-1),
1236
+ mx.expand_dims(img_embed[patch_idx], axis=0).astype(inputs_embeds.dtype),
1237
+ result
1238
+ )
1239
+ embed_idx += 1
1240
+
1241
+ inputs_embeds = result
1242
+ position_ids, rope_deltas = llm_model.get_rope_index(input_ids, image_grid_thw)
1243
+ cos, sin = llm_model.language_model.rotary_emb(inputs_embeds, position_ids)
1244
+ if inputs_embeds.ndim == 2:
1245
+ inputs_embeds = mx.expand_dims(inputs_embeds, axis=0)
1246
+
1247
+ if image_mask is not None:
1248
+ visual_pos_masks = image_mask
1249
+
1250
+ return inputs_embeds, deepstack_visual_embeds, visual_pos_masks, cos, sin, rope_deltas
1251
+
1252
+
1253
+ # Legacy Model wrapper (for backward compatibility)
1254
+ class Model(nn.Module):
1255
+ def __init__(self, args: ModelArgs):
1256
+ super().__init__()
1257
+ self.args = args
1258
+ self.model = Qwen3VLModel(args)
1259
+ if not args.text_config.tie_word_embeddings:
1260
+ self.lm_head = nn.Linear(args.text_config.hidden_size, args.text_config.vocab_size, bias=False)
1261
+
1262
+ def __call__(
1263
+ self,
1264
+ inputs: mx.array = None,
1265
+ mask: mx.array = None,
1266
+ cache=None,
1267
+ inputs_embeds: Optional[mx.array] = None,
1268
+ pixel_values: Optional[mx.array] = None,
1269
+ image_grid_thw: Optional[mx.array] = None,
1270
+ visual_pos_masks: Optional[mx.array] = None,
1271
+ deepstack_visual_embeds: Optional[List[mx.array]] = None,
1272
+ cos: Optional[mx.array] = None,
1273
+ sin: Optional[mx.array] = None,
1274
+ rope_deltas: Optional[mx.array] = None,
1275
+ ):
1276
+ out = self.model(
1277
+ input_ids=inputs,
1278
+ inputs_embeds=inputs_embeds,
1279
+ attention_mask=mask,
1280
+ cache=cache,
1281
+ pixel_values=pixel_values,
1282
+ image_grid_thw=image_grid_thw,
1283
+ visual_pos_masks=visual_pos_masks,
1284
+ deepstack_visual_embeds=deepstack_visual_embeds,
1285
+ cos=cos,
1286
+ sin=sin,
1287
+ rope_deltas=rope_deltas,
1288
+ )
1289
+ if self.args.text_config.tie_word_embeddings:
1290
+ return self.model.language_model.embed_tokens.as_linear(out)
1291
+ else:
1292
+ return self.lm_head(out)
1293
+
1294
+ def sanitize(self, weights):
1295
+ # Remove any unnecessary weights
1296
+ sanitized = {}
1297
+ for k, v in weights.items():
1298
+ sanitized[k] = v
1299
+
1300
+ # Handle tied embeddings - remove lm_head if using tied embeddings
1301
+ if self.args.text_config.tie_word_embeddings:
1302
+ sanitized.pop("lm_head.weight", None)
1303
+
1304
+ return sanitized
1305
+
1306
+ @property
1307
+ def layers(self):
1308
+ return self.model.language_model.layers