nexaai 1.0.29__cp310-cp310-macosx_14_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (580) hide show
  1. nexaai/__init__.py +99 -0
  2. nexaai/_stub.cpython-310-darwin.so +0 -0
  3. nexaai/_version.py +4 -0
  4. nexaai/asr.py +68 -0
  5. nexaai/asr_impl/__init__.py +0 -0
  6. nexaai/asr_impl/mlx_asr_impl.py +93 -0
  7. nexaai/asr_impl/pybind_asr_impl.py +127 -0
  8. nexaai/base.py +39 -0
  9. nexaai/binds/__init__.py +7 -0
  10. nexaai/binds/asr_bind.cpython-310-darwin.so +0 -0
  11. nexaai/binds/common_bind.cpython-310-darwin.so +0 -0
  12. nexaai/binds/cpu_gpu/libggml-base.dylib +0 -0
  13. nexaai/binds/cpu_gpu/libggml-cpu.so +0 -0
  14. nexaai/binds/cpu_gpu/libggml-metal.so +0 -0
  15. nexaai/binds/cpu_gpu/libggml.dylib +0 -0
  16. nexaai/binds/cpu_gpu/libmtmd.dylib +0 -0
  17. nexaai/binds/cpu_gpu/libnexa_cpu_gpu.dylib +0 -0
  18. nexaai/binds/cpu_gpu/libnexa_plugin.dylib +0 -0
  19. nexaai/binds/cv_bind.cpython-310-darwin.so +0 -0
  20. nexaai/binds/diarize_bind.cpython-310-darwin.so +0 -0
  21. nexaai/binds/embedder_bind.cpython-310-darwin.so +0 -0
  22. nexaai/binds/libnexa_bridge.dylib +0 -0
  23. nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
  24. nexaai/binds/metal/libnexa_plugin.dylib +0 -0
  25. nexaai/binds/metal/py-lib/ml.py +888 -0
  26. nexaai/binds/metal/py-lib/mlx_audio/__init__.py +0 -0
  27. nexaai/binds/metal/py-lib/mlx_audio/codec/__init__.py +1 -0
  28. nexaai/binds/metal/py-lib/mlx_audio/codec/models/__init__.py +5 -0
  29. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  30. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  31. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  32. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  33. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  34. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  35. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/__init__.py +1 -0
  36. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/base.py +228 -0
  37. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/dac.py +285 -0
  38. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  39. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  40. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  41. nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/__init__.py +1 -0
  42. nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/encodec.py +777 -0
  43. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/__init__.py +1 -0
  44. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/mimi.py +286 -0
  45. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  46. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  47. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  48. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  49. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  50. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  51. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/__init__.py +1 -0
  52. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model.py +260 -0
  53. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model_v2.py +383 -0
  54. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/utils.py +122 -0
  55. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/__init__.py +1 -0
  56. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/attention.py +97 -0
  57. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/layers.py +306 -0
  58. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/snac.py +154 -0
  59. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/vq.py +135 -0
  60. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/__init__.py +1 -0
  61. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/mel.py +33 -0
  62. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/vocos.py +359 -0
  63. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/__init__.py +0 -0
  64. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  65. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_descript.py +109 -0
  66. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_encodec.py +58 -0
  67. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_mimi.py +22 -0
  68. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_s3.py +25 -0
  69. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_snac.py +40 -0
  70. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_vocos.py +93 -0
  71. nexaai/binds/metal/py-lib/mlx_audio/server.py +525 -0
  72. nexaai/binds/metal/py-lib/mlx_audio/sts/__init__.py +0 -0
  73. nexaai/binds/metal/py-lib/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  74. nexaai/binds/metal/py-lib/mlx_audio/sts/voice_pipeline.py +327 -0
  75. nexaai/binds/metal/py-lib/mlx_audio/stt/__init__.py +0 -0
  76. nexaai/binds/metal/py-lib/mlx_audio/stt/generate.py +174 -0
  77. nexaai/binds/metal/py-lib/mlx_audio/stt/models/__init__.py +0 -0
  78. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  79. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  80. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/attention.py +187 -0
  81. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/audio.py +76 -0
  82. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  83. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  84. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  85. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  86. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  87. nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  88. nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  89. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/__init__.py +1 -0
  90. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/audio.py +82 -0
  91. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/decoding.py +742 -0
  92. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/timing.py +329 -0
  93. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  94. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/whisper.py +862 -0
  95. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/writers.py +268 -0
  96. nexaai/binds/metal/py-lib/mlx_audio/stt/tests/test_models.py +381 -0
  97. nexaai/binds/metal/py-lib/mlx_audio/stt/utils.py +195 -0
  98. nexaai/binds/metal/py-lib/mlx_audio/tts/__init__.py +1 -0
  99. nexaai/binds/metal/py-lib/mlx_audio/tts/audio_player.py +120 -0
  100. nexaai/binds/metal/py-lib/mlx_audio/tts/convert.py +71 -0
  101. nexaai/binds/metal/py-lib/mlx_audio/tts/generate.py +449 -0
  102. nexaai/binds/metal/py-lib/mlx_audio/tts/models/__init__.py +0 -0
  103. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/__init__.py +4 -0
  104. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/bark.py +528 -0
  105. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/isftnet.py +12 -0
  106. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/pipeline.py +442 -0
  107. nexaai/binds/metal/py-lib/mlx_audio/tts/models/base.py +84 -0
  108. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/__init__.py +1 -0
  109. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/audio.py +287 -0
  110. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/config.py +256 -0
  111. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/dia.py +592 -0
  112. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/layers.py +870 -0
  113. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/__init__.py +3 -0
  114. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/attention.py +180 -0
  115. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  116. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/conformer.py +247 -0
  117. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  118. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  119. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  120. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  121. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  122. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  123. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/indextts.py +412 -0
  124. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/mel.py +37 -0
  125. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/normalize.py +294 -0
  126. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  127. nexaai/binds/metal/py-lib/mlx_audio/tts/models/interpolate.py +108 -0
  128. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  129. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  130. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  131. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/modules.py +659 -0
  132. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  133. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/voice.py +113 -0
  134. nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/__init__.py +3 -0
  135. nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/llama.py +324 -0
  136. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/__init__.py +1 -0
  137. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  138. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  139. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/outetts.py +255 -0
  140. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  141. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/tokens.py +36 -0
  142. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/__init__.py +3 -0
  143. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/attention.py +195 -0
  144. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/sesame.py +633 -0
  145. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  146. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/__init__.py +1 -0
  147. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  148. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/bicodec.py +269 -0
  149. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  150. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  151. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  152. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  153. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  154. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  155. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  156. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  157. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  158. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  159. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  160. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  161. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  162. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  163. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  164. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/spark.py +382 -0
  165. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  166. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/file.py +221 -0
  167. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  168. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/__init__.py +0 -0
  169. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_base.py +66 -0
  170. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_convert.py +173 -0
  171. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_interpolate.py +88 -0
  172. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_models.py +974 -0
  173. nexaai/binds/metal/py-lib/mlx_audio/tts/utils.py +337 -0
  174. nexaai/binds/metal/py-lib/mlx_audio/utils.py +237 -0
  175. nexaai/binds/metal/py-lib/mlx_audio/version.py +1 -0
  176. nexaai/binds/metal/py-lib/profiling.py +239 -0
  177. nexaai/binds/nexaml/libfftw3.3.dylib +0 -0
  178. nexaai/binds/nexaml/libfftw3f.3.dylib +0 -0
  179. nexaai/binds/nexaml/libggml-base.dylib +0 -0
  180. nexaai/binds/nexaml/libggml-cpu.so +0 -0
  181. nexaai/binds/nexaml/libggml-metal.so +0 -0
  182. nexaai/binds/nexaml/libggml.dylib +0 -0
  183. nexaai/binds/nexaml/libmp3lame.0.dylib +0 -0
  184. nexaai/binds/nexaml/libmpg123.0.dylib +0 -0
  185. nexaai/binds/nexaml/libnexa-mm-process.dylib +0 -0
  186. nexaai/binds/nexaml/libnexa-sampling.dylib +0 -0
  187. nexaai/binds/nexaml/libnexa_plugin.dylib +0 -0
  188. nexaai/binds/nexaml/libnexaproc.dylib +0 -0
  189. nexaai/binds/nexaml/libomp.dylib +0 -0
  190. nexaai/binds/nexaml/libqwen3-vl.dylib +0 -0
  191. nexaai/binds/nexaml/libqwen3vl-vision.dylib +0 -0
  192. nexaai/binds/rerank_bind.cpython-310-darwin.so +0 -0
  193. nexaai/binds/vlm_bind.cpython-310-darwin.so +0 -0
  194. nexaai/common.py +106 -0
  195. nexaai/cv.py +95 -0
  196. nexaai/cv_impl/__init__.py +0 -0
  197. nexaai/cv_impl/mlx_cv_impl.py +91 -0
  198. nexaai/cv_impl/pybind_cv_impl.py +124 -0
  199. nexaai/diarize.py +80 -0
  200. nexaai/diarize_impl/__init__.py +1 -0
  201. nexaai/diarize_impl/pybind_diarize_impl.py +125 -0
  202. nexaai/embedder.py +73 -0
  203. nexaai/embedder_impl/__init__.py +0 -0
  204. nexaai/embedder_impl/mlx_embedder_impl.py +118 -0
  205. nexaai/embedder_impl/pybind_embedder_impl.py +96 -0
  206. nexaai/image_gen.py +141 -0
  207. nexaai/image_gen_impl/__init__.py +0 -0
  208. nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -0
  209. nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -0
  210. nexaai/llm.py +98 -0
  211. nexaai/llm_impl/__init__.py +0 -0
  212. nexaai/llm_impl/mlx_llm_impl.py +271 -0
  213. nexaai/llm_impl/pybind_llm_impl.py +238 -0
  214. nexaai/log.py +92 -0
  215. nexaai/mlx_backend/asr/__init__.py +12 -0
  216. nexaai/mlx_backend/asr/interface.py +122 -0
  217. nexaai/mlx_backend/common/__init__.py +0 -0
  218. nexaai/mlx_backend/common/utils.py +25 -0
  219. nexaai/mlx_backend/cv/__init__.py +0 -0
  220. nexaai/mlx_backend/cv/generate.py +195 -0
  221. nexaai/mlx_backend/cv/interface.py +162 -0
  222. nexaai/mlx_backend/cv/main.py +81 -0
  223. nexaai/mlx_backend/cv/modeling/pp_ocr_v4.py +1736 -0
  224. nexaai/mlx_backend/embedding/__init__.py +0 -0
  225. nexaai/mlx_backend/embedding/generate.py +333 -0
  226. nexaai/mlx_backend/embedding/interface.py +617 -0
  227. nexaai/mlx_backend/embedding/main.py +173 -0
  228. nexaai/mlx_backend/embedding/modeling/__init__.py +0 -0
  229. nexaai/mlx_backend/embedding/modeling/nexa_jina_v2.py +399 -0
  230. nexaai/mlx_backend/image_gen/__init__.py +1 -0
  231. nexaai/mlx_backend/image_gen/generate_sd.py +244 -0
  232. nexaai/mlx_backend/image_gen/interface.py +82 -0
  233. nexaai/mlx_backend/image_gen/main.py +281 -0
  234. nexaai/mlx_backend/image_gen/stable_diffusion/__init__.py +306 -0
  235. nexaai/mlx_backend/image_gen/stable_diffusion/clip.py +116 -0
  236. nexaai/mlx_backend/image_gen/stable_diffusion/config.py +65 -0
  237. nexaai/mlx_backend/image_gen/stable_diffusion/model_io.py +386 -0
  238. nexaai/mlx_backend/image_gen/stable_diffusion/sampler.py +105 -0
  239. nexaai/mlx_backend/image_gen/stable_diffusion/tokenizer.py +100 -0
  240. nexaai/mlx_backend/image_gen/stable_diffusion/unet.py +460 -0
  241. nexaai/mlx_backend/image_gen/stable_diffusion/vae.py +274 -0
  242. nexaai/mlx_backend/llm/__init__.py +0 -0
  243. nexaai/mlx_backend/llm/generate.py +149 -0
  244. nexaai/mlx_backend/llm/interface.py +764 -0
  245. nexaai/mlx_backend/llm/main.py +68 -0
  246. nexaai/mlx_backend/ml.py +888 -0
  247. nexaai/mlx_backend/mlx_audio/__init__.py +0 -0
  248. nexaai/mlx_backend/mlx_audio/codec/__init__.py +1 -0
  249. nexaai/mlx_backend/mlx_audio/codec/models/__init__.py +5 -0
  250. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  251. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  252. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  253. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  254. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  255. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  256. nexaai/mlx_backend/mlx_audio/codec/models/descript/__init__.py +1 -0
  257. nexaai/mlx_backend/mlx_audio/codec/models/descript/base.py +228 -0
  258. nexaai/mlx_backend/mlx_audio/codec/models/descript/dac.py +285 -0
  259. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  260. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  261. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  262. nexaai/mlx_backend/mlx_audio/codec/models/encodec/__init__.py +1 -0
  263. nexaai/mlx_backend/mlx_audio/codec/models/encodec/encodec.py +777 -0
  264. nexaai/mlx_backend/mlx_audio/codec/models/mimi/__init__.py +1 -0
  265. nexaai/mlx_backend/mlx_audio/codec/models/mimi/mimi.py +286 -0
  266. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  267. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  268. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  269. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  270. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  271. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  272. nexaai/mlx_backend/mlx_audio/codec/models/s3/__init__.py +1 -0
  273. nexaai/mlx_backend/mlx_audio/codec/models/s3/model.py +260 -0
  274. nexaai/mlx_backend/mlx_audio/codec/models/s3/model_v2.py +383 -0
  275. nexaai/mlx_backend/mlx_audio/codec/models/s3/utils.py +122 -0
  276. nexaai/mlx_backend/mlx_audio/codec/models/snac/__init__.py +1 -0
  277. nexaai/mlx_backend/mlx_audio/codec/models/snac/attention.py +97 -0
  278. nexaai/mlx_backend/mlx_audio/codec/models/snac/layers.py +306 -0
  279. nexaai/mlx_backend/mlx_audio/codec/models/snac/snac.py +154 -0
  280. nexaai/mlx_backend/mlx_audio/codec/models/snac/vq.py +135 -0
  281. nexaai/mlx_backend/mlx_audio/codec/models/vocos/__init__.py +1 -0
  282. nexaai/mlx_backend/mlx_audio/codec/models/vocos/mel.py +33 -0
  283. nexaai/mlx_backend/mlx_audio/codec/models/vocos/vocos.py +359 -0
  284. nexaai/mlx_backend/mlx_audio/codec/tests/__init__.py +0 -0
  285. nexaai/mlx_backend/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  286. nexaai/mlx_backend/mlx_audio/codec/tests/test_descript.py +109 -0
  287. nexaai/mlx_backend/mlx_audio/codec/tests/test_encodec.py +58 -0
  288. nexaai/mlx_backend/mlx_audio/codec/tests/test_mimi.py +22 -0
  289. nexaai/mlx_backend/mlx_audio/codec/tests/test_s3.py +25 -0
  290. nexaai/mlx_backend/mlx_audio/codec/tests/test_snac.py +40 -0
  291. nexaai/mlx_backend/mlx_audio/codec/tests/test_vocos.py +93 -0
  292. nexaai/mlx_backend/mlx_audio/server.py +525 -0
  293. nexaai/mlx_backend/mlx_audio/sts/__init__.py +0 -0
  294. nexaai/mlx_backend/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  295. nexaai/mlx_backend/mlx_audio/sts/voice_pipeline.py +327 -0
  296. nexaai/mlx_backend/mlx_audio/stt/__init__.py +0 -0
  297. nexaai/mlx_backend/mlx_audio/stt/generate.py +174 -0
  298. nexaai/mlx_backend/mlx_audio/stt/models/__init__.py +0 -0
  299. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  300. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  301. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/attention.py +187 -0
  302. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/audio.py +76 -0
  303. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  304. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  305. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  306. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  307. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  308. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  309. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  310. nexaai/mlx_backend/mlx_audio/stt/models/whisper/__init__.py +1 -0
  311. nexaai/mlx_backend/mlx_audio/stt/models/whisper/audio.py +82 -0
  312. nexaai/mlx_backend/mlx_audio/stt/models/whisper/decoding.py +742 -0
  313. nexaai/mlx_backend/mlx_audio/stt/models/whisper/timing.py +329 -0
  314. nexaai/mlx_backend/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  315. nexaai/mlx_backend/mlx_audio/stt/models/whisper/whisper.py +862 -0
  316. nexaai/mlx_backend/mlx_audio/stt/models/whisper/writers.py +268 -0
  317. nexaai/mlx_backend/mlx_audio/stt/tests/test_models.py +381 -0
  318. nexaai/mlx_backend/mlx_audio/stt/utils.py +195 -0
  319. nexaai/mlx_backend/mlx_audio/tts/__init__.py +1 -0
  320. nexaai/mlx_backend/mlx_audio/tts/audio_player.py +120 -0
  321. nexaai/mlx_backend/mlx_audio/tts/convert.py +71 -0
  322. nexaai/mlx_backend/mlx_audio/tts/generate.py +449 -0
  323. nexaai/mlx_backend/mlx_audio/tts/models/__init__.py +0 -0
  324. nexaai/mlx_backend/mlx_audio/tts/models/bark/__init__.py +4 -0
  325. nexaai/mlx_backend/mlx_audio/tts/models/bark/bark.py +528 -0
  326. nexaai/mlx_backend/mlx_audio/tts/models/bark/isftnet.py +12 -0
  327. nexaai/mlx_backend/mlx_audio/tts/models/bark/pipeline.py +442 -0
  328. nexaai/mlx_backend/mlx_audio/tts/models/base.py +84 -0
  329. nexaai/mlx_backend/mlx_audio/tts/models/dia/__init__.py +1 -0
  330. nexaai/mlx_backend/mlx_audio/tts/models/dia/audio.py +287 -0
  331. nexaai/mlx_backend/mlx_audio/tts/models/dia/config.py +256 -0
  332. nexaai/mlx_backend/mlx_audio/tts/models/dia/dia.py +592 -0
  333. nexaai/mlx_backend/mlx_audio/tts/models/dia/layers.py +870 -0
  334. nexaai/mlx_backend/mlx_audio/tts/models/indextts/__init__.py +3 -0
  335. nexaai/mlx_backend/mlx_audio/tts/models/indextts/attention.py +180 -0
  336. nexaai/mlx_backend/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  337. nexaai/mlx_backend/mlx_audio/tts/models/indextts/conformer.py +247 -0
  338. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  339. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  340. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  341. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  342. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  343. nexaai/mlx_backend/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  344. nexaai/mlx_backend/mlx_audio/tts/models/indextts/indextts.py +412 -0
  345. nexaai/mlx_backend/mlx_audio/tts/models/indextts/mel.py +37 -0
  346. nexaai/mlx_backend/mlx_audio/tts/models/indextts/normalize.py +294 -0
  347. nexaai/mlx_backend/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  348. nexaai/mlx_backend/mlx_audio/tts/models/interpolate.py +108 -0
  349. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  350. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  351. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  352. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/modules.py +659 -0
  353. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  354. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/voice.py +113 -0
  355. nexaai/mlx_backend/mlx_audio/tts/models/llama/__init__.py +3 -0
  356. nexaai/mlx_backend/mlx_audio/tts/models/llama/llama.py +324 -0
  357. nexaai/mlx_backend/mlx_audio/tts/models/outetts/__init__.py +1 -0
  358. nexaai/mlx_backend/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  359. nexaai/mlx_backend/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  360. nexaai/mlx_backend/mlx_audio/tts/models/outetts/default_speaker.json +461 -0
  361. nexaai/mlx_backend/mlx_audio/tts/models/outetts/outetts.py +255 -0
  362. nexaai/mlx_backend/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  363. nexaai/mlx_backend/mlx_audio/tts/models/outetts/tokens.py +36 -0
  364. nexaai/mlx_backend/mlx_audio/tts/models/sesame/__init__.py +3 -0
  365. nexaai/mlx_backend/mlx_audio/tts/models/sesame/attention.py +195 -0
  366. nexaai/mlx_backend/mlx_audio/tts/models/sesame/sesame.py +633 -0
  367. nexaai/mlx_backend/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  368. nexaai/mlx_backend/mlx_audio/tts/models/spark/__init__.py +1 -0
  369. nexaai/mlx_backend/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  370. nexaai/mlx_backend/mlx_audio/tts/models/spark/bicodec.py +269 -0
  371. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  372. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  373. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  374. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  375. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  376. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  377. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  378. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  379. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  380. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  381. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  382. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  383. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  384. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  385. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  386. nexaai/mlx_backend/mlx_audio/tts/models/spark/spark.py +382 -0
  387. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  388. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/file.py +221 -0
  389. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  390. nexaai/mlx_backend/mlx_audio/tts/tests/__init__.py +0 -0
  391. nexaai/mlx_backend/mlx_audio/tts/tests/test_base.py +66 -0
  392. nexaai/mlx_backend/mlx_audio/tts/tests/test_convert.py +173 -0
  393. nexaai/mlx_backend/mlx_audio/tts/tests/test_interpolate.py +88 -0
  394. nexaai/mlx_backend/mlx_audio/tts/tests/test_models.py +974 -0
  395. nexaai/mlx_backend/mlx_audio/tts/utils.py +337 -0
  396. nexaai/mlx_backend/mlx_audio/utils.py +237 -0
  397. nexaai/mlx_backend/mlx_audio/version.py +1 -0
  398. nexaai/mlx_backend/profiling.py +239 -0
  399. nexaai/mlx_backend/rerank/__init__.py +0 -0
  400. nexaai/mlx_backend/rerank/generate.py +174 -0
  401. nexaai/mlx_backend/rerank/interface.py +287 -0
  402. nexaai/mlx_backend/rerank/main.py +127 -0
  403. nexaai/mlx_backend/rerank/modeling/__init__.py +0 -0
  404. nexaai/mlx_backend/rerank/modeling/nexa_jina_rerank.py +330 -0
  405. nexaai/mlx_backend/sd/__init__.py +1 -0
  406. nexaai/mlx_backend/sd/interface.py +362 -0
  407. nexaai/mlx_backend/sd/main.py +286 -0
  408. nexaai/mlx_backend/sd/modeling/__init__.py +306 -0
  409. nexaai/mlx_backend/sd/modeling/clip.py +116 -0
  410. nexaai/mlx_backend/sd/modeling/config.py +65 -0
  411. nexaai/mlx_backend/sd/modeling/model_io.py +385 -0
  412. nexaai/mlx_backend/sd/modeling/sampler.py +105 -0
  413. nexaai/mlx_backend/sd/modeling/tokenizer.py +100 -0
  414. nexaai/mlx_backend/sd/modeling/unet.py +460 -0
  415. nexaai/mlx_backend/sd/modeling/vae.py +274 -0
  416. nexaai/mlx_backend/tts/__init__.py +12 -0
  417. nexaai/mlx_backend/tts/interface.py +276 -0
  418. nexaai/mlx_backend/vlm/__init__.py +3 -0
  419. nexaai/mlx_backend/vlm/generate.py +572 -0
  420. nexaai/mlx_backend/vlm/generate_qwen3_vl.py +374 -0
  421. nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +259 -0
  422. nexaai/mlx_backend/vlm/interface.py +559 -0
  423. nexaai/mlx_backend/vlm/main.py +365 -0
  424. nexaai/mlx_backend/vlm/modeling/__init__.py +0 -0
  425. nexaai/mlx_backend/vlm/modeling/convert.py +68 -0
  426. nexaai/mlx_backend/vlm/modeling/models/__init__.py +0 -0
  427. nexaai/mlx_backend/vlm/modeling/models/aya_vision/__init__.py +8 -0
  428. nexaai/mlx_backend/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
  429. nexaai/mlx_backend/vlm/modeling/models/aya_vision/interpolate.py +186 -0
  430. nexaai/mlx_backend/vlm/modeling/models/aya_vision/language.py +233 -0
  431. nexaai/mlx_backend/vlm/modeling/models/aya_vision/vision.py +503 -0
  432. nexaai/mlx_backend/vlm/modeling/models/base.py +202 -0
  433. nexaai/mlx_backend/vlm/modeling/models/cache.py +230 -0
  434. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
  435. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
  436. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
  437. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
  438. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
  439. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
  440. nexaai/mlx_backend/vlm/modeling/models/florence2/__init__.py +8 -0
  441. nexaai/mlx_backend/vlm/modeling/models/florence2/florence2.py +366 -0
  442. nexaai/mlx_backend/vlm/modeling/models/florence2/language.py +488 -0
  443. nexaai/mlx_backend/vlm/modeling/models/florence2/vision.py +591 -0
  444. nexaai/mlx_backend/vlm/modeling/models/gemma3/__init__.py +8 -0
  445. nexaai/mlx_backend/vlm/modeling/models/gemma3/gemma3.py +213 -0
  446. nexaai/mlx_backend/vlm/modeling/models/gemma3/language.py +315 -0
  447. nexaai/mlx_backend/vlm/modeling/models/gemma3/vision.py +238 -0
  448. nexaai/mlx_backend/vlm/modeling/models/gemma3n/__init__.py +2 -0
  449. nexaai/mlx_backend/vlm/modeling/models/gemma3n/audio.py +1038 -0
  450. nexaai/mlx_backend/vlm/modeling/models/gemma3n/config.py +139 -0
  451. nexaai/mlx_backend/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
  452. nexaai/mlx_backend/vlm/modeling/models/gemma3n/language.py +629 -0
  453. nexaai/mlx_backend/vlm/modeling/models/gemma3n/vision.py +1022 -0
  454. nexaai/mlx_backend/vlm/modeling/models/idefics2/__init__.py +9 -0
  455. nexaai/mlx_backend/vlm/modeling/models/idefics2/idefics2.py +294 -0
  456. nexaai/mlx_backend/vlm/modeling/models/idefics2/language.py +191 -0
  457. nexaai/mlx_backend/vlm/modeling/models/idefics2/vision.py +267 -0
  458. nexaai/mlx_backend/vlm/modeling/models/idefics3/__init__.py +8 -0
  459. nexaai/mlx_backend/vlm/modeling/models/idefics3/idefics3.py +175 -0
  460. nexaai/mlx_backend/vlm/modeling/models/idefics3/language.py +192 -0
  461. nexaai/mlx_backend/vlm/modeling/models/idefics3/vision.py +233 -0
  462. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/__init__.py +9 -0
  463. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
  464. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/language.py +220 -0
  465. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/processor.py +393 -0
  466. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/vision.py +293 -0
  467. nexaai/mlx_backend/vlm/modeling/models/kernels.py +307 -0
  468. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/__init__.py +8 -0
  469. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
  470. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/language.py +509 -0
  471. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/vision.py +522 -0
  472. nexaai/mlx_backend/vlm/modeling/models/llama4/__init__.py +8 -0
  473. nexaai/mlx_backend/vlm/modeling/models/llama4/language.py +386 -0
  474. nexaai/mlx_backend/vlm/modeling/models/llama4/llama4.py +138 -0
  475. nexaai/mlx_backend/vlm/modeling/models/llama4/vision.py +560 -0
  476. nexaai/mlx_backend/vlm/modeling/models/llava/__init__.py +8 -0
  477. nexaai/mlx_backend/vlm/modeling/models/llava/language.py +240 -0
  478. nexaai/mlx_backend/vlm/modeling/models/llava/llava.py +153 -0
  479. nexaai/mlx_backend/vlm/modeling/models/llava/vision.py +259 -0
  480. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/__init__.py +9 -0
  481. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/language.py +236 -0
  482. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
  483. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/vision.py +303 -0
  484. nexaai/mlx_backend/vlm/modeling/models/llava_next/__init__.py +8 -0
  485. nexaai/mlx_backend/vlm/modeling/models/llava_next/language.py +230 -0
  486. nexaai/mlx_backend/vlm/modeling/models/llava_next/llava_next.py +160 -0
  487. nexaai/mlx_backend/vlm/modeling/models/llava_next/vision.py +243 -0
  488. nexaai/mlx_backend/vlm/modeling/models/mistral3/__init__.py +8 -0
  489. nexaai/mlx_backend/vlm/modeling/models/mistral3/mistral3.py +283 -0
  490. nexaai/mlx_backend/vlm/modeling/models/mllama/__init__.py +8 -0
  491. nexaai/mlx_backend/vlm/modeling/models/mllama/language.py +416 -0
  492. nexaai/mlx_backend/vlm/modeling/models/mllama/mllama.py +172 -0
  493. nexaai/mlx_backend/vlm/modeling/models/mllama/vision.py +499 -0
  494. nexaai/mlx_backend/vlm/modeling/models/molmo/__init__.py +8 -0
  495. nexaai/mlx_backend/vlm/modeling/models/molmo/language.py +243 -0
  496. nexaai/mlx_backend/vlm/modeling/models/molmo/molmo.py +133 -0
  497. nexaai/mlx_backend/vlm/modeling/models/molmo/vision.py +465 -0
  498. nexaai/mlx_backend/vlm/modeling/models/multi_modality/__init__.py +10 -0
  499. nexaai/mlx_backend/vlm/modeling/models/multi_modality/language.py +230 -0
  500. nexaai/mlx_backend/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
  501. nexaai/mlx_backend/vlm/modeling/models/multi_modality/sam.py +557 -0
  502. nexaai/mlx_backend/vlm/modeling/models/multi_modality/vision.py +526 -0
  503. nexaai/mlx_backend/vlm/modeling/models/paligemma/__init__.py +8 -0
  504. nexaai/mlx_backend/vlm/modeling/models/paligemma/language.py +282 -0
  505. nexaai/mlx_backend/vlm/modeling/models/paligemma/paligemma.py +160 -0
  506. nexaai/mlx_backend/vlm/modeling/models/paligemma/vision.py +242 -0
  507. nexaai/mlx_backend/vlm/modeling/models/phi3_v/__init__.py +8 -0
  508. nexaai/mlx_backend/vlm/modeling/models/phi3_v/language.py +21 -0
  509. nexaai/mlx_backend/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
  510. nexaai/mlx_backend/vlm/modeling/models/phi3_v/su_rope.py +71 -0
  511. nexaai/mlx_backend/vlm/modeling/models/phi3_v/vision.py +324 -0
  512. nexaai/mlx_backend/vlm/modeling/models/pixtral/__init__.py +8 -0
  513. nexaai/mlx_backend/vlm/modeling/models/pixtral/language.py +229 -0
  514. nexaai/mlx_backend/vlm/modeling/models/pixtral/pixtral.py +161 -0
  515. nexaai/mlx_backend/vlm/modeling/models/pixtral/vision.py +320 -0
  516. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
  517. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
  518. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
  519. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
  520. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
  521. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
  522. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/config.py +104 -0
  523. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/language.py +490 -0
  524. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
  525. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/vision.py +312 -0
  526. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
  527. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
  528. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
  529. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
  530. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
  531. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
  532. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
  533. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/processor.py +476 -0
  534. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +1262 -0
  535. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
  536. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
  537. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
  538. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
  539. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
  540. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
  541. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
  542. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
  543. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1308 -0
  544. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
  545. nexaai/mlx_backend/vlm/modeling/models/smolvlm/__init__.py +8 -0
  546. nexaai/mlx_backend/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
  547. nexaai/mlx_backend/vlm/modeling/processing_qwen2_5_vl.py +209 -0
  548. nexaai/mlx_backend/vlm/modeling/processing_qwen2_vl.py +215 -0
  549. nexaai/mlx_backend/vlm/modeling/prompt_utils.py +474 -0
  550. nexaai/mlx_backend/vlm/modeling/sample_utils.py +39 -0
  551. nexaai/mlx_backend/vlm/modeling/tokenizer_utils.py +344 -0
  552. nexaai/mlx_backend/vlm/modeling/trainer/__init__.py +9 -0
  553. nexaai/mlx_backend/vlm/modeling/trainer/lora.py +70 -0
  554. nexaai/mlx_backend/vlm/modeling/trainer/trainer.py +296 -0
  555. nexaai/mlx_backend/vlm/modeling/trainer/utils.py +160 -0
  556. nexaai/mlx_backend/vlm/modeling/utils.py +928 -0
  557. nexaai/rerank.py +57 -0
  558. nexaai/rerank_impl/__init__.py +0 -0
  559. nexaai/rerank_impl/mlx_rerank_impl.py +94 -0
  560. nexaai/rerank_impl/pybind_rerank_impl.py +136 -0
  561. nexaai/runtime.py +68 -0
  562. nexaai/runtime_error.py +24 -0
  563. nexaai/tts.py +75 -0
  564. nexaai/tts_impl/__init__.py +0 -0
  565. nexaai/tts_impl/mlx_tts_impl.py +94 -0
  566. nexaai/tts_impl/pybind_tts_impl.py +43 -0
  567. nexaai/utils/decode.py +18 -0
  568. nexaai/utils/manifest_utils.py +531 -0
  569. nexaai/utils/model_manager.py +1745 -0
  570. nexaai/utils/model_types.py +49 -0
  571. nexaai/utils/progress_tracker.py +389 -0
  572. nexaai/utils/quantization_utils.py +245 -0
  573. nexaai/vlm.py +130 -0
  574. nexaai/vlm_impl/__init__.py +0 -0
  575. nexaai/vlm_impl/mlx_vlm_impl.py +259 -0
  576. nexaai/vlm_impl/pybind_vlm_impl.py +275 -0
  577. nexaai-1.0.29.dist-info/METADATA +35 -0
  578. nexaai-1.0.29.dist-info/RECORD +580 -0
  579. nexaai-1.0.29.dist-info/WHEEL +5 -0
  580. nexaai-1.0.29.dist-info/top_level.txt +1 -0
@@ -0,0 +1,239 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ from dataclasses import dataclass, field
5
+ from typing import Any, Optional
6
+ from enum import IntEnum
7
+
8
+ # --------------------------------------------------------------------------------------
9
+ # Stop reason constants matching profile.h
10
+ # --------------------------------------------------------------------------------------
11
+
12
+ class StopReason(IntEnum):
13
+ """Stop reason constants matching profile.h"""
14
+ ML_STOP_REASON_UNKNOWN = 0
15
+ ML_STOP_REASON_EOS = 1
16
+ ML_STOP_REASON_LENGTH = 2
17
+ ML_STOP_REASON_USER = 3
18
+ ML_STOP_REASON_STOP_SEQUENCE = 4
19
+ ML_STOP_REASON_COMPLETED = 5
20
+
21
+ # --------------------------------------------------------------------------------------
22
+ # Profiling data structure
23
+ # --------------------------------------------------------------------------------------
24
+
25
+ @dataclass
26
+ class ProfilingData:
27
+ """Profiling data for performance metrics."""
28
+ ttft_us: int = 0 # Time to first token (us)
29
+ total_time_us: int = 0 # Total generation time (us)
30
+ prompt_time_us: int = 0 # Prompt processing time (us)
31
+ decode_time_us: int = 0 # Token generation time (us)
32
+ tokens_per_second: float = 0.0 # Decoding speed (tokens/sec)
33
+ total_tokens: int = 0 # Total tokens generated
34
+ prompt_tokens: int = 0 # Number of prompt tokens
35
+ generated_tokens: int = 0 # Number of generated tokens
36
+ stop_reason: int = StopReason.ML_STOP_REASON_UNKNOWN # Stop reason (numeric)
37
+
38
+ def reset(self):
39
+ """Reset all profiling data."""
40
+ self.ttft_us = 0
41
+ self.total_time_us = 0
42
+ self.prompt_time_us = 0
43
+ self.decode_time_us = 0
44
+ self.tokens_per_second = 0.0
45
+ self.total_tokens = 0
46
+ self.prompt_tokens = 0
47
+ self.generated_tokens = 0
48
+ self.stop_reason = StopReason.ML_STOP_REASON_UNKNOWN
49
+
50
+ # --------------------------------------------------------------------------------------
51
+ # Profiling context (similar to ml_ProfilingContext in profile.h)
52
+ # --------------------------------------------------------------------------------------
53
+
54
+ @dataclass
55
+ class ProfilingContext:
56
+ """Profiling context for tracking timing and state."""
57
+ start_time: Optional[float] = None
58
+ prompt_start_time: Optional[float] = None
59
+ prompt_end_time: Optional[float] = None
60
+ decode_start_time: Optional[float] = None
61
+ decode_end_time: Optional[float] = None
62
+ first_token_time: Optional[float] = None
63
+ end_time: Optional[float] = None
64
+
65
+ ttft_recorded: bool = False
66
+ stop_reason: int = StopReason.ML_STOP_REASON_UNKNOWN
67
+ prompt_tokens: int = 0
68
+ generated_tokens: int = 0
69
+
70
+ def reset(self):
71
+ """Reset profiling context."""
72
+ self.start_time = None
73
+ self.prompt_start_time = None
74
+ self.prompt_end_time = None
75
+ self.decode_start_time = None
76
+ self.decode_end_time = None
77
+ self.first_token_time = None
78
+ self.end_time = None
79
+ self.ttft_recorded = False
80
+ self.stop_reason = StopReason.ML_STOP_REASON_UNKNOWN
81
+ self.prompt_tokens = 0
82
+ self.generated_tokens = 0
83
+
84
+ # --------------------------------------------------------------------------------------
85
+ # Profiling functions (similar to profile.h functions)
86
+ # --------------------------------------------------------------------------------------
87
+
88
+ def profiling_reset(ctx: ProfilingContext) -> None:
89
+ """Reset profiling context (ml_profiling_reset)."""
90
+ ctx.reset()
91
+
92
+ def profiling_start(ctx: ProfilingContext) -> None:
93
+ """Start profiling (ml_profiling_start)."""
94
+ ctx.start_time = time.perf_counter()
95
+ ctx.prompt_start_time = ctx.start_time
96
+
97
+ def profiling_prompt_start(ctx: ProfilingContext) -> None:
98
+ """Start prompt processing timing (ml_profiling_prompt_start)."""
99
+ ctx.prompt_start_time = time.perf_counter()
100
+
101
+ def profiling_prompt_end(ctx: ProfilingContext) -> None:
102
+ """End prompt processing timing (ml_profiling_prompt_end)."""
103
+ ctx.prompt_end_time = time.perf_counter()
104
+
105
+ def profiling_decode_start(ctx: ProfilingContext) -> None:
106
+ """Start decode timing (ml_profiling_decode_start)."""
107
+ ctx.decode_start_time = time.perf_counter()
108
+
109
+ def profiling_decode_end(ctx: ProfilingContext) -> None:
110
+ """End decode timing (ml_profiling_decode_end)."""
111
+ ctx.decode_end_time = time.perf_counter()
112
+
113
+ def profiling_record_ttft(ctx: ProfilingContext) -> None:
114
+ """Record time to first token (ml_profiling_record_ttft)."""
115
+ if not ctx.ttft_recorded and ctx.start_time is not None:
116
+ ctx.first_token_time = time.perf_counter()
117
+ ctx.ttft_recorded = True
118
+
119
+ def profiling_update_prompt_tokens(ctx: ProfilingContext, prompt_tokens: int) -> None:
120
+ """Update prompt token count (ml_profiling_update_prompt_tokens)."""
121
+ ctx.prompt_tokens = prompt_tokens
122
+
123
+ def profiling_update_generated_tokens(ctx: ProfilingContext, generated_tokens: int) -> None:
124
+ """Update generated token count (ml_profiling_update_generated_tokens)."""
125
+ ctx.generated_tokens = generated_tokens
126
+
127
+ def profiling_stop_reason(ctx: ProfilingContext, stop_reason: int) -> None:
128
+ """Set stop reason (ml_profiling_stop_reason)."""
129
+ ctx.stop_reason = stop_reason
130
+
131
+ def profiling_end(ctx: ProfilingContext) -> None:
132
+ """End profiling (ml_profiling_end)."""
133
+ ctx.end_time = time.perf_counter()
134
+
135
+ def profiling_gen_data(ctx: ProfilingContext) -> ProfilingData:
136
+ """Generate profiling data from context (ml_profiling_gen_data)."""
137
+ data = ProfilingData()
138
+
139
+ if ctx.start_time is None or ctx.end_time is None:
140
+ return data
141
+
142
+ # Calculate total time
143
+ data.total_time_us = int((ctx.end_time - ctx.start_time) * 1_000_000)
144
+
145
+ # Calculate prompt time
146
+ if ctx.prompt_start_time is not None and ctx.prompt_end_time is not None:
147
+ data.prompt_time_us = int((ctx.prompt_end_time - ctx.prompt_start_time) * 1_000_000)
148
+
149
+ # Calculate decode time
150
+ if ctx.decode_start_time is not None and ctx.decode_end_time is not None:
151
+ data.decode_time_us = int((ctx.decode_end_time - ctx.decode_start_time) * 1_000_000)
152
+
153
+ # Calculate TTFT
154
+ if ctx.first_token_time is not None and ctx.start_time is not None:
155
+ data.ttft_us = int((ctx.first_token_time - ctx.start_time) * 1_000_000)
156
+
157
+ # Set token counts
158
+ data.prompt_tokens = ctx.prompt_tokens
159
+ data.generated_tokens = ctx.generated_tokens
160
+ data.total_tokens = ctx.prompt_tokens + ctx.generated_tokens
161
+
162
+ # Calculate tokens per second
163
+ if data.decode_time_us > 0:
164
+ data.tokens_per_second = (data.generated_tokens * 1_000_000.0) / data.decode_time_us
165
+
166
+ # Set stop reason
167
+ data.stop_reason = ctx.stop_reason
168
+
169
+ return data
170
+
171
+ def stop_reason_to_string(reason: int) -> str:
172
+ """Convert stop reason to string (stop_reason_to_string)."""
173
+ try:
174
+ return StopReason(reason).name
175
+ except ValueError:
176
+ return f"UNKNOWN({reason})"
177
+
178
+ # --------------------------------------------------------------------------------------
179
+ # Profiling mixin for model classes
180
+ # --------------------------------------------------------------------------------------
181
+
182
+ class ProfilingMixin:
183
+ """Mixin class to add profiling capabilities to model classes."""
184
+
185
+ def __init__(self):
186
+ """Initialize profiling mixin."""
187
+ self._profiling_context = ProfilingContext()
188
+ self._profiling_data = ProfilingData()
189
+
190
+ def _start_profiling(self) -> None:
191
+ """Start profiling for an operation."""
192
+ profiling_reset(self._profiling_context)
193
+ profiling_start(self._profiling_context)
194
+
195
+ def _prompt_start(self) -> None:
196
+ """Start prompt processing timing."""
197
+ profiling_prompt_start(self._profiling_context)
198
+
199
+ def _prompt_end(self) -> None:
200
+ """End prompt processing timing."""
201
+ profiling_prompt_end(self._profiling_context)
202
+
203
+ def _decode_start(self) -> None:
204
+ """Start decode timing."""
205
+ profiling_decode_start(self._profiling_context)
206
+
207
+ def _decode_end(self) -> None:
208
+ """End decode timing."""
209
+ profiling_decode_end(self._profiling_context)
210
+
211
+ def _record_ttft(self) -> None:
212
+ """Record time to first token."""
213
+ profiling_record_ttft(self._profiling_context)
214
+
215
+ def _update_prompt_tokens(self, prompt_tokens: int) -> None:
216
+ """Update prompt token count."""
217
+ profiling_update_prompt_tokens(self._profiling_context, prompt_tokens)
218
+
219
+ def _update_generated_tokens(self, generated_tokens: int) -> None:
220
+ """Update generated token count."""
221
+ profiling_update_generated_tokens(self._profiling_context, generated_tokens)
222
+
223
+ def _set_stop_reason(self, stop_reason: int) -> None:
224
+ """Set stop reason."""
225
+ profiling_stop_reason(self._profiling_context, stop_reason)
226
+
227
+ def _end_profiling(self) -> ProfilingData:
228
+ """End profiling and return data."""
229
+ profiling_end(self._profiling_context)
230
+ self._profiling_data = profiling_gen_data(self._profiling_context)
231
+ return self._profiling_data
232
+
233
+ def get_profiling_data(self) -> ProfilingData:
234
+ """Get profiling data for the last operation."""
235
+ return self._profiling_data
236
+
237
+ def reset_profiling(self) -> None:
238
+ """Reset profiling data."""
239
+ self._profiling_data.reset()
File without changes
@@ -0,0 +1,174 @@
1
+ # Copyright © Nexa AI
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import sys
16
+ import os
17
+ import mlx.core as mx
18
+ import mlx.nn as nn
19
+ import numpy as np
20
+ import time
21
+
22
+ from transformers import AutoTokenizer
23
+ from huggingface_hub import snapshot_download
24
+ from .modeling.nexa_jina_rerank import Model, ModelArgs
25
+
26
+
27
+ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
28
+ """Create position ids from input ids, accounting for padding tokens"""
29
+ mask = (input_ids != padding_idx).astype(mx.int32)
30
+ incremental_indices = (mx.cumsum(mask, axis=1) + past_key_values_length) * mask
31
+ return incremental_indices.astype(mx.int32) + padding_idx
32
+
33
+
34
+ def prepare_inputs(query, documents, tokenizer, max_length=1024):
35
+ """Prepare inputs for the model - match torch exactly"""
36
+ sentence_pairs = [[query, doc] for doc in documents]
37
+ inputs = tokenizer(
38
+ sentence_pairs,
39
+ padding="max_length",
40
+ truncation=True,
41
+ return_tensors="np",
42
+ max_length=max_length,
43
+ )
44
+
45
+ input_ids = mx.array(inputs["input_ids"]).astype(mx.int32)
46
+ seqlen = input_ids.shape[1]
47
+ attention_mask = mx.array(inputs["attention_mask"]).astype(mx.float32)
48
+
49
+ # Create token_type_ids as 1D tensor like torch, then broadcast for each batch item
50
+ token_type_ids_1d = mx.zeros(seqlen, dtype=mx.int32)
51
+ batch_size = input_ids.shape[0]
52
+ token_type_ids = mx.broadcast_to(
53
+ mx.expand_dims(token_type_ids_1d, axis=0), (batch_size, seqlen)
54
+ )
55
+
56
+ # Create position ids for each sequence in the batch
57
+ position_ids = create_position_ids_from_input_ids(input_ids, padding_idx=1)
58
+
59
+ return input_ids, attention_mask, token_type_ids, position_ids
60
+
61
+
62
+ def load_model(model_id):
63
+ """Initialize and load the Jina V2 rerank model."""
64
+ curr_dir = os.path.dirname(os.path.abspath(__file__))
65
+ model_dir = f"{curr_dir}/modelfiles/nexaml_jina_v2_rerank_mlx"
66
+
67
+ # Download model if not exists
68
+ if not os.path.exists(model_dir):
69
+ print(f"Downloading model {model_id}...")
70
+
71
+ os.makedirs(model_dir, exist_ok=True)
72
+
73
+ try:
74
+ snapshot_download(
75
+ repo_id=model_id,
76
+ allow_patterns=["*.safetensors", "config.json", "tokenizer*"],
77
+ local_dir=model_dir,
78
+ local_dir_use_symlinks=False
79
+ )
80
+ print("Model download completed!")
81
+ except Exception as e:
82
+ print(f"Failed to download model: {e}")
83
+ print("Try: huggingface-cli login (if authentication required)")
84
+ raise
85
+
86
+ # Create model config
87
+ config = ModelArgs()
88
+ model = Model(config)
89
+
90
+ # Load weights
91
+ weight_file = os.path.join(model_dir, "model.safetensors")
92
+ if not os.path.exists(weight_file):
93
+ # Try alternative naming patterns
94
+ safetensors_files = [f for f in os.listdir(model_dir) if f.endswith('.safetensors')]
95
+ if safetensors_files:
96
+ weight_file = os.path.join(model_dir, safetensors_files[0])
97
+ else:
98
+ raise FileNotFoundError(f"No .safetensors file found in {model_dir}")
99
+
100
+ print(f"Loading weights from: {weight_file}")
101
+ model.load_weights(weight_file, strict=True)
102
+ model.eval()
103
+
104
+ return model, model_dir
105
+
106
+
107
+ def load_tokenizer(model_path):
108
+ """Load and configure the tokenizer."""
109
+ return AutoTokenizer.from_pretrained(model_path)
110
+
111
+
112
+ def rerank_documents(model, tokenizer, query, documents, max_length=1024):
113
+ """Rerank documents based on query relevance."""
114
+ # Prepare inputs
115
+ input_ids, attention_mask, token_type_ids, position_ids = prepare_inputs(
116
+ query, documents, tokenizer, max_length
117
+ )
118
+
119
+ # Run inference
120
+ start_time = time.time()
121
+ scores = model.nexa_forward(input_ids, attention_mask, token_type_ids, position_ids)
122
+ scores = mx.squeeze(scores, axis=-1)
123
+ end_time = time.time()
124
+
125
+ # Apply sigmoid to get probabilities
126
+ scores_sigmoid = mx.sigmoid(scores)
127
+
128
+ inference_time = (end_time - start_time) * 1000 # Convert to ms
129
+
130
+ return scores, scores_sigmoid, inference_time
131
+
132
+
133
+ def main(model_id):
134
+ """Main function to handle reranking demonstration."""
135
+
136
+ # Load model and tokenizer
137
+ model, model_path = load_model(model_id)
138
+ tokenizer = load_tokenizer(model_path)
139
+
140
+ # Example query and documents
141
+ query = "What are the health benefits of green tea?"
142
+ documents = [
143
+ "Green tea is rich in antioxidants and may improve brain function.",
144
+ "Coffee contains caffeine and can boost energy levels.",
145
+ "Das Trinken von grünem Tee kann das Risiko für Herzkrankheiten senken.",
146
+ "Black tea is another popular beverage with its own health benefits.",
147
+ ]
148
+
149
+ # Perform reranking
150
+ scores, scores_sigmoid, inference_time = rerank_documents(
151
+ model, tokenizer, query, documents
152
+ )
153
+
154
+ # Display results
155
+ print("=" * 70)
156
+ print("Reranking Results:")
157
+ print("=" * 70)
158
+ print(f"Query: {query}")
159
+ print()
160
+
161
+ for i, (doc, score, prob) in enumerate(zip(documents, scores.tolist(), scores_sigmoid.tolist())):
162
+ print(f"Document {i+1}:")
163
+ print(f" Text: {doc}")
164
+ print(f" Score: {score:.4f}")
165
+ print(f" Probability: {prob:.4f}")
166
+ print()
167
+
168
+ print(f"Inference time: {inference_time:.1f}ms")
169
+ print(f"Throughput: {len(documents)/inference_time*1000:.1f} docs/s")
170
+
171
+
172
+ if __name__ == "__main__":
173
+ model_id = "nexaml/jina-v2-rerank-mlx"
174
+ main(model_id)
@@ -0,0 +1,287 @@
1
+ # Copyright © Nexa AI
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.s
14
+
15
+ from __future__ import annotations
16
+
17
+ import os
18
+ import json
19
+ import mlx.core as mx
20
+ import mlx.nn as nn
21
+ import numpy as np
22
+ import time
23
+ from pathlib import Path
24
+ from typing import Any, List, Optional, Sequence
25
+ from dataclasses import dataclass
26
+ from abc import ABC, abstractmethod
27
+
28
+ # Import necessary modules
29
+ from transformers import AutoTokenizer
30
+
31
+ # Import from ml.py for API alignment (assuming similar structure)
32
+ try:
33
+ from ml import (
34
+ Reranker as BaseReranker,
35
+ Path as PathType,
36
+ )
37
+ except ImportError:
38
+ # Fallback to local definitions if ml.py not available
39
+ PathType = Path
40
+ BaseReranker = ABC
41
+
42
+ # Import profiling module
43
+ from profiling import ProfilingMixin, ProfilingData, StopReason
44
+
45
+ # Import the model implementation
46
+ from .modeling.nexa_jina_rerank import Model, ModelArgs
47
+
48
+
49
+ @dataclass
50
+ class RerankConfig:
51
+ """Configuration for reranking."""
52
+ batch_size: int = 1
53
+ normalize: bool = True
54
+ normalize_method: str = "softmax" # "softmax" | "min-max" | "none"
55
+
56
+ def __init__(
57
+ self,
58
+ batch_size: int = 1,
59
+ normalize: bool = True,
60
+ normalize_method: str = "softmax",
61
+ ) -> None:
62
+ self.batch_size = batch_size
63
+ self.normalize = normalize
64
+ self.normalize_method = normalize_method
65
+
66
+
67
+ class Reranker(BaseReranker, ProfilingMixin):
68
+ """
69
+ Reranker interface for MLX reranking models.
70
+ API aligned with ml.py Reranker abstract base class.
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ model_path: PathType,
76
+ tokenizer_path: PathType,
77
+ device: Optional[str] = None,
78
+ ) -> None:
79
+ """Initialize the Reranker model."""
80
+ # Initialize profiling mixin
81
+ ProfilingMixin.__init__(self)
82
+
83
+ # Store paths
84
+ if (os.path.isfile(model_path)):
85
+ model_path = os.path.dirname(model_path)
86
+
87
+ # Call parent constructor if inheriting from ml.py
88
+ if hasattr(super(), '__init__'):
89
+ super().__init__(model_path, tokenizer_path, device)
90
+
91
+ # Store paths and device
92
+ self.model_path = model_path
93
+ self.tokenizer_path = tokenizer_path
94
+ self.device = device if device is not None else "cpu"
95
+
96
+ # Initialize model and tokenizer as None
97
+ self.model = None
98
+ self.tokenizer = None
99
+ self.config = None
100
+
101
+ def destroy(self) -> None:
102
+ """Destroy the model and free resources."""
103
+ self.model = None
104
+ self.tokenizer = None
105
+ self.config = None
106
+
107
+ def load_model(self, model_path: PathType, extra_data: Any = None) -> bool:
108
+ """Load model from path."""
109
+ try:
110
+ # Use the provided model_path or fall back to instance path
111
+ if model_path:
112
+ # Apply same file-to-directory conversion as in __init__
113
+ if os.path.isfile(model_path):
114
+ model_path = os.path.dirname(model_path)
115
+ self.model_path = model_path
116
+
117
+ # Load the model using internal implementation
118
+ self.model = self._load_jina_model(self.model_path)
119
+ self.tokenizer = self._load_tokenizer()
120
+
121
+ return True
122
+ except Exception as e:
123
+ print(f"Failed to load model: {e}")
124
+ return False
125
+
126
+ def close(self) -> None:
127
+ """Close the model."""
128
+ self.destroy()
129
+
130
+ def rerank(
131
+ self,
132
+ query: str,
133
+ documents: Sequence[str],
134
+ config: Optional[RerankConfig] = None,
135
+ clear_cache: bool = True,
136
+ ) -> mx.array:
137
+ """Rerank documents given a query."""
138
+ if self.model is None or self.tokenizer is None:
139
+ raise RuntimeError("Model not loaded. Call load_model() first.")
140
+
141
+ if config is None:
142
+ config = RerankConfig()
143
+
144
+ # Start profiling
145
+ self._start_profiling()
146
+ self._prompt_start()
147
+
148
+ all_scores = []
149
+
150
+ # Process documents in batches
151
+ batch_size = config.batch_size
152
+ for i in range(0, len(documents), batch_size):
153
+ batch_docs = documents[i:i + batch_size]
154
+ batch_scores = self._rerank_batch(query, batch_docs, config)
155
+ all_scores.append(batch_scores)
156
+
157
+ if clear_cache:
158
+ mx.clear_cache()
159
+
160
+ # End prompt processing, start decode
161
+ self._prompt_end()
162
+ self._decode_start()
163
+
164
+ # Concatenate all batch scores into a single array
165
+ res = mx.concatenate(all_scores, axis=0) if len(all_scores) > 1 else all_scores[0]
166
+
167
+ # End decode and profiling
168
+ self._decode_end()
169
+ self._set_stop_reason(StopReason.ML_STOP_REASON_COMPLETED)
170
+ self._end_profiling()
171
+
172
+ return res
173
+
174
+ def _load_jina_model(self, model_dir: str) -> Model:
175
+ """Initialize and load the Jina V2 rerank model."""
176
+
177
+ # Validate that model path exists
178
+ if not os.path.exists(model_dir):
179
+ raise ValueError(f"Model path does not exist: {model_dir}")
180
+
181
+ # Store model directory for tokenizer loading
182
+ self._model_dir = model_dir
183
+
184
+ # Create model config
185
+ config = ModelArgs()
186
+ model = Model(config)
187
+
188
+ # Load weights
189
+ weight_file = os.path.join(model_dir, "model.safetensors")
190
+ if not os.path.exists(weight_file):
191
+ # Try alternative naming patterns
192
+ safetensors_files = [f for f in os.listdir(model_dir) if f.endswith('.safetensors')]
193
+ if safetensors_files:
194
+ weight_file = os.path.join(model_dir, safetensors_files[0])
195
+ else:
196
+ raise FileNotFoundError(f"No .safetensors file found in {model_dir}")
197
+
198
+ model.load_weights(weight_file, strict=True)
199
+ model.eval()
200
+
201
+ return model
202
+
203
+ def _load_tokenizer(self) -> AutoTokenizer:
204
+ """Load and configure the tokenizer."""
205
+ return AutoTokenizer.from_pretrained(self._model_dir)
206
+
207
+ def _rerank_batch(self, query: str, documents: List[str], config: RerankConfig) -> mx.array:
208
+ """Rerank a batch of documents and return their scores."""
209
+ # Prepare inputs
210
+ input_ids, attention_mask, token_type_ids, position_ids = self._prepare_inputs(
211
+ query, documents, self.tokenizer, max_length=1024
212
+ )
213
+
214
+ # Run inference
215
+ scores = self.model.nexa_forward(input_ids, attention_mask, token_type_ids, position_ids)
216
+ scores = mx.squeeze(scores, axis=-1)
217
+
218
+ # Apply normalization if requested
219
+ if config.normalize:
220
+ scores = self._normalize_scores(scores, config.normalize_method)
221
+
222
+ return scores
223
+
224
+ def _create_position_ids_from_input_ids(self, input_ids, padding_idx, past_key_values_length=0):
225
+ """Create position ids from input ids, accounting for padding tokens"""
226
+ mask = (input_ids != padding_idx).astype(mx.int32)
227
+ incremental_indices = (mx.cumsum(mask, axis=1) + past_key_values_length) * mask
228
+ return incremental_indices.astype(mx.int32) + padding_idx
229
+
230
+ def _prepare_inputs(self, query, documents, tokenizer, max_length=1024):
231
+ """Prepare inputs for the model - match torch exactly"""
232
+ sentence_pairs = [[query, doc] for doc in documents]
233
+ inputs = tokenizer(
234
+ sentence_pairs,
235
+ padding="max_length",
236
+ truncation=True,
237
+ return_tensors="np",
238
+ max_length=max_length,
239
+ )
240
+
241
+ input_ids = mx.array(inputs["input_ids"]).astype(mx.int32)
242
+ seqlen = input_ids.shape[1]
243
+ attention_mask = mx.array(inputs["attention_mask"]).astype(mx.float32)
244
+
245
+ # Create token_type_ids as 1D tensor like torch, then broadcast for each batch item
246
+ token_type_ids_1d = mx.zeros(seqlen, dtype=mx.int32)
247
+ batch_size = input_ids.shape[0]
248
+ token_type_ids = mx.broadcast_to(
249
+ mx.expand_dims(token_type_ids_1d, axis=0), (batch_size, seqlen)
250
+ )
251
+
252
+ # Create position ids for each sequence in the batch
253
+ position_ids = self._create_position_ids_from_input_ids(input_ids, padding_idx=1)
254
+
255
+ return input_ids, attention_mask, token_type_ids, position_ids
256
+
257
+ def _normalize_scores(self, scores: mx.array, method: str) -> mx.array:
258
+ """Normalize scores using specified method."""
259
+ if method == "none":
260
+ return scores
261
+ elif method == "softmax":
262
+ # For 1D arrays, use axis=0; for higher dims, use axis=-1
263
+ if len(scores.shape) == 1:
264
+ return mx.softmax(scores, axis=0)
265
+ else:
266
+ return mx.softmax(scores, axis=-1)
267
+ elif method == "min-max":
268
+ min_val = mx.min(scores)
269
+ max_val = mx.max(scores)
270
+ if max_val > min_val:
271
+ return (scores - min_val) / (max_val - min_val)
272
+ return scores
273
+ else:
274
+ return scores
275
+
276
+
277
+ # Factory function for creating reranker instances
278
+ def create_reranker(
279
+ model_path: PathType,
280
+ tokenizer_path: Optional[PathType] = None,
281
+ device: Optional[str] = None,
282
+ ) -> Reranker:
283
+ """Create and return a Reranker instance."""
284
+ if tokenizer_path is None:
285
+ tokenizer_path = model_path
286
+
287
+ return Reranker(model_path, tokenizer_path, device)