nexaai 1.0.16rc13__cp310-cp310-macosx_15_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nexaai might be problematic. Click here for more details.

Files changed (557) hide show
  1. nexaai/__init__.py +83 -0
  2. nexaai/_stub.cpython-310-darwin.so +0 -0
  3. nexaai/_version.py +4 -0
  4. nexaai/asr.py +64 -0
  5. nexaai/asr_impl/__init__.py +0 -0
  6. nexaai/asr_impl/mlx_asr_impl.py +92 -0
  7. nexaai/asr_impl/pybind_asr_impl.py +44 -0
  8. nexaai/base.py +39 -0
  9. nexaai/binds/__init__.py +4 -0
  10. nexaai/binds/common_bind.cpython-310-darwin.so +0 -0
  11. nexaai/binds/embedder_bind.cpython-310-darwin.so +0 -0
  12. nexaai/binds/libnexa_bridge.dylib +0 -0
  13. nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
  14. nexaai/binds/nexa_llama_cpp/libggml-base.dylib +0 -0
  15. nexaai/binds/nexa_llama_cpp/libggml-cpu.so +0 -0
  16. nexaai/binds/nexa_llama_cpp/libggml-metal.so +0 -0
  17. nexaai/binds/nexa_llama_cpp/libggml.dylib +0 -0
  18. nexaai/binds/nexa_llama_cpp/libllama.dylib +0 -0
  19. nexaai/binds/nexa_llama_cpp/libmtmd.dylib +0 -0
  20. nexaai/binds/nexa_llama_cpp/libnexa_plugin.dylib +0 -0
  21. nexaai/binds/nexa_mlx/libnexa_plugin.dylib +0 -0
  22. nexaai/binds/nexa_mlx/py-lib/ml.py +888 -0
  23. nexaai/binds/nexa_mlx/py-lib/mlx_audio/__init__.py +0 -0
  24. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/__init__.py +1 -0
  25. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/__init__.py +5 -0
  26. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  27. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  28. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  29. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  30. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  31. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  32. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/__init__.py +1 -0
  33. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/base.py +228 -0
  34. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/dac.py +285 -0
  35. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  36. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  37. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  38. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/encodec/__init__.py +1 -0
  39. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/encodec/encodec.py +777 -0
  40. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/__init__.py +1 -0
  41. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/mimi.py +286 -0
  42. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  43. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  44. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  45. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  46. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  47. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  48. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/__init__.py +1 -0
  49. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/model.py +260 -0
  50. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/model_v2.py +383 -0
  51. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/s3/utils.py +122 -0
  52. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/__init__.py +1 -0
  53. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/attention.py +97 -0
  54. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/layers.py +306 -0
  55. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/snac.py +154 -0
  56. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/snac/vq.py +135 -0
  57. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/vocos/__init__.py +1 -0
  58. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/vocos/mel.py +33 -0
  59. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/models/vocos/vocos.py +359 -0
  60. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/__init__.py +0 -0
  61. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  62. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_descript.py +109 -0
  63. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_encodec.py +58 -0
  64. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_mimi.py +22 -0
  65. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_s3.py +25 -0
  66. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_snac.py +40 -0
  67. nexaai/binds/nexa_mlx/py-lib/mlx_audio/codec/tests/test_vocos.py +93 -0
  68. nexaai/binds/nexa_mlx/py-lib/mlx_audio/server.py +525 -0
  69. nexaai/binds/nexa_mlx/py-lib/mlx_audio/sts/__init__.py +0 -0
  70. nexaai/binds/nexa_mlx/py-lib/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  71. nexaai/binds/nexa_mlx/py-lib/mlx_audio/sts/voice_pipeline.py +327 -0
  72. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/__init__.py +0 -0
  73. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/generate.py +174 -0
  74. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/__init__.py +0 -0
  75. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  76. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  77. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/attention.py +187 -0
  78. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/audio.py +76 -0
  79. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  80. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  81. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  82. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  83. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  84. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  85. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  86. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/__init__.py +1 -0
  87. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/audio.py +82 -0
  88. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/decoding.py +742 -0
  89. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/timing.py +329 -0
  90. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  91. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/whisper.py +862 -0
  92. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/models/whisper/writers.py +268 -0
  93. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/tests/test_models.py +381 -0
  94. nexaai/binds/nexa_mlx/py-lib/mlx_audio/stt/utils.py +195 -0
  95. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/__init__.py +1 -0
  96. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/audio_player.py +120 -0
  97. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/convert.py +71 -0
  98. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/generate.py +449 -0
  99. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/__init__.py +0 -0
  100. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/__init__.py +4 -0
  101. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/bark.py +528 -0
  102. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/isftnet.py +12 -0
  103. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/bark/pipeline.py +442 -0
  104. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/base.py +84 -0
  105. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/__init__.py +1 -0
  106. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/audio.py +287 -0
  107. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/config.py +256 -0
  108. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/dia.py +592 -0
  109. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/dia/layers.py +870 -0
  110. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/__init__.py +3 -0
  111. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/attention.py +180 -0
  112. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  113. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/conformer.py +247 -0
  114. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  115. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  116. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  117. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  118. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  119. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  120. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/indextts.py +412 -0
  121. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/mel.py +37 -0
  122. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/normalize.py +294 -0
  123. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  124. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/interpolate.py +108 -0
  125. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  126. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  127. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  128. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/modules.py +659 -0
  129. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  130. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/kokoro/voice.py +113 -0
  131. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/llama/__init__.py +3 -0
  132. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/llama/llama.py +324 -0
  133. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/__init__.py +1 -0
  134. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  135. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  136. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/outetts.py +255 -0
  137. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  138. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/outetts/tokens.py +36 -0
  139. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/__init__.py +3 -0
  140. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/attention.py +195 -0
  141. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/sesame.py +633 -0
  142. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  143. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/__init__.py +1 -0
  144. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  145. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/bicodec.py +269 -0
  146. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  147. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  148. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  149. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  150. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  151. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  152. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  153. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  154. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  155. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  156. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  157. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  158. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  159. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  160. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  161. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/spark.py +382 -0
  162. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  163. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/utils/file.py +221 -0
  164. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  165. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/__init__.py +0 -0
  166. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_base.py +66 -0
  167. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_convert.py +173 -0
  168. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_interpolate.py +88 -0
  169. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/tests/test_models.py +974 -0
  170. nexaai/binds/nexa_mlx/py-lib/mlx_audio/tts/utils.py +337 -0
  171. nexaai/binds/nexa_mlx/py-lib/mlx_audio/utils.py +237 -0
  172. nexaai/binds/nexa_mlx/py-lib/mlx_audio/version.py +1 -0
  173. nexaai/binds/nexa_mlx/py-lib/profiling.py +239 -0
  174. nexaai/binds/nexa_nexaml/libggml-base.dylib +0 -0
  175. nexaai/binds/nexa_nexaml/libggml-cpu.so +0 -0
  176. nexaai/binds/nexa_nexaml/libggml-metal.so +0 -0
  177. nexaai/binds/nexa_nexaml/libggml.dylib +0 -0
  178. nexaai/binds/nexa_nexaml/libnexa-mm-process.dylib +0 -0
  179. nexaai/binds/nexa_nexaml/libnexa-sampling.dylib +0 -0
  180. nexaai/binds/nexa_nexaml/libnexa_plugin.dylib +0 -0
  181. nexaai/binds/nexa_nexaml/libnexaproc.dylib +0 -0
  182. nexaai/binds/nexa_nexaml/libqwen3-vl.dylib +0 -0
  183. nexaai/binds/nexa_nexaml/libqwen3vl-vision.dylib +0 -0
  184. nexaai/binds/vlm_bind.cpython-310-darwin.so +0 -0
  185. nexaai/common.py +104 -0
  186. nexaai/cv.py +92 -0
  187. nexaai/cv_impl/__init__.py +0 -0
  188. nexaai/cv_impl/mlx_cv_impl.py +89 -0
  189. nexaai/cv_impl/pybind_cv_impl.py +32 -0
  190. nexaai/embedder.py +72 -0
  191. nexaai/embedder_impl/__init__.py +0 -0
  192. nexaai/embedder_impl/mlx_embedder_impl.py +116 -0
  193. nexaai/embedder_impl/pybind_embedder_impl.py +95 -0
  194. nexaai/image_gen.py +140 -0
  195. nexaai/image_gen_impl/__init__.py +0 -0
  196. nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -0
  197. nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -0
  198. nexaai/llm.py +96 -0
  199. nexaai/llm_impl/__init__.py +0 -0
  200. nexaai/llm_impl/mlx_llm_impl.py +269 -0
  201. nexaai/llm_impl/pybind_llm_impl.py +218 -0
  202. nexaai/log.py +92 -0
  203. nexaai/mlx_backend/asr/__init__.py +12 -0
  204. nexaai/mlx_backend/asr/interface.py +122 -0
  205. nexaai/mlx_backend/common/__init__.py +0 -0
  206. nexaai/mlx_backend/common/utils.py +25 -0
  207. nexaai/mlx_backend/cv/__init__.py +0 -0
  208. nexaai/mlx_backend/cv/generate.py +195 -0
  209. nexaai/mlx_backend/cv/interface.py +151 -0
  210. nexaai/mlx_backend/cv/main.py +81 -0
  211. nexaai/mlx_backend/cv/modeling/pp_ocr_v4.py +1736 -0
  212. nexaai/mlx_backend/embedding/__init__.py +0 -0
  213. nexaai/mlx_backend/embedding/generate.py +333 -0
  214. nexaai/mlx_backend/embedding/interface.py +617 -0
  215. nexaai/mlx_backend/embedding/main.py +173 -0
  216. nexaai/mlx_backend/embedding/modeling/__init__.py +0 -0
  217. nexaai/mlx_backend/embedding/modeling/nexa_jina_v2.py +399 -0
  218. nexaai/mlx_backend/image_gen/__init__.py +1 -0
  219. nexaai/mlx_backend/image_gen/generate_sd.py +244 -0
  220. nexaai/mlx_backend/image_gen/interface.py +82 -0
  221. nexaai/mlx_backend/image_gen/main.py +281 -0
  222. nexaai/mlx_backend/image_gen/stable_diffusion/__init__.py +306 -0
  223. nexaai/mlx_backend/image_gen/stable_diffusion/clip.py +116 -0
  224. nexaai/mlx_backend/image_gen/stable_diffusion/config.py +65 -0
  225. nexaai/mlx_backend/image_gen/stable_diffusion/model_io.py +386 -0
  226. nexaai/mlx_backend/image_gen/stable_diffusion/sampler.py +105 -0
  227. nexaai/mlx_backend/image_gen/stable_diffusion/tokenizer.py +100 -0
  228. nexaai/mlx_backend/image_gen/stable_diffusion/unet.py +460 -0
  229. nexaai/mlx_backend/image_gen/stable_diffusion/vae.py +274 -0
  230. nexaai/mlx_backend/llm/__init__.py +0 -0
  231. nexaai/mlx_backend/llm/generate.py +149 -0
  232. nexaai/mlx_backend/llm/interface.py +764 -0
  233. nexaai/mlx_backend/llm/main.py +68 -0
  234. nexaai/mlx_backend/ml.py +888 -0
  235. nexaai/mlx_backend/mlx_audio/__init__.py +0 -0
  236. nexaai/mlx_backend/mlx_audio/codec/__init__.py +1 -0
  237. nexaai/mlx_backend/mlx_audio/codec/models/__init__.py +5 -0
  238. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  239. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  240. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  241. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  242. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  243. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  244. nexaai/mlx_backend/mlx_audio/codec/models/descript/__init__.py +1 -0
  245. nexaai/mlx_backend/mlx_audio/codec/models/descript/base.py +228 -0
  246. nexaai/mlx_backend/mlx_audio/codec/models/descript/dac.py +285 -0
  247. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  248. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  249. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  250. nexaai/mlx_backend/mlx_audio/codec/models/encodec/__init__.py +1 -0
  251. nexaai/mlx_backend/mlx_audio/codec/models/encodec/encodec.py +777 -0
  252. nexaai/mlx_backend/mlx_audio/codec/models/mimi/__init__.py +1 -0
  253. nexaai/mlx_backend/mlx_audio/codec/models/mimi/mimi.py +286 -0
  254. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  255. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  256. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  257. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  258. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  259. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  260. nexaai/mlx_backend/mlx_audio/codec/models/s3/__init__.py +1 -0
  261. nexaai/mlx_backend/mlx_audio/codec/models/s3/model.py +260 -0
  262. nexaai/mlx_backend/mlx_audio/codec/models/s3/model_v2.py +383 -0
  263. nexaai/mlx_backend/mlx_audio/codec/models/s3/utils.py +122 -0
  264. nexaai/mlx_backend/mlx_audio/codec/models/snac/__init__.py +1 -0
  265. nexaai/mlx_backend/mlx_audio/codec/models/snac/attention.py +97 -0
  266. nexaai/mlx_backend/mlx_audio/codec/models/snac/layers.py +306 -0
  267. nexaai/mlx_backend/mlx_audio/codec/models/snac/snac.py +154 -0
  268. nexaai/mlx_backend/mlx_audio/codec/models/snac/vq.py +135 -0
  269. nexaai/mlx_backend/mlx_audio/codec/models/vocos/__init__.py +1 -0
  270. nexaai/mlx_backend/mlx_audio/codec/models/vocos/mel.py +33 -0
  271. nexaai/mlx_backend/mlx_audio/codec/models/vocos/vocos.py +359 -0
  272. nexaai/mlx_backend/mlx_audio/codec/tests/__init__.py +0 -0
  273. nexaai/mlx_backend/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  274. nexaai/mlx_backend/mlx_audio/codec/tests/test_descript.py +109 -0
  275. nexaai/mlx_backend/mlx_audio/codec/tests/test_encodec.py +58 -0
  276. nexaai/mlx_backend/mlx_audio/codec/tests/test_mimi.py +22 -0
  277. nexaai/mlx_backend/mlx_audio/codec/tests/test_s3.py +25 -0
  278. nexaai/mlx_backend/mlx_audio/codec/tests/test_snac.py +40 -0
  279. nexaai/mlx_backend/mlx_audio/codec/tests/test_vocos.py +93 -0
  280. nexaai/mlx_backend/mlx_audio/server.py +525 -0
  281. nexaai/mlx_backend/mlx_audio/sts/__init__.py +0 -0
  282. nexaai/mlx_backend/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  283. nexaai/mlx_backend/mlx_audio/sts/voice_pipeline.py +327 -0
  284. nexaai/mlx_backend/mlx_audio/stt/__init__.py +0 -0
  285. nexaai/mlx_backend/mlx_audio/stt/generate.py +174 -0
  286. nexaai/mlx_backend/mlx_audio/stt/models/__init__.py +0 -0
  287. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  288. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  289. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/attention.py +187 -0
  290. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/audio.py +76 -0
  291. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  292. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  293. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  294. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  295. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  296. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  297. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  298. nexaai/mlx_backend/mlx_audio/stt/models/whisper/__init__.py +1 -0
  299. nexaai/mlx_backend/mlx_audio/stt/models/whisper/audio.py +82 -0
  300. nexaai/mlx_backend/mlx_audio/stt/models/whisper/decoding.py +742 -0
  301. nexaai/mlx_backend/mlx_audio/stt/models/whisper/timing.py +329 -0
  302. nexaai/mlx_backend/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  303. nexaai/mlx_backend/mlx_audio/stt/models/whisper/whisper.py +862 -0
  304. nexaai/mlx_backend/mlx_audio/stt/models/whisper/writers.py +268 -0
  305. nexaai/mlx_backend/mlx_audio/stt/tests/test_models.py +381 -0
  306. nexaai/mlx_backend/mlx_audio/stt/utils.py +195 -0
  307. nexaai/mlx_backend/mlx_audio/tts/__init__.py +1 -0
  308. nexaai/mlx_backend/mlx_audio/tts/audio_player.py +120 -0
  309. nexaai/mlx_backend/mlx_audio/tts/convert.py +71 -0
  310. nexaai/mlx_backend/mlx_audio/tts/generate.py +449 -0
  311. nexaai/mlx_backend/mlx_audio/tts/models/__init__.py +0 -0
  312. nexaai/mlx_backend/mlx_audio/tts/models/bark/__init__.py +4 -0
  313. nexaai/mlx_backend/mlx_audio/tts/models/bark/bark.py +528 -0
  314. nexaai/mlx_backend/mlx_audio/tts/models/bark/isftnet.py +12 -0
  315. nexaai/mlx_backend/mlx_audio/tts/models/bark/pipeline.py +442 -0
  316. nexaai/mlx_backend/mlx_audio/tts/models/base.py +84 -0
  317. nexaai/mlx_backend/mlx_audio/tts/models/dia/__init__.py +1 -0
  318. nexaai/mlx_backend/mlx_audio/tts/models/dia/audio.py +287 -0
  319. nexaai/mlx_backend/mlx_audio/tts/models/dia/config.py +256 -0
  320. nexaai/mlx_backend/mlx_audio/tts/models/dia/dia.py +592 -0
  321. nexaai/mlx_backend/mlx_audio/tts/models/dia/layers.py +870 -0
  322. nexaai/mlx_backend/mlx_audio/tts/models/indextts/__init__.py +3 -0
  323. nexaai/mlx_backend/mlx_audio/tts/models/indextts/attention.py +180 -0
  324. nexaai/mlx_backend/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  325. nexaai/mlx_backend/mlx_audio/tts/models/indextts/conformer.py +247 -0
  326. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  327. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  328. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  329. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  330. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  331. nexaai/mlx_backend/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  332. nexaai/mlx_backend/mlx_audio/tts/models/indextts/indextts.py +412 -0
  333. nexaai/mlx_backend/mlx_audio/tts/models/indextts/mel.py +37 -0
  334. nexaai/mlx_backend/mlx_audio/tts/models/indextts/normalize.py +294 -0
  335. nexaai/mlx_backend/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  336. nexaai/mlx_backend/mlx_audio/tts/models/interpolate.py +108 -0
  337. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  338. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  339. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  340. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/modules.py +659 -0
  341. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  342. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/voice.py +113 -0
  343. nexaai/mlx_backend/mlx_audio/tts/models/llama/__init__.py +3 -0
  344. nexaai/mlx_backend/mlx_audio/tts/models/llama/llama.py +324 -0
  345. nexaai/mlx_backend/mlx_audio/tts/models/outetts/__init__.py +1 -0
  346. nexaai/mlx_backend/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  347. nexaai/mlx_backend/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  348. nexaai/mlx_backend/mlx_audio/tts/models/outetts/default_speaker.json +461 -0
  349. nexaai/mlx_backend/mlx_audio/tts/models/outetts/outetts.py +255 -0
  350. nexaai/mlx_backend/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  351. nexaai/mlx_backend/mlx_audio/tts/models/outetts/tokens.py +36 -0
  352. nexaai/mlx_backend/mlx_audio/tts/models/sesame/__init__.py +3 -0
  353. nexaai/mlx_backend/mlx_audio/tts/models/sesame/attention.py +195 -0
  354. nexaai/mlx_backend/mlx_audio/tts/models/sesame/sesame.py +633 -0
  355. nexaai/mlx_backend/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  356. nexaai/mlx_backend/mlx_audio/tts/models/spark/__init__.py +1 -0
  357. nexaai/mlx_backend/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  358. nexaai/mlx_backend/mlx_audio/tts/models/spark/bicodec.py +269 -0
  359. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  360. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  361. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  362. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  363. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  364. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  365. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  366. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  367. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  368. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  369. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  370. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  371. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  372. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  373. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  374. nexaai/mlx_backend/mlx_audio/tts/models/spark/spark.py +382 -0
  375. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  376. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/file.py +221 -0
  377. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  378. nexaai/mlx_backend/mlx_audio/tts/tests/__init__.py +0 -0
  379. nexaai/mlx_backend/mlx_audio/tts/tests/test_base.py +66 -0
  380. nexaai/mlx_backend/mlx_audio/tts/tests/test_convert.py +173 -0
  381. nexaai/mlx_backend/mlx_audio/tts/tests/test_interpolate.py +88 -0
  382. nexaai/mlx_backend/mlx_audio/tts/tests/test_models.py +974 -0
  383. nexaai/mlx_backend/mlx_audio/tts/utils.py +337 -0
  384. nexaai/mlx_backend/mlx_audio/utils.py +237 -0
  385. nexaai/mlx_backend/mlx_audio/version.py +1 -0
  386. nexaai/mlx_backend/profiling.py +239 -0
  387. nexaai/mlx_backend/rerank/__init__.py +0 -0
  388. nexaai/mlx_backend/rerank/generate.py +174 -0
  389. nexaai/mlx_backend/rerank/interface.py +287 -0
  390. nexaai/mlx_backend/rerank/main.py +127 -0
  391. nexaai/mlx_backend/rerank/modeling/__init__.py +0 -0
  392. nexaai/mlx_backend/rerank/modeling/nexa_jina_rerank.py +330 -0
  393. nexaai/mlx_backend/sd/__init__.py +1 -0
  394. nexaai/mlx_backend/sd/interface.py +362 -0
  395. nexaai/mlx_backend/sd/main.py +286 -0
  396. nexaai/mlx_backend/sd/modeling/__init__.py +306 -0
  397. nexaai/mlx_backend/sd/modeling/clip.py +116 -0
  398. nexaai/mlx_backend/sd/modeling/config.py +65 -0
  399. nexaai/mlx_backend/sd/modeling/model_io.py +385 -0
  400. nexaai/mlx_backend/sd/modeling/sampler.py +105 -0
  401. nexaai/mlx_backend/sd/modeling/tokenizer.py +100 -0
  402. nexaai/mlx_backend/sd/modeling/unet.py +460 -0
  403. nexaai/mlx_backend/sd/modeling/vae.py +274 -0
  404. nexaai/mlx_backend/tts/__init__.py +12 -0
  405. nexaai/mlx_backend/tts/interface.py +276 -0
  406. nexaai/mlx_backend/vlm/__init__.py +3 -0
  407. nexaai/mlx_backend/vlm/generate.py +572 -0
  408. nexaai/mlx_backend/vlm/generate_qwen3_vl.py +261 -0
  409. nexaai/mlx_backend/vlm/interface.py +415 -0
  410. nexaai/mlx_backend/vlm/main.py +316 -0
  411. nexaai/mlx_backend/vlm/modeling/__init__.py +0 -0
  412. nexaai/mlx_backend/vlm/modeling/convert.py +68 -0
  413. nexaai/mlx_backend/vlm/modeling/models/__init__.py +0 -0
  414. nexaai/mlx_backend/vlm/modeling/models/aya_vision/__init__.py +8 -0
  415. nexaai/mlx_backend/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
  416. nexaai/mlx_backend/vlm/modeling/models/aya_vision/interpolate.py +186 -0
  417. nexaai/mlx_backend/vlm/modeling/models/aya_vision/language.py +233 -0
  418. nexaai/mlx_backend/vlm/modeling/models/aya_vision/vision.py +503 -0
  419. nexaai/mlx_backend/vlm/modeling/models/base.py +202 -0
  420. nexaai/mlx_backend/vlm/modeling/models/cache.py +230 -0
  421. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
  422. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
  423. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
  424. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
  425. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
  426. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
  427. nexaai/mlx_backend/vlm/modeling/models/florence2/__init__.py +8 -0
  428. nexaai/mlx_backend/vlm/modeling/models/florence2/florence2.py +366 -0
  429. nexaai/mlx_backend/vlm/modeling/models/florence2/language.py +488 -0
  430. nexaai/mlx_backend/vlm/modeling/models/florence2/vision.py +591 -0
  431. nexaai/mlx_backend/vlm/modeling/models/gemma3/__init__.py +8 -0
  432. nexaai/mlx_backend/vlm/modeling/models/gemma3/gemma3.py +213 -0
  433. nexaai/mlx_backend/vlm/modeling/models/gemma3/language.py +315 -0
  434. nexaai/mlx_backend/vlm/modeling/models/gemma3/vision.py +238 -0
  435. nexaai/mlx_backend/vlm/modeling/models/gemma3n/__init__.py +2 -0
  436. nexaai/mlx_backend/vlm/modeling/models/gemma3n/audio.py +1038 -0
  437. nexaai/mlx_backend/vlm/modeling/models/gemma3n/config.py +139 -0
  438. nexaai/mlx_backend/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
  439. nexaai/mlx_backend/vlm/modeling/models/gemma3n/language.py +629 -0
  440. nexaai/mlx_backend/vlm/modeling/models/gemma3n/vision.py +1022 -0
  441. nexaai/mlx_backend/vlm/modeling/models/idefics2/__init__.py +9 -0
  442. nexaai/mlx_backend/vlm/modeling/models/idefics2/idefics2.py +294 -0
  443. nexaai/mlx_backend/vlm/modeling/models/idefics2/language.py +191 -0
  444. nexaai/mlx_backend/vlm/modeling/models/idefics2/vision.py +267 -0
  445. nexaai/mlx_backend/vlm/modeling/models/idefics3/__init__.py +8 -0
  446. nexaai/mlx_backend/vlm/modeling/models/idefics3/idefics3.py +175 -0
  447. nexaai/mlx_backend/vlm/modeling/models/idefics3/language.py +192 -0
  448. nexaai/mlx_backend/vlm/modeling/models/idefics3/vision.py +233 -0
  449. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/__init__.py +9 -0
  450. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
  451. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/language.py +220 -0
  452. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/processor.py +393 -0
  453. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/vision.py +293 -0
  454. nexaai/mlx_backend/vlm/modeling/models/kernels.py +307 -0
  455. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/__init__.py +8 -0
  456. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
  457. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/language.py +509 -0
  458. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/vision.py +522 -0
  459. nexaai/mlx_backend/vlm/modeling/models/llama4/__init__.py +8 -0
  460. nexaai/mlx_backend/vlm/modeling/models/llama4/language.py +386 -0
  461. nexaai/mlx_backend/vlm/modeling/models/llama4/llama4.py +138 -0
  462. nexaai/mlx_backend/vlm/modeling/models/llama4/vision.py +560 -0
  463. nexaai/mlx_backend/vlm/modeling/models/llava/__init__.py +8 -0
  464. nexaai/mlx_backend/vlm/modeling/models/llava/language.py +240 -0
  465. nexaai/mlx_backend/vlm/modeling/models/llava/llava.py +153 -0
  466. nexaai/mlx_backend/vlm/modeling/models/llava/vision.py +259 -0
  467. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/__init__.py +9 -0
  468. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/language.py +236 -0
  469. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
  470. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/vision.py +303 -0
  471. nexaai/mlx_backend/vlm/modeling/models/llava_next/__init__.py +8 -0
  472. nexaai/mlx_backend/vlm/modeling/models/llava_next/language.py +230 -0
  473. nexaai/mlx_backend/vlm/modeling/models/llava_next/llava_next.py +160 -0
  474. nexaai/mlx_backend/vlm/modeling/models/llava_next/vision.py +243 -0
  475. nexaai/mlx_backend/vlm/modeling/models/mistral3/__init__.py +8 -0
  476. nexaai/mlx_backend/vlm/modeling/models/mistral3/mistral3.py +283 -0
  477. nexaai/mlx_backend/vlm/modeling/models/mllama/__init__.py +8 -0
  478. nexaai/mlx_backend/vlm/modeling/models/mllama/language.py +416 -0
  479. nexaai/mlx_backend/vlm/modeling/models/mllama/mllama.py +172 -0
  480. nexaai/mlx_backend/vlm/modeling/models/mllama/vision.py +499 -0
  481. nexaai/mlx_backend/vlm/modeling/models/molmo/__init__.py +8 -0
  482. nexaai/mlx_backend/vlm/modeling/models/molmo/language.py +243 -0
  483. nexaai/mlx_backend/vlm/modeling/models/molmo/molmo.py +133 -0
  484. nexaai/mlx_backend/vlm/modeling/models/molmo/vision.py +465 -0
  485. nexaai/mlx_backend/vlm/modeling/models/multi_modality/__init__.py +10 -0
  486. nexaai/mlx_backend/vlm/modeling/models/multi_modality/language.py +230 -0
  487. nexaai/mlx_backend/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
  488. nexaai/mlx_backend/vlm/modeling/models/multi_modality/sam.py +557 -0
  489. nexaai/mlx_backend/vlm/modeling/models/multi_modality/vision.py +526 -0
  490. nexaai/mlx_backend/vlm/modeling/models/paligemma/__init__.py +8 -0
  491. nexaai/mlx_backend/vlm/modeling/models/paligemma/language.py +282 -0
  492. nexaai/mlx_backend/vlm/modeling/models/paligemma/paligemma.py +160 -0
  493. nexaai/mlx_backend/vlm/modeling/models/paligemma/vision.py +242 -0
  494. nexaai/mlx_backend/vlm/modeling/models/phi3_v/__init__.py +8 -0
  495. nexaai/mlx_backend/vlm/modeling/models/phi3_v/language.py +21 -0
  496. nexaai/mlx_backend/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
  497. nexaai/mlx_backend/vlm/modeling/models/phi3_v/su_rope.py +71 -0
  498. nexaai/mlx_backend/vlm/modeling/models/phi3_v/vision.py +324 -0
  499. nexaai/mlx_backend/vlm/modeling/models/pixtral/__init__.py +8 -0
  500. nexaai/mlx_backend/vlm/modeling/models/pixtral/language.py +229 -0
  501. nexaai/mlx_backend/vlm/modeling/models/pixtral/pixtral.py +161 -0
  502. nexaai/mlx_backend/vlm/modeling/models/pixtral/vision.py +320 -0
  503. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
  504. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
  505. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
  506. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
  507. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
  508. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
  509. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/config.py +104 -0
  510. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/language.py +490 -0
  511. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
  512. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/vision.py +312 -0
  513. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
  514. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
  515. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
  516. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
  517. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
  518. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
  519. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
  520. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/processor.py +476 -0
  521. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +1223 -0
  522. nexaai/mlx_backend/vlm/modeling/models/smolvlm/__init__.py +8 -0
  523. nexaai/mlx_backend/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
  524. nexaai/mlx_backend/vlm/modeling/processing_qwen2_5_vl.py +209 -0
  525. nexaai/mlx_backend/vlm/modeling/processing_qwen2_vl.py +215 -0
  526. nexaai/mlx_backend/vlm/modeling/prompt_utils.py +474 -0
  527. nexaai/mlx_backend/vlm/modeling/sample_utils.py +39 -0
  528. nexaai/mlx_backend/vlm/modeling/tokenizer_utils.py +344 -0
  529. nexaai/mlx_backend/vlm/modeling/trainer/__init__.py +9 -0
  530. nexaai/mlx_backend/vlm/modeling/trainer/lora.py +70 -0
  531. nexaai/mlx_backend/vlm/modeling/trainer/trainer.py +296 -0
  532. nexaai/mlx_backend/vlm/modeling/trainer/utils.py +160 -0
  533. nexaai/mlx_backend/vlm/modeling/utils.py +928 -0
  534. nexaai/rerank.py +55 -0
  535. nexaai/rerank_impl/__init__.py +0 -0
  536. nexaai/rerank_impl/mlx_rerank_impl.py +92 -0
  537. nexaai/rerank_impl/pybind_rerank_impl.py +43 -0
  538. nexaai/runtime.py +68 -0
  539. nexaai/tts.py +74 -0
  540. nexaai/tts_impl/__init__.py +0 -0
  541. nexaai/tts_impl/mlx_tts_impl.py +94 -0
  542. nexaai/tts_impl/pybind_tts_impl.py +43 -0
  543. nexaai/utils/avatar_fetcher.py +104 -0
  544. nexaai/utils/decode.py +18 -0
  545. nexaai/utils/manifest_utils.py +324 -0
  546. nexaai/utils/model_manager.py +1353 -0
  547. nexaai/utils/model_types.py +47 -0
  548. nexaai/utils/progress_tracker.py +385 -0
  549. nexaai/utils/quantization_utils.py +245 -0
  550. nexaai/vlm.py +128 -0
  551. nexaai/vlm_impl/__init__.py +0 -0
  552. nexaai/vlm_impl/mlx_vlm_impl.py +258 -0
  553. nexaai/vlm_impl/pybind_vlm_impl.py +230 -0
  554. nexaai-1.0.16rc13.dist-info/METADATA +32 -0
  555. nexaai-1.0.16rc13.dist-info/RECORD +557 -0
  556. nexaai-1.0.16rc13.dist-info/WHEEL +5 -0
  557. nexaai-1.0.16rc13.dist-info/top_level.txt +1 -0
@@ -0,0 +1,474 @@
1
+ from enum import Enum
2
+ from functools import partial
3
+ from typing import Any, Dict, List, Optional, Union
4
+
5
+
6
+ class MessageFormat(Enum):
7
+ """Enum for different message format types."""
8
+
9
+ LIST_WITH_IMAGE = "list_with_image"
10
+ LIST_WITH_IMAGE_FIRST = "list_with_image_first"
11
+ LIST_WITH_IMAGE_TYPE = "list_with_image_type"
12
+ LIST_WITH_IMAGE_TYPE_TEXT = "list_with_image_type_text"
13
+ LIST_WITH_IMAGE_TYPE_TEXT_IMAGE_LAST = "list_with_image_type_text_image_last"
14
+ IMAGE_TOKEN = "image_token"
15
+ IMAGE_TOKEN_PIPE = "image_token_pipe"
16
+ START_IMAGE_TOKEN = "start_image_token"
17
+ IMAGE_TOKEN_NEWLINE = "image_token_newline"
18
+ NUMBERED_IMAGE_TOKENS = "numbered_image_tokens"
19
+ PROMPT_ONLY = "prompt_only"
20
+ PROMPT_WITH_IMAGE_TOKEN = "prompt_with_image_token"
21
+ PROMPT_WITH_START_IMAGE_TOKEN = "prompt_with_start_image_token"
22
+ VIDEO_WITH_TEXT = "video_with_text"
23
+
24
+
25
+ # Model configuration mapping
26
+ MODEL_CONFIG = {
27
+ # List with image format models
28
+ "idefics2": MessageFormat.LIST_WITH_IMAGE,
29
+ "idefics3": MessageFormat.LIST_WITH_IMAGE_FIRST,
30
+ "aya_vision": MessageFormat.LIST_WITH_IMAGE,
31
+ "qwen2_vl": MessageFormat.LIST_WITH_IMAGE,
32
+ "qwen2_5_vl": MessageFormat.LIST_WITH_IMAGE_FIRST,
33
+ "mistral3": MessageFormat.LIST_WITH_IMAGE_FIRST,
34
+ "internvl_chat": MessageFormat.LIST_WITH_IMAGE_TYPE,
35
+ "kimi_vl": MessageFormat.LIST_WITH_IMAGE,
36
+ "gemma3": MessageFormat.START_IMAGE_TOKEN,
37
+ "gemma3n": MessageFormat.LIST_WITH_IMAGE_TYPE_TEXT_IMAGE_LAST,
38
+ "llama4": MessageFormat.LIST_WITH_IMAGE,
39
+ "smolvlm": MessageFormat.LIST_WITH_IMAGE_FIRST,
40
+ "llava": MessageFormat.LIST_WITH_IMAGE,
41
+ "llava_next": MessageFormat.LIST_WITH_IMAGE,
42
+ "mllama": MessageFormat.LIST_WITH_IMAGE,
43
+ "pixtral": MessageFormat.LIST_WITH_IMAGE_TYPE,
44
+ # Token-based models
45
+ "llava-qwen2": MessageFormat.IMAGE_TOKEN_NEWLINE,
46
+ "bunny-llama": MessageFormat.IMAGE_TOKEN_NEWLINE,
47
+ "phi3_v": MessageFormat.NUMBERED_IMAGE_TOKENS,
48
+ "multi_modality": MessageFormat.IMAGE_TOKEN,
49
+ "deepseek_vl_v2": MessageFormat.IMAGE_TOKEN_NEWLINE,
50
+ # Prompt-only models
51
+ "florence2": MessageFormat.PROMPT_ONLY,
52
+ "molmo": MessageFormat.PROMPT_ONLY,
53
+ "paligemma": MessageFormat.PROMPT_WITH_IMAGE_TOKEN,
54
+ }
55
+
56
+ # Models that don't support multi-image
57
+ SINGLE_IMAGE_ONLY_MODELS = {
58
+ "llava_next",
59
+ "llava-qwen2",
60
+ "bunny-llama",
61
+ "paligemma",
62
+ "multi_modality",
63
+ "mllama",
64
+ }
65
+
66
+
67
+ class MessageBuilder:
68
+ """Builder for creating messages in various formats."""
69
+
70
+ @staticmethod
71
+ def text_message(text: str) -> Dict[str, str]:
72
+ """Create a simple text message."""
73
+ return {"type": "text", "text": text}
74
+
75
+ @staticmethod
76
+ def content_message(content: str) -> Dict[str, str]:
77
+ """Create a content-type text message."""
78
+ return {"type": "text", "content": content}
79
+
80
+ @staticmethod
81
+ def image_message() -> Dict[str, str]:
82
+ """Create an image message."""
83
+ return {"type": "image"}
84
+
85
+ @staticmethod
86
+ def audio_message() -> Dict[str, str]:
87
+ """Create an audio message."""
88
+ return {"type": "audio"}
89
+
90
+ @staticmethod
91
+ def video_message(
92
+ video_path: str, max_pixels: int = 224 * 224, fps: int = 1
93
+ ) -> Dict[str, Any]:
94
+ """Create a video message."""
95
+ return {
96
+ "type": "video",
97
+ "video": video_path,
98
+ "max_pixels": max_pixels,
99
+ "fps": fps,
100
+ }
101
+
102
+
103
+ class MessageFormatter:
104
+ """Handles formatting messages for different model types."""
105
+
106
+ def __init__(self, model_name: str):
107
+ self.model_name = model_name.lower()
108
+ self.format_type = MODEL_CONFIG.get(self.model_name)
109
+ if not self.format_type:
110
+ raise ValueError(f"Unsupported model: {model_name}")
111
+
112
+ def format_message(
113
+ self,
114
+ prompt: str,
115
+ role: str = "user",
116
+ skip_image_token: bool = False,
117
+ skip_audio_token: bool = False,
118
+ num_images: int = 1,
119
+ num_audios: int = 1,
120
+ **kwargs,
121
+ ) -> Union[str, Dict[str, Any]]:
122
+ """Format a message based on the model type."""
123
+
124
+ # Check multi-image support
125
+ if num_images > 1 and self.model_name in SINGLE_IMAGE_ONLY_MODELS:
126
+ raise ValueError(
127
+ f"Model {self.model_name} does not support multi-image chat. "
128
+ f"Please only use 1 image."
129
+ )
130
+
131
+ # Handle video format for specific models
132
+ if self.model_name in ["qwen2_vl", "qwen2_5_vl"] and kwargs.get("video"):
133
+ return self._format_video_message(prompt, kwargs)
134
+
135
+ # Route to appropriate formatter
136
+ formatter_map = {
137
+ MessageFormat.LIST_WITH_IMAGE: self._format_list_with_image,
138
+ MessageFormat.LIST_WITH_IMAGE_FIRST: partial(
139
+ self._format_list_with_image, image_first=True
140
+ ),
141
+ MessageFormat.LIST_WITH_IMAGE_TYPE: self._format_list_with_image_type,
142
+ MessageFormat.LIST_WITH_IMAGE_TYPE_TEXT: partial(
143
+ self._format_list_with_image_type, message_type="text"
144
+ ),
145
+ MessageFormat.LIST_WITH_IMAGE_TYPE_TEXT_IMAGE_LAST: partial(
146
+ self._format_list_with_image_type,
147
+ message_type="text",
148
+ image_first=False,
149
+ ),
150
+ MessageFormat.IMAGE_TOKEN: partial(
151
+ self._format_with_token, token="<image>"
152
+ ),
153
+ MessageFormat.IMAGE_TOKEN_PIPE: partial(
154
+ self._format_with_token, token="<|image|>"
155
+ ),
156
+ MessageFormat.START_IMAGE_TOKEN: partial(
157
+ self._format_with_token, token="<start_of_image>", image_first=False
158
+ ),
159
+ MessageFormat.IMAGE_TOKEN_NEWLINE: partial(
160
+ self._format_with_token, token="<image>\n"
161
+ ),
162
+ MessageFormat.NUMBERED_IMAGE_TOKENS: self._format_numbered_tokens,
163
+ MessageFormat.PROMPT_ONLY: lambda *args, **kw: prompt,
164
+ MessageFormat.PROMPT_WITH_IMAGE_TOKEN: lambda *args, **kw: "<image>"
165
+ * num_images
166
+ + prompt,
167
+ MessageFormat.PROMPT_WITH_START_IMAGE_TOKEN: lambda *args, **kw: prompt
168
+ + "<start_of_image>" * num_images,
169
+ MessageFormat.VIDEO_WITH_TEXT: self._format_video_message,
170
+ }
171
+
172
+ formatter = formatter_map.get(self.format_type)
173
+ return formatter(
174
+ prompt,
175
+ role,
176
+ skip_image_token,
177
+ skip_audio_token,
178
+ num_images,
179
+ num_audios,
180
+ **kwargs,
181
+ )
182
+
183
+ def _format_list_with_image(
184
+ self,
185
+ prompt: str,
186
+ role: str,
187
+ skip_image_token: bool,
188
+ skip_audio_token: bool,
189
+ num_images: int,
190
+ num_audios: int,
191
+ image_first: bool = False,
192
+ **kwargs,
193
+ ) -> Dict[str, Any]:
194
+ """Format as a list with image tokens."""
195
+ content = [MessageBuilder.text_message(prompt)]
196
+
197
+ if role == "user" and not skip_image_token:
198
+ image_tokens = [MessageBuilder.image_message()] * num_images
199
+ content = image_tokens + content if image_first else content + image_tokens
200
+
201
+ return {"role": role, "content": content}
202
+
203
+ def _format_list_with_image_type(
204
+ self,
205
+ prompt: str,
206
+ role: str,
207
+ skip_image_token: bool,
208
+ skip_audio_token: bool,
209
+ num_images: int,
210
+ num_audios: int,
211
+ message_type: str = "content",
212
+ image_first: bool = True,
213
+ **kwargs,
214
+ ) -> Dict[str, Any]:
215
+ """Format as a list with typed messages."""
216
+ msg_func = (
217
+ MessageBuilder.content_message
218
+ if message_type == "content"
219
+ else MessageBuilder.text_message
220
+ )
221
+ message = {"role": role, "content": [msg_func(prompt)]}
222
+
223
+ if role == "user":
224
+ if not skip_image_token:
225
+ message["content"] = (
226
+ [MessageBuilder.image_message()] * num_images + message["content"]
227
+ if image_first
228
+ else message["content"]
229
+ + [MessageBuilder.image_message()] * num_images
230
+ )
231
+ if not skip_audio_token:
232
+ message["content"] = (
233
+ message["content"] + [MessageBuilder.audio_message()] * num_audios
234
+ )
235
+
236
+ if role == "assistant":
237
+ message["content"] = message["content"][0].get(
238
+ "content", message["content"][0].get("text")
239
+ )
240
+
241
+ return message
242
+
243
+ def _format_with_token(
244
+ self,
245
+ prompt: str,
246
+ role: str,
247
+ skip_image_token: bool,
248
+ skip_audio_token: bool,
249
+ num_images: int,
250
+ num_audios: int,
251
+ token: str,
252
+ image_first: bool = True,
253
+ **kwargs,
254
+ ) -> Dict[str, Any]:
255
+ """Format with image tokens in the text."""
256
+ content = prompt
257
+
258
+ if role == "user" and not skip_image_token:
259
+ prefix = token * num_images
260
+ content = f"{prefix}{content}" if image_first else f"{content}{prefix}"
261
+
262
+ return {"role": role, "content": content}
263
+
264
+ def _format_numbered_tokens(
265
+ self,
266
+ prompt: str,
267
+ role: str,
268
+ skip_image_token: bool,
269
+ skip_audio_token: bool,
270
+ num_images: int,
271
+ num_audios: int,
272
+ **kwargs,
273
+ ) -> Dict[str, Any]:
274
+ """Format with numbered image tokens."""
275
+ content = prompt
276
+
277
+ if role == "user" and not skip_image_token:
278
+ # phi3_v uses single token regardless of num_images
279
+ prefix = (
280
+ "<|image_1|>"
281
+ if self.model_name == "phi3_v"
282
+ else " ".join([f"<|image_{i+1}|>" for i in range(num_images)])
283
+ )
284
+ content = f"{prefix}{content}"
285
+
286
+ return {"role": role, "content": content}
287
+
288
+ def _format_video_message(
289
+ self,
290
+ prompt: str,
291
+ role: str = "user",
292
+ skip_image_token: bool = False,
293
+ skip_audio_token: bool = False,
294
+ num_images: int = 0,
295
+ num_audios: int = 0,
296
+ **kwargs,
297
+ ) -> Dict[str, Any]:
298
+ """Format a video message with text."""
299
+ return {
300
+ "role": role,
301
+ "content": [
302
+ MessageBuilder.video_message(
303
+ kwargs["video"],
304
+ kwargs.get("max_pixels", 224 * 224),
305
+ kwargs.get("fps", 1),
306
+ ),
307
+ MessageBuilder.text_message(prompt),
308
+ ],
309
+ }
310
+
311
+
312
+ def get_message_json(
313
+ model_name: str,
314
+ prompt: str,
315
+ role: str = "user",
316
+ skip_image_token: bool = False,
317
+ skip_audio_token: bool = False,
318
+ num_images: int = 0,
319
+ num_audios: int = 0,
320
+ **kwargs,
321
+ ) -> Union[str, Dict[str, Any]]:
322
+ """
323
+ Get the appropriate JSON message based on the specified model.
324
+
325
+ Args:
326
+ model_name: The model for which to generate the message
327
+ prompt: The text prompt to be included in the message
328
+ role: The role of the message (default: "user")
329
+ skip_image_token: Whether to skip adding image tokens
330
+ skip_audio_token: Whether to skip adding audio tokens
331
+ num_images: Number of image tokens to add
332
+ num_audios: Number of audio tokens to add
333
+ **kwargs: Additional arguments (e.g., video path, max_pixels, fps)
334
+
335
+ Returns:
336
+ A dictionary or string representing the message for the specified model
337
+ """
338
+ formatter = MessageFormatter(model_name)
339
+
340
+ return formatter.format_message(
341
+ prompt,
342
+ role,
343
+ skip_image_token,
344
+ skip_audio_token,
345
+ num_images,
346
+ num_audios,
347
+ **kwargs,
348
+ )
349
+
350
+
351
+ def get_chat_template(
352
+ processor,
353
+ messages: List[Dict[str, Any]],
354
+ add_generation_prompt: bool,
355
+ tokenize: bool = False,
356
+ **kwargs,
357
+ ) -> Any:
358
+ """Apply chat template using processor's tokenizer."""
359
+ try:
360
+ processor = (
361
+ processor
362
+ if "chat_template" in processor.__dict__.keys()
363
+ else processor.tokenizer
364
+ )
365
+
366
+ return processor.apply_chat_template(
367
+ messages,
368
+ tokenize=tokenize,
369
+ add_generation_prompt=add_generation_prompt,
370
+ **kwargs,
371
+ )
372
+ except AttributeError:
373
+ raise ValueError(
374
+ "Error: processor does not have 'chat_template' or 'tokenizer' attribute."
375
+ )
376
+
377
+
378
+ def apply_chat_template(
379
+ processor,
380
+ config: Union[Dict[str, Any], Any],
381
+ prompt: Union[str, Dict[str, Any], List[Any]],
382
+ add_generation_prompt: bool = True,
383
+ return_messages: bool = False,
384
+ num_images: int = 0,
385
+ num_audios: int = 0,
386
+ **kwargs,
387
+ ) -> Union[List[Dict[str, Any]], str, Any]:
388
+ """
389
+ Apply chat template to prompts.
390
+
391
+ Args:
392
+ processor: The processor with chat template functionality
393
+ config: Model configuration
394
+ prompt: Single prompt string, dict, or list of prompts
395
+ add_generation_prompt: Whether to add generation prompt
396
+ return_messages: Whether to return messages list instead of template
397
+ num_images: Number of images in the input
398
+ num_audios: Number of audio files in the input
399
+ **kwargs: Additional arguments for message formatting
400
+
401
+ Returns:
402
+ Formatted messages or chat template
403
+ """
404
+ config = config if isinstance(config, dict) else config.__dict__
405
+ model_type = config["model_type"]
406
+
407
+ # Build messages from prompts
408
+ messages = []
409
+
410
+ if isinstance(prompt, str):
411
+ # Single string prompt
412
+ messages.append(
413
+ get_message_json(
414
+ model_type,
415
+ prompt,
416
+ num_images=num_images,
417
+ num_audios=num_audios,
418
+ **kwargs,
419
+ )
420
+ )
421
+ elif isinstance(prompt, dict):
422
+ # Single dict prompt
423
+ messages.append(
424
+ get_message_json(
425
+ model_type,
426
+ prompt["content"],
427
+ prompt["role"],
428
+ num_images=num_images,
429
+ num_audios=num_audios,
430
+ **kwargs,
431
+ )
432
+ )
433
+ elif isinstance(prompt, list):
434
+ # List of prompts
435
+ for i, p in enumerate(prompt):
436
+ if isinstance(p, str):
437
+ is_first = i == 0
438
+ messages.append(
439
+ get_message_json(
440
+ model_type,
441
+ p,
442
+ skip_image_token=not is_first,
443
+ skip_audio_token=not is_first,
444
+ num_images=num_images,
445
+ num_audios=num_audios,
446
+ **kwargs,
447
+ )
448
+ )
449
+ elif isinstance(p, dict):
450
+ role = p.get("role", "user")
451
+ is_first = i == 0 or (i == 1 and role not in ["system", "assistant"])
452
+ messages.append(
453
+ get_message_json(
454
+ model_type,
455
+ p["content"],
456
+ role,
457
+ skip_image_token=not is_first
458
+ or role in ["system", "assistant"],
459
+ skip_audio_token=not is_first
460
+ or role in ["system", "assistant"],
461
+ num_images=num_images,
462
+ num_audios=num_audios,
463
+ **kwargs,
464
+ )
465
+ )
466
+
467
+ if return_messages:
468
+ return messages
469
+
470
+ # Some models only need the last message
471
+ if model_type in ["paligemma", "molmo", "florence2"]:
472
+ return messages[-1]
473
+
474
+ return get_chat_template(processor, messages, add_generation_prompt)
@@ -0,0 +1,39 @@
1
+ import mlx.core as mx
2
+
3
+
4
+ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.array:
5
+ """
6
+ Apply top-p (nucleus) sampling to logits.
7
+
8
+ Args:
9
+ logits: The logits from the model's output.
10
+ top_p: The cumulative probability threshold for top-p filtering.
11
+ temperature: Temperature parameter for softmax distribution reshaping.
12
+ Returns:
13
+ token selected based on the top-p criterion.
14
+ """
15
+ if (
16
+ logits.dtype == mx.bfloat16
17
+ ): # workaround for unable to load kernel contiguous_scan_inclusive_sum_bfloat16_bfloat16
18
+ logits = logits.astype(mx.float32)
19
+
20
+ # referenced implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L449-L460
21
+ probs = mx.softmax(logits / temperature, axis=-1)
22
+
23
+ # sort probs in ascending order
24
+ sorted_indices = mx.argsort(probs, axis=-1)
25
+ sorted_probs = probs[..., sorted_indices.squeeze(0)]
26
+
27
+ cumulative_probs = mx.cumsum(sorted_probs, axis=-1)
28
+
29
+ # select tokens with cumulative probs below threshold
30
+ top_probs = mx.where(
31
+ cumulative_probs > 1 - top_p,
32
+ sorted_probs,
33
+ mx.zeros_like(sorted_probs),
34
+ )
35
+
36
+ sorted_token = mx.random.categorical(mx.log(top_probs))
37
+ token = sorted_indices.squeeze(0)[sorted_token]
38
+
39
+ return token