nexaai 1.0.29__cp310-cp310-macosx_14_0_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (580) hide show
  1. nexaai/__init__.py +99 -0
  2. nexaai/_stub.cpython-310-darwin.so +0 -0
  3. nexaai/_version.py +4 -0
  4. nexaai/asr.py +68 -0
  5. nexaai/asr_impl/__init__.py +0 -0
  6. nexaai/asr_impl/mlx_asr_impl.py +93 -0
  7. nexaai/asr_impl/pybind_asr_impl.py +127 -0
  8. nexaai/base.py +39 -0
  9. nexaai/binds/__init__.py +7 -0
  10. nexaai/binds/asr_bind.cpython-310-darwin.so +0 -0
  11. nexaai/binds/common_bind.cpython-310-darwin.so +0 -0
  12. nexaai/binds/cpu_gpu/libggml-base.dylib +0 -0
  13. nexaai/binds/cpu_gpu/libggml-cpu.so +0 -0
  14. nexaai/binds/cpu_gpu/libggml-metal.so +0 -0
  15. nexaai/binds/cpu_gpu/libggml.dylib +0 -0
  16. nexaai/binds/cpu_gpu/libmtmd.dylib +0 -0
  17. nexaai/binds/cpu_gpu/libnexa_cpu_gpu.dylib +0 -0
  18. nexaai/binds/cpu_gpu/libnexa_plugin.dylib +0 -0
  19. nexaai/binds/cv_bind.cpython-310-darwin.so +0 -0
  20. nexaai/binds/diarize_bind.cpython-310-darwin.so +0 -0
  21. nexaai/binds/embedder_bind.cpython-310-darwin.so +0 -0
  22. nexaai/binds/libnexa_bridge.dylib +0 -0
  23. nexaai/binds/llm_bind.cpython-310-darwin.so +0 -0
  24. nexaai/binds/metal/libnexa_plugin.dylib +0 -0
  25. nexaai/binds/metal/py-lib/ml.py +888 -0
  26. nexaai/binds/metal/py-lib/mlx_audio/__init__.py +0 -0
  27. nexaai/binds/metal/py-lib/mlx_audio/codec/__init__.py +1 -0
  28. nexaai/binds/metal/py-lib/mlx_audio/codec/models/__init__.py +5 -0
  29. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  30. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  31. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  32. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  33. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  34. nexaai/binds/metal/py-lib/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  35. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/__init__.py +1 -0
  36. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/base.py +228 -0
  37. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/dac.py +285 -0
  38. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  39. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  40. nexaai/binds/metal/py-lib/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  41. nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/__init__.py +1 -0
  42. nexaai/binds/metal/py-lib/mlx_audio/codec/models/encodec/encodec.py +777 -0
  43. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/__init__.py +1 -0
  44. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/mimi.py +286 -0
  45. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  46. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  47. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  48. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  49. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  50. nexaai/binds/metal/py-lib/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  51. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/__init__.py +1 -0
  52. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model.py +260 -0
  53. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/model_v2.py +383 -0
  54. nexaai/binds/metal/py-lib/mlx_audio/codec/models/s3/utils.py +122 -0
  55. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/__init__.py +1 -0
  56. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/attention.py +97 -0
  57. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/layers.py +306 -0
  58. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/snac.py +154 -0
  59. nexaai/binds/metal/py-lib/mlx_audio/codec/models/snac/vq.py +135 -0
  60. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/__init__.py +1 -0
  61. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/mel.py +33 -0
  62. nexaai/binds/metal/py-lib/mlx_audio/codec/models/vocos/vocos.py +359 -0
  63. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/__init__.py +0 -0
  64. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  65. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_descript.py +109 -0
  66. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_encodec.py +58 -0
  67. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_mimi.py +22 -0
  68. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_s3.py +25 -0
  69. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_snac.py +40 -0
  70. nexaai/binds/metal/py-lib/mlx_audio/codec/tests/test_vocos.py +93 -0
  71. nexaai/binds/metal/py-lib/mlx_audio/server.py +525 -0
  72. nexaai/binds/metal/py-lib/mlx_audio/sts/__init__.py +0 -0
  73. nexaai/binds/metal/py-lib/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  74. nexaai/binds/metal/py-lib/mlx_audio/sts/voice_pipeline.py +327 -0
  75. nexaai/binds/metal/py-lib/mlx_audio/stt/__init__.py +0 -0
  76. nexaai/binds/metal/py-lib/mlx_audio/stt/generate.py +174 -0
  77. nexaai/binds/metal/py-lib/mlx_audio/stt/models/__init__.py +0 -0
  78. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  79. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  80. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/attention.py +187 -0
  81. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/audio.py +76 -0
  82. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  83. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  84. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  85. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  86. nexaai/binds/metal/py-lib/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  87. nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  88. nexaai/binds/metal/py-lib/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  89. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/__init__.py +1 -0
  90. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/audio.py +82 -0
  91. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/decoding.py +742 -0
  92. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/timing.py +329 -0
  93. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  94. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/whisper.py +862 -0
  95. nexaai/binds/metal/py-lib/mlx_audio/stt/models/whisper/writers.py +268 -0
  96. nexaai/binds/metal/py-lib/mlx_audio/stt/tests/test_models.py +381 -0
  97. nexaai/binds/metal/py-lib/mlx_audio/stt/utils.py +195 -0
  98. nexaai/binds/metal/py-lib/mlx_audio/tts/__init__.py +1 -0
  99. nexaai/binds/metal/py-lib/mlx_audio/tts/audio_player.py +120 -0
  100. nexaai/binds/metal/py-lib/mlx_audio/tts/convert.py +71 -0
  101. nexaai/binds/metal/py-lib/mlx_audio/tts/generate.py +449 -0
  102. nexaai/binds/metal/py-lib/mlx_audio/tts/models/__init__.py +0 -0
  103. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/__init__.py +4 -0
  104. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/bark.py +528 -0
  105. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/isftnet.py +12 -0
  106. nexaai/binds/metal/py-lib/mlx_audio/tts/models/bark/pipeline.py +442 -0
  107. nexaai/binds/metal/py-lib/mlx_audio/tts/models/base.py +84 -0
  108. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/__init__.py +1 -0
  109. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/audio.py +287 -0
  110. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/config.py +256 -0
  111. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/dia.py +592 -0
  112. nexaai/binds/metal/py-lib/mlx_audio/tts/models/dia/layers.py +870 -0
  113. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/__init__.py +3 -0
  114. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/attention.py +180 -0
  115. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  116. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/conformer.py +247 -0
  117. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  118. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  119. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  120. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  121. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  122. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  123. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/indextts.py +412 -0
  124. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/mel.py +37 -0
  125. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/normalize.py +294 -0
  126. nexaai/binds/metal/py-lib/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  127. nexaai/binds/metal/py-lib/mlx_audio/tts/models/interpolate.py +108 -0
  128. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  129. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  130. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  131. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/modules.py +659 -0
  132. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  133. nexaai/binds/metal/py-lib/mlx_audio/tts/models/kokoro/voice.py +113 -0
  134. nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/__init__.py +3 -0
  135. nexaai/binds/metal/py-lib/mlx_audio/tts/models/llama/llama.py +324 -0
  136. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/__init__.py +1 -0
  137. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  138. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  139. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/outetts.py +255 -0
  140. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  141. nexaai/binds/metal/py-lib/mlx_audio/tts/models/outetts/tokens.py +36 -0
  142. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/__init__.py +3 -0
  143. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/attention.py +195 -0
  144. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/sesame.py +633 -0
  145. nexaai/binds/metal/py-lib/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  146. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/__init__.py +1 -0
  147. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  148. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/bicodec.py +269 -0
  149. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  150. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  151. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  152. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  153. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  154. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  155. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  156. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  157. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  158. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  159. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  160. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  161. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  162. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  163. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  164. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/spark.py +382 -0
  165. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  166. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/file.py +221 -0
  167. nexaai/binds/metal/py-lib/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  168. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/__init__.py +0 -0
  169. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_base.py +66 -0
  170. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_convert.py +173 -0
  171. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_interpolate.py +88 -0
  172. nexaai/binds/metal/py-lib/mlx_audio/tts/tests/test_models.py +974 -0
  173. nexaai/binds/metal/py-lib/mlx_audio/tts/utils.py +337 -0
  174. nexaai/binds/metal/py-lib/mlx_audio/utils.py +237 -0
  175. nexaai/binds/metal/py-lib/mlx_audio/version.py +1 -0
  176. nexaai/binds/metal/py-lib/profiling.py +239 -0
  177. nexaai/binds/nexaml/libfftw3.3.dylib +0 -0
  178. nexaai/binds/nexaml/libfftw3f.3.dylib +0 -0
  179. nexaai/binds/nexaml/libggml-base.dylib +0 -0
  180. nexaai/binds/nexaml/libggml-cpu.so +0 -0
  181. nexaai/binds/nexaml/libggml-metal.so +0 -0
  182. nexaai/binds/nexaml/libggml.dylib +0 -0
  183. nexaai/binds/nexaml/libmp3lame.0.dylib +0 -0
  184. nexaai/binds/nexaml/libmpg123.0.dylib +0 -0
  185. nexaai/binds/nexaml/libnexa-mm-process.dylib +0 -0
  186. nexaai/binds/nexaml/libnexa-sampling.dylib +0 -0
  187. nexaai/binds/nexaml/libnexa_plugin.dylib +0 -0
  188. nexaai/binds/nexaml/libnexaproc.dylib +0 -0
  189. nexaai/binds/nexaml/libomp.dylib +0 -0
  190. nexaai/binds/nexaml/libqwen3-vl.dylib +0 -0
  191. nexaai/binds/nexaml/libqwen3vl-vision.dylib +0 -0
  192. nexaai/binds/rerank_bind.cpython-310-darwin.so +0 -0
  193. nexaai/binds/vlm_bind.cpython-310-darwin.so +0 -0
  194. nexaai/common.py +106 -0
  195. nexaai/cv.py +95 -0
  196. nexaai/cv_impl/__init__.py +0 -0
  197. nexaai/cv_impl/mlx_cv_impl.py +91 -0
  198. nexaai/cv_impl/pybind_cv_impl.py +124 -0
  199. nexaai/diarize.py +80 -0
  200. nexaai/diarize_impl/__init__.py +1 -0
  201. nexaai/diarize_impl/pybind_diarize_impl.py +125 -0
  202. nexaai/embedder.py +73 -0
  203. nexaai/embedder_impl/__init__.py +0 -0
  204. nexaai/embedder_impl/mlx_embedder_impl.py +118 -0
  205. nexaai/embedder_impl/pybind_embedder_impl.py +96 -0
  206. nexaai/image_gen.py +141 -0
  207. nexaai/image_gen_impl/__init__.py +0 -0
  208. nexaai/image_gen_impl/mlx_image_gen_impl.py +292 -0
  209. nexaai/image_gen_impl/pybind_image_gen_impl.py +85 -0
  210. nexaai/llm.py +98 -0
  211. nexaai/llm_impl/__init__.py +0 -0
  212. nexaai/llm_impl/mlx_llm_impl.py +271 -0
  213. nexaai/llm_impl/pybind_llm_impl.py +238 -0
  214. nexaai/log.py +92 -0
  215. nexaai/mlx_backend/asr/__init__.py +12 -0
  216. nexaai/mlx_backend/asr/interface.py +122 -0
  217. nexaai/mlx_backend/common/__init__.py +0 -0
  218. nexaai/mlx_backend/common/utils.py +25 -0
  219. nexaai/mlx_backend/cv/__init__.py +0 -0
  220. nexaai/mlx_backend/cv/generate.py +195 -0
  221. nexaai/mlx_backend/cv/interface.py +162 -0
  222. nexaai/mlx_backend/cv/main.py +81 -0
  223. nexaai/mlx_backend/cv/modeling/pp_ocr_v4.py +1736 -0
  224. nexaai/mlx_backend/embedding/__init__.py +0 -0
  225. nexaai/mlx_backend/embedding/generate.py +333 -0
  226. nexaai/mlx_backend/embedding/interface.py +617 -0
  227. nexaai/mlx_backend/embedding/main.py +173 -0
  228. nexaai/mlx_backend/embedding/modeling/__init__.py +0 -0
  229. nexaai/mlx_backend/embedding/modeling/nexa_jina_v2.py +399 -0
  230. nexaai/mlx_backend/image_gen/__init__.py +1 -0
  231. nexaai/mlx_backend/image_gen/generate_sd.py +244 -0
  232. nexaai/mlx_backend/image_gen/interface.py +82 -0
  233. nexaai/mlx_backend/image_gen/main.py +281 -0
  234. nexaai/mlx_backend/image_gen/stable_diffusion/__init__.py +306 -0
  235. nexaai/mlx_backend/image_gen/stable_diffusion/clip.py +116 -0
  236. nexaai/mlx_backend/image_gen/stable_diffusion/config.py +65 -0
  237. nexaai/mlx_backend/image_gen/stable_diffusion/model_io.py +386 -0
  238. nexaai/mlx_backend/image_gen/stable_diffusion/sampler.py +105 -0
  239. nexaai/mlx_backend/image_gen/stable_diffusion/tokenizer.py +100 -0
  240. nexaai/mlx_backend/image_gen/stable_diffusion/unet.py +460 -0
  241. nexaai/mlx_backend/image_gen/stable_diffusion/vae.py +274 -0
  242. nexaai/mlx_backend/llm/__init__.py +0 -0
  243. nexaai/mlx_backend/llm/generate.py +149 -0
  244. nexaai/mlx_backend/llm/interface.py +764 -0
  245. nexaai/mlx_backend/llm/main.py +68 -0
  246. nexaai/mlx_backend/ml.py +888 -0
  247. nexaai/mlx_backend/mlx_audio/__init__.py +0 -0
  248. nexaai/mlx_backend/mlx_audio/codec/__init__.py +1 -0
  249. nexaai/mlx_backend/mlx_audio/codec/models/__init__.py +5 -0
  250. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/__init__.py +1 -0
  251. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/activation.py +51 -0
  252. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/amp.py +96 -0
  253. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/bigvgan.py +149 -0
  254. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/conv.py +114 -0
  255. nexaai/mlx_backend/mlx_audio/codec/models/bigvgan/resample.py +177 -0
  256. nexaai/mlx_backend/mlx_audio/codec/models/descript/__init__.py +1 -0
  257. nexaai/mlx_backend/mlx_audio/codec/models/descript/base.py +228 -0
  258. nexaai/mlx_backend/mlx_audio/codec/models/descript/dac.py +285 -0
  259. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/__init__.py +1 -0
  260. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/layers.py +129 -0
  261. nexaai/mlx_backend/mlx_audio/codec/models/descript/nn/quantize.py +149 -0
  262. nexaai/mlx_backend/mlx_audio/codec/models/encodec/__init__.py +1 -0
  263. nexaai/mlx_backend/mlx_audio/codec/models/encodec/encodec.py +777 -0
  264. nexaai/mlx_backend/mlx_audio/codec/models/mimi/__init__.py +1 -0
  265. nexaai/mlx_backend/mlx_audio/codec/models/mimi/mimi.py +286 -0
  266. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/__init__.py +20 -0
  267. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/conv.py +398 -0
  268. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/kv_cache.py +199 -0
  269. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/quantization.py +179 -0
  270. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/seanet.py +314 -0
  271. nexaai/mlx_backend/mlx_audio/codec/models/mimi/modules/transformer.py +256 -0
  272. nexaai/mlx_backend/mlx_audio/codec/models/s3/__init__.py +1 -0
  273. nexaai/mlx_backend/mlx_audio/codec/models/s3/model.py +260 -0
  274. nexaai/mlx_backend/mlx_audio/codec/models/s3/model_v2.py +383 -0
  275. nexaai/mlx_backend/mlx_audio/codec/models/s3/utils.py +122 -0
  276. nexaai/mlx_backend/mlx_audio/codec/models/snac/__init__.py +1 -0
  277. nexaai/mlx_backend/mlx_audio/codec/models/snac/attention.py +97 -0
  278. nexaai/mlx_backend/mlx_audio/codec/models/snac/layers.py +306 -0
  279. nexaai/mlx_backend/mlx_audio/codec/models/snac/snac.py +154 -0
  280. nexaai/mlx_backend/mlx_audio/codec/models/snac/vq.py +135 -0
  281. nexaai/mlx_backend/mlx_audio/codec/models/vocos/__init__.py +1 -0
  282. nexaai/mlx_backend/mlx_audio/codec/models/vocos/mel.py +33 -0
  283. nexaai/mlx_backend/mlx_audio/codec/models/vocos/vocos.py +359 -0
  284. nexaai/mlx_backend/mlx_audio/codec/tests/__init__.py +0 -0
  285. nexaai/mlx_backend/mlx_audio/codec/tests/test_bigvgan.py +54 -0
  286. nexaai/mlx_backend/mlx_audio/codec/tests/test_descript.py +109 -0
  287. nexaai/mlx_backend/mlx_audio/codec/tests/test_encodec.py +58 -0
  288. nexaai/mlx_backend/mlx_audio/codec/tests/test_mimi.py +22 -0
  289. nexaai/mlx_backend/mlx_audio/codec/tests/test_s3.py +25 -0
  290. nexaai/mlx_backend/mlx_audio/codec/tests/test_snac.py +40 -0
  291. nexaai/mlx_backend/mlx_audio/codec/tests/test_vocos.py +93 -0
  292. nexaai/mlx_backend/mlx_audio/server.py +525 -0
  293. nexaai/mlx_backend/mlx_audio/sts/__init__.py +0 -0
  294. nexaai/mlx_backend/mlx_audio/sts/tests/test_voice_pipeline.py +156 -0
  295. nexaai/mlx_backend/mlx_audio/sts/voice_pipeline.py +327 -0
  296. nexaai/mlx_backend/mlx_audio/stt/__init__.py +0 -0
  297. nexaai/mlx_backend/mlx_audio/stt/generate.py +174 -0
  298. nexaai/mlx_backend/mlx_audio/stt/models/__init__.py +0 -0
  299. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/__init__.py +1 -0
  300. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/alignment.py +248 -0
  301. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/attention.py +187 -0
  302. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/audio.py +76 -0
  303. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/conformer.py +331 -0
  304. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/ctc.py +34 -0
  305. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/parakeet.py +604 -0
  306. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/rnnt.py +157 -0
  307. nexaai/mlx_backend/mlx_audio/stt/models/parakeet/tokenizer.py +2 -0
  308. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/feature_extractor.py +757 -0
  309. nexaai/mlx_backend/mlx_audio/stt/models/wav2vec/wav2vec.py +738 -0
  310. nexaai/mlx_backend/mlx_audio/stt/models/whisper/__init__.py +1 -0
  311. nexaai/mlx_backend/mlx_audio/stt/models/whisper/audio.py +82 -0
  312. nexaai/mlx_backend/mlx_audio/stt/models/whisper/decoding.py +742 -0
  313. nexaai/mlx_backend/mlx_audio/stt/models/whisper/timing.py +329 -0
  314. nexaai/mlx_backend/mlx_audio/stt/models/whisper/tokenizer.py +398 -0
  315. nexaai/mlx_backend/mlx_audio/stt/models/whisper/whisper.py +862 -0
  316. nexaai/mlx_backend/mlx_audio/stt/models/whisper/writers.py +268 -0
  317. nexaai/mlx_backend/mlx_audio/stt/tests/test_models.py +381 -0
  318. nexaai/mlx_backend/mlx_audio/stt/utils.py +195 -0
  319. nexaai/mlx_backend/mlx_audio/tts/__init__.py +1 -0
  320. nexaai/mlx_backend/mlx_audio/tts/audio_player.py +120 -0
  321. nexaai/mlx_backend/mlx_audio/tts/convert.py +71 -0
  322. nexaai/mlx_backend/mlx_audio/tts/generate.py +449 -0
  323. nexaai/mlx_backend/mlx_audio/tts/models/__init__.py +0 -0
  324. nexaai/mlx_backend/mlx_audio/tts/models/bark/__init__.py +4 -0
  325. nexaai/mlx_backend/mlx_audio/tts/models/bark/bark.py +528 -0
  326. nexaai/mlx_backend/mlx_audio/tts/models/bark/isftnet.py +12 -0
  327. nexaai/mlx_backend/mlx_audio/tts/models/bark/pipeline.py +442 -0
  328. nexaai/mlx_backend/mlx_audio/tts/models/base.py +84 -0
  329. nexaai/mlx_backend/mlx_audio/tts/models/dia/__init__.py +1 -0
  330. nexaai/mlx_backend/mlx_audio/tts/models/dia/audio.py +287 -0
  331. nexaai/mlx_backend/mlx_audio/tts/models/dia/config.py +256 -0
  332. nexaai/mlx_backend/mlx_audio/tts/models/dia/dia.py +592 -0
  333. nexaai/mlx_backend/mlx_audio/tts/models/dia/layers.py +870 -0
  334. nexaai/mlx_backend/mlx_audio/tts/models/indextts/__init__.py +3 -0
  335. nexaai/mlx_backend/mlx_audio/tts/models/indextts/attention.py +180 -0
  336. nexaai/mlx_backend/mlx_audio/tts/models/indextts/bigvgan.py +124 -0
  337. nexaai/mlx_backend/mlx_audio/tts/models/indextts/conformer.py +247 -0
  338. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/__init__.py +0 -0
  339. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/asp.py +59 -0
  340. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/ecapa_tdnn.py +91 -0
  341. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/se_res2net.py +132 -0
  342. nexaai/mlx_backend/mlx_audio/tts/models/indextts/ecapa_tdnn/tdnn.py +42 -0
  343. nexaai/mlx_backend/mlx_audio/tts/models/indextts/gpt2.py +38 -0
  344. nexaai/mlx_backend/mlx_audio/tts/models/indextts/indextts.py +412 -0
  345. nexaai/mlx_backend/mlx_audio/tts/models/indextts/mel.py +37 -0
  346. nexaai/mlx_backend/mlx_audio/tts/models/indextts/normalize.py +294 -0
  347. nexaai/mlx_backend/mlx_audio/tts/models/indextts/perceiver.py +62 -0
  348. nexaai/mlx_backend/mlx_audio/tts/models/interpolate.py +108 -0
  349. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/__init__.py +4 -0
  350. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/istftnet.py +979 -0
  351. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/kokoro.py +331 -0
  352. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/modules.py +659 -0
  353. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/pipeline.py +453 -0
  354. nexaai/mlx_backend/mlx_audio/tts/models/kokoro/voice.py +113 -0
  355. nexaai/mlx_backend/mlx_audio/tts/models/llama/__init__.py +3 -0
  356. nexaai/mlx_backend/mlx_audio/tts/models/llama/llama.py +324 -0
  357. nexaai/mlx_backend/mlx_audio/tts/models/outetts/__init__.py +1 -0
  358. nexaai/mlx_backend/mlx_audio/tts/models/outetts/audio_processor.py +351 -0
  359. nexaai/mlx_backend/mlx_audio/tts/models/outetts/dac_interface.py +162 -0
  360. nexaai/mlx_backend/mlx_audio/tts/models/outetts/default_speaker.json +461 -0
  361. nexaai/mlx_backend/mlx_audio/tts/models/outetts/outetts.py +255 -0
  362. nexaai/mlx_backend/mlx_audio/tts/models/outetts/prompt_processor.py +181 -0
  363. nexaai/mlx_backend/mlx_audio/tts/models/outetts/tokens.py +36 -0
  364. nexaai/mlx_backend/mlx_audio/tts/models/sesame/__init__.py +3 -0
  365. nexaai/mlx_backend/mlx_audio/tts/models/sesame/attention.py +195 -0
  366. nexaai/mlx_backend/mlx_audio/tts/models/sesame/sesame.py +633 -0
  367. nexaai/mlx_backend/mlx_audio/tts/models/sesame/watermarking.py +105 -0
  368. nexaai/mlx_backend/mlx_audio/tts/models/spark/__init__.py +1 -0
  369. nexaai/mlx_backend/mlx_audio/tts/models/spark/audio_tokenizer.py +138 -0
  370. nexaai/mlx_backend/mlx_audio/tts/models/spark/bicodec.py +269 -0
  371. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/__init__.py +0 -0
  372. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/__init__.py +0 -0
  373. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/blocks/sampler.py +111 -0
  374. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/__init__.py +0 -0
  375. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_decoder.py +120 -0
  376. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/feat_encoder.py +136 -0
  377. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/encoder_decoder/wave_generator.py +113 -0
  378. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/finite_scalar_quantization.py +238 -0
  379. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual.py +209 -0
  380. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/residual_fsq.py +309 -0
  381. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/__init__.py +1 -0
  382. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/ecapa_tdnn.py +283 -0
  383. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/perceiver_encoder.py +326 -0
  384. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/pooling_layers.py +297 -0
  385. nexaai/mlx_backend/mlx_audio/tts/models/spark/modules/speaker/speaker_encoder.py +155 -0
  386. nexaai/mlx_backend/mlx_audio/tts/models/spark/spark.py +382 -0
  387. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/audio.py +220 -0
  388. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/file.py +221 -0
  389. nexaai/mlx_backend/mlx_audio/tts/models/spark/utils/token_parser.py +181 -0
  390. nexaai/mlx_backend/mlx_audio/tts/tests/__init__.py +0 -0
  391. nexaai/mlx_backend/mlx_audio/tts/tests/test_base.py +66 -0
  392. nexaai/mlx_backend/mlx_audio/tts/tests/test_convert.py +173 -0
  393. nexaai/mlx_backend/mlx_audio/tts/tests/test_interpolate.py +88 -0
  394. nexaai/mlx_backend/mlx_audio/tts/tests/test_models.py +974 -0
  395. nexaai/mlx_backend/mlx_audio/tts/utils.py +337 -0
  396. nexaai/mlx_backend/mlx_audio/utils.py +237 -0
  397. nexaai/mlx_backend/mlx_audio/version.py +1 -0
  398. nexaai/mlx_backend/profiling.py +239 -0
  399. nexaai/mlx_backend/rerank/__init__.py +0 -0
  400. nexaai/mlx_backend/rerank/generate.py +174 -0
  401. nexaai/mlx_backend/rerank/interface.py +287 -0
  402. nexaai/mlx_backend/rerank/main.py +127 -0
  403. nexaai/mlx_backend/rerank/modeling/__init__.py +0 -0
  404. nexaai/mlx_backend/rerank/modeling/nexa_jina_rerank.py +330 -0
  405. nexaai/mlx_backend/sd/__init__.py +1 -0
  406. nexaai/mlx_backend/sd/interface.py +362 -0
  407. nexaai/mlx_backend/sd/main.py +286 -0
  408. nexaai/mlx_backend/sd/modeling/__init__.py +306 -0
  409. nexaai/mlx_backend/sd/modeling/clip.py +116 -0
  410. nexaai/mlx_backend/sd/modeling/config.py +65 -0
  411. nexaai/mlx_backend/sd/modeling/model_io.py +385 -0
  412. nexaai/mlx_backend/sd/modeling/sampler.py +105 -0
  413. nexaai/mlx_backend/sd/modeling/tokenizer.py +100 -0
  414. nexaai/mlx_backend/sd/modeling/unet.py +460 -0
  415. nexaai/mlx_backend/sd/modeling/vae.py +274 -0
  416. nexaai/mlx_backend/tts/__init__.py +12 -0
  417. nexaai/mlx_backend/tts/interface.py +276 -0
  418. nexaai/mlx_backend/vlm/__init__.py +3 -0
  419. nexaai/mlx_backend/vlm/generate.py +572 -0
  420. nexaai/mlx_backend/vlm/generate_qwen3_vl.py +374 -0
  421. nexaai/mlx_backend/vlm/generate_qwen3_vl_moe.py +259 -0
  422. nexaai/mlx_backend/vlm/interface.py +559 -0
  423. nexaai/mlx_backend/vlm/main.py +365 -0
  424. nexaai/mlx_backend/vlm/modeling/__init__.py +0 -0
  425. nexaai/mlx_backend/vlm/modeling/convert.py +68 -0
  426. nexaai/mlx_backend/vlm/modeling/models/__init__.py +0 -0
  427. nexaai/mlx_backend/vlm/modeling/models/aya_vision/__init__.py +8 -0
  428. nexaai/mlx_backend/vlm/modeling/models/aya_vision/aya_vision.py +193 -0
  429. nexaai/mlx_backend/vlm/modeling/models/aya_vision/interpolate.py +186 -0
  430. nexaai/mlx_backend/vlm/modeling/models/aya_vision/language.py +233 -0
  431. nexaai/mlx_backend/vlm/modeling/models/aya_vision/vision.py +503 -0
  432. nexaai/mlx_backend/vlm/modeling/models/base.py +202 -0
  433. nexaai/mlx_backend/vlm/modeling/models/cache.py +230 -0
  434. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/__init__.py +10 -0
  435. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/conversation.py +264 -0
  436. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/deepseek_vl_v2.py +472 -0
  437. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/language.py +591 -0
  438. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/processing_deepsek_vl_v2.py +526 -0
  439. nexaai/mlx_backend/vlm/modeling/models/deepseek_vl_v2/vision.py +356 -0
  440. nexaai/mlx_backend/vlm/modeling/models/florence2/__init__.py +8 -0
  441. nexaai/mlx_backend/vlm/modeling/models/florence2/florence2.py +366 -0
  442. nexaai/mlx_backend/vlm/modeling/models/florence2/language.py +488 -0
  443. nexaai/mlx_backend/vlm/modeling/models/florence2/vision.py +591 -0
  444. nexaai/mlx_backend/vlm/modeling/models/gemma3/__init__.py +8 -0
  445. nexaai/mlx_backend/vlm/modeling/models/gemma3/gemma3.py +213 -0
  446. nexaai/mlx_backend/vlm/modeling/models/gemma3/language.py +315 -0
  447. nexaai/mlx_backend/vlm/modeling/models/gemma3/vision.py +238 -0
  448. nexaai/mlx_backend/vlm/modeling/models/gemma3n/__init__.py +2 -0
  449. nexaai/mlx_backend/vlm/modeling/models/gemma3n/audio.py +1038 -0
  450. nexaai/mlx_backend/vlm/modeling/models/gemma3n/config.py +139 -0
  451. nexaai/mlx_backend/vlm/modeling/models/gemma3n/gemma3n.py +322 -0
  452. nexaai/mlx_backend/vlm/modeling/models/gemma3n/language.py +629 -0
  453. nexaai/mlx_backend/vlm/modeling/models/gemma3n/vision.py +1022 -0
  454. nexaai/mlx_backend/vlm/modeling/models/idefics2/__init__.py +9 -0
  455. nexaai/mlx_backend/vlm/modeling/models/idefics2/idefics2.py +294 -0
  456. nexaai/mlx_backend/vlm/modeling/models/idefics2/language.py +191 -0
  457. nexaai/mlx_backend/vlm/modeling/models/idefics2/vision.py +267 -0
  458. nexaai/mlx_backend/vlm/modeling/models/idefics3/__init__.py +8 -0
  459. nexaai/mlx_backend/vlm/modeling/models/idefics3/idefics3.py +175 -0
  460. nexaai/mlx_backend/vlm/modeling/models/idefics3/language.py +192 -0
  461. nexaai/mlx_backend/vlm/modeling/models/idefics3/vision.py +233 -0
  462. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/__init__.py +9 -0
  463. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/internvl_chat.py +140 -0
  464. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/language.py +220 -0
  465. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/processor.py +393 -0
  466. nexaai/mlx_backend/vlm/modeling/models/internvl_chat/vision.py +293 -0
  467. nexaai/mlx_backend/vlm/modeling/models/kernels.py +307 -0
  468. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/__init__.py +8 -0
  469. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/kimi_vl.py +143 -0
  470. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/language.py +509 -0
  471. nexaai/mlx_backend/vlm/modeling/models/kimi_vl/vision.py +522 -0
  472. nexaai/mlx_backend/vlm/modeling/models/llama4/__init__.py +8 -0
  473. nexaai/mlx_backend/vlm/modeling/models/llama4/language.py +386 -0
  474. nexaai/mlx_backend/vlm/modeling/models/llama4/llama4.py +138 -0
  475. nexaai/mlx_backend/vlm/modeling/models/llama4/vision.py +560 -0
  476. nexaai/mlx_backend/vlm/modeling/models/llava/__init__.py +8 -0
  477. nexaai/mlx_backend/vlm/modeling/models/llava/language.py +240 -0
  478. nexaai/mlx_backend/vlm/modeling/models/llava/llava.py +153 -0
  479. nexaai/mlx_backend/vlm/modeling/models/llava/vision.py +259 -0
  480. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/__init__.py +9 -0
  481. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/language.py +236 -0
  482. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/llava_bunny.py +256 -0
  483. nexaai/mlx_backend/vlm/modeling/models/llava_bunny/vision.py +303 -0
  484. nexaai/mlx_backend/vlm/modeling/models/llava_next/__init__.py +8 -0
  485. nexaai/mlx_backend/vlm/modeling/models/llava_next/language.py +230 -0
  486. nexaai/mlx_backend/vlm/modeling/models/llava_next/llava_next.py +160 -0
  487. nexaai/mlx_backend/vlm/modeling/models/llava_next/vision.py +243 -0
  488. nexaai/mlx_backend/vlm/modeling/models/mistral3/__init__.py +8 -0
  489. nexaai/mlx_backend/vlm/modeling/models/mistral3/mistral3.py +283 -0
  490. nexaai/mlx_backend/vlm/modeling/models/mllama/__init__.py +8 -0
  491. nexaai/mlx_backend/vlm/modeling/models/mllama/language.py +416 -0
  492. nexaai/mlx_backend/vlm/modeling/models/mllama/mllama.py +172 -0
  493. nexaai/mlx_backend/vlm/modeling/models/mllama/vision.py +499 -0
  494. nexaai/mlx_backend/vlm/modeling/models/molmo/__init__.py +8 -0
  495. nexaai/mlx_backend/vlm/modeling/models/molmo/language.py +243 -0
  496. nexaai/mlx_backend/vlm/modeling/models/molmo/molmo.py +133 -0
  497. nexaai/mlx_backend/vlm/modeling/models/molmo/vision.py +465 -0
  498. nexaai/mlx_backend/vlm/modeling/models/multi_modality/__init__.py +10 -0
  499. nexaai/mlx_backend/vlm/modeling/models/multi_modality/language.py +230 -0
  500. nexaai/mlx_backend/vlm/modeling/models/multi_modality/multi_modality.py +385 -0
  501. nexaai/mlx_backend/vlm/modeling/models/multi_modality/sam.py +557 -0
  502. nexaai/mlx_backend/vlm/modeling/models/multi_modality/vision.py +526 -0
  503. nexaai/mlx_backend/vlm/modeling/models/paligemma/__init__.py +8 -0
  504. nexaai/mlx_backend/vlm/modeling/models/paligemma/language.py +282 -0
  505. nexaai/mlx_backend/vlm/modeling/models/paligemma/paligemma.py +160 -0
  506. nexaai/mlx_backend/vlm/modeling/models/paligemma/vision.py +242 -0
  507. nexaai/mlx_backend/vlm/modeling/models/phi3_v/__init__.py +8 -0
  508. nexaai/mlx_backend/vlm/modeling/models/phi3_v/language.py +21 -0
  509. nexaai/mlx_backend/vlm/modeling/models/phi3_v/phi3_v.py +243 -0
  510. nexaai/mlx_backend/vlm/modeling/models/phi3_v/su_rope.py +71 -0
  511. nexaai/mlx_backend/vlm/modeling/models/phi3_v/vision.py +324 -0
  512. nexaai/mlx_backend/vlm/modeling/models/pixtral/__init__.py +8 -0
  513. nexaai/mlx_backend/vlm/modeling/models/pixtral/language.py +229 -0
  514. nexaai/mlx_backend/vlm/modeling/models/pixtral/pixtral.py +161 -0
  515. nexaai/mlx_backend/vlm/modeling/models/pixtral/vision.py +320 -0
  516. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/__init__.py +2 -0
  517. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/config.py +108 -0
  518. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/language.py +490 -0
  519. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/qwen2_5_vl.py +168 -0
  520. nexaai/mlx_backend/vlm/modeling/models/qwen2_5_vl/vision.py +414 -0
  521. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/__init__.py +2 -0
  522. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/config.py +104 -0
  523. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/language.py +490 -0
  524. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/qwen2_vl.py +167 -0
  525. nexaai/mlx_backend/vlm/modeling/models/qwen2_vl/vision.py +312 -0
  526. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/__init__.py +0 -0
  527. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/base.py +117 -0
  528. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/cache.py +531 -0
  529. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/generate.py +701 -0
  530. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/rope_utils.py +255 -0
  531. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/sample_utils.py +303 -0
  532. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/llm_common/tokenizer_utils.py +407 -0
  533. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/processor.py +476 -0
  534. nexaai/mlx_backend/vlm/modeling/models/qwen3_vl/qwen3vl.py +1262 -0
  535. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/__init__.py +0 -0
  536. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/base.py +117 -0
  537. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/cache.py +531 -0
  538. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/generate.py +701 -0
  539. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/rope_utils.py +255 -0
  540. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/sample_utils.py +303 -0
  541. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/llm_common/tokenizer_utils.py +407 -0
  542. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/processor.py +476 -0
  543. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/qwen3vl_moe.py +1308 -0
  544. nexaai/mlx_backend/vlm/modeling/models/qwen3vl_moe/switch_layers.py +210 -0
  545. nexaai/mlx_backend/vlm/modeling/models/smolvlm/__init__.py +8 -0
  546. nexaai/mlx_backend/vlm/modeling/models/smolvlm/smolvlm.py +62 -0
  547. nexaai/mlx_backend/vlm/modeling/processing_qwen2_5_vl.py +209 -0
  548. nexaai/mlx_backend/vlm/modeling/processing_qwen2_vl.py +215 -0
  549. nexaai/mlx_backend/vlm/modeling/prompt_utils.py +474 -0
  550. nexaai/mlx_backend/vlm/modeling/sample_utils.py +39 -0
  551. nexaai/mlx_backend/vlm/modeling/tokenizer_utils.py +344 -0
  552. nexaai/mlx_backend/vlm/modeling/trainer/__init__.py +9 -0
  553. nexaai/mlx_backend/vlm/modeling/trainer/lora.py +70 -0
  554. nexaai/mlx_backend/vlm/modeling/trainer/trainer.py +296 -0
  555. nexaai/mlx_backend/vlm/modeling/trainer/utils.py +160 -0
  556. nexaai/mlx_backend/vlm/modeling/utils.py +928 -0
  557. nexaai/rerank.py +57 -0
  558. nexaai/rerank_impl/__init__.py +0 -0
  559. nexaai/rerank_impl/mlx_rerank_impl.py +94 -0
  560. nexaai/rerank_impl/pybind_rerank_impl.py +136 -0
  561. nexaai/runtime.py +68 -0
  562. nexaai/runtime_error.py +24 -0
  563. nexaai/tts.py +75 -0
  564. nexaai/tts_impl/__init__.py +0 -0
  565. nexaai/tts_impl/mlx_tts_impl.py +94 -0
  566. nexaai/tts_impl/pybind_tts_impl.py +43 -0
  567. nexaai/utils/decode.py +18 -0
  568. nexaai/utils/manifest_utils.py +531 -0
  569. nexaai/utils/model_manager.py +1745 -0
  570. nexaai/utils/model_types.py +49 -0
  571. nexaai/utils/progress_tracker.py +389 -0
  572. nexaai/utils/quantization_utils.py +245 -0
  573. nexaai/vlm.py +130 -0
  574. nexaai/vlm_impl/__init__.py +0 -0
  575. nexaai/vlm_impl/mlx_vlm_impl.py +259 -0
  576. nexaai/vlm_impl/pybind_vlm_impl.py +275 -0
  577. nexaai-1.0.29.dist-info/METADATA +35 -0
  578. nexaai-1.0.29.dist-info/RECORD +580 -0
  579. nexaai-1.0.29.dist-info/WHEEL +5 -0
  580. nexaai-1.0.29.dist-info/top_level.txt +1 -0
@@ -0,0 +1,742 @@
1
+ # Copyright © 2023 Apple Inc.
2
+
3
+ import zlib
4
+ from dataclasses import dataclass, field, replace
5
+ from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
6
+
7
+ import mlx.core as mx
8
+ import numpy as np
9
+ from mlx.utils import tree_map
10
+
11
+ from .audio import CHUNK_LENGTH
12
+ from .tokenizer import Tokenizer, get_tokenizer
13
+
14
+
15
+ def compression_ratio(text) -> float:
16
+ text_bytes = text.encode("utf-8")
17
+ return len(text_bytes) / len(zlib.compress(text_bytes))
18
+
19
+
20
+ def detect_language(
21
+ model: "Whisper", mel: mx.array, tokenizer: Tokenizer = None
22
+ ) -> Tuple[mx.array, List[dict]]:
23
+ """
24
+ Detect the spoken language in the audio, and return them as list of strings, along with the ids
25
+ of the most probable language tokens and the probability distribution over all language tokens.
26
+ This is performed outside the main decode loop in order to not interfere with kv-caching.
27
+
28
+ Returns
29
+ -------
30
+ language_tokens : mx.array, shape = (n_audio,)
31
+ ids of the most probable language tokens, which appears after the startoftranscript token.
32
+ language_probs : List[Dict[str, float]], length = n_audio
33
+ list of dictionaries containing the probability distribution over all languages.
34
+ """
35
+ if tokenizer is None:
36
+ tokenizer = get_tokenizer(
37
+ model.is_multilingual, num_languages=model.num_languages
38
+ )
39
+ if (
40
+ tokenizer.language is None
41
+ or tokenizer.language_token not in tokenizer.sot_sequence
42
+ ):
43
+ raise ValueError(
44
+ "This model doesn't have language tokens so it can't perform lang id"
45
+ )
46
+
47
+ single = mel.ndim == 2
48
+ if single:
49
+ mel = mel[None]
50
+
51
+ # skip encoder forward pass if already-encoded audio features were given
52
+ if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
53
+ mel = model.encoder(mel)
54
+
55
+ # forward pass using a single token, startoftranscript
56
+ n_audio = mel.shape[0]
57
+ x = mx.array([[tokenizer.sot]] * n_audio) # [n_audio, 1]
58
+ logits = model.logits(x, mel)[:, 0]
59
+
60
+ # collect detected languages; suppress all non-language tokens
61
+ mask = mx.full(logits.shape[-1], -mx.inf, dtype=mx.float32)
62
+ mask[list(tokenizer.all_language_tokens)] = 0.0
63
+ logits += mask
64
+ language_tokens = mx.argmax(logits, axis=-1)
65
+ language_token_probs = mx.softmax(logits, axis=-1)
66
+ language_token_probs = np.array(language_token_probs)
67
+ language_probs = [
68
+ {
69
+ c: language_token_probs[i, j].item()
70
+ for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
71
+ }
72
+ for i in range(n_audio)
73
+ ]
74
+
75
+ if single:
76
+ language_tokens = language_tokens[0]
77
+ language_probs = language_probs[0]
78
+
79
+ return language_tokens, language_probs
80
+
81
+
82
+ @dataclass(frozen=True)
83
+ class DecodingOptions:
84
+ # whether to perform X->X "transcribe" or X->English "translate"
85
+ task: str = "transcribe"
86
+
87
+ # language that the audio is in; uses detected language if None
88
+ language: Optional[str] = None
89
+
90
+ # sampling-related options
91
+ temperature: float = 0.0
92
+ sample_len: Optional[int] = None # maximum number of tokens to sample
93
+ best_of: Optional[int] = None # number of independent sample trajectories, if t > 0
94
+ beam_size: Optional[int] = None # number of beams in beam search, if t == 0
95
+ patience: Optional[float] = None # patience in beam search (arxiv:2204.05424)
96
+
97
+ # "alpha" in Google NMT, or None for length norm, when ranking generations
98
+ # to select which to return among the beams or best-of-N samples
99
+ length_penalty: Optional[float] = None
100
+
101
+ # text or tokens to feed as the prompt or the prefix; for more info:
102
+ # https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
103
+ prompt: Optional[Union[str, List[int]]] = None # for the previous context
104
+ prefix: Optional[Union[str, List[int]]] = None # to prefix the current context
105
+
106
+ # list of tokens ids (or comma-separated token ids) to suppress
107
+ # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
108
+ suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
109
+ suppress_blank: bool = True # this will suppress blank outputs
110
+
111
+ # timestamp sampling options
112
+ without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
113
+ max_initial_timestamp: Optional[float] = 1.0
114
+
115
+ # implementation details
116
+ fp16: bool = True # use fp16 for most of the calculation
117
+
118
+
119
+ @dataclass(frozen=True)
120
+ class DecodingResult:
121
+ audio_features: mx.array
122
+ language: str
123
+ language_probs: Optional[Dict[str, float]] = None
124
+ tokens: List[int] = field(default_factory=list)
125
+ text: str = ""
126
+ avg_logprob: float = np.nan
127
+ no_speech_prob: float = np.nan
128
+ temperature: float = np.nan
129
+ compression_ratio: float = np.nan
130
+
131
+
132
+ class Inference:
133
+ def __init__(self, model: "Whisper"):
134
+ self.model: "Whisper" = model
135
+ self.kv_cache = None
136
+
137
+ def logits(self, tokens: mx.array, audio_features: mx.array) -> mx.array:
138
+ """Perform a forward pass on the decoder and return per-token logits"""
139
+ logits, self.kv_cache, _ = self.model.decoder(
140
+ tokens, audio_features, kv_cache=self.kv_cache
141
+ )
142
+ return logits.astype(mx.float32)
143
+
144
+ def rearrange_kv_cache(self, source_indices):
145
+ """Update the key-value cache according to the updated beams"""
146
+ # update the key/value cache to contain the selected sequences
147
+ if source_indices != list(range(len(source_indices))):
148
+ self.kv_cache = tree_map(lambda x: x[source_indices], self.kv_cache)
149
+
150
+ def reset(self):
151
+ self.kv_cache = None
152
+
153
+
154
+ class SequenceRanker:
155
+ def rank(
156
+ self, tokens: List[List[mx.array]], sum_logprobs: List[List[float]]
157
+ ) -> List[int]:
158
+ """
159
+ Given a list of groups of samples and their cumulative log probabilities,
160
+ return the indices of the samples in each group to select as the final result
161
+ """
162
+ raise NotImplementedError
163
+
164
+
165
+ class MaximumLikelihoodRanker(SequenceRanker):
166
+ """
167
+ Select the sample with the highest log probabilities, penalized using either
168
+ a simple length normalization or Google NMT paper's length penalty
169
+ """
170
+
171
+ def __init__(self, length_penalty: Optional[float]):
172
+ self.length_penalty = length_penalty
173
+
174
+ def rank(self, tokens: List[List[List[int]]], sum_logprobs: List[List[float]]):
175
+ def scores(logprobs, lengths):
176
+ result = []
177
+ for logprob, length in zip(logprobs, lengths):
178
+ if self.length_penalty is None:
179
+ penalty = length
180
+ else:
181
+ # from the Google NMT paper
182
+ penalty = ((5 + length) / 6) ** self.length_penalty
183
+ result.append(logprob / penalty)
184
+ return result
185
+
186
+ # get the sequence with the highest score
187
+ lengths = [[len(t) for t in s] for s in tokens]
188
+ return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
189
+
190
+
191
+ class TokenDecoder:
192
+ def reset(self):
193
+ """Initialize any stateful variables for decoding a new sequence"""
194
+
195
+ def update(
196
+ self, tokens: mx.array, logits: mx.array, sum_logprobs: mx.array
197
+ ) -> Tuple[mx.array, bool, mx.array]:
198
+ """Specify how to select the next token, based on the current trace and logits
199
+
200
+ Parameters
201
+ ----------
202
+ tokens : mx.array, shape = (n_batch, current_sequence_length)
203
+ all tokens in the context so far, including the prefix and sot_sequence tokens
204
+
205
+ logits : mx.array, shape = (n_batch, vocab_size)
206
+ per-token logits of the probability distribution at the current step
207
+
208
+ sum_logprobs : mx.array, shape = (n_batch)
209
+ cumulative log probabilities for each sequence
210
+
211
+ Returns
212
+ -------
213
+ tokens : mx.array, shape = (n_batch, current_sequence_length + 1)
214
+ the tokens, appended with the selected next token
215
+
216
+ completed : bool
217
+ True if all sequences has reached the end of text
218
+
219
+ sum_logprobs: mx.array, shape = (n_batch)
220
+ updated cumulative log probabilities for each sequence
221
+
222
+ """
223
+ raise NotImplementedError
224
+
225
+ def finalize(
226
+ self, tokens: mx.array, sum_logprobs: mx.array
227
+ ) -> Tuple[Sequence[Sequence[mx.array]], List[List[float]]]:
228
+ """Finalize search and return the final candidate sequences
229
+
230
+ Parameters
231
+ ----------
232
+ tokens : mx.array, shape = (n_audio, n_group, current_sequence_length)
233
+ all tokens in the context so far, including the prefix and sot_sequence
234
+
235
+ sum_logprobs : mx.array, shape = (n_audio, n_group)
236
+ cumulative log probabilities for each sequence
237
+
238
+ Returns
239
+ -------
240
+ tokens : Sequence[Sequence[mx.array]], length = n_audio
241
+ sequence of mx.arrays containing candidate token sequences, for each audio input
242
+
243
+ sum_logprobs : List[List[float]], length = n_audio
244
+ sequence of cumulative log probabilities corresponding to the above
245
+
246
+ """
247
+ raise NotImplementedError
248
+
249
+
250
+ @mx.compile
251
+ def categorical(logits, temp):
252
+ return mx.random.categorical(logits / temp)
253
+
254
+
255
+ class GreedyDecoder(TokenDecoder):
256
+ def __init__(self, temperature: float, eot: int):
257
+ self.temperature = temperature
258
+ self.eot = eot
259
+
260
+ def update(
261
+ self, tokens: mx.array, logits: mx.array, sum_logprobs: mx.array
262
+ ) -> Tuple[mx.array, bool, mx.array]:
263
+ if self.temperature == 0:
264
+ next_tokens = logits.argmax(axis=-1)
265
+ else:
266
+ next_tokens = categorical(logits, self.temperature)
267
+
268
+ logprobs = logits - mx.logsumexp(logits, axis=-1)
269
+
270
+ current_logprobs = logprobs[mx.arange(logprobs.shape[0]), next_tokens]
271
+ sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
272
+
273
+ eot_mask = tokens[:, -1] == self.eot
274
+ next_tokens = next_tokens * (1 - eot_mask) + self.eot * eot_mask
275
+ tokens = mx.concatenate([tokens, next_tokens[:, None]], axis=-1)
276
+
277
+ completed = mx.all(tokens[:, -1] == self.eot)
278
+ return tokens, completed, sum_logprobs
279
+
280
+ def finalize(self, tokens: mx.array, sum_logprobs: mx.array):
281
+ # make sure each sequence has at least one EOT token at the end
282
+ tokens = mx.pad(tokens, [(0, 0), (0, 0), (0, 1)], constant_values=self.eot)
283
+ return tokens, sum_logprobs
284
+
285
+
286
+ class LogitFilter:
287
+ def apply(self, logits: mx.array, tokens: mx.array) -> mx.array:
288
+ """Apply any filtering or masking to logits
289
+
290
+ Parameters
291
+ ----------
292
+ logits : mx.array, shape = (n_batch, vocab_size)
293
+ per-token logits of the probability distribution at the current step
294
+
295
+ tokens : mx.array, shape = (n_batch, current_sequence_length)
296
+ all tokens in the context so far, including the prefix and sot_sequence tokens
297
+
298
+ """
299
+ raise NotImplementedError
300
+
301
+
302
+ class SuppressBlank(LogitFilter):
303
+ def __init__(self, tokenizer: Tokenizer, sample_begin: int, n_vocab: int):
304
+ self.sample_begin = sample_begin
305
+ mask = np.zeros(n_vocab, np.float32)
306
+ mask[tokenizer.encode(" ") + [tokenizer.eot]] = -np.inf
307
+ self.mask = mx.array(mask)
308
+
309
+ def apply(self, logits: mx.array, tokens: mx.array) -> mx.array:
310
+ if tokens.shape[1] == self.sample_begin:
311
+ return logits + self.mask
312
+ return logits
313
+
314
+
315
+ class SuppressTokens(LogitFilter):
316
+ def __init__(self, suppress_tokens: Sequence[int], n_vocab: int):
317
+ mask = np.zeros(n_vocab, np.float32)
318
+ mask[list(suppress_tokens)] = -np.inf
319
+ self.mask = mx.array(mask)
320
+
321
+ def apply(self, logits: mx.array, tokens: mx.array) -> mx.array:
322
+ return logits + self.mask
323
+
324
+
325
+ class ApplyTimestampRules(LogitFilter):
326
+ def __init__(
327
+ self,
328
+ tokenizer: Tokenizer,
329
+ sample_begin: int,
330
+ max_initial_timestamp_index: Optional[int],
331
+ ):
332
+ self.tokenizer = tokenizer
333
+ self.sample_begin = sample_begin
334
+ self.max_initial_timestamp_index = max_initial_timestamp_index
335
+
336
+ def apply(self, logits: mx.array, tokens: mx.array) -> mx.array:
337
+ mask = np.zeros(logits.shape, np.float32)
338
+ # suppress <|notimestamps|> which is handled by without_timestamps
339
+ if self.tokenizer.no_timestamps is not None:
340
+ mask[:, self.tokenizer.no_timestamps] = -np.inf
341
+
342
+ ## timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
343
+ tokens = tokens.tolist()
344
+ for k in range(len(tokens)):
345
+ seq = tokens[k][self.sample_begin :]
346
+ last_was_timestamp = (
347
+ len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
348
+ )
349
+ penultimate_was_timestamp = (
350
+ len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
351
+ )
352
+
353
+ if last_was_timestamp:
354
+ if penultimate_was_timestamp: # has to be non-timestamp
355
+ mask[k, self.tokenizer.timestamp_begin :] = -np.inf
356
+ else: # cannot be normal text tokens
357
+ mask[k, : self.tokenizer.eot] = -np.inf
358
+
359
+ timestamps = [
360
+ i for i, v in enumerate(seq) if v > self.tokenizer.timestamp_begin
361
+ ]
362
+ if len(timestamps) > 0:
363
+ # timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
364
+ # also force each segment to have a nonzero length, to prevent infinite looping
365
+ last_timestamp = timestamps[-1]
366
+ if not last_timestamp or penultimate_was_timestamp:
367
+ last_timestamp += 1
368
+ mask[k, self.tokenizer.timestamp_begin : last_timestamp] = -np.inf
369
+
370
+ if len(tokens[0]) == self.sample_begin:
371
+ # suppress generating non-timestamp tokens at the beginning
372
+ mask[:, : self.tokenizer.timestamp_begin] = -np.inf
373
+
374
+ # apply the `max_initial_timestamp` option
375
+ if self.max_initial_timestamp_index is not None:
376
+ last_allowed = (
377
+ self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
378
+ )
379
+ mask[:, last_allowed + 1 :] = -np.inf
380
+
381
+ # if sum of probability over timestamps is above any other token, sample timestamp
382
+ mask = mx.array(mask)
383
+ logprobs = logits - mx.logsumexp(logits, axis=-1)
384
+ timestamp_logprob = logprobs[:, self.tokenizer.timestamp_begin :].logsumexp(
385
+ axis=-1, keepdims=True
386
+ )
387
+ max_text_token_logprob = logprobs[:, : self.tokenizer.timestamp_begin].max(
388
+ axis=-1, keepdims=True
389
+ )
390
+ mask[:, : self.tokenizer.timestamp_begin] = mx.where(
391
+ timestamp_logprob > max_text_token_logprob,
392
+ -mx.inf,
393
+ mask[:, : self.tokenizer.timestamp_begin],
394
+ )
395
+ return logits + mask
396
+
397
+
398
+ class DecodingTask:
399
+ inference: Inference
400
+ sequence_ranker: SequenceRanker
401
+ decoder: TokenDecoder
402
+ logit_filters: List[LogitFilter]
403
+
404
+ def __init__(self, model: "Whisper", options: DecodingOptions):
405
+ self.model = model
406
+
407
+ language = options.language or "en"
408
+ tokenizer = get_tokenizer(
409
+ model.is_multilingual,
410
+ num_languages=model.num_languages,
411
+ language=language,
412
+ task=options.task,
413
+ )
414
+ self.tokenizer: Tokenizer = tokenizer
415
+ self.options: DecodingOptions = self._verify_options(options)
416
+
417
+ self.n_group: int = options.beam_size or options.best_of or 1
418
+ self.n_ctx: int = model.dims.n_text_ctx
419
+ self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
420
+
421
+ self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
422
+ if self.options.without_timestamps:
423
+ self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
424
+
425
+ self.initial_tokens: Tuple[int] = self._get_initial_tokens()
426
+ self.sample_begin: int = len(self.initial_tokens)
427
+ self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
428
+
429
+ # inference: implements the forward pass through the decoder, including kv caching
430
+ self.inference = Inference(model)
431
+
432
+ # sequence ranker: implements how to rank a group of sampled sequences
433
+ self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
434
+
435
+ # decoder: implements how to select the next tokens, given the autoregressive distribution
436
+ if options.beam_size is not None:
437
+ raise NotImplementedError("Beam search decoder is not yet implemented")
438
+ else:
439
+ self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
440
+
441
+ # logit filters: applies various rules to suppress or penalize certain tokens
442
+ self.logit_filters = []
443
+ if self.options.suppress_blank:
444
+ self.logit_filters.append(
445
+ SuppressBlank(self.tokenizer, self.sample_begin, model.dims.n_vocab)
446
+ )
447
+ if self.options.suppress_tokens:
448
+ self.logit_filters.append(
449
+ SuppressTokens(self._get_suppress_tokens(), model.dims.n_vocab)
450
+ )
451
+
452
+ if not options.without_timestamps:
453
+ precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
454
+ max_initial_timestamp_index = None
455
+ if options.max_initial_timestamp:
456
+ max_initial_timestamp_index = round(
457
+ self.options.max_initial_timestamp / precision
458
+ )
459
+ self.logit_filters.append(
460
+ ApplyTimestampRules(
461
+ tokenizer, self.sample_begin, max_initial_timestamp_index
462
+ )
463
+ )
464
+
465
+ def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
466
+ if options.beam_size is not None and options.best_of is not None:
467
+ raise ValueError("beam_size and best_of can't be given together")
468
+ if options.temperature == 0:
469
+ if options.best_of is not None:
470
+ raise ValueError("best_of with greedy sampling (T=0) is not compatible")
471
+ if options.patience is not None and options.beam_size is None:
472
+ raise ValueError("patience requires beam_size to be given")
473
+ if options.length_penalty is not None and not (
474
+ 0 <= options.length_penalty <= 1
475
+ ):
476
+ raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
477
+
478
+ return options
479
+
480
+ def _get_initial_tokens(self) -> Tuple[int]:
481
+ tokens = list(self.sot_sequence)
482
+
483
+ if prefix := self.options.prefix:
484
+ prefix_tokens = (
485
+ self.tokenizer.encode(" " + prefix.strip())
486
+ if isinstance(prefix, str)
487
+ else prefix
488
+ )
489
+ if self.sample_len is not None:
490
+ max_prefix_len = self.n_ctx // 2 - self.sample_len
491
+ prefix_tokens = prefix_tokens[-max_prefix_len:]
492
+ tokens = tokens + prefix_tokens
493
+
494
+ if prompt := self.options.prompt:
495
+ prompt_tokens = (
496
+ self.tokenizer.encode(" " + prompt.strip())
497
+ if isinstance(prompt, str)
498
+ else prompt
499
+ )
500
+ tokens = (
501
+ [self.tokenizer.sot_prev]
502
+ + prompt_tokens[-(self.n_ctx // 2 - 1) :]
503
+ + tokens
504
+ )
505
+
506
+ return tuple(tokens)
507
+
508
+ def _get_suppress_tokens(self) -> Tuple[int]:
509
+ suppress_tokens = self.options.suppress_tokens
510
+
511
+ if isinstance(suppress_tokens, str):
512
+ suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
513
+
514
+ if -1 in suppress_tokens:
515
+ suppress_tokens = [t for t in suppress_tokens if t >= 0]
516
+ suppress_tokens.extend(self.tokenizer.non_speech_tokens)
517
+ elif suppress_tokens is None or len(suppress_tokens) == 0:
518
+ suppress_tokens = [] # interpret empty string as an empty list
519
+ else:
520
+ assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
521
+
522
+ suppress_tokens.extend(
523
+ [
524
+ self.tokenizer.transcribe,
525
+ self.tokenizer.translate,
526
+ self.tokenizer.sot,
527
+ self.tokenizer.sot_prev,
528
+ self.tokenizer.sot_lm,
529
+ ]
530
+ )
531
+ if self.tokenizer.no_speech is not None:
532
+ # no-speech probability is collected separately
533
+ suppress_tokens.append(self.tokenizer.no_speech)
534
+
535
+ return tuple(sorted(set(suppress_tokens)))
536
+
537
+ def _get_audio_features(self, mel: mx.array):
538
+ if self.options.fp16:
539
+ mel = mel.astype(mx.float16)
540
+
541
+ if mel.shape[-2:] == (
542
+ self.model.dims.n_audio_ctx,
543
+ self.model.dims.n_audio_state,
544
+ ):
545
+ # encoded audio features are given; skip audio encoding
546
+ audio_features = mel
547
+ else:
548
+ audio_features = self.model.encoder(mel)
549
+
550
+ if audio_features.dtype != (mx.float16 if self.options.fp16 else mx.float32):
551
+ raise TypeError(
552
+ f"audio_features has an incorrect dtype: {audio_features.dtype}"
553
+ )
554
+
555
+ return audio_features
556
+
557
+ def _detect_language(self, audio_features: mx.array, tokens: np.array):
558
+ languages = [self.options.language] * audio_features.shape[0]
559
+ lang_probs = None
560
+
561
+ if self.options.language is None or self.options.task == "lang_id":
562
+ lang_tokens, lang_probs = self.model.detect_language(
563
+ audio_features, self.tokenizer
564
+ )
565
+ languages = [max(probs, key=probs.get) for probs in lang_probs]
566
+ if self.options.language is None:
567
+ # write language tokens
568
+ tokens[:, self.sot_index + 1] = np.array(lang_tokens)
569
+
570
+ return languages, lang_probs
571
+
572
+ def _main_loop(self, audio_features: mx.array, tokens: mx.array):
573
+ n_batch = tokens.shape[0]
574
+ sum_logprobs = mx.zeros(n_batch)
575
+
576
+ def _step(inputs, audio_features, tokens, sum_logprobs):
577
+ pre_logits = self.inference.logits(inputs, audio_features)
578
+
579
+ # consider the logits at the last token only
580
+ logits = pre_logits[:, -1]
581
+
582
+ # apply the logit filters, e.g. for suppressing or applying penalty to
583
+ for logit_filter in self.logit_filters:
584
+ logits = logit_filter.apply(logits, tokens)
585
+
586
+ # expand the tokens tensor with the selected next tokens
587
+ tokens, completed, sum_logprobs = self.decoder.update(
588
+ tokens, logits, sum_logprobs
589
+ )
590
+ return tokens, completed, sum_logprobs, pre_logits
591
+
592
+ tokens, completed, sum_logprobs, pre_logits = _step(
593
+ tokens, audio_features, tokens, sum_logprobs
594
+ )
595
+ if self.tokenizer.no_speech is not None: # compute no_speech_probs
596
+ probs_at_sot = mx.softmax(pre_logits[:, self.sot_index], axis=-1)
597
+ no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech]
598
+ else:
599
+ no_speech_probs = mx.full(n_batch, mx.nan)
600
+ mx.async_eval(completed, tokens, sum_logprobs, no_speech_probs)
601
+
602
+ for i in range(1, self.sample_len):
603
+ inputs = tokens[:, -1:]
604
+ if tokens.shape[-1] > self.n_ctx:
605
+ break
606
+ next_tokens, next_completed, next_sum_logprobs, _ = _step(
607
+ inputs, audio_features, tokens, sum_logprobs
608
+ )
609
+ mx.async_eval(next_completed, next_tokens, next_sum_logprobs)
610
+ if completed:
611
+ break
612
+ tokens = next_tokens
613
+ completed = next_completed
614
+ sum_logprobs = next_sum_logprobs
615
+
616
+ return tokens, sum_logprobs, no_speech_probs
617
+
618
+ def run(self, mel: mx.array) -> List[DecodingResult]:
619
+ self.inference.reset()
620
+ self.decoder.reset()
621
+ tokenizer: Tokenizer = self.tokenizer
622
+ n_audio: int = mel.shape[0]
623
+
624
+ audio_features: mx.array = self._get_audio_features(mel) # encoder forward pass
625
+ tokens: mx.array = mx.array(self.initial_tokens)
626
+ tokens = mx.broadcast_to(tokens, (n_audio, len(self.initial_tokens)))
627
+
628
+ # detect language if requested, overwriting the language token
629
+ languages, language_probs = self._detect_language(audio_features, tokens)
630
+ if self.options.task == "lang_id":
631
+ return [
632
+ DecodingResult(
633
+ audio_features=features, language=language, language_probs=probs
634
+ )
635
+ for features, language, probs in zip(
636
+ audio_features, languages, language_probs
637
+ )
638
+ ]
639
+
640
+ # repeat tokens by the group size, for beam search or best-of-n sampling
641
+ if self.n_group > 1:
642
+ tokens = tokens[:, None, :]
643
+ tokens = mx.broadcast_to(
644
+ tokens, [n_audio, self.n_group, len(self.initial_tokens)]
645
+ )
646
+ tokens = tokens.reshape(
647
+ tokens, (n_audio * self.n_group, len(self.initial_tokens))
648
+ )
649
+
650
+ # call the main sampling loop
651
+ tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
652
+
653
+ # reshape the tensors to have (n_audio, n_group) as the first two dimensions
654
+ audio_features = audio_features[:: self.n_group]
655
+ no_speech_probs = no_speech_probs[:: self.n_group]
656
+ assert audio_features.shape[0] == len(no_speech_probs) == n_audio
657
+
658
+ tokens = tokens.reshape(n_audio, self.n_group, -1)
659
+ sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
660
+
661
+ # get the final candidates for each group, and slice between the first sampled token and EOT
662
+ tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
663
+ tokens = tokens[..., self.sample_begin :]
664
+
665
+ # eval and convert to list
666
+ mx.eval(tokens, sum_logprobs, no_speech_probs)
667
+ tokens = tokens.tolist()
668
+ sum_logprobs = sum_logprobs.tolist()
669
+ no_speech_probs = no_speech_probs.tolist()
670
+ tokens = [[t[: t.index(tokenizer.eot)] for t in s] for s in tokens]
671
+
672
+ # select the top-ranked sample in each group
673
+ selected = self.sequence_ranker.rank(tokens, sum_logprobs)
674
+ tokens: List[List[int]] = [t[i] for i, t in zip(selected, tokens)]
675
+ texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
676
+
677
+ sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
678
+ avg_logprobs: List[float] = [
679
+ lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
680
+ ]
681
+
682
+ fields = (
683
+ texts,
684
+ languages,
685
+ tokens,
686
+ audio_features,
687
+ avg_logprobs,
688
+ no_speech_probs,
689
+ )
690
+ if len(set(map(len, fields))) != 1:
691
+ raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
692
+
693
+ return [
694
+ DecodingResult(
695
+ audio_features=features,
696
+ language=language,
697
+ tokens=tokens,
698
+ text=text,
699
+ avg_logprob=avg_logprob,
700
+ no_speech_prob=no_speech_prob,
701
+ temperature=self.options.temperature,
702
+ compression_ratio=compression_ratio(text),
703
+ )
704
+ for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
705
+ *fields
706
+ )
707
+ ]
708
+
709
+
710
+ def decode(
711
+ model: "Whisper",
712
+ mel: mx.array,
713
+ options: DecodingOptions = DecodingOptions(),
714
+ **kwargs,
715
+ ) -> Union[DecodingResult, List[DecodingResult]]:
716
+ """
717
+ Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
718
+
719
+ Parameters
720
+ ----------
721
+ model: Whisper
722
+ the Whisper model instance
723
+
724
+ mel: mx.array, shape = (80, 3000) or (*, 80, 3000)
725
+ An array containing the Mel spectrogram(s)
726
+
727
+ options: DecodingOptions
728
+ A dataclass that contains all necessary options for decoding 30-second segments
729
+
730
+ Returns
731
+ -------
732
+ result: Union[DecodingResult, List[DecodingResult]]
733
+ The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
734
+ """
735
+ if single := mel.ndim == 2:
736
+ mel = mel[None]
737
+
738
+ if kwargs:
739
+ options = replace(options, **kwargs)
740
+
741
+ result = DecodingTask(model, options).run(mel)
742
+ return result[0] if single else result